-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
151 lines (137 loc) · 5.85 KB
/
main.py
File metadata and controls
151 lines (137 loc) · 5.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from astrbot.api import logger
from astrbot.api.provider import ProviderRequest
from astrbot.api.event import (
filter,
AstrMessageEvent,
MessageEventResult,
ResultContentType,
) # noqa
from astrbot.api.star import Context, Star, register, StarTools
from astrbot.api import logger # noqa
from astrbot.dashboard.server import Response
from .core.starter import ATRIMemoryStarter
from collections import defaultdict
from quart import request
PLUGIN_DATA_DIR = StarTools.get_data_dir("atri")
@register("atri", "Soulter", "ATRI - My Dear Moments", "0.0.1")
class ATRIPlugin(Star):
def __init__(self, context: Context):
super().__init__(context)
self.user_counter = defaultdict(int)
# 阈值
self.sum_threshold = 10
self.dialogs = defaultdict(list) # umo -> history
self.context.register_web_api("/alkaid/ltm/graph", self.api_get_graph, ["GET"], "获取记忆图数据")
self.context.register_web_api("/alkaid/ltm/user_ids", self.api_get_user_ids, ["GET"], "获取所有用户ID")
self.context.register_web_api("/alkaid/ltm/graph/add", self.api_add_graph, ["POST"], "添加记忆图数据")
self.context.register_web_api("/alkaid/ltm/graph/search", self.api_search_graph, ["GET"], "搜索记忆图数据")
async def api_get_graph(self):
# params
user_id = request.args.get("user_id", None)
group_id = request.args.get("group_id", None)
filter = {}
if user_id:
filter["user_id"] = user_id
if group_id:
filter["group_id"] = group_id
result = await self.memory_layer.graph_memory.get_graph(filter)
return Response().ok(data=result).__dict__
async def api_get_user_ids(self):
result = await self.memory_layer.graph_memory.get_user_ids()
return Response().ok(data=result).__dict__
async def api_add_graph(self):
data = await request.get_json()
text = data.get("text")
user_id = data.get("user_id")
need_summarize = data.get("need_summarize", False)
if need_summarize:
text = await self.memory_layer.summarizer.summarize(text)
await self.memory_layer.graph_memory.add_to_graph(
text=text,
user_id=user_id,
)
return Response().ok("添加成功").__dict__
async def api_search_graph(self):
user_id = request.args.get("user_id", None)
query = request.args.get("query", None)
filters = {}
if user_id:
filters["user_id"] = user_id
result = await self.memory_layer.graph_memory.search_graph(query, filters=filters)
return Response().ok(data=result).__dict__
@filter.on_astrbot_loaded()
async def on_astrbot_loaded(self):
self.llm_provider = self.context.provider_manager.curr_provider_inst
self.memory_layer = ATRIMemoryStarter(
data_dir_path=PLUGIN_DATA_DIR,
llm_provider=self.llm_provider,
)
await self.memory_layer.initialize()
@filter.on_llm_request()
async def requesting(self, event: AstrMessageEvent, req: ProviderRequest):
"""处理请求事件"""
filters = {
"user_id": str(event.get_sender_id()),
}
if event.get_group_id():
filters["group_id"] = str(event.get_group_id())
results = await self.memory_layer.graph_memory.search_graph(
req.prompt,
num_to_retrieval=5,
filters=filters,
)
if results:
req.system_prompt += (
"\n\nHere are related memories between you and user:\n" + str(results)
)
def parse_identifier(self, event: AstrMessageEvent):
name = event.get_sender_name()
user_id = event.get_sender_id()
if name == user_id:
return name
elif not name:
return user_id
else:
return name
# @filter.after_message_sent()
@filter.event_message_type(filter.EventMessageType.ALL)
async def after_message(self, event: AstrMessageEvent):
"""处理消息事件"""
if not event.message_str: # TODO: 处理多模态信息
return
# result = event.get_result()
# TODO: streaming result?
# if not result or result.result_content_type != ResultContentType.LLM_RESULT:
# return
uid = event.unified_msg_origin
identifier = self.parse_identifier(event)
message = event.message_str.replace("\n", " ")
self.dialogs[uid].append(f"User({identifier}): {message}")
# self.dialogs[uid].append(f"Me: {result.get_plain_text()}")
self.user_counter[uid] += 1
if self.user_counter[uid] >= self.sum_threshold:
logger.info(
f"User {uid} has sent {self.user_counter[uid]} messages. Summarizing conversation."
)
self.user_counter[uid] = 0
dialog = self.dialogs[uid]
dialog_str = "\n".join(dialog)
text = await self.memory_layer.summarizer.summarize(dialog_str)
logger.debug(f"Summarized text: {text}")
if "%None%" in text.strip():
logger.info("没有符合总结的内容,跳过这轮总结。")
self.dialogs[uid].clear()
return
elif "%Hold%" in text.strip():
logger.info("对话话题不完整,继续观察。")
return
await self.memory_layer.graph_memory.add_to_graph(
text=text,
user_id=str(event.get_sender_id()),
group_id=str(event.get_group_id()),
username=event.get_sender_name(),
)
logger.info("Added to graph.")
self.dialogs[uid].clear()
async def terminate(self):
"""可选择实现异步的插件销毁方法,当插件被卸载/停用时会调用。"""