Skip to content
Open

fix: #211

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions sdks/python/src/agent_control/control_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async def chat(message: str) -> str:
from typing import Any, TypeVar

from agent_control_models import Step, normalize_action
from agent_control_telemetry import get_trace_context_from_provider

from agent_control import AgentControlClient
from agent_control.evaluation import check_evaluation_with_local
Expand All @@ -53,6 +54,25 @@ async def chat(message: str) -> str:
F = TypeVar("F", bound=Callable[..., Any])


def _resolve_control_trace_context() -> tuple[str, str]:
"""Resolve trace/span IDs for a decorated control site.

External providers, such as the Galileo bridge, are authoritative because
they may reserve the concrete span ID that the eventual LLM/tool call will
use. Without a provider, keep the existing behavior: share an active trace
but create a fresh function span for this decorated call.
"""
provider_context = get_trace_context_from_provider()
if provider_context is not None:
return provider_context["trace_id"], provider_context["span_id"]

existing_trace_id = get_current_trace_id()
if existing_trace_id:
return existing_trace_id, _generate_span_id()

return get_trace_and_span_ids()


@dataclass
class ControlContext:
"""
Expand Down Expand Up @@ -697,14 +717,7 @@ async def _execute_with_control(
# Get cached controls for local evaluation support
controls = _get_server_controls()

# Get trace context: inherit trace_id if set, always generate new span_id
# This allows multiple @control() calls to share the same trace but have unique spans
existing_trace_id = get_current_trace_id()
if existing_trace_id:
trace_id = existing_trace_id
span_id = _generate_span_id() # New span for this function
else:
trace_id, span_id = get_trace_and_span_ids() # New trace and span
trace_id, span_id = _resolve_control_trace_context()

ctx = ControlContext(
agent_name=agent.agent_name,
Expand Down
52 changes: 50 additions & 2 deletions sdks/python/tests/test_control_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from unittest.mock import MagicMock, patch

import pytest
from agent_control_telemetry import clear_trace_context_provider, set_trace_context_provider

from agent_control.control_decorators import ControlViolationError, ControlSteerError, control

from agent_control.control_decorators import ControlSteerError, ControlViolationError, control

# =============================================================================
# FIXTURES
Expand Down Expand Up @@ -255,6 +255,54 @@ async def chat(message: str) -> str:
class TestPrePostExecution:
"""Tests for pre and post execution checks."""

@pytest.mark.asyncio
async def test_uses_external_provider_trace_context(self, mock_agent, mock_safe_response):
"""Test that an external provider supplies both trace and span IDs."""
# Given: an external telemetry provider that owns the active trace/span IDs
provided_trace_id = "6c4e3f7e-4a9a-4e7e-8c1f-3a9a3a9a3a9d"
provided_span_id = "8d30272e-23f7-4a4c-80d8-2decb2f3f9f8"
captured_contexts = []

async def mock_evaluate(
agent_name,
step,
stage,
server_url,
trace_id=None,
span_id=None,
controls=None,
event_agent_name=None,
):
captured_contexts.append((trace_id, span_id))
return mock_safe_response

set_trace_context_provider(
lambda: {"trace_id": provided_trace_id, "span_id": provided_span_id}
)
try:
with (
patch(
"agent_control.control_decorators._get_current_agent",
return_value=mock_agent,
),
patch("agent_control.control_decorators._evaluate", side_effect=mock_evaluate),
):

@control()
async def chat(message: str) -> str:
return f"Response to: {message}"

# When: a protected function runs pre and post checks
await chat("Hello!")
finally:
clear_trace_context_provider()

# Then: Agent Control preserves the provider's concrete target span ID
assert captured_contexts == [
(provided_trace_id, provided_span_id),
(provided_trace_id, provided_span_id),
]

@pytest.mark.asyncio
async def test_calls_pre_and_post(self, mock_agent, mock_safe_response):
"""Test that both pre and post checks are called."""
Expand Down
Loading