-
Notifications
You must be signed in to change notification settings - Fork 422
【CI】update rl cases #1835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【CI】update rl cases #1835
Changes from all commits
2620c5b
aa512a4
71cdcef
2192551
c357e11
5a0673b
a10a582
4f11574
d90d677
20e7aa8
e387118
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -206,4 +206,5 @@ | |
| work_dir=work_dir, | ||
| seed=123, | ||
| debug_rollout=False, | ||
| exp_tracker="jsonl", | ||
| ) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,57 +1,56 @@ | ||||||||||||||||||||||||||||||||||||||
| """RL Colocate Trainer 示例配置(GRPO + GSM8K)。 | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| 用法:通过环境变量传入路径后,由 CLI 加载本配置并 trainer_cfg.build().fit()。 | ||||||||||||||||||||||||||||||||||||||
| 需设置: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH | ||||||||||||||||||||||||||||||||||||||
| 可选: WORLD_SIZE, ENABLE_RETURN_ROUTED_EXPERTS, LOSS_TYPE, LOSS_MODE, SP_SIZE | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||
| from copy import deepcopy | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.data_proto.rl_data import SampleParams | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.advantage import GRPOAdvantageConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.datasets.rl_tokenize_fn import RLQwen3VLTokenizeFnConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.model.compose.qwen3_vl import Qwen3VLDense8BConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.advantage import GRPOAdvantageConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.agent_loop_manager import ( | ||||||||||||||||||||||||||||||||||||||
| AgentLoopManagerConfig, | ||||||||||||||||||||||||||||||||||||||
| SamplerConfig, | ||||||||||||||||||||||||||||||||||||||
| SyncProduceStrategyConfig, | ||||||||||||||||||||||||||||||||||||||
| TaskSpecConfig, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.evaluator import EvaluatorConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.utils import AcceleratorResourcesConfig, CPUResourcesConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.rollout.worker import RolloutConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.judger import GEO3KJudgerConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.loss import GRPOLossConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.rollout.worker import RolloutConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.trainer import WorkerConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.utils import AcceleratorResourcesConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, SamplerConfig, SyncProduceStrategyConfig, TaskSpecConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.evaluator import EvaluatorConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.rl.loss import GRPOLossConfig | ||||||||||||||||||||||||||||||||||||||
| from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # env | ||||||||||||||||||||||||||||||||||||||
| work_dir = os.environ["WORK_DIR"] | ||||||||||||||||||||||||||||||||||||||
| model_path = os.environ["MODEL_PATH"] | ||||||||||||||||||||||||||||||||||||||
| data_path = os.environ["DATA_PATH"] | ||||||||||||||||||||||||||||||||||||||
| eval_data_path = os.environ["EVAL_DATA_PATH"] | ||||||||||||||||||||||||||||||||||||||
| enable_evaluate = eval_data_path != "" | ||||||||||||||||||||||||||||||||||||||
| enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") | ||||||||||||||||||||||||||||||||||||||
| NNODE = int(os.environ.get("WORLD_SIZE", "1")) | ||||||||||||||||||||||||||||||||||||||
| media_root = os.environ["MEDIA_ROOT"] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # basic settings | ||||||||||||||||||||||||||||||||||||||
| experimental_name = "grpo_geo3k" | ||||||||||||||||||||||||||||||||||||||
| total_epochs = 15 | ||||||||||||||||||||||||||||||||||||||
| global_batch_size = 64 | ||||||||||||||||||||||||||||||||||||||
| total_train_steps = 15 # TODO: total_epoch | ||||||||||||||||||||||||||||||||||||||
| evaluate_step = 15 | ||||||||||||||||||||||||||||||||||||||
| train_optimizer_steps = 4 | ||||||||||||||||||||||||||||||||||||||
| train_batch_size = 1024 | ||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个太大了,可以参考以前值 64 |
||||||||||||||||||||||||||||||||||||||
| prompt_repeat_k = 5 | ||||||||||||||||||||||||||||||||||||||
| rollout_tp_size = 2 | ||||||||||||||||||||||||||||||||||||||
| rollout_tp_size = 1 | ||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 测试 tp2,能测的多一点 |
||||||||||||||||||||||||||||||||||||||
| rollout_ep_size = 1 | ||||||||||||||||||||||||||||||||||||||
| max_prompt_length = 1024 | ||||||||||||||||||||||||||||||||||||||
| max_response_length = 2048 | ||||||||||||||||||||||||||||||||||||||
| pack_max_length = 32768 | ||||||||||||||||||||||||||||||||||||||
| train_optimizer_steps = 4 | ||||||||||||||||||||||||||||||||||||||
| hf_interval = 30 | ||||||||||||||||||||||||||||||||||||||
| enable_initial_evaluate = True | ||||||||||||||||||||||||||||||||||||||
| evaluate_step = 15 | ||||||||||||||||||||||||||||||||||||||
| pack_max_length = 32 * 1024 | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # 1. resources | ||||||||||||||||||||||||||||||||||||||
| resources = AcceleratorResourcesConfig( | ||||||||||||||||||||||||||||||||||||||
| accelerator="GPU", | ||||||||||||||||||||||||||||||||||||||
| num_workers=8, | ||||||||||||||||||||||||||||||||||||||
| num_workers=8 * NNODE, | ||||||||||||||||||||||||||||||||||||||
| num_cpus_per_worker=12, | ||||||||||||||||||||||||||||||||||||||
| cpu_memory_per_worker=16 * 1024**3, | ||||||||||||||||||||||||||||||||||||||
| cpu_memory_per_worker=16 * 1024**3, # 16 GB | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # 2. rollout | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -62,109 +61,134 @@ | |||||||||||||||||||||||||||||||||||||
| dtype="bfloat16", | ||||||||||||||||||||||||||||||||||||||
| tensor_parallel_size=rollout_tp_size, | ||||||||||||||||||||||||||||||||||||||
| expert_parallel_size=rollout_ep_size, | ||||||||||||||||||||||||||||||||||||||
| gpu_memory_utilization=0.75, | ||||||||||||||||||||||||||||||||||||||
| gpu_memory_utilization=0.8, | ||||||||||||||||||||||||||||||||||||||
| context_length=max_response_length + max_prompt_length, | ||||||||||||||||||||||||||||||||||||||
| enable_return_routed_experts=(enable_return_routed_experts == "1"), | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # sampling params | ||||||||||||||||||||||||||||||||||||||
| training_sample_params = SampleParams(max_tokens=max_response_length) | ||||||||||||||||||||||||||||||||||||||
| evaluation_sample_params = deepcopy(training_sample_params) | ||||||||||||||||||||||||||||||||||||||
| evaluation_sample_params.top_p = 1.0 | ||||||||||||||||||||||||||||||||||||||
| evaluation_sample_params.temperature = 0.0 | ||||||||||||||||||||||||||||||||||||||
| evaluation_sample_params.top_k = 1 | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # 3. datasets | ||||||||||||||||||||||||||||||||||||||
| train_dataset_cfg = [ | ||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||
| "dataset": DatasetConfig( | ||||||||||||||||||||||||||||||||||||||
| name="geo3k", | ||||||||||||||||||||||||||||||||||||||
| anno_path=data_path, | ||||||||||||||||||||||||||||||||||||||
| class_name="VLMJsonlDataset", | ||||||||||||||||||||||||||||||||||||||
| media_root=media_root, | ||||||||||||||||||||||||||||||||||||||
| sample_ratio=1.0, | ||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||
| "tokenize_fn": RLQwen3VLTokenizeFnConfig(processor_path=model_path, max_length=max_prompt_length), | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||
| eval_dataset_cfg = [ | ||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||
| "dataset": DatasetConfig( | ||||||||||||||||||||||||||||||||||||||
| name="geo3k", | ||||||||||||||||||||||||||||||||||||||
| anno_path=eval_data_path if enable_evaluate else data_path, | ||||||||||||||||||||||||||||||||||||||
| class_name="VLMJsonlDataset", | ||||||||||||||||||||||||||||||||||||||
| media_root=media_root, | ||||||||||||||||||||||||||||||||||||||
| sample_ratio=1.0, | ||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||
| "tokenize_fn": RLQwen3VLTokenizeFnConfig( | ||||||||||||||||||||||||||||||||||||||
| processor_path=model_path, | ||||||||||||||||||||||||||||||||||||||
| max_length=max_prompt_length, | ||||||||||||||||||||||||||||||||||||||
| ignore_multimodal_info=True, | ||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||
| dataloader_cfg = DataloaderConfig( | ||||||||||||||||||||||||||||||||||||||
| dataset_config_list=train_dataset_cfg, | ||||||||||||||||||||||||||||||||||||||
| pack_max_length=pack_max_length, | ||||||||||||||||||||||||||||||||||||||
| collator="fake_collator", | ||||||||||||||||||||||||||||||||||||||
| pack_level="none", | ||||||||||||||||||||||||||||||||||||||
| num_workers=8, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| eval_dataloader_cfg = DataloaderConfig( | ||||||||||||||||||||||||||||||||||||||
| dataset_config_list=eval_dataset_cfg, | ||||||||||||||||||||||||||||||||||||||
| pack_max_length=pack_max_length, | ||||||||||||||||||||||||||||||||||||||
| collator="fake_collator", | ||||||||||||||||||||||||||||||||||||||
| pack_level="none", | ||||||||||||||||||||||||||||||||||||||
| num_workers=8, | ||||||||||||||||||||||||||||||||||||||
| # 3. judger | ||||||||||||||||||||||||||||||||||||||
| judger_config = GEO3KJudgerConfig( | ||||||||||||||||||||||||||||||||||||||
| cpu_resources=CPUResourcesConfig(num_workers=1, num_cpus_per_worker=1), | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # 4. judger | ||||||||||||||||||||||||||||||||||||||
| judger_config = GEO3KJudgerConfig() | ||||||||||||||||||||||||||||||||||||||
| # 4. train worker | ||||||||||||||||||||||||||||||||||||||
| lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) | ||||||||||||||||||||||||||||||||||||||
| fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # 5. train worker | ||||||||||||||||||||||||||||||||||||||
| # TODO: support get_model_config_from_hf | ||||||||||||||||||||||||||||||||||||||
| model_cfg = Qwen3VLDense8BConfig(freeze_vision=True, freeze_projector=True) | ||||||||||||||||||||||||||||||||||||||
| optim_cfg = AdamWConfig(lr=1e-6, foreach=False) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if hasattr(model_cfg.text_config, "balancing_loss_cfg"): | ||||||||||||||||||||||||||||||||||||||
| model_cfg.text_config.balancing_loss_cfg = None | ||||||||||||||||||||||||||||||||||||||
| if hasattr(model_cfg.text_config, "z_loss_cfg"): | ||||||||||||||||||||||||||||||||||||||
| model_cfg.text_config.z_loss_cfg = None | ||||||||||||||||||||||||||||||||||||||
| optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) | ||||||||||||||||||||||||||||||||||||||
| loss_cfg = GRPOLossConfig( | ||||||||||||||||||||||||||||||||||||||
| policy_loss_cfg=dict( | ||||||||||||||||||||||||||||||||||||||
| cliprange_high=0.2, | ||||||||||||||||||||||||||||||||||||||
| cliprange_high=0.28, | ||||||||||||||||||||||||||||||||||||||
| cliprange_low=0.2, | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+82
to
89
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Inconsistent indentation — the inner dict entries use 12 spaces while the outer list uses 4. This looks like it was intended to be 8-space aligned (matching the GSM8K config style). Consider reformatting for consistency:
Suggested change
Also, per project style (CLAUDE.md), use double quotes
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ignore |
||||||||||||||||||||||||||||||||||||||
| loss_type="vanilla", | ||||||||||||||||||||||||||||||||||||||
| loss_type=os.environ.get("LOSS_TYPE", "vanilla"), | ||||||||||||||||||||||||||||||||||||||
| clip_ratio_c=10.0, | ||||||||||||||||||||||||||||||||||||||
| log_prob_diff_min=-20.0, | ||||||||||||||||||||||||||||||||||||||
| log_prob_diff_max=20.0, | ||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||
| ignore_idx=-100, | ||||||||||||||||||||||||||||||||||||||
| use_kl_loss=True, | ||||||||||||||||||||||||||||||||||||||
| kl_loss_coef=0.001, | ||||||||||||||||||||||||||||||||||||||
| use_kl_loss=False, | ||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个也不要改,测试覆盖率更高
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 其实这个配置是examples/v1/config/rl_grpo_geo3k_judge.py中的 |
||||||||||||||||||||||||||||||||||||||
| kl_loss_coef=0.0, | ||||||||||||||||||||||||||||||||||||||
| kl_loss_type="low_var_kl", | ||||||||||||||||||||||||||||||||||||||
| mode="chunk", | ||||||||||||||||||||||||||||||||||||||
| mode=os.environ.get("LOSS_MODE", "chunk"), | ||||||||||||||||||||||||||||||||||||||
| chunk_size=512, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) | ||||||||||||||||||||||||||||||||||||||
| fsdp_cfg = FSDPConfig(cpu_offload=False) | ||||||||||||||||||||||||||||||||||||||
| train_worker_cfg: WorkerConfig = WorkerConfig( | ||||||||||||||||||||||||||||||||||||||
| train_worker_cfg = WorkerConfig( | ||||||||||||||||||||||||||||||||||||||
| model_cfg=model_cfg, | ||||||||||||||||||||||||||||||||||||||
| load_from=model_path, | ||||||||||||||||||||||||||||||||||||||
| optim_cfg=optim_cfg, | ||||||||||||||||||||||||||||||||||||||
| loss_cfg=loss_cfg, | ||||||||||||||||||||||||||||||||||||||
| lr_cfg=lr_cfg, | ||||||||||||||||||||||||||||||||||||||
| fsdp_cfg=fsdp_cfg, | ||||||||||||||||||||||||||||||||||||||
| sp_size=1, | ||||||||||||||||||||||||||||||||||||||
| sp_size=int(os.environ.get("SP_SIZE", "1")), | ||||||||||||||||||||||||||||||||||||||
| optimizer_steps=train_optimizer_steps, | ||||||||||||||||||||||||||||||||||||||
| pack_max_length=pack_max_length, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # 6. agent loop managers | ||||||||||||||||||||||||||||||||||||||
| # 5. train agent loop manager | ||||||||||||||||||||||||||||||||||||||
| train_dataset_cfg = [ | ||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||
| "dataset": DatasetConfig(name="geo3k", | ||||||||||||||||||||||||||||||||||||||
| anno_path=data_path, | ||||||||||||||||||||||||||||||||||||||
| class_name='VLMJsonlDataset', | ||||||||||||||||||||||||||||||||||||||
| media_root=media_root, | ||||||||||||||||||||||||||||||||||||||
| sample_ratio=1.0), | ||||||||||||||||||||||||||||||||||||||
| "tokenize_fn": RLQwen3VLTokenizeFnConfig(processor_path=model_path, | ||||||||||||||||||||||||||||||||||||||
| max_length=max_prompt_length), | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| dataloader_cfg = DataloaderConfig( | ||||||||||||||||||||||||||||||||||||||
| dataset_config_list=train_dataset_cfg, | ||||||||||||||||||||||||||||||||||||||
| pack_max_length=pack_max_length, | ||||||||||||||||||||||||||||||||||||||
| collator="fake_collator", | ||||||||||||||||||||||||||||||||||||||
| pack_level="none", | ||||||||||||||||||||||||||||||||||||||
| num_workers=8, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| sampler_config = SamplerConfig( | ||||||||||||||||||||||||||||||||||||||
| dataloader_cfg=dataloader_cfg, | ||||||||||||||||||||||||||||||||||||||
| prompt_repeat_k=prompt_repeat_k, | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+127
to
+136
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Same formatting issues here — mixed indentation and single quotes. These should match the project's double-quote convention (
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ignore |
||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| training_sample_params = SampleParams( | ||||||||||||||||||||||||||||||||||||||
| max_tokens=max_response_length, | ||||||||||||||||||||||||||||||||||||||
| top_k=0, | ||||||||||||||||||||||||||||||||||||||
| top_p=1.0, | ||||||||||||||||||||||||||||||||||||||
| temperature=1.0, | ||||||||||||||||||||||||||||||||||||||
| min_tokens=0, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| agent_loop_config = SingleTurnAgentLoopConfig( | ||||||||||||||||||||||||||||||||||||||
| hf_checkpoint=model_path, | ||||||||||||||||||||||||||||||||||||||
| sample_params=training_sample_params, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| produce_strategy_config = SyncProduceStrategyConfig() | ||||||||||||||||||||||||||||||||||||||
| agent_loop_manager_cfg = AgentLoopManagerConfig( | ||||||||||||||||||||||||||||||||||||||
| tasks=TaskSpecConfig( | ||||||||||||||||||||||||||||||||||||||
| task_name="train_task", | ||||||||||||||||||||||||||||||||||||||
| agent_loop_config=agent_loop_config, | ||||||||||||||||||||||||||||||||||||||
| judger_config=judger_config, | ||||||||||||||||||||||||||||||||||||||
| produce_strategy_config=SyncProduceStrategyConfig(), | ||||||||||||||||||||||||||||||||||||||
| sampler_config=SamplerConfig(dataloader_cfg=dataloader_cfg, prompt_repeat_k=prompt_repeat_k), | ||||||||||||||||||||||||||||||||||||||
| produce_strategy_config=produce_strategy_config, | ||||||||||||||||||||||||||||||||||||||
| sampler_config=sampler_config, | ||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # 6. eval agent loop manager | ||||||||||||||||||||||||||||||||||||||
| eval_dataset_cfg = [ | ||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||
| "dataset": DatasetConfig(name="geo3k", | ||||||||||||||||||||||||||||||||||||||
| anno_path=eval_data_path, | ||||||||||||||||||||||||||||||||||||||
| class_name='VLMJsonlDataset', | ||||||||||||||||||||||||||||||||||||||
| media_root=media_root, | ||||||||||||||||||||||||||||||||||||||
| sample_ratio=1.0), | ||||||||||||||||||||||||||||||||||||||
| "tokenize_fn": RLQwen3VLTokenizeFnConfig(processor_path=model_path, | ||||||||||||||||||||||||||||||||||||||
| max_length=max_prompt_length, | ||||||||||||||||||||||||||||||||||||||
| ignore_multimodal_info=True), | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| eval_dataloader_cfg = DataloaderConfig( | ||||||||||||||||||||||||||||||||||||||
| dataset_config_list=eval_dataset_cfg, | ||||||||||||||||||||||||||||||||||||||
| pack_max_length=pack_max_length, | ||||||||||||||||||||||||||||||||||||||
| collator="fake_collator", | ||||||||||||||||||||||||||||||||||||||
| pack_level="none", | ||||||||||||||||||||||||||||||||||||||
| num_workers=8, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| eval_sampler_config = SamplerConfig( | ||||||||||||||||||||||||||||||||||||||
| dataloader_cfg=eval_dataloader_cfg, | ||||||||||||||||||||||||||||||||||||||
| prompt_repeat_k=1, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| evaluation_sample_params = SampleParams( | ||||||||||||||||||||||||||||||||||||||
| max_tokens=max_response_length, | ||||||||||||||||||||||||||||||||||||||
| top_k=1, | ||||||||||||||||||||||||||||||||||||||
| top_p=1.0, | ||||||||||||||||||||||||||||||||||||||
| temperature=0.0, | ||||||||||||||||||||||||||||||||||||||
| min_tokens=0, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| eval_agent_loop_config = SingleTurnAgentLoopConfig( | ||||||||||||||||||||||||||||||||||||||
| hf_checkpoint=model_path, | ||||||||||||||||||||||||||||||||||||||
| sample_params=evaluation_sample_params, | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -174,28 +198,32 @@ | |||||||||||||||||||||||||||||||||||||
| task_name="eval_task", | ||||||||||||||||||||||||||||||||||||||
| agent_loop_config=eval_agent_loop_config, | ||||||||||||||||||||||||||||||||||||||
| judger_config=judger_config, | ||||||||||||||||||||||||||||||||||||||
| sampler_config=SamplerConfig(dataloader_cfg=eval_dataloader_cfg, prompt_repeat_k=1), | ||||||||||||||||||||||||||||||||||||||
| sampler_config=eval_sampler_config, | ||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # 7. trainer | ||||||||||||||||||||||||||||||||||||||
| # 7. evaluator | ||||||||||||||||||||||||||||||||||||||
| evaluator_config = EvaluatorConfig(compute_metric_func=None) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) | ||||||||||||||||||||||||||||||||||||||
| trainer = RLColocateTrainerConfig( | ||||||||||||||||||||||||||||||||||||||
| resources=resources, | ||||||||||||||||||||||||||||||||||||||
| train_worker_cfg=train_worker_cfg, | ||||||||||||||||||||||||||||||||||||||
| train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config | ||||||||||||||||||||||||||||||||||||||
| rollout_config=rollout_config, | ||||||||||||||||||||||||||||||||||||||
| tokenizer_path=model_path, | ||||||||||||||||||||||||||||||||||||||
| replay_buffer_config=SyncReplayBufferConfig(), | ||||||||||||||||||||||||||||||||||||||
| agent_loop_manager_cfg=agent_loop_manager_cfg, | ||||||||||||||||||||||||||||||||||||||
| eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, | ||||||||||||||||||||||||||||||||||||||
| evaluator_config=EvaluatorConfig(compute_metric_func=None), | ||||||||||||||||||||||||||||||||||||||
| evaluator_config=evaluator_config, | ||||||||||||||||||||||||||||||||||||||
| load_from=model_path, | ||||||||||||||||||||||||||||||||||||||
| total_epochs=total_epochs, | ||||||||||||||||||||||||||||||||||||||
| train_batch_size=global_batch_size, | ||||||||||||||||||||||||||||||||||||||
| total_train_steps=total_train_steps, | ||||||||||||||||||||||||||||||||||||||
| train_batch_size=train_batch_size, | ||||||||||||||||||||||||||||||||||||||
| advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), | ||||||||||||||||||||||||||||||||||||||
| enable_evaluate=enable_evaluate, | ||||||||||||||||||||||||||||||||||||||
| enable_initial_evaluate=enable_evaluate and enable_initial_evaluate, | ||||||||||||||||||||||||||||||||||||||
| enable_evaluate=True, | ||||||||||||||||||||||||||||||||||||||
| enable_initial_evaluate=False, | ||||||||||||||||||||||||||||||||||||||
| evaluate_step=evaluate_step, | ||||||||||||||||||||||||||||||||||||||
| work_dir=work_dir, | ||||||||||||||||||||||||||||||||||||||
| hf_interval=hf_interval, | ||||||||||||||||||||||||||||||||||||||
| seed=123, | ||||||||||||||||||||||||||||||||||||||
| debug_rollout=False, | ||||||||||||||||||||||||||||||||||||||
| exp_tracker="jsonl", | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得 autotest 要不直接测试 rl_qwen3p5_vl_35B_grpo_mixdata.py 这个配置。覆盖率更高,而且是 qwen3.5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
第一批先加已有的吧