diff --git a/sdks/python/src/agent_control/control_decorators.py b/sdks/python/src/agent_control/control_decorators.py index f886793d..edd8a261 100644 --- a/sdks/python/src/agent_control/control_decorators.py +++ b/sdks/python/src/agent_control/control_decorators.py @@ -36,6 +36,7 @@ async def chat(message: str) -> str: from typing import Any, TypeVar from agent_control_models import Step, normalize_action +from agent_control_telemetry import get_trace_context_from_provider from agent_control import AgentControlClient from agent_control.evaluation import check_evaluation_with_local @@ -53,6 +54,25 @@ async def chat(message: str) -> str: F = TypeVar("F", bound=Callable[..., Any]) +def _resolve_control_trace_context() -> tuple[str, str]: + """Resolve trace/span IDs for a decorated control site. + + External providers, such as the Galileo bridge, are authoritative because + they may reserve the concrete span ID that the eventual LLM/tool call will + use. Without a provider, keep the existing behavior: share an active trace + but create a fresh function span for this decorated call. + """ + provider_context = get_trace_context_from_provider() + if provider_context is not None: + return provider_context["trace_id"], provider_context["span_id"] + + existing_trace_id = get_current_trace_id() + if existing_trace_id: + return existing_trace_id, _generate_span_id() + + return get_trace_and_span_ids() + + @dataclass class ControlContext: """ @@ -697,14 +717,7 @@ async def _execute_with_control( # Get cached controls for local evaluation support controls = _get_server_controls() - # Get trace context: inherit trace_id if set, always generate new span_id - # This allows multiple @control() calls to share the same trace but have unique spans - existing_trace_id = get_current_trace_id() - if existing_trace_id: - trace_id = existing_trace_id - span_id = _generate_span_id() # New span for this function - else: - trace_id, span_id = get_trace_and_span_ids() # New trace and span + trace_id, span_id = _resolve_control_trace_context() ctx = ControlContext( agent_name=agent.agent_name, diff --git a/sdks/python/tests/test_control_decorators.py b/sdks/python/tests/test_control_decorators.py index e7972348..00a5e215 100644 --- a/sdks/python/tests/test_control_decorators.py +++ b/sdks/python/tests/test_control_decorators.py @@ -3,9 +3,9 @@ from unittest.mock import MagicMock, patch import pytest +from agent_control_telemetry import clear_trace_context_provider, set_trace_context_provider -from agent_control.control_decorators import ControlViolationError, ControlSteerError, control - +from agent_control.control_decorators import ControlSteerError, ControlViolationError, control # ============================================================================= # FIXTURES @@ -255,6 +255,54 @@ async def chat(message: str) -> str: class TestPrePostExecution: """Tests for pre and post execution checks.""" + @pytest.mark.asyncio + async def test_uses_external_provider_trace_context(self, mock_agent, mock_safe_response): + """Test that an external provider supplies both trace and span IDs.""" + # Given: an external telemetry provider that owns the active trace/span IDs + provided_trace_id = "6c4e3f7e-4a9a-4e7e-8c1f-3a9a3a9a3a9d" + provided_span_id = "8d30272e-23f7-4a4c-80d8-2decb2f3f9f8" + captured_contexts = [] + + async def mock_evaluate( + agent_name, + step, + stage, + server_url, + trace_id=None, + span_id=None, + controls=None, + event_agent_name=None, + ): + captured_contexts.append((trace_id, span_id)) + return mock_safe_response + + set_trace_context_provider( + lambda: {"trace_id": provided_trace_id, "span_id": provided_span_id} + ) + try: + with ( + patch( + "agent_control.control_decorators._get_current_agent", + return_value=mock_agent, + ), + patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate), + ): + + @control() + async def chat(message: str) -> str: + return f"Response to: {message}" + + # When: a protected function runs pre and post checks + await chat("Hello!") + finally: + clear_trace_context_provider() + + # Then: Agent Control preserves the provider's concrete target span ID + assert captured_contexts == [ + (provided_trace_id, provided_span_id), + (provided_trace_id, provided_span_id), + ] + @pytest.mark.asyncio async def test_calls_pre_and_post(self, mock_agent, mock_safe_response): """Test that both pre and post checks are called."""