Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions swift/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.distributed as dist
from abc import ABC, abstractmethod
from typing import Literal

from swift.utils import get_current_device, get_logger

Expand Down Expand Up @@ -72,7 +73,7 @@ def compute(self):

class MeanMetric(Metric):

def __init__(self, nan_value=0, device=None, group=None):
def __init__(self, nan_value=0, device=None, group=None, reduction: Literal['sum', 'mean'] = 'mean'):
super().__init__()
self.nan_value = nan_value
self.add_state('state', default=0.)
Expand All @@ -81,6 +82,7 @@ def __init__(self, nan_value=0, device=None, group=None):
device = get_current_device()
self.device = device
self.group = group
self.reduction = reduction

def update(self, state: torch.Tensor):
if isinstance(state, (torch.Tensor, np.ndarray)):
Expand All @@ -104,10 +106,11 @@ def compute(self):
tensor = torch.tensor([self.state, self.count], device=self.device)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=self.group)
self.state, self.count = tensor[0].item(), int(tensor[1].item())
if self.count == 0:
value = self.nan_value
if reduction == 'sum':

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The variable reduction is not defined in the compute method. It should be accessed via self.reduction as initialized in __init__.

Suggested change
if reduction == 'sum':
if self.reduction == 'sum':

value = self.state
else:
value = self.state / self.count
return {
'value': value,
}
if self.count == 0:
value = self.nan_value
else:
value = self.state / self.count
return {'value': value}
2 changes: 2 additions & 0 deletions swift/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,6 +1629,7 @@ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int
if self.packing and isinstance(batch[0], list):
batch = sum(batch, start=[])
num_samples = len(batch)
num_tokens = sum(sum([b['lengths'] for b in batch], start=[]))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using sum(..., start=[]) to flatten lists is highly inefficient ($O(N^2)$ complexity) and can cause performance bottlenecks with larger batch sizes. Additionally, if any batch element is missing the 'lengths' key, it will raise a KeyError. A more efficient and robust approach is to sum the lengths generator-style.

Suggested change
num_tokens = sum(sum([b['lengths'] for b in batch], start=[]))
num_tokens = sum(sum(b.get('lengths', [])) for b in batch)

if self.task_type == 'causal_lm':
if self.mode in {'transformers', 'train'}:
res = self._data_collator(batch, padding_to=padding_to)
Expand Down Expand Up @@ -1657,6 +1658,7 @@ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int
num_samples = res.pop('num_samples')
if self.use_megatron:
res['num_samples'] = num_samples
res['num_tokens'] = num_tokens
return res

@staticmethod
Expand Down
13 changes: 13 additions & 0 deletions swift/trainers/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ def add_train_message(logs, state, start_time, start_step) -> None:
if state.max_memory:
logs['memory(GiB)'] = round(state.max_memory, 2)
logs['train_speed(s/it)'] = round(train_speed, 6)
num_tokens = getattr(state, 'num_tokens', None)
if num_tokens is not None:
num_tokens = float(num_tokens)
if dist.is_initialized():
num_tokens = torch.tensor(num_tokens)
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM)
tps = num_tokens / elapsed
logs['num_input_tokens_seen'] = round(num_tokens, 6)
logs['train_speed(tokens/s)'] = round(tps, 6)
Comment on lines +31 to +39

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are several issues here:\n1. torch and dist are not imported in this file, which will cause a NameError at runtime.\n2. torch.tensor(num_tokens) creates a CPU tensor by default. In distributed training with NCCL backend, all_reduce on CPU tensors is not supported and will fail.\n3. num_tokens becomes a torch.Tensor after all_reduce, and calling round() on it or dividing it might cause issues or keep it as a tensor. It should be converted back to a float using .item().

    num_tokens = getattr(state, 'num_tokens', None)\n    if num_tokens is not None:\n        import torch\n        import torch.distributed as dist\n        from swift.utils import get_current_device\n        num_tokens = float(num_tokens)\n        if dist.is_initialized():\n            device = get_current_device()\n            num_tokens_tensor = torch.tensor(num_tokens, device=device)\n            dist.all_reduce(num_tokens_tensor, op=dist.ReduceOp.SUM)\n            num_tokens = num_tokens_tensor.item()\n        tps = num_tokens / elapsed\n        logs['num_input_tokens_seen'] = round(num_tokens, 6)\n        logs['train_speed(tokens/s)'] = round(tps, 6)



class ProgressCallbackNew(ProgressCallback):
Expand All @@ -50,6 +59,10 @@ def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader

def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs):
add_train_message(logs, state, self.start_time, self.start_step)
n_steps = state.global_step - self.current_step
num_tokens = logs.pop('num_tokens', None)
if num_tokens is not None and n_steps > 0:
logs['']

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line is incomplete and will raise a KeyError at runtime because '' is not a valid key in logs. Please complete the logging logic or remove this line.

if not is_pai_training_job() and state.is_world_process_zero:
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
append_to_jsonl(jsonl_path, logs)
Expand Down
3 changes: 3 additions & 0 deletions swift/trainers/seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
logger.warning_once('The cross_entropy loss function defined in Liger Kernel will not '
'take effect, potentially leading to increased GPU memory consumption.')
labels = inputs.pop('labels')
num_tokens = inputs.pop('num_tokens', None)
outputs = model(**inputs)
mode = 'train' if self.model.training else 'eval'
if num_tokens is not None:
self.state.num_tokens += num_tokens
Comment on lines +142 to +143

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

self.state (which is a TrainerState from transformers) does not have a num_tokens attribute by default. This will raise an AttributeError on the first training step. It should be initialized safely.

        if num_tokens is not None:\n            self.state.num_tokens = getattr(self.state, 'num_tokens', 0) + num_tokens

if getattr(outputs, 'aux_loss', None) is not None:
self.custom_metrics[mode]['aux_loss'].update(outputs.aux_loss)
# Save past state if it exists
Expand Down
Loading