Skip to content

Commit 0321bca

Browse files
committed
Improve the code and add the tests
1 parent 70f28fc commit 0321bca

9 files changed

Lines changed: 2918 additions & 187 deletions

File tree

pyramid_matching/README.md

Lines changed: 144 additions & 3 deletions
Large diffs are not rendered by default.

pyramid_matching/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dependencies = [
1717

1818
[project.optional-dependencies]
1919
# These we only need to install when running the test, not when installing the package
20-
dev = ["pytest", "pytest-cov"]
20+
dev = ["pytest", "pytest-cov", "pandas", "pyarrow"]
2121

2222
[tool.pytest.ini_options]
2323
testpaths = ["tests"]
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import polars as pl
2+
3+
from .matchers import FuzzyMatcher
4+
from .pyramid_matcher import PyramidMatcher
5+
6+
7+
def main():
8+
"""Example of calling the match_pyramids function."""
9+
dhis2_pyramid = pl.read_csv(
10+
"/home/leyregarrido/01_github_repos/openhexa-ds-developments/pyramid_matching/data/dhis2_pyramid.csv"
11+
)
12+
data = pl.read_csv(
13+
"/home/leyregarrido/01_github_repos/openhexa-ds-developments/pyramid_matching/data/data_to_match.csv"
14+
)
15+
16+
matcher = FuzzyMatcher(threshold=80)
17+
pyramid_matcher = PyramidMatcher(matcher=matcher)
18+
19+
(
20+
matched_data,
21+
matched_data_simplified,
22+
reference_not_matched,
23+
candidate_not_matched,
24+
) = pyramid_matcher.run_matching(
25+
reference_pyramid=dhis2_pyramid,
26+
candidate_pyramid=data,
27+
# levels_to_match=["level_1", "level_2", "level_3", "level_4", "level_5"] # auto
28+
)
29+
30+
print(matched_data.head())
31+
print(matched_data_simplified.head())
32+
print(reference_not_matched.head())
33+
print(candidate_not_matched.head())
34+
35+
36+
if __name__ == "__main__":
37+
main()

pyramid_matching/pyramid_matcher/matchers.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
3-
from typing import Any, TypeAlias
3+
from typing import Generic, TypeAlias, TypeVar
44

55
from rapidfuzz import fuzz, process
66
from shapely.geometry.base import BaseGeometry
@@ -11,29 +11,33 @@
1111

1212
CandidateAttributes: TypeAlias = list[str]
1313

14+
K = TypeVar("K", str, BaseGeometry)
15+
1416

1517
@dataclass(frozen=True)
1618
class MatchResult:
1719
"""Data class to hold the result of a match operation."""
1820

1921
query: str
2022
matched: str
21-
attributes: dict[str, Any]
23+
attributes: CandidateAttributes
2224
score: float
2325

2426

25-
class BaseMatcher(ABC):
27+
class BaseMatcher(ABC, Generic[K]):
2628
"""Abstract base class for matchers that compute similarity scores."""
2729

2830
@abstractmethod
2931
def get_similarity(
30-
self, query: str | BaseGeometry, candidates: dict[str | BaseGeometry, CandidateAttributes]
32+
self,
33+
query: K,
34+
candidates: dict[K, CandidateAttributes],
3135
) -> MatchResult | None:
3236
"""Return similarity scores for the candidates."""
3337
pass
3438

3539

36-
class FuzzyMatcher(BaseMatcher):
40+
class FuzzyMatcher(BaseMatcher[str]):
3741
"""Matcher that uses fuzzy string matching to compute similarity scores."""
3842

3943
def __init__(self, threshold: float = 80, scorer_name: str = "wratio"):
@@ -86,7 +90,9 @@ def get_similarity(
8690
if no match meets the threshold.
8791
"""
8892
candidate_strings = list(candidates.keys())
89-
best_match = self.process.extractOne(query, candidate_strings, scorer=self.scorer)
93+
best_match = self.process.extractOne(
94+
query, candidate_strings, scorer=self.scorer
95+
)
9096

9197
if best_match is None:
9298
return None
@@ -108,13 +114,15 @@ def __str__(self) -> str:
108114
return f"FuzzyMatcher(scorer: {self.scorer.__name__})"
109115

110116

111-
class SentenceTransformerMatcher(BaseMatcher):
117+
class SentenceTransformerMatcher(BaseMatcher[str]):
112118
"""Matcher that uses sentence transformers to compute similarity scores.
113119
114120
NOTE: Not yet implemented.
115121
"""
116122

117-
def __init__(self, model_name: str | None = "sentence-transformers/all-MiniLM-L6-v2"):
123+
def __init__(
124+
self, model_name: str | None = "sentence-transformers/all-MiniLM-L6-v2"
125+
):
118126
from sentence_transformers import SentenceTransformer # noqa: PLC0415
119127

120128
model_name = "sentence-transformers/all-MiniLM-L6-v2"
@@ -128,13 +136,15 @@ def get_similarity(
128136
# cand_embs = self.model.encode(candidates, convert_to_tensor=True)
129137
# scores = cos_sim(query_emb, cand_embs)[0].cpu().numpy()
130138
"""Return similarity scores for the candidates using sentence transformers."""
131-
raise NotImplementedError("SentenceTransformerMatcher.get_similarity is not implemented.")
139+
raise NotImplementedError(
140+
"SentenceTransformerMatcher.get_similarity is not implemented."
141+
)
132142

133143
def __str__(self) -> str:
134144
return f"TransformerMatcher(scorer: {self.model.__name__})"
135145

136146

137-
class GeometryMatcher(BaseMatcher):
147+
class GeometryMatcher(BaseMatcher[BaseGeometry]):
138148
"""Match org units using spatial proximity and overlap.
139149
140150
NOTE: Not yet implemented. This is a test implementation.
@@ -198,14 +208,20 @@ def _score(self, ref: BaseGeometry, cand: BaseGeometry) -> float | None:
198208
distance_score = 1.0 - (distance / self.max_distance)
199209

200210
overlap_score = 0.0
201-
if self.use_overlap and ref.geom_type == "Polygon" and cand.geom_type == "Polygon":
211+
if (
212+
self.use_overlap
213+
and ref.geom_type == "Polygon"
214+
and cand.geom_type == "Polygon"
215+
):
202216
inter = ref.intersection(cand).area
203217
union = ref.union(cand).area
204218
if union > 0:
205219
overlap_score = inter / union
206220

207221
# Final weighted score
208-
return (1 - self.overlap_weight) * distance_score + self.overlap_weight * overlap_score
222+
return (
223+
1 - self.overlap_weight
224+
) * distance_score + self.overlap_weight * overlap_score
209225

210226
def _geom_id(self, geom: BaseGeometry) -> str:
211227
"""Return an identifier for the query geometry.

0 commit comments

Comments
 (0)