@@ -129,7 +129,15 @@ def parse_args():
129129 )
130130 parser .add_argument ("--model_path" , type = str , help = "The path of model" )
131131
132+ # Routing replay (R3) arguments
133+ parser .add_argument ("--enable_routing_replay" , type = int , default = 0 , help = "Enable routing replay" )
134+ parser .add_argument ("--routing_num_moe_layers" , type = int , default = 0 , help = "Number of MoE layers for routing" )
135+ parser .add_argument ("--routing_moe_top_k" , type = int , default = 0 , help = "MoE top_k for routing" )
136+ parser .add_argument ("--routing_dtype" , type = str , default = "uint8" , help = "Routing data dtype" )
137+
132138 args = parser .parse_args ()
139+ # Convert int flag to bool
140+ args .enable_routing_replay = bool (args .enable_routing_replay )
133141 return args
134142
135143
@@ -241,6 +249,13 @@ def __init__(self, args):
241249 self ._init_cpu_cache ()
242250 if self .storage_backend_type is not None :
243251 self ._init_storage (args )
252+
253+ # Initialize auxiliary data specs (e.g., routing replay)
254+ self .aux_data_specs = {}
255+ self .routing_host_view = None
256+ self .routing_swap_buffer = None
257+ self ._init_routing_aux_data (args )
258+
244259 self ._init_control ()
245260
246261 cache_task_broadcast_data = np .zeros (shape = [1 ], dtype = np .int32 )
@@ -307,6 +322,162 @@ def __init__(self, args):
307322 )
308323 self .cache_transfer_inited_signal .value [self .rank ] = 1
309324
325+ def _init_routing_aux_data (self , args ):
326+ """Initialize routing auxiliary data buffers for swap sync."""
327+ enable_routing_replay = getattr (args , "enable_routing_replay" , False )
328+ if not enable_routing_replay :
329+ return
330+
331+ try :
332+ from fastdeploy .cache_manager .cache_data import AuxBlockDataSpec
333+ from fastdeploy .model_executor .layers .moe .routing_indices_cache import (
334+ RoutingHostBufferView ,
335+ RoutingSwapBuffer ,
336+ )
337+
338+ num_moe_layers = getattr (args , "routing_num_moe_layers" , 0 )
339+ moe_top_k = getattr (args , "routing_moe_top_k" , 0 )
340+ routing_dtype = getattr (args , "routing_dtype" , "uint8" )
341+
342+ if num_moe_layers == 0 or moe_top_k == 0 :
343+ return
344+
345+ spec = AuxBlockDataSpec (
346+ name = "routing" ,
347+ num_layers = num_moe_layers ,
348+ per_token_size = moe_top_k ,
349+ block_size = self .block_size ,
350+ dtype = routing_dtype ,
351+ )
352+
353+ # Create routing swap buffer (for CPU blocks)
354+ if self .num_cpu_blocks > 0 :
355+ dp_suffix = str (getattr (args , "engine_worker_queue_port" , "" ))
356+ self .routing_swap_buffer = RoutingSwapBuffer (
357+ num_cpu_blocks = self .num_cpu_blocks ,
358+ block_size = self .block_size ,
359+ num_moe_layers = num_moe_layers ,
360+ top_k = moe_top_k ,
361+ dtype = routing_dtype ,
362+ dp_suffix = dp_suffix ,
363+ )
364+ spec .swap_buffer = self .routing_swap_buffer
365+
366+ # Attach to routing host buffer (SharedMemory created by Engine)
367+ dp_suffix = str (getattr (args , "engine_worker_queue_port" , "" ))
368+ shm_name = f"routing_host_buffer.{ dp_suffix } "
369+ max_num_kv_tokens = self .num_gpu_blocks * self .block_size
370+ shape = (max_num_kv_tokens , num_moe_layers , moe_top_k )
371+ try :
372+ self .routing_host_view = RoutingHostBufferView (shape = shape , dtype = routing_dtype , shm_name = shm_name )
373+ logger .info (f"[R3] CTM attached to RoutingHostBuffer: { shm_name } " )
374+ except FileNotFoundError :
375+ logger .warning (f"[R3] CTM RoutingHostBuffer { shm_name } not found" )
376+
377+ self .aux_data_specs ["routing" ] = spec
378+ logger .info (f"[R3] CTM registered routing aux data: layers={ num_moe_layers } , top_k={ moe_top_k } " )
379+
380+ except Exception as e :
381+ logger .warning (f"[R3] CTM failed to init routing aux data: { e } " )
382+
383+ def _swap_routing (self , gpu_block_ids , cpu_block_ids , direction ):
384+ """
385+ Swap routing data between routing_host_buffer and routing_swap_buffer.
386+ Pure CPU-to-CPU numpy memcpy, no GPU DMA.
387+ """
388+ if self .routing_host_view is None or self .routing_swap_buffer is None :
389+ return
390+ bs = self .block_size
391+ for gpu_bid , cpu_bid in zip (gpu_block_ids , cpu_block_ids ):
392+ gpu_start = gpu_bid * bs
393+ gpu_end = gpu_start + bs
394+ cpu_start = cpu_bid * bs
395+ cpu_end = cpu_start + bs
396+ if direction == "to_cpu" :
397+ self .routing_swap_buffer .buffer [cpu_start :cpu_end ] = self .routing_host_view .buffer [gpu_start :gpu_end ]
398+ else : # to_gpu
399+ self .routing_host_view .buffer [gpu_start :gpu_end ] = self .routing_swap_buffer .buffer [cpu_start :cpu_end ]
400+
401+ def _write_routing_to_storage (self , task_keys , gpu_block_ids ):
402+ """
403+ Write routing data from routing_host_buffer to storage backend.
404+ Only for mooncake/file backends; only tp_rank=0 writes routing.
405+ """
406+ if self .routing_host_view is None or self .rank != 0 :
407+ return
408+ if self .storage_backend_type not in ("mooncake" , "file" ):
409+ return
410+
411+ try :
412+ spec = self .aux_data_specs .get ("routing" )
413+ if spec is None or not spec .enabled :
414+ return
415+
416+ bs = self .block_size
417+ routing_keys = []
418+ routing_ptrs = []
419+ routing_sizes = []
420+ per_block_bytes = bs * spec .num_layers * spec .per_token_size * np .dtype (spec .dtype ).itemsize
421+
422+ for block_hash , gpu_bid in zip (task_keys , gpu_block_ids ):
423+ key = spec .get_storage_key (self .key_prefix , block_hash , self .rank )
424+ start = gpu_bid * bs
425+ end = start + bs
426+ block_data = self .routing_host_view .buffer [start :end ]
427+ if not block_data .flags ["C_CONTIGUOUS" ]:
428+ block_data = np .ascontiguousarray (block_data )
429+ routing_keys .append (key )
430+ routing_ptrs .append (block_data .ctypes .data )
431+ routing_sizes .append (per_block_bytes )
432+
433+ if routing_keys :
434+ self .storage_backend .batch_set (
435+ keys = routing_keys , target_locations = routing_ptrs , target_sizes = routing_sizes
436+ )
437+ logger .debug (f"[R3] Wrote { len (routing_keys )} routing blocks to storage" )
438+ except Exception as e :
439+ logger .warning (f"[R3] Failed to write routing to storage: { e } " )
440+
441+ def _read_routing_from_storage (self , task_keys , gpu_block_ids ):
442+ """
443+ Read routing data from storage backend into routing_host_buffer.
444+ Only for mooncake/file backends; only tp_rank=0 reads routing.
445+ """
446+ if self .routing_host_view is None or self .rank != 0 :
447+ return
448+ if self .storage_backend_type not in ("mooncake" , "file" ):
449+ return
450+
451+ try :
452+ spec = self .aux_data_specs .get ("routing" )
453+ if spec is None or not spec .enabled :
454+ return
455+
456+ bs = self .block_size
457+ per_block_bytes = bs * spec .num_layers * spec .per_token_size * np .dtype (spec .dtype ).itemsize
458+
459+ for block_hash , gpu_bid in zip (task_keys , gpu_block_ids ):
460+ key = spec .get_storage_key (self .key_prefix , block_hash , self .rank )
461+ start = gpu_bid * bs
462+ end = start + bs
463+ target_slice = self .routing_host_view .buffer [start :end ]
464+ if not target_slice .flags ["C_CONTIGUOUS" ]:
465+ # Need contiguous target for ctypes pointer
466+ tmp = np .ascontiguousarray (target_slice )
467+ result = self .storage_backend .get (
468+ key = key , target_location = tmp .ctypes .data , target_size = per_block_bytes
469+ )
470+ if result is not None and result >= 0 :
471+ self .routing_host_view .buffer [start :end ] = tmp
472+ else :
473+ self .storage_backend .get (
474+ key = key , target_location = target_slice .ctypes .data , target_size = per_block_bytes
475+ )
476+
477+ logger .debug (f"[R3] Read { len (task_keys )} routing blocks from storage" )
478+ except Exception as e :
479+ logger .warning (f"[R3] Failed to read routing from storage: { e } " )
480+
310481 def _init_control (self ):
311482 dp_rank = self .local_data_parallel_id
312483 tp_rank = self .rank
@@ -809,6 +980,9 @@ def read_storage_task(self, task: ReadStorageTask):
809980 logger .info (
810981 f"Successfully read { len (valid_gpu_block_ids )} blocks from cache storage for task { task .task_id } "
811982 )
983+ # Read routing data from storage for matched blocks
984+ matched_keys = task .keys [: len (valid_gpu_block_ids )]
985+ self ._read_routing_from_storage (matched_keys , valid_gpu_block_ids )
812986 except Exception as e :
813987 logger .error (
814988 f"Failed to read cache for task { task .task_id } , error: { e } , traceback: { traceback .format_exc ()} "
@@ -1000,6 +1174,9 @@ def write_back_storage_task(self, task: WriteStorageTask):
10001174 logger .info (
10011175 f"Successfully wrote { write_block_num } blocks to cache storage for task { task .task_id } "
10021176 )
1177+ # Write routing data to storage (shares dedup with KVCache)
1178+ remaining_keys = task .keys [match_block_num :]
1179+ self ._write_routing_to_storage (remaining_keys , gpu_block_ids )
10031180 except Exception as e :
10041181 logger .error (f"Error in write back storage task: { e } , traceback:{ traceback .format_exc ()} " )
10051182 gpu_block_ids = []
@@ -1375,6 +1552,10 @@ def _transfer_data(
13751552 0 ,
13761553 )
13771554
1555+ # Routing: routing_host_buffer → routing_swap_buffer
1556+ if "routing" in self .aux_data_specs :
1557+ self ._swap_routing (gpu_block_ids , cpu_block_ids , "to_cpu" )
1558+
13781559 elif event_type .value == CacheStatus .SWAP2GPU .value :
13791560 swap_cache_all_layers (
13801561 self .gpu_cache_k_tensors ,
@@ -1413,6 +1594,11 @@ def _transfer_data(
14131594 self .device ,
14141595 1 ,
14151596 )
1597+
1598+ # Routing: routing_swap_buffer → routing_host_buffer
1599+ if "routing" in self .aux_data_specs :
1600+ self ._swap_routing (gpu_block_ids , cpu_block_ids , "to_gpu" )
1601+
14161602 else :
14171603 logger .warning (
14181604 f"transfer data: Get unexpected event type { event_type } , only SWAP2CPU and SWAP2GPU supported"
0 commit comments