11from abc import ABC , abstractmethod
22from dataclasses import dataclass
3- from typing import Any , TypeAlias
3+ from typing import Generic , TypeAlias , TypeVar
44
55from rapidfuzz import fuzz , process
66from shapely .geometry .base import BaseGeometry
1111
1212CandidateAttributes : TypeAlias = list [str ]
1313
14+ K = TypeVar ("K" , str , BaseGeometry )
15+
1416
1517@dataclass (frozen = True )
1618class MatchResult :
1719 """Data class to hold the result of a match operation."""
1820
1921 query : str
2022 matched : str
21- attributes : dict [ str , Any ]
23+ attributes : CandidateAttributes
2224 score : float
2325
2426
25- class BaseMatcher (ABC ):
27+ class BaseMatcher (ABC , Generic [ K ] ):
2628 """Abstract base class for matchers that compute similarity scores."""
2729
2830 @abstractmethod
2931 def get_similarity (
30- self , query : str | BaseGeometry , candidates : dict [str | BaseGeometry , CandidateAttributes ]
32+ self ,
33+ query : K ,
34+ candidates : dict [K , CandidateAttributes ],
3135 ) -> MatchResult | None :
3236 """Return similarity scores for the candidates."""
3337 pass
3438
3539
36- class FuzzyMatcher (BaseMatcher ):
40+ class FuzzyMatcher (BaseMatcher [ str ] ):
3741 """Matcher that uses fuzzy string matching to compute similarity scores."""
3842
3943 def __init__ (self , threshold : float = 80 , scorer_name : str = "wratio" ):
@@ -86,7 +90,9 @@ def get_similarity(
8690 if no match meets the threshold.
8791 """
8892 candidate_strings = list (candidates .keys ())
89- best_match = self .process .extractOne (query , candidate_strings , scorer = self .scorer )
93+ best_match = self .process .extractOne (
94+ query , candidate_strings , scorer = self .scorer
95+ )
9096
9197 if best_match is None :
9298 return None
@@ -108,13 +114,15 @@ def __str__(self) -> str:
108114 return f"FuzzyMatcher(scorer: { self .scorer .__name__ } )"
109115
110116
111- class SentenceTransformerMatcher (BaseMatcher ):
117+ class SentenceTransformerMatcher (BaseMatcher [ str ] ):
112118 """Matcher that uses sentence transformers to compute similarity scores.
113119
114120 NOTE: Not yet implemented.
115121 """
116122
117- def __init__ (self , model_name : str | None = "sentence-transformers/all-MiniLM-L6-v2" ):
123+ def __init__ (
124+ self , model_name : str | None = "sentence-transformers/all-MiniLM-L6-v2"
125+ ):
118126 from sentence_transformers import SentenceTransformer # noqa: PLC0415
119127
120128 model_name = "sentence-transformers/all-MiniLM-L6-v2"
@@ -128,13 +136,15 @@ def get_similarity(
128136 # cand_embs = self.model.encode(candidates, convert_to_tensor=True)
129137 # scores = cos_sim(query_emb, cand_embs)[0].cpu().numpy()
130138 """Return similarity scores for the candidates using sentence transformers."""
131- raise NotImplementedError ("SentenceTransformerMatcher.get_similarity is not implemented." )
139+ raise NotImplementedError (
140+ "SentenceTransformerMatcher.get_similarity is not implemented."
141+ )
132142
133143 def __str__ (self ) -> str :
134144 return f"TransformerMatcher(scorer: { self .model .__name__ } )"
135145
136146
137- class GeometryMatcher (BaseMatcher ):
147+ class GeometryMatcher (BaseMatcher [ BaseGeometry ] ):
138148 """Match org units using spatial proximity and overlap.
139149
140150 NOTE: Not yet implemented. This is a test implementation.
@@ -198,14 +208,20 @@ def _score(self, ref: BaseGeometry, cand: BaseGeometry) -> float | None:
198208 distance_score = 1.0 - (distance / self .max_distance )
199209
200210 overlap_score = 0.0
201- if self .use_overlap and ref .geom_type == "Polygon" and cand .geom_type == "Polygon" :
211+ if (
212+ self .use_overlap
213+ and ref .geom_type == "Polygon"
214+ and cand .geom_type == "Polygon"
215+ ):
202216 inter = ref .intersection (cand ).area
203217 union = ref .union (cand ).area
204218 if union > 0 :
205219 overlap_score = inter / union
206220
207221 # Final weighted score
208- return (1 - self .overlap_weight ) * distance_score + self .overlap_weight * overlap_score
222+ return (
223+ 1 - self .overlap_weight
224+ ) * distance_score + self .overlap_weight * overlap_score
209225
210226 def _geom_id (self , geom : BaseGeometry ) -> str :
211227 """Return an identifier for the query geometry.
0 commit comments