diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 78fd0ed487..2abd9824b6 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -283,6 +283,7 @@ def test_cross_attn( DISTRIBUTED_SCORE_MOD_DATA_SHAPES = { "L0": [], "L1": [(4, 16, 4, 64)], + "L2": [], }