Skip to content

Tool: evaluate layer-wise numerical-error propagation#525

Open
jlamypoirier wants to merge 43 commits into
mainfrom
jlp_evaluate_precision
Open

Tool: evaluate layer-wise numerical-error propagation#525
jlamypoirier wants to merge 43 commits into
mainfrom
jlp_evaluate_precision

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier commented May 26, 2026

Summary

  • New tools/evaluate_precision.py — inherits PretrainedGPTModelConfig (so model: and pretrained: are real typed Config fields) and adds variants:, output_dir:, and a few optional knobs. Runs a fp32 reference plus one trainer invocation per variant in-process; captures per-layer forward activations and input gradients via the standard tensor-logs pipeline; emits per-tensor RMS / max diffs as a console table + precision_report.json.
  • Variants aren't dtype-only: each is a flat dict of dotted-path overrides (same syntax as Fast-LLM CLI key=value args) so a variant can sweep any config knob — attention implementation, optimizer dtype, fused vs unfused, etc.
  • Per-variant trainer configs are built with TrainerConfig.get_subclass(...).from_dict(base, fp32_dtypes, variant_updates, tool_overrides). Tuple-keyed updates compose in precedence order: forced fp32 → variant overrides (which can re-override fp32) → tool-required debug-logging overrides (which always win).
  • Training, optimizer, and data sections of the trainer config are hardcoded inside the tool (single iteration, no checkpoint save, random tokens, LR 0, fp32 optimization dtype). Only knobs that affect the propagation measurement are user-facing: model, pretrained, variants, output_dir, num_samples, micro_batch_size, sequence_length.
  • Moves compare_tensor_logs.py from tests/utils/ into fast_llm/engine/config_utils/ so it's importable from tools/, and factors a _compute_diff helper out of CompareConfig.compare_tensors so the tool can extract numbers for every tensor — not only those that breach a tolerance. Three test callers updated; behaviour unchanged.
  • Fills in the HF metadata allowlist (fast_llm/engine/checkpoint/huggingface.py) with the generic PretrainedConfig keys newer transformers versions serialize: generation defaults, encoder-decoder flags, family markers, torchscript, is_decoder, etc. Without this, loading any modern HF Llama checkpoint trips the coverage walker. None are architecture knobs Fast-LLM consumes.

Usage

python -m tools.evaluate_precision -c tool.yaml
pretrained:
  path: /path/to/local/hf/snapshot
  format: llama
output_dir: /tmp/precision_eval
variants:
  bf16:
    model.distributed.compute_dtype: bfloat16
  bf16_sdpa:
    model.distributed.compute_dtype: bfloat16
    model.base_model.decoder.block.mixer.implementation: sdpa

Fast-LLM's HF loader reads weights from a local directory, so HF Hub IDs need to be snapshot_download'd first. model: and pretrained: can also be combined — pretrained provides architecture+weights, model: overrides individual fields.

Test plan

  • Cluster smoke test on a real HF checkpoint (SmolLM2-135M, snapshot via huggingface_hub). Reference fp32 + bf16 variant ran end-to-end; per-layer RMS/max table populated for all 30 decoder layers + embeddings + head, fw + bw; JSON artifact round-trips through json.load. Output shows propagated error growing with depth, with sharp jumps at layers where activation magnitude regime changes (e.g. ref_scale 6 → 777 around block 11, bf16 RMS rel 10% → 0.7% → back up to 13% at block 28).
  • Existing layer-comparison tests still pass with the moved compare_tensor_logs.py and the refactored compare_tensors.

🤖 Generated with Claude Code

jlamypoirier and others added 4 commits May 27, 2026 15:07
A new `tools/evaluate_precision.py` (`RunnableConfig`) drives a fp32
reference run plus one one-iteration trainer run per named variant from
a Fast-LLM training YAML, then extracts per-layer forward activations
and input gradients from the saved tensor logs and reports per-tensor
RMS and max diffs (absolute and scaled). Variants are flat dicts of
dotted-path overrides, the same syntax as Fast-LLM CLI key=value args,
so they can sweep arbitrary configuration knobs (dtype, attention
implementation, optimizer dtype, etc.) — not just compute_dtype.

Also moves `compare_tensor_logs.py` into the `fast_llm` package so it
is importable from `tools/` (the test tree isn't on sys.path for
script entry points), and factors a `_compute_diff` helper out of
`CompareConfig.compare_tensors` so the tool can extract numbers for
every tensor rather than only those that breach a tolerance. Existing
test callers are unaffected.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The tool now takes a single YAML containing `pretrained:` (the
checkpoint that defines the model architecture + weights), `variants:`,
`output_dir:` and a few optional knobs (`model_type`, `num_samples`,
`micro_batch_size`, `sequence_length`). The training/optimizer/data
sections of the underlying training config are hardcoded — they have
no bearing on the propagation measurement (1 iteration, no checkpoint
save, random tokens, dummy learning rate, optimization dtype forced to
float32 alongside compute dtype). A variant can still override any of
the hardcoded fields via the dotted-path mechanism if needed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The tool's input mirrors the trainer config's top-level shape: both
`model:` (FastLLMModelConfig dict) and `pretrained:` are user-facing,
and either or both may be set. Pretrained-from-HF is one config choice
among many — a user can also specify the architecture inline, or load
from HF and override individual fields.

The forced fp32 dtypes and tool-required debug levels are now applied
as overrides on top of whatever the user supplies, instead of being
hardcoded into the model section.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The tool now inherits from `PretrainedGPTModelConfig` so `model` and
`pretrained` are typed `FastLLMModelConfig` / `CheckpointLoadConfig`
fields rather than loose dicts — validated, autocompleted, and
introspectable like any other Fast-LLM config block.

Per-variant trainer configs are built with `TrainerConfig.get_subclass(...)
.from_dict(base, *updates)` instead of mutating a dict and round-tripping
through YAML. Updates use tuple-keyed dotted paths so forced-fp32,
variant overrides, and tool-required debug-logging overrides compose
cleanly in the right precedence.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jlamypoirier jlamypoirier force-pushed the jlp_evaluate_precision branch from 6431307 to 4c444d8 Compare May 27, 2026 19:16
jlamypoirier and others added 21 commits May 27, 2026 15:46
`transformers.PretrainedConfig.to_dict()` serializes a growing set of
generic defaults (generation knobs, family markers, encoder-decoder
flags). The Fast-LLM allowlist covered only a subset, so loading any
modern HF Llama checkpoint via `pretrained.format: llama` tripped the
coverage walker on keys like `torchscript`, `is_decoder`,
`is_llama_config`, `rope_interleaved`, and the full set of generation
defaults.

Fill in the missing entries, grouped by category. None of them are
architecture knobs that Fast-LLM consumes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop step / shape / max_rel columns, shorten the tensor name to the
description after the colon, reorder to Tensor / Kind / Relative /
Absolute / Max / Scale, format Relative as percent and the rest with
`.3g`. The JSON report keeps every field.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop the separate Kind column and append `(fw)` / `(bw)` to the
shortened tensor name. Switch numeric formatting to fixed precision:
Relative shows `.2f` percent, Absolute / Max / Scale show `.2e`
scientific. Every column now lines up on a consistent digit count.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Scientific notation was overkill for values that mostly land between
0.01 and a few hundred. `.3f` is more readable while keeping the
per-column digit count consistent.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Fast-LLM's `Run.__init__` picks the next free `runs/<n>` subdirectory
based on what already exists, but `_artifact_path` reads `runs/0`
unconditionally. Without this wipe, re-running the tool against the
same `output_dir` reads stale artifacts from the first invocation and
silently reports old numbers — even though the trainer correctly ran
with the new config.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a `data_path` field to the tool. When set, the tool lazily
generates a tokenized memmap dataset with random advantages and
old_logprobs at the given path (via the test helper
`tests/utils/dataset._get_test_dataset`) and uses it as the training
input. Required for policy-gradient losses like GSPO/GRPO that consume
those fields. Without it, the tool falls back to the random token
generator as before.

Console table now formats numeric columns with `.4g` so 1e-7-scale
GSPO gradients aren't rounded to zero while normal CE-magnitude
values still read as fixed-point numbers.

