Skip to content

Commit 3067930

Browse files
committed
[Enhance] Make ModelOutputs pydantic BaseModel
ghstack-source-id: 067a5e1 Pull-Request: InternLM#1516
1 parent 3d57b53 commit 3067930

5 files changed

Lines changed: 50 additions & 11 deletions

File tree

xtuner/v1/engine/train_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
293293
seq_ctx=seq_ctx_list,
294294
loss_ctx=loss_ctx_list,
295295
)
296+
output.free_nongrad_feature()
296297

297298
# llm loss has been global averaged
298299
llm_loss = output["loss"]

xtuner/v1/loss/base_loss_ctx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing_extensions import Self
1313

1414
from xtuner.v1.loss.utils import sp_split
15+
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo
1516

1617
from .chunk_loss import ChunkLoss
1718

@@ -195,6 +196,10 @@ def forward(
195196
else:
196197
loss, (logits, extra_info) = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)
197198

199+
# TODO: yanhuida, should be removed
200+
if not isinstance(extra_info, ModelForwardExtraLogInfo):
201+
extra_info = ModelForwardExtraLogInfo(extra_info)
202+
198203
extra_info["local_base_loss"] = loss.detach().clone()
199204

200205
# Step 2.c in the loss calculation: reduce the loss over all ranks using all_reduce with autograd support

xtuner/v1/model/base.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,32 @@ def layers_type(self) -> list[Literal["full_attention", "sliding_attention"]]:
194194
]
195195

196196

197-
class ModelOutputs(TypedDict):
198-
hidden_states: NotRequired[list[torch.Tensor]]
199-
logits: NotRequired[torch.Tensor]
197+
class ModelOutputs(PydanticBaseModel):
198+
model_config = ConfigDict(arbitrary_types_allowed=True)
199+
hidden_states: list[torch.Tensor] | None = None
200+
logits: torch.Tensor | None = None
200201
loss: torch.Tensor
201-
extra_info: ModelForwardExtraLogInfo
202+
extra_info: ModelForwardExtraLogInfo | None = None
203+
204+
def free_nongrad_feature(self):
205+
"""Release large intermediate tensors not needed for backward or
206+
logging.
207+
208+
This method is called immediately after forward() in the micro-batch loop.
209+
It releases large tensors (logits, hidden_states) while keeping:
210+
- loss: needed for backward pass
211+
- extra_info: lightweight logging info needed by post_micro_batch_forward()
212+
"""
213+
self.hidden_states = None
214+
self.logits = None
215+
216+
# TODO: Only for avoid BC. Should be removed later.
217+
def __getitem__(self, key):
218+
return getattr(self, key)
219+
220+
# TODO: Only for avoid BC. Should be removed later.
221+
def __contains__(self, key):
222+
return key in self.model_fields_set
202223

203224

204225
def _is_float8_available():

xtuner/v1/model/dense/dense.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def forward(
107107
output["loss"] = loss
108108
output["logits"] = logits
109109
output["extra_info"] = extra_info
110-
return ModelOutputs(**output) # type: ignore[typeddict-item]
110+
return ModelOutputs(**output)
111111

112112
def build_embeddings(self, config: TransformerConfig):
113113
return nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)

xtuner/v1/model/moe/moe.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,22 @@
7979

8080

8181
class MoEModelOutputs(ModelOutputs):
82-
router_logits: NotRequired[dict[str, torch.Tensor]]
83-
balancing_loss: NotRequired[torch.Tensor]
84-
z_loss: NotRequired[torch.Tensor]
85-
tokens_per_expert_global: NotRequired[torch.Tensor]
82+
router_logits: dict[str, torch.Tensor] | None = None
83+
balancing_loss: torch.Tensor | None = None
84+
z_loss: torch.Tensor | None = None
85+
tokens_per_expert_global: torch.Tensor
86+
87+
def free_nongrad_feature(self):
88+
"""Release large intermediate tensors not needed for backward or
89+
logging.
90+
91+
This method is called immediately after forward() in the micro-batch loop.
92+
It releases large tensors (logits, hidden_states) while keeping:
93+
- loss: needed for backward pass
94+
- extra_info: lightweight logging info needed by post_micro_batch_forward()
95+
"""
96+
super().free_nongrad_feature()
97+
self.router_logits = None
8698

8799

88100
class BalancingLossConfig(PydanticBaseModel):
@@ -482,7 +494,7 @@ def _micro_batch_forward(
482494

483495
output["router_logits"] = router_logits_dict
484496

485-
return MoEModelOutputs(**output, logits=logits) # type: ignore[typeddict-item]
497+
return MoEModelOutputs(**output, logits=logits)
486498

487499
def _forward(
488500
self,
@@ -583,7 +595,7 @@ def _forward(
583595
else:
584596
output["router_logits"] = None
585597

586-
return MoEModelOutputs(**output) # type: ignore[typeddict-item]
598+
return MoEModelOutputs(**output)
587599

588600
def build_embeddings(self, config: MoEConfig):
589601
return nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)

0 commit comments

Comments
 (0)