Skip to content

Commit 3ae3fe6

Browse files
committed
[Refactor] Unify TrainEngine by moving model-specific logic to model layer
Previously, we had two separate train engines: - `TrainEngine` for regular models - `VisionComposeTrainEngine` for vision-language models This duplication led to: - Code maintenance overhead (242 lines of duplicated logic) - Tight coupling between engine and model-specific details - Difficulty in extending to new model types - **Remove** `VisionComposeTrainEngine` entirely (242 lines deleted) - **Add** `pre_micro_batch_forward()` and `post_micro_batch_forward()` hooks to `BaseModel` - `pre_micro_batch_forward()`: Compute data batch statistics before forward pass - `post_micro_batch_forward()`: Aggregate micro-batch results and compute metrics - **Unify** `TrainEngine` to handle all model types through the new hook system - **BaseModel**: - Add `DataBatchInfo` and `BatchForwardInfo` TypedDicts for return types - Implement default `pre_micro_batch_forward()` to compute token statistics - Implement default `post_micro_batch_forward()` to aggregate losses and extra info - Add overload type hints for `__call__` to improve type inference - **MoE Model**: - Override `post_micro_batch_forward()` to handle MoE-specific logic: - Compute maxvio for router load balancing - Update router bias based on expert load - Add `need_update_bias` property for cleaner code - Properly scale balancing_loss and z_loss by batch_size - **ComposeModel**: - Override `pre_micro_batch_forward()` to compute image token statistics - Add `ComposeDataBatchInfo` with `step_consumed_img_tokens` field - **TrainEngine**: - Simplify `train_step()` to delegate statistics to model hooks - Replace `LossLog` and `OtherLog` with unified `TrainStepInfo` - Add `_get_total_loss()` to aggregate all losses (with TODO for future refactor) - Remove all model-specific branching logic - **EngineConfig**: - Remove conditional logic for VisionComposeTrainEngine - Use single TrainEngine.build() path - Update to use `TrainStepInfo` instead of separate `LossLog` and `OtherLog` - Simplify hook signatures (from 2 params to 1) - Remove conditional engine instantiation logic - Replace `VisionComposeTrainEngine` imports with `TrainEngine` - Update test assertions to use new `TrainStepInfo` structure - Remove TypeAdapter validation for deprecated types Currently, `TrainEngine._get_total_loss()` aggregates losses by iterating through ModelOutputs fields. This is pragmatic but not ideal: - **Pros**: Avoids large-scale changes to model forward() logic - **Cons**: Engine knows about loss field names (coupling) - **Future**: Model should return total_loss directly (see TODO comment) `loss_ctx.batch_size` represents the full gradient accumulation batch size, not intra_layer_micro_batch. This is correctly set in `CELossContext.build_batches()` and used for scaling balancing_loss and z_loss. The pre/post hooks provide clean extension points: - Subclasses can override to add model-specific logic - Default implementations in BaseModel handle common cases - No conditional logic needed in engine layer 1. **Code Reduction**: -242 lines (VisionComposeTrainEngine removed) 2. **Better Separation of Concerns**: Engine focuses on training orchestration, models handle their own statistics 3. **Extensibility**: New model types can override hooks without changing engine 4. **Type Safety**: Unified TrainStepInfo with clear field definitions 5. **Maintainability**: Single engine implementation to maintain - Loss reduce logic still needs clarification (minor issue, doesn't affect training) - TODO added for future refactor: move total_loss aggregation to model layer - All format issues (extra blank lines, class formatting) fixed ghstack-source-id: 9ce752d Pull-Request: #1518
1 parent da4e589 commit 3ae3fe6

17 files changed

Lines changed: 311 additions & 530 deletions

tests/engine/test_dense_train_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def warmup_fn(x):
9090
seq_ctx = seq_ctx_list[0]
9191
loss_ctx = loss_ctx_list[0]
9292
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
93-
loss_log, _ = engine.train_step(engine_input)
93+
loss_log = engine.train_step(engine_input)["logs_info"]
9494
grad_norm = engine.clip_grad_norm()
9595
engine.step_optimizer(grad_norm)
9696
lr_scheduler.step()

tests/engine/test_moe_train_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def warmup_fn(x):
9999
loss_ctx = loss_ctx_list[0]
100100
seq_ctx = seq_ctx_list[0]
101101
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
102-
loss_log, _ = engine.train_step(engine_input)
102+
loss_log = engine.train_step(engine_input)["logs_info"]
103103
grad_norm = engine.clip_grad_norm()
104104
engine.step_optimizer(grad_norm)
105105
lr_scheduler.step()
@@ -190,7 +190,7 @@ def warmup_fn(x):
190190
loss_ctx = loss_ctx_list[0]
191191
seq_ctx = seq_ctx_list[0]
192192
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
193-
loss_log, _ = engine.train_step(engine_input)
193+
loss_log = engine.train_step(engine_input)["logs_info"]
194194
grad_norm = engine.clip_grad_norm()
195195
engine.step_optimizer(grad_norm)
196196
lr_scheduler.step()

tests/engine/test_moe_train_engine_float8.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def warmup_fn(x):
9393
loss_ctx = loss_ctx_list[0]
9494
seq_ctx = seq_ctx_list[0]
9595
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
96-
loss_log, _ = engine.train_step(engine_input)
96+
loss_log = engine.train_step(engine_input)["logs_info"]
9797
grad_norm = engine.clip_grad_norm()
9898
engine.step_optimizer(grad_norm)
9999
lr_scheduler.step()
@@ -171,7 +171,7 @@ def warmup_fn(x):
171171
loss_ctx = loss_ctx_list[0]
172172
seq_ctx = seq_ctx_list[0]
173173
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
174-
loss_log, _ = engine.train_step(engine_input)
174+
loss_log = engine.train_step(engine_input)["logs_info"]
175175
grad_norm = engine.clip_grad_norm()
176176
engine.step_optimizer(grad_norm)
177177
lr_scheduler.step()
@@ -270,11 +270,11 @@ def warmup_fn(x):
270270
loss_ctx = loss_ctx_list[0]
271271
seq_ctx = seq_ctx_list[0]
272272
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
273-
loss_log, _ = engine.train_step(engine_input)
273+
logs_info = engine.train_step(engine_input)["logs_info"]
274274
grad_norm = engine.clip_grad_norm()
275275
engine.step_optimizer(grad_norm)
276276
lr_scheduler.step()
277-
losses.append(loss_log["reduced_llm_loss"])
277+
losses.append(logs_info["reduced_llm_loss"])
278278
losses_ref = torch.tensor([2.41, 2.41, 2.47, 2.42, 2.44, 2.44, 2.42, 2.38, 2.31, 2.30])
279279
losses = torch.tensor(losses)
280280
self._check_loss_curve(losses, losses_ref)

tests/model/test_qwen3_tile_embedding.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from xtuner.v1.loss.ce_loss import CELossConfig
2020
from xtuner.v1.config import FSDPConfig, LRConfig, AdamWConfig
2121
from xtuner.v1.engine.train_engine import TrainEngine
22-
from xtuner.v1.engine.vision_compose_train_engine import VisionComposeTrainEngine
2322
from torch.optim.lr_scheduler import LambdaLR
2423
from xtuner.v1.utils import pad_to_max_length
2524
from xtuner.v1.utils.device import get_device
@@ -85,7 +84,7 @@ def warmup_fn(x):
8584
loss_ctx = loss_ctx_list[0]
8685
seq_ctx = seq_ctx_list[0]
8786
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
88-
loss_log, _ = engine.train_step(engine_input)
87+
engine.train_step(engine_input)
8988
grad_norm = engine.clip_grad_norm()
9089
engine.step_optimizer(grad_norm)
9190
lr_scheduler.step()
@@ -116,7 +115,7 @@ def test_qwen3vl_tie_embedding(self, device, tp_size):
116115
cpu_offload=False,
117116
tp_size=tp_size
118117
)
119-
engine = VisionComposeTrainEngine(
118+
engine = TrainEngine(
120119
model_cfg=dense_cfg, optim_cfg=optim_cfg, fsdp_cfg=fsdp_cfg
121120
)
122121
engine.from_hf(hf_path=QWEN3_VL_DENSE_PATH)
@@ -160,7 +159,7 @@ def warmup_fn(x):
160159
loss_ctx = loss_ctx_list[0]
161160
seq_ctx = seq_ctx_list[0]
162161
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
163-
loss_log, _ = engine.train_step(engine_input)
162+
engine.train_step(engine_input)
164163
grad_norm = engine.clip_grad_norm()
165164
engine.step_optimizer(grad_norm)
166165
lr_scheduler.step()

tests/train/test_trainer.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from xtuner.v1.datasets import FTDPTokenizeFnConfig
2727
from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig
2828
from xtuner.v1.train.trainer import TrainerConfig
29-
from xtuner.v1.engine.train_engine import LossLog, OtherLog
3029
from xtuner.v1.loss import CELossConfig
3130
from xtuner._testing import DeterministicDDPTestCase
3231
from unittest import TestCase
@@ -647,8 +646,6 @@ def test_hooks_config(self):
647646
self.create_pg(DEVICE)
648647
checkpoint_function_call_times = 0
649648
train_step_function_call_times = 0
650-
losslog_adapater = TypeAdapter(LossLog)
651-
otherlog_adapter = TypeAdapter(OtherLog)
652649

653650
def checkpoint_hook(checkpoint, step, epoch, total_step, total_epoch):
654651
nonlocal checkpoint_function_call_times
@@ -674,9 +671,6 @@ def __init__(self) -> None:
674671
self.count = 0
675672

676673
def __call__(self, loss_log, other_log, step, epoch, total_step, total_epoch):
677-
losslog_adapater.validate_python(loss_log)
678-
otherlog_adapter.validate_python(other_log)
679-
680674
assert self.trainer().cur_step == step
681675
assert self.trainer().cur_epoch == epoch
682676
assert self.trainer().total_step == total_step

xtuner/v1/engine/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
from xtuner.v1.engine.config import EngineConfig
22

3-
from .train_engine import LossLog, OtherLog, TrainEngine
4-
from .vision_compose_train_engine import VisionComposeTrainEngine
3+
from .train_engine import TrainEngine
54

65

76
__all__ = [
87
"TrainEngine",
98
"EngineConfig",
10-
"VisionComposeTrainEngine",
11-
"LossLog",
12-
"OtherLog",
139
]

xtuner/v1/engine/config.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44

55
from xtuner.v1.config import FSDPConfig, OptimConfig
66
from xtuner.v1.engine.train_engine import TrainEngine
7-
from xtuner.v1.engine.vision_compose_train_engine import VisionComposeTrainEngine
87
from xtuner.v1.model.base import BaseModel, ConfigDict
9-
from xtuner.v1.model.compose.base import BaseComposeConfig
108

119

1210
@runtime_checkable
@@ -27,7 +25,4 @@ class EngineConfig(PydanticBaseModel):
2725
model_cfg: ModelConfigProto
2826

2927
def build(self):
30-
if isinstance(self.model_cfg, BaseComposeConfig):
31-
return VisionComposeTrainEngine(model_cfg=self.model_cfg, optim_cfg=self.optim_cfg, fsdp_cfg=self.fsdp_cfg)
32-
else:
33-
return TrainEngine(model_cfg=self.model_cfg, optim_cfg=self.optim_cfg, fsdp_cfg=self.fsdp_cfg)
28+
return TrainEngine(model_cfg=self.model_cfg, optim_cfg=self.optim_cfg, fsdp_cfg=self.fsdp_cfg)

xtuner/v1/engine/train_engine.py

Lines changed: 40 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch
99
import torch.distributed as dist
1010
import torch.distributed.checkpoint as dcp
11-
from pydantic import ConfigDict
1211
from safetensors import safe_open
1312
from torch.distributed.checkpoint.state_dict import (
1413
StateDictOptions,
@@ -21,42 +20,33 @@
2120
from torch.utils._foreach_utils import (
2221
_device_has_foreach_support,
2322
)
24-
from typing_extensions import NotRequired, TypedDict
2523

2624
from xtuner.v1.config import FSDPConfig, OptimConfig
2725
from xtuner.v1.data_proto.sequence_context import SequenceContext
28-
from xtuner.v1.model.base import BaseModel, ModelItem, XTunerBaseModelConfig
29-
from xtuner.v1.model.utils import ModelForwardExtraLogInfo
30-
from xtuner.v1.module.router import NoAuxRouterConfig
26+
from xtuner.v1.model.base import (
27+
BaseModel,
28+
BatchForwardInfo,
29+
DataBatchInfo,
30+
ModelItem,
31+
ModelOutputs,
32+
XTunerBaseModelConfig,
33+
)
3134
from xtuner.v1.profiler.prober import ProberList
3235
from xtuner.v1.utils import get_device, get_logger, get_torch_device_module, profile_time_and_memory
3336
from xtuner.v1.utils.grad_norm import cal_grad_norm
3437

3538

39+
class TrainStepInfo(DataBatchInfo, BatchForwardInfo):
40+
total_loss: float
41+
42+
3643
logger = get_logger()
3744
DEVICE = get_device()
3845
DEVICE_MODULE = get_torch_device_module()
3946

4047
threading_lock = threading.Lock()
4148

4249

43-
class LossLog(TypedDict):
44-
__pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc]
45-
local_loss: float
46-
reduced_llm_loss: float
47-
reduced_balancing_loss: NotRequired[float]
48-
reduced_z_loss: NotRequired[float]
49-
50-
51-
class OtherLog(TypedDict):
52-
__pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc]
53-
maxvio: NotRequired[float]
54-
step_consumed_tokens: int
55-
step_consumed_img_tokens: NotRequired[int]
56-
extra_info: ModelForwardExtraLogInfo
57-
efficient_attn_ratio: float
58-
59-
6050
class CPUThreadTaskCoordinator:
6151
def __init__(self, futures, callback):
6252
self.futures = futures
@@ -206,66 +196,36 @@ def grad_accumulation_steps(self, data_batches_len: int):
206196
intra_layer_micro_batch = self.intra_layer_micro_batch
207197
return data_batches_len // intra_layer_micro_batch
208198

