Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions backend/alembic/versions/add_subdomain_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Add subdomain_prefix to tenants table."""

from alembic import op
import sqlalchemy as sa
from sqlalchemy import inspect

revision: str = "add_subdomain_prefix"
down_revision = "d9cbd43b62e5"
branch_labels = None
depends_on = None

def upgrade() -> None:
conn = op.get_bind()
inspector = inspect(conn)
columns = [c["name"] for c in inspector.get_columns("tenants")]
if "subdomain_prefix" not in columns:
op.add_column("tenants", sa.Column("subdomain_prefix", sa.String(50), nullable=True))
indexes = [i["name"] for i in inspector.get_indexes("tenants")]
if "ix_tenants_subdomain_prefix" not in indexes:
op.create_index("ix_tenants_subdomain_prefix", "tenants", ["subdomain_prefix"], unique=True)

def downgrade() -> None:
op.drop_index("ix_tenants_subdomain_prefix", "tenants")
op.drop_column("tenants", "subdomain_prefix")
27 changes: 27 additions & 0 deletions backend/alembic/versions/add_tenant_is_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Add is_default field to tenants table."""

from alembic import op
import sqlalchemy as sa

revision = "add_tenant_is_default"
down_revision = "add_subdomain_prefix"
branch_labels = None
depends_on = None

def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
cols = [c['name'] for c in inspector.get_columns('tenants')]
if 'is_default' not in cols:
op.add_column('tenants', sa.Column('is_default', sa.Boolean(), nullable=False, server_default='false'))
conn.execute(sa.text("""
UPDATE tenants
SET is_default = true
WHERE id = (
SELECT id FROM tenants WHERE is_active = true ORDER BY created_at ASC LIMIT 1
)
AND NOT EXISTS (SELECT 1 FROM tenants WHERE is_default = true)
"""))

