[Relax][ONNX] Fix LayerNormalization no-bias zero tensor shape and dtype#19772
[Relax][ONNX] Fix LayerNormalization no-bias zero tensor shape and dtype#19772javierdejesusda wants to merge 2 commits into
Conversation
When the optional bias input of LayerNormalization is omitted, the zero bias was built from data.struct_info.shape[1] and hardcoded to float32 instead of following the scale (gamma) tensor. For a non-square input such as [2, 3, 4, 8] with scale [8], this produced a bias of shape (3,) while gamma is (8,), so relax.op.nn.layer_norm raised an InternalError on the size mismatch. For a half-precision model with no bias, the float32 bias was rejected because gamma, beta, and data must share one dtype. Synthesize the zero bias from gamma_shape and the scale dtype, matching ONNX semantics where an omitted B is treated as zeros shaped and typed like the scale. Add non-square no-bias regression cases: an fp16 case checked end to end and a bf16 case checked through the importer, since ONNX Runtime's CPU provider has no bf16 LayerNormalization kernel. Fixes apache#19691
There was a problem hiding this comment.
Code Review
This pull request fixes an issue in the ONNX frontend's LayerNormalization importer when no bias is provided. It updates the fallback bias creation to use the shape and data type of the scale (gamma_shape and scale.struct_info.dtype) instead of hardcoding a float32 array based on the input's second dimension. It also adds corresponding unit tests for non-square inputs, float16, and bfloat16 data types. A review comment correctly points out that using bfloat16 directly in _np.zeros will raise a TypeError in standard NumPy environments, and suggests constructing the NumPy array as float32 and letting relax.const handle the conversion to the target dtype.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
np.zeros rejects TVM dtype strings that NumPy lacks natively, so np.zeros(gamma_shape, dtype="bfloat16") raises "data type 'bfloat16' not understood". relax.const imports ml_dtypes and casts internally, but its np.zeros argument is evaluated first, so that import is too late. Build the zeros array with a native dtype and pass the target dtype to relax.const, matching the existing torch frontend convention.
Root cause
In the ONNX
LayerNormalizationspec the biasBis optional; when omitted it should behave aszeros shaped and typed like the scale
W. InLayerNormalization._impl_v17, the synthesized zerobias instead took its shape from
data.struct_info.shape[1](an unrelated data dim) and hardcodeddtype="float32". For input[2, 3, 4, 8]with scale[8]andaxis=-1this builds a bias ofshape
(3,)while gamma is(8,), sorelax.op.nn.layer_normraises a size-mismatchInternalError. The float32 hardcode also breaks fp16/bf16 no-bias models, since gamma, beta, anddata must share a dtype. PyTorch's
nn.LayerNorm(..., bias=False)exports exactly this no-bias form.Fix
Derive both the shape and dtype of the synthesized zero bias from the scale, matching the ONNX
semantics for an omitted
Band the existing torch frontend(
relax.const(np.zeros(shape), x.struct_info.dtype)):gamma_shapeand the_np/get_const_tupleimports are already present. Deriving the dtype fromthe scale (rather than the issue's float32-only suggestion) is what also fixes the fp16/bf16 case.
Test plan
Added non-square no-bias regression cases to
test_frontend_onnx.py::test_layer_norm(the previousno-bias case was square, which masked the bug): float32
[2,3,4,8]/scale[8]and float16 withfull
check_correctness, plus a bf16 importer-only case (ORT's CPU provider has no bf16LayerNormalization kernel).
Fixes #19691