diff --git a/backend/app/api/dingtalk.py b/backend/app/api/dingtalk.py index 9a431ecf2..4bba24787 100644 --- a/backend/app/api/dingtalk.py +++ b/backend/app/api/dingtalk.py @@ -136,6 +136,122 @@ async def delete_dingtalk_channel( asyncio.create_task(dingtalk_stream_manager.stop_client(agent_id)) +# ─── Message Dedup (processing → done state machine) ────────────────────── + +# 单条消息允许处理的最长时间(超出自动释放, 允许钉钉重传再次进入) +PROCESSING_TTL: float = 180.0 +# 成功处理后的去重保留窗口 +DONE_TTL: float = 600.0 +# 内存 fallback periodic cleanup 触发间隔 +_DEDUP_GC_EVERY: int = 100 + +# 内存存储: {message_id: (state, expire_at_monotonic)} ; state ∈ {"processing", "done"} +_processed_messages: dict[str, tuple[str, float]] = {} +_dedup_check_counter: int = 0 + +# Redis 客户端 factory: 模块级可替换变量, 便于测试 monkeypatch 禁用。 +# 失败时置 None, _redis_client_or_none 返回 None 退到内存 fallback。 +try: + from app.core.events import get_redis as _get_redis_client +except Exception: # pragma: no cover - import 层容错 + _get_redis_client = None # type: ignore[assignment] + + +def _dedup_now() -> float: + import time as _t + return _t.monotonic() + + +def _dedup_gc(now: float) -> None: + global _dedup_check_counter + _dedup_check_counter += 1 + if _dedup_check_counter % _DEDUP_GC_EVERY != 0: + return + expired = [k for k, (_, exp) in _processed_messages.items() if exp < now] + for k in expired: + del _processed_messages[k] + + +async def _redis_client_or_none(): + if _get_redis_client is None: + return None + try: + return await _get_redis_client() + except Exception: + return None + + +def _redis_key(message_id: str) -> str: + return f"dingtalk:dedup:{message_id}" + + +async def acquire_dedup_lock(message_id: str) -> tuple[bool, str]: + """Claim processing rights for a given DingTalk message_id. + + Returns (accepted, state): + accepted=True → 获得处理权, state="new" + accepted=False → 已有处理 / 已完成, state ∈ {"processing","done"} + """ + if not message_id: + return True, "new" + + redis = await _redis_client_or_none() + if redis is not None: + key = _redis_key(message_id) + # 先用 SET NX 尝试抢锁;不成功再读当前状态判断 processing/done + was_set = await redis.set( + key, "processing", ex=int(PROCESSING_TTL), nx=True + ) + if was_set: + return True, "new" + existing = await redis.get(key) + if existing in (b"done", "done"): + return False, "done" + return False, "processing" + + # 内存 fallback + now = _dedup_now() + hit = _processed_messages.get(message_id) + if hit and hit[1] > now: + return False, hit[0] + _processed_messages[message_id] = ("processing", now + PROCESSING_TTL) + _dedup_gc(now) + return True, "new" + + +async def mark_dedup_done(message_id: str) -> None: + """Transition processing → done with DONE_TTL.""" + if not message_id: + return + redis = await _redis_client_or_none() + if redis is not None: + await redis.set(_redis_key(message_id), "done", ex=int(DONE_TTL)) + return + now = _dedup_now() + _processed_messages[message_id] = ("done", now + DONE_TTL) + + +async def release_dedup_lock(message_id: str) -> None: + """Release processing lock on failure so a retransmit can re-attempt. + + 仅在当前值为 "processing" 时删除, 避免误删 done。 + """ + if not message_id: + return + redis = await _redis_client_or_none() + if redis is not None: + try: + current = await redis.get(_redis_key(message_id)) + if current in (b"processing", "processing"): + await redis.delete(_redis_key(message_id)) + except Exception: + pass + return + hit = _processed_messages.get(message_id) + if hit and hit[0] == "processing": + del _processed_messages[message_id] + + # ─── Message Processing (called by Stream callback) ──── async def process_dingtalk_message( @@ -145,8 +261,56 @@ async def process_dingtalk_message( conversation_id: str, conversation_type: str, session_webhook: str, + message_id: str = "", ): - """Process an incoming DingTalk bot message and reply via session webhook.""" + """Process an incoming DingTalk bot message and reply via session webhook. + + Dedup wrapper: + - acquire processing lock + - on success → mark done + - on exception → release lock so retransmit can retry + """ + accepted, state = await acquire_dedup_lock(message_id) + if not accepted: + logger.info( + f"[DingTalk] Skip duplicate message_id={message_id} (state={state})" + ) + return + + try: + await _process_dingtalk_message_inner( + agent_id=agent_id, + sender_staff_id=sender_staff_id, + user_text=user_text, + conversation_id=conversation_id, + conversation_type=conversation_type, + session_webhook=session_webhook, + ) + except Exception: + logger.exception( + f"[DingTalk] Processing failed for message_id={message_id}; releasing dedup lock" + ) + await release_dedup_lock(message_id) + raise + else: + await mark_dedup_done(message_id) + + +async def _process_dingtalk_message_inner( + *, + agent_id: uuid.UUID, + sender_staff_id: str, + user_text: str, + conversation_id: str, + conversation_type: str, + session_webhook: str, +) -> None: + """Actual DingTalk message processing (outside the dedup wrapper). + + Body is a verbatim move of the original process_dingtalk_message body, + starting from `import json` onward. The dedup check that used to be at + the top is now handled by the wrapper. + """ import json import httpx from datetime import datetime, timezone @@ -177,13 +341,49 @@ async def process_dingtalk_message( # P2P / single chat conv_id = f"dingtalk_p2p_{sender_staff_id}" - # Resolve channel user via unified service (uses OrgMember + SSO patterns) + # Fetch user detail from DingTalk corp API for cross-channel matching + extra_info: dict = {"unionid": sender_staff_id} + try: + cfg_r = await db.execute( + _select(ChannelConfig).where( + ChannelConfig.agent_id == agent_id, + ChannelConfig.channel_type == "dingtalk", + ) + ) + dt_config = cfg_r.scalar_one_or_none() + if dt_config and dt_config.app_id and dt_config.app_secret: + from app.services.dingtalk_service import get_dingtalk_user_detail + user_detail = await get_dingtalk_user_detail( + dt_config.app_id, dt_config.app_secret, sender_staff_id + ) + if user_detail: + dt_mobile = user_detail.get("mobile", "") + dt_email = user_detail.get("email", "") or user_detail.get("org_email", "") + dt_unionid = user_detail.get("unionid", "") + dt_name = user_detail.get("name", "") + extra_info = { + "unionid": dt_unionid or sender_staff_id, + "name": dt_name, + "mobile": dt_mobile or None, + "email": dt_email or None, + "avatar_url": user_detail.get("avatar", ""), + } + except Exception as e: + logger.warning(f"[DingTalk] Failed to fetch user detail for {sender_staff_id}: {e}") + + # 真实 unionid 可能与 sender_staff_id 不同; 一并作为候选参与 OrgMember 匹配 + real_unionid = extra_info.get("unionid") + candidate_extra_ids: list[str] = [] + if real_unionid and real_unionid != sender_staff_id: + candidate_extra_ids.append(real_unionid) + platform_user = await channel_user_service.resolve_channel_user( db=db, agent=agent_obj, channel_type="dingtalk", external_user_id=sender_staff_id, - extra_info={"unionid": sender_staff_id}, + extra_info=extra_info, + extra_ids=candidate_extra_ids, ) platform_user_id = platform_user.id diff --git a/backend/app/services/channel_user_service.py b/backend/app/services/channel_user_service.py index 56fd90d41..4675422ef 100644 --- a/backend/app/services/channel_user_service.py +++ b/backend/app/services/channel_user_service.py @@ -9,7 +9,7 @@ from typing import Any from loguru import logger -from sqlalchemy import select +from sqlalchemy import or_, select from sqlalchemy.ext.asyncio import AsyncSession from app.core.security import hash_password @@ -30,6 +30,7 @@ async def resolve_channel_user( channel_type: str, external_user_id: str, extra_info: dict[str, Any] | None = None, + extra_ids: list[str] | None = None, ) -> User: """Resolve channel user identity, find or create platform User. @@ -45,6 +46,8 @@ async def resolve_channel_user( channel_type: "dingtalk" | "wecom" | "feishu" external_user_id: User ID from external platform (staff_id/userid/open_id) extra_info: Optional name/avatar/mobile/email from platform API + extra_ids: Additional candidate identifiers (e.g. real unionid discovered + via user/get) OR-matched against OrgMember.unionid/external_id. Returns: Resolved User instance @@ -55,9 +58,13 @@ async def resolve_channel_user( # Step 1: Ensure IdentityProvider exists provider = await self._ensure_provider(db, channel_type, tenant_id) - # Step 2: Try to find OrgMember by external identity + # Step 2: Try to find OrgMember by all candidate identifiers + candidate_ids: list[str] = [external_user_id] + for cid in (extra_ids or []): + if cid and cid not in candidate_ids: + candidate_ids.append(cid) org_member = await self._find_org_member( - db, provider.id, channel_type, external_user_id + db, provider.id, channel_type, candidate_ids ) # Step 3: Resolve User from OrgMember or other means @@ -70,6 +77,13 @@ async def resolve_channel_user( logger.debug( f"[{channel_type}] Found user via linked OrgMember: {user.id}" ) + try: + await self._enrich_user_from_extra_info(db, user, extra_info) + except Exception: + logger.exception( + f"[{channel_type}] enrichment failed for user {user.id}; " + f"continuing without enrichment" + ) return user # Step 4: Try to find User by email/mobile from extra_info @@ -90,36 +104,41 @@ async def resolve_channel_user( f"[{channel_type}] Matched user by mobile: {user.id}" ) - # If found User by email/mobile, link OrgMember if exists (only for org-sync channels) + # If found User by email/mobile, enrich and link OrgMember if user: + try: + await self._enrich_user_from_extra_info(db, user, extra_info) + except Exception: + logger.exception( + f"[{channel_type}] enrichment failed for user {user.id}; " + f"continuing without enrichment" + ) if channel_type in ("feishu", "dingtalk", "wecom"): if org_member and not org_member.user_id: - # Existing shell OrgMember not yet linked → link it + # Existing shell OrgMember not yet linked → link it + backfill ids org_member.user_id = user.id + self._backfill_org_member_ids( + org_member, channel_type, external_user_id, extra_info + ) elif not org_member: - # No OrgMember found by external_id. Before creating a new shell, - # check if this user already has an OrgMember from org sync so - # we reuse it instead of creating a duplicate entry. existing_member = await self._find_existing_org_member_for_user( db, user.id, provider.id, tenant_id ) if existing_member: - # Reuse the org-synced record: update its channel-specific IDs - # so future lookups by external_id work without a new shell. - if channel_type == "feishu": - if external_user_id.startswith("on_"): - existing_member.unionid = existing_member.unionid or external_user_id - elif external_user_id.startswith("ou_"): - existing_member.open_id = existing_member.open_id or external_user_id + # Reuse the org-synced record: back-fill channel identifiers + # so future direct lookups hit without another user/get call. + self._backfill_org_member_ids( + existing_member, channel_type, external_user_id, extra_info + ) logger.info( - f"[{channel_type}] Reusing org-synced OrgMember {existing_member.id} " - f"for user {user.id} instead of creating a duplicate shell" + f"[{channel_type}] Reusing org-synced OrgMember " + f"{existing_member.id} for user {user.id}; " + f"back-filled channel identifiers" ) else: - # Truly no OrgMember for this user → create shell await self._create_org_member_shell( db, provider, channel_type, external_user_id, extra_info, - linked_user_id=user.id + linked_user_id=user.id, ) await db.flush() return user @@ -134,10 +153,13 @@ async def resolve_channel_user( if channel_type in ("feishu", "dingtalk", "wecom"): if org_member: org_member.user_id = user.id + self._backfill_org_member_ids( + org_member, channel_type, external_user_id, extra_info + ) else: await self._create_org_member_shell( db, provider, channel_type, external_user_id, extra_info, - linked_user_id=user.id + linked_user_id=user.id, ) await db.flush() logger.info( @@ -177,47 +199,37 @@ async def _find_org_member( db: AsyncSession, provider_id: uuid.UUID, channel_type: str, - external_user_id: str, + candidate_ids: list[str], ) -> OrgMember | None: - """Find OrgMember by external identity. - - For Feishu: try unionid first, then open_id, then external_id - For DingTalk: try unionid first, then external_id - For WeCom: try external_id (userid) + """Find OrgMember by a list of candidate external identifiers. - Returns None if OrgMember not found or org sync is not enabled for this channel. + 所有候选 ID 走 OR 匹配, 适配钉钉同时拥有 staff_id 与 unionid 的场景。 """ + if not candidate_ids: + return None try: - # Build OR conditions for matching - conditions = [OrgMember.provider_id == provider_id, OrgMember.status == "active"] + base = [OrgMember.provider_id == provider_id, OrgMember.status == "active"] - # Channel-specific matching priority if channel_type == "feishu": - # Feishu: unionid is most stable, then open_id, then user_id - conditions.append( - (OrgMember.unionid == external_user_id) | - (OrgMember.open_id == external_user_id) | - (OrgMember.external_id == external_user_id) + id_match = or_( + OrgMember.unionid.in_(candidate_ids), + OrgMember.open_id.in_(candidate_ids), + OrgMember.external_id.in_(candidate_ids), ) elif channel_type == "dingtalk": - # DingTalk: unionid is stable across apps, then external_id - conditions.append( - (OrgMember.unionid == external_user_id) | - (OrgMember.external_id == external_user_id) + id_match = or_( + OrgMember.unionid.in_(candidate_ids), + OrgMember.external_id.in_(candidate_ids), ) elif channel_type == "wecom": - # WeCom: external_id (userid) is the primary identifier - conditions.append(OrgMember.external_id == external_user_id) + id_match = OrgMember.external_id.in_(candidate_ids) else: - # Generic fallback (discord, slack, etc. - no org sync) - # These channels don't have OrgMember, return None immediately return None - query = select(OrgMember).where(*conditions) + query = select(OrgMember).where(*base, id_match) result = await db.execute(query) return result.scalar_one_or_none() except Exception as e: - # OrgMember table may not exist or org sync not enabled logger.debug(f"[{channel_type}] OrgMember lookup failed: {e}") return None @@ -273,6 +285,121 @@ async def _find_existing_org_member_for_user( result = await db.execute(query.limit(1)) return result.scalar_one_or_none() + def _backfill_org_member_ids( + self, + member: OrgMember, + channel_type: str, + external_user_id: str, + extra_info: dict[str, Any], + ) -> None: + """回填 channel 特定的 identifier 到现有 OrgMember(只填空字段)。 + + 幂等: 重复调用不覆盖非空值。不写库, 依赖外层 flush。 + """ + unionid_from_api = extra_info.get("unionid") + + if channel_type == "dingtalk": + if not member.external_id and external_user_id: + member.external_id = external_user_id + if not member.unionid and unionid_from_api: + member.unionid = unionid_from_api + + elif channel_type == "feishu": + if external_user_id.startswith("on_"): + if not member.unionid: + member.unionid = external_user_id + elif external_user_id.startswith("ou_"): + if not member.open_id: + member.open_id = external_user_id + if not member.external_id and external_user_id: + member.external_id = external_user_id + if not member.unionid and unionid_from_api: + member.unionid = unionid_from_api + + elif channel_type == "wecom": + if not member.external_id and external_user_id: + member.external_id = external_user_id + + async def _enrich_user_from_extra_info( + self, + db: AsyncSession, + user: User, + extra_info: dict[str, Any], + ) -> None: + """Enrich existing user with mobile/email/name from channel extra_info. + + Only fills in fields that are currently empty on the user AND not + already claimed by another Identity (Identity.phone/email are globally + unique — writing a value that exists elsewhere would raise + IntegrityError and break the caller). On conflict, the field is + silently skipped (logged at warning level). + """ + from app.models.user import Identity + + updated = False + name = extra_info.get("name") + mobile = extra_info.get("mobile") + email = extra_info.get("email") + avatar = extra_info.get("avatar_url") + + if name and not user.display_name: + user.display_name = name + updated = True + if avatar and not user.avatar_url: + user.avatar_url = avatar + updated = True + + # Enrich Identity-level fields (phone, email) if available. + # Pre-check for conflicts on globally unique fields to avoid + # IntegrityError from collision with another Identity. + if user.identity_id and (mobile or email): + identity = await db.get(Identity, user.identity_id) + if identity: + if mobile and not identity.phone: + if await self._identity_field_in_use( + db, Identity.phone, mobile, identity.id + ): + logger.warning( + f"[enrich] phone={mobile} already claimed by another " + f"identity; skipping phone backfill for identity {identity.id}" + ) + else: + identity.phone = mobile + updated = True + if email and not identity.email: + if await self._identity_field_in_use( + db, Identity.email, email, identity.id + ): + logger.warning( + f"[enrich] email={email} already claimed by another " + f"identity; skipping email backfill for identity {identity.id}" + ) + else: + identity.email = email + updated = True + + if updated: + await db.flush() + + async def _identity_field_in_use( + self, + db: AsyncSession, + column, + value: str, + exclude_identity_id: uuid.UUID, + ) -> bool: + """Check whether any OTHER Identity already holds the given value on column. + + Used to pre-empt IntegrityError on Identity.phone/email (globally unique). + """ + from app.models.user import Identity + + stmt = select(Identity.id).where( + column == value, Identity.id != exclude_identity_id + ).limit(1) + result = await db.execute(stmt) + return result.scalar_one_or_none() is not None + async def _create_channel_user( self, db: AsyncSession, diff --git a/backend/app/services/dingtalk_cache.py b/backend/app/services/dingtalk_cache.py new file mode 100644 index 000000000..4259c957e --- /dev/null +++ b/backend/app/services/dingtalk_cache.py @@ -0,0 +1,52 @@ +"""进程内 TTL 缓存 helper, 服务于 dingtalk_service 的 token/user_detail。 + +设计要点: +- 单节点内存缓存: 多副本各自缓存, 不致命(token 每副本独立拉取; + user_detail 不变, 重复拉取只是多一次请求, 不会引发业务错误)。 +- single-flight: 同 key 并发调用时合并为一次 factory 执行, 避免 thundering herd。 +- 失败结果不进入缓存: 由调用方在 factory 返回无效值后主动 invalidate。 +""" +from __future__ import annotations + +import asyncio +import time +from typing import Any, Awaitable, Callable + + +class TTLCache: + def __init__(self, default_ttl: float = 60.0) -> None: + self._default_ttl = default_ttl + self._store: dict[str, tuple[float, Any]] = {} + self._locks: dict[str, asyncio.Lock] = {} + + def _lock(self, key: str) -> asyncio.Lock: + lock = self._locks.get(key) + if lock is None: + lock = asyncio.Lock() + self._locks[key] = lock + return lock + + async def get_or_set( + self, + key: str, + factory: Callable[[], Awaitable[Any]], + ttl: float | None = None, + ) -> Any: + now = time.monotonic() + hit = self._store.get(key) + if hit and hit[0] > now: + return hit[1] + + lock = self._lock(key) + async with lock: + hit = self._store.get(key) + now = time.monotonic() + if hit and hit[0] > now: + return hit[1] + value = await factory() + expire_at = now + (ttl if ttl is not None else self._default_ttl) + self._store[key] = (expire_at, value) + return value + + def invalidate(self, key: str) -> None: + self._store.pop(key, None) diff --git a/backend/app/services/dingtalk_service.py b/backend/app/services/dingtalk_service.py index 010fc56a4..dec59fc0b 100644 --- a/backend/app/services/dingtalk_service.py +++ b/backend/app/services/dingtalk_service.py @@ -4,31 +4,75 @@ import httpx from loguru import logger +from app.services.dingtalk_cache import TTLCache -async def get_dingtalk_access_token(app_id: str, app_secret: str) -> dict: - """Get DingTalk access_token using app_id and app_secret. - - API: https://open.dingtalk.com/document/orgapp/obtain-access_token - """ - url = "https://oapi.dingtalk.com/gettoken" - params = { - "appkey": app_id, - "appsecret": app_secret, - } +# DingTalk access_token expires_in 通常 7200s; 提前 200s 过期以留出刷新余量。 +_token_cache = TTLCache(default_ttl=7000) +# user/get 返回结果 30 分钟内保持有效。 +_user_detail_cache = TTLCache(default_ttl=1800) - async with httpx.AsyncClient(timeout=10) as client: - try: - resp = await client.get(url, params=params) - data = resp.json() - if data.get("errcode") == 0: - return {"access_token": data.get("access_token"), "expires_in": data.get("expires_in")} - else: +async def get_dingtalk_access_token(app_id: str, app_secret: str) -> dict: + """Get DingTalk access_token, 带 TTL 缓存(single-flight)。""" + cache_key = f"token:{app_id}" + + async def _fetch() -> dict: + url = "https://oapi.dingtalk.com/gettoken" + params = {"appkey": app_id, "appsecret": app_secret} + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.get(url, params=params) + data = resp.json() + if data.get("errcode") == 0: + return { + "access_token": data.get("access_token"), + "expires_in": data.get("expires_in"), + } logger.error(f"[DingTalk] Failed to get access_token: {data}") return {"errcode": data.get("errcode"), "errmsg": data.get("errmsg")} - except Exception as e: - logger.error(f"[DingTalk] Network error getting access_token: {e}") - return {"errcode": -1, "errmsg": str(e)} + except Exception as e: + logger.error(f"[DingTalk] Network error getting access_token: {e}") + return {"errcode": -1, "errmsg": str(e)} + + result = await _token_cache.get_or_set(cache_key, _fetch) + # 失败结果不保留: 下次调用能重新拉取 + if not result.get("access_token"): + _token_cache.invalidate(cache_key) + return result + + +async def get_dingtalk_user_detail(app_id: str, app_secret: str, userid: str) -> dict | None: + """Fetch user detail from DingTalk corp API, 带 30min 缓存。""" + cache_key = f"userdetail:{app_id}:{userid}" + + async def _fetch() -> dict | None: + token_result = await get_dingtalk_access_token(app_id, app_secret) + access_token = token_result.get("access_token") + if not access_token: + return None + url = "https://oapi.dingtalk.com/topapi/v2/user/get" + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.post( + url, + params={"access_token": access_token}, + json={"userid": userid}, + ) + data = resp.json() + if data.get("errcode") == 0: + return data.get("result", {}) + logger.warning( + f"[DingTalk] user/get failed for {userid}: {data.get('errmsg')}" + ) + return None + except Exception as e: + logger.warning(f"[DingTalk] user/get error for {userid}: {e}") + return None + + result = await _user_detail_cache.get_or_set(cache_key, _fetch) + if not result: + _user_detail_cache.invalidate(cache_key) + return result async def send_dingtalk_v1_robot_oto_message( diff --git a/backend/app/services/dingtalk_stream.py b/backend/app/services/dingtalk_stream.py index 28a8ba8e2..c51ed4830 100644 --- a/backend/app/services/dingtalk_stream.py +++ b/backend/app/services/dingtalk_stream.py @@ -98,6 +98,7 @@ async def process(self, callback: dingtalk_stream.CallbackMessage): conversation_id = incoming.conversation_id or "" conversation_type = incoming.conversation_type or "1" session_webhook = incoming.session_webhook or "" + message_id = incoming.message_id or "" logger.info( f"[DingTalk Stream] Message from [{incoming.sender_nick}]{sender_staff_id}: {user_text[:80]}" @@ -107,7 +108,7 @@ async def process(self, callback: dingtalk_stream.CallbackMessage): from app.api.dingtalk import process_dingtalk_message if main_loop and main_loop.is_running(): - future = asyncio.run_coroutine_threadsafe( + asyncio.run_coroutine_threadsafe( process_dingtalk_message( agent_id=agent_id, sender_staff_id=sender_staff_id, @@ -115,16 +116,11 @@ async def process(self, callback: dingtalk_stream.CallbackMessage): conversation_id=conversation_id, conversation_type=conversation_type, session_webhook=session_webhook, + message_id=message_id, ), main_loop, ) - # Wait for result (with timeout) - try: - future.result(timeout=120) - except Exception as e: - logger.error(f"[DingTalk Stream] LLM processing error: {e}") - import traceback - traceback.print_exc() + # Fire-and-forget: ACK immediately, do not wait for LLM else: logger.warning("[DingTalk Stream] Main loop not available for dispatch") diff --git a/backend/tests/test_channel_user_service_identity.py b/backend/tests/test_channel_user_service_identity.py new file mode 100644 index 000000000..771941ee6 --- /dev/null +++ b/backend/tests/test_channel_user_service_identity.py @@ -0,0 +1,592 @@ +"""channel_user_service 的 OrgMember 匹配与回填逻辑。 + +不走 DB: 用 FakeSession 吸收 session 方法, monkeypatch 替换查询入口。 +聚焦: resolve_channel_user 如何组合 find → match → backfill → link 这一条链。 +""" +from __future__ import annotations + +import uuid +from types import SimpleNamespace + +import pytest + +from app.services import channel_user_service as cus_mod +from app.services.channel_user_service import channel_user_service + + +class _FakeSession: + """吸收 resolve_channel_user 用到的 session 方法, 行为对业务无副作用。""" + + def __init__(self) -> None: + self.added: list = [] + self.flushed = 0 + + def add(self, obj): + self.added.append(obj) + + async def flush(self): + self.flushed += 1 + + async def get(self, model, key): + return None + + async def execute(self, _query): + class _R: + def scalar_one_or_none(self_inner): + return None + return _R() + + +@pytest.fixture +def fake_session(): + return _FakeSession() + + +@pytest.fixture +def agent(): + return SimpleNamespace(id=uuid.uuid4(), tenant_id=uuid.uuid4(), name="A") + + +@pytest.fixture +def patch_provider(monkeypatch): + """跳过 provider 查询; 直接返回固定 IdentityProvider.""" + provider = SimpleNamespace(id=uuid.uuid4(), tenant_id=None, provider_type="dingtalk") + + async def _fake_ensure(self, db, provider_type, tenant_id): + return provider + + monkeypatch.setattr( + cus_mod.ChannelUserService, "_ensure_provider", _fake_ensure + ) + return provider + + +async def test_find_org_member_receives_candidate_ids( + fake_session, agent, patch_provider, monkeypatch +): + captured = {} + + async def _fake_find(self, db, provider_id, channel_type, candidate_ids): + captured["ids"] = list(candidate_ids) + user = SimpleNamespace( + id=uuid.uuid4(), identity_id=None, + display_name="Bob", avatar_url=None, + ) + member = SimpleNamespace(id=uuid.uuid4(), user_id=user.id) + fake_session._preloaded_user = user + return member + + async def _fake_db_get(model, key): + return fake_session._preloaded_user + + monkeypatch.setattr(cus_mod.ChannelUserService, "_find_org_member", _fake_find) + monkeypatch.setattr(fake_session, "get", _fake_db_get, raising=False) + + await channel_user_service.resolve_channel_user( + db=fake_session, + agent=agent, + channel_type="dingtalk", + external_user_id="staff-1", + extra_info={"unionid": "UNION-1"}, + extra_ids=["UNION-1"], + ) + + assert captured["ids"] == ["staff-1", "UNION-1"] + + +async def test_find_org_member_deduplicates_candidate_ids( + fake_session, agent, patch_provider, monkeypatch +): + captured = {} + + async def _fake_find(self, db, provider_id, channel_type, candidate_ids): + captured["ids"] = list(candidate_ids) + user = SimpleNamespace(id=uuid.uuid4(), identity_id=None) + member = SimpleNamespace(id=uuid.uuid4(), user_id=user.id) + fake_session._preloaded_user = user + return member + + monkeypatch.setattr(cus_mod.ChannelUserService, "_find_org_member", _fake_find) + + async def _fake_db_get(model, key): + return fake_session._preloaded_user + + monkeypatch.setattr(fake_session, "get", _fake_db_get, raising=False) + + await channel_user_service.resolve_channel_user( + db=fake_session, + agent=agent, + channel_type="dingtalk", + external_user_id="staff-1", + extra_info={"unionid": "staff-1"}, + extra_ids=["staff-1"], + ) + + assert captured["ids"] == ["staff-1"] # 去重 + + +class _RecordingSession: + """Captures the SQL from db.execute without running it.""" + + def __init__(self): + self.last_stmt = None + + async def execute(self, stmt): + self.last_stmt = stmt + + class _R: + def scalar_one_or_none(self_inner): + return None + + return _R() + + +async def test_find_org_member_sql_dingtalk(): + sess = _RecordingSession() + await channel_user_service._find_org_member( + sess, uuid.uuid4(), "dingtalk", ["staff-1", "UNION-1"] + ) + sql = str(sess.last_stmt.compile(compile_kwargs={"literal_binds": True})) + # Isolate the WHERE clause so SELECT-column references don't pollute checks + where_clause = sql.split("WHERE", 1)[1] + # dingtalk: OR over unionid + external_id, NOT open_id IN (...) + assert "org_members.unionid IN" in where_clause + assert "org_members.external_id IN" in where_clause + assert "org_members.open_id IN" not in where_clause + assert "'staff-1'" in where_clause and "'UNION-1'" in where_clause + + +async def test_find_org_member_sql_feishu(): + sess = _RecordingSession() + await channel_user_service._find_org_member( + sess, uuid.uuid4(), "feishu", ["ou_x", "on_y"] + ) + sql = str(sess.last_stmt.compile(compile_kwargs={"literal_binds": True})) + where_clause = sql.split("WHERE", 1)[1] + # feishu: OR over unionid + open_id + external_id + assert "org_members.unionid IN" in where_clause + assert "org_members.open_id IN" in where_clause + assert "org_members.external_id IN" in where_clause + + +async def test_find_org_member_sql_wecom(): + sess = _RecordingSession() + await channel_user_service._find_org_member( + sess, uuid.uuid4(), "wecom", ["userid-1"] + ) + sql = str(sess.last_stmt.compile(compile_kwargs={"literal_binds": True})) + where_clause = sql.split("WHERE", 1)[1] + # wecom: external_id only, no unionid IN / open_id IN in WHERE + assert "org_members.external_id IN" in where_clause + assert "org_members.unionid IN" not in where_clause + assert "org_members.open_id IN" not in where_clause + + +async def test_find_org_member_empty_ids_returns_none_without_execute(): + sess = _RecordingSession() + result = await channel_user_service._find_org_member( + sess, uuid.uuid4(), "dingtalk", [] + ) + assert result is None + assert sess.last_stmt is None # short-circuits, no execute + + +def _make_member(**kwargs): + defaults = dict( + id=uuid.uuid4(), external_id=None, unionid=None, open_id=None, user_id=None, + ) + defaults.update(kwargs) + return SimpleNamespace(**defaults) + + +def _fake_provider_for(channel_type: str): + return SimpleNamespace( + id=uuid.uuid4(), tenant_id=None, provider_type=channel_type + ) + + +def test_backfill_dingtalk_fills_external_and_unionid(): + svc = channel_user_service + member = _make_member() + svc._backfill_org_member_ids( + member, + channel_type="dingtalk", + external_user_id="staff-carol-777", + extra_info={"unionid": "UNION-CAROL", "mobile": "13800000001"}, + ) + assert member.external_id == "staff-carol-777" + assert member.unionid == "UNION-CAROL" + + +def test_backfill_dingtalk_does_not_overwrite_existing(): + svc = channel_user_service + member = _make_member(external_id="existing-staff", unionid="existing-union") + svc._backfill_org_member_ids( + member, + channel_type="dingtalk", + external_user_id="staff-new", + extra_info={"unionid": "UNION-NEW"}, + ) + assert member.external_id == "existing-staff" + assert member.unionid == "existing-union" + + +def test_backfill_feishu_on_prefix_goes_to_unionid(): + svc = channel_user_service + member = _make_member() + svc._backfill_org_member_ids( + member, + channel_type="feishu", + external_user_id="on_unionid_xxx", + extra_info={}, + ) + assert member.unionid == "on_unionid_xxx" + + +def test_backfill_feishu_ou_prefix_goes_to_openid(): + svc = channel_user_service + member = _make_member() + svc._backfill_org_member_ids( + member, + channel_type="feishu", + external_user_id="ou_openid_xxx", + extra_info={}, + ) + assert member.open_id == "ou_openid_xxx" + + +def test_backfill_wecom_only_fills_external_id(): + svc = channel_user_service + member = _make_member() + svc._backfill_org_member_ids( + member, + channel_type="wecom", + external_user_id="userid-wecom-1", + extra_info={"unionid": "ignored-for-wecom"}, + ) + assert member.external_id == "userid-wecom-1" + assert member.unionid is None + + +async def test_reuse_existing_org_member_triggers_backfill( + fake_session, agent, patch_provider, monkeypatch +): + """email 命中 User → 找到 existing_member → 应回填 dingtalk 标识到 existing_member""" + matched_user = SimpleNamespace( + id=uuid.uuid4(), identity_id=None, + display_name="Carol", avatar_url=None, + ) + existing_member = _make_member(user_id=matched_user.id) + + async def _fake_find_none(self, db, provider_id, channel_type, candidate_ids): + return None + + async def _fake_match_email(db, email, tenant_id): + return matched_user + + async def _fake_match_mobile(db, mobile, tenant_id): + return None + + async def _fake_find_existing(self, db, user_id, provider_id, tenant_id): + return existing_member + + monkeypatch.setattr(cus_mod.ChannelUserService, "_find_org_member", _fake_find_none) + monkeypatch.setattr(cus_mod.sso_service, "match_user_by_email", _fake_match_email) + monkeypatch.setattr(cus_mod.sso_service, "match_user_by_mobile", _fake_match_mobile) + monkeypatch.setattr( + cus_mod.ChannelUserService, "_find_existing_org_member_for_user", + _fake_find_existing, + ) + + async def _get_none(model, key): + return None + monkeypatch.setattr(fake_session, "get", _get_none, raising=False) + + await channel_user_service.resolve_channel_user( + db=fake_session, + agent=agent, + channel_type="dingtalk", + external_user_id="staff-carol-777", + extra_info={ + "unionid": "UNION-CAROL", + "mobile": "13800000001", + "email": "carol@example.com", + "name": "Carol", + }, + extra_ids=["UNION-CAROL"], + ) + + assert existing_member.external_id == "staff-carol-777" + assert existing_member.unionid == "UNION-CAROL" + + +async def test_enrich_skips_phone_when_other_identity_uses_it(fake_session, monkeypatch): + """Pre-check: if another Identity already has the phone, skip instead of raising.""" + from app.services.channel_user_service import channel_user_service as svc + from app.services import channel_user_service as cus_mod + + current_identity = SimpleNamespace( + id=uuid.uuid4(), phone=None, email=None, + ) + user = SimpleNamespace( + id=uuid.uuid4(), identity_id=current_identity.id, + display_name=None, avatar_url=None, + ) + + async def _fake_get(model, key): + assert key == current_identity.id + return current_identity + monkeypatch.setattr(fake_session, "get", _fake_get, raising=False) + + # Simulate "another identity has this phone": execute returns a truthy row + other_identity_id = uuid.uuid4() + + async def _fake_execute(stmt): + sql = str(stmt) + + class _R: + def scalar_one_or_none(self_inner): + # Return the other identity's id if the query is looking up + # identities by phone; else None. + if "identities.phone" in sql or "phone =" in sql.lower(): + return other_identity_id + return None + return _R() + + monkeypatch.setattr(fake_session, "execute", _fake_execute, raising=False) + + await svc._enrich_user_from_extra_info( + fake_session, user, {"mobile": "15703300627", "email": None, "name": None} + ) + + # Phone was NOT written, no exception raised + assert current_identity.phone is None + + +async def test_enrich_skips_email_when_other_identity_uses_it(fake_session, monkeypatch): + from app.services.channel_user_service import channel_user_service as svc + + current_identity = SimpleNamespace( + id=uuid.uuid4(), phone=None, email=None, + ) + user = SimpleNamespace( + id=uuid.uuid4(), identity_id=current_identity.id, + display_name=None, avatar_url=None, + ) + + async def _fake_get(model, key): + return current_identity + monkeypatch.setattr(fake_session, "get", _fake_get, raising=False) + + async def _fake_execute(stmt): + sql = str(stmt) + + class _R: + def scalar_one_or_none(self_inner): + if "identities.email" in sql or "email =" in sql.lower(): + return uuid.uuid4() + return None + return _R() + monkeypatch.setattr(fake_session, "execute", _fake_execute, raising=False) + + await svc._enrich_user_from_extra_info( + fake_session, user, + {"mobile": None, "email": "dup@example.com", "name": None}, + ) + + assert current_identity.email is None + + +async def test_enrich_writes_phone_when_no_conflict(fake_session, monkeypatch): + """Happy path: no other identity uses the phone → write succeeds.""" + from app.services.channel_user_service import channel_user_service as svc + + current_identity = SimpleNamespace( + id=uuid.uuid4(), phone=None, email=None, + ) + user = SimpleNamespace( + id=uuid.uuid4(), identity_id=current_identity.id, + display_name=None, avatar_url=None, + ) + + async def _fake_get(model, key): + return current_identity + monkeypatch.setattr(fake_session, "get", _fake_get, raising=False) + + async def _fake_execute(stmt): + class _R: + def scalar_one_or_none(self_inner): + return None # no conflict + return _R() + monkeypatch.setattr(fake_session, "execute", _fake_execute, raising=False) + + await svc._enrich_user_from_extra_info( + fake_session, user, {"mobile": "13800000000", "email": None, "name": None} + ) + + assert current_identity.phone == "13800000000" + + +async def test_resolve_continues_when_enrich_raises( + fake_session, agent, patch_provider, monkeypatch +): + """Isolation: even if _enrich raises unexpectedly, resolve still returns the user.""" + from app.services.channel_user_service import channel_user_service as svc + from app.services import channel_user_service as cus_mod + + matched_user = SimpleNamespace( + id=uuid.uuid4(), identity_id=uuid.uuid4(), + display_name=None, avatar_url=None, + ) + + async def _find_linked(self, db, provider_id, channel_type, candidate_ids): + # Return a member already linked to matched_user → Case 1 branch + return SimpleNamespace(id=uuid.uuid4(), user_id=matched_user.id) + + monkeypatch.setattr(cus_mod.ChannelUserService, "_find_org_member", _find_linked) + + async def _db_get(model, key): + if key == matched_user.id: + return matched_user + return None + monkeypatch.setattr(fake_session, "get", _db_get, raising=False) + + async def _boom(self, db, user, extra_info): + raise RuntimeError("simulated enrichment failure") + monkeypatch.setattr( + cus_mod.ChannelUserService, "_enrich_user_from_extra_info", _boom + ) + + # resolve_channel_user should catch the enrichment error and still return the user + result = await svc.resolve_channel_user( + db=fake_session, agent=agent, channel_type="dingtalk", + external_user_id="staff-xyz", + extra_info={"mobile": "13900000000", "email": "x@y.com"}, + ) + assert result.id == matched_user.id + + +async def test_reuse_existing_org_member_triggers_backfill_feishu( + fake_session, agent, patch_provider, monkeypatch +): + """Feishu reuse: extra_info has email + unionid; external_user_id is open_id (ou_...) + → backfill should populate open_id (from prefix), external_id, and unionid (from extra_info) + on the reused OrgMember. + """ + import app.services.channel_user_service as cus_mod + + async def _fake_ensure_feishu(self, db, ptype, tid): + return _fake_provider_for("feishu") + + monkeypatch.setattr(cus_mod.ChannelUserService, "_ensure_provider", _fake_ensure_feishu) + + matched_user = SimpleNamespace( + id=uuid.uuid4(), identity_id=None, + display_name=None, avatar_url=None, + ) + existing_member = _make_member(user_id=matched_user.id) + + async def _fake_find_none(self, db, provider_id, channel_type, candidate_ids): + return None + + async def _fake_match_email(db, email, tenant_id): + return matched_user + + async def _fake_match_mobile(db, mobile, tenant_id): + return None + + async def _fake_find_existing(self, db, user_id, provider_id, tenant_id): + return existing_member + + monkeypatch.setattr(cus_mod.ChannelUserService, "_find_org_member", _fake_find_none) + monkeypatch.setattr(cus_mod.sso_service, "match_user_by_email", _fake_match_email) + monkeypatch.setattr(cus_mod.sso_service, "match_user_by_mobile", _fake_match_mobile) + monkeypatch.setattr( + cus_mod.ChannelUserService, + "_find_existing_org_member_for_user", + _fake_find_existing, + ) + + async def _get_none(model, key): + return None + monkeypatch.setattr(fake_session, "get", _get_none, raising=False) + + await channel_user_service.resolve_channel_user( + db=fake_session, + agent=agent, + channel_type="feishu", + external_user_id="ou_abc123xyz", # Feishu open_id prefix + extra_info={ + "unionid": "UNION-FEISHU-X", + "email": "feishu-user@example.com", + "name": "Feishu User", + }, + ) + + # Backfill assertions — open_id from prefix, external_id from external_user_id, + # unionid from extra_info + assert existing_member.open_id == "ou_abc123xyz" + assert existing_member.external_id == "ou_abc123xyz" + assert existing_member.unionid == "UNION-FEISHU-X" + + +async def test_reuse_existing_org_member_triggers_backfill_wecom( + fake_session, agent, patch_provider, monkeypatch +): + """WeCom reuse: only external_id should be filled (from userid). + unionid in extra_info must NOT be written to the member (wecom doesn't track unionid + on OrgMember). + """ + import app.services.channel_user_service as cus_mod + + async def _fake_ensure_wecom(self, db, ptype, tid): + return _fake_provider_for("wecom") + + monkeypatch.setattr(cus_mod.ChannelUserService, "_ensure_provider", _fake_ensure_wecom) + + matched_user = SimpleNamespace( + id=uuid.uuid4(), identity_id=None, + display_name=None, avatar_url=None, + ) + existing_member = _make_member(user_id=matched_user.id) + + async def _fake_find_none(self, db, provider_id, channel_type, candidate_ids): + return None + + async def _fake_match_email(db, email, tenant_id): + return matched_user + + async def _fake_match_mobile(db, mobile, tenant_id): + return None + + async def _fake_find_existing(self, db, user_id, provider_id, tenant_id): + return existing_member + + monkeypatch.setattr(cus_mod.ChannelUserService, "_find_org_member", _fake_find_none) + monkeypatch.setattr(cus_mod.sso_service, "match_user_by_email", _fake_match_email) + monkeypatch.setattr(cus_mod.sso_service, "match_user_by_mobile", _fake_match_mobile) + monkeypatch.setattr( + cus_mod.ChannelUserService, + "_find_existing_org_member_for_user", + _fake_find_existing, + ) + + async def _get_none(model, key): + return None + monkeypatch.setattr(fake_session, "get", _get_none, raising=False) + + await channel_user_service.resolve_channel_user( + db=fake_session, + agent=agent, + channel_type="wecom", + external_user_id="wecom-userid-42", + extra_info={ + "unionid": "SHOULD-BE-IGNORED", # wecom doesn't write unionid on member + "email": "wecom-user@example.com", + "name": "WeCom User", + }, + ) + + assert existing_member.external_id == "wecom-userid-42" + assert existing_member.unionid is None # not written for wecom + assert existing_member.open_id is None # not written for wecom diff --git a/backend/tests/test_dingtalk_dedup.py b/backend/tests/test_dingtalk_dedup.py new file mode 100644 index 000000000..429db1326 --- /dev/null +++ b/backend/tests/test_dingtalk_dedup.py @@ -0,0 +1,66 @@ +import asyncio + +import pytest + +from app.api import dingtalk as dingtalk_api + + +@pytest.fixture(autouse=True) +def _reset_state(monkeypatch): + dingtalk_api._processed_messages.clear() + dingtalk_api._dedup_check_counter = 0 + # 禁用 Redis: 强制走内存 fallback + monkeypatch.setattr(dingtalk_api, "_get_redis_client", None) + + +async def test_acquire_first_returns_accepted(): + accepted, state = await dingtalk_api.acquire_dedup_lock("m-1") + assert accepted is True + assert state == "new" + + +async def test_acquire_second_while_processing_returns_duplicate(): + await dingtalk_api.acquire_dedup_lock("m-2") + accepted, state = await dingtalk_api.acquire_dedup_lock("m-2") + assert accepted is False + assert state == "processing" + + +async def test_acquire_after_done_returns_duplicate(): + await dingtalk_api.acquire_dedup_lock("m-3") + await dingtalk_api.mark_dedup_done("m-3") + accepted, state = await dingtalk_api.acquire_dedup_lock("m-3") + assert accepted is False + assert state == "done" + + +async def test_release_allows_retry(): + await dingtalk_api.acquire_dedup_lock("m-4") + await dingtalk_api.release_dedup_lock("m-4") + accepted, state = await dingtalk_api.acquire_dedup_lock("m-4") + assert accepted is True + assert state == "new" + + +async def test_release_after_done_is_noop(): + await dingtalk_api.acquire_dedup_lock("m-4b") + await dingtalk_api.mark_dedup_done("m-4b") + await dingtalk_api.release_dedup_lock("m-4b") # 不应把 done 删掉 + accepted, state = await dingtalk_api.acquire_dedup_lock("m-4b") + assert accepted is False + assert state == "done" + + +async def test_empty_message_id_always_accepts(): + accepted, state = await dingtalk_api.acquire_dedup_lock("") + assert accepted is True + assert state == "new" + + +async def test_processing_ttl_expires(monkeypatch): + monkeypatch.setattr(dingtalk_api, "PROCESSING_TTL", 0.05, raising=False) + await dingtalk_api.acquire_dedup_lock("m-5") + await asyncio.sleep(0.08) + accepted, state = await dingtalk_api.acquire_dedup_lock("m-5") + assert accepted is True + assert state == "new" diff --git a/backend/tests/test_dingtalk_service_cache.py b/backend/tests/test_dingtalk_service_cache.py new file mode 100644 index 000000000..1ccc6f7db --- /dev/null +++ b/backend/tests/test_dingtalk_service_cache.py @@ -0,0 +1,144 @@ +import asyncio + +import pytest + +from app.services.dingtalk_cache import TTLCache +from app.services import dingtalk_service + + +class _FakeResp: + def __init__(self, payload): + self._payload = payload + self.status_code = 200 + + def json(self): + return self._payload + + +class _FakeClient: + instances: list["_FakeClient"] = [] + + def __init__(self, *args, **kwargs): + self.get_calls: list[tuple[str, dict]] = [] + self.post_calls: list[tuple[str, dict, dict]] = [] + _FakeClient.instances.append(self) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def get(self, url, params=None): + self.get_calls.append((url, params or {})) + return _FakeResp({"errcode": 0, "access_token": "tok-A", "expires_in": 7200}) + + async def post(self, url, params=None, json=None): + self.post_calls.append((url, params or {}, json or {})) + return _FakeResp({ + "errcode": 0, + "result": { + "userid": json["userid"], + "name": "Alice", + "mobile": "13800000000", + "email": "alice@example.com", + "unionid": f"UNION-{json['userid']}", + }, + }) + + +@pytest.fixture(autouse=True) +def _reset_caches(monkeypatch): + dingtalk_service._token_cache._store.clear() + dingtalk_service._user_detail_cache._store.clear() + _FakeClient.instances.clear() + monkeypatch.setattr(dingtalk_service.httpx, "AsyncClient", _FakeClient) + + +async def test_ttl_cache_hit_miss(): + cache = TTLCache(default_ttl=60) + calls = {"n": 0} + + async def factory(): + calls["n"] += 1 + return {"v": calls["n"]} + + v1 = await cache.get_or_set("k", factory) + v2 = await cache.get_or_set("k", factory) + assert v1 == v2 == {"v": 1} + assert calls["n"] == 1 + + +async def test_ttl_cache_expires(): + cache = TTLCache(default_ttl=0.05) + calls = {"n": 0} + + async def factory(): + calls["n"] += 1 + return calls["n"] + + assert await cache.get_or_set("k", factory) == 1 + await asyncio.sleep(0.08) + assert await cache.get_or_set("k", factory) == 2 + + +async def test_ttl_cache_single_flight(): + cache = TTLCache(default_ttl=60) + calls = {"n": 0} + start = asyncio.Event() + + async def factory(): + await start.wait() + calls["n"] += 1 + return calls["n"] + + t1 = asyncio.create_task(cache.get_or_set("k", factory)) + t2 = asyncio.create_task(cache.get_or_set("k", factory)) + await asyncio.sleep(0) + start.set() + r1, r2 = await asyncio.gather(t1, t2) + assert r1 == r2 == 1 + assert calls["n"] == 1 + + +async def test_access_token_cached_across_calls(): + t1 = await dingtalk_service.get_dingtalk_access_token("APP", "SEC") + t2 = await dingtalk_service.get_dingtalk_access_token("APP", "SEC") + assert t1["access_token"] == t2["access_token"] == "tok-A" + gets = [c for inst in _FakeClient.instances for c in inst.get_calls] + assert len(gets) == 1 + + +async def test_user_detail_cached_per_userid(): + d1 = await dingtalk_service.get_dingtalk_user_detail("APP", "SEC", "user-1") + d2 = await dingtalk_service.get_dingtalk_user_detail("APP", "SEC", "user-1") + d3 = await dingtalk_service.get_dingtalk_user_detail("APP", "SEC", "user-2") + assert d1 == d2 + assert d1["unionid"] == "UNION-user-1" + assert d3["unionid"] == "UNION-user-2" + user_posts = [ + c for inst in _FakeClient.instances + for c in inst.post_calls if "user/get" in c[0] + ] + assert len(user_posts) == 2 # 2 distinct userids + + +class _EmptyResultClient(_FakeClient): + async def post(self, url, params=None, json=None): + self.post_calls.append((url, params or {}, json or {})) + # DingTalk returns success but result field missing → _fetch coerces to {} + return _FakeResp({"errcode": 0}) + + +async def test_user_detail_empty_result_is_not_cached(monkeypatch): + monkeypatch.setattr(dingtalk_service.httpx, "AsyncClient", _EmptyResultClient) + r1 = await dingtalk_service.get_dingtalk_user_detail("APP", "SEC", "user-x") + r2 = await dingtalk_service.get_dingtalk_user_detail("APP", "SEC", "user-x") + assert r1 == {} + assert r2 == {} + user_posts = [ + c for inst in _FakeClient.instances + for c in inst.post_calls if "user/get" in c[0] + ] + # Without the fix: only 1 POST (empty dict cached). With fix: 2 POSTs. + assert len(user_posts) == 2