Skip to content

Commit 0c68b77

Browse files
gongshaotianclaude
andcommitted
[RL][Feature] R3 Phase 2: routing data follows KVCache block lifecycle (swap/storage/PD)
Implement dual-buffer architecture for routing replay: - GPU transient buffer [max_num_batched_tokens, L, K] with Triton v2 kernel - SharedMemory routing_host_buffer for cross-process Engine/Worker/CTM sharing - Lazy SharedMemory attach in Worker and TokenProcessor (Engine creates after profiling) - CTM routing write/read for swap and storage backends - PD disaggregation: P gathers routing via send_first_token, D writes to host buffer - Local store persistence verified end-to-end Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2efab46 commit 0c68b77

13 files changed

Lines changed: 730 additions & 28 deletions

File tree

fastdeploy/cache_manager/cache_data.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,35 @@
1414
# limitations under the License.
1515
"""
1616

17+
from dataclasses import dataclass
1718
from enum import Enum
19+
from typing import Any, Optional
1820

1921
from fastdeploy.utils import get_logger
2022

2123
logger = get_logger("prefix_cache_manager", "cache_manager.log")
2224

2325

26+
@dataclass
27+
class AuxBlockDataSpec:
28+
"""
29+
Describes a type of auxiliary data bound to KVCache blocks.
30+
CacheTransferManager iterates registered specs during swap/storage
31+
to perform corresponding data transfers.
32+
"""
33+
34+
name: str
35+
num_layers: int
36+
per_token_size: int = 0
37+
block_size: int = 0
38+
dtype: str = "uint8"
39+
swap_buffer: Optional[Any] = None
40+
enabled: bool = True
41+
42+
def get_storage_key(self, key_prefix: str, block_hash: str, rank: int) -> str:
43+
return f"prefix{key_prefix}_{block_hash}_{rank}_{self.name}"
44+
45+
2446
class CacheStatus(Enum):
2547
"""
2648
cache status enum class
@@ -56,6 +78,7 @@ def __init__(
5678
cache_status=CacheStatus.GPU,
5779
is_persistent=False,
5880
persistent_shared_count=0,
81+
aux_data_names=None,
5982
):
6083
"""
6184
Args:
@@ -89,6 +112,7 @@ def __init__(
89112
self.cache_status = cache_status
90113
self.is_persistent = is_persistent
91114
self.persistent_shared_count = persistent_shared_count
115+
self.aux_data_names = aux_data_names or []
92116
self.req_id_set = set()
93117

94118
def __lt__(self, other):
@@ -102,7 +126,7 @@ def __lt__(self, other):
102126
else:
103127
return self.depth > other.depth
104128

105-
def __str__(self):
129+
def __str__(self) -> str:
106130
"""
107131
return node info
108132
"""

fastdeploy/cache_manager/cache_transfer_manager.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,15 @@ def parse_args():
129129
)
130130
parser.add_argument("--model_path", type=str, help="The path of model")
131131

132+
# Routing replay (R3) arguments
133+
parser.add_argument("--enable_routing_replay", type=int, default=0, help="Enable routing replay")
134+
parser.add_argument("--routing_num_moe_layers", type=int, default=0, help="Number of MoE layers for routing")
135+
parser.add_argument("--routing_moe_top_k", type=int, default=0, help="MoE top_k for routing")
136+
parser.add_argument("--routing_dtype", type=str, default="uint8", help="Routing data dtype")
137+
132138
args = parser.parse_args()
139+
# Convert int flag to bool
140+
args.enable_routing_replay = bool(args.enable_routing_replay)
133141
return args
134142

135143

@@ -241,6 +249,13 @@ def __init__(self, args):
241249
self._init_cpu_cache()
242250
if self.storage_backend_type is not None:
243251
self._init_storage(args)
252+
253+
# Initialize auxiliary data specs (e.g., routing replay)
254+
self.aux_data_specs = {}
255+
self.routing_host_view = None
256+
self.routing_swap_buffer = None
257+
self._init_routing_aux_data(args)
258+
244259
self._init_control()
245260

246261
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
@@ -307,6 +322,162 @@ def __init__(self, args):
307322
)
308323
self.cache_transfer_inited_signal.value[self.rank] = 1
309324

325+
def _init_routing_aux_data(self, args):
326+
"""Initialize routing auxiliary data buffers for swap sync."""
327+
enable_routing_replay = getattr(args, "enable_routing_replay", False)
328+
if not enable_routing_replay:
329+
return
330+
331+
try:
332+
from fastdeploy.cache_manager.cache_data import AuxBlockDataSpec
333+
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
334+
RoutingHostBufferView,
335+
RoutingSwapBuffer,
336+
)
337+
338+
num_moe_layers = getattr(args, "routing_num_moe_layers", 0)
339+
moe_top_k = getattr(args, "routing_moe_top_k", 0)
340+
routing_dtype = getattr(args, "routing_dtype", "uint8")
341+
342+
if num_moe_layers == 0 or moe_top_k == 0:
343+
return
344+
345+
spec = AuxBlockDataSpec(
346+
name="routing",
347+
num_layers=num_moe_layers,
348+
per_token_size=moe_top_k,
349+
block_size=self.block_size,
350+
dtype=routing_dtype,
351+
)
352+
353+
# Create routing swap buffer (for CPU blocks)
354+
if self.num_cpu_blocks > 0:
355+
dp_suffix = str(getattr(args, "engine_worker_queue_port", ""))
356+
self.routing_swap_buffer = RoutingSwapBuffer(
357+
num_cpu_blocks=self.num_cpu_blocks,
358+
block_size=self.block_size,
359+
num_moe_layers=num_moe_layers,
360+
top_k=moe_top_k,
361+
dtype=routing_dtype,
362+
dp_suffix=dp_suffix,
363+
)
364+
spec.swap_buffer = self.routing_swap_buffer
365+
366+
# Attach to routing host buffer (SharedMemory created by Engine)
367+
dp_suffix = str(getattr(args, "engine_worker_queue_port", ""))
368+
shm_name = f"routing_host_buffer.{dp_suffix}"
369+
max_num_kv_tokens = self.num_gpu_blocks * self.block_size
370+
shape = (max_num_kv_tokens, num_moe_layers, moe_top_k)
371+
try:
372+
self.routing_host_view = RoutingHostBufferView(shape=shape, dtype=routing_dtype, shm_name=shm_name)
373+
logger.info(f"[R3] CTM attached to RoutingHostBuffer: {shm_name}")
374+
except FileNotFoundError:
375+
logger.warning(f"[R3] CTM RoutingHostBuffer {shm_name} not found")
376+
377+
self.aux_data_specs["routing"] = spec
378+
logger.info(f"[R3] CTM registered routing aux data: layers={num_moe_layers}, top_k={moe_top_k}")
379+
380+
except Exception as e:
381+
logger.warning(f"[R3] CTM failed to init routing aux data: {e}")
382+
383+
def _swap_routing(self, gpu_block_ids, cpu_block_ids, direction):
384+
"""
385+
Swap routing data between routing_host_buffer and routing_swap_buffer.
386+
Pure CPU-to-CPU numpy memcpy, no GPU DMA.
387+
"""
388+
if self.routing_host_view is None or self.routing_swap_buffer is None:
389+
return
390+
bs = self.block_size
391+
for gpu_bid, cpu_bid in zip(gpu_block_ids, cpu_block_ids):
392+
gpu_start = gpu_bid * bs
393+
gpu_end = gpu_start + bs
394+
cpu_start = cpu_bid * bs
395+
cpu_end = cpu_start + bs
396+
if direction == "to_cpu":
397+
self.routing_swap_buffer.buffer[cpu_start:cpu_end] = self.routing_host_view.buffer[gpu_start:gpu_end]
398+
else: # to_gpu
399+
self.routing_host_view.buffer[gpu_start:gpu_end] = self.routing_swap_buffer.buffer[cpu_start:cpu_end]
400+
401+
def _write_routing_to_storage(self, task_keys, gpu_block_ids):
402+
"""
403+
Write routing data from routing_host_buffer to storage backend.
404+
Only for mooncake/file backends; only tp_rank=0 writes routing.
405+
"""
406+
if self.routing_host_view is None or self.rank != 0:
407+
return
408+
if self.storage_backend_type not in ("mooncake", "file"):
409+
return
410+
411+
try:
412+
spec = self.aux_data_specs.get("routing")
413+
if spec is None or not spec.enabled:
414+
return
415+
416+
bs = self.block_size
417+
routing_keys = []
418+
routing_ptrs = []
419+
routing_sizes = []
420+
per_block_bytes = bs * spec.num_layers * spec.per_token_size * np.dtype(spec.dtype).itemsize
421+
422+
for block_hash, gpu_bid in zip(task_keys, gpu_block_ids):
423+
key = spec.get_storage_key(self.key_prefix, block_hash, self.rank)
424+
start = gpu_bid * bs
425+
end = start + bs
426+
block_data = self.routing_host_view.buffer[start:end]
427+
if not block_data.flags["C_CONTIGUOUS"]:
428+
block_data = np.ascontiguousarray(block_data)
429+
routing_keys.append(key)
430+
routing_ptrs.append(block_data.ctypes.data)
431+
routing_sizes.append(per_block_bytes)
432+
433+
if routing_keys:
434+
self.storage_backend.batch_set(
435+
keys=routing_keys, target_locations=routing_ptrs, target_sizes=routing_sizes
436+
)
437+
logger.debug(f"[R3] Wrote {len(routing_keys)} routing blocks to storage")
438+
except Exception as e:
439+
logger.warning(f"[R3] Failed to write routing to storage: {e}")
440+
441+
def _read_routing_from_storage(self, task_keys, gpu_block_ids):
442+
"""
443+
Read routing data from storage backend into routing_host_buffer.
444+
Only for mooncake/file backends; only tp_rank=0 reads routing.
445+
"""
446+
if self.routing_host_view is None or self.rank != 0:
447+
return
448+
if self.storage_backend_type not in ("mooncake", "file"):
449+
return
450+
451+
try:
452+
spec = self.aux_data_specs.get("routing")
453+
if spec is None or not spec.enabled:
454+
return
455+
456+
bs = self.block_size
457+
per_block_bytes = bs * spec.num_layers * spec.per_token_size * np.dtype(spec.dtype).itemsize
458+
459+
for block_hash, gpu_bid in zip(task_keys, gpu_block_ids):
460+
key = spec.get_storage_key(self.key_prefix, block_hash, self.rank)
461+
start = gpu_bid * bs
462+
end = start + bs
463+
target_slice = self.routing_host_view.buffer[start:end]
464+
if not target_slice.flags["C_CONTIGUOUS"]:
465+
# Need contiguous target for ctypes pointer
466+
tmp = np.ascontiguousarray(target_slice)
467+
result = self.storage_backend.get(
468+
key=key, target_location=tmp.ctypes.data, target_size=per_block_bytes
469+
)
470+
if result is not None and result >= 0:
471+
self.routing_host_view.buffer[start:end] = tmp
472+
else:
473+
self.storage_backend.get(
474+
key=key, target_location=target_slice.ctypes.data, target_size=per_block_bytes
475+
)
476+
477+
logger.debug(f"[R3] Read {len(task_keys)} routing blocks from storage")
478+
except Exception as e:
479+
logger.warning(f"[R3] Failed to read routing from storage: {e}")
480+
310481
def _init_control(self):
311482
dp_rank = self.local_data_parallel_id
312483
tp_rank = self.rank
@@ -809,6 +980,9 @@ def read_storage_task(self, task: ReadStorageTask):
809980
logger.info(
810981
f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}"
811982
)
983+
# Read routing data from storage for matched blocks
984+
matched_keys = task.keys[: len(valid_gpu_block_ids)]
985+
self._read_routing_from_storage(matched_keys, valid_gpu_block_ids)
812986
except Exception as e:
813987
logger.error(
814988
f"Failed to read cache for task {task.task_id}, error: {e}, traceback: {traceback.format_exc()}"
@@ -1000,6 +1174,9 @@ def write_back_storage_task(self, task: WriteStorageTask):
10001174
logger.info(
10011175
f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}"
10021176
)
1177+
# Write routing data to storage (shares dedup with KVCache)
1178+
remaining_keys = task.keys[match_block_num:]
1179+
self._write_routing_to_storage(remaining_keys, gpu_block_ids)
10031180
except Exception as e:
10041181
logger.error(f"Error in write back storage task: {e}, traceback:{traceback.format_exc()}")
10051182
gpu_block_ids = []
@@ -1375,6 +1552,10 @@ def _transfer_data(
13751552
0,
13761553
)
13771554

1555+
# Routing: routing_host_buffer → routing_swap_buffer
1556+
if "routing" in self.aux_data_specs:
1557+
self._swap_routing(gpu_block_ids, cpu_block_ids, "to_cpu")
1558+
13781559
elif event_type.value == CacheStatus.SWAP2GPU.value:
13791560
swap_cache_all_layers(
13801561
self.gpu_cache_k_tensors,
@@ -1413,6 +1594,11 @@ def _transfer_data(
14131594
self.device,
14141595
1,
14151596
)
1597+
1598+
# Routing: routing_swap_buffer → routing_host_buffer
1599+
if "routing" in self.aux_data_specs:
1600+
self._swap_routing(gpu_block_ids, cpu_block_ids, "to_gpu")
1601+
14161602
else:
14171603
logger.warning(
14181604
f"transfer data: Get unexpected event type {event_type}, only SWAP2CPU and SWAP2GPU supported"

fastdeploy/cache_manager/prefix_cache_manager.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,25 @@ def launch_cache_manager(
293293
else:
294294
storage_arg_str = " "
295295

296+
# Compute routing replay args for CTM
297+
routing_arg_str = ""
298+
routing_replay_config = getattr(self.config, "routing_replay_config", None)
299+
if routing_replay_config is not None and routing_replay_config.enable_routing_replay:
300+
model_config = self.config.model_config
301+
num_moe_layers = model_config.num_hidden_layers - model_config.moe_layer_start_index
302+
if model_config.architectures[0] == "Glm4MoeForCausalLM":
303+
moe_top_k = model_config.num_experts_per_tok
304+
else:
305+
moe_top_k = model_config.moe_k
306+
num_experts = model_config.moe_num_experts + model_config.moe_num_shared_experts
307+
routing_dtype = "uint8" if num_experts + 1 <= 255 else ("uint16" if num_experts + 1 <= 65535 else "uint32")
308+
routing_arg_str = (
309+
f" --enable_routing_replay 1"
310+
f" --routing_num_moe_layers {num_moe_layers}"
311+
f" --routing_moe_top_k {moe_top_k}"
312+
f" --routing_dtype {routing_dtype}"
313+
)
314+
296315
if self.cache_config.num_cpu_blocks > 0 or self.cache_config.kvcache_storage_backend:
297316
for i in range(tensor_parallel_size):
298317
launch_cmd = (
@@ -324,6 +343,7 @@ def launch_cache_manager(
324343
+ f" --write_policy {cache_config.write_policy}"
325344
+ f" --max_model_len {self.config.model_config.max_model_len}"
326345
+ f" --model_path {self.config.model_config.model}"
346+
+ routing_arg_str
327347
+ f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1"
328348
)
329349
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")

fastdeploy/engine/common_engine.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2274,10 +2274,51 @@ def _stop_profile(self):
22742274
num_gpu_blocks = self.get_profile_block_num_signal.value[0]
22752275
self.cfg.cache_config.reset(num_gpu_blocks)
22762276
self.resource_manager.reset_cache_config(self.cfg.cache_config)
2277+
2278+
# Create RoutingHostBuffer (SharedMemory) after num_gpu_blocks is known
2279+
if self.cfg.routing_replay_config.enable_routing_replay:
2280+
self._init_routing_host_buffer(num_gpu_blocks)
2281+
22772282
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
22782283
device_ids = self.cfg.parallel_config.device_ids.split(",")
22792284
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
22802285

2286+
def _init_routing_host_buffer(self, num_gpu_blocks: int):
2287+
"""Create RoutingHostBuffer SharedMemory after profiling determines num_gpu_blocks."""
2288+
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
2289+
RoutingHostBuffer,
2290+
RoutingHostBufferView,
2291+
)
2292+
2293+
model_config = self.cfg.model_config
2294+
num_moe_layers = model_config.num_hidden_layers - model_config.moe_layer_start_index
2295+
if model_config.architectures[0] == "Glm4MoeForCausalLM":
2296+
moe_top_k = model_config.num_experts_per_tok
2297+
else:
2298+
moe_top_k = model_config.moe_k
2299+
2300+
num_experts = model_config.moe_num_experts + model_config.moe_num_shared_experts
2301+
dtype = "uint8" if num_experts + 1 <= 255 else ("uint16" if num_experts + 1 <= 65535 else "uint32")
2302+
2303+
dp_suffix = str(self.cfg.parallel_config.local_engine_worker_queue_port)
2304+
self.routing_host_buffer = RoutingHostBuffer(
2305+
num_gpu_blocks=num_gpu_blocks,
2306+
block_size=self.cfg.cache_config.block_size,
2307+
num_moe_layers=num_moe_layers,
2308+
top_k=moe_top_k,
2309+
dtype=dtype,
2310+
dp_suffix=dp_suffix,
2311+
)
2312+
2313+
# Set routing_host_view on resource_manager for PD disaggregation (D side)
2314+
if hasattr(self, "resource_manager") and hasattr(self.resource_manager, "routing_host_view"):
2315+
shm_name = f"routing_host_buffer.{dp_suffix}"
2316+
max_num_kv_tokens = num_gpu_blocks * self.cfg.cache_config.block_size
2317+
shape = (max_num_kv_tokens, num_moe_layers, moe_top_k)
2318+
self.resource_manager.routing_host_view = RoutingHostBufferView(
2319+
shape=shape, dtype=dtype, shm_name=shm_name
2320+
)
2321+
22812322
def check_health(self, time_interval_threashold=30):
22822323
"""
22832324
Check the health of the model server by checking whether all workers are alive.

fastdeploy/engine/engine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,11 @@ def _stop_profile(self):
725725
num_gpu_blocks = self.get_profile_block_num_signal.value[0]
726726
self.cfg.cache_config.reset(num_gpu_blocks)
727727
self.engine.resource_manager.reset_cache_config(self.cfg.cache_config)
728+
729+
# Create RoutingHostBuffer (SharedMemory) before starting cache service
730+
if self.cfg.routing_replay_config.enable_routing_replay:
731+
self.engine._init_routing_host_buffer(num_gpu_blocks)
732+
728733
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
729734
if not current_platform.is_intel_hpu():
730735
device_ids = self.cfg.parallel_config.device_ids.split(",")

0 commit comments

Comments
 (0)