Skip to content

Commit 37bb553

Browse files
committed
feat: add suggest_cre_mappings() for automatic CRE mapping via embeddings
1 parent a4ae0c5 commit 37bb553

2 files changed

Lines changed: 114 additions & 0 deletions

File tree

application/cmd/cre_main.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from alive_progress import alive_bar
2626
from application.prompt_client import prompt_client as prompt_client
2727
from application.utils import gap_analysis
28+
from application.prompt_client.prompt_client import SIMILARITY_THRESHOLD
2829

2930
logging.basicConfig()
3031
logger = logging.getLogger(__name__)
@@ -315,6 +316,58 @@ def register_standard(
315316
conn.set(standard_hash, value="")
316317

317318

319+
def suggest_cre_mappings(
320+
standard_entries: List[defs.Standard],
321+
collection: db.Node_collection,
322+
confidence_threshold: float = SIMILARITY_THRESHOLD,
323+
) -> Dict[str, Any]:
324+
"""
325+
Given a list of Standard entries, suggest CRE mappings using
326+
cosine similarity on existing embeddings.
327+
328+
Returns high-confidence matches and flags low-confidence ones
329+
for human review.
330+
331+
Args:
332+
standard_entries: list of Standard nodes to map
333+
collection: database connection
334+
confidence_threshold: minimum similarity score to auto-map
335+
336+
Returns:
337+
Dict with 'mapped' (high confidence) and 'needs_review' (low confidence) lists
338+
"""
339+
if not standard_entries:
340+
logger.warning("suggest_cre_mappings() called with no standard_entries")
341+
return {"mapped": [], "needs_review": []}
342+
343+
ph = prompt_client.PromptHandler(database=collection)
344+
results: Dict[str, Any] = {"mapped": [], "needs_review": []}
345+
346+
for node in standard_entries:
347+
text = " ".join(filter(None, [node.name, node.section, node.description]))
348+
if not text.strip():
349+
continue
350+
embedding = ph.get_text_embeddings(text)
351+
cre_id, similarity = ph.get_id_of_most_similar_cre_paginated(
352+
embedding, similarity_threshold=confidence_threshold
353+
)
354+
entry = {
355+
"standard": node.todict(),
356+
"suggested_cre_id": cre_id,
357+
"confidence": round(float(similarity), 4) if similarity else None,
358+
}
359+
if cre_id and similarity and similarity >= confidence_threshold:
360+
results["mapped"].append(entry)
361+
else:
362+
results["needs_review"].append(entry)
363+
364+
logger.info(
365+
f"suggest_cre_mappings: {len(results['mapped'])} mapped, "
366+
f"{len(results['needs_review'])} need review"
367+
)
368+
return results
369+
370+
318371
def parse_standards_from_spreadsheeet(
319372
cre_file: List[Dict[str, Any]],
320373
cache_location: str,

application/tests/cre_main_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,67 @@ def test_add_from_disk(
783783
# main.export_to_osib(file_loc=f"{dir}/osib.yaml", cache=cache)
784784
# mocked_db_connect.assert_called_with(path=cache)
785785
# mocked_cre2osib.assert_called_with([defs.CRE(id="000-000", name="c0")])
786+
787+
@patch.object(prompt_client.PromptHandler, "get_text_embeddings")
788+
@patch.object(prompt_client.PromptHandler, "get_id_of_most_similar_cre_paginated")
789+
def test_suggest_cre_mappings(
790+
self,
791+
mock_get_similar_cre,
792+
mock_get_embeddings,
793+
) -> None:
794+
# Arrange
795+
standard_entries = [
796+
defs.Standard(
797+
name="PCI-DSS",
798+
section="Use strong cryptography to protect data in transit",
799+
description="All transmissions of cardholder data must be encrypted.",
800+
),
801+
defs.Standard(
802+
name="PCI-DSS",
803+
section="Some vague control with no good match",
804+
description="",
805+
),
806+
]
807+
808+
fake_embedding = [0.1] * 768
809+
mock_get_embeddings.return_value = fake_embedding
810+
811+
# First standard maps well, second does not
812+
mock_get_similar_cre.side_effect = [
813+
("cre-db-id-123", 0.85), # high confidence
814+
(None, None), # low confidence / no match
815+
]
816+
817+
# Act
818+
result = main.suggest_cre_mappings(
819+
standard_entries=standard_entries,
820+
collection=self.collection,
821+
)
822+
823+
# Assert
824+
self.assertEqual(len(result["mapped"]), 1)
825+
self.assertEqual(len(result["needs_review"]), 1)
826+
827+
mapped = result["mapped"][0]
828+
self.assertEqual(mapped["suggested_cre_id"], "cre-db-id-123")
829+
self.assertEqual(mapped["confidence"], 0.85)
830+
self.assertEqual(mapped["standard"]["name"], "PCI-DSS")
831+
832+
review = result["needs_review"][0]
833+
self.assertIsNone(review["suggested_cre_id"])
834+
self.assertIsNone(review["confidence"])
835+
836+
# Assert embeddings were called for each standard
837+
self.assertEqual(mock_get_embeddings.call_count, 2)
838+
839+
@patch.object(prompt_client.PromptHandler, "get_text_embeddings")
840+
def test_suggest_cre_mappings_empty_input(self, mock_get_embeddings) -> None:
841+
result = main.suggest_cre_mappings(
842+
standard_entries=[],
843+
collection=self.collection,
844+
)
845+
self.assertEqual(result, {"mapped": [], "needs_review": []})
846+
mock_get_embeddings.assert_not_called()
786847

787848

788849
if __name__ == "__main__":

0 commit comments

Comments
 (0)