209-
def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
199+
def train_step(self, data_batches: list[ModelItem]) -> TrainStepInfo:
210200
"""Perform a training step with the given data batches and mesh.
211201
212202
Args:
213203
data_batches (List[Dict]): The input data batches for the training step.
214204
"""
215205
self._maybe_precompute_float8_dynamic_scale_for_fsdp()
216206

217-
loss_log: LossLog = {} # type: ignore[typeddict-item]
218-
other_log: OtherLog = {} # type: ignore[typeddict-item]
219207
intra_layer_micro_batch = self.intra_layer_micro_batch
220208
assert len(data_batches) % intra_layer_micro_batch == 0, (
221209
f"data_batches length {len(data_batches)} is not divisible by intra_layer_micro_batch {intra_layer_micro_batch}"
222210
)
223211
iters_per_step = self.grad_accumulation_steps(len(data_batches))
224212

225-
moe_need_update_bias = (
226-
isinstance(getattr(self.model_cfg, "router", None), NoAuxRouterConfig)
227-
and self.model_cfg.router.router_bias_update_speed > 0
228-
)
229-
moe_need_log_maxvio = getattr(self.model_cfg, "router", None) is not None
230-
231-
if moe_need_log_maxvio:
232-
tokens_per_expert_global_for_bias = torch.zeros(
233-
self.model_cfg.num_hidden_layers - self.model_cfg.first_k_dense_replace,
234-
self.model_cfg.n_routed_experts,
235-
dtype=torch.int64,
236-
device=DEVICE,
237-
)
238-
239-
step_loss = torch.tensor(0.0, device=DEVICE)
240-
step_llm_loss = torch.tensor(0.0, device=DEVICE)
241-
step_balancing_loss: torch.Tensor | None = None
242-
step_z_loss: torch.Tensor | None = None
243-
step_consumed_tokens = torch.tensor(0, device=DEVICE)
244-
245213
if self._count == 0:
246214
logger.info(f"grad_accumulation_steps: {iters_per_step}")
247215
self._count += 1
248216

