From b7929edb899b52d45e87d8be20de8b786f341809 Mon Sep 17 00:00:00 2001 From: whatevertogo Date: Tue, 17 Feb 2026 03:43:19 +0800 Subject: [PATCH 01/11] feat(stdio): add MCP tool toggle support for stdio transport Align stdio transport tool visibility with HTTP transport behavior: - Add enabled_tools field to stdio status heartbeat file - Filter tools/list based on Unity-side tool enable/disable state - Add manage_editor actions: set_mcp_tool_enabled, get_mcp_tool_enabled, list_mcp_tools - Immediately refresh status file on tool toggle (no heartbeat delay) - Fail-open: missing/invalid status files skip filtering (backward compatible) Unity changes: - ManageEditor.cs: add tool toggle actions with self-lock protection - StdioBridgeHost.cs: include enabled_tools in heartbeat payload Server changes: - manage_editor.py: add new action parameters and bool coercion - unity_instance_middleware.py: stdio tools/list filtering with multi-instance support Tests: - Python: stdio tool filter tests, param coercion tests - Unity: ManageEditorToolToggleTests, TransportCommandDispatcherToolToggleTests Co-Authored-By: Claude Opus 4.6 --- .../Transport/Transports/StdioBridgeHost.cs | 35 +++- MCPForUnity/Editor/Tools/ManageEditor.cs | 133 ++++++++++++- Server/README.md | 28 +++ Server/src/services/tools/manage_editor.py | 20 +- .../transport/unity_instance_middleware.py | 129 +++++++++++- .../test_manage_editor_param_coercion.py | 45 +++++ ...y_instance_middleware_stdio_tool_filter.py | 183 ++++++++++++++++++ ...ansportCommandDispatcherToolToggleTests.cs | 102 ++++++++++ ...rtCommandDispatcherToolToggleTests.cs.meta | 11 ++ .../Tools/ManageEditorToolToggleTests.cs | 117 +++++++++++ .../Tools/ManageEditorToolToggleTests.cs.meta | 11 ++ 11 files changed, 806 insertions(+), 8 deletions(-) create mode 100644 Server/tests/integration/test_manage_editor_param_coercion.py create mode 100644 Server/tests/test_unity_instance_middleware_stdio_tool_filter.py create mode 100644 TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/TransportCommandDispatcherToolToggleTests.cs create mode 100644 TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/TransportCommandDispatcherToolToggleTests.cs.meta create mode 100644 TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs create mode 100644 TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs.meta diff --git a/MCPForUnity/Editor/Services/Transport/Transports/StdioBridgeHost.cs b/MCPForUnity/Editor/Services/Transport/Transports/StdioBridgeHost.cs index 8a9037a7e..b9b6019d6 100644 --- a/MCPForUnity/Editor/Services/Transport/Transports/StdioBridgeHost.cs +++ b/MCPForUnity/Editor/Services/Transport/Transports/StdioBridgeHost.cs @@ -976,6 +976,19 @@ private static bool IsValidJson(string text) return false; } + public static void RefreshStatusFile(string reason = "manual_refresh") + { + try + { + heartbeatSeq++; + WriteHeartbeat(false, reason); + } + catch (Exception ex) + { + McpLog.Warn($"Failed to refresh stdio status file: {ex.Message}"); + } + } + public static void WriteHeartbeat(bool reloading, string reason = null) { @@ -987,7 +1000,8 @@ public static void WriteHeartbeat(bool reloading, string reason = null) dir = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), ".unity-mcp"); } Directory.CreateDirectory(dir); - string filePath = Path.Combine(dir, $"unity-mcp-status-{ComputeProjectHash(Application.dataPath)}.json"); + string projectHash = ComputeProjectHash(Application.dataPath); + string filePath = Path.Combine(dir, $"unity-mcp-status-{projectHash}.json"); string projectName = "Unknown"; try @@ -1009,14 +1023,33 @@ public static void WriteHeartbeat(bool reloading, string reason = null) } catch { } + string[] enabledTools = Array.Empty(); + try + { + var toolMetadata = MCPServiceLocator.ToolDiscovery.GetEnabledTools(); + enabledTools = toolMetadata + ?.Select(tool => tool?.Name) + .Where(name => !string.IsNullOrWhiteSpace(name)) + .Distinct(StringComparer.Ordinal) + .OrderBy(name => name, StringComparer.Ordinal) + .ToArray() + ?? Array.Empty(); + } + catch (Exception ex) + { + McpLog.Warn($"Failed to resolve enabled tools for stdio status file: {ex.Message}"); + } + var payload = new { unity_port = currentUnityPort, reloading, reason = reason ?? (reloading ? "reloading" : "ready"), seq = heartbeatSeq, + project_hash = projectHash, project_path = Application.dataPath, project_name = projectName, + enabled_tools = enabledTools, unity_version = Application.unityVersion, last_heartbeat = DateTime.UtcNow.ToString("O") }; diff --git a/MCPForUnity/Editor/Tools/ManageEditor.cs b/MCPForUnity/Editor/Tools/ManageEditor.cs index 27048031a..d204e4c34 100644 --- a/MCPForUnity/Editor/Tools/ManageEditor.cs +++ b/MCPForUnity/Editor/Tools/ManageEditor.cs @@ -1,5 +1,7 @@ using System; using MCPForUnity.Editor.Helpers; +using MCPForUnity.Editor.Services; +using MCPForUnity.Editor.Services.Transport.Transports; using Newtonsoft.Json.Linq; using UnityEditor; using UnityEditorInternal; // Required for tag management @@ -101,6 +103,38 @@ public static object HandleCommand(JObject @params) if (!toolNameResult.IsSuccess) return new ErrorResponse(toolNameResult.ErrorMessage); return SetActiveTool(toolNameResult.Value); + case "set_mcp_tool_enabled": + var setToolEnabledNameResult = p.GetRequired( + "toolName", + "'toolName' parameter required for set_mcp_tool_enabled."); + if (!setToolEnabledNameResult.IsSuccess) + { + return new ErrorResponse(setToolEnabledNameResult.ErrorMessage); + } + + if (!p.Has("enabled")) + { + return new ErrorResponse("'enabled' parameter required for set_mcp_tool_enabled."); + } + + bool? enabled = ParamCoercion.CoerceBoolNullable(p.GetRaw("enabled")); + if (!enabled.HasValue) + { + return new ErrorResponse("'enabled' parameter must be a boolean."); + } + + return SetMcpToolEnabled(setToolEnabledNameResult.Value, enabled.Value); + case "get_mcp_tool_enabled": + var getToolEnabledNameResult = p.GetRequired( + "toolName", + "'toolName' parameter required for get_mcp_tool_enabled."); + if (!getToolEnabledNameResult.IsSuccess) + { + return new ErrorResponse(getToolEnabledNameResult.ErrorMessage); + } + return GetMcpToolEnabled(getToolEnabledNameResult.Value); + case "list_mcp_tools": + return ListMcpTools(); // Tag Management case "add_tag": @@ -136,7 +170,7 @@ public static object HandleCommand(JObject @params) default: return new ErrorResponse( - $"Unknown action: '{action}'. Supported actions: play, pause, stop, set_active_tool, add_tag, remove_tag, add_layer, remove_layer. Use MCP resources for reading editor state, project info, tags, layers, selection, windows, prefab stage, and active tool." + $"Unknown action: '{action}'. Supported actions: play, pause, stop, set_active_tool, set_mcp_tool_enabled, get_mcp_tool_enabled, list_mcp_tools, add_tag, remove_tag, add_layer, remove_layer. Use MCP resources for reading editor state, project info, tags, layers, selection, windows, prefab stage, and active tool." ); } } @@ -178,6 +212,103 @@ private static object SetActiveTool(string toolName) } } + private static object SetMcpToolEnabled(string toolName, bool enabled) + { + if (string.IsNullOrWhiteSpace(toolName)) + { + return new ErrorResponse("Tool name cannot be empty."); + } + + if (string.Equals(toolName, "manage_editor", StringComparison.OrdinalIgnoreCase) && !enabled) + { + return new ErrorResponse("Tool 'manage_editor' cannot be disabled."); + } + + var metadata = MCPServiceLocator.ToolDiscovery.GetToolMetadata(toolName); + if (metadata == null) + { + return new ErrorResponse($"Unknown tool '{toolName}'."); + } + + MCPServiceLocator.ToolDiscovery.SetToolEnabled(metadata.Name, enabled); + RefreshStdioStatusFile(); + + return new SuccessResponse( + $"Tool '{metadata.Name}' {(enabled ? "enabled" : "disabled")} successfully.", + new + { + toolName = metadata.Name, + enabled + }); + } + + private static object GetMcpToolEnabled(string toolName) + { + if (string.IsNullOrWhiteSpace(toolName)) + { + return new ErrorResponse("Tool name cannot be empty."); + } + + var metadata = MCPServiceLocator.ToolDiscovery.GetToolMetadata(toolName); + if (metadata == null) + { + return new ErrorResponse($"Unknown tool '{toolName}'."); + } + + bool enabled = MCPServiceLocator.ToolDiscovery.IsToolEnabled(metadata.Name); + return new SuccessResponse( + $"Tool '{metadata.Name}' is {(enabled ? "enabled" : "disabled")}.", + new + { + toolName = metadata.Name, + enabled + }); + } + + private static object ListMcpTools() + { + try + { + var discoveredTools = MCPServiceLocator.ToolDiscovery.DiscoverAllTools(); + var toolStates = new JArray(); + + foreach (var tool in discoveredTools) + { + toolStates.Add(new JObject + { + ["name"] = tool.Name, + ["enabled"] = MCPServiceLocator.ToolDiscovery.IsToolEnabled(tool.Name), + ["autoRegister"] = tool.AutoRegister, + ["isBuiltIn"] = tool.IsBuiltIn + }); + } + + return new SuccessResponse( + $"Listed {toolStates.Count} MCP tools.", + new JObject + { + ["toolCount"] = toolStates.Count, + ["tools"] = toolStates + }); + } + catch (Exception e) + { + return new ErrorResponse($"Failed to list MCP tools: {e.Message}"); + } + } + + private static void RefreshStdioStatusFile() + { + try + { + StdioBridgeHost.RefreshStatusFile("tool_toggle"); + } + catch (Exception e) + { + McpLog.Warn($"Failed to refresh stdio status file after tool toggle: {e.Message}"); + } + } + // --- Tag Management Methods --- private static object AddTag(string tagName) diff --git a/Server/README.md b/Server/README.md index 5435360f0..c64490770 100644 --- a/Server/README.md +++ b/Server/README.md @@ -164,6 +164,34 @@ Telemetry: - `UNITY_MCP_TELEMETRY_ENDPOINT` - Override telemetry endpoint URL - `UNITY_MCP_TELEMETRY_TIMEOUT` - Override telemetry request timeout (seconds) +### MCP tool toggles in stdio + +The `manage_editor` tool exposes MCP tool enable/disable controls: + +- `set_mcp_tool_enabled` (`tool_name`, `enabled`) +- `get_mcp_tool_enabled` (`tool_name`) +- `list_mcp_tools` + +Example: + +```json +{ + "action": "set_mcp_tool_enabled", + "tool_name": "manage_scene", + "enabled": false +} +``` + +When running in `stdio`, `tools/list` is filtered by Unity's enabled tool state. +If all Unity-managed tools are disabled, `tools/list` will only show server-only tools. +The Unity status file (`~/.unity-mcp/unity-mcp-status-.json`) now includes: + +- `project_hash` +- `enabled_tools` + +Tool toggle changes trigger an immediate status-file refresh, so `tools/list` +updates do not depend on waiting for the next heartbeat. + ### Examples **Stdio (default):** diff --git a/Server/src/services/tools/manage_editor.py b/Server/src/services/tools/manage_editor.py index 51480b75d..0bfa58880 100644 --- a/Server/src/services/tools/manage_editor.py +++ b/Server/src/services/tools/manage_editor.py @@ -12,18 +12,20 @@ @mcp_for_unity_tool( - description="Controls and queries the Unity editor's state and settings. Tip: pass booleans as true/false; if your client only sends strings, 'true'/'false' are accepted. Read-only actions: telemetry_status, telemetry_ping. Modifying actions: play, pause, stop, set_active_tool, add_tag, remove_tag, add_layer, remove_layer.", + description="Controls and queries the Unity editor's state and settings. Tip: pass booleans as true/false; if your client only sends strings, 'true'/'false' are accepted. Read-only actions: telemetry_status, telemetry_ping, get_mcp_tool_enabled, list_mcp_tools. Modifying actions: play, pause, stop, set_active_tool, set_mcp_tool_enabled, add_tag, remove_tag, add_layer, remove_layer.", annotations=ToolAnnotations( title="Manage Editor", ), ) async def manage_editor( ctx: Context, - action: Annotated[Literal["telemetry_status", "telemetry_ping", "play", "pause", "stop", "set_active_tool", "add_tag", "remove_tag", "add_layer", "remove_layer"], "Get and update the Unity Editor state."], + action: Annotated[Literal["telemetry_status", "telemetry_ping", "play", "pause", "stop", "set_active_tool", "set_mcp_tool_enabled", "get_mcp_tool_enabled", "list_mcp_tools", "add_tag", "remove_tag", "add_layer", "remove_layer"], "Get and update the Unity Editor state."], wait_for_completion: Annotated[bool | str, "Optional. If True, waits for certain actions (accepts true/false or 'true'/'false')"] | None = None, tool_name: Annotated[str, - "Tool name when setting active tool"] | None = None, + "Tool name when setting active tool or updating/querying MCP tool enabled state"] | None = None, + enabled: Annotated[bool | str, + "Optional. Required for set_mcp_tool_enabled (accepts true/false or 'true'/'false')"] | None = None, tag_name: Annotated[str, "Tag name when adding and removing tags"] | None = None, layer_name: Annotated[str, @@ -32,7 +34,14 @@ async def manage_editor( # Get active instance from request state (injected by middleware) unity_instance = get_unity_instance_from_context(ctx) - wait_for_completion = coerce_bool(wait_for_completion) + wait_for_completion_value = coerce_bool(wait_for_completion, default=None) + enabled_value = coerce_bool(enabled, default=None) + + if enabled is not None and enabled_value is None: + return { + "success": False, + "message": "enabled must be a boolean value ('true'/'false').", + } try: # Diagnostics: quick telemetry checks @@ -45,8 +54,9 @@ async def manage_editor( # Prepare parameters, removing None values params = { "action": action, - "waitForCompletion": wait_for_completion, + "waitForCompletion": wait_for_completion_value, "toolName": tool_name, + "enabled": enabled_value, "tagName": tag_name, "layerName": layer_name, } diff --git a/Server/src/transport/unity_instance_middleware.py b/Server/src/transport/unity_instance_middleware.py index 41b4e8baf..d25fd28fa 100644 --- a/Server/src/transport/unity_instance_middleware.py +++ b/Server/src/transport/unity_instance_middleware.py @@ -5,8 +5,11 @@ into the request-scoped state, allowing tools to access it via ctx.get_state("unity_instance"). """ from threading import RLock +import json import logging +import os import time +from pathlib import Path from fastmcp.server.middleware import Middleware, MiddlewareContext @@ -294,13 +297,22 @@ async def on_list_tools(self, context: MiddlewareContext, call_next): def _should_filter_tool_listing(self) -> bool: transport = (config.transport_mode or "stdio").lower() - return transport == "http" and PluginHub.is_configured() + if transport == "http": + return PluginHub.is_configured() + + return transport == "stdio" async def _resolve_enabled_tool_names_for_context( self, context: MiddlewareContext, ) -> set[str] | None: ctx = context.fastmcp_context + transport = (config.transport_mode or "stdio").lower() + + if transport == "stdio": + active_instance = ctx.get_state("unity_instance") + return self._resolve_enabled_tool_names_for_stdio_context(active_instance) + user_id = ctx.get_state("user_id") if config.http_remote_hosted else None active_instance = ctx.get_state("unity_instance") project_hashes = self._resolve_candidate_project_hashes(active_instance) @@ -373,6 +385,121 @@ async def _resolve_enabled_tool_names_for_context( return enabled_tool_names + def _resolve_enabled_tool_names_for_stdio_context(self, active_instance: str | None) -> set[str] | None: + status_payloads = self._list_stdio_status_payloads() + if not status_payloads: + return None + + project_hashes = self._resolve_candidate_project_hashes(active_instance) + if project_hashes: + active_hash = project_hashes[0] + for payload in status_payloads: + if payload["project_hash"] == active_hash: + return payload["enabled_tools"] + + logger.debug( + "No stdio status payload matched active hash '%s'; skipping tools/list filtering.", + active_hash, + ) + return None + + # Multi-instance edge case (no active_instance selected): merge enabled tools from + # all discovered status files so tools/list does not "flicker" between instances. + # This intentionally favors stability over strict per-instance precision. + enabled_by_project_hash: dict[str, set[str]] = {} + for payload in status_payloads: + project_hash = payload["project_hash"] + if project_hash in enabled_by_project_hash: + continue + enabled_by_project_hash[project_hash] = payload["enabled_tools"] + + if len(enabled_by_project_hash) > 1: + union_enabled_tools: set[str] = set() + for enabled_tools in enabled_by_project_hash.values(): + union_enabled_tools.update(enabled_tools) + return union_enabled_tools + + # status_payloads is non-empty here, and every payload contributes a valid + # project_hash; after de-duplication this leaves exactly one project entry. + return next(iter(enabled_by_project_hash.values())) + + def _list_stdio_status_payloads(self) -> list[dict[str, object]]: + status_dir_env = os.getenv("UNITY_MCP_STATUS_DIR") + status_dir = Path(status_dir_env).expanduser() if status_dir_env else Path.home().joinpath(".unity-mcp") + + try: + status_files = sorted( + status_dir.glob("unity-mcp-status-*.json"), + key=lambda path: path.stat().st_mtime, + reverse=True, + ) + except OSError as exc: + logger.debug( + "Failed to enumerate stdio status files from %s: %s", + status_dir, + exc, + exc_info=True, + ) + return [] + + payloads: list[dict[str, object]] = [] + for status_file in status_files: + file_hash = self._extract_project_hash_from_filename(status_file) + try: + with status_file.open("r", encoding="utf-8") as handle: + raw_payload = json.load(handle) + except (OSError, ValueError) as exc: + logger.debug( + "Failed to parse stdio status file %s: %s", + status_file, + exc, + exc_info=True, + ) + continue + + if not isinstance(raw_payload, dict): + logger.debug("Skipping stdio status file %s with non-object payload.", status_file) + continue + + enabled_tools_raw = raw_payload.get("enabled_tools") + if not isinstance(enabled_tools_raw, list): + # Missing enabled_tools means the status format is too old for safe filtering. + logger.debug("Skipping stdio status file %s without enabled_tools field.", status_file) + continue + + enabled_tools = { + tool_name + for tool_name in enabled_tools_raw + if isinstance(tool_name, str) and tool_name + } + + project_hash = raw_payload.get("project_hash") + if not isinstance(project_hash, str) or not project_hash: + project_hash = file_hash + + if not project_hash: + logger.debug("Skipping stdio status file %s without project hash.", status_file) + continue + + payloads.append( + { + "project_hash": project_hash, + "enabled_tools": enabled_tools, + } + ) + + return payloads + + @staticmethod + def _extract_project_hash_from_filename(status_file: Path) -> str | None: + prefix = "unity-mcp-status-" + stem = status_file.stem + if not stem.startswith(prefix): + return None + + suffix = stem[len(prefix):] + return suffix or None + def _refresh_tool_visibility_metadata_from_registry(self) -> None: now = time.monotonic() if now - self._last_tool_visibility_refresh < self._tool_visibility_refresh_interval_seconds: diff --git a/Server/tests/integration/test_manage_editor_param_coercion.py b/Server/tests/integration/test_manage_editor_param_coercion.py new file mode 100644 index 000000000..9d4c9a011 --- /dev/null +++ b/Server/tests/integration/test_manage_editor_param_coercion.py @@ -0,0 +1,45 @@ +import asyncio + +from .test_helpers import DummyContext +import services.tools.manage_editor as manage_editor_mod + + +def test_manage_editor_set_mcp_tool_enabled_string_coercion(monkeypatch): + captured = {} + + async def fake_send(_func, _instance, _tool_name, params, **_kwargs): + captured["params"] = params + return {"success": True, "message": "ok"} + + monkeypatch.setattr(manage_editor_mod, "send_with_unity_instance", fake_send) + + result = asyncio.run( + manage_editor_mod.manage_editor( + ctx=DummyContext(), + action="set_mcp_tool_enabled", + tool_name="manage_scene", + enabled="false", + ) + ) + + assert result["success"] is True + assert captured["params"]["enabled"] is False + + +def test_manage_editor_set_mcp_tool_enabled_invalid_boolean(monkeypatch): + async def fake_send(_func, _instance, _tool_name, params, **_kwargs): + return {"success": True, "message": "ok"} + + monkeypatch.setattr(manage_editor_mod, "send_with_unity_instance", fake_send) + + result = asyncio.run( + manage_editor_mod.manage_editor( + ctx=DummyContext(), + action="set_mcp_tool_enabled", + tool_name="manage_scene", + enabled="invalid-bool", + ) + ) + + assert result["success"] is False + assert "enabled" in result["message"] diff --git a/Server/tests/test_unity_instance_middleware_stdio_tool_filter.py b/Server/tests/test_unity_instance_middleware_stdio_tool_filter.py new file mode 100644 index 000000000..b45c7ba85 --- /dev/null +++ b/Server/tests/test_unity_instance_middleware_stdio_tool_filter.py @@ -0,0 +1,183 @@ +import json +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from core.config import config +from transport.unity_instance_middleware import UnityInstanceMiddleware + + +def _tool_registry_for_visibility_tests() -> list[dict]: + return [ + {"name": "manage_scene", "unity_target": "manage_scene"}, + {"name": "manage_script", "unity_target": "manage_script"}, + {"name": "manage_asset", "unity_target": "manage_asset"}, + {"name": "create_script", "unity_target": "manage_script"}, + {"name": "set_active_instance", "unity_target": None}, + ] + + +def _build_fastmcp_context(active_instance: str | None = None) -> Mock: + state = {} + if active_instance: + state["unity_instance"] = active_instance + + ctx = Mock() + ctx.client_id = "test-client" + ctx.set_state = Mock(side_effect=lambda key, value: state.__setitem__(key, value)) + ctx.get_state = Mock(side_effect=lambda key: state.get(key)) + return ctx + + +def _write_status_file(path, payload: dict) -> None: + path.write_text(json.dumps(payload), encoding="utf-8") + + +async def _filter_tool_names(middleware: UnityInstanceMiddleware, fastmcp_context: Mock) -> list[str]: + middleware_ctx = SimpleNamespace(fastmcp_context=fastmcp_context) + available_tools = [ + SimpleNamespace(name="manage_scene"), + SimpleNamespace(name="manage_asset"), + SimpleNamespace(name="create_script"), + SimpleNamespace(name="set_active_instance"), + SimpleNamespace(name="custom_server_tool"), + ] + + async def call_next(_ctx): + return available_tools + + with patch.object(middleware, "_inject_unity_instance", new=AsyncMock()): + with patch( + "transport.unity_instance_middleware.get_registered_tools", + return_value=_tool_registry_for_visibility_tests(), + ): + filtered = await middleware.on_list_tools(middleware_ctx, call_next) + + return [tool.name for tool in filtered] + + +@pytest.mark.asyncio +async def test_stdio_list_tools_filters_enabled_tools(monkeypatch, tmp_path): + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + { + "project_hash": "abc123", + "enabled_tools": ["manage_scene", "manage_script"], + }, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + assert "manage_scene" in names + assert "create_script" in names + assert "set_active_instance" in names + assert "custom_server_tool" in names + assert "manage_asset" not in names + + +@pytest.mark.asyncio +async def test_stdio_list_tools_skips_filter_when_status_file_missing(monkeypatch, tmp_path): + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + assert "manage_scene" in names + assert "manage_asset" in names + assert "create_script" in names + assert "set_active_instance" in names + assert "custom_server_tool" in names + + +@pytest.mark.asyncio +async def test_stdio_list_tools_skips_filter_when_status_file_is_invalid_json(monkeypatch, tmp_path): + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + (tmp_path / "unity-mcp-status-abc123.json").write_text("{invalid", encoding="utf-8") + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + assert "manage_scene" in names + assert "manage_asset" in names + assert "create_script" in names + + +@pytest.mark.asyncio +async def test_stdio_list_tools_prefers_active_instance_hash_when_multiple_files(monkeypatch, tmp_path): + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-first11.json", + { + "project_hash": "first11", + "enabled_tools": ["manage_asset"], + }, + ) + _write_status_file( + tmp_path / "unity-mcp-status-second22.json", + { + "project_hash": "second22", + "enabled_tools": ["manage_scene"], + }, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("AnyName@first11")) + + assert "manage_asset" in names + assert "manage_scene" not in names + + +@pytest.mark.asyncio +async def test_stdio_list_tools_uses_union_when_no_active_instance_and_multiple_hashes(monkeypatch, tmp_path): + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-first11.json", + { + "project_hash": "first11", + "enabled_tools": ["manage_scene"], + }, + ) + _write_status_file( + tmp_path / "unity-mcp-status-second22.json", + { + "project_hash": "second22", + "enabled_tools": ["manage_asset"], + }, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context(None)) + + assert "manage_scene" in names + assert "manage_asset" in names + assert "set_active_instance" in names + + +@pytest.mark.asyncio +async def test_stdio_list_tools_skips_filter_when_enabled_tools_field_missing(monkeypatch, tmp_path): + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + { + "project_hash": "abc123", + "unity_port": 6400, + }, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + assert "manage_scene" in names + assert "manage_asset" in names + assert "create_script" in names diff --git a/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/TransportCommandDispatcherToolToggleTests.cs b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/TransportCommandDispatcherToolToggleTests.cs new file mode 100644 index 000000000..6d771e89c --- /dev/null +++ b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/TransportCommandDispatcherToolToggleTests.cs @@ -0,0 +1,102 @@ +using System; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using MCPForUnity.Editor.Constants; +using Newtonsoft.Json.Linq; +using NUnit.Framework; +using UnityEditor; + +namespace MCPForUnityTests.Editor.Services +{ + public class TransportCommandDispatcherToolToggleTests + { + private const string TargetTool = "manage_scene"; + private string _targetToolPrefKey; + private bool _hadTargetToolPref; + private bool _previousTargetToolEnabled; + + [SetUp] + public void SetUp() + { + _targetToolPrefKey = EditorPrefKeys.ToolEnabledPrefix + TargetTool; + _hadTargetToolPref = EditorPrefs.HasKey(_targetToolPrefKey); + _previousTargetToolEnabled = EditorPrefs.GetBool(_targetToolPrefKey, true); + } + + [TearDown] + public void TearDown() + { + if (_hadTargetToolPref) + { + EditorPrefs.SetBool(_targetToolPrefKey, _previousTargetToolEnabled); + } + else + { + EditorPrefs.DeleteKey(_targetToolPrefKey); + } + } + + [Test] + public void ExecuteCommandJsonAsync_WhenToolDisabled_ReturnsDisabledError() + { + EditorPrefs.SetBool(_targetToolPrefKey, false); + + string payload = new JObject + { + ["type"] = TargetTool, + ["params"] = new JObject + { + ["action"] = "ping", + }, + }.ToString(); + + string responseJson = ExecuteCommandJson(payload); + var response = JObject.Parse(responseJson); + string error = response["error"]?.ToString() ?? string.Empty; + + Assert.AreEqual("error", response["status"]?.ToString()); + StringAssert.Contains("disabled in the Unity Editor", error); + } + + [Test] + public void ExecuteCommandJsonAsync_WhenToolEnabled_DoesNotReturnDisabledError() + { + EditorPrefs.SetBool(_targetToolPrefKey, true); + + string payload = new JObject + { + ["type"] = TargetTool, + ["params"] = new JObject + { + ["action"] = "ping", + }, + }.ToString(); + + string responseJson = ExecuteCommandJson(payload); + var response = JObject.Parse(responseJson); + string error = response["error"]?.ToString() ?? string.Empty; + + Assert.Less(error.IndexOf("disabled in the Unity Editor", StringComparison.OrdinalIgnoreCase), 0); + } + + private static string ExecuteCommandJson(string commandJson) + { + Type dispatcherType = Type.GetType( + "MCPForUnity.Editor.Services.Transport.TransportCommandDispatcher, MCPForUnity.Editor"); + Assert.IsNotNull(dispatcherType, "Failed to resolve TransportCommandDispatcher type."); + + MethodInfo executeMethod = dispatcherType.GetMethod( + "ExecuteCommandJsonAsync", + BindingFlags.Public | BindingFlags.Static); + Assert.IsNotNull(executeMethod, "Failed to resolve ExecuteCommandJsonAsync."); + + var task = executeMethod.Invoke( + null, + new object[] { commandJson, CancellationToken.None }) as Task; + Assert.IsNotNull(task, "ExecuteCommandJsonAsync did not return Task."); + + return task.GetAwaiter().GetResult(); + } + } +} diff --git a/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/TransportCommandDispatcherToolToggleTests.cs.meta b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/TransportCommandDispatcherToolToggleTests.cs.meta new file mode 100644 index 000000000..bc9541677 --- /dev/null +++ b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/TransportCommandDispatcherToolToggleTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 780f26d83984444d9f19c0398cc46d3f +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs new file mode 100644 index 000000000..44ce789a1 --- /dev/null +++ b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs @@ -0,0 +1,117 @@ +using System; +using System.Linq; +using MCPForUnity.Editor.Constants; +using MCPForUnity.Editor.Helpers; +using MCPForUnity.Editor.Tools; +using Newtonsoft.Json.Linq; +using NUnit.Framework; +using UnityEditor; + +namespace MCPForUnityTests.Editor.Tools +{ + public class ManageEditorToolToggleTests + { + private const string TargetTool = "manage_scene"; + private string _targetToolPrefKey; + private bool _hadTargetToolPref; + private bool _previousTargetToolEnabled; + + [SetUp] + public void SetUp() + { + _targetToolPrefKey = EditorPrefKeys.ToolEnabledPrefix + TargetTool; + _hadTargetToolPref = EditorPrefs.HasKey(_targetToolPrefKey); + _previousTargetToolEnabled = EditorPrefs.GetBool(_targetToolPrefKey, true); + } + + [TearDown] + public void TearDown() + { + if (_hadTargetToolPref) + { + EditorPrefs.SetBool(_targetToolPrefKey, _previousTargetToolEnabled); + } + else + { + EditorPrefs.DeleteKey(_targetToolPrefKey); + } + } + + [Test] + public void HandleCommand_SetMcpToolEnabled_UpdatesStoredToolPreference() + { + var result = ManageEditor.HandleCommand(new JObject + { + ["action"] = "set_mcp_tool_enabled", + ["toolName"] = TargetTool, + ["enabled"] = false, + }); + + var response = JObject.FromObject(result); + Assert.AreEqual(true, response["success"]?.Value()); + Assert.AreEqual(false, EditorPrefs.GetBool(_targetToolPrefKey, true)); + } + + [Test] + public void HandleCommand_SetMcpToolEnabled_RejectsDisablingManageEditor() + { + var result = ManageEditor.HandleCommand(new JObject + { + ["action"] = "set_mcp_tool_enabled", + ["toolName"] = "manage_editor", + ["enabled"] = false, + }); + + var response = JObject.FromObject(result); + Assert.AreEqual(false, response["success"]?.Value()); + StringAssert.Contains("cannot be disabled", response["error"]?.ToString()); + } + + [Test] + public void HandleCommand_GetMcpToolEnabled_ReturnsCurrentState() + { + EditorPrefs.SetBool(_targetToolPrefKey, false); + + var result = ManageEditor.HandleCommand(new JObject + { + ["action"] = "get_mcp_tool_enabled", + ["toolName"] = TargetTool, + }); + + var response = JObject.FromObject(result); + Assert.AreEqual(true, response["success"]?.Value()); + + var data = response["data"] as JObject; + Assert.IsNotNull(data); + Assert.AreEqual(TargetTool, data["toolName"]?.ToString()); + Assert.AreEqual(false, data["enabled"]?.Value()); + } + + [Test] + public void HandleCommand_ListMcpTools_ReturnsToolStateShape() + { + var result = ManageEditor.HandleCommand(new JObject + { + ["action"] = "list_mcp_tools", + }); + + var response = JObject.FromObject(result); + Assert.AreEqual(true, response["success"]?.Value()); + + var data = response["data"] as JObject; + Assert.IsNotNull(data); + var tools = data["tools"] as JArray; + Assert.IsNotNull(tools); + Assert.Greater(tools.Count, 0); + + var sceneTool = tools + .OfType() + .FirstOrDefault(tool => string.Equals(tool["name"]?.ToString(), TargetTool, StringComparison.Ordinal)); + + Assert.IsNotNull(sceneTool); + Assert.IsNotNull(sceneTool["enabled"]); + Assert.IsNotNull(sceneTool["autoRegister"]); + Assert.IsNotNull(sceneTool["isBuiltIn"]); + } + } +} diff --git a/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs.meta b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs.meta new file mode 100644 index 000000000..738526e8c --- /dev/null +++ b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e4a8f9a95a6c41d3b7ecf10d460879d8 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: From b88905f9f70db29be4c4390aa672fd9fcf4b5445 Mon Sep 17 00:00:00 2001 From: whatevertogo Date: Tue, 17 Feb 2026 04:27:15 +0800 Subject: [PATCH 02/11] feat(stdio): add tools/list_changed notifications for tool toggle Implement MCP tools/list_changed notifications to inform clients when the available tool list changes: - Send notification immediately after set_mcp_tool_enabled succeeds - Start background watcher to poll stdio status file changes - Track MCP sessions and notify on new connections (on_message, on_notification) - Add status file TTL (15s default) to ignore stale status files - Clean up stale sessions when notification send fails New environment variables: - UNITY_MCP_STDIO_TOOLS_WATCH_INTERVAL_SECONDS (default: 1.0, min: 0.2) - UNITY_MCP_STDIO_STATUS_TTL_SECONDS (default: 15.0) Tests: - test_manage_editor_set_mcp_tool_enabled_sends_tool_list_changed - test_stdio_list_tools_ignores_stale_status_files_before_union - test_unity_instance_middleware_tool_list_notifications.py (new) Co-Authored-By: Claude Opus 4.6 --- Server/src/main.py | 10 + Server/src/services/tools/manage_editor.py | 11 + .../transport/unity_instance_middleware.py | 223 ++++++++++++++++++ .../test_manage_editor_param_coercion.py | 28 +++ ...y_instance_middleware_stdio_tool_filter.py | 34 +++ ...ance_middleware_tool_list_notifications.py | 108 +++++++++ 6 files changed, 414 insertions(+) create mode 100644 Server/tests/test_unity_instance_middleware_tool_list_notifications.py diff --git a/Server/src/main.py b/Server/src/main.py index 2cf162c04..402452f5f 100644 --- a/Server/src/main.py +++ b/Server/src/main.py @@ -140,6 +140,7 @@ async def server_lifespan(server: FastMCP) -> AsyncIterator[dict[str, Any]]: global _unity_connection_pool, _server_version _server_version = get_package_version() logger.info(f"MCP for Unity Server v{_server_version} starting up") + unity_middleware = get_unity_instance_middleware() # Register custom tool management endpoints with FastMCP # Routes are declared globally below after FastMCP initialization @@ -160,6 +161,11 @@ async def server_lifespan(server: FastMCP) -> AsyncIterator[dict[str, Any]]: loop = asyncio.get_running_loop() PluginHub.configure(_plugin_registry, loop) + try: + await unity_middleware.start_stdio_tools_watcher() + except Exception: + logger.debug("Failed to start stdio tools watcher.", exc_info=True) + # Record server startup telemetry start_time = time.time() start_clk = time.perf_counter() @@ -244,6 +250,10 @@ def _emit_startup(): "plugin_registry": _plugin_registry, } finally: + try: + await unity_middleware.stop_stdio_tools_watcher() + except Exception: + logger.debug("Failed to stop stdio tools watcher.", exc_info=True) if _unity_connection_pool: _unity_connection_pool.disconnect_all() logger.info("MCP for Unity Server shut down") diff --git a/Server/src/services/tools/manage_editor.py b/Server/src/services/tools/manage_editor.py index 0bfa58880..4b9e418da 100644 --- a/Server/src/services/tools/manage_editor.py +++ b/Server/src/services/tools/manage_editor.py @@ -1,4 +1,5 @@ from typing import Annotated, Any, Literal +import logging from fastmcp import Context from mcp.types import ToolAnnotations @@ -10,6 +11,8 @@ from transport.legacy.unity_connection import async_send_command_with_retry from services.tools.utils import coerce_bool +logger = logging.getLogger("mcp-for-unity-server") + @mcp_for_unity_tool( description="Controls and queries the Unity editor's state and settings. Tip: pass booleans as true/false; if your client only sends strings, 'true'/'false' are accepted. Read-only actions: telemetry_status, telemetry_ping, get_mcp_tool_enabled, list_mcp_tools. Modifying actions: play, pause, stop, set_active_tool, set_mcp_tool_enabled, add_tag, remove_tag, add_layer, remove_layer.", @@ -67,6 +70,14 @@ async def manage_editor( # Preserve structured failure data; unwrap success into a friendlier shape if isinstance(response, dict) and response.get("success"): + if action == "set_mcp_tool_enabled": + try: + await ctx.send_tool_list_changed() + except Exception: + logger.debug( + "Failed to send tools/list_changed notification after set_mcp_tool_enabled.", + exc_info=True, + ) return {"success": True, "message": response.get("message", "Editor operation successful."), "data": response.get("data")} return response if isinstance(response, dict) else {"success": False, "message": str(response)} diff --git a/Server/src/transport/unity_instance_middleware.py b/Server/src/transport/unity_instance_middleware.py index d25fd28fa..a16da4f28 100644 --- a/Server/src/transport/unity_instance_middleware.py +++ b/Server/src/transport/unity_instance_middleware.py @@ -4,7 +4,9 @@ This middleware intercepts all tool calls and injects the active Unity instance into the request-scoped state, allowing tools to access it via ctx.get_state("unity_instance"). """ +import asyncio from threading import RLock +from datetime import datetime, timezone import json import logging import os @@ -61,6 +63,7 @@ def __init__(self): self._active_by_key: dict[str, str] = {} self._lock = RLock() self._metadata_lock = RLock() + self._session_lock = RLock() self._unity_managed_tool_names: set[str] = set() self._tool_alias_to_unity_target: dict[str, str] = {} self._server_only_tool_names: set[str] = set() @@ -68,6 +71,9 @@ def __init__(self): self._last_tool_visibility_refresh = 0.0 self._tool_visibility_refresh_interval_seconds = 0.5 self._has_logged_empty_registry_warning = False + self._tracked_sessions: dict[str, object] = {} + self._stdio_tools_watch_task: asyncio.Task | None = None + self._last_stdio_tools_state_signature: tuple[tuple[str, tuple[str, ...]], ...] | None = None def get_session_key(self, ctx) -> str: """ @@ -107,6 +113,159 @@ def clear_active_instance(self, ctx) -> None: with self._lock: self._active_by_key.pop(key, None) + @staticmethod + def _is_stdio_transport() -> bool: + return (config.transport_mode or "stdio").lower() == "stdio" + + def _track_session_from_context(self, fastmcp_context) -> bool: + if fastmcp_context is None or fastmcp_context.request_context is None: + return False + + try: + session_id = fastmcp_context.session_id + session = fastmcp_context.session + except RuntimeError: + return False + + if not isinstance(session_id, str) or not session_id: + return False + + with self._session_lock: + existing = self._tracked_sessions.get(session_id) + if existing is session: + return False + + self._tracked_sessions[session_id] = session + + return True + + async def _notify_tool_list_changed_to_sessions(self, reason: str) -> None: + with self._session_lock: + session_items = list(self._tracked_sessions.items()) + + if not session_items: + return + + stale_session_ids: list[str] = [] + sent_count = 0 + for session_id, session in session_items: + try: + await session.send_tool_list_changed() + sent_count += 1 + except Exception: + stale_session_ids.append(session_id) + logger.debug( + "Failed sending tools/list_changed to session %s (reason=%s); session will be removed.", + session_id, + reason, + exc_info=True, + ) + + if stale_session_ids: + with self._session_lock: + for session_id in stale_session_ids: + self._tracked_sessions.pop(session_id, None) + + if sent_count: + logger.debug( + "Sent tools/list_changed notification to %d tracked session(s) (reason=%s).", + sent_count, + reason, + ) + + def _build_stdio_tools_state_signature(self) -> tuple[tuple[str, tuple[str, ...]], ...]: + payloads = self._list_stdio_status_payloads() + enabled_by_hash: dict[str, tuple[str, ...]] = {} + for payload in payloads: + project_hash = payload.get("project_hash") + if not isinstance(project_hash, str) or not project_hash or project_hash in enabled_by_hash: + continue + + enabled_raw = payload.get("enabled_tools") + if isinstance(enabled_raw, set): + enabled_tools = tuple( + sorted( + tool_name + for tool_name in enabled_raw + if isinstance(tool_name, str) and tool_name + ) + ) + elif isinstance(enabled_raw, list): + enabled_tools = tuple( + sorted( + tool_name + for tool_name in enabled_raw + if isinstance(tool_name, str) and tool_name + ) + ) + else: + enabled_tools = () + + enabled_by_hash[project_hash] = enabled_tools + + return tuple(sorted(enabled_by_hash.items(), key=lambda item: item[0])) + + @staticmethod + def _get_stdio_tools_watch_interval_seconds() -> float: + raw_interval = os.getenv("UNITY_MCP_STDIO_TOOLS_WATCH_INTERVAL_SECONDS", "1.0") + try: + parsed_interval = float(raw_interval) + if parsed_interval < 0.2: + return 0.2 + return parsed_interval + except (TypeError, ValueError): + return 1.0 + + async def _run_stdio_tools_watch_loop(self, interval_seconds: float) -> None: + while True: + try: + await asyncio.sleep(interval_seconds) + current_signature = self._build_stdio_tools_state_signature() + if self._last_stdio_tools_state_signature is None: + self._last_stdio_tools_state_signature = current_signature + continue + + if current_signature != self._last_stdio_tools_state_signature: + self._last_stdio_tools_state_signature = current_signature + await self._notify_tool_list_changed_to_sessions("stdio_state_changed") + except asyncio.CancelledError: + raise + except Exception: + logger.debug("stdio tools watcher iteration failed.", exc_info=True) + + async def start_stdio_tools_watcher(self) -> None: + if not self._is_stdio_transport(): + return + + task = self._stdio_tools_watch_task + if task is not None and not task.done(): + return + + self._last_stdio_tools_state_signature = self._build_stdio_tools_state_signature() + interval_seconds = self._get_stdio_tools_watch_interval_seconds() + self._stdio_tools_watch_task = asyncio.create_task( + self._run_stdio_tools_watch_loop(interval_seconds), + name="unity-mcp-stdio-tools-watcher", + ) + logger.debug("Started stdio tools watcher (interval=%ss).", interval_seconds) + + async def stop_stdio_tools_watcher(self) -> None: + task = self._stdio_tools_watch_task + self._stdio_tools_watch_task = None + self._last_stdio_tools_state_signature = None + + if task is not None: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception: + logger.debug("Error while stopping stdio tools watcher.", exc_info=True) + + with self._session_lock: + self._tracked_sessions.clear() + async def _maybe_autoselect_instance(self, ctx) -> str | None: """ Auto-select the sole Unity instance when no active instance is set. @@ -269,6 +428,22 @@ async def on_call_tool(self, context: MiddlewareContext, call_next): await self._inject_unity_instance(context) return await call_next(context) + async def on_message(self, context: MiddlewareContext, call_next): + if self._is_stdio_transport(): + is_new_session = self._track_session_from_context(context.fastmcp_context) + if is_new_session: + await self._notify_tool_list_changed_to_sessions("session_registered") + + return await call_next(context) + + async def on_notification(self, context: MiddlewareContext, call_next): + if self._is_stdio_transport(): + self._track_session_from_context(context.fastmcp_context) + if context.method == "notifications/initialized": + await self._notify_tool_list_changed_to_sessions("client_initialized") + + return await call_next(context) + async def on_read_resource(self, context: MiddlewareContext, call_next): """Inject active Unity instance into resource context if available.""" await self._inject_unity_instance(context) @@ -424,6 +599,8 @@ def _resolve_enabled_tool_names_for_stdio_context(self, active_instance: str | N return next(iter(enabled_by_project_hash.values())) def _list_stdio_status_payloads(self) -> list[dict[str, object]]: + status_ttl_seconds = self._get_stdio_status_ttl_seconds() + now_utc = datetime.now(timezone.utc) status_dir_env = os.getenv("UNITY_MCP_STATUS_DIR") status_dir = Path(status_dir_env).expanduser() if status_dir_env else Path.home().joinpath(".unity-mcp") @@ -473,6 +650,26 @@ def _list_stdio_status_payloads(self) -> list[dict[str, object]]: if isinstance(tool_name, str) and tool_name } + freshness = self._parse_heartbeat_datetime(raw_payload.get("last_heartbeat")) + if freshness is None: + try: + freshness = datetime.fromtimestamp(status_file.stat().st_mtime, tz=timezone.utc) + except OSError: + logger.debug( + "Failed to read mtime for stdio status file %s; skipping for safety.", + status_file, + exc_info=True, + ) + continue + + if (now_utc - freshness).total_seconds() > status_ttl_seconds: + logger.debug( + "Skipping stale stdio status file %s (age exceeds %ss).", + status_file, + status_ttl_seconds, + ) + continue + project_hash = raw_payload.get("project_hash") if not isinstance(project_hash, str) or not project_hash: project_hash = file_hash @@ -500,6 +697,32 @@ def _extract_project_hash_from_filename(status_file: Path) -> str | None: suffix = stem[len(prefix):] return suffix or None + @staticmethod + def _get_stdio_status_ttl_seconds() -> float: + raw_ttl = os.getenv("UNITY_MCP_STDIO_STATUS_TTL_SECONDS", "15") + try: + ttl = float(raw_ttl) + if ttl > 0: + return ttl + except (TypeError, ValueError): + pass + return 15.0 + + @staticmethod + def _parse_heartbeat_datetime(raw_heartbeat: object) -> datetime | None: + if not isinstance(raw_heartbeat, str) or not raw_heartbeat: + return None + + try: + parsed = datetime.fromisoformat(raw_heartbeat.replace("Z", "+00:00")) + except ValueError: + return None + + if parsed.tzinfo is None: + return parsed.replace(tzinfo=timezone.utc) + + return parsed.astimezone(timezone.utc) + def _refresh_tool_visibility_metadata_from_registry(self) -> None: now = time.monotonic() if now - self._last_tool_visibility_refresh < self._tool_visibility_refresh_interval_seconds: diff --git a/Server/tests/integration/test_manage_editor_param_coercion.py b/Server/tests/integration/test_manage_editor_param_coercion.py index 9d4c9a011..19ec77334 100644 --- a/Server/tests/integration/test_manage_editor_param_coercion.py +++ b/Server/tests/integration/test_manage_editor_param_coercion.py @@ -4,6 +4,15 @@ import services.tools.manage_editor as manage_editor_mod +class NotifyContext(DummyContext): + def __init__(self, **meta): + super().__init__(**meta) + self.tool_list_changed_calls = 0 + + async def send_tool_list_changed(self): + self.tool_list_changed_calls += 1 + + def test_manage_editor_set_mcp_tool_enabled_string_coercion(monkeypatch): captured = {} @@ -43,3 +52,22 @@ async def fake_send(_func, _instance, _tool_name, params, **_kwargs): assert result["success"] is False assert "enabled" in result["message"] + + +def test_manage_editor_set_mcp_tool_enabled_sends_tool_list_changed(monkeypatch): + async def fake_send(_func, _instance, _tool_name, params, **_kwargs): + return {"success": True, "message": "ok"} + + monkeypatch.setattr(manage_editor_mod, "send_with_unity_instance", fake_send) + ctx = NotifyContext() + result = asyncio.run( + manage_editor_mod.manage_editor( + ctx=ctx, + action="set_mcp_tool_enabled", + tool_name="manage_scene", + enabled=True, + ) + ) + + assert result["success"] is True + assert ctx.tool_list_changed_calls == 1 diff --git a/Server/tests/test_unity_instance_middleware_stdio_tool_filter.py b/Server/tests/test_unity_instance_middleware_stdio_tool_filter.py index b45c7ba85..d4ba27f5e 100644 --- a/Server/tests/test_unity_instance_middleware_stdio_tool_filter.py +++ b/Server/tests/test_unity_instance_middleware_stdio_tool_filter.py @@ -1,4 +1,5 @@ import json +from datetime import datetime, timedelta, timezone from types import SimpleNamespace from unittest.mock import AsyncMock, Mock, patch @@ -163,6 +164,39 @@ async def test_stdio_list_tools_uses_union_when_no_active_instance_and_multiple_ assert "set_active_instance" in names +@pytest.mark.asyncio +async def test_stdio_list_tools_ignores_stale_status_files_before_union(monkeypatch, tmp_path): + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + monkeypatch.setenv("UNITY_MCP_STDIO_STATUS_TTL_SECONDS", "15") + + fresh_heartbeat = datetime.now(timezone.utc).isoformat() + stale_heartbeat = (datetime.now(timezone.utc) - timedelta(minutes=2)).isoformat() + + _write_status_file( + tmp_path / "unity-mcp-status-fresh11.json", + { + "project_hash": "fresh11", + "enabled_tools": ["manage_scene"], + "last_heartbeat": fresh_heartbeat, + }, + ) + _write_status_file( + tmp_path / "unity-mcp-status-stale22.json", + { + "project_hash": "stale22", + "enabled_tools": ["manage_asset"], + "last_heartbeat": stale_heartbeat, + }, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context(None)) + + assert "manage_scene" in names + assert "manage_asset" not in names + + @pytest.mark.asyncio async def test_stdio_list_tools_skips_filter_when_enabled_tools_field_missing(monkeypatch, tmp_path): monkeypatch.setattr(config, "transport_mode", "stdio") diff --git a/Server/tests/test_unity_instance_middleware_tool_list_notifications.py b/Server/tests/test_unity_instance_middleware_tool_list_notifications.py new file mode 100644 index 000000000..d20b792fe --- /dev/null +++ b/Server/tests/test_unity_instance_middleware_tool_list_notifications.py @@ -0,0 +1,108 @@ +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from core.config import config +from transport.unity_instance_middleware import UnityInstanceMiddleware + + +def _build_context(session_id: str, session_obj: object, method: str = "tools/list"): + fastmcp_context = SimpleNamespace( + request_context=object(), + session_id=session_id, + session=session_obj, + ) + return SimpleNamespace( + fastmcp_context=fastmcp_context, + method=method, + ) + + +@pytest.mark.asyncio +async def test_on_message_registers_new_session_and_notifies(monkeypatch): + monkeypatch.setattr(config, "transport_mode", "stdio") + middleware = UnityInstanceMiddleware() + session = SimpleNamespace(send_tool_list_changed=AsyncMock()) + + context = _build_context("session-1", session) + await middleware.on_message(context, AsyncMock(return_value=None)) + await middleware.on_message(context, AsyncMock(return_value=None)) + + # First message for a new session sends one immediate refresh notification. + assert session.send_tool_list_changed.await_count == 1 + + +@pytest.mark.asyncio +async def test_on_notification_initialized_triggers_tools_list_changed(monkeypatch): + monkeypatch.setattr(config, "transport_mode", "stdio") + middleware = UnityInstanceMiddleware() + session = SimpleNamespace(send_tool_list_changed=AsyncMock()) + + context = _build_context("session-init", session, method="notifications/initialized") + await middleware.on_notification(context, AsyncMock(return_value=None)) + + assert session.send_tool_list_changed.await_count == 1 + + +@pytest.mark.asyncio +async def test_notify_tool_list_changed_removes_stale_sessions(monkeypatch): + monkeypatch.setattr(config, "transport_mode", "stdio") + middleware = UnityInstanceMiddleware() + + healthy_session = SimpleNamespace(send_tool_list_changed=AsyncMock(return_value=None)) + + async def _raise_send(): + raise RuntimeError("session closed") + + stale_session = SimpleNamespace(send_tool_list_changed=AsyncMock(side_effect=_raise_send)) + middleware._tracked_sessions["healthy"] = healthy_session + middleware._tracked_sessions["stale"] = stale_session + + await middleware._notify_tool_list_changed_to_sessions("test_reason") + + assert "healthy" in middleware._tracked_sessions + assert "stale" not in middleware._tracked_sessions + assert healthy_session.send_tool_list_changed.await_count == 1 + + +@pytest.mark.asyncio +async def test_start_stdio_tools_watcher_skips_when_transport_is_not_stdio(monkeypatch): + monkeypatch.setattr(config, "transport_mode", "http") + middleware = UnityInstanceMiddleware() + + await middleware.start_stdio_tools_watcher() + assert middleware._stdio_tools_watch_task is None + + +@pytest.mark.asyncio +async def test_stdio_tools_watcher_notifies_on_signature_change(monkeypatch): + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STDIO_TOOLS_WATCH_INTERVAL_SECONDS", "0.2") + middleware = UnityInstanceMiddleware() + middleware._notify_tool_list_changed_to_sessions = AsyncMock(return_value=None) + + signatures = [ + (("hash1", ("manage_scene",)),), + (("hash1", ("manage_scene", "manage_asset")),), + (("hash1", ("manage_scene", "manage_asset")),), + ] + signature_index = {"value": 0} + + def _fake_signature(): + index = signature_index["value"] + signature_index["value"] += 1 + if index >= len(signatures): + return signatures[-1] + return signatures[index] + + monkeypatch.setattr(middleware, "_build_stdio_tools_state_signature", _fake_signature) + + await middleware.start_stdio_tools_watcher() + try: + await asyncio.sleep(0.35) + finally: + await middleware.stop_stdio_tools_watcher() + + middleware._notify_tool_list_changed_to_sessions.assert_awaited_with("stdio_state_changed") From 640aef8ee481ecb905fbc53129bfba865f2c0758 Mon Sep 17 00:00:00 2001 From: whatevertogo Date: Tue, 17 Feb 2026 04:49:31 +0800 Subject: [PATCH 03/11] feat(stdio): add configuration options for stdio status freshness and tool watch interval --- Server/README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/Server/README.md b/Server/README.md index c64490770..dd76989d5 100644 --- a/Server/README.md +++ b/Server/README.md @@ -147,6 +147,8 @@ These options apply to the `mcp-for-unity` command (whether run via `uvx`, Docke - `UNITY_MCP_HTTP_REMOTE_HOSTED` - Enable remote-hosted mode (`true`, `1`, or `yes`) - `UNITY_MCP_DEFAULT_INSTANCE` - Default Unity instance to target (project name, hash, or `Name@hash`) - `UNITY_MCP_SKIP_STARTUP_CONNECT=1` - Skip initial Unity connection attempt on startup +- `UNITY_MCP_STDIO_STATUS_TTL_SECONDS` - Freshness window for stdio status files used by tools/list filtering (default: `15`) +- `UNITY_MCP_STDIO_TOOLS_WATCH_INTERVAL_SECONDS` - Poll interval for stdio tool-list change watcher in seconds (default: `1.0`, minimum: `0.2`) API key authentication (remote-hosted mode): @@ -191,6 +193,17 @@ The Unity status file (`~/.unity-mcp/unity-mcp-status-.json`) now includes Tool toggle changes trigger an immediate status-file refresh, so `tools/list` updates do not depend on waiting for the next heartbeat. +When a client session initializes in `stdio`, the server sends +`notifications/tools/list_changed` to trigger an immediate tool-list refresh. +During runtime, a stdio watcher monitors status-file changes and emits the same +notification when enabled-tool state changes. +To avoid stale instance data, stdio filtering only uses recent status files +(default freshness window: 15s, configurable via `UNITY_MCP_STDIO_STATUS_TTL_SECONDS`). +Watcher interval defaults to 1.0s (minimum 0.2s), configurable via +`UNITY_MCP_STDIO_TOOLS_WATCH_INTERVAL_SECONDS`. +Compatibility note: if a client ignores `notifications/tools/list_changed`, +tool calls still enforce enabled/disabled state, but visible list updates may +still require reconnecting that client. ### Examples From b87baaf1b55ea0d940a27e854771aa2a898f5394 Mon Sep 17 00:00:00 2001 From: David Sarno Date: Tue, 17 Feb 2026 18:00:09 -0800 Subject: [PATCH 04/11] Add per-call unity_instance routing via middleware argument interception Any tool call can now include a unity_instance parameter to route that specific call to a target Unity instance without changing the session default and without requiring a set_active_instance call first. The middleware pops unity_instance from tool call arguments before Pydantic validation runs, resolves it (port number, hash prefix, or Name@hash), and injects it into request-scoped state for that call only. - Port numbers resolve to the matching Name@hash via status file lookup rather than synthetic direct:{port} IDs, so the transport layer can route them correctly - HTTP mode rejects port-based targeting with a clear error - set_active_instance now also accepts port numbers for consistency - Multi-instance scenarios log available instances with ports when auto-select cannot choose - _discover_instances() helper DRYs up transport-aware instance discovery previously duplicated across the codebase - Server instructions updated to document both routing approaches - 18 new tests covering pop behaviour, per-call vs session routing, port resolution, transport modes, and edge cases Closes #697 Co-Authored-By: Claude Sonnet 4.6 --- Server/src/main.py | 3 +- .../src/services/tools/set_active_instance.py | 37 +- .../transport/unity_instance_middleware.py | 148 ++++++- .../integration/test_inline_unity_instance.py | 390 ++++++++++++++++++ 4 files changed, 574 insertions(+), 4 deletions(-) create mode 100644 Server/tests/integration/test_inline_unity_instance.py diff --git a/Server/src/main.py b/Server/src/main.py index 2cf162c04..c0415ecc6 100644 --- a/Server/src/main.py +++ b/Server/src/main.py @@ -268,7 +268,8 @@ def _build_instructions(project_scoped_tools: bool) -> str: Targeting Unity instances: - Use the resource mcpforunity://instances to list active Unity sessions (Name@hash). -- When multiple instances are connected, call set_active_instance with the exact Name@hash before using tools/resources. The server will error if multiple are connected and no active instance is set. +- When multiple instances are connected, call set_active_instance with the exact Name@hash before using tools/resources to pin routing for the whole session. The server will error if multiple are connected and no active instance is set. +- Alternatively, pass unity_instance as a parameter on any individual tool call to route just that call (e.g. unity_instance="MyGame@abc123", unity_instance="abc" for a hash prefix, or unity_instance="6401" for a port number in stdio mode). This does not change the session default. Important Workflows: diff --git a/Server/src/services/tools/set_active_instance.py b/Server/src/services/tools/set_active_instance.py index 30582867f..ecdfb3a5c 100644 --- a/Server/src/services/tools/set_active_instance.py +++ b/Server/src/services/tools/set_active_instance.py @@ -13,17 +13,50 @@ @mcp_for_unity_tool( unity_target=None, - description="Set the active Unity instance for this client/session. Accepts Name@hash or hash.", + description="Set the active Unity instance for this client/session. Accepts Name@hash, hash prefix, or port number (stdio only).", annotations=ToolAnnotations( title="Set Active Instance", ), ) async def set_active_instance( ctx: Context, - instance: Annotated[str, "Target instance (Name@hash or hash prefix)"] + instance: Annotated[str, "Target instance (Name@hash, hash prefix, or port number in stdio mode)"] ) -> dict[str, Any]: transport = (config.transport_mode or "stdio").lower() + # Port number shorthand (stdio only) — resolve to Name@hash via pool discovery + value = (instance or "").strip() + if value.isdigit(): + if transport == "http": + return { + "success": False, + "error": f"Port-based targeting ('{value}') is not supported in HTTP transport mode. " + "Use Name@hash or a hash prefix. Read mcpforunity://instances for available instances." + } + port_int = int(value) + pool = get_unity_connection_pool() + instances = pool.discover_all_instances(force_refresh=True) + match = next((inst for inst in instances if getattr(inst, "port", None) == port_int), None) + if match is None: + available = ", ".join( + f"{inst.id} (port {getattr(inst, 'port', '?')})" for inst in instances + ) or "none" + return { + "success": False, + "error": f"No Unity instance found on port {value}. Available: {available}." + } + resolved_id = match.id + middleware = get_unity_instance_middleware() + middleware.set_active_instance(ctx, resolved_id) + return { + "success": True, + "message": f"Active instance set to {resolved_id}", + "data": { + "instance": resolved_id, + "session_key": middleware.get_session_key(ctx), + }, + } + # Discover running instances based on transport if transport == "http": # In remote-hosted mode, filter sessions by user_id diff --git a/Server/src/transport/unity_instance_middleware.py b/Server/src/transport/unity_instance_middleware.py index 41b4e8baf..e8aea6625 100644 --- a/Server/src/transport/unity_instance_middleware.py +++ b/Server/src/transport/unity_instance_middleware.py @@ -104,6 +104,124 @@ def clear_active_instance(self, ctx) -> None: with self._lock: self._active_by_key.pop(key, None) + async def _discover_instances(self, ctx) -> list: + """ + Return running Unity instances across both HTTP (PluginHub) and stdio transports. + + Returns a list of objects with .id (Name@hash) and .hash attributes. + """ + from types import SimpleNamespace + transport = (config.transport_mode or "stdio").lower() + results: list = [] + + if PluginHub.is_configured(): + try: + user_id = None + get_state_fn = getattr(ctx, "get_state", None) + if callable(get_state_fn) and config.http_remote_hosted: + user_id = get_state_fn("user_id") + sessions_data = await PluginHub.get_sessions(user_id=user_id) + sessions = sessions_data.sessions or {} + for session_info in sessions.values(): + project = getattr(session_info, "project", None) or "Unknown" + hash_value = getattr(session_info, "hash", None) + if hash_value: + results.append(SimpleNamespace( + id=f"{project}@{hash_value}", + hash=hash_value, + name=project, + )) + except Exception as exc: + if isinstance(exc, (SystemExit, KeyboardInterrupt)): + raise + logger.debug("PluginHub instance discovery failed (%s)", type(exc).__name__, exc_info=True) + + if not results and transport != "http": + try: + from transport.legacy.unity_connection import get_unity_connection_pool + pool = get_unity_connection_pool() + results = pool.discover_all_instances(force_refresh=True) + except Exception as exc: + if isinstance(exc, (SystemExit, KeyboardInterrupt)): + raise + logger.debug("Stdio instance discovery failed (%s)", type(exc).__name__, exc_info=True) + + return results + + async def _resolve_instance_value(self, value: str, ctx) -> str: + """ + Resolve a unity_instance string to a validated instance identifier. + + Accepts: + - Bare port number like "6401" (stdio only) -> resolved Name@hash + - "Name@hash" exact match + - Hash prefix (unique prefix match against running instances) + + Raises ValueError with a user-friendly message on failure. + """ + value = value.strip() + if not value: + raise ValueError("unity_instance value must not be empty.") + + transport = (config.transport_mode or "stdio").lower() + + # Port number (stdio only) — resolve to Name@hash via status file lookup + if value.isdigit(): + if transport == "http": + raise ValueError( + f"Port-based targeting ('{value}') is not supported in HTTP transport mode. " + "Use Name@hash or a hash prefix. Read mcpforunity://instances for available instances." + ) + port_int = int(value) + instances = await self._discover_instances(ctx) + for inst in instances: + if getattr(inst, "port", None) == port_int: + return inst.id + available = ", ".join( + f"{getattr(i, 'id', '?')} (port {getattr(i, 'port', '?')})" + for i in instances + ) or "none" + raise ValueError( + f"No Unity instance found on port {value}. Available: {available}." + ) + + instances = await self._discover_instances(ctx) + ids = { + getattr(inst, "id", None): inst + for inst in instances + if getattr(inst, "id", None) + } + + # Exact Name@hash match + if "@" in value: + if value in ids: + return value + available = ", ".join(ids) or "none" + raise ValueError( + f"Instance '{value}' not found. Available: {available}. " + "Read mcpforunity://instances for current sessions." + ) + + # Hash prefix match + lookup = value.lower() + matches = [ + inst for inst in instances + if getattr(inst, "hash", "") and getattr(inst, "hash", "").lower().startswith(lookup) + ] + if len(matches) == 1: + return matches[0].id + if len(matches) > 1: + ambiguous = ", ".join(getattr(m, "id", "?") for m in matches) + raise ValueError( + f"Hash prefix '{value}' is ambiguous ({ambiguous}). " + "Provide the full Name@hash from mcpforunity://instances." + ) + available = ", ".join(ids) or "none" + raise ValueError( + f"No running Unity instance matches '{value}'. Available: {available}. " + "Read mcpforunity://instances for current sessions." + ) + async def _maybe_autoselect_instance(self, ctx) -> str | None: """ Auto-select the sole Unity instance when no active instance is set. @@ -136,6 +254,12 @@ async def _maybe_autoselect_instance(self, ctx) -> str | None: chosen, ) return chosen + if len(ids) > 1: + logger.info( + "Multiple Unity instances found (%d). Pass unity_instance on any tool call " + "or call set_active_instance to choose one. Available: %s", + len(ids), ", ".join(ids), + ) except (ConnectionError, ValueError, KeyError, TimeoutError, AttributeError) as exc: logger.debug( "PluginHub auto-select probe failed (%s); falling back to stdio", @@ -168,6 +292,12 @@ async def _maybe_autoselect_instance(self, ctx) -> str | None: chosen, ) return chosen + if len(ids) > 1: + logger.info( + "Multiple Unity instances found (%d). Pass unity_instance on any tool call " + "or call set_active_instance to choose one. Available: %s", + len(ids), ", ".join(ids), + ) except (ConnectionError, ValueError, KeyError, TimeoutError, AttributeError) as exc: logger.debug( "Stdio auto-select probe failed (%s)", @@ -214,7 +344,23 @@ async def _inject_unity_instance(self, context: MiddlewareContext) -> None: if user_id: ctx.set_state("user_id", user_id) - active_instance = self.get_active_instance(ctx) + # Per-call routing: check if this tool call explicitly specifies unity_instance. + # context.message.arguments is a mutable dict on CallToolRequestParams; resource + # reads use ReadResourceRequestParams which has no .arguments, so this is a no-op for them. + # We pop the key here so Pydantic's type_adapter.validate_python() never sees it. + active_instance: str | None = None + msg_args = getattr(getattr(context, "message", None), "arguments", None) + if isinstance(msg_args, dict) and "unity_instance" in msg_args: + raw = msg_args.pop("unity_instance") + if raw is not None: + raw_str = str(raw).strip() + if raw_str: + # Raises ValueError with a user-friendly message on invalid input. + active_instance = await self._resolve_instance_value(raw_str, ctx) + logger.debug("Per-call unity_instance resolved to: %s", active_instance) + + if not active_instance: + active_instance = self.get_active_instance(ctx) if not active_instance: active_instance = await self._maybe_autoselect_instance(ctx) if active_instance: diff --git a/Server/tests/integration/test_inline_unity_instance.py b/Server/tests/integration/test_inline_unity_instance.py new file mode 100644 index 000000000..13b13c32b --- /dev/null +++ b/Server/tests/integration/test_inline_unity_instance.py @@ -0,0 +1,390 @@ +""" +Tests for per-call unity_instance routing via middleware argument interception. + +When a tool call includes unity_instance in its arguments, the middleware: + 1. Pops the key before Pydantic validation sees it + 2. Resolves it to a validated instance identifier + 3. Sets it in request-scoped state for that call only (does NOT persist to session) +""" +import asyncio +import sys +import types +from types import SimpleNamespace + +import pytest + +from .test_helpers import DummyContext +from core.config import config + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class DummyMiddlewareContext: + """Minimal MiddlewareContext stand-in with a mutable arguments dict.""" + + def __init__(self, ctx, arguments: dict | None = None): + self.fastmcp_context = ctx + self.message = SimpleNamespace(arguments=arguments if arguments is not None else {}) + + +def _make_middleware(monkeypatch, *, transport="stdio", plugin_hub_configured=False, sessions=None, pool_instances=None): + """ + Build a UnityInstanceMiddleware with patched transport dependencies. + + sessions: dict of session_id -> SimpleNamespace(project=..., hash=...) + pool_instances: list of SimpleNamespace(id=..., hash=...) + """ + plugin_hub_mod = types.ModuleType("transport.plugin_hub") + + _sessions = sessions or {} + _configured = plugin_hub_configured + + class FakePluginHub: + @classmethod + def is_configured(cls): + return _configured + + @classmethod + async def get_sessions(cls, user_id=None): + return SimpleNamespace(sessions=_sessions) + + @classmethod + async def _resolve_session_id(cls, instance, user_id=None): + return None + + plugin_hub_mod.PluginHub = FakePluginHub + monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub_mod) + monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False) + + from transport.unity_instance_middleware import UnityInstanceMiddleware + + middleware = UnityInstanceMiddleware() + monkeypatch.setattr(config, "transport_mode", transport) + monkeypatch.setattr(config, "http_remote_hosted", False) + + if pool_instances is not None: + async def fake_discover(ctx): + return pool_instances + monkeypatch.setattr(middleware, "_discover_instances", fake_discover) + + return middleware + + +# --------------------------------------------------------------------------- +# Pop behaviour +# --------------------------------------------------------------------------- + +def test_unity_instance_is_popped_from_arguments(monkeypatch): + """unity_instance key must be removed from arguments before the tool function sees them.""" + instances = [SimpleNamespace(id="Proj@abc123", hash="abc123")] + mw = _make_middleware(monkeypatch, pool_instances=instances) + + ctx = DummyContext() + ctx.client_id = "client-1" + args = {"action": "get_active", "unity_instance": "abc123"} + mw_ctx = DummyMiddlewareContext(ctx, arguments=args) + + asyncio.run(mw._inject_unity_instance(mw_ctx)) + + assert "unity_instance" not in args + assert "action" in args # other keys untouched + + +def test_arguments_without_unity_instance_untouched(monkeypatch): + """When unity_instance is absent, arguments dict is left completely untouched.""" + mw = _make_middleware(monkeypatch, pool_instances=[SimpleNamespace(id="Proj@abc123", hash="abc123")]) + + ctx = DummyContext() + ctx.client_id = "client-1" + # Seed a persisted instance so auto-select isn't needed + mw.set_active_instance(ctx, "Proj@abc123") + + args = {"action": "get_active", "name": "Test"} + mw_ctx = DummyMiddlewareContext(ctx, arguments=args) + + asyncio.run(mw._inject_unity_instance(mw_ctx)) + + assert args == {"action": "get_active", "name": "Test"} + + +# --------------------------------------------------------------------------- +# Per-call routing (no persistence) +# --------------------------------------------------------------------------- + +def test_inline_routes_to_specified_instance(monkeypatch): + """Per-call unity_instance sets request state to the resolved instance.""" + instances = [SimpleNamespace(id="Proj@abc123", hash="abc123")] + mw = _make_middleware(monkeypatch, pool_instances=instances) + + ctx = DummyContext() + ctx.client_id = "client-1" + mw_ctx = DummyMiddlewareContext(ctx, arguments={"unity_instance": "abc123"}) + + asyncio.run(mw._inject_unity_instance(mw_ctx)) + + assert ctx.get_state("unity_instance") == "Proj@abc123" + + +def test_inline_does_not_persist_to_session(monkeypatch): + """Per-call unity_instance must not change the session-persisted instance.""" + instances = [ + SimpleNamespace(id="ProjA@aaa111", hash="aaa111"), + SimpleNamespace(id="ProjB@bbb222", hash="bbb222"), + ] + mw = _make_middleware(monkeypatch, pool_instances=instances) + + ctx = DummyContext() + ctx.client_id = "client-1" + mw.set_active_instance(ctx, "ProjA@aaa111") + + # Call 1: inline override to ProjB + mw_ctx1 = DummyMiddlewareContext(ctx, arguments={"unity_instance": "bbb222"}) + asyncio.run(mw._inject_unity_instance(mw_ctx1)) + assert ctx.get_state("unity_instance") == "ProjB@bbb222" + + # Call 2: no inline — must revert to session-persisted ProjA + mw_ctx2 = DummyMiddlewareContext(ctx, arguments={}) + asyncio.run(mw._inject_unity_instance(mw_ctx2)) + assert ctx.get_state("unity_instance") == "ProjA@aaa111" + + +def test_inline_overrides_session_persisted_instance(monkeypatch): + """Inline unity_instance takes precedence over session-persisted instance.""" + instances = [ + SimpleNamespace(id="ProjA@aaa111", hash="aaa111"), + SimpleNamespace(id="ProjB@bbb222", hash="bbb222"), + ] + mw = _make_middleware(monkeypatch, pool_instances=instances) + + ctx = DummyContext() + ctx.client_id = "client-1" + mw.set_active_instance(ctx, "ProjA@aaa111") + + mw_ctx = DummyMiddlewareContext(ctx, arguments={"unity_instance": "ProjB@bbb222"}) + asyncio.run(mw._inject_unity_instance(mw_ctx)) + + assert ctx.get_state("unity_instance") == "ProjB@bbb222" + # Session still pinned to ProjA + assert mw.get_active_instance(ctx) == "ProjA@aaa111" + + +# --------------------------------------------------------------------------- +# Port number resolution (stdio) +# --------------------------------------------------------------------------- + +def test_port_number_resolves_to_name_hash_stdio(monkeypatch): + """Bare port number resolves to the matching Name@hash in stdio mode.""" + instances = [ + SimpleNamespace(id="Proj@abc123", hash="abc123", port=6401), + SimpleNamespace(id="Other@def456", hash="def456", port=6402), + ] + mw = _make_middleware(monkeypatch, transport="stdio", pool_instances=instances) + + ctx = DummyContext() + ctx.client_id = "client-1" + mw_ctx = DummyMiddlewareContext(ctx, arguments={"unity_instance": "6401"}) + + asyncio.run(mw._inject_unity_instance(mw_ctx)) + + assert ctx.get_state("unity_instance") == "Proj@abc123" + + +def test_port_number_not_found_raises(monkeypatch): + """Port number with no matching instance raises ValueError.""" + instances = [SimpleNamespace(id="Proj@abc123", hash="abc123", port=6401)] + mw = _make_middleware(monkeypatch, transport="stdio", pool_instances=instances) + + ctx = DummyContext() + ctx.client_id = "client-1" + mw_ctx = DummyMiddlewareContext(ctx, arguments={"unity_instance": "9999"}) + + with pytest.raises(ValueError, match="No Unity instance found on port 9999"): + asyncio.run(mw._inject_unity_instance(mw_ctx)) + + +def test_port_number_errors_in_http_mode(monkeypatch): + """Bare port number raises ValueError in HTTP transport mode.""" + mw = _make_middleware(monkeypatch, transport="http") + + ctx = DummyContext() + ctx.client_id = "client-1" + mw_ctx = DummyMiddlewareContext(ctx, arguments={"unity_instance": "6401"}) + + with pytest.raises(ValueError, match="not supported in HTTP transport mode"): + asyncio.run(mw._inject_unity_instance(mw_ctx)) + + +# --------------------------------------------------------------------------- +# Name@hash and hash prefix resolution +# --------------------------------------------------------------------------- + +def test_name_at_hash_resolves_exactly(monkeypatch): + """Full Name@hash resolves directly without discovery.""" + instances = [SimpleNamespace(id="Proj@abc123", hash="abc123")] + mw = _make_middleware(monkeypatch, pool_instances=instances) + + ctx = DummyContext() + ctx.client_id = "client-1" + mw_ctx = DummyMiddlewareContext(ctx, arguments={"unity_instance": "Proj@abc123"}) + + asyncio.run(mw._inject_unity_instance(mw_ctx)) + + assert ctx.get_state("unity_instance") == "Proj@abc123" + + +def test_unknown_name_at_hash_raises(monkeypatch): + """Unknown Name@hash raises ValueError.""" + instances = [SimpleNamespace(id="Proj@abc123", hash="abc123")] + mw = _make_middleware(monkeypatch, pool_instances=instances) + + ctx = DummyContext() + ctx.client_id = "client-1" + mw_ctx = DummyMiddlewareContext(ctx, arguments={"unity_instance": "Ghost@deadbeef"}) + + with pytest.raises(ValueError, match="not found"): + asyncio.run(mw._inject_unity_instance(mw_ctx)) + + +def test_hash_prefix_resolves_unique(monkeypatch): + """Unique hash prefix resolves to the full Name@hash.""" + instances = [SimpleNamespace(id="Proj@abc123", hash="abc123")] + mw = _make_middleware(monkeypatch, pool_instances=instances) + + ctx = DummyContext() + ctx.client_id = "client-1" + mw_ctx = DummyMiddlewareContext(ctx, arguments={"unity_instance": "abc"}) + + asyncio.run(mw._inject_unity_instance(mw_ctx)) + + assert ctx.get_state("unity_instance") == "Proj@abc123" + + +def test_ambiguous_hash_prefix_raises(monkeypatch): + """Ambiguous hash prefix raises ValueError.""" + instances = [ + SimpleNamespace(id="ProjA@abc111", hash="abc111"), + SimpleNamespace(id="ProjB@abc222", hash="abc222"), + ] + mw = _make_middleware(monkeypatch, pool_instances=instances) + + ctx = DummyContext() + ctx.client_id = "client-1" + mw_ctx = DummyMiddlewareContext(ctx, arguments={"unity_instance": "abc"}) + + with pytest.raises(ValueError, match="ambiguous"): + asyncio.run(mw._inject_unity_instance(mw_ctx)) + + +def test_no_match_raises(monkeypatch): + """Hash prefix matching nothing raises ValueError.""" + instances = [SimpleNamespace(id="Proj@abc123", hash="abc123")] + mw = _make_middleware(monkeypatch, pool_instances=instances) + + ctx = DummyContext() + ctx.client_id = "client-1" + mw_ctx = DummyMiddlewareContext(ctx, arguments={"unity_instance": "xyz"}) + + with pytest.raises(ValueError, match="No running Unity instance"): + asyncio.run(mw._inject_unity_instance(mw_ctx)) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + +def test_none_unity_instance_falls_through_to_session(monkeypatch): + """None value for unity_instance falls through to session-persisted instance.""" + mw = _make_middleware(monkeypatch) + ctx = DummyContext() + ctx.client_id = "client-1" + mw.set_active_instance(ctx, "Proj@abc123") + + mw_ctx = DummyMiddlewareContext(ctx, arguments={"unity_instance": None, "action": "x"}) + + asyncio.run(mw._inject_unity_instance(mw_ctx)) + + assert ctx.get_state("unity_instance") == "Proj@abc123" + + +def test_empty_string_unity_instance_falls_through_to_session(monkeypatch): + """Empty string unity_instance falls through to session-persisted instance.""" + mw = _make_middleware(monkeypatch) + ctx = DummyContext() + ctx.client_id = "client-1" + mw.set_active_instance(ctx, "Proj@abc123") + + mw_ctx = DummyMiddlewareContext(ctx, arguments={"unity_instance": " "}) + + asyncio.run(mw._inject_unity_instance(mw_ctx)) + + assert ctx.get_state("unity_instance") == "Proj@abc123" + + +def test_resource_read_unaffected(monkeypatch): + """on_read_resource with no .arguments attribute routes via session state normally.""" + mw = _make_middleware(monkeypatch) + ctx = DummyContext() + ctx.client_id = "client-1" + mw.set_active_instance(ctx, "Proj@abc123") + + # ReadResourceRequestParams has .uri not .arguments + resource_ctx = SimpleNamespace( + fastmcp_context=ctx, + message=SimpleNamespace(uri="mcpforunity://scene/active"), + ) + + asyncio.run(mw._inject_unity_instance(resource_ctx)) + + assert ctx.get_state("unity_instance") == "Proj@abc123" + + +# --------------------------------------------------------------------------- +# set_active_instance tool: port number support +# --------------------------------------------------------------------------- + +def test_set_active_instance_port_stdio(monkeypatch): + """set_active_instance accepts a port number in stdio mode and resolves to Name@hash.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setattr(config, "http_remote_hosted", False) + + from transport.unity_instance_middleware import UnityInstanceMiddleware, set_unity_instance_middleware + mw = UnityInstanceMiddleware() + set_unity_instance_middleware(mw) + + pool_instance = SimpleNamespace(id="Proj@abc123", hash="abc123", port=6401) + + class FakePool: + def discover_all_instances(self, force_refresh=False): + return [pool_instance] + + import services.tools.set_active_instance as sat + monkeypatch.setattr(sat, "get_unity_connection_pool", lambda: FakePool()) + + from services.tools.set_active_instance import set_active_instance + + ctx = DummyContext() + ctx.client_id = "client-1" + + result = asyncio.run(set_active_instance(ctx, instance="6401")) + + assert result["success"] is True + assert result["data"]["instance"] == "Proj@abc123" + assert mw.get_active_instance(ctx) == "Proj@abc123" + + +def test_set_active_instance_port_http_errors(monkeypatch): + """set_active_instance rejects port numbers in HTTP mode.""" + monkeypatch.setattr(config, "transport_mode", "http") + monkeypatch.setattr(config, "http_remote_hosted", False) + + from services.tools.set_active_instance import set_active_instance + + ctx = DummyContext() + ctx.client_id = "client-1" + + result = asyncio.run(set_active_instance(ctx, instance="6401")) + + assert result["success"] is False + assert "not supported in HTTP transport mode" in result["error"] From 7036b4560abb789e8c20c932e53bac26ef944905 Mon Sep 17 00:00:00 2001 From: whatevertogo Date: Thu, 19 Feb 2026 00:22:46 +0800 Subject: [PATCH 05/11] feat: add stdio tool filtering tests and address code review feedback Add comprehensive edge case tests (34 tests) for stdio tool filtering: - enabled_tools parsing (empty list, null elements, wrong type) - project_hash handling (missing, empty, filename fallback) - heartbeat parsing (timezone formats, TTL boundary) - TTL and watch interval configuration validation - session tracking edge cases - watcher lifecycle (start/stop, error recovery) - signature building and deduplication - status directory handling - concurrent operations safety Code review fixes: - Guard RefreshStdioStatusFile to only run when StdioBridgeHost.IsRunning - Use asyncio.gather for concurrent notification dispatch - Collapse duplicate enabled_raw handling for set/list types - Replace timing-based test with deterministic Event synchronization - Use direct call instead of reflection in C# tests (InternalsVisibleTo) - Use StringAssert.DoesNotContain for clearer negative assertion Co-Authored-By: Claude Opus 4.6 --- MCPForUnity/Editor/Tools/ManageEditor.cs | 3 + .../transport/unity_instance_middleware.py | 26 +- ...ty_instance_middleware_stdio_edge_cases.py | 716 ++++++++++++++++++ ...ance_middleware_tool_list_notifications.py | 15 +- ...ansportCommandDispatcherToolToggleTests.cs | 29 +- 5 files changed, 748 insertions(+), 41 deletions(-) create mode 100644 Server/tests/test_unity_instance_middleware_stdio_edge_cases.py diff --git a/MCPForUnity/Editor/Tools/ManageEditor.cs b/MCPForUnity/Editor/Tools/ManageEditor.cs index d204e4c34..15e370481 100644 --- a/MCPForUnity/Editor/Tools/ManageEditor.cs +++ b/MCPForUnity/Editor/Tools/ManageEditor.cs @@ -299,6 +299,9 @@ private static object ListMcpTools() private static void RefreshStdioStatusFile() { + if (!StdioBridgeHost.IsRunning) + return; + try { StdioBridgeHost.RefreshStatusFile("tool_toggle"); diff --git a/Server/src/transport/unity_instance_middleware.py b/Server/src/transport/unity_instance_middleware.py index cb3992c69..9cf193720 100644 --- a/Server/src/transport/unity_instance_middleware.py +++ b/Server/src/transport/unity_instance_middleware.py @@ -151,20 +151,26 @@ async def _notify_tool_list_changed_to_sessions(self, reason: str) -> None: if not session_items: return - stale_session_ids: list[str] = [] - sent_count = 0 - for session_id, session in session_items: + async def _send_one(session_id: str, session): try: await session.send_tool_list_changed() - sent_count += 1 + return session_id, True except Exception: - stale_session_ids.append(session_id) logger.debug( "Failed sending tools/list_changed to session %s (reason=%s); session will be removed.", session_id, reason, exc_info=True, ) + return session_id, False + + results = await asyncio.gather( + *[_send_one(sid, sess) for sid, sess in session_items], + return_exceptions=False, + ) + + stale_session_ids = [sid for sid, ok in results if not ok] + sent_count = sum(1 for _, ok in results if ok) if stale_session_ids: with self._session_lock: @@ -187,15 +193,7 @@ def _build_stdio_tools_state_signature(self) -> tuple[tuple[str, tuple[str, ...] continue enabled_raw = payload.get("enabled_tools") - if isinstance(enabled_raw, set): - enabled_tools = tuple( - sorted( - tool_name - for tool_name in enabled_raw - if isinstance(tool_name, str) and tool_name - ) - ) - elif isinstance(enabled_raw, list): + if isinstance(enabled_raw, (set, list)): enabled_tools = tuple( sorted( tool_name diff --git a/Server/tests/test_unity_instance_middleware_stdio_edge_cases.py b/Server/tests/test_unity_instance_middleware_stdio_edge_cases.py new file mode 100644 index 000000000..1272dcf75 --- /dev/null +++ b/Server/tests/test_unity_instance_middleware_stdio_edge_cases.py @@ -0,0 +1,716 @@ +""" +Edge case tests for stdio tool toggle functionality. + +Tests cover boundary conditions, error handling, and concurrent scenarios +that may not occur in normal operation but could cause subtle bugs. +""" +import asyncio +import json +import os +from datetime import datetime, timedelta, timezone +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from core.config import config +from transport.unity_instance_middleware import UnityInstanceMiddleware + + +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + +def _tool_registry() -> list[dict]: + return [ + {"name": "manage_scene", "unity_target": "manage_scene"}, + {"name": "manage_script", "unity_target": "manage_script"}, + {"name": "manage_asset", "unity_target": "manage_asset"}, + {"name": "server_only_tool", "unity_target": None}, + ] + + +def _build_fastmcp_context(active_instance: str | None = None) -> Mock: + state = {} + if active_instance: + state["unity_instance"] = active_instance + + ctx = Mock() + ctx.client_id = "test-client" + ctx.set_state = Mock(side_effect=lambda key, value: state.__setitem__(key, value)) + ctx.get_state = Mock(side_effect=lambda key: state.get(key)) + return ctx + + +def _write_status_file(path, payload: dict) -> None: + path.write_text(json.dumps(payload), encoding="utf-8") + + +async def _filter_tool_names(middleware: UnityInstanceMiddleware, fastmcp_context: Mock) -> list[str]: + middleware_ctx = SimpleNamespace(fastmcp_context=fastmcp_context) + available_tools = [ + SimpleNamespace(name="manage_scene"), + SimpleNamespace(name="manage_script"), + SimpleNamespace(name="manage_asset"), + SimpleNamespace(name="server_only_tool"), + ] + + async def call_next(_ctx): + return available_tools + + with patch.object(middleware, "_inject_unity_instance", new=AsyncMock()): + with patch( + "transport.unity_instance_middleware.get_registered_tools", + return_value=_tool_registry(), + ): + filtered = await middleware.on_list_tools(middleware_ctx, call_next) + + return [tool.name for tool in filtered] + + +def _build_context(session_id: str, session_obj: object, method: str = "tools/list"): + fastmcp_context = SimpleNamespace( + request_context=object(), + session_id=session_id, + session=session_obj, + ) + return SimpleNamespace( + fastmcp_context=fastmcp_context, + method=method, + ) + + +# --------------------------------------------------------------------------- +# Status file content edge cases +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_empty_enabled_tools_list_passes_filter(monkeypatch, tmp_path): + """Empty enabled_tools list should result in only server-only tools visible.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + {"project_hash": "abc123", "enabled_tools": []}, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + # Only server-only tool should be visible + assert "server_only_tool" in names + assert "manage_scene" not in names + assert "manage_script" not in names + assert "manage_asset" not in names + + +@pytest.mark.asyncio +async def test_enabled_tools_with_null_elements_filtered(monkeypatch, tmp_path): + """Null elements in enabled_tools should be ignored.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + {"project_hash": "abc123", "enabled_tools": ["manage_scene", None, "manage_script", 123, ""]}, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + assert "manage_scene" in names + assert "manage_script" in names + assert "manage_asset" not in names + + +@pytest.mark.asyncio +async def test_enabled_tools_is_object_not_list_skipped(monkeypatch, tmp_path): + """If enabled_tools is an object instead of list, file should be skipped.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + {"project_hash": "abc123", "enabled_tools": {"manage_scene": True}}, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + # Should fall through without filtering + assert "manage_scene" in names + assert "manage_asset" in names + + +@pytest.mark.asyncio +async def test_project_hash_missing_uses_filename_hash(monkeypatch, tmp_path): + """If project_hash is missing, should extract from filename.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-fromfilename.json", + {"enabled_tools": ["manage_scene"]}, # No project_hash + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@fromfilename")) + + assert "manage_scene" in names + assert "manage_asset" not in names + + +@pytest.mark.asyncio +async def test_project_hash_empty_string_uses_filename_hash(monkeypatch, tmp_path): + """If project_hash is empty string, should extract from filename.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-filename123.json", + {"project_hash": "", "enabled_tools": ["manage_scene"]}, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@filename123")) + + assert "manage_scene" in names + + +@pytest.mark.asyncio +async def test_no_project_hash_and_no_filename_hash_skipped(monkeypatch, tmp_path): + """File without project_hash and no hash in filename should be skipped.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-.json", # Empty hash suffix + {"enabled_tools": ["manage_scene"]}, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@anyhash")) + + # Should fall through without filtering since no valid payload + assert "manage_scene" in names + assert "manage_asset" in names + + +# --------------------------------------------------------------------------- +# Heartbeat and TTL edge cases +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_heartbeat_with_z_suffix_parsed_correctly(monkeypatch, tmp_path): + """Heartbeat with 'Z' suffix should be parsed correctly.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + monkeypatch.setenv("UNITY_MCP_STDIO_STATUS_TTL_SECONDS", "60") + + # Recent heartbeat with Z suffix + heartbeat = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + {"project_hash": "abc123", "enabled_tools": ["manage_scene"], "last_heartbeat": heartbeat}, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + assert "manage_scene" in names + + +@pytest.mark.asyncio +async def test_heartbeat_without_timezone_treated_as_utc(monkeypatch, tmp_path): + """Heartbeat without timezone should be treated as UTC.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + monkeypatch.setenv("UNITY_MCP_STDIO_STATUS_TTL_SECONDS", "60") + + # Heartbeat without timezone + heartbeat = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S") + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + {"project_hash": "abc123", "enabled_tools": ["manage_scene"], "last_heartbeat": heartbeat}, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + assert "manage_scene" in names + + +@pytest.mark.asyncio +async def test_heartbeat_invalid_format_falls_back_to_mtime(monkeypatch, tmp_path): + """Invalid heartbeat format should fall back to file mtime.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + monkeypatch.setenv("UNITY_MCP_STDIO_STATUS_TTL_SECONDS", "60") + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + {"project_hash": "abc123", "enabled_tools": ["manage_scene"], "last_heartbeat": "not-a-date"}, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + # File is fresh by mtime, should filter + assert "manage_scene" in names + assert "manage_asset" not in names + + +@pytest.mark.asyncio +async def test_heartbeat_exactly_at_ttl_boundary_is_stale(monkeypatch, tmp_path): + """Heartbeat exactly at TTL boundary should be considered stale (>).""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + monkeypatch.setenv("UNITY_MCP_STDIO_STATUS_TTL_SECONDS", "10") + + # Heartbeat exactly TTL seconds ago (boundary case: > not >=) + boundary_heartbeat = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat() + slightly_fresh = (datetime.now(timezone.utc) - timedelta(seconds=9.9)).isoformat() + + _write_status_file( + tmp_path / "unity-mcp-status-boundary.json", + {"project_hash": "boundary", "enabled_tools": ["manage_scene"], "last_heartbeat": boundary_heartbeat}, + ) + _write_status_file( + tmp_path / "unity-mcp-status-fresh.json", + {"project_hash": "fresh", "enabled_tools": ["manage_asset"], "last_heartbeat": slightly_fresh}, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context(None)) + + # Boundary should be stale, fresh should be included + assert "manage_asset" in names + assert "manage_scene" not in names + + +@pytest.mark.asyncio +async def test_ttl_zero_uses_default(monkeypatch, tmp_path): + """TTL of 0 should fall back to default.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + monkeypatch.setenv("UNITY_MCP_STDIO_STATUS_TTL_SECONDS", "0") + + recent_heartbeat = datetime.now(timezone.utc).isoformat() + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + {"project_hash": "abc123", "enabled_tools": ["manage_scene"], "last_heartbeat": recent_heartbeat}, + ) + + middleware = UnityInstanceMiddleware() + # TTL=0 should use default of 15, so recent file should be fresh + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + assert "manage_scene" in names + + +@pytest.mark.asyncio +async def test_ttl_negative_uses_default(monkeypatch, tmp_path): + """Negative TTL should fall back to default.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + monkeypatch.setenv("UNITY_MCP_STDIO_STATUS_TTL_SECONDS", "-5") + + recent_heartbeat = datetime.now(timezone.utc).isoformat() + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + {"project_hash": "abc123", "enabled_tools": ["manage_scene"], "last_heartbeat": recent_heartbeat}, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + assert "manage_scene" in names + + +# --------------------------------------------------------------------------- +# Watch interval edge cases +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_watch_interval_below_minimum_clamped(monkeypatch): + """Watch interval below 0.2 should be clamped to 0.2.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STDIO_TOOLS_WATCH_INTERVAL_SECONDS", "0.1") + + middleware = UnityInstanceMiddleware() + assert middleware._get_stdio_tools_watch_interval_seconds() == 0.2 + + +@pytest.mark.asyncio +async def test_watch_interval_negative_clamped_to_minimum(monkeypatch): + """Negative watch interval should be clamped to minimum 0.2.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STDIO_TOOLS_WATCH_INTERVAL_SECONDS", "-1") + + middleware = UnityInstanceMiddleware() + # Negative values are parsed then clamped to minimum + assert middleware._get_stdio_tools_watch_interval_seconds() == 0.2 + + +@pytest.mark.asyncio +async def test_watch_interval_invalid_string_uses_default(monkeypatch): + """Invalid watch interval string should use default.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STDIO_TOOLS_WATCH_INTERVAL_SECONDS", "not-a-number") + + middleware = UnityInstanceMiddleware() + assert middleware._get_stdio_tools_watch_interval_seconds() == 1.0 + + +# --------------------------------------------------------------------------- +# Session tracking edge cases +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_track_session_with_none_context_returns_false(monkeypatch): + """None context should return False without error.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + middleware = UnityInstanceMiddleware() + + result = middleware._track_session_from_context(None) + assert result is False + + +@pytest.mark.asyncio +async def test_track_session_with_none_request_context_returns_false(monkeypatch): + """Context with None request_context should return False.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + middleware = UnityInstanceMiddleware() + + ctx = SimpleNamespace(request_context=None) + result = middleware._track_session_from_context(ctx) + assert result is False + + +@pytest.mark.asyncio +async def test_track_session_with_empty_session_id_returns_false(monkeypatch): + """Empty session_id should return False.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + middleware = UnityInstanceMiddleware() + + ctx = SimpleNamespace( + request_context=object(), + session_id="", + session=object(), + ) + result = middleware._track_session_from_context(ctx) + assert result is False + + +@pytest.mark.asyncio +async def test_track_session_with_none_session_id_returns_false(monkeypatch): + """None session_id should return False.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + middleware = UnityInstanceMiddleware() + + ctx = SimpleNamespace( + request_context=object(), + session_id=None, + session=object(), + ) + result = middleware._track_session_from_context(ctx) + assert result is False + + +@pytest.mark.asyncio +async def test_track_same_session_twice_returns_false_second_time(monkeypatch): + """Tracking the same session twice should return False on second call.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + middleware = UnityInstanceMiddleware() + + session = object() + ctx = SimpleNamespace( + request_context=object(), + session_id="session-1", + session=session, + ) + + result1 = middleware._track_session_from_context(ctx) + result2 = middleware._track_session_from_context(ctx) + + assert result1 is True + assert result2 is False + + +@pytest.mark.asyncio +async def test_notify_with_no_tracked_sessions_returns_early(monkeypatch): + """Notify with no tracked sessions should return immediately without error.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + middleware = UnityInstanceMiddleware() + + # Should not raise + await middleware._notify_tool_list_changed_to_sessions("test") + + +@pytest.mark.asyncio +async def test_all_sessions_fail_during_notify_clears_all(monkeypatch): + """If all sessions fail during notify, all should be cleared.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + middleware = UnityInstanceMiddleware() + + async def _raise(): + raise RuntimeError("connection lost") + + failing_session = SimpleNamespace(send_tool_list_changed=AsyncMock(side_effect=_raise)) + middleware._tracked_sessions["fail1"] = failing_session + middleware._tracked_sessions["fail2"] = failing_session + + await middleware._notify_tool_list_changed_to_sessions("test") + + assert len(middleware._tracked_sessions) == 0 + + +# --------------------------------------------------------------------------- +# Watcher lifecycle edge cases +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_start_watcher_twice_only_creates_one_task(monkeypatch): + """Starting watcher twice should only create one task.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STDIO_TOOLS_WATCH_INTERVAL_SECONDS", "1.0") + middleware = UnityInstanceMiddleware() + + await middleware.start_stdio_tools_watcher() + first_task = middleware._stdio_tools_watch_task + + await middleware.start_stdio_tools_watcher() + second_task = middleware._stdio_tools_watch_task + + assert first_task is second_task + assert not first_task.done() + + await middleware.stop_stdio_tools_watcher() + + +@pytest.mark.asyncio +async def test_stop_watcher_when_not_started_is_safe(monkeypatch): + """Stopping watcher when not started should not raise.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + middleware = UnityInstanceMiddleware() + + # Should not raise + await middleware.stop_stdio_tools_watcher() + assert middleware._stdio_tools_watch_task is None + + +@pytest.mark.asyncio +async def test_stop_watcher_clears_tracked_sessions(monkeypatch): + """Stopping watcher should clear all tracked sessions.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + middleware = UnityInstanceMiddleware() + middleware._tracked_sessions["s1"] = object() + middleware._tracked_sessions["s2"] = object() + + await middleware.stop_stdio_tools_watcher() + + assert len(middleware._tracked_sessions) == 0 + + +@pytest.mark.asyncio +async def test_watcher_continues_after_iteration_error(monkeypatch): + """Watcher should continue after an iteration error.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STDIO_TOOLS_WATCH_INTERVAL_SECONDS", "0.2") + middleware = UnityInstanceMiddleware() + middleware._notify_tool_list_changed_to_sessions = AsyncMock(return_value=None) + + call_count = {"value": 0} + error_on_call = 1 + + def _fake_signature(): + call_count["value"] += 1 + if call_count["value"] == error_on_call: + # First call is initial state, second call is first iteration + pass + if call_count["value"] == 2: + raise RuntimeError("simulated error") + if call_count["value"] >= 4: + # Stop after a few successful iterations + raise asyncio.CancelledError() + return (("hash1", ("tool1",)),) + + monkeypatch.setattr(middleware, "_build_stdio_tools_state_signature", _fake_signature) + + await middleware.start_stdio_tools_watcher() + try: + await asyncio.sleep(0.5) + except asyncio.CancelledError: + pass + finally: + await middleware.stop_stdio_tools_watcher() + + # Should have continued despite error on call 2 + assert call_count["value"] >= 3 + + +# --------------------------------------------------------------------------- +# Active instance resolution edge cases +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_active_instance_without_at_symbol_uses_full_string(monkeypatch, tmp_path): + """Active instance without @ should be treated as bare hash.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + {"project_hash": "abc123", "enabled_tools": ["manage_scene"]}, + ) + + middleware = UnityInstanceMiddleware() + # Bare hash without @ + names = await _filter_tool_names(middleware, _build_fastmcp_context("abc123")) + + assert "manage_scene" in names + assert "manage_asset" not in names + + +@pytest.mark.asyncio +async def test_active_instance_is_none_passes_all_tools(monkeypatch, tmp_path): + """None active instance should pass all tools when no status files.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context(None)) + + # All tools should pass + assert "manage_scene" in names + assert "manage_script" in names + assert "manage_asset" in names + assert "server_only_tool" in names + + +# --------------------------------------------------------------------------- +# Signature building edge cases +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_signature_empty_payloads_returns_empty(monkeypatch, tmp_path): + """Empty payloads should return empty signature.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + middleware = UnityInstanceMiddleware() + signature = middleware._build_stdio_tools_state_signature() + + assert signature == () + + +@pytest.mark.asyncio +async def test_signature_deduplicates_same_project_hash(monkeypatch, tmp_path): + """Multiple files with same project_hash should be deduplicated.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-abc111.json", + {"project_hash": "samehash", "enabled_tools": ["manage_scene"]}, + ) + _write_status_file( + tmp_path / "unity-mcp-status-abc222.json", + {"project_hash": "samehash", "enabled_tools": ["manage_asset"]}, + ) + + middleware = UnityInstanceMiddleware() + signature = middleware._build_stdio_tools_state_signature() + + # Only one entry for samehash (first by filename sort order) + assert len(signature) == 1 + assert signature[0][0] == "samehash" + + +@pytest.mark.asyncio +async def test_signature_enabled_tools_as_set_converted_to_tuple(monkeypatch, tmp_path): + """enabled_tools as set should be converted to sorted tuple.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + {"project_hash": "abc123", "enabled_tools": ["z_tool", "a_tool", "m_tool"]}, + ) + + middleware = UnityInstanceMiddleware() + signature = middleware._build_stdio_tools_state_signature() + + # Should be sorted + assert signature[0][1] == ("a_tool", "m_tool", "z_tool") + + +# --------------------------------------------------------------------------- +# Status directory edge cases +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_status_dir_does_not_exist_returns_empty(monkeypatch, tmp_path): + """Non-existent status directory should return empty payloads.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + non_existent = tmp_path / "does-not-exist" + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(non_existent)) + + middleware = UnityInstanceMiddleware() + payloads = middleware._list_stdio_status_payloads() + + assert payloads == [] + + +@pytest.mark.asyncio +async def test_status_dir_custom_path_expanded(monkeypatch, tmp_path): + """Custom status dir path with ~ should be expanded.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + # Use tmp_path as a substitute for home directory expansion + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + {"project_hash": "abc123", "enabled_tools": ["manage_scene"]}, + ) + + middleware = UnityInstanceMiddleware() + payloads = middleware._list_stdio_status_payloads() + + assert len(payloads) == 1 + assert payloads[0]["project_hash"] == "abc123" + + +# --------------------------------------------------------------------------- +# Concurrent operation tests +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_concurrent_start_stop_watcher_safe(monkeypatch): + """Concurrent start/stop operations should be safe.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STDIO_TOOLS_WATCH_INTERVAL_SECONDS", "0.5") + middleware = UnityInstanceMiddleware() + + # Start multiple times concurrently + await asyncio.gather( + middleware.start_stdio_tools_watcher(), + middleware.start_stdio_tools_watcher(), + middleware.start_stdio_tools_watcher(), + ) + + assert middleware._stdio_tools_watch_task is not None + assert not middleware._stdio_tools_watch_task.done() + + # Stop multiple times concurrently + await asyncio.gather( + middleware.stop_stdio_tools_watcher(), + middleware.stop_stdio_tools_watcher(), + ) + + assert middleware._stdio_tools_watch_task is None diff --git a/Server/tests/test_unity_instance_middleware_tool_list_notifications.py b/Server/tests/test_unity_instance_middleware_tool_list_notifications.py index d20b792fe..33236d62a 100644 --- a/Server/tests/test_unity_instance_middleware_tool_list_notifications.py +++ b/Server/tests/test_unity_instance_middleware_tool_list_notifications.py @@ -81,7 +81,15 @@ async def test_stdio_tools_watcher_notifies_on_signature_change(monkeypatch): monkeypatch.setattr(config, "transport_mode", "stdio") monkeypatch.setenv("UNITY_MCP_STDIO_TOOLS_WATCH_INTERVAL_SECONDS", "0.2") middleware = UnityInstanceMiddleware() - middleware._notify_tool_list_changed_to_sessions = AsyncMock(return_value=None) + + # Use an Event for deterministic synchronization instead of sleep + notification_event = asyncio.Event() + + async def _fake_notify(reason: str): + notification_event.set() + return None + + middleware._notify_tool_list_changed_to_sessions = AsyncMock(side_effect=_fake_notify) signatures = [ (("hash1", ("manage_scene",)),), @@ -101,7 +109,10 @@ def _fake_signature(): await middleware.start_stdio_tools_watcher() try: - await asyncio.sleep(0.35) + # Wait for the notification with a timeout instead of sleeping + await asyncio.wait_for(notification_event.wait(), timeout=1.0) + except asyncio.TimeoutError: + pass finally: await middleware.stop_stdio_tools_watcher() diff --git a/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/TransportCommandDispatcherToolToggleTests.cs b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/TransportCommandDispatcherToolToggleTests.cs index 6d771e89c..d296ecac6 100644 --- a/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/TransportCommandDispatcherToolToggleTests.cs +++ b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/TransportCommandDispatcherToolToggleTests.cs @@ -1,8 +1,6 @@ -using System; -using System.Reflection; using System.Threading; -using System.Threading.Tasks; using MCPForUnity.Editor.Constants; +using MCPForUnity.Editor.Services.Transport; using Newtonsoft.Json.Linq; using NUnit.Framework; using UnityEditor; @@ -51,7 +49,7 @@ public void ExecuteCommandJsonAsync_WhenToolDisabled_ReturnsDisabledError() }, }.ToString(); - string responseJson = ExecuteCommandJson(payload); + string responseJson = TransportCommandDispatcher.ExecuteCommandJsonAsync(payload, CancellationToken.None).GetAwaiter().GetResult(); var response = JObject.Parse(responseJson); string error = response["error"]?.ToString() ?? string.Empty; @@ -73,30 +71,11 @@ public void ExecuteCommandJsonAsync_WhenToolEnabled_DoesNotReturnDisabledError() }, }.ToString(); - string responseJson = ExecuteCommandJson(payload); + string responseJson = TransportCommandDispatcher.ExecuteCommandJsonAsync(payload, CancellationToken.None).GetAwaiter().GetResult(); var response = JObject.Parse(responseJson); string error = response["error"]?.ToString() ?? string.Empty; - Assert.Less(error.IndexOf("disabled in the Unity Editor", StringComparison.OrdinalIgnoreCase), 0); - } - - private static string ExecuteCommandJson(string commandJson) - { - Type dispatcherType = Type.GetType( - "MCPForUnity.Editor.Services.Transport.TransportCommandDispatcher, MCPForUnity.Editor"); - Assert.IsNotNull(dispatcherType, "Failed to resolve TransportCommandDispatcher type."); - - MethodInfo executeMethod = dispatcherType.GetMethod( - "ExecuteCommandJsonAsync", - BindingFlags.Public | BindingFlags.Static); - Assert.IsNotNull(executeMethod, "Failed to resolve ExecuteCommandJsonAsync."); - - var task = executeMethod.Invoke( - null, - new object[] { commandJson, CancellationToken.None }) as Task; - Assert.IsNotNull(task, "ExecuteCommandJsonAsync did not return Task."); - - return task.GetAwaiter().GetResult(); + StringAssert.DoesNotContain("disabled in the Unity Editor", error); } } } From 5b475a1de74d55317ae036423f1fe26f3ce22371 Mon Sep 17 00:00:00 2001 From: whatevertogo Date: Thu, 19 Feb 2026 01:03:12 +0800 Subject: [PATCH 06/11] test(middleware): add status file schema compatibility tests Add tests for handling future/partial schema without enabled_tools field and unknown extra fields in status files to ensure backward/forward compatibility. --- ...ty_instance_middleware_stdio_edge_cases.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/Server/tests/test_unity_instance_middleware_stdio_edge_cases.py b/Server/tests/test_unity_instance_middleware_stdio_edge_cases.py index 1272dcf75..e55cfa5bc 100644 --- a/Server/tests/test_unity_instance_middleware_stdio_edge_cases.py +++ b/Server/tests/test_unity_instance_middleware_stdio_edge_cases.py @@ -143,6 +143,57 @@ async def test_enabled_tools_is_object_not_list_skipped(monkeypatch, tmp_path): assert "manage_asset" in names +@pytest.mark.asyncio +async def test_status_file_with_future_state_field_but_no_enabled_tools_skipped(monkeypatch, tmp_path): + """Future/partial schema without enabled_tools should safely skip filtering.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + { + "project_hash": "abc123", + "enabled_tools_state": "unavailable", + "last_heartbeat": datetime.now(timezone.utc).isoformat(), + }, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + # Should fall through without filtering for backward/forward compatibility. + assert "manage_scene" in names + assert "manage_script" in names + assert "manage_asset" in names + assert "server_only_tool" in names + + +@pytest.mark.asyncio +async def test_status_file_with_unknown_extra_fields_does_not_break_filtering(monkeypatch, tmp_path): + """Unknown extra fields should be ignored and normal filtering should still apply.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + { + "project_hash": "abc123", + "enabled_tools": ["manage_scene"], + "enabled_tools_state": "ok", + "schema_version": 2, + "some_future_field": {"nested": True}, + }, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + assert "manage_scene" in names + assert "manage_script" not in names + assert "manage_asset" not in names + assert "server_only_tool" in names + + @pytest.mark.asyncio async def test_project_hash_missing_uses_filename_hash(monkeypatch, tmp_path): """If project_hash is missing, should extract from filename.""" From 01920d26dac1af507787d6b988d7c23d9939286e Mon Sep 17 00:00:00 2001 From: whatevertogo Date: Thu, 19 Feb 2026 02:31:05 +0800 Subject: [PATCH 07/11] feat(editor): add HTTP tool reregistration on tool toggle When set_mcp_tool_enabled is called, now reregisters tools with HTTP transport (PluginHub) in addition to refreshing stdio status file. This ensures MCP clients receive tools/list_changed notifications for both transport modes. Co-Authored-By: Claude Opus 4.6 --- MCPForUnity/Editor/Tools/ManageEditor.cs | 32 +++ .../Tools/ManageEditorToolToggleTests.cs | 185 ++++++++++++++++++ 2 files changed, 217 insertions(+) diff --git a/MCPForUnity/Editor/Tools/ManageEditor.cs b/MCPForUnity/Editor/Tools/ManageEditor.cs index 15e370481..3fb9be9cf 100644 --- a/MCPForUnity/Editor/Tools/ManageEditor.cs +++ b/MCPForUnity/Editor/Tools/ManageEditor.cs @@ -1,6 +1,8 @@ using System; +using System.Threading.Tasks; using MCPForUnity.Editor.Helpers; using MCPForUnity.Editor.Services; +using MCPForUnity.Editor.Services.Transport; using MCPForUnity.Editor.Services.Transport.Transports; using Newtonsoft.Json.Linq; using UnityEditor; @@ -232,6 +234,7 @@ private static object SetMcpToolEnabled(string toolName, bool enabled) MCPServiceLocator.ToolDiscovery.SetToolEnabled(metadata.Name, enabled); RefreshStdioStatusFile(); + RefreshHttpToolRegistration(); return new SuccessResponse( $"Tool '{metadata.Name}' {(enabled ? "enabled" : "disabled")} successfully.", @@ -312,6 +315,35 @@ private static void RefreshStdioStatusFile() } } + private static void RefreshHttpToolRegistration() + { + try + { + var transportManager = MCPServiceLocator.TransportManager; + var client = transportManager.GetClient(TransportMode.Http); + if (client == null || !client.IsConnected) + { + return; + } + + _ = Task.Run(async () => + { + try + { + await client.ReregisterToolsAsync().ConfigureAwait(false); + } + catch (Exception e) + { + McpLog.Warn($"Failed to reregister HTTP tools after tool toggle: {e.Message}"); + } + }); + } + catch (Exception e) + { + McpLog.Warn($"Failed to schedule HTTP tool reregistration after tool toggle: {e.Message}"); + } + } + // --- Tag Management Methods --- private static object AddTag(string tagName) diff --git a/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs index 44ce789a1..dcaf0a59d 100644 --- a/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs +++ b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs @@ -1,7 +1,11 @@ using System; using System.Linq; +using System.Threading; +using System.Threading.Tasks; using MCPForUnity.Editor.Constants; using MCPForUnity.Editor.Helpers; +using MCPForUnity.Editor.Services; +using MCPForUnity.Editor.Services.Transport; using MCPForUnity.Editor.Tools; using Newtonsoft.Json.Linq; using NUnit.Framework; @@ -15,6 +19,7 @@ public class ManageEditorToolToggleTests private string _targetToolPrefKey; private bool _hadTargetToolPref; private bool _previousTargetToolEnabled; + private TransportManager _originalTransportManager; [SetUp] public void SetUp() @@ -22,11 +27,17 @@ public void SetUp() _targetToolPrefKey = EditorPrefKeys.ToolEnabledPrefix + TargetTool; _hadTargetToolPref = EditorPrefs.HasKey(_targetToolPrefKey); _previousTargetToolEnabled = EditorPrefs.GetBool(_targetToolPrefKey, true); + _originalTransportManager = MCPServiceLocator.TransportManager; } [TearDown] public void TearDown() { + if (_originalTransportManager != null) + { + MCPServiceLocator.Register(_originalTransportManager); + } + if (_hadTargetToolPref) { EditorPrefs.SetBool(_targetToolPrefKey, _previousTargetToolEnabled); @@ -113,5 +124,179 @@ public void HandleCommand_ListMcpTools_ReturnsToolStateShape() Assert.IsNotNull(sceneTool["autoRegister"]); Assert.IsNotNull(sceneTool["isBuiltIn"]); } + + [Test] + public void HandleCommand_SetMcpToolEnabled_ReregistersHttpTools_WhenHttpClientConnected() + { + var httpClient = new FakeTransportClient(isConnected: true, "http"); + var stdioClient = new FakeTransportClient(isConnected: false, "stdio"); + var transportManager = new TransportManager(); + transportManager.Configure( + () => httpClient, + () => stdioClient); + + bool started = transportManager.StartAsync(TransportMode.Http).GetAwaiter().GetResult(); + Assert.IsTrue(started); + + MCPServiceLocator.Register(transportManager); + + var result = ManageEditor.HandleCommand(new JObject + { + ["action"] = "set_mcp_tool_enabled", + ["toolName"] = TargetTool, + ["enabled"] = false, + }); + + var response = JObject.FromObject(result); + Assert.AreEqual(true, response["success"]?.Value()); + Assert.IsTrue(httpClient.WaitForReregister(TimeSpan.FromSeconds(1)), "Expected HTTP tools to be reregistered."); + Assert.AreEqual(1, httpClient.ReregisterCalls); + } + + [Test] + public void HandleCommand_SetMcpToolEnabled_DoesNotFail_WhenHttpClientMissing() + { + var transportManager = new TransportManager(); + transportManager.Configure( + () => new FakeTransportClient(isConnected: true, "http"), + () => new FakeTransportClient(isConnected: false, "stdio")); + + MCPServiceLocator.Register(transportManager); + + var result = ManageEditor.HandleCommand(new JObject + { + ["action"] = "set_mcp_tool_enabled", + ["toolName"] = TargetTool, + ["enabled"] = false, + }); + + var response = JObject.FromObject(result); + Assert.AreEqual(true, response["success"]?.Value()); + } + + [Test] + public void HandleCommand_SetMcpToolEnabled_DoesNotReregister_WhenHttpClientDisconnected() + { + var httpClient = new FakeTransportClient(isConnected: false, "http"); + var stdioClient = new FakeTransportClient(isConnected: false, "stdio"); + var transportManager = new TransportManager(); + transportManager.Configure( + () => httpClient, + () => stdioClient); + + bool started = transportManager.StartAsync(TransportMode.Http).GetAwaiter().GetResult(); + Assert.IsFalse(started); + + MCPServiceLocator.Register(transportManager); + + var result = ManageEditor.HandleCommand(new JObject + { + ["action"] = "set_mcp_tool_enabled", + ["toolName"] = TargetTool, + ["enabled"] = false, + }); + + var response = JObject.FromObject(result); + Assert.AreEqual(true, response["success"]?.Value()); + Assert.IsFalse(httpClient.WaitForReregister(TimeSpan.FromMilliseconds(100))); + Assert.AreEqual(0, httpClient.ReregisterCalls); + } + + [Test] + public void HandleCommand_SetMcpToolEnabled_SwallowsReregisterErrors() + { + var httpClient = new FakeTransportClient(isConnected: true, "http", throwOnReregister: true); + var stdioClient = new FakeTransportClient(isConnected: false, "stdio"); + var transportManager = new TransportManager(); + transportManager.Configure( + () => httpClient, + () => stdioClient); + + bool started = transportManager.StartAsync(TransportMode.Http).GetAwaiter().GetResult(); + Assert.IsTrue(started); + + MCPServiceLocator.Register(transportManager); + + var result = ManageEditor.HandleCommand(new JObject + { + ["action"] = "set_mcp_tool_enabled", + ["toolName"] = TargetTool, + ["enabled"] = false, + }); + + var response = JObject.FromObject(result); + Assert.AreEqual(true, response["success"]?.Value()); + Assert.IsTrue(httpClient.WaitForReregister(TimeSpan.FromSeconds(1)), "Expected HTTP reregister to be attempted."); + Assert.AreEqual(1, httpClient.ReregisterCalls); + } + + [Test] + public void HandleCommand_SetMcpToolEnabled_DoesNotReregister_WhenValidationFails() + { + var httpClient = new FakeTransportClient(isConnected: true, "http"); + var stdioClient = new FakeTransportClient(isConnected: false, "stdio"); + var transportManager = new TransportManager(); + transportManager.Configure( + () => httpClient, + () => stdioClient); + + bool started = transportManager.StartAsync(TransportMode.Http).GetAwaiter().GetResult(); + Assert.IsTrue(started); + + MCPServiceLocator.Register(transportManager); + + var result = ManageEditor.HandleCommand(new JObject + { + ["action"] = "set_mcp_tool_enabled", + ["toolName"] = "manage_editor", + ["enabled"] = false, + }); + + var response = JObject.FromObject(result); + Assert.AreEqual(false, response["success"]?.Value()); + Assert.IsFalse(httpClient.WaitForReregister(TimeSpan.FromMilliseconds(100))); + Assert.AreEqual(0, httpClient.ReregisterCalls); + } + + private sealed class FakeTransportClient : IMcpTransportClient + { + private readonly bool _isConnected; + private readonly bool _throwOnReregister; + private readonly ManualResetEventSlim _reregisterSignal = new ManualResetEventSlim(false); + private int _reregisterCalls; + + public FakeTransportClient(bool isConnected, string name, bool throwOnReregister = false) + { + _isConnected = isConnected; + _throwOnReregister = throwOnReregister; + TransportName = name; + } + + public bool IsConnected => _isConnected; + public string TransportName { get; } + public TransportState State => _isConnected + ? TransportState.Connected(TransportName) + : TransportState.Disconnected(TransportName); + public int ReregisterCalls => _reregisterCalls; + + public Task StartAsync() => Task.FromResult(_isConnected); + + public Task StopAsync() => Task.CompletedTask; + + public Task VerifyAsync() => Task.FromResult(_isConnected); + + public Task ReregisterToolsAsync() + { + Interlocked.Increment(ref _reregisterCalls); + _reregisterSignal.Set(); + if (_throwOnReregister) + { + throw new InvalidOperationException("simulated reregister failure"); + } + return Task.CompletedTask; + } + + public bool WaitForReregister(TimeSpan timeout) => _reregisterSignal.Wait(timeout); + } } } From 8fe196aeb4d445d22918cc21528f5885bb0c48b8 Mon Sep 17 00:00:00 2001 From: whatevertogo <1879483647@qq.com> Date: Thu, 19 Feb 2026 17:32:03 +0800 Subject: [PATCH 08/11] Refactor instance resolution and simplify status refresh handling --- MCPForUnity/Editor/Tools/ManageEditor.cs | 9 +- .../src/services/tools/set_active_instance.py | 120 ++---------------- .../transport/unity_instance_middleware.py | 14 +- 3 files changed, 12 insertions(+), 131 deletions(-) diff --git a/MCPForUnity/Editor/Tools/ManageEditor.cs b/MCPForUnity/Editor/Tools/ManageEditor.cs index d204e4c34..dfb3a2999 100644 --- a/MCPForUnity/Editor/Tools/ManageEditor.cs +++ b/MCPForUnity/Editor/Tools/ManageEditor.cs @@ -299,14 +299,7 @@ private static object ListMcpTools() private static void RefreshStdioStatusFile() { - try - { - StdioBridgeHost.RefreshStatusFile("tool_toggle"); - } - catch (Exception e) - { - McpLog.Warn($"Failed to refresh stdio status file after tool toggle: {e.Message}"); - } + StdioBridgeHost.RefreshStatusFile("tool_toggle"); } // --- Tag Management Methods --- diff --git a/Server/src/services/tools/set_active_instance.py b/Server/src/services/tools/set_active_instance.py index ecdfb3a5c..640f84ca8 100644 --- a/Server/src/services/tools/set_active_instance.py +++ b/Server/src/services/tools/set_active_instance.py @@ -1,14 +1,10 @@ from typing import Annotated, Any -from types import SimpleNamespace from fastmcp import Context from mcp.types import ToolAnnotations from services.registry import mcp_for_unity_tool -from transport.legacy.unity_connection import get_unity_connection_pool from transport.unity_instance_middleware import get_unity_instance_middleware -from transport.plugin_hub import PluginHub -from core.config import config @mcp_for_unity_tool( @@ -22,72 +18,6 @@ async def set_active_instance( ctx: Context, instance: Annotated[str, "Target instance (Name@hash, hash prefix, or port number in stdio mode)"] ) -> dict[str, Any]: - transport = (config.transport_mode or "stdio").lower() - - # Port number shorthand (stdio only) — resolve to Name@hash via pool discovery - value = (instance or "").strip() - if value.isdigit(): - if transport == "http": - return { - "success": False, - "error": f"Port-based targeting ('{value}') is not supported in HTTP transport mode. " - "Use Name@hash or a hash prefix. Read mcpforunity://instances for available instances." - } - port_int = int(value) - pool = get_unity_connection_pool() - instances = pool.discover_all_instances(force_refresh=True) - match = next((inst for inst in instances if getattr(inst, "port", None) == port_int), None) - if match is None: - available = ", ".join( - f"{inst.id} (port {getattr(inst, 'port', '?')})" for inst in instances - ) or "none" - return { - "success": False, - "error": f"No Unity instance found on port {value}. Available: {available}." - } - resolved_id = match.id - middleware = get_unity_instance_middleware() - middleware.set_active_instance(ctx, resolved_id) - return { - "success": True, - "message": f"Active instance set to {resolved_id}", - "data": { - "instance": resolved_id, - "session_key": middleware.get_session_key(ctx), - }, - } - - # Discover running instances based on transport - if transport == "http": - # In remote-hosted mode, filter sessions by user_id - user_id = ctx.get_state( - "user_id") if config.http_remote_hosted else None - sessions_data = await PluginHub.get_sessions(user_id=user_id) - sessions = sessions_data.sessions - instances = [] - for session_id, session in sessions.items(): - project = session.project or "Unknown" - hash_value = session.hash - if not hash_value: - continue - inst_id = f"{project}@{hash_value}" - instances.append(SimpleNamespace( - id=inst_id, - hash=hash_value, - name=project, - session_id=session_id, - )) - else: - pool = get_unity_connection_pool() - instances = pool.discover_all_instances(force_refresh=True) - - if not instances: - return { - "success": False, - "error": "No Unity instances are currently connected. Start Unity and press 'Start Session'." - } - ids = {inst.id: inst for inst in instances if getattr(inst, "id", None)} - value = (instance or "").strip() if not value: return { @@ -95,60 +25,26 @@ async def set_active_instance( "error": "Instance identifier is required. " "Use mcpforunity://instances to copy a Name@hash or provide a hash prefix." } - resolved = None - if "@" in value: - resolved = ids.get(value) - if resolved is None: - return { - "success": False, - "error": f"Instance '{value}' not found. " - "Use mcpforunity://instances to copy an exact Name@hash." - } - else: - lookup = value.lower() - matches = [] - for inst in instances: - if not getattr(inst, "id", None): - continue - inst_hash = getattr(inst, "hash", "") - if inst_hash and inst_hash.lower().startswith(lookup): - matches.append(inst) - if not matches: - return { - "success": False, - "error": f"Instance hash '{value}' does not match any running Unity editors. " - "Use mcpforunity://instances to confirm the available hashes." - } - if len(matches) > 1: - matching_ids = ", ".join( - inst.id for inst in matches if getattr(inst, "id", None) - ) or "multiple instances" - return { - "success": False, - "error": f"Instance hash '{value}' is ambiguous ({matching_ids}). " - "Provide the full Name@hash from mcpforunity://instances." - } - resolved = matches[0] - - if resolved is None: - # Should be unreachable due to logic above, but satisfies static analysis + middleware = get_unity_instance_middleware() + try: + resolved_id = await middleware._resolve_instance_value(value, ctx) + except ValueError as exc: return { "success": False, - "error": "Internal error: Instance resolution failed." + "error": str(exc), } # Store selection in middleware (session-scoped) - middleware = get_unity_instance_middleware() # We use middleware.set_active_instance to persist the selection. # The session key is an internal detail but useful for debugging response. - middleware.set_active_instance(ctx, resolved.id) + middleware.set_active_instance(ctx, resolved_id) session_key = middleware.get_session_key(ctx) return { "success": True, - "message": f"Active instance set to {resolved.id}", + "message": f"Active instance set to {resolved_id}", "data": { - "instance": resolved.id, + "instance": resolved_id, "session_key": session_key, }, } diff --git a/Server/src/transport/unity_instance_middleware.py b/Server/src/transport/unity_instance_middleware.py index cb3992c69..c8e1fe9fb 100644 --- a/Server/src/transport/unity_instance_middleware.py +++ b/Server/src/transport/unity_instance_middleware.py @@ -5,7 +5,7 @@ into the request-scoped state, allowing tools to access it via ctx.get_state("unity_instance"). """ import asyncio -from threading import RLock +from threading import Lock, RLock from datetime import datetime, timezone import json import logging @@ -64,7 +64,7 @@ def __init__(self): self._active_by_key: dict[str, str] = {} self._lock = RLock() self._metadata_lock = RLock() - self._session_lock = RLock() + self._session_lock = Lock() self._unity_managed_tool_names: set[str] = set() self._tool_alias_to_unity_target: dict[str, str] = {} self._server_only_tool_names: set[str] = set() @@ -187,15 +187,7 @@ def _build_stdio_tools_state_signature(self) -> tuple[tuple[str, tuple[str, ...] continue enabled_raw = payload.get("enabled_tools") - if isinstance(enabled_raw, set): - enabled_tools = tuple( - sorted( - tool_name - for tool_name in enabled_raw - if isinstance(tool_name, str) and tool_name - ) - ) - elif isinstance(enabled_raw, list): + if isinstance(enabled_raw, (set, list)): enabled_tools = tuple( sorted( tool_name From 5feeb3375a5c9e377c10b6af3fe1dcd365695a3e Mon Sep 17 00:00:00 2001 From: whatevertogo Date: Thu, 19 Feb 2026 17:44:59 +0800 Subject: [PATCH 09/11] fix: harden stdio tool refresh and instance selection middleware --- MCPForUnity/Editor/Tools/ManageEditor.cs | 21 +-- .../transport/unity_instance_middleware.py | 175 +++++++++++------- ...ance_middleware_tool_list_notifications.py | 19 ++ 3 files changed, 138 insertions(+), 77 deletions(-) diff --git a/MCPForUnity/Editor/Tools/ManageEditor.cs b/MCPForUnity/Editor/Tools/ManageEditor.cs index 3fb9be9cf..03a9b883f 100644 --- a/MCPForUnity/Editor/Tools/ManageEditor.cs +++ b/MCPForUnity/Editor/Tools/ManageEditor.cs @@ -221,17 +221,19 @@ private static object SetMcpToolEnabled(string toolName, bool enabled) return new ErrorResponse("Tool name cannot be empty."); } - if (string.Equals(toolName, "manage_editor", StringComparison.OrdinalIgnoreCase) && !enabled) - { - return new ErrorResponse("Tool 'manage_editor' cannot be disabled."); - } - var metadata = MCPServiceLocator.ToolDiscovery.GetToolMetadata(toolName); if (metadata == null) { return new ErrorResponse($"Unknown tool '{toolName}'."); } + if (!enabled && ( + string.Equals(metadata.Name, "manage_editor", StringComparison.OrdinalIgnoreCase) || + string.Equals(metadata.Name, "set_active_instance", StringComparison.OrdinalIgnoreCase))) + { + return new ErrorResponse($"Tool '{metadata.Name}' cannot be disabled."); + } + MCPServiceLocator.ToolDiscovery.SetToolEnabled(metadata.Name, enabled); RefreshStdioStatusFile(); RefreshHttpToolRegistration(); @@ -305,14 +307,7 @@ private static void RefreshStdioStatusFile() if (!StdioBridgeHost.IsRunning) return; - try - { - StdioBridgeHost.RefreshStatusFile("tool_toggle"); - } - catch (Exception e) - { - McpLog.Warn($"Failed to refresh stdio status file after tool toggle: {e.Message}"); - } + StdioBridgeHost.RefreshStatusFile("tool_toggle"); } private static void RefreshHttpToolRegistration() diff --git a/Server/src/transport/unity_instance_middleware.py b/Server/src/transport/unity_instance_middleware.py index 3c5bd5d49..b52852cae 100644 --- a/Server/src/transport/unity_instance_middleware.py +++ b/Server/src/transport/unity_instance_middleware.py @@ -73,6 +73,7 @@ def __init__(self): self._tool_visibility_refresh_interval_seconds = 0.5 self._has_logged_empty_registry_warning = False self._tracked_sessions: dict[str, object] = {} + self._initial_refresh_notified_sessions: set[str] = set() self._stdio_tools_watch_task: asyncio.Task | None = None self._last_stdio_tools_state_signature: tuple[tuple[str, tuple[str, ...]], ...] | None = None @@ -144,6 +145,31 @@ def _track_session_from_context(self, fastmcp_context) -> bool: return True + @staticmethod + def _try_get_session_id_from_context(fastmcp_context) -> str | None: + if fastmcp_context is None or fastmcp_context.request_context is None: + return None + + try: + session_id = fastmcp_context.session_id + except RuntimeError: + return None + + if not isinstance(session_id, str) or not session_id: + return None + + return session_id + + def _should_send_initial_tool_list_refresh(self, session_id: str | None) -> bool: + if session_id is None: + return True + + with self._session_lock: + if session_id in self._initial_refresh_notified_sessions: + return False + self._initial_refresh_notified_sessions.add(session_id) + return True + async def _notify_tool_list_changed_to_sessions(self, reason: str) -> None: with self._session_lock: session_items = list(self._tracked_sessions.items()) @@ -176,6 +202,7 @@ async def _send_one(session_id: str, session): with self._session_lock: for session_id in stale_session_ids: self._tracked_sessions.pop(session_id, None) + self._initial_refresh_notified_sessions.discard(session_id) if sent_count: logger.debug( @@ -268,6 +295,7 @@ async def stop_stdio_tools_watcher(self) -> None: with self._session_lock: self._tracked_sessions.clear() + self._initial_refresh_notified_sessions.clear() # ========================================================================= # Per-call instance routing (PR #772) @@ -591,7 +619,10 @@ async def on_call_tool(self, context: MiddlewareContext, call_next): async def on_message(self, context: MiddlewareContext, call_next): if self._is_stdio_transport(): is_new_session = self._track_session_from_context(context.fastmcp_context) - if is_new_session: + session_id = self._try_get_session_id_from_context(context.fastmcp_context) + # A new stdio session needs one immediate tools/list refresh so the + # client does not wait for background polling. + if is_new_session and self._should_send_initial_tool_list_refresh(session_id): await self._notify_tool_list_changed_to_sessions("session_registered") return await call_next(context) @@ -600,7 +631,11 @@ async def on_notification(self, context: MiddlewareContext, call_next): if self._is_stdio_transport(): self._track_session_from_context(context.fastmcp_context) if context.method == "notifications/initialized": - await self._notify_tool_list_changed_to_sessions("client_initialized") + session_id = self._try_get_session_id_from_context(context.fastmcp_context) + # Some clients send a normal message before notifications/initialized. + # Deduplicate the initial push to avoid back-to-back list_changed noise. + if self._should_send_initial_tool_list_refresh(session_id): + await self._notify_tool_list_changed_to_sessions("client_initialized") return await call_next(context) @@ -762,6 +797,74 @@ def _resolve_enabled_tool_names_for_stdio_context(self, active_instance: str | N # project_hash; after de-duplication this leaves exactly one project entry. return next(iter(enabled_by_project_hash.values())) + def _parse_single_status_file( + self, + status_file: Path, + now_utc: datetime, + status_ttl_seconds: float, + ) -> dict[str, object] | None: + file_hash = self._extract_project_hash_from_filename(status_file) + try: + with status_file.open("r", encoding="utf-8") as handle: + raw_payload = json.load(handle) + except (OSError, ValueError) as exc: + logger.debug( + "Failed to parse stdio status file %s: %s", + status_file, + exc, + exc_info=True, + ) + return None + + if not isinstance(raw_payload, dict): + logger.debug("Skipping stdio status file %s with non-object payload.", status_file) + return None + + enabled_tools_raw = raw_payload.get("enabled_tools") + if not isinstance(enabled_tools_raw, list): + # Missing enabled_tools means the status format is too old for safe filtering. + logger.debug("Skipping stdio status file %s without enabled_tools field.", status_file) + return None + + enabled_tools = { + tool_name + for tool_name in enabled_tools_raw + if isinstance(tool_name, str) and tool_name + } + + freshness = self._parse_heartbeat_datetime(raw_payload.get("last_heartbeat")) + if freshness is None: + try: + freshness = datetime.fromtimestamp(status_file.stat().st_mtime, tz=timezone.utc) + except OSError: + logger.debug( + "Failed to read mtime for stdio status file %s; skipping for safety.", + status_file, + exc_info=True, + ) + return None + + if (now_utc - freshness).total_seconds() > status_ttl_seconds: + logger.debug( + "Skipping stale stdio status file %s (age exceeds %ss).", + status_file, + status_ttl_seconds, + ) + return None + + project_hash = raw_payload.get("project_hash") + if not isinstance(project_hash, str) or not project_hash: + project_hash = file_hash + + if not project_hash: + logger.debug("Skipping stdio status file %s without project hash.", status_file) + return None + + return { + "project_hash": project_hash, + "enabled_tools": enabled_tools, + } + def _list_stdio_status_payloads(self) -> list[dict[str, object]]: status_ttl_seconds = self._get_stdio_status_ttl_seconds() now_utc = datetime.now(timezone.utc) @@ -785,69 +888,13 @@ def _list_stdio_status_payloads(self) -> list[dict[str, object]]: payloads: list[dict[str, object]] = [] for status_file in status_files: - file_hash = self._extract_project_hash_from_filename(status_file) - try: - with status_file.open("r", encoding="utf-8") as handle: - raw_payload = json.load(handle) - except (OSError, ValueError) as exc: - logger.debug( - "Failed to parse stdio status file %s: %s", - status_file, - exc, - exc_info=True, - ) - continue - - if not isinstance(raw_payload, dict): - logger.debug("Skipping stdio status file %s with non-object payload.", status_file) - continue - - enabled_tools_raw = raw_payload.get("enabled_tools") - if not isinstance(enabled_tools_raw, list): - # Missing enabled_tools means the status format is too old for safe filtering. - logger.debug("Skipping stdio status file %s without enabled_tools field.", status_file) - continue - - enabled_tools = { - tool_name - for tool_name in enabled_tools_raw - if isinstance(tool_name, str) and tool_name - } - - freshness = self._parse_heartbeat_datetime(raw_payload.get("last_heartbeat")) - if freshness is None: - try: - freshness = datetime.fromtimestamp(status_file.stat().st_mtime, tz=timezone.utc) - except OSError: - logger.debug( - "Failed to read mtime for stdio status file %s; skipping for safety.", - status_file, - exc_info=True, - ) - continue - - if (now_utc - freshness).total_seconds() > status_ttl_seconds: - logger.debug( - "Skipping stale stdio status file %s (age exceeds %ss).", - status_file, - status_ttl_seconds, - ) - continue - - project_hash = raw_payload.get("project_hash") - if not isinstance(project_hash, str) or not project_hash: - project_hash = file_hash - - if not project_hash: - logger.debug("Skipping stdio status file %s without project hash.", status_file) - continue - - payloads.append( - { - "project_hash": project_hash, - "enabled_tools": enabled_tools, - } + parsed_payload = self._parse_single_status_file( + status_file=status_file, + now_utc=now_utc, + status_ttl_seconds=status_ttl_seconds, ) + if parsed_payload is not None: + payloads.append(parsed_payload) return payloads diff --git a/Server/tests/test_unity_instance_middleware_tool_list_notifications.py b/Server/tests/test_unity_instance_middleware_tool_list_notifications.py index 33236d62a..c9a8b8dd2 100644 --- a/Server/tests/test_unity_instance_middleware_tool_list_notifications.py +++ b/Server/tests/test_unity_instance_middleware_tool_list_notifications.py @@ -46,6 +46,25 @@ async def test_on_notification_initialized_triggers_tools_list_changed(monkeypat assert session.send_tool_list_changed.await_count == 1 +@pytest.mark.asyncio +async def test_initialized_notification_does_not_duplicate_session_registration_refresh(monkeypatch): + monkeypatch.setattr(config, "transport_mode", "stdio") + middleware = UnityInstanceMiddleware() + session = SimpleNamespace(send_tool_list_changed=AsyncMock()) + + message_context = _build_context("session-init-dedupe", session) + initialized_context = _build_context( + "session-init-dedupe", + session, + method="notifications/initialized", + ) + + await middleware.on_message(message_context, AsyncMock(return_value=None)) + await middleware.on_notification(initialized_context, AsyncMock(return_value=None)) + + assert session.send_tool_list_changed.await_count == 1 + + @pytest.mark.asyncio async def test_notify_tool_list_changed_removes_stale_sessions(monkeypatch): monkeypatch.setattr(config, "transport_mode", "stdio") From 683a6f5d4a476d9f44b25b7d0426202fa83e5aaf Mon Sep 17 00:00:00 2001 From: whatevertogo Date: Thu, 19 Feb 2026 18:19:02 +0800 Subject: [PATCH 10/11] test: update set_active_instance integration mocks for middleware resolver --- .../integration/test_inline_unity_instance.py | 9 +-- .../test_instance_routing_comprehensive.py | 60 +++++++------------ 2 files changed, 23 insertions(+), 46 deletions(-) diff --git a/Server/tests/integration/test_inline_unity_instance.py b/Server/tests/integration/test_inline_unity_instance.py index 13b13c32b..7d3eaec5b 100644 --- a/Server/tests/integration/test_inline_unity_instance.py +++ b/Server/tests/integration/test_inline_unity_instance.py @@ -355,12 +355,9 @@ def test_set_active_instance_port_stdio(monkeypatch): pool_instance = SimpleNamespace(id="Proj@abc123", hash="abc123", port=6401) - class FakePool: - def discover_all_instances(self, force_refresh=False): - return [pool_instance] - - import services.tools.set_active_instance as sat - monkeypatch.setattr(sat, "get_unity_connection_pool", lambda: FakePool()) + async def fake_discover(_ctx): + return [pool_instance] + monkeypatch.setattr(mw, "_discover_instances", fake_discover) from services.tools.set_active_instance import set_active_instance diff --git a/Server/tests/integration/test_instance_routing_comprehensive.py b/Server/tests/integration/test_instance_routing_comprehensive.py index e96e69685..ee0870085 100644 --- a/Server/tests/integration/test_instance_routing_comprehensive.py +++ b/Server/tests/integration/test_instance_routing_comprehensive.py @@ -12,13 +12,13 @@ """ import pytest from unittest.mock import AsyncMock, Mock, MagicMock, patch +from types import SimpleNamespace from fastmcp import Context from core.config import config from transport.unity_instance_middleware import UnityInstanceMiddleware from services.tools import get_unity_instance_from_context from services.tools.set_active_instance import set_active_instance as set_active_instance_tool -from transport.models import SessionList, SessionDetails class TestInstanceRoutingBasics: @@ -164,7 +164,7 @@ class TestInstanceRoutingHTTP: @pytest.mark.asyncio async def test_set_active_instance_http_transport(self, monkeypatch): - """set_active_instance should enumerate PluginHub sessions under HTTP.""" + """set_active_instance should resolve Name@hash under HTTP transport.""" middleware = UnityInstanceMiddleware() ctx = Mock(spec=Context) ctx.session_id = "http-session" @@ -174,19 +174,10 @@ async def test_set_active_instance_http_transport(self, monkeypatch): ctx.get_state = Mock(side_effect=lambda k: state_storage.get(k)) monkeypatch.setattr(config, "transport_mode", "http") - fake_sessions = SessionList( - sessions={ - "sess-1": SessionDetails( - project="Ramble", - hash="8e29de57", - unity_version="6000.2.10f1", - connected_at="2025-11-21T03:30:03.682353+00:00", - ) - } - ) monkeypatch.setattr( - "services.tools.set_active_instance.PluginHub.get_sessions", - AsyncMock(return_value=fake_sessions), + middleware, + "_discover_instances", + AsyncMock(return_value=[SimpleNamespace(id="Ramble@8e29de57", hash="8e29de57")]), ) monkeypatch.setattr( "services.tools.set_active_instance.get_unity_instance_middleware", @@ -200,7 +191,7 @@ async def test_set_active_instance_http_transport(self, monkeypatch): @pytest.mark.asyncio async def test_set_active_instance_http_hash_only(self, monkeypatch): - """Hash-only selection should resolve via PluginHub registry.""" + """Hash-only selection should resolve to the matching Name@hash.""" middleware = UnityInstanceMiddleware() ctx = Mock(spec=Context) ctx.session_id = "http-session-2" @@ -210,26 +201,17 @@ async def test_set_active_instance_http_hash_only(self, monkeypatch): ctx.get_state = Mock(side_effect=lambda k: state_storage.get(k)) monkeypatch.setattr(config, "transport_mode", "http") - fake_sessions = SessionList( - sessions={ - "sess-99": SessionDetails( - project="UnityMCPTests", - hash="cc8756d4", - unity_version="2021.3.45f2", - connected_at="2025-11-21T03:37:01.501022+00:00", - ) - } - ) monkeypatch.setattr( - "services.tools.set_active_instance.PluginHub.get_sessions", - AsyncMock(return_value=fake_sessions), + middleware, + "_discover_instances", + AsyncMock(return_value=[SimpleNamespace(id="UnityMCPTests@cc8756d4", hash="cc8756d4")]), ) monkeypatch.setattr( "services.tools.set_active_instance.get_unity_instance_middleware", lambda: middleware, ) - result = await set_active_instance_tool(ctx, "UnityMCPTests@cc8756d4") + result = await set_active_instance_tool(ctx, "cc8756d4") assert result["success"] is True assert middleware.get_active_instance(ctx) == "UnityMCPTests@cc8756d4" @@ -242,10 +224,10 @@ async def test_set_active_instance_http_hash_missing(self, monkeypatch): ctx.session_id = "http-session-3" monkeypatch.setattr(config, "transport_mode", "http") - fake_sessions = SessionList(sessions={}) monkeypatch.setattr( - "services.tools.set_active_instance.PluginHub.get_sessions", - AsyncMock(return_value=fake_sessions), + middleware, + "_discover_instances", + AsyncMock(return_value=[]), ) monkeypatch.setattr( "services.tools.set_active_instance.get_unity_instance_middleware", @@ -255,7 +237,7 @@ async def test_set_active_instance_http_hash_missing(self, monkeypatch): result = await set_active_instance_tool(ctx, "Unknown@deadbeef") assert result["success"] is False - assert "No Unity instances" in result["error"] + assert "not found" in result["error"] @pytest.mark.asyncio async def test_set_active_instance_http_hash_ambiguous(self, monkeypatch): @@ -265,15 +247,13 @@ async def test_set_active_instance_http_hash_ambiguous(self, monkeypatch): ctx.session_id = "http-session-4" monkeypatch.setattr(config, "transport_mode", "http") - fake_sessions = SessionList( - sessions={ - "sess-a": SessionDetails(project="ProjA", hash="abc12345", unity_version="2022", connected_at="now"), - "sess-b": SessionDetails(project="ProjB", hash="abc98765", unity_version="2022", connected_at="now"), - } - ) monkeypatch.setattr( - "services.tools.set_active_instance.PluginHub.get_sessions", - AsyncMock(return_value=fake_sessions), + middleware, + "_discover_instances", + AsyncMock(return_value=[ + SimpleNamespace(id="ProjA@abc12345", hash="abc12345"), + SimpleNamespace(id="ProjB@abc98765", hash="abc98765"), + ]), ) monkeypatch.setattr( "services.tools.set_active_instance.get_unity_instance_middleware", From b52dcad6157a39c738598e58e92a468fe36891dd Mon Sep 17 00:00:00 2001 From: whatevertogo Date: Fri, 20 Feb 2026 22:58:44 +0800 Subject: [PATCH 11/11] fix: streamline tool disabling logic and enhance tool visibility checks --- MCPForUnity/Editor/Tools/ManageEditor.cs | 4 +- .../transport/unity_instance_middleware.py | 81 ++++++++--- ...ty_instance_middleware_stdio_edge_cases.py | 136 ++++++++++++++---- ...ance_middleware_tool_list_notifications.py | 51 +++++++ .../Tools/ManageEditorToolToggleTests.cs | 16 +++ 5 files changed, 241 insertions(+), 47 deletions(-) diff --git a/MCPForUnity/Editor/Tools/ManageEditor.cs b/MCPForUnity/Editor/Tools/ManageEditor.cs index 03a9b883f..3e4e07330 100644 --- a/MCPForUnity/Editor/Tools/ManageEditor.cs +++ b/MCPForUnity/Editor/Tools/ManageEditor.cs @@ -227,9 +227,7 @@ private static object SetMcpToolEnabled(string toolName, bool enabled) return new ErrorResponse($"Unknown tool '{toolName}'."); } - if (!enabled && ( - string.Equals(metadata.Name, "manage_editor", StringComparison.OrdinalIgnoreCase) || - string.Equals(metadata.Name, "set_active_instance", StringComparison.OrdinalIgnoreCase))) + if (!enabled && string.Equals(metadata.Name, "manage_editor", StringComparison.OrdinalIgnoreCase)) { return new ErrorResponse($"Tool '{metadata.Name}' cannot be disabled."); } diff --git a/Server/src/transport/unity_instance_middleware.py b/Server/src/transport/unity_instance_middleware.py index 53443cd96..d4736c327 100644 --- a/Server/src/transport/unity_instance_middleware.py +++ b/Server/src/transport/unity_instance_middleware.py @@ -181,6 +181,14 @@ async def _send_one(session_id: str, session): try: await session.send_tool_list_changed() return session_id, True + except asyncio.CancelledError: + logger.debug( + "tools/list_changed send cancelled for session %s (reason=%s); session will be removed.", + session_id, + reason, + exc_info=True, + ) + return session_id, False except Exception: logger.debug( "Failed sending tools/list_changed to session %s (reason=%s); session will be removed.", @@ -192,11 +200,27 @@ async def _send_one(session_id: str, session): results = await asyncio.gather( *[_send_one(sid, sess) for sid, sess in session_items], - return_exceptions=False, + return_exceptions=True, ) - stale_session_ids = [sid for sid, ok in results if not ok] - sent_count = sum(1 for _, ok in results if ok) + normalized_results: list[tuple[str, bool]] = [] + for result in results: + if isinstance(result, BaseException): + logger.debug( + "Unexpected exception while broadcasting tools/list_changed (reason=%s): %s", + reason, + type(result).__name__, + exc_info=( + type(result), + result, + result.__traceback__, + ), + ) + continue + normalized_results.append(result) + + stale_session_ids = [sid for sid, ok in normalized_results if not ok] + sent_count = sum(1 for _, ok in normalized_results if ok) if stale_session_ids: with self._session_lock: @@ -655,10 +679,11 @@ async def on_list_tools(self, context: MiddlewareContext, call_next): if enabled_tool_names is None: return tools + tool_visibility_snapshot = self._get_tool_visibility_snapshot() filtered = [] for tool in tools: tool_name = getattr(tool, "name", None) - if self._is_tool_visible(tool_name, enabled_tool_names): + if self._is_tool_visible(tool_name, enabled_tool_names, tool_visibility_snapshot): filtered.append(tool) return filtered @@ -683,7 +708,7 @@ async def _resolve_enabled_tool_names_for_context( if transport == "stdio": active_instance = ctx.get_state("unity_instance") - return self._resolve_enabled_tool_names_for_stdio_context(active_instance) + return await self._resolve_enabled_tool_names_for_stdio_context(active_instance) user_id = ctx.get_state("user_id") if config.http_remote_hosted else None active_instance = ctx.get_state("unity_instance") @@ -757,8 +782,8 @@ async def _resolve_enabled_tool_names_for_context( return enabled_tool_names - def _resolve_enabled_tool_names_for_stdio_context(self, active_instance: str | None) -> set[str] | None: - status_payloads = self._list_stdio_status_payloads() + async def _resolve_enabled_tool_names_for_stdio_context(self, active_instance: str | None) -> set[str] | None: + status_payloads = await asyncio.to_thread(self._list_stdio_status_payloads) if not status_payloads: return None @@ -870,11 +895,18 @@ def _list_stdio_status_payloads(self) -> list[dict[str, object]]: status_dir = Path(status_dir_env).expanduser() if status_dir_env else Path.home().joinpath(".unity-mcp") try: - status_files = sorted( - status_dir.glob("unity-mcp-status-*.json"), - key=lambda path: path.stat().st_mtime, - reverse=True, - ) + status_file_entries: list[tuple[float, Path]] = [] + for status_file in status_dir.glob("unity-mcp-status-*.json"): + try: + mtime = status_file.stat().st_mtime + except OSError: + logger.debug( + "Skipping stdio status file %s because stat() failed.", + status_file, + exc_info=True, + ) + continue + status_file_entries.append((mtime, status_file)) except OSError as exc: logger.debug( "Failed to enumerate stdio status files from %s: %s", @@ -884,6 +916,9 @@ def _list_stdio_status_payloads(self) -> list[dict[str, object]]: ) return [] + status_file_entries.sort(key=lambda item: item[0], reverse=True) + status_files = [path for _, path in status_file_entries] + payloads: list[dict[str, object]] = [] for status_file in status_files: parsed_payload = self._parse_single_status_file( @@ -1003,6 +1038,14 @@ def _refresh_tool_visibility_metadata_from_registry(self) -> None: self._tool_visibility_signature = signature self._last_tool_visibility_refresh = now + def _get_tool_visibility_snapshot(self) -> tuple[set[str], dict[str, str], set[str]]: + with self._metadata_lock: + return ( + set(self._unity_managed_tool_names), + dict(self._tool_alias_to_unity_target), + set(self._server_only_tool_names), + ) + @staticmethod def _resolve_candidate_project_hashes(active_instance: str | None) -> list[str]: if not active_instance: @@ -1014,22 +1057,28 @@ def _resolve_candidate_project_hashes(active_instance: str | None) -> list[str]: return [active_instance] - def _is_tool_visible(self, tool_name: str | None, enabled_tool_names: set[str]) -> bool: + def _is_tool_visible( + self, + tool_name: str | None, + enabled_tool_names: set[str], + tool_visibility_snapshot: tuple[set[str], dict[str, str], set[str]], + ) -> bool: + unity_managed_tool_names, tool_alias_to_unity_target, server_only_tool_names = tool_visibility_snapshot if not isinstance(tool_name, str) or not tool_name: return True - if tool_name in self._server_only_tool_names: + if tool_name in server_only_tool_names: return True if tool_name in enabled_tool_names: return True - unity_target = self._tool_alias_to_unity_target.get(tool_name) + unity_target = tool_alias_to_unity_target.get(tool_name) if unity_target: return unity_target in enabled_tool_names # Keep unknown tools visible for forward compatibility. - if tool_name not in self._unity_managed_tool_names: + if tool_name not in unity_managed_tool_names: return True return False diff --git a/Server/tests/test_unity_instance_middleware_stdio_edge_cases.py b/Server/tests/test_unity_instance_middleware_stdio_edge_cases.py index e55cfa5bc..f3cdb991c 100644 --- a/Server/tests/test_unity_instance_middleware_stdio_edge_cases.py +++ b/Server/tests/test_unity_instance_middleware_stdio_edge_cases.py @@ -6,8 +6,8 @@ """ import asyncio import json -import os from datetime import datetime, timedelta, timezone +from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock, Mock, patch @@ -68,18 +68,6 @@ async def call_next(_ctx): return [tool.name for tool in filtered] -def _build_context(session_id: str, session_obj: object, method: str = "tools/list"): - fastmcp_context = SimpleNamespace( - request_context=object(), - session_id=session_id, - session=session_obj, - ) - return SimpleNamespace( - fastmcp_context=fastmcp_context, - method=method, - ) - - # --------------------------------------------------------------------------- # Status file content edge cases # --------------------------------------------------------------------------- @@ -294,6 +282,48 @@ async def test_heartbeat_without_timezone_treated_as_utc(monkeypatch, tmp_path): assert "manage_scene" in names +@pytest.mark.asyncio +async def test_heartbeat_with_positive_timezone_offset_parsed_correctly(monkeypatch, tmp_path): + """Heartbeat with positive timezone offset should be parsed and accepted.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + monkeypatch.setenv("UNITY_MCP_STDIO_STATUS_TTL_SECONDS", "60") + + ist = timezone(timedelta(hours=5, minutes=30)) + heartbeat = datetime.now(timezone.utc).astimezone(ist).isoformat() + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + {"project_hash": "abc123", "enabled_tools": ["manage_scene"], "last_heartbeat": heartbeat}, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + assert "manage_scene" in names + + +@pytest.mark.asyncio +async def test_heartbeat_with_negative_timezone_offset_parsed_correctly(monkeypatch, tmp_path): + """Heartbeat with negative timezone offset should be parsed and accepted.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + monkeypatch.setenv("UNITY_MCP_STDIO_STATUS_TTL_SECONDS", "60") + + pst = timezone(timedelta(hours=-8)) + heartbeat = datetime.now(timezone.utc).astimezone(pst).isoformat() + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + {"project_hash": "abc123", "enabled_tools": ["manage_scene"], "last_heartbeat": heartbeat}, + ) + + middleware = UnityInstanceMiddleware() + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + assert "manage_scene" in names + + @pytest.mark.asyncio async def test_heartbeat_invalid_format_falls_back_to_mtime(monkeypatch, tmp_path): """Invalid heartbeat format should fall back to file mtime.""" @@ -315,19 +345,18 @@ async def test_heartbeat_invalid_format_falls_back_to_mtime(monkeypatch, tmp_pat @pytest.mark.asyncio -async def test_heartbeat_exactly_at_ttl_boundary_is_stale(monkeypatch, tmp_path): - """Heartbeat exactly at TTL boundary should be considered stale (>).""" +async def test_heartbeat_near_ttl_boundary_filters_only_clearly_stale_payload(monkeypatch, tmp_path): + """Near TTL boundary, clearly stale payloads are excluded while fresh payloads remain.""" monkeypatch.setattr(config, "transport_mode", "stdio") monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) monkeypatch.setenv("UNITY_MCP_STDIO_STATUS_TTL_SECONDS", "10") - # Heartbeat exactly TTL seconds ago (boundary case: > not >=) - boundary_heartbeat = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat() - slightly_fresh = (datetime.now(timezone.utc) - timedelta(seconds=9.9)).isoformat() + stale_heartbeat = (datetime.now(timezone.utc) - timedelta(seconds=10.2)).isoformat() + slightly_fresh = (datetime.now(timezone.utc) - timedelta(seconds=9.8)).isoformat() _write_status_file( tmp_path / "unity-mcp-status-boundary.json", - {"project_hash": "boundary", "enabled_tools": ["manage_scene"], "last_heartbeat": boundary_heartbeat}, + {"project_hash": "boundary", "enabled_tools": ["manage_scene"], "last_heartbeat": stale_heartbeat}, ) _write_status_file( tmp_path / "unity-mcp-status-fresh.json", @@ -337,7 +366,7 @@ async def test_heartbeat_exactly_at_ttl_boundary_is_stale(monkeypatch, tmp_path) middleware = UnityInstanceMiddleware() names = await _filter_tool_names(middleware, _build_fastmcp_context(None)) - # Boundary should be stale, fresh should be included + # Stale should be excluded, fresh should be included. assert "manage_asset" in names assert "manage_scene" not in names @@ -418,6 +447,34 @@ async def test_watch_interval_invalid_string_uses_default(monkeypatch): assert middleware._get_stdio_tools_watch_interval_seconds() == 1.0 +@pytest.mark.asyncio +async def test_stdio_list_tools_reads_status_files_via_to_thread(monkeypatch, tmp_path): + """Stdio tools/list filtering should offload status-file reads with asyncio.to_thread.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + _write_status_file( + tmp_path / "unity-mcp-status-abc123.json", + {"project_hash": "abc123", "enabled_tools": ["manage_scene"]}, + ) + + middleware = UnityInstanceMiddleware() + to_thread_call_count = {"value": 0} + + async def _fake_to_thread(func, *args, **kwargs): + to_thread_call_count["value"] += 1 + return func(*args, **kwargs) + + with patch( + "transport.unity_instance_middleware.asyncio.to_thread", + new=AsyncMock(side_effect=_fake_to_thread), + ): + names = await _filter_tool_names(middleware, _build_fastmcp_context("Project@abc123")) + + assert to_thread_call_count["value"] >= 1 + assert "manage_scene" in names + + # --------------------------------------------------------------------------- # Session tracking edge cases # --------------------------------------------------------------------------- @@ -576,18 +633,16 @@ async def test_watcher_continues_after_iteration_error(monkeypatch): middleware = UnityInstanceMiddleware() middleware._notify_tool_list_changed_to_sessions = AsyncMock(return_value=None) + progressed_after_error = asyncio.Event() call_count = {"value": 0} - error_on_call = 1 def _fake_signature(): call_count["value"] += 1 - if call_count["value"] == error_on_call: - # First call is initial state, second call is first iteration - pass if call_count["value"] == 2: raise RuntimeError("simulated error") + if call_count["value"] >= 3: + progressed_after_error.set() if call_count["value"] >= 4: - # Stop after a few successful iterations raise asyncio.CancelledError() return (("hash1", ("tool1",)),) @@ -595,9 +650,7 @@ def _fake_signature(): await middleware.start_stdio_tools_watcher() try: - await asyncio.sleep(0.5) - except asyncio.CancelledError: - pass + await asyncio.wait_for(progressed_after_error.wait(), timeout=1.0) finally: await middleware.stop_stdio_tools_watcher() @@ -737,6 +790,33 @@ async def test_status_dir_custom_path_expanded(monkeypatch, tmp_path): assert payloads[0]["project_hash"] == "abc123" +@pytest.mark.asyncio +async def test_status_dir_skips_files_that_fail_stat_without_failing_enumeration(monkeypatch, tmp_path): + """A single stat() failure should be skipped instead of failing the full status scan.""" + monkeypatch.setattr(config, "transport_mode", "stdio") + monkeypatch.setenv("UNITY_MCP_STATUS_DIR", str(tmp_path)) + + flaky_path = tmp_path / "unity-mcp-status-flaky.json" + good_path = tmp_path / "unity-mcp-status-good.json" + + _write_status_file(flaky_path, {"project_hash": "flaky", "enabled_tools": ["manage_asset"]}) + _write_status_file(good_path, {"project_hash": "good", "enabled_tools": ["manage_scene"]}) + + original_stat = Path.stat + + def _fake_stat(path: Path, *args, **kwargs): + if path == flaky_path: + raise FileNotFoundError("simulated race: file disappeared between glob and stat") + return original_stat(path, *args, **kwargs) + + middleware = UnityInstanceMiddleware() + with patch("pathlib.Path.stat", new=_fake_stat): + payloads = middleware._list_stdio_status_payloads() + + assert len(payloads) == 1 + assert payloads[0]["project_hash"] == "good" + + # --------------------------------------------------------------------------- # Concurrent operation tests # --------------------------------------------------------------------------- diff --git a/Server/tests/test_unity_instance_middleware_tool_list_notifications.py b/Server/tests/test_unity_instance_middleware_tool_list_notifications.py index c9a8b8dd2..0d59eb4f7 100644 --- a/Server/tests/test_unity_instance_middleware_tool_list_notifications.py +++ b/Server/tests/test_unity_instance_middleware_tool_list_notifications.py @@ -86,6 +86,57 @@ async def _raise_send(): assert healthy_session.send_tool_list_changed.await_count == 1 +@pytest.mark.asyncio +async def test_notify_tool_list_changed_handles_cancelled_session_send(monkeypatch): + monkeypatch.setattr(config, "transport_mode", "stdio") + middleware = UnityInstanceMiddleware() + + healthy_session = SimpleNamespace(send_tool_list_changed=AsyncMock(return_value=None)) + + async def _raise_cancelled(): + raise asyncio.CancelledError() + + cancelled_session = SimpleNamespace(send_tool_list_changed=AsyncMock(side_effect=_raise_cancelled)) + middleware._tracked_sessions["healthy"] = healthy_session + middleware._tracked_sessions["cancelled"] = cancelled_session + + await middleware._notify_tool_list_changed_to_sessions("test_reason") + + assert "healthy" in middleware._tracked_sessions + assert "cancelled" not in middleware._tracked_sessions + assert healthy_session.send_tool_list_changed.await_count == 1 + + +@pytest.mark.asyncio +async def test_on_message_is_noop_for_http_transport(monkeypatch): + monkeypatch.setattr(config, "transport_mode", "http") + middleware = UnityInstanceMiddleware() + session = SimpleNamespace(send_tool_list_changed=AsyncMock()) + context = _build_context("session-http-message", session) + call_next = AsyncMock(return_value="ok") + + result = await middleware.on_message(context, call_next) + + assert result == "ok" + call_next.assert_awaited_once() + assert session.send_tool_list_changed.await_count == 0 + + +@pytest.mark.asyncio +async def test_on_notification_is_noop_for_http_transport(monkeypatch): + monkeypatch.setattr(config, "transport_mode", "http") + middleware = UnityInstanceMiddleware() + session = SimpleNamespace(send_tool_list_changed=AsyncMock()) + context = _build_context("session-http-notification", session, method="notifications/initialized") + call_next = AsyncMock(return_value="ok") + + result = await middleware.on_notification(context, call_next) + + assert result == "ok" + call_next.assert_awaited_once() + assert session.send_tool_list_changed.await_count == 0 + + @pytest.mark.asyncio async def test_start_stdio_tools_watcher_skips_when_transport_is_not_stdio(monkeypatch): monkeypatch.setattr(config, "transport_mode", "http") diff --git a/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs index dcaf0a59d..0a03a76c3 100644 --- a/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs +++ b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Tools/ManageEditorToolToggleTests.cs @@ -78,6 +78,22 @@ public void HandleCommand_SetMcpToolEnabled_RejectsDisablingManageEditor() StringAssert.Contains("cannot be disabled", response["error"]?.ToString()); } + [Test] + public void HandleCommand_SetMcpToolEnabled_SetActiveInstanceIsUnknownTool() + { + var result = ManageEditor.HandleCommand(new JObject + { + ["action"] = "set_mcp_tool_enabled", + ["toolName"] = "set_active_instance", + ["enabled"] = false, + }); + + var response = JObject.FromObject(result); + Assert.AreEqual(false, response["success"]?.Value()); + StringAssert.Contains("Unknown tool", response["error"]?.ToString()); + StringAssert.Contains("set_active_instance", response["error"]?.ToString()); + } + [Test] public void HandleCommand_GetMcpToolEnabled_ReturnsCurrentState() {