def downgrade() -> None:
op.drop_column('tenants', 'is_default')
34 changes: 26 additions & 8 deletions backend/app/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,12 +512,18 @@ async def login(data: UserLogin, background_tasks: BackgroundTasks, db: AsyncSes
user = valid_users[0]
else:
# Specific tenant requested (Dedicated Link flow)
# Search for the user record in that tenant
user = next((u for u in valid_users if u.tenant_id == data.tenant_id), None)

# Search for the user record in that tenant.
# Also allow users with tenant_id=None (SSO-created but not yet
# assigned to a tenant) so they can proceed to company setup.
user = next(
(u for u in valid_users
if u.tenant_id == data.tenant_id or u.tenant_id is None),
None,
)

# Cross-tenant access check
if not user:
# Even platform admins must have a valid record in the targeted tenant
# Even platform admins must have a valid record in the targeted tenant
# when logging in via a dedicated tenant URL / tenant_id.
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
Expand Down Expand Up @@ -910,14 +916,26 @@ async def oauth_callback(
raise HTTPException(status_code=404, detail=f"Provider '{provider}' not supported")

try:
# Exchange code for token
token_data = await auth_provider.exchange_code_for_token(data.code)
# Exchange code for token (pass redirect_uri for OAuth2 providers that require it)
if hasattr(auth_provider, 'exchange_code_for_token') and data.redirect_uri:
token_data = await auth_provider.exchange_code_for_token(data.code, redirect_uri=data.redirect_uri)
else:
token_data = await auth_provider.exchange_code_for_token(data.code)
access_token = token_data.get("access_token")
if not access_token:
raise HTTPException(status_code=400, detail="Failed to get access token from provider")

# Get user info
user_info = await auth_provider.get_user_info(access_token)
# Get user info with fallback to token_data extraction
try:
user_info = await auth_provider.get_user_info(access_token)
except Exception:
if hasattr(auth_provider, 'get_user_info_from_token_data'):
user_info = await auth_provider.get_user_info_from_token_data(token_data)
else:
raise
if not user_info.provider_user_id and hasattr(auth_provider, 'get_user_info_from_token_data'):
# try token_data as last resort
user_info = await auth_provider.get_user_info_from_token_data(token_data)

# Find or create user
user, is_new = await auth_provider.find_or_create_user(db, user_info)
Expand Down
16 changes: 15 additions & 1 deletion backend/app/api/enterprise.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,7 @@ class OAuth2Config(BaseModel):
token_url: str | None = None # OAuth2 token endpoint
user_info_url: str | None = None # OAuth2 user info endpoint
scope: str | None = "openid profile email"
field_mapping: dict | None = None # Custom field name mapping

def to_config_dict(self) -> dict:
"""Convert to config dict with both naming conventions for compatibility."""
Expand All @@ -811,6 +812,8 @@ def to_config_dict(self) -> dict:
config["user_info_url"] = self.user_info_url
if self.scope:
config["scope"] = self.scope
if self.field_mapping:
config["field_mapping"] = self.field_mapping
return config

@classmethod
Expand All @@ -823,6 +826,7 @@ def from_config_dict(cls, config: dict) -> "OAuth2Config":
token_url=config.get("token_url"),
user_info_url=config.get("user_info_url"),
scope=config.get("scope"),
field_mapping=config.get("field_mapping"),
)


Expand All @@ -837,6 +841,7 @@ class IdentityProviderOAuth2Create(BaseModel):
token_url: str
user_info_url: str
scope: str | None = "openid profile email"
field_mapping: dict | None = None
tenant_id: uuid.UUID | None = None


Expand Down Expand Up @@ -936,6 +941,7 @@ async def create_oauth2_provider(
token_url=data.token_url,
user_info_url=data.user_info_url,
scope=data.scope,
field_mapping=data.field_mapping,
)
config = oauth_config.to_config_dict()

Expand All @@ -962,6 +968,7 @@ async def create_oauth2_provider(
provider_type="oauth2",
name=data.name,
is_active=data.is_active,
sso_login_enabled=True,
config=config,
tenant_id=tid
)
Expand All @@ -981,6 +988,7 @@ class OAuth2ConfigUpdate(BaseModel):
token_url: str | None = None
user_info_url: str | None = None
scope: str | None = None
field_mapping: dict | None = None # Custom field name mapping


@router.patch("/identity-providers/{provider_id}/oauth2", response_model=IdentityProviderOut)
Expand Down Expand Up @@ -1009,7 +1017,7 @@ async def update_oauth2_provider(
provider.is_active = data.is_active

# Update config fields
if any([data.app_id, data.app_secret is not None, data.authorize_url, data.token_url, data.user_info_url, data.scope]):
if any([data.app_id, data.app_secret is not None, data.authorize_url, data.token_url, data.user_info_url, data.scope, data.field_mapping is not None]):
current_config = provider.config.copy()

if data.app_id is not None:
Expand All @@ -1031,6 +1039,12 @@ async def update_oauth2_provider(
current_config["user_info_url"] = data.user_info_url
if data.scope is not None:
current_config["scope"] = data.scope
if data.field_mapping is not None:
# Empty dict or explicit None clears the mapping
if data.field_mapping:
current_config["field_mapping"] = data.field_mapping
else:
current_config.pop("field_mapping", None)

# Validate the updated config
validate_provider_config("oauth2", current_config)
Expand Down
106 changes: 97 additions & 9 deletions backend/app/api/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from urllib.parse import quote

from fastapi import APIRouter, Depends, HTTPException, Request, status
from loguru import logger
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

Expand Down Expand Up @@ -104,15 +105,12 @@ async def get_sso_config(sid: uuid.UUID, request: Request, db: AsyncSession = De
result = await db.execute(query)
providers = result.scalars().all()

# Determine the base URL for OAuth callbacks using centralized platform service:
from app.services.platform_service import platform_service
if session.tenant_id:
from app.models.tenant import Tenant
tenant_result = await db.execute(select(Tenant).where(Tenant.id == session.tenant_id))
tenant_obj = tenant_result.scalar_one_or_none()
public_base = await platform_service.get_tenant_sso_base_url(db, tenant_obj, request)
else:
public_base = await platform_service.get_public_base_url(db, request)
# Determine the base URL for OAuth callbacks using unified domain resolution:
from app.core.domain import resolve_base_url
public_base = await resolve_base_url(
db, request=request,
tenant_id=str(session.tenant_id) if session.tenant_id else None
)

auth_urls = []
for p in providers:
Expand Down Expand Up @@ -141,5 +139,95 @@ async def get_sso_config(sid: uuid.UUID, request: Request, db: AsyncSession = De
url = f"https://open.work.weixin.qq.com/wwopen/sso/qrConnect?appid={corp_id}&agentid={agent_id}&redirect_uri={quote(redir)}&state={sid}"
auth_urls.append({"provider_type": "wecom", "name": p.name, "url": url})

elif p.provider_type == "oauth2":
from app.services.auth_registry import auth_provider_registry
auth_provider = await auth_provider_registry.get_provider(
db, "oauth2", str(session.tenant_id) if session.tenant_id else None
)
if auth_provider:
redir = f"{public_base}/api/auth/oauth2/callback"
url = await auth_provider.get_authorization_url(redir, str(sid))
auth_urls.append({"provider_type": "oauth2", "name": p.name, "url": url})

return auth_urls


@router.get("/auth/oauth2/callback")
async def oauth2_callback(
code: str,
state: str = None,
db: AsyncSession = Depends(get_db),
):
"""Handle OAuth2 SSO callback -- exchange code for user session."""
from app.core.security import create_access_token
from fastapi.responses import HTMLResponse
from app.services.auth_registry import auth_provider_registry

# 1. Resolve tenant context from state (= session ID)
tenant_id = None
sid = None
if state:
try:
sid = uuid.UUID(state)
s_res = await db.execute(select(SSOScanSession).where(SSOScanSession.id == sid))
session = s_res.scalar_one_or_none()
if session:
tenant_id = session.tenant_id
except (ValueError, AttributeError):
pass

# 2. Get OAuth2 provider
auth_provider = await auth_provider_registry.get_provider(
db, "oauth2", str(tenant_id) if tenant_id else None
)
if not auth_provider:
return HTMLResponse("Auth failed: OAuth2 provider not configured")

# 3. Exchange code -> token -> user info -> find/create user
try:
token_data = await auth_provider.exchange_code_for_token(code)
access_token = token_data.get("access_token")
if not access_token:
logger.error("OAuth2 token exchange returned no access_token")
return HTMLResponse("Auth failed: token exchange error")

user_info = await auth_provider.get_user_info(access_token)
if not user_info.provider_user_id:
logger.error("OAuth2 user info missing user ID")
return HTMLResponse("Auth failed: no user ID returned")

user, is_new = await auth_provider.find_or_create_user(
db, user_info, tenant_id=str(tenant_id) if tenant_id else None
)
if not user:
return HTMLResponse("Auth failed: user resolution failed")

except Exception as e:
logger.error("OAuth2 login error: %s", e)
return HTMLResponse(f"Auth failed: {e!s}")

# 4. Generate JWT, update SSO session
token = create_access_token(str(user.id), user.role)

if sid:
try:
s_res = await db.execute(select(SSOScanSession).where(SSOScanSession.id == sid))
session = s_res.scalar_one_or_none()
if session:
session.status = "authorized"
session.provider_type = "oauth2"
session.user_id = user.id
session.access_token = token
session.error_msg = None
await db.commit()
return HTMLResponse(
'<html><head><meta charset="utf-8" /></head>'
'<body><div>SSO login successful. Redirecting...</div>'
f'<script>window.location.href = "/sso/entry?sid={sid}&complete=1";</script>'
'</body></html>'
)
except Exception as e:
logger.exception("Failed to update SSO session (oauth2): %s", e)

return HTMLResponse("Logged in successfully.")

41 changes: 21 additions & 20 deletions backend/app/api/teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,42 +37,44 @@

TEAMS_MSG_LIMIT = 28000 # Teams message char limit (approx 28KB)

# In-memory cache for OAuth tokens
_teams_tokens: dict[str, dict] = {} # agent_id -> {access_token, expires_at}


async def _get_teams_access_token(config: ChannelConfig) -> str | None:
"""Get or refresh Microsoft Teams access token.

Supports:
- Client credentials (app_id + app_secret) - default
- Managed Identity (when use_managed_identity is True in extra_config)
Token is cached in Redis (preferred) with in-memory fallback.
Key: clawith:token:teams:{agent_id}
"""
from app.core.token_cache import get_cached_token, set_cached_token

agent_id = str(config.agent_id)
cached = _teams_tokens.get(agent_id)
if cached and cached["expires_at"] > time.time() + 60: # Refresh 60s before expiry
cache_key = f"clawith:token:teams:{agent_id}"

cached = await get_cached_token(cache_key)
if cached:
logger.debug(f"Teams: Using cached access token for agent {agent_id}")
return cached["access_token"]
return cached

# Check if managed identity should be used
use_managed_identity = config.extra_config.get("use_managed_identity", False)

if use_managed_identity:
# Use Azure Managed Identity
try:
from azure.identity.aio import DefaultAzureCredential
from azure.core.credentials import AccessToken

credential = DefaultAzureCredential()
# For Bot Framework, we need the token for the Bot Framework API
# Managed identity needs to be granted permissions to the Bot Framework API
scope = "https://api.botframework.com/.default"
token: AccessToken = await credential.get_token(scope)

_teams_tokens[agent_id] = {
"access_token": token.token,
"expires_at": token.expires_on,
}

# expires_on is a Unix timestamp; TTL = expires_on - now - 60s buffer
ttl = max(int(token.expires_on - time.time()) - 60, 60)
await set_cached_token(cache_key, token.token, ttl)
logger.info(f"Teams: Successfully obtained access token via managed identity for agent {agent_id}, expires at {token.expires_on}")
await credential.close()
return token.token
Expand All @@ -82,7 +84,7 @@ async def _get_teams_access_token(config: ChannelConfig) -> str | None:
except Exception as e:
logger.exception(f"Teams: Failed to get access token via managed identity for agent {agent_id}: {e}")
return None

# Use client credentials (app_id + app_secret)
app_id = config.app_id
app_secret = config.app_secret
Expand Down Expand Up @@ -117,11 +119,10 @@ async def _get_teams_access_token(config: ChannelConfig) -> str | None:
access_token = token_data["access_token"]
expires_in = token_data["expires_in"]

_teams_tokens[agent_id] = {
"access_token": access_token,
"expires_at": time.time() + expires_in,
}
logger.info(f"Teams: Successfully obtained access token for agent {agent_id}, expires in {expires_in}s")
# TTL = expires_in - 60s buffer
ttl = max(expires_in - 60, 60)
await set_cached_token(cache_key, access_token, ttl)
logger.info(f"Teams: Successfully obtained access token for agent {agent_id}, expires in {expires_in}s, TTL={ttl}s")
return access_token
except httpx.HTTPStatusError as e:
error_body = e.response.text if hasattr(e, 'response') and e.response else "No response body"
Expand Down
Loading