Skip to content

Commit 97ed07e

Browse files
authored
Misc bug fixes (#208)
Fixed a bunch of things found by asking Claude for a review. It found several nits, a logic bug in the code to determine whether a set of text ranges all lie inside another, and a real doozie: a mutable default in `SemanticRefAccumulator.__init__`. Quite a few new tests were added, too.
1 parent 4e428e2 commit 97ed07e

18 files changed

Lines changed: 893 additions & 73 deletions

src/typeagent/aitools/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def parse_azure_endpoint(
191191
if not azure_endpoint:
192192
raise RuntimeError(f"Environment variable {endpoint_envvar} not found")
193193

194-
m = re.search(r"[?,]api-version=([\d-]+(?:preview)?)", azure_endpoint)
194+
m = re.search(r"[?&]api-version=([\d-]+(?:preview)?)", azure_endpoint)
195195
if not m:
196196
raise RuntimeError(
197197
f"{endpoint_envvar}={azure_endpoint} doesn't contain valid api-version field"

src/typeagent/aitools/vectorbase.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,9 @@ def fuzzy_lookup_embedding_in_subset(
150150
max_hits: int | None = None,
151151
min_score: float | None = None,
152152
) -> list[ScoredInt]:
153+
ordinals_set = set(ordinals_of_subset)
153154
return self.fuzzy_lookup_embedding(
154-
embedding, max_hits, min_score, lambda i: i in ordinals_of_subset
155+
embedding, max_hits, min_score, lambda i: i in ordinals_set
155156
)
156157

157158
async def fuzzy_lookup(

src/typeagent/emails/email_import.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,8 @@ def _merge_chunks(
263263
yield cur_chunk
264264
cur_chunk = new_chunk
265265
else:
266-
cur_chunk += separator
266+
if cur_chunk:
267+
cur_chunk += separator
267268
cur_chunk += new_chunk
268269

269270
if (len(cur_chunk)) > 0:

src/typeagent/knowpro/answers.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -404,11 +404,12 @@ async def get_enclosing_date_range_for_text_range(
404404
start_timestamp = (await messages.get_item(range.start.message_ordinal)).timestamp
405405
if not start_timestamp:
406406
return None
407-
end_timestamp = (
408-
(await messages.get_item(range.end.message_ordinal)).timestamp
409-
if range.end
410-
else None
411-
)
407+
end_timestamp: str | None = None
408+
if range.end:
409+
end_ordinal = range.end.message_ordinal
410+
if end_ordinal < await messages.size():
411+
end_timestamp = (await messages.get_item(end_ordinal)).timestamp
412+
# else: range extends to the end of the conversation; leave as None.
412413
return DateRange(
413414
start=Datetime.fromisoformat(start_timestamp),
414415
end=Datetime.fromisoformat(end_timestamp) if end_timestamp else None,
@@ -535,7 +536,7 @@ def facets_to_merged_facets(facets: list[Facet]) -> MergedFacets:
535536
merged_facets: MergedFacets = {}
536537
for facet in facets:
537538
name = facet.name.lower()
538-
value = str(facet).lower()
539+
value = str(facet.value).lower()
539540
merged_facets.setdefault(name, []).append(value)
540541
return merged_facets
541542

src/typeagent/knowpro/collections.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,14 @@ def add(self, value: T, score: float, is_exact_match: bool = True) -> None:
9191
)
9292
)
9393
else:
94+
# New related-only match: hit_count stays 0 because
95+
# only exact matches count as direct hits. This matters
96+
# for select_with_hit_count / _matches_with_min_hit_count
97+
# which filter on hit_count to weed out noise.
9498
self.set_match(
9599
Match(
96100
value,
97-
hit_count=1,
101+
hit_count=0,
98102
score=0.0,
99103
related_hit_count=1,
100104
related_score=score,
@@ -250,9 +254,11 @@ def smooth_match_score[T](match: Match[T]) -> None:
250254

251255

252256
class SemanticRefAccumulator(MatchAccumulator[SemanticRefOrdinal]):
253-
def __init__(self, search_term_matches: set[str] = set()):
257+
def __init__(self, search_term_matches: set[str] | None = None):
254258
super().__init__()
255-
self.search_term_matches = search_term_matches
259+
self.search_term_matches = (
260+
search_term_matches if search_term_matches is not None else set()
261+
)
256262

257263
def add_term_matches(
258264
self,
@@ -330,8 +336,7 @@ async def group_matches_by_type(
330336
semantic_ref = await semantic_refs.get_item(match.value)
331337
group = groups.get(semantic_ref.knowledge.knowledge_type)
332338
if group is None:
333-
group = SemanticRefAccumulator()
334-
group.search_term_matches = self.search_term_matches
339+
group = SemanticRefAccumulator(self.search_term_matches)
335340
groups[semantic_ref.knowledge.knowledge_type] = group
336341
group.set_match(match)
337342
return groups
@@ -513,11 +518,10 @@ def add_ranges(self, text_ranges: "list[TextRange] | TextRangeCollection") -> No
513518
for text_range in text_ranges._ranges:
514519
self.add_range(text_range)
515520

516-
def is_in_range(self, inner_range: TextRange) -> bool:
517-
if len(self._ranges) == 0:
518-
return False
519-
i = bisect.bisect_left(self._ranges, inner_range)
520-
for outer_range in self._ranges[i:]:
521+
def contains_range(self, inner_range: TextRange) -> bool:
522+
# Since ranges are sorted by start, once we pass inner_range's start
523+
# no further range can contain it.
524+
for outer_range in self._ranges:
521525
if outer_range.start > inner_range.start:
522526
break
523527
if inner_range in outer_range:
@@ -544,7 +548,7 @@ def is_range_in_scope(self, inner_range: TextRange) -> bool:
544548
# We have a very simple impl: we don't intersect/union ranges yet.
545549
# Instead, we ensure that the inner range is not rejected by any outer ranges.
546550
for outer_ranges in self.text_ranges:
547-
if not outer_ranges.is_in_range(inner_range):
551+
if not outer_ranges.contains_range(inner_range):
548552
return False
549553
return True
550554

src/typeagent/knowpro/interfaces_search.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
)
1919

2020
__all__ = [
21-
"SearchTerm",
2221
"KnowledgePropertyName",
2322
"PropertySearchTerm",
23+
"SearchSelectExpr",
24+
"SearchTerm",
2425
"SearchTermGroup",
2526
"SearchTermGroupTypes",
27+
"SemanticRefSearchResult",
2628
"WhenFilter",
27-
"SearchSelectExpr",
2829
]
2930

3031

@@ -142,15 +143,3 @@ class SemanticRefSearchResult:
142143

143144
term_matches: set[str]
144145
semantic_ref_matches: list[ScoredSemanticRefOrdinal]
145-
146-
147-
__all__ = [
148-
"KnowledgePropertyName",
149-
"PropertySearchTerm",
150-
"SearchSelectExpr",
151-
"SearchTerm",
152-
"SearchTermGroup",
153-
"SearchTermGroupTypes",
154-
"SemanticRefSearchResult",
155-
"WhenFilter",
156-
]

src/typeagent/knowpro/query.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
Thread,
4646
)
4747
from .kplib import ConcreteEntity
48+
from .utils import aenumerate
4849

4950
# TODO: Move to compilelib.py
5051
type BooleanOp = Literal["and", "or", "or_max"]
@@ -101,11 +102,14 @@ async def get_text_range_for_date_range(
101102
messages = conversation.messages
102103
range_start_ordinal: MessageOrdinal = -1
103104
range_end_ordinal = range_start_ordinal
104-
async for message in messages:
105-
if Datetime.fromisoformat(message.timestamp) in date_range:
105+
async for ordinal, message in aenumerate(messages):
106+
if (
107+
message.timestamp
108+
and Datetime.fromisoformat(message.timestamp) in date_range
109+
):
106110
if range_start_ordinal < 0:
107-
range_start_ordinal = message.ordinal
108-
range_end_ordinal = message.ordinal
111+
range_start_ordinal = ordinal
112+
range_end_ordinal = ordinal
109113
else:
110114
if range_start_ordinal >= 0:
111115
# We have a range, so break.
@@ -696,7 +700,7 @@ class WhereSemanticRefExpr(QueryOpExpr[SemanticRefAccumulator]):
696700

697701
async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator:
698702
accumulator = await self.source_expr.eval(context)
699-
filtered = SemanticRefAccumulator(accumulator.search_term_matches)
703+
filtered = SemanticRefAccumulator(set(accumulator.search_term_matches))
700704

701705
# Filter matches asynchronously
702706
filtered_matches = []

src/typeagent/knowpro/search.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,9 @@ class SearchOptions:
9090

9191
def __repr__(self):
9292
parts = []
93-
for key in dir(self):
94-
if not key.startswith("_"):
95-
value = getattr(self, key)
96-
if value is not None:
97-
parts.append(f"{key}={value!r}")
93+
for key, value in vars(self).items():
94+
if not key.startswith("_") and value is not None:
95+
parts.append(f"{key}={value!r}")
9896
return f"{self.__class__.__name__}({', '.join(parts)})"
9997

10098

src/typeagent/knowpro/searchlang.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,9 @@ class LanguageSearchOptions(SearchOptions):
8383

8484
def __repr__(self):
8585
parts = []
86-
for key in dir(self):
87-
if not key.startswith("_"):
88-
value = getattr(self, key)
89-
if value is not None:
90-
parts.append(f"{key}={value!r}")
86+
for key, value in vars(self).items():
87+
if not key.startswith("_") and value is not None:
88+
parts.append(f"{key}={value!r}")
9189
return f"{self.__class__.__name__}({', '.join(parts)})"
9290

9391

@@ -371,6 +369,9 @@ def compile_action_term_as_search_terms(
371369
self.compile_entity_terms_as_search_terms(
372370
action_term.additional_entities, action_group
373371
)
372+
# only append the nested or_max wrapper when created one (use_or_max) and it's non-empty.
373+
if use_or_max and action_group.terms:
374+
term_group.terms.append(action_group)
374375
return term_group
375376

376377
def compile_search_terms(
@@ -609,21 +610,6 @@ def add_entity_name_to_group(
609610
exact_match_value,
610611
)
611612

612-
def add_search_term_to_groupadd_entity_name_to_group(
613-
self,
614-
entity_term: EntityTerm,
615-
property_name: PropertyNames,
616-
term_group: SearchTermGroup,
617-
exact_match_value: bool = False,
618-
) -> None:
619-
if not entity_term.is_name_pronoun:
620-
self.add_property_term_to_group(
621-
property_name.value,
622-
entity_term.name,
623-
term_group,
624-
exact_match_value,
625-
)
626-
627613
def add_property_term_to_group(
628614
self,
629615
property_name: str,

src/typeagent/knowpro/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,18 @@
33

44
"""Utility functions for the knowpro package."""
55

6+
from collections.abc import AsyncIterable
7+
68
from .interfaces import MessageOrdinal, TextLocation, TextRange
79

810

11+
async def aenumerate[T](aiterable: AsyncIterable[T], start: int = 0):
12+
i = start
13+
async for item in aiterable:
14+
yield i, item
15+
i += 1
16+
17+
918
def text_range_from_message_chunk(
1019
message_ordinal: MessageOrdinal,
1120
chunk_ordinal: int = 0,

0 commit comments

Comments
 (0)