11from typing import Optional , Any , List
22import asyncio
3- import os
43from hypercorn import Config
54from hypercorn .utils import LifespanFailureError
65import hypercorn .asyncio as aio_hypercorn
2625from .model_manager import ModelManager
2726from dnet_p2p import DnetDeviceProperties
2827from .mcp_handler import create_mcp_server
28+ from .load_helpers import (
29+ _prepare_topology_core ,
30+ _load_model_core ,
31+ _unload_model_core ,
32+ )
2933
3034
3135class HTTPServer :
@@ -45,19 +49,16 @@ def __init__(
4549 self .http_server : Optional [asyncio .Task ] = None
4650
4751 # Create MCP server first to get lifespan
48- mcp = create_mcp_server (
49- inference_manager , model_manager , cluster_manager
50- )
52+ mcp = create_mcp_server (inference_manager , model_manager , cluster_manager )
5153 # Use path='/' since we're mounting at /mcp, so final path will be /mcp/
5254 mcp_app = mcp .http_app (path = "/" )
53-
55+
5456 # Create FastAPI app with MCP lifespan
5557 self .app = FastAPI (lifespan = mcp_app .lifespan )
56-
58+
5759 # Mount MCP server as ASGI app
5860 self .app .mount ("/mcp" , mcp_app )
5961
60-
6162 async def start (self , shutdown_trigger : Any = lambda : asyncio .Future ()) -> None :
6263 await self ._setup_routes ()
6364
@@ -166,59 +167,27 @@ async def load_model(self, req: APILoadModelRequest) -> APILoadModelResponse:
166167 ),
167168 )
168169
169- model_config = get_model_config_json (req .model )
170- embedding_size = int (model_config ["hidden_size" ])
171- num_layers = int (model_config ["num_hidden_layers" ])
172-
173- await self .cluster_manager .scan_devices ()
174- batch_sizes = [1 ]
175- profiles = await self .cluster_manager .profile_cluster (
176- req .model , embedding_size , 2 , batch_sizes
177- )
178- if not profiles :
179- return APILoadModelResponse (
180- model = req .model ,
181- success = False ,
182- shard_statuses = [],
183- message = "No profiles collected" ,
170+ try :
171+ topology = await _prepare_topology_core (
172+ self .cluster_manager , req .model , req .kv_bits , req .seq_len
184173 )
185-
186- model_profile_split = profile_model (
187- repo_id = req .model ,
188- batch_sizes = batch_sizes ,
189- sequence_length = req .seq_len ,
190- )
191- model_profile = model_profile_split .to_model_profile ()
192- topology = await self .cluster_manager .solve_topology (
193- profiles , model_profile , req .model , num_layers , req .kv_bits
194- )
174+ except RuntimeError as e :
175+ if "No profiles collected" in str (e ):
176+ return APILoadModelResponse (
177+ model = req .model ,
178+ success = False ,
179+ shard_statuses = [],
180+ message = "No profiles collected" ,
181+ )
182+ raise
195183 self .cluster_manager .current_topology = topology
196184
197- api_props = await self .cluster_manager .discovery .async_get_own_properties ()
198- grpc_port = int (self .inference_manager .grpc_port )
199-
200- # Callback address shards should use for SendToken.
201- # In static discovery / cloud setups, discovery may report 127.0.0.1 which is not usable.
202- api_callback_addr = (os .getenv ("DNET_API_CALLBACK_ADDR" ) or "" ).strip ()
203- if not api_callback_addr :
204- api_callback_addr = f"{ api_props .local_ip } :{ grpc_port } "
205- if api_props .local_ip in ("127.0.0.1" , "localhost" ):
206- logger .warning (
207- "API callback address is loopback (%s). Remote shards will fail to SendToken. "
208- "Set DNET_API_CALLBACK_ADDR to a reachable host:port." ,
209- api_callback_addr ,
210- )
211- response = await self .model_manager .load_model (
185+ response = await _load_model_core (
186+ self .cluster_manager ,
187+ self .model_manager ,
188+ self .inference_manager ,
212189 topology ,
213- api_props ,
214- self .inference_manager .grpc_port ,
215- api_callback_address = api_callback_addr ,
216190 )
217- if response .success :
218- first_shard = topology .devices [0 ]
219- await self .inference_manager .connect_to_ring (
220- first_shard .local_ip , first_shard .shard_port , api_callback_addr
221- )
222191 return response
223192
224193 except Exception as e :
@@ -231,12 +200,7 @@ async def load_model(self, req: APILoadModelRequest) -> APILoadModelResponse:
231200 )
232201
233202 async def unload_model (self ) -> UnloadModelResponse :
234- await self .cluster_manager .scan_devices ()
235- shards = self .cluster_manager .shards
236- response = await self .model_manager .unload_model (shards )
237- if response .success :
238- self .cluster_manager .current_topology = None
239- return response
203+ return await _unload_model_core (self .cluster_manager , self .model_manager )
240204
241205 async def get_devices (self ) -> JSONResponse :
242206 devices = await self .cluster_manager .discovery .async_get_properties ()
0 commit comments