Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions astrbot/core/agent/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ class Message(BaseModel):
content: str | list[ContentPart] | None = None
"""The content of the message."""

name: str | None = None
"""Optional name of the sender, used to identify different users in conversation."""

tool_calls: list[ToolCall] | list[dict] | None = None
"""The tool calls of the message."""

Expand All @@ -198,6 +201,8 @@ def serialize(self, handler):
data.pop("tool_calls", None)
if self.tool_call_id is None:
data.pop("tool_call_id", None)
if self.name is None:
data.pop("name", None)
return data


Expand Down
17 changes: 15 additions & 2 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import zoneinfo
from collections.abc import Coroutine
from dataclasses import dataclass, field
from typing import Any

from astrbot.api import sp
from astrbot.core import logger
Expand Down Expand Up @@ -145,15 +146,27 @@ async def _get_session_conv(
) -> Conversation:
conv_mgr = plugin_context.conversation_manager
umo = event.unified_msg_origin
user_name = event.get_sender_name()
avatar = event.get_sender_avatar()
cid = await conv_mgr.get_curr_conversation_id(umo)
if not cid:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
cid = await conv_mgr.new_conversation(umo, event.get_platform_id(), user_name=user_name, avatar=avatar)
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
cid = await conv_mgr.new_conversation(umo, event.get_platform_id(), user_name=user_name, avatar=avatar)
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
raise RuntimeError("无法创建新的对话。")
# 如果已有对话但 user_name 或 avatar 为空,更新它们
updates: dict[str, Any] = {}
if conversation.user_name is None and user_name:
updates["user_name"] = user_name
if conversation.avatar is None and avatar:
updates["avatar"] = avatar
if updates:
await conv_mgr.db.update_conversation(cid, **updates)
for field, value in updates.items():
setattr(conversation, field, value)
return conversation


Expand Down
20 changes: 18 additions & 2 deletions astrbot/core/conversation_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import json
from collections.abc import Awaitable, Callable
from datetime import timezone

from astrbot.core import sp
from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment
Expand Down Expand Up @@ -58,8 +59,15 @@ async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:

def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
"""将 ConversationV2 对象转换为 Conversation 对象"""
created_at = int(conv_v2.created_at.timestamp())
updated_at = int(conv_v2.updated_at.timestamp())
# SQLite 读回的 datetime 可能丢失时区信息,需要显式标记为 UTC
ca = conv_v2.created_at
if ca.tzinfo is None:
ca = ca.replace(tzinfo=timezone.utc)
ua = conv_v2.updated_at
if ua.tzinfo is None:
ua = ua.replace(tzinfo=timezone.utc)
created_at = int(ca.timestamp())
updated_at = int(ua.timestamp())
return Conversation(
platform_id=conv_v2.platform_id,
user_id=conv_v2.user_id,
Expand All @@ -70,6 +78,8 @@ def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
created_at=created_at,
updated_at=updated_at,
token_usage=conv_v2.token_usage,
user_name=conv_v2.user_name,
avatar=conv_v2.avatar,
)

async def new_conversation(
Expand All @@ -79,11 +89,15 @@ async def new_conversation(
content: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
user_name: str | None = None,
avatar: str | None = None,
) -> str:
"""新建对话,并将当前会话的对话转移到新对话.

Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
user_name (str | None): 用户名称
avatar (str | None): 用户头像 URL
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串

Expand All @@ -101,6 +115,8 @@ async def new_conversation(
content=content,
title=title,
persona_id=persona_id,
user_name=user_name,
avatar=avatar,
)
self.session_conversations[unified_msg_origin] = conv.conversation_id
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
Expand Down
4 changes: 4 additions & 0 deletions astrbot/core/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ async def create_conversation(
cid: str | None = None,
created_at: datetime.datetime | None = None,
updated_at: datetime.datetime | None = None,
user_name: str | None = None,
avatar: str | None = None,
) -> ConversationV2:
"""Create a new conversation."""
...
Expand All @@ -157,6 +159,8 @@ async def update_conversation(
persona_id: str | None = None,
content: list[dict] | None = None,
token_usage: int | None = None,
user_name: str | None = None,
avatar: str | None = None,
) -> None:
"""Update a conversation's history."""
...
Expand Down
7 changes: 7 additions & 0 deletions astrbot/core/db/po.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class ConversationV2(TimestampMixin, SQLModel, table=True):
)
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False)
user_name: str | None = Field(default=None, max_length=255)
avatar: str | None = Field(default=None, max_length=512)
"""用户头像 URL"""
content: list | None = Field(default=None, sa_type=JSON)

