Skip to content

Commit 73f2295

Browse files
authored
Merge pull request #735 from dsarno/fix/723-rebased
fix: preserve tool toggle state and filter tools in client listings (rebased #723)
2 parents e6d179f + eeaecba commit 73f2295

23 files changed

Lines changed: 1291 additions & 38 deletions

MCPForUnity/Editor/Services/ToolDiscoveryService.cs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -226,16 +226,6 @@ private void EnsurePreferenceInitialized(ToolMetadata metadata)
226226
{
227227
bool defaultValue = metadata.AutoRegister || metadata.IsBuiltIn;
228228
EditorPrefs.SetBool(key, defaultValue);
229-
return;
230-
}
231-
232-
if (metadata.IsBuiltIn && !metadata.AutoRegister)
233-
{
234-
bool currentValue = EditorPrefs.GetBool(key, metadata.AutoRegister);
235-
if (currentValue == metadata.AutoRegister)
236-
{
237-
EditorPrefs.SetBool(key, true);
238-
}
239229
}
240230
}
241231

MCPForUnity/Editor/Services/Transport/IMcpTransportClient.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ public interface IMcpTransportClient
1414
Task<bool> StartAsync();
1515
Task StopAsync();
1616
Task<bool> VerifyAsync();
17+
Task ReregisterToolsAsync();
1718
}
1819
}

MCPForUnity/Editor/Services/Transport/TransportManager.cs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,6 @@ private IMcpTransportClient GetOrCreateClient(TransportMode mode)
4242
};
4343
}
4444

45-
private IMcpTransportClient GetClient(TransportMode mode)
46-
{
47-
return mode switch
48-
{
49-
TransportMode.Http => _httpClient,
50-
TransportMode.Stdio => _stdioClient,
51-
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
52-
};
53-
}
54-
5545
public async Task<bool> StartAsync(TransportMode mode)
5646
{
5747
IMcpTransportClient client = GetOrCreateClient(mode);
@@ -163,6 +153,20 @@ public void ForceStop(TransportMode mode)
163153
}
164154
}
165155

156+
/// <summary>
157+
/// Gets the active transport client for the specified mode.
158+
/// Returns null if the client hasn't been created yet.
159+
/// </summary>
160+
public IMcpTransportClient GetClient(TransportMode mode)
161+
{
162+
return mode switch
163+
{
164+
TransportMode.Http => _httpClient,
165+
TransportMode.Stdio => _stdioClient,
166+
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
167+
};
168+
}
169+
166170
private void UpdateState(TransportMode mode, TransportState state)
167171
{
168172
switch (mode)

MCPForUnity/Editor/Services/Transport/Transports/StdioTransportClient.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,12 @@ public Task<bool> VerifyAsync()
4646
return Task.FromResult(running);
4747
}
4848

49+
public Task ReregisterToolsAsync()
50+
{
51+
// Stdio transport doesn't support dynamic tool reregistration
52+
// Tools are registered at server startup
53+
return Task.CompletedTask;
54+
}
55+
4956
}
5057
}

MCPForUnity/Editor/Services/Transport/Transports/WebSocketTransportClient.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,29 @@ private async Task SendRegisterToolsAsync(CancellationToken token)
562562
McpLog.Info($"[WebSocket] Sent {tools.Count} tools registration", false);
563563
}
564564

565+
public async Task ReregisterToolsAsync()
566+
{
567+
if (!IsConnected || _lifecycleCts == null)
568+
{
569+
McpLog.Warn("[WebSocket] Cannot reregister tools: not connected");
570+
return;
571+
}
572+
573+
try
574+
{
575+
await SendRegisterToolsAsync(_lifecycleCts.Token).ConfigureAwait(false);
576+
McpLog.Info("[WebSocket] Tool reregistration completed", false);
577+
}
578+
catch (System.OperationCanceledException)
579+
{
580+
McpLog.Warn("[WebSocket] Tool reregistration cancelled");
581+
}
582+
catch (System.Exception ex)
583+
{
584+
McpLog.Error($"[WebSocket] Tool reregistration failed: {ex.Message}");
585+
}
586+
}
587+
565588
private async Task HandleExecuteAsync(JObject payload, CancellationToken token)
566589
{
567590
string commandId = payload.Value<string>("id");

MCPForUnity/Editor/Windows/Components/Tools/McpToolsSection.cs

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Linq;
4+
using System.Threading.Tasks;
45
using MCPForUnity.Editor.Constants;
56
using MCPForUnity.Editor.Helpers;
67
using MCPForUnity.Editor.Services;
8+
using MCPForUnity.Editor.Services.Transport;
79
using MCPForUnity.Editor.Tools;
810
using UnityEditor;
11+
using UnityEditor.UIElements;
912
using UnityEngine.UIElements;
1013

1114
namespace MCPForUnity.Editor.Windows.Components.Tools
@@ -228,23 +231,63 @@ private VisualElement CreateToolRow(ToolMetadata tool)
228231
return row;
229232
}
230233

