Skip to content

Commit 5241ddf

Browse files
committed
[Enhance] Make ModelOutputs pydantic BaseModel
1 parent 3b9c8d4 commit 5241ddf

5 files changed

Lines changed: 38 additions & 12 deletions

File tree

.github/workflows/claude.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ on:
99
types: [opened, assigned]
1010
pull_request_review:
1111
types: [submitted]
12+
pull_request_target:
13+
types: [opened, synchronize]
14+
branches: [main]
1215

1316
jobs:
1417
claude:
@@ -38,7 +41,7 @@ jobs:
3841
# Prompt A workaround for claude code action bug of `Fork` PR
3942
prompt: |
4043
REPO: ${{ github.repository }}
41-
PR NUMBER: ${{ github.event.pull_request.number }}
44+
PR NUMBER: ${{ github.event.pull_request.number || github.event.issue.number}}
4245
4346
Please review this pull request.
4447

xtuner/v1/loss/base_loss_ctx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing_extensions import Self
1313

1414
from xtuner.v1.loss.utils import sp_split
15+
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo
1516

1617
from .chunk_loss import ChunkLoss
1718

@@ -195,6 +196,10 @@ def forward(
195196
else:
196197
loss, (logits, extra_info) = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)
197198

199+
# TODO: yanhuida, should be removed
200+
if not isinstance(extra_info, ModelForwardExtraLogInfo):
201+
extra_info = ModelForwardExtraLogInfo(extra_info)
202+
198203
extra_info["local_base_loss"] = loss.detach().clone()
199204

200205
# Step 2.c in the loss calculation: reduce the loss over all ranks using all_reduce with autograd support

xtuner/v1/model/base.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,25 @@ def layers_type(self) -> list[Literal["full_attention", "sliding_attention"]]:
194194
]
195195

196196

197-
class ModelOutputs(TypedDict):
198-
hidden_states: NotRequired[list[torch.Tensor]]
199-
logits: NotRequired[torch.Tensor]
197+
class ModelOutputs(PydanticBaseModel):
198+
model_config = ConfigDict(arbitrary_types_allowed=True)
199+
hidden_states: list[torch.Tensor] | None = None
200+
logits: torch.Tensor | None = None
200201
loss: torch.Tensor
201-
extra_info: ModelForwardExtraLogInfo
202+
extra_info: ModelForwardExtraLogInfo | None = None
203+
204+
def free(self):
205+
self.hidden_states = None
206+
self.logits = None
207+
self.extra_info = None
208+
209+
# TODO: Only for avoid BC. Should be removed later.
210+
def __getitem__(self, key):
211+
return getattr(self, key)
212+
213+
# TODO: Only for avoid BC. Should be removed later.
214+
def __contains__(self, key):
215+
return key in self.model_fields_set
202216

203217

204218
def _is_float8_available():

xtuner/v1/model/dense/dense.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def forward(
107107
output["loss"] = loss
108108
output["logits"] = logits
109109
output["extra_info"] = extra_info
110-
return ModelOutputs(**output) # type: ignore[typeddict-item]
110+
return ModelOutputs(**output)
111111

112112
def build_embeddings(self, config: TransformerConfig):
113113
return nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)

xtuner/v1/model/moe/moe.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,14 @@
7979

8080

8181
class MoEModelOutputs(ModelOutputs):
82-
router_logits: NotRequired[dict[str, torch.Tensor]]
83-
balancing_loss: NotRequired[torch.Tensor]
84-
z_loss: NotRequired[torch.Tensor]
85-
tokens_per_expert_global: NotRequired[torch.Tensor]
82+
router_logits: dict[str, torch.Tensor] | None = None
83+
balancing_loss: torch.Tensor | None = None
84+
z_loss: torch.Tensor | None = None
85+
tokens_per_expert_global: torch.Tensor
86+
87+
def free(self):
88+
super().free()
89+
self.router_logits = None
8690

8791

8892
class BalancingLossConfig(PydanticBaseModel):
@@ -482,7 +486,7 @@ def _micro_batch_forward(
482486

483487
output["router_logits"] = router_logits_dict
484488

485-
return MoEModelOutputs(**output, logits=logits) # type: ignore[typeddict-item]
489+
return MoEModelOutputs(**output, logits=logits)
486490

487491
def _forward(
488492
self,
@@ -583,7 +587,7 @@ def _forward(
583587
else:
584588
output["router_logits"] = None
585589

586-
return MoEModelOutputs(**output) # type: ignore[typeddict-item]
590+
return MoEModelOutputs(**output)
587591

588592
def build_embeddings(self, config: MoEConfig):
589593
return nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)

0 commit comments

Comments
 (0)