Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
seq_ctx=seq_ctx_list,
loss_ctx=loss_ctx_list,
)
output.free_nongrad_feature()

# llm loss has been global averaged
llm_loss = output["loss"]
Expand Down
6 changes: 6 additions & 0 deletions xtuner/v1/loss/base_loss_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def forward(
head_weight: torch.Tensor,
head_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo

assert self.loss_kwargs is not None, "loss_kwargs must be set before calling forward"
if head_bias is not None:
raise NotImplementedError("Loss does not support head_bias yet.")
Expand All @@ -195,6 +197,10 @@ def forward(
else:
loss, (logits, extra_info) = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)

# TODO: yanhuida, should be removed
if not isinstance(extra_info, ModelForwardExtraLogInfo):
extra_info = ModelForwardExtraLogInfo(extra_info)

extra_info["local_base_loss"] = loss.detach().clone()

# Step 2.c in the loss calculation: reduce the loss over all ranks using all_reduce with autograd support
Expand Down
29 changes: 25 additions & 4 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,32 @@ def layers_type(self) -> list[Literal["full_attention", "sliding_attention", "li
]


class ModelOutputs(TypedDict):
hidden_states: NotRequired[list[torch.Tensor]]
logits: NotRequired[torch.Tensor]
class ModelOutputs(PydanticBaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
hidden_states: list[torch.Tensor] | None = None
logits: torch.Tensor | None = None
loss: torch.Tensor
extra_info: ModelForwardExtraLogInfo
extra_info: ModelForwardExtraLogInfo | None = None

def free_nongrad_feature(self):
"""Release large intermediate tensors not needed for backward or
logging.

This method is called immediately after forward() in the micro-batch loop.
It releases large tensors (logits, hidden_states) while keeping:
- loss: needed for backward pass
- extra_info: lightweight logging info needed by post_micro_batch_forward()
"""
self.hidden_states = None
self.logits = None

# TODO: Only for avoid BC. Should be removed later.
def __getitem__(self, key):
return getattr(self, key)

# TODO: Only for avoid BC. Should be removed later.
def __contains__(self, key):
return key in self.model_fields_set


def _is_float8_available():
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/model/dense/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def forward(
output["loss"] = loss
output["logits"] = logits
output["extra_info"] = extra_info
return ModelOutputs(**output) # type: ignore[typeddict-item]
return ModelOutputs(**output)

def build_embeddings(self, config: TransformerConfig):
return nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
Expand Down
24 changes: 18 additions & 6 deletions xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,22 @@


class MoEModelOutputs(ModelOutputs):
router_logits: NotRequired[dict[str, torch.Tensor]]
balancing_loss: NotRequired[torch.Tensor]
z_loss: NotRequired[torch.Tensor]
tokens_per_expert_global: NotRequired[torch.Tensor]
router_logits: dict[str, torch.Tensor] | None = None
balancing_loss: torch.Tensor | None = None
z_loss: torch.Tensor | None = None
tokens_per_expert_global: torch.Tensor

def free_nongrad_feature(self):
"""Release large intermediate tensors not needed for backward or
logging.

This method is called immediately after forward() in the micro-batch loop.
It releases large tensors (logits, hidden_states) while keeping:
- loss: needed for backward pass
- extra_info: lightweight logging info needed by post_micro_batch_forward()
"""
super().free_nongrad_feature()
self.router_logits = None


class BalancingLossConfig(PydanticBaseModel):
Expand Down Expand Up @@ -486,7 +498,7 @@ def _micro_batch_forward(

output["router_logits"] = router_logits_dict

return MoEModelOutputs(**output, logits=logits) # type: ignore[typeddict-item]
return MoEModelOutputs(**output, logits=logits)

def _forward(
self,
Expand Down Expand Up @@ -587,7 +599,7 @@ def _forward(
else:
output["router_logits"] = None

return MoEModelOutputs(**output) # type: ignore[typeddict-item]
return MoEModelOutputs(**output)

def build_embeddings(self, config: MoEConfig):
return nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/utils/internal_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def pop_metrics(self, data_batches: list[ModelItem]):

if (
self.internal_metrics_cfg.monitor_moe_load_balance_stats
and (cur_tokens_per_expert := output.get("tokens_per_expert_global")) is not None
and (cur_tokens_per_expert := output.tokens_per_expert_global) is not None
):
# At this point, tokens_per_expert_global is already all-reduced into current rank.
# [num_layers, num_experts]
Expand All @@ -209,7 +209,7 @@ def pop_metrics(self, data_batches: list[ModelItem]):

if (
self.internal_metrics_cfg.monitor_moe_router_logits_stats
and (cur_router_logits := output.get("router_logits")) is not None
and (cur_router_logits := output.router_logits) is not None
):
for layer_name, router_logits in cur_router_logits.items():
# [bsz, packed_len, num_experts]
Expand Down
Loading