From 461e4e645a8481e756723dd75c06a168c0603fab Mon Sep 17 00:00:00 2001 From: cc <1716911340@qq.com> Date: Mon, 8 Sep 2025 18:38:19 +0800 Subject: [PATCH 01/38] Merge pull request #1 from anHappyDog/feature/weight_convertor feat(weight): refactor and add qwen2.5-vl mg2hf convertor --- rlinf/utils/convertor/__init__.py | 13 + rlinf/utils/convertor/utils.py | 457 ++++++++++++++++++ .../utils/resharding/mcore_weight_reshard.py | 2 +- rlinf/utils/resharding/reshard_config.py | 9 +- rlinf/utils/resharding/utils.py | 218 --------- 5 files changed, 477 insertions(+), 222 deletions(-) create mode 100644 rlinf/utils/convertor/__init__.py create mode 100644 rlinf/utils/convertor/utils.py diff --git a/rlinf/utils/convertor/__init__.py b/rlinf/utils/convertor/__init__.py new file mode 100644 index 000000000..5b365ea1e --- /dev/null +++ b/rlinf/utils/convertor/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/rlinf/utils/convertor/utils.py b/rlinf/utils/convertor/utils.py new file mode 100644 index 000000000..5d15e3639 --- /dev/null +++ b/rlinf/utils/convertor/utils.py @@ -0,0 +1,457 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Dict, List, Optional, Tuple + +import torch + + +class TransformType(Enum): + SPLIT_QKV = "split_qkv" + SPLIT_QKV_BIAS = "split_qkv_bias" + SPLIT_FC1 = "split_fc1" + SPLIT_NONE = "split_none" + + +class TransformFunc: + @staticmethod + def _split_gqa_tensor( + tensor: torch.Tensor, new_statedict: dict, weight_names: List[str], config + ) -> None: + """ + Private helper to split a GQA-combined tensor (weight or bias). + """ + hidden_size = config.model_config.hidden_size + num_attention_heads = config.model_config.num_attention_heads + num_key_value_heads = ( + config.model_config.num_query_groups or num_attention_heads + ) + head_dim = hidden_size // num_attention_heads + + tp_size = config.model_config.tensor_model_parallel_size + + assert num_key_value_heads % tp_size == 0, ( + "num_key_value_heads must be divisible by tensor parallel size" + ) + + q_heads_per_rank = num_attention_heads // tp_size + kv_heads_per_rank = num_key_value_heads // tp_size + + q_shard_size = q_heads_per_rank * head_dim + k_shard_size = kv_heads_per_rank * head_dim + v_shard_size = kv_heads_per_rank * head_dim + + shard_size = q_shard_size + k_shard_size + v_shard_size + + q_shards, k_shards, v_shards = [], [], [] + + # [Qi,Ki,Vi] + for shard in tensor.split(shard_size, dim=0): + # Qi, Ki, Vi + q_shard, k_shard, v_shard = shard.split( + [q_shard_size, k_shard_size, v_shard_size], dim=0 + ) + q_shards.append(q_shard) + k_shards.append(k_shard) + v_shards.append(v_shard) + + # cat + q_full = torch.cat(q_shards, dim=0) + k_full = torch.cat(k_shards, dim=0) + v_full = torch.cat(v_shards, dim=0) + + # saved + new_statedict[weight_names[0]] = q_full.clone() + new_statedict[weight_names[1]] = k_full.clone() + new_statedict[weight_names[2]] = v_full.clone() + + @staticmethod + def split_fc1( + linear_fc1: torch.Tensor, new_statedict: dict, weight_names: List[str], config + ) -> None: + assert weight_names is not None and len(weight_names) == 2, ( + f"split_fc1 transform expects two weight names, got {weight_names}" + ) + + tp_size = config.model_config.tensor_model_parallel_size + target_tp = config.reshard_tp_size + split_size = linear_fc1.shape[0] // (tp_size // target_tp) + linear_fc1_slice = torch.split(linear_fc1, split_size, dim=0) + + gate_proj_shards = [] + up_proj_shards = [] + for weight in linear_fc1_slice: + assert weight.shape[0] % 2 == 0, ( + f"linear_fc1 weight shape {weight.shape} is not even along dim 0" + ) + weight_chunk = torch.chunk(weight, 2, dim=0) + gate_proj_shards.append(weight_chunk[0]) + up_proj_shards.append(weight_chunk[1]) + gate_proj = torch.cat(gate_proj_shards, dim=0) + up_proj = torch.cat(up_proj_shards, dim=0) + + new_statedict[weight_names[0]] = gate_proj.clone() + new_statedict[weight_names[1]] = up_proj.clone() + + @staticmethod + def split_none( + tensor: torch.Tensor, new_statedict: dict, weight_names: List[str] + ) -> None: + assert weight_names is not None and len(weight_names) == 1, ( + f"split_none transform expects one weight name, got {weight_names}" + ) + new_statedict[weight_names[0]] = tensor.clone() + + +@dataclass +class ConvertorRule: + pattern: re.Pattern + transform: TransformType + targets: List[str] + post: Optional[Callable] = None + + +class BaseConvertor: + def __init__(self, config, strict: bool = False): + self.cfg = config + self.strict = strict + self.rules = self.build_rules() + + def map_name(self, name: str) -> Optional[Tuple[TransformType, List[str]]]: + def _get_targets_from_match(templates: list[str], m: re.Match) -> list[str]: + gd = m.groupdict() + out = [] + for t in templates: + if "{" in t and "}" in t: + out.append(t.format(**gd)) + else: + out.append(m.expand(t)) + return out + + for r in self.rules: + m = r.pattern.fullmatch(name) + if not m: + continue + targets = r.targets + if r.post: + targets = r.post(targets, m) + full_names = _get_targets_from_match(targets, m) + return r.transform, full_names + return None + + def convert(self, state_dict: Dict) -> Dict: + converted = {} + for k, v in state_dict.items(): + mapped = self.map_name(k) + if mapped is None: + if self.strict: + raise KeyError(f"Unmapped key {k}") + continue + transform, targets = mapped + if transform in (TransformType.SPLIT_QKV, TransformType.SPLIT_QKV_BIAS): + TransformFunc._split_gqa_tensor(v, converted, targets, self.cfg) + elif transform == TransformType.SPLIT_FC1: + TransformFunc.split_fc1(v, converted, targets, self.cfg) + elif transform == TransformType.SPLIT_NONE: + TransformFunc.split_none(v, converted, targets) + else: + raise ValueError(f"Unknown transform type {transform}") + return converted + + def build_rules(self) -> List[ConvertorRule]: + """ + Should be implemented in subclass to build the conversion rules. + """ + raise NotImplementedError + + +class Qwen2_5Convertor(BaseConvertor): + def build_rules(self) -> List[ConvertorRule]: + LID = r"(?P\d+)" + WB = r"(?Pweight|bias)" + + return [ + # embeddings + ConvertorRule( + re.compile(r"embedding\.word_embeddings\.weight$"), + TransformType.SPLIT_NONE, + [r"model.embed_tokens.weight"], + ), + # final_layernorm + ConvertorRule( + re.compile(r"decoder\.final_layernorm\.weight$"), + TransformType.SPLIT_NONE, + [r"model.norm.weight"], + ), + # lm_head + ConvertorRule( + re.compile(r"output_layer\.weight$"), + TransformType.SPLIT_NONE, + [r"lm_head.weight"], + ), + # attn qkv norm + ConvertorRule( + re.compile( + rf"decoder\.layers\.{LID}\.self_attention\.linear_qkv\.layer_norm_weight$" + ), + TransformType.SPLIT_NONE, + [r"model.layers.\g.input_layernorm.weight"], + ), + # attn qkv weights/bias + ConvertorRule( + re.compile( + rf"decoder\.layers\.{LID}\.self_attention\.linear_qkv\.{WB}$" + ), + TransformType.SPLIT_QKV, + [ + r"model.layers.\g.self_attn.q_proj.\g", + r"model.layers.\g.self_attn.k_proj.\g", + r"model.layers.\g.self_attn.v_proj.\g", + ], + ), + # attn o proj + ConvertorRule( + re.compile( + rf"decoder\.layers\.{LID}\.self_attention\.linear_proj\.{WB}$" + ), + TransformType.SPLIT_NONE, + [r"model.layers.\g.self_attn.o_proj.\g"], + ), + # mlp fc1 + ConvertorRule( + re.compile(rf"decoder\.layers\.{LID}\.mlp\.linear_fc1\.{WB}$"), + TransformType.SPLIT_FC1, + [ + r"model.layers.\g.mlp.gate_proj.\g", + r"model.layers.\g.mlp.up_proj.\g", + ], + ), + # mlp fc2 + ConvertorRule( + re.compile(rf"decoder\.layers\.{LID}\.mlp\.linear_fc2\.{WB}$"), + TransformType.SPLIT_NONE, + [r"model.layers.\g.mlp.down_proj.\g"], + ), + # mlp norms + ConvertorRule( + re.compile( + rf"decoder\.layers\.{LID}\.mlp\.linear_fc1\.layer_norm_weight$" + ), + TransformType.SPLIT_NONE, + [r"model.layers.\g.post_attention_layernorm.weight"], + ), + ] + + +class Qwen2_5VLConvertor(BaseConvertor): + def _build_vision_rules(self) -> List[ConvertorRule]: + B = r"(?P\d+)" + WB = r"(?Pweight|bias)" + HF_V_PREFIX = "model.visual" + HF_V_DECODER_PREFIX = f"{HF_V_PREFIX}.blocks" + MG_V_PREFIX = "vision_model" + MG_V_DECODER_PREFIX = rf"{MG_V_PREFIX}\.decoder\.layers" + + vision_rules = [ + # vision patch embed + ConvertorRule( + re.compile(rf"^{MG_V_PREFIX}\.patch_embed\.proj\.weight$"), + TransformType.SPLIT_NONE, + [f"{HF_V_PREFIX}.patch_embed.proj.weight"], + ), + # final layer norm + ConvertorRule( + re.compile(rf"^{MG_V_PREFIX}\.decoder\.final_layernorm\.weight$"), + TransformType.SPLIT_NONE, + [f"{HF_V_PREFIX}.merger.ln_q.weight"], + ), + # attn norm + ConvertorRule( + re.compile( + rf"^{MG_V_DECODER_PREFIX}\.{B}\.self_attention\.layer_norm_weight$" + ), + TransformType.SPLIT_NONE, + [f"{HF_V_DECODER_PREFIX}" + r".\g.norm1.weight"], + ), + # attn qkv + ConvertorRule( + re.compile( + rf"^{MG_V_DECODER_PREFIX}\.{B}\.self_attention\.linear_qkv\.{WB}$" + ), + TransformType.SPLIT_NONE, + [f"{HF_V_DECODER_PREFIX}" + r".\g.attn.qkv.\g"], + ), + # attn proj + ConvertorRule( + re.compile( + rf"^{MG_V_DECODER_PREFIX}\.{B}\.self_attention\.linear_proj\.{WB}$" + ), + TransformType.SPLIT_NONE, + [f"{HF_V_DECODER_PREFIX}" + r".\g.attn.proj.\g"], + ), + # mlp fc1 + ConvertorRule( + re.compile(rf"^{MG_V_DECODER_PREFIX}\.{B}\.mlp\.linear_fc1\.{WB}$"), + TransformType.SPLIT_FC1, + [ + f"{HF_V_DECODER_PREFIX}" + r".\g.mlp.gate_proj.\g", + f"{HF_V_DECODER_PREFIX}" + r".\g.mlp.up_proj.\g", + ], + ), + # mlp fc2 + ConvertorRule( + re.compile(rf"^{MG_V_DECODER_PREFIX}\.{B}\.mlp\.linear_fc2\.{WB}$"), + TransformType.SPLIT_NONE, + [f"{HF_V_DECODER_PREFIX}" + r".\g.mlp.down_proj.\g"], + ), + # mlp norm + ConvertorRule( + re.compile( + rf"^{MG_V_DECODER_PREFIX}\.{B}\.mlp\.linear_fc1\.layer_norm_weight$" + ), + TransformType.SPLIT_NONE, + [f"{HF_V_DECODER_PREFIX}" + r".\g.norm2.weight"], + ), + ] + return vision_rules + + def _build_llm_rules(self) -> List[ConvertorRule]: + B = r"(?P\d+)" + WB = r"(?Pweight|bias)" + HF_LLM_PREFIX = "model.language_model" + MG_LLM_PREFIX = "language_model" + MG_LLM_DECODER_PREFIX = rf"{MG_LLM_PREFIX}\.decoder\.layers" + + llm_rules = [ + # embeddings + ConvertorRule( + re.compile(rf"^{MG_LLM_PREFIX}\.embed_tokens\.weight$"), + TransformType.SPLIT_NONE, + [f"{HF_LLM_PREFIX}.embedding.weight"], + ), + # final_layernorm + ConvertorRule( + re.compile(rf"^{MG_LLM_PREFIX}\.final_layernorm\.weight$"), + TransformType.SPLIT_NONE, + [f"{HF_LLM_PREFIX}.norm.weight"], + ), + # attn norm + ConvertorRule( + re.compile( + rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.self_attention\.layer_norm_weight$" + ), + TransformType.SPLIT_NONE, + [f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.input_layernorm.weight"], + ), + # attn qkv + ConvertorRule( + re.compile( + rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.self_attention\.linear_qkv\.{WB}$" + ), + TransformType.SPLIT_QKV, + [ + f"{HF_LLM_PREFIX}" + + r".decoder.layers.\g.self_attn.q_proj.\g", + f"{HF_LLM_PREFIX}" + + r".decoder.layers.\g.self_attn.k_proj.\g", + f"{HF_LLM_PREFIX}" + + r".decoder.layers.\g.self_attn.v_proj.\g", + ], + ), + # attn proj + ConvertorRule( + re.compile( + rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.self_attention\.linear_proj\.{WB}$" + ), + TransformType.SPLIT_NONE, + [f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.self_attn.o_proj.\g"], + ), + # mlp fc1 + ConvertorRule( + re.compile(rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.mlp\.linear_fc1\.{WB}$"), + TransformType.SPLIT_FC1, + [ + f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.mlp.gate_proj.\g", + f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.mlp.up_proj.\g", + ], + ), + # mlp fc2 + ConvertorRule( + re.compile(rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.mlp\.linear_fc2\.{WB}$"), + TransformType.SPLIT_NONE, + [f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.mlp.down_proj.\g"], + ), + # mlp norm + ConvertorRule( + re.compile( + rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.mlp\.linear_fc1\.layer_norm_weight$" + ), + TransformType.SPLIT_NONE, + [ + f"{HF_LLM_PREFIX}" + + r".decoder.layers.\g.post_attention_layernorm.weight" + ], + ), + ] + return llm_rules + + def _build_projector_rules(self) -> List[ConvertorRule]: + HF_PROJECTOR_PREFIX = "model.visual.merger" + MG_PROJECTOR_PREFIX = "vision_model.protection.encoder" + WB = r"(?Pweight|bias)" + + projector_rules = [ + # projector fc1 + ConvertorRule( + re.compile(rf"^{MG_PROJECTOR_PREFIX}\.linear_fc1\.{WB}$"), + TransformType.SPLIT_NONE, + [f"{HF_PROJECTOR_PREFIX}" + r".mlp.0.\g"], + ), + # projector fc2 + ConvertorRule( + re.compile(rf"^{MG_PROJECTOR_PREFIX}\.linear_fc2\.{WB}$"), + TransformType.SPLIT_NONE, + [f"{HF_PROJECTOR_PREFIX}" + r".mlp.2.\g"], + ), + ] + return projector_rules + + def build_rules(self) -> List[ConvertorRule]: + rules = [] + rules.extend(self._build_vision_rules()) + rules.extend(self._build_llm_rules()) + rules.extend(self._build_projector_rules()) + return rules + + +_MG2HF_CONVERTOR_REGISTRY = {} + + +def register_mg2hf_convertor(model_arch: str, convertor_cls: Callable) -> None: + if model_arch in _MG2HF_CONVERTOR_REGISTRY: + raise ValueError(f"Convertor for {model_arch} already registered") + _MG2HF_CONVERTOR_REGISTRY[model_arch] = convertor_cls + + +register_mg2hf_convertor("qwen2.5", Qwen2_5Convertor) +register_mg2hf_convertor("qwen2.5-vl", Qwen2_5VLConvertor) + + +def get_mg2hf_convertor(model_arch: str, config, strict: bool = False) -> BaseConvertor: + if model_arch not in _MG2HF_CONVERTOR_REGISTRY: + raise ValueError(f"No convertor registered for {model_arch}") + convertor_cls = _MG2HF_CONVERTOR_REGISTRY[model_arch] + return convertor_cls(config=config, strict=strict) diff --git a/rlinf/utils/resharding/mcore_weight_reshard.py b/rlinf/utils/resharding/mcore_weight_reshard.py index 9ec44eba0..90d8277fa 100644 --- a/rlinf/utils/resharding/mcore_weight_reshard.py +++ b/rlinf/utils/resharding/mcore_weight_reshard.py @@ -183,6 +183,6 @@ def get_layer_num(param_name): ) if self.config.convert_fn is not None: - model_state_dict = self.config.convert_fn(model_state_dict, self.config) + model_state_dict = self.config.convert_fn(model_state_dict) return model_state_dict diff --git a/rlinf/utils/resharding/reshard_config.py b/rlinf/utils/resharding/reshard_config.py index 2b493a839..79e089bcd 100644 --- a/rlinf/utils/resharding/reshard_config.py +++ b/rlinf/utils/resharding/reshard_config.py @@ -17,7 +17,9 @@ from megatron.core.transformer import TransformerConfig -from .utils import get_convert_fn, get_pp_reshard_fn, get_tp_reshard_fn +from rlinf.utils.convertor.utils import get_mg2hf_convertor + +from .utils import get_pp_reshard_fn, get_tp_reshard_fn @dataclass @@ -37,7 +39,7 @@ class ReshardConfig: """Resharding pp size.""" convert_fn: Callable = None - """Convert function to use for converting the model parameters' weight and name from training engine to rollout engine.""" + """Function to convert the model weights from megatron format to HuggingFace format.""" tp_reshard_fn: Callable = None """Resharding function to use for resharding the model parallelism from tensor_model_parallel_size to reshard_tp_size.""" @@ -59,7 +61,8 @@ def __post_init__(self): ) if self.convert_fn is None and self.reshard_weights_format != "mcore": - self.convert_fn = get_convert_fn(self.model_arch) + self._convertor = get_mg2hf_convertor(self.model_arch, self, strict=True) + self.convert_fn = self._convertor.convert if self.tp_reshard_fn is None: self.tp_reshard_fn = get_tp_reshard_fn(self.model_arch) diff --git a/rlinf/utils/resharding/utils.py b/rlinf/utils/resharding/utils.py index 1fae2b05a..82ca3eadf 100644 --- a/rlinf/utils/resharding/utils.py +++ b/rlinf/utils/resharding/utils.py @@ -12,23 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re -from enum import Enum -from typing import List, Tuple import torch from megatron.core import parallel_state -def get_convert_fn(model_arch: str): - if model_arch == "qwen2.5": - return TransformFunc.convert_mega_qwen2_5_to_hf - else: - raise NotImplementedError( - f"get_convert_fn for model_arch {model_arch} is not implemented" - ) - - def get_tp_reshard_fn(model_arch: str): if model_arch == "qwen2.5": return tp_reshard_fn_qwen2_5 @@ -47,212 +35,6 @@ def get_pp_reshard_fn(model_arch: str): ) -########################### -# convert fn implementation -########################### - - -class TransformType(Enum): - SPLIT_QKV = "split_qkv" - SPLIT_QKV_BIAS = "split_qkv_bias" - SPLIT_FC1 = "split_fc1" - SPLIT_NONE = "split_none" - - -class TransformFunc: - @staticmethod - def _split_gqa_tensor( - tensor: torch.Tensor, new_statedict: dict, weight_names: List[str], config - ) -> None: - hidden_size = config.model_config.hidden_size - num_attention_heads = config.model_config.num_attention_heads - num_query_groups = config.model_config.num_query_groups or num_attention_heads - head_dim = hidden_size // num_attention_heads - - target_tp = config.reshard_tp_size - assert num_query_groups % target_tp == 0, ( - "num_query_groups must be divisible by reshard_tp_size" - ) - local_num_query_groups = num_query_groups // target_tp - - # heads per query group - assert num_attention_heads % num_query_groups == 0, ( - "num_attention_heads must be divisible by num_query_groups" - ) - q_heads_per_group = num_attention_heads // num_query_groups - - num_channel_qkv = q_heads_per_group + 2 - - if tensor.ndim == 2: - # Weight: [out_features, in_features] - out_features, in_features = tensor.shape - expected_out = local_num_query_groups * num_channel_qkv * head_dim - assert out_features == expected_out, ( - f"Unexpected fused QKV weight shape {tensor.shape}, expect " - f"[{expected_out}, {in_features}] (local groups={local_num_query_groups})" - ) - - qkv = tensor.view( - local_num_query_groups, num_channel_qkv, head_dim, in_features - ) - q, k, v = torch.split( - qkv, [q_heads_per_group, 1, 1], dim=1 - ) # shapes: [G, qh, D, In], [G,1,D,In], [G,1,D,In] - q_full = q.reshape(-1, in_features).contiguous() - k_full = k.reshape(-1, in_features).contiguous() - v_full = v.reshape(-1, in_features).contiguous() - else: - # Bias: [out_features] - out_features = tensor.shape[0] - expected_out = local_num_query_groups * num_channel_qkv * head_dim - assert out_features == expected_out, ( - f"Unexpected fused QKV bias shape {tensor.shape}, expect " - f"[{expected_out}] (local groups={local_num_query_groups})" - ) - - qkv = tensor.view(local_num_query_groups, num_channel_qkv, head_dim) - q, k, v = torch.split(qkv, [q_heads_per_group, 1, 1], dim=1) - q_full = q.reshape(-1).contiguous() - k_full = k.reshape(-1).contiguous() - v_full = v.reshape(-1).contiguous() - - # Save to target names - new_statedict[weight_names[0]] = q_full.clone() - new_statedict[weight_names[1]] = k_full.clone() - new_statedict[weight_names[2]] = v_full.clone() - - @staticmethod - def split_fc1( - linear_fc1: torch.Tensor, new_statedict: dict, weight_names: List[str], config - ) -> None: - assert weight_names is not None and len(weight_names) == 2, ( - f"split_fc1 transform expects two weight names, got {weight_names}" - ) - - tp_size = config.model_config.tensor_model_parallel_size - target_tp = config.reshard_tp_size - split_size = linear_fc1.shape[0] // (tp_size // target_tp) - linear_fc1_slice = torch.split(linear_fc1, split_size, dim=0) - - gate_proj_shards = [] - up_proj_shards = [] - for weight in linear_fc1_slice: - assert weight.shape[0] % 2 == 0, ( - f"linear_fc1 weight shape {weight.shape} is not even along dim 0" - ) - weight_chunk = torch.chunk(weight, 2, dim=0) - gate_proj_shards.append(weight_chunk[0]) - up_proj_shards.append(weight_chunk[1]) - gate_proj = torch.cat(gate_proj_shards, dim=0) - up_proj = torch.cat(up_proj_shards, dim=0) - - new_statedict[weight_names[0]] = gate_proj.clone() - new_statedict[weight_names[1]] = up_proj.clone() - - @staticmethod - def split_none( - tensor: torch.Tensor, new_statedict: dict, weight_names: List[str] - ) -> None: - assert weight_names is not None and len(weight_names) == 1, ( - f"split_none transform expects one weight name, got {weight_names}" - ) - new_statedict[weight_names[0]] = tensor.clone() - - @staticmethod - def mega_name_qwen2_5_to_hf(name: str) -> Tuple[TransformType, List[str]]: - """ - Convert qwen2_5 model weight megatron name to hf name and do shape transform if needed. - - Args: - name (str): megatron model weight name - - Returns: - (TransformType, List[str]): transform type and the corresponding hf model weight name - """ - if "embedding.word_embeddings.weight" in name: - return (TransformType.SPLIT_NONE, ["model.embed_tokens.weight"]) - if "decoder.final_layernorm.weight" in name: - return (TransformType.SPLIT_NONE, ["model.norm.weight"]) - if "output_layer.weight" in name: - return (TransformType.SPLIT_NONE, ["lm_head.weight"]) - layer_id, suffix = TransformFunc.extract_layer_info(name) - assert layer_id is not None, f"Cannot extract layer info from {name}" - result_pattern = "model.layers.{}.{}" - nmap = { - "self_attention.linear_proj.weight": ( - TransformType.SPLIT_NONE, - ["self_attn.o_proj.weight"], - ), - "self_attention.linear_qkv.layer_norm_weight": ( - TransformType.SPLIT_NONE, - ["input_layernorm.weight"], - ), - "self_attention.linear_qkv.weight": ( - TransformType.SPLIT_QKV, - [ - "self_attn.q_proj.weight", - "self_attn.k_proj.weight", - "self_attn.v_proj.weight", - ], - ), - "self_attention.linear_qkv.bias": ( - TransformType.SPLIT_QKV_BIAS, - [ - "self_attn.q_proj.bias", - "self_attn.k_proj.bias", - "self_attn.v_proj.bias", - ], - ), - "mlp.linear_fc1.layer_norm_weight": ( - TransformType.SPLIT_NONE, - ["post_attention_layernorm.weight"], - ), - "mlp.linear_fc1.weight": ( - TransformType.SPLIT_FC1, - ["mlp.gate_proj.weight", "mlp.up_proj.weight"], - ), - "mlp.linear_fc2.weight": ( - TransformType.SPLIT_NONE, - ["mlp.down_proj.weight"], - ), - } - - assert suffix in nmap, f"Cannot find mapping for {suffix}" - - transform_type, suffixes = nmap[suffix] - return ( - transform_type, - [result_pattern.format(layer_id, suffix) for suffix in suffixes], - ) - - @staticmethod - def convert_mega_qwen2_5_to_hf(model_state_dict: dict, config) -> dict: - new_statedict = {} - for name, param in model_state_dict.items(): - transform_type, hf_names = TransformFunc.mega_name_qwen2_5_to_hf(name) - if transform_type == TransformType.SPLIT_QKV: - TransformFunc._split_gqa_tensor(param, new_statedict, hf_names, config) - elif transform_type == TransformType.SPLIT_QKV_BIAS: - TransformFunc._split_gqa_tensor(param, new_statedict, hf_names, config) - elif transform_type == TransformType.SPLIT_FC1: - TransformFunc.split_fc1(param, new_statedict, hf_names, config) - elif transform_type == TransformType.SPLIT_NONE: - TransformFunc.split_none(param, new_statedict, hf_names) - else: - raise NotImplementedError( - f"Transform type {transform_type} not implemented" - ) - return new_statedict - - @staticmethod - def extract_layer_info(s): - pattern = r"layers\.(\d+)\.(.+)" - match = re.search(pattern, s) - if match: - return match.group(1), match.group(2) - return None, None - - ############################## # tp reshard fn implementation ############################## From 5db6fab7125716285b2dbf17fe6abccaf27acf79 Mon Sep 17 00:00:00 2001 From: cc <1716911340@qq.com> Date: Fri, 12 Sep 2025 15:40:01 +0800 Subject: [PATCH 02/38] feat(rollout_mm): add multimodal input/output for rollout backend (#2) * Merge pull request #1 from anHappyDog/feature/weight_convertor feat(weight): refactor and add qwen2.5-vl mg2hf convertor * feat(mm_input): add basic vision-language dataset processor and yaml config Signed-off-by: Bo Dai * feat(mm_input): add vLLM multimodal support Signed-off-by: Bo Dai --------- Signed-off-by: Bo Dai --- .../config/qwen2.5-vl-3b-grpo-megatron.yaml | 262 ++++++++++++++++++ examples/vlm/main_vlm.py | 99 +++++++ examples/vlm/run_main_vlm_grpo_megatron.sh | 21 ++ rlinf/config.py | 12 +- rlinf/data/datasets.py | 245 +++++++++++++--- rlinf/data/io_struct.py | 28 +- rlinf/runners/math_runner.py | 7 +- rlinf/workers/rollout/sglang/sglang_worker.py | 4 +- 8 files changed, 629 insertions(+), 49 deletions(-) create mode 100644 examples/vlm/config/qwen2.5-vl-3b-grpo-megatron.yaml create mode 100644 examples/vlm/main_vlm.py create mode 100644 examples/vlm/run_main_vlm_grpo_megatron.sh diff --git a/examples/vlm/config/qwen2.5-vl-3b-grpo-megatron.yaml b/examples/vlm/config/qwen2.5-vl-3b-grpo-megatron.yaml new file mode 100644 index 000000000..e04fe0118 --- /dev/null +++ b/examples/vlm/config/qwen2.5-vl-3b-grpo-megatron.yaml @@ -0,0 +1,262 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + num_gpus_per_node: 8 + component_placement: + actor,rollout: all + +runner: + task_type: math + logger: + log_path: ${runner.output_dir}/${runner.experiment_name} + project_name: rlinf + experiment_name: ${runner.experiment_name} + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 5 + max_steps: -1 + + val_check_interval: 1 + save_interval: 50 + + seq_length: 2048 + + enable_dynamic_batch_size: False + max_tokens_per_mbs: 28672 + + resume_dir: null + experiment_name: qwen2.5-vl-3b-grpo + output_dir: /mnt/public/daibo/results + +algorithm: + group_size: 16 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: True + shuffle_rollout: False + + # GRPO loss params + loss_type: ppo + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + + adv_type: grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /mnt/public/hf_models/qwen2.5-VL-3B/ + model_arch: qwen2.5-vl + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + attention_backend: triton + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + sglang_decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + +data: + type: vision_language + max_prompt_length: 1024 + filter_prompt_by_length: True + rollout_batch_size: 512 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + image_keys: ["image"] # some vlm datasets may have multiple image columns + choice_key: "choices" + answer_key: "answer" + solution_key: "solution" + use_chat_template: True + shuffle: True + validation_shuffle: True + seed: 1234 + train_data_paths: ["/mnt/public/daibo/dataset/science_qa/data/train-00000-of-00001-1028f23e353fbe3e.parquet"] + val_data_paths: ["/mnt/public/daibo/dataset/science_qa/data/validation-00000-of-00001-6c7328ff6c84284c.parquet"] + +actor: + group_name: "ActorGroup" + training_backend: megatron + mcore_gpt: True + spec_name: decoder_gpt + + checkpoint_load_path: /mnt/public/mg_ckpts/qwen2.5-VL-3B-tp2-pp1/ + + offload_optimizer: True + offload_weight: True + offload_grad: True + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: fp16 + add_bias_linear: False + + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + + activation: swiglu + sequence_parallel: True + # recompute_method: block + # recompute_granularity: selective + + recompute_method: block + recompute_granularity: full + recompute_num_layers: 20 + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + + normalization: rmsnorm + + position_embedding_type: rope + + apply_rope_fusion: True + bias_dropout_fusion: False + persist_layer_norm: False + bias_activation_fusion: False + attention_softmax_in_fp32: True + batch_p2p_comm: False + variable_seq_lengths: True + gradient_accumulation_fusion: False + moe_token_dispatcher_type: alltoall + use_cpu_initialization: False + + optim: + optimizer: adam + bf16: False + fp16: True + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /mnt/public/hf_models/qwen2.5-VL-3B/ + use_fast: False + trust_remote_code: True + padding_side: 'right' + + megatron: + ddp_bucket_size: null + distributed_backend: nccl # Support 'nccl' and 'gloo' + distributed_timeout_minutes: 30 + ckpt_format: torch + use_dist_ckpt: False + tp_comm_bootstrap_backend: nccl + tp_comm_overlap_cfg: null # tp_comm_overlap_cfg.yaml + use_hf_ckpt: False # if true, will transfer hf model to generate megatron checkpoint and use it for training. + use_profiler: False # if true, will enable torch profiler when training, pay attention it has influence on performance + + ckpt_convertor: # config for ckpt convertor + model: DeepSeek-R1-Distill-Qwen-1.5B + model_type: null # will be set by hf model's config if null + hf_model_path: ${rollout.model_dir} # path to the hf model + save_path: ${runner.output_dir}/${runner.experiment_name}/converted_ckpts/actor + use_gpu_num : 0 + use_gpu_index: null + process_num: 16 # number of processes to use for checkpointing + tensor_model_parallel_size: ${actor.model.tensor_model_parallel_size} + pipeline_model_parallel_size: ${actor.model.pipeline_model_parallel_size} + + profiler: # profile megatron when inference and traning + output_dir: ${runner.output_dir}/${runner.experiment_name}/profiler + activities: ["cpu", "cuda"] + record_shapes: False + profile_memory: False + with_stack: False + with_flops: False + with_modules: True + export_tensorboard: True + export_chrome_trace: False + chrome_filename_prefix: "chrome_trace" + schedule_warmup: 2 + schedule_active: 1 + schedule_repeat: 1 # inference and training will repeat such times + # schedule_wait: it will be set at runtime + + +reward: + use_reward_model: false + reward_type: 'math' + reward_scale: 5.0 + +critic: + use_critic_model: false \ No newline at end of file diff --git a/examples/vlm/main_vlm.py b/examples/vlm/main_vlm.py new file mode 100644 index 000000000..32466fa12 --- /dev/null +++ b/examples/vlm/main_vlm.py @@ -0,0 +1,99 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import hydra +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf + +from rlinf.config import validate_cfg +from rlinf.data.datasets import create_rl_dataset +from rlinf.data.tokenizers import hf_tokenizer +from rlinf.runners.math_runner import MathRunner +from rlinf.scheduler import Cluster +from rlinf.utils.placement import ModelParallelComponentPlacement, PlacementMode +from rlinf.utils.utils import output_redirector +from rlinf.workers.actor.megatron_actor_worker import MegatronActor +from rlinf.workers.inference.megatron_inference_worker import MegatronInference +from rlinf.workers.rollout.sglang.sglang_worker import AsyncSGLangWorker, SGLangWorker + +"""Script to start GRPO training""" +mp.set_start_method("spawn", force=True) + + +@hydra.main(version_base="1.1") +@output_redirector +def main(cfg) -> None: + cfg = validate_cfg(cfg) + print(json.dumps(OmegaConf.to_container(cfg, resolve=True), indent=2)) + + cluster = Cluster( + num_nodes=cfg.cluster.num_nodes, num_gpus_per_node=cfg.cluster.num_gpus_per_node + ) + component_placement = ModelParallelComponentPlacement(cfg) + + # Rollout group + rollout_placement_strategy = component_placement.get_strategy("rollout") + SGLangWorkerCls = ( + SGLangWorker + if component_placement.placement_mode == PlacementMode.COLLOCATED + else AsyncSGLangWorker + ) + rollout_group = SGLangWorkerCls.create_group(cfg, component_placement).launch( + cluster, + name=cfg.rollout.group_name, + placement_strategy=rollout_placement_strategy, + ) + + # Inference group + inference_group = None + if ( + component_placement.placement_mode == PlacementMode.DISAGGREGATED + and cfg.algorithm.recompute_logprobs + ): + inference_placement_strategy = component_placement.get_strategy("inference") + inference_group = MegatronInference.create_group( + cfg, component_placement + ).launch( + cluster, + name=cfg.inference.group_name, + placement_strategy=inference_placement_strategy, + ) + + # GRPO Actor group + actor_placement_strategy = component_placement.get_strategy("actor") + actor_group = MegatronActor.create_group(cfg, component_placement).launch( + cluster, name=cfg.actor.group_name, placement_strategy=actor_placement_strategy + ) + + tokenizer = hf_tokenizer(cfg.actor.tokenizer.tokenizer_model) + train_ds, val_ds = create_rl_dataset(cfg.data, tokenizer) + + runner = MathRunner( + cfg=cfg, + placement=component_placement, + train_dataset=train_ds, + val_dataset=val_ds, + rollout=rollout_group, + inference=inference_group, + actor=actor_group, + ) + + runner.init_workers() + runner.run() + + +if __name__ == "__main__": + main() diff --git a/examples/vlm/run_main_vlm_grpo_megatron.sh b/examples/vlm/run_main_vlm_grpo_megatron.sh new file mode 100644 index 000000000..2e5a75e3a --- /dev/null +++ b/examples/vlm/run_main_vlm_grpo_megatron.sh @@ -0,0 +1,21 @@ +#! /bin/bash +set -x + +tabs 4 +export VLLM_ATTENTION_BACKEND=XFORMERS +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export TOKENIZERS_PARALLELISM=false +export RAY_DEDUP_LOGS=0 + +CONFIG_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +REPO_PATH=$(dirname $(dirname "$CONFIG_PATH")) +MEGATRON_PATH=/opt/Megatron-LM +export PYTHONPATH=${REPO_PATH}:${MEGATRON_PATH}:$PYTHONPATH + +if [ -z "$1" ]; then + CONFIG_NAME="qwen2.5-vl-3b-grpo-megatron" +else + CONFIG_NAME=$1 +fi + +python ${REPO_PATH}/examples/vlm/main_vlm.py --config-path ${CONFIG_PATH}/config/ --config-name $CONFIG_NAME \ No newline at end of file diff --git a/rlinf/config.py b/rlinf/config.py index 86198334a..f3b2984cd 100644 --- a/rlinf/config.py +++ b/rlinf/config.py @@ -199,8 +199,16 @@ def validate_model_cfg_by_hf_config(cfg, hf_model_path): qkv_bias = getattr(hf_config, "attention_bias", False) with open_dict(cfg): - if hf_config.rope_scaling is not None: - cfg.model.seq_len_interpolation_factor = hf_config.rope_scaling["factor"] + rs = getattr(hf_config, "rope_scaling", None) + if isinstance(rs, dict): + rtype = rs.get("type", "") + if rtype in {"linear", "dynamic", "ntk", "yarn"}: + f = rs.get("factor") + if f is not None: + cfg.model.seq_len_interpolation_factor = float(f) + else: + # mrope + cfg.model.seq_len_interpolation_factor = None cfg.model.override_vocab_size = hf_config.vocab_size cfg.model.max_position_embeddings = hf_config.max_position_embeddings cfg.model.rotary_base = hf_config.rope_theta diff --git a/rlinf/data/datasets.py b/rlinf/data/datasets.py index fcce53f47..e33667703 100644 --- a/rlinf/data/datasets.py +++ b/rlinf/data/datasets.py @@ -15,11 +15,12 @@ import json import logging import os -from collections import defaultdict -from typing import List +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np +import pandas as pd import torch +from omegaconf import DictConfig from torch.utils.data import Dataset @@ -60,6 +61,18 @@ def batch_pad_to_fixed_len( return batch_pad +@dataclass +class DatasetItem: + prompt: torch.Tensor + length: int + answer: str + idx: int + solution: Optional[str] = None + image_data: Optional[List[Union[bytes, str]]] = None + prompt_text: Optional[str] = None + meta: Optional[Dict[str, Any]] = None + + class MathDataset(Dataset): def __init__(self, data_paths, config, tokenizer): super().__init__() @@ -151,16 +164,160 @@ def __getitem__(self, idx): self.tokenizer.eos_token_id, left_pad=True, )[0] - - output = { - "prompt": prompt_tokens_tensor, - "length": prompt_length, - "answer": answer, - "idx": idx, - } + output = DatasetItem( + prompt=prompt_tokens_tensor, + length=prompt_length, + answer=answer, + idx=idx, + image_data=[], + ) return output +class VisionLanguageDataset(Dataset): + def __init__( + self, data_paths: Union[List[str], str], config: DictConfig, tokenizer + ): + super().__init__() + self.data_paths = data_paths + self.use_chat_template = config.use_chat_template + + self.image_keys = config.image_keys + self.prompt_key = config.prompt_key + self.choice_key = config.choice_key + self.answer_key = config.answer_key + self.solution_key = config.solution_key + + if isinstance(self.data_paths, str): + self.data_paths = [self.data_paths] + + self.max_prompt_length = config.max_prompt_length + self.tokenizer = tokenizer + self.data = self._load_data() + self.post_process() + + def post_process(self) -> None: + def get_image_list( + dataitem: Dict, image_keys: Optional[List[str]] + ) -> List[Union[bytes, str]]: + image_list: List[Union[bytes, str]] = [] + if image_keys: + for key in image_keys: + image_content = dataitem.get(key, None) + if image_content is None: + continue + if isinstance(image_content, dict) and "bytes" in image_content: + image_content = image_content["bytes"] + assert isinstance(image_content, bytes), ( + f"image content should be bytes, but got {type(image_content)} , content is {image_content}" + ) + image_list.append(image_content) + return image_list + + def process_prompt( + data_item: Dict, image_count: int + ) -> Tuple[str, List[int], int]: + question = data_item.get(self.prompt_key, "") + options = data_item.get(self.choice_key, []) + if not isinstance(options, list): + options = [options] + prompt_text = question + if options: + prompt_text += f"{options}\n" + if self.use_chat_template: + message_content: List = [] + for i in range(image_count): + message_content.append({"type": "image"}) + message_content.append({"type": "text", "text": prompt_text}) + messages = [{"role": "user", "content": message_content}] + prompt_text = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + prompt_ids = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + ) + if isinstance(prompt_ids, torch.Tensor): + if prompt_ids.dim() == 2 and prompt_ids.size(0) == 1: + prompt_ids = prompt_ids.squeeze(0) # [L] + prompt_ids = prompt_ids.to(dtype=torch.long) + else: + prompt_ids = torch.tensor(prompt_ids, dtype=torch.long) + prompt_length = len(prompt_ids) + + return prompt_text, prompt_ids, prompt_length + else: + raise NotImplementedError("Non-chat template not implemented yet.") + + processed_data: List[DatasetItem] = [] + for idx, item in enumerate(self.data): + image_list: List[Union[bytes, str]] = get_image_list(item, self.image_keys) + prompt_text, prompt_ids, prompt_length = process_prompt( + item, len(image_list) + ) + + if prompt_length > self.max_prompt_length: + print( + f"prompt_ids length {prompt_length} exceeds the max_prompt_length {self.max_prompt_length}", + ) + prompt_ids = prompt_ids[: self.max_prompt_length] + prompt_length = self.max_prompt_length + prompt_ids = batch_pad_to_fixed_len( + [prompt_ids], + self.max_prompt_length, + self.tokenizer.eos_token_id, + left_pad=True, + )[0] + answer = item.get(self.answer_key, None) + solution = item.get(self.solution_key, None) + + data_item = DatasetItem( + prompt_text=prompt_text, + prompt=prompt_ids, + length=prompt_length, + image_data=image_list, + answer=answer, + solution=solution, + idx=idx, + ) + processed_data.append(data_item) + self.data = processed_data + + def _load_data(self) -> List: + merged_data = [] + for path in self.data_paths: + _, file_extension = os.path.splitext(path) + try: + pass + if file_extension == ".parquet": + loaded_data: List = pd.read_parquet(path).to_dict(orient="records") + merged_data.extend(loaded_data) + elif file_extension == ".jsonl": + with open(path, "r", encoding="utf-8") as file: + loaded_data = [json.loads(line.strip()) for line in file] + merged_data.extend(loaded_data) + elif file_extension == ".json": + with open(path, "r", encoding="utf-8") as file: + content = json.load(file) + if isinstance(content, list): + merged_data.extend(content) + else: + merged_data.append(content) + else: + print(f"Unsupport {file_extension}, skip: {path}") + except Exception as e: + raise RuntimeError(f"Load data error: {e}") + return merged_data + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, index): + return self.data[index] + + def create_rl_dataset(data_config, tokenizer): """Create rl datasets. @@ -176,6 +333,8 @@ def create_rl_dataset(data_config, tokenizer): if data_config.type == "math": dataset_cls = MathDataset + elif data_config.type == "vision_language": + dataset_cls = VisionLanguageDataset else: return None, None @@ -197,32 +356,46 @@ def create_rl_dataset(data_config, tokenizer): return train_dataset, val_dataset -def collate_fn(data_list: list[dict]) -> dict: - r""" - Collate a batch of sample dicts into batched tensors and arrays. - - Args: - data_list: List of dicts mapping feature names to torch.Tensor or other values. - - Returns: - Dict where tensor entries are stacked into a torch.Tensor of shape - (batch_size, \*dims) and non-tensor entries are converted to - np.ndarray of dtype object with shape (batch_size,). - """ - tensors = defaultdict(list) - non_tensors = defaultdict(list) - - for data in data_list: - for key, val in data.items(): - if isinstance(val, torch.Tensor): - tensors[key].append(val) - else: - non_tensors[key].append(val) +def collate_fn(data_list: List["DatasetItem"]) -> Dict[str, Any]: + prompts = [] + lens = [] + for it in data_list: + p = ( + it.prompt + if isinstance(it.prompt, torch.Tensor) + else torch.as_tensor(it.prompt, dtype=torch.long) + ) + if p.dim() == 2 and p.size(0) == 1: + p = p.squeeze(0) + assert p.dim() == 1, ( + f"DatasetItem.prompt must be 1-D tensor, current shape is: {p.shape}" + ) + prompts.append(p) + lens.append(p.numel()) - for key, val in tensors.items(): - tensors[key] = torch.stack(val, dim=0) + if len(set(lens)) == 1: + target_len = lens[0] + else: + target_len = min(lens) + prompts = [p[-target_len:] if p.numel() > target_len else p for p in prompts] - for key, val in non_tensors.items(): - non_tensors[key] = np.array(val, dtype=object) + batch_prompt = torch.stack(prompts, dim=0) # [B, L] + batch_length = torch.tensor( + [min(int(it.length), target_len) for it in data_list], dtype=torch.long + ) - return {**tensors, **non_tensors} + batch_idx = torch.tensor([int(it.idx) for it in data_list], dtype=torch.long) + + batch: Dict[str, Any] = { + "prompt": batch_prompt, # [B, L] + "length": batch_length, # [B] + "answer": [it.answer for it in data_list], # List[str] + "idx": batch_idx, # [B] + "solution": [it.solution for it in data_list], # List[Optional[str]] + "image_data": [ + it.image_data for it in data_list + ], # List[Optional[List[bytes|str]]] + "prompt_text": [it.prompt_text for it in data_list], # List[Optional[str]] + "meta": [it.meta for it in data_list], # List[Optional[dict]] + } + return batch diff --git a/rlinf/data/io_struct.py b/rlinf/data/io_struct.py index 1c455fa70..510acb292 100644 --- a/rlinf/data/io_struct.py +++ b/rlinf/data/io_struct.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from omegaconf import DictConfig @@ -49,12 +49,14 @@ class RolloutRequest: n: Number of completions to generate for each input idx: List of unique identifiers for the requests, used for tracking input_lengths: List of lengths of the input sequences, corresponding to input_ids + image_data: list of image data (bytes or URLs) for multimodal inputs answers: Optional list of answers for the requests, if available """ n: int input_ids: List[List[int]] answers: List[str] + image_data: Union[List[List[bytes]], List[List[str]]] def repeat(self) -> "RolloutRequest": """Repeat each input in the RolloutRequest a specified number of times. @@ -113,14 +115,20 @@ def split(self, num_splits: int) -> List["RolloutRequest"]: def repeat_and_split( self, rollout_batch_size: Optional[int] = None ) -> List["RolloutRequest"]: - input_ids, answers = zip( + input_ids, answers, image_data = zip( *[ - (input_id, answer) - for input_id, answer in zip(self.input_ids, self.answers) + (input_id, answer, image_data) + for input_id, answer, image_data in zip( + self.input_ids, self.answers, self.image_data + ) for _ in range(self.n) ] ) - input_ids, answers = (list(input_ids), list(answers)) + input_ids, answers, image_data = ( + list(input_ids), + list(answers), + list(image_data), + ) # Split input ids based on rollout_batch_size_per_gpu if rollout_batch_size is None: @@ -134,14 +142,16 @@ def repeat_and_split( splitted_requests = [] input_ids_split_list = split_list(input_ids, num_batches) answers_split_list = split_list(answers, num_batches) + image_data_split_list = split_list(image_data, num_batches) - for input_ids_batch, answers_batch in zip( - input_ids_split_list, answers_split_list + for input_ids_batch, answers_batch, image_data_batch in zip( + input_ids_split_list, answers_split_list, image_data_split_list ): request = RolloutRequest( n=self.n, input_ids=input_ids_batch, answers=answers_batch, + image_data=image_data_batch, ) splitted_requests.append(request) @@ -257,7 +267,7 @@ class RolloutResult: prompt_texts: Optional[List[str]] = None response_texts: Optional[List[str]] = None answers: Optional[List[str]] = None - + image_data: Optional[Union[List[List[bytes]], List[List[str]]]] = None # Inference # Only set when recompute_logprobs is False rollout_logprobs: Optional[List[List[float]]] = None @@ -380,6 +390,7 @@ def from_sglang_results( group_size: int, input_ids: List[List[int]], answers: Optional[List[List[int]]] = None, + image_data: Optional[Union[List[List[bytes]], List[List[str]]]] = None, return_logprobs: bool = False, ) -> "RolloutResult": """Create a MathRolloutResult from the given results and input IDs. @@ -406,6 +417,7 @@ def from_sglang_results( response_lengths=[len(res["output_ids"]) for res in results], response_ids=[res["output_ids"] for res in results], answers=answers, + image_data=image_data, is_end=[ res["meta_info"]["finish_reason"]["type"] == "stop" for res in results ], diff --git a/rlinf/runners/math_runner.py b/rlinf/runners/math_runner.py index e3f5b3750..a88826f5a 100644 --- a/rlinf/runners/math_runner.py +++ b/rlinf/runners/math_runner.py @@ -274,18 +274,21 @@ def epoch(self): def _put_batch(self, batch: Dict[str, torch.Tensor]): prompt_ids = batch["prompt"].tolist() lengths = batch["length"].tolist() - answers = batch["answer"].tolist() + answers = batch["answer"] + image_data = batch["image_data"] prompts = [ids[-pmp_len:] for ids, pmp_len in zip(prompt_ids, lengths)] rollout_dp_size = self.component_placement.rollout_dp_size - for input_ids, answers in zip( + for input_ids, answers, image_data in zip( split_list(prompts, rollout_dp_size, enforce_divisible_batch=False), split_list(answers, rollout_dp_size, enforce_divisible_batch=False), + split_list(image_data, rollout_dp_size, enforce_divisible_batch=False), ): request = RolloutRequest( n=self.cfg.algorithm.group_size, input_ids=input_ids, answers=answers, + image_data=image_data, ) self.dataloader_channel.put(request, async_op=True) diff --git a/rlinf/workers/rollout/sglang/sglang_worker.py b/rlinf/workers/rollout/sglang/sglang_worker.py index 8d4a15cb7..6a921c0bd 100644 --- a/rlinf/workers/rollout/sglang/sglang_worker.py +++ b/rlinf/workers/rollout/sglang/sglang_worker.py @@ -169,7 +169,6 @@ def sync_model_from_actor(self): def rollout(self, input_channel: Channel, output_channel: Channel): request: RolloutRequest = input_channel.get() - # Repeat prompts based on the group_size config requests = request.repeat_and_split(self._rollout_batch_size) @@ -181,6 +180,8 @@ def rollout(self, input_channel: Channel, output_channel: Channel): with self.worker_timer(): results = self._engine.generate( input_ids=request.input_ids, + # 0.4.4 has modality bug,can't pass non-None image_data + image_data=request.image_data if any(request.image_data) else None, sampling_params=self._sampling_params, return_logprob=self._return_logprobs, ) @@ -191,6 +192,7 @@ def rollout(self, input_channel: Channel, output_channel: Channel): request.n, request.input_ids, request.answers, + request.image_data, self._return_logprobs, ) rollout_results.append(rollout_result) From 90a13089eed7b521baba5a24812c5ab77e7e5930 Mon Sep 17 00:00:00 2001 From: guozhen <37097045+guozhen1997@users.noreply.github.com> Date: Thu, 18 Sep 2025 18:13:33 +0800 Subject: [PATCH 03/38] feat(vlm): support VLM sglang rollout and fsdp training (#6) Signed-off-by: guozhen1997 <2997871698@qq.com> Signed-off-by: Bo Dai Co-authored-by: Bo Dai --- .../math/config/qwen2.5-1.5b-grpo-fsdp.yaml | 203 ++++++++ examples/math/main_math.py | 7 +- ...grpo_megatron.sh => run_main_math_grpo.sh} | 2 +- .../run_main_math_pipeline_grpo_megatron.sh | 21 - .../config/qwen2.5-vl-3b-grpo-megatron.yaml | 132 ++--- examples/vlm/main_vlm.py | 18 +- rlinf/algorithms/losses.py | 17 +- rlinf/config.py | 16 +- rlinf/data/datasets.py | 62 ++- rlinf/data/io_struct.py | 146 ++++++ .../hybrid_engines/fsdp/fsdp_model_manager.py | 46 +- rlinf/hybrid_engines/fsdp/utils.py | 8 +- .../sglang/sglang_0_4_4/sgl_scheduler.py | 5 +- .../sglang/sglang_0_4_6/sgl_scheduler.py | 6 +- .../sglang/sglang_0_4_9/sgl_scheduler.py | 5 +- rlinf/models/__init__.py | 3 +- rlinf/models/embodiment/model_utils.py | 25 +- rlinf/runners/math_runner.py | 2 +- rlinf/utils/convertor/utils.py | 2 +- rlinf/utils/distributed.py | 27 +- rlinf/utils/placement.py | 36 +- rlinf/utils/resharding/utils.py | 3 +- rlinf/utils/utils.py | 52 ++ rlinf/workers/actor/__init__.py | 17 + rlinf/workers/actor/fsdp_actor_worker.py | 455 +++++++++++++++++- rlinf/workers/rollout/sglang/sglang_worker.py | 5 +- rlinf/workers/rollout/utils.py | 6 + 27 files changed, 1068 insertions(+), 259 deletions(-) create mode 100644 examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml rename examples/math/{run_main_math_grpo_megatron.sh => run_main_math_grpo.sh} (92%) delete mode 100644 examples/math/run_main_math_pipeline_grpo_megatron.sh diff --git a/examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml b/examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml new file mode 100644 index 000000000..d8a1e8c3f --- /dev/null +++ b/examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml @@ -0,0 +1,203 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + num_gpus_per_node: 8 + component_placement: + actor,rollout: all + +runner: + task_type: math + logger: + log_path: ${runner.output_dir}/${runner.experiment_name} + project_name: rlinf + experiment_name: ${runner.experiment_name} + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 5 + max_steps: -1 + + val_check_interval: 1 + save_interval: 50 + + seq_length: 2048 + + enable_dynamic_batch_size: False + max_tokens_per_mbs: 28672 + + resume_dir: null + experiment_name: grpo-1.5b + output_dir: ../results + +algorithm: + group_size: 8 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: False + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /mnt/public/hf_models/DeepSeek-R1-Distill-Qwen-1.5B + model_arch: qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm] + + sglang: + attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: math + max_prompt_length: 1024 + filter_prompt_by_length: True + rollout_batch_size: 8 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + shuffle: True + validation_shuffle: True + seed: 1234 + train_data_paths: ["/mnt/public/guozhen/data/boba_106k_0319_prompt_1024.jsonl"] + val_data_paths: ["/mnt/public/guozhen/data/boba_106k_0319_prompt_1024.jsonl"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: fp16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /mnt/public/hf_models/DeepSeek-R1-Distill-Qwen-1.5B + + optim: + optimizer: adam + bf16: False + fp16: True + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /mnt/public/hf_models/DeepSeek-R1-Distill-Qwen-1.5B + use_fast: False + trust_remote_code: True + padding_side: 'right' + +reward: + use_reward_model: false + reward_type: 'math' + reward_scale: 5.0 + +critic: + use_critic_model: false \ No newline at end of file diff --git a/examples/math/main_math.py b/examples/math/main_math.py index 7150566fb..80f408bd0 100644 --- a/examples/math/main_math.py +++ b/examples/math/main_math.py @@ -25,7 +25,7 @@ from rlinf.scheduler import Cluster from rlinf.utils.placement import ModelParallelComponentPlacement, PlacementMode from rlinf.utils.utils import output_redirector -from rlinf.workers.actor.megatron_actor_worker import MegatronActor +from rlinf.workers.actor import get_actor_worker from rlinf.workers.inference.megatron_inference_worker import MegatronInference from rlinf.workers.rollout.utils import get_rollout_backend_worker @@ -68,13 +68,14 @@ def main(cfg) -> None: ) # GRPO Actor group + actor_worker_cls = get_actor_worker(cfg) actor_placement_strategy = component_placement.get_strategy("actor") - actor_group = MegatronActor.create_group(cfg, component_placement).launch( + actor_group = actor_worker_cls.create_group(cfg, component_placement).launch( cluster, name=cfg.actor.group_name, placement_strategy=actor_placement_strategy ) tokenizer = hf_tokenizer(cfg.actor.tokenizer.tokenizer_model) - train_ds, val_ds = create_rl_dataset(cfg.data, tokenizer) + train_ds, val_ds = create_rl_dataset(cfg, tokenizer) runner = MathRunner( cfg=cfg, diff --git a/examples/math/run_main_math_grpo_megatron.sh b/examples/math/run_main_math_grpo.sh similarity index 92% rename from examples/math/run_main_math_grpo_megatron.sh rename to examples/math/run_main_math_grpo.sh index f826f882f..dc2f75ee0 100644 --- a/examples/math/run_main_math_grpo_megatron.sh +++ b/examples/math/run_main_math_grpo.sh @@ -13,7 +13,7 @@ MEGATRON_PATH=/opt/Megatron-LM export PYTHONPATH=${REPO_PATH}:${MEGATRON_PATH}:$PYTHONPATH if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-1.5b-grpo-megatron" + CONFIG_NAME="qwen2.5-1.5b-grpo-fsdp" else CONFIG_NAME=$1 fi diff --git a/examples/math/run_main_math_pipeline_grpo_megatron.sh b/examples/math/run_main_math_pipeline_grpo_megatron.sh deleted file mode 100644 index 7deb96519..000000000 --- a/examples/math/run_main_math_pipeline_grpo_megatron.sh +++ /dev/null @@ -1,21 +0,0 @@ -#! /bin/bash -set -x - -tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export TOKENIZERS_PARALLELISM=false -export RAY_DEDUP_LOGS=0 - -CONFIG_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -REPO_PATH=$(dirname $(dirname "$CONFIG_PATH")) -MEGATRON_PATH=/opt/Megatron-LM -export PYTHONPATH=${REPO_PATH}:${MEGATRON_PATH}:$PYTHONPATH - -if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-1.5b-grpo-megatron-pipeline" -else - CONFIG_NAME=$1 -fi - -python ${REPO_PATH}/examples/math/main_math.py --config-path ${CONFIG_PATH}/config/ --config-name $CONFIG_NAME \ No newline at end of file diff --git a/examples/vlm/config/qwen2.5-vl-3b-grpo-megatron.yaml b/examples/vlm/config/qwen2.5-vl-3b-grpo-megatron.yaml index e04fe0118..cfe4febe7 100644 --- a/examples/vlm/config/qwen2.5-vl-3b-grpo-megatron.yaml +++ b/examples/vlm/config/qwen2.5-vl-3b-grpo-megatron.yaml @@ -32,11 +32,11 @@ runner: max_tokens_per_mbs: 28672 resume_dir: null - experiment_name: qwen2.5-vl-3b-grpo - output_dir: /mnt/public/daibo/results + experiment_name: grpo-1.5b + output_dir: ../results algorithm: - group_size: 16 + group_size: 8 n_minibatches: 4 training_batch_size_per_gpu: 1 # micro batch size @@ -50,11 +50,11 @@ algorithm: # val rollout mbs val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} - recompute_logprobs: True + recompute_logprobs: False shuffle_rollout: False # GRPO loss params - loss_type: ppo + loss_type: math_ppo_actor loss_agg_func: "token-mean" kl_beta: 0.0 # 0.001 kl_penalty_type: low_var_kl @@ -62,8 +62,10 @@ algorithm: entropy_bonus: 0.0 calculate_entropy: False clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null - adv_type: grpo + adv_type: math_grpo normalize_advantages: True early_stop_imp_ratio: 5.0 use_valid_token_scale: False @@ -83,38 +85,46 @@ rollout: gpu_memory_utilization: 0.55 - model_dir: /mnt/public/hf_models/qwen2.5-VL-3B/ - model_arch: qwen2.5-vl + model_dir: /mnt/public/hf_models/Qwen2.5-VL-3B-Instruct + model_arch: qwen2.5_vl #qwen2.5 enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. distributed_executor_backend: mp # ray or mp disable_log_stats: False detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm] - attention_backend: triton + sglang: + attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. return_logprobs: ${not:${algorithm.recompute_logprobs}} - tensor_parallel_size: 1 + tensor_parallel_size: 2 pipeline_parallel_size: 1 validate_weight: False # whether to send all weights at first for weight comparison. validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. - sglang_decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. max_running_requests: 64 # the maximum number of running requests in the rollout engine. cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. - use_torch_compile: False # enable torch_compile in SGLang for rollout. - torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. - data: type: vision_language max_prompt_length: 1024 filter_prompt_by_length: True - rollout_batch_size: 512 + rollout_batch_size: 8 val_rollout_batch_size: null num_workers: 2 prompt_key: prompt @@ -126,20 +136,20 @@ data: shuffle: True validation_shuffle: True seed: 1234 - train_data_paths: ["/mnt/public/daibo/dataset/science_qa/data/train-00000-of-00001-1028f23e353fbe3e.parquet"] - val_data_paths: ["/mnt/public/daibo/dataset/science_qa/data/validation-00000-of-00001-6c7328ff6c84284c.parquet"] + train_data_paths: ["/mnt/public/guozhen/data/science_qa/train-00000-of-00001-1028f23e353fbe3e.parquet"] + val_data_paths: ["/mnt/public/guozhen/data/science_qa/test-00000-of-00001-f0e719df791966ff.parquet"] actor: group_name: "ActorGroup" - training_backend: megatron + training_backend: fsdp mcore_gpt: True spec_name: decoder_gpt - checkpoint_load_path: /mnt/public/mg_ckpts/qwen2.5-VL-3B-tp2-pp1/ + enable_offload: True + checkpoint_load_path: null - offload_optimizer: True - offload_weight: True - offload_grad: True + global_batch_size: 8 + micro_batch_size: 1 enable_dp_load_balance: False @@ -148,43 +158,20 @@ actor: seed: 1234 model: - precision: fp16 - add_bias_linear: False - - tensor_model_parallel_size: 2 - pipeline_model_parallel_size: 1 - - activation: swiglu - sequence_parallel: True - # recompute_method: block - # recompute_granularity: selective - - recompute_method: block - recompute_granularity: full - recompute_num_layers: 20 + precision: bf16 + sharding_strategy: full_shard + is_lora: False seq_length: ${runner.seq_length} encoder_seq_length: ${runner.seq_length} + model_path: /mnt/public/hf_models/Qwen2.5-VL-3B-Instruct/ - normalization: rmsnorm - - position_embedding_type: rope - - apply_rope_fusion: True - bias_dropout_fusion: False - persist_layer_norm: False - bias_activation_fusion: False - attention_softmax_in_fp32: True - batch_p2p_comm: False - variable_seq_lengths: True - gradient_accumulation_fusion: False - moe_token_dispatcher_type: alltoall - use_cpu_initialization: False + model_arch: ${rollout.model_arch} optim: optimizer: adam - bf16: False - fp16: True + bf16: True #False + fp16: False #True lr: 2e-05 adam_beta1: 0.9 adam_beta2: 0.95 @@ -209,50 +196,11 @@ actor: lr_decay_iters: 10 tokenizer: - tokenizer_model: /mnt/public/hf_models/qwen2.5-VL-3B/ + tokenizer_model: /mnt/public/hf_models/Qwen2.5-VL-3B-Instruct use_fast: False trust_remote_code: True padding_side: 'right' - megatron: - ddp_bucket_size: null - distributed_backend: nccl # Support 'nccl' and 'gloo' - distributed_timeout_minutes: 30 - ckpt_format: torch - use_dist_ckpt: False - tp_comm_bootstrap_backend: nccl - tp_comm_overlap_cfg: null # tp_comm_overlap_cfg.yaml - use_hf_ckpt: False # if true, will transfer hf model to generate megatron checkpoint and use it for training. - use_profiler: False # if true, will enable torch profiler when training, pay attention it has influence on performance - - ckpt_convertor: # config for ckpt convertor - model: DeepSeek-R1-Distill-Qwen-1.5B - model_type: null # will be set by hf model's config if null - hf_model_path: ${rollout.model_dir} # path to the hf model - save_path: ${runner.output_dir}/${runner.experiment_name}/converted_ckpts/actor - use_gpu_num : 0 - use_gpu_index: null - process_num: 16 # number of processes to use for checkpointing - tensor_model_parallel_size: ${actor.model.tensor_model_parallel_size} - pipeline_model_parallel_size: ${actor.model.pipeline_model_parallel_size} - - profiler: # profile megatron when inference and traning - output_dir: ${runner.output_dir}/${runner.experiment_name}/profiler - activities: ["cpu", "cuda"] - record_shapes: False - profile_memory: False - with_stack: False - with_flops: False - with_modules: True - export_tensorboard: True - export_chrome_trace: False - chrome_filename_prefix: "chrome_trace" - schedule_warmup: 2 - schedule_active: 1 - schedule_repeat: 1 # inference and training will repeat such times - # schedule_wait: it will be set at runtime - - reward: use_reward_model: false reward_type: 'math' diff --git a/examples/vlm/main_vlm.py b/examples/vlm/main_vlm.py index 32466fa12..605577fba 100644 --- a/examples/vlm/main_vlm.py +++ b/examples/vlm/main_vlm.py @@ -25,9 +25,9 @@ from rlinf.scheduler import Cluster from rlinf.utils.placement import ModelParallelComponentPlacement, PlacementMode from rlinf.utils.utils import output_redirector -from rlinf.workers.actor.megatron_actor_worker import MegatronActor +from rlinf.workers.actor import get_actor_worker from rlinf.workers.inference.megatron_inference_worker import MegatronInference -from rlinf.workers.rollout.sglang.sglang_worker import AsyncSGLangWorker, SGLangWorker +from rlinf.workers.rollout.utils import get_rollout_backend_worker """Script to start GRPO training""" mp.set_start_method("spawn", force=True) @@ -44,14 +44,11 @@ def main(cfg) -> None: ) component_placement = ModelParallelComponentPlacement(cfg) + rollout_worker_cls = get_rollout_backend_worker(cfg, component_placement) + # Rollout group rollout_placement_strategy = component_placement.get_strategy("rollout") - SGLangWorkerCls = ( - SGLangWorker - if component_placement.placement_mode == PlacementMode.COLLOCATED - else AsyncSGLangWorker - ) - rollout_group = SGLangWorkerCls.create_group(cfg, component_placement).launch( + rollout_group = rollout_worker_cls.create_group(cfg, component_placement).launch( cluster, name=cfg.rollout.group_name, placement_strategy=rollout_placement_strategy, @@ -73,13 +70,14 @@ def main(cfg) -> None: ) # GRPO Actor group + actor_worker_cls = get_actor_worker(cfg) actor_placement_strategy = component_placement.get_strategy("actor") - actor_group = MegatronActor.create_group(cfg, component_placement).launch( + actor_group = actor_worker_cls.create_group(cfg, component_placement).launch( cluster, name=cfg.actor.group_name, placement_strategy=actor_placement_strategy ) tokenizer = hf_tokenizer(cfg.actor.tokenizer.tokenizer_model) - train_ds, val_ds = create_rl_dataset(cfg.data, tokenizer) + train_ds, val_ds = create_rl_dataset(cfg, tokenizer) runner = MathRunner( cfg=cfg, diff --git a/rlinf/algorithms/losses.py b/rlinf/algorithms/losses.py index 3980f9136..798e5330f 100644 --- a/rlinf/algorithms/losses.py +++ b/rlinf/algorithms/losses.py @@ -196,7 +196,8 @@ def compute_math_ppo_actor_loss(**kwargs): loss_agg_func = kwargs["loss_agg_func"] logprobs = kwargs["logprobs"] old_logprobs = kwargs["old_logprobs"] - eps_clip = kwargs["eps_clip"] + clip_ratio_low = kwargs["clip_ratio_low"] + clip_ratio_high = kwargs["clip_ratio_high"] advantages = kwargs["advantages"] loss_mask = kwargs.get("loss_mask", None) c_clip = kwargs.get("c_clip", None) @@ -212,7 +213,7 @@ def compute_math_ppo_actor_loss(**kwargs): ratio = torch.where(loss_mask, torch.exp(logprobs - old_logprobs), 0) approx_kl = torch.where(loss_mask, (logprobs - old_logprobs).detach(), 0.0) - clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip) + clipped_ratio = torch.clamp(ratio, 1.0 - clip_ratio_low, 1.0 + clip_ratio_high) policy_loss1 = -advantages * ratio policy_loss2 = -advantages * clipped_ratio @@ -239,12 +240,12 @@ def compute_math_ppo_actor_loss(**kwargs): # Compile metrics for logging metrics_data = { - "policy_loss": masked_mean(policy_loss.detach(), loss_mask), - "ratio": masked_mean(ratio.detach(), loss_mask), - "clipped_ratio": masked_mean(clipped_ratio.detach(), loss_mask), - "dual_cliped_ratio": masked_mean(dual_cliped_ratio.detach(), loss_mask), - "approx_kl": approx_kl.detach(), - "clip_fraction": clip_fraction.detach(), + "policy_loss": masked_mean(policy_loss.detach(), loss_mask).cpu(), + "ratio": masked_mean(ratio.detach(), loss_mask).cpu(), + "clipped_ratio": masked_mean(clipped_ratio.detach(), loss_mask).cpu(), + "dual_cliped_ratio": masked_mean(dual_cliped_ratio.detach(), loss_mask).cpu(), + "approx_kl": approx_kl.detach().cpu(), + "clip_fraction": clip_fraction.detach().cpu(), } return policy_loss, metrics_data diff --git a/rlinf/config.py b/rlinf/config.py index f3b2984cd..0f3a21903 100644 --- a/rlinf/config.py +++ b/rlinf/config.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses +import importlib.util import logging import os from dataclasses import asdict @@ -33,15 +34,7 @@ logging.getLogger().setLevel(logging.INFO) -try: - import transformer_engine - - HAVE_TE = True -except ImportError: - transformer_engine = None - HAVE_TE = False - -SUPPORTED_MODEL_ARCHS = ["qwen2.5", "openvla", "openvla_oft"] +SUPPORTED_MODEL_ARCHS = ["qwen2.5", "qwen2.5_vl", "openvla", "openvla_oft"] SUPPORTED_ROLLOUT_BACKENDS = ["sglang", "vllm"] __all__ = ["build_config"] @@ -764,7 +757,10 @@ def build_transformer_config(cfg) -> "TransformerConfig": tp_only_amax_red = cfg.get("tp_only_amax_red", False) if cfg.get("enable_cuda_graph", False): - assert HAVE_TE, "Transformer Engine is required for cudagraphs." + if importlib.util.find_spec("transformer_engine") is None: + raise ImportError( + "Can not import transformer_engine, which is required for cudagraphs." + ) assert cfg.get("use_te_rng_tracker", False), ( "Transformer engine's RNG tracker is required for cudagraphs, this can be enabled with \ 'use_te_rng_tracker=True'." diff --git a/rlinf/data/datasets.py b/rlinf/data/datasets.py index e33667703..162d7b94d 100644 --- a/rlinf/data/datasets.py +++ b/rlinf/data/datasets.py @@ -21,7 +21,9 @@ import pandas as pd import torch from omegaconf import DictConfig +from PIL.Image import Image from torch.utils.data import Dataset +from transformers import AutoProcessor def batch_pad_to_fixed_len( @@ -80,12 +82,12 @@ def __init__(self, data_paths, config, tokenizer): if isinstance(self.data_paths, str): self.data_paths = [self.data_paths] - self.max_prompt_length = config.max_prompt_length + self.max_prompt_length = config.data.max_prompt_length self.tokenizer = tokenizer - self.prompt_key = config.prompt_key + self.prompt_key = config.data.prompt_key self.data = self._load_data() - if config.get("filter_prompt_by_length", False): + if config.data.get("filter_prompt_by_length", False): total = len(self.data) filtered = [] failed = 0 @@ -180,19 +182,20 @@ def __init__( ): super().__init__() self.data_paths = data_paths - self.use_chat_template = config.use_chat_template + self.use_chat_template = config.data.use_chat_template - self.image_keys = config.image_keys - self.prompt_key = config.prompt_key - self.choice_key = config.choice_key - self.answer_key = config.answer_key - self.solution_key = config.solution_key + self.image_keys = config.data.image_keys + self.prompt_key = config.data.prompt_key + self.choice_key = config.data.choice_key + self.answer_key = config.data.answer_key + self.solution_key = config.data.solution_key if isinstance(self.data_paths, str): self.data_paths = [self.data_paths] - self.max_prompt_length = config.max_prompt_length + self.max_prompt_length = config.data.max_prompt_length self.tokenizer = tokenizer + self.processor = AutoProcessor.from_pretrained(config.actor.model.model_path) self.data = self._load_data() self.post_process() @@ -206,17 +209,25 @@ def get_image_list( image_content = dataitem.get(key, None) if image_content is None: continue + if isinstance(image_content, Image): + image_content.append(image_content) if isinstance(image_content, dict) and "bytes" in image_content: image_content = image_content["bytes"] assert isinstance(image_content, bytes), ( f"image content should be bytes, but got {type(image_content)} , content is {image_content}" ) image_list.append(image_content) + if image_list == []: + return [None] return image_list def process_prompt( data_item: Dict, image_count: int - ) -> Tuple[str, List[int], int]: + ) -> Tuple[ + str, + List[int], + int, + ]: question = data_item.get(self.prompt_key, "") options = data_item.get(self.choice_key, []) if not isinstance(options, list): @@ -230,15 +241,14 @@ def process_prompt( message_content.append({"type": "image"}) message_content.append({"type": "text", "text": prompt_text}) messages = [{"role": "user", "content": message_content}] - prompt_text = self.tokenizer.apply_chat_template( + prompt_text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - prompt_ids = self.tokenizer.apply_chat_template( - messages, - tokenize=True, - add_generation_prompt=True, + prompt_ids = self.processor( + text=[prompt_text], + padding=True, return_tensors="pt", - ) + )["input_ids"] if isinstance(prompt_ids, torch.Tensor): if prompt_ids.dim() == 2 and prompt_ids.size(0) == 1: prompt_ids = prompt_ids.squeeze(0) # [L] @@ -278,7 +288,7 @@ def process_prompt( prompt=prompt_ids, length=prompt_length, image_data=image_list, - answer=answer, + answer=str(answer), solution=solution, idx=idx, ) @@ -318,11 +328,11 @@ def __getitem__(self, index): return self.data[index] -def create_rl_dataset(data_config, tokenizer): +def create_rl_dataset(config: DictConfig, tokenizer): """Create rl datasets. Arguments: - data_config: The data config. + config: The RLinf config. tokenizer (Tokenizer): The tokenizer. Returns: @@ -331,9 +341,9 @@ def create_rl_dataset(data_config, tokenizer): val_dataset (Dataset): The validation dataset. """ - if data_config.type == "math": + if config.data.type == "math": dataset_cls = MathDataset - elif data_config.type == "vision_language": + elif config.data.type == "vision_language": dataset_cls = VisionLanguageDataset else: return None, None @@ -342,14 +352,14 @@ def create_rl_dataset(data_config, tokenizer): # Instantiate the dataset using the determined dataset class train_dataset = dataset_cls( - data_paths=data_config.train_data_paths, - config=data_config, + data_paths=config.data.train_data_paths, + config=config, tokenizer=tokenizer, ) val_dataset = dataset_cls( - data_paths=data_config.val_data_paths, - config=data_config, + data_paths=config.data.val_data_paths, + config=config, tokenizer=tokenizer, ) diff --git a/rlinf/data/io_struct.py b/rlinf/data/io_struct.py index 510acb292..e40fed973 100644 --- a/rlinf/data/io_struct.py +++ b/rlinf/data/io_struct.py @@ -519,6 +519,152 @@ def merge_list(dst_list: List, src_list: List): return merged_result + @staticmethod + def split_result_list_by_group( + rollout_results: List["RolloutResult"], + ) -> List["RolloutResult"]: + """ + Split RolloutResult objects by group_size. + + If input has only one RolloutResult, split it into multiple RolloutResult objects by group_size. + If input has multiple RolloutResult objects, split each one and merge the results. + + Args: + rollout_results: List of input RolloutResult objects + + Returns: + List of RolloutResult objects grouped by group_size + """ + assert len(rollout_results) > 0, "No rollout results to split." + + all_split_results = [] + + for rollout_result in rollout_results: + split_results = RolloutResult._split_single_result_by_group(rollout_result) + all_split_results.extend(split_results) + + return all_split_results + + @staticmethod + def _split_single_result_by_group( + rollout_result: "RolloutResult", + ) -> List["RolloutResult"]: + """ + Split a single RolloutResult into multiple RolloutResult objects by group_size. + + Args: + rollout_result: The RolloutResult to be split + + Returns: + List of split RolloutResult objects + """ + group_size = rollout_result.group_size + num_sequence = rollout_result.num_sequence + + assert num_sequence % group_size == 0, ( + f"num_sequence ({num_sequence}) must be divisible by group_size ({group_size})" + ) + + num_groups = num_sequence // group_size + split_results = [] + + # Split list fields + prompt_lengths_split = split_list(rollout_result.prompt_lengths, num_groups) + prompt_ids_split = split_list(rollout_result.prompt_ids, num_groups) + response_lengths_split = split_list(rollout_result.response_lengths, num_groups) + response_ids_split = split_list(rollout_result.response_ids, num_groups) + is_end_split = split_list(rollout_result.is_end, num_groups) + + # Handle optional fields + answers_split = None + if rollout_result.answers is not None: + answers_split = split_list(rollout_result.answers, num_groups) + + image_data_split = None + if rollout_result.image_data is not None: + image_data_split = split_list(rollout_result.image_data, num_groups) + + prompt_texts_split = None + if rollout_result.prompt_texts is not None: + prompt_texts_split = split_list(rollout_result.prompt_texts, num_groups) + + response_texts_split = None + if rollout_result.response_texts is not None: + response_texts_split = split_list(rollout_result.response_texts, num_groups) + + rollout_logprobs_split = None + if rollout_result.rollout_logprobs is not None: + rollout_logprobs_split = split_list( + rollout_result.rollout_logprobs, num_groups + ) + + # Handle tensor fields + rewards_split = None + if rollout_result.rewards is not None: + if isinstance(rollout_result.rewards, torch.Tensor): + rewards_split = torch.chunk(rollout_result.rewards, num_groups, dim=0) + else: + rewards_split = split_list(rollout_result.rewards, num_groups) + + advantages_split = None + if rollout_result.advantages is not None: + if isinstance(rollout_result.advantages, torch.Tensor): + advantages_split = torch.chunk( + rollout_result.advantages, num_groups, dim=0 + ) + else: + advantages_split = split_list(rollout_result.advantages, num_groups) + + prev_logprobs_split = None + if rollout_result.prev_logprobs is not None: + prev_logprobs_split = torch.chunk( + rollout_result.prev_logprobs, num_groups, dim=0 + ) + + ref_logprobs_split = None + if rollout_result.ref_logprobs is not None: + ref_logprobs_split = torch.chunk( + rollout_result.ref_logprobs, num_groups, dim=0 + ) + + # Create split RolloutResult objects + for i in range(num_groups): + split_result = RolloutResult( + num_sequence=group_size, + group_size=group_size, + prompt_lengths=prompt_lengths_split[i], + prompt_ids=prompt_ids_split[i], + response_lengths=response_lengths_split[i], + response_ids=response_ids_split[i], + is_end=is_end_split[i], + answers=answers_split[i] if answers_split is not None else None, + image_data=image_data_split[i] + if image_data_split is not None + else None, + prompt_texts=prompt_texts_split[i] + if prompt_texts_split is not None + else None, + response_texts=response_texts_split[i] + if response_texts_split is not None + else None, + rollout_logprobs=rollout_logprobs_split[i] + if rollout_logprobs_split is not None + else None, + rewards=rewards_split[i] if rewards_split is not None else None, + advantages=advantages_split[i] + if advantages_split is not None + else None, + prev_logprobs=prev_logprobs_split[i] + if prev_logprobs_split is not None + else None, + ref_logprobs=ref_logprobs_split[i] + if ref_logprobs_split is not None + else None, + ) + split_results.append(split_result) + + return split_results + def to_actor_batch( self, data_seq_length: int, diff --git a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py index 8e3b9f16c..5f04d633f 100644 --- a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py +++ b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py @@ -19,9 +19,10 @@ from omegaconf import DictConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType -from transformers import AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq from rlinf.config import torch_dtype_from_precision +from rlinf.data.tokenizers import hf_tokenizer from rlinf.hybrid_engines.fsdp.utils import ( get_fsdp_wrap_policy, init_fn, @@ -40,13 +41,15 @@ def __init__(self, cfg: DictConfig): self.logger = get_logger() self.torch_dtype = torch_dtype_from_precision(self._cfg.model.precision) - assert ( - self.torch_dtype == torch.float16 or self.torch_dtype == torch.bfloat16 - ), ( - f"Precision {self._cfg.model.precision} is not supported, only support bf16 and fp16." - ) + self.tokenizer = hf_tokenizer(cfg.tokenizer.tokenizer_model) def model_provider_func(self) -> torch.nn.Module: + model_config = AutoConfig.from_pretrained( + self._cfg.model.model_path, + trust_remote_code=True, + attn_implementation="flash_attention_2", + ) + if self._cfg.model.get("gptq_model", False): from auto_gptq import AutoGPTQForCausalLM @@ -57,23 +60,31 @@ def model_provider_func(self) -> torch.nn.Module: elif self._cfg.model.get("load_in_8bit", False): model = AutoModelForCausalLM.from_pretrained( self._cfg.model.model_path, - device_map=self._cfg.model.get("device_map", "auto"), load_in_8bit=True, ) else: + if type(model_config) in AutoModelForVision2Seq._model_mapping.keys(): + auto_model_class = AutoModelForVision2Seq + else: + auto_model_class = AutoModelForCausalLM + + # TODO: fix this, load model in float16/bfloat16 may cause optimizer in bf16, which is incorrect # default load in float16 - model = AutoModelForCausalLM.from_pretrained( + model = auto_model_class.from_pretrained( self._cfg.model.model_path, torch_dtype=self.torch_dtype, - device_map=self._cfg.model.get("device_map", "auto"), + config=model_config, trust_remote_code=True, - use_safetensors=self._cfg.model.get("use_safetensors", False), ) - if torch.cuda.is_available(): - model = model.cuda() - if self.torch_dtype == torch.float16: - model = model.half() + model.to(self.torch_dtype) + + if torch.cuda.is_available(): + model = model.cuda() + if self.torch_dtype == torch.float16: + model = model.half() + + torch.distributed.barrier() return model def setup_model_and_optimizer(self): @@ -89,8 +100,8 @@ def setup_model_and_optimizer(self): mixed_precision = MixedPrecision( param_dtype=self.torch_dtype, - reduce_dtype=self.torch_dtype, - buffer_dtype=self.torch_dtype, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, ) if self._cfg.model.sharding_strategy == "full_shard": @@ -108,7 +119,6 @@ def setup_model_and_optimizer(self): self.model = FSDP( module, param_init_fn=init_fn, - use_orig_params=False, auto_wrap_policy=auto_wrap_policy, device_id=int(os.environ["LOCAL_RANK"]), sharding_strategy=sharding_strategy, # zero3 @@ -130,7 +140,7 @@ def setup_model_and_optimizer(self): }, ] - if self._cfg.model.vh_mode in ["a", "a0", "a6"]: + if self._cfg.model.get("vh_mode", None) in ["a", "a0", "a6"]: param_groups.append( { "params": [ diff --git a/rlinf/hybrid_engines/fsdp/utils.py b/rlinf/hybrid_engines/fsdp/utils.py index 2334b1fa7..2461f7006 100644 --- a/rlinf/hybrid_engines/fsdp/utils.py +++ b/rlinf/hybrid_engines/fsdp/utils.py @@ -58,7 +58,7 @@ def cpu_init_weights(): return init_context -def get_fsdp_wrap_policy(module, config=None, is_lora=False): +def get_fsdp_wrap_policy(module, config=None, is_lora=False, is_vla_model=False): """ FSDP wrap policy that handles both standard transformer models and VLA models. @@ -76,11 +76,8 @@ def get_fsdp_wrap_policy(module, config=None, is_lora=False): if config.get("disable", False): return None - # Check if this is a VLA model by looking for language_model attribute - is_vla_model = hasattr(module, "language_model") - # Get transformer layer classes to wrap - if is_vla_model: + if hasattr(module, "language_model"): # For VLA models, get transformer classes from language_model submodule default_transformer_cls_names_to_wrap = getattr( module.language_model, "_no_split_modules", None @@ -100,6 +97,7 @@ def get_fsdp_wrap_policy(module, config=None, is_lora=False): # Add vision transformer policies for VLA models if is_vla_model: + from prismatic.extern.hf.modeling_prismatic import PrismaticProjector from timm.models.vision_transformer import VisionTransformer from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy diff --git a/rlinf/hybrid_engines/sglang/sglang_0_4_4/sgl_scheduler.py b/rlinf/hybrid_engines/sglang/sglang_0_4_4/sgl_scheduler.py index 509c56bab..87f736798 100644 --- a/rlinf/hybrid_engines/sglang/sglang_0_4_4/sgl_scheduler.py +++ b/rlinf/hybrid_engines/sglang/sglang_0_4_4/sgl_scheduler.py @@ -108,10 +108,13 @@ def __init__( placement )[(self.get_parent_rank(), self._rank)] + use_presharded_weights = ( + False if self.cfg.actor.training_backend == "fsdp" else True + ) # it's important to use load_weight to load resharded weight from megatron for _, module in self.tp_worker.worker.model_runner.model.named_modules(): if hasattr(module, "use_presharded_weights"): - module.use_presharded_weights = True + module.use_presharded_weights = use_presharded_weights self._logger.info( f"Running Scheduler dp rank {self.get_parent_rank()}, tp rank {self.tp_rank}, corresponding actor weight rank = {self.actor_weight_rank}" diff --git a/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py b/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py index ef503527b..684f8d333 100644 --- a/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py +++ b/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py @@ -29,6 +29,7 @@ ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, ) +from sglang.srt.managers.mm_utils import init_embedding_cache from sglang.srt.managers.scheduler import Scheduler as _Scheduler from sglang.srt.managers.scheduler import logger from sglang.srt.server_args import PortArgs, ServerArgs @@ -110,10 +111,13 @@ def __init__( self.actor_weight_rank = RankMapper.get_rollout_rank_to_actor_rank_map( placement )[(self.get_parent_rank(), self._rank)] + use_presharded_weights = ( + False if self.cfg.actor.training_backend == "fsdp" else True + ) # it's important to use load_weight to load resharded weight from megatron for _, module in self.tp_worker.worker.model_runner.model.named_modules(): if hasattr(module, "use_presharded_weights"): - module.use_presharded_weights = True + module.use_presharded_weights = use_presharded_weights self._logger.info( f"Running Scheduler dp rank {self.get_parent_rank()}, tp rank {self.tp_rank}, corresponding actor weight rank = {self.actor_weight_rank}" diff --git a/rlinf/hybrid_engines/sglang/sglang_0_4_9/sgl_scheduler.py b/rlinf/hybrid_engines/sglang/sglang_0_4_9/sgl_scheduler.py index a7057a161..1f9beb409 100644 --- a/rlinf/hybrid_engines/sglang/sglang_0_4_9/sgl_scheduler.py +++ b/rlinf/hybrid_engines/sglang/sglang_0_4_9/sgl_scheduler.py @@ -112,10 +112,13 @@ def __init__( self.actor_weight_rank = RankMapper.get_rollout_rank_to_actor_rank_map( placement )[(self.get_parent_rank(), self._rank)] + use_presharded_weights = ( + False if self.cfg.actor.training_backend == "fsdp" else True + ) # it's important to use load_weight to load resharded weight from megatron for _, module in self.tp_worker.worker.model_runner.model.named_modules(): if hasattr(module, "use_presharded_weights"): - module.use_presharded_weights = True + module.use_presharded_weights = use_presharded_weights self._logger.info( f"Running Scheduler dp rank {self.get_parent_rank()}, tp rank {self.tp_rank}, corresponding actor weight rank = {self.actor_weight_rank}" diff --git a/rlinf/models/__init__.py b/rlinf/models/__init__.py index 08207900a..617ac7467 100644 --- a/rlinf/models/__init__.py +++ b/rlinf/models/__init__.py @@ -17,7 +17,6 @@ import torch from omegaconf import DictConfig -from peft import LoraConfig, PeftModel, get_peft_model from transformers import ( AutoConfig, AutoImageProcessor, @@ -172,6 +171,8 @@ def get_model(model_path, cfg: DictConfig, override_config_kwargs=None): model = model.cuda() if cfg.is_lora: + from peft import LoraConfig, PeftModel, get_peft_model + if not hasattr(cfg, "lora_path") or cfg.lora_path is None: lora_config = LoraConfig( r=cfg.lora_rank, diff --git a/rlinf/models/embodiment/model_utils.py b/rlinf/models/embodiment/model_utils.py index 04425cfdc..8e7aebeb1 100644 --- a/rlinf/models/embodiment/model_utils.py +++ b/rlinf/models/embodiment/model_utils.py @@ -15,9 +15,10 @@ from typing import Any, Optional import torch -import torch.nn.functional as F from transformers.generation import TopKLogitsWarper +from rlinf.utils.utils import compute_entropy_from_logits, compute_logprobs_from_logits + def default_logits_processor(logits, action_tokens, vocab_size, n_action_bins): logits = logits.permute(0, 2, 1) # [B, vocab-size, action-dim] @@ -34,28 +35,6 @@ def default_logits_processor(logits, action_tokens, vocab_size, n_action_bins): return ret -def compute_logprobs_from_logits(logits, target): - logprobs = -F.cross_entropy( - logits, target=target, reduction="none" - ) # [B, action-dim] - return logprobs - - -def compute_entropy_from_logits(logits, epsilon=1e-10): - """ - Compute entropy by logits. - - Args: - logits: [B, vocab-size, seq-len] - Returns: - entropy: [B, seq-len] - """ - all_probs = F.softmax(logits, dim=1) # [B, vocab-size, seq-len] - all_log_probs = torch.log(all_probs + epsilon) - entropy = -torch.sum(all_probs * all_log_probs, dim=1) # [B, seq-len] - return entropy - - def custom_forward( model, input_ids, diff --git a/rlinf/runners/math_runner.py b/rlinf/runners/math_runner.py index a88826f5a..be2cfac1f 100644 --- a/rlinf/runners/math_runner.py +++ b/rlinf/runners/math_runner.py @@ -424,7 +424,7 @@ def run(self): } self.metric_logger.log(training_metrics, logging_steps + i) - logging_metrics = time_metrics + logging_metrics = {f"{k}_time": v for k, v in time_metrics.items()} if self.cfg.actor.get("calculate_flops", False): flops_metrics = self._compute_flops_metrics( diff --git a/rlinf/utils/convertor/utils.py b/rlinf/utils/convertor/utils.py index 5d15e3639..e187bb919 100644 --- a/rlinf/utils/convertor/utils.py +++ b/rlinf/utils/convertor/utils.py @@ -447,7 +447,7 @@ def register_mg2hf_convertor(model_arch: str, convertor_cls: Callable) -> None: register_mg2hf_convertor("qwen2.5", Qwen2_5Convertor) -register_mg2hf_convertor("qwen2.5-vl", Qwen2_5VLConvertor) +register_mg2hf_convertor("qwen2.5_vl", Qwen2_5VLConvertor) def get_mg2hf_convertor(model_arch: str, config, strict: bool = False) -> BaseConvertor: diff --git a/rlinf/utils/distributed.py b/rlinf/utils/distributed.py index e9da8d6da..a54f444d8 100644 --- a/rlinf/utils/distributed.py +++ b/rlinf/utils/distributed.py @@ -31,7 +31,12 @@ def compute_rollout_metrics( - rollout_batch, max_prompt_len, response_len, use_critic=False + rollout_batch, + max_prompt_len, + response_len, + dp_world_size, + dp_group=None, + use_critic=False, ): device = torch.device(f"cuda:{torch.cuda.current_device()}") advantages = rollout_batch["advantages"].to(device=device) @@ -41,8 +46,6 @@ def compute_rollout_metrics( reward_scores = rollout_batch["rewards"].clone().to(device=device) is_end = rollout_batch["is_end"].clone().float().to(device=device) - dp_world_size = parallel_state.get_data_parallel_world_size() - prompt_lengths_list = [ torch.empty_like(prompt_lengths) for _ in range(dp_world_size) ] @@ -52,12 +55,12 @@ def compute_rollout_metrics( torch.distributed.all_gather( prompt_lengths_list, prompt_lengths, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) torch.distributed.all_gather( decode_lengths_list, response_lengths, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) total_prompt_lengths = torch.cat(prompt_lengths_list, dim=0) @@ -66,22 +69,22 @@ def compute_rollout_metrics( torch.distributed.all_reduce( prompt_lengths, torch.distributed.ReduceOp.AVG, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) torch.distributed.all_reduce( response_lengths, torch.distributed.ReduceOp.AVG, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) torch.distributed.all_reduce( reward_scores, torch.distributed.ReduceOp.AVG, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) torch.distributed.all_reduce( is_end, torch.distributed.ReduceOp.AVG, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) valid_adv = torch.masked_select(advantages, mask) @@ -90,12 +93,12 @@ def compute_rollout_metrics( torch.distributed.all_reduce( n_valid_token, op=torch.distributed.ReduceOp.SUM, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) torch.distributed.all_reduce( adv_sum, op=torch.distributed.ReduceOp.SUM, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) adv_mean = adv_sum / n_valid_token @@ -107,7 +110,7 @@ def compute_rollout_metrics( torch.distributed.all_reduce( reduce_tensor, torch.distributed.ReduceOp.MAX, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) adv_min, adv_max = reduce_tensor.tolist() diff --git a/rlinf/utils/placement.py b/rlinf/utils/placement.py index 7b707894e..10738ce9e 100644 --- a/rlinf/utils/placement.py +++ b/rlinf/utils/placement.py @@ -226,9 +226,6 @@ def __init__(self, config: DictConfig, cluster: Cluster): self._rollout_num_gpus = len(self._rollout_gpus) if self._is_collocated(): - assert self.actor_tp_size >= self.rollout_tp_size, ( - f"Actor TP size {self.actor_tp_size} must be greater or equal to Rollout TP size {self.rollout_tp_size}." - ) assert self._inference_gpus is None, ( "Inference GPUs must not be specified in collocated mode." ) @@ -282,11 +279,14 @@ def _generate_placements(self): actor_tp_size = self._config.actor.model.tensor_model_parallel_size rollout_tp_size = self._config.rollout.tensor_parallel_size - assert actor_tp_size >= rollout_tp_size, ( - f"Actor TP size ({actor_tp_size}) must be greater or equal to Rollout TP size ({rollout_tp_size})" - ) - assert actor_tp_size % rollout_tp_size == 0, ( - f"Actor TP size ({actor_tp_size}) must be divisible by Rollout TP size ({rollout_tp_size})" + if actor_tp_size > rollout_tp_size: + assert actor_tp_size % rollout_tp_size == 0, ( + f"Actor TP size ({actor_tp_size}) must be divisible by Rollout TP size ({rollout_tp_size})" + ) + stride = ( + self.actor_tp_size // self.rollout_tp_size + if self.actor_tp_size > self.rollout_tp_size + else 1 ) stride = actor_tp_size // rollout_tp_size self._placements["rollout"] = PackedPlacementStrategy( @@ -325,18 +325,18 @@ def has_dedicated_inference(self): @property def actor_dp_size(self) -> int: return self._actor_num_gpus // ( - self._config.actor.model.tensor_model_parallel_size - * self._config.actor.model.context_parallel_size - * self._config.actor.model.pipeline_model_parallel_size + self._config.actor.model.get("tensor_model_parallel_size", 1) + * self._config.actor.model.get("context_parallel_size", 1) + * self._config.actor.model.get("pipeline_model_parallel_size", 1) ) @property def actor_tp_size(self) -> int: - return self._config.actor.model.tensor_model_parallel_size + return self._config.actor.model.get("tensor_model_parallel_size", 1) @property def actor_pp_size(self) -> int: - return self._config.actor.model.pipeline_model_parallel_size + return self._config.actor.model.get("pipeline_model_parallel_size", 1) @property def actor_world_size(self) -> int: @@ -349,7 +349,7 @@ def inference_tp_size(self) -> int: and hasattr(self._config.inference, "model") and hasattr(self._config.inference.model, "tensor_model_parallel_size") ): - return self._config.inference.model.tensor_model_parallel_size + return self._config.inference.model.get("tensor_model_parallel_size", 1) else: return self.actor_tp_size @@ -360,7 +360,7 @@ def inference_pp_size(self) -> int: and hasattr(self._config.inference, "model") and hasattr(self._config.inference.model, "pipeline_model_parallel_size") ): - return self._config.inference.model.pipeline_model_parallel_size + return self._config.inference.model.get("pipeline_model_parallel_size", 1) else: return self.actor_pp_size @@ -377,13 +377,13 @@ def inference_world_size(self) -> int: @property def rollout_dp_size(self) -> int: return self._rollout_num_gpus // ( - self._config.rollout.tensor_parallel_size - * self._config.rollout.pipeline_parallel_size + self._config.rollout.get("tensor_parallel_size", 1) + * self._config.rollout.get("pipeline_parallel_size", 1) ) @property def rollout_tp_size(self) -> int: - return self._config.rollout.tensor_parallel_size + return self._config.rollout.get("tensor_parallel_size", 1) @property def rollout_world_size(self) -> int: diff --git a/rlinf/utils/resharding/utils.py b/rlinf/utils/resharding/utils.py index 82ca3eadf..d7a4af231 100644 --- a/rlinf/utils/resharding/utils.py +++ b/rlinf/utils/resharding/utils.py @@ -14,7 +14,6 @@ import torch -from megatron.core import parallel_state def get_tp_reshard_fn(model_arch: str): @@ -96,6 +95,8 @@ def _gather_pp_group_tensor_and_reshard( def pp_reshard_fn_qwen2_5(model_state_dict, pp_group, dtype): + from megatron.core import parallel_state + pp_first_rank = parallel_state.get_pipeline_model_parallel_first_rank() pp_last_rank = parallel_state.get_pipeline_model_parallel_last_rank() diff --git a/rlinf/utils/utils.py b/rlinf/utils/utils.py index 449412f8b..d117f4dc4 100644 --- a/rlinf/utils/utils.py +++ b/rlinf/utils/utils.py @@ -20,6 +20,7 @@ from functools import partial, wraps import torch +import torch.nn.functional as F def clear_memory(sync=True): @@ -124,6 +125,57 @@ def seq_mean_token_mean(values, mask): return loss +def logprobs_from_logits_flash_attn(logits, labels, inplace_backward=True): + from flash_attn.ops.triton.cross_entropy import cross_entropy_loss + + output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward) + assert isinstance(output, tuple), ( + "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]." + ) + return -output[0] + + +def compute_logprobs_from_logits(logits, target, task_type="embodied"): + if task_type == "embodied": + logprobs = -F.cross_entropy( + logits, target=target, reduction="none" + ) # [B, action-dim] + return logprobs + batch_dim = logits.shape[:-1] + last_dim = logits.shape[-1] + logits = logits.reshape(-1, last_dim) + labels = target.reshape(-1) + logprobs = logprobs_from_logits_flash_attn( + logits, labels=labels, inplace_backward=False + ) + logprobs = logprobs.view(*batch_dim) + return logprobs + + +def entropy_from_logits(logits: torch.Tensor): + """Calculate entropy from logits.""" + pd = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) + return entropy + + +def compute_entropy_from_logits(logits, epsilon=1e-10, task_type="embodied"): + """ + Compute entropy by logits. + + Args: + logits: [B, vocab-size, seq-len] + Returns: + entropy: [B, seq-len] + """ + if task_type == "embodied": + all_probs = F.softmax(logits, dim=1) # [B, vocab-size, seq-len] + all_log_probs = torch.log(all_probs + epsilon) + entropy = -torch.sum(all_probs * all_log_probs, dim=1) # [B, seq-len] + return entropy + return entropy_from_logits(logits=logits) + + class DualOutput: def __init__(self, file, terminal): self.file = file diff --git a/rlinf/workers/actor/__init__.py b/rlinf/workers/actor/__init__.py index 5b365ea1e..2d315469e 100644 --- a/rlinf/workers/actor/__init__.py +++ b/rlinf/workers/actor/__init__.py @@ -11,3 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from omegaconf import DictConfig + +from rlinf.scheduler.worker.worker import Worker + + +def get_actor_worker(cfg: DictConfig) -> Worker: + if cfg.actor.training_backend == "fsdp": + from .fsdp_actor_worker import FSDPActor + + return FSDPActor + elif cfg.actor.training_backend == "megatron": + from .megatron_actor_worker import MegatronActor + + return MegatronActor + else: + raise ValueError(f"Unsupported training backend: {cfg.actor.training_backend}") diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index 61b05c9b4..37dd1d9ba 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -14,31 +14,480 @@ import gc import os +from typing import Dict, List, Tuple import numpy as np import torch from omegaconf import DictConfig from torch.distributed.device_mesh import init_device_mesh +from torch.multiprocessing.reductions import reduce_tensor from tqdm import tqdm import rlinf.algorithms # noqa: F401 from rlinf.algorithms.registry import actor_loss, calculate_adv_and_returns -from rlinf.algorithms.utils import preprocess_advantages_inputs, preprocess_loss_inputs +from rlinf.algorithms.utils import ( + kl_penalty, + preprocess_advantages_inputs, + preprocess_loss_inputs, +) +from rlinf.data.io_struct import RolloutResult from rlinf.hybrid_engines.fsdp.fsdp_model_manager import ( FSDPModelManager, ) from rlinf.models import get_model from rlinf.models.embodiment.model_utils import custom_forward -from rlinf.scheduler import Cluster, Worker +from rlinf.scheduler import Channel, Cluster, Worker from rlinf.utils.data_iter_utils import get_iterator_k_split from rlinf.utils.distributed import all_reduce_dict +from rlinf.utils.distributed import ( + compute_rollout_metrics as compute_math_rollout_metrics, +) from rlinf.utils.metric_utils import ( append_to_dict, compute_loss_mask, compute_rollout_metrics, compute_split_num, ) -from rlinf.utils.placement import HybridComponentPlacement +from rlinf.utils.placement import ( + HybridComponentPlacement, + ModelParallelComponentPlacement, +) +from rlinf.utils.utils import ( + compute_entropy_from_logits, + compute_logprobs_from_logits, + masked_mean, + seq_mean_token_mean, + seq_mean_token_sum, +) +from rlinf.workers.rollout.utils import RankMapper +from toolkits.math_verifier.verify import math_verify_call + + +class FSDPActor(FSDPModelManager, Worker): + def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): + Worker.__init__(self) + super().__init__(cfg.actor) + + self.cfg = cfg + + self.response_len = ( + cfg.actor.model.encoder_seq_length - cfg.data.max_prompt_length + ) + self.calculate_entropy = self.cfg.algorithm.calculate_entropy + self.calculate_entropy_loss = ( + self.cfg.algorithm.entropy_bonus > 0 and self.calculate_entropy + ) + self.kl_beta = self.cfg.algorithm.kl_beta + self.kl_penalty_type = self.cfg.algorithm.kl_penalty_type + + self.total_batch_size_per_dp = ( + self.cfg.data.rollout_batch_size + * self.cfg.algorithm.group_size + // self._world_size + ) + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + self.device = torch.cuda.current_device() + world_size = self._world_size + self.device_mesh = init_device_mesh( + "cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] + ) + + self._rollout_group_name = cfg.rollout.group_name + self._component_placement = placement + self.is_data_io_rank = True + + if self.cfg.algorithm.loss_agg_func == "token-mean": + self.loss_agg_func = masked_mean + elif self.cfg.algorithm.loss_agg_func == "seq-mean-token-sum": + self.loss_agg_func = seq_mean_token_sum + elif self.cfg.algorithm.loss_agg_func == "seq-mean-token-mean": + self.loss_agg_func = seq_mean_token_mean + else: + raise NotImplementedError( + f"algorithm.loss_agg_func={self.cfg.algorithm.loss_agg_func} is not supported!" + ) + + # Reward configurations + if not self.cfg.reward.use_reward_model: + assert self.cfg.reward.reward_type == "math", "only support math" + self.reward_fn = math_verify_call + + def init_worker(self): + self.setup_model_and_optimizer() + if self.cfg.actor.get("enable_offload", False): + self.offload_fsdp_param_and_grad() + self.offload_fsdp_optimizer() + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + self._setup_rollout_weight_dst_ranks() + + def _setup_rollout_weight_dst_ranks(self): + """Setup destination ranks for token and weight communication.""" + rank_map = RankMapper.get_actor_rank_to_rollout_rank_map( + self._component_placement + ) + self._weight_dst_rank_in_rollout = rank_map[self._rank] + self.log_info( + f"Actor rank {self._rank} will send weights to {self._weight_dst_rank_in_rollout}" + ) + + def del_reshard_state_dict(self): + if hasattr(self, "rollou_state_dict"): + del self.rollou_state_dict + + def sync_model_to_rollout(self): + if next(self.model.parameters()).is_cpu: + self.load_fsdp_param_and_grad(self.device) + + self.rollou_state_dict = self.get_model_state_dict() + + if self._weight_dst_rank_in_rollout is not None: + + def transform_key(k): + if k.startswith("model.language_model."): + return "model." + k[21:] + elif k.startswith("model."): + return k[6:] + else: + return k + + handle = { + transform_key(k): reduce_tensor(v) + for k, v in self.rollou_state_dict.items() + } + + self.send( + handle, self._rollout_group_name, self._weight_dst_rank_in_rollout + ) + if self.cfg.actor.get("enable_offload", False): + self.offload_fsdp_param_and_grad() + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + + def compute_logprobs(self): + self.model.eval() + self.rollout_batch["logprob"] = self.rollout_batch["prev_logprobs"] + + def get_batch( + self, channel: Channel + ) -> Tuple[Dict[str, torch.Tensor], RolloutResult]: + result: RolloutResult = channel.get() + + batch = result.to_actor_batch( + self.cfg.data.max_prompt_length, + self.cfg.actor.model.encoder_seq_length, + self.tokenizer.eos_token_id, + ) + return batch, result + + def put_result(self, result: RolloutResult, channel: Channel): + if channel.is_local: + # Local channel, every process will put its own data locally + # No need to broadcast + channel.put(result) + else: + if self.is_data_io_rank: + channel.put(result) + + def _load_weight_and_optimizer(self, channel: Channel): + # Acquire the GPUs to ensure that no one is using them before loading models + # Otherwise, it may lead to OOM + with channel.gpu_lock: + if self.cfg.actor.get("enable_offload", False): + self.load_fsdp_param_and_grad(self.device) + self.load_fsdp_optimizer(self.device) + + def run_training(self, input_channel: Channel): + # Get all batches for this DP + batches = [] + recv_batch_size = 0 + while recv_batch_size < self.total_batch_size_per_dp: + batch, rollout_result = self.get_batch(input_channel) + batches.append(batch) + recv_batch_size += rollout_result.num_sequence + assert recv_batch_size == self.total_batch_size_per_dp, ( + f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" + ) + batch = RolloutResult.merge_batches(batches) + + # Must be called after batch is retrieved, which is when rollout has stopped + # Otherwise, loading model might cause OOM + self._load_weight_and_optimizer(input_channel) + + global_batches = get_iterator_k_split( + batch, + num_splits=self.cfg.algorithm.n_minibatches, + shuffle=self.cfg.algorithm.get("shuffle_rollout", True), + shuffle_seed=self.cfg.actor.seed, + ) + + self.model.train() + assert ( + self.cfg.actor.global_batch_size + % (self.cfg.actor.micro_batch_size * self._world_size) + == 0 + ) + + self.gradient_accumulation = ( + self.cfg.actor.global_batch_size + // self.cfg.actor.micro_batch_size + // self._world_size + ) + + training_metrics_list = [] + # Global batch iterations + with self.worker_timer(): + for global_batch in global_batches: + train_global_batch_size = global_batch["input_ids"].shape[0] + assert ( + train_global_batch_size + == self.cfg.actor.global_batch_size + // torch.distributed.get_world_size() + ) + assert train_global_batch_size % self.cfg.actor.micro_batch_size == 0, ( + f"{train_global_batch_size=}, {self.cfg.actor.micro_batch_size}" + ) + + self.gradient_accumulation = ( + self.cfg.actor.global_batch_size + // self.cfg.actor.micro_batch_size + // self._world_size + ) + # split batch into micro_batches + train_micro_batches = get_iterator_k_split( + global_batch, + train_global_batch_size // self.cfg.actor.micro_batch_size, + ) + + self.optimizer.zero_grad() + metrics = {} + for _, m_batch in enumerate(train_micro_batches): + for k, v in m_batch.items(): + m_batch[k] = v.to(f"cuda:{int(os.environ['LOCAL_RANK'])}") + + multi_modal_inputs = {} + if "multi_modal_inputs" in m_batch.keys(): + if ( + "image_bound" in m_batch["multi_modal_inputs"][0] + ): # minicpm-o logic + for key in m_batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = [ + inputs[key] + for inputs in m_batch["multi_modal_inputs"] + ] + else: + for key in m_batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat( + [ + inputs[key] + for inputs in m_batch["multi_modal_inputs"] + ], + dim=0, + ) + + input_ids = m_batch["input_ids"] + attention_mask = m_batch["attention_mask"] + position_ids = m_batch["position_ids"] + prev_logprobs = m_batch["prev_logprobs"] + advantages = m_batch["advantages"] + ref_logprobs = None + if "ref_logprobs" in m_batch: + ref_logprobs = m_batch["ref_logprobs"] + + loss_mask = m_batch["attention_mask"][:, -self.response_len :] + + output = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + use_cache=False, + ) # prevent model thinks we are generating + + logits = output.logits + + logits.div_(self.cfg.algorithm.sampling_params.temperature) + + responses = input_ids[:, -self.response_len :] + logits = logits[ + :, -self.response_len - 1 : -1, : + ] # (bsz, response_length, vocab_size) + logprobs = compute_logprobs_from_logits( + logits, responses, task_type=self.cfg.runner.task_type + ) + if self.calculate_entropy: + entropy = compute_entropy_from_logits( + logits, task_type=self.cfg.runner.task_type + ) # (bsz, response_length) + + clip_ratio = self.cfg.algorithm.ratio_clip_eps + clip_ratio_low = ( + self.cfg.algorithm.clip_ratio_low + if self.cfg.algorithm.clip_ratio_low is not None + else clip_ratio + ) + clip_ratio_high = ( + self.cfg.algorithm.clip_ratio_high + if self.cfg.algorithm.clip_ratio_high is not None + else clip_ratio + ) + clip_ratio_c = self.cfg.algorithm.get("clip_ratio_c", 3.0) + + loss, mbs_metrics_data = actor_loss( + loss_type=self.cfg.algorithm.loss_type, + loss_agg_func=self.loss_agg_func, + logprobs=logprobs, + old_logprobs=prev_logprobs, + advantages=advantages, + clip_ratio_low=clip_ratio_low, + clip_ratio_high=clip_ratio_high, + clip_ratio_c=clip_ratio_c, + loss_mask=loss_mask, + ) + + entropy_loss = torch.tensor(0.0, device=torch.cuda.current_device()) + if self.calculate_entropy: + entropy = output["entropy"][ + :, -self.response_len - 1 : -1 + ].contiguous() + entropy_loss = self.loss_agg_func(entropy, mask=loss_mask) + if self.calculate_entropy_loss: + loss = ( + loss - self.cfg.algorithm.entropy_bonus * entropy_loss + ) + + kl_loss = torch.tensor(0.0, device=torch.cuda.current_device()) + if self.kl_beta > 0 and ref_logprobs is not None: + kld = kl_penalty(ref_logprobs, logprobs, self.kl_penalty_type) + kl_loss = self.loss_agg_func(kld, loss_mask) + loss = loss + kl_loss * self.kl_beta + + # add to log + mbs_metrics_data.update( + { + "final_loss": loss.detach().cpu(), + "entropy_loss": entropy_loss.detach().cpu(), + "kl_loss": kl_loss.detach().cpu(), + } + ) + + append_to_dict(metrics, mbs_metrics_data) + + mean_metric_dict = { + key: np.mean(value) for key, value in metrics.items() + } + mean_metric_dict = all_reduce_dict( + mean_metric_dict, op=torch.distributed.ReduceOp.AVG + ) + training_metrics_list.append(mean_metric_dict) + + # Rollout metrics + rollout_metrics, _, _ = compute_math_rollout_metrics( + batch, self.cfg.data.max_prompt_length, self.response_len, self._world_size + ) + + return rollout_metrics, training_metrics_list + + def save_checkpoint(self, save_base_path, step): + torch.distributed.barrier() + model_state = self.get_model_state_dict() + optim_state = self.get_optimizer_state_dict() + if self._rank == 0: + os.makedirs(save_base_path, exist_ok=True) + torch.save(model_state, os.path.join(save_base_path, "model.pt")) + torch.save(optim_state, os.path.join(save_base_path, "optim.pt")) + torch.distributed.barrier() + + def _compute_batch_rewards( + self, batch: Dict[str, torch.Tensor], answers: List[str] + ): + """Reward computation using non-model based reward.""" + texts = [] + for response, response_len in zip( + batch["input_ids"], + batch["response_lengths"], + ): + response = response[ + self.cfg.data.max_prompt_length : self.cfg.data.max_prompt_length + + response_len + ] + texts.append( + self.tokenizer.decode(response.tolist(), skip_special_tokens=True) + ) + rewards = self.reward_fn(texts, answers) + reward_scores = [ + self.cfg.reward.reward_scale + if reward == 1 + else -self.cfg.reward.reward_scale + for reward in rewards + ] + all_reward_scores = torch.as_tensor( + reward_scores, + dtype=torch.float, + device=torch.device("cpu"), + ).view(-1, 1) + return all_reward_scores.flatten() + + # Rewards + def compute_rewards(self, input_channel: Channel, output_channel: Channel): + """Compute rewards. + + Args: + input_channel: The input channel to read from. + output_channel: The output channel to send results to. + """ + recv_batch_size = 0 + while recv_batch_size < self.total_batch_size_per_dp: + batch, rollout_result = self.get_batch(input_channel) + recv_batch_size += rollout_result.num_sequence + + # Compute rule-based reward + with self.worker_timer(): + if rollout_result.rewards is None: + rollout_result.rewards = self._compute_batch_rewards( + batch, rollout_result.answers + ) + + self.put_result(rollout_result, output_channel) + + assert recv_batch_size == self.total_batch_size_per_dp, ( + f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" + ) + + # Advantages and returns + def compute_advantages_and_returns( + self, input_channel: Channel, output_channel: Channel + ): + """Compute the advantages and returns. + + Args: + input_channel: The input channel to read from. + output_channel: The output channel to send results to. + """ + recv_batch_size = 0 + while recv_batch_size < self.total_batch_size_per_dp: + batch, rollout_result = self.get_batch(input_channel) + recv_batch_size += rollout_result.num_sequence + + with self.worker_timer(): + if rollout_result.advantages is None: + mask = batch["attention_mask"][:, -self.response_len :] + advantages, returns = calculate_adv_and_returns( + adv_type=self.cfg.algorithm.adv_type, + reward_scores=batch["rewards"].cuda(), + mask=mask.cuda(), + num_responses=self.cfg.algorithm.group_size, + ) + rollout_result.advantages = advantages.cpu() + + self.put_result(rollout_result, output_channel) + + assert recv_batch_size == self.total_batch_size_per_dp, ( + f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" + ) class EmbodiedFSDPActor(FSDPModelManager, Worker): diff --git a/rlinf/workers/rollout/sglang/sglang_worker.py b/rlinf/workers/rollout/sglang/sglang_worker.py index 6a921c0bd..41f026879 100644 --- a/rlinf/workers/rollout/sglang/sglang_worker.py +++ b/rlinf/workers/rollout/sglang/sglang_worker.py @@ -207,8 +207,9 @@ def rollout(self, input_channel: Channel, output_channel: Channel): self._stop() # Release the GPUs once the engine has offloaded output_channel.device_lock.release() - rollout_result = RolloutResult.merge_result_list(rollout_results) - output_channel.put(rollout_result) + rollout_result_list = RolloutResult.split_result_list_by_group(rollout_results) + for rollout_result in rollout_result_list: + output_channel.put(rollout_result) def all_floats_equal(float_list: list[float], epsilon: float = 1e-9) -> bool: diff --git a/rlinf/workers/rollout/utils.py b/rlinf/workers/rollout/utils.py index f3845ef2f..f92e48caa 100644 --- a/rlinf/workers/rollout/utils.py +++ b/rlinf/workers/rollout/utils.py @@ -376,6 +376,12 @@ def get_actor_rank_to_rollout_rank_map( """ Get the global mapping from actor 1D rank to rollout 2D rank as dict. """ + # rank -> (dp, tp) + if actor_tp_size == 1: + return { + rank: (rank // rollout_tp_size, rank % rollout_tp_size) + for rank in range(actor_world_size) + } rank_map = {} for actor_rank in range(actor_world_size): rank_map[actor_rank] = cls._get_actor_rank_to_rollout_rank( From 7c382c8114912217117d5335e3e1a26bfe2194b2 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Fri, 19 Sep 2025 08:50:10 +0000 Subject: [PATCH 04/38] feat(dataset): refactor and add lazy loader process Signed-off-by: Bo Dai --- rlinf/data/datasets.py | 562 +++++++++++++++++++++++++++++++---------- 1 file changed, 424 insertions(+), 138 deletions(-) diff --git a/rlinf/data/datasets.py b/rlinf/data/datasets.py index 162d7b94d..2be41ce3b 100644 --- a/rlinf/data/datasets.py +++ b/rlinf/data/datasets.py @@ -16,14 +16,14 @@ import logging import os from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import pandas as pd import torch from omegaconf import DictConfig from PIL.Image import Image from torch.utils.data import Dataset -from transformers import AutoProcessor +from transformers import AutoProcessor, AutoTokenizer def batch_pad_to_fixed_len( @@ -176,156 +176,424 @@ def __getitem__(self, idx): return output -class VisionLanguageDataset(Dataset): +class VLMBaseDataset(Dataset): + def __init__( - self, data_paths: Union[List[str], str], config: DictConfig, tokenizer - ): + self, + data_paths: Union[List[str], str], + config: DictConfig, + tokenizer: AutoTokenizer, + *, + lazy_loading: bool = False, + ) -> None: super().__init__() - self.data_paths = data_paths - self.use_chat_template = config.data.use_chat_template + self.cfg = config + raw_paths = [data_paths] if isinstance(data_paths, str) else list(data_paths) + # Expand directories into file lists recursively (json/jsonl/parquet) + self.data_paths = self._expand_data_paths(raw_paths) + self.tokenizer = tokenizer + # Delay processor creation; only needed when use_chat_template is True + self._processor = None - self.image_keys = config.data.image_keys + self.use_chat_template = bool(config.data.use_chat_template) + self.image_keys = list(config.data.image_keys or []) self.prompt_key = config.data.prompt_key - self.choice_key = config.data.choice_key - self.answer_key = config.data.answer_key - self.solution_key = config.data.solution_key + self.choice_key = config.data.get("choice_key", None) + self.answer_key = config.data.get("answer_key", None) + self.solution_key = config.data.get("solution_key", None) + self.max_prompt_length = int(config.data.max_prompt_length) + self.eos_id = int(self.tokenizer.eos_token_id) - if isinstance(self.data_paths, str): - self.data_paths = [self.data_paths] + # Loading mode + self.lazy_loading = bool(getattr(config.data, "lazy_loading", lazy_loading)) - self.max_prompt_length = config.data.max_prompt_length - self.tokenizer = tokenizer - self.processor = AutoProcessor.from_pretrained(config.actor.model.model_path) - self.data = self._load_data() - self.post_process() - - def post_process(self) -> None: - def get_image_list( - dataitem: Dict, image_keys: Optional[List[str]] - ) -> List[Union[bytes, str]]: - image_list: List[Union[bytes, str]] = [] - if image_keys: - for key in image_keys: - image_content = dataitem.get(key, None) - if image_content is None: - continue - if isinstance(image_content, Image): - image_content.append(image_content) - if isinstance(image_content, dict) and "bytes" in image_content: - image_content = image_content["bytes"] - assert isinstance(image_content, bytes), ( - f"image content should be bytes, but got {type(image_content)} , content is {image_content}" - ) - image_list.append(image_content) - if image_list == []: - return [None] - return image_list - - def process_prompt( - data_item: Dict, image_count: int - ) -> Tuple[ - str, - List[int], - int, - ]: - question = data_item.get(self.prompt_key, "") - options = data_item.get(self.choice_key, []) - if not isinstance(options, list): - options = [options] - prompt_text = question - if options: - prompt_text += f"{options}\n" - if self.use_chat_template: - message_content: List = [] - for i in range(image_count): - message_content.append({"type": "image"}) - message_content.append({"type": "text", "text": prompt_text}) - messages = [{"role": "user", "content": message_content}] - prompt_text = self.processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - prompt_ids = self.processor( - text=[prompt_text], - padding=True, - return_tensors="pt", - )["input_ids"] - if isinstance(prompt_ids, torch.Tensor): - if prompt_ids.dim() == 2 and prompt_ids.size(0) == 1: - prompt_ids = prompt_ids.squeeze(0) # [L] - prompt_ids = prompt_ids.to(dtype=torch.long) - else: - prompt_ids = torch.tensor(prompt_ids, dtype=torch.long) - prompt_length = len(prompt_ids) - - return prompt_text, prompt_ids, prompt_length - else: - raise NotImplementedError("Non-chat template not implemented yet.") + self._records = [] + self._indices = [] # (path, fmt, row_index_or_offset) - processed_data: List[DatasetItem] = [] - for idx, item in enumerate(self.data): - image_list: List[Union[bytes, str]] = get_image_list(item, self.image_keys) - prompt_text, prompt_ids, prompt_length = process_prompt( - item, len(image_list) - ) + if self.lazy_loading: + self._build_lazy_indices() + else: + self._eager_load_all() - if prompt_length > self.max_prompt_length: - print( - f"prompt_ids length {prompt_length} exceeds the max_prompt_length {self.max_prompt_length}", + def __len__(self) -> int: + return len(self._indices) if self.lazy_loading else len(self._records) + + def __getitem__(self, idx: int) -> DatasetItem: + if self.lazy_loading: + path, fmt, key = self._indices[idx] + raw = self._load_single_lazy(path, fmt, key) + return self._process_raw_record(raw, idx) + else: + raw = self._records[idx] + return self._process_raw_record(raw, idx) + + # Ensure dataset is picklable for multi-process DataLoader by removing + # unpicklable cache objects like pyarrow.ParquetFile from state. + def __getstate__(self): + state = self.__dict__.copy() + # Drop heavy/unpicklable caches; they will be rebuilt on-demand in workers + for k in ("_parquet_cache", "_parquet_df_cache"): + if k in state: + state[k] = {} + return state + + def __setstate__(self, state): + # Restore state and ensure cache dicts exist + self.__dict__.update(state) + self._parquet_cache = getattr(self, "_parquet_cache", {}) + self._parquet_df_cache = getattr(self, "_parquet_df_cache", {}) + + def get_image_list(self, dataitem: Dict[str, Any]) -> List[Union[bytes, str, None]]: + images: List[Union[bytes, str, None]] = [] + for k in self.image_keys: + v = dataitem.get(k, None) + if v is None: + continue + if isinstance(v, Image): + images.append(v) + elif isinstance(v, dict) and "bytes" in v: + images.append(v["bytes"]) + else: + images.append(v) # path or url + if not images: + images = [None] + return images + + def build_prompt_text(self, data_item: Dict[str, Any]) -> str: + # Default: prompt + optional choices rendered inline + q = data_item.get(self.prompt_key, "") + choices = data_item.get(self.choice_key, []) if self.choice_key else [] + if not isinstance(choices, list): + choices = [choices] + if choices: + return f"{q}{choices}\n" + return str(q) + + def encode_prompt( + self, prompt_text: str, image_count: int + ) -> Tuple[torch.Tensor, int, Optional[str]]: + """ + Return (token_ids[L], length, prompt_text_used). If using chat template, encode with processor. + Subclasses may override to support alternative prompting. + """ + if self.use_chat_template: + if self._processor is None: + self._processor = AutoProcessor.from_pretrained( + self.cfg.actor.model.model_path ) - prompt_ids = prompt_ids[: self.max_prompt_length] - prompt_length = self.max_prompt_length - prompt_ids = batch_pad_to_fixed_len( - [prompt_ids], - self.max_prompt_length, - self.tokenizer.eos_token_id, - left_pad=True, - )[0] - answer = item.get(self.answer_key, None) - solution = item.get(self.solution_key, None) - - data_item = DatasetItem( - prompt_text=prompt_text, - prompt=prompt_ids, - length=prompt_length, - image_data=image_list, - answer=str(answer), - solution=solution, - idx=idx, + content: List[Dict[str, Any]] = [] + for _ in range(max(0, image_count)): + content.append({"type": "image"}) + content.append({"type": "text", "text": prompt_text}) + messages = [{"role": "user", "content": content}] + rendered = self._processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True ) - processed_data.append(data_item) - self.data = processed_data + ids = self._processor(text=[rendered], padding=True, return_tensors="pt")[ + "input_ids" + ] + if isinstance(ids, torch.Tensor): + if ids.dim() == 2 and ids.size(0) == 1: + ids = ids.squeeze(0) + ids = ids.to(dtype=torch.long) + else: + ids = torch.tensor(ids, dtype=torch.long) + return ids, int(ids.numel()), rendered + else: + # fallback: tokenizer only + ids_list = self.tokenizer.encode(prompt_text) + ids = torch.as_tensor(ids_list, dtype=torch.long) + return ids, int(ids.numel()), prompt_text + + def postprocess_dataset_item( + self, item: DatasetItem, raw: Dict[str, Any] + ) -> DatasetItem: + return item + + def _expand_data_paths(self, inputs: List[str]) -> List[str]: + exts = {".jsonl", ".json", ".parquet"} + files: List[str] = [] + for p in inputs: + if os.path.isdir(p): + for root, _, fnames in os.walk(p): + for fn in fnames: + ext = os.path.splitext(fn)[1].lower() + if ext in exts: + files.append(os.path.join(root, fn)) + else: + files.append(p) + files = sorted(set(files)) + return files - def _load_data(self) -> List: - merged_data = [] + def _eager_load_all(self) -> None: + merged: List[Dict[str, Any]] = [] for path in self.data_paths: - _, file_extension = os.path.splitext(path) + fmt = os.path.splitext(path)[1].lower() + if fmt == ".jsonl": + with open(path, "r", encoding="utf-8") as f: + merged.extend(json.loads(l) for l in f) + elif fmt == ".json": + with open(path, "r", encoding="utf-8") as f: + content = json.load(f) + if isinstance(content, list): + merged.extend(content) + else: + merged.append(content) + elif fmt == ".parquet": + try: + merged.extend(pd.read_parquet(path).to_dict(orient="records")) + except Exception as e: + raise RuntimeError(f"Failed to load parquet eagerly: {path}: {e}") + else: + logging.warning(f"Unsupported format {fmt} for path {path}, skipping.") + self._records = merged + # Build indices for consistency + self._indices = [("", "eager", i) for i in range(len(self._records))] + + def _build_lazy_indices(self) -> None: + self._indices.clear() + for path in self.data_paths: + fmt = os.path.splitext(path)[1].lower() + if fmt == ".jsonl": + # index by byte offsets for each line + offsets: List[int] = [] + with open(path, "rb") as fb: + pos = 0 + for line in fb: + offsets.append(pos) + pos += len(line) + self._indices.extend((path, "jsonl", off) for off in offsets) + elif fmt == ".json": + try: + with open(path, "r", encoding="utf-8") as f: + content = json.load(f) + if not isinstance(content, list): + content = [content] + # store the content to avoid re-reading + # keep perfile cache + self._json_cache = getattr(self, "_json_cache", {}) + self._json_cache[path] = content + self._indices.extend((path, "json", i) for i in range(len(content))) + except Exception as e: + raise RuntimeError(f"Failed to index json lazily: {path}: {e}") + elif fmt == ".parquet": + try: + import pyarrow.parquet as pq # type: ignore + + pf = pq.ParquetFile(path) + num_rows = pf.metadata.num_rows + # file handle cache + self._parquet_cache = getattr(self, "_parquet_cache", {}) + self._parquet_cache[path] = pf + self._indices.extend((path, "parquet", i) for i in range(num_rows)) + except Exception: + df = pd.read_parquet(path) + self._parquet_df_cache = getattr(self, "_parquet_df_cache", {}) + self._parquet_df_cache[path] = df + self._indices.extend( + (path, "parquet_pd", i) for i in range(len(df)) + ) + else: + logging.warning(f"Unsupported format {fmt} for path {path}, skipping.") + + def _load_single_lazy(self, path: str, fmt: str, key: Any) -> Dict[str, Any]: + if fmt == "eager": + return self._records[int(key)] + if fmt == "jsonl": + with open(path, "rb") as fb: + fb.seek(int(key)) + line = fb.readline() + return json.loads(line.decode("utf-8").strip()) + if fmt == "json": + return self._json_cache[path][int(key)] # type: ignore[attr-defined] + if fmt == "parquet": + # Try to use pyarrow lazily; rebuild cache if missing + self._parquet_cache = getattr(self, "_parquet_cache", {}) + pf = self._parquet_cache.get(path) + if pf is None: + try: + import pyarrow.parquet as pq # type: ignore + + pf = pq.ParquetFile(path) + self._parquet_cache[path] = pf + except Exception: + # Fall back to pandas-based cache + self._parquet_df_cache = getattr(self, "_parquet_df_cache", {}) + df = self._parquet_df_cache.get(path) + if df is None: + df = pd.read_parquet(path) + self._parquet_df_cache[path] = df + return df.iloc[int(key)].to_dict() + table = pf.read_row_group(key // max(1, pf.metadata.num_rows), columns=None) try: - pass - if file_extension == ".parquet": - loaded_data: List = pd.read_parquet(path).to_dict(orient="records") - merged_data.extend(loaded_data) - elif file_extension == ".jsonl": - with open(path, "r", encoding="utf-8") as file: - loaded_data = [json.loads(line.strip()) for line in file] - merged_data.extend(loaded_data) - elif file_extension == ".json": - with open(path, "r", encoding="utf-8") as file: - content = json.load(file) - if isinstance(content, list): - merged_data.extend(content) - else: - merged_data.append(content) - else: - print(f"Unsupport {file_extension}, skip: {path}") - except Exception as e: - raise RuntimeError(f"Load data error: {e}") - return merged_data + df = table.to_pandas() + return df.iloc[int(key) % len(df)].to_dict() + except Exception: + df_all = pf.read().to_pandas() + return df_all.iloc[int(key)].to_dict() + if fmt == "parquet_pd": + self._parquet_df_cache = getattr(self, "_parquet_df_cache", {}) + df = self._parquet_df_cache.get(path) + if df is None: + df = pd.read_parquet(path) + self._parquet_df_cache[path] = df + return df.iloc[int(key)].to_dict() + raise RuntimeError(f"Unknown lazy fmt {fmt}") + + def _process_raw_record(self, raw: Dict[str, Any], idx: int) -> DatasetItem: + images = self.get_image_list(raw) + prompt_text = self.build_prompt_text(raw) + prompt_ids, plen, rendered_text = self.encode_prompt(prompt_text, len(images)) + + if plen > self.max_prompt_length: + prompt_ids = prompt_ids[: self.max_prompt_length] + plen = self.max_prompt_length + prompt_ids = batch_pad_to_fixed_len( + [prompt_ids], self.max_prompt_length, self.eos_id, left_pad=True + )[0] - def __len__(self) -> int: - return len(self.data) + answer_val = raw.get(self.answer_key, None) if self.answer_key else None + solution_val = raw.get(self.solution_key, None) if self.solution_key else None + item = DatasetItem( + prompt=prompt_ids, + length=plen, + answer=str(answer_val) if answer_val is not None else None, + idx=idx, + image_data=images, + prompt_text=rendered_text or prompt_text, + solution=solution_val, + meta=None, + ) + return self.postprocess_dataset_item(item, raw) + + +class VLMDatasetRegistry: + registry: Dict[str, Callable[..., VLMBaseDataset]] = {} + + @classmethod + def register( + cls, name: str + ) -> Callable[[Callable[..., VLMBaseDataset]], Callable[..., VLMBaseDataset]]: + def decorator(klass: Callable[..., VLMBaseDataset]): + cls.registry[name] = klass + return klass + + return decorator + + @classmethod + def create( + cls, + dataset_name: Optional[str], + *, + data_paths: Union[List[str], str], + config: DictConfig, + tokenizer: AutoTokenizer, + ) -> VLMBaseDataset: + key = dataset_name.lower() + klass = cls.registry.get(key) + return klass(data_paths=data_paths, config=config, tokenizer=tokenizer) + + +@VLMDatasetRegistry.register("robo2vlm") +class Robo2VLMDataset(VLMBaseDataset): + def get_image_list(self, dataitem: Dict[str, Any]) -> List[Union[bytes, str, None]]: + # Prefer common robo2vlm fields if present, else fallback to configured keys + images: List[Any] = [] + if "images" in dataitem: + v = dataitem.get("images") + if isinstance(v, list): + images = list(v) + elif v is not None: + images = [v] + else: + images = [None] + elif "image" in dataitem: + v = dataitem.get("image") + if v is not None: + images = [v] + else: + images = [None] + else: + # fallback to base behavior using configured image_keys + return super().get_image_list(dataitem) + + # Normalize each element similar to base behavior + normed: List[Union[bytes, str, None]] = [] + for v in images: + if v is None: + continue + if isinstance(v, Image): + normed.append(v) + elif isinstance(v, dict) and "bytes" in v: + normed.append(v["bytes"]) # raw bytes + else: + normed.append(v) # path/uri/string + if not normed: + normed = [None] + return normed + + def build_prompt_text(self, data_item: Dict[str, Any]) -> str: + # Use 'question' and 'choices' if present; else fallback to base using configured prompt/choice keys + question = data_item.get("question", None) + choices = data_item.get("choices", None) + if question is None: + return super().build_prompt_text(data_item) + # normalize choices + if isinstance(choices, str): + try: + import ast - def __getitem__(self, index): - return self.data[index] + choices = ast.literal_eval(choices) + except Exception: + choices = [choices] + if not isinstance(choices, list): + choices = [choices] if choices is not None else [] + + text = f"{question}\n" + if choices: + text += "Choices:\n" + for i, c in enumerate(choices): + text += f"{chr(65 + i)}. {c}\n" + return text + + def postprocess_dataset_item( + self, item: DatasetItem, raw: Dict[str, Any] + ) -> DatasetItem: + # Derive answer from 'correct_answer' and 'choices' if not provided + if not item.answer or str(item.answer).lower() in {"none", "", "null"}: + choices = raw.get("choices") + ca = raw.get("correct_answer") + try: + # Normalize choices + if isinstance(choices, str): + import ast + + choices = ast.literal_eval(choices) + if not isinstance(choices, list): + choices = [choices] if choices is not None else [] + + ans_val: Optional[str] = None + if isinstance(ca, int) and 0 <= ca < len(choices): + ans_val = str(choices[ca]) + elif isinstance(ca, str): + cstr = ca.strip() + # Letter index like 'A', 'B', ... + if len(cstr) == 1 and "A" <= cstr <= "Z": + idx = ord(cstr) - ord("A") + if 0 <= idx < len(choices): + ans_val = str(choices[idx]) + # Direct match to a choice value + if ans_val is None and choices: + for ch in choices: + if str(ch) == cstr: + ans_val = cstr + break + if ans_val is not None: + item.answer = ans_val + except Exception: + # Keep original if any + pass + return item def create_rl_dataset(config: DictConfig, tokenizer): @@ -344,7 +612,25 @@ def create_rl_dataset(config: DictConfig, tokenizer): if config.data.type == "math": dataset_cls = MathDataset elif config.data.type == "vision_language": - dataset_cls = VisionLanguageDataset + # Prefer new factory-based VLM datasets; fallback to legacy if requested + dataset_name = getattr(config.data, "dataset_name", None) + lazy_loading = bool(getattr(config.data, "lazy_loading", False)) + + print(f"Using VLM dataset: name={dataset_name}, lazy_loading={lazy_loading}") + + train_dataset = VLMDatasetRegistry.create( + dataset_name, + data_paths=config.data.train_data_paths, + config=config, + tokenizer=tokenizer, + ) + val_dataset = VLMDatasetRegistry.create( + dataset_name, + data_paths=config.data.val_data_paths, + config=config, + tokenizer=tokenizer, + ) + return train_dataset, val_dataset else: return None, None From fa0fc753d30a2c7a386b6a42df23063f7ba4608a Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Fri, 19 Sep 2025 11:13:33 +0000 Subject: [PATCH 05/38] fix(vllm): fix wrong image_data param when running vlm in vllm Signed-off-by: Bo Dai --- .../math/config/qwen2.5-1.5b-grpo-fsdp.yaml | 28 +++++++++++-------- rlinf/data/datasets.py | 1 - .../hybrid_engines/vllm/vllm_0_8_5/worker.py | 5 +++- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml b/examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml index d8a1e8c3f..55285baa8 100644 --- a/examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml +++ b/examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml @@ -8,7 +8,7 @@ hydra: cluster: num_nodes: 1 - num_gpus_per_node: 8 + num_gpus_per_node: 4 component_placement: actor,rollout: all @@ -33,8 +33,7 @@ runner: resume_dir: null experiment_name: grpo-1.5b - output_dir: ../results - + output_dir: /mnt/public/daibo/results algorithm: group_size: 8 @@ -85,7 +84,7 @@ rollout: gpu_memory_utilization: 0.55 - model_dir: /mnt/public/hf_models/DeepSeek-R1-Distill-Qwen-1.5B + model_dir: /mnt/public/hf_models/qwen2.5-VL-3B/ model_arch: qwen2.5 enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. distributed_executor_backend: mp # ray or mp @@ -94,7 +93,7 @@ rollout: padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine eos: null # will be tokenizer.eos_token_id if null. - rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm] + rollout_backend: vllm # here choose which backend to rollout,support [sglang, vllm] sglang: attention_backend: triton # [flashinfer, triton] for more, see sglang's doc @@ -121,18 +120,25 @@ rollout: cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. data: - type: math + type: vision_language + dataset_name: robo2vlm max_prompt_length: 1024 filter_prompt_by_length: True rollout_batch_size: 8 val_rollout_batch_size: null num_workers: 2 - prompt_key: prompt shuffle: True validation_shuffle: True seed: 1234 - train_data_paths: ["/mnt/public/guozhen/data/boba_106k_0319_prompt_1024.jsonl"] - val_data_paths: ["/mnt/public/guozhen/data/boba_106k_0319_prompt_1024.jsonl"] + train_data_paths: ["/mnt/public/daibo/dataset/robo2vlm-1/data/"] + val_data_paths: ["/mnt/public/daibo/dataset/robo2vlm-1/data/"] + prompt_key: question + image_keys: [image] + answer_key: answer + choice_key: choices + solution_key: null + use_chat_template: True + lazy_loading: True actor: group_name: "ActorGroup" @@ -159,7 +165,7 @@ actor: seq_length: ${runner.seq_length} encoder_seq_length: ${runner.seq_length} - model_path: /mnt/public/hf_models/DeepSeek-R1-Distill-Qwen-1.5B + model_path: /mnt/public/hf_models/qwen2.5-VL-3B/ optim: optimizer: adam @@ -189,7 +195,7 @@ actor: lr_decay_iters: 10 tokenizer: - tokenizer_model: /mnt/public/hf_models/DeepSeek-R1-Distill-Qwen-1.5B + tokenizer_model: /mnt/public/hf_models/qwen2.5-VL-3B/ use_fast: False trust_remote_code: True padding_side: 'right' diff --git a/rlinf/data/datasets.py b/rlinf/data/datasets.py index 2be41ce3b..107e89d55 100644 --- a/rlinf/data/datasets.py +++ b/rlinf/data/datasets.py @@ -177,7 +177,6 @@ def __getitem__(self, idx): class VLMBaseDataset(Dataset): - def __init__( self, data_paths: Union[List[str], str], diff --git a/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py b/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py index 3e021e922..519895e49 100644 --- a/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py +++ b/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py @@ -48,6 +48,9 @@ def __init__( ) # rlinf specific self.rlinf_config = rlinf_config + self.using_sharded_weight = ( + False if self.rlinf_config.actor.training_backend == "fsdp" else True + ) self._rlinf_worker = _RLinfWorker( parent_address=parent_address, world_size=vllm_config.parallel_config.world_size, @@ -103,7 +106,7 @@ def sync_hf_weight(self) -> None: def use_sharded_weights(self) -> None: model = self.model_runner.model for _, param in model.named_parameters(): - setattr(param, "is_sharded_weight", True) + setattr(param, "is_sharded_weight", self.using_sharded_weight) def get_dp_rank(self) -> int: return self._rlinf_worker.get_parent_rank() From 6a7d4bc3b8d7cf69ed3732ca24c56c069af26899 Mon Sep 17 00:00:00 2001 From: guozhen1997 <2997871698@qq.com> Date: Mon, 22 Sep 2025 21:22:50 +0800 Subject: [PATCH 06/38] feat: add vqa reward function, unify math and vqa reward Signed-off-by: guozhen1997 <2997871698@qq.com> --- rlinf/algorithms/rewards/__init__.py | 15 +++++ rlinf/algorithms/rewards/math/__init__.py | 24 +++++++ rlinf/algorithms/rewards/vqa/__init__.py | 42 +++++++++++++ .../algorithms/rewards/vqa/format_rewards.py | 53 ++++++++++++++++ rlinf/algorithms/rewards/vqa/qa_rewards.py | 63 +++++++++++++++++++ rlinf/data/datasets.py | 4 +- .../hybrid_engines/fsdp/fsdp_model_manager.py | 1 - rlinf/workers/actor/fsdp_actor_worker.py | 16 ++--- 8 files changed, 205 insertions(+), 13 deletions(-) create mode 100644 rlinf/algorithms/rewards/__init__.py create mode 100644 rlinf/algorithms/rewards/math/__init__.py create mode 100644 rlinf/algorithms/rewards/vqa/__init__.py create mode 100644 rlinf/algorithms/rewards/vqa/format_rewards.py create mode 100644 rlinf/algorithms/rewards/vqa/qa_rewards.py diff --git a/rlinf/algorithms/rewards/__init__.py b/rlinf/algorithms/rewards/__init__.py new file mode 100644 index 000000000..3a48b84b6 --- /dev/null +++ b/rlinf/algorithms/rewards/__init__.py @@ -0,0 +1,15 @@ +from .math import MathReward +from .vqa import VQAReward + +def register_reward(name: str, reward_class: type): + assert name not in reward_registry, f"Reward {name} already registered" + reward_registry[name] = reward_class + +def get_reward_class(name: str): + assert name in reward_registry, f"Reward {name} not found" + return reward_registry[name] + +reward_registry = {} + +register_reward("math", MathReward) +register_reward("vqa", VQAReward) \ No newline at end of file diff --git a/rlinf/algorithms/rewards/math/__init__.py b/rlinf/algorithms/rewards/math/__init__.py new file mode 100644 index 000000000..a94ff2dc4 --- /dev/null +++ b/rlinf/algorithms/rewards/math/__init__.py @@ -0,0 +1,24 @@ +from typing import List +from omegaconf import DictConfig +from toolkits.math_verifier.verify import math_verify_call + + +class MathReward: + def __init__(self, config: DictConfig): + self.scale = config.get("scale", 1.0) + + def get_reward( + self, response: List[str], reference: List[List[str]] + ) -> List[float]: + """ + Calculates reward scores for a list of responses compared to corresponding lists of reference answers. + For each response, the function checks if it matches any of the provided references using the `process_results` function. + The reward for each response is computed as the first element of the result (converted to float) multiplied by `self.scale`. + Args: + response (List[str]): A list of response strings to be evaluated. + reference (List[List[str]]): A list where each element is a list of reference strings corresponding to each response. + Returns: + List[float]: A list of reward scores, one for each response. + """ + + return math_verify_call(response, reference) * self.scale \ No newline at end of file diff --git a/rlinf/algorithms/rewards/vqa/__init__.py b/rlinf/algorithms/rewards/vqa/__init__.py new file mode 100644 index 000000000..d4dd55f20 --- /dev/null +++ b/rlinf/algorithms/rewards/vqa/__init__.py @@ -0,0 +1,42 @@ +import torch +from typing import List +from omegaconf import DictConfig +from .qa_rewards import qa_accuracy_reward +from .format_rewards import think_format_reward, answer_format_reward + + +class VQAReward: + def __init__(self, config: DictConfig): + self.reward_weights = config.get("reward_weights", { + "qa_accuracy": 1.0, + "think_format": 0.0, + "answer_format": 0.0, + }) + for reward_name, reward_weight in self.reward_weights.items(): + assert reward_name in ["qa_accuracy", "think_format", "answer_format"], f"Reward {reward_name} not supported" + assert reward_weight >= 0, f"Reward weight {reward_weight} must be non-negative" + self.reward_weights = [reward_weight["qa_accuracy"], reward_weight["think_format"], reward_weight["answer_format"]] + + self.reward_functions = [qa_accuracy_reward, think_format_reward, answer_format_reward] + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def get_reward(self, completions: List[str], answers: List[str]) -> List[float]: + rewards = [] + for i, reward_function in enumerate(self.reward_functions): + if self.reward_weights[i] > 0: + rewards.append(reward_function(completions, answers)) + else: + rewards.append([0.0] * len(completions)) + + # Apply weights to each reward function's output and sum + + # rewards [num_reward_functions, len(completions)] + rewards_tensor = torch.tensor(rewards, device=self.device) + weights_tensor = torch.tensor(self.reward_weights, device=self.device) + + # [num_reward_functions, num_completions] * [num_reward_functions, 1] -> [num_completions] + final_rewards = (rewards_tensor * weights_tensor.unsqueeze(1)).sum(dim=0) + + return final_rewards.tolist() + \ No newline at end of file diff --git a/rlinf/algorithms/rewards/vqa/format_rewards.py b/rlinf/algorithms/rewards/vqa/format_rewards.py new file mode 100644 index 000000000..2f926e733 --- /dev/null +++ b/rlinf/algorithms/rewards/vqa/format_rewards.py @@ -0,0 +1,53 @@ +import re +from typing import List + + +def think_format_reward(completions, answers) -> List[float]: + """ + Think format reward function compatible with GRPO training. + + Reward function that checks if reasoning is enclosed within tags. + + Args: + completions: List of model completions (text strings) + + Returns: + List of reward scores (1.0 for correct format, 0.0 otherwise) + """ + pattern = r"^(?!.*)(.*?).*$" + rewards = [] + + for completion in completions: + completion_text = str(completion).strip() + match = re.match(pattern, completion_text, re.DOTALL | re.MULTILINE) + rewards.append(1.0 if match else 0.0) + + return rewards + + +def answer_format_reward(completions, answers) -> List[float]: + """ + Reward function that checks for proper answer formatting. + + Expected format: X. content where X is a choice letter. + + Args: + completions: List of model completions (text strings) + + Returns: + List of reward scores (1.0 for correct format, 0.0 otherwise) + """ + rewards = [] + + for completion in completions: + completion_text = str(completion).strip() + + # Check for proper answer format: X. content + answer_pattern = r'\s*[A-E]\.\s*.+?\s*' + has_proper_answer = bool(re.search( + answer_pattern, completion_text, re.DOTALL | re.IGNORECASE + )) + + rewards.append(1.0 if has_proper_answer else 0.0) + + return rewards \ No newline at end of file diff --git a/rlinf/algorithms/rewards/vqa/qa_rewards.py b/rlinf/algorithms/rewards/vqa/qa_rewards.py new file mode 100644 index 000000000..ce7a3443b --- /dev/null +++ b/rlinf/algorithms/rewards/vqa/qa_rewards.py @@ -0,0 +1,63 @@ +import re +from typing import List + + +def qa_accuracy_reward(completions, answers) -> List[float]: + """ + Reward function that evaluates question-answering accuracy for VQA tasks. + + Based on TRL's accuracy_reward pattern but adapted for multiple choice VQA. + + Args: + completions: List of model completions (text strings) + answers: List of correct answers (text strings) + + Returns: + List of reward scores (1.0 for correct, 0.0 for incorrect) + """ + rewards = [] + + for completion, answer in zip(completions, answers): + completion_text = str(completion).strip() + + # Extract answer from completion - look for X. content + patterns = [ + r'\s*[A-E]\.\s*(.*?)\s*', + r'\s*[A-E]\s*(.*?)\s*', + r'\s*(.*?)\s*', + ] + + answer_match = None + for pattern in patterns: + answer_match = re.search(pattern, completion_text, re.DOTALL | re.IGNORECASE) + if answer_match: + break + + if not answer_match: + rewards.append(0.0) + continue + + predicted_content = answer_match.group(1).strip() + + content_match = _compare_choice_content(predicted_content, answer) + + rewards.append(1.0 if content_match else 0.0) + + return rewards + + +def _compare_choice_content(predicted: str, correct: str) -> bool: + """Compare predicted choice content with correct content.""" + # Simple normalized comparison + pred_normalized = predicted.lower().strip() + correct_normalized = correct.lower().strip() + + # Direct match + if pred_normalized == correct_normalized: + return True + + # Partial match for more flexibility + if pred_normalized in correct_normalized or correct_normalized in pred_normalized: + return True + + return False \ No newline at end of file diff --git a/rlinf/data/datasets.py b/rlinf/data/datasets.py index 107e89d55..2669b14e0 100644 --- a/rlinf/data/datasets.py +++ b/rlinf/data/datasets.py @@ -489,8 +489,8 @@ def create( tokenizer: AutoTokenizer, ) -> VLMBaseDataset: key = dataset_name.lower() - klass = cls.registry.get(key) - return klass(data_paths=data_paths, config=config, tokenizer=tokenizer) + dataset_class = cls.registry.get(key) + return dataset_class(data_paths=data_paths, config=config, tokenizer=tokenizer) @VLMDatasetRegistry.register("robo2vlm") diff --git a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py index 5f04d633f..c3bd9475a 100644 --- a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py +++ b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py @@ -68,7 +68,6 @@ def model_provider_func(self) -> torch.nn.Module: else: auto_model_class = AutoModelForCausalLM - # TODO: fix this, load model in float16/bfloat16 may cause optimizer in bf16, which is incorrect # default load in float16 model = auto_model_class.from_pretrained( self._cfg.model.model_path, diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index 37dd1d9ba..e5432efe9 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -17,6 +17,7 @@ from typing import Dict, List, Tuple import numpy as np +from RLinf.rlinf.algorithms.rewards import get_reward_class import torch from omegaconf import DictConfig from torch.distributed.device_mesh import init_device_mesh @@ -60,7 +61,6 @@ seq_mean_token_sum, ) from rlinf.workers.rollout.utils import RankMapper -from toolkits.math_verifier.verify import math_verify_call class FSDPActor(FSDPModelManager, Worker): @@ -110,8 +110,9 @@ def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): # Reward configurations if not self.cfg.reward.use_reward_model: - assert self.cfg.reward.reward_type == "math", "only support math" - self.reward_fn = math_verify_call + assert self.cfg.reward.reward_type in ["math", "vqa"], "only support math and vqa reward!" + reward_cls = get_reward_class(self.cfg.reward.reward_type) + self.reward = reward_cls(self.cfg.reward) def init_worker(self): self.setup_model_and_optimizer() @@ -417,13 +418,8 @@ def _compute_batch_rewards( texts.append( self.tokenizer.decode(response.tolist(), skip_special_tokens=True) ) - rewards = self.reward_fn(texts, answers) - reward_scores = [ - self.cfg.reward.reward_scale - if reward == 1 - else -self.cfg.reward.reward_scale - for reward in rewards - ] + reward_scores = self.reward.get_reward(texts, answers) + all_reward_scores = torch.as_tensor( reward_scores, dtype=torch.float, From 7100e6b7eb36555810fe1861268ee20cf2ddf7b9 Mon Sep 17 00:00:00 2001 From: guozhen1997 <2997871698@qq.com> Date: Mon, 22 Sep 2025 22:03:39 +0800 Subject: [PATCH 07/38] feat: add reward worker Signed-off-by: guozhen1997 <2997871698@qq.com> --- .../math/config/qwen2.5-1.5b-grpo-fsdp.yaml | 9 +- examples/math/main_math.py | 8 ++ rlinf/runners/math_runner.py | 3 +- rlinf/utils/placement.py | 8 ++ rlinf/workers/reward/reward_worker.py | 101 ++++++++++++++++++ 5 files changed, 126 insertions(+), 3 deletions(-) create mode 100644 rlinf/workers/reward/reward_worker.py diff --git a/examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml b/examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml index 55285baa8..822008f04 100644 --- a/examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml +++ b/examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml @@ -201,9 +201,14 @@ actor: padding_side: 'right' reward: + group_name: "ActorGroup" use_reward_model: false - reward_type: 'math' - reward_scale: 5.0 + reward_type: 'vqa' + # reward_scale: 5.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 critic: use_critic_model: false \ No newline at end of file diff --git a/examples/math/main_math.py b/examples/math/main_math.py index 80f408bd0..37c75aca8 100644 --- a/examples/math/main_math.py +++ b/examples/math/main_math.py @@ -28,6 +28,7 @@ from rlinf.workers.actor import get_actor_worker from rlinf.workers.inference.megatron_inference_worker import MegatronInference from rlinf.workers.rollout.utils import get_rollout_backend_worker +from rlinf.workers.reward.reward_worker import RewardWorker """Script to start GRPO training""" mp.set_start_method("spawn", force=True) @@ -66,6 +67,12 @@ def main(cfg) -> None: name=cfg.inference.group_name, placement_strategy=inference_placement_strategy, ) + + # Reward group + reward_placement_strategy = component_placement.get_strategy("reward") + reward_group = RewardWorker.create_group(cfg, component_placement).launch( + cluster, name=cfg.reward.group_name, placement_strategy=reward_placement_strategy + ) # GRPO Actor group actor_worker_cls = get_actor_worker(cfg) @@ -85,6 +92,7 @@ def main(cfg) -> None: rollout=rollout_group, inference=inference_group, actor=actor_group, + reward=reward_group, ) runner.init_workers() diff --git a/rlinf/runners/math_runner.py b/rlinf/runners/math_runner.py index be2cfac1f..c0326a8f7 100644 --- a/rlinf/runners/math_runner.py +++ b/rlinf/runners/math_runner.py @@ -35,6 +35,7 @@ from rlinf.utils.timers import Timer from rlinf.workers.actor.megatron_actor_worker import MegatronActor from rlinf.workers.inference.megatron_inference_worker import MegatronInference +from rlinf.workers.reward.reward_worker import RewardWorker if typing.TYPE_CHECKING: from rlinf.workers.rollout.sglang.sglang_worker import SGLangWorker @@ -55,7 +56,7 @@ def __init__( rollout: Union["SGLangWorker", "VLLMWorker"], inference: Optional[MegatronInference], actor: MegatronActor, - reward: Optional[Worker] = None, + reward: Optional[RewardWorker] = None, ): """""" self.cfg = cfg diff --git a/rlinf/utils/placement.py b/rlinf/utils/placement.py index 10738ce9e..d6fa798c8 100644 --- a/rlinf/utils/placement.py +++ b/rlinf/utils/placement.py @@ -202,6 +202,7 @@ def __init__(self, config: DictConfig, cluster: Cluster): self._actor_gpus = self._component_gpu_map.get("actor", None) self._inference_gpus = self._component_gpu_map.get("inference", None) self._rollout_gpus = self._component_gpu_map.get("rollout", None) + self._reward_gpus = self._component_gpu_map.get("reward", None) assert self._actor_gpus is not None, ( "Actor GPUs must be specified in the component_placement config." ) @@ -224,6 +225,7 @@ def __init__(self, config: DictConfig, cluster: Cluster): len(self._inference_gpus) if self._inference_gpus else 0 ) self._rollout_num_gpus = len(self._rollout_gpus) + self._reward_num_gpus = len(self._reward_gpus) if self._is_collocated(): assert self._inference_gpus is None, ( @@ -295,6 +297,9 @@ def _generate_placements(self): num_accelerators_per_process=rollout_tp_size, stride=stride, ) + self._placements["reward"] = PackedPlacementStrategy( + self._reward_gpus[0], self._reward_gpus[-1] + ) elif self._placement_mode == PlacementMode.DISAGGREGATED: # Generate continuous placement strategies for components in a cluster. num_gpus_per_rollout_dp = len(self._rollout_gpus) // self.rollout_dp_size @@ -310,6 +315,9 @@ def _generate_placements(self): self._placements["actor"] = PackedPlacementStrategy( self._actor_gpus[0], self._actor_gpus[-1] ) + self._placements["reward"] = PackedPlacementStrategy( + self._reward_gpus[0], self._reward_gpus[-1] + ) @property def is_disaggregated(self): diff --git a/rlinf/workers/reward/reward_worker.py b/rlinf/workers/reward/reward_worker.py new file mode 100644 index 000000000..bc98d5a6a --- /dev/null +++ b/rlinf/workers/reward/reward_worker.py @@ -0,0 +1,101 @@ +import torch +from typing import Dict, Tuple, List +from omegaconf import DictConfig +from rlinf.hybrid_engines.fsdp.fsdp_model_manager import FSDPModelManager +from rlinf.scheduler import Worker, Channel +from rlinf.algorithms.rewards import get_reward_class +from rlinf.data.io_struct import RolloutResult + + +class RewardWorker(Worker, FSDPModelManager): + def __init__(self, cfg: DictConfig): + Worker.__init__(self) + super().__init__(cfg.reward) + self.cfg = cfg + + self.total_batch_size_per_dp = ( + self.cfg.data.rollout_batch_size + * self.cfg.algorithm.get("group_size", 1) + // self._world_size + ) + + def init_worker(self): + if self.cfg.reward.use_reward_model: + self.setup_model_and_optimizer() + self.offload_fsdp_param_and_grad() + self.offload_fsdp_optimizer() + else: + self.reward = get_reward_class(self.cfg.reward.name)(self.cfg.reward) + + def get_batch( + self, channel: Channel + ) -> Tuple[Dict[str, torch.Tensor], RolloutResult]: + result: RolloutResult = channel.get() + + batch = result.to_actor_batch( + self.cfg.data.max_prompt_length, + self.cfg.actor.model.encoder_seq_length, + self.tokenizer.eos_token_id, + ) + return batch, result + + def compute_rewards(self, input_channel: Channel, output_channel: Channel): + """Compute rewards. + + Args: + input_channel: The input channel to read from. + output_channel: The output channel to send results to. + """ + + with self.worker_timer(): + recv_batch_size = 0 + while recv_batch_size < self.total_batch_size_per_dp: + batch, rollout_result = self.get_batch(input_channel) + recv_batch_size += rollout_result.num_sequence + + # Compute rule-based reward + if rollout_result.rewards is None: + rollout_result.rewards = self._compute_batch_rewards( + batch, rollout_result.answers + ) + output_channel.put(rollout_result) + + assert recv_batch_size == self.total_batch_size_per_dp, ( + f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" + ) + + def _compute_batch_rewards( + self, batch: Dict[str, torch.Tensor], answers: List[str] + ): + """Reward computation using non-model based reward.""" + + if self.cfg.reward.use_reward_model: + return self.compute_batch_rewards_with_model(batch) + + texts = [] + for response, response_len in zip( + batch["input_ids"], + batch["response_lengths"], + ): + response = response[ + self.cfg.data.max_prompt_length : self.cfg.data.max_prompt_length + + response_len + ] + texts.append( + self.tokenizer.decode(response.tolist(), skip_special_tokens=True) + ) + reward_scores = self.reward.get_reward(texts, answers) + + all_reward_scores = torch.as_tensor( + reward_scores, + dtype=torch.float, + device=torch.device("cpu"), + ).view(-1, 1) + return all_reward_scores.flatten() + + def compute_batch_rewards_with_model(self, batch: Dict[str, torch.Tensor]): + self.model.eval() + with torch.no_grad(): + # TODO: fix this + rewards = self.model(batch["input_ids"], batch["attention_mask"]) + return rewards \ No newline at end of file From f7c2fbae79743a8b8a990850293700dce80df549 Mon Sep 17 00:00:00 2001 From: guozhen1997 <2997871698@qq.com> Date: Tue, 23 Sep 2025 16:47:31 +0800 Subject: [PATCH 08/38] fix: fix vqa reward bugs and ruff format Signed-off-by: guozhen1997 <2997871698@qq.com> --- examples/math/main_math.py | 8 +- ...tron.yaml => qwen2.5-vl-3b-grpo-fsdp.yaml} | 25 ++++- examples/vlm/main_vlm.py | 10 ++ examples/vlm/run_main_vlm_grpo_megatron.sh | 2 +- rlinf/algorithms/rewards/__init__.py | 19 +++- rlinf/algorithms/rewards/math/__init__.py | 18 ++- rlinf/algorithms/rewards/vqa/__init__.py | 64 ++++++++--- .../algorithms/rewards/vqa/format_rewards.py | 50 ++++++--- rlinf/algorithms/rewards/vqa/qa_rewards.py | 105 +++++++++++++----- rlinf/data/datasets.py | 72 ++++++------ rlinf/data/io_struct.py | 2 +- rlinf/runners/math_runner.py | 2 +- rlinf/workers/actor/fsdp_actor_worker.py | 6 +- rlinf/workers/reward/reward_worker.py | 34 ++++-- 14 files changed, 290 insertions(+), 127 deletions(-) rename examples/vlm/config/{qwen2.5-vl-3b-grpo-megatron.yaml => qwen2.5-vl-3b-grpo-fsdp.yaml} (90%) diff --git a/examples/math/main_math.py b/examples/math/main_math.py index 37c75aca8..5b9ee61e1 100644 --- a/examples/math/main_math.py +++ b/examples/math/main_math.py @@ -27,8 +27,8 @@ from rlinf.utils.utils import output_redirector from rlinf.workers.actor import get_actor_worker from rlinf.workers.inference.megatron_inference_worker import MegatronInference -from rlinf.workers.rollout.utils import get_rollout_backend_worker from rlinf.workers.reward.reward_worker import RewardWorker +from rlinf.workers.rollout.utils import get_rollout_backend_worker """Script to start GRPO training""" mp.set_start_method("spawn", force=True) @@ -67,11 +67,13 @@ def main(cfg) -> None: name=cfg.inference.group_name, placement_strategy=inference_placement_strategy, ) - + # Reward group reward_placement_strategy = component_placement.get_strategy("reward") reward_group = RewardWorker.create_group(cfg, component_placement).launch( - cluster, name=cfg.reward.group_name, placement_strategy=reward_placement_strategy + cluster, + name=cfg.reward.group_name, + placement_strategy=reward_placement_strategy, ) # GRPO Actor group diff --git a/examples/vlm/config/qwen2.5-vl-3b-grpo-megatron.yaml b/examples/vlm/config/qwen2.5-vl-3b-grpo-fsdp.yaml similarity index 90% rename from examples/vlm/config/qwen2.5-vl-3b-grpo-megatron.yaml rename to examples/vlm/config/qwen2.5-vl-3b-grpo-fsdp.yaml index cfe4febe7..02b5d8ea4 100644 --- a/examples/vlm/config/qwen2.5-vl-3b-grpo-megatron.yaml +++ b/examples/vlm/config/qwen2.5-vl-3b-grpo-fsdp.yaml @@ -10,7 +10,7 @@ cluster: num_nodes: 1 num_gpus_per_node: 8 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: task_type: math @@ -94,7 +94,7 @@ rollout: padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine eos: null # will be tokenizer.eos_token_id if null. - rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm] + rollout_backend: vllm # here choose which backend to rollout,support [sglang, vllm] sglang: attention_backend: triton # [flashinfer, triton] for more, see sglang's doc @@ -122,6 +122,7 @@ rollout: data: type: vision_language + dataset_name: robo2vlm max_prompt_length: 1024 filter_prompt_by_length: True rollout_batch_size: 8 @@ -133,11 +134,12 @@ data: answer_key: "answer" solution_key: "solution" use_chat_template: True + lazy_loading: True shuffle: True validation_shuffle: True seed: 1234 - train_data_paths: ["/mnt/public/guozhen/data/science_qa/train-00000-of-00001-1028f23e353fbe3e.parquet"] - val_data_paths: ["/mnt/public/guozhen/data/science_qa/test-00000-of-00001-f0e719df791966ff.parquet"] + train_data_paths: ["/mnt/public/guozhen/data/robo2vlm/train/"] + val_data_paths: ["/mnt/public/guozhen/data/robo2vlm/test/"] actor: group_name: "ActorGroup" @@ -202,9 +204,20 @@ actor: padding_side: 'right' reward: + group_name: "RewardGroup" use_reward_model: false - reward_type: 'math' - reward_scale: 5.0 + reward_type: 'vqa' + # reward_scale: 5.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: /mnt/public/hf_models/Qwen2.5-VL-3B-Instruct + use_fast: False + trust_remote_code: True + padding_side: 'right' critic: use_critic_model: false \ No newline at end of file diff --git a/examples/vlm/main_vlm.py b/examples/vlm/main_vlm.py index 605577fba..6ed11dd75 100644 --- a/examples/vlm/main_vlm.py +++ b/examples/vlm/main_vlm.py @@ -27,6 +27,7 @@ from rlinf.utils.utils import output_redirector from rlinf.workers.actor import get_actor_worker from rlinf.workers.inference.megatron_inference_worker import MegatronInference +from rlinf.workers.reward.reward_worker import RewardWorker from rlinf.workers.rollout.utils import get_rollout_backend_worker """Script to start GRPO training""" @@ -69,6 +70,14 @@ def main(cfg) -> None: placement_strategy=inference_placement_strategy, ) + # Reward group + reward_placement_strategy = component_placement.get_strategy("reward") + reward_group = RewardWorker.create_group(cfg, component_placement).launch( + cluster, + name=cfg.reward.group_name, + placement_strategy=reward_placement_strategy, + ) + # GRPO Actor group actor_worker_cls = get_actor_worker(cfg) actor_placement_strategy = component_placement.get_strategy("actor") @@ -87,6 +96,7 @@ def main(cfg) -> None: rollout=rollout_group, inference=inference_group, actor=actor_group, + reward=reward_group, ) runner.init_workers() diff --git a/examples/vlm/run_main_vlm_grpo_megatron.sh b/examples/vlm/run_main_vlm_grpo_megatron.sh index 2e5a75e3a..99165babb 100644 --- a/examples/vlm/run_main_vlm_grpo_megatron.sh +++ b/examples/vlm/run_main_vlm_grpo_megatron.sh @@ -13,7 +13,7 @@ MEGATRON_PATH=/opt/Megatron-LM export PYTHONPATH=${REPO_PATH}:${MEGATRON_PATH}:$PYTHONPATH if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-vl-3b-grpo-megatron" + CONFIG_NAME="qwen2.5-vl-3b-grpo-fsdp" else CONFIG_NAME=$1 fi diff --git a/rlinf/algorithms/rewards/__init__.py b/rlinf/algorithms/rewards/__init__.py index 3a48b84b6..3d354437b 100644 --- a/rlinf/algorithms/rewards/__init__.py +++ b/rlinf/algorithms/rewards/__init__.py @@ -1,15 +1,32 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .math import MathReward from .vqa import VQAReward + def register_reward(name: str, reward_class: type): assert name not in reward_registry, f"Reward {name} already registered" reward_registry[name] = reward_class + def get_reward_class(name: str): assert name in reward_registry, f"Reward {name} not found" return reward_registry[name] + reward_registry = {} register_reward("math", MathReward) -register_reward("vqa", VQAReward) \ No newline at end of file +register_reward("vqa", VQAReward) diff --git a/rlinf/algorithms/rewards/math/__init__.py b/rlinf/algorithms/rewards/math/__init__.py index a94ff2dc4..7eb6401a8 100644 --- a/rlinf/algorithms/rewards/math/__init__.py +++ b/rlinf/algorithms/rewards/math/__init__.py @@ -1,5 +1,21 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import List + from omegaconf import DictConfig + from toolkits.math_verifier.verify import math_verify_call @@ -21,4 +37,4 @@ def get_reward( List[float]: A list of reward scores, one for each response. """ - return math_verify_call(response, reference) * self.scale \ No newline at end of file + return math_verify_call(response, reference) * self.scale diff --git a/rlinf/algorithms/rewards/vqa/__init__.py b/rlinf/algorithms/rewards/vqa/__init__.py index d4dd55f20..8175d72a1 100644 --- a/rlinf/algorithms/rewards/vqa/__init__.py +++ b/rlinf/algorithms/rewards/vqa/__init__.py @@ -1,27 +1,58 @@ -import torch +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import List + +import torch from omegaconf import DictConfig + +from .format_rewards import answer_format_reward, think_format_reward from .qa_rewards import qa_accuracy_reward -from .format_rewards import think_format_reward, answer_format_reward class VQAReward: def __init__(self, config: DictConfig): - self.reward_weights = config.get("reward_weights", { - "qa_accuracy": 1.0, - "think_format": 0.0, - "answer_format": 0.0, - }) - for reward_name, reward_weight in self.reward_weights.items(): - assert reward_name in ["qa_accuracy", "think_format", "answer_format"], f"Reward {reward_name} not supported" - assert reward_weight >= 0, f"Reward weight {reward_weight} must be non-negative" - self.reward_weights = [reward_weight["qa_accuracy"], reward_weight["think_format"], reward_weight["answer_format"]] - - self.reward_functions = [qa_accuracy_reward, think_format_reward, answer_format_reward] + reward_weights_config = config.get( + "reward_weights", + { + "qa_accuracy": 1.0, + "think_format": 0.0, + "answer_format": 0.0, + }, + ) + for reward_name, reward_weight in reward_weights_config.items(): + assert reward_name in ["qa_accuracy", "think_format", "answer_format"], ( + f"Reward {reward_name} not supported" + ) + assert reward_weight >= 0, ( + f"Reward weight {reward_weight} must be non-negative" + ) + self.reward_weights = [ + reward_weights_config["qa_accuracy"], + reward_weights_config["think_format"], + reward_weights_config["answer_format"], + ] + + self.reward_functions = [ + qa_accuracy_reward, + think_format_reward, + answer_format_reward, + ] self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - def get_reward(self, completions: List[str], answers: List[str]) -> List[float]: + def get_reward(self, completions: List[str], answers: List[dict]) -> List[float]: rewards = [] for i, reward_function in enumerate(self.reward_functions): if self.reward_weights[i] > 0: @@ -34,9 +65,8 @@ def get_reward(self, completions: List[str], answers: List[str]) -> List[float]: # rewards [num_reward_functions, len(completions)] rewards_tensor = torch.tensor(rewards, device=self.device) weights_tensor = torch.tensor(self.reward_weights, device=self.device) - + # [num_reward_functions, num_completions] * [num_reward_functions, 1] -> [num_completions] final_rewards = (rewards_tensor * weights_tensor.unsqueeze(1)).sum(dim=0) - + return final_rewards.tolist() - \ No newline at end of file diff --git a/rlinf/algorithms/rewards/vqa/format_rewards.py b/rlinf/algorithms/rewards/vqa/format_rewards.py index 2f926e733..205bbe336 100644 --- a/rlinf/algorithms/rewards/vqa/format_rewards.py +++ b/rlinf/algorithms/rewards/vqa/format_rewards.py @@ -1,3 +1,17 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import re from typing import List @@ -5,49 +19,49 @@ def think_format_reward(completions, answers) -> List[float]: """ Think format reward function compatible with GRPO training. - + Reward function that checks if reasoning is enclosed within tags. - + Args: completions: List of model completions (text strings) - + Returns: List of reward scores (1.0 for correct format, 0.0 otherwise) """ pattern = r"^(?!.*)(.*?).*$" rewards = [] - + for completion in completions: completion_text = str(completion).strip() match = re.match(pattern, completion_text, re.DOTALL | re.MULTILINE) rewards.append(1.0 if match else 0.0) - + return rewards def answer_format_reward(completions, answers) -> List[float]: """ Reward function that checks for proper answer formatting. - + Expected format: X. content where X is a choice letter. - + Args: - completions: List of model completions (text strings) - + completions: List of model completions (text strings) + Returns: List of reward scores (1.0 for correct format, 0.0 otherwise) """ rewards = [] - + for completion in completions: completion_text = str(completion).strip() - + # Check for proper answer format: X. content - answer_pattern = r'\s*[A-E]\.\s*.+?\s*' - has_proper_answer = bool(re.search( - answer_pattern, completion_text, re.DOTALL | re.IGNORECASE - )) - + answer_pattern = r"\s*[A-E]\.\s*.+?\s*" + has_proper_answer = bool( + re.search(answer_pattern, completion_text, re.DOTALL | re.IGNORECASE) + ) + rewards.append(1.0 if has_proper_answer else 0.0) - - return rewards \ No newline at end of file + + return rewards diff --git a/rlinf/algorithms/rewards/vqa/qa_rewards.py b/rlinf/algorithms/rewards/vqa/qa_rewards.py index ce7a3443b..2bc9540d3 100644 --- a/rlinf/algorithms/rewards/vqa/qa_rewards.py +++ b/rlinf/algorithms/rewards/vqa/qa_rewards.py @@ -1,3 +1,17 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import re from typing import List @@ -5,44 +19,77 @@ def qa_accuracy_reward(completions, answers) -> List[float]: """ Reward function that evaluates question-answering accuracy for VQA tasks. - + Based on TRL's accuracy_reward pattern but adapted for multiple choice VQA. - + Args: completions: List of model completions (text strings) - answers: List of correct answers (text strings) - + answers: List of correct answers (dict) + Returns: List of reward scores (1.0 for correct, 0.0 for incorrect) """ rewards = [] - + for completion, answer in zip(completions, answers): completion_text = str(completion).strip() - + # Extract answer from completion - look for X. content - patterns = [ - r'\s*[A-E]\.\s*(.*?)\s*', - r'\s*[A-E]\s*(.*?)\s*', - r'\s*(.*?)\s*', - ] - - answer_match = None - for pattern in patterns: - answer_match = re.search(pattern, completion_text, re.DOTALL | re.IGNORECASE) - if answer_match: - break - + answer_match = re.search( + r"\s*([A-E])\.\s*(.*?)\s*", + completion_text, + re.DOTALL | re.IGNORECASE, + ) + if not answer_match: rewards.append(0.0) continue - - predicted_content = answer_match.group(1).strip() - - content_match = _compare_choice_content(predicted_content, answer) - - rewards.append(1.0 if content_match else 0.0) - + + predicted_letter = answer_match.group(1).upper() + predicted_content = answer_match.group(2).strip() + + # Get ground truth from kwargs + correct_answer = answer.get("correct_answer", None) + choices = answer.get("choices", None) + + if correct_answer is None or choices is None: + rewards.append(0.0) + continue + + # Normalize correct_answer to letter format + if isinstance(correct_answer, int): + correct_letter = chr(65 + correct_answer) # 0->A, 1->B, etc. + elif isinstance(correct_answer, str): + correct_letter = correct_answer.strip().upper() + else: + rewards.append(0.0) + continue + + # Parse choices if string format + if isinstance(choices, str): + try: + import ast + + choices = ast.literal_eval(choices) + except (ValueError, SyntaxError): + choices = [str(choices)] + + # Get correct choice content + letter_to_idx = {"A": 0, "B": 1, "C": 2, "D": 3, "E": 4} + if correct_letter in letter_to_idx and letter_to_idx[correct_letter] < len( + choices + ): + correct_content = choices[letter_to_idx[correct_letter]].strip() + else: + rewards.append(0.0) + continue + + # Check accuracy: both letter and content must match + letter_match = predicted_letter == correct_letter + content_match = _compare_choice_content(predicted_content, correct_content) + + rewards.append(1.0 if (letter_match and content_match) else 0.0) + return rewards @@ -51,13 +98,13 @@ def _compare_choice_content(predicted: str, correct: str) -> bool: # Simple normalized comparison pred_normalized = predicted.lower().strip() correct_normalized = correct.lower().strip() - + # Direct match if pred_normalized == correct_normalized: return True - + # Partial match for more flexibility if pred_normalized in correct_normalized or correct_normalized in pred_normalized: return True - - return False \ No newline at end of file + + return False diff --git a/rlinf/data/datasets.py b/rlinf/data/datasets.py index 2669b14e0..75922780b 100644 --- a/rlinf/data/datasets.py +++ b/rlinf/data/datasets.py @@ -67,7 +67,7 @@ def batch_pad_to_fixed_len( class DatasetItem: prompt: torch.Tensor length: int - answer: str + answer: str | dict idx: int solution: Optional[str] = None image_data: Optional[List[Union[bytes, str]]] = None @@ -182,8 +182,6 @@ def __init__( data_paths: Union[List[str], str], config: DictConfig, tokenizer: AutoTokenizer, - *, - lazy_loading: bool = False, ) -> None: super().__init__() self.cfg = config @@ -194,6 +192,7 @@ def __init__( # Delay processor creation; only needed when use_chat_template is True self._processor = None + self.system_prompt = config.data.get("system_prompt", None) self.use_chat_template = bool(config.data.use_chat_template) self.image_keys = list(config.data.image_keys or []) self.prompt_key = config.data.prompt_key @@ -204,7 +203,7 @@ def __init__( self.eos_id = int(self.tokenizer.eos_token_id) # Loading mode - self.lazy_loading = bool(getattr(config.data, "lazy_loading", lazy_loading)) + self.lazy_loading = bool(getattr(config.data, "lazy_loading", False)) self._records = [] self._indices = [] # (path, fmt, row_index_or_offset) @@ -280,11 +279,20 @@ def encode_prompt( self._processor = AutoProcessor.from_pretrained( self.cfg.actor.model.model_path ) + messages = [] + if self.system_prompt is not None: + messages.append( + { + "role": "system", + "content": [{"type": "text", "text": self.system_prompt}], + } + ) + content: List[Dict[str, Any]] = [] for _ in range(max(0, image_count)): content.append({"type": "image"}) content.append({"type": "text", "text": prompt_text}) - messages = [{"role": "user", "content": content}] + messages.append({"role": "user", "content": content}) rendered = self._processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) @@ -495,6 +503,20 @@ def create( @VLMDatasetRegistry.register("robo2vlm") class Robo2VLMDataset(VLMBaseDataset): + def __init__( + self, + data_paths: Union[List[str], str], + config: DictConfig, + tokenizer: AutoTokenizer, + ) -> None: + super().__init__(data_paths, config, tokenizer) + self.system_prompt = ( + "You are a helpful robotic vision assistant specialized in " + "answering questions about robotic manipulation tasks. " + "Use tags to show your reasoning process, " + "then provide your final answer in tags." + ) + def get_image_list(self, dataitem: Dict[str, Any]) -> List[Union[bytes, str, None]]: # Prefer common robo2vlm fields if present, else fallback to configured keys images: List[Any] = [] @@ -558,40 +580,12 @@ def build_prompt_text(self, data_item: Dict[str, Any]) -> str: def postprocess_dataset_item( self, item: DatasetItem, raw: Dict[str, Any] ) -> DatasetItem: - # Derive answer from 'correct_answer' and 'choices' if not provided - if not item.answer or str(item.answer).lower() in {"none", "", "null"}: - choices = raw.get("choices") - ca = raw.get("correct_answer") - try: - # Normalize choices - if isinstance(choices, str): - import ast - - choices = ast.literal_eval(choices) - if not isinstance(choices, list): - choices = [choices] if choices is not None else [] - - ans_val: Optional[str] = None - if isinstance(ca, int) and 0 <= ca < len(choices): - ans_val = str(choices[ca]) - elif isinstance(ca, str): - cstr = ca.strip() - # Letter index like 'A', 'B', ... - if len(cstr) == 1 and "A" <= cstr <= "Z": - idx = ord(cstr) - ord("A") - if 0 <= idx < len(choices): - ans_val = str(choices[idx]) - # Direct match to a choice value - if ans_val is None and choices: - for ch in choices: - if str(ch) == cstr: - ans_val = cstr - break - if ans_val is not None: - item.answer = ans_val - except Exception: - # Keep original if any - pass + answer_dict = { + "choices": raw.get("choices", None), + "correct_answer": raw.get("correct_answer", None), + } + item.answer = answer_dict + return item diff --git a/rlinf/data/io_struct.py b/rlinf/data/io_struct.py index e40fed973..78459e548 100644 --- a/rlinf/data/io_struct.py +++ b/rlinf/data/io_struct.py @@ -266,7 +266,7 @@ class RolloutResult: advantages: Optional[List[float] | torch.Tensor] = None prompt_texts: Optional[List[str]] = None response_texts: Optional[List[str]] = None - answers: Optional[List[str]] = None + answers: Optional[List[str | dict]] = None image_data: Optional[Union[List[List[bytes]], List[List[str]]]] = None # Inference # Only set when recompute_logprobs is False diff --git a/rlinf/runners/math_runner.py b/rlinf/runners/math_runner.py index c0326a8f7..dca52c0c0 100644 --- a/rlinf/runners/math_runner.py +++ b/rlinf/runners/math_runner.py @@ -25,7 +25,7 @@ from tqdm import tqdm from rlinf.data.io_struct import RolloutRequest -from rlinf.scheduler import Channel, Worker +from rlinf.scheduler import Channel from rlinf.scheduler import WorkerGroupFuncResult as Handle from rlinf.utils.data_iter_utils import split_list from rlinf.utils.distributed import ScopedTimer diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index e5432efe9..06c28797d 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -17,7 +17,6 @@ from typing import Dict, List, Tuple import numpy as np -from RLinf.rlinf.algorithms.rewards import get_reward_class import torch from omegaconf import DictConfig from torch.distributed.device_mesh import init_device_mesh @@ -26,6 +25,7 @@ import rlinf.algorithms # noqa: F401 from rlinf.algorithms.registry import actor_loss, calculate_adv_and_returns +from rlinf.algorithms.rewards import get_reward_class from rlinf.algorithms.utils import ( kl_penalty, preprocess_advantages_inputs, @@ -110,7 +110,9 @@ def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): # Reward configurations if not self.cfg.reward.use_reward_model: - assert self.cfg.reward.reward_type in ["math", "vqa"], "only support math and vqa reward!" + assert self.cfg.reward.reward_type in ["math", "vqa"], ( + "only support math and vqa reward!" + ) reward_cls = get_reward_class(self.cfg.reward.reward_type) self.reward = reward_cls(self.cfg.reward) diff --git a/rlinf/workers/reward/reward_worker.py b/rlinf/workers/reward/reward_worker.py index bc98d5a6a..9290c23d6 100644 --- a/rlinf/workers/reward/reward_worker.py +++ b/rlinf/workers/reward/reward_worker.py @@ -1,17 +1,35 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple + import torch -from typing import Dict, Tuple, List from omegaconf import DictConfig -from rlinf.hybrid_engines.fsdp.fsdp_model_manager import FSDPModelManager -from rlinf.scheduler import Worker, Channel + from rlinf.algorithms.rewards import get_reward_class from rlinf.data.io_struct import RolloutResult +from rlinf.hybrid_engines.fsdp.fsdp_model_manager import FSDPModelManager +from rlinf.scheduler import Channel, Worker +from rlinf.utils.placement import ModelParallelComponentPlacement -class RewardWorker(Worker, FSDPModelManager): - def __init__(self, cfg: DictConfig): +class RewardWorker(FSDPModelManager, Worker): + def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): Worker.__init__(self) super().__init__(cfg.reward) self.cfg = cfg + self.component_placement = placement self.total_batch_size_per_dp = ( self.cfg.data.rollout_batch_size @@ -25,7 +43,7 @@ def init_worker(self): self.offload_fsdp_param_and_grad() self.offload_fsdp_optimizer() else: - self.reward = get_reward_class(self.cfg.reward.name)(self.cfg.reward) + self.reward = get_reward_class(self.cfg.reward.reward_type)(self.cfg.reward) def get_batch( self, channel: Channel @@ -65,7 +83,7 @@ def compute_rewards(self, input_channel: Channel, output_channel: Channel): ) def _compute_batch_rewards( - self, batch: Dict[str, torch.Tensor], answers: List[str] + self, batch: Dict[str, torch.Tensor], answers: List[str | dict] ): """Reward computation using non-model based reward.""" @@ -98,4 +116,4 @@ def compute_batch_rewards_with_model(self, batch: Dict[str, torch.Tensor]): with torch.no_grad(): # TODO: fix this rewards = self.model(batch["input_ids"], batch["attention_mask"]) - return rewards \ No newline at end of file + return rewards From 0c38831a87ba3d3473624e56d7ecf6f79732f58c Mon Sep 17 00:00:00 2001 From: guozhen1997 <2997871698@qq.com> Date: Tue, 23 Sep 2025 18:37:55 +0800 Subject: [PATCH 09/38] feat: rename and reorganize example config Signed-off-by: guozhen1997 <2997871698@qq.com> --- .../config/math}/qwen2.5-1.5b-grpo-fsdp.yaml | 2 +- .../qwen2.5-1.5b-grpo-megatron-pipeline.yaml | 0 .../math}/qwen2.5-1.5b-grpo-megatron.yaml | 0 .../config/math}/qwen2.5-1.5b-single-gpu.yaml | 0 .../math}/qwen2.5-32b-grpo-megatron.yaml | 0 .../math}/qwen2.5-7b-grpo-megatron.yaml | 0 .../config/tp_comm_overlap_cfg.yaml | 0 .../config/vqa}/qwen2.5-vl-3b-grpo-fsdp.yaml | 4 +- .../main_math.py => reasoning/main_grpo.py} | 0 .../run_main_grpo_math.sh} | 4 +- .../run_main_grpo_vqa.sh} | 2 +- .../run_placement_autotune.sh | 0 examples/vlm/main_vlm.py | 107 ------------------ rlinf/config.py | 12 +- 14 files changed, 14 insertions(+), 117 deletions(-) rename examples/{math/config => reasoning/config/math}/qwen2.5-1.5b-grpo-fsdp.yaml (99%) rename examples/{math/config => reasoning/config/math}/qwen2.5-1.5b-grpo-megatron-pipeline.yaml (100%) rename examples/{math/config => reasoning/config/math}/qwen2.5-1.5b-grpo-megatron.yaml (100%) rename examples/{math/config => reasoning/config/math}/qwen2.5-1.5b-single-gpu.yaml (100%) rename examples/{math/config => reasoning/config/math}/qwen2.5-32b-grpo-megatron.yaml (100%) rename examples/{math/config => reasoning/config/math}/qwen2.5-7b-grpo-megatron.yaml (100%) rename examples/{math => reasoning}/config/tp_comm_overlap_cfg.yaml (100%) rename examples/{vlm/config => reasoning/config/vqa}/qwen2.5-vl-3b-grpo-fsdp.yaml (99%) rename examples/{math/main_math.py => reasoning/main_grpo.py} (100%) rename examples/{math/run_main_math_grpo.sh => reasoning/run_main_grpo_math.sh} (71%) rename examples/{vlm/run_main_vlm_grpo_megatron.sh => reasoning/run_main_grpo_vqa.sh} (79%) rename examples/{math => reasoning}/run_placement_autotune.sh (100%) delete mode 100644 examples/vlm/main_vlm.py diff --git a/examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml similarity index 99% rename from examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml rename to examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml index 822008f04..bc83ef381 100644 --- a/examples/math/config/qwen2.5-1.5b-grpo-fsdp.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml @@ -13,7 +13,7 @@ cluster: actor,rollout: all runner: - task_type: math + task_type: reasoning logger: log_path: ${runner.output_dir}/${runner.experiment_name} project_name: rlinf diff --git a/examples/math/config/qwen2.5-1.5b-grpo-megatron-pipeline.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml similarity index 100% rename from examples/math/config/qwen2.5-1.5b-grpo-megatron-pipeline.yaml rename to examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml diff --git a/examples/math/config/qwen2.5-1.5b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml similarity index 100% rename from examples/math/config/qwen2.5-1.5b-grpo-megatron.yaml rename to examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml diff --git a/examples/math/config/qwen2.5-1.5b-single-gpu.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml similarity index 100% rename from examples/math/config/qwen2.5-1.5b-single-gpu.yaml rename to examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml diff --git a/examples/math/config/qwen2.5-32b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml similarity index 100% rename from examples/math/config/qwen2.5-32b-grpo-megatron.yaml rename to examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml diff --git a/examples/math/config/qwen2.5-7b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml similarity index 100% rename from examples/math/config/qwen2.5-7b-grpo-megatron.yaml rename to examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml diff --git a/examples/math/config/tp_comm_overlap_cfg.yaml b/examples/reasoning/config/tp_comm_overlap_cfg.yaml similarity index 100% rename from examples/math/config/tp_comm_overlap_cfg.yaml rename to examples/reasoning/config/tp_comm_overlap_cfg.yaml diff --git a/examples/vlm/config/qwen2.5-vl-3b-grpo-fsdp.yaml b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml similarity index 99% rename from examples/vlm/config/qwen2.5-vl-3b-grpo-fsdp.yaml rename to examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml index 02b5d8ea4..c96445a06 100644 --- a/examples/vlm/config/qwen2.5-vl-3b-grpo-fsdp.yaml +++ b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml @@ -13,7 +13,7 @@ cluster: actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: ${runner.output_dir}/${runner.experiment_name} project_name: rlinf @@ -207,7 +207,7 @@ reward: group_name: "RewardGroup" use_reward_model: false reward_type: 'vqa' - # reward_scale: 5.0 + reward_scale: 1.0 reward_weights: qa_accuracy: 1.0 think_format: 0.0 diff --git a/examples/math/main_math.py b/examples/reasoning/main_grpo.py similarity index 100% rename from examples/math/main_math.py rename to examples/reasoning/main_grpo.py diff --git a/examples/math/run_main_math_grpo.sh b/examples/reasoning/run_main_grpo_math.sh similarity index 71% rename from examples/math/run_main_math_grpo.sh rename to examples/reasoning/run_main_grpo_math.sh index dc2f75ee0..56e13c7c2 100644 --- a/examples/math/run_main_math_grpo.sh +++ b/examples/reasoning/run_main_grpo_math.sh @@ -13,9 +13,9 @@ MEGATRON_PATH=/opt/Megatron-LM export PYTHONPATH=${REPO_PATH}:${MEGATRON_PATH}:$PYTHONPATH if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-1.5b-grpo-fsdp" + CONFIG_NAME="qwen2.5-1.5b-grpo-megatron" else CONFIG_NAME=$1 fi -python ${REPO_PATH}/examples/math/main_math.py --config-path ${CONFIG_PATH}/config/ --config-name $CONFIG_NAME \ No newline at end of file +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path ${CONFIG_PATH}/config/math/ --config-name $CONFIG_NAME \ No newline at end of file diff --git a/examples/vlm/run_main_vlm_grpo_megatron.sh b/examples/reasoning/run_main_grpo_vqa.sh similarity index 79% rename from examples/vlm/run_main_vlm_grpo_megatron.sh rename to examples/reasoning/run_main_grpo_vqa.sh index 99165babb..1b41f415c 100644 --- a/examples/vlm/run_main_vlm_grpo_megatron.sh +++ b/examples/reasoning/run_main_grpo_vqa.sh @@ -18,4 +18,4 @@ else CONFIG_NAME=$1 fi -python ${REPO_PATH}/examples/vlm/main_vlm.py --config-path ${CONFIG_PATH}/config/ --config-name $CONFIG_NAME \ No newline at end of file +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path ${CONFIG_PATH}/config/vqa/ --config-name $CONFIG_NAME \ No newline at end of file diff --git a/examples/math/run_placement_autotune.sh b/examples/reasoning/run_placement_autotune.sh similarity index 100% rename from examples/math/run_placement_autotune.sh rename to examples/reasoning/run_placement_autotune.sh diff --git a/examples/vlm/main_vlm.py b/examples/vlm/main_vlm.py deleted file mode 100644 index 6ed11dd75..000000000 --- a/examples/vlm/main_vlm.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2025 The RLinf Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import hydra -import torch.multiprocessing as mp -from omegaconf.omegaconf import OmegaConf - -from rlinf.config import validate_cfg -from rlinf.data.datasets import create_rl_dataset -from rlinf.data.tokenizers import hf_tokenizer -from rlinf.runners.math_runner import MathRunner -from rlinf.scheduler import Cluster -from rlinf.utils.placement import ModelParallelComponentPlacement, PlacementMode -from rlinf.utils.utils import output_redirector -from rlinf.workers.actor import get_actor_worker -from rlinf.workers.inference.megatron_inference_worker import MegatronInference -from rlinf.workers.reward.reward_worker import RewardWorker -from rlinf.workers.rollout.utils import get_rollout_backend_worker - -"""Script to start GRPO training""" -mp.set_start_method("spawn", force=True) - - -@hydra.main(version_base="1.1") -@output_redirector -def main(cfg) -> None: - cfg = validate_cfg(cfg) - print(json.dumps(OmegaConf.to_container(cfg, resolve=True), indent=2)) - - cluster = Cluster( - num_nodes=cfg.cluster.num_nodes, num_gpus_per_node=cfg.cluster.num_gpus_per_node - ) - component_placement = ModelParallelComponentPlacement(cfg) - - rollout_worker_cls = get_rollout_backend_worker(cfg, component_placement) - - # Rollout group - rollout_placement_strategy = component_placement.get_strategy("rollout") - rollout_group = rollout_worker_cls.create_group(cfg, component_placement).launch( - cluster, - name=cfg.rollout.group_name, - placement_strategy=rollout_placement_strategy, - ) - - # Inference group - inference_group = None - if ( - component_placement.placement_mode == PlacementMode.DISAGGREGATED - and cfg.algorithm.recompute_logprobs - ): - inference_placement_strategy = component_placement.get_strategy("inference") - inference_group = MegatronInference.create_group( - cfg, component_placement - ).launch( - cluster, - name=cfg.inference.group_name, - placement_strategy=inference_placement_strategy, - ) - - # Reward group - reward_placement_strategy = component_placement.get_strategy("reward") - reward_group = RewardWorker.create_group(cfg, component_placement).launch( - cluster, - name=cfg.reward.group_name, - placement_strategy=reward_placement_strategy, - ) - - # GRPO Actor group - actor_worker_cls = get_actor_worker(cfg) - actor_placement_strategy = component_placement.get_strategy("actor") - actor_group = actor_worker_cls.create_group(cfg, component_placement).launch( - cluster, name=cfg.actor.group_name, placement_strategy=actor_placement_strategy - ) - - tokenizer = hf_tokenizer(cfg.actor.tokenizer.tokenizer_model) - train_ds, val_ds = create_rl_dataset(cfg, tokenizer) - - runner = MathRunner( - cfg=cfg, - placement=component_placement, - train_dataset=train_ds, - val_dataset=val_ds, - rollout=rollout_group, - inference=inference_group, - actor=actor_group, - reward=reward_group, - ) - - runner.init_workers() - runner.run() - - -if __name__ == "__main__": - main() diff --git a/rlinf/config.py b/rlinf/config.py index 0f3a21903..9c31cd49f 100644 --- a/rlinf/config.py +++ b/rlinf/config.py @@ -36,6 +36,7 @@ SUPPORTED_MODEL_ARCHS = ["qwen2.5", "qwen2.5_vl", "openvla", "openvla_oft"] SUPPORTED_ROLLOUT_BACKENDS = ["sglang", "vllm"] +SUPPORTED_TASK_TYPE = ["embodied", "reasoning", "coding_online_rl"] __all__ = ["build_config"] @@ -528,7 +529,7 @@ def get_robot_control_mode(robot: str): return cfg -def validate_math_cfg(cfg: DictConfig) -> DictConfig: +def validate_reasoning_cfg(cfg: DictConfig) -> DictConfig: assert cfg.rollout.model_arch in SUPPORTED_MODEL_ARCHS, ( f"Model {cfg.rollout.model_arch} is not supported" ) @@ -607,11 +608,14 @@ def validate_coding_online_rl_cfg(cfg: DictConfig) -> DictConfig: def validate_cfg(cfg: DictConfig) -> DictConfig: OmegaConf.set_struct(cfg, True) + assert cfg.runner.task_type in SUPPORTED_TASK_TYPE, ( + f"task_type must be one of {SUPPORTED_TASK_TYPE}" + ) if cfg.runner.task_type == "embodied": cfg = validate_embodied_cfg(cfg) - if cfg.runner.task_type == "math": - cfg = validate_math_cfg(cfg) - if cfg.runner.task_type == "coding_online_rl": + elif cfg.runner.task_type == "reasoning": + cfg = validate_reasoning_cfg(cfg) + elif cfg.runner.task_type == "coding_online_rl": cfg = validate_coding_online_rl_cfg(cfg) if ( From e6ebd609404a79e0ff3cb6f735020a24e2493b5e Mon Sep 17 00:00:00 2001 From: guozhen1997 <2997871698@qq.com> Date: Tue, 23 Sep 2025 20:43:14 +0800 Subject: [PATCH 10/38] fix: fix ruff, fix merge bugs Signed-off-by: guozhen1997 <2997871698@qq.com> --- .../sglang/sglang_0_4_6/sgl_scheduler.py | 1 - rlinf/utils/convertor/utils.py | 72 +++++++++++-------- 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py b/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py index 684f8d333..9a69b8548 100644 --- a/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py +++ b/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py @@ -29,7 +29,6 @@ ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, ) -from sglang.srt.managers.mm_utils import init_embedding_cache from sglang.srt.managers.scheduler import Scheduler as _Scheduler from sglang.srt.managers.scheduler import logger from sglang.srt.server_args import PortArgs, ServerArgs diff --git a/rlinf/utils/convertor/utils.py b/rlinf/utils/convertor/utils.py index e187bb919..761218650 100644 --- a/rlinf/utils/convertor/utils.py +++ b/rlinf/utils/convertor/utils.py @@ -31,49 +31,59 @@ class TransformFunc: def _split_gqa_tensor( tensor: torch.Tensor, new_statedict: dict, weight_names: List[str], config ) -> None: - """ - Private helper to split a GQA-combined tensor (weight or bias). - """ hidden_size = config.model_config.hidden_size num_attention_heads = config.model_config.num_attention_heads - num_key_value_heads = ( - config.model_config.num_query_groups or num_attention_heads - ) + num_query_groups = config.model_config.num_query_groups or num_attention_heads head_dim = hidden_size // num_attention_heads - tp_size = config.model_config.tensor_model_parallel_size - - assert num_key_value_heads % tp_size == 0, ( - "num_key_value_heads must be divisible by tensor parallel size" + target_tp = config.reshard_tp_size + assert num_query_groups % target_tp == 0, ( + "num_query_groups must be divisible by reshard_tp_size" ) + local_num_query_groups = num_query_groups // target_tp - q_heads_per_rank = num_attention_heads // tp_size - kv_heads_per_rank = num_key_value_heads // tp_size - - q_shard_size = q_heads_per_rank * head_dim - k_shard_size = kv_heads_per_rank * head_dim - v_shard_size = kv_heads_per_rank * head_dim + # heads per query group + assert num_attention_heads % num_query_groups == 0, ( + "num_attention_heads must be divisible by num_query_groups" + ) + q_heads_per_group = num_attention_heads // num_query_groups - shard_size = q_shard_size + k_shard_size + v_shard_size + num_channel_qkv = q_heads_per_group + 2 - q_shards, k_shards, v_shards = [], [], [] + if tensor.ndim == 2: + # Weight: [out_features, in_features] + out_features, in_features = tensor.shape + expected_out = local_num_query_groups * num_channel_qkv * head_dim + assert out_features == expected_out, ( + f"Unexpected fused QKV weight shape {tensor.shape}, expect " + f"[{expected_out}, {in_features}] (local groups={local_num_query_groups})" + ) - # [Qi,Ki,Vi] - for shard in tensor.split(shard_size, dim=0): - # Qi, Ki, Vi - q_shard, k_shard, v_shard = shard.split( - [q_shard_size, k_shard_size, v_shard_size], dim=0 + qkv = tensor.view( + local_num_query_groups, num_channel_qkv, head_dim, in_features + ) + q, k, v = torch.split( + qkv, [q_heads_per_group, 1, 1], dim=1 + ) # shapes: [G, qh, D, In], [G,1,D,In], [G,1,D,In] + q_full = q.reshape(-1, in_features).contiguous() + k_full = k.reshape(-1, in_features).contiguous() + v_full = v.reshape(-1, in_features).contiguous() + else: + # Bias: [out_features] + out_features = tensor.shape[0] + expected_out = local_num_query_groups * num_channel_qkv * head_dim + assert out_features == expected_out, ( + f"Unexpected fused QKV bias shape {tensor.shape}, expect " + f"[{expected_out}] (local groups={local_num_query_groups})" ) - q_shards.append(q_shard) - k_shards.append(k_shard) - v_shards.append(v_shard) - # cat - q_full = torch.cat(q_shards, dim=0) - k_full = torch.cat(k_shards, dim=0) - v_full = torch.cat(v_shards, dim=0) + qkv = tensor.view(local_num_query_groups, num_channel_qkv, head_dim) + q, k, v = torch.split(qkv, [q_heads_per_group, 1, 1], dim=1) + q_full = q.reshape(-1).contiguous() + k_full = k.reshape(-1).contiguous() + v_full = v.reshape(-1).contiguous() - # saved + # Save to target names new_statedict[weight_names[0]] = q_full.clone() new_statedict[weight_names[1]] = k_full.clone() new_statedict[weight_names[2]] = v_full.clone() From dc446fc91b6b01066ccbb4b71ec30c1ea6a316ef Mon Sep 17 00:00:00 2001 From: guozhen1997 <2997871698@qq.com> Date: Thu, 25 Sep 2025 14:38:26 +0800 Subject: [PATCH 11/38] fix: fix multi modal inputs Signed-off-by: guozhen1997 <2997871698@qq.com> --- .../config/math/qwen2.5-1.5b-grpo-fsdp.yaml | 4 +- .../config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml | 1 - examples/reasoning/main_grpo.py | 4 +- rlinf/data/datasets.py | 49 ++++++-- rlinf/data/io_struct.py | 110 ++++++++++++++++-- .../{math_runner.py => reasoning_runner.py} | 15 ++- rlinf/utils/placement.py | 17 +-- rlinf/workers/actor/fsdp_actor_worker.py | 45 ++----- rlinf/workers/rollout/sglang/sglang_worker.py | 1 + rlinf/workers/rollout/vllm/vllm_worker.py | 1 + 10 files changed, 174 insertions(+), 73 deletions(-) rename rlinf/runners/{math_runner.py => reasoning_runner.py} (97%) diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml index bc83ef381..e17486fe2 100644 --- a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml @@ -8,9 +8,9 @@ hydra: cluster: num_nodes: 1 - num_gpus_per_node: 4 component_placement: - actor,rollout: all + actor: 0-3 + rollout: 4-7 runner: task_type: reasoning diff --git a/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml index c96445a06..6643e74bd 100644 --- a/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml +++ b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml @@ -8,7 +8,6 @@ hydra: cluster: num_nodes: 1 - num_gpus_per_node: 8 component_placement: actor,rollout,reward: all diff --git a/examples/reasoning/main_grpo.py b/examples/reasoning/main_grpo.py index 5b9ee61e1..30073d562 100644 --- a/examples/reasoning/main_grpo.py +++ b/examples/reasoning/main_grpo.py @@ -21,7 +21,7 @@ from rlinf.config import validate_cfg from rlinf.data.datasets import create_rl_dataset from rlinf.data.tokenizers import hf_tokenizer -from rlinf.runners.math_runner import MathRunner +from rlinf.runners.reasoning_runner import ReasoningRunner from rlinf.scheduler import Cluster from rlinf.utils.placement import ModelParallelComponentPlacement, PlacementMode from rlinf.utils.utils import output_redirector @@ -86,7 +86,7 @@ def main(cfg) -> None: tokenizer = hf_tokenizer(cfg.actor.tokenizer.tokenizer_model) train_ds, val_ds = create_rl_dataset(cfg, tokenizer) - runner = MathRunner( + runner = ReasoningRunner( cfg=cfg, placement=component_placement, train_dataset=train_ds, diff --git a/rlinf/data/datasets.py b/rlinf/data/datasets.py index 75922780b..677377a68 100644 --- a/rlinf/data/datasets.py +++ b/rlinf/data/datasets.py @@ -16,12 +16,13 @@ import logging import os from dataclasses import dataclass +from io import BytesIO from typing import Any, Callable, Dict, List, Optional, Tuple, Union import pandas as pd import torch from omegaconf import DictConfig -from PIL.Image import Image +from PIL import Image from torch.utils.data import Dataset from transformers import AutoProcessor, AutoTokenizer @@ -73,6 +74,7 @@ class DatasetItem: image_data: Optional[List[Union[bytes, str]]] = None prompt_text: Optional[str] = None meta: Optional[Dict[str, Any]] = None + multi_modal_inputs: Optional[Dict[str, Any]] = None class MathDataset(Dataset): @@ -247,7 +249,7 @@ def get_image_list(self, dataitem: Dict[str, Any]) -> List[Union[bytes, str, Non v = dataitem.get(k, None) if v is None: continue - if isinstance(v, Image): + if isinstance(v, Image.Image): images.append(v) elif isinstance(v, dict) and "bytes" in v: images.append(v["bytes"]) @@ -268,7 +270,7 @@ def build_prompt_text(self, data_item: Dict[str, Any]) -> str: return str(q) def encode_prompt( - self, prompt_text: str, image_count: int + self, prompt_text: str, images ) -> Tuple[torch.Tensor, int, Optional[str]]: """ Return (token_ids[L], length, prompt_text_used). If using chat template, encode with processor. @@ -289,28 +291,47 @@ def encode_prompt( ) content: List[Dict[str, Any]] = [] - for _ in range(max(0, image_count)): + for _ in range(max(0, len(images))): content.append({"type": "image"}) content.append({"type": "text", "text": prompt_text}) messages.append({"role": "user", "content": content}) rendered = self._processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - ids = self._processor(text=[rendered], padding=True, return_tensors="pt")[ - "input_ids" - ] + + images_inputs = [] + for image in images: + image_obj = None + if isinstance(image, Image.Image): + image_obj = image.convert("RGB") + if isinstance(image, (bytes, bytearray)): + image_obj = Image.open(BytesIO(image)).convert("RGB") + images_inputs.append(image_obj) + + inputs = self._processor( + text=[rendered], images=images_inputs, padding=True, return_tensors="pt" + ) + inputs.pop("attention_mask") + inputs.pop("input_ids") + ids = self._processor( + text=[rendered], images=None, padding=True, return_tensors="pt" + )["input_ids"] if isinstance(ids, torch.Tensor): if ids.dim() == 2 and ids.size(0) == 1: ids = ids.squeeze(0) ids = ids.to(dtype=torch.long) else: ids = torch.tensor(ids, dtype=torch.long) - return ids, int(ids.numel()), rendered + + multi_modal_inputs = {} + for k, v in inputs.items(): + multi_modal_inputs[k] = v + return ids, int(ids.numel()), rendered, multi_modal_inputs else: # fallback: tokenizer only ids_list = self.tokenizer.encode(prompt_text) ids = torch.as_tensor(ids_list, dtype=torch.long) - return ids, int(ids.numel()), prompt_text + return ids, int(ids.numel()), prompt_text, {} def postprocess_dataset_item( self, item: DatasetItem, raw: Dict[str, Any] @@ -450,7 +471,9 @@ def _load_single_lazy(self, path: str, fmt: str, key: Any) -> Dict[str, Any]: def _process_raw_record(self, raw: Dict[str, Any], idx: int) -> DatasetItem: images = self.get_image_list(raw) prompt_text = self.build_prompt_text(raw) - prompt_ids, plen, rendered_text = self.encode_prompt(prompt_text, len(images)) + prompt_ids, plen, rendered_text, multi_modal_inputs = self.encode_prompt( + prompt_text, images + ) if plen > self.max_prompt_length: prompt_ids = prompt_ids[: self.max_prompt_length] @@ -470,6 +493,7 @@ def _process_raw_record(self, raw: Dict[str, Any], idx: int) -> DatasetItem: prompt_text=rendered_text or prompt_text, solution=solution_val, meta=None, + multi_modal_inputs=multi_modal_inputs, ) return self.postprocess_dataset_item(item, raw) @@ -543,7 +567,7 @@ def get_image_list(self, dataitem: Dict[str, Any]) -> List[Union[bytes, str, Non for v in images: if v is None: continue - if isinstance(v, Image): + if isinstance(v, Image.Image): normed.append(v) elif isinstance(v, dict) and "bytes" in v: normed.append(v["bytes"]) # raw bytes @@ -686,5 +710,8 @@ def collate_fn(data_list: List["DatasetItem"]) -> Dict[str, Any]: ], # List[Optional[List[bytes|str]]] "prompt_text": [it.prompt_text for it in data_list], # List[Optional[str]] "meta": [it.meta for it in data_list], # List[Optional[dict]] + "multi_modal_inputs": [ + it.multi_modal_inputs for it in data_list + ], # List[Optional[dict]] } return batch diff --git a/rlinf/data/io_struct.py b/rlinf/data/io_struct.py index 78459e548..2c07bc9b2 100644 --- a/rlinf/data/io_struct.py +++ b/rlinf/data/io_struct.py @@ -47,16 +47,70 @@ class RolloutRequest: Attr input_ids: List of input token IDs for rollout n: Number of completions to generate for each input - idx: List of unique identifiers for the requests, used for tracking - input_lengths: List of lengths of the input sequences, corresponding to input_ids image_data: list of image data (bytes or URLs) for multimodal inputs answers: Optional list of answers for the requests, if available + multi_modal_inputs: list of multi-modal inputs for the requests """ n: int input_ids: List[List[int]] - answers: List[str] image_data: Union[List[List[bytes]], List[List[str]]] + answers: List[str] + multi_modal_inputs: List[Dict] + + def repeat(self) -> "RolloutRequest": + """Repeat each input in the RolloutRequest a specified number of times. + + Args: + times (int): The number of times to repeat each input. + + Returns: + RolloutRequest: A new RolloutRequest with repeated inputs. + """ + assert self.n > 0, "n must be greater than 0" + + input_ids, answers = zip( + *[ + (input_id, answer) + for input_id, answer in zip(self.input_ids, self.answers) + for _ in range(self.n) + ] + ) + return RolloutRequest( + n=self.n, + input_ids=list(input_ids), + answers=list(answers), + ) + + def split(self, num_splits: int) -> List["RolloutRequest"]: + """Split the RolloutRequest into multiple smaller requests. + + Args: + num_splits (int): The number of splits to create. + + Returns: + List[RolloutRequest]: A list of smaller RolloutRequest instances. + """ + assert num_splits > 0, "num_splits must be greater than 0" + assert len(self.input_ids) % num_splits == 0, ( + f"Input IDs length {len(self.input_ids)} is not divisible by num_splits {num_splits}" + ) + + input_ids_split_list = split_list(self.input_ids, num_splits) + answers_split_list = split_list(self.answers, num_splits) + + splitted_requests = [] + for input_ids_batch, answers_batch in zip( + input_ids_split_list, answers_split_list + ): + request = RolloutRequest( + n=self.n, + input_ids=input_ids_batch, + answers=answers_batch, + ) + splitted_requests.append(request) + + return splitted_requests def repeat(self) -> "RolloutRequest": """Repeat each input in the RolloutRequest a specified number of times. @@ -115,19 +169,23 @@ def split(self, num_splits: int) -> List["RolloutRequest"]: def repeat_and_split( self, rollout_batch_size: Optional[int] = None ) -> List["RolloutRequest"]: - input_ids, answers, image_data = zip( + input_ids, answers, image_data, multi_modal_inputs = zip( *[ - (input_id, answer, image_data) - for input_id, answer, image_data in zip( - self.input_ids, self.answers, self.image_data + (input_id, answer, image_data, multi_modal_inputs) + for input_id, answer, image_data, multi_modal_inputs in zip( + self.input_ids, + self.answers, + self.image_data, + self.multi_modal_inputs, ) for _ in range(self.n) ] ) - input_ids, answers, image_data = ( + input_ids, answers, image_data, multi_modal_inputs = ( list(input_ids), list(answers), list(image_data), + list(multi_modal_inputs), ) # Split input ids based on rollout_batch_size_per_gpu @@ -143,15 +201,25 @@ def repeat_and_split( input_ids_split_list = split_list(input_ids, num_batches) answers_split_list = split_list(answers, num_batches) image_data_split_list = split_list(image_data, num_batches) - - for input_ids_batch, answers_batch, image_data_batch in zip( - input_ids_split_list, answers_split_list, image_data_split_list + multi_modal_inputs_split_list = split_list(multi_modal_inputs, num_batches) + + for ( + input_ids_batch, + answers_batch, + image_data_batch, + multi_modal_inputs_batch, + ) in zip( + input_ids_split_list, + answers_split_list, + image_data_split_list, + multi_modal_inputs_split_list, ): request = RolloutRequest( n=self.n, input_ids=input_ids_batch, answers=answers_batch, image_data=image_data_batch, + multi_modal_inputs=multi_modal_inputs_batch, ) splitted_requests.append(request) @@ -268,6 +336,7 @@ class RolloutResult: response_texts: Optional[List[str]] = None answers: Optional[List[str | dict]] = None image_data: Optional[Union[List[List[bytes]], List[List[str]]]] = None + multi_modal_inputs: Optional[List[dict]] = None # Inference # Only set when recompute_logprobs is False rollout_logprobs: Optional[List[List[float]]] = None @@ -320,6 +389,7 @@ def from_vllm_results( group_size: int, results: List[VllmRequestOutput], answers: Optional[List[str]] = None, + multi_modal_inputs: Optional[List[Dict]] = None, return_logprobs: bool = False, ) -> "RolloutResult": def get_logprobs( @@ -378,6 +448,7 @@ def get_logprobs( response_ids=response_ids, response_lengths=response_lengths, response_texts=response_texts, + multi_modal_inputs=multi_modal_inputs, is_end=is_end, ) if return_logprobs: @@ -391,6 +462,7 @@ def from_sglang_results( input_ids: List[List[int]], answers: Optional[List[List[int]]] = None, image_data: Optional[Union[List[List[bytes]], List[List[str]]]] = None, + multi_modal_inputs: Optional[List[Dict]] = None, return_logprobs: bool = False, ) -> "RolloutResult": """Create a MathRolloutResult from the given results and input IDs. @@ -418,6 +490,7 @@ def from_sglang_results( response_ids=[res["output_ids"] for res in results], answers=answers, image_data=image_data, + multi_modal_inputs=multi_modal_inputs, is_end=[ res["meta_info"]["finish_reason"]["type"] == "stop" for res in results ], @@ -584,6 +657,12 @@ def _split_single_result_by_group( if rollout_result.image_data is not None: image_data_split = split_list(rollout_result.image_data, num_groups) + multi_modal_inputs_split = None + if rollout_result.multi_modal_inputs is not None: + multi_modal_inputs_split = split_list( + rollout_result.multi_modal_inputs, num_groups + ) + prompt_texts_split = None if rollout_result.prompt_texts is not None: prompt_texts_split = split_list(rollout_result.prompt_texts, num_groups) @@ -641,6 +720,9 @@ def _split_single_result_by_group( image_data=image_data_split[i] if image_data_split is not None else None, + multi_modal_inputs=multi_modal_inputs_split[i] + if multi_modal_inputs_split is not None + else None, prompt_texts=prompt_texts_split[i] if prompt_texts_split is not None else None, @@ -761,6 +843,12 @@ def to_actor_batch( "response_lengths": response_lengths.cuda(), } + if ( + self.multi_modal_inputs is not None + and self.multi_modal_inputs[0] is not None + ): + batch["multi_modal_inputs"] = self.multi_modal_inputs + if self.advantages is not None: if isinstance(self.advantages, torch.Tensor): batch["advantages"] = self.advantages.cuda() diff --git a/rlinf/runners/math_runner.py b/rlinf/runners/reasoning_runner.py similarity index 97% rename from rlinf/runners/math_runner.py rename to rlinf/runners/reasoning_runner.py index dca52c0c0..404154fe2 100644 --- a/rlinf/runners/math_runner.py +++ b/rlinf/runners/reasoning_runner.py @@ -44,8 +44,8 @@ logging.getLogger().setLevel(logging.INFO) -class MathRunner: - """Runner for math model training.""" +class ReasoningRunner: + """Runner for reasoning task RL training.""" def __init__( self, @@ -277,19 +277,24 @@ def _put_batch(self, batch: Dict[str, torch.Tensor]): lengths = batch["length"].tolist() answers = batch["answer"] image_data = batch["image_data"] - prompts = [ids[-pmp_len:] for ids, pmp_len in zip(prompt_ids, lengths)] + multi_modal_inputs = batch["multi_modal_inputs"] + prompt_ids = [ids[-pmp_len:] for ids, pmp_len in zip(prompt_ids, lengths)] rollout_dp_size = self.component_placement.rollout_dp_size - for input_ids, answers, image_data in zip( - split_list(prompts, rollout_dp_size, enforce_divisible_batch=False), + for input_ids, answers, image_data, multi_modal_inputs in zip( + split_list(prompt_ids, rollout_dp_size, enforce_divisible_batch=False), split_list(answers, rollout_dp_size, enforce_divisible_batch=False), split_list(image_data, rollout_dp_size, enforce_divisible_batch=False), + split_list( + multi_modal_inputs, rollout_dp_size, enforce_divisible_batch=False + ), ): request = RolloutRequest( n=self.cfg.algorithm.group_size, input_ids=input_ids, answers=answers, image_data=image_data, + multi_modal_inputs=multi_modal_inputs, ) self.dataloader_channel.put(request, async_op=True) diff --git a/rlinf/utils/placement.py b/rlinf/utils/placement.py index d6fa798c8..6ecea767f 100644 --- a/rlinf/utils/placement.py +++ b/rlinf/utils/placement.py @@ -225,7 +225,7 @@ def __init__(self, config: DictConfig, cluster: Cluster): len(self._inference_gpus) if self._inference_gpus else 0 ) self._rollout_num_gpus = len(self._rollout_gpus) - self._reward_num_gpus = len(self._reward_gpus) + self._reward_num_gpus = len(self._reward_gpus) if self._reward_gpus else 0 if self._is_collocated(): assert self._inference_gpus is None, ( @@ -279,22 +279,19 @@ def _generate_placements(self): self._actor_gpus[0], self._actor_gpus[-1] ) - actor_tp_size = self._config.actor.model.tensor_model_parallel_size - rollout_tp_size = self._config.rollout.tensor_parallel_size - if actor_tp_size > rollout_tp_size: - assert actor_tp_size % rollout_tp_size == 0, ( - f"Actor TP size ({actor_tp_size}) must be divisible by Rollout TP size ({rollout_tp_size})" + if self.actor_tp_size > self.rollout_tp_size: + assert self.actor_tp_size % self.rollout_tp_size == 0, ( + f"Actor TP size ({self.actor_tp_size}) must be divisible by Rollout TP size ({self.rollout_tp_size})" ) stride = ( self.actor_tp_size // self.rollout_tp_size if self.actor_tp_size > self.rollout_tp_size else 1 ) - stride = actor_tp_size // rollout_tp_size self._placements["rollout"] = PackedPlacementStrategy( self._rollout_gpus[0], self._rollout_gpus[-1], - num_accelerators_per_process=rollout_tp_size, + num_accelerators_per_process=self.rollout_tp_size, stride=stride, ) self._placements["reward"] = PackedPlacementStrategy( @@ -396,3 +393,7 @@ def rollout_tp_size(self) -> int: @property def rollout_world_size(self) -> int: return self._rollout_num_gpus + + @property + def reward_world_size(self) -> int: + return self._reward_num_gpus diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index 06c28797d..2604c0bb0 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -198,7 +198,7 @@ def put_result(self, result: RolloutResult, channel: Channel): def _load_weight_and_optimizer(self, channel: Channel): # Acquire the GPUs to ensure that no one is using them before loading models # Otherwise, it may lead to OOM - with channel.gpu_lock: + with channel.device_lock: if self.cfg.actor.get("enable_offload", False): self.load_fsdp_param_and_grad(self.device) self.load_fsdp_optimizer(self.device) @@ -234,30 +234,18 @@ def run_training(self, input_channel: Channel): == 0 ) - self.gradient_accumulation = ( - self.cfg.actor.global_batch_size - // self.cfg.actor.micro_batch_size - // self._world_size - ) - training_metrics_list = [] # Global batch iterations with self.worker_timer(): for global_batch in global_batches: train_global_batch_size = global_batch["input_ids"].shape[0] - assert ( - train_global_batch_size - == self.cfg.actor.global_batch_size - // torch.distributed.get_world_size() - ) + assert train_global_batch_size % self.cfg.actor.micro_batch_size == 0, ( - f"{train_global_batch_size=}, {self.cfg.actor.micro_batch_size}" + f"{train_global_batch_size=}, {self.cfg.actor.micro_batch_size=}" ) self.gradient_accumulation = ( - self.cfg.actor.global_batch_size - // self.cfg.actor.micro_batch_size - // self._world_size + train_global_batch_size // self.cfg.actor.micro_batch_size ) # split batch into micro_batches train_micro_batches = get_iterator_k_split( @@ -269,27 +257,18 @@ def run_training(self, input_channel: Channel): metrics = {} for _, m_batch in enumerate(train_micro_batches): for k, v in m_batch.items(): - m_batch[k] = v.to(f"cuda:{int(os.environ['LOCAL_RANK'])}") + m_batch[k] = v.cuda() if isinstance(v, torch.Tensor) else v multi_modal_inputs = {} if "multi_modal_inputs" in m_batch.keys(): - if ( - "image_bound" in m_batch["multi_modal_inputs"][0] - ): # minicpm-o logic - for key in m_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = [ + for key in m_batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat( + [ inputs[key] for inputs in m_batch["multi_modal_inputs"] - ] - else: - for key in m_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat( - [ - inputs[key] - for inputs in m_batch["multi_modal_inputs"] - ], - dim=0, - ) + ], + dim=0, + ).cuda() input_ids = m_batch["input_ids"] attention_mask = m_batch["attention_mask"] @@ -308,7 +287,7 @@ def run_training(self, input_channel: Channel): position_ids=position_ids, **multi_modal_inputs, use_cache=False, - ) # prevent model thinks we are generating + ) logits = output.logits diff --git a/rlinf/workers/rollout/sglang/sglang_worker.py b/rlinf/workers/rollout/sglang/sglang_worker.py index 41f026879..764702c63 100644 --- a/rlinf/workers/rollout/sglang/sglang_worker.py +++ b/rlinf/workers/rollout/sglang/sglang_worker.py @@ -193,6 +193,7 @@ def rollout(self, input_channel: Channel, output_channel: Channel): request.input_ids, request.answers, request.image_data, + request.multi_modal_inputs, self._return_logprobs, ) rollout_results.append(rollout_result) diff --git a/rlinf/workers/rollout/vllm/vllm_worker.py b/rlinf/workers/rollout/vllm/vllm_worker.py index d5ecd11fc..54c1c7ee5 100644 --- a/rlinf/workers/rollout/vllm/vllm_worker.py +++ b/rlinf/workers/rollout/vllm/vllm_worker.py @@ -399,6 +399,7 @@ async def rollout_and_return( group_size=self._cfg.algorithm.group_size, results=vllm_results, answers=request.answers, + multi_modal_inputs=request.multi_modal_inputs, return_logprobs=self._return_logprobs, ) if self._placement.is_disaggregated: From a04f2d0ac8c44a85da8d6ab2b0bc35cb6ca66861 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Mon, 29 Sep 2025 03:51:36 +0000 Subject: [PATCH 12/38] fix(math): fix some bugs when running math model Signed-off-by: Bo Dai --- .../config/math/qwen2.5-1.5b-grpo-fsdp.yaml | 33 +++++++++++-------- rlinf/algorithms/rewards/math/__init__.py | 5 +-- rlinf/runners/reasoning_runner.py | 12 +++---- rlinf/workers/actor/fsdp_actor_worker.py | 14 +------- 4 files changed, 27 insertions(+), 37 deletions(-) diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml index e17486fe2..c4c646808 100644 --- a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml @@ -9,8 +9,7 @@ hydra: cluster: num_nodes: 1 component_placement: - actor: 0-3 - rollout: 4-7 + actor,rollout,reward: all runner: task_type: reasoning @@ -33,7 +32,7 @@ runner: resume_dir: null experiment_name: grpo-1.5b - output_dir: /mnt/public/daibo/results + output_dir: ../results algorithm: group_size: 8 @@ -84,7 +83,7 @@ rollout: gpu_memory_utilization: 0.55 - model_dir: /mnt/public/hf_models/qwen2.5-VL-3B/ + model_dir: /path/to/model/DeepSeek-R1-Distill-Qwen-1.5B/ model_arch: qwen2.5 enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. distributed_executor_backend: mp # ray or mp @@ -120,8 +119,8 @@ rollout: cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. data: - type: vision_language - dataset_name: robo2vlm + type: math + dataset_name: boba max_prompt_length: 1024 filter_prompt_by_length: True rollout_batch_size: 8 @@ -130,9 +129,9 @@ data: shuffle: True validation_shuffle: True seed: 1234 - train_data_paths: ["/mnt/public/daibo/dataset/robo2vlm-1/data/"] - val_data_paths: ["/mnt/public/daibo/dataset/robo2vlm-1/data/"] - prompt_key: question + train_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"] + val_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"] + prompt_key: prompt image_keys: [image] answer_key: answer choice_key: choices @@ -165,7 +164,7 @@ actor: seq_length: ${runner.seq_length} encoder_seq_length: ${runner.seq_length} - model_path: /mnt/public/hf_models/qwen2.5-VL-3B/ + model_path: /path/to/model/DeepSeek-R1-Distill-Qwen-1.5B/ optim: optimizer: adam @@ -195,20 +194,26 @@ actor: lr_decay_iters: 10 tokenizer: - tokenizer_model: /mnt/public/hf_models/qwen2.5-VL-3B/ + tokenizer_model: /path/to/model/DeepSeek-R1-Distill-Qwen-1.5B/ use_fast: False trust_remote_code: True padding_side: 'right' reward: - group_name: "ActorGroup" + group_name: "RewardGroup" use_reward_model: false - reward_type: 'vqa' - # reward_scale: 5.0 + reward_type: 'math' + reward_scale: 5.0 reward_weights: qa_accuracy: 1.0 think_format: 0.0 answer_format: 0.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/rlinf/algorithms/rewards/math/__init__.py b/rlinf/algorithms/rewards/math/__init__.py index 7eb6401a8..1a67e80e1 100644 --- a/rlinf/algorithms/rewards/math/__init__.py +++ b/rlinf/algorithms/rewards/math/__init__.py @@ -21,7 +21,7 @@ class MathReward: def __init__(self, config: DictConfig): - self.scale = config.get("scale", 1.0) + self.scale = config.get("reward_scale", 1.0) def get_reward( self, response: List[str], reference: List[List[str]] @@ -37,4 +37,5 @@ def get_reward( List[float]: A list of reward scores, one for each response. """ - return math_verify_call(response, reference) * self.scale + rewards = math_verify_call(response, reference) + return [float(reward) * self.scale for reward in rewards] diff --git a/rlinf/runners/reasoning_runner.py b/rlinf/runners/reasoning_runner.py index 404154fe2..1d2bf64d9 100644 --- a/rlinf/runners/reasoning_runner.py +++ b/rlinf/runners/reasoning_runner.py @@ -56,21 +56,20 @@ def __init__( rollout: Union["SGLangWorker", "VLLMWorker"], inference: Optional[MegatronInference], actor: MegatronActor, - reward: Optional[RewardWorker] = None, + reward: RewardWorker, ): """""" self.cfg = cfg self.component_placement = placement self.is_pipeline = self.component_placement.is_disaggregated self.has_dedicated_inference = inference is not None - self.has_dedicated_reward = reward is not None # Workers self.rollout = rollout self.actor = actor # Collocated mode uses actor as inference self.inference = inference if self.has_dedicated_inference else self.actor - self.reward = reward if self.has_dedicated_reward else self.actor + self.reward = reward # Data channels self.dataloader_channel = Channel.create("DataLoader") @@ -80,9 +79,7 @@ def __init__( self.inference_channel = Channel.create( "Inference", local=not self.has_dedicated_inference ) - self.reward_channel = Channel.create( - "Reward", local=not self.has_dedicated_reward - ) + self.reward_channel = Channel.create("Reward") self.actor_channel = Channel.create("Actor", local=True) # Configurations @@ -180,8 +177,7 @@ def init_workers(self): self.actor.init_worker().wait() if self.has_dedicated_inference: self.inference.init_worker().wait() - if self.has_dedicated_reward: - self.reward.init_worker().wait() + self.reward.init_worker().wait() if self.cfg.runner.resume_dir is None: return diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index 2604c0bb0..a26c03c16 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -147,19 +147,7 @@ def sync_model_to_rollout(self): self.rollou_state_dict = self.get_model_state_dict() if self._weight_dst_rank_in_rollout is not None: - - def transform_key(k): - if k.startswith("model.language_model."): - return "model." + k[21:] - elif k.startswith("model."): - return k[6:] - else: - return k - - handle = { - transform_key(k): reduce_tensor(v) - for k, v in self.rollou_state_dict.items() - } + handle = {k: reduce_tensor(v) for k, v in self.rollou_state_dict.items()} self.send( handle, self._rollout_group_name, self._weight_dst_rank_in_rollout From a7df8fc21fb9c2422da06f5a5ffc4ad5be51376b Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Mon, 29 Sep 2025 07:59:01 +0000 Subject: [PATCH 13/38] fix(math): fix some merge_batch when item is not tensor,add support for special prefix Signed-off-by: Bo Dai --- rlinf/workers/actor/fsdp_actor_worker.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index a26c03c16..ccf3bc68e 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -138,19 +138,30 @@ def _setup_rollout_weight_dst_ranks(self): def del_reshard_state_dict(self): if hasattr(self, "rollou_state_dict"): - del self.rollou_state_dict + del self.rollout_state_dict def sync_model_to_rollout(self): if next(self.model.parameters()).is_cpu: self.load_fsdp_param_and_grad(self.device) - self.rollou_state_dict = self.get_model_state_dict() + self.rollout_state_dict = self.get_model_state_dict() + + has_visual = any("visual." in k for k in self.rollout_state_dict.keys()) + + state_dict = {} if self._weight_dst_rank_in_rollout is not None: - handle = {k: reduce_tensor(v) for k, v in self.rollou_state_dict.items()} + for k, v in self.rollout_state_dict.items(): + name = k + if has_visual: + if name.startswith("model.language_model."): + name = "model." + name[21:] + elif name.startswith("model."): + name = name[6:] + state_dict[name] = reduce_tensor(v) self.send( - handle, self._rollout_group_name, self._weight_dst_rank_in_rollout + state_dict, self._rollout_group_name, self._weight_dst_rank_in_rollout ) if self.cfg.actor.get("enable_offload", False): self.offload_fsdp_param_and_grad() From 14cbdf0bb387848ebc334d673631bb5c50390963 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Mon, 29 Sep 2025 08:46:48 +0000 Subject: [PATCH 14/38] chore: add corresponding changes to yaml because of RewardModel and other configurations Signed-off-by: Bo Dai --- .../qwen2.5-1.5b-grpo-megatron-pipeline.yaml | 11 +++++++++- .../math/qwen2.5-1.5b-grpo-megatron.yaml | 20 ++++++++++++++++--- .../config/math/qwen2.5-1.5b-single-gpu.yaml | 11 ++++++++-- .../math/qwen2.5-32b-grpo-megatron.yaml | 10 ++++++++-- .../config/math/qwen2.5-7b-grpo-megatron.yaml | 11 ++++++++-- rlinf/data/io_struct.py | 16 ++++++++------- 6 files changed, 62 insertions(+), 17 deletions(-) diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml index 9ba641d15..0815011f3 100644 --- a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml @@ -12,9 +12,10 @@ cluster: rollout: 0-15 inference: 16-23 actor: 24-63 + reward: 0-15 runner: - task_type: math + task_type: reasoning logger: log_path: ${runner.output_dir}/${runner.experiment_name} project_name: rlinf @@ -134,6 +135,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 1024 filter_prompt_by_length: True rollout_batch_size: 512 @@ -146,6 +148,7 @@ data: train_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"] val_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"] + actor: group_name: "ActorGroup" training_backend: megatron @@ -271,9 +274,15 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} critic: use_critic_model: false diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml index 6f75dc38e..28a5ca960 100644 --- a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 16 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: ${runner.output_dir}/${runner.experiment_name} project_name: rlinf @@ -121,17 +121,24 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 1024 filter_prompt_by_length: True rollout_batch_size: 512 val_rollout_batch_size: null num_workers: 2 - prompt_key: prompt shuffle: True validation_shuffle: True seed: 1234 train_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"] val_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"] + prompt_key: prompt + image_keys: [image] + answer_key: answer + choice_key: choices + solution_key: null + use_chat_template: True + lazy_loading: True actor: group_name: "ActorGroup" @@ -258,9 +265,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml index e3f5bf28d..d654c6522 100644 --- a/examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 1 component_placement: - actor,rollout: 0 + actor,rollout,reward: 0 runner: - task_type: math + task_type: reasoning logger: log_path: ${runner.output_dir}/${runner.experiment_name} project_name: rlinf @@ -258,9 +258,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml index 6e397dfda..f7fb2e16c 100644 --- a/examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml +++ b/examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 32 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: ${runner.output_dir}/${runner.experiment_name} project_name: rlinf @@ -259,5 +259,11 @@ reward: reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml index 63146687e..5f33d9cb2 100644 --- a/examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml +++ b/examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 16 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: ${runner.output_dir}/${runner.experiment_name} project_name: rlinf @@ -257,9 +257,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/rlinf/data/io_struct.py b/rlinf/data/io_struct.py index 2c07bc9b2..ea78685b5 100644 --- a/rlinf/data/io_struct.py +++ b/rlinf/data/io_struct.py @@ -894,14 +894,16 @@ def merge_batches( return merged_batch if len(batches) == 1: return batches[0] + for key in batches[0].keys(): - assert torch.is_tensor(batches[0][key]), ( - f"Expected tensor for key {key} in batches, got {type(batches[0][key])}" - ) - assert torch.is_tensor(batches[0][key]), ( - f"Expected tensor for key {key} in batches, got {type(batches[0][key])}" - ) - merged_batch[key] = torch.cat([batch[key] for batch in batches], dim=0) + if torch.is_tensor(batches[0][key]): + merged_batch[key] = torch.cat([batch[key] for batch in batches], dim=0) + elif isinstance(batches[0][key], list): + merged_batch[key] = [] + for batch in batches: + merged_batch[key].extend(batch[key]) + else: + raise ValueError(f"Unsupported batch key type: {type(batches[0][key])}") return merged_batch From 124598874d4eb11bd163ee154a92a41695ca1aea Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Tue, 30 Sep 2025 04:05:18 +0000 Subject: [PATCH 15/38] fix(megatron): apply corresponding changes due to fsdp Signed-off-by: Bo Dai --- .../qwen2.5-1.5b-grpo-megatron-pipeline.yaml | 2 + .../math/qwen2.5-1.5b-grpo-megatron.yaml | 30 +++++++-------- .../config/math/qwen2.5-1.5b-single-gpu.yaml | 2 + .../math/qwen2.5-32b-grpo-megatron.yaml | 2 + .../config/math/qwen2.5-7b-grpo-megatron.yaml | 2 + rlinf/algorithms/losses.py | 12 +++--- rlinf/runners/reasoning_runner.py | 29 ++++++--------- rlinf/workers/actor/megatron_actor_worker.py | 24 +++++++++--- rlinf/workers/reward/reward_worker.py | 3 +- rlinf/workers/rollout/sglang/sglang_worker.py | 37 +++++++++---------- 10 files changed, 75 insertions(+), 68 deletions(-) diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml index 0815011f3..ea4606003 100644 --- a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml @@ -66,6 +66,8 @@ algorithm: entropy_bonus: 0.0 calculate_entropy: True clip_ratio_c: null # 3.0 + clip_ratio_low: null # if null or not set, will use ratio_clip_eps + clip_ratio_high: null # if null or not set, will use ratio_clip_eps adv_type: math_grpo normalize_advantages: False diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml index 28a5ca960..63e972b3a 100644 --- a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml @@ -7,7 +7,7 @@ hydra: output_subdir: null cluster: - num_nodes: 16 + num_nodes: 1 component_placement: actor,rollout,reward: all @@ -25,14 +25,14 @@ runner: val_check_interval: 1 save_interval: 50 - seq_length: 28672 + seq_length: 10240 enable_dynamic_batch_size: False max_tokens_per_mbs: 28672 resume_dir: null - experiment_name: grpo-1.5b - output_dir: ../results + experiment_name: megatron-vllm-1.5b-math-test + output_dir: /mnt/public/daibo/results algorithm: group_size: 16 @@ -61,6 +61,8 @@ algorithm: entropy_bonus: 0.0 calculate_entropy: False clip_ratio_c: null # 3.0 + clip_ratio_low: null # if null or not set, will use ratio_clip_eps + clip_ratio_high: null # if null or not set, will use ratio_clip_eps adv_type: math_grpo normalize_advantages: True @@ -82,7 +84,7 @@ rollout: gpu_memory_utilization: 0.55 - model_dir: /path/to/model/DeepSeek-R1-Distill-Qwen-1.5B/ + model_dir: /mnt/public/hf_models/DeepSeek-R1-Distill-Qwen-1.5B/ model_arch: qwen2.5 enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. distributed_executor_backend: mp # ray or mp @@ -91,7 +93,7 @@ rollout: padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine eos: null # will be tokenizer.eos_token_id if null. - rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm] + rollout_backend: vllm # here choose which backend to rollout,support [sglang, vllm] sglang: attention_backend: triton # [flashinfer, triton] for more, see sglang's doc @@ -124,21 +126,15 @@ data: dataset_name: boba max_prompt_length: 1024 filter_prompt_by_length: True - rollout_batch_size: 512 + rollout_batch_size: 8 val_rollout_batch_size: null num_workers: 2 + prompt_key: prompt shuffle: True validation_shuffle: True seed: 1234 - train_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"] - val_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"] - prompt_key: prompt - image_keys: [image] - answer_key: answer - choice_key: choices - solution_key: null - use_chat_template: True - lazy_loading: True + train_data_paths: ["/mnt/public/daibo/dataset/boba_106k_0319_prompt_1024.jsonl"] + val_data_paths: ["/mnt/public/daibo/dataset/boba_106k_0319_prompt_1024.jsonl"] actor: group_name: "ActorGroup" @@ -220,7 +216,7 @@ actor: lr_decay_iters: 10 tokenizer: - tokenizer_model: /path/to/model/DeepSeek-R1-Distill-Qwen-1.5B/ + tokenizer_model: /mnt/public/hf_models/DeepSeek-R1-Distill-Qwen-1.5B/ use_fast: False trust_remote_code: True padding_side: 'right' diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml index d654c6522..1829050c1 100644 --- a/examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml @@ -61,6 +61,8 @@ algorithm: entropy_bonus: 0.0 calculate_entropy: False clip_ratio_c: null # 3.0 + clip_ratio_low: null # if null or not set, will use ratio_clip_eps + clip_ratio_high: null # if null or not set, will use ratio_clip_eps adv_type: math_grpo normalize_advantages: True diff --git a/examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml index f7fb2e16c..e9eb2089e 100644 --- a/examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml +++ b/examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml @@ -61,6 +61,8 @@ algorithm: entropy_bonus: 0.0 calculate_entropy: False clip_ratio_c: null # 3.0 + clip_ratio_low: null # if null or not set, will use ratio_clip_eps + clip_ratio_high: null # if null or not set, will use ratio_clip_eps adv_type: math_grpo normalize_advantages: True diff --git a/examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml index 5f33d9cb2..b2a70d6f8 100644 --- a/examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml +++ b/examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml @@ -61,6 +61,8 @@ algorithm: entropy_bonus: 0.0 calculate_entropy: False clip_ratio_c: null # 3.0 + clip_ratio_low: null # if null or not set, will use ratio_clip_eps + clip_ratio_high: null # if null or not set, will use ratio_clip_eps adv_type: math_grpo normalize_advantages: True diff --git a/rlinf/algorithms/losses.py b/rlinf/algorithms/losses.py index 798e5330f..1d66885ea 100644 --- a/rlinf/algorithms/losses.py +++ b/rlinf/algorithms/losses.py @@ -240,12 +240,12 @@ def compute_math_ppo_actor_loss(**kwargs): # Compile metrics for logging metrics_data = { - "policy_loss": masked_mean(policy_loss.detach(), loss_mask).cpu(), - "ratio": masked_mean(ratio.detach(), loss_mask).cpu(), - "clipped_ratio": masked_mean(clipped_ratio.detach(), loss_mask).cpu(), - "dual_cliped_ratio": masked_mean(dual_cliped_ratio.detach(), loss_mask).cpu(), - "approx_kl": approx_kl.detach().cpu(), - "clip_fraction": clip_fraction.detach().cpu(), + "policy_loss": masked_mean(policy_loss.detach(), loss_mask), + "ratio": masked_mean(ratio.detach(), loss_mask), + "clipped_ratio": masked_mean(clipped_ratio.detach(), loss_mask), + "dual_cliped_ratio": masked_mean(dual_cliped_ratio.detach(), loss_mask), + "approx_kl": approx_kl.detach(), + "clip_fraction": clip_fraction.detach(), } return policy_loss, metrics_data diff --git a/rlinf/runners/reasoning_runner.py b/rlinf/runners/reasoning_runner.py index 1d2bf64d9..d68abf72a 100644 --- a/rlinf/runners/reasoning_runner.py +++ b/rlinf/runners/reasoning_runner.py @@ -76,9 +76,7 @@ def __init__( self.rollout_channel = Channel.create("Rollout") # Create a local channel (i.e., a channel that is different in every process) # if inference is not a dedicated worker - self.inference_channel = Channel.create( - "Inference", local=not self.has_dedicated_inference - ) + self.inference_channel = Channel.create("Inference") self.reward_channel = Channel.create("Reward") self.actor_channel = Channel.create("Actor", local=True) @@ -332,38 +330,33 @@ def run(self): output_channel=self.rollout_channel, ) + # Rewards + reward_handle: Handle = self.reward.compute_rewards( + input_channel=self.rollout_channel, + output_channel=self.reward_channel, + ) + if self.recompute_logprobs: # Inference prev/ref logprobs infer_handle: Handle = self.inference.run_inference( - input_channel=self.rollout_channel, + input_channel=self.reward_channel, output_channel=self.inference_channel, compute_ref_logprobs=self.compute_ref_logprobs, ) inference_channel = self.inference_channel else: infer_handle = None - inference_channel = self.rollout_channel - - # Rewards - reward_handle: Handle = self.reward.compute_rewards( - input_channel=inference_channel, - output_channel=self.reward_channel, - ) + inference_channel = self.reward_channel # Advantages and returns adv_handle: Handle = self.actor.compute_advantages_and_returns( - input_channel=self.reward_channel, + input_channel=inference_channel, output_channel=self.actor_channel, ) # Actor training - actor_input_channel = self.actor_channel - if self.is_pipeline: - # In pipeline mode, the rollout already contains the advantages and returns - # So the above two steps are in fact no-ops, and we should directly use the inference channel as the input - actor_input_channel = inference_channel actor_handle: Handle = self.actor.run_training( - input_channel=actor_input_channel, + input_channel=self.actor_channel, ) metrics = actor_handle.wait() diff --git a/rlinf/workers/actor/megatron_actor_worker.py b/rlinf/workers/actor/megatron_actor_worker.py index 54376e1a5..57c54dbb9 100644 --- a/rlinf/workers/actor/megatron_actor_worker.py +++ b/rlinf/workers/actor/megatron_actor_worker.py @@ -115,10 +115,21 @@ def __init__( self.calculate_entropy_loss = ( self.cfg.algorithm.entropy_bonus > 0 and self.calculate_entropy ) - self.ratio_eps = self.cfg.algorithm.ratio_clip_eps + clip_ratio = self.cfg.algorithm.ratio_clip_eps + self.clip_ratio_low = ( + self.cfg.algorithm.get("clip_ratio_low") + if self.cfg.algorithm.get("clip_ratio_low") is not None + else clip_ratio + ) + self.clip_ratio_high = ( + self.cfg.algorithm.get("clip_ratio_high") + if self.cfg.algorithm.get("clip_ratio_high") is not None + else clip_ratio + ) self.logprob_forward_micro_batch_size = ( self.cfg.algorithm.logprob_forward_micro_batch_size ) + self.kl_beta = self.cfg.algorithm.kl_beta self.kl_penalty_type = self.cfg.algorithm.kl_penalty_type self.clip_ratio_c = self.cfg.algorithm.clip_ratio_c @@ -382,7 +393,8 @@ def loss_func(output): logprobs=curr_logprobs, old_logprobs=prev_logprobs, advantages=advantages, - eps_clip=self.ratio_eps, + clip_ratio_low=self.clip_ratio_low, + clip_ratio_high=self.clip_ratio_high, loss_mask=mask, ) @@ -843,7 +855,6 @@ def run_inference( while recv_batch_size < self.total_batch_size_per_dp: batch, rollout_result = self.get_batch(input_channel) recv_batch_size += rollout_result.num_sequence - # Must be called after batch is retrieved, suggesting that rollout has stopped # Otherwise, loading model might cause OOM in the collocated mode self._load_weight_and_optimizer(input_channel) @@ -859,7 +870,6 @@ def run_inference( with cpu_weight_swap(self.model[0], self.ref_policy_state_dict): ref_logprobs = self.inference_step(batch) rollout_result.ref_logprobs = ref_logprobs.cpu() - self.put_result(rollout_result, output_channel) assert recv_batch_size == self.total_batch_size_per_dp, ( @@ -963,7 +973,6 @@ def compute_advantages_and_returns( while recv_batch_size < self.total_batch_size_per_dp: batch, rollout_result = self.get_batch(input_channel) recv_batch_size += rollout_result.num_sequence - with self.worker_timer(): if rollout_result.advantages is None: mask = batch["attention_mask"][:, -self.response_len :] @@ -1033,7 +1042,10 @@ def sync_model_to_rollout(self): def _compute_rollout_metrics(self, batch): rollout_metrics, total_prompt_lengths, total_decode_lengths = ( compute_rollout_metrics( - batch, self.cfg.data.max_prompt_length, self.response_len + batch, + self.cfg.data.max_prompt_length, + self.response_len, + self._world_size, ) ) diff --git a/rlinf/workers/reward/reward_worker.py b/rlinf/workers/reward/reward_worker.py index 9290c23d6..fefd422d6 100644 --- a/rlinf/workers/reward/reward_worker.py +++ b/rlinf/workers/reward/reward_worker.py @@ -49,7 +49,6 @@ def get_batch( self, channel: Channel ) -> Tuple[Dict[str, torch.Tensor], RolloutResult]: result: RolloutResult = channel.get() - batch = result.to_actor_batch( self.cfg.data.max_prompt_length, self.cfg.actor.model.encoder_seq_length, @@ -69,8 +68,8 @@ def compute_rewards(self, input_channel: Channel, output_channel: Channel): recv_batch_size = 0 while recv_batch_size < self.total_batch_size_per_dp: batch, rollout_result = self.get_batch(input_channel) - recv_batch_size += rollout_result.num_sequence + recv_batch_size += rollout_result.num_sequence # Compute rule-based reward if rollout_result.rewards is None: rollout_result.rewards = self._compute_batch_rewards( diff --git a/rlinf/workers/rollout/sglang/sglang_worker.py b/rlinf/workers/rollout/sglang/sglang_worker.py index 764702c63..a339edd9e 100644 --- a/rlinf/workers/rollout/sglang/sglang_worker.py +++ b/rlinf/workers/rollout/sglang/sglang_worker.py @@ -323,7 +323,6 @@ async def rollout(self, input_channel: Channel, output_channel: Channel): total_reqs = len(rollout_tasks) required_reqs = total_reqs // self._cfg.algorithm.max_num_gen_batches - droped_reqs = 0 finished_reqs = 0 abort_flag = False @@ -334,20 +333,20 @@ async def rollout(self, input_channel: Channel, output_channel: Channel): if self._completion_info.is_completed(hash_id): results = self._completion_info.get_results(hash_id) - ( - rewards, - advantages, - ) = await self._compute_reward_and_advantage( - results, - self._current_request.answers[raw_id], - ) - if ( - all_floats_equal(rewards) - and self._cfg.algorithm.get("max_num_gen_batches", 1) > 1 - ): - if (total_reqs - droped_reqs) > required_reqs: - droped_reqs += rollout_request.n - continue + # ( + # rewards, + # advantages, + # ) = await self._compute_reward_and_advantage( + # results, + # self._current_request.answers[raw_id], + # ) + # if ( + # all_floats_equal(rewards) + # and self._cfg.algorithm.get("max_num_gen_batches", 1) > 1 + # ): + # if (total_reqs - droped_reqs) > required_reqs: + # droped_reqs += rollout_request.n + # continue input_ids = [input_ids] * len(results) rollout_result = RolloutResult.from_sglang_results( @@ -356,10 +355,10 @@ async def rollout(self, input_channel: Channel, output_channel: Channel): input_ids, return_logprobs=self._return_logprobs, ) - rollout_result.rewards = torch.tensor( - rewards, dtype=torch.float32 - ).reshape(-1, 1) - rollout_result.advantages = advantages + # rollout_result.rewards = torch.tensor( + # rewards, dtype=torch.float32 + # ).reshape(-1, 1) + # rollout_result.advantages = advantages return_tasks.append( asyncio.create_task( self._put_result(rollout_result, output_channel) From 4bf2d812ef4c864ede6ce6150782d56c06834059 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Tue, 30 Sep 2025 10:21:46 +0000 Subject: [PATCH 16/38] fix(reward): change math_verify_call's result from {0,1} to {-1,1} Signed-off-by: Bo Dai --- rlinf/data/io_struct.py | 2 +- rlinf/workers/actor/megatron_actor_worker.py | 81 +------------------- rlinf/workers/rollout/vllm/vllm_worker.py | 2 +- toolkits/math_verifier/verify.py | 12 ++- 4 files changed, 8 insertions(+), 89 deletions(-) diff --git a/rlinf/data/io_struct.py b/rlinf/data/io_struct.py index ea78685b5..b69b680bd 100644 --- a/rlinf/data/io_struct.py +++ b/rlinf/data/io_struct.py @@ -55,7 +55,7 @@ class RolloutRequest: n: int input_ids: List[List[int]] image_data: Union[List[List[bytes]], List[List[str]]] - answers: List[str] + answers: List[str] multi_modal_inputs: List[Dict] def repeat(self) -> "RolloutRequest": diff --git a/rlinf/workers/actor/megatron_actor_worker.py b/rlinf/workers/actor/megatron_actor_worker.py index 57c54dbb9..a08415ea9 100644 --- a/rlinf/workers/actor/megatron_actor_worker.py +++ b/rlinf/workers/actor/megatron_actor_worker.py @@ -14,7 +14,7 @@ import copy from functools import partial -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple import torch import torch.distributed @@ -53,7 +53,6 @@ ) from rlinf.utils.distributed import ( RolloutDataBalance, - broadcast_tensor_within_mp, broadcast_tensor_within_pp, compute_rollout_metrics, masked_normalization, @@ -876,84 +875,6 @@ def run_inference( f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" ) - # Rewards - def compute_rewards(self, input_channel: Channel, output_channel: Channel): - """Compute rewards. - - Args: - input_channel: The input channel to read from. - output_channel: The output channel to send results to. - """ - assert self.reward_fn is not None, "reward_fn is not set" - if self.is_pipeline: - # In pipeline mode, rewards are computed in the rollout - with self.worker_timer(): - return - recv_batch_size = 0 - while recv_batch_size < self.total_batch_size_per_dp: - batch, rollout_result = self.get_batch(input_channel) - recv_batch_size += rollout_result.num_sequence - - # Compute rule-based reward - with self.worker_timer(): - if rollout_result.rewards is None: - rollout_result.rewards = self._compute_batch_rewards( - batch, rollout_result.answers - ).cpu() - - self.put_result(rollout_result, output_channel) - - assert recv_batch_size == self.total_batch_size_per_dp, ( - f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" - ) - - def _compute_batch_rewards( - self, batch: Dict[str, torch.Tensor], answers: List[str] - ): - """Reward computation using non-model based reward.""" - all_reward_scores = [] - texts = [] - for response, response_len in zip( - batch["input_ids"], - batch["response_lengths"], - ): - response = response[ - self.cfg.data.max_prompt_length : self.cfg.data.max_prompt_length - + response_len - ] - texts.append( - self.tokenizer.decode(response.tolist(), skip_special_tokens=True) - ) - - if torch.distributed.get_rank() == parallel_state.get_model_parallel_src_rank(): - rewards = self.reward_fn(texts, answers) - if self.cfg.reward.reward_type == "math": - reward_scores = [ - self.cfg.reward.reward_scale - if reward == 1 - else -self.cfg.reward.reward_scale - for reward in rewards - ] - else: - reward_scores = rewards - - all_reward_scores.extend(reward_scores) - - if len(all_reward_scores) > 0: - new_all_rewards = [] - - for response in all_reward_scores: - if response is None: - response = 0.0 - new_all_rewards.append(response) - - all_reward_scores = torch.as_tensor( - new_all_rewards, - dtype=torch.float, - device=torch.cuda.current_device(), - ).view(-1, 1) - return broadcast_tensor_within_mp(all_reward_scores).flatten().to("cpu") - # Advantages and returns def compute_advantages_and_returns( self, input_channel: Channel, output_channel: Channel diff --git a/rlinf/workers/rollout/vllm/vllm_worker.py b/rlinf/workers/rollout/vllm/vllm_worker.py index 54c1c7ee5..899e07f1e 100644 --- a/rlinf/workers/rollout/vllm/vllm_worker.py +++ b/rlinf/workers/rollout/vllm/vllm_worker.py @@ -399,7 +399,7 @@ async def rollout_and_return( group_size=self._cfg.algorithm.group_size, results=vllm_results, answers=request.answers, - multi_modal_inputs=request.multi_modal_inputs, + multi_modal_inputs=request.multi_modal_inputs, return_logprobs=self._return_logprobs, ) if self._placement.is_disaggregated: diff --git a/toolkits/math_verifier/verify.py b/toolkits/math_verifier/verify.py index 80bf4b552..31d92c280 100644 --- a/toolkits/math_verifier/verify.py +++ b/toolkits/math_verifier/verify.py @@ -348,22 +348,22 @@ def process_results(answer, solution): extracted_solution = extract_answer(solution, "math", use_last_number=True) if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]: - retval = 0 + retval = -1 elif extracted_solution is None or extracted_solution.strip() in [ "None", "none", "", ]: - retval = 0 + retval = -1 elif math_equal(extracted_answer, extracted_solution, timeout=False): # elif call_with_timeout(math_equal, extracted_answer, extracted_solution): retval = 1 else: - retval = 0 + retval = -1 return retval, (extracted_answer, extracted_solution) except Exception: - return 0, ("None", "None") + return -1, ("None", "None") def process_results_process(a, b, output_queue): @@ -406,14 +406,12 @@ def math_verify_call( labels = [] has_timeout = False for jobs in all_jobs: - label = 0 try: for job in as_completed(jobs, timeout=timeout): x = job.result() - label = label or x + labels.append(x) except TimeoutError: has_timeout = True - labels.append(label) if has_timeout: reset_global_process_pool() From fc77e2aac77324d5720643cdba6c04d9bb86a005 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Tue, 30 Sep 2025 10:46:13 +0000 Subject: [PATCH 17/38] feat(ci): change corresponding ci config for refactored code Signed-off-by: Bo Dai --- rlinf/data/io_struct.py | 54 ------------------- rlinf/hybrid_engines/fsdp/utils.py | 1 - .../auto_placement/qwen2.5-1.5b-grpo.yaml | 12 ++++- ...1.5b-grpo-collocated-rollout-logprobs.yaml | 12 ++++- .../sglang/qwen2.5-1.5b-grpo-collocated.yaml | 12 ++++- ...5-1.5b-grpo-pipeline-rollout-logprobs.yaml | 12 ++++- .../sglang/qwen2.5-1.5b-grpo-pipeline.yaml | 11 +++- tests/e2e_tests/math/sglang/run_collocated.sh | 2 +- tests/e2e_tests/math/sglang/run_pipeline.sh | 2 +- ...1.5b-grpo-collocated-rollout-logprobs.yaml | 12 ++++- .../vllm/qwen2.5-1.5b-grpo-collocated.yaml | 12 ++++- tests/e2e_tests/math/vllm/run_collocated.sh | 2 +- tests/unit_tests/test_placement.py | 28 ---------- 13 files changed, 73 insertions(+), 99 deletions(-) diff --git a/rlinf/data/io_struct.py b/rlinf/data/io_struct.py index b69b680bd..b7dc359b5 100644 --- a/rlinf/data/io_struct.py +++ b/rlinf/data/io_struct.py @@ -58,60 +58,6 @@ class RolloutRequest: answers: List[str] multi_modal_inputs: List[Dict] - def repeat(self) -> "RolloutRequest": - """Repeat each input in the RolloutRequest a specified number of times. - - Args: - times (int): The number of times to repeat each input. - - Returns: - RolloutRequest: A new RolloutRequest with repeated inputs. - """ - assert self.n > 0, "n must be greater than 0" - - input_ids, answers = zip( - *[ - (input_id, answer) - for input_id, answer in zip(self.input_ids, self.answers) - for _ in range(self.n) - ] - ) - return RolloutRequest( - n=self.n, - input_ids=list(input_ids), - answers=list(answers), - ) - - def split(self, num_splits: int) -> List["RolloutRequest"]: - """Split the RolloutRequest into multiple smaller requests. - - Args: - num_splits (int): The number of splits to create. - - Returns: - List[RolloutRequest]: A list of smaller RolloutRequest instances. - """ - assert num_splits > 0, "num_splits must be greater than 0" - assert len(self.input_ids) % num_splits == 0, ( - f"Input IDs length {len(self.input_ids)} is not divisible by num_splits {num_splits}" - ) - - input_ids_split_list = split_list(self.input_ids, num_splits) - answers_split_list = split_list(self.answers, num_splits) - - splitted_requests = [] - for input_ids_batch, answers_batch in zip( - input_ids_split_list, answers_split_list - ): - request = RolloutRequest( - n=self.n, - input_ids=input_ids_batch, - answers=answers_batch, - ) - splitted_requests.append(request) - - return splitted_requests - def repeat(self) -> "RolloutRequest": """Repeat each input in the RolloutRequest a specified number of times. diff --git a/rlinf/hybrid_engines/fsdp/utils.py b/rlinf/hybrid_engines/fsdp/utils.py index 2461f7006..0a9f0054d 100644 --- a/rlinf/hybrid_engines/fsdp/utils.py +++ b/rlinf/hybrid_engines/fsdp/utils.py @@ -97,7 +97,6 @@ def get_fsdp_wrap_policy(module, config=None, is_lora=False, is_vla_model=False) # Add vision transformer policies for VLA models if is_vla_model: - from prismatic.extern.hf.modeling_prismatic import PrismaticProjector from timm.models.vision_transformer import VisionTransformer from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy diff --git a/tests/e2e_tests/auto_placement/qwen2.5-1.5b-grpo.yaml b/tests/e2e_tests/auto_placement/qwen2.5-1.5b-grpo.yaml index 6555cb9bf..d1c65161b 100644 --- a/tests/e2e_tests/auto_placement/qwen2.5-1.5b-grpo.yaml +++ b/tests/e2e_tests/auto_placement/qwen2.5-1.5b-grpo.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 1 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: ${runner.output_dir}/${runner.experiment_name} project_name: rlinf @@ -119,6 +119,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 1024 filter_prompt_by_length: True rollout_batch_size: 128 @@ -256,13 +257,20 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} critic: use_critic_model: false + profile_data: actor_cost: 95.7 inference_cost: 30.8 diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml b/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml index 7e2c5164a..3516fe44b 100644 --- a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml +++ b/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 1 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -121,6 +121,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 256 filter_prompt_by_length: True rollout_batch_size: 8 @@ -259,9 +260,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated.yaml b/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated.yaml index 1dfe47aeb..79b5e1595 100644 --- a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated.yaml +++ b/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 1 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -126,6 +126,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 256 filter_prompt_by_length: True rollout_batch_size: 8 @@ -264,9 +265,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml b/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml index 25344d0bf..34bcff492 100644 --- a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml +++ b/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml @@ -11,9 +11,10 @@ cluster: component_placement: rollout: 0-3 actor: 4-7 + reward: 0-3 runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -124,6 +125,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 256 rollout_batch_size: 8 val_rollout_batch_size: null @@ -257,11 +259,17 @@ actor: schedule_repeat: 1 # inference and training will repeat such times # schedule_wait: it will be set at runtime - reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline.yaml b/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline.yaml index d5cd2b4a4..b48eb6057 100644 --- a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline.yaml +++ b/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline.yaml @@ -12,9 +12,10 @@ cluster: rollout: 0-3 inference: 4-5 actor: 6-7 + reward: 0-3 runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -138,6 +139,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 256 rollout_batch_size: 8 val_rollout_batch_size: null @@ -273,9 +275,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/sglang/run_collocated.sh b/tests/e2e_tests/math/sglang/run_collocated.sh index 1610f8fac..5911653e7 100644 --- a/tests/e2e_tests/math/sglang/run_collocated.sh +++ b/tests/e2e_tests/math/sglang/run_collocated.sh @@ -14,4 +14,4 @@ else CONFIG_NAME=$1 fi -python ${REPO_PATH}/examples/math/main_math.py --config-path $REPO_PATH/tests/e2e_tests/math/sglang --config-name $CONFIG_NAME \ No newline at end of file +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/math/sglang --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/e2e_tests/math/sglang/run_pipeline.sh b/tests/e2e_tests/math/sglang/run_pipeline.sh index 85e2e5c2d..f18012bb4 100644 --- a/tests/e2e_tests/math/sglang/run_pipeline.sh +++ b/tests/e2e_tests/math/sglang/run_pipeline.sh @@ -14,4 +14,4 @@ else CONFIG_NAME=$1 fi -python ${REPO_PATH}/examples/math/main_math.py --config-path $REPO_PATH/tests/e2e_tests/math/sglang --config-name $CONFIG_NAME \ No newline at end of file +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/math/sglang --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml b/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml index fe61ab16c..edeaee9c3 100644 --- a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml +++ b/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 1 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -121,6 +121,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 256 filter_prompt_by_length: True rollout_batch_size: 8 @@ -259,9 +260,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated.yaml b/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated.yaml index 099fe7268..09df84ca8 100644 --- a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated.yaml +++ b/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 1 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -122,6 +122,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 256 filter_prompt_by_length: True rollout_batch_size: 8 @@ -260,9 +261,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/vllm/run_collocated.sh b/tests/e2e_tests/math/vllm/run_collocated.sh index b4e924b1d..6ce4067fd 100644 --- a/tests/e2e_tests/math/vllm/run_collocated.sh +++ b/tests/e2e_tests/math/vllm/run_collocated.sh @@ -14,4 +14,4 @@ else CONFIG_NAME=$1 fi -python ${REPO_PATH}/examples/math/main_math.py --config-path $REPO_PATH/tests/e2e_tests/math/vllm --config-name $CONFIG_NAME \ No newline at end of file +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/math/vllm --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/unit_tests/test_placement.py b/tests/unit_tests/test_placement.py index c16ff7fb9..22c75ab68 100644 --- a/tests/unit_tests/test_placement.py +++ b/tests/unit_tests/test_placement.py @@ -1087,34 +1087,6 @@ def test_model_parallel_component_placement_init_missing_rollout_gpus(self): cluster = mock_cluster(num_nodes=1, num_accelerators_per_node=4) ModelParallelComponentPlacement(config, cluster) - def test_model_parallel_component_placement_init_collocated_mode_invalid_tp_sizes( - self, - ): - """Test ModelParallelComponentPlacement raises error when actor TP size < rollout TP size in collocated mode.""" - config = DictConfig( - { - "cluster": { - "num_nodes": 1, - "component_placement": {"actor,rollout": "0-3"}, - }, - "actor": { - "model": { - "tensor_model_parallel_size": 2, - "context_parallel_size": 1, - "pipeline_model_parallel_size": 1, - } - }, - "rollout": {"tensor_parallel_size": 4, "pipeline_parallel_size": 1}, - } - ) - - with pytest.raises( - AssertionError, - match="Actor TP size 2 must be greater or equal to Rollout TP size 4", - ): - cluster = mock_cluster(num_nodes=1, num_accelerators_per_node=4) - ModelParallelComponentPlacement(config, cluster) - def test_model_parallel_component_placement_init_collocated_mode_with_inference_gpus( self, ): From 9d40cb40186189dec4e2c4cb0051b16d0d045ab2 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Thu, 2 Oct 2025 09:03:44 +0000 Subject: [PATCH 18/38] chore: refactor dataset parts Signed-off-by: Bo Dai --- rlinf/data/datasets/__init__.py | 150 ++++++++++ rlinf/data/datasets/item.py | 59 ++++ rlinf/data/datasets/math.py | 153 ++++++++++ rlinf/data/datasets/utils.py | 68 +++++ rlinf/data/{datasets.py => datasets/vlm.py} | 275 ++---------------- rlinf/data/io_struct.py | 29 +- rlinf/workers/actor/fsdp_actor_worker.py | 2 +- rlinf/workers/reward/reward_worker.py | 4 +- rlinf/workers/rollout/sglang/sglang_worker.py | 21 +- rlinf/workers/rollout/vllm/vllm_worker.py | 31 +- tests/unit_tests/test_io_struct.py | 135 +++++++++ 11 files changed, 613 insertions(+), 314 deletions(-) create mode 100644 rlinf/data/datasets/__init__.py create mode 100644 rlinf/data/datasets/item.py create mode 100644 rlinf/data/datasets/math.py create mode 100644 rlinf/data/datasets/utils.py rename rlinf/data/{datasets.py => datasets/vlm.py} (68%) create mode 100644 tests/unit_tests/test_io_struct.py diff --git a/rlinf/data/datasets/__init__.py b/rlinf/data/datasets/__init__.py new file mode 100644 index 000000000..f86f3576e --- /dev/null +++ b/rlinf/data/datasets/__init__.py @@ -0,0 +1,150 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Dict, List, Tuple + +import torch +from omegaconf import DictConfig +from torch.utils.data import Dataset +from transformers import AutoTokenizer + +from rlinf.data.datasets.item import DatasetItem +from rlinf.data.datasets.math import MathDataset +from rlinf.data.datasets.vlm import VLMDatasetRegistry + + +def create_rl_dataset( + config: DictConfig, tokenizer: AutoTokenizer +) -> Tuple[Dataset, Dataset]: + """Create rl datasets. + + Arguments: + config: The RLinf config. + tokenizer (Tokenizer): The tokenizer. + + Returns: + train_dataset (Dataset): The training dataset. + + val_dataset (Dataset): The validation dataset. + """ + + if config.data.type == "math": + dataset_cls = MathDataset + elif config.data.type == "vision_language": + # Prefer new factory-based VLM datasets; fallback to legacy if requested + dataset_name = getattr(config.data, "dataset_name", None) + lazy_loading = bool(getattr(config.data, "lazy_loading", False)) + + logging.info( + f"Using VLM dataset: name={dataset_name}, lazy_loading={lazy_loading}" + ) + + train_dataset = VLMDatasetRegistry.create( + dataset_name, + data_paths=config.data.train_data_paths, + config=config, + tokenizer=tokenizer, + ) + val_dataset = VLMDatasetRegistry.create( + dataset_name, + data_paths=config.data.val_data_paths, + config=config, + tokenizer=tokenizer, + ) + return train_dataset, val_dataset + else: + return None, None + + logging.info(f"Using dataset class: {dataset_cls.__name__}") + + # Instantiate the dataset using the determined dataset class + train_dataset = dataset_cls( + data_paths=config.data.train_data_paths, + config=config, + tokenizer=tokenizer, + ) + + val_dataset = dataset_cls( + data_paths=config.data.val_data_paths, + config=config, + tokenizer=tokenizer, + ) + + return train_dataset, val_dataset + + +def collate_fn(data_list: List["DatasetItem"]) -> Dict[str, Any]: + """ + Collate function for batching dataset items. + """ + prompts = [] + lens = [] + for it in data_list: + p = ( + it.prompt + if isinstance(it.prompt, torch.Tensor) + else torch.as_tensor(it.prompt, dtype=torch.long) + ) + if p.dim() == 2 and p.size(0) == 1: + p = p.squeeze(0) + assert p.dim() == 1, ( + f"DatasetItem.prompt must be 1-D tensor, current shape is: {p.shape}" + ) + prompts.append(p) + lens.append(p.numel()) + + if len(set(lens)) == 1: + target_len = lens[0] + else: + target_len = min(lens) + prompts = [p[-target_len:] if p.numel() > target_len else p for p in prompts] + + batch_prompt = torch.stack(prompts, dim=0) # [B, L] + batch_length = torch.tensor( + [min(int(it.length), target_len) for it in data_list], dtype=torch.long + ) + + batch_idx = torch.tensor([int(it.idx) for it in data_list], dtype=torch.long) + + batch: Dict[str, Any] = { + "prompt": batch_prompt, # [B, L] + "length": batch_length, # [B] + "answer": [it.answer for it in data_list], # List[str] + "idx": batch_idx, # [B] + "solution": [it.solution for it in data_list], # List[Optional[str]] + "image_data": [ + it.image_data for it in data_list + ], # List[Optional[List[bytes|str]]] + "prompt_text": [it.prompt_text for it in data_list], # List[Optional[str]] + "meta": [it.meta for it in data_list], # List[Optional[dict]] + "multi_modal_inputs": [ + it.multi_modal_inputs for it in data_list + ], # List[Optional[dict]] + } + return batch diff --git a/rlinf/data/datasets/item.py b/rlinf/data/datasets/item.py new file mode 100644 index 000000000..e75155dcb --- /dev/null +++ b/rlinf/data/datasets/item.py @@ -0,0 +1,59 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import torch + + +@dataclass +class DatasetItem: + """ + A single item in processed dataset. + + Attributes: + prompt (torch.Tensor): Tokenized prompt input_ids tensor. + length (int): Length of the prompt input_ids. + answer (str | dict): The answer associated with the prompt. + idx (int): Index of the item in the dataset. + solution (Optional[str]): Optional solution text if exists. + prompt_text (Optional[str]): Optional original prompt text before tokenization. + meta (Optional[Dict[str, Any]]): Optional metadata dictionary. + multi_modal_inputs (Optional[Dict[str, Any]]): Optional dictionary for additional multi-modal inputs. + """ + + prompt: torch.Tensor + length: int + answer: str | dict + idx: int + solution: Optional[str] = None + image_data: Optional[List[Union[bytes, str]]] = None + prompt_text: Optional[str] = None + meta: Optional[Dict[str, Any]] = None + multi_modal_inputs: Optional[Dict[str, Any]] = None diff --git a/rlinf/data/datasets/math.py b/rlinf/data/datasets/math.py new file mode 100644 index 000000000..821074bf6 --- /dev/null +++ b/rlinf/data/datasets/math.py @@ -0,0 +1,153 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import logging +import os +from typing import Any, List, Tuple, Union + +import torch +from omegaconf import DictConfig +from torch.utils.data import Dataset +from transformers import AutoTokenizer + +from rlinf.data.datasets.item import DatasetItem +from rlinf.data.datasets.utils import batch_pad_to_fixed_len + + +class MathDataset(Dataset): + def __init__( + self, + data_paths: Union[str, List[str]], + config: DictConfig, + tokenizer: AutoTokenizer, + ): + super().__init__() + self.data_paths = data_paths + if isinstance(self.data_paths, str): + self.data_paths = [self.data_paths] + + self.max_prompt_length = config.data.max_prompt_length + self.tokenizer = tokenizer + self.prompt_key = config.data.prompt_key + + self.data = self._load_data() + if config.data.get("filter_prompt_by_length", False): + total = len(self.data) + filtered = [] + failed = 0 + + for item in self.data: + try: + _, L = self.encode(item[self.prompt_key]) + if L <= self.max_prompt_length: + filtered.append(item) + except Exception: + failed += 1 + + self.data = filtered + assert len(self.data) > 0, ( + f"No samples found within max_prompt_length={self.max_prompt_length}. " + "Please check your dataset or increase max_prompt_length." + ) + + if failed > 0: + logging.warning( + f"{failed} samples were skipped due to format issues " + f"(kept {len(self.data)} / {total})." + ) + + def _load_data(self) -> List[Any]: + """ + Load and merge data from multiple files(json or jsonl). + """ + merged_data = [] + + for path in self.data_paths: + _, file_extension = os.path.splitext(path) + try: + with open(path, "r", encoding="utf-8") as file: + if file_extension == ".jsonl": + merged_data.extend([json.loads(line.strip()) for line in file]) + elif file_extension == ".json": + content = json.load(file) + if isinstance(content, list): + merged_data.extend(content) + else: + merged_data.append(content) + else: + print(f"Unsupport {file_extension}, skip: {path}") + except Exception: + raise RuntimeError("Load data error") + + return merged_data + + def __len__(self): + return len(self.data) + + def encode(self, text: str) -> Tuple[List[int], int]: + """ + Use tokenizer to encode the text and return the token ids and length. + """ + text_ids = self.tokenizer.encode(text) + return text_ids, len(text_ids) + + def __getitem__(self, idx): + """ + Return a single prompt. + """ + + prompt = self.data[idx][self.prompt_key] + + answer = self.data[idx]["solutions"] + + prompt_tokens, prompt_length = self.encode(prompt) + prompt_tokens_tensor = torch.as_tensor(prompt_tokens, dtype=torch.int64) + + if prompt_length > self.max_prompt_length: + print( + f"prompt_tokens_tensor length {prompt_length} exceeds the max_prompt_length {self.max_prompt_length}", + ) + prompt_tokens_tensor = prompt_tokens_tensor[: self.max_prompt_length] + prompt_length = self.max_prompt_length + + prompt_tokens_tensor = batch_pad_to_fixed_len( + [prompt_tokens_tensor], + self.max_prompt_length, + self.tokenizer.eos_token_id, + left_pad=True, + )[0] + output = DatasetItem( + prompt=prompt_tokens_tensor, + length=prompt_length, + answer=answer, + idx=idx, + image_data=[], + ) + return output diff --git a/rlinf/data/datasets/utils.py b/rlinf/data/datasets/utils.py new file mode 100644 index 000000000..db4dbdb58 --- /dev/null +++ b/rlinf/data/datasets/utils.py @@ -0,0 +1,68 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch + + +def batch_pad_to_fixed_len( + batch: List[torch.Tensor], + max_batch_len: int, + pad_token: int, + left_pad: bool = False, +) -> torch.Tensor: + if left_pad: + batch_pad = torch.stack( + [ + torch.cat( + [ + torch.full( + (max_batch_len - len(seq),), pad_token, dtype=seq.dtype + ), # pad on the left + seq, + ] + ) + for seq in batch + ] + ) + else: + batch_pad = torch.stack( + [ + torch.cat( + [ + seq, + torch.full( + (max_batch_len - len(seq),), pad_token, dtype=seq.dtype + ), + ] + ) + for seq in batch + ] + ) + return batch_pad diff --git a/rlinf/data/datasets.py b/rlinf/data/datasets/vlm.py similarity index 68% rename from rlinf/data/datasets.py rename to rlinf/data/datasets/vlm.py index 677377a68..18956dd81 100644 --- a/rlinf/data/datasets.py +++ b/rlinf/data/datasets/vlm.py @@ -12,10 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import json import logging import os -from dataclasses import dataclass from io import BytesIO from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -26,156 +40,8 @@ from torch.utils.data import Dataset from transformers import AutoProcessor, AutoTokenizer - -def batch_pad_to_fixed_len( - batch: List[torch.Tensor], - max_batch_len: int, - pad_token: int, - left_pad: bool = False, -) -> torch.Tensor: - if left_pad: - batch_pad = torch.stack( - [ - torch.cat( - [ - torch.full( - (max_batch_len - len(seq),), pad_token, dtype=seq.dtype - ), # pad on the left - seq, - ] - ) - for seq in batch - ] - ) - else: - batch_pad = torch.stack( - [ - torch.cat( - [ - seq, - torch.full( - (max_batch_len - len(seq),), pad_token, dtype=seq.dtype - ), - ] - ) - for seq in batch - ] - ) - return batch_pad - - -@dataclass -class DatasetItem: - prompt: torch.Tensor - length: int - answer: str | dict - idx: int - solution: Optional[str] = None - image_data: Optional[List[Union[bytes, str]]] = None - prompt_text: Optional[str] = None - meta: Optional[Dict[str, Any]] = None - multi_modal_inputs: Optional[Dict[str, Any]] = None - - -class MathDataset(Dataset): - def __init__(self, data_paths, config, tokenizer): - super().__init__() - self.data_paths = data_paths - if isinstance(self.data_paths, str): - self.data_paths = [self.data_paths] - - self.max_prompt_length = config.data.max_prompt_length - self.tokenizer = tokenizer - self.prompt_key = config.data.prompt_key - - self.data = self._load_data() - if config.data.get("filter_prompt_by_length", False): - total = len(self.data) - filtered = [] - failed = 0 - - for item in self.data: - try: - _, L = self.encode(item[self.prompt_key]) - if L <= self.max_prompt_length: - filtered.append(item) - except Exception: - failed += 1 - - self.data = filtered - assert len(self.data) > 0, ( - f"No samples found within max_prompt_length={self.max_prompt_length}. " - "Please check your dataset or increase max_prompt_length." - ) - - if failed > 0: - logging.warning( - f"{failed} samples were skipped due to format issues " - f"(kept {len(self.data)} / {total})." - ) - - def _load_data(self): - merged_data = [] - - for path in self.data_paths: - _, file_extension = os.path.splitext(path) - try: - with open(path, "r", encoding="utf-8") as file: - if file_extension == ".jsonl": - merged_data.extend([json.loads(line.strip()) for line in file]) - elif file_extension == ".json": - content = json.load(file) - if isinstance(content, list): - merged_data.extend(content) - else: - merged_data.append(content) - else: - print(f"Unsupport {file_extension}, skip: {path}") - except Exception: - raise RuntimeError("Load data error") - - return merged_data - - def __len__(self): - return len(self.data) - - def encode(self, text): - text_ids = self.tokenizer.encode(text) - return text_ids, len(text_ids) - - def __getitem__(self, idx): - """ - Return a single prompt. - """ - - prompt = self.data[idx][self.prompt_key] - - answer = self.data[idx]["solutions"] - - prompt_tokens, prompt_length = self.encode(prompt) - prompt_tokens_tensor = torch.as_tensor(prompt_tokens, dtype=torch.int64) - - if prompt_length > self.max_prompt_length: - print( - f"prompt_tokens_tensor length {prompt_length} exceeds the max_prompt_length {self.max_prompt_length}", - ) - prompt_tokens_tensor = prompt_tokens_tensor[: self.max_prompt_length] - prompt_length = self.max_prompt_length - - prompt_tokens_tensor = batch_pad_to_fixed_len( - [prompt_tokens_tensor], - self.max_prompt_length, - self.tokenizer.eos_token_id, - left_pad=True, - )[0] - output = DatasetItem( - prompt=prompt_tokens_tensor, - length=prompt_length, - answer=answer, - idx=idx, - image_data=[], - ) - return output +from rlinf.data.datasets.item import DatasetItem +from rlinf.data.datasets.utils import batch_pad_to_fixed_len class VLMBaseDataset(Dataset): @@ -542,7 +408,6 @@ def __init__( ) def get_image_list(self, dataitem: Dict[str, Any]) -> List[Union[bytes, str, None]]: - # Prefer common robo2vlm fields if present, else fallback to configured keys images: List[Any] = [] if "images" in dataitem: v = dataitem.get("images") @@ -559,10 +424,8 @@ def get_image_list(self, dataitem: Dict[str, Any]) -> List[Union[bytes, str, Non else: images = [None] else: - # fallback to base behavior using configured image_keys return super().get_image_list(dataitem) - # Normalize each element similar to base behavior normed: List[Union[bytes, str, None]] = [] for v in images: if v is None: @@ -611,107 +474,3 @@ def postprocess_dataset_item( item.answer = answer_dict return item - - -def create_rl_dataset(config: DictConfig, tokenizer): - """Create rl datasets. - - Arguments: - config: The RLinf config. - tokenizer (Tokenizer): The tokenizer. - - Returns: - train_dataset (Dataset): The training dataset. - - val_dataset (Dataset): The validation dataset. - """ - - if config.data.type == "math": - dataset_cls = MathDataset - elif config.data.type == "vision_language": - # Prefer new factory-based VLM datasets; fallback to legacy if requested - dataset_name = getattr(config.data, "dataset_name", None) - lazy_loading = bool(getattr(config.data, "lazy_loading", False)) - - print(f"Using VLM dataset: name={dataset_name}, lazy_loading={lazy_loading}") - - train_dataset = VLMDatasetRegistry.create( - dataset_name, - data_paths=config.data.train_data_paths, - config=config, - tokenizer=tokenizer, - ) - val_dataset = VLMDatasetRegistry.create( - dataset_name, - data_paths=config.data.val_data_paths, - config=config, - tokenizer=tokenizer, - ) - return train_dataset, val_dataset - else: - return None, None - - print(f"Using dataset class: {dataset_cls.__name__}") - - # Instantiate the dataset using the determined dataset class - train_dataset = dataset_cls( - data_paths=config.data.train_data_paths, - config=config, - tokenizer=tokenizer, - ) - - val_dataset = dataset_cls( - data_paths=config.data.val_data_paths, - config=config, - tokenizer=tokenizer, - ) - - return train_dataset, val_dataset - - -def collate_fn(data_list: List["DatasetItem"]) -> Dict[str, Any]: - prompts = [] - lens = [] - for it in data_list: - p = ( - it.prompt - if isinstance(it.prompt, torch.Tensor) - else torch.as_tensor(it.prompt, dtype=torch.long) - ) - if p.dim() == 2 and p.size(0) == 1: - p = p.squeeze(0) - assert p.dim() == 1, ( - f"DatasetItem.prompt must be 1-D tensor, current shape is: {p.shape}" - ) - prompts.append(p) - lens.append(p.numel()) - - if len(set(lens)) == 1: - target_len = lens[0] - else: - target_len = min(lens) - prompts = [p[-target_len:] if p.numel() > target_len else p for p in prompts] - - batch_prompt = torch.stack(prompts, dim=0) # [B, L] - batch_length = torch.tensor( - [min(int(it.length), target_len) for it in data_list], dtype=torch.long - ) - - batch_idx = torch.tensor([int(it.idx) for it in data_list], dtype=torch.long) - - batch: Dict[str, Any] = { - "prompt": batch_prompt, # [B, L] - "length": batch_length, # [B] - "answer": [it.answer for it in data_list], # List[str] - "idx": batch_idx, # [B] - "solution": [it.solution for it in data_list], # List[Optional[str]] - "image_data": [ - it.image_data for it in data_list - ], # List[Optional[List[bytes|str]]] - "prompt_text": [it.prompt_text for it in data_list], # List[Optional[str]] - "meta": [it.meta for it in data_list], # List[Optional[dict]] - "multi_modal_inputs": [ - it.multi_modal_inputs for it in data_list - ], # List[Optional[dict]] - } - return batch diff --git a/rlinf/data/io_struct.py b/rlinf/data/io_struct.py index b7dc359b5..cb5e6133e 100644 --- a/rlinf/data/io_struct.py +++ b/rlinf/data/io_struct.py @@ -69,10 +69,15 @@ def repeat(self) -> "RolloutRequest": """ assert self.n > 0, "n must be greater than 0" - input_ids, answers = zip( + input_ids, answers, image_data, multi_modal_inputs = zip( *[ - (input_id, answer) - for input_id, answer in zip(self.input_ids, self.answers) + (input_id, answer, image_data, multi_modal_inputs) + for input_id, answer, image_data, multi_modal_inputs in zip( + self.input_ids, + self.answers, + self.image_data, + self.multi_modal_inputs, + ) for _ in range(self.n) ] ) @@ -80,6 +85,8 @@ def repeat(self) -> "RolloutRequest": n=self.n, input_ids=list(input_ids), answers=list(answers), + image_data=list(image_data), + multi_modal_inputs=list(multi_modal_inputs), ) def split(self, num_splits: int) -> List["RolloutRequest"]: @@ -98,15 +105,27 @@ def split(self, num_splits: int) -> List["RolloutRequest"]: input_ids_split_list = split_list(self.input_ids, num_splits) answers_split_list = split_list(self.answers, num_splits) + image_data_split_list = split_list(self.image_data, num_splits) + multi_modal_inputs_split_list = split_list(self.multi_modal_inputs, num_splits) splitted_requests = [] - for input_ids_batch, answers_batch in zip( - input_ids_split_list, answers_split_list + for ( + input_ids_batch, + answers_batch, + image_data_batch, + multi_modal_inputs_batch, + ) in zip( + input_ids_split_list, + answers_split_list, + image_data_split_list, + multi_modal_inputs_split_list, ): request = RolloutRequest( n=self.n, input_ids=input_ids_batch, answers=answers_batch, + image_data=image_data_batch, + multi_modal_inputs=multi_modal_inputs_batch, ) splitted_requests.append(request) diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index ccf3bc68e..429476402 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -137,7 +137,7 @@ def _setup_rollout_weight_dst_ranks(self): ) def del_reshard_state_dict(self): - if hasattr(self, "rollou_state_dict"): + if hasattr(self, "rollout_state_dict"): del self.rollout_state_dict def sync_model_to_rollout(self): diff --git a/rlinf/workers/reward/reward_worker.py b/rlinf/workers/reward/reward_worker.py index fefd422d6..03b2311af 100644 --- a/rlinf/workers/reward/reward_worker.py +++ b/rlinf/workers/reward/reward_worker.py @@ -114,5 +114,7 @@ def compute_batch_rewards_with_model(self, batch: Dict[str, torch.Tensor]): self.model.eval() with torch.no_grad(): # TODO: fix this - rewards = self.model(batch["input_ids"], batch["attention_mask"]) + rewards = ( + self.model(batch["input_ids"], batch["attention_mask"]).detach().cpu() + ) return rewards diff --git a/rlinf/workers/rollout/sglang/sglang_worker.py b/rlinf/workers/rollout/sglang/sglang_worker.py index a339edd9e..93f268d6c 100644 --- a/rlinf/workers/rollout/sglang/sglang_worker.py +++ b/rlinf/workers/rollout/sglang/sglang_worker.py @@ -276,7 +276,7 @@ async def _compute_reward_and_advantage( ) results = math_verify_call(texts, answers) - rewards = [(1 if r else -1) * self._reward_model.scale for r in results] + rewards = [r * self._reward_model.scale for r in results] rewards_tensor = torch.tensor(rewards, dtype=torch.float) mean = rewards_tensor.mean() @@ -333,20 +333,6 @@ async def rollout(self, input_channel: Channel, output_channel: Channel): if self._completion_info.is_completed(hash_id): results = self._completion_info.get_results(hash_id) - # ( - # rewards, - # advantages, - # ) = await self._compute_reward_and_advantage( - # results, - # self._current_request.answers[raw_id], - # ) - # if ( - # all_floats_equal(rewards) - # and self._cfg.algorithm.get("max_num_gen_batches", 1) > 1 - # ): - # if (total_reqs - droped_reqs) > required_reqs: - # droped_reqs += rollout_request.n - # continue input_ids = [input_ids] * len(results) rollout_result = RolloutResult.from_sglang_results( @@ -355,10 +341,7 @@ async def rollout(self, input_channel: Channel, output_channel: Channel): input_ids, return_logprobs=self._return_logprobs, ) - # rollout_result.rewards = torch.tensor( - # rewards, dtype=torch.float32 - # ).reshape(-1, 1) - # rollout_result.advantages = advantages + return_tasks.append( asyncio.create_task( self._put_result(rollout_result, output_channel) diff --git a/rlinf/workers/rollout/vllm/vllm_worker.py b/rlinf/workers/rollout/vllm/vllm_worker.py index 899e07f1e..edeab9f0d 100644 --- a/rlinf/workers/rollout/vllm/vllm_worker.py +++ b/rlinf/workers/rollout/vllm/vllm_worker.py @@ -19,7 +19,6 @@ from typing import AsyncGenerator, List, Optional, Union import requests -import torch from omegaconf import DictConfig from PIL.Image import Image from transformers import AutoTokenizer @@ -36,7 +35,7 @@ from rlinf.scheduler import Channel, Worker from rlinf.utils.placement import ComponentPlacement from rlinf.workers.rollout.utils import print_vllm_outputs -from toolkits.math_verifier.verify import MathRewardModel, math_verify_call +from toolkits.math_verifier.verify import MathRewardModel from . import VLLMExecutor @@ -363,32 +362,6 @@ async def _stop(self) -> None: if not self._placement.is_disaggregated: await self.offload_model_weights() - async def _compute_reward_and_advantage(self, rollout_result: RolloutResult): - """ - Compute rewards and advantages for the rollout result using math verification. - """ - answers = rollout_result.answers - outputs = rollout_result.response_texts - num_sequence = rollout_result.num_sequence - assert len(answers) == len(outputs), ( - f"Answers length {len(answers)} != outputs length {len(outputs)}" - ) - assert len(answers) == num_sequence, ( - f"Answers length {len(answers)} != num_sequence {num_sequence}" - ) - - math_verify_results = math_verify_call(outputs, answers) - rewards = [ - (1 if r else -1) * self._reward_model.scale for r in math_verify_results - ] - rewards_tensor = torch.tensor(rewards, dtype=torch.float) - rollout_result.rewards = rewards_tensor.reshape(-1, 1) - - mean = rewards_tensor.mean() - std = rewards_tensor.std(unbiased=False) - advantages = (rewards_tensor - mean) / (std + 1e-6) - rollout_result.advantages = advantages.tolist() - async def rollout_and_return( self, request: RolloutRequest, output_channel: Channel ): @@ -402,8 +375,6 @@ async def rollout_and_return( multi_modal_inputs=request.multi_modal_inputs, return_logprobs=self._return_logprobs, ) - if self._placement.is_disaggregated: - await self._compute_reward_and_advantage(rollout_result) await self._put_result(result=rollout_result, output_channel=output_channel) diff --git a/tests/unit_tests/test_io_struct.py b/tests/unit_tests/test_io_struct.py new file mode 100644 index 000000000..7104c277a --- /dev/null +++ b/tests/unit_tests/test_io_struct.py @@ -0,0 +1,135 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from rlinf.data.io_struct import RolloutRequest, RolloutResult + + +def test_rollout_request_repeat_preserves_multimodal(): + request = RolloutRequest( + n=2, + input_ids=[[1, 2, 3], [4, 5]], + image_data=[[b"img1-1", b"img1-2"], []], + answers=["ans1", "ans2"], + multi_modal_inputs=[{"pixels": [1, 2]}, {"pixels": [3]}], + ) + + repeated = request.repeat() + + assert repeated.n == 2 + assert repeated.input_ids == [[1, 2, 3], [1, 2, 3], [4, 5], [4, 5]] + assert repeated.answers == ["ans1", "ans1", "ans2", "ans2"] + assert repeated.image_data == [ + [b"img1-1", b"img1-2"], + [b"img1-1", b"img1-2"], + [], + [], + ] + assert repeated.multi_modal_inputs == [ + {"pixels": [1, 2]}, + {"pixels": [1, 2]}, + {"pixels": [3]}, + {"pixels": [3]}, + ] + + +def _make_rollout_result(): + num_sequence = 4 + group_size = 2 + return RolloutResult( + num_sequence=num_sequence, + group_size=group_size, + prompt_lengths=[3, 3, 4, 4], + prompt_ids=[[11, 12, 13], [11, 12, 13], [21, 22, 23, 24], [21, 22, 23, 24]], + response_lengths=[2, 2, 2, 2], + response_ids=[[101, 102], [201, 202], [301, 302], [401, 402]], + is_end=[True, False, True, True], + answers=[{"answer": "a"}, {"answer": "b"}, {"answer": "c"}, {"answer": "d"}], + image_data=[[b"a"], [b"b"], [b"c"], [b"d"]], + multi_modal_inputs=[ + {"vision": "img-a"}, + {"vision": "img-b"}, + {"vision": "img-c"}, + {"vision": "img-d"}, + ], + prompt_texts=["prompt-a", "prompt-a", "prompt-b", "prompt-b"], + response_texts=["resp-a1", "resp-a2", "resp-b1", "resp-b2"], + rollout_logprobs=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]], + rewards=torch.tensor([[1.0], [0.5], [0.2], [0.1]]), + advantages=[0.1, 0.2, 0.3, 0.4], + prev_logprobs=torch.tensor( + [ + [0.01, 0.02], + [0.03, 0.04], + [0.05, 0.06], + [0.07, 0.08], + ] + ), + ref_logprobs=torch.tensor( + [ + [0.11, 0.12], + [0.13, 0.14], + [0.15, 0.16], + [0.17, 0.18], + ] + ), + ) + + +def test_rollout_result_split_and_merge_roundtrip(): + result = _make_rollout_result() + + split_results = RolloutResult.split_result_list_by_group([result]) + + assert len(split_results) == result.num_sequence // result.group_size + first, second = split_results + + assert first.num_sequence == result.group_size + assert second.num_sequence == result.group_size + assert first.prompt_ids == result.prompt_ids[: result.group_size] + assert second.prompt_ids == result.prompt_ids[result.group_size :] + assert first.response_ids == result.response_ids[: result.group_size] + assert second.response_ids == result.response_ids[result.group_size :] + assert first.prompt_texts == result.prompt_texts[: result.group_size] + assert second.prompt_texts == result.prompt_texts[result.group_size :] + assert first.response_texts == result.response_texts[: result.group_size] + assert second.response_texts == result.response_texts[result.group_size :] + assert first.image_data == result.image_data[: result.group_size] + assert second.image_data == result.image_data[result.group_size :] + assert first.multi_modal_inputs == result.multi_modal_inputs[: result.group_size] + assert second.multi_modal_inputs == result.multi_modal_inputs[result.group_size :] + assert first.rollout_logprobs == result.rollout_logprobs[: result.group_size] + assert second.rollout_logprobs == result.rollout_logprobs[result.group_size :] + assert torch.equal(first.rewards, result.rewards[: result.group_size]) + assert torch.equal(second.rewards, result.rewards[result.group_size :]) + assert first.advantages == result.advantages[: result.group_size] + assert second.advantages == result.advantages[result.group_size :] + + merged = RolloutResult.merge_result_list(split_results) + + assert merged.num_sequence == result.num_sequence + assert merged.group_size == result.group_size + assert merged.prompt_ids == result.prompt_ids + assert merged.prompt_lengths == result.prompt_lengths + assert merged.response_ids == result.response_ids + assert merged.response_lengths == result.response_lengths + assert merged.is_end == result.is_end + assert merged.answers == result.answers + assert merged.rollout_logprobs == result.rollout_logprobs + assert merged.advantages == result.advantages + assert torch.equal(merged.rewards, result.rewards) + assert torch.equal(merged.prev_logprobs, result.prev_logprobs) + assert torch.equal(merged.ref_logprobs, result.ref_logprobs) From ecb1ed0b463ec7df5ef4c0417942be9ee0ad4a3e Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Thu, 2 Oct 2025 11:34:25 +0000 Subject: [PATCH 19/38] fix(mm_data): unify vllm/sglang's mm_data passing Signed-off-by: Bo Dai --- rlinf/data/datasets/vlm.py | 5 +- rlinf/data/io_struct.py | 2 +- rlinf/workers/actor/fsdp_actor_worker.py | 4 +- rlinf/workers/rollout/sglang/sglang_worker.py | 4 +- rlinf/workers/rollout/vllm/vllm_worker.py | 62 ++++++++++++++++--- 5 files changed, 64 insertions(+), 13 deletions(-) diff --git a/rlinf/data/datasets/vlm.py b/rlinf/data/datasets/vlm.py index 18956dd81..73509c32a 100644 --- a/rlinf/data/datasets/vlm.py +++ b/rlinf/data/datasets/vlm.py @@ -178,7 +178,10 @@ def encode_prompt( text=[rendered], images=images_inputs, padding=True, return_tensors="pt" ) inputs.pop("attention_mask") - inputs.pop("input_ids") + # NOTE: + # we use these input_ids in inputs rather than belows + # because sglang need corresponding pixel_values len's placeholder + # in input_ids, while vllm does not need. ids = self._processor( text=[rendered], images=None, padding=True, return_tensors="pt" )["input_ids"] diff --git a/rlinf/data/io_struct.py b/rlinf/data/io_struct.py index cb5e6133e..8a01ddc1c 100644 --- a/rlinf/data/io_struct.py +++ b/rlinf/data/io_struct.py @@ -20,7 +20,7 @@ from vllm.outputs import CompletionOutput from vllm.outputs import RequestOutput as VllmRequestOutput -from rlinf.data.datasets import batch_pad_to_fixed_len +from rlinf.data.datasets.utils import batch_pad_to_fixed_len from rlinf.utils.data_iter_utils import ( get_iterator_k_split, split_list, diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index 429476402..1e9e2c598 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -356,9 +356,9 @@ def run_training(self, input_channel: Channel): ) append_to_dict(metrics, mbs_metrics_data) - mean_metric_dict = { - key: np.mean(value) for key, value in metrics.items() + key: torch.mean(torch.stack(value)) + for key, value in metrics.items() } mean_metric_dict = all_reduce_dict( mean_metric_dict, op=torch.distributed.ReduceOp.AVG diff --git a/rlinf/workers/rollout/sglang/sglang_worker.py b/rlinf/workers/rollout/sglang/sglang_worker.py index 93f268d6c..3fcda8fca 100644 --- a/rlinf/workers/rollout/sglang/sglang_worker.py +++ b/rlinf/workers/rollout/sglang/sglang_worker.py @@ -179,7 +179,9 @@ def rollout(self, input_channel: Channel, output_channel: Channel): # Generate outputs using the SGLang engine. with self.worker_timer(): results = self._engine.generate( - input_ids=request.input_ids, + input_ids=request.input_ids + if request.multi_modal_inputs + else request.multi_modal_inputs["input_ids"], # 0.4.4 has modality bug,can't pass non-None image_data image_data=request.image_data if any(request.image_data) else None, sampling_params=self._sampling_params, diff --git a/rlinf/workers/rollout/vllm/vllm_worker.py b/rlinf/workers/rollout/vllm/vllm_worker.py index edeab9f0d..b3f82d9d9 100644 --- a/rlinf/workers/rollout/vllm/vllm_worker.py +++ b/rlinf/workers/rollout/vllm/vllm_worker.py @@ -20,7 +20,7 @@ import requests from omegaconf import DictConfig -from PIL.Image import Image +from PIL import Image from transformers import AutoTokenizer from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs @@ -226,6 +226,23 @@ async def generate( Union[List[List[Union[bytes, str]]], List[Union[bytes, str]]] ] = None, ) -> List[RequestOutput]: + """ + Do Generate Task using the vllm async engine. + + Args: + input_ids: The input token ids to generate. It can be a list of list of int, + or a list of int (single prompt). + sampling_params: The sampling parameters to use for generation. + prompt_texts: The input prompt texts to generate. It can be a list of strings + or a single string. If provided, it will be used instead of input_ids. + image_data: The input multi-modal data to generate. It can be a list of list + of bytes or image paths (local or URL), or a list of bytes or image paths + (single prompt). + + Returns: + List[RequestOutput]: A list of RequestOutput from vllm engine. + """ + def check_input_ids() -> List[List[int]]: assert isinstance(input_ids, list), ( "input_ids should be a list or list of list of int." @@ -266,19 +283,22 @@ def check_image_data() -> Optional[List[List[Image]]]: if prompt_texts is not None: for i, prompt_text in enumerate(prompt_texts): if image_list is not None: - image_list = self._process_image_data(image_data=image_list[i]) + images = self._process_image_data(image_data=image_list[i]) inputs.append( - TextPrompt(prompt=prompt_text, multi_modal_data=image_list) + TextPrompt( + prompt=prompt_text, multi_modal_data={"image": images} + ) ) else: inputs.append(TextPrompt(prompt=prompt_text)) else: for i, input_id in enumerate(input_ids): if image_list is not None: - image_list = self._process_image_data(image_data=image_list[i]) + images = self._process_image_data(image_data=image_list[i]) inputs.append( TokensPrompt( - prompt_token_ids=input_id, multi_modal_data=image_list + prompt_token_ids=input_id, + multi_modal_data={"image": images}, ) ) else: @@ -302,7 +322,8 @@ def check_image_data() -> Optional[List[List[Image]]]: async def init_worker(self) -> None: """ Use EngineArgs and VllmConfig to initialize VLLM async engine. - Then offload the model weights, ready to use weights sent from actor. + If mode is collocated, it will additionally offload model weights, + ready to use parameters sent from actor. """ engine_args: EngineArgs = EngineArgs( model=self._cfg.rollout.model_dir, @@ -349,6 +370,13 @@ async def init_worker(self) -> None: await self.offload_model_weights() async def _put_result(self, result: RolloutResult, output_channel: Channel) -> None: + """ + Helper function to put the result to output channel. + + Args: + result: The RolloutResult to put to the channel. + output_channel: The output channel to send results to. + """ await output_channel.put(result, async_op=True).async_wait() async def _stop(self) -> None: @@ -365,8 +393,18 @@ async def _stop(self) -> None: async def rollout_and_return( self, request: RolloutRequest, output_channel: Channel ): + """ + Helper function to rollout for a single RolloutRequest and build RolloutResult then + put it to output channel. + + Args: + request: The RolloutRequest to process. + output_channel: The output channel to send results to. + """ vllm_results: List[RequestOutput] = await self.generate( - input_ids=request.input_ids, sampling_params=self._sampling_params + input_ids=request.input_ids, + image_data=request.image_data, + sampling_params=self._sampling_params, ) rollout_result: RolloutResult = RolloutResult.from_vllm_results( group_size=self._cfg.algorithm.group_size, @@ -375,10 +413,18 @@ async def rollout_and_return( multi_modal_inputs=request.multi_modal_inputs, return_logprobs=self._return_logprobs, ) - await self._put_result(result=rollout_result, output_channel=output_channel) async def rollout(self, input_channel: Channel, output_channel: Channel) -> None: + """ + Perform rollout using vllm engine. + It will read `RolloutRequest` from input_channel and put `RolloutResult` to output_channel. + If the input request is None, it will stop the rollout. + + Args: + input_channel: The input channel to read from. + output_channel: The output channel to send results to. + """ rollout_request: RolloutRequest = await input_channel.get( async_op=True ).async_wait() From 582a438525ea79d66e7044937bc8a5ea766b5908 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Thu, 2 Oct 2025 12:59:02 +0000 Subject: [PATCH 20/38] fix(rollout): fix some problems in sglang/vllm, now both are ok Signed-off-by: Bo Dai --- rlinf/data/datasets/vlm.py | 18 +++++++++++------- rlinf/data/io_struct.py | 9 ++++++++- rlinf/utils/data_iter_utils.py | 8 ++++++-- rlinf/workers/actor/fsdp_actor_worker.py | 2 -- rlinf/workers/rollout/sglang/sglang_worker.py | 4 +--- rlinf/workers/rollout/vllm/vllm_worker.py | 9 ++++++++- 6 files changed, 34 insertions(+), 16 deletions(-) diff --git a/rlinf/data/datasets/vlm.py b/rlinf/data/datasets/vlm.py index 73509c32a..da9e952b9 100644 --- a/rlinf/data/datasets/vlm.py +++ b/rlinf/data/datasets/vlm.py @@ -178,13 +178,17 @@ def encode_prompt( text=[rendered], images=images_inputs, padding=True, return_tensors="pt" ) inputs.pop("attention_mask") - # NOTE: - # we use these input_ids in inputs rather than belows - # because sglang need corresponding pixel_values len's placeholder - # in input_ids, while vllm does not need. - ids = self._processor( - text=[rendered], images=None, padding=True, return_tensors="pt" - )["input_ids"] + if self.cfg.rollout.rollout_backend == "sglang": + ids = inputs.pop("input_ids") + elif self.cfg.rollout.rollout_backend == "vllm": + inputs.pop("input_ids") + ids = self._processor( + text=[rendered], images=None, padding=True, return_tensors="pt" + )["input_ids"] + else: + raise ValueError( + f"Unsupported rollout backend {self.cfg.rollout.rollout_backend}" + ) if isinstance(ids, torch.Tensor): if ids.dim() == 2 and ids.size(0) == 1: ids = ids.squeeze(0) diff --git a/rlinf/data/io_struct.py b/rlinf/data/io_struct.py index 8a01ddc1c..bb1978754 100644 --- a/rlinf/data/io_struct.py +++ b/rlinf/data/io_struct.py @@ -371,6 +371,13 @@ def get_logprobs( num_sequences = len(results) * group_size + if multi_modal_inputs: + mm_inputs = [] + for mm_input in multi_modal_inputs: + mm_inputs.extend([mm_input] * group_size) + else: + mm_inputs = None + prompt_lengths = [] prompt_ids = [] response_lengths = [] @@ -413,7 +420,7 @@ def get_logprobs( response_ids=response_ids, response_lengths=response_lengths, response_texts=response_texts, - multi_modal_inputs=multi_modal_inputs, + multi_modal_inputs=mm_inputs, is_end=is_end, ) if return_logprobs: diff --git a/rlinf/utils/data_iter_utils.py b/rlinf/utils/data_iter_utils.py index 27bbd7215..31e440201 100644 --- a/rlinf/utils/data_iter_utils.py +++ b/rlinf/utils/data_iter_utils.py @@ -60,13 +60,17 @@ def concat_dict_list(list_of_dicts: List[Dict[str, Any]]) -> Dict[str, Any]: return result -def split_list(inputs, num_chunks, enforce_divisible_batch: Optional[bool] = True): +def split_list( + inputs: List, num_chunks: int, enforce_divisible_batch: Optional[bool] = True +): """ Split a list into equal sized chunks """ if enforce_divisible_batch: chunk_size = len(inputs) // num_chunks - assert len(inputs) % chunk_size == 0, "Issue with batch size configuration!" + assert len(inputs) % chunk_size == 0, ( + f"Issue with batch size configuration! inputs len:{len(inputs)} num_chunks:{num_chunks}" + ) return [inputs[i : i + chunk_size] for i in range(0, len(inputs), chunk_size)] else: k, m = divmod(len(inputs), num_chunks) diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index 1e9e2c598..7bc376fba 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -214,7 +214,6 @@ def run_training(self, input_channel: Channel): f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" ) batch = RolloutResult.merge_batches(batches) - # Must be called after batch is retrieved, which is when rollout has stopped # Otherwise, loading model might cause OOM self._load_weight_and_optimizer(input_channel) @@ -279,7 +278,6 @@ def run_training(self, input_channel: Channel): ref_logprobs = m_batch["ref_logprobs"] loss_mask = m_batch["attention_mask"][:, -self.response_len :] - output = self.model( input_ids=input_ids, attention_mask=attention_mask, diff --git a/rlinf/workers/rollout/sglang/sglang_worker.py b/rlinf/workers/rollout/sglang/sglang_worker.py index 3fcda8fca..93f268d6c 100644 --- a/rlinf/workers/rollout/sglang/sglang_worker.py +++ b/rlinf/workers/rollout/sglang/sglang_worker.py @@ -179,9 +179,7 @@ def rollout(self, input_channel: Channel, output_channel: Channel): # Generate outputs using the SGLang engine. with self.worker_timer(): results = self._engine.generate( - input_ids=request.input_ids - if request.multi_modal_inputs - else request.multi_modal_inputs["input_ids"], + input_ids=request.input_ids, # 0.4.4 has modality bug,can't pass non-None image_data image_data=request.image_data if any(request.image_data) else None, sampling_params=self._sampling_params, diff --git a/rlinf/workers/rollout/vllm/vllm_worker.py b/rlinf/workers/rollout/vllm/vllm_worker.py index b3f82d9d9..1c6e859a5 100644 --- a/rlinf/workers/rollout/vllm/vllm_worker.py +++ b/rlinf/workers/rollout/vllm/vllm_worker.py @@ -377,7 +377,14 @@ async def _put_result(self, result: RolloutResult, output_channel: Channel) -> N result: The RolloutResult to put to the channel. output_channel: The output channel to send results to. """ - await output_channel.put(result, async_op=True).async_wait() + # NOTE: + # To fit reward worker and actor workers' expected input count, + # currently we can only split result into groups. + splited_results = RolloutResult.split_result_list_by_group([result]) + put_tasks = [ + output_channel.put(r, async_op=True).async_wait() for r in splited_results + ] + await asyncio.gather(*put_tasks) async def _stop(self) -> None: """ From fa5b8610d1adf6cb71768c42bad593b719e5e484 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Thu, 2 Oct 2025 16:23:30 +0000 Subject: [PATCH 21/38] fix(ci): add ci for vqa Signed-off-by: Bo Dai --- .github/workflows/vqa_e2e.yml | 63 +++++ rlinf/algorithms/rewards/__init__.py | 4 +- .../sglang/qwen2.5-vl-3b-grpo-collocated.yaml | 222 ++++++++++++++++++ tests/e2e_tests/vqa/sglang/run_collocated.sh | 17 ++ .../vllm/qwen2.5-vl-3b-grpo-collocated.yaml | 222 ++++++++++++++++++ tests/e2e_tests/vqa/vllm/run_collocated.sh | 17 ++ 6 files changed, 543 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/vqa_e2e.yml create mode 100644 tests/e2e_tests/vqa/sglang/qwen2.5-vl-3b-grpo-collocated.yaml create mode 100644 tests/e2e_tests/vqa/sglang/run_collocated.sh create mode 100644 tests/e2e_tests/vqa/vllm/qwen2.5-vl-3b-grpo-collocated.yaml create mode 100644 tests/e2e_tests/vqa/vllm/run_collocated.sh diff --git a/.github/workflows/vqa_e2e.yml b/.github/workflows/vqa_e2e.yml new file mode 100644 index 000000000..58ce056e2 --- /dev/null +++ b/.github/workflows/vqa_e2e.yml @@ -0,0 +1,63 @@ +name: VQA End2End + +on: + push: + branches: + - 'release/v[0-9].[0-9]' + - main + paths: + - '**/*.py' + - 'tests/**' + - '.github/workflows/*.yml' + - '!docs/**' + - '!README.md' + - '!*.yaml' + - '!*.toml' + - '!ray_utils/**' + - '!requirements/**' + + pull_request: + branches: + - 'release/v[0-9].[0-9]' + - main + paths: + - '**/*.py' + - 'tests/**' + - '.github/workflows/*.yml' + - '!docs/**' + - '!README.md' + - '*.yaml' + - '*.toml' + - '!ray_utils/**' + - '!requirements/**' + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + qwen-vl-grpo-test: + runs-on: rlinf + container: + image: rlinf/rlinf:math-rlinf0.1-torch2.6.0-sglang0.4.6.post5-vllm0.8.5-megatron0.13.0-te2.1 + volumes: + - /mnt/public/dataset:/workspace/dataset + - /mnt/public/tokenizer:/workspace/tokenizer + options: --gpus="all" --shm-size=80g + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: SGLang Collocated mode + run: | + export REPO_PATH=$(pwd) + bash tests/e2e_tests/vqa/sglang/run_collocated.sh + + - name: vLLM Collocated mode + run: | + export REPO_PATH=$(pwd) + bash tests/e2e_tests/vqa/vllm/run_collocated.sh diff --git a/rlinf/algorithms/rewards/__init__.py b/rlinf/algorithms/rewards/__init__.py index 3d354437b..2ab6528ca 100644 --- a/rlinf/algorithms/rewards/__init__.py +++ b/rlinf/algorithms/rewards/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .math import MathReward -from .vqa import VQAReward +from rlinf.algorithms.rewards.math import MathReward +from rlinf.algorithms.rewards.vqa import VQAReward def register_reward(name: str, reward_class: type): diff --git a/tests/e2e_tests/vqa/sglang/qwen2.5-vl-3b-grpo-collocated.yaml b/tests/e2e_tests/vqa/sglang/qwen2.5-vl-3b-grpo-collocated.yaml new file mode 100644 index 000000000..30c7a150c --- /dev/null +++ b/tests/e2e_tests/vqa/sglang/qwen2.5-vl-3b-grpo-collocated.yaml @@ -0,0 +1,222 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: ${runner.output_dir}/${runner.experiment_name} + project_name: rlinf + experiment_name: ${runner.experiment_name} + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 5 + max_steps: -1 + + val_check_interval: 1 + save_interval: 50 + + seq_length: 2048 + + enable_dynamic_batch_size: False + max_tokens_per_mbs: 28672 + + resume_dir: null + experiment_name: grpo-qwen2.5-vl-3b + output_dir: /workspace/results + +algorithm: + group_size: 8 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: False + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /workspace/dataset/Qwen2.5-VL-3B-Instruct + model_arch: qwen2.5_vl #qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm] + + sglang: + attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: vision_language + dataset_name: robo2vlm + max_prompt_length: 1024 + filter_prompt_by_length: True + rollout_batch_size: 8 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + image_keys: ["image"] # some vlm datasets may have multiple image columns + choice_key: "choices" + answer_key: "answer" + solution_key: "solution" + use_chat_template: True + lazy_loading: True + shuffle: True + validation_shuffle: True + seed: 1234 + train_data_paths: ["/workspace/dataset/robo2vlm-1/data/train-00000-of-00262.parquet"] + val_data_paths: ["/workspace/dataset/robo2vlm-1/data/test-00000-of-00003.parquet"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /workspace/dataset/Qwen2.5-VL-3B-Instruct + + model_arch: ${rollout.model_arch} + + optim: + optimizer: adam + bf16: True #False + fp16: False #True + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /workspace/dataset/Qwen2.5-VL-3B-Instruct + use_fast: False + trust_remote_code: True + padding_side: 'right' + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'vqa' + reward_scale: 1.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + +critic: + use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/vqa/sglang/run_collocated.sh b/tests/e2e_tests/vqa/sglang/run_collocated.sh new file mode 100644 index 000000000..43fa65fd0 --- /dev/null +++ b/tests/e2e_tests/vqa/sglang/run_collocated.sh @@ -0,0 +1,17 @@ +#! /bin/bash +set -x + +tabs 4 +export VLLM_ATTENTION_BACKEND=XFORMERS +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export TOKENIZERS_PARALLELISM=false + +export PYTHONPATH=${REPO_PATH}:$PYTHONPATH + +if [ -z "$1" ]; then + CONFIG_NAME="qwen2.5-vl-3b-grpo-collocated" +else + CONFIG_NAME=$1 +fi + +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/math/sglang --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/e2e_tests/vqa/vllm/qwen2.5-vl-3b-grpo-collocated.yaml b/tests/e2e_tests/vqa/vllm/qwen2.5-vl-3b-grpo-collocated.yaml new file mode 100644 index 000000000..aef3ec271 --- /dev/null +++ b/tests/e2e_tests/vqa/vllm/qwen2.5-vl-3b-grpo-collocated.yaml @@ -0,0 +1,222 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: ${runner.output_dir}/${runner.experiment_name} + project_name: rlinf + experiment_name: ${runner.experiment_name} + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 5 + max_steps: -1 + + val_check_interval: 1 + save_interval: 50 + + seq_length: 2048 + + enable_dynamic_batch_size: False + max_tokens_per_mbs: 28672 + + resume_dir: null + experiment_name: grpo-qwen2.5-vl-3b + output_dir: /workspace/results + +algorithm: + group_size: 8 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: False + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /workspace/dataset/Qwen2.5-VL-3B-Instruct + model_arch: qwen2.5_vl #qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: vllm # here choose which backend to rollout,support [sglang, vllm] + + sglang: + attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: vision_language + dataset_name: robo2vlm + max_prompt_length: 1024 + filter_prompt_by_length: True + rollout_batch_size: 8 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + image_keys: ["image"] # some vlm datasets may have multiple image columns + choice_key: "choices" + answer_key: "answer" + solution_key: "solution" + use_chat_template: True + lazy_loading: True + shuffle: True + validation_shuffle: True + seed: 1234 + train_data_paths: ["/workspace/dataset/robo2vlm-1/data/train-00000-of-00262.parquet"] + val_data_paths: ["/workspace/dataset/robo2vlm-1/data/test-00000-of-00003.parquet"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /workspace/dataset/Qwen2.5-VL-3B-Instruct + + model_arch: ${rollout.model_arch} + + optim: + optimizer: adam + bf16: True #False + fp16: False #True + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /workspace/dataset/Qwen2.5-VL-3B-Instruct + use_fast: False + trust_remote_code: True + padding_side: 'right' + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'vqa' + reward_scale: 1.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + +critic: + use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/vqa/vllm/run_collocated.sh b/tests/e2e_tests/vqa/vllm/run_collocated.sh new file mode 100644 index 000000000..ab406c4ea --- /dev/null +++ b/tests/e2e_tests/vqa/vllm/run_collocated.sh @@ -0,0 +1,17 @@ +#! /bin/bash +set -x + +tabs 4 +export VLLM_ATTENTION_BACKEND=XFORMERS +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export TOKENIZERS_PARALLELISM=false + +export PYTHONPATH=${REPO_PATH}:$PYTHONPATH + +if [ -z "$1" ]; then + CONFIG_NAME="qwen2.5-vl-3b-grpo-collocated" +else + CONFIG_NAME=$1 +fi + +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/vqa/vllm --config-name $CONFIG_NAME \ No newline at end of file From 74535e68e99e1160849750bef670f810ee9b236d Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Fri, 3 Oct 2025 04:57:47 +0000 Subject: [PATCH 22/38] fix(ci): fix some bugs in ci Signed-off-by: Bo Dai --- rlinf/data/io_struct.py | 12 +++++++----- .../megatron/megatron_model_manager.py | 1 + rlinf/workers/actor/fsdp_actor_worker.py | 3 ++- rlinf/workers/actor/megatron_actor_worker.py | 4 ---- rlinf/workers/rollout/sglang/sglang_worker.py | 2 ++ rlinf/workers/rollout/vllm/vllm_worker.py | 4 ++-- ...5-1.5b-grpo-pipeline-rollout-logprobs.yaml | 12 ++++++++++-- .../math/vllm/qwen2.5-1.5b-grpo-pipeline.yaml | 12 ++++++++++-- tests/e2e_tests/math/vllm/run_pipeline.sh | 2 +- tests/e2e_tests/vqa/sglang/run_collocated.sh | 2 +- tests/unit_tests/test_auto_placement.py | 4 ++-- toolkits/auto_placement/scheduler_task.py | 10 ++++++---- toolkits/math_verifier/verify.py | 19 +++++++++++++++---- 13 files changed, 59 insertions(+), 28 deletions(-) diff --git a/rlinf/data/io_struct.py b/rlinf/data/io_struct.py index bb1978754..4948131a6 100644 --- a/rlinf/data/io_struct.py +++ b/rlinf/data/io_struct.py @@ -13,12 +13,14 @@ # limitations under the License. from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import torch from omegaconf import DictConfig -from vllm.outputs import CompletionOutput -from vllm.outputs import RequestOutput as VllmRequestOutput + +if TYPE_CHECKING: + from vllm.outputs import CompletionOutput + from vllm.outputs import RequestOutput as VllmRequestOutput from rlinf.data.datasets.utils import batch_pad_to_fixed_len from rlinf.utils.data_iter_utils import ( @@ -352,13 +354,13 @@ def _get_attention_masks_and_position_ids( @staticmethod def from_vllm_results( group_size: int, - results: List[VllmRequestOutput], + results: List["VllmRequestOutput"], answers: Optional[List[str]] = None, multi_modal_inputs: Optional[List[Dict]] = None, return_logprobs: bool = False, ) -> "RolloutResult": def get_logprobs( - response_ids: List[int], output: CompletionOutput + response_ids: List[int], output: "CompletionOutput" ) -> List[float]: logprobs = [] returned_logprobs = output.logprobs diff --git a/rlinf/hybrid_engines/megatron/megatron_model_manager.py b/rlinf/hybrid_engines/megatron/megatron_model_manager.py index 57f34dbf5..6fe4fbfd6 100644 --- a/rlinf/hybrid_engines/megatron/megatron_model_manager.py +++ b/rlinf/hybrid_engines/megatron/megatron_model_manager.py @@ -184,6 +184,7 @@ def model_provider_func(self, pre_process, post_process): return model def optimizer_step(self, increment): + clear_memory() success, grad_norm, num_zeros_in_grad = self.optimizer.step() self.lr_scheduler.step(increment=increment) diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index 7bc376fba..e13a8949f 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -25,7 +25,6 @@ import rlinf.algorithms # noqa: F401 from rlinf.algorithms.registry import actor_loss, calculate_adv_and_returns -from rlinf.algorithms.rewards import get_reward_class from rlinf.algorithms.utils import ( kl_penalty, preprocess_advantages_inputs, @@ -113,6 +112,8 @@ def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): assert self.cfg.reward.reward_type in ["math", "vqa"], ( "only support math and vqa reward!" ) + from rlinf.algorithms.rewards import get_reward_class + reward_cls = get_reward_class(self.cfg.reward.reward_type) self.reward = reward_cls(self.cfg.reward) diff --git a/rlinf/workers/actor/megatron_actor_worker.py b/rlinf/workers/actor/megatron_actor_worker.py index a08415ea9..bbac0701c 100644 --- a/rlinf/workers/actor/megatron_actor_worker.py +++ b/rlinf/workers/actor/megatron_actor_worker.py @@ -885,10 +885,6 @@ def compute_advantages_and_returns( input_channel: The input channel to read from. output_channel: The output channel to send results to. """ - if self.is_pipeline: - # In pipeline mode, advantages are computed in the rollout - with self.worker_timer(): - return clear_memory() recv_batch_size = 0 while recv_batch_size < self.total_batch_size_per_dp: diff --git a/rlinf/workers/rollout/sglang/sglang_worker.py b/rlinf/workers/rollout/sglang/sglang_worker.py index 93f268d6c..abdf59365 100644 --- a/rlinf/workers/rollout/sglang/sglang_worker.py +++ b/rlinf/workers/rollout/sglang/sglang_worker.py @@ -335,10 +335,12 @@ async def rollout(self, input_channel: Channel, output_channel: Channel): results = self._completion_info.get_results(hash_id) input_ids = [input_ids] * len(results) + answers = [rollout_request.answers[raw_id]] * len(results) rollout_result = RolloutResult.from_sglang_results( results, rollout_request.n, input_ids, + answers=answers, return_logprobs=self._return_logprobs, ) diff --git a/rlinf/workers/rollout/vllm/vllm_worker.py b/rlinf/workers/rollout/vllm/vllm_worker.py index 1c6e859a5..a67d967ed 100644 --- a/rlinf/workers/rollout/vllm/vllm_worker.py +++ b/rlinf/workers/rollout/vllm/vllm_worker.py @@ -265,8 +265,8 @@ def check_prompt_text() -> Optional[List[str]]: assert len(prompt_texts) > 0, "prompt_text should not be empty." return prompt_texts - def check_image_data() -> Optional[List[List[Image]]]: - if image_data is None: + def check_image_data() -> Optional[List[List[Image.Image]]]: + if image_data is None or not any(image_data): return None assert isinstance(image_data, list), "image_data should be a list." if isinstance(image_data[0], list): diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml b/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml index 62d0c5247..111969ff4 100644 --- a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml +++ b/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml @@ -12,9 +12,10 @@ cluster: component_placement: rollout: 0-3 actor: 4-7 + reward: 0-3 runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -66,7 +67,7 @@ algorithm: calculate_entropy: True clip_ratio_c: null # 3.0 - adv_type: grpo + adv_type: math_grpo normalize_advantages: False early_stop_imp_ratio: 5.0 use_valid_token_scale: True @@ -260,9 +261,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline.yaml b/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline.yaml index 3f7821587..705757c31 100644 --- a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline.yaml +++ b/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline.yaml @@ -13,9 +13,10 @@ cluster: rollout: 0-3 inference: 4-5 actor: 6-7 + reward: 0-3 runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -67,7 +68,7 @@ algorithm: calculate_entropy: True clip_ratio_c: null # 3.0 - adv_type: grpo + adv_type: math_grpo normalize_advantages: False early_stop_imp_ratio: 5.0 use_valid_token_scale: True @@ -274,9 +275,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/vllm/run_pipeline.sh b/tests/e2e_tests/math/vllm/run_pipeline.sh index 0a21368f6..59fb19454 100644 --- a/tests/e2e_tests/math/vllm/run_pipeline.sh +++ b/tests/e2e_tests/math/vllm/run_pipeline.sh @@ -14,4 +14,4 @@ else CONFIG_NAME=$1 fi -python ${REPO_PATH}/examples/math/main_math.py --config-path $REPO_PATH/tests/e2e_tests/math/vllm --config-name $CONFIG_NAME \ No newline at end of file +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/math/vllm --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/e2e_tests/vqa/sglang/run_collocated.sh b/tests/e2e_tests/vqa/sglang/run_collocated.sh index 43fa65fd0..793c92417 100644 --- a/tests/e2e_tests/vqa/sglang/run_collocated.sh +++ b/tests/e2e_tests/vqa/sglang/run_collocated.sh @@ -14,4 +14,4 @@ else CONFIG_NAME=$1 fi -python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/math/sglang --config-name $CONFIG_NAME \ No newline at end of file +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/vqa/sglang --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/unit_tests/test_auto_placement.py b/tests/unit_tests/test_auto_placement.py index a70b01c9d..559229763 100644 --- a/tests/unit_tests/test_auto_placement.py +++ b/tests/unit_tests/test_auto_placement.py @@ -598,7 +598,7 @@ def test_scheduler_task_initialization(self, mock_validate): """Test SchedulerTask initialization.""" # Create a mock config mock_cfg = MagicMock() - mock_cfg.runner.task_type = "math" + mock_cfg.runner.task_type = "reasoning" mock_cfg.actor.model.tensor_model_parallel_size = 2 mock_cfg.actor.model.pipeline_model_parallel_size = 1 mock_cfg.rollout.tensor_parallel_size = 1 @@ -620,7 +620,7 @@ def test_scheduler_task_initialization(self, mock_validate): scheduler_task = SchedulerTask(mock_cfg, mock_cluster) - assert scheduler_task.is_math is True + assert scheduler_task.is_reasoning is True assert scheduler_task.total_gpus == 8 assert scheduler_task.group_size == 4 assert "actor" in scheduler_task.components_config diff --git a/toolkits/auto_placement/scheduler_task.py b/toolkits/auto_placement/scheduler_task.py index 00d8ea8aa..b3be46012 100644 --- a/toolkits/auto_placement/scheduler_task.py +++ b/toolkits/auto_placement/scheduler_task.py @@ -31,8 +31,10 @@ def __init__( workflow_graph: Optional[Dict[ComponentNode, List[ComponentNode]]] = None, ): self.cfg = cfg - self.is_math = cfg.runner.task_type == "math" - assert self.is_math, "Only math task is supported" + self.is_reasoning = cfg.runner.task_type == "reasoning" + assert self.is_reasoning, ( + f"Only reasoning task is supported, current task type: {cfg.runner.task_type}" + ) self.components_config = { "actor": { @@ -71,7 +73,7 @@ def __init__( self.global_step_batch_size = self.rollout_batch_size * self.group_size if workflow_graph is None: - if self.is_math: + if self.is_reasoning: actor = ComponentNode("actor") inference = ComponentNode("inference") rollout = ComponentNode("rollout") @@ -179,7 +181,7 @@ def parse_partition_allocation_to_cfg( def time_division_multiplexing(self) -> List[Dict[str, Workflow]]: partitions: List[Dict[str, Workflow]] = get_workflow_partition(self.workflow) - if self.is_math: + if self.is_reasoning: valid_partitions = [ i for i in partitions if len(i) in [1, len(self.components_config)] ] diff --git a/toolkits/math_verifier/verify.py b/toolkits/math_verifier/verify.py index 31d92c280..988b86045 100644 --- a/toolkits/math_verifier/verify.py +++ b/toolkits/math_verifier/verify.py @@ -14,7 +14,13 @@ import multiprocessing import re -from concurrent.futures import ProcessPoolExecutor, as_completed +from concurrent.futures import ( + ProcessPoolExecutor, + as_completed, +) +from concurrent.futures import ( + TimeoutError as FuturesTimeoutError, +) from typing import List, Union import regex @@ -403,15 +409,20 @@ def math_verify_call( jobs.append(job) all_jobs.append(jobs) - labels = [] + labels: List[int] = [] has_timeout = False for jobs in all_jobs: + label = 0 try: for job in as_completed(jobs, timeout=timeout): x = job.result() - labels.append(x) - except TimeoutError: + label = label or x + except FuturesTimeoutError: has_timeout = True + for job in jobs: + job.cancel() + finally: + labels.append(label) if has_timeout: reset_global_process_pool() From 62df3139450a83baa087bc06425ff638721b7f8c Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Sun, 5 Oct 2025 00:28:27 +0000 Subject: [PATCH 23/38] fix(fsdp): add forgotten backward and optimizer step Signed-off-by: Bo Dai --- rlinf/workers/actor/fsdp_actor_worker.py | 29 ++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index e13a8949f..5e4a48848 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -346,6 +346,10 @@ def run_training(self, input_channel: Channel): loss = loss + kl_loss * self.kl_beta # add to log + # scale loss for gradient accumulation and backprop + loss = loss / self.gradient_accumulation + loss.backward() + mbs_metrics_data.update( { "final_loss": loss.detach().cpu(), @@ -355,6 +359,18 @@ def run_training(self, input_channel: Channel): ) append_to_dict(metrics, mbs_metrics_data) + # apply gradient clipping and optimizer step at the end of a global batch + grad_norm = None + try: + grad_norm = self.model.clip_grad_norm_( + max_norm=self.cfg.actor.optim.clip_grad + ) + except Exception: + pass + self.optimizer.step() + self.optimizer.zero_grad() + + # aggregate metrics across micro-batches mean_metric_dict = { key: torch.mean(torch.stack(value)) for key, value in metrics.items() @@ -362,6 +378,19 @@ def run_training(self, input_channel: Channel): mean_metric_dict = all_reduce_dict( mean_metric_dict, op=torch.distributed.ReduceOp.AVG ) + # add optimizer stats + if grad_norm is not None: + mean_metric_dict["actor/grad_norm"] = ( + torch.as_tensor( + grad_norm + if torch.is_tensor(grad_norm) + else float(grad_norm) + ) + .float() + .cpu() + ) + lr = self.optimizer.param_groups[0]["lr"] + mean_metric_dict["actor/lr"] = torch.as_tensor(lr).float().cpu() training_metrics_list.append(mean_metric_dict) # Rollout metrics From 1ec7e5450948764605e05a0bc812b9603d4bee96 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Sun, 5 Oct 2025 11:08:01 +0000 Subject: [PATCH 24/38] fix(collocated): fix inference/rollout do jobs parallelly which causes oom in collocated mode Signed-off-by: Bo Dai --- rlinf/runners/reasoning_runner.py | 1 + rlinf/workers/actor/megatron_actor_worker.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/rlinf/runners/reasoning_runner.py b/rlinf/runners/reasoning_runner.py index d68abf72a..b53010e18 100644 --- a/rlinf/runners/reasoning_runner.py +++ b/rlinf/runners/reasoning_runner.py @@ -341,6 +341,7 @@ def run(self): infer_handle: Handle = self.inference.run_inference( input_channel=self.reward_channel, output_channel=self.inference_channel, + rollout_channel=self.rollout_channel, compute_ref_logprobs=self.compute_ref_logprobs, ) inference_channel = self.inference_channel diff --git a/rlinf/workers/actor/megatron_actor_worker.py b/rlinf/workers/actor/megatron_actor_worker.py index bbac0701c..cb74b0d00 100644 --- a/rlinf/workers/actor/megatron_actor_worker.py +++ b/rlinf/workers/actor/megatron_actor_worker.py @@ -841,6 +841,7 @@ def run_inference( self, input_channel: Channel, output_channel: Channel, + rollout_channel: Channel, compute_ref_logprobs: bool, ): """Compute prev/ref logprobs using the actor Model's forward. @@ -851,6 +852,8 @@ def run_inference( compute_ref_logprobs: Whether to compute reference logprobs. """ recv_batch_size = 0 + if not self.is_pipeline: + rollout_channel.device_lock.acquire() while recv_batch_size < self.total_batch_size_per_dp: batch, rollout_result = self.get_batch(input_channel) recv_batch_size += rollout_result.num_sequence @@ -874,6 +877,8 @@ def run_inference( assert recv_batch_size == self.total_batch_size_per_dp, ( f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" ) + if not self.is_pipeline: + rollout_channel.device_lock.release() # Advantages and returns def compute_advantages_and_returns( From e57d10df7007f71cff52484cb22b896f946c6c64 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Wed, 8 Oct 2025 18:44:16 +0000 Subject: [PATCH 25/38] fix(sync_weight): fix oom bugs Signed-off-by: Bo Dai --- rlinf/runners/reasoning_runner.py | 1 - rlinf/workers/actor/fsdp_actor_worker.py | 16 +++-- rlinf/workers/actor/megatron_actor_worker.py | 5 -- rlinf/workers/reward/reward_worker.py | 63 +++++++++----------- 4 files changed, 36 insertions(+), 49 deletions(-) diff --git a/rlinf/runners/reasoning_runner.py b/rlinf/runners/reasoning_runner.py index b53010e18..d68abf72a 100644 --- a/rlinf/runners/reasoning_runner.py +++ b/rlinf/runners/reasoning_runner.py @@ -341,7 +341,6 @@ def run(self): infer_handle: Handle = self.inference.run_inference( input_channel=self.reward_channel, output_channel=self.inference_channel, - rollout_channel=self.rollout_channel, compute_ref_logprobs=self.compute_ref_logprobs, ) inference_channel = self.inference_channel diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index 5e4a48848..ba35362e0 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -144,9 +144,12 @@ def del_reshard_state_dict(self): def sync_model_to_rollout(self): if next(self.model.parameters()).is_cpu: self.load_fsdp_param_and_grad(self.device) - self.rollout_state_dict = self.get_model_state_dict() + if self.cfg.actor.get("enable_offload", False): + self.offload_fsdp_param_and_grad(offload_grad=True) + self.offload_fsdp_optimizer() + has_visual = any("visual." in k for k in self.rollout_state_dict.keys()) state_dict = {} @@ -161,14 +164,9 @@ def sync_model_to_rollout(self): name = name[6:] state_dict[name] = reduce_tensor(v) - self.send( - state_dict, self._rollout_group_name, self._weight_dst_rank_in_rollout - ) - if self.cfg.actor.get("enable_offload", False): - self.offload_fsdp_param_and_grad() - torch.cuda.synchronize() - gc.collect() - torch.cuda.empty_cache() + self.send( + state_dict, self._rollout_group_name, self._weight_dst_rank_in_rollout + ) def compute_logprobs(self): self.model.eval() diff --git a/rlinf/workers/actor/megatron_actor_worker.py b/rlinf/workers/actor/megatron_actor_worker.py index cb74b0d00..bbac0701c 100644 --- a/rlinf/workers/actor/megatron_actor_worker.py +++ b/rlinf/workers/actor/megatron_actor_worker.py @@ -841,7 +841,6 @@ def run_inference( self, input_channel: Channel, output_channel: Channel, - rollout_channel: Channel, compute_ref_logprobs: bool, ): """Compute prev/ref logprobs using the actor Model's forward. @@ -852,8 +851,6 @@ def run_inference( compute_ref_logprobs: Whether to compute reference logprobs. """ recv_batch_size = 0 - if not self.is_pipeline: - rollout_channel.device_lock.acquire() while recv_batch_size < self.total_batch_size_per_dp: batch, rollout_result = self.get_batch(input_channel) recv_batch_size += rollout_result.num_sequence @@ -877,8 +874,6 @@ def run_inference( assert recv_batch_size == self.total_batch_size_per_dp, ( f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" ) - if not self.is_pipeline: - rollout_channel.device_lock.release() # Advantages and returns def compute_advantages_and_returns( diff --git a/rlinf/workers/reward/reward_worker.py b/rlinf/workers/reward/reward_worker.py index 03b2311af..e186eebac 100644 --- a/rlinf/workers/reward/reward_worker.py +++ b/rlinf/workers/reward/reward_worker.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple +from typing import Dict, Tuple import torch from omegaconf import DictConfig @@ -67,48 +67,43 @@ def compute_rewards(self, input_channel: Channel, output_channel: Channel): with self.worker_timer(): recv_batch_size = 0 while recv_batch_size < self.total_batch_size_per_dp: - batch, rollout_result = self.get_batch(input_channel) - + rollout_result: RolloutResult = input_channel.get() recv_batch_size += rollout_result.num_sequence - # Compute rule-based reward + if rollout_result.rewards is None: - rollout_result.rewards = self._compute_batch_rewards( - batch, rollout_result.answers - ) + if self.cfg.reward.use_reward_model: + with input_channel.device_lock: + batch = rollout_result.to_actor_batch( + self.cfg.data.max_prompt_length, + self.cfg.actor.model.encoder_seq_length, + self.tokenizer.eos_token_id, + ) + rollout_result.rewards = ( + self.compute_batch_rewards_with_model(batch) + ) + else: + rollout_result.rewards = self._compute_rule_based_rewards( + rollout_result + ) + output_channel.put(rollout_result) assert recv_batch_size == self.total_batch_size_per_dp, ( f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" ) - def _compute_batch_rewards( - self, batch: Dict[str, torch.Tensor], answers: List[str | dict] - ): - """Reward computation using non-model based reward.""" + def _compute_rule_based_rewards(self, rollout_result: RolloutResult): + # Decode only the generated tokens; response_ids are already the post-prompt tokens + texts = self.tokenizer.batch_decode( + rollout_result.response_ids, skip_special_tokens=True + ) - if self.cfg.reward.use_reward_model: - return self.compute_batch_rewards_with_model(batch) - - texts = [] - for response, response_len in zip( - batch["input_ids"], - batch["response_lengths"], - ): - response = response[ - self.cfg.data.max_prompt_length : self.cfg.data.max_prompt_length - + response_len - ] - texts.append( - self.tokenizer.decode(response.tolist(), skip_special_tokens=True) - ) - reward_scores = self.reward.get_reward(texts, answers) - - all_reward_scores = torch.as_tensor( - reward_scores, - dtype=torch.float, - device=torch.device("cpu"), - ).view(-1, 1) - return all_reward_scores.flatten() + scores = self.reward.get_reward(texts, rollout_result.answers) + return ( + torch.as_tensor(scores, dtype=torch.float, device=torch.device("cpu")) + .view(-1, 1) + .flatten() + ) def compute_batch_rewards_with_model(self, batch: Dict[str, torch.Tensor]): self.model.eval() From d0edcd0dfc0b9ce24c2032c2f0d6690403ec366d Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Thu, 9 Oct 2025 06:36:55 +0000 Subject: [PATCH 26/38] fix(vlm): in torch260's image, transformers version is 4.51.1 and it's ok, use 4.51.1 rather than 4.56.1 Signed-off-by: Bo Dai --- .../math/qwen2.5-1.5b-grpo-megatron.yaml | 20 +++++++++---------- rlinf/runners/reasoning_runner.py | 1 + rlinf/workers/actor/fsdp_actor_worker.py | 8 ++++++-- rlinf/workers/actor/megatron_actor_worker.py | 5 ++++- .../sglang/qwen2.5-vl-3b-grpo-collocated.yaml | 4 ++-- .../vllm/qwen2.5-vl-3b-grpo-collocated.yaml | 4 ++-- 6 files changed, 25 insertions(+), 17 deletions(-) diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml index 63e972b3a..4d851f894 100644 --- a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml @@ -7,7 +7,7 @@ hydra: output_subdir: null cluster: - num_nodes: 1 + num_nodes: 16 component_placement: actor,rollout,reward: all @@ -25,14 +25,14 @@ runner: val_check_interval: 1 save_interval: 50 - seq_length: 10240 + seq_length: 28672 enable_dynamic_batch_size: False max_tokens_per_mbs: 28672 resume_dir: null - experiment_name: megatron-vllm-1.5b-math-test - output_dir: /mnt/public/daibo/results + experiment_name: grpo-1.5b + output_dir: ../results algorithm: group_size: 16 @@ -84,7 +84,7 @@ rollout: gpu_memory_utilization: 0.55 - model_dir: /mnt/public/hf_models/DeepSeek-R1-Distill-Qwen-1.5B/ + model_dir: /path/to/model/DeepSeek-R1-Distill-Qwen-1.5B/ model_arch: qwen2.5 enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. distributed_executor_backend: mp # ray or mp @@ -93,7 +93,7 @@ rollout: padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine eos: null # will be tokenizer.eos_token_id if null. - rollout_backend: vllm # here choose which backend to rollout,support [sglang, vllm] + rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm] sglang: attention_backend: triton # [flashinfer, triton] for more, see sglang's doc @@ -126,15 +126,15 @@ data: dataset_name: boba max_prompt_length: 1024 filter_prompt_by_length: True - rollout_batch_size: 8 + rollout_batch_size: 512 val_rollout_batch_size: null num_workers: 2 prompt_key: prompt shuffle: True validation_shuffle: True seed: 1234 - train_data_paths: ["/mnt/public/daibo/dataset/boba_106k_0319_prompt_1024.jsonl"] - val_data_paths: ["/mnt/public/daibo/dataset/boba_106k_0319_prompt_1024.jsonl"] + train_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"] + val_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"] actor: group_name: "ActorGroup" @@ -216,7 +216,7 @@ actor: lr_decay_iters: 10 tokenizer: - tokenizer_model: /mnt/public/hf_models/DeepSeek-R1-Distill-Qwen-1.5B/ + tokenizer_model: /path/to/model/DeepSeek-R1-Distill-Qwen-1.5B/ use_fast: False trust_remote_code: True padding_side: 'right' diff --git a/rlinf/runners/reasoning_runner.py b/rlinf/runners/reasoning_runner.py index d68abf72a..b53010e18 100644 --- a/rlinf/runners/reasoning_runner.py +++ b/rlinf/runners/reasoning_runner.py @@ -341,6 +341,7 @@ def run(self): infer_handle: Handle = self.inference.run_inference( input_channel=self.reward_channel, output_channel=self.inference_channel, + rollout_channel=self.rollout_channel, compute_ref_logprobs=self.compute_ref_logprobs, ) inference_channel = self.inference_channel diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index ba35362e0..f4a047766 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -160,8 +160,12 @@ def sync_model_to_rollout(self): if has_visual: if name.startswith("model.language_model."): name = "model." + name[21:] - elif name.startswith("model."): - name = name[6:] + # NOTE: + # if transformers version is 4.56.1 or older(not tested), + # the following line should be uncommented + + # elif name.startswith("model."): + # name = name[6:] state_dict[name] = reduce_tensor(v) self.send( diff --git a/rlinf/workers/actor/megatron_actor_worker.py b/rlinf/workers/actor/megatron_actor_worker.py index bbac0701c..6736e550b 100644 --- a/rlinf/workers/actor/megatron_actor_worker.py +++ b/rlinf/workers/actor/megatron_actor_worker.py @@ -841,6 +841,7 @@ def run_inference( self, input_channel: Channel, output_channel: Channel, + rollout_channel: Channel, compute_ref_logprobs: bool, ): """Compute prev/ref logprobs using the actor Model's forward. @@ -856,7 +857,9 @@ def run_inference( recv_batch_size += rollout_result.num_sequence # Must be called after batch is retrieved, suggesting that rollout has stopped # Otherwise, loading model might cause OOM in the collocated mode - self._load_weight_and_optimizer(input_channel) + self._load_weight_and_optimizer( + input_channel if self.is_pipeline else rollout_channel + ) # Prev logprobs with self.worker_timer(): diff --git a/tests/e2e_tests/vqa/sglang/qwen2.5-vl-3b-grpo-collocated.yaml b/tests/e2e_tests/vqa/sglang/qwen2.5-vl-3b-grpo-collocated.yaml index 30c7a150c..bcbae2f94 100644 --- a/tests/e2e_tests/vqa/sglang/qwen2.5-vl-3b-grpo-collocated.yaml +++ b/tests/e2e_tests/vqa/sglang/qwen2.5-vl-3b-grpo-collocated.yaml @@ -19,8 +19,8 @@ runner: experiment_name: ${runner.experiment_name} logger_backends: ["tensorboard"] # wandb, swanlab - max_epochs: 5 - max_steps: -1 + max_epochs: 1 + max_steps: 3 val_check_interval: 1 save_interval: 50 diff --git a/tests/e2e_tests/vqa/vllm/qwen2.5-vl-3b-grpo-collocated.yaml b/tests/e2e_tests/vqa/vllm/qwen2.5-vl-3b-grpo-collocated.yaml index aef3ec271..f555a7292 100644 --- a/tests/e2e_tests/vqa/vllm/qwen2.5-vl-3b-grpo-collocated.yaml +++ b/tests/e2e_tests/vqa/vllm/qwen2.5-vl-3b-grpo-collocated.yaml @@ -19,8 +19,8 @@ runner: experiment_name: ${runner.experiment_name} logger_backends: ["tensorboard"] # wandb, swanlab - max_epochs: 5 - max_steps: -1 + max_epochs: 1 + max_steps: 3 val_check_interval: 1 save_interval: 50 From 19f2a278489154aca4a198200c3ae6460b2c0c4a Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Fri, 10 Oct 2025 09:08:22 +0000 Subject: [PATCH 27/38] fix(fsdp): use bf16 instead of fp16 for training Signed-off-by: Bo Dai --- .../config/math/qwen2.5-1.5b-grpo-fsdp.yaml | 6 +- rlinf/algorithms/losses.py | 16 ++- rlinf/workers/actor/fsdp_actor_worker.py | 104 +++++------------- rlinf/workers/rollout/vllm/vllm_worker.py | 2 + 4 files changed, 41 insertions(+), 87 deletions(-) diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml index c4c646808..e1e97c215 100644 --- a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml @@ -158,7 +158,7 @@ actor: seed: 1234 model: - precision: fp16 + precision: bf16 sharding_strategy: full_shard is_lora: False @@ -168,8 +168,8 @@ actor: optim: optimizer: adam - bf16: False - fp16: True + bf16: True + fp16: False lr: 2e-05 adam_beta1: 0.9 adam_beta2: 0.95 diff --git a/rlinf/algorithms/losses.py b/rlinf/algorithms/losses.py index 1d66885ea..f1bf025cd 100644 --- a/rlinf/algorithms/losses.py +++ b/rlinf/algorithms/losses.py @@ -233,17 +233,21 @@ def compute_math_ppo_actor_loss(**kwargs): clip_mask = policy_loss1.detach() < policy_loss2.detach() dual_clip_mask.logical_and_(loss_mask) - clip_fraction = clip_mask.logical_and_(loss_mask).count_nonzero() / loss_mask_count - approx_kl = -approx_kl.sum() / loss_mask_count + num_clipped = clip_mask.logical_and_(loss_mask).count_nonzero() + + clip_fraction = num_clipped.float() / float(loss_mask_count) + approx_kl = -approx_kl.sum() / float(loss_mask_count) dual_cliped_ratio = torch.where(dual_clip_mask, ratio, 0) # Compile metrics for logging metrics_data = { - "policy_loss": masked_mean(policy_loss.detach(), loss_mask), - "ratio": masked_mean(ratio.detach(), loss_mask), - "clipped_ratio": masked_mean(clipped_ratio.detach(), loss_mask), - "dual_cliped_ratio": masked_mean(dual_cliped_ratio.detach(), loss_mask), + "policy_loss": masked_mean(policy_loss.detach(), loss_mask).detach(), + "ratio": masked_mean(ratio.detach(), loss_mask).detach(), + "clipped_ratio": masked_mean(clipped_ratio.detach(), loss_mask).detach(), + "dual_cliped_ratio": masked_mean( + dual_cliped_ratio.detach(), loss_mask + ).detach(), "approx_kl": approx_kl.detach(), "clip_fraction": clip_fraction.detach(), } diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index f4a047766..8c467d553 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -14,7 +14,7 @@ import gc import os -from typing import Dict, List, Tuple +from typing import Dict, Tuple import numpy as np import torch @@ -142,14 +142,13 @@ def del_reshard_state_dict(self): del self.rollout_state_dict def sync_model_to_rollout(self): + if self.cfg.actor.get("enable_offload", False): + self.offload_fsdp_optimizer() + if next(self.model.parameters()).is_cpu: self.load_fsdp_param_and_grad(self.device) self.rollout_state_dict = self.get_model_state_dict() - if self.cfg.actor.get("enable_offload", False): - self.offload_fsdp_param_and_grad(offload_grad=True) - self.offload_fsdp_optimizer() - has_visual = any("visual." in k for k in self.rollout_state_dict.keys()) state_dict = {} @@ -168,9 +167,12 @@ def sync_model_to_rollout(self): # name = name[6:] state_dict[name] = reduce_tensor(v) - self.send( - state_dict, self._rollout_group_name, self._weight_dst_rank_in_rollout - ) + self.send( + state_dict, self._rollout_group_name, self._weight_dst_rank_in_rollout + ) + + if self.cfg.actor.get("enable_offload", False): + self.offload_fsdp_param_and_grad() def compute_logprobs(self): self.model.eval() @@ -354,22 +356,23 @@ def run_training(self, input_channel: Channel): mbs_metrics_data.update( { - "final_loss": loss.detach().cpu(), - "entropy_loss": entropy_loss.detach().cpu(), - "kl_loss": kl_loss.detach().cpu(), + "final_loss": loss.detach(), + "entropy_loss": entropy_loss.detach(), + "kl_loss": kl_loss.detach(), } ) append_to_dict(metrics, mbs_metrics_data) # apply gradient clipping and optimizer step at the end of a global batch - grad_norm = None - try: - grad_norm = self.model.clip_grad_norm_( - max_norm=self.cfg.actor.optim.clip_grad + grad_norm = self.model.clip_grad_norm_( + max_norm=self.cfg.actor.optim.clip_grad + ) + if not torch.isfinite(grad_norm).all(): + self.log_warning( + "grad norm is not finite, skip this optimizer step." ) - except Exception: - pass - self.optimizer.step() + else: + self.optimizer.step() self.optimizer.zero_grad() # aggregate metrics across micro-batches @@ -381,16 +384,12 @@ def run_training(self, input_channel: Channel): mean_metric_dict, op=torch.distributed.ReduceOp.AVG ) # add optimizer stats - if grad_norm is not None: - mean_metric_dict["actor/grad_norm"] = ( - torch.as_tensor( - grad_norm - if torch.is_tensor(grad_norm) - else float(grad_norm) - ) - .float() - .cpu() + if torch.is_tensor(grad_norm): + mean_metric_dict["actor/grad_norm"] = float( + grad_norm.detach().item() ) + else: + mean_metric_dict["actor/grad_norm"] = float(grad_norm) lr = self.optimizer.param_groups[0]["lr"] mean_metric_dict["actor/lr"] = torch.as_tensor(lr).float().cpu() training_metrics_list.append(mean_metric_dict) @@ -412,57 +411,6 @@ def save_checkpoint(self, save_base_path, step): torch.save(optim_state, os.path.join(save_base_path, "optim.pt")) torch.distributed.barrier() - def _compute_batch_rewards( - self, batch: Dict[str, torch.Tensor], answers: List[str] - ): - """Reward computation using non-model based reward.""" - texts = [] - for response, response_len in zip( - batch["input_ids"], - batch["response_lengths"], - ): - response = response[ - self.cfg.data.max_prompt_length : self.cfg.data.max_prompt_length - + response_len - ] - texts.append( - self.tokenizer.decode(response.tolist(), skip_special_tokens=True) - ) - reward_scores = self.reward.get_reward(texts, answers) - - all_reward_scores = torch.as_tensor( - reward_scores, - dtype=torch.float, - device=torch.device("cpu"), - ).view(-1, 1) - return all_reward_scores.flatten() - - # Rewards - def compute_rewards(self, input_channel: Channel, output_channel: Channel): - """Compute rewards. - - Args: - input_channel: The input channel to read from. - output_channel: The output channel to send results to. - """ - recv_batch_size = 0 - while recv_batch_size < self.total_batch_size_per_dp: - batch, rollout_result = self.get_batch(input_channel) - recv_batch_size += rollout_result.num_sequence - - # Compute rule-based reward - with self.worker_timer(): - if rollout_result.rewards is None: - rollout_result.rewards = self._compute_batch_rewards( - batch, rollout_result.answers - ) - - self.put_result(rollout_result, output_channel) - - assert recv_batch_size == self.total_batch_size_per_dp, ( - f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" - ) - # Advantages and returns def compute_advantages_and_returns( self, input_channel: Channel, output_channel: Channel diff --git a/rlinf/workers/rollout/vllm/vllm_worker.py b/rlinf/workers/rollout/vllm/vllm_worker.py index a67d967ed..b3629b170 100644 --- a/rlinf/workers/rollout/vllm/vllm_worker.py +++ b/rlinf/workers/rollout/vllm/vllm_worker.py @@ -420,6 +420,8 @@ async def rollout_and_return( multi_modal_inputs=request.multi_modal_inputs, return_logprobs=self._return_logprobs, ) + if self._cfg.rollout.print_outputs: + print_vllm_outputs(outputs=vllm_results) await self._put_result(result=rollout_result, output_channel=output_channel) async def rollout(self, input_channel: Channel, output_channel: Channel) -> None: From d67365ca444fef81f00f2f6a5d74e4200e33e0cf Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Fri, 10 Oct 2025 10:09:43 +0000 Subject: [PATCH 28/38] feat(ci): add fsdp ci Signed-off-by: Bo Dai --- .github/workflows/code-test.yml | 77 +++++-- .github/workflows/vqa_e2e.yml | 4 +- examples/reasoning/run_main_grpo_math.sh | 1 - examples/reasoning/run_main_grpo_vqa.sh | 1 - tests/e2e_tests/math/sglang/run_collocated.sh | 17 -- tests/e2e_tests/math/vllm/run_collocated.sh | 17 -- tests/e2e_tests/math/vllm/run_pipeline.sh | 17 -- ...-collocated-fsdp-sgl-rollout-logprobs.yaml | 213 ++++++++++++++++++ ...qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml | 213 ++++++++++++++++++ ...collocated-fsdp-vllm-rollout-logprobs.yaml | 213 ++++++++++++++++++ ...wen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml | 213 ++++++++++++++++++ ...o-collocated-mg-sgl-rollout-logprobs.yaml} | 8 +- .../qwen2.5-1.5b-grpo-collocated-mg-sgl.yaml} | 5 - ...-collocated-mg-vllm-rollout-logprobs.yaml} | 0 ...qwen2.5-1.5b-grpo-collocated-mg-vllm.yaml} | 0 ...rpo-pipeline-mg-sgl-rollout-logprobs.yaml} | 0 .../qwen2.5-1.5b-grpo-pipeline-mg-sgl.yaml} | 0 ...po-pipeline-mg-vllm-rollout-logprobs.yaml} | 0 .../qwen2.5-1.5b-grpo-pipeline-mg-vllm.yaml} | 0 ...en2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml} | 0 ...n2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml} | 0 .../sglang => reasoning}/run_collocated.sh | 7 +- .../sglang => reasoning}/run_pipeline.sh | 6 +- tests/e2e_tests/vqa/vllm/run_collocated.sh | 17 -- 24 files changed, 917 insertions(+), 112 deletions(-) delete mode 100644 tests/e2e_tests/math/sglang/run_collocated.sh delete mode 100644 tests/e2e_tests/math/vllm/run_collocated.sh delete mode 100644 tests/e2e_tests/math/vllm/run_pipeline.sh create mode 100644 tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml create mode 100644 tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml create mode 100644 tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml create mode 100644 tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml rename tests/e2e_tests/{math/sglang/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml => reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs.yaml} (96%) rename tests/e2e_tests/{math/sglang/qwen2.5-1.5b-grpo-collocated.yaml => reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl.yaml} (96%) rename tests/e2e_tests/{math/vllm/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml => reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs.yaml} (100%) rename tests/e2e_tests/{math/vllm/qwen2.5-1.5b-grpo-collocated.yaml => reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm.yaml} (100%) rename tests/e2e_tests/{math/sglang/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml => reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs.yaml} (100%) rename tests/e2e_tests/{math/sglang/qwen2.5-1.5b-grpo-pipeline.yaml => reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl.yaml} (100%) rename tests/e2e_tests/{math/vllm/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml => reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs.yaml} (100%) rename tests/e2e_tests/{math/vllm/qwen2.5-1.5b-grpo-pipeline.yaml => reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm.yaml} (100%) rename tests/e2e_tests/{vqa/sglang/qwen2.5-vl-3b-grpo-collocated.yaml => reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml} (100%) rename tests/e2e_tests/{vqa/vllm/qwen2.5-vl-3b-grpo-collocated.yaml => reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml} (100%) rename tests/e2e_tests/{vqa/sglang => reasoning}/run_collocated.sh (63%) rename tests/e2e_tests/{math/sglang => reasoning}/run_pipeline.sh (63%) delete mode 100644 tests/e2e_tests/vqa/vllm/run_collocated.sh diff --git a/.github/workflows/code-test.yml b/.github/workflows/code-test.yml index c9c34a6c6..6fd1c1409 100644 --- a/.github/workflows/code-test.yml +++ b/.github/workflows/code-test.yml @@ -113,33 +113,48 @@ jobs: - name: Checkout code uses: actions/checkout@v5 - - name: SGLang Collocated mode + - name: Megatron SGLang Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reason - bash tests/e2e_tests/math/sglang/run_collocated.sh + source switch_env reasoning + bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-mg-sgl - - name: vLLM Collocated mode + - name: Megatron vLLM Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reason - bash tests/e2e_tests/math/vllm/run_collocated.sh + source switch_env reasoning + bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-mg-vllm - - name: SGLang Pipeline mode + - name: Megatron SGLang Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reason - bash tests/e2e_tests/math/sglang/run_pipeline.sh + source switch_env reasoning + bash tests/e2e_tests/reasoning/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-mg-sgl - - name: vLLM Pipeline mode + - name: Megatron vLLM Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reason - bash tests/e2e_tests/math/vllm/run_pipeline.sh + source switch_env reasoning + bash tests/e2e_tests/reasoning/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-mg-vllm + + - name: FSDO SGLang Collocated mode + timeout-minutes: 20 + run: | + export REPO_PATH=$(pwd) + source switch_env reasoning + bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-fsdp-sgl + + - name: FSDP vLLM Collocated mode + timeout-minutes: 20 + run: | + export REPO_PATH=$(pwd) + source switch_env reasoning + bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-fsdp-vllm + reason-qwen-grpo-test-rollout-logprobs: needs: [check-changes] @@ -149,33 +164,47 @@ jobs: - name: Checkout code uses: actions/checkout@v5 - - name: SGLang Collocated mode + - name: Megatron SGLang Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reason - bash tests/e2e_tests/math/sglang/run_collocated.sh qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml + source switch_env reasoning + bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs - - name: vLLM Collocated mode + - name: Megatron vLLM Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reason - bash tests/e2e_tests/math/vllm/run_collocated.sh qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml + source switch_env reasoning + bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs - - name: SGLang Pipeline mode + - name: Megatron SGLang Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reason - bash tests/e2e_tests/math/sglang/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml + source switch_env reasoning + bash tests/e2e_tests/reasoning/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs - - name: vLLM Pipeline mode + - name: Megatron vLLM Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reason - bash tests/e2e_tests/math/vllm/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml + source switch_env reasoning + bash tests/e2e_tests/reasoning/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs + + - name: FSDP SGLang Collocated mode + timeout-minutes: 20 + run: | + export REPO_PATH=$(pwd) + source switch_env reasoning + bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-fsdp-sglang-rollout-logprobs + + - name: FSDP vLLM Collocated mode + timeout-minutes: 20 + run: | + export REPO_PATH=$(pwd) + source switch_env reasoning + bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs coding-online-rl-qwen-ppo-test: needs: [check-changes] diff --git a/.github/workflows/vqa_e2e.yml b/.github/workflows/vqa_e2e.yml index 58ce056e2..50b6217aa 100644 --- a/.github/workflows/vqa_e2e.yml +++ b/.github/workflows/vqa_e2e.yml @@ -55,9 +55,9 @@ jobs: - name: SGLang Collocated mode run: | export REPO_PATH=$(pwd) - bash tests/e2e_tests/vqa/sglang/run_collocated.sh + bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-vl-3b-grpo-collocated-fsdp-sgl - name: vLLM Collocated mode run: | export REPO_PATH=$(pwd) - bash tests/e2e_tests/vqa/vllm/run_collocated.sh + bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-vl-3b-grpo-collocated-fsdp-vllm diff --git a/examples/reasoning/run_main_grpo_math.sh b/examples/reasoning/run_main_grpo_math.sh index 56e13c7c2..18a48d780 100644 --- a/examples/reasoning/run_main_grpo_math.sh +++ b/examples/reasoning/run_main_grpo_math.sh @@ -2,7 +2,6 @@ set -x tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS export CUDA_DEVICE_MAX_CONNECTIONS=1 export TOKENIZERS_PARALLELISM=false export RAY_DEDUP_LOGS=0 diff --git a/examples/reasoning/run_main_grpo_vqa.sh b/examples/reasoning/run_main_grpo_vqa.sh index 1b41f415c..3cc526f0e 100644 --- a/examples/reasoning/run_main_grpo_vqa.sh +++ b/examples/reasoning/run_main_grpo_vqa.sh @@ -2,7 +2,6 @@ set -x tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS export CUDA_DEVICE_MAX_CONNECTIONS=1 export TOKENIZERS_PARALLELISM=false export RAY_DEDUP_LOGS=0 diff --git a/tests/e2e_tests/math/sglang/run_collocated.sh b/tests/e2e_tests/math/sglang/run_collocated.sh deleted file mode 100644 index 5911653e7..000000000 --- a/tests/e2e_tests/math/sglang/run_collocated.sh +++ /dev/null @@ -1,17 +0,0 @@ -#! /bin/bash -set -x - -tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export TOKENIZERS_PARALLELISM=false - -export PYTHONPATH=${REPO_PATH}:$PYTHONPATH - -if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-1.5b-grpo-collocated" -else - CONFIG_NAME=$1 -fi - -python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/math/sglang --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/e2e_tests/math/vllm/run_collocated.sh b/tests/e2e_tests/math/vllm/run_collocated.sh deleted file mode 100644 index 6ce4067fd..000000000 --- a/tests/e2e_tests/math/vllm/run_collocated.sh +++ /dev/null @@ -1,17 +0,0 @@ -#! /bin/bash -set -x - -tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export TOKENIZERS_PARALLELISM=false - -export PYTHONPATH=${REPO_PATH}:$PYTHONPATH - -if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-1.5b-grpo-collocated" -else - CONFIG_NAME=$1 -fi - -python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/math/vllm --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/e2e_tests/math/vllm/run_pipeline.sh b/tests/e2e_tests/math/vllm/run_pipeline.sh deleted file mode 100644 index 59fb19454..000000000 --- a/tests/e2e_tests/math/vllm/run_pipeline.sh +++ /dev/null @@ -1,17 +0,0 @@ -#! /bin/bash -set -x - -tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export TOKENIZERS_PARALLELISM=false - -export PYTHONPATH=${REPO_PATH}:$PYTHONPATH - -if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-1.5b-grpo-pipeline" -else - CONFIG_NAME=$1 -fi - -python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/math/vllm --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml new file mode 100644 index 000000000..fd208e21b --- /dev/null +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml @@ -0,0 +1,213 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: /workspace/results/ + project_name: rlinf + experiment_name: "ci-test" + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 1 + max_steps: 3 + + val_check_interval: 1 + save_interval: -1 + + seq_length: 1024 + + enable_dynamic_batch_size: True + max_tokens_per_mbs: 1024 + + resume_dir: null + experiment_name: grpo-1.5b + output_dir: /workspace/results +algorithm: + group_size: 2 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: False + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + model_arch: qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: sglang # [sglang, vllm] + + sglang: + attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: math + dataset_name: boba + max_prompt_length: 256 + filter_prompt_by_length: True + rollout_batch_size: 8 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + shuffle: True + validation_shuffle: True + seed: 1 + train_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + val_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + + optim: + optimizer: adam + bf16: True + fp16: False + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + use_fast: False + trust_remote_code: True + padding_side: 'right' + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'math' + reward_scale: 5.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + +critic: + use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml new file mode 100644 index 000000000..bc7ab77e2 --- /dev/null +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml @@ -0,0 +1,213 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: /workspace/results/ + project_name: rlinf + experiment_name: "ci-test" + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 1 + max_steps: 3 + + val_check_interval: 1 + save_interval: -1 + + seq_length: 1024 + + enable_dynamic_batch_size: True + max_tokens_per_mbs: 1024 + + resume_dir: null + experiment_name: grpo-1.5b + output_dir: /workspace/results +algorithm: + group_size: 2 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: True + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + model_arch: qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: sglang # [sglang, vllm] + + sglang: + attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: math + dataset_name: boba + max_prompt_length: 256 + filter_prompt_by_length: True + rollout_batch_size: 8 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + shuffle: True + validation_shuffle: True + seed: 1 + train_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + val_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + + optim: + optimizer: adam + bf16: True + fp16: False + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + use_fast: False + trust_remote_code: True + padding_side: 'right' + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'math' + reward_scale: 5.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + +critic: + use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml new file mode 100644 index 000000000..f2b15f9cb --- /dev/null +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml @@ -0,0 +1,213 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: /workspace/results/ + project_name: rlinf + experiment_name: "ci-test" + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 1 + max_steps: 3 + + val_check_interval: 1 + save_interval: -1 + + seq_length: 1024 + + enable_dynamic_batch_size: True + max_tokens_per_mbs: 1024 + + resume_dir: null + experiment_name: grpo-1.5b + output_dir: /workspace/results +algorithm: + group_size: 2 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: False + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + model_arch: qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: vllm # [sglang, vllm] + + sglang: + attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: math + dataset_name: boba + max_prompt_length: 256 + filter_prompt_by_length: True + rollout_batch_size: 8 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + shuffle: True + validation_shuffle: True + seed: 1 + train_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + val_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + + optim: + optimizer: adam + bf16: True + fp16: False + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + use_fast: False + trust_remote_code: True + padding_side: 'right' + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'math' + reward_scale: 5.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + +critic: + use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml new file mode 100644 index 000000000..19a76629f --- /dev/null +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml @@ -0,0 +1,213 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: /workspace/results/ + project_name: rlinf + experiment_name: "ci-test" + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 1 + max_steps: 3 + + val_check_interval: 1 + save_interval: -1 + + seq_length: 1024 + + enable_dynamic_batch_size: True + max_tokens_per_mbs: 1024 + + resume_dir: null + experiment_name: grpo-1.5b + output_dir: /workspace/results +algorithm: + group_size: 2 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: True + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + model_arch: qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: vllm # [sglang, vllm] + + sglang: + attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: math + dataset_name: boba + max_prompt_length: 256 + filter_prompt_by_length: True + rollout_batch_size: 8 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + shuffle: True + validation_shuffle: True + seed: 1 + train_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + val_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + + optim: + optimizer: adam + bf16: True + fp16: False + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + use_fast: False + trust_remote_code: True + padding_side: 'right' + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'math' + reward_scale: 5.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + +critic: + use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs.yaml similarity index 96% rename from tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs.yaml index 3516fe44b..4edc14979 100644 --- a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs.yaml @@ -84,11 +84,11 @@ rollout: model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B model_arch: qwen2.5 - enforce_eager: False # if False, vllm will capture cuda graph, which will take more time to initialize. + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. distributed_executor_backend: mp # ray or mp disable_log_stats: False detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. - padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for vllm rollout + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine eos: null # will be tokenizer.eos_token_id if null. rollout_backend: sglang # [sglang, vllm] @@ -114,9 +114,9 @@ rollout: validate_weight: False # whether to send all weights at first for weight comparison. validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. - print_outputs: False # whether to print the outputs (token ids, texts, etc.) of inference engine. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. - max_running_requests: 64 # the maximum number of running requests in the inference engine. + max_running_requests: 64 # the maximum number of running requests in the rollout engine. cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. data: diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl.yaml similarity index 96% rename from tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl.yaml index 79b5e1595..1854fd8c3 100644 --- a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl.yaml @@ -107,7 +107,6 @@ rollout: max_num_batched_tokens: null # the maximum number of tokens to be batched together in vllm. If set to null, vllm will use its default value. torch_profiler_dir: null # if not null, vllm will enable torch profiler and save the result to the specified directory. - return_logprobs: ${not:${algorithm.recompute_logprobs}} tensor_parallel_size: 1 @@ -117,13 +116,9 @@ rollout: validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. - sglang_decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. max_running_requests: 64 # the maximum number of running requests in the rollout engine. cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. - use_torch_compile: False # enable torch_compile in SGLang for rollout. - torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. - data: type: math dataset_name: boba diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs.yaml similarity index 100% rename from tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs.yaml diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm.yaml similarity index 100% rename from tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm.yaml diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs.yaml similarity index 100% rename from tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs.yaml diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl.yaml similarity index 100% rename from tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl.yaml diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs.yaml similarity index 100% rename from tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs.yaml diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm.yaml similarity index 100% rename from tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm.yaml diff --git a/tests/e2e_tests/vqa/sglang/qwen2.5-vl-3b-grpo-collocated.yaml b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml similarity index 100% rename from tests/e2e_tests/vqa/sglang/qwen2.5-vl-3b-grpo-collocated.yaml rename to tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml diff --git a/tests/e2e_tests/vqa/vllm/qwen2.5-vl-3b-grpo-collocated.yaml b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml similarity index 100% rename from tests/e2e_tests/vqa/vllm/qwen2.5-vl-3b-grpo-collocated.yaml rename to tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml diff --git a/tests/e2e_tests/vqa/sglang/run_collocated.sh b/tests/e2e_tests/reasoning/run_collocated.sh similarity index 63% rename from tests/e2e_tests/vqa/sglang/run_collocated.sh rename to tests/e2e_tests/reasoning/run_collocated.sh index 793c92417..92e43866f 100644 --- a/tests/e2e_tests/vqa/sglang/run_collocated.sh +++ b/tests/e2e_tests/reasoning/run_collocated.sh @@ -2,16 +2,15 @@ set -x tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS export CUDA_DEVICE_MAX_CONNECTIONS=1 export TOKENIZERS_PARALLELISM=false export PYTHONPATH=${REPO_PATH}:$PYTHONPATH if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-vl-3b-grpo-collocated" + echo "Please provide a config name as the first argument." + exit 1 else CONFIG_NAME=$1 fi - -python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/vqa/sglang --config-name $CONFIG_NAME \ No newline at end of file +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/reasoning/ --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/e2e_tests/math/sglang/run_pipeline.sh b/tests/e2e_tests/reasoning/run_pipeline.sh similarity index 63% rename from tests/e2e_tests/math/sglang/run_pipeline.sh rename to tests/e2e_tests/reasoning/run_pipeline.sh index f18012bb4..3ca1574f0 100644 --- a/tests/e2e_tests/math/sglang/run_pipeline.sh +++ b/tests/e2e_tests/reasoning/run_pipeline.sh @@ -2,16 +2,16 @@ set -x tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS export CUDA_DEVICE_MAX_CONNECTIONS=1 export TOKENIZERS_PARALLELISM=false export PYTHONPATH=${REPO_PATH}:$PYTHONPATH if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-1.5b-grpo-pipeline" + echo "Please provide a config name as the first argument." + exit 1 else CONFIG_NAME=$1 fi -python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/math/sglang --config-name $CONFIG_NAME \ No newline at end of file +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/reasoning --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/e2e_tests/vqa/vllm/run_collocated.sh b/tests/e2e_tests/vqa/vllm/run_collocated.sh deleted file mode 100644 index ab406c4ea..000000000 --- a/tests/e2e_tests/vqa/vllm/run_collocated.sh +++ /dev/null @@ -1,17 +0,0 @@ -#! /bin/bash -set -x - -tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export TOKENIZERS_PARALLELISM=false - -export PYTHONPATH=${REPO_PATH}:$PYTHONPATH - -if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-vl-3b-grpo-collocated" -else - CONFIG_NAME=$1 -fi - -python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/vqa/vllm --config-name $CONFIG_NAME \ No newline at end of file From 01f95ff36ea5d4d72d880ba25680331fd06fcfb6 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Fri, 10 Oct 2025 15:15:26 +0000 Subject: [PATCH 29/38] feat(fsdp): fix ci, add fsdp optimizations like overlap and gradient accumulation Signed-off-by: Bo Dai --- .../config/libero_10_grpo_openvlaoft.yaml | 6 ++++++ .../libero_10_grpo_openvlaoft_eval.yaml | 6 ++++++ .../config/libero_10_ppo_openvlaoft.yaml | 6 ++++++ .../config/libero_goal_grpo_openvlaoft.yaml | 6 ++++++ .../config/libero_object_grpo_openvlaoft.yaml | 6 ++++++ .../libero_spatial_grpo_openvlaoft.yaml | 6 ++++++ .../config/maniskill_grpo_openvla.yaml | 6 ++++++ .../config/maniskill_grpo_openvlaoft.yaml | 6 ++++++ .../config/maniskill_ppo_openvla.yaml | 6 ++++++ .../maniskill_ppo_openvla_quickstart.yaml | 6 ++++++ .../config/maniskill_ppo_openvlaoft.yaml | 6 ++++++ .../config/robotwin_ppo_openvlaoft.yaml | 6 ++++++ .../config/math/qwen2.5-1.5b-grpo-fsdp.yaml | 8 ++++++- .../config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml | 18 ++++++++++------ rlinf/config.py | 21 +++++++++++++++++++ .../hybrid_engines/fsdp/fsdp_model_manager.py | 15 ++++++++++++- rlinf/workers/actor/fsdp_actor_worker.py | 11 ++++++++-- .../embodied/maniskill_ppo_openvla.yaml | 6 ++++++ ...-collocated-fsdp-sgl-rollout-logprobs.yaml | 6 ++++++ ...qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml | 6 ++++++ ...collocated-fsdp-vllm-rollout-logprobs.yaml | 6 ++++++ ...wen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml | 6 ++++++ ...wen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml | 6 ++++++ ...en2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml | 6 ++++++ 24 files changed, 177 insertions(+), 10 deletions(-) diff --git a/examples/embodiment/config/libero_10_grpo_openvlaoft.yaml b/examples/embodiment/config/libero_10_grpo_openvlaoft.yaml index 1d9720fdd..525b1c951 100644 --- a/examples/embodiment/config/libero_10_grpo_openvlaoft.yaml +++ b/examples/embodiment/config/libero_10_grpo_openvlaoft.yaml @@ -157,6 +157,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/libero_10_grpo_openvlaoft_eval.yaml b/examples/embodiment/config/libero_10_grpo_openvlaoft_eval.yaml index 628272ed1..69708e7c0 100644 --- a/examples/embodiment/config/libero_10_grpo_openvlaoft_eval.yaml +++ b/examples/embodiment/config/libero_10_grpo_openvlaoft_eval.yaml @@ -158,6 +158,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/libero_10_ppo_openvlaoft.yaml b/examples/embodiment/config/libero_10_ppo_openvlaoft.yaml index bf7e31667..15b33bbbc 100644 --- a/examples/embodiment/config/libero_10_ppo_openvlaoft.yaml +++ b/examples/embodiment/config/libero_10_ppo_openvlaoft.yaml @@ -152,6 +152,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/libero_goal_grpo_openvlaoft.yaml b/examples/embodiment/config/libero_goal_grpo_openvlaoft.yaml index d6356c314..699d7ab40 100644 --- a/examples/embodiment/config/libero_goal_grpo_openvlaoft.yaml +++ b/examples/embodiment/config/libero_goal_grpo_openvlaoft.yaml @@ -156,6 +156,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/libero_object_grpo_openvlaoft.yaml b/examples/embodiment/config/libero_object_grpo_openvlaoft.yaml index b767fd25a..d2bacd05e 100644 --- a/examples/embodiment/config/libero_object_grpo_openvlaoft.yaml +++ b/examples/embodiment/config/libero_object_grpo_openvlaoft.yaml @@ -156,6 +156,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/libero_spatial_grpo_openvlaoft.yaml b/examples/embodiment/config/libero_spatial_grpo_openvlaoft.yaml index 69aec20eb..9469166d9 100644 --- a/examples/embodiment/config/libero_spatial_grpo_openvlaoft.yaml +++ b/examples/embodiment/config/libero_spatial_grpo_openvlaoft.yaml @@ -156,6 +156,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/maniskill_grpo_openvla.yaml b/examples/embodiment/config/maniskill_grpo_openvla.yaml index 16dc2af06..3679d533d 100644 --- a/examples/embodiment/config/maniskill_grpo_openvla.yaml +++ b/examples/embodiment/config/maniskill_grpo_openvla.yaml @@ -158,6 +158,12 @@ actor: adam_eps: 1.0e-05 clip_grad: 1.0 + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/maniskill_grpo_openvlaoft.yaml b/examples/embodiment/config/maniskill_grpo_openvlaoft.yaml index def45aafb..7bd32855b 100644 --- a/examples/embodiment/config/maniskill_grpo_openvlaoft.yaml +++ b/examples/embodiment/config/maniskill_grpo_openvlaoft.yaml @@ -156,6 +156,12 @@ actor: adam_eps: 1.0e-05 clip_grad: 10.0 + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/maniskill_ppo_openvla.yaml b/examples/embodiment/config/maniskill_ppo_openvla.yaml index 6aeae632d..bf9cab8eb 100644 --- a/examples/embodiment/config/maniskill_ppo_openvla.yaml +++ b/examples/embodiment/config/maniskill_ppo_openvla.yaml @@ -154,6 +154,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/maniskill_ppo_openvla_quickstart.yaml b/examples/embodiment/config/maniskill_ppo_openvla_quickstart.yaml index 969dc85cb..b81d1390f 100644 --- a/examples/embodiment/config/maniskill_ppo_openvla_quickstart.yaml +++ b/examples/embodiment/config/maniskill_ppo_openvla_quickstart.yaml @@ -170,6 +170,12 @@ actor: adam_eps: 1.0e-05 clip_grad: 1.0 + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/maniskill_ppo_openvlaoft.yaml b/examples/embodiment/config/maniskill_ppo_openvlaoft.yaml index 1e5893f43..f957f1997 100644 --- a/examples/embodiment/config/maniskill_ppo_openvlaoft.yaml +++ b/examples/embodiment/config/maniskill_ppo_openvlaoft.yaml @@ -159,6 +159,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/robotwin_ppo_openvlaoft.yaml b/examples/embodiment/config/robotwin_ppo_openvlaoft.yaml index 30700e1c2..56d936659 100644 --- a/examples/embodiment/config/robotwin_ppo_openvlaoft.yaml +++ b/examples/embodiment/config/robotwin_ppo_openvlaoft.yaml @@ -158,6 +158,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml index e1e97c215..6a7bef298 100644 --- a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml @@ -25,7 +25,7 @@ runner: val_check_interval: 1 save_interval: 50 - seq_length: 2048 + seq_length: 28672 enable_dynamic_batch_size: False max_tokens_per_mbs: 28672 @@ -199,6 +199,12 @@ actor: trust_remote_code: True padding_side: 'right' + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + reward: group_name: "RewardGroup" use_reward_model: false diff --git a/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml index 6643e74bd..8d32a6a33 100644 --- a/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml +++ b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml @@ -84,7 +84,7 @@ rollout: gpu_memory_utilization: 0.55 - model_dir: /mnt/public/hf_models/Qwen2.5-VL-3B-Instruct + model_dir: /path/to/model/Qwen2.5-VL-3B-Instruct model_arch: qwen2.5_vl #qwen2.5 enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. distributed_executor_backend: mp # ray or mp @@ -137,8 +137,8 @@ data: shuffle: True validation_shuffle: True seed: 1234 - train_data_paths: ["/mnt/public/guozhen/data/robo2vlm/train/"] - val_data_paths: ["/mnt/public/guozhen/data/robo2vlm/test/"] + train_data_paths: ["/dataset/robo2vlm-1/data/train/"] + val_data_paths: ["/dataset/robo2vlm-1/data/val/"] actor: group_name: "ActorGroup" @@ -165,7 +165,7 @@ actor: seq_length: ${runner.seq_length} encoder_seq_length: ${runner.seq_length} - model_path: /mnt/public/hf_models/Qwen2.5-VL-3B-Instruct/ + model_path: /path/to/model/Qwen2.5-VL-3B-Instruct/ model_arch: ${rollout.model_arch} @@ -197,11 +197,17 @@ actor: lr_decay_iters: 10 tokenizer: - tokenizer_model: /mnt/public/hf_models/Qwen2.5-VL-3B-Instruct + tokenizer_model: /path/to/model/Qwen2.5-VL-3B-Instruct use_fast: False trust_remote_code: True padding_side: 'right' + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + reward: group_name: "RewardGroup" use_reward_model: false @@ -213,7 +219,7 @@ reward: answer_format: 0.0 tokenizer: - tokenizer_model: /mnt/public/hf_models/Qwen2.5-VL-3B-Instruct + tokenizer_model: /path/to/model/Qwen2.5-VL-3B-Instruct use_fast: False trust_remote_code: True padding_side: 'right' diff --git a/rlinf/config.py b/rlinf/config.py index 9c31cd49f..16f69f2e8 100644 --- a/rlinf/config.py +++ b/rlinf/config.py @@ -37,6 +37,7 @@ SUPPORTED_MODEL_ARCHS = ["qwen2.5", "qwen2.5_vl", "openvla", "openvla_oft"] SUPPORTED_ROLLOUT_BACKENDS = ["sglang", "vllm"] SUPPORTED_TASK_TYPE = ["embodied", "reasoning", "coding_online_rl"] +SUPPORTED_TRAINING_BACKENDS = ["megatron", "fsdp"] __all__ = ["build_config"] @@ -222,6 +223,16 @@ def validate_model_cfg_by_hf_config(cfg, hf_model_path): return cfg +def validate_fsdp_cfg(cfg: DictConfig) -> DictConfig: + OmegaConf.set_struct(cfg, True) + with open_dict(cfg): + cfg.fsdp.forward_prefetch = cfg.fsdp.get("forward_prefetch", False) + cfg.fsdp.limit_all_gathers = cfg.fsdp.get("limit_all_gathers", False) + cfg.fsdp.backward_prefetch = cfg.fsdp.get("backward_prefetch", False) + cfg.fsdp.use_orig_params = cfg.fsdp.get("use_orig_params", False) + return cfg + + def validate_megatron_cfg(cfg: DictConfig) -> DictConfig: OmegaConf.set_struct(cfg, True) @@ -624,13 +635,23 @@ def validate_cfg(cfg: DictConfig) -> DictConfig: ): assert cfg.algorithm.group_size > 1 + assert cfg.actor.training_backend in SUPPORTED_TRAINING_BACKENDS, ( + f"Unsupported training_backend {cfg.actor.training_backend}. Supported training backends are {SUPPORTED_TRAINING_BACKENDS}." + ) + if cfg.actor.training_backend == "megatron": cfg.actor = validate_megatron_cfg(cfg.actor) cfg.actor = validate_model_cfg_by_hf_config(cfg.actor, cfg.rollout.model_dir) + elif cfg.actor.training_backend == "fsdp": + cfg.actor = validate_fsdp_cfg(cfg.actor) + cfg.actor = validate_model_cfg_by_hf_config(cfg.actor, cfg.rollout.model_dir) if cfg.critic.use_critic_model and cfg.critic.training_backend == "megatron": cfg.critic = validate_megatron_cfg(cfg.critic) cfg = validate_model_cfg_by_hf_config(cfg.critic, cfg.rollout.model_dir) + elif cfg.critic.use_critic_model and cfg.critic.training_backend == "fsdp": + cfg.critic = validate_fsdp_cfg(cfg.critic) + cfg.critic = validate_model_cfg_by_hf_config(cfg.critic, cfg.rollout.model_dir) return cfg diff --git a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py index c3bd9475a..9679514c0 100644 --- a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py +++ b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py @@ -17,8 +17,13 @@ import torch import torch.optim as optim from omegaconf import DictConfig +from torch.distributed.fsdp import ( + BackwardPrefetch, + MixedPrecision, + ShardingStrategy, + StateDictType, +) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq from rlinf.config import torch_dtype_from_precision @@ -123,6 +128,14 @@ def setup_model_and_optimizer(self): sharding_strategy=sharding_strategy, # zero3 mixed_precision=mixed_precision, sync_module_states=True, + forward_prefetch=self._cfg.fsdp.forward_prefetch, + backward_prefetch=( + BackwardPrefetch.BACKWARD_PRE + if self._cfg.fsdp.backward_prefetch + else BackwardPrefetch.NONE + ), + limit_all_gathers=self._cfg.fsdp.limit_all_gathers, + use_orig_params=self._cfg.fsdp.use_orig_params, ) # NOTE: Currently we assume that only the value head contains "value_head" in its name. diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index 8c467d553..3276cd744 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -14,6 +14,7 @@ import gc import os +from contextlib import nullcontext from typing import Dict, Tuple import numpy as np @@ -258,7 +259,12 @@ def run_training(self, input_channel: Channel): self.optimizer.zero_grad() metrics = {} - for _, m_batch in enumerate(train_micro_batches): + for idx, m_batch in enumerate(train_micro_batches): + backward_ctx = ( + self.model.no_sync() + if idx < self.gradient_accumulation - 1 + else nullcontext() + ) for k, v in m_batch.items(): m_batch[k] = v.cuda() if isinstance(v, torch.Tensor) else v @@ -352,7 +358,8 @@ def run_training(self, input_channel: Channel): # add to log # scale loss for gradient accumulation and backprop loss = loss / self.gradient_accumulation - loss.backward() + with backward_ctx: + loss.backward() mbs_metrics_data.update( { diff --git a/tests/e2e_tests/embodied/maniskill_ppo_openvla.yaml b/tests/e2e_tests/embodied/maniskill_ppo_openvla.yaml index 2063ab483..5d9258743 100644 --- a/tests/e2e_tests/embodied/maniskill_ppo_openvla.yaml +++ b/tests/e2e_tests/embodied/maniskill_ppo_openvla.yaml @@ -165,6 +165,12 @@ actor: adam_eps: 1.0e-05 clip_grad: 1.0 + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + reward: use_reward_model: False diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml index fd208e21b..658a730ca 100644 --- a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml @@ -193,6 +193,12 @@ actor: trust_remote_code: True padding_side: 'right' + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + reward: group_name: "RewardGroup" use_reward_model: false diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml index bc7ab77e2..29ac00d07 100644 --- a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml @@ -193,6 +193,12 @@ actor: trust_remote_code: True padding_side: 'right' + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + reward: group_name: "RewardGroup" use_reward_model: false diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml index f2b15f9cb..6b94ddf21 100644 --- a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml @@ -193,6 +193,12 @@ actor: trust_remote_code: True padding_side: 'right' + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + reward: group_name: "RewardGroup" use_reward_model: false diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml index 19a76629f..ad06eabf3 100644 --- a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml @@ -193,6 +193,12 @@ actor: trust_remote_code: True padding_side: 'right' + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + reward: group_name: "RewardGroup" use_reward_model: false diff --git a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml index bcbae2f94..45575d9e7 100644 --- a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml @@ -202,6 +202,12 @@ actor: trust_remote_code: True padding_side: 'right' + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + reward: group_name: "RewardGroup" use_reward_model: false diff --git a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml index f555a7292..870731fbb 100644 --- a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml @@ -202,6 +202,12 @@ actor: trust_remote_code: True padding_side: 'right' + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + reward: group_name: "RewardGroup" use_reward_model: false From a8023c82d16c80f2fcc43acb5872117b79640fc9 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Sat, 11 Oct 2025 09:35:28 +0000 Subject: [PATCH 30/38] fix(ci): add fsdp's run_inference, fix ci Signed-off-by: Bo Dai --- .github/workflows/code-test.yml | 24 ++-- .../config/math/qwen2.5-1.5b-grpo-fsdp.yaml | 2 +- rlinf/config.py | 2 +- rlinf/workers/actor/fsdp_actor_worker.py | 103 ++++++++++++++++-- rlinf/workers/actor/megatron_actor_worker.py | 4 +- rlinf/workers/reward/reward_worker.py | 16 +-- ...-collocated-fsdp-sgl-rollout-logprobs.yaml | 2 +- ...qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml | 2 +- ...collocated-fsdp-vllm-rollout-logprobs.yaml | 2 +- ...wen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml | 2 +- ...wen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml | 2 +- ...en2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml | 2 +- 12 files changed, 119 insertions(+), 44 deletions(-) diff --git a/.github/workflows/code-test.yml b/.github/workflows/code-test.yml index 6fd1c1409..075b7ac8d 100644 --- a/.github/workflows/code-test.yml +++ b/.github/workflows/code-test.yml @@ -117,42 +117,42 @@ jobs: timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reasoning + source switch_env reason bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-mg-sgl - name: Megatron vLLM Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reasoning + source switch_env reason bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-mg-vllm - name: Megatron SGLang Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reasoning + source switch_env reason bash tests/e2e_tests/reasoning/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-mg-sgl - name: Megatron vLLM Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reasoning + source switch_env reason bash tests/e2e_tests/reasoning/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-mg-vllm - name: FSDO SGLang Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reasoning + source switch_env reason bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-fsdp-sgl - name: FSDP vLLM Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reasoning + source switch_env reason bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-fsdp-vllm @@ -168,42 +168,42 @@ jobs: timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reasoning + source switch_env reason bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs - name: Megatron vLLM Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reasoning + source switch_env reason bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs - name: Megatron SGLang Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reasoning + source switch_env reason bash tests/e2e_tests/reasoning/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs - name: Megatron vLLM Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reasoning + source switch_env reason bash tests/e2e_tests/reasoning/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs - name: FSDP SGLang Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reasoning + source switch_env reason bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-fsdp-sglang-rollout-logprobs - name: FSDP vLLM Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) - source switch_env reasoning + source switch_env reason bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs coding-online-rl-qwen-ppo-test: diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml index 6a7bef298..19281cbc9 100644 --- a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml @@ -123,7 +123,7 @@ data: dataset_name: boba max_prompt_length: 1024 filter_prompt_by_length: True - rollout_batch_size: 8 + rollout_batch_size: 512 val_rollout_batch_size: null num_workers: 2 shuffle: True diff --git a/rlinf/config.py b/rlinf/config.py index 16f69f2e8..ae7b85985 100644 --- a/rlinf/config.py +++ b/rlinf/config.py @@ -186,7 +186,7 @@ def validate_vllm_cfg(cfg): def validate_model_cfg_by_hf_config(cfg, hf_model_path): # validate by hf config - hf_config = AutoConfig.from_pretrained(hf_model_path) + hf_config = AutoConfig.from_pretrained(hf_model_path, trust_remote_code=True) if "Qwen2ForCausalLM" in hf_config.architectures: qkv_bias = True diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index 3276cd744..3b18c8e1e 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -56,7 +56,9 @@ from rlinf.utils.utils import ( compute_entropy_from_logits, compute_logprobs_from_logits, + cpu_weight_swap, masked_mean, + retrieve_model_state_dict_in_cpu, seq_mean_token_mean, seq_mean_token_sum, ) @@ -96,6 +98,8 @@ def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): self._rollout_group_name = cfg.rollout.group_name self._component_placement = placement self.is_data_io_rank = True + self.is_pipeline = self._component_placement.is_disaggregated + self.ref_policy_state_dict = None if self.cfg.algorithm.loss_agg_func == "token-mean": self.loss_agg_func = masked_mean @@ -118,8 +122,13 @@ def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): reward_cls = get_reward_class(self.cfg.reward.reward_type) self.reward = reward_cls(self.cfg.reward) - def init_worker(self): + def init_worker(self) -> None: self.setup_model_and_optimizer() + if self.cfg.algorithm.kl_beta > 0 and self.cfg.actor.get( + "combine_reference_model", True + ): + self.ref_policy_state_dict = retrieve_model_state_dict_in_cpu(self.model) + if self.cfg.actor.get("enable_offload", False): self.offload_fsdp_param_and_grad() self.offload_fsdp_optimizer() @@ -128,7 +137,7 @@ def init_worker(self): torch.cuda.empty_cache() self._setup_rollout_weight_dst_ranks() - def _setup_rollout_weight_dst_ranks(self): + def _setup_rollout_weight_dst_ranks(self) -> None: """Setup destination ranks for token and weight communication.""" rank_map = RankMapper.get_actor_rank_to_rollout_rank_map( self._component_placement @@ -138,11 +147,11 @@ def _setup_rollout_weight_dst_ranks(self): f"Actor rank {self._rank} will send weights to {self._weight_dst_rank_in_rollout}" ) - def del_reshard_state_dict(self): + def del_reshard_state_dict(self) -> None: if hasattr(self, "rollout_state_dict"): del self.rollout_state_dict - def sync_model_to_rollout(self): + def sync_model_to_rollout(self) -> None: if self.cfg.actor.get("enable_offload", False): self.offload_fsdp_optimizer() @@ -175,7 +184,7 @@ def sync_model_to_rollout(self): if self.cfg.actor.get("enable_offload", False): self.offload_fsdp_param_and_grad() - def compute_logprobs(self): + def compute_logprobs(self) -> None: self.model.eval() self.rollout_batch["logprob"] = self.rollout_batch["prev_logprobs"] @@ -191,7 +200,7 @@ def get_batch( ) return batch, result - def put_result(self, result: RolloutResult, channel: Channel): + def put_result(self, result: RolloutResult, channel: Channel) -> None: if channel.is_local: # Local channel, every process will put its own data locally # No need to broadcast @@ -200,7 +209,7 @@ def put_result(self, result: RolloutResult, channel: Channel): if self.is_data_io_rank: channel.put(result) - def _load_weight_and_optimizer(self, channel: Channel): + def _load_weight_and_optimizer(self, channel: Channel) -> None: # Acquire the GPUs to ensure that no one is using them before loading models # Otherwise, it may lead to OOM with channel.device_lock: @@ -208,7 +217,81 @@ def _load_weight_and_optimizer(self, channel: Channel): self.load_fsdp_param_and_grad(self.device) self.load_fsdp_optimizer(self.device) - def run_training(self, input_channel: Channel): + @torch.no_grad() + def inference_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + self.model.eval() + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + position_ids = batch["position_ids"] + + multi_modal_inputs = {} + if "multi_modal_inputs" in batch.keys(): + for key in batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat( + [inputs[key] for inputs in batch["multi_modal_inputs"]], + dim=0, + ).cuda() + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + **multi_modal_inputs, + ) + + logits = outputs.logits + logits = logits[:, -self.response_len - 1 : -1, :] + logits = logits / self.cfg.algorithm.sampling_params.temperature + + responses = input_ids[:, -self.response_len :] + logprobs = compute_logprobs_from_logits( + logits, responses, task_type=self.cfg.runner.task_type + ) + return logprobs + + def run_inference( + self, + input_channel: Channel, + output_channel: Channel, + rollout_channel: Channel, + compute_ref_logprobs: bool, + ) -> None: + """ + Compute prev/ref logprobs using the actor Model's forward. + + Args: + input_channel: The input channel to read from. + output_channel: The output channel to send results to. + rollout_channel: get the rollout channel's device lock in case of collision. + compute_ref_logprobs: Whether to compute reference logprobs. + """ + recv_batch_size = 0 + while recv_batch_size < self.total_batch_size_per_dp: + batch, rollout_result = self.get_batch(input_channel) + recv_batch_size += rollout_result.num_sequence + self._load_weight_and_optimizer( + input_channel if self.is_pipeline else rollout_channel + ) + + with self.worker_timer(): + prev_logprobs = self.inference_step(batch) + rollout_result.prev_logprobs = prev_logprobs.cpu() + + if compute_ref_logprobs: + assert self.ref_policy_state_dict is not None, ( + "Reference policy state dict is None but compute_ref_logprobs is True" + ) + with cpu_weight_swap(self.model, self.ref_policy_state_dict): + ref_logprobs = self.inference_step(batch) + rollout_result.ref_logprobs = ref_logprobs.cpu() + self.put_result(rollout_result, output_channel) + + assert recv_batch_size == self.total_batch_size_per_dp, ( + f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" + ) + + def run_training(self, input_channel: Channel) -> Tuple[Dict, list]: # Get all batches for this DP batches = [] recv_batch_size = 0 @@ -408,7 +491,7 @@ def run_training(self, input_channel: Channel): return rollout_metrics, training_metrics_list - def save_checkpoint(self, save_base_path, step): + def save_checkpoint(self, save_base_path: str, step: int) -> None: torch.distributed.barrier() model_state = self.get_model_state_dict() optim_state = self.get_optimizer_state_dict() @@ -421,7 +504,7 @@ def save_checkpoint(self, save_base_path, step): # Advantages and returns def compute_advantages_and_returns( self, input_channel: Channel, output_channel: Channel - ): + ) -> None: """Compute the advantages and returns. Args: diff --git a/rlinf/workers/actor/megatron_actor_worker.py b/rlinf/workers/actor/megatron_actor_worker.py index 6736e550b..c7129a321 100644 --- a/rlinf/workers/actor/megatron_actor_worker.py +++ b/rlinf/workers/actor/megatron_actor_worker.py @@ -844,11 +844,13 @@ def run_inference( rollout_channel: Channel, compute_ref_logprobs: bool, ): - """Compute prev/ref logprobs using the actor Model's forward. + """ + Compute prev/ref logprobs using the actor Model's forward. Args: input_channel: The input channel to read from. output_channel: The output channel to send results to. + rollout_channel: get the rollout channel's device lock in case of collision. compute_ref_logprobs: Whether to compute reference logprobs. """ recv_batch_size = 0 diff --git a/rlinf/workers/reward/reward_worker.py b/rlinf/workers/reward/reward_worker.py index e186eebac..eea493117 100644 --- a/rlinf/workers/reward/reward_worker.py +++ b/rlinf/workers/reward/reward_worker.py @@ -19,15 +19,13 @@ from rlinf.algorithms.rewards import get_reward_class from rlinf.data.io_struct import RolloutResult -from rlinf.hybrid_engines.fsdp.fsdp_model_manager import FSDPModelManager from rlinf.scheduler import Channel, Worker from rlinf.utils.placement import ModelParallelComponentPlacement -class RewardWorker(FSDPModelManager, Worker): +class RewardWorker(Worker): def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): Worker.__init__(self) - super().__init__(cfg.reward) self.cfg = cfg self.component_placement = placement @@ -39,9 +37,7 @@ def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): def init_worker(self): if self.cfg.reward.use_reward_model: - self.setup_model_and_optimizer() - self.offload_fsdp_param_and_grad() - self.offload_fsdp_optimizer() + raise NotImplementedError("Reward model is not implemented yet.") else: self.reward = get_reward_class(self.cfg.reward.reward_type)(self.cfg.reward) @@ -106,10 +102,4 @@ def _compute_rule_based_rewards(self, rollout_result: RolloutResult): ) def compute_batch_rewards_with_model(self, batch: Dict[str, torch.Tensor]): - self.model.eval() - with torch.no_grad(): - # TODO: fix this - rewards = ( - self.model(batch["input_ids"], batch["attention_mask"]).detach().cpu() - ) - return rewards + raise NotImplementedError("Reward model is not implemented yet.") diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml index 658a730ca..86c6f362c 100644 --- a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml @@ -123,7 +123,7 @@ data: dataset_name: boba max_prompt_length: 256 filter_prompt_by_length: True - rollout_batch_size: 8 + rollout_batch_size: 16 val_rollout_batch_size: null num_workers: 2 prompt_key: prompt diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml index 29ac00d07..dbf7925f2 100644 --- a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml @@ -123,7 +123,7 @@ data: dataset_name: boba max_prompt_length: 256 filter_prompt_by_length: True - rollout_batch_size: 8 + rollout_batch_size: 16 val_rollout_batch_size: null num_workers: 2 prompt_key: prompt diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml index 6b94ddf21..e85fd4146 100644 --- a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml @@ -123,7 +123,7 @@ data: dataset_name: boba max_prompt_length: 256 filter_prompt_by_length: True - rollout_batch_size: 8 + rollout_batch_size: 16 val_rollout_batch_size: null num_workers: 2 prompt_key: prompt diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml index ad06eabf3..1aab4cb7b 100644 --- a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml @@ -123,7 +123,7 @@ data: dataset_name: boba max_prompt_length: 256 filter_prompt_by_length: True - rollout_batch_size: 8 + rollout_batch_size: 16 val_rollout_batch_size: null num_workers: 2 prompt_key: prompt diff --git a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml index 45575d9e7..46278b949 100644 --- a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml @@ -124,7 +124,7 @@ data: dataset_name: robo2vlm max_prompt_length: 1024 filter_prompt_by_length: True - rollout_batch_size: 8 + rollout_batch_size: 16 val_rollout_batch_size: null num_workers: 2 prompt_key: prompt diff --git a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml index 870731fbb..fb7c8aa02 100644 --- a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml @@ -124,7 +124,7 @@ data: dataset_name: robo2vlm max_prompt_length: 1024 filter_prompt_by_length: True - rollout_batch_size: 8 + rollout_batch_size: 16 val_rollout_batch_size: null num_workers: 2 prompt_key: prompt From 803c4c68be2b43b87611209f824464b32ae6791b Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Sun, 12 Oct 2025 12:20:23 +0000 Subject: [PATCH 31/38] fix(ci): fix some errors Signed-off-by: Bo Dai --- .github/workflows/code-test.yml | 28 +++- .github/workflows/vqa_e2e.yml | 63 -------- rlinf/config.py | 4 +- .../hybrid_engines/fsdp/fsdp_model_manager.py | 38 ++--- .../hybrid_engines/vllm/vllm_0_8_5/worker.py | 13 ++ rlinf/runners/coding_online_rl_runner.py | 1 + rlinf/workers/actor/fsdp_actor_worker.py | 8 -- rlinf/workers/actor/megatron_actor_worker.py | 6 +- rlinf/workers/reward/reward_worker.py | 3 +- .../coding_online_rl/qwen2.5-1.5b-ppo.yaml | 1 + .../embodied/libero_130_grpo_openvlaoft.yaml | 6 + tests/unit_tests/test_io_struct.py | 135 ------------------ 12 files changed, 75 insertions(+), 231 deletions(-) delete mode 100644 .github/workflows/vqa_e2e.yml delete mode 100644 tests/unit_tests/test_io_struct.py diff --git a/.github/workflows/code-test.yml b/.github/workflows/code-test.yml index 075b7ac8d..830ee82e7 100644 --- a/.github/workflows/code-test.yml +++ b/.github/workflows/code-test.yml @@ -141,7 +141,7 @@ jobs: source switch_env reason bash tests/e2e_tests/reasoning/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-mg-vllm - - name: FSDO SGLang Collocated mode + - name: FSDP SGLang Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) @@ -197,7 +197,7 @@ jobs: run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-fsdp-sglang-rollout-logprobs + bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs - name: FSDP vLLM Collocated mode timeout-minutes: 20 @@ -225,6 +225,28 @@ jobs: source switch_env reason bash tests/e2e_tests/coding_online_rl/run_coding_online_rl.sh + qwen-vl-grpo-test: + needs: [check-changes] + if: needs.check-changes.outputs.file_filter == 'true' + runs-on: reason + steps: + - name: Checkout code + uses: actions/checkout@v5 + + - name: FSDP SGLang Collocated mode + timeout-minutes: 20 + run: | + export REPO_PATH=$(pwd) + source switch_env reason + bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-vl-3b-grpo-collocated-fsdp-sgl + + - name: FSDP vLLM Collocated mode + timeout-minutes: 20 + run: | + export REPO_PATH=$(pwd) + source switch_env reason + bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-vl-3b-grpo-collocated-fsdp-vllm + # =============================================== embodied e2e tests ==================================================== embodied-maniskill-ppo-openvla-test: @@ -312,7 +334,7 @@ jobs: # Reason e2e tests reason-qwen-grpo-test, reason-qwen-grpo-test-rollout-logprobs, - coding-online-rl-qwen-ppo-test, + coding-online-rl-qwen-ppo-test, qwen-vl-grpo-test, # Embodied e2e tests embodied-maniskill-ppo-openvla-test, embodied-maniskill-grpo-openvlaoft-test, embodied-libero-goal-grpo-openvlaoft-test,embodied-libero-130-grpo-openvlaoft-test, diff --git a/.github/workflows/vqa_e2e.yml b/.github/workflows/vqa_e2e.yml deleted file mode 100644 index 50b6217aa..000000000 --- a/.github/workflows/vqa_e2e.yml +++ /dev/null @@ -1,63 +0,0 @@ -name: VQA End2End - -on: - push: - branches: - - 'release/v[0-9].[0-9]' - - main - paths: - - '**/*.py' - - 'tests/**' - - '.github/workflows/*.yml' - - '!docs/**' - - '!README.md' - - '!*.yaml' - - '!*.toml' - - '!ray_utils/**' - - '!requirements/**' - - pull_request: - branches: - - 'release/v[0-9].[0-9]' - - main - paths: - - '**/*.py' - - 'tests/**' - - '.github/workflows/*.yml' - - '!docs/**' - - '!README.md' - - '*.yaml' - - '*.toml' - - '!ray_utils/**' - - '!requirements/**' - -permissions: - contents: read - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - qwen-vl-grpo-test: - runs-on: rlinf - container: - image: rlinf/rlinf:math-rlinf0.1-torch2.6.0-sglang0.4.6.post5-vllm0.8.5-megatron0.13.0-te2.1 - volumes: - - /mnt/public/dataset:/workspace/dataset - - /mnt/public/tokenizer:/workspace/tokenizer - options: --gpus="all" --shm-size=80g - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: SGLang Collocated mode - run: | - export REPO_PATH=$(pwd) - bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-vl-3b-grpo-collocated-fsdp-sgl - - - name: vLLM Collocated mode - run: | - export REPO_PATH=$(pwd) - bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-vl-3b-grpo-collocated-fsdp-vllm diff --git a/rlinf/config.py b/rlinf/config.py index ae7b85985..05de15254 100644 --- a/rlinf/config.py +++ b/rlinf/config.py @@ -644,14 +644,12 @@ def validate_cfg(cfg: DictConfig) -> DictConfig: cfg.actor = validate_model_cfg_by_hf_config(cfg.actor, cfg.rollout.model_dir) elif cfg.actor.training_backend == "fsdp": cfg.actor = validate_fsdp_cfg(cfg.actor) - cfg.actor = validate_model_cfg_by_hf_config(cfg.actor, cfg.rollout.model_dir) if cfg.critic.use_critic_model and cfg.critic.training_backend == "megatron": cfg.critic = validate_megatron_cfg(cfg.critic) - cfg = validate_model_cfg_by_hf_config(cfg.critic, cfg.rollout.model_dir) + cfg.critic = validate_model_cfg_by_hf_config(cfg.critic, cfg.rollout.model_dir) elif cfg.critic.use_critic_model and cfg.critic.training_backend == "fsdp": cfg.critic = validate_fsdp_cfg(cfg.critic) - cfg.critic = validate_model_cfg_by_hf_config(cfg.critic, cfg.rollout.model_dir) return cfg diff --git a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py index 9679514c0..aad2c130f 100644 --- a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py +++ b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py @@ -49,22 +49,35 @@ def __init__(self, cfg: DictConfig): self.tokenizer = hf_tokenizer(cfg.tokenizer.tokenizer_model) def model_provider_func(self) -> torch.nn.Module: + cfg = self._cfg + use_gptq = cfg.model.get("gptq_model", False) + load_in_8bit = cfg.model.get("load_in_8bit", False) + + use_triton = cfg.get("use_triton", True) + + assert torch.cuda.is_available(), "CUDA is not available." + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + device = torch.device(f"cuda:{local_rank}") + model_config = AutoConfig.from_pretrained( - self._cfg.model.model_path, + cfg.model.model_path, trust_remote_code=True, attn_implementation="flash_attention_2", ) - if self._cfg.model.get("gptq_model", False): + if use_gptq: from auto_gptq import AutoGPTQForCausalLM model_wrapper = AutoGPTQForCausalLM.from_quantized( - self._cfg.model.model_path, device="cuda:0", use_triton=True + cfg.model.model_path, + device=device, + use_triton=use_triton, ) model = model_wrapper.model - elif self._cfg.model.get("load_in_8bit", False): + elif load_in_8bit: model = AutoModelForCausalLM.from_pretrained( - self._cfg.model.model_path, + cfg.model.model_path, + config=model_config, load_in_8bit=True, ) else: @@ -73,22 +86,15 @@ def model_provider_func(self) -> torch.nn.Module: else: auto_model_class = AutoModelForCausalLM - # default load in float16 model = auto_model_class.from_pretrained( - self._cfg.model.model_path, + cfg.model.model_path, torch_dtype=self.torch_dtype, config=model_config, trust_remote_code=True, ) - model.to(self.torch_dtype) - - if torch.cuda.is_available(): - model = model.cuda() - if self.torch_dtype == torch.float16: - model = model.half() - - torch.distributed.barrier() + if torch.distributed.is_initialized(): + torch.distributed.barrier() return model def setup_model_and_optimizer(self): @@ -132,7 +138,7 @@ def setup_model_and_optimizer(self): backward_prefetch=( BackwardPrefetch.BACKWARD_PRE if self._cfg.fsdp.backward_prefetch - else BackwardPrefetch.NONE + else None ), limit_all_gathers=self._cfg.fsdp.limit_all_gathers, use_orig_params=self._cfg.fsdp.use_orig_params, diff --git a/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py b/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py index 519895e49..104334c1a 100644 --- a/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py +++ b/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py @@ -75,6 +75,19 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: super().initialize_from_config(kv_cache_config) def offload_model_weights(self) -> None: + torch.cuda.synchronize() + + model = self.model_runner.model + with torch.no_grad(): + for mod in model.modules(): + for name, buf in list(getattr(mod, "_buffers", {}).items()): + if isinstance(buf, torch.Tensor) and buf.is_cuda: + cpu_buf = ( + buf.detach().to("cpu", non_blocking=False).contiguous() + ) + mod._buffers[name] = cpu_buf + torch.cuda.empty_cache() + super().sleep(level=2) def sync_hf_weight(self) -> None: diff --git a/rlinf/runners/coding_online_rl_runner.py b/rlinf/runners/coding_online_rl_runner.py index 46be3423a..4ea2ec024 100644 --- a/rlinf/runners/coding_online_rl_runner.py +++ b/rlinf/runners/coding_online_rl_runner.py @@ -212,6 +212,7 @@ def run(self): infer_handle: Handle = self.inference.run_inference( input_channel=self.dataloader_channel, output_channel=self.inference_channel, + rollout_channel=None, compute_ref_logprobs=self.compute_ref_logprobs, ) inference_channel = self.inference_channel diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index 3b18c8e1e..17b9f42ec 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -54,7 +54,6 @@ ModelParallelComponentPlacement, ) from rlinf.utils.utils import ( - compute_entropy_from_logits, compute_logprobs_from_logits, cpu_weight_swap, masked_mean, @@ -132,9 +131,6 @@ def init_worker(self) -> None: if self.cfg.actor.get("enable_offload", False): self.offload_fsdp_param_and_grad() self.offload_fsdp_optimizer() - torch.cuda.synchronize() - gc.collect() - torch.cuda.empty_cache() self._setup_rollout_weight_dst_ranks() def _setup_rollout_weight_dst_ranks(self) -> None: @@ -391,10 +387,6 @@ def run_training(self, input_channel: Channel) -> Tuple[Dict, list]: logprobs = compute_logprobs_from_logits( logits, responses, task_type=self.cfg.runner.task_type ) - if self.calculate_entropy: - entropy = compute_entropy_from_logits( - logits, task_type=self.cfg.runner.task_type - ) # (bsz, response_length) clip_ratio = self.cfg.algorithm.ratio_clip_eps clip_ratio_low = ( diff --git a/rlinf/workers/actor/megatron_actor_worker.py b/rlinf/workers/actor/megatron_actor_worker.py index c7129a321..f7bb3631e 100644 --- a/rlinf/workers/actor/megatron_actor_worker.py +++ b/rlinf/workers/actor/megatron_actor_worker.py @@ -841,7 +841,7 @@ def run_inference( self, input_channel: Channel, output_channel: Channel, - rollout_channel: Channel, + rollout_channel: Optional[Channel], compute_ref_logprobs: bool, ): """ @@ -860,7 +860,9 @@ def run_inference( # Must be called after batch is retrieved, suggesting that rollout has stopped # Otherwise, loading model might cause OOM in the collocated mode self._load_weight_and_optimizer( - input_channel if self.is_pipeline else rollout_channel + input_channel + if self.is_pipeline or rollout_channel is None + else rollout_channel ) # Prev logprobs diff --git a/rlinf/workers/reward/reward_worker.py b/rlinf/workers/reward/reward_worker.py index eea493117..fd2472611 100644 --- a/rlinf/workers/reward/reward_worker.py +++ b/rlinf/workers/reward/reward_worker.py @@ -19,6 +19,7 @@ from rlinf.algorithms.rewards import get_reward_class from rlinf.data.io_struct import RolloutResult +from rlinf.data.tokenizers import hf_tokenizer from rlinf.scheduler import Channel, Worker from rlinf.utils.placement import ModelParallelComponentPlacement @@ -28,7 +29,7 @@ def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): Worker.__init__(self) self.cfg = cfg self.component_placement = placement - + self.tokenizer = hf_tokenizer(cfg.reward.tokenizer.tokenizer_model) self.total_batch_size_per_dp = ( self.cfg.data.rollout_batch_size * self.cfg.algorithm.get("group_size", 1) diff --git a/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml b/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml index 1de266648..a87c3a474 100644 --- a/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml +++ b/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml @@ -12,6 +12,7 @@ cluster: rollout: 0-3 inference: 4-5 actor: 6-7 + reward: 0-3 runner: task_type: coding_online_rl diff --git a/tests/e2e_tests/embodied/libero_130_grpo_openvlaoft.yaml b/tests/e2e_tests/embodied/libero_130_grpo_openvlaoft.yaml index 2365b5a86..2672dcd01 100644 --- a/tests/e2e_tests/embodied/libero_130_grpo_openvlaoft.yaml +++ b/tests/e2e_tests/embodied/libero_130_grpo_openvlaoft.yaml @@ -157,6 +157,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + reward: use_reward_model: False diff --git a/tests/unit_tests/test_io_struct.py b/tests/unit_tests/test_io_struct.py deleted file mode 100644 index 7104c277a..000000000 --- a/tests/unit_tests/test_io_struct.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2025 The RLinf Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch - -from rlinf.data.io_struct import RolloutRequest, RolloutResult - - -def test_rollout_request_repeat_preserves_multimodal(): - request = RolloutRequest( - n=2, - input_ids=[[1, 2, 3], [4, 5]], - image_data=[[b"img1-1", b"img1-2"], []], - answers=["ans1", "ans2"], - multi_modal_inputs=[{"pixels": [1, 2]}, {"pixels": [3]}], - ) - - repeated = request.repeat() - - assert repeated.n == 2 - assert repeated.input_ids == [[1, 2, 3], [1, 2, 3], [4, 5], [4, 5]] - assert repeated.answers == ["ans1", "ans1", "ans2", "ans2"] - assert repeated.image_data == [ - [b"img1-1", b"img1-2"], - [b"img1-1", b"img1-2"], - [], - [], - ] - assert repeated.multi_modal_inputs == [ - {"pixels": [1, 2]}, - {"pixels": [1, 2]}, - {"pixels": [3]}, - {"pixels": [3]}, - ] - - -def _make_rollout_result(): - num_sequence = 4 - group_size = 2 - return RolloutResult( - num_sequence=num_sequence, - group_size=group_size, - prompt_lengths=[3, 3, 4, 4], - prompt_ids=[[11, 12, 13], [11, 12, 13], [21, 22, 23, 24], [21, 22, 23, 24]], - response_lengths=[2, 2, 2, 2], - response_ids=[[101, 102], [201, 202], [301, 302], [401, 402]], - is_end=[True, False, True, True], - answers=[{"answer": "a"}, {"answer": "b"}, {"answer": "c"}, {"answer": "d"}], - image_data=[[b"a"], [b"b"], [b"c"], [b"d"]], - multi_modal_inputs=[ - {"vision": "img-a"}, - {"vision": "img-b"}, - {"vision": "img-c"}, - {"vision": "img-d"}, - ], - prompt_texts=["prompt-a", "prompt-a", "prompt-b", "prompt-b"], - response_texts=["resp-a1", "resp-a2", "resp-b1", "resp-b2"], - rollout_logprobs=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]], - rewards=torch.tensor([[1.0], [0.5], [0.2], [0.1]]), - advantages=[0.1, 0.2, 0.3, 0.4], - prev_logprobs=torch.tensor( - [ - [0.01, 0.02], - [0.03, 0.04], - [0.05, 0.06], - [0.07, 0.08], - ] - ), - ref_logprobs=torch.tensor( - [ - [0.11, 0.12], - [0.13, 0.14], - [0.15, 0.16], - [0.17, 0.18], - ] - ), - ) - - -def test_rollout_result_split_and_merge_roundtrip(): - result = _make_rollout_result() - - split_results = RolloutResult.split_result_list_by_group([result]) - - assert len(split_results) == result.num_sequence // result.group_size - first, second = split_results - - assert first.num_sequence == result.group_size - assert second.num_sequence == result.group_size - assert first.prompt_ids == result.prompt_ids[: result.group_size] - assert second.prompt_ids == result.prompt_ids[result.group_size :] - assert first.response_ids == result.response_ids[: result.group_size] - assert second.response_ids == result.response_ids[result.group_size :] - assert first.prompt_texts == result.prompt_texts[: result.group_size] - assert second.prompt_texts == result.prompt_texts[result.group_size :] - assert first.response_texts == result.response_texts[: result.group_size] - assert second.response_texts == result.response_texts[result.group_size :] - assert first.image_data == result.image_data[: result.group_size] - assert second.image_data == result.image_data[result.group_size :] - assert first.multi_modal_inputs == result.multi_modal_inputs[: result.group_size] - assert second.multi_modal_inputs == result.multi_modal_inputs[result.group_size :] - assert first.rollout_logprobs == result.rollout_logprobs[: result.group_size] - assert second.rollout_logprobs == result.rollout_logprobs[result.group_size :] - assert torch.equal(first.rewards, result.rewards[: result.group_size]) - assert torch.equal(second.rewards, result.rewards[result.group_size :]) - assert first.advantages == result.advantages[: result.group_size] - assert second.advantages == result.advantages[result.group_size :] - - merged = RolloutResult.merge_result_list(split_results) - - assert merged.num_sequence == result.num_sequence - assert merged.group_size == result.group_size - assert merged.prompt_ids == result.prompt_ids - assert merged.prompt_lengths == result.prompt_lengths - assert merged.response_ids == result.response_ids - assert merged.response_lengths == result.response_lengths - assert merged.is_end == result.is_end - assert merged.answers == result.answers - assert merged.rollout_logprobs == result.rollout_logprobs - assert merged.advantages == result.advantages - assert torch.equal(merged.rewards, result.rewards) - assert torch.equal(merged.prev_logprobs, result.prev_logprobs) - assert torch.equal(merged.ref_logprobs, result.ref_logprobs) From c1a74b0afe6391fd9b1197d1e2099914e2e2e8dd Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Mon, 13 Oct 2025 02:17:56 +0000 Subject: [PATCH 32/38] feat(ci): fix ci Signed-off-by: Bo Dai --- .github/workflows/code-test.yml | 32 +++++++++---------- .../hybrid_engines/fsdp/fsdp_model_manager.py | 4 +-- .../hybrid_engines/vllm/vllm_0_8_5/worker.py | 11 ++----- rlinf/runners/coding_online_rl_runner.py | 6 +--- .../{run_auto_placement.sh => run.sh} | 0 .../{run_coding_online_rl.sh => run.sh} | 0 .../embodied/libero_goal_grpo_openvlaoft.yaml | 6 ++++ .../embodied/maniskill_grpo_openvlaoft.yaml | 6 ++++ ...wen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml | 4 +-- ...en2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml | 4 +-- .../reasoning/{run_collocated.sh => run.sh} | 0 tests/e2e_tests/reasoning/run_pipeline.sh | 17 ---------- 12 files changed, 38 insertions(+), 52 deletions(-) rename tests/e2e_tests/auto_placement/{run_auto_placement.sh => run.sh} (100%) rename tests/e2e_tests/coding_online_rl/{run_coding_online_rl.sh => run.sh} (100%) rename tests/e2e_tests/reasoning/{run_collocated.sh => run.sh} (100%) delete mode 100644 tests/e2e_tests/reasoning/run_pipeline.sh diff --git a/.github/workflows/code-test.yml b/.github/workflows/code-test.yml index 830ee82e7..0ec8ca3ef 100644 --- a/.github/workflows/code-test.yml +++ b/.github/workflows/code-test.yml @@ -118,42 +118,42 @@ jobs: run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-mg-sgl + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-mg-sgl - name: Megatron vLLM Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-mg-vllm + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-mg-vllm - name: Megatron SGLang Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-mg-sgl + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-pipeline-mg-sgl - name: Megatron vLLM Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-mg-vllm + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-pipeline-mg-vllm - name: FSDP SGLang Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-fsdp-sgl + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-fsdp-sgl - name: FSDP vLLM Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-fsdp-vllm + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-fsdp-vllm reason-qwen-grpo-test-rollout-logprobs: @@ -169,42 +169,42 @@ jobs: run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs - name: Megatron vLLM Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs - name: Megatron SGLang Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs - name: Megatron vLLM Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs - name: FSDP SGLang Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs - name: FSDP vLLM Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs coding-online-rl-qwen-ppo-test: needs: [check-changes] @@ -223,7 +223,7 @@ jobs: run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/coding_online_rl/run_coding_online_rl.sh + bash tests/e2e_tests/coding_online_rl/run.sh qwen-vl-grpo-test: needs: [check-changes] @@ -238,14 +238,14 @@ jobs: run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-vl-3b-grpo-collocated-fsdp-sgl + bash tests/e2e_tests/reasoning/run.sh qwen2.5-vl-3b-grpo-collocated-fsdp-sgl - name: FSDP vLLM Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/reasoning/run_collocated.sh qwen2.5-vl-3b-grpo-collocated-fsdp-vllm + bash tests/e2e_tests/reasoning/run.sh qwen2.5-vl-3b-grpo-collocated-fsdp-vllm # =============================================== embodied e2e tests ==================================================== @@ -321,7 +321,7 @@ jobs: run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/auto_placement/run_auto_placement.sh + bash tests/e2e_tests/auto_placement/run.sh # =============================================== finale ==================================================== diff --git a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py index aad2c130f..c16ddc3a8 100644 --- a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py +++ b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py @@ -110,8 +110,8 @@ def setup_model_and_optimizer(self): mixed_precision = MixedPrecision( param_dtype=self.torch_dtype, - reduce_dtype=torch.float32, - buffer_dtype=torch.float32, + reduce_dtype=self.torch_dtype, + buffer_dtype=self.torch_dtype, ) if self._cfg.model.sharding_strategy == "full_shard": diff --git a/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py b/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py index 104334c1a..9a3cf4d85 100644 --- a/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py +++ b/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py @@ -78,14 +78,9 @@ def offload_model_weights(self) -> None: torch.cuda.synchronize() model = self.model_runner.model - with torch.no_grad(): - for mod in model.modules(): - for name, buf in list(getattr(mod, "_buffers", {}).items()): - if isinstance(buf, torch.Tensor) and buf.is_cuda: - cpu_buf = ( - buf.detach().to("cpu", non_blocking=False).contiguous() - ) - mod._buffers[name] = cpu_buf + self._sleep_saved_buffers = { + name: buffer.cpu().clone() for name, buffer in model.named_buffers() + } torch.cuda.empty_cache() super().sleep(level=2) diff --git a/rlinf/runners/coding_online_rl_runner.py b/rlinf/runners/coding_online_rl_runner.py index 4ea2ec024..377667523 100644 --- a/rlinf/runners/coding_online_rl_runner.py +++ b/rlinf/runners/coding_online_rl_runner.py @@ -222,16 +222,12 @@ def run(self): # Advantages and returns adv_handle: Handle = self.actor.compute_advantages_and_returns( - input_channel=self.inference_channel, + input_channel=inference_channel, output_channel=self.actor_channel, ) # Actor training actor_input_channel = self.actor_channel - if self.is_pipeline: - # In pipeline mode, the rollout already contains the advantages and returns - # So the above two steps are in fact no-ops, and we should directly use the inference channel as the input - actor_input_channel = inference_channel actor_handle: Handle = self.actor.run_training( input_channel=actor_input_channel, ) diff --git a/tests/e2e_tests/auto_placement/run_auto_placement.sh b/tests/e2e_tests/auto_placement/run.sh similarity index 100% rename from tests/e2e_tests/auto_placement/run_auto_placement.sh rename to tests/e2e_tests/auto_placement/run.sh diff --git a/tests/e2e_tests/coding_online_rl/run_coding_online_rl.sh b/tests/e2e_tests/coding_online_rl/run.sh similarity index 100% rename from tests/e2e_tests/coding_online_rl/run_coding_online_rl.sh rename to tests/e2e_tests/coding_online_rl/run.sh diff --git a/tests/e2e_tests/embodied/libero_goal_grpo_openvlaoft.yaml b/tests/e2e_tests/embodied/libero_goal_grpo_openvlaoft.yaml index ddfe4a500..6dc7893d3 100644 --- a/tests/e2e_tests/embodied/libero_goal_grpo_openvlaoft.yaml +++ b/tests/e2e_tests/embodied/libero_goal_grpo_openvlaoft.yaml @@ -156,6 +156,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/tests/e2e_tests/embodied/maniskill_grpo_openvlaoft.yaml b/tests/e2e_tests/embodied/maniskill_grpo_openvlaoft.yaml index ab384947a..f26ce7c3d 100644 --- a/tests/e2e_tests/embodied/maniskill_grpo_openvlaoft.yaml +++ b/tests/e2e_tests/embodied/maniskill_grpo_openvlaoft.yaml @@ -155,6 +155,12 @@ actor: adam_eps: 1.0e-05 clip_grad: 10.0 + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml index 46278b949..ddc087f99 100644 --- a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml @@ -20,10 +20,10 @@ runner: logger_backends: ["tensorboard"] # wandb, swanlab max_epochs: 1 - max_steps: 3 + max_steps: 2 val_check_interval: 1 - save_interval: 50 + save_interval: -1 seq_length: 2048 diff --git a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml index fb7c8aa02..86c64ebe0 100644 --- a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml @@ -20,10 +20,10 @@ runner: logger_backends: ["tensorboard"] # wandb, swanlab max_epochs: 1 - max_steps: 3 + max_steps: 2 val_check_interval: 1 - save_interval: 50 + save_interval: -1 seq_length: 2048 diff --git a/tests/e2e_tests/reasoning/run_collocated.sh b/tests/e2e_tests/reasoning/run.sh similarity index 100% rename from tests/e2e_tests/reasoning/run_collocated.sh rename to tests/e2e_tests/reasoning/run.sh diff --git a/tests/e2e_tests/reasoning/run_pipeline.sh b/tests/e2e_tests/reasoning/run_pipeline.sh deleted file mode 100644 index 3ca1574f0..000000000 --- a/tests/e2e_tests/reasoning/run_pipeline.sh +++ /dev/null @@ -1,17 +0,0 @@ -#! /bin/bash -set -x - -tabs 4 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export TOKENIZERS_PARALLELISM=false - -export PYTHONPATH=${REPO_PATH}:$PYTHONPATH - -if [ -z "$1" ]; then - echo "Please provide a config name as the first argument." - exit 1 -else - CONFIG_NAME=$1 -fi - -python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/reasoning --config-name $CONFIG_NAME \ No newline at end of file From 2d5131394fc6b5f9ac73049e86c768f39bb3312b Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Mon, 13 Oct 2025 12:14:24 +0000 Subject: [PATCH 33/38] fix(reward): remove redundant reward definitions Signed-off-by: Bo Dai --- rlinf/algorithms/registry.py | 19 ------- rlinf/algorithms/rewards/__init__.py | 2 + rlinf/algorithms/rewards/code/__init__.py | 30 ++++++++++ rlinf/algorithms/rewards/vqa/__init__.py | 56 ++++++++----------- rlinf/workers/actor/fsdp_actor_worker.py | 10 ---- rlinf/workers/actor/megatron_actor_worker.py | 11 ++-- rlinf/workers/rollout/sglang/sglang_worker.py | 26 --------- rlinf/workers/rollout/vllm/vllm_worker.py | 2 - .../coding_online_rl/qwen2.5-1.5b-ppo.yaml | 2 +- toolkits/__init__.py | 19 ------- toolkits/code_verifier/verify.py | 2 - toolkits/math_verifier/verify.py | 27 --------- 12 files changed, 60 insertions(+), 146 deletions(-) create mode 100644 rlinf/algorithms/rewards/code/__init__.py diff --git a/rlinf/algorithms/registry.py b/rlinf/algorithms/registry.py index 96b19f2b4..11bfc6c6a 100644 --- a/rlinf/algorithms/registry.py +++ b/rlinf/algorithms/registry.py @@ -73,22 +73,3 @@ def calculate_adv_and_returns(**kwargs) -> Tuple[torch.Tensor, Optional[torch.Te adv_type = kwargs["adv_type"] fn = get_adv_and_returns(adv_type) return fn(**kwargs) - - -REWARD_REGISTRY: Dict[str, Callable] = {} - - -def register_reward_fn(name: str): - def decorator(fn): - REWARD_REGISTRY[name] = fn - return fn - - return decorator - - -def get_reward_fn(name: Optional[str]): - if name is None: - return None - if name not in REWARD_REGISTRY: - raise ValueError(f"Reward function {name} not registered") - return REWARD_REGISTRY[name] diff --git a/rlinf/algorithms/rewards/__init__.py b/rlinf/algorithms/rewards/__init__.py index 2ab6528ca..380cfa102 100644 --- a/rlinf/algorithms/rewards/__init__.py +++ b/rlinf/algorithms/rewards/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from rlinf.algorithms.rewards.code import CodeReward from rlinf.algorithms.rewards.math import MathReward from rlinf.algorithms.rewards.vqa import VQAReward @@ -30,3 +31,4 @@ def get_reward_class(name: str): register_reward("math", MathReward) register_reward("vqa", VQAReward) +register_reward("code", CodeReward) diff --git a/rlinf/algorithms/rewards/code/__init__.py b/rlinf/algorithms/rewards/code/__init__.py new file mode 100644 index 000000000..0fc75f971 --- /dev/null +++ b/rlinf/algorithms/rewards/code/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from omegaconf import DictConfig + +from toolkits.code_verifier.verify import fim_verify_call + + +class CodeReward: + def __init__(self, config: DictConfig): + self.scale = config.get("reward_scale", 1.0) + + def get_reward( + self, response: List[str], reference: List[List[str]] + ) -> List[float]: + rewards = fim_verify_call(response, reference) + return [float(reward) * self.scale for reward in rewards] diff --git a/rlinf/algorithms/rewards/vqa/__init__.py b/rlinf/algorithms/rewards/vqa/__init__.py index 8175d72a1..77b009369 100644 --- a/rlinf/algorithms/rewards/vqa/__init__.py +++ b/rlinf/algorithms/rewards/vqa/__init__.py @@ -22,51 +22,41 @@ class VQAReward: - def __init__(self, config: DictConfig): - reward_weights_config = config.get( - "reward_weights", - { - "qa_accuracy": 1.0, - "think_format": 0.0, - "answer_format": 0.0, - }, - ) - for reward_name, reward_weight in reward_weights_config.items(): - assert reward_name in ["qa_accuracy", "think_format", "answer_format"], ( - f"Reward {reward_name} not supported" - ) - assert reward_weight >= 0, ( - f"Reward weight {reward_weight} must be non-negative" - ) - self.reward_weights = [ - reward_weights_config["qa_accuracy"], - reward_weights_config["think_format"], - reward_weights_config["answer_format"], - ] + NEEDED_REWARD_FUNCTIONS = { + "qa_accuracy": qa_accuracy_reward, + "think_format": think_format_reward, + "answer_format": answer_format_reward, + } - self.reward_functions = [ - qa_accuracy_reward, - think_format_reward, - answer_format_reward, - ] + def __init__(self, config: DictConfig): + assert "reward_weights" in config, "VQAReward requires reward_weights in config" + self.reward_weights_config = config.reward_weights + assert set(self.reward_weights_config.keys()) == set( + self.NEEDED_REWARD_FUNCTIONS.keys() + ), ( + f"Reward weights must contains all of: {self.NEEDED_REWARD_FUNCTIONS.keys()} but got {list(self.reward_weights_config.keys())}" + ) + assert all( + reward_weight >= 0 for reward_weight in self.reward_weights_config.values() + ), ( + f"All reward weights must be non-negative but got {list(self.reward_weights_config.values())}" + ) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_reward(self, completions: List[str], answers: List[dict]) -> List[float]: rewards = [] - for i, reward_function in enumerate(self.reward_functions): - if self.reward_weights[i] > 0: + reward_weights = [] + for reward_name, reward_function in self.NEEDED_REWARD_FUNCTIONS.items(): + if self.reward_weights_config[reward_name] > 0: rewards.append(reward_function(completions, answers)) else: rewards.append([0.0] * len(completions)) + reward_weights.append(self.reward_weights_config[reward_name]) - # Apply weights to each reward function's output and sum - - # rewards [num_reward_functions, len(completions)] rewards_tensor = torch.tensor(rewards, device=self.device) - weights_tensor = torch.tensor(self.reward_weights, device=self.device) + weights_tensor = torch.tensor(reward_weights, device=self.device) - # [num_reward_functions, num_completions] * [num_reward_functions, 1] -> [num_completions] final_rewards = (rewards_tensor * weights_tensor.unsqueeze(1)).sum(dim=0) return final_rewards.tolist() diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index 17b9f42ec..06607903f 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -111,16 +111,6 @@ def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): f"algorithm.loss_agg_func={self.cfg.algorithm.loss_agg_func} is not supported!" ) - # Reward configurations - if not self.cfg.reward.use_reward_model: - assert self.cfg.reward.reward_type in ["math", "vqa"], ( - "only support math and vqa reward!" - ) - from rlinf.algorithms.rewards import get_reward_class - - reward_cls = get_reward_class(self.cfg.reward.reward_type) - self.reward = reward_cls(self.cfg.reward) - def init_worker(self) -> None: self.setup_model_and_optimizer() if self.cfg.algorithm.kl_beta > 0 and self.cfg.actor.get( diff --git a/rlinf/workers/actor/megatron_actor_worker.py b/rlinf/workers/actor/megatron_actor_worker.py index f7bb3631e..5bd19ddc5 100644 --- a/rlinf/workers/actor/megatron_actor_worker.py +++ b/rlinf/workers/actor/megatron_actor_worker.py @@ -31,7 +31,6 @@ from rlinf.algorithms.registry import ( actor_loss, calculate_adv_and_returns, - get_reward_fn, ) from rlinf.algorithms.utils import kl_penalty from rlinf.data.io_struct import ( @@ -79,7 +78,6 @@ seq_mean_token_sum, ) from rlinf.workers.rollout.utils import RankMapper -from toolkits import register_rewards class MegatronActor(MegatronModelManager, Worker): @@ -102,6 +100,10 @@ def __init__( self.cfg = cfg self.component_placement = placement + # check placement validity when actor backend is megatron + assert placement.rollout_tp_size <= placement.actor_tp_size, ( + f" rollout tensor parallel size {placement.rollout_tp_size} must be less than or equal to actor tensor parallel size {placement.actor_tp_size}." + ) # Data configurations self.response_len = ( role_cfg.model.encoder_seq_length - cfg.data.max_prompt_length @@ -154,11 +156,6 @@ def __init__( self.ref_policy_state_dict = None self.is_pipeline = self.component_placement.is_disaggregated - # Reward configurations - if not self.cfg.reward.use_reward_model: - register_rewards() - self.reward_fn = get_reward_fn(self.cfg.reward.reward_type) - # Rollout configurations self.rollout_group_name = self.cfg.rollout.group_name diff --git a/rlinf/workers/rollout/sglang/sglang_worker.py b/rlinf/workers/rollout/sglang/sglang_worker.py index abdf59365..4d5c51552 100644 --- a/rlinf/workers/rollout/sglang/sglang_worker.py +++ b/rlinf/workers/rollout/sglang/sglang_worker.py @@ -18,7 +18,6 @@ from typing import Dict, List, Optional, Tuple import numpy as np -import torch from omegaconf import DictConfig from sglang.srt.server_args import ServerArgs from transformers import AutoTokenizer @@ -35,7 +34,6 @@ from rlinf.workers.rollout.utils import ( print_sglang_outputs, ) -from toolkits.math_verifier.verify import MathRewardModel, math_verify_call class SGLangWorker(Worker): @@ -233,7 +231,6 @@ def __init__(self, config: DictConfig, placement: ComponentPlacement): self._rollout_end_event = asyncio.Event() self._sync_weight_end_event = asyncio.Event() - self._reward_model = MathRewardModel(scale=self._cfg.reward.reward_scale) assert self._rollout_batch_size is None, ( "rollout_batch_size_per_gpu is not supported in AsyncSGLangWorker" ) @@ -262,29 +259,6 @@ async def init_worker(self): if self._cfg.rollout.validate_weight: await self._validate_weight_at_first() - async def _compute_reward_and_advantage( - self, engine_results: List[Dict], answer: str - ): - answers = [answer] * len(engine_results) - texts: List[str] = [] - for res in engine_results: - if hasattr(res, "text"): - texts.append(res["text"]) - else: - texts.append( - self._tokenizer.decode(res["output_ids"], skip_special_tokens=True) - ) - - results = math_verify_call(texts, answers) - rewards = [r * self._reward_model.scale for r in results] - rewards_tensor = torch.tensor(rewards, dtype=torch.float) - - mean = rewards_tensor.mean() - std = rewards_tensor.std() - advantages = (rewards_tensor - mean) / (std + 1e-6) - - return rewards, advantages.tolist() - async def _async_generate( self, raw_id: int, input_ids: List[int], sampling_params: dict ): diff --git a/rlinf/workers/rollout/vllm/vllm_worker.py b/rlinf/workers/rollout/vllm/vllm_worker.py index b3629b170..3d36b9a44 100644 --- a/rlinf/workers/rollout/vllm/vllm_worker.py +++ b/rlinf/workers/rollout/vllm/vllm_worker.py @@ -35,7 +35,6 @@ from rlinf.scheduler import Channel, Worker from rlinf.utils.placement import ComponentPlacement from rlinf.workers.rollout.utils import print_vllm_outputs -from toolkits.math_verifier.verify import MathRewardModel from . import VLLMExecutor @@ -68,7 +67,6 @@ def __init__(self, config: DictConfig, placement: ComponentPlacement): "The capital of France is", "The future of AI is", ] - self._reward_model = MathRewardModel(self._cfg.reward.reward_scale) self.request_counter = Counter() def _prepare_vllm_environment(self) -> None: diff --git a/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml b/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml index a87c3a474..f34ee4cae 100644 --- a/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml +++ b/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml @@ -283,7 +283,7 @@ actor: reward: use_reward_model: False - reward_type: fim_verify_call + reward_type: code reward_scale: 5.0 critic: diff --git a/toolkits/__init__.py b/toolkits/__init__.py index 8b6f0114a..5b365ea1e 100644 --- a/toolkits/__init__.py +++ b/toolkits/__init__.py @@ -11,22 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -from rlinf.algorithms.registry import get_reward_fn - - -def register_rewards(): - try: - from toolkits.code_verifier.verify import fim_verify_call - - assert get_reward_fn("fim_verify_call") == fim_verify_call - except ImportError: - pass - - try: - from toolkits.math_verifier.verify import math_verify_call - - assert get_reward_fn("math") == math_verify_call - except ImportError: - pass diff --git a/toolkits/code_verifier/verify.py b/toolkits/code_verifier/verify.py index 0ad756d02..5017b9e54 100644 --- a/toolkits/code_verifier/verify.py +++ b/toolkits/code_verifier/verify.py @@ -21,10 +21,8 @@ except ImportError: fuzz = None FUZZY_AVAILABLE = False -from rlinf.algorithms.registry import register_reward_fn -@register_reward_fn("fim_verify_call") def fim_verify_call( responses: List[str], references: List[str], diff --git a/toolkits/math_verifier/verify.py b/toolkits/math_verifier/verify.py index 988b86045..8d0cbdb11 100644 --- a/toolkits/math_verifier/verify.py +++ b/toolkits/math_verifier/verify.py @@ -29,7 +29,6 @@ from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr -from rlinf.algorithms.registry import register_reward_fn from toolkits.math_verifier.parser import extract_answer global_executor = ProcessPoolExecutor(max_workers=40) @@ -389,7 +388,6 @@ def verify_math_solution(answer: str, solution: str): return process_results(answer, solution)[0] -@register_reward_fn("math") def math_verify_call( responses: List[str], references: List[str], @@ -429,31 +427,6 @@ def math_verify_call( return labels -class MathRewardModel: - def __init__(self, scale: float): - self.scale = scale - - def get_reward( - self, response: List[str], reference: List[List[str]] - ) -> List[float]: - """ - Calculates reward scores for a list of responses compared to corresponding lists of reference answers. - For each response, the function checks if it matches any of the provided references using the `process_results` function. - The reward for each response is computed as the first element of the result (converted to float) multiplied by `self.scale`. - Args: - response (List[str]): A list of response strings to be evaluated. - reference (List[List[str]]): A list where each element is a list of reference strings corresponding to each response. - Returns: - List[float]: A list of reward scores, one for each response. - """ - - results = [] - for resp, refs in zip(response, reference): - result = any(process_results(resp, ref)[0] for ref in refs) - results.append((1 if result else -1) * self.scale) - return results - - if __name__ == "__main__": sample = { "answers": ["\\boxed{-\\frac{2}{3}}"], From fbc9be75f6eb807cc7772956d38e0f185a3c8661 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Mon, 13 Oct 2025 14:19:05 +0000 Subject: [PATCH 34/38] fix(lock): set fsdp's recompute_logprobs True for lock competition safety Signed-off-by: Bo Dai --- examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml index 19281cbc9..6d76c780e 100644 --- a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml @@ -48,7 +48,7 @@ algorithm: # val rollout mbs val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} - recompute_logprobs: False + recompute_logprobs: True shuffle_rollout: False # GRPO loss params From 29693590fcc4cc706f22829e5d946a62c73209d0 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Tue, 14 Oct 2025 03:33:41 +0000 Subject: [PATCH 35/38] chore: remove useless code, add correct dp_group param for mg Signed-off-by: Bo Dai --- rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py | 8 -------- rlinf/workers/actor/megatron_actor_worker.py | 1 + 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py b/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py index 9a3cf4d85..519895e49 100644 --- a/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py +++ b/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py @@ -75,14 +75,6 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: super().initialize_from_config(kv_cache_config) def offload_model_weights(self) -> None: - torch.cuda.synchronize() - - model = self.model_runner.model - self._sleep_saved_buffers = { - name: buffer.cpu().clone() for name, buffer in model.named_buffers() - } - torch.cuda.empty_cache() - super().sleep(level=2) def sync_hf_weight(self) -> None: diff --git a/rlinf/workers/actor/megatron_actor_worker.py b/rlinf/workers/actor/megatron_actor_worker.py index 5bd19ddc5..40289ddf0 100644 --- a/rlinf/workers/actor/megatron_actor_worker.py +++ b/rlinf/workers/actor/megatron_actor_worker.py @@ -967,6 +967,7 @@ def _compute_rollout_metrics(self, batch): self.cfg.data.max_prompt_length, self.response_len, self._world_size, + dp_group=parallel_state.get_data_parallel_group(), ) ) From 9dae32e8e9024bec7f27c29c429ddbe01024c150 Mon Sep 17 00:00:00 2001 From: Bo Dai Date: Tue, 14 Oct 2025 03:55:36 +0000 Subject: [PATCH 36/38] fix(reward): move reward worker's timer to where reward computation really happens Signed-off-by: Bo Dai --- rlinf/workers/reward/reward_worker.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/rlinf/workers/reward/reward_worker.py b/rlinf/workers/reward/reward_worker.py index fd2472611..88be65ddc 100644 --- a/rlinf/workers/reward/reward_worker.py +++ b/rlinf/workers/reward/reward_worker.py @@ -60,13 +60,11 @@ def compute_rewards(self, input_channel: Channel, output_channel: Channel): input_channel: The input channel to read from. output_channel: The output channel to send results to. """ - - with self.worker_timer(): - recv_batch_size = 0 - while recv_batch_size < self.total_batch_size_per_dp: - rollout_result: RolloutResult = input_channel.get() - recv_batch_size += rollout_result.num_sequence - + recv_batch_size = 0 + while recv_batch_size < self.total_batch_size_per_dp: + rollout_result: RolloutResult = input_channel.get() + recv_batch_size += rollout_result.num_sequence + with self.worker_timer(): if rollout_result.rewards is None: if self.cfg.reward.use_reward_model: with input_channel.device_lock: @@ -83,11 +81,11 @@ def compute_rewards(self, input_channel: Channel, output_channel: Channel): rollout_result ) - output_channel.put(rollout_result) + output_channel.put(rollout_result) - assert recv_batch_size == self.total_batch_size_per_dp, ( - f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" - ) + assert recv_batch_size == self.total_batch_size_per_dp, ( + f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" + ) def _compute_rule_based_rewards(self, rollout_result: RolloutResult): # Decode only the generated tokens; response_ids are already the post-prompt tokens From 5935efb584355c4672d425181b143e04b3883dec Mon Sep 17 00:00:00 2001 From: Varian-cym <1842506975@qq.com> Date: Fri, 7 Nov 2025 08:23:41 +0000 Subject: [PATCH 37/38] ADD:support npu --- .../config/math/qwen2.5-1.5b-grpo-fsdp.yaml | 20 +- .../config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml | 20 +- .../vqa/qwen2.5-vl-3b-grpo-fsdp_baak.yaml | 222 ++ examples/reasoning/run_main_grpo_math.sh | 4 +- examples/reasoning/run_main_grpo_vqa.sh | 0 fusion_result.json | 21 + kernel_meta/buildPidInfo.json | 2754 +++++++++++++++++ rlinf/algorithms/losses.py | 14 +- rlinf/data/io_struct.py | 28 +- .../hybrid_engines/fsdp/fsdp_model_manager.py | 5 +- .../sglang/sglang_0_5_2/__init__.py | 13 + .../sglang/sglang_0_5_2/io_struct.py | 59 + .../sglang/sglang_0_5_2/sgl_engine.py | 363 +++ .../sglang/sglang_0_5_2/sgl_scheduler.py | 476 +++ .../sglang/sglang_0_5_2/tokenizer_manager.py | 129 + rlinf/utils/distributed.py | 44 +- rlinf/utils/utils.py | 28 +- rlinf/workers/actor/fsdp_actor_worker.py | 66 +- rlinf/workers/actor/fsdp_actor_worker_bak.py | 903 ++++++ rlinf/workers/rollout/sglang/__init__.py | 6 + rlinf/workers/rollout/sglang/sglang_worker.py | 1 + rlinf/workers/rollout/vllm/vllm_worker.py | 3 +- test.py | 3 + 23 files changed, 5079 insertions(+), 103 deletions(-) create mode 100644 examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp_baak.yaml mode change 100644 => 100755 examples/reasoning/run_main_grpo_vqa.sh create mode 100644 fusion_result.json create mode 100644 kernel_meta/buildPidInfo.json create mode 100644 rlinf/hybrid_engines/sglang/sglang_0_5_2/__init__.py create mode 100644 rlinf/hybrid_engines/sglang/sglang_0_5_2/io_struct.py create mode 100644 rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_engine.py create mode 100644 rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_scheduler.py create mode 100644 rlinf/hybrid_engines/sglang/sglang_0_5_2/tokenizer_manager.py create mode 100644 rlinf/workers/actor/fsdp_actor_worker_bak.py create mode 100644 test.py diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml index 6d76c780e..3b31eecdc 100644 --- a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml @@ -83,7 +83,7 @@ rollout: gpu_memory_utilization: 0.55 - model_dir: /path/to/model/DeepSeek-R1-Distill-Qwen-1.5B/ + model_dir: /home/weight/DeepSeek-R1-Distill-Qwen-1.5B-2layer/ model_arch: qwen2.5 enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. distributed_executor_backend: mp # ray or mp @@ -92,10 +92,10 @@ rollout: padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine eos: null # will be tokenizer.eos_token_id if null. - rollout_backend: vllm # here choose which backend to rollout,support [sglang, vllm] + rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm] sglang: - attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + attention_backend: ascend # [flashinfer, triton] for more, see sglang's doc decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. use_torch_compile: False # enable torch_compile in SGLang for rollout. torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. @@ -108,7 +108,7 @@ rollout: return_logprobs: ${not:${algorithm.recompute_logprobs}} - tensor_parallel_size: 2 + tensor_parallel_size: 1 pipeline_parallel_size: 1 validate_weight: False # whether to send all weights at first for weight comparison. @@ -123,14 +123,14 @@ data: dataset_name: boba max_prompt_length: 1024 filter_prompt_by_length: True - rollout_batch_size: 512 + rollout_batch_size: 16 val_rollout_batch_size: null num_workers: 2 shuffle: True validation_shuffle: True seed: 1234 - train_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"] - val_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"] + train_data_paths: ["/home/dataset/boba/AReaL-boba-106k.jsonl"] + val_data_paths: ["/home/dataset/boba/AReaL-boba-106k.jsonl"] prompt_key: prompt image_keys: [image] answer_key: answer @@ -164,7 +164,7 @@ actor: seq_length: ${runner.seq_length} encoder_seq_length: ${runner.seq_length} - model_path: /path/to/model/DeepSeek-R1-Distill-Qwen-1.5B/ + model_path: /home/weight/DeepSeek-R1-Distill-Qwen-1.5B-2layer/ optim: optimizer: adam @@ -194,7 +194,7 @@ actor: lr_decay_iters: 10 tokenizer: - tokenizer_model: /path/to/model/DeepSeek-R1-Distill-Qwen-1.5B/ + tokenizer_model: /home/weight/DeepSeek-R1-Distill-Qwen-1.5B-2layer/ use_fast: False trust_remote_code: True padding_side: 'right' @@ -222,4 +222,4 @@ reward: padding_side: ${actor.tokenizer.padding_side} critic: - use_critic_model: false \ No newline at end of file + use_critic_model: false diff --git a/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml index 8d32a6a33..17c192fa2 100644 --- a/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml +++ b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml @@ -84,7 +84,7 @@ rollout: gpu_memory_utilization: 0.55 - model_dir: /path/to/model/Qwen2.5-VL-3B-Instruct + model_dir: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct model_arch: qwen2.5_vl #qwen2.5 enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. distributed_executor_backend: mp # ray or mp @@ -93,10 +93,10 @@ rollout: padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine eos: null # will be tokenizer.eos_token_id if null. - rollout_backend: vllm # here choose which backend to rollout,support [sglang, vllm] + rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm] sglang: - attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + attention_backend: ascend # [flashinfer, triton] for more, see sglang's doc decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. use_torch_compile: False # enable torch_compile in SGLang for rollout. torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. @@ -109,7 +109,7 @@ rollout: return_logprobs: ${not:${algorithm.recompute_logprobs}} - tensor_parallel_size: 2 + tensor_parallel_size: 1 pipeline_parallel_size: 1 validate_weight: False # whether to send all weights at first for weight comparison. @@ -137,8 +137,8 @@ data: shuffle: True validation_shuffle: True seed: 1234 - train_data_paths: ["/dataset/robo2vlm-1/data/train/"] - val_data_paths: ["/dataset/robo2vlm-1/data/val/"] + train_data_paths: ["/home/x00922209/datasets/advaitgupta/robo2VLM/data/"] + val_data_paths: ["/home/x00922209/datasets/advaitgupta/robo2VLM/data/"] actor: group_name: "ActorGroup" @@ -165,7 +165,7 @@ actor: seq_length: ${runner.seq_length} encoder_seq_length: ${runner.seq_length} - model_path: /path/to/model/Qwen2.5-VL-3B-Instruct/ + model_path: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct model_arch: ${rollout.model_arch} @@ -197,7 +197,7 @@ actor: lr_decay_iters: 10 tokenizer: - tokenizer_model: /path/to/model/Qwen2.5-VL-3B-Instruct + tokenizer_model: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct use_fast: False trust_remote_code: True padding_side: 'right' @@ -219,10 +219,10 @@ reward: answer_format: 0.0 tokenizer: - tokenizer_model: /path/to/model/Qwen2.5-VL-3B-Instruct + tokenizer_model: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct use_fast: False trust_remote_code: True padding_side: 'right' critic: - use_critic_model: false \ No newline at end of file + use_critic_model: false diff --git a/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp_baak.yaml b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp_baak.yaml new file mode 100644 index 000000000..fa5b16080 --- /dev/null +++ b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp_baak.yaml @@ -0,0 +1,222 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: ${runner.output_dir}/${runner.experiment_name} + project_name: rlinf + experiment_name: ${runner.experiment_name} + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 5 + max_steps: -1 + + val_check_interval: 1 + save_interval: 50 + + seq_length: 2048 + + enable_dynamic_batch_size: False + max_tokens_per_mbs: 28672 + + resume_dir: null + experiment_name: grpo-1.5b + output_dir: ../results + +algorithm: + group_size: 8 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: False + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct + model_arch: qwen2.5_vl #qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. taoxu 1010 + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm] + + sglang: + attention_backend: ascend # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: vision_language + dataset_name: robo2vlm + max_prompt_length: 1024 + filter_prompt_by_length: True + rollout_batch_size: 8 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + image_keys: ["image"] # some vlm datasets may have multiple image columns + choice_key: "choices" + answer_key: "answer" + solution_key: "solution" + use_chat_template: True + lazy_loading: True + shuffle: True + validation_shuffle: True + seed: 1234 + train_data_paths: ["/home/x00922209/datasets/advaitgupta/robo2VLM/data/"] + val_data_paths: ["/home/x00922209/datasets/advaitgupta/robo2VLM/data/"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct + + model_arch: ${rollout.model_arch} + + optim: + optimizer: adam + bf16: True #False + fp16: False #True + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct + use_fast: False + trust_remote_code: True + padding_side: 'right' + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'vqa' + reward_scale: 1.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct + use_fast: False + trust_remote_code: True + padding_side: 'right' + +critic: + use_critic_model: false \ No newline at end of file diff --git a/examples/reasoning/run_main_grpo_math.sh b/examples/reasoning/run_main_grpo_math.sh index 18a48d780..1b123fd6f 100644 --- a/examples/reasoning/run_main_grpo_math.sh +++ b/examples/reasoning/run_main_grpo_math.sh @@ -14,7 +14,7 @@ export PYTHONPATH=${REPO_PATH}:${MEGATRON_PATH}:$PYTHONPATH if [ -z "$1" ]; then CONFIG_NAME="qwen2.5-1.5b-grpo-megatron" else - CONFIG_NAME=$1 + CONFIG_NAME="qwen2.5-1.5b-grpo-fsdp.yaml" fi -python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path ${CONFIG_PATH}/config/math/ --config-name $CONFIG_NAME \ No newline at end of file +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path ${CONFIG_PATH}/config/math/ --config-name $CONFIG_NAME diff --git a/examples/reasoning/run_main_grpo_vqa.sh b/examples/reasoning/run_main_grpo_vqa.sh old mode 100644 new mode 100755 diff --git a/fusion_result.json b/fusion_result.json new file mode 100644 index 000000000..20b56950d --- /dev/null +++ b/fusion_result.json @@ -0,0 +1,21 @@ +{ + "session_and_graph_id_0_0": { + "graph_fusion": { + "IndexByTensorStaticFusionPass": { + "effect_times": "0", + "match_times": "1" + }, + "RefreshInt64ToInt32FusionPass": { + "effect_times": "0", + "match_times": "1" + } + }, + "ub_fusion": { + "AutomaticUbFusion": { + "effect_times": "0", + "match_times": "2", + "repository_hit_times": "0" + } + } + } +} \ No newline at end of file diff --git a/kernel_meta/buildPidInfo.json b/kernel_meta/buildPidInfo.json new file mode 100644 index 000000000..9900ce4f6 --- /dev/null +++ b/kernel_meta/buildPidInfo.json @@ -0,0 +1,2754 @@ +[ + [ + 39911, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12797813670961250527" + ], + [ + 39913, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2319069557219941997" + ], + [ + 39917, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11862212240406296461" + ], + [ + 39922, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_885784172161213002" + ], + [ + 39926, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7410510130683415152" + ], + [ + 39930, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9194155428752533289" + ], + [ + 39934, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10424833431148169855" + ], + [ + 39936, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9855745055097868770" + ], + [ + 39939, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18191654686185573965" + ], + [ + 39942, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1065590051412680278" + ], + [ + 39949, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16967210759787679084" + ], + [ + 39950, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8381815239432563187" + ], + [ + 39958, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2867384563160197328" + ], + [ + 39965, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4547732597630948415" + ], + [ + 39971, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15778275423241821855" + ], + [ + 39980, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5207596602532600872" + ], + [ + 40666, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15835674102581580671" + ], + [ + 82796, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8782270786933035597" + ], + [ + 100620, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7353802482069301333" + ], + [ + 112704, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7725153434515021512" + ], + [ + 122239, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7210453806497323497" + ], + [ + 146045, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3405930793784815214" + ], + [ + 146076, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12695548580890247073" + ], + [ + 146140, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1605098968774552756" + ], + [ + 146205, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5766664014834865721" + ], + [ + 146347, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9612635458804393825" + ], + [ + 146447, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7725679307351709877" + ], + [ + 146559, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14784440107212476086" + ], + [ + 146567, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17267848541246919924" + ], + [ + 146569, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17784643373181208859" + ], + [ + 146573, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10067831816227681826" + ], + [ + 146589, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1044498738075030000" + ], + [ + 158951, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15191603587064731993" + ], + [ + 159046, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6443795088650938062" + ], + [ + 159116, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7617045580398011601" + ], + [ + 159145, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17761038868925529398" + ], + [ + 159234, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8618471054285342305" + ], + [ + 159360, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16579794217881527508" + ], + [ + 159440, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10254394513325194666" + ], + [ + 159460, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11087666240252757180" + ], + [ + 159513, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9428322862212375668" + ], + [ + 159626, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16449556766463462195" + ], + [ + 159768, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1071949320798082478" + ], + [ + 159861, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9565697075048113424" + ], + [ + 159886, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3062431087256661843" + ], + [ + 160022, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13353626036029525043" + ], + [ + 161093, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_508908986595913110" + ], + [ + 162181, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3563660611022476863" + ], + [ + 259575, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8201276825499370867" + ], + [ + 259577, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15529659103780327911" + ], + [ + 259581, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10174461859794700718" + ], + [ + 259585, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16507297788703908047" + ], + [ + 259587, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11966777041657863002" + ], + [ + 259592, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2726330588881717531" + ], + [ + 259596, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17030647279409832227" + ], + [ + 259598, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12186607257866309998" + ], + [ + 259601, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3768772217671096392" + ], + [ + 259607, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3911160194709892748" + ], + [ + 259612, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7977521137351637493" + ], + [ + 259619, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4128563410056480856" + ], + [ + 259626, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9581134290190490336" + ], + [ + 259628, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_702154103913980934" + ], + [ + 259636, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14027257268528873801" + ], + [ + 259643, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17044983378904657187" + ], + [ + 260342, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3738242728542939509" + ], + [ + 260402, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15806158438452141266" + ], + [ + 260467, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4567683529506793175" + ], + [ + 260537, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10271642930345239273" + ], + [ + 260629, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3945686445184445770" + ], + [ + 260690, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17057962697473254798" + ], + [ + 260755, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16871514587483264100" + ], + [ + 260763, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2675279169353877702" + ], + [ + 260795, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14119700360366649411" + ], + [ + 315952, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11641596342314442912" + ], + [ + 329434, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17017774772109951435" + ], + [ + 340293, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_584578301371656169" + ], + [ + 352185, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11288852585934677911" + ], + [ + 367542, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6263883255072188097" + ], + [ + 367633, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5888503093289751929" + ], + [ + 367666, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5420145819910730026" + ], + [ + 379262, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1783494569175948891" + ], + [ + 379266, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_537678043780734220" + ], + [ + 379285, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5877535174558663131" + ], + [ + 379286, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10249842372811971323" + ], + [ + 379383, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_656090948065222947" + ], + [ + 379674, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11244763063226368392" + ], + [ + 380304, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_231890364717191513" + ], + [ + 380313, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5692311019013524597" + ], + [ + 380315, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11068432615228024246" + ], + [ + 380440, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15012465798244511954" + ], + [ + 380532, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3776119111621864505" + ], + [ + 380533, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6221575824346386447" + ], + [ + 380834, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11465674300335516174" + ], + [ + 381115, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7186726943654900649" + ], + [ + 381318, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12507534548909153748" + ], + [ + 381749, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15611949621196386618" + ], + [ + 446067, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15585473429645162327" + ], + [ + 446070, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2905662668684958995" + ], + [ + 446077, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9261524943542376262" + ], + [ + 446081, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11733848515447041665" + ], + [ + 446083, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13074589690020057687" + ], + [ + 446087, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6740951685981297380" + ], + [ + 446089, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1230291348499102449" + ], + [ + 446095, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8512469513541935614" + ], + [ + 446098, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6734225289254301183" + ], + [ + 446100, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3193003074983928221" + ], + [ + 446107, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17730893955086431941" + ], + [ + 446108, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14296202533891893120" + ], + [ + 446115, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17201341106939116953" + ], + [ + 446122, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17335486450140669799" + ], + [ + 446128, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_172524440323681068" + ], + [ + 446139, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8488055523413543743" + ], + [ + 446862, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11022287578591461948" + ], + [ + 475137, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5524899443716147254" + ], + [ + 484622, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9030794170733368903" + ], + [ + 494550, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10177095432241817550" + ], + [ + 508239, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10453435395226908113" + ], + [ + 519345, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15868233357086722582" + ], + [ + 552276, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15438796748740496551" + ], + [ + 552339, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6239891206884327744" + ], + [ + 552379, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9273241722792499069" + ], + [ + 552468, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16552715273014474462" + ], + [ + 552569, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3196337734268543926" + ], + [ + 552675, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5847337843613087850" + ], + [ + 552742, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6352307575227886689" + ], + [ + 552750, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1982943654522997836" + ], + [ + 552783, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3828732781290684800" + ], + [ + 552794, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16977504117571897465" + ], + [ + 565573, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5204528607269617447" + ], + [ + 565699, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4727225444131891245" + ], + [ + 565716, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3170838120379853841" + ], + [ + 565750, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2776707022079058597" + ], + [ + 566045, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12036042907429605982" + ], + [ + 566310, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12421069780630334687" + ], + [ + 566383, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3217743380181589526" + ], + [ + 566422, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1276333807841191048" + ], + [ + 566494, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7963328572114012833" + ], + [ + 566534, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6674089693496856100" + ], + [ + 566844, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9863761807279720849" + ], + [ + 567668, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16074272625414500687" + ], + [ + 567773, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4282295767343065807" + ], + [ + 567937, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2161356974128856154" + ], + [ + 567993, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8242040281073172353" + ], + [ + 568229, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15920297452776001566" + ], + [ + 637429, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_959368100671980870" + ], + [ + 637431, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14153015096824152693" + ], + [ + 637436, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_899818523044574745" + ], + [ + 637439, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10247733762527698150" + ], + [ + 637444, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14145287080845176151" + ], + [ + 637447, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17222031087440245104" + ], + [ + 637453, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12136397179187398942" + ], + [ + 637456, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8731812876370325350" + ], + [ + 637459, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10602237885762750118" + ], + [ + 637462, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16594426567840719304" + ], + [ + 637468, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7366125432199894168" + ], + [ + 637476, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11110698096089951252" + ], + [ + 637484, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18188682871993650937" + ], + [ + 637489, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10878533416589658592" + ], + [ + 637495, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8079586613243454881" + ], + [ + 637502, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2215817028210768871" + ], + [ + 670212, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15344779562198194740" + ], + [ + 680392, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18388393187931224557" + ], + [ + 690762, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3514287218513747531" + ], + [ + 743351, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6478191836071807723" + ], + [ + 743388, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12832048017481839151" + ], + [ + 743400, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10541782531111842730" + ], + [ + 743475, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10763222511672298619" + ], + [ + 743664, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17197230416625906030" + ], + [ + 743787, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18047780774813550571" + ], + [ + 743849, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2446754395186006741" + ], + [ + 743875, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5679073243145878981" + ], + [ + 743879, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16603745088718856023" + ], + [ + 743887, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13969638480830774322" + ], + [ + 743895, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13369228484466567544" + ], + [ + 743998, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12762855693433035364" + ], + [ + 744081, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3347649343482407740" + ], + [ + 757025, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15796729346284784432" + ], + [ + 757028, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14967701666120261818" + ], + [ + 757039, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5372387541384308030" + ], + [ + 757071, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3245068616512596448" + ], + [ + 757197, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14277746927271838195" + ], + [ + 757272, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1831510647579988021" + ], + [ + 758082, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1963453102454522286" + ], + [ + 803617, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15623709472216942560" + ], + [ + 803620, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5962434898661482018" + ], + [ + 803625, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14549239843071968661" + ], + [ + 803630, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8567544072600264650" + ], + [ + 803633, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18407024629481036527" + ], + [ + 803664, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15960439059008180895" + ], + [ + 803666, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3263727124264509013" + ], + [ + 803667, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9451034478073827629" + ], + [ + 803669, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18226372968931963744" + ], + [ + 803670, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14202429746807222158" + ], + [ + 803674, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7069975997530525703" + ], + [ + 803675, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8834989374251547639" + ], + [ + 803679, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17156716697792514660" + ], + [ + 803681, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7928243516609743611" + ], + [ + 803683, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17472278663547573068" + ], + [ + 803688, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9297387177610951812" + ], + [ + 803858, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9931525164958826644" + ], + [ + 803962, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18009242967987784208" + ], + [ + 804016, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2874613183183755843" + ], + [ + 804136, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4244083726001215679" + ], + [ + 804179, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9469466012086032567" + ], + [ + 804247, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6732244663862958131" + ], + [ + 804321, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17195272942189944812" + ], + [ + 804392, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10755163416630026042" + ], + [ + 804513, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12350748768321932278" + ], + [ + 804688, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6894846644228586388" + ], + [ + 873633, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15791320709907010305" + ], + [ + 888932, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17017344803712162999" + ], + [ + 898268, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12878681295278483117" + ], + [ + 904854, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13733947334674992446" + ], + [ + 910408, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3037428981486910277" + ], + [ + 910622, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12284873316976139363" + ], + [ + 922414, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4043581537781464345" + ], + [ + 922752, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14805333628576514738" + ], + [ + 922753, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5878253078724431956" + ], + [ + 922763, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15057904892991266440" + ], + [ + 922779, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18116003072026181513" + ], + [ + 922864, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3039955536516904891" + ], + [ + 922878, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8073047691309365525" + ], + [ + 923385, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3835924806113435602" + ], + [ + 923770, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9436136851994508472" + ], + [ + 923881, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12348575658552678709" + ], + [ + 924169, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3100439312273551508" + ], + [ + 924681, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3582699337697878406" + ], + [ + 925395, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17118759297342529823" + ], + [ + 925409, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18056433885446449459" + ], + [ + 925631, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13111124903733223810" + ], + [ + 925674, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2449886720924222030" + ], + [ + 974754, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_206876370358077584" + ], + [ + 974757, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11257961443249735269" + ], + [ + 998802, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4978121236847435213" + ], + [ + 1003838, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17874922247557918464" + ], + [ + 1008420, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14756252425461487877" + ], + [ + 1051386, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15672847212087630880" + ], + [ + 1056292, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11690764893981564089" + ], + [ + 1077929, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18243756878085111260" + ], + [ + 1078664, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13101390950818077021" + ], + [ + 1092149, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4702484821586339293" + ], + [ + 1092150, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8700607087392051310" + ], + [ + 1092167, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11867004705859069539" + ], + [ + 1092168, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14425483843466519342" + ], + [ + 1092170, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3137345424683392113" + ], + [ + 1092171, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5580715293573031613" + ], + [ + 1092247, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9573985685283690942" + ], + [ + 1092248, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1117326061986011595" + ], + [ + 1092430, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12852273377442787843" + ], + [ + 1092431, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2948619980269787118" + ], + [ + 1092814, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16608503441526049509" + ], + [ + 1092815, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8557565976778605998" + ], + [ + 1092974, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1586571163572556133" + ], + [ + 1092975, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15935510270417225075" + ], + [ + 1093267, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4402275187066910043" + ], + [ + 1093268, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5557153902550791572" + ], + [ + 1196049, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8425342506929295889" + ], + [ + 1196053, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18378899707502877030" + ], + [ + 1196056, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3094735019997891187" + ], + [ + 1196060, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6889997414760764962" + ], + [ + 1196065, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16740248499631073491" + ], + [ + 1196068, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8147418516272875846" + ], + [ + 1196070, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7125085183662271974" + ], + [ + 1196076, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17768317808754919941" + ], + [ + 1196345, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12412202691532259202" + ], + [ + 1196434, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10867201177510314690" + ], + [ + 1196530, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17300580725555235990" + ], + [ + 1196591, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11922602998666607133" + ], + [ + 1196658, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17067792987140972942" + ], + [ + 1196718, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6156226292697541158" + ], + [ + 1196789, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11972615607024541236" + ], + [ + 1225343, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16644111413419115375" + ], + [ + 1248801, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18426022297842133396" + ], + [ + 1254600, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13015264027012895286" + ], + [ + 1302314, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2178400540252727344" + ], + [ + 1302417, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15389620440927117005" + ], + [ + 1302502, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13548285859569365661" + ], + [ + 1302542, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3104574529036971715" + ], + [ + 1302580, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9161340348181567069" + ], + [ + 1302678, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2235073576100882104" + ], + [ + 1313288, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10714691217339955812" + ], + [ + 1313289, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15436382972359969753" + ], + [ + 1313295, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10087935773393978978" + ], + [ + 1313297, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1956045676144969910" + ], + [ + 1313413, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10346769137148346031" + ], + [ + 1313414, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12500094804178281846" + ], + [ + 1313594, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2233046969904106531" + ], + [ + 1313602, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17274697924548824848" + ], + [ + 1313910, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1971305477396827550" + ], + [ + 1313911, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1123480945946248614" + ], + [ + 1313913, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13632501687006175122" + ], + [ + 1313914, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17948247176888184637" + ], + [ + 1314424, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4838020015312522463" + ], + [ + 1314425, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5662771786044476863" + ], + [ + 1314464, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11644308174162608419" + ], + [ + 1314465, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14904900394694390614" + ], + [ + 1364104, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12599208991475254956" + ], + [ + 1364109, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11714438223684649350" + ], + [ + 1364112, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12536573013377539960" + ], + [ + 1364118, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9176497740794462077" + ], + [ + 1364120, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12851108834171289286" + ], + [ + 1364124, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7874367383020722434" + ], + [ + 1364133, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11768575761097821347" + ], + [ + 1364136, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9136218849828640261" + ], + [ + 1364436, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15227420143784422207" + ], + [ + 1364530, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_678152374747580855" + ], + [ + 1364650, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1863822375592647081" + ], + [ + 1364684, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14158128526209592340" + ], + [ + 1364721, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4584616283426803771" + ], + [ + 1364806, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13235463644610255264" + ], + [ + 1364877, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7481851041269423666" + ], + [ + 1364909, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9126134503219493660" + ], + [ + 1365000, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12501320043627540492" + ], + [ + 1365092, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16659304774143478016" + ], + [ + 1398994, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7719667656758625427" + ], + [ + 1414384, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8539755814366516822" + ], + [ + 1425199, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12779466743610801848" + ], + [ + 1434341, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5346047771258662063" + ], + [ + 1443931, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8500329968264923909" + ], + [ + 1471398, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2835195402119160546" + ], + [ + 1480308, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17921460277698942329" + ], + [ + 1480309, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6060307077504199420" + ], + [ + 1480544, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6633913915366293595" + ], + [ + 1480545, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5929325753921586179" + ], + [ + 1480762, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5087823519142655646" + ], + [ + 1480763, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17650771157019299908" + ], + [ + 1481439, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17960349755202839121" + ], + [ + 1481440, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11344166716576316064" + ], + [ + 1481478, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18011110268539351239" + ], + [ + 1481479, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13212044995188295238" + ], + [ + 1481983, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7786193257091011544" + ], + [ + 1481984, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4535156793004568521" + ], + [ + 1482479, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7392738499711986039" + ], + [ + 1482480, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16803447451398881609" + ], + [ + 1482675, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5816531416034365646" + ], + [ + 1482676, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10586423650728537576" + ], + [ + 1531952, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15737151482683442731" + ], + [ + 1531958, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5428763241135066291" + ], + [ + 1531961, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10672990360669572384" + ], + [ + 1531965, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7671654195981105868" + ], + [ + 1531970, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14942561136263416232" + ], + [ + 1531973, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16581135063665980351" + ], + [ + 1531977, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4856394302363438782" + ], + [ + 1531981, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15877297897966392256" + ], + [ + 1532252, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6425863017087684784" + ], + [ + 1532297, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1423327815346902785" + ], + [ + 1532410, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16344047352049658364" + ], + [ + 1532505, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17305619576099580631" + ], + [ + 1532566, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8395635612082422722" + ], + [ + 1532630, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8519186747189401928" + ], + [ + 1532664, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1253792069041758400" + ], + [ + 1532734, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16904214805911888839" + ], + [ + 1532855, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8506737276851019991" + ], + [ + 1532973, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_649535223346514500" + ], + [ + 1533011, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5898366698551436494" + ], + [ + 1533074, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5096822643445805939" + ], + [ + 1533138, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14031496100727112023" + ], + [ + 1533230, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6266968138694899706" + ], + [ + 1533290, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16937440140735016382" + ], + [ + 1533334, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10380564371013497541" + ], + [ + 1646912, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16348313361094870629" + ], + [ + 1646913, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16036313845476488225" + ], + [ + 1646927, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7913415713848588049" + ], + [ + 1646928, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4103568604206357407" + ], + [ + 1647271, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3917280544582020227" + ], + [ + 1647272, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13152661351180665451" + ], + [ + 1647797, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6759349546217838702" + ], + [ + 1647798, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10771757146841459066" + ], + [ + 1649758, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6186015994533271987" + ], + [ + 1649759, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9297921486035376125" + ], + [ + 1650053, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17376928305113794204" + ], + [ + 1650054, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14705463419645039338" + ], + [ + 1650074, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17507795361488404022" + ], + [ + 1650075, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_791693539870127207" + ], + [ + 1650077, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11601476096660261739" + ], + [ + 1650078, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12330363696036321630" + ], + [ + 1698356, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3664737074668484302" + ], + [ + 1698360, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_307407573315947332" + ], + [ + 1698364, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9769783361472155454" + ], + [ + 1698365, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11620844252758829593" + ], + [ + 1698369, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18049344638709473559" + ], + [ + 1698371, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12020994052316156595" + ], + [ + 1698376, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11281246552810662244" + ], + [ + 1698381, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2375830332811294762" + ], + [ + 1698622, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17464125010525395920" + ], + [ + 1698715, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8252497263861944219" + ], + [ + 1698754, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2367183499606118175" + ], + [ + 1698875, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3307535582648818046" + ], + [ + 1698960, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14070843042311209568" + ], + [ + 1699023, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3711321641783824900" + ], + [ + 1699057, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4067246543661527215" + ], + [ + 1699128, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9510008041698682979" + ], + [ + 1699222, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15070956721185346780" + ], + [ + 1699307, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16822435983994287417" + ], + [ + 1699350, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9200131402658211969" + ], + [ + 1699410, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4312347285030704895" + ], + [ + 1699501, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8452932345320744230" + ], + [ + 1699592, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18319826375533952136" + ], + [ + 1745273, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9427035578749238306" + ], + [ + 1752419, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11861522220206622819" + ], + [ + 1815709, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8461056289216019937" + ], + [ + 1815710, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15670507367599225435" + ], + [ + 1815718, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14014685830231205955" + ], + [ + 1815719, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8933292223775536158" + ], + [ + 1815722, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4670593314428329079" + ], + [ + 1815723, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14753117404955500077" + ], + [ + 1815761, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14867940080031341584" + ], + [ + 1815765, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9603778874236517966" + ], + [ + 1815801, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12882566826319414462" + ], + [ + 1815802, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10278462134454322731" + ], + [ + 1815807, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_634460488610617389" + ], + [ + 1815809, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3787426685243459175" + ], + [ + 1816092, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6900943374411146311" + ], + [ + 1816098, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2936618988510319244" + ], + [ + 1816749, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5025953421479037474" + ], + [ + 1816757, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10504877859244548446" + ], + [ + 3034511, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3109980201143968959" + ], + [ + 3034516, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16225395238917630830" + ], + [ + 3034519, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3014914144290525407" + ], + [ + 3034520, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4368939354329067568" + ], + [ + 3034525, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8231836696933284847" + ], + [ + 3034528, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11274985324515364055" + ], + [ + 3034533, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16776878341103524463" + ], + [ + 3034534, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11836972797422572451" + ], + [ + 3034538, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11222685037651458786" + ], + [ + 3034543, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2358340237215830214" + ], + [ + 3034547, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5530069306192163345" + ], + [ + 3034553, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15467614126156452015" + ], + [ + 3034559, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5799901149805008020" + ], + [ + 3034568, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13480646072272034112" + ], + [ + 3034574, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9925042805192677700" + ], + [ + 3034584, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5749985854107212444" + ], + [ + 3139002, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16606806383596787697" + ], + [ + 3139034, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13693873830274303283" + ], + [ + 3139037, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1174647644274299681" + ], + [ + 3139041, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8463715729614322689" + ], + [ + 3139046, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13622137732458691231" + ], + [ + 3139061, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13976382821074310028" + ], + [ + 3139103, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5674315837922161200" + ], + [ + 3139186, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9519308949519055442" + ], + [ + 3139256, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11870662408869016824" + ], + [ + 3139325, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9203351225653898140" + ], + [ + 3139379, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15928929015260880554" + ], + [ + 3139390, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4002276556586837015" + ], + [ + 3139395, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_869027509639957243" + ], + [ + 3139439, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7590980568514383017" + ], + [ + 3139477, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5691785906114378471" + ], + [ + 3139578, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1636167670197209792" + ], + [ + 3153344, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14721186196409181902" + ], + [ + 3153350, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8038130317126733515" + ], + [ + 3153355, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9363551798712942410" + ], + [ + 3153360, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4001983642343932187" + ], + [ + 3153908, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3967378375263321794" + ], + [ + 3154349, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1956150361493637850" + ], + [ + 3154374, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9515009690803494350" + ], + [ + 3154941, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6403793289284352763" + ], + [ + 3155135, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18377681884643089435" + ], + [ + 3155270, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2894252287023698380" + ], + [ + 3155272, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1176757978366636180" + ], + [ + 3155356, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10243682700450778485" + ], + [ + 3156048, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13419435171562057387" + ], + [ + 3156349, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5310938153691480087" + ], + [ + 3156373, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2839757395093087158" + ], + [ + 3156759, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_189645908260721847" + ], + [ + 3357366, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14712843742675416789" + ], + [ + 3357370, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14441746631406223373" + ], + [ + 3357372, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16051981893101423452" + ], + [ + 3357379, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7313357428863899572" + ], + [ + 3357382, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15741756964888684277" + ], + [ + 3357385, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_518244284579097490" + ], + [ + 3357390, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9623044147940949233" + ], + [ + 3357396, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_188046812284482511" + ], + [ + 3357815, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2070546872302042284" + ], + [ + 3357878, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11580410533744131165" + ], + [ + 3357919, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9721818878768485015" + ], + [ + 3358001, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5097555846270763399" + ], + [ + 3358063, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12100682775419753259" + ], + [ + 3358121, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_831802341219462707" + ], + [ + 3358154, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3511804342372595424" + ], + [ + 3358199, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13392040954474567685" + ], + [ + 3358318, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1488600423820696507" + ], + [ + 3358415, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7033478551626861368" + ], + [ + 3358501, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10949369291795034167" + ], + [ + 3358546, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9575369317971095792" + ], + [ + 3358599, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4171338315313792908" + ], + [ + 3358636, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3479887625518456208" + ], + [ + 3358704, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10816808840412347489" + ], + [ + 3358735, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4565693084836619869" + ], + [ + 3473543, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6198223556594770384" + ], + [ + 3473551, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10258537246918590156" + ], + [ + 3473616, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8769553907161686792" + ], + [ + 3473617, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3773271510051838348" + ], + [ + 3473932, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13489168884906494727" + ], + [ + 3473933, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18253176944629380716" + ], + [ + 3473943, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10561446346587147503" + ], + [ + 3473944, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1057740105323334353" + ], + [ + 3474186, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16534136807012196550" + ], + [ + 3474187, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3053221032925412808" + ], + [ + 3474202, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11135988479496829721" + ], + [ + 3474205, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15763171237614044906" + ], + [ + 3475761, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18163430628544684227" + ], + [ + 3475762, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_906162624150815971" + ], + [ + 3475895, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16284538720262805411" + ], + [ + 3475896, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6471773242206149096" + ], + [ + 3524930, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7351702211122833722" + ], + [ + 3524938, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_849812185320057193" + ], + [ + 3524944, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_899185330692270860" + ], + [ + 3524949, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16251077242215705442" + ], + [ + 3524957, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17261980608195435381" + ], + [ + 3524959, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15728989693202867994" + ], + [ + 3524963, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_578314422480040874" + ], + [ + 3524972, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5858027835250679442" + ], + [ + 3625661, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8595121483727162241" + ], + [ + 3628698, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6644896138201857560" + ], + [ + 3628765, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14380676338487055612" + ], + [ + 3629117, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3884026636792605184" + ], + [ + 3629128, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8471987555244152278" + ], + [ + 3629159, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10706976348774196355" + ], + [ + 3629253, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6353439240478980341" + ], + [ + 3629283, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11201244953910665068" + ], + [ + 3629360, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13965969812663104297" + ], + [ + 3629455, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15143351641328536949" + ], + [ + 3629517, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1731228908521478078" + ], + [ + 3629588, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1280418295423790350" + ], + [ + 3629613, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8333685548859735493" + ], + [ + 3629770, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6530807606746596767" + ], + [ + 3629792, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7249227128131175886" + ], + [ + 3629874, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17572995736186529549" + ], + [ + 3642042, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_504718820247128253" + ], + [ + 3642043, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5621471611596574380" + ], + [ + 3642063, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5859789574421468396" + ], + [ + 3642065, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1749104467585305323" + ], + [ + 3642198, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16903094956545060933" + ], + [ + 3642199, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15476704792505050546" + ], + [ + 3642237, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16807932559841593208" + ], + [ + 3642238, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8310833007698505204" + ], + [ + 3642367, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4707148721279721064" + ], + [ + 3642371, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9013616000502015446" + ], + [ + 3642732, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12951651460115558391" + ], + [ + 3642733, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6856766171535732498" + ], + [ + 3643042, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14506781066108968506" + ], + [ + 3643043, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7374279504609390465" + ], + [ + 3643572, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5955043855143664956" + ], + [ + 3643573, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8196104095227417096" + ], + [ + 3691270, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2620465584826061471" + ], + [ + 3691273, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1757388310219074848" + ], + [ + 3691281, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11097774192912884566" + ], + [ + 3691283, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4742545615202517810" + ], + [ + 3691285, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16070580061589069085" + ], + [ + 3691287, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6948592204353944837" + ], + [ + 3691292, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4433630213579020385" + ], + [ + 3691296, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15584731480290088561" + ], + [ + 3691301, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_221849414828219018" + ], + [ + 3691305, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_447840502870009883" + ], + [ + 3691312, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15894757923477358902" + ], + [ + 3691316, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2064442071811218632" + ], + [ + 3691318, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16020379562811154133" + ], + [ + 3691328, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1281388017954056527" + ], + [ + 3691337, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2331797143551461098" + ], + [ + 3691344, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13357220641161590082" + ], + [ + 3692044, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18315790265804955058" + ], + [ + 3692064, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17507869003332814399" + ], + [ + 3692144, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4507988044373530697" + ], + [ + 3692262, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18426271347408772907" + ], + [ + 3722920, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5017140278663970478" + ], + [ + 3736238, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12999327330513004683" + ], + [ + 3756817, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11445498054270621315" + ], + [ + 3765828, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1383945043061831987" + ], + [ + 3774374, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12318664942362170822" + ], + [ + 3783232, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3418024057005528102" + ], + [ + 3797970, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8613331899495270727" + ], + [ + 3798096, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9878016966715019087" + ], + [ + 3798154, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6627377776921054141" + ], + [ + 3798272, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10804569796438592151" + ], + [ + 3798362, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13557064605936120080" + ], + [ + 3798416, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5828012701014328691" + ], + [ + 3810814, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8002480603490928081" + ], + [ + 3811199, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15215890719473899494" + ], + [ + 3811238, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5765224812454476842" + ], + [ + 3811276, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16426201200428659968" + ], + [ + 3811288, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17001284103737435638" + ], + [ + 3811588, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15465033144293705320" + ], + [ + 3811661, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11237554842213579307" + ], + [ + 3811984, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4208370878176288374" + ], + [ + 3812122, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_859880967845588792" + ], + [ + 3812520, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11075357886525732487" + ], + [ + 3812639, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8682143685678125481" + ], + [ + 3812662, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12325636616321695255" + ], + [ + 3812698, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14800074808472192333" + ], + [ + 3812997, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12158776810784339196" + ], + [ + 3813382, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7085330449950753708" + ], + [ + 3813392, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9314724528373063174" + ], + [ + 3868673, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8427308000189841720" + ], + [ + 3868680, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_731075373502833868" + ], + [ + 3868682, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9807646971983976701" + ], + [ + 3868684, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_708313394526412267" + ], + [ + 3868688, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13447384123535294566" + ], + [ + 3868692, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14036027334014170087" + ], + [ + 3868696, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3187295055034313317" + ], + [ + 3868700, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3850755952385802565" + ], + [ + 3868705, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10106160760700975956" + ], + [ + 3868706, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17824178791258425552" + ], + [ + 3868711, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11909562160231335840" + ], + [ + 3868719, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13607563626468215022" + ], + [ + 3868727, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16787076582919642185" + ], + [ + 3868737, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8246440632798892168" + ], + [ + 3868739, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4159297518721824730" + ], + [ + 3868745, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18015449808950218451" + ], + [ + 3869439, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14196112055158873872" + ], + [ + 3869482, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14077174710107593045" + ], + [ + 3869512, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8061984655211037805" + ], + [ + 3869576, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16985047152068049914" + ], + [ + 3869667, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12041070490733630674" + ], + [ + 3869787, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5526959636930734514" + ], + [ + 3869824, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14021406351436652831" + ], + [ + 3903350, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8500099511158252282" + ], + [ + 3908889, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12265936708546584396" + ], + [ + 3919175, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6922556638284219234" + ], + [ + 3939646, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17523183083010355471" + ], + [ + 3953020, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6290038540260580030" + ], + [ + 3965009, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14272343095608536296" + ], + [ + 3976392, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5029439665772523493" + ], + [ + 3976512, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11697369406167774290" + ], + [ + 3976632, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3263338739247539272" + ], + [ + 3988513, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_787927194113102781" + ], + [ + 3988531, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18023508788842543544" + ], + [ + 3988558, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12183092996673180058" + ], + [ + 3988563, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4831310188807215850" + ], + [ + 3988807, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2198912152797944480" + ], + [ + 3988848, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4213118868561974759" + ], + [ + 3988850, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3232631311510150968" + ], + [ + 3988957, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11047352039168056415" + ], + [ + 3989106, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16542734713691156018" + ], + [ + 3989387, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15165735529962902451" + ], + [ + 3989535, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12857832019022611393" + ], + [ + 3989541, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8287593005309603838" + ], + [ + 3989647, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18367334379009715015" + ], + [ + 3990050, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4446833231210301139" + ], + [ + 3990104, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5145384017407119481" + ], + [ + 3990359, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6801533416498745154" + ], + [ + 4039903, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17584089607946077100" + ], + [ + 4039907, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_503447699485187456" + ], + [ + 4039912, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15069983805839197528" + ], + [ + 4039916, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14542568312346822221" + ], + [ + 4039918, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4892714578234319308" + ], + [ + 4039925, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5411057577383729197" + ], + [ + 4039928, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15464081172484410235" + ], + [ + 4039932, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6225688351493779577" + ], + [ + 4039935, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15625622540562040764" + ], + [ + 4039937, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10502628395411644717" + ], + [ + 4039943, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11797930136011262397" + ], + [ + 4039946, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13885865319396877227" + ], + [ + 4039957, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5671629446760297180" + ], + [ + 4039960, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1393594888891471458" + ], + [ + 4039967, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3969615252157795066" + ], + [ + 4039972, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17691785108059131348" + ], + [ + 4040670, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15630378231912934353" + ], + [ + 4040733, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1125260354809627775" + ], + [ + 4040797, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_250967488981294306" + ], + [ + 4040888, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6978632213513861862" + ], + [ + 4040963, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14539782102651016167" + ], + [ + 4073698, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4027982239315583183" + ], + [ + 4087187, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5952558817727402033" + ], + [ + 4108124, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8896124445966427863" + ], + [ + 4121492, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3777541655201128969" + ], + [ + 4134649, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14525555472067694467" + ], + [ + 4145652, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9698817224586140305" + ], + [ + 4147258, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11549998609739639611" + ], + [ + 4147317, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6537741669769489394" + ], + [ + 4147443, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7264494085295711900" + ], + [ + 4147482, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1691570203381345331" + ], + [ + 4147600, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_805114710948033208" + ], + [ + 4158981, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1404183938416935234" + ], + [ + 4159025, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_366169914351560748" + ], + [ + 4159276, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6543673387235939768" + ], + [ + 4159300, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_583449162339166668" + ], + [ + 4159319, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15439605152958871044" + ], + [ + 4159324, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11830655025939200238" + ], + [ + 4159646, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2897405567433811008" + ], + [ + 4159775, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11323777380456298142" + ], + [ + 4159784, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18401703462013938226" + ], + [ + 4159918, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3439594419195257361" + ], + [ + 4160584, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8461195438972411491" + ], + [ + 4160770, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14710116434070740759" + ], + [ + 4161660, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8602813889361540004" + ], + [ + 4161781, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6869104894578156210" + ], + [ + 4161828, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2207584602328927841" + ], + [ + 4162147, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13886441570120016825" + ] +] \ No newline at end of file diff --git a/rlinf/algorithms/losses.py b/rlinf/algorithms/losses.py index f1bf025cd..72e9004c6 100644 --- a/rlinf/algorithms/losses.py +++ b/rlinf/algorithms/losses.py @@ -201,11 +201,15 @@ def compute_math_ppo_actor_loss(**kwargs): advantages = kwargs["advantages"] loss_mask = kwargs.get("loss_mask", None) c_clip = kwargs.get("c_clip", None) - - assert logprobs.dtype == torch.float32 - assert old_logprobs.dtype == torch.float32 - assert advantages.dtype == torch.float32 - + if logprobs.dtype != torch.float32: + logprobs = logprobs.float() # 转换为 float32 + #assert logprobs.dtype == torch.float32 + #assert old_logprobs.dtype == torch.float32 + #assert advantages.dtype == torch.float32 + if old_logprobs.dtype != torch.float32: + old_logprobs = old_logprobs.float() + if advantages.dtype != torch.float32: + advantages = advantages.float() assert loss_mask is not None loss_mask_count = loss_mask.count_nonzero() or 1 diff --git a/rlinf/data/io_struct.py b/rlinf/data/io_struct.py index 4948131a6..2fa0fea8a 100644 --- a/rlinf/data/io_struct.py +++ b/rlinf/data/io_struct.py @@ -27,7 +27,7 @@ get_iterator_k_split, split_list, ) - +import torch_npu def get_batch_size( batch: Dict[str, torch.Tensor], batch_tensor_key: str = "input_ids" @@ -809,12 +809,12 @@ def to_actor_batch( ) # [B, training_seq_length] batch = { - "input_ids": input_ids.cuda(), - "attention_mask": attention_mask.cuda(), - "is_end": is_end.cuda(), - "position_ids": position_ids.cuda(), - "prompt_lengths": prompt_lengths.cuda(), - "response_lengths": response_lengths.cuda(), + "input_ids": input_ids.npu(), + "attention_mask": attention_mask.npu(), + "is_end": is_end.npu(), + "position_ids": position_ids.npu(), + "prompt_lengths": prompt_lengths.npu(), + "response_lengths": response_lengths.npu(), } if ( @@ -825,7 +825,7 @@ def to_actor_batch( if self.advantages is not None: if isinstance(self.advantages, torch.Tensor): - batch["advantages"] = self.advantages.cuda() + batch["advantages"] = self.advantages.npu() else: response_attention_mask = attention_mask[ :, -max_response_len: @@ -833,17 +833,17 @@ def to_actor_batch( advantages = torch.tensor(self.advantages, dtype=torch.float32).reshape( -1, 1 ) # [B, 1] - advantages = response_attention_mask.float().cuda() * advantages.cuda() - batch["advantages"] = advantages.cuda() + advantages = response_attention_mask.float().npu() * advantages.npu() + batch["advantages"] = advantages.npu() if self.prev_logprobs is not None: - batch["prev_logprobs"] = self.prev_logprobs.cuda() + batch["prev_logprobs"] = self.prev_logprobs.npu() if self.ref_logprobs is not None: - batch["ref_logprobs"] = self.ref_logprobs.cuda() + batch["ref_logprobs"] = self.ref_logprobs.npu() if self.rewards is not None: - batch["rewards"] = self.rewards.cuda() + batch["rewards"] = self.rewards.npu() if self.rollout_logprobs is not None: logprobs = batch_pad_to_fixed_len( @@ -854,7 +854,7 @@ def to_actor_batch( max_batch_len=max_response_len, pad_token=pad_token, ) - batch["prev_logprobs"] = logprobs.cuda() + batch["prev_logprobs"] = logprobs.npu() return batch diff --git a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py index c16ddc3a8..690b3d4f5 100644 --- a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py +++ b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py @@ -35,6 +35,7 @@ from rlinf.utils.logging import get_logger from rlinf.utils.utils import clear_memory +import torch_npu class FSDPModelManager: """ @@ -55,9 +56,9 @@ def model_provider_func(self) -> torch.nn.Module: use_triton = cfg.get("use_triton", True) - assert torch.cuda.is_available(), "CUDA is not available." + assert torch.npu.is_available(), "CUDA is not available." local_rank = int(os.environ.get("LOCAL_RANK", 0)) - device = torch.device(f"cuda:{local_rank}") + device = torch.npu.device(f"npu:{local_rank}") model_config = AutoConfig.from_pretrained( cfg.model.model_path, diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/__init__.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/__init__.py new file mode 100644 index 000000000..5b365ea1e --- /dev/null +++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/io_struct.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/io_struct.py new file mode 100644 index 000000000..960d40eb0 --- /dev/null +++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/io_struct.py @@ -0,0 +1,59 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class TaskMethodInput: + method_name: str + args: List[Any] = field(default_factory=list) + kwargs: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class TaskMethodOutput: + method_name: str + result: Optional[Any] = None + + +@dataclass +class OffloadReqInput: + pass + + +@dataclass +class OffloadReqOutput: + pass + + +@dataclass +class SyncWeightInput: + pass + + +@dataclass +class SyncWeightOutput: + pass + + +@dataclass +class SyncHFWeightInput: + pass + + +@dataclass +class SyncHFWeightOutput: + pass diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_engine.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_engine.py new file mode 100644 index 000000000..e8e05c88b --- /dev/null +++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_engine.py @@ -0,0 +1,363 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import atexit +import logging +import multiprocessing as mp +import os +import random +import signal +import threading +import time +from typing import Dict, Optional, Tuple + +import uvloop +import zmq +from omegaconf import DictConfig +from sglang.srt.entrypoints.engine import Engine as _Engine +from sglang.srt.managers.data_parallel_controller import ( + run_data_parallel_controller_process, +) +from sglang.srt.managers.detokenizer_manager import run_detokenizer_process +from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter + +# from sglang.srt.managers.scheduler import run_scheduler_process +# from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.managers.template_manager import TemplateManager +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.utils import ( + assert_pkg_version, + configure_logger, + get_bool_env_var, + get_zmq_socket, + is_cuda, + kill_process_tree, + launch_dummy_health_check_server, + prepare_model_and_tokenizer, + set_prometheus_multiproc_dir, + set_ulimit, +) + +from rlinf.scheduler import WorkerAddress +from rlinf.utils.placement import ComponentPlacement + +from .io_struct import OffloadReqInput, SyncHFWeightInput, SyncWeightInput +from .sgl_scheduler import run_scheduler_process +from .tokenizer_manager import TokenizerManager + +# Fix a bug of Python threading +setattr(threading, "_register_atexit", lambda *args, **kwargs: None) + +logger = logging.getLogger(__name__) +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +_is_cuda = is_cuda() + + +class Engine(_Engine): + def __init__( + self, + parent_address: WorkerAddress, + placement: ComponentPlacement, + config: DictConfig, + dp_rank: int, + **kwargs, + ): + """ + The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`. + Please refer to `ServerArgs` for the documentation. + """ + if "server_args" in kwargs: + # Directly load server_args + server_args = kwargs["server_args"] + else: + # Construct server_args from kwargs + if "log_level" not in kwargs: + # Do not print logs by default + kwargs["log_level"] = "error" + server_args = ServerArgs(**kwargs) + + # Shutdown the subprocesses automatically when the program exits + atexit.register(self.shutdown) + + # Allocate ports for inter-process communications + self.port_args = PortArgs.init_new(server_args) + + # Launch subprocesses + tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( + parent_address=parent_address, + placement=placement, + config=config, + dp_rank=dp_rank, + server_args=server_args, + port_args=self.port_args, + ) + + self.server_args = server_args + self.tokenizer_manager = tokenizer_manager + self.scheduler_info = scheduler_info + self.template_manager = template_manager + + context = zmq.Context(2) + self.send_to_rpc = get_zmq_socket( + context, zmq.DEALER, self.port_args.rpc_ipc_name, True + ) + + def offload_model_weights(self): + """Offload model weights to meta.""" + obj = OffloadReqInput() + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.offload_model_weights(obj, None) + ) + + def sync_hf_weight(self): + obj = SyncHFWeightInput() + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.tokenizer_manager.sync_hf_weight(obj)) + + def sync_weight(self): + obj = SyncWeightInput() + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.tokenizer_manager.sync_weight(obj)) + + +def _set_envs_and_config(server_args: ServerArgs): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem)) + if not server_args.enable_symm_mem: + os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + os.environ["CUDA_MODULE_LOADING"] = "AUTO" + # flashinfer uses this environment variable for various kernels from MoE to quant kernels + os.environ["TRTLLM_ENABLE_PDL"] = "1" + + # Can also be passed as argument + os.environ["SGLANG_RUN_ID"] = ( + f"sglang-run-{time.time()}-{random.randint(0, 100000000)}" + ) + + # Set prometheus env vars + if server_args.enable_metrics: + set_prometheus_multiproc_dir() + + # Set ulimit + set_ulimit() + + # Check flashinfer version + if server_args.attention_backend == "flashinfer": + assert_pkg_version( + "flashinfer_python", + "0.3.0", + "Please uninstall the old version and " + "reinstall the latest version by following the instructions " + "at https://docs.flashinfer.ai/installation.html.", + ) + if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"): + assert_pkg_version( + "sgl-kernel", + "0.3.8", + "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", + ) + + if True: # Keep this check for internal code compatibility + # Register the signal handler. + # The child processes will send SIGQUIT to this process when any error happens + # This process then clean up the whole process tree + # Note: This sigquit handler is used in the launch phase, and may be replaced by + # the running_phase_sigquit_handler in the tokenizer manager after the grpc server is launched. + def launch_phase_sigquit_handler(signum, frame): + logger.error( + "Received sigquit from a child process. It usually means the child failed." + ) + kill_process_tree(os.getpid()) + + signal.signal(signal.SIGQUIT, launch_phase_sigquit_handler) + + # Set mp start method + mp.set_start_method("spawn", force=True) + + +def _launch_subprocesses( + parent_address: WorkerAddress, + placement: ComponentPlacement, + config: DictConfig, + dp_rank: int, + server_args: ServerArgs, + port_args: Optional[PortArgs] = None, +) -> Tuple[TokenizerManager, TemplateManager, Dict]: + """ + Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. + """ + + assert server_args.pp_size == 1, ( + "RLinf currently only supports and validates pp_size=1." + ) + + # Configure global environment + configure_logger(server_args) + server_args.check_server_args() + _set_envs_and_config(server_args) + + # Allocate ports for inter-process communications + if port_args is None: + port_args = PortArgs.init_new(server_args) + logger.info(f"{server_args=}") + + # If using model from www.modelscope.cn, first download the model. + server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( + server_args.model_path, server_args.tokenizer_path + ) + + scheduler_procs = [] + if server_args.dp_size == 1: + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + scheduler_pipe_readers = [] + + nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1) + tp_size_per_node = server_args.tp_size // nnodes_per_tp_group + tp_rank_range = range( + tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group), + tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1), + ) + + pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1) + pp_rank_range = range( + pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group), + pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1), + ) + + for pp_rank in pp_rank_range: + for tp_rank in tp_rank_range: + reader, writer = mp.Pipe(duplex=False) + gpu_id = ( + server_args.base_gpu_id + + ((pp_rank % pp_size_per_node) * tp_size_per_node) + + (tp_rank % tp_size_per_node) * server_args.gpu_id_step + ) + moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) + proc = mp.Process( + target=run_scheduler_process, + args=( + parent_address, + placement, + config, + server_args.tp_size * server_args.pp_size, + tp_rank + pp_rank * server_args.pp_size, + server_args, + port_args, + gpu_id, + tp_rank, + moe_ep_rank, + pp_rank, + None, + writer, + None, + ), + ) + + with memory_saver_adapter.configure_subprocess(): + proc.start() + scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) + else: + # Launch the data parallel controller + reader, writer = mp.Pipe(duplex=False) + scheduler_pipe_readers = [reader] + proc = mp.Process( + target=run_data_parallel_controller_process, + args=(server_args, port_args, writer), + ) + proc.start() + scheduler_procs.append(proc) + + if server_args.node_rank >= 1: + # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer, + # so they can just wait here. + + for reader in scheduler_pipe_readers: + data = reader.recv() + assert data["status"] == "ready" + + if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": + # When using `Engine` as a Python API, we don't want to block here. + return None, None, None + + launch_dummy_health_check_server( + server_args.host, server_args.port, server_args.enable_metrics + ) + + for proc in scheduler_procs: + proc.join() + logger.error( + f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" + ) + return None, None, None + + # Launch detokenizer process + detoken_proc = mp.Process( + target=run_detokenizer_process, + args=( + server_args, + port_args, + ), + ) + detoken_proc.start() + if server_args.tokenizer_worker_num > 1: + # Launch multi-tokenizer router + tokenizer_manager = MultiTokenizerRouter(server_args, port_args) + + # Initialize templates + template_manager = None + else: + # Launch tokenizer process + tokenizer_manager = TokenizerManager(server_args, port_args) + + # Initialize templates + template_manager = TemplateManager() + template_manager.initialize_templates( + tokenizer_manager=tokenizer_manager, + model_path=server_args.model_path, + chat_template=server_args.chat_template, + completion_template=server_args.completion_template, + ) + + # Wait for the model to finish loading + scheduler_infos = [] + for i in range(len(scheduler_pipe_readers)): + try: + data = scheduler_pipe_readers[i].recv() + except EOFError: + logger.error( + f"Rank {i} scheduler is dead. Please check if there are relevant logs." + ) + scheduler_procs[i].join() + logger.error(f"Exit code: {scheduler_procs[i].exitcode}") + raise + + if data["status"] != "ready": + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + scheduler_infos.append(data) + + # Assume all schedulers have the same scheduler_info + scheduler_info = scheduler_infos[0] + tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] + return tokenizer_manager, template_manager, scheduler_info diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_scheduler.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_scheduler.py new file mode 100644 index 000000000..f6aac9d27 --- /dev/null +++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_scheduler.py @@ -0,0 +1,476 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import faulthandler +import logging +import os +import signal +from typing import Optional + +import psutil +import setproctitle +import torch +from omegaconf import DictConfig +from sglang.srt.disaggregation.utils import ( + DisaggregationMode, +) +from sglang.srt.managers.io_struct import ( + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, +) +from sglang.srt.managers.scheduler import Scheduler as _Scheduler +from sglang.srt.managers.scheduler import logger +from sglang.srt.managers.utils import DPBalanceMeta +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import ( + broadcast_pyobj, + configure_logger, + get_bool_env_var, + kill_itself_when_parent_died, + set_gpu_proc_affinity, + suppress_other_loggers, +) +from sglang.utils import get_exception_traceback + +from rlinf.scheduler import Worker, WorkerAddress +from rlinf.utils.placement import ModelParallelComponentPlacement, PlacementMode +from rlinf.workers.rollout.utils import ( + RankMapper, + get_module_from_name, + rebind_param_attr, + swap_tensor_pointer, +) + +from .io_struct import ( + OffloadReqInput, + OffloadReqOutput, + SyncHFWeightInput, + SyncHFWeightOutput, + SyncWeightInput, + SyncWeightOutput, + TaskMethodInput, + TaskMethodOutput, +) +import torch_npu +logger.setLevel(logging.WARNING) + +def safe_load_weights(model, weights: list): + """ + 安全加载权重,自动兼容两种命名约定: + - 'visual.xxx' (Hugging Face Qwen-VL 格式) + - 'model.visual.xxx' (部分 SGLang 或自定义格式) + + Parameters: + model: PyTorch 模型实例(需有 named_parameters()) + weights: List of (name, torch.Tensor) + """ + params_dict = dict(model.named_parameters()) + + # 构建一个映射:标准化 key -> 实际参数名 + # 例如:'visual.patch_embed.proj.weight' 可能对应 params_dict 中的 'visual...' 或 'model.visual...' + normalized_to_actual = {} + for param_name in params_dict.keys(): + if param_name.startswith("model.visual."): + # 映射到无 model. 的标准名 + normalized = param_name[len("model."):] # "visual.xxx" + normalized_to_actual[normalized] = param_name + normalized_to_actual[param_name] = param_name # 也保留原名 + elif param_name.startswith("visual."): + normalized = param_name + normalized_to_actual[normalized] = param_name + normalized_to_actual["model." + normalized] = param_name # 兼容带 model. 的输入 + else: + # 非 visual 参数,直接映射 + normalized_to_actual[param_name] = param_name + + # 加载每个权重 + for name, loaded_weight in weights: + if name in normalized_to_actual: + actual_name = normalized_to_actual[name] + param = params_dict[actual_name] + assert param.shape == loaded_weight.shape, ( + f"Shape mismatch for {name}: expected {param.shape}, got {loaded_weight.shape}" + ) + param.copy_(loaded_weight) + else: + # 可选:跳过不存在的参数(如优化器状态、非模型参数) + print(f"[Warning] Skipping weight not in model: {name}") + continue + +class Scheduler(_Scheduler, Worker): + """ + Overridden class of SGLang's TP worker class _Scheduler. + A Scheduler is a Task that manages the TP worker, and performs necessary weight synchronization with actor and weight offloading. + """ + + def __init__( + self, + parent_address: WorkerAddress, + placement: ModelParallelComponentPlacement, + config: DictConfig, + world_size: int, + rank: int, + server_args: ServerArgs, + port_args: PortArgs, + gpu_id: int, + tp_rank: int, + moe_ep_rank: int, + pp_rank: int, + dp_rank: Optional[int], + dp_balance_meta: Optional[DPBalanceMeta] = None, + ): + Worker.__init__( + self, parent_address=parent_address, world_size=world_size, rank=rank + ) + + # since 0.4.6.post2, pp_rank is added into Scheduler init's parameters + # but we don't use it in our implementation, so we set it to 0 + _Scheduler.__init__( + self, + server_args, + port_args, + gpu_id, + tp_rank, + moe_ep_rank, + pp_rank, + dp_rank, + dp_balance_meta, + ) + # `TpModelWorkerClient` is used when ServerArgs.enable_overlap=True, and it has 'worker' attribute. + # But in early SGLang version, `TpModelWorker` doesn't have 'worker' attribute. + if not hasattr(self.tp_worker, "worker"): + self.tp_worker.worker = self.tp_worker + + self._request_dispatcher._mapping.extend( + [ + (TaskMethodInput, self.run_task_method), + (OffloadReqInput, self.offload_model_weights), + (SyncWeightInput, self.sync_weight), + (SyncHFWeightInput, self.sync_hf_weight), + ] + ) + self.cfg = config + self.binded_attr = {} + + self._actor_group_name = self.cfg.actor.group_name + self.placement_mode = placement.placement_mode + self.actor_weight_rank = RankMapper.get_rollout_rank_to_actor_rank_map( + placement + )[(self.get_parent_rank(), self._rank)] + # it's important to use load_weight to load resharded weight from megatron + for _, module in self.tp_worker.worker.model_runner.model.named_modules(): + if hasattr(module, "use_presharded_weights"): + module.use_presharded_weights = False + + self._logger.info( + f"Running Scheduler dp rank {self.get_parent_rank()}, tp rank {self.tp_rank}, corresponding actor weight rank = {self.actor_weight_rank}" + ) + + def sync_in_tp(self, fn: str = ""): + broadcast_pyobj( + [], self.tp_rank, self.tp_worker.worker.model_runner.tp_group.cpu_group + ) + # logger.info(f"{fn}: Sync in tp success!") + + def cuda_info(self, text: str = ""): + free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info() + free_gpu_memory /= 2**30 + total_gpu_memory /= 2**30 + + memory_allocated = torch.npu.memory_allocated() / 2**30 + memory_reserved = torch.npu.memory_reserved() / 2**30 + + self._logger.info( + f"[dp {self.get_parent_rank()}-tp {self.tp_rank}] {text} " + f"{memory_allocated=:.2f} GiB, {memory_reserved=:.2f} GiB, " + f"{free_gpu_memory=:.2f} GiB, {total_gpu_memory=:.2f} GiB" + ) + + def offload_model_weights(self, recv_req: OffloadReqInput): + use_cudagraph = not self.cfg.rollout.enforce_eager + colocate = self.placement_mode == PlacementMode.COLLOCATED + if not colocate: + assert use_cudagraph, "If not colocate, use_cudagraph must be True now." + + if use_cudagraph or not colocate: + self.release_memory_occupation(ReleaseMemoryOccupationReqInput()) + # self.cuda_info("After offload Model weights and kv cache") + return OffloadReqOutput() + + # manually offload + self.named_buffers = { + n: buf.clone() + for n, buf in self.tp_worker.worker.model_runner.model.named_buffers() + } + + self.binded_attr = { + name: param.__dict__ + for name, param in self.tp_worker.worker.model_runner.model.named_parameters() + } + + # offload parameters + self.tp_worker.worker.model_runner.model.to("meta") + + # offload kv cache + self.tp_worker.worker.model_runner.token_to_kv_pool._clear_buffers() + + self.flush_cache() + self.sync_in_tp("offload_model_weights") + # self.cuda_info("After offload Model weights and kv cache") + return OffloadReqOutput() + + def sync_hf_weight(self, recv_req: SyncHFWeightInput): + use_cudagraph = not self.cfg.rollout.enforce_eager + colocate = self.placement_mode == PlacementMode.COLLOCATED + + assert use_cudagraph, "use_cudagraph must be True now." + + state_dict = self.recv( + src_group_name=self._actor_group_name, + src_rank=self.actor_weight_rank, + ) + + model = self.tp_worker.worker.model_runner.model + + if colocate: + self.resume_memory_occupation(ResumeMemoryOccupationReqInput()) + for name, handle in state_dict.items(): + #func, args = handle + #list_args = list(args) + # NOTE: the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + #list_args[6] = torch.npu.current_device() + #new_weight = func(*list_args) + + #model.load_weights([(name, new_weight)]) + #self.tp_worker.worker.model_runner.update_weights_from_tensor( + # [(name, new_weight)], load_format="direct" + #) + #del new_weight + func, args = handle + import inspect + sig = inspect.signature(func) + param_names = list(sig.parameters.keys()) + + # 将 args 转为 kwargs + kwargs = {} + args = list(args) + for i, param_name in enumerate(param_names): + if i < len(args): + kwargs[param_name] = args[i] + else: + break + + # 修改设备参数(假设参数名是 'map_location' 或 'device') + if 'map_location' in kwargs: + kwargs['map_location'] = f"npu:{torch.npu.current_device()}" + elif 'device' in kwargs: + kwargs['device'] = torch.npu.current_device() + + new_weight = func(**kwargs) + model.load_weights([(name, new_weight)]) + #safe_load_weights(model, [(name, new_weight)]) + del new_weight + #fixed_weights = [] + #for name, weight in [(name, new_weight)]: + # if name.startswith("visual.") and not name.startswith("model.visual."): + # fixed_weights.append(("model." + name, weight)) + # else: + # fixed_weights.append((name, weight)) + #model.load_weights(fixed_weights) + #del new_weight + else: + # disaggregate mode, recv tensor directly + for name, tensor in state_dict.items(): + model.load_weights([(name, tensor)]) + self.flush_cache() + self.sync_in_tp("sync_hf_weight") + return SyncHFWeightOutput() + + def sync_weight(self, recv_req: SyncWeightInput): + use_cudagraph = not self.cfg.rollout.enforce_eager + colocate = self.placement_mode == PlacementMode.COLLOCATED + if not colocate: + assert use_cudagraph, "If not colocate, use_cudagraph must be True now." + + state_dict = self.recv( + src_group_name=self._actor_group_name, + src_rank=self.actor_weight_rank, + ) + model = self.tp_worker.worker.model_runner.model + + if use_cudagraph and colocate: + self.resume_memory_occupation(ResumeMemoryOccupationReqInput()) + + if colocate: + if use_cudagraph: + for name, handle in state_dict.items(): + func, args = handle + list_args = list(args) + # NOTE: the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = torch.npu.current_device() + new_weight = func(*list_args) + + self.tp_worker.worker.model_runner.update_weights_from_tensor( + [(name, new_weight)], load_format="direct" + ) + del new_weight + + else: + named_params = dict(model.named_parameters()) + for name, handle in state_dict.items(): + rebind_param_attr(model, name, self.binded_attr, materialize=False) + func, args = handle + list_args = list(args) + list_args[6] = torch.npu.current_device() + new_weight = func(*list_args) + vllm_weight = named_params[name] + assert vllm_weight.shape == new_weight.shape, ( + f"{name}: {vllm_weight.shape=}, {new_weight.shape=}" + ) + assert vllm_weight.dtype == new_weight.dtype, ( + f"{name}: {vllm_weight.dtype=}, {new_weight.dtype=}" + ) + + swap_tensor_pointer(vllm_weight, new_weight) + del new_weight + + for name, buffer in self.named_buffers.items(): + vllm_buffer = get_module_from_name(model, name) + assert vllm_buffer.shape == buffer.shape + assert vllm_buffer.dtype == buffer.dtype + swap_tensor_pointer(vllm_buffer, buffer) + + self.named_buffers = {} + + self.tp_worker.worker.model_runner.token_to_kv_pool._create_buffers() + else: + # disaggregate mode, recv tensor directly + named_tensors = [(n, p) for n, p in state_dict.items()] + self.tp_worker.worker.model_runner.update_weights_from_tensor( + named_tensors, load_format="direct" + ) + self.sync_in_tp("sync_weight") + + return SyncWeightOutput() + + def run_task_method(self, obj: TaskMethodInput): + """ + Run a CommTask method with the given name and arguments. + NOTE: will call wait() if async_op is True. + """ + result = getattr(self, obj.method_name)(*obj.args, **obj.kwargs) + if "async_op" in obj.kwargs and obj.kwargs["async_op"]: + result = result.wait() + return TaskMethodOutput(method_name=obj.method_name, result=result) + + +def run_scheduler_process( + parent_address: WorkerAddress, + placement: ModelParallelComponentPlacement, + config: DictConfig, + world_size: int, + rank: int, + server_args: ServerArgs, + port_args: PortArgs, + gpu_id: int, + tp_rank: int, + moe_ep_rank: int, + pp_rank: int, + dp_rank: Optional[int], + pipe_writer, + balance_meta: Optional[DPBalanceMeta] = None, +): + # Generate the prefix + prefix = "" + if dp_rank is not None: + prefix += f" DP{dp_rank}" + if server_args.tp_size > 1: + prefix += f" TP{tp_rank}" + if server_args.ep_size > 1: + prefix += f" EP{moe_ep_rank}" + if server_args.pp_size > 1: + prefix += f" PP{pp_rank}" + + # Config the process + setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}") + faulthandler.enable() + kill_itself_when_parent_died() + parent_process = psutil.Process().parent() + + # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var + if dp_rank is None and "SGLANG_DP_RANK" in os.environ: + dp_rank = int(os.environ["SGLANG_DP_RANK"]) + + # Configure the logger + configure_logger(server_args, prefix=prefix) + suppress_other_loggers() + + # Set cpu affinity to this gpu process + if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): + set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) + + # Create a scheduler and run the event loop + try: + scheduler = Scheduler( + parent_address, + placement, + config, + world_size, + rank, + server_args, + port_args, + gpu_id, + tp_rank, + moe_ep_rank, + pp_rank, + dp_rank, + dp_balance_meta=balance_meta, + ) + pipe_writer.send( + { + "status": "ready", + "max_total_num_tokens": scheduler.max_total_num_tokens, + "max_req_input_len": scheduler.max_req_input_len, + } + ) + + disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode + if disaggregation_mode == DisaggregationMode.NULL: + if server_args.pp_size > 1: + scheduler.event_loop_pp() + elif scheduler.enable_overlap: + scheduler.event_loop_overlap() + else: + scheduler.event_loop_normal() + elif disaggregation_mode == DisaggregationMode.PREFILL: + if scheduler.enable_overlap: + scheduler.event_loop_overlap_disagg_prefill() + else: + if server_args.pp_size > 1: + scheduler.event_loop_pp_disagg_prefill() + else: + scheduler.event_loop_normal_disagg_prefill() + + elif disaggregation_mode == DisaggregationMode.DECODE: + if scheduler.enable_overlap: + scheduler.event_loop_overlap_disagg_decode() + else: + scheduler.event_loop_normal_disagg_decode() + + except Exception: + traceback = get_exception_traceback() + logger.error(f"Scheduler hit an exception: {traceback}") + parent_process.send_signal(signal.SIGQUIT) diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/tokenizer_manager.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/tokenizer_manager.py new file mode 100644 index 000000000..07b77b200 --- /dev/null +++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/tokenizer_manager.py @@ -0,0 +1,129 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import fastapi +from sglang.srt.managers.io_struct import AbortReq +from sglang.srt.managers.tokenizer_manager import TokenizerManager as _TokenizerManager +#from sglang.srt.managers.tokenizer_manager import _Communicator +from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator +from sglang.srt.server_args import PortArgs, ServerArgs +from .io_struct import ( + OffloadReqInput, + OffloadReqOutput, + SyncHFWeightInput, + SyncHFWeightOutput, + SyncWeightInput, + SyncWeightOutput, + TaskMethodInput, + TaskMethodOutput, +) + + +# Add two methods and their communicators, input/output structs. +class TokenizerManager(_TokenizerManager): + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + ): + super().__init__( + server_args=server_args, + port_args=port_args, + ) + + self.run_task_method_communicator = _Communicator( + self.send_to_scheduler, + fan_out=server_args.dp_size, + ) + self.offload_model_weights_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.sync_weight_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.sync_hf_weight_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + + self._result_dispatcher._mapping.extend( + [ + ( + TaskMethodOutput, + self.run_task_method_communicator.handle_recv, + ), + ( + OffloadReqOutput, + self.offload_model_weights_communicator.handle_recv, + ), + ( + SyncWeightOutput, + self.sync_weight_communicator.handle_recv, + ), + ( + SyncHFWeightOutput, + self.sync_hf_weight_communicator.handle_recv, + ), + ] + ) + + async def run_task_method( + self, + obj: TaskMethodInput = None, + request: Optional[fastapi.Request] = None, + ): + """ + Run a task method with the given name and arguments. + """ + self.auto_create_handle_loop() + if isinstance(obj, str): + obj = TaskMethodInput(method_name=obj) + res: List[TaskMethodOutput] = await self.run_task_method_communicator(obj) + return res[0].result + + async def offload_model_weights( + self, + obj: OffloadReqInput = None, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + if obj is None: + obj = OffloadReqInput() + await self.offload_model_weights_communicator(obj) + + async def sync_hf_weight( + self, + obj: SyncHFWeightInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.sync_hf_weight_communicator(obj) + + async def sync_weight( + self, + obj: SyncWeightInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.sync_weight_communicator(obj) + + def abort_request(self, rid: str): + if rid != "" and rid not in self.rid_to_state: + return + req = AbortReq(rid) + self.send_to_scheduler.send_pyobj(req) + + async def pause_generation(self): + self.abort_request("") diff --git a/rlinf/utils/distributed.py b/rlinf/utils/distributed.py index a54f444d8..1f5ddc9c6 100644 --- a/rlinf/utils/distributed.py +++ b/rlinf/utils/distributed.py @@ -29,7 +29,7 @@ from rlinf.utils.timers import NamedTimer - +import torch_npu def compute_rollout_metrics( rollout_batch, max_prompt_len, @@ -38,7 +38,7 @@ def compute_rollout_metrics( dp_group=None, use_critic=False, ): - device = torch.device(f"cuda:{torch.cuda.current_device()}") + device = torch.device(f"npu:{torch.npu.current_device()}") advantages = rollout_batch["advantages"].to(device=device) mask = rollout_batch["attention_mask"][:, -response_len:].to(device=device) prompt_lengths = rollout_batch["prompt_lengths"].clone().to(device=device) @@ -105,7 +105,7 @@ def compute_rollout_metrics( adv_max = torch.max(valid_adv).detach().item() adv_min = torch.min(valid_adv).detach().item() reduce_tensor = torch.as_tensor( - [-adv_min, adv_max], device=torch.cuda.current_device(), dtype=torch.float32 + [-adv_min, adv_max], device=torch.npu.current_device(), dtype=torch.float32 ) torch.distributed.all_reduce( reduce_tensor, @@ -175,7 +175,7 @@ def from_rollout_batches( dp_group: Optional[ProcessGroup], partitioning_tool: Callable, ) -> Self: - current_device = torch.cuda.current_device() + current_device = torch.npu.current_device() attn_mask = rollout_batches.get("attention_mask") current_num_samples = attn_mask.size(0) @@ -406,12 +406,12 @@ def rebalance_nd_tensor(tensor, group): NOTE: assumes all other (i.e., non-zero) dimensions are equal. """ num_samples = torch.as_tensor( - tensor.size(0), dtype=torch.int64, device=torch.cuda.current_device() + tensor.size(0), dtype=torch.int64, device=torch.npu.current_device() ) batch_num_per_rank = torch.zeros( torch.distributed.get_world_size(group), dtype=torch.int64, - device=torch.cuda.current_device(), + device=torch.npu.current_device(), ) torch.distributed.all_gather_into_tensor( batch_num_per_rank, num_samples, group=group @@ -422,7 +422,7 @@ def rebalance_nd_tensor(tensor, group): indices = batch_num_per_rank.cumsum(dim=0) output_tensor = torch.zeros( - B, *other_dims, dtype=tensor.dtype, device=torch.cuda.current_device() + B, *other_dims, dtype=tensor.dtype, device=torch.npu.current_device() ) # tensor_split is a view we can copy into @@ -454,7 +454,7 @@ def broadcast_tensor( """ if torch.distributed.get_rank() == src: - tensor = tensor.cuda() + tensor = tensor.npu() if dtype: tensor = tensor.to(dtype) @@ -467,7 +467,7 @@ def broadcast_tensor( torch.distributed.broadcast_object_list(metadata, src, group) dtype, input_shape = metadata - tensor = torch.empty(input_shape, dtype=dtype, device="cuda") + tensor = torch.empty(input_shape, dtype=dtype, device="npu") torch.distributed.broadcast(tensor, src, group) return tensor @@ -519,7 +519,7 @@ def broadcast_tensor_within_dp(tensor: torch.Tensor, dtype: torch.dtype): def gather_tensor(tensor, dst, group, dtype=None): """Gather any tensor to the dst rank from every other rank in the given group. All the ranks that send or receive data must call this function.""" - tensor = tensor.to(device=torch.cuda.current_device(), dtype=dtype) + tensor = tensor.to(device=torch.npu.current_device(), dtype=dtype) if torch.distributed.get_rank() == dst: gather_list = [ torch.empty_like(tensor) @@ -549,8 +549,8 @@ def normalize_tensor(tensor, mask, group=None): """normalizes a tensor using global mean and std""" dtype = torch.float64 tensor = tensor.to(dtype) - tensor = tensor.to(device=torch.cuda.current_device()) - mask = mask.to(device=torch.cuda.current_device()) + tensor = tensor.to(device=torch.npu.current_device()) + mask = mask.to(device=torch.npu.current_device()) tensor_global_mean, tensor_global_var = masked_global_mean_var( tensor, mask, group=group @@ -589,7 +589,7 @@ def masked_normalization( Normalized x, with the same shape as x. """ dtype = torch.float64 if high_precision else torch.float32 - x = x.to(dtype=dtype).cuda() + x = x.to(dtype=dtype).npu() if not inplace: x = x.clone() if dim is None: @@ -599,7 +599,7 @@ def masked_normalization( np.prod([x.shape[d] for d in dim]), dtype=dtype, device=x.device ) else: - mask = mask.to(dtype=dtype).cuda() + mask = mask.to(dtype=dtype).npu() assert len(mask.shape) == len(x.shape), (mask.shape, x.shape, dim) for i in range(len(x.shape)): if i in dim: @@ -643,8 +643,8 @@ def masked_global_mean_var(values, mask, group=None): mask and values must have same shape, with mask being {0,1} with 1 being the values we want to keep """ assert values.shape == mask.shape, (values.shape, mask.shape) - values = values.to(device=torch.cuda.current_device()) - mask = mask.to(device=torch.cuda.current_device()) + values = values.to(device=torch.npu.current_device()) + mask = mask.to(device=torch.npu.current_device()) values = values * mask @@ -652,7 +652,7 @@ def masked_global_mean_var(values, mask, group=None): sum_and_count = torch.tensor( [values.sum(), mask.sum()], dtype=torch.float64, - device=torch.cuda.current_device(), + device=torch.npu.current_device(), ) torch.distributed.all_reduce(sum_and_count, group=group) global_sum, global_count = sum_and_count @@ -660,7 +660,7 @@ def masked_global_mean_var(values, mask, group=None): variance_summed = ( (((values - global_mean) ** 2) * mask) .sum() - .to(device=torch.cuda.current_device(), dtype=torch.float64) + .to(device=torch.npu.current_device(), dtype=torch.float64) ) torch.distributed.all_reduce(variance_summed, group=group) @@ -669,12 +669,12 @@ def masked_global_mean_var(values, mask, group=None): def report_device_info(info_str): - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info() free_gpu_memory /= 2**30 total_gpu_memory /= 2**30 - memory_allocated = torch.cuda.memory_allocated() / 2**30 - memory_reserved = torch.cuda.memory_reserved() / 2**30 + memory_allocated = torch.npu.memory_allocated() / 2**30 + memory_reserved = torch.npu.memory_reserved() / 2**30 print( f"[Rank {torch.distributed.get_rank()}] {info_str}, {free_gpu_memory=:.2f} GiB, {total_gpu_memory=:.2f} GiB, {memory_allocated=:.2f} GiB, {memory_reserved=:.2f} GiB" @@ -725,7 +725,7 @@ def all_reduce_dict( ): keys = sorted(dictionary) tensor = torch.as_tensor( - [dictionary[k] for k in keys], dtype=dtype, device=torch.cuda.current_device() + [dictionary[k] for k in keys], dtype=dtype, device=torch.npu.current_device() ) torch.distributed.all_reduce(tensor, op=op, group=group) return dict(zip(keys, tensor.tolist())) diff --git a/rlinf/utils/utils.py b/rlinf/utils/utils.py index d117f4dc4..b840538cc 100644 --- a/rlinf/utils/utils.py +++ b/rlinf/utils/utils.py @@ -21,13 +21,13 @@ import torch import torch.nn.functional as F - +import torch_npu def clear_memory(sync=True): if sync: - torch.cuda.synchronize() + torch.npu.synchronize() gc.collect() - torch.cuda.empty_cache() + torch.npu.empty_cache() def apply_func_to_dict(func, dictionary): @@ -54,7 +54,7 @@ def retrieve_model_state_dict_in_cpu(model): cpu_dict[name] = item - torch.cuda.synchronize() + torch.npu.synchronize() return cpu_dict @@ -126,13 +126,21 @@ def seq_mean_token_mean(values, mask): def logprobs_from_logits_flash_attn(logits, labels, inplace_backward=True): - from flash_attn.ops.triton.cross_entropy import cross_entropy_loss + #from flash_attn.ops.triton.cross_entropy import cross_entropy_loss - output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward) - assert isinstance(output, tuple), ( - "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]." - ) - return -output[0] + #output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward) + #assert isinstance(output, tuple), ( + # "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]." + #) + #return -output[0] + import torch.nn.functional as F + + # 数值稳定的 log_softmax + log_probs = F.log_softmax(logits, dim=-1) + + # 提取标签对应的 logprob + labels = labels.unsqueeze(-1) + return torch.gather(log_probs, -1, labels).squeeze(-1) def compute_logprobs_from_logits(logits, target, task_type="embodied"): diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index 06607903f..1b3b015bb 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -62,7 +62,7 @@ seq_mean_token_sum, ) from rlinf.workers.rollout.utils import RankMapper - +import torch_npu class FSDPActor(FSDPModelManager, Worker): def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): @@ -87,11 +87,11 @@ def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): // self._world_size ) - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - self.device = torch.cuda.current_device() + torch.npu.set_device(int(os.environ["LOCAL_RANK"])) + self.device = torch.npu.current_device() world_size = self._world_size self.device_mesh = init_device_mesh( - "cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] + "npu", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] ) self._rollout_group_name = cfg.rollout.group_name @@ -216,7 +216,7 @@ def inference_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: multi_modal_inputs[key] = torch.cat( [inputs[key] for inputs in batch["multi_modal_inputs"]], dim=0, - ).cuda() + ).npu() outputs = self.model( input_ids=input_ids, @@ -257,20 +257,32 @@ def run_inference( batch, rollout_result = self.get_batch(input_channel) recv_batch_size += rollout_result.num_sequence self._load_weight_and_optimizer( - input_channel if self.is_pipeline else rollout_channel + input_channel if self.is_pipeline else rollout_channel + ) + num_splits = ( + rollout_result.num_sequence + // self.cfg.algorithm.logprob_forward_micro_batch_size ) + micro_batches_iter = get_iterator_k_split( + batch, + num_splits=num_splits, + ) + micro_batches = list(micro_batches_iter) + prev_logprobs = [] with self.worker_timer(): - prev_logprobs = self.inference_step(batch) - rollout_result.prev_logprobs = prev_logprobs.cpu() - + for micro_batch in micro_batches: + prev_logprobs.append(self.inference_step(micro_batch).cpu()) + rollout_result.prev_logprobs = torch.cat(prev_logprobs) if compute_ref_logprobs: assert self.ref_policy_state_dict is not None, ( "Reference policy state dict is None but compute_ref_logprobs is True" ) + ref_logprobs = [] with cpu_weight_swap(self.model, self.ref_policy_state_dict): - ref_logprobs = self.inference_step(batch) - rollout_result.ref_logprobs = ref_logprobs.cpu() + for micro_batch in micro_batches: + ref_logprobs.append(self.inference_step(micro_batch).cpu()) + rollout_result.ref_logprobs = torch.cat(ref_logprobs) self.put_result(rollout_result, output_channel) assert recv_batch_size == self.total_batch_size_per_dp, ( @@ -335,7 +347,7 @@ def run_training(self, input_channel: Channel) -> Tuple[Dict, list]: else nullcontext() ) for k, v in m_batch.items(): - m_batch[k] = v.cuda() if isinstance(v, torch.Tensor) else v + m_batch[k] = v.npu() if isinstance(v, torch.Tensor) else v multi_modal_inputs = {} if "multi_modal_inputs" in m_batch.keys(): @@ -346,7 +358,7 @@ def run_training(self, input_channel: Channel) -> Tuple[Dict, list]: for inputs in m_batch["multi_modal_inputs"] ], dim=0, - ).cuda() + ).npu() input_ids = m_batch["input_ids"] attention_mask = m_batch["attention_mask"] @@ -403,7 +415,7 @@ def run_training(self, input_channel: Channel) -> Tuple[Dict, list]: loss_mask=loss_mask, ) - entropy_loss = torch.tensor(0.0, device=torch.cuda.current_device()) + entropy_loss = torch.tensor(0.0, device=torch.npu.current_device()) if self.calculate_entropy: entropy = output["entropy"][ :, -self.response_len - 1 : -1 @@ -414,7 +426,7 @@ def run_training(self, input_channel: Channel) -> Tuple[Dict, list]: loss - self.cfg.algorithm.entropy_bonus * entropy_loss ) - kl_loss = torch.tensor(0.0, device=torch.cuda.current_device()) + kl_loss = torch.tensor(0.0, device=torch.npu.current_device()) if self.kl_beta > 0 and ref_logprobs is not None: kld = kl_penalty(ref_logprobs, logprobs, self.kl_penalty_type) kl_loss = self.loss_agg_func(kld, loss_mask) @@ -503,8 +515,8 @@ def compute_advantages_and_returns( mask = batch["attention_mask"][:, -self.response_len :] advantages, returns = calculate_adv_and_returns( adv_type=self.cfg.algorithm.adv_type, - reward_scores=batch["rewards"].cuda(), - mask=mask.cuda(), + reward_scores=batch["rewards"].npu(), + mask=mask.npu(), num_responses=self.cfg.algorithm.group_size, ) rollout_result.advantages = advantages.cpu() @@ -522,11 +534,11 @@ def __init__(self, cfg: DictConfig): super().__init__(cfg.actor) self.cfg = cfg - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - self.device = torch.cuda.current_device() + torch.npu.set_device(int(os.environ["LOCAL_RANK"])) + self.device = torch.npu.current_device() world_size = self._world_size self.device_mesh = init_device_mesh( - "cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] + "npu", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] ) self._env_group_name = cfg.env.group_name @@ -554,9 +566,9 @@ def init_worker(self): if self.cfg.actor.get("enable_offload", False): self.offload_fsdp_param_and_grad() self.offload_fsdp_optimizer() - torch.cuda.synchronize() + torch.npu.synchronize() gc.collect() - torch.cuda.empty_cache() + torch.npu.empty_cache() def model_provider_func(self): model = get_model(self.cfg.actor.checkpoint_load_path, self.cfg.actor.model) @@ -577,10 +589,10 @@ def sync_model_to_rollout(self): if self.cfg.actor.get("enable_offload", False): self.offload_fsdp_param_and_grad() self.offload_fsdp_optimizer() - torch.cuda.synchronize() + torch.npu.synchronize() del state_dict gc.collect() - torch.cuda.empty_cache() + torch.npu.empty_cache() async def recv_rollout_batch(self): send_num = self._component_placement.get_world_size("rollout") * self.stage_num @@ -864,7 +876,7 @@ def run_training(self): metrics_data["loss"] = loss.detach().item() append_to_dict(metrics, metrics_data) - torch.cuda.empty_cache() + torch.npu.empty_cache() grad_norm = self.model.clip_grad_norm_( max_norm=self.cfg.actor.optim.clip_grad @@ -886,9 +898,9 @@ def run_training(self): ) self.optimizer.zero_grad() - torch.cuda.synchronize() + torch.npu.synchronize() torch.distributed.barrier() - torch.cuda.empty_cache() + torch.npu.empty_cache() return mean_metric_dict diff --git a/rlinf/workers/actor/fsdp_actor_worker_bak.py b/rlinf/workers/actor/fsdp_actor_worker_bak.py new file mode 100644 index 000000000..84d8e09c5 --- /dev/null +++ b/rlinf/workers/actor/fsdp_actor_worker_bak.py @@ -0,0 +1,903 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from contextlib import nullcontext +from typing import Dict, Tuple + +import numpy as np +import torch +from omegaconf import DictConfig +from torch.distributed.device_mesh import init_device_mesh +from torch.multiprocessing.reductions import reduce_tensor +from tqdm import tqdm + +import rlinf.algorithms # noqa: F401 +from rlinf.algorithms.registry import actor_loss, calculate_adv_and_returns +from rlinf.algorithms.utils import ( + kl_penalty, + preprocess_advantages_inputs, + preprocess_loss_inputs, +) +from rlinf.data.io_struct import RolloutResult +from rlinf.hybrid_engines.fsdp.fsdp_model_manager import ( + FSDPModelManager, +) +from rlinf.models import get_model +from rlinf.models.embodiment.model_utils import custom_forward +from rlinf.scheduler import Channel, Cluster, Worker +from rlinf.utils.data_iter_utils import get_iterator_k_split +from rlinf.utils.distributed import all_reduce_dict +from rlinf.utils.distributed import ( + compute_rollout_metrics as compute_math_rollout_metrics, +) +from rlinf.utils.metric_utils import ( + append_to_dict, + compute_loss_mask, + compute_rollout_metrics, + compute_split_num, +) +from rlinf.utils.placement import ( + HybridComponentPlacement, + ModelParallelComponentPlacement, +) +from rlinf.utils.utils import ( + compute_logprobs_from_logits, + cpu_weight_swap, + masked_mean, + retrieve_model_state_dict_in_cpu, + seq_mean_token_mean, + seq_mean_token_sum, +) +from rlinf.workers.rollout.utils import RankMapper +import torch_npu + +class FSDPActor(FSDPModelManager, Worker): + def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): + Worker.__init__(self) + super().__init__(cfg.actor) + + self.cfg = cfg + + self.response_len = ( + cfg.actor.model.encoder_seq_length - cfg.data.max_prompt_length + ) + self.calculate_entropy = self.cfg.algorithm.calculate_entropy + self.calculate_entropy_loss = ( + self.cfg.algorithm.entropy_bonus > 0 and self.calculate_entropy + ) + self.kl_beta = self.cfg.algorithm.kl_beta + self.kl_penalty_type = self.cfg.algorithm.kl_penalty_type + + self.total_batch_size_per_dp = ( + self.cfg.data.rollout_batch_size + * self.cfg.algorithm.group_size + // self._world_size + ) + + torch.npu.set_device(int(os.environ["LOCAL_RANK"])) + self.device = torch.npu.current_device() + world_size = self._world_size + self.device_mesh = init_device_mesh( + "npu", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] + ) + + self._rollout_group_name = cfg.rollout.group_name + self._component_placement = placement + self.is_data_io_rank = True + self.is_pipeline = self._component_placement.is_disaggregated + self.ref_policy_state_dict = None + + if self.cfg.algorithm.loss_agg_func == "token-mean": + self.loss_agg_func = masked_mean + elif self.cfg.algorithm.loss_agg_func == "seq-mean-token-sum": + self.loss_agg_func = seq_mean_token_sum + elif self.cfg.algorithm.loss_agg_func == "seq-mean-token-mean": + self.loss_agg_func = seq_mean_token_mean + else: + raise NotImplementedError( + f"algorithm.loss_agg_func={self.cfg.algorithm.loss_agg_func} is not supported!" + ) + + def init_worker(self) -> None: + self.setup_model_and_optimizer() + if self.cfg.algorithm.kl_beta > 0 and self.cfg.actor.get( + "combine_reference_model", True + ): + self.ref_policy_state_dict = retrieve_model_state_dict_in_cpu(self.model) + + if self.cfg.actor.get("enable_offload", False): + self.offload_fsdp_param_and_grad() + self.offload_fsdp_optimizer() + self._setup_rollout_weight_dst_ranks() + + def _setup_rollout_weight_dst_ranks(self) -> None: + """Setup destination ranks for token and weight communication.""" + rank_map = RankMapper.get_actor_rank_to_rollout_rank_map( + self._component_placement + ) + self._weight_dst_rank_in_rollout = rank_map[self._rank] + self.log_info( + f"Actor rank {self._rank} will send weights to {self._weight_dst_rank_in_rollout}" + ) + + def del_reshard_state_dict(self) -> None: + if hasattr(self, "rollout_state_dict"): + del self.rollout_state_dict + + def sync_model_to_rollout(self) -> None: + if self.cfg.actor.get("enable_offload", False): + self.offload_fsdp_optimizer() + + if next(self.model.parameters()).is_cpu: + self.load_fsdp_param_and_grad(self.device) + self.rollout_state_dict = self.get_model_state_dict() + + has_visual = any("visual." in k for k in self.rollout_state_dict.keys()) + + state_dict = {} + + if self._weight_dst_rank_in_rollout is not None: + for k, v in self.rollout_state_dict.items(): + name = k + if has_visual: + if name.startswith("model.language_model."): + name = "model." + name[21:] + # NOTE: + # if transformers version is 4.56.1 or older(not tested), + # the following line should be uncommented + + # elif name.startswith("model."): + # name = name[6:] + state_dict[name] = reduce_tensor(v) + + self.send( + state_dict, self._rollout_group_name, self._weight_dst_rank_in_rollout + ) + + if self.cfg.actor.get("enable_offload", False): + self.offload_fsdp_param_and_grad() + + def compute_logprobs(self) -> None: + self.model.eval() + self.rollout_batch["logprob"] = self.rollout_batch["prev_logprobs"] + + def get_batch( + self, channel: Channel + ) -> Tuple[Dict[str, torch.Tensor], RolloutResult]: + result: RolloutResult = channel.get() + + batch = result.to_actor_batch( + self.cfg.data.max_prompt_length, + self.cfg.actor.model.encoder_seq_length, + self.tokenizer.eos_token_id, + ) + return batch, result + + def put_result(self, result: RolloutResult, channel: Channel) -> None: + if channel.is_local: + # Local channel, every process will put its own data locally + # No need to broadcast + channel.put(result) + else: + if self.is_data_io_rank: + channel.put(result) + + def _load_weight_and_optimizer(self, channel: Channel) -> None: + # Acquire the GPUs to ensure that no one is using them before loading models + # Otherwise, it may lead to OOM + with channel.device_lock: + if self.cfg.actor.get("enable_offload", False): + self.load_fsdp_param_and_grad(self.device) + self.load_fsdp_optimizer(self.device) + + @torch.no_grad() + def inference_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + self.model.eval() + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + position_ids = batch["position_ids"] + + multi_modal_inputs = {} + if "multi_modal_inputs" in batch.keys(): + for key in batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat( + [inputs[key] for inputs in batch["multi_modal_inputs"]], + dim=0, + ).npu() + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + **multi_modal_inputs, + ) + + logits = outputs.logits + logits = logits[:, -self.response_len - 1 : -1, :] + logits = logits / self.cfg.algorithm.sampling_params.temperature + + responses = input_ids[:, -self.response_len :] + logprobs = compute_logprobs_from_logits( + logits, responses, task_type=self.cfg.runner.task_type + ) + return logprobs + + def run_inference( + self, + input_channel: Channel, + output_channel: Channel, + rollout_channel: Channel, + compute_ref_logprobs: bool, + ) -> None: + """ + Compute prev/ref logprobs using the actor Model's forward. + + Args: + input_channel: The input channel to read from. + output_channel: The output channel to send results to. + rollout_channel: get the rollout channel's device lock in case of collision. + compute_ref_logprobs: Whether to compute reference logprobs. + """ + recv_batch_size = 0 + while recv_batch_size < self.total_batch_size_per_dp: + batch, rollout_result = self.get_batch(input_channel) + recv_batch_size += rollout_result.num_sequence + self._load_weight_and_optimizer( + input_channel if self.is_pipeline else rollout_channel + ) + + with self.worker_timer(): + prev_logprobs = self.inference_step(batch) + rollout_result.prev_logprobs = prev_logprobs.cpu() + + if compute_ref_logprobs: + assert self.ref_policy_state_dict is not None, ( + "Reference policy state dict is None but compute_ref_logprobs is True" + ) + with cpu_weight_swap(self.model, self.ref_policy_state_dict): + ref_logprobs = self.inference_step(batch) + rollout_result.ref_logprobs = ref_logprobs.cpu() + self.put_result(rollout_result, output_channel) + + assert recv_batch_size == self.total_batch_size_per_dp, ( + f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" + ) + + def run_training(self, input_channel: Channel) -> Tuple[Dict, list]: + # Get all batches for this DP + batches = [] + recv_batch_size = 0 + while recv_batch_size < self.total_batch_size_per_dp: + batch, rollout_result = self.get_batch(input_channel) + batches.append(batch) + recv_batch_size += rollout_result.num_sequence + assert recv_batch_size == self.total_batch_size_per_dp, ( + f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" + ) + batch = RolloutResult.merge_batches(batches) + # Must be called after batch is retrieved, which is when rollout has stopped + # Otherwise, loading model might cause OOM + self._load_weight_and_optimizer(input_channel) + + global_batches = get_iterator_k_split( + batch, + num_splits=self.cfg.algorithm.n_minibatches, + shuffle=self.cfg.algorithm.get("shuffle_rollout", True), + shuffle_seed=self.cfg.actor.seed, + ) + + self.model.train() + assert ( + self.cfg.actor.global_batch_size + % (self.cfg.actor.micro_batch_size * self._world_size) + == 0 + ) + + training_metrics_list = [] + # Global batch iterations + with self.worker_timer(): + for global_batch in global_batches: + train_global_batch_size = global_batch["input_ids"].shape[0] + + assert train_global_batch_size % self.cfg.actor.micro_batch_size == 0, ( + f"{train_global_batch_size=}, {self.cfg.actor.micro_batch_size=}" + ) + + self.gradient_accumulation = ( + train_global_batch_size // self.cfg.actor.micro_batch_size + ) + # split batch into micro_batches + train_micro_batches = get_iterator_k_split( + global_batch, + train_global_batch_size // self.cfg.actor.micro_batch_size, + ) + + self.optimizer.zero_grad() + metrics = {} + for idx, m_batch in enumerate(train_micro_batches): + backward_ctx = ( + self.model.no_sync() + if idx < self.gradient_accumulation - 1 + else nullcontext() + ) + for k, v in m_batch.items(): + m_batch[k] = v.npu() if isinstance(v, torch.Tensor) else v + + multi_modal_inputs = {} + if "multi_modal_inputs" in m_batch.keys(): + for key in m_batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat( + [ + inputs[key] + for inputs in m_batch["multi_modal_inputs"] + ], + dim=0, + ).npu() + + input_ids = m_batch["input_ids"] + attention_mask = m_batch["attention_mask"] + position_ids = m_batch["position_ids"] + prev_logprobs = m_batch["prev_logprobs"] + advantages = m_batch["advantages"] + ref_logprobs = None + if "ref_logprobs" in m_batch: + ref_logprobs = m_batch["ref_logprobs"] + + loss_mask = m_batch["attention_mask"][:, -self.response_len :] + output = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + use_cache=False, + ) + + logits = output.logits + + logits.div_(self.cfg.algorithm.sampling_params.temperature) + + responses = input_ids[:, -self.response_len :] + logits = logits[ + :, -self.response_len - 1 : -1, : + ] # (bsz, response_length, vocab_size) + logprobs = compute_logprobs_from_logits( + logits, responses, task_type=self.cfg.runner.task_type + ) + + clip_ratio = self.cfg.algorithm.ratio_clip_eps + clip_ratio_low = ( + self.cfg.algorithm.clip_ratio_low + if self.cfg.algorithm.clip_ratio_low is not None + else clip_ratio + ) + clip_ratio_high = ( + self.cfg.algorithm.clip_ratio_high + if self.cfg.algorithm.clip_ratio_high is not None + else clip_ratio + ) + clip_ratio_c = self.cfg.algorithm.get("clip_ratio_c", 3.0) + + loss, mbs_metrics_data = actor_loss( + loss_type=self.cfg.algorithm.loss_type, + loss_agg_func=self.loss_agg_func, + logprobs=logprobs, + old_logprobs=prev_logprobs, + advantages=advantages, + clip_ratio_low=clip_ratio_low, + clip_ratio_high=clip_ratio_high, + clip_ratio_c=clip_ratio_c, + loss_mask=loss_mask, + ) + + entropy_loss = torch.tensor(0.0, device=torch.npu.current_device()) + if self.calculate_entropy: + entropy = output["entropy"][ + :, -self.response_len - 1 : -1 + ].contiguous() + entropy_loss = self.loss_agg_func(entropy, mask=loss_mask) + if self.calculate_entropy_loss: + loss = ( + loss - self.cfg.algorithm.entropy_bonus * entropy_loss + ) + + kl_loss = torch.tensor(0.0, device=torch.npu.current_device()) + if self.kl_beta > 0 and ref_logprobs is not None: + kld = kl_penalty(ref_logprobs, logprobs, self.kl_penalty_type) + kl_loss = self.loss_agg_func(kld, loss_mask) + loss = loss + kl_loss * self.kl_beta + + # add to log + # scale loss for gradient accumulation and backprop + loss = loss / self.gradient_accumulation + with backward_ctx: + loss.backward() + + mbs_metrics_data.update( + { + "final_loss": loss.detach(), + "entropy_loss": entropy_loss.detach(), + "kl_loss": kl_loss.detach(), + } + ) + + append_to_dict(metrics, mbs_metrics_data) + # apply gradient clipping and optimizer step at the end of a global batch + grad_norm = self.model.clip_grad_norm_( + max_norm=self.cfg.actor.optim.clip_grad + ) + if not torch.isfinite(grad_norm).all(): + self.log_warning( + "grad norm is not finite, skip this optimizer step." + ) + else: + self.optimizer.step() + self.optimizer.zero_grad() + + # aggregate metrics across micro-batches + mean_metric_dict = { + key: torch.mean(torch.stack(value)) + for key, value in metrics.items() + } + mean_metric_dict = all_reduce_dict( + mean_metric_dict, op=torch.distributed.ReduceOp.AVG + ) + # add optimizer stats + if torch.is_tensor(grad_norm): + mean_metric_dict["actor/grad_norm"] = float( + grad_norm.detach().item() + ) + else: + mean_metric_dict["actor/grad_norm"] = float(grad_norm) + lr = self.optimizer.param_groups[0]["lr"] + mean_metric_dict["actor/lr"] = torch.as_tensor(lr).float().cpu() + training_metrics_list.append(mean_metric_dict) + + # Rollout metrics + rollout_metrics, _, _ = compute_math_rollout_metrics( + batch, self.cfg.data.max_prompt_length, self.response_len, self._world_size + ) + + return rollout_metrics, training_metrics_list + + def save_checkpoint(self, save_base_path: str, step: int) -> None: + torch.distributed.barrier() + model_state = self.get_model_state_dict() + optim_state = self.get_optimizer_state_dict() + if self._rank == 0: + os.makedirs(save_base_path, exist_ok=True) + torch.save(model_state, os.path.join(save_base_path, "model.pt")) + torch.save(optim_state, os.path.join(save_base_path, "optim.pt")) + torch.distributed.barrier() + + # Advantages and returns + def compute_advantages_and_returns( + self, input_channel: Channel, output_channel: Channel + ) -> None: + """Compute the advantages and returns. + + Args: + input_channel: The input channel to read from. + output_channel: The output channel to send results to. + """ + recv_batch_size = 0 + while recv_batch_size < self.total_batch_size_per_dp: + batch, rollout_result = self.get_batch(input_channel) + recv_batch_size += rollout_result.num_sequence + + with self.worker_timer(): + if rollout_result.advantages is None: + mask = batch["attention_mask"][:, -self.response_len :] + advantages, returns = calculate_adv_and_returns( + adv_type=self.cfg.algorithm.adv_type, + reward_scores=batch["rewards"].npu(), + mask=mask.npu(), + num_responses=self.cfg.algorithm.group_size, + ) + rollout_result.advantages = advantages.cpu() + + self.put_result(rollout_result, output_channel) + + assert recv_batch_size == self.total_batch_size_per_dp, ( + f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" + ) + + +class EmbodiedFSDPActor(FSDPModelManager, Worker): + def __init__(self, cfg: DictConfig): + Worker.__init__(self) + super().__init__(cfg.actor) + + self.cfg = cfg + torch.npu.set_device(int(os.environ["LOCAL_RANK"])) + self.device = torch.npu.current_device() + world_size = self._world_size + self.device_mesh = init_device_mesh( + "npu", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] + ) + + self._env_group_name = cfg.env.group_name + self._rollout_group_name = cfg.rollout.group_name + self._component_placement = HybridComponentPlacement(cfg, Cluster()) + self._weight_dst_rank_in_rollout = self._rank + if self._weight_dst_rank_in_rollout >= self._component_placement.get_world_size( + "rollout" + ): + self._weight_dst_rank_in_rollout = None + + self._obs_queue_name = cfg.env.channel.queue_name + self._action_queue_name = cfg.rollout.channel.queue_name + self._replay_buffer_name = cfg.actor.channel.queue_name + # stage_num: default to 2, use for pipeline rollout process + self.stage_num = cfg.rollout.pipeline_stage_num + + self.channel = self.connect_channel(cfg.actor.channel.name) + self.channel.create_queue( + cfg.actor.channel.queue_name, maxsize=cfg.actor.channel.queue_size + ) + + def init_worker(self): + self.setup_model_and_optimizer() + if self.cfg.actor.get("enable_offload", False): + self.offload_fsdp_param_and_grad() + self.offload_fsdp_optimizer() + torch.npu.synchronize() + gc.collect() + torch.npu.empty_cache() + + def model_provider_func(self): + model = get_model(self.cfg.actor.checkpoint_load_path, self.cfg.actor.model) + if model is not None: + return model + return super().model_provider_func() + + def sync_model_to_rollout(self): + if next(self.model.parameters()).is_cpu: + self.load_fsdp_param_and_grad(self.device) + self.load_fsdp_optimizer(self.device) + + state_dict = self.get_model_state_dict() + if self._weight_dst_rank_in_rollout is not None: + self.send( + state_dict, self._rollout_group_name, self._weight_dst_rank_in_rollout + ) + if self.cfg.actor.get("enable_offload", False): + self.offload_fsdp_param_and_grad() + self.offload_fsdp_optimizer() + torch.npu.synchronize() + del state_dict + gc.collect() + torch.npu.empty_cache() + + async def recv_rollout_batch(self): + send_num = self._component_placement.get_world_size("rollout") * self.stage_num + recv_num = self._component_placement.get_world_size("actor") + split_num = compute_split_num(send_num, recv_num) + + self.rollout_batch = {} + recv_list = [] + for i in range(split_num): + recv_list.append( + await self.channel.get( + queue_name=self._replay_buffer_name, async_op=True + ).async_wait() + ) + + # shape [num_chunk, bsz, chunk_size], cat dim 1 + for key in recv_list[0].keys(): + if "env_info/" not in key: + self.rollout_batch[key] = torch.cat( + [recv_list[i][key] for i in range(split_num)], dim=1 + ) + else: + self.rollout_batch[key] = torch.cat( + [recv_list[i][key] for i in range(split_num)], dim=0 + ) + + self.rollout_batch = self._process_received_rollout_batch(self.rollout_batch) + + def _process_received_rollout_batch(self, rollout_batch): + """ + original shape: [rollout_epoch x n_chunk_steps, bsz, num_action_chunks, ...] + target shape: [n_chunk_steps, rollout_epoch x bsz, num_action_chunks, ...] + """ + rollout_epoch = self.cfg.algorithm.rollout_epoch + for key, value in rollout_batch.items(): + new_value = value.reshape( + rollout_epoch, -1, *value.shape[1:] + ) # [rollout_epoch, n_chunk_step, bsz, ...] + new_value = new_value.transpose( + 0, 1 + ) # [n_chunk_step, rollout_epoch, bsz, ...] + new_value = new_value.reshape(new_value.shape[0], -1, *new_value.shape[3:]) + rollout_batch[key] = new_value + + if ( + not self.cfg.env.train.auto_reset + and not self.cfg.env.train.ignore_terminations + ): + dones = rollout_batch[ + "dones" + ] # [n_chunk_step, rollout_epoch x bsz, num_action_chunks] + loss_mask, loss_mask_sum = compute_loss_mask(dones) + + if self.cfg.algorithm.reward_type == "chunk_level": + loss_mask = loss_mask.any(dim=-1, keepdim=True) + loss_mask_sum = loss_mask_sum[..., -1:] + + rollout_batch["loss_mask"] = loss_mask + rollout_batch["loss_mask_sum"] = loss_mask_sum + + # filter data by rewards + if self.cfg.algorithm.get("filter_rewards", False): + rewards = rollout_batch[ + "rewards" + ] # [n_chunk_step, batch, num_action_chunks] + if self.rollout_batch.get("loss_mask", None) is not None: + rewards = rewards * rollout_batch["loss_mask"] + n_chunk_step, batch_size, num_action_chunks = rewards.shape + + group_size = self.cfg.algorithm.group_size + assert batch_size % group_size == 0, ( + f"batch {batch_size} not divisible by group_size {group_size}" + ) + n_prompts = batch_size // group_size + + # calculate rewards by prompt + rewards = rewards.transpose( + 0, 1 + ) # [batch, n_chunk_step, num_action_chunks] + rewards = rewards.reshape(rewards.shape[0], -1) # [batch, n_step] + reward_matrix = rewards.reshape( + n_prompts, group_size, rewards.shape[-1] + ) # [n_prompts, group_size, n_step] + reward_matrix = reward_matrix.sum(dim=-1) # [n_prompts, group_size] + mean_reward_in_group = reward_matrix.mean(dim=1) # [n_prompts] + + # mask + reward_filter_mask = ( + mean_reward_in_group >= self.cfg.algorithm.rewards_lower_bound + ) & ( + mean_reward_in_group <= self.cfg.algorithm.rewards_upper_bound + ) # [n_prompts] + + # extend mask dimension + reward_filter_mask = reward_filter_mask.repeat_interleave( + group_size + ) # [batch] + reward_filter_mask = ( + reward_filter_mask.unsqueeze(0).expand(n_chunk_step, -1).unsqueeze(-1) + ) # [n_chunk_step, batch, 1] + + # update loss_mask + if self.rollout_batch.get("loss_mask", None) is not None: + rollout_batch["loss_mask"] = ( + reward_filter_mask & self.rollout_batch["loss_mask"] + ) + else: + rollout_batch["loss_mask"] = reward_filter_mask + + return rollout_batch + + def compute_logprobs(self): + self.model.eval() + self.rollout_batch["logprob"] = self.rollout_batch["prev_logprobs"] + + def compute_advantages_and_returns(self): + stage_num = self.cfg.rollout.pipeline_stage_num + env_world_size = self._component_placement.get_world_size("env") + actor_world_size = self._component_placement.get_world_size("actor") + num_group_envs_for_train = ( + self.cfg.algorithm.num_group_envs + * stage_num + * env_world_size + // actor_world_size + ) + + kwargs = { + "adv_type": self.cfg.algorithm.adv_type, + "rewards": self.rollout_batch["rewards"], + "dones": self.rollout_batch["dones"], + "normalize_advantages": self.cfg.algorithm.get( + "normalize_advantages", True + ), + "values": self.rollout_batch.get("prev_values", None), + "gamma": self.cfg.algorithm.get("gamma", 1), + "gae_lambda": self.cfg.algorithm.get("gae_lambda", 1), + "num_group_envs": num_group_envs_for_train, + "group_size": self.cfg.algorithm.get("group_size", 8), + "reward_type": self.cfg.algorithm.reward_type, + "loss_mask": self.rollout_batch.get("loss_mask", None), + "rollout_epoch": self.cfg.algorithm.get("rollout_epoch", 1), + } + kwargs = preprocess_advantages_inputs(**kwargs) + advantages, returns = calculate_adv_and_returns(**kwargs) + + self.rollout_batch.update({"advantages": advantages, "returns": returns}) + rollout_metrics = compute_rollout_metrics(self.rollout_batch) + return rollout_metrics + + def run_training(self): + if self.cfg.actor.get("enable_offload", False): + self.load_fsdp_param_and_grad(self.device) + self.load_fsdp_optimizer(self.device) + + self.model.train() + self.optimizer.zero_grad() + rollout_size = ( + self.rollout_batch["input_ids"].shape[0] + * self.rollout_batch["input_ids"].shape[1] + ) + shuffle_id = torch.randperm(rollout_size) + + for key, value in self.rollout_batch.items(): + self.log_on_first_rank(f"run training, {key}: {value.shape}") + + with torch.no_grad(): + for key, value in self.rollout_batch.items(): + if key in ["dones", "prev_values"]: + value = value[:-1] + if "env_info" in key: + continue + value = value.reshape(rollout_size, *value.shape[2:]) + self.rollout_batch[key] = value[shuffle_id] + + assert ( + self.cfg.actor.global_batch_size + % (self.cfg.actor.micro_batch_size * self._world_size) + == 0 + ) + + self.gradient_accumulation = ( + self.cfg.actor.global_batch_size + // self.cfg.actor.micro_batch_size + // self._world_size + ) + + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + rollout_size = self.rollout_batch["input_ids"].size(0) + batch_size_per_rank = self.cfg.actor.global_batch_size // self._world_size + assert rollout_size % batch_size_per_rank == 0, ( + f"{rollout_size} is not divisible by {batch_size_per_rank}" + ) + rollout_dataloader_iter = get_iterator_k_split( + self.rollout_batch, + rollout_size // batch_size_per_rank, + ) + + metrics = {} + for _, train_global_batch in tqdm( + enumerate(rollout_dataloader_iter), desc="get loss and metrics" + ): + # split batch into micro_batches + train_global_batch_size = train_global_batch["input_ids"].shape[0] + assert ( + train_global_batch_size + == self.cfg.actor.global_batch_size + // torch.distributed.get_world_size() + ) + assert train_global_batch_size % self.cfg.actor.micro_batch_size == 0, ( + f"{train_global_batch_size=}, {self.cfg.actor.micro_batch_size}" + ) + train_micro_batch = get_iterator_k_split( + train_global_batch, + train_global_batch_size // self.cfg.actor.micro_batch_size, + ) + + self.optimizer.zero_grad() + for data_idx, data in enumerate(train_micro_batch): + for k, v in data.items(): + data[k] = v.to(f"cuda:{int(os.environ['LOCAL_RANK'])}") + + data = self.model.preprocess_for_train(data) + input_ids = data["input_ids"] + action_tokens = data["action_tokens"] + attention_mask = data["attention_mask"] + pixel_values = data["pixel_values"] + + action_token_len = self.model.action_dim * self.model.num_action_chunks + + logits_processor_args = { + "action_tokens": action_tokens, + "vocab_size": self.model.vocab_size, + "n_action_bins": self.model.config.n_action_bins, + } + + output_dict = custom_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + action_token_len=action_token_len, + value_model=True + if self.cfg.algorithm.adv_type == "embodied_gae" + else False, + value_head_mode=self.cfg.actor.model.get("vh_mode", None), + temperature=self.cfg.algorithm.sampling_params.temperature_train, + top_k=self.cfg.algorithm.sampling_params.top_k, + logits_processor_args=logits_processor_args, + ) + + kwargs = { + "loss_type": self.cfg.algorithm.loss_type, + "logprob_type": self.cfg.algorithm.logprob_type, + "entropy_type": self.cfg.algorithm.entropy_type, + "single_action_dim": self.model.action_dim, + "logprobs": output_dict["logprobs"], + "entropy": output_dict["entropy"], + "values": output_dict.get("values", None), + "old_logprobs": data["prev_logprobs"], + "advantages": data["advantages"], + "returns": data["returns"], + "prev_values": data["prev_values"], + "clip_ratio_high": self.cfg.algorithm.clip_ratio_high, + "clip_ratio_low": self.cfg.algorithm.clip_ratio_low, + "value_clip": self.cfg.algorithm.get("value_clip", None), + "huber_delta": self.cfg.algorithm.get("huber_delta", None), + "entropy_bonus": self.cfg.algorithm.entropy_bonus, + "loss_mask": data.get("loss_mask", None), + "loss_mask_sum": data.get("loss_mask_sum", None), + "max_episode_steps": self.cfg.env.train.max_episode_steps, + } + + kwargs = preprocess_loss_inputs(**kwargs) + + loss, metrics_data = actor_loss(**kwargs) + + loss /= self.gradient_accumulation + loss.backward() + + metrics_data["loss"] = loss.detach().item() + append_to_dict(metrics, metrics_data) + + torch.npu.empty_cache() + + grad_norm = self.model.clip_grad_norm_( + max_norm=self.cfg.actor.optim.clip_grad + ) + self.optimizer.step() + + self.optimizer.zero_grad() + data = { + "actor/grad_norm": grad_norm.detach().item(), + "actor/lr": self.optimizer.param_groups[0]["lr"], + } + if self.cfg.algorithm.adv_type == "embodied_gae": + data["critic/lr"] = self.optimizer.param_groups[1]["lr"] + append_to_dict(metrics, data) + + mean_metric_dict = {key: np.mean(value) for key, value in metrics.items()} + mean_metric_dict = all_reduce_dict( + mean_metric_dict, op=torch.distributed.ReduceOp.AVG + ) + + self.optimizer.zero_grad() + torch.npu.synchronize() + torch.distributed.barrier() + torch.npu.empty_cache() + + return mean_metric_dict + + def save_checkpoint(self, save_base_path, step): + torch.distributed.barrier() + model_state = self.get_model_state_dict() + optim_state = self.get_optimizer_state_dict() + if self._rank == 0: + os.makedirs(save_base_path, exist_ok=True) + torch.save(model_state, os.path.join(save_base_path, "model.pt")) + torch.save(optim_state, os.path.join(save_base_path, "optim.pt")) + torch.distributed.barrier() diff --git a/rlinf/workers/rollout/sglang/__init__.py b/rlinf/workers/rollout/sglang/__init__.py index 5e0fb219f..cc78162b9 100644 --- a/rlinf/workers/rollout/sglang/__init__.py +++ b/rlinf/workers/rollout/sglang/__init__.py @@ -49,6 +49,12 @@ def get_version(pkg): from rlinf.hybrid_engines.sglang.sglang_0_4_9.sgl_engine import ( Engine, ) +elif package_version >= parse("0.5.0") and package_version < parse("0.5.3"): + sglang_version = package_version + from rlinf.hybrid_engines.sglang.sglang_0_5_2 import io_struct + from rlinf.hybrid_engines.sglang.sglang_0_5_2.sgl_engine import ( + Engine, + ) else: raise ValueError(f"sglang version {package_version} not supported") diff --git a/rlinf/workers/rollout/sglang/sglang_worker.py b/rlinf/workers/rollout/sglang/sglang_worker.py index 4d5c51552..a24b297f2 100644 --- a/rlinf/workers/rollout/sglang/sglang_worker.py +++ b/rlinf/workers/rollout/sglang/sglang_worker.py @@ -113,6 +113,7 @@ def _init_engine(self): log_level="info", max_running_requests=self._cfg.rollout.max_running_requests, dist_init_addr=f"127.0.0.1:{str(Cluster.find_free_port())}", + device="npu", ) self.log_on_first_rank(f"{server_args=}") diff --git a/rlinf/workers/rollout/vllm/vllm_worker.py b/rlinf/workers/rollout/vllm/vllm_worker.py index 3d36b9a44..7324fe841 100644 --- a/rlinf/workers/rollout/vllm/vllm_worker.py +++ b/rlinf/workers/rollout/vllm/vllm_worker.py @@ -337,7 +337,8 @@ async def init_worker(self) -> None: trust_remote_code=self._cfg.actor.tokenizer.trust_remote_code, max_model_len=self._cfg.runner.seq_length, max_num_seqs=self._cfg.rollout.max_running_requests, - enable_sleep_mode=True, # it enables offload weights + enable_sleep_mode=False, + device="npu", ) vllm_config: VllmConfig = engine_args.create_engine_config() diff --git a/test.py b/test.py new file mode 100644 index 000000000..485f51f99 --- /dev/null +++ b/test.py @@ -0,0 +1,3 @@ +from safetensors import safe_open +with safe_open("/home/weight/Qwen2.5-VL-3B-Instruct/model-00001-of-00002.safetensors", framework="pt") as f: + print(f.keys()) From 73fb07e444ae90462e478c9723905ae79c9e5e8d Mon Sep 17 00:00:00 2001 From: Varian-cym <1842506975@qq.com> Date: Fri, 7 Nov 2025 08:54:23 +0000 Subject: [PATCH 38/38] . --- rlinf/workers/actor/fsdp_actor_worker_bak.py | 903 ------------------- 1 file changed, 903 deletions(-) delete mode 100644 rlinf/workers/actor/fsdp_actor_worker_bak.py diff --git a/rlinf/workers/actor/fsdp_actor_worker_bak.py b/rlinf/workers/actor/fsdp_actor_worker_bak.py deleted file mode 100644 index 84d8e09c5..000000000 --- a/rlinf/workers/actor/fsdp_actor_worker_bak.py +++ /dev/null @@ -1,903 +0,0 @@ -# Copyright 2025 The RLinf Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gc -import os -from contextlib import nullcontext -from typing import Dict, Tuple - -import numpy as np -import torch -from omegaconf import DictConfig -from torch.distributed.device_mesh import init_device_mesh -from torch.multiprocessing.reductions import reduce_tensor -from tqdm import tqdm - -import rlinf.algorithms # noqa: F401 -from rlinf.algorithms.registry import actor_loss, calculate_adv_and_returns -from rlinf.algorithms.utils import ( - kl_penalty, - preprocess_advantages_inputs, - preprocess_loss_inputs, -) -from rlinf.data.io_struct import RolloutResult -from rlinf.hybrid_engines.fsdp.fsdp_model_manager import ( - FSDPModelManager, -) -from rlinf.models import get_model -from rlinf.models.embodiment.model_utils import custom_forward -from rlinf.scheduler import Channel, Cluster, Worker -from rlinf.utils.data_iter_utils import get_iterator_k_split -from rlinf.utils.distributed import all_reduce_dict -from rlinf.utils.distributed import ( - compute_rollout_metrics as compute_math_rollout_metrics, -) -from rlinf.utils.metric_utils import ( - append_to_dict, - compute_loss_mask, - compute_rollout_metrics, - compute_split_num, -) -from rlinf.utils.placement import ( - HybridComponentPlacement, - ModelParallelComponentPlacement, -) -from rlinf.utils.utils import ( - compute_logprobs_from_logits, - cpu_weight_swap, - masked_mean, - retrieve_model_state_dict_in_cpu, - seq_mean_token_mean, - seq_mean_token_sum, -) -from rlinf.workers.rollout.utils import RankMapper -import torch_npu - -class FSDPActor(FSDPModelManager, Worker): - def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): - Worker.__init__(self) - super().__init__(cfg.actor) - - self.cfg = cfg - - self.response_len = ( - cfg.actor.model.encoder_seq_length - cfg.data.max_prompt_length - ) - self.calculate_entropy = self.cfg.algorithm.calculate_entropy - self.calculate_entropy_loss = ( - self.cfg.algorithm.entropy_bonus > 0 and self.calculate_entropy - ) - self.kl_beta = self.cfg.algorithm.kl_beta - self.kl_penalty_type = self.cfg.algorithm.kl_penalty_type - - self.total_batch_size_per_dp = ( - self.cfg.data.rollout_batch_size - * self.cfg.algorithm.group_size - // self._world_size - ) - - torch.npu.set_device(int(os.environ["LOCAL_RANK"])) - self.device = torch.npu.current_device() - world_size = self._world_size - self.device_mesh = init_device_mesh( - "npu", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] - ) - - self._rollout_group_name = cfg.rollout.group_name - self._component_placement = placement - self.is_data_io_rank = True - self.is_pipeline = self._component_placement.is_disaggregated - self.ref_policy_state_dict = None - - if self.cfg.algorithm.loss_agg_func == "token-mean": - self.loss_agg_func = masked_mean - elif self.cfg.algorithm.loss_agg_func == "seq-mean-token-sum": - self.loss_agg_func = seq_mean_token_sum - elif self.cfg.algorithm.loss_agg_func == "seq-mean-token-mean": - self.loss_agg_func = seq_mean_token_mean - else: - raise NotImplementedError( - f"algorithm.loss_agg_func={self.cfg.algorithm.loss_agg_func} is not supported!" - ) - - def init_worker(self) -> None: - self.setup_model_and_optimizer() - if self.cfg.algorithm.kl_beta > 0 and self.cfg.actor.get( - "combine_reference_model", True - ): - self.ref_policy_state_dict = retrieve_model_state_dict_in_cpu(self.model) - - if self.cfg.actor.get("enable_offload", False): - self.offload_fsdp_param_and_grad() - self.offload_fsdp_optimizer() - self._setup_rollout_weight_dst_ranks() - - def _setup_rollout_weight_dst_ranks(self) -> None: - """Setup destination ranks for token and weight communication.""" - rank_map = RankMapper.get_actor_rank_to_rollout_rank_map( - self._component_placement - ) - self._weight_dst_rank_in_rollout = rank_map[self._rank] - self.log_info( - f"Actor rank {self._rank} will send weights to {self._weight_dst_rank_in_rollout}" - ) - - def del_reshard_state_dict(self) -> None: - if hasattr(self, "rollout_state_dict"): - del self.rollout_state_dict - - def sync_model_to_rollout(self) -> None: - if self.cfg.actor.get("enable_offload", False): - self.offload_fsdp_optimizer() - - if next(self.model.parameters()).is_cpu: - self.load_fsdp_param_and_grad(self.device) - self.rollout_state_dict = self.get_model_state_dict() - - has_visual = any("visual." in k for k in self.rollout_state_dict.keys()) - - state_dict = {} - - if self._weight_dst_rank_in_rollout is not None: - for k, v in self.rollout_state_dict.items(): - name = k - if has_visual: - if name.startswith("model.language_model."): - name = "model." + name[21:] - # NOTE: - # if transformers version is 4.56.1 or older(not tested), - # the following line should be uncommented - - # elif name.startswith("model."): - # name = name[6:] - state_dict[name] = reduce_tensor(v) - - self.send( - state_dict, self._rollout_group_name, self._weight_dst_rank_in_rollout - ) - - if self.cfg.actor.get("enable_offload", False): - self.offload_fsdp_param_and_grad() - - def compute_logprobs(self) -> None: - self.model.eval() - self.rollout_batch["logprob"] = self.rollout_batch["prev_logprobs"] - - def get_batch( - self, channel: Channel - ) -> Tuple[Dict[str, torch.Tensor], RolloutResult]: - result: RolloutResult = channel.get() - - batch = result.to_actor_batch( - self.cfg.data.max_prompt_length, - self.cfg.actor.model.encoder_seq_length, - self.tokenizer.eos_token_id, - ) - return batch, result - - def put_result(self, result: RolloutResult, channel: Channel) -> None: - if channel.is_local: - # Local channel, every process will put its own data locally - # No need to broadcast - channel.put(result) - else: - if self.is_data_io_rank: - channel.put(result) - - def _load_weight_and_optimizer(self, channel: Channel) -> None: - # Acquire the GPUs to ensure that no one is using them before loading models - # Otherwise, it may lead to OOM - with channel.device_lock: - if self.cfg.actor.get("enable_offload", False): - self.load_fsdp_param_and_grad(self.device) - self.load_fsdp_optimizer(self.device) - - @torch.no_grad() - def inference_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - self.model.eval() - input_ids = batch["input_ids"] - attention_mask = batch["attention_mask"] - position_ids = batch["position_ids"] - - multi_modal_inputs = {} - if "multi_modal_inputs" in batch.keys(): - for key in batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat( - [inputs[key] for inputs in batch["multi_modal_inputs"]], - dim=0, - ).npu() - - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=False, - **multi_modal_inputs, - ) - - logits = outputs.logits - logits = logits[:, -self.response_len - 1 : -1, :] - logits = logits / self.cfg.algorithm.sampling_params.temperature - - responses = input_ids[:, -self.response_len :] - logprobs = compute_logprobs_from_logits( - logits, responses, task_type=self.cfg.runner.task_type - ) - return logprobs - - def run_inference( - self, - input_channel: Channel, - output_channel: Channel, - rollout_channel: Channel, - compute_ref_logprobs: bool, - ) -> None: - """ - Compute prev/ref logprobs using the actor Model's forward. - - Args: - input_channel: The input channel to read from. - output_channel: The output channel to send results to. - rollout_channel: get the rollout channel's device lock in case of collision. - compute_ref_logprobs: Whether to compute reference logprobs. - """ - recv_batch_size = 0 - while recv_batch_size < self.total_batch_size_per_dp: - batch, rollout_result = self.get_batch(input_channel) - recv_batch_size += rollout_result.num_sequence - self._load_weight_and_optimizer( - input_channel if self.is_pipeline else rollout_channel - ) - - with self.worker_timer(): - prev_logprobs = self.inference_step(batch) - rollout_result.prev_logprobs = prev_logprobs.cpu() - - if compute_ref_logprobs: - assert self.ref_policy_state_dict is not None, ( - "Reference policy state dict is None but compute_ref_logprobs is True" - ) - with cpu_weight_swap(self.model, self.ref_policy_state_dict): - ref_logprobs = self.inference_step(batch) - rollout_result.ref_logprobs = ref_logprobs.cpu() - self.put_result(rollout_result, output_channel) - - assert recv_batch_size == self.total_batch_size_per_dp, ( - f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" - ) - - def run_training(self, input_channel: Channel) -> Tuple[Dict, list]: - # Get all batches for this DP - batches = [] - recv_batch_size = 0 - while recv_batch_size < self.total_batch_size_per_dp: - batch, rollout_result = self.get_batch(input_channel) - batches.append(batch) - recv_batch_size += rollout_result.num_sequence - assert recv_batch_size == self.total_batch_size_per_dp, ( - f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" - ) - batch = RolloutResult.merge_batches(batches) - # Must be called after batch is retrieved, which is when rollout has stopped - # Otherwise, loading model might cause OOM - self._load_weight_and_optimizer(input_channel) - - global_batches = get_iterator_k_split( - batch, - num_splits=self.cfg.algorithm.n_minibatches, - shuffle=self.cfg.algorithm.get("shuffle_rollout", True), - shuffle_seed=self.cfg.actor.seed, - ) - - self.model.train() - assert ( - self.cfg.actor.global_batch_size - % (self.cfg.actor.micro_batch_size * self._world_size) - == 0 - ) - - training_metrics_list = [] - # Global batch iterations - with self.worker_timer(): - for global_batch in global_batches: - train_global_batch_size = global_batch["input_ids"].shape[0] - - assert train_global_batch_size % self.cfg.actor.micro_batch_size == 0, ( - f"{train_global_batch_size=}, {self.cfg.actor.micro_batch_size=}" - ) - - self.gradient_accumulation = ( - train_global_batch_size // self.cfg.actor.micro_batch_size - ) - # split batch into micro_batches - train_micro_batches = get_iterator_k_split( - global_batch, - train_global_batch_size // self.cfg.actor.micro_batch_size, - ) - - self.optimizer.zero_grad() - metrics = {} - for idx, m_batch in enumerate(train_micro_batches): - backward_ctx = ( - self.model.no_sync() - if idx < self.gradient_accumulation - 1 - else nullcontext() - ) - for k, v in m_batch.items(): - m_batch[k] = v.npu() if isinstance(v, torch.Tensor) else v - - multi_modal_inputs = {} - if "multi_modal_inputs" in m_batch.keys(): - for key in m_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat( - [ - inputs[key] - for inputs in m_batch["multi_modal_inputs"] - ], - dim=0, - ).npu() - - input_ids = m_batch["input_ids"] - attention_mask = m_batch["attention_mask"] - position_ids = m_batch["position_ids"] - prev_logprobs = m_batch["prev_logprobs"] - advantages = m_batch["advantages"] - ref_logprobs = None - if "ref_logprobs" in m_batch: - ref_logprobs = m_batch["ref_logprobs"] - - loss_mask = m_batch["attention_mask"][:, -self.response_len :] - output = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - **multi_modal_inputs, - use_cache=False, - ) - - logits = output.logits - - logits.div_(self.cfg.algorithm.sampling_params.temperature) - - responses = input_ids[:, -self.response_len :] - logits = logits[ - :, -self.response_len - 1 : -1, : - ] # (bsz, response_length, vocab_size) - logprobs = compute_logprobs_from_logits( - logits, responses, task_type=self.cfg.runner.task_type - ) - - clip_ratio = self.cfg.algorithm.ratio_clip_eps - clip_ratio_low = ( - self.cfg.algorithm.clip_ratio_low - if self.cfg.algorithm.clip_ratio_low is not None - else clip_ratio - ) - clip_ratio_high = ( - self.cfg.algorithm.clip_ratio_high - if self.cfg.algorithm.clip_ratio_high is not None - else clip_ratio - ) - clip_ratio_c = self.cfg.algorithm.get("clip_ratio_c", 3.0) - - loss, mbs_metrics_data = actor_loss( - loss_type=self.cfg.algorithm.loss_type, - loss_agg_func=self.loss_agg_func, - logprobs=logprobs, - old_logprobs=prev_logprobs, - advantages=advantages, - clip_ratio_low=clip_ratio_low, - clip_ratio_high=clip_ratio_high, - clip_ratio_c=clip_ratio_c, - loss_mask=loss_mask, - ) - - entropy_loss = torch.tensor(0.0, device=torch.npu.current_device()) - if self.calculate_entropy: - entropy = output["entropy"][ - :, -self.response_len - 1 : -1 - ].contiguous() - entropy_loss = self.loss_agg_func(entropy, mask=loss_mask) - if self.calculate_entropy_loss: - loss = ( - loss - self.cfg.algorithm.entropy_bonus * entropy_loss - ) - - kl_loss = torch.tensor(0.0, device=torch.npu.current_device()) - if self.kl_beta > 0 and ref_logprobs is not None: - kld = kl_penalty(ref_logprobs, logprobs, self.kl_penalty_type) - kl_loss = self.loss_agg_func(kld, loss_mask) - loss = loss + kl_loss * self.kl_beta - - # add to log - # scale loss for gradient accumulation and backprop - loss = loss / self.gradient_accumulation - with backward_ctx: - loss.backward() - - mbs_metrics_data.update( - { - "final_loss": loss.detach(), - "entropy_loss": entropy_loss.detach(), - "kl_loss": kl_loss.detach(), - } - ) - - append_to_dict(metrics, mbs_metrics_data) - # apply gradient clipping and optimizer step at the end of a global batch - grad_norm = self.model.clip_grad_norm_( - max_norm=self.cfg.actor.optim.clip_grad - ) - if not torch.isfinite(grad_norm).all(): - self.log_warning( - "grad norm is not finite, skip this optimizer step." - ) - else: - self.optimizer.step() - self.optimizer.zero_grad() - - # aggregate metrics across micro-batches - mean_metric_dict = { - key: torch.mean(torch.stack(value)) - for key, value in metrics.items() - } - mean_metric_dict = all_reduce_dict( - mean_metric_dict, op=torch.distributed.ReduceOp.AVG - ) - # add optimizer stats - if torch.is_tensor(grad_norm): - mean_metric_dict["actor/grad_norm"] = float( - grad_norm.detach().item() - ) - else: - mean_metric_dict["actor/grad_norm"] = float(grad_norm) - lr = self.optimizer.param_groups[0]["lr"] - mean_metric_dict["actor/lr"] = torch.as_tensor(lr).float().cpu() - training_metrics_list.append(mean_metric_dict) - - # Rollout metrics - rollout_metrics, _, _ = compute_math_rollout_metrics( - batch, self.cfg.data.max_prompt_length, self.response_len, self._world_size - ) - - return rollout_metrics, training_metrics_list - - def save_checkpoint(self, save_base_path: str, step: int) -> None: - torch.distributed.barrier() - model_state = self.get_model_state_dict() - optim_state = self.get_optimizer_state_dict() - if self._rank == 0: - os.makedirs(save_base_path, exist_ok=True) - torch.save(model_state, os.path.join(save_base_path, "model.pt")) - torch.save(optim_state, os.path.join(save_base_path, "optim.pt")) - torch.distributed.barrier() - - # Advantages and returns - def compute_advantages_and_returns( - self, input_channel: Channel, output_channel: Channel - ) -> None: - """Compute the advantages and returns. - - Args: - input_channel: The input channel to read from. - output_channel: The output channel to send results to. - """ - recv_batch_size = 0 - while recv_batch_size < self.total_batch_size_per_dp: - batch, rollout_result = self.get_batch(input_channel) - recv_batch_size += rollout_result.num_sequence - - with self.worker_timer(): - if rollout_result.advantages is None: - mask = batch["attention_mask"][:, -self.response_len :] - advantages, returns = calculate_adv_and_returns( - adv_type=self.cfg.algorithm.adv_type, - reward_scores=batch["rewards"].npu(), - mask=mask.npu(), - num_responses=self.cfg.algorithm.group_size, - ) - rollout_result.advantages = advantages.cpu() - - self.put_result(rollout_result, output_channel) - - assert recv_batch_size == self.total_batch_size_per_dp, ( - f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" - ) - - -class EmbodiedFSDPActor(FSDPModelManager, Worker): - def __init__(self, cfg: DictConfig): - Worker.__init__(self) - super().__init__(cfg.actor) - - self.cfg = cfg - torch.npu.set_device(int(os.environ["LOCAL_RANK"])) - self.device = torch.npu.current_device() - world_size = self._world_size - self.device_mesh = init_device_mesh( - "npu", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] - ) - - self._env_group_name = cfg.env.group_name - self._rollout_group_name = cfg.rollout.group_name - self._component_placement = HybridComponentPlacement(cfg, Cluster()) - self._weight_dst_rank_in_rollout = self._rank - if self._weight_dst_rank_in_rollout >= self._component_placement.get_world_size( - "rollout" - ): - self._weight_dst_rank_in_rollout = None - - self._obs_queue_name = cfg.env.channel.queue_name - self._action_queue_name = cfg.rollout.channel.queue_name - self._replay_buffer_name = cfg.actor.channel.queue_name - # stage_num: default to 2, use for pipeline rollout process - self.stage_num = cfg.rollout.pipeline_stage_num - - self.channel = self.connect_channel(cfg.actor.channel.name) - self.channel.create_queue( - cfg.actor.channel.queue_name, maxsize=cfg.actor.channel.queue_size - ) - - def init_worker(self): - self.setup_model_and_optimizer() - if self.cfg.actor.get("enable_offload", False): - self.offload_fsdp_param_and_grad() - self.offload_fsdp_optimizer() - torch.npu.synchronize() - gc.collect() - torch.npu.empty_cache() - - def model_provider_func(self): - model = get_model(self.cfg.actor.checkpoint_load_path, self.cfg.actor.model) - if model is not None: - return model - return super().model_provider_func() - - def sync_model_to_rollout(self): - if next(self.model.parameters()).is_cpu: - self.load_fsdp_param_and_grad(self.device) - self.load_fsdp_optimizer(self.device) - - state_dict = self.get_model_state_dict() - if self._weight_dst_rank_in_rollout is not None: - self.send( - state_dict, self._rollout_group_name, self._weight_dst_rank_in_rollout - ) - if self.cfg.actor.get("enable_offload", False): - self.offload_fsdp_param_and_grad() - self.offload_fsdp_optimizer() - torch.npu.synchronize() - del state_dict - gc.collect() - torch.npu.empty_cache() - - async def recv_rollout_batch(self): - send_num = self._component_placement.get_world_size("rollout") * self.stage_num - recv_num = self._component_placement.get_world_size("actor") - split_num = compute_split_num(send_num, recv_num) - - self.rollout_batch = {} - recv_list = [] - for i in range(split_num): - recv_list.append( - await self.channel.get( - queue_name=self._replay_buffer_name, async_op=True - ).async_wait() - ) - - # shape [num_chunk, bsz, chunk_size], cat dim 1 - for key in recv_list[0].keys(): - if "env_info/" not in key: - self.rollout_batch[key] = torch.cat( - [recv_list[i][key] for i in range(split_num)], dim=1 - ) - else: - self.rollout_batch[key] = torch.cat( - [recv_list[i][key] for i in range(split_num)], dim=0 - ) - - self.rollout_batch = self._process_received_rollout_batch(self.rollout_batch) - - def _process_received_rollout_batch(self, rollout_batch): - """ - original shape: [rollout_epoch x n_chunk_steps, bsz, num_action_chunks, ...] - target shape: [n_chunk_steps, rollout_epoch x bsz, num_action_chunks, ...] - """ - rollout_epoch = self.cfg.algorithm.rollout_epoch - for key, value in rollout_batch.items(): - new_value = value.reshape( - rollout_epoch, -1, *value.shape[1:] - ) # [rollout_epoch, n_chunk_step, bsz, ...] - new_value = new_value.transpose( - 0, 1 - ) # [n_chunk_step, rollout_epoch, bsz, ...] - new_value = new_value.reshape(new_value.shape[0], -1, *new_value.shape[3:]) - rollout_batch[key] = new_value - - if ( - not self.cfg.env.train.auto_reset - and not self.cfg.env.train.ignore_terminations - ): - dones = rollout_batch[ - "dones" - ] # [n_chunk_step, rollout_epoch x bsz, num_action_chunks] - loss_mask, loss_mask_sum = compute_loss_mask(dones) - - if self.cfg.algorithm.reward_type == "chunk_level": - loss_mask = loss_mask.any(dim=-1, keepdim=True) - loss_mask_sum = loss_mask_sum[..., -1:] - - rollout_batch["loss_mask"] = loss_mask - rollout_batch["loss_mask_sum"] = loss_mask_sum - - # filter data by rewards - if self.cfg.algorithm.get("filter_rewards", False): - rewards = rollout_batch[ - "rewards" - ] # [n_chunk_step, batch, num_action_chunks] - if self.rollout_batch.get("loss_mask", None) is not None: - rewards = rewards * rollout_batch["loss_mask"] - n_chunk_step, batch_size, num_action_chunks = rewards.shape - - group_size = self.cfg.algorithm.group_size - assert batch_size % group_size == 0, ( - f"batch {batch_size} not divisible by group_size {group_size}" - ) - n_prompts = batch_size // group_size - - # calculate rewards by prompt - rewards = rewards.transpose( - 0, 1 - ) # [batch, n_chunk_step, num_action_chunks] - rewards = rewards.reshape(rewards.shape[0], -1) # [batch, n_step] - reward_matrix = rewards.reshape( - n_prompts, group_size, rewards.shape[-1] - ) # [n_prompts, group_size, n_step] - reward_matrix = reward_matrix.sum(dim=-1) # [n_prompts, group_size] - mean_reward_in_group = reward_matrix.mean(dim=1) # [n_prompts] - - # mask - reward_filter_mask = ( - mean_reward_in_group >= self.cfg.algorithm.rewards_lower_bound - ) & ( - mean_reward_in_group <= self.cfg.algorithm.rewards_upper_bound - ) # [n_prompts] - - # extend mask dimension - reward_filter_mask = reward_filter_mask.repeat_interleave( - group_size - ) # [batch] - reward_filter_mask = ( - reward_filter_mask.unsqueeze(0).expand(n_chunk_step, -1).unsqueeze(-1) - ) # [n_chunk_step, batch, 1] - - # update loss_mask - if self.rollout_batch.get("loss_mask", None) is not None: - rollout_batch["loss_mask"] = ( - reward_filter_mask & self.rollout_batch["loss_mask"] - ) - else: - rollout_batch["loss_mask"] = reward_filter_mask - - return rollout_batch - - def compute_logprobs(self): - self.model.eval() - self.rollout_batch["logprob"] = self.rollout_batch["prev_logprobs"] - - def compute_advantages_and_returns(self): - stage_num = self.cfg.rollout.pipeline_stage_num - env_world_size = self._component_placement.get_world_size("env") - actor_world_size = self._component_placement.get_world_size("actor") - num_group_envs_for_train = ( - self.cfg.algorithm.num_group_envs - * stage_num - * env_world_size - // actor_world_size - ) - - kwargs = { - "adv_type": self.cfg.algorithm.adv_type, - "rewards": self.rollout_batch["rewards"], - "dones": self.rollout_batch["dones"], - "normalize_advantages": self.cfg.algorithm.get( - "normalize_advantages", True - ), - "values": self.rollout_batch.get("prev_values", None), - "gamma": self.cfg.algorithm.get("gamma", 1), - "gae_lambda": self.cfg.algorithm.get("gae_lambda", 1), - "num_group_envs": num_group_envs_for_train, - "group_size": self.cfg.algorithm.get("group_size", 8), - "reward_type": self.cfg.algorithm.reward_type, - "loss_mask": self.rollout_batch.get("loss_mask", None), - "rollout_epoch": self.cfg.algorithm.get("rollout_epoch", 1), - } - kwargs = preprocess_advantages_inputs(**kwargs) - advantages, returns = calculate_adv_and_returns(**kwargs) - - self.rollout_batch.update({"advantages": advantages, "returns": returns}) - rollout_metrics = compute_rollout_metrics(self.rollout_batch) - return rollout_metrics - - def run_training(self): - if self.cfg.actor.get("enable_offload", False): - self.load_fsdp_param_and_grad(self.device) - self.load_fsdp_optimizer(self.device) - - self.model.train() - self.optimizer.zero_grad() - rollout_size = ( - self.rollout_batch["input_ids"].shape[0] - * self.rollout_batch["input_ids"].shape[1] - ) - shuffle_id = torch.randperm(rollout_size) - - for key, value in self.rollout_batch.items(): - self.log_on_first_rank(f"run training, {key}: {value.shape}") - - with torch.no_grad(): - for key, value in self.rollout_batch.items(): - if key in ["dones", "prev_values"]: - value = value[:-1] - if "env_info" in key: - continue - value = value.reshape(rollout_size, *value.shape[2:]) - self.rollout_batch[key] = value[shuffle_id] - - assert ( - self.cfg.actor.global_batch_size - % (self.cfg.actor.micro_batch_size * self._world_size) - == 0 - ) - - self.gradient_accumulation = ( - self.cfg.actor.global_batch_size - // self.cfg.actor.micro_batch_size - // self._world_size - ) - - # Split to make minibatch iterator for updating the actor - # See PPO paper for details. https://arxiv.org/abs/1707.06347 - rollout_size = self.rollout_batch["input_ids"].size(0) - batch_size_per_rank = self.cfg.actor.global_batch_size // self._world_size - assert rollout_size % batch_size_per_rank == 0, ( - f"{rollout_size} is not divisible by {batch_size_per_rank}" - ) - rollout_dataloader_iter = get_iterator_k_split( - self.rollout_batch, - rollout_size // batch_size_per_rank, - ) - - metrics = {} - for _, train_global_batch in tqdm( - enumerate(rollout_dataloader_iter), desc="get loss and metrics" - ): - # split batch into micro_batches - train_global_batch_size = train_global_batch["input_ids"].shape[0] - assert ( - train_global_batch_size - == self.cfg.actor.global_batch_size - // torch.distributed.get_world_size() - ) - assert train_global_batch_size % self.cfg.actor.micro_batch_size == 0, ( - f"{train_global_batch_size=}, {self.cfg.actor.micro_batch_size}" - ) - train_micro_batch = get_iterator_k_split( - train_global_batch, - train_global_batch_size // self.cfg.actor.micro_batch_size, - ) - - self.optimizer.zero_grad() - for data_idx, data in enumerate(train_micro_batch): - for k, v in data.items(): - data[k] = v.to(f"cuda:{int(os.environ['LOCAL_RANK'])}") - - data = self.model.preprocess_for_train(data) - input_ids = data["input_ids"] - action_tokens = data["action_tokens"] - attention_mask = data["attention_mask"] - pixel_values = data["pixel_values"] - - action_token_len = self.model.action_dim * self.model.num_action_chunks - - logits_processor_args = { - "action_tokens": action_tokens, - "vocab_size": self.model.vocab_size, - "n_action_bins": self.model.config.n_action_bins, - } - - output_dict = custom_forward( - self.model, - input_ids=input_ids, - attention_mask=attention_mask, - pixel_values=pixel_values, - action_token_len=action_token_len, - value_model=True - if self.cfg.algorithm.adv_type == "embodied_gae" - else False, - value_head_mode=self.cfg.actor.model.get("vh_mode", None), - temperature=self.cfg.algorithm.sampling_params.temperature_train, - top_k=self.cfg.algorithm.sampling_params.top_k, - logits_processor_args=logits_processor_args, - ) - - kwargs = { - "loss_type": self.cfg.algorithm.loss_type, - "logprob_type": self.cfg.algorithm.logprob_type, - "entropy_type": self.cfg.algorithm.entropy_type, - "single_action_dim": self.model.action_dim, - "logprobs": output_dict["logprobs"], - "entropy": output_dict["entropy"], - "values": output_dict.get("values", None), - "old_logprobs": data["prev_logprobs"], - "advantages": data["advantages"], - "returns": data["returns"], - "prev_values": data["prev_values"], - "clip_ratio_high": self.cfg.algorithm.clip_ratio_high, - "clip_ratio_low": self.cfg.algorithm.clip_ratio_low, - "value_clip": self.cfg.algorithm.get("value_clip", None), - "huber_delta": self.cfg.algorithm.get("huber_delta", None), - "entropy_bonus": self.cfg.algorithm.entropy_bonus, - "loss_mask": data.get("loss_mask", None), - "loss_mask_sum": data.get("loss_mask_sum", None), - "max_episode_steps": self.cfg.env.train.max_episode_steps, - } - - kwargs = preprocess_loss_inputs(**kwargs) - - loss, metrics_data = actor_loss(**kwargs) - - loss /= self.gradient_accumulation - loss.backward() - - metrics_data["loss"] = loss.detach().item() - append_to_dict(metrics, metrics_data) - - torch.npu.empty_cache() - - grad_norm = self.model.clip_grad_norm_( - max_norm=self.cfg.actor.optim.clip_grad - ) - self.optimizer.step() - - self.optimizer.zero_grad() - data = { - "actor/grad_norm": grad_norm.detach().item(), - "actor/lr": self.optimizer.param_groups[0]["lr"], - } - if self.cfg.algorithm.adv_type == "embodied_gae": - data["critic/lr"] = self.optimizer.param_groups[1]["lr"] - append_to_dict(metrics, data) - - mean_metric_dict = {key: np.mean(value) for key, value in metrics.items()} - mean_metric_dict = all_reduce_dict( - mean_metric_dict, op=torch.distributed.ReduceOp.AVG - ) - - self.optimizer.zero_grad() - torch.npu.synchronize() - torch.distributed.barrier() - torch.npu.empty_cache() - - return mean_metric_dict - - def save_checkpoint(self, save_base_path, step): - torch.distributed.barrier() - model_state = self.get_model_state_dict() - optim_state = self.get_optimizer_state_dict() - if self._rank == 0: - os.makedirs(save_base_path, exist_ok=True) - torch.save(model_state, os.path.join(save_base_path, "model.pt")) - torch.save(optim_state, os.path.join(save_base_path, "optim.pt")) - torch.distributed.barrier()