Skip to content

refactor(loss): migrate DPOLossConfig, DistillationLossConfig, DraftC…#2520

Open
NolenLiang wants to merge 4 commits into
mainfrom
nliang/typeddict-to-basemodel-loss
Open

refactor(loss): migrate DPOLossConfig, DistillationLossConfig, DraftC…#2520
NolenLiang wants to merge 4 commits into
mainfrom
nliang/typeddict-to-basemodel-loss

Conversation

@NolenLiang
Copy link
Copy Markdown
Contributor

…rossEntropyLossConfig to BaseModel

Convert 3 TypedDict loss config classes to pydantic BaseModel with extra="allow". Update dict-style access (cfg["key"]) to attribute access (cfg.key) in their init methods.

  • DPOLossConfig: 5 required fields, used by DPOLossFn
  • DistillationLossConfig: 3 required fields, used by DistillationLossFn
  • DraftCrossEntropyLossConfig: 1 optional field (arbitrary_types_allowed for ProcessGroup)

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

…rossEntropyLossConfig to BaseModel

Convert 3 TypedDict loss config classes to pydantic BaseModel with
extra="allow". Update dict-style access (cfg["key"]) to attribute
access (cfg.key) in their __init__ methods.

- DPOLossConfig: 5 required fields, used by DPOLossFn
- DistillationLossConfig: 3 required fields, used by DistillationLossFn
- DraftCrossEntropyLossConfig: 1 optional field (arbitrary_types_allowed
  for ProcessGroup)

Signed-off-by: nliang <nliang@nvidia.com>
@NolenLiang NolenLiang requested a review from a team as a code owner May 18, 2026 14:30
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 18, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@NolenLiang NolenLiang added the CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version) label May 18, 2026
@NolenLiang
Copy link
Copy Markdown
Contributor Author

/ok to test c340a3b

Callers (dpo.py, tests) pass DPOConfig dict or plain dict to
DPOLossFn, not DPOLossConfig BaseModel. Add isinstance check to
auto-convert dict to BaseModel, maintaining backward compatibility.

Same fix for DistillationLossFn.

Signed-off-by: nliang <nliang@nvidia.com>
@NolenLiang
Copy link
Copy Markdown
Contributor Author

/ok to test b5427ff

@NolenLiang
Copy link
Copy Markdown
Contributor Author

/ok to test b5427ff

@NolenLiang NolenLiang added CI:L1 Run doctests, unit tests, and functional tests and removed CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version) labels May 19, 2026
Add defaults to DPOLossConfig and DistillationLossConfig fields
matching the reference configs (dpo.yaml, distillation_math.yaml).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: nliang <nliang@nvidia.com>
@NolenLiang
Copy link
Copy Markdown
Contributor Author

/ok to test 857f01e

@NolenLiang
Copy link
Copy Markdown
Contributor Author

/ok to test 5178546

Comment on lines +865 to +867
def __init__(self, cfg: DPOLossConfig | dict, use_linear_ce_fusion: bool = False):
if isinstance(cfg, dict):
cfg = DPOLossConfig(**cfg)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not use this tricky way and fix the places that fail because of this.

Suggested change
def __init__(self, cfg: DPOLossConfig | dict, use_linear_ce_fusion: bool = False):
if isinstance(cfg, dict):
cfg = DPOLossConfig(**cfg)
def __init__(self, cfg: DPOLossConfig, use_linear_ce_fusion: bool = False):

Copy link
Copy Markdown
Contributor Author

@NolenLiang NolenLiang May 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On this branch dpo.py is not modified — DPOLossFn(master_config.dpo) at dpo.py:270 still passes a plain dict (DPOConfig is still a TypedDict here). Removing the guard now would break the L1 functional test. Will remove it once DPO PR #2524 merges.

Comment on lines +972 to +974
def __init__(self, cfg: DistillationLossConfig | dict):
if isinstance(cfg, dict):
cfg = DistillationLossConfig(**cfg)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Removed the guard and updated all callers in test_loss_functions.py and test_distillation.py to pass DistillationLossConfig(...) directly.

Comment on lines +28 to +30
class DraftCrossEntropyLossConfig(BaseModel, extra="allow"):
model_config = {"arbitrary_types_allowed": True}
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think let's not change this for now, actually DraftCrossEntropyLossConfig is not used and not set from the config.

Suggested change
class DraftCrossEntropyLossConfig(BaseModel, extra="allow"):
model_config = {"arbitrary_types_allowed": True}
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None
class DraftCrossEntropyLossConfig(TypedDict):
vocab_parallel_group: Optional[torch.distributed.ProcessGroup]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Reverted to TypedDict since it is unused and not loaded from configuration files.

… DistillationLossFn guard

1. Revert DraftCrossEntropyLossConfig to TypedDict (unused, not loaded from config)
2. Remove isinstance(cfg, dict) guard from DistillationLossFn.__init__
   and update all callers to pass DistillationLossConfig directly
3. Keep DPOLossFn guard for now (dpo.py still passes dict on this branch;
   will remove after DPO PR #2524 merges)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: nliang <nliang@nvidia.com>
@NolenLiang NolenLiang force-pushed the nliang/typeddict-to-basemodel-loss branch from 5178546 to fd0af2d Compare May 20, 2026 05:14
@NolenLiang NolenLiang requested a review from a team as a code owner May 20, 2026 05:14
@NolenLiang
Copy link
Copy Markdown
Contributor Author

/ok to test fd0af2d

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants