Add FP8 support for SALMAutomodel#15754
Conversation
|
/ok to test 614e4b4 |
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
6c20005 to
23761ed
Compare
|
/ok to test 23761ed |
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
|
/ok to test 329f23a |
KunalDhawan
left a comment
There was a problem hiding this comment.
Great work @pzelasko, added minor comments below, other than that LGTM
| """Return the minimal sequence-length multiple so B*T is divisible by 8.""" | ||
| if batch_size <= 0: | ||
| raise ValueError(f"batch_size must be positive; got {batch_size}.") | ||
| return 8 // gcd(batch_size, 8) |
There was a problem hiding this comment.
This only ensures B*T % 8 == 0 and ignores tp_size, which is inconsistent with the THD helper (8 * cp_size * tp_size). Under BSHD + TP + TE-FP8 I think that breaks in two ways:
prepare_inputstruncates the seq dim to a multiple oftp_sizeso sequence parallelism doesn't silently reshape the input (salm_automodel.py ~L269), but thenforwardappendspad = (-T) % seq_multipletokens, so the padded length is no longer guaranteed divisible bytp_size→ SP shape break.- With SP the local TE Linear sees
M = B*T/tp_size, so FP8 actually needsB*T % (8*tp_size) == 0, not just% 8.
Could we either thread tp_size through here (note 8*tp_size alone isn't enough — e.g. B=16, tp=4 → multiple of 2, still not divisible by 4 — so probably needs an explicit lcm(tp_size, ...)), or add a validate_fp8_config rejection for BSHD + TP + TE-FP8 pointing folks at the THD packed path? A BSHD analogue of test_maybe_pad_thd_..._accounts_for_cp_and_tp would lock it down. This combo wasn't in the 2-GPU run (dp=2 ep=2, no TP), so it's currently untested.
| def backward(self, *args, **kwargs): | ||
| self._setup_moe_fsdp_sync() | ||
| with loss_parallel(): | ||
| with loss_parallel(), te_fp8_context(self.cfg.get("automodel_backend", None)): |
There was a problem hiding this comment.
Quick question on wrapping backward in te_fp8_context too: standard TE usage only wraps the forward, and the backward consumes the FP8 metadata captured during the forward's fp8_autocast. Re-entering fp8_autocast here can, for history/delayed-scaling recipes, trigger an extra amax/scale update (and a second amax all-reduce) at context exit. Probably harmless for block/current (which is what the run used), but it's an easy source of subtle scale drift on other recipes. Is it deliberate / needed for something specific? If not, dropping it from backward seems safer.
|
/ok to test 17e203f |
Summary
Testing