Skip to content

Commit 07a74c0

Browse files
committed
Improve the code and add the tests
1 parent 70f28fc commit 07a74c0

9 files changed

Lines changed: 2887 additions & 187 deletions

File tree

.github/workflows/ci.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ jobs:
2222
name: Tests
2323
runs-on: ubuntu-latest
2424
strategy:
25+
fail-fast: false
2526
matrix:
2627
package: [d2d_development, pyramid_matching]
2728
steps:

pyramid_matching/README.md

Lines changed: 143 additions & 3 deletions
Large diffs are not rendered by default.

pyramid_matching/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ dependencies = [
1313
"rapidfuzz>=3.0.0",
1414
"shapely>=2.0.0",
1515
"polars>=1.0.0",
16+
"pandas>=2.3.1",
1617
]
1718

1819
[project.optional-dependencies]
1920
# These we only need to install when running the test, not when installing the package
20-
dev = ["pytest", "pytest-cov"]
21+
dev = ["pytest", "pytest-cov", "pyarrow"]
2122

2223
[tool.pytest.ini_options]
2324
testpaths = ["tests"]

pyramid_matching/pyramid_matcher/matchers.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
3-
from typing import Any, TypeAlias
3+
from typing import Generic, TypeAlias, TypeVar
44

55
from rapidfuzz import fuzz, process
66
from shapely.geometry.base import BaseGeometry
@@ -11,29 +11,33 @@
1111

1212
CandidateAttributes: TypeAlias = list[str]
1313

14+
K = TypeVar("K", str, BaseGeometry)
15+
1416

1517
@dataclass(frozen=True)
1618
class MatchResult:
1719
"""Data class to hold the result of a match operation."""
1820

1921
query: str
2022
matched: str
21-
attributes: dict[str, Any]
23+
attributes: CandidateAttributes
2224
score: float
2325

2426

25-
class BaseMatcher(ABC):
27+
class BaseMatcher(ABC, Generic[K]):
2628
"""Abstract base class for matchers that compute similarity scores."""
2729

2830
@abstractmethod
2931
def get_similarity(
30-
self, query: str | BaseGeometry, candidates: dict[str | BaseGeometry, CandidateAttributes]
32+
self,
33+
query: K,
34+
candidates: dict[K, CandidateAttributes],
3135
) -> MatchResult | None:
3236
"""Return similarity scores for the candidates."""
3337
pass
3438

3539

36-
class FuzzyMatcher(BaseMatcher):
40+
class FuzzyMatcher(BaseMatcher[str]):
3741
"""Matcher that uses fuzzy string matching to compute similarity scores."""
3842

3943
def __init__(self, threshold: float = 80, scorer_name: str = "wratio"):
@@ -86,7 +90,9 @@ def get_similarity(
8690
if no match meets the threshold.
8791
"""
8892
candidate_strings = list(candidates.keys())
89-
best_match = self.process.extractOne(query, candidate_strings, scorer=self.scorer)
93+
best_match = self.process.extractOne(
94+
query, candidate_strings, scorer=self.scorer
95+
)
9096

9197
if best_match is None:
9298
return None
@@ -108,13 +114,15 @@ def __str__(self) -> str:
108114
return f"FuzzyMatcher(scorer: {self.scorer.__name__})"
109115

110116

111-
class SentenceTransformerMatcher(BaseMatcher):
117+
class SentenceTransformerMatcher(BaseMatcher[str]):
112118
"""Matcher that uses sentence transformers to compute similarity scores.
113119
114120
NOTE: Not yet implemented.
115121
"""
116122

117-
def __init__(self, model_name: str | None = "sentence-transformers/all-MiniLM-L6-v2"):
123+
def __init__(
124+
self, model_name: str | None = "sentence-transformers/all-MiniLM-L6-v2"
125+
):
118126
from sentence_transformers import SentenceTransformer # noqa: PLC0415
119127

120128
model_name = "sentence-transformers/all-MiniLM-L6-v2"
@@ -128,13 +136,15 @@ def get_similarity(
128136
# cand_embs = self.model.encode(candidates, convert_to_tensor=True)
129137
# scores = cos_sim(query_emb, cand_embs)[0].cpu().numpy()
130138
"""Return similarity scores for the candidates using sentence transformers."""
131-
raise NotImplementedError("SentenceTransformerMatcher.get_similarity is not implemented.")
139+
raise NotImplementedError(
140+
"SentenceTransformerMatcher.get_similarity is not implemented."
141+
)
132142

133143
def __str__(self) -> str:
134144
return f"TransformerMatcher(scorer: {self.model.__name__})"
135145

136146

137-
class GeometryMatcher(BaseMatcher):
147+
class GeometryMatcher(BaseMatcher[BaseGeometry]):
138148
"""Match org units using spatial proximity and overlap.
139149
140150
NOTE: Not yet implemented. This is a test implementation.
@@ -198,14 +208,20 @@ def _score(self, ref: BaseGeometry, cand: BaseGeometry) -> float | None:
198208
distance_score = 1.0 - (distance / self.max_distance)
199209

200210
overlap_score = 0.0
201-
if self.use_overlap and ref.geom_type == "Polygon" and cand.geom_type == "Polygon":
211+
if (
212+
self.use_overlap
213+
and ref.geom_type == "Polygon"
214+
and cand.geom_type == "Polygon"
215+
):
202216
inter = ref.intersection(cand).area
203217
union = ref.union(cand).area
204218
if union > 0:
205219
overlap_score = inter / union
206220

207221
# Final weighted score
208-
return (1 - self.overlap_weight) * distance_score + self.overlap_weight * overlap_score
222+
return (
223+
1 - self.overlap_weight
224+
) * distance_score + self.overlap_weight * overlap_score
209225

210226
def _geom_id(self, geom: BaseGeometry) -> str:
211227
"""Return an identifier for the query geometry.

0 commit comments

Comments
 (0)