title: str | None = Field(default=None, max_length=255)
Expand Down Expand Up @@ -418,6 +421,10 @@ class Conversation:
updated_at: int = 0
token_usage: int = 0
"""对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。"""
user_name: str | None = None
"""发送消息的用户名称"""
avatar: str | None = None
"""用户头像 URL"""


class Personality(TypedDict):
Expand Down
67 changes: 47 additions & 20 deletions astrbot/core/db/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,39 +57,58 @@ async def initialize(self) -> None:
# 确保 personas 表有 folder_id、sort_order、skills 列(前向兼容)
await self._ensure_persona_folder_columns(conn)
await self._ensure_persona_skills_column(conn)
# 确保 conversations 表有 user_name 列(前向兼容)
await self._ensure_conversation_user_name_column(conn)
# 确保 conversations 表有 avatar 列(前向兼容)
await self._ensure_conversation_avatar_column(conn)
await conn.commit()

async def _ensure_persona_folder_columns(self, conn) -> None:
"""确保 personas 表有 folder_id 和 sort_order 列。
async def _ensure_column(
self, conn, table: str, column: str, ddl: str
) -> None:
"""确保指定表有指定列,如果不存在则添加。

这是为了支持旧版数据库的平滑升级。新版数据库通过 SQLModel
的 metadata.create_all 自动创建这些列。

Args:
conn: 数据库连接
table: 表名
column: 列名
ddl: ALTER TABLE 语句中的列定义(不包含 ALTER TABLE ... ADD COLUMN 部分)
"""
result = await conn.execute(text("PRAGMA table_info(personas)"))
result = await conn.execute(text(f"PRAGMA table_info({table})"))
columns = {row[1] for row in result.fetchall()}

if "folder_id" not in columns:
if column not in columns:
await conn.execute(
text(
"ALTER TABLE personas ADD COLUMN folder_id VARCHAR(36) DEFAULT NULL"
)
)
if "sort_order" not in columns:
await conn.execute(
text("ALTER TABLE personas ADD COLUMN sort_order INTEGER DEFAULT 0")
text(f"ALTER TABLE {table} ADD COLUMN {ddl}")
)

async def _ensure_persona_folder_columns(self, conn) -> None:
"""确保 personas 表有 folder_id 和 sort_order 列。"""
await self._ensure_column(
conn, "personas", "folder_id", "folder_id VARCHAR(36) DEFAULT NULL"
)
await self._ensure_column(
conn, "personas", "sort_order", "sort_order INTEGER DEFAULT 0"
)

async def _ensure_persona_skills_column(self, conn) -> None:
"""确保 personas 表有 skills 列。
"""确保 personas 表有 skills 列。"""
await self._ensure_column(conn, "personas", "skills", "skills JSON")

