Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/engine/test_dense_train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def warmup_fn(x):
seq_ctx = seq_ctx_list[0]
loss_ctx = loss_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
loss_log, _ = engine.train_step(engine_input)
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
lr_scheduler.step()
Expand Down
4 changes: 2 additions & 2 deletions tests/engine/test_moe_train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def warmup_fn(x):
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
loss_log, _ = engine.train_step(engine_input)
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
lr_scheduler.step()
Expand Down Expand Up @@ -190,7 +190,7 @@ def warmup_fn(x):
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
loss_log, _ = engine.train_step(engine_input)
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
lr_scheduler.step()
Expand Down
8 changes: 4 additions & 4 deletions tests/engine/test_moe_train_engine_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def warmup_fn(x):
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
loss_log, _ = engine.train_step(engine_input)
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
lr_scheduler.step()
Expand Down Expand Up @@ -171,7 +171,7 @@ def warmup_fn(x):
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
loss_log, _ = engine.train_step(engine_input)
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
lr_scheduler.step()
Expand Down Expand Up @@ -270,11 +270,11 @@ def warmup_fn(x):
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
loss_log, _ = engine.train_step(engine_input)
logs_info = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
lr_scheduler.step()
losses.append(loss_log["reduced_llm_loss"])
losses.append(logs_info["reduced_llm_loss"])
losses_ref = torch.tensor([2.41, 2.41, 2.47, 2.42, 2.44, 2.44, 2.42, 2.38, 2.31, 2.30])
losses = torch.tensor(losses)
self._check_loss_curve(losses, losses_ref)
Expand Down
7 changes: 3 additions & 4 deletions tests/model/test_qwen3_tile_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from xtuner.v1.loss.ce_loss import CELossConfig
from xtuner.v1.config import FSDPConfig, LRConfig, AdamWConfig
from xtuner.v1.engine.train_engine import TrainEngine
from xtuner.v1.engine.vision_compose_train_engine import VisionComposeTrainEngine
from torch.optim.lr_scheduler import LambdaLR
from xtuner.v1.utils import pad_to_max_length
from xtuner.v1.utils.device import get_device
Expand Down Expand Up @@ -85,7 +84,7 @@ def warmup_fn(x):
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
loss_log, _ = engine.train_step(engine_input)
engine.train_step(engine_input)
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
lr_scheduler.step()
Expand Down Expand Up @@ -116,7 +115,7 @@ def test_qwen3vl_tie_embedding(self, device, tp_size):
cpu_offload=False,
tp_size=tp_size
)
engine = VisionComposeTrainEngine(
engine = TrainEngine(
model_cfg=dense_cfg, optim_cfg=optim_cfg, fsdp_cfg=fsdp_cfg
)
engine.from_hf(hf_path=QWEN3_VL_DENSE_PATH)
Expand Down Expand Up @@ -160,7 +159,7 @@ def warmup_fn(x):
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
loss_log, _ = engine.train_step(engine_input)
engine.train_step(engine_input)
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
lr_scheduler.step()
Expand Down
17 changes: 5 additions & 12 deletions tests/train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
from xtuner.v1.datasets import FTDPTokenizeFnConfig
from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig
from xtuner.v1.train.trainer import TrainerConfig
from xtuner.v1.engine.train_engine import LossLog, OtherLog
from xtuner.v1.loss import CELossConfig
from xtuner._testing import DeterministicDDPTestCase
from unittest import TestCase
from xtuner.v1.train.trainer import XTunerMeta, ExpInfo, ExpHistory, GitInfo
from xtuner.v1.utils.device import get_device
from xtuner.v1.datasets.dataloader import Dataloader
from torch.optim.lr_scheduler import SequentialLR
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo


DEVICE = get_device()
Expand All @@ -55,10 +55,8 @@ def grad_accumulation_steps(self, *args, **kwargs):

def train_step(self, *args, **kwargs):
self.train_step_calls += 1
return (
{"local_loss": 1.0, "reduced_llm_loss": 0.8},
{"step_consumed_tokens": 100, "grad_norm": torch.tensor(1.0), "efficient_attn_ratio": 0.5}
)
return {"total_loss": 1.8, "step_consumed_tokens": 100, "grad_norm": torch.tensor(1.0), "efficient_attn_ratio": 0.5, "logs_info": {"local_loss": 1.0, "reduced_llm_loss": 0.8}, "extra_info": ModelForwardExtraLogInfo()}


def save_hf(self, hf_path):
self.save_hf_calls.append(hf_path)
Expand Down Expand Up @@ -647,14 +645,12 @@ def test_hooks_config(self):
self.create_pg(DEVICE)
checkpoint_function_call_times = 0
train_step_function_call_times = 0
losslog_adapater = TypeAdapter(LossLog)
otherlog_adapter = TypeAdapter(OtherLog)

def checkpoint_hook(checkpoint, step, epoch, total_step, total_epoch):
nonlocal checkpoint_function_call_times
checkpoint_function_call_times += 1

def train_step_hook(loss_log, other_log, step, epoch, total_step, total_epoch):
def train_step_hook(train_step_info, step, epoch, total_step, total_epoch):
nonlocal train_step_function_call_times
train_step_function_call_times += 1

Expand All @@ -673,10 +669,7 @@ def connect_trainer(self, trainer: Trainer):
def __init__(self) -> None:
self.count = 0

def __call__(self, loss_log, other_log, step, epoch, total_step, total_epoch):
losslog_adapater.validate_python(loss_log)
otherlog_adapter.validate_python(other_log)

def __call__(self, train_step_info, step, epoch, total_step, total_epoch):
assert self.trainer().cur_step == step
assert self.trainer().cur_epoch == epoch
assert self.trainer().total_step == total_step
Expand Down
6 changes: 1 addition & 5 deletions xtuner/v1/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from xtuner.v1.engine.config import EngineConfig

from .train_engine import LossLog, OtherLog, TrainEngine
from .vision_compose_train_engine import VisionComposeTrainEngine
from .train_engine import TrainEngine


__all__ = [
"TrainEngine",
"EngineConfig",
"VisionComposeTrainEngine",
"LossLog",
"OtherLog",
]
7 changes: 1 addition & 6 deletions xtuner/v1/engine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

from xtuner.v1.config import FSDPConfig, OptimConfig
from xtuner.v1.engine.train_engine import TrainEngine
from xtuner.v1.engine.vision_compose_train_engine import VisionComposeTrainEngine
from xtuner.v1.model.base import BaseModel, ConfigDict
from xtuner.v1.model.compose.base import BaseComposeConfig


@runtime_checkable
Expand All @@ -27,7 +25,4 @@ class EngineConfig(PydanticBaseModel):
model_cfg: ModelConfigProto

def build(self):
if isinstance(self.model_cfg, BaseComposeConfig):
return VisionComposeTrainEngine(model_cfg=self.model_cfg, optim_cfg=self.optim_cfg, fsdp_cfg=self.fsdp_cfg)
else:
return TrainEngine(model_cfg=self.model_cfg, optim_cfg=self.optim_cfg, fsdp_cfg=self.fsdp_cfg)
return TrainEngine(model_cfg=self.model_cfg, optim_cfg=self.optim_cfg, fsdp_cfg=self.fsdp_cfg)
Loading