refactor(loss): migrate DPOLossConfig, DistillationLossConfig, DraftC…#2520
refactor(loss): migrate DPOLossConfig, DistillationLossConfig, DraftC…#2520NolenLiang wants to merge 4 commits into
Conversation
…rossEntropyLossConfig to BaseModel Convert 3 TypedDict loss config classes to pydantic BaseModel with extra="allow". Update dict-style access (cfg["key"]) to attribute access (cfg.key) in their __init__ methods. - DPOLossConfig: 5 required fields, used by DPOLossFn - DistillationLossConfig: 3 required fields, used by DistillationLossFn - DraftCrossEntropyLossConfig: 1 optional field (arbitrary_types_allowed for ProcessGroup) Signed-off-by: nliang <nliang@nvidia.com>
|
/ok to test c340a3b |
Callers (dpo.py, tests) pass DPOConfig dict or plain dict to DPOLossFn, not DPOLossConfig BaseModel. Add isinstance check to auto-convert dict to BaseModel, maintaining backward compatibility. Same fix for DistillationLossFn. Signed-off-by: nliang <nliang@nvidia.com>
|
/ok to test b5427ff |
|
/ok to test b5427ff |
Add defaults to DPOLossConfig and DistillationLossConfig fields matching the reference configs (dpo.yaml, distillation_math.yaml). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: nliang <nliang@nvidia.com>
|
/ok to test 857f01e |
|
/ok to test 5178546 |
| def __init__(self, cfg: DPOLossConfig | dict, use_linear_ce_fusion: bool = False): | ||
| if isinstance(cfg, dict): | ||
| cfg = DPOLossConfig(**cfg) |
There was a problem hiding this comment.
let's not use this tricky way and fix the places that fail because of this.
| def __init__(self, cfg: DPOLossConfig | dict, use_linear_ce_fusion: bool = False): | |
| if isinstance(cfg, dict): | |
| cfg = DPOLossConfig(**cfg) | |
| def __init__(self, cfg: DPOLossConfig, use_linear_ce_fusion: bool = False): |
There was a problem hiding this comment.
On this branch dpo.py is not modified — DPOLossFn(master_config.dpo) at dpo.py:270 still passes a plain dict (DPOConfig is still a TypedDict here). Removing the guard now would break the L1 functional test. Will remove it once DPO PR #2524 merges.
| def __init__(self, cfg: DistillationLossConfig | dict): | ||
| if isinstance(cfg, dict): | ||
| cfg = DistillationLossConfig(**cfg) |
There was a problem hiding this comment.
There was a problem hiding this comment.
Done. Removed the guard and updated all callers in test_loss_functions.py and test_distillation.py to pass DistillationLossConfig(...) directly.
| class DraftCrossEntropyLossConfig(BaseModel, extra="allow"): | ||
| model_config = {"arbitrary_types_allowed": True} | ||
| vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None |
There was a problem hiding this comment.
I think let's not change this for now, actually DraftCrossEntropyLossConfig is not used and not set from the config.
| class DraftCrossEntropyLossConfig(BaseModel, extra="allow"): | |
| model_config = {"arbitrary_types_allowed": True} | |
| vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None | |
| class DraftCrossEntropyLossConfig(TypedDict): | |
| vocab_parallel_group: Optional[torch.distributed.ProcessGroup] |
There was a problem hiding this comment.
Done. Reverted to TypedDict since it is unused and not loaded from configuration files.
… DistillationLossFn guard 1. Revert DraftCrossEntropyLossConfig to TypedDict (unused, not loaded from config) 2. Remove isinstance(cfg, dict) guard from DistillationLossFn.__init__ and update all callers to pass DistillationLossConfig directly 3. Keep DPOLossFn guard for now (dpo.py still passes dict on this branch; will remove after DPO PR #2524 merges) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: nliang <nliang@nvidia.com>
5178546 to
fd0af2d
Compare
|
/ok to test fd0af2d |
…rossEntropyLossConfig to BaseModel
Convert 3 TypedDict loss config classes to pydantic BaseModel with extra="allow". Update dict-style access (cfg["key"]) to attribute access (cfg.key) in their init methods.
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information