-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsentencetransformer.py
More file actions
74 lines (60 loc) · 2.44 KB
/
sentencetransformer.py
File metadata and controls
74 lines (60 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from sentence_transformers import SentenceTransformer, util
import numpy as np
# Simulated database (in real project, this would be PostgreSQL)
image_database = {}
class ImageMatcher:
def __init__(self):
# Load the model - this might take a few seconds first time
self.model = SentenceTransformer('all-MiniLM-L6-v2')
def add_image(self, image_id, description):
# Convert description to vector and store
vector = self.model.encode(description)
image_database[image_id] = {
'description': description,
'vector': vector
}
def find_similar_image(self, query_description, threshold=0.8):
# Convert query to vector
query_vector = self.model.encode(query_description)
best_match = None
highest_similarity = 0
# Compare with all stored images
for image_id, data in image_database.items():
similarity = util.cos_sim(query_vector, data['vector'])[0][0]
similarity = float(similarity) # Convert tensor to float
if similarity > highest_similarity:
highest_similarity = similarity
best_match = image_id
# Return best match if it's above threshold
if highest_similarity >= threshold:
return best_match, highest_similarity
return None, highest_similarity
# Test the system
def test_matcher():
matcher = ImageMatcher()
# Add some test images
test_images = {
'img1': 'a red car in a city street during daytime',
'img2': 'a blue car parked in an urban area',
'img3': 'a cat sleeping on a windowsill',
'img4': 'red automobile on city road in daylight'
}
for img_id, desc in test_images.items():
matcher.add_image(img_id, desc)
# Test cases
test_queries = [
'red vehicle in city during day',
'cat napping by window',
'airplane flying in sky'
]
print("\nTesting similarity matching:\n")
for query in test_queries:
match_id, similarity = matcher.find_similar_image(query)
print(f"\nQuery: '{query}'")
if match_id:
print(f"Found match: '{test_images[match_id]}'")
print(f"Similarity score: {similarity:.2f}")
else:
print(f"No match found above threshold. Best similarity: {similarity:.2f}")
if __name__ == "__main__":
test_matcher()