-
Notifications
You must be signed in to change notification settings - Fork 792
Expand file tree
/
Copy pathadd_handler.py
More file actions
197 lines (164 loc) · 7.21 KB
/
add_handler.py
File metadata and controls
197 lines (164 loc) · 7.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""
Add handler for memory addition functionality (Class-based version).
This module provides a class-based implementation of add handlers,
using dependency injection for better modularity and testability.
"""
import os
import threading
from pydantic import validate_call
from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
from memos.api.handlers.component_init import create_per_db_components
from memos.api.product_models import APIADDRequest, APIFeedbackRequest, MemoryResponse
from memos.memories.textual.item import (
list_all_fields,
)
from memos.multi_mem_cube.composite_cube import CompositeCubeView
from memos.multi_mem_cube.single_cube import SingleCubeView
from memos.multi_mem_cube.views import MemCubeView
from memos.types import MessageList
class AddHandler(BaseHandler):
"""
Handler for memory addition operations.
Handles text memory additions with sync/async support.
"""
def __init__(self, dependencies: HandlerDependencies):
"""
Initialize add handler.
Args:
dependencies: HandlerDependencies instance
"""
super().__init__(dependencies)
self._validate_dependencies(
"naive_mem_cube", "mem_reader", "mem_scheduler", "feedback_server"
)
self._per_user_cube_cache: dict[str, dict] = {}
self._cache_lock = threading.Lock()
def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse:
"""
Main handler for add memories endpoint.
Orchestrates the addition of text memories,
supporting concurrent processing.
Args:
add_req: Add memory request (deprecated fields are converted in model validator)
Returns:
MemoryResponse with added memory information
"""
self.logger.info(
f"[DIAGNOSTIC] server_router -> add_handler.handle_add_memories called (Modified at 2025-11-29 18:46). Full request: {add_req.model_dump_json(indent=2)}"
)
if add_req.info:
exclude_fields = list_all_fields()
info_len = len(add_req.info)
add_req.info = {k: v for k, v in add_req.info.items() if k not in exclude_fields}
if len(add_req.info) < info_len:
self.logger.warning(f"[AddHandler] info fields can not contain {exclude_fields}.")
cube_view = self._build_cube_view(add_req)
@validate_call
def _check_messages(messages: MessageList) -> None:
pass
if add_req.is_feedback:
try:
messages = add_req.messages
_check_messages(messages)
chat_history = add_req.chat_history if add_req.chat_history else []
concatenate_chat = chat_history + messages
last_user_index = max(
i for i, d in enumerate(concatenate_chat) if d["role"] == "user"
)
feedback_content = concatenate_chat[last_user_index]["content"]
feedback_history = concatenate_chat[:last_user_index]
feedback_req = APIFeedbackRequest(
user_id=add_req.user_id,
session_id=add_req.session_id,
task_id=add_req.task_id,
history=feedback_history,
feedback_content=feedback_content,
writable_cube_ids=add_req.writable_cube_ids,
async_mode=add_req.async_mode,
info=add_req.info,
)
process_record = cube_view.feedback_memories(feedback_req)
self.logger.info(
f"[ADDFeedbackHandler] Final feedback results count={len(process_record)}"
)
return MemoryResponse(
message="Memory feedback successfully",
data=[process_record],
)
except Exception as e:
self.logger.warning(f"[ADDFeedbackHandler] Running error: {e}")
results = cube_view.add_memories(add_req)
self.logger.info(f"[AddHandler] Final add results count={len(results)}")
return MemoryResponse(
message="Memory added successfully",
data=results,
)
@property
def _is_neo4j_multidb(self) -> bool:
"""Return True when using Neo4j enterprise with one-database-per-user mode."""
backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "")).lower()
shared_db = os.getenv("MOS_NEO4J_SHARED_DB", "false").lower() == "true"
return backend == "neo4j" and not shared_db
def _get_per_user_components(self, user_id: str) -> dict:
"""Return (creating on first access) per-user graph/mem components.
Uses double-checked locking so the expensive component creation happens
only once per user even under concurrent requests.
"""
if user_id not in self._per_user_cube_cache:
with self._cache_lock:
if user_id not in self._per_user_cube_cache:
self.logger.info(
f"[AddHandler] Creating per-user components for user_id={user_id!r}"
)
self._per_user_cube_cache[user_id] = create_per_db_components(
db_name=user_id,
base_components=vars(self.deps),
)
return self._per_user_cube_cache[user_id]
def _resolve_cube_ids(self, add_req: APIADDRequest) -> list[str]:
"""
Normalize target cube ids from add_req.
Priority:
1) writable_cube_ids (deprecated mem_cube_id is converted to this in model validator)
2) fallback to user_id
"""
if add_req.writable_cube_ids:
return list(dict.fromkeys(add_req.writable_cube_ids))
return [add_req.user_id]
def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView:
cube_ids = self._resolve_cube_ids(add_req)
if self._is_neo4j_multidb:
per_user = self._get_per_user_components(add_req.user_id)
naive_mem_cube = per_user["naive_mem_cube"]
mem_reader = per_user["mem_reader"]
else:
naive_mem_cube = self.naive_mem_cube
mem_reader = self.mem_reader
if len(cube_ids) == 1:
cube_id = cube_ids[0]
return SingleCubeView(
cube_id=cube_id,
naive_mem_cube=naive_mem_cube,
mem_reader=mem_reader,
mem_scheduler=self.mem_scheduler,
logger=self.logger,
feedback_server=self.feedback_server,
searcher=None,
)
else:
single_views = [
SingleCubeView(
cube_id=cube_id,
naive_mem_cube=naive_mem_cube,
mem_reader=mem_reader,
mem_scheduler=self.mem_scheduler,
logger=self.logger,
feedback_server=self.feedback_server,
searcher=None,
)
for cube_id in cube_ids
]
return CompositeCubeView(
cube_views=single_views,
logger=self.logger,
)