From 0c68b77f9fb25e15b47acaddf58f8cfb0fb1bc75 Mon Sep 17 00:00:00 2001 From: gongshaotian Date: Tue, 31 Mar 2026 17:35:42 +0800 Subject: [PATCH] [RL][Feature] R3 Phase 2: routing data follows KVCache block lifecycle (swap/storage/PD) Implement dual-buffer architecture for routing replay: - GPU transient buffer [max_num_batched_tokens, L, K] with Triton v2 kernel - SharedMemory routing_host_buffer for cross-process Engine/Worker/CTM sharing - Lazy SharedMemory attach in Worker and TokenProcessor (Engine creates after profiling) - CTM routing write/read for swap and storage backends - PD disaggregation: P gathers routing via send_first_token, D writes to host buffer - Local store persistence verified end-to-end Co-Authored-By: Claude Opus 4.6 --- fastdeploy/cache_manager/cache_data.py | 26 +- .../cache_manager/cache_transfer_manager.py | 186 ++++++++++++ .../cache_manager/prefix_cache_manager.py | 20 ++ fastdeploy/engine/common_engine.py | 41 +++ fastdeploy/engine/engine.py | 5 + fastdeploy/engine/request.py | 11 +- .../engine/sched/resource_manager_v1.py | 31 ++ fastdeploy/model_executor/forward_meta.py | 2 + fastdeploy/model_executor/layers/moe/moe.py | 13 +- .../layers/moe/routing_indices_cache.py | 272 +++++++++++++++++- .../model_executor/pre_and_post_process.py | 58 +++- fastdeploy/output/token_processor.py | 90 ++++++ fastdeploy/worker/gpu_model_runner.py | 3 + 13 files changed, 730 insertions(+), 28 deletions(-) diff --git a/fastdeploy/cache_manager/cache_data.py b/fastdeploy/cache_manager/cache_data.py index 82911eccfa3..9fd48cec2ce 100644 --- a/fastdeploy/cache_manager/cache_data.py +++ b/fastdeploy/cache_manager/cache_data.py @@ -14,13 +14,35 @@ # limitations under the License. """ +from dataclasses import dataclass from enum import Enum +from typing import Any, Optional from fastdeploy.utils import get_logger logger = get_logger("prefix_cache_manager", "cache_manager.log") +@dataclass +class AuxBlockDataSpec: + """ + Describes a type of auxiliary data bound to KVCache blocks. + CacheTransferManager iterates registered specs during swap/storage + to perform corresponding data transfers. + """ + + name: str + num_layers: int + per_token_size: int = 0 + block_size: int = 0 + dtype: str = "uint8" + swap_buffer: Optional[Any] = None + enabled: bool = True + + def get_storage_key(self, key_prefix: str, block_hash: str, rank: int) -> str: + return f"prefix{key_prefix}_{block_hash}_{rank}_{self.name}" + + class CacheStatus(Enum): """ cache status enum class @@ -56,6 +78,7 @@ def __init__( cache_status=CacheStatus.GPU, is_persistent=False, persistent_shared_count=0, + aux_data_names=None, ): """ Args: @@ -89,6 +112,7 @@ def __init__( self.cache_status = cache_status self.is_persistent = is_persistent self.persistent_shared_count = persistent_shared_count + self.aux_data_names = aux_data_names or [] self.req_id_set = set() def __lt__(self, other): @@ -102,7 +126,7 @@ def __lt__(self, other): else: return self.depth > other.depth - def __str__(self): + def __str__(self) -> str: """ return node info """ diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index b264f03b753..d9ebc7340a2 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -129,7 +129,15 @@ def parse_args(): ) parser.add_argument("--model_path", type=str, help="The path of model") + # Routing replay (R3) arguments + parser.add_argument("--enable_routing_replay", type=int, default=0, help="Enable routing replay") + parser.add_argument("--routing_num_moe_layers", type=int, default=0, help="Number of MoE layers for routing") + parser.add_argument("--routing_moe_top_k", type=int, default=0, help="MoE top_k for routing") + parser.add_argument("--routing_dtype", type=str, default="uint8", help="Routing data dtype") + args = parser.parse_args() + # Convert int flag to bool + args.enable_routing_replay = bool(args.enable_routing_replay) return args @@ -241,6 +249,13 @@ def __init__(self, args): self._init_cpu_cache() if self.storage_backend_type is not None: self._init_storage(args) + + # Initialize auxiliary data specs (e.g., routing replay) + self.aux_data_specs = {} + self.routing_host_view = None + self.routing_swap_buffer = None + self._init_routing_aux_data(args) + self._init_control() cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32) @@ -307,6 +322,162 @@ def __init__(self, args): ) self.cache_transfer_inited_signal.value[self.rank] = 1 + def _init_routing_aux_data(self, args): + """Initialize routing auxiliary data buffers for swap sync.""" + enable_routing_replay = getattr(args, "enable_routing_replay", False) + if not enable_routing_replay: + return + + try: + from fastdeploy.cache_manager.cache_data import AuxBlockDataSpec + from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( + RoutingHostBufferView, + RoutingSwapBuffer, + ) + + num_moe_layers = getattr(args, "routing_num_moe_layers", 0) + moe_top_k = getattr(args, "routing_moe_top_k", 0) + routing_dtype = getattr(args, "routing_dtype", "uint8") + + if num_moe_layers == 0 or moe_top_k == 0: + return + + spec = AuxBlockDataSpec( + name="routing", + num_layers=num_moe_layers, + per_token_size=moe_top_k, + block_size=self.block_size, + dtype=routing_dtype, + ) + + # Create routing swap buffer (for CPU blocks) + if self.num_cpu_blocks > 0: + dp_suffix = str(getattr(args, "engine_worker_queue_port", "")) + self.routing_swap_buffer = RoutingSwapBuffer( + num_cpu_blocks=self.num_cpu_blocks, + block_size=self.block_size, + num_moe_layers=num_moe_layers, + top_k=moe_top_k, + dtype=routing_dtype, + dp_suffix=dp_suffix, + ) + spec.swap_buffer = self.routing_swap_buffer + + # Attach to routing host buffer (SharedMemory created by Engine) + dp_suffix = str(getattr(args, "engine_worker_queue_port", "")) + shm_name = f"routing_host_buffer.{dp_suffix}" + max_num_kv_tokens = self.num_gpu_blocks * self.block_size + shape = (max_num_kv_tokens, num_moe_layers, moe_top_k) + try: + self.routing_host_view = RoutingHostBufferView(shape=shape, dtype=routing_dtype, shm_name=shm_name) + logger.info(f"[R3] CTM attached to RoutingHostBuffer: {shm_name}") + except FileNotFoundError: + logger.warning(f"[R3] CTM RoutingHostBuffer {shm_name} not found") + + self.aux_data_specs["routing"] = spec + logger.info(f"[R3] CTM registered routing aux data: layers={num_moe_layers}, top_k={moe_top_k}") + + except Exception as e: + logger.warning(f"[R3] CTM failed to init routing aux data: {e}") + + def _swap_routing(self, gpu_block_ids, cpu_block_ids, direction): + """ + Swap routing data between routing_host_buffer and routing_swap_buffer. + Pure CPU-to-CPU numpy memcpy, no GPU DMA. + """ + if self.routing_host_view is None or self.routing_swap_buffer is None: + return + bs = self.block_size + for gpu_bid, cpu_bid in zip(gpu_block_ids, cpu_block_ids): + gpu_start = gpu_bid * bs + gpu_end = gpu_start + bs + cpu_start = cpu_bid * bs + cpu_end = cpu_start + bs + if direction == "to_cpu": + self.routing_swap_buffer.buffer[cpu_start:cpu_end] = self.routing_host_view.buffer[gpu_start:gpu_end] + else: # to_gpu + self.routing_host_view.buffer[gpu_start:gpu_end] = self.routing_swap_buffer.buffer[cpu_start:cpu_end] + + def _write_routing_to_storage(self, task_keys, gpu_block_ids): + """ + Write routing data from routing_host_buffer to storage backend. + Only for mooncake/file backends; only tp_rank=0 writes routing. + """ + if self.routing_host_view is None or self.rank != 0: + return + if self.storage_backend_type not in ("mooncake", "file"): + return + + try: + spec = self.aux_data_specs.get("routing") + if spec is None or not spec.enabled: + return + + bs = self.block_size + routing_keys = [] + routing_ptrs = [] + routing_sizes = [] + per_block_bytes = bs * spec.num_layers * spec.per_token_size * np.dtype(spec.dtype).itemsize + + for block_hash, gpu_bid in zip(task_keys, gpu_block_ids): + key = spec.get_storage_key(self.key_prefix, block_hash, self.rank) + start = gpu_bid * bs + end = start + bs + block_data = self.routing_host_view.buffer[start:end] + if not block_data.flags["C_CONTIGUOUS"]: + block_data = np.ascontiguousarray(block_data) + routing_keys.append(key) + routing_ptrs.append(block_data.ctypes.data) + routing_sizes.append(per_block_bytes) + + if routing_keys: + self.storage_backend.batch_set( + keys=routing_keys, target_locations=routing_ptrs, target_sizes=routing_sizes + ) + logger.debug(f"[R3] Wrote {len(routing_keys)} routing blocks to storage") + except Exception as e: + logger.warning(f"[R3] Failed to write routing to storage: {e}") + + def _read_routing_from_storage(self, task_keys, gpu_block_ids): + """ + Read routing data from storage backend into routing_host_buffer. + Only for mooncake/file backends; only tp_rank=0 reads routing. + """ + if self.routing_host_view is None or self.rank != 0: + return + if self.storage_backend_type not in ("mooncake", "file"): + return + + try: + spec = self.aux_data_specs.get("routing") + if spec is None or not spec.enabled: + return + + bs = self.block_size + per_block_bytes = bs * spec.num_layers * spec.per_token_size * np.dtype(spec.dtype).itemsize + + for block_hash, gpu_bid in zip(task_keys, gpu_block_ids): + key = spec.get_storage_key(self.key_prefix, block_hash, self.rank) + start = gpu_bid * bs + end = start + bs + target_slice = self.routing_host_view.buffer[start:end] + if not target_slice.flags["C_CONTIGUOUS"]: + # Need contiguous target for ctypes pointer + tmp = np.ascontiguousarray(target_slice) + result = self.storage_backend.get( + key=key, target_location=tmp.ctypes.data, target_size=per_block_bytes + ) + if result is not None and result >= 0: + self.routing_host_view.buffer[start:end] = tmp + else: + self.storage_backend.get( + key=key, target_location=target_slice.ctypes.data, target_size=per_block_bytes + ) + + logger.debug(f"[R3] Read {len(task_keys)} routing blocks from storage") + except Exception as e: + logger.warning(f"[R3] Failed to read routing from storage: {e}") + def _init_control(self): dp_rank = self.local_data_parallel_id tp_rank = self.rank @@ -809,6 +980,9 @@ def read_storage_task(self, task: ReadStorageTask): logger.info( f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}" ) + # Read routing data from storage for matched blocks + matched_keys = task.keys[: len(valid_gpu_block_ids)] + self._read_routing_from_storage(matched_keys, valid_gpu_block_ids) except Exception as e: logger.error( f"Failed to read cache for task {task.task_id}, error: {e}, traceback: {traceback.format_exc()}" @@ -1000,6 +1174,9 @@ def write_back_storage_task(self, task: WriteStorageTask): logger.info( f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}" ) + # Write routing data to storage (shares dedup with KVCache) + remaining_keys = task.keys[match_block_num:] + self._write_routing_to_storage(remaining_keys, gpu_block_ids) except Exception as e: logger.error(f"Error in write back storage task: {e}, traceback:{traceback.format_exc()}") gpu_block_ids = [] @@ -1375,6 +1552,10 @@ def _transfer_data( 0, ) + # Routing: routing_host_buffer → routing_swap_buffer + if "routing" in self.aux_data_specs: + self._swap_routing(gpu_block_ids, cpu_block_ids, "to_cpu") + elif event_type.value == CacheStatus.SWAP2GPU.value: swap_cache_all_layers( self.gpu_cache_k_tensors, @@ -1413,6 +1594,11 @@ def _transfer_data( self.device, 1, ) + + # Routing: routing_swap_buffer → routing_host_buffer + if "routing" in self.aux_data_specs: + self._swap_routing(gpu_block_ids, cpu_block_ids, "to_gpu") + else: logger.warning( f"transfer data: Get unexpected event type {event_type}, only SWAP2CPU and SWAP2GPU supported" diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 0f2e7869971..502f19f5b3b 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -293,6 +293,25 @@ def launch_cache_manager( else: storage_arg_str = " " + # Compute routing replay args for CTM + routing_arg_str = "" + routing_replay_config = getattr(self.config, "routing_replay_config", None) + if routing_replay_config is not None and routing_replay_config.enable_routing_replay: + model_config = self.config.model_config + num_moe_layers = model_config.num_hidden_layers - model_config.moe_layer_start_index + if model_config.architectures[0] == "Glm4MoeForCausalLM": + moe_top_k = model_config.num_experts_per_tok + else: + moe_top_k = model_config.moe_k + num_experts = model_config.moe_num_experts + model_config.moe_num_shared_experts + routing_dtype = "uint8" if num_experts + 1 <= 255 else ("uint16" if num_experts + 1 <= 65535 else "uint32") + routing_arg_str = ( + f" --enable_routing_replay 1" + f" --routing_num_moe_layers {num_moe_layers}" + f" --routing_moe_top_k {moe_top_k}" + f" --routing_dtype {routing_dtype}" + ) + if self.cache_config.num_cpu_blocks > 0 or self.cache_config.kvcache_storage_backend: for i in range(tensor_parallel_size): launch_cmd = ( @@ -324,6 +343,7 @@ def launch_cache_manager( + f" --write_policy {cache_config.write_policy}" + f" --max_model_len {self.config.model_config.max_model_len}" + f" --model_path {self.config.model_config.model}" + + routing_arg_str + f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1" ) logger.info(f"Launch cache transfer manager, command:{launch_cmd}") diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 1c6452408ba..a6b67a9b41a 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -2274,10 +2274,51 @@ def _stop_profile(self): num_gpu_blocks = self.get_profile_block_num_signal.value[0] self.cfg.cache_config.reset(num_gpu_blocks) self.resource_manager.reset_cache_config(self.cfg.cache_config) + + # Create RoutingHostBuffer (SharedMemory) after num_gpu_blocks is known + if self.cfg.routing_replay_config.enable_routing_replay: + self._init_routing_host_buffer(num_gpu_blocks) + if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": device_ids = self.cfg.parallel_config.device_ids.split(",") self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix) + def _init_routing_host_buffer(self, num_gpu_blocks: int): + """Create RoutingHostBuffer SharedMemory after profiling determines num_gpu_blocks.""" + from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( + RoutingHostBuffer, + RoutingHostBufferView, + ) + + model_config = self.cfg.model_config + num_moe_layers = model_config.num_hidden_layers - model_config.moe_layer_start_index + if model_config.architectures[0] == "Glm4MoeForCausalLM": + moe_top_k = model_config.num_experts_per_tok + else: + moe_top_k = model_config.moe_k + + num_experts = model_config.moe_num_experts + model_config.moe_num_shared_experts + dtype = "uint8" if num_experts + 1 <= 255 else ("uint16" if num_experts + 1 <= 65535 else "uint32") + + dp_suffix = str(self.cfg.parallel_config.local_engine_worker_queue_port) + self.routing_host_buffer = RoutingHostBuffer( + num_gpu_blocks=num_gpu_blocks, + block_size=self.cfg.cache_config.block_size, + num_moe_layers=num_moe_layers, + top_k=moe_top_k, + dtype=dtype, + dp_suffix=dp_suffix, + ) + + # Set routing_host_view on resource_manager for PD disaggregation (D side) + if hasattr(self, "resource_manager") and hasattr(self.resource_manager, "routing_host_view"): + shm_name = f"routing_host_buffer.{dp_suffix}" + max_num_kv_tokens = num_gpu_blocks * self.cfg.cache_config.block_size + shape = (max_num_kv_tokens, num_moe_layers, moe_top_k) + self.resource_manager.routing_host_view = RoutingHostBufferView( + shape=shape, dtype=dtype, shm_name=shm_name + ) + def check_health(self, time_interval_threashold=30): """ Check the health of the model server by checking whether all workers are alive. diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 41461710aa3..22e5d501c03 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -725,6 +725,11 @@ def _stop_profile(self): num_gpu_blocks = self.get_profile_block_num_signal.value[0] self.cfg.cache_config.reset(num_gpu_blocks) self.engine.resource_manager.reset_cache_config(self.cfg.cache_config) + + # Create RoutingHostBuffer (SharedMemory) before starting cache service + if self.cfg.routing_replay_config.enable_routing_replay: + self.engine._init_routing_host_buffer(num_gpu_blocks) + if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": if not current_platform.is_intel_hpu(): device_ids = self.cfg.parallel_config.device_ids.split(",") diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 597bd23eea8..8c7b0110344 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -1012,6 +1012,7 @@ def __init__( self.ic_req_data = ic_req_data self.prompt_token_ids_len = prompt_token_ids_len self.trace_carrier = trace_carrier + self.routing_data = None # Optional[np.ndarray], [seq_len, num_moe_layers, top_k] if prompt_token_ids is None: self.prompt_token_ids = [] @@ -1107,12 +1108,15 @@ def from_dict(cls, d: dict): d.pop("metrics", None) metrics = None trace_carrier = d.pop("trace_carrier", {}) - return RequestOutput(**d, outputs=completion_output, metrics=metrics, trace_carrier=trace_carrier) + routing_data = d.pop("routing_data", None) + obj = RequestOutput(**d, outputs=completion_output, metrics=metrics, trace_carrier=trace_carrier) + obj.routing_data = routing_data + return obj def to_dict(self): """convert RequestOutput into a serializable dict""" - return { + d = { "request_id": self.request_id, "prompt": self.prompt, "prompt_token_ids": self.prompt_token_ids, @@ -1130,6 +1134,9 @@ def to_dict(self): "prompt_token_ids_len": self.prompt_token_ids_len, "trace_carrier": self.trace_carrier, } + if self.routing_data is not None: + d["routing_data"] = self.routing_data + return d def get(self, key: str, default_value=None): if hasattr(self, key): diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 6cc9363bf6b..dbcf54175cd 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -208,6 +208,7 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l self.encoder_cache = EncoderCacheManager(config.cache_config.max_encoder_cache) self.processor_cache = None + self.routing_host_view = None # Set by Engine after RoutingHostBuffer creation if config.model_config.enable_mm and config.cache_config.max_processor_cache > 0: max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024) self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes) @@ -1355,8 +1356,38 @@ def add_prefilled_request(self, request_output: RequestOutput): request_output.metrics.decode_recv_req_time = request.metrics.decode_recv_req_time request_output.metrics.decode_preallocate_req_time = request.metrics.decode_preallocate_req_time request.metrics = request_output.metrics + + # [R3] Write P's prefill routing data into D's routing_host_buffer + if ( + self.routing_host_view is not None + and hasattr(request_output, "routing_data") + and request_output.routing_data is not None + ): + try: + self._write_prefill_routing_to_host_buffer(request, request_output.routing_data) + except Exception as e: + llm_logger.warning(f"[R3] Failed to write prefill routing for {request_output.request_id}: {e}") + self.running.append(request) + def _write_prefill_routing_to_host_buffer(self, request, routing_data): + """ + Write P's prefill routing data into D's routing_host_buffer. + Uses D's block_tables to compute slot_mapping. + """ + import math + + seq_len = routing_data.shape[0] + block_size = self.config.cache_config.block_size + num_blocks_needed = math.ceil(seq_len / block_size) + block_ids = request.block_tables[:num_blocks_needed] + + positions = np.arange(seq_len) + block_indices = positions // block_size + offsets = positions % block_size + slot_mapping = np.array(block_ids)[block_indices] * block_size + offsets + self.routing_host_view.scatter(slot_mapping, routing_data) + def _free_blocks(self, request: Request): if self.config.cache_config.enable_prefix_caching: self.cache_manager.release_block_ids(request) diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index d625326c06f..4d1119d33c1 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -148,6 +148,8 @@ class ForwardMeta: is_dummy_or_profile_run: bool = False # Routing Replay table buffer routing_replay_table: Optional[paddle.Tensor] = None + # Phase 2: GPU transient routing buffer [max_num_batched_tokens, num_moe_layers, top_k] + gpu_routing_buffer: Optional[paddle.Tensor] = None # chunked MoE related moe_num_chunk: int = 1 diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 12964ef25e0..50236e23664 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -29,6 +29,7 @@ from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( save_routing_to_buffer, + save_routing_to_buffer_v2, ) from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import h2d_copy, slice_fn @@ -697,7 +698,17 @@ def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = topk_ids_hookfunc = None if self.enable_routing_replay: # When execute empty_input_forward forward_meta is None. When execute mtp layer routing_replay_table is None. - if forward_meta is not None and forward_meta.routing_replay_table is not None: + if forward_meta is not None and forward_meta.gpu_routing_buffer is not None: + moe_layer_idx = self.layer_idx - self.fd_config.model_config.moe_layer_start_index + topk_ids_hookfunc = partial( + save_routing_to_buffer_v2, + gpu_routing_buffer=forward_meta.gpu_routing_buffer, + layer_idx=moe_layer_idx, + tp_size=self.fd_config.parallel_config.tensor_parallel_size, + ep_size=self.fd_config.parallel_config.expert_parallel_size, + tp_group=self.fd_config.parallel_config.tp_group, + ) + elif forward_meta is not None and forward_meta.routing_replay_table is not None: moe_layer_idx = self.layer_idx - self.fd_config.model_config.moe_layer_start_index topk_ids_hookfunc = partial( save_routing_to_buffer, diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index da423bd3da2..894e7dfcd06 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -18,6 +18,7 @@ import atexit import functools import multiprocessing +import multiprocessing.shared_memory import os import shutil import threading @@ -155,6 +156,72 @@ def save_routing_to_buffer( ) +@enable_compat_on_triton_kernel +@triton.jit +def _save_routing_kernel_v2( + GPU_ROUTING_BUFFER_PTR, + TOPK_IDS_PTR, + LAYER_IDX, + TOKEN_NUM, + TOP_K, + NUM_MOE_LAYERS, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + token_offsets = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + token_mask = token_offsets < TOKEN_NUM + k_offsets = tl.arange(0, BLOCK_SIZE_K) + k_mask = k_offsets < TOP_K + + load_mask = token_mask[:, None] & k_mask[None, :] + topk_vals = tl.load( + TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :], + mask=load_mask, + ) + + STRIDE_TOKEN = NUM_MOE_LAYERS * TOP_K + STRIDE_LAYER = TOP_K + output_ptrs = ( + GPU_ROUTING_BUFFER_PTR + token_offsets[:, None] * STRIDE_TOKEN + LAYER_IDX * STRIDE_LAYER + k_offsets[None, :] + ) + tl.store(output_ptrs, topk_vals, mask=load_mask) + + +def save_routing_to_buffer_v2( + gpu_routing_buffer: paddle.Tensor, + topk_ids: paddle.Tensor, + layer_idx: int, + tp_size: int, + ep_size: int, + tp_group: dist.communication.group.Group, +): + token_num_per_rank = topk_ids.shape[0] + if token_num_per_rank == 0: + return + if tp_size > 1 and ep_size > 1: + topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype) + paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group) + topk_ids = topk_ids_all[:token_num_per_rank, :] + + token_num, top_k = topk_ids.shape + num_moe_layers = gpu_routing_buffer.shape[1] + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_K = triton.next_power_of_2(top_k) + grid = (triton.cdiv(token_num, BLOCK_SIZE_M),) + _save_routing_kernel_v2[grid]( + gpu_routing_buffer, + topk_ids, + LAYER_IDX=layer_idx, + TOKEN_NUM=token_num, + TOP_K=top_k, + NUM_MOE_LAYERS=num_moe_layers, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + + class RoutingReplayManager: """Request level routing replay table manager""" @@ -203,17 +270,38 @@ def _init_routing_cache(self, dtype: str, total_block_num: int): max_num_kv_tokens = total_block_num * self.fd_config.cache_config.block_size + # Legacy host cache (kept during transition, will be replaced by SharedMemory routing_host_buffer) self._host_cache = paddle.full( shape=[max_num_kv_tokens, self.num_moe_layers, self.moe_top_k], fill_value=-1, dtype=dtype, device="cpu" ) - self.routing_replay_table = paddle.full( - shape=[self.max_num_seqs, self.max_model_len, self.num_moe_layers, self.moe_top_k], + # Phase 2: Small GPU transient buffer (replaces the old routing_replay_table) + max_num_batched_tokens = self.fd_config.scheduler_config.max_num_batched_tokens + self.gpu_routing_buffer = paddle.full( + shape=[max_num_batched_tokens, self.num_moe_layers, self.moe_top_k], fill_value=-1, dtype=dtype, ) + + # Legacy routing_replay_table kept as alias for backward compatibility during transition + self.routing_replay_table = self.gpu_routing_buffer + + # Lazy attach to SharedMemory routing_host_buffer (created by Engine in _stop_profile) + # Engine creates SharedMemory after profiling completes, which is after Worker init. + # So we defer attachment to the first save_captured_routing() call. + self.routing_host_view = None + self._routing_host_view_attach_attempted = False + self._routing_host_view_shm_name = ( + f"routing_host_buffer.{str(self.fd_config.parallel_config.local_engine_worker_queue_port)}" + ) + self._routing_host_view_shape = (max_num_kv_tokens, self.num_moe_layers, self.moe_top_k) + self._routing_host_view_dtype = dtype + + gpu_buffer_bytes = int(np.prod(self.gpu_routing_buffer.shape)) * np.dtype(dtype).itemsize logger.info( - f"[R3] The host cache size is:{self._host_cache.shape}, device cache size is: {self.routing_replay_table.shape}" + f"[R3] GPU transient routing buffer: {self.gpu_routing_buffer.shape} " + f"({gpu_buffer_bytes / 1024:.1f} KB), " + f"host cache: {self._host_cache.shape}" ) def get_routing_dtype(self, num_experts: int, reserved_fill_value: int = 1) -> str: @@ -236,7 +324,7 @@ def get_routing_dtype(self, num_experts: int, reserved_fill_value: int = 1) -> s return dtype def update_host_cache(self, positions: paddle.Tensor, slot_mapping: paddle.Tensor): - """Update the host cache with new tokens""" + """Update the host cache with new tokens (legacy v1 path)""" for batch_id, position in enumerate(positions): if len(position) > 0 and len(slot_mapping[batch_id]) > 0: routing_ids = self.routing_replay_table[batch_id, position, :, :].contiguous() @@ -244,6 +332,69 @@ def update_host_cache(self, positions: paddle.Tensor, slot_mapping: paddle.Tenso self._host_cache[slot_mapping[batch_id], :, :] = routing_ids + def _try_attach_routing_host_view(self): + """Lazily attach to SharedMemory routing_host_buffer on first use.""" + if self._routing_host_view_attach_attempted: + return + self._routing_host_view_attach_attempted = True + try: + self.routing_host_view = RoutingHostBufferView( + shape=self._routing_host_view_shape, + dtype=self._routing_host_view_dtype, + shm_name=self._routing_host_view_shm_name, + ) + logger.info(f"[R3] Attached to RoutingHostBuffer SharedMemory: {self._routing_host_view_shm_name}") + except FileNotFoundError: + logger.warning( + f"[R3] RoutingHostBuffer SharedMemory {self._routing_host_view_shm_name} not found. " + "Falling back to legacy _host_cache (no swap sync)." + ) + + def save_captured_routing(self, num_tokens: int, slot_mapping: np.ndarray): + """ + After forward, scatter GPU buffer routing data to routing_host_buffer. + Called in step gap (post_process), not during forward. CUDAGraph compatible. + + Args: + num_tokens: Number of tokens processed in the current step + slot_mapping: [num_tokens], each token's routing_host_buffer slot index + """ + if num_tokens == 0: + return + + # Lazy attach to SharedMemory (Engine creates it after profiling completes) + if self.routing_host_view is None and not self._routing_host_view_attach_attempted: + self._try_attach_routing_host_view() + + # D2H copy: GPU → CPU numpy + data = self.gpu_routing_buffer[:num_tokens].cpu().numpy() + + if self.routing_host_view is not None: + # Phase 2: scatter to SharedMemory routing_host_buffer + self.routing_host_view.scatter(slot_mapping, data) + else: + # Fallback: scatter to legacy _host_cache + self._host_cache[slot_mapping, :, :] = paddle.to_tensor(data, place="cpu") + + def compute_slot_mapping_flat(self, positions) -> np.ndarray: + """ + Compute flat slot_mapping for all tokens in the step. + Returns a 1D numpy array of slot indices. + """ + all_slots = [] + block_size = self.fd_config.cache_config.block_size + for batch_id, position in enumerate(positions): + if len(position) == 0: + continue + block_table_indices = position // block_size + token_block_ids = self.block_table[batch_id, block_table_indices] + block_offset = position % block_size + token_cache_ids = np.array(token_block_ids) * block_size + block_offset + all_slots.append(token_cache_ids) + if all_slots: + return np.concatenate(all_slots) + return np.array([], dtype=np.int64) + def get_token_positions(self, seq_lens_decoder, seq_lens_this_time): """Get token position of each sequence in a batch.""" starts = seq_lens_decoder.numpy()[:, 0] @@ -386,15 +537,19 @@ def _put_request_to_store( def clear_request(self, batch_id: int): """Clear the routing indices of the request""" - self._clear_table_slot(batch_id) + # With gpu_routing_buffer (v2), no per-batch-id slot to clear — + # buffer is reused each step. Just remove from tracking dict. self.routing_batch_to_request.pop(batch_id, None) def _clear_table_slot(self, batch_id: int): - assert 0 <= batch_id < self.max_num_seqs - self.routing_replay_table[batch_id].fill_(-1) + # No-op with gpu_routing_buffer (v2): buffer is linear, reused each step + pass def get_routing_table(self) -> paddle.Tensor: - return self.routing_replay_table + return self.gpu_routing_buffer + + def get_gpu_routing_buffer(self) -> paddle.Tensor: + return self.gpu_routing_buffer def split_request_id(self, request_id: str): """ @@ -415,10 +570,109 @@ def split_request_id(self, request_id: str): def clear_all_request(self): """Clear all requests""" - self.routing_replay_table.fill_(-1) + self.gpu_routing_buffer.fill_(-1) self.routing_batch_to_request = {} +class RoutingHostBuffer: + """ + Manages routing_host_buffer (corresponds to KVCache GPU cache). + Indexed by gpu_block_id * block_size + offset. + Shared across processes via POSIX SharedMemory. + Each DP rank creates its own instance; name includes dp_suffix. + """ + + def __init__( + self, num_gpu_blocks: int, block_size: int, num_moe_layers: int, top_k: int, dtype: str, dp_suffix: str = "" + ): + max_num_gpu_tokens = num_gpu_blocks * block_size + self.shape = (max_num_gpu_tokens, num_moe_layers, top_k) + self.dtype = np.dtype(dtype) + self.block_size = block_size + total_bytes = int(np.prod(self.shape)) * self.dtype.itemsize + + self.shm_name = f"routing_host_buffer.{dp_suffix}" + self.shm = multiprocessing.shared_memory.SharedMemory( + create=True, size=max(total_bytes, 1), name=self.shm_name + ) + self.buffer = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) + self.buffer[:] = 0xFF if dtype == "uint8" else 0 # -1 for uint8 + + logger.info( + f"[R3] Created RoutingHostBuffer: shape={self.shape}, " + f"size={total_bytes / 1024:.1f} KB, name={self.shm_name}" + ) + + def close(self): + self.shm.close() + self.shm.unlink() + + +class RoutingHostBufferView: + """Read/write view of routing_host_buffer (cross-process, does not own).""" + + def __init__(self, shape, dtype: str, shm_name: str): + self.shm = multiprocessing.shared_memory.SharedMemory(name=shm_name, create=False) + self.dtype = np.dtype(dtype) + self.buffer = np.ndarray(shape, dtype=self.dtype, buffer=self.shm.buf) + + def scatter(self, slot_mapping: np.ndarray, data: np.ndarray): + """Scatter GPU buffer data to corresponding slots (Worker calls this).""" + self.buffer[slot_mapping] = data + + def gather(self, slot_mapping: np.ndarray) -> np.ndarray: + """Gather data from specified slots (TokenProcessor calls this).""" + return self.buffer[slot_mapping].copy() + + def close(self): + self.shm.close() + + +class RoutingSwapBuffer: + """ + Manages routing_swap_buffer (corresponds to KVCache CPU cache). + Indexed by cpu_block_id * block_size + offset. + CacheTransferManager creates this; shared via SharedMemory. + """ + + def __init__( + self, num_cpu_blocks: int, block_size: int, num_moe_layers: int, top_k: int, dtype: str, dp_suffix: str = "" + ): + max_num_cpu_tokens = num_cpu_blocks * block_size + self.shape = (max_num_cpu_tokens, num_moe_layers, top_k) + self.dtype = np.dtype(dtype) + self.block_size = block_size + total_bytes = int(np.prod(self.shape)) * self.dtype.itemsize + + self.shm_name = f"routing_swap_buffer.{dp_suffix}" + self.shm = multiprocessing.shared_memory.SharedMemory( + create=True, size=max(total_bytes, 1), name=self.shm_name + ) + self.buffer = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) + self.buffer[:] = 0xFF if dtype == "uint8" else 0 + + logger.info( + f"[R3] Created RoutingSwapBuffer: shape={self.shape}, " + f"size={total_bytes / 1024:.1f} KB, name={self.shm_name}" + ) + + def close(self): + self.shm.close() + self.shm.unlink() + + +class RoutingSwapBufferView: + """Read/write view of routing_swap_buffer (cross-process, does not own).""" + + def __init__(self, shape, dtype: str, shm_name: str): + self.shm = multiprocessing.shared_memory.SharedMemory(name=shm_name, create=False) + self.dtype = np.dtype(dtype) + self.buffer = np.ndarray(shape, dtype=self.dtype, buffer=self.shm.buf) + + def close(self): + self.shm.close() + + class StoreWrapper(object): def __init__(self, fd_config: False) -> None: super().__init__() diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 9c0a3eb6973..cb12c102de5 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -307,15 +307,29 @@ def post_process_normal( # Routing replay if routing_replay_manager is not None: - # Update host cache - slot_mapping = routing_replay_manager.compute_slot_mapping( - positions=routing_replay_manager.pending_update_positions - ) - routing_replay_manager.update_host_cache( - positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping - ) + # Trigger lazy SharedMemory attach if not yet attempted + routing_replay_manager._try_attach_routing_host_view() + if routing_replay_manager.routing_host_view is not None: + # Phase 2 path: GPU transient buffer → SharedMemory routing_host_buffer + slot_mapping_flat = routing_replay_manager.compute_slot_mapping_flat( + positions=routing_replay_manager.pending_update_positions + ) + num_tokens = len(slot_mapping_flat) + if routing_replay_manager.tp_rank == 0: + routing_replay_manager.save_captured_routing( + num_tokens=num_tokens, + slot_mapping=slot_mapping_flat, + ) + else: + # Legacy v1 path: batch_id-indexed GPU table → CPU _host_cache + slot_mapping = routing_replay_manager.compute_slot_mapping( + positions=routing_replay_manager.pending_update_positions + ) + routing_replay_manager.update_host_cache( + positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping + ) - # Put routing of finished requests to store + # Put routing of finished requests to store (legacy path, still needed until Step 4 fully migrates to Engine) finished_batch_ids = paddle.flatten(paddle.isin(sampler_output.sampled_token_ids, model_output.eos_token_id)) context_lens = model_output.seq_lens_decoder + model_output.seq_lens_encoder routing_replay_manager.put_finished_batch(finished_batch_ids=finished_batch_ids, seq_lens_decoder=context_lens) @@ -452,13 +466,27 @@ def post_process_specualate( # Routing replay if routing_replay_manager is not None: - # Update host cache - slot_mapping = routing_replay_manager.compute_slot_mapping( - positions=routing_replay_manager.pending_update_positions - ) - routing_replay_manager.update_host_cache( - positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping - ) + # Trigger lazy SharedMemory attach if not yet attempted + routing_replay_manager._try_attach_routing_host_view() + if routing_replay_manager.routing_host_view is not None: + # Phase 2 path: GPU transient buffer → SharedMemory routing_host_buffer + slot_mapping_flat = routing_replay_manager.compute_slot_mapping_flat( + positions=routing_replay_manager.pending_update_positions + ) + num_tokens = len(slot_mapping_flat) + if routing_replay_manager.tp_rank == 0: + routing_replay_manager.save_captured_routing( + num_tokens=num_tokens, + slot_mapping=slot_mapping_flat, + ) + else: + # Legacy v1 path + slot_mapping = routing_replay_manager.compute_slot_mapping( + positions=routing_replay_manager.pending_update_positions + ) + routing_replay_manager.update_host_cache( + positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping + ) # Put routing of finished requests to store last_accept_token = paddle.full_like(model_output.accept_tokens, -1) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 646c04854ef..43213b6d0c8 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -131,6 +131,73 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.health_lock = threading.Lock() self.engine_output_token_hang = False + # Routing replay: attach to SharedMemory routing_host_buffer (lazy init after profiling) + self.routing_host_view = None + self._routing_host_view_init_attempted = False + + def _init_routing_host_view(self): + """Attach to SharedMemory routing_host_buffer created by Engine. Called lazily.""" + self._routing_host_view_init_attempted = True + if not self.cfg.routing_replay_config.enable_routing_replay: + return + try: + from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( + RoutingHostBufferView, + ) + + model_config = self.cfg.model_config + cache_config = self.cfg.cache_config + num_moe_layers = model_config.num_hidden_layers - model_config.moe_layer_start_index + if model_config.architectures[0] == "Glm4MoeForCausalLM": + moe_top_k = model_config.num_experts_per_tok + else: + moe_top_k = model_config.moe_k + num_experts = model_config.moe_num_experts + model_config.moe_num_shared_experts + dtype = "uint8" if num_experts + 1 <= 255 else ("uint16" if num_experts + 1 <= 65535 else "uint32") + + dp_suffix = str(self.cfg.parallel_config.local_engine_worker_queue_port) + shm_name = f"routing_host_buffer.{dp_suffix}" + num_gpu_blocks = cache_config.total_block_num + max_num_kv_tokens = num_gpu_blocks * cache_config.block_size + shape = (max_num_kv_tokens, num_moe_layers, moe_top_k) + + self.routing_host_view = RoutingHostBufferView(shape=shape, dtype=dtype, shm_name=shm_name) + self._routing_block_size = cache_config.block_size + self._routing_num_moe_layers = num_moe_layers + self._routing_moe_top_k = moe_top_k + llm_logger.info(f"[R3] TokenProcessor attached to RoutingHostBuffer: {shm_name}") + except FileNotFoundError: + llm_logger.warning("[R3] RoutingHostBuffer SharedMemory not found, routing gather disabled.") + except Exception as e: + llm_logger.warning(f"[R3] Failed to attach to RoutingHostBuffer: {e}") + + def _gather_routing_for_finished_request(self, task, seq_len: int): + """ + Gather complete routing data for a finished request from routing_host_buffer. + + Args: + task: Request task with block_tables + seq_len: Total sequence length + + Returns: + numpy array [seq_len, num_moe_layers, top_k] or None + """ + if self.routing_host_view is None and not self._routing_host_view_init_attempted: + self._init_routing_host_view() + if self.routing_host_view is None: + return None + + import math + + block_size = self._routing_block_size + block_ids = task.block_tables[: math.ceil(seq_len / block_size)] + positions = np.arange(seq_len) + block_indices = positions // block_size + offsets = positions % block_size + slot_mapping = np.array(block_ids)[block_indices] * block_size + offsets + + return self.routing_host_view.gather(slot_mapping) + def healthy(self): """ whether token processor is healthy @@ -516,6 +583,20 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False """ recycle resources """ + # Gather routing data before blocks are recycled (blocks will be freed below) + if result is not None and result.finished and self.cfg.routing_replay_config.enable_routing_replay: + try: + seq_len = ( + task.prompt_token_ids_len + len(task.output_token_ids) + if hasattr(task, "output_token_ids") + else task.prompt_token_ids_len + ) + routing_data = self._gather_routing_for_finished_request(task, seq_len) + if routing_data is not None: + result.routing_data = routing_data + except Exception as e: + llm_logger.warning(f"[R3] Failed to gather routing for {task_id}: {e}") + if is_prefill: start_time = time.time() result.metrics.wait_for_sending_cache_time = time.time() @@ -533,6 +614,15 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False f"wait for sending cache, request_id: {task_id}, cost seconds: {time.time()-start_time:.5f}" ) result.metrics.send_request_output_to_decode_time = time.time() + # [R3] Gather prefill routing data before sending to D + if self.cfg.routing_replay_config.enable_routing_replay and result.error_code == 200: + try: + seq_len = task.prompt_token_ids_len + routing_data = self._gather_routing_for_finished_request(task, seq_len) + if routing_data is not None: + result.routing_data = routing_data + except Exception as e: + llm_logger.warning(f"[R3] Failed to gather prefill routing for {task_id}: {e}") self.split_connector.send_first_token(task.disaggregate_info, [result]) if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.resource_manager.finish_requests_async(task_id) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index efcefa5978d..abb1c5b3667 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1455,8 +1455,10 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): """ # Initialize forward meta routing_replay_table = None + gpu_routing_buffer = None if self.routing_replay_manager is not None: routing_replay_table = self.routing_replay_manager.get_routing_table() + gpu_routing_buffer = self.routing_replay_manager.get_gpu_routing_buffer() self.forward_meta = ForwardMeta( ids_remove_padding=self.share_inputs["ids_remove_padding"], rotary_embs=self.share_inputs["rope_emb"], @@ -1484,6 +1486,7 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], routing_replay_table=routing_replay_table, + gpu_routing_buffer=gpu_routing_buffer, ) dist_status = self.collect_distributed_status()