Skip to content

Commit 23e7f04

Browse files
Yuvraj SinghYuvraj Singh
authored andcommitted
Add MCP integration improvements and refactoring
1 parent ba77c62 commit 23e7f04

6 files changed

Lines changed: 550 additions & 294 deletions

File tree

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ dependencies = [
3838
"dnet-p2p @ file://${PROJECT_ROOT}/lib/dnet-p2p/bindings/py",
3939
"rich>=13.0.0",
4040
"psutil>=5.9.0",
41-
"fastmcp",
41+
"fastmcp==2.13.0",
4242
]
4343

4444
[project.optional-dependencies]
@@ -48,6 +48,7 @@ cuda = ["mlx[cuda]"]
4848
dev = [
4949
"openai>=2.6.0", # for OpenAI compatibility tests
5050
"pytest>=8.4.2",
51+
"pytest-asyncio>=0.24.0",
5152
"mypy>=1.3.0", # Type checking
5253
"ruff>=0.0.285",
5354
"types-psutil>=7.1.3",
@@ -68,6 +69,7 @@ python_files = ["test_*.py", "*_test.py"]
6869
testpaths = ["tests"]
6970
python_functions = ["test_"]
7071
log_cli = true
72+
asyncio_mode = "auto"
7173
markers = [
7274
"api: tests for API node components (HTTP, gRPC, managers)",
7375
"shard: tests for Shard node components (HTTP, gRPC, runtime, policies, ring)",
@@ -80,6 +82,7 @@ markers = [
8082
"core: tests for core memory/cache/utils not tied to api/shard",
8183
"e2e: integration tests requiring live servers or multiple components",
8284
"integration: model catalog integration tests for CI (manual trigger)",
85+
"mcp: tests for MCP handler tools and resources",
8386
]
8487

8588
[tool.ruff]

src/dnet/api/http_api.py

Lines changed: 25 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Optional, Any, List
22
import asyncio
3-
import os
43
from hypercorn import Config
54
from hypercorn.utils import LifespanFailureError
65
import hypercorn.asyncio as aio_hypercorn
@@ -26,6 +25,11 @@
2625
from .model_manager import ModelManager
2726
from dnet_p2p import DnetDeviceProperties
2827
from .mcp_handler import create_mcp_server
28+
from .load_helpers import (
29+
_prepare_topology_core,
30+
_load_model_core,
31+
_unload_model_core,
32+
)
2933

3034

3135
class HTTPServer:
@@ -45,19 +49,16 @@ def __init__(
4549
self.http_server: Optional[asyncio.Task] = None
4650

4751
# Create MCP server first to get lifespan
48-
mcp = create_mcp_server(
49-
inference_manager, model_manager, cluster_manager
50-
)
52+
mcp = create_mcp_server(inference_manager, model_manager, cluster_manager)
5153
# Use path='/' since we're mounting at /mcp, so final path will be /mcp/
5254
mcp_app = mcp.http_app(path="/")
53-
55+
5456
# Create FastAPI app with MCP lifespan
5557
self.app = FastAPI(lifespan=mcp_app.lifespan)
56-
58+
5759
# Mount MCP server as ASGI app
5860
self.app.mount("/mcp", mcp_app)
5961

60-
6162
async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()) -> None:
6263
await self._setup_routes()
6364

@@ -166,59 +167,27 @@ async def load_model(self, req: APILoadModelRequest) -> APILoadModelResponse:
166167
),
167168
)
168169