Rename `download_santacoder_tokenizer` to `download_test_tokenizer` —
it actually downloads the GPT-2 tokenizer.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
After the per-tensor tables, emit a short summary block per variant
showing first/last/max/median for forward and backward separately.
Aggregates over the intermediate layers per metric column (max and
median are computed per-column, so each row is a per-metric envelope
of the intermediate band rather than the metrics of any single layer).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Single compact table with one row per variant and columns for fw/bw
first/last/max/median Relative %. Max/median are over intermediate
layers (excluding first/last) when there is at least one intermediate
row.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Rename `max`/`median` columns to `mid max`/`mid med` and add a header
note (`mid = excluding first/last`) so it's clear the aggregation
excludes the boundary layers. Also fix a column-collision bug where
labels at exactly the cell width touched without separator.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each variant now occupies two rows in the summary (fw on the first,
bw on the second), with the metric columns shared. Reads more
naturally and keeps the table half as wide. Percent precision goes
from .2f to .3f so single-digit-percent differences between variants
are visible.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Top header line groups columns under `fw` / `bw`; the second line
lists the per-pass aggregations. Aggregations are ordered
chronologically along the pass — first → mid med → mid max → last —
so reading left to right traces the propagation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds an `fp32_lm_head` field on `LanguageModelHeadConfig`. When `True`,
the LM head linear's input and weight are upcast to FP32 before the
matmul, matching vLLM's `bf16_last_layer_fp32` quantization. This lets
the trainer compute log-probabilities at the same numerical precision
as the actor's sampling, so the importance-sampling ratio starts near
1.0 instead of being artificially inflated by a trainer/actor precision
mismatch.

The detached FP32 weight has `requires_grad=False`, which makes
`output_parallel_linear_backward` skip the weight-grad path. The FSDP
gradient contract is restored by computing `grad_weight = grad.t() @ saved_input`
explicitly and accumulating into the original BF16 param's `grad_buffer`
via `accumulate_gradient`.

Off by default — disabled path is byte-identical to before.