231-
private void HandleToggleChange(ToolMetadata tool, bool enabled, bool updateSummary = true)
234+
private void HandleToggleChange(
235+
ToolMetadata tool,
236+
bool enabled,
237+
bool updateSummary = true,
238+
bool reregisterTools = true)
232239
{
233240
MCPServiceLocator.ToolDiscovery.SetToolEnabled(tool.Name, enabled);
234241

235242
if (updateSummary)
236243
{
237244
UpdateSummary();
238245
}
246+
247+
if (reregisterTools)
248+
{
249+
// Trigger tool reregistration with connected MCP server
250+
ReregisterToolsAsync();
251+
}
252+
}
253+
254+
private void ReregisterToolsAsync()
255+
{
256+
// Fire and forget - don't block UI thread
257+
var transportManager = MCPServiceLocator.TransportManager;
258+
var client = transportManager.GetClient(TransportMode.Http);
259+
if (client == null || !client.IsConnected)
260+
{
261+
return;
262+
}
263+
264+
_ = Task.Run(async () =>
265+
{
266+
try
267+
{
268+
await client.ReregisterToolsAsync().ConfigureAwait(false);
269+
}
270+
catch (Exception ex)
271+
{
272+
McpLog.Warn($"Failed to reregister tools: {ex}");
273+
}
274+
});
239275
}
240276

241277
private void SetAllToolsState(bool enabled)
242278
{
279+
bool hasChanges = false;
280+
243281
foreach (var tool in allTools)
244282
{
245283
if (!toolToggleMap.TryGetValue(tool.Name, out var toggle))
246284
{
247-
MCPServiceLocator.ToolDiscovery.SetToolEnabled(tool.Name, enabled);
285+
bool currentEnabled = MCPServiceLocator.ToolDiscovery.IsToolEnabled(tool.Name);
286+
if (currentEnabled != enabled)
287+
{
288+
MCPServiceLocator.ToolDiscovery.SetToolEnabled(tool.Name, enabled);
289+
hasChanges = true;
290+
}
248291
continue;
249292
}
250293

@@ -254,10 +297,17 @@ private void SetAllToolsState(bool enabled)
254297
}
255298

256299
toggle.SetValueWithoutNotify(enabled);
257-
HandleToggleChange(tool, enabled, updateSummary: false);
300+
HandleToggleChange(tool, enabled, updateSummary: false, reregisterTools: false);
301+
hasChanges = true;
258302
}
259303

260304
UpdateSummary();
305+
306+
if (hasChanges)
307+
{
308+
// Trigger a single reregistration after bulk change
309+
ReregisterToolsAsync();
310+
}
261311
}
262312

263313
private void UpdateSummary()

Server/src/services/custom_tool_service.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from starlette.requests import Request
1111
from starlette.responses import JSONResponse
1212

13+
from core.config import config
1314
from models.models import MCPResponse, ToolDefinitionModel, ToolParameterModel
1415
from core.logging_decorator import log_execution
1516
from core.telemetry_decorator import telemetry_tool
@@ -28,6 +29,23 @@
2829
_MAX_POLL_SECONDS = 600
2930

3031

32+
def get_user_id_from_context(ctx: Context) -> str | None:
33+
"""Read user_id from request-scoped context in remote-hosted mode."""
34+
if not config.http_remote_hosted:
35+
return None
36+
37+
get_state = getattr(ctx, "get_state", None)
38+
if not callable(get_state):
39+
return None
40+
41+
try:
42+
user_id = get_state("user_id")
43+
except Exception:
44+
return None
45+
46+
return user_id if isinstance(user_id, str) and user_id else None
47+
48+
3149
class RegisterToolsPayload(BaseModel):
3250
project_id: str
3351
project_hash: str | None = None
@@ -84,30 +102,40 @@ async def register_tools(request: Request) -> JSONResponse:
84102
return JSONResponse(response.model_dump())
85103

86104
# --- Public API for MCP tools ---------------------------------------
87-
async def list_registered_tools(self, project_id: str) -> list[ToolDefinitionModel]:
105+
async def list_registered_tools(
106+
self,
107+
project_id: str,
108+
user_id: str | None = None,
109+
) -> list[ToolDefinitionModel]:
88110
legacy = list(self._project_tools.get(project_id, {}).values())
89-
hub_tools = await PluginHub.get_tools_for_project(project_id)
111+
hub_tools = await PluginHub.get_tools_for_project(project_id, user_id=user_id)
90112
return legacy + hub_tools
91113

92-
async def get_tool_definition(self, project_id: str, tool_name: str) -> ToolDefinitionModel | None:
114+
async def get_tool_definition(
115+
self,
116+
project_id: str,
117+
tool_name: str,
118+
user_id: str | None = None,
119+
) -> ToolDefinitionModel | None:
93120
tool = self._project_tools.get(project_id, {}).get(tool_name)
94121
if tool:
95122
return tool
96-
return await PluginHub.get_tool_definition(project_id, tool_name)
123+
return await PluginHub.get_tool_definition(project_id, tool_name, user_id=user_id)
97124

