From d09e8b7de643e8c895ad4cbf4f65a9652a145ae5 Mon Sep 17 00:00:00 2001 From: "nap.liu" Date: Wed, 8 Apr 2026 21:38:29 +0800 Subject: [PATCH 1/5] fix: allow unbound users (tenant_id=None) to pass tenant-scoped login check Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/app/api/auth.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py index ef76429b5..5c0cc4e27 100644 --- a/backend/app/api/auth.py +++ b/backend/app/api/auth.py @@ -512,12 +512,18 @@ async def login(data: UserLogin, background_tasks: BackgroundTasks, db: AsyncSes user = valid_users[0] else: # Specific tenant requested (Dedicated Link flow) - # Search for the user record in that tenant - user = next((u for u in valid_users if u.tenant_id == data.tenant_id), None) - + # Search for the user record in that tenant. + # Also allow users with tenant_id=None (SSO-created but not yet + # assigned to a tenant) so they can proceed to company setup. + user = next( + (u for u in valid_users + if u.tenant_id == data.tenant_id or u.tenant_id is None), + None, + ) + # Cross-tenant access check if not user: - # Even platform admins must have a valid record in the targeted tenant + # Even platform admins must have a valid record in the targeted tenant # when logging in via a dedicated tenant URL / tenant_id. raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, From efb31312a2e7f66ffa53dbf699684eb99ad4d437 Mon Sep 17 00:00:00 2001 From: "nap.liu" Date: Wed, 8 Apr 2026 20:01:52 +0800 Subject: [PATCH 2/5] feat: unified URL resolution with 5-level fallback chain and multi-tenant subdomain support Co-Authored-By: Claude Opus 4.6 (1M context) --- .../alembic/versions/add_subdomain_prefix.py | 24 +++ .../alembic/versions/add_tenant_is_default.py | 27 ++++ backend/app/api/sso.py | 15 +- backend/app/api/tenants.py | 147 ++++++++++++------ backend/app/core/domain.py | 77 +++++++++ backend/app/models/tenant.py | 5 + backend/app/services/platform_service.py | 24 ++- frontend/src/i18n/en.json | 6 +- frontend/src/i18n/zh.json | 6 +- frontend/src/pages/AdminCompanies.tsx | 55 +++++++ 10 files changed, 319 insertions(+), 67 deletions(-) create mode 100644 backend/alembic/versions/add_subdomain_prefix.py create mode 100644 backend/alembic/versions/add_tenant_is_default.py create mode 100644 backend/app/core/domain.py diff --git a/backend/alembic/versions/add_subdomain_prefix.py b/backend/alembic/versions/add_subdomain_prefix.py new file mode 100644 index 000000000..5fecb85c0 --- /dev/null +++ b/backend/alembic/versions/add_subdomain_prefix.py @@ -0,0 +1,24 @@ +"""Add subdomain_prefix to tenants table.""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy import inspect + +revision: str = "add_subdomain_prefix" +down_revision = "d9cbd43b62e5" +branch_labels = None +depends_on = None + +def upgrade() -> None: + conn = op.get_bind() + inspector = inspect(conn) + columns = [c["name"] for c in inspector.get_columns("tenants")] + if "subdomain_prefix" not in columns: + op.add_column("tenants", sa.Column("subdomain_prefix", sa.String(50), nullable=True)) + indexes = [i["name"] for i in inspector.get_indexes("tenants")] + if "ix_tenants_subdomain_prefix" not in indexes: + op.create_index("ix_tenants_subdomain_prefix", "tenants", ["subdomain_prefix"], unique=True) + +def downgrade() -> None: + op.drop_index("ix_tenants_subdomain_prefix", "tenants") + op.drop_column("tenants", "subdomain_prefix") diff --git a/backend/alembic/versions/add_tenant_is_default.py b/backend/alembic/versions/add_tenant_is_default.py new file mode 100644 index 000000000..00e164615 --- /dev/null +++ b/backend/alembic/versions/add_tenant_is_default.py @@ -0,0 +1,27 @@ +"""Add is_default field to tenants table.""" + +from alembic import op +import sqlalchemy as sa + +revision = "add_tenant_is_default" +down_revision = "add_subdomain_prefix" +branch_labels = None +depends_on = None + +def upgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + cols = [c['name'] for c in inspector.get_columns('tenants')] + if 'is_default' not in cols: + op.add_column('tenants', sa.Column('is_default', sa.Boolean(), nullable=False, server_default='false')) + conn.execute(sa.text(""" + UPDATE tenants + SET is_default = true + WHERE id = ( + SELECT id FROM tenants WHERE is_active = true ORDER BY created_at ASC LIMIT 1 + ) + AND NOT EXISTS (SELECT 1 FROM tenants WHERE is_default = true) + """)) + +def downgrade() -> None: + op.drop_column('tenants', 'is_default') diff --git a/backend/app/api/sso.py b/backend/app/api/sso.py index 1c5210247..940b5d7b4 100644 --- a/backend/app/api/sso.py +++ b/backend/app/api/sso.py @@ -104,15 +104,12 @@ async def get_sso_config(sid: uuid.UUID, request: Request, db: AsyncSession = De result = await db.execute(query) providers = result.scalars().all() - # Determine the base URL for OAuth callbacks using centralized platform service: - from app.services.platform_service import platform_service - if session.tenant_id: - from app.models.tenant import Tenant - tenant_result = await db.execute(select(Tenant).where(Tenant.id == session.tenant_id)) - tenant_obj = tenant_result.scalar_one_or_none() - public_base = await platform_service.get_tenant_sso_base_url(db, tenant_obj, request) - else: - public_base = await platform_service.get_public_base_url(db, request) + # Determine the base URL for OAuth callbacks using unified domain resolution: + from app.core.domain import resolve_base_url + public_base = await resolve_base_url( + db, request=request, + tenant_id=str(session.tenant_id) if session.tenant_id else None + ) auth_urls = [] for p in providers: diff --git a/backend/app/api/tenants.py b/backend/app/api/tenants.py index 1e8c26b14..85793949f 100644 --- a/backend/app/api/tenants.py +++ b/backend/app/api/tenants.py @@ -37,6 +37,9 @@ class TenantOut(BaseModel): is_active: bool sso_enabled: bool = False sso_domain: str | None = None + is_default: bool = False + subdomain_prefix: str | None = None + effective_base_url: str | None = None created_at: datetime | None = None model_config = {"from_attributes": True} @@ -49,6 +52,8 @@ class TenantUpdate(BaseModel): is_active: bool | None = None sso_enabled: bool | None = None sso_domain: str | None = None + subdomain_prefix: str | None = None + is_default: bool | None = None # ─── Helpers ──────────────────────────────────────────── @@ -338,66 +343,80 @@ async def get_registration_config(db: AsyncSession = Depends(get_db)): return {"allow_self_create_company": allowed} -# ─── Public: Resolve Tenant by Domain ─────────────────── - -@router.get("/resolve-by-domain") -async def resolve_tenant_by_domain( - domain: str, - db: AsyncSession = Depends(get_db), -): - """Resolve a tenant by its sso_domain or subdomain slug. +# ─── Public: Check Subdomain Prefix Availability ───────── - sso_domain is stored as a full URL (e.g. "https://acme.clawith.ai" or "http://1.2.3.4:3009"). - The incoming `domain` parameter is the host (without protocol). - - Lookup precedence: - 1. Exact match on tenant.sso_domain ending with the host (strips protocol) - 2. Extract slug from "{slug}.clawith.ai" and match tenant.slug - """ - tenant = None - - # 1. Match by stripping protocol from stored sso_domain - # sso_domain = "https://acme.clawith.ai" → compare against "acme.clawith.ai" - for proto in ("https://", "http://"): - result = await db.execute( - select(Tenant).where(Tenant.sso_domain == f"{proto}{domain}") - ) - tenant = result.scalar_one_or_none() - if tenant: - break - - # 2. Try without port (e.g. domain = "1.2.3.4:3009" → try "1.2.3.4") - if not tenant and ":" in domain: - domain_no_port = domain.split(":")[0] - for proto in ("https://", "http://"): - result = await db.execute( - select(Tenant).where(Tenant.sso_domain.like(f"{proto}{domain_no_port}%")) - ) - tenant = result.scalar_one_or_none() - if tenant: - break +@router.get("/check-prefix") +async def check_prefix(prefix: str, db: AsyncSession = Depends(get_db)): + """Check if a subdomain prefix is available.""" + import re as _re + if not _re.match(r'^[a-z0-9]([a-z0-9\-]{0,48}[a-z0-9])?$', prefix): + return {"available": False, "reason": "Invalid format"} + reserved = {"www", "api", "app", "admin", "mail", "smtp", "ftp", "ns1", "ns2", "cdn", "static", "assets"} + if prefix in reserved: + return {"available": False, "reason": "Reserved"} + result = await db.execute(select(Tenant).where(Tenant.subdomain_prefix == prefix)) + if result.scalar_one_or_none(): + return {"available": False, "reason": "Already taken"} + return {"available": True} - # 3. Fallback: extract slug from subdomain pattern - if not tenant: - import re - m = re.match(r"^([a-z0-9][a-z0-9\-]*[a-z0-9])\.clawith\.ai$", domain.lower()) - if m: - slug = m.group(1) - result = await db.execute(select(Tenant).where(Tenant.slug == slug)) - tenant = result.scalar_one_or_none() - if not tenant or not tenant.is_active or not tenant.sso_enabled: - raise HTTPException(status_code=404, detail="Tenant not found or not active or SSO not enabled") +# ─── Public: Resolve Tenant by Domain ─────────────────── +def _tenant_response(tenant): return { - "id": tenant.id, + "id": str(tenant.id), "name": tenant.name, "slug": tenant.slug, "sso_enabled": tenant.sso_enabled, "sso_domain": tenant.sso_domain, - "is_active": tenant.is_active, + "subdomain_prefix": tenant.subdomain_prefix, + "is_default": getattr(tenant, 'is_default', False), } + +@router.get("/resolve-by-domain") +async def resolve_by_domain(domain: str, db: AsyncSession = Depends(get_db)): + """Resolve a tenant by the incoming domain/hostname.""" + from app.core.domain import resolve_base_url, _get_global_base_url + from urllib.parse import urlparse + + hostname = domain.split(":")[0] + + # 1. Match sso_domain exactly + result = await db.execute(select(Tenant).where(Tenant.sso_domain == hostname)) + tenant = result.scalar_one_or_none() + if tenant: + return _tenant_response(tenant) + + # 2. Match subdomain_prefix + global_url = await _get_global_base_url(db) + if global_url: + parsed = urlparse(global_url) + global_host = parsed.hostname or "" + if hostname.endswith(f".{global_host}"): + prefix = hostname[: -(len(global_host) + 1)] + result = await db.execute(select(Tenant).where(Tenant.subdomain_prefix == prefix)) + tenant = result.scalar_one_or_none() + if tenant: + return _tenant_response(tenant) + + # 3. Match slug as subdomain + parts = hostname.split(".") + if len(parts) >= 2: + slug_candidate = parts[0] + result = await db.execute(select(Tenant).where(Tenant.slug == slug_candidate)) + tenant = result.scalar_one_or_none() + if tenant: + return _tenant_response(tenant) + + # 4. Return default tenant + result = await db.execute(select(Tenant).where(Tenant.is_default == True)) + tenant = result.scalar_one_or_none() + if tenant: + return _tenant_response(tenant) + + raise HTTPException(status_code=404, detail="No tenant found for this domain") + # ─── Authenticated: List / Get ────────────────────────── @router.get("/", response_model=list[TenantOut]) @@ -425,7 +444,10 @@ async def get_tenant( tenant = result.scalar_one_or_none() if not tenant: raise HTTPException(status_code=404, detail="Tenant not found") - return TenantOut.model_validate(tenant) + from app.core.domain import resolve_base_url + out = TenantOut.model_validate(tenant) + out.effective_base_url = await resolve_base_url(db, tenant_id=str(tenant.id)) + return out @router.put("/{tenant_id}", response_model=TenantOut) @@ -444,13 +466,36 @@ async def update_tenant( raise HTTPException(status_code=404, detail="Tenant not found") update_data = data.model_dump(exclude_unset=True) - + # SSO configuration is managed exclusively by the company's own org_admin # via the Enterprise Settings page. Platform admins should not override it here. if current_user.role == "platform_admin": update_data.pop("sso_enabled", None) update_data.pop("sso_domain", None) + # Validate subdomain_prefix format if provided + if "subdomain_prefix" in update_data and update_data["subdomain_prefix"] is not None: + import re as _re + prefix = update_data["subdomain_prefix"] + if not _re.match(r'^[a-z0-9]([a-z0-9\-]{0,48}[a-z0-9])?$', prefix): + raise HTTPException(status_code=400, detail="Invalid subdomain prefix format") + reserved = {"www", "api", "app", "admin", "mail", "smtp", "ftp", "ns1", "ns2", "cdn", "static", "assets"} + if prefix in reserved: + raise HTTPException(status_code=400, detail="Subdomain prefix is reserved") + # Check uniqueness + existing = await db.execute( + select(Tenant).where(Tenant.subdomain_prefix == prefix, Tenant.id != tenant_id) + ) + if existing.scalar_one_or_none(): + raise HTTPException(status_code=400, detail="Subdomain prefix already taken") + + # If setting is_default=True, clear other defaults + if update_data.get("is_default") is True: + from sqlalchemy import update as sql_update + await db.execute( + sql_update(Tenant).where(Tenant.id != tenant_id).values(is_default=False) + ) + for field, value in update_data.items(): setattr(tenant, field, value) await db.flush() diff --git a/backend/app/core/domain.py b/backend/app/core/domain.py new file mode 100644 index 000000000..6e2a166bb --- /dev/null +++ b/backend/app/core/domain.py @@ -0,0 +1,77 @@ +"""Domain resolution with fallback chain.""" + +import os + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from fastapi import Request + +from app.models.system_settings import SystemSetting + + +async def _get_global_base_url(db: AsyncSession): + """Helper: read platform public_base_url from system_settings.""" + # Try DB first + result = await db.execute( + select(SystemSetting).where(SystemSetting.key == "platform") + ) + setting = result.scalar_one_or_none() + if setting and setting.value.get("public_base_url"): + return setting.value["public_base_url"].rstrip("/") + # Fallback to ENV + env_url = os.environ.get("PUBLIC_BASE_URL") + if env_url: + return env_url.rstrip("/") + return None + + +async def resolve_base_url( + db: AsyncSession, + request: Request | None = None, + tenant_id: str | None = None, +) -> str: + """Resolve the effective base URL using the fallback chain: + + 1. Tenant-specific sso_domain (if tenant_id provided and tenant has sso_domain) + 2. Tenant subdomain_prefix + global hostname + 3. Platform global public_base_url (from system_settings) + 4. Request origin (from request.base_url) + 5. Hardcoded fallback + + Returns a full URL like "https://acme.example.com" or "http://localhost:3008" + """ + # Level 1 & 2: Tenant-specific + if tenant_id: + from app.models.tenant import Tenant + result = await db.execute(select(Tenant).where(Tenant.id == tenant_id)) + tenant = result.scalar_one_or_none() + if tenant: + # Level 1: complete custom domain + if tenant.sso_domain: + domain = tenant.sso_domain.rstrip("/") + if domain.startswith("http://") or domain.startswith("https://"): + return domain + return f"https://{domain}" + + # Level 2: subdomain prefix + global hostname (skip for default tenant) + if tenant.subdomain_prefix and not getattr(tenant, 'is_default', False): + global_url = await _get_global_base_url(db) + if global_url: + from urllib.parse import urlparse + parsed = urlparse(global_url) + host = f"{tenant.subdomain_prefix}.{parsed.hostname}" + if parsed.port and parsed.port not in (80, 443): + host = f"{host}:{parsed.port}" + return f"{parsed.scheme}://{host}" + + # Level 3: Platform global setting + global_url = await _get_global_base_url(db) + if global_url: + return global_url + + # Level 4: Request origin + if request: + return str(request.base_url).rstrip("/") + + # Level 5: Hardcoded fallback + return "http://localhost:8000" diff --git a/backend/app/models/tenant.py b/backend/app/models/tenant.py index 113238c62..d122bdc9f 100644 --- a/backend/app/models/tenant.py +++ b/backend/app/models/tenant.py @@ -44,6 +44,11 @@ class Tenant(Base): sso_enabled: Mapped[bool] = mapped_column(Boolean, default=False) sso_domain: Mapped[str | None] = mapped_column(String(255), unique=True, index=True, nullable=True) + # Subdomain prefix for auto-generated tenant URLs (e.g. "acme" → acme.clawith.com) + subdomain_prefix: Mapped[str | None] = mapped_column(String(50), unique=True, index=True, nullable=True) + # Whether this is the platform's default tenant + is_default: Mapped[bool] = mapped_column(Boolean, default=False) + # Trigger limits — defaults for new agents & floor values default_max_triggers: Mapped[int] = mapped_column(Integer, default=20) min_poll_interval_floor: Mapped[int] = mapped_column(Integer, default=5) diff --git a/backend/app/services/platform_service.py b/backend/app/services/platform_service.py index 7bdc61faf..3f73aa8d7 100644 --- a/backend/app/services/platform_service.py +++ b/backend/app/services/platform_service.py @@ -20,23 +20,37 @@ def is_ip_address(self, host: str) -> bool: async def get_public_base_url(self, db: AsyncSession | None = None, request: Request | None = None) -> str: """Resolve the platform's public base URL with priority lookup. - + Priority: 1. Environment variable (PUBLIC_BASE_URL) - from .env or docker - 2. Incoming request's base URL (browser address) - 3. Hardcoded fallback (https://try.clawith.ai) + 2. Database system_settings (platform.public_base_url) + 3. Incoming request's base URL (browser address) + 4. Hardcoded fallback (https://try.clawith.ai) """ # 1. Try environment variable env_url = os.environ.get("PUBLIC_BASE_URL") if env_url: return env_url.rstrip("/") - # 2. Fallback to request (browser address) + # 2. Try database system_settings + if db: + try: + from app.models.system_settings import SystemSetting + result = await db.execute( + select(SystemSetting).where(SystemSetting.key == "platform") + ) + setting = result.scalar_one_or_none() + if setting and setting.value and setting.value.get("public_base_url"): + return setting.value["public_base_url"].rstrip("/") + except Exception: + pass + + # 3. Fallback to request (browser address) if request: # Note: request.base_url might include trailing slash return str(request.base_url).rstrip("/") - # 3. Absolute fallback + # 4. Absolute fallback return "https://try.clawith.ai" diff --git a/frontend/src/i18n/en.json b/frontend/src/i18n/en.json index a5be3fd99..d01fbfbef 100644 --- a/frontend/src/i18n/en.json +++ b/frontend/src/i18n/en.json @@ -1724,6 +1724,10 @@ "ssoConfigDesc": "Configure SSO and custom domain for this company.", "ssoEnabled": "Enable SSO", "ssoDomain": "Custom Access Domain", - "ssoDomainPlaceholder": "e.g. acme.clawith.com" + "ssoDomainPlaceholder": "e.g. acme.clawith.com", + "publicUrl": { + "title": "Public URL", + "desc": "The external URL used for SSO callbacks, subdomain generation, and email links. Include the protocol (e.g. https://example.com)." + } } } diff --git a/frontend/src/i18n/zh.json b/frontend/src/i18n/zh.json index ee020b0e1..465cc21b6 100644 --- a/frontend/src/i18n/zh.json +++ b/frontend/src/i18n/zh.json @@ -1619,7 +1619,11 @@ "all": "全部", "filterStatus": "按状态筛选", "enable": "启用", - "noFilterResults": "没有符合当前筛选条件的公司。" + "noFilterResults": "没有符合当前筛选条件的公司。", + "publicUrl": { + "title": "公开访问地址", + "desc": "用于 SSO 回调、子域名生成和邮件链接的外部 URL。请包含协议(如 https://example.com)。" + } }, "companySetup": { "title": "设置你的公司", diff --git a/frontend/src/pages/AdminCompanies.tsx b/frontend/src/pages/AdminCompanies.tsx index 884b0f62d..bd0b66713 100644 --- a/frontend/src/pages/AdminCompanies.tsx +++ b/frontend/src/pages/AdminCompanies.tsx @@ -109,6 +109,10 @@ function PlatformTab() { const [nbSaving, setNbSaving] = useState(false); const [nbSaved, setNbSaved] = useState(false); + // Public Base URL + const [publicBaseUrl, setPublicBaseUrl] = useState(''); + const [publicBaseUrlSaving, setPublicBaseUrlSaving] = useState(false); + const [publicBaseUrlSaved, setPublicBaseUrlSaved] = useState(false); // System email configuration const [systemEmailConfig, setSystemEmailConfig] = useState({ @@ -163,6 +167,16 @@ function PlatformTab() { } }).catch(() => { }); + // Load Public Base URL + const token2 = localStorage.getItem('token'); + fetch('/api/enterprise/system-settings/platform', { + headers: { 'Content-Type': 'application/json', ...(token2 ? { Authorization: `Bearer ${token2}` } : {}) }, + }).then(r => r.json()).then(d => { + if (d?.value?.public_base_url) { + setPublicBaseUrl(d.value.public_base_url); + } + }).catch(() => { }); + // Load System Email fetchJson('/enterprise/system-settings/system_email_platform') .then(d => { @@ -219,6 +233,24 @@ function PlatformTab() { }; + const savePublicBaseUrl = async () => { + setPublicBaseUrlSaving(true); + try { + const token = localStorage.getItem('token'); + await fetch('/api/enterprise/system-settings/platform', { + method: 'PUT', + headers: { 'Content-Type': 'application/json', ...(token ? { Authorization: `Bearer ${token}` } : {}) }, + body: JSON.stringify({ value: { public_base_url: publicBaseUrl.trim() || null } }), + }); + setPublicBaseUrlSaved(true); + setTimeout(() => setPublicBaseUrlSaved(false), 2000); + showToast(t('enterprise.config.saved', 'Saved')); + } catch (e: any) { + showToast(e.message || t('common.error', 'Failed to save'), 'error'); + } + setPublicBaseUrlSaving(false); + }; + const saveEmailConfig = async () => { setEmailConfigSaving(true); try { @@ -334,6 +366,29 @@ function PlatformTab() { + {/* Public Base URL */} +
+
+ {t('admin.publicUrl.title', 'Public URL')} +
+
+ {t('admin.publicUrl.desc', 'The external URL used for webhook callbacks and published page links. Include the protocol (e.g. https://example.com).')} +
+
+ setPublicBaseUrl(e.target.value)} + placeholder="https://clawith.example.com" + style={{ fontSize: '13px', flex: 1, maxWidth: '400px' }} + /> + + {publicBaseUrlSaved && {t('enterprise.config.saved', 'Saved')}} +
+
+ {/* Notification Bar */}
From 9a65575d2990aa3d7afe0119af71dc7745c1f2bc Mon Sep 17 00:00:00 2001 From: "nap.liu" Date: Wed, 8 Apr 2026 20:12:04 +0800 Subject: [PATCH 3/5] feat: add generic OAuth2 SSO login with configurable field mapping Add a generic OAuth2AuthProvider that works with any OAuth2-compliant identity provider (Google, Azure AD, Keycloak, Auth0, custom corporate OAuth2 servers, etc.). Backend: - New OAuth2AuthProvider class with configurable authorize_url, token_url, userinfo_url, client_id, client_secret, scope, and field_mapping - Token exchange uses application/x-www-form-urlencoded (RFC 6749) - Graceful handling of userinfo 401/empty/invalid responses - Configurable field_mapping maps provider fields to Clawith fields (provider_user_id, email, display_name, mobile, avatar_url) - Standard OIDC field fallbacks when no custom mapping is configured - Provider registered in auth_registry as "oauth2" - SSO callback route (GET /auth/oauth2/callback) with session handling - OAuth2 provider type added to SSO config endpoint Frontend: - OAuth2 configuration form with Token URL, UserInfo URL, Scope fields - Field Mapping section for custom provider field names - Save/update via dedicated OAuth2 API endpoints Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/app/api/auth.py | 20 +- backend/app/api/enterprise.py | 16 +- backend/app/api/sso.py | 91 ++++ backend/app/schemas/schemas.py | 1 + backend/app/services/auth_provider.py | 149 +++++ backend/app/services/auth_registry.py | 1 + backend/app/services/dingtalk_stream.py | 636 +++++++++++++++++++--- frontend/src/i18n/en.json | 25 +- frontend/src/i18n/zh.json | 25 +- frontend/src/pages/EnterpriseSettings.tsx | 62 ++- 10 files changed, 945 insertions(+), 81 deletions(-) diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py index 5c0cc4e27..26ed39cc6 100644 --- a/backend/app/api/auth.py +++ b/backend/app/api/auth.py @@ -916,14 +916,26 @@ async def oauth_callback( raise HTTPException(status_code=404, detail=f"Provider '{provider}' not supported") try: - # Exchange code for token - token_data = await auth_provider.exchange_code_for_token(data.code) + # Exchange code for token (pass redirect_uri for OAuth2 providers that require it) + if hasattr(auth_provider, 'exchange_code_for_token') and data.redirect_uri: + token_data = await auth_provider.exchange_code_for_token(data.code, redirect_uri=data.redirect_uri) + else: + token_data = await auth_provider.exchange_code_for_token(data.code) access_token = token_data.get("access_token") if not access_token: raise HTTPException(status_code=400, detail="Failed to get access token from provider") - # Get user info - user_info = await auth_provider.get_user_info(access_token) + # Get user info with fallback to token_data extraction + try: + user_info = await auth_provider.get_user_info(access_token) + except Exception: + if hasattr(auth_provider, 'get_user_info_from_token_data'): + user_info = await auth_provider.get_user_info_from_token_data(token_data) + else: + raise + if not user_info.provider_user_id and hasattr(auth_provider, 'get_user_info_from_token_data'): + # try token_data as last resort + user_info = await auth_provider.get_user_info_from_token_data(token_data) # Find or create user user, is_new = await auth_provider.find_or_create_user(db, user_info) diff --git a/backend/app/api/enterprise.py b/backend/app/api/enterprise.py index 25d98c802..05f29b1de 100644 --- a/backend/app/api/enterprise.py +++ b/backend/app/api/enterprise.py @@ -793,6 +793,7 @@ class OAuth2Config(BaseModel): token_url: str | None = None # OAuth2 token endpoint user_info_url: str | None = None # OAuth2 user info endpoint scope: str | None = "openid profile email" + field_mapping: dict | None = None # Custom field name mapping def to_config_dict(self) -> dict: """Convert to config dict with both naming conventions for compatibility.""" @@ -811,6 +812,8 @@ def to_config_dict(self) -> dict: config["user_info_url"] = self.user_info_url if self.scope: config["scope"] = self.scope + if self.field_mapping: + config["field_mapping"] = self.field_mapping return config @classmethod @@ -823,6 +826,7 @@ def from_config_dict(cls, config: dict) -> "OAuth2Config": token_url=config.get("token_url"), user_info_url=config.get("user_info_url"), scope=config.get("scope"), + field_mapping=config.get("field_mapping"), ) @@ -837,6 +841,7 @@ class IdentityProviderOAuth2Create(BaseModel): token_url: str user_info_url: str scope: str | None = "openid profile email" + field_mapping: dict | None = None tenant_id: uuid.UUID | None = None @@ -936,6 +941,7 @@ async def create_oauth2_provider( token_url=data.token_url, user_info_url=data.user_info_url, scope=data.scope, + field_mapping=data.field_mapping, ) config = oauth_config.to_config_dict() @@ -962,6 +968,7 @@ async def create_oauth2_provider( provider_type="oauth2", name=data.name, is_active=data.is_active, + sso_login_enabled=True, config=config, tenant_id=tid ) @@ -981,6 +988,7 @@ class OAuth2ConfigUpdate(BaseModel): token_url: str | None = None user_info_url: str | None = None scope: str | None = None + field_mapping: dict | None = None # Custom field name mapping @router.patch("/identity-providers/{provider_id}/oauth2", response_model=IdentityProviderOut) @@ -1009,7 +1017,7 @@ async def update_oauth2_provider( provider.is_active = data.is_active # Update config fields - if any([data.app_id, data.app_secret is not None, data.authorize_url, data.token_url, data.user_info_url, data.scope]): + if any([data.app_id, data.app_secret is not None, data.authorize_url, data.token_url, data.user_info_url, data.scope, data.field_mapping is not None]): current_config = provider.config.copy() if data.app_id is not None: @@ -1031,6 +1039,12 @@ async def update_oauth2_provider( current_config["user_info_url"] = data.user_info_url if data.scope is not None: current_config["scope"] = data.scope + if data.field_mapping is not None: + # Empty dict or explicit None clears the mapping + if data.field_mapping: + current_config["field_mapping"] = data.field_mapping + else: + current_config.pop("field_mapping", None) # Validate the updated config validate_provider_config("oauth2", current_config) diff --git a/backend/app/api/sso.py b/backend/app/api/sso.py index 940b5d7b4..a1f07e441 100644 --- a/backend/app/api/sso.py +++ b/backend/app/api/sso.py @@ -4,6 +4,7 @@ from urllib.parse import quote from fastapi import APIRouter, Depends, HTTPException, Request, status +from loguru import logger from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -138,5 +139,95 @@ async def get_sso_config(sid: uuid.UUID, request: Request, db: AsyncSession = De url = f"https://open.work.weixin.qq.com/wwopen/sso/qrConnect?appid={corp_id}&agentid={agent_id}&redirect_uri={quote(redir)}&state={sid}" auth_urls.append({"provider_type": "wecom", "name": p.name, "url": url}) + elif p.provider_type == "oauth2": + from app.services.auth_registry import auth_provider_registry + auth_provider = await auth_provider_registry.get_provider( + db, "oauth2", str(session.tenant_id) if session.tenant_id else None + ) + if auth_provider: + redir = f"{public_base}/api/auth/oauth2/callback" + url = await auth_provider.get_authorization_url(redir, str(sid)) + auth_urls.append({"provider_type": "oauth2", "name": p.name, "url": url}) + return auth_urls + +@router.get("/auth/oauth2/callback") +async def oauth2_callback( + code: str, + state: str = None, + db: AsyncSession = Depends(get_db), +): + """Handle OAuth2 SSO callback -- exchange code for user session.""" + from app.core.security import create_access_token + from fastapi.responses import HTMLResponse + from app.services.auth_registry import auth_provider_registry + + # 1. Resolve tenant context from state (= session ID) + tenant_id = None + sid = None + if state: + try: + sid = uuid.UUID(state) + s_res = await db.execute(select(SSOScanSession).where(SSOScanSession.id == sid)) + session = s_res.scalar_one_or_none() + if session: + tenant_id = session.tenant_id + except (ValueError, AttributeError): + pass + + # 2. Get OAuth2 provider + auth_provider = await auth_provider_registry.get_provider( + db, "oauth2", str(tenant_id) if tenant_id else None + ) + if not auth_provider: + return HTMLResponse("Auth failed: OAuth2 provider not configured") + + # 3. Exchange code -> token -> user info -> find/create user + try: + token_data = await auth_provider.exchange_code_for_token(code) + access_token = token_data.get("access_token") + if not access_token: + logger.error("OAuth2 token exchange returned no access_token") + return HTMLResponse("Auth failed: token exchange error") + + user_info = await auth_provider.get_user_info(access_token) + if not user_info.provider_user_id: + logger.error("OAuth2 user info missing user ID") + return HTMLResponse("Auth failed: no user ID returned") + + user, is_new = await auth_provider.find_or_create_user( + db, user_info, tenant_id=str(tenant_id) if tenant_id else None + ) + if not user: + return HTMLResponse("Auth failed: user resolution failed") + + except Exception as e: + logger.error("OAuth2 login error: %s", e) + return HTMLResponse(f"Auth failed: {e!s}") + + # 4. Generate JWT, update SSO session + token = create_access_token(str(user.id), user.role) + + if sid: + try: + s_res = await db.execute(select(SSOScanSession).where(SSOScanSession.id == sid)) + session = s_res.scalar_one_or_none() + if session: + session.status = "authorized" + session.provider_type = "oauth2" + session.user_id = user.id + session.access_token = token + session.error_msg = None + await db.commit() + return HTMLResponse( + '' + '
SSO login successful. Redirecting...
' + f'' + '' + ) + except Exception as e: + logger.exception("Failed to update SSO session (oauth2): %s", e) + + return HTMLResponse("Logged in successfully.") + diff --git a/backend/app/schemas/schemas.py b/backend/app/schemas/schemas.py index 3870392b9..f0da50b62 100644 --- a/backend/app/schemas/schemas.py +++ b/backend/app/schemas/schemas.py @@ -180,6 +180,7 @@ class OAuthAuthorizeResponse(BaseModel): class OAuthCallbackRequest(BaseModel): code: str state: str + redirect_uri: str = "" class IdentityBindRequest(BaseModel): diff --git a/backend/app/services/auth_provider.py b/backend/app/services/auth_provider.py index d40cd583b..0d2543098 100644 --- a/backend/app/services/auth_provider.py +++ b/backend/app/services/auth_provider.py @@ -637,6 +637,154 @@ async def get_user_info(self, access_token: str) -> ExternalUserInfo: ) +class OAuth2AuthProvider(BaseAuthProvider): + """Generic OAuth2 provider (RFC 6749 Authorization Code flow). + + Works with any OAuth2-compliant identity provider (Google, Azure AD, + Keycloak, Auth0, custom corporate OAuth2 servers, etc.). + + Config keys: + client_id, client_secret, authorize_url, token_url, userinfo_url, + scope, field_mapping + """ + + provider_type = "oauth2" + + def __init__(self, provider=None, config=None): + super().__init__(provider, config) + self.client_id = self.config.get("client_id") or self.config.get("app_id", "") + self.client_secret = self.config.get("client_secret") or self.config.get("app_secret", "") + self.authorize_url = self.config.get("authorize_url", "") + self.token_url = self.config.get("token_url", "") + self.userinfo_url = self.config.get("userinfo_url") or self.config.get("user_info_url", "") + self.scope = self.config.get("scope", "openid profile email") + + # Derive token_url / userinfo_url from authorize_url if not provided + if self.authorize_url and not self.token_url: + base = self.authorize_url.rsplit("/", 1)[0] + self.token_url = f"{base}/token" + if self.authorize_url and not self.userinfo_url: + base = self.authorize_url.rsplit("/", 1)[0] + self.userinfo_url = f"{base}/userinfo" + + # Configurable field mapping: provider response field -> Clawith field + self.field_mapping = self.config.get("field_mapping") or {} + + async def get_authorization_url(self, redirect_uri: str, state: str) -> str: + from urllib.parse import quote, urlencode + params = urlencode({ + "response_type": "code", + "client_id": self.client_id, + "redirect_uri": redirect_uri, + "scope": self.scope, + "state": state, + }) + return f"{self.authorize_url}?{params}" + + async def exchange_code_for_token(self, code: str, redirect_uri: str = "") -> dict: + """Exchange authorization code for access token. + + Uses application/x-www-form-urlencoded per RFC 6749 Section 4.1.3. + """ + async with httpx.AsyncClient() as client: + data = { + "grant_type": "authorization_code", + "code": code, + "client_id": self.client_id, + "client_secret": self.client_secret, + } + if redirect_uri: + data["redirect_uri"] = redirect_uri + resp = await client.post( + self.token_url, + data=data, + ) + if resp.status_code != 200: + logger.error( + "OAuth2 token exchange failed (HTTP %s): %s", + resp.status_code, + resp.text[:500], + ) + return {} + return resp.json() + + def _resolve_field(self, data: dict, clawith_field: str, default_keys: list[str]) -> str: + """Resolve a field using user-configured mapping first, then standard fallbacks.""" + custom_key = self.field_mapping.get(clawith_field) + if custom_key and data.get(custom_key): + return str(data[custom_key]) + for key in default_keys: + if data.get(key): + return str(data[key]) + return "" + + async def get_user_info(self, access_token: str) -> ExternalUserInfo: + async with httpx.AsyncClient() as client: + resp = await client.get( + self.userinfo_url, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + # Gracefully handle non-200 responses + if resp.status_code != 200: + logger.error("OAuth2 userinfo returned HTTP %s", resp.status_code) + return ExternalUserInfo( + provider_type=self.provider_type, + provider_user_id="", + raw_data={"error": f"userinfo HTTP {resp.status_code}"}, + ) + + try: + resp_data = resp.json() + except Exception: + logger.error("OAuth2 userinfo response is not valid JSON") + return ExternalUserInfo( + provider_type=self.provider_type, + provider_user_id="", + raw_data={"error": "userinfo JSON parse error"}, + ) + + # Some providers wrap payload in {"data": {...}}; standard OIDC is flat + if "data" in resp_data and isinstance(resp_data["data"], dict): + info = resp_data["data"] + else: + info = resp_data + + user_id = self._resolve_field(info, "provider_user_id", ["sub", "id", "user_id", "userId"]) + name = self._resolve_field(info, "display_name", ["name", "preferred_username", "nickname", "userName"]) + email = self._resolve_field(info, "email", ["email"]) + mobile = self._resolve_field(info, "mobile", ["phone_number", "mobile", "phone"]) + avatar = self._resolve_field(info, "avatar_url", ["picture", "avatar_url", "avatar"]) + + return ExternalUserInfo( + provider_type=self.provider_type, + provider_user_id=str(user_id), + name=name, + email=email, + mobile=mobile, + avatar_url=avatar, + raw_data=info, + ) + + async def get_user_info_from_token_data(self, token_data: dict) -> ExternalUserInfo: + """Fallback: extract user info directly from token response data.""" + info = token_data + user_id = self._resolve_field(info, "provider_user_id", ["sub", "id", "user_id", "userId", "openid"]) + name = self._resolve_field(info, "display_name", ["name", "preferred_username", "nickname"]) + email = self._resolve_field(info, "email", ["email"]) + mobile = self._resolve_field(info, "mobile", ["phone_number", "mobile", "phone"]) + avatar = self._resolve_field(info, "avatar_url", ["picture", "avatar_url", "avatar"]) + return ExternalUserInfo( + provider_type=self.provider_type, + provider_user_id=str(user_id), + name=name, + email=email, + mobile=mobile, + avatar_url=avatar, + raw_data=info, + ) + + class MicrosoftTeamsAuthProvider(BaseAuthProvider): """Microsoft Teams OAuth provider implementation.""" @@ -658,5 +806,6 @@ async def get_user_info(self, access_token: str) -> ExternalUserInfo: "feishu": FeishuAuthProvider, "dingtalk": DingTalkAuthProvider, "wecom": WeComAuthProvider, + "oauth2": OAuth2AuthProvider, "microsoft_teams": MicrosoftTeamsAuthProvider, } diff --git a/backend/app/services/auth_registry.py b/backend/app/services/auth_registry.py index c3e9b0035..3c655bade 100644 --- a/backend/app/services/auth_registry.py +++ b/backend/app/services/auth_registry.py @@ -15,6 +15,7 @@ DingTalkAuthProvider, FeishuAuthProvider, MicrosoftTeamsAuthProvider, + OAuth2AuthProvider, WeComAuthProvider, ) diff --git a/backend/app/services/dingtalk_stream.py b/backend/app/services/dingtalk_stream.py index 28a8ba8e2..73cf80b20 100644 --- a/backend/app/services/dingtalk_stream.py +++ b/backend/app/services/dingtalk_stream.py @@ -5,15 +5,401 @@ """ import asyncio +import base64 +import json import threading import uuid -from typing import Dict +from pathlib import Path +from typing import Dict, List, Optional, Tuple +import httpx from loguru import logger from sqlalchemy import select +from app.config import get_settings from app.database import async_session from app.models.channel_config import ChannelConfig +from app.services.dingtalk_token import dingtalk_token_manager + + +# ─── DingTalk Media Helpers ───────────────────────────── + + +async def _get_media_download_url( + access_token: str, download_code: str, robot_code: str +) -> Optional[str]: + """Get media file download URL from DingTalk API.""" + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.post( + "https://api.dingtalk.com/v1.0/robot/messageFiles/download", + headers={"x-acs-dingtalk-access-token": access_token}, + json={"downloadCode": download_code, "robotCode": robot_code}, + ) + data = resp.json() + url = data.get("downloadUrl") + if url: + return url + logger.error(f"[DingTalk] Failed to get download URL: {data}") + return None + except Exception as e: + logger.error(f"[DingTalk] Error getting download URL: {e}") + return None + + +async def _download_file(url: str) -> Optional[bytes]: + """Download a file from a URL and return its bytes.""" + try: + async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client: + resp = await client.get(url) + resp.raise_for_status() + return resp.content + except Exception as e: + logger.error(f"[DingTalk] Error downloading file: {e}") + return None + + +async def download_dingtalk_media( + app_key: str, app_secret: str, download_code: str +) -> Optional[bytes]: + """Download a media file from DingTalk using downloadCode. + + Steps: get access_token -> get download URL -> download file bytes. + """ + access_token = await dingtalk_token_manager.get_token(app_key, app_secret) + if not access_token: + return None + + download_url = await _get_media_download_url(access_token, download_code, app_key) + if not download_url: + return None + + return await _download_file(download_url) + + +def _resolve_upload_dir(agent_id: uuid.UUID) -> Path: + """Get the uploads directory for an agent, creating it if needed.""" + settings = get_settings() + upload_dir = Path(settings.AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads" + upload_dir.mkdir(parents=True, exist_ok=True) + return upload_dir + + +async def _process_media_message( + msg_data: dict, + app_key: str, + app_secret: str, + agent_id: uuid.UUID, +) -> Tuple[str, Optional[List[str]], Optional[List[str]]]: + """Process a DingTalk message and extract text + media info. + + Returns: + (user_text, image_base64_list, saved_file_paths) + - user_text: text content for the LLM (may include markers) + - image_base64_list: list of base64-encoded image data URIs, or None + - saved_file_paths: list of saved file paths, or None + """ + msgtype = msg_data.get("msgtype", "text") + logger.info(f"[DingTalk] Processing message type: {msgtype}") + + image_base64_list: List[str] = [] + saved_file_paths: List[str] = [] + + if msgtype == "text": + text_content = msg_data.get("text", {}).get("content", "").strip() + return text_content, None, None + + elif msgtype == "picture": + download_code = msg_data.get("content", {}).get("downloadCode", "") + if not download_code: + download_code = msg_data.get("downloadCode", "") + if not download_code: + logger.warning("[DingTalk] Picture message without downloadCode") + return "[User sent an image, but it could not be downloaded]", None, None + + file_bytes = await download_dingtalk_media(app_key, app_secret, download_code) + if not file_bytes: + return "[User sent an image, but download failed]", None, None + + upload_dir = _resolve_upload_dir(agent_id) + filename = f"dingtalk_img_{uuid.uuid4().hex[:8]}.jpg" + save_path = upload_dir / filename + save_path.write_bytes(file_bytes) + logger.info(f"[DingTalk] Saved image to {save_path} ({len(file_bytes)} bytes)") + + b64_data = base64.b64encode(file_bytes).decode("ascii") + image_marker = f"[image_data:data:image/jpeg;base64,{b64_data}]" + return ( + f"[User sent an image]\n{image_marker}", + [f"data:image/jpeg;base64,{b64_data}"], + [str(save_path)], + ) + + elif msgtype == "richText": + rich_text = msg_data.get("content", {}).get("richText", []) + text_parts: List[str] = [] + + for section in rich_text: + for item in section if isinstance(section, list) else [section]: + if "text" in item: + text_parts.append(item["text"]) + elif "downloadCode" in item: + file_bytes = await download_dingtalk_media( + app_key, app_secret, item["downloadCode"] + ) + if file_bytes: + upload_dir = _resolve_upload_dir(agent_id) + filename = f"dingtalk_richimg_{uuid.uuid4().hex[:8]}.jpg" + save_path = upload_dir / filename + save_path.write_bytes(file_bytes) + logger.info(f"[DingTalk] Saved rich text image to {save_path}") + + b64_data = base64.b64encode(file_bytes).decode("ascii") + image_marker = f"[image_data:data:image/jpeg;base64,{b64_data}]" + text_parts.append(image_marker) + image_base64_list.append(f"data:image/jpeg;base64,{b64_data}") + saved_file_paths.append(str(save_path)) + + combined_text = "\n".join(text_parts).strip() + if not combined_text: + combined_text = "[User sent a rich text message]" + + return ( + combined_text, + image_base64_list if image_base64_list else None, + saved_file_paths if saved_file_paths else None, + ) + + elif msgtype == "audio": + content = msg_data.get("content", {}) + recognition = content.get("recognition", "") + if recognition: + logger.info(f"[DingTalk] Audio with recognition: {recognition[:80]}") + return f"[Voice message] {recognition}", None, None + + download_code = content.get("downloadCode", "") + if download_code: + file_bytes = await download_dingtalk_media(app_key, app_secret, download_code) + if file_bytes: + upload_dir = _resolve_upload_dir(agent_id) + duration = content.get("duration", "unknown") + filename = f"dingtalk_audio_{uuid.uuid4().hex[:8]}.amr" + save_path = upload_dir / filename + save_path.write_bytes(file_bytes) + logger.info(f"[DingTalk] Saved audio to {save_path} ({len(file_bytes)} bytes)") + return ( + f"[User sent a voice message, duration {duration}ms, saved to {filename}]", + None, + [str(save_path)], + ) + return "[User sent a voice message, but it could not be processed]", None, None + + elif msgtype == "video": + content = msg_data.get("content", {}) + download_code = content.get("downloadCode", "") + if download_code: + file_bytes = await download_dingtalk_media(app_key, app_secret, download_code) + if file_bytes: + upload_dir = _resolve_upload_dir(agent_id) + duration = content.get("duration", "unknown") + filename = f"dingtalk_video_{uuid.uuid4().hex[:8]}.mp4" + save_path = upload_dir / filename + save_path.write_bytes(file_bytes) + logger.info(f"[DingTalk] Saved video to {save_path} ({len(file_bytes)} bytes)") + return ( + f"[User sent a video, duration {duration}ms, saved to {filename}]", + None, + [str(save_path)], + ) + return "[User sent a video, but it could not be downloaded]", None, None + + elif msgtype == "file": + content = msg_data.get("content", {}) + download_code = content.get("downloadCode", "") + original_filename = content.get("fileName", "unknown_file") + if download_code: + file_bytes = await download_dingtalk_media(app_key, app_secret, download_code) + if file_bytes: + upload_dir = _resolve_upload_dir(agent_id) + safe_name = f"dingtalk_{uuid.uuid4().hex[:8]}_{original_filename}" + save_path = upload_dir / safe_name + save_path.write_bytes(file_bytes) + logger.info( + f"[DingTalk] Saved file '{original_filename}' to {save_path} " + f"({len(file_bytes)} bytes)" + ) + return ( + f"[file:{original_filename}]", + None, + [str(save_path)], + ) + return f"[User sent file {original_filename}, but it could not be downloaded]", None, None + + else: + logger.warning(f"[DingTalk] Unsupported message type: {msgtype}") + return f"[User sent a {msgtype} message, which is not yet supported]", None, None + + +# ─── DingTalk Media Upload & Send ─────────────────────── + +async def _upload_dingtalk_media( + app_key: str, + app_secret: str, + file_path: str, + media_type: str = "file", +) -> Optional[str]: + """Upload a media file to DingTalk and return the mediaId. + + Args: + app_key: DingTalk app key (robotCode). + app_secret: DingTalk app secret. + file_path: Local file path to upload. + media_type: One of 'image', 'voice', 'video', 'file'. + + Returns: + mediaId string on success, None on failure. + """ + access_token = await dingtalk_token_manager.get_token(app_key, app_secret) + if not access_token: + return None + + file_p = Path(file_path) + if not file_p.exists(): + logger.error(f"[DingTalk] Upload failed: file not found: {file_path}") + return None + + try: + file_bytes = file_p.read_bytes() + async with httpx.AsyncClient(timeout=60) as client: + # Use the legacy oapi endpoint which is more reliable and widely supported. + upload_url = ( + f"https://oapi.dingtalk.com/media/upload" + f"?access_token={access_token}&type={media_type}" + ) + resp = await client.post( + upload_url, + files={"media": (file_p.name, file_bytes)}, + ) + data = resp.json() + # Legacy API returns media_id (snake_case), new API returns mediaId + media_id = data.get("media_id") or data.get("mediaId") + if media_id and data.get("errcode", 0) == 0: + logger.info( + f"[DingTalk] Uploaded {media_type} '{file_p.name}' -> mediaId={media_id[:20]}..." + ) + return media_id + logger.error(f"[DingTalk] Upload failed: {data}") + return None + except Exception as e: + logger.error(f"[DingTalk] Upload error: {e}") + return None + + +async def _send_dingtalk_media_message( + app_key: str, + app_secret: str, + target_id: str, + media_id: str, + media_type: str, + conversation_type: str, + filename: Optional[str] = None, +) -> bool: + """Send a media message via DingTalk proactive message API. + + Args: + app_key: DingTalk app key (robotCode). + app_secret: DingTalk app secret. + target_id: For P2P: sender_staff_id; For group: openConversationId. + media_id: The mediaId from upload. + media_type: One of 'image', 'voice', 'video', 'file'. + conversation_type: '1' for P2P, '2' for group. + filename: Original filename (used for file/video types). + + Returns: + True on success, False on failure. + """ + access_token = await dingtalk_token_manager.get_token(app_key, app_secret) + if not access_token: + return False + + headers = {"x-acs-dingtalk-access-token": access_token} + + # Build msgKey and msgParam based on media_type + if media_type == "image": + msg_key = "sampleImageMsg" + msg_param = json.dumps({"photoURL": media_id}) + elif media_type == "voice": + msg_key = "sampleAudio" + msg_param = json.dumps({"mediaId": media_id, "duration": "3000"}) + elif media_type == "video": + safe_name = filename or "video.mp4" + ext = Path(safe_name).suffix.lstrip(".") or "mp4" + msg_key = "sampleFile" + msg_param = json.dumps({ + "mediaId": media_id, + "fileName": safe_name, + "fileType": ext, + }) + else: + # file + safe_name = filename or "file" + ext = Path(safe_name).suffix.lstrip(".") or "bin" + msg_key = "sampleFile" + msg_param = json.dumps({ + "mediaId": media_id, + "fileName": safe_name, + "fileType": ext, + }) + + try: + async with httpx.AsyncClient(timeout=15) as client: + if conversation_type == "2": + # Group chat + resp = await client.post( + "https://api.dingtalk.com/v1.0/robot/groupMessages/send", + headers=headers, + json={ + "robotCode": app_key, + "openConversationId": target_id, + "msgKey": msg_key, + "msgParam": msg_param, + }, + ) + else: + # P2P chat + resp = await client.post( + "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend", + headers=headers, + json={ + "robotCode": app_key, + "userIds": [target_id], + "msgKey": msg_key, + "msgParam": msg_param, + }, + ) + + data = resp.json() + if resp.status_code >= 400 or data.get("errcode"): + logger.error(f"[DingTalk] Send media failed: {data}") + return False + + logger.info( + f"[DingTalk] Sent {media_type} message to {target_id[:16]}... " + f"(conv_type={conversation_type})" + ) + return True + except Exception as e: + logger.error(f"[DingTalk] Send media error: {e}") + return False + + +# ─── Stream Manager ───────────────────────────────────── + + +def _fire_and_forget(loop, coro): + """Schedule a coroutine on the main loop and log any unhandled exception.""" + future = asyncio.run_coroutine_threadsafe(coro, loop) + future.add_done_callback(lambda f: f.exception() if not f.cancelled() else None) class DingTalkStreamManager: @@ -67,47 +453,72 @@ def _run_client_thread( app_secret: str, stop_event: threading.Event, ): - """Run the DingTalk Stream client in a blocking thread.""" + """Run the DingTalk Stream client with auto-reconnect.""" try: import dingtalk_stream + except ImportError: + logger.warning( + "[DingTalk Stream] dingtalk-stream package not installed. " + "Install with: pip install dingtalk-stream" + ) + self._threads.pop(agent_id, None) + self._stop_events.pop(agent_id, None) + return - # Reference to manager's main loop for async dispatch - main_loop = self._main_loop - - class ClawithChatbotHandler(dingtalk_stream.ChatbotHandler): - """Custom handler that dispatches messages to the Clawith LLM pipeline.""" - - async def process(self, callback: dingtalk_stream.CallbackMessage): - """Handle incoming bot message from DingTalk Stream. - - NOTE: The SDK invokes this method in the thread's own asyncio loop, - so we must dispatch to the main FastAPI loop for DB + LLM work. - """ - try: - # Parse the raw data into a ChatbotMessage via class method - incoming = dingtalk_stream.ChatbotMessage.from_dict(callback.data) - - # Extract text content + MAX_RETRIES = 5 + RETRY_DELAYS = [2, 5, 15, 30, 60] # exponential backoff, seconds + + # Reference to manager's main loop for async dispatch + main_loop = self._main_loop + retries = 0 + manager_self = self + + class ClawithChatbotHandler(dingtalk_stream.ChatbotHandler): + """Custom handler that dispatches messages to the Clawith LLM pipeline.""" + + async def process(self, callback: dingtalk_stream.CallbackMessage): + """Handle incoming bot message from DingTalk Stream. + + NOTE: The SDK invokes this method in the thread's own asyncio loop, + so we must dispatch to the main FastAPI loop for DB + LLM work. + """ + try: + # Parse the raw data + incoming = dingtalk_stream.ChatbotMessage.from_dict(callback.data) + msg_data = callback.data if isinstance(callback.data, dict) else json.loads(callback.data) + + msgtype = msg_data.get("msgtype", "text") + sender_staff_id = incoming.sender_staff_id or incoming.sender_id or "" + sender_nick = incoming.sender_nick or "" + message_id = incoming.message_id or "" + conversation_id = incoming.conversation_id or "" + conversation_type = incoming.conversation_type or "1" + session_webhook = incoming.session_webhook or "" + + logger.info( + f"[DingTalk Stream] Received {msgtype} message from {sender_staff_id}" + ) + + if msgtype == "text": + # Plain text: use existing logic text_list = incoming.get_text_list() user_text = " ".join(text_list).strip() if text_list else "" - if not user_text: return dingtalk_stream.AckMessage.STATUS_OK, "empty message" - sender_staff_id = incoming.sender_staff_id or incoming.sender_id or "" - conversation_id = incoming.conversation_id or "" - conversation_type = incoming.conversation_type or "1" - session_webhook = incoming.session_webhook or "" - logger.info( - f"[DingTalk Stream] Message from [{incoming.sender_nick}]{sender_staff_id}: {user_text[:80]}" + f"[DingTalk Stream] Text from {sender_staff_id}: {user_text[:80]}" ) - # Dispatch to the main FastAPI event loop for DB + LLM processing from app.api.dingtalk import process_dingtalk_message if main_loop and main_loop.is_running(): - future = asyncio.run_coroutine_threadsafe( + # Add thinking reaction immediately + from app.services.dingtalk_reaction import add_thinking_reaction + _fire_and_forget(main_loop, + add_thinking_reaction(app_key, app_secret, message_id, conversation_id)) + + _fire_and_forget(main_loop, process_dingtalk_message( agent_id=agent_id, sender_staff_id=sender_staff_id, @@ -115,50 +526,134 @@ async def process(self, callback: dingtalk_stream.CallbackMessage): conversation_id=conversation_id, conversation_type=conversation_type, session_webhook=session_webhook, - ), - main_loop, - ) - # Wait for result (with timeout) - try: - future.result(timeout=120) - except Exception as e: - logger.error(f"[DingTalk Stream] LLM processing error: {e}") - import traceback - traceback.print_exc() + sender_nick=sender_nick, + message_id=message_id, + )) else: - logger.warning("[DingTalk Stream] Main loop not available for dispatch") - - return dingtalk_stream.AckMessage.STATUS_OK, "ok" - except Exception as e: - logger.error(f"[DingTalk Stream] Error in message handler: {e}") - import traceback - traceback.print_exc() - return dingtalk_stream.AckMessage.STATUS_SYSTEM_EXCEPTION, str(e) - - credential = dingtalk_stream.Credential(client_id=app_key, client_secret=app_secret) - client = dingtalk_stream.DingTalkStreamClient(credential=credential) - client.register_callback_handler( - dingtalk_stream.chatbot.ChatbotMessage.TOPIC, - ClawithChatbotHandler(), - ) + logger.warning("[DingTalk Stream] Main loop not available") - logger.info(f"[DingTalk Stream] Connecting for agent {agent_id}...") - # start_forever() blocks until disconnected - client.start_forever() + else: + # Non-text message: process media in the main loop + if main_loop and main_loop.is_running(): + # Add thinking reaction immediately + from app.services.dingtalk_reaction import add_thinking_reaction + _fire_and_forget(main_loop, + add_thinking_reaction(app_key, app_secret, message_id, conversation_id)) + + _fire_and_forget(main_loop, + manager_self._handle_media_and_dispatch( + msg_data=msg_data, + app_key=app_key, + app_secret=app_secret, + agent_id=agent_id, + sender_staff_id=sender_staff_id, + conversation_id=conversation_id, + conversation_type=conversation_type, + session_webhook=session_webhook, + sender_nick=sender_nick, + message_id=message_id, + )) + else: + logger.warning("[DingTalk Stream] Main loop not available") + + return dingtalk_stream.AckMessage.STATUS_OK, "ok" + except Exception as e: + logger.error(f"[DingTalk Stream] Error in message handler: {e}") + import traceback + traceback.print_exc() + return dingtalk_stream.AckMessage.STATUS_SYSTEM_EXCEPTION, str(e) + + while not stop_event.is_set() and retries <= MAX_RETRIES: + try: + credential = dingtalk_stream.Credential(client_id=app_key, client_secret=app_secret) + client = dingtalk_stream.DingTalkStreamClient(credential=credential) + client.register_callback_handler( + dingtalk_stream.chatbot.ChatbotMessage.TOPIC, + ClawithChatbotHandler(), + ) - except ImportError: - logger.warning( - "[DingTalk Stream] dingtalk-stream package not installed. " - "Install with: pip install dingtalk-stream" + logger.info( + f"[DingTalk Stream] Connecting for agent {agent_id}... " + f"(attempt {retries + 1}/{MAX_RETRIES + 1})" + ) + # start_forever() blocks until disconnected + client.start_forever() + + # start_forever returned: connection dropped + if stop_event.is_set(): + break # intentional stop, no retry + + # Reset retries on successful connection (ran for a while then disconnected) + retries = 0 + retries += 1 + logger.warning( + f"[DingTalk Stream] Connection lost for agent {agent_id}, will retry..." + ) + + except Exception as e: + retries += 1 + logger.error( + f"[DingTalk Stream] Connection error for {agent_id} " + f"(attempt {retries}/{MAX_RETRIES + 1}): {e}" + ) + + if retries > MAX_RETRIES: + logger.error( + f"[DingTalk Stream] Agent {agent_id} exhausted all {MAX_RETRIES} retries, giving up" + ) + break + + delay = RETRY_DELAYS[min(retries - 1, len(RETRY_DELAYS) - 1)] + logger.info( + f"[DingTalk Stream] Retrying in {delay}s for agent {agent_id}..." ) - except Exception as e: - logger.error(f"[DingTalk Stream] Client error for {agent_id}: {e}") - import traceback - traceback.print_exc() - finally: - self._threads.pop(agent_id, None) - self._stop_events.pop(agent_id, None) - logger.info(f"[DingTalk Stream] Client stopped for agent {agent_id}") + # Use stop_event.wait so we exit immediately if stopped + if stop_event.wait(timeout=delay): + break # stop was requested during wait + + self._threads.pop(agent_id, None) + self._stop_events.pop(agent_id, None) + logger.info(f"[DingTalk Stream] Client stopped for agent {agent_id}") + + @staticmethod + async def _handle_media_and_dispatch( + msg_data: dict, + app_key: str, + app_secret: str, + agent_id: uuid.UUID, + sender_staff_id: str, + conversation_id: str, + conversation_type: str, + session_webhook: str, + sender_nick: str = "", + message_id: str = "", + ): + """Download media, then dispatch to process_dingtalk_message.""" + from app.api.dingtalk import process_dingtalk_message + + user_text, image_base64_list, saved_file_paths = await _process_media_message( + msg_data=msg_data, + app_key=app_key, + app_secret=app_secret, + agent_id=agent_id, + ) + + if not user_text: + logger.info("[DingTalk Stream] Empty content after media processing, skipping") + return + + await process_dingtalk_message( + agent_id=agent_id, + sender_staff_id=sender_staff_id, + user_text=user_text, + conversation_id=conversation_id, + conversation_type=conversation_type, + session_webhook=session_webhook, + image_base64_list=image_base64_list, + saved_file_paths=saved_file_paths, + sender_nick=sender_nick, + message_id=message_id, + ) async def stop_client(self, agent_id: uuid.UUID): """Stop a running Stream client for an agent.""" @@ -167,7 +662,10 @@ async def stop_client(self, agent_id: uuid.UUID): stop_event.set() thread = self._threads.pop(agent_id, None) if thread and thread.is_alive(): - logger.info(f"[DingTalk Stream] Stopping client for agent {agent_id}") + logger.info(f"[DingTalk Stream] Stopping client for agent {agent_id}, waiting for thread...") + thread.join(timeout=5) + if thread.is_alive(): + logger.warning(f"[DingTalk Stream] Thread for {agent_id} did not exit within 5s") async def start_all(self): """Start Stream clients for all configured DingTalk agents.""" diff --git a/frontend/src/i18n/en.json b/frontend/src/i18n/en.json index d01fbfbef..7a417ae2c 100644 --- a/frontend/src/i18n/en.json +++ b/frontend/src/i18n/en.json @@ -1359,7 +1359,30 @@ "messagingTitle": "3. Proactive 1-to-1 Messaging (AI-initiated)", "messagingDesc": "Sending messages to individual users requires a valid WeCom API access token, which can only be obtained from a server IP that has been whitelisted in the self-built app's settings. Unlike Feishu, WeCom mandates IP-level restrictions on all API calls — there is no token-only authentication option.", "footerText": "Due to the above constraints, WeCom integration currently cannot be easily set up by most users. We are actively exploring alternative approaches — including WeCom ISV (service provider) registration and lower-friction API options — or we may advocate for WeCom to relax these restrictions for SaaS platforms. Configuration will be re-enabled once a viable path is available." - } + }, + "avatarUrlField": "Avatar URL Field", + "avatarUrlFieldPlaceholder": "Default: picture", + "authorizeUrlPlaceholder": "https://sso.example.com/oauth2/authorize", + "clearAllMappings": "Clear All Field Mappings", + "clearField": "Clear this field", + "deleteConfirmProvider": "Are you sure you want to delete this configuration?", + "emailField": "Email Field", + "emailFieldPlaceholder": "Default: email", + "fieldMapping": "Field Mapping", + "fieldMappingHint": "Optional, leave empty to use standard OIDC fields", + "hasCustomMapping": "Custom field mapping configured", + "mobileField": "Mobile Field", + "mobileFieldPlaceholder": "Default: phone_number", + "nameField": "Name Field", + "nameFieldPlaceholder": "Default: name", + "oauth2": "OAuth2", + "oauth2Desc": "Generic OIDC Provider", + "scope": "Scope", + "scopePlaceholder": "openid profile email", + "tokenUrlPlaceholder": "Leave empty to auto-derive from Authorize URL", + "userIdField": "User ID Field", + "userIdFieldPlaceholder": "Default: sub", + "userInfoUrlPlaceholder": "Leave empty to auto-derive from Authorize URL" }, "dangerZone": "️ Danger Zone", "deleteCompanyDesc": "Permanently delete this company and all its data, including agents, models, tools, and skills. This action cannot be undone.", diff --git a/frontend/src/i18n/zh.json b/frontend/src/i18n/zh.json index 465cc21b6..956f784f7 100644 --- a/frontend/src/i18n/zh.json +++ b/frontend/src/i18n/zh.json @@ -1525,7 +1525,30 @@ "messagingTitle": "3. AI 主动发送一对一消息", "messagingDesc": "向企微成员主动发消息,需先获取有效的 API access_token,而获取 token 的请求 IP 必须在自建应用的「企业可信IP」白名单中。与飞书不同,企微对所有 API 调用均强制要求 IP 白名单,没有仅凭 token 认证的选项。", "footerText": "由于上述限制,目前大多数用户无法轻松完成企业微信集成配置。我们正在积极探索替代方案——包括申请企微服务商(ISV)资质及寻找限制更少的接入方式——同时也希望企微能够降低对 SaaS 平台的接入门槛。待可行路径明确后,配置入口将重新开启。" - } + }, + "avatarUrlField": "头像地址字段", + "avatarUrlFieldPlaceholder": "默认: picture", + "authorizeUrlPlaceholder": "https://sso.example.com/oauth2/authorize", + "clearAllMappings": "清空所有字段映射", + "clearField": "清空此字段", + "deleteConfirmProvider": "确定要删除此配置吗?", + "emailField": "邮箱字段", + "emailFieldPlaceholder": "默认:email", + "fieldMapping": "字段映射", + "fieldMappingHint": "可选,留空使用标准 OIDC 字段", + "hasCustomMapping": "当前已配置自定义字段映射", + "mobileField": "手机号字段", + "mobileFieldPlaceholder": "默认:phone_number", + "nameField": "姓名字段", + "nameFieldPlaceholder": "默认:name", + "oauth2": "OAuth2", + "oauth2Desc": "通用 OIDC 提供商", + "scope": "Scope", + "scopePlaceholder": "openid profile email", + "tokenUrlPlaceholder": "留空则从授权地址自动推导", + "userIdField": "用户 ID 字段", + "userIdFieldPlaceholder": "默认:sub", + "userInfoUrlPlaceholder": "留空则从授权地址自动推导" } }, "common": { diff --git a/frontend/src/pages/EnterpriseSettings.tsx b/frontend/src/pages/EnterpriseSettings.tsx index f9ade6332..7d5ec4956 100644 --- a/frontend/src/pages/EnterpriseSettings.tsx +++ b/frontend/src/pages/EnterpriseSettings.tsx @@ -419,7 +419,8 @@ function OrgTab({ tenant }: { tenant: any }) { authorize_url: '', token_url: '', user_info_url: '', - scope: 'openid profile email' + scope: 'openid profile email', + field_mapping: {} as Record, }); const currentTenantId = localStorage.getItem('current_tenant_id') || ''; @@ -522,7 +523,8 @@ function OrgTab({ tenant }: { tenant: any }) { authorize_url: config?.authorize_url || '', token_url: config?.token_url || '', user_info_url: config?.user_info_url || '', - scope: config?.scope || 'openid profile email' + scope: config?.scope || 'openid profile email', + field_mapping: config?.field_mapping || {}, }); const save = () => { @@ -565,7 +567,8 @@ function OrgTab({ tenant }: { tenant: any }) { name: nameMap[type] || type, config: defaults[type] || {}, app_id: '', app_secret: '', authorize_url: '', token_url: '', user_info_url: '', - scope: 'openid profile email' + scope: 'openid profile email', + field_mapping: {}, }); } setSelectedDept(null); @@ -650,8 +653,57 @@ function OrgTab({ tenant }: { tenant: any }) { setForm({ ...form, app_secret: e.target.value })} />
- - setForm({ ...form, authorize_url: e.target.value })} /> + + setForm({ ...form, authorize_url: e.target.value })} placeholder={t('enterprise.identity.authorizeUrlPlaceholder', 'https://sso.example.com/oauth2/authorize')} /> +
+
+ + setForm({ ...form, token_url: e.target.value })} placeholder={t('enterprise.identity.tokenUrlPlaceholder', 'Leave empty to auto-derive from Authorize URL')} /> +
+
+ + setForm({ ...form, user_info_url: e.target.value })} placeholder={t('enterprise.identity.userInfoUrlPlaceholder', 'Leave empty to auto-derive from Authorize URL')} /> +
+
+ + setForm({ ...form, scope: e.target.value })} placeholder={t('enterprise.identity.scopePlaceholder', 'openid profile email')} /> +
+
+
+ {t('enterprise.identity.fieldMapping', 'Field Mapping')} ({t('enterprise.identity.fieldMappingHint', 'Optional, leave empty to use standard OIDC fields')}) +
+
+
+ + setForm({ ...form, field_mapping: { ...form.field_mapping, provider_user_id: e.target.value } })} + placeholder={t('enterprise.identity.userIdFieldPlaceholder', 'Default: sub')} style={{ fontSize: '12px' }} /> +
+
+ + setForm({ ...form, field_mapping: { ...form.field_mapping, display_name: e.target.value } })} + placeholder={t('enterprise.identity.nameFieldPlaceholder', 'Default: name')} style={{ fontSize: '12px' }} /> +
+
+ + setForm({ ...form, field_mapping: { ...form.field_mapping, email: e.target.value } })} + placeholder={t('enterprise.identity.emailFieldPlaceholder', 'Default: email')} style={{ fontSize: '12px' }} /> +
+
+ + setForm({ ...form, field_mapping: { ...form.field_mapping, mobile: e.target.value } })} + placeholder={t('enterprise.identity.mobileFieldPlaceholder', 'Default: phone_number')} style={{ fontSize: '12px' }} /> +
+
+ + setForm({ ...form, field_mapping: { ...form.field_mapping, avatar_url: e.target.value } })} + placeholder={t('enterprise.identity.avatarUrlFieldPlaceholder', 'Default: picture')} style={{ fontSize: '12px' }} /> +
+
) : type === 'wecom' ? ( From ee951298ac64acd7e6744a6a313b83f0435b1282 Mon Sep 17 00:00:00 2001 From: "nap.liu" Date: Wed, 8 Apr 2026 20:07:16 +0800 Subject: [PATCH 4/5] feat: inject platform base_url into agent system prompt and webhook URLs Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/app/services/agent_context.py | 20 ++++++++++++++++++++ backend/app/services/agent_tools.py | 16 +++++++++++++--- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/backend/app/services/agent_context.py b/backend/app/services/agent_context.py index 84ebaece2..62e0e61fc 100644 --- a/backend/app/services/agent_context.py +++ b/backend/app/services/agent_context.py @@ -564,4 +564,24 @@ async def build_agent_context(agent_id: uuid.UUID, agent_name: str, role_descrip if current_user_name: dynamic_parts.append(f"\n## Current Conversation\nYou are currently chatting with **{current_user_name}**. Address them by name when appropriate.") + # Inject platform base URL so agent knows where it is deployed + try: + from app.services.platform_service import platform_service + _platform_url = (await platform_service.get_public_base_url()).rstrip("/") + platform_lines = [ + "\n## Platform Base URLs", + "You are running on the Clawith platform. Always use these URLs exactly -- never guess or invent domain names.", + "", + "- **Platform base**: " + _platform_url, + "- **Webhook**: " + _platform_url + "/api/webhooks/t/ (replace with actual trigger token)", + "- **Public page**: " + _platform_url + "/p/ (replace with actual page id returned by publish_page)", + "- **File download**: " + _platform_url + "/api/agents//files/download?path=", + "- **Gateway poll**: " + _platform_url + "/api/gateway/poll (used by external agents to check inbox)", + "", + "Never use placeholder domains (clawith.com, try.clawith.ai, webhook.clawith.com, api.clawith.ai, etc.).", + ] + dynamic_parts.append("\n".join(platform_lines)) + except Exception: + pass + return "\n".join(static_parts), "\n".join(dynamic_parts) diff --git a/backend/app/services/agent_tools.py b/backend/app/services/agent_tools.py index f1880d172..5976494a5 100644 --- a/backend/app/services/agent_tools.py +++ b/backend/app/services/agent_tools.py @@ -5544,6 +5544,7 @@ async def _handle_cancel_trigger(agent_id: uuid.UUID, arguments: dict) -> str: async def _handle_list_triggers(agent_id: uuid.UUID) -> str: """List all active triggers for the agent.""" from app.models.trigger import AgentTrigger + from app.services.platform_service import platform_service try: async with async_session() as db: @@ -5554,15 +5555,24 @@ async def _handle_list_triggers(agent_id: uuid.UUID) -> str: ) triggers = result.scalars().all() + # Resolve base URL for webhook triggers + _base_url = (await platform_service.get_public_base_url(db)).rstrip("/") + if not triggers: return "No triggers found. Use set_trigger to create one." - lines = ["| Name | Type | Config | Reason | Status | Fires |", "|------|------|--------|--------|--------|-------|"] + lines = ["| Name | Type | Config | Webhook URL | Reason | Status | Fires |", "|------|------|--------|-------------|--------|--------|-------|"] for t in triggers: status = "✅ active" if t.is_enabled else "⏸ disabled" - config_str = str(t.config)[:50] + config = t.config or {} + if t.type == "webhook" and config.get("token"): + config_str = f"token: {config['token']}" + webhook_url = f"{_base_url}/api/webhooks/t/{config['token']}" + else: + config_str = str(config)[:50] + webhook_url = "-" reason_str = t.reason[:40] if t.reason else "" - lines.append(f"| {t.name} | {t.type} | {config_str} | {reason_str} | {status} | {t.fire_count} |") + lines.append(f"| {t.name} | {t.type} | {config_str} | {webhook_url} | {reason_str} | {status} | {t.fire_count} |") return "\n".join(lines) From 19137d24be63c18d20146ceed32e212aa5d117c5 Mon Sep 17 00:00:00 2001 From: "nap.liu" Date: Wed, 8 Apr 2026 19:57:44 +0800 Subject: [PATCH 5/5] refactor: migrate in-memory token caches to Redis with memory fallback Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/app/api/teams.py | 41 +++++++-------- backend/app/api/wecom.py | 44 ++++++++++++---- backend/app/core/token_cache.py | 64 ++++++++++++++++++++++++ backend/app/services/auth_provider.py | 24 +++++++-- backend/app/services/feishu_service.py | 30 ++++++++--- backend/app/services/org_sync_adapter.py | 40 ++++++++++++--- 6 files changed, 194 insertions(+), 49 deletions(-) create mode 100644 backend/app/core/token_cache.py diff --git a/backend/app/api/teams.py b/backend/app/api/teams.py index 09368ebe1..9c2b0e63f 100644 --- a/backend/app/api/teams.py +++ b/backend/app/api/teams.py @@ -37,42 +37,44 @@ TEAMS_MSG_LIMIT = 28000 # Teams message char limit (approx 28KB) -# In-memory cache for OAuth tokens -_teams_tokens: dict[str, dict] = {} # agent_id -> {access_token, expires_at} - async def _get_teams_access_token(config: ChannelConfig) -> str | None: """Get or refresh Microsoft Teams access token. - + Supports: - Client credentials (app_id + app_secret) - default - Managed Identity (when use_managed_identity is True in extra_config) + Token is cached in Redis (preferred) with in-memory fallback. + Key: clawith:token:teams:{agent_id} """ + from app.core.token_cache import get_cached_token, set_cached_token + agent_id = str(config.agent_id) - cached = _teams_tokens.get(agent_id) - if cached and cached["expires_at"] > time.time() + 60: # Refresh 60s before expiry + cache_key = f"clawith:token:teams:{agent_id}" + + cached = await get_cached_token(cache_key) + if cached: logger.debug(f"Teams: Using cached access token for agent {agent_id}") - return cached["access_token"] + return cached # Check if managed identity should be used use_managed_identity = config.extra_config.get("use_managed_identity", False) - + if use_managed_identity: # Use Azure Managed Identity try: from azure.identity.aio import DefaultAzureCredential from azure.core.credentials import AccessToken - + credential = DefaultAzureCredential() # For Bot Framework, we need the token for the Bot Framework API # Managed identity needs to be granted permissions to the Bot Framework API scope = "https://api.botframework.com/.default" token: AccessToken = await credential.get_token(scope) - - _teams_tokens[agent_id] = { - "access_token": token.token, - "expires_at": token.expires_on, - } + + # expires_on is a Unix timestamp; TTL = expires_on - now - 60s buffer + ttl = max(int(token.expires_on - time.time()) - 60, 60) + await set_cached_token(cache_key, token.token, ttl) logger.info(f"Teams: Successfully obtained access token via managed identity for agent {agent_id}, expires at {token.expires_on}") await credential.close() return token.token @@ -82,7 +84,7 @@ async def _get_teams_access_token(config: ChannelConfig) -> str | None: except Exception as e: logger.exception(f"Teams: Failed to get access token via managed identity for agent {agent_id}: {e}") return None - + # Use client credentials (app_id + app_secret) app_id = config.app_id app_secret = config.app_secret @@ -117,11 +119,10 @@ async def _get_teams_access_token(config: ChannelConfig) -> str | None: access_token = token_data["access_token"] expires_in = token_data["expires_in"] - _teams_tokens[agent_id] = { - "access_token": access_token, - "expires_at": time.time() + expires_in, - } - logger.info(f"Teams: Successfully obtained access token for agent {agent_id}, expires in {expires_in}s") + # TTL = expires_in - 60s buffer + ttl = max(expires_in - 60, 60) + await set_cached_token(cache_key, access_token, ttl) + logger.info(f"Teams: Successfully obtained access token for agent {agent_id}, expires in {expires_in}s, TTL={ttl}s") return access_token except httpx.HTTPStatusError as e: error_body = e.response.text if hasattr(e, 'response') and e.response else "No response body" diff --git a/backend/app/api/wecom.py b/backend/app/api/wecom.py index b9f4d0dc2..249b00d5a 100644 --- a/backend/app/api/wecom.py +++ b/backend/app/api/wecom.py @@ -27,6 +27,34 @@ router = APIRouter(tags=["wecom"]) +async def _get_wecom_token_cached(corp_id: str, corp_secret: str) -> str: + """Get WeCom access_token with Redis (preferred) + memory fallback caching. + + Key: clawith:token:wecom:{corp_id} + TTL: 6900s (7200s validity - 5 min early refresh) + """ + from app.core.token_cache import get_cached_token, set_cached_token + import httpx as _httpx + + cache_key = f"clawith:token:wecom:{corp_id}" + cached = await get_cached_token(cache_key) + if cached: + return cached + + async with _httpx.AsyncClient(timeout=10) as _client: + _resp = await _client.get( + "https://qyapi.weixin.qq.com/cgi-bin/gettoken", + params={"corpid": corp_id, "corpsecret": corp_secret}, + ) + _data = _resp.json() + token = _data.get("access_token", "") + expires_in = int(_data.get("expires_in") or 7200) + if token: + ttl = max(expires_in - 300, 300) + await set_cached_token(cache_key, token, ttl) + return token + + # ─── WeCom AES Crypto ────────────────────────────────── def _pad(text: bytes) -> bytes: @@ -445,13 +473,11 @@ async def _process_wecom_kf_event(agent_id: uuid.UUID, config_obj: ChannelConfig if not config: return - async with httpx.AsyncClient(timeout=10) as client: - tok_resp = await client.get("https://qyapi.weixin.qq.com/cgi-bin/gettoken", params={"corpid": config.app_id, "corpsecret": config.app_secret}) - token_data = tok_resp.json() - access_token = token_data.get("access_token") - if not access_token: - return + access_token = await _get_wecom_token_cached(config.app_id, config.app_secret) + if not access_token: + return + async with httpx.AsyncClient(timeout=10) as client: current_cursor = token has_more = 1 current_ts = int(time.time()) @@ -596,12 +622,8 @@ async def _process_wecom_text( # Send reply via WeCom API wecom_agent_id = (config.extra_config or {}).get("wecom_agent_id", "") try: + access_token = await _get_wecom_token_cached(config.app_id, config.app_secret) async with httpx.AsyncClient(timeout=10) as client: - tok_resp = await client.get( - "https://qyapi.weixin.qq.com/cgi-bin/gettoken", - params={"corpid": config.app_id, "corpsecret": config.app_secret}, - ) - access_token = tok_resp.json().get("access_token", "") if access_token: if is_kf and open_kfid: # For KF messages, need to bridge/trans state first then send via kf/send_msg diff --git a/backend/app/core/token_cache.py b/backend/app/core/token_cache.py new file mode 100644 index 000000000..3dc09c993 --- /dev/null +++ b/backend/app/core/token_cache.py @@ -0,0 +1,64 @@ +""" +Unified Redis-backed token cache with in-memory fallback. + +Key naming convention: clawith:token:{type}:{identifier} +Examples: + clawith:token:dingtalk_corp:{app_key} + clawith:token:feishu_tenant:{app_id} + clawith:token:wecom:{corp_id} + clawith:token:teams:{agent_id} +""" +import time +from typing import Optional + +# In-memory fallback store: {key: (value, expire_at)} +_memory_cache: dict[str, tuple[str, float]] = {} + + +async def get_cached_token(key: str) -> Optional[str]: + """Get token from Redis (preferred) or memory fallback.""" + # Try Redis first + try: + from app.core.events import get_redis + redis = await get_redis() + if redis: + val = await redis.get(key) + if val: + return val if isinstance(val, str) else val.decode() + except Exception: + pass + + # Fallback to memory + if key in _memory_cache: + val, expire_at = _memory_cache[key] + if time.time() < expire_at: + return val + del _memory_cache[key] + return None + + +async def set_cached_token(key: str, value: str, ttl_seconds: int) -> None: + """Set token in Redis (preferred) and memory fallback.""" + # Try Redis first + try: + from app.core.events import get_redis + redis = await get_redis() + if redis: + await redis.setex(key, ttl_seconds, value) + except Exception: + pass + + # Always set in memory as fallback + _memory_cache[key] = (value, time.time() + ttl_seconds) + + +async def delete_cached_token(key: str) -> None: + """Delete token from both Redis and memory.""" + try: + from app.core.events import get_redis + redis = await get_redis() + if redis: + await redis.delete(key) + except Exception: + pass + _memory_cache.pop(key, None) diff --git a/backend/app/services/auth_provider.py b/backend/app/services/auth_provider.py index 0d2543098..c4fe559a2 100644 --- a/backend/app/services/auth_provider.py +++ b/backend/app/services/auth_provider.py @@ -294,8 +294,19 @@ async def get_authorization_url(self, redirect_uri: str, state: str) -> str: return f"{base_url}?{params}" async def get_app_access_token(self) -> str: - if self._app_access_token: - return self._app_access_token + """Get or refresh the Feishu app access token. + + Cached in Redis (preferred) with in-memory fallback. + Key: clawith:token:feishu_tenant:{app_id} + TTL: 6900s (7200s validity - 5 min early refresh) + """ + from app.core.token_cache import get_cached_token, set_cached_token + + cache_key = f"clawith:token:feishu_tenant:{self.app_id}" + cached = await get_cached_token(cache_key) + if cached: + self._app_access_token = cached + return cached async with httpx.AsyncClient() as client: resp = await client.post( @@ -303,8 +314,13 @@ async def get_app_access_token(self) -> str: json={"app_id": self.app_id, "app_secret": self.app_secret}, ) data = resp.json() - self._app_access_token = data.get("app_access_token", "") - return self._app_access_token + token = data.get("app_access_token", "") or data.get("tenant_access_token", "") + expire = data.get("expire", 7200) + if token: + ttl = max(expire - 300, 60) + await set_cached_token(cache_key, token, ttl) + self._app_access_token = token + return token async def exchange_code_for_token(self, code: str) -> dict: app_token = await self.get_app_access_token() diff --git a/backend/app/services/feishu_service.py b/backend/app/services/feishu_service.py index 629a95e90..33871e4c2 100644 --- a/backend/app/services/feishu_service.py +++ b/backend/app/services/feishu_service.py @@ -86,21 +86,39 @@ async def get_app_access_token(self) -> str: return await self.get_tenant_access_token(self.app_id, self.app_secret) async def get_tenant_access_token(self, app_id: str = None, app_secret: str = None) -> str: - """Get or refresh the app-level access token (tenant_access_token).""" + """Get or refresh the app-level access token (tenant_access_token). + + Cached in Redis (preferred) with in-memory fallback. + Key: clawith:token:feishu_tenant:{app_id} + TTL: 6900s (7200s validity - 5 min early refresh) + """ + from app.core.token_cache import get_cached_token, set_cached_token + target_app_id = app_id or self.app_id target_app_secret = app_secret or self.app_secret - + cache_key = f"clawith:token:feishu_tenant:{target_app_id}" + + cached = await get_cached_token(cache_key) + if cached: + if not app_id: + self._app_access_token = cached + return cached + async with httpx.AsyncClient() as client: resp = await client.post(FEISHU_APP_TOKEN_URL, json={ "app_id": target_app_id, "app_secret": target_app_secret, }) data = resp.json() - + token = data.get("tenant_access_token") or data.get("app_access_token", "") - if not app_id: # only cache default app token - self._app_access_token = token - + expire = data.get("expire", 7200) + if token: + ttl = max(expire - 300, 60) + await set_cached_token(cache_key, token, ttl) + if not app_id: # only update instance var for default app token + self._app_access_token = token + return token async def exchange_code_for_user(self, code: str) -> dict: diff --git a/backend/app/services/org_sync_adapter.py b/backend/app/services/org_sync_adapter.py index 8a3525d13..9ebbaa681 100644 --- a/backend/app/services/org_sync_adapter.py +++ b/backend/app/services/org_sync_adapter.py @@ -781,12 +781,23 @@ def api_base_url(self) -> str: return self.DINGTALK_API_URL async def get_access_token(self) -> str: - if self._access_token and self._token_expires_at and datetime.now() < self._token_expires_at: - return self._access_token + """Get or refresh the DingTalk access token. + + Cached in Redis (preferred) with in-memory fallback. + Key: clawith:token:dingtalk_corp:{app_key} + TTL: expires_in - 300s (5 min early refresh) + """ + from app.core.token_cache import get_cached_token, set_cached_token if not self.app_key or not self.app_secret: raise ValueError("DingTalk app_key/app_secret missing in provider config") + cache_key = f"clawith:token:dingtalk_corp:{self.app_key}" + cached = await get_cached_token(cache_key) + if cached: + self._access_token = cached + return cached + async with httpx.AsyncClient() as client: resp = await client.get( self.DINGTALK_TOKEN_URL, @@ -797,9 +808,11 @@ async def get_access_token(self) -> str: raise RuntimeError(f"DingTalk token error: {data.get('errmsg') or data}") token = data.get("access_token") or "" expires_in = int(data.get("expires_in") or 7200) - self._access_token = token - # refresh a bit earlier - self._token_expires_at = datetime.now() + timedelta(seconds=max(expires_in - 60, 60)) + if token: + ttl = max(expires_in - 300, 60) + await set_cached_token(cache_key, token, ttl) + self._access_token = token + self._token_expires_at = datetime.now() + timedelta(seconds=max(expires_in - 60, 60)) return token async def fetch_departments(self) -> list[ExternalDepartment]: @@ -990,19 +1003,30 @@ def api_base_url(self) -> str: async def get_access_token(self) -> str: """Get valid access token using the 通讯录同步 (contact-sync) secret. + Cached in Redis (preferred) with in-memory fallback. + Key: clawith:token:wecom:{corp_id} + TTL: 6900s (7200s validity - 5 min early refresh) + This token can call department/simplelist and user/list_id. It cannot call user/list or user/get (those raise errcode 48009). Full user profiles are obtained passively via SSO login instead. """ - if self._access_token and self._token_expires_at and datetime.now() < self._token_expires_at: - return self._access_token + from app.core.token_cache import get_cached_token, set_cached_token if not self.corp_id or not self.secret: raise ValueError("WeCom corp_id or secret missing in provider config") + cache_key = f"clawith:token:wecom:{self.corp_id}" + cached = await get_cached_token(cache_key) + if cached: + self._access_token = cached + return cached + token = await self._fetch_token(self.corp_id, self.secret) + if token: + ttl = max(7200 - 300, 300) + await set_cached_token(cache_key, token, ttl) self._access_token = token - # Refresh slightly before true expiry to avoid clock-skew issues self._token_expires_at = datetime.now() + timedelta(seconds=7200 - 300) return token