diff --git a/backend/app/api/dingtalk.py b/backend/app/api/dingtalk.py index 9a431ecf2..b4d0f02d2 100644 --- a/backend/app/api/dingtalk.py +++ b/backend/app/api/dingtalk.py @@ -187,6 +187,22 @@ async def process_dingtalk_message( ) platform_user_id = platform_user.id + # Check for channel commands (/new, /reset) + from app.services.channel_commands import is_channel_command, handle_channel_command + if is_channel_command(user_text): + cmd_result = await handle_channel_command( + db=db, command=user_text, agent_id=agent_id, + user_id=platform_user_id, external_conv_id=conv_id, + source_channel="dingtalk", + ) + await db.commit() + async with httpx.AsyncClient(timeout=10) as _cl_cmd: + await _cl_cmd.post(session_webhook, json={ + "msgtype": "text", + "text": {"content": cmd_result["message"]}, + }) + return + # Find or create session sess = await find_or_create_channel_session( db=db, diff --git a/backend/app/api/feishu.py b/backend/app/api/feishu.py index aff83664a..329e5f244 100644 --- a/backend/app/api/feishu.py +++ b/backend/app/api/feishu.py @@ -446,6 +446,13 @@ async def process_feishu_event(agent_id: uuid.UUID, body: dict, db: AsyncSession from app.models.agent import DEFAULT_CONTEXT_WINDOW_SIZE ctx_size = (agent_obj.context_window_size or DEFAULT_CONTEXT_WINDOW_SIZE) if agent_obj else DEFAULT_CONTEXT_WINDOW_SIZE + # Detect channel command early, but defer processing until we have + # resolved the real sender's platform_user_id (see below). Handling the + # command before resolve_channel_user() would attribute the new P2P + # session to the agent creator instead of the actual Feishu sender. + from app.services.channel_commands import is_channel_command, handle_channel_command + _is_cmd = is_channel_command(user_text) + # Pre-resolve session so history lookup uses the UUID (session created later if new) _pre_sess_r = await db.execute( select(__import__('app.models.chat_session', fromlist=['ChatSession']).ChatSession).where( @@ -561,6 +568,27 @@ async def process_feishu_event(agent_id: uuid.UUID, body: dict, db: AsyncSession ) platform_user_id = platform_user.id + # Now that the real sender is resolved, handle /new or /reset so the + # replacement P2P session is attributed to the sender (not creator_id). + # Mirrors the user_id rule used by find_or_create_channel_session below: + # group → creator_id (placeholder); P2P → platform_user_id. + if _is_cmd: + _is_group_cmd = (chat_type == "group") + _cmd_user_id = creator_id if _is_group_cmd else platform_user_id + _cmd_result = await handle_channel_command( + db=db, command=user_text, agent_id=agent_id, + user_id=_cmd_user_id, external_conv_id=conv_id, + source_channel="feishu", + ) + await db.commit() + import json as _j_cmd + _cmd_reply = _j_cmd.dumps({"text": _cmd_result["message"]}) + if _is_group_cmd and chat_id: + await feishu_service.send_message(config.app_id, config.app_secret, chat_id, "text", _cmd_reply, receive_id_type="chat_id") + else: + await feishu_service.send_message(config.app_id, config.app_secret, sender_open_id, "text", _cmd_reply) + return {"code": 0, "msg": "command handled"} + # ── Find-or-create a ChatSession via external_conv_id (DB-based, no cache needed) ── from datetime import datetime as _dt, timezone as _tz _is_group = (chat_type == "group") diff --git a/backend/app/services/channel_commands.py b/backend/app/services/channel_commands.py new file mode 100644 index 000000000..3b7df17e9 --- /dev/null +++ b/backend/app/services/channel_commands.py @@ -0,0 +1,79 @@ +"""Channel command handler for external channels (DingTalk, Feishu, etc.) + +Supports slash commands like /new to reset session context. +""" + +import uuid +from datetime import datetime, timezone +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.chat_session import ChatSession +from app.services.channel_session import find_or_create_channel_session + + +COMMANDS = {"/new", "/reset"} + + +def is_channel_command(text: str) -> bool: + """Check if the message is a recognized channel command.""" + stripped = text.strip().lower() + return stripped in COMMANDS + + +async def handle_channel_command( + db: AsyncSession, + command: str, + agent_id: uuid.UUID, + user_id: uuid.UUID, + external_conv_id: str, + source_channel: str, +) -> dict: + """Handle a channel command and return response info. + + Returns: + {"action": "new_session", "message": "..."} + """ + cmd = command.strip().lower() + + if cmd in ("/new", "/reset"): + # Find current session. Scope by source_channel as well so we never + # accidentally archive a session from a different channel that happens + # to share the same external_conv_id (defensive against future changes + # to the per-channel ID prefix scheme). + result = await db.execute( + select(ChatSession).where( + ChatSession.agent_id == agent_id, + ChatSession.external_conv_id == external_conv_id, + ChatSession.source_channel == source_channel, + ) + ) + old_session = result.scalar_one_or_none() + + if old_session: + # Rename old external_conv_id so find_or_create will make a new one + now = datetime.now(timezone.utc) + old_session.external_conv_id = ( + f"{external_conv_id}__archived_{now.strftime('%Y%m%d_%H%M%S')}" + ) + await db.flush() + + # Create new session + new_session = ChatSession( + agent_id=agent_id, + user_id=user_id, + title="New Session", + source_channel=source_channel, + external_conv_id=external_conv_id, + created_at=datetime.now(timezone.utc), + ) + db.add(new_session) + await db.flush() + + return { + "action": "new_session", + "session_id": str(new_session.id), + "message": "已开启新对话,之前的上下文已清除。", + } + + return {"action": "unknown", "message": f"未知命令: {cmd}"} diff --git a/backend/tests/test_channel_commands.py b/backend/tests/test_channel_commands.py new file mode 100644 index 000000000..4d15cafa6 --- /dev/null +++ b/backend/tests/test_channel_commands.py @@ -0,0 +1,205 @@ +"""Unit tests for app.services.channel_commands. + +Covers the two review concerns addressed in PR #342: +1. `handle_channel_command()` must archive only the session matching the + correct `source_channel` (so a Feishu `/new` never archives a DingTalk + session with the same external_conv_id, etc.). +2. The new session it creates must be attributed to the `user_id` passed + in by the caller (not silently falling back to some other id). +""" + +from __future__ import annotations + +import uuid +from types import SimpleNamespace +from typing import Any + +import pytest + +from app.services import channel_commands + + +class _ExecutedQuery: + """Captures WHERE-clause state for an executed SQLAlchemy select().""" + + def __init__(self, statement: Any) -> None: + self.statement = statement + # Extract column names referenced by equality comparisons in the WHERE + # clause. This lets tests assert that source_channel is part of the + # filter without depending on clause order. + self.filter_columns: set[str] = set() + self.filter_values: dict[str, Any] = {} + whereclause = getattr(statement, "whereclause", None) + self._collect(whereclause) + + def _collect(self, clause: Any) -> None: + if clause is None: + return + # BooleanClauseList (AND/OR) has .clauses + sub_clauses = getattr(clause, "clauses", None) + if sub_clauses: + for c in sub_clauses: + self._collect(c) + return + left = getattr(clause, "left", None) + right = getattr(clause, "right", None) + if left is not None: + name = getattr(left, "key", None) or getattr(left, "name", None) + if name: + self.filter_columns.add(name) + if right is not None: + # BindParameter exposes .value + val = getattr(right, "value", None) + if val is not None: + self.filter_values[name] = val + + +class _FakeResult: + def __init__(self, value: Any) -> None: + self._value = value + + def scalar_one_or_none(self) -> Any: + return self._value + + +class FakeDB: + """Minimal AsyncSession stub that records executes / adds / flush / commit.""" + + def __init__(self, lookup_result: Any = None) -> None: + self._lookup_result = lookup_result + self.executed: list[_ExecutedQuery] = [] + self.added: list[Any] = [] + self.flushes = 0 + + async def execute(self, statement, _params=None): # noqa: D401 + self.executed.append(_ExecutedQuery(statement)) + return _FakeResult(self._lookup_result) + + def add(self, obj) -> None: + # Assign an id so handle_channel_command can stringify it. + if getattr(obj, "id", None) is None: + try: + obj.id = uuid.uuid4() + except Exception: + pass + self.added.append(obj) + + async def flush(self) -> None: + self.flushes += 1 + + +@pytest.mark.asyncio +async def test_handle_channel_command_scopes_lookup_by_source_channel(): + """Regression test for review concern #2. + + The session-archive lookup must include `source_channel` in its WHERE + clause so a /new command on one channel never archives a same-external-id + session on another channel. + """ + agent_id = uuid.uuid4() + user_id = uuid.uuid4() + + db = FakeDB(lookup_result=None) # no pre-existing session + + result = await channel_commands.handle_channel_command( + db=db, + command="/new", + agent_id=agent_id, + user_id=user_id, + external_conv_id="feishu_p2p_ou_xxx", + source_channel="feishu", + ) + + assert result["action"] == "new_session" + # Exactly one SELECT for the old-session lookup. + assert len(db.executed) == 1 + q = db.executed[0] + # The WHERE clause must filter on all three columns. + assert "agent_id" in q.filter_columns + assert "external_conv_id" in q.filter_columns + assert "source_channel" in q.filter_columns, ( + "handle_channel_command() must scope the archive lookup by source_channel " + "so it never archives a cross-channel session with a colliding external_conv_id" + ) + assert q.filter_values.get("source_channel") == "feishu" + + +@pytest.mark.asyncio +async def test_handle_channel_command_only_archives_same_channel_session(): + """A session from a different channel must not be archived. + + We simulate the lookup returning None (as it would when the pre-existing + session has a different source_channel) and assert the old record's + external_conv_id is left untouched. + """ + agent_id = uuid.uuid4() + user_id = uuid.uuid4() + + # Simulate a session that belongs to a different channel — the query with + # source_channel='feishu' should miss it, so the fake returns None. + db = FakeDB(lookup_result=None) + + await channel_commands.handle_channel_command( + db=db, + command="/new", + agent_id=agent_id, + user_id=user_id, + external_conv_id="shared_conv_id_xxx", + source_channel="feishu", + ) + + # Only the new session should have been added; no archival mutation. + assert len(db.added) == 1 + new_sess = db.added[0] + assert new_sess.source_channel == "feishu" + assert new_sess.external_conv_id == "shared_conv_id_xxx" + # Confirm the user_id on the new session is the one we passed in. + assert new_sess.user_id == user_id + + +@pytest.mark.asyncio +async def test_handle_channel_command_uses_caller_user_id_for_new_session(): + """Regression test for review concern #1. + + `handle_channel_command()` must attribute the replacement session to the + `user_id` the caller supplied. The Feishu caller now passes the resolved + platform_user_id (instead of creator_id) — this test locks the contract. + """ + agent_id = uuid.uuid4() + sender_platform_user_id = uuid.uuid4() + + # Existing session to be archived. + old_session = SimpleNamespace(external_conv_id="feishu_p2p_ou_zzz") + db = FakeDB(lookup_result=old_session) + + result = await channel_commands.handle_channel_command( + db=db, + command="/reset", + agent_id=agent_id, + user_id=sender_platform_user_id, + external_conv_id="feishu_p2p_ou_zzz", + source_channel="feishu", + ) + + assert result["action"] == "new_session" + # Old session got its external_conv_id renamed to the archived form. + assert old_session.external_conv_id.startswith("feishu_p2p_ou_zzz__archived_") + # New session was added with the caller-supplied user_id. + assert len(db.added) == 1 + new_sess = db.added[0] + assert new_sess.user_id == sender_platform_user_id + assert new_sess.agent_id == agent_id + assert new_sess.source_channel == "feishu" + assert new_sess.external_conv_id == "feishu_p2p_ou_zzz" + + +@pytest.mark.asyncio +async def test_is_channel_command_recognises_slash_commands(): + assert channel_commands.is_channel_command("/new") is True + assert channel_commands.is_channel_command("/reset") is True + assert channel_commands.is_channel_command(" /NEW ") is True + assert channel_commands.is_channel_command("/RESET") is True + # Non-commands + assert channel_commands.is_channel_command("hello") is False + assert channel_commands.is_channel_command("/newish") is False + assert channel_commands.is_channel_command("") is False