From f0672f4cff44ed0ee2f3e3e213caf8c0e27c06ec Mon Sep 17 00:00:00 2001 From: David Burton Date: Tue, 5 May 2026 18:31:15 -0400 Subject: [PATCH] feat(finetune): add bf16 mixed-precision via amp_dtype config Wraps the forward + loss in torch.autocast for both the tokenizer and predictor training loops, gated by a new Config.amp_dtype field. Setting "bfloat16" enables bf16 autocast; None (default) keeps the existing FP32 path bit-exact. bf16 has the same exponent range as FP32, so AdamW master weights need no scaling and no GradScaler is wired in. --- finetune/config.py | 6 ++++++ finetune/train_predictor.py | 37 +++++++++++++++++++------------- finetune/train_tokenizer.py | 29 +++++++++++++++---------- finetune/utils/training_utils.py | 26 ++++++++++++++++++++++ 4 files changed, 72 insertions(+), 26 deletions(-) diff --git a/finetune/config.py b/finetune/config.py index 04cc3ee3..7b08dd97 100644 --- a/finetune/config.py +++ b/finetune/config.py @@ -61,6 +61,12 @@ def __init__(self): # Gradient accumulation to simulate a larger batch size. self.accumulation_steps = 1 + # Mixed-precision training. Set to "bfloat16" on Ampere-class + # or newer GPUs (RTX 30/40-series, A100, H100) to reduce step + # time and activation memory. ``None`` keeps full FP32 training + # and is the default for parity with earlier runs. + self.amp_dtype = None + # AdamW optimizer parameters. self.adam_beta1 = 0.9 self.adam_beta2 = 0.95 diff --git a/finetune/train_predictor.py b/finetune/train_predictor.py index 47eddc91..ca9acf29 100644 --- a/finetune/train_predictor.py +++ b/finetune/train_predictor.py @@ -22,7 +22,8 @@ cleanup_ddp, set_seed, get_model_size, - format_time + format_time, + resolve_amp_dtype, ) @@ -68,6 +69,10 @@ def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_ train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size) + amp_dtype, amp_enabled = resolve_amp_dtype(config.get('amp_dtype')) + if rank == 0 and amp_enabled: + print(f"AMP enabled: autocast on cuda with dtype={amp_dtype}.") + optimizer = torch.optim.AdamW( model.parameters(), lr=config['predictor_learning_rate'], @@ -96,17 +101,18 @@ def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_ batch_x = batch_x.to(device, non_blocking=True) batch_x_stamp = batch_x_stamp.to(device, non_blocking=True) - # Tokenize input data on-the-fly - with torch.no_grad(): - token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp_enabled): + # Tokenize input data on-the-fly + with torch.no_grad(): + token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) - # Prepare inputs and targets for the language model - token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]] - token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]] + # Prepare inputs and targets for the language model + token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]] + token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]] - # Forward pass and loss calculation - logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :]) - loss, s1_loss, s2_loss = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1]) + # Forward pass and loss calculation + logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :]) + loss, s1_loss, s2_loss = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1]) # Backward pass and optimization optimizer.zero_grad() @@ -140,12 +146,13 @@ def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_ batch_x = batch_x.to(device, non_blocking=True) batch_x_stamp = batch_x_stamp.to(device, non_blocking=True) - token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) - token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]] - token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]] + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp_enabled): + token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) + token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]] + token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]] - logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :]) - val_loss, _, _ = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1]) + logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :]) + val_loss, _, _ = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1]) tot_val_loss_sum_rank += val_loss.item() val_batches_processed_rank += 1 diff --git a/finetune/train_tokenizer.py b/finetune/train_tokenizer.py index 60186e1e..36913e2f 100644 --- a/finetune/train_tokenizer.py +++ b/finetune/train_tokenizer.py @@ -26,6 +26,7 @@ set_seed, get_model_size, format_time, + resolve_amp_dtype, ) @@ -95,6 +96,10 @@ def train_model(model, device, config, save_dir, logger, rank, world_size): train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size) + amp_dtype, amp_enabled = resolve_amp_dtype(config.get('amp_dtype')) + if rank == 0 and amp_enabled: + print(f"AMP enabled: autocast on cuda with dtype={amp_dtype}.") + optimizer = torch.optim.AdamW( model.parameters(), lr=config['tokenizer_learning_rate'], @@ -133,15 +138,16 @@ def train_model(model, device, config, save_dir, logger, rank, world_size): end_idx = (j + 1) * (ori_batch_x.shape[0] // config['accumulation_steps']) batch_x = ori_batch_x[start_idx:end_idx] - # Forward pass - zs, bsq_loss, _, _ = model(batch_x) - z_pre, z = zs + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp_enabled): + # Forward pass + zs, bsq_loss, _, _ = model(batch_x) + z_pre, z = zs - # Loss calculation - recon_loss_pre = F.mse_loss(z_pre, batch_x) - recon_loss_all = F.mse_loss(z, batch_x) - recon_loss = recon_loss_pre + recon_loss_all - loss = (recon_loss + bsq_loss) / 2 # Assuming w_1=w_2=1 + # Loss calculation + recon_loss_pre = F.mse_loss(z_pre, batch_x) + recon_loss_all = F.mse_loss(z, batch_x) + recon_loss = recon_loss_pre + recon_loss_all + loss = (recon_loss + bsq_loss) / 2 # Assuming w_1=w_2=1 loss_scaled = loss / config['accumulation_steps'] current_batch_total_loss += loss.item() @@ -177,9 +183,10 @@ def train_model(model, device, config, save_dir, logger, rank, world_size): with torch.no_grad(): for ori_batch_x, _ in val_loader: ori_batch_x = ori_batch_x.to(device, non_blocking=True) - zs, _, _, _ = model(ori_batch_x) - _, z = zs - val_loss_item = F.mse_loss(z, ori_batch_x) + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp_enabled): + zs, _, _, _ = model(ori_batch_x) + _, z = zs + val_loss_item = F.mse_loss(z, ori_batch_x) tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0) val_sample_count_rank += ori_batch_x.size(0) diff --git a/finetune/utils/training_utils.py b/finetune/utils/training_utils.py index 8756322a..9bad8d0a 100644 --- a/finetune/utils/training_utils.py +++ b/finetune/utils/training_utils.py @@ -116,3 +116,29 @@ def format_time(seconds: float) -> str: +def resolve_amp_dtype(amp_dtype): + """ + Resolves the configured AMP dtype string into the arguments expected by + `torch.autocast`. + + Currently only "bfloat16" is supported. Passing ``None`` disables + mixed-precision training; the returned dtype is then irrelevant and + autocast becomes a no-op. + + Args: + amp_dtype (str | None): "bfloat16" or None. + + Returns: + tuple[torch.dtype, bool]: (dtype, enabled) suitable for + ``torch.autocast(device_type=..., dtype=dtype, enabled=enabled)``. + + Raises: + ValueError: If ``amp_dtype`` is set to an unsupported value. + """ + if amp_dtype is None: + return torch.float32, False + if amp_dtype == "bfloat16": + return torch.bfloat16, True + raise ValueError( + f"Unsupported amp_dtype {amp_dtype!r}; expected 'bfloat16' or None." + )