diff --git a/config/system.yaml b/config/system.yaml index 7187a018..de882cf3 100644 --- a/config/system.yaml +++ b/config/system.yaml @@ -46,6 +46,18 @@ llm_pool: # Add, remove or override model specific parameters temperature: null # Removes temperature from default max_completion_tokens: 2048 # Overrides default + # Vertex AI example with per-model region: + # judge_vertex_gemini: + # provider: vertex_ai + # model: gemini-2.0-flash + # parameters: + # vertex_location: us-central1 # Region for this model + # # vertex_project: my-gcp-project # Optional: override GCP project + # judge_vertex_llama: + # provider: vertex_ai + # model: meta/llama-3.3-70b-instruct-maas + # parameters: + # vertex_location: europe-west1 # Different region for this model # Judge Panel: multiple judges from the pool # Combine their scores. First judge in judges is the fallback when the full panel is not used for a metric. diff --git a/src/lightspeed_evaluation/core/llm/litellm_patch.py b/src/lightspeed_evaluation/core/llm/litellm_patch.py index 9eb13a70..d41d51ec 100644 --- a/src/lightspeed_evaluation/core/llm/litellm_patch.py +++ b/src/lightspeed_evaluation/core/llm/litellm_patch.py @@ -1,6 +1,6 @@ -"""LiteLLM configuration for token tracking and Ragas 0.4 compatibility. +"""LiteLLM configuration for token tracking, Ragas 0.4 compatibility, and Vertex AI support. -This module configures litellm for two purposes: +This module configures litellm for three purposes: 1. TOKEN TRACKING: Wraps litellm.completion, litellm.acompletion, litellm.embedding, and litellm.aembedding to track token usage for all LLM and embedding calls. @@ -14,14 +14,22 @@ We replace the LoggingWorker with a no-op implementation to avoid this. This is safe because we don't use litellm's built-in observability features. + +3. VERTEX AI PER-MODEL REGION SUPPORT: litellm.drop_params=True (set by + DeepEval) silently strips vertex_project and vertex_location from + completion kwargs. The completion wrappers intercept these params and + temporarily set them as litellm module-level attributes, which litellm + checks as a fallback in its vertex_ai handler. """ +import asyncio import logging import os import threading import warnings +from contextlib import asynccontextmanager, contextmanager from functools import wraps -from typing import Any +from typing import Any, AsyncGenerator, Generator import litellm @@ -89,6 +97,89 @@ def clear_queue(self) -> None: litellm.suppress_debug_info = True +# ============================================================================= +# GLOBAL STATE LOCK +# ============================================================================= +# Single lock for ALL litellm global state mutations (cache, ssl_verify, +# vertex_project, vertex_location). Import this lock in any module that +# reads/writes litellm global state to prevent race conditions between +# concurrent pipelines. Both sync and async code paths share this lock; +# async callers use asyncio.to_thread so the event loop is never blocked. +litellm_state_lock = threading.Lock() + + +# ============================================================================= +# VERTEX AI PER-MODEL REGION SUPPORT +# ============================================================================= +# litellm.drop_params=True (set by DeepEval) silently strips vertex_project +# and vertex_location from completion kwargs. We intercept these params and +# temporarily set them as litellm module-level attributes, which litellm +# checks as a fallback in its vertex_ai handler. + + +@contextmanager +def _vertex_override(kwargs: dict[str, Any]) -> Generator[None, None, None]: + """Pop vertex_project/vertex_location from kwargs and set as litellm module attrs. + + Always acquires litellm_state_lock to prevent concurrent reads of partially + updated globals, even when no vertex params are present in kwargs. + """ + with litellm_state_lock: + vp = kwargs.pop("vertex_project", None) + vl = kwargs.pop("vertex_location", None) + if vp is None and vl is None: + yield + return + old_vp = getattr(litellm, "vertex_project", None) + old_vl = getattr(litellm, "vertex_location", None) + try: + if vp is not None: + litellm.vertex_project = vp + if vl is not None: + litellm.vertex_location = vl + yield + finally: + litellm.vertex_project = old_vp + litellm.vertex_location = old_vl + + +@asynccontextmanager +async def _vertex_override_async( + kwargs: dict[str, Any], +) -> AsyncGenerator[None, None]: + """Async version of _vertex_override using asyncio.to_thread. + + Acquires litellm_state_lock before mutating globals and holds it across the + yield so no concurrent caller can see partially-updated state. Lock + acquire/release use asyncio.to_thread to avoid blocking the event loop. + Uses the same lock as the synchronous path to prevent races between sync + and async callers. + """ + await asyncio.to_thread(litellm_state_lock.acquire) + try: + vp = kwargs.pop("vertex_project", None) + vl = kwargs.pop("vertex_location", None) + if vp is None and vl is None: + await asyncio.to_thread(litellm_state_lock.release) + yield + return + old_vp = getattr(litellm, "vertex_project", None) + old_vl = getattr(litellm, "vertex_location", None) + if vp is not None: + litellm.vertex_project = vp + if vl is not None: + litellm.vertex_location = vl + except BaseException: + await asyncio.to_thread(litellm_state_lock.release) + raise + try: + yield + finally: + litellm.vertex_project = old_vp + litellm.vertex_location = old_vl + await asyncio.to_thread(litellm_state_lock.release) + + # ============================================================================= # TOKEN TRACKING: Wrap completion and embedding functions # ============================================================================= @@ -101,11 +192,11 @@ def clear_queue(self) -> None: _original_aembedding = litellm.aembedding -# Patch litellm's completion functions to include token tracking @wraps(_original_completion) def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any: - """Wrapper around litellm.completion that tracks tokens.""" - response = _original_completion(*args, **kwargs) + """Wrapper around litellm.completion that tracks tokens and handles Vertex params.""" + with _vertex_override(kwargs): + response = _original_completion(*args, **kwargs) try: track_judge_tokens(response) except Exception as e: # pylint: disable=broad-exception-caught @@ -115,8 +206,9 @@ def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any: @wraps(_original_acompletion) async def _acompletion_with_token_tracking(*args: Any, **kwargs: Any) -> Any: - """Wrapper around litellm.acompletion that tracks tokens.""" - response = await _original_acompletion(*args, **kwargs) + """Wrapper around litellm.acompletion that tracks tokens and handles Vertex params.""" + async with _vertex_override_async(kwargs): + response = await _original_acompletion(*args, **kwargs) try: track_judge_tokens(response) except Exception as e: # pylint: disable=broad-exception-caught @@ -124,7 +216,6 @@ async def _acompletion_with_token_tracking(*args: Any, **kwargs: Any) -> Any: return response -# Patch litellm's embedding functions to include token tracking @wraps(_original_embedding) def _embedding_with_token_tracking(*args: Any, **kwargs: Any) -> Any: """Wrapper around litellm.embedding that tracks tokens.""" @@ -147,22 +238,12 @@ async def _aembedding_with_token_tracking(*args: Any, **kwargs: Any) -> Any: return response -# Patch litellm's completion and embedding functions to include token tracking litellm.completion = _completion_with_token_tracking litellm.acompletion = _acompletion_with_token_tracking litellm.embedding = _embedding_with_token_tracking litellm.aembedding = _aembedding_with_token_tracking -# ============================================================================= -# GLOBAL STATE LOCK -# ============================================================================= -# Single lock for ALL litellm global state mutations (cache, ssl_verify). -# Import this lock in any module that reads/writes litellm.cache or -# litellm.ssl_verify to prevent race conditions between concurrent pipelines. -litellm_state_lock = threading.Lock() - - # ============================================================================= # SSL CONFIGURATION UTILITY # ============================================================================= diff --git a/tests/unit/core/llm/test_litellm_patch.py b/tests/unit/core/llm/test_litellm_patch.py new file mode 100644 index 00000000..b8fd843f --- /dev/null +++ b/tests/unit/core/llm/test_litellm_patch.py @@ -0,0 +1,394 @@ +"""Unit tests for litellm_patch vertex override support.""" + +import asyncio +import threading +from typing import Any, Callable + +import pytest +from pytest_mock import MockerFixture +import litellm + +from lightspeed_evaluation.core.llm import litellm_patch +from lightspeed_evaluation.core.llm.litellm_patch import ( + _vertex_override, + _vertex_override_async, +) + + +class TestVertexOverrideContextManager: + """Tests for the _vertex_override context manager.""" + + def test_no_vertex_params_is_noop(self) -> None: + """Test that _vertex_override is a no-op when no vertex params present.""" + kwargs: dict[str, Any] = {"model": "gpt-4", "temperature": 0.5} + original_kwargs = dict(kwargs) + + with _vertex_override(kwargs): + pass + + assert kwargs == original_kwargs + + def test_vertex_location_set_and_restored(self) -> None: + """Test that vertex_location is set on litellm module and restored after.""" + old_value = getattr(litellm, "vertex_location", None) + kwargs: dict[str, Any] = {"vertex_location": "us-central1"} + + with _vertex_override(kwargs): + assert litellm.vertex_location == "us-central1" + assert "vertex_location" not in kwargs + + assert getattr(litellm, "vertex_location", None) == old_value + + def test_vertex_project_set_and_restored(self) -> None: + """Test that vertex_project is set on litellm module and restored after.""" + old_value = getattr(litellm, "vertex_project", None) + kwargs: dict[str, Any] = {"vertex_project": "my-project"} + + with _vertex_override(kwargs): + assert litellm.vertex_project == "my-project" + assert "vertex_project" not in kwargs + + assert getattr(litellm, "vertex_project", None) == old_value + + def test_both_params_set_and_restored(self) -> None: + """Test that both vertex params are set and restored.""" + old_location = getattr(litellm, "vertex_location", None) + old_project = getattr(litellm, "vertex_project", None) + kwargs: dict[str, Any] = { + "vertex_location": "europe-west1", + "vertex_project": "my-project", + "temperature": 0.5, + } + + with _vertex_override(kwargs): + assert litellm.vertex_location == "europe-west1" + assert litellm.vertex_project == "my-project" + assert "vertex_location" not in kwargs + assert "vertex_project" not in kwargs + assert kwargs == {"temperature": 0.5} + + assert getattr(litellm, "vertex_location", None) == old_location + assert getattr(litellm, "vertex_project", None) == old_project + + def test_params_restored_on_exception(self) -> None: + """Test that vertex params are restored even when an exception occurs.""" + old_location = getattr(litellm, "vertex_location", None) + kwargs: dict[str, Any] = {"vertex_location": "us-east1"} + + with pytest.raises(ValueError, match="test error"): + with _vertex_override(kwargs): + assert litellm.vertex_location == "us-east1" + raise ValueError("test error") + + assert getattr(litellm, "vertex_location", None) == old_location + + def test_lock_acquired_without_vertex_params(self, mocker: MockerFixture) -> None: + """Test that the lock is acquired even when no vertex params are present.""" + mock_lock = mocker.patch.object(litellm_patch, "litellm_state_lock") + kwargs: dict[str, Any] = {"temperature": 0.5} + + with _vertex_override(kwargs): + pass + + mock_lock.__enter__.assert_called_once() + mock_lock.__exit__.assert_called_once() + + def test_lock_acquired_with_vertex_params(self, mocker: MockerFixture) -> None: + """Test that the lock is acquired when vertex params are present.""" + mock_lock = mocker.MagicMock() + mocker.patch.object(litellm_patch, "litellm_state_lock", mock_lock) + kwargs: dict[str, Any] = {"vertex_location": "us-central1"} + + with _vertex_override(kwargs): + pass + + mock_lock.__enter__.assert_called_once() + mock_lock.__exit__.assert_called_once() + + +class TestVertexOverrideAsyncContextManager: + """Tests for the _vertex_override_async async context manager.""" + + @pytest.mark.asyncio + async def test_no_vertex_params_is_noop(self) -> None: + """Test that _vertex_override_async is a no-op when no vertex params present.""" + kwargs: dict[str, Any] = {"model": "gpt-4", "temperature": 0.5} + original_kwargs = dict(kwargs) + + async with _vertex_override_async(kwargs): + pass + + assert kwargs == original_kwargs + + @pytest.mark.asyncio + async def test_vertex_location_set_and_restored(self) -> None: + """Test that vertex_location is set on litellm module and restored after.""" + old_value = getattr(litellm, "vertex_location", None) + kwargs: dict[str, Any] = {"vertex_location": "us-central1"} + + async with _vertex_override_async(kwargs): + assert litellm.vertex_location == "us-central1" + assert "vertex_location" not in kwargs + + assert getattr(litellm, "vertex_location", None) == old_value + + @pytest.mark.asyncio + async def test_both_params_set_and_restored(self) -> None: + """Test that both vertex params are set and restored.""" + old_location = getattr(litellm, "vertex_location", None) + old_project = getattr(litellm, "vertex_project", None) + kwargs: dict[str, Any] = { + "vertex_location": "europe-west1", + "vertex_project": "my-project", + "temperature": 0.5, + } + + async with _vertex_override_async(kwargs): + assert litellm.vertex_location == "europe-west1" + assert litellm.vertex_project == "my-project" + assert "vertex_location" not in kwargs + assert "vertex_project" not in kwargs + assert kwargs == {"temperature": 0.5} + + assert getattr(litellm, "vertex_location", None) == old_location + assert getattr(litellm, "vertex_project", None) == old_project + + @pytest.mark.asyncio + async def test_params_restored_on_exception(self) -> None: + """Test that vertex params are restored even when an exception occurs.""" + old_location = getattr(litellm, "vertex_location", None) + kwargs: dict[str, Any] = {"vertex_location": "us-east1"} + + with pytest.raises(ValueError, match="test error"): + async with _vertex_override_async(kwargs): + assert litellm.vertex_location == "us-east1" + raise ValueError("test error") + + assert getattr(litellm, "vertex_location", None) == old_location + + @pytest.mark.asyncio + async def test_threading_lock_used_with_vertex_params( + self, mocker: MockerFixture + ) -> None: + """Test that litellm_state_lock is acquired and held across yield.""" + mock_lock = mocker.MagicMock() + mocker.patch.object(litellm_patch, "litellm_state_lock", mock_lock) + kwargs: dict[str, Any] = {"vertex_location": "us-central1"} + + async with _vertex_override_async(kwargs): + mock_lock.acquire.assert_called_once() + mock_lock.release.assert_not_called() + + mock_lock.release.assert_called_once() + + @pytest.mark.asyncio + async def test_threading_lock_acquired_without_vertex_params( + self, mocker: MockerFixture + ) -> None: + """Test that litellm_state_lock is acquired even when no vertex params.""" + mock_lock = mocker.MagicMock() + mocker.patch.object(litellm_patch, "litellm_state_lock", mock_lock) + kwargs: dict[str, Any] = {"temperature": 0.5} + + async with _vertex_override_async(kwargs): + pass + + mock_lock.acquire.assert_called_once() + mock_lock.release.assert_called_once() + + +class TestCompletionWithVertexOverride: + """Test litellm.completion integration with vertex override.""" + + def test_completion_with_vertex_location( + self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] + ) -> None: + """Test that vertex_location is handled during completion calls.""" + mock_completion = mocker.patch(f"{litellm_patch.__name__}._original_completion") + mock_completion.return_value = mock_judge_llm_response( + prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" + ) + + old_location = getattr(litellm, "vertex_location", None) + + litellm.completion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "test"}], + vertex_location="us-central1", + ) + + mock_completion.assert_called_once() + call_kwargs = mock_completion.call_args[1] + assert "vertex_location" not in call_kwargs + assert getattr(litellm, "vertex_location", None) == old_location + + def test_completion_without_vertex_params_unchanged( + self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] + ) -> None: + """Test that completion works normally without vertex params.""" + mock_completion = mocker.patch(f"{litellm_patch.__name__}._original_completion") + mock_completion.return_value = mock_judge_llm_response( + prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" + ) + + litellm.completion( + model="gpt-4", + messages=[{"role": "user", "content": "test"}], + temperature=0.5, + ) + + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["temperature"] == 0.5 + assert call_kwargs["model"] == "gpt-4" + + @pytest.mark.asyncio + async def test_acompletion_with_vertex_location( + self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] + ) -> None: + """Test that vertex_location is handled during async completion calls.""" + mock_acompletion = mocker.patch( + f"{litellm_patch.__name__}._original_acompletion" + ) + mock_acompletion.return_value = mock_judge_llm_response( + prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" + ) + + old_location = getattr(litellm, "vertex_location", None) + + await litellm.acompletion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "test"}], + vertex_location="europe-west1", + ) + + mock_acompletion.assert_called_once() + call_kwargs = mock_acompletion.call_args[1] + assert "vertex_location" not in call_kwargs + assert getattr(litellm, "vertex_location", None) == old_location + + @pytest.mark.asyncio + async def test_acompletion_with_both_vertex_params( + self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] + ) -> None: + """Test that both vertex params are handled during async completion.""" + mock_acompletion = mocker.patch( + f"{litellm_patch.__name__}._original_acompletion" + ) + mock_acompletion.return_value = mock_judge_llm_response( + prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" + ) + + old_location = getattr(litellm, "vertex_location", None) + old_project = getattr(litellm, "vertex_project", None) + + await litellm.acompletion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "test"}], + vertex_location="us-central1", + vertex_project="my-project", + ) + + call_kwargs = mock_acompletion.call_args[1] + assert "vertex_location" not in call_kwargs + assert "vertex_project" not in call_kwargs + assert getattr(litellm, "vertex_location", None) == old_location + assert getattr(litellm, "vertex_project", None) == old_project + + +class TestInterleavedSyncAsyncCompletions: + """Test that sync and async completions sharing one lock don't race.""" + + @pytest.mark.asyncio + async def test_interleaved_sync_async_no_deadlock( + self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] + ) -> None: + """Interleaved sync/async vertex completions complete without deadlock.""" + response = mock_judge_llm_response( + prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" + ) + mocker.patch.object( + litellm_patch, "_original_completion", return_value=response + ) + mocker.patch.object( + litellm_patch, "_original_acompletion", return_value=response + ) + + old_vp = getattr(litellm, "vertex_project", None) + old_vl = getattr(litellm, "vertex_location", None) + completed: list[str] = [] + + def run_sync(project: str, location: str) -> None: + litellm.completion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "test"}], + vertex_project=project, + vertex_location=location, + ) + completed.append(f"sync-{project}") + + async def run_async(project: str, location: str) -> None: + await litellm.acompletion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "test"}], + vertex_project=project, + vertex_location=location, + ) + completed.append(f"async-{project}") + + try: + loop = asyncio.get_running_loop() + await asyncio.wait_for( + asyncio.gather( + loop.run_in_executor(None, run_sync, "proj-s1", "loc-s1"), + loop.run_in_executor(None, run_sync, "proj-s2", "loc-s2"), + run_async("proj-a1", "loc-a1"), + run_async("proj-a2", "loc-a2"), + ), + timeout=10, + ) + + assert len(completed) == 4 + finally: + litellm.vertex_project = old_vp + litellm.vertex_location = old_vl + + @pytest.mark.asyncio + async def test_sequential_sync_async_restores_globals( + self, mocker: MockerFixture, mock_judge_llm_response: Callable[..., Any] + ) -> None: + """Sequential sync then async vertex call restores globals correctly.""" + response = mock_judge_llm_response( + prompt_tokens=10, completion_tokens=5, cache_hit=False, content="ok" + ) + mocker.patch.object( + litellm_patch, "_original_completion", return_value=response + ) + mocker.patch.object( + litellm_patch, "_original_acompletion", return_value=response + ) + + old_vp = getattr(litellm, "vertex_project", None) + old_vl = getattr(litellm, "vertex_location", None) + + litellm.completion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "test"}], + vertex_project="proj-sync", + vertex_location="loc-sync", + ) + assert getattr(litellm, "vertex_project", None) == old_vp + assert getattr(litellm, "vertex_location", None) == old_vl + + await litellm.acompletion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "test"}], + vertex_project="proj-async", + vertex_location="loc-async", + ) + assert getattr(litellm, "vertex_project", None) == old_vp + assert getattr(litellm, "vertex_location", None) == old_vl + + def test_sync_and_async_share_same_lock(self) -> None: + """Both _vertex_override and _vertex_override_async use litellm_state_lock.""" + assert isinstance(litellm_patch.litellm_state_lock, type(threading.Lock())) + assert not hasattr(litellm_patch, "litellm_state_async_lock")