[1/N] refactor grpo/gkd#9580
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the GKD and GRPO training pipelines to use a unified, sample-based data structure (OnPolicySample, GRPOSample, GKDSample) and standardizes batch collation across backends, while also deprecating the unsupported seq_kd mode. The code review identified several critical bugs and issues, including a short-circuiting bug in GKD trainer's sample processing, a NameError due to a missing excluded variable in GRPO trainer, a missing deep-copy implementation in OnPolicySample.from_row, potential AttributeErrors in null_ref_context and lazy-loaded utility imports, and minor documentation typos and tensor-indexing performance overheads.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| encoded_batch = to_device(template.data_collator(encoded_list, padding_to=padding_to), self.device) | ||
| student_encoded_list = [] | ||
| teacher_encoded_list = [] | ||
| has_opsd = any(s.build_teacher_view() for s in samples) |
There was a problem hiding this comment.
由于 any() 具有短路求值特性,一旦 samples 中的某个样本在调用 build_teacher_view() 时返回 True,后续的样本将不会再执行 build_teacher_view()。由于 build_teacher_view() 具有填充 self.teacher_messages 的副作用,这些被跳过的样本的 self.teacher_messages 将保持为 None,从而在后续的编码阶段引发错误。建议使用列表推导式 any([s.build_teacher_view() for s in samples]) 来强制评估所有元素。
| has_opsd = any(s.build_teacher_view() for s in samples) | |
| has_opsd = any([s.build_teacher_view() for s in samples]) |
| def _prepare_model_inputs(self, inputs: 'DataType') -> Dict[str, Any]: | ||
| """Filters inputs to create model_inputs, removing GRPO-specific and template extra keys.""" | ||
| excluded = GRPO_NON_MODEL_KEYS | FILTERED_KEYS | ||
| return {k: v for k, v in inputs.items() if k not in excluded} |
There was a problem hiding this comment.
在此次重构中,局部变量 excluded 的定义被删除了,但它仍然在 return 语句中被使用。这会导致在调用 _prepare_model_inputs 时立即引发 NameError: name 'excluded' is not defined 错误,从而导致训练中断。建议重新定义 excluded 变量。
| def _prepare_model_inputs(self, inputs: 'DataType') -> Dict[str, Any]: | |
| """Filters inputs to create model_inputs, removing GRPO-specific and template extra keys.""" | |
| excluded = GRPO_NON_MODEL_KEYS | FILTERED_KEYS | |
| return {k: v for k, v in inputs.items() if k not in excluded} | |
| def _prepare_model_inputs(self, inputs: 'DataType') -> Dict[str, Any]: | |
| """Filters inputs to create model_inputs, removing GRPO-specific and template extra keys.""" | |
| excluded = GRPO_NON_MODEL_KEYS | FILTERED_KEYS | |
| return {k: v for k, v in inputs.items() if k not in excluded} |
| ### Mode 3: Off-Policy 学习 | ||
| - 触发条件:其他情况 | ||
| - 数据来源:数据集中的标注响应 | ||
| | `--global_batch_size` | 全局批次大小:`micro_batch_size × dp_size × gradient_accumulation_steps` |s |
| for key, value in row.items(): | ||
| if key in field_names: | ||
| standard[key] = value | ||
| elif key == 'is_truncated': | ||
| continue # derived from finish_reason | ||
| else: | ||
| extra[key] = value |
There was a problem hiding this comment.
文档字符串中说明 row 中的值会被深拷贝(deep-copied)以确保样本拥有其数据,并防止在微步(micro-steps)之间共享引用导致数据损坏。然而,当前的实现并没有进行任何拷贝操作。建议按照设计意图对值进行深拷贝。
| for key, value in row.items(): | |
| if key in field_names: | |
| standard[key] = value | |
| elif key == 'is_truncated': | |
| continue # derived from finish_reason | |
| else: | |
| extra[key] = value | |
| import copy | |
| for key, value in row.items(): | |
| copied_value = copy.deepcopy(value) | |
| if key in field_names: | |
| standard[key] = copied_value | |
| elif key == 'is_truncated': | |
| continue # derived from finish_reason | |
| else: | |
| extra[key] = copied_value |
| unique_indices = torch.tensor(sorted(seen.values()), device=device) | ||
| unique_request_ids = [request_ids[i] for i in unique_indices.cpu()] | ||
| unique_prompt_ids = [prompt_ids[i] for i in unique_indices.cpu()] |
There was a problem hiding this comment.
使用 PyTorch 张量元素来索引 Python 列表(例如 request_ids[i],其中 i 是遍历 unique_indices.cpu() 得到的 0-D 张量)会带来不必要的性能开销,因为每次索引都需要将张量元素转换为 Python 整数。由于 sorted(seen.values()) 本身已经是一个 Python 整数列表,建议直接使用它来索引 Python 列表,仅在需要索引 PyTorch 张量 rewards 时才创建对应的设备张量。
| unique_indices = torch.tensor(sorted(seen.values()), device=device) | |
| unique_request_ids = [request_ids[i] for i in unique_indices.cpu()] | |
| unique_prompt_ids = [prompt_ids[i] for i in unique_indices.cpu()] | |
| unique_idxs_list = sorted(seen.values()) | |
| unique_request_ids = [request_ids[i] for i in unique_idxs_list] | |
| unique_prompt_ids = [prompt_ids[i] for i in unique_idxs_list] | |
| unique_indices = torch.tensor(unique_idxs_list, device=device) |
| with self.accelerator.unwrap_model(self.model).disable_adapter() if is_peft_model( | ||
| self.model) and not self.ref_adapter_name else nullcontext(): | ||
| if self.ref_adapter_name: | ||
| self.model.set_adapter(self.ref_adapter_name) | ||
| yield | ||
| if self.ref_adapter_name: | ||
| self.model.set_adapter(self.model_adapter_name or 'default') |
There was a problem hiding this comment.
由于 null_ref_context 被移动到了 RLHFTrainerMixin 中,并被其他不一定定义了 ref_adapter_name 或 model_adapter_name 的训练器(如 GRPOTrainer)所调用,直接访问 self.ref_adapter_name 可能会引发 AttributeError。建议使用 getattr 进行防御性保护。
| with self.accelerator.unwrap_model(self.model).disable_adapter() if is_peft_model( | |
| self.model) and not self.ref_adapter_name else nullcontext(): | |
| if self.ref_adapter_name: | |
| self.model.set_adapter(self.ref_adapter_name) | |
| yield | |
| if self.ref_adapter_name: | |
| self.model.set_adapter(self.model_adapter_name or 'default') | |
| ref_adapter_name = getattr(self, 'ref_adapter_name', None) | |
| with self.accelerator.unwrap_model(self.model).disable_adapter() if is_peft_model( | |
| self.model) and not ref_adapter_name else nullcontext(): | |
| if ref_adapter_name: | |
| self.model.set_adapter(ref_adapter_name) | |
| yield | |
| if ref_adapter_name: | |
| self.model.set_adapter(getattr(self, 'model_adapter_name', None) or 'default') |
| 'parse_args', | ||
| 'patch_getattr', |
There was a problem hiding this comment.
No description provided.