Skip to content

Add L2 score mod distributed attention shape#3147

Open
vcherepanov-nv wants to merge 1 commit into
NVIDIA:mainfrom
vcherepanov-nv:fix-jax-l2-tests
Open

Add L2 score mod distributed attention shape#3147
vcherepanov-nv wants to merge 1 commit into
NVIDIA:mainfrom
vcherepanov-nv:fix-jax-l2-tests

Conversation

@vcherepanov-nv

Copy link
Copy Markdown
Collaborator

Description

Add L2 score mod distributed attention shape

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add L2 shape to fix L2 tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds the \"L2\" key to DISTRIBUTED_SCORE_MOD_DATA_SHAPES in the distributed fused-attention test file. Before this change, running the test suite with NVTE_JAX_UNITTEST_LEVEL=L2 would raise a ValueError because the key was absent; the addition prevents that crash.

  • The new entry is \"L2\": [], an empty list, which resolves to zero parametrized test cases at L2 — so while the ValueError is avoided, no score-mod distributed-attention tests actually run at that level.
  • Every other shape dictionary in this file (DISTRIBUTED_SELF_ATTN_DATA_SHAPES, DISTRIBUTED_CROSS_ATTN_DATA_SHAPES) provides a distinct, non-empty shape tuple for L2, suggesting at least one concrete shape should be supplied here as well.

Confidence Score: 4/5

The change prevents a crash when running at L2 level, but leaves L2 with no actual test coverage for the score-mod attention path.

The single-line fix resolves the ValueError from the missing L2 key, but "L2": [] means zero test cases are parametrized when the suite runs at L2. The stated goal — fixing L2 tests — is not achieved: no score-mod distributed-attention tests execute at that level.

tests/jax/test_distributed_fused_attn.py — the L2 entry in DISTRIBUTED_SCORE_MOD_DATA_SHAPES needs at least one shape tuple.

Important Files Changed

Filename Overview
tests/jax/test_distributed_fused_attn.py Adds "L2": [] to DISTRIBUTED_SCORE_MOD_DATA_SHAPES; while this prevents the ValueError that occurred when L2 was missing from the dict, the empty list means no score-mod tests are parametrized or executed at L2 level.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["Test run starts\n(NVTE_JAX_UNITTEST_LEVEL=L2)"] --> B["pytest_parametrize_wrapper\ncalled with DISTRIBUTED_SCORE_MOD_DATA_SHAPES"]
    B --> C["get_parameters_for_test_level\nlooks up 'L2' key"]
    C --> D{"Key exists?"}
    D -- "Before PR\n(key missing)" --> E["ValueError:\nUnsupported test level"]
    D -- "After PR\n(key = [])" --> F["returns empty list []"]
    F --> G["pytest.mark.parametrize\n('data_shape', [])"]
    G --> H["0 test cases collected\nor pytest warning/skip"]
    H --> I["No score-mod L2 tests run"]
    style E fill:#f88,stroke:#c00
    style I fill:#fa8,stroke:#c60
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["Test run starts\n(NVTE_JAX_UNITTEST_LEVEL=L2)"] --> B["pytest_parametrize_wrapper\ncalled with DISTRIBUTED_SCORE_MOD_DATA_SHAPES"]
    B --> C["get_parameters_for_test_level\nlooks up 'L2' key"]
    C --> D{"Key exists?"}
    D -- "Before PR\n(key missing)" --> E["ValueError:\nUnsupported test level"]
    D -- "After PR\n(key = [])" --> F["returns empty list []"]
    F --> G["pytest.mark.parametrize\n('data_shape', [])"]
    G --> H["0 test cases collected\nor pytest warning/skip"]
    H --> I["No score-mod L2 tests run"]
    style E fill:#f88,stroke:#c00
    style I fill:#fa8,stroke:#c60
Loading

Reviews (2): Last reviewed commit: "Add L2 score mod distributed attention s..." | Re-trigger Greptile

Comment thread tests/jax/test_distributed_fused_attn.py Outdated
DISTRIBUTED_SCORE_MOD_DATA_SHAPES = {
"L0": [],
"L1": [(4, 16, 4, 64)],
"L2": [(4, 16, 4, 64)],

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be (assuming you want this to run as L1 test):

DISTRIBUTED_SCORE_MOD_DATA_SHAPES = {
      "L0": [],
      "L1": [(4, 16, 4, 64)],
      "L2": [],
  }

What you have will run the same tests for L1 and L2 there by duplicating effort

Please urgently launch a pipeline with a JAX build manually for L0, L1 and L2 levels and confirm that it runs successfully before merging

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
DISTRIBUTED_SCORE_MOD_DATA_SHAPES = {
"L0": [],
"L1": [(4, 16, 4, 64)],
"L2": [],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 L2 level resolves to zero test cases

"L2": [] is passed to pytest_parametrize_wrapper, which calls get_parameters_for_test_level and returns the empty list. That list is forwarded directly to pytest.mark.parametrize("data_shape", []). With an empty parametrize set, pytest either skips the test entirely or raises a collection error depending on the --empty-parameter-set-mark config, so when NVTE_JAX_UNITTEST_LEVEL=L2 is used in CI no TestDistributedScoreModSelfAttn cases will execute. The PR description says this change "fixes L2 tests", but the fix needs at least one concrete shape tuple — the same pattern used by DISTRIBUTED_SELF_ATTN_DATA_SHAPES where L2 carries [(32, 512, 12, 64)].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants