diff --git a/astrbot/core/agent/message.py b/astrbot/core/agent/message.py index 582b1eef2..99faf0d13 100644 --- a/astrbot/core/agent/message.py +++ b/astrbot/core/agent/message.py @@ -172,6 +172,9 @@ class Message(BaseModel): content: str | list[ContentPart] | None = None """The content of the message.""" + name: str | None = None + """Optional name of the sender, used to identify different users in conversation.""" + tool_calls: list[ToolCall] | list[dict] | None = None """The tool calls of the message.""" @@ -198,6 +201,8 @@ def serialize(self, handler): data.pop("tool_calls", None) if self.tool_call_id is None: data.pop("tool_call_id", None) + if self.name is None: + data.pop("name", None) return data diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 690a6404c..d28607bfb 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -9,6 +9,7 @@ import zoneinfo from collections.abc import Coroutine from dataclasses import dataclass, field +from typing import Any from astrbot.api import sp from astrbot.core import logger @@ -145,15 +146,27 @@ async def _get_session_conv( ) -> Conversation: conv_mgr = plugin_context.conversation_manager umo = event.unified_msg_origin + user_name = event.get_sender_name() + avatar = event.get_sender_avatar() cid = await conv_mgr.get_curr_conversation_id(umo) if not cid: - cid = await conv_mgr.new_conversation(umo, event.get_platform_id()) + cid = await conv_mgr.new_conversation(umo, event.get_platform_id(), user_name=user_name, avatar=avatar) conversation = await conv_mgr.get_conversation(umo, cid) if not conversation: - cid = await conv_mgr.new_conversation(umo, event.get_platform_id()) + cid = await conv_mgr.new_conversation(umo, event.get_platform_id(), user_name=user_name, avatar=avatar) conversation = await conv_mgr.get_conversation(umo, cid) if not conversation: raise RuntimeError("无法创建新的对话。") + # 如果已有对话但 user_name 或 avatar 为空,更新它们 + updates: dict[str, Any] = {} + if conversation.user_name is None and user_name: + updates["user_name"] = user_name + if conversation.avatar is None and avatar: + updates["avatar"] = avatar + if updates: + await conv_mgr.db.update_conversation(cid, **updates) + for field, value in updates.items(): + setattr(conversation, field, value) return conversation diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index a0a0c0e2f..150342e33 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -6,6 +6,7 @@ import json from collections.abc import Awaitable, Callable +from datetime import timezone from astrbot.core import sp from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment @@ -58,8 +59,15 @@ async def _trigger_session_deleted(self, unified_msg_origin: str) -> None: def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation: """将 ConversationV2 对象转换为 Conversation 对象""" - created_at = int(conv_v2.created_at.timestamp()) - updated_at = int(conv_v2.updated_at.timestamp()) + # SQLite 读回的 datetime 可能丢失时区信息,需要显式标记为 UTC + ca = conv_v2.created_at + if ca.tzinfo is None: + ca = ca.replace(tzinfo=timezone.utc) + ua = conv_v2.updated_at + if ua.tzinfo is None: + ua = ua.replace(tzinfo=timezone.utc) + created_at = int(ca.timestamp()) + updated_at = int(ua.timestamp()) return Conversation( platform_id=conv_v2.platform_id, user_id=conv_v2.user_id, @@ -70,6 +78,8 @@ def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation: created_at=created_at, updated_at=updated_at, token_usage=conv_v2.token_usage, + user_name=conv_v2.user_name, + avatar=conv_v2.avatar, ) async def new_conversation( @@ -79,11 +89,15 @@ async def new_conversation( content: list[dict] | None = None, title: str | None = None, persona_id: str | None = None, + user_name: str | None = None, + avatar: str | None = None, ) -> str: """新建对话,并将当前会话的对话转移到新对话. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id + user_name (str | None): 用户名称 + avatar (str | None): 用户头像 URL Returns: conversation_id (str): 对话 ID, 是 uuid 格式的字符串 @@ -101,6 +115,8 @@ async def new_conversation( content=content, title=title, persona_id=persona_id, + user_name=user_name, + avatar=avatar, ) self.session_conversations[unified_msg_origin] = conv.conversation_id await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id) diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 7b67b8755..5c0dbf41f 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -145,6 +145,8 @@ async def create_conversation( cid: str | None = None, created_at: datetime.datetime | None = None, updated_at: datetime.datetime | None = None, + user_name: str | None = None, + avatar: str | None = None, ) -> ConversationV2: """Create a new conversation.""" ... @@ -157,6 +159,8 @@ async def update_conversation( persona_id: str | None = None, content: list[dict] | None = None, token_usage: int | None = None, + user_name: str | None = None, + avatar: str | None = None, ) -> None: """Update a conversation's history.""" ... diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 81649c0d7..aac6f6c5b 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -54,6 +54,9 @@ class ConversationV2(TimestampMixin, SQLModel, table=True): ) platform_id: str = Field(nullable=False) user_id: str = Field(nullable=False) + user_name: str | None = Field(default=None, max_length=255) + avatar: str | None = Field(default=None, max_length=512) + """用户头像 URL""" content: list | None = Field(default=None, sa_type=JSON) title: str | None = Field(default=None, max_length=255) @@ -418,6 +421,10 @@ class Conversation: updated_at: int = 0 token_usage: int = 0 """对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。""" + user_name: str | None = None + """发送消息的用户名称""" + avatar: str | None = None + """用户头像 URL""" class Personality(TypedDict): diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 153e13e8b..ad165c012 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -57,39 +57,58 @@ async def initialize(self) -> None: # 确保 personas 表有 folder_id、sort_order、skills 列(前向兼容) await self._ensure_persona_folder_columns(conn) await self._ensure_persona_skills_column(conn) + # 确保 conversations 表有 user_name 列(前向兼容) + await self._ensure_conversation_user_name_column(conn) + # 确保 conversations 表有 avatar 列(前向兼容) + await self._ensure_conversation_avatar_column(conn) await conn.commit() - async def _ensure_persona_folder_columns(self, conn) -> None: - """确保 personas 表有 folder_id 和 sort_order 列。 + async def _ensure_column( + self, conn, table: str, column: str, ddl: str + ) -> None: + """确保指定表有指定列,如果不存在则添加。 这是为了支持旧版数据库的平滑升级。新版数据库通过 SQLModel 的 metadata.create_all 自动创建这些列。 + + Args: + conn: 数据库连接 + table: 表名 + column: 列名 + ddl: ALTER TABLE 语句中的列定义(不包含 ALTER TABLE ... ADD COLUMN 部分) """ - result = await conn.execute(text("PRAGMA table_info(personas)")) + result = await conn.execute(text(f"PRAGMA table_info({table})")) columns = {row[1] for row in result.fetchall()} - if "folder_id" not in columns: + if column not in columns: await conn.execute( - text( - "ALTER TABLE personas ADD COLUMN folder_id VARCHAR(36) DEFAULT NULL" - ) - ) - if "sort_order" not in columns: - await conn.execute( - text("ALTER TABLE personas ADD COLUMN sort_order INTEGER DEFAULT 0") + text(f"ALTER TABLE {table} ADD COLUMN {ddl}") ) + async def _ensure_persona_folder_columns(self, conn) -> None: + """确保 personas 表有 folder_id 和 sort_order 列。""" + await self._ensure_column( + conn, "personas", "folder_id", "folder_id VARCHAR(36) DEFAULT NULL" + ) + await self._ensure_column( + conn, "personas", "sort_order", "sort_order INTEGER DEFAULT 0" + ) + async def _ensure_persona_skills_column(self, conn) -> None: - """确保 personas 表有 skills 列。 + """确保 personas 表有 skills 列。""" + await self._ensure_column(conn, "personas", "skills", "skills JSON") - 这是为了支持旧版数据库的平滑升级。新版数据库通过 SQLModel - 的 metadata.create_all 自动创建这些列。 - """ - result = await conn.execute(text("PRAGMA table_info(personas)")) - columns = {row[1] for row in result.fetchall()} + async def _ensure_conversation_user_name_column(self, conn) -> None: + """确保 conversations 表有 user_name 列。""" + await self._ensure_column( + conn, "conversations", "user_name", "user_name VARCHAR(255) DEFAULT NULL" + ) - if "skills" not in columns: - await conn.execute(text("ALTER TABLE personas ADD COLUMN skills JSON")) + async def _ensure_conversation_avatar_column(self, conn) -> None: + """确保 conversations 表有 avatar 列。""" + await self._ensure_column( + conn, "conversations", "avatar", "avatar VARCHAR(512) DEFAULT NULL" + ) # ==== # Platform Statistics @@ -259,6 +278,8 @@ async def create_conversation( cid=None, created_at=None, updated_at=None, + user_name=None, + avatar=None, ): kwargs = {} if cid: @@ -276,13 +297,15 @@ async def create_conversation( platform_id=platform_id, title=title, persona_id=persona_id, + user_name=user_name, + avatar=avatar, **kwargs, ) session.add(new_conversation) return new_conversation async def update_conversation( - self, cid, title=None, persona_id=None, content=None, token_usage=None + self, cid, title=None, persona_id=None, content=None, token_usage=None, user_name=None, avatar=None ): async with self.get_db() as session: session: AsyncSession @@ -299,6 +322,10 @@ async def update_conversation( values["content"] = content if token_usage is not None: values["token_usage"] = token_usage + if user_name is not None: + values["user_name"] = user_name + if avatar is not None: + values["avatar"] = avatar if not values: return None query = query.values(**values) diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index b99a5778b..95efdf63e 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -191,6 +191,12 @@ def get_sender_name(self) -> str: return self.message_obj.sender.nickname return "" + def get_sender_avatar(self) -> str | None: + """获取消息发送者的头像 URL。(可能会返回 None)""" + if hasattr(self.message_obj.sender, 'avatar'): + return self.message_obj.sender.avatar + return None + def set_extra(self, key, value): """设置额外的信息。""" self._extras[key] = value diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index 253963322..c8de70db2 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -10,6 +10,8 @@ class MessageMember: user_id: str # 发送者id nickname: str | None = None + avatar: str | None = None + """用户头像 URL""" def __str__(self): # 使用 f-string 来构建返回的字符串表示形式 diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index adc24578f..421f5e7a0 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -1,9 +1,10 @@ import asyncio import os import sys +import time import uuid from collections.abc import Awaitable, Callable -from typing import Any, cast +from typing import Any, NamedTuple, cast import quart from requests import Response @@ -38,6 +39,17 @@ from typing_extensions import override +# 客户信息缓存条目 +class CustomerCacheEntry(NamedTuple): + nickname: str + avatar: str | None + expire_at: float + + +# 客户信息 TTL 缓存(默认 5 分钟) +CUSTOMER_CACHE_TTL = 300 + + class WecomServer: def __init__(self, event_queue: asyncio.Queue, config: dict): self.server = quart.Quart(__name__) @@ -173,6 +185,8 @@ def __init__( # 微信客服 self.kf_name = self.config.get("kf_name", None) + # 客户信息缓存 (external_userid -> CustomerCacheEntry) + self._customer_cache: dict[str, CustomerCacheEntry] = {} if self.kf_name: # inject self.wechat_kf_api = WeChatKF(client=self.client) @@ -184,10 +198,10 @@ def __init__( async def callback(msg: BaseMessage): if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event": + token = msg._data["Token"] + kfid = msg._data["OpenKfId"] def get_latest_msg_item() -> dict | None: - token = msg._data["Token"] - kfid = msg._data["OpenKfId"] has_more = 1 ret = {} while has_more: @@ -203,6 +217,7 @@ def get_latest_msg_item() -> dict | None: get_latest_msg_item, ) if msg_new: + msg_new["open_kfid"] = kfid await self.convert_wechat_kf_message(msg_new) return await self.convert_message(msg) @@ -350,11 +365,45 @@ async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None: async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: msgtype = msg.get("msgtype") external_userid = cast(str, msg.get("external_userid")) + + # 尝试从缓存获取客户信息 + nickname = external_userid + avatar = None + now = time.time() + cached = self._customer_cache.get(external_userid) + if cached and cached.expire_at > now: + # 缓存命中 + nickname = cached.nickname + avatar = cached.avatar + logger.debug(f"客户信息缓存命中: external_userid={external_userid}") + else: + # 缓存未命中或已过期,调用 API + try: + customer_info = await asyncio.get_event_loop().run_in_executor( + None, + self.wechat_kf_api.batchget_customer, + external_userid, + ) + # 避免在日志中输出完整客户信息(包含昵称、头像等敏感数据) + logger.debug(f"获取客户信息成功: external_userid={external_userid}") + customer_list = customer_info.get("customer_list", []) + if customer_list: + nickname = customer_list[0].get("nickname", external_userid) + avatar = customer_list[0].get("avatar", None) + # 更新缓存 + self._customer_cache[external_userid] = CustomerCacheEntry( + nickname=nickname, + avatar=avatar, + expire_at=now + CUSTOMER_CACHE_TTL, + ) + except Exception as e: + logger.debug(f"获取客户信息失败: {e}") + abm = AstrBotMessage() abm.raw_message = msg abm.raw_message["_wechat_kf_flag"] = None # 方便处理 abm.self_id = msg["open_kfid"] - abm.sender = MessageMember(external_userid, external_userid) + abm.sender = MessageMember(external_userid, nickname, avatar) abm.session_id = external_userid abm.type = MessageType.FRIEND_MESSAGE abm.message_id = msg.get("msgid", uuid.uuid4().hex[:8]) diff --git a/astrbot/dashboard/routes/conversation.py b/astrbot/dashboard/routes/conversation.py index 513d3603f..ee659fd77 100644 --- a/astrbot/dashboard/routes/conversation.py +++ b/astrbot/dashboard/routes/conversation.py @@ -1,5 +1,6 @@ import json import traceback +from dataclasses import asdict from datetime import datetime from io import BytesIO @@ -88,8 +89,11 @@ async def list_conversations(self): (total_count + page_size - 1) // page_size if total_count > 0 else 1 ) + # 将 Conversation dataclass 对象转换为字典 + conversations_dict = [asdict(conv) for conv in conversations] + result = { - "conversations": conversations, + "conversations": conversations_dict, "pagination": { "page": page, "page_size": page_size, diff --git a/dashboard/src/i18n/locales/en-US/features/conversation.json b/dashboard/src/i18n/locales/en-US/features/conversation.json index 3e8cb7128..b580bda0b 100644 --- a/dashboard/src/i18n/locales/en-US/features/conversation.json +++ b/dashboard/src/i18n/locales/en-US/features/conversation.json @@ -28,6 +28,8 @@ "cid": "Conversation ID", "umo": "Unified Message Origin", "sessionId": "Session ID", + "userName": "User Name", + "avatar": "Avatar", "createdAt": "Created At", "updatedAt": "Updated At", "actions": "Actions" diff --git a/dashboard/src/i18n/locales/zh-CN/features/conversation.json b/dashboard/src/i18n/locales/zh-CN/features/conversation.json index 8a5ca6eb5..3247aa7d4 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/conversation.json +++ b/dashboard/src/i18n/locales/zh-CN/features/conversation.json @@ -28,6 +28,8 @@ "cid": "对话 ID", "umo": "消息会话来源", "sessionId": "会话 ID", + "userName": "用户名", + "avatar": "头像", "createdAt": "创建时间", "updatedAt": "更新时间", "actions": "操作" diff --git a/dashboard/src/views/ConversationPage.vue b/dashboard/src/views/ConversationPage.vue index 2a615b294..5d576b379 100644 --- a/dashboard/src/views/ConversationPage.vue +++ b/dashboard/src/views/ConversationPage.vue @@ -97,6 +97,17 @@ {{ item.sessionInfo.sessionId || tm('status.unknown') }} + + + + @@ -448,6 +459,8 @@ export default { { title: this.tm('table.headers.sessionId'), key: 'sessionId', sortable: true, width: '100px' }, ], }, + { title: this.tm('table.headers.userName'), key: 'user_name', sortable: true, width: '120px' }, + { title: this.tm('table.headers.avatar'), key: 'avatar', sortable: false, width: '80px' }, { title: this.tm('table.headers.createdAt'), key: 'created_at', sortable: true, width: '180px' }, { title: this.tm('table.headers.updatedAt'), key: 'updated_at', sortable: true, width: '180px' }, { title: this.tm('table.headers.actions'), key: 'actions', sortable: false, align: 'center' }