249-
train_engine_extra_info = ModelForwardExtraLogInfo()
250217
micro_batch_iter = 0
251-
efficient_forward_tokens = torch.tensor(0, device=DEVICE, dtype=torch.long)
252-
total_forward_tokens = torch.tensor(0, device=DEVICE, dtype=torch.long)
218+
micro_batch_results = []
219+
220+
data_batch_info = self.model.pre_micro_batch_forward(data_batches)
221+
total_loss = torch.tensor(0.0, device=DEVICE)
222+
253223
for i in range(0, len(data_batches), intra_layer_micro_batch):
254224
ProberList.set_micro_batch_iter(micro_batch_iter)
255225
micro_batch_iter += 1
256226
data_batch = data_batches[i : i + intra_layer_micro_batch]
257-
seq_ctx_list = []
258-
loss_ctx_list = []
259-
for data in data_batch:
260-
seq_ctx = data["seq_ctx"]
261-
loss_ctx = data["loss_ctx"]
262-
seq_ctx_list.append(seq_ctx)
263-
loss_ctx_list.append(loss_ctx)
264-
step_consumed_tokens += seq_ctx.mask.sum()
265-
266-
num_tokens = seq_ctx.cu_seq_lens_k[1:] - seq_ctx.cu_seq_lens_k[:-1]
267-
efficient_forward_tokens += (num_tokens.long() ** 2).sum()
268-
total_forward_tokens += (num_tokens.long().sum()) ** 2
227+
seq_ctx_list = [i["seq_ctx"] for i in data_batch]
228+
loss_ctx_list = [i["loss_ctx"] for i in data_batch]
269229

