From 2726349fb03228433dd83df204b0d15ef6ad5f2b Mon Sep 17 00:00:00 2001 From: Apoorv Darshan Date: Fri, 15 May 2026 23:53:53 +0530 Subject: [PATCH] fix(backend): tag BYOK LLM errors by source --- backend/routers/pusher.py | 3 +- backend/routers/sync.py | 4 +- backend/tests/unit/test_byok_llm_logging.py | 92 ++++++++++++++++++++ backend/utils/byok.py | 16 ++++ backend/utils/llm/byok_errors.py | 90 +++++++++++++++++++ backend/utils/llm/clients.py | 95 ++++++++++++++++++--- backend/utils/retrieval/agentic.py | 3 +- 7 files changed, 287 insertions(+), 16 deletions(-) create mode 100644 backend/tests/unit/test_byok_llm_logging.py create mode 100644 backend/utils/llm/byok_errors.py diff --git a/backend/routers/pusher.py b/backend/routers/pusher.py index 2e4625076df..f1967ea3fc2 100644 --- a/backend/routers/pusher.py +++ b/backend/routers/pusher.py @@ -25,7 +25,7 @@ trigger_external_integrations, ) from utils.conversations.location import async_get_google_maps_location -from utils.byok import set_byok_keys +from utils.byok import set_byok_keys, set_byok_uid from utils.conversations.process_conversation import process_conversation from utils.executors import storage_executor from utils.webhooks import ( @@ -79,6 +79,7 @@ async def _process_conversation_task( """ if byok_keys: set_byok_keys(byok_keys) + set_byok_uid(uid) try: conversation_data = conversations_db.get_conversation(uid, conversation_id) if not conversation_data: diff --git a/backend/routers/sync.py b/backend/routers/sync.py index 671deba0881..ace4806c7f0 100644 --- a/backend/routers/sync.py +++ b/backend/routers/sync.py @@ -49,7 +49,7 @@ ) from utils import encryption -from utils.byok import get_byok_keys, set_byok_keys +from utils.byok import get_byok_keys, set_byok_keys, set_byok_uid from utils.log_sanitizer import sanitize from utils.stt.pre_recorded import deepgram_prerecorded, get_deepgram_model_for_language, postprocess_words from utils.stt.vad import vad_is_empty @@ -1359,6 +1359,7 @@ def _run_full_pipeline_background( Moved ALL heavy processing here so the v2 endpoint returns 202 immediately. """ set_byok_keys(byok_keys or {}) + set_byok_uid(uid if byok_keys else None) segmented_paths = set() wav_paths = [] stage_timings = {} @@ -1580,6 +1581,7 @@ def _process_one_segment(path): pass finally: set_byok_keys({}) + set_byok_uid(None) _cleanup_files(list(segmented_paths)) _cleanup_files(wav_paths) try: diff --git a/backend/tests/unit/test_byok_llm_logging.py b/backend/tests/unit/test_byok_llm_logging.py new file mode 100644 index 00000000000..d8ca06ab68a --- /dev/null +++ b/backend/tests/unit/test_byok_llm_logging.py @@ -0,0 +1,92 @@ +import os +import sys +import types +from unittest.mock import MagicMock, patch + +os.environ.setdefault('OPENAI_API_KEY', 'sk-test-fake-for-unit-tests') +os.environ.setdefault('ANTHROPIC_API_KEY', 'ant-test-fake-for-unit-tests') +os.environ.setdefault('ENCRYPTION_SECRET', 'omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv') + +sys.modules.setdefault('database._client', MagicMock()) +llm_usage_stub = types.ModuleType('database.llm_usage') +llm_usage_stub.record_llm_usage = MagicMock() +sys.modules.setdefault('database.llm_usage', llm_usage_stub) + + +class _HTTPError(Exception): + def __init__(self, message: str, status_code: int): + super().__init__(message) + self.status_code = status_code + + +def test_classify_byok_llm_error_authentication(): + from utils.llm.byok_errors import classify_byok_llm_error + + assert classify_byok_llm_error(_HTTPError("bad api key", 401)) == 'invalid' + + +def test_classify_byok_llm_error_permission(): + from utils.llm.byok_errors import classify_byok_llm_error + + assert classify_byok_llm_error(_HTTPError("project denied", 403)) == 'permission' + + +def test_classify_byok_llm_error_insufficient_quota(): + from utils.llm.byok_errors import classify_byok_llm_error + + assert classify_byok_llm_error(_HTTPError("insufficient_quota", 429)) == 'quota' + + +def test_classify_byok_llm_error_ignores_transient_rate_limit(): + from utils.llm.byok_errors import classify_byok_llm_error + + assert classify_byok_llm_error(_HTTPError("rate limit reached, retry later", 429)) is None + + +@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1') +@patch('utils.llm.byok_errors.get_byok_key', return_value='sk-user') +def test_handle_llm_error_logs_byok_source(mock_get_key, mock_get_uid): + from utils.llm.byok_errors import handle_llm_error + + with patch('utils.llm.byok_errors.logger.error') as mock_log: + handle_llm_error(_HTTPError("insufficient_quota", 429), 'openai', feature='memories', model='gpt-test') + + log_args = mock_log.call_args.args + assert 'LLM error source=%s' in log_args[0] + assert log_args[1] == 'byok' + assert log_args[2] == 'openai' + assert log_args[8] == 'quota' + + +@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1') +@patch('utils.llm.byok_errors.get_byok_key', return_value=None) +def test_handle_llm_error_logs_platform_source(mock_get_key, mock_get_uid): + from utils.llm.byok_errors import handle_llm_error + + with patch('utils.llm.byok_errors.logger.error') as mock_log: + handle_llm_error(_HTTPError("server error", 500), 'openai', feature='memories', model='gpt-test') + + assert mock_log.call_args.args[1] == 'platform' + assert mock_log.call_args.args[8] == 'unknown' + + +def test_validate_byok_request_records_current_uid(): + from utils.byok import get_byok_uid, validate_byok_request + + with patch('utils.byok._check_byok_validity', return_value=None): + validate_byok_request('user-1') + + assert get_byok_uid() == 'user-1' + + +def test_llm_error_callback_uses_provider_context(): + from utils.llm.clients import _LLMErrorCallback + + callback = _LLMErrorCallback('openai', model='gpt-test', feature='memories') + error = _HTTPError('bad key', 401) + + with patch('utils.llm.clients.handle_llm_error') as mock_handle: + callback.on_llm_error(error) + + mock_handle.assert_called_once() + assert mock_handle.call_args.args[:2] == (error, 'openai') diff --git a/backend/utils/byok.py b/backend/utils/byok.py index 39d355273dd..6c3c9fdfa1d 100644 --- a/backend/utils/byok.py +++ b/backend/utils/byok.py @@ -73,6 +73,7 @@ def invalidate_byok_state_cache(uid: str) -> None: # Keys for the current request, if the client supplied them. # Default is None (not {}) to avoid sharing a mutable object across contexts. _byok_ctx: ContextVar[Optional[Dict[str, str]]] = ContextVar('byok_keys', default=None) +_byok_uid_ctx: ContextVar[Optional[str]] = ContextVar('byok_uid', default=None) def get_byok_keys() -> Dict[str, str]: @@ -87,6 +88,16 @@ def get_byok_key(provider: str) -> Optional[str]: return keys.get(provider) +def get_byok_uid() -> Optional[str]: + """Return the authenticated uid for the current request, when known.""" + return _byok_uid_ctx.get() + + +def set_byok_uid(uid: Optional[str]) -> None: + """Attach the authenticated uid to the current request context.""" + _byok_uid_ctx.set(uid) + + def has_byok_keys() -> bool: """True if the current request carries at least one BYOK header.""" keys = _byok_ctx.get() @@ -127,10 +138,12 @@ async def dispatch(self, request: Request, call_next): if value: keys[provider] = value token = _byok_ctx.set(keys) + uid_token = _byok_uid_ctx.set(None) try: return await call_next(request) finally: _byok_ctx.reset(token) + _byok_uid_ctx.reset(uid_token) # --------------------------------------------------------------------------- @@ -203,6 +216,7 @@ def validate_byok_request(uid: str) -> None: if error: logger.warning('BYOK validation failed uid=%s: %s', uid, error) raise HTTPException(status_code=403, detail=error) + set_byok_uid(uid) def validate_byok_websocket(uid: str) -> Optional[str]: @@ -215,4 +229,6 @@ def validate_byok_websocket(uid: str) -> Optional[str]: error = _check_byok_validity(uid) if error: logger.warning('BYOK WS validation failed uid=%s: %s', uid, error) + else: + set_byok_uid(uid) return error diff --git a/backend/utils/llm/byok_errors.py b/backend/utils/llm/byok_errors.py new file mode 100644 index 00000000000..3d4c6d35a54 --- /dev/null +++ b/backend/utils/llm/byok_errors.py @@ -0,0 +1,90 @@ +import asyncio +import logging +from typing import Optional + +from utils.byok import get_byok_key, get_byok_uid +from utils.executors import storage_executor, submit_with_context +from utils.log_sanitizer import sanitize + +logger = logging.getLogger(__name__) + +_QUOTA_ERROR_NAMES = frozenset({'RateLimitError'}) + + +def get_llm_error_source(provider: Optional[str]) -> str: + """Return platform/byok for the current request and provider.""" + if provider and get_byok_key(provider): + return 'byok' + return 'platform' + + +def classify_byok_llm_error(error: Exception) -> Optional[str]: + """Classify user-actionable BYOK failures for structured logging.""" + status_code = _get_status_code(error) + error_name = type(error).__name__ + error_text = sanitize(str(error)).lower() + + if status_code == 401 or error_name == 'AuthenticationError': + return 'invalid' + if status_code == 403 or error_name == 'PermissionDeniedError': + return 'permission' + if status_code == 429 or error_name in _QUOTA_ERROR_NAMES: + if 'insufficient_quota' in error_text or 'quota' in error_text: + return 'quota' + return None + + +def handle_llm_error( + error: Exception, + provider: Optional[str], + feature: Optional[str] = None, + model: Optional[str] = None, + operation: str = 'chat', +) -> None: + """Log LLM failures with source context.""" + source = get_llm_error_source(provider) + reason = classify_byok_llm_error(error) if source == 'byok' else None + uid = get_byok_uid() + status_code = _get_status_code(error) + + logger.error( + 'LLM error source=%s provider=%s feature=%s model=%s operation=%s uid=%s status_code=%s reason=%s ' + 'error_type=%s error=%s', + source, + provider or 'unknown', + feature or 'unknown', + model or 'unknown', + operation, + uid or 'unknown', + status_code or 'unknown', + reason or 'unknown', + type(error).__name__, + sanitize(str(error)), + ) + + +async def handle_llm_error_async( + error: Exception, + provider: Optional[str], + feature: Optional[str] = None, + model: Optional[str] = None, + operation: str = 'chat', +) -> None: + """Run LLM error handling off the event loop while preserving BYOK context.""" + future = submit_with_context(storage_executor, handle_llm_error, error, provider, feature, model, operation) + try: + await asyncio.wrap_future(future) + except Exception as e: + logger.error('Async LLM error handler failed provider=%s feature=%s: %s', provider, feature, e) + + +def _get_status_code(error: Exception) -> Optional[int]: + status_code = getattr(error, 'status_code', None) + if isinstance(status_code, int): + return status_code + + response = getattr(error, 'response', None) + response_status = getattr(response, 'status_code', None) + if isinstance(response_status, int): + return response_status + return None diff --git a/backend/utils/llm/clients.py b/backend/utils/llm/clients.py index 2d73a028f59..769bbee4c22 100644 --- a/backend/utils/llm/clients.py +++ b/backend/utils/llm/clients.py @@ -6,6 +6,7 @@ import anthropic import httpx from cachetools import TTLCache +from langchain_core.callbacks import BaseCallbackHandler from langchain_core.language_models import BaseChatModel from langchain_core.output_parsers import PydanticOutputParser from langchain_google_genai import ChatGoogleGenerativeAI @@ -14,12 +15,49 @@ from models.structured import Structured from utils.byok import get_byok_key +from utils.llm.byok_errors import handle_llm_error from utils.llm.usage_tracker import get_usage_callback logger = logging.getLogger(__name__) _usage_callback = get_usage_callback() + +class _LLMErrorCallback(BaseCallbackHandler): + """LangChain callback that tags provider errors with platform/BYOK source.""" + + def __init__(self, provider: str, model: str = '', feature: str = ''): + self.provider = provider + self.model = model + self.feature = feature + + def on_llm_error(self, error: BaseException, **kwargs) -> None: + if isinstance(error, Exception): + handle_llm_error(error, self.provider, feature=self.feature, model=self.model) + + +_llm_error_callbacks: Dict[Tuple[str, str, str], _LLMErrorCallback] = {} + + +def _get_llm_error_callback(provider: str, model: str = '', feature: str = '') -> _LLMErrorCallback: + key = (provider, model, feature) + if key not in _llm_error_callbacks: + _llm_error_callbacks[key] = _LLMErrorCallback(provider, model=model, feature=feature) + return _llm_error_callbacks[key] + + +def _with_llm_callbacks(kwargs: Dict[str, Any], provider: str, model: str = '', feature: str = '') -> Dict[str, Any]: + result = dict(kwargs) + callbacks = list(result.get('callbacks') or []) + if _usage_callback not in callbacks: + callbacks.append(_usage_callback) + error_callback = _get_llm_error_callback(provider, model=model, feature=feature) + if error_callback not in callbacks: + callbacks.append(error_callback) + result['callbacks'] = callbacks + return result + + # --------------------------------------------------------------------------- # BYOK (Bring Your Own Key) # @@ -56,6 +94,7 @@ class _OpenAIEmbeddingsProxy: """Transparent proxy for OpenAIEmbeddings that uses BYOK OpenAI when set.""" __slots__ = ('_model', '_default', '_ctor_kwargs') + _METHODS_TO_WRAP = {'embed_documents', 'aembed_documents', 'embed_query', 'aembed_query'} def __init__(self, model: str, default: OpenAIEmbeddings, ctor_kwargs: Dict[str, Any]): object.__setattr__(self, '_model', model) @@ -74,7 +113,28 @@ def _resolve(self) -> OpenAIEmbeddings: return self._default def __getattr__(self, name: str): - return getattr(self._resolve(), name) + attr = getattr(self._resolve(), name) + if name not in self._METHODS_TO_WRAP or not callable(attr): + return attr + if name.startswith('a'): + + async def _wrapped_async(*args, **kwargs): + try: + return await attr(*args, **kwargs) + except Exception as e: + handle_llm_error(e, 'openai', feature='embeddings', model=self._model, operation=name) + raise + + return _wrapped_async + + def _wrapped(*args, **kwargs): + try: + return attr(*args, **kwargs) + except Exception as e: + handle_llm_error(e, 'openai', feature='embeddings', model=self._model, operation=name) + raise + + return _wrapped _BYOK_CACHE_MAX_SIZE = 256 @@ -111,7 +171,10 @@ def _create_byok_client( model: str, provider: str, byok_key: str, streaming: bool = False, feature: str = '' ) -> Optional[ChatOpenAI]: """Create a ChatOpenAI using the user's BYOK key. Returns None if BYOK not supported for this provider.""" - kwargs: Dict[str, Any] = {'callbacks': [_usage_callback], 'request_timeout': 120, 'max_retries': 1} + callback_provider = _effective_byok_provider(model, provider) + kwargs: Dict[str, Any] = _with_llm_callbacks( + {'request_timeout': 120, 'max_retries': 1}, callback_provider, model=model, feature=feature + ) if model == 'gpt-5.1': kwargs['extra_body'] = {"prompt_cache_retention": "24h"} if streaming: @@ -148,6 +211,7 @@ def get_anthropic_client() -> anthropic.AsyncAnthropic: def get_openai_chat(model: str, **kwargs) -> ChatOpenAI: """Explicit factory; equivalent to using the module-level proxies.""" + kwargs = _with_llm_callbacks(kwargs, 'openai', model=model) byok = get_byok_key('openai') if byok: return _cached_openai_chat(model, byok, kwargs) @@ -417,11 +481,9 @@ def _get_or_create_openai_llm(model_name: str, streaming: bool = False) -> ChatO """Get or create a cached ChatOpenAI for an OpenAI model.""" key = (model_name, streaming, 'openai') if key not in _llm_cache: - kwargs: Dict[str, Any] = { - 'callbacks': [_usage_callback], - 'request_timeout': 120, - 'max_retries': 1, - } + kwargs: Dict[str, Any] = _with_llm_callbacks( + {'request_timeout': 120, 'max_retries': 1}, 'openai', model=model_name + ) if model_name == 'gpt-5.1': kwargs['extra_body'] = {"prompt_cache_retention": "24h"} if streaming: @@ -447,10 +509,10 @@ def _get_or_create_openrouter_llm( 'api_key': os.environ.get('OPENROUTER_API_KEY'), 'base_url': "https://openrouter.ai/api/v1", 'default_headers': {"X-Title": "Omi Chat"}, - 'callbacks': [_usage_callback], 'request_timeout': 120, 'max_retries': 1, } + kwargs = _with_llm_callbacks(kwargs, 'openrouter', model=api_model) if temperature is not None: kwargs['temperature'] = temperature if streaming: @@ -478,7 +540,7 @@ def _get_or_create_gemini_llm(model_name: str, streaming: bool = False) -> BaseC use_vertex = os.environ.get('USE_VERTEX_AI', '').lower() == 'true' gcp_project = os.environ.get('GOOGLE_CLOUD_PROJECT', '') if use_vertex else '' gemini_key = os.environ.get('GEMINI_API_KEY', '') - kwargs: Dict[str, Any] = {'callbacks': [_usage_callback], 'timeout': 120, 'max_retries': 1} + kwargs: Dict[str, Any] = _with_llm_callbacks({'timeout': 120, 'max_retries': 1}, 'gemini', model=model_name) if streaming: kwargs['streaming'] = True @@ -624,7 +686,10 @@ def get_qos_info() -> Dict[str, Dict[str, str]]: # Legacy module-level alias (kept for test compatibility). # Production code should use get_llm(feature) exclusively. # --------------------------------------------------------------------------- -llm_mini = ChatOpenAI(model='gpt-4.1-mini', callbacks=[_usage_callback], request_timeout=120, max_retries=1) +llm_mini = ChatOpenAI( + model='gpt-4.1-mini', + **_with_llm_callbacks({'request_timeout': 120, 'max_retries': 1}, 'openai', model='gpt-4.1-mini'), +) # --------------------------------------------------------------------------- # Embeddings, parser, utilities @@ -667,6 +732,10 @@ def gemini_embed_query(text: str) -> List[float]: 'taskType': 'RETRIEVAL_QUERY', } headers = {'x-goog-api-key': api_key, 'Content-Type': 'application/json'} - resp = httpx.post(url, json=payload, headers=headers, timeout=10) - resp.raise_for_status() - return resp.json()['embedding']['values'] + try: + resp = httpx.post(url, json=payload, headers=headers, timeout=10) + resp.raise_for_status() + return resp.json()['embedding']['values'] + except Exception as e: + handle_llm_error(e, 'gemini', feature='embeddings', model='embedding-001', operation='embed_query') + raise diff --git a/backend/utils/retrieval/agentic.py b/backend/utils/retrieval/agentic.py index 84c3697ffb4..e49a7eecce4 100644 --- a/backend/utils/retrieval/agentic.py +++ b/backend/utils/retrieval/agentic.py @@ -47,6 +47,7 @@ ) from utils.retrieval.tools.app_tools import load_app_tools, get_tool_status_message from utils.retrieval.safety import AgentSafetyGuard, SafetyGuardError +from utils.llm.byok_errors import handle_llm_error_async from utils.llm.clients import anthropic_client, ANTHROPIC_AGENT_MODEL from utils.llm.chat import _get_agentic_qa_prompt from utils.other.endpoints import timeit @@ -420,7 +421,7 @@ async def _run_anthropic_agent_stream( response = await stream.get_final_message() except Exception as e: - logger.error(f"Anthropic API error: {e}") + await handle_llm_error_async(e, 'anthropic', feature='chat_agent', model=ANTHROPIC_AGENT_MODEL) await callback.put_data(f"\n\nSorry, I encountered an error. Please try again.") await callback.end() return