-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_note_embeddings.py
More file actions
124 lines (102 loc) · 3.59 KB
/
generate_note_embeddings.py
File metadata and controls
124 lines (102 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from utils import *
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
from qdrant_client.http.models import PointStruct
from sentence_transformers import SentenceTransformer
import os
import dotenv
dotenv.load_dotenv()
COLLECTION_NAME = os.getenv("COLLECTION_NAME")
MODEL_NAME = "all-mpnet-base-v2"
def init_vector_db():
client = QdrantClient("http://localhost:6333")
model = SentenceTransformer(MODEL_NAME)
if not any(
[c.name == COLLECTION_NAME for c in client.get_collections().collections]
):
client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(
size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE
),
)
return client, model
client, model = init_vector_db()
import hashlib
secret_key = os.getenv("SECRET_KEY")
assert secret_key, "SECRET_KEY is not set in .env"
def embed_all_notes_into_vector_database(note_path, meta_data, content, root_directory):
note_title = (
("\t" + meta_data["title"])
if "title" in meta_data
else ("\t" + str(meta_data["id"]) if "id" in meta_data else note_path)
)
# hash note_title with secret key
hash_object = hashlib.sha256()
hash_object.update(note_title.encode("utf-8") + secret_key.encode("utf-8"))
note_id = int(hash_object.hexdigest()[:16], 16)
# Check if note is already in the database
embedded_content = model.encode(clean_text_for_embedding_model(content))
search_result = client.search(
collection_name=COLLECTION_NAME, query_vector=embedded_content, limit=1
)
if not search_result or search_result[0].score < 0.95:
# Embed and store the note
embedding = embedded_content.tolist()
client.upsert(
collection_name=COLLECTION_NAME,
points=[
PointStruct(
id=note_id,
payload={
"title": note_title,
"content": content,
}, # TODO: Add note creation time/edit time to payload
vector=embedding,
)
],
)
print(f"Embedded note: {note_title}")
else:
print(f"Note already embedded: {note_title}")
return None, None, False
def query_notes(query):
query_vector = model.encode(clean_text_for_embedding_model(query)).tolist()
results = client.search(
collection_name=COLLECTION_NAME, query_vector=query_vector, limit=10
)
print(f"Query: {query}")
print("Results:")
for result in results:
print(f"- {result.payload['title']} (similarity: {result.score:.2f})")
print(f" Content: {result.payload['content'][:100]}...")
print("-----------------------")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--directories",
nargs="+",
type=str,
help="List of directories to process",
default=[],
required=False,
)
parser.add_argument(
"--files",
nargs="+",
type=str,
help="List of files to process",
required=False,
default=[],
)
args = parser.parse_args()
for directory in args.directories:
loop_through_directories(
directory,
[embed_all_notes_into_vector_database],
clear_bottom_matter=False,
)
for file in args.files:
process_note(
file, "", [embed_all_notes_into_vector_database], clear_bottom_matter=False
)