169-
model_config = get_model_config_json(req.model)
170-
embedding_size = int(model_config["hidden_size"])
171-
num_layers = int(model_config["num_hidden_layers"])
172-
173-
await self.cluster_manager.scan_devices()
174-
batch_sizes = [1]
175-
profiles = await self.cluster_manager.profile_cluster(
176-
req.model, embedding_size, 2, batch_sizes
177-
)
178-
if not profiles:
179-
return APILoadModelResponse(
180-
model=req.model,
181-
success=False,
182-
shard_statuses=[],
183-
message="No profiles collected",
170+
try:
171+
topology = await _prepare_topology_core(
172+
self.cluster_manager, req.model, req.kv_bits, req.seq_len
184173
)
185-
186-
model_profile_split = profile_model(
187-
repo_id=req.model,
188-
batch_sizes=batch_sizes,
189-
sequence_length=req.seq_len,
190-
)
191-
model_profile = model_profile_split.to_model_profile()
192-
topology = await self.cluster_manager.solve_topology(
193-
profiles, model_profile, req.model, num_layers, req.kv_bits
194-
)
174+
except RuntimeError as e:
175+
if "No profiles collected" in str(e):
176+
return APILoadModelResponse(
177+
model=req.model,
178+
success=False,
179+
shard_statuses=[],
180+
message="No profiles collected",
181+
)
182+
raise
195183
self.cluster_manager.current_topology = topology
196184

197-
api_props = await self.cluster_manager.discovery.async_get_own_properties()
198-
grpc_port = int(self.inference_manager.grpc_port)
199-
200-
# Callback address shards should use for SendToken.
201-
# In static discovery / cloud setups, discovery may report 127.0.0.1 which is not usable.
202-
api_callback_addr = (os.getenv("DNET_API_CALLBACK_ADDR") or "").strip()
203-
if not api_callback_addr:
204-
api_callback_addr = f"{api_props.local_ip}:{grpc_port}"
205-
if api_props.local_ip in ("127.0.0.1", "localhost"):
206-
logger.warning(
207-
"API callback address is loopback (%s). Remote shards will fail to SendToken. "
208-
"Set DNET_API_CALLBACK_ADDR to a reachable host:port.",
209-
api_callback_addr,
210-
)
211-
response = await self.model_manager.load_model(
185+
response = await _load_model_core(
186+
self.cluster_manager,
187+
self.model_manager,
188+
self.inference_manager,
212189
topology,
213-
api_props,
214-
self.inference_manager.grpc_port,
215-
api_callback_address=api_callback_addr,
216190
)
217-
if response.success:
218-
first_shard = topology.devices[0]
219-
await self.inference_manager.connect_to_ring(
220-
first_shard.local_ip, first_shard.shard_port, api_callback_addr
221-
)
222191
return response
223192

224193
except Exception as e:
@@ -231,12 +200,7 @@ async def load_model(self, req: APILoadModelRequest) -> APILoadModelResponse:
231200
)
232201

233202
async def unload_model(self) -> UnloadModelResponse:
234-
await self.cluster_manager.scan_devices()
235-
shards = self.cluster_manager.shards
236-
response = await self.model_manager.unload_model(shards)
237-
if response.success:
238-
self.cluster_manager.current_topology = None
239-
return response
203+
return await _unload_model_core(self.cluster_manager, self.model_manager)
240204

241205
async def get_devices(self) -> JSONResponse:
242206
devices = await self.cluster_manager.discovery.async_get_properties()

src/dnet/api/load_helpers.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import os
2+
from dnet.utils.logger import logger
3+
from dnet.utils.model import get_model_config_json
4+
from distilp.profiler import profile_model
5+
from dnet.core.types.topology import TopologyInfo
6+
from .models import APILoadModelResponse, UnloadModelResponse
7+
8+
9+
async def get_api_callback_address(
10+
cluster_manager,
11+
grpc_port: int | str,
12+
) -> str:
13+
api_props = await cluster_manager.discovery.async_get_own_properties()
14+
grpc_port_int = int(grpc_port)
15+
api_callback_addr = (os.getenv("DNET_API_CALLBACK_ADDR") or "").strip()
16+
if not api_callback_addr:
17+
api_callback_addr = f"{api_props.local_ip}:{grpc_port_int}"
18+
if api_props.local_ip in ("127.0.0.1", "localhost"):
19+
logger.warning(
20+
"API callback address is loopback (%s). Remote shards will fail to SendToken. "
21+
"Set DNET_API_CALLBACK_ADDR to a reachable host:port.",
22+
api_callback_addr,
23+
)
24+
return api_callback_addr
25+
26+
27+
async def _prepare_topology_core(
28+
cluster_manager,
29+
model: str,
30+
kv_bits: str,
31+
seq_len: int,
32+
progress_callback=None,
33+
) -> TopologyInfo:
34+
model_config = get_model_config_json(model)
35+
embedding_size = int(model_config["hidden_size"])
36+
num_layers = int(model_config["num_hidden_layers"])
37+
38+
await cluster_manager.scan_devices()
39+
if progress_callback:
40+
await progress_callback("Profiling cluster performance")
41+
batch_sizes = [1]
42+
profiles = await cluster_manager.profile_cluster(
43+
model, embedding_size, 2, batch_sizes
44+
)
45+
if not profiles:
46+
raise RuntimeError("No profiles collected")
47+
48+
if progress_callback:
49+
await progress_callback("Computing optimal layer distribution")
50+
model_profile_split = profile_model(
51+
repo_id=model,
52+
batch_sizes=batch_sizes,
53+
sequence_length=seq_len,
54+
)
55+
model_profile = model_profile_split.to_model_profile()
56+
57+
topology = await cluster_manager.solve_topology(
58+
profiles, model_profile, model, num_layers, kv_bits
59+
)
60+
return topology
61+
62+
63+
async def _load_model_core(
64+
cluster_manager,
65+
model_manager,
66+
inference_manager,
67+
topology: TopologyInfo,
68+
) -> APILoadModelResponse:
69+
api_props = await cluster_manager.discovery.async_get_own_properties()
70+
grpc_port = int(inference_manager.grpc_port)
71+
72+
api_callback_addr = await get_api_callback_address(
73+
cluster_manager, inference_manager.grpc_port
74+
)
75+
response = await model_manager.load_model(
76+
topology,
77+
api_props,
78+
grpc_port,
79+
api_callback_address=api_callback_addr,
80+
)
81+
if response.success:
82+
first_shard = topology.devices[0]
83+
await inference_manager.connect_to_ring(
84+
first_shard.local_ip, first_shard.shard_port, api_callback_addr
85+
)
86+
return response
87+
88+
89+
async def _unload_model_core(
90+
cluster_manager,
91+
model_manager,
92+
) -> UnloadModelResponse:
93+
await cluster_manager.scan_devices()
94+
shards = cluster_manager.shards
95+
response = await model_manager.unload_model(shards)
96+
if response.success:
97+
cluster_manager.current_topology = None
98+
return response

0 commit comments

Comments
 (0)