Skip to content

Commit 4808395

Browse files
committed
fix: enforce float32 in ssm scan for numerical stability
1 parent 4c5d2ca commit 4808395

2 files changed

Lines changed: 48 additions & 20 deletions

File tree

aetheris/cli/main.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,27 +111,39 @@ def generate_command(args):
111111
history_ids = set(input_ids[0].tolist())
112112

113113
print("-" * 50)
114-
print(f"Prompt: {prompt}")
115-
print("Generated Continuation:")
116-
117-
for _ in range(max_new_tokens):
118-
# Check if we should use autocast (skip if model uses float32)
119-
use_autocast = True
120-
if config.torch_dtype == torch.float32:
121-
use_autocast = False
122-
123-
if use_autocast:
124-
with torch.amp.autocast('cuda' if device.type == 'cuda' else 'cpu', dtype=model.config.torch_dtype):
114+
print(f"Prompt: {prompt}")
115+
print("Generated Continuation:")
116+
117+
for step in range(max_new_tokens):
118+
# Check if we should use autocast (skip if model uses float32)
119+
use_autocast = True
120+
if config.torch_dtype == torch.float32:
121+
use_autocast = False
122+
123+
if use_autocast:
124+
with torch.amp.autocast('cuda' if device.type == 'cuda' else 'cpu', dtype=model.config.torch_dtype):
125+
outputs = model(generated_ids)
126+
logits = outputs['logits']
127+
next_token_logits = logits[:, -1, :]
128+
else:
125129
outputs = model(generated_ids)
126130
logits = outputs['logits']
127131
next_token_logits = logits[:, -1, :]
128-
else:
129-
outputs = model(generated_ids)
130-
logits = outputs['logits']
131-
next_token_logits = logits[:, -1, :]
132-
133-
# Repetition penalty
134-
for token_id in history_ids:
132+
133+
# --- DEBUG: Print Top Predictions for First Step ---
134+
if step == 0:
135+
probs = F.softmax(next_token_logits, dim=-1)
136+
top_probs, top_indices = torch.topk(probs, 5)
137+
print("\n[DEBUG] Step 0 Top-5 Predictions:")
138+
for i in range(5):
139+
token_idx = top_indices[0, i].item()
140+
prob = top_probs[0, i].item()
141+
token_str = tokenizer.decode([token_idx])
142+
print(f" {i+1}. '{token_str}' ({prob:.4f})")
143+
print("-----------------------------------")
144+
# ---------------------------------------------------
145+
146+
# Repetition penalty for token_id in history_ids:
135147
if token_id < next_token_logits.size(-1):
136148
logit = next_token_logits[0, token_id].item()
137149
if logit > 0:

aetheris/modules/ssm.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,23 @@ def selective_scan_native(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
88
"""Memory-efficient scan with reduced intermediate tensors."""
99
B_size, L, D_inner = u.shape
1010
D_state = A.shape[-1]
11+
12+
# Save original dtype
13+
original_dtype = u.dtype
1114

1215
# Use in-place operations where possible
13-
h = torch.zeros(B_size, D_inner, D_state, device=u.device, dtype=u.dtype)
16+
# FORCE FLOAT32 for state to prevent underflow/overflow in long sequences
17+
h = torch.zeros(B_size, D_inner, D_state, device=u.device, dtype=torch.float32)
1418
ys = []
19+
20+
# Cast inputs to float32 for the scan
21+
# Note: This increases memory usage slightly but is critical for stability
22+
u = u.float()
23+
delta = delta.float()
24+
A = A.float()
25+
B = B.float()
26+
C = C.float()
27+
D = D.float()
1528

1629
for l in range(L):
1730
dt = delta[:, l, :].unsqueeze(-1)
@@ -28,7 +41,10 @@ def selective_scan_native(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
2841
ys.append(y_t)
2942

3043
y = torch.stack(ys, dim=1)
31-
return y + u * D
44+
y = y + u * D
45+
46+
# Cast back to original dtype
47+
return y.to(dtype=original_dtype)
3248

3349
class SSMBlock(nn.Module):
3450
"""Memory-optimized State Space Model with stability improvements."""

0 commit comments

Comments
 (0)