diff --git a/swift/metrics/utils.py b/swift/metrics/utils.py index e2e0f53c32..594c643d7f 100644 --- a/swift/metrics/utils.py +++ b/swift/metrics/utils.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist from abc import ABC, abstractmethod +from typing import Literal from swift.utils import get_current_device, get_logger @@ -72,7 +73,7 @@ def compute(self): class MeanMetric(Metric): - def __init__(self, nan_value=0, device=None, group=None): + def __init__(self, nan_value=0, device=None, group=None, reduction: Literal['sum', 'mean'] = 'mean'): super().__init__() self.nan_value = nan_value self.add_state('state', default=0.) @@ -81,6 +82,7 @@ def __init__(self, nan_value=0, device=None, group=None): device = get_current_device() self.device = device self.group = group + self.reduction = reduction def update(self, state: torch.Tensor): if isinstance(state, (torch.Tensor, np.ndarray)): @@ -104,10 +106,11 @@ def compute(self): tensor = torch.tensor([self.state, self.count], device=self.device) dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=self.group) self.state, self.count = tensor[0].item(), int(tensor[1].item()) - if self.count == 0: - value = self.nan_value + if reduction == 'sum': + value = self.state else: - value = self.state / self.count - return { - 'value': value, - } + if self.count == 0: + value = self.nan_value + else: + value = self.state / self.count + return {'value': value} diff --git a/swift/template/base.py b/swift/template/base.py index 267f2afb04..5b1e6c4212 100644 --- a/swift/template/base.py +++ b/swift/template/base.py @@ -1629,6 +1629,7 @@ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int if self.packing and isinstance(batch[0], list): batch = sum(batch, start=[]) num_samples = len(batch) + num_tokens = sum(sum([b['lengths'] for b in batch], start=[])) if self.task_type == 'causal_lm': if self.mode in {'transformers', 'train'}: res = self._data_collator(batch, padding_to=padding_to) @@ -1657,6 +1658,7 @@ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int num_samples = res.pop('num_samples') if self.use_megatron: res['num_samples'] = num_samples + res['num_tokens'] = num_tokens return res @staticmethod diff --git a/swift/trainers/patcher.py b/swift/trainers/patcher.py index a18071647e..ec233ab480 100644 --- a/swift/trainers/patcher.py +++ b/swift/trainers/patcher.py @@ -28,6 +28,15 @@ def add_train_message(logs, state, start_time, start_step) -> None: if state.max_memory: logs['memory(GiB)'] = round(state.max_memory, 2) logs['train_speed(s/it)'] = round(train_speed, 6) + num_tokens = getattr(state, 'num_tokens', None) + if num_tokens is not None: + num_tokens = float(num_tokens) + if dist.is_initialized(): + num_tokens = torch.tensor(num_tokens) + dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) + tps = num_tokens / elapsed + logs['num_input_tokens_seen'] = round(num_tokens, 6) + logs['train_speed(tokens/s)'] = round(tps, 6) class ProgressCallbackNew(ProgressCallback): @@ -50,6 +59,10 @@ def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs): add_train_message(logs, state, self.start_time, self.start_step) + n_steps = state.global_step - self.current_step + num_tokens = logs.pop('num_tokens', None) + if num_tokens is not None and n_steps > 0: + logs[''] if not is_pai_training_job() and state.is_world_process_zero: jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') append_to_jsonl(jsonl_path, logs) diff --git a/swift/trainers/seq2seq_trainer.py b/swift/trainers/seq2seq_trainer.py index fe0089ea30..ea0710ce7e 100644 --- a/swift/trainers/seq2seq_trainer.py +++ b/swift/trainers/seq2seq_trainer.py @@ -136,8 +136,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N logger.warning_once('The cross_entropy loss function defined in Liger Kernel will not ' 'take effect, potentially leading to increased GPU memory consumption.') labels = inputs.pop('labels') + num_tokens = inputs.pop('num_tokens', None) outputs = model(**inputs) mode = 'train' if self.model.training else 'eval' + if num_tokens is not None: + self.state.num_tokens += num_tokens if getattr(outputs, 'aux_loss', None) is not None: self.custom_metrics[mode]['aux_loss'].update(outputs.aux_loss) # Save past state if it exists