@@ -783,6 +783,67 @@ def test_add_from_disk(
783783 # main.export_to_osib(file_loc=f"{dir}/osib.yaml", cache=cache)
784784 # mocked_db_connect.assert_called_with(path=cache)
785785 # mocked_cre2osib.assert_called_with([defs.CRE(id="000-000", name="c0")])
786+
787+ @patch .object (prompt_client .PromptHandler , "get_text_embeddings" )
788+ @patch .object (prompt_client .PromptHandler , "get_id_of_most_similar_cre_paginated" )
789+ def test_suggest_cre_mappings (
790+ self ,
791+ mock_get_similar_cre ,
792+ mock_get_embeddings ,
793+ ) -> None :
794+ # Arrange
795+ standard_entries = [
796+ defs .Standard (
797+ name = "PCI-DSS" ,
798+ section = "Use strong cryptography to protect data in transit" ,
799+ description = "All transmissions of cardholder data must be encrypted." ,
800+ ),
801+ defs .Standard (
802+ name = "PCI-DSS" ,
803+ section = "Some vague control with no good match" ,
804+ description = "" ,
805+ ),
806+ ]
807+
808+ fake_embedding = [0.1 ] * 768
809+ mock_get_embeddings .return_value = fake_embedding
810+
811+ # First standard maps well, second does not
812+ mock_get_similar_cre .side_effect = [
813+ ("cre-db-id-123" , 0.85 ), # high confidence
814+ (None , None ), # low confidence / no match
815+ ]
816+
817+ # Act
818+ result = main .suggest_cre_mappings (
819+ standard_entries = standard_entries ,
820+ collection = self .collection ,
821+ )
822+
823+ # Assert
824+ self .assertEqual (len (result ["mapped" ]), 1 )
825+ self .assertEqual (len (result ["needs_review" ]), 1 )
826+
827+ mapped = result ["mapped" ][0 ]
828+ self .assertEqual (mapped ["suggested_cre_id" ], "cre-db-id-123" )
829+ self .assertEqual (mapped ["confidence" ], 0.85 )
830+ self .assertEqual (mapped ["standard" ]["name" ], "PCI-DSS" )
831+
832+ review = result ["needs_review" ][0 ]
833+ self .assertIsNone (review ["suggested_cre_id" ])
834+ self .assertIsNone (review ["confidence" ])
835+
836+ # Assert embeddings were called for each standard
837+ self .assertEqual (mock_get_embeddings .call_count , 2 )
838+
839+ @patch .object (prompt_client .PromptHandler , "get_text_embeddings" )
840+ def test_suggest_cre_mappings_empty_input (self , mock_get_embeddings ) -> None :
841+ result = main .suggest_cre_mappings (
842+ standard_entries = [],
843+ collection = self .collection ,
844+ )
845+ self .assertEqual (result , {"mapped" : [], "needs_review" : []})
846+ mock_get_embeddings .assert_not_called ()
786847
787848
788849if __name__ == "__main__" :
0 commit comments