Skip to content

Commit a955f72

Browse files
committed
Test stateful session
1 parent 04379e0 commit a955f72

1 file changed

Lines changed: 127 additions & 0 deletions

File tree

tests/test_stateful_session.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import asyncio
2+
import uuid
3+
from dataclasses import dataclass
4+
5+
import pytest
6+
from mcp import ClientSession
7+
from nexusrpc import Operation, service
8+
from nexusrpc.handler import StartOperationContext, service_handler, sync_operation
9+
from temporalio import nexus, workflow
10+
from temporalio.api.nexus.v1 import EndpointSpec, EndpointTarget
11+
from temporalio.api.operatorservice.v1 import CreateNexusEndpointRequest
12+
from temporalio.client import WithStartWorkflowOperation
13+
from temporalio.common import WorkflowIDConflictPolicy
14+
from temporalio.contrib.pydantic import pydantic_data_converter
15+
from temporalio.testing import WorkflowEnvironment
16+
from temporalio.worker import Worker
17+
18+
from nexusmcp import MCPServiceHandler, WorkflowTransport
19+
20+
mcp_service = MCPServiceHandler()
21+
22+
23+
@dataclass
24+
class AppendInput:
25+
session_id: str
26+
value: int
27+
28+
29+
@service
30+
class TestService:
31+
append: Operation[AppendInput, list[int]]
32+
33+
34+
@workflow.defn(sandboxed=False)
35+
class AppendWorkflow:
36+
def __init__(self):
37+
self.data = []
38+
39+
@workflow.run
40+
async def run(self) -> None:
41+
await asyncio.Event().wait()
42+
43+
@workflow.update
44+
async def append(self, input: int) -> list[int]:
45+
self.data.append(input)
46+
return self.data
47+
48+
49+
@mcp_service.register
50+
@service_handler(service=TestService)
51+
class TestServiceHandler:
52+
@sync_operation
53+
async def append(self, ctx: StartOperationContext, input: AppendInput) -> list[int]:
54+
# TODO: start workflow in a call made during initialize()?
55+
# wf = nexus.client().get_workflow_handle_for(AppendWorkflow.run, input.session_id)
56+
with_start = WithStartWorkflowOperation(
57+
AppendWorkflow.run,
58+
id=input.session_id,
59+
task_queue=nexus.info().task_queue,
60+
id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
61+
)
62+
return await nexus.client().execute_update_with_start_workflow(
63+
AppendWorkflow.append,
64+
input.value,
65+
start_workflow_operation=with_start,
66+
)
67+
68+
69+
@dataclass
70+
class MCPCallerWorkflowInput:
71+
endpoint: str
72+
73+
74+
# sandbox disabled due to use of ThreadLocal by sniffio
75+
# TODO: make this unnecessary
76+
@workflow.defn(sandboxed=False)
77+
class MCPCallerWorkflow:
78+
@workflow.run
79+
async def run(self, input: MCPCallerWorkflowInput) -> None:
80+
transport = WorkflowTransport(input.endpoint)
81+
async with transport.connect() as (read_stream, write_stream):
82+
async with ClientSession(read_stream, write_stream) as session:
83+
await session.initialize()
84+
85+
list_tools_result = await session.list_tools()
86+
assert len(list_tools_result.tools) == 1
87+
assert list_tools_result.tools[0].name == "TestService/append"
88+
89+
call_result = await session.call_tool("TestService/append", {"session_id": "123", "value": 1})
90+
assert call_result.content[0].text == "[1]"
91+
92+
call_result = await session.call_tool("TestService/append", {"session_id": "123", "value": 2})
93+
assert call_result.content[0].text == "[1, 2]"
94+
95+
96+
@pytest.mark.asyncio
97+
async def test_workflow_caller() -> None:
98+
endpoint_name = "endpoint"
99+
task_queue = "handler-queue"
100+
101+
async with await WorkflowEnvironment.start_local(data_converter=pydantic_data_converter) as env:
102+
await env.client.operator_service.create_nexus_endpoint(
103+
CreateNexusEndpointRequest(
104+
spec=EndpointSpec(
105+
name=endpoint_name,
106+
target=EndpointTarget(
107+
worker=EndpointTarget.Worker(
108+
namespace=env.client.namespace,
109+
task_queue=task_queue,
110+
)
111+
),
112+
)
113+
)
114+
)
115+
116+
async with Worker(
117+
env.client,
118+
task_queue=task_queue,
119+
workflows=[MCPCallerWorkflow, AppendWorkflow],
120+
nexus_service_handlers=[mcp_service, TestServiceHandler()],
121+
):
122+
await env.client.execute_workflow(
123+
MCPCallerWorkflow.run,
124+
arg=MCPCallerWorkflowInput(endpoint=endpoint_name),
125+
id=str(uuid.uuid4()),
126+
task_queue=task_queue,
127+
)

0 commit comments

Comments
 (0)