Skip to content

Commit 7a443e5

Browse files
committed
Test stateful session
1 parent 7f4227e commit 7a443e5

1 file changed

Lines changed: 136 additions & 0 deletions

File tree

tests/test_stateful_session.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import asyncio
2+
import uuid
3+
from dataclasses import dataclass
4+
5+
import pytest
6+
from mcp.types import CallToolResult, ListToolsResult
7+
from nexusrpc import Operation, service
8+
from nexusrpc.handler import StartOperationContext, service_handler, sync_operation
9+
from pydantic import BaseModel
10+
from temporalio import nexus, workflow
11+
from temporalio.api.nexus.v1 import EndpointSpec, EndpointTarget
12+
from temporalio.api.operatorservice.v1 import CreateNexusEndpointRequest
13+
from temporalio.client import WithStartWorkflowOperation
14+
from temporalio.common import WorkflowIDConflictPolicy
15+
from temporalio.contrib.pydantic import pydantic_data_converter
16+
from temporalio.testing import WorkflowEnvironment
17+
from temporalio.worker import Worker
18+
19+
import nexusmcp.workflow
20+
from nexusmcp import MCPServiceHandler
21+
22+
mcp_service = MCPServiceHandler()
23+
24+
25+
@dataclass
26+
class AppendInput:
27+
session_id: str
28+
value: int
29+
30+
31+
class AppendOutput(BaseModel):
32+
data: list[int]
33+
34+
35+
@service
36+
class TestService:
37+
append: Operation[AppendInput, AppendOutput]
38+
39+
40+
@workflow.defn(sandboxed=False)
41+
class AppendWorkflow:
42+
def __init__(self):
43+
self.data = []
44+
45+
@workflow.run
46+
async def run(self) -> None:
47+
await asyncio.Event().wait()
48+
49+
@workflow.update
50+
async def append(self, input: int) -> AppendOutput:
51+
self.data.append(input)
52+
return AppendOutput(data=self.data)
53+
54+
55+
@mcp_service.register
56+
@service_handler(service=TestService)
57+
class TestServiceHandler:
58+
@sync_operation
59+
async def append(self, ctx: StartOperationContext, input: AppendInput) -> AppendOutput:
60+
# TODO: Should we use a custom ClientSession and start the workflow in initialize()?
61+
# wf = nexus.client().get_workflow_handle_for(AppendWorkflow.run, input.session_id)
62+
with_start = WithStartWorkflowOperation(
63+
AppendWorkflow.run,
64+
id=input.session_id,
65+
task_queue=nexus.info().task_queue,
66+
id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
67+
)
68+
return await nexus.client().execute_update_with_start_workflow(
69+
AppendWorkflow.append,
70+
input.value,
71+
start_workflow_operation=with_start,
72+
)
73+
74+
75+
@dataclass
76+
class MCPCallerWorkflowInput:
77+
endpoint: str
78+
79+
80+
class MCPCallerWorkflowOutput(BaseModel):
81+
list_tools_result: ListToolsResult
82+
call_tool_results: list[CallToolResult]
83+
84+
85+
# sandbox disabled due to use of ThreadLocal by sniffio
86+
# TODO: make this unnecessary
87+
@workflow.defn(sandboxed=False)
88+
class MCPCallerWorkflow:
89+
@workflow.run
90+
async def run(self, input: MCPCallerWorkflowInput) -> MCPCallerWorkflowOutput:
91+
async with nexusmcp.workflow.MCPClient(input.endpoint).connect() as session:
92+
list_tools_result = await session.list_tools()
93+
call_tool_result_1 = await session.call_tool("TestService_append", {"session_id": "123", "value": 1})
94+
call_tool_result_2 = await session.call_tool("TestService_append", {"session_id": "123", "value": 2})
95+
return MCPCallerWorkflowOutput(
96+
list_tools_result=list_tools_result,
97+
call_tool_results=[call_tool_result_1, call_tool_result_2],
98+
)
99+
100+
101+
@pytest.mark.asyncio
102+
async def test_workflow_caller() -> None:
103+
endpoint_name = "endpoint"
104+
task_queue = "handler-queue"
105+
106+
async with await WorkflowEnvironment.start_local(data_converter=pydantic_data_converter) as env:
107+
await env.client.operator_service.create_nexus_endpoint(
108+
CreateNexusEndpointRequest(
109+
spec=EndpointSpec(
110+
name=endpoint_name,
111+
target=EndpointTarget(
112+
worker=EndpointTarget.Worker(
113+
namespace=env.client.namespace,
114+
task_queue=task_queue,
115+
)
116+
),
117+
)
118+
)
119+
)
120+
121+
async with Worker(
122+
env.client,
123+
task_queue=task_queue,
124+
workflows=[MCPCallerWorkflow, AppendWorkflow],
125+
nexus_service_handlers=[mcp_service, TestServiceHandler()],
126+
):
127+
result = await env.client.execute_workflow(
128+
MCPCallerWorkflow.run,
129+
arg=MCPCallerWorkflowInput(endpoint=endpoint_name),
130+
id=str(uuid.uuid4()),
131+
task_queue=task_queue,
132+
)
133+
assert len(result.list_tools_result.tools) == 1
134+
assert result.list_tools_result.tools[0].name == "TestService_append"
135+
assert result.call_tool_results[0].structuredContent == {"data": [1]}
136+
assert result.call_tool_results[1].structuredContent == {"data": [1, 2]}

0 commit comments

Comments
 (0)