Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions netra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ def shutdown(cls) -> None:
meter_provider.shutdown()
except Exception:
pass
# Close simulation HTTP client
if hasattr(cls, "simulation") and cls.simulation is not None:
try:
cls.simulation.close()
except Exception:
pass

@classmethod
def get_meter(cls, name: str = "netra", version: Optional[str] = None) -> otel_metrics.Meter:
Expand Down
4 changes: 4 additions & 0 deletions netra/simulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from netra.simulation.models import (
ConversationResponse,
ConversationStatus,
FileData,
ProcessedFile,
SimulationItem,
TaskResult,
)
Expand All @@ -12,6 +14,8 @@
"BaseTask",
"ConversationResponse",
"ConversationStatus",
"FileData",
"ProcessedFile",
"SimulationItem",
"TaskResult",
]
113 changes: 69 additions & 44 deletions netra/simulation/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
"""Public API for running multi-turn conversation simulations."""

import asyncio
import concurrent.futures
import logging
Expand All @@ -8,7 +6,8 @@

from netra.config import Config
from netra.simulation.client import SimulationHttpClient
from netra.simulation.models import SimulationItem
from netra.simulation.constants import DEFAULT_MAX_TURNS, LOG_PREFIX, SPAN_NAME
from netra.simulation.models import ConversationStatus, FileData, SimulationItem
from netra.simulation.task import BaseTask
from netra.simulation.utils import (
execute_task,
Expand All @@ -20,9 +19,6 @@

logger = logging.getLogger(__name__)

_LOG_PREFIX = "netra.simulation"
_SPAN_NAME = "Netra.Simulation.TestRun"


class Simulation:
"""Public API for running multi-turn conversation simulations.
Expand All @@ -43,23 +39,29 @@ def __init__(self, config: Config) -> None:
self._config = config
self._client = SimulationHttpClient(config)

def close(self) -> None:
"""Release resources held by the simulation client."""
self._client.close()

def run_simulation(
self,
name: str,
dataset_id: str,
task: BaseTask,
context: Optional[dict[str, Any]] = None,
max_concurrency: int = 5,
max_turns: int = DEFAULT_MAX_TURNS,
) -> Optional[dict[str, Any]]:
"""Run a multi-turn conversation simulation.

Args:
name: Name of the simulation run.
dataset_id: Identifier of the dataset to simulate.
task: A BaseTask instance whose run() method receives (message, session_id)
task: A BaseTask instance whose run() method receives (message, session_id, files)
and returns TaskResult. Can be sync or async.
context: Optional context data for the simulation.
max_concurrency: Maximum parallel executions (default: 5).
max_turns: Maximum conversation turns per item before aborting (default: 50).

Returns:
Dictionary with simulation results, or None on failure.
Expand All @@ -77,72 +79,79 @@ def run_simulation(
return None

run_id = run_result.get("run_id")
run_items = run_result.get("simulation_items")
if not run_items:
logger.error("%s: No items returned from create_run", _LOG_PREFIX)
simulation_items = run_result.get("simulation_items")
if not simulation_items:
logger.error("%s: No items returned from create_run", LOG_PREFIX)
return None

logger.info("%s: Starting simulation with %d items", _LOG_PREFIX, len(run_items))
logger.info("%s: Starting simulation with %d items", LOG_PREFIX, len(simulation_items))
try:
result = run_async_safely(
self._run_simulation_async(run_id, run_items, task, max_concurrency) # type:ignore[arg-type]
self._run_simulation_async(
run_id, simulation_items, task, max_concurrency, max_turns # type:ignore[arg-type]
)
)

elapsed_time = time.time() - start_time
logger.info("%s: Simulation completed in %.2f seconds", _LOG_PREFIX, elapsed_time)
logger.info("%s: Simulation completed in %.2f seconds", LOG_PREFIX, elapsed_time)
self._client.post_run_status(run_id, "completed") # type:ignore[arg-type]
return result
except BaseException:
logger.error("%s: Run simulation failed", _LOG_PREFIX)
except Exception:
logger.error("%s: Run simulation failed", LOG_PREFIX, exc_info=True)
self._client.post_run_status(run_id, "failed") # type:ignore[arg-type]
return None

async def _run_simulation_async(
self,
run_id: str,
run_items: list[SimulationItem],
simulation_items: list[SimulationItem],
task: BaseTask,
max_concurrency: int,
max_turns: int,
) -> dict[str, Any]:
"""Async implementation of run_simulation with semaphore-based concurrency.
"""Orchestrate concurrent simulation execution.

Each simulation item is dispatched to a thread via ``run_in_executor``.
Inside each thread, ``run_async_safely`` creates a **new** event loop
so that async user tasks (``BaseTask.run``) work correctly without
nesting into the orchestrator's loop. This two-level design lets us
honour ``max_concurrency`` while supporting both sync and async tasks
transparently.

Args:
run_id: The simulation run identifier.
run_items: List of simulation items to process.
simulation_items: List of simulation items to process.
task: The BaseTask instance to execute (sync or async).
max_concurrency: Maximum concurrent executions.
max_turns: Maximum conversation turns per item.

Returns:
Dictionary with simulation results.
"""

max_workers = min(5, max_concurrency)
results: dict[str, Any] = {
"success": True,
"completed": [],
"failed": [],
"total_items": len(run_items),
"total_items": len(simulation_items),
}
processed_count = 0
lock = asyncio.Lock()

loop = asyncio.get_running_loop()

def run_item_in_thread(run_item: SimulationItem) -> dict[str, Any]:
"""
Run a single simulation item in a thread.
"""Run a single simulation item in a dedicated thread/event-loop.

Args:
run_item: The simulation item to run.

Returns:
Dictionary with simulation result.
"""
return run_async_safely(self._execute_conversation(run_id, run_item, task))
return run_async_safely(self._execute_conversation(run_id, run_item, task, max_turns))

async def process_item(run_item: SimulationItem) -> None:
"""
Process a single simulation item and handle its completion.
"""Process a single simulation item and record its outcome.

Args:
run_item: The simulation item to process.
Expand All @@ -155,14 +164,14 @@ async def process_item(run_item: SimulationItem) -> None:
processed_count += 1
logger.info(
"%s: %d/%d processed (run_item_id=%s)",
_LOG_PREFIX,
LOG_PREFIX,
processed_count,
len(run_items),
len(simulation_items),
run_item.run_item_id,
)

with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
tasks = [asyncio.create_task(process_item(run_item)) for run_item in run_items]
with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
tasks = [asyncio.create_task(process_item(item)) for item in simulation_items]
try:
await asyncio.gather(*tasks)
except (asyncio.CancelledError, KeyboardInterrupt):
Expand All @@ -172,7 +181,7 @@ async def process_item(run_item: SimulationItem) -> None:
executor.shutdown(wait=False, cancel_futures=True)
logger.info(
"%s: Completed=%d, Failed=%d",
_LOG_PREFIX,
LOG_PREFIX,
len(results["completed"]),
len(results["failed"]),
)
Expand All @@ -183,41 +192,46 @@ async def _execute_conversation(
run_id: str,
run_item: SimulationItem,
task: BaseTask,
) -> Any:
max_turns: int,
) -> dict[str, Any]:
"""Execute a multi-turn conversation for a single simulation item.

Args:
run_id: The simulation run identifier.
run_item: The simulation item to process.
task: The BaseTask instance to execute (sync or async).
max_turns: Safety limit on the number of conversation turns.

Returns:
Dictionary with execution result including success status.
"""
run_item_id = run_item.run_item_id
message = run_item.message
turn_id = run_item.turn_id
raw_files: list[FileData] = run_item.files
session_id: Optional[str] = None

while True:
for turn_number in range(1, max_turns + 1):
try:
with SpanWrapper(_SPAN_NAME, module_name=_LOG_PREFIX) as span:
with SpanWrapper(SPAN_NAME, module_name=LOG_PREFIX) as span:
trace_id = ""
otel_span = span.get_current_span()
if otel_span:
span_context = otel_span.get_span_context()
trace_id = format_trace_id(span_context.trace_id)

response_message, task_session_id = await execute_task(task, message, session_id)
response_message, task_session_id = await execute_task(
task, message, session_id, raw_files=raw_files
)
if task_session_id:
session_id = task_session_id

response = self._client.trigger_conversation(
message=response_message,
turn_id=turn_id,
session_id=session_id or "",
trace_id=trace_id,
)
response = self._client.trigger_conversation(
message=response_message,
turn_id=turn_id,
session_id=session_id or "",
trace_id=trace_id,
)

if response is None:
error_msg = "Failed to get conversation response"
Expand All @@ -228,10 +242,10 @@ async def _execute_conversation(
"turn_id": turn_id,
}

if response.decision == "stop":
if response.decision == ConversationStatus.STOP:
logger.info(
"%s: Completed run_item_id=%s reason=%s",
_LOG_PREFIX,
LOG_PREFIX,
run_item_id,
response.reason,
)
Expand All @@ -243,12 +257,13 @@ async def _execute_conversation(

message = response.next_user_message # type:ignore[assignment]
turn_id = response.next_turn_id # type:ignore[assignment]
raw_files = response.next_files

except Exception as exc:
error_msg = str(exc)
logger.error(
"%s: Task failed run_item_id=%s, turn_id=%s: %s",
_LOG_PREFIX,
LOG_PREFIX,
run_item_id,
turn_id,
error_msg,
Expand All @@ -260,3 +275,13 @@ async def _execute_conversation(
"error": error_msg,
"turn_id": turn_id,
}

error_msg = f"Exceeded maximum turns ({max_turns}) for run_item_id={run_item_id}"
logger.error("%s: %s", LOG_PREFIX, error_msg)
self._client.report_failure(run_id=run_id, run_item_id=run_item_id, error=error_msg)
return {
"run_item_id": run_item_id,
"success": False,
"error": error_msg,
"turn_id": turn_id,
}
Loading