Skip to content

Commit 0d80d09

Browse files
committed
fix(training): harden model against NaN/Inf instability in fp16
- ssm.py: Clamp `delta` projection (max=5.0) to prevent SSM state explosion. - trainer.py: Enhance gradient clipping to detect and skip `Inf` gradients (overflow), not just `NaN`. - model.py: Add runtime fail-safe to patch layer outputs using `nan_to_num` if corruption is detected during forward pass. This addresses the loss divergence observed at step 11k.
1 parent f01c9a2 commit 0d80d09

3 files changed

Lines changed: 8 additions & 7 deletions

File tree

aetheris/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None) -> Dict[
4343
x = self.embedding(input_ids)
4444
total_aux_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
4545

46-
for layer in self.layers:
46+
for i, layer in enumerate(self.layers):
4747
if self.gradient_checkpointing and self.training:
4848
# Checkpoint ALL layers for maximum memory savings
4949
if isinstance(layer, SparseMoELayer):
@@ -66,7 +66,7 @@ def moe_forward(module, inp):
6666

6767
# Add gradient clipping per layer to catch issues early
6868
if self.training and torch.isnan(x).any():
69-
print(f"WARNING: NaN detected in layer output!")
69+
print(f"WARNING: NaN detected in layer {i} ({layer.__class__.__name__}) output!")
7070
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
7171

7272
x = self.final_norm(x)

aetheris/modules/ssm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7474
x_conv = x_conv[:, :, :-2].transpose(1, 2)
7575
x_conv = self.act(x_conv)
7676

77-
# Add small epsilon to prevent numerical issues
78-
delta = F.softplus(self.delta_proj(x_conv)) + 1e-4
77+
# Add small epsilon to prevent numerical issues and clamp max value
78+
delta = torch.clamp(F.softplus(self.delta_proj(x_conv)), max=5.0) + 1e-4
7979
B_ssm = self.B_proj(x_conv)
8080
C_ssm = self.C_proj(x_conv)
8181

aetheris/trainer/trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,11 @@ def train_epoch(self, train_loader, total_steps, start_step=0, stage_name="Train
7272
# Gradient clipping
7373
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
7474

75-
if torch.isnan(grad_norm):
76-
print(f"WARNING: NaN gradient at step {global_step}, skipping update")
75+
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
76+
print(f"WARNING: NaN/Inf gradient at step {global_step}, skipping update")
77+
else:
78+
self.scaler.step(self.optimizer)
7779

78-
self.scaler.step(self.optimizer)
7980
self.scaler.update()
8081

8182
global_step += 1

0 commit comments

Comments
 (0)