From fb59d56fa4e176c19146a519e557851ee7d4deab Mon Sep 17 00:00:00 2001 From: David Sarno Date: Fri, 20 Feb 2026 13:58:26 -0800 Subject: [PATCH] HTTP transport: pass retry_on_reload through and add stale connection detection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #795 — the HTTP/WebSocket transport path (PluginHub) now respects retry_on_reload=False and performs pre-send liveness checks, matching the stdio path behavior from PR #792. Co-Authored-By: Codex --- Server/src/transport/plugin_hub.py | 129 ++++++++++++++++-- Server/src/transport/unity_transport.py | 5 + .../test_domain_reload_resilience.py | 60 ++++++++ .../tests/integration/test_transport_smoke.py | 33 +++++ 4 files changed, 217 insertions(+), 10 deletions(-) diff --git a/Server/src/transport/plugin_hub.py b/Server/src/transport/plugin_hub.py index fa8aa3a98..805ec4ace 100644 --- a/Server/src/transport/plugin_hub.py +++ b/Server/src/transport/plugin_hub.py @@ -10,7 +10,7 @@ from typing import Any, ClassVar from starlette.endpoints import WebSocketEndpoint -from starlette.websockets import WebSocket +from starlette.websockets import WebSocket, WebSocketState from core.config import config from core.constants import API_KEY_HEADER @@ -546,11 +546,102 @@ async def _get_connection(cls, session_id: str) -> WebSocket: raise RuntimeError(f"Plugin session {session_id} not connected") return websocket + @classmethod + async def _evict_connection(cls, session_id: str, reason: str) -> None: + """Drop a stale session from in-memory maps and registry.""" + lock = cls._lock + if lock is None: + return + + websocket: WebSocket | None = None + ping_task: asyncio.Task | None = None + pending_futures: list[asyncio.Future] = [] + async with lock: + websocket = cls._connections.pop(session_id, None) + ping_task = cls._ping_tasks.pop(session_id, None) + cls._last_pong.pop(session_id, None) + keys_to_remove: list[object] = [] + for key, entry in list(cls._pending.items()): + if entry.get("session_id") == session_id: + future = entry.get("future") + if future and not future.done(): + pending_futures.append(future) + keys_to_remove.append(key) + for key in keys_to_remove: + cls._pending.pop(key, None) + + if ping_task is not None and not ping_task.done(): + ping_task.cancel() + + for future in pending_futures: + if not future.done(): + future.set_exception( + PluginDisconnectedError( + f"Unity plugin session {session_id} disconnected while awaiting command_result" + ) + ) + + if websocket is not None: + try: + await websocket.close(code=1001) + except Exception: + pass + + if cls._registry is not None: + try: + await cls._registry.unregister(session_id) + except Exception: + logger.debug( + "Failed to unregister evicted plugin session %s", + session_id, + exc_info=True, + ) + + logger.debug("Evicted plugin session %s (%s)", session_id, reason) + + @classmethod + async def _ensure_live_connection(cls, session_id: str) -> bool: + """Best-effort pre-send liveness check for a plugin WebSocket.""" + try: + websocket = await cls._get_connection(session_id) + except RuntimeError: + await cls._evict_connection(session_id, "missing_websocket") + return False + + if ( + websocket.client_state == WebSocketState.CONNECTED + and websocket.application_state == WebSocketState.CONNECTED + ): + return True + + logger.debug( + "Detected stale plugin connection before send: session=%s app_state=%s client_state=%s", + session_id, + websocket.application_state, + websocket.client_state, + ) + await cls._evict_connection(session_id, "stale_websocket_state") + return False + + @staticmethod + def _unavailable_retry_response(reason: str = "no_unity_session") -> dict[str, Any]: + return MCPResponse( + success=False, + error="Unity session not available; please retry", + hint="retry", + data={"reason": reason, "retry_after_ms": 250}, + ).model_dump() + # ------------------------------------------------------------------ # Session resolution helpers # ------------------------------------------------------------------ @classmethod - async def _resolve_session_id(cls, unity_instance: str | None, user_id: str | None = None) -> str: + async def _resolve_session_id( + cls, + unity_instance: str | None, + user_id: str | None = None, + retry_on_reload: bool = True, + ) -> str: """Resolve a project hash (Unity instance id) to an active plugin session. During Unity domain reloads the plugin's WebSocket session is torn down @@ -561,6 +652,7 @@ async def _resolve_session_id(cls, unity_instance: str | None, user_id: str | No Args: unity_instance: Target instance (Name@hash or hash) user_id: User ID from API key validation (for remote-hosted mode session isolation) + retry_on_reload: If False, do not wait for reconnects when no session is present. """ if cls._registry is None: raise RuntimeError("Plugin registry not configured") @@ -589,6 +681,8 @@ async def _resolve_session_id(cls, unity_instance: str | None, user_id: str | No max_wait_s = 20.0 # Clamp to [0, 20] to prevent misconfiguration from causing excessive waits max_wait_s = max(0.0, min(max_wait_s, 20.0)) + if not retry_on_reload: + max_wait_s = 0.0 retry_ms = float(getattr(config, "reload_retry_ms", 250)) sleep_seconds = max(0.05, min(0.25, retry_ms / 1000.0)) @@ -684,6 +778,7 @@ async def send_command_for_instance( command_type: str, params: dict[str, Any], user_id: str | None = None, + retry_on_reload: bool = True, ) -> dict[str, Any]: """Send a command to a Unity instance. @@ -692,28 +787,42 @@ async def send_command_for_instance( command_type: Command type to execute params: Command parameters user_id: User ID for session isolation in remote-hosted mode + retry_on_reload: If False, do not wait for session reconnect on reload. """ try: - session_id = await cls._resolve_session_id(unity_instance, user_id=user_id) + session_id = await cls._resolve_session_id( + unity_instance, + user_id=user_id, + retry_on_reload=retry_on_reload, + ) except NoUnitySessionError: logger.debug( "Unity session unavailable; returning retry: command=%s instance=%s", command_type, unity_instance or "default", ) - return MCPResponse( - success=False, - error="Unity session not available; please retry", - hint="retry", - data={"reason": "no_unity_session", "retry_after_ms": 250}, - ).model_dump() + return cls._unavailable_retry_response("no_unity_session") + + if not await cls._ensure_live_connection(session_id): + if not retry_on_reload: + return cls._unavailable_retry_response("stale_connection") + try: + session_id = await cls._resolve_session_id( + unity_instance, + user_id=user_id, + retry_on_reload=True, + ) + except NoUnitySessionError: + return cls._unavailable_retry_response("no_unity_session") + if not await cls._ensure_live_connection(session_id): + return cls._unavailable_retry_response("stale_connection") # During domain reload / immediate reconnect windows, the plugin may be connected but not yet # ready to process execute commands on the Unity main thread (which can be further delayed when # the Unity Editor is unfocused). For fast-path commands, we do a bounded readiness probe using # a main-thread ping command (handled by TransportCommandDispatcher) rather than waiting on # register_tools (which can be delayed by EditorApplication.delayCall). - if command_type in cls._FAST_FAIL_COMMANDS and command_type != "ping": + if retry_on_reload and command_type in cls._FAST_FAIL_COMMANDS and command_type != "ping": try: max_wait_s = float(os.environ.get( "UNITY_MCP_SESSION_READY_WAIT_SECONDS", "6")) diff --git a/Server/src/transport/unity_transport.py b/Server/src/transport/unity_transport.py index d55ab6fd6..f3b1d3d3f 100644 --- a/Server/src/transport/unity_transport.py +++ b/Server/src/transport/unity_transport.py @@ -71,12 +71,17 @@ async def send_with_unity_instance( ).model_dump() ) + retry_on_reload = kwargs.pop("retry_on_reload", True) + if not isinstance(retry_on_reload, bool): + retry_on_reload = True + try: raw = await PluginHub.send_command_for_instance( unity_instance, command_type, params, user_id=user_id, + retry_on_reload=retry_on_reload, ) return normalize_unity_response(raw) except Exception as exc: diff --git a/Server/tests/integration/test_domain_reload_resilience.py b/Server/tests/integration/test_domain_reload_resilience.py index b2efddac6..803eec7b0 100644 --- a/Server/tests/integration/test_domain_reload_resilience.py +++ b/Server/tests/integration/test_domain_reload_resilience.py @@ -103,6 +103,66 @@ async def mock_list_sessions(**kwargs): PluginHub._lock = original_lock +@pytest.mark.asyncio +async def test_plugin_hub_no_wait_when_retry_disabled(monkeypatch): + """retry_on_reload=False should skip reconnect wait loops.""" + from transport.plugin_hub import PluginHub, NoUnitySessionError + from transport.plugin_registry import PluginRegistry + + mock_registry = AsyncMock(spec=PluginRegistry) + mock_registry.get_session_id_by_hash = AsyncMock(return_value=None) + mock_registry.list_sessions = AsyncMock(return_value={}) + + original_registry = PluginHub._registry + original_lock = PluginHub._lock + PluginHub._registry = mock_registry + PluginHub._lock = asyncio.Lock() + + monkeypatch.setenv("UNITY_MCP_SESSION_RESOLVE_MAX_WAIT_S", "20.0") + + try: + with pytest.raises(NoUnitySessionError): + await PluginHub._resolve_session_id( + unity_instance="hash-missing", + retry_on_reload=False, + ) + + assert mock_registry.get_session_id_by_hash.await_count == 1 + assert mock_registry.list_sessions.await_count == 1 + finally: + PluginHub._registry = original_registry + PluginHub._lock = original_lock + + +@pytest.mark.asyncio +async def test_send_command_for_instance_fails_fast_on_stale_when_retry_disabled(monkeypatch): + """Stale HTTP session should not send command when retry_on_reload is disabled.""" + from transport.plugin_hub import PluginHub + + resolve_mock = AsyncMock(return_value="sess-stale") + ensure_mock = AsyncMock(return_value=False) + send_mock = AsyncMock() + + monkeypatch.setattr(PluginHub, "_resolve_session_id", resolve_mock) + monkeypatch.setattr(PluginHub, "_ensure_live_connection", ensure_mock) + monkeypatch.setattr(PluginHub, "send_command", send_mock) + + result = await PluginHub.send_command_for_instance( + unity_instance="Project@hash-stale", + command_type="manage_script", + params={"action": "edit"}, + retry_on_reload=False, + ) + + assert result["success"] is False + assert result["hint"] == "retry" + assert result.get("data", {}).get("reason") == "stale_connection" + assert resolve_mock.await_count == 1 + _, resolve_kwargs = resolve_mock.await_args + assert resolve_kwargs.get("retry_on_reload") is False + send_mock.assert_not_awaited() + + @pytest.mark.asyncio async def test_read_console_during_simulated_reload(monkeypatch): """ diff --git a/Server/tests/integration/test_transport_smoke.py b/Server/tests/integration/test_transport_smoke.py index b39bb0500..564919773 100644 --- a/Server/tests/integration/test_transport_smoke.py +++ b/Server/tests/integration/test_transport_smoke.py @@ -59,6 +59,39 @@ async def _unused_send_fn(*_args, **_kwargs): assert result["data"] == {"via": "http-remote"} +@pytest.mark.asyncio +async def test_http_forwards_retry_on_reload(monkeypatch): + """HTTP transport should pass retry_on_reload through to PluginHub.""" + monkeypatch.setattr(config, "transport_mode", "http") + monkeypatch.setattr(config, "http_remote_hosted", False) + + captured: dict[str, object] = {} + + async def fake_send_command_for_instance(_instance, _command, _params, **kwargs): + captured.update(kwargs) + return {"status": "success", "result": {"data": {"via": "http"}}} + + monkeypatch.setattr( + unity_transport.PluginHub, + "send_command_for_instance", + fake_send_command_for_instance, + ) + + async def _unused_send_fn(*_args, **_kwargs): + raise AssertionError("send_fn should not be used in HTTP mode") + + result = await unity_transport.send_with_unity_instance( + _unused_send_fn, + None, + "manage_script", + {"action": "edit"}, + retry_on_reload=False, + ) + + assert result["success"] is True + assert captured.get("retry_on_reload") is False + + @pytest.mark.asyncio async def test_stdio_smoke(monkeypatch): """Stdio transport should call the legacy send fn with instance_id."""