Skip to content

Commit 5ad84d7

Browse files
committed
[Enhance] Make ModelOutputs pydantic BaseModel
ghstack-source-id: 2ad7bfb Pull-Request: InternLM#1516
1 parent 3d57b53 commit 5ad84d7

6 files changed

Lines changed: 53 additions & 13 deletions

File tree

xtuner/v1/engine/train_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
293293
seq_ctx=seq_ctx_list,
294294
loss_ctx=loss_ctx_list,
295295
)
296+
output.free_nongrad_feature()
296297

297298
# llm loss has been global averaged
298299
llm_loss = output["loss"]

xtuner/v1/loss/base_loss_ctx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ def forward(
186186
head_weight: torch.Tensor,
187187
head_bias: torch.Tensor | None = None,
188188
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
189+
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo
190+
189191
assert self.loss_kwargs is not None, "loss_kwargs must be set before calling forward"
190192
if head_bias is not None:
191193
raise NotImplementedError("Loss does not support head_bias yet.")
@@ -195,6 +197,10 @@ def forward(
195197
else:
196198
loss, (logits, extra_info) = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)
197199

200+
# TODO: yanhuida, should be removed
201+
if not isinstance(extra_info, ModelForwardExtraLogInfo):
202+
extra_info = ModelForwardExtraLogInfo(extra_info)
203+
198204
extra_info["local_base_loss"] = loss.detach().clone()
199205

200206
# Step 2.c in the loss calculation: reduce the loss over all ranks using all_reduce with autograd support

xtuner/v1/model/base.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,32 @@ def layers_type(self) -> list[Literal["full_attention", "sliding_attention"]]:
194194
]
195195

196196

197-
class ModelOutputs(TypedDict):
198-
hidden_states: NotRequired[list[torch.Tensor]]
199-
logits: NotRequired[torch.Tensor]
197+
class ModelOutputs(PydanticBaseModel):
198+
model_config = ConfigDict(arbitrary_types_allowed=True)
199+
hidden_states: list[torch.Tensor] | None = None
200+
logits: torch.Tensor | None = None
200201
loss: torch.Tensor
201-
extra_info: ModelForwardExtraLogInfo
202+
extra_info: ModelForwardExtraLogInfo | None = None
203+
204+
def free_nongrad_feature(self):
205+
"""Release large intermediate tensors not needed for backward or
206+
logging.
207+
208+
This method is called immediately after forward() in the micro-batch loop.
209+
It releases large tensors (logits, hidden_states) while keeping:
210+
- loss: needed for backward pass
211+
- extra_info: lightweight logging info needed by post_micro_batch_forward()
212+
"""
213+
self.hidden_states = None
214+
self.logits = None
215+
216+
# TODO: Only for avoid BC. Should be removed later.
217+
def __getitem__(self, key):
218+
return getattr(self, key)
219+
220+
# TODO: Only for avoid BC. Should be removed later.
221+
def __contains__(self, key):
222+
return key in self.model_fields_set
202223

203224

204225
def _is_float8_available():

xtuner/v1/model/dense/dense.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def forward(
107107
output["loss"] = loss
108108
output["logits"] = logits
109109
output["extra_info"] = extra_info
110-
return ModelOutputs(**output) # type: ignore[typeddict-item]
110+
return ModelOutputs(**output)
111111

112112
def build_embeddings(self, config: TransformerConfig):
113113
return nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)

xtuner/v1/model/moe/moe.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,22 @@
7979

8080

8181
class MoEModelOutputs(ModelOutputs):
82-
router_logits: NotRequired[dict[str, torch.Tensor]]
83-
balancing_loss: NotRequired[torch.Tensor]
84-
z_loss: NotRequired[torch.Tensor]
85-
tokens_per_expert_global: NotRequired[torch.Tensor]
82+
router_logits: dict[str, torch.Tensor] | None = None
83+
balancing_loss: torch.Tensor | None = None
84+
z_loss: torch.Tensor | None = None
85+
tokens_per_expert_global: torch.Tensor
86+
87+
def free_nongrad_feature(self):
88+
"""Release large intermediate tensors not needed for backward or
89+
logging.
90+
91+
This method is called immediately after forward() in the micro-batch loop.
92+
It releases large tensors (logits, hidden_states) while keeping:
93+
- loss: needed for backward pass
94+
- extra_info: lightweight logging info needed by post_micro_batch_forward()
95+
"""
96+
super().free_nongrad_feature()
97+
self.router_logits = None
8698

8799

88100
class BalancingLossConfig(PydanticBaseModel):
@@ -482,7 +494,7 @@ def _micro_batch_forward(
482494

483495
output["router_logits"] = router_logits_dict
484496

485-
return MoEModelOutputs(**output, logits=logits) # type: ignore[typeddict-item]
497+
return MoEModelOutputs(**output, logits=logits)
486498

487499
def _forward(
488500
self,
@@ -583,7 +595,7 @@ def _forward(
583595
else:
584596
output["router_logits"] = None
585597

586-
return MoEModelOutputs(**output) # type: ignore[typeddict-item]
598+
return MoEModelOutputs(**output)
587599

588600
def build_embeddings(self, config: MoEConfig):
589601
return nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)

xtuner/v1/utils/internal_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def pop_metrics(self, data_batches: list[ModelItem]):
198198

199199
if (
200200
self.internal_metrics_cfg.monitor_moe_load_balance_stats
201-
and (cur_tokens_per_expert := output.get("tokens_per_expert_global")) is not None
201+
and (cur_tokens_per_expert := output.tokens_per_expert_global) is not None
202202
):
203203
# At this point, tokens_per_expert_global is already all-reduced into current rank.
204204
# [num_layers, num_experts]
@@ -209,7 +209,7 @@ def pop_metrics(self, data_batches: list[ModelItem]):
209209

210210
if (
211211
self.internal_metrics_cfg.monitor_moe_router_logits_stats
212-
and (cur_router_logits := output.get("router_logits")) is not None
212+
and (cur_router_logits := output.router_logits) is not None
213213
):
214214
for layer_name, router_logits in cur_router_logits.items():
215215
# [bsz, packed_len, num_experts]

0 commit comments

Comments
 (0)