270230
if self.intra_layer_micro_batch == 1:
271231
output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0])
@@ -278,78 +238,16 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
278238
)
279239
output.free_nongrad_feature()
280240

281-
# llm loss has been global averaged
282-
llm_loss = output["loss"]
283-
step_llm_loss += llm_loss.detach().clone()
284-
285-
loss = llm_loss
286-
if "extra_info" in output:
287-
train_engine_extra_info.append(output["extra_info"])
288-
289-
if "balancing_loss" in output:
290-
balancing_loss = output["balancing_loss"] / iters_per_step
291-
loss = loss + balancing_loss
292-
if step_balancing_loss is None:
293-
step_balancing_loss = balancing_loss
294-
else:
295-
step_balancing_loss += balancing_loss
296-
297-
if "z_loss" in output:
298-
z_loss = output["z_loss"] / iters_per_step
299-
loss = loss + z_loss
241+
micro_batch_results.append(output)
300242

301-
if step_z_loss is None:
302-
step_z_loss = z_loss
303-
else:
304-
step_z_loss += z_loss
305-
306-
if moe_need_log_maxvio:
307-
assert "tokens_per_expert_global" in output, "tokens_per_expert_global is required for bias update."
308-
tokens_per_expert_global_for_bias += output["tokens_per_expert_global"]
309-
310-
del output
243+
loss = self._get_total_loss(output)
311244
loss.backward()
245+
total_loss += loss.detach()
312246
# call dump_forward_records after backward to record the recomputed activations
313247
ProberList.after_micro_iter_forward()
314-
step_loss += loss.detach().clone()
315-
316-
if moe_need_log_maxvio:
317-
avg_count_load = tokens_per_expert_global_for_bias.float().mean(1)
318-
max_load_i, _ = torch.max(tokens_per_expert_global_for_bias, dim=1)
319-
maxvio_all_layers = (max_load_i - avg_count_load) / avg_count_load
320-
maxvio = maxvio_all_layers.mean()
321-
if moe_need_update_bias:
322-
self.model.update_bias(tokens_per_expert_global_for_bias, avg_count_load) # type: ignore
323-
other_log["maxvio"] = maxvio.item()
324-
325-
reduced_llm_loss = step_llm_loss
326-
dist.all_reduce(reduced_llm_loss.div_(dist.get_world_size()))
327-
328-
loss_log["local_loss"] = step_loss.item()
329-
loss_log["reduced_llm_loss"] = reduced_llm_loss.item()
330-
if step_balancing_loss is not None:
331-
reduced_balancing_loss = step_balancing_loss
332-
dist.all_reduce(reduced_balancing_loss.div_(dist.get_world_size()))
333-
loss_log["reduced_balancing_loss"] = reduced_balancing_loss.item()
334-
if step_z_loss is not None:
335-
reduced_z_loss = step_z_loss
336-
dist.all_reduce(reduced_z_loss.div_(dist.get_world_size()))
337-
loss_log["reduced_z_loss"] = reduced_z_loss.item()
338-
other_log["step_consumed_tokens"] = int(step_consumed_tokens.item())
339-
other_log["extra_info"] = train_engine_extra_info
340-
other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item()
341-
342-
extra_info = other_log.get("extra_info", {}) # type: ignore
343-
344-
# TODO: @duanyanhui `extra_info` should be redesigned.
345-
if not isinstance(extra_info, ModelForwardExtraLogInfo):
346-
extra_info = ModelForwardExtraLogInfo(extra_info)
347-
loss_log.update(extra_info.get())
348-
349-
if "maxvio" in other_log:
350-
loss_log["maxvio"] = other_log["maxvio"] # type: ignore
351-
loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"] # type: ignore
352-
return loss_log, other_log
248+
249+
batch_forward_info = self.model.post_micro_batch_forward(micro_batch_results)
250+
return TrainStepInfo(total_loss=total_loss.item(), **data_batch_info, **batch_forward_info)
353251

354252
def from_hf(self, hf_path: str | Path, strict: bool = False):
355253
self.model.from_hf(hf_path=hf_path, strict=strict)
@@ -529,3 +427,17 @@ def _maybe_precompute_float8_dynamic_scale_for_fsdp(self):
529427
for model in self.model.modules():
530428
if isinstance(model, BaseModel) and model.float8_handler is not None:
531429
model.float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)
430+
431+
def _get_total_loss(self, model_outputs: ModelOutputs) -> torch.Tensor:
432+
# TODO: This logic should be moved into the model layer. The model should be responsible
433+
# for aggregating all losses (CE loss, balancing loss, z loss, etc.) and returning a
434+
# single total_loss. The engine should only call model.forward() and use the returned
435+
# total_loss directly, rather than iterating through fields to sum losses here.
436+
# This would provide better separation of concerns and make the loss computation logic
437+
# more explicit and maintainable.
438+
loss = torch.tensor(0.0, device=DEVICE)
439+
for key in model_outputs.model_fields:
440+
value = getattr(model_outputs, key)
441+
if "loss" in key and isinstance(value, torch.Tensor):
442+
loss += value
443+
return loss

0 commit comments

Comments
 (0)