98125
async def execute_tool(
99126
self,
100127
project_id: str,
101128
tool_name: str,
102129
unity_instance: str | None,
103130
params: dict[str, object] | None = None,
131+
user_id: str | None = None,
104132
) -> MCPResponse:
105133
params = params or {}
106134
logger.info(
107135
f"Executing tool '{tool_name}' for project '{project_id}' (instance={unity_instance}) with params: {params}"
108136
)
109137

110-
definition = await self.get_tool_definition(project_id, tool_name)
138+
definition = await self.get_tool_definition(project_id, tool_name, user_id=user_id)
111139
if definition is None:
112140
return MCPResponse(
113141
success=False,
@@ -119,6 +147,7 @@ async def execute_tool(
119147
unity_instance,
120148
tool_name,
121149
params,
150+
user_id=user_id,
122151
)
123152

124153
if not definition.requires_polling:
@@ -132,6 +161,7 @@ async def execute_tool(
132161
params,
133162
response,
134163
definition.poll_action or "status",
164+
user_id=user_id,
135165
)
136166
logger.info(f"Tool '{tool_name}' polled response: {result}")
137167
return result
@@ -156,6 +186,7 @@ async def _poll_until_complete(
156186
initial_params: dict[str, object],
157187
initial_response,
158188
poll_action: str,
189+
user_id: str | None = None,
159190
) -> MCPResponse:
160191
poll_params = dict(initial_params)
161192
poll_params["action"] = poll_action or "status"
@@ -180,7 +211,11 @@ async def _poll_until_complete(
180211

181212
try:
182213
response = await send_with_unity_instance(
183-
async_send_command_with_retry, unity_instance, tool_name, poll_params
214+
async_send_command_with_retry,
215+
unity_instance,
216+
tool_name,
217+
poll_params,
218+
user_id=user_id,
184219
)
185220
except Exception as exc: # pragma: no cover - network/domain reload variability
186221
logger.debug(f"Polling {tool_name} failed, will retry: {exc}")
@@ -347,8 +382,15 @@ async def _handler(ctx: Context, **kwargs) -> MCPResponse:
347382
)
348383

349384
params = {k: v for k, v in kwargs.items() if v is not None}
385+
user_id = get_user_id_from_context(ctx)
350386
service = CustomToolService.get_instance()
351-
return await service.execute_tool(project_id, definition.name, unity_instance, params)
387+
return await service.execute_tool(
388+
project_id,
389+
definition.name,
390+
unity_instance,
391+
params,
392+
user_id=user_id,
393+
)
352394

353395
_handler.__name__ = f"custom_tool_{definition.name}"
354396
_handler.__doc__ = definition.description or ""

Server/src/services/registry/tool_registry.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
def mcp_for_unity_tool(
1111
name: str | None = None,
1212
description: str | None = None,
13+
unity_target: str | None = "self",
1314
**kwargs
1415
) -> Callable:
1516
"""
@@ -20,6 +21,10 @@ def mcp_for_unity_tool(
2021
Args:
2122
name: Tool name (defaults to function name)
2223
description: Tool description
24+
unity_target: Visibility target used by middleware filtering.
25+
- "self" (default): tool follows its own enabled state.
26+
- None: server-only tool, always visible in tool listing.
27+
- "<tool_name>": alias tool that follows another Unity tool state.
2328
**kwargs: Additional arguments passed to @mcp.tool()
2429
2530
Example:
@@ -29,11 +34,29 @@ async def my_custom_tool(ctx: Context, ...):
2934
"""
3035
def decorator(func: Callable) -> Callable:
3136
tool_name = name if name is not None else func.__name__
37+
# Safety guard: unity_target is internal metadata and must never leak into mcp.tool kwargs.
38+
tool_kwargs = dict(kwargs) # Create a copy to avoid side effects
39+
if "unity_target" in tool_kwargs:
40+
del tool_kwargs["unity_target"]
41+
42+
if unity_target is None:
43+
normalized_unity_target: str | None = None
44+
elif isinstance(unity_target, str) and unity_target.strip():
45+
normalized_unity_target = (
46+
tool_name if unity_target == "self" else unity_target.strip()
47+
)
48+
else:
49+
raise ValueError(
50+
f"Invalid unity_target for tool '{tool_name}': {unity_target!r}. "
51+
"Expected None or a non-empty string."
52+
)
53+
3254
_tool_registry.append({
3355
'func': func,
3456
'name': tool_name,
3557
'description': description,
36-
'kwargs': kwargs
58+
'unity_target': normalized_unity_target,
59+
'kwargs': tool_kwargs,
3760
})
3861

3962
return func

0 commit comments

Comments
 (0)