From e46d5e1b1ee43dcd11d7cacdcc3be50700a5260a Mon Sep 17 00:00:00 2001 From: Natasha Vazquez Date: Mon, 18 May 2026 16:13:10 -0700 Subject: [PATCH 1/7] feat: add per-model Vertex AI region support for judge panel litellm.drop_params=True (set by DeepEval) silently strips vertex_project and vertex_location from completion kwargs. This adds a _vertex_override context manager to litellm_patch.py that intercepts these params and temporarily sets them as litellm module-level attributes, which litellm checks as a fallback in its vertex_ai handler. Thread-safe and a no-op when no vertex params are present. Co-Authored-By: Claude Opus 4.6 --- config/system.yaml | 12 ++ .../core/llm/litellm_patch.py | 79 +++++-- tests/unit/core/llm/test_litellm_patch.py | 199 ++++++++++++++++++ 3 files changed, 271 insertions(+), 19 deletions(-) create mode 100644 tests/unit/core/llm/test_litellm_patch.py diff --git a/config/system.yaml b/config/system.yaml index 7187a018..de882cf3 100644 --- a/config/system.yaml +++ b/config/system.yaml @@ -46,6 +46,18 @@ llm_pool: # Add, remove or override model specific parameters temperature: null # Removes temperature from default max_completion_tokens: 2048 # Overrides default + # Vertex AI example with per-model region: + # judge_vertex_gemini: + # provider: vertex_ai + # model: gemini-2.0-flash + # parameters: + # vertex_location: us-central1 # Region for this model + # # vertex_project: my-gcp-project # Optional: override GCP project + # judge_vertex_llama: + # provider: vertex_ai + # model: meta/llama-3.3-70b-instruct-maas + # parameters: + # vertex_location: europe-west1 # Different region for this model # Judge Panel: multiple judges from the pool # Combine their scores. First judge in judges is the fallback when the full panel is not used for a metric. diff --git a/src/lightspeed_evaluation/core/llm/litellm_patch.py b/src/lightspeed_evaluation/core/llm/litellm_patch.py index 9eb13a70..7ecaec85 100644 --- a/src/lightspeed_evaluation/core/llm/litellm_patch.py +++ b/src/lightspeed_evaluation/core/llm/litellm_patch.py @@ -1,6 +1,6 @@ -"""LiteLLM configuration for token tracking and Ragas 0.4 compatibility. +"""LiteLLM configuration for token tracking, Ragas 0.4 compatibility, and Vertex AI support. -This module configures litellm for two purposes: +This module configures litellm for three purposes: 1. TOKEN TRACKING: Wraps litellm.completion, litellm.acompletion, litellm.embedding, and litellm.aembedding to track token usage for all LLM and embedding calls. @@ -14,14 +14,21 @@ We replace the LoggingWorker with a no-op implementation to avoid this. This is safe because we don't use litellm's built-in observability features. + +3. VERTEX AI PER-MODEL REGION SUPPORT: litellm.drop_params=True (set by + DeepEval) silently strips vertex_project and vertex_location from + completion kwargs. The completion wrappers intercept these params and + temporarily set them as litellm module-level attributes, which litellm + checks as a fallback in its vertex_ai handler. """ import logging import os import threading import warnings +from contextlib import contextmanager from functools import wraps -from typing import Any +from typing import Any, Generator import litellm @@ -89,6 +96,50 @@ def clear_queue(self) -> None: litellm.suppress_debug_info = True +# ============================================================================= +# GLOBAL STATE LOCK +# ============================================================================= +# Single lock for ALL litellm global state mutations (cache, ssl_verify, +# vertex_project, vertex_location). Import this lock in any module that +# reads/writes litellm global state to prevent race conditions between +# concurrent pipelines. +litellm_state_lock = threading.Lock() + + +# ============================================================================= +# VERTEX AI PER-MODEL REGION SUPPORT +# ============================================================================= +# litellm.drop_params=True (set by DeepEval) silently strips vertex_project +# and vertex_location from completion kwargs. We intercept these params and +# temporarily set them as litellm module-level attributes, which litellm +# checks as a fallback in its vertex_ai handler. + + +@contextmanager +def _vertex_override(kwargs: dict[str, Any]) -> Generator[None, None, None]: + """Pop vertex_project/vertex_location from kwargs and set as litellm module attrs. + + When neither key is present the context manager is a no-op (no lock acquired). + """ + vp = kwargs.pop("vertex_project", None) + vl = kwargs.pop("vertex_location", None) + if vp is None and vl is None: + yield + return + with litellm_state_lock: + old_vp = getattr(litellm, "vertex_project", None) + old_vl = getattr(litellm, "vertex_location", None) + try: + if vp is not None: + litellm.vertex_project = vp + if vl is not None: + litellm.vertex_location = vl + yield + finally: + litellm.vertex_project = old_vp + litellm.vertex_location = old_vl + + # ============================================================================= # TOKEN TRACKING: Wrap completion and embedding functions # ============================================================================= @@ -101,11 +152,11 @@ def clear_queue(self) -> None: _original_aembedding = litellm.aembedding -# Patch litellm's completion functions to include token tracking @wraps(_original_completion) def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any: - """Wrapper around litellm.completion that tracks tokens.""" - response = _original_completion(*args, **kwargs) + """Wrapper around litellm.completion that tracks tokens and handles Vertex params.""" + with _vertex_override(kwargs): + response = _original_completion(*args, **kwargs) try: track_judge_tokens(response) except Exception as e: # pylint: disable=broad-exception-caught @@ -115,8 +166,9 @@ def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any: @wraps(_original_acompletion) async def _acompletion_with_token_tracking(*args: Any, **kwargs: Any) -> Any: - """Wrapper around litellm.acompletion that tracks tokens.""" - response = await _original_acompletion(*args, **kwargs) + """Wrapper around litellm.acompletion that tracks tokens and handles Vertex params.""" + with _vertex_override(kwargs): + response = await _original_acompletion(*args, **kwargs) try: track_judge_tokens(response) except Exception as e: # pylint: disable=broad-exception-caught @@ -124,7 +176,6 @@ async def _acompletion_with_token_tracking(*args: Any, **kwargs: Any) -> Any: return response -# Patch litellm's embedding functions to include token tracking @wraps(_original_embedding) def _embedding_with_token_tracking(*args: Any, **kwargs: Any) -> Any: """Wrapper around litellm.embedding that tracks tokens.""" @@ -147,22 +198,12 @@ async def _aembedding_with_token_tracking(*args: Any, **kwargs: Any) -> Any: return response -# Patch litellm's completion and embedding functions to include token tracking litellm.completion = _completion_with_token_tracking litellm.acompletion = _acompletion_with_token_tracking litellm.embedding = _embedding_with_token_tracking litellm.aembedding = _aembedding_with_token_tracking -# ============================================================================= -# GLOBAL STATE LOCK -# ============================================================================= -# Single lock for ALL litellm global state mutations (cache, ssl_verify). -# Import this lock in any module that reads/writes litellm.cache or -# litellm.ssl_verify to prevent race conditions between concurrent pipelines. -litellm_state_lock = threading.Lock() - - # ============================================================================= # SSL CONFIGURATION UTILITY # ============================================================================= diff --git a/tests/unit/core/llm/test_litellm_patch.py b/tests/unit/core/llm/test_litellm_patch.py new file mode 100644 index 00000000..29e8b747 --- /dev/null +++ b/tests/unit/core/llm/test_litellm_patch.py @@ -0,0 +1,199 @@ +"""Unit tests for litellm_patch vertex override support.""" + +from typing import Any, Callable + +import pytest +from pytest_mock import MockerFixture +import litellm + +from lightspeed_evaluation.core.llm import litellm_patch +from lightspeed_evaluation.core.llm.litellm_patch import _vertex_override + + +class TestVertexOverrideContextManager: + """Tests for the _vertex_override context manager.""" + + def test_no_vertex_params_is_noop(self) -> None: + """Test that _vertex_override is a no-op when no vertex params present.""" + kwargs: dict[str, Any] = {"model": "gpt-4", "temperature": 0.5} + original_kwargs = dict(kwargs) + + with _vertex_override(kwargs): + pass + + assert kwargs == original_kwargs + + def test_vertex_location_set_and_restored(self) -> None: + """Test that vertex_location is set on litellm module and restored after.""" + old_value = getattr(litellm, "vertex_location", None) + kwargs: dict[str, Any] = {"vertex_location": "us-central1"} + + with _vertex_override(kwargs): + assert litellm.vertex_location == "us-central1" + assert "vertex_location" not in kwargs + + assert getattr(litellm, "vertex_location", None) == old_value + + def test_vertex_project_set_and_restored(self) -> None: + """Test that vertex_project is set on litellm module and restored after.""" + old_value = getattr(litellm, "vertex_project", None) + kwargs: dict[str, Any] = {"vertex_project": "my-project"} + + with _vertex_override(kwargs): + assert litellm.vertex_project == "my-project" + assert "vertex_project" not in kwargs + + assert getattr(litellm, "vertex_project", None) == old_value + + def test_both_params_set_and_restored(self) -> None: + """Test that both vertex params are set and restored.""" + old_location = getattr(litellm, "vertex_location", None) + old_project = getattr(litellm, "vertex_project", None) + kwargs: dict[str, Any] = { + "vertex_location": "europe-west1", + "vertex_project": "my-project", + "temperature": 0.5, + } + + with _vertex_override(kwargs): + assert litellm.vertex_location == "europe-west1" + assert litellm.vertex_project == "my-project" + assert "vertex_location" not in kwargs + assert "vertex_project" not in kwargs + assert kwargs == {"temperature": 0.5} + + assert getattr(litellm, "vertex_location", None) == old_location + assert getattr(litellm, "vertex_project", None) == old_project + + def test_params_restored_on_exception(self) -> None: + """Test that vertex params are restored even when an exception occurs.""" + old_location = getattr(litellm, "vertex_location", None) + kwargs: dict[str, Any] = {"vertex_location": "us-east1"} + + with pytest.raises(ValueError, match="test error"): + with _vertex_override(kwargs): + assert litellm.vertex_location == "us-east1" + raise ValueError("test error") + + assert getattr(litellm, "vertex_location", None) == old_location + + def test_no_lock_acquired_without_vertex_params( + self, mocker: MockerFixture + ) -> None: + """Test that the lock is not acquired when no vertex params are present.""" + mock_lock = mocker.patch.object(litellm_patch, "litellm_state_lock") + kwargs: dict[str, Any] = {"temperature": 0.5} + + with _vertex_override(kwargs): + pass + + mock_lock.__enter__.assert_not_called() + + def test_lock_acquired_with_vertex_params(self, mocker: MockerFixture) -> None: + """Test that the lock is acquired when vertex params are present.""" + mock_lock = mocker.MagicMock() + mocker.patch.object(litellm_patch, "litellm_state_lock", mock_lock) + kwargs: dict[str, Any] = {"vertex_location": "us-central1"} + + with _vertex_override(kwargs): + pass + + mock_lock.__enter__.assert_called_once() + + +class TestCompletionWithVertexOverride: + """Test litellm.completion integration with vertex override.""" + + def test_completion_with_vertex_location( + self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] + ) -> None: + """Test that vertex_location is handled during completion calls.""" + mock_completion = mocker.patch(f"{litellm_patch.__name__}._original_completion") + mock_completion.return_value = mock_judge_llm_response( + prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" + ) + + old_location = getattr(litellm, "vertex_location", None) + + litellm.completion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "test"}], + vertex_location="us-central1", + ) + + mock_completion.assert_called_once() + call_kwargs = mock_completion.call_args[1] + assert "vertex_location" not in call_kwargs + assert getattr(litellm, "vertex_location", None) == old_location + + def test_completion_without_vertex_params_unchanged( + self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] + ) -> None: + """Test that completion works normally without vertex params.""" + mock_completion = mocker.patch(f"{litellm_patch.__name__}._original_completion") + mock_completion.return_value = mock_judge_llm_response( + prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" + ) + + litellm.completion( + model="gpt-4", + messages=[{"role": "user", "content": "test"}], + temperature=0.5, + ) + + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["temperature"] == 0.5 + assert call_kwargs["model"] == "gpt-4" + + @pytest.mark.asyncio + async def test_acompletion_with_vertex_location( + self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] + ) -> None: + """Test that vertex_location is handled during async completion calls.""" + mock_acompletion = mocker.patch( + f"{litellm_patch.__name__}._original_acompletion" + ) + mock_acompletion.return_value = mock_judge_llm_response( + prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" + ) + + old_location = getattr(litellm, "vertex_location", None) + + await litellm.acompletion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "test"}], + vertex_location="europe-west1", + ) + + mock_acompletion.assert_called_once() + call_kwargs = mock_acompletion.call_args[1] + assert "vertex_location" not in call_kwargs + assert getattr(litellm, "vertex_location", None) == old_location + + @pytest.mark.asyncio + async def test_acompletion_with_both_vertex_params( + self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] + ) -> None: + """Test that both vertex params are handled during async completion.""" + mock_acompletion = mocker.patch( + f"{litellm_patch.__name__}._original_acompletion" + ) + mock_acompletion.return_value = mock_judge_llm_response( + prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" + ) + + old_location = getattr(litellm, "vertex_location", None) + old_project = getattr(litellm, "vertex_project", None) + + await litellm.acompletion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "test"}], + vertex_location="us-central1", + vertex_project="my-project", + ) + + call_kwargs = mock_acompletion.call_args[1] + assert "vertex_location" not in call_kwargs + assert "vertex_project" not in call_kwargs + assert getattr(litellm, "vertex_location", None) == old_location + assert getattr(litellm, "vertex_project", None) == old_project From 2369a0b0f2069fc5b1d02e67c5a2d8f307a74dfd Mon Sep 17 00:00:00 2001 From: Natasha Vazquez Date: Mon, 18 May 2026 18:53:14 -0700 Subject: [PATCH 2/7] fix: close race condition and event loop blocking in _vertex_override The sync context manager now always acquires litellm_state_lock, even when no vertex params are present so that non-vertex requests cannot read litellm.vertex_project/vertex_location while another thread is mutating them. An async variant (_vertex_override_async) backed by a new asyncio.Lock replaces the threading.Lock usage in the acompletion path, preventing the event loop from blocking during concurrent calls. --- .../core/llm/litellm_patch.py | 63 ++++++++--- tests/unit/core/llm/test_litellm_patch.py | 106 +++++++++++++++++- 2 files changed, 149 insertions(+), 20 deletions(-) diff --git a/src/lightspeed_evaluation/core/llm/litellm_patch.py b/src/lightspeed_evaluation/core/llm/litellm_patch.py index 7ecaec85..621843d1 100644 --- a/src/lightspeed_evaluation/core/llm/litellm_patch.py +++ b/src/lightspeed_evaluation/core/llm/litellm_patch.py @@ -22,13 +22,14 @@ checks as a fallback in its vertex_ai handler. """ +import asyncio import logging import os import threading import warnings -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from functools import wraps -from typing import Any, Generator +from typing import Any, AsyncGenerator, Generator import litellm @@ -97,13 +98,16 @@ def clear_queue(self) -> None: # ============================================================================= -# GLOBAL STATE LOCK +# GLOBAL STATE LOCKS # ============================================================================= -# Single lock for ALL litellm global state mutations (cache, ssl_verify, -# vertex_project, vertex_location). Import this lock in any module that -# reads/writes litellm global state to prevent race conditions between -# concurrent pipelines. +# Locks for ALL litellm global state mutations (cache, ssl_verify, +# vertex_project, vertex_location). Import the appropriate lock in any +# module that reads/writes litellm global state to prevent race conditions +# between concurrent pipelines. +# - litellm_state_lock: for synchronous code paths (threading.Lock) +# - litellm_state_async_lock: for asynchronous code paths (asyncio.Lock) litellm_state_lock = threading.Lock() +litellm_state_async_lock = asyncio.Lock() # ============================================================================= @@ -119,14 +123,45 @@ def clear_queue(self) -> None: def _vertex_override(kwargs: dict[str, Any]) -> Generator[None, None, None]: """Pop vertex_project/vertex_location from kwargs and set as litellm module attrs. - When neither key is present the context manager is a no-op (no lock acquired). + Always acquires litellm_state_lock to prevent concurrent reads of partially + updated globals, even when no vertex params are present in kwargs. """ - vp = kwargs.pop("vertex_project", None) - vl = kwargs.pop("vertex_location", None) - if vp is None and vl is None: - yield - return with litellm_state_lock: + vp = kwargs.pop("vertex_project", None) + vl = kwargs.pop("vertex_location", None) + if vp is None and vl is None: + yield + return + old_vp = getattr(litellm, "vertex_project", None) + old_vl = getattr(litellm, "vertex_location", None) + try: + if vp is not None: + litellm.vertex_project = vp + if vl is not None: + litellm.vertex_location = vl + yield + finally: + litellm.vertex_project = old_vp + litellm.vertex_location = old_vl + + +@asynccontextmanager +async def _vertex_override_async( + kwargs: dict[str, Any], +) -> AsyncGenerator[None, None]: + """Async version of _vertex_override using asyncio.Lock. + + Uses litellm_state_async_lock instead of threading.Lock to avoid blocking + the event loop. The lock is held across the yield (including any awaited + completion call) to ensure globals remain consistent for the duration of + the request. + """ + async with litellm_state_async_lock: + vp = kwargs.pop("vertex_project", None) + vl = kwargs.pop("vertex_location", None) + if vp is None and vl is None: + yield + return old_vp = getattr(litellm, "vertex_project", None) old_vl = getattr(litellm, "vertex_location", None) try: @@ -167,7 +202,7 @@ def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any: @wraps(_original_acompletion) async def _acompletion_with_token_tracking(*args: Any, **kwargs: Any) -> Any: """Wrapper around litellm.acompletion that tracks tokens and handles Vertex params.""" - with _vertex_override(kwargs): + async with _vertex_override_async(kwargs): response = await _original_acompletion(*args, **kwargs) try: track_judge_tokens(response) diff --git a/tests/unit/core/llm/test_litellm_patch.py b/tests/unit/core/llm/test_litellm_patch.py index 29e8b747..d5591f68 100644 --- a/tests/unit/core/llm/test_litellm_patch.py +++ b/tests/unit/core/llm/test_litellm_patch.py @@ -7,7 +7,10 @@ import litellm from lightspeed_evaluation.core.llm import litellm_patch -from lightspeed_evaluation.core.llm.litellm_patch import _vertex_override +from lightspeed_evaluation.core.llm.litellm_patch import ( + _vertex_override, + _vertex_override_async, +) class TestVertexOverrideContextManager: @@ -77,17 +80,15 @@ def test_params_restored_on_exception(self) -> None: assert getattr(litellm, "vertex_location", None) == old_location - def test_no_lock_acquired_without_vertex_params( - self, mocker: MockerFixture - ) -> None: - """Test that the lock is not acquired when no vertex params are present.""" + def test_lock_acquired_without_vertex_params(self, mocker: MockerFixture) -> None: + """Test that the lock is acquired even when no vertex params are present.""" mock_lock = mocker.patch.object(litellm_patch, "litellm_state_lock") kwargs: dict[str, Any] = {"temperature": 0.5} with _vertex_override(kwargs): pass - mock_lock.__enter__.assert_not_called() + mock_lock.__enter__.assert_called_once() def test_lock_acquired_with_vertex_params(self, mocker: MockerFixture) -> None: """Test that the lock is acquired when vertex params are present.""" @@ -101,6 +102,99 @@ def test_lock_acquired_with_vertex_params(self, mocker: MockerFixture) -> None: mock_lock.__enter__.assert_called_once() +class TestVertexOverrideAsyncContextManager: + """Tests for the _vertex_override_async async context manager.""" + + @pytest.mark.asyncio + async def test_no_vertex_params_is_noop(self) -> None: + """Test that _vertex_override_async is a no-op when no vertex params present.""" + kwargs: dict[str, Any] = {"model": "gpt-4", "temperature": 0.5} + original_kwargs = dict(kwargs) + + async with _vertex_override_async(kwargs): + pass + + assert kwargs == original_kwargs + + @pytest.mark.asyncio + async def test_vertex_location_set_and_restored(self) -> None: + """Test that vertex_location is set on litellm module and restored after.""" + old_value = getattr(litellm, "vertex_location", None) + kwargs: dict[str, Any] = {"vertex_location": "us-central1"} + + async with _vertex_override_async(kwargs): + assert litellm.vertex_location == "us-central1" + assert "vertex_location" not in kwargs + + assert getattr(litellm, "vertex_location", None) == old_value + + @pytest.mark.asyncio + async def test_both_params_set_and_restored(self) -> None: + """Test that both vertex params are set and restored.""" + old_location = getattr(litellm, "vertex_location", None) + old_project = getattr(litellm, "vertex_project", None) + kwargs: dict[str, Any] = { + "vertex_location": "europe-west1", + "vertex_project": "my-project", + "temperature": 0.5, + } + + async with _vertex_override_async(kwargs): + assert litellm.vertex_location == "europe-west1" + assert litellm.vertex_project == "my-project" + assert "vertex_location" not in kwargs + assert "vertex_project" not in kwargs + assert kwargs == {"temperature": 0.5} + + assert getattr(litellm, "vertex_location", None) == old_location + assert getattr(litellm, "vertex_project", None) == old_project + + @pytest.mark.asyncio + async def test_params_restored_on_exception(self) -> None: + """Test that vertex params are restored even when an exception occurs.""" + old_location = getattr(litellm, "vertex_location", None) + kwargs: dict[str, Any] = {"vertex_location": "us-east1"} + + with pytest.raises(ValueError, match="test error"): + async with _vertex_override_async(kwargs): + assert litellm.vertex_location == "us-east1" + raise ValueError("test error") + + assert getattr(litellm, "vertex_location", None) == old_location + + @pytest.mark.asyncio + async def test_async_lock_acquired_with_vertex_params( + self, mocker: MockerFixture + ) -> None: + """Test that the async lock is acquired when vertex params are present.""" + mock_lock = mocker.MagicMock() + mock_lock.__aenter__ = mocker.AsyncMock(return_value=None) + mock_lock.__aexit__ = mocker.AsyncMock(return_value=False) + mocker.patch.object(litellm_patch, "litellm_state_async_lock", mock_lock) + kwargs: dict[str, Any] = {"vertex_location": "us-central1"} + + async with _vertex_override_async(kwargs): + pass + + mock_lock.__aenter__.assert_called_once() + + @pytest.mark.asyncio + async def test_async_lock_acquired_without_vertex_params( + self, mocker: MockerFixture + ) -> None: + """Test that the async lock is acquired even without vertex params.""" + mock_lock = mocker.MagicMock() + mock_lock.__aenter__ = mocker.AsyncMock(return_value=None) + mock_lock.__aexit__ = mocker.AsyncMock(return_value=False) + mocker.patch.object(litellm_patch, "litellm_state_async_lock", mock_lock) + kwargs: dict[str, Any] = {"temperature": 0.5} + + async with _vertex_override_async(kwargs): + pass + + mock_lock.__aenter__.assert_called_once() + + class TestCompletionWithVertexOverride: """Test litellm.completion integration with vertex override.""" From 0a37727b85d8148e2b66e2f48235f7fad0875ad4 Mon Sep 17 00:00:00 2001 From: Natasha Vazquez Date: Mon, 18 May 2026 20:55:24 -0700 Subject: [PATCH 3/7] Consolidate vertex override locks into single threading.Lock Replace the dual-lock scheme (threading.Lock + asyncio.Lock) with a single threading.Lock shared by both sync and async code paths. _vertex_override_async now uses asyncio.to_thread to acquire the lock in a thread-pool worker, avoiding event-loop blocking without needing a separate asyncio.Lock. --- .../core/llm/litellm_patch.py | 61 +++++---- tests/unit/core/llm/test_litellm_patch.py | 118 ++++++++++++++++-- 2 files changed, 141 insertions(+), 38 deletions(-) diff --git a/src/lightspeed_evaluation/core/llm/litellm_patch.py b/src/lightspeed_evaluation/core/llm/litellm_patch.py index 621843d1..d42b85bf 100644 --- a/src/lightspeed_evaluation/core/llm/litellm_patch.py +++ b/src/lightspeed_evaluation/core/llm/litellm_patch.py @@ -98,16 +98,14 @@ def clear_queue(self) -> None: # ============================================================================= -# GLOBAL STATE LOCKS +# GLOBAL STATE LOCK # ============================================================================= -# Locks for ALL litellm global state mutations (cache, ssl_verify, -# vertex_project, vertex_location). Import the appropriate lock in any -# module that reads/writes litellm global state to prevent race conditions -# between concurrent pipelines. -# - litellm_state_lock: for synchronous code paths (threading.Lock) -# - litellm_state_async_lock: for asynchronous code paths (asyncio.Lock) +# Single lock for ALL litellm global state mutations (cache, ssl_verify, +# vertex_project, vertex_location). Import this lock in any module that +# reads/writes litellm global state to prevent race conditions between +# concurrent pipelines. Both sync and async code paths share this lock; +# async callers use asyncio.to_thread so the event loop is never blocked. litellm_state_lock = threading.Lock() -litellm_state_async_lock = asyncio.Lock() # ============================================================================= @@ -149,30 +147,41 @@ def _vertex_override(kwargs: dict[str, Any]) -> Generator[None, None, None]: async def _vertex_override_async( kwargs: dict[str, Any], ) -> AsyncGenerator[None, None]: - """Async version of _vertex_override using asyncio.Lock. + """Async version of _vertex_override using asyncio.to_thread. - Uses litellm_state_async_lock instead of threading.Lock to avoid blocking - the event loop. The lock is held across the yield (including any awaited - completion call) to ensure globals remain consistent for the duration of - the request. + Runs lock acquisition and litellm global-state mutation in a thread-pool + worker via asyncio.to_thread so the event loop is never blocked. Uses the + same litellm_state_lock as the synchronous path to prevent races between + sync and async callers. """ - async with litellm_state_async_lock: - vp = kwargs.pop("vertex_project", None) - vl = kwargs.pop("vertex_location", None) - if vp is None and vl is None: - yield - return - old_vp = getattr(litellm, "vertex_project", None) - old_vl = getattr(litellm, "vertex_location", None) - try: + + def _apply() -> tuple[Any, Any] | None: + with litellm_state_lock: + vp = kwargs.pop("vertex_project", None) + vl = kwargs.pop("vertex_location", None) + if vp is None and vl is None: + return None + old_vp = getattr(litellm, "vertex_project", None) + old_vl = getattr(litellm, "vertex_location", None) if vp is not None: litellm.vertex_project = vp if vl is not None: litellm.vertex_location = vl - yield - finally: - litellm.vertex_project = old_vp - litellm.vertex_location = old_vl + return (old_vp, old_vl) + + def _restore(old: tuple[Any, Any]) -> None: + with litellm_state_lock: + litellm.vertex_project = old[0] + litellm.vertex_location = old[1] + + old = await asyncio.to_thread(_apply) + if old is None: + yield + return + try: + yield + finally: + await asyncio.to_thread(_restore, old) # ============================================================================= diff --git a/tests/unit/core/llm/test_litellm_patch.py b/tests/unit/core/llm/test_litellm_patch.py index d5591f68..640ef8f4 100644 --- a/tests/unit/core/llm/test_litellm_patch.py +++ b/tests/unit/core/llm/test_litellm_patch.py @@ -1,5 +1,7 @@ """Unit tests for litellm_patch vertex override support.""" +import asyncio +import threading from typing import Any, Callable import pytest @@ -163,36 +165,32 @@ async def test_params_restored_on_exception(self) -> None: assert getattr(litellm, "vertex_location", None) == old_location @pytest.mark.asyncio - async def test_async_lock_acquired_with_vertex_params( + async def test_threading_lock_used_with_vertex_params( self, mocker: MockerFixture ) -> None: - """Test that the async lock is acquired when vertex params are present.""" + """Test that litellm_state_lock (threading.Lock) is acquired for vertex params.""" mock_lock = mocker.MagicMock() - mock_lock.__aenter__ = mocker.AsyncMock(return_value=None) - mock_lock.__aexit__ = mocker.AsyncMock(return_value=False) - mocker.patch.object(litellm_patch, "litellm_state_async_lock", mock_lock) + mocker.patch.object(litellm_patch, "litellm_state_lock", mock_lock) kwargs: dict[str, Any] = {"vertex_location": "us-central1"} async with _vertex_override_async(kwargs): pass - mock_lock.__aenter__.assert_called_once() + assert mock_lock.__enter__.call_count >= 1 @pytest.mark.asyncio - async def test_async_lock_acquired_without_vertex_params( + async def test_threading_lock_used_without_vertex_params( self, mocker: MockerFixture ) -> None: - """Test that the async lock is acquired even without vertex params.""" + """Test that litellm_state_lock is acquired even without vertex params.""" mock_lock = mocker.MagicMock() - mock_lock.__aenter__ = mocker.AsyncMock(return_value=None) - mock_lock.__aexit__ = mocker.AsyncMock(return_value=False) - mocker.patch.object(litellm_patch, "litellm_state_async_lock", mock_lock) + mocker.patch.object(litellm_patch, "litellm_state_lock", mock_lock) kwargs: dict[str, Any] = {"temperature": 0.5} async with _vertex_override_async(kwargs): pass - mock_lock.__aenter__.assert_called_once() + mock_lock.__enter__.assert_called_once() class TestCompletionWithVertexOverride: @@ -291,3 +289,99 @@ async def test_acompletion_with_both_vertex_params( assert "vertex_project" not in call_kwargs assert getattr(litellm, "vertex_location", None) == old_location assert getattr(litellm, "vertex_project", None) == old_project + + +class TestInterleavedSyncAsyncCompletions: + """Test that sync and async completions sharing one lock don't race.""" + + @pytest.mark.asyncio + async def test_interleaved_sync_async_no_deadlock( + self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] + ) -> None: + """Interleaved sync/async vertex completions complete without deadlock.""" + response = mock_judge_llm_response( + prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" + ) + mocker.patch.object( + litellm_patch, "_original_completion", return_value=response + ) + mocker.patch.object( + litellm_patch, "_original_acompletion", return_value=response + ) + + old_vp = getattr(litellm, "vertex_project", None) + old_vl = getattr(litellm, "vertex_location", None) + completed: list[str] = [] + + def run_sync(project: str, location: str) -> None: + litellm.completion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "test"}], + vertex_project=project, + vertex_location=location, + ) + completed.append(f"sync-{project}") + + async def run_async(project: str, location: str) -> None: + await litellm.acompletion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "test"}], + vertex_project=project, + vertex_location=location, + ) + completed.append(f"async-{project}") + + try: + loop = asyncio.get_running_loop() + await asyncio.gather( + loop.run_in_executor(None, run_sync, "proj-s1", "loc-s1"), + loop.run_in_executor(None, run_sync, "proj-s2", "loc-s2"), + run_async("proj-a1", "loc-a1"), + run_async("proj-a2", "loc-a2"), + ) + + assert len(completed) == 4 + finally: + litellm.vertex_project = old_vp + litellm.vertex_location = old_vl + + @pytest.mark.asyncio + async def test_sequential_sync_async_restores_globals( + self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] + ) -> None: + """Sequential sync then async vertex call restores globals correctly.""" + response = mock_judge_llm_response( + prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" + ) + mocker.patch.object( + litellm_patch, "_original_completion", return_value=response + ) + mocker.patch.object( + litellm_patch, "_original_acompletion", return_value=response + ) + + old_vp = getattr(litellm, "vertex_project", None) + old_vl = getattr(litellm, "vertex_location", None) + + litellm.completion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "test"}], + vertex_project="proj-sync", + vertex_location="loc-sync", + ) + assert getattr(litellm, "vertex_project", None) == old_vp + assert getattr(litellm, "vertex_location", None) == old_vl + + await litellm.acompletion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "test"}], + vertex_project="proj-async", + vertex_location="loc-async", + ) + assert getattr(litellm, "vertex_project", None) == old_vp + assert getattr(litellm, "vertex_location", None) == old_vl + + def test_sync_and_async_share_same_lock(self) -> None: + """Both _vertex_override and _vertex_override_async use litellm_state_lock.""" + assert isinstance(litellm_patch.litellm_state_lock, type(threading.Lock())) + assert not hasattr(litellm_patch, "litellm_state_async_lock") From b59d0a38b4462c52ae67c89755416458cd92beac Mon Sep 17 00:00:00 2001 From: Natasha Vazquez Date: Mon, 18 May 2026 21:41:17 -0700 Subject: [PATCH 4/7] fix: hold lock across yield in _vertex_override_async to close race window The async context manager acquired and released litellm_state_lock inside separate _apply/_restore helpers, leaving the lock unheld during the yield. A concurrent caller could overwrite vertex_project/ vertex_location globals in that window. Now acquire before mutating and release in finally, matching the sync version's lock lifetime. --- .../core/llm/litellm_patch.py | 46 ++++++++----------- tests/unit/core/llm/test_litellm_patch.py | 13 +++--- 2 files changed, 26 insertions(+), 33 deletions(-) diff --git a/src/lightspeed_evaluation/core/llm/litellm_patch.py b/src/lightspeed_evaluation/core/llm/litellm_patch.py index d42b85bf..8c18f4a2 100644 --- a/src/lightspeed_evaluation/core/llm/litellm_patch.py +++ b/src/lightspeed_evaluation/core/llm/litellm_patch.py @@ -149,39 +149,31 @@ async def _vertex_override_async( ) -> AsyncGenerator[None, None]: """Async version of _vertex_override using asyncio.to_thread. - Runs lock acquisition and litellm global-state mutation in a thread-pool - worker via asyncio.to_thread so the event loop is never blocked. Uses the - same litellm_state_lock as the synchronous path to prevent races between - sync and async callers. + Acquires litellm_state_lock before mutating globals and holds it across the + yield so no concurrent caller can see partially-updated state. Lock + acquire/release use asyncio.to_thread to avoid blocking the event loop. + Uses the same lock as the synchronous path to prevent races between sync + and async callers. """ - - def _apply() -> tuple[Any, Any] | None: - with litellm_state_lock: - vp = kwargs.pop("vertex_project", None) - vl = kwargs.pop("vertex_location", None) - if vp is None and vl is None: - return None - old_vp = getattr(litellm, "vertex_project", None) - old_vl = getattr(litellm, "vertex_location", None) - if vp is not None: - litellm.vertex_project = vp - if vl is not None: - litellm.vertex_location = vl - return (old_vp, old_vl) - - def _restore(old: tuple[Any, Any]) -> None: - with litellm_state_lock: - litellm.vertex_project = old[0] - litellm.vertex_location = old[1] - - old = await asyncio.to_thread(_apply) - if old is None: + vp = kwargs.pop("vertex_project", None) + vl = kwargs.pop("vertex_location", None) + if vp is None and vl is None: yield return + + await asyncio.to_thread(litellm_state_lock.acquire) + old_vp = getattr(litellm, "vertex_project", None) + old_vl = getattr(litellm, "vertex_location", None) try: + if vp is not None: + litellm.vertex_project = vp + if vl is not None: + litellm.vertex_location = vl yield finally: - await asyncio.to_thread(_restore, old) + litellm.vertex_project = old_vp + litellm.vertex_location = old_vl + await asyncio.to_thread(litellm_state_lock.release) # ============================================================================= diff --git a/tests/unit/core/llm/test_litellm_patch.py b/tests/unit/core/llm/test_litellm_patch.py index 640ef8f4..d2385a37 100644 --- a/tests/unit/core/llm/test_litellm_patch.py +++ b/tests/unit/core/llm/test_litellm_patch.py @@ -168,21 +168,22 @@ async def test_params_restored_on_exception(self) -> None: async def test_threading_lock_used_with_vertex_params( self, mocker: MockerFixture ) -> None: - """Test that litellm_state_lock (threading.Lock) is acquired for vertex params.""" + """Test that litellm_state_lock is acquired and held across yield.""" mock_lock = mocker.MagicMock() mocker.patch.object(litellm_patch, "litellm_state_lock", mock_lock) kwargs: dict[str, Any] = {"vertex_location": "us-central1"} async with _vertex_override_async(kwargs): - pass + mock_lock.acquire.assert_called_once() + mock_lock.release.assert_not_called() - assert mock_lock.__enter__.call_count >= 1 + mock_lock.release.assert_called_once() @pytest.mark.asyncio - async def test_threading_lock_used_without_vertex_params( + async def test_threading_lock_not_used_without_vertex_params( self, mocker: MockerFixture ) -> None: - """Test that litellm_state_lock is acquired even without vertex params.""" + """Test that litellm_state_lock is not acquired when no vertex params.""" mock_lock = mocker.MagicMock() mocker.patch.object(litellm_patch, "litellm_state_lock", mock_lock) kwargs: dict[str, Any] = {"temperature": 0.5} @@ -190,7 +191,7 @@ async def test_threading_lock_used_without_vertex_params( async with _vertex_override_async(kwargs): pass - mock_lock.__enter__.assert_called_once() + mock_lock.acquire.assert_not_called() class TestCompletionWithVertexOverride: From d9addea4d9a4a7a78cb34f806ef9388174493d91 Mon Sep 17 00:00:00 2001 From: Natasha Vazquez Date: Mon, 18 May 2026 21:47:23 -0700 Subject: [PATCH 5/7] test: add timeout to asyncio.gather in deadlock detection test Wrap the asyncio.gather call in asyncio.wait_for with a 10-second timeout so the test fails cleanly instead of hanging indefinitely if a deadlock occurs between sync and async vertex completions. --- tests/unit/core/llm/test_litellm_patch.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/unit/core/llm/test_litellm_patch.py b/tests/unit/core/llm/test_litellm_patch.py index d2385a37..a8387d71 100644 --- a/tests/unit/core/llm/test_litellm_patch.py +++ b/tests/unit/core/llm/test_litellm_patch.py @@ -334,11 +334,14 @@ async def run_async(project: str, location: str) -> None: try: loop = asyncio.get_running_loop() - await asyncio.gather( - loop.run_in_executor(None, run_sync, "proj-s1", "loc-s1"), - loop.run_in_executor(None, run_sync, "proj-s2", "loc-s2"), - run_async("proj-a1", "loc-a1"), - run_async("proj-a2", "loc-a2"), + await asyncio.wait_for( + asyncio.gather( + loop.run_in_executor(None, run_sync, "proj-s1", "loc-s1"), + loop.run_in_executor(None, run_sync, "proj-s2", "loc-s2"), + run_async("proj-a1", "loc-a1"), + run_async("proj-a2", "loc-a2"), + ), + timeout=10, ) assert len(completed) == 4 From 8d456a46f533f9dafbc601f52872a299f303f78c Mon Sep 17 00:00:00 2001 From: Natasha Vazquez Date: Tue, 19 May 2026 09:43:10 -0700 Subject: [PATCH 6/7] test: assert lock release in sync vertex override tests Add __exit__ assertions alongside existing __enter__ checks so the tests verify the lock is both acquired and released. --- tests/unit/core/llm/test_litellm_patch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/core/llm/test_litellm_patch.py b/tests/unit/core/llm/test_litellm_patch.py index a8387d71..ce92dc35 100644 --- a/tests/unit/core/llm/test_litellm_patch.py +++ b/tests/unit/core/llm/test_litellm_patch.py @@ -91,6 +91,7 @@ def test_lock_acquired_without_vertex_params(self, mocker: MockerFixture) -> Non pass mock_lock.__enter__.assert_called_once() + mock_lock.__exit__.assert_called_once() def test_lock_acquired_with_vertex_params(self, mocker: MockerFixture) -> None: """Test that the lock is acquired when vertex params are present.""" @@ -102,6 +103,7 @@ def test_lock_acquired_with_vertex_params(self, mocker: MockerFixture) -> None: pass mock_lock.__enter__.assert_called_once() + mock_lock.__exit__.assert_called_once() class TestVertexOverrideAsyncContextManager: From 5bb0acd259113a2fe92eb092f445faba447e2edd Mon Sep 17 00:00:00 2001 From: Natasha Vazquez Date: Wed, 20 May 2026 07:02:33 -0700 Subject: [PATCH 7/7] fix: acquire lock before popping kwargs in _vertex_override_async to close race The async _vertex_override_async popped vertex_project/vertex_location from kwargs and checked for None before acquiring litellm_state_lock, allowing concurrent callers to observe partially-updated litellm globals. Move lock acquisition before the pop/check so the async path matches the sync _vertex_override, which already holds the lock from the start. --- .../core/llm/litellm_patch.py | 20 +++++++++++-------- tests/unit/core/llm/test_litellm_patch.py | 7 ++++--- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/lightspeed_evaluation/core/llm/litellm_patch.py b/src/lightspeed_evaluation/core/llm/litellm_patch.py index 8c18f4a2..d41d51ec 100644 --- a/src/lightspeed_evaluation/core/llm/litellm_patch.py +++ b/src/lightspeed_evaluation/core/llm/litellm_patch.py @@ -155,20 +155,24 @@ async def _vertex_override_async( Uses the same lock as the synchronous path to prevent races between sync and async callers. """ - vp = kwargs.pop("vertex_project", None) - vl = kwargs.pop("vertex_location", None) - if vp is None and vl is None: - yield - return - await asyncio.to_thread(litellm_state_lock.acquire) - old_vp = getattr(litellm, "vertex_project", None) - old_vl = getattr(litellm, "vertex_location", None) try: + vp = kwargs.pop("vertex_project", None) + vl = kwargs.pop("vertex_location", None) + if vp is None and vl is None: + await asyncio.to_thread(litellm_state_lock.release) + yield + return + old_vp = getattr(litellm, "vertex_project", None) + old_vl = getattr(litellm, "vertex_location", None) if vp is not None: litellm.vertex_project = vp if vl is not None: litellm.vertex_location = vl + except BaseException: + await asyncio.to_thread(litellm_state_lock.release) + raise + try: yield finally: litellm.vertex_project = old_vp diff --git a/tests/unit/core/llm/test_litellm_patch.py b/tests/unit/core/llm/test_litellm_patch.py index ce92dc35..b8fd843f 100644 --- a/tests/unit/core/llm/test_litellm_patch.py +++ b/tests/unit/core/llm/test_litellm_patch.py @@ -182,10 +182,10 @@ async def test_threading_lock_used_with_vertex_params( mock_lock.release.assert_called_once() @pytest.mark.asyncio - async def test_threading_lock_not_used_without_vertex_params( + async def test_threading_lock_acquired_without_vertex_params( self, mocker: MockerFixture ) -> None: - """Test that litellm_state_lock is not acquired when no vertex params.""" + """Test that litellm_state_lock is acquired even when no vertex params.""" mock_lock = mocker.MagicMock() mocker.patch.object(litellm_patch, "litellm_state_lock", mock_lock) kwargs: dict[str, Any] = {"temperature": 0.5} @@ -193,7 +193,8 @@ async def test_threading_lock_not_used_without_vertex_params( async with _vertex_override_async(kwargs): pass - mock_lock.acquire.assert_not_called() + mock_lock.acquire.assert_called_once() + mock_lock.release.assert_called_once() class TestCompletionWithVertexOverride: