diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 3d9dfba9a16b..113193263d86 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3834,8 +3834,7 @@ def _impl_v17(cls, bb, inputs, attr, params): gamma_shape = get_const_tuple(scale.struct_info.shape) if bias is None: - seq_len = data.struct_info.shape[1].value - bias = relax.const([0.0] * seq_len, dtype="float32") + bias = relax.const(_np.zeros(gamma_shape), dtype=scale.struct_info.dtype) else: beta_shape = get_const_tuple(bias.struct_info.shape) if gamma_shape != beta_shape: diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 8d8c1bc54b9b..7f77dac0c876 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2330,6 +2330,72 @@ def test_layer_norm(): model = helper.make_model(graph, producer_name="layer_norm_test") check_correctness(model) + # No bias with a non-square input where data.shape[1] differs from the scale + # shape, see https://github.com/apache/tvm/issues/19691. + layer_norm_node = helper.make_node( + "LayerNormalization", ["input", "scale"], ["Y"], axis=-1, epsilon=1e-12 + ) + + graph = helper.make_graph( + [layer_norm_node], + "layer_norm_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 3, 4, 8]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [8]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4, 8]), + ], + ) + + model = helper.make_model(graph, producer_name="layer_norm_test") + check_correctness(model) + + # No bias with a non-square fp16 input. The synthesized zero bias must match + # the scale dtype, otherwise layer_norm rejects the float32 bias, see + # https://github.com/apache/tvm/issues/19691. + layer_norm_node = helper.make_node( + "LayerNormalization", ["input", "scale"], ["Y"], axis=-1, epsilon=1e-12 + ) + + graph = helper.make_graph( + [layer_norm_node], + "layer_norm_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT16, [2, 3, 4, 8]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT16, [8]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT16, [2, 3, 4, 8]), + ], + ) + + model = helper.make_model(graph, producer_name="layer_norm_test") + check_correctness(model, opset=17, atol=1e-2, rtol=1e-2) + + # Same no-bias path for bf16. ONNX Runtime's CPU provider has no bf16 + # LayerNormalization kernel, so this only checks the importer builds the + # graph with a bf16 zero bias (the dtype the fix derives from the scale). + layer_norm_node = helper.make_node( + "LayerNormalization", ["input", "scale"], ["Y"], axis=-1, epsilon=1e-12 + ) + + graph = helper.make_graph( + [layer_norm_node], + "layer_norm_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.BFLOAT16, [2, 3, 4, 8]), + helper.make_tensor_value_info("scale", TensorProto.BFLOAT16, [8]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.BFLOAT16, [2, 3, 4, 8]), + ], + ) + + model = helper.make_model(graph, producer_name="layer_norm_test") + model.opset_import[0].version = 17 + from_onnx(model, opset=17, keep_params_in_input=True) + def test_layer_norm_with_nd_gamma_beta(): layer_norm_node = helper.make_node(