Cherry-picked from #526 to unblock the precision-evaluation tool's
GSPO smoke test, which compares fp32_lm_head=true vs false.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Instead of generic `first` / `last` headers in the summary, use the
actual layer name pulled from the matching tensor's `Global <layer>
<kind>:` prefix. For the SmolLM2 smoke run that surfaces as
`embeddings` / `head` on fw and `head` / `decoder.0` on bw — directly
showing which layer the boundary values come from rather than making
the reader guess.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…den_states

Previously the only way to get a non-layer-output tensor (e.g. the LM head's
logits) into `tensor_logs` was to crank `model_debug_level`, which logs every
single `_debug`-emitted tensor (~700 per step for a 30-layer model).

Add a `MultiStageConfig.debug_hidden_states_log: list[str]` field — regex
patterns that get appended to each model input's `output_hidden_states` set.
Matching tensors are still populated into `kwargs[hidden_states]` (existing
contract for the HF inference wrapper); now they're also written to
`tensor_logs` so the precision tool can compare them across variants.

`_debug` already had the `output_hidden_state`-matched branch but only used it
to populate `kwargs[hidden_states]`. Extending it to also call
`log_distributed_tensor` at a fixed verbosity (13, matching the test
convention so samples are recorded) is a small gating change.

Plumbed through `GPTModel.get_preprocessing_config` →
`LanguageModelBatchPreprocessingConfig.output_hidden_states` →
`LanguageModelBatch.get_model_inputs`, which compiles the patterns and unions
them into each `LanguageModelInput.output_hidden_states`.

The precision tool now sets `[r"head\.logits"]` and surfaces logits as a
dedicated `logits` column on the fw side of the summary table.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The head's `logits` tensor has `requires_grad=False` (output of a
custom-autograd Function), so the existing `_debug(logits, ...)` could
only capture the forward value. Add a second `_debug(grad, "logits.grad",
...)` call right after the loss returns the explicit `dL/d_logits` so
the gradient is captured at the same fidelity. With the precision tool's
`output_hidden_states` pattern `r"head\.logits"`, both `head.logits`
and `head.logits.grad` end up in tensor_logs.

Tool summary surfaces both via dedicated `logits` columns — placed at
end-of-fw and start-of-bw chronologically. For GSPO the bw-logits column
reveals that the dL/dlogits computation itself is extremely precise
(~0.001% relative error), and the apparent backward noise actually
enters through the head matmul further downstream.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…alues

`.3f%` was rounding the bw-logits values down to 0.001%-0.000%, hiding
real signal. Switch to `.4g%` so values across 5 orders of magnitude
(0.0001% to ~20%) all render with meaningful precision; large values
keep 4 significant figures, tiny ones spell out their leading non-zero
digits or fall back to scientific.

Bw column order is now first / logits / mid med / mid max / last so
`logits` sits right after `head` (the first bw row) — semantically
the gradient at logits is what the head's backward consumes before
producing the gradient at its input.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Keep the prior `.3f%` default in the summary so most columns still
show `0.000%` / `12.672%` style values, but compute a per-column
decimal count based on the smallest non-zero value in that column —
bumping up just enough that every cell carries at least two
significant figures. Decimal count is uniform within a column.

For the GSPO run, only the bw-logits column hits the threshold and
gets bumped from 3 to 5 decimals, surfacing values like `0.00095%`
that previously rounded to `0.001%` or worse.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Cell width drops from `max_label + 1` to `max_label`, inter-cell sep
from two spaces to one, group sep from four spaces to three. About 18
chars narrower on the GSPO smoke run with no loss of alignment or
readability.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Lets `pretrained.path: org/model-id` resolve via huggingface_hub.snapshot_download
when not a local directory, matching transformers' from_pretrained behavior.
Local paths pass through unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two ready-to-run configs for tools/evaluate_precision: smol.yaml sweeps
precision-stability features (full_precision_gradients, full_precision_residual,
fp32_lm_head) on SmolLM2-135M; smol_gspo.yaml repeats the sweep with the GSPO
policy-gradient loss enabled.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
A single forward+backward pass with micro_batch_size=1 has no gradient
accumulation, so toggling full_precision_gradients produces bit-identical
results to the bf16 baseline.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jlamypoirier
Copy link
Copy Markdown
Collaborator Author

jlamypoirier commented May 28, 2026

Caution

These numbers are invalid and have been superseded. This run was on SmolLM2-135M and, as the follow-up comment explains, hit the micro_batch_size-as-sequence-length bug — every run here was on 1-token inputs. The layer-wise measurement is redone correctly (Qwen2.5-0.5B, realistic MATH-500 text, forward+backward, averaged over 32 sequences) in this comment.


Sample precision-evaluation runs

Output of python -m tools.evaluate_precision -c examples/evaluate_precision/<config> on SmolLM2-135M, sequence length 128, single forward+backward step. Numbers are RMS relative diff vs the forced-fp32 reference, in %.

smol.yaml — pretrained HF weights

Variant            embeddings  fw mid med  fw mid max  fw logits  fw head     bw head  bw logits  bw mid med  bw mid max  bw decoder.0
bf16               0.000%      1.192%      19.904%     43.910%    4.597%      20.375%  14.710%    17.426%     22.495%     14.725%
bf16_fp32_lm_head  0.000%      1.192%      19.904%     43.901%    4.673%      19.559%  15.259%    16.797%     22.118%     14.375%
bf16_fp32_residual 0.000%      0.260%      4.569%      5.348%     0.568%      5.768%   4.132%     4.298%      8.025%      4.401%
bf16_max_precision 0.000%      0.260%      4.569%      5.347%     0.353%      6.653%   4.909%     4.959%      7.643%      4.381%

smol.yaml — random init (pretrained.model_weights=False)

Variant            embeddings  fw mid med  fw mid max  fw logits  fw head     bw head  bw logits  bw mid med  bw mid max  bw decoder.0
bf16               0.168%      1.739%      2.334%      2.425%     0.160%      0.284%   0.621%     2.120%      2.861%      3.188%
bf16_fp32_lm_head  0.168%      1.739%      2.334%      2.421%     0.160%      0.284%   0.617%     2.139%      2.937%      3.196%
bf16_fp32_residual 0.168%      1.372%      1.603%      1.689%     0.040%      0.295%   0.447%     1.435%      2.179%      2.321%
bf16_max_precision 0.168%      1.372%      1.603%      1.686%     0.041%      0.295%   0.434%     1.437%      2.180%      2.232%

smol_gspo.yaml — pretrained HF weights

Variant            embeddings  fw mid med  fw mid max  fw logits  fw head     bw head  bw logits  bw mid med  bw mid max  bw decoder.0
bf16               0.000%      0.242%      12.672%     11.312%    10.852%     0.107%   0.00095%   0.125%      0.340%      5.407%
bf16_fp32_lm_head  0.000%      0.242%      12.672%     11.296%    10.800%     0.105%   0.00176%   0.123%      0.336%      5.403%
bf16_fp32_residual 0.000%      0.227%      6.190%      6.593%     2.798%      0.042%   0.00573%   0.031%      0.068%      0.994%
bf16_max_precision 0.000%      0.227%      6.190%      6.587%     4.650%      0.049%   0.00730%   0.053%      0.128%      2.123%

smol_gspo.yaml — random init (pretrained.model_weights=False)

Variant            embeddings  fw mid med  fw mid max  fw logits  fw head     bw head  bw logits   bw mid med  bw mid max  bw decoder.0
bf16               0.173%      1.783%      2.222%      2.152%     2.296%      0.0133%  0.000095%   0.023%      0.158%      4.387%
bf16_fp32_lm_head  0.173%      1.783%      2.222%      2.143%     2.301%      0.0134%  0.000095%   0.023%      0.163%      4.626%
bf16_fp32_residual 0.173%      1.356%      1.625%      1.566%     0.860%      0.0048%  0.000044%   0.011%      0.081%      2.210%
bf16_max_precision 0.173%      1.356%      1.625%      1.560%     0.939%      0.0057%  0.000045%   0.012%      0.091%      2.611%

Observations

  • Pretrained weights produce much larger forward-pass errors than random init — particularly visible at fw mid max (single worst intermediate layer) and at head.logits. The CE loss config peaks around 20-44%, GSPO around 11-13%. Random init keeps everything under ~3%.
  • full_precision_residual is the dominant stability lever — it cuts the worst forward-pass numbers in roughly half (pretrained CE) or by a smaller fraction (random / GSPO).
  • fp32_lm_head (Add fp32_lm_head flag for vLLM precision parity #526) has limited effect on its own; it visibly helps only when combined with full_precision_residual and even there mostly on the absolute head output (not on logits).
  • GSPO backward errors are far smaller than CE backward errors (e.g. bw logits at 1e-3-1e-5% vs ~15% for CE), consistent with the GSPO loss producing much smaller logit gradients.

jlamypoirier and others added 3 commits May 28, 2026 14:49
Enables debug_all_param_gradients so every parameter's reduced gradient is
captured in tensor_logs alongside the existing layer activations and input
gradients. New rows are tagged with kind 'grad' and appear in the per-variant
table but stay out of the fw/bw summary table.

Also makes the per-variant table's Tensor column width fit the longest name
(parameter gradients can be 40+ chars) and bumps the Relative column to
adaptive precision (capped at 5 decimals) so legitimately tiny values stay
legible.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Group rows in the per-variant tables by display group with blank lines
between fw, bw, and grad. The reduce_gradients hook emits parameter
gradients chronologically interleaved with the backward pass, which made
the previous table hard to scan. Display grouping is independent of `kind`
so the summary aggregation is unaffected — head.logits.grad just moves
to the bw block visually.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each pass gets its own self-contained Variant x columns table with
labels picked from the actual first/last logged tensor. Weight gradients
get a head/mid med/mid max/embeddings layout mirroring the bw structure;
the grad table makes large norm_1 outliers (>200% relative) immediately
visible at a glance.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace the chronological first/last columns in the grad table with
named lookups (lm_head / embeddings) and split the intermediate
aggregation by category: linear weights, norm weights, biases. The
bias columns appear only when biases exist. lm_head shows n/a when
the LM head weight is tied to the embedding (e.g. SmolLM2), since the
combined gradient is recorded under the embedding parameter.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add `sample_level_overrides: dict[str, int]` (regex pattern -> level) to
`TensorLogsConfig`. `log_tensor` raises the effective level for any
tensor whose logged name matches a pattern, so callers can collect more
samples for specific tensors without changing the default. Useful for
sparsely-non-zero tensors like embedding-weight gradients, where the
default uniform stride misses every non-zero row.

evaluate_precision: switch `num_samples` to actually drive the level
(was only cropping the text log), bump default to 8192, default
sequence length to 2048 in the example yamls, and add a 1M-sample
override for `Global gradient: embeddings.*` to make embedding-grad
errors measurable on small batches.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jlamypoirier
Copy link
Copy Markdown
Collaborator Author

jlamypoirier commented May 29, 2026

Note

Superseded (small-model). These SmolLM2-135M within-engine chosen-logprob results are correct, but have been superseded by the realistic-text Qwen2.5-0.5B runs — the per-sequence / GSPO cross-engine comment and the layer-wise reproduction, which confirm these findings at a larger scale.


Within-engine precision: chosen-token log π

Added per-position log_softmax(logits)[label] as a diagnostic signal — the scalar that policy-gradient importance ratios actually depend on. Wired through as a no-grad chosen_logprob LM loss; the tool auto-adds it and surfaces a dedicated summary with bias, correlation, slope, and residual-after-linear-fit.

Bug fix bundled in: Fast-LLM's data.micro_batch_size is the per-sample sequence length, not the batch dim. The tool was passing 1 thinking it controlled batch size — so every previous run in this PR was on 1-token inputs. All numbers I posted earlier in this PR are invalid; the tables below replace them.

Smol (random labels)

Variant RMS rel Bias rel Corr Slope
bf16 0.473% -0.005% 0.99950 +0.99848
bf16_fp32_lm_head 0.371% -0.005% 0.99970 +0.99839
bf16_fp32_residual 0.358% -0.014% 0.99972 +1.00001
bf16_max_precision 0.216% -0.004% 0.99990 +0.99990
bf16_in_fp32_out (diagnostic) 0.371% -0.005% 0.99970 +0.99839
bf16_reduced_reduction (diagnostic) 0.473% -0.005% 0.99950 +0.99848

Smol_GSPO (RL data)

Variant RMS rel Bias rel Corr Slope
bf16 0.931% +0.047% 0.99982 +1.00095
bf16_fp32_lm_head 0.876% +0.047% 0.99984 +1.00081
bf16_fp32_residual 0.644% -0.032% 0.99992 +1.00108
bf16_max_precision 0.568% -0.017% 0.99994 +1.00109
bf16_in_fp32_out (diagnostic) 0.876% +0.047% 0.99984 +1.00081
bf16_reduced_reduction (diagnostic) 0.931% +0.047% 0.99982 +1.00095

Reference scale is ~11–12 nats, so 1% relative ≈ 0.1 nats absolute. Per-token bias sits at 0.005–0.014 nats.

Findings

1. Within-engine bf16 vs fp32 is small across the board. All RMS values < 1%, all bias values < 0.05% (≈0.005 nats), all correlations ≥0.9995, all slopes ≈1.000. There's no systematic distortion — just per-token decorrelated noise plus negligible mean shift. The intrinsic precision of a single bf16 engine compared to its fp32 equivalent is not on a scale that would cause RL collapse on its own.

2. fp32_lm_head's entire effect is the output dtype, not the matmul precision. Tested by adding bf16_in_fp32_out (= fp32_lm_head=True + matmul_precision='medium', which routes the matmul back through bf16 Tensor Cores while keeping logits fp32). Result matches standard fp32_lm_head to 5 sig figs on every metric. So the gain is purely from skipping the bf16 round on the output logits, not from running the matmul in fp32. The flag's name is misleading; it could be implemented as a bf16-in / fp32-out kernel with half the matmul memory bandwidth.

3. fp32_lm_head is downstream of the bias source. bf16 and bf16+fp32_lm_head have identical bias on GSPO data; only fp32_residual moves it. So whatever asymmetric rounding is producing the (very small) bias lives upstream of the LM head — bf16 residual stream / weight quantization. Useful structural fact, but the magnitudes involved are too small to matter intrinsically.

4. allow_bf16_reduced_precision_reduction = True has no observable effect at our matmul size (576 × 49152). cuBLAS likely doesn't pick a split-K kernel here, so the flag is moot.

Caveats and future work

This is a single-engine, single-step, small-model measurement (SmolLM2-135M, 1 fwd+bwd). The literature reports vLLM-vs-trainer log-prob mismatches of 2–24 nats per sequence and RL training collapse without precision fixes; our largest within-engine bias is ~0.005 nats. The gap is real, but many things differ between the two settings, any combination of which could be responsible:

  • Cross-engine alignment — rollouts in vLLM vs gradients in the trainer go through different kernel implementations (attention, matmul, fused ops). Our single-engine setup can't see this.
  • Model scale — different reduction depth, larger logit magnitudes, larger vocab.
  • Training dynamics — small per-step biases compound over thousands of steps; we measure one step.
  • Algorithm/data distribution — real RL rollouts (high-entropy regions, importance-ratio outliers) stress the head differently from our synthetic GRPO data.
  • Software stack — different RL frameworks (TRL, open-instruct, ScaleRL) have different numerical paths even within the trainer side.

The most direct follow-up — and the one closest to what the literature actually measures — is a vLLM-vs-trainer chosen-logprob comparison at matched scale. That would isolate the cross-engine factor specifically; combining it with the within-engine measurements here would let us decompose the literature's 2–24 nat mismatch into engine-mismatch vs other factors. Out of scope for this PR.

…riants

- New `chosen_logprob` LM loss: logs `log_softmax(logits)[label]` per position with
  no gradient contribution. Tool auto-adds it and surfaces a dedicated summary with
  bias, correlation, slope, and residual-after-linear-fit.
- `_compute_diff` reports bias_abs/rel, correlation, slope, residual_rms_abs/rel —
  the linear decomposition separates systematic shift/scale from per-position noise.
- Per-variant auto-calibrated power-of-2 gradient scale: each variant runs a
  calibration pass at scale=1 to measure max unscaled gradient, then the real run
  picks the largest power-of-2 scale that fits within fp16 range (with a small
  safety factor for fused-kernel partial sums). `_compare` unscales per variant.
- Tool: backend-override mechanism (`_torch_backend.*`) and `_torch_matmul_precision`
  variant keys for diagnostic variants. New variants: `bf16_in_fp32_out` (probes
  whether `fp32_lm_head`'s gain is from output dtype vs matmul precision),
  `bf16_reduced_reduction` (probes the split-K reduction path), and a full fp16
  sweep mirroring the bf16 variants.
- Fix: `data.micro_batch_size` in Fast-LLM is the per-sample sequence length, not
  the batch dim. Tool was passing 1 → every prior run was on 1-token inputs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jlamypoirier
Copy link
Copy Markdown
Collaborator Author

jlamypoirier commented May 29, 2026

Note

Superseded (small-model). This SmolLM2-135M fp16-vs-bf16 sweep stands, but the same ~5–8× fp16 precision gain is reproduced on Qwen2.5-0.5B with realistic text in the layer-wise reproduction.


FP16 vs BF16 within-engine: ~8× precision improvement, no bias

Added an FP16 sweep (fp16, fp16_fp32_residual, fp16_fp32_lm_head, fp16_max_precision) and per-variant auto-calibrated gradient scaling. Each variant runs a calibration pass at constant=1.0 to measure max unscaled gradient, then picks the largest power-of-2 scale that keeps scale × max_unscaled × headroom < fp16_max. _compare unscales per variant. Auto-picked scales for this setup: bf16/fp32 variants → 2^11=2048; fp16 base / fp16_fp32_residual → 2^9=512; fp16 + fp32_lm_head variants → 2^11=2048 (the cleaner head reduces the backward-grad max).

Chosen-token log π — smol (random labels)

Variant RMS rel Bias rel Resid rel Corr
bf16 0.473% -0.005% 0.472% 0.99950
bf16_fp32_lm_head 0.371% -0.005% 0.370% 0.99970
bf16_max_precision 0.216% -0.004% 0.216% 0.99990
fp16 0.057% +0.0005% 0.056% 0.99999
fp16_fp32_lm_head 0.044% +0.0007% 0.044% 1.00000
fp16_max_precision 0.028% -0.0004% 0.027% 1.00000

Chosen-token log π — smol_GSPO (RL data)

Variant RMS rel Bias rel Resid rel Corr
bf16 0.931% +0.047% 0.928% 0.99982
bf16_max_precision 0.568% -0.017% 0.565% 0.99994
fp16 0.113% -0.004% 0.110% 1.00000
fp16_max_precision 0.104% -0.003% 0.101% 1.00000

Gradients — smol_GSPO (RL data)

Variant linear_med linear_max norm_med norm_max embeddings
bf16 0.000236% 0.0376% 0.0174% 0.292% 0.000483%
bf16_max_precision 0.000186% 0.0346% 0.0153% 0.608% 0.000355%
fp16 0.000028% 0.0061% 0.0023% 0.109% 0.000057%
fp16_max_precision 0.000028% 0.0073% 0.0022% 0.155% 0.000058%

Findings

1. FP16 gives ~8× precision reduction across the board. Per-token chosen_logprob: 0.473%→0.057% (smol) and 0.931%→0.113% (gspo) — both 8.2-8.3× reduction, matching the 7→10 mantissa-bit ratio. Gradients show the same ~8× ratio across linear/norm/embedding weights. Bias collapses too: from ~0.05% in bf16 to ~0.003-0.005% in fp16.

2. FP16 + fp32_lm_head / fp32_residual add little on top. Unlike bf16 where fp32_residual was meaningful (0.93%→0.64% on gspo), in fp16 the residual is already precise enough that adding fp32 to the residual stream barely moves anything (0.113%→0.111%). fp32_lm_head is similarly minor (0.113%→0.107%). The stability features that matter for bf16 are essentially superfluous at fp16.

3. Correlation slope is exactly 1.0 for fp16. Bf16 had slope deviations from 1.0 in the 0.0008-0.0011 range (very small systematic distortion); fp16's slope is 0.99979-1.00000 — effectively unity. No structural distortion at all.

4. The remaining caveats from the prior bf16 analysis still stand. This is single-engine, single-step, SmolLM2-135M. The literature's reported RL collapse on bf16 still isn't visible at this scale, and we can't probe cross-engine alignment with this tool. What this commit does settle: FP16 has the expected ~8× intrinsic precision over BF16 in our setting, which is consistent with the precision-based mechanism that the FP16 paper (Liu et al. 2510.26788) invokes.

The natural next step to actually attribute the literature's RL-stability claim is still a vLLM-vs-trainer log-prob diff — out of scope for this PR.

Committed in 312343e7.

jlamypoirier and others added 4 commits June 1, 2026 12:37
Replace the Trainer + data-loading path in tools/evaluate_precision.py with a lean
forward+backward runner (InferenceRunner-style: model + ScheduleRunner + lr-0 optimizer,
training-phase schedule) fed a fixed, already-preprocessed input. This lets the model see an
exact token tensor (the data pipeline would re-randomize the model input via shuffle/packing)
and drops the training/data-loading infrastructure the tool doesn't need — which also fixes the
GPU-memory accumulation that OOM'd on larger models (each run's model+optimizer is now freed).

The input is built once (configurable input_text_file -> tokenized, or uniform-random) and saved
to output_dir/input_ids.pt so the DeepSpeed-side tool can consume byte-identical model input.

Add tools/evaluate_precision_deepspeed.py: the HF-transformers + DeepSpeed counterpart, mirroring
PipelineRL's proven fp32-lm-head and log-pi computation, reporting the same chosen-logprob and
categorized-gradient metrics so Fast-LLM's bf16 precision pattern can be benchmarked against
DeepSpeed's. fp16 gradients use loss scaling to avoid underflow.

Add examples/evaluate_precision/qwen.yaml and sample_text.txt for the Qwen2.5-0.5B comparison.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Lean runner now honors pretrained.model_weights (initialize_weights when not loading), matching the
trainer's branch; the DeepSpeed harness gains --random-init (build from config). Note: random init is
a poor cross-engine test — the two engines use different init schemes (different models), and HF's
from_config init yields near-uniform untrained logits where bf16 noise dominates (correlation ~0).
The pretrained comparison is the meaningful one.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Subprocess-per-variant log-prob precision sweep mirroring the trainer-side
tools: feeds a fixed prompt, reads vLLM prompt_logprobs (chosen-token log-pi
aligned with the trainers), and compares each precision variant against the
fp32 reference. Forces a single attention backend across variants to isolate
precision from the kernel.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The bf16_last_layer_fp32 quant matches its fp32 head by layer-name suffix.
vLLM names the tied head embed_tokens (lm_head = embed_tokens), so the
production lm_head prefix silently runs a bf16 head on tied models. Default
--fp32-head-prefix auto now picks embed_tokens when embeddings are tied so the
fp32 head genuinely binds (text bf16_fp32_head 1.05% -> 0.79%, matching the
trainers); pass lm_head for the literal production setting.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jlamypoirier
Copy link
Copy Markdown
Collaborator Author

Claude Opus 4.8 note — within-engine log-prob precision, all three engines.

Cross-engine precision: Fast-LLM vs DeepSpeed vs vLLM

Per-token chosen-token log π precision on Qwen2.5-0.5B, each variant measured against that engine's own fp32 reference (RMS relative error), on byte-identical inputs (seed-0 random and a fixed ~2K-token text; fp32 scale matches across engines: 13.52 / 3.611). Trainers run forward+backward; vLLM is forward-only.

Realistic text (|log π| scale ≈ 3.61):

variant Fast-LLM DeepSpeed vLLM
bf16, bf16 head 1.048% 1.075% 1.053%
bf16, fp32 head 0.753% 0.830% 0.794%
fp16, fp16 head 0.136% 0.132% 0.146%
fp16, fp32 head 0.101% 0.105% — ¹

Random tokens (|log π| scale ≈ 13.52):

variant Fast-LLM DeepSpeed vLLM
bf16, bf16 head 0.265% 0.307% 0.280%
bf16, fp32 head 0.265% 0.306% 0.280% ²
fp16, fp16 head 0.032% 0.040% 0.049%
fp16, fp32 head 0.032% 0.040% — ¹

Takeaways

  • Fast-LLM's bf16 precision matches the proven engines. bf16 lands at ~1.05% (text) / ~0.27–0.31% (random) on all three; fp16 is ~7–8× tighter everywhere. Correlation ≥ 0.9999 in every cell.
  • fp32 lm head is regime-dependent, identically across engines: negligible on high-entropy random input, but a ~23–28% error reduction on realistic text (FLLM −28%, DS −23%, vLLM −25%). This is the within-engine effect; its main role in the stack is cross-engine matching (vLLM emits fp32 logits, trainers set fp32_lm_head to match).
  • The residual engine-to-engine spread (≤ ~0.05 pp) is consistent with attention-kernel differences, not a precision regression.

¹ vLLM has no fp16+fp32-head variant

vLLM's bf16_last_layer_fp32 quant rejects an fp16 body (supports bf16/fp32 only).

² vLLM fp32 head silently no-ops on tied-embedding models

For tied models (Qwen2.5-0.5B/1.5B/3B) vLLM sets lm_head = model.embed_tokens (qwen2.py:548), so the quant's lm_head-prefix match misses and the head runs in bf16 — bf16_fp32_head came out bitwise-identical to plain bf16. The text-table value above is the corrected run (prefix retargeted to embed_tokens, which reproduces the expected −25%). On tied small models, the as-shipped vLLM quant uses a bf16 head while the trainers use fp32 — a latent cross-engine mismatch. Untied 7B is unaffected.

Method: DeepSpeed forces sdpa and vLLM forces TRITON_ATTN across all dtypes (kernel held fixed so the diff reflects precision); Fast-LLM uses its native kernels. Tools: tools/evaluate_precision{,_deepspeed,_vllm}.py.

jlamypoirier and others added 3 commits June 2, 2026 11:53
Add tools/evaluate_precision_cross_engine.py: loads each engine's per-token
log π vectors and reports the cross-engine log-ratio δ = log π_A − log π_B
(mean/RMS/max/clip), plus the error-correlation decomposition δ = floor +
(e_A − e_B) with ρ = corr(e_A, e_B) — the quantity that explains why fp32-head
matters across engines (removes the uncorrelated head-rounding component) while
being small within one.

Persist the full per-token log π (token order, aligned 1:1 with vLLM's
prompt_logprobs[1:]) from the two trainer tools so the comparison can consume
plain tensors: evaluate_precision.py extracts the chosen_logprob vector from the
run artifacts; evaluate_precision_deepspeed.py gains --output-dir.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Generalize the comparison over precisions {bf16, fp16}: each gets a matched
(fp32 head both sides) and mismatched (A fp32 head, B body-dtype head) group,
both now spanning every engine pair — the mismatched group previously only
covered vLLM pairs, so Fast-LLM−DeepSpeed was missing. vLLM has no fp16+fp32
head (its quant rejects an fp16 body), so fp16 matched is trainer-only; fp16
mismatched still covers all pairs with vLLM as the body-head side.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…tion

Enumerate every combination of {fp32 head, body-dtype head} on each side per
engine pair, rather than one matched + one mismatched row. This adds the
production-relevant direction that was missing: a body-dtype head on the
training side against vLLM's fp32 head (vLLM always emits fp32 logits in
production), the prior single mismatched row had it reversed. Per-side head
columns make each row's config explicit; the decomposition mirrors the same
precision/head/pair combinations.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jlamypoirier
Copy link
Copy Markdown
Collaborator Author

Claude Opus 4.8 note — cross-engine log-probability comparison (Fast-LLM / DeepSpeed / vLLM).

Cross-engine RL log-prob agreement

Follow-up to the within-engine tables above. This measures the gap across engines on byte-identical input: the per-token log-ratio δ = log π_trainer − log π_vLLM over the chosen tokens. When the trainer recomputes log π on tokens vLLM sampled, δ is the log of the RL importance ratio exp(log π_train − log π_old) that multiplies the advantage — so δ is the quantity that actually perturbs the gradient, and the gap the literature quotes in nats.

Setup: Qwen2.5-0.5B, single forward (vLLM forward-only), byte-identical input (seed-0 random and a fixed ~2K-token text; fp32 log-prob scale matches across engines). vLLM emits fp32 logits in production, so the columns are the two trainer-vs-vLLM comparisons; rows sweep body precision × {fp32 head, body-dtype head} on each side (labeled trainer head / vLLM head). Cells are RMS δ in nats.

Realistic text (|log π| ≈ 3.61):

config (trainer head / vLLM head) Fast-LLM / vLLM DeepSpeed / vLLM
fp32 (floor) 0.0022 0.0022
bf16 — fp32 / fp32 — production 0.0306 0.0321
bf16 — fp32 / bf16 0.0390 0.0408
bf16 — bf16 / fp32 — prod mismatch (fp32_lm_head off) 0.0407 0.0404
bf16 — bf16 / bf16 0.0468 0.0468
fp16 — fp32 / fp16 0.0050 0.0050
fp16 — fp16 / fp16 0.0060 0.0057

Random tokens (|log π| ≈ 13.52):

config (trainer head / vLLM head) Fast-LLM / vLLM DeepSpeed / vLLM
fp32 (floor) 0.0041 0.0041
bf16 — fp32 / fp32 — production 0.0373 0.0409
bf16 — fp32 / bf16 0.0374 0.0409
bf16 — bf16 / fp32 — prod mismatch 0.0373 0.0409
bf16 — bf16 / bf16 0.0374 0.0410
fp16 — fp32 / fp16 0.0048 0.0053
fp16 — fp16 / fp16 0.0048 0.0054

Takeaways

  • Matching the trainer head to vLLM's fp32 logits (fp32_lm_head) is the production config and the tightest bf16 cell. On text, trainer-fp32/vLLM-fp32 = 0.031; letting the trainer head fall back to bf16 (the realistic mismatch — fp32_lm_head off, vLLM still fp32) widens it to ~0.040 (+25–33%); both sides on a bf16 head is worst (~0.047). On random the head column is flat — head precision is irrelevant at high entropy. Same regime-dependence as the within-engine tables.
  • It's a partial fix, not the dominant term. Even fully matched, the trainer↔vLLM bf16 gap is ~0.031 nats; that residual is bf16 body rounding that decorrelates across engines (per-token error correlation ρ ≈ 0.4, so the errors add rather than cancel). fp32_lm_head removes only the head component (~⅓ of the gap); the body ⅔ would need an fp32 body, which isn't done in training.
  • Fast-LLM matches DeepSpeed. The two trainers are fp32-identical (their fp32 floor is 0.0000), and Fast-LLM's bf16 gap vs vLLM is at or slightly below DeepSpeed's in every cell. Adopting Fast-LLM as the trainer adds no sampler↔trainer mismatch beyond the proven stack.
  • fp16 is ~6–8× tighter than bf16 across the board, but vLLM cannot run an fp16 fp32-head (its quant rejects an fp16 body), so an fp16 trainer can't be head-matched to vLLM. Moot in practice — the stack runs bf16.
  • Absolute magnitude is small in this regime: the per-token gap is ~0.03–0.04 nats (typical reweight exp(δ) ≈ 1.03, worst token ≈ 1.3), ~1 nat per sequence — well below the literature's 2–24 nats. That range comes from long sampled generations on larger models with policy drift; this is a single 0.5B prefill at init, i.e. the clean numerical floor. mean(δ) ≈ 0 in every cell, so the gap is variance (numerical noise), not a systematic bias.

Rows requiring a vLLM fp32 head at fp16 are dropped (unavailable). Method: kernel held fixed across dtypes (vLLM forced to one attention backend, DeepSpeed forced to sdpa, Fast-LLM native kernels) so the diff reflects precision. Tools: tools/evaluate_precision{,_deepspeed,_vllm,_cross_engine}.py.

jlamypoirier and others added 3 commits June 2, 2026 18:17
Runs a single forward in `StageMode.inference` (no optimizer, no gradient
buffers) instead of forward+backward, so large models (e.g. 7B in fp32) fit
where forward+backward+Adam would OOM.

The LM head skips all losses in eval mode, so after setup the head(s) are
forced back into train mode directly; `run_step`'s per-step `train(False)`
is a guarded no-op once `_training` is False, keeping the head trained so
`chosen_logprob` still logs. Only `chosen_logprob` is configured (no
grad-producing loss), so no backward ever touches the absent gradient
buffers. Uses a validation-phase schedule (forward-only but still produces
labels, unlike inference phase). Verified the forward-only log π is
bitwise-identical to the forward+backward path.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
`--forward-only` initializes the DeepSpeed engine without an optimizer (no
fp32 master copy or Adam state) and runs a single eval()+no_grad forward,
for large models (e.g. 7B in fp32) where forward+backward+Adam would OOM.

The engine is kept rather than bypassed for a plain HF forward: DeepSpeed's
bf16/fp16 forward is not bit-identical to a plain HF forward in the same
dtype — measured ~0.032 nats mean / 0.22 max on Qwen2.5-0.5B bf16,
comparable to the cross-engine signal itself — so bypassing it would shift
the log π. The no-optimizer engine forward is bitwise-identical to the full
forward+backward path across all variants (fp32/bf16/fp16, head on/off).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
7B is untied, so the fp32 LM head genuinely changes the logits (unlike the
tied 0.5B). Forward-only so the fp32 reference fits in memory.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jlamypoirier
Copy link
Copy Markdown
Collaborator Author

Claude Opus 4.8 note — cross-engine log-probability comparison at 7B (Qwen2.5-7B).

Cross-engine RL log-prob agreement — Qwen2.5-7B

Follow-up to the Qwen2.5-0.5B cross-engine tables above, now on Qwen2.5-7B. The 7B model has untied embeddings, so the LM head is a real parameter and the fp32-head upcast genuinely binds in vLLM (bf16_last_layer_fp32) — the actual production configuration, which the tied 0.5B could only emulate. Same metric: per-token δ = log π_trainer − log π_vLLM over the chosen tokens (the log of the RL importance ratio), in nats.

Setup: single forward per engine, forward-only (the fp32 reference doesn't fit forward+backward at 7B); byte-identical input; one attention backend held fixed per engine. vLLM emits fp32 logits in production, so the columns are the two trainer-vs-vLLM comparisons; rows sweep body precision × {fp32 head, body-dtype head}. Cells are RMS δ in nats.

Realistic text (|log π| ≈ 2.85):

config (trainer head / vLLM head) Fast-LLM / vLLM DeepSpeed / vLLM
fp32 (floor) 0.0017 0.0017
bf16 — fp32 / fp32 — production 0.0227 0.0222
bf16 — fp32 / bf16 0.0320 0.0319
bf16 — bf16 / fp32 — prod mismatch (fp32_lm_head off) 0.0329 0.0319
bf16 — bf16 / bf16 0.0392 0.0393
fp16 — fp32 / fp16 0.0042 0.0041
fp16 — fp16 / fp16 0.0050 0.0051

Random tokens (|log π| ≈ 13.88):

config (trainer head / vLLM head) Fast-LLM / vLLM DeepSpeed / vLLM
fp32 (floor) 0.0024 0.0024
bf16 — fp32 / fp32 — production 0.0259 0.0356
bf16 — fp32 / bf16 0.0262 0.0359
bf16 — bf16 / fp32 — prod mismatch 0.0263 0.0361
bf16 — bf16 / bf16 0.0267 0.0365
fp16 — fp32 / fp16 0.0034 0.0039
fp16 — fp16 / fp16 0.0034 0.0039

Takeaways

  • fp32_lm_head helps a bit more on the untied 7B, but it's a modest bump, not a different regime. On text, matching the trainer head to vLLM's fp32 logits tightens 0.0329 → 0.0227 (~31%, vs ~25% on the tied 0.5B; within Fast-LLM, 1.11% → 0.73%). The untied model's real contribution is qualitative — the vLLM quant genuinely binds, so this is the production path rather than an emulation — not a large quantitative jump. On random the head column is flat (head precision is irrelevant at high entropy), same as 0.5B.
  • It's still a partial fix. Even fully matched, the trainer↔vLLM bf16 text gap is ~0.023 nats; the residual is bf16 body rounding that decorrelates across engines (per-token error correlation ρ ≈ 0.39), which the head can't touch (~⅔ of the gap).
  • Fast-LLM matches DeepSpeed on text and is closer to vLLM on random. The fp32 floor between the two trainers is 0.0000 (fp32-identical). On text the bf16 gap vs vLLM is within noise (0.0227 vs 0.0222). On random Fast-LLM is clearly closer to vLLM than DeepSpeed is (0.026 vs 0.036): DeepSpeed's bf16 is noisier on random tokens (within-engine RMS error 0.049 vs Fast-LLM's 0.040), including one outlier token at |δ| ≈ 1.1 nats that Fast-LLM and vLLM agree on. Adopting Fast-LLM adds no sampler↔trainer mismatch beyond the proven stack.
  • fp16 is ~5–8× tighter than bf16 across the board, but vLLM can't run an fp16 fp32-head (its quant rejects an fp16 body), so an fp16 trainer can't be head-matched to vLLM. Moot in practice — the stack runs bf16.
  • Absolute magnitude is small: the per-token gap is ~0.02–0.04 nats (typical reweight exp(δ) ≈ 1.02–1.04), well below the literature's 2–24 nats — that range is long sampled generations on larger models with policy drift; this is a single prefill at init. mean(δ) ≈ 0 in every cell, so the gap is variance, not bias.

Rows requiring a vLLM fp32 head at fp16 are dropped (unavailable). Forward-only: Fast-LLM uses StageMode.inference, DeepSpeed initializes its engine without an optimizer — both verified bitwise-identical to their forward+backward log π. Tools: tools/evaluate_precision{,_deepspeed,_vllm,_cross_engine}.py.

Extend the precision tools to run many independent sequences and report the
across-sequence distribution of the length-normalized log-ratio (per-sequence
mean delta over completion tokens) — the GSPO-relevant quantity — alongside the
existing per-token (GRPO) view.

- evaluate_precision.py: input_dataset/num_sequences fields build one sequence
  per dataset example (chat-templated prompt + reference solution, prompt length
  recorded), saved to inputs.pt. Forward-only, one forward per sequence on a
  resident model, flushed per sequence so the per-sequence chosen_logprob vectors
  stay separable; saved as a list. Layer/hidden debug logs disabled in this mode.
- evaluate_precision_deepspeed.py / _vllm.py: --inputs-file consumes the shared
  sequence set and saves per-sequence log-prob lists.
- evaluate_precision_cross_engine.py: --inputs-file slices the completion region
  per sequence; adds an across-sequence table (mean/std/RMS/max of per-sequence
  mean delta). Loads logprobs as list-or-tensor (single-sequence back-compat).
- examples/evaluate_precision/qwen_multi_seq.yaml: Qwen2.5-0.5B + MATH-500, 256
  sequences.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jlamypoirier
Copy link
Copy Markdown
Collaborator Author

Claude Opus 4.8 note — per-sequence (GSPO) cross-engine log-probability comparison.

Per-sequence (length-normalized / GSPO) cross-engine agreement — Qwen2.5-0.5B

The earlier cross-engine tables are per-token δ = log π_A − log π_B over the chosen tokens — the GRPO-relevant quantity (per-token variance). GSPO instead uses a length-normalized sequence ratio: the importance ratio for a whole sequence is exp(δ̄) where δ̄ = mean δ over the sequence's completion tokens. Averaging over the sequence damps the per-token variance (~1/√T) and leaves the systematic part. So the GSPO-relevant question is the across-sequence distribution of δ̄, which this run measures directly.

Setup: 256 independent sequences, each a chat-templated math prompt (system + problem) followed by the reference solution, from HuggingFaceH4/MATH-500; δ̄ is taken over the completion tokens only. Byte-identical token ids across engines, one attention backend held fixed per engine, forward-only. vLLM emits fp32 logits in production, so the rows below are both bf16 bodies with fp32 LM heads on both sides (the stack default). All values in nats (δ̄ is a log-ratio; exp(δ̄) is the reweighting factor).

pair (A − B), bf16 / fp32 heads mean δ̄ std RMS max|δ̄|
fp32 floor — Fast-LLM − DeepSpeed −0.00000 0.00000 0.00000 0.0000
fp32 floor — trainer − vLLM +0.00016 0.00102 0.00103 0.0094
Fast-LLM − vLLM (production) −0.00016 0.00432 0.00432 0.0323
Fast-LLM − DeepSpeed −0.00009 0.00400 0.00399 0.0238
DeepSpeed − vLLM −0.00007 0.00337 0.00336 0.0129
fp16 (fp32 head) — Fast-LLM − vLLM +0.00009 0.00054 0.00054 0.0043

For context, the matching per-token (GRPO) RMS δ on the same run: bf16 Fast-LLM − vLLM 0.0364, DeepSpeed − vLLM 0.0341, Fast-LLM − DeepSpeed 0.0369; fp32 floor between the trainers 0.0000, trainer − vLLM 0.0043; fp16 ≈ 0.005.

Takeaways

  • No systematic bias. Across 256 sequences |mean δ̄| < 0.0005 nats for every bf16 pair — the engine swap adds no directional drift to the length-normalized ratio. The residual is sequence-to-sequence variance, which averages down further over a training batch. Bias is the part that wouldn't average away, and it's effectively zero here.
  • Length-normalization shrinks the gap ~8–9×. Per-token bf16 RMS 0.036 → per-sequence RMS 0.0043 (the expected ~1/√T from averaging over the completion). Even the worst single sequence is max|δ̄| ≈ 0.03 nats, i.e. a reweight of exp(0.03) ≈ 1.03.
  • Fast-LLM tracks DeepSpeed. Against vLLM the per-sequence RMS is 0.0043 (Fast-LLM) vs 0.0034 (DeepSpeed) — same ballpark — and the fp32 floor between the two trainers is exactly 0 (fp32-identical). Adopting Fast-LLM introduces no sampler↔trainer mismatch beyond the proven stack.
  • Scope / caveats. This is a single forward at initialization (no sampling, no policy drift), so it's the numerical floor of the cross-engine gap, not the trained-policy gap the literature quotes. The model is the tied 0.5B, where vLLM's bf16_last_layer_fp32 quant no-ops on the head — so fp32-head matching barely moves the per-sequence numbers here; the untied 7B is where that knob genuinely binds and would be the real test of the head's per-sequence contribution.

Tools: tools/evaluate_precision{,_deepspeed,_vllm,_cross_engine}.py gain a multi-sequence mode (input_dataset / --inputs-file); each engine runs the sequences independently and saves per-sequence log π, and the cross-engine tool reports the across-sequence δ̄ distribution alongside the per-token table. Config: examples/evaluate_precision/qwen_multi_seq.yaml.

@jlamypoirier
Copy link
Copy Markdown
Collaborator Author

jlamypoirier commented Jun 3, 2026

Claude Opus 4.8 note — layer-wise precision over multiple real sequences (redoing the first comment properly).

Layer-wise bf16/fp16-vs-fp32 precision — Qwen2.5-0.5B, real text, averaged over 32 sequences

The first comment in this PR reported per-layer forward / backward / gradient RMS-relative diffs against an fp32 reference, but on SmolLM2-135M and — as the next comment flagged — on accidental 1-token inputs, so those numbers are void. This redoes the measurement properly.

Setup: Qwen2.5-0.5B (pretrained), 32 independent MATH-500 sequences (chat-templated problem + reference solution), one full forward+backward per sequence with the per-layer debug logs on. Each per-tensor metric is averaged across the 32 sequences — the length-normalized / GSPO-relevant view, where per-token noise damps (~1/√(N·T)) and systematic per-layer bias persists. All values are RMS relative diff vs the forced-fp32 reference, in %. mid excludes the first/last layer of the group.

Forward activations

Variant embeddings mid med mid max logits head
bf16 0.000% 1.598% 2.746% 2.847% 0.339%
bf16_fp32_lm_head 0.000% 1.598% 2.746% 2.842% 0.350%
bf16_fp32_residual 0.000% 1.418% 2.166% 2.248% 0.302%
bf16_max_precision 0.000% 1.418% 2.166% 2.242% 0.257%
fp16 0.0000025% 0.247% 0.517% 0.535% 0.061%
fp16_fp32_lm_head 0.0000025% 0.247% 0.517% 0.535% 0.063%

Backward activations

Variant head logits mid med mid max decoder.0
bf16 2.262% 0.042% 8.054% 11.717% 14.072%
bf16_fp32_lm_head 2.110% 0.051% 7.992% 11.674% 14.034%
bf16_fp32_residual 1.789% 0.027% 7.060% 11.155% 12.805%
bf16_max_precision 1.595% 0.030% 6.963% 11.105% 12.759%
fp16 0.387% 0.016% 1.479% 2.216% 3.360%
fp16_fp32_lm_head 0.374% 0.011% 1.475% 2.212% 3.370%

Parameter gradients

Variant lm head linear med linear max norm med norm max bias med bias max embeddings
bf16 n/a 9.016% 20.426% 9.777% 22.512% 8.600% 20.056% 7.063%
bf16_fp32_lm_head n/a 8.911% 20.840% 9.712% 22.700% 8.477% 20.081% 7.020%
bf16_fp32_residual n/a 8.180% 19.922% 8.566% 16.784% 7.628% 19.393% 6.773%
bf16_max_precision n/a 8.078% 20.058% 8.556% 17.225% 7.488% 19.382% 6.598%
fp16 n/a 1.627% 7.802% 1.825% 10.690% 1.743% 4.827% 1.153%
fp16_fp32_lm_head n/a 1.628% 7.810% 1.815% 10.732% 1.740% 4.928% 1.140%

Chosen-token log π (per-token)

Variant RMS rel Bias rel Resid rel Corr Slope Max abs
bf16 3.086% +0.139% 2.975% 0.99943 +0.99114 1.652
bf16_fp32_lm_head 3.038% +0.141% 2.925% 0.99945 +0.99114 1.652
bf16_fp32_residual 1.479% −0.127% 1.461% 0.99986 +1.00102 0.4634
bf16_max_precision 1.368% −0.110% 1.353% 0.99988 +1.00078 0.4619
fp16 0.472% −0.026% 0.464% 0.99999 +1.00083 0.2433
fp16_fp32_lm_head 0.465% −0.027% 0.457% 0.99999 +1.00086 0.2429

Findings

  • Forward error is moderate and grows with depth. bf16 peaks at the logits (~2.8%), embeddings are exactly fp32 (the input rows are identical), and error accumulates through the stack. This is nothing like the 20–44% the first comment reported — that was the 1-token bug compounded by pathological inputs, not a real depth-propagation signal.
  • full_precision_residual is the dominant lever. It cuts the forward error and the per-token log π RMS roughly in half (logits 2.85% → 2.25%, log π 3.09% → 1.48%). fp32_lm_head on its own barely moves the forward pass (logits 2.847% → 2.842%) — consistent with the earlier finding that its effect is the output dtype, not the matmul.
  • fp16 is ≈ 5–8× more precise than bf16 across forward, backward, gradient, and log π — the expected mantissa-bit gain, and the only knob here that pulls the per-token log π RMS below ~0.5%.
  • Per-token log π shows no systematic distortion even on real text: bias < 0.15%, correlation ≥ 0.9994, slope ≈ 1.0. The diff is decorrelated per-token noise plus a negligible mean shift, and full_precision_residual is the only variant that visibly moves the (already tiny) bias.

Caveats

  • Single forward+backward at initialization (pretrained weights, no optimizer step, no training dynamics) — this is the per-step numerical floor, not the drift accumulated over a training run.
  • 0.5B is weight-tied, so there is no separate LM-head gradient (lm head = n/a) and fp32_lm_head has little to bind to; on an untied larger model the head columns would behave differently.
  • The large gradient relative errors (norm ~22%, bias ~20%) are on small-magnitude parameters where relative RMS is naturally inflated; the absolute diffs are tiny.
  • These are within-engine (bf16/fp16 vs fp32 in the same trainer), not the cross-engine vs-vLLM gap measured in the per-sequence/GSPO comment.

Tooling: tools/evaluate_precision.py gains a forward+backward multi-sequence mode — input_dataset with forward_only: false runs the full per-layer debug per sequence and averages each metric across sequences (the report also keeps the across-sequence std/max of each tensor's relative RMS). Config: examples/evaluate_precision/qwen_multi_seq_layerwise.yaml.

@jlamypoirier
Copy link
Copy Markdown
Collaborator Author

jlamypoirier commented Jun 4, 2026

Claude Opus 4.8 note — are the double-digit gradient errors a problem? Precision noise vs. gradient sampling noise, and the systematic bias.

Gradient precision: noise vs. systematic bias — Qwen2.5-0.5B, 256 sequences

The layer-wise tables show double-digit relative-RMS errors on the parameter gradients (bf16 vs fp32: linear ~9% median, norm/bias up to ~20–23% worst-layer). This comment checks whether that actually matters, by comparing the precision error to the gradient's own sample-to-sample fluctuation and isolating the part that survives batch-averaging.

Setup: Qwen2.5-0.5B, 256 independent MATH-500 sequences, one forward+backward each, plain cross-entropy. Per parameter, across the 256 sequences:

  • true gradient = the average over all sequences (best estimate of the real gradient),
  • sampling fluctuation = how much a single sequence's gradient deviates from that average,
  • precision noise = bf16/fp16 gradient minus fp32 gradient, same sequence,
  • systematic bias = the average of that bf16−fp32 difference — the part that points the same way every sequence, so it does not cancel when averaged over a batch.

Table 1 — gradient noise (linear params; norm/bias within ~1 pt)

Config per-sample spread (% of true grad) precision noise vs sampling noise — median precision noise vs sampling noise — worst layer systematic bias (% of true grad)
bf16 154% 11% 36% 7.6%
bf16 + fp32 head 154% 11% 36% 7.5%
bf16 + fp32 residual 154% 10% 24% 6.9%
bf16 + both 154% 10% 24% 6.9%
fp16 154% 2% 14% 1.8%
fp16 + fp32 head 154% 2% 14% 1.8%

Per-sample spread is identical across configs — it's a property of the fp32 gradient, not the precision.

Read-out: per-sample gradients are noise-dominated — the sample-to-sample spread (~154%) is larger than the true gradient itself. So the scary double-digit "relative errors" were precision error measured against a per-sample gradient that is mostly sampling noise. Measured against that noise, the bf16 precision error is only ~11% (fp16: ~2%), and like sampling noise it averages away over a batch. The part that does not average away is the systematic bias: ~7.6% of the true gradient for bf16, ~1.8% for fp16.

Is the systematic bias real? (256 sequences)

Config measured bias noise floor (spurious bias under H0) true bias (floor-removed) bias ÷ floor split-half agreement
bf16 7.6% 1.0% 7.5% 7.7 0.97
fp16 1.8% 0.2% 1.8% 9.2 0.98

Even with zero true bias, averaging 256 finite-sample errors fakes a ~1% bias (the floor). The measured bias is 7.7× that floor, and splitting the 256 sequences in half gives bias estimates that agree at 0.97 correlation — a reproducible fixed pattern, not a sampling artifact. (At 32 sequences this was only 2.3× / 0.69, which is why we went to 256.) The floor-corrected bias is stable across sample size (7.4% at N=32 → 7.5% at N=256).

What the bias costs: the single-step bias–noise crossover

The bias is a floor that batch size cannot beat: more samples drive the sampling noise to zero, but the bf16 batch gradient stays offset by the bias. The single-step bias–noise crossover — the batch size at which, in one optimizer step, the sampling noise has shrunk to the bias level — is (per-sample spread ÷ bias)²:

Not to be confused with the critical batch size of McCandlish et al. (2018): same gradient-noise machinery, but that crossover is gradient noise vs. signal (a training-efficiency limit), whereas this is gradient noise vs. precision bias (a numerical-accuracy limit). And as the name says, this is a single-step quantity — see caveats.

Config systematic bias crossover batch (sequences) crossover batch (tokens, ~317/seq)
bf16 7.6% ~410 ~130k
bf16 + fp32 residual 6.9% ~500 ~160k
fp16 1.8% ~7,300 ~2.3M

Below the crossover, bf16's bias is buried under sampling noise the optimizer already tolerates (bf16 ≈ fp32). Above it, the bias dominates and more batch stops buying accuracy. bf16's crossover (~130k tokens) lands around typical training batch sizes; fp16 has ~17× more headroom.

Findings

  • The double-digit gradient errors are not a real signal — they're precision noise measured against noise-dominated per-sample gradients. The random precision noise sits ~9× below the gradient's own sampling noise and averages away over a batch.
  • The genuine, confirmed effect is a ~7.5% systematic bf16 gradient-direction bias (fp16: ~1.8%) that survives batch-averaging.
  • fp32_residual trims the bias ~9% (extends the crossover ~20%); fp32_lm_head is a no-op here; fp16 is the real lever (~4× lower bias, ~17× more batch headroom).

Caveats / open questions

  • Single forward+backward at initialization (pretrained weights), tied 0.5B (so fp32_lm_head has nothing to bind to; an untied larger model would differ).
  • The crossover above is a single-step quantity. A systematic bias accumulates coherently across steps while noise random-walks, so over a full run the bias can bite at smaller batches — but only if the per-step bias is correlated across steps, which we have not measured (it would take a multi-step / paired bf16-vs-fp32 trajectory experiment).
  • The crossover batch likely shrinks for larger models (deeper → more rounding accumulation, larger activation outliers, untied head), i.e. bf16 a bigger relative concern at scale — directly testable at 7B.

Prior work — a known effect, quantified here in a new place

The underlying mechanism is standard numerical analysis: under round-to-nearest, the rounding errors in a long sum or dot product are correlated (systematically the same sign once the inputs share a distribution), so they accumulate in proportion to the number of terms — a genuine bias — whereas stochastic rounding makes them zero-mean, so they grow only with the square root of the number of terms (Higham & Mary, probabilistic rounding-error analysis). Three recent LLM-training papers each touch a different facet of what we measure:

  • Small-batch training (arXiv 2507.07101) — the batch-size companion to the crossover above, not a precision result: it argues gradient accumulation is wasteful because small batches train stably once Adam's hyperparameters are scaled to batch size (holding the second-moment half-life fixed in tokens), so there is no need to fake a large batch. Relevant to the batch-size axis our crossover sits on, not the rounding bias.
  • Defeating the training-inference mismatch via FP16 (arXiv 2510.26788) — in the RL setting, attributes the train/inference log-probability mismatch to bf16's 7 mantissa bits versus fp16's 10, and reports roughly 24x smaller mismatch under fp16. Matches our "fp16 is the lever" result, measured on log-probabilities rather than gradients.
  • Why low-precision transformer training fails (arXiv 2510.04212) — tracks the cumulative low-minus-high-precision difference across training steps and finds it stays one-signed, i.e. the bias compounds coherently instead of cancelling (directly relevant to the compounding caveat above), through an attention/flash-attention-specific trigger.

What this comment adds is a per-parameter measurement of the bf16 gradient-computation bias on a real model, separated from sampling noise by the noise-floor and split-half checks, and expressed as a batch-size (token) crossover. The effect itself is well established; this particular quantification we have not found stated elsewhere, and offer it as such in case there is prior art we missed.

Stochastic rounding vs fp16

Stochastic rounding attacks the bias at its source rather than just shrinking it: each rounding becomes zero-mean, so the dominant first-order accumulation term (above) cancels in expectation. It does not remove the bias altogether — residual bias survives from nonlinearities (which turn SR's added per-sample variance into a second-order offset, via the curvature of softmax / the loss / activations) and from any rounding SR doesn't cover, and the variance it adds only pays off above the crossover. fp16 is an orthogonal lever, not superseded by SR: more mantissa bits shrink the per-rounding error (our measured ~4x, deterministic, no RNG), and it stacks with SR (fp16+SR injects less variance than bf16+SR). Crucially we have not measured SR here — our numbers are fp16-vs-bf16, both round-to-nearest — so SR-vs-fp16 is theory only for now; a real number means adding the SR cast in the matmul epilogue (software SR runs on current hardware; a hardware cast instruction exists only on Blackwell) and re-running this same decomposition.

Tooling: the forward+backward multi-sequence layer-wise mode (input_dataset + forward_only: false); the noise/bias decomposition is a small offline pass over the saved per-sequence gradient logs (element-aligned across sequences, since parameter-gradient shapes are sequence-length-independent).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jlamypoirier
Copy link
Copy Markdown
Collaborator Author

Claude Opus 4.8 note — does the flash-attention low-precision failure mode apply here?

TL;DR: the failure mode in arXiv 2510.04212 is a property of the paper's own attention reimplementation (which accumulates the attention output in bf16), not the official flash_attn kernel (fp32 accumulator). Their training code runs on that reimplementation rather than flash_attn; the maintainer notes the same; and we checked the official kernel directly.

  • Trigger present, output not biased. The paper's condition (a softmax row with ≥2 entries where bf16(exp(S−max))==1.0) fires on 4.0% of rows (layers 0–2). But on the official flash_attn kernel (bf16) vs fp32 on identical Q/K/V, the attention-output error is unbiased noise even on trigger rows (systematic share of the error ≈ 0.002); only layer 0 shows a faint one-signed residual (~0.5% of the signal). The overflow they describe needs a bf16 accumulator; the official kernel accumulates in fp32.
  • Where the bf16 error is. Per op, attention injects ~2× the relative bf16 error of MLP (~0.8% vs ~0.4%), but both are unbiased noise — so the systematic gradient bias above is a general all-matmul effect, not attention-specific.

To be clear, the paper's mechanism is sound for bf16-accumulating attention; it just doesn't carry over to the fp32-accumulating official kernel, so it isn't a concern for this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant