Skip to content

Commit d771e11

Browse files
fix: Reduce LR to 5e-5, fix NaN handling, and track skipped batches in dataset state
1 parent becdbe4 commit d771e11

3 files changed

Lines changed: 23 additions & 7 deletions

File tree

.github/workflows/train.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ jobs:
142142
TOTAL_STEPS: '100000'
143143
GRAD_ACCUM: '4'
144144
BATCH_SIZE: '2'
145+
LEARNING_RATE: '5e-5'
145146
BLOCK_SIZE: '512'
146147
USE_EWC: '1'
147148
GRADIENT_CHECKPOINTING: '1'

meridian/training/trainer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class TrainingConfig:
4242
total_steps: int = 100_000
4343

4444
# Optimizer
45-
learning_rate: float = 3e-4
45+
learning_rate: float = 5e-5
4646
weight_decay: float = 0.1
4747
max_grad_norm: float = 1.0
4848
warmup_ratio: float = 0.06
@@ -92,6 +92,7 @@ def __init__(
9292
# State
9393
self.global_step = 0
9494
self.run_step = 0
95+
self.processed_batches = 0
9596
self.best_loss = float("inf")
9697

9798
# EWC for continual learning
@@ -193,6 +194,7 @@ def train(self) -> None:
193194
# Get batch
194195
try:
195196
batch = next(data_iter)
197+
self.processed_batches += 1
196198
except StopIteration:
197199
print("[INFO] Dataset exhausted. Ending training.")
198200
break
@@ -327,6 +329,14 @@ def train(self) -> None:
327329

328330
def save_checkpoint(self, path: str) -> None:
329331
"""Save model + optimizer + trainer state."""
332+
# Sanity check for NaN weights
333+
for name, param in self.model.named_parameters():
334+
if torch.isnan(param).any():
335+
print(
336+
f" [CRITICAL] NaN detected in parameter '{name}'. Aborting checkpoint save to protect repo."
337+
)
338+
return
339+
330340
os.makedirs(path, exist_ok=True)
331341
print(f" [SAVE] Checkpoint → {path}")
332342

train.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,17 @@ def main():
218218
tokenizer.save_pretrained(checkpoint_path)
219219

220220
# Update dataset state
221-
batches_this_run = (
222-
(trainer.global_step - initial_global_step)
223-
* train_config.batch_size
224-
* train_config.gradient_accumulation_steps
225-
)
226-
new_processed = processed_items + batches_this_run
221+
# Use actual processed batches from trainer (includes skipped ones)
222+
if hasattr(trainer, "processed_batches"):
223+
batches_processed = trainer.processed_batches
224+
else:
225+
# Fallback for backward compatibility
226+
batches_processed = (
227+
trainer.global_step - initial_global_step
228+
) * train_config.gradient_accumulation_steps
229+
230+
items_processed = batches_processed * train_config.batch_size
231+
new_processed = processed_items + items_processed
227232

228233
for sp in [state_path, os.path.join(checkpoint_path, "dataset_state.json")]:
229234
with open(sp, "w") as f:

0 commit comments

Comments
 (0)