diff --git a/.github/workflows/code-test.yml b/.github/workflows/code-test.yml index c9c34a6c6..0ec8ca3ef 100644 --- a/.github/workflows/code-test.yml +++ b/.github/workflows/code-test.yml @@ -113,33 +113,48 @@ jobs: - name: Checkout code uses: actions/checkout@v5 - - name: SGLang Collocated mode + - name: Megatron SGLang Collocated mode + timeout-minutes: 20 + run: | + export REPO_PATH=$(pwd) + source switch_env reason + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-mg-sgl + + - name: Megatron vLLM Collocated mode + timeout-minutes: 20 + run: | + export REPO_PATH=$(pwd) + source switch_env reason + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-mg-vllm + + - name: Megatron SGLang Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/math/sglang/run_collocated.sh + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-pipeline-mg-sgl - - name: vLLM Collocated mode + - name: Megatron vLLM Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/math/vllm/run_collocated.sh + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-pipeline-mg-vllm - - name: SGLang Pipeline mode + - name: FSDP SGLang Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/math/sglang/run_pipeline.sh + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-fsdp-sgl - - name: vLLM Pipeline mode + - name: FSDP vLLM Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/math/vllm/run_pipeline.sh + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-fsdp-vllm + reason-qwen-grpo-test-rollout-logprobs: needs: [check-changes] @@ -149,33 +164,47 @@ jobs: - name: Checkout code uses: actions/checkout@v5 - - name: SGLang Collocated mode + - name: Megatron SGLang Collocated mode + timeout-minutes: 20 + run: | + export REPO_PATH=$(pwd) + source switch_env reason + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs + + - name: Megatron vLLM Collocated mode + timeout-minutes: 20 + run: | + export REPO_PATH=$(pwd) + source switch_env reason + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs + + - name: Megatron SGLang Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/math/sglang/run_collocated.sh qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs - - name: vLLM Collocated mode + - name: Megatron vLLM Pipeline mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/math/vllm/run_collocated.sh qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs - - name: SGLang Pipeline mode + - name: FSDP SGLang Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/math/sglang/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs - - name: vLLM Pipeline mode + - name: FSDP vLLM Collocated mode timeout-minutes: 20 run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/math/vllm/run_pipeline.sh qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml + bash tests/e2e_tests/reasoning/run.sh qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs coding-online-rl-qwen-ppo-test: needs: [check-changes] @@ -194,7 +223,29 @@ jobs: run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/coding_online_rl/run_coding_online_rl.sh + bash tests/e2e_tests/coding_online_rl/run.sh + + qwen-vl-grpo-test: + needs: [check-changes] + if: needs.check-changes.outputs.file_filter == 'true' + runs-on: reason + steps: + - name: Checkout code + uses: actions/checkout@v5 + + - name: FSDP SGLang Collocated mode + timeout-minutes: 20 + run: | + export REPO_PATH=$(pwd) + source switch_env reason + bash tests/e2e_tests/reasoning/run.sh qwen2.5-vl-3b-grpo-collocated-fsdp-sgl + + - name: FSDP vLLM Collocated mode + timeout-minutes: 20 + run: | + export REPO_PATH=$(pwd) + source switch_env reason + bash tests/e2e_tests/reasoning/run.sh qwen2.5-vl-3b-grpo-collocated-fsdp-vllm # =============================================== embodied e2e tests ==================================================== @@ -270,7 +321,7 @@ jobs: run: | export REPO_PATH=$(pwd) source switch_env reason - bash tests/e2e_tests/auto_placement/run_auto_placement.sh + bash tests/e2e_tests/auto_placement/run.sh # =============================================== finale ==================================================== @@ -283,7 +334,7 @@ jobs: # Reason e2e tests reason-qwen-grpo-test, reason-qwen-grpo-test-rollout-logprobs, - coding-online-rl-qwen-ppo-test, + coding-online-rl-qwen-ppo-test, qwen-vl-grpo-test, # Embodied e2e tests embodied-maniskill-ppo-openvla-test, embodied-maniskill-grpo-openvlaoft-test, embodied-libero-goal-grpo-openvlaoft-test,embodied-libero-130-grpo-openvlaoft-test, diff --git a/examples/embodiment/config/libero_10_grpo_openvlaoft.yaml b/examples/embodiment/config/libero_10_grpo_openvlaoft.yaml index 1d9720fdd..525b1c951 100644 --- a/examples/embodiment/config/libero_10_grpo_openvlaoft.yaml +++ b/examples/embodiment/config/libero_10_grpo_openvlaoft.yaml @@ -157,6 +157,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/libero_10_grpo_openvlaoft_eval.yaml b/examples/embodiment/config/libero_10_grpo_openvlaoft_eval.yaml index 628272ed1..69708e7c0 100644 --- a/examples/embodiment/config/libero_10_grpo_openvlaoft_eval.yaml +++ b/examples/embodiment/config/libero_10_grpo_openvlaoft_eval.yaml @@ -158,6 +158,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/libero_10_ppo_openvlaoft.yaml b/examples/embodiment/config/libero_10_ppo_openvlaoft.yaml index bf7e31667..15b33bbbc 100644 --- a/examples/embodiment/config/libero_10_ppo_openvlaoft.yaml +++ b/examples/embodiment/config/libero_10_ppo_openvlaoft.yaml @@ -152,6 +152,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/libero_goal_grpo_openvlaoft.yaml b/examples/embodiment/config/libero_goal_grpo_openvlaoft.yaml index d6356c314..699d7ab40 100644 --- a/examples/embodiment/config/libero_goal_grpo_openvlaoft.yaml +++ b/examples/embodiment/config/libero_goal_grpo_openvlaoft.yaml @@ -156,6 +156,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/libero_object_grpo_openvlaoft.yaml b/examples/embodiment/config/libero_object_grpo_openvlaoft.yaml index b767fd25a..d2bacd05e 100644 --- a/examples/embodiment/config/libero_object_grpo_openvlaoft.yaml +++ b/examples/embodiment/config/libero_object_grpo_openvlaoft.yaml @@ -156,6 +156,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/libero_spatial_grpo_openvlaoft.yaml b/examples/embodiment/config/libero_spatial_grpo_openvlaoft.yaml index 69aec20eb..9469166d9 100644 --- a/examples/embodiment/config/libero_spatial_grpo_openvlaoft.yaml +++ b/examples/embodiment/config/libero_spatial_grpo_openvlaoft.yaml @@ -156,6 +156,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/maniskill_grpo_openvla.yaml b/examples/embodiment/config/maniskill_grpo_openvla.yaml index 16dc2af06..3679d533d 100644 --- a/examples/embodiment/config/maniskill_grpo_openvla.yaml +++ b/examples/embodiment/config/maniskill_grpo_openvla.yaml @@ -158,6 +158,12 @@ actor: adam_eps: 1.0e-05 clip_grad: 1.0 + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/maniskill_grpo_openvlaoft.yaml b/examples/embodiment/config/maniskill_grpo_openvlaoft.yaml index def45aafb..7bd32855b 100644 --- a/examples/embodiment/config/maniskill_grpo_openvlaoft.yaml +++ b/examples/embodiment/config/maniskill_grpo_openvlaoft.yaml @@ -156,6 +156,12 @@ actor: adam_eps: 1.0e-05 clip_grad: 10.0 + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/maniskill_ppo_openvla.yaml b/examples/embodiment/config/maniskill_ppo_openvla.yaml index 6aeae632d..bf9cab8eb 100644 --- a/examples/embodiment/config/maniskill_ppo_openvla.yaml +++ b/examples/embodiment/config/maniskill_ppo_openvla.yaml @@ -154,6 +154,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/maniskill_ppo_openvla_quickstart.yaml b/examples/embodiment/config/maniskill_ppo_openvla_quickstart.yaml index 969dc85cb..b81d1390f 100644 --- a/examples/embodiment/config/maniskill_ppo_openvla_quickstart.yaml +++ b/examples/embodiment/config/maniskill_ppo_openvla_quickstart.yaml @@ -170,6 +170,12 @@ actor: adam_eps: 1.0e-05 clip_grad: 1.0 + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/maniskill_ppo_openvlaoft.yaml b/examples/embodiment/config/maniskill_ppo_openvlaoft.yaml index 1e5893f43..f957f1997 100644 --- a/examples/embodiment/config/maniskill_ppo_openvlaoft.yaml +++ b/examples/embodiment/config/maniskill_ppo_openvlaoft.yaml @@ -159,6 +159,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/embodiment/config/robotwin_ppo_openvlaoft.yaml b/examples/embodiment/config/robotwin_ppo_openvlaoft.yaml index 30700e1c2..56d936659 100644 --- a/examples/embodiment/config/robotwin_ppo_openvlaoft.yaml +++ b/examples/embodiment/config/robotwin_ppo_openvlaoft.yaml @@ -158,6 +158,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml new file mode 100644 index 000000000..3b31eecdc --- /dev/null +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-fsdp.yaml @@ -0,0 +1,225 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: ${runner.output_dir}/${runner.experiment_name} + project_name: rlinf + experiment_name: ${runner.experiment_name} + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 5 + max_steps: -1 + + val_check_interval: 1 + save_interval: 50 + + seq_length: 28672 + + enable_dynamic_batch_size: False + max_tokens_per_mbs: 28672 + + resume_dir: null + experiment_name: grpo-1.5b + output_dir: ../results +algorithm: + group_size: 8 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: True + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /home/weight/DeepSeek-R1-Distill-Qwen-1.5B-2layer/ + model_arch: qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm] + + sglang: + attention_backend: ascend # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: math + dataset_name: boba + max_prompt_length: 1024 + filter_prompt_by_length: True + rollout_batch_size: 16 + val_rollout_batch_size: null + num_workers: 2 + shuffle: True + validation_shuffle: True + seed: 1234 + train_data_paths: ["/home/dataset/boba/AReaL-boba-106k.jsonl"] + val_data_paths: ["/home/dataset/boba/AReaL-boba-106k.jsonl"] + prompt_key: prompt + image_keys: [image] + answer_key: answer + choice_key: choices + solution_key: null + use_chat_template: True + lazy_loading: True + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /home/weight/DeepSeek-R1-Distill-Qwen-1.5B-2layer/ + + optim: + optimizer: adam + bf16: True + fp16: False + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /home/weight/DeepSeek-R1-Distill-Qwen-1.5B-2layer/ + use_fast: False + trust_remote_code: True + padding_side: 'right' + + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'math' + reward_scale: 5.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + +critic: + use_critic_model: false diff --git a/examples/math/config/qwen2.5-1.5b-grpo-megatron-pipeline.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml similarity index 95% rename from examples/math/config/qwen2.5-1.5b-grpo-megatron-pipeline.yaml rename to examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml index 9ba641d15..ea4606003 100644 --- a/examples/math/config/qwen2.5-1.5b-grpo-megatron-pipeline.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron-pipeline.yaml @@ -12,9 +12,10 @@ cluster: rollout: 0-15 inference: 16-23 actor: 24-63 + reward: 0-15 runner: - task_type: math + task_type: reasoning logger: log_path: ${runner.output_dir}/${runner.experiment_name} project_name: rlinf @@ -65,6 +66,8 @@ algorithm: entropy_bonus: 0.0 calculate_entropy: True clip_ratio_c: null # 3.0 + clip_ratio_low: null # if null or not set, will use ratio_clip_eps + clip_ratio_high: null # if null or not set, will use ratio_clip_eps adv_type: math_grpo normalize_advantages: False @@ -134,6 +137,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 1024 filter_prompt_by_length: True rollout_batch_size: 512 @@ -146,6 +150,7 @@ data: train_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"] val_data_paths: ["/dataset/boba/AReaL-boba-106k.jsonl"] + actor: group_name: "ActorGroup" training_backend: megatron @@ -271,9 +276,15 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} critic: use_critic_model: false diff --git a/examples/math/config/qwen2.5-1.5b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml similarity index 94% rename from examples/math/config/qwen2.5-1.5b-grpo-megatron.yaml rename to examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml index 6f75dc38e..4d851f894 100644 --- a/examples/math/config/qwen2.5-1.5b-grpo-megatron.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-grpo-megatron.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 16 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: ${runner.output_dir}/${runner.experiment_name} project_name: rlinf @@ -61,6 +61,8 @@ algorithm: entropy_bonus: 0.0 calculate_entropy: False clip_ratio_c: null # 3.0 + clip_ratio_low: null # if null or not set, will use ratio_clip_eps + clip_ratio_high: null # if null or not set, will use ratio_clip_eps adv_type: math_grpo normalize_advantages: True @@ -121,6 +123,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 1024 filter_prompt_by_length: True rollout_batch_size: 512 @@ -258,9 +261,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/examples/math/config/qwen2.5-1.5b-single-gpu.yaml b/examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml similarity index 94% rename from examples/math/config/qwen2.5-1.5b-single-gpu.yaml rename to examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml index e3f5bf28d..1829050c1 100644 --- a/examples/math/config/qwen2.5-1.5b-single-gpu.yaml +++ b/examples/reasoning/config/math/qwen2.5-1.5b-single-gpu.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 1 component_placement: - actor,rollout: 0 + actor,rollout,reward: 0 runner: - task_type: math + task_type: reasoning logger: log_path: ${runner.output_dir}/${runner.experiment_name} project_name: rlinf @@ -61,6 +61,8 @@ algorithm: entropy_bonus: 0.0 calculate_entropy: False clip_ratio_c: null # 3.0 + clip_ratio_low: null # if null or not set, will use ratio_clip_eps + clip_ratio_high: null # if null or not set, will use ratio_clip_eps adv_type: math_grpo normalize_advantages: True @@ -258,9 +260,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/examples/math/config/qwen2.5-32b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml similarity index 95% rename from examples/math/config/qwen2.5-32b-grpo-megatron.yaml rename to examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml index 6e397dfda..e9eb2089e 100644 --- a/examples/math/config/qwen2.5-32b-grpo-megatron.yaml +++ b/examples/reasoning/config/math/qwen2.5-32b-grpo-megatron.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 32 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: ${runner.output_dir}/${runner.experiment_name} project_name: rlinf @@ -61,6 +61,8 @@ algorithm: entropy_bonus: 0.0 calculate_entropy: False clip_ratio_c: null # 3.0 + clip_ratio_low: null # if null or not set, will use ratio_clip_eps + clip_ratio_high: null # if null or not set, will use ratio_clip_eps adv_type: math_grpo normalize_advantages: True @@ -259,5 +261,11 @@ reward: reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/examples/math/config/qwen2.5-7b-grpo-megatron.yaml b/examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml similarity index 94% rename from examples/math/config/qwen2.5-7b-grpo-megatron.yaml rename to examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml index 63146687e..b2a70d6f8 100644 --- a/examples/math/config/qwen2.5-7b-grpo-megatron.yaml +++ b/examples/reasoning/config/math/qwen2.5-7b-grpo-megatron.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 16 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: ${runner.output_dir}/${runner.experiment_name} project_name: rlinf @@ -61,6 +61,8 @@ algorithm: entropy_bonus: 0.0 calculate_entropy: False clip_ratio_c: null # 3.0 + clip_ratio_low: null # if null or not set, will use ratio_clip_eps + clip_ratio_high: null # if null or not set, will use ratio_clip_eps adv_type: math_grpo normalize_advantages: True @@ -257,9 +259,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/examples/math/config/tp_comm_overlap_cfg.yaml b/examples/reasoning/config/tp_comm_overlap_cfg.yaml similarity index 100% rename from examples/math/config/tp_comm_overlap_cfg.yaml rename to examples/reasoning/config/tp_comm_overlap_cfg.yaml diff --git a/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml new file mode 100644 index 000000000..17c192fa2 --- /dev/null +++ b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp.yaml @@ -0,0 +1,228 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: ${runner.output_dir}/${runner.experiment_name} + project_name: rlinf + experiment_name: ${runner.experiment_name} + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 5 + max_steps: -1 + + val_check_interval: 1 + save_interval: 50 + + seq_length: 2048 + + enable_dynamic_batch_size: False + max_tokens_per_mbs: 28672 + + resume_dir: null + experiment_name: grpo-1.5b + output_dir: ../results + +algorithm: + group_size: 8 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: False + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct + model_arch: qwen2.5_vl #qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm] + + sglang: + attention_backend: ascend # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: vision_language + dataset_name: robo2vlm + max_prompt_length: 1024 + filter_prompt_by_length: True + rollout_batch_size: 8 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + image_keys: ["image"] # some vlm datasets may have multiple image columns + choice_key: "choices" + answer_key: "answer" + solution_key: "solution" + use_chat_template: True + lazy_loading: True + shuffle: True + validation_shuffle: True + seed: 1234 + train_data_paths: ["/home/x00922209/datasets/advaitgupta/robo2VLM/data/"] + val_data_paths: ["/home/x00922209/datasets/advaitgupta/robo2VLM/data/"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct + + model_arch: ${rollout.model_arch} + + optim: + optimizer: adam + bf16: True #False + fp16: False #True + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct + use_fast: False + trust_remote_code: True + padding_side: 'right' + + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'vqa' + reward_scale: 1.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct + use_fast: False + trust_remote_code: True + padding_side: 'right' + +critic: + use_critic_model: false diff --git a/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp_baak.yaml b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp_baak.yaml new file mode 100644 index 000000000..fa5b16080 --- /dev/null +++ b/examples/reasoning/config/vqa/qwen2.5-vl-3b-grpo-fsdp_baak.yaml @@ -0,0 +1,222 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: ${runner.output_dir}/${runner.experiment_name} + project_name: rlinf + experiment_name: ${runner.experiment_name} + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 5 + max_steps: -1 + + val_check_interval: 1 + save_interval: 50 + + seq_length: 2048 + + enable_dynamic_batch_size: False + max_tokens_per_mbs: 28672 + + resume_dir: null + experiment_name: grpo-1.5b + output_dir: ../results + +algorithm: + group_size: 8 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: False + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct + model_arch: qwen2.5_vl #qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. taoxu 1010 + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm] + + sglang: + attention_backend: ascend # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: vision_language + dataset_name: robo2vlm + max_prompt_length: 1024 + filter_prompt_by_length: True + rollout_batch_size: 8 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + image_keys: ["image"] # some vlm datasets may have multiple image columns + choice_key: "choices" + answer_key: "answer" + solution_key: "solution" + use_chat_template: True + lazy_loading: True + shuffle: True + validation_shuffle: True + seed: 1234 + train_data_paths: ["/home/x00922209/datasets/advaitgupta/robo2VLM/data/"] + val_data_paths: ["/home/x00922209/datasets/advaitgupta/robo2VLM/data/"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct + + model_arch: ${rollout.model_arch} + + optim: + optimizer: adam + bf16: True #False + fp16: False #True + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct + use_fast: False + trust_remote_code: True + padding_side: 'right' + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'vqa' + reward_scale: 1.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: /home/x00922209/models/Qwen/Qwen2.5-VL-3B-Instruct + use_fast: False + trust_remote_code: True + padding_side: 'right' + +critic: + use_critic_model: false \ No newline at end of file diff --git a/examples/math/main_math.py b/examples/reasoning/main_grpo.py similarity index 81% rename from examples/math/main_math.py rename to examples/reasoning/main_grpo.py index 7150566fb..30073d562 100644 --- a/examples/math/main_math.py +++ b/examples/reasoning/main_grpo.py @@ -21,12 +21,13 @@ from rlinf.config import validate_cfg from rlinf.data.datasets import create_rl_dataset from rlinf.data.tokenizers import hf_tokenizer -from rlinf.runners.math_runner import MathRunner +from rlinf.runners.reasoning_runner import ReasoningRunner from rlinf.scheduler import Cluster from rlinf.utils.placement import ModelParallelComponentPlacement, PlacementMode from rlinf.utils.utils import output_redirector -from rlinf.workers.actor.megatron_actor_worker import MegatronActor +from rlinf.workers.actor import get_actor_worker from rlinf.workers.inference.megatron_inference_worker import MegatronInference +from rlinf.workers.reward.reward_worker import RewardWorker from rlinf.workers.rollout.utils import get_rollout_backend_worker """Script to start GRPO training""" @@ -67,16 +68,25 @@ def main(cfg) -> None: placement_strategy=inference_placement_strategy, ) + # Reward group + reward_placement_strategy = component_placement.get_strategy("reward") + reward_group = RewardWorker.create_group(cfg, component_placement).launch( + cluster, + name=cfg.reward.group_name, + placement_strategy=reward_placement_strategy, + ) + # GRPO Actor group + actor_worker_cls = get_actor_worker(cfg) actor_placement_strategy = component_placement.get_strategy("actor") - actor_group = MegatronActor.create_group(cfg, component_placement).launch( + actor_group = actor_worker_cls.create_group(cfg, component_placement).launch( cluster, name=cfg.actor.group_name, placement_strategy=actor_placement_strategy ) tokenizer = hf_tokenizer(cfg.actor.tokenizer.tokenizer_model) - train_ds, val_ds = create_rl_dataset(cfg.data, tokenizer) + train_ds, val_ds = create_rl_dataset(cfg, tokenizer) - runner = MathRunner( + runner = ReasoningRunner( cfg=cfg, placement=component_placement, train_dataset=train_ds, @@ -84,6 +94,7 @@ def main(cfg) -> None: rollout=rollout_group, inference=inference_group, actor=actor_group, + reward=reward_group, ) runner.init_workers() diff --git a/examples/math/run_main_math_grpo_megatron.sh b/examples/reasoning/run_main_grpo_math.sh similarity index 70% rename from examples/math/run_main_math_grpo_megatron.sh rename to examples/reasoning/run_main_grpo_math.sh index f826f882f..1b123fd6f 100644 --- a/examples/math/run_main_math_grpo_megatron.sh +++ b/examples/reasoning/run_main_grpo_math.sh @@ -2,7 +2,6 @@ set -x tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS export CUDA_DEVICE_MAX_CONNECTIONS=1 export TOKENIZERS_PARALLELISM=false export RAY_DEDUP_LOGS=0 @@ -15,7 +14,7 @@ export PYTHONPATH=${REPO_PATH}:${MEGATRON_PATH}:$PYTHONPATH if [ -z "$1" ]; then CONFIG_NAME="qwen2.5-1.5b-grpo-megatron" else - CONFIG_NAME=$1 + CONFIG_NAME="qwen2.5-1.5b-grpo-fsdp.yaml" fi -python ${REPO_PATH}/examples/math/main_math.py --config-path ${CONFIG_PATH}/config/ --config-name $CONFIG_NAME \ No newline at end of file +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path ${CONFIG_PATH}/config/math/ --config-name $CONFIG_NAME diff --git a/examples/math/run_main_math_pipeline_grpo_megatron.sh b/examples/reasoning/run_main_grpo_vqa.sh old mode 100644 new mode 100755 similarity index 64% rename from examples/math/run_main_math_pipeline_grpo_megatron.sh rename to examples/reasoning/run_main_grpo_vqa.sh index 7deb96519..3cc526f0e --- a/examples/math/run_main_math_pipeline_grpo_megatron.sh +++ b/examples/reasoning/run_main_grpo_vqa.sh @@ -2,7 +2,6 @@ set -x tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS export CUDA_DEVICE_MAX_CONNECTIONS=1 export TOKENIZERS_PARALLELISM=false export RAY_DEDUP_LOGS=0 @@ -13,9 +12,9 @@ MEGATRON_PATH=/opt/Megatron-LM export PYTHONPATH=${REPO_PATH}:${MEGATRON_PATH}:$PYTHONPATH if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-1.5b-grpo-megatron-pipeline" + CONFIG_NAME="qwen2.5-vl-3b-grpo-fsdp" else CONFIG_NAME=$1 fi -python ${REPO_PATH}/examples/math/main_math.py --config-path ${CONFIG_PATH}/config/ --config-name $CONFIG_NAME \ No newline at end of file +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path ${CONFIG_PATH}/config/vqa/ --config-name $CONFIG_NAME \ No newline at end of file diff --git a/examples/math/run_placement_autotune.sh b/examples/reasoning/run_placement_autotune.sh similarity index 100% rename from examples/math/run_placement_autotune.sh rename to examples/reasoning/run_placement_autotune.sh diff --git a/fusion_result.json b/fusion_result.json new file mode 100644 index 000000000..20b56950d --- /dev/null +++ b/fusion_result.json @@ -0,0 +1,21 @@ +{ + "session_and_graph_id_0_0": { + "graph_fusion": { + "IndexByTensorStaticFusionPass": { + "effect_times": "0", + "match_times": "1" + }, + "RefreshInt64ToInt32FusionPass": { + "effect_times": "0", + "match_times": "1" + } + }, + "ub_fusion": { + "AutomaticUbFusion": { + "effect_times": "0", + "match_times": "2", + "repository_hit_times": "0" + } + } + } +} \ No newline at end of file diff --git a/kernel_meta/buildPidInfo.json b/kernel_meta/buildPidInfo.json new file mode 100644 index 000000000..9900ce4f6 --- /dev/null +++ b/kernel_meta/buildPidInfo.json @@ -0,0 +1,2754 @@ +[ + [ + 39911, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12797813670961250527" + ], + [ + 39913, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2319069557219941997" + ], + [ + 39917, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11862212240406296461" + ], + [ + 39922, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_885784172161213002" + ], + [ + 39926, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7410510130683415152" + ], + [ + 39930, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9194155428752533289" + ], + [ + 39934, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10424833431148169855" + ], + [ + 39936, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9855745055097868770" + ], + [ + 39939, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18191654686185573965" + ], + [ + 39942, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1065590051412680278" + ], + [ + 39949, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16967210759787679084" + ], + [ + 39950, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8381815239432563187" + ], + [ + 39958, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2867384563160197328" + ], + [ + 39965, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4547732597630948415" + ], + [ + 39971, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15778275423241821855" + ], + [ + 39980, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5207596602532600872" + ], + [ + 40666, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15835674102581580671" + ], + [ + 82796, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8782270786933035597" + ], + [ + 100620, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7353802482069301333" + ], + [ + 112704, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7725153434515021512" + ], + [ + 122239, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7210453806497323497" + ], + [ + 146045, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3405930793784815214" + ], + [ + 146076, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12695548580890247073" + ], + [ + 146140, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1605098968774552756" + ], + [ + 146205, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5766664014834865721" + ], + [ + 146347, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9612635458804393825" + ], + [ + 146447, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7725679307351709877" + ], + [ + 146559, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14784440107212476086" + ], + [ + 146567, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17267848541246919924" + ], + [ + 146569, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17784643373181208859" + ], + [ + 146573, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10067831816227681826" + ], + [ + 146589, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1044498738075030000" + ], + [ + 158951, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15191603587064731993" + ], + [ + 159046, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6443795088650938062" + ], + [ + 159116, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7617045580398011601" + ], + [ + 159145, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17761038868925529398" + ], + [ + 159234, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8618471054285342305" + ], + [ + 159360, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16579794217881527508" + ], + [ + 159440, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10254394513325194666" + ], + [ + 159460, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11087666240252757180" + ], + [ + 159513, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9428322862212375668" + ], + [ + 159626, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16449556766463462195" + ], + [ + 159768, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1071949320798082478" + ], + [ + 159861, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9565697075048113424" + ], + [ + 159886, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3062431087256661843" + ], + [ + 160022, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13353626036029525043" + ], + [ + 161093, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_508908986595913110" + ], + [ + 162181, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3563660611022476863" + ], + [ + 259575, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8201276825499370867" + ], + [ + 259577, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15529659103780327911" + ], + [ + 259581, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10174461859794700718" + ], + [ + 259585, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16507297788703908047" + ], + [ + 259587, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11966777041657863002" + ], + [ + 259592, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2726330588881717531" + ], + [ + 259596, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17030647279409832227" + ], + [ + 259598, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12186607257866309998" + ], + [ + 259601, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3768772217671096392" + ], + [ + 259607, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3911160194709892748" + ], + [ + 259612, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7977521137351637493" + ], + [ + 259619, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4128563410056480856" + ], + [ + 259626, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9581134290190490336" + ], + [ + 259628, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_702154103913980934" + ], + [ + 259636, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14027257268528873801" + ], + [ + 259643, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17044983378904657187" + ], + [ + 260342, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3738242728542939509" + ], + [ + 260402, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15806158438452141266" + ], + [ + 260467, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4567683529506793175" + ], + [ + 260537, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10271642930345239273" + ], + [ + 260629, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3945686445184445770" + ], + [ + 260690, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17057962697473254798" + ], + [ + 260755, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16871514587483264100" + ], + [ + 260763, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2675279169353877702" + ], + [ + 260795, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14119700360366649411" + ], + [ + 315952, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11641596342314442912" + ], + [ + 329434, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17017774772109951435" + ], + [ + 340293, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_584578301371656169" + ], + [ + 352185, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11288852585934677911" + ], + [ + 367542, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6263883255072188097" + ], + [ + 367633, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5888503093289751929" + ], + [ + 367666, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5420145819910730026" + ], + [ + 379262, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1783494569175948891" + ], + [ + 379266, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_537678043780734220" + ], + [ + 379285, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5877535174558663131" + ], + [ + 379286, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10249842372811971323" + ], + [ + 379383, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_656090948065222947" + ], + [ + 379674, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11244763063226368392" + ], + [ + 380304, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_231890364717191513" + ], + [ + 380313, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5692311019013524597" + ], + [ + 380315, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11068432615228024246" + ], + [ + 380440, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15012465798244511954" + ], + [ + 380532, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3776119111621864505" + ], + [ + 380533, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6221575824346386447" + ], + [ + 380834, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11465674300335516174" + ], + [ + 381115, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7186726943654900649" + ], + [ + 381318, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12507534548909153748" + ], + [ + 381749, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15611949621196386618" + ], + [ + 446067, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15585473429645162327" + ], + [ + 446070, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2905662668684958995" + ], + [ + 446077, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9261524943542376262" + ], + [ + 446081, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11733848515447041665" + ], + [ + 446083, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13074589690020057687" + ], + [ + 446087, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6740951685981297380" + ], + [ + 446089, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1230291348499102449" + ], + [ + 446095, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8512469513541935614" + ], + [ + 446098, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6734225289254301183" + ], + [ + 446100, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3193003074983928221" + ], + [ + 446107, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17730893955086431941" + ], + [ + 446108, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14296202533891893120" + ], + [ + 446115, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17201341106939116953" + ], + [ + 446122, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17335486450140669799" + ], + [ + 446128, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_172524440323681068" + ], + [ + 446139, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8488055523413543743" + ], + [ + 446862, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11022287578591461948" + ], + [ + 475137, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5524899443716147254" + ], + [ + 484622, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9030794170733368903" + ], + [ + 494550, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10177095432241817550" + ], + [ + 508239, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10453435395226908113" + ], + [ + 519345, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15868233357086722582" + ], + [ + 552276, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15438796748740496551" + ], + [ + 552339, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6239891206884327744" + ], + [ + 552379, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9273241722792499069" + ], + [ + 552468, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16552715273014474462" + ], + [ + 552569, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3196337734268543926" + ], + [ + 552675, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5847337843613087850" + ], + [ + 552742, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6352307575227886689" + ], + [ + 552750, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1982943654522997836" + ], + [ + 552783, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3828732781290684800" + ], + [ + 552794, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16977504117571897465" + ], + [ + 565573, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5204528607269617447" + ], + [ + 565699, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4727225444131891245" + ], + [ + 565716, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3170838120379853841" + ], + [ + 565750, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2776707022079058597" + ], + [ + 566045, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12036042907429605982" + ], + [ + 566310, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12421069780630334687" + ], + [ + 566383, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3217743380181589526" + ], + [ + 566422, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1276333807841191048" + ], + [ + 566494, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7963328572114012833" + ], + [ + 566534, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6674089693496856100" + ], + [ + 566844, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9863761807279720849" + ], + [ + 567668, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16074272625414500687" + ], + [ + 567773, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4282295767343065807" + ], + [ + 567937, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2161356974128856154" + ], + [ + 567993, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8242040281073172353" + ], + [ + 568229, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15920297452776001566" + ], + [ + 637429, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_959368100671980870" + ], + [ + 637431, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14153015096824152693" + ], + [ + 637436, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_899818523044574745" + ], + [ + 637439, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10247733762527698150" + ], + [ + 637444, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14145287080845176151" + ], + [ + 637447, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17222031087440245104" + ], + [ + 637453, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12136397179187398942" + ], + [ + 637456, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8731812876370325350" + ], + [ + 637459, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10602237885762750118" + ], + [ + 637462, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16594426567840719304" + ], + [ + 637468, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7366125432199894168" + ], + [ + 637476, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11110698096089951252" + ], + [ + 637484, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18188682871993650937" + ], + [ + 637489, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10878533416589658592" + ], + [ + 637495, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8079586613243454881" + ], + [ + 637502, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2215817028210768871" + ], + [ + 670212, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15344779562198194740" + ], + [ + 680392, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18388393187931224557" + ], + [ + 690762, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3514287218513747531" + ], + [ + 743351, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6478191836071807723" + ], + [ + 743388, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12832048017481839151" + ], + [ + 743400, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10541782531111842730" + ], + [ + 743475, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10763222511672298619" + ], + [ + 743664, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17197230416625906030" + ], + [ + 743787, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18047780774813550571" + ], + [ + 743849, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2446754395186006741" + ], + [ + 743875, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5679073243145878981" + ], + [ + 743879, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16603745088718856023" + ], + [ + 743887, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13969638480830774322" + ], + [ + 743895, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13369228484466567544" + ], + [ + 743998, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12762855693433035364" + ], + [ + 744081, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3347649343482407740" + ], + [ + 757025, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15796729346284784432" + ], + [ + 757028, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14967701666120261818" + ], + [ + 757039, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5372387541384308030" + ], + [ + 757071, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3245068616512596448" + ], + [ + 757197, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14277746927271838195" + ], + [ + 757272, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1831510647579988021" + ], + [ + 758082, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1963453102454522286" + ], + [ + 803617, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15623709472216942560" + ], + [ + 803620, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5962434898661482018" + ], + [ + 803625, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14549239843071968661" + ], + [ + 803630, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8567544072600264650" + ], + [ + 803633, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18407024629481036527" + ], + [ + 803664, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15960439059008180895" + ], + [ + 803666, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3263727124264509013" + ], + [ + 803667, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9451034478073827629" + ], + [ + 803669, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18226372968931963744" + ], + [ + 803670, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14202429746807222158" + ], + [ + 803674, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7069975997530525703" + ], + [ + 803675, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8834989374251547639" + ], + [ + 803679, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17156716697792514660" + ], + [ + 803681, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7928243516609743611" + ], + [ + 803683, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17472278663547573068" + ], + [ + 803688, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9297387177610951812" + ], + [ + 803858, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9931525164958826644" + ], + [ + 803962, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18009242967987784208" + ], + [ + 804016, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2874613183183755843" + ], + [ + 804136, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4244083726001215679" + ], + [ + 804179, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9469466012086032567" + ], + [ + 804247, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6732244663862958131" + ], + [ + 804321, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17195272942189944812" + ], + [ + 804392, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10755163416630026042" + ], + [ + 804513, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12350748768321932278" + ], + [ + 804688, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6894846644228586388" + ], + [ + 873633, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15791320709907010305" + ], + [ + 888932, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17017344803712162999" + ], + [ + 898268, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12878681295278483117" + ], + [ + 904854, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13733947334674992446" + ], + [ + 910408, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3037428981486910277" + ], + [ + 910622, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12284873316976139363" + ], + [ + 922414, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4043581537781464345" + ], + [ + 922752, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14805333628576514738" + ], + [ + 922753, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5878253078724431956" + ], + [ + 922763, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15057904892991266440" + ], + [ + 922779, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18116003072026181513" + ], + [ + 922864, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3039955536516904891" + ], + [ + 922878, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8073047691309365525" + ], + [ + 923385, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3835924806113435602" + ], + [ + 923770, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9436136851994508472" + ], + [ + 923881, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12348575658552678709" + ], + [ + 924169, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3100439312273551508" + ], + [ + 924681, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3582699337697878406" + ], + [ + 925395, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17118759297342529823" + ], + [ + 925409, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18056433885446449459" + ], + [ + 925631, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13111124903733223810" + ], + [ + 925674, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2449886720924222030" + ], + [ + 974754, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_206876370358077584" + ], + [ + 974757, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11257961443249735269" + ], + [ + 998802, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4978121236847435213" + ], + [ + 1003838, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17874922247557918464" + ], + [ + 1008420, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14756252425461487877" + ], + [ + 1051386, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15672847212087630880" + ], + [ + 1056292, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11690764893981564089" + ], + [ + 1077929, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18243756878085111260" + ], + [ + 1078664, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13101390950818077021" + ], + [ + 1092149, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4702484821586339293" + ], + [ + 1092150, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8700607087392051310" + ], + [ + 1092167, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11867004705859069539" + ], + [ + 1092168, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14425483843466519342" + ], + [ + 1092170, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3137345424683392113" + ], + [ + 1092171, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5580715293573031613" + ], + [ + 1092247, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9573985685283690942" + ], + [ + 1092248, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1117326061986011595" + ], + [ + 1092430, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12852273377442787843" + ], + [ + 1092431, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2948619980269787118" + ], + [ + 1092814, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16608503441526049509" + ], + [ + 1092815, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8557565976778605998" + ], + [ + 1092974, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1586571163572556133" + ], + [ + 1092975, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15935510270417225075" + ], + [ + 1093267, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4402275187066910043" + ], + [ + 1093268, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5557153902550791572" + ], + [ + 1196049, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8425342506929295889" + ], + [ + 1196053, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18378899707502877030" + ], + [ + 1196056, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3094735019997891187" + ], + [ + 1196060, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6889997414760764962" + ], + [ + 1196065, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16740248499631073491" + ], + [ + 1196068, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8147418516272875846" + ], + [ + 1196070, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7125085183662271974" + ], + [ + 1196076, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17768317808754919941" + ], + [ + 1196345, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12412202691532259202" + ], + [ + 1196434, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10867201177510314690" + ], + [ + 1196530, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17300580725555235990" + ], + [ + 1196591, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11922602998666607133" + ], + [ + 1196658, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17067792987140972942" + ], + [ + 1196718, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6156226292697541158" + ], + [ + 1196789, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11972615607024541236" + ], + [ + 1225343, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16644111413419115375" + ], + [ + 1248801, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18426022297842133396" + ], + [ + 1254600, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13015264027012895286" + ], + [ + 1302314, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2178400540252727344" + ], + [ + 1302417, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15389620440927117005" + ], + [ + 1302502, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13548285859569365661" + ], + [ + 1302542, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3104574529036971715" + ], + [ + 1302580, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9161340348181567069" + ], + [ + 1302678, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2235073576100882104" + ], + [ + 1313288, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10714691217339955812" + ], + [ + 1313289, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15436382972359969753" + ], + [ + 1313295, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10087935773393978978" + ], + [ + 1313297, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1956045676144969910" + ], + [ + 1313413, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10346769137148346031" + ], + [ + 1313414, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12500094804178281846" + ], + [ + 1313594, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2233046969904106531" + ], + [ + 1313602, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17274697924548824848" + ], + [ + 1313910, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1971305477396827550" + ], + [ + 1313911, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1123480945946248614" + ], + [ + 1313913, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13632501687006175122" + ], + [ + 1313914, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17948247176888184637" + ], + [ + 1314424, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4838020015312522463" + ], + [ + 1314425, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5662771786044476863" + ], + [ + 1314464, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11644308174162608419" + ], + [ + 1314465, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14904900394694390614" + ], + [ + 1364104, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12599208991475254956" + ], + [ + 1364109, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11714438223684649350" + ], + [ + 1364112, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12536573013377539960" + ], + [ + 1364118, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9176497740794462077" + ], + [ + 1364120, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12851108834171289286" + ], + [ + 1364124, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7874367383020722434" + ], + [ + 1364133, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11768575761097821347" + ], + [ + 1364136, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9136218849828640261" + ], + [ + 1364436, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15227420143784422207" + ], + [ + 1364530, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_678152374747580855" + ], + [ + 1364650, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1863822375592647081" + ], + [ + 1364684, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14158128526209592340" + ], + [ + 1364721, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4584616283426803771" + ], + [ + 1364806, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13235463644610255264" + ], + [ + 1364877, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7481851041269423666" + ], + [ + 1364909, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9126134503219493660" + ], + [ + 1365000, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12501320043627540492" + ], + [ + 1365092, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16659304774143478016" + ], + [ + 1398994, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7719667656758625427" + ], + [ + 1414384, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8539755814366516822" + ], + [ + 1425199, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12779466743610801848" + ], + [ + 1434341, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5346047771258662063" + ], + [ + 1443931, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8500329968264923909" + ], + [ + 1471398, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2835195402119160546" + ], + [ + 1480308, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17921460277698942329" + ], + [ + 1480309, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6060307077504199420" + ], + [ + 1480544, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6633913915366293595" + ], + [ + 1480545, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5929325753921586179" + ], + [ + 1480762, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5087823519142655646" + ], + [ + 1480763, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17650771157019299908" + ], + [ + 1481439, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17960349755202839121" + ], + [ + 1481440, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11344166716576316064" + ], + [ + 1481478, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18011110268539351239" + ], + [ + 1481479, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13212044995188295238" + ], + [ + 1481983, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7786193257091011544" + ], + [ + 1481984, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4535156793004568521" + ], + [ + 1482479, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7392738499711986039" + ], + [ + 1482480, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16803447451398881609" + ], + [ + 1482675, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5816531416034365646" + ], + [ + 1482676, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10586423650728537576" + ], + [ + 1531952, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15737151482683442731" + ], + [ + 1531958, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5428763241135066291" + ], + [ + 1531961, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10672990360669572384" + ], + [ + 1531965, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7671654195981105868" + ], + [ + 1531970, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14942561136263416232" + ], + [ + 1531973, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16581135063665980351" + ], + [ + 1531977, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4856394302363438782" + ], + [ + 1531981, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15877297897966392256" + ], + [ + 1532252, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6425863017087684784" + ], + [ + 1532297, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1423327815346902785" + ], + [ + 1532410, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16344047352049658364" + ], + [ + 1532505, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17305619576099580631" + ], + [ + 1532566, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8395635612082422722" + ], + [ + 1532630, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8519186747189401928" + ], + [ + 1532664, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1253792069041758400" + ], + [ + 1532734, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16904214805911888839" + ], + [ + 1532855, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8506737276851019991" + ], + [ + 1532973, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_649535223346514500" + ], + [ + 1533011, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5898366698551436494" + ], + [ + 1533074, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5096822643445805939" + ], + [ + 1533138, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14031496100727112023" + ], + [ + 1533230, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6266968138694899706" + ], + [ + 1533290, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16937440140735016382" + ], + [ + 1533334, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10380564371013497541" + ], + [ + 1646912, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16348313361094870629" + ], + [ + 1646913, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16036313845476488225" + ], + [ + 1646927, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7913415713848588049" + ], + [ + 1646928, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4103568604206357407" + ], + [ + 1647271, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3917280544582020227" + ], + [ + 1647272, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13152661351180665451" + ], + [ + 1647797, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6759349546217838702" + ], + [ + 1647798, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10771757146841459066" + ], + [ + 1649758, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6186015994533271987" + ], + [ + 1649759, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9297921486035376125" + ], + [ + 1650053, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17376928305113794204" + ], + [ + 1650054, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14705463419645039338" + ], + [ + 1650074, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17507795361488404022" + ], + [ + 1650075, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_791693539870127207" + ], + [ + 1650077, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11601476096660261739" + ], + [ + 1650078, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12330363696036321630" + ], + [ + 1698356, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3664737074668484302" + ], + [ + 1698360, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_307407573315947332" + ], + [ + 1698364, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9769783361472155454" + ], + [ + 1698365, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11620844252758829593" + ], + [ + 1698369, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18049344638709473559" + ], + [ + 1698371, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12020994052316156595" + ], + [ + 1698376, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11281246552810662244" + ], + [ + 1698381, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2375830332811294762" + ], + [ + 1698622, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17464125010525395920" + ], + [ + 1698715, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8252497263861944219" + ], + [ + 1698754, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2367183499606118175" + ], + [ + 1698875, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3307535582648818046" + ], + [ + 1698960, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14070843042311209568" + ], + [ + 1699023, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3711321641783824900" + ], + [ + 1699057, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4067246543661527215" + ], + [ + 1699128, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9510008041698682979" + ], + [ + 1699222, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15070956721185346780" + ], + [ + 1699307, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16822435983994287417" + ], + [ + 1699350, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9200131402658211969" + ], + [ + 1699410, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4312347285030704895" + ], + [ + 1699501, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8452932345320744230" + ], + [ + 1699592, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18319826375533952136" + ], + [ + 1745273, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9427035578749238306" + ], + [ + 1752419, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11861522220206622819" + ], + [ + 1815709, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8461056289216019937" + ], + [ + 1815710, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15670507367599225435" + ], + [ + 1815718, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14014685830231205955" + ], + [ + 1815719, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8933292223775536158" + ], + [ + 1815722, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4670593314428329079" + ], + [ + 1815723, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14753117404955500077" + ], + [ + 1815761, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14867940080031341584" + ], + [ + 1815765, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9603778874236517966" + ], + [ + 1815801, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12882566826319414462" + ], + [ + 1815802, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10278462134454322731" + ], + [ + 1815807, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_634460488610617389" + ], + [ + 1815809, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3787426685243459175" + ], + [ + 1816092, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6900943374411146311" + ], + [ + 1816098, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2936618988510319244" + ], + [ + 1816749, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5025953421479037474" + ], + [ + 1816757, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10504877859244548446" + ], + [ + 3034511, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3109980201143968959" + ], + [ + 3034516, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16225395238917630830" + ], + [ + 3034519, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3014914144290525407" + ], + [ + 3034520, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4368939354329067568" + ], + [ + 3034525, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8231836696933284847" + ], + [ + 3034528, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11274985324515364055" + ], + [ + 3034533, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16776878341103524463" + ], + [ + 3034534, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11836972797422572451" + ], + [ + 3034538, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11222685037651458786" + ], + [ + 3034543, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2358340237215830214" + ], + [ + 3034547, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5530069306192163345" + ], + [ + 3034553, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15467614126156452015" + ], + [ + 3034559, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5799901149805008020" + ], + [ + 3034568, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13480646072272034112" + ], + [ + 3034574, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9925042805192677700" + ], + [ + 3034584, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5749985854107212444" + ], + [ + 3139002, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16606806383596787697" + ], + [ + 3139034, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13693873830274303283" + ], + [ + 3139037, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1174647644274299681" + ], + [ + 3139041, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8463715729614322689" + ], + [ + 3139046, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13622137732458691231" + ], + [ + 3139061, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13976382821074310028" + ], + [ + 3139103, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5674315837922161200" + ], + [ + 3139186, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9519308949519055442" + ], + [ + 3139256, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11870662408869016824" + ], + [ + 3139325, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9203351225653898140" + ], + [ + 3139379, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15928929015260880554" + ], + [ + 3139390, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4002276556586837015" + ], + [ + 3139395, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_869027509639957243" + ], + [ + 3139439, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7590980568514383017" + ], + [ + 3139477, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5691785906114378471" + ], + [ + 3139578, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1636167670197209792" + ], + [ + 3153344, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14721186196409181902" + ], + [ + 3153350, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8038130317126733515" + ], + [ + 3153355, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9363551798712942410" + ], + [ + 3153360, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4001983642343932187" + ], + [ + 3153908, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3967378375263321794" + ], + [ + 3154349, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1956150361493637850" + ], + [ + 3154374, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9515009690803494350" + ], + [ + 3154941, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6403793289284352763" + ], + [ + 3155135, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18377681884643089435" + ], + [ + 3155270, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2894252287023698380" + ], + [ + 3155272, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1176757978366636180" + ], + [ + 3155356, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10243682700450778485" + ], + [ + 3156048, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13419435171562057387" + ], + [ + 3156349, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5310938153691480087" + ], + [ + 3156373, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2839757395093087158" + ], + [ + 3156759, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_189645908260721847" + ], + [ + 3357366, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14712843742675416789" + ], + [ + 3357370, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14441746631406223373" + ], + [ + 3357372, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16051981893101423452" + ], + [ + 3357379, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7313357428863899572" + ], + [ + 3357382, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15741756964888684277" + ], + [ + 3357385, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_518244284579097490" + ], + [ + 3357390, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9623044147940949233" + ], + [ + 3357396, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_188046812284482511" + ], + [ + 3357815, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2070546872302042284" + ], + [ + 3357878, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11580410533744131165" + ], + [ + 3357919, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9721818878768485015" + ], + [ + 3358001, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5097555846270763399" + ], + [ + 3358063, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12100682775419753259" + ], + [ + 3358121, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_831802341219462707" + ], + [ + 3358154, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3511804342372595424" + ], + [ + 3358199, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13392040954474567685" + ], + [ + 3358318, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1488600423820696507" + ], + [ + 3358415, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7033478551626861368" + ], + [ + 3358501, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10949369291795034167" + ], + [ + 3358546, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9575369317971095792" + ], + [ + 3358599, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4171338315313792908" + ], + [ + 3358636, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3479887625518456208" + ], + [ + 3358704, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10816808840412347489" + ], + [ + 3358735, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4565693084836619869" + ], + [ + 3473543, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6198223556594770384" + ], + [ + 3473551, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10258537246918590156" + ], + [ + 3473616, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8769553907161686792" + ], + [ + 3473617, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3773271510051838348" + ], + [ + 3473932, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13489168884906494727" + ], + [ + 3473933, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18253176944629380716" + ], + [ + 3473943, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10561446346587147503" + ], + [ + 3473944, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1057740105323334353" + ], + [ + 3474186, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16534136807012196550" + ], + [ + 3474187, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3053221032925412808" + ], + [ + 3474202, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11135988479496829721" + ], + [ + 3474205, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15763171237614044906" + ], + [ + 3475761, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18163430628544684227" + ], + [ + 3475762, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_906162624150815971" + ], + [ + 3475895, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16284538720262805411" + ], + [ + 3475896, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6471773242206149096" + ], + [ + 3524930, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7351702211122833722" + ], + [ + 3524938, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_849812185320057193" + ], + [ + 3524944, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_899185330692270860" + ], + [ + 3524949, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16251077242215705442" + ], + [ + 3524957, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17261980608195435381" + ], + [ + 3524959, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15728989693202867994" + ], + [ + 3524963, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_578314422480040874" + ], + [ + 3524972, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5858027835250679442" + ], + [ + 3625661, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8595121483727162241" + ], + [ + 3628698, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6644896138201857560" + ], + [ + 3628765, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14380676338487055612" + ], + [ + 3629117, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3884026636792605184" + ], + [ + 3629128, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8471987555244152278" + ], + [ + 3629159, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10706976348774196355" + ], + [ + 3629253, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6353439240478980341" + ], + [ + 3629283, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11201244953910665068" + ], + [ + 3629360, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13965969812663104297" + ], + [ + 3629455, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15143351641328536949" + ], + [ + 3629517, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1731228908521478078" + ], + [ + 3629588, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1280418295423790350" + ], + [ + 3629613, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8333685548859735493" + ], + [ + 3629770, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6530807606746596767" + ], + [ + 3629792, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7249227128131175886" + ], + [ + 3629874, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17572995736186529549" + ], + [ + 3642042, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_504718820247128253" + ], + [ + 3642043, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5621471611596574380" + ], + [ + 3642063, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5859789574421468396" + ], + [ + 3642065, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1749104467585305323" + ], + [ + 3642198, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16903094956545060933" + ], + [ + 3642199, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15476704792505050546" + ], + [ + 3642237, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16807932559841593208" + ], + [ + 3642238, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8310833007698505204" + ], + [ + 3642367, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4707148721279721064" + ], + [ + 3642371, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9013616000502015446" + ], + [ + 3642732, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12951651460115558391" + ], + [ + 3642733, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6856766171535732498" + ], + [ + 3643042, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14506781066108968506" + ], + [ + 3643043, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7374279504609390465" + ], + [ + 3643572, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5955043855143664956" + ], + [ + 3643573, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8196104095227417096" + ], + [ + 3691270, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2620465584826061471" + ], + [ + 3691273, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1757388310219074848" + ], + [ + 3691281, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11097774192912884566" + ], + [ + 3691283, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4742545615202517810" + ], + [ + 3691285, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16070580061589069085" + ], + [ + 3691287, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6948592204353944837" + ], + [ + 3691292, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4433630213579020385" + ], + [ + 3691296, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15584731480290088561" + ], + [ + 3691301, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_221849414828219018" + ], + [ + 3691305, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_447840502870009883" + ], + [ + 3691312, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15894757923477358902" + ], + [ + 3691316, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2064442071811218632" + ], + [ + 3691318, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16020379562811154133" + ], + [ + 3691328, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1281388017954056527" + ], + [ + 3691337, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2331797143551461098" + ], + [ + 3691344, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13357220641161590082" + ], + [ + 3692044, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18315790265804955058" + ], + [ + 3692064, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17507869003332814399" + ], + [ + 3692144, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4507988044373530697" + ], + [ + 3692262, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18426271347408772907" + ], + [ + 3722920, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5017140278663970478" + ], + [ + 3736238, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12999327330513004683" + ], + [ + 3756817, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11445498054270621315" + ], + [ + 3765828, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1383945043061831987" + ], + [ + 3774374, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12318664942362170822" + ], + [ + 3783232, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3418024057005528102" + ], + [ + 3797970, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8613331899495270727" + ], + [ + 3798096, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9878016966715019087" + ], + [ + 3798154, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6627377776921054141" + ], + [ + 3798272, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10804569796438592151" + ], + [ + 3798362, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13557064605936120080" + ], + [ + 3798416, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5828012701014328691" + ], + [ + 3810814, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8002480603490928081" + ], + [ + 3811199, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15215890719473899494" + ], + [ + 3811238, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5765224812454476842" + ], + [ + 3811276, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16426201200428659968" + ], + [ + 3811288, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17001284103737435638" + ], + [ + 3811588, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15465033144293705320" + ], + [ + 3811661, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11237554842213579307" + ], + [ + 3811984, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4208370878176288374" + ], + [ + 3812122, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_859880967845588792" + ], + [ + 3812520, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11075357886525732487" + ], + [ + 3812639, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8682143685678125481" + ], + [ + 3812662, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12325636616321695255" + ], + [ + 3812698, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14800074808472192333" + ], + [ + 3812997, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12158776810784339196" + ], + [ + 3813382, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7085330449950753708" + ], + [ + 3813392, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9314724528373063174" + ], + [ + 3868673, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8427308000189841720" + ], + [ + 3868680, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_731075373502833868" + ], + [ + 3868682, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9807646971983976701" + ], + [ + 3868684, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_708313394526412267" + ], + [ + 3868688, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13447384123535294566" + ], + [ + 3868692, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14036027334014170087" + ], + [ + 3868696, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3187295055034313317" + ], + [ + 3868700, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3850755952385802565" + ], + [ + 3868705, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10106160760700975956" + ], + [ + 3868706, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17824178791258425552" + ], + [ + 3868711, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11909562160231335840" + ], + [ + 3868719, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13607563626468215022" + ], + [ + 3868727, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16787076582919642185" + ], + [ + 3868737, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8246440632798892168" + ], + [ + 3868739, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4159297518721824730" + ], + [ + 3868745, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18015449808950218451" + ], + [ + 3869439, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14196112055158873872" + ], + [ + 3869482, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14077174710107593045" + ], + [ + 3869512, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8061984655211037805" + ], + [ + 3869576, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16985047152068049914" + ], + [ + 3869667, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12041070490733630674" + ], + [ + 3869787, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5526959636930734514" + ], + [ + 3869824, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14021406351436652831" + ], + [ + 3903350, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8500099511158252282" + ], + [ + 3908889, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12265936708546584396" + ], + [ + 3919175, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6922556638284219234" + ], + [ + 3939646, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17523183083010355471" + ], + [ + 3953020, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6290038540260580030" + ], + [ + 3965009, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14272343095608536296" + ], + [ + 3976392, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5029439665772523493" + ], + [ + 3976512, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11697369406167774290" + ], + [ + 3976632, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3263338739247539272" + ], + [ + 3988513, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_787927194113102781" + ], + [ + 3988531, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18023508788842543544" + ], + [ + 3988558, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12183092996673180058" + ], + [ + 3988563, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4831310188807215850" + ], + [ + 3988807, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2198912152797944480" + ], + [ + 3988848, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4213118868561974759" + ], + [ + 3988850, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3232631311510150968" + ], + [ + 3988957, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11047352039168056415" + ], + [ + 3989106, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_16542734713691156018" + ], + [ + 3989387, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15165735529962902451" + ], + [ + 3989535, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_12857832019022611393" + ], + [ + 3989541, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8287593005309603838" + ], + [ + 3989647, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18367334379009715015" + ], + [ + 3990050, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4446833231210301139" + ], + [ + 3990104, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5145384017407119481" + ], + [ + 3990359, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6801533416498745154" + ], + [ + 4039903, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17584089607946077100" + ], + [ + 4039907, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_503447699485187456" + ], + [ + 4039912, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15069983805839197528" + ], + [ + 4039916, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14542568312346822221" + ], + [ + 4039918, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4892714578234319308" + ], + [ + 4039925, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5411057577383729197" + ], + [ + 4039928, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15464081172484410235" + ], + [ + 4039932, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6225688351493779577" + ], + [ + 4039935, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15625622540562040764" + ], + [ + 4039937, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_10502628395411644717" + ], + [ + 4039943, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11797930136011262397" + ], + [ + 4039946, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13885865319396877227" + ], + [ + 4039957, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5671629446760297180" + ], + [ + 4039960, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1393594888891471458" + ], + [ + 4039967, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3969615252157795066" + ], + [ + 4039972, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_17691785108059131348" + ], + [ + 4040670, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15630378231912934353" + ], + [ + 4040733, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1125260354809627775" + ], + [ + 4040797, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_250967488981294306" + ], + [ + 4040888, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6978632213513861862" + ], + [ + 4040963, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14539782102651016167" + ], + [ + 4073698, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_4027982239315583183" + ], + [ + 4087187, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_5952558817727402033" + ], + [ + 4108124, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8896124445966427863" + ], + [ + 4121492, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3777541655201128969" + ], + [ + 4134649, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14525555472067694467" + ], + [ + 4145652, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_9698817224586140305" + ], + [ + 4147258, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11549998609739639611" + ], + [ + 4147317, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6537741669769489394" + ], + [ + 4147443, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_7264494085295711900" + ], + [ + 4147482, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1691570203381345331" + ], + [ + 4147600, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_805114710948033208" + ], + [ + 4158981, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_1404183938416935234" + ], + [ + 4159025, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_366169914351560748" + ], + [ + 4159276, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6543673387235939768" + ], + [ + 4159300, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_583449162339166668" + ], + [ + 4159319, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_15439605152958871044" + ], + [ + 4159324, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11830655025939200238" + ], + [ + 4159646, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2897405567433811008" + ], + [ + 4159775, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_11323777380456298142" + ], + [ + 4159784, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_18401703462013938226" + ], + [ + 4159918, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_3439594419195257361" + ], + [ + 4160584, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8461195438972411491" + ], + [ + 4160770, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_14710116434070740759" + ], + [ + 4161660, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_8602813889361540004" + ], + [ + 4161781, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_6869104894578156210" + ], + [ + 4161828, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_2207584602328927841" + ], + [ + 4162147, + "/home/cym/rf/fsdp-sglang-rf/new-rf/RLinf/kernel_meta/kernel_meta_13886441570120016825" + ] +] \ No newline at end of file diff --git a/rlinf/algorithms/losses.py b/rlinf/algorithms/losses.py index 3980f9136..72e9004c6 100644 --- a/rlinf/algorithms/losses.py +++ b/rlinf/algorithms/losses.py @@ -196,15 +196,20 @@ def compute_math_ppo_actor_loss(**kwargs): loss_agg_func = kwargs["loss_agg_func"] logprobs = kwargs["logprobs"] old_logprobs = kwargs["old_logprobs"] - eps_clip = kwargs["eps_clip"] + clip_ratio_low = kwargs["clip_ratio_low"] + clip_ratio_high = kwargs["clip_ratio_high"] advantages = kwargs["advantages"] loss_mask = kwargs.get("loss_mask", None) c_clip = kwargs.get("c_clip", None) - - assert logprobs.dtype == torch.float32 - assert old_logprobs.dtype == torch.float32 - assert advantages.dtype == torch.float32 - + if logprobs.dtype != torch.float32: + logprobs = logprobs.float() # 转换为 float32 + #assert logprobs.dtype == torch.float32 + #assert old_logprobs.dtype == torch.float32 + #assert advantages.dtype == torch.float32 + if old_logprobs.dtype != torch.float32: + old_logprobs = old_logprobs.float() + if advantages.dtype != torch.float32: + advantages = advantages.float() assert loss_mask is not None loss_mask_count = loss_mask.count_nonzero() or 1 @@ -212,7 +217,7 @@ def compute_math_ppo_actor_loss(**kwargs): ratio = torch.where(loss_mask, torch.exp(logprobs - old_logprobs), 0) approx_kl = torch.where(loss_mask, (logprobs - old_logprobs).detach(), 0.0) - clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip) + clipped_ratio = torch.clamp(ratio, 1.0 - clip_ratio_low, 1.0 + clip_ratio_high) policy_loss1 = -advantages * ratio policy_loss2 = -advantages * clipped_ratio @@ -232,17 +237,21 @@ def compute_math_ppo_actor_loss(**kwargs): clip_mask = policy_loss1.detach() < policy_loss2.detach() dual_clip_mask.logical_and_(loss_mask) - clip_fraction = clip_mask.logical_and_(loss_mask).count_nonzero() / loss_mask_count - approx_kl = -approx_kl.sum() / loss_mask_count + num_clipped = clip_mask.logical_and_(loss_mask).count_nonzero() + + clip_fraction = num_clipped.float() / float(loss_mask_count) + approx_kl = -approx_kl.sum() / float(loss_mask_count) dual_cliped_ratio = torch.where(dual_clip_mask, ratio, 0) # Compile metrics for logging metrics_data = { - "policy_loss": masked_mean(policy_loss.detach(), loss_mask), - "ratio": masked_mean(ratio.detach(), loss_mask), - "clipped_ratio": masked_mean(clipped_ratio.detach(), loss_mask), - "dual_cliped_ratio": masked_mean(dual_cliped_ratio.detach(), loss_mask), + "policy_loss": masked_mean(policy_loss.detach(), loss_mask).detach(), + "ratio": masked_mean(ratio.detach(), loss_mask).detach(), + "clipped_ratio": masked_mean(clipped_ratio.detach(), loss_mask).detach(), + "dual_cliped_ratio": masked_mean( + dual_cliped_ratio.detach(), loss_mask + ).detach(), "approx_kl": approx_kl.detach(), "clip_fraction": clip_fraction.detach(), } diff --git a/rlinf/algorithms/registry.py b/rlinf/algorithms/registry.py index 96b19f2b4..11bfc6c6a 100644 --- a/rlinf/algorithms/registry.py +++ b/rlinf/algorithms/registry.py @@ -73,22 +73,3 @@ def calculate_adv_and_returns(**kwargs) -> Tuple[torch.Tensor, Optional[torch.Te adv_type = kwargs["adv_type"] fn = get_adv_and_returns(adv_type) return fn(**kwargs) - - -REWARD_REGISTRY: Dict[str, Callable] = {} - - -def register_reward_fn(name: str): - def decorator(fn): - REWARD_REGISTRY[name] = fn - return fn - - return decorator - - -def get_reward_fn(name: Optional[str]): - if name is None: - return None - if name not in REWARD_REGISTRY: - raise ValueError(f"Reward function {name} not registered") - return REWARD_REGISTRY[name] diff --git a/rlinf/algorithms/rewards/__init__.py b/rlinf/algorithms/rewards/__init__.py new file mode 100644 index 000000000..380cfa102 --- /dev/null +++ b/rlinf/algorithms/rewards/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rlinf.algorithms.rewards.code import CodeReward +from rlinf.algorithms.rewards.math import MathReward +from rlinf.algorithms.rewards.vqa import VQAReward + + +def register_reward(name: str, reward_class: type): + assert name not in reward_registry, f"Reward {name} already registered" + reward_registry[name] = reward_class + + +def get_reward_class(name: str): + assert name in reward_registry, f"Reward {name} not found" + return reward_registry[name] + + +reward_registry = {} + +register_reward("math", MathReward) +register_reward("vqa", VQAReward) +register_reward("code", CodeReward) diff --git a/rlinf/algorithms/rewards/code/__init__.py b/rlinf/algorithms/rewards/code/__init__.py new file mode 100644 index 000000000..0fc75f971 --- /dev/null +++ b/rlinf/algorithms/rewards/code/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from omegaconf import DictConfig + +from toolkits.code_verifier.verify import fim_verify_call + + +class CodeReward: + def __init__(self, config: DictConfig): + self.scale = config.get("reward_scale", 1.0) + + def get_reward( + self, response: List[str], reference: List[List[str]] + ) -> List[float]: + rewards = fim_verify_call(response, reference) + return [float(reward) * self.scale for reward in rewards] diff --git a/rlinf/algorithms/rewards/math/__init__.py b/rlinf/algorithms/rewards/math/__init__.py new file mode 100644 index 000000000..1a67e80e1 --- /dev/null +++ b/rlinf/algorithms/rewards/math/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from omegaconf import DictConfig + +from toolkits.math_verifier.verify import math_verify_call + + +class MathReward: + def __init__(self, config: DictConfig): + self.scale = config.get("reward_scale", 1.0) + + def get_reward( + self, response: List[str], reference: List[List[str]] + ) -> List[float]: + """ + Calculates reward scores for a list of responses compared to corresponding lists of reference answers. + For each response, the function checks if it matches any of the provided references using the `process_results` function. + The reward for each response is computed as the first element of the result (converted to float) multiplied by `self.scale`. + Args: + response (List[str]): A list of response strings to be evaluated. + reference (List[List[str]]): A list where each element is a list of reference strings corresponding to each response. + Returns: + List[float]: A list of reward scores, one for each response. + """ + + rewards = math_verify_call(response, reference) + return [float(reward) * self.scale for reward in rewards] diff --git a/rlinf/algorithms/rewards/vqa/__init__.py b/rlinf/algorithms/rewards/vqa/__init__.py new file mode 100644 index 000000000..77b009369 --- /dev/null +++ b/rlinf/algorithms/rewards/vqa/__init__.py @@ -0,0 +1,62 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +from omegaconf import DictConfig + +from .format_rewards import answer_format_reward, think_format_reward +from .qa_rewards import qa_accuracy_reward + + +class VQAReward: + NEEDED_REWARD_FUNCTIONS = { + "qa_accuracy": qa_accuracy_reward, + "think_format": think_format_reward, + "answer_format": answer_format_reward, + } + + def __init__(self, config: DictConfig): + assert "reward_weights" in config, "VQAReward requires reward_weights in config" + + self.reward_weights_config = config.reward_weights + assert set(self.reward_weights_config.keys()) == set( + self.NEEDED_REWARD_FUNCTIONS.keys() + ), ( + f"Reward weights must contains all of: {self.NEEDED_REWARD_FUNCTIONS.keys()} but got {list(self.reward_weights_config.keys())}" + ) + assert all( + reward_weight >= 0 for reward_weight in self.reward_weights_config.values() + ), ( + f"All reward weights must be non-negative but got {list(self.reward_weights_config.values())}" + ) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def get_reward(self, completions: List[str], answers: List[dict]) -> List[float]: + rewards = [] + reward_weights = [] + for reward_name, reward_function in self.NEEDED_REWARD_FUNCTIONS.items(): + if self.reward_weights_config[reward_name] > 0: + rewards.append(reward_function(completions, answers)) + else: + rewards.append([0.0] * len(completions)) + reward_weights.append(self.reward_weights_config[reward_name]) + + rewards_tensor = torch.tensor(rewards, device=self.device) + weights_tensor = torch.tensor(reward_weights, device=self.device) + + final_rewards = (rewards_tensor * weights_tensor.unsqueeze(1)).sum(dim=0) + + return final_rewards.tolist() diff --git a/rlinf/algorithms/rewards/vqa/format_rewards.py b/rlinf/algorithms/rewards/vqa/format_rewards.py new file mode 100644 index 000000000..205bbe336 --- /dev/null +++ b/rlinf/algorithms/rewards/vqa/format_rewards.py @@ -0,0 +1,67 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import List + + +def think_format_reward(completions, answers) -> List[float]: + """ + Think format reward function compatible with GRPO training. + + Reward function that checks if reasoning is enclosed within tags. + + Args: + completions: List of model completions (text strings) + + Returns: + List of reward scores (1.0 for correct format, 0.0 otherwise) + """ + pattern = r"^(?!.*)(.*?).*$" + rewards = [] + + for completion in completions: + completion_text = str(completion).strip() + match = re.match(pattern, completion_text, re.DOTALL | re.MULTILINE) + rewards.append(1.0 if match else 0.0) + + return rewards + + +def answer_format_reward(completions, answers) -> List[float]: + """ + Reward function that checks for proper answer formatting. + + Expected format: X. content where X is a choice letter. + + Args: + completions: List of model completions (text strings) + + Returns: + List of reward scores (1.0 for correct format, 0.0 otherwise) + """ + rewards = [] + + for completion in completions: + completion_text = str(completion).strip() + + # Check for proper answer format: X. content + answer_pattern = r"\s*[A-E]\.\s*.+?\s*" + has_proper_answer = bool( + re.search(answer_pattern, completion_text, re.DOTALL | re.IGNORECASE) + ) + + rewards.append(1.0 if has_proper_answer else 0.0) + + return rewards diff --git a/rlinf/algorithms/rewards/vqa/qa_rewards.py b/rlinf/algorithms/rewards/vqa/qa_rewards.py new file mode 100644 index 000000000..2bc9540d3 --- /dev/null +++ b/rlinf/algorithms/rewards/vqa/qa_rewards.py @@ -0,0 +1,110 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import List + + +def qa_accuracy_reward(completions, answers) -> List[float]: + """ + Reward function that evaluates question-answering accuracy for VQA tasks. + + Based on TRL's accuracy_reward pattern but adapted for multiple choice VQA. + + Args: + completions: List of model completions (text strings) + answers: List of correct answers (dict) + + Returns: + List of reward scores (1.0 for correct, 0.0 for incorrect) + """ + rewards = [] + + for completion, answer in zip(completions, answers): + completion_text = str(completion).strip() + + # Extract answer from completion - look for X. content + answer_match = re.search( + r"\s*([A-E])\.\s*(.*?)\s*", + completion_text, + re.DOTALL | re.IGNORECASE, + ) + + if not answer_match: + rewards.append(0.0) + continue + + predicted_letter = answer_match.group(1).upper() + predicted_content = answer_match.group(2).strip() + + # Get ground truth from kwargs + correct_answer = answer.get("correct_answer", None) + choices = answer.get("choices", None) + + if correct_answer is None or choices is None: + rewards.append(0.0) + continue + + # Normalize correct_answer to letter format + if isinstance(correct_answer, int): + correct_letter = chr(65 + correct_answer) # 0->A, 1->B, etc. + elif isinstance(correct_answer, str): + correct_letter = correct_answer.strip().upper() + else: + rewards.append(0.0) + continue + + # Parse choices if string format + if isinstance(choices, str): + try: + import ast + + choices = ast.literal_eval(choices) + except (ValueError, SyntaxError): + choices = [str(choices)] + + # Get correct choice content + letter_to_idx = {"A": 0, "B": 1, "C": 2, "D": 3, "E": 4} + if correct_letter in letter_to_idx and letter_to_idx[correct_letter] < len( + choices + ): + correct_content = choices[letter_to_idx[correct_letter]].strip() + else: + rewards.append(0.0) + continue + + # Check accuracy: both letter and content must match + letter_match = predicted_letter == correct_letter + content_match = _compare_choice_content(predicted_content, correct_content) + + rewards.append(1.0 if (letter_match and content_match) else 0.0) + + return rewards + + +def _compare_choice_content(predicted: str, correct: str) -> bool: + """Compare predicted choice content with correct content.""" + # Simple normalized comparison + pred_normalized = predicted.lower().strip() + correct_normalized = correct.lower().strip() + + # Direct match + if pred_normalized == correct_normalized: + return True + + # Partial match for more flexibility + if pred_normalized in correct_normalized or correct_normalized in pred_normalized: + return True + + return False diff --git a/rlinf/config.py b/rlinf/config.py index 86198334a..05de15254 100644 --- a/rlinf/config.py +++ b/rlinf/config.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses +import importlib.util import logging import os from dataclasses import asdict @@ -33,16 +34,10 @@ logging.getLogger().setLevel(logging.INFO) -try: - import transformer_engine - - HAVE_TE = True -except ImportError: - transformer_engine = None - HAVE_TE = False - -SUPPORTED_MODEL_ARCHS = ["qwen2.5", "openvla", "openvla_oft"] +SUPPORTED_MODEL_ARCHS = ["qwen2.5", "qwen2.5_vl", "openvla", "openvla_oft"] SUPPORTED_ROLLOUT_BACKENDS = ["sglang", "vllm"] +SUPPORTED_TASK_TYPE = ["embodied", "reasoning", "coding_online_rl"] +SUPPORTED_TRAINING_BACKENDS = ["megatron", "fsdp"] __all__ = ["build_config"] @@ -191,7 +186,7 @@ def validate_vllm_cfg(cfg): def validate_model_cfg_by_hf_config(cfg, hf_model_path): # validate by hf config - hf_config = AutoConfig.from_pretrained(hf_model_path) + hf_config = AutoConfig.from_pretrained(hf_model_path, trust_remote_code=True) if "Qwen2ForCausalLM" in hf_config.architectures: qkv_bias = True @@ -199,8 +194,16 @@ def validate_model_cfg_by_hf_config(cfg, hf_model_path): qkv_bias = getattr(hf_config, "attention_bias", False) with open_dict(cfg): - if hf_config.rope_scaling is not None: - cfg.model.seq_len_interpolation_factor = hf_config.rope_scaling["factor"] + rs = getattr(hf_config, "rope_scaling", None) + if isinstance(rs, dict): + rtype = rs.get("type", "") + if rtype in {"linear", "dynamic", "ntk", "yarn"}: + f = rs.get("factor") + if f is not None: + cfg.model.seq_len_interpolation_factor = float(f) + else: + # mrope + cfg.model.seq_len_interpolation_factor = None cfg.model.override_vocab_size = hf_config.vocab_size cfg.model.max_position_embeddings = hf_config.max_position_embeddings cfg.model.rotary_base = hf_config.rope_theta @@ -220,6 +223,16 @@ def validate_model_cfg_by_hf_config(cfg, hf_model_path): return cfg +def validate_fsdp_cfg(cfg: DictConfig) -> DictConfig: + OmegaConf.set_struct(cfg, True) + with open_dict(cfg): + cfg.fsdp.forward_prefetch = cfg.fsdp.get("forward_prefetch", False) + cfg.fsdp.limit_all_gathers = cfg.fsdp.get("limit_all_gathers", False) + cfg.fsdp.backward_prefetch = cfg.fsdp.get("backward_prefetch", False) + cfg.fsdp.use_orig_params = cfg.fsdp.get("use_orig_params", False) + return cfg + + def validate_megatron_cfg(cfg: DictConfig) -> DictConfig: OmegaConf.set_struct(cfg, True) @@ -527,7 +540,7 @@ def get_robot_control_mode(robot: str): return cfg -def validate_math_cfg(cfg: DictConfig) -> DictConfig: +def validate_reasoning_cfg(cfg: DictConfig) -> DictConfig: assert cfg.rollout.model_arch in SUPPORTED_MODEL_ARCHS, ( f"Model {cfg.rollout.model_arch} is not supported" ) @@ -606,11 +619,14 @@ def validate_coding_online_rl_cfg(cfg: DictConfig) -> DictConfig: def validate_cfg(cfg: DictConfig) -> DictConfig: OmegaConf.set_struct(cfg, True) + assert cfg.runner.task_type in SUPPORTED_TASK_TYPE, ( + f"task_type must be one of {SUPPORTED_TASK_TYPE}" + ) if cfg.runner.task_type == "embodied": cfg = validate_embodied_cfg(cfg) - if cfg.runner.task_type == "math": - cfg = validate_math_cfg(cfg) - if cfg.runner.task_type == "coding_online_rl": + elif cfg.runner.task_type == "reasoning": + cfg = validate_reasoning_cfg(cfg) + elif cfg.runner.task_type == "coding_online_rl": cfg = validate_coding_online_rl_cfg(cfg) if ( @@ -619,13 +635,21 @@ def validate_cfg(cfg: DictConfig) -> DictConfig: ): assert cfg.algorithm.group_size > 1 + assert cfg.actor.training_backend in SUPPORTED_TRAINING_BACKENDS, ( + f"Unsupported training_backend {cfg.actor.training_backend}. Supported training backends are {SUPPORTED_TRAINING_BACKENDS}." + ) + if cfg.actor.training_backend == "megatron": cfg.actor = validate_megatron_cfg(cfg.actor) cfg.actor = validate_model_cfg_by_hf_config(cfg.actor, cfg.rollout.model_dir) + elif cfg.actor.training_backend == "fsdp": + cfg.actor = validate_fsdp_cfg(cfg.actor) if cfg.critic.use_critic_model and cfg.critic.training_backend == "megatron": cfg.critic = validate_megatron_cfg(cfg.critic) - cfg = validate_model_cfg_by_hf_config(cfg.critic, cfg.rollout.model_dir) + cfg.critic = validate_model_cfg_by_hf_config(cfg.critic, cfg.rollout.model_dir) + elif cfg.critic.use_critic_model and cfg.critic.training_backend == "fsdp": + cfg.critic = validate_fsdp_cfg(cfg.critic) return cfg @@ -756,7 +780,10 @@ def build_transformer_config(cfg) -> "TransformerConfig": tp_only_amax_red = cfg.get("tp_only_amax_red", False) if cfg.get("enable_cuda_graph", False): - assert HAVE_TE, "Transformer Engine is required for cudagraphs." + if importlib.util.find_spec("transformer_engine") is None: + raise ImportError( + "Can not import transformer_engine, which is required for cudagraphs." + ) assert cfg.get("use_te_rng_tracker", False), ( "Transformer engine's RNG tracker is required for cudagraphs, this can be enabled with \ 'use_te_rng_tracker=True'." diff --git a/rlinf/data/datasets/__init__.py b/rlinf/data/datasets/__init__.py new file mode 100644 index 000000000..f86f3576e --- /dev/null +++ b/rlinf/data/datasets/__init__.py @@ -0,0 +1,150 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Dict, List, Tuple + +import torch +from omegaconf import DictConfig +from torch.utils.data import Dataset +from transformers import AutoTokenizer + +from rlinf.data.datasets.item import DatasetItem +from rlinf.data.datasets.math import MathDataset +from rlinf.data.datasets.vlm import VLMDatasetRegistry + + +def create_rl_dataset( + config: DictConfig, tokenizer: AutoTokenizer +) -> Tuple[Dataset, Dataset]: + """Create rl datasets. + + Arguments: + config: The RLinf config. + tokenizer (Tokenizer): The tokenizer. + + Returns: + train_dataset (Dataset): The training dataset. + + val_dataset (Dataset): The validation dataset. + """ + + if config.data.type == "math": + dataset_cls = MathDataset + elif config.data.type == "vision_language": + # Prefer new factory-based VLM datasets; fallback to legacy if requested + dataset_name = getattr(config.data, "dataset_name", None) + lazy_loading = bool(getattr(config.data, "lazy_loading", False)) + + logging.info( + f"Using VLM dataset: name={dataset_name}, lazy_loading={lazy_loading}" + ) + + train_dataset = VLMDatasetRegistry.create( + dataset_name, + data_paths=config.data.train_data_paths, + config=config, + tokenizer=tokenizer, + ) + val_dataset = VLMDatasetRegistry.create( + dataset_name, + data_paths=config.data.val_data_paths, + config=config, + tokenizer=tokenizer, + ) + return train_dataset, val_dataset + else: + return None, None + + logging.info(f"Using dataset class: {dataset_cls.__name__}") + + # Instantiate the dataset using the determined dataset class + train_dataset = dataset_cls( + data_paths=config.data.train_data_paths, + config=config, + tokenizer=tokenizer, + ) + + val_dataset = dataset_cls( + data_paths=config.data.val_data_paths, + config=config, + tokenizer=tokenizer, + ) + + return train_dataset, val_dataset + + +def collate_fn(data_list: List["DatasetItem"]) -> Dict[str, Any]: + """ + Collate function for batching dataset items. + """ + prompts = [] + lens = [] + for it in data_list: + p = ( + it.prompt + if isinstance(it.prompt, torch.Tensor) + else torch.as_tensor(it.prompt, dtype=torch.long) + ) + if p.dim() == 2 and p.size(0) == 1: + p = p.squeeze(0) + assert p.dim() == 1, ( + f"DatasetItem.prompt must be 1-D tensor, current shape is: {p.shape}" + ) + prompts.append(p) + lens.append(p.numel()) + + if len(set(lens)) == 1: + target_len = lens[0] + else: + target_len = min(lens) + prompts = [p[-target_len:] if p.numel() > target_len else p for p in prompts] + + batch_prompt = torch.stack(prompts, dim=0) # [B, L] + batch_length = torch.tensor( + [min(int(it.length), target_len) for it in data_list], dtype=torch.long + ) + + batch_idx = torch.tensor([int(it.idx) for it in data_list], dtype=torch.long) + + batch: Dict[str, Any] = { + "prompt": batch_prompt, # [B, L] + "length": batch_length, # [B] + "answer": [it.answer for it in data_list], # List[str] + "idx": batch_idx, # [B] + "solution": [it.solution for it in data_list], # List[Optional[str]] + "image_data": [ + it.image_data for it in data_list + ], # List[Optional[List[bytes|str]]] + "prompt_text": [it.prompt_text for it in data_list], # List[Optional[str]] + "meta": [it.meta for it in data_list], # List[Optional[dict]] + "multi_modal_inputs": [ + it.multi_modal_inputs for it in data_list + ], # List[Optional[dict]] + } + return batch diff --git a/rlinf/data/datasets/item.py b/rlinf/data/datasets/item.py new file mode 100644 index 000000000..e75155dcb --- /dev/null +++ b/rlinf/data/datasets/item.py @@ -0,0 +1,59 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import torch + + +@dataclass +class DatasetItem: + """ + A single item in processed dataset. + + Attributes: + prompt (torch.Tensor): Tokenized prompt input_ids tensor. + length (int): Length of the prompt input_ids. + answer (str | dict): The answer associated with the prompt. + idx (int): Index of the item in the dataset. + solution (Optional[str]): Optional solution text if exists. + prompt_text (Optional[str]): Optional original prompt text before tokenization. + meta (Optional[Dict[str, Any]]): Optional metadata dictionary. + multi_modal_inputs (Optional[Dict[str, Any]]): Optional dictionary for additional multi-modal inputs. + """ + + prompt: torch.Tensor + length: int + answer: str | dict + idx: int + solution: Optional[str] = None + image_data: Optional[List[Union[bytes, str]]] = None + prompt_text: Optional[str] = None + meta: Optional[Dict[str, Any]] = None + multi_modal_inputs: Optional[Dict[str, Any]] = None diff --git a/rlinf/data/datasets.py b/rlinf/data/datasets/math.py similarity index 52% rename from rlinf/data/datasets.py rename to rlinf/data/datasets/math.py index fcce53f47..821074bf6 100644 --- a/rlinf/data/datasets.py +++ b/rlinf/data/datasets/math.py @@ -12,67 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import json import logging import os -from collections import defaultdict -from typing import List +from typing import Any, List, Tuple, Union -import numpy as np import torch +from omegaconf import DictConfig from torch.utils.data import Dataset +from transformers import AutoTokenizer - -def batch_pad_to_fixed_len( - batch: List[torch.Tensor], - max_batch_len: int, - pad_token: int, - left_pad: bool = False, -) -> torch.Tensor: - if left_pad: - batch_pad = torch.stack( - [ - torch.cat( - [ - torch.full( - (max_batch_len - len(seq),), pad_token, dtype=seq.dtype - ), # pad on the left - seq, - ] - ) - for seq in batch - ] - ) - else: - batch_pad = torch.stack( - [ - torch.cat( - [ - seq, - torch.full( - (max_batch_len - len(seq),), pad_token, dtype=seq.dtype - ), - ] - ) - for seq in batch - ] - ) - return batch_pad +from rlinf.data.datasets.item import DatasetItem +from rlinf.data.datasets.utils import batch_pad_to_fixed_len class MathDataset(Dataset): - def __init__(self, data_paths, config, tokenizer): + def __init__( + self, + data_paths: Union[str, List[str]], + config: DictConfig, + tokenizer: AutoTokenizer, + ): super().__init__() self.data_paths = data_paths if isinstance(self.data_paths, str): self.data_paths = [self.data_paths] - self.max_prompt_length = config.max_prompt_length + self.max_prompt_length = config.data.max_prompt_length self.tokenizer = tokenizer - self.prompt_key = config.prompt_key + self.prompt_key = config.data.prompt_key self.data = self._load_data() - if config.get("filter_prompt_by_length", False): + if config.data.get("filter_prompt_by_length", False): total = len(self.data) filtered = [] failed = 0 @@ -97,7 +83,10 @@ def __init__(self, data_paths, config, tokenizer): f"(kept {len(self.data)} / {total})." ) - def _load_data(self): + def _load_data(self) -> List[Any]: + """ + Load and merge data from multiple files(json or jsonl). + """ merged_data = [] for path in self.data_paths: @@ -122,7 +111,10 @@ def _load_data(self): def __len__(self): return len(self.data) - def encode(self, text): + def encode(self, text: str) -> Tuple[List[int], int]: + """ + Use tokenizer to encode the text and return the token ids and length. + """ text_ids = self.tokenizer.encode(text) return text_ids, len(text_ids) @@ -151,78 +143,11 @@ def __getitem__(self, idx): self.tokenizer.eos_token_id, left_pad=True, )[0] - - output = { - "prompt": prompt_tokens_tensor, - "length": prompt_length, - "answer": answer, - "idx": idx, - } + output = DatasetItem( + prompt=prompt_tokens_tensor, + length=prompt_length, + answer=answer, + idx=idx, + image_data=[], + ) return output - - -def create_rl_dataset(data_config, tokenizer): - """Create rl datasets. - - Arguments: - data_config: The data config. - tokenizer (Tokenizer): The tokenizer. - - Returns: - train_dataset (Dataset): The training dataset. - - val_dataset (Dataset): The validation dataset. - """ - - if data_config.type == "math": - dataset_cls = MathDataset - else: - return None, None - - print(f"Using dataset class: {dataset_cls.__name__}") - - # Instantiate the dataset using the determined dataset class - train_dataset = dataset_cls( - data_paths=data_config.train_data_paths, - config=data_config, - tokenizer=tokenizer, - ) - - val_dataset = dataset_cls( - data_paths=data_config.val_data_paths, - config=data_config, - tokenizer=tokenizer, - ) - - return train_dataset, val_dataset - - -def collate_fn(data_list: list[dict]) -> dict: - r""" - Collate a batch of sample dicts into batched tensors and arrays. - - Args: - data_list: List of dicts mapping feature names to torch.Tensor or other values. - - Returns: - Dict where tensor entries are stacked into a torch.Tensor of shape - (batch_size, \*dims) and non-tensor entries are converted to - np.ndarray of dtype object with shape (batch_size,). - """ - tensors = defaultdict(list) - non_tensors = defaultdict(list) - - for data in data_list: - for key, val in data.items(): - if isinstance(val, torch.Tensor): - tensors[key].append(val) - else: - non_tensors[key].append(val) - - for key, val in tensors.items(): - tensors[key] = torch.stack(val, dim=0) - - for key, val in non_tensors.items(): - non_tensors[key] = np.array(val, dtype=object) - - return {**tensors, **non_tensors} diff --git a/rlinf/data/datasets/utils.py b/rlinf/data/datasets/utils.py new file mode 100644 index 000000000..db4dbdb58 --- /dev/null +++ b/rlinf/data/datasets/utils.py @@ -0,0 +1,68 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch + + +def batch_pad_to_fixed_len( + batch: List[torch.Tensor], + max_batch_len: int, + pad_token: int, + left_pad: bool = False, +) -> torch.Tensor: + if left_pad: + batch_pad = torch.stack( + [ + torch.cat( + [ + torch.full( + (max_batch_len - len(seq),), pad_token, dtype=seq.dtype + ), # pad on the left + seq, + ] + ) + for seq in batch + ] + ) + else: + batch_pad = torch.stack( + [ + torch.cat( + [ + seq, + torch.full( + (max_batch_len - len(seq),), pad_token, dtype=seq.dtype + ), + ] + ) + for seq in batch + ] + ) + return batch_pad diff --git a/rlinf/data/datasets/vlm.py b/rlinf/data/datasets/vlm.py new file mode 100644 index 000000000..da9e952b9 --- /dev/null +++ b/rlinf/data/datasets/vlm.py @@ -0,0 +1,483 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import logging +import os +from io import BytesIO +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import pandas as pd +import torch +from omegaconf import DictConfig +from PIL import Image +from torch.utils.data import Dataset +from transformers import AutoProcessor, AutoTokenizer + +from rlinf.data.datasets.item import DatasetItem +from rlinf.data.datasets.utils import batch_pad_to_fixed_len + + +class VLMBaseDataset(Dataset): + def __init__( + self, + data_paths: Union[List[str], str], + config: DictConfig, + tokenizer: AutoTokenizer, + ) -> None: + super().__init__() + self.cfg = config + raw_paths = [data_paths] if isinstance(data_paths, str) else list(data_paths) + # Expand directories into file lists recursively (json/jsonl/parquet) + self.data_paths = self._expand_data_paths(raw_paths) + self.tokenizer = tokenizer + # Delay processor creation; only needed when use_chat_template is True + self._processor = None + + self.system_prompt = config.data.get("system_prompt", None) + self.use_chat_template = bool(config.data.use_chat_template) + self.image_keys = list(config.data.image_keys or []) + self.prompt_key = config.data.prompt_key + self.choice_key = config.data.get("choice_key", None) + self.answer_key = config.data.get("answer_key", None) + self.solution_key = config.data.get("solution_key", None) + self.max_prompt_length = int(config.data.max_prompt_length) + self.eos_id = int(self.tokenizer.eos_token_id) + + # Loading mode + self.lazy_loading = bool(getattr(config.data, "lazy_loading", False)) + + self._records = [] + self._indices = [] # (path, fmt, row_index_or_offset) + + if self.lazy_loading: + self._build_lazy_indices() + else: + self._eager_load_all() + + def __len__(self) -> int: + return len(self._indices) if self.lazy_loading else len(self._records) + + def __getitem__(self, idx: int) -> DatasetItem: + if self.lazy_loading: + path, fmt, key = self._indices[idx] + raw = self._load_single_lazy(path, fmt, key) + return self._process_raw_record(raw, idx) + else: + raw = self._records[idx] + return self._process_raw_record(raw, idx) + + # Ensure dataset is picklable for multi-process DataLoader by removing + # unpicklable cache objects like pyarrow.ParquetFile from state. + def __getstate__(self): + state = self.__dict__.copy() + # Drop heavy/unpicklable caches; they will be rebuilt on-demand in workers + for k in ("_parquet_cache", "_parquet_df_cache"): + if k in state: + state[k] = {} + return state + + def __setstate__(self, state): + # Restore state and ensure cache dicts exist + self.__dict__.update(state) + self._parquet_cache = getattr(self, "_parquet_cache", {}) + self._parquet_df_cache = getattr(self, "_parquet_df_cache", {}) + + def get_image_list(self, dataitem: Dict[str, Any]) -> List[Union[bytes, str, None]]: + images: List[Union[bytes, str, None]] = [] + for k in self.image_keys: + v = dataitem.get(k, None) + if v is None: + continue + if isinstance(v, Image.Image): + images.append(v) + elif isinstance(v, dict) and "bytes" in v: + images.append(v["bytes"]) + else: + images.append(v) # path or url + if not images: + images = [None] + return images + + def build_prompt_text(self, data_item: Dict[str, Any]) -> str: + # Default: prompt + optional choices rendered inline + q = data_item.get(self.prompt_key, "") + choices = data_item.get(self.choice_key, []) if self.choice_key else [] + if not isinstance(choices, list): + choices = [choices] + if choices: + return f"{q}{choices}\n" + return str(q) + + def encode_prompt( + self, prompt_text: str, images + ) -> Tuple[torch.Tensor, int, Optional[str]]: + """ + Return (token_ids[L], length, prompt_text_used). If using chat template, encode with processor. + Subclasses may override to support alternative prompting. + """ + if self.use_chat_template: + if self._processor is None: + self._processor = AutoProcessor.from_pretrained( + self.cfg.actor.model.model_path + ) + messages = [] + if self.system_prompt is not None: + messages.append( + { + "role": "system", + "content": [{"type": "text", "text": self.system_prompt}], + } + ) + + content: List[Dict[str, Any]] = [] + for _ in range(max(0, len(images))): + content.append({"type": "image"}) + content.append({"type": "text", "text": prompt_text}) + messages.append({"role": "user", "content": content}) + rendered = self._processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + images_inputs = [] + for image in images: + image_obj = None + if isinstance(image, Image.Image): + image_obj = image.convert("RGB") + if isinstance(image, (bytes, bytearray)): + image_obj = Image.open(BytesIO(image)).convert("RGB") + images_inputs.append(image_obj) + + inputs = self._processor( + text=[rendered], images=images_inputs, padding=True, return_tensors="pt" + ) + inputs.pop("attention_mask") + if self.cfg.rollout.rollout_backend == "sglang": + ids = inputs.pop("input_ids") + elif self.cfg.rollout.rollout_backend == "vllm": + inputs.pop("input_ids") + ids = self._processor( + text=[rendered], images=None, padding=True, return_tensors="pt" + )["input_ids"] + else: + raise ValueError( + f"Unsupported rollout backend {self.cfg.rollout.rollout_backend}" + ) + if isinstance(ids, torch.Tensor): + if ids.dim() == 2 and ids.size(0) == 1: + ids = ids.squeeze(0) + ids = ids.to(dtype=torch.long) + else: + ids = torch.tensor(ids, dtype=torch.long) + + multi_modal_inputs = {} + for k, v in inputs.items(): + multi_modal_inputs[k] = v + return ids, int(ids.numel()), rendered, multi_modal_inputs + else: + # fallback: tokenizer only + ids_list = self.tokenizer.encode(prompt_text) + ids = torch.as_tensor(ids_list, dtype=torch.long) + return ids, int(ids.numel()), prompt_text, {} + + def postprocess_dataset_item( + self, item: DatasetItem, raw: Dict[str, Any] + ) -> DatasetItem: + return item + + def _expand_data_paths(self, inputs: List[str]) -> List[str]: + exts = {".jsonl", ".json", ".parquet"} + files: List[str] = [] + for p in inputs: + if os.path.isdir(p): + for root, _, fnames in os.walk(p): + for fn in fnames: + ext = os.path.splitext(fn)[1].lower() + if ext in exts: + files.append(os.path.join(root, fn)) + else: + files.append(p) + files = sorted(set(files)) + return files + + def _eager_load_all(self) -> None: + merged: List[Dict[str, Any]] = [] + for path in self.data_paths: + fmt = os.path.splitext(path)[1].lower() + if fmt == ".jsonl": + with open(path, "r", encoding="utf-8") as f: + merged.extend(json.loads(l) for l in f) + elif fmt == ".json": + with open(path, "r", encoding="utf-8") as f: + content = json.load(f) + if isinstance(content, list): + merged.extend(content) + else: + merged.append(content) + elif fmt == ".parquet": + try: + merged.extend(pd.read_parquet(path).to_dict(orient="records")) + except Exception as e: + raise RuntimeError(f"Failed to load parquet eagerly: {path}: {e}") + else: + logging.warning(f"Unsupported format {fmt} for path {path}, skipping.") + self._records = merged + # Build indices for consistency + self._indices = [("", "eager", i) for i in range(len(self._records))] + + def _build_lazy_indices(self) -> None: + self._indices.clear() + for path in self.data_paths: + fmt = os.path.splitext(path)[1].lower() + if fmt == ".jsonl": + # index by byte offsets for each line + offsets: List[int] = [] + with open(path, "rb") as fb: + pos = 0 + for line in fb: + offsets.append(pos) + pos += len(line) + self._indices.extend((path, "jsonl", off) for off in offsets) + elif fmt == ".json": + try: + with open(path, "r", encoding="utf-8") as f: + content = json.load(f) + if not isinstance(content, list): + content = [content] + # store the content to avoid re-reading + # keep perfile cache + self._json_cache = getattr(self, "_json_cache", {}) + self._json_cache[path] = content + self._indices.extend((path, "json", i) for i in range(len(content))) + except Exception as e: + raise RuntimeError(f"Failed to index json lazily: {path}: {e}") + elif fmt == ".parquet": + try: + import pyarrow.parquet as pq # type: ignore + + pf = pq.ParquetFile(path) + num_rows = pf.metadata.num_rows + # file handle cache + self._parquet_cache = getattr(self, "_parquet_cache", {}) + self._parquet_cache[path] = pf + self._indices.extend((path, "parquet", i) for i in range(num_rows)) + except Exception: + df = pd.read_parquet(path) + self._parquet_df_cache = getattr(self, "_parquet_df_cache", {}) + self._parquet_df_cache[path] = df + self._indices.extend( + (path, "parquet_pd", i) for i in range(len(df)) + ) + else: + logging.warning(f"Unsupported format {fmt} for path {path}, skipping.") + + def _load_single_lazy(self, path: str, fmt: str, key: Any) -> Dict[str, Any]: + if fmt == "eager": + return self._records[int(key)] + if fmt == "jsonl": + with open(path, "rb") as fb: + fb.seek(int(key)) + line = fb.readline() + return json.loads(line.decode("utf-8").strip()) + if fmt == "json": + return self._json_cache[path][int(key)] # type: ignore[attr-defined] + if fmt == "parquet": + # Try to use pyarrow lazily; rebuild cache if missing + self._parquet_cache = getattr(self, "_parquet_cache", {}) + pf = self._parquet_cache.get(path) + if pf is None: + try: + import pyarrow.parquet as pq # type: ignore + + pf = pq.ParquetFile(path) + self._parquet_cache[path] = pf + except Exception: + # Fall back to pandas-based cache + self._parquet_df_cache = getattr(self, "_parquet_df_cache", {}) + df = self._parquet_df_cache.get(path) + if df is None: + df = pd.read_parquet(path) + self._parquet_df_cache[path] = df + return df.iloc[int(key)].to_dict() + table = pf.read_row_group(key // max(1, pf.metadata.num_rows), columns=None) + try: + df = table.to_pandas() + return df.iloc[int(key) % len(df)].to_dict() + except Exception: + df_all = pf.read().to_pandas() + return df_all.iloc[int(key)].to_dict() + if fmt == "parquet_pd": + self._parquet_df_cache = getattr(self, "_parquet_df_cache", {}) + df = self._parquet_df_cache.get(path) + if df is None: + df = pd.read_parquet(path) + self._parquet_df_cache[path] = df + return df.iloc[int(key)].to_dict() + raise RuntimeError(f"Unknown lazy fmt {fmt}") + + def _process_raw_record(self, raw: Dict[str, Any], idx: int) -> DatasetItem: + images = self.get_image_list(raw) + prompt_text = self.build_prompt_text(raw) + prompt_ids, plen, rendered_text, multi_modal_inputs = self.encode_prompt( + prompt_text, images + ) + + if plen > self.max_prompt_length: + prompt_ids = prompt_ids[: self.max_prompt_length] + plen = self.max_prompt_length + prompt_ids = batch_pad_to_fixed_len( + [prompt_ids], self.max_prompt_length, self.eos_id, left_pad=True + )[0] + + answer_val = raw.get(self.answer_key, None) if self.answer_key else None + solution_val = raw.get(self.solution_key, None) if self.solution_key else None + item = DatasetItem( + prompt=prompt_ids, + length=plen, + answer=str(answer_val) if answer_val is not None else None, + idx=idx, + image_data=images, + prompt_text=rendered_text or prompt_text, + solution=solution_val, + meta=None, + multi_modal_inputs=multi_modal_inputs, + ) + return self.postprocess_dataset_item(item, raw) + + +class VLMDatasetRegistry: + registry: Dict[str, Callable[..., VLMBaseDataset]] = {} + + @classmethod + def register( + cls, name: str + ) -> Callable[[Callable[..., VLMBaseDataset]], Callable[..., VLMBaseDataset]]: + def decorator(klass: Callable[..., VLMBaseDataset]): + cls.registry[name] = klass + return klass + + return decorator + + @classmethod + def create( + cls, + dataset_name: Optional[str], + *, + data_paths: Union[List[str], str], + config: DictConfig, + tokenizer: AutoTokenizer, + ) -> VLMBaseDataset: + key = dataset_name.lower() + dataset_class = cls.registry.get(key) + return dataset_class(data_paths=data_paths, config=config, tokenizer=tokenizer) + + +@VLMDatasetRegistry.register("robo2vlm") +class Robo2VLMDataset(VLMBaseDataset): + def __init__( + self, + data_paths: Union[List[str], str], + config: DictConfig, + tokenizer: AutoTokenizer, + ) -> None: + super().__init__(data_paths, config, tokenizer) + self.system_prompt = ( + "You are a helpful robotic vision assistant specialized in " + "answering questions about robotic manipulation tasks. " + "Use tags to show your reasoning process, " + "then provide your final answer in tags." + ) + + def get_image_list(self, dataitem: Dict[str, Any]) -> List[Union[bytes, str, None]]: + images: List[Any] = [] + if "images" in dataitem: + v = dataitem.get("images") + if isinstance(v, list): + images = list(v) + elif v is not None: + images = [v] + else: + images = [None] + elif "image" in dataitem: + v = dataitem.get("image") + if v is not None: + images = [v] + else: + images = [None] + else: + return super().get_image_list(dataitem) + + normed: List[Union[bytes, str, None]] = [] + for v in images: + if v is None: + continue + if isinstance(v, Image.Image): + normed.append(v) + elif isinstance(v, dict) and "bytes" in v: + normed.append(v["bytes"]) # raw bytes + else: + normed.append(v) # path/uri/string + if not normed: + normed = [None] + return normed + + def build_prompt_text(self, data_item: Dict[str, Any]) -> str: + # Use 'question' and 'choices' if present; else fallback to base using configured prompt/choice keys + question = data_item.get("question", None) + choices = data_item.get("choices", None) + if question is None: + return super().build_prompt_text(data_item) + # normalize choices + if isinstance(choices, str): + try: + import ast + + choices = ast.literal_eval(choices) + except Exception: + choices = [choices] + if not isinstance(choices, list): + choices = [choices] if choices is not None else [] + + text = f"{question}\n" + if choices: + text += "Choices:\n" + for i, c in enumerate(choices): + text += f"{chr(65 + i)}. {c}\n" + return text + + def postprocess_dataset_item( + self, item: DatasetItem, raw: Dict[str, Any] + ) -> DatasetItem: + answer_dict = { + "choices": raw.get("choices", None), + "correct_answer": raw.get("correct_answer", None), + } + item.answer = answer_dict + + return item diff --git a/rlinf/data/io_struct.py b/rlinf/data/io_struct.py index 1c455fa70..2fa0fea8a 100644 --- a/rlinf/data/io_struct.py +++ b/rlinf/data/io_struct.py @@ -13,19 +13,21 @@ # limitations under the License. from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import torch from omegaconf import DictConfig -from vllm.outputs import CompletionOutput -from vllm.outputs import RequestOutput as VllmRequestOutput -from rlinf.data.datasets import batch_pad_to_fixed_len +if TYPE_CHECKING: + from vllm.outputs import CompletionOutput + from vllm.outputs import RequestOutput as VllmRequestOutput + +from rlinf.data.datasets.utils import batch_pad_to_fixed_len from rlinf.utils.data_iter_utils import ( get_iterator_k_split, split_list, ) - +import torch_npu def get_batch_size( batch: Dict[str, torch.Tensor], batch_tensor_key: str = "input_ids" @@ -47,14 +49,16 @@ class RolloutRequest: Attr input_ids: List of input token IDs for rollout n: Number of completions to generate for each input - idx: List of unique identifiers for the requests, used for tracking - input_lengths: List of lengths of the input sequences, corresponding to input_ids + image_data: list of image data (bytes or URLs) for multimodal inputs answers: Optional list of answers for the requests, if available + multi_modal_inputs: list of multi-modal inputs for the requests """ n: int input_ids: List[List[int]] + image_data: Union[List[List[bytes]], List[List[str]]] answers: List[str] + multi_modal_inputs: List[Dict] def repeat(self) -> "RolloutRequest": """Repeat each input in the RolloutRequest a specified number of times. @@ -67,10 +71,15 @@ def repeat(self) -> "RolloutRequest": """ assert self.n > 0, "n must be greater than 0" - input_ids, answers = zip( + input_ids, answers, image_data, multi_modal_inputs = zip( *[ - (input_id, answer) - for input_id, answer in zip(self.input_ids, self.answers) + (input_id, answer, image_data, multi_modal_inputs) + for input_id, answer, image_data, multi_modal_inputs in zip( + self.input_ids, + self.answers, + self.image_data, + self.multi_modal_inputs, + ) for _ in range(self.n) ] ) @@ -78,6 +87,8 @@ def repeat(self) -> "RolloutRequest": n=self.n, input_ids=list(input_ids), answers=list(answers), + image_data=list(image_data), + multi_modal_inputs=list(multi_modal_inputs), ) def split(self, num_splits: int) -> List["RolloutRequest"]: @@ -96,15 +107,27 @@ def split(self, num_splits: int) -> List["RolloutRequest"]: input_ids_split_list = split_list(self.input_ids, num_splits) answers_split_list = split_list(self.answers, num_splits) + image_data_split_list = split_list(self.image_data, num_splits) + multi_modal_inputs_split_list = split_list(self.multi_modal_inputs, num_splits) splitted_requests = [] - for input_ids_batch, answers_batch in zip( - input_ids_split_list, answers_split_list + for ( + input_ids_batch, + answers_batch, + image_data_batch, + multi_modal_inputs_batch, + ) in zip( + input_ids_split_list, + answers_split_list, + image_data_split_list, + multi_modal_inputs_split_list, ): request = RolloutRequest( n=self.n, input_ids=input_ids_batch, answers=answers_batch, + image_data=image_data_batch, + multi_modal_inputs=multi_modal_inputs_batch, ) splitted_requests.append(request) @@ -113,14 +136,24 @@ def split(self, num_splits: int) -> List["RolloutRequest"]: def repeat_and_split( self, rollout_batch_size: Optional[int] = None ) -> List["RolloutRequest"]: - input_ids, answers = zip( + input_ids, answers, image_data, multi_modal_inputs = zip( *[ - (input_id, answer) - for input_id, answer in zip(self.input_ids, self.answers) + (input_id, answer, image_data, multi_modal_inputs) + for input_id, answer, image_data, multi_modal_inputs in zip( + self.input_ids, + self.answers, + self.image_data, + self.multi_modal_inputs, + ) for _ in range(self.n) ] ) - input_ids, answers = (list(input_ids), list(answers)) + input_ids, answers, image_data, multi_modal_inputs = ( + list(input_ids), + list(answers), + list(image_data), + list(multi_modal_inputs), + ) # Split input ids based on rollout_batch_size_per_gpu if rollout_batch_size is None: @@ -134,14 +167,26 @@ def repeat_and_split( splitted_requests = [] input_ids_split_list = split_list(input_ids, num_batches) answers_split_list = split_list(answers, num_batches) - - for input_ids_batch, answers_batch in zip( - input_ids_split_list, answers_split_list + image_data_split_list = split_list(image_data, num_batches) + multi_modal_inputs_split_list = split_list(multi_modal_inputs, num_batches) + + for ( + input_ids_batch, + answers_batch, + image_data_batch, + multi_modal_inputs_batch, + ) in zip( + input_ids_split_list, + answers_split_list, + image_data_split_list, + multi_modal_inputs_split_list, ): request = RolloutRequest( n=self.n, input_ids=input_ids_batch, answers=answers_batch, + image_data=image_data_batch, + multi_modal_inputs=multi_modal_inputs_batch, ) splitted_requests.append(request) @@ -256,8 +301,9 @@ class RolloutResult: advantages: Optional[List[float] | torch.Tensor] = None prompt_texts: Optional[List[str]] = None response_texts: Optional[List[str]] = None - answers: Optional[List[str]] = None - + answers: Optional[List[str | dict]] = None + image_data: Optional[Union[List[List[bytes]], List[List[str]]]] = None + multi_modal_inputs: Optional[List[dict]] = None # Inference # Only set when recompute_logprobs is False rollout_logprobs: Optional[List[List[float]]] = None @@ -308,12 +354,13 @@ def _get_attention_masks_and_position_ids( @staticmethod def from_vllm_results( group_size: int, - results: List[VllmRequestOutput], + results: List["VllmRequestOutput"], answers: Optional[List[str]] = None, + multi_modal_inputs: Optional[List[Dict]] = None, return_logprobs: bool = False, ) -> "RolloutResult": def get_logprobs( - response_ids: List[int], output: CompletionOutput + response_ids: List[int], output: "CompletionOutput" ) -> List[float]: logprobs = [] returned_logprobs = output.logprobs @@ -326,6 +373,13 @@ def get_logprobs( num_sequences = len(results) * group_size + if multi_modal_inputs: + mm_inputs = [] + for mm_input in multi_modal_inputs: + mm_inputs.extend([mm_input] * group_size) + else: + mm_inputs = None + prompt_lengths = [] prompt_ids = [] response_lengths = [] @@ -368,6 +422,7 @@ def get_logprobs( response_ids=response_ids, response_lengths=response_lengths, response_texts=response_texts, + multi_modal_inputs=mm_inputs, is_end=is_end, ) if return_logprobs: @@ -380,6 +435,8 @@ def from_sglang_results( group_size: int, input_ids: List[List[int]], answers: Optional[List[List[int]]] = None, + image_data: Optional[Union[List[List[bytes]], List[List[str]]]] = None, + multi_modal_inputs: Optional[List[Dict]] = None, return_logprobs: bool = False, ) -> "RolloutResult": """Create a MathRolloutResult from the given results and input IDs. @@ -406,6 +463,8 @@ def from_sglang_results( response_lengths=[len(res["output_ids"]) for res in results], response_ids=[res["output_ids"] for res in results], answers=answers, + image_data=image_data, + multi_modal_inputs=multi_modal_inputs, is_end=[ res["meta_info"]["finish_reason"]["type"] == "stop" for res in results ], @@ -507,6 +566,161 @@ def merge_list(dst_list: List, src_list: List): return merged_result + @staticmethod + def split_result_list_by_group( + rollout_results: List["RolloutResult"], + ) -> List["RolloutResult"]: + """ + Split RolloutResult objects by group_size. + + If input has only one RolloutResult, split it into multiple RolloutResult objects by group_size. + If input has multiple RolloutResult objects, split each one and merge the results. + + Args: + rollout_results: List of input RolloutResult objects + + Returns: + List of RolloutResult objects grouped by group_size + """ + assert len(rollout_results) > 0, "No rollout results to split." + + all_split_results = [] + + for rollout_result in rollout_results: + split_results = RolloutResult._split_single_result_by_group(rollout_result) + all_split_results.extend(split_results) + + return all_split_results + + @staticmethod + def _split_single_result_by_group( + rollout_result: "RolloutResult", + ) -> List["RolloutResult"]: + """ + Split a single RolloutResult into multiple RolloutResult objects by group_size. + + Args: + rollout_result: The RolloutResult to be split + + Returns: + List of split RolloutResult objects + """ + group_size = rollout_result.group_size + num_sequence = rollout_result.num_sequence + + assert num_sequence % group_size == 0, ( + f"num_sequence ({num_sequence}) must be divisible by group_size ({group_size})" + ) + + num_groups = num_sequence // group_size + split_results = [] + + # Split list fields + prompt_lengths_split = split_list(rollout_result.prompt_lengths, num_groups) + prompt_ids_split = split_list(rollout_result.prompt_ids, num_groups) + response_lengths_split = split_list(rollout_result.response_lengths, num_groups) + response_ids_split = split_list(rollout_result.response_ids, num_groups) + is_end_split = split_list(rollout_result.is_end, num_groups) + + # Handle optional fields + answers_split = None + if rollout_result.answers is not None: + answers_split = split_list(rollout_result.answers, num_groups) + + image_data_split = None + if rollout_result.image_data is not None: + image_data_split = split_list(rollout_result.image_data, num_groups) + + multi_modal_inputs_split = None + if rollout_result.multi_modal_inputs is not None: + multi_modal_inputs_split = split_list( + rollout_result.multi_modal_inputs, num_groups + ) + + prompt_texts_split = None + if rollout_result.prompt_texts is not None: + prompt_texts_split = split_list(rollout_result.prompt_texts, num_groups) + + response_texts_split = None + if rollout_result.response_texts is not None: + response_texts_split = split_list(rollout_result.response_texts, num_groups) + + rollout_logprobs_split = None + if rollout_result.rollout_logprobs is not None: + rollout_logprobs_split = split_list( + rollout_result.rollout_logprobs, num_groups + ) + + # Handle tensor fields + rewards_split = None + if rollout_result.rewards is not None: + if isinstance(rollout_result.rewards, torch.Tensor): + rewards_split = torch.chunk(rollout_result.rewards, num_groups, dim=0) + else: + rewards_split = split_list(rollout_result.rewards, num_groups) + + advantages_split = None + if rollout_result.advantages is not None: + if isinstance(rollout_result.advantages, torch.Tensor): + advantages_split = torch.chunk( + rollout_result.advantages, num_groups, dim=0 + ) + else: + advantages_split = split_list(rollout_result.advantages, num_groups) + + prev_logprobs_split = None + if rollout_result.prev_logprobs is not None: + prev_logprobs_split = torch.chunk( + rollout_result.prev_logprobs, num_groups, dim=0 + ) + + ref_logprobs_split = None + if rollout_result.ref_logprobs is not None: + ref_logprobs_split = torch.chunk( + rollout_result.ref_logprobs, num_groups, dim=0 + ) + + # Create split RolloutResult objects + for i in range(num_groups): + split_result = RolloutResult( + num_sequence=group_size, + group_size=group_size, + prompt_lengths=prompt_lengths_split[i], + prompt_ids=prompt_ids_split[i], + response_lengths=response_lengths_split[i], + response_ids=response_ids_split[i], + is_end=is_end_split[i], + answers=answers_split[i] if answers_split is not None else None, + image_data=image_data_split[i] + if image_data_split is not None + else None, + multi_modal_inputs=multi_modal_inputs_split[i] + if multi_modal_inputs_split is not None + else None, + prompt_texts=prompt_texts_split[i] + if prompt_texts_split is not None + else None, + response_texts=response_texts_split[i] + if response_texts_split is not None + else None, + rollout_logprobs=rollout_logprobs_split[i] + if rollout_logprobs_split is not None + else None, + rewards=rewards_split[i] if rewards_split is not None else None, + advantages=advantages_split[i] + if advantages_split is not None + else None, + prev_logprobs=prev_logprobs_split[i] + if prev_logprobs_split is not None + else None, + ref_logprobs=ref_logprobs_split[i] + if ref_logprobs_split is not None + else None, + ) + split_results.append(split_result) + + return split_results + def to_actor_batch( self, data_seq_length: int, @@ -595,17 +809,23 @@ def to_actor_batch( ) # [B, training_seq_length] batch = { - "input_ids": input_ids.cuda(), - "attention_mask": attention_mask.cuda(), - "is_end": is_end.cuda(), - "position_ids": position_ids.cuda(), - "prompt_lengths": prompt_lengths.cuda(), - "response_lengths": response_lengths.cuda(), + "input_ids": input_ids.npu(), + "attention_mask": attention_mask.npu(), + "is_end": is_end.npu(), + "position_ids": position_ids.npu(), + "prompt_lengths": prompt_lengths.npu(), + "response_lengths": response_lengths.npu(), } + if ( + self.multi_modal_inputs is not None + and self.multi_modal_inputs[0] is not None + ): + batch["multi_modal_inputs"] = self.multi_modal_inputs + if self.advantages is not None: if isinstance(self.advantages, torch.Tensor): - batch["advantages"] = self.advantages.cuda() + batch["advantages"] = self.advantages.npu() else: response_attention_mask = attention_mask[ :, -max_response_len: @@ -613,17 +833,17 @@ def to_actor_batch( advantages = torch.tensor(self.advantages, dtype=torch.float32).reshape( -1, 1 ) # [B, 1] - advantages = response_attention_mask.float().cuda() * advantages.cuda() - batch["advantages"] = advantages.cuda() + advantages = response_attention_mask.float().npu() * advantages.npu() + batch["advantages"] = advantages.npu() if self.prev_logprobs is not None: - batch["prev_logprobs"] = self.prev_logprobs.cuda() + batch["prev_logprobs"] = self.prev_logprobs.npu() if self.ref_logprobs is not None: - batch["ref_logprobs"] = self.ref_logprobs.cuda() + batch["ref_logprobs"] = self.ref_logprobs.npu() if self.rewards is not None: - batch["rewards"] = self.rewards.cuda() + batch["rewards"] = self.rewards.npu() if self.rollout_logprobs is not None: logprobs = batch_pad_to_fixed_len( @@ -634,7 +854,7 @@ def to_actor_batch( max_batch_len=max_response_len, pad_token=pad_token, ) - batch["prev_logprobs"] = logprobs.cuda() + batch["prev_logprobs"] = logprobs.npu() return batch @@ -648,14 +868,16 @@ def merge_batches( return merged_batch if len(batches) == 1: return batches[0] + for key in batches[0].keys(): - assert torch.is_tensor(batches[0][key]), ( - f"Expected tensor for key {key} in batches, got {type(batches[0][key])}" - ) - assert torch.is_tensor(batches[0][key]), ( - f"Expected tensor for key {key} in batches, got {type(batches[0][key])}" - ) - merged_batch[key] = torch.cat([batch[key] for batch in batches], dim=0) + if torch.is_tensor(batches[0][key]): + merged_batch[key] = torch.cat([batch[key] for batch in batches], dim=0) + elif isinstance(batches[0][key], list): + merged_batch[key] = [] + for batch in batches: + merged_batch[key].extend(batch[key]) + else: + raise ValueError(f"Unsupported batch key type: {type(batches[0][key])}") return merged_batch diff --git a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py index 8e3b9f16c..690b3d4f5 100644 --- a/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py +++ b/rlinf/hybrid_engines/fsdp/fsdp_model_manager.py @@ -17,11 +17,17 @@ import torch import torch.optim as optim from omegaconf import DictConfig +from torch.distributed.fsdp import ( + BackwardPrefetch, + MixedPrecision, + ShardingStrategy, + StateDictType, +) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType -from transformers import AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq from rlinf.config import torch_dtype_from_precision +from rlinf.data.tokenizers import hf_tokenizer from rlinf.hybrid_engines.fsdp.utils import ( get_fsdp_wrap_policy, init_fn, @@ -29,6 +35,7 @@ from rlinf.utils.logging import get_logger from rlinf.utils.utils import clear_memory +import torch_npu class FSDPModelManager: """ @@ -40,40 +47,55 @@ def __init__(self, cfg: DictConfig): self.logger = get_logger() self.torch_dtype = torch_dtype_from_precision(self._cfg.model.precision) - assert ( - self.torch_dtype == torch.float16 or self.torch_dtype == torch.bfloat16 - ), ( - f"Precision {self._cfg.model.precision} is not supported, only support bf16 and fp16." - ) + self.tokenizer = hf_tokenizer(cfg.tokenizer.tokenizer_model) def model_provider_func(self) -> torch.nn.Module: - if self._cfg.model.get("gptq_model", False): + cfg = self._cfg + use_gptq = cfg.model.get("gptq_model", False) + load_in_8bit = cfg.model.get("load_in_8bit", False) + + use_triton = cfg.get("use_triton", True) + + assert torch.npu.is_available(), "CUDA is not available." + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + device = torch.npu.device(f"npu:{local_rank}") + + model_config = AutoConfig.from_pretrained( + cfg.model.model_path, + trust_remote_code=True, + attn_implementation="flash_attention_2", + ) + + if use_gptq: from auto_gptq import AutoGPTQForCausalLM model_wrapper = AutoGPTQForCausalLM.from_quantized( - self._cfg.model.model_path, device="cuda:0", use_triton=True + cfg.model.model_path, + device=device, + use_triton=use_triton, ) model = model_wrapper.model - elif self._cfg.model.get("load_in_8bit", False): + elif load_in_8bit: model = AutoModelForCausalLM.from_pretrained( - self._cfg.model.model_path, - device_map=self._cfg.model.get("device_map", "auto"), + cfg.model.model_path, + config=model_config, load_in_8bit=True, ) else: - # default load in float16 - model = AutoModelForCausalLM.from_pretrained( - self._cfg.model.model_path, + if type(model_config) in AutoModelForVision2Seq._model_mapping.keys(): + auto_model_class = AutoModelForVision2Seq + else: + auto_model_class = AutoModelForCausalLM + + model = auto_model_class.from_pretrained( + cfg.model.model_path, torch_dtype=self.torch_dtype, - device_map=self._cfg.model.get("device_map", "auto"), + config=model_config, trust_remote_code=True, - use_safetensors=self._cfg.model.get("use_safetensors", False), ) - if torch.cuda.is_available(): - model = model.cuda() - if self.torch_dtype == torch.float16: - model = model.half() + if torch.distributed.is_initialized(): + torch.distributed.barrier() return model def setup_model_and_optimizer(self): @@ -108,12 +130,19 @@ def setup_model_and_optimizer(self): self.model = FSDP( module, param_init_fn=init_fn, - use_orig_params=False, auto_wrap_policy=auto_wrap_policy, device_id=int(os.environ["LOCAL_RANK"]), sharding_strategy=sharding_strategy, # zero3 mixed_precision=mixed_precision, sync_module_states=True, + forward_prefetch=self._cfg.fsdp.forward_prefetch, + backward_prefetch=( + BackwardPrefetch.BACKWARD_PRE + if self._cfg.fsdp.backward_prefetch + else None + ), + limit_all_gathers=self._cfg.fsdp.limit_all_gathers, + use_orig_params=self._cfg.fsdp.use_orig_params, ) # NOTE: Currently we assume that only the value head contains "value_head" in its name. @@ -130,7 +159,7 @@ def setup_model_and_optimizer(self): }, ] - if self._cfg.model.vh_mode in ["a", "a0", "a6"]: + if self._cfg.model.get("vh_mode", None) in ["a", "a0", "a6"]: param_groups.append( { "params": [ diff --git a/rlinf/hybrid_engines/fsdp/utils.py b/rlinf/hybrid_engines/fsdp/utils.py index 2334b1fa7..0a9f0054d 100644 --- a/rlinf/hybrid_engines/fsdp/utils.py +++ b/rlinf/hybrid_engines/fsdp/utils.py @@ -58,7 +58,7 @@ def cpu_init_weights(): return init_context -def get_fsdp_wrap_policy(module, config=None, is_lora=False): +def get_fsdp_wrap_policy(module, config=None, is_lora=False, is_vla_model=False): """ FSDP wrap policy that handles both standard transformer models and VLA models. @@ -76,11 +76,8 @@ def get_fsdp_wrap_policy(module, config=None, is_lora=False): if config.get("disable", False): return None - # Check if this is a VLA model by looking for language_model attribute - is_vla_model = hasattr(module, "language_model") - # Get transformer layer classes to wrap - if is_vla_model: + if hasattr(module, "language_model"): # For VLA models, get transformer classes from language_model submodule default_transformer_cls_names_to_wrap = getattr( module.language_model, "_no_split_modules", None diff --git a/rlinf/hybrid_engines/megatron/megatron_model_manager.py b/rlinf/hybrid_engines/megatron/megatron_model_manager.py index 57f34dbf5..6fe4fbfd6 100644 --- a/rlinf/hybrid_engines/megatron/megatron_model_manager.py +++ b/rlinf/hybrid_engines/megatron/megatron_model_manager.py @@ -184,6 +184,7 @@ def model_provider_func(self, pre_process, post_process): return model def optimizer_step(self, increment): + clear_memory() success, grad_norm, num_zeros_in_grad = self.optimizer.step() self.lr_scheduler.step(increment=increment) diff --git a/rlinf/hybrid_engines/sglang/sglang_0_4_4/sgl_scheduler.py b/rlinf/hybrid_engines/sglang/sglang_0_4_4/sgl_scheduler.py index 509c56bab..87f736798 100644 --- a/rlinf/hybrid_engines/sglang/sglang_0_4_4/sgl_scheduler.py +++ b/rlinf/hybrid_engines/sglang/sglang_0_4_4/sgl_scheduler.py @@ -108,10 +108,13 @@ def __init__( placement )[(self.get_parent_rank(), self._rank)] + use_presharded_weights = ( + False if self.cfg.actor.training_backend == "fsdp" else True + ) # it's important to use load_weight to load resharded weight from megatron for _, module in self.tp_worker.worker.model_runner.model.named_modules(): if hasattr(module, "use_presharded_weights"): - module.use_presharded_weights = True + module.use_presharded_weights = use_presharded_weights self._logger.info( f"Running Scheduler dp rank {self.get_parent_rank()}, tp rank {self.tp_rank}, corresponding actor weight rank = {self.actor_weight_rank}" diff --git a/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py b/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py index ef503527b..9a69b8548 100644 --- a/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py +++ b/rlinf/hybrid_engines/sglang/sglang_0_4_6/sgl_scheduler.py @@ -110,10 +110,13 @@ def __init__( self.actor_weight_rank = RankMapper.get_rollout_rank_to_actor_rank_map( placement )[(self.get_parent_rank(), self._rank)] + use_presharded_weights = ( + False if self.cfg.actor.training_backend == "fsdp" else True + ) # it's important to use load_weight to load resharded weight from megatron for _, module in self.tp_worker.worker.model_runner.model.named_modules(): if hasattr(module, "use_presharded_weights"): - module.use_presharded_weights = True + module.use_presharded_weights = use_presharded_weights self._logger.info( f"Running Scheduler dp rank {self.get_parent_rank()}, tp rank {self.tp_rank}, corresponding actor weight rank = {self.actor_weight_rank}" diff --git a/rlinf/hybrid_engines/sglang/sglang_0_4_9/sgl_scheduler.py b/rlinf/hybrid_engines/sglang/sglang_0_4_9/sgl_scheduler.py index a7057a161..1f9beb409 100644 --- a/rlinf/hybrid_engines/sglang/sglang_0_4_9/sgl_scheduler.py +++ b/rlinf/hybrid_engines/sglang/sglang_0_4_9/sgl_scheduler.py @@ -112,10 +112,13 @@ def __init__( self.actor_weight_rank = RankMapper.get_rollout_rank_to_actor_rank_map( placement )[(self.get_parent_rank(), self._rank)] + use_presharded_weights = ( + False if self.cfg.actor.training_backend == "fsdp" else True + ) # it's important to use load_weight to load resharded weight from megatron for _, module in self.tp_worker.worker.model_runner.model.named_modules(): if hasattr(module, "use_presharded_weights"): - module.use_presharded_weights = True + module.use_presharded_weights = use_presharded_weights self._logger.info( f"Running Scheduler dp rank {self.get_parent_rank()}, tp rank {self.tp_rank}, corresponding actor weight rank = {self.actor_weight_rank}" diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/__init__.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/__init__.py new file mode 100644 index 000000000..5b365ea1e --- /dev/null +++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/io_struct.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/io_struct.py new file mode 100644 index 000000000..960d40eb0 --- /dev/null +++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/io_struct.py @@ -0,0 +1,59 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class TaskMethodInput: + method_name: str + args: List[Any] = field(default_factory=list) + kwargs: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class TaskMethodOutput: + method_name: str + result: Optional[Any] = None + + +@dataclass +class OffloadReqInput: + pass + + +@dataclass +class OffloadReqOutput: + pass + + +@dataclass +class SyncWeightInput: + pass + + +@dataclass +class SyncWeightOutput: + pass + + +@dataclass +class SyncHFWeightInput: + pass + + +@dataclass +class SyncHFWeightOutput: + pass diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_engine.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_engine.py new file mode 100644 index 000000000..e8e05c88b --- /dev/null +++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_engine.py @@ -0,0 +1,363 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import atexit +import logging +import multiprocessing as mp +import os +import random +import signal +import threading +import time +from typing import Dict, Optional, Tuple + +import uvloop +import zmq +from omegaconf import DictConfig +from sglang.srt.entrypoints.engine import Engine as _Engine +from sglang.srt.managers.data_parallel_controller import ( + run_data_parallel_controller_process, +) +from sglang.srt.managers.detokenizer_manager import run_detokenizer_process +from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter + +# from sglang.srt.managers.scheduler import run_scheduler_process +# from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.managers.template_manager import TemplateManager +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.utils import ( + assert_pkg_version, + configure_logger, + get_bool_env_var, + get_zmq_socket, + is_cuda, + kill_process_tree, + launch_dummy_health_check_server, + prepare_model_and_tokenizer, + set_prometheus_multiproc_dir, + set_ulimit, +) + +from rlinf.scheduler import WorkerAddress +from rlinf.utils.placement import ComponentPlacement + +from .io_struct import OffloadReqInput, SyncHFWeightInput, SyncWeightInput +from .sgl_scheduler import run_scheduler_process +from .tokenizer_manager import TokenizerManager + +# Fix a bug of Python threading +setattr(threading, "_register_atexit", lambda *args, **kwargs: None) + +logger = logging.getLogger(__name__) +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +_is_cuda = is_cuda() + + +class Engine(_Engine): + def __init__( + self, + parent_address: WorkerAddress, + placement: ComponentPlacement, + config: DictConfig, + dp_rank: int, + **kwargs, + ): + """ + The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`. + Please refer to `ServerArgs` for the documentation. + """ + if "server_args" in kwargs: + # Directly load server_args + server_args = kwargs["server_args"] + else: + # Construct server_args from kwargs + if "log_level" not in kwargs: + # Do not print logs by default + kwargs["log_level"] = "error" + server_args = ServerArgs(**kwargs) + + # Shutdown the subprocesses automatically when the program exits + atexit.register(self.shutdown) + + # Allocate ports for inter-process communications + self.port_args = PortArgs.init_new(server_args) + + # Launch subprocesses + tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses( + parent_address=parent_address, + placement=placement, + config=config, + dp_rank=dp_rank, + server_args=server_args, + port_args=self.port_args, + ) + + self.server_args = server_args + self.tokenizer_manager = tokenizer_manager + self.scheduler_info = scheduler_info + self.template_manager = template_manager + + context = zmq.Context(2) + self.send_to_rpc = get_zmq_socket( + context, zmq.DEALER, self.port_args.rpc_ipc_name, True + ) + + def offload_model_weights(self): + """Offload model weights to meta.""" + obj = OffloadReqInput() + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.offload_model_weights(obj, None) + ) + + def sync_hf_weight(self): + obj = SyncHFWeightInput() + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.tokenizer_manager.sync_hf_weight(obj)) + + def sync_weight(self): + obj = SyncWeightInput() + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.tokenizer_manager.sync_weight(obj)) + + +def _set_envs_and_config(server_args: ServerArgs): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem)) + if not server_args.enable_symm_mem: + os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + os.environ["CUDA_MODULE_LOADING"] = "AUTO" + # flashinfer uses this environment variable for various kernels from MoE to quant kernels + os.environ["TRTLLM_ENABLE_PDL"] = "1" + + # Can also be passed as argument + os.environ["SGLANG_RUN_ID"] = ( + f"sglang-run-{time.time()}-{random.randint(0, 100000000)}" + ) + + # Set prometheus env vars + if server_args.enable_metrics: + set_prometheus_multiproc_dir() + + # Set ulimit + set_ulimit() + + # Check flashinfer version + if server_args.attention_backend == "flashinfer": + assert_pkg_version( + "flashinfer_python", + "0.3.0", + "Please uninstall the old version and " + "reinstall the latest version by following the instructions " + "at https://docs.flashinfer.ai/installation.html.", + ) + if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"): + assert_pkg_version( + "sgl-kernel", + "0.3.8", + "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", + ) + + if True: # Keep this check for internal code compatibility + # Register the signal handler. + # The child processes will send SIGQUIT to this process when any error happens + # This process then clean up the whole process tree + # Note: This sigquit handler is used in the launch phase, and may be replaced by + # the running_phase_sigquit_handler in the tokenizer manager after the grpc server is launched. + def launch_phase_sigquit_handler(signum, frame): + logger.error( + "Received sigquit from a child process. It usually means the child failed." + ) + kill_process_tree(os.getpid()) + + signal.signal(signal.SIGQUIT, launch_phase_sigquit_handler) + + # Set mp start method + mp.set_start_method("spawn", force=True) + + +def _launch_subprocesses( + parent_address: WorkerAddress, + placement: ComponentPlacement, + config: DictConfig, + dp_rank: int, + server_args: ServerArgs, + port_args: Optional[PortArgs] = None, +) -> Tuple[TokenizerManager, TemplateManager, Dict]: + """ + Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. + """ + + assert server_args.pp_size == 1, ( + "RLinf currently only supports and validates pp_size=1." + ) + + # Configure global environment + configure_logger(server_args) + server_args.check_server_args() + _set_envs_and_config(server_args) + + # Allocate ports for inter-process communications + if port_args is None: + port_args = PortArgs.init_new(server_args) + logger.info(f"{server_args=}") + + # If using model from www.modelscope.cn, first download the model. + server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( + server_args.model_path, server_args.tokenizer_path + ) + + scheduler_procs = [] + if server_args.dp_size == 1: + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + scheduler_pipe_readers = [] + + nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1) + tp_size_per_node = server_args.tp_size // nnodes_per_tp_group + tp_rank_range = range( + tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group), + tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1), + ) + + pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1) + pp_rank_range = range( + pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group), + pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1), + ) + + for pp_rank in pp_rank_range: + for tp_rank in tp_rank_range: + reader, writer = mp.Pipe(duplex=False) + gpu_id = ( + server_args.base_gpu_id + + ((pp_rank % pp_size_per_node) * tp_size_per_node) + + (tp_rank % tp_size_per_node) * server_args.gpu_id_step + ) + moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) + proc = mp.Process( + target=run_scheduler_process, + args=( + parent_address, + placement, + config, + server_args.tp_size * server_args.pp_size, + tp_rank + pp_rank * server_args.pp_size, + server_args, + port_args, + gpu_id, + tp_rank, + moe_ep_rank, + pp_rank, + None, + writer, + None, + ), + ) + + with memory_saver_adapter.configure_subprocess(): + proc.start() + scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) + else: + # Launch the data parallel controller + reader, writer = mp.Pipe(duplex=False) + scheduler_pipe_readers = [reader] + proc = mp.Process( + target=run_data_parallel_controller_process, + args=(server_args, port_args, writer), + ) + proc.start() + scheduler_procs.append(proc) + + if server_args.node_rank >= 1: + # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer, + # so they can just wait here. + + for reader in scheduler_pipe_readers: + data = reader.recv() + assert data["status"] == "ready" + + if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": + # When using `Engine` as a Python API, we don't want to block here. + return None, None, None + + launch_dummy_health_check_server( + server_args.host, server_args.port, server_args.enable_metrics + ) + + for proc in scheduler_procs: + proc.join() + logger.error( + f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" + ) + return None, None, None + + # Launch detokenizer process + detoken_proc = mp.Process( + target=run_detokenizer_process, + args=( + server_args, + port_args, + ), + ) + detoken_proc.start() + if server_args.tokenizer_worker_num > 1: + # Launch multi-tokenizer router + tokenizer_manager = MultiTokenizerRouter(server_args, port_args) + + # Initialize templates + template_manager = None + else: + # Launch tokenizer process + tokenizer_manager = TokenizerManager(server_args, port_args) + + # Initialize templates + template_manager = TemplateManager() + template_manager.initialize_templates( + tokenizer_manager=tokenizer_manager, + model_path=server_args.model_path, + chat_template=server_args.chat_template, + completion_template=server_args.completion_template, + ) + + # Wait for the model to finish loading + scheduler_infos = [] + for i in range(len(scheduler_pipe_readers)): + try: + data = scheduler_pipe_readers[i].recv() + except EOFError: + logger.error( + f"Rank {i} scheduler is dead. Please check if there are relevant logs." + ) + scheduler_procs[i].join() + logger.error(f"Exit code: {scheduler_procs[i].exitcode}") + raise + + if data["status"] != "ready": + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + scheduler_infos.append(data) + + # Assume all schedulers have the same scheduler_info + scheduler_info = scheduler_infos[0] + tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] + return tokenizer_manager, template_manager, scheduler_info diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_scheduler.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_scheduler.py new file mode 100644 index 000000000..f6aac9d27 --- /dev/null +++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/sgl_scheduler.py @@ -0,0 +1,476 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import faulthandler +import logging +import os +import signal +from typing import Optional + +import psutil +import setproctitle +import torch +from omegaconf import DictConfig +from sglang.srt.disaggregation.utils import ( + DisaggregationMode, +) +from sglang.srt.managers.io_struct import ( + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, +) +from sglang.srt.managers.scheduler import Scheduler as _Scheduler +from sglang.srt.managers.scheduler import logger +from sglang.srt.managers.utils import DPBalanceMeta +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import ( + broadcast_pyobj, + configure_logger, + get_bool_env_var, + kill_itself_when_parent_died, + set_gpu_proc_affinity, + suppress_other_loggers, +) +from sglang.utils import get_exception_traceback + +from rlinf.scheduler import Worker, WorkerAddress +from rlinf.utils.placement import ModelParallelComponentPlacement, PlacementMode +from rlinf.workers.rollout.utils import ( + RankMapper, + get_module_from_name, + rebind_param_attr, + swap_tensor_pointer, +) + +from .io_struct import ( + OffloadReqInput, + OffloadReqOutput, + SyncHFWeightInput, + SyncHFWeightOutput, + SyncWeightInput, + SyncWeightOutput, + TaskMethodInput, + TaskMethodOutput, +) +import torch_npu +logger.setLevel(logging.WARNING) + +def safe_load_weights(model, weights: list): + """ + 安全加载权重,自动兼容两种命名约定: + - 'visual.xxx' (Hugging Face Qwen-VL 格式) + - 'model.visual.xxx' (部分 SGLang 或自定义格式) + + Parameters: + model: PyTorch 模型实例(需有 named_parameters()) + weights: List of (name, torch.Tensor) + """ + params_dict = dict(model.named_parameters()) + + # 构建一个映射:标准化 key -> 实际参数名 + # 例如:'visual.patch_embed.proj.weight' 可能对应 params_dict 中的 'visual...' 或 'model.visual...' + normalized_to_actual = {} + for param_name in params_dict.keys(): + if param_name.startswith("model.visual."): + # 映射到无 model. 的标准名 + normalized = param_name[len("model."):] # "visual.xxx" + normalized_to_actual[normalized] = param_name + normalized_to_actual[param_name] = param_name # 也保留原名 + elif param_name.startswith("visual."): + normalized = param_name + normalized_to_actual[normalized] = param_name + normalized_to_actual["model." + normalized] = param_name # 兼容带 model. 的输入 + else: + # 非 visual 参数,直接映射 + normalized_to_actual[param_name] = param_name + + # 加载每个权重 + for name, loaded_weight in weights: + if name in normalized_to_actual: + actual_name = normalized_to_actual[name] + param = params_dict[actual_name] + assert param.shape == loaded_weight.shape, ( + f"Shape mismatch for {name}: expected {param.shape}, got {loaded_weight.shape}" + ) + param.copy_(loaded_weight) + else: + # 可选:跳过不存在的参数(如优化器状态、非模型参数) + print(f"[Warning] Skipping weight not in model: {name}") + continue + +class Scheduler(_Scheduler, Worker): + """ + Overridden class of SGLang's TP worker class _Scheduler. + A Scheduler is a Task that manages the TP worker, and performs necessary weight synchronization with actor and weight offloading. + """ + + def __init__( + self, + parent_address: WorkerAddress, + placement: ModelParallelComponentPlacement, + config: DictConfig, + world_size: int, + rank: int, + server_args: ServerArgs, + port_args: PortArgs, + gpu_id: int, + tp_rank: int, + moe_ep_rank: int, + pp_rank: int, + dp_rank: Optional[int], + dp_balance_meta: Optional[DPBalanceMeta] = None, + ): + Worker.__init__( + self, parent_address=parent_address, world_size=world_size, rank=rank + ) + + # since 0.4.6.post2, pp_rank is added into Scheduler init's parameters + # but we don't use it in our implementation, so we set it to 0 + _Scheduler.__init__( + self, + server_args, + port_args, + gpu_id, + tp_rank, + moe_ep_rank, + pp_rank, + dp_rank, + dp_balance_meta, + ) + # `TpModelWorkerClient` is used when ServerArgs.enable_overlap=True, and it has 'worker' attribute. + # But in early SGLang version, `TpModelWorker` doesn't have 'worker' attribute. + if not hasattr(self.tp_worker, "worker"): + self.tp_worker.worker = self.tp_worker + + self._request_dispatcher._mapping.extend( + [ + (TaskMethodInput, self.run_task_method), + (OffloadReqInput, self.offload_model_weights), + (SyncWeightInput, self.sync_weight), + (SyncHFWeightInput, self.sync_hf_weight), + ] + ) + self.cfg = config + self.binded_attr = {} + + self._actor_group_name = self.cfg.actor.group_name + self.placement_mode = placement.placement_mode + self.actor_weight_rank = RankMapper.get_rollout_rank_to_actor_rank_map( + placement + )[(self.get_parent_rank(), self._rank)] + # it's important to use load_weight to load resharded weight from megatron + for _, module in self.tp_worker.worker.model_runner.model.named_modules(): + if hasattr(module, "use_presharded_weights"): + module.use_presharded_weights = False + + self._logger.info( + f"Running Scheduler dp rank {self.get_parent_rank()}, tp rank {self.tp_rank}, corresponding actor weight rank = {self.actor_weight_rank}" + ) + + def sync_in_tp(self, fn: str = ""): + broadcast_pyobj( + [], self.tp_rank, self.tp_worker.worker.model_runner.tp_group.cpu_group + ) + # logger.info(f"{fn}: Sync in tp success!") + + def cuda_info(self, text: str = ""): + free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info() + free_gpu_memory /= 2**30 + total_gpu_memory /= 2**30 + + memory_allocated = torch.npu.memory_allocated() / 2**30 + memory_reserved = torch.npu.memory_reserved() / 2**30 + + self._logger.info( + f"[dp {self.get_parent_rank()}-tp {self.tp_rank}] {text} " + f"{memory_allocated=:.2f} GiB, {memory_reserved=:.2f} GiB, " + f"{free_gpu_memory=:.2f} GiB, {total_gpu_memory=:.2f} GiB" + ) + + def offload_model_weights(self, recv_req: OffloadReqInput): + use_cudagraph = not self.cfg.rollout.enforce_eager + colocate = self.placement_mode == PlacementMode.COLLOCATED + if not colocate: + assert use_cudagraph, "If not colocate, use_cudagraph must be True now." + + if use_cudagraph or not colocate: + self.release_memory_occupation(ReleaseMemoryOccupationReqInput()) + # self.cuda_info("After offload Model weights and kv cache") + return OffloadReqOutput() + + # manually offload + self.named_buffers = { + n: buf.clone() + for n, buf in self.tp_worker.worker.model_runner.model.named_buffers() + } + + self.binded_attr = { + name: param.__dict__ + for name, param in self.tp_worker.worker.model_runner.model.named_parameters() + } + + # offload parameters + self.tp_worker.worker.model_runner.model.to("meta") + + # offload kv cache + self.tp_worker.worker.model_runner.token_to_kv_pool._clear_buffers() + + self.flush_cache() + self.sync_in_tp("offload_model_weights") + # self.cuda_info("After offload Model weights and kv cache") + return OffloadReqOutput() + + def sync_hf_weight(self, recv_req: SyncHFWeightInput): + use_cudagraph = not self.cfg.rollout.enforce_eager + colocate = self.placement_mode == PlacementMode.COLLOCATED + + assert use_cudagraph, "use_cudagraph must be True now." + + state_dict = self.recv( + src_group_name=self._actor_group_name, + src_rank=self.actor_weight_rank, + ) + + model = self.tp_worker.worker.model_runner.model + + if colocate: + self.resume_memory_occupation(ResumeMemoryOccupationReqInput()) + for name, handle in state_dict.items(): + #func, args = handle + #list_args = list(args) + # NOTE: the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + #list_args[6] = torch.npu.current_device() + #new_weight = func(*list_args) + + #model.load_weights([(name, new_weight)]) + #self.tp_worker.worker.model_runner.update_weights_from_tensor( + # [(name, new_weight)], load_format="direct" + #) + #del new_weight + func, args = handle + import inspect + sig = inspect.signature(func) + param_names = list(sig.parameters.keys()) + + # 将 args 转为 kwargs + kwargs = {} + args = list(args) + for i, param_name in enumerate(param_names): + if i < len(args): + kwargs[param_name] = args[i] + else: + break + + # 修改设备参数(假设参数名是 'map_location' 或 'device') + if 'map_location' in kwargs: + kwargs['map_location'] = f"npu:{torch.npu.current_device()}" + elif 'device' in kwargs: + kwargs['device'] = torch.npu.current_device() + + new_weight = func(**kwargs) + model.load_weights([(name, new_weight)]) + #safe_load_weights(model, [(name, new_weight)]) + del new_weight + #fixed_weights = [] + #for name, weight in [(name, new_weight)]: + # if name.startswith("visual.") and not name.startswith("model.visual."): + # fixed_weights.append(("model." + name, weight)) + # else: + # fixed_weights.append((name, weight)) + #model.load_weights(fixed_weights) + #del new_weight + else: + # disaggregate mode, recv tensor directly + for name, tensor in state_dict.items(): + model.load_weights([(name, tensor)]) + self.flush_cache() + self.sync_in_tp("sync_hf_weight") + return SyncHFWeightOutput() + + def sync_weight(self, recv_req: SyncWeightInput): + use_cudagraph = not self.cfg.rollout.enforce_eager + colocate = self.placement_mode == PlacementMode.COLLOCATED + if not colocate: + assert use_cudagraph, "If not colocate, use_cudagraph must be True now." + + state_dict = self.recv( + src_group_name=self._actor_group_name, + src_rank=self.actor_weight_rank, + ) + model = self.tp_worker.worker.model_runner.model + + if use_cudagraph and colocate: + self.resume_memory_occupation(ResumeMemoryOccupationReqInput()) + + if colocate: + if use_cudagraph: + for name, handle in state_dict.items(): + func, args = handle + list_args = list(args) + # NOTE: the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = torch.npu.current_device() + new_weight = func(*list_args) + + self.tp_worker.worker.model_runner.update_weights_from_tensor( + [(name, new_weight)], load_format="direct" + ) + del new_weight + + else: + named_params = dict(model.named_parameters()) + for name, handle in state_dict.items(): + rebind_param_attr(model, name, self.binded_attr, materialize=False) + func, args = handle + list_args = list(args) + list_args[6] = torch.npu.current_device() + new_weight = func(*list_args) + vllm_weight = named_params[name] + assert vllm_weight.shape == new_weight.shape, ( + f"{name}: {vllm_weight.shape=}, {new_weight.shape=}" + ) + assert vllm_weight.dtype == new_weight.dtype, ( + f"{name}: {vllm_weight.dtype=}, {new_weight.dtype=}" + ) + + swap_tensor_pointer(vllm_weight, new_weight) + del new_weight + + for name, buffer in self.named_buffers.items(): + vllm_buffer = get_module_from_name(model, name) + assert vllm_buffer.shape == buffer.shape + assert vllm_buffer.dtype == buffer.dtype + swap_tensor_pointer(vllm_buffer, buffer) + + self.named_buffers = {} + + self.tp_worker.worker.model_runner.token_to_kv_pool._create_buffers() + else: + # disaggregate mode, recv tensor directly + named_tensors = [(n, p) for n, p in state_dict.items()] + self.tp_worker.worker.model_runner.update_weights_from_tensor( + named_tensors, load_format="direct" + ) + self.sync_in_tp("sync_weight") + + return SyncWeightOutput() + + def run_task_method(self, obj: TaskMethodInput): + """ + Run a CommTask method with the given name and arguments. + NOTE: will call wait() if async_op is True. + """ + result = getattr(self, obj.method_name)(*obj.args, **obj.kwargs) + if "async_op" in obj.kwargs and obj.kwargs["async_op"]: + result = result.wait() + return TaskMethodOutput(method_name=obj.method_name, result=result) + + +def run_scheduler_process( + parent_address: WorkerAddress, + placement: ModelParallelComponentPlacement, + config: DictConfig, + world_size: int, + rank: int, + server_args: ServerArgs, + port_args: PortArgs, + gpu_id: int, + tp_rank: int, + moe_ep_rank: int, + pp_rank: int, + dp_rank: Optional[int], + pipe_writer, + balance_meta: Optional[DPBalanceMeta] = None, +): + # Generate the prefix + prefix = "" + if dp_rank is not None: + prefix += f" DP{dp_rank}" + if server_args.tp_size > 1: + prefix += f" TP{tp_rank}" + if server_args.ep_size > 1: + prefix += f" EP{moe_ep_rank}" + if server_args.pp_size > 1: + prefix += f" PP{pp_rank}" + + # Config the process + setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}") + faulthandler.enable() + kill_itself_when_parent_died() + parent_process = psutil.Process().parent() + + # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var + if dp_rank is None and "SGLANG_DP_RANK" in os.environ: + dp_rank = int(os.environ["SGLANG_DP_RANK"]) + + # Configure the logger + configure_logger(server_args, prefix=prefix) + suppress_other_loggers() + + # Set cpu affinity to this gpu process + if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): + set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) + + # Create a scheduler and run the event loop + try: + scheduler = Scheduler( + parent_address, + placement, + config, + world_size, + rank, + server_args, + port_args, + gpu_id, + tp_rank, + moe_ep_rank, + pp_rank, + dp_rank, + dp_balance_meta=balance_meta, + ) + pipe_writer.send( + { + "status": "ready", + "max_total_num_tokens": scheduler.max_total_num_tokens, + "max_req_input_len": scheduler.max_req_input_len, + } + ) + + disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode + if disaggregation_mode == DisaggregationMode.NULL: + if server_args.pp_size > 1: + scheduler.event_loop_pp() + elif scheduler.enable_overlap: + scheduler.event_loop_overlap() + else: + scheduler.event_loop_normal() + elif disaggregation_mode == DisaggregationMode.PREFILL: + if scheduler.enable_overlap: + scheduler.event_loop_overlap_disagg_prefill() + else: + if server_args.pp_size > 1: + scheduler.event_loop_pp_disagg_prefill() + else: + scheduler.event_loop_normal_disagg_prefill() + + elif disaggregation_mode == DisaggregationMode.DECODE: + if scheduler.enable_overlap: + scheduler.event_loop_overlap_disagg_decode() + else: + scheduler.event_loop_normal_disagg_decode() + + except Exception: + traceback = get_exception_traceback() + logger.error(f"Scheduler hit an exception: {traceback}") + parent_process.send_signal(signal.SIGQUIT) diff --git a/rlinf/hybrid_engines/sglang/sglang_0_5_2/tokenizer_manager.py b/rlinf/hybrid_engines/sglang/sglang_0_5_2/tokenizer_manager.py new file mode 100644 index 000000000..07b77b200 --- /dev/null +++ b/rlinf/hybrid_engines/sglang/sglang_0_5_2/tokenizer_manager.py @@ -0,0 +1,129 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import fastapi +from sglang.srt.managers.io_struct import AbortReq +from sglang.srt.managers.tokenizer_manager import TokenizerManager as _TokenizerManager +#from sglang.srt.managers.tokenizer_manager import _Communicator +from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator +from sglang.srt.server_args import PortArgs, ServerArgs +from .io_struct import ( + OffloadReqInput, + OffloadReqOutput, + SyncHFWeightInput, + SyncHFWeightOutput, + SyncWeightInput, + SyncWeightOutput, + TaskMethodInput, + TaskMethodOutput, +) + + +# Add two methods and their communicators, input/output structs. +class TokenizerManager(_TokenizerManager): + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + ): + super().__init__( + server_args=server_args, + port_args=port_args, + ) + + self.run_task_method_communicator = _Communicator( + self.send_to_scheduler, + fan_out=server_args.dp_size, + ) + self.offload_model_weights_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.sync_weight_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.sync_hf_weight_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + + self._result_dispatcher._mapping.extend( + [ + ( + TaskMethodOutput, + self.run_task_method_communicator.handle_recv, + ), + ( + OffloadReqOutput, + self.offload_model_weights_communicator.handle_recv, + ), + ( + SyncWeightOutput, + self.sync_weight_communicator.handle_recv, + ), + ( + SyncHFWeightOutput, + self.sync_hf_weight_communicator.handle_recv, + ), + ] + ) + + async def run_task_method( + self, + obj: TaskMethodInput = None, + request: Optional[fastapi.Request] = None, + ): + """ + Run a task method with the given name and arguments. + """ + self.auto_create_handle_loop() + if isinstance(obj, str): + obj = TaskMethodInput(method_name=obj) + res: List[TaskMethodOutput] = await self.run_task_method_communicator(obj) + return res[0].result + + async def offload_model_weights( + self, + obj: OffloadReqInput = None, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + if obj is None: + obj = OffloadReqInput() + await self.offload_model_weights_communicator(obj) + + async def sync_hf_weight( + self, + obj: SyncHFWeightInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.sync_hf_weight_communicator(obj) + + async def sync_weight( + self, + obj: SyncWeightInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.sync_weight_communicator(obj) + + def abort_request(self, rid: str): + if rid != "" and rid not in self.rid_to_state: + return + req = AbortReq(rid) + self.send_to_scheduler.send_pyobj(req) + + async def pause_generation(self): + self.abort_request("") diff --git a/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py b/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py index 3e021e922..519895e49 100644 --- a/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py +++ b/rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py @@ -48,6 +48,9 @@ def __init__( ) # rlinf specific self.rlinf_config = rlinf_config + self.using_sharded_weight = ( + False if self.rlinf_config.actor.training_backend == "fsdp" else True + ) self._rlinf_worker = _RLinfWorker( parent_address=parent_address, world_size=vllm_config.parallel_config.world_size, @@ -103,7 +106,7 @@ def sync_hf_weight(self) -> None: def use_sharded_weights(self) -> None: model = self.model_runner.model for _, param in model.named_parameters(): - setattr(param, "is_sharded_weight", True) + setattr(param, "is_sharded_weight", self.using_sharded_weight) def get_dp_rank(self) -> int: return self._rlinf_worker.get_parent_rank() diff --git a/rlinf/models/__init__.py b/rlinf/models/__init__.py index 08207900a..617ac7467 100644 --- a/rlinf/models/__init__.py +++ b/rlinf/models/__init__.py @@ -17,7 +17,6 @@ import torch from omegaconf import DictConfig -from peft import LoraConfig, PeftModel, get_peft_model from transformers import ( AutoConfig, AutoImageProcessor, @@ -172,6 +171,8 @@ def get_model(model_path, cfg: DictConfig, override_config_kwargs=None): model = model.cuda() if cfg.is_lora: + from peft import LoraConfig, PeftModel, get_peft_model + if not hasattr(cfg, "lora_path") or cfg.lora_path is None: lora_config = LoraConfig( r=cfg.lora_rank, diff --git a/rlinf/models/embodiment/model_utils.py b/rlinf/models/embodiment/model_utils.py index 04425cfdc..8e7aebeb1 100644 --- a/rlinf/models/embodiment/model_utils.py +++ b/rlinf/models/embodiment/model_utils.py @@ -15,9 +15,10 @@ from typing import Any, Optional import torch -import torch.nn.functional as F from transformers.generation import TopKLogitsWarper +from rlinf.utils.utils import compute_entropy_from_logits, compute_logprobs_from_logits + def default_logits_processor(logits, action_tokens, vocab_size, n_action_bins): logits = logits.permute(0, 2, 1) # [B, vocab-size, action-dim] @@ -34,28 +35,6 @@ def default_logits_processor(logits, action_tokens, vocab_size, n_action_bins): return ret -def compute_logprobs_from_logits(logits, target): - logprobs = -F.cross_entropy( - logits, target=target, reduction="none" - ) # [B, action-dim] - return logprobs - - -def compute_entropy_from_logits(logits, epsilon=1e-10): - """ - Compute entropy by logits. - - Args: - logits: [B, vocab-size, seq-len] - Returns: - entropy: [B, seq-len] - """ - all_probs = F.softmax(logits, dim=1) # [B, vocab-size, seq-len] - all_log_probs = torch.log(all_probs + epsilon) - entropy = -torch.sum(all_probs * all_log_probs, dim=1) # [B, seq-len] - return entropy - - def custom_forward( model, input_ids, diff --git a/rlinf/runners/coding_online_rl_runner.py b/rlinf/runners/coding_online_rl_runner.py index 46be3423a..377667523 100644 --- a/rlinf/runners/coding_online_rl_runner.py +++ b/rlinf/runners/coding_online_rl_runner.py @@ -212,6 +212,7 @@ def run(self): infer_handle: Handle = self.inference.run_inference( input_channel=self.dataloader_channel, output_channel=self.inference_channel, + rollout_channel=None, compute_ref_logprobs=self.compute_ref_logprobs, ) inference_channel = self.inference_channel @@ -221,16 +222,12 @@ def run(self): # Advantages and returns adv_handle: Handle = self.actor.compute_advantages_and_returns( - input_channel=self.inference_channel, + input_channel=inference_channel, output_channel=self.actor_channel, ) # Actor training actor_input_channel = self.actor_channel - if self.is_pipeline: - # In pipeline mode, the rollout already contains the advantages and returns - # So the above two steps are in fact no-ops, and we should directly use the inference channel as the input - actor_input_channel = inference_channel actor_handle: Handle = self.actor.run_training( input_channel=actor_input_channel, ) diff --git a/rlinf/runners/math_runner.py b/rlinf/runners/reasoning_runner.py similarity index 91% rename from rlinf/runners/math_runner.py rename to rlinf/runners/reasoning_runner.py index e3f5b3750..b53010e18 100644 --- a/rlinf/runners/math_runner.py +++ b/rlinf/runners/reasoning_runner.py @@ -25,7 +25,7 @@ from tqdm import tqdm from rlinf.data.io_struct import RolloutRequest -from rlinf.scheduler import Channel, Worker +from rlinf.scheduler import Channel from rlinf.scheduler import WorkerGroupFuncResult as Handle from rlinf.utils.data_iter_utils import split_list from rlinf.utils.distributed import ScopedTimer @@ -35,6 +35,7 @@ from rlinf.utils.timers import Timer from rlinf.workers.actor.megatron_actor_worker import MegatronActor from rlinf.workers.inference.megatron_inference_worker import MegatronInference +from rlinf.workers.reward.reward_worker import RewardWorker if typing.TYPE_CHECKING: from rlinf.workers.rollout.sglang.sglang_worker import SGLangWorker @@ -43,8 +44,8 @@ logging.getLogger().setLevel(logging.INFO) -class MathRunner: - """Runner for math model training.""" +class ReasoningRunner: + """Runner for reasoning task RL training.""" def __init__( self, @@ -55,33 +56,28 @@ def __init__( rollout: Union["SGLangWorker", "VLLMWorker"], inference: Optional[MegatronInference], actor: MegatronActor, - reward: Optional[Worker] = None, + reward: RewardWorker, ): """""" self.cfg = cfg self.component_placement = placement self.is_pipeline = self.component_placement.is_disaggregated self.has_dedicated_inference = inference is not None - self.has_dedicated_reward = reward is not None # Workers self.rollout = rollout self.actor = actor # Collocated mode uses actor as inference self.inference = inference if self.has_dedicated_inference else self.actor - self.reward = reward if self.has_dedicated_reward else self.actor + self.reward = reward # Data channels self.dataloader_channel = Channel.create("DataLoader") self.rollout_channel = Channel.create("Rollout") # Create a local channel (i.e., a channel that is different in every process) # if inference is not a dedicated worker - self.inference_channel = Channel.create( - "Inference", local=not self.has_dedicated_inference - ) - self.reward_channel = Channel.create( - "Reward", local=not self.has_dedicated_reward - ) + self.inference_channel = Channel.create("Inference") + self.reward_channel = Channel.create("Reward") self.actor_channel = Channel.create("Actor", local=True) # Configurations @@ -179,8 +175,7 @@ def init_workers(self): self.actor.init_worker().wait() if self.has_dedicated_inference: self.inference.init_worker().wait() - if self.has_dedicated_reward: - self.reward.init_worker().wait() + self.reward.init_worker().wait() if self.cfg.runner.resume_dir is None: return @@ -274,18 +269,26 @@ def epoch(self): def _put_batch(self, batch: Dict[str, torch.Tensor]): prompt_ids = batch["prompt"].tolist() lengths = batch["length"].tolist() - answers = batch["answer"].tolist() - prompts = [ids[-pmp_len:] for ids, pmp_len in zip(prompt_ids, lengths)] + answers = batch["answer"] + image_data = batch["image_data"] + multi_modal_inputs = batch["multi_modal_inputs"] + prompt_ids = [ids[-pmp_len:] for ids, pmp_len in zip(prompt_ids, lengths)] rollout_dp_size = self.component_placement.rollout_dp_size - for input_ids, answers in zip( - split_list(prompts, rollout_dp_size, enforce_divisible_batch=False), + for input_ids, answers, image_data, multi_modal_inputs in zip( + split_list(prompt_ids, rollout_dp_size, enforce_divisible_batch=False), split_list(answers, rollout_dp_size, enforce_divisible_batch=False), + split_list(image_data, rollout_dp_size, enforce_divisible_batch=False), + split_list( + multi_modal_inputs, rollout_dp_size, enforce_divisible_batch=False + ), ): request = RolloutRequest( n=self.cfg.algorithm.group_size, input_ids=input_ids, answers=answers, + image_data=image_data, + multi_modal_inputs=multi_modal_inputs, ) self.dataloader_channel.put(request, async_op=True) @@ -327,38 +330,34 @@ def run(self): output_channel=self.rollout_channel, ) + # Rewards + reward_handle: Handle = self.reward.compute_rewards( + input_channel=self.rollout_channel, + output_channel=self.reward_channel, + ) + if self.recompute_logprobs: # Inference prev/ref logprobs infer_handle: Handle = self.inference.run_inference( - input_channel=self.rollout_channel, + input_channel=self.reward_channel, output_channel=self.inference_channel, + rollout_channel=self.rollout_channel, compute_ref_logprobs=self.compute_ref_logprobs, ) inference_channel = self.inference_channel else: infer_handle = None - inference_channel = self.rollout_channel - - # Rewards - reward_handle: Handle = self.reward.compute_rewards( - input_channel=inference_channel, - output_channel=self.reward_channel, - ) + inference_channel = self.reward_channel # Advantages and returns adv_handle: Handle = self.actor.compute_advantages_and_returns( - input_channel=self.reward_channel, + input_channel=inference_channel, output_channel=self.actor_channel, ) # Actor training - actor_input_channel = self.actor_channel - if self.is_pipeline: - # In pipeline mode, the rollout already contains the advantages and returns - # So the above two steps are in fact no-ops, and we should directly use the inference channel as the input - actor_input_channel = inference_channel actor_handle: Handle = self.actor.run_training( - input_channel=actor_input_channel, + input_channel=self.actor_channel, ) metrics = actor_handle.wait() @@ -421,7 +420,7 @@ def run(self): } self.metric_logger.log(training_metrics, logging_steps + i) - logging_metrics = time_metrics + logging_metrics = {f"{k}_time": v for k, v in time_metrics.items()} if self.cfg.actor.get("calculate_flops", False): flops_metrics = self._compute_flops_metrics( diff --git a/rlinf/utils/convertor/__init__.py b/rlinf/utils/convertor/__init__.py new file mode 100644 index 000000000..5b365ea1e --- /dev/null +++ b/rlinf/utils/convertor/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/rlinf/utils/convertor/utils.py b/rlinf/utils/convertor/utils.py new file mode 100644 index 000000000..761218650 --- /dev/null +++ b/rlinf/utils/convertor/utils.py @@ -0,0 +1,467 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Dict, List, Optional, Tuple + +import torch + + +class TransformType(Enum): + SPLIT_QKV = "split_qkv" + SPLIT_QKV_BIAS = "split_qkv_bias" + SPLIT_FC1 = "split_fc1" + SPLIT_NONE = "split_none" + + +class TransformFunc: + @staticmethod + def _split_gqa_tensor( + tensor: torch.Tensor, new_statedict: dict, weight_names: List[str], config + ) -> None: + hidden_size = config.model_config.hidden_size + num_attention_heads = config.model_config.num_attention_heads + num_query_groups = config.model_config.num_query_groups or num_attention_heads + head_dim = hidden_size // num_attention_heads + + target_tp = config.reshard_tp_size + assert num_query_groups % target_tp == 0, ( + "num_query_groups must be divisible by reshard_tp_size" + ) + local_num_query_groups = num_query_groups // target_tp + + # heads per query group + assert num_attention_heads % num_query_groups == 0, ( + "num_attention_heads must be divisible by num_query_groups" + ) + q_heads_per_group = num_attention_heads // num_query_groups + + num_channel_qkv = q_heads_per_group + 2 + + if tensor.ndim == 2: + # Weight: [out_features, in_features] + out_features, in_features = tensor.shape + expected_out = local_num_query_groups * num_channel_qkv * head_dim + assert out_features == expected_out, ( + f"Unexpected fused QKV weight shape {tensor.shape}, expect " + f"[{expected_out}, {in_features}] (local groups={local_num_query_groups})" + ) + + qkv = tensor.view( + local_num_query_groups, num_channel_qkv, head_dim, in_features + ) + q, k, v = torch.split( + qkv, [q_heads_per_group, 1, 1], dim=1 + ) # shapes: [G, qh, D, In], [G,1,D,In], [G,1,D,In] + q_full = q.reshape(-1, in_features).contiguous() + k_full = k.reshape(-1, in_features).contiguous() + v_full = v.reshape(-1, in_features).contiguous() + else: + # Bias: [out_features] + out_features = tensor.shape[0] + expected_out = local_num_query_groups * num_channel_qkv * head_dim + assert out_features == expected_out, ( + f"Unexpected fused QKV bias shape {tensor.shape}, expect " + f"[{expected_out}] (local groups={local_num_query_groups})" + ) + + qkv = tensor.view(local_num_query_groups, num_channel_qkv, head_dim) + q, k, v = torch.split(qkv, [q_heads_per_group, 1, 1], dim=1) + q_full = q.reshape(-1).contiguous() + k_full = k.reshape(-1).contiguous() + v_full = v.reshape(-1).contiguous() + + # Save to target names + new_statedict[weight_names[0]] = q_full.clone() + new_statedict[weight_names[1]] = k_full.clone() + new_statedict[weight_names[2]] = v_full.clone() + + @staticmethod + def split_fc1( + linear_fc1: torch.Tensor, new_statedict: dict, weight_names: List[str], config + ) -> None: + assert weight_names is not None and len(weight_names) == 2, ( + f"split_fc1 transform expects two weight names, got {weight_names}" + ) + + tp_size = config.model_config.tensor_model_parallel_size + target_tp = config.reshard_tp_size + split_size = linear_fc1.shape[0] // (tp_size // target_tp) + linear_fc1_slice = torch.split(linear_fc1, split_size, dim=0) + + gate_proj_shards = [] + up_proj_shards = [] + for weight in linear_fc1_slice: + assert weight.shape[0] % 2 == 0, ( + f"linear_fc1 weight shape {weight.shape} is not even along dim 0" + ) + weight_chunk = torch.chunk(weight, 2, dim=0) + gate_proj_shards.append(weight_chunk[0]) + up_proj_shards.append(weight_chunk[1]) + gate_proj = torch.cat(gate_proj_shards, dim=0) + up_proj = torch.cat(up_proj_shards, dim=0) + + new_statedict[weight_names[0]] = gate_proj.clone() + new_statedict[weight_names[1]] = up_proj.clone() + + @staticmethod + def split_none( + tensor: torch.Tensor, new_statedict: dict, weight_names: List[str] + ) -> None: + assert weight_names is not None and len(weight_names) == 1, ( + f"split_none transform expects one weight name, got {weight_names}" + ) + new_statedict[weight_names[0]] = tensor.clone() + + +@dataclass +class ConvertorRule: + pattern: re.Pattern + transform: TransformType + targets: List[str] + post: Optional[Callable] = None + + +class BaseConvertor: + def __init__(self, config, strict: bool = False): + self.cfg = config + self.strict = strict + self.rules = self.build_rules() + + def map_name(self, name: str) -> Optional[Tuple[TransformType, List[str]]]: + def _get_targets_from_match(templates: list[str], m: re.Match) -> list[str]: + gd = m.groupdict() + out = [] + for t in templates: + if "{" in t and "}" in t: + out.append(t.format(**gd)) + else: + out.append(m.expand(t)) + return out + + for r in self.rules: + m = r.pattern.fullmatch(name) + if not m: + continue + targets = r.targets + if r.post: + targets = r.post(targets, m) + full_names = _get_targets_from_match(targets, m) + return r.transform, full_names + return None + + def convert(self, state_dict: Dict) -> Dict: + converted = {} + for k, v in state_dict.items(): + mapped = self.map_name(k) + if mapped is None: + if self.strict: + raise KeyError(f"Unmapped key {k}") + continue + transform, targets = mapped + if transform in (TransformType.SPLIT_QKV, TransformType.SPLIT_QKV_BIAS): + TransformFunc._split_gqa_tensor(v, converted, targets, self.cfg) + elif transform == TransformType.SPLIT_FC1: + TransformFunc.split_fc1(v, converted, targets, self.cfg) + elif transform == TransformType.SPLIT_NONE: + TransformFunc.split_none(v, converted, targets) + else: + raise ValueError(f"Unknown transform type {transform}") + return converted + + def build_rules(self) -> List[ConvertorRule]: + """ + Should be implemented in subclass to build the conversion rules. + """ + raise NotImplementedError + + +class Qwen2_5Convertor(BaseConvertor): + def build_rules(self) -> List[ConvertorRule]: + LID = r"(?P\d+)" + WB = r"(?Pweight|bias)" + + return [ + # embeddings + ConvertorRule( + re.compile(r"embedding\.word_embeddings\.weight$"), + TransformType.SPLIT_NONE, + [r"model.embed_tokens.weight"], + ), + # final_layernorm + ConvertorRule( + re.compile(r"decoder\.final_layernorm\.weight$"), + TransformType.SPLIT_NONE, + [r"model.norm.weight"], + ), + # lm_head + ConvertorRule( + re.compile(r"output_layer\.weight$"), + TransformType.SPLIT_NONE, + [r"lm_head.weight"], + ), + # attn qkv norm + ConvertorRule( + re.compile( + rf"decoder\.layers\.{LID}\.self_attention\.linear_qkv\.layer_norm_weight$" + ), + TransformType.SPLIT_NONE, + [r"model.layers.\g.input_layernorm.weight"], + ), + # attn qkv weights/bias + ConvertorRule( + re.compile( + rf"decoder\.layers\.{LID}\.self_attention\.linear_qkv\.{WB}$" + ), + TransformType.SPLIT_QKV, + [ + r"model.layers.\g.self_attn.q_proj.\g", + r"model.layers.\g.self_attn.k_proj.\g", + r"model.layers.\g.self_attn.v_proj.\g", + ], + ), + # attn o proj + ConvertorRule( + re.compile( + rf"decoder\.layers\.{LID}\.self_attention\.linear_proj\.{WB}$" + ), + TransformType.SPLIT_NONE, + [r"model.layers.\g.self_attn.o_proj.\g"], + ), + # mlp fc1 + ConvertorRule( + re.compile(rf"decoder\.layers\.{LID}\.mlp\.linear_fc1\.{WB}$"), + TransformType.SPLIT_FC1, + [ + r"model.layers.\g.mlp.gate_proj.\g", + r"model.layers.\g.mlp.up_proj.\g", + ], + ), + # mlp fc2 + ConvertorRule( + re.compile(rf"decoder\.layers\.{LID}\.mlp\.linear_fc2\.{WB}$"), + TransformType.SPLIT_NONE, + [r"model.layers.\g.mlp.down_proj.\g"], + ), + # mlp norms + ConvertorRule( + re.compile( + rf"decoder\.layers\.{LID}\.mlp\.linear_fc1\.layer_norm_weight$" + ), + TransformType.SPLIT_NONE, + [r"model.layers.\g.post_attention_layernorm.weight"], + ), + ] + + +class Qwen2_5VLConvertor(BaseConvertor): + def _build_vision_rules(self) -> List[ConvertorRule]: + B = r"(?P\d+)" + WB = r"(?Pweight|bias)" + HF_V_PREFIX = "model.visual" + HF_V_DECODER_PREFIX = f"{HF_V_PREFIX}.blocks" + MG_V_PREFIX = "vision_model" + MG_V_DECODER_PREFIX = rf"{MG_V_PREFIX}\.decoder\.layers" + + vision_rules = [ + # vision patch embed + ConvertorRule( + re.compile(rf"^{MG_V_PREFIX}\.patch_embed\.proj\.weight$"), + TransformType.SPLIT_NONE, + [f"{HF_V_PREFIX}.patch_embed.proj.weight"], + ), + # final layer norm + ConvertorRule( + re.compile(rf"^{MG_V_PREFIX}\.decoder\.final_layernorm\.weight$"), + TransformType.SPLIT_NONE, + [f"{HF_V_PREFIX}.merger.ln_q.weight"], + ), + # attn norm + ConvertorRule( + re.compile( + rf"^{MG_V_DECODER_PREFIX}\.{B}\.self_attention\.layer_norm_weight$" + ), + TransformType.SPLIT_NONE, + [f"{HF_V_DECODER_PREFIX}" + r".\g.norm1.weight"], + ), + # attn qkv + ConvertorRule( + re.compile( + rf"^{MG_V_DECODER_PREFIX}\.{B}\.self_attention\.linear_qkv\.{WB}$" + ), + TransformType.SPLIT_NONE, + [f"{HF_V_DECODER_PREFIX}" + r".\g.attn.qkv.\g"], + ), + # attn proj + ConvertorRule( + re.compile( + rf"^{MG_V_DECODER_PREFIX}\.{B}\.self_attention\.linear_proj\.{WB}$" + ), + TransformType.SPLIT_NONE, + [f"{HF_V_DECODER_PREFIX}" + r".\g.attn.proj.\g"], + ), + # mlp fc1 + ConvertorRule( + re.compile(rf"^{MG_V_DECODER_PREFIX}\.{B}\.mlp\.linear_fc1\.{WB}$"), + TransformType.SPLIT_FC1, + [ + f"{HF_V_DECODER_PREFIX}" + r".\g.mlp.gate_proj.\g", + f"{HF_V_DECODER_PREFIX}" + r".\g.mlp.up_proj.\g", + ], + ), + # mlp fc2 + ConvertorRule( + re.compile(rf"^{MG_V_DECODER_PREFIX}\.{B}\.mlp\.linear_fc2\.{WB}$"), + TransformType.SPLIT_NONE, + [f"{HF_V_DECODER_PREFIX}" + r".\g.mlp.down_proj.\g"], + ), + # mlp norm + ConvertorRule( + re.compile( + rf"^{MG_V_DECODER_PREFIX}\.{B}\.mlp\.linear_fc1\.layer_norm_weight$" + ), + TransformType.SPLIT_NONE, + [f"{HF_V_DECODER_PREFIX}" + r".\g.norm2.weight"], + ), + ] + return vision_rules + + def _build_llm_rules(self) -> List[ConvertorRule]: + B = r"(?P\d+)" + WB = r"(?Pweight|bias)" + HF_LLM_PREFIX = "model.language_model" + MG_LLM_PREFIX = "language_model" + MG_LLM_DECODER_PREFIX = rf"{MG_LLM_PREFIX}\.decoder\.layers" + + llm_rules = [ + # embeddings + ConvertorRule( + re.compile(rf"^{MG_LLM_PREFIX}\.embed_tokens\.weight$"), + TransformType.SPLIT_NONE, + [f"{HF_LLM_PREFIX}.embedding.weight"], + ), + # final_layernorm + ConvertorRule( + re.compile(rf"^{MG_LLM_PREFIX}\.final_layernorm\.weight$"), + TransformType.SPLIT_NONE, + [f"{HF_LLM_PREFIX}.norm.weight"], + ), + # attn norm + ConvertorRule( + re.compile( + rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.self_attention\.layer_norm_weight$" + ), + TransformType.SPLIT_NONE, + [f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.input_layernorm.weight"], + ), + # attn qkv + ConvertorRule( + re.compile( + rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.self_attention\.linear_qkv\.{WB}$" + ), + TransformType.SPLIT_QKV, + [ + f"{HF_LLM_PREFIX}" + + r".decoder.layers.\g.self_attn.q_proj.\g", + f"{HF_LLM_PREFIX}" + + r".decoder.layers.\g.self_attn.k_proj.\g", + f"{HF_LLM_PREFIX}" + + r".decoder.layers.\g.self_attn.v_proj.\g", + ], + ), + # attn proj + ConvertorRule( + re.compile( + rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.self_attention\.linear_proj\.{WB}$" + ), + TransformType.SPLIT_NONE, + [f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.self_attn.o_proj.\g"], + ), + # mlp fc1 + ConvertorRule( + re.compile(rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.mlp\.linear_fc1\.{WB}$"), + TransformType.SPLIT_FC1, + [ + f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.mlp.gate_proj.\g", + f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.mlp.up_proj.\g", + ], + ), + # mlp fc2 + ConvertorRule( + re.compile(rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.mlp\.linear_fc2\.{WB}$"), + TransformType.SPLIT_NONE, + [f"{HF_LLM_PREFIX}" + r".decoder.layers.\g.mlp.down_proj.\g"], + ), + # mlp norm + ConvertorRule( + re.compile( + rf"^{MG_LLM_DECODER_PREFIX}\.{B}\.mlp\.linear_fc1\.layer_norm_weight$" + ), + TransformType.SPLIT_NONE, + [ + f"{HF_LLM_PREFIX}" + + r".decoder.layers.\g.post_attention_layernorm.weight" + ], + ), + ] + return llm_rules + + def _build_projector_rules(self) -> List[ConvertorRule]: + HF_PROJECTOR_PREFIX = "model.visual.merger" + MG_PROJECTOR_PREFIX = "vision_model.protection.encoder" + WB = r"(?Pweight|bias)" + + projector_rules = [ + # projector fc1 + ConvertorRule( + re.compile(rf"^{MG_PROJECTOR_PREFIX}\.linear_fc1\.{WB}$"), + TransformType.SPLIT_NONE, + [f"{HF_PROJECTOR_PREFIX}" + r".mlp.0.\g"], + ), + # projector fc2 + ConvertorRule( + re.compile(rf"^{MG_PROJECTOR_PREFIX}\.linear_fc2\.{WB}$"), + TransformType.SPLIT_NONE, + [f"{HF_PROJECTOR_PREFIX}" + r".mlp.2.\g"], + ), + ] + return projector_rules + + def build_rules(self) -> List[ConvertorRule]: + rules = [] + rules.extend(self._build_vision_rules()) + rules.extend(self._build_llm_rules()) + rules.extend(self._build_projector_rules()) + return rules + + +_MG2HF_CONVERTOR_REGISTRY = {} + + +def register_mg2hf_convertor(model_arch: str, convertor_cls: Callable) -> None: + if model_arch in _MG2HF_CONVERTOR_REGISTRY: + raise ValueError(f"Convertor for {model_arch} already registered") + _MG2HF_CONVERTOR_REGISTRY[model_arch] = convertor_cls + + +register_mg2hf_convertor("qwen2.5", Qwen2_5Convertor) +register_mg2hf_convertor("qwen2.5_vl", Qwen2_5VLConvertor) + + +def get_mg2hf_convertor(model_arch: str, config, strict: bool = False) -> BaseConvertor: + if model_arch not in _MG2HF_CONVERTOR_REGISTRY: + raise ValueError(f"No convertor registered for {model_arch}") + convertor_cls = _MG2HF_CONVERTOR_REGISTRY[model_arch] + return convertor_cls(config=config, strict=strict) diff --git a/rlinf/utils/data_iter_utils.py b/rlinf/utils/data_iter_utils.py index 27bbd7215..31e440201 100644 --- a/rlinf/utils/data_iter_utils.py +++ b/rlinf/utils/data_iter_utils.py @@ -60,13 +60,17 @@ def concat_dict_list(list_of_dicts: List[Dict[str, Any]]) -> Dict[str, Any]: return result -def split_list(inputs, num_chunks, enforce_divisible_batch: Optional[bool] = True): +def split_list( + inputs: List, num_chunks: int, enforce_divisible_batch: Optional[bool] = True +): """ Split a list into equal sized chunks """ if enforce_divisible_batch: chunk_size = len(inputs) // num_chunks - assert len(inputs) % chunk_size == 0, "Issue with batch size configuration!" + assert len(inputs) % chunk_size == 0, ( + f"Issue with batch size configuration! inputs len:{len(inputs)} num_chunks:{num_chunks}" + ) return [inputs[i : i + chunk_size] for i in range(0, len(inputs), chunk_size)] else: k, m = divmod(len(inputs), num_chunks) diff --git a/rlinf/utils/distributed.py b/rlinf/utils/distributed.py index e9da8d6da..1f5ddc9c6 100644 --- a/rlinf/utils/distributed.py +++ b/rlinf/utils/distributed.py @@ -29,11 +29,16 @@ from rlinf.utils.timers import NamedTimer - +import torch_npu def compute_rollout_metrics( - rollout_batch, max_prompt_len, response_len, use_critic=False + rollout_batch, + max_prompt_len, + response_len, + dp_world_size, + dp_group=None, + use_critic=False, ): - device = torch.device(f"cuda:{torch.cuda.current_device()}") + device = torch.device(f"npu:{torch.npu.current_device()}") advantages = rollout_batch["advantages"].to(device=device) mask = rollout_batch["attention_mask"][:, -response_len:].to(device=device) prompt_lengths = rollout_batch["prompt_lengths"].clone().to(device=device) @@ -41,8 +46,6 @@ def compute_rollout_metrics( reward_scores = rollout_batch["rewards"].clone().to(device=device) is_end = rollout_batch["is_end"].clone().float().to(device=device) - dp_world_size = parallel_state.get_data_parallel_world_size() - prompt_lengths_list = [ torch.empty_like(prompt_lengths) for _ in range(dp_world_size) ] @@ -52,12 +55,12 @@ def compute_rollout_metrics( torch.distributed.all_gather( prompt_lengths_list, prompt_lengths, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) torch.distributed.all_gather( decode_lengths_list, response_lengths, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) total_prompt_lengths = torch.cat(prompt_lengths_list, dim=0) @@ -66,22 +69,22 @@ def compute_rollout_metrics( torch.distributed.all_reduce( prompt_lengths, torch.distributed.ReduceOp.AVG, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) torch.distributed.all_reduce( response_lengths, torch.distributed.ReduceOp.AVG, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) torch.distributed.all_reduce( reward_scores, torch.distributed.ReduceOp.AVG, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) torch.distributed.all_reduce( is_end, torch.distributed.ReduceOp.AVG, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) valid_adv = torch.masked_select(advantages, mask) @@ -90,24 +93,24 @@ def compute_rollout_metrics( torch.distributed.all_reduce( n_valid_token, op=torch.distributed.ReduceOp.SUM, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) torch.distributed.all_reduce( adv_sum, op=torch.distributed.ReduceOp.SUM, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) adv_mean = adv_sum / n_valid_token adv_max = torch.max(valid_adv).detach().item() adv_min = torch.min(valid_adv).detach().item() reduce_tensor = torch.as_tensor( - [-adv_min, adv_max], device=torch.cuda.current_device(), dtype=torch.float32 + [-adv_min, adv_max], device=torch.npu.current_device(), dtype=torch.float32 ) torch.distributed.all_reduce( reduce_tensor, torch.distributed.ReduceOp.MAX, - group=parallel_state.get_data_parallel_group(), + group=dp_group, ) adv_min, adv_max = reduce_tensor.tolist() @@ -172,7 +175,7 @@ def from_rollout_batches( dp_group: Optional[ProcessGroup], partitioning_tool: Callable, ) -> Self: - current_device = torch.cuda.current_device() + current_device = torch.npu.current_device() attn_mask = rollout_batches.get("attention_mask") current_num_samples = attn_mask.size(0) @@ -403,12 +406,12 @@ def rebalance_nd_tensor(tensor, group): NOTE: assumes all other (i.e., non-zero) dimensions are equal. """ num_samples = torch.as_tensor( - tensor.size(0), dtype=torch.int64, device=torch.cuda.current_device() + tensor.size(0), dtype=torch.int64, device=torch.npu.current_device() ) batch_num_per_rank = torch.zeros( torch.distributed.get_world_size(group), dtype=torch.int64, - device=torch.cuda.current_device(), + device=torch.npu.current_device(), ) torch.distributed.all_gather_into_tensor( batch_num_per_rank, num_samples, group=group @@ -419,7 +422,7 @@ def rebalance_nd_tensor(tensor, group): indices = batch_num_per_rank.cumsum(dim=0) output_tensor = torch.zeros( - B, *other_dims, dtype=tensor.dtype, device=torch.cuda.current_device() + B, *other_dims, dtype=tensor.dtype, device=torch.npu.current_device() ) # tensor_split is a view we can copy into @@ -451,7 +454,7 @@ def broadcast_tensor( """ if torch.distributed.get_rank() == src: - tensor = tensor.cuda() + tensor = tensor.npu() if dtype: tensor = tensor.to(dtype) @@ -464,7 +467,7 @@ def broadcast_tensor( torch.distributed.broadcast_object_list(metadata, src, group) dtype, input_shape = metadata - tensor = torch.empty(input_shape, dtype=dtype, device="cuda") + tensor = torch.empty(input_shape, dtype=dtype, device="npu") torch.distributed.broadcast(tensor, src, group) return tensor @@ -516,7 +519,7 @@ def broadcast_tensor_within_dp(tensor: torch.Tensor, dtype: torch.dtype): def gather_tensor(tensor, dst, group, dtype=None): """Gather any tensor to the dst rank from every other rank in the given group. All the ranks that send or receive data must call this function.""" - tensor = tensor.to(device=torch.cuda.current_device(), dtype=dtype) + tensor = tensor.to(device=torch.npu.current_device(), dtype=dtype) if torch.distributed.get_rank() == dst: gather_list = [ torch.empty_like(tensor) @@ -546,8 +549,8 @@ def normalize_tensor(tensor, mask, group=None): """normalizes a tensor using global mean and std""" dtype = torch.float64 tensor = tensor.to(dtype) - tensor = tensor.to(device=torch.cuda.current_device()) - mask = mask.to(device=torch.cuda.current_device()) + tensor = tensor.to(device=torch.npu.current_device()) + mask = mask.to(device=torch.npu.current_device()) tensor_global_mean, tensor_global_var = masked_global_mean_var( tensor, mask, group=group @@ -586,7 +589,7 @@ def masked_normalization( Normalized x, with the same shape as x. """ dtype = torch.float64 if high_precision else torch.float32 - x = x.to(dtype=dtype).cuda() + x = x.to(dtype=dtype).npu() if not inplace: x = x.clone() if dim is None: @@ -596,7 +599,7 @@ def masked_normalization( np.prod([x.shape[d] for d in dim]), dtype=dtype, device=x.device ) else: - mask = mask.to(dtype=dtype).cuda() + mask = mask.to(dtype=dtype).npu() assert len(mask.shape) == len(x.shape), (mask.shape, x.shape, dim) for i in range(len(x.shape)): if i in dim: @@ -640,8 +643,8 @@ def masked_global_mean_var(values, mask, group=None): mask and values must have same shape, with mask being {0,1} with 1 being the values we want to keep """ assert values.shape == mask.shape, (values.shape, mask.shape) - values = values.to(device=torch.cuda.current_device()) - mask = mask.to(device=torch.cuda.current_device()) + values = values.to(device=torch.npu.current_device()) + mask = mask.to(device=torch.npu.current_device()) values = values * mask @@ -649,7 +652,7 @@ def masked_global_mean_var(values, mask, group=None): sum_and_count = torch.tensor( [values.sum(), mask.sum()], dtype=torch.float64, - device=torch.cuda.current_device(), + device=torch.npu.current_device(), ) torch.distributed.all_reduce(sum_and_count, group=group) global_sum, global_count = sum_and_count @@ -657,7 +660,7 @@ def masked_global_mean_var(values, mask, group=None): variance_summed = ( (((values - global_mean) ** 2) * mask) .sum() - .to(device=torch.cuda.current_device(), dtype=torch.float64) + .to(device=torch.npu.current_device(), dtype=torch.float64) ) torch.distributed.all_reduce(variance_summed, group=group) @@ -666,12 +669,12 @@ def masked_global_mean_var(values, mask, group=None): def report_device_info(info_str): - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info() free_gpu_memory /= 2**30 total_gpu_memory /= 2**30 - memory_allocated = torch.cuda.memory_allocated() / 2**30 - memory_reserved = torch.cuda.memory_reserved() / 2**30 + memory_allocated = torch.npu.memory_allocated() / 2**30 + memory_reserved = torch.npu.memory_reserved() / 2**30 print( f"[Rank {torch.distributed.get_rank()}] {info_str}, {free_gpu_memory=:.2f} GiB, {total_gpu_memory=:.2f} GiB, {memory_allocated=:.2f} GiB, {memory_reserved=:.2f} GiB" @@ -722,7 +725,7 @@ def all_reduce_dict( ): keys = sorted(dictionary) tensor = torch.as_tensor( - [dictionary[k] for k in keys], dtype=dtype, device=torch.cuda.current_device() + [dictionary[k] for k in keys], dtype=dtype, device=torch.npu.current_device() ) torch.distributed.all_reduce(tensor, op=op, group=group) return dict(zip(keys, tensor.tolist())) diff --git a/rlinf/utils/placement.py b/rlinf/utils/placement.py index 7b707894e..6ecea767f 100644 --- a/rlinf/utils/placement.py +++ b/rlinf/utils/placement.py @@ -202,6 +202,7 @@ def __init__(self, config: DictConfig, cluster: Cluster): self._actor_gpus = self._component_gpu_map.get("actor", None) self._inference_gpus = self._component_gpu_map.get("inference", None) self._rollout_gpus = self._component_gpu_map.get("rollout", None) + self._reward_gpus = self._component_gpu_map.get("reward", None) assert self._actor_gpus is not None, ( "Actor GPUs must be specified in the component_placement config." ) @@ -224,11 +225,9 @@ def __init__(self, config: DictConfig, cluster: Cluster): len(self._inference_gpus) if self._inference_gpus else 0 ) self._rollout_num_gpus = len(self._rollout_gpus) + self._reward_num_gpus = len(self._reward_gpus) if self._reward_gpus else 0 if self._is_collocated(): - assert self.actor_tp_size >= self.rollout_tp_size, ( - f"Actor TP size {self.actor_tp_size} must be greater or equal to Rollout TP size {self.rollout_tp_size}." - ) assert self._inference_gpus is None, ( "Inference GPUs must not be specified in collocated mode." ) @@ -280,21 +279,24 @@ def _generate_placements(self): self._actor_gpus[0], self._actor_gpus[-1] ) - actor_tp_size = self._config.actor.model.tensor_model_parallel_size - rollout_tp_size = self._config.rollout.tensor_parallel_size - assert actor_tp_size >= rollout_tp_size, ( - f"Actor TP size ({actor_tp_size}) must be greater or equal to Rollout TP size ({rollout_tp_size})" - ) - assert actor_tp_size % rollout_tp_size == 0, ( - f"Actor TP size ({actor_tp_size}) must be divisible by Rollout TP size ({rollout_tp_size})" + if self.actor_tp_size > self.rollout_tp_size: + assert self.actor_tp_size % self.rollout_tp_size == 0, ( + f"Actor TP size ({self.actor_tp_size}) must be divisible by Rollout TP size ({self.rollout_tp_size})" + ) + stride = ( + self.actor_tp_size // self.rollout_tp_size + if self.actor_tp_size > self.rollout_tp_size + else 1 ) - stride = actor_tp_size // rollout_tp_size self._placements["rollout"] = PackedPlacementStrategy( self._rollout_gpus[0], self._rollout_gpus[-1], - num_accelerators_per_process=rollout_tp_size, + num_accelerators_per_process=self.rollout_tp_size, stride=stride, ) + self._placements["reward"] = PackedPlacementStrategy( + self._reward_gpus[0], self._reward_gpus[-1] + ) elif self._placement_mode == PlacementMode.DISAGGREGATED: # Generate continuous placement strategies for components in a cluster. num_gpus_per_rollout_dp = len(self._rollout_gpus) // self.rollout_dp_size @@ -310,6 +312,9 @@ def _generate_placements(self): self._placements["actor"] = PackedPlacementStrategy( self._actor_gpus[0], self._actor_gpus[-1] ) + self._placements["reward"] = PackedPlacementStrategy( + self._reward_gpus[0], self._reward_gpus[-1] + ) @property def is_disaggregated(self): @@ -325,18 +330,18 @@ def has_dedicated_inference(self): @property def actor_dp_size(self) -> int: return self._actor_num_gpus // ( - self._config.actor.model.tensor_model_parallel_size - * self._config.actor.model.context_parallel_size - * self._config.actor.model.pipeline_model_parallel_size + self._config.actor.model.get("tensor_model_parallel_size", 1) + * self._config.actor.model.get("context_parallel_size", 1) + * self._config.actor.model.get("pipeline_model_parallel_size", 1) ) @property def actor_tp_size(self) -> int: - return self._config.actor.model.tensor_model_parallel_size + return self._config.actor.model.get("tensor_model_parallel_size", 1) @property def actor_pp_size(self) -> int: - return self._config.actor.model.pipeline_model_parallel_size + return self._config.actor.model.get("pipeline_model_parallel_size", 1) @property def actor_world_size(self) -> int: @@ -349,7 +354,7 @@ def inference_tp_size(self) -> int: and hasattr(self._config.inference, "model") and hasattr(self._config.inference.model, "tensor_model_parallel_size") ): - return self._config.inference.model.tensor_model_parallel_size + return self._config.inference.model.get("tensor_model_parallel_size", 1) else: return self.actor_tp_size @@ -360,7 +365,7 @@ def inference_pp_size(self) -> int: and hasattr(self._config.inference, "model") and hasattr(self._config.inference.model, "pipeline_model_parallel_size") ): - return self._config.inference.model.pipeline_model_parallel_size + return self._config.inference.model.get("pipeline_model_parallel_size", 1) else: return self.actor_pp_size @@ -377,14 +382,18 @@ def inference_world_size(self) -> int: @property def rollout_dp_size(self) -> int: return self._rollout_num_gpus // ( - self._config.rollout.tensor_parallel_size - * self._config.rollout.pipeline_parallel_size + self._config.rollout.get("tensor_parallel_size", 1) + * self._config.rollout.get("pipeline_parallel_size", 1) ) @property def rollout_tp_size(self) -> int: - return self._config.rollout.tensor_parallel_size + return self._config.rollout.get("tensor_parallel_size", 1) @property def rollout_world_size(self) -> int: return self._rollout_num_gpus + + @property + def reward_world_size(self) -> int: + return self._reward_num_gpus diff --git a/rlinf/utils/resharding/mcore_weight_reshard.py b/rlinf/utils/resharding/mcore_weight_reshard.py index 9ec44eba0..90d8277fa 100644 --- a/rlinf/utils/resharding/mcore_weight_reshard.py +++ b/rlinf/utils/resharding/mcore_weight_reshard.py @@ -183,6 +183,6 @@ def get_layer_num(param_name): ) if self.config.convert_fn is not None: - model_state_dict = self.config.convert_fn(model_state_dict, self.config) + model_state_dict = self.config.convert_fn(model_state_dict) return model_state_dict diff --git a/rlinf/utils/resharding/reshard_config.py b/rlinf/utils/resharding/reshard_config.py index 2b493a839..79e089bcd 100644 --- a/rlinf/utils/resharding/reshard_config.py +++ b/rlinf/utils/resharding/reshard_config.py @@ -17,7 +17,9 @@ from megatron.core.transformer import TransformerConfig -from .utils import get_convert_fn, get_pp_reshard_fn, get_tp_reshard_fn +from rlinf.utils.convertor.utils import get_mg2hf_convertor + +from .utils import get_pp_reshard_fn, get_tp_reshard_fn @dataclass @@ -37,7 +39,7 @@ class ReshardConfig: """Resharding pp size.""" convert_fn: Callable = None - """Convert function to use for converting the model parameters' weight and name from training engine to rollout engine.""" + """Function to convert the model weights from megatron format to HuggingFace format.""" tp_reshard_fn: Callable = None """Resharding function to use for resharding the model parallelism from tensor_model_parallel_size to reshard_tp_size.""" @@ -59,7 +61,8 @@ def __post_init__(self): ) if self.convert_fn is None and self.reshard_weights_format != "mcore": - self.convert_fn = get_convert_fn(self.model_arch) + self._convertor = get_mg2hf_convertor(self.model_arch, self, strict=True) + self.convert_fn = self._convertor.convert if self.tp_reshard_fn is None: self.tp_reshard_fn = get_tp_reshard_fn(self.model_arch) diff --git a/rlinf/utils/resharding/utils.py b/rlinf/utils/resharding/utils.py index 1fae2b05a..d7a4af231 100644 --- a/rlinf/utils/resharding/utils.py +++ b/rlinf/utils/resharding/utils.py @@ -12,21 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re -from enum import Enum -from typing import List, Tuple import torch -from megatron.core import parallel_state - - -def get_convert_fn(model_arch: str): - if model_arch == "qwen2.5": - return TransformFunc.convert_mega_qwen2_5_to_hf - else: - raise NotImplementedError( - f"get_convert_fn for model_arch {model_arch} is not implemented" - ) def get_tp_reshard_fn(model_arch: str): @@ -47,212 +34,6 @@ def get_pp_reshard_fn(model_arch: str): ) -########################### -# convert fn implementation -########################### - - -class TransformType(Enum): - SPLIT_QKV = "split_qkv" - SPLIT_QKV_BIAS = "split_qkv_bias" - SPLIT_FC1 = "split_fc1" - SPLIT_NONE = "split_none" - - -class TransformFunc: - @staticmethod - def _split_gqa_tensor( - tensor: torch.Tensor, new_statedict: dict, weight_names: List[str], config - ) -> None: - hidden_size = config.model_config.hidden_size - num_attention_heads = config.model_config.num_attention_heads - num_query_groups = config.model_config.num_query_groups or num_attention_heads - head_dim = hidden_size // num_attention_heads - - target_tp = config.reshard_tp_size - assert num_query_groups % target_tp == 0, ( - "num_query_groups must be divisible by reshard_tp_size" - ) - local_num_query_groups = num_query_groups // target_tp - - # heads per query group - assert num_attention_heads % num_query_groups == 0, ( - "num_attention_heads must be divisible by num_query_groups" - ) - q_heads_per_group = num_attention_heads // num_query_groups - - num_channel_qkv = q_heads_per_group + 2 - - if tensor.ndim == 2: - # Weight: [out_features, in_features] - out_features, in_features = tensor.shape - expected_out = local_num_query_groups * num_channel_qkv * head_dim - assert out_features == expected_out, ( - f"Unexpected fused QKV weight shape {tensor.shape}, expect " - f"[{expected_out}, {in_features}] (local groups={local_num_query_groups})" - ) - - qkv = tensor.view( - local_num_query_groups, num_channel_qkv, head_dim, in_features - ) - q, k, v = torch.split( - qkv, [q_heads_per_group, 1, 1], dim=1 - ) # shapes: [G, qh, D, In], [G,1,D,In], [G,1,D,In] - q_full = q.reshape(-1, in_features).contiguous() - k_full = k.reshape(-1, in_features).contiguous() - v_full = v.reshape(-1, in_features).contiguous() - else: - # Bias: [out_features] - out_features = tensor.shape[0] - expected_out = local_num_query_groups * num_channel_qkv * head_dim - assert out_features == expected_out, ( - f"Unexpected fused QKV bias shape {tensor.shape}, expect " - f"[{expected_out}] (local groups={local_num_query_groups})" - ) - - qkv = tensor.view(local_num_query_groups, num_channel_qkv, head_dim) - q, k, v = torch.split(qkv, [q_heads_per_group, 1, 1], dim=1) - q_full = q.reshape(-1).contiguous() - k_full = k.reshape(-1).contiguous() - v_full = v.reshape(-1).contiguous() - - # Save to target names - new_statedict[weight_names[0]] = q_full.clone() - new_statedict[weight_names[1]] = k_full.clone() - new_statedict[weight_names[2]] = v_full.clone() - - @staticmethod - def split_fc1( - linear_fc1: torch.Tensor, new_statedict: dict, weight_names: List[str], config - ) -> None: - assert weight_names is not None and len(weight_names) == 2, ( - f"split_fc1 transform expects two weight names, got {weight_names}" - ) - - tp_size = config.model_config.tensor_model_parallel_size - target_tp = config.reshard_tp_size - split_size = linear_fc1.shape[0] // (tp_size // target_tp) - linear_fc1_slice = torch.split(linear_fc1, split_size, dim=0) - - gate_proj_shards = [] - up_proj_shards = [] - for weight in linear_fc1_slice: - assert weight.shape[0] % 2 == 0, ( - f"linear_fc1 weight shape {weight.shape} is not even along dim 0" - ) - weight_chunk = torch.chunk(weight, 2, dim=0) - gate_proj_shards.append(weight_chunk[0]) - up_proj_shards.append(weight_chunk[1]) - gate_proj = torch.cat(gate_proj_shards, dim=0) - up_proj = torch.cat(up_proj_shards, dim=0) - - new_statedict[weight_names[0]] = gate_proj.clone() - new_statedict[weight_names[1]] = up_proj.clone() - - @staticmethod - def split_none( - tensor: torch.Tensor, new_statedict: dict, weight_names: List[str] - ) -> None: - assert weight_names is not None and len(weight_names) == 1, ( - f"split_none transform expects one weight name, got {weight_names}" - ) - new_statedict[weight_names[0]] = tensor.clone() - - @staticmethod - def mega_name_qwen2_5_to_hf(name: str) -> Tuple[TransformType, List[str]]: - """ - Convert qwen2_5 model weight megatron name to hf name and do shape transform if needed. - - Args: - name (str): megatron model weight name - - Returns: - (TransformType, List[str]): transform type and the corresponding hf model weight name - """ - if "embedding.word_embeddings.weight" in name: - return (TransformType.SPLIT_NONE, ["model.embed_tokens.weight"]) - if "decoder.final_layernorm.weight" in name: - return (TransformType.SPLIT_NONE, ["model.norm.weight"]) - if "output_layer.weight" in name: - return (TransformType.SPLIT_NONE, ["lm_head.weight"]) - layer_id, suffix = TransformFunc.extract_layer_info(name) - assert layer_id is not None, f"Cannot extract layer info from {name}" - result_pattern = "model.layers.{}.{}" - nmap = { - "self_attention.linear_proj.weight": ( - TransformType.SPLIT_NONE, - ["self_attn.o_proj.weight"], - ), - "self_attention.linear_qkv.layer_norm_weight": ( - TransformType.SPLIT_NONE, - ["input_layernorm.weight"], - ), - "self_attention.linear_qkv.weight": ( - TransformType.SPLIT_QKV, - [ - "self_attn.q_proj.weight", - "self_attn.k_proj.weight", - "self_attn.v_proj.weight", - ], - ), - "self_attention.linear_qkv.bias": ( - TransformType.SPLIT_QKV_BIAS, - [ - "self_attn.q_proj.bias", - "self_attn.k_proj.bias", - "self_attn.v_proj.bias", - ], - ), - "mlp.linear_fc1.layer_norm_weight": ( - TransformType.SPLIT_NONE, - ["post_attention_layernorm.weight"], - ), - "mlp.linear_fc1.weight": ( - TransformType.SPLIT_FC1, - ["mlp.gate_proj.weight", "mlp.up_proj.weight"], - ), - "mlp.linear_fc2.weight": ( - TransformType.SPLIT_NONE, - ["mlp.down_proj.weight"], - ), - } - - assert suffix in nmap, f"Cannot find mapping for {suffix}" - - transform_type, suffixes = nmap[suffix] - return ( - transform_type, - [result_pattern.format(layer_id, suffix) for suffix in suffixes], - ) - - @staticmethod - def convert_mega_qwen2_5_to_hf(model_state_dict: dict, config) -> dict: - new_statedict = {} - for name, param in model_state_dict.items(): - transform_type, hf_names = TransformFunc.mega_name_qwen2_5_to_hf(name) - if transform_type == TransformType.SPLIT_QKV: - TransformFunc._split_gqa_tensor(param, new_statedict, hf_names, config) - elif transform_type == TransformType.SPLIT_QKV_BIAS: - TransformFunc._split_gqa_tensor(param, new_statedict, hf_names, config) - elif transform_type == TransformType.SPLIT_FC1: - TransformFunc.split_fc1(param, new_statedict, hf_names, config) - elif transform_type == TransformType.SPLIT_NONE: - TransformFunc.split_none(param, new_statedict, hf_names) - else: - raise NotImplementedError( - f"Transform type {transform_type} not implemented" - ) - return new_statedict - - @staticmethod - def extract_layer_info(s): - pattern = r"layers\.(\d+)\.(.+)" - match = re.search(pattern, s) - if match: - return match.group(1), match.group(2) - return None, None - - ############################## # tp reshard fn implementation ############################## @@ -314,6 +95,8 @@ def _gather_pp_group_tensor_and_reshard( def pp_reshard_fn_qwen2_5(model_state_dict, pp_group, dtype): + from megatron.core import parallel_state + pp_first_rank = parallel_state.get_pipeline_model_parallel_first_rank() pp_last_rank = parallel_state.get_pipeline_model_parallel_last_rank() diff --git a/rlinf/utils/utils.py b/rlinf/utils/utils.py index 449412f8b..b840538cc 100644 --- a/rlinf/utils/utils.py +++ b/rlinf/utils/utils.py @@ -20,13 +20,14 @@ from functools import partial, wraps import torch - +import torch.nn.functional as F +import torch_npu def clear_memory(sync=True): if sync: - torch.cuda.synchronize() + torch.npu.synchronize() gc.collect() - torch.cuda.empty_cache() + torch.npu.empty_cache() def apply_func_to_dict(func, dictionary): @@ -53,7 +54,7 @@ def retrieve_model_state_dict_in_cpu(model): cpu_dict[name] = item - torch.cuda.synchronize() + torch.npu.synchronize() return cpu_dict @@ -124,6 +125,65 @@ def seq_mean_token_mean(values, mask): return loss +def logprobs_from_logits_flash_attn(logits, labels, inplace_backward=True): + #from flash_attn.ops.triton.cross_entropy import cross_entropy_loss + + #output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward) + #assert isinstance(output, tuple), ( + # "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]." + #) + #return -output[0] + import torch.nn.functional as F + + # 数值稳定的 log_softmax + log_probs = F.log_softmax(logits, dim=-1) + + # 提取标签对应的 logprob + labels = labels.unsqueeze(-1) + return torch.gather(log_probs, -1, labels).squeeze(-1) + + +def compute_logprobs_from_logits(logits, target, task_type="embodied"): + if task_type == "embodied": + logprobs = -F.cross_entropy( + logits, target=target, reduction="none" + ) # [B, action-dim] + return logprobs + batch_dim = logits.shape[:-1] + last_dim = logits.shape[-1] + logits = logits.reshape(-1, last_dim) + labels = target.reshape(-1) + logprobs = logprobs_from_logits_flash_attn( + logits, labels=labels, inplace_backward=False + ) + logprobs = logprobs.view(*batch_dim) + return logprobs + + +def entropy_from_logits(logits: torch.Tensor): + """Calculate entropy from logits.""" + pd = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) + return entropy + + +def compute_entropy_from_logits(logits, epsilon=1e-10, task_type="embodied"): + """ + Compute entropy by logits. + + Args: + logits: [B, vocab-size, seq-len] + Returns: + entropy: [B, seq-len] + """ + if task_type == "embodied": + all_probs = F.softmax(logits, dim=1) # [B, vocab-size, seq-len] + all_log_probs = torch.log(all_probs + epsilon) + entropy = -torch.sum(all_probs * all_log_probs, dim=1) # [B, seq-len] + return entropy + return entropy_from_logits(logits=logits) + + class DualOutput: def __init__(self, file, terminal): self.file = file diff --git a/rlinf/workers/actor/__init__.py b/rlinf/workers/actor/__init__.py index 5b365ea1e..2d315469e 100644 --- a/rlinf/workers/actor/__init__.py +++ b/rlinf/workers/actor/__init__.py @@ -11,3 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from omegaconf import DictConfig + +from rlinf.scheduler.worker.worker import Worker + + +def get_actor_worker(cfg: DictConfig) -> Worker: + if cfg.actor.training_backend == "fsdp": + from .fsdp_actor_worker import FSDPActor + + return FSDPActor + elif cfg.actor.training_backend == "megatron": + from .megatron_actor_worker import MegatronActor + + return MegatronActor + else: + raise ValueError(f"Unsupported training backend: {cfg.actor.training_backend}") diff --git a/rlinf/workers/actor/fsdp_actor_worker.py b/rlinf/workers/actor/fsdp_actor_worker.py index 61b05c9b4..1b3b015bb 100644 --- a/rlinf/workers/actor/fsdp_actor_worker.py +++ b/rlinf/workers/actor/fsdp_actor_worker.py @@ -14,31 +14,518 @@ import gc import os +from contextlib import nullcontext +from typing import Dict, Tuple import numpy as np import torch from omegaconf import DictConfig from torch.distributed.device_mesh import init_device_mesh +from torch.multiprocessing.reductions import reduce_tensor from tqdm import tqdm import rlinf.algorithms # noqa: F401 from rlinf.algorithms.registry import actor_loss, calculate_adv_and_returns -from rlinf.algorithms.utils import preprocess_advantages_inputs, preprocess_loss_inputs +from rlinf.algorithms.utils import ( + kl_penalty, + preprocess_advantages_inputs, + preprocess_loss_inputs, +) +from rlinf.data.io_struct import RolloutResult from rlinf.hybrid_engines.fsdp.fsdp_model_manager import ( FSDPModelManager, ) from rlinf.models import get_model from rlinf.models.embodiment.model_utils import custom_forward -from rlinf.scheduler import Cluster, Worker +from rlinf.scheduler import Channel, Cluster, Worker from rlinf.utils.data_iter_utils import get_iterator_k_split from rlinf.utils.distributed import all_reduce_dict +from rlinf.utils.distributed import ( + compute_rollout_metrics as compute_math_rollout_metrics, +) from rlinf.utils.metric_utils import ( append_to_dict, compute_loss_mask, compute_rollout_metrics, compute_split_num, ) -from rlinf.utils.placement import HybridComponentPlacement +from rlinf.utils.placement import ( + HybridComponentPlacement, + ModelParallelComponentPlacement, +) +from rlinf.utils.utils import ( + compute_logprobs_from_logits, + cpu_weight_swap, + masked_mean, + retrieve_model_state_dict_in_cpu, + seq_mean_token_mean, + seq_mean_token_sum, +) +from rlinf.workers.rollout.utils import RankMapper +import torch_npu + +class FSDPActor(FSDPModelManager, Worker): + def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): + Worker.__init__(self) + super().__init__(cfg.actor) + + self.cfg = cfg + + self.response_len = ( + cfg.actor.model.encoder_seq_length - cfg.data.max_prompt_length + ) + self.calculate_entropy = self.cfg.algorithm.calculate_entropy + self.calculate_entropy_loss = ( + self.cfg.algorithm.entropy_bonus > 0 and self.calculate_entropy + ) + self.kl_beta = self.cfg.algorithm.kl_beta + self.kl_penalty_type = self.cfg.algorithm.kl_penalty_type + + self.total_batch_size_per_dp = ( + self.cfg.data.rollout_batch_size + * self.cfg.algorithm.group_size + // self._world_size + ) + + torch.npu.set_device(int(os.environ["LOCAL_RANK"])) + self.device = torch.npu.current_device() + world_size = self._world_size + self.device_mesh = init_device_mesh( + "npu", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] + ) + + self._rollout_group_name = cfg.rollout.group_name + self._component_placement = placement + self.is_data_io_rank = True + self.is_pipeline = self._component_placement.is_disaggregated + self.ref_policy_state_dict = None + + if self.cfg.algorithm.loss_agg_func == "token-mean": + self.loss_agg_func = masked_mean + elif self.cfg.algorithm.loss_agg_func == "seq-mean-token-sum": + self.loss_agg_func = seq_mean_token_sum + elif self.cfg.algorithm.loss_agg_func == "seq-mean-token-mean": + self.loss_agg_func = seq_mean_token_mean + else: + raise NotImplementedError( + f"algorithm.loss_agg_func={self.cfg.algorithm.loss_agg_func} is not supported!" + ) + + def init_worker(self) -> None: + self.setup_model_and_optimizer() + if self.cfg.algorithm.kl_beta > 0 and self.cfg.actor.get( + "combine_reference_model", True + ): + self.ref_policy_state_dict = retrieve_model_state_dict_in_cpu(self.model) + + if self.cfg.actor.get("enable_offload", False): + self.offload_fsdp_param_and_grad() + self.offload_fsdp_optimizer() + self._setup_rollout_weight_dst_ranks() + + def _setup_rollout_weight_dst_ranks(self) -> None: + """Setup destination ranks for token and weight communication.""" + rank_map = RankMapper.get_actor_rank_to_rollout_rank_map( + self._component_placement + ) + self._weight_dst_rank_in_rollout = rank_map[self._rank] + self.log_info( + f"Actor rank {self._rank} will send weights to {self._weight_dst_rank_in_rollout}" + ) + + def del_reshard_state_dict(self) -> None: + if hasattr(self, "rollout_state_dict"): + del self.rollout_state_dict + + def sync_model_to_rollout(self) -> None: + if self.cfg.actor.get("enable_offload", False): + self.offload_fsdp_optimizer() + + if next(self.model.parameters()).is_cpu: + self.load_fsdp_param_and_grad(self.device) + self.rollout_state_dict = self.get_model_state_dict() + + has_visual = any("visual." in k for k in self.rollout_state_dict.keys()) + + state_dict = {} + + if self._weight_dst_rank_in_rollout is not None: + for k, v in self.rollout_state_dict.items(): + name = k + if has_visual: + if name.startswith("model.language_model."): + name = "model." + name[21:] + # NOTE: + # if transformers version is 4.56.1 or older(not tested), + # the following line should be uncommented + + # elif name.startswith("model."): + # name = name[6:] + state_dict[name] = reduce_tensor(v) + + self.send( + state_dict, self._rollout_group_name, self._weight_dst_rank_in_rollout + ) + + if self.cfg.actor.get("enable_offload", False): + self.offload_fsdp_param_and_grad() + + def compute_logprobs(self) -> None: + self.model.eval() + self.rollout_batch["logprob"] = self.rollout_batch["prev_logprobs"] + + def get_batch( + self, channel: Channel + ) -> Tuple[Dict[str, torch.Tensor], RolloutResult]: + result: RolloutResult = channel.get() + + batch = result.to_actor_batch( + self.cfg.data.max_prompt_length, + self.cfg.actor.model.encoder_seq_length, + self.tokenizer.eos_token_id, + ) + return batch, result + + def put_result(self, result: RolloutResult, channel: Channel) -> None: + if channel.is_local: + # Local channel, every process will put its own data locally + # No need to broadcast + channel.put(result) + else: + if self.is_data_io_rank: + channel.put(result) + + def _load_weight_and_optimizer(self, channel: Channel) -> None: + # Acquire the GPUs to ensure that no one is using them before loading models + # Otherwise, it may lead to OOM + with channel.device_lock: + if self.cfg.actor.get("enable_offload", False): + self.load_fsdp_param_and_grad(self.device) + self.load_fsdp_optimizer(self.device) + + @torch.no_grad() + def inference_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + self.model.eval() + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + position_ids = batch["position_ids"] + + multi_modal_inputs = {} + if "multi_modal_inputs" in batch.keys(): + for key in batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat( + [inputs[key] for inputs in batch["multi_modal_inputs"]], + dim=0, + ).npu() + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + **multi_modal_inputs, + ) + + logits = outputs.logits + logits = logits[:, -self.response_len - 1 : -1, :] + logits = logits / self.cfg.algorithm.sampling_params.temperature + + responses = input_ids[:, -self.response_len :] + logprobs = compute_logprobs_from_logits( + logits, responses, task_type=self.cfg.runner.task_type + ) + return logprobs + + def run_inference( + self, + input_channel: Channel, + output_channel: Channel, + rollout_channel: Channel, + compute_ref_logprobs: bool, + ) -> None: + """ + Compute prev/ref logprobs using the actor Model's forward. + + Args: + input_channel: The input channel to read from. + output_channel: The output channel to send results to. + rollout_channel: get the rollout channel's device lock in case of collision. + compute_ref_logprobs: Whether to compute reference logprobs. + """ + recv_batch_size = 0 + while recv_batch_size < self.total_batch_size_per_dp: + batch, rollout_result = self.get_batch(input_channel) + recv_batch_size += rollout_result.num_sequence + self._load_weight_and_optimizer( + input_channel if self.is_pipeline else rollout_channel + ) + num_splits = ( + rollout_result.num_sequence + // self.cfg.algorithm.logprob_forward_micro_batch_size + ) + micro_batches_iter = get_iterator_k_split( + batch, + num_splits=num_splits, + ) + micro_batches = list(micro_batches_iter) + + prev_logprobs = [] + with self.worker_timer(): + for micro_batch in micro_batches: + prev_logprobs.append(self.inference_step(micro_batch).cpu()) + rollout_result.prev_logprobs = torch.cat(prev_logprobs) + if compute_ref_logprobs: + assert self.ref_policy_state_dict is not None, ( + "Reference policy state dict is None but compute_ref_logprobs is True" + ) + ref_logprobs = [] + with cpu_weight_swap(self.model, self.ref_policy_state_dict): + for micro_batch in micro_batches: + ref_logprobs.append(self.inference_step(micro_batch).cpu()) + rollout_result.ref_logprobs = torch.cat(ref_logprobs) + self.put_result(rollout_result, output_channel) + + assert recv_batch_size == self.total_batch_size_per_dp, ( + f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" + ) + + def run_training(self, input_channel: Channel) -> Tuple[Dict, list]: + # Get all batches for this DP + batches = [] + recv_batch_size = 0 + while recv_batch_size < self.total_batch_size_per_dp: + batch, rollout_result = self.get_batch(input_channel) + batches.append(batch) + recv_batch_size += rollout_result.num_sequence + assert recv_batch_size == self.total_batch_size_per_dp, ( + f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" + ) + batch = RolloutResult.merge_batches(batches) + # Must be called after batch is retrieved, which is when rollout has stopped + # Otherwise, loading model might cause OOM + self._load_weight_and_optimizer(input_channel) + + global_batches = get_iterator_k_split( + batch, + num_splits=self.cfg.algorithm.n_minibatches, + shuffle=self.cfg.algorithm.get("shuffle_rollout", True), + shuffle_seed=self.cfg.actor.seed, + ) + + self.model.train() + assert ( + self.cfg.actor.global_batch_size + % (self.cfg.actor.micro_batch_size * self._world_size) + == 0 + ) + + training_metrics_list = [] + # Global batch iterations + with self.worker_timer(): + for global_batch in global_batches: + train_global_batch_size = global_batch["input_ids"].shape[0] + + assert train_global_batch_size % self.cfg.actor.micro_batch_size == 0, ( + f"{train_global_batch_size=}, {self.cfg.actor.micro_batch_size=}" + ) + + self.gradient_accumulation = ( + train_global_batch_size // self.cfg.actor.micro_batch_size + ) + # split batch into micro_batches + train_micro_batches = get_iterator_k_split( + global_batch, + train_global_batch_size // self.cfg.actor.micro_batch_size, + ) + + self.optimizer.zero_grad() + metrics = {} + for idx, m_batch in enumerate(train_micro_batches): + backward_ctx = ( + self.model.no_sync() + if idx < self.gradient_accumulation - 1 + else nullcontext() + ) + for k, v in m_batch.items(): + m_batch[k] = v.npu() if isinstance(v, torch.Tensor) else v + + multi_modal_inputs = {} + if "multi_modal_inputs" in m_batch.keys(): + for key in m_batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat( + [ + inputs[key] + for inputs in m_batch["multi_modal_inputs"] + ], + dim=0, + ).npu() + + input_ids = m_batch["input_ids"] + attention_mask = m_batch["attention_mask"] + position_ids = m_batch["position_ids"] + prev_logprobs = m_batch["prev_logprobs"] + advantages = m_batch["advantages"] + ref_logprobs = None + if "ref_logprobs" in m_batch: + ref_logprobs = m_batch["ref_logprobs"] + + loss_mask = m_batch["attention_mask"][:, -self.response_len :] + output = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + use_cache=False, + ) + + logits = output.logits + + logits.div_(self.cfg.algorithm.sampling_params.temperature) + + responses = input_ids[:, -self.response_len :] + logits = logits[ + :, -self.response_len - 1 : -1, : + ] # (bsz, response_length, vocab_size) + logprobs = compute_logprobs_from_logits( + logits, responses, task_type=self.cfg.runner.task_type + ) + + clip_ratio = self.cfg.algorithm.ratio_clip_eps + clip_ratio_low = ( + self.cfg.algorithm.clip_ratio_low + if self.cfg.algorithm.clip_ratio_low is not None + else clip_ratio + ) + clip_ratio_high = ( + self.cfg.algorithm.clip_ratio_high + if self.cfg.algorithm.clip_ratio_high is not None + else clip_ratio + ) + clip_ratio_c = self.cfg.algorithm.get("clip_ratio_c", 3.0) + + loss, mbs_metrics_data = actor_loss( + loss_type=self.cfg.algorithm.loss_type, + loss_agg_func=self.loss_agg_func, + logprobs=logprobs, + old_logprobs=prev_logprobs, + advantages=advantages, + clip_ratio_low=clip_ratio_low, + clip_ratio_high=clip_ratio_high, + clip_ratio_c=clip_ratio_c, + loss_mask=loss_mask, + ) + + entropy_loss = torch.tensor(0.0, device=torch.npu.current_device()) + if self.calculate_entropy: + entropy = output["entropy"][ + :, -self.response_len - 1 : -1 + ].contiguous() + entropy_loss = self.loss_agg_func(entropy, mask=loss_mask) + if self.calculate_entropy_loss: + loss = ( + loss - self.cfg.algorithm.entropy_bonus * entropy_loss + ) + + kl_loss = torch.tensor(0.0, device=torch.npu.current_device()) + if self.kl_beta > 0 and ref_logprobs is not None: + kld = kl_penalty(ref_logprobs, logprobs, self.kl_penalty_type) + kl_loss = self.loss_agg_func(kld, loss_mask) + loss = loss + kl_loss * self.kl_beta + + # add to log + # scale loss for gradient accumulation and backprop + loss = loss / self.gradient_accumulation + with backward_ctx: + loss.backward() + + mbs_metrics_data.update( + { + "final_loss": loss.detach(), + "entropy_loss": entropy_loss.detach(), + "kl_loss": kl_loss.detach(), + } + ) + + append_to_dict(metrics, mbs_metrics_data) + # apply gradient clipping and optimizer step at the end of a global batch + grad_norm = self.model.clip_grad_norm_( + max_norm=self.cfg.actor.optim.clip_grad + ) + if not torch.isfinite(grad_norm).all(): + self.log_warning( + "grad norm is not finite, skip this optimizer step." + ) + else: + self.optimizer.step() + self.optimizer.zero_grad() + + # aggregate metrics across micro-batches + mean_metric_dict = { + key: torch.mean(torch.stack(value)) + for key, value in metrics.items() + } + mean_metric_dict = all_reduce_dict( + mean_metric_dict, op=torch.distributed.ReduceOp.AVG + ) + # add optimizer stats + if torch.is_tensor(grad_norm): + mean_metric_dict["actor/grad_norm"] = float( + grad_norm.detach().item() + ) + else: + mean_metric_dict["actor/grad_norm"] = float(grad_norm) + lr = self.optimizer.param_groups[0]["lr"] + mean_metric_dict["actor/lr"] = torch.as_tensor(lr).float().cpu() + training_metrics_list.append(mean_metric_dict) + + # Rollout metrics + rollout_metrics, _, _ = compute_math_rollout_metrics( + batch, self.cfg.data.max_prompt_length, self.response_len, self._world_size + ) + + return rollout_metrics, training_metrics_list + + def save_checkpoint(self, save_base_path: str, step: int) -> None: + torch.distributed.barrier() + model_state = self.get_model_state_dict() + optim_state = self.get_optimizer_state_dict() + if self._rank == 0: + os.makedirs(save_base_path, exist_ok=True) + torch.save(model_state, os.path.join(save_base_path, "model.pt")) + torch.save(optim_state, os.path.join(save_base_path, "optim.pt")) + torch.distributed.barrier() + + # Advantages and returns + def compute_advantages_and_returns( + self, input_channel: Channel, output_channel: Channel + ) -> None: + """Compute the advantages and returns. + + Args: + input_channel: The input channel to read from. + output_channel: The output channel to send results to. + """ + recv_batch_size = 0 + while recv_batch_size < self.total_batch_size_per_dp: + batch, rollout_result = self.get_batch(input_channel) + recv_batch_size += rollout_result.num_sequence + + with self.worker_timer(): + if rollout_result.advantages is None: + mask = batch["attention_mask"][:, -self.response_len :] + advantages, returns = calculate_adv_and_returns( + adv_type=self.cfg.algorithm.adv_type, + reward_scores=batch["rewards"].npu(), + mask=mask.npu(), + num_responses=self.cfg.algorithm.group_size, + ) + rollout_result.advantages = advantages.cpu() + + self.put_result(rollout_result, output_channel) + + assert recv_batch_size == self.total_batch_size_per_dp, ( + f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" + ) class EmbodiedFSDPActor(FSDPModelManager, Worker): @@ -47,11 +534,11 @@ def __init__(self, cfg: DictConfig): super().__init__(cfg.actor) self.cfg = cfg - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - self.device = torch.cuda.current_device() + torch.npu.set_device(int(os.environ["LOCAL_RANK"])) + self.device = torch.npu.current_device() world_size = self._world_size self.device_mesh = init_device_mesh( - "cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] + "npu", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] ) self._env_group_name = cfg.env.group_name @@ -79,9 +566,9 @@ def init_worker(self): if self.cfg.actor.get("enable_offload", False): self.offload_fsdp_param_and_grad() self.offload_fsdp_optimizer() - torch.cuda.synchronize() + torch.npu.synchronize() gc.collect() - torch.cuda.empty_cache() + torch.npu.empty_cache() def model_provider_func(self): model = get_model(self.cfg.actor.checkpoint_load_path, self.cfg.actor.model) @@ -102,10 +589,10 @@ def sync_model_to_rollout(self): if self.cfg.actor.get("enable_offload", False): self.offload_fsdp_param_and_grad() self.offload_fsdp_optimizer() - torch.cuda.synchronize() + torch.npu.synchronize() del state_dict gc.collect() - torch.cuda.empty_cache() + torch.npu.empty_cache() async def recv_rollout_batch(self): send_num = self._component_placement.get_world_size("rollout") * self.stage_num @@ -389,7 +876,7 @@ def run_training(self): metrics_data["loss"] = loss.detach().item() append_to_dict(metrics, metrics_data) - torch.cuda.empty_cache() + torch.npu.empty_cache() grad_norm = self.model.clip_grad_norm_( max_norm=self.cfg.actor.optim.clip_grad @@ -411,9 +898,9 @@ def run_training(self): ) self.optimizer.zero_grad() - torch.cuda.synchronize() + torch.npu.synchronize() torch.distributed.barrier() - torch.cuda.empty_cache() + torch.npu.empty_cache() return mean_metric_dict diff --git a/rlinf/workers/actor/megatron_actor_worker.py b/rlinf/workers/actor/megatron_actor_worker.py index 54376e1a5..40289ddf0 100644 --- a/rlinf/workers/actor/megatron_actor_worker.py +++ b/rlinf/workers/actor/megatron_actor_worker.py @@ -14,7 +14,7 @@ import copy from functools import partial -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple import torch import torch.distributed @@ -31,7 +31,6 @@ from rlinf.algorithms.registry import ( actor_loss, calculate_adv_and_returns, - get_reward_fn, ) from rlinf.algorithms.utils import kl_penalty from rlinf.data.io_struct import ( @@ -53,7 +52,6 @@ ) from rlinf.utils.distributed import ( RolloutDataBalance, - broadcast_tensor_within_mp, broadcast_tensor_within_pp, compute_rollout_metrics, masked_normalization, @@ -80,7 +78,6 @@ seq_mean_token_sum, ) from rlinf.workers.rollout.utils import RankMapper -from toolkits import register_rewards class MegatronActor(MegatronModelManager, Worker): @@ -103,6 +100,10 @@ def __init__( self.cfg = cfg self.component_placement = placement + # check placement validity when actor backend is megatron + assert placement.rollout_tp_size <= placement.actor_tp_size, ( + f" rollout tensor parallel size {placement.rollout_tp_size} must be less than or equal to actor tensor parallel size {placement.actor_tp_size}." + ) # Data configurations self.response_len = ( role_cfg.model.encoder_seq_length - cfg.data.max_prompt_length @@ -115,10 +116,21 @@ def __init__( self.calculate_entropy_loss = ( self.cfg.algorithm.entropy_bonus > 0 and self.calculate_entropy ) - self.ratio_eps = self.cfg.algorithm.ratio_clip_eps + clip_ratio = self.cfg.algorithm.ratio_clip_eps + self.clip_ratio_low = ( + self.cfg.algorithm.get("clip_ratio_low") + if self.cfg.algorithm.get("clip_ratio_low") is not None + else clip_ratio + ) + self.clip_ratio_high = ( + self.cfg.algorithm.get("clip_ratio_high") + if self.cfg.algorithm.get("clip_ratio_high") is not None + else clip_ratio + ) self.logprob_forward_micro_batch_size = ( self.cfg.algorithm.logprob_forward_micro_batch_size ) + self.kl_beta = self.cfg.algorithm.kl_beta self.kl_penalty_type = self.cfg.algorithm.kl_penalty_type self.clip_ratio_c = self.cfg.algorithm.clip_ratio_c @@ -144,11 +156,6 @@ def __init__( self.ref_policy_state_dict = None self.is_pipeline = self.component_placement.is_disaggregated - # Reward configurations - if not self.cfg.reward.use_reward_model: - register_rewards() - self.reward_fn = get_reward_fn(self.cfg.reward.reward_type) - # Rollout configurations self.rollout_group_name = self.cfg.rollout.group_name @@ -382,7 +389,8 @@ def loss_func(output): logprobs=curr_logprobs, old_logprobs=prev_logprobs, advantages=advantages, - eps_clip=self.ratio_eps, + clip_ratio_low=self.clip_ratio_low, + clip_ratio_high=self.clip_ratio_high, loss_mask=mask, ) @@ -830,23 +838,29 @@ def run_inference( self, input_channel: Channel, output_channel: Channel, + rollout_channel: Optional[Channel], compute_ref_logprobs: bool, ): - """Compute prev/ref logprobs using the actor Model's forward. + """ + Compute prev/ref logprobs using the actor Model's forward. Args: input_channel: The input channel to read from. output_channel: The output channel to send results to. + rollout_channel: get the rollout channel's device lock in case of collision. compute_ref_logprobs: Whether to compute reference logprobs. """ recv_batch_size = 0 while recv_batch_size < self.total_batch_size_per_dp: batch, rollout_result = self.get_batch(input_channel) recv_batch_size += rollout_result.num_sequence - # Must be called after batch is retrieved, suggesting that rollout has stopped # Otherwise, loading model might cause OOM in the collocated mode - self._load_weight_and_optimizer(input_channel) + self._load_weight_and_optimizer( + input_channel + if self.is_pipeline or rollout_channel is None + else rollout_channel + ) # Prev logprobs with self.worker_timer(): @@ -859,91 +873,12 @@ def run_inference( with cpu_weight_swap(self.model[0], self.ref_policy_state_dict): ref_logprobs = self.inference_step(batch) rollout_result.ref_logprobs = ref_logprobs.cpu() - - self.put_result(rollout_result, output_channel) - - assert recv_batch_size == self.total_batch_size_per_dp, ( - f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" - ) - - # Rewards - def compute_rewards(self, input_channel: Channel, output_channel: Channel): - """Compute rewards. - - Args: - input_channel: The input channel to read from. - output_channel: The output channel to send results to. - """ - assert self.reward_fn is not None, "reward_fn is not set" - if self.is_pipeline: - # In pipeline mode, rewards are computed in the rollout - with self.worker_timer(): - return - recv_batch_size = 0 - while recv_batch_size < self.total_batch_size_per_dp: - batch, rollout_result = self.get_batch(input_channel) - recv_batch_size += rollout_result.num_sequence - - # Compute rule-based reward - with self.worker_timer(): - if rollout_result.rewards is None: - rollout_result.rewards = self._compute_batch_rewards( - batch, rollout_result.answers - ).cpu() - self.put_result(rollout_result, output_channel) assert recv_batch_size == self.total_batch_size_per_dp, ( f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" ) - def _compute_batch_rewards( - self, batch: Dict[str, torch.Tensor], answers: List[str] - ): - """Reward computation using non-model based reward.""" - all_reward_scores = [] - texts = [] - for response, response_len in zip( - batch["input_ids"], - batch["response_lengths"], - ): - response = response[ - self.cfg.data.max_prompt_length : self.cfg.data.max_prompt_length - + response_len - ] - texts.append( - self.tokenizer.decode(response.tolist(), skip_special_tokens=True) - ) - - if torch.distributed.get_rank() == parallel_state.get_model_parallel_src_rank(): - rewards = self.reward_fn(texts, answers) - if self.cfg.reward.reward_type == "math": - reward_scores = [ - self.cfg.reward.reward_scale - if reward == 1 - else -self.cfg.reward.reward_scale - for reward in rewards - ] - else: - reward_scores = rewards - - all_reward_scores.extend(reward_scores) - - if len(all_reward_scores) > 0: - new_all_rewards = [] - - for response in all_reward_scores: - if response is None: - response = 0.0 - new_all_rewards.append(response) - - all_reward_scores = torch.as_tensor( - new_all_rewards, - dtype=torch.float, - device=torch.cuda.current_device(), - ).view(-1, 1) - return broadcast_tensor_within_mp(all_reward_scores).flatten().to("cpu") - # Advantages and returns def compute_advantages_and_returns( self, input_channel: Channel, output_channel: Channel @@ -954,16 +889,11 @@ def compute_advantages_and_returns( input_channel: The input channel to read from. output_channel: The output channel to send results to. """ - if self.is_pipeline: - # In pipeline mode, advantages are computed in the rollout - with self.worker_timer(): - return clear_memory() recv_batch_size = 0 while recv_batch_size < self.total_batch_size_per_dp: batch, rollout_result = self.get_batch(input_channel) recv_batch_size += rollout_result.num_sequence - with self.worker_timer(): if rollout_result.advantages is None: mask = batch["attention_mask"][:, -self.response_len :] @@ -1033,7 +963,11 @@ def sync_model_to_rollout(self): def _compute_rollout_metrics(self, batch): rollout_metrics, total_prompt_lengths, total_decode_lengths = ( compute_rollout_metrics( - batch, self.cfg.data.max_prompt_length, self.response_len + batch, + self.cfg.data.max_prompt_length, + self.response_len, + self._world_size, + dp_group=parallel_state.get_data_parallel_group(), ) ) diff --git a/rlinf/workers/reward/reward_worker.py b/rlinf/workers/reward/reward_worker.py new file mode 100644 index 000000000..88be65ddc --- /dev/null +++ b/rlinf/workers/reward/reward_worker.py @@ -0,0 +1,104 @@ +# Copyright 2025 The RLinf Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Tuple + +import torch +from omegaconf import DictConfig + +from rlinf.algorithms.rewards import get_reward_class +from rlinf.data.io_struct import RolloutResult +from rlinf.data.tokenizers import hf_tokenizer +from rlinf.scheduler import Channel, Worker +from rlinf.utils.placement import ModelParallelComponentPlacement + + +class RewardWorker(Worker): + def __init__(self, cfg: DictConfig, placement: ModelParallelComponentPlacement): + Worker.__init__(self) + self.cfg = cfg + self.component_placement = placement + self.tokenizer = hf_tokenizer(cfg.reward.tokenizer.tokenizer_model) + self.total_batch_size_per_dp = ( + self.cfg.data.rollout_batch_size + * self.cfg.algorithm.get("group_size", 1) + // self._world_size + ) + + def init_worker(self): + if self.cfg.reward.use_reward_model: + raise NotImplementedError("Reward model is not implemented yet.") + else: + self.reward = get_reward_class(self.cfg.reward.reward_type)(self.cfg.reward) + + def get_batch( + self, channel: Channel + ) -> Tuple[Dict[str, torch.Tensor], RolloutResult]: + result: RolloutResult = channel.get() + batch = result.to_actor_batch( + self.cfg.data.max_prompt_length, + self.cfg.actor.model.encoder_seq_length, + self.tokenizer.eos_token_id, + ) + return batch, result + + def compute_rewards(self, input_channel: Channel, output_channel: Channel): + """Compute rewards. + + Args: + input_channel: The input channel to read from. + output_channel: The output channel to send results to. + """ + recv_batch_size = 0 + while recv_batch_size < self.total_batch_size_per_dp: + rollout_result: RolloutResult = input_channel.get() + recv_batch_size += rollout_result.num_sequence + with self.worker_timer(): + if rollout_result.rewards is None: + if self.cfg.reward.use_reward_model: + with input_channel.device_lock: + batch = rollout_result.to_actor_batch( + self.cfg.data.max_prompt_length, + self.cfg.actor.model.encoder_seq_length, + self.tokenizer.eos_token_id, + ) + rollout_result.rewards = ( + self.compute_batch_rewards_with_model(batch) + ) + else: + rollout_result.rewards = self._compute_rule_based_rewards( + rollout_result + ) + + output_channel.put(rollout_result) + + assert recv_batch_size == self.total_batch_size_per_dp, ( + f"Expected {self.total_batch_size_per_dp} sequences from channel, but got {recv_batch_size}" + ) + + def _compute_rule_based_rewards(self, rollout_result: RolloutResult): + # Decode only the generated tokens; response_ids are already the post-prompt tokens + texts = self.tokenizer.batch_decode( + rollout_result.response_ids, skip_special_tokens=True + ) + + scores = self.reward.get_reward(texts, rollout_result.answers) + return ( + torch.as_tensor(scores, dtype=torch.float, device=torch.device("cpu")) + .view(-1, 1) + .flatten() + ) + + def compute_batch_rewards_with_model(self, batch: Dict[str, torch.Tensor]): + raise NotImplementedError("Reward model is not implemented yet.") diff --git a/rlinf/workers/rollout/sglang/__init__.py b/rlinf/workers/rollout/sglang/__init__.py index 5e0fb219f..cc78162b9 100644 --- a/rlinf/workers/rollout/sglang/__init__.py +++ b/rlinf/workers/rollout/sglang/__init__.py @@ -49,6 +49,12 @@ def get_version(pkg): from rlinf.hybrid_engines.sglang.sglang_0_4_9.sgl_engine import ( Engine, ) +elif package_version >= parse("0.5.0") and package_version < parse("0.5.3"): + sglang_version = package_version + from rlinf.hybrid_engines.sglang.sglang_0_5_2 import io_struct + from rlinf.hybrid_engines.sglang.sglang_0_5_2.sgl_engine import ( + Engine, + ) else: raise ValueError(f"sglang version {package_version} not supported") diff --git a/rlinf/workers/rollout/sglang/sglang_worker.py b/rlinf/workers/rollout/sglang/sglang_worker.py index 8d4a15cb7..a24b297f2 100644 --- a/rlinf/workers/rollout/sglang/sglang_worker.py +++ b/rlinf/workers/rollout/sglang/sglang_worker.py @@ -18,7 +18,6 @@ from typing import Dict, List, Optional, Tuple import numpy as np -import torch from omegaconf import DictConfig from sglang.srt.server_args import ServerArgs from transformers import AutoTokenizer @@ -35,7 +34,6 @@ from rlinf.workers.rollout.utils import ( print_sglang_outputs, ) -from toolkits.math_verifier.verify import MathRewardModel, math_verify_call class SGLangWorker(Worker): @@ -115,6 +113,7 @@ def _init_engine(self): log_level="info", max_running_requests=self._cfg.rollout.max_running_requests, dist_init_addr=f"127.0.0.1:{str(Cluster.find_free_port())}", + device="npu", ) self.log_on_first_rank(f"{server_args=}") @@ -169,7 +168,6 @@ def sync_model_from_actor(self): def rollout(self, input_channel: Channel, output_channel: Channel): request: RolloutRequest = input_channel.get() - # Repeat prompts based on the group_size config requests = request.repeat_and_split(self._rollout_batch_size) @@ -181,6 +179,8 @@ def rollout(self, input_channel: Channel, output_channel: Channel): with self.worker_timer(): results = self._engine.generate( input_ids=request.input_ids, + # 0.4.4 has modality bug,can't pass non-None image_data + image_data=request.image_data if any(request.image_data) else None, sampling_params=self._sampling_params, return_logprob=self._return_logprobs, ) @@ -191,6 +191,8 @@ def rollout(self, input_channel: Channel, output_channel: Channel): request.n, request.input_ids, request.answers, + request.image_data, + request.multi_modal_inputs, self._return_logprobs, ) rollout_results.append(rollout_result) @@ -205,8 +207,9 @@ def rollout(self, input_channel: Channel, output_channel: Channel): self._stop() # Release the GPUs once the engine has offloaded output_channel.device_lock.release() - rollout_result = RolloutResult.merge_result_list(rollout_results) - output_channel.put(rollout_result) + rollout_result_list = RolloutResult.split_result_list_by_group(rollout_results) + for rollout_result in rollout_result_list: + output_channel.put(rollout_result) def all_floats_equal(float_list: list[float], epsilon: float = 1e-9) -> bool: @@ -229,7 +232,6 @@ def __init__(self, config: DictConfig, placement: ComponentPlacement): self._rollout_end_event = asyncio.Event() self._sync_weight_end_event = asyncio.Event() - self._reward_model = MathRewardModel(scale=self._cfg.reward.reward_scale) assert self._rollout_batch_size is None, ( "rollout_batch_size_per_gpu is not supported in AsyncSGLangWorker" ) @@ -258,29 +260,6 @@ async def init_worker(self): if self._cfg.rollout.validate_weight: await self._validate_weight_at_first() - async def _compute_reward_and_advantage( - self, engine_results: List[Dict], answer: str - ): - answers = [answer] * len(engine_results) - texts: List[str] = [] - for res in engine_results: - if hasattr(res, "text"): - texts.append(res["text"]) - else: - texts.append( - self._tokenizer.decode(res["output_ids"], skip_special_tokens=True) - ) - - results = math_verify_call(texts, answers) - rewards = [(1 if r else -1) * self._reward_model.scale for r in results] - rewards_tensor = torch.tensor(rewards, dtype=torch.float) - - mean = rewards_tensor.mean() - std = rewards_tensor.std() - advantages = (rewards_tensor - mean) / (std + 1e-6) - - return rewards, advantages.tolist() - async def _async_generate( self, raw_id: int, input_ids: List[int], sampling_params: dict ): @@ -319,7 +298,6 @@ async def rollout(self, input_channel: Channel, output_channel: Channel): total_reqs = len(rollout_tasks) required_reqs = total_reqs // self._cfg.algorithm.max_num_gen_batches - droped_reqs = 0 finished_reqs = 0 abort_flag = False @@ -330,32 +308,17 @@ async def rollout(self, input_channel: Channel, output_channel: Channel): if self._completion_info.is_completed(hash_id): results = self._completion_info.get_results(hash_id) - ( - rewards, - advantages, - ) = await self._compute_reward_and_advantage( - results, - self._current_request.answers[raw_id], - ) - if ( - all_floats_equal(rewards) - and self._cfg.algorithm.get("max_num_gen_batches", 1) > 1 - ): - if (total_reqs - droped_reqs) > required_reqs: - droped_reqs += rollout_request.n - continue input_ids = [input_ids] * len(results) + answers = [rollout_request.answers[raw_id]] * len(results) rollout_result = RolloutResult.from_sglang_results( results, rollout_request.n, input_ids, + answers=answers, return_logprobs=self._return_logprobs, ) - rollout_result.rewards = torch.tensor( - rewards, dtype=torch.float32 - ).reshape(-1, 1) - rollout_result.advantages = advantages + return_tasks.append( asyncio.create_task( self._put_result(rollout_result, output_channel) diff --git a/rlinf/workers/rollout/utils.py b/rlinf/workers/rollout/utils.py index f3845ef2f..f92e48caa 100644 --- a/rlinf/workers/rollout/utils.py +++ b/rlinf/workers/rollout/utils.py @@ -376,6 +376,12 @@ def get_actor_rank_to_rollout_rank_map( """ Get the global mapping from actor 1D rank to rollout 2D rank as dict. """ + # rank -> (dp, tp) + if actor_tp_size == 1: + return { + rank: (rank // rollout_tp_size, rank % rollout_tp_size) + for rank in range(actor_world_size) + } rank_map = {} for actor_rank in range(actor_world_size): rank_map[actor_rank] = cls._get_actor_rank_to_rollout_rank( diff --git a/rlinf/workers/rollout/vllm/vllm_worker.py b/rlinf/workers/rollout/vllm/vllm_worker.py index d5ecd11fc..7324fe841 100644 --- a/rlinf/workers/rollout/vllm/vllm_worker.py +++ b/rlinf/workers/rollout/vllm/vllm_worker.py @@ -19,9 +19,8 @@ from typing import AsyncGenerator, List, Optional, Union import requests -import torch from omegaconf import DictConfig -from PIL.Image import Image +from PIL import Image from transformers import AutoTokenizer from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs @@ -36,7 +35,6 @@ from rlinf.scheduler import Channel, Worker from rlinf.utils.placement import ComponentPlacement from rlinf.workers.rollout.utils import print_vllm_outputs -from toolkits.math_verifier.verify import MathRewardModel, math_verify_call from . import VLLMExecutor @@ -69,7 +67,6 @@ def __init__(self, config: DictConfig, placement: ComponentPlacement): "The capital of France is", "The future of AI is", ] - self._reward_model = MathRewardModel(self._cfg.reward.reward_scale) self.request_counter = Counter() def _prepare_vllm_environment(self) -> None: @@ -227,6 +224,23 @@ async def generate( Union[List[List[Union[bytes, str]]], List[Union[bytes, str]]] ] = None, ) -> List[RequestOutput]: + """ + Do Generate Task using the vllm async engine. + + Args: + input_ids: The input token ids to generate. It can be a list of list of int, + or a list of int (single prompt). + sampling_params: The sampling parameters to use for generation. + prompt_texts: The input prompt texts to generate. It can be a list of strings + or a single string. If provided, it will be used instead of input_ids. + image_data: The input multi-modal data to generate. It can be a list of list + of bytes or image paths (local or URL), or a list of bytes or image paths + (single prompt). + + Returns: + List[RequestOutput]: A list of RequestOutput from vllm engine. + """ + def check_input_ids() -> List[List[int]]: assert isinstance(input_ids, list), ( "input_ids should be a list or list of list of int." @@ -249,8 +263,8 @@ def check_prompt_text() -> Optional[List[str]]: assert len(prompt_texts) > 0, "prompt_text should not be empty." return prompt_texts - def check_image_data() -> Optional[List[List[Image]]]: - if image_data is None: + def check_image_data() -> Optional[List[List[Image.Image]]]: + if image_data is None or not any(image_data): return None assert isinstance(image_data, list), "image_data should be a list." if isinstance(image_data[0], list): @@ -267,19 +281,22 @@ def check_image_data() -> Optional[List[List[Image]]]: if prompt_texts is not None: for i, prompt_text in enumerate(prompt_texts): if image_list is not None: - image_list = self._process_image_data(image_data=image_list[i]) + images = self._process_image_data(image_data=image_list[i]) inputs.append( - TextPrompt(prompt=prompt_text, multi_modal_data=image_list) + TextPrompt( + prompt=prompt_text, multi_modal_data={"image": images} + ) ) else: inputs.append(TextPrompt(prompt=prompt_text)) else: for i, input_id in enumerate(input_ids): if image_list is not None: - image_list = self._process_image_data(image_data=image_list[i]) + images = self._process_image_data(image_data=image_list[i]) inputs.append( TokensPrompt( - prompt_token_ids=input_id, multi_modal_data=image_list + prompt_token_ids=input_id, + multi_modal_data={"image": images}, ) ) else: @@ -303,7 +320,8 @@ def check_image_data() -> Optional[List[List[Image]]]: async def init_worker(self) -> None: """ Use EngineArgs and VllmConfig to initialize VLLM async engine. - Then offload the model weights, ready to use weights sent from actor. + If mode is collocated, it will additionally offload model weights, + ready to use parameters sent from actor. """ engine_args: EngineArgs = EngineArgs( model=self._cfg.rollout.model_dir, @@ -319,7 +337,8 @@ async def init_worker(self) -> None: trust_remote_code=self._cfg.actor.tokenizer.trust_remote_code, max_model_len=self._cfg.runner.seq_length, max_num_seqs=self._cfg.rollout.max_running_requests, - enable_sleep_mode=True, # it enables offload weights + enable_sleep_mode=False, + device="npu", ) vllm_config: VllmConfig = engine_args.create_engine_config() @@ -350,7 +369,21 @@ async def init_worker(self) -> None: await self.offload_model_weights() async def _put_result(self, result: RolloutResult, output_channel: Channel) -> None: - await output_channel.put(result, async_op=True).async_wait() + """ + Helper function to put the result to output channel. + + Args: + result: The RolloutResult to put to the channel. + output_channel: The output channel to send results to. + """ + # NOTE: + # To fit reward worker and actor workers' expected input count, + # currently we can only split result into groups. + splited_results = RolloutResult.split_result_list_by_group([result]) + put_tasks = [ + output_channel.put(r, async_op=True).async_wait() for r in splited_results + ] + await asyncio.gather(*put_tasks) async def _stop(self) -> None: """ @@ -363,50 +396,43 @@ async def _stop(self) -> None: if not self._placement.is_disaggregated: await self.offload_model_weights() - async def _compute_reward_and_advantage(self, rollout_result: RolloutResult): - """ - Compute rewards and advantages for the rollout result using math verification. - """ - answers = rollout_result.answers - outputs = rollout_result.response_texts - num_sequence = rollout_result.num_sequence - assert len(answers) == len(outputs), ( - f"Answers length {len(answers)} != outputs length {len(outputs)}" - ) - assert len(answers) == num_sequence, ( - f"Answers length {len(answers)} != num_sequence {num_sequence}" - ) - - math_verify_results = math_verify_call(outputs, answers) - rewards = [ - (1 if r else -1) * self._reward_model.scale for r in math_verify_results - ] - rewards_tensor = torch.tensor(rewards, dtype=torch.float) - rollout_result.rewards = rewards_tensor.reshape(-1, 1) - - mean = rewards_tensor.mean() - std = rewards_tensor.std(unbiased=False) - advantages = (rewards_tensor - mean) / (std + 1e-6) - rollout_result.advantages = advantages.tolist() - async def rollout_and_return( self, request: RolloutRequest, output_channel: Channel ): + """ + Helper function to rollout for a single RolloutRequest and build RolloutResult then + put it to output channel. + + Args: + request: The RolloutRequest to process. + output_channel: The output channel to send results to. + """ vllm_results: List[RequestOutput] = await self.generate( - input_ids=request.input_ids, sampling_params=self._sampling_params + input_ids=request.input_ids, + image_data=request.image_data, + sampling_params=self._sampling_params, ) rollout_result: RolloutResult = RolloutResult.from_vllm_results( group_size=self._cfg.algorithm.group_size, results=vllm_results, answers=request.answers, + multi_modal_inputs=request.multi_modal_inputs, return_logprobs=self._return_logprobs, ) - if self._placement.is_disaggregated: - await self._compute_reward_and_advantage(rollout_result) - + if self._cfg.rollout.print_outputs: + print_vllm_outputs(outputs=vllm_results) await self._put_result(result=rollout_result, output_channel=output_channel) async def rollout(self, input_channel: Channel, output_channel: Channel) -> None: + """ + Perform rollout using vllm engine. + It will read `RolloutRequest` from input_channel and put `RolloutResult` to output_channel. + If the input request is None, it will stop the rollout. + + Args: + input_channel: The input channel to read from. + output_channel: The output channel to send results to. + """ rollout_request: RolloutRequest = await input_channel.get( async_op=True ).async_wait() diff --git a/test.py b/test.py new file mode 100644 index 000000000..485f51f99 --- /dev/null +++ b/test.py @@ -0,0 +1,3 @@ +from safetensors import safe_open +with safe_open("/home/weight/Qwen2.5-VL-3B-Instruct/model-00001-of-00002.safetensors", framework="pt") as f: + print(f.keys()) diff --git a/tests/e2e_tests/auto_placement/qwen2.5-1.5b-grpo.yaml b/tests/e2e_tests/auto_placement/qwen2.5-1.5b-grpo.yaml index 6555cb9bf..d1c65161b 100644 --- a/tests/e2e_tests/auto_placement/qwen2.5-1.5b-grpo.yaml +++ b/tests/e2e_tests/auto_placement/qwen2.5-1.5b-grpo.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 1 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: ${runner.output_dir}/${runner.experiment_name} project_name: rlinf @@ -119,6 +119,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 1024 filter_prompt_by_length: True rollout_batch_size: 128 @@ -256,13 +257,20 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} critic: use_critic_model: false + profile_data: actor_cost: 95.7 inference_cost: 30.8 diff --git a/tests/e2e_tests/auto_placement/run_auto_placement.sh b/tests/e2e_tests/auto_placement/run.sh similarity index 100% rename from tests/e2e_tests/auto_placement/run_auto_placement.sh rename to tests/e2e_tests/auto_placement/run.sh diff --git a/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml b/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml index 1de266648..f34ee4cae 100644 --- a/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml +++ b/tests/e2e_tests/coding_online_rl/qwen2.5-1.5b-ppo.yaml @@ -12,6 +12,7 @@ cluster: rollout: 0-3 inference: 4-5 actor: 6-7 + reward: 0-3 runner: task_type: coding_online_rl @@ -282,7 +283,7 @@ actor: reward: use_reward_model: False - reward_type: fim_verify_call + reward_type: code reward_scale: 5.0 critic: diff --git a/tests/e2e_tests/coding_online_rl/run_coding_online_rl.sh b/tests/e2e_tests/coding_online_rl/run.sh similarity index 100% rename from tests/e2e_tests/coding_online_rl/run_coding_online_rl.sh rename to tests/e2e_tests/coding_online_rl/run.sh diff --git a/tests/e2e_tests/embodied/libero_130_grpo_openvlaoft.yaml b/tests/e2e_tests/embodied/libero_130_grpo_openvlaoft.yaml index 2365b5a86..2672dcd01 100644 --- a/tests/e2e_tests/embodied/libero_130_grpo_openvlaoft.yaml +++ b/tests/e2e_tests/embodied/libero_130_grpo_openvlaoft.yaml @@ -157,6 +157,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + reward: use_reward_model: False diff --git a/tests/e2e_tests/embodied/libero_goal_grpo_openvlaoft.yaml b/tests/e2e_tests/embodied/libero_goal_grpo_openvlaoft.yaml index ddfe4a500..6dc7893d3 100644 --- a/tests/e2e_tests/embodied/libero_goal_grpo_openvlaoft.yaml +++ b/tests/e2e_tests/embodied/libero_goal_grpo_openvlaoft.yaml @@ -156,6 +156,12 @@ actor: trust_remote_code: True padding_side: "right" + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/tests/e2e_tests/embodied/maniskill_grpo_openvlaoft.yaml b/tests/e2e_tests/embodied/maniskill_grpo_openvlaoft.yaml index ab384947a..f26ce7c3d 100644 --- a/tests/e2e_tests/embodied/maniskill_grpo_openvlaoft.yaml +++ b/tests/e2e_tests/embodied/maniskill_grpo_openvlaoft.yaml @@ -155,6 +155,12 @@ actor: adam_eps: 1.0e-05 clip_grad: 10.0 + fsdp: + forward_prefetch: False + limit_all_gathers: False + backward_prefetch: False + use_orig_params: False + reward: use_reward_model: False diff --git a/tests/e2e_tests/embodied/maniskill_ppo_openvla.yaml b/tests/e2e_tests/embodied/maniskill_ppo_openvla.yaml index 2063ab483..5d9258743 100644 --- a/tests/e2e_tests/embodied/maniskill_ppo_openvla.yaml +++ b/tests/e2e_tests/embodied/maniskill_ppo_openvla.yaml @@ -165,6 +165,12 @@ actor: adam_eps: 1.0e-05 clip_grad: 1.0 + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + reward: use_reward_model: False diff --git a/tests/e2e_tests/math/sglang/run_collocated.sh b/tests/e2e_tests/math/sglang/run_collocated.sh deleted file mode 100644 index 1610f8fac..000000000 --- a/tests/e2e_tests/math/sglang/run_collocated.sh +++ /dev/null @@ -1,17 +0,0 @@ -#! /bin/bash -set -x - -tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export TOKENIZERS_PARALLELISM=false - -export PYTHONPATH=${REPO_PATH}:$PYTHONPATH - -if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-1.5b-grpo-collocated" -else - CONFIG_NAME=$1 -fi - -python ${REPO_PATH}/examples/math/main_math.py --config-path $REPO_PATH/tests/e2e_tests/math/sglang --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/e2e_tests/math/sglang/run_pipeline.sh b/tests/e2e_tests/math/sglang/run_pipeline.sh deleted file mode 100644 index 85e2e5c2d..000000000 --- a/tests/e2e_tests/math/sglang/run_pipeline.sh +++ /dev/null @@ -1,17 +0,0 @@ -#! /bin/bash -set -x - -tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export TOKENIZERS_PARALLELISM=false - -export PYTHONPATH=${REPO_PATH}:$PYTHONPATH - -if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-1.5b-grpo-pipeline" -else - CONFIG_NAME=$1 -fi - -python ${REPO_PATH}/examples/math/main_math.py --config-path $REPO_PATH/tests/e2e_tests/math/sglang --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/e2e_tests/math/vllm/run_collocated.sh b/tests/e2e_tests/math/vllm/run_collocated.sh deleted file mode 100644 index b4e924b1d..000000000 --- a/tests/e2e_tests/math/vllm/run_collocated.sh +++ /dev/null @@ -1,17 +0,0 @@ -#! /bin/bash -set -x - -tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export TOKENIZERS_PARALLELISM=false - -export PYTHONPATH=${REPO_PATH}:$PYTHONPATH - -if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-1.5b-grpo-collocated" -else - CONFIG_NAME=$1 -fi - -python ${REPO_PATH}/examples/math/main_math.py --config-path $REPO_PATH/tests/e2e_tests/math/vllm --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/e2e_tests/math/vllm/run_pipeline.sh b/tests/e2e_tests/math/vllm/run_pipeline.sh deleted file mode 100644 index 0a21368f6..000000000 --- a/tests/e2e_tests/math/vllm/run_pipeline.sh +++ /dev/null @@ -1,17 +0,0 @@ -#! /bin/bash -set -x - -tabs 4 -export VLLM_ATTENTION_BACKEND=XFORMERS -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export TOKENIZERS_PARALLELISM=false - -export PYTHONPATH=${REPO_PATH}:$PYTHONPATH - -if [ -z "$1" ]; then - CONFIG_NAME="qwen2.5-1.5b-grpo-pipeline" -else - CONFIG_NAME=$1 -fi - -python ${REPO_PATH}/examples/math/main_math.py --config-path $REPO_PATH/tests/e2e_tests/math/vllm --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml new file mode 100644 index 000000000..86c6f362c --- /dev/null +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl-rollout-logprobs.yaml @@ -0,0 +1,219 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: /workspace/results/ + project_name: rlinf + experiment_name: "ci-test" + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 1 + max_steps: 3 + + val_check_interval: 1 + save_interval: -1 + + seq_length: 1024 + + enable_dynamic_batch_size: True + max_tokens_per_mbs: 1024 + + resume_dir: null + experiment_name: grpo-1.5b + output_dir: /workspace/results +algorithm: + group_size: 2 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: False + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + model_arch: qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: sglang # [sglang, vllm] + + sglang: + attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: math + dataset_name: boba + max_prompt_length: 256 + filter_prompt_by_length: True + rollout_batch_size: 16 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + shuffle: True + validation_shuffle: True + seed: 1 + train_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + val_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + + optim: + optimizer: adam + bf16: True + fp16: False + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + use_fast: False + trust_remote_code: True + padding_side: 'right' + + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'math' + reward_scale: 5.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + +critic: + use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml new file mode 100644 index 000000000..dbf7925f2 --- /dev/null +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-sgl.yaml @@ -0,0 +1,219 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: /workspace/results/ + project_name: rlinf + experiment_name: "ci-test" + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 1 + max_steps: 3 + + val_check_interval: 1 + save_interval: -1 + + seq_length: 1024 + + enable_dynamic_batch_size: True + max_tokens_per_mbs: 1024 + + resume_dir: null + experiment_name: grpo-1.5b + output_dir: /workspace/results +algorithm: + group_size: 2 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: True + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + model_arch: qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: sglang # [sglang, vllm] + + sglang: + attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: math + dataset_name: boba + max_prompt_length: 256 + filter_prompt_by_length: True + rollout_batch_size: 16 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + shuffle: True + validation_shuffle: True + seed: 1 + train_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + val_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + + optim: + optimizer: adam + bf16: True + fp16: False + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + use_fast: False + trust_remote_code: True + padding_side: 'right' + + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'math' + reward_scale: 5.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + +critic: + use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml new file mode 100644 index 000000000..e85fd4146 --- /dev/null +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm-rollout-logprobs.yaml @@ -0,0 +1,219 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: /workspace/results/ + project_name: rlinf + experiment_name: "ci-test" + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 1 + max_steps: 3 + + val_check_interval: 1 + save_interval: -1 + + seq_length: 1024 + + enable_dynamic_batch_size: True + max_tokens_per_mbs: 1024 + + resume_dir: null + experiment_name: grpo-1.5b + output_dir: /workspace/results +algorithm: + group_size: 2 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: False + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + model_arch: qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: vllm # [sglang, vllm] + + sglang: + attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: math + dataset_name: boba + max_prompt_length: 256 + filter_prompt_by_length: True + rollout_batch_size: 16 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + shuffle: True + validation_shuffle: True + seed: 1 + train_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + val_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + + optim: + optimizer: adam + bf16: True + fp16: False + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + use_fast: False + trust_remote_code: True + padding_side: 'right' + + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'math' + reward_scale: 5.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + +critic: + use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml new file mode 100644 index 000000000..1aab4cb7b --- /dev/null +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-fsdp-vllm.yaml @@ -0,0 +1,219 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: /workspace/results/ + project_name: rlinf + experiment_name: "ci-test" + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 1 + max_steps: 3 + + val_check_interval: 1 + save_interval: -1 + + seq_length: 1024 + + enable_dynamic_batch_size: True + max_tokens_per_mbs: 1024 + + resume_dir: null + experiment_name: grpo-1.5b + output_dir: /workspace/results +algorithm: + group_size: 2 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: True + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + model_arch: qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: vllm # [sglang, vllm] + + sglang: + attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: math + dataset_name: boba + max_prompt_length: 256 + filter_prompt_by_length: True + rollout_batch_size: 16 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + shuffle: True + validation_shuffle: True + seed: 1 + train_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + val_data_paths: ["/workspace/dataset/boba_106k_0319_prompt_1024.jsonl"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + + optim: + optimizer: adam + bf16: True + fp16: False + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B + use_fast: False + trust_remote_code: True + padding_side: 'right' + + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'math' + reward_scale: 5.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + +critic: + use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs.yaml similarity index 93% rename from tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs.yaml index 7e2c5164a..4edc14979 100644 --- a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl-rollout-logprobs.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 1 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -84,11 +84,11 @@ rollout: model_dir: /workspace/dataset/DeepSeek-R1-Distill-Qwen-1.5B model_arch: qwen2.5 - enforce_eager: False # if False, vllm will capture cuda graph, which will take more time to initialize. + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. distributed_executor_backend: mp # ray or mp disable_log_stats: False detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. - padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for vllm rollout + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine eos: null # will be tokenizer.eos_token_id if null. rollout_backend: sglang # [sglang, vllm] @@ -114,13 +114,14 @@ rollout: validate_weight: False # whether to send all weights at first for weight comparison. validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. - print_outputs: False # whether to print the outputs (token ids, texts, etc.) of inference engine. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. - max_running_requests: 64 # the maximum number of running requests in the inference engine. + max_running_requests: 64 # the maximum number of running requests in the rollout engine. cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. data: type: math + dataset_name: boba max_prompt_length: 256 filter_prompt_by_length: True rollout_batch_size: 8 @@ -259,9 +260,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl.yaml similarity index 95% rename from tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl.yaml index 1dfe47aeb..1854fd8c3 100644 --- a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-collocated.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-sgl.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 1 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -107,7 +107,6 @@ rollout: max_num_batched_tokens: null # the maximum number of tokens to be batched together in vllm. If set to null, vllm will use its default value. torch_profiler_dir: null # if not null, vllm will enable torch profiler and save the result to the specified directory. - return_logprobs: ${not:${algorithm.recompute_logprobs}} tensor_parallel_size: 1 @@ -117,15 +116,12 @@ rollout: validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. - sglang_decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. max_running_requests: 64 # the maximum number of running requests in the rollout engine. cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. - use_torch_compile: False # enable torch_compile in SGLang for rollout. - torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. - data: type: math + dataset_name: boba max_prompt_length: 256 filter_prompt_by_length: True rollout_batch_size: 8 @@ -264,9 +260,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs.yaml similarity index 96% rename from tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs.yaml index fe61ab16c..edeaee9c3 100644 --- a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated-rollout-logprobs.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm-rollout-logprobs.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 1 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -121,6 +121,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 256 filter_prompt_by_length: True rollout_batch_size: 8 @@ -259,9 +260,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm.yaml similarity index 96% rename from tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm.yaml index 099fe7268..09df84ca8 100644 --- a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-collocated.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-collocated-mg-vllm.yaml @@ -9,10 +9,10 @@ hydra: cluster: num_nodes: 1 component_placement: - actor,rollout: all + actor,rollout,reward: all runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -122,6 +122,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 256 filter_prompt_by_length: True rollout_batch_size: 8 @@ -260,9 +261,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs.yaml similarity index 96% rename from tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs.yaml index 25344d0bf..34bcff492 100644 --- a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl-rollout-logprobs.yaml @@ -11,9 +11,10 @@ cluster: component_placement: rollout: 0-3 actor: 4-7 + reward: 0-3 runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -124,6 +125,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 256 rollout_batch_size: 8 val_rollout_batch_size: null @@ -257,11 +259,17 @@ actor: schedule_repeat: 1 # inference and training will repeat such times # schedule_wait: it will be set at runtime - reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl.yaml similarity index 96% rename from tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl.yaml index d5cd2b4a4..b48eb6057 100644 --- a/tests/e2e_tests/math/sglang/qwen2.5-1.5b-grpo-pipeline.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-sgl.yaml @@ -12,9 +12,10 @@ cluster: rollout: 0-3 inference: 4-5 actor: 6-7 + reward: 0-3 runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -138,6 +139,7 @@ rollout: data: type: math + dataset_name: boba max_prompt_length: 256 rollout_batch_size: 8 val_rollout_batch_size: null @@ -273,9 +275,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs.yaml similarity index 96% rename from tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs.yaml index 62d0c5247..111969ff4 100644 --- a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline-rollout-logprobs.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm-rollout-logprobs.yaml @@ -12,9 +12,10 @@ cluster: component_placement: rollout: 0-3 actor: 4-7 + reward: 0-3 runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -66,7 +67,7 @@ algorithm: calculate_entropy: True clip_ratio_c: null # 3.0 - adv_type: grpo + adv_type: math_grpo normalize_advantages: False early_stop_imp_ratio: 5.0 use_valid_token_scale: True @@ -260,9 +261,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline.yaml b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm.yaml similarity index 96% rename from tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline.yaml rename to tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm.yaml index 3f7821587..705757c31 100644 --- a/tests/e2e_tests/math/vllm/qwen2.5-1.5b-grpo-pipeline.yaml +++ b/tests/e2e_tests/reasoning/qwen2.5-1.5b-grpo-pipeline-mg-vllm.yaml @@ -13,9 +13,10 @@ cluster: rollout: 0-3 inference: 4-5 actor: 6-7 + reward: 0-3 runner: - task_type: math + task_type: reasoning logger: log_path: /workspace/results/ project_name: rlinf @@ -67,7 +68,7 @@ algorithm: calculate_entropy: True clip_ratio_c: null # 3.0 - adv_type: grpo + adv_type: math_grpo normalize_advantages: False early_stop_imp_ratio: 5.0 use_valid_token_scale: True @@ -274,9 +275,16 @@ actor: reward: + group_name: "RewardGroup" use_reward_model: false reward_type: 'math' reward_scale: 5.0 + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + critic: use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml new file mode 100644 index 000000000..ddc087f99 --- /dev/null +++ b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-sgl.yaml @@ -0,0 +1,228 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: ${runner.output_dir}/${runner.experiment_name} + project_name: rlinf + experiment_name: ${runner.experiment_name} + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 1 + max_steps: 2 + + val_check_interval: 1 + save_interval: -1 + + seq_length: 2048 + + enable_dynamic_batch_size: False + max_tokens_per_mbs: 28672 + + resume_dir: null + experiment_name: grpo-qwen2.5-vl-3b + output_dir: /workspace/results + +algorithm: + group_size: 8 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: False + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /workspace/dataset/Qwen2.5-VL-3B-Instruct + model_arch: qwen2.5_vl #qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: sglang # here choose which backend to rollout,support [sglang, vllm] + + sglang: + attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: vision_language + dataset_name: robo2vlm + max_prompt_length: 1024 + filter_prompt_by_length: True + rollout_batch_size: 16 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + image_keys: ["image"] # some vlm datasets may have multiple image columns + choice_key: "choices" + answer_key: "answer" + solution_key: "solution" + use_chat_template: True + lazy_loading: True + shuffle: True + validation_shuffle: True + seed: 1234 + train_data_paths: ["/workspace/dataset/robo2vlm-1/data/train-00000-of-00262.parquet"] + val_data_paths: ["/workspace/dataset/robo2vlm-1/data/test-00000-of-00003.parquet"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /workspace/dataset/Qwen2.5-VL-3B-Instruct + + model_arch: ${rollout.model_arch} + + optim: + optimizer: adam + bf16: True #False + fp16: False #True + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /workspace/dataset/Qwen2.5-VL-3B-Instruct + use_fast: False + trust_remote_code: True + padding_side: 'right' + + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'vqa' + reward_scale: 1.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + +critic: + use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml new file mode 100644 index 000000000..86c64ebe0 --- /dev/null +++ b/tests/e2e_tests/reasoning/qwen2.5-vl-3b-grpo-collocated-fsdp-vllm.yaml @@ -0,0 +1,228 @@ +defaults: + - override hydra/job_logging: stdout + +hydra: + run: + dir: . + output_subdir: null + +cluster: + num_nodes: 1 + component_placement: + actor,rollout,reward: all + +runner: + task_type: reasoning + logger: + log_path: ${runner.output_dir}/${runner.experiment_name} + project_name: rlinf + experiment_name: ${runner.experiment_name} + logger_backends: ["tensorboard"] # wandb, swanlab + + max_epochs: 1 + max_steps: 2 + + val_check_interval: 1 + save_interval: -1 + + seq_length: 2048 + + enable_dynamic_batch_size: False + max_tokens_per_mbs: 28672 + + resume_dir: null + experiment_name: grpo-qwen2.5-vl-3b + output_dir: /workspace/results + +algorithm: + group_size: 8 + + n_minibatches: 4 + training_batch_size_per_gpu: 1 # micro batch size + rollout_batch_size_per_gpu: null # If set to null, rollout_batch_size will be evenly divided across all inference instances. You can reduce this parameter if inference consumes too much GPU memory. + + # mbs to do log prob inference, can be set to + # lower than rollout_batch_size_per_gpu to reduce + # memory usage + logprob_forward_micro_batch_size: 1 # ${.rollout_batch_size_per_gpu} + + # val rollout mbs + val_rollout_batch_size_per_gpu: 4 # ${.rollout_batch_size_per_gpu} + + recompute_logprobs: False + shuffle_rollout: False + + # GRPO loss params + loss_type: math_ppo_actor + loss_agg_func: "token-mean" + kl_beta: 0.0 # 0.001 + kl_penalty_type: low_var_kl + ratio_clip_eps: 0.2 + entropy_bonus: 0.0 + calculate_entropy: False + clip_ratio_c: null # 3.0 + clip_ratio_low: null + clip_ratio_high: null + + adv_type: math_grpo + normalize_advantages: True + early_stop_imp_ratio: 5.0 + use_valid_token_scale: False + + # params for rollout + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 1000000 + top_p: 1.0 + repetition_penalty: 1.0 + max_new_tokens: ${subtract:${runner.seq_length}, ${data.max_prompt_length}} + min_new_tokens: 1 + +rollout: + group_name: "RolloutGroup" + + gpu_memory_utilization: 0.55 + + model_dir: /workspace/dataset/Qwen2.5-VL-3B-Instruct + model_arch: qwen2.5_vl #qwen2.5 + enforce_eager: False # if False, rollout engine will capture cuda graph, which will take more time to initialize. + distributed_executor_backend: mp # ray or mp + disable_log_stats: False + detokenize: False # Whether to detokenize the output. During RL we actually don't need to detokenize it. Can be set to True for debugging. + padding: null # will be tokenizer.pad_token_id if null. it is used to filter megatron's padding for rollout engine + eos: null # will be tokenizer.eos_token_id if null. + + rollout_backend: vllm # here choose which backend to rollout,support [sglang, vllm] + + sglang: + attention_backend: triton # [flashinfer, triton] for more, see sglang's doc + decode_log_interval: 500000 # the interval for SGLang to log the decode time and other stats. + use_torch_compile: False # enable torch_compile in SGLang for rollout. + torch_compile_max_bs: 128 # the maximum batch size for torch compile. If the batch size is larger than this, torch compile will not be used. + + vllm: + attention_backend: FLASH_ATTN #[FLASH_ATTN,XFORMERS] for more, see vllm's doc + enable_chunked_prefill: True # enable vllm to use chunked_prefill. + enable_prefix_caching: True # enable vllm to use prefix_caching. + enable_flash_infer_sampler: True # if True, vllm will use flashinfer to do sampling. + + return_logprobs: ${not:${algorithm.recompute_logprobs}} + + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + + validate_weight: False # whether to send all weights at first for weight comparison. + validate_save_dir: null # the directory to save the weights for comparison. If validate_weight is True, this will be used to save the weights for comparison. + print_outputs: False # whether to print the outputs (token ids, texts, etc.) of rollout engine. + + max_running_requests: 64 # the maximum number of running requests in the rollout engine. + cuda_graph_max_bs: 128 # the maximum batch size for cuda graph. If the batch size is larger than this, cuda graph will not be used. + +data: + type: vision_language + dataset_name: robo2vlm + max_prompt_length: 1024 + filter_prompt_by_length: True + rollout_batch_size: 16 + val_rollout_batch_size: null + num_workers: 2 + prompt_key: prompt + image_keys: ["image"] # some vlm datasets may have multiple image columns + choice_key: "choices" + answer_key: "answer" + solution_key: "solution" + use_chat_template: True + lazy_loading: True + shuffle: True + validation_shuffle: True + seed: 1234 + train_data_paths: ["/workspace/dataset/robo2vlm-1/data/train-00000-of-00262.parquet"] + val_data_paths: ["/workspace/dataset/robo2vlm-1/data/test-00000-of-00003.parquet"] + +actor: + group_name: "ActorGroup" + training_backend: fsdp + mcore_gpt: True + spec_name: decoder_gpt + + enable_offload: True + checkpoint_load_path: null + + global_batch_size: 8 + micro_batch_size: 1 + + enable_dp_load_balance: False + + calculate_flops: False + + seed: 1234 + + model: + precision: bf16 + sharding_strategy: full_shard + is_lora: False + + seq_length: ${runner.seq_length} + encoder_seq_length: ${runner.seq_length} + model_path: /workspace/dataset/Qwen2.5-VL-3B-Instruct + + model_arch: ${rollout.model_arch} + + optim: + optimizer: adam + bf16: True #False + fp16: False #True + lr: 2e-05 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-05 + min_lr: 2.0e-6 + weight_decay: 0.05 + use_distributed_optimizer: True + overlap_grad_reduce: False + overlap_param_gather: False + optimizer_enable_pin: false + overlap_param_gather_with_optimizer_step: False + clip_grad: 0.8 + loss_scale: 65536 + + lr_sched: + lr_warmup_fraction: 0.01 + lr_warmup_init: 0.0 + lr_warmup_iters: 0 + max_lr: 2.0e-5 + min_lr: 0.0 + lr_decay_style: constant + lr_decay_iters: 10 + + tokenizer: + tokenizer_model: /workspace/dataset/Qwen2.5-VL-3B-Instruct + use_fast: False + trust_remote_code: True + padding_side: 'right' + + fsdp: + forward_prefetch: True + limit_all_gathers: True + backward_prefetch: True + use_orig_params: True + +reward: + group_name: "RewardGroup" + use_reward_model: false + reward_type: 'vqa' + reward_scale: 1.0 + reward_weights: + qa_accuracy: 1.0 + think_format: 0.0 + answer_format: 0.0 + + tokenizer: + tokenizer_model: ${actor.tokenizer.tokenizer_model} + use_fast: ${actor.tokenizer.use_fast} + trust_remote_code: ${actor.tokenizer.trust_remote_code} + padding_side: ${actor.tokenizer.padding_side} + +critic: + use_critic_model: false \ No newline at end of file diff --git a/tests/e2e_tests/reasoning/run.sh b/tests/e2e_tests/reasoning/run.sh new file mode 100644 index 000000000..92e43866f --- /dev/null +++ b/tests/e2e_tests/reasoning/run.sh @@ -0,0 +1,16 @@ +#! /bin/bash +set -x + +tabs 4 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export TOKENIZERS_PARALLELISM=false + +export PYTHONPATH=${REPO_PATH}:$PYTHONPATH + +if [ -z "$1" ]; then + echo "Please provide a config name as the first argument." + exit 1 +else + CONFIG_NAME=$1 +fi +python ${REPO_PATH}/examples/reasoning/main_grpo.py --config-path $REPO_PATH/tests/e2e_tests/reasoning/ --config-name $CONFIG_NAME \ No newline at end of file diff --git a/tests/unit_tests/test_auto_placement.py b/tests/unit_tests/test_auto_placement.py index a70b01c9d..559229763 100644 --- a/tests/unit_tests/test_auto_placement.py +++ b/tests/unit_tests/test_auto_placement.py @@ -598,7 +598,7 @@ def test_scheduler_task_initialization(self, mock_validate): """Test SchedulerTask initialization.""" # Create a mock config mock_cfg = MagicMock() - mock_cfg.runner.task_type = "math" + mock_cfg.runner.task_type = "reasoning" mock_cfg.actor.model.tensor_model_parallel_size = 2 mock_cfg.actor.model.pipeline_model_parallel_size = 1 mock_cfg.rollout.tensor_parallel_size = 1 @@ -620,7 +620,7 @@ def test_scheduler_task_initialization(self, mock_validate): scheduler_task = SchedulerTask(mock_cfg, mock_cluster) - assert scheduler_task.is_math is True + assert scheduler_task.is_reasoning is True assert scheduler_task.total_gpus == 8 assert scheduler_task.group_size == 4 assert "actor" in scheduler_task.components_config diff --git a/tests/unit_tests/test_placement.py b/tests/unit_tests/test_placement.py index c16ff7fb9..22c75ab68 100644 --- a/tests/unit_tests/test_placement.py +++ b/tests/unit_tests/test_placement.py @@ -1087,34 +1087,6 @@ def test_model_parallel_component_placement_init_missing_rollout_gpus(self): cluster = mock_cluster(num_nodes=1, num_accelerators_per_node=4) ModelParallelComponentPlacement(config, cluster) - def test_model_parallel_component_placement_init_collocated_mode_invalid_tp_sizes( - self, - ): - """Test ModelParallelComponentPlacement raises error when actor TP size < rollout TP size in collocated mode.""" - config = DictConfig( - { - "cluster": { - "num_nodes": 1, - "component_placement": {"actor,rollout": "0-3"}, - }, - "actor": { - "model": { - "tensor_model_parallel_size": 2, - "context_parallel_size": 1, - "pipeline_model_parallel_size": 1, - } - }, - "rollout": {"tensor_parallel_size": 4, "pipeline_parallel_size": 1}, - } - ) - - with pytest.raises( - AssertionError, - match="Actor TP size 2 must be greater or equal to Rollout TP size 4", - ): - cluster = mock_cluster(num_nodes=1, num_accelerators_per_node=4) - ModelParallelComponentPlacement(config, cluster) - def test_model_parallel_component_placement_init_collocated_mode_with_inference_gpus( self, ): diff --git a/toolkits/__init__.py b/toolkits/__init__.py index 8b6f0114a..5b365ea1e 100644 --- a/toolkits/__init__.py +++ b/toolkits/__init__.py @@ -11,22 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -from rlinf.algorithms.registry import get_reward_fn - - -def register_rewards(): - try: - from toolkits.code_verifier.verify import fim_verify_call - - assert get_reward_fn("fim_verify_call") == fim_verify_call - except ImportError: - pass - - try: - from toolkits.math_verifier.verify import math_verify_call - - assert get_reward_fn("math") == math_verify_call - except ImportError: - pass diff --git a/toolkits/auto_placement/scheduler_task.py b/toolkits/auto_placement/scheduler_task.py index 00d8ea8aa..b3be46012 100644 --- a/toolkits/auto_placement/scheduler_task.py +++ b/toolkits/auto_placement/scheduler_task.py @@ -31,8 +31,10 @@ def __init__( workflow_graph: Optional[Dict[ComponentNode, List[ComponentNode]]] = None, ): self.cfg = cfg - self.is_math = cfg.runner.task_type == "math" - assert self.is_math, "Only math task is supported" + self.is_reasoning = cfg.runner.task_type == "reasoning" + assert self.is_reasoning, ( + f"Only reasoning task is supported, current task type: {cfg.runner.task_type}" + ) self.components_config = { "actor": { @@ -71,7 +73,7 @@ def __init__( self.global_step_batch_size = self.rollout_batch_size * self.group_size if workflow_graph is None: - if self.is_math: + if self.is_reasoning: actor = ComponentNode("actor") inference = ComponentNode("inference") rollout = ComponentNode("rollout") @@ -179,7 +181,7 @@ def parse_partition_allocation_to_cfg( def time_division_multiplexing(self) -> List[Dict[str, Workflow]]: partitions: List[Dict[str, Workflow]] = get_workflow_partition(self.workflow) - if self.is_math: + if self.is_reasoning: valid_partitions = [ i for i in partitions if len(i) in [1, len(self.components_config)] ] diff --git a/toolkits/code_verifier/verify.py b/toolkits/code_verifier/verify.py index 0ad756d02..5017b9e54 100644 --- a/toolkits/code_verifier/verify.py +++ b/toolkits/code_verifier/verify.py @@ -21,10 +21,8 @@ except ImportError: fuzz = None FUZZY_AVAILABLE = False -from rlinf.algorithms.registry import register_reward_fn -@register_reward_fn("fim_verify_call") def fim_verify_call( responses: List[str], references: List[str], diff --git a/toolkits/math_verifier/verify.py b/toolkits/math_verifier/verify.py index 80bf4b552..8d0cbdb11 100644 --- a/toolkits/math_verifier/verify.py +++ b/toolkits/math_verifier/verify.py @@ -14,7 +14,13 @@ import multiprocessing import re -from concurrent.futures import ProcessPoolExecutor, as_completed +from concurrent.futures import ( + ProcessPoolExecutor, + as_completed, +) +from concurrent.futures import ( + TimeoutError as FuturesTimeoutError, +) from typing import List, Union import regex @@ -23,7 +29,6 @@ from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr -from rlinf.algorithms.registry import register_reward_fn from toolkits.math_verifier.parser import extract_answer global_executor = ProcessPoolExecutor(max_workers=40) @@ -348,22 +353,22 @@ def process_results(answer, solution): extracted_solution = extract_answer(solution, "math", use_last_number=True) if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]: - retval = 0 + retval = -1 elif extracted_solution is None or extracted_solution.strip() in [ "None", "none", "", ]: - retval = 0 + retval = -1 elif math_equal(extracted_answer, extracted_solution, timeout=False): # elif call_with_timeout(math_equal, extracted_answer, extracted_solution): retval = 1 else: - retval = 0 + retval = -1 return retval, (extracted_answer, extracted_solution) except Exception: - return 0, ("None", "None") + return -1, ("None", "None") def process_results_process(a, b, output_queue): @@ -383,7 +388,6 @@ def verify_math_solution(answer: str, solution: str): return process_results(answer, solution)[0] -@register_reward_fn("math") def math_verify_call( responses: List[str], references: List[str], @@ -403,7 +407,7 @@ def math_verify_call( jobs.append(job) all_jobs.append(jobs) - labels = [] + labels: List[int] = [] has_timeout = False for jobs in all_jobs: label = 0 @@ -411,40 +415,18 @@ def math_verify_call( for job in as_completed(jobs, timeout=timeout): x = job.result() label = label or x - except TimeoutError: + except FuturesTimeoutError: has_timeout = True - labels.append(label) + for job in jobs: + job.cancel() + finally: + labels.append(label) if has_timeout: reset_global_process_pool() return labels -class MathRewardModel: - def __init__(self, scale: float): - self.scale = scale - - def get_reward( - self, response: List[str], reference: List[List[str]] - ) -> List[float]: - """ - Calculates reward scores for a list of responses compared to corresponding lists of reference answers. - For each response, the function checks if it matches any of the provided references using the `process_results` function. - The reward for each response is computed as the first element of the result (converted to float) multiplied by `self.scale`. - Args: - response (List[str]): A list of response strings to be evaluated. - reference (List[List[str]]): A list where each element is a list of reference strings corresponding to each response. - Returns: - List[float]: A list of reward scores, one for each response. - """ - - results = [] - for resp, refs in zip(response, reference): - result = any(process_results(resp, ref)[0] for ref in refs) - results.append((1 if result else -1) * self.scale) - return results - - if __name__ == "__main__": sample = { "answers": ["\\boxed{-\\frac{2}{3}}"],