diff --git a/backend/alembic/versions/add_subdomain_prefix.py b/backend/alembic/versions/add_subdomain_prefix.py
new file mode 100644
index 000000000..5fecb85c0
--- /dev/null
+++ b/backend/alembic/versions/add_subdomain_prefix.py
@@ -0,0 +1,24 @@
+"""Add subdomain_prefix to tenants table."""
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy import inspect
+
+revision: str = "add_subdomain_prefix"
+down_revision = "d9cbd43b62e5"
+branch_labels = None
+depends_on = None
+
+def upgrade() -> None:
+ conn = op.get_bind()
+ inspector = inspect(conn)
+ columns = [c["name"] for c in inspector.get_columns("tenants")]
+ if "subdomain_prefix" not in columns:
+ op.add_column("tenants", sa.Column("subdomain_prefix", sa.String(50), nullable=True))
+ indexes = [i["name"] for i in inspector.get_indexes("tenants")]
+ if "ix_tenants_subdomain_prefix" not in indexes:
+ op.create_index("ix_tenants_subdomain_prefix", "tenants", ["subdomain_prefix"], unique=True)
+
+def downgrade() -> None:
+ op.drop_index("ix_tenants_subdomain_prefix", "tenants")
+ op.drop_column("tenants", "subdomain_prefix")
diff --git a/backend/alembic/versions/add_tenant_is_default.py b/backend/alembic/versions/add_tenant_is_default.py
new file mode 100644
index 000000000..00e164615
--- /dev/null
+++ b/backend/alembic/versions/add_tenant_is_default.py
@@ -0,0 +1,27 @@
+"""Add is_default field to tenants table."""
+
+from alembic import op
+import sqlalchemy as sa
+
+revision = "add_tenant_is_default"
+down_revision = "add_subdomain_prefix"
+branch_labels = None
+depends_on = None
+
+def upgrade() -> None:
+ conn = op.get_bind()
+ inspector = sa.inspect(conn)
+ cols = [c['name'] for c in inspector.get_columns('tenants')]
+ if 'is_default' not in cols:
+ op.add_column('tenants', sa.Column('is_default', sa.Boolean(), nullable=False, server_default='false'))
+ conn.execute(sa.text("""
+ UPDATE tenants
+ SET is_default = true
+ WHERE id = (
+ SELECT id FROM tenants WHERE is_active = true ORDER BY created_at ASC LIMIT 1
+ )
+ AND NOT EXISTS (SELECT 1 FROM tenants WHERE is_default = true)
+ """))
+
+def downgrade() -> None:
+ op.drop_column('tenants', 'is_default')
diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py
index ef76429b5..26ed39cc6 100644
--- a/backend/app/api/auth.py
+++ b/backend/app/api/auth.py
@@ -512,12 +512,18 @@ async def login(data: UserLogin, background_tasks: BackgroundTasks, db: AsyncSes
user = valid_users[0]
else:
# Specific tenant requested (Dedicated Link flow)
- # Search for the user record in that tenant
- user = next((u for u in valid_users if u.tenant_id == data.tenant_id), None)
-
+ # Search for the user record in that tenant.
+ # Also allow users with tenant_id=None (SSO-created but not yet
+ # assigned to a tenant) so they can proceed to company setup.
+ user = next(
+ (u for u in valid_users
+ if u.tenant_id == data.tenant_id or u.tenant_id is None),
+ None,
+ )
+
# Cross-tenant access check
if not user:
- # Even platform admins must have a valid record in the targeted tenant
+ # Even platform admins must have a valid record in the targeted tenant
# when logging in via a dedicated tenant URL / tenant_id.
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -910,14 +916,26 @@ async def oauth_callback(
raise HTTPException(status_code=404, detail=f"Provider '{provider}' not supported")
try:
- # Exchange code for token
- token_data = await auth_provider.exchange_code_for_token(data.code)
+ # Exchange code for token (pass redirect_uri for OAuth2 providers that require it)
+ if hasattr(auth_provider, 'exchange_code_for_token') and data.redirect_uri:
+ token_data = await auth_provider.exchange_code_for_token(data.code, redirect_uri=data.redirect_uri)
+ else:
+ token_data = await auth_provider.exchange_code_for_token(data.code)
access_token = token_data.get("access_token")
if not access_token:
raise HTTPException(status_code=400, detail="Failed to get access token from provider")
- # Get user info
- user_info = await auth_provider.get_user_info(access_token)
+ # Get user info with fallback to token_data extraction
+ try:
+ user_info = await auth_provider.get_user_info(access_token)
+ except Exception:
+ if hasattr(auth_provider, 'get_user_info_from_token_data'):
+ user_info = await auth_provider.get_user_info_from_token_data(token_data)
+ else:
+ raise
+ if not user_info.provider_user_id and hasattr(auth_provider, 'get_user_info_from_token_data'):
+ # try token_data as last resort
+ user_info = await auth_provider.get_user_info_from_token_data(token_data)
# Find or create user
user, is_new = await auth_provider.find_or_create_user(db, user_info)
diff --git a/backend/app/api/enterprise.py b/backend/app/api/enterprise.py
index 25d98c802..05f29b1de 100644
--- a/backend/app/api/enterprise.py
+++ b/backend/app/api/enterprise.py
@@ -793,6 +793,7 @@ class OAuth2Config(BaseModel):
token_url: str | None = None # OAuth2 token endpoint
user_info_url: str | None = None # OAuth2 user info endpoint
scope: str | None = "openid profile email"
+ field_mapping: dict | None = None # Custom field name mapping
def to_config_dict(self) -> dict:
"""Convert to config dict with both naming conventions for compatibility."""
@@ -811,6 +812,8 @@ def to_config_dict(self) -> dict:
config["user_info_url"] = self.user_info_url
if self.scope:
config["scope"] = self.scope
+ if self.field_mapping:
+ config["field_mapping"] = self.field_mapping
return config
@classmethod
@@ -823,6 +826,7 @@ def from_config_dict(cls, config: dict) -> "OAuth2Config":
token_url=config.get("token_url"),
user_info_url=config.get("user_info_url"),
scope=config.get("scope"),
+ field_mapping=config.get("field_mapping"),
)
@@ -837,6 +841,7 @@ class IdentityProviderOAuth2Create(BaseModel):
token_url: str
user_info_url: str
scope: str | None = "openid profile email"
+ field_mapping: dict | None = None
tenant_id: uuid.UUID | None = None
@@ -936,6 +941,7 @@ async def create_oauth2_provider(
token_url=data.token_url,
user_info_url=data.user_info_url,
scope=data.scope,
+ field_mapping=data.field_mapping,
)
config = oauth_config.to_config_dict()
@@ -962,6 +968,7 @@ async def create_oauth2_provider(
provider_type="oauth2",
name=data.name,
is_active=data.is_active,
+ sso_login_enabled=True,
config=config,
tenant_id=tid
)
@@ -981,6 +988,7 @@ class OAuth2ConfigUpdate(BaseModel):
token_url: str | None = None
user_info_url: str | None = None
scope: str | None = None
+ field_mapping: dict | None = None # Custom field name mapping
@router.patch("/identity-providers/{provider_id}/oauth2", response_model=IdentityProviderOut)
@@ -1009,7 +1017,7 @@ async def update_oauth2_provider(
provider.is_active = data.is_active
# Update config fields
- if any([data.app_id, data.app_secret is not None, data.authorize_url, data.token_url, data.user_info_url, data.scope]):
+ if any([data.app_id, data.app_secret is not None, data.authorize_url, data.token_url, data.user_info_url, data.scope, data.field_mapping is not None]):
current_config = provider.config.copy()
if data.app_id is not None:
@@ -1031,6 +1039,12 @@ async def update_oauth2_provider(
current_config["user_info_url"] = data.user_info_url
if data.scope is not None:
current_config["scope"] = data.scope
+ if data.field_mapping is not None:
+ # Empty dict or explicit None clears the mapping
+ if data.field_mapping:
+ current_config["field_mapping"] = data.field_mapping
+ else:
+ current_config.pop("field_mapping", None)
# Validate the updated config
validate_provider_config("oauth2", current_config)
diff --git a/backend/app/api/sso.py b/backend/app/api/sso.py
index 1c5210247..a1f07e441 100644
--- a/backend/app/api/sso.py
+++ b/backend/app/api/sso.py
@@ -4,6 +4,7 @@
from urllib.parse import quote
from fastapi import APIRouter, Depends, HTTPException, Request, status
+from loguru import logger
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -104,15 +105,12 @@ async def get_sso_config(sid: uuid.UUID, request: Request, db: AsyncSession = De
result = await db.execute(query)
providers = result.scalars().all()
- # Determine the base URL for OAuth callbacks using centralized platform service:
- from app.services.platform_service import platform_service
- if session.tenant_id:
- from app.models.tenant import Tenant
- tenant_result = await db.execute(select(Tenant).where(Tenant.id == session.tenant_id))
- tenant_obj = tenant_result.scalar_one_or_none()
- public_base = await platform_service.get_tenant_sso_base_url(db, tenant_obj, request)
- else:
- public_base = await platform_service.get_public_base_url(db, request)
+ # Determine the base URL for OAuth callbacks using unified domain resolution:
+ from app.core.domain import resolve_base_url
+ public_base = await resolve_base_url(
+ db, request=request,
+ tenant_id=str(session.tenant_id) if session.tenant_id else None
+ )
auth_urls = []
for p in providers:
@@ -141,5 +139,95 @@ async def get_sso_config(sid: uuid.UUID, request: Request, db: AsyncSession = De
url = f"https://open.work.weixin.qq.com/wwopen/sso/qrConnect?appid={corp_id}&agentid={agent_id}&redirect_uri={quote(redir)}&state={sid}"
auth_urls.append({"provider_type": "wecom", "name": p.name, "url": url})
+ elif p.provider_type == "oauth2":
+ from app.services.auth_registry import auth_provider_registry
+ auth_provider = await auth_provider_registry.get_provider(
+ db, "oauth2", str(session.tenant_id) if session.tenant_id else None
+ )
+ if auth_provider:
+ redir = f"{public_base}/api/auth/oauth2/callback"
+ url = await auth_provider.get_authorization_url(redir, str(sid))
+ auth_urls.append({"provider_type": "oauth2", "name": p.name, "url": url})
+
return auth_urls
+
+@router.get("/auth/oauth2/callback")
+async def oauth2_callback(
+ code: str,
+ state: str = None,
+ db: AsyncSession = Depends(get_db),
+):
+ """Handle OAuth2 SSO callback -- exchange code for user session."""
+ from app.core.security import create_access_token
+ from fastapi.responses import HTMLResponse
+ from app.services.auth_registry import auth_provider_registry
+
+ # 1. Resolve tenant context from state (= session ID)
+ tenant_id = None
+ sid = None
+ if state:
+ try:
+ sid = uuid.UUID(state)
+ s_res = await db.execute(select(SSOScanSession).where(SSOScanSession.id == sid))
+ session = s_res.scalar_one_or_none()
+ if session:
+ tenant_id = session.tenant_id
+ except (ValueError, AttributeError):
+ pass
+
+ # 2. Get OAuth2 provider
+ auth_provider = await auth_provider_registry.get_provider(
+ db, "oauth2", str(tenant_id) if tenant_id else None
+ )
+ if not auth_provider:
+ return HTMLResponse("Auth failed: OAuth2 provider not configured")
+
+ # 3. Exchange code -> token -> user info -> find/create user
+ try:
+ token_data = await auth_provider.exchange_code_for_token(code)
+ access_token = token_data.get("access_token")
+ if not access_token:
+ logger.error("OAuth2 token exchange returned no access_token")
+ return HTMLResponse("Auth failed: token exchange error")
+
+ user_info = await auth_provider.get_user_info(access_token)
+ if not user_info.provider_user_id:
+ logger.error("OAuth2 user info missing user ID")
+ return HTMLResponse("Auth failed: no user ID returned")
+
+ user, is_new = await auth_provider.find_or_create_user(
+ db, user_info, tenant_id=str(tenant_id) if tenant_id else None
+ )
+ if not user:
+ return HTMLResponse("Auth failed: user resolution failed")
+
+ except Exception as e:
+ logger.error("OAuth2 login error: %s", e)
+ return HTMLResponse(f"Auth failed: {e!s}")
+
+ # 4. Generate JWT, update SSO session
+ token = create_access_token(str(user.id), user.role)
+
+ if sid:
+ try:
+ s_res = await db.execute(select(SSOScanSession).where(SSOScanSession.id == sid))
+ session = s_res.scalar_one_or_none()
+ if session:
+ session.status = "authorized"
+ session.provider_type = "oauth2"
+ session.user_id = user.id
+ session.access_token = token
+ session.error_msg = None
+ await db.commit()
+ return HTMLResponse(
+ '
'
+ '
SSO login successful. Redirecting...
'
+ f''
+ ''
+ )
+ except Exception as e:
+ logger.exception("Failed to update SSO session (oauth2): %s", e)
+
+ return HTMLResponse("Logged in successfully.")
+
diff --git a/backend/app/api/teams.py b/backend/app/api/teams.py
index 09368ebe1..9c2b0e63f 100644
--- a/backend/app/api/teams.py
+++ b/backend/app/api/teams.py
@@ -37,42 +37,44 @@
TEAMS_MSG_LIMIT = 28000 # Teams message char limit (approx 28KB)
-# In-memory cache for OAuth tokens
-_teams_tokens: dict[str, dict] = {} # agent_id -> {access_token, expires_at}
-
async def _get_teams_access_token(config: ChannelConfig) -> str | None:
"""Get or refresh Microsoft Teams access token.
-
+
Supports:
- Client credentials (app_id + app_secret) - default
- Managed Identity (when use_managed_identity is True in extra_config)
+ Token is cached in Redis (preferred) with in-memory fallback.
+ Key: clawith:token:teams:{agent_id}
"""
+ from app.core.token_cache import get_cached_token, set_cached_token
+
agent_id = str(config.agent_id)
- cached = _teams_tokens.get(agent_id)
- if cached and cached["expires_at"] > time.time() + 60: # Refresh 60s before expiry
+ cache_key = f"clawith:token:teams:{agent_id}"
+
+ cached = await get_cached_token(cache_key)
+ if cached:
logger.debug(f"Teams: Using cached access token for agent {agent_id}")
- return cached["access_token"]
+ return cached
# Check if managed identity should be used
use_managed_identity = config.extra_config.get("use_managed_identity", False)
-
+
if use_managed_identity:
# Use Azure Managed Identity
try:
from azure.identity.aio import DefaultAzureCredential
from azure.core.credentials import AccessToken
-
+
credential = DefaultAzureCredential()
# For Bot Framework, we need the token for the Bot Framework API
# Managed identity needs to be granted permissions to the Bot Framework API
scope = "https://api.botframework.com/.default"
token: AccessToken = await credential.get_token(scope)
-
- _teams_tokens[agent_id] = {
- "access_token": token.token,
- "expires_at": token.expires_on,
- }
+
+ # expires_on is a Unix timestamp; TTL = expires_on - now - 60s buffer
+ ttl = max(int(token.expires_on - time.time()) - 60, 60)
+ await set_cached_token(cache_key, token.token, ttl)
logger.info(f"Teams: Successfully obtained access token via managed identity for agent {agent_id}, expires at {token.expires_on}")
await credential.close()
return token.token
@@ -82,7 +84,7 @@ async def _get_teams_access_token(config: ChannelConfig) -> str | None:
except Exception as e:
logger.exception(f"Teams: Failed to get access token via managed identity for agent {agent_id}: {e}")
return None
-
+
# Use client credentials (app_id + app_secret)
app_id = config.app_id
app_secret = config.app_secret
@@ -117,11 +119,10 @@ async def _get_teams_access_token(config: ChannelConfig) -> str | None:
access_token = token_data["access_token"]
expires_in = token_data["expires_in"]
- _teams_tokens[agent_id] = {
- "access_token": access_token,
- "expires_at": time.time() + expires_in,
- }
- logger.info(f"Teams: Successfully obtained access token for agent {agent_id}, expires in {expires_in}s")
+ # TTL = expires_in - 60s buffer
+ ttl = max(expires_in - 60, 60)
+ await set_cached_token(cache_key, access_token, ttl)
+ logger.info(f"Teams: Successfully obtained access token for agent {agent_id}, expires in {expires_in}s, TTL={ttl}s")
return access_token
except httpx.HTTPStatusError as e:
error_body = e.response.text if hasattr(e, 'response') and e.response else "No response body"
diff --git a/backend/app/api/tenants.py b/backend/app/api/tenants.py
index 1e8c26b14..85793949f 100644
--- a/backend/app/api/tenants.py
+++ b/backend/app/api/tenants.py
@@ -37,6 +37,9 @@ class TenantOut(BaseModel):
is_active: bool
sso_enabled: bool = False
sso_domain: str | None = None
+ is_default: bool = False
+ subdomain_prefix: str | None = None
+ effective_base_url: str | None = None
created_at: datetime | None = None
model_config = {"from_attributes": True}
@@ -49,6 +52,8 @@ class TenantUpdate(BaseModel):
is_active: bool | None = None
sso_enabled: bool | None = None
sso_domain: str | None = None
+ subdomain_prefix: str | None = None
+ is_default: bool | None = None
# ─── Helpers ────────────────────────────────────────────
@@ -338,66 +343,80 @@ async def get_registration_config(db: AsyncSession = Depends(get_db)):
return {"allow_self_create_company": allowed}
-# ─── Public: Resolve Tenant by Domain ───────────────────
-
-@router.get("/resolve-by-domain")
-async def resolve_tenant_by_domain(
- domain: str,
- db: AsyncSession = Depends(get_db),
-):
- """Resolve a tenant by its sso_domain or subdomain slug.
+# ─── Public: Check Subdomain Prefix Availability ─────────
- sso_domain is stored as a full URL (e.g. "https://acme.clawith.ai" or "http://1.2.3.4:3009").
- The incoming `domain` parameter is the host (without protocol).
-
- Lookup precedence:
- 1. Exact match on tenant.sso_domain ending with the host (strips protocol)
- 2. Extract slug from "{slug}.clawith.ai" and match tenant.slug
- """
- tenant = None
-
- # 1. Match by stripping protocol from stored sso_domain
- # sso_domain = "https://acme.clawith.ai" → compare against "acme.clawith.ai"
- for proto in ("https://", "http://"):
- result = await db.execute(
- select(Tenant).where(Tenant.sso_domain == f"{proto}{domain}")
- )
- tenant = result.scalar_one_or_none()
- if tenant:
- break
-
- # 2. Try without port (e.g. domain = "1.2.3.4:3009" → try "1.2.3.4")
- if not tenant and ":" in domain:
- domain_no_port = domain.split(":")[0]
- for proto in ("https://", "http://"):
- result = await db.execute(
- select(Tenant).where(Tenant.sso_domain.like(f"{proto}{domain_no_port}%"))
- )
- tenant = result.scalar_one_or_none()
- if tenant:
- break
+@router.get("/check-prefix")
+async def check_prefix(prefix: str, db: AsyncSession = Depends(get_db)):
+ """Check if a subdomain prefix is available."""
+ import re as _re
+ if not _re.match(r'^[a-z0-9]([a-z0-9\-]{0,48}[a-z0-9])?$', prefix):
+ return {"available": False, "reason": "Invalid format"}
+ reserved = {"www", "api", "app", "admin", "mail", "smtp", "ftp", "ns1", "ns2", "cdn", "static", "assets"}
+ if prefix in reserved:
+ return {"available": False, "reason": "Reserved"}
+ result = await db.execute(select(Tenant).where(Tenant.subdomain_prefix == prefix))
+ if result.scalar_one_or_none():
+ return {"available": False, "reason": "Already taken"}
+ return {"available": True}
- # 3. Fallback: extract slug from subdomain pattern
- if not tenant:
- import re
- m = re.match(r"^([a-z0-9][a-z0-9\-]*[a-z0-9])\.clawith\.ai$", domain.lower())
- if m:
- slug = m.group(1)
- result = await db.execute(select(Tenant).where(Tenant.slug == slug))
- tenant = result.scalar_one_or_none()
- if not tenant or not tenant.is_active or not tenant.sso_enabled:
- raise HTTPException(status_code=404, detail="Tenant not found or not active or SSO not enabled")
+# ─── Public: Resolve Tenant by Domain ───────────────────
+def _tenant_response(tenant):
return {
- "id": tenant.id,
+ "id": str(tenant.id),
"name": tenant.name,
"slug": tenant.slug,
"sso_enabled": tenant.sso_enabled,
"sso_domain": tenant.sso_domain,
- "is_active": tenant.is_active,
+ "subdomain_prefix": tenant.subdomain_prefix,
+ "is_default": getattr(tenant, 'is_default', False),
}
+
+@router.get("/resolve-by-domain")
+async def resolve_by_domain(domain: str, db: AsyncSession = Depends(get_db)):
+ """Resolve a tenant by the incoming domain/hostname."""
+ from app.core.domain import resolve_base_url, _get_global_base_url
+ from urllib.parse import urlparse
+
+ hostname = domain.split(":")[0]
+
+ # 1. Match sso_domain exactly
+ result = await db.execute(select(Tenant).where(Tenant.sso_domain == hostname))
+ tenant = result.scalar_one_or_none()
+ if tenant:
+ return _tenant_response(tenant)
+
+ # 2. Match subdomain_prefix
+ global_url = await _get_global_base_url(db)
+ if global_url:
+ parsed = urlparse(global_url)
+ global_host = parsed.hostname or ""
+ if hostname.endswith(f".{global_host}"):
+ prefix = hostname[: -(len(global_host) + 1)]
+ result = await db.execute(select(Tenant).where(Tenant.subdomain_prefix == prefix))
+ tenant = result.scalar_one_or_none()
+ if tenant:
+ return _tenant_response(tenant)
+
+ # 3. Match slug as subdomain
+ parts = hostname.split(".")
+ if len(parts) >= 2:
+ slug_candidate = parts[0]
+ result = await db.execute(select(Tenant).where(Tenant.slug == slug_candidate))
+ tenant = result.scalar_one_or_none()
+ if tenant:
+ return _tenant_response(tenant)
+
+ # 4. Return default tenant
+ result = await db.execute(select(Tenant).where(Tenant.is_default == True))
+ tenant = result.scalar_one_or_none()
+ if tenant:
+ return _tenant_response(tenant)
+
+ raise HTTPException(status_code=404, detail="No tenant found for this domain")
+
# ─── Authenticated: List / Get ──────────────────────────
@router.get("/", response_model=list[TenantOut])
@@ -425,7 +444,10 @@ async def get_tenant(
tenant = result.scalar_one_or_none()
if not tenant:
raise HTTPException(status_code=404, detail="Tenant not found")
- return TenantOut.model_validate(tenant)
+ from app.core.domain import resolve_base_url
+ out = TenantOut.model_validate(tenant)
+ out.effective_base_url = await resolve_base_url(db, tenant_id=str(tenant.id))
+ return out
@router.put("/{tenant_id}", response_model=TenantOut)
@@ -444,13 +466,36 @@ async def update_tenant(
raise HTTPException(status_code=404, detail="Tenant not found")
update_data = data.model_dump(exclude_unset=True)
-
+
# SSO configuration is managed exclusively by the company's own org_admin
# via the Enterprise Settings page. Platform admins should not override it here.
if current_user.role == "platform_admin":
update_data.pop("sso_enabled", None)
update_data.pop("sso_domain", None)
+ # Validate subdomain_prefix format if provided
+ if "subdomain_prefix" in update_data and update_data["subdomain_prefix"] is not None:
+ import re as _re
+ prefix = update_data["subdomain_prefix"]
+ if not _re.match(r'^[a-z0-9]([a-z0-9\-]{0,48}[a-z0-9])?$', prefix):
+ raise HTTPException(status_code=400, detail="Invalid subdomain prefix format")
+ reserved = {"www", "api", "app", "admin", "mail", "smtp", "ftp", "ns1", "ns2", "cdn", "static", "assets"}
+ if prefix in reserved:
+ raise HTTPException(status_code=400, detail="Subdomain prefix is reserved")
+ # Check uniqueness
+ existing = await db.execute(
+ select(Tenant).where(Tenant.subdomain_prefix == prefix, Tenant.id != tenant_id)
+ )
+ if existing.scalar_one_or_none():
+ raise HTTPException(status_code=400, detail="Subdomain prefix already taken")
+
+ # If setting is_default=True, clear other defaults
+ if update_data.get("is_default") is True:
+ from sqlalchemy import update as sql_update
+ await db.execute(
+ sql_update(Tenant).where(Tenant.id != tenant_id).values(is_default=False)
+ )
+
for field, value in update_data.items():
setattr(tenant, field, value)
await db.flush()
diff --git a/backend/app/api/wecom.py b/backend/app/api/wecom.py
index b9f4d0dc2..249b00d5a 100644
--- a/backend/app/api/wecom.py
+++ b/backend/app/api/wecom.py
@@ -27,6 +27,34 @@
router = APIRouter(tags=["wecom"])
+async def _get_wecom_token_cached(corp_id: str, corp_secret: str) -> str:
+ """Get WeCom access_token with Redis (preferred) + memory fallback caching.
+
+ Key: clawith:token:wecom:{corp_id}
+ TTL: 6900s (7200s validity - 5 min early refresh)
+ """
+ from app.core.token_cache import get_cached_token, set_cached_token
+ import httpx as _httpx
+
+ cache_key = f"clawith:token:wecom:{corp_id}"
+ cached = await get_cached_token(cache_key)
+ if cached:
+ return cached
+
+ async with _httpx.AsyncClient(timeout=10) as _client:
+ _resp = await _client.get(
+ "https://qyapi.weixin.qq.com/cgi-bin/gettoken",
+ params={"corpid": corp_id, "corpsecret": corp_secret},
+ )
+ _data = _resp.json()
+ token = _data.get("access_token", "")
+ expires_in = int(_data.get("expires_in") or 7200)
+ if token:
+ ttl = max(expires_in - 300, 300)
+ await set_cached_token(cache_key, token, ttl)
+ return token
+
+
# ─── WeCom AES Crypto ──────────────────────────────────
def _pad(text: bytes) -> bytes:
@@ -445,13 +473,11 @@ async def _process_wecom_kf_event(agent_id: uuid.UUID, config_obj: ChannelConfig
if not config:
return
- async with httpx.AsyncClient(timeout=10) as client:
- tok_resp = await client.get("https://qyapi.weixin.qq.com/cgi-bin/gettoken", params={"corpid": config.app_id, "corpsecret": config.app_secret})
- token_data = tok_resp.json()
- access_token = token_data.get("access_token")
- if not access_token:
- return
+ access_token = await _get_wecom_token_cached(config.app_id, config.app_secret)
+ if not access_token:
+ return
+ async with httpx.AsyncClient(timeout=10) as client:
current_cursor = token
has_more = 1
current_ts = int(time.time())
@@ -596,12 +622,8 @@ async def _process_wecom_text(
# Send reply via WeCom API
wecom_agent_id = (config.extra_config or {}).get("wecom_agent_id", "")
try:
+ access_token = await _get_wecom_token_cached(config.app_id, config.app_secret)
async with httpx.AsyncClient(timeout=10) as client:
- tok_resp = await client.get(
- "https://qyapi.weixin.qq.com/cgi-bin/gettoken",
- params={"corpid": config.app_id, "corpsecret": config.app_secret},
- )
- access_token = tok_resp.json().get("access_token", "")
if access_token:
if is_kf and open_kfid:
# For KF messages, need to bridge/trans state first then send via kf/send_msg
diff --git a/backend/app/core/domain.py b/backend/app/core/domain.py
new file mode 100644
index 000000000..6e2a166bb
--- /dev/null
+++ b/backend/app/core/domain.py
@@ -0,0 +1,77 @@
+"""Domain resolution with fallback chain."""
+
+import os
+
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+from fastapi import Request
+
+from app.models.system_settings import SystemSetting
+
+
+async def _get_global_base_url(db: AsyncSession):
+ """Helper: read platform public_base_url from system_settings."""
+ # Try DB first
+ result = await db.execute(
+ select(SystemSetting).where(SystemSetting.key == "platform")
+ )
+ setting = result.scalar_one_or_none()
+ if setting and setting.value.get("public_base_url"):
+ return setting.value["public_base_url"].rstrip("/")
+ # Fallback to ENV
+ env_url = os.environ.get("PUBLIC_BASE_URL")
+ if env_url:
+ return env_url.rstrip("/")
+ return None
+
+
+async def resolve_base_url(
+ db: AsyncSession,
+ request: Request | None = None,
+ tenant_id: str | None = None,
+) -> str:
+ """Resolve the effective base URL using the fallback chain:
+
+ 1. Tenant-specific sso_domain (if tenant_id provided and tenant has sso_domain)
+ 2. Tenant subdomain_prefix + global hostname
+ 3. Platform global public_base_url (from system_settings)
+ 4. Request origin (from request.base_url)
+ 5. Hardcoded fallback
+
+ Returns a full URL like "https://acme.example.com" or "http://localhost:3008"
+ """
+ # Level 1 & 2: Tenant-specific
+ if tenant_id:
+ from app.models.tenant import Tenant
+ result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
+ tenant = result.scalar_one_or_none()
+ if tenant:
+ # Level 1: complete custom domain
+ if tenant.sso_domain:
+ domain = tenant.sso_domain.rstrip("/")
+ if domain.startswith("http://") or domain.startswith("https://"):
+ return domain
+ return f"https://{domain}"
+
+ # Level 2: subdomain prefix + global hostname (skip for default tenant)
+ if tenant.subdomain_prefix and not getattr(tenant, 'is_default', False):
+ global_url = await _get_global_base_url(db)
+ if global_url:
+ from urllib.parse import urlparse
+ parsed = urlparse(global_url)
+ host = f"{tenant.subdomain_prefix}.{parsed.hostname}"
+ if parsed.port and parsed.port not in (80, 443):
+ host = f"{host}:{parsed.port}"
+ return f"{parsed.scheme}://{host}"
+
+ # Level 3: Platform global setting
+ global_url = await _get_global_base_url(db)
+ if global_url:
+ return global_url
+
+ # Level 4: Request origin
+ if request:
+ return str(request.base_url).rstrip("/")
+
+ # Level 5: Hardcoded fallback
+ return "http://localhost:8000"
diff --git a/backend/app/core/token_cache.py b/backend/app/core/token_cache.py
new file mode 100644
index 000000000..3dc09c993
--- /dev/null
+++ b/backend/app/core/token_cache.py
@@ -0,0 +1,64 @@
+"""
+Unified Redis-backed token cache with in-memory fallback.
+
+Key naming convention: clawith:token:{type}:{identifier}
+Examples:
+ clawith:token:dingtalk_corp:{app_key}
+ clawith:token:feishu_tenant:{app_id}
+ clawith:token:wecom:{corp_id}
+ clawith:token:teams:{agent_id}
+"""
+import time
+from typing import Optional
+
+# In-memory fallback store: {key: (value, expire_at)}
+_memory_cache: dict[str, tuple[str, float]] = {}
+
+
+async def get_cached_token(key: str) -> Optional[str]:
+ """Get token from Redis (preferred) or memory fallback."""
+ # Try Redis first
+ try:
+ from app.core.events import get_redis
+ redis = await get_redis()
+ if redis:
+ val = await redis.get(key)
+ if val:
+ return val if isinstance(val, str) else val.decode()
+ except Exception:
+ pass
+
+ # Fallback to memory
+ if key in _memory_cache:
+ val, expire_at = _memory_cache[key]
+ if time.time() < expire_at:
+ return val
+ del _memory_cache[key]
+ return None
+
+
+async def set_cached_token(key: str, value: str, ttl_seconds: int) -> None:
+ """Set token in Redis (preferred) and memory fallback."""
+ # Try Redis first
+ try:
+ from app.core.events import get_redis
+ redis = await get_redis()
+ if redis:
+ await redis.setex(key, ttl_seconds, value)
+ except Exception:
+ pass
+
+ # Always set in memory as fallback
+ _memory_cache[key] = (value, time.time() + ttl_seconds)
+
+
+async def delete_cached_token(key: str) -> None:
+ """Delete token from both Redis and memory."""
+ try:
+ from app.core.events import get_redis
+ redis = await get_redis()
+ if redis:
+ await redis.delete(key)
+ except Exception:
+ pass
+ _memory_cache.pop(key, None)
diff --git a/backend/app/models/tenant.py b/backend/app/models/tenant.py
index 113238c62..d122bdc9f 100644
--- a/backend/app/models/tenant.py
+++ b/backend/app/models/tenant.py
@@ -44,6 +44,11 @@ class Tenant(Base):
sso_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
sso_domain: Mapped[str | None] = mapped_column(String(255), unique=True, index=True, nullable=True)
+ # Subdomain prefix for auto-generated tenant URLs (e.g. "acme" → acme.clawith.com)
+ subdomain_prefix: Mapped[str | None] = mapped_column(String(50), unique=True, index=True, nullable=True)
+ # Whether this is the platform's default tenant
+ is_default: Mapped[bool] = mapped_column(Boolean, default=False)
+
# Trigger limits — defaults for new agents & floor values
default_max_triggers: Mapped[int] = mapped_column(Integer, default=20)
min_poll_interval_floor: Mapped[int] = mapped_column(Integer, default=5)
diff --git a/backend/app/schemas/schemas.py b/backend/app/schemas/schemas.py
index 3870392b9..f0da50b62 100644
--- a/backend/app/schemas/schemas.py
+++ b/backend/app/schemas/schemas.py
@@ -180,6 +180,7 @@ class OAuthAuthorizeResponse(BaseModel):
class OAuthCallbackRequest(BaseModel):
code: str
state: str
+ redirect_uri: str = ""
class IdentityBindRequest(BaseModel):
diff --git a/backend/app/services/agent_context.py b/backend/app/services/agent_context.py
index 84ebaece2..62e0e61fc 100644
--- a/backend/app/services/agent_context.py
+++ b/backend/app/services/agent_context.py
@@ -564,4 +564,24 @@ async def build_agent_context(agent_id: uuid.UUID, agent_name: str, role_descrip
if current_user_name:
dynamic_parts.append(f"\n## Current Conversation\nYou are currently chatting with **{current_user_name}**. Address them by name when appropriate.")
+ # Inject platform base URL so agent knows where it is deployed
+ try:
+ from app.services.platform_service import platform_service
+ _platform_url = (await platform_service.get_public_base_url()).rstrip("/")
+ platform_lines = [
+ "\n## Platform Base URLs",
+ "You are running on the Clawith platform. Always use these URLs exactly -- never guess or invent domain names.",
+ "",
+ "- **Platform base**: " + _platform_url,
+ "- **Webhook**: " + _platform_url + "/api/webhooks/t/ (replace with actual trigger token)",
+ "- **Public page**: " + _platform_url + "/p/ (replace with actual page id returned by publish_page)",
+ "- **File download**: " + _platform_url + "/api/agents//files/download?path=",
+ "- **Gateway poll**: " + _platform_url + "/api/gateway/poll (used by external agents to check inbox)",
+ "",
+ "Never use placeholder domains (clawith.com, try.clawith.ai, webhook.clawith.com, api.clawith.ai, etc.).",
+ ]
+ dynamic_parts.append("\n".join(platform_lines))
+ except Exception:
+ pass
+
return "\n".join(static_parts), "\n".join(dynamic_parts)
diff --git a/backend/app/services/agent_tools.py b/backend/app/services/agent_tools.py
index f1880d172..5976494a5 100644
--- a/backend/app/services/agent_tools.py
+++ b/backend/app/services/agent_tools.py
@@ -5544,6 +5544,7 @@ async def _handle_cancel_trigger(agent_id: uuid.UUID, arguments: dict) -> str:
async def _handle_list_triggers(agent_id: uuid.UUID) -> str:
"""List all active triggers for the agent."""
from app.models.trigger import AgentTrigger
+ from app.services.platform_service import platform_service
try:
async with async_session() as db:
@@ -5554,15 +5555,24 @@ async def _handle_list_triggers(agent_id: uuid.UUID) -> str:
)
triggers = result.scalars().all()
+ # Resolve base URL for webhook triggers
+ _base_url = (await platform_service.get_public_base_url(db)).rstrip("/")
+
if not triggers:
return "No triggers found. Use set_trigger to create one."
- lines = ["| Name | Type | Config | Reason | Status | Fires |", "|------|------|--------|--------|--------|-------|"]
+ lines = ["| Name | Type | Config | Webhook URL | Reason | Status | Fires |", "|------|------|--------|-------------|--------|--------|-------|"]
for t in triggers:
status = "✅ active" if t.is_enabled else "⏸ disabled"
- config_str = str(t.config)[:50]
+ config = t.config or {}
+ if t.type == "webhook" and config.get("token"):
+ config_str = f"token: {config['token']}"
+ webhook_url = f"{_base_url}/api/webhooks/t/{config['token']}"
+ else:
+ config_str = str(config)[:50]
+ webhook_url = "-"
reason_str = t.reason[:40] if t.reason else ""
- lines.append(f"| {t.name} | {t.type} | {config_str} | {reason_str} | {status} | {t.fire_count} |")
+ lines.append(f"| {t.name} | {t.type} | {config_str} | {webhook_url} | {reason_str} | {status} | {t.fire_count} |")
return "\n".join(lines)
diff --git a/backend/app/services/auth_provider.py b/backend/app/services/auth_provider.py
index d40cd583b..c4fe559a2 100644
--- a/backend/app/services/auth_provider.py
+++ b/backend/app/services/auth_provider.py
@@ -294,8 +294,19 @@ async def get_authorization_url(self, redirect_uri: str, state: str) -> str:
return f"{base_url}?{params}"
async def get_app_access_token(self) -> str:
- if self._app_access_token:
- return self._app_access_token
+ """Get or refresh the Feishu app access token.
+
+ Cached in Redis (preferred) with in-memory fallback.
+ Key: clawith:token:feishu_tenant:{app_id}
+ TTL: 6900s (7200s validity - 5 min early refresh)
+ """
+ from app.core.token_cache import get_cached_token, set_cached_token
+
+ cache_key = f"clawith:token:feishu_tenant:{self.app_id}"
+ cached = await get_cached_token(cache_key)
+ if cached:
+ self._app_access_token = cached
+ return cached
async with httpx.AsyncClient() as client:
resp = await client.post(
@@ -303,8 +314,13 @@ async def get_app_access_token(self) -> str:
json={"app_id": self.app_id, "app_secret": self.app_secret},
)
data = resp.json()
- self._app_access_token = data.get("app_access_token", "")
- return self._app_access_token
+ token = data.get("app_access_token", "") or data.get("tenant_access_token", "")
+ expire = data.get("expire", 7200)
+ if token:
+ ttl = max(expire - 300, 60)
+ await set_cached_token(cache_key, token, ttl)
+ self._app_access_token = token
+ return token
async def exchange_code_for_token(self, code: str) -> dict:
app_token = await self.get_app_access_token()
@@ -637,6 +653,154 @@ async def get_user_info(self, access_token: str) -> ExternalUserInfo:
)
+class OAuth2AuthProvider(BaseAuthProvider):
+ """Generic OAuth2 provider (RFC 6749 Authorization Code flow).
+
+ Works with any OAuth2-compliant identity provider (Google, Azure AD,
+ Keycloak, Auth0, custom corporate OAuth2 servers, etc.).
+
+ Config keys:
+ client_id, client_secret, authorize_url, token_url, userinfo_url,
+ scope, field_mapping
+ """
+
+ provider_type = "oauth2"
+
+ def __init__(self, provider=None, config=None):
+ super().__init__(provider, config)
+ self.client_id = self.config.get("client_id") or self.config.get("app_id", "")
+ self.client_secret = self.config.get("client_secret") or self.config.get("app_secret", "")
+ self.authorize_url = self.config.get("authorize_url", "")
+ self.token_url = self.config.get("token_url", "")
+ self.userinfo_url = self.config.get("userinfo_url") or self.config.get("user_info_url", "")
+ self.scope = self.config.get("scope", "openid profile email")
+
+ # Derive token_url / userinfo_url from authorize_url if not provided
+ if self.authorize_url and not self.token_url:
+ base = self.authorize_url.rsplit("/", 1)[0]
+ self.token_url = f"{base}/token"
+ if self.authorize_url and not self.userinfo_url:
+ base = self.authorize_url.rsplit("/", 1)[0]
+ self.userinfo_url = f"{base}/userinfo"
+
+ # Configurable field mapping: provider response field -> Clawith field
+ self.field_mapping = self.config.get("field_mapping") or {}
+
+ async def get_authorization_url(self, redirect_uri: str, state: str) -> str:
+ from urllib.parse import quote, urlencode
+ params = urlencode({
+ "response_type": "code",
+ "client_id": self.client_id,
+ "redirect_uri": redirect_uri,
+ "scope": self.scope,
+ "state": state,
+ })
+ return f"{self.authorize_url}?{params}"
+
+ async def exchange_code_for_token(self, code: str, redirect_uri: str = "") -> dict:
+ """Exchange authorization code for access token.
+
+ Uses application/x-www-form-urlencoded per RFC 6749 Section 4.1.3.
+ """
+ async with httpx.AsyncClient() as client:
+ data = {
+ "grant_type": "authorization_code",
+ "code": code,
+ "client_id": self.client_id,
+ "client_secret": self.client_secret,
+ }
+ if redirect_uri:
+ data["redirect_uri"] = redirect_uri
+ resp = await client.post(
+ self.token_url,
+ data=data,
+ )
+ if resp.status_code != 200:
+ logger.error(
+ "OAuth2 token exchange failed (HTTP %s): %s",
+ resp.status_code,
+ resp.text[:500],
+ )
+ return {}
+ return resp.json()
+
+ def _resolve_field(self, data: dict, clawith_field: str, default_keys: list[str]) -> str:
+ """Resolve a field using user-configured mapping first, then standard fallbacks."""
+ custom_key = self.field_mapping.get(clawith_field)
+ if custom_key and data.get(custom_key):
+ return str(data[custom_key])
+ for key in default_keys:
+ if data.get(key):
+ return str(data[key])
+ return ""
+
+ async def get_user_info(self, access_token: str) -> ExternalUserInfo:
+ async with httpx.AsyncClient() as client:
+ resp = await client.get(
+ self.userinfo_url,
+ headers={"Authorization": f"Bearer {access_token}"},
+ )
+
+ # Gracefully handle non-200 responses
+ if resp.status_code != 200:
+ logger.error("OAuth2 userinfo returned HTTP %s", resp.status_code)
+ return ExternalUserInfo(
+ provider_type=self.provider_type,
+ provider_user_id="",
+ raw_data={"error": f"userinfo HTTP {resp.status_code}"},
+ )
+
+ try:
+ resp_data = resp.json()
+ except Exception:
+ logger.error("OAuth2 userinfo response is not valid JSON")
+ return ExternalUserInfo(
+ provider_type=self.provider_type,
+ provider_user_id="",
+ raw_data={"error": "userinfo JSON parse error"},
+ )
+
+ # Some providers wrap payload in {"data": {...}}; standard OIDC is flat
+ if "data" in resp_data and isinstance(resp_data["data"], dict):
+ info = resp_data["data"]
+ else:
+ info = resp_data
+
+ user_id = self._resolve_field(info, "provider_user_id", ["sub", "id", "user_id", "userId"])
+ name = self._resolve_field(info, "display_name", ["name", "preferred_username", "nickname", "userName"])
+ email = self._resolve_field(info, "email", ["email"])
+ mobile = self._resolve_field(info, "mobile", ["phone_number", "mobile", "phone"])
+ avatar = self._resolve_field(info, "avatar_url", ["picture", "avatar_url", "avatar"])
+
+ return ExternalUserInfo(
+ provider_type=self.provider_type,
+ provider_user_id=str(user_id),
+ name=name,
+ email=email,
+ mobile=mobile,
+ avatar_url=avatar,
+ raw_data=info,
+ )
+
+ async def get_user_info_from_token_data(self, token_data: dict) -> ExternalUserInfo:
+ """Fallback: extract user info directly from token response data."""
+ info = token_data
+ user_id = self._resolve_field(info, "provider_user_id", ["sub", "id", "user_id", "userId", "openid"])
+ name = self._resolve_field(info, "display_name", ["name", "preferred_username", "nickname"])
+ email = self._resolve_field(info, "email", ["email"])
+ mobile = self._resolve_field(info, "mobile", ["phone_number", "mobile", "phone"])
+ avatar = self._resolve_field(info, "avatar_url", ["picture", "avatar_url", "avatar"])
+ return ExternalUserInfo(
+ provider_type=self.provider_type,
+ provider_user_id=str(user_id),
+ name=name,
+ email=email,
+ mobile=mobile,
+ avatar_url=avatar,
+ raw_data=info,
+ )
+
+
class MicrosoftTeamsAuthProvider(BaseAuthProvider):
"""Microsoft Teams OAuth provider implementation."""
@@ -658,5 +822,6 @@ async def get_user_info(self, access_token: str) -> ExternalUserInfo:
"feishu": FeishuAuthProvider,
"dingtalk": DingTalkAuthProvider,
"wecom": WeComAuthProvider,
+ "oauth2": OAuth2AuthProvider,
"microsoft_teams": MicrosoftTeamsAuthProvider,
}
diff --git a/backend/app/services/auth_registry.py b/backend/app/services/auth_registry.py
index c3e9b0035..3c655bade 100644
--- a/backend/app/services/auth_registry.py
+++ b/backend/app/services/auth_registry.py
@@ -15,6 +15,7 @@
DingTalkAuthProvider,
FeishuAuthProvider,
MicrosoftTeamsAuthProvider,
+ OAuth2AuthProvider,
WeComAuthProvider,
)
diff --git a/backend/app/services/dingtalk_stream.py b/backend/app/services/dingtalk_stream.py
index 28a8ba8e2..73cf80b20 100644
--- a/backend/app/services/dingtalk_stream.py
+++ b/backend/app/services/dingtalk_stream.py
@@ -5,15 +5,401 @@
"""
import asyncio
+import base64
+import json
import threading
import uuid
-from typing import Dict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+import httpx
from loguru import logger
from sqlalchemy import select
+from app.config import get_settings
from app.database import async_session
from app.models.channel_config import ChannelConfig
+from app.services.dingtalk_token import dingtalk_token_manager
+
+
+# ─── DingTalk Media Helpers ─────────────────────────────
+
+
+async def _get_media_download_url(
+ access_token: str, download_code: str, robot_code: str
+) -> Optional[str]:
+ """Get media file download URL from DingTalk API."""
+ try:
+ async with httpx.AsyncClient(timeout=10) as client:
+ resp = await client.post(
+ "https://api.dingtalk.com/v1.0/robot/messageFiles/download",
+ headers={"x-acs-dingtalk-access-token": access_token},
+ json={"downloadCode": download_code, "robotCode": robot_code},
+ )
+ data = resp.json()
+ url = data.get("downloadUrl")
+ if url:
+ return url
+ logger.error(f"[DingTalk] Failed to get download URL: {data}")
+ return None
+ except Exception as e:
+ logger.error(f"[DingTalk] Error getting download URL: {e}")
+ return None
+
+
+async def _download_file(url: str) -> Optional[bytes]:
+ """Download a file from a URL and return its bytes."""
+ try:
+ async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client:
+ resp = await client.get(url)
+ resp.raise_for_status()
+ return resp.content
+ except Exception as e:
+ logger.error(f"[DingTalk] Error downloading file: {e}")
+ return None
+
+
+async def download_dingtalk_media(
+ app_key: str, app_secret: str, download_code: str
+) -> Optional[bytes]:
+ """Download a media file from DingTalk using downloadCode.
+
+ Steps: get access_token -> get download URL -> download file bytes.
+ """
+ access_token = await dingtalk_token_manager.get_token(app_key, app_secret)
+ if not access_token:
+ return None
+
+ download_url = await _get_media_download_url(access_token, download_code, app_key)
+ if not download_url:
+ return None
+
+ return await _download_file(download_url)
+
+
+def _resolve_upload_dir(agent_id: uuid.UUID) -> Path:
+ """Get the uploads directory for an agent, creating it if needed."""
+ settings = get_settings()
+ upload_dir = Path(settings.AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads"
+ upload_dir.mkdir(parents=True, exist_ok=True)
+ return upload_dir
+
+
+async def _process_media_message(
+ msg_data: dict,
+ app_key: str,
+ app_secret: str,
+ agent_id: uuid.UUID,
+) -> Tuple[str, Optional[List[str]], Optional[List[str]]]:
+ """Process a DingTalk message and extract text + media info.
+
+ Returns:
+ (user_text, image_base64_list, saved_file_paths)
+ - user_text: text content for the LLM (may include markers)
+ - image_base64_list: list of base64-encoded image data URIs, or None
+ - saved_file_paths: list of saved file paths, or None
+ """
+ msgtype = msg_data.get("msgtype", "text")
+ logger.info(f"[DingTalk] Processing message type: {msgtype}")
+
+ image_base64_list: List[str] = []
+ saved_file_paths: List[str] = []
+
+ if msgtype == "text":
+ text_content = msg_data.get("text", {}).get("content", "").strip()
+ return text_content, None, None
+
+ elif msgtype == "picture":
+ download_code = msg_data.get("content", {}).get("downloadCode", "")
+ if not download_code:
+ download_code = msg_data.get("downloadCode", "")
+ if not download_code:
+ logger.warning("[DingTalk] Picture message without downloadCode")
+ return "[User sent an image, but it could not be downloaded]", None, None
+
+ file_bytes = await download_dingtalk_media(app_key, app_secret, download_code)
+ if not file_bytes:
+ return "[User sent an image, but download failed]", None, None
+
+ upload_dir = _resolve_upload_dir(agent_id)
+ filename = f"dingtalk_img_{uuid.uuid4().hex[:8]}.jpg"
+ save_path = upload_dir / filename
+ save_path.write_bytes(file_bytes)
+ logger.info(f"[DingTalk] Saved image to {save_path} ({len(file_bytes)} bytes)")
+
+ b64_data = base64.b64encode(file_bytes).decode("ascii")
+ image_marker = f"[image_data:data:image/jpeg;base64,{b64_data}]"
+ return (
+ f"[User sent an image]\n{image_marker}",
+ [f"data:image/jpeg;base64,{b64_data}"],
+ [str(save_path)],
+ )
+
+ elif msgtype == "richText":
+ rich_text = msg_data.get("content", {}).get("richText", [])
+ text_parts: List[str] = []
+
+ for section in rich_text:
+ for item in section if isinstance(section, list) else [section]:
+ if "text" in item:
+ text_parts.append(item["text"])
+ elif "downloadCode" in item:
+ file_bytes = await download_dingtalk_media(
+ app_key, app_secret, item["downloadCode"]
+ )
+ if file_bytes:
+ upload_dir = _resolve_upload_dir(agent_id)
+ filename = f"dingtalk_richimg_{uuid.uuid4().hex[:8]}.jpg"
+ save_path = upload_dir / filename
+ save_path.write_bytes(file_bytes)
+ logger.info(f"[DingTalk] Saved rich text image to {save_path}")
+
+ b64_data = base64.b64encode(file_bytes).decode("ascii")
+ image_marker = f"[image_data:data:image/jpeg;base64,{b64_data}]"
+ text_parts.append(image_marker)
+ image_base64_list.append(f"data:image/jpeg;base64,{b64_data}")
+ saved_file_paths.append(str(save_path))
+
+ combined_text = "\n".join(text_parts).strip()
+ if not combined_text:
+ combined_text = "[User sent a rich text message]"
+
+ return (
+ combined_text,
+ image_base64_list if image_base64_list else None,
+ saved_file_paths if saved_file_paths else None,
+ )
+
+ elif msgtype == "audio":
+ content = msg_data.get("content", {})
+ recognition = content.get("recognition", "")
+ if recognition:
+ logger.info(f"[DingTalk] Audio with recognition: {recognition[:80]}")
+ return f"[Voice message] {recognition}", None, None
+
+ download_code = content.get("downloadCode", "")
+ if download_code:
+ file_bytes = await download_dingtalk_media(app_key, app_secret, download_code)
+ if file_bytes:
+ upload_dir = _resolve_upload_dir(agent_id)
+ duration = content.get("duration", "unknown")
+ filename = f"dingtalk_audio_{uuid.uuid4().hex[:8]}.amr"
+ save_path = upload_dir / filename
+ save_path.write_bytes(file_bytes)
+ logger.info(f"[DingTalk] Saved audio to {save_path} ({len(file_bytes)} bytes)")
+ return (
+ f"[User sent a voice message, duration {duration}ms, saved to {filename}]",
+ None,
+ [str(save_path)],
+ )
+ return "[User sent a voice message, but it could not be processed]", None, None
+
+ elif msgtype == "video":
+ content = msg_data.get("content", {})
+ download_code = content.get("downloadCode", "")
+ if download_code:
+ file_bytes = await download_dingtalk_media(app_key, app_secret, download_code)
+ if file_bytes:
+ upload_dir = _resolve_upload_dir(agent_id)
+ duration = content.get("duration", "unknown")
+ filename = f"dingtalk_video_{uuid.uuid4().hex[:8]}.mp4"
+ save_path = upload_dir / filename
+ save_path.write_bytes(file_bytes)
+ logger.info(f"[DingTalk] Saved video to {save_path} ({len(file_bytes)} bytes)")
+ return (
+ f"[User sent a video, duration {duration}ms, saved to {filename}]",
+ None,
+ [str(save_path)],
+ )
+ return "[User sent a video, but it could not be downloaded]", None, None
+
+ elif msgtype == "file":
+ content = msg_data.get("content", {})
+ download_code = content.get("downloadCode", "")
+ original_filename = content.get("fileName", "unknown_file")
+ if download_code:
+ file_bytes = await download_dingtalk_media(app_key, app_secret, download_code)
+ if file_bytes:
+ upload_dir = _resolve_upload_dir(agent_id)
+ safe_name = f"dingtalk_{uuid.uuid4().hex[:8]}_{original_filename}"
+ save_path = upload_dir / safe_name
+ save_path.write_bytes(file_bytes)
+ logger.info(
+ f"[DingTalk] Saved file '{original_filename}' to {save_path} "
+ f"({len(file_bytes)} bytes)"
+ )
+ return (
+ f"[file:{original_filename}]",
+ None,
+ [str(save_path)],
+ )
+ return f"[User sent file {original_filename}, but it could not be downloaded]", None, None
+
+ else:
+ logger.warning(f"[DingTalk] Unsupported message type: {msgtype}")
+ return f"[User sent a {msgtype} message, which is not yet supported]", None, None
+
+
+# ─── DingTalk Media Upload & Send ───────────────────────
+
+async def _upload_dingtalk_media(
+ app_key: str,
+ app_secret: str,
+ file_path: str,
+ media_type: str = "file",
+) -> Optional[str]:
+ """Upload a media file to DingTalk and return the mediaId.
+
+ Args:
+ app_key: DingTalk app key (robotCode).
+ app_secret: DingTalk app secret.
+ file_path: Local file path to upload.
+ media_type: One of 'image', 'voice', 'video', 'file'.
+
+ Returns:
+ mediaId string on success, None on failure.
+ """
+ access_token = await dingtalk_token_manager.get_token(app_key, app_secret)
+ if not access_token:
+ return None
+
+ file_p = Path(file_path)
+ if not file_p.exists():
+ logger.error(f"[DingTalk] Upload failed: file not found: {file_path}")
+ return None
+
+ try:
+ file_bytes = file_p.read_bytes()
+ async with httpx.AsyncClient(timeout=60) as client:
+ # Use the legacy oapi endpoint which is more reliable and widely supported.
+ upload_url = (
+ f"https://oapi.dingtalk.com/media/upload"
+ f"?access_token={access_token}&type={media_type}"
+ )
+ resp = await client.post(
+ upload_url,
+ files={"media": (file_p.name, file_bytes)},
+ )
+ data = resp.json()
+ # Legacy API returns media_id (snake_case), new API returns mediaId
+ media_id = data.get("media_id") or data.get("mediaId")
+ if media_id and data.get("errcode", 0) == 0:
+ logger.info(
+ f"[DingTalk] Uploaded {media_type} '{file_p.name}' -> mediaId={media_id[:20]}..."
+ )
+ return media_id
+ logger.error(f"[DingTalk] Upload failed: {data}")
+ return None
+ except Exception as e:
+ logger.error(f"[DingTalk] Upload error: {e}")
+ return None
+
+
+async def _send_dingtalk_media_message(
+ app_key: str,
+ app_secret: str,
+ target_id: str,
+ media_id: str,
+ media_type: str,
+ conversation_type: str,
+ filename: Optional[str] = None,
+) -> bool:
+ """Send a media message via DingTalk proactive message API.
+
+ Args:
+ app_key: DingTalk app key (robotCode).
+ app_secret: DingTalk app secret.
+ target_id: For P2P: sender_staff_id; For group: openConversationId.
+ media_id: The mediaId from upload.
+ media_type: One of 'image', 'voice', 'video', 'file'.
+ conversation_type: '1' for P2P, '2' for group.
+ filename: Original filename (used for file/video types).
+
+ Returns:
+ True on success, False on failure.
+ """
+ access_token = await dingtalk_token_manager.get_token(app_key, app_secret)
+ if not access_token:
+ return False
+
+ headers = {"x-acs-dingtalk-access-token": access_token}
+
+ # Build msgKey and msgParam based on media_type
+ if media_type == "image":
+ msg_key = "sampleImageMsg"
+ msg_param = json.dumps({"photoURL": media_id})
+ elif media_type == "voice":
+ msg_key = "sampleAudio"
+ msg_param = json.dumps({"mediaId": media_id, "duration": "3000"})
+ elif media_type == "video":
+ safe_name = filename or "video.mp4"
+ ext = Path(safe_name).suffix.lstrip(".") or "mp4"
+ msg_key = "sampleFile"
+ msg_param = json.dumps({
+ "mediaId": media_id,
+ "fileName": safe_name,
+ "fileType": ext,
+ })
+ else:
+ # file
+ safe_name = filename or "file"
+ ext = Path(safe_name).suffix.lstrip(".") or "bin"
+ msg_key = "sampleFile"
+ msg_param = json.dumps({
+ "mediaId": media_id,
+ "fileName": safe_name,
+ "fileType": ext,
+ })
+
+ try:
+ async with httpx.AsyncClient(timeout=15) as client:
+ if conversation_type == "2":
+ # Group chat
+ resp = await client.post(
+ "https://api.dingtalk.com/v1.0/robot/groupMessages/send",
+ headers=headers,
+ json={
+ "robotCode": app_key,
+ "openConversationId": target_id,
+ "msgKey": msg_key,
+ "msgParam": msg_param,
+ },
+ )
+ else:
+ # P2P chat
+ resp = await client.post(
+ "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend",
+ headers=headers,
+ json={
+ "robotCode": app_key,
+ "userIds": [target_id],
+ "msgKey": msg_key,
+ "msgParam": msg_param,
+ },
+ )
+
+ data = resp.json()
+ if resp.status_code >= 400 or data.get("errcode"):
+ logger.error(f"[DingTalk] Send media failed: {data}")
+ return False
+
+ logger.info(
+ f"[DingTalk] Sent {media_type} message to {target_id[:16]}... "
+ f"(conv_type={conversation_type})"
+ )
+ return True
+ except Exception as e:
+ logger.error(f"[DingTalk] Send media error: {e}")
+ return False
+
+
+# ─── Stream Manager ─────────────────────────────────────
+
+
+def _fire_and_forget(loop, coro):
+ """Schedule a coroutine on the main loop and log any unhandled exception."""
+ future = asyncio.run_coroutine_threadsafe(coro, loop)
+ future.add_done_callback(lambda f: f.exception() if not f.cancelled() else None)
class DingTalkStreamManager:
@@ -67,47 +453,72 @@ def _run_client_thread(
app_secret: str,
stop_event: threading.Event,
):
- """Run the DingTalk Stream client in a blocking thread."""
+ """Run the DingTalk Stream client with auto-reconnect."""
try:
import dingtalk_stream
+ except ImportError:
+ logger.warning(
+ "[DingTalk Stream] dingtalk-stream package not installed. "
+ "Install with: pip install dingtalk-stream"
+ )
+ self._threads.pop(agent_id, None)
+ self._stop_events.pop(agent_id, None)
+ return
- # Reference to manager's main loop for async dispatch
- main_loop = self._main_loop
-
- class ClawithChatbotHandler(dingtalk_stream.ChatbotHandler):
- """Custom handler that dispatches messages to the Clawith LLM pipeline."""
-
- async def process(self, callback: dingtalk_stream.CallbackMessage):
- """Handle incoming bot message from DingTalk Stream.
-
- NOTE: The SDK invokes this method in the thread's own asyncio loop,
- so we must dispatch to the main FastAPI loop for DB + LLM work.
- """
- try:
- # Parse the raw data into a ChatbotMessage via class method
- incoming = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
-
- # Extract text content
+ MAX_RETRIES = 5
+ RETRY_DELAYS = [2, 5, 15, 30, 60] # exponential backoff, seconds
+
+ # Reference to manager's main loop for async dispatch
+ main_loop = self._main_loop
+ retries = 0
+ manager_self = self
+
+ class ClawithChatbotHandler(dingtalk_stream.ChatbotHandler):
+ """Custom handler that dispatches messages to the Clawith LLM pipeline."""
+
+ async def process(self, callback: dingtalk_stream.CallbackMessage):
+ """Handle incoming bot message from DingTalk Stream.
+
+ NOTE: The SDK invokes this method in the thread's own asyncio loop,
+ so we must dispatch to the main FastAPI loop for DB + LLM work.
+ """
+ try:
+ # Parse the raw data
+ incoming = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
+ msg_data = callback.data if isinstance(callback.data, dict) else json.loads(callback.data)
+
+ msgtype = msg_data.get("msgtype", "text")
+ sender_staff_id = incoming.sender_staff_id or incoming.sender_id or ""
+ sender_nick = incoming.sender_nick or ""
+ message_id = incoming.message_id or ""
+ conversation_id = incoming.conversation_id or ""
+ conversation_type = incoming.conversation_type or "1"
+ session_webhook = incoming.session_webhook or ""
+
+ logger.info(
+ f"[DingTalk Stream] Received {msgtype} message from {sender_staff_id}"
+ )
+
+ if msgtype == "text":
+ # Plain text: use existing logic
text_list = incoming.get_text_list()
user_text = " ".join(text_list).strip() if text_list else ""
-
if not user_text:
return dingtalk_stream.AckMessage.STATUS_OK, "empty message"
- sender_staff_id = incoming.sender_staff_id or incoming.sender_id or ""
- conversation_id = incoming.conversation_id or ""
- conversation_type = incoming.conversation_type or "1"
- session_webhook = incoming.session_webhook or ""
-
logger.info(
- f"[DingTalk Stream] Message from [{incoming.sender_nick}]{sender_staff_id}: {user_text[:80]}"
+ f"[DingTalk Stream] Text from {sender_staff_id}: {user_text[:80]}"
)
- # Dispatch to the main FastAPI event loop for DB + LLM processing
from app.api.dingtalk import process_dingtalk_message
if main_loop and main_loop.is_running():
- future = asyncio.run_coroutine_threadsafe(
+ # Add thinking reaction immediately
+ from app.services.dingtalk_reaction import add_thinking_reaction
+ _fire_and_forget(main_loop,
+ add_thinking_reaction(app_key, app_secret, message_id, conversation_id))
+
+ _fire_and_forget(main_loop,
process_dingtalk_message(
agent_id=agent_id,
sender_staff_id=sender_staff_id,
@@ -115,50 +526,134 @@ async def process(self, callback: dingtalk_stream.CallbackMessage):
conversation_id=conversation_id,
conversation_type=conversation_type,
session_webhook=session_webhook,
- ),
- main_loop,
- )
- # Wait for result (with timeout)
- try:
- future.result(timeout=120)
- except Exception as e:
- logger.error(f"[DingTalk Stream] LLM processing error: {e}")
- import traceback
- traceback.print_exc()
+ sender_nick=sender_nick,
+ message_id=message_id,
+ ))
else:
- logger.warning("[DingTalk Stream] Main loop not available for dispatch")
-
- return dingtalk_stream.AckMessage.STATUS_OK, "ok"
- except Exception as e:
- logger.error(f"[DingTalk Stream] Error in message handler: {e}")
- import traceback
- traceback.print_exc()
- return dingtalk_stream.AckMessage.STATUS_SYSTEM_EXCEPTION, str(e)
-
- credential = dingtalk_stream.Credential(client_id=app_key, client_secret=app_secret)
- client = dingtalk_stream.DingTalkStreamClient(credential=credential)
- client.register_callback_handler(
- dingtalk_stream.chatbot.ChatbotMessage.TOPIC,
- ClawithChatbotHandler(),
- )
+ logger.warning("[DingTalk Stream] Main loop not available")
- logger.info(f"[DingTalk Stream] Connecting for agent {agent_id}...")
- # start_forever() blocks until disconnected
- client.start_forever()
+ else:
+ # Non-text message: process media in the main loop
+ if main_loop and main_loop.is_running():
+ # Add thinking reaction immediately
+ from app.services.dingtalk_reaction import add_thinking_reaction
+ _fire_and_forget(main_loop,
+ add_thinking_reaction(app_key, app_secret, message_id, conversation_id))
+
+ _fire_and_forget(main_loop,
+ manager_self._handle_media_and_dispatch(
+ msg_data=msg_data,
+ app_key=app_key,
+ app_secret=app_secret,
+ agent_id=agent_id,
+ sender_staff_id=sender_staff_id,
+ conversation_id=conversation_id,
+ conversation_type=conversation_type,
+ session_webhook=session_webhook,
+ sender_nick=sender_nick,
+ message_id=message_id,
+ ))
+ else:
+ logger.warning("[DingTalk Stream] Main loop not available")
+
+ return dingtalk_stream.AckMessage.STATUS_OK, "ok"
+ except Exception as e:
+ logger.error(f"[DingTalk Stream] Error in message handler: {e}")
+ import traceback
+ traceback.print_exc()
+ return dingtalk_stream.AckMessage.STATUS_SYSTEM_EXCEPTION, str(e)
+
+ while not stop_event.is_set() and retries <= MAX_RETRIES:
+ try:
+ credential = dingtalk_stream.Credential(client_id=app_key, client_secret=app_secret)
+ client = dingtalk_stream.DingTalkStreamClient(credential=credential)
+ client.register_callback_handler(
+ dingtalk_stream.chatbot.ChatbotMessage.TOPIC,
+ ClawithChatbotHandler(),
+ )
- except ImportError:
- logger.warning(
- "[DingTalk Stream] dingtalk-stream package not installed. "
- "Install with: pip install dingtalk-stream"
+ logger.info(
+ f"[DingTalk Stream] Connecting for agent {agent_id}... "
+ f"(attempt {retries + 1}/{MAX_RETRIES + 1})"
+ )
+ # start_forever() blocks until disconnected
+ client.start_forever()
+
+ # start_forever returned: connection dropped
+ if stop_event.is_set():
+ break # intentional stop, no retry
+
+ # Reset retries on successful connection (ran for a while then disconnected)
+ retries = 0
+ retries += 1
+ logger.warning(
+ f"[DingTalk Stream] Connection lost for agent {agent_id}, will retry..."
+ )
+
+ except Exception as e:
+ retries += 1
+ logger.error(
+ f"[DingTalk Stream] Connection error for {agent_id} "
+ f"(attempt {retries}/{MAX_RETRIES + 1}): {e}"
+ )
+
+ if retries > MAX_RETRIES:
+ logger.error(
+ f"[DingTalk Stream] Agent {agent_id} exhausted all {MAX_RETRIES} retries, giving up"
+ )
+ break
+
+ delay = RETRY_DELAYS[min(retries - 1, len(RETRY_DELAYS) - 1)]
+ logger.info(
+ f"[DingTalk Stream] Retrying in {delay}s for agent {agent_id}..."
)
- except Exception as e:
- logger.error(f"[DingTalk Stream] Client error for {agent_id}: {e}")
- import traceback
- traceback.print_exc()
- finally:
- self._threads.pop(agent_id, None)
- self._stop_events.pop(agent_id, None)
- logger.info(f"[DingTalk Stream] Client stopped for agent {agent_id}")
+ # Use stop_event.wait so we exit immediately if stopped
+ if stop_event.wait(timeout=delay):
+ break # stop was requested during wait
+
+ self._threads.pop(agent_id, None)
+ self._stop_events.pop(agent_id, None)
+ logger.info(f"[DingTalk Stream] Client stopped for agent {agent_id}")
+
+ @staticmethod
+ async def _handle_media_and_dispatch(
+ msg_data: dict,
+ app_key: str,
+ app_secret: str,
+ agent_id: uuid.UUID,
+ sender_staff_id: str,
+ conversation_id: str,
+ conversation_type: str,
+ session_webhook: str,
+ sender_nick: str = "",
+ message_id: str = "",
+ ):
+ """Download media, then dispatch to process_dingtalk_message."""
+ from app.api.dingtalk import process_dingtalk_message
+
+ user_text, image_base64_list, saved_file_paths = await _process_media_message(
+ msg_data=msg_data,
+ app_key=app_key,
+ app_secret=app_secret,
+ agent_id=agent_id,
+ )
+
+ if not user_text:
+ logger.info("[DingTalk Stream] Empty content after media processing, skipping")
+ return
+
+ await process_dingtalk_message(
+ agent_id=agent_id,
+ sender_staff_id=sender_staff_id,
+ user_text=user_text,
+ conversation_id=conversation_id,
+ conversation_type=conversation_type,
+ session_webhook=session_webhook,
+ image_base64_list=image_base64_list,
+ saved_file_paths=saved_file_paths,
+ sender_nick=sender_nick,
+ message_id=message_id,
+ )
async def stop_client(self, agent_id: uuid.UUID):
"""Stop a running Stream client for an agent."""
@@ -167,7 +662,10 @@ async def stop_client(self, agent_id: uuid.UUID):
stop_event.set()
thread = self._threads.pop(agent_id, None)
if thread and thread.is_alive():
- logger.info(f"[DingTalk Stream] Stopping client for agent {agent_id}")
+ logger.info(f"[DingTalk Stream] Stopping client for agent {agent_id}, waiting for thread...")
+ thread.join(timeout=5)
+ if thread.is_alive():
+ logger.warning(f"[DingTalk Stream] Thread for {agent_id} did not exit within 5s")
async def start_all(self):
"""Start Stream clients for all configured DingTalk agents."""
diff --git a/backend/app/services/feishu_service.py b/backend/app/services/feishu_service.py
index 629a95e90..33871e4c2 100644
--- a/backend/app/services/feishu_service.py
+++ b/backend/app/services/feishu_service.py
@@ -86,21 +86,39 @@ async def get_app_access_token(self) -> str:
return await self.get_tenant_access_token(self.app_id, self.app_secret)
async def get_tenant_access_token(self, app_id: str = None, app_secret: str = None) -> str:
- """Get or refresh the app-level access token (tenant_access_token)."""
+ """Get or refresh the app-level access token (tenant_access_token).
+
+ Cached in Redis (preferred) with in-memory fallback.
+ Key: clawith:token:feishu_tenant:{app_id}
+ TTL: 6900s (7200s validity - 5 min early refresh)
+ """
+ from app.core.token_cache import get_cached_token, set_cached_token
+
target_app_id = app_id or self.app_id
target_app_secret = app_secret or self.app_secret
-
+ cache_key = f"clawith:token:feishu_tenant:{target_app_id}"
+
+ cached = await get_cached_token(cache_key)
+ if cached:
+ if not app_id:
+ self._app_access_token = cached
+ return cached
+
async with httpx.AsyncClient() as client:
resp = await client.post(FEISHU_APP_TOKEN_URL, json={
"app_id": target_app_id,
"app_secret": target_app_secret,
})
data = resp.json()
-
+
token = data.get("tenant_access_token") or data.get("app_access_token", "")
- if not app_id: # only cache default app token
- self._app_access_token = token
-
+ expire = data.get("expire", 7200)
+ if token:
+ ttl = max(expire - 300, 60)
+ await set_cached_token(cache_key, token, ttl)
+ if not app_id: # only update instance var for default app token
+ self._app_access_token = token
+
return token
async def exchange_code_for_user(self, code: str) -> dict:
diff --git a/backend/app/services/org_sync_adapter.py b/backend/app/services/org_sync_adapter.py
index 8a3525d13..9ebbaa681 100644
--- a/backend/app/services/org_sync_adapter.py
+++ b/backend/app/services/org_sync_adapter.py
@@ -781,12 +781,23 @@ def api_base_url(self) -> str:
return self.DINGTALK_API_URL
async def get_access_token(self) -> str:
- if self._access_token and self._token_expires_at and datetime.now() < self._token_expires_at:
- return self._access_token
+ """Get or refresh the DingTalk access token.
+
+ Cached in Redis (preferred) with in-memory fallback.
+ Key: clawith:token:dingtalk_corp:{app_key}
+ TTL: expires_in - 300s (5 min early refresh)
+ """
+ from app.core.token_cache import get_cached_token, set_cached_token
if not self.app_key or not self.app_secret:
raise ValueError("DingTalk app_key/app_secret missing in provider config")
+ cache_key = f"clawith:token:dingtalk_corp:{self.app_key}"
+ cached = await get_cached_token(cache_key)
+ if cached:
+ self._access_token = cached
+ return cached
+
async with httpx.AsyncClient() as client:
resp = await client.get(
self.DINGTALK_TOKEN_URL,
@@ -797,9 +808,11 @@ async def get_access_token(self) -> str:
raise RuntimeError(f"DingTalk token error: {data.get('errmsg') or data}")
token = data.get("access_token") or ""
expires_in = int(data.get("expires_in") or 7200)
- self._access_token = token
- # refresh a bit earlier
- self._token_expires_at = datetime.now() + timedelta(seconds=max(expires_in - 60, 60))
+ if token:
+ ttl = max(expires_in - 300, 60)
+ await set_cached_token(cache_key, token, ttl)
+ self._access_token = token
+ self._token_expires_at = datetime.now() + timedelta(seconds=max(expires_in - 60, 60))
return token
async def fetch_departments(self) -> list[ExternalDepartment]:
@@ -990,19 +1003,30 @@ def api_base_url(self) -> str:
async def get_access_token(self) -> str:
"""Get valid access token using the 通讯录同步 (contact-sync) secret.
+ Cached in Redis (preferred) with in-memory fallback.
+ Key: clawith:token:wecom:{corp_id}
+ TTL: 6900s (7200s validity - 5 min early refresh)
+
This token can call department/simplelist and user/list_id.
It cannot call user/list or user/get (those raise errcode 48009).
Full user profiles are obtained passively via SSO login instead.
"""
- if self._access_token and self._token_expires_at and datetime.now() < self._token_expires_at:
- return self._access_token
+ from app.core.token_cache import get_cached_token, set_cached_token
if not self.corp_id or not self.secret:
raise ValueError("WeCom corp_id or secret missing in provider config")
+ cache_key = f"clawith:token:wecom:{self.corp_id}"
+ cached = await get_cached_token(cache_key)
+ if cached:
+ self._access_token = cached
+ return cached
+
token = await self._fetch_token(self.corp_id, self.secret)
+ if token:
+ ttl = max(7200 - 300, 300)
+ await set_cached_token(cache_key, token, ttl)
self._access_token = token
- # Refresh slightly before true expiry to avoid clock-skew issues
self._token_expires_at = datetime.now() + timedelta(seconds=7200 - 300)
return token
diff --git a/backend/app/services/platform_service.py b/backend/app/services/platform_service.py
index 7bdc61faf..3f73aa8d7 100644
--- a/backend/app/services/platform_service.py
+++ b/backend/app/services/platform_service.py
@@ -20,23 +20,37 @@ def is_ip_address(self, host: str) -> bool:
async def get_public_base_url(self, db: AsyncSession | None = None, request: Request | None = None) -> str:
"""Resolve the platform's public base URL with priority lookup.
-
+
Priority:
1. Environment variable (PUBLIC_BASE_URL) - from .env or docker
- 2. Incoming request's base URL (browser address)
- 3. Hardcoded fallback (https://try.clawith.ai)
+ 2. Database system_settings (platform.public_base_url)
+ 3. Incoming request's base URL (browser address)
+ 4. Hardcoded fallback (https://try.clawith.ai)
"""
# 1. Try environment variable
env_url = os.environ.get("PUBLIC_BASE_URL")
if env_url:
return env_url.rstrip("/")
- # 2. Fallback to request (browser address)
+ # 2. Try database system_settings
+ if db:
+ try:
+ from app.models.system_settings import SystemSetting
+ result = await db.execute(
+ select(SystemSetting).where(SystemSetting.key == "platform")
+ )
+ setting = result.scalar_one_or_none()
+ if setting and setting.value and setting.value.get("public_base_url"):
+ return setting.value["public_base_url"].rstrip("/")
+ except Exception:
+ pass
+
+ # 3. Fallback to request (browser address)
if request:
# Note: request.base_url might include trailing slash
return str(request.base_url).rstrip("/")
- # 3. Absolute fallback
+ # 4. Absolute fallback
return "https://try.clawith.ai"
diff --git a/frontend/src/i18n/en.json b/frontend/src/i18n/en.json
index a5be3fd99..7a417ae2c 100644
--- a/frontend/src/i18n/en.json
+++ b/frontend/src/i18n/en.json
@@ -1359,7 +1359,30 @@
"messagingTitle": "3. Proactive 1-to-1 Messaging (AI-initiated)",
"messagingDesc": "Sending messages to individual users requires a valid WeCom API access token, which can only be obtained from a server IP that has been whitelisted in the self-built app's settings. Unlike Feishu, WeCom mandates IP-level restrictions on all API calls — there is no token-only authentication option.",
"footerText": "Due to the above constraints, WeCom integration currently cannot be easily set up by most users. We are actively exploring alternative approaches — including WeCom ISV (service provider) registration and lower-friction API options — or we may advocate for WeCom to relax these restrictions for SaaS platforms. Configuration will be re-enabled once a viable path is available."
- }
+ },
+ "avatarUrlField": "Avatar URL Field",
+ "avatarUrlFieldPlaceholder": "Default: picture",
+ "authorizeUrlPlaceholder": "https://sso.example.com/oauth2/authorize",
+ "clearAllMappings": "Clear All Field Mappings",
+ "clearField": "Clear this field",
+ "deleteConfirmProvider": "Are you sure you want to delete this configuration?",
+ "emailField": "Email Field",
+ "emailFieldPlaceholder": "Default: email",
+ "fieldMapping": "Field Mapping",
+ "fieldMappingHint": "Optional, leave empty to use standard OIDC fields",
+ "hasCustomMapping": "Custom field mapping configured",
+ "mobileField": "Mobile Field",
+ "mobileFieldPlaceholder": "Default: phone_number",
+ "nameField": "Name Field",
+ "nameFieldPlaceholder": "Default: name",
+ "oauth2": "OAuth2",
+ "oauth2Desc": "Generic OIDC Provider",
+ "scope": "Scope",
+ "scopePlaceholder": "openid profile email",
+ "tokenUrlPlaceholder": "Leave empty to auto-derive from Authorize URL",
+ "userIdField": "User ID Field",
+ "userIdFieldPlaceholder": "Default: sub",
+ "userInfoUrlPlaceholder": "Leave empty to auto-derive from Authorize URL"
},
"dangerZone": "️ Danger Zone",
"deleteCompanyDesc": "Permanently delete this company and all its data, including agents, models, tools, and skills. This action cannot be undone.",
@@ -1724,6 +1747,10 @@
"ssoConfigDesc": "Configure SSO and custom domain for this company.",
"ssoEnabled": "Enable SSO",
"ssoDomain": "Custom Access Domain",
- "ssoDomainPlaceholder": "e.g. acme.clawith.com"
+ "ssoDomainPlaceholder": "e.g. acme.clawith.com",
+ "publicUrl": {
+ "title": "Public URL",
+ "desc": "The external URL used for SSO callbacks, subdomain generation, and email links. Include the protocol (e.g. https://example.com)."
+ }
}
}
diff --git a/frontend/src/i18n/zh.json b/frontend/src/i18n/zh.json
index ee020b0e1..956f784f7 100644
--- a/frontend/src/i18n/zh.json
+++ b/frontend/src/i18n/zh.json
@@ -1525,7 +1525,30 @@
"messagingTitle": "3. AI 主动发送一对一消息",
"messagingDesc": "向企微成员主动发消息,需先获取有效的 API access_token,而获取 token 的请求 IP 必须在自建应用的「企业可信IP」白名单中。与飞书不同,企微对所有 API 调用均强制要求 IP 白名单,没有仅凭 token 认证的选项。",
"footerText": "由于上述限制,目前大多数用户无法轻松完成企业微信集成配置。我们正在积极探索替代方案——包括申请企微服务商(ISV)资质及寻找限制更少的接入方式——同时也希望企微能够降低对 SaaS 平台的接入门槛。待可行路径明确后,配置入口将重新开启。"
- }
+ },
+ "avatarUrlField": "头像地址字段",
+ "avatarUrlFieldPlaceholder": "默认: picture",
+ "authorizeUrlPlaceholder": "https://sso.example.com/oauth2/authorize",
+ "clearAllMappings": "清空所有字段映射",
+ "clearField": "清空此字段",
+ "deleteConfirmProvider": "确定要删除此配置吗?",
+ "emailField": "邮箱字段",
+ "emailFieldPlaceholder": "默认:email",
+ "fieldMapping": "字段映射",
+ "fieldMappingHint": "可选,留空使用标准 OIDC 字段",
+ "hasCustomMapping": "当前已配置自定义字段映射",
+ "mobileField": "手机号字段",
+ "mobileFieldPlaceholder": "默认:phone_number",
+ "nameField": "姓名字段",
+ "nameFieldPlaceholder": "默认:name",
+ "oauth2": "OAuth2",
+ "oauth2Desc": "通用 OIDC 提供商",
+ "scope": "Scope",
+ "scopePlaceholder": "openid profile email",
+ "tokenUrlPlaceholder": "留空则从授权地址自动推导",
+ "userIdField": "用户 ID 字段",
+ "userIdFieldPlaceholder": "默认:sub",
+ "userInfoUrlPlaceholder": "留空则从授权地址自动推导"
}
},
"common": {
@@ -1619,7 +1642,11 @@
"all": "全部",
"filterStatus": "按状态筛选",
"enable": "启用",
- "noFilterResults": "没有符合当前筛选条件的公司。"
+ "noFilterResults": "没有符合当前筛选条件的公司。",
+ "publicUrl": {
+ "title": "公开访问地址",
+ "desc": "用于 SSO 回调、子域名生成和邮件链接的外部 URL。请包含协议(如 https://example.com)。"
+ }
},
"companySetup": {
"title": "设置你的公司",
diff --git a/frontend/src/pages/AdminCompanies.tsx b/frontend/src/pages/AdminCompanies.tsx
index 884b0f62d..bd0b66713 100644
--- a/frontend/src/pages/AdminCompanies.tsx
+++ b/frontend/src/pages/AdminCompanies.tsx
@@ -109,6 +109,10 @@ function PlatformTab() {
const [nbSaving, setNbSaving] = useState(false);
const [nbSaved, setNbSaved] = useState(false);
+ // Public Base URL
+ const [publicBaseUrl, setPublicBaseUrl] = useState('');
+ const [publicBaseUrlSaving, setPublicBaseUrlSaving] = useState(false);
+ const [publicBaseUrlSaved, setPublicBaseUrlSaved] = useState(false);
// System email configuration
const [systemEmailConfig, setSystemEmailConfig] = useState({
@@ -163,6 +167,16 @@ function PlatformTab() {
}
}).catch(() => { });
+ // Load Public Base URL
+ const token2 = localStorage.getItem('token');
+ fetch('/api/enterprise/system-settings/platform', {
+ headers: { 'Content-Type': 'application/json', ...(token2 ? { Authorization: `Bearer ${token2}` } : {}) },
+ }).then(r => r.json()).then(d => {
+ if (d?.value?.public_base_url) {
+ setPublicBaseUrl(d.value.public_base_url);
+ }
+ }).catch(() => { });
+
// Load System Email
fetchJson('/enterprise/system-settings/system_email_platform')
.then(d => {
@@ -219,6 +233,24 @@ function PlatformTab() {
};
+ const savePublicBaseUrl = async () => {
+ setPublicBaseUrlSaving(true);
+ try {
+ const token = localStorage.getItem('token');
+ await fetch('/api/enterprise/system-settings/platform', {
+ method: 'PUT',
+ headers: { 'Content-Type': 'application/json', ...(token ? { Authorization: `Bearer ${token}` } : {}) },
+ body: JSON.stringify({ value: { public_base_url: publicBaseUrl.trim() || null } }),
+ });
+ setPublicBaseUrlSaved(true);
+ setTimeout(() => setPublicBaseUrlSaved(false), 2000);
+ showToast(t('enterprise.config.saved', 'Saved'));
+ } catch (e: any) {
+ showToast(e.message || t('common.error', 'Failed to save'), 'error');
+ }
+ setPublicBaseUrlSaving(false);
+ };
+
const saveEmailConfig = async () => {
setEmailConfigSaving(true);
try {
@@ -334,6 +366,29 @@ function PlatformTab() {
+ {/* Public Base URL */}
+
+
+ {t('admin.publicUrl.title', 'Public URL')}
+
+
+ {t('admin.publicUrl.desc', 'The external URL used for webhook callbacks and published page links. Include the protocol (e.g. https://example.com).')}
+
+ {t('enterprise.identity.fieldMapping', 'Field Mapping')} ({t('enterprise.identity.fieldMappingHint', 'Optional, leave empty to use standard OIDC fields')})
+