Skip to content

[1/N] refactor grpo/gkd#9580

Draft
hjh0119 wants to merge 8 commits into
modelscope:mainfrom
hjh0119:refactor-adv
Draft

[1/N] refactor grpo/gkd#9580
hjh0119 wants to merge 8 commits into
modelscope:mainfrom
hjh0119:refactor-adv

Conversation

@hjh0119

@hjh0119 hjh0119 commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the GKD and GRPO training pipelines to use a unified, sample-based data structure (OnPolicySample, GRPOSample, GKDSample) and standardizes batch collation across backends, while also deprecating the unsupported seq_kd mode. The code review identified several critical bugs and issues, including a short-circuiting bug in GKD trainer's sample processing, a NameError due to a missing excluded variable in GRPO trainer, a missing deep-copy implementation in OnPolicySample.from_row, potential AttributeErrors in null_ref_context and lazy-loaded utility imports, and minor documentation typos and tensor-indexing performance overheads.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

encoded_batch = to_device(template.data_collator(encoded_list, padding_to=padding_to), self.device)
student_encoded_list = []
teacher_encoded_list = []
has_opsd = any(s.build_teacher_view() for s in samples)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

由于 any() 具有短路求值特性,一旦 samples 中的某个样本在调用 build_teacher_view() 时返回 True,后续的样本将不会再执行 build_teacher_view()。由于 build_teacher_view() 具有填充 self.teacher_messages 的副作用,这些被跳过的样本的 self.teacher_messages 将保持为 None,从而在后续的编码阶段引发错误。建议使用列表推导式 any([s.build_teacher_view() for s in samples]) 来强制评估所有元素。

Suggested change
has_opsd = any(s.build_teacher_view() for s in samples)
has_opsd = any([s.build_teacher_view() for s in samples])

Comment on lines 1690 to 1692
def _prepare_model_inputs(self, inputs: 'DataType') -> Dict[str, Any]:
"""Filters inputs to create model_inputs, removing GRPO-specific and template extra keys."""
excluded = GRPO_NON_MODEL_KEYS | FILTERED_KEYS
return {k: v for k, v in inputs.items() if k not in excluded}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

在此次重构中,局部变量 excluded 的定义被删除了,但它仍然在 return 语句中被使用。这会导致在调用 _prepare_model_inputs 时立即引发 NameError: name 'excluded' is not defined 错误,从而导致训练中断。建议重新定义 excluded 变量。

Suggested change
def _prepare_model_inputs(self, inputs: 'DataType') -> Dict[str, Any]:
"""Filters inputs to create model_inputs, removing GRPO-specific and template extra keys."""
excluded = GRPO_NON_MODEL_KEYS | FILTERED_KEYS
return {k: v for k, v in inputs.items() if k not in excluded}
def _prepare_model_inputs(self, inputs: 'DataType') -> Dict[str, Any]:
"""Filters inputs to create model_inputs, removing GRPO-specific and template extra keys."""
excluded = GRPO_NON_MODEL_KEYS | FILTERED_KEYS
return {k: v for k, v in inputs.items() if k not in excluded}

### Mode 3: Off-Policy 学习
- 触发条件:其他情况
- 数据来源:数据集中的标注响应
| `--global_batch_size` | 全局批次大小:`micro_batch_size × dp_size × gradient_accumulation_steps` |s

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

发现表格行末尾有一个多余的字符 s,这应该是一个拼写错误。

Comment thread swift/rl_core/data.py
Comment on lines +111 to +117
for key, value in row.items():
if key in field_names:
standard[key] = value
elif key == 'is_truncated':
continue # derived from finish_reason
else:
extra[key] = value

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

文档字符串中说明 row 中的值会被深拷贝(deep-copied)以确保样本拥有其数据,并防止在微步(micro-steps)之间共享引用导致数据损坏。然而,当前的实现并没有进行任何拷贝操作。建议按照设计意图对值进行深拷贝。

Suggested change
for key, value in row.items():
if key in field_names:
standard[key] = value
elif key == 'is_truncated':
continue # derived from finish_reason
else:
extra[key] = value
import copy
for key, value in row.items():
copied_value = copy.deepcopy(value)
if key in field_names:
standard[key] = copied_value
elif key == 'is_truncated':
continue # derived from finish_reason
else:
extra[key] = copied_value

Comment on lines +160 to +162
unique_indices = torch.tensor(sorted(seen.values()), device=device)
unique_request_ids = [request_ids[i] for i in unique_indices.cpu()]
unique_prompt_ids = [prompt_ids[i] for i in unique_indices.cpu()]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

使用 PyTorch 张量元素来索引 Python 列表(例如 request_ids[i],其中 i 是遍历 unique_indices.cpu() 得到的 0-D 张量)会带来不必要的性能开销,因为每次索引都需要将张量元素转换为 Python 整数。由于 sorted(seen.values()) 本身已经是一个 Python 整数列表,建议直接使用它来索引 Python 列表,仅在需要索引 PyTorch 张量 rewards 时才创建对应的设备张量。

Suggested change
unique_indices = torch.tensor(sorted(seen.values()), device=device)
unique_request_ids = [request_ids[i] for i in unique_indices.cpu()]
unique_prompt_ids = [prompt_ids[i] for i in unique_indices.cpu()]
unique_idxs_list = sorted(seen.values())
unique_request_ids = [request_ids[i] for i in unique_idxs_list]
unique_prompt_ids = [prompt_ids[i] for i in unique_idxs_list]
unique_indices = torch.tensor(unique_idxs_list, device=device)

Comment on lines +188 to +194
with self.accelerator.unwrap_model(self.model).disable_adapter() if is_peft_model(
self.model) and not self.ref_adapter_name else nullcontext():
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.set_adapter(self.model_adapter_name or 'default')

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

由于 null_ref_context 被移动到了 RLHFTrainerMixin 中,并被其他不一定定义了 ref_adapter_namemodel_adapter_name 的训练器(如 GRPOTrainer)所调用,直接访问 self.ref_adapter_name 可能会引发 AttributeError。建议使用 getattr 进行防御性保护。

Suggested change
with self.accelerator.unwrap_model(self.model).disable_adapter() if is_peft_model(
self.model) and not self.ref_adapter_name else nullcontext():
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.set_adapter(self.model_adapter_name or 'default')
ref_adapter_name = getattr(self, 'ref_adapter_name', None)
with self.accelerator.unwrap_model(self.model).disable_adapter() if is_peft_model(
self.model) and not ref_adapter_name else nullcontext():
if ref_adapter_name:
self.model.set_adapter(ref_adapter_name)
yield
if ref_adapter_name:
self.model.set_adapter(getattr(self, 'model_adapter_name', None) or 'default')

Comment thread swift/utils/__init__.py
Comment on lines +133 to +134
'parse_args',
'patch_getattr',

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

公共工具函数 parse_args_from_dict 在重构为延迟加载时,被遗漏在 _SUBMOD_ATTRS['utils'] 之外。这会导致任何从 swift.utils 导入 parse_args_from_dict 的外部代码引发 AttributeError。建议将其添加回 _SUBMOD_ATTRS['utils'] 中。

Suggested change
'parse_args',
'patch_getattr',
'parse_args',
'parse_args_from_dict',
'patch_getattr'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant