|
79 | 79 |
|
80 | 80 |
|
81 | 81 | class MoEModelOutputs(ModelOutputs): |
82 | | - router_logits: NotRequired[dict[str, torch.Tensor]] |
83 | | - balancing_loss: NotRequired[torch.Tensor] |
84 | | - z_loss: NotRequired[torch.Tensor] |
85 | | - tokens_per_expert_global: NotRequired[torch.Tensor] |
| 82 | + router_logits: dict[str, torch.Tensor] | None = None |
| 83 | + balancing_loss: torch.Tensor | None = None |
| 84 | + z_loss: torch.Tensor | None = None |
| 85 | + tokens_per_expert_global: torch.Tensor |
| 86 | + |
| 87 | + def free_nongrad_feature(self): |
| 88 | + """Release large intermediate tensors not needed for backward or |
| 89 | + logging. |
| 90 | +
|
| 91 | + This method is called immediately after forward() in the micro-batch loop. |
| 92 | + It releases large tensors (logits, hidden_states) while keeping: |
| 93 | + - loss: needed for backward pass |
| 94 | + - extra_info: lightweight logging info needed by post_micro_batch_forward() |
| 95 | + """ |
| 96 | + super().free_nongrad_feature() |
| 97 | + self.router_logits = None |
86 | 98 |
|
87 | 99 |
|
88 | 100 | class BalancingLossConfig(PydanticBaseModel): |
@@ -482,7 +494,7 @@ def _micro_batch_forward( |
482 | 494 |
|
483 | 495 | output["router_logits"] = router_logits_dict |
484 | 496 |
|
485 | | - return MoEModelOutputs(**output, logits=logits) # type: ignore[typeddict-item] |
| 497 | + return MoEModelOutputs(**output, logits=logits) |
486 | 498 |
|
487 | 499 | def _forward( |
488 | 500 | self, |
@@ -583,7 +595,7 @@ def _forward( |
583 | 595 | else: |
584 | 596 | output["router_logits"] = None |
585 | 597 |
|
586 | | - return MoEModelOutputs(**output) # type: ignore[typeddict-item] |
| 598 | + return MoEModelOutputs(**output) |
587 | 599 |
|
588 | 600 | def build_embeddings(self, config: MoEConfig): |
589 | 601 | return nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) |
|
0 commit comments