这是为了支持旧版数据库的平滑升级。新版数据库通过 SQLModel
的 metadata.create_all 自动创建这些列。
"""
result = await conn.execute(text("PRAGMA table_info(personas)"))
columns = {row[1] for row in result.fetchall()}
async def _ensure_conversation_user_name_column(self, conn) -> None:
"""确保 conversations 表有 user_name 列。"""
await self._ensure_column(
conn, "conversations", "user_name", "user_name VARCHAR(255) DEFAULT NULL"
)

if "skills" not in columns:
await conn.execute(text("ALTER TABLE personas ADD COLUMN skills JSON"))
async def _ensure_conversation_avatar_column(self, conn) -> None:
"""确保 conversations 表有 avatar 列。"""
await self._ensure_column(
conn, "conversations", "avatar", "avatar VARCHAR(512) DEFAULT NULL"
)

# ====
# Platform Statistics
Expand Down Expand Up @@ -259,6 +278,8 @@ async def create_conversation(
cid=None,
created_at=None,
updated_at=None,
user_name=None,
avatar=None,
):
kwargs = {}
if cid:
Expand All @@ -276,13 +297,15 @@ async def create_conversation(
platform_id=platform_id,
title=title,
persona_id=persona_id,
user_name=user_name,
avatar=avatar,
**kwargs,
)
session.add(new_conversation)
return new_conversation

async def update_conversation(
self, cid, title=None, persona_id=None, content=None, token_usage=None
self, cid, title=None, persona_id=None, content=None, token_usage=None, user_name=None, avatar=None
):
async with self.get_db() as session:
session: AsyncSession
Expand All @@ -299,6 +322,10 @@ async def update_conversation(
values["content"] = content
if token_usage is not None:
values["token_usage"] = token_usage
if user_name is not None:
values["user_name"] = user_name
if avatar is not None:
values["avatar"] = avatar
if not values:
return None
query = query.values(**values)
Expand Down
6 changes: 6 additions & 0 deletions astrbot/core/platform/astr_message_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ def get_sender_name(self) -> str:
return self.message_obj.sender.nickname
return ""

def get_sender_avatar(self) -> str | None:
"""获取消息发送者的头像 URL。(可能会返回 None)"""
if hasattr(self.message_obj.sender, 'avatar'):
return self.message_obj.sender.avatar
return None

def set_extra(self, key, value):
"""设置额外的信息。"""
self._extras[key] = value
Expand Down
2 changes: 2 additions & 0 deletions astrbot/core/platform/astrbot_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
class MessageMember:
user_id: str # 发送者id
nickname: str | None = None
avatar: str | None = None
"""用户头像 URL"""

def __str__(self):
# 使用 f-string 来构建返回的字符串表示形式
Expand Down
57 changes: 53 additions & 4 deletions astrbot/core/platform/sources/wecom/wecom_adapter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import os
import sys
import time
import uuid
from collections.abc import Awaitable, Callable
from typing import Any, cast
from typing import Any, NamedTuple, cast

import quart
from requests import Response
Expand Down Expand Up @@ -38,6 +39,17 @@
from typing_extensions import override


# 客户信息缓存条目
class CustomerCacheEntry(NamedTuple):
nickname: str
avatar: str | None
expire_at: float


# 客户信息 TTL 缓存(默认 5 分钟)
CUSTOMER_CACHE_TTL = 300


class WecomServer:
def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__)
Expand Down Expand Up @@ -173,6 +185,8 @@ def __init__(

# 微信客服
self.kf_name = self.config.get("kf_name", None)
# 客户信息缓存 (external_userid -> CustomerCacheEntry)
self._customer_cache: dict[str, CustomerCacheEntry] = {}
if self.kf_name:
# inject
self.wechat_kf_api = WeChatKF(client=self.client)
Expand All @@ -184,10 +198,10 @@ def __init__(

async def callback(msg: BaseMessage):
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
token = msg._data["Token"]
kfid = msg._data["OpenKfId"]

def get_latest_msg_item() -> dict | None:
token = msg._data["Token"]
kfid = msg._data["OpenKfId"]
has_more = 1
ret = {}
while has_more:
Expand All @@ -203,6 +217,7 @@ def get_latest_msg_item() -> dict | None:
get_latest_msg_item,
)
if msg_new:
msg_new["open_kfid"] = kfid
await self.convert_wechat_kf_message(msg_new)
return
await self.convert_message(msg)
Expand Down Expand Up @@ -350,11 +365,45 @@ async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
msgtype = msg.get("msgtype")
external_userid = cast(str, msg.get("external_userid"))

# 尝试从缓存获取客户信息
nickname = external_userid
avatar = None
now = time.time()
cached = self._customer_cache.get(external_userid)
if cached and cached.expire_at > now:
# 缓存命中
nickname = cached.nickname
avatar = cached.avatar
logger.debug(f"客户信息缓存命中: external_userid={external_userid}")
else:
# 缓存未命中或已过期,调用 API
try:
customer_info = await asyncio.get_event_loop().run_in_executor(
None,
self.wechat_kf_api.batchget_customer,
external_userid,
)
# 避免在日志中输出完整客户信息(包含昵称、头像等敏感数据)
logger.debug(f"获取客户信息成功: external_userid={external_userid}")
customer_list = customer_info.get("customer_list", [])
if customer_list:
nickname = customer_list[0].get("nickname", external_userid)
avatar = customer_list[0].get("avatar", None)
# 更新缓存
self._customer_cache[external_userid] = CustomerCacheEntry(
nickname=nickname,
avatar=avatar,
expire_at=now + CUSTOMER_CACHE_TTL,
)
except Exception as e:
logger.debug(f"获取客户信息失败: {e}")

abm = AstrBotMessage()
abm.raw_message = msg
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
abm.self_id = msg["open_kfid"]
abm.sender = MessageMember(external_userid, external_userid)
abm.sender = MessageMember(external_userid, nickname, avatar)
abm.session_id = external_userid
abm.type = MessageType.FRIEND_MESSAGE
abm.message_id = msg.get("msgid", uuid.uuid4().hex[:8])
Expand Down
Loading
Loading