Skip to content

Comments

R3 PR: Rollout Routing Replay#1101

Open
devpatelio wants to merge 2 commits intoNovaSky-AI:mainfrom
devpatelio:devpatel/r3
Open

R3 PR: Rollout Routing Replay#1101
devpatelio wants to merge 2 commits intoNovaSky-AI:mainfrom
devpatelio:devpatel/r3

Conversation

@devpatelio
Copy link
Collaborator

@devpatelio devpatelio commented Feb 13, 2026

Rollout Routing Replay Design Doc

This PR introduces the R3 Feature (See Paper). To use the R3 feature, we need to upgrade vLLM to use the latest PR that returns expert indices from the router.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for 'Routing Replay' in Mixture-of-Experts models, which involves significant changes across the configuration, data processing pipeline, and training workers. The changes are mostly well-implemented, but there are a few key areas for improvement. I've identified a critical bug where routing information is dropped in the generator, a couple of high-severity issues related to resource cleanup in the Megatron worker that could lead to state leakage, and some medium-severity concerns regarding dependency management and temporary debug code. Addressing these points will improve the robustness and maintainability of the new feature.

output_logprobs: Optional[List[float]]
new_obs: ConversationType
obs_ids: List[int]
output_rollout_inference_indices: Optional[List[List[List[List[int]]]]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The output_rollout_inference_indices field is added to TurnOutput, but it seems it's not being propagated to the final GeneratorOutput. The value is fetched from the inference engine in agent_loop but then dropped. It needs to be collected across turns and returned from agent_loop, and then packaged into the GeneratorOutput in the generate method. This seems to be a bug as the routing information will be lost.

Comment on lines +101 to +103
vllm = [
{ url = "https://wheels.vllm.ai/49e6b86c91cddc4a4b2152c6d8834780bbd50deb/vllm-0.14.0rc1.dev505%2Bg49e6b86c9-cp38-abi3-manylinux_2_31_x86_64.whl", marker = "sys_platform == 'linux'" }
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pinning a dependency to a direct URL, especially one with a commit hash, can be risky for long-term maintainability and security. If this URL becomes unavailable, the build will break. It also makes it harder to track what specific changes are included. For production or long-term projects, consider forking the repository and building from a specific tag/branch, or hosting the wheel in a private package repository for better stability and control.

Comment on lines +314 to +318
# Log if enable_return_routed_experts is being passed
if "enable_return_routed_experts" in kwargs:
logger.info(f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs")
else:
logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This debug logging seems temporary. It's good for development, but it should be removed before merging to avoid cluttering the logs. If this check is intended to be permanent, consider using a lower log level like debug instead of info and warning.

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 3 potential issues.

View 6 additional findings in Devin Review.

Open in Devin Review

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 SkyRLGymGenerator.generate() never propagates rollout_inference_indices to GeneratorOutput

The rollout_inference_indices data collected from the inference engine is stored in TurnOutput.output_rollout_inference_indices (line 362) during the agent loop, but it is never extracted from TrajectoryOutput/StepWiseOutput and never included in the final GeneratorOutput dict constructed at line 775.

Root Cause and Impact

The generate() method builds the final GeneratorOutput at skyrl_train/generators/skyrl_gym_generator.py:775-785 but omits the "rollout_inference_indices" key:

generator_output: GeneratorOutput = {
    "prompt_token_ids": prompt_token_ids,
    "response_ids": responses,
    "rewards": rewards,
    "loss_masks": loss_masks,
    "stop_reasons": stop_reasons,
    "rollout_metrics": rollout_metrics,
    "rollout_logprobs": rollout_logprobs,
    "trajectory_ids": out_trajectory_ids,
    "is_last_step": is_last_step,
    # Missing: "rollout_inference_indices"
}

Downstream in trainer.py:600, the trainer does generator_output.get("rollout_inference_indices", None) which will always return None. This means rollout_inference_indices_tensor will always be None in the TrainingInputBatch, and the entire routing replay feature (moe_enable_routing_replay) will silently receive no data, rendering the Rollout Routing Replay feature non-functional.

Additionally, the same key is missing from generate_batched() at skyrl_train/generators/skyrl_gym_generator.py:657-665.

(Refers to lines 775-785)

Prompt for agents
The rollout_inference_indices data is collected in TurnOutput.output_rollout_inference_indices during the agent loop but never propagated to the final GeneratorOutput.

To fix this:
1. In the generate() method around line 738-746, collect rollout_inference_indices from all_outputs similarly to how rollout_logprobs is collected (lines 755-764). For the non-step-wise path, extract from TrajectoryOutput; for step-wise, extract from StepWiseOutput.step_outputs.
2. Add "rollout_inference_indices": <collected_value> to the GeneratorOutput dict at line 775-785.
3. Similarly, the generate_batched() method at line 657-665 should also include "rollout_inference_indices" in its output dict.
4. The TrajectoryOutput dataclass (line 34-43) likely also needs a rollout_inference_indices field so that agent_loop() can propagate it from TurnOutput to TrajectoryOutput (around line 467-475 where the final TrajectoryOutput is constructed).
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +152 to +163
if resp.routed_experts is not None:
if hasattr(resp.routed_experts, "tolist"):
routed_experts_list = resp.routed_experts.tolist()
else:
routed_experts_list = resp.routed_experts
rollout_inference_indices.append(routed_experts_list)

if len(response_logprobs) and response_logprobs[0] is None:
response_logprobs = None # hack: assume uniform sampling params

if len(rollout_inference_indices) == 0:
rollout_inference_indices = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Misaligned rollout_inference_indices list when some outputs lack routed_experts

In _postprocess_outputs, rollout_inference_indices is only appended when resp.routed_experts is not None (line 152-157), but unlike response_logprobs which always appends (even None), this means the list can have fewer elements than the number of responses.

Root Cause and Impact

Compare the handling of response_logprobs vs rollout_inference_indices:

# response_logprobs: always appends, even None
response_logprobs.append(_logprobs)  # line 150

# rollout_inference_indices: only appends when non-None
if resp.routed_experts is not None:  # line 152
    rollout_inference_indices.append(routed_experts_list)  # line 157

If enable_return_routed_experts is enabled but some responses have routed_experts=None (e.g., due to an error or edge case in vLLM), rollout_inference_indices will have fewer elements than responses. Downstream code at skyrl_train/generators/skyrl_gym_generator.py:317-318 does rollout_inference_indices[0] to extract for batch index 0, which would retrieve the wrong sample's data when the lists are misaligned.

Additionally, the check at line 162 (if len(rollout_inference_indices) == 0) only handles the all-None case but not the partial case, so it would return a non-None list with misaligned indices.

Suggested change
if resp.routed_experts is not None:
if hasattr(resp.routed_experts, "tolist"):
routed_experts_list = resp.routed_experts.tolist()
else:
routed_experts_list = resp.routed_experts
rollout_inference_indices.append(routed_experts_list)
if len(response_logprobs) and response_logprobs[0] is None:
response_logprobs = None # hack: assume uniform sampling params
if len(rollout_inference_indices) == 0:
rollout_inference_indices = None
if hasattr(resp, 'routed_experts') and resp.routed_experts is not None:
if hasattr(resp.routed_experts, "tolist"):
routed_experts_list = resp.routed_experts.tolist()
else:
routed_experts_list = resp.routed_experts
rollout_inference_indices.append(routed_experts_list)
else:
rollout_inference_indices.append(None)
if len(response_logprobs) and response_logprobs[0] is None:
response_logprobs = None # hack: assume uniform sampling params
if len(rollout_inference_indices) and rollout_inference_indices[0] is None:
rollout_inference_indices = None
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

"prompt_token_ids": [prompt_token_ids], # Wrap in list for batch of 1
"sampling_params": sampling_params,
"session_ids": None,
"rollout_inference_indices": None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Extra key rollout_inference_indices added to InferenceEngineInput dict not in TypedDict

In the sample() method of InferenceEngineInterface, the input_batch dict includes "rollout_inference_indices": None at line 75, but this key is not defined in the InferenceEngineInput TypedDict (lines 12-17).

Root Cause and Impact

The InferenceEngineInput TypedDict at skyrl_train/inference_engines/base.py:12-17 only has four keys: prompts, prompt_token_ids, sampling_params, session_ids. Adding rollout_inference_indices to the dict at line 75 introduces an undeclared key. While Python TypedDicts don't strictly enforce this at runtime, this key could cause unexpected behavior if any downstream code iterates over the dict keys or performs strict validation. It also signals a misunderstanding of the data flow since rollout_inference_indices is an output, not an input.

Suggested change
"rollout_inference_indices": None,
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

@CharlieFRuan CharlieFRuan self-assigned this Feb 13, 2026
Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 8 additional findings in Devin Review.

Open in Devin Review

Comment on lines +538 to +540
from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay

setup_router_replay_forward(data, self.enable_router_replay)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 forward_backward calls setup_router_replay_forward instead of setup_router_replay_backward

The forward_backward method at megatron_worker.py:538-540 imports and calls setup_router_replay_forward instead of the intended setup_router_replay_backward. The setup_router_replay_backward function was specifically created for the training forward/backward pass.

Root Cause

At skyrl_train/workers/megatron/megatron_worker.py:538-540:

from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay
setup_router_replay_forward(data, self.enable_router_replay)

This should be:

from skyrl_train.utils.replay_utils import setup_router_replay_backward, clear_router_replay
setup_router_replay_backward(data, self.enable_router_replay)

Currently both setup_router_replay_forward and setup_router_replay_backward in skyrl_train/utils/replay_utils.py have identical implementations (both set REPLAY_FORWARD), so there is no functional difference today. However, this is semantically incorrect and will become a real bug if/when setup_router_replay_backward is updated to set a different replay action (e.g., a dedicated backward replay mode).

Suggested change
from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay
setup_router_replay_forward(data, self.enable_router_replay)
from skyrl_train.utils.replay_utils import setup_router_replay_backward, clear_router_replay
setup_router_replay_backward(data, self.enable_router_replay)
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants