Skip to content

Add FP8 support for SALMAutomodel#15754

Open
pzelasko wants to merge 4 commits into
mainfrom
codex/salm-automodel-fp8
Open

Add FP8 support for SALMAutomodel#15754
pzelasko wants to merge 4 commits into
mainfrom
codex/salm-automodel-fp8

Conversation

@pzelasko
Copy link
Copy Markdown
Collaborator

@pzelasko pzelasko commented Jun 4, 2026

Summary

  • add SpeechLM2 FP8 helpers for TorchAO config creation, TE FP8 autocast/patching, FSDP scale precompute, and TE FP8 padding/alignment
  • wire SALMAutomodel through those helpers without cached FP8 state and keep the forward/backward flow minimal
  • support TE FP8 padding for BSHD and THD packed inputs, including context-parallel alignment metadata
  • document TE FP8 and TorchAO FP8 config examples and add focused unit/runtime coverage

Testing

  • git diff --cached --check
  • pytest tests/collections/speechlm2/test_fp8.py tests/collections/speechlm2/test_salm_packed_sequences.py -q
  • in nemo-speech:cu13-h100plus: pytest tests/collections/speechlm2/test_salm_automodel.py -q
  • in nemo-speech:cu13-h100plus: 2-GPU LibriSpeech training with Nemotron 3 Nano + Canary-1B-Flash, FSDP2 dp=2 ep=2, TE FP8 recipe=block, completed 2000/2000 steps in 1:18:02 with final logged loss 0.0586

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Jun 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@pzelasko
Copy link
Copy Markdown
Collaborator Author

pzelasko commented Jun 4, 2026

/ok to test 614e4b4

pzelasko added 2 commits June 4, 2026 11:14
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
@pzelasko pzelasko force-pushed the codex/salm-automodel-fp8 branch from 6c20005 to 23761ed Compare June 4, 2026 18:14
@pzelasko
Copy link
Copy Markdown
Collaborator Author

pzelasko commented Jun 4, 2026

/ok to test 23761ed

Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
@pzelasko pzelasko requested a review from a team as a code owner June 4, 2026 21:42
@pzelasko
Copy link
Copy Markdown
Collaborator Author

pzelasko commented Jun 4, 2026

/ok to test 329f23a

Copy link
Copy Markdown
Collaborator

@KunalDhawan KunalDhawan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @pzelasko, added minor comments below, other than that LGTM

Comment thread nemo/collections/speechlm2/parts/fp8.py Outdated
"""Return the minimal sequence-length multiple so B*T is divisible by 8."""
if batch_size <= 0:
raise ValueError(f"batch_size must be positive; got {batch_size}.")
return 8 // gcd(batch_size, 8)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only ensures B*T % 8 == 0 and ignores tp_size, which is inconsistent with the THD helper (8 * cp_size * tp_size). Under BSHD + TP + TE-FP8 I think that breaks in two ways:

  • prepare_inputs truncates the seq dim to a multiple of tp_size so sequence parallelism doesn't silently reshape the input (salm_automodel.py ~L269), but then forward appends pad = (-T) % seq_multiple tokens, so the padded length is no longer guaranteed divisible by tp_size → SP shape break.
  • With SP the local TE Linear sees M = B*T/tp_size, so FP8 actually needs B*T % (8*tp_size) == 0, not just % 8.

Could we either thread tp_size through here (note 8*tp_size alone isn't enough — e.g. B=16, tp=4 → multiple of 2, still not divisible by 4 — so probably needs an explicit lcm(tp_size, ...)), or add a validate_fp8_config rejection for BSHD + TP + TE-FP8 pointing folks at the THD packed path? A BSHD analogue of test_maybe_pad_thd_..._accounts_for_cp_and_tp would lock it down. This combo wasn't in the 2-GPU run (dp=2 ep=2, no TP), so it's currently untested.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch

def backward(self, *args, **kwargs):
self._setup_moe_fsdp_sync()
with loss_parallel():
with loss_parallel(), te_fp8_context(self.cfg.get("automodel_backend", None)):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question on wrapping backward in te_fp8_context too: standard TE usage only wraps the forward, and the backward consumes the FP8 metadata captured during the forward's fp8_autocast. Re-entering fp8_autocast here can, for history/delayed-scaling recipes, trigger an extra amax/scale update (and a second amax all-reduce) at context exit. Probably harmless for block/current (which is what the run used), but it's an easy source of subtle scale drift on other recipes. Is it deliberate / needed for something specific? If not, dropping it from backward seems safer.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch too

@pzelasko
Copy link
Copy Markdown
Collaborator Author

pzelasko commented Jun 5, 2026

/ok to test 17e203f

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants