diff --git a/.github/workflows/code-test.yml b/.github/workflows/code-test.yml
index c9c34a6c6..0ec8ca3ef 100644
--- a/.github/workflows/code-test.yml
+++ b/.github/workflows/code-test.yml
@@ -113,33 +113,48 @@ jobs:
- name: Checkout code
uses: actions/checkout@v5
- - name: SGLang Collocated mode
+ - name: Megatron SGLang Collocated mode
+ timeout-minutes: 20
+ run: |
+ export REPO_PATH=$(pwd)
+ source switch_env reason
+ bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-mg-sgl
+
+ - name: Megatron vLLM Collocated mode
+ timeout-minutes: 20
+ run: |
+ export REPO_PATH=$(pwd)
+ source switch_env reason
+ bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-mg-vllm
+
+ - name: Megatron SGLang Pipeline mode
timeout-minutes: 20
run: |
export REPO_PATH=$(pwd)
source switch_env reason
- bash tests/e2e_tests/math/sglang/run_collocated.sh
+ bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-pipeline-mg-sgl
- - name: vLLM Collocated mode
+ - name: Megatron vLLM Pipeline mode
timeout-minutes: 20
run: |
export REPO_PATH=$(pwd)
source switch_env reason
- bash tests/e2e_tests/math/vllm/run_collocated.sh
+ bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-pipeline-mg-vllm
- - name: SGLang Pipeline mode
+ - name: FSDP SGLang Collocated mode
timeout-minutes: 20
run: |
export REPO_PATH=$(pwd)
source switch_env reason
- bash tests/e2e_tests/math/sglang/run_pipeline.sh
+ bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-fsdp-sgl
- - name: vLLM Pipeline mode
+ - name: FSDP vLLM Collocated mode
timeout-minutes: 20
run: |
export REPO_PATH=$(pwd)
source switch_env reason
- bash tests/e2e_tests/math/vllm/run_pipeline.sh
+ bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-fsdp-vllm
+
reason-qwen-grpo-test-rollout-logprobs:
needs: [check-changes]
@@ -149,33 +164,47 @@ jobs:
- name: Checkout code
uses: actions/checkout@v5
- - name: SGLang Collocated mode
+ - name: Megatron SGLang Collocated mode
+ timeout-minutes: 20
+ run: |
+ export REPO_PATH=$(pwd)
+ source switch_env reason
+ bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs
+
+ - name: Megatron vLLM Collocated mode
+ timeout-minutes: 20
+ run: |
+ export REPO_PATH=$(pwd)
+ source switch_env reason
+ bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs
+
+ - name: Megatron SGLang Pipeline mode
timeout-minutes: 20
run: |
export REPO_PATH=$(pwd)
source switch_env reason
- bash tests/e2e_tests/math/sglang/run_collocated.sh qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml
+ bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs
- - name: vLLM Collocated mode
+ - name: Megatron vLLM Pipeline mode
timeout-minutes: 20
run: |
export REPO_PATH=$(pwd)
source switch_env reason
- bash tests/e2e_tests/math/vllm/run_collocated.sh qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml
+ bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs
- - name: SGLang Pipeline mode
+ - name: FSDP SGLang Collocated mode
timeout-minutes: 20
run: |
export REPO_PATH=$(pwd)
source switch_env reason
- bash tests/e2e_tests/math/sglang/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml
+ bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs
- - name: vLLM Pipeline mode
+ - name: FSDP vLLM Collocated mode
timeout-minutes: 20
run: |
export REPO_PATH=$(pwd)
source switch_env reason
- bash tests/e2e_tests/math/vllm/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml
+ bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs
coding-online-rl-qwen-ppo-test:
needs: [check-changes]
@@ -194,7 +223,29 @@ jobs:
run: |
export REPO_PATH=$(pwd)
source switch_env reason
- bash tests/e2e_tests/coding_online_rl/run_coding_online_rl.sh
+ bash tests/e2e_tests/coding_online_rl/run.sh
+
+ qwen-vl-grpo-test:
+ needs: [check-changes]
+ if: needs.check-changes.outputs.file_filter == 'true'
+ runs-on: reason
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v5
+
+ - name: FSDP SGLang Collocated mode
+ timeout-minutes: 20
+ run: |
+ export REPO_PATH=$(pwd)
+ source switch_env reason
+ bash tests/e2e_tests/reasoning/run.sh qwen2.5-vl-3b-grpo-collocated-fsdp-sgl
+
+ - name: FSDP vLLM Collocated mode
+ timeout-minutes: 20
+ run: |
+ export REPO_PATH=$(pwd)
+ source switch_env reason
+ bash tests/e2e_tests/reasoning/run.sh qwen2.5-vl-3b-grpo-collocated-fsdp-vllm
# =============================================== embodied e2e tests ====================================================
@@ -270,7 +321,7 @@ jobs:
run: |
export REPO_PATH=$(pwd)
source switch_env reason
- bash tests/e2e_tests/auto_placement/run_auto_placement.sh
+ bash tests/e2e_tests/auto_placement/run.sh
# =============================================== finale ====================================================
@@ -283,7 +334,7 @@ jobs:
# Reason e2e tests
reason-qwen-grpo-test, reason-qwen-grpo-test-rollout-logprobs,
- coding-online-rl-qwen-ppo-test,
+ coding-online-rl-qwen-ppo-test, qwen-vl-grpo-test,
# Embodied e2e tests
embodied-maniskill-ppo-openvla-test, embodied-maniskill-grpo-openvlaoft-test, embodied-libero-goal-grpo-openvlaoft-test,embodied-libero-130-grpo-openvlaoft-test,
diff --git a/examples/embodiment/config/libero_10_grpo_openvlaoft.yaml b/examples/embodiment/config/libero_10_grpo_openvlaoft.yaml
index 1d9720fdd..525b1c951 100644
--- a/examples/embodiment/config/libero_10_grpo_openvlaoft.yaml
+++ b/examples/embodiment/config/libero_10_grpo_openvlaoft.yaml
@@ -157,6 +157,12 @@ actor:
trust_remote_code: True
padding_side: "right"
+ fsdp:
+ forward_prefetch: False
+ limit_all_gathers: False
+ backward_prefetch: False
+ use_orig_params: False
+
reward:
use_reward_model: False
diff --git a/examples/embodiment/config/libero_10_grpo_openvlaoft_eval.yaml b/examples/embodiment/config/libero_10_grpo_openvlaoft_eval.yaml
index 628272ed1..69708e7c0 100644
--- a/examples/embodiment/config/libero_10_grpo_openvlaoft_eval.yaml
+++ b/examples/embodiment/config/libero_10_grpo_openvlaoft_eval.yaml
@@ -158,6 +158,12 @@ actor:
trust_remote_code: True
padding_side: "right"
+ fsdp:
+ forward_prefetch: False
+ limit_all_gathers: False
+ backward_prefetch: False
+ use_orig_params: False
+
reward:
use_reward_model: False
diff --git a/examples/embodiment/config/libero_10_ppo_openvlaoft.yaml b/examples/embodiment/config/libero_10_ppo_openvlaoft.yaml
index bf7e31667..15b33bbbc 100644
--- a/examples/embodiment/config/libero_10_ppo_openvlaoft.yaml
+++ b/examples/embodiment/config/libero_10_ppo_openvlaoft.yaml
@@ -152,6 +152,12 @@ actor:
trust_remote_code: True
padding_side: "right"
+ fsdp:
+ forward_prefetch: False
+ limit_all_gathers: False
+ backward_prefetch: False
+ use_orig_params: False
+
reward:
use_reward_model: False
diff --git a/examples/embodiment/config/libero_goal_grpo_openvlaoft.yaml b/examples/embodiment/config/libero_goal_grpo_openvlaoft.yaml
index d6356c314..699d7ab40 100644
--- a/examples/embodiment/config/libero_goal_grpo_openvlaoft.yaml
+++ b/examples/embodiment/config/libero_goal_grpo_openvlaoft.yaml
@@ -156,6 +156,12 @@ actor:
trust_remote_code: True
padding_side: "right"
+ fsdp:
+ forward_prefetch: False
+ limit_all_gathers: False
+ backward_prefetch: False
+ use_orig_params: False
+
reward:
use_reward_model: False
diff --git a/examples/embodiment/config/libero_object_grpo_openvlaoft.yaml b/examples/embodiment/config/libero_object_grpo_openvlaoft.yaml
index b767fd25a..d2bacd05e 100644
--- a/examples/embodiment/config/libero_object_grpo_openvlaoft.yaml
+++ b/examples/embodiment/config/libero_object_grpo_openvlaoft.yaml
@@ -156,6 +156,12 @@ actor:
trust_remote_code: True
padding_side: "right"
+ fsdp:
+ forward_prefetch: False
+ limit_all_gathers: False
+ backward_prefetch: False
+ use_orig_params: False
+
reward:
use_reward_model: False
diff --git a/examples/embodiment/config/libero_spatial_grpo_openvlaoft.yaml b/examples/embodiment/config/libero_spatial_grpo_openvlaoft.yaml
index 69aec20eb..9469166d9 100644
--- a/examples/embodiment/config/libero_spatial_grpo_openvlaoft.yaml
+++ b/examples/embodiment/config/libero_spatial_grpo_openvlaoft.yaml
@@ -156,6 +156,12 @@ actor:
trust_remote_code: True
padding_side: "right"
+ fsdp:
+ forward_prefetch: False
+ limit_all_gathers: False
+ backward_prefetch: False
+ use_orig_params: False
+
reward:
use_reward_model: False
diff --git a/examples/embodiment/config/maniskill_grpo_openvla.yaml b/examples/embodiment/config/maniskill_grpo_openvla.yaml
index 16dc2af06..3679d533d 100644
--- a/examples/embodiment/config/maniskill_grpo_openvla.yaml
+++ b/examples/embodiment/config/maniskill_grpo_openvla.yaml
@@ -158,6 +158,12 @@ actor:
adam_eps: 1.0e-05
clip_grad: 1.0
+ fsdp:
+ forward_prefetch: False
+ limit_all_gathers: False
+ backward_prefetch: False
+ use_orig_params: False
+
reward:
use_reward_model: False
diff --git a/examples/embodiment/config/maniskill_grpo_openvlaoft.yaml b/examples/embodiment/config/maniskill_grpo_openvlaoft.yaml
index def45aafb..7bd32855b 100644
--- a/examples/embodiment/config/maniskill_grpo_openvlaoft.yaml
+++ b/examples/embodiment/config/maniskill_grpo_openvlaoft.yaml
@@ -156,6 +156,12 @@ actor:
adam_eps: 1.0e-05
clip_grad: 10.0
+ fsdp:
+ forward_prefetch: False
+ limit_all_gathers: False
+ backward_prefetch: False
+ use_orig_params: False
+
reward:
use_reward_model: False
diff --git a/examples/embodiment/config/maniskill_ppo_openvla.yaml b/examples/embodiment/config/maniskill_ppo_openvla.yaml
index 6aeae632d..bf9cab8eb 100644
--- a/examples/embodiment/config/maniskill_ppo_openvla.yaml
+++ b/examples/embodiment/config/maniskill_ppo_openvla.yaml
@@ -154,6 +154,12 @@ actor:
trust_remote_code: True
padding_side: "right"
+ fsdp:
+ forward_prefetch: False
+ limit_all_gathers: False
+ backward_prefetch: False
+ use_orig_params: False
+
reward:
use_reward_model: False
diff --git a/examples/embodiment/config/maniskill_ppo_openvla_quickstart.yaml b/examples/embodiment/config/maniskill_ppo_openvla_quickstart.yaml
index 969dc85cb..b81d1390f 100644
--- a/examples/embodiment/config/maniskill_ppo_openvla_quickstart.yaml
+++ b/examples/embodiment/config/maniskill_ppo_openvla_quickstart.yaml
@@ -170,6 +170,12 @@ actor:
adam_eps: 1.0e-05
clip_grad: 1.0
+ fsdp:
+ forward_prefetch: False
+ limit_all_gathers: False
+ backward_prefetch: False
+ use_orig_params: False
+
reward:
use_reward_model: False
diff --git a/examples/embodiment/config/maniskill_ppo_openvlaoft.yaml b/examples/embodiment/config/maniskill_ppo_openvlaoft.yaml
index 1e5893f43..f957f1997 100644
--- a/examples/embodiment/config/maniskill_ppo_openvlaoft.yaml
+++ b/examples/embodiment/config/maniskill_ppo_openvlaoft.yaml
@@ -159,6 +159,12 @@ actor:
trust_remote_code: True
padding_side: "right"
+ fsdp:
+ forward_prefetch: False
+ limit_all_gathers: False
+ backward_prefetch: False
+ use_orig_params: False
+
reward:
use_reward_model: False
diff --git a/examples/embodiment/config/robotwin_ppo_openvlaoft.yaml b/examples/embodiment/config/robotwin_ppo_openvlaoft.yaml
index 30700e1c2..56d936659 100644
--- a/examples/embodiment/config/robotwin_ppo_openvlaoft.yaml
+++ b/examples/embodiment/config/robotwin_ppo_openvlaoft.yaml
@@ -158,6 +158,12 @@ actor:
trust_remote_code: True
padding_side: "right"
+ fsdp:
+ forward_prefetch: False
+ limit_all_gathers: False
+ backward_prefetch: False
+ use_orig_params: False
+
reward:
use_reward_model: False
diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml
new file mode 100644
index 000000000..3b31eecdc
--- /dev/null
+++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml
@@ -0,0 +1,225 @@
+defaults:
+ - override hydra/job_logging: stdout
+
+hydra:
+ run:
+ dir: .
+ output_subdir: null
+
+cluster:
+ num_nodes: 1
+ component_placement:
+ actor,rollout,reward: all
+
+runner:
+ task_type: reasoning
+ logger:
+ log_path: ${runner.output_dir}/${runner.experiment_name}
+ project_name: rlinf
+ experiment_name: ${runner.experiment_name}
+ logger_backends: ["tensorboard"] # wandb, swanlab
+
+ max_epochs: 5
+ max_steps: -1
+
+ val_check_interval: 1
+ save_interval: 50
+
+ seq_length: 28672
+
+ enable_dynamic_batch_size: False
+ max_tokens_per_mbs: 28672
+
+ resume_dir: null
+ experiment_name: grpo-1.5b
+ output_dir: ../results
+algorithm:
+ group_size: 8
+
+ n_minibatches: 4
+ training_batch_size_per_gpu: 1 # micro batch size
+ rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory.
+
+ # mbs to do log prob inference, can be set to
+ # lower than rollout_batch_size_per_gpu to reduce
+ # memory usage
+ logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu}
+
+ # val rollout mbs
+ val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu}
+
+ recompute_logprobs: True
+ shuffle_rollout: False
+
+ # GRPO loss params
+ loss_type: math_ppo_actor
+ loss_agg_func: "token-mean"
+ kl_beta: 0.0 # 0.001
+ kl_penalty_type: low_var_kl
+ ratio_clip_eps: 0.2
+ entropy_bonus: 0.0
+ calculate_entropy: False
+ clip_ratio_c: null # 3.0
+ clip_ratio_low: null
+ clip_ratio_high: null
+
+ adv_type: math_grpo
+ normalize_advantages: True
+ early_stop_imp_ratio: 5.0
+ use_valid_token_scale: False
+
+ # params for rollout
+ sampling_params:
+ use_greedy: False
+ temperature: 1.0
+ top_k: 1000000
+ top_p: 1.0
+ repetition_penalty: 1.0
+ max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}}
+ min_new_tokens: 1
+
+rollout:
+ group_name: "RolloutGroup"
+
+ gpu_memory_utilization: 0.55
+
+ model_dir: /home/weight/DeepSeek-R1-Distill-Qwen-1.5B-2layer/
+ model_arch: qwen2.5
+ enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize.
+ distributed_executor_backend: mp # ray or mp
+ disable_log_stats: False
+ detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging.
+ padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine
+ eos: null # will be tokenizer.eos_token_id if null.
+
+ rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm]
+
+ sglang:
+ attention_backend: ascend # [flashinfer, triton] for more, see sglang's doc
+ decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats.
+ use_torch_compile: False # enable torch_compile in SGLang for rollout.
+ torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used.
+
+ vllm:
+ attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc
+ enable_chunked_prefill: True # enable vllm to use chunked_prefill.
+ enable_prefix_caching: True # enable vllm to use prefix_caching.
+ enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling.
+
+ return_logprobs: ${not:${algorithm.recompute_logprobs}}
+
+ tensor_parallel_size: 1
+ pipeline_parallel_size: 1
+
+ validate_weight: False # whether to send all weights at first for weight comparison.
+ validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison.
+ print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine.
+
+ max_running_requests: 64 # the maximum number of running requests in the rollout engine.
+ cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used.
+
+data:
+ type: math
+ dataset_name: boba
+ max_prompt_length: 1024
+ filter_prompt_by_length: True
+ rollout_batch_size: 16
+ val_rollout_batch_size: null
+ num_workers: 2
+ shuffle: True
+ validation_shuffle: True
+ seed: 1234
+ train_data_paths: ["/home/dataset/boba/AReaL-boba-106k.jsonl"]
+ val_data_paths: ["/home/dataset/boba/AReaL-boba-106k.jsonl"]
+ prompt_key: prompt
+ image_keys: [image]
+ answer_key: answer
+ choice_key: choices
+ solution_key: null
+ use_chat_template: True
+ lazy_loading: True
+
+actor:
+ group_name: "ActorGroup"
+ training_backend: fsdp
+ mcore_gpt: True
+ spec_name: decoder_gpt
+
+ enable_offload: True
+ checkpoint_load_path: null
+
+ global_batch_size: 8
+ micro_batch_size: 1
+
+ enable_dp_load_balance: False
+
+ calculate_flops: False
+
+ seed: 1234
+
+ model:
+ precision: bf16
+ sharding_strategy: full_shard
+ is_lora: False
+
+ seq_length: ${runner.seq_length}
+ encoder_seq_length: ${runner.seq_length}
+ model_path: /home/weight/DeepSeek-R1-Distill-Qwen-1.5B-2layer/
+
+ optim:
+ optimizer: adam
+ bf16: True
+ fp16: False
+ lr: 2e-05
+ adam_beta1: 0.9
+ adam_beta2: 0.95
+ adam_eps: 1.0e-05
+ min_lr: 2.0e-6
+ weight_decay: 0.05
+ use_distributed_optimizer: True
+ overlap_grad_reduce: False
+ overlap_param_gather: False
+ optimizer_enable_pin: false
+ overlap_param_gather_with_optimizer_step: False
+ clip_grad: 0.8
+ loss_scale: 65536
+
+ lr_sched:
+ lr_warmup_fraction: 0.01
+ lr_warmup_init: 0.0
+ lr_warmup_iters: 0
+ max_lr: 2.0e-5
+ min_lr: 0.0
+ lr_decay_style: constant
+ lr_decay_iters: 10
+
+ tokenizer:
+ tokenizer_model: /home/weight/DeepSeek-R1-Distill-Qwen-1.5B-2layer/
+ use_fast: False
+ trust_remote_code: True
+ padding_side: 'right'
+
+ fsdp:
+ forward_prefetch: True
+ limit_all_gathers: True
+ backward_prefetch: True
+ use_orig_params: True
+
+reward:
+ group_name: "RewardGroup"
+ use_reward_model: false
+ reward_type: 'math'
+ reward_scale: 5.0
+ reward_weights:
+ qa_accuracy: 1.0
+ think_format: 0.0
+ answer_format: 0.0
+
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
+critic:
+ use_critic_model: false
diff --git a/examples/math/config/qwen2.5-1.5b-grpo-megatron-pipeline.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml
similarity index 95%
rename from examples/math/config/qwen2.5-1.5b-grpo-megatron-pipeline.yaml
rename to examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml
index 9ba641d15..ea4606003 100644
--- a/examples/math/config/qwen2.5-1.5b-grpo-megatron-pipeline.yaml
+++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml
@@ -12,9 +12,10 @@ cluster:
rollout: 0-15
inference: 16-23
actor: 24-63
+ reward: 0-15
runner:
- task_type: math
+ task_type: reasoning
logger:
log_path: ${runner.output_dir}/${runner.experiment_name}
project_name: rlinf
@@ -65,6 +66,8 @@ algorithm:
entropy_bonus: 0.0
calculate_entropy: True
clip_ratio_c: null # 3.0
+ clip_ratio_low: null # if null or not set, will use ratio_clip_eps
+ clip_ratio_high: null # if null or not set, will use ratio_clip_eps
adv_type: math_grpo
normalize_advantages: False
@@ -134,6 +137,7 @@ rollout:
data:
type: math
+ dataset_name: boba
max_prompt_length: 1024
filter_prompt_by_length: True
rollout_batch_size: 512
@@ -146,6 +150,7 @@ data:
train_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"]
val_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"]
+
actor:
group_name: "ActorGroup"
training_backend: megatron
@@ -271,9 +276,15 @@ actor:
reward:
+ group_name: "RewardGroup"
use_reward_model: false
reward_type: 'math'
reward_scale: 5.0
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
critic:
use_critic_model: false
diff --git a/examples/math/config/qwen2.5-1.5b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml
similarity index 94%
rename from examples/math/config/qwen2.5-1.5b-grpo-megatron.yaml
rename to examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml
index 6f75dc38e..4d851f894 100644
--- a/examples/math/config/qwen2.5-1.5b-grpo-megatron.yaml
+++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml
@@ -9,10 +9,10 @@ hydra:
cluster:
num_nodes: 16
component_placement:
- actor,rollout: all
+ actor,rollout,reward: all
runner:
- task_type: math
+ task_type: reasoning
logger:
log_path: ${runner.output_dir}/${runner.experiment_name}
project_name: rlinf
@@ -61,6 +61,8 @@ algorithm:
entropy_bonus: 0.0
calculate_entropy: False
clip_ratio_c: null # 3.0
+ clip_ratio_low: null # if null or not set, will use ratio_clip_eps
+ clip_ratio_high: null # if null or not set, will use ratio_clip_eps
adv_type: math_grpo
normalize_advantages: True
@@ -121,6 +123,7 @@ rollout:
data:
type: math
+ dataset_name: boba
max_prompt_length: 1024
filter_prompt_by_length: True
rollout_batch_size: 512
@@ -258,9 +261,16 @@ actor:
reward:
+ group_name: "RewardGroup"
use_reward_model: false
reward_type: 'math'
reward_scale: 5.0
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
critic:
use_critic_model: false
\ No newline at end of file
diff --git a/examples/math/config/qwen2.5-1.5b-single-gpu.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml
similarity index 94%
rename from examples/math/config/qwen2.5-1.5b-single-gpu.yaml
rename to examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml
index e3f5bf28d..1829050c1 100644
--- a/examples/math/config/qwen2.5-1.5b-single-gpu.yaml
+++ b/examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml
@@ -9,10 +9,10 @@ hydra:
cluster:
num_nodes: 1
component_placement:
- actor,rollout: 0
+ actor,rollout,reward: 0
runner:
- task_type: math
+ task_type: reasoning
logger:
log_path: ${runner.output_dir}/${runner.experiment_name}
project_name: rlinf
@@ -61,6 +61,8 @@ algorithm:
entropy_bonus: 0.0
calculate_entropy: False
clip_ratio_c: null # 3.0
+ clip_ratio_low: null # if null or not set, will use ratio_clip_eps
+ clip_ratio_high: null # if null or not set, will use ratio_clip_eps
adv_type: math_grpo
normalize_advantages: True
@@ -258,9 +260,16 @@ actor:
reward:
+ group_name: "RewardGroup"
use_reward_model: false
reward_type: 'math'
reward_scale: 5.0
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
critic:
use_critic_model: false
\ No newline at end of file
diff --git a/examples/math/config/qwen2.5-32b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml
similarity index 95%
rename from examples/math/config/qwen2.5-32b-grpo-megatron.yaml
rename to examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml
index 6e397dfda..e9eb2089e 100644
--- a/examples/math/config/qwen2.5-32b-grpo-megatron.yaml
+++ b/examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml
@@ -9,10 +9,10 @@ hydra:
cluster:
num_nodes: 32
component_placement:
- actor,rollout: all
+ actor,rollout,reward: all
runner:
- task_type: math
+ task_type: reasoning
logger:
log_path: ${runner.output_dir}/${runner.experiment_name}
project_name: rlinf
@@ -61,6 +61,8 @@ algorithm:
entropy_bonus: 0.0
calculate_entropy: False
clip_ratio_c: null # 3.0
+ clip_ratio_low: null # if null or not set, will use ratio_clip_eps
+ clip_ratio_high: null # if null or not set, will use ratio_clip_eps
adv_type: math_grpo
normalize_advantages: True
@@ -259,5 +261,11 @@ reward:
reward_type: 'math'
reward_scale: 5.0
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
critic:
use_critic_model: false
\ No newline at end of file
diff --git a/examples/math/config/qwen2.5-7b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml
similarity index 94%
rename from examples/math/config/qwen2.5-7b-grpo-megatron.yaml
rename to examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml
index 63146687e..b2a70d6f8 100644
--- a/examples/math/config/qwen2.5-7b-grpo-megatron.yaml
+++ b/examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml
@@ -9,10 +9,10 @@ hydra:
cluster:
num_nodes: 16
component_placement:
- actor,rollout: all
+ actor,rollout,reward: all
runner:
- task_type: math
+ task_type: reasoning
logger:
log_path: ${runner.output_dir}/${runner.experiment_name}
project_name: rlinf
@@ -61,6 +61,8 @@ algorithm:
entropy_bonus: 0.0
calculate_entropy: False
clip_ratio_c: null # 3.0
+ clip_ratio_low: null # if null or not set, will use ratio_clip_eps
+ clip_ratio_high: null # if null or not set, will use ratio_clip_eps
adv_type: math_grpo
normalize_advantages: True
@@ -257,9 +259,16 @@ actor:
reward:
+ group_name: "RewardGroup"
use_reward_model: false
reward_type: 'math'
reward_scale: 5.0
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
critic:
use_critic_model: false
\ No newline at end of file
diff --git a/examples/math/config/tp_comm_overlap_cfg.yaml b/examples/reasoning/config/tp_comm_overlap_cfg.yaml
similarity index 100%
rename from examples/math/config/tp_comm_overlap_cfg.yaml
rename to examples/reasoning/config/tp_comm_overlap_cfg.yaml
diff --git a/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml
new file mode 100644
index 000000000..17c192fa2
--- /dev/null
+++ b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml
@@ -0,0 +1,228 @@
+defaults:
+ - override hydra/job_logging: stdout
+
+hydra:
+ run:
+ dir: .
+ output_subdir: null
+
+cluster:
+ num_nodes: 1
+ component_placement:
+ actor,rollout,reward: all
+
+runner:
+ task_type: reasoning
+ logger:
+ log_path: ${runner.output_dir}/${runner.experiment_name}
+ project_name: rlinf
+ experiment_name: ${runner.experiment_name}
+ logger_backends: ["tensorboard"] # wandb, swanlab
+
+ max_epochs: 5
+ max_steps: -1
+
+ val_check_interval: 1
+ save_interval: 50
+
+ seq_length: 2048
+
+ enable_dynamic_batch_size: False
+ max_tokens_per_mbs: 28672
+
+ resume_dir: null
+ experiment_name: grpo-1.5b
+ output_dir: ../results
+
+algorithm:
+ group_size: 8
+
+ n_minibatches: 4
+ training_batch_size_per_gpu: 1 # micro batch size
+ rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory.
+
+ # mbs to do log prob inference, can be set to
+ # lower than rollout_batch_size_per_gpu to reduce
+ # memory usage
+ logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu}
+
+ # val rollout mbs
+ val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu}
+
+ recompute_logprobs: False
+ shuffle_rollout: False
+
+ # GRPO loss params
+ loss_type: math_ppo_actor
+ loss_agg_func: "token-mean"
+ kl_beta: 0.0 # 0.001
+ kl_penalty_type: low_var_kl
+ ratio_clip_eps: 0.2
+ entropy_bonus: 0.0
+ calculate_entropy: False
+ clip_ratio_c: null # 3.0
+ clip_ratio_low: null
+ clip_ratio_high: null
+
+ adv_type: math_grpo
+ normalize_advantages: True
+ early_stop_imp_ratio: 5.0
+ use_valid_token_scale: False
+
+ # params for rollout
+ sampling_params:
+ use_greedy: False
+ temperature: 1.0
+ top_k: 1000000
+ top_p: 1.0
+ repetition_penalty: 1.0
+ max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}}
+ min_new_tokens: 1
+
+rollout:
+ group_name: "RolloutGroup"
+
+ gpu_memory_utilization: 0.55
+
+ model_dir: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct
+ model_arch: qwen2.5_vl #qwen2.5
+ enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize.
+ distributed_executor_backend: mp # ray or mp
+ disable_log_stats: False
+ detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging.
+ padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine
+ eos: null # will be tokenizer.eos_token_id if null.
+
+ rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm]
+
+ sglang:
+ attention_backend: ascend # [flashinfer, triton] for more, see sglang's doc
+ decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats.
+ use_torch_compile: False # enable torch_compile in SGLang for rollout.
+ torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used.
+
+ vllm:
+ attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc
+ enable_chunked_prefill: True # enable vllm to use chunked_prefill.
+ enable_prefix_caching: True # enable vllm to use prefix_caching.
+ enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling.
+
+ return_logprobs: ${not:${algorithm.recompute_logprobs}}
+
+ tensor_parallel_size: 1
+ pipeline_parallel_size: 1
+
+ validate_weight: False # whether to send all weights at first for weight comparison.
+ validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison.
+ print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine.
+
+ max_running_requests: 64 # the maximum number of running requests in the rollout engine.
+ cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used.
+
+data:
+ type: vision_language
+ dataset_name: robo2vlm
+ max_prompt_length: 1024
+ filter_prompt_by_length: True
+ rollout_batch_size: 8
+ val_rollout_batch_size: null
+ num_workers: 2
+ prompt_key: prompt
+ image_keys: ["image"] # some vlm datasets may have multiple image columns
+ choice_key: "choices"
+ answer_key: "answer"
+ solution_key: "solution"
+ use_chat_template: True
+ lazy_loading: True
+ shuffle: True
+ validation_shuffle: True
+ seed: 1234
+ train_data_paths: ["/home/x00922209/datasets/advaitgupta/robo2VLM/data/"]
+ val_data_paths: ["/home/x00922209/datasets/advaitgupta/robo2VLM/data/"]
+
+actor:
+ group_name: "ActorGroup"
+ training_backend: fsdp
+ mcore_gpt: True
+ spec_name: decoder_gpt
+
+ enable_offload: True
+ checkpoint_load_path: null
+
+ global_batch_size: 8
+ micro_batch_size: 1
+
+ enable_dp_load_balance: False
+
+ calculate_flops: False
+
+ seed: 1234
+
+ model:
+ precision: bf16
+ sharding_strategy: full_shard
+ is_lora: False
+
+ seq_length: ${runner.seq_length}
+ encoder_seq_length: ${runner.seq_length}
+ model_path: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct
+
+ model_arch: ${rollout.model_arch}
+
+ optim:
+ optimizer: adam
+ bf16: True #False
+ fp16: False #True
+ lr: 2e-05
+ adam_beta1: 0.9
+ adam_beta2: 0.95
+ adam_eps: 1.0e-05
+ min_lr: 2.0e-6
+ weight_decay: 0.05
+ use_distributed_optimizer: True
+ overlap_grad_reduce: False
+ overlap_param_gather: False
+ optimizer_enable_pin: false
+ overlap_param_gather_with_optimizer_step: False
+ clip_grad: 0.8
+ loss_scale: 65536
+
+ lr_sched:
+ lr_warmup_fraction: 0.01
+ lr_warmup_init: 0.0
+ lr_warmup_iters: 0
+ max_lr: 2.0e-5
+ min_lr: 0.0
+ lr_decay_style: constant
+ lr_decay_iters: 10
+
+ tokenizer:
+ tokenizer_model: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct
+ use_fast: False
+ trust_remote_code: True
+ padding_side: 'right'
+
+ fsdp:
+ forward_prefetch: True
+ limit_all_gathers: True
+ backward_prefetch: True
+ use_orig_params: True
+
+reward:
+ group_name: "RewardGroup"
+ use_reward_model: false
+ reward_type: 'vqa'
+ reward_scale: 1.0
+ reward_weights:
+ qa_accuracy: 1.0
+ think_format: 0.0
+ answer_format: 0.0
+
+ tokenizer:
+ tokenizer_model: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct
+ use_fast: False
+ trust_remote_code: True
+ padding_side: 'right'
+
+critic:
+ use_critic_model: false
diff --git a/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp_baak.yaml b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp_baak.yaml
new file mode 100644
index 000000000..fa5b16080
--- /dev/null
+++ b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp_baak.yaml
@@ -0,0 +1,222 @@
+defaults:
+ - override hydra/job_logging: stdout
+
+hydra:
+ run:
+ dir: .
+ output_subdir: null
+
+cluster:
+ num_nodes: 1
+ component_placement:
+ actor,rollout,reward: all
+
+runner:
+ task_type: reasoning
+ logger:
+ log_path: ${runner.output_dir}/${runner.experiment_name}
+ project_name: rlinf
+ experiment_name: ${runner.experiment_name}
+ logger_backends: ["tensorboard"] # wandb, swanlab
+
+ max_epochs: 5
+ max_steps: -1
+
+ val_check_interval: 1
+ save_interval: 50
+
+ seq_length: 2048
+
+ enable_dynamic_batch_size: False
+ max_tokens_per_mbs: 28672
+
+ resume_dir: null
+ experiment_name: grpo-1.5b
+ output_dir: ../results
+
+algorithm:
+ group_size: 8
+
+ n_minibatches: 4
+ training_batch_size_per_gpu: 1 # micro batch size
+ rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory.
+
+ # mbs to do log prob inference, can be set to
+ # lower than rollout_batch_size_per_gpu to reduce
+ # memory usage
+ logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu}
+
+ # val rollout mbs
+ val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu}
+
+ recompute_logprobs: False
+ shuffle_rollout: False
+
+ # GRPO loss params
+ loss_type: math_ppo_actor
+ loss_agg_func: "token-mean"
+ kl_beta: 0.0 # 0.001
+ kl_penalty_type: low_var_kl
+ ratio_clip_eps: 0.2
+ entropy_bonus: 0.0
+ calculate_entropy: False
+ clip_ratio_c: null # 3.0
+ clip_ratio_low: null
+ clip_ratio_high: null
+
+ adv_type: math_grpo
+ normalize_advantages: True
+ early_stop_imp_ratio: 5.0
+ use_valid_token_scale: False
+
+ # params for rollout
+ sampling_params:
+ use_greedy: False
+ temperature: 1.0
+ top_k: 1000000
+ top_p: 1.0
+ repetition_penalty: 1.0
+ max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}}
+ min_new_tokens: 1
+
+rollout:
+ group_name: "RolloutGroup"
+
+ gpu_memory_utilization: 0.55
+
+ model_dir: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct
+ model_arch: qwen2.5_vl #qwen2.5
+ enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. taoxu 1010
+ distributed_executor_backend: mp # ray or mp
+ disable_log_stats: False
+ detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging.
+ padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine
+ eos: null # will be tokenizer.eos_token_id if null.
+
+ rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm]
+
+ sglang:
+ attention_backend: ascend # [flashinfer, triton] for more, see sglang's doc
+ decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats.
+ use_torch_compile: False # enable torch_compile in SGLang for rollout.
+ torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used.
+
+ vllm:
+ attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc
+ enable_chunked_prefill: True # enable vllm to use chunked_prefill.
+ enable_prefix_caching: True # enable vllm to use prefix_caching.
+ enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling.
+
+ return_logprobs: ${not:${algorithm.recompute_logprobs}}
+
+ tensor_parallel_size: 2
+ pipeline_parallel_size: 1
+
+ validate_weight: False # whether to send all weights at first for weight comparison.
+ validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison.
+ print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine.
+
+ max_running_requests: 64 # the maximum number of running requests in the rollout engine.
+ cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used.
+
+data:
+ type: vision_language
+ dataset_name: robo2vlm
+ max_prompt_length: 1024
+ filter_prompt_by_length: True
+ rollout_batch_size: 8
+ val_rollout_batch_size: null
+ num_workers: 2
+ prompt_key: prompt
+ image_keys: ["image"] # some vlm datasets may have multiple image columns
+ choice_key: "choices"
+ answer_key: "answer"
+ solution_key: "solution"
+ use_chat_template: True
+ lazy_loading: True
+ shuffle: True
+ validation_shuffle: True
+ seed: 1234
+ train_data_paths: ["/home/x00922209/datasets/advaitgupta/robo2VLM/data/"]
+ val_data_paths: ["/home/x00922209/datasets/advaitgupta/robo2VLM/data/"]
+
+actor:
+ group_name: "ActorGroup"
+ training_backend: fsdp
+ mcore_gpt: True
+ spec_name: decoder_gpt
+
+ enable_offload: True
+ checkpoint_load_path: null
+
+ global_batch_size: 8
+ micro_batch_size: 1
+
+ enable_dp_load_balance: False
+
+ calculate_flops: False
+
+ seed: 1234
+
+ model:
+ precision: bf16
+ sharding_strategy: full_shard
+ is_lora: False
+
+ seq_length: ${runner.seq_length}
+ encoder_seq_length: ${runner.seq_length}
+ model_path: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct
+
+ model_arch: ${rollout.model_arch}
+
+ optim:
+ optimizer: adam
+ bf16: True #False
+ fp16: False #True
+ lr: 2e-05
+ adam_beta1: 0.9
+ adam_beta2: 0.95
+ adam_eps: 1.0e-05
+ min_lr: 2.0e-6
+ weight_decay: 0.05
+ use_distributed_optimizer: True
+ overlap_grad_reduce: False
+ overlap_param_gather: False
+ optimizer_enable_pin: false
+ overlap_param_gather_with_optimizer_step: False
+ clip_grad: 0.8
+ loss_scale: 65536
+
+ lr_sched:
+ lr_warmup_fraction: 0.01
+ lr_warmup_init: 0.0
+ lr_warmup_iters: 0
+ max_lr: 2.0e-5
+ min_lr: 0.0
+ lr_decay_style: constant
+ lr_decay_iters: 10
+
+ tokenizer:
+ tokenizer_model: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct
+ use_fast: False
+ trust_remote_code: True
+ padding_side: 'right'
+
+reward:
+ group_name: "RewardGroup"
+ use_reward_model: false
+ reward_type: 'vqa'
+ reward_scale: 1.0
+ reward_weights:
+ qa_accuracy: 1.0
+ think_format: 0.0
+ answer_format: 0.0
+
+ tokenizer:
+ tokenizer_model: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct
+ use_fast: False
+ trust_remote_code: True
+ padding_side: 'right'
+
+critic:
+ use_critic_model: false
\ No newline at end of file
diff --git a/examples/math/main_math.py b/examples/reasoning/main_grpo.py
similarity index 81%
rename from examples/math/main_math.py
rename to examples/reasoning/main_grpo.py
index 7150566fb..30073d562 100644
--- a/examples/math/main_math.py
+++ b/examples/reasoning/main_grpo.py
@@ -21,12 +21,13 @@
from rlinf.config import validate_cfg
from rlinf.data.datasets import create_rl_dataset
from rlinf.data.tokenizers import hf_tokenizer
-from rlinf.runners.math_runner import MathRunner
+from rlinf.runners.reasoning_runner import ReasoningRunner
from rlinf.scheduler import Cluster
from rlinf.utils.placement import ModelParallelComponentPlacement, PlacementMode
from rlinf.utils.utils import output_redirector
-from rlinf.workers.actor.megatron_actor_worker import MegatronActor
+from rlinf.workers.actor import get_actor_worker
from rlinf.workers.inference.megatron_inference_worker import MegatronInference
+from rlinf.workers.reward.reward_worker import RewardWorker
from rlinf.workers.rollout.utils import get_rollout_backend_worker
"""Script to start GRPO training"""
@@ -67,16 +68,25 @@ def main(cfg) -> None:
placement_strategy=inference_placement_strategy,
)
+ # Reward group
+ reward_placement_strategy = component_placement.get_strategy("reward")
+ reward_group = RewardWorker.create_group(cfg, component_placement).launch(
+ cluster,
+ name=cfg.reward.group_name,
+ placement_strategy=reward_placement_strategy,
+ )
+
# GRPO Actor group
+ actor_worker_cls = get_actor_worker(cfg)
actor_placement_strategy = component_placement.get_strategy("actor")
- actor_group = MegatronActor.create_group(cfg, component_placement).launch(
+ actor_group = actor_worker_cls.create_group(cfg, component_placement).launch(
cluster, name=cfg.actor.group_name, placement_strategy=actor_placement_strategy
)
tokenizer = hf_tokenizer(cfg.actor.tokenizer.tokenizer_model)
- train_ds, val_ds = create_rl_dataset(cfg.data, tokenizer)
+ train_ds, val_ds = create_rl_dataset(cfg, tokenizer)
- runner = MathRunner(
+ runner = ReasoningRunner(
cfg=cfg,
placement=component_placement,
train_dataset=train_ds,
@@ -84,6 +94,7 @@ def main(cfg) -> None:
rollout=rollout_group,
inference=inference_group,
actor=actor_group,
+ reward=reward_group,
)
runner.init_workers()
diff --git a/examples/math/run_main_math_grpo_megatron.sh b/examples/reasoning/run_main_grpo_math.sh
similarity index 70%
rename from examples/math/run_main_math_grpo_megatron.sh
rename to examples/reasoning/run_main_grpo_math.sh
index f826f882f..1b123fd6f 100644
--- a/examples/math/run_main_math_grpo_megatron.sh
+++ b/examples/reasoning/run_main_grpo_math.sh
@@ -2,7 +2,6 @@
set -x
tabs 4
-export VLLM_ATTENTION_BACKEND=XFORMERS
export CUDA_DEVICE_MAX_CONNECTIONS=1
export TOKENIZERS_PARALLELISM=false
export RAY_DEDUP_LOGS=0
@@ -15,7 +14,7 @@ export PYTHONPATH=${REPO_PATH}:${MEGATRON_PATH}:$PYTHONPATH
if [ -z "$1" ]; then
CONFIG_NAME="qwen2.5-1.5b-grpo-megatron"
else
- CONFIG_NAME=$1
+ CONFIG_NAME="qwen2.5-1.5b-grpo-fsdp.yaml"
fi
-python ${REPO_PATH}/examples/math/main_math.py --config-path ${CONFIG_PATH}/config/ --config-name $CONFIG_NAME
\ No newline at end of file
+python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path ${CONFIG_PATH}/config/math/ --config-name $CONFIG_NAME
diff --git a/examples/math/run_main_math_pipeline_grpo_megatron.sh b/examples/reasoning/run_main_grpo_vqa.sh
old mode 100644
new mode 100755
similarity index 64%
rename from examples/math/run_main_math_pipeline_grpo_megatron.sh
rename to examples/reasoning/run_main_grpo_vqa.sh
index 7deb96519..3cc526f0e
--- a/examples/math/run_main_math_pipeline_grpo_megatron.sh
+++ b/examples/reasoning/run_main_grpo_vqa.sh
@@ -2,7 +2,6 @@
set -x
tabs 4
-export VLLM_ATTENTION_BACKEND=XFORMERS
export CUDA_DEVICE_MAX_CONNECTIONS=1
export TOKENIZERS_PARALLELISM=false
export RAY_DEDUP_LOGS=0
@@ -13,9 +12,9 @@ MEGATRON_PATH=/opt/Megatron-LM
export PYTHONPATH=${REPO_PATH}:${MEGATRON_PATH}:$PYTHONPATH
if [ -z "$1" ]; then
- CONFIG_NAME="qwen2.5-1.5b-grpo-megatron-pipeline"
+ CONFIG_NAME="qwen2.5-vl-3b-grpo-fsdp"
else
CONFIG_NAME=$1
fi
-python ${REPO_PATH}/examples/math/main_math.py --config-path ${CONFIG_PATH}/config/ --config-name $CONFIG_NAME
\ No newline at end of file
+python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path ${CONFIG_PATH}/config/vqa/ --config-name $CONFIG_NAME
\ No newline at end of file
diff --git a/examples/math/run_placement_autotune.sh b/examples/reasoning/run_placement_autotune.sh
similarity index 100%
rename from examples/math/run_placement_autotune.sh
rename to examples/reasoning/run_placement_autotune.sh
diff --git a/fusion_result.json b/fusion_result.json
new file mode 100644
index 000000000..20b56950d
--- /dev/null
+++ b/fusion_result.json
@@ -0,0 +1,21 @@
+{
+ "session_and_graph_id_0_0": {
+ "graph_fusion": {
+ "IndexByTensorStaticFusionPass": {
+ "effect_times": "0",
+ "match_times": "1"
+ },
+ "RefreshInt64ToInt32FusionPass": {
+ "effect_times": "0",
+ "match_times": "1"
+ }
+ },
+ "ub_fusion": {
+ "AutomaticUbFusion": {
+ "effect_times": "0",
+ "match_times": "2",
+ "repository_hit_times": "0"
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/kernel_meta/buildPidInfo.json b/kernel_meta/buildPidInfo.json
new file mode 100644
index 000000000..9900ce4f6
--- /dev/null
+++ b/kernel_meta/buildPidInfo.json
@@ -0,0 +1,2754 @@
+[
+ [
+ 39911,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12797813670961250527"
+ ],
+ [
+ 39913,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2319069557219941997"
+ ],
+ [
+ 39917,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11862212240406296461"
+ ],
+ [
+ 39922,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_885784172161213002"
+ ],
+ [
+ 39926,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7410510130683415152"
+ ],
+ [
+ 39930,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9194155428752533289"
+ ],
+ [
+ 39934,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10424833431148169855"
+ ],
+ [
+ 39936,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9855745055097868770"
+ ],
+ [
+ 39939,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18191654686185573965"
+ ],
+ [
+ 39942,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1065590051412680278"
+ ],
+ [
+ 39949,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16967210759787679084"
+ ],
+ [
+ 39950,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8381815239432563187"
+ ],
+ [
+ 39958,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2867384563160197328"
+ ],
+ [
+ 39965,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4547732597630948415"
+ ],
+ [
+ 39971,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15778275423241821855"
+ ],
+ [
+ 39980,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5207596602532600872"
+ ],
+ [
+ 40666,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15835674102581580671"
+ ],
+ [
+ 82796,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8782270786933035597"
+ ],
+ [
+ 100620,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7353802482069301333"
+ ],
+ [
+ 112704,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7725153434515021512"
+ ],
+ [
+ 122239,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7210453806497323497"
+ ],
+ [
+ 146045,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3405930793784815214"
+ ],
+ [
+ 146076,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12695548580890247073"
+ ],
+ [
+ 146140,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1605098968774552756"
+ ],
+ [
+ 146205,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5766664014834865721"
+ ],
+ [
+ 146347,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9612635458804393825"
+ ],
+ [
+ 146447,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7725679307351709877"
+ ],
+ [
+ 146559,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14784440107212476086"
+ ],
+ [
+ 146567,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17267848541246919924"
+ ],
+ [
+ 146569,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17784643373181208859"
+ ],
+ [
+ 146573,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10067831816227681826"
+ ],
+ [
+ 146589,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1044498738075030000"
+ ],
+ [
+ 158951,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15191603587064731993"
+ ],
+ [
+ 159046,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6443795088650938062"
+ ],
+ [
+ 159116,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7617045580398011601"
+ ],
+ [
+ 159145,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17761038868925529398"
+ ],
+ [
+ 159234,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8618471054285342305"
+ ],
+ [
+ 159360,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16579794217881527508"
+ ],
+ [
+ 159440,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10254394513325194666"
+ ],
+ [
+ 159460,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11087666240252757180"
+ ],
+ [
+ 159513,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9428322862212375668"
+ ],
+ [
+ 159626,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16449556766463462195"
+ ],
+ [
+ 159768,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1071949320798082478"
+ ],
+ [
+ 159861,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9565697075048113424"
+ ],
+ [
+ 159886,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3062431087256661843"
+ ],
+ [
+ 160022,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13353626036029525043"
+ ],
+ [
+ 161093,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_508908986595913110"
+ ],
+ [
+ 162181,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3563660611022476863"
+ ],
+ [
+ 259575,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8201276825499370867"
+ ],
+ [
+ 259577,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15529659103780327911"
+ ],
+ [
+ 259581,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10174461859794700718"
+ ],
+ [
+ 259585,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16507297788703908047"
+ ],
+ [
+ 259587,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11966777041657863002"
+ ],
+ [
+ 259592,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2726330588881717531"
+ ],
+ [
+ 259596,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17030647279409832227"
+ ],
+ [
+ 259598,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12186607257866309998"
+ ],
+ [
+ 259601,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3768772217671096392"
+ ],
+ [
+ 259607,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3911160194709892748"
+ ],
+ [
+ 259612,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7977521137351637493"
+ ],
+ [
+ 259619,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4128563410056480856"
+ ],
+ [
+ 259626,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9581134290190490336"
+ ],
+ [
+ 259628,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_702154103913980934"
+ ],
+ [
+ 259636,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14027257268528873801"
+ ],
+ [
+ 259643,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17044983378904657187"
+ ],
+ [
+ 260342,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3738242728542939509"
+ ],
+ [
+ 260402,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15806158438452141266"
+ ],
+ [
+ 260467,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4567683529506793175"
+ ],
+ [
+ 260537,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10271642930345239273"
+ ],
+ [
+ 260629,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3945686445184445770"
+ ],
+ [
+ 260690,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17057962697473254798"
+ ],
+ [
+ 260755,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16871514587483264100"
+ ],
+ [
+ 260763,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2675279169353877702"
+ ],
+ [
+ 260795,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14119700360366649411"
+ ],
+ [
+ 315952,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11641596342314442912"
+ ],
+ [
+ 329434,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17017774772109951435"
+ ],
+ [
+ 340293,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_584578301371656169"
+ ],
+ [
+ 352185,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11288852585934677911"
+ ],
+ [
+ 367542,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6263883255072188097"
+ ],
+ [
+ 367633,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5888503093289751929"
+ ],
+ [
+ 367666,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5420145819910730026"
+ ],
+ [
+ 379262,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1783494569175948891"
+ ],
+ [
+ 379266,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_537678043780734220"
+ ],
+ [
+ 379285,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5877535174558663131"
+ ],
+ [
+ 379286,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10249842372811971323"
+ ],
+ [
+ 379383,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_656090948065222947"
+ ],
+ [
+ 379674,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11244763063226368392"
+ ],
+ [
+ 380304,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_231890364717191513"
+ ],
+ [
+ 380313,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5692311019013524597"
+ ],
+ [
+ 380315,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11068432615228024246"
+ ],
+ [
+ 380440,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15012465798244511954"
+ ],
+ [
+ 380532,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3776119111621864505"
+ ],
+ [
+ 380533,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6221575824346386447"
+ ],
+ [
+ 380834,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11465674300335516174"
+ ],
+ [
+ 381115,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7186726943654900649"
+ ],
+ [
+ 381318,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12507534548909153748"
+ ],
+ [
+ 381749,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15611949621196386618"
+ ],
+ [
+ 446067,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15585473429645162327"
+ ],
+ [
+ 446070,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2905662668684958995"
+ ],
+ [
+ 446077,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9261524943542376262"
+ ],
+ [
+ 446081,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11733848515447041665"
+ ],
+ [
+ 446083,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13074589690020057687"
+ ],
+ [
+ 446087,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6740951685981297380"
+ ],
+ [
+ 446089,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1230291348499102449"
+ ],
+ [
+ 446095,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8512469513541935614"
+ ],
+ [
+ 446098,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6734225289254301183"
+ ],
+ [
+ 446100,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3193003074983928221"
+ ],
+ [
+ 446107,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17730893955086431941"
+ ],
+ [
+ 446108,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14296202533891893120"
+ ],
+ [
+ 446115,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17201341106939116953"
+ ],
+ [
+ 446122,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17335486450140669799"
+ ],
+ [
+ 446128,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_172524440323681068"
+ ],
+ [
+ 446139,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8488055523413543743"
+ ],
+ [
+ 446862,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11022287578591461948"
+ ],
+ [
+ 475137,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5524899443716147254"
+ ],
+ [
+ 484622,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9030794170733368903"
+ ],
+ [
+ 494550,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10177095432241817550"
+ ],
+ [
+ 508239,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10453435395226908113"
+ ],
+ [
+ 519345,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15868233357086722582"
+ ],
+ [
+ 552276,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15438796748740496551"
+ ],
+ [
+ 552339,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6239891206884327744"
+ ],
+ [
+ 552379,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9273241722792499069"
+ ],
+ [
+ 552468,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16552715273014474462"
+ ],
+ [
+ 552569,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3196337734268543926"
+ ],
+ [
+ 552675,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5847337843613087850"
+ ],
+ [
+ 552742,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6352307575227886689"
+ ],
+ [
+ 552750,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1982943654522997836"
+ ],
+ [
+ 552783,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3828732781290684800"
+ ],
+ [
+ 552794,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16977504117571897465"
+ ],
+ [
+ 565573,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5204528607269617447"
+ ],
+ [
+ 565699,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4727225444131891245"
+ ],
+ [
+ 565716,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3170838120379853841"
+ ],
+ [
+ 565750,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2776707022079058597"
+ ],
+ [
+ 566045,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12036042907429605982"
+ ],
+ [
+ 566310,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12421069780630334687"
+ ],
+ [
+ 566383,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3217743380181589526"
+ ],
+ [
+ 566422,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1276333807841191048"
+ ],
+ [
+ 566494,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7963328572114012833"
+ ],
+ [
+ 566534,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6674089693496856100"
+ ],
+ [
+ 566844,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9863761807279720849"
+ ],
+ [
+ 567668,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16074272625414500687"
+ ],
+ [
+ 567773,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4282295767343065807"
+ ],
+ [
+ 567937,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2161356974128856154"
+ ],
+ [
+ 567993,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8242040281073172353"
+ ],
+ [
+ 568229,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15920297452776001566"
+ ],
+ [
+ 637429,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_959368100671980870"
+ ],
+ [
+ 637431,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14153015096824152693"
+ ],
+ [
+ 637436,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_899818523044574745"
+ ],
+ [
+ 637439,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10247733762527698150"
+ ],
+ [
+ 637444,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14145287080845176151"
+ ],
+ [
+ 637447,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17222031087440245104"
+ ],
+ [
+ 637453,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12136397179187398942"
+ ],
+ [
+ 637456,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8731812876370325350"
+ ],
+ [
+ 637459,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10602237885762750118"
+ ],
+ [
+ 637462,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16594426567840719304"
+ ],
+ [
+ 637468,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7366125432199894168"
+ ],
+ [
+ 637476,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11110698096089951252"
+ ],
+ [
+ 637484,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18188682871993650937"
+ ],
+ [
+ 637489,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10878533416589658592"
+ ],
+ [
+ 637495,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8079586613243454881"
+ ],
+ [
+ 637502,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2215817028210768871"
+ ],
+ [
+ 670212,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15344779562198194740"
+ ],
+ [
+ 680392,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18388393187931224557"
+ ],
+ [
+ 690762,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3514287218513747531"
+ ],
+ [
+ 743351,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6478191836071807723"
+ ],
+ [
+ 743388,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12832048017481839151"
+ ],
+ [
+ 743400,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10541782531111842730"
+ ],
+ [
+ 743475,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10763222511672298619"
+ ],
+ [
+ 743664,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17197230416625906030"
+ ],
+ [
+ 743787,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18047780774813550571"
+ ],
+ [
+ 743849,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2446754395186006741"
+ ],
+ [
+ 743875,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5679073243145878981"
+ ],
+ [
+ 743879,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16603745088718856023"
+ ],
+ [
+ 743887,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13969638480830774322"
+ ],
+ [
+ 743895,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13369228484466567544"
+ ],
+ [
+ 743998,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12762855693433035364"
+ ],
+ [
+ 744081,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3347649343482407740"
+ ],
+ [
+ 757025,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15796729346284784432"
+ ],
+ [
+ 757028,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14967701666120261818"
+ ],
+ [
+ 757039,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5372387541384308030"
+ ],
+ [
+ 757071,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3245068616512596448"
+ ],
+ [
+ 757197,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14277746927271838195"
+ ],
+ [
+ 757272,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1831510647579988021"
+ ],
+ [
+ 758082,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1963453102454522286"
+ ],
+ [
+ 803617,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15623709472216942560"
+ ],
+ [
+ 803620,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5962434898661482018"
+ ],
+ [
+ 803625,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14549239843071968661"
+ ],
+ [
+ 803630,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8567544072600264650"
+ ],
+ [
+ 803633,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18407024629481036527"
+ ],
+ [
+ 803664,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15960439059008180895"
+ ],
+ [
+ 803666,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3263727124264509013"
+ ],
+ [
+ 803667,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9451034478073827629"
+ ],
+ [
+ 803669,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18226372968931963744"
+ ],
+ [
+ 803670,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14202429746807222158"
+ ],
+ [
+ 803674,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7069975997530525703"
+ ],
+ [
+ 803675,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8834989374251547639"
+ ],
+ [
+ 803679,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17156716697792514660"
+ ],
+ [
+ 803681,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7928243516609743611"
+ ],
+ [
+ 803683,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17472278663547573068"
+ ],
+ [
+ 803688,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9297387177610951812"
+ ],
+ [
+ 803858,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9931525164958826644"
+ ],
+ [
+ 803962,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18009242967987784208"
+ ],
+ [
+ 804016,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2874613183183755843"
+ ],
+ [
+ 804136,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4244083726001215679"
+ ],
+ [
+ 804179,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9469466012086032567"
+ ],
+ [
+ 804247,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6732244663862958131"
+ ],
+ [
+ 804321,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17195272942189944812"
+ ],
+ [
+ 804392,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10755163416630026042"
+ ],
+ [
+ 804513,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12350748768321932278"
+ ],
+ [
+ 804688,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6894846644228586388"
+ ],
+ [
+ 873633,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15791320709907010305"
+ ],
+ [
+ 888932,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17017344803712162999"
+ ],
+ [
+ 898268,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12878681295278483117"
+ ],
+ [
+ 904854,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13733947334674992446"
+ ],
+ [
+ 910408,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3037428981486910277"
+ ],
+ [
+ 910622,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12284873316976139363"
+ ],
+ [
+ 922414,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4043581537781464345"
+ ],
+ [
+ 922752,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14805333628576514738"
+ ],
+ [
+ 922753,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5878253078724431956"
+ ],
+ [
+ 922763,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15057904892991266440"
+ ],
+ [
+ 922779,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18116003072026181513"
+ ],
+ [
+ 922864,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3039955536516904891"
+ ],
+ [
+ 922878,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8073047691309365525"
+ ],
+ [
+ 923385,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3835924806113435602"
+ ],
+ [
+ 923770,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9436136851994508472"
+ ],
+ [
+ 923881,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12348575658552678709"
+ ],
+ [
+ 924169,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3100439312273551508"
+ ],
+ [
+ 924681,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3582699337697878406"
+ ],
+ [
+ 925395,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17118759297342529823"
+ ],
+ [
+ 925409,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18056433885446449459"
+ ],
+ [
+ 925631,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13111124903733223810"
+ ],
+ [
+ 925674,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2449886720924222030"
+ ],
+ [
+ 974754,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_206876370358077584"
+ ],
+ [
+ 974757,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11257961443249735269"
+ ],
+ [
+ 998802,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4978121236847435213"
+ ],
+ [
+ 1003838,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17874922247557918464"
+ ],
+ [
+ 1008420,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14756252425461487877"
+ ],
+ [
+ 1051386,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15672847212087630880"
+ ],
+ [
+ 1056292,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11690764893981564089"
+ ],
+ [
+ 1077929,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18243756878085111260"
+ ],
+ [
+ 1078664,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13101390950818077021"
+ ],
+ [
+ 1092149,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4702484821586339293"
+ ],
+ [
+ 1092150,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8700607087392051310"
+ ],
+ [
+ 1092167,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11867004705859069539"
+ ],
+ [
+ 1092168,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14425483843466519342"
+ ],
+ [
+ 1092170,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3137345424683392113"
+ ],
+ [
+ 1092171,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5580715293573031613"
+ ],
+ [
+ 1092247,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9573985685283690942"
+ ],
+ [
+ 1092248,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1117326061986011595"
+ ],
+ [
+ 1092430,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12852273377442787843"
+ ],
+ [
+ 1092431,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2948619980269787118"
+ ],
+ [
+ 1092814,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16608503441526049509"
+ ],
+ [
+ 1092815,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8557565976778605998"
+ ],
+ [
+ 1092974,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1586571163572556133"
+ ],
+ [
+ 1092975,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15935510270417225075"
+ ],
+ [
+ 1093267,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4402275187066910043"
+ ],
+ [
+ 1093268,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5557153902550791572"
+ ],
+ [
+ 1196049,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8425342506929295889"
+ ],
+ [
+ 1196053,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18378899707502877030"
+ ],
+ [
+ 1196056,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3094735019997891187"
+ ],
+ [
+ 1196060,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6889997414760764962"
+ ],
+ [
+ 1196065,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16740248499631073491"
+ ],
+ [
+ 1196068,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8147418516272875846"
+ ],
+ [
+ 1196070,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7125085183662271974"
+ ],
+ [
+ 1196076,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17768317808754919941"
+ ],
+ [
+ 1196345,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12412202691532259202"
+ ],
+ [
+ 1196434,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10867201177510314690"
+ ],
+ [
+ 1196530,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17300580725555235990"
+ ],
+ [
+ 1196591,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11922602998666607133"
+ ],
+ [
+ 1196658,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17067792987140972942"
+ ],
+ [
+ 1196718,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6156226292697541158"
+ ],
+ [
+ 1196789,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11972615607024541236"
+ ],
+ [
+ 1225343,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16644111413419115375"
+ ],
+ [
+ 1248801,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18426022297842133396"
+ ],
+ [
+ 1254600,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13015264027012895286"
+ ],
+ [
+ 1302314,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2178400540252727344"
+ ],
+ [
+ 1302417,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15389620440927117005"
+ ],
+ [
+ 1302502,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13548285859569365661"
+ ],
+ [
+ 1302542,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3104574529036971715"
+ ],
+ [
+ 1302580,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9161340348181567069"
+ ],
+ [
+ 1302678,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2235073576100882104"
+ ],
+ [
+ 1313288,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10714691217339955812"
+ ],
+ [
+ 1313289,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15436382972359969753"
+ ],
+ [
+ 1313295,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10087935773393978978"
+ ],
+ [
+ 1313297,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1956045676144969910"
+ ],
+ [
+ 1313413,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10346769137148346031"
+ ],
+ [
+ 1313414,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12500094804178281846"
+ ],
+ [
+ 1313594,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2233046969904106531"
+ ],
+ [
+ 1313602,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17274697924548824848"
+ ],
+ [
+ 1313910,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1971305477396827550"
+ ],
+ [
+ 1313911,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1123480945946248614"
+ ],
+ [
+ 1313913,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13632501687006175122"
+ ],
+ [
+ 1313914,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17948247176888184637"
+ ],
+ [
+ 1314424,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4838020015312522463"
+ ],
+ [
+ 1314425,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5662771786044476863"
+ ],
+ [
+ 1314464,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11644308174162608419"
+ ],
+ [
+ 1314465,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14904900394694390614"
+ ],
+ [
+ 1364104,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12599208991475254956"
+ ],
+ [
+ 1364109,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11714438223684649350"
+ ],
+ [
+ 1364112,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12536573013377539960"
+ ],
+ [
+ 1364118,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9176497740794462077"
+ ],
+ [
+ 1364120,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12851108834171289286"
+ ],
+ [
+ 1364124,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7874367383020722434"
+ ],
+ [
+ 1364133,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11768575761097821347"
+ ],
+ [
+ 1364136,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9136218849828640261"
+ ],
+ [
+ 1364436,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15227420143784422207"
+ ],
+ [
+ 1364530,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_678152374747580855"
+ ],
+ [
+ 1364650,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1863822375592647081"
+ ],
+ [
+ 1364684,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14158128526209592340"
+ ],
+ [
+ 1364721,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4584616283426803771"
+ ],
+ [
+ 1364806,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13235463644610255264"
+ ],
+ [
+ 1364877,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7481851041269423666"
+ ],
+ [
+ 1364909,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9126134503219493660"
+ ],
+ [
+ 1365000,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12501320043627540492"
+ ],
+ [
+ 1365092,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16659304774143478016"
+ ],
+ [
+ 1398994,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7719667656758625427"
+ ],
+ [
+ 1414384,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8539755814366516822"
+ ],
+ [
+ 1425199,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12779466743610801848"
+ ],
+ [
+ 1434341,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5346047771258662063"
+ ],
+ [
+ 1443931,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8500329968264923909"
+ ],
+ [
+ 1471398,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2835195402119160546"
+ ],
+ [
+ 1480308,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17921460277698942329"
+ ],
+ [
+ 1480309,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6060307077504199420"
+ ],
+ [
+ 1480544,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6633913915366293595"
+ ],
+ [
+ 1480545,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5929325753921586179"
+ ],
+ [
+ 1480762,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5087823519142655646"
+ ],
+ [
+ 1480763,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17650771157019299908"
+ ],
+ [
+ 1481439,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17960349755202839121"
+ ],
+ [
+ 1481440,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11344166716576316064"
+ ],
+ [
+ 1481478,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18011110268539351239"
+ ],
+ [
+ 1481479,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13212044995188295238"
+ ],
+ [
+ 1481983,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7786193257091011544"
+ ],
+ [
+ 1481984,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4535156793004568521"
+ ],
+ [
+ 1482479,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7392738499711986039"
+ ],
+ [
+ 1482480,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16803447451398881609"
+ ],
+ [
+ 1482675,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5816531416034365646"
+ ],
+ [
+ 1482676,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10586423650728537576"
+ ],
+ [
+ 1531952,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15737151482683442731"
+ ],
+ [
+ 1531958,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5428763241135066291"
+ ],
+ [
+ 1531961,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10672990360669572384"
+ ],
+ [
+ 1531965,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7671654195981105868"
+ ],
+ [
+ 1531970,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14942561136263416232"
+ ],
+ [
+ 1531973,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16581135063665980351"
+ ],
+ [
+ 1531977,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4856394302363438782"
+ ],
+ [
+ 1531981,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15877297897966392256"
+ ],
+ [
+ 1532252,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6425863017087684784"
+ ],
+ [
+ 1532297,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1423327815346902785"
+ ],
+ [
+ 1532410,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16344047352049658364"
+ ],
+ [
+ 1532505,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17305619576099580631"
+ ],
+ [
+ 1532566,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8395635612082422722"
+ ],
+ [
+ 1532630,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8519186747189401928"
+ ],
+ [
+ 1532664,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1253792069041758400"
+ ],
+ [
+ 1532734,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16904214805911888839"
+ ],
+ [
+ 1532855,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8506737276851019991"
+ ],
+ [
+ 1532973,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_649535223346514500"
+ ],
+ [
+ 1533011,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5898366698551436494"
+ ],
+ [
+ 1533074,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5096822643445805939"
+ ],
+ [
+ 1533138,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14031496100727112023"
+ ],
+ [
+ 1533230,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6266968138694899706"
+ ],
+ [
+ 1533290,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16937440140735016382"
+ ],
+ [
+ 1533334,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10380564371013497541"
+ ],
+ [
+ 1646912,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16348313361094870629"
+ ],
+ [
+ 1646913,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16036313845476488225"
+ ],
+ [
+ 1646927,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7913415713848588049"
+ ],
+ [
+ 1646928,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4103568604206357407"
+ ],
+ [
+ 1647271,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3917280544582020227"
+ ],
+ [
+ 1647272,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13152661351180665451"
+ ],
+ [
+ 1647797,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6759349546217838702"
+ ],
+ [
+ 1647798,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10771757146841459066"
+ ],
+ [
+ 1649758,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6186015994533271987"
+ ],
+ [
+ 1649759,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9297921486035376125"
+ ],
+ [
+ 1650053,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17376928305113794204"
+ ],
+ [
+ 1650054,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14705463419645039338"
+ ],
+ [
+ 1650074,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17507795361488404022"
+ ],
+ [
+ 1650075,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_791693539870127207"
+ ],
+ [
+ 1650077,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11601476096660261739"
+ ],
+ [
+ 1650078,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12330363696036321630"
+ ],
+ [
+ 1698356,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3664737074668484302"
+ ],
+ [
+ 1698360,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_307407573315947332"
+ ],
+ [
+ 1698364,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9769783361472155454"
+ ],
+ [
+ 1698365,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11620844252758829593"
+ ],
+ [
+ 1698369,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18049344638709473559"
+ ],
+ [
+ 1698371,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12020994052316156595"
+ ],
+ [
+ 1698376,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11281246552810662244"
+ ],
+ [
+ 1698381,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2375830332811294762"
+ ],
+ [
+ 1698622,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17464125010525395920"
+ ],
+ [
+ 1698715,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8252497263861944219"
+ ],
+ [
+ 1698754,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2367183499606118175"
+ ],
+ [
+ 1698875,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3307535582648818046"
+ ],
+ [
+ 1698960,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14070843042311209568"
+ ],
+ [
+ 1699023,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3711321641783824900"
+ ],
+ [
+ 1699057,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4067246543661527215"
+ ],
+ [
+ 1699128,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9510008041698682979"
+ ],
+ [
+ 1699222,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15070956721185346780"
+ ],
+ [
+ 1699307,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16822435983994287417"
+ ],
+ [
+ 1699350,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9200131402658211969"
+ ],
+ [
+ 1699410,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4312347285030704895"
+ ],
+ [
+ 1699501,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8452932345320744230"
+ ],
+ [
+ 1699592,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18319826375533952136"
+ ],
+ [
+ 1745273,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9427035578749238306"
+ ],
+ [
+ 1752419,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11861522220206622819"
+ ],
+ [
+ 1815709,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8461056289216019937"
+ ],
+ [
+ 1815710,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15670507367599225435"
+ ],
+ [
+ 1815718,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14014685830231205955"
+ ],
+ [
+ 1815719,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8933292223775536158"
+ ],
+ [
+ 1815722,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4670593314428329079"
+ ],
+ [
+ 1815723,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14753117404955500077"
+ ],
+ [
+ 1815761,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14867940080031341584"
+ ],
+ [
+ 1815765,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9603778874236517966"
+ ],
+ [
+ 1815801,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12882566826319414462"
+ ],
+ [
+ 1815802,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10278462134454322731"
+ ],
+ [
+ 1815807,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_634460488610617389"
+ ],
+ [
+ 1815809,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3787426685243459175"
+ ],
+ [
+ 1816092,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6900943374411146311"
+ ],
+ [
+ 1816098,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2936618988510319244"
+ ],
+ [
+ 1816749,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5025953421479037474"
+ ],
+ [
+ 1816757,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10504877859244548446"
+ ],
+ [
+ 3034511,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3109980201143968959"
+ ],
+ [
+ 3034516,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16225395238917630830"
+ ],
+ [
+ 3034519,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3014914144290525407"
+ ],
+ [
+ 3034520,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4368939354329067568"
+ ],
+ [
+ 3034525,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8231836696933284847"
+ ],
+ [
+ 3034528,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11274985324515364055"
+ ],
+ [
+ 3034533,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16776878341103524463"
+ ],
+ [
+ 3034534,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11836972797422572451"
+ ],
+ [
+ 3034538,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11222685037651458786"
+ ],
+ [
+ 3034543,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2358340237215830214"
+ ],
+ [
+ 3034547,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5530069306192163345"
+ ],
+ [
+ 3034553,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15467614126156452015"
+ ],
+ [
+ 3034559,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5799901149805008020"
+ ],
+ [
+ 3034568,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13480646072272034112"
+ ],
+ [
+ 3034574,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9925042805192677700"
+ ],
+ [
+ 3034584,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5749985854107212444"
+ ],
+ [
+ 3139002,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16606806383596787697"
+ ],
+ [
+ 3139034,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13693873830274303283"
+ ],
+ [
+ 3139037,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1174647644274299681"
+ ],
+ [
+ 3139041,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8463715729614322689"
+ ],
+ [
+ 3139046,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13622137732458691231"
+ ],
+ [
+ 3139061,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13976382821074310028"
+ ],
+ [
+ 3139103,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5674315837922161200"
+ ],
+ [
+ 3139186,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9519308949519055442"
+ ],
+ [
+ 3139256,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11870662408869016824"
+ ],
+ [
+ 3139325,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9203351225653898140"
+ ],
+ [
+ 3139379,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15928929015260880554"
+ ],
+ [
+ 3139390,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4002276556586837015"
+ ],
+ [
+ 3139395,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_869027509639957243"
+ ],
+ [
+ 3139439,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7590980568514383017"
+ ],
+ [
+ 3139477,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5691785906114378471"
+ ],
+ [
+ 3139578,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1636167670197209792"
+ ],
+ [
+ 3153344,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14721186196409181902"
+ ],
+ [
+ 3153350,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8038130317126733515"
+ ],
+ [
+ 3153355,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9363551798712942410"
+ ],
+ [
+ 3153360,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4001983642343932187"
+ ],
+ [
+ 3153908,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3967378375263321794"
+ ],
+ [
+ 3154349,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1956150361493637850"
+ ],
+ [
+ 3154374,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9515009690803494350"
+ ],
+ [
+ 3154941,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6403793289284352763"
+ ],
+ [
+ 3155135,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18377681884643089435"
+ ],
+ [
+ 3155270,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2894252287023698380"
+ ],
+ [
+ 3155272,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1176757978366636180"
+ ],
+ [
+ 3155356,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10243682700450778485"
+ ],
+ [
+ 3156048,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13419435171562057387"
+ ],
+ [
+ 3156349,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5310938153691480087"
+ ],
+ [
+ 3156373,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2839757395093087158"
+ ],
+ [
+ 3156759,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_189645908260721847"
+ ],
+ [
+ 3357366,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14712843742675416789"
+ ],
+ [
+ 3357370,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14441746631406223373"
+ ],
+ [
+ 3357372,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16051981893101423452"
+ ],
+ [
+ 3357379,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7313357428863899572"
+ ],
+ [
+ 3357382,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15741756964888684277"
+ ],
+ [
+ 3357385,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_518244284579097490"
+ ],
+ [
+ 3357390,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9623044147940949233"
+ ],
+ [
+ 3357396,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_188046812284482511"
+ ],
+ [
+ 3357815,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2070546872302042284"
+ ],
+ [
+ 3357878,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11580410533744131165"
+ ],
+ [
+ 3357919,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9721818878768485015"
+ ],
+ [
+ 3358001,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5097555846270763399"
+ ],
+ [
+ 3358063,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12100682775419753259"
+ ],
+ [
+ 3358121,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_831802341219462707"
+ ],
+ [
+ 3358154,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3511804342372595424"
+ ],
+ [
+ 3358199,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13392040954474567685"
+ ],
+ [
+ 3358318,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1488600423820696507"
+ ],
+ [
+ 3358415,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7033478551626861368"
+ ],
+ [
+ 3358501,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10949369291795034167"
+ ],
+ [
+ 3358546,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9575369317971095792"
+ ],
+ [
+ 3358599,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4171338315313792908"
+ ],
+ [
+ 3358636,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3479887625518456208"
+ ],
+ [
+ 3358704,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10816808840412347489"
+ ],
+ [
+ 3358735,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4565693084836619869"
+ ],
+ [
+ 3473543,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6198223556594770384"
+ ],
+ [
+ 3473551,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10258537246918590156"
+ ],
+ [
+ 3473616,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8769553907161686792"
+ ],
+ [
+ 3473617,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3773271510051838348"
+ ],
+ [
+ 3473932,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13489168884906494727"
+ ],
+ [
+ 3473933,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18253176944629380716"
+ ],
+ [
+ 3473943,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10561446346587147503"
+ ],
+ [
+ 3473944,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1057740105323334353"
+ ],
+ [
+ 3474186,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16534136807012196550"
+ ],
+ [
+ 3474187,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3053221032925412808"
+ ],
+ [
+ 3474202,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11135988479496829721"
+ ],
+ [
+ 3474205,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15763171237614044906"
+ ],
+ [
+ 3475761,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18163430628544684227"
+ ],
+ [
+ 3475762,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_906162624150815971"
+ ],
+ [
+ 3475895,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16284538720262805411"
+ ],
+ [
+ 3475896,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6471773242206149096"
+ ],
+ [
+ 3524930,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7351702211122833722"
+ ],
+ [
+ 3524938,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_849812185320057193"
+ ],
+ [
+ 3524944,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_899185330692270860"
+ ],
+ [
+ 3524949,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16251077242215705442"
+ ],
+ [
+ 3524957,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17261980608195435381"
+ ],
+ [
+ 3524959,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15728989693202867994"
+ ],
+ [
+ 3524963,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_578314422480040874"
+ ],
+ [
+ 3524972,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5858027835250679442"
+ ],
+ [
+ 3625661,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8595121483727162241"
+ ],
+ [
+ 3628698,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6644896138201857560"
+ ],
+ [
+ 3628765,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14380676338487055612"
+ ],
+ [
+ 3629117,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3884026636792605184"
+ ],
+ [
+ 3629128,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8471987555244152278"
+ ],
+ [
+ 3629159,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10706976348774196355"
+ ],
+ [
+ 3629253,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6353439240478980341"
+ ],
+ [
+ 3629283,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11201244953910665068"
+ ],
+ [
+ 3629360,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13965969812663104297"
+ ],
+ [
+ 3629455,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15143351641328536949"
+ ],
+ [
+ 3629517,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1731228908521478078"
+ ],
+ [
+ 3629588,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1280418295423790350"
+ ],
+ [
+ 3629613,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8333685548859735493"
+ ],
+ [
+ 3629770,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6530807606746596767"
+ ],
+ [
+ 3629792,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7249227128131175886"
+ ],
+ [
+ 3629874,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17572995736186529549"
+ ],
+ [
+ 3642042,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_504718820247128253"
+ ],
+ [
+ 3642043,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5621471611596574380"
+ ],
+ [
+ 3642063,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5859789574421468396"
+ ],
+ [
+ 3642065,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1749104467585305323"
+ ],
+ [
+ 3642198,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16903094956545060933"
+ ],
+ [
+ 3642199,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15476704792505050546"
+ ],
+ [
+ 3642237,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16807932559841593208"
+ ],
+ [
+ 3642238,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8310833007698505204"
+ ],
+ [
+ 3642367,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4707148721279721064"
+ ],
+ [
+ 3642371,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9013616000502015446"
+ ],
+ [
+ 3642732,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12951651460115558391"
+ ],
+ [
+ 3642733,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6856766171535732498"
+ ],
+ [
+ 3643042,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14506781066108968506"
+ ],
+ [
+ 3643043,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7374279504609390465"
+ ],
+ [
+ 3643572,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5955043855143664956"
+ ],
+ [
+ 3643573,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8196104095227417096"
+ ],
+ [
+ 3691270,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2620465584826061471"
+ ],
+ [
+ 3691273,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1757388310219074848"
+ ],
+ [
+ 3691281,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11097774192912884566"
+ ],
+ [
+ 3691283,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4742545615202517810"
+ ],
+ [
+ 3691285,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16070580061589069085"
+ ],
+ [
+ 3691287,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6948592204353944837"
+ ],
+ [
+ 3691292,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4433630213579020385"
+ ],
+ [
+ 3691296,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15584731480290088561"
+ ],
+ [
+ 3691301,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_221849414828219018"
+ ],
+ [
+ 3691305,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_447840502870009883"
+ ],
+ [
+ 3691312,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15894757923477358902"
+ ],
+ [
+ 3691316,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2064442071811218632"
+ ],
+ [
+ 3691318,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16020379562811154133"
+ ],
+ [
+ 3691328,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1281388017954056527"
+ ],
+ [
+ 3691337,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2331797143551461098"
+ ],
+ [
+ 3691344,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13357220641161590082"
+ ],
+ [
+ 3692044,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18315790265804955058"
+ ],
+ [
+ 3692064,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17507869003332814399"
+ ],
+ [
+ 3692144,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4507988044373530697"
+ ],
+ [
+ 3692262,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18426271347408772907"
+ ],
+ [
+ 3722920,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5017140278663970478"
+ ],
+ [
+ 3736238,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12999327330513004683"
+ ],
+ [
+ 3756817,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11445498054270621315"
+ ],
+ [
+ 3765828,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1383945043061831987"
+ ],
+ [
+ 3774374,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12318664942362170822"
+ ],
+ [
+ 3783232,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3418024057005528102"
+ ],
+ [
+ 3797970,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8613331899495270727"
+ ],
+ [
+ 3798096,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9878016966715019087"
+ ],
+ [
+ 3798154,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6627377776921054141"
+ ],
+ [
+ 3798272,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10804569796438592151"
+ ],
+ [
+ 3798362,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13557064605936120080"
+ ],
+ [
+ 3798416,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5828012701014328691"
+ ],
+ [
+ 3810814,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8002480603490928081"
+ ],
+ [
+ 3811199,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15215890719473899494"
+ ],
+ [
+ 3811238,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5765224812454476842"
+ ],
+ [
+ 3811276,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16426201200428659968"
+ ],
+ [
+ 3811288,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17001284103737435638"
+ ],
+ [
+ 3811588,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15465033144293705320"
+ ],
+ [
+ 3811661,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11237554842213579307"
+ ],
+ [
+ 3811984,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4208370878176288374"
+ ],
+ [
+ 3812122,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_859880967845588792"
+ ],
+ [
+ 3812520,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11075357886525732487"
+ ],
+ [
+ 3812639,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8682143685678125481"
+ ],
+ [
+ 3812662,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12325636616321695255"
+ ],
+ [
+ 3812698,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14800074808472192333"
+ ],
+ [
+ 3812997,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12158776810784339196"
+ ],
+ [
+ 3813382,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7085330449950753708"
+ ],
+ [
+ 3813392,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9314724528373063174"
+ ],
+ [
+ 3868673,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8427308000189841720"
+ ],
+ [
+ 3868680,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_731075373502833868"
+ ],
+ [
+ 3868682,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9807646971983976701"
+ ],
+ [
+ 3868684,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_708313394526412267"
+ ],
+ [
+ 3868688,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13447384123535294566"
+ ],
+ [
+ 3868692,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14036027334014170087"
+ ],
+ [
+ 3868696,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3187295055034313317"
+ ],
+ [
+ 3868700,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3850755952385802565"
+ ],
+ [
+ 3868705,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10106160760700975956"
+ ],
+ [
+ 3868706,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17824178791258425552"
+ ],
+ [
+ 3868711,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11909562160231335840"
+ ],
+ [
+ 3868719,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13607563626468215022"
+ ],
+ [
+ 3868727,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16787076582919642185"
+ ],
+ [
+ 3868737,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8246440632798892168"
+ ],
+ [
+ 3868739,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4159297518721824730"
+ ],
+ [
+ 3868745,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18015449808950218451"
+ ],
+ [
+ 3869439,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14196112055158873872"
+ ],
+ [
+ 3869482,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14077174710107593045"
+ ],
+ [
+ 3869512,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8061984655211037805"
+ ],
+ [
+ 3869576,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16985047152068049914"
+ ],
+ [
+ 3869667,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12041070490733630674"
+ ],
+ [
+ 3869787,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5526959636930734514"
+ ],
+ [
+ 3869824,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14021406351436652831"
+ ],
+ [
+ 3903350,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8500099511158252282"
+ ],
+ [
+ 3908889,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12265936708546584396"
+ ],
+ [
+ 3919175,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6922556638284219234"
+ ],
+ [
+ 3939646,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17523183083010355471"
+ ],
+ [
+ 3953020,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6290038540260580030"
+ ],
+ [
+ 3965009,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14272343095608536296"
+ ],
+ [
+ 3976392,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5029439665772523493"
+ ],
+ [
+ 3976512,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11697369406167774290"
+ ],
+ [
+ 3976632,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3263338739247539272"
+ ],
+ [
+ 3988513,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_787927194113102781"
+ ],
+ [
+ 3988531,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18023508788842543544"
+ ],
+ [
+ 3988558,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12183092996673180058"
+ ],
+ [
+ 3988563,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4831310188807215850"
+ ],
+ [
+ 3988807,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2198912152797944480"
+ ],
+ [
+ 3988848,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4213118868561974759"
+ ],
+ [
+ 3988850,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3232631311510150968"
+ ],
+ [
+ 3988957,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11047352039168056415"
+ ],
+ [
+ 3989106,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16542734713691156018"
+ ],
+ [
+ 3989387,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15165735529962902451"
+ ],
+ [
+ 3989535,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12857832019022611393"
+ ],
+ [
+ 3989541,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8287593005309603838"
+ ],
+ [
+ 3989647,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18367334379009715015"
+ ],
+ [
+ 3990050,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4446833231210301139"
+ ],
+ [
+ 3990104,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5145384017407119481"
+ ],
+ [
+ 3990359,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6801533416498745154"
+ ],
+ [
+ 4039903,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17584089607946077100"
+ ],
+ [
+ 4039907,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_503447699485187456"
+ ],
+ [
+ 4039912,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15069983805839197528"
+ ],
+ [
+ 4039916,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14542568312346822221"
+ ],
+ [
+ 4039918,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4892714578234319308"
+ ],
+ [
+ 4039925,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5411057577383729197"
+ ],
+ [
+ 4039928,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15464081172484410235"
+ ],
+ [
+ 4039932,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6225688351493779577"
+ ],
+ [
+ 4039935,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15625622540562040764"
+ ],
+ [
+ 4039937,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10502628395411644717"
+ ],
+ [
+ 4039943,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11797930136011262397"
+ ],
+ [
+ 4039946,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13885865319396877227"
+ ],
+ [
+ 4039957,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5671629446760297180"
+ ],
+ [
+ 4039960,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1393594888891471458"
+ ],
+ [
+ 4039967,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3969615252157795066"
+ ],
+ [
+ 4039972,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17691785108059131348"
+ ],
+ [
+ 4040670,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15630378231912934353"
+ ],
+ [
+ 4040733,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1125260354809627775"
+ ],
+ [
+ 4040797,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_250967488981294306"
+ ],
+ [
+ 4040888,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6978632213513861862"
+ ],
+ [
+ 4040963,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14539782102651016167"
+ ],
+ [
+ 4073698,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4027982239315583183"
+ ],
+ [
+ 4087187,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5952558817727402033"
+ ],
+ [
+ 4108124,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8896124445966427863"
+ ],
+ [
+ 4121492,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3777541655201128969"
+ ],
+ [
+ 4134649,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14525555472067694467"
+ ],
+ [
+ 4145652,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9698817224586140305"
+ ],
+ [
+ 4147258,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11549998609739639611"
+ ],
+ [
+ 4147317,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6537741669769489394"
+ ],
+ [
+ 4147443,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7264494085295711900"
+ ],
+ [
+ 4147482,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1691570203381345331"
+ ],
+ [
+ 4147600,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_805114710948033208"
+ ],
+ [
+ 4158981,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1404183938416935234"
+ ],
+ [
+ 4159025,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_366169914351560748"
+ ],
+ [
+ 4159276,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6543673387235939768"
+ ],
+ [
+ 4159300,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_583449162339166668"
+ ],
+ [
+ 4159319,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15439605152958871044"
+ ],
+ [
+ 4159324,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11830655025939200238"
+ ],
+ [
+ 4159646,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2897405567433811008"
+ ],
+ [
+ 4159775,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11323777380456298142"
+ ],
+ [
+ 4159784,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18401703462013938226"
+ ],
+ [
+ 4159918,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3439594419195257361"
+ ],
+ [
+ 4160584,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8461195438972411491"
+ ],
+ [
+ 4160770,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14710116434070740759"
+ ],
+ [
+ 4161660,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8602813889361540004"
+ ],
+ [
+ 4161781,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6869104894578156210"
+ ],
+ [
+ 4161828,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2207584602328927841"
+ ],
+ [
+ 4162147,
+ "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13886441570120016825"
+ ]
+]
\ No newline at end of file
diff --git a/rlinf/algorithms/losses.py b/rlinf/algorithms/losses.py
index 3980f9136..72e9004c6 100644
--- a/rlinf/algorithms/losses.py
+++ b/rlinf/algorithms/losses.py
@@ -196,15 +196,20 @@ def compute_math_ppo_actor_loss(**kwargs):
loss_agg_func = kwargs["loss_agg_func"]
logprobs = kwargs["logprobs"]
old_logprobs = kwargs["old_logprobs"]
- eps_clip = kwargs["eps_clip"]
+ clip_ratio_low = kwargs["clip_ratio_low"]
+ clip_ratio_high = kwargs["clip_ratio_high"]
advantages = kwargs["advantages"]
loss_mask = kwargs.get("loss_mask", None)
c_clip = kwargs.get("c_clip", None)
-
- assert logprobs.dtype == torch.float32
- assert old_logprobs.dtype == torch.float32
- assert advantages.dtype == torch.float32
-
+ if logprobs.dtype != torch.float32:
+ logprobs = logprobs.float() # 转换为 float32
+ #assert logprobs.dtype == torch.float32
+ #assert old_logprobs.dtype == torch.float32
+ #assert advantages.dtype == torch.float32
+ if old_logprobs.dtype != torch.float32:
+ old_logprobs = old_logprobs.float()
+ if advantages.dtype != torch.float32:
+ advantages = advantages.float()
assert loss_mask is not None
loss_mask_count = loss_mask.count_nonzero() or 1
@@ -212,7 +217,7 @@ def compute_math_ppo_actor_loss(**kwargs):
ratio = torch.where(loss_mask, torch.exp(logprobs - old_logprobs), 0)
approx_kl = torch.where(loss_mask, (logprobs - old_logprobs).detach(), 0.0)
- clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
+ clipped_ratio = torch.clamp(ratio, 1.0 - clip_ratio_low, 1.0 + clip_ratio_high)
policy_loss1 = -advantages * ratio
policy_loss2 = -advantages * clipped_ratio
@@ -232,17 +237,21 @@ def compute_math_ppo_actor_loss(**kwargs):
clip_mask = policy_loss1.detach() < policy_loss2.detach()
dual_clip_mask.logical_and_(loss_mask)
- clip_fraction = clip_mask.logical_and_(loss_mask).count_nonzero() / loss_mask_count
- approx_kl = -approx_kl.sum() / loss_mask_count
+ num_clipped = clip_mask.logical_and_(loss_mask).count_nonzero()
+
+ clip_fraction = num_clipped.float() / float(loss_mask_count)
+ approx_kl = -approx_kl.sum() / float(loss_mask_count)
dual_cliped_ratio = torch.where(dual_clip_mask, ratio, 0)
# Compile metrics for logging
metrics_data = {
- "policy_loss": masked_mean(policy_loss.detach(), loss_mask),
- "ratio": masked_mean(ratio.detach(), loss_mask),
- "clipped_ratio": masked_mean(clipped_ratio.detach(), loss_mask),
- "dual_cliped_ratio": masked_mean(dual_cliped_ratio.detach(), loss_mask),
+ "policy_loss": masked_mean(policy_loss.detach(), loss_mask).detach(),
+ "ratio": masked_mean(ratio.detach(), loss_mask).detach(),
+ "clipped_ratio": masked_mean(clipped_ratio.detach(), loss_mask).detach(),
+ "dual_cliped_ratio": masked_mean(
+ dual_cliped_ratio.detach(), loss_mask
+ ).detach(),
"approx_kl": approx_kl.detach(),
"clip_fraction": clip_fraction.detach(),
}
diff --git a/rlinf/algorithms/registry.py b/rlinf/algorithms/registry.py
index 96b19f2b4..11bfc6c6a 100644
--- a/rlinf/algorithms/registry.py
+++ b/rlinf/algorithms/registry.py
@@ -73,22 +73,3 @@ def calculate_adv_and_returns(**kwargs) -> Tuple[torch.Tensor, Optional[torch.Te
adv_type = kwargs["adv_type"]
fn = get_adv_and_returns(adv_type)
return fn(**kwargs)
-
-
-REWARD_REGISTRY: Dict[str, Callable] = {}
-
-
-def register_reward_fn(name: str):
- def decorator(fn):
- REWARD_REGISTRY[name] = fn
- return fn
-
- return decorator
-
-
-def get_reward_fn(name: Optional[str]):
- if name is None:
- return None
- if name not in REWARD_REGISTRY:
- raise ValueError(f"Reward function {name} not registered")
- return REWARD_REGISTRY[name]
diff --git a/rlinf/algorithms/rewards/__init__.py b/rlinf/algorithms/rewards/__init__.py
new file mode 100644
index 000000000..380cfa102
--- /dev/null
+++ b/rlinf/algorithms/rewards/__init__.py
@@ -0,0 +1,34 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from rlinf.algorithms.rewards.code import CodeReward
+from rlinf.algorithms.rewards.math import MathReward
+from rlinf.algorithms.rewards.vqa import VQAReward
+
+
+def register_reward(name: str, reward_class: type):
+ assert name not in reward_registry, f"Reward {name} already registered"
+ reward_registry[name] = reward_class
+
+
+def get_reward_class(name: str):
+ assert name in reward_registry, f"Reward {name} not found"
+ return reward_registry[name]
+
+
+reward_registry = {}
+
+register_reward("math", MathReward)
+register_reward("vqa", VQAReward)
+register_reward("code", CodeReward)
diff --git a/rlinf/algorithms/rewards/code/__init__.py b/rlinf/algorithms/rewards/code/__init__.py
new file mode 100644
index 000000000..0fc75f971
--- /dev/null
+++ b/rlinf/algorithms/rewards/code/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+from omegaconf import DictConfig
+
+from toolkits.code_verifier.verify import fim_verify_call
+
+
+class CodeReward:
+ def __init__(self, config: DictConfig):
+ self.scale = config.get("reward_scale", 1.0)
+
+ def get_reward(
+ self, response: List[str], reference: List[List[str]]
+ ) -> List[float]:
+ rewards = fim_verify_call(response, reference)
+ return [float(reward) * self.scale for reward in rewards]
diff --git a/rlinf/algorithms/rewards/math/__init__.py b/rlinf/algorithms/rewards/math/__init__.py
new file mode 100644
index 000000000..1a67e80e1
--- /dev/null
+++ b/rlinf/algorithms/rewards/math/__init__.py
@@ -0,0 +1,41 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+from omegaconf import DictConfig
+
+from toolkits.math_verifier.verify import math_verify_call
+
+
+class MathReward:
+ def __init__(self, config: DictConfig):
+ self.scale = config.get("reward_scale", 1.0)
+
+ def get_reward(
+ self, response: List[str], reference: List[List[str]]
+ ) -> List[float]:
+ """
+ Calculates reward scores for a list of responses compared to corresponding lists of reference answers.
+ For each response, the function checks if it matches any of the provided references using the `process_results` function.
+ The reward for each response is computed as the first element of the result (converted to float) multiplied by `self.scale`.
+ Args:
+ response (List[str]): A list of response strings to be evaluated.
+ reference (List[List[str]]): A list where each element is a list of reference strings corresponding to each response.
+ Returns:
+ List[float]: A list of reward scores, one for each response.
+ """
+
+ rewards = math_verify_call(response, reference)
+ return [float(reward) * self.scale for reward in rewards]
diff --git a/rlinf/algorithms/rewards/vqa/__init__.py b/rlinf/algorithms/rewards/vqa/__init__.py
new file mode 100644
index 000000000..77b009369
--- /dev/null
+++ b/rlinf/algorithms/rewards/vqa/__init__.py
@@ -0,0 +1,62 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import torch
+from omegaconf import DictConfig
+
+from .format_rewards import answer_format_reward, think_format_reward
+from .qa_rewards import qa_accuracy_reward
+
+
+class VQAReward:
+ NEEDED_REWARD_FUNCTIONS = {
+ "qa_accuracy": qa_accuracy_reward,
+ "think_format": think_format_reward,
+ "answer_format": answer_format_reward,
+ }
+
+ def __init__(self, config: DictConfig):
+ assert "reward_weights" in config, "VQAReward requires reward_weights in config"
+
+ self.reward_weights_config = config.reward_weights
+ assert set(self.reward_weights_config.keys()) == set(
+ self.NEEDED_REWARD_FUNCTIONS.keys()
+ ), (
+ f"Reward weights must contains all of: {self.NEEDED_REWARD_FUNCTIONS.keys()} but got {list(self.reward_weights_config.keys())}"
+ )
+ assert all(
+ reward_weight >= 0 for reward_weight in self.reward_weights_config.values()
+ ), (
+ f"All reward weights must be non-negative but got {list(self.reward_weights_config.values())}"
+ )
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ def get_reward(self, completions: List[str], answers: List[dict]) -> List[float]:
+ rewards = []
+ reward_weights = []
+ for reward_name, reward_function in self.NEEDED_REWARD_FUNCTIONS.items():
+ if self.reward_weights_config[reward_name] > 0:
+ rewards.append(reward_function(completions, answers))
+ else:
+ rewards.append([0.0] * len(completions))
+ reward_weights.append(self.reward_weights_config[reward_name])
+
+ rewards_tensor = torch.tensor(rewards, device=self.device)
+ weights_tensor = torch.tensor(reward_weights, device=self.device)
+
+ final_rewards = (rewards_tensor * weights_tensor.unsqueeze(1)).sum(dim=0)
+
+ return final_rewards.tolist()
diff --git a/rlinf/algorithms/rewards/vqa/format_rewards.py b/rlinf/algorithms/rewards/vqa/format_rewards.py
new file mode 100644
index 000000000..205bbe336
--- /dev/null
+++ b/rlinf/algorithms/rewards/vqa/format_rewards.py
@@ -0,0 +1,67 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import List
+
+
+def think_format_reward(completions, answers) -> List[float]:
+ """
+ Think format reward function compatible with GRPO training.
+
+ Reward function that checks if reasoning is enclosed within tags.
+
+ Args:
+ completions: List of model completions (text strings)
+
+ Returns:
+ List of reward scores (1.0 for correct format, 0.0 otherwise)
+ """
+ pattern = r"^(?!.*)(.*?).*$"
+ rewards = []
+
+ for completion in completions:
+ completion_text = str(completion).strip()
+ match = re.match(pattern, completion_text, re.DOTALL | re.MULTILINE)
+ rewards.append(1.0 if match else 0.0)
+
+ return rewards
+
+
+def answer_format_reward(completions, answers) -> List[float]:
+ """
+ Reward function that checks for proper answer formatting.
+
+ Expected format: X. content where X is a choice letter.
+
+ Args:
+ completions: List of model completions (text strings)
+
+ Returns:
+ List of reward scores (1.0 for correct format, 0.0 otherwise)
+ """
+ rewards = []
+
+ for completion in completions:
+ completion_text = str(completion).strip()
+
+ # Check for proper answer format: X. content
+ answer_pattern = r"\s*[A-E]\.\s*.+?\s*"
+ has_proper_answer = bool(
+ re.search(answer_pattern, completion_text, re.DOTALL | re.IGNORECASE)
+ )
+
+ rewards.append(1.0 if has_proper_answer else 0.0)
+
+ return rewards
diff --git a/rlinf/algorithms/rewards/vqa/qa_rewards.py b/rlinf/algorithms/rewards/vqa/qa_rewards.py
new file mode 100644
index 000000000..2bc9540d3
--- /dev/null
+++ b/rlinf/algorithms/rewards/vqa/qa_rewards.py
@@ -0,0 +1,110 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import List
+
+
+def qa_accuracy_reward(completions, answers) -> List[float]:
+ """
+ Reward function that evaluates question-answering accuracy for VQA tasks.
+
+ Based on TRL's accuracy_reward pattern but adapted for multiple choice VQA.
+
+ Args:
+ completions: List of model completions (text strings)
+ answers: List of correct answers (dict)
+
+ Returns:
+ List of reward scores (1.0 for correct, 0.0 for incorrect)
+ """
+ rewards = []
+
+ for completion, answer in zip(completions, answers):
+ completion_text = str(completion).strip()
+
+ # Extract answer from completion - look for X. content
+ answer_match = re.search(
+ r"\s*([A-E])\.\s*(.*?)\s*",
+ completion_text,
+ re.DOTALL | re.IGNORECASE,
+ )
+
+ if not answer_match:
+ rewards.append(0.0)
+ continue
+
+ predicted_letter = answer_match.group(1).upper()
+ predicted_content = answer_match.group(2).strip()
+
+ # Get ground truth from kwargs
+ correct_answer = answer.get("correct_answer", None)
+ choices = answer.get("choices", None)
+
+ if correct_answer is None or choices is None:
+ rewards.append(0.0)
+ continue
+
+ # Normalize correct_answer to letter format
+ if isinstance(correct_answer, int):
+ correct_letter = chr(65 + correct_answer) # 0->A, 1->B, etc.
+ elif isinstance(correct_answer, str):
+ correct_letter = correct_answer.strip().upper()
+ else:
+ rewards.append(0.0)
+ continue
+
+ # Parse choices if string format
+ if isinstance(choices, str):
+ try:
+ import ast
+
+ choices = ast.literal_eval(choices)
+ except (ValueError, SyntaxError):
+ choices = [str(choices)]
+
+ # Get correct choice content
+ letter_to_idx = {"A": 0, "B": 1, "C": 2, "D": 3, "E": 4}
+ if correct_letter in letter_to_idx and letter_to_idx[correct_letter] < len(
+ choices
+ ):
+ correct_content = choices[letter_to_idx[correct_letter]].strip()
+ else:
+ rewards.append(0.0)
+ continue
+
+ # Check accuracy: both letter and content must match
+ letter_match = predicted_letter == correct_letter
+ content_match = _compare_choice_content(predicted_content, correct_content)
+
+ rewards.append(1.0 if (letter_match and content_match) else 0.0)
+
+ return rewards
+
+
+def _compare_choice_content(predicted: str, correct: str) -> bool:
+ """Compare predicted choice content with correct content."""
+ # Simple normalized comparison
+ pred_normalized = predicted.lower().strip()
+ correct_normalized = correct.lower().strip()
+
+ # Direct match
+ if pred_normalized == correct_normalized:
+ return True
+
+ # Partial match for more flexibility
+ if pred_normalized in correct_normalized or correct_normalized in pred_normalized:
+ return True
+
+ return False
diff --git a/rlinf/config.py b/rlinf/config.py
index 86198334a..05de15254 100644
--- a/rlinf/config.py
+++ b/rlinf/config.py
@@ -13,6 +13,7 @@
# limitations under the License.
import dataclasses
+import importlib.util
import logging
import os
from dataclasses import asdict
@@ -33,16 +34,10 @@
logging.getLogger().setLevel(logging.INFO)
-try:
- import transformer_engine
-
- HAVE_TE = True
-except ImportError:
- transformer_engine = None
- HAVE_TE = False
-
-SUPPORTED_MODEL_ARCHS = ["qwen2.5", "openvla", "openvla_oft"]
+SUPPORTED_MODEL_ARCHS = ["qwen2.5", "qwen2.5_vl", "openvla", "openvla_oft"]
SUPPORTED_ROLLOUT_BACKENDS = ["sglang", "vllm"]
+SUPPORTED_TASK_TYPE = ["embodied", "reasoning", "coding_online_rl"]
+SUPPORTED_TRAINING_BACKENDS = ["megatron", "fsdp"]
__all__ = ["build_config"]
@@ -191,7 +186,7 @@ def validate_vllm_cfg(cfg):
def validate_model_cfg_by_hf_config(cfg, hf_model_path):
# validate by hf config
- hf_config = AutoConfig.from_pretrained(hf_model_path)
+ hf_config = AutoConfig.from_pretrained(hf_model_path, trust_remote_code=True)
if "Qwen2ForCausalLM" in hf_config.architectures:
qkv_bias = True
@@ -199,8 +194,16 @@ def validate_model_cfg_by_hf_config(cfg, hf_model_path):
qkv_bias = getattr(hf_config, "attention_bias", False)
with open_dict(cfg):
- if hf_config.rope_scaling is not None:
- cfg.model.seq_len_interpolation_factor = hf_config.rope_scaling["factor"]
+ rs = getattr(hf_config, "rope_scaling", None)
+ if isinstance(rs, dict):
+ rtype = rs.get("type", "")
+ if rtype in {"linear", "dynamic", "ntk", "yarn"}:
+ f = rs.get("factor")
+ if f is not None:
+ cfg.model.seq_len_interpolation_factor = float(f)
+ else:
+ # mrope
+ cfg.model.seq_len_interpolation_factor = None
cfg.model.override_vocab_size = hf_config.vocab_size
cfg.model.max_position_embeddings = hf_config.max_position_embeddings
cfg.model.rotary_base = hf_config.rope_theta
@@ -220,6 +223,16 @@ def validate_model_cfg_by_hf_config(cfg, hf_model_path):
return cfg
+def validate_fsdp_cfg(cfg: DictConfig) -> DictConfig:
+ OmegaConf.set_struct(cfg, True)
+ with open_dict(cfg):
+ cfg.fsdp.forward_prefetch = cfg.fsdp.get("forward_prefetch", False)
+ cfg.fsdp.limit_all_gathers = cfg.fsdp.get("limit_all_gathers", False)
+ cfg.fsdp.backward_prefetch = cfg.fsdp.get("backward_prefetch", False)
+ cfg.fsdp.use_orig_params = cfg.fsdp.get("use_orig_params", False)
+ return cfg
+
+
def validate_megatron_cfg(cfg: DictConfig) -> DictConfig:
OmegaConf.set_struct(cfg, True)
@@ -527,7 +540,7 @@ def get_robot_control_mode(robot: str):
return cfg
-def validate_math_cfg(cfg: DictConfig) -> DictConfig:
+def validate_reasoning_cfg(cfg: DictConfig) -> DictConfig:
assert cfg.rollout.model_arch in SUPPORTED_MODEL_ARCHS, (
f"Model {cfg.rollout.model_arch} is not supported"
)
@@ -606,11 +619,14 @@ def validate_coding_online_rl_cfg(cfg: DictConfig) -> DictConfig:
def validate_cfg(cfg: DictConfig) -> DictConfig:
OmegaConf.set_struct(cfg, True)
+ assert cfg.runner.task_type in SUPPORTED_TASK_TYPE, (
+ f"task_type must be one of {SUPPORTED_TASK_TYPE}"
+ )
if cfg.runner.task_type == "embodied":
cfg = validate_embodied_cfg(cfg)
- if cfg.runner.task_type == "math":
- cfg = validate_math_cfg(cfg)
- if cfg.runner.task_type == "coding_online_rl":
+ elif cfg.runner.task_type == "reasoning":
+ cfg = validate_reasoning_cfg(cfg)
+ elif cfg.runner.task_type == "coding_online_rl":
cfg = validate_coding_online_rl_cfg(cfg)
if (
@@ -619,13 +635,21 @@ def validate_cfg(cfg: DictConfig) -> DictConfig:
):
assert cfg.algorithm.group_size > 1
+ assert cfg.actor.training_backend in SUPPORTED_TRAINING_BACKENDS, (
+ f"Unsupported training_backend {cfg.actor.training_backend}. Supported training backends are {SUPPORTED_TRAINING_BACKENDS}."
+ )
+
if cfg.actor.training_backend == "megatron":
cfg.actor = validate_megatron_cfg(cfg.actor)
cfg.actor = validate_model_cfg_by_hf_config(cfg.actor, cfg.rollout.model_dir)
+ elif cfg.actor.training_backend == "fsdp":
+ cfg.actor = validate_fsdp_cfg(cfg.actor)
if cfg.critic.use_critic_model and cfg.critic.training_backend == "megatron":
cfg.critic = validate_megatron_cfg(cfg.critic)
- cfg = validate_model_cfg_by_hf_config(cfg.critic, cfg.rollout.model_dir)
+ cfg.critic = validate_model_cfg_by_hf_config(cfg.critic, cfg.rollout.model_dir)
+ elif cfg.critic.use_critic_model and cfg.critic.training_backend == "fsdp":
+ cfg.critic = validate_fsdp_cfg(cfg.critic)
return cfg
@@ -756,7 +780,10 @@ def build_transformer_config(cfg) -> "TransformerConfig":
tp_only_amax_red = cfg.get("tp_only_amax_red", False)
if cfg.get("enable_cuda_graph", False):
- assert HAVE_TE, "Transformer Engine is required for cudagraphs."
+ if importlib.util.find_spec("transformer_engine") is None:
+ raise ImportError(
+ "Can not import transformer_engine, which is required for cudagraphs."
+ )
assert cfg.get("use_te_rng_tracker", False), (
"Transformer engine's RNG tracker is required for cudagraphs, this can be enabled with \
'use_te_rng_tracker=True'."
diff --git a/rlinf/data/datasets/__init__.py b/rlinf/data/datasets/__init__.py
new file mode 100644
index 000000000..f86f3576e
--- /dev/null
+++ b/rlinf/data/datasets/__init__.py
@@ -0,0 +1,150 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Any, Dict, List, Tuple
+
+import torch
+from omegaconf import DictConfig
+from torch.utils.data import Dataset
+from transformers import AutoTokenizer
+
+from rlinf.data.datasets.item import DatasetItem
+from rlinf.data.datasets.math import MathDataset
+from rlinf.data.datasets.vlm import VLMDatasetRegistry
+
+
+def create_rl_dataset(
+ config: DictConfig, tokenizer: AutoTokenizer
+) -> Tuple[Dataset, Dataset]:
+ """Create rl datasets.
+
+ Arguments:
+ config: The RLinf config.
+ tokenizer (Tokenizer): The tokenizer.
+
+ Returns:
+ train_dataset (Dataset): The training dataset.
+
+ val_dataset (Dataset): The validation dataset.
+ """
+
+ if config.data.type == "math":
+ dataset_cls = MathDataset
+ elif config.data.type == "vision_language":
+ # Prefer new factory-based VLM datasets; fallback to legacy if requested
+ dataset_name = getattr(config.data, "dataset_name", None)
+ lazy_loading = bool(getattr(config.data, "lazy_loading", False))
+
+ logging.info(
+ f"Using VLM dataset: name={dataset_name}, lazy_loading={lazy_loading}"
+ )
+
+ train_dataset = VLMDatasetRegistry.create(
+ dataset_name,
+ data_paths=config.data.train_data_paths,
+ config=config,
+ tokenizer=tokenizer,
+ )
+ val_dataset = VLMDatasetRegistry.create(
+ dataset_name,
+ data_paths=config.data.val_data_paths,
+ config=config,
+ tokenizer=tokenizer,
+ )
+ return train_dataset, val_dataset
+ else:
+ return None, None
+
+ logging.info(f"Using dataset class: {dataset_cls.__name__}")
+
+ # Instantiate the dataset using the determined dataset class
+ train_dataset = dataset_cls(
+ data_paths=config.data.train_data_paths,
+ config=config,
+ tokenizer=tokenizer,
+ )
+
+ val_dataset = dataset_cls(
+ data_paths=config.data.val_data_paths,
+ config=config,
+ tokenizer=tokenizer,
+ )
+
+ return train_dataset, val_dataset
+
+
+def collate_fn(data_list: List["DatasetItem"]) -> Dict[str, Any]:
+ """
+ Collate function for batching dataset items.
+ """
+ prompts = []
+ lens = []
+ for it in data_list:
+ p = (
+ it.prompt
+ if isinstance(it.prompt, torch.Tensor)
+ else torch.as_tensor(it.prompt, dtype=torch.long)
+ )
+ if p.dim() == 2 and p.size(0) == 1:
+ p = p.squeeze(0)
+ assert p.dim() == 1, (
+ f"DatasetItem.prompt must be 1-D tensor, current shape is: {p.shape}"
+ )
+ prompts.append(p)
+ lens.append(p.numel())
+
+ if len(set(lens)) == 1:
+ target_len = lens[0]
+ else:
+ target_len = min(lens)
+ prompts = [p[-target_len:] if p.numel() > target_len else p for p in prompts]
+
+ batch_prompt = torch.stack(prompts, dim=0) # [B, L]
+ batch_length = torch.tensor(
+ [min(int(it.length), target_len) for it in data_list], dtype=torch.long
+ )
+
+ batch_idx = torch.tensor([int(it.idx) for it in data_list], dtype=torch.long)
+
+ batch: Dict[str, Any] = {
+ "prompt": batch_prompt, # [B, L]
+ "length": batch_length, # [B]
+ "answer": [it.answer for it in data_list], # List[str]
+ "idx": batch_idx, # [B]
+ "solution": [it.solution for it in data_list], # List[Optional[str]]
+ "image_data": [
+ it.image_data for it in data_list
+ ], # List[Optional[List[bytes|str]]]
+ "prompt_text": [it.prompt_text for it in data_list], # List[Optional[str]]
+ "meta": [it.meta for it in data_list], # List[Optional[dict]]
+ "multi_modal_inputs": [
+ it.multi_modal_inputs for it in data_list
+ ], # List[Optional[dict]]
+ }
+ return batch
diff --git a/rlinf/data/datasets/item.py b/rlinf/data/datasets/item.py
new file mode 100644
index 000000000..e75155dcb
--- /dev/null
+++ b/rlinf/data/datasets/item.py
@@ -0,0 +1,59 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+
+
+@dataclass
+class DatasetItem:
+ """
+ A single item in processed dataset.
+
+ Attributes:
+ prompt (torch.Tensor): Tokenized prompt input_ids tensor.
+ length (int): Length of the prompt input_ids.
+ answer (str | dict): The answer associated with the prompt.
+ idx (int): Index of the item in the dataset.
+ solution (Optional[str]): Optional solution text if exists.
+ prompt_text (Optional[str]): Optional original prompt text before tokenization.
+ meta (Optional[Dict[str, Any]]): Optional metadata dictionary.
+ multi_modal_inputs (Optional[Dict[str, Any]]): Optional dictionary for additional multi-modal inputs.
+ """
+
+ prompt: torch.Tensor
+ length: int
+ answer: str | dict
+ idx: int
+ solution: Optional[str] = None
+ image_data: Optional[List[Union[bytes, str]]] = None
+ prompt_text: Optional[str] = None
+ meta: Optional[Dict[str, Any]] = None
+ multi_modal_inputs: Optional[Dict[str, Any]] = None
diff --git a/rlinf/data/datasets.py b/rlinf/data/datasets/math.py
similarity index 52%
rename from rlinf/data/datasets.py
rename to rlinf/data/datasets/math.py
index fcce53f47..821074bf6 100644
--- a/rlinf/data/datasets.py
+++ b/rlinf/data/datasets/math.py
@@ -12,67 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
import json
import logging
import os
-from collections import defaultdict
-from typing import List
+from typing import Any, List, Tuple, Union
-import numpy as np
import torch
+from omegaconf import DictConfig
from torch.utils.data import Dataset
+from transformers import AutoTokenizer
-
-def batch_pad_to_fixed_len(
- batch: List[torch.Tensor],
- max_batch_len: int,
- pad_token: int,
- left_pad: bool = False,
-) -> torch.Tensor:
- if left_pad:
- batch_pad = torch.stack(
- [
- torch.cat(
- [
- torch.full(
- (max_batch_len - len(seq),), pad_token, dtype=seq.dtype
- ), # pad on the left
- seq,
- ]
- )
- for seq in batch
- ]
- )
- else:
- batch_pad = torch.stack(
- [
- torch.cat(
- [
- seq,
- torch.full(
- (max_batch_len - len(seq),), pad_token, dtype=seq.dtype
- ),
- ]
- )
- for seq in batch
- ]
- )
- return batch_pad
+from rlinf.data.datasets.item import DatasetItem
+from rlinf.data.datasets.utils import batch_pad_to_fixed_len
class MathDataset(Dataset):
- def __init__(self, data_paths, config, tokenizer):
+ def __init__(
+ self,
+ data_paths: Union[str, List[str]],
+ config: DictConfig,
+ tokenizer: AutoTokenizer,
+ ):
super().__init__()
self.data_paths = data_paths
if isinstance(self.data_paths, str):
self.data_paths = [self.data_paths]
- self.max_prompt_length = config.max_prompt_length
+ self.max_prompt_length = config.data.max_prompt_length
self.tokenizer = tokenizer
- self.prompt_key = config.prompt_key
+ self.prompt_key = config.data.prompt_key
self.data = self._load_data()
- if config.get("filter_prompt_by_length", False):
+ if config.data.get("filter_prompt_by_length", False):
total = len(self.data)
filtered = []
failed = 0
@@ -97,7 +83,10 @@ def __init__(self, data_paths, config, tokenizer):
f"(kept {len(self.data)} / {total})."
)
- def _load_data(self):
+ def _load_data(self) -> List[Any]:
+ """
+ Load and merge data from multiple files(json or jsonl).
+ """
merged_data = []
for path in self.data_paths:
@@ -122,7 +111,10 @@ def _load_data(self):
def __len__(self):
return len(self.data)
- def encode(self, text):
+ def encode(self, text: str) -> Tuple[List[int], int]:
+ """
+ Use tokenizer to encode the text and return the token ids and length.
+ """
text_ids = self.tokenizer.encode(text)
return text_ids, len(text_ids)
@@ -151,78 +143,11 @@ def __getitem__(self, idx):
self.tokenizer.eos_token_id,
left_pad=True,
)[0]
-
- output = {
- "prompt": prompt_tokens_tensor,
- "length": prompt_length,
- "answer": answer,
- "idx": idx,
- }
+ output = DatasetItem(
+ prompt=prompt_tokens_tensor,
+ length=prompt_length,
+ answer=answer,
+ idx=idx,
+ image_data=[],
+ )
return output
-
-
-def create_rl_dataset(data_config, tokenizer):
- """Create rl datasets.
-
- Arguments:
- data_config: The data config.
- tokenizer (Tokenizer): The tokenizer.
-
- Returns:
- train_dataset (Dataset): The training dataset.
-
- val_dataset (Dataset): The validation dataset.
- """
-
- if data_config.type == "math":
- dataset_cls = MathDataset
- else:
- return None, None
-
- print(f"Using dataset class: {dataset_cls.__name__}")
-
- # Instantiate the dataset using the determined dataset class
- train_dataset = dataset_cls(
- data_paths=data_config.train_data_paths,
- config=data_config,
- tokenizer=tokenizer,
- )
-
- val_dataset = dataset_cls(
- data_paths=data_config.val_data_paths,
- config=data_config,
- tokenizer=tokenizer,
- )
-
- return train_dataset, val_dataset
-
-
-def collate_fn(data_list: list[dict]) -> dict:
- r"""
- Collate a batch of sample dicts into batched tensors and arrays.
-
- Args:
- data_list: List of dicts mapping feature names to torch.Tensor or other values.
-
- Returns:
- Dict where tensor entries are stacked into a torch.Tensor of shape
- (batch_size, \*dims) and non-tensor entries are converted to
- np.ndarray of dtype object with shape (batch_size,).
- """
- tensors = defaultdict(list)
- non_tensors = defaultdict(list)
-
- for data in data_list:
- for key, val in data.items():
- if isinstance(val, torch.Tensor):
- tensors[key].append(val)
- else:
- non_tensors[key].append(val)
-
- for key, val in tensors.items():
- tensors[key] = torch.stack(val, dim=0)
-
- for key, val in non_tensors.items():
- non_tensors[key] = np.array(val, dtype=object)
-
- return {**tensors, **non_tensors}
diff --git a/rlinf/data/datasets/utils.py b/rlinf/data/datasets/utils.py
new file mode 100644
index 000000000..db4dbdb58
--- /dev/null
+++ b/rlinf/data/datasets/utils.py
@@ -0,0 +1,68 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import torch
+
+
+def batch_pad_to_fixed_len(
+ batch: List[torch.Tensor],
+ max_batch_len: int,
+ pad_token: int,
+ left_pad: bool = False,
+) -> torch.Tensor:
+ if left_pad:
+ batch_pad = torch.stack(
+ [
+ torch.cat(
+ [
+ torch.full(
+ (max_batch_len - len(seq),), pad_token, dtype=seq.dtype
+ ), # pad on the left
+ seq,
+ ]
+ )
+ for seq in batch
+ ]
+ )
+ else:
+ batch_pad = torch.stack(
+ [
+ torch.cat(
+ [
+ seq,
+ torch.full(
+ (max_batch_len - len(seq),), pad_token, dtype=seq.dtype
+ ),
+ ]
+ )
+ for seq in batch
+ ]
+ )
+ return batch_pad
diff --git a/rlinf/data/datasets/vlm.py b/rlinf/data/datasets/vlm.py
new file mode 100644
index 000000000..da9e952b9
--- /dev/null
+++ b/rlinf/data/datasets/vlm.py
@@ -0,0 +1,483 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+import logging
+import os
+from io import BytesIO
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import pandas as pd
+import torch
+from omegaconf import DictConfig
+from PIL import Image
+from torch.utils.data import Dataset
+from transformers import AutoProcessor, AutoTokenizer
+
+from rlinf.data.datasets.item import DatasetItem
+from rlinf.data.datasets.utils import batch_pad_to_fixed_len
+
+
+class VLMBaseDataset(Dataset):
+ def __init__(
+ self,
+ data_paths: Union[List[str], str],
+ config: DictConfig,
+ tokenizer: AutoTokenizer,
+ ) -> None:
+ super().__init__()
+ self.cfg = config
+ raw_paths = [data_paths] if isinstance(data_paths, str) else list(data_paths)
+ # Expand directories into file lists recursively (json/jsonl/parquet)
+ self.data_paths = self._expand_data_paths(raw_paths)
+ self.tokenizer = tokenizer
+ # Delay processor creation; only needed when use_chat_template is True
+ self._processor = None
+
+ self.system_prompt = config.data.get("system_prompt", None)
+ self.use_chat_template = bool(config.data.use_chat_template)
+ self.image_keys = list(config.data.image_keys or [])
+ self.prompt_key = config.data.prompt_key
+ self.choice_key = config.data.get("choice_key", None)
+ self.answer_key = config.data.get("answer_key", None)
+ self.solution_key = config.data.get("solution_key", None)
+ self.max_prompt_length = int(config.data.max_prompt_length)
+ self.eos_id = int(self.tokenizer.eos_token_id)
+
+ # Loading mode
+ self.lazy_loading = bool(getattr(config.data, "lazy_loading", False))
+
+ self._records = []
+ self._indices = [] # (path, fmt, row_index_or_offset)
+
+ if self.lazy_loading:
+ self._build_lazy_indices()
+ else:
+ self._eager_load_all()
+
+ def __len__(self) -> int:
+ return len(self._indices) if self.lazy_loading else len(self._records)
+
+ def __getitem__(self, idx: int) -> DatasetItem:
+ if self.lazy_loading:
+ path, fmt, key = self._indices[idx]
+ raw = self._load_single_lazy(path, fmt, key)
+ return self._process_raw_record(raw, idx)
+ else:
+ raw = self._records[idx]
+ return self._process_raw_record(raw, idx)
+
+ # Ensure dataset is picklable for multi-process DataLoader by removing
+ # unpicklable cache objects like pyarrow.ParquetFile from state.
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ # Drop heavy/unpicklable caches; they will be rebuilt on-demand in workers
+ for k in ("_parquet_cache", "_parquet_df_cache"):
+ if k in state:
+ state[k] = {}
+ return state
+
+ def __setstate__(self, state):
+ # Restore state and ensure cache dicts exist
+ self.__dict__.update(state)
+ self._parquet_cache = getattr(self, "_parquet_cache", {})
+ self._parquet_df_cache = getattr(self, "_parquet_df_cache", {})
+
+ def get_image_list(self, dataitem: Dict[str, Any]) -> List[Union[bytes, str, None]]:
+ images: List[Union[bytes, str, None]] = []
+ for k in self.image_keys:
+ v = dataitem.get(k, None)
+ if v is None:
+ continue
+ if isinstance(v, Image.Image):
+ images.append(v)
+ elif isinstance(v, dict) and "bytes" in v:
+ images.append(v["bytes"])
+ else:
+ images.append(v) # path or url
+ if not images:
+ images = [None]
+ return images
+
+ def build_prompt_text(self, data_item: Dict[str, Any]) -> str:
+ # Default: prompt + optional choices rendered inline
+ q = data_item.get(self.prompt_key, "")
+ choices = data_item.get(self.choice_key, []) if self.choice_key else []
+ if not isinstance(choices, list):
+ choices = [choices]
+ if choices:
+ return f"{q}{choices}\n"
+ return str(q)
+
+ def encode_prompt(
+ self, prompt_text: str, images
+ ) -> Tuple[torch.Tensor, int, Optional[str]]:
+ """
+ Return (token_ids[L], length, prompt_text_used). If using chat template, encode with processor.
+ Subclasses may override to support alternative prompting.
+ """
+ if self.use_chat_template:
+ if self._processor is None:
+ self._processor = AutoProcessor.from_pretrained(
+ self.cfg.actor.model.model_path
+ )
+ messages = []
+ if self.system_prompt is not None:
+ messages.append(
+ {
+ "role": "system",
+ "content": [{"type": "text", "text": self.system_prompt}],
+ }
+ )
+
+ content: List[Dict[str, Any]] = []
+ for _ in range(max(0, len(images))):
+ content.append({"type": "image"})
+ content.append({"type": "text", "text": prompt_text})
+ messages.append({"role": "user", "content": content})
+ rendered = self._processor.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ images_inputs = []
+ for image in images:
+ image_obj = None
+ if isinstance(image, Image.Image):
+ image_obj = image.convert("RGB")
+ if isinstance(image, (bytes, bytearray)):
+ image_obj = Image.open(BytesIO(image)).convert("RGB")
+ images_inputs.append(image_obj)
+
+ inputs = self._processor(
+ text=[rendered], images=images_inputs, padding=True, return_tensors="pt"
+ )
+ inputs.pop("attention_mask")
+ if self.cfg.rollout.rollout_backend == "sglang":
+ ids = inputs.pop("input_ids")
+ elif self.cfg.rollout.rollout_backend == "vllm":
+ inputs.pop("input_ids")
+ ids = self._processor(
+ text=[rendered], images=None, padding=True, return_tensors="pt"
+ )["input_ids"]
+ else:
+ raise ValueError(
+ f"Unsupported rollout backend {self.cfg.rollout.rollout_backend}"
+ )
+ if isinstance(ids, torch.Tensor):
+ if ids.dim() == 2 and ids.size(0) == 1:
+ ids = ids.squeeze(0)
+ ids = ids.to(dtype=torch.long)
+ else:
+ ids = torch.tensor(ids, dtype=torch.long)
+
+ multi_modal_inputs = {}
+ for k, v in inputs.items():
+ multi_modal_inputs[k] = v
+ return ids, int(ids.numel()), rendered, multi_modal_inputs
+ else:
+ # fallback: tokenizer only
+ ids_list = self.tokenizer.encode(prompt_text)
+ ids = torch.as_tensor(ids_list, dtype=torch.long)
+ return ids, int(ids.numel()), prompt_text, {}
+
+ def postprocess_dataset_item(
+ self, item: DatasetItem, raw: Dict[str, Any]
+ ) -> DatasetItem:
+ return item
+
+ def _expand_data_paths(self, inputs: List[str]) -> List[str]:
+ exts = {".jsonl", ".json", ".parquet"}
+ files: List[str] = []
+ for p in inputs:
+ if os.path.isdir(p):
+ for root, _, fnames in os.walk(p):
+ for fn in fnames:
+ ext = os.path.splitext(fn)[1].lower()
+ if ext in exts:
+ files.append(os.path.join(root, fn))
+ else:
+ files.append(p)
+ files = sorted(set(files))
+ return files
+
+ def _eager_load_all(self) -> None:
+ merged: List[Dict[str, Any]] = []
+ for path in self.data_paths:
+ fmt = os.path.splitext(path)[1].lower()
+ if fmt == ".jsonl":
+ with open(path, "r", encoding="utf-8") as f:
+ merged.extend(json.loads(l) for l in f)
+ elif fmt == ".json":
+ with open(path, "r", encoding="utf-8") as f:
+ content = json.load(f)
+ if isinstance(content, list):
+ merged.extend(content)
+ else:
+ merged.append(content)
+ elif fmt == ".parquet":
+ try:
+ merged.extend(pd.read_parquet(path).to_dict(orient="records"))
+ except Exception as e:
+ raise RuntimeError(f"Failed to load parquet eagerly: {path}: {e}")
+ else:
+ logging.warning(f"Unsupported format {fmt} for path {path}, skipping.")
+ self._records = merged
+ # Build indices for consistency
+ self._indices = [("", "eager", i) for i in range(len(self._records))]
+
+ def _build_lazy_indices(self) -> None:
+ self._indices.clear()
+ for path in self.data_paths:
+ fmt = os.path.splitext(path)[1].lower()
+ if fmt == ".jsonl":
+ # index by byte offsets for each line
+ offsets: List[int] = []
+ with open(path, "rb") as fb:
+ pos = 0
+ for line in fb:
+ offsets.append(pos)
+ pos += len(line)
+ self._indices.extend((path, "jsonl", off) for off in offsets)
+ elif fmt == ".json":
+ try:
+ with open(path, "r", encoding="utf-8") as f:
+ content = json.load(f)
+ if not isinstance(content, list):
+ content = [content]
+ # store the content to avoid re-reading
+ # keep perfile cache
+ self._json_cache = getattr(self, "_json_cache", {})
+ self._json_cache[path] = content
+ self._indices.extend((path, "json", i) for i in range(len(content)))
+ except Exception as e:
+ raise RuntimeError(f"Failed to index json lazily: {path}: {e}")
+ elif fmt == ".parquet":
+ try:
+ import pyarrow.parquet as pq # type: ignore
+
+ pf = pq.ParquetFile(path)
+ num_rows = pf.metadata.num_rows
+ # file handle cache
+ self._parquet_cache = getattr(self, "_parquet_cache", {})
+ self._parquet_cache[path] = pf
+ self._indices.extend((path, "parquet", i) for i in range(num_rows))
+ except Exception:
+ df = pd.read_parquet(path)
+ self._parquet_df_cache = getattr(self, "_parquet_df_cache", {})
+ self._parquet_df_cache[path] = df
+ self._indices.extend(
+ (path, "parquet_pd", i) for i in range(len(df))
+ )
+ else:
+ logging.warning(f"Unsupported format {fmt} for path {path}, skipping.")
+
+ def _load_single_lazy(self, path: str, fmt: str, key: Any) -> Dict[str, Any]:
+ if fmt == "eager":
+ return self._records[int(key)]
+ if fmt == "jsonl":
+ with open(path, "rb") as fb:
+ fb.seek(int(key))
+ line = fb.readline()
+ return json.loads(line.decode("utf-8").strip())
+ if fmt == "json":
+ return self._json_cache[path][int(key)] # type: ignore[attr-defined]
+ if fmt == "parquet":
+ # Try to use pyarrow lazily; rebuild cache if missing
+ self._parquet_cache = getattr(self, "_parquet_cache", {})
+ pf = self._parquet_cache.get(path)
+ if pf is None:
+ try:
+ import pyarrow.parquet as pq # type: ignore
+
+ pf = pq.ParquetFile(path)
+ self._parquet_cache[path] = pf
+ except Exception:
+ # Fall back to pandas-based cache
+ self._parquet_df_cache = getattr(self, "_parquet_df_cache", {})
+ df = self._parquet_df_cache.get(path)
+ if df is None:
+ df = pd.read_parquet(path)
+ self._parquet_df_cache[path] = df
+ return df.iloc[int(key)].to_dict()
+ table = pf.read_row_group(key // max(1, pf.metadata.num_rows), columns=None)
+ try:
+ df = table.to_pandas()
+ return df.iloc[int(key) % len(df)].to_dict()
+ except Exception:
+ df_all = pf.read().to_pandas()
+ return df_all.iloc[int(key)].to_dict()
+ if fmt == "parquet_pd":
+ self._parquet_df_cache = getattr(self, "_parquet_df_cache", {})
+ df = self._parquet_df_cache.get(path)
+ if df is None:
+ df = pd.read_parquet(path)
+ self._parquet_df_cache[path] = df
+ return df.iloc[int(key)].to_dict()
+ raise RuntimeError(f"Unknown lazy fmt {fmt}")
+
+ def _process_raw_record(self, raw: Dict[str, Any], idx: int) -> DatasetItem:
+ images = self.get_image_list(raw)
+ prompt_text = self.build_prompt_text(raw)
+ prompt_ids, plen, rendered_text, multi_modal_inputs = self.encode_prompt(
+ prompt_text, images
+ )
+
+ if plen > self.max_prompt_length:
+ prompt_ids = prompt_ids[: self.max_prompt_length]
+ plen = self.max_prompt_length
+ prompt_ids = batch_pad_to_fixed_len(
+ [prompt_ids], self.max_prompt_length, self.eos_id, left_pad=True
+ )[0]
+
+ answer_val = raw.get(self.answer_key, None) if self.answer_key else None
+ solution_val = raw.get(self.solution_key, None) if self.solution_key else None
+ item = DatasetItem(
+ prompt=prompt_ids,
+ length=plen,
+ answer=str(answer_val) if answer_val is not None else None,
+ idx=idx,
+ image_data=images,
+ prompt_text=rendered_text or prompt_text,
+ solution=solution_val,
+ meta=None,
+ multi_modal_inputs=multi_modal_inputs,
+ )
+ return self.postprocess_dataset_item(item, raw)
+
+
+class VLMDatasetRegistry:
+ registry: Dict[str, Callable[..., VLMBaseDataset]] = {}
+
+ @classmethod
+ def register(
+ cls, name: str
+ ) -> Callable[[Callable[..., VLMBaseDataset]], Callable[..., VLMBaseDataset]]:
+ def decorator(klass: Callable[..., VLMBaseDataset]):
+ cls.registry[name] = klass
+ return klass
+
+ return decorator
+
+ @classmethod
+ def create(
+ cls,
+ dataset_name: Optional[str],
+ *,
+ data_paths: Union[List[str], str],
+ config: DictConfig,
+ tokenizer: AutoTokenizer,
+ ) -> VLMBaseDataset:
+ key = dataset_name.lower()
+ dataset_class = cls.registry.get(key)
+ return dataset_class(data_paths=data_paths, config=config, tokenizer=tokenizer)
+
+
+@VLMDatasetRegistry.register("robo2vlm")
+class Robo2VLMDataset(VLMBaseDataset):
+ def __init__(
+ self,
+ data_paths: Union[List[str], str],
+ config: DictConfig,
+ tokenizer: AutoTokenizer,
+ ) -> None:
+ super().__init__(data_paths, config, tokenizer)
+ self.system_prompt = (
+ "You are a helpful robotic vision assistant specialized in "
+ "answering questions about robotic manipulation tasks. "
+ "Use tags to show your reasoning process, "
+ "then provide your final answer in tags."
+ )
+
+ def get_image_list(self, dataitem: Dict[str, Any]) -> List[Union[bytes, str, None]]:
+ images: List[Any] = []
+ if "images" in dataitem:
+ v = dataitem.get("images")
+ if isinstance(v, list):
+ images = list(v)
+ elif v is not None:
+ images = [v]
+ else:
+ images = [None]
+ elif "image" in dataitem:
+ v = dataitem.get("image")
+ if v is not None:
+ images = [v]
+ else:
+ images = [None]
+ else:
+ return super().get_image_list(dataitem)
+
+ normed: List[Union[bytes, str, None]] = []
+ for v in images:
+ if v is None:
+ continue
+ if isinstance(v, Image.Image):
+ normed.append(v)
+ elif isinstance(v, dict) and "bytes" in v:
+ normed.append(v["bytes"]) # raw bytes
+ else:
+ normed.append(v) # path/uri/string
+ if not normed:
+ normed = [None]
+ return normed
+
+ def build_prompt_text(self, data_item: Dict[str, Any]) -> str:
+ # Use 'question' and 'choices' if present; else fallback to base using configured prompt/choice keys
+ question = data_item.get("question", None)
+ choices = data_item.get("choices", None)
+ if question is None:
+ return super().build_prompt_text(data_item)
+ # normalize choices
+ if isinstance(choices, str):
+ try:
+ import ast
+
+ choices = ast.literal_eval(choices)
+ except Exception:
+ choices = [choices]
+ if not isinstance(choices, list):
+ choices = [choices] if choices is not None else []
+
+ text = f"{question}\n"
+ if choices:
+ text += "Choices:\n"
+ for i, c in enumerate(choices):
+ text += f"{chr(65 + i)}. {c}\n"
+ return text
+
+ def postprocess_dataset_item(
+ self, item: DatasetItem, raw: Dict[str, Any]
+ ) -> DatasetItem:
+ answer_dict = {
+ "choices": raw.get("choices", None),
+ "correct_answer": raw.get("correct_answer", None),
+ }
+ item.answer = answer_dict
+
+ return item
diff --git a/rlinf/data/io_struct.py b/rlinf/data/io_struct.py
index 1c455fa70..2fa0fea8a 100644
--- a/rlinf/data/io_struct.py
+++ b/rlinf/data/io_struct.py
@@ -13,19 +13,21 @@
# limitations under the License.
from dataclasses import dataclass
-from typing import Callable, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import torch
from omegaconf import DictConfig
-from vllm.outputs import CompletionOutput
-from vllm.outputs import RequestOutput as VllmRequestOutput
-from rlinf.data.datasets import batch_pad_to_fixed_len
+if TYPE_CHECKING:
+ from vllm.outputs import CompletionOutput
+ from vllm.outputs import RequestOutput as VllmRequestOutput
+
+from rlinf.data.datasets.utils import batch_pad_to_fixed_len
from rlinf.utils.data_iter_utils import (
get_iterator_k_split,
split_list,
)
-
+import torch_npu
def get_batch_size(
batch: Dict[str, torch.Tensor], batch_tensor_key: str = "input_ids"
@@ -47,14 +49,16 @@ class RolloutRequest:
Attr
input_ids: List of input token IDs for rollout
n: Number of completions to generate for each input
- idx: List of unique identifiers for the requests, used for tracking
- input_lengths: List of lengths of the input sequences, corresponding to input_ids
+ image_data: list of image data (bytes or URLs) for multimodal inputs
answers: Optional list of answers for the requests, if available
+ multi_modal_inputs: list of multi-modal inputs for the requests
"""
n: int
input_ids: List[List[int]]
+ image_data: Union[List[List[bytes]], List[List[str]]]
answers: List[str]
+ multi_modal_inputs: List[Dict]
def repeat(self) -> "RolloutRequest":
"""Repeat each input in the RolloutRequest a specified number of times.
@@ -67,10 +71,15 @@ def repeat(self) -> "RolloutRequest":
"""
assert self.n > 0, "n must be greater than 0"
- input_ids, answers = zip(
+ input_ids, answers, image_data, multi_modal_inputs = zip(
*[
- (input_id, answer)
- for input_id, answer in zip(self.input_ids, self.answers)
+ (input_id, answer, image_data, multi_modal_inputs)
+ for input_id, answer, image_data, multi_modal_inputs in zip(
+ self.input_ids,
+ self.answers,
+ self.image_data,
+ self.multi_modal_inputs,
+ )
for _ in range(self.n)
]
)
@@ -78,6 +87,8 @@ def repeat(self) -> "RolloutRequest":
n=self.n,
input_ids=list(input_ids),
answers=list(answers),
+ image_data=list(image_data),
+ multi_modal_inputs=list(multi_modal_inputs),
)
def split(self, num_splits: int) -> List["RolloutRequest"]:
@@ -96,15 +107,27 @@ def split(self, num_splits: int) -> List["RolloutRequest"]:
input_ids_split_list = split_list(self.input_ids, num_splits)
answers_split_list = split_list(self.answers, num_splits)
+ image_data_split_list = split_list(self.image_data, num_splits)
+ multi_modal_inputs_split_list = split_list(self.multi_modal_inputs, num_splits)
splitted_requests = []
- for input_ids_batch, answers_batch in zip(
- input_ids_split_list, answers_split_list
+ for (
+ input_ids_batch,
+ answers_batch,
+ image_data_batch,
+ multi_modal_inputs_batch,
+ ) in zip(
+ input_ids_split_list,
+ answers_split_list,
+ image_data_split_list,
+ multi_modal_inputs_split_list,
):
request = RolloutRequest(
n=self.n,
input_ids=input_ids_batch,
answers=answers_batch,
+ image_data=image_data_batch,
+ multi_modal_inputs=multi_modal_inputs_batch,
)
splitted_requests.append(request)
@@ -113,14 +136,24 @@ def split(self, num_splits: int) -> List["RolloutRequest"]:
def repeat_and_split(
self, rollout_batch_size: Optional[int] = None
) -> List["RolloutRequest"]:
- input_ids, answers = zip(
+ input_ids, answers, image_data, multi_modal_inputs = zip(
*[
- (input_id, answer)
- for input_id, answer in zip(self.input_ids, self.answers)
+ (input_id, answer, image_data, multi_modal_inputs)
+ for input_id, answer, image_data, multi_modal_inputs in zip(
+ self.input_ids,
+ self.answers,
+ self.image_data,
+ self.multi_modal_inputs,
+ )
for _ in range(self.n)
]
)
- input_ids, answers = (list(input_ids), list(answers))
+ input_ids, answers, image_data, multi_modal_inputs = (
+ list(input_ids),
+ list(answers),
+ list(image_data),
+ list(multi_modal_inputs),
+ )
# Split input ids based on rollout_batch_size_per_gpu
if rollout_batch_size is None:
@@ -134,14 +167,26 @@ def repeat_and_split(
splitted_requests = []
input_ids_split_list = split_list(input_ids, num_batches)
answers_split_list = split_list(answers, num_batches)
-
- for input_ids_batch, answers_batch in zip(
- input_ids_split_list, answers_split_list
+ image_data_split_list = split_list(image_data, num_batches)
+ multi_modal_inputs_split_list = split_list(multi_modal_inputs, num_batches)
+
+ for (
+ input_ids_batch,
+ answers_batch,
+ image_data_batch,
+ multi_modal_inputs_batch,
+ ) in zip(
+ input_ids_split_list,
+ answers_split_list,
+ image_data_split_list,
+ multi_modal_inputs_split_list,
):
request = RolloutRequest(
n=self.n,
input_ids=input_ids_batch,
answers=answers_batch,
+ image_data=image_data_batch,
+ multi_modal_inputs=multi_modal_inputs_batch,
)
splitted_requests.append(request)
@@ -256,8 +301,9 @@ class RolloutResult:
advantages: Optional[List[float] | torch.Tensor] = None
prompt_texts: Optional[List[str]] = None
response_texts: Optional[List[str]] = None
- answers: Optional[List[str]] = None
-
+ answers: Optional[List[str | dict]] = None
+ image_data: Optional[Union[List[List[bytes]], List[List[str]]]] = None
+ multi_modal_inputs: Optional[List[dict]] = None
# Inference
# Only set when recompute_logprobs is False
rollout_logprobs: Optional[List[List[float]]] = None
@@ -308,12 +354,13 @@ def _get_attention_masks_and_position_ids(
@staticmethod
def from_vllm_results(
group_size: int,
- results: List[VllmRequestOutput],
+ results: List["VllmRequestOutput"],
answers: Optional[List[str]] = None,
+ multi_modal_inputs: Optional[List[Dict]] = None,
return_logprobs: bool = False,
) -> "RolloutResult":
def get_logprobs(
- response_ids: List[int], output: CompletionOutput
+ response_ids: List[int], output: "CompletionOutput"
) -> List[float]:
logprobs = []
returned_logprobs = output.logprobs
@@ -326,6 +373,13 @@ def get_logprobs(
num_sequences = len(results) * group_size
+ if multi_modal_inputs:
+ mm_inputs = []
+ for mm_input in multi_modal_inputs:
+ mm_inputs.extend([mm_input] * group_size)
+ else:
+ mm_inputs = None
+
prompt_lengths = []
prompt_ids = []
response_lengths = []
@@ -368,6 +422,7 @@ def get_logprobs(
response_ids=response_ids,
response_lengths=response_lengths,
response_texts=response_texts,
+ multi_modal_inputs=mm_inputs,
is_end=is_end,
)
if return_logprobs:
@@ -380,6 +435,8 @@ def from_sglang_results(
group_size: int,
input_ids: List[List[int]],
answers: Optional[List[List[int]]] = None,
+ image_data: Optional[Union[List[List[bytes]], List[List[str]]]] = None,
+ multi_modal_inputs: Optional[List[Dict]] = None,
return_logprobs: bool = False,
) -> "RolloutResult":
"""Create a MathRolloutResult from the given results and input IDs.
@@ -406,6 +463,8 @@ def from_sglang_results(
response_lengths=[len(res["output_ids"]) for res in results],
response_ids=[res["output_ids"] for res in results],
answers=answers,
+ image_data=image_data,
+ multi_modal_inputs=multi_modal_inputs,
is_end=[
res["meta_info"]["finish_reason"]["type"] == "stop" for res in results
],
@@ -507,6 +566,161 @@ def merge_list(dst_list: List, src_list: List):
return merged_result
+ @staticmethod
+ def split_result_list_by_group(
+ rollout_results: List["RolloutResult"],
+ ) -> List["RolloutResult"]:
+ """
+ Split RolloutResult objects by group_size.
+
+ If input has only one RolloutResult, split it into multiple RolloutResult objects by group_size.
+ If input has multiple RolloutResult objects, split each one and merge the results.
+
+ Args:
+ rollout_results: List of input RolloutResult objects
+
+ Returns:
+ List of RolloutResult objects grouped by group_size
+ """
+ assert len(rollout_results) > 0, "No rollout results to split."
+
+ all_split_results = []
+
+ for rollout_result in rollout_results:
+ split_results = RolloutResult._split_single_result_by_group(rollout_result)
+ all_split_results.extend(split_results)
+
+ return all_split_results
+
+ @staticmethod
+ def _split_single_result_by_group(
+ rollout_result: "RolloutResult",
+ ) -> List["RolloutResult"]:
+ """
+ Split a single RolloutResult into multiple RolloutResult objects by group_size.
+
+ Args:
+ rollout_result: The RolloutResult to be split
+
+ Returns:
+ List of split RolloutResult objects
+ """
+ group_size = rollout_result.group_size
+ num_sequence = rollout_result.num_sequence
+
+ assert num_sequence % group_size == 0, (
+ f"num_sequence ({num_sequence}) must be divisible by group_size ({group_size})"
+ )
+
+ num_groups = num_sequence // group_size
+ split_results = []
+
+ # Split list fields
+ prompt_lengths_split = split_list(rollout_result.prompt_lengths, num_groups)
+ prompt_ids_split = split_list(rollout_result.prompt_ids, num_groups)
+ response_lengths_split = split_list(rollout_result.response_lengths, num_groups)
+ response_ids_split = split_list(rollout_result.response_ids, num_groups)
+ is_end_split = split_list(rollout_result.is_end, num_groups)
+
+ # Handle optional fields
+ answers_split = None
+ if rollout_result.answers is not None:
+ answers_split = split_list(rollout_result.answers, num_groups)
+
+ image_data_split = None
+ if rollout_result.image_data is not None:
+ image_data_split = split_list(rollout_result.image_data, num_groups)
+
+ multi_modal_inputs_split = None
+ if rollout_result.multi_modal_inputs is not None:
+ multi_modal_inputs_split = split_list(
+ rollout_result.multi_modal_inputs, num_groups
+ )
+
+ prompt_texts_split = None
+ if rollout_result.prompt_texts is not None:
+ prompt_texts_split = split_list(rollout_result.prompt_texts, num_groups)
+
+ response_texts_split = None
+ if rollout_result.response_texts is not None:
+ response_texts_split = split_list(rollout_result.response_texts, num_groups)
+
+ rollout_logprobs_split = None
+ if rollout_result.rollout_logprobs is not None:
+ rollout_logprobs_split = split_list(
+ rollout_result.rollout_logprobs, num_groups
+ )
+
+ # Handle tensor fields
+ rewards_split = None
+ if rollout_result.rewards is not None:
+ if isinstance(rollout_result.rewards, torch.Tensor):
+ rewards_split = torch.chunk(rollout_result.rewards, num_groups, dim=0)
+ else:
+ rewards_split = split_list(rollout_result.rewards, num_groups)
+
+ advantages_split = None
+ if rollout_result.advantages is not None:
+ if isinstance(rollout_result.advantages, torch.Tensor):
+ advantages_split = torch.chunk(
+ rollout_result.advantages, num_groups, dim=0
+ )
+ else:
+ advantages_split = split_list(rollout_result.advantages, num_groups)
+
+ prev_logprobs_split = None
+ if rollout_result.prev_logprobs is not None:
+ prev_logprobs_split = torch.chunk(
+ rollout_result.prev_logprobs, num_groups, dim=0
+ )
+
+ ref_logprobs_split = None
+ if rollout_result.ref_logprobs is not None:
+ ref_logprobs_split = torch.chunk(
+ rollout_result.ref_logprobs, num_groups, dim=0
+ )
+
+ # Create split RolloutResult objects
+ for i in range(num_groups):
+ split_result = RolloutResult(
+ num_sequence=group_size,
+ group_size=group_size,
+ prompt_lengths=prompt_lengths_split[i],
+ prompt_ids=prompt_ids_split[i],
+ response_lengths=response_lengths_split[i],
+ response_ids=response_ids_split[i],
+ is_end=is_end_split[i],
+ answers=answers_split[i] if answers_split is not None else None,
+ image_data=image_data_split[i]
+ if image_data_split is not None
+ else None,
+ multi_modal_inputs=multi_modal_inputs_split[i]
+ if multi_modal_inputs_split is not None
+ else None,
+ prompt_texts=prompt_texts_split[i]
+ if prompt_texts_split is not None
+ else None,
+ response_texts=response_texts_split[i]
+ if response_texts_split is not None
+ else None,
+ rollout_logprobs=rollout_logprobs_split[i]
+ if rollout_logprobs_split is not None
+ else None,
+ rewards=rewards_split[i] if rewards_split is not None else None,
+ advantages=advantages_split[i]
+ if advantages_split is not None
+ else None,
+ prev_logprobs=prev_logprobs_split[i]
+ if prev_logprobs_split is not None
+ else None,
+ ref_logprobs=ref_logprobs_split[i]
+ if ref_logprobs_split is not None
+ else None,
+ )
+ split_results.append(split_result)
+
+ return split_results
+
def to_actor_batch(
self,
data_seq_length: int,
@@ -595,17 +809,23 @@ def to_actor_batch(
) # [B, training_seq_length]
batch = {
- "input_ids": input_ids.cuda(),
- "attention_mask": attention_mask.cuda(),
- "is_end": is_end.cuda(),
- "position_ids": position_ids.cuda(),
- "prompt_lengths": prompt_lengths.cuda(),
- "response_lengths": response_lengths.cuda(),
+ "input_ids": input_ids.npu(),
+ "attention_mask": attention_mask.npu(),
+ "is_end": is_end.npu(),
+ "position_ids": position_ids.npu(),
+ "prompt_lengths": prompt_lengths.npu(),
+ "response_lengths": response_lengths.npu(),
}
+ if (
+ self.multi_modal_inputs is not None
+ and self.multi_modal_inputs[0] is not None
+ ):
+ batch["multi_modal_inputs"] = self.multi_modal_inputs
+
if self.advantages is not None:
if isinstance(self.advantages, torch.Tensor):
- batch["advantages"] = self.advantages.cuda()
+ batch["advantages"] = self.advantages.npu()
else:
response_attention_mask = attention_mask[
:, -max_response_len:
@@ -613,17 +833,17 @@ def to_actor_batch(
advantages = torch.tensor(self.advantages, dtype=torch.float32).reshape(
-1, 1
) # [B, 1]
- advantages = response_attention_mask.float().cuda() * advantages.cuda()
- batch["advantages"] = advantages.cuda()
+ advantages = response_attention_mask.float().npu() * advantages.npu()
+ batch["advantages"] = advantages.npu()
if self.prev_logprobs is not None:
- batch["prev_logprobs"] = self.prev_logprobs.cuda()
+ batch["prev_logprobs"] = self.prev_logprobs.npu()
if self.ref_logprobs is not None:
- batch["ref_logprobs"] = self.ref_logprobs.cuda()
+ batch["ref_logprobs"] = self.ref_logprobs.npu()
if self.rewards is not None:
- batch["rewards"] = self.rewards.cuda()
+ batch["rewards"] = self.rewards.npu()
if self.rollout_logprobs is not None:
logprobs = batch_pad_to_fixed_len(
@@ -634,7 +854,7 @@ def to_actor_batch(
max_batch_len=max_response_len,
pad_token=pad_token,
)
- batch["prev_logprobs"] = logprobs.cuda()
+ batch["prev_logprobs"] = logprobs.npu()
return batch
@@ -648,14 +868,16 @@ def merge_batches(
return merged_batch
if len(batches) == 1:
return batches[0]
+
for key in batches[0].keys():
- assert torch.is_tensor(batches[0][key]), (
- f"Expected tensor for key {key} in batches, got {type(batches[0][key])}"
- )
- assert torch.is_tensor(batches[0][key]), (
- f"Expected tensor for key {key} in batches, got {type(batches[0][key])}"
- )
- merged_batch[key] = torch.cat([batch[key] for batch in batches], dim=0)
+ if torch.is_tensor(batches[0][key]):
+ merged_batch[key] = torch.cat([batch[key] for batch in batches], dim=0)
+ elif isinstance(batches[0][key], list):
+ merged_batch[key] = []
+ for batch in batches:
+ merged_batch[key].extend(batch[key])
+ else:
+ raise ValueError(f"Unsupported batch key type: {type(batches[0][key])}")
return merged_batch
diff --git a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py
index 8e3b9f16c..690b3d4f5 100644
--- a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py
+++ b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py
@@ -17,11 +17,17 @@
import torch
import torch.optim as optim
from omegaconf import DictConfig
+from torch.distributed.fsdp import (
+ BackwardPrefetch,
+ MixedPrecision,
+ ShardingStrategy,
+ StateDictType,
+)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType
-from transformers import AutoModelForCausalLM
+from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq
from rlinf.config import torch_dtype_from_precision
+from rlinf.data.tokenizers import hf_tokenizer
from rlinf.hybrid_engines.fsdp.utils import (
get_fsdp_wrap_policy,
init_fn,
@@ -29,6 +35,7 @@
from rlinf.utils.logging import get_logger
from rlinf.utils.utils import clear_memory
+import torch_npu
class FSDPModelManager:
"""
@@ -40,40 +47,55 @@ def __init__(self, cfg: DictConfig):
self.logger = get_logger()
self.torch_dtype = torch_dtype_from_precision(self._cfg.model.precision)
- assert (
- self.torch_dtype == torch.float16 or self.torch_dtype == torch.bfloat16
- ), (
- f"Precision {self._cfg.model.precision} is not supported, only support bf16 and fp16."
- )
+ self.tokenizer = hf_tokenizer(cfg.tokenizer.tokenizer_model)
def model_provider_func(self) -> torch.nn.Module:
- if self._cfg.model.get("gptq_model", False):
+ cfg = self._cfg
+ use_gptq = cfg.model.get("gptq_model", False)
+ load_in_8bit = cfg.model.get("load_in_8bit", False)
+
+ use_triton = cfg.get("use_triton", True)
+
+ assert torch.npu.is_available(), "CUDA is not available."
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
+ device = torch.npu.device(f"npu:{local_rank}")
+
+ model_config = AutoConfig.from_pretrained(
+ cfg.model.model_path,
+ trust_remote_code=True,
+ attn_implementation="flash_attention_2",
+ )
+
+ if use_gptq:
from auto_gptq import AutoGPTQForCausalLM
model_wrapper = AutoGPTQForCausalLM.from_quantized(
- self._cfg.model.model_path, device="cuda:0", use_triton=True
+ cfg.model.model_path,
+ device=device,
+ use_triton=use_triton,
)
model = model_wrapper.model
- elif self._cfg.model.get("load_in_8bit", False):
+ elif load_in_8bit:
model = AutoModelForCausalLM.from_pretrained(
- self._cfg.model.model_path,
- device_map=self._cfg.model.get("device_map", "auto"),
+ cfg.model.model_path,
+ config=model_config,
load_in_8bit=True,
)
else:
- # default load in float16
- model = AutoModelForCausalLM.from_pretrained(
- self._cfg.model.model_path,
+ if type(model_config) in AutoModelForVision2Seq._model_mapping.keys():
+ auto_model_class = AutoModelForVision2Seq
+ else:
+ auto_model_class = AutoModelForCausalLM
+
+ model = auto_model_class.from_pretrained(
+ cfg.model.model_path,
torch_dtype=self.torch_dtype,
- device_map=self._cfg.model.get("device_map", "auto"),
+ config=model_config,
trust_remote_code=True,
- use_safetensors=self._cfg.model.get("use_safetensors", False),
)
- if torch.cuda.is_available():
- model = model.cuda()
- if self.torch_dtype == torch.float16:
- model = model.half()
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier()
return model
def setup_model_and_optimizer(self):
@@ -108,12 +130,19 @@ def setup_model_and_optimizer(self):
self.model = FSDP(
module,
param_init_fn=init_fn,
- use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=int(os.environ["LOCAL_RANK"]),
sharding_strategy=sharding_strategy, # zero3
mixed_precision=mixed_precision,
sync_module_states=True,
+ forward_prefetch=self._cfg.fsdp.forward_prefetch,
+ backward_prefetch=(
+ BackwardPrefetch.BACKWARD_PRE
+ if self._cfg.fsdp.backward_prefetch
+ else None
+ ),
+ limit_all_gathers=self._cfg.fsdp.limit_all_gathers,
+ use_orig_params=self._cfg.fsdp.use_orig_params,
)
# NOTE: Currently we assume that only the value head contains "value_head" in its name.
@@ -130,7 +159,7 @@ def setup_model_and_optimizer(self):
},
]
- if self._cfg.model.vh_mode in ["a", "a0", "a6"]:
+ if self._cfg.model.get("vh_mode", None) in ["a", "a0", "a6"]:
param_groups.append(
{
"params": [
diff --git a/rlinf/hybrid_engines/fsdp/utils.py b/rlinf/hybrid_engines/fsdp/utils.py
index 2334b1fa7..0a9f0054d 100644
--- a/rlinf/hybrid_engines/fsdp/utils.py
+++ b/rlinf/hybrid_engines/fsdp/utils.py
@@ -58,7 +58,7 @@ def cpu_init_weights():
return init_context
-def get_fsdp_wrap_policy(module, config=None, is_lora=False):
+def get_fsdp_wrap_policy(module, config=None, is_lora=False, is_vla_model=False):
"""
FSDP wrap policy that handles both standard transformer models and VLA models.
@@ -76,11 +76,8 @@ def get_fsdp_wrap_policy(module, config=None, is_lora=False):
if config.get("disable", False):
return None
- # Check if this is a VLA model by looking for language_model attribute
- is_vla_model = hasattr(module, "language_model")
-
# Get transformer layer classes to wrap
- if is_vla_model:
+ if hasattr(module, "language_model"):
# For VLA models, get transformer classes from language_model submodule
default_transformer_cls_names_to_wrap = getattr(
module.language_model, "_no_split_modules", None
diff --git a/rlinf/hybrid_engines/megatron/megatron_model_manager.py b/rlinf/hybrid_engines/megatron/megatron_model_manager.py
index 57f34dbf5..6fe4fbfd6 100644
--- a/rlinf/hybrid_engines/megatron/megatron_model_manager.py
+++ b/rlinf/hybrid_engines/megatron/megatron_model_manager.py
@@ -184,6 +184,7 @@ def model_provider_func(self, pre_process, post_process):
return model
def optimizer_step(self, increment):
+ clear_memory()
success, grad_norm, num_zeros_in_grad = self.optimizer.step()
self.lr_scheduler.step(increment=increment)
diff --git a/rlinf/hybrid_engines/sglang/sglang_0_4_4/sgl_scheduler.py b/rlinf/hybrid_engines/sglang/sglang_0_4_4/sgl_scheduler.py
index 509c56bab..87f736798 100644
--- a/rlinf/hybrid_engines/sglang/sglang_0_4_4/sgl_scheduler.py
+++ b/rlinf/hybrid_engines/sglang/sglang_0_4_4/sgl_scheduler.py
@@ -108,10 +108,13 @@ def __init__(
placement
)[(self.get_parent_rank(), self._rank)]
+ use_presharded_weights = (
+ False if self.cfg.actor.training_backend == "fsdp" else True
+ )
# it's important to use load_weight to load resharded weight from megatron
for _, module in self.tp_worker.worker.model_runner.model.named_modules():
if hasattr(module, "use_presharded_weights"):
- module.use_presharded_weights = True
+ module.use_presharded_weights = use_presharded_weights
self._logger.info(
f"Running Scheduler dp rank {self.get_parent_rank()}, tp rank {self.tp_rank}, corresponding actor weight rank = {self.actor_weight_rank}"
diff --git a/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py b/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py
index ef503527b..9a69b8548 100644
--- a/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py
+++ b/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py
@@ -110,10 +110,13 @@ def __init__(
self.actor_weight_rank = RankMapper.get_rollout_rank_to_actor_rank_map(
placement
)[(self.get_parent_rank(), self._rank)]
+ use_presharded_weights = (
+ False if self.cfg.actor.training_backend == "fsdp" else True
+ )
# it's important to use load_weight to load resharded weight from megatron
for _, module in self.tp_worker.worker.model_runner.model.named_modules():
if hasattr(module, "use_presharded_weights"):
- module.use_presharded_weights = True
+ module.use_presharded_weights = use_presharded_weights
self._logger.info(
f"Running Scheduler dp rank {self.get_parent_rank()}, tp rank {self.tp_rank}, corresponding actor weight rank = {self.actor_weight_rank}"
diff --git a/rlinf/hybrid_engines/sglang/sglang_0_4_9/sgl_scheduler.py b/rlinf/hybrid_engines/sglang/sglang_0_4_9/sgl_scheduler.py
index a7057a161..1f9beb409 100644
--- a/rlinf/hybrid_engines/sglang/sglang_0_4_9/sgl_scheduler.py
+++ b/rlinf/hybrid_engines/sglang/sglang_0_4_9/sgl_scheduler.py
@@ -112,10 +112,13 @@ def __init__(
self.actor_weight_rank = RankMapper.get_rollout_rank_to_actor_rank_map(
placement
)[(self.get_parent_rank(), self._rank)]
+ use_presharded_weights = (
+ False if self.cfg.actor.training_backend == "fsdp" else True
+ )
# it's important to use load_weight to load resharded weight from megatron
for _, module in self.tp_worker.worker.model_runner.model.named_modules():
if hasattr(module, "use_presharded_weights"):
- module.use_presharded_weights = True
+ module.use_presharded_weights = use_presharded_weights
self._logger.info(
f"Running Scheduler dp rank {self.get_parent_rank()}, tp rank {self.tp_rank}, corresponding actor weight rank = {self.actor_weight_rank}"
diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/__init__.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/__init__.py
new file mode 100644
index 000000000..5b365ea1e
--- /dev/null
+++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/io_struct.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/io_struct.py
new file mode 100644
index 000000000..960d40eb0
--- /dev/null
+++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/io_struct.py
@@ -0,0 +1,59 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional
+
+
+@dataclass
+class TaskMethodInput:
+ method_name: str
+ args: List[Any] = field(default_factory=list)
+ kwargs: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class TaskMethodOutput:
+ method_name: str
+ result: Optional[Any] = None
+
+
+@dataclass
+class OffloadReqInput:
+ pass
+
+
+@dataclass
+class OffloadReqOutput:
+ pass
+
+
+@dataclass
+class SyncWeightInput:
+ pass
+
+
+@dataclass
+class SyncWeightOutput:
+ pass
+
+
+@dataclass
+class SyncHFWeightInput:
+ pass
+
+
+@dataclass
+class SyncHFWeightOutput:
+ pass
diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_engine.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_engine.py
new file mode 100644
index 000000000..e8e05c88b
--- /dev/null
+++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_engine.py
@@ -0,0 +1,363 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import atexit
+import logging
+import multiprocessing as mp
+import os
+import random
+import signal
+import threading
+import time
+from typing import Dict, Optional, Tuple
+
+import uvloop
+import zmq
+from omegaconf import DictConfig
+from sglang.srt.entrypoints.engine import Engine as _Engine
+from sglang.srt.managers.data_parallel_controller import (
+ run_data_parallel_controller_process,
+)
+from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
+from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
+
+# from sglang.srt.managers.scheduler import run_scheduler_process
+# from sglang.srt.managers.tokenizer_manager import TokenizerManager
+from sglang.srt.managers.template_manager import TemplateManager
+from sglang.srt.server_args import PortArgs, ServerArgs
+from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
+from sglang.srt.utils import (
+ assert_pkg_version,
+ configure_logger,
+ get_bool_env_var,
+ get_zmq_socket,
+ is_cuda,
+ kill_process_tree,
+ launch_dummy_health_check_server,
+ prepare_model_and_tokenizer,
+ set_prometheus_multiproc_dir,
+ set_ulimit,
+)
+
+from rlinf.scheduler import WorkerAddress
+from rlinf.utils.placement import ComponentPlacement
+
+from .io_struct import OffloadReqInput, SyncHFWeightInput, SyncWeightInput
+from .sgl_scheduler import run_scheduler_process
+from .tokenizer_manager import TokenizerManager
+
+# Fix a bug of Python threading
+setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
+
+logger = logging.getLogger(__name__)
+asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
+
+_is_cuda = is_cuda()
+
+
+class Engine(_Engine):
+ def __init__(
+ self,
+ parent_address: WorkerAddress,
+ placement: ComponentPlacement,
+ config: DictConfig,
+ dp_rank: int,
+ **kwargs,
+ ):
+ """
+ The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`.
+ Please refer to `ServerArgs` for the documentation.
+ """
+ if "server_args" in kwargs:
+ # Directly load server_args
+ server_args = kwargs["server_args"]
+ else:
+ # Construct server_args from kwargs
+ if "log_level" not in kwargs:
+ # Do not print logs by default
+ kwargs["log_level"] = "error"
+ server_args = ServerArgs(**kwargs)
+
+ # Shutdown the subprocesses automatically when the program exits
+ atexit.register(self.shutdown)
+
+ # Allocate ports for inter-process communications
+ self.port_args = PortArgs.init_new(server_args)
+
+ # Launch subprocesses
+ tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
+ parent_address=parent_address,
+ placement=placement,
+ config=config,
+ dp_rank=dp_rank,
+ server_args=server_args,
+ port_args=self.port_args,
+ )
+
+ self.server_args = server_args
+ self.tokenizer_manager = tokenizer_manager
+ self.scheduler_info = scheduler_info
+ self.template_manager = template_manager
+
+ context = zmq.Context(2)
+ self.send_to_rpc = get_zmq_socket(
+ context, zmq.DEALER, self.port_args.rpc_ipc_name, True
+ )
+
+ def offload_model_weights(self):
+ """Offload model weights to meta."""
+ obj = OffloadReqInput()
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(
+ self.tokenizer_manager.offload_model_weights(obj, None)
+ )
+
+ def sync_hf_weight(self):
+ obj = SyncHFWeightInput()
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(self.tokenizer_manager.sync_hf_weight(obj))
+
+ def sync_weight(self):
+ obj = SyncWeightInput()
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(self.tokenizer_manager.sync_weight(obj))
+
+
+def _set_envs_and_config(server_args: ServerArgs):
+ # Set global environments
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+ os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
+ if not server_args.enable_symm_mem:
+ os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
+ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
+ os.environ["CUDA_MODULE_LOADING"] = "AUTO"
+ # flashinfer uses this environment variable for various kernels from MoE to quant kernels
+ os.environ["TRTLLM_ENABLE_PDL"] = "1"
+
+ # Can also be passed as argument
+ os.environ["SGLANG_RUN_ID"] = (
+ f"sglang-run-{time.time()}-{random.randint(0, 100000000)}"
+ )
+
+ # Set prometheus env vars
+ if server_args.enable_metrics:
+ set_prometheus_multiproc_dir()
+
+ # Set ulimit
+ set_ulimit()
+
+ # Check flashinfer version
+ if server_args.attention_backend == "flashinfer":
+ assert_pkg_version(
+ "flashinfer_python",
+ "0.3.0",
+ "Please uninstall the old version and "
+ "reinstall the latest version by following the instructions "
+ "at https://docs.flashinfer.ai/installation.html.",
+ )
+ if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
+ assert_pkg_version(
+ "sgl-kernel",
+ "0.3.8",
+ "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
+ )
+
+ if True: # Keep this check for internal code compatibility
+ # Register the signal handler.
+ # The child processes will send SIGQUIT to this process when any error happens
+ # This process then clean up the whole process tree
+ # Note: This sigquit handler is used in the launch phase, and may be replaced by
+ # the running_phase_sigquit_handler in the tokenizer manager after the grpc server is launched.
+ def launch_phase_sigquit_handler(signum, frame):
+ logger.error(
+ "Received sigquit from a child process. It usually means the child failed."
+ )
+ kill_process_tree(os.getpid())
+
+ signal.signal(signal.SIGQUIT, launch_phase_sigquit_handler)
+
+ # Set mp start method
+ mp.set_start_method("spawn", force=True)
+
+
+def _launch_subprocesses(
+ parent_address: WorkerAddress,
+ placement: ComponentPlacement,
+ config: DictConfig,
+ dp_rank: int,
+ server_args: ServerArgs,
+ port_args: Optional[PortArgs] = None,
+) -> Tuple[TokenizerManager, TemplateManager, Dict]:
+ """
+ Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
+ """
+
+ assert server_args.pp_size == 1, (
+ "RLinf currently only supports and validates pp_size=1."
+ )
+
+ # Configure global environment
+ configure_logger(server_args)
+ server_args.check_server_args()
+ _set_envs_and_config(server_args)
+
+ # Allocate ports for inter-process communications
+ if port_args is None:
+ port_args = PortArgs.init_new(server_args)
+ logger.info(f"{server_args=}")
+
+ # If using model from www.modelscope.cn, first download the model.
+ server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
+ server_args.model_path, server_args.tokenizer_path
+ )
+
+ scheduler_procs = []
+ if server_args.dp_size == 1:
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
+ enable=server_args.enable_memory_saver
+ )
+ scheduler_pipe_readers = []
+
+ nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
+ tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
+ tp_rank_range = range(
+ tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
+ tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
+ )
+
+ pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
+ pp_rank_range = range(
+ pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
+ pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
+ )
+
+ for pp_rank in pp_rank_range:
+ for tp_rank in tp_rank_range:
+ reader, writer = mp.Pipe(duplex=False)
+ gpu_id = (
+ server_args.base_gpu_id
+ + ((pp_rank % pp_size_per_node) * tp_size_per_node)
+ + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
+ )
+ moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
+ proc = mp.Process(
+ target=run_scheduler_process,
+ args=(
+ parent_address,
+ placement,
+ config,
+ server_args.tp_size * server_args.pp_size,
+ tp_rank + pp_rank * server_args.pp_size,
+ server_args,
+ port_args,
+ gpu_id,
+ tp_rank,
+ moe_ep_rank,
+ pp_rank,
+ None,
+ writer,
+ None,
+ ),
+ )
+
+ with memory_saver_adapter.configure_subprocess():
+ proc.start()
+ scheduler_procs.append(proc)
+ scheduler_pipe_readers.append(reader)
+ else:
+ # Launch the data parallel controller
+ reader, writer = mp.Pipe(duplex=False)
+ scheduler_pipe_readers = [reader]
+ proc = mp.Process(
+ target=run_data_parallel_controller_process,
+ args=(server_args, port_args, writer),
+ )
+ proc.start()
+ scheduler_procs.append(proc)
+
+ if server_args.node_rank >= 1:
+ # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer,
+ # so they can just wait here.
+
+ for reader in scheduler_pipe_readers:
+ data = reader.recv()
+ assert data["status"] == "ready"
+
+ if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
+ # When using `Engine` as a Python API, we don't want to block here.
+ return None, None, None
+
+ launch_dummy_health_check_server(
+ server_args.host, server_args.port, server_args.enable_metrics
+ )
+
+ for proc in scheduler_procs:
+ proc.join()
+ logger.error(
+ f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
+ )
+ return None, None, None
+
+ # Launch detokenizer process
+ detoken_proc = mp.Process(
+ target=run_detokenizer_process,
+ args=(
+ server_args,
+ port_args,
+ ),
+ )
+ detoken_proc.start()
+ if server_args.tokenizer_worker_num > 1:
+ # Launch multi-tokenizer router
+ tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
+
+ # Initialize templates
+ template_manager = None
+ else:
+ # Launch tokenizer process
+ tokenizer_manager = TokenizerManager(server_args, port_args)
+
+ # Initialize templates
+ template_manager = TemplateManager()
+ template_manager.initialize_templates(
+ tokenizer_manager=tokenizer_manager,
+ model_path=server_args.model_path,
+ chat_template=server_args.chat_template,
+ completion_template=server_args.completion_template,
+ )
+
+ # Wait for the model to finish loading
+ scheduler_infos = []
+ for i in range(len(scheduler_pipe_readers)):
+ try:
+ data = scheduler_pipe_readers[i].recv()
+ except EOFError:
+ logger.error(
+ f"Rank {i} scheduler is dead. Please check if there are relevant logs."
+ )
+ scheduler_procs[i].join()
+ logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
+ raise
+
+ if data["status"] != "ready":
+ raise RuntimeError(
+ "Initialization failed. Please see the error messages above."
+ )
+ scheduler_infos.append(data)
+
+ # Assume all schedulers have the same scheduler_info
+ scheduler_info = scheduler_infos[0]
+ tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
+ return tokenizer_manager, template_manager, scheduler_info
diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_scheduler.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_scheduler.py
new file mode 100644
index 000000000..f6aac9d27
--- /dev/null
+++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_scheduler.py
@@ -0,0 +1,476 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import faulthandler
+import logging
+import os
+import signal
+from typing import Optional
+
+import psutil
+import setproctitle
+import torch
+from omegaconf import DictConfig
+from sglang.srt.disaggregation.utils import (
+ DisaggregationMode,
+)
+from sglang.srt.managers.io_struct import (
+ ReleaseMemoryOccupationReqInput,
+ ResumeMemoryOccupationReqInput,
+)
+from sglang.srt.managers.scheduler import Scheduler as _Scheduler
+from sglang.srt.managers.scheduler import logger
+from sglang.srt.managers.utils import DPBalanceMeta
+from sglang.srt.server_args import PortArgs, ServerArgs
+from sglang.srt.utils import (
+ broadcast_pyobj,
+ configure_logger,
+ get_bool_env_var,
+ kill_itself_when_parent_died,
+ set_gpu_proc_affinity,
+ suppress_other_loggers,
+)
+from sglang.utils import get_exception_traceback
+
+from rlinf.scheduler import Worker, WorkerAddress
+from rlinf.utils.placement import ModelParallelComponentPlacement, PlacementMode
+from rlinf.workers.rollout.utils import (
+ RankMapper,
+ get_module_from_name,
+ rebind_param_attr,
+ swap_tensor_pointer,
+)
+
+from .io_struct import (
+ OffloadReqInput,
+ OffloadReqOutput,
+ SyncHFWeightInput,
+ SyncHFWeightOutput,
+ SyncWeightInput,
+ SyncWeightOutput,
+ TaskMethodInput,
+ TaskMethodOutput,
+)
+import torch_npu
+logger.setLevel(logging.WARNING)
+
+def safe_load_weights(model, weights: list):
+ """
+ 安全加载权重,自动兼容两种命名约定:
+ - 'visual.xxx' (Hugging Face Qwen-VL 格式)
+ - 'model.visual.xxx' (部分 SGLang 或自定义格式)
+
+ Parameters:
+ model: PyTorch 模型实例(需有 named_parameters())
+ weights: List of (name, torch.Tensor)
+ """
+ params_dict = dict(model.named_parameters())
+
+ # 构建一个映射:标准化 key -> 实际参数名
+ # 例如:'visual.patch_embed.proj.weight' 可能对应 params_dict 中的 'visual...' 或 'model.visual...'
+ normalized_to_actual = {}
+ for param_name in params_dict.keys():
+ if param_name.startswith("model.visual."):
+ # 映射到无 model. 的标准名
+ normalized = param_name[len("model."):] # "visual.xxx"
+ normalized_to_actual[normalized] = param_name
+ normalized_to_actual[param_name] = param_name # 也保留原名
+ elif param_name.startswith("visual."):
+ normalized = param_name
+ normalized_to_actual[normalized] = param_name
+ normalized_to_actual["model." + normalized] = param_name # 兼容带 model. 的输入
+ else:
+ # 非 visual 参数,直接映射
+ normalized_to_actual[param_name] = param_name
+
+ # 加载每个权重
+ for name, loaded_weight in weights:
+ if name in normalized_to_actual:
+ actual_name = normalized_to_actual[name]
+ param = params_dict[actual_name]
+ assert param.shape == loaded_weight.shape, (
+ f"Shape mismatch for {name}: expected {param.shape}, got {loaded_weight.shape}"
+ )
+ param.copy_(loaded_weight)
+ else:
+ # 可选:跳过不存在的参数(如优化器状态、非模型参数)
+ print(f"[Warning] Skipping weight not in model: {name}")
+ continue
+
+class Scheduler(_Scheduler, Worker):
+ """
+ Overridden class of SGLang's TP worker class _Scheduler.
+ A Scheduler is a Task that manages the TP worker, and performs necessary weight synchronization with actor and weight offloading.
+ """
+
+ def __init__(
+ self,
+ parent_address: WorkerAddress,
+ placement: ModelParallelComponentPlacement,
+ config: DictConfig,
+ world_size: int,
+ rank: int,
+ server_args: ServerArgs,
+ port_args: PortArgs,
+ gpu_id: int,
+ tp_rank: int,
+ moe_ep_rank: int,
+ pp_rank: int,
+ dp_rank: Optional[int],
+ dp_balance_meta: Optional[DPBalanceMeta] = None,
+ ):
+ Worker.__init__(
+ self, parent_address=parent_address, world_size=world_size, rank=rank
+ )
+
+ # since 0.4.6.post2, pp_rank is added into Scheduler init's parameters
+ # but we don't use it in our implementation, so we set it to 0
+ _Scheduler.__init__(
+ self,
+ server_args,
+ port_args,
+ gpu_id,
+ tp_rank,
+ moe_ep_rank,
+ pp_rank,
+ dp_rank,
+ dp_balance_meta,
+ )
+ # `TpModelWorkerClient` is used when ServerArgs.enable_overlap=True, and it has 'worker' attribute.
+ # But in early SGLang version, `TpModelWorker` doesn't have 'worker' attribute.
+ if not hasattr(self.tp_worker, "worker"):
+ self.tp_worker.worker = self.tp_worker
+
+ self._request_dispatcher._mapping.extend(
+ [
+ (TaskMethodInput, self.run_task_method),
+ (OffloadReqInput, self.offload_model_weights),
+ (SyncWeightInput, self.sync_weight),
+ (SyncHFWeightInput, self.sync_hf_weight),
+ ]
+ )
+ self.cfg = config
+ self.binded_attr = {}
+
+ self._actor_group_name = self.cfg.actor.group_name
+ self.placement_mode = placement.placement_mode
+ self.actor_weight_rank = RankMapper.get_rollout_rank_to_actor_rank_map(
+ placement
+ )[(self.get_parent_rank(), self._rank)]
+ # it's important to use load_weight to load resharded weight from megatron
+ for _, module in self.tp_worker.worker.model_runner.model.named_modules():
+ if hasattr(module, "use_presharded_weights"):
+ module.use_presharded_weights = False
+
+ self._logger.info(
+ f"Running Scheduler dp rank {self.get_parent_rank()}, tp rank {self.tp_rank}, corresponding actor weight rank = {self.actor_weight_rank}"
+ )
+
+ def sync_in_tp(self, fn: str = ""):
+ broadcast_pyobj(
+ [], self.tp_rank, self.tp_worker.worker.model_runner.tp_group.cpu_group
+ )
+ # logger.info(f"{fn}: Sync in tp success!")
+
+ def cuda_info(self, text: str = ""):
+ free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
+ free_gpu_memory /= 2**30
+ total_gpu_memory /= 2**30
+
+ memory_allocated = torch.npu.memory_allocated() / 2**30
+ memory_reserved = torch.npu.memory_reserved() / 2**30
+
+ self._logger.info(
+ f"[dp {self.get_parent_rank()}-tp {self.tp_rank}] {text} "
+ f"{memory_allocated=:.2f} GiB, {memory_reserved=:.2f} GiB, "
+ f"{free_gpu_memory=:.2f} GiB, {total_gpu_memory=:.2f} GiB"
+ )
+
+ def offload_model_weights(self, recv_req: OffloadReqInput):
+ use_cudagraph = not self.cfg.rollout.enforce_eager
+ colocate = self.placement_mode == PlacementMode.COLLOCATED
+ if not colocate:
+ assert use_cudagraph, "If not colocate, use_cudagraph must be True now."
+
+ if use_cudagraph or not colocate:
+ self.release_memory_occupation(ReleaseMemoryOccupationReqInput())
+ # self.cuda_info("After offload Model weights and kv cache")
+ return OffloadReqOutput()
+
+ # manually offload
+ self.named_buffers = {
+ n: buf.clone()
+ for n, buf in self.tp_worker.worker.model_runner.model.named_buffers()
+ }
+
+ self.binded_attr = {
+ name: param.__dict__
+ for name, param in self.tp_worker.worker.model_runner.model.named_parameters()
+ }
+
+ # offload parameters
+ self.tp_worker.worker.model_runner.model.to("meta")
+
+ # offload kv cache
+ self.tp_worker.worker.model_runner.token_to_kv_pool._clear_buffers()
+
+ self.flush_cache()
+ self.sync_in_tp("offload_model_weights")
+ # self.cuda_info("After offload Model weights and kv cache")
+ return OffloadReqOutput()
+
+ def sync_hf_weight(self, recv_req: SyncHFWeightInput):
+ use_cudagraph = not self.cfg.rollout.enforce_eager
+ colocate = self.placement_mode == PlacementMode.COLLOCATED
+
+ assert use_cudagraph, "use_cudagraph must be True now."
+
+ state_dict = self.recv(
+ src_group_name=self._actor_group_name,
+ src_rank=self.actor_weight_rank,
+ )
+
+ model = self.tp_worker.worker.model_runner.model
+
+ if colocate:
+ self.resume_memory_occupation(ResumeMemoryOccupationReqInput())
+ for name, handle in state_dict.items():
+ #func, args = handle
+ #list_args = list(args)
+ # NOTE: the key is to change device id to the current device id
+ # in case two processes have different CUDA_VISIBLE_DEVICES
+ #list_args[6] = torch.npu.current_device()
+ #new_weight = func(*list_args)
+
+ #model.load_weights([(name, new_weight)])
+ #self.tp_worker.worker.model_runner.update_weights_from_tensor(
+ # [(name, new_weight)], load_format="direct"
+ #)
+ #del new_weight
+ func, args = handle
+ import inspect
+ sig = inspect.signature(func)
+ param_names = list(sig.parameters.keys())
+
+ # 将 args 转为 kwargs
+ kwargs = {}
+ args = list(args)
+ for i, param_name in enumerate(param_names):
+ if i < len(args):
+ kwargs[param_name] = args[i]
+ else:
+ break
+
+ # 修改设备参数(假设参数名是 'map_location' 或 'device')
+ if 'map_location' in kwargs:
+ kwargs['map_location'] = f"npu:{torch.npu.current_device()}"
+ elif 'device' in kwargs:
+ kwargs['device'] = torch.npu.current_device()
+
+ new_weight = func(**kwargs)
+ model.load_weights([(name, new_weight)])
+ #safe_load_weights(model, [(name, new_weight)])
+ del new_weight
+ #fixed_weights = []
+ #for name, weight in [(name, new_weight)]:
+ # if name.startswith("visual.") and not name.startswith("model.visual."):
+ # fixed_weights.append(("model." + name, weight))
+ # else:
+ # fixed_weights.append((name, weight))
+ #model.load_weights(fixed_weights)
+ #del new_weight
+ else:
+ # disaggregate mode, recv tensor directly
+ for name, tensor in state_dict.items():
+ model.load_weights([(name, tensor)])
+ self.flush_cache()
+ self.sync_in_tp("sync_hf_weight")
+ return SyncHFWeightOutput()
+
+ def sync_weight(self, recv_req: SyncWeightInput):
+ use_cudagraph = not self.cfg.rollout.enforce_eager
+ colocate = self.placement_mode == PlacementMode.COLLOCATED
+ if not colocate:
+ assert use_cudagraph, "If not colocate, use_cudagraph must be True now."
+
+ state_dict = self.recv(
+ src_group_name=self._actor_group_name,
+ src_rank=self.actor_weight_rank,
+ )
+ model = self.tp_worker.worker.model_runner.model
+
+ if use_cudagraph and colocate:
+ self.resume_memory_occupation(ResumeMemoryOccupationReqInput())
+
+ if colocate:
+ if use_cudagraph:
+ for name, handle in state_dict.items():
+ func, args = handle
+ list_args = list(args)
+ # NOTE: the key is to change device id to the current device id
+ # in case two processes have different CUDA_VISIBLE_DEVICES
+ list_args[6] = torch.npu.current_device()
+ new_weight = func(*list_args)
+
+ self.tp_worker.worker.model_runner.update_weights_from_tensor(
+ [(name, new_weight)], load_format="direct"
+ )
+ del new_weight
+
+ else:
+ named_params = dict(model.named_parameters())
+ for name, handle in state_dict.items():
+ rebind_param_attr(model, name, self.binded_attr, materialize=False)
+ func, args = handle
+ list_args = list(args)
+ list_args[6] = torch.npu.current_device()
+ new_weight = func(*list_args)
+ vllm_weight = named_params[name]
+ assert vllm_weight.shape == new_weight.shape, (
+ f"{name}: {vllm_weight.shape=}, {new_weight.shape=}"
+ )
+ assert vllm_weight.dtype == new_weight.dtype, (
+ f"{name}: {vllm_weight.dtype=}, {new_weight.dtype=}"
+ )
+
+ swap_tensor_pointer(vllm_weight, new_weight)
+ del new_weight
+
+ for name, buffer in self.named_buffers.items():
+ vllm_buffer = get_module_from_name(model, name)
+ assert vllm_buffer.shape == buffer.shape
+ assert vllm_buffer.dtype == buffer.dtype
+ swap_tensor_pointer(vllm_buffer, buffer)
+
+ self.named_buffers = {}
+
+ self.tp_worker.worker.model_runner.token_to_kv_pool._create_buffers()
+ else:
+ # disaggregate mode, recv tensor directly
+ named_tensors = [(n, p) for n, p in state_dict.items()]
+ self.tp_worker.worker.model_runner.update_weights_from_tensor(
+ named_tensors, load_format="direct"
+ )
+ self.sync_in_tp("sync_weight")
+
+ return SyncWeightOutput()
+
+ def run_task_method(self, obj: TaskMethodInput):
+ """
+ Run a CommTask method with the given name and arguments.
+ NOTE: will call wait() if async_op is True.
+ """
+ result = getattr(self, obj.method_name)(*obj.args, **obj.kwargs)
+ if "async_op" in obj.kwargs and obj.kwargs["async_op"]:
+ result = result.wait()
+ return TaskMethodOutput(method_name=obj.method_name, result=result)
+
+
+def run_scheduler_process(
+ parent_address: WorkerAddress,
+ placement: ModelParallelComponentPlacement,
+ config: DictConfig,
+ world_size: int,
+ rank: int,
+ server_args: ServerArgs,
+ port_args: PortArgs,
+ gpu_id: int,
+ tp_rank: int,
+ moe_ep_rank: int,
+ pp_rank: int,
+ dp_rank: Optional[int],
+ pipe_writer,
+ balance_meta: Optional[DPBalanceMeta] = None,
+):
+ # Generate the prefix
+ prefix = ""
+ if dp_rank is not None:
+ prefix += f" DP{dp_rank}"
+ if server_args.tp_size > 1:
+ prefix += f" TP{tp_rank}"
+ if server_args.ep_size > 1:
+ prefix += f" EP{moe_ep_rank}"
+ if server_args.pp_size > 1:
+ prefix += f" PP{pp_rank}"
+
+ # Config the process
+ setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
+ faulthandler.enable()
+ kill_itself_when_parent_died()
+ parent_process = psutil.Process().parent()
+
+ # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
+ if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
+ dp_rank = int(os.environ["SGLANG_DP_RANK"])
+
+ # Configure the logger
+ configure_logger(server_args, prefix=prefix)
+ suppress_other_loggers()
+
+ # Set cpu affinity to this gpu process
+ if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
+ set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
+
+ # Create a scheduler and run the event loop
+ try:
+ scheduler = Scheduler(
+ parent_address,
+ placement,
+ config,
+ world_size,
+ rank,
+ server_args,
+ port_args,
+ gpu_id,
+ tp_rank,
+ moe_ep_rank,
+ pp_rank,
+ dp_rank,
+ dp_balance_meta=balance_meta,
+ )
+ pipe_writer.send(
+ {
+ "status": "ready",
+ "max_total_num_tokens": scheduler.max_total_num_tokens,
+ "max_req_input_len": scheduler.max_req_input_len,
+ }
+ )
+
+ disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
+ if disaggregation_mode == DisaggregationMode.NULL:
+ if server_args.pp_size > 1:
+ scheduler.event_loop_pp()
+ elif scheduler.enable_overlap:
+ scheduler.event_loop_overlap()
+ else:
+ scheduler.event_loop_normal()
+ elif disaggregation_mode == DisaggregationMode.PREFILL:
+ if scheduler.enable_overlap:
+ scheduler.event_loop_overlap_disagg_prefill()
+ else:
+ if server_args.pp_size > 1:
+ scheduler.event_loop_pp_disagg_prefill()
+ else:
+ scheduler.event_loop_normal_disagg_prefill()
+
+ elif disaggregation_mode == DisaggregationMode.DECODE:
+ if scheduler.enable_overlap:
+ scheduler.event_loop_overlap_disagg_decode()
+ else:
+ scheduler.event_loop_normal_disagg_decode()
+
+ except Exception:
+ traceback = get_exception_traceback()
+ logger.error(f"Scheduler hit an exception: {traceback}")
+ parent_process.send_signal(signal.SIGQUIT)
diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/tokenizer_manager.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/tokenizer_manager.py
new file mode 100644
index 000000000..07b77b200
--- /dev/null
+++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/tokenizer_manager.py
@@ -0,0 +1,129 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional
+
+import fastapi
+from sglang.srt.managers.io_struct import AbortReq
+from sglang.srt.managers.tokenizer_manager import TokenizerManager as _TokenizerManager
+#from sglang.srt.managers.tokenizer_manager import _Communicator
+from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator
+from sglang.srt.server_args import PortArgs, ServerArgs
+from .io_struct import (
+ OffloadReqInput,
+ OffloadReqOutput,
+ SyncHFWeightInput,
+ SyncHFWeightOutput,
+ SyncWeightInput,
+ SyncWeightOutput,
+ TaskMethodInput,
+ TaskMethodOutput,
+)
+
+
+# Add two methods and their communicators, input/output structs.
+class TokenizerManager(_TokenizerManager):
+ def __init__(
+ self,
+ server_args: ServerArgs,
+ port_args: PortArgs,
+ ):
+ super().__init__(
+ server_args=server_args,
+ port_args=port_args,
+ )
+
+ self.run_task_method_communicator = _Communicator(
+ self.send_to_scheduler,
+ fan_out=server_args.dp_size,
+ )
+ self.offload_model_weights_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
+ self.sync_weight_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
+ self.sync_hf_weight_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
+
+ self._result_dispatcher._mapping.extend(
+ [
+ (
+ TaskMethodOutput,
+ self.run_task_method_communicator.handle_recv,
+ ),
+ (
+ OffloadReqOutput,
+ self.offload_model_weights_communicator.handle_recv,
+ ),
+ (
+ SyncWeightOutput,
+ self.sync_weight_communicator.handle_recv,
+ ),
+ (
+ SyncHFWeightOutput,
+ self.sync_hf_weight_communicator.handle_recv,
+ ),
+ ]
+ )
+
+ async def run_task_method(
+ self,
+ obj: TaskMethodInput = None,
+ request: Optional[fastapi.Request] = None,
+ ):
+ """
+ Run a task method with the given name and arguments.
+ """
+ self.auto_create_handle_loop()
+ if isinstance(obj, str):
+ obj = TaskMethodInput(method_name=obj)
+ res: List[TaskMethodOutput] = await self.run_task_method_communicator(obj)
+ return res[0].result
+
+ async def offload_model_weights(
+ self,
+ obj: OffloadReqInput = None,
+ request: Optional[fastapi.Request] = None,
+ ):
+ self.auto_create_handle_loop()
+ if obj is None:
+ obj = OffloadReqInput()
+ await self.offload_model_weights_communicator(obj)
+
+ async def sync_hf_weight(
+ self,
+ obj: SyncHFWeightInput,
+ request: Optional[fastapi.Request] = None,
+ ):
+ self.auto_create_handle_loop()
+ await self.sync_hf_weight_communicator(obj)
+
+ async def sync_weight(
+ self,
+ obj: SyncWeightInput,
+ request: Optional[fastapi.Request] = None,
+ ):
+ self.auto_create_handle_loop()
+ await self.sync_weight_communicator(obj)
+
+ def abort_request(self, rid: str):
+ if rid != "" and rid not in self.rid_to_state:
+ return
+ req = AbortReq(rid)
+ self.send_to_scheduler.send_pyobj(req)
+
+ async def pause_generation(self):
+ self.abort_request("")
diff --git a/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py b/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py
index 3e021e922..519895e49 100644
--- a/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py
+++ b/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py
@@ -48,6 +48,9 @@ def __init__(
)
# rlinf specific
self.rlinf_config = rlinf_config
+ self.using_sharded_weight = (
+ False if self.rlinf_config.actor.training_backend == "fsdp" else True
+ )
self._rlinf_worker = _RLinfWorker(
parent_address=parent_address,
world_size=vllm_config.parallel_config.world_size,
@@ -103,7 +106,7 @@ def sync_hf_weight(self) -> None:
def use_sharded_weights(self) -> None:
model = self.model_runner.model
for _, param in model.named_parameters():
- setattr(param, "is_sharded_weight", True)
+ setattr(param, "is_sharded_weight", self.using_sharded_weight)
def get_dp_rank(self) -> int:
return self._rlinf_worker.get_parent_rank()
diff --git a/rlinf/models/__init__.py b/rlinf/models/__init__.py
index 08207900a..617ac7467 100644
--- a/rlinf/models/__init__.py
+++ b/rlinf/models/__init__.py
@@ -17,7 +17,6 @@
import torch
from omegaconf import DictConfig
-from peft import LoraConfig, PeftModel, get_peft_model
from transformers import (
AutoConfig,
AutoImageProcessor,
@@ -172,6 +171,8 @@ def get_model(model_path, cfg: DictConfig, override_config_kwargs=None):
model = model.cuda()
if cfg.is_lora:
+ from peft import LoraConfig, PeftModel, get_peft_model
+
if not hasattr(cfg, "lora_path") or cfg.lora_path is None:
lora_config = LoraConfig(
r=cfg.lora_rank,
diff --git a/rlinf/models/embodiment/model_utils.py b/rlinf/models/embodiment/model_utils.py
index 04425cfdc..8e7aebeb1 100644
--- a/rlinf/models/embodiment/model_utils.py
+++ b/rlinf/models/embodiment/model_utils.py
@@ -15,9 +15,10 @@
from typing import Any, Optional
import torch
-import torch.nn.functional as F
from transformers.generation import TopKLogitsWarper
+from rlinf.utils.utils import compute_entropy_from_logits, compute_logprobs_from_logits
+
def default_logits_processor(logits, action_tokens, vocab_size, n_action_bins):
logits = logits.permute(0, 2, 1) # [B, vocab-size, action-dim]
@@ -34,28 +35,6 @@ def default_logits_processor(logits, action_tokens, vocab_size, n_action_bins):
return ret
-def compute_logprobs_from_logits(logits, target):
- logprobs = -F.cross_entropy(
- logits, target=target, reduction="none"
- ) # [B, action-dim]
- return logprobs
-
-
-def compute_entropy_from_logits(logits, epsilon=1e-10):
- """
- Compute entropy by logits.
-
- Args:
- logits: [B, vocab-size, seq-len]
- Returns:
- entropy: [B, seq-len]
- """
- all_probs = F.softmax(logits, dim=1) # [B, vocab-size, seq-len]
- all_log_probs = torch.log(all_probs + epsilon)
- entropy = -torch.sum(all_probs * all_log_probs, dim=1) # [B, seq-len]
- return entropy
-
-
def custom_forward(
model,
input_ids,
diff --git a/rlinf/runners/coding_online_rl_runner.py b/rlinf/runners/coding_online_rl_runner.py
index 46be3423a..377667523 100644
--- a/rlinf/runners/coding_online_rl_runner.py
+++ b/rlinf/runners/coding_online_rl_runner.py
@@ -212,6 +212,7 @@ def run(self):
infer_handle: Handle = self.inference.run_inference(
input_channel=self.dataloader_channel,
output_channel=self.inference_channel,
+ rollout_channel=None,
compute_ref_logprobs=self.compute_ref_logprobs,
)
inference_channel = self.inference_channel
@@ -221,16 +222,12 @@ def run(self):
# Advantages and returns
adv_handle: Handle = self.actor.compute_advantages_and_returns(
- input_channel=self.inference_channel,
+ input_channel=inference_channel,
output_channel=self.actor_channel,
)
# Actor training
actor_input_channel = self.actor_channel
- if self.is_pipeline:
- # In pipeline mode, the rollout already contains the advantages and returns
- # So the above two steps are in fact no-ops, and we should directly use the inference channel as the input
- actor_input_channel = inference_channel
actor_handle: Handle = self.actor.run_training(
input_channel=actor_input_channel,
)
diff --git a/rlinf/runners/math_runner.py b/rlinf/runners/reasoning_runner.py
similarity index 91%
rename from rlinf/runners/math_runner.py
rename to rlinf/runners/reasoning_runner.py
index e3f5b3750..b53010e18 100644
--- a/rlinf/runners/math_runner.py
+++ b/rlinf/runners/reasoning_runner.py
@@ -25,7 +25,7 @@
from tqdm import tqdm
from rlinf.data.io_struct import RolloutRequest
-from rlinf.scheduler import Channel, Worker
+from rlinf.scheduler import Channel
from rlinf.scheduler import WorkerGroupFuncResult as Handle
from rlinf.utils.data_iter_utils import split_list
from rlinf.utils.distributed import ScopedTimer
@@ -35,6 +35,7 @@
from rlinf.utils.timers import Timer
from rlinf.workers.actor.megatron_actor_worker import MegatronActor
from rlinf.workers.inference.megatron_inference_worker import MegatronInference
+from rlinf.workers.reward.reward_worker import RewardWorker
if typing.TYPE_CHECKING:
from rlinf.workers.rollout.sglang.sglang_worker import SGLangWorker
@@ -43,8 +44,8 @@
logging.getLogger().setLevel(logging.INFO)
-class MathRunner:
- """Runner for math model training."""
+class ReasoningRunner:
+ """Runner for reasoning task RL training."""
def __init__(
self,
@@ -55,33 +56,28 @@ def __init__(
rollout: Union["SGLangWorker", "VLLMWorker"],
inference: Optional[MegatronInference],
actor: MegatronActor,
- reward: Optional[Worker] = None,
+ reward: RewardWorker,
):
""""""
self.cfg = cfg
self.component_placement = placement
self.is_pipeline = self.component_placement.is_disaggregated
self.has_dedicated_inference = inference is not None
- self.has_dedicated_reward = reward is not None
# Workers
self.rollout = rollout
self.actor = actor
# Collocated mode uses actor as inference
self.inference = inference if self.has_dedicated_inference else self.actor
- self.reward = reward if self.has_dedicated_reward else self.actor
+ self.reward = reward
# Data channels
self.dataloader_channel = Channel.create("DataLoader")
self.rollout_channel = Channel.create("Rollout")
# Create a local channel (i.e., a channel that is different in every process)
# if inference is not a dedicated worker
- self.inference_channel = Channel.create(
- "Inference", local=not self.has_dedicated_inference
- )
- self.reward_channel = Channel.create(
- "Reward", local=not self.has_dedicated_reward
- )
+ self.inference_channel = Channel.create("Inference")
+ self.reward_channel = Channel.create("Reward")
self.actor_channel = Channel.create("Actor", local=True)
# Configurations
@@ -179,8 +175,7 @@ def init_workers(self):
self.actor.init_worker().wait()
if self.has_dedicated_inference:
self.inference.init_worker().wait()
- if self.has_dedicated_reward:
- self.reward.init_worker().wait()
+ self.reward.init_worker().wait()
if self.cfg.runner.resume_dir is None:
return
@@ -274,18 +269,26 @@ def epoch(self):
def _put_batch(self, batch: Dict[str, torch.Tensor]):
prompt_ids = batch["prompt"].tolist()
lengths = batch["length"].tolist()
- answers = batch["answer"].tolist()
- prompts = [ids[-pmp_len:] for ids, pmp_len in zip(prompt_ids, lengths)]
+ answers = batch["answer"]
+ image_data = batch["image_data"]
+ multi_modal_inputs = batch["multi_modal_inputs"]
+ prompt_ids = [ids[-pmp_len:] for ids, pmp_len in zip(prompt_ids, lengths)]
rollout_dp_size = self.component_placement.rollout_dp_size
- for input_ids, answers in zip(
- split_list(prompts, rollout_dp_size, enforce_divisible_batch=False),
+ for input_ids, answers, image_data, multi_modal_inputs in zip(
+ split_list(prompt_ids, rollout_dp_size, enforce_divisible_batch=False),
split_list(answers, rollout_dp_size, enforce_divisible_batch=False),
+ split_list(image_data, rollout_dp_size, enforce_divisible_batch=False),
+ split_list(
+ multi_modal_inputs, rollout_dp_size, enforce_divisible_batch=False
+ ),
):
request = RolloutRequest(
n=self.cfg.algorithm.group_size,
input_ids=input_ids,
answers=answers,
+ image_data=image_data,
+ multi_modal_inputs=multi_modal_inputs,
)
self.dataloader_channel.put(request, async_op=True)
@@ -327,38 +330,34 @@ def run(self):
output_channel=self.rollout_channel,
)
+ # Rewards
+ reward_handle: Handle = self.reward.compute_rewards(
+ input_channel=self.rollout_channel,
+ output_channel=self.reward_channel,
+ )
+
if self.recompute_logprobs:
# Inference prev/ref logprobs
infer_handle: Handle = self.inference.run_inference(
- input_channel=self.rollout_channel,
+ input_channel=self.reward_channel,
output_channel=self.inference_channel,
+ rollout_channel=self.rollout_channel,
compute_ref_logprobs=self.compute_ref_logprobs,
)
inference_channel = self.inference_channel
else:
infer_handle = None
- inference_channel = self.rollout_channel
-
- # Rewards
- reward_handle: Handle = self.reward.compute_rewards(
- input_channel=inference_channel,
- output_channel=self.reward_channel,
- )
+ inference_channel = self.reward_channel
# Advantages and returns
adv_handle: Handle = self.actor.compute_advantages_and_returns(
- input_channel=self.reward_channel,
+ input_channel=inference_channel,
output_channel=self.actor_channel,
)
# Actor training
- actor_input_channel = self.actor_channel
- if self.is_pipeline:
- # In pipeline mode, the rollout already contains the advantages and returns
- # So the above two steps are in fact no-ops, and we should directly use the inference channel as the input
- actor_input_channel = inference_channel
actor_handle: Handle = self.actor.run_training(
- input_channel=actor_input_channel,
+ input_channel=self.actor_channel,
)
metrics = actor_handle.wait()
@@ -421,7 +420,7 @@ def run(self):
}
self.metric_logger.log(training_metrics, logging_steps + i)
- logging_metrics = time_metrics
+ logging_metrics = {f"{k}_time": v for k, v in time_metrics.items()}
if self.cfg.actor.get("calculate_flops", False):
flops_metrics = self._compute_flops_metrics(
diff --git a/rlinf/utils/convertor/__init__.py b/rlinf/utils/convertor/__init__.py
new file mode 100644
index 000000000..5b365ea1e
--- /dev/null
+++ b/rlinf/utils/convertor/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/rlinf/utils/convertor/utils.py b/rlinf/utils/convertor/utils.py
new file mode 100644
index 000000000..761218650
--- /dev/null
+++ b/rlinf/utils/convertor/utils.py
@@ -0,0 +1,467 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import re
+from dataclasses import dataclass
+from enum import Enum
+from typing import Callable, Dict, List, Optional, Tuple
+
+import torch
+
+
+class TransformType(Enum):
+ SPLIT_QKV = "split_qkv"
+ SPLIT_QKV_BIAS = "split_qkv_bias"
+ SPLIT_FC1 = "split_fc1"
+ SPLIT_NONE = "split_none"
+
+
+class TransformFunc:
+ @staticmethod
+ def _split_gqa_tensor(
+ tensor: torch.Tensor, new_statedict: dict, weight_names: List[str], config
+ ) -> None:
+ hidden_size = config.model_config.hidden_size
+ num_attention_heads = config.model_config.num_attention_heads
+ num_query_groups = config.model_config.num_query_groups or num_attention_heads
+ head_dim = hidden_size // num_attention_heads
+
+ target_tp = config.reshard_tp_size
+ assert num_query_groups % target_tp == 0, (
+ "num_query_groups must be divisible by reshard_tp_size"
+ )
+ local_num_query_groups = num_query_groups // target_tp
+
+ # heads per query group
+ assert num_attention_heads % num_query_groups == 0, (
+ "num_attention_heads must be divisible by num_query_groups"
+ )
+ q_heads_per_group = num_attention_heads // num_query_groups
+
+ num_channel_qkv = q_heads_per_group + 2
+
+ if tensor.ndim == 2:
+ # Weight: [out_features, in_features]
+ out_features, in_features = tensor.shape
+ expected_out = local_num_query_groups * num_channel_qkv * head_dim
+ assert out_features == expected_out, (
+ f"Unexpected fused QKV weight shape {tensor.shape}, expect "
+ f"[{expected_out}, {in_features}] (local groups={local_num_query_groups})"
+ )
+
+ qkv = tensor.view(
+ local_num_query_groups, num_channel_qkv, head_dim, in_features
+ )
+ q, k, v = torch.split(
+ qkv, [q_heads_per_group, 1, 1], dim=1
+ ) # shapes: [G, qh, D, In], [G,1,D,In], [G,1,D,In]
+ q_full = q.reshape(-1, in_features).contiguous()
+ k_full = k.reshape(-1, in_features).contiguous()
+ v_full = v.reshape(-1, in_features).contiguous()
+ else:
+ # Bias: [out_features]
+ out_features = tensor.shape[0]
+ expected_out = local_num_query_groups * num_channel_qkv * head_dim
+ assert out_features == expected_out, (
+ f"Unexpected fused QKV bias shape {tensor.shape}, expect "
+ f"[{expected_out}] (local groups={local_num_query_groups})"
+ )
+
+ qkv = tensor.view(local_num_query_groups, num_channel_qkv, head_dim)
+ q, k, v = torch.split(qkv, [q_heads_per_group, 1, 1], dim=1)
+ q_full = q.reshape(-1).contiguous()
+ k_full = k.reshape(-1).contiguous()
+ v_full = v.reshape(-1).contiguous()
+
+ # Save to target names
+ new_statedict[weight_names[0]] = q_full.clone()
+ new_statedict[weight_names[1]] = k_full.clone()
+ new_statedict[weight_names[2]] = v_full.clone()
+
+ @staticmethod
+ def split_fc1(
+ linear_fc1: torch.Tensor, new_statedict: dict, weight_names: List[str], config
+ ) -> None:
+ assert weight_names is not None and len(weight_names) == 2, (
+ f"split_fc1 transform expects two weight names, got {weight_names}"
+ )
+
+ tp_size = config.model_config.tensor_model_parallel_size
+ target_tp = config.reshard_tp_size
+ split_size = linear_fc1.shape[0] // (tp_size // target_tp)
+ linear_fc1_slice = torch.split(linear_fc1, split_size, dim=0)
+
+ gate_proj_shards = []
+ up_proj_shards = []
+ for weight in linear_fc1_slice:
+ assert weight.shape[0] % 2 == 0, (
+ f"linear_fc1 weight shape {weight.shape} is not even along dim 0"
+ )
+ weight_chunk = torch.chunk(weight, 2, dim=0)
+ gate_proj_shards.append(weight_chunk[0])
+ up_proj_shards.append(weight_chunk[1])
+ gate_proj = torch.cat(gate_proj_shards, dim=0)
+ up_proj = torch.cat(up_proj_shards, dim=0)
+
+ new_statedict[weight_names[0]] = gate_proj.clone()
+ new_statedict[weight_names[1]] = up_proj.clone()
+
+ @staticmethod
+ def split_none(
+ tensor: torch.Tensor, new_statedict: dict, weight_names: List[str]
+ ) -> None:
+ assert weight_names is not None and len(weight_names) == 1, (
+ f"split_none transform expects one weight name, got {weight_names}"
+ )
+ new_statedict[weight_names[0]] = tensor.clone()
+
+
+@dataclass
+class ConvertorRule:
+ pattern: re.Pattern
+ transform: TransformType
+ targets: List[str]
+ post: Optional[Callable] = None
+
+
+class BaseConvertor:
+ def __init__(self, config, strict: bool = False):
+ self.cfg = config
+ self.strict = strict
+ self.rules = self.build_rules()
+
+ def map_name(self, name: str) -> Optional[Tuple[TransformType, List[str]]]:
+ def _get_targets_from_match(templates: list[str], m: re.Match) -> list[str]:
+ gd = m.groupdict()
+ out = []
+ for t in templates:
+ if "{" in t and "}" in t:
+ out.append(t.format(**gd))
+ else:
+ out.append(m.expand(t))
+ return out
+
+ for r in self.rules:
+ m = r.pattern.fullmatch(name)
+ if not m:
+ continue
+ targets = r.targets
+ if r.post:
+ targets = r.post(targets, m)
+ full_names = _get_targets_from_match(targets, m)
+ return r.transform, full_names
+ return None
+
+ def convert(self, state_dict: Dict) -> Dict:
+ converted = {}
+ for k, v in state_dict.items():
+ mapped = self.map_name(k)
+ if mapped is None:
+ if self.strict:
+ raise KeyError(f"Unmapped key {k}")
+ continue
+ transform, targets = mapped
+ if transform in (TransformType.SPLIT_QKV, TransformType.SPLIT_QKV_BIAS):
+ TransformFunc._split_gqa_tensor(v, converted, targets, self.cfg)
+ elif transform == TransformType.SPLIT_FC1:
+ TransformFunc.split_fc1(v, converted, targets, self.cfg)
+ elif transform == TransformType.SPLIT_NONE:
+ TransformFunc.split_none(v, converted, targets)
+ else:
+ raise ValueError(f"Unknown transform type {transform}")
+ return converted
+
+ def build_rules(self) -> List[ConvertorRule]:
+ """
+ Should be implemented in subclass to build the conversion rules.
+ """
+ raise NotImplementedError
+
+
+class Qwen2_5Convertor(BaseConvertor):
+ def build_rules(self) -> List[ConvertorRule]:
+ LID = r"(?P\d+)"
+ WB = r"(?Pweight|bias)"
+
+ return [
+ # embeddings
+ ConvertorRule(
+ re.compile(r"embedding\.word_embeddings\.weight$"),
+ TransformType.SPLIT_NONE,
+ [r"model.embed_tokens.weight"],
+ ),
+ # final_layernorm
+ ConvertorRule(
+ re.compile(r"decoder\.final_layernorm\.weight$"),
+ TransformType.SPLIT_NONE,
+ [r"model.norm.weight"],
+ ),
+ # lm_head
+ ConvertorRule(
+ re.compile(r"output_layer\.weight$"),
+ TransformType.SPLIT_NONE,
+ [r"lm_head.weight"],
+ ),
+ # attn qkv norm
+ ConvertorRule(
+ re.compile(
+ rf"decoder\.layers\.{LID}\.self_attention\.linear_qkv\.layer_norm_weight$"
+ ),
+ TransformType.SPLIT_NONE,
+ [r"model.layers.\g.input_layernorm.weight"],
+ ),
+ # attn qkv weights/bias
+ ConvertorRule(
+ re.compile(
+ rf"decoder\.layers\.{LID}\.self_attention\.linear_qkv\.{WB}$"
+ ),
+ TransformType.SPLIT_QKV,
+ [
+ r"model.layers.\g.self_attn.q_proj.\g",
+ r"model.layers.\g.self_attn.k_proj.\g",
+ r"model.layers.\g.self_attn.v_proj.\g",
+ ],
+ ),
+ # attn o proj
+ ConvertorRule(
+ re.compile(
+ rf"decoder\.layers\.{LID}\.self_attention\.linear_proj\.{WB}$"
+ ),
+ TransformType.SPLIT_NONE,
+ [r"model.layers.\g.self_attn.o_proj.\g"],
+ ),
+ # mlp fc1
+ ConvertorRule(
+ re.compile(rf"decoder\.layers\.{LID}\.mlp\.linear_fc1\.{WB}$"),
+ TransformType.SPLIT_FC1,
+ [
+ r"model.layers.\g.mlp.gate_proj.\g",
+ r"model.layers.\g.mlp.up_proj.\g",
+ ],
+ ),
+ # mlp fc2
+ ConvertorRule(
+ re.compile(rf"decoder\.layers\.{LID}\.mlp\.linear_fc2\.{WB}$"),
+ TransformType.SPLIT_NONE,
+ [r"model.layers.\g.mlp.down_proj.\g"],
+ ),
+ # mlp norms
+ ConvertorRule(
+ re.compile(
+ rf"decoder\.layers\.{LID}\.mlp\.linear_fc1\.layer_norm_weight$"
+ ),
+ TransformType.SPLIT_NONE,
+ [r"model.layers.\g.post_attention_layernorm.weight"],
+ ),
+ ]
+
+
+class Qwen2_5VLConvertor(BaseConvertor):
+ def _build_vision_rules(self) -> List[ConvertorRule]:
+ B = r"(?P\d+)"
+ WB = r"(?Pweight|bias)"
+ HF_V_PREFIX = "model.visual"
+ HF_V_DECODER_PREFIX = f"{HF_V_PREFIX}.blocks"
+ MG_V_PREFIX = "vision_model"
+ MG_V_DECODER_PREFIX = rf"{MG_V_PREFIX}\.decoder\.layers"
+
+ vision_rules = [
+ # vision patch embed
+ ConvertorRule(
+ re.compile(rf"^{MG_V_PREFIX}\.patch_embed\.proj\.weight$"),
+ TransformType.SPLIT_NONE,
+ [f"{HF_V_PREFIX}.patch_embed.proj.weight"],
+ ),
+ # final layer norm
+ ConvertorRule(
+ re.compile(rf"^{MG_V_PREFIX}\.decoder\.final_layernorm\.weight$"),
+ TransformType.SPLIT_NONE,
+ [f"{HF_V_PREFIX}.merger.ln_q.weight"],
+ ),
+ # attn norm
+ ConvertorRule(
+ re.compile(
+ rf"^{MG_V_DECODER_PREFIX}\.{B}\.self_attention\.layer_norm_weight$"
+ ),
+ TransformType.SPLIT_NONE,
+ [f"{HF_V_DECODER_PREFIX}" + r".\g.norm1.weight"],
+ ),
+ # attn qkv
+ ConvertorRule(
+ re.compile(
+ rf"^{MG_V_DECODER_PREFIX}\.{B}\.self_attention\.linear_qkv\.{WB}$"
+ ),
+ TransformType.SPLIT_NONE,
+ [f"{HF_V_DECODER_PREFIX}" + r".\g.attn.qkv.\g"],
+ ),
+ # attn proj
+ ConvertorRule(
+ re.compile(
+ rf"^{MG_V_DECODER_PREFIX}\.{B}\.self_attention\.linear_proj\.{WB}$"
+ ),
+ TransformType.SPLIT_NONE,
+ [f"{HF_V_DECODER_PREFIX}" + r".\g.attn.proj.\g"],
+ ),
+ # mlp fc1
+ ConvertorRule(
+ re.compile(rf"^{MG_V_DECODER_PREFIX}\.{B}\.mlp\.linear_fc1\.{WB}$"),
+ TransformType.SPLIT_FC1,
+ [
+ f"{HF_V_DECODER_PREFIX}" + r".\g.mlp.gate_proj.\g",
+ f"{HF_V_DECODER_PREFIX}" + r".\g.mlp.up_proj.\g",
+ ],
+ ),
+ # mlp fc2
+ ConvertorRule(
+ re.compile(rf"^{MG_V_DECODER_PREFIX}\.{B}\.mlp\.linear_fc2\.{WB}$"),
+ TransformType.SPLIT_NONE,
+ [f"{HF_V_DECODER_PREFIX}" + r".\g.mlp.down_proj.\g"],
+ ),
+ # mlp norm
+ ConvertorRule(
+ re.compile(
+ rf"^{MG_V_DECODER_PREFIX}\.{B}\.mlp\.linear_fc1\.layer_norm_weight$"
+ ),
+ TransformType.SPLIT_NONE,
+ [f"{HF_V_DECODER_PREFIX}" + r".\g.norm2.weight"],
+ ),
+ ]
+ return vision_rules
+
+ def _build_llm_rules(self) -> List[ConvertorRule]:
+ B = r"(?P\d+)"
+ WB = r"(?Pweight|bias)"
+ HF_LLM_PREFIX = "model.language_model"
+ MG_LLM_PREFIX = "language_model"
+ MG_LLM_DECODER_PREFIX = rf"{MG_LLM_PREFIX}\.decoder\.layers"
+
+ llm_rules = [
+ # embeddings
+ ConvertorRule(
+ re.compile(rf"^{MG_LLM_PREFIX}\.embed_tokens\.weight$"),
+ TransformType.SPLIT_NONE,
+ [f"{HF_LLM_PREFIX}.embedding.weight"],
+ ),
+ # final_layernorm
+ ConvertorRule(
+ re.compile(rf"^{MG_LLM_PREFIX}\.final_layernorm\.weight$"),
+ TransformType.SPLIT_NONE,
+ [f"{HF_LLM_PREFIX}.norm.weight"],
+ ),
+ # attn norm
+ ConvertorRule(
+ re.compile(
+ rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.self_attention\.layer_norm_weight$"
+ ),
+ TransformType.SPLIT_NONE,
+ [f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.input_layernorm.weight"],
+ ),
+ # attn qkv
+ ConvertorRule(
+ re.compile(
+ rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.self_attention\.linear_qkv\.{WB}$"
+ ),
+ TransformType.SPLIT_QKV,
+ [
+ f"{HF_LLM_PREFIX}"
+ + r".decoder.layers.\g.self_attn.q_proj.\g",
+ f"{HF_LLM_PREFIX}"
+ + r".decoder.layers.\g.self_attn.k_proj.\g",
+ f"{HF_LLM_PREFIX}"
+ + r".decoder.layers.\g.self_attn.v_proj.\g",
+ ],
+ ),
+ # attn proj
+ ConvertorRule(
+ re.compile(
+ rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.self_attention\.linear_proj\.{WB}$"
+ ),
+ TransformType.SPLIT_NONE,
+ [f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.self_attn.o_proj.\g"],
+ ),
+ # mlp fc1
+ ConvertorRule(
+ re.compile(rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.mlp\.linear_fc1\.{WB}$"),
+ TransformType.SPLIT_FC1,
+ [
+ f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.mlp.gate_proj.\g",
+ f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.mlp.up_proj.\g",
+ ],
+ ),
+ # mlp fc2
+ ConvertorRule(
+ re.compile(rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.mlp\.linear_fc2\.{WB}$"),
+ TransformType.SPLIT_NONE,
+ [f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.mlp.down_proj.\g"],
+ ),
+ # mlp norm
+ ConvertorRule(
+ re.compile(
+ rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.mlp\.linear_fc1\.layer_norm_weight$"
+ ),
+ TransformType.SPLIT_NONE,
+ [
+ f"{HF_LLM_PREFIX}"
+ + r".decoder.layers.\g.post_attention_layernorm.weight"
+ ],
+ ),
+ ]
+ return llm_rules
+
+ def _build_projector_rules(self) -> List[ConvertorRule]:
+ HF_PROJECTOR_PREFIX = "model.visual.merger"
+ MG_PROJECTOR_PREFIX = "vision_model.protection.encoder"
+ WB = r"(?Pweight|bias)"
+
+ projector_rules = [
+ # projector fc1
+ ConvertorRule(
+ re.compile(rf"^{MG_PROJECTOR_PREFIX}\.linear_fc1\.{WB}$"),
+ TransformType.SPLIT_NONE,
+ [f"{HF_PROJECTOR_PREFIX}" + r".mlp.0.\g"],
+ ),
+ # projector fc2
+ ConvertorRule(
+ re.compile(rf"^{MG_PROJECTOR_PREFIX}\.linear_fc2\.{WB}$"),
+ TransformType.SPLIT_NONE,
+ [f"{HF_PROJECTOR_PREFIX}" + r".mlp.2.\g"],
+ ),
+ ]
+ return projector_rules
+
+ def build_rules(self) -> List[ConvertorRule]:
+ rules = []
+ rules.extend(self._build_vision_rules())
+ rules.extend(self._build_llm_rules())
+ rules.extend(self._build_projector_rules())
+ return rules
+
+
+_MG2HF_CONVERTOR_REGISTRY = {}
+
+
+def register_mg2hf_convertor(model_arch: str, convertor_cls: Callable) -> None:
+ if model_arch in _MG2HF_CONVERTOR_REGISTRY:
+ raise ValueError(f"Convertor for {model_arch} already registered")
+ _MG2HF_CONVERTOR_REGISTRY[model_arch] = convertor_cls
+
+
+register_mg2hf_convertor("qwen2.5", Qwen2_5Convertor)
+register_mg2hf_convertor("qwen2.5_vl", Qwen2_5VLConvertor)
+
+
+def get_mg2hf_convertor(model_arch: str, config, strict: bool = False) -> BaseConvertor:
+ if model_arch not in _MG2HF_CONVERTOR_REGISTRY:
+ raise ValueError(f"No convertor registered for {model_arch}")
+ convertor_cls = _MG2HF_CONVERTOR_REGISTRY[model_arch]
+ return convertor_cls(config=config, strict=strict)
diff --git a/rlinf/utils/data_iter_utils.py b/rlinf/utils/data_iter_utils.py
index 27bbd7215..31e440201 100644
--- a/rlinf/utils/data_iter_utils.py
+++ b/rlinf/utils/data_iter_utils.py
@@ -60,13 +60,17 @@ def concat_dict_list(list_of_dicts: List[Dict[str, Any]]) -> Dict[str, Any]:
return result
-def split_list(inputs, num_chunks, enforce_divisible_batch: Optional[bool] = True):
+def split_list(
+ inputs: List, num_chunks: int, enforce_divisible_batch: Optional[bool] = True
+):
"""
Split a list into equal sized chunks
"""
if enforce_divisible_batch:
chunk_size = len(inputs) // num_chunks
- assert len(inputs) % chunk_size == 0, "Issue with batch size configuration!"
+ assert len(inputs) % chunk_size == 0, (
+ f"Issue with batch size configuration! inputs len:{len(inputs)} num_chunks:{num_chunks}"
+ )
return [inputs[i : i + chunk_size] for i in range(0, len(inputs), chunk_size)]
else:
k, m = divmod(len(inputs), num_chunks)
diff --git a/rlinf/utils/distributed.py b/rlinf/utils/distributed.py
index e9da8d6da..1f5ddc9c6 100644
--- a/rlinf/utils/distributed.py
+++ b/rlinf/utils/distributed.py
@@ -29,11 +29,16 @@
from rlinf.utils.timers import NamedTimer
-
+import torch_npu
def compute_rollout_metrics(
- rollout_batch, max_prompt_len, response_len, use_critic=False
+ rollout_batch,
+ max_prompt_len,
+ response_len,
+ dp_world_size,
+ dp_group=None,
+ use_critic=False,
):
- device = torch.device(f"cuda:{torch.cuda.current_device()}")
+ device = torch.device(f"npu:{torch.npu.current_device()}")
advantages = rollout_batch["advantages"].to(device=device)
mask = rollout_batch["attention_mask"][:, -response_len:].to(device=device)
prompt_lengths = rollout_batch["prompt_lengths"].clone().to(device=device)
@@ -41,8 +46,6 @@ def compute_rollout_metrics(
reward_scores = rollout_batch["rewards"].clone().to(device=device)
is_end = rollout_batch["is_end"].clone().float().to(device=device)
- dp_world_size = parallel_state.get_data_parallel_world_size()
-
prompt_lengths_list = [
torch.empty_like(prompt_lengths) for _ in range(dp_world_size)
]
@@ -52,12 +55,12 @@ def compute_rollout_metrics(
torch.distributed.all_gather(
prompt_lengths_list,
prompt_lengths,
- group=parallel_state.get_data_parallel_group(),
+ group=dp_group,
)
torch.distributed.all_gather(
decode_lengths_list,
response_lengths,
- group=parallel_state.get_data_parallel_group(),
+ group=dp_group,
)
total_prompt_lengths = torch.cat(prompt_lengths_list, dim=0)
@@ -66,22 +69,22 @@ def compute_rollout_metrics(
torch.distributed.all_reduce(
prompt_lengths,
torch.distributed.ReduceOp.AVG,
- group=parallel_state.get_data_parallel_group(),
+ group=dp_group,
)
torch.distributed.all_reduce(
response_lengths,
torch.distributed.ReduceOp.AVG,
- group=parallel_state.get_data_parallel_group(),
+ group=dp_group,
)
torch.distributed.all_reduce(
reward_scores,
torch.distributed.ReduceOp.AVG,
- group=parallel_state.get_data_parallel_group(),
+ group=dp_group,
)
torch.distributed.all_reduce(
is_end,
torch.distributed.ReduceOp.AVG,
- group=parallel_state.get_data_parallel_group(),
+ group=dp_group,
)
valid_adv = torch.masked_select(advantages, mask)
@@ -90,24 +93,24 @@ def compute_rollout_metrics(
torch.distributed.all_reduce(
n_valid_token,
op=torch.distributed.ReduceOp.SUM,
- group=parallel_state.get_data_parallel_group(),
+ group=dp_group,
)
torch.distributed.all_reduce(
adv_sum,
op=torch.distributed.ReduceOp.SUM,
- group=parallel_state.get_data_parallel_group(),
+ group=dp_group,
)
adv_mean = adv_sum / n_valid_token
adv_max = torch.max(valid_adv).detach().item()
adv_min = torch.min(valid_adv).detach().item()
reduce_tensor = torch.as_tensor(
- [-adv_min, adv_max], device=torch.cuda.current_device(), dtype=torch.float32
+ [-adv_min, adv_max], device=torch.npu.current_device(), dtype=torch.float32
)
torch.distributed.all_reduce(
reduce_tensor,
torch.distributed.ReduceOp.MAX,
- group=parallel_state.get_data_parallel_group(),
+ group=dp_group,
)
adv_min, adv_max = reduce_tensor.tolist()
@@ -172,7 +175,7 @@ def from_rollout_batches(
dp_group: Optional[ProcessGroup],
partitioning_tool: Callable,
) -> Self:
- current_device = torch.cuda.current_device()
+ current_device = torch.npu.current_device()
attn_mask = rollout_batches.get("attention_mask")
current_num_samples = attn_mask.size(0)
@@ -403,12 +406,12 @@ def rebalance_nd_tensor(tensor, group):
NOTE: assumes all other (i.e., non-zero) dimensions are equal.
"""
num_samples = torch.as_tensor(
- tensor.size(0), dtype=torch.int64, device=torch.cuda.current_device()
+ tensor.size(0), dtype=torch.int64, device=torch.npu.current_device()
)
batch_num_per_rank = torch.zeros(
torch.distributed.get_world_size(group),
dtype=torch.int64,
- device=torch.cuda.current_device(),
+ device=torch.npu.current_device(),
)
torch.distributed.all_gather_into_tensor(
batch_num_per_rank, num_samples, group=group
@@ -419,7 +422,7 @@ def rebalance_nd_tensor(tensor, group):
indices = batch_num_per_rank.cumsum(dim=0)
output_tensor = torch.zeros(
- B, *other_dims, dtype=tensor.dtype, device=torch.cuda.current_device()
+ B, *other_dims, dtype=tensor.dtype, device=torch.npu.current_device()
)
# tensor_split is a view we can copy into
@@ -451,7 +454,7 @@ def broadcast_tensor(
"""
if torch.distributed.get_rank() == src:
- tensor = tensor.cuda()
+ tensor = tensor.npu()
if dtype:
tensor = tensor.to(dtype)
@@ -464,7 +467,7 @@ def broadcast_tensor(
torch.distributed.broadcast_object_list(metadata, src, group)
dtype, input_shape = metadata
- tensor = torch.empty(input_shape, dtype=dtype, device="cuda")
+ tensor = torch.empty(input_shape, dtype=dtype, device="npu")
torch.distributed.broadcast(tensor, src, group)
return tensor
@@ -516,7 +519,7 @@ def broadcast_tensor_within_dp(tensor: torch.Tensor, dtype: torch.dtype):
def gather_tensor(tensor, dst, group, dtype=None):
"""Gather any tensor to the dst rank from every other rank in the given group.
All the ranks that send or receive data must call this function."""
- tensor = tensor.to(device=torch.cuda.current_device(), dtype=dtype)
+ tensor = tensor.to(device=torch.npu.current_device(), dtype=dtype)
if torch.distributed.get_rank() == dst:
gather_list = [
torch.empty_like(tensor)
@@ -546,8 +549,8 @@ def normalize_tensor(tensor, mask, group=None):
"""normalizes a tensor using global mean and std"""
dtype = torch.float64
tensor = tensor.to(dtype)
- tensor = tensor.to(device=torch.cuda.current_device())
- mask = mask.to(device=torch.cuda.current_device())
+ tensor = tensor.to(device=torch.npu.current_device())
+ mask = mask.to(device=torch.npu.current_device())
tensor_global_mean, tensor_global_var = masked_global_mean_var(
tensor, mask, group=group
@@ -586,7 +589,7 @@ def masked_normalization(
Normalized x, with the same shape as x.
"""
dtype = torch.float64 if high_precision else torch.float32
- x = x.to(dtype=dtype).cuda()
+ x = x.to(dtype=dtype).npu()
if not inplace:
x = x.clone()
if dim is None:
@@ -596,7 +599,7 @@ def masked_normalization(
np.prod([x.shape[d] for d in dim]), dtype=dtype, device=x.device
)
else:
- mask = mask.to(dtype=dtype).cuda()
+ mask = mask.to(dtype=dtype).npu()
assert len(mask.shape) == len(x.shape), (mask.shape, x.shape, dim)
for i in range(len(x.shape)):
if i in dim:
@@ -640,8 +643,8 @@ def masked_global_mean_var(values, mask, group=None):
mask and values must have same shape, with mask being {0,1} with 1 being the values we want to keep
"""
assert values.shape == mask.shape, (values.shape, mask.shape)
- values = values.to(device=torch.cuda.current_device())
- mask = mask.to(device=torch.cuda.current_device())
+ values = values.to(device=torch.npu.current_device())
+ mask = mask.to(device=torch.npu.current_device())
values = values * mask
@@ -649,7 +652,7 @@ def masked_global_mean_var(values, mask, group=None):
sum_and_count = torch.tensor(
[values.sum(), mask.sum()],
dtype=torch.float64,
- device=torch.cuda.current_device(),
+ device=torch.npu.current_device(),
)
torch.distributed.all_reduce(sum_and_count, group=group)
global_sum, global_count = sum_and_count
@@ -657,7 +660,7 @@ def masked_global_mean_var(values, mask, group=None):
variance_summed = (
(((values - global_mean) ** 2) * mask)
.sum()
- .to(device=torch.cuda.current_device(), dtype=torch.float64)
+ .to(device=torch.npu.current_device(), dtype=torch.float64)
)
torch.distributed.all_reduce(variance_summed, group=group)
@@ -666,12 +669,12 @@ def masked_global_mean_var(values, mask, group=None):
def report_device_info(info_str):
- free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
+ free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
free_gpu_memory /= 2**30
total_gpu_memory /= 2**30
- memory_allocated = torch.cuda.memory_allocated() / 2**30
- memory_reserved = torch.cuda.memory_reserved() / 2**30
+ memory_allocated = torch.npu.memory_allocated() / 2**30
+ memory_reserved = torch.npu.memory_reserved() / 2**30
print(
f"[Rank {torch.distributed.get_rank()}] {info_str}, {free_gpu_memory=:.2f} GiB, {total_gpu_memory=:.2f} GiB, {memory_allocated=:.2f} GiB, {memory_reserved=:.2f} GiB"
@@ -722,7 +725,7 @@ def all_reduce_dict(
):
keys = sorted(dictionary)
tensor = torch.as_tensor(
- [dictionary[k] for k in keys], dtype=dtype, device=torch.cuda.current_device()
+ [dictionary[k] for k in keys], dtype=dtype, device=torch.npu.current_device()
)
torch.distributed.all_reduce(tensor, op=op, group=group)
return dict(zip(keys, tensor.tolist()))
diff --git a/rlinf/utils/placement.py b/rlinf/utils/placement.py
index 7b707894e..6ecea767f 100644
--- a/rlinf/utils/placement.py
+++ b/rlinf/utils/placement.py
@@ -202,6 +202,7 @@ def __init__(self, config: DictConfig, cluster: Cluster):
self._actor_gpus = self._component_gpu_map.get("actor", None)
self._inference_gpus = self._component_gpu_map.get("inference", None)
self._rollout_gpus = self._component_gpu_map.get("rollout", None)
+ self._reward_gpus = self._component_gpu_map.get("reward", None)
assert self._actor_gpus is not None, (
"Actor GPUs must be specified in the component_placement config."
)
@@ -224,11 +225,9 @@ def __init__(self, config: DictConfig, cluster: Cluster):
len(self._inference_gpus) if self._inference_gpus else 0
)
self._rollout_num_gpus = len(self._rollout_gpus)
+ self._reward_num_gpus = len(self._reward_gpus) if self._reward_gpus else 0
if self._is_collocated():
- assert self.actor_tp_size >= self.rollout_tp_size, (
- f"Actor TP size {self.actor_tp_size} must be greater or equal to Rollout TP size {self.rollout_tp_size}."
- )
assert self._inference_gpus is None, (
"Inference GPUs must not be specified in collocated mode."
)
@@ -280,21 +279,24 @@ def _generate_placements(self):
self._actor_gpus[0], self._actor_gpus[-1]
)
- actor_tp_size = self._config.actor.model.tensor_model_parallel_size
- rollout_tp_size = self._config.rollout.tensor_parallel_size
- assert actor_tp_size >= rollout_tp_size, (
- f"Actor TP size ({actor_tp_size}) must be greater or equal to Rollout TP size ({rollout_tp_size})"
- )
- assert actor_tp_size % rollout_tp_size == 0, (
- f"Actor TP size ({actor_tp_size}) must be divisible by Rollout TP size ({rollout_tp_size})"
+ if self.actor_tp_size > self.rollout_tp_size:
+ assert self.actor_tp_size % self.rollout_tp_size == 0, (
+ f"Actor TP size ({self.actor_tp_size}) must be divisible by Rollout TP size ({self.rollout_tp_size})"
+ )
+ stride = (
+ self.actor_tp_size // self.rollout_tp_size
+ if self.actor_tp_size > self.rollout_tp_size
+ else 1
)
- stride = actor_tp_size // rollout_tp_size
self._placements["rollout"] = PackedPlacementStrategy(
self._rollout_gpus[0],
self._rollout_gpus[-1],
- num_accelerators_per_process=rollout_tp_size,
+ num_accelerators_per_process=self.rollout_tp_size,
stride=stride,
)
+ self._placements["reward"] = PackedPlacementStrategy(
+ self._reward_gpus[0], self._reward_gpus[-1]
+ )
elif self._placement_mode == PlacementMode.DISAGGREGATED:
# Generate continuous placement strategies for components in a cluster.
num_gpus_per_rollout_dp = len(self._rollout_gpus) // self.rollout_dp_size
@@ -310,6 +312,9 @@ def _generate_placements(self):
self._placements["actor"] = PackedPlacementStrategy(
self._actor_gpus[0], self._actor_gpus[-1]
)
+ self._placements["reward"] = PackedPlacementStrategy(
+ self._reward_gpus[0], self._reward_gpus[-1]
+ )
@property
def is_disaggregated(self):
@@ -325,18 +330,18 @@ def has_dedicated_inference(self):
@property
def actor_dp_size(self) -> int:
return self._actor_num_gpus // (
- self._config.actor.model.tensor_model_parallel_size
- * self._config.actor.model.context_parallel_size
- * self._config.actor.model.pipeline_model_parallel_size
+ self._config.actor.model.get("tensor_model_parallel_size", 1)
+ * self._config.actor.model.get("context_parallel_size", 1)
+ * self._config.actor.model.get("pipeline_model_parallel_size", 1)
)
@property
def actor_tp_size(self) -> int:
- return self._config.actor.model.tensor_model_parallel_size
+ return self._config.actor.model.get("tensor_model_parallel_size", 1)
@property
def actor_pp_size(self) -> int:
- return self._config.actor.model.pipeline_model_parallel_size
+ return self._config.actor.model.get("pipeline_model_parallel_size", 1)
@property
def actor_world_size(self) -> int:
@@ -349,7 +354,7 @@ def inference_tp_size(self) -> int:
and hasattr(self._config.inference, "model")
and hasattr(self._config.inference.model, "tensor_model_parallel_size")
):
- return self._config.inference.model.tensor_model_parallel_size
+ return self._config.inference.model.get("tensor_model_parallel_size", 1)
else:
return self.actor_tp_size
@@ -360,7 +365,7 @@ def inference_pp_size(self) -> int:
and hasattr(self._config.inference, "model")
and hasattr(self._config.inference.model, "pipeline_model_parallel_size")
):
- return self._config.inference.model.pipeline_model_parallel_size
+ return self._config.inference.model.get("pipeline_model_parallel_size", 1)
else:
return self.actor_pp_size
@@ -377,14 +382,18 @@ def inference_world_size(self) -> int:
@property
def rollout_dp_size(self) -> int:
return self._rollout_num_gpus // (
- self._config.rollout.tensor_parallel_size
- * self._config.rollout.pipeline_parallel_size
+ self._config.rollout.get("tensor_parallel_size", 1)
+ * self._config.rollout.get("pipeline_parallel_size", 1)
)
@property
def rollout_tp_size(self) -> int:
- return self._config.rollout.tensor_parallel_size
+ return self._config.rollout.get("tensor_parallel_size", 1)
@property
def rollout_world_size(self) -> int:
return self._rollout_num_gpus
+
+ @property
+ def reward_world_size(self) -> int:
+ return self._reward_num_gpus
diff --git a/rlinf/utils/resharding/mcore_weight_reshard.py b/rlinf/utils/resharding/mcore_weight_reshard.py
index 9ec44eba0..90d8277fa 100644
--- a/rlinf/utils/resharding/mcore_weight_reshard.py
+++ b/rlinf/utils/resharding/mcore_weight_reshard.py
@@ -183,6 +183,6 @@ def get_layer_num(param_name):
)
if self.config.convert_fn is not None:
- model_state_dict = self.config.convert_fn(model_state_dict, self.config)
+ model_state_dict = self.config.convert_fn(model_state_dict)
return model_state_dict
diff --git a/rlinf/utils/resharding/reshard_config.py b/rlinf/utils/resharding/reshard_config.py
index 2b493a839..79e089bcd 100644
--- a/rlinf/utils/resharding/reshard_config.py
+++ b/rlinf/utils/resharding/reshard_config.py
@@ -17,7 +17,9 @@
from megatron.core.transformer import TransformerConfig
-from .utils import get_convert_fn, get_pp_reshard_fn, get_tp_reshard_fn
+from rlinf.utils.convertor.utils import get_mg2hf_convertor
+
+from .utils import get_pp_reshard_fn, get_tp_reshard_fn
@dataclass
@@ -37,7 +39,7 @@ class ReshardConfig:
"""Resharding pp size."""
convert_fn: Callable = None
- """Convert function to use for converting the model parameters' weight and name from training engine to rollout engine."""
+ """Function to convert the model weights from megatron format to HuggingFace format."""
tp_reshard_fn: Callable = None
"""Resharding function to use for resharding the model parallelism from tensor_model_parallel_size to reshard_tp_size."""
@@ -59,7 +61,8 @@ def __post_init__(self):
)
if self.convert_fn is None and self.reshard_weights_format != "mcore":
- self.convert_fn = get_convert_fn(self.model_arch)
+ self._convertor = get_mg2hf_convertor(self.model_arch, self, strict=True)
+ self.convert_fn = self._convertor.convert
if self.tp_reshard_fn is None:
self.tp_reshard_fn = get_tp_reshard_fn(self.model_arch)
diff --git a/rlinf/utils/resharding/utils.py b/rlinf/utils/resharding/utils.py
index 1fae2b05a..d7a4af231 100644
--- a/rlinf/utils/resharding/utils.py
+++ b/rlinf/utils/resharding/utils.py
@@ -12,21 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import re
-from enum import Enum
-from typing import List, Tuple
import torch
-from megatron.core import parallel_state
-
-
-def get_convert_fn(model_arch: str):
- if model_arch == "qwen2.5":
- return TransformFunc.convert_mega_qwen2_5_to_hf
- else:
- raise NotImplementedError(
- f"get_convert_fn for model_arch {model_arch} is not implemented"
- )
def get_tp_reshard_fn(model_arch: str):
@@ -47,212 +34,6 @@ def get_pp_reshard_fn(model_arch: str):
)
-###########################
-# convert fn implementation
-###########################
-
-
-class TransformType(Enum):
- SPLIT_QKV = "split_qkv"
- SPLIT_QKV_BIAS = "split_qkv_bias"
- SPLIT_FC1 = "split_fc1"
- SPLIT_NONE = "split_none"
-
-
-class TransformFunc:
- @staticmethod
- def _split_gqa_tensor(
- tensor: torch.Tensor, new_statedict: dict, weight_names: List[str], config
- ) -> None:
- hidden_size = config.model_config.hidden_size
- num_attention_heads = config.model_config.num_attention_heads
- num_query_groups = config.model_config.num_query_groups or num_attention_heads
- head_dim = hidden_size // num_attention_heads
-
- target_tp = config.reshard_tp_size
- assert num_query_groups % target_tp == 0, (
- "num_query_groups must be divisible by reshard_tp_size"
- )
- local_num_query_groups = num_query_groups // target_tp
-
- # heads per query group
- assert num_attention_heads % num_query_groups == 0, (
- "num_attention_heads must be divisible by num_query_groups"
- )
- q_heads_per_group = num_attention_heads // num_query_groups
-
- num_channel_qkv = q_heads_per_group + 2
-
- if tensor.ndim == 2:
- # Weight: [out_features, in_features]
- out_features, in_features = tensor.shape
- expected_out = local_num_query_groups * num_channel_qkv * head_dim
- assert out_features == expected_out, (
- f"Unexpected fused QKV weight shape {tensor.shape}, expect "
- f"[{expected_out}, {in_features}] (local groups={local_num_query_groups})"
- )
-
- qkv = tensor.view(
- local_num_query_groups, num_channel_qkv, head_dim, in_features
- )
- q, k, v = torch.split(
- qkv, [q_heads_per_group, 1, 1], dim=1
- ) # shapes: [G, qh, D, In], [G,1,D,In], [G,1,D,In]
- q_full = q.reshape(-1, in_features).contiguous()
- k_full = k.reshape(-1, in_features).contiguous()
- v_full = v.reshape(-1, in_features).contiguous()
- else:
- # Bias: [out_features]
- out_features = tensor.shape[0]
- expected_out = local_num_query_groups * num_channel_qkv * head_dim
- assert out_features == expected_out, (
- f"Unexpected fused QKV bias shape {tensor.shape}, expect "
- f"[{expected_out}] (local groups={local_num_query_groups})"
- )
-
- qkv = tensor.view(local_num_query_groups, num_channel_qkv, head_dim)
- q, k, v = torch.split(qkv, [q_heads_per_group, 1, 1], dim=1)
- q_full = q.reshape(-1).contiguous()
- k_full = k.reshape(-1).contiguous()
- v_full = v.reshape(-1).contiguous()
-
- # Save to target names
- new_statedict[weight_names[0]] = q_full.clone()
- new_statedict[weight_names[1]] = k_full.clone()
- new_statedict[weight_names[2]] = v_full.clone()
-
- @staticmethod
- def split_fc1(
- linear_fc1: torch.Tensor, new_statedict: dict, weight_names: List[str], config
- ) -> None:
- assert weight_names is not None and len(weight_names) == 2, (
- f"split_fc1 transform expects two weight names, got {weight_names}"
- )
-
- tp_size = config.model_config.tensor_model_parallel_size
- target_tp = config.reshard_tp_size
- split_size = linear_fc1.shape[0] // (tp_size // target_tp)
- linear_fc1_slice = torch.split(linear_fc1, split_size, dim=0)
-
- gate_proj_shards = []
- up_proj_shards = []
- for weight in linear_fc1_slice:
- assert weight.shape[0] % 2 == 0, (
- f"linear_fc1 weight shape {weight.shape} is not even along dim 0"
- )
- weight_chunk = torch.chunk(weight, 2, dim=0)
- gate_proj_shards.append(weight_chunk[0])
- up_proj_shards.append(weight_chunk[1])
- gate_proj = torch.cat(gate_proj_shards, dim=0)
- up_proj = torch.cat(up_proj_shards, dim=0)
-
- new_statedict[weight_names[0]] = gate_proj.clone()
- new_statedict[weight_names[1]] = up_proj.clone()
-
- @staticmethod
- def split_none(
- tensor: torch.Tensor, new_statedict: dict, weight_names: List[str]
- ) -> None:
- assert weight_names is not None and len(weight_names) == 1, (
- f"split_none transform expects one weight name, got {weight_names}"
- )
- new_statedict[weight_names[0]] = tensor.clone()
-
- @staticmethod
- def mega_name_qwen2_5_to_hf(name: str) -> Tuple[TransformType, List[str]]:
- """
- Convert qwen2_5 model weight megatron name to hf name and do shape transform if needed.
-
- Args:
- name (str): megatron model weight name
-
- Returns:
- (TransformType, List[str]): transform type and the corresponding hf model weight name
- """
- if "embedding.word_embeddings.weight" in name:
- return (TransformType.SPLIT_NONE, ["model.embed_tokens.weight"])
- if "decoder.final_layernorm.weight" in name:
- return (TransformType.SPLIT_NONE, ["model.norm.weight"])
- if "output_layer.weight" in name:
- return (TransformType.SPLIT_NONE, ["lm_head.weight"])
- layer_id, suffix = TransformFunc.extract_layer_info(name)
- assert layer_id is not None, f"Cannot extract layer info from {name}"
- result_pattern = "model.layers.{}.{}"
- nmap = {
- "self_attention.linear_proj.weight": (
- TransformType.SPLIT_NONE,
- ["self_attn.o_proj.weight"],
- ),
- "self_attention.linear_qkv.layer_norm_weight": (
- TransformType.SPLIT_NONE,
- ["input_layernorm.weight"],
- ),
- "self_attention.linear_qkv.weight": (
- TransformType.SPLIT_QKV,
- [
- "self_attn.q_proj.weight",
- "self_attn.k_proj.weight",
- "self_attn.v_proj.weight",
- ],
- ),
- "self_attention.linear_qkv.bias": (
- TransformType.SPLIT_QKV_BIAS,
- [
- "self_attn.q_proj.bias",
- "self_attn.k_proj.bias",
- "self_attn.v_proj.bias",
- ],
- ),
- "mlp.linear_fc1.layer_norm_weight": (
- TransformType.SPLIT_NONE,
- ["post_attention_layernorm.weight"],
- ),
- "mlp.linear_fc1.weight": (
- TransformType.SPLIT_FC1,
- ["mlp.gate_proj.weight", "mlp.up_proj.weight"],
- ),
- "mlp.linear_fc2.weight": (
- TransformType.SPLIT_NONE,
- ["mlp.down_proj.weight"],
- ),
- }
-
- assert suffix in nmap, f"Cannot find mapping for {suffix}"
-
- transform_type, suffixes = nmap[suffix]
- return (
- transform_type,
- [result_pattern.format(layer_id, suffix) for suffix in suffixes],
- )
-
- @staticmethod
- def convert_mega_qwen2_5_to_hf(model_state_dict: dict, config) -> dict:
- new_statedict = {}
- for name, param in model_state_dict.items():
- transform_type, hf_names = TransformFunc.mega_name_qwen2_5_to_hf(name)
- if transform_type == TransformType.SPLIT_QKV:
- TransformFunc._split_gqa_tensor(param, new_statedict, hf_names, config)
- elif transform_type == TransformType.SPLIT_QKV_BIAS:
- TransformFunc._split_gqa_tensor(param, new_statedict, hf_names, config)
- elif transform_type == TransformType.SPLIT_FC1:
- TransformFunc.split_fc1(param, new_statedict, hf_names, config)
- elif transform_type == TransformType.SPLIT_NONE:
- TransformFunc.split_none(param, new_statedict, hf_names)
- else:
- raise NotImplementedError(
- f"Transform type {transform_type} not implemented"
- )
- return new_statedict
-
- @staticmethod
- def extract_layer_info(s):
- pattern = r"layers\.(\d+)\.(.+)"
- match = re.search(pattern, s)
- if match:
- return match.group(1), match.group(2)
- return None, None
-
-
##############################
# tp reshard fn implementation
##############################
@@ -314,6 +95,8 @@ def _gather_pp_group_tensor_and_reshard(
def pp_reshard_fn_qwen2_5(model_state_dict, pp_group, dtype):
+ from megatron.core import parallel_state
+
pp_first_rank = parallel_state.get_pipeline_model_parallel_first_rank()
pp_last_rank = parallel_state.get_pipeline_model_parallel_last_rank()
diff --git a/rlinf/utils/utils.py b/rlinf/utils/utils.py
index 449412f8b..b840538cc 100644
--- a/rlinf/utils/utils.py
+++ b/rlinf/utils/utils.py
@@ -20,13 +20,14 @@
from functools import partial, wraps
import torch
-
+import torch.nn.functional as F
+import torch_npu
def clear_memory(sync=True):
if sync:
- torch.cuda.synchronize()
+ torch.npu.synchronize()
gc.collect()
- torch.cuda.empty_cache()
+ torch.npu.empty_cache()
def apply_func_to_dict(func, dictionary):
@@ -53,7 +54,7 @@ def retrieve_model_state_dict_in_cpu(model):
cpu_dict[name] = item
- torch.cuda.synchronize()
+ torch.npu.synchronize()
return cpu_dict
@@ -124,6 +125,65 @@ def seq_mean_token_mean(values, mask):
return loss
+def logprobs_from_logits_flash_attn(logits, labels, inplace_backward=True):
+ #from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
+
+ #output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward)
+ #assert isinstance(output, tuple), (
+ # "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]."
+ #)
+ #return -output[0]
+ import torch.nn.functional as F
+
+ # 数值稳定的 log_softmax
+ log_probs = F.log_softmax(logits, dim=-1)
+
+ # 提取标签对应的 logprob
+ labels = labels.unsqueeze(-1)
+ return torch.gather(log_probs, -1, labels).squeeze(-1)
+
+
+def compute_logprobs_from_logits(logits, target, task_type="embodied"):
+ if task_type == "embodied":
+ logprobs = -F.cross_entropy(
+ logits, target=target, reduction="none"
+ ) # [B, action-dim]
+ return logprobs
+ batch_dim = logits.shape[:-1]
+ last_dim = logits.shape[-1]
+ logits = logits.reshape(-1, last_dim)
+ labels = target.reshape(-1)
+ logprobs = logprobs_from_logits_flash_attn(
+ logits, labels=labels, inplace_backward=False
+ )
+ logprobs = logprobs.view(*batch_dim)
+ return logprobs
+
+
+def entropy_from_logits(logits: torch.Tensor):
+ """Calculate entropy from logits."""
+ pd = torch.nn.functional.softmax(logits, dim=-1)
+ entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)
+ return entropy
+
+
+def compute_entropy_from_logits(logits, epsilon=1e-10, task_type="embodied"):
+ """
+ Compute entropy by logits.
+
+ Args:
+ logits: [B, vocab-size, seq-len]
+ Returns:
+ entropy: [B, seq-len]
+ """
+ if task_type == "embodied":
+ all_probs = F.softmax(logits, dim=1) # [B, vocab-size, seq-len]
+ all_log_probs = torch.log(all_probs + epsilon)
+ entropy = -torch.sum(all_probs * all_log_probs, dim=1) # [B, seq-len]
+ return entropy
+ return entropy_from_logits(logits=logits)
+
+
class DualOutput:
def __init__(self, file, terminal):
self.file = file
diff --git a/rlinf/workers/actor/__init__.py b/rlinf/workers/actor/__init__.py
index 5b365ea1e..2d315469e 100644
--- a/rlinf/workers/actor/__init__.py
+++ b/rlinf/workers/actor/__init__.py
@@ -11,3 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from omegaconf import DictConfig
+
+from rlinf.scheduler.worker.worker import Worker
+
+
+def get_actor_worker(cfg: DictConfig) -> Worker:
+ if cfg.actor.training_backend == "fsdp":
+ from .fsdp_actor_worker import FSDPActor
+
+ return FSDPActor
+ elif cfg.actor.training_backend == "megatron":
+ from .megatron_actor_worker import MegatronActor
+
+ return MegatronActor
+ else:
+ raise ValueError(f"Unsupported training backend: {cfg.actor.training_backend}")
diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py
index 61b05c9b4..1b3b015bb 100644
--- a/rlinf/workers/actor/fsdp_actor_worker.py
+++ b/rlinf/workers/actor/fsdp_actor_worker.py
@@ -14,31 +14,518 @@
import gc
import os
+from contextlib import nullcontext
+from typing import Dict, Tuple
import numpy as np
import torch
from omegaconf import DictConfig
from torch.distributed.device_mesh import init_device_mesh
+from torch.multiprocessing.reductions import reduce_tensor
from tqdm import tqdm
import rlinf.algorithms # noqa: F401
from rlinf.algorithms.registry import actor_loss, calculate_adv_and_returns
-from rlinf.algorithms.utils import preprocess_advantages_inputs, preprocess_loss_inputs
+from rlinf.algorithms.utils import (
+ kl_penalty,
+ preprocess_advantages_inputs,
+ preprocess_loss_inputs,
+)
+from rlinf.data.io_struct import RolloutResult
from rlinf.hybrid_engines.fsdp.fsdp_model_manager import (
FSDPModelManager,
)
from rlinf.models import get_model
from rlinf.models.embodiment.model_utils import custom_forward
-from rlinf.scheduler import Cluster, Worker
+from rlinf.scheduler import Channel, Cluster, Worker
from rlinf.utils.data_iter_utils import get_iterator_k_split
from rlinf.utils.distributed import all_reduce_dict
+from rlinf.utils.distributed import (
+ compute_rollout_metrics as compute_math_rollout_metrics,
+)
from rlinf.utils.metric_utils import (
append_to_dict,
compute_loss_mask,
compute_rollout_metrics,
compute_split_num,
)
-from rlinf.utils.placement import HybridComponentPlacement
+from rlinf.utils.placement import (
+ HybridComponentPlacement,
+ ModelParallelComponentPlacement,
+)
+from rlinf.utils.utils import (
+ compute_logprobs_from_logits,
+ cpu_weight_swap,
+ masked_mean,
+ retrieve_model_state_dict_in_cpu,
+ seq_mean_token_mean,
+ seq_mean_token_sum,
+)
+from rlinf.workers.rollout.utils import RankMapper
+import torch_npu
+
+class FSDPActor(FSDPModelManager, Worker):
+ def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement):
+ Worker.__init__(self)
+ super().__init__(cfg.actor)
+
+ self.cfg = cfg
+
+ self.response_len = (
+ cfg.actor.model.encoder_seq_length - cfg.data.max_prompt_length
+ )
+ self.calculate_entropy = self.cfg.algorithm.calculate_entropy
+ self.calculate_entropy_loss = (
+ self.cfg.algorithm.entropy_bonus > 0 and self.calculate_entropy
+ )
+ self.kl_beta = self.cfg.algorithm.kl_beta
+ self.kl_penalty_type = self.cfg.algorithm.kl_penalty_type
+
+ self.total_batch_size_per_dp = (
+ self.cfg.data.rollout_batch_size
+ * self.cfg.algorithm.group_size
+ // self._world_size
+ )
+
+ torch.npu.set_device(int(os.environ["LOCAL_RANK"]))
+ self.device = torch.npu.current_device()
+ world_size = self._world_size
+ self.device_mesh = init_device_mesh(
+ "npu", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]
+ )
+
+ self._rollout_group_name = cfg.rollout.group_name
+ self._component_placement = placement
+ self.is_data_io_rank = True
+ self.is_pipeline = self._component_placement.is_disaggregated
+ self.ref_policy_state_dict = None
+
+ if self.cfg.algorithm.loss_agg_func == "token-mean":
+ self.loss_agg_func = masked_mean
+ elif self.cfg.algorithm.loss_agg_func == "seq-mean-token-sum":
+ self.loss_agg_func = seq_mean_token_sum
+ elif self.cfg.algorithm.loss_agg_func == "seq-mean-token-mean":
+ self.loss_agg_func = seq_mean_token_mean
+ else:
+ raise NotImplementedError(
+ f"algorithm.loss_agg_func={self.cfg.algorithm.loss_agg_func} is not supported!"
+ )
+
+ def init_worker(self) -> None:
+ self.setup_model_and_optimizer()
+ if self.cfg.algorithm.kl_beta > 0 and self.cfg.actor.get(
+ "combine_reference_model", True
+ ):
+ self.ref_policy_state_dict = retrieve_model_state_dict_in_cpu(self.model)
+
+ if self.cfg.actor.get("enable_offload", False):
+ self.offload_fsdp_param_and_grad()
+ self.offload_fsdp_optimizer()
+ self._setup_rollout_weight_dst_ranks()
+
+ def _setup_rollout_weight_dst_ranks(self) -> None:
+ """Setup destination ranks for token and weight communication."""
+ rank_map = RankMapper.get_actor_rank_to_rollout_rank_map(
+ self._component_placement
+ )
+ self._weight_dst_rank_in_rollout = rank_map[self._rank]
+ self.log_info(
+ f"Actor rank {self._rank} will send weights to {self._weight_dst_rank_in_rollout}"
+ )
+
+ def del_reshard_state_dict(self) -> None:
+ if hasattr(self, "rollout_state_dict"):
+ del self.rollout_state_dict
+
+ def sync_model_to_rollout(self) -> None:
+ if self.cfg.actor.get("enable_offload", False):
+ self.offload_fsdp_optimizer()
+
+ if next(self.model.parameters()).is_cpu:
+ self.load_fsdp_param_and_grad(self.device)
+ self.rollout_state_dict = self.get_model_state_dict()
+
+ has_visual = any("visual." in k for k in self.rollout_state_dict.keys())
+
+ state_dict = {}
+
+ if self._weight_dst_rank_in_rollout is not None:
+ for k, v in self.rollout_state_dict.items():
+ name = k
+ if has_visual:
+ if name.startswith("model.language_model."):
+ name = "model." + name[21:]
+ # NOTE:
+ # if transformers version is 4.56.1 or older(not tested),
+ # the following line should be uncommented
+
+ # elif name.startswith("model."):
+ # name = name[6:]
+ state_dict[name] = reduce_tensor(v)
+
+ self.send(
+ state_dict, self._rollout_group_name, self._weight_dst_rank_in_rollout
+ )
+
+ if self.cfg.actor.get("enable_offload", False):
+ self.offload_fsdp_param_and_grad()
+
+ def compute_logprobs(self) -> None:
+ self.model.eval()
+ self.rollout_batch["logprob"] = self.rollout_batch["prev_logprobs"]
+
+ def get_batch(
+ self, channel: Channel
+ ) -> Tuple[Dict[str, torch.Tensor], RolloutResult]:
+ result: RolloutResult = channel.get()
+
+ batch = result.to_actor_batch(
+ self.cfg.data.max_prompt_length,
+ self.cfg.actor.model.encoder_seq_length,
+ self.tokenizer.eos_token_id,
+ )
+ return batch, result
+
+ def put_result(self, result: RolloutResult, channel: Channel) -> None:
+ if channel.is_local:
+ # Local channel, every process will put its own data locally
+ # No need to broadcast
+ channel.put(result)
+ else:
+ if self.is_data_io_rank:
+ channel.put(result)
+
+ def _load_weight_and_optimizer(self, channel: Channel) -> None:
+ # Acquire the GPUs to ensure that no one is using them before loading models
+ # Otherwise, it may lead to OOM
+ with channel.device_lock:
+ if self.cfg.actor.get("enable_offload", False):
+ self.load_fsdp_param_and_grad(self.device)
+ self.load_fsdp_optimizer(self.device)
+
+ @torch.no_grad()
+ def inference_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
+ self.model.eval()
+ input_ids = batch["input_ids"]
+ attention_mask = batch["attention_mask"]
+ position_ids = batch["position_ids"]
+
+ multi_modal_inputs = {}
+ if "multi_modal_inputs" in batch.keys():
+ for key in batch["multi_modal_inputs"][0].keys():
+ multi_modal_inputs[key] = torch.cat(
+ [inputs[key] for inputs in batch["multi_modal_inputs"]],
+ dim=0,
+ ).npu()
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ use_cache=False,
+ **multi_modal_inputs,
+ )
+
+ logits = outputs.logits
+ logits = logits[:, -self.response_len - 1 : -1, :]
+ logits = logits / self.cfg.algorithm.sampling_params.temperature
+
+ responses = input_ids[:, -self.response_len :]
+ logprobs = compute_logprobs_from_logits(
+ logits, responses, task_type=self.cfg.runner.task_type
+ )
+ return logprobs
+
+ def run_inference(
+ self,
+ input_channel: Channel,
+ output_channel: Channel,
+ rollout_channel: Channel,
+ compute_ref_logprobs: bool,
+ ) -> None:
+ """
+ Compute prev/ref logprobs using the actor Model's forward.
+
+ Args:
+ input_channel: The input channel to read from.
+ output_channel: The output channel to send results to.
+ rollout_channel: get the rollout channel's device lock in case of collision.
+ compute_ref_logprobs: Whether to compute reference logprobs.
+ """
+ recv_batch_size = 0
+ while recv_batch_size < self.total_batch_size_per_dp:
+ batch, rollout_result = self.get_batch(input_channel)
+ recv_batch_size += rollout_result.num_sequence
+ self._load_weight_and_optimizer(
+ input_channel if self.is_pipeline else rollout_channel
+ )
+ num_splits = (
+ rollout_result.num_sequence
+ // self.cfg.algorithm.logprob_forward_micro_batch_size
+ )
+ micro_batches_iter = get_iterator_k_split(
+ batch,
+ num_splits=num_splits,
+ )
+ micro_batches = list(micro_batches_iter)
+
+ prev_logprobs = []
+ with self.worker_timer():
+ for micro_batch in micro_batches:
+ prev_logprobs.append(self.inference_step(micro_batch).cpu())
+ rollout_result.prev_logprobs = torch.cat(prev_logprobs)
+ if compute_ref_logprobs:
+ assert self.ref_policy_state_dict is not None, (
+ "Reference policy state dict is None but compute_ref_logprobs is True"
+ )
+ ref_logprobs = []
+ with cpu_weight_swap(self.model, self.ref_policy_state_dict):
+ for micro_batch in micro_batches:
+ ref_logprobs.append(self.inference_step(micro_batch).cpu())
+ rollout_result.ref_logprobs = torch.cat(ref_logprobs)
+ self.put_result(rollout_result, output_channel)
+
+ assert recv_batch_size == self.total_batch_size_per_dp, (
+ f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}"
+ )
+
+ def run_training(self, input_channel: Channel) -> Tuple[Dict, list]:
+ # Get all batches for this DP
+ batches = []
+ recv_batch_size = 0
+ while recv_batch_size < self.total_batch_size_per_dp:
+ batch, rollout_result = self.get_batch(input_channel)
+ batches.append(batch)
+ recv_batch_size += rollout_result.num_sequence
+ assert recv_batch_size == self.total_batch_size_per_dp, (
+ f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}"
+ )
+ batch = RolloutResult.merge_batches(batches)
+ # Must be called after batch is retrieved, which is when rollout has stopped
+ # Otherwise, loading model might cause OOM
+ self._load_weight_and_optimizer(input_channel)
+
+ global_batches = get_iterator_k_split(
+ batch,
+ num_splits=self.cfg.algorithm.n_minibatches,
+ shuffle=self.cfg.algorithm.get("shuffle_rollout", True),
+ shuffle_seed=self.cfg.actor.seed,
+ )
+
+ self.model.train()
+ assert (
+ self.cfg.actor.global_batch_size
+ % (self.cfg.actor.micro_batch_size * self._world_size)
+ == 0
+ )
+
+ training_metrics_list = []
+ # Global batch iterations
+ with self.worker_timer():
+ for global_batch in global_batches:
+ train_global_batch_size = global_batch["input_ids"].shape[0]
+
+ assert train_global_batch_size % self.cfg.actor.micro_batch_size == 0, (
+ f"{train_global_batch_size=}, {self.cfg.actor.micro_batch_size=}"
+ )
+
+ self.gradient_accumulation = (
+ train_global_batch_size // self.cfg.actor.micro_batch_size
+ )
+ # split batch into micro_batches
+ train_micro_batches = get_iterator_k_split(
+ global_batch,
+ train_global_batch_size // self.cfg.actor.micro_batch_size,
+ )
+
+ self.optimizer.zero_grad()
+ metrics = {}
+ for idx, m_batch in enumerate(train_micro_batches):
+ backward_ctx = (
+ self.model.no_sync()
+ if idx < self.gradient_accumulation - 1
+ else nullcontext()
+ )
+ for k, v in m_batch.items():
+ m_batch[k] = v.npu() if isinstance(v, torch.Tensor) else v
+
+ multi_modal_inputs = {}
+ if "multi_modal_inputs" in m_batch.keys():
+ for key in m_batch["multi_modal_inputs"][0].keys():
+ multi_modal_inputs[key] = torch.cat(
+ [
+ inputs[key]
+ for inputs in m_batch["multi_modal_inputs"]
+ ],
+ dim=0,
+ ).npu()
+
+ input_ids = m_batch["input_ids"]
+ attention_mask = m_batch["attention_mask"]
+ position_ids = m_batch["position_ids"]
+ prev_logprobs = m_batch["prev_logprobs"]
+ advantages = m_batch["advantages"]
+ ref_logprobs = None
+ if "ref_logprobs" in m_batch:
+ ref_logprobs = m_batch["ref_logprobs"]
+
+ loss_mask = m_batch["attention_mask"][:, -self.response_len :]
+ output = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ **multi_modal_inputs,
+ use_cache=False,
+ )
+
+ logits = output.logits
+
+ logits.div_(self.cfg.algorithm.sampling_params.temperature)
+
+ responses = input_ids[:, -self.response_len :]
+ logits = logits[
+ :, -self.response_len - 1 : -1, :
+ ] # (bsz, response_length, vocab_size)
+ logprobs = compute_logprobs_from_logits(
+ logits, responses, task_type=self.cfg.runner.task_type
+ )
+
+ clip_ratio = self.cfg.algorithm.ratio_clip_eps
+ clip_ratio_low = (
+ self.cfg.algorithm.clip_ratio_low
+ if self.cfg.algorithm.clip_ratio_low is not None
+ else clip_ratio
+ )
+ clip_ratio_high = (
+ self.cfg.algorithm.clip_ratio_high
+ if self.cfg.algorithm.clip_ratio_high is not None
+ else clip_ratio
+ )
+ clip_ratio_c = self.cfg.algorithm.get("clip_ratio_c", 3.0)
+
+ loss, mbs_metrics_data = actor_loss(
+ loss_type=self.cfg.algorithm.loss_type,
+ loss_agg_func=self.loss_agg_func,
+ logprobs=logprobs,
+ old_logprobs=prev_logprobs,
+ advantages=advantages,
+ clip_ratio_low=clip_ratio_low,
+ clip_ratio_high=clip_ratio_high,
+ clip_ratio_c=clip_ratio_c,
+ loss_mask=loss_mask,
+ )
+
+ entropy_loss = torch.tensor(0.0, device=torch.npu.current_device())
+ if self.calculate_entropy:
+ entropy = output["entropy"][
+ :, -self.response_len - 1 : -1
+ ].contiguous()
+ entropy_loss = self.loss_agg_func(entropy, mask=loss_mask)
+ if self.calculate_entropy_loss:
+ loss = (
+ loss - self.cfg.algorithm.entropy_bonus * entropy_loss
+ )
+
+ kl_loss = torch.tensor(0.0, device=torch.npu.current_device())
+ if self.kl_beta > 0 and ref_logprobs is not None:
+ kld = kl_penalty(ref_logprobs, logprobs, self.kl_penalty_type)
+ kl_loss = self.loss_agg_func(kld, loss_mask)
+ loss = loss + kl_loss * self.kl_beta
+
+ # add to log
+ # scale loss for gradient accumulation and backprop
+ loss = loss / self.gradient_accumulation
+ with backward_ctx:
+ loss.backward()
+
+ mbs_metrics_data.update(
+ {
+ "final_loss": loss.detach(),
+ "entropy_loss": entropy_loss.detach(),
+ "kl_loss": kl_loss.detach(),
+ }
+ )
+
+ append_to_dict(metrics, mbs_metrics_data)
+ # apply gradient clipping and optimizer step at the end of a global batch
+ grad_norm = self.model.clip_grad_norm_(
+ max_norm=self.cfg.actor.optim.clip_grad
+ )
+ if not torch.isfinite(grad_norm).all():
+ self.log_warning(
+ "grad norm is not finite, skip this optimizer step."
+ )
+ else:
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+
+ # aggregate metrics across micro-batches
+ mean_metric_dict = {
+ key: torch.mean(torch.stack(value))
+ for key, value in metrics.items()
+ }
+ mean_metric_dict = all_reduce_dict(
+ mean_metric_dict, op=torch.distributed.ReduceOp.AVG
+ )
+ # add optimizer stats
+ if torch.is_tensor(grad_norm):
+ mean_metric_dict["actor/grad_norm"] = float(
+ grad_norm.detach().item()
+ )
+ else:
+ mean_metric_dict["actor/grad_norm"] = float(grad_norm)
+ lr = self.optimizer.param_groups[0]["lr"]
+ mean_metric_dict["actor/lr"] = torch.as_tensor(lr).float().cpu()
+ training_metrics_list.append(mean_metric_dict)
+
+ # Rollout metrics
+ rollout_metrics, _, _ = compute_math_rollout_metrics(
+ batch, self.cfg.data.max_prompt_length, self.response_len, self._world_size
+ )
+
+ return rollout_metrics, training_metrics_list
+
+ def save_checkpoint(self, save_base_path: str, step: int) -> None:
+ torch.distributed.barrier()
+ model_state = self.get_model_state_dict()
+ optim_state = self.get_optimizer_state_dict()
+ if self._rank == 0:
+ os.makedirs(save_base_path, exist_ok=True)
+ torch.save(model_state, os.path.join(save_base_path, "model.pt"))
+ torch.save(optim_state, os.path.join(save_base_path, "optim.pt"))
+ torch.distributed.barrier()
+
+ # Advantages and returns
+ def compute_advantages_and_returns(
+ self, input_channel: Channel, output_channel: Channel
+ ) -> None:
+ """Compute the advantages and returns.
+
+ Args:
+ input_channel: The input channel to read from.
+ output_channel: The output channel to send results to.
+ """
+ recv_batch_size = 0
+ while recv_batch_size < self.total_batch_size_per_dp:
+ batch, rollout_result = self.get_batch(input_channel)
+ recv_batch_size += rollout_result.num_sequence
+
+ with self.worker_timer():
+ if rollout_result.advantages is None:
+ mask = batch["attention_mask"][:, -self.response_len :]
+ advantages, returns = calculate_adv_and_returns(
+ adv_type=self.cfg.algorithm.adv_type,
+ reward_scores=batch["rewards"].npu(),
+ mask=mask.npu(),
+ num_responses=self.cfg.algorithm.group_size,
+ )
+ rollout_result.advantages = advantages.cpu()
+
+ self.put_result(rollout_result, output_channel)
+
+ assert recv_batch_size == self.total_batch_size_per_dp, (
+ f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}"
+ )
class EmbodiedFSDPActor(FSDPModelManager, Worker):
@@ -47,11 +534,11 @@ def __init__(self, cfg: DictConfig):
super().__init__(cfg.actor)
self.cfg = cfg
- torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
- self.device = torch.cuda.current_device()
+ torch.npu.set_device(int(os.environ["LOCAL_RANK"]))
+ self.device = torch.npu.current_device()
world_size = self._world_size
self.device_mesh = init_device_mesh(
- "cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]
+ "npu", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]
)
self._env_group_name = cfg.env.group_name
@@ -79,9 +566,9 @@ def init_worker(self):
if self.cfg.actor.get("enable_offload", False):
self.offload_fsdp_param_and_grad()
self.offload_fsdp_optimizer()
- torch.cuda.synchronize()
+ torch.npu.synchronize()
gc.collect()
- torch.cuda.empty_cache()
+ torch.npu.empty_cache()
def model_provider_func(self):
model = get_model(self.cfg.actor.checkpoint_load_path, self.cfg.actor.model)
@@ -102,10 +589,10 @@ def sync_model_to_rollout(self):
if self.cfg.actor.get("enable_offload", False):
self.offload_fsdp_param_and_grad()
self.offload_fsdp_optimizer()
- torch.cuda.synchronize()
+ torch.npu.synchronize()
del state_dict
gc.collect()
- torch.cuda.empty_cache()
+ torch.npu.empty_cache()
async def recv_rollout_batch(self):
send_num = self._component_placement.get_world_size("rollout") * self.stage_num
@@ -389,7 +876,7 @@ def run_training(self):
metrics_data["loss"] = loss.detach().item()
append_to_dict(metrics, metrics_data)
- torch.cuda.empty_cache()
+ torch.npu.empty_cache()
grad_norm = self.model.clip_grad_norm_(
max_norm=self.cfg.actor.optim.clip_grad
@@ -411,9 +898,9 @@ def run_training(self):
)
self.optimizer.zero_grad()
- torch.cuda.synchronize()
+ torch.npu.synchronize()
torch.distributed.barrier()
- torch.cuda.empty_cache()
+ torch.npu.empty_cache()
return mean_metric_dict
diff --git a/rlinf/workers/actor/megatron_actor_worker.py b/rlinf/workers/actor/megatron_actor_worker.py
index 54376e1a5..40289ddf0 100644
--- a/rlinf/workers/actor/megatron_actor_worker.py
+++ b/rlinf/workers/actor/megatron_actor_worker.py
@@ -14,7 +14,7 @@
import copy
from functools import partial
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, Optional, Tuple
import torch
import torch.distributed
@@ -31,7 +31,6 @@
from rlinf.algorithms.registry import (
actor_loss,
calculate_adv_and_returns,
- get_reward_fn,
)
from rlinf.algorithms.utils import kl_penalty
from rlinf.data.io_struct import (
@@ -53,7 +52,6 @@
)
from rlinf.utils.distributed import (
RolloutDataBalance,
- broadcast_tensor_within_mp,
broadcast_tensor_within_pp,
compute_rollout_metrics,
masked_normalization,
@@ -80,7 +78,6 @@
seq_mean_token_sum,
)
from rlinf.workers.rollout.utils import RankMapper
-from toolkits import register_rewards
class MegatronActor(MegatronModelManager, Worker):
@@ -103,6 +100,10 @@ def __init__(
self.cfg = cfg
self.component_placement = placement
+ # check placement validity when actor backend is megatron
+ assert placement.rollout_tp_size <= placement.actor_tp_size, (
+ f" rollout tensor parallel size {placement.rollout_tp_size} must be less than or equal to actor tensor parallel size {placement.actor_tp_size}."
+ )
# Data configurations
self.response_len = (
role_cfg.model.encoder_seq_length - cfg.data.max_prompt_length
@@ -115,10 +116,21 @@ def __init__(
self.calculate_entropy_loss = (
self.cfg.algorithm.entropy_bonus > 0 and self.calculate_entropy
)
- self.ratio_eps = self.cfg.algorithm.ratio_clip_eps
+ clip_ratio = self.cfg.algorithm.ratio_clip_eps
+ self.clip_ratio_low = (
+ self.cfg.algorithm.get("clip_ratio_low")
+ if self.cfg.algorithm.get("clip_ratio_low") is not None
+ else clip_ratio
+ )
+ self.clip_ratio_high = (
+ self.cfg.algorithm.get("clip_ratio_high")
+ if self.cfg.algorithm.get("clip_ratio_high") is not None
+ else clip_ratio
+ )
self.logprob_forward_micro_batch_size = (
self.cfg.algorithm.logprob_forward_micro_batch_size
)
+
self.kl_beta = self.cfg.algorithm.kl_beta
self.kl_penalty_type = self.cfg.algorithm.kl_penalty_type
self.clip_ratio_c = self.cfg.algorithm.clip_ratio_c
@@ -144,11 +156,6 @@ def __init__(
self.ref_policy_state_dict = None
self.is_pipeline = self.component_placement.is_disaggregated
- # Reward configurations
- if not self.cfg.reward.use_reward_model:
- register_rewards()
- self.reward_fn = get_reward_fn(self.cfg.reward.reward_type)
-
# Rollout configurations
self.rollout_group_name = self.cfg.rollout.group_name
@@ -382,7 +389,8 @@ def loss_func(output):
logprobs=curr_logprobs,
old_logprobs=prev_logprobs,
advantages=advantages,
- eps_clip=self.ratio_eps,
+ clip_ratio_low=self.clip_ratio_low,
+ clip_ratio_high=self.clip_ratio_high,
loss_mask=mask,
)
@@ -830,23 +838,29 @@ def run_inference(
self,
input_channel: Channel,
output_channel: Channel,
+ rollout_channel: Optional[Channel],
compute_ref_logprobs: bool,
):
- """Compute prev/ref logprobs using the actor Model's forward.
+ """
+ Compute prev/ref logprobs using the actor Model's forward.
Args:
input_channel: The input channel to read from.
output_channel: The output channel to send results to.
+ rollout_channel: get the rollout channel's device lock in case of collision.
compute_ref_logprobs: Whether to compute reference logprobs.
"""
recv_batch_size = 0
while recv_batch_size < self.total_batch_size_per_dp:
batch, rollout_result = self.get_batch(input_channel)
recv_batch_size += rollout_result.num_sequence
-
# Must be called after batch is retrieved, suggesting that rollout has stopped
# Otherwise, loading model might cause OOM in the collocated mode
- self._load_weight_and_optimizer(input_channel)
+ self._load_weight_and_optimizer(
+ input_channel
+ if self.is_pipeline or rollout_channel is None
+ else rollout_channel
+ )
# Prev logprobs
with self.worker_timer():
@@ -859,91 +873,12 @@ def run_inference(
with cpu_weight_swap(self.model[0], self.ref_policy_state_dict):
ref_logprobs = self.inference_step(batch)
rollout_result.ref_logprobs = ref_logprobs.cpu()
-
- self.put_result(rollout_result, output_channel)
-
- assert recv_batch_size == self.total_batch_size_per_dp, (
- f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}"
- )
-
- # Rewards
- def compute_rewards(self, input_channel: Channel, output_channel: Channel):
- """Compute rewards.
-
- Args:
- input_channel: The input channel to read from.
- output_channel: The output channel to send results to.
- """
- assert self.reward_fn is not None, "reward_fn is not set"
- if self.is_pipeline:
- # In pipeline mode, rewards are computed in the rollout
- with self.worker_timer():
- return
- recv_batch_size = 0
- while recv_batch_size < self.total_batch_size_per_dp:
- batch, rollout_result = self.get_batch(input_channel)
- recv_batch_size += rollout_result.num_sequence
-
- # Compute rule-based reward
- with self.worker_timer():
- if rollout_result.rewards is None:
- rollout_result.rewards = self._compute_batch_rewards(
- batch, rollout_result.answers
- ).cpu()
-
self.put_result(rollout_result, output_channel)
assert recv_batch_size == self.total_batch_size_per_dp, (
f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}"
)
- def _compute_batch_rewards(
- self, batch: Dict[str, torch.Tensor], answers: List[str]
- ):
- """Reward computation using non-model based reward."""
- all_reward_scores = []
- texts = []
- for response, response_len in zip(
- batch["input_ids"],
- batch["response_lengths"],
- ):
- response = response[
- self.cfg.data.max_prompt_length : self.cfg.data.max_prompt_length
- + response_len
- ]
- texts.append(
- self.tokenizer.decode(response.tolist(), skip_special_tokens=True)
- )
-
- if torch.distributed.get_rank() == parallel_state.get_model_parallel_src_rank():
- rewards = self.reward_fn(texts, answers)
- if self.cfg.reward.reward_type == "math":
- reward_scores = [
- self.cfg.reward.reward_scale
- if reward == 1
- else -self.cfg.reward.reward_scale
- for reward in rewards
- ]
- else:
- reward_scores = rewards
-
- all_reward_scores.extend(reward_scores)
-
- if len(all_reward_scores) > 0:
- new_all_rewards = []
-
- for response in all_reward_scores:
- if response is None:
- response = 0.0
- new_all_rewards.append(response)
-
- all_reward_scores = torch.as_tensor(
- new_all_rewards,
- dtype=torch.float,
- device=torch.cuda.current_device(),
- ).view(-1, 1)
- return broadcast_tensor_within_mp(all_reward_scores).flatten().to("cpu")
-
# Advantages and returns
def compute_advantages_and_returns(
self, input_channel: Channel, output_channel: Channel
@@ -954,16 +889,11 @@ def compute_advantages_and_returns(
input_channel: The input channel to read from.
output_channel: The output channel to send results to.
"""
- if self.is_pipeline:
- # In pipeline mode, advantages are computed in the rollout
- with self.worker_timer():
- return
clear_memory()
recv_batch_size = 0
while recv_batch_size < self.total_batch_size_per_dp:
batch, rollout_result = self.get_batch(input_channel)
recv_batch_size += rollout_result.num_sequence
-
with self.worker_timer():
if rollout_result.advantages is None:
mask = batch["attention_mask"][:, -self.response_len :]
@@ -1033,7 +963,11 @@ def sync_model_to_rollout(self):
def _compute_rollout_metrics(self, batch):
rollout_metrics, total_prompt_lengths, total_decode_lengths = (
compute_rollout_metrics(
- batch, self.cfg.data.max_prompt_length, self.response_len
+ batch,
+ self.cfg.data.max_prompt_length,
+ self.response_len,
+ self._world_size,
+ dp_group=parallel_state.get_data_parallel_group(),
)
)
diff --git a/rlinf/workers/reward/reward_worker.py b/rlinf/workers/reward/reward_worker.py
new file mode 100644
index 000000000..88be65ddc
--- /dev/null
+++ b/rlinf/workers/reward/reward_worker.py
@@ -0,0 +1,104 @@
+# Copyright 2025 The RLinf Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Tuple
+
+import torch
+from omegaconf import DictConfig
+
+from rlinf.algorithms.rewards import get_reward_class
+from rlinf.data.io_struct import RolloutResult
+from rlinf.data.tokenizers import hf_tokenizer
+from rlinf.scheduler import Channel, Worker
+from rlinf.utils.placement import ModelParallelComponentPlacement
+
+
+class RewardWorker(Worker):
+ def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement):
+ Worker.__init__(self)
+ self.cfg = cfg
+ self.component_placement = placement
+ self.tokenizer = hf_tokenizer(cfg.reward.tokenizer.tokenizer_model)
+ self.total_batch_size_per_dp = (
+ self.cfg.data.rollout_batch_size
+ * self.cfg.algorithm.get("group_size", 1)
+ // self._world_size
+ )
+
+ def init_worker(self):
+ if self.cfg.reward.use_reward_model:
+ raise NotImplementedError("Reward model is not implemented yet.")
+ else:
+ self.reward = get_reward_class(self.cfg.reward.reward_type)(self.cfg.reward)
+
+ def get_batch(
+ self, channel: Channel
+ ) -> Tuple[Dict[str, torch.Tensor], RolloutResult]:
+ result: RolloutResult = channel.get()
+ batch = result.to_actor_batch(
+ self.cfg.data.max_prompt_length,
+ self.cfg.actor.model.encoder_seq_length,
+ self.tokenizer.eos_token_id,
+ )
+ return batch, result
+
+ def compute_rewards(self, input_channel: Channel, output_channel: Channel):
+ """Compute rewards.
+
+ Args:
+ input_channel: The input channel to read from.
+ output_channel: The output channel to send results to.
+ """
+ recv_batch_size = 0
+ while recv_batch_size < self.total_batch_size_per_dp:
+ rollout_result: RolloutResult = input_channel.get()
+ recv_batch_size += rollout_result.num_sequence
+ with self.worker_timer():
+ if rollout_result.rewards is None:
+ if self.cfg.reward.use_reward_model:
+ with input_channel.device_lock:
+ batch = rollout_result.to_actor_batch(
+ self.cfg.data.max_prompt_length,
+ self.cfg.actor.model.encoder_seq_length,
+ self.tokenizer.eos_token_id,
+ )
+ rollout_result.rewards = (
+ self.compute_batch_rewards_with_model(batch)
+ )
+ else:
+ rollout_result.rewards = self._compute_rule_based_rewards(
+ rollout_result
+ )
+
+ output_channel.put(rollout_result)
+
+ assert recv_batch_size == self.total_batch_size_per_dp, (
+ f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}"
+ )
+
+ def _compute_rule_based_rewards(self, rollout_result: RolloutResult):
+ # Decode only the generated tokens; response_ids are already the post-prompt tokens
+ texts = self.tokenizer.batch_decode(
+ rollout_result.response_ids, skip_special_tokens=True
+ )
+
+ scores = self.reward.get_reward(texts, rollout_result.answers)
+ return (
+ torch.as_tensor(scores, dtype=torch.float, device=torch.device("cpu"))
+ .view(-1, 1)
+ .flatten()
+ )
+
+ def compute_batch_rewards_with_model(self, batch: Dict[str, torch.Tensor]):
+ raise NotImplementedError("Reward model is not implemented yet.")
diff --git a/rlinf/workers/rollout/sglang/__init__.py b/rlinf/workers/rollout/sglang/__init__.py
index 5e0fb219f..cc78162b9 100644
--- a/rlinf/workers/rollout/sglang/__init__.py
+++ b/rlinf/workers/rollout/sglang/__init__.py
@@ -49,6 +49,12 @@ def get_version(pkg):
from rlinf.hybrid_engines.sglang.sglang_0_4_9.sgl_engine import (
Engine,
)
+elif package_version >= parse("0.5.0") and package_version < parse("0.5.3"):
+ sglang_version = package_version
+ from rlinf.hybrid_engines.sglang.sglang_0_5_2 import io_struct
+ from rlinf.hybrid_engines.sglang.sglang_0_5_2.sgl_engine import (
+ Engine,
+ )
else:
raise ValueError(f"sglang version {package_version} not supported")
diff --git a/rlinf/workers/rollout/sglang/sglang_worker.py b/rlinf/workers/rollout/sglang/sglang_worker.py
index 8d4a15cb7..a24b297f2 100644
--- a/rlinf/workers/rollout/sglang/sglang_worker.py
+++ b/rlinf/workers/rollout/sglang/sglang_worker.py
@@ -18,7 +18,6 @@
from typing import Dict, List, Optional, Tuple
import numpy as np
-import torch
from omegaconf import DictConfig
from sglang.srt.server_args import ServerArgs
from transformers import AutoTokenizer
@@ -35,7 +34,6 @@
from rlinf.workers.rollout.utils import (
print_sglang_outputs,
)
-from toolkits.math_verifier.verify import MathRewardModel, math_verify_call
class SGLangWorker(Worker):
@@ -115,6 +113,7 @@ def _init_engine(self):
log_level="info",
max_running_requests=self._cfg.rollout.max_running_requests,
dist_init_addr=f"127.0.0.1:{str(Cluster.find_free_port())}",
+ device="npu",
)
self.log_on_first_rank(f"{server_args=}")
@@ -169,7 +168,6 @@ def sync_model_from_actor(self):
def rollout(self, input_channel: Channel, output_channel: Channel):
request: RolloutRequest = input_channel.get()
-
# Repeat prompts based on the group_size config
requests = request.repeat_and_split(self._rollout_batch_size)
@@ -181,6 +179,8 @@ def rollout(self, input_channel: Channel, output_channel: Channel):
with self.worker_timer():
results = self._engine.generate(
input_ids=request.input_ids,
+ # 0.4.4 has modality bug,can't pass non-None image_data
+ image_data=request.image_data if any(request.image_data) else None,
sampling_params=self._sampling_params,
return_logprob=self._return_logprobs,
)
@@ -191,6 +191,8 @@ def rollout(self, input_channel: Channel, output_channel: Channel):
request.n,
request.input_ids,
request.answers,
+ request.image_data,
+ request.multi_modal_inputs,
self._return_logprobs,
)
rollout_results.append(rollout_result)
@@ -205,8 +207,9 @@ def rollout(self, input_channel: Channel, output_channel: Channel):
self._stop()
# Release the GPUs once the engine has offloaded
output_channel.device_lock.release()
- rollout_result = RolloutResult.merge_result_list(rollout_results)
- output_channel.put(rollout_result)
+ rollout_result_list = RolloutResult.split_result_list_by_group(rollout_results)
+ for rollout_result in rollout_result_list:
+ output_channel.put(rollout_result)
def all_floats_equal(float_list: list[float], epsilon: float = 1e-9) -> bool:
@@ -229,7 +232,6 @@ def __init__(self, config: DictConfig, placement: ComponentPlacement):
self._rollout_end_event = asyncio.Event()
self._sync_weight_end_event = asyncio.Event()
- self._reward_model = MathRewardModel(scale=self._cfg.reward.reward_scale)
assert self._rollout_batch_size is None, (
"rollout_batch_size_per_gpu is not supported in AsyncSGLangWorker"
)
@@ -258,29 +260,6 @@ async def init_worker(self):
if self._cfg.rollout.validate_weight:
await self._validate_weight_at_first()
- async def _compute_reward_and_advantage(
- self, engine_results: List[Dict], answer: str
- ):
- answers = [answer] * len(engine_results)
- texts: List[str] = []
- for res in engine_results:
- if hasattr(res, "text"):
- texts.append(res["text"])
- else:
- texts.append(
- self._tokenizer.decode(res["output_ids"], skip_special_tokens=True)
- )
-
- results = math_verify_call(texts, answers)
- rewards = [(1 if r else -1) * self._reward_model.scale for r in results]
- rewards_tensor = torch.tensor(rewards, dtype=torch.float)
-
- mean = rewards_tensor.mean()
- std = rewards_tensor.std()
- advantages = (rewards_tensor - mean) / (std + 1e-6)
-
- return rewards, advantages.tolist()
-
async def _async_generate(
self, raw_id: int, input_ids: List[int], sampling_params: dict
):
@@ -319,7 +298,6 @@ async def rollout(self, input_channel: Channel, output_channel: Channel):
total_reqs = len(rollout_tasks)
required_reqs = total_reqs // self._cfg.algorithm.max_num_gen_batches
- droped_reqs = 0
finished_reqs = 0
abort_flag = False
@@ -330,32 +308,17 @@ async def rollout(self, input_channel: Channel, output_channel: Channel):
if self._completion_info.is_completed(hash_id):
results = self._completion_info.get_results(hash_id)
- (
- rewards,
- advantages,
- ) = await self._compute_reward_and_advantage(
- results,
- self._current_request.answers[raw_id],
- )
- if (
- all_floats_equal(rewards)
- and self._cfg.algorithm.get("max_num_gen_batches", 1) > 1
- ):
- if (total_reqs - droped_reqs) > required_reqs:
- droped_reqs += rollout_request.n
- continue
input_ids = [input_ids] * len(results)
+ answers = [rollout_request.answers[raw_id]] * len(results)
rollout_result = RolloutResult.from_sglang_results(
results,
rollout_request.n,
input_ids,
+ answers=answers,
return_logprobs=self._return_logprobs,
)
- rollout_result.rewards = torch.tensor(
- rewards, dtype=torch.float32
- ).reshape(-1, 1)
- rollout_result.advantages = advantages
+
return_tasks.append(
asyncio.create_task(
self._put_result(rollout_result, output_channel)
diff --git a/rlinf/workers/rollout/utils.py b/rlinf/workers/rollout/utils.py
index f3845ef2f..f92e48caa 100644
--- a/rlinf/workers/rollout/utils.py
+++ b/rlinf/workers/rollout/utils.py
@@ -376,6 +376,12 @@ def get_actor_rank_to_rollout_rank_map(
"""
Get the global mapping from actor 1D rank to rollout 2D rank as dict.
"""
+ # rank -> (dp, tp)
+ if actor_tp_size == 1:
+ return {
+ rank: (rank // rollout_tp_size, rank % rollout_tp_size)
+ for rank in range(actor_world_size)
+ }
rank_map = {}
for actor_rank in range(actor_world_size):
rank_map[actor_rank] = cls._get_actor_rank_to_rollout_rank(
diff --git a/rlinf/workers/rollout/vllm/vllm_worker.py b/rlinf/workers/rollout/vllm/vllm_worker.py
index d5ecd11fc..7324fe841 100644
--- a/rlinf/workers/rollout/vllm/vllm_worker.py
+++ b/rlinf/workers/rollout/vllm/vllm_worker.py
@@ -19,9 +19,8 @@
from typing import AsyncGenerator, List, Optional, Union
import requests
-import torch
from omegaconf import DictConfig
-from PIL.Image import Image
+from PIL import Image
from transformers import AutoTokenizer
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
@@ -36,7 +35,6 @@
from rlinf.scheduler import Channel, Worker
from rlinf.utils.placement import ComponentPlacement
from rlinf.workers.rollout.utils import print_vllm_outputs
-from toolkits.math_verifier.verify import MathRewardModel, math_verify_call
from . import VLLMExecutor
@@ -69,7 +67,6 @@ def __init__(self, config: DictConfig, placement: ComponentPlacement):
"The capital of France is",
"The future of AI is",
]
- self._reward_model = MathRewardModel(self._cfg.reward.reward_scale)
self.request_counter = Counter()
def _prepare_vllm_environment(self) -> None:
@@ -227,6 +224,23 @@ async def generate(
Union[List[List[Union[bytes, str]]], List[Union[bytes, str]]]
] = None,
) -> List[RequestOutput]:
+ """
+ Do Generate Task using the vllm async engine.
+
+ Args:
+ input_ids: The input token ids to generate. It can be a list of list of int,
+ or a list of int (single prompt).
+ sampling_params: The sampling parameters to use for generation.
+ prompt_texts: The input prompt texts to generate. It can be a list of strings
+ or a single string. If provided, it will be used instead of input_ids.
+ image_data: The input multi-modal data to generate. It can be a list of list
+ of bytes or image paths (local or URL), or a list of bytes or image paths
+ (single prompt).
+
+ Returns:
+ List[RequestOutput]: A list of RequestOutput from vllm engine.
+ """
+
def check_input_ids() -> List[List[int]]:
assert isinstance(input_ids, list), (
"input_ids should be a list or list of list of int."
@@ -249,8 +263,8 @@ def check_prompt_text() -> Optional[List[str]]:
assert len(prompt_texts) > 0, "prompt_text should not be empty."
return prompt_texts
- def check_image_data() -> Optional[List[List[Image]]]:
- if image_data is None:
+ def check_image_data() -> Optional[List[List[Image.Image]]]:
+ if image_data is None or not any(image_data):
return None
assert isinstance(image_data, list), "image_data should be a list."
if isinstance(image_data[0], list):
@@ -267,19 +281,22 @@ def check_image_data() -> Optional[List[List[Image]]]:
if prompt_texts is not None:
for i, prompt_text in enumerate(prompt_texts):
if image_list is not None:
- image_list = self._process_image_data(image_data=image_list[i])
+ images = self._process_image_data(image_data=image_list[i])
inputs.append(
- TextPrompt(prompt=prompt_text, multi_modal_data=image_list)
+ TextPrompt(
+ prompt=prompt_text, multi_modal_data={"image": images}
+ )
)
else:
inputs.append(TextPrompt(prompt=prompt_text))
else:
for i, input_id in enumerate(input_ids):
if image_list is not None:
- image_list = self._process_image_data(image_data=image_list[i])
+ images = self._process_image_data(image_data=image_list[i])
inputs.append(
TokensPrompt(
- prompt_token_ids=input_id, multi_modal_data=image_list
+ prompt_token_ids=input_id,
+ multi_modal_data={"image": images},
)
)
else:
@@ -303,7 +320,8 @@ def check_image_data() -> Optional[List[List[Image]]]:
async def init_worker(self) -> None:
"""
Use EngineArgs and VllmConfig to initialize VLLM async engine.
- Then offload the model weights, ready to use weights sent from actor.
+ If mode is collocated, it will additionally offload model weights,
+ ready to use parameters sent from actor.
"""
engine_args: EngineArgs = EngineArgs(
model=self._cfg.rollout.model_dir,
@@ -319,7 +337,8 @@ async def init_worker(self) -> None:
trust_remote_code=self._cfg.actor.tokenizer.trust_remote_code,
max_model_len=self._cfg.runner.seq_length,
max_num_seqs=self._cfg.rollout.max_running_requests,
- enable_sleep_mode=True, # it enables offload weights
+ enable_sleep_mode=False,
+ device="npu",
)
vllm_config: VllmConfig = engine_args.create_engine_config()
@@ -350,7 +369,21 @@ async def init_worker(self) -> None:
await self.offload_model_weights()
async def _put_result(self, result: RolloutResult, output_channel: Channel) -> None:
- await output_channel.put(result, async_op=True).async_wait()
+ """
+ Helper function to put the result to output channel.
+
+ Args:
+ result: The RolloutResult to put to the channel.
+ output_channel: The output channel to send results to.
+ """
+ # NOTE:
+ # To fit reward worker and actor workers' expected input count,
+ # currently we can only split result into groups.
+ splited_results = RolloutResult.split_result_list_by_group([result])
+ put_tasks = [
+ output_channel.put(r, async_op=True).async_wait() for r in splited_results
+ ]
+ await asyncio.gather(*put_tasks)
async def _stop(self) -> None:
"""
@@ -363,50 +396,43 @@ async def _stop(self) -> None:
if not self._placement.is_disaggregated:
await self.offload_model_weights()
- async def _compute_reward_and_advantage(self, rollout_result: RolloutResult):
- """
- Compute rewards and advantages for the rollout result using math verification.
- """
- answers = rollout_result.answers
- outputs = rollout_result.response_texts
- num_sequence = rollout_result.num_sequence
- assert len(answers) == len(outputs), (
- f"Answers length {len(answers)} != outputs length {len(outputs)}"
- )
- assert len(answers) == num_sequence, (
- f"Answers length {len(answers)} != num_sequence {num_sequence}"
- )
-
- math_verify_results = math_verify_call(outputs, answers)
- rewards = [
- (1 if r else -1) * self._reward_model.scale for r in math_verify_results
- ]
- rewards_tensor = torch.tensor(rewards, dtype=torch.float)
- rollout_result.rewards = rewards_tensor.reshape(-1, 1)
-
- mean = rewards_tensor.mean()
- std = rewards_tensor.std(unbiased=False)
- advantages = (rewards_tensor - mean) / (std + 1e-6)
- rollout_result.advantages = advantages.tolist()
-
async def rollout_and_return(
self, request: RolloutRequest, output_channel: Channel
):
+ """
+ Helper function to rollout for a single RolloutRequest and build RolloutResult then
+ put it to output channel.
+
+ Args:
+ request: The RolloutRequest to process.
+ output_channel: The output channel to send results to.
+ """
vllm_results: List[RequestOutput] = await self.generate(
- input_ids=request.input_ids, sampling_params=self._sampling_params
+ input_ids=request.input_ids,
+ image_data=request.image_data,
+ sampling_params=self._sampling_params,
)
rollout_result: RolloutResult = RolloutResult.from_vllm_results(
group_size=self._cfg.algorithm.group_size,
results=vllm_results,
answers=request.answers,
+ multi_modal_inputs=request.multi_modal_inputs,
return_logprobs=self._return_logprobs,
)
- if self._placement.is_disaggregated:
- await self._compute_reward_and_advantage(rollout_result)
-
+ if self._cfg.rollout.print_outputs:
+ print_vllm_outputs(outputs=vllm_results)
await self._put_result(result=rollout_result, output_channel=output_channel)
async def rollout(self, input_channel: Channel, output_channel: Channel) -> None:
+ """
+ Perform rollout using vllm engine.
+ It will read `RolloutRequest` from input_channel and put `RolloutResult` to output_channel.
+ If the input request is None, it will stop the rollout.
+
+ Args:
+ input_channel: The input channel to read from.
+ output_channel: The output channel to send results to.
+ """
rollout_request: RolloutRequest = await input_channel.get(
async_op=True
).async_wait()
diff --git a/test.py b/test.py
new file mode 100644
index 000000000..485f51f99
--- /dev/null
+++ b/test.py
@@ -0,0 +1,3 @@
+from safetensors import safe_open
+with safe_open("/home/weight/Qwen2.5-VL-3B-Instruct/model-00001-of-00002.safetensors", framework="pt") as f:
+ print(f.keys())
diff --git a/tests/e2e_tests/auto_placement/qwen2.5-1.5b-grpo.yaml b/tests/e2e_tests/auto_placement/qwen2.5-1.5b-grpo.yaml
index 6555cb9bf..d1c65161b 100644
--- a/tests/e2e_tests/auto_placement/qwen2.5-1.5b-grpo.yaml
+++ b/tests/e2e_tests/auto_placement/qwen2.5-1.5b-grpo.yaml
@@ -9,10 +9,10 @@ hydra:
cluster:
num_nodes: 1
component_placement:
- actor,rollout: all
+ actor,rollout,reward: all
runner:
- task_type: math
+ task_type: reasoning
logger:
log_path: ${runner.output_dir}/${runner.experiment_name}
project_name: rlinf
@@ -119,6 +119,7 @@ rollout:
data:
type: math
+ dataset_name: boba
max_prompt_length: 1024
filter_prompt_by_length: True
rollout_batch_size: 128
@@ -256,13 +257,20 @@ actor:
reward:
+ group_name: "RewardGroup"
use_reward_model: false
reward_type: 'math'
reward_scale: 5.0
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
critic:
use_critic_model: false
+
profile_data:
actor_cost: 95.7
inference_cost: 30.8
diff --git a/tests/e2e_tests/auto_placement/run_auto_placement.sh b/tests/e2e_tests/auto_placement/run.sh
similarity index 100%
rename from tests/e2e_tests/auto_placement/run_auto_placement.sh
rename to tests/e2e_tests/auto_placement/run.sh
diff --git a/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml b/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml
index 1de266648..f34ee4cae 100644
--- a/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml
+++ b/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml
@@ -12,6 +12,7 @@ cluster:
rollout: 0-3
inference: 4-5
actor: 6-7
+ reward: 0-3
runner:
task_type: coding_online_rl
@@ -282,7 +283,7 @@ actor:
reward:
use_reward_model: False
- reward_type: fim_verify_call
+ reward_type: code
reward_scale: 5.0
critic:
diff --git a/tests/e2e_tests/coding_online_rl/run_coding_online_rl.sh b/tests/e2e_tests/coding_online_rl/run.sh
similarity index 100%
rename from tests/e2e_tests/coding_online_rl/run_coding_online_rl.sh
rename to tests/e2e_tests/coding_online_rl/run.sh
diff --git a/tests/e2e_tests/embodied/libero_130_grpo_openvlaoft.yaml b/tests/e2e_tests/embodied/libero_130_grpo_openvlaoft.yaml
index 2365b5a86..2672dcd01 100644
--- a/tests/e2e_tests/embodied/libero_130_grpo_openvlaoft.yaml
+++ b/tests/e2e_tests/embodied/libero_130_grpo_openvlaoft.yaml
@@ -157,6 +157,12 @@ actor:
trust_remote_code: True
padding_side: "right"
+ fsdp:
+ forward_prefetch: True
+ limit_all_gathers: True
+ backward_prefetch: True
+ use_orig_params: True
+
reward:
use_reward_model: False
diff --git a/tests/e2e_tests/embodied/libero_goal_grpo_openvlaoft.yaml b/tests/e2e_tests/embodied/libero_goal_grpo_openvlaoft.yaml
index ddfe4a500..6dc7893d3 100644
--- a/tests/e2e_tests/embodied/libero_goal_grpo_openvlaoft.yaml
+++ b/tests/e2e_tests/embodied/libero_goal_grpo_openvlaoft.yaml
@@ -156,6 +156,12 @@ actor:
trust_remote_code: True
padding_side: "right"
+ fsdp:
+ forward_prefetch: False
+ limit_all_gathers: False
+ backward_prefetch: False
+ use_orig_params: False
+
reward:
use_reward_model: False
diff --git a/tests/e2e_tests/embodied/maniskill_grpo_openvlaoft.yaml b/tests/e2e_tests/embodied/maniskill_grpo_openvlaoft.yaml
index ab384947a..f26ce7c3d 100644
--- a/tests/e2e_tests/embodied/maniskill_grpo_openvlaoft.yaml
+++ b/tests/e2e_tests/embodied/maniskill_grpo_openvlaoft.yaml
@@ -155,6 +155,12 @@ actor:
adam_eps: 1.0e-05
clip_grad: 10.0
+ fsdp:
+ forward_prefetch: False
+ limit_all_gathers: False
+ backward_prefetch: False
+ use_orig_params: False
+
reward:
use_reward_model: False
diff --git a/tests/e2e_tests/embodied/maniskill_ppo_openvla.yaml b/tests/e2e_tests/embodied/maniskill_ppo_openvla.yaml
index 2063ab483..5d9258743 100644
--- a/tests/e2e_tests/embodied/maniskill_ppo_openvla.yaml
+++ b/tests/e2e_tests/embodied/maniskill_ppo_openvla.yaml
@@ -165,6 +165,12 @@ actor:
adam_eps: 1.0e-05
clip_grad: 1.0
+ fsdp:
+ forward_prefetch: True
+ limit_all_gathers: True
+ backward_prefetch: True
+ use_orig_params: True
+
reward:
use_reward_model: False
diff --git a/tests/e2e_tests/math/sglang/run_collocated.sh b/tests/e2e_tests/math/sglang/run_collocated.sh
deleted file mode 100644
index 1610f8fac..000000000
--- a/tests/e2e_tests/math/sglang/run_collocated.sh
+++ /dev/null
@@ -1,17 +0,0 @@
-#! /bin/bash
-set -x
-
-tabs 4
-export VLLM_ATTENTION_BACKEND=XFORMERS
-export CUDA_DEVICE_MAX_CONNECTIONS=1
-export TOKENIZERS_PARALLELISM=false
-
-export PYTHONPATH=${REPO_PATH}:$PYTHONPATH
-
-if [ -z "$1" ]; then
- CONFIG_NAME="qwen2.5-1.5b-grpo-collocated"
-else
- CONFIG_NAME=$1
-fi
-
-python ${REPO_PATH}/examples/math/main_math.py --config-path $REPO_PATH/tests/e2e_tests/math/sglang --config-name $CONFIG_NAME
\ No newline at end of file
diff --git a/tests/e2e_tests/math/sglang/run_pipeline.sh b/tests/e2e_tests/math/sglang/run_pipeline.sh
deleted file mode 100644
index 85e2e5c2d..000000000
--- a/tests/e2e_tests/math/sglang/run_pipeline.sh
+++ /dev/null
@@ -1,17 +0,0 @@
-#! /bin/bash
-set -x
-
-tabs 4
-export VLLM_ATTENTION_BACKEND=XFORMERS
-export CUDA_DEVICE_MAX_CONNECTIONS=1
-export TOKENIZERS_PARALLELISM=false
-
-export PYTHONPATH=${REPO_PATH}:$PYTHONPATH
-
-if [ -z "$1" ]; then
- CONFIG_NAME="qwen2.5-1.5b-grpo-pipeline"
-else
- CONFIG_NAME=$1
-fi
-
-python ${REPO_PATH}/examples/math/main_math.py --config-path $REPO_PATH/tests/e2e_tests/math/sglang --config-name $CONFIG_NAME
\ No newline at end of file
diff --git a/tests/e2e_tests/math/vllm/run_collocated.sh b/tests/e2e_tests/math/vllm/run_collocated.sh
deleted file mode 100644
index b4e924b1d..000000000
--- a/tests/e2e_tests/math/vllm/run_collocated.sh
+++ /dev/null
@@ -1,17 +0,0 @@
-#! /bin/bash
-set -x
-
-tabs 4
-export VLLM_ATTENTION_BACKEND=XFORMERS
-export CUDA_DEVICE_MAX_CONNECTIONS=1
-export TOKENIZERS_PARALLELISM=false
-
-export PYTHONPATH=${REPO_PATH}:$PYTHONPATH
-
-if [ -z "$1" ]; then
- CONFIG_NAME="qwen2.5-1.5b-grpo-collocated"
-else
- CONFIG_NAME=$1
-fi
-
-python ${REPO_PATH}/examples/math/main_math.py --config-path $REPO_PATH/tests/e2e_tests/math/vllm --config-name $CONFIG_NAME
\ No newline at end of file
diff --git a/tests/e2e_tests/math/vllm/run_pipeline.sh b/tests/e2e_tests/math/vllm/run_pipeline.sh
deleted file mode 100644
index 0a21368f6..000000000
--- a/tests/e2e_tests/math/vllm/run_pipeline.sh
+++ /dev/null
@@ -1,17 +0,0 @@
-#! /bin/bash
-set -x
-
-tabs 4
-export VLLM_ATTENTION_BACKEND=XFORMERS
-export CUDA_DEVICE_MAX_CONNECTIONS=1
-export TOKENIZERS_PARALLELISM=false
-
-export PYTHONPATH=${REPO_PATH}:$PYTHONPATH
-
-if [ -z "$1" ]; then
- CONFIG_NAME="qwen2.5-1.5b-grpo-pipeline"
-else
- CONFIG_NAME=$1
-fi
-
-python ${REPO_PATH}/examples/math/main_math.py --config-path $REPO_PATH/tests/e2e_tests/math/vllm --config-name $CONFIG_NAME
\ No newline at end of file
diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml
new file mode 100644
index 000000000..86c6f362c
--- /dev/null
+++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml
@@ -0,0 +1,219 @@
+defaults:
+ - override hydra/job_logging: stdout
+
+hydra:
+ run:
+ dir: .
+ output_subdir: null
+
+cluster:
+ num_nodes: 1
+ component_placement:
+ actor,rollout,reward: all
+
+runner:
+ task_type: reasoning
+ logger:
+ log_path: /workspace/results/
+ project_name: rlinf
+ experiment_name: "ci-test"
+ logger_backends: ["tensorboard"] # wandb, swanlab
+
+ max_epochs: 1
+ max_steps: 3
+
+ val_check_interval: 1
+ save_interval: -1
+
+ seq_length: 1024
+
+ enable_dynamic_batch_size: True
+ max_tokens_per_mbs: 1024
+
+ resume_dir: null
+ experiment_name: grpo-1.5b
+ output_dir: /workspace/results
+algorithm:
+ group_size: 2
+
+ n_minibatches: 4
+ training_batch_size_per_gpu: 1 # micro batch size
+ rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory.
+
+ # mbs to do log prob inference, can be set to
+ # lower than rollout_batch_size_per_gpu to reduce
+ # memory usage
+ logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu}
+
+ # val rollout mbs
+ val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu}
+
+ recompute_logprobs: False
+ shuffle_rollout: False
+
+ # GRPO loss params
+ loss_type: math_ppo_actor
+ loss_agg_func: "token-mean"
+ kl_beta: 0.0 # 0.001
+ kl_penalty_type: low_var_kl
+ ratio_clip_eps: 0.2
+ entropy_bonus: 0.0
+ calculate_entropy: False
+ clip_ratio_c: null # 3.0
+ clip_ratio_low: null
+ clip_ratio_high: null
+
+ adv_type: math_grpo
+ normalize_advantages: True
+ early_stop_imp_ratio: 5.0
+ use_valid_token_scale: False
+
+ # params for rollout
+ sampling_params:
+ use_greedy: False
+ temperature: 1.0
+ top_k: 1000000
+ top_p: 1.0
+ repetition_penalty: 1.0
+ max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}}
+ min_new_tokens: 1
+
+rollout:
+ group_name: "RolloutGroup"
+
+ gpu_memory_utilization: 0.55
+
+ model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B
+ model_arch: qwen2.5
+ enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize.
+ distributed_executor_backend: mp # ray or mp
+ disable_log_stats: False
+ detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging.
+ padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine
+ eos: null # will be tokenizer.eos_token_id if null.
+
+ rollout_backend: sglang # [sglang, vllm]
+
+ sglang:
+ attention_backend: triton # [flashinfer, triton] for more, see sglang's doc
+ decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats.
+ use_torch_compile: False # enable torch_compile in SGLang for rollout.
+ torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used.
+
+ vllm:
+ attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc
+ enable_chunked_prefill: True # enable vllm to use chunked_prefill.
+ enable_prefix_caching: True # enable vllm to use prefix_caching.
+ enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling.
+
+ return_logprobs: ${not:${algorithm.recompute_logprobs}}
+
+ tensor_parallel_size: 1
+ pipeline_parallel_size: 1
+
+ validate_weight: False # whether to send all weights at first for weight comparison.
+ validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison.
+ print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine.
+
+ max_running_requests: 64 # the maximum number of running requests in the rollout engine.
+ cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used.
+
+data:
+ type: math
+ dataset_name: boba
+ max_prompt_length: 256
+ filter_prompt_by_length: True
+ rollout_batch_size: 16
+ val_rollout_batch_size: null
+ num_workers: 2
+ prompt_key: prompt
+ shuffle: True
+ validation_shuffle: True
+ seed: 1
+ train_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"]
+ val_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"]
+
+actor:
+ group_name: "ActorGroup"
+ training_backend: fsdp
+ mcore_gpt: True
+ spec_name: decoder_gpt
+
+ enable_offload: True
+ checkpoint_load_path: null
+
+ global_batch_size: 8
+ micro_batch_size: 1
+
+ enable_dp_load_balance: False
+
+ calculate_flops: False
+
+ seed: 1234
+
+ model:
+ precision: bf16
+ sharding_strategy: full_shard
+ is_lora: False
+
+ seq_length: ${runner.seq_length}
+ encoder_seq_length: ${runner.seq_length}
+ model_path: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B
+
+ optim:
+ optimizer: adam
+ bf16: True
+ fp16: False
+ lr: 2e-05
+ adam_beta1: 0.9
+ adam_beta2: 0.95
+ adam_eps: 1.0e-05
+ min_lr: 2.0e-6
+ weight_decay: 0.05
+ use_distributed_optimizer: True
+ overlap_grad_reduce: False
+ overlap_param_gather: False
+ optimizer_enable_pin: false
+ overlap_param_gather_with_optimizer_step: False
+ clip_grad: 0.8
+ loss_scale: 65536
+
+ lr_sched:
+ lr_warmup_fraction: 0.01
+ lr_warmup_init: 0.0
+ lr_warmup_iters: 0
+ max_lr: 2.0e-5
+ min_lr: 0.0
+ lr_decay_style: constant
+ lr_decay_iters: 10
+
+ tokenizer:
+ tokenizer_model: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B
+ use_fast: False
+ trust_remote_code: True
+ padding_side: 'right'
+
+ fsdp:
+ forward_prefetch: True
+ limit_all_gathers: True
+ backward_prefetch: True
+ use_orig_params: True
+
+reward:
+ group_name: "RewardGroup"
+ use_reward_model: false
+ reward_type: 'math'
+ reward_scale: 5.0
+ reward_weights:
+ qa_accuracy: 1.0
+ think_format: 0.0
+ answer_format: 0.0
+
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
+critic:
+ use_critic_model: false
\ No newline at end of file
diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml
new file mode 100644
index 000000000..dbf7925f2
--- /dev/null
+++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml
@@ -0,0 +1,219 @@
+defaults:
+ - override hydra/job_logging: stdout
+
+hydra:
+ run:
+ dir: .
+ output_subdir: null
+
+cluster:
+ num_nodes: 1
+ component_placement:
+ actor,rollout,reward: all
+
+runner:
+ task_type: reasoning
+ logger:
+ log_path: /workspace/results/
+ project_name: rlinf
+ experiment_name: "ci-test"
+ logger_backends: ["tensorboard"] # wandb, swanlab
+
+ max_epochs: 1
+ max_steps: 3
+
+ val_check_interval: 1
+ save_interval: -1
+
+ seq_length: 1024
+
+ enable_dynamic_batch_size: True
+ max_tokens_per_mbs: 1024
+
+ resume_dir: null
+ experiment_name: grpo-1.5b
+ output_dir: /workspace/results
+algorithm:
+ group_size: 2
+
+ n_minibatches: 4
+ training_batch_size_per_gpu: 1 # micro batch size
+ rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory.
+
+ # mbs to do log prob inference, can be set to
+ # lower than rollout_batch_size_per_gpu to reduce
+ # memory usage
+ logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu}
+
+ # val rollout mbs
+ val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu}
+
+ recompute_logprobs: True
+ shuffle_rollout: False
+
+ # GRPO loss params
+ loss_type: math_ppo_actor
+ loss_agg_func: "token-mean"
+ kl_beta: 0.0 # 0.001
+ kl_penalty_type: low_var_kl
+ ratio_clip_eps: 0.2
+ entropy_bonus: 0.0
+ calculate_entropy: False
+ clip_ratio_c: null # 3.0
+ clip_ratio_low: null
+ clip_ratio_high: null
+
+ adv_type: math_grpo
+ normalize_advantages: True
+ early_stop_imp_ratio: 5.0
+ use_valid_token_scale: False
+
+ # params for rollout
+ sampling_params:
+ use_greedy: False
+ temperature: 1.0
+ top_k: 1000000
+ top_p: 1.0
+ repetition_penalty: 1.0
+ max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}}
+ min_new_tokens: 1
+
+rollout:
+ group_name: "RolloutGroup"
+
+ gpu_memory_utilization: 0.55
+
+ model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B
+ model_arch: qwen2.5
+ enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize.
+ distributed_executor_backend: mp # ray or mp
+ disable_log_stats: False
+ detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging.
+ padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine
+ eos: null # will be tokenizer.eos_token_id if null.
+
+ rollout_backend: sglang # [sglang, vllm]
+
+ sglang:
+ attention_backend: triton # [flashinfer, triton] for more, see sglang's doc
+ decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats.
+ use_torch_compile: False # enable torch_compile in SGLang for rollout.
+ torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used.
+
+ vllm:
+ attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc
+ enable_chunked_prefill: True # enable vllm to use chunked_prefill.
+ enable_prefix_caching: True # enable vllm to use prefix_caching.
+ enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling.
+
+ return_logprobs: ${not:${algorithm.recompute_logprobs}}
+
+ tensor_parallel_size: 1
+ pipeline_parallel_size: 1
+
+ validate_weight: False # whether to send all weights at first for weight comparison.
+ validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison.
+ print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine.
+
+ max_running_requests: 64 # the maximum number of running requests in the rollout engine.
+ cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used.
+
+data:
+ type: math
+ dataset_name: boba
+ max_prompt_length: 256
+ filter_prompt_by_length: True
+ rollout_batch_size: 16
+ val_rollout_batch_size: null
+ num_workers: 2
+ prompt_key: prompt
+ shuffle: True
+ validation_shuffle: True
+ seed: 1
+ train_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"]
+ val_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"]
+
+actor:
+ group_name: "ActorGroup"
+ training_backend: fsdp
+ mcore_gpt: True
+ spec_name: decoder_gpt
+
+ enable_offload: True
+ checkpoint_load_path: null
+
+ global_batch_size: 8
+ micro_batch_size: 1
+
+ enable_dp_load_balance: False
+
+ calculate_flops: False
+
+ seed: 1234
+
+ model:
+ precision: bf16
+ sharding_strategy: full_shard
+ is_lora: False
+
+ seq_length: ${runner.seq_length}
+ encoder_seq_length: ${runner.seq_length}
+ model_path: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B
+
+ optim:
+ optimizer: adam
+ bf16: True
+ fp16: False
+ lr: 2e-05
+ adam_beta1: 0.9
+ adam_beta2: 0.95
+ adam_eps: 1.0e-05
+ min_lr: 2.0e-6
+ weight_decay: 0.05
+ use_distributed_optimizer: True
+ overlap_grad_reduce: False
+ overlap_param_gather: False
+ optimizer_enable_pin: false
+ overlap_param_gather_with_optimizer_step: False
+ clip_grad: 0.8
+ loss_scale: 65536
+
+ lr_sched:
+ lr_warmup_fraction: 0.01
+ lr_warmup_init: 0.0
+ lr_warmup_iters: 0
+ max_lr: 2.0e-5
+ min_lr: 0.0
+ lr_decay_style: constant
+ lr_decay_iters: 10
+
+ tokenizer:
+ tokenizer_model: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B
+ use_fast: False
+ trust_remote_code: True
+ padding_side: 'right'
+
+ fsdp:
+ forward_prefetch: True
+ limit_all_gathers: True
+ backward_prefetch: True
+ use_orig_params: True
+
+reward:
+ group_name: "RewardGroup"
+ use_reward_model: false
+ reward_type: 'math'
+ reward_scale: 5.0
+ reward_weights:
+ qa_accuracy: 1.0
+ think_format: 0.0
+ answer_format: 0.0
+
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
+critic:
+ use_critic_model: false
\ No newline at end of file
diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml
new file mode 100644
index 000000000..e85fd4146
--- /dev/null
+++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml
@@ -0,0 +1,219 @@
+defaults:
+ - override hydra/job_logging: stdout
+
+hydra:
+ run:
+ dir: .
+ output_subdir: null
+
+cluster:
+ num_nodes: 1
+ component_placement:
+ actor,rollout,reward: all
+
+runner:
+ task_type: reasoning
+ logger:
+ log_path: /workspace/results/
+ project_name: rlinf
+ experiment_name: "ci-test"
+ logger_backends: ["tensorboard"] # wandb, swanlab
+
+ max_epochs: 1
+ max_steps: 3
+
+ val_check_interval: 1
+ save_interval: -1
+
+ seq_length: 1024
+
+ enable_dynamic_batch_size: True
+ max_tokens_per_mbs: 1024
+
+ resume_dir: null
+ experiment_name: grpo-1.5b
+ output_dir: /workspace/results
+algorithm:
+ group_size: 2
+
+ n_minibatches: 4
+ training_batch_size_per_gpu: 1 # micro batch size
+ rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory.
+
+ # mbs to do log prob inference, can be set to
+ # lower than rollout_batch_size_per_gpu to reduce
+ # memory usage
+ logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu}
+
+ # val rollout mbs
+ val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu}
+
+ recompute_logprobs: False
+ shuffle_rollout: False
+
+ # GRPO loss params
+ loss_type: math_ppo_actor
+ loss_agg_func: "token-mean"
+ kl_beta: 0.0 # 0.001
+ kl_penalty_type: low_var_kl
+ ratio_clip_eps: 0.2
+ entropy_bonus: 0.0
+ calculate_entropy: False
+ clip_ratio_c: null # 3.0
+ clip_ratio_low: null
+ clip_ratio_high: null
+
+ adv_type: math_grpo
+ normalize_advantages: True
+ early_stop_imp_ratio: 5.0
+ use_valid_token_scale: False
+
+ # params for rollout
+ sampling_params:
+ use_greedy: False
+ temperature: 1.0
+ top_k: 1000000
+ top_p: 1.0
+ repetition_penalty: 1.0
+ max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}}
+ min_new_tokens: 1
+
+rollout:
+ group_name: "RolloutGroup"
+
+ gpu_memory_utilization: 0.55
+
+ model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B
+ model_arch: qwen2.5
+ enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize.
+ distributed_executor_backend: mp # ray or mp
+ disable_log_stats: False
+ detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging.
+ padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine
+ eos: null # will be tokenizer.eos_token_id if null.
+
+ rollout_backend: vllm # [sglang, vllm]
+
+ sglang:
+ attention_backend: triton # [flashinfer, triton] for more, see sglang's doc
+ decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats.
+ use_torch_compile: False # enable torch_compile in SGLang for rollout.
+ torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used.
+
+ vllm:
+ attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc
+ enable_chunked_prefill: True # enable vllm to use chunked_prefill.
+ enable_prefix_caching: True # enable vllm to use prefix_caching.
+ enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling.
+
+ return_logprobs: ${not:${algorithm.recompute_logprobs}}
+
+ tensor_parallel_size: 1
+ pipeline_parallel_size: 1
+
+ validate_weight: False # whether to send all weights at first for weight comparison.
+ validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison.
+ print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine.
+
+ max_running_requests: 64 # the maximum number of running requests in the rollout engine.
+ cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used.
+
+data:
+ type: math
+ dataset_name: boba
+ max_prompt_length: 256
+ filter_prompt_by_length: True
+ rollout_batch_size: 16
+ val_rollout_batch_size: null
+ num_workers: 2
+ prompt_key: prompt
+ shuffle: True
+ validation_shuffle: True
+ seed: 1
+ train_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"]
+ val_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"]
+
+actor:
+ group_name: "ActorGroup"
+ training_backend: fsdp
+ mcore_gpt: True
+ spec_name: decoder_gpt
+
+ enable_offload: True
+ checkpoint_load_path: null
+
+ global_batch_size: 8
+ micro_batch_size: 1
+
+ enable_dp_load_balance: False
+
+ calculate_flops: False
+
+ seed: 1234
+
+ model:
+ precision: bf16
+ sharding_strategy: full_shard
+ is_lora: False
+
+ seq_length: ${runner.seq_length}
+ encoder_seq_length: ${runner.seq_length}
+ model_path: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B
+
+ optim:
+ optimizer: adam
+ bf16: True
+ fp16: False
+ lr: 2e-05
+ adam_beta1: 0.9
+ adam_beta2: 0.95
+ adam_eps: 1.0e-05
+ min_lr: 2.0e-6
+ weight_decay: 0.05
+ use_distributed_optimizer: True
+ overlap_grad_reduce: False
+ overlap_param_gather: False
+ optimizer_enable_pin: false
+ overlap_param_gather_with_optimizer_step: False
+ clip_grad: 0.8
+ loss_scale: 65536
+
+ lr_sched:
+ lr_warmup_fraction: 0.01
+ lr_warmup_init: 0.0
+ lr_warmup_iters: 0
+ max_lr: 2.0e-5
+ min_lr: 0.0
+ lr_decay_style: constant
+ lr_decay_iters: 10
+
+ tokenizer:
+ tokenizer_model: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B
+ use_fast: False
+ trust_remote_code: True
+ padding_side: 'right'
+
+ fsdp:
+ forward_prefetch: True
+ limit_all_gathers: True
+ backward_prefetch: True
+ use_orig_params: True
+
+reward:
+ group_name: "RewardGroup"
+ use_reward_model: false
+ reward_type: 'math'
+ reward_scale: 5.0
+ reward_weights:
+ qa_accuracy: 1.0
+ think_format: 0.0
+ answer_format: 0.0
+
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
+critic:
+ use_critic_model: false
\ No newline at end of file
diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml
new file mode 100644
index 000000000..1aab4cb7b
--- /dev/null
+++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml
@@ -0,0 +1,219 @@
+defaults:
+ - override hydra/job_logging: stdout
+
+hydra:
+ run:
+ dir: .
+ output_subdir: null
+
+cluster:
+ num_nodes: 1
+ component_placement:
+ actor,rollout,reward: all
+
+runner:
+ task_type: reasoning
+ logger:
+ log_path: /workspace/results/
+ project_name: rlinf
+ experiment_name: "ci-test"
+ logger_backends: ["tensorboard"] # wandb, swanlab
+
+ max_epochs: 1
+ max_steps: 3
+
+ val_check_interval: 1
+ save_interval: -1
+
+ seq_length: 1024
+
+ enable_dynamic_batch_size: True
+ max_tokens_per_mbs: 1024
+
+ resume_dir: null
+ experiment_name: grpo-1.5b
+ output_dir: /workspace/results
+algorithm:
+ group_size: 2
+
+ n_minibatches: 4
+ training_batch_size_per_gpu: 1 # micro batch size
+ rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory.
+
+ # mbs to do log prob inference, can be set to
+ # lower than rollout_batch_size_per_gpu to reduce
+ # memory usage
+ logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu}
+
+ # val rollout mbs
+ val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu}
+
+ recompute_logprobs: True
+ shuffle_rollout: False
+
+ # GRPO loss params
+ loss_type: math_ppo_actor
+ loss_agg_func: "token-mean"
+ kl_beta: 0.0 # 0.001
+ kl_penalty_type: low_var_kl
+ ratio_clip_eps: 0.2
+ entropy_bonus: 0.0
+ calculate_entropy: False
+ clip_ratio_c: null # 3.0
+ clip_ratio_low: null
+ clip_ratio_high: null
+
+ adv_type: math_grpo
+ normalize_advantages: True
+ early_stop_imp_ratio: 5.0
+ use_valid_token_scale: False
+
+ # params for rollout
+ sampling_params:
+ use_greedy: False
+ temperature: 1.0
+ top_k: 1000000
+ top_p: 1.0
+ repetition_penalty: 1.0
+ max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}}
+ min_new_tokens: 1
+
+rollout:
+ group_name: "RolloutGroup"
+
+ gpu_memory_utilization: 0.55
+
+ model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B
+ model_arch: qwen2.5
+ enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize.
+ distributed_executor_backend: mp # ray or mp
+ disable_log_stats: False
+ detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging.
+ padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine
+ eos: null # will be tokenizer.eos_token_id if null.
+
+ rollout_backend: vllm # [sglang, vllm]
+
+ sglang:
+ attention_backend: triton # [flashinfer, triton] for more, see sglang's doc
+ decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats.
+ use_torch_compile: False # enable torch_compile in SGLang for rollout.
+ torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used.
+
+ vllm:
+ attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc
+ enable_chunked_prefill: True # enable vllm to use chunked_prefill.
+ enable_prefix_caching: True # enable vllm to use prefix_caching.
+ enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling.
+
+ return_logprobs: ${not:${algorithm.recompute_logprobs}}
+
+ tensor_parallel_size: 1
+ pipeline_parallel_size: 1
+
+ validate_weight: False # whether to send all weights at first for weight comparison.
+ validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison.
+ print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine.
+
+ max_running_requests: 64 # the maximum number of running requests in the rollout engine.
+ cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used.
+
+data:
+ type: math
+ dataset_name: boba
+ max_prompt_length: 256
+ filter_prompt_by_length: True
+ rollout_batch_size: 16
+ val_rollout_batch_size: null
+ num_workers: 2
+ prompt_key: prompt
+ shuffle: True
+ validation_shuffle: True
+ seed: 1
+ train_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"]
+ val_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"]
+
+actor:
+ group_name: "ActorGroup"
+ training_backend: fsdp
+ mcore_gpt: True
+ spec_name: decoder_gpt
+
+ enable_offload: True
+ checkpoint_load_path: null
+
+ global_batch_size: 8
+ micro_batch_size: 1
+
+ enable_dp_load_balance: False
+
+ calculate_flops: False
+
+ seed: 1234
+
+ model:
+ precision: bf16
+ sharding_strategy: full_shard
+ is_lora: False
+
+ seq_length: ${runner.seq_length}
+ encoder_seq_length: ${runner.seq_length}
+ model_path: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B
+
+ optim:
+ optimizer: adam
+ bf16: True
+ fp16: False
+ lr: 2e-05
+ adam_beta1: 0.9
+ adam_beta2: 0.95
+ adam_eps: 1.0e-05
+ min_lr: 2.0e-6
+ weight_decay: 0.05
+ use_distributed_optimizer: True
+ overlap_grad_reduce: False
+ overlap_param_gather: False
+ optimizer_enable_pin: false
+ overlap_param_gather_with_optimizer_step: False
+ clip_grad: 0.8
+ loss_scale: 65536
+
+ lr_sched:
+ lr_warmup_fraction: 0.01
+ lr_warmup_init: 0.0
+ lr_warmup_iters: 0
+ max_lr: 2.0e-5
+ min_lr: 0.0
+ lr_decay_style: constant
+ lr_decay_iters: 10
+
+ tokenizer:
+ tokenizer_model: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B
+ use_fast: False
+ trust_remote_code: True
+ padding_side: 'right'
+
+ fsdp:
+ forward_prefetch: True
+ limit_all_gathers: True
+ backward_prefetch: True
+ use_orig_params: True
+
+reward:
+ group_name: "RewardGroup"
+ use_reward_model: false
+ reward_type: 'math'
+ reward_scale: 5.0
+ reward_weights:
+ qa_accuracy: 1.0
+ think_format: 0.0
+ answer_format: 0.0
+
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
+critic:
+ use_critic_model: false
\ No newline at end of file
diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs.yaml
similarity index 93%
rename from tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml
rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs.yaml
index 7e2c5164a..4edc14979 100644
--- a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml
+++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs.yaml
@@ -9,10 +9,10 @@ hydra:
cluster:
num_nodes: 1
component_placement:
- actor,rollout: all
+ actor,rollout,reward: all
runner:
- task_type: math
+ task_type: reasoning
logger:
log_path: /workspace/results/
project_name: rlinf
@@ -84,11 +84,11 @@ rollout:
model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B
model_arch: qwen2.5
- enforce_eager: False # if False, vllm will capture cuda graph, which will take more time to initialize.
+ enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize.
distributed_executor_backend: mp # ray or mp
disable_log_stats: False
detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging.
- padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for vllm rollout
+ padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine
eos: null # will be tokenizer.eos_token_id if null.
rollout_backend: sglang # [sglang, vllm]
@@ -114,13 +114,14 @@ rollout:
validate_weight: False # whether to send all weights at first for weight comparison.
validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison.
- print_outputs: False # whether to print the outputs (token ids, texts, etc.) of inference engine.
+ print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine.
- max_running_requests: 64 # the maximum number of running requests in the inference engine.
+ max_running_requests: 64 # the maximum number of running requests in the rollout engine.
cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used.
data:
type: math
+ dataset_name: boba
max_prompt_length: 256
filter_prompt_by_length: True
rollout_batch_size: 8
@@ -259,9 +260,16 @@ actor:
reward:
+ group_name: "RewardGroup"
use_reward_model: false
reward_type: 'math'
reward_scale: 5.0
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
critic:
use_critic_model: false
\ No newline at end of file
diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl.yaml
similarity index 95%
rename from tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated.yaml
rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl.yaml
index 1dfe47aeb..1854fd8c3 100644
--- a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated.yaml
+++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl.yaml
@@ -9,10 +9,10 @@ hydra:
cluster:
num_nodes: 1
component_placement:
- actor,rollout: all
+ actor,rollout,reward: all
runner:
- task_type: math
+ task_type: reasoning
logger:
log_path: /workspace/results/
project_name: rlinf
@@ -107,7 +107,6 @@ rollout:
max_num_batched_tokens: null # the maximum number of tokens to be batched together in vllm. If set to null, vllm will use its default value.
torch_profiler_dir: null # if not null, vllm will enable torch profiler and save the result to the specified directory.
-
return_logprobs: ${not:${algorithm.recompute_logprobs}}
tensor_parallel_size: 1
@@ -117,15 +116,12 @@ rollout:
validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison.
print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine.
- sglang_decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats.
max_running_requests: 64 # the maximum number of running requests in the rollout engine.
cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used.
- use_torch_compile: False # enable torch_compile in SGLang for rollout.
- torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used.
-
data:
type: math
+ dataset_name: boba
max_prompt_length: 256
filter_prompt_by_length: True
rollout_batch_size: 8
@@ -264,9 +260,16 @@ actor:
reward:
+ group_name: "RewardGroup"
use_reward_model: false
reward_type: 'math'
reward_scale: 5.0
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
critic:
use_critic_model: false
\ No newline at end of file
diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs.yaml
similarity index 96%
rename from tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml
rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs.yaml
index fe61ab16c..edeaee9c3 100644
--- a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml
+++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs.yaml
@@ -9,10 +9,10 @@ hydra:
cluster:
num_nodes: 1
component_placement:
- actor,rollout: all
+ actor,rollout,reward: all
runner:
- task_type: math
+ task_type: reasoning
logger:
log_path: /workspace/results/
project_name: rlinf
@@ -121,6 +121,7 @@ rollout:
data:
type: math
+ dataset_name: boba
max_prompt_length: 256
filter_prompt_by_length: True
rollout_batch_size: 8
@@ -259,9 +260,16 @@ actor:
reward:
+ group_name: "RewardGroup"
use_reward_model: false
reward_type: 'math'
reward_scale: 5.0
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
critic:
use_critic_model: false
\ No newline at end of file
diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm.yaml
similarity index 96%
rename from tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated.yaml
rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm.yaml
index 099fe7268..09df84ca8 100644
--- a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated.yaml
+++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm.yaml
@@ -9,10 +9,10 @@ hydra:
cluster:
num_nodes: 1
component_placement:
- actor,rollout: all
+ actor,rollout,reward: all
runner:
- task_type: math
+ task_type: reasoning
logger:
log_path: /workspace/results/
project_name: rlinf
@@ -122,6 +122,7 @@ rollout:
data:
type: math
+ dataset_name: boba
max_prompt_length: 256
filter_prompt_by_length: True
rollout_batch_size: 8
@@ -260,9 +261,16 @@ actor:
reward:
+ group_name: "RewardGroup"
use_reward_model: false
reward_type: 'math'
reward_scale: 5.0
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
critic:
use_critic_model: false
\ No newline at end of file
diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs.yaml
similarity index 96%
rename from tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml
rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs.yaml
index 25344d0bf..34bcff492 100644
--- a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml
+++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs.yaml
@@ -11,9 +11,10 @@ cluster:
component_placement:
rollout: 0-3
actor: 4-7
+ reward: 0-3
runner:
- task_type: math
+ task_type: reasoning
logger:
log_path: /workspace/results/
project_name: rlinf
@@ -124,6 +125,7 @@ rollout:
data:
type: math
+ dataset_name: boba
max_prompt_length: 256
rollout_batch_size: 8
val_rollout_batch_size: null
@@ -257,11 +259,17 @@ actor:
schedule_repeat: 1 # inference and training will repeat such times
# schedule_wait: it will be set at runtime
-
reward:
+ group_name: "RewardGroup"
use_reward_model: false
reward_type: 'math'
reward_scale: 5.0
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
critic:
use_critic_model: false
\ No newline at end of file
diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl.yaml
similarity index 96%
rename from tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline.yaml
rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl.yaml
index d5cd2b4a4..b48eb6057 100644
--- a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline.yaml
+++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl.yaml
@@ -12,9 +12,10 @@ cluster:
rollout: 0-3
inference: 4-5
actor: 6-7
+ reward: 0-3
runner:
- task_type: math
+ task_type: reasoning
logger:
log_path: /workspace/results/
project_name: rlinf
@@ -138,6 +139,7 @@ rollout:
data:
type: math
+ dataset_name: boba
max_prompt_length: 256
rollout_batch_size: 8
val_rollout_batch_size: null
@@ -273,9 +275,16 @@ actor:
reward:
+ group_name: "RewardGroup"
use_reward_model: false
reward_type: 'math'
reward_scale: 5.0
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
critic:
use_critic_model: false
\ No newline at end of file
diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs.yaml
similarity index 96%
rename from tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml
rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs.yaml
index 62d0c5247..111969ff4 100644
--- a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml
+++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs.yaml
@@ -12,9 +12,10 @@ cluster:
component_placement:
rollout: 0-3
actor: 4-7
+ reward: 0-3
runner:
- task_type: math
+ task_type: reasoning
logger:
log_path: /workspace/results/
project_name: rlinf
@@ -66,7 +67,7 @@ algorithm:
calculate_entropy: True
clip_ratio_c: null # 3.0
- adv_type: grpo
+ adv_type: math_grpo
normalize_advantages: False
early_stop_imp_ratio: 5.0
use_valid_token_scale: True
@@ -260,9 +261,16 @@ actor:
reward:
+ group_name: "RewardGroup"
use_reward_model: false
reward_type: 'math'
reward_scale: 5.0
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
critic:
use_critic_model: false
\ No newline at end of file
diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm.yaml
similarity index 96%
rename from tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline.yaml
rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm.yaml
index 3f7821587..705757c31 100644
--- a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline.yaml
+++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm.yaml
@@ -13,9 +13,10 @@ cluster:
rollout: 0-3
inference: 4-5
actor: 6-7
+ reward: 0-3
runner:
- task_type: math
+ task_type: reasoning
logger:
log_path: /workspace/results/
project_name: rlinf
@@ -67,7 +68,7 @@ algorithm:
calculate_entropy: True
clip_ratio_c: null # 3.0
- adv_type: grpo
+ adv_type: math_grpo
normalize_advantages: False
early_stop_imp_ratio: 5.0
use_valid_token_scale: True
@@ -274,9 +275,16 @@ actor:
reward:
+ group_name: "RewardGroup"
use_reward_model: false
reward_type: 'math'
reward_scale: 5.0
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
critic:
use_critic_model: false
\ No newline at end of file
diff --git a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml
new file mode 100644
index 000000000..ddc087f99
--- /dev/null
+++ b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml
@@ -0,0 +1,228 @@
+defaults:
+ - override hydra/job_logging: stdout
+
+hydra:
+ run:
+ dir: .
+ output_subdir: null
+
+cluster:
+ num_nodes: 1
+ component_placement:
+ actor,rollout,reward: all
+
+runner:
+ task_type: reasoning
+ logger:
+ log_path: ${runner.output_dir}/${runner.experiment_name}
+ project_name: rlinf
+ experiment_name: ${runner.experiment_name}
+ logger_backends: ["tensorboard"] # wandb, swanlab
+
+ max_epochs: 1
+ max_steps: 2
+
+ val_check_interval: 1
+ save_interval: -1
+
+ seq_length: 2048
+
+ enable_dynamic_batch_size: False
+ max_tokens_per_mbs: 28672
+
+ resume_dir: null
+ experiment_name: grpo-qwen2.5-vl-3b
+ output_dir: /workspace/results
+
+algorithm:
+ group_size: 8
+
+ n_minibatches: 4
+ training_batch_size_per_gpu: 1 # micro batch size
+ rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory.
+
+ # mbs to do log prob inference, can be set to
+ # lower than rollout_batch_size_per_gpu to reduce
+ # memory usage
+ logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu}
+
+ # val rollout mbs
+ val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu}
+
+ recompute_logprobs: False
+ shuffle_rollout: False
+
+ # GRPO loss params
+ loss_type: math_ppo_actor
+ loss_agg_func: "token-mean"
+ kl_beta: 0.0 # 0.001
+ kl_penalty_type: low_var_kl
+ ratio_clip_eps: 0.2
+ entropy_bonus: 0.0
+ calculate_entropy: False
+ clip_ratio_c: null # 3.0
+ clip_ratio_low: null
+ clip_ratio_high: null
+
+ adv_type: math_grpo
+ normalize_advantages: True
+ early_stop_imp_ratio: 5.0
+ use_valid_token_scale: False
+
+ # params for rollout
+ sampling_params:
+ use_greedy: False
+ temperature: 1.0
+ top_k: 1000000
+ top_p: 1.0
+ repetition_penalty: 1.0
+ max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}}
+ min_new_tokens: 1
+
+rollout:
+ group_name: "RolloutGroup"
+
+ gpu_memory_utilization: 0.55
+
+ model_dir: /workspace/dataset/Qwen2.5-VL-3B-Instruct
+ model_arch: qwen2.5_vl #qwen2.5
+ enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize.
+ distributed_executor_backend: mp # ray or mp
+ disable_log_stats: False
+ detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging.
+ padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine
+ eos: null # will be tokenizer.eos_token_id if null.
+
+ rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm]
+
+ sglang:
+ attention_backend: triton # [flashinfer, triton] for more, see sglang's doc
+ decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats.
+ use_torch_compile: False # enable torch_compile in SGLang for rollout.
+ torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used.
+
+ vllm:
+ attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc
+ enable_chunked_prefill: True # enable vllm to use chunked_prefill.
+ enable_prefix_caching: True # enable vllm to use prefix_caching.
+ enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling.
+
+ return_logprobs: ${not:${algorithm.recompute_logprobs}}
+
+ tensor_parallel_size: 2
+ pipeline_parallel_size: 1
+
+ validate_weight: False # whether to send all weights at first for weight comparison.
+ validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison.
+ print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine.
+
+ max_running_requests: 64 # the maximum number of running requests in the rollout engine.
+ cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used.
+
+data:
+ type: vision_language
+ dataset_name: robo2vlm
+ max_prompt_length: 1024
+ filter_prompt_by_length: True
+ rollout_batch_size: 16
+ val_rollout_batch_size: null
+ num_workers: 2
+ prompt_key: prompt
+ image_keys: ["image"] # some vlm datasets may have multiple image columns
+ choice_key: "choices"
+ answer_key: "answer"
+ solution_key: "solution"
+ use_chat_template: True
+ lazy_loading: True
+ shuffle: True
+ validation_shuffle: True
+ seed: 1234
+ train_data_paths: ["/workspace/dataset/robo2vlm-1/data/train-00000-of-00262.parquet"]
+ val_data_paths: ["/workspace/dataset/robo2vlm-1/data/test-00000-of-00003.parquet"]
+
+actor:
+ group_name: "ActorGroup"
+ training_backend: fsdp
+ mcore_gpt: True
+ spec_name: decoder_gpt
+
+ enable_offload: True
+ checkpoint_load_path: null
+
+ global_batch_size: 8
+ micro_batch_size: 1
+
+ enable_dp_load_balance: False
+
+ calculate_flops: False
+
+ seed: 1234
+
+ model:
+ precision: bf16
+ sharding_strategy: full_shard
+ is_lora: False
+
+ seq_length: ${runner.seq_length}
+ encoder_seq_length: ${runner.seq_length}
+ model_path: /workspace/dataset/Qwen2.5-VL-3B-Instruct
+
+ model_arch: ${rollout.model_arch}
+
+ optim:
+ optimizer: adam
+ bf16: True #False
+ fp16: False #True
+ lr: 2e-05
+ adam_beta1: 0.9
+ adam_beta2: 0.95
+ adam_eps: 1.0e-05
+ min_lr: 2.0e-6
+ weight_decay: 0.05
+ use_distributed_optimizer: True
+ overlap_grad_reduce: False
+ overlap_param_gather: False
+ optimizer_enable_pin: false
+ overlap_param_gather_with_optimizer_step: False
+ clip_grad: 0.8
+ loss_scale: 65536
+
+ lr_sched:
+ lr_warmup_fraction: 0.01
+ lr_warmup_init: 0.0
+ lr_warmup_iters: 0
+ max_lr: 2.0e-5
+ min_lr: 0.0
+ lr_decay_style: constant
+ lr_decay_iters: 10
+
+ tokenizer:
+ tokenizer_model: /workspace/dataset/Qwen2.5-VL-3B-Instruct
+ use_fast: False
+ trust_remote_code: True
+ padding_side: 'right'
+
+ fsdp:
+ forward_prefetch: True
+ limit_all_gathers: True
+ backward_prefetch: True
+ use_orig_params: True
+
+reward:
+ group_name: "RewardGroup"
+ use_reward_model: false
+ reward_type: 'vqa'
+ reward_scale: 1.0
+ reward_weights:
+ qa_accuracy: 1.0
+ think_format: 0.0
+ answer_format: 0.0
+
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
+critic:
+ use_critic_model: false
\ No newline at end of file
diff --git a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml
new file mode 100644
index 000000000..86c64ebe0
--- /dev/null
+++ b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml
@@ -0,0 +1,228 @@
+defaults:
+ - override hydra/job_logging: stdout
+
+hydra:
+ run:
+ dir: .
+ output_subdir: null
+
+cluster:
+ num_nodes: 1
+ component_placement:
+ actor,rollout,reward: all
+
+runner:
+ task_type: reasoning
+ logger:
+ log_path: ${runner.output_dir}/${runner.experiment_name}
+ project_name: rlinf
+ experiment_name: ${runner.experiment_name}
+ logger_backends: ["tensorboard"] # wandb, swanlab
+
+ max_epochs: 1
+ max_steps: 2
+
+ val_check_interval: 1
+ save_interval: -1
+
+ seq_length: 2048
+
+ enable_dynamic_batch_size: False
+ max_tokens_per_mbs: 28672
+
+ resume_dir: null
+ experiment_name: grpo-qwen2.5-vl-3b
+ output_dir: /workspace/results
+
+algorithm:
+ group_size: 8
+
+ n_minibatches: 4
+ training_batch_size_per_gpu: 1 # micro batch size
+ rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory.
+
+ # mbs to do log prob inference, can be set to
+ # lower than rollout_batch_size_per_gpu to reduce
+ # memory usage
+ logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu}
+
+ # val rollout mbs
+ val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu}
+
+ recompute_logprobs: False
+ shuffle_rollout: False
+
+ # GRPO loss params
+ loss_type: math_ppo_actor
+ loss_agg_func: "token-mean"
+ kl_beta: 0.0 # 0.001
+ kl_penalty_type: low_var_kl
+ ratio_clip_eps: 0.2
+ entropy_bonus: 0.0
+ calculate_entropy: False
+ clip_ratio_c: null # 3.0
+ clip_ratio_low: null
+ clip_ratio_high: null
+
+ adv_type: math_grpo
+ normalize_advantages: True
+ early_stop_imp_ratio: 5.0
+ use_valid_token_scale: False
+
+ # params for rollout
+ sampling_params:
+ use_greedy: False
+ temperature: 1.0
+ top_k: 1000000
+ top_p: 1.0
+ repetition_penalty: 1.0
+ max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}}
+ min_new_tokens: 1
+
+rollout:
+ group_name: "RolloutGroup"
+
+ gpu_memory_utilization: 0.55
+
+ model_dir: /workspace/dataset/Qwen2.5-VL-3B-Instruct
+ model_arch: qwen2.5_vl #qwen2.5
+ enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize.
+ distributed_executor_backend: mp # ray or mp
+ disable_log_stats: False
+ detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging.
+ padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine
+ eos: null # will be tokenizer.eos_token_id if null.
+
+ rollout_backend: vllm # here choose which backend to rollout,support [sglang, vllm]
+
+ sglang:
+ attention_backend: triton # [flashinfer, triton] for more, see sglang's doc
+ decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats.
+ use_torch_compile: False # enable torch_compile in SGLang for rollout.
+ torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used.
+
+ vllm:
+ attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc
+ enable_chunked_prefill: True # enable vllm to use chunked_prefill.
+ enable_prefix_caching: True # enable vllm to use prefix_caching.
+ enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling.
+
+ return_logprobs: ${not:${algorithm.recompute_logprobs}}
+
+ tensor_parallel_size: 2
+ pipeline_parallel_size: 1
+
+ validate_weight: False # whether to send all weights at first for weight comparison.
+ validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison.
+ print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine.
+
+ max_running_requests: 64 # the maximum number of running requests in the rollout engine.
+ cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used.
+
+data:
+ type: vision_language
+ dataset_name: robo2vlm
+ max_prompt_length: 1024
+ filter_prompt_by_length: True
+ rollout_batch_size: 16
+ val_rollout_batch_size: null
+ num_workers: 2
+ prompt_key: prompt
+ image_keys: ["image"] # some vlm datasets may have multiple image columns
+ choice_key: "choices"
+ answer_key: "answer"
+ solution_key: "solution"
+ use_chat_template: True
+ lazy_loading: True
+ shuffle: True
+ validation_shuffle: True
+ seed: 1234
+ train_data_paths: ["/workspace/dataset/robo2vlm-1/data/train-00000-of-00262.parquet"]
+ val_data_paths: ["/workspace/dataset/robo2vlm-1/data/test-00000-of-00003.parquet"]
+
+actor:
+ group_name: "ActorGroup"
+ training_backend: fsdp
+ mcore_gpt: True
+ spec_name: decoder_gpt
+
+ enable_offload: True
+ checkpoint_load_path: null
+
+ global_batch_size: 8
+ micro_batch_size: 1
+
+ enable_dp_load_balance: False
+
+ calculate_flops: False
+
+ seed: 1234
+
+ model:
+ precision: bf16
+ sharding_strategy: full_shard
+ is_lora: False
+
+ seq_length: ${runner.seq_length}
+ encoder_seq_length: ${runner.seq_length}
+ model_path: /workspace/dataset/Qwen2.5-VL-3B-Instruct
+
+ model_arch: ${rollout.model_arch}
+
+ optim:
+ optimizer: adam
+ bf16: True #False
+ fp16: False #True
+ lr: 2e-05
+ adam_beta1: 0.9
+ adam_beta2: 0.95
+ adam_eps: 1.0e-05
+ min_lr: 2.0e-6
+ weight_decay: 0.05
+ use_distributed_optimizer: True
+ overlap_grad_reduce: False
+ overlap_param_gather: False
+ optimizer_enable_pin: false
+ overlap_param_gather_with_optimizer_step: False
+ clip_grad: 0.8
+ loss_scale: 65536
+
+ lr_sched:
+ lr_warmup_fraction: 0.01
+ lr_warmup_init: 0.0
+ lr_warmup_iters: 0
+ max_lr: 2.0e-5
+ min_lr: 0.0
+ lr_decay_style: constant
+ lr_decay_iters: 10
+
+ tokenizer:
+ tokenizer_model: /workspace/dataset/Qwen2.5-VL-3B-Instruct
+ use_fast: False
+ trust_remote_code: True
+ padding_side: 'right'
+
+ fsdp:
+ forward_prefetch: True
+ limit_all_gathers: True
+ backward_prefetch: True
+ use_orig_params: True
+
+reward:
+ group_name: "RewardGroup"
+ use_reward_model: false
+ reward_type: 'vqa'
+ reward_scale: 1.0
+ reward_weights:
+ qa_accuracy: 1.0
+ think_format: 0.0
+ answer_format: 0.0
+
+ tokenizer:
+ tokenizer_model: ${actor.tokenizer.tokenizer_model}
+ use_fast: ${actor.tokenizer.use_fast}
+ trust_remote_code: ${actor.tokenizer.trust_remote_code}
+ padding_side: ${actor.tokenizer.padding_side}
+
+critic:
+ use_critic_model: false
\ No newline at end of file
diff --git a/tests/e2e_tests/reasoning/run.sh b/tests/e2e_tests/reasoning/run.sh
new file mode 100644
index 000000000..92e43866f
--- /dev/null
+++ b/tests/e2e_tests/reasoning/run.sh
@@ -0,0 +1,16 @@
+#! /bin/bash
+set -x
+
+tabs 4
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+export TOKENIZERS_PARALLELISM=false
+
+export PYTHONPATH=${REPO_PATH}:$PYTHONPATH
+
+if [ -z "$1" ]; then
+ echo "Please provide a config name as the first argument."
+ exit 1
+else
+ CONFIG_NAME=$1
+fi
+python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/reasoning/ --config-name $CONFIG_NAME
\ No newline at end of file
diff --git a/tests/unit_tests/test_auto_placement.py b/tests/unit_tests/test_auto_placement.py
index a70b01c9d..559229763 100644
--- a/tests/unit_tests/test_auto_placement.py
+++ b/tests/unit_tests/test_auto_placement.py
@@ -598,7 +598,7 @@ def test_scheduler_task_initialization(self, mock_validate):
"""Test SchedulerTask initialization."""
# Create a mock config
mock_cfg = MagicMock()
- mock_cfg.runner.task_type = "math"
+ mock_cfg.runner.task_type = "reasoning"
mock_cfg.actor.model.tensor_model_parallel_size = 2
mock_cfg.actor.model.pipeline_model_parallel_size = 1
mock_cfg.rollout.tensor_parallel_size = 1
@@ -620,7 +620,7 @@ def test_scheduler_task_initialization(self, mock_validate):
scheduler_task = SchedulerTask(mock_cfg, mock_cluster)
- assert scheduler_task.is_math is True
+ assert scheduler_task.is_reasoning is True
assert scheduler_task.total_gpus == 8
assert scheduler_task.group_size == 4
assert "actor" in scheduler_task.components_config
diff --git a/tests/unit_tests/test_placement.py b/tests/unit_tests/test_placement.py
index c16ff7fb9..22c75ab68 100644
--- a/tests/unit_tests/test_placement.py
+++ b/tests/unit_tests/test_placement.py
@@ -1087,34 +1087,6 @@ def test_model_parallel_component_placement_init_missing_rollout_gpus(self):
cluster = mock_cluster(num_nodes=1, num_accelerators_per_node=4)
ModelParallelComponentPlacement(config, cluster)
- def test_model_parallel_component_placement_init_collocated_mode_invalid_tp_sizes(
- self,
- ):
- """Test ModelParallelComponentPlacement raises error when actor TP size < rollout TP size in collocated mode."""
- config = DictConfig(
- {
- "cluster": {
- "num_nodes": 1,
- "component_placement": {"actor,rollout": "0-3"},
- },
- "actor": {
- "model": {
- "tensor_model_parallel_size": 2,
- "context_parallel_size": 1,
- "pipeline_model_parallel_size": 1,
- }
- },
- "rollout": {"tensor_parallel_size": 4, "pipeline_parallel_size": 1},
- }
- )
-
- with pytest.raises(
- AssertionError,
- match="Actor TP size 2 must be greater or equal to Rollout TP size 4",
- ):
- cluster = mock_cluster(num_nodes=1, num_accelerators_per_node=4)
- ModelParallelComponentPlacement(config, cluster)
-
def test_model_parallel_component_placement_init_collocated_mode_with_inference_gpus(
self,
):
diff --git a/toolkits/__init__.py b/toolkits/__init__.py
index 8b6f0114a..5b365ea1e 100644
--- a/toolkits/__init__.py
+++ b/toolkits/__init__.py
@@ -11,22 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
-from rlinf.algorithms.registry import get_reward_fn
-
-
-def register_rewards():
- try:
- from toolkits.code_verifier.verify import fim_verify_call
-
- assert get_reward_fn("fim_verify_call") == fim_verify_call
- except ImportError:
- pass
-
- try:
- from toolkits.math_verifier.verify import math_verify_call
-
- assert get_reward_fn("math") == math_verify_call
- except ImportError:
- pass
diff --git a/toolkits/auto_placement/scheduler_task.py b/toolkits/auto_placement/scheduler_task.py
index 00d8ea8aa..b3be46012 100644
--- a/toolkits/auto_placement/scheduler_task.py
+++ b/toolkits/auto_placement/scheduler_task.py
@@ -31,8 +31,10 @@ def __init__(
workflow_graph: Optional[Dict[ComponentNode, List[ComponentNode]]] = None,
):
self.cfg = cfg
- self.is_math = cfg.runner.task_type == "math"
- assert self.is_math, "Only math task is supported"
+ self.is_reasoning = cfg.runner.task_type == "reasoning"
+ assert self.is_reasoning, (
+ f"Only reasoning task is supported, current task type: {cfg.runner.task_type}"
+ )
self.components_config = {
"actor": {
@@ -71,7 +73,7 @@ def __init__(
self.global_step_batch_size = self.rollout_batch_size * self.group_size
if workflow_graph is None:
- if self.is_math:
+ if self.is_reasoning:
actor = ComponentNode("actor")
inference = ComponentNode("inference")
rollout = ComponentNode("rollout")
@@ -179,7 +181,7 @@ def parse_partition_allocation_to_cfg(
def time_division_multiplexing(self) -> List[Dict[str, Workflow]]:
partitions: List[Dict[str, Workflow]] = get_workflow_partition(self.workflow)
- if self.is_math:
+ if self.is_reasoning:
valid_partitions = [
i for i in partitions if len(i) in [1, len(self.components_config)]
]
diff --git a/toolkits/code_verifier/verify.py b/toolkits/code_verifier/verify.py
index 0ad756d02..5017b9e54 100644
--- a/toolkits/code_verifier/verify.py
+++ b/toolkits/code_verifier/verify.py
@@ -21,10 +21,8 @@
except ImportError:
fuzz = None
FUZZY_AVAILABLE = False
-from rlinf.algorithms.registry import register_reward_fn
-@register_reward_fn("fim_verify_call")
def fim_verify_call(
responses: List[str],
references: List[str],
diff --git a/toolkits/math_verifier/verify.py b/toolkits/math_verifier/verify.py
index 80bf4b552..8d0cbdb11 100644
--- a/toolkits/math_verifier/verify.py
+++ b/toolkits/math_verifier/verify.py
@@ -14,7 +14,13 @@
import multiprocessing
import re
-from concurrent.futures import ProcessPoolExecutor, as_completed
+from concurrent.futures import (
+ ProcessPoolExecutor,
+ as_completed,
+)
+from concurrent.futures import (
+ TimeoutError as FuturesTimeoutError,
+)
from typing import List, Union
import regex
@@ -23,7 +29,6 @@
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
-from rlinf.algorithms.registry import register_reward_fn
from toolkits.math_verifier.parser import extract_answer
global_executor = ProcessPoolExecutor(max_workers=40)
@@ -348,22 +353,22 @@ def process_results(answer, solution):
extracted_solution = extract_answer(solution, "math", use_last_number=True)
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
- retval = 0
+ retval = -1
elif extracted_solution is None or extracted_solution.strip() in [
"None",
"none",
"",
]:
- retval = 0
+ retval = -1
elif math_equal(extracted_answer, extracted_solution, timeout=False):
# elif call_with_timeout(math_equal, extracted_answer, extracted_solution):
retval = 1
else:
- retval = 0
+ retval = -1
return retval, (extracted_answer, extracted_solution)
except Exception:
- return 0, ("None", "None")
+ return -1, ("None", "None")
def process_results_process(a, b, output_queue):
@@ -383,7 +388,6 @@ def verify_math_solution(answer: str, solution: str):
return process_results(answer, solution)[0]
-@register_reward_fn("math")
def math_verify_call(
responses: List[str],
references: List[str],
@@ -403,7 +407,7 @@ def math_verify_call(
jobs.append(job)
all_jobs.append(jobs)
- labels = []
+ labels: List[int] = []
has_timeout = False
for jobs in all_jobs:
label = 0
@@ -411,40 +415,18 @@ def math_verify_call(
for job in as_completed(jobs, timeout=timeout):
x = job.result()
label = label or x
- except TimeoutError:
+ except FuturesTimeoutError:
has_timeout = True
- labels.append(label)
+ for job in jobs:
+ job.cancel()
+ finally:
+ labels.append(label)
if has_timeout:
reset_global_process_pool()
return labels
-class MathRewardModel:
- def __init__(self, scale: float):
- self.scale = scale
-
- def get_reward(
- self, response: List[str], reference: List[List[str]]
- ) -> List[float]:
- """
- Calculates reward scores for a list of responses compared to corresponding lists of reference answers.
- For each response, the function checks if it matches any of the provided references using the `process_results` function.
- The reward for each response is computed as the first element of the result (converted to float) multiplied by `self.scale`.
- Args:
- response (List[str]): A list of response strings to be evaluated.
- reference (List[List[str]]): A list where each element is a list of reference strings corresponding to each response.
- Returns:
- List[float]: A list of reward scores, one for each response.
- """
-
- results = []
- for resp, refs in zip(response, reference):
- result = any(process_results(resp, ref)[0] for ref in refs)
- results.append((1 if result else -1) * self.scale)
- return results
-
-
if __name__ == "__main__":
sample = {
"answers": ["\\boxed{-\\frac{2}{3}}"],