From 797cca86ac407c73ddf6811ade7f435b34275e6c Mon Sep 17 00:00:00 2001 From: Den Delimarsky Date: Thu, 19 Feb 2026 05:18:03 +0000 Subject: [PATCH 1/6] Fix the session binding logic for tasks. --- .../server/experimental/request_context.py | 3 +- .../experimental/task_result_handler.py | 6 +- src/mcp/server/lowlevel/experimental.py | 12 +- src/mcp/shared/experimental/tasks/helpers.py | 8 +- .../tasks/in_memory_task_store.py | 52 ++++-- src/mcp/shared/experimental/tasks/store.py | 30 +++- tests/experimental/tasks/server/test_store.py | 154 ++++++++++++++++++ 7 files changed, 233 insertions(+), 32 deletions(-) diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index 91aa9a645..71d517c57 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -187,7 +187,8 @@ async def work(task: ServerTaskContext) -> CallToolResult: # Access task_group via TaskSupport - raises if not in run() context task_group = support.task_group - task = await support.store.create_task(self.task_metadata, task_id) + session_id = str(id(self._session)) + task = await support.store.create_task(self.task_metadata, task_id, session_id=session_id) task_ctx = ServerTaskContext( task=task, diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index b2268bc1c..20e943741 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -80,6 +80,7 @@ async def handle( request: GetTaskPayloadRequest, session: ServerSession, request_id: RequestId, + session_id: str | None = None, ) -> GetTaskPayloadResult: """Handle a tasks/result request. @@ -94,6 +95,7 @@ async def handle( request: The GetTaskPayloadRequest session: The server session for sending messages request_id: The request ID for relatedRequestId routing + session_id: Optional session identifier for access control. Returns: GetTaskPayloadResult with the task's final payload @@ -101,7 +103,7 @@ async def handle( task_id = request.params.task_id while True: - task = await self._store.get_task(task_id) + task = await self._store.get_task(task_id, session_id=session_id) if task is None: raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {task_id}") @@ -109,7 +111,7 @@ async def handle( # If task is terminal, return result if is_terminal(task.status): - result = await self._store.get_result(task_id) + result = await self._store.get_result(task_id, session_id=session_id) # GetTaskPayloadResult is a Result with extra="allow" # The stored result contains the actual payload data # Per spec: tasks/result MUST include _meta with related-task metadata diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 8ac268728..efd4145ea 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -153,7 +153,8 @@ def enable_tasks( async def _default_get_task( ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams ) -> GetTaskResult: - task = await task_support.store.get_task(params.task_id) + session_id = str(id(ctx.session)) + task = await task_support.store.get_task(params.task_id, session_id=session_id) if task is None: raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}") return GetTaskResult( @@ -174,8 +175,9 @@ async def _default_get_task_result( ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: assert ctx.request_id is not None + session_id = str(id(ctx.session)) req = GetTaskPayloadRequest(params=params) - result = await task_support.handler.handle(req, ctx.session, ctx.request_id) + result = await task_support.handler.handle(req, ctx.session, ctx.request_id, session_id=session_id) return result self._add_request_handler("tasks/result", _default_get_task_result) @@ -186,7 +188,8 @@ async def _default_list_tasks( ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None ) -> ListTasksResult: cursor = params.cursor if params else None - tasks, next_cursor = await task_support.store.list_tasks(cursor) + session_id = str(id(ctx.session)) + tasks, next_cursor = await task_support.store.list_tasks(cursor, session_id=session_id) return ListTasksResult(tasks=tasks, next_cursor=next_cursor) self._add_request_handler("tasks/list", _default_list_tasks) @@ -196,7 +199,8 @@ async def _default_list_tasks( async def _default_cancel_task( ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams ) -> CancelTaskResult: - result = await cancel_task(task_support.store, params.task_id) + session_id = str(id(ctx.session)) + result = await cancel_task(task_support.store, params.task_id, session_id=session_id) return result self._add_request_handler("tasks/cancel", _default_cancel_task) diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py index bd1781cb5..461ae55fe 100644 --- a/src/mcp/shared/experimental/tasks/helpers.py +++ b/src/mcp/shared/experimental/tasks/helpers.py @@ -50,6 +50,7 @@ def is_terminal(status: TaskStatus) -> bool: async def cancel_task( store: TaskStore, task_id: str, + session_id: str | None = None, ) -> CancelTaskResult: """Cancel a task with spec-compliant validation. @@ -62,20 +63,21 @@ async def cancel_task( Args: store: The task store task_id: The task identifier to cancel + session_id: Optional session identifier for access control. Returns: CancelTaskResult with the cancelled task state Raises: MCPError: With INVALID_PARAMS (-32602) if: - - Task does not exist + - Task does not exist or is not accessible by this session - Task is already in a terminal state (completed, failed, cancelled) Example: async def handle_cancel(ctx, params: CancelTaskRequestParams) -> CancelTaskResult: return await cancel_task(store, params.task_id) """ - task = await store.get_task(task_id) + task = await store.get_task(task_id, session_id=session_id) if task is None: raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {task_id}") @@ -83,7 +85,7 @@ async def handle_cancel(ctx, params: CancelTaskRequestParams) -> CancelTaskResul raise MCPError(code=INVALID_PARAMS, message=f"Cannot cancel task in terminal state '{task.status}'") # Update task to cancelled status - cancelled_task = await store.update_task(task_id, status=TASK_STATUS_CANCELLED) + cancelled_task = await store.update_task(task_id, status=TASK_STATUS_CANCELLED, session_id=session_id) return CancelTaskResult(**cancelled_task.model_dump()) diff --git a/src/mcp/shared/experimental/tasks/in_memory_task_store.py b/src/mcp/shared/experimental/tasks/in_memory_task_store.py index 42f4fb703..73060222b 100644 --- a/src/mcp/shared/experimental/tasks/in_memory_task_store.py +++ b/src/mcp/shared/experimental/tasks/in_memory_task_store.py @@ -22,6 +22,7 @@ class StoredTask: """Internal storage representation of a task.""" task: Task + session_id: str | None = None result: Result | None = None # Time when this task should be removed (None = never) expires_at: datetime | None = field(default=None) @@ -32,6 +33,7 @@ class InMemoryTaskStore(TaskStore): Features: - Automatic TTL-based cleanup (lazy expiration) + - Session isolation (tasks are scoped to their creating session) - Thread-safe for single-process async use - Pagination support for list_tasks @@ -66,10 +68,25 @@ def _cleanup_expired(self) -> None: for task_id in expired_ids: del self._tasks[task_id] + def _get_stored_task(self, task_id: str, session_id: str | None = None) -> StoredTask | None: + """Retrieve a stored task, enforcing session ownership when a session_id is provided. + + Returns None if the task does not exist or belongs to a different session. + When either the caller's session_id or the stored task's session_id is None, + no filtering occurs (backward compatibility). + """ + stored = self._tasks.get(task_id) + if stored is None: + return None + if session_id is not None and stored.session_id is not None and stored.session_id != session_id: + return None + return stored + async def create_task( self, metadata: TaskMetadata, task_id: str | None = None, + session_id: str | None = None, ) -> Task: """Create a new task with the given metadata.""" # Cleanup expired tasks on access @@ -82,6 +99,7 @@ async def create_task( stored = StoredTask( task=task, + session_id=session_id, expires_at=self._calculate_expiry(metadata.ttl), ) self._tasks[task.task_id] = stored @@ -89,12 +107,12 @@ async def create_task( # Return a copy to prevent external modification return Task(**task.model_dump()) - async def get_task(self, task_id: str) -> Task | None: + async def get_task(self, task_id: str, session_id: str | None = None) -> Task | None: """Get a task by ID.""" # Cleanup expired tasks on access self._cleanup_expired() - stored = self._tasks.get(task_id) + stored = self._get_stored_task(task_id, session_id) if stored is None: return None @@ -106,9 +124,10 @@ async def update_task( task_id: str, status: TaskStatus | None = None, status_message: str | None = None, + session_id: str | None = None, ) -> Task: """Update a task's status and/or message.""" - stored = self._tasks.get(task_id) + stored = self._get_stored_task(task_id, session_id) if stored is None: raise ValueError(f"Task with ID {task_id} not found") @@ -137,17 +156,17 @@ async def update_task( return Task(**stored.task.model_dump()) - async def store_result(self, task_id: str, result: Result) -> None: + async def store_result(self, task_id: str, result: Result, session_id: str | None = None) -> None: """Store the result for a task.""" - stored = self._tasks.get(task_id) + stored = self._get_stored_task(task_id, session_id) if stored is None: raise ValueError(f"Task with ID {task_id} not found") stored.result = result - async def get_result(self, task_id: str) -> Result | None: + async def get_result(self, task_id: str, session_id: str | None = None) -> Result | None: """Get the stored result for a task.""" - stored = self._tasks.get(task_id) + stored = self._get_stored_task(task_id, session_id) if stored is None: return None @@ -156,34 +175,41 @@ async def get_result(self, task_id: str) -> Result | None: async def list_tasks( self, cursor: str | None = None, + session_id: str | None = None, ) -> tuple[list[Task], str | None]: """List tasks with pagination.""" # Cleanup expired tasks on access self._cleanup_expired() - all_task_ids = list(self._tasks.keys()) + # Filter tasks by session ownership before pagination + filtered_task_ids = [ + task_id + for task_id, stored in self._tasks.items() + if session_id is None or stored.session_id is None or stored.session_id == session_id + ] start_index = 0 if cursor is not None: try: - cursor_index = all_task_ids.index(cursor) + cursor_index = filtered_task_ids.index(cursor) start_index = cursor_index + 1 except ValueError: raise ValueError(f"Invalid cursor: {cursor}") - page_task_ids = all_task_ids[start_index : start_index + self._page_size] + page_task_ids = filtered_task_ids[start_index : start_index + self._page_size] tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids] # Determine next cursor next_cursor = None - if start_index + self._page_size < len(all_task_ids) and page_task_ids: + if start_index + self._page_size < len(filtered_task_ids) and page_task_ids: next_cursor = page_task_ids[-1] return tasks, next_cursor - async def delete_task(self, task_id: str) -> bool: + async def delete_task(self, task_id: str, session_id: str | None = None) -> bool: """Delete a task.""" - if task_id not in self._tasks: + stored = self._get_stored_task(task_id, session_id) + if stored is None: return False del self._tasks[task_id] diff --git a/src/mcp/shared/experimental/tasks/store.py b/src/mcp/shared/experimental/tasks/store.py index 7de97d40c..aceb54214 100644 --- a/src/mcp/shared/experimental/tasks/store.py +++ b/src/mcp/shared/experimental/tasks/store.py @@ -19,12 +19,15 @@ async def create_task( self, metadata: TaskMetadata, task_id: str | None = None, + session_id: str | None = None, ) -> Task: """Create a new task. Args: metadata: Task metadata (ttl, etc.) task_id: Optional task ID. If None, implementation should generate one. + session_id: Optional session identifier. When provided, the task is + bound to this session for isolation purposes. Returns: The created Task with status="working" @@ -34,14 +37,15 @@ async def create_task( """ @abstractmethod - async def get_task(self, task_id: str) -> Task | None: + async def get_task(self, task_id: str, session_id: str | None = None) -> Task | None: """Get a task by ID. Args: task_id: The task identifier + session_id: Optional session identifier for access control. Returns: - The Task, or None if not found + The Task, or None if not found or not accessible by this session. """ @abstractmethod @@ -50,6 +54,7 @@ async def update_task( task_id: str, status: TaskStatus | None = None, status_message: str | None = None, + session_id: str | None = None, ) -> Task: """Update a task's status and/or message. @@ -57,63 +62,70 @@ async def update_task( task_id: The task identifier status: New status (if changing) status_message: New status message (if changing) + session_id: Optional session identifier for access control. Returns: The updated Task Raises: - ValueError: If task not found + ValueError: If task not found or not accessible by this session. ValueError: If attempting to transition from a terminal status (completed, failed, cancelled). Per spec, terminal states MUST NOT transition to any other status. """ @abstractmethod - async def store_result(self, task_id: str, result: Result) -> None: + async def store_result(self, task_id: str, result: Result, session_id: str | None = None) -> None: """Store the result for a task. Args: task_id: The task identifier result: The result to store + session_id: Optional session identifier for access control. Raises: - ValueError: If task not found + ValueError: If task not found or not accessible by this session. """ @abstractmethod - async def get_result(self, task_id: str) -> Result | None: + async def get_result(self, task_id: str, session_id: str | None = None) -> Result | None: """Get the stored result for a task. Args: task_id: The task identifier + session_id: Optional session identifier for access control. Returns: - The stored Result, or None if not available + The stored Result, or None if not available. """ @abstractmethod async def list_tasks( self, cursor: str | None = None, + session_id: str | None = None, ) -> tuple[list[Task], str | None]: """List tasks with pagination. Args: cursor: Optional cursor for pagination + session_id: Optional session identifier. When provided, only tasks + belonging to this session are returned. Returns: Tuple of (tasks, next_cursor). next_cursor is None if no more pages. """ @abstractmethod - async def delete_task(self, task_id: str) -> bool: + async def delete_task(self, task_id: str, session_id: str | None = None) -> bool: """Delete a task. Args: task_id: The task identifier + session_id: Optional session identifier for access control. Returns: - True if deleted, False if not found + True if deleted, False if not found or not accessible by this session. """ @abstractmethod diff --git a/tests/experimental/tasks/server/test_store.py b/tests/experimental/tasks/server/test_store.py index 0d431899c..d996ec69f 100644 --- a/tests/experimental/tasks/server/test_store.py +++ b/tests/experimental/tasks/server/test_store.py @@ -404,3 +404,157 @@ async def test_cancel_task_succeeds_for_input_required_task(store: InMemoryTaskS assert result.task_id == task.task_id assert result.status == "cancelled" + + +# --- Session isolation tests --- + + +@pytest.mark.anyio +async def test_session_b_cannot_list_tasks_created_by_session_a(store: InMemoryTaskStore) -> None: + """Test that session-b cannot list tasks created by session-a.""" + await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-a") + await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-a") + + tasks, _ = await store.list_tasks(session_id="session-b") + assert len(tasks) == 0 + + +@pytest.mark.anyio +async def test_session_b_cannot_read_task_created_by_session_a(store: InMemoryTaskStore) -> None: + """Test that session-b cannot read a task created by session-a.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-a") + + result = await store.get_task(task.task_id, session_id="session-b") + assert result is None + + +@pytest.mark.anyio +async def test_session_b_cannot_update_task_created_by_session_a(store: InMemoryTaskStore) -> None: + """Test that session-b cannot update a task created by session-a.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-a") + + with pytest.raises(ValueError, match="not found"): + await store.update_task(task.task_id, status="cancelled", session_id="session-b") + + +@pytest.mark.anyio +async def test_session_b_cannot_store_result_on_session_a_task(store: InMemoryTaskStore) -> None: + """Test that session-b cannot store a result on session-a's task.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-a") + result = CallToolResult(content=[TextContent(type="text", text="secret")]) + + with pytest.raises(ValueError, match="not found"): + await store.store_result(task.task_id, result, session_id="session-b") + + +@pytest.mark.anyio +async def test_session_b_cannot_get_result_of_session_a_task(store: InMemoryTaskStore) -> None: + """Test that session-b cannot get the result of session-a's task.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-a") + result = CallToolResult(content=[TextContent(type="text", text="secret")]) + await store.store_result(task.task_id, result, session_id="session-a") + + retrieved = await store.get_result(task.task_id, session_id="session-b") + assert retrieved is None + + +@pytest.mark.anyio +async def test_session_b_cannot_delete_task_created_by_session_a(store: InMemoryTaskStore) -> None: + """Test that session-b cannot delete a task created by session-a.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-a") + + deleted = await store.delete_task(task.task_id, session_id="session-b") + assert deleted is False + + # Task should still exist for session-a + retrieved = await store.get_task(task.task_id, session_id="session-a") + assert retrieved is not None + + +@pytest.mark.anyio +async def test_owning_session_can_access_its_own_tasks(store: InMemoryTaskStore) -> None: + """Test that the owning session can access its own tasks.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-a") + + retrieved = await store.get_task(task.task_id, session_id="session-a") + assert retrieved is not None + assert retrieved.task_id == task.task_id + + +@pytest.mark.anyio +async def test_list_only_tasks_belonging_to_requesting_session(store: InMemoryTaskStore) -> None: + """Test that list_tasks returns only tasks belonging to the requesting session.""" + await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-a") + await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-b") + await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-a") + + tasks_a, _ = await store.list_tasks(session_id="session-a") + assert len(tasks_a) == 2 + + tasks_b, _ = await store.list_tasks(session_id="session-b") + assert len(tasks_b) == 1 + + +@pytest.mark.anyio +async def test_no_session_id_allows_access_backward_compat(store: InMemoryTaskStore) -> None: + """Test backward compatibility: no session_id on read allows access to all tasks.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-a") + + # No session_id on read = no filtering + retrieved = await store.get_task(task.task_id) + assert retrieved is not None + + +@pytest.mark.anyio +async def test_task_created_without_session_id_accessible_by_any_session(store: InMemoryTaskStore) -> None: + """Test that tasks created without session_id are accessible by any session.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + # Any session_id on read should still see the task + retrieved = await store.get_task(task.task_id, session_id="session-b") + assert retrieved is not None + + +@pytest.mark.anyio +async def test_session_isolation_pagination() -> None: + """Test that pagination works correctly within a session.""" + store = InMemoryTaskStore(page_size=10) + + # Create 15 tasks for session-a, 5 for session-b + for _ in range(15): + await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-a") + for _ in range(5): + await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-b") + + # First page for session-a should have 10 + page1, next_cursor = await store.list_tasks(session_id="session-a") + assert len(page1) == 10 + assert next_cursor is not None + + # Second page for session-a should have 5 + page2, next_cursor = await store.list_tasks(cursor=next_cursor, session_id="session-a") + assert len(page2) == 5 + assert next_cursor is None + + # session-b should only see its 5 + tasks_b, next_cursor = await store.list_tasks(session_id="session-b") + assert len(tasks_b) == 5 + assert next_cursor is None + + store.cleanup() + + +@pytest.mark.anyio +async def test_cancel_task_with_session_isolation(store: InMemoryTaskStore) -> None: + """Test that cancel_task respects session isolation.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-a") + + # session-b should not be able to cancel session-a's task + with pytest.raises(MCPError) as exc_info: + await cancel_task(store, task.task_id, session_id="session-b") + assert exc_info.value.error.code == INVALID_PARAMS + assert "not found" in exc_info.value.error.message + + # session-a should be able to cancel its own task + result = await cancel_task(store, task.task_id, session_id="session-a") + assert result.status == "cancelled" From 51372a39a821dae84a8a322e67f1a8ec4df93d7a Mon Sep 17 00:00:00 2001 From: Den Delimarsky Date: Thu, 19 Feb 2026 07:17:53 +0000 Subject: [PATCH 2/6] Pass through the actual session ID --- src/mcp/server/experimental/request_context.py | 2 +- src/mcp/server/lowlevel/experimental.py | 8 ++++---- src/mcp/server/lowlevel/server.py | 4 ++++ src/mcp/server/session.py | 2 ++ src/mcp/server/streamable_http_manager.py | 2 ++ 5 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index 71d517c57..ac6e4f902 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -187,7 +187,7 @@ async def work(task: ServerTaskContext) -> CallToolResult: # Access task_group via TaskSupport - raises if not in run() context task_group = support.task_group - session_id = str(id(self._session)) + session_id = self._session.session_id task = await support.store.create_task(self.task_metadata, task_id, session_id=session_id) task_ctx = ServerTaskContext( diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index efd4145ea..50a8bcd26 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -153,7 +153,7 @@ def enable_tasks( async def _default_get_task( ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams ) -> GetTaskResult: - session_id = str(id(ctx.session)) + session_id = ctx.session.session_id task = await task_support.store.get_task(params.task_id, session_id=session_id) if task is None: raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}") @@ -175,7 +175,7 @@ async def _default_get_task_result( ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: assert ctx.request_id is not None - session_id = str(id(ctx.session)) + session_id = ctx.session.session_id req = GetTaskPayloadRequest(params=params) result = await task_support.handler.handle(req, ctx.session, ctx.request_id, session_id=session_id) return result @@ -188,7 +188,7 @@ async def _default_list_tasks( ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None ) -> ListTasksResult: cursor = params.cursor if params else None - session_id = str(id(ctx.session)) + session_id = ctx.session.session_id tasks, next_cursor = await task_support.store.list_tasks(cursor, session_id=session_id) return ListTasksResult(tasks=tasks, next_cursor=next_cursor) @@ -199,7 +199,7 @@ async def _default_list_tasks( async def _default_cancel_task( ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams ) -> CancelTaskResult: - session_id = str(id(ctx.session)) + session_id = ctx.session.session_id result = await cancel_task(task_support.store, params.task_id, session_id=session_id) return result diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 9ca5ac4fc..7808eee25 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -374,6 +374,9 @@ async def run( # the initialization lifecycle, but can do so with any available node # rather than requiring initialization for each connection. stateless: bool = False, + # Optional session identifier for task isolation. When provided (e.g., + # from the transport's mcp_session_id), tasks are bound to this ID. + session_id: str | None = None, ): async with AsyncExitStack() as stack: lifespan_context = await stack.enter_async_context(self.lifespan(self)) @@ -383,6 +386,7 @@ async def run( write_stream, initialization_options, stateless=stateless, + session_id=session_id, ) ) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 6925aa556..79c51e69e 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -83,9 +83,11 @@ def __init__( write_stream: MemoryObjectSendStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, + session_id: str | None = None, ) -> None: super().__init__(read_stream, write_stream) self._stateless = stateless + self.session_id = session_id self._initialization_state = ( InitializationState.Initialized if stateless else InitializationState.NotInitialized ) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 50bcd5e79..1e416fbd8 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -172,6 +172,7 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA write_stream, self.app.create_initialization_options(), stateless=True, + session_id=http_transport.mcp_session_id, ) except Exception: # pragma: no cover logger.exception("Stateless session crashed") @@ -240,6 +241,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE write_stream, self.app.create_initialization_options(), stateless=False, + session_id=http_transport.mcp_session_id, ) if idle_scope.cancelled_caught: From 0e504fa886f1d6d2b9f9c70069d828826bb82783 Mon Sep 17 00:00:00 2001 From: Den Delimarsky Date: Fri, 20 Feb 2026 08:25:09 +0000 Subject: [PATCH 3/6] Address feedback :house: Remote-Dev: homespace --- src/mcp/client/_memory.py | 3 + .../server/experimental/request_context.py | 2 + src/mcp/server/experimental/task_context.py | 36 +++-- .../experimental/task_result_handler.py | 4 +- src/mcp/server/lowlevel/experimental.py | 17 +- src/mcp/server/lowlevel/server.py | 2 - src/mcp/server/session.py | 1 + src/mcp/shared/experimental/tasks/context.py | 14 +- src/mcp/shared/experimental/tasks/helpers.py | 18 ++- .../tasks/in_memory_task_store.py | 43 +++-- src/mcp/shared/experimental/tasks/store.py | 35 ++-- .../tasks/client/test_handlers.py | 45 +++--- tests/experimental/tasks/client/test_tasks.py | 20 +-- .../experimental/tasks/server/test_context.py | 46 +++--- .../tasks/server/test_integration.py | 14 +- .../experimental/tasks/server/test_server.py | 11 +- .../tasks/server/test_server_task_context.py | 74 +++++---- tests/experimental/tasks/server/test_store.py | 153 ++++++++---------- .../tasks/server/test_task_result_handler.py | 42 ++--- .../tasks/test_elicitation_scenarios.py | 26 +-- 20 files changed, 322 insertions(+), 284 deletions(-) diff --git a/src/mcp/client/_memory.py b/src/mcp/client/_memory.py index e6e938673..337deeb75 100644 --- a/src/mcp/client/_memory.py +++ b/src/mcp/client/_memory.py @@ -2,6 +2,7 @@ from __future__ import annotations +import uuid from collections.abc import AsyncIterator from contextlib import AbstractAsyncContextManager, asynccontextmanager from types import TracebackType @@ -50,12 +51,14 @@ async def _connect(self) -> AsyncIterator[TransportStreams]: async with anyio.create_task_group() as tg: # Start server in background + memory_session_id = uuid.uuid4().hex tg.start_soon( lambda: actual_server.run( server_read, server_write, actual_server.create_initialization_options(), raise_exceptions=self._raise_exceptions, + session_id=memory_session_id, ) ) diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index ac6e4f902..bba5ad81b 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -188,6 +188,8 @@ async def work(task: ServerTaskContext) -> CallToolResult: task_group = support.task_group session_id = self._session.session_id + if session_id is None: + raise RuntimeError("Session ID is required for task operations but session has no ID.") task = await support.store.create_task(self.task_metadata, task_id, session_id=session_id) task_ctx = ServerTaskContext( diff --git a/src/mcp/server/experimental/task_context.py b/src/mcp/server/experimental/task_context.py index 9b626c986..08ad30f68 100644 --- a/src/mcp/server/experimental/task_context.py +++ b/src/mcp/server/experimental/task_context.py @@ -88,7 +88,11 @@ def __init__( queue: The message queue for elicitation/sampling handler: The result handler for response routing (required for elicit/create_message) """ - self._ctx = TaskContext(task=task, store=store) + session_id = session.session_id + if session_id is None: + raise RuntimeError("Session ID is required for task operations but session has no ID.") + self._session_id = session_id + self._ctx = TaskContext(task=task, store=store, session_id=session_id) self._session = session self._queue = queue self._handler = handler @@ -210,7 +214,7 @@ async def elicit( raise RuntimeError("handler is required for elicit(). Pass handler= to ServerTaskContext.") # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) + await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id) # Build the request using session's helper request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage] @@ -234,12 +238,12 @@ async def elicit( try: # Wait for response (routed back via TaskResultHandler) response_data = await resolver.wait() - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id) return ElicitResult.model_validate(response_data) except anyio.get_cancelled_exc_class(): # This path is tested in test_elicit_restores_status_on_cancellation # which verifies status is restored to "working" after cancellation. - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id) raise async def elicit_url( @@ -279,7 +283,7 @@ async def elicit_url( raise RuntimeError("handler is required for elicit_url(). Pass handler= to ServerTaskContext.") # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) + await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id) # Build the request using session's helper request = self._session._build_elicit_url_request( # pyright: ignore[reportPrivateUsage] @@ -304,10 +308,10 @@ async def elicit_url( try: # Wait for response (routed back via TaskResultHandler) response_data = await resolver.wait() - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id) return ElicitResult.model_validate(response_data) except anyio.get_cancelled_exc_class(): # pragma: no cover - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id) raise async def create_message( @@ -362,7 +366,7 @@ async def create_message( raise RuntimeError("handler is required for create_message(). Pass handler= to ServerTaskContext.") # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) + await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id) # Build the request using session's helper request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage] @@ -394,12 +398,12 @@ async def create_message( try: # Wait for response (routed back via TaskResultHandler) response_data = await resolver.wait() - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id) return CreateMessageResult.model_validate(response_data) except anyio.get_cancelled_exc_class(): # This path is tested in test_create_message_restores_status_on_cancellation # which verifies status is restored to "working" after cancellation. - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id) raise async def elicit_as_task( @@ -435,7 +439,7 @@ async def elicit_as_task( raise RuntimeError("handler is required for elicit_as_task()") # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) + await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id) request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage] message=message, @@ -472,11 +476,11 @@ async def elicit_as_task( ElicitResult, ) - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id) return result except anyio.get_cancelled_exc_class(): # pragma: no cover - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id) raise async def create_message_as_task( @@ -531,7 +535,7 @@ async def create_message_as_task( raise RuntimeError("handler is required for create_message_as_task()") # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) + await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id) # Build request WITH task field for task-augmented sampling request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage] @@ -577,9 +581,9 @@ async def create_message_as_task( CreateMessageResult, ) - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id) return result except anyio.get_cancelled_exc_class(): # pragma: no cover - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id) raise diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index 20e943741..8d922acd0 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -80,7 +80,7 @@ async def handle( request: GetTaskPayloadRequest, session: ServerSession, request_id: RequestId, - session_id: str | None = None, + session_id: str, ) -> GetTaskPayloadResult: """Handle a tasks/result request. @@ -95,7 +95,7 @@ async def handle( request: The GetTaskPayloadRequest session: The server session for sending messages request_id: The request ID for relatedRequestId routing - session_id: Optional session identifier for access control. + session_id: Session identifier for access control. Returns: GetTaskPayloadResult with the task's final payload diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 50a8bcd26..33758aa2c 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -147,13 +147,22 @@ def enable_tasks( if on_cancel_task is not None: self._add_request_handler("tasks/cancel", on_cancel_task) + def _require_session_id(ctx: ServerRequestContext[LifespanResultT]) -> str: + session_id = ctx.session.session_id + if session_id is None: + raise MCPError( + code=INVALID_PARAMS, + message="Session ID is required for task operations.", + ) + return session_id + # Fill in defaults for any not provided if not self._has_handler("tasks/get"): async def _default_get_task( ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams ) -> GetTaskResult: - session_id = ctx.session.session_id + session_id = _require_session_id(ctx) task = await task_support.store.get_task(params.task_id, session_id=session_id) if task is None: raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}") @@ -175,7 +184,7 @@ async def _default_get_task_result( ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: assert ctx.request_id is not None - session_id = ctx.session.session_id + session_id = _require_session_id(ctx) req = GetTaskPayloadRequest(params=params) result = await task_support.handler.handle(req, ctx.session, ctx.request_id, session_id=session_id) return result @@ -188,7 +197,7 @@ async def _default_list_tasks( ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None ) -> ListTasksResult: cursor = params.cursor if params else None - session_id = ctx.session.session_id + session_id = _require_session_id(ctx) tasks, next_cursor = await task_support.store.list_tasks(cursor, session_id=session_id) return ListTasksResult(tasks=tasks, next_cursor=next_cursor) @@ -199,7 +208,7 @@ async def _default_list_tasks( async def _default_cancel_task( ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams ) -> CancelTaskResult: - session_id = ctx.session.session_id + session_id = _require_session_id(ctx) result = await cancel_task(task_support.store, params.task_id, session_id=session_id) return result diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 7808eee25..a54aef08e 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -374,8 +374,6 @@ async def run( # the initialization lifecycle, but can do so with any available node # rather than requiring initialization for each connection. stateless: bool = False, - # Optional session identifier for task isolation. When provided (e.g., - # from the transport's mcp_session_id), tasks are bound to this ID. session_id: str | None = None, ): async with AsyncExitStack() as stack: diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 79c51e69e..89c6cca39 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -83,6 +83,7 @@ def __init__( write_stream: MemoryObjectSendStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, + *, session_id: str | None = None, ) -> None: super().__init__(read_stream, write_stream) diff --git a/src/mcp/shared/experimental/tasks/context.py b/src/mcp/shared/experimental/tasks/context.py index ed0d2b91b..5eae0477c 100644 --- a/src/mcp/shared/experimental/tasks/context.py +++ b/src/mcp/shared/experimental/tasks/context.py @@ -21,19 +21,20 @@ class TaskContext: use ServerTaskContext from mcp.server.experimental. Example (distributed worker): - async def worker_job(task_id: str): + async def worker_job(task_id: str, session_id: str): store = RedisTaskStore(redis_url) - task = await store.get_task(task_id) - ctx = TaskContext(task=task, store=store) + task = await store.get_task(task_id, session_id=session_id) + ctx = TaskContext(task=task, store=store, session_id=session_id) await ctx.update_status("Working...") result = await do_work() await ctx.complete(result) """ - def __init__(self, task: Task, store: TaskStore): + def __init__(self, task: Task, store: TaskStore, *, session_id: str): self._task = task self._store = store + self._session_id = session_id self._cancelled = False @property @@ -68,6 +69,7 @@ async def update_status(self, message: str) -> None: self._task = await self._store.update_task( self.task_id, status_message=message, + session_id=self._session_id, ) async def complete(self, result: Result) -> None: @@ -76,10 +78,11 @@ async def complete(self, result: Result) -> None: Args: result: The task result """ - await self._store.store_result(self.task_id, result) + await self._store.store_result(self.task_id, result, session_id=self._session_id) self._task = await self._store.update_task( self.task_id, status=TASK_STATUS_COMPLETED, + session_id=self._session_id, ) async def fail(self, error: str) -> None: @@ -92,4 +95,5 @@ async def fail(self, error: str) -> None: self.task_id, status=TASK_STATUS_FAILED, status_message=error, + session_id=self._session_id, ) diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py index 461ae55fe..588eb7aae 100644 --- a/src/mcp/shared/experimental/tasks/helpers.py +++ b/src/mcp/shared/experimental/tasks/helpers.py @@ -50,7 +50,8 @@ def is_terminal(status: TaskStatus) -> bool: async def cancel_task( store: TaskStore, task_id: str, - session_id: str | None = None, + *, + session_id: str, ) -> CancelTaskResult: """Cancel a task with spec-compliant validation. @@ -63,7 +64,7 @@ async def cancel_task( Args: store: The task store task_id: The task identifier to cancel - session_id: Optional session identifier for access control. + session_id: Session identifier for access control. Returns: CancelTaskResult with the cancelled task state @@ -75,7 +76,7 @@ async def cancel_task( Example: async def handle_cancel(ctx, params: CancelTaskRequestParams) -> CancelTaskResult: - return await cancel_task(store, params.task_id) + return await cancel_task(store, params.task_id, session_id=ctx.session.session_id) """ task = await store.get_task(task_id, session_id=session_id) if task is None: @@ -124,6 +125,8 @@ def create_task_state( async def task_execution( task_id: str, store: TaskStore, + *, + session_id: str, ) -> AsyncIterator[TaskContext]: """Context manager for safe task execution (pure, no server dependencies). @@ -136,6 +139,7 @@ async def task_execution( Args: task_id: The task identifier to execute store: The task store (must be accessible by the worker) + session_id: Session identifier for access control. Yields: TaskContext for updating status and completing/failing the task @@ -144,18 +148,18 @@ async def task_execution( ValueError: If the task is not found in the store Example (distributed worker): - async def worker_process(task_id: str): + async def worker_process(task_id: str, session_id: str): store = RedisTaskStore(redis_url) - async with task_execution(task_id, store) as ctx: + async with task_execution(task_id, store, session_id=session_id) as ctx: await ctx.update_status("Working...") result = await do_work() await ctx.complete(result) """ - task = await store.get_task(task_id) + task = await store.get_task(task_id, session_id=session_id) if task is None: raise ValueError(f"Task {task_id} not found") - ctx = TaskContext(task, store) + ctx = TaskContext(task, store, session_id=session_id) try: yield ctx except Exception as e: diff --git a/src/mcp/shared/experimental/tasks/in_memory_task_store.py b/src/mcp/shared/experimental/tasks/in_memory_task_store.py index 73060222b..46c7775e1 100644 --- a/src/mcp/shared/experimental/tasks/in_memory_task_store.py +++ b/src/mcp/shared/experimental/tasks/in_memory_task_store.py @@ -22,7 +22,7 @@ class StoredTask: """Internal storage representation of a task.""" task: Task - session_id: str | None = None + session_id: str result: Result | None = None # Time when this task should be removed (None = never) expires_at: datetime | None = field(default=None) @@ -68,17 +68,15 @@ def _cleanup_expired(self) -> None: for task_id in expired_ids: del self._tasks[task_id] - def _get_stored_task(self, task_id: str, session_id: str | None = None) -> StoredTask | None: - """Retrieve a stored task, enforcing session ownership when a session_id is provided. + def _get_stored_task(self, task_id: str, *, session_id: str) -> StoredTask | None: + """Retrieve a stored task, enforcing session ownership. Returns None if the task does not exist or belongs to a different session. - When either the caller's session_id or the stored task's session_id is None, - no filtering occurs (backward compatibility). """ stored = self._tasks.get(task_id) if stored is None: return None - if session_id is not None and stored.session_id is not None and stored.session_id != session_id: + if stored.session_id != session_id: return None return stored @@ -86,7 +84,8 @@ async def create_task( self, metadata: TaskMetadata, task_id: str | None = None, - session_id: str | None = None, + *, + session_id: str, ) -> Task: """Create a new task with the given metadata.""" # Cleanup expired tasks on access @@ -107,12 +106,12 @@ async def create_task( # Return a copy to prevent external modification return Task(**task.model_dump()) - async def get_task(self, task_id: str, session_id: str | None = None) -> Task | None: + async def get_task(self, task_id: str, *, session_id: str) -> Task | None: """Get a task by ID.""" # Cleanup expired tasks on access self._cleanup_expired() - stored = self._get_stored_task(task_id, session_id) + stored = self._get_stored_task(task_id, session_id=session_id) if stored is None: return None @@ -124,10 +123,11 @@ async def update_task( task_id: str, status: TaskStatus | None = None, status_message: str | None = None, - session_id: str | None = None, + *, + session_id: str, ) -> Task: """Update a task's status and/or message.""" - stored = self._get_stored_task(task_id, session_id) + stored = self._get_stored_task(task_id, session_id=session_id) if stored is None: raise ValueError(f"Task with ID {task_id} not found") @@ -156,17 +156,17 @@ async def update_task( return Task(**stored.task.model_dump()) - async def store_result(self, task_id: str, result: Result, session_id: str | None = None) -> None: + async def store_result(self, task_id: str, result: Result, *, session_id: str) -> None: """Store the result for a task.""" - stored = self._get_stored_task(task_id, session_id) + stored = self._get_stored_task(task_id, session_id=session_id) if stored is None: raise ValueError(f"Task with ID {task_id} not found") stored.result = result - async def get_result(self, task_id: str, session_id: str | None = None) -> Result | None: + async def get_result(self, task_id: str, *, session_id: str) -> Result | None: """Get the stored result for a task.""" - stored = self._get_stored_task(task_id, session_id) + stored = self._get_stored_task(task_id, session_id=session_id) if stored is None: return None @@ -175,18 +175,15 @@ async def get_result(self, task_id: str, session_id: str | None = None) -> Resul async def list_tasks( self, cursor: str | None = None, - session_id: str | None = None, + *, + session_id: str, ) -> tuple[list[Task], str | None]: """List tasks with pagination.""" # Cleanup expired tasks on access self._cleanup_expired() # Filter tasks by session ownership before pagination - filtered_task_ids = [ - task_id - for task_id, stored in self._tasks.items() - if session_id is None or stored.session_id is None or stored.session_id == session_id - ] + filtered_task_ids = [task_id for task_id, stored in self._tasks.items() if stored.session_id == session_id] start_index = 0 if cursor is not None: @@ -206,9 +203,9 @@ async def list_tasks( return tasks, next_cursor - async def delete_task(self, task_id: str, session_id: str | None = None) -> bool: + async def delete_task(self, task_id: str, *, session_id: str) -> bool: """Delete a task.""" - stored = self._get_stored_task(task_id, session_id) + stored = self._get_stored_task(task_id, session_id=session_id) if stored is None: return False diff --git a/src/mcp/shared/experimental/tasks/store.py b/src/mcp/shared/experimental/tasks/store.py index aceb54214..6845c1b1b 100644 --- a/src/mcp/shared/experimental/tasks/store.py +++ b/src/mcp/shared/experimental/tasks/store.py @@ -19,15 +19,16 @@ async def create_task( self, metadata: TaskMetadata, task_id: str | None = None, - session_id: str | None = None, + *, + session_id: str, ) -> Task: """Create a new task. Args: metadata: Task metadata (ttl, etc.) task_id: Optional task ID. If None, implementation should generate one. - session_id: Optional session identifier. When provided, the task is - bound to this session for isolation purposes. + session_id: Session identifier. The task is bound to this session + for isolation purposes. Returns: The created Task with status="working" @@ -37,12 +38,12 @@ async def create_task( """ @abstractmethod - async def get_task(self, task_id: str, session_id: str | None = None) -> Task | None: + async def get_task(self, task_id: str, *, session_id: str) -> Task | None: """Get a task by ID. Args: task_id: The task identifier - session_id: Optional session identifier for access control. + session_id: Session identifier for access control. Returns: The Task, or None if not found or not accessible by this session. @@ -54,7 +55,8 @@ async def update_task( task_id: str, status: TaskStatus | None = None, status_message: str | None = None, - session_id: str | None = None, + *, + session_id: str, ) -> Task: """Update a task's status and/or message. @@ -62,7 +64,7 @@ async def update_task( task_id: The task identifier status: New status (if changing) status_message: New status message (if changing) - session_id: Optional session identifier for access control. + session_id: Session identifier for access control. Returns: The updated Task @@ -75,25 +77,25 @@ async def update_task( """ @abstractmethod - async def store_result(self, task_id: str, result: Result, session_id: str | None = None) -> None: + async def store_result(self, task_id: str, result: Result, *, session_id: str) -> None: """Store the result for a task. Args: task_id: The task identifier result: The result to store - session_id: Optional session identifier for access control. + session_id: Session identifier for access control. Raises: ValueError: If task not found or not accessible by this session. """ @abstractmethod - async def get_result(self, task_id: str, session_id: str | None = None) -> Result | None: + async def get_result(self, task_id: str, *, session_id: str) -> Result | None: """Get the stored result for a task. Args: task_id: The task identifier - session_id: Optional session identifier for access control. + session_id: Session identifier for access control. Returns: The stored Result, or None if not available. @@ -103,26 +105,27 @@ async def get_result(self, task_id: str, session_id: str | None = None) -> Resul async def list_tasks( self, cursor: str | None = None, - session_id: str | None = None, + *, + session_id: str, ) -> tuple[list[Task], str | None]: """List tasks with pagination. Args: cursor: Optional cursor for pagination - session_id: Optional session identifier. When provided, only tasks - belonging to this session are returned. + session_id: Session identifier. Only tasks belonging to this + session are returned. Returns: Tuple of (tasks, next_cursor). next_cursor is None if no more pages. """ @abstractmethod - async def delete_task(self, task_id: str, session_id: str | None = None) -> bool: + async def delete_task(self, task_id: str, *, session_id: str) -> bool: """Delete a task. Args: task_id: The task identifier - session_id: Optional session identifier for access control. + session_id: Session identifier for access control. Returns: True if deleted, False if not found or not accessible by this session. diff --git a/tests/experimental/tasks/client/test_handlers.py b/tests/experimental/tasks/client/test_handlers.py index 137ff8010..49fab5d92 100644 --- a/tests/experimental/tasks/client/test_handlers.py +++ b/tests/experimental/tasks/client/test_handlers.py @@ -118,7 +118,7 @@ async def get_task_handler( ) -> GetTaskResult | ErrorData: nonlocal received_task_id received_task_id = params.task_id - task = await store.get_task(params.task_id) + task = await store.get_task(params.task_id, session_id="test-session") assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, @@ -130,7 +130,7 @@ async def get_task_handler( poll_interval=task.poll_interval, ) - await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-123") + await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-123", session_id="test-session") task_handlers = ExperimentalTaskHandlers(get_task=get_task_handler) client_ready = anyio.Event() @@ -179,17 +179,18 @@ async def get_task_result_handler( context: RequestContext[ClientSession], params: GetTaskPayloadRequestParams, ) -> GetTaskPayloadResult | ErrorData: - result = await store.get_result(params.task_id) + result = await store.get_result(params.task_id, session_id="test-session") assert result is not None, f"Test setup error: result for {params.task_id} should exist" assert isinstance(result, types.CallToolResult) return GetTaskPayloadResult(**result.model_dump()) - await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-456") + await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-456", session_id="test-session") await store.store_result( "test-task-456", types.CallToolResult(content=[TextContent(type="text", text="Task completed successfully!")]), + session_id="test-session", ) - await store.update_task("test-task-456", status="completed") + await store.update_task("test-task-456", status="completed", session_id="test-session") task_handlers = ExperimentalTaskHandlers(get_task_result=get_task_result_handler) client_ready = anyio.Event() @@ -243,11 +244,11 @@ async def list_tasks_handler( params: types.PaginatedRequestParams | None, ) -> ListTasksResult | ErrorData: cursor = params.cursor if params else None - tasks_list, next_cursor = await store.list_tasks(cursor=cursor) + tasks_list, next_cursor = await store.list_tasks(cursor=cursor, session_id="test-session") return ListTasksResult(tasks=tasks_list, next_cursor=next_cursor) - await store.create_task(TaskMetadata(ttl=60000), task_id="task-1") - await store.create_task(TaskMetadata(ttl=60000), task_id="task-2") + await store.create_task(TaskMetadata(ttl=60000), task_id="task-1", session_id="test-session") + await store.create_task(TaskMetadata(ttl=60000), task_id="task-2", session_id="test-session") task_handlers = ExperimentalTaskHandlers(list_tasks=list_tasks_handler) client_ready = anyio.Event() @@ -297,10 +298,10 @@ async def cancel_task_handler( context: RequestContext[ClientSession], params: CancelTaskRequestParams, ) -> CancelTaskResult | ErrorData: - task = await store.get_task(params.task_id) + task = await store.get_task(params.task_id, session_id="test-session") assert task is not None, f"Test setup error: task {params.task_id} should exist" - await store.update_task(params.task_id, status="cancelled") - updated = await store.get_task(params.task_id) + await store.update_task(params.task_id, status="cancelled", session_id="test-session") + updated = await store.get_task(params.task_id, session_id="test-session") assert updated is not None return CancelTaskResult( task_id=updated.task_id, @@ -310,7 +311,7 @@ async def cancel_task_handler( ttl=updated.ttl, ) - await store.create_task(TaskMetadata(ttl=60000), task_id="task-to-cancel") + await store.create_task(TaskMetadata(ttl=60000), task_id="task-to-cancel", session_id="test-session") task_handlers = ExperimentalTaskHandlers(cancel_task=cancel_task_handler) client_ready = anyio.Event() @@ -365,7 +366,7 @@ async def task_augmented_sampling_callback( params: CreateMessageRequestParams, task_metadata: TaskMetadata, ) -> CreateTaskResult: - task = await store.create_task(task_metadata) + task = await store.create_task(task_metadata, session_id="test-session") created_task_id[0] = task.task_id async def do_sampling() -> None: @@ -375,8 +376,8 @@ async def do_sampling() -> None: model="test-model", stop_reason="endTurn", ) - await store.store_result(task.task_id, result) - await store.update_task(task.task_id, status="completed") + await store.store_result(task.task_id, result, session_id="test-session") + await store.update_task(task.task_id, status="completed", session_id="test-session") sampling_completed.set() assert background_tg[0] is not None @@ -387,7 +388,7 @@ async def get_task_handler( context: RequestContext[ClientSession], params: GetTaskRequestParams, ) -> GetTaskResult | ErrorData: - task = await store.get_task(params.task_id) + task = await store.get_task(params.task_id, session_id="test-session") assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, @@ -403,7 +404,7 @@ async def get_task_result_handler( context: RequestContext[ClientSession], params: GetTaskPayloadRequestParams, ) -> GetTaskPayloadResult | ErrorData: - result = await store.get_result(params.task_id) + result = await store.get_result(params.task_id, session_id="test-session") assert result is not None, f"Test setup error: result for {params.task_id} should exist" assert isinstance(result, CreateMessageResult) return GetTaskPayloadResult(**result.model_dump()) @@ -509,14 +510,14 @@ async def task_augmented_elicitation_callback( params: ElicitRequestParams, task_metadata: TaskMetadata, ) -> CreateTaskResult | ErrorData: - task = await store.create_task(task_metadata) + task = await store.create_task(task_metadata, session_id="test-session") created_task_id[0] = task.task_id async def do_elicitation() -> None: # Simulate user providing elicitation response result = ElicitResult(action="accept", content={"name": "Test User"}) - await store.store_result(task.task_id, result) - await store.update_task(task.task_id, status="completed") + await store.store_result(task.task_id, result, session_id="test-session") + await store.update_task(task.task_id, status="completed", session_id="test-session") elicitation_completed.set() assert background_tg[0] is not None @@ -527,7 +528,7 @@ async def get_task_handler( context: RequestContext[ClientSession], params: GetTaskRequestParams, ) -> GetTaskResult | ErrorData: - task = await store.get_task(params.task_id) + task = await store.get_task(params.task_id, session_id="test-session") assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, @@ -543,7 +544,7 @@ async def get_task_result_handler( context: RequestContext[ClientSession], params: GetTaskPayloadRequestParams, ) -> GetTaskPayloadResult | ErrorData: - result = await store.get_result(params.task_id) + result = await store.get_result(params.task_id, session_id="test-session") assert result is not None, f"Test setup error: result for {params.task_id} should exist" assert isinstance(result, ElicitResult) return GetTaskPayloadResult(**result.model_dump()) diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py index 613c794eb..ef47e03e6 100644 --- a/tests/experimental/tasks/client/test_tasks.py +++ b/tests/experimental/tasks/client/test_tasks.py @@ -56,13 +56,13 @@ async def _handle_call_tool_with_done_event( if ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None - task = await app.store.create_task(task_metadata) + task = await app.store.create_task(task_metadata, session_id="test-session") done_event = Event() app.task_done_events[task.task_id] = done_event async def do_work() -> None: - async with task_execution(task.task_id, app.store) as task_ctx: + async with task_execution(task.task_id, app.store, session_id="test-session") as task_ctx: await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) done_event.set() @@ -88,7 +88,7 @@ async def test_session_experimental_get_task() -> None: async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) + task = await app.store.get_task(params.task_id, session_id="test-session") assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, @@ -146,7 +146,7 @@ async def handle_get_task_result( ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: app = ctx.lifespan_context - result = await app.store.get_result(params.task_id) + result = await app.store.get_result(params.task_id, session_id="test-session") assert result is not None, f"Test setup error: result for {params.task_id} should exist" assert isinstance(result, CallToolResult) return GetTaskPayloadResult(**result.model_dump()) @@ -195,7 +195,7 @@ async def handle_list_tasks( ) -> ListTasksResult: app = ctx.lifespan_context cursor = params.cursor if params else None - tasks_list, next_cursor = await app.store.list_tasks(cursor=cursor) + tasks_list, next_cursor = await app.store.list_tasks(cursor=cursor, session_id="test-session") return ListTasksResult(tasks=tasks_list, next_cursor=next_cursor) server: Server[AppContext] = Server( @@ -239,14 +239,14 @@ async def handle_call_tool_no_work( if ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None - task = await app.store.create_task(task_metadata) + task = await app.store.create_task(task_metadata, session_id="test-session") # Don't start any work - task stays in "working" status return CreateTaskResult(task=task) raise NotImplementedError async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) + task = await app.store.get_task(params.task_id, session_id="test-session") assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, @@ -262,10 +262,10 @@ async def handle_cancel_task( ctx: ServerRequestContext[AppContext], params: CancelTaskRequestParams ) -> CancelTaskResult: app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) + task = await app.store.get_task(params.task_id, session_id="test-session") assert task is not None, f"Test setup error: task {params.task_id} should exist" - await app.store.update_task(params.task_id, status="cancelled") - updated_task = await app.store.get_task(params.task_id) + await app.store.update_task(params.task_id, status="cancelled", session_id="test-session") + updated_task = await app.store.get_task(params.task_id, session_id="test-session") assert updated_task is not None return CancelTaskResult( task_id=updated_task.task_id, diff --git a/tests/experimental/tasks/server/test_context.py b/tests/experimental/tasks/server/test_context.py index a0f1a190d..dcf26dcc2 100644 --- a/tests/experimental/tasks/server/test_context.py +++ b/tests/experimental/tasks/server/test_context.py @@ -12,8 +12,8 @@ async def test_task_context_properties() -> None: """Test TaskContext basic properties.""" store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") + ctx = TaskContext(task, store, session_id="test-session") assert ctx.task_id == task.task_id assert ctx.task.task_id == task.task_id @@ -27,13 +27,13 @@ async def test_task_context_properties() -> None: async def test_task_context_update_status() -> None: """Test TaskContext.update_status.""" store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") + ctx = TaskContext(task, store, session_id="test-session") await ctx.update_status("Processing step 1...") # Check status message was updated - updated = await store.get_task(task.task_id) + updated = await store.get_task(task.task_id, session_id="test-session") assert updated is not None assert updated.status_message == "Processing step 1..." @@ -44,19 +44,19 @@ async def test_task_context_update_status() -> None: async def test_task_context_complete() -> None: """Test TaskContext.complete.""" store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") + ctx = TaskContext(task, store, session_id="test-session") result = CallToolResult(content=[TextContent(type="text", text="Done!")]) await ctx.complete(result) # Check task status - updated = await store.get_task(task.task_id) + updated = await store.get_task(task.task_id, session_id="test-session") assert updated is not None assert updated.status == "completed" # Check result is stored - stored_result = await store.get_result(task.task_id) + stored_result = await store.get_result(task.task_id, session_id="test-session") assert stored_result is not None store.cleanup() @@ -66,13 +66,13 @@ async def test_task_context_complete() -> None: async def test_task_context_fail() -> None: """Test TaskContext.fail.""" store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") + ctx = TaskContext(task, store, session_id="test-session") await ctx.fail("Something went wrong!") # Check task status - updated = await store.get_task(task.task_id) + updated = await store.get_task(task.task_id, session_id="test-session") assert updated is not None assert updated.status == "failed" assert updated.status_message == "Something went wrong!" @@ -84,8 +84,8 @@ async def test_task_context_fail() -> None: async def test_task_context_cancellation() -> None: """Test TaskContext cancellation request.""" store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") + ctx = TaskContext(task, store, session_id="test-session") assert ctx.is_cancelled is False @@ -126,9 +126,9 @@ def test_create_task_state_has_created_at() -> None: async def test_task_execution_provides_context() -> None: """task_execution provides a TaskContext for the task.""" store = InMemoryTaskStore() - await store.create_task(TaskMetadata(ttl=60000), task_id="exec-test-1") + await store.create_task(TaskMetadata(ttl=60000), task_id="exec-test-1", session_id="test-session") - async with task_execution("exec-test-1", store) as ctx: + async with task_execution("exec-test-1", store, session_id="test-session") as ctx: assert ctx.task_id == "exec-test-1" assert ctx.task.status == "working" @@ -139,13 +139,13 @@ async def test_task_execution_provides_context() -> None: async def test_task_execution_auto_fails_on_exception() -> None: """task_execution automatically fails task on unhandled exception.""" store = InMemoryTaskStore() - await store.create_task(TaskMetadata(ttl=60000), task_id="exec-fail-1") + await store.create_task(TaskMetadata(ttl=60000), task_id="exec-fail-1", session_id="test-session") - async with task_execution("exec-fail-1", store): + async with task_execution("exec-fail-1", store, session_id="test-session"): raise RuntimeError("Oops!") # Task should be failed - failed_task = await store.get_task("exec-fail-1") + failed_task = await store.get_task("exec-fail-1", session_id="test-session") assert failed_task is not None assert failed_task.status == "failed" assert "Oops!" in (failed_task.status_message or "") @@ -157,16 +157,16 @@ async def test_task_execution_auto_fails_on_exception() -> None: async def test_task_execution_doesnt_fail_if_already_terminal() -> None: """task_execution doesn't re-fail if task already terminal.""" store = InMemoryTaskStore() - await store.create_task(TaskMetadata(ttl=60000), task_id="exec-term-1") + await store.create_task(TaskMetadata(ttl=60000), task_id="exec-term-1", session_id="test-session") - async with task_execution("exec-term-1", store) as ctx: + async with task_execution("exec-term-1", store, session_id="test-session") as ctx: # Complete the task first await ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) # Then raise - shouldn't change status raise RuntimeError("This shouldn't matter") # Task should remain completed - final_task = await store.get_task("exec-term-1") + final_task = await store.get_task("exec-term-1", session_id="test-session") assert final_task is not None assert final_task.status == "completed" @@ -179,5 +179,5 @@ async def test_task_execution_not_found() -> None: store = InMemoryTaskStore() with pytest.raises(ValueError, match="not found"): - async with task_execution("nonexistent", store): + async with task_execution("nonexistent", store, session_id="test-session"): ... diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py index b5b79033d..5bb4836af 100644 --- a/tests/experimental/tasks/server/test_integration.py +++ b/tests/experimental/tasks/server/test_integration.py @@ -75,13 +75,13 @@ async def handle_call_tool( if params.name == "process_data" and ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None - task = await app.store.create_task(task_metadata) + task = await app.store.create_task(task_metadata, session_id="test-session") done_event = Event() app.task_done_events[task.task_id] = done_event async def do_work() -> None: - async with task_execution(task.task_id, app.store) as task_ctx: + async with task_execution(task.task_id, app.store, session_id="test-session") as task_ctx: await task_ctx.update_status("Processing input...") input_value = (params.arguments or {}).get("input", "") result_text = f"Processed: {input_value.upper()}" @@ -95,7 +95,7 @@ async def do_work() -> None: async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) + task = await app.store.get_task(params.task_id, session_id="test-session") assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, @@ -111,7 +111,7 @@ async def handle_get_task_result( ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: app = ctx.lifespan_context - result = await app.store.get_result(params.task_id) + result = await app.store.get_result(params.task_id, session_id="test-session") assert result is not None, f"Test setup error: result for {params.task_id} should exist" assert isinstance(result, CallToolResult) return GetTaskPayloadResult(**result.model_dump()) @@ -183,13 +183,13 @@ async def handle_call_tool( if params.name == "failing_task" and ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None - task = await app.store.create_task(task_metadata) + task = await app.store.create_task(task_metadata, session_id="test-session") done_event = Event() app.task_done_events[task.task_id] = done_event async def do_failing_work() -> None: - async with task_execution(task.task_id, app.store) as task_ctx: + async with task_execution(task.task_id, app.store, session_id="test-session") as task_ctx: await task_ctx.update_status("About to fail...") raise RuntimeError("Something went wrong!") # This line is reached because task_execution suppresses the exception @@ -202,7 +202,7 @@ async def do_failing_work() -> None: async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) + task = await app.store.get_task(params.task_id, session_id="test-session") assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 6a28b274e..708b5f085 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -340,6 +340,7 @@ async def run_server() -> None: experimental_capabilities={}, ), ), + session_id="test-session", ) as server_session: task_support.configure_session(server_session) async for message in server_session.incoming_messages: @@ -356,7 +357,7 @@ async def run_server() -> None: await client_session.initialize() # Create a task directly in the store for testing - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") # Test list_tasks (default handler) list_result = await client_session.experimental.list_tasks() @@ -373,11 +374,13 @@ async def run_server() -> None: await client_session.experimental.get_task("nonexistent-task") # Create a completed task to test get_task_result - completed_task = await store.create_task(TaskMetadata(ttl=60000)) + completed_task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") await store.store_result( - completed_task.task_id, CallToolResult(content=[TextContent(type="text", text="Test result")]) + completed_task.task_id, + CallToolResult(content=[TextContent(type="text", text="Test result")]), + session_id="test-session", ) - await store.update_task(completed_task.task_id, status="completed") + await store.update_task(completed_task.task_id, status="completed", session_id="test-session") # Test get_task_result (default handler) payload_result = await client_session.send_request( diff --git a/tests/experimental/tasks/server/test_server_task_context.py b/tests/experimental/tasks/server/test_server_task_context.py index e23299698..a8132b5b5 100644 --- a/tests/experimental/tasks/server/test_server_task_context.py +++ b/tests/experimental/tasks/server/test_server_task_context.py @@ -34,8 +34,9 @@ async def test_server_task_context_properties() -> None: """Test ServerTaskContext property accessors.""" store = InMemoryTaskStore() mock_session = Mock() + mock_session.session_id = "test-session" queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-123") + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-123", session_id="test-session") ctx = ServerTaskContext( task=task, @@ -56,8 +57,9 @@ async def test_server_task_context_request_cancellation() -> None: """Test ServerTaskContext.request_cancellation().""" store = InMemoryTaskStore() mock_session = Mock() + mock_session.session_id = "test-session" queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") ctx = ServerTaskContext( task=task, @@ -78,9 +80,10 @@ async def test_server_task_context_update_status_with_notify() -> None: """Test update_status sends notification when notify=True.""" store = InMemoryTaskStore() mock_session = Mock() + mock_session.session_id = "test-session" mock_session.send_notification = AsyncMock() queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") ctx = ServerTaskContext( task=task, @@ -100,9 +103,10 @@ async def test_server_task_context_update_status_without_notify() -> None: """Test update_status skips notification when notify=False.""" store = InMemoryTaskStore() mock_session = Mock() + mock_session.session_id = "test-session" mock_session.send_notification = AsyncMock() queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") ctx = ServerTaskContext( task=task, @@ -122,9 +126,10 @@ async def test_server_task_context_complete_with_notify() -> None: """Test complete sends notification when notify=True.""" store = InMemoryTaskStore() mock_session = Mock() + mock_session.session_id = "test-session" mock_session.send_notification = AsyncMock() queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") ctx = ServerTaskContext( task=task, @@ -145,9 +150,10 @@ async def test_server_task_context_fail_with_notify() -> None: """Test fail sends notification when notify=True.""" store = InMemoryTaskStore() mock_session = Mock() + mock_session.session_id = "test-session" mock_session.send_notification = AsyncMock() queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") ctx = ServerTaskContext( task=task, @@ -167,10 +173,11 @@ async def test_elicit_raises_when_client_lacks_capability() -> None: """Test that elicit() raises MCPError when client doesn't support elicitation.""" store = InMemoryTaskStore() mock_session = Mock() + mock_session.session_id = "test-session" mock_session.check_client_capability = Mock(return_value=False) queue = InMemoryTaskMessageQueue() handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") ctx = ServerTaskContext( task=task, @@ -193,10 +200,11 @@ async def test_create_message_raises_when_client_lacks_capability() -> None: """Test that create_message() raises MCPError when client doesn't support sampling.""" store = InMemoryTaskStore() mock_session = Mock() + mock_session.session_id = "test-session" mock_session.check_client_capability = Mock(return_value=False) queue = InMemoryTaskMessageQueue() handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") ctx = ServerTaskContext( task=task, @@ -219,9 +227,10 @@ async def test_elicit_raises_without_handler() -> None: """Test that elicit() raises when handler is not provided.""" store = InMemoryTaskStore() mock_session = Mock() + mock_session.session_id = "test-session" mock_session.check_client_capability = Mock(return_value=True) queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") ctx = ServerTaskContext( task=task, @@ -242,9 +251,10 @@ async def test_elicit_url_raises_without_handler() -> None: """Test that elicit_url() raises when handler is not provided.""" store = InMemoryTaskStore() mock_session = Mock() + mock_session.session_id = "test-session" mock_session.check_client_capability = Mock(return_value=True) queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") ctx = ServerTaskContext( task=task, @@ -269,9 +279,10 @@ async def test_create_message_raises_without_handler() -> None: """Test that create_message() raises when handler is not provided.""" store = InMemoryTaskStore() mock_session = Mock() + mock_session.session_id = "test-session" mock_session.check_client_capability = Mock(return_value=True) queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") ctx = ServerTaskContext( task=task, @@ -293,9 +304,10 @@ async def test_elicit_queues_request_and_waits_for_response() -> None: store = InMemoryTaskStore() queue = InMemoryTaskMessageQueue() handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") mock_session = Mock() + mock_session.session_id = "test-session" mock_session.check_client_capability = Mock(return_value=True) mock_session._build_elicit_form_request = Mock( return_value=JSONRPCRequest( @@ -330,7 +342,7 @@ async def run_elicit() -> None: await queue.wait_for_message(task.task_id) # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) + updated_task = await store.get_task(task.task_id, session_id="test-session") assert updated_task is not None assert updated_task.status == "input_required" @@ -348,7 +360,7 @@ async def run_elicit() -> None: assert elicit_result.content == {"name": "Alice"} # Verify task is back to working - final_task = await store.get_task(task.task_id) + final_task = await store.get_task(task.task_id, session_id="test-session") assert final_task is not None assert final_task.status == "working" @@ -361,9 +373,10 @@ async def test_elicit_url_queues_request_and_waits_for_response() -> None: store = InMemoryTaskStore() queue = InMemoryTaskMessageQueue() handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") mock_session = Mock() + mock_session.session_id = "test-session" mock_session.check_client_capability = Mock(return_value=True) mock_session._build_elicit_url_request = Mock( return_value=JSONRPCRequest( @@ -399,7 +412,7 @@ async def run_elicit_url() -> None: await queue.wait_for_message(task.task_id) # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) + updated_task = await store.get_task(task.task_id, session_id="test-session") assert updated_task is not None assert updated_task.status == "input_required" @@ -416,7 +429,7 @@ async def run_elicit_url() -> None: assert elicit_result.action == "accept" # Verify task is back to working - final_task = await store.get_task(task.task_id) + final_task = await store.get_task(task.task_id, session_id="test-session") assert final_task is not None assert final_task.status == "working" @@ -429,9 +442,10 @@ async def test_create_message_queues_request_and_waits_for_response() -> None: store = InMemoryTaskStore() queue = InMemoryTaskMessageQueue() handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") mock_session = Mock() + mock_session.session_id = "test-session" mock_session.check_client_capability = Mock(return_value=True) mock_session._build_create_message_request = Mock( return_value=JSONRPCRequest( @@ -466,7 +480,7 @@ async def run_sampling() -> None: await queue.wait_for_message(task.task_id) # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) + updated_task = await store.get_task(task.task_id, session_id="test-session") assert updated_task is not None assert updated_task.status == "input_required" @@ -491,7 +505,7 @@ async def run_sampling() -> None: assert sampling_result.model == "test-model" # Verify task is back to working - final_task = await store.get_task(task.task_id) + final_task = await store.get_task(task.task_id, session_id="test-session") assert final_task is not None assert final_task.status == "working" @@ -504,9 +518,10 @@ async def test_elicit_restores_status_on_cancellation() -> None: store = InMemoryTaskStore() queue = InMemoryTaskMessageQueue() handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") mock_session = Mock() + mock_session.session_id = "test-session" mock_session.check_client_capability = Mock(return_value=True) mock_session._build_elicit_form_request = Mock( return_value=JSONRPCRequest( @@ -546,7 +561,7 @@ async def do_elicit() -> None: await queue.wait_for_message(task.task_id) # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) + updated_task = await store.get_task(task.task_id, session_id="test-session") assert updated_task is not None assert updated_task.status == "input_required" @@ -559,7 +574,7 @@ async def do_elicit() -> None: msg.resolver.set_exception(asyncio.CancelledError()) # Verify task is back to working after cancellation - final_task = await store.get_task(task.task_id) + final_task = await store.get_task(task.task_id, session_id="test-session") assert final_task is not None assert final_task.status == "working" assert cancelled_error_raised @@ -573,9 +588,10 @@ async def test_create_message_restores_status_on_cancellation() -> None: store = InMemoryTaskStore() queue = InMemoryTaskMessageQueue() handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") mock_session = Mock() + mock_session.session_id = "test-session" mock_session.check_client_capability = Mock(return_value=True) mock_session._build_create_message_request = Mock( return_value=JSONRPCRequest( @@ -615,7 +631,7 @@ async def do_sampling() -> None: await queue.wait_for_message(task.task_id) # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) + updated_task = await store.get_task(task.task_id, session_id="test-session") assert updated_task is not None assert updated_task.status == "input_required" @@ -628,7 +644,7 @@ async def do_sampling() -> None: msg.resolver.set_exception(asyncio.CancelledError()) # Verify task is back to working after cancellation - final_task = await store.get_task(task.task_id) + final_task = await store.get_task(task.task_id, session_id="test-session") assert final_task is not None assert final_task.status == "working" assert cancelled_error_raised @@ -641,10 +657,11 @@ async def test_elicit_as_task_raises_without_handler() -> None: """Test that elicit_as_task() raises when handler is not provided.""" store = InMemoryTaskStore() queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") # Create mock session with proper client capabilities mock_session = Mock() + mock_session.session_id = "test-session" mock_session.client_params = InitializeRequestParams( protocol_version="2025-01-01", capabilities=ClientCapabilities( @@ -676,10 +693,11 @@ async def test_create_message_as_task_raises_without_handler() -> None: """Test that create_message_as_task() raises when handler is not provided.""" store = InMemoryTaskStore() queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") # Create mock session with proper client capabilities mock_session = Mock() + mock_session.session_id = "test-session" mock_session.client_params = InitializeRequestParams( protocol_version="2025-01-01", capabilities=ClientCapabilities( diff --git a/tests/experimental/tasks/server/test_store.py b/tests/experimental/tasks/server/test_store.py index d996ec69f..e5dac940a 100644 --- a/tests/experimental/tasks/server/test_store.py +++ b/tests/experimental/tasks/server/test_store.py @@ -22,13 +22,13 @@ async def store() -> AsyncIterator[InMemoryTaskStore]: @pytest.mark.anyio async def test_create_and_get(store: InMemoryTaskStore) -> None: """Test InMemoryTaskStore create and get operations.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") assert task.task_id is not None assert task.status == "working" assert task.ttl == 60000 - retrieved = await store.get_task(task.task_id) + retrieved = await store.get_task(task.task_id, session_id="test-session") assert retrieved is not None assert retrieved.task_id == task.task_id assert retrieved.status == "working" @@ -40,12 +40,13 @@ async def test_create_with_custom_id(store: InMemoryTaskStore) -> None: task = await store.create_task( metadata=TaskMetadata(ttl=60000), task_id="my-custom-id", + session_id="test-session", ) assert task.task_id == "my-custom-id" assert task.status == "working" - retrieved = await store.get_task("my-custom-id") + retrieved = await store.get_task("my-custom-id", session_id="test-session") assert retrieved is not None assert retrieved.task_id == "my-custom-id" @@ -53,30 +54,32 @@ async def test_create_with_custom_id(store: InMemoryTaskStore) -> None: @pytest.mark.anyio async def test_create_duplicate_id_raises(store: InMemoryTaskStore) -> None: """Test that creating a task with duplicate ID raises.""" - await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate") + await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate", session_id="test-session") with pytest.raises(ValueError, match="already exists"): - await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate") + await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate", session_id="test-session") @pytest.mark.anyio async def test_get_nonexistent_returns_none(store: InMemoryTaskStore) -> None: """Test that getting a nonexistent task returns None.""" - retrieved = await store.get_task("nonexistent") + retrieved = await store.get_task("nonexistent", session_id="test-session") assert retrieved is None @pytest.mark.anyio async def test_update_status(store: InMemoryTaskStore) -> None: """Test InMemoryTaskStore status updates.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") - updated = await store.update_task(task.task_id, status="completed", status_message="All done!") + updated = await store.update_task( + task.task_id, status="completed", status_message="All done!", session_id="test-session" + ) assert updated.status == "completed" assert updated.status_message == "All done!" - retrieved = await store.get_task(task.task_id) + retrieved = await store.get_task(task.task_id, session_id="test-session") assert retrieved is not None assert retrieved.status == "completed" assert retrieved.status_message == "All done!" @@ -86,35 +89,35 @@ async def test_update_status(store: InMemoryTaskStore) -> None: async def test_update_nonexistent_raises(store: InMemoryTaskStore) -> None: """Test that updating a nonexistent task raises.""" with pytest.raises(ValueError, match="not found"): - await store.update_task("nonexistent", status="completed") + await store.update_task("nonexistent", status="completed", session_id="test-session") @pytest.mark.anyio async def test_store_and_get_result(store: InMemoryTaskStore) -> None: """Test InMemoryTaskStore result storage and retrieval.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") # Store result result = CallToolResult(content=[TextContent(type="text", text="Result data")]) - await store.store_result(task.task_id, result) + await store.store_result(task.task_id, result, session_id="test-session") # Retrieve result - retrieved_result = await store.get_result(task.task_id) + retrieved_result = await store.get_result(task.task_id, session_id="test-session") assert retrieved_result == result @pytest.mark.anyio async def test_get_result_nonexistent_returns_none(store: InMemoryTaskStore) -> None: """Test that getting result for nonexistent task returns None.""" - result = await store.get_result("nonexistent") + result = await store.get_result("nonexistent", session_id="test-session") assert result is None @pytest.mark.anyio async def test_get_result_no_result_returns_none(store: InMemoryTaskStore) -> None: """Test that getting result when none stored returns None.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - result = await store.get_result(task.task_id) + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") + result = await store.get_result(task.task_id, session_id="test-session") assert result is None @@ -123,9 +126,9 @@ async def test_list_tasks(store: InMemoryTaskStore) -> None: """Test InMemoryTaskStore list operation.""" # Create multiple tasks for _ in range(3): - await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") - tasks, next_cursor = await store.list_tasks() + tasks, next_cursor = await store.list_tasks(session_id="test-session") assert len(tasks) == 3 assert next_cursor is None # Less than page size @@ -138,20 +141,20 @@ async def test_list_tasks_pagination() -> None: # Create 5 tasks for _ in range(5): - await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") # First page - tasks, next_cursor = await store.list_tasks() + tasks, next_cursor = await store.list_tasks(session_id="test-session") assert len(tasks) == 2 assert next_cursor is not None # Second page - tasks, next_cursor = await store.list_tasks(cursor=next_cursor) + tasks, next_cursor = await store.list_tasks(cursor=next_cursor, session_id="test-session") assert len(tasks) == 2 assert next_cursor is not None # Third page (last) - tasks, next_cursor = await store.list_tasks(cursor=next_cursor) + tasks, next_cursor = await store.list_tasks(cursor=next_cursor, session_id="test-session") assert len(tasks) == 1 assert next_cursor is None @@ -161,33 +164,33 @@ async def test_list_tasks_pagination() -> None: @pytest.mark.anyio async def test_list_tasks_invalid_cursor(store: InMemoryTaskStore) -> None: """Test that invalid cursor raises.""" - await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") with pytest.raises(ValueError, match="Invalid cursor"): - await store.list_tasks(cursor="invalid-cursor") + await store.list_tasks(cursor="invalid-cursor", session_id="test-session") @pytest.mark.anyio async def test_delete_task(store: InMemoryTaskStore) -> None: """Test InMemoryTaskStore delete operation.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") - deleted = await store.delete_task(task.task_id) + deleted = await store.delete_task(task.task_id, session_id="test-session") assert deleted is True - retrieved = await store.get_task(task.task_id) + retrieved = await store.get_task(task.task_id, session_id="test-session") assert retrieved is None # Delete non-existent - deleted = await store.delete_task(task.task_id) + deleted = await store.delete_task(task.task_id, session_id="test-session") assert deleted is False @pytest.mark.anyio async def test_get_all_tasks_helper(store: InMemoryTaskStore) -> None: """Test the get_all_tasks debugging helper.""" - await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") + await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") all_tasks = store.get_all_tasks() assert len(all_tasks) == 2 @@ -199,18 +202,18 @@ async def test_store_result_nonexistent_raises(store: InMemoryTaskStore) -> None result = CallToolResult(content=[TextContent(type="text", text="Result")]) with pytest.raises(ValueError, match="not found"): - await store.store_result("nonexistent-id", result) + await store.store_result("nonexistent-id", result, session_id="test-session") @pytest.mark.anyio async def test_create_task_with_null_ttl(store: InMemoryTaskStore) -> None: """Test creating task with null TTL (never expires).""" - task = await store.create_task(metadata=TaskMetadata(ttl=None)) + task = await store.create_task(metadata=TaskMetadata(ttl=None), session_id="test-session") assert task.ttl is None # Task should persist (not expire) - retrieved = await store.get_task(task.task_id) + retrieved = await store.get_task(task.task_id, session_id="test-session") assert retrieved is not None @@ -218,7 +221,7 @@ async def test_create_task_with_null_ttl(store: InMemoryTaskStore) -> None: async def test_task_expiration_cleanup(store: InMemoryTaskStore) -> None: """Test that expired tasks are cleaned up lazily.""" # Create a task with very short TTL - task = await store.create_task(metadata=TaskMetadata(ttl=1)) # 1ms TTL + task = await store.create_task(metadata=TaskMetadata(ttl=1), session_id="test-session") # 1ms TTL # Manually force the expiry to be in the past stored = store._tasks.get(task.task_id) @@ -230,7 +233,7 @@ async def test_task_expiration_cleanup(store: InMemoryTaskStore) -> None: # Any access operation should clean up expired tasks # list_tasks triggers cleanup - tasks, _ = await store.list_tasks() + tasks, _ = await store.list_tasks(session_id="test-session") # Expired task should be cleaned up assert task.task_id not in store._tasks @@ -241,7 +244,7 @@ async def test_task_expiration_cleanup(store: InMemoryTaskStore) -> None: async def test_task_with_null_ttl_never_expires(store: InMemoryTaskStore) -> None: """Test that tasks with null TTL never expire during cleanup.""" # Create task with null TTL - task = await store.create_task(metadata=TaskMetadata(ttl=None)) + task = await store.create_task(metadata=TaskMetadata(ttl=None), session_id="test-session") # Verify internal storage has no expiry stored = store._tasks.get(task.task_id) @@ -249,12 +252,12 @@ async def test_task_with_null_ttl_never_expires(store: InMemoryTaskStore) -> Non assert stored.expires_at is None # Access operations should NOT remove this task - await store.list_tasks() - await store.get_task(task.task_id) + await store.list_tasks(session_id="test-session") + await store.get_task(task.task_id, session_id="test-session") # Task should still exist assert task.task_id in store._tasks - retrieved = await store.get_task(task.task_id) + retrieved = await store.get_task(task.task_id, session_id="test-session") assert retrieved is not None @@ -262,7 +265,7 @@ async def test_task_with_null_ttl_never_expires(store: InMemoryTaskStore) -> Non async def test_terminal_task_ttl_reset(store: InMemoryTaskStore) -> None: """Test that TTL is reset when task enters terminal state.""" # Create task with short TTL - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) # 60s + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") # 60s # Get the initial expiry stored = store._tasks.get(task.task_id) @@ -271,7 +274,7 @@ async def test_terminal_task_ttl_reset(store: InMemoryTaskStore) -> None: assert initial_expiry is not None # Update to terminal state (completed) - await store.update_task(task.task_id, status="completed") + await store.update_task(task.task_id, status="completed", session_id="test-session") # Expiry should be reset to a new time (from now + TTL) new_expiry = stored.expires_at @@ -288,19 +291,19 @@ async def test_terminal_status_transition_rejected(store: InMemoryTaskStore) -> """ # Test each terminal status for terminal_status in ("completed", "failed", "cancelled"): - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") # Move to terminal state - await store.update_task(task.task_id, status=terminal_status) + await store.update_task(task.task_id, status=terminal_status, session_id="test-session") # Attempting to transition to any other status should raise with pytest.raises(ValueError, match="Cannot transition from terminal status"): - await store.update_task(task.task_id, status="working") + await store.update_task(task.task_id, status="working", session_id="test-session") # Also test transitioning to another terminal state other_terminal = "failed" if terminal_status != "failed" else "completed" with pytest.raises(ValueError, match="Cannot transition from terminal status"): - await store.update_task(task.task_id, status=other_terminal) + await store.update_task(task.task_id, status=other_terminal, session_id="test-session") @pytest.mark.anyio @@ -309,15 +312,15 @@ async def test_terminal_status_allows_same_status(store: InMemoryTaskStore) -> N This is not a transition, so it should be allowed (no-op). """ - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="completed") + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") + await store.update_task(task.task_id, status="completed", session_id="test-session") # Setting the same status should not raise - updated = await store.update_task(task.task_id, status="completed") + updated = await store.update_task(task.task_id, status="completed", session_id="test-session") assert updated.status == "completed" # Updating just the message should also work - updated = await store.update_task(task.task_id, status_message="Updated message") + updated = await store.update_task(task.task_id, status_message="Updated message", session_id="test-session") assert updated.status_message == "Updated message" @@ -331,16 +334,16 @@ async def test_wait_for_update_nonexistent_raises(store: InMemoryTaskStore) -> N @pytest.mark.anyio async def test_cancel_task_succeeds_for_working_task(store: InMemoryTaskStore) -> None: """Test cancel_task helper succeeds for a working task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") assert task.status == "working" - result = await cancel_task(store, task.task_id) + result = await cancel_task(store, task.task_id, session_id="test-session") assert result.task_id == task.task_id assert result.status == "cancelled" # Verify store is updated - retrieved = await store.get_task(task.task_id) + retrieved = await store.get_task(task.task_id, session_id="test-session") assert retrieved is not None assert retrieved.status == "cancelled" @@ -349,7 +352,7 @@ async def test_cancel_task_succeeds_for_working_task(store: InMemoryTaskStore) - async def test_cancel_task_rejects_nonexistent_task(store: InMemoryTaskStore) -> None: """Test cancel_task raises MCPError with INVALID_PARAMS for nonexistent task.""" with pytest.raises(MCPError) as exc_info: - await cancel_task(store, "nonexistent-task-id") + await cancel_task(store, "nonexistent-task-id", session_id="test-session") assert exc_info.value.error.code == INVALID_PARAMS assert "not found" in exc_info.value.error.message @@ -358,11 +361,11 @@ async def test_cancel_task_rejects_nonexistent_task(store: InMemoryTaskStore) -> @pytest.mark.anyio async def test_cancel_task_rejects_completed_task(store: InMemoryTaskStore) -> None: """Test cancel_task raises MCPError with INVALID_PARAMS for completed task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="completed") + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") + await store.update_task(task.task_id, status="completed", session_id="test-session") with pytest.raises(MCPError) as exc_info: - await cancel_task(store, task.task_id) + await cancel_task(store, task.task_id, session_id="test-session") assert exc_info.value.error.code == INVALID_PARAMS assert "terminal state 'completed'" in exc_info.value.error.message @@ -371,11 +374,11 @@ async def test_cancel_task_rejects_completed_task(store: InMemoryTaskStore) -> N @pytest.mark.anyio async def test_cancel_task_rejects_failed_task(store: InMemoryTaskStore) -> None: """Test cancel_task raises MCPError with INVALID_PARAMS for failed task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="failed") + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") + await store.update_task(task.task_id, status="failed", session_id="test-session") with pytest.raises(MCPError) as exc_info: - await cancel_task(store, task.task_id) + await cancel_task(store, task.task_id, session_id="test-session") assert exc_info.value.error.code == INVALID_PARAMS assert "terminal state 'failed'" in exc_info.value.error.message @@ -384,11 +387,11 @@ async def test_cancel_task_rejects_failed_task(store: InMemoryTaskStore) -> None @pytest.mark.anyio async def test_cancel_task_rejects_already_cancelled_task(store: InMemoryTaskStore) -> None: """Test cancel_task raises MCPError with INVALID_PARAMS for already cancelled task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="cancelled") + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") + await store.update_task(task.task_id, status="cancelled", session_id="test-session") with pytest.raises(MCPError) as exc_info: - await cancel_task(store, task.task_id) + await cancel_task(store, task.task_id, session_id="test-session") assert exc_info.value.error.code == INVALID_PARAMS assert "terminal state 'cancelled'" in exc_info.value.error.message @@ -397,10 +400,10 @@ async def test_cancel_task_rejects_already_cancelled_task(store: InMemoryTaskSto @pytest.mark.anyio async def test_cancel_task_succeeds_for_input_required_task(store: InMemoryTaskStore) -> None: """Test cancel_task helper succeeds for a task in input_required status.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="input_required") + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="test-session") + await store.update_task(task.task_id, status="input_required", session_id="test-session") - result = await cancel_task(store, task.task_id) + result = await cancel_task(store, task.task_id, session_id="test-session") assert result.task_id == task.task_id assert result.status == "cancelled" @@ -495,26 +498,6 @@ async def test_list_only_tasks_belonging_to_requesting_session(store: InMemoryTa assert len(tasks_b) == 1 -@pytest.mark.anyio -async def test_no_session_id_allows_access_backward_compat(store: InMemoryTaskStore) -> None: - """Test backward compatibility: no session_id on read allows access to all tasks.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="session-a") - - # No session_id on read = no filtering - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - - -@pytest.mark.anyio -async def test_task_created_without_session_id_accessible_by_any_session(store: InMemoryTaskStore) -> None: - """Test that tasks created without session_id are accessible by any session.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - # Any session_id on read should still see the task - retrieved = await store.get_task(task.task_id, session_id="session-b") - assert retrieved is not None - - @pytest.mark.anyio async def test_session_isolation_pagination() -> None: """Test that pagination works correctly within a session.""" diff --git a/tests/experimental/tasks/server/test_task_result_handler.py b/tests/experimental/tasks/server/test_task_result_handler.py index 8b5a03ce2..52995f392 100644 --- a/tests/experimental/tasks/server/test_task_result_handler.py +++ b/tests/experimental/tasks/server/test_task_result_handler.py @@ -51,16 +51,16 @@ async def test_handle_returns_result_for_completed_task( store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler ) -> None: """Test that handle() returns the stored result for a completed task.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task", session_id="test-session") result = CallToolResult(content=[TextContent(type="text", text="Done!")]) - await store.store_result(task.task_id, result) - await store.update_task(task.task_id, status="completed") + await store.store_result(task.task_id, result, session_id="test-session") + await store.update_task(task.task_id, status="completed", session_id="test-session") mock_session = Mock() mock_session.send_message = AsyncMock() request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - response = await handler.handle(request, mock_session, "req-1") + response = await handler.handle(request, mock_session, "req-1", "test-session") assert response is not None assert response.meta is not None @@ -76,7 +76,7 @@ async def test_handle_raises_for_nonexistent_task( request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="nonexistent")) with pytest.raises(MCPError) as exc_info: - await handler.handle(request, mock_session, "req-1") + await handler.handle(request, mock_session, "req-1", "test-session") assert "not found" in exc_info.value.error.message @@ -86,14 +86,14 @@ async def test_handle_returns_empty_result_when_no_result_stored( store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler ) -> None: """Test that handle() returns minimal result when task completed without stored result.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - await store.update_task(task.task_id, status="completed") + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task", session_id="test-session") + await store.update_task(task.task_id, status="completed", session_id="test-session") mock_session = Mock() mock_session.send_message = AsyncMock() request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - response = await handler.handle(request, mock_session, "req-1") + response = await handler.handle(request, mock_session, "req-1", "test-session") assert response is not None assert response.meta is not None @@ -105,7 +105,7 @@ async def test_handle_delivers_queued_messages( store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler ) -> None: """Test that handle() delivers queued messages before returning.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task", session_id="test-session") queued_msg = QueuedMessage( type="notification", @@ -117,7 +117,7 @@ async def test_handle_delivers_queued_messages( ), ) await queue.enqueue(task.task_id, queued_msg) - await store.update_task(task.task_id, status="completed") + await store.update_task(task.task_id, status="completed", session_id="test-session") sent_messages: list[SessionMessage] = [] @@ -128,7 +128,7 @@ async def track_send(msg: SessionMessage) -> None: mock_session.send_message = track_send request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - await handler.handle(request, mock_session, "req-1") + await handler.handle(request, mock_session, "req-1", "test-session") assert len(sent_messages) == 1 @@ -138,7 +138,7 @@ async def test_handle_waits_for_task_completion( store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler ) -> None: """Test that handle() waits for task to complete before returning.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task", session_id="test-session") mock_session = Mock() mock_session.send_message = AsyncMock() @@ -147,7 +147,7 @@ async def test_handle_waits_for_task_completion( result_holder: list[GetTaskPayloadResult | None] = [None] async def run_handle() -> None: - result_holder[0] = await handler.handle(request, mock_session, "req-1") + result_holder[0] = await handler.handle(request, mock_session, "req-1", "test-session") async with anyio.create_task_group() as tg: tg.start_soon(run_handle) @@ -156,8 +156,10 @@ async def run_handle() -> None: while task.task_id not in store._update_events: await anyio.sleep(0) - await store.store_result(task.task_id, CallToolResult(content=[TextContent(type="text", text="Done")])) - await store.update_task(task.task_id, status="completed") + await store.store_result( + task.task_id, CallToolResult(content=[TextContent(type="text", text="Done")]), session_id="test-session" + ) + await store.update_task(task.task_id, status="completed", session_id="test-session") assert result_holder[0] is not None @@ -234,7 +236,7 @@ async def test_deliver_registers_resolver_for_request_messages( store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler ) -> None: """Test that _deliver_queued_messages registers resolvers for request messages.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task", session_id="test-session") resolver: Resolver[dict[str, Any]] = Resolver() queued_msg = QueuedMessage( @@ -264,7 +266,7 @@ async def test_deliver_skips_resolver_registration_when_no_original_id( store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler ) -> None: """Test that _deliver_queued_messages skips resolver registration when original_request_id is None.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task", session_id="test-session") resolver: Resolver[dict[str, Any]] = Resolver() queued_msg = QueuedMessage( @@ -296,7 +298,7 @@ async def test_wait_for_task_update_handles_store_exception( store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler ) -> None: """Test that _wait_for_task_update handles store exception gracefully.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task", session_id="test-session") # Make wait_for_update raise an exception async def failing_wait(task_id: str) -> None: @@ -333,7 +335,7 @@ async def test_wait_for_task_update_handles_queue_exception( store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler ) -> None: """Test that _wait_for_task_update handles queue exception gracefully.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task", session_id="test-session") # Make wait_for_message raise an exception async def failing_wait(task_id: str) -> None: @@ -346,7 +348,7 @@ async def update_later() -> None: # Wait for store to start waiting (event gets created when wait starts) while task.task_id not in store._update_events: await anyio.sleep(0) - await store.update_task(task.task_id, status="completed") + await store.update_task(task.task_id, status="completed", session_id="test-session") async with anyio.create_task_group() as tg: tg.start_soon(update_later) diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py index 2d0378a9c..d6002641c 100644 --- a/tests/experimental/tasks/test_elicitation_scenarios.py +++ b/tests/experimental/tasks/test_elicitation_scenarios.py @@ -61,13 +61,13 @@ async def handle_augmented_elicitation( ) -> CreateTaskResult: """Handle task-augmented elicitation by creating a client-side task.""" elicit_received.set() - task = await client_task_store.create_task(task_metadata) + task = await client_task_store.create_task(task_metadata, session_id="test-session") task_complete_events[task.task_id] = Event() async def complete_task() -> None: # Store result before updating status to avoid race condition - await client_task_store.store_result(task.task_id, elicit_response) - await client_task_store.update_task(task.task_id, status="completed") + await client_task_store.store_result(task.task_id, elicit_response, session_id="test-session") + await client_task_store.update_task(task.task_id, status="completed", session_id="test-session") task_complete_events[task.task_id].set() context.session._task_group.start_soon(complete_task) # pyright: ignore[reportPrivateUsage] @@ -78,7 +78,7 @@ async def handle_get_task( params: Any, ) -> GetTaskResult: """Handle tasks/get from server.""" - task = await client_task_store.get_task(params.task_id) + task = await client_task_store.get_task(params.task_id, session_id="test-session") assert task is not None, f"Task not found: {params.task_id}" return GetTaskResult( task_id=task.task_id, @@ -98,7 +98,7 @@ async def handle_get_task_result( event = task_complete_events.get(params.task_id) assert event is not None, f"No completion event for task: {params.task_id}" await event.wait() - result = await client_task_store.get_result(params.task_id) + result = await client_task_store.get_result(params.task_id, session_id="test-session") assert result is not None, f"Result not found for task: {params.task_id}" return GetTaskPayloadResult.model_validate(result.model_dump(by_alias=True)) @@ -129,13 +129,13 @@ async def handle_augmented_sampling( ) -> CreateTaskResult: """Handle task-augmented sampling by creating a client-side task.""" sampling_received.set() - task = await client_task_store.create_task(task_metadata) + task = await client_task_store.create_task(task_metadata, session_id="test-session") task_complete_events[task.task_id] = Event() async def complete_task() -> None: # Store result before updating status to avoid race condition - await client_task_store.store_result(task.task_id, sampling_response) - await client_task_store.update_task(task.task_id, status="completed") + await client_task_store.store_result(task.task_id, sampling_response, session_id="test-session") + await client_task_store.update_task(task.task_id, status="completed", session_id="test-session") task_complete_events[task.task_id].set() context.session._task_group.start_soon(complete_task) # pyright: ignore[reportPrivateUsage] @@ -146,7 +146,7 @@ async def handle_get_task( params: Any, ) -> GetTaskResult: """Handle tasks/get from server.""" - task = await client_task_store.get_task(params.task_id) + task = await client_task_store.get_task(params.task_id, session_id="test-session") assert task is not None, f"Task not found: {params.task_id}" return GetTaskResult( task_id=task.task_id, @@ -166,7 +166,7 @@ async def handle_get_task_result( event = task_complete_events.get(params.task_id) assert event is not None, f"No completion event for task: {params.task_id}" await event.wait() - result = await client_task_store.get_result(params.task_id) + result = await client_task_store.get_result(params.task_id, session_id="test-session") assert result is not None, f"Result not found for task: {params.task_id}" return GetTaskPayloadResult.model_validate(result.model_dump(by_alias=True)) @@ -230,6 +230,7 @@ async def run_server() -> None: notification_options=NotificationOptions(), experimental_capabilities={}, ), + session_id="test-session", ) async def run_client() -> None: @@ -307,6 +308,7 @@ async def run_server() -> None: notification_options=NotificationOptions(), experimental_capabilities={}, ), + session_id="test-session", ) async def run_client() -> None: @@ -386,6 +388,7 @@ async def run_server() -> None: notification_options=NotificationOptions(), experimental_capabilities={}, ), + session_id="test-session", ) async def run_client() -> None: @@ -483,6 +486,7 @@ async def run_server() -> None: notification_options=NotificationOptions(), experimental_capabilities={}, ), + session_id="test-session", ) async def run_client() -> None: @@ -577,6 +581,7 @@ async def run_server() -> None: notification_options=NotificationOptions(), experimental_capabilities={}, ), + session_id="test-session", ) async def run_client() -> None: @@ -656,6 +661,7 @@ async def run_server() -> None: notification_options=NotificationOptions(), experimental_capabilities={}, ), + session_id="test-session", ) async def run_client() -> None: From e4172291e7f046c294b090c0a8ef75b101d129ce Mon Sep 17 00:00:00 2001 From: Den Delimarsky Date: Fri, 20 Feb 2026 08:35:58 +0000 Subject: [PATCH 4/6] Fix tests :house: Remote-Dev: homespace --- .../tasks/server/test_run_task_flow.py | 24 ++++++++++ .../experimental/tasks/server/test_server.py | 47 +++++++++++++++++++ .../tasks/server/test_server_task_context.py | 16 +++++++ 3 files changed, 87 insertions(+) diff --git a/tests/experimental/tasks/server/test_run_task_flow.py b/tests/experimental/tasks/server/test_run_task_flow.py index 027382e69..26ae6ed9a 100644 --- a/tests/experimental/tasks/server/test_run_task_flow.py +++ b/tests/experimental/tasks/server/test_run_task_flow.py @@ -31,6 +31,7 @@ GetTaskResult, ListToolsResult, PaginatedRequestParams, + TaskMetadata, TextContent, ) @@ -365,3 +366,26 @@ async def work(task: ServerTaskContext) -> CallToolResult: break assert status.status_message == "Manually failed" + + +@pytest.mark.anyio +async def test_run_task_without_session_id_raises() -> None: + """Test that run_task raises when session has no session_id.""" + task_support = TaskSupport.in_memory() + + mock_session = Mock() + mock_session.session_id = None + + experimental = Experimental( + task_metadata=TaskMetadata(ttl=60000), + _client_capabilities=None, + _session=mock_session, + _task_support=task_support, + ) + + async def work(task: ServerTaskContext) -> CallToolResult: + raise NotImplementedError + + async with task_support.run(): + with pytest.raises(RuntimeError, match="Session ID is required"): + await experimental.run_task(work) diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 708b5f085..8450684ad 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -399,6 +399,53 @@ async def run_server() -> None: tg.cancel_scope.cancel() +@pytest.mark.anyio +async def test_default_task_handlers_require_session_id() -> None: + """Test that default task handlers reject requests when session has no session_id.""" + server = Server("test-no-session-id") + server.experimental.enable_tasks() + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: ... # pragma: no branch + + async def run_server() -> None: + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + # session_id is None (no session_id passed) + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, {}, False) + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # All default task handlers should fail with "Session ID is required" + with pytest.raises(MCPError, match="Session ID is required"): + await client_session.experimental.list_tasks() + + tg.cancel_scope.cancel() + + @pytest.mark.anyio async def test_build_elicit_form_request() -> None: """Test that _build_elicit_form_request builds a proper elicitation request.""" diff --git a/tests/experimental/tasks/server/test_server_task_context.py b/tests/experimental/tasks/server/test_server_task_context.py index a8132b5b5..1856fbd19 100644 --- a/tests/experimental/tasks/server/test_server_task_context.py +++ b/tests/experimental/tasks/server/test_server_task_context.py @@ -725,3 +725,19 @@ async def test_create_message_as_task_raises_without_handler() -> None: ) store.cleanup() + + +@pytest.mark.anyio +async def test_server_task_context_requires_session_id() -> None: + """Test that ServerTaskContext raises when session has no session_id.""" + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") + + mock_session = Mock() + mock_session.session_id = None + + with pytest.raises(RuntimeError, match="Session ID is required for task operations"): + ServerTaskContext(task=task, store=store, session=mock_session, queue=queue) + + store.cleanup() From a4f5ade51eb2d4e8c79b5ad27d40a38b902256ae Mon Sep 17 00:00:00 2001 From: Den Delimarsky Date: Fri, 20 Feb 2026 08:38:29 +0000 Subject: [PATCH 5/6] Fix feedback item :house: Remote-Dev: homespace --- src/mcp/server/streamable_http_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 1e416fbd8..2e737e934 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -172,7 +172,7 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA write_stream, self.app.create_initialization_options(), stateless=True, - session_id=http_transport.mcp_session_id, + session_id=None, # No session in stateless mode ) except Exception: # pragma: no cover logger.exception("Stateless session crashed") From 113f35c726563fb4cd6f1f60d92c0927cfac5361 Mon Sep 17 00:00:00 2001 From: Den Delimarsky Date: Mon, 9 Mar 2026 07:52:57 +0000 Subject: [PATCH 6/6] Make session_id optional in TaskStore; reject stateless mode instead Previously, session_id was required (str) throughout the task store interface, which forced sessionless transports like stdio and memory to fabricate UUIDs before running tasks. This was awkward because those transports are single-client by architecture and have no session concept. The policy is now enforced at the handler layer instead of the store layer: default task handlers reject requests when the server is in stateless mode (where tasks cannot survive across requests), and pass session_id through as-is otherwise. A None session_id simply means no session-scoped isolation, which is correct for single-client transports. Isolation in InMemoryTaskStore uses strict equality: None only matches None, so tasks created by a sessionless transport are not visible to session-scoped transports and vice versa, preventing cross-transport leaks when a process serves multiple transports from one store. Changes: - TaskStore, InMemoryTaskStore, TaskContext, helpers: session_id is now str | None - ServerSession: expose stateless property - Default task handlers: reject stateless mode, pass session_id as-is - run_task() / ServerTaskContext: accept None session_id, reject stateless mode - Memory transport: revert fabricated UUID session_id - New tests for None-session isolation (strict equality, no cross-transport leaks) :house: Remote-Dev: homespace --- src/mcp/client/_memory.py | 3 - .../server/experimental/request_context.py | 6 +- src/mcp/server/experimental/task_context.py | 7 +- .../experimental/task_result_handler.py | 5 +- src/mcp/server/lowlevel/experimental.py | 29 ++++---- src/mcp/server/session.py | 5 ++ src/mcp/shared/experimental/tasks/context.py | 4 +- src/mcp/shared/experimental/tasks/helpers.py | 12 +-- .../tasks/in_memory_task_store.py | 24 +++--- src/mcp/shared/experimental/tasks/store.py | 39 ++++++---- .../tasks/server/test_run_task_flow.py | 7 +- .../experimental/tasks/server/test_server.py | 15 ++-- .../tasks/server/test_server_task_context.py | 15 ++-- tests/experimental/tasks/server/test_store.py | 73 +++++++++++++++++++ 14 files changed, 172 insertions(+), 72 deletions(-) diff --git a/src/mcp/client/_memory.py b/src/mcp/client/_memory.py index 337deeb75..e6e938673 100644 --- a/src/mcp/client/_memory.py +++ b/src/mcp/client/_memory.py @@ -2,7 +2,6 @@ from __future__ import annotations -import uuid from collections.abc import AsyncIterator from contextlib import AbstractAsyncContextManager, asynccontextmanager from types import TracebackType @@ -51,14 +50,12 @@ async def _connect(self) -> AsyncIterator[TransportStreams]: async with anyio.create_task_group() as tg: # Start server in background - memory_session_id = uuid.uuid4().hex tg.start_soon( lambda: actual_server.run( server_read, server_write, actual_server.create_initialization_options(), raise_exceptions=self._raise_exceptions, - session_id=memory_session_id, ) ) diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index 0b7a86184..81d1534a3 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -189,9 +189,11 @@ async def work(task: ServerTaskContext) -> CallToolResult: # Access task_group via TaskSupport - raises if not in run() context task_group = support.task_group + if self._session.stateless: + raise RuntimeError( + "run_task() does not support stateless mode. Tasks require a persistent session for result retrieval." + ) session_id = self._session.session_id - if session_id is None: - raise RuntimeError("Session ID is required for task operations but session has no ID.") task = await support.store.create_task(self.task_metadata, task_id, session_id=session_id) task_ctx = ServerTaskContext( diff --git a/src/mcp/server/experimental/task_context.py b/src/mcp/server/experimental/task_context.py index 9f26d5fed..98381ce0d 100644 --- a/src/mcp/server/experimental/task_context.py +++ b/src/mcp/server/experimental/task_context.py @@ -90,11 +90,8 @@ def __init__( queue: The message queue for elicitation/sampling handler: The result handler for response routing (required for elicit/create_message) """ - session_id = session.session_id - if session_id is None: - raise RuntimeError("Session ID is required for task operations but session has no ID.") - self._session_id = session_id - self._ctx = TaskContext(task=task, store=store, session_id=session_id) + self._session_id = session.session_id + self._ctx = TaskContext(task=task, store=store, session_id=self._session_id) self._session = session self._queue = queue self._handler = handler diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index 8d922acd0..9c4d3b4d4 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -80,7 +80,7 @@ async def handle( request: GetTaskPayloadRequest, session: ServerSession, request_id: RequestId, - session_id: str, + session_id: str | None, ) -> GetTaskPayloadResult: """Handle a tasks/result request. @@ -95,7 +95,8 @@ async def handle( request: The GetTaskPayloadRequest session: The server session for sending messages request_id: The request ID for relatedRequestId routing - session_id: Session identifier for access control. + session_id: Session identifier for access control. Must exactly + match the session_id the task was created with (including None). Returns: GetTaskPayloadResult with the task's final payload diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 5fc4b493e..c58b276cc 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -153,14 +153,15 @@ def enable_tasks( if on_cancel_task is not None: self._add_request_handler("tasks/cancel", on_cancel_task) - def _require_session_id(ctx: ServerRequestContext[LifespanResultT]) -> str: - session_id = ctx.session.session_id - if session_id is None: + def _check_stateless(ctx: ServerRequestContext[LifespanResultT]) -> None: + if ctx.session.stateless: raise MCPError( code=INVALID_PARAMS, - message="Session ID is required for task operations.", + message=( + "Default task handlers do not support stateless mode. " + "Provide custom task handlers if you need stateless task support." + ), ) - return session_id # Fill in defaults for any not provided if not self._has_handler("tasks/get"): @@ -168,8 +169,8 @@ def _require_session_id(ctx: ServerRequestContext[LifespanResultT]) -> str: async def _default_get_task( ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams ) -> GetTaskResult: - session_id = _require_session_id(ctx) - task = await task_support.store.get_task(params.task_id, session_id=session_id) + _check_stateless(ctx) + task = await task_support.store.get_task(params.task_id, session_id=ctx.session.session_id) if task is None: raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}") return GetTaskResult( @@ -190,9 +191,11 @@ async def _default_get_task_result( ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: assert ctx.request_id is not None - session_id = _require_session_id(ctx) + _check_stateless(ctx) req = GetTaskPayloadRequest(params=params) - result = await task_support.handler.handle(req, ctx.session, ctx.request_id, session_id=session_id) + result = await task_support.handler.handle( + req, ctx.session, ctx.request_id, session_id=ctx.session.session_id + ) return result self._add_request_handler("tasks/result", _default_get_task_result) @@ -202,9 +205,9 @@ async def _default_get_task_result( async def _default_list_tasks( ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None ) -> ListTasksResult: + _check_stateless(ctx) cursor = params.cursor if params else None - session_id = _require_session_id(ctx) - tasks, next_cursor = await task_support.store.list_tasks(cursor, session_id=session_id) + tasks, next_cursor = await task_support.store.list_tasks(cursor, session_id=ctx.session.session_id) return ListTasksResult(tasks=tasks, next_cursor=next_cursor) self._add_request_handler("tasks/list", _default_list_tasks) @@ -214,8 +217,8 @@ async def _default_list_tasks( async def _default_cancel_task( ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams ) -> CancelTaskResult: - session_id = _require_session_id(ctx) - result = await cancel_task(task_support.store, params.task_id, session_id=session_id) + _check_stateless(ctx) + result = await cancel_task(task_support.store, params.task_id, session_id=ctx.session.session_id) return result self._add_request_handler("tasks/cancel", _default_cancel_task) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 1495e9d11..a708082d5 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -111,6 +111,11 @@ def _receive_notification_adapter(self) -> TypeAdapter[types.ClientNotification] def client_params(self) -> types.InitializeRequestParams | None: return self._client_params + @property + def stateless(self) -> bool: + """Whether this session is in stateless mode (no persistent server-side state).""" + return self._stateless + @property def experimental(self) -> ExperimentalServerSessionFeatures: """Experimental APIs for server→client task operations. diff --git a/src/mcp/shared/experimental/tasks/context.py b/src/mcp/shared/experimental/tasks/context.py index 5eae0477c..c967641d7 100644 --- a/src/mcp/shared/experimental/tasks/context.py +++ b/src/mcp/shared/experimental/tasks/context.py @@ -21,7 +21,7 @@ class TaskContext: use ServerTaskContext from mcp.server.experimental. Example (distributed worker): - async def worker_job(task_id: str, session_id: str): + async def worker_job(task_id: str, session_id: str | None): store = RedisTaskStore(redis_url) task = await store.get_task(task_id, session_id=session_id) ctx = TaskContext(task=task, store=store, session_id=session_id) @@ -31,7 +31,7 @@ async def worker_job(task_id: str, session_id: str): await ctx.complete(result) """ - def __init__(self, task: Task, store: TaskStore, *, session_id: str): + def __init__(self, task: Task, store: TaskStore, *, session_id: str | None): self._task = task self._store = store self._session_id = session_id diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py index e8a777016..bbbb8cca7 100644 --- a/src/mcp/shared/experimental/tasks/helpers.py +++ b/src/mcp/shared/experimental/tasks/helpers.py @@ -51,7 +51,7 @@ async def cancel_task( store: TaskStore, task_id: str, *, - session_id: str, + session_id: str | None, ) -> CancelTaskResult: """Cancel a task with spec-compliant validation. @@ -64,7 +64,8 @@ async def cancel_task( Args: store: The task store task_id: The task identifier to cancel - session_id: Session identifier for access control. + session_id: Session identifier for access control. Must exactly match + the session_id the task was created with (including None). Returns: CancelTaskResult with the cancelled task state @@ -128,7 +129,7 @@ async def task_execution( task_id: str, store: TaskStore, *, - session_id: str, + session_id: str | None, ) -> AsyncIterator[TaskContext]: """Context manager for safe task execution (pure, no server dependencies). @@ -141,7 +142,8 @@ async def task_execution( Args: task_id: The task identifier to execute store: The task store (must be accessible by the worker) - session_id: Session identifier for access control. + session_id: Session identifier for access control. Must exactly match + the session_id the task was created with (including None). Yields: TaskContext for updating status and completing/failing the task @@ -150,7 +152,7 @@ async def task_execution( ValueError: If the task is not found in the store Example (distributed worker): - async def worker_process(task_id: str, session_id: str): + async def worker_process(task_id: str, session_id: str | None): store = RedisTaskStore(redis_url) async with task_execution(task_id, store, session_id=session_id) as ctx: await ctx.update_status("Working...") diff --git a/src/mcp/shared/experimental/tasks/in_memory_task_store.py b/src/mcp/shared/experimental/tasks/in_memory_task_store.py index 46c7775e1..298ad9f17 100644 --- a/src/mcp/shared/experimental/tasks/in_memory_task_store.py +++ b/src/mcp/shared/experimental/tasks/in_memory_task_store.py @@ -22,7 +22,7 @@ class StoredTask: """Internal storage representation of a task.""" task: Task - session_id: str + session_id: str | None result: Result | None = None # Time when this task should be removed (None = never) expires_at: datetime | None = field(default=None) @@ -68,10 +68,13 @@ def _cleanup_expired(self) -> None: for task_id in expired_ids: del self._tasks[task_id] - def _get_stored_task(self, task_id: str, *, session_id: str) -> StoredTask | None: + def _get_stored_task(self, task_id: str, *, session_id: str | None) -> StoredTask | None: """Retrieve a stored task, enforcing session ownership. Returns None if the task does not exist or belongs to a different session. + Isolation uses strict equality: None only matches None, so tasks created + by sessionless transports (stdio) are not visible to session-scoped + transports (HTTP) and vice versa. """ stored = self._tasks.get(task_id) if stored is None: @@ -85,7 +88,7 @@ async def create_task( metadata: TaskMetadata, task_id: str | None = None, *, - session_id: str, + session_id: str | None, ) -> Task: """Create a new task with the given metadata.""" # Cleanup expired tasks on access @@ -106,7 +109,7 @@ async def create_task( # Return a copy to prevent external modification return Task(**task.model_dump()) - async def get_task(self, task_id: str, *, session_id: str) -> Task | None: + async def get_task(self, task_id: str, *, session_id: str | None) -> Task | None: """Get a task by ID.""" # Cleanup expired tasks on access self._cleanup_expired() @@ -124,7 +127,7 @@ async def update_task( status: TaskStatus | None = None, status_message: str | None = None, *, - session_id: str, + session_id: str | None, ) -> Task: """Update a task's status and/or message.""" stored = self._get_stored_task(task_id, session_id=session_id) @@ -156,7 +159,7 @@ async def update_task( return Task(**stored.task.model_dump()) - async def store_result(self, task_id: str, result: Result, *, session_id: str) -> None: + async def store_result(self, task_id: str, result: Result, *, session_id: str | None) -> None: """Store the result for a task.""" stored = self._get_stored_task(task_id, session_id=session_id) if stored is None: @@ -164,7 +167,7 @@ async def store_result(self, task_id: str, result: Result, *, session_id: str) - stored.result = result - async def get_result(self, task_id: str, *, session_id: str) -> Result | None: + async def get_result(self, task_id: str, *, session_id: str | None) -> Result | None: """Get the stored result for a task.""" stored = self._get_stored_task(task_id, session_id=session_id) if stored is None: @@ -176,13 +179,14 @@ async def list_tasks( self, cursor: str | None = None, *, - session_id: str, + session_id: str | None, ) -> tuple[list[Task], str | None]: """List tasks with pagination.""" # Cleanup expired tasks on access self._cleanup_expired() - # Filter tasks by session ownership before pagination + # Filter tasks by session ownership before pagination. + # Strict equality: None only matches None (sessionless transports). filtered_task_ids = [task_id for task_id, stored in self._tasks.items() if stored.session_id == session_id] start_index = 0 @@ -203,7 +207,7 @@ async def list_tasks( return tasks, next_cursor - async def delete_task(self, task_id: str, *, session_id: str) -> bool: + async def delete_task(self, task_id: str, *, session_id: str | None) -> bool: """Delete a task.""" stored = self._get_stored_task(task_id, session_id=session_id) if stored is None: diff --git a/src/mcp/shared/experimental/tasks/store.py b/src/mcp/shared/experimental/tasks/store.py index 6845c1b1b..068b0d5c0 100644 --- a/src/mcp/shared/experimental/tasks/store.py +++ b/src/mcp/shared/experimental/tasks/store.py @@ -20,15 +20,17 @@ async def create_task( metadata: TaskMetadata, task_id: str | None = None, *, - session_id: str, + session_id: str | None, ) -> Task: """Create a new task. Args: metadata: Task metadata (ttl, etc.) task_id: Optional task ID. If None, implementation should generate one. - session_id: Session identifier. The task is bound to this session - for isolation purposes. + session_id: Session identifier for isolation. If None, the task is + not bound to a session (single-client transports like stdio). + If a string, the task is scoped to that session and only + accessible with the same session_id. Returns: The created Task with status="working" @@ -38,12 +40,13 @@ async def create_task( """ @abstractmethod - async def get_task(self, task_id: str, *, session_id: str) -> Task | None: + async def get_task(self, task_id: str, *, session_id: str | None) -> Task | None: """Get a task by ID. Args: task_id: The task identifier - session_id: Session identifier for access control. + session_id: Session identifier for access control. Must exactly + match the session_id the task was created with (including None). Returns: The Task, or None if not found or not accessible by this session. @@ -56,7 +59,7 @@ async def update_task( status: TaskStatus | None = None, status_message: str | None = None, *, - session_id: str, + session_id: str | None, ) -> Task: """Update a task's status and/or message. @@ -64,7 +67,8 @@ async def update_task( task_id: The task identifier status: New status (if changing) status_message: New status message (if changing) - session_id: Session identifier for access control. + session_id: Session identifier for access control. Must exactly + match the session_id the task was created with (including None). Returns: The updated Task @@ -77,25 +81,27 @@ async def update_task( """ @abstractmethod - async def store_result(self, task_id: str, result: Result, *, session_id: str) -> None: + async def store_result(self, task_id: str, result: Result, *, session_id: str | None) -> None: """Store the result for a task. Args: task_id: The task identifier result: The result to store - session_id: Session identifier for access control. + session_id: Session identifier for access control. Must exactly + match the session_id the task was created with (including None). Raises: ValueError: If task not found or not accessible by this session. """ @abstractmethod - async def get_result(self, task_id: str, *, session_id: str) -> Result | None: + async def get_result(self, task_id: str, *, session_id: str | None) -> Result | None: """Get the stored result for a task. Args: task_id: The task identifier - session_id: Session identifier for access control. + session_id: Session identifier for access control. Must exactly + match the session_id the task was created with (including None). Returns: The stored Result, or None if not available. @@ -106,26 +112,27 @@ async def list_tasks( self, cursor: str | None = None, *, - session_id: str, + session_id: str | None, ) -> tuple[list[Task], str | None]: """List tasks with pagination. Args: cursor: Optional cursor for pagination - session_id: Session identifier. Only tasks belonging to this - session are returned. + session_id: Session identifier. Only tasks with an exactly matching + session_id are returned (None only matches tasks created with None). Returns: Tuple of (tasks, next_cursor). next_cursor is None if no more pages. """ @abstractmethod - async def delete_task(self, task_id: str, *, session_id: str) -> bool: + async def delete_task(self, task_id: str, *, session_id: str | None) -> bool: """Delete a task. Args: task_id: The task identifier - session_id: Session identifier for access control. + session_id: Session identifier for access control. Must exactly + match the session_id the task was created with (including None). Returns: True if deleted, False if not found or not accessible by this session. diff --git a/tests/experimental/tasks/server/test_run_task_flow.py b/tests/experimental/tasks/server/test_run_task_flow.py index 26ae6ed9a..03ef623e1 100644 --- a/tests/experimental/tasks/server/test_run_task_flow.py +++ b/tests/experimental/tasks/server/test_run_task_flow.py @@ -369,12 +369,13 @@ async def work(task: ServerTaskContext) -> CallToolResult: @pytest.mark.anyio -async def test_run_task_without_session_id_raises() -> None: - """Test that run_task raises when session has no session_id.""" +async def test_run_task_in_stateless_mode_raises() -> None: + """Test that run_task raises when the session is in stateless mode.""" task_support = TaskSupport.in_memory() mock_session = Mock() mock_session.session_id = None + mock_session.stateless = True experimental = Experimental( task_metadata=TaskMetadata(ttl=60000), @@ -387,5 +388,5 @@ async def work(task: ServerTaskContext) -> CallToolResult: raise NotImplementedError async with task_support.run(): - with pytest.raises(RuntimeError, match="Session ID is required"): + with pytest.raises(RuntimeError, match="does not support stateless mode"): await experimental.run_task(work) diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 8450684ad..bb76c8f1a 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -400,9 +400,13 @@ async def run_server() -> None: @pytest.mark.anyio -async def test_default_task_handlers_require_session_id() -> None: - """Test that default task handlers reject requests when session has no session_id.""" - server = Server("test-no-session-id") +async def test_default_task_handlers_reject_stateless_mode() -> None: + """Test that default task handlers reject requests in stateless mode. + + Task operations require a persistent session for result retrieval; stateless + mode creates a fresh session per request, so tasks cannot survive across requests. + """ + server = Server("test-stateless") server.experimental.enable_tasks() server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) @@ -424,8 +428,8 @@ async def run_server() -> None: experimental_capabilities={}, ), ), + stateless=True, ) as server_session: - # session_id is None (no session_id passed) async for message in server_session.incoming_messages: await server._handle_message(message, server_session, {}, False) @@ -439,8 +443,7 @@ async def run_server() -> None: ) as client_session: await client_session.initialize() - # All default task handlers should fail with "Session ID is required" - with pytest.raises(MCPError, match="Session ID is required"): + with pytest.raises(MCPError, match="do not support stateless mode"): await client_session.experimental.list_tasks() tg.cancel_scope.cancel() diff --git a/tests/experimental/tasks/server/test_server_task_context.py b/tests/experimental/tasks/server/test_server_task_context.py index 1856fbd19..651f71069 100644 --- a/tests/experimental/tasks/server/test_server_task_context.py +++ b/tests/experimental/tasks/server/test_server_task_context.py @@ -728,16 +728,21 @@ async def test_create_message_as_task_raises_without_handler() -> None: @pytest.mark.anyio -async def test_server_task_context_requires_session_id() -> None: - """Test that ServerTaskContext raises when session has no session_id.""" +async def test_server_task_context_accepts_none_session_id() -> None: + """Test that ServerTaskContext works with session_id=None (sessionless transports like stdio).""" store = InMemoryTaskStore() queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000), session_id="test-session") + task = await store.create_task(TaskMetadata(ttl=60000), session_id=None) mock_session = Mock() mock_session.session_id = None + mock_session.send_notification = AsyncMock() + + ctx = ServerTaskContext(task=task, store=store, session=mock_session, queue=queue) + await ctx.update_status("Working...") - with pytest.raises(RuntimeError, match="Session ID is required for task operations"): - ServerTaskContext(task=task, store=store, session=mock_session, queue=queue) + retrieved = await store.get_task(task.task_id, session_id=None) + assert retrieved is not None + assert retrieved.status_message == "Working..." store.cleanup() diff --git a/tests/experimental/tasks/server/test_store.py b/tests/experimental/tasks/server/test_store.py index e5dac940a..92a70c860 100644 --- a/tests/experimental/tasks/server/test_store.py +++ b/tests/experimental/tasks/server/test_store.py @@ -541,3 +541,76 @@ async def test_cancel_task_with_session_isolation(store: InMemoryTaskStore) -> N # session-a should be able to cancel its own task result = await cancel_task(store, task.task_id, session_id="session-a") assert result.status == "cancelled" + + +# --- None session_id (sessionless transports like stdio) --- + + +@pytest.mark.anyio +async def test_none_session_id_can_access_own_tasks(store: InMemoryTaskStore) -> None: + """Test that a None session_id (sessionless transport) can access tasks it created. + + This verifies stdio/memory transports work: they have no session concept, so + they create and retrieve tasks with session_id=None. + """ + task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id=None) + + retrieved = await store.get_task(task.task_id, session_id=None) + assert retrieved is not None + assert retrieved.task_id == task.task_id + + tasks, _ = await store.list_tasks(session_id=None) + assert len(tasks) == 1 + assert tasks[0].task_id == task.task_id + + +@pytest.mark.anyio +async def test_none_session_cannot_read_session_scoped_task(store: InMemoryTaskStore) -> None: + """Test that a None session_id cannot read tasks created with a real session_id. + + Strict equality isolation: a sessionless client (stdio) cannot see tasks + created by a session-scoped client (HTTP), even when sharing a store. + """ + http_task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="http-session-1") + + assert await store.get_task(http_task.task_id, session_id=None) is None + assert await store.get_result(http_task.task_id, session_id=None) is None + assert await store.delete_task(http_task.task_id, session_id=None) is False + + tasks, _ = await store.list_tasks(session_id=None) + assert len(tasks) == 0 + + +@pytest.mark.anyio +async def test_session_cannot_read_none_session_task(store: InMemoryTaskStore) -> None: + """Test that a real session_id cannot read tasks created with session_id=None. + + Strict equality isolation: an HTTP session cannot see tasks created by a + sessionless client (stdio), closing the gap present in the TypeScript SDK. + """ + stdio_task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id=None) + + assert await store.get_task(stdio_task.task_id, session_id="http-session-1") is None + assert await store.get_result(stdio_task.task_id, session_id="http-session-1") is None + assert await store.delete_task(stdio_task.task_id, session_id="http-session-1") is False + + tasks, _ = await store.list_tasks(session_id="http-session-1") + assert len(tasks) == 0 + + +@pytest.mark.anyio +async def test_none_and_session_scoped_tasks_coexist(store: InMemoryTaskStore) -> None: + """Test that None-session and session-scoped tasks coexist without leaking.""" + stdio_task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id=None) + http_a_task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="http-a") + http_b_task = await store.create_task(metadata=TaskMetadata(ttl=60000), session_id="http-b") + + # Each scope sees exactly its own task + stdio_tasks, _ = await store.list_tasks(session_id=None) + assert [t.task_id for t in stdio_tasks] == [stdio_task.task_id] + + http_a_tasks, _ = await store.list_tasks(session_id="http-a") + assert [t.task_id for t in http_a_tasks] == [http_a_task.task_id] + + http_b_tasks, _ = await store.list_tasks(session_id="http-b") + assert [t.task_id for t in http_b_tasks] == [http_b_task.task_id]