diff --git a/backend/app/api/dingtalk.py b/backend/app/api/dingtalk.py index 9a431ecf2..8137e1c6c 100644 --- a/backend/app/api/dingtalk.py +++ b/backend/app/api/dingtalk.py @@ -136,6 +136,122 @@ async def delete_dingtalk_channel( asyncio.create_task(dingtalk_stream_manager.stop_client(agent_id)) +# ─── Message Dedup (processing → done state machine) ────────────────────── + +# 单条消息允许处理的最长时间(超出自动释放, 允许钉钉重传再次进入) +PROCESSING_TTL: float = 180.0 +# 成功处理后的去重保留窗口 +DONE_TTL: float = 600.0 +# 内存 fallback periodic cleanup 触发间隔 +_DEDUP_GC_EVERY: int = 100 + +# 内存存储: {message_id: (state, expire_at_monotonic)} ; state ∈ {"processing", "done"} +_processed_messages: dict[str, tuple[str, float]] = {} +_dedup_check_counter: int = 0 + +# Redis 客户端 factory: 模块级可替换变量, 便于测试 monkeypatch 禁用。 +# 失败时置 None, _redis_client_or_none 返回 None 退到内存 fallback。 +try: + from app.core.events import get_redis as _get_redis_client +except Exception: # pragma: no cover - import 层容错 + _get_redis_client = None # type: ignore[assignment] + + +def _dedup_now() -> float: + import time as _t + return _t.monotonic() + + +def _dedup_gc(now: float) -> None: + global _dedup_check_counter + _dedup_check_counter += 1 + if _dedup_check_counter % _DEDUP_GC_EVERY != 0: + return + expired = [k for k, (_, exp) in _processed_messages.items() if exp < now] + for k in expired: + del _processed_messages[k] + + +async def _redis_client_or_none(): + if _get_redis_client is None: + return None + try: + return await _get_redis_client() + except Exception: + return None + + +def _redis_key(message_id: str) -> str: + return f"dingtalk:dedup:{message_id}" + + +async def acquire_dedup_lock(message_id: str) -> tuple[bool, str]: + """Claim processing rights for a given DingTalk message_id. + + Returns (accepted, state): + accepted=True → 获得处理权, state="new" + accepted=False → 已有处理 / 已完成, state ∈ {"processing","done"} + """ + if not message_id: + return True, "new" + + redis = await _redis_client_or_none() + if redis is not None: + key = _redis_key(message_id) + # 先用 SET NX 尝试抢锁;不成功再读当前状态判断 processing/done + was_set = await redis.set( + key, "processing", ex=int(PROCESSING_TTL), nx=True + ) + if was_set: + return True, "new" + existing = await redis.get(key) + if existing in (b"done", "done"): + return False, "done" + return False, "processing" + + # 内存 fallback + now = _dedup_now() + hit = _processed_messages.get(message_id) + if hit and hit[1] > now: + return False, hit[0] + _processed_messages[message_id] = ("processing", now + PROCESSING_TTL) + _dedup_gc(now) + return True, "new" + + +async def mark_dedup_done(message_id: str) -> None: + """Transition processing → done with DONE_TTL.""" + if not message_id: + return + redis = await _redis_client_or_none() + if redis is not None: + await redis.set(_redis_key(message_id), "done", ex=int(DONE_TTL)) + return + now = _dedup_now() + _processed_messages[message_id] = ("done", now + DONE_TTL) + + +async def release_dedup_lock(message_id: str) -> None: + """Release processing lock on failure so a retransmit can re-attempt. + + 仅在当前值为 "processing" 时删除, 避免误删 done。 + """ + if not message_id: + return + redis = await _redis_client_or_none() + if redis is not None: + try: + current = await redis.get(_redis_key(message_id)) + if current in (b"processing", "processing"): + await redis.delete(_redis_key(message_id)) + except Exception: + pass + return + hit = _processed_messages.get(message_id) + if hit and hit[0] == "processing": + del _processed_messages[message_id] + + # ─── Message Processing (called by Stream callback) ──── async def process_dingtalk_message( @@ -145,8 +261,56 @@ async def process_dingtalk_message( conversation_id: str, conversation_type: str, session_webhook: str, + message_id: str = "", ): - """Process an incoming DingTalk bot message and reply via session webhook.""" + """Process an incoming DingTalk bot message and reply via session webhook. + + Dedup wrapper: + - acquire processing lock + - on success → mark done + - on exception → release lock so retransmit can retry + """ + accepted, state = await acquire_dedup_lock(message_id) + if not accepted: + logger.info( + f"[DingTalk] Skip duplicate message_id={message_id} (state={state})" + ) + return + + try: + await _process_dingtalk_message_inner( + agent_id=agent_id, + sender_staff_id=sender_staff_id, + user_text=user_text, + conversation_id=conversation_id, + conversation_type=conversation_type, + session_webhook=session_webhook, + ) + except Exception: + logger.exception( + f"[DingTalk] Processing failed for message_id={message_id}; releasing dedup lock" + ) + await release_dedup_lock(message_id) + raise + else: + await mark_dedup_done(message_id) + + +async def _process_dingtalk_message_inner( + *, + agent_id: uuid.UUID, + sender_staff_id: str, + user_text: str, + conversation_id: str, + conversation_type: str, + session_webhook: str, +) -> None: + """Actual DingTalk message processing (outside the dedup wrapper). + + Body is a verbatim move of the original process_dingtalk_message body, + starting from `import json` onward. The dedup check that used to be at + the top is now handled by the wrapper. + """ import json import httpx from datetime import datetime, timezone diff --git a/backend/app/services/dingtalk_cache.py b/backend/app/services/dingtalk_cache.py new file mode 100644 index 000000000..4259c957e --- /dev/null +++ b/backend/app/services/dingtalk_cache.py @@ -0,0 +1,52 @@ +"""进程内 TTL 缓存 helper, 服务于 dingtalk_service 的 token/user_detail。 + +设计要点: +- 单节点内存缓存: 多副本各自缓存, 不致命(token 每副本独立拉取; + user_detail 不变, 重复拉取只是多一次请求, 不会引发业务错误)。 +- single-flight: 同 key 并发调用时合并为一次 factory 执行, 避免 thundering herd。 +- 失败结果不进入缓存: 由调用方在 factory 返回无效值后主动 invalidate。 +""" +from __future__ import annotations + +import asyncio +import time +from typing import Any, Awaitable, Callable + + +class TTLCache: + def __init__(self, default_ttl: float = 60.0) -> None: + self._default_ttl = default_ttl + self._store: dict[str, tuple[float, Any]] = {} + self._locks: dict[str, asyncio.Lock] = {} + + def _lock(self, key: str) -> asyncio.Lock: + lock = self._locks.get(key) + if lock is None: + lock = asyncio.Lock() + self._locks[key] = lock + return lock + + async def get_or_set( + self, + key: str, + factory: Callable[[], Awaitable[Any]], + ttl: float | None = None, + ) -> Any: + now = time.monotonic() + hit = self._store.get(key) + if hit and hit[0] > now: + return hit[1] + + lock = self._lock(key) + async with lock: + hit = self._store.get(key) + now = time.monotonic() + if hit and hit[0] > now: + return hit[1] + value = await factory() + expire_at = now + (ttl if ttl is not None else self._default_ttl) + self._store[key] = (expire_at, value) + return value + + def invalidate(self, key: str) -> None: + self._store.pop(key, None) diff --git a/backend/app/services/dingtalk_service.py b/backend/app/services/dingtalk_service.py index 010fc56a4..dec59fc0b 100644 --- a/backend/app/services/dingtalk_service.py +++ b/backend/app/services/dingtalk_service.py @@ -4,31 +4,75 @@ import httpx from loguru import logger +from app.services.dingtalk_cache import TTLCache -async def get_dingtalk_access_token(app_id: str, app_secret: str) -> dict: - """Get DingTalk access_token using app_id and app_secret. - - API: https://open.dingtalk.com/document/orgapp/obtain-access_token - """ - url = "https://oapi.dingtalk.com/gettoken" - params = { - "appkey": app_id, - "appsecret": app_secret, - } +# DingTalk access_token expires_in 通常 7200s; 提前 200s 过期以留出刷新余量。 +_token_cache = TTLCache(default_ttl=7000) +# user/get 返回结果 30 分钟内保持有效。 +_user_detail_cache = TTLCache(default_ttl=1800) - async with httpx.AsyncClient(timeout=10) as client: - try: - resp = await client.get(url, params=params) - data = resp.json() - if data.get("errcode") == 0: - return {"access_token": data.get("access_token"), "expires_in": data.get("expires_in")} - else: +async def get_dingtalk_access_token(app_id: str, app_secret: str) -> dict: + """Get DingTalk access_token, 带 TTL 缓存(single-flight)。""" + cache_key = f"token:{app_id}" + + async def _fetch() -> dict: + url = "https://oapi.dingtalk.com/gettoken" + params = {"appkey": app_id, "appsecret": app_secret} + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.get(url, params=params) + data = resp.json() + if data.get("errcode") == 0: + return { + "access_token": data.get("access_token"), + "expires_in": data.get("expires_in"), + } logger.error(f"[DingTalk] Failed to get access_token: {data}") return {"errcode": data.get("errcode"), "errmsg": data.get("errmsg")} - except Exception as e: - logger.error(f"[DingTalk] Network error getting access_token: {e}") - return {"errcode": -1, "errmsg": str(e)} + except Exception as e: + logger.error(f"[DingTalk] Network error getting access_token: {e}") + return {"errcode": -1, "errmsg": str(e)} + + result = await _token_cache.get_or_set(cache_key, _fetch) + # 失败结果不保留: 下次调用能重新拉取 + if not result.get("access_token"): + _token_cache.invalidate(cache_key) + return result + + +async def get_dingtalk_user_detail(app_id: str, app_secret: str, userid: str) -> dict | None: + """Fetch user detail from DingTalk corp API, 带 30min 缓存。""" + cache_key = f"userdetail:{app_id}:{userid}" + + async def _fetch() -> dict | None: + token_result = await get_dingtalk_access_token(app_id, app_secret) + access_token = token_result.get("access_token") + if not access_token: + return None + url = "https://oapi.dingtalk.com/topapi/v2/user/get" + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.post( + url, + params={"access_token": access_token}, + json={"userid": userid}, + ) + data = resp.json() + if data.get("errcode") == 0: + return data.get("result", {}) + logger.warning( + f"[DingTalk] user/get failed for {userid}: {data.get('errmsg')}" + ) + return None + except Exception as e: + logger.warning(f"[DingTalk] user/get error for {userid}: {e}") + return None + + result = await _user_detail_cache.get_or_set(cache_key, _fetch) + if not result: + _user_detail_cache.invalidate(cache_key) + return result async def send_dingtalk_v1_robot_oto_message( diff --git a/backend/app/services/dingtalk_stream.py b/backend/app/services/dingtalk_stream.py index 28a8ba8e2..c51ed4830 100644 --- a/backend/app/services/dingtalk_stream.py +++ b/backend/app/services/dingtalk_stream.py @@ -98,6 +98,7 @@ async def process(self, callback: dingtalk_stream.CallbackMessage): conversation_id = incoming.conversation_id or "" conversation_type = incoming.conversation_type or "1" session_webhook = incoming.session_webhook or "" + message_id = incoming.message_id or "" logger.info( f"[DingTalk Stream] Message from [{incoming.sender_nick}]{sender_staff_id}: {user_text[:80]}" @@ -107,7 +108,7 @@ async def process(self, callback: dingtalk_stream.CallbackMessage): from app.api.dingtalk import process_dingtalk_message if main_loop and main_loop.is_running(): - future = asyncio.run_coroutine_threadsafe( + asyncio.run_coroutine_threadsafe( process_dingtalk_message( agent_id=agent_id, sender_staff_id=sender_staff_id, @@ -115,16 +116,11 @@ async def process(self, callback: dingtalk_stream.CallbackMessage): conversation_id=conversation_id, conversation_type=conversation_type, session_webhook=session_webhook, + message_id=message_id, ), main_loop, ) - # Wait for result (with timeout) - try: - future.result(timeout=120) - except Exception as e: - logger.error(f"[DingTalk Stream] LLM processing error: {e}") - import traceback - traceback.print_exc() + # Fire-and-forget: ACK immediately, do not wait for LLM else: logger.warning("[DingTalk Stream] Main loop not available for dispatch") diff --git a/backend/tests/test_dingtalk_dedup.py b/backend/tests/test_dingtalk_dedup.py new file mode 100644 index 000000000..429db1326 --- /dev/null +++ b/backend/tests/test_dingtalk_dedup.py @@ -0,0 +1,66 @@ +import asyncio + +import pytest + +from app.api import dingtalk as dingtalk_api + + +@pytest.fixture(autouse=True) +def _reset_state(monkeypatch): + dingtalk_api._processed_messages.clear() + dingtalk_api._dedup_check_counter = 0 + # 禁用 Redis: 强制走内存 fallback + monkeypatch.setattr(dingtalk_api, "_get_redis_client", None) + + +async def test_acquire_first_returns_accepted(): + accepted, state = await dingtalk_api.acquire_dedup_lock("m-1") + assert accepted is True + assert state == "new" + + +async def test_acquire_second_while_processing_returns_duplicate(): + await dingtalk_api.acquire_dedup_lock("m-2") + accepted, state = await dingtalk_api.acquire_dedup_lock("m-2") + assert accepted is False + assert state == "processing" + + +async def test_acquire_after_done_returns_duplicate(): + await dingtalk_api.acquire_dedup_lock("m-3") + await dingtalk_api.mark_dedup_done("m-3") + accepted, state = await dingtalk_api.acquire_dedup_lock("m-3") + assert accepted is False + assert state == "done" + + +async def test_release_allows_retry(): + await dingtalk_api.acquire_dedup_lock("m-4") + await dingtalk_api.release_dedup_lock("m-4") + accepted, state = await dingtalk_api.acquire_dedup_lock("m-4") + assert accepted is True + assert state == "new" + + +async def test_release_after_done_is_noop(): + await dingtalk_api.acquire_dedup_lock("m-4b") + await dingtalk_api.mark_dedup_done("m-4b") + await dingtalk_api.release_dedup_lock("m-4b") # 不应把 done 删掉 + accepted, state = await dingtalk_api.acquire_dedup_lock("m-4b") + assert accepted is False + assert state == "done" + + +async def test_empty_message_id_always_accepts(): + accepted, state = await dingtalk_api.acquire_dedup_lock("") + assert accepted is True + assert state == "new" + + +async def test_processing_ttl_expires(monkeypatch): + monkeypatch.setattr(dingtalk_api, "PROCESSING_TTL", 0.05, raising=False) + await dingtalk_api.acquire_dedup_lock("m-5") + await asyncio.sleep(0.08) + accepted, state = await dingtalk_api.acquire_dedup_lock("m-5") + assert accepted is True + assert state == "new" diff --git a/backend/tests/test_dingtalk_service_cache.py b/backend/tests/test_dingtalk_service_cache.py new file mode 100644 index 000000000..1ccc6f7db --- /dev/null +++ b/backend/tests/test_dingtalk_service_cache.py @@ -0,0 +1,144 @@ +import asyncio + +import pytest + +from app.services.dingtalk_cache import TTLCache +from app.services import dingtalk_service + + +class _FakeResp: + def __init__(self, payload): + self._payload = payload + self.status_code = 200 + + def json(self): + return self._payload + + +class _FakeClient: + instances: list["_FakeClient"] = [] + + def __init__(self, *args, **kwargs): + self.get_calls: list[tuple[str, dict]] = [] + self.post_calls: list[tuple[str, dict, dict]] = [] + _FakeClient.instances.append(self) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def get(self, url, params=None): + self.get_calls.append((url, params or {})) + return _FakeResp({"errcode": 0, "access_token": "tok-A", "expires_in": 7200}) + + async def post(self, url, params=None, json=None): + self.post_calls.append((url, params or {}, json or {})) + return _FakeResp({ + "errcode": 0, + "result": { + "userid": json["userid"], + "name": "Alice", + "mobile": "13800000000", + "email": "alice@example.com", + "unionid": f"UNION-{json['userid']}", + }, + }) + + +@pytest.fixture(autouse=True) +def _reset_caches(monkeypatch): + dingtalk_service._token_cache._store.clear() + dingtalk_service._user_detail_cache._store.clear() + _FakeClient.instances.clear() + monkeypatch.setattr(dingtalk_service.httpx, "AsyncClient", _FakeClient) + + +async def test_ttl_cache_hit_miss(): + cache = TTLCache(default_ttl=60) + calls = {"n": 0} + + async def factory(): + calls["n"] += 1 + return {"v": calls["n"]} + + v1 = await cache.get_or_set("k", factory) + v2 = await cache.get_or_set("k", factory) + assert v1 == v2 == {"v": 1} + assert calls["n"] == 1 + + +async def test_ttl_cache_expires(): + cache = TTLCache(default_ttl=0.05) + calls = {"n": 0} + + async def factory(): + calls["n"] += 1 + return calls["n"] + + assert await cache.get_or_set("k", factory) == 1 + await asyncio.sleep(0.08) + assert await cache.get_or_set("k", factory) == 2 + + +async def test_ttl_cache_single_flight(): + cache = TTLCache(default_ttl=60) + calls = {"n": 0} + start = asyncio.Event() + + async def factory(): + await start.wait() + calls["n"] += 1 + return calls["n"] + + t1 = asyncio.create_task(cache.get_or_set("k", factory)) + t2 = asyncio.create_task(cache.get_or_set("k", factory)) + await asyncio.sleep(0) + start.set() + r1, r2 = await asyncio.gather(t1, t2) + assert r1 == r2 == 1 + assert calls["n"] == 1 + + +async def test_access_token_cached_across_calls(): + t1 = await dingtalk_service.get_dingtalk_access_token("APP", "SEC") + t2 = await dingtalk_service.get_dingtalk_access_token("APP", "SEC") + assert t1["access_token"] == t2["access_token"] == "tok-A" + gets = [c for inst in _FakeClient.instances for c in inst.get_calls] + assert len(gets) == 1 + + +async def test_user_detail_cached_per_userid(): + d1 = await dingtalk_service.get_dingtalk_user_detail("APP", "SEC", "user-1") + d2 = await dingtalk_service.get_dingtalk_user_detail("APP", "SEC", "user-1") + d3 = await dingtalk_service.get_dingtalk_user_detail("APP", "SEC", "user-2") + assert d1 == d2 + assert d1["unionid"] == "UNION-user-1" + assert d3["unionid"] == "UNION-user-2" + user_posts = [ + c for inst in _FakeClient.instances + for c in inst.post_calls if "user/get" in c[0] + ] + assert len(user_posts) == 2 # 2 distinct userids + + +class _EmptyResultClient(_FakeClient): + async def post(self, url, params=None, json=None): + self.post_calls.append((url, params or {}, json or {})) + # DingTalk returns success but result field missing → _fetch coerces to {} + return _FakeResp({"errcode": 0}) + + +async def test_user_detail_empty_result_is_not_cached(monkeypatch): + monkeypatch.setattr(dingtalk_service.httpx, "AsyncClient", _EmptyResultClient) + r1 = await dingtalk_service.get_dingtalk_user_detail("APP", "SEC", "user-x") + r2 = await dingtalk_service.get_dingtalk_user_detail("APP", "SEC", "user-x") + assert r1 == {} + assert r2 == {} + user_posts = [ + c for inst in _FakeClient.instances + for c in inst.post_calls if "user/get" in c[0] + ] + # Without the fix: only 1 POST (empty dict cached). With fix: 2 POSTs. + assert len(user_posts) == 2