Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3834,8 +3834,7 @@ def _impl_v17(cls, bb, inputs, attr, params):
gamma_shape = get_const_tuple(scale.struct_info.shape)

if bias is None:
seq_len = data.struct_info.shape[1].value
bias = relax.const([0.0] * seq_len, dtype="float32")
bias = relax.const(_np.zeros(gamma_shape), dtype=scale.struct_info.dtype)
else:
beta_shape = get_const_tuple(bias.struct_info.shape)
if gamma_shape != beta_shape:
Expand Down
66 changes: 66 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2330,6 +2330,72 @@ def test_layer_norm():
model = helper.make_model(graph, producer_name="layer_norm_test")
check_correctness(model)

# No bias with a non-square input where data.shape[1] differs from the scale
# shape, see https://github.com/apache/tvm/issues/19691.
layer_norm_node = helper.make_node(
"LayerNormalization", ["input", "scale"], ["Y"], axis=-1, epsilon=1e-12
)

graph = helper.make_graph(
[layer_norm_node],
"layer_norm_test",
inputs=[
helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 3, 4, 8]),
helper.make_tensor_value_info("scale", TensorProto.FLOAT, [8]),
],
outputs=[
helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4, 8]),
],
)

model = helper.make_model(graph, producer_name="layer_norm_test")
check_correctness(model)

# No bias with a non-square fp16 input. The synthesized zero bias must match
# the scale dtype, otherwise layer_norm rejects the float32 bias, see
# https://github.com/apache/tvm/issues/19691.
layer_norm_node = helper.make_node(
"LayerNormalization", ["input", "scale"], ["Y"], axis=-1, epsilon=1e-12
)

graph = helper.make_graph(
[layer_norm_node],
"layer_norm_test",
inputs=[
helper.make_tensor_value_info("input", TensorProto.FLOAT16, [2, 3, 4, 8]),
helper.make_tensor_value_info("scale", TensorProto.FLOAT16, [8]),
],
outputs=[
helper.make_tensor_value_info("Y", TensorProto.FLOAT16, [2, 3, 4, 8]),
],
)

model = helper.make_model(graph, producer_name="layer_norm_test")
check_correctness(model, opset=17, atol=1e-2, rtol=1e-2)

# Same no-bias path for bf16. ONNX Runtime's CPU provider has no bf16
# LayerNormalization kernel, so this only checks the importer builds the
# graph with a bf16 zero bias (the dtype the fix derives from the scale).
layer_norm_node = helper.make_node(
"LayerNormalization", ["input", "scale"], ["Y"], axis=-1, epsilon=1e-12
)

graph = helper.make_graph(
[layer_norm_node],
"layer_norm_test",
inputs=[
helper.make_tensor_value_info("input", TensorProto.BFLOAT16, [2, 3, 4, 8]),
helper.make_tensor_value_info("scale", TensorProto.BFLOAT16, [8]),
],
outputs=[
helper.make_tensor_value_info("Y", TensorProto.BFLOAT16, [2, 3, 4, 8]),
],
)

model = helper.make_model(graph, producer_name="layer_norm_test")
model.opset_import[0].version = 17
from_onnx(model, opset=17, keep_params_in_input=True)


def test_layer_norm_with_nd_gamma_beta():
layer_norm_node = helper.make_node(
Expand Down
Loading