Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for 'Routing Replay' in Mixture-of-Experts models, which involves significant changes across the configuration, data processing pipeline, and training workers. The changes are mostly well-implemented, but there are a few key areas for improvement. I've identified a critical bug where routing information is dropped in the generator, a couple of high-severity issues related to resource cleanup in the Megatron worker that could lead to state leakage, and some medium-severity concerns regarding dependency management and temporary debug code. Addressing these points will improve the robustness and maintainability of the new feature.
| output_logprobs: Optional[List[float]] | ||
| new_obs: ConversationType | ||
| obs_ids: List[int] | ||
| output_rollout_inference_indices: Optional[List[List[List[List[int]]]]] |
There was a problem hiding this comment.
The output_rollout_inference_indices field is added to TurnOutput, but it seems it's not being propagated to the final GeneratorOutput. The value is fetched from the inference engine in agent_loop but then dropped. It needs to be collected across turns and returned from agent_loop, and then packaged into the GeneratorOutput in the generate method. This seems to be a bug as the routing information will be lost.
| vllm = [ | ||
| { url = "https://wheels.vllm.ai/49e6b86c91cddc4a4b2152c6d8834780bbd50deb/vllm-0.14.0rc1.dev505%2Bg49e6b86c9-cp38-abi3-manylinux_2_31_x86_64.whl", marker = "sys_platform == 'linux'" } | ||
| ] |
There was a problem hiding this comment.
Pinning a dependency to a direct URL, especially one with a commit hash, can be risky for long-term maintainability and security. If this URL becomes unavailable, the build will break. It also makes it harder to track what specific changes are included. For production or long-term projects, consider forking the repository and building from a specific tag/branch, or hosting the wheel in a private package repository for better stability and control.
| # Log if enable_return_routed_experts is being passed | ||
| if "enable_return_routed_experts" in kwargs: | ||
| logger.info(f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs") | ||
| else: | ||
| logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs") |
There was a problem hiding this comment.
🔴 SkyRLGymGenerator.generate() never propagates rollout_inference_indices to GeneratorOutput
The rollout_inference_indices data collected from the inference engine is stored in TurnOutput.output_rollout_inference_indices (line 362) during the agent loop, but it is never extracted from TrajectoryOutput/StepWiseOutput and never included in the final GeneratorOutput dict constructed at line 775.
Root Cause and Impact
The generate() method builds the final GeneratorOutput at skyrl_train/generators/skyrl_gym_generator.py:775-785 but omits the "rollout_inference_indices" key:
generator_output: GeneratorOutput = {
"prompt_token_ids": prompt_token_ids,
"response_ids": responses,
"rewards": rewards,
"loss_masks": loss_masks,
"stop_reasons": stop_reasons,
"rollout_metrics": rollout_metrics,
"rollout_logprobs": rollout_logprobs,
"trajectory_ids": out_trajectory_ids,
"is_last_step": is_last_step,
# Missing: "rollout_inference_indices"
}Downstream in trainer.py:600, the trainer does generator_output.get("rollout_inference_indices", None) which will always return None. This means rollout_inference_indices_tensor will always be None in the TrainingInputBatch, and the entire routing replay feature (moe_enable_routing_replay) will silently receive no data, rendering the Rollout Routing Replay feature non-functional.
Additionally, the same key is missing from generate_batched() at skyrl_train/generators/skyrl_gym_generator.py:657-665.
(Refers to lines 775-785)
Prompt for agents
The rollout_inference_indices data is collected in TurnOutput.output_rollout_inference_indices during the agent loop but never propagated to the final GeneratorOutput.
To fix this:
1. In the generate() method around line 738-746, collect rollout_inference_indices from all_outputs similarly to how rollout_logprobs is collected (lines 755-764). For the non-step-wise path, extract from TrajectoryOutput; for step-wise, extract from StepWiseOutput.step_outputs.
2. Add "rollout_inference_indices": <collected_value> to the GeneratorOutput dict at line 775-785.
3. Similarly, the generate_batched() method at line 657-665 should also include "rollout_inference_indices" in its output dict.
4. The TrajectoryOutput dataclass (line 34-43) likely also needs a rollout_inference_indices field so that agent_loop() can propagate it from TurnOutput to TrajectoryOutput (around line 467-475 where the final TrajectoryOutput is constructed).
Was this helpful? React with 👍 or 👎 to provide feedback.
| if resp.routed_experts is not None: | ||
| if hasattr(resp.routed_experts, "tolist"): | ||
| routed_experts_list = resp.routed_experts.tolist() | ||
| else: | ||
| routed_experts_list = resp.routed_experts | ||
| rollout_inference_indices.append(routed_experts_list) | ||
|
|
||
| if len(response_logprobs) and response_logprobs[0] is None: | ||
| response_logprobs = None # hack: assume uniform sampling params | ||
|
|
||
| if len(rollout_inference_indices) == 0: | ||
| rollout_inference_indices = None |
There was a problem hiding this comment.
🔴 Misaligned rollout_inference_indices list when some outputs lack routed_experts
In _postprocess_outputs, rollout_inference_indices is only appended when resp.routed_experts is not None (line 152-157), but unlike response_logprobs which always appends (even None), this means the list can have fewer elements than the number of responses.
Root Cause and Impact
Compare the handling of response_logprobs vs rollout_inference_indices:
# response_logprobs: always appends, even None
response_logprobs.append(_logprobs) # line 150
# rollout_inference_indices: only appends when non-None
if resp.routed_experts is not None: # line 152
rollout_inference_indices.append(routed_experts_list) # line 157If enable_return_routed_experts is enabled but some responses have routed_experts=None (e.g., due to an error or edge case in vLLM), rollout_inference_indices will have fewer elements than responses. Downstream code at skyrl_train/generators/skyrl_gym_generator.py:317-318 does rollout_inference_indices[0] to extract for batch index 0, which would retrieve the wrong sample's data when the lists are misaligned.
Additionally, the check at line 162 (if len(rollout_inference_indices) == 0) only handles the all-None case but not the partial case, so it would return a non-None list with misaligned indices.
| if resp.routed_experts is not None: | |
| if hasattr(resp.routed_experts, "tolist"): | |
| routed_experts_list = resp.routed_experts.tolist() | |
| else: | |
| routed_experts_list = resp.routed_experts | |
| rollout_inference_indices.append(routed_experts_list) | |
| if len(response_logprobs) and response_logprobs[0] is None: | |
| response_logprobs = None # hack: assume uniform sampling params | |
| if len(rollout_inference_indices) == 0: | |
| rollout_inference_indices = None | |
| if hasattr(resp, 'routed_experts') and resp.routed_experts is not None: | |
| if hasattr(resp.routed_experts, "tolist"): | |
| routed_experts_list = resp.routed_experts.tolist() | |
| else: | |
| routed_experts_list = resp.routed_experts | |
| rollout_inference_indices.append(routed_experts_list) | |
| else: | |
| rollout_inference_indices.append(None) | |
| if len(response_logprobs) and response_logprobs[0] is None: | |
| response_logprobs = None # hack: assume uniform sampling params | |
| if len(rollout_inference_indices) and rollout_inference_indices[0] is None: | |
| rollout_inference_indices = None |
Was this helpful? React with 👍 or 👎 to provide feedback.
| "prompt_token_ids": [prompt_token_ids], # Wrap in list for batch of 1 | ||
| "sampling_params": sampling_params, | ||
| "session_ids": None, | ||
| "rollout_inference_indices": None, |
There was a problem hiding this comment.
🟡 Extra key rollout_inference_indices added to InferenceEngineInput dict not in TypedDict
In the sample() method of InferenceEngineInterface, the input_batch dict includes "rollout_inference_indices": None at line 75, but this key is not defined in the InferenceEngineInput TypedDict (lines 12-17).
Root Cause and Impact
The InferenceEngineInput TypedDict at skyrl_train/inference_engines/base.py:12-17 only has four keys: prompts, prompt_token_ids, sampling_params, session_ids. Adding rollout_inference_indices to the dict at line 75 introduces an undeclared key. While Python TypedDicts don't strictly enforce this at runtime, this key could cause unexpected behavior if any downstream code iterates over the dict keys or performs strict validation. It also signals a misunderstanding of the data flow since rollout_inference_indices is an output, not an input.
| "rollout_inference_indices": None, |
Was this helpful? React with 👍 or 👎 to provide feedback.
| from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay | ||
|
|
||
| setup_router_replay_forward(data, self.enable_router_replay) |
There was a problem hiding this comment.
🟡 forward_backward calls setup_router_replay_forward instead of setup_router_replay_backward
The forward_backward method at megatron_worker.py:538-540 imports and calls setup_router_replay_forward instead of the intended setup_router_replay_backward. The setup_router_replay_backward function was specifically created for the training forward/backward pass.
Root Cause
At skyrl_train/workers/megatron/megatron_worker.py:538-540:
from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay
setup_router_replay_forward(data, self.enable_router_replay)This should be:
from skyrl_train.utils.replay_utils import setup_router_replay_backward, clear_router_replay
setup_router_replay_backward(data, self.enable_router_replay)Currently both setup_router_replay_forward and setup_router_replay_backward in skyrl_train/utils/replay_utils.py have identical implementations (both set REPLAY_FORWARD), so there is no functional difference today. However, this is semantically incorrect and will become a real bug if/when setup_router_replay_backward is updated to set a different replay action (e.g., a dedicated backward replay mode).
| from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay | |
| setup_router_replay_forward(data, self.enable_router_replay) | |
| from skyrl_train.utils.replay_utils import setup_router_replay_backward, clear_router_replay | |
| setup_router_replay_backward(data, self.enable_router_replay) |
Was this helpful? React with 👍 or 👎 to provide feedback.
Rollout Routing Replay Design Doc
This PR introduces the R3 Feature (See Paper). To use the R3 feature, we need to upgrade vLLM to use the latest PR that returns expert indices from the router.