Skip to content

fix attention template#840

Merged
ftynse merged 2 commits intomainfrom
users/ftynse/harden-attention
Feb 9, 2026
Merged

fix attention template#840
ftynse merged 2 commits intomainfrom
users/ftynse/harden-attention

Conversation

@ftynse
Copy link
Copy Markdown
Contributor

@ftynse ftynse commented Feb 6, 2026

  • fix the type signature for the mma_kind argument;
  • check that there are exactly 2 mma kinds;
  • don't assume divisibility in constraints

- fix the type signature for the `mma_kind` argument;
- check that there are exactly 2 mma kinds;
- don't assume divisibility in constraints

Signed-off-by: Alex Zinenko <git@ozinenko.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates the vanilla attention kernel template to improve mfma_variant typing/validation and to relax tiling constraints that previously assumed divisibility.

Changes:

  • Change mfma_variant parameter type to a sequence and validate it has exactly 2 entries.
  • Replace BLOCK_M / 4 wave constraint with a ceiling-based expression to avoid assuming divisibility.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

def get_vanilla_attention_kernel(
shape: AttentionShape,
mfma_variant: MMAType,
mfma_variant: Sequence[MMAType],
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If callers must pass exactly two MMA variants (as enforced below), it’s clearer to encode that in the type signature instead of a generic Sequence. Consider using a fixed-length tuple type (e.g., tuple[MMAType, MMAType] / Tuple[MMAType, MMAType]) so type-checkers can catch invalid call sites without relying on runtime validation.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, a tuple would be better

"Sliding window is only supported for causal attention."
)

assert len(mfma_variant) == 2, "mfma_variant must be a sequence of two MMATypes"
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using assert for argument validation is brittle because assertions can be stripped with Python optimizations (-O), which would remove this check and could lead to downstream index errors (e.g., mfma_variant[1]). Prefer raising a ValueError/TypeError with a helpful message instead.

Suggested change
assert len(mfma_variant) == 2, "mfma_variant must be a sequence of two MMATypes"
if len(mfma_variant) != 2:
raise ValueError("mfma_variant must be a sequence of two MMATypes")

Copilot uses AI. Check for mistakes.
"Sliding window is only supported for causal attention."
)

assert len(mfma_variant) == 2, "mfma_variant must be a sequence of two MMATypes"
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message could be more actionable and consistent with the type name (MMAType vs MMATypes). Consider including the received length/value (e.g., expected 2, got {len(mfma_variant)}) to make debugging easier.

Suggested change
assert len(mfma_variant) == 2, "mfma_variant must be a sequence of two MMATypes"
assert len(mfma_variant) == 2, (
f"mfma_variant must be a sequence of two MMAType elements; "
f"expected length 2, got {len(mfma_variant)}"
)

Copilot uses AI. Check for mistakes.
constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)]
constraints += [tkw.TilingConstraint(K2, BLOCK_K2)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 4)]
constraints += [tkw.WaveConstraint(M, sympy.ceiling(BLOCK_M / 4))]
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If BLOCK_M is a Python int, BLOCK_M / 4 produces a float, which can introduce floating-point rounding before sympy.ceiling(...) is applied. To keep this exact, use an integer-safe form (e.g., (BLOCK_M + 3) // 4 when BLOCK_M is integral) or a SymPy rational (e.g., sympy.ceiling(sympy.Rational(BLOCK_M, 4))) when BLOCK_M is symbolic/integer-like.

Suggested change
constraints += [tkw.WaveConstraint(M, sympy.ceiling(BLOCK_M / 4))]
constraints += [tkw.WaveConstraint(M, sympy.ceiling(sympy.Rational(BLOCK_M, 4)))]

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
To improve, we could add some documentation to the template and also move to a type with an explicit length rather than Sequence and assert len == 2 for mfma_variant

def get_vanilla_attention_kernel(
shape: AttentionShape,
mfma_variant: MMAType,
mfma_variant: Sequence[MMAType],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, a tuple would be better

Comment thread wave_lang/kernel/wave/templates/vanilla_attention.py
Comment thread wave_lang/kernel/wave/templates/vanilla_attention.py Outdated
Signed-off-by: Alex Zinenko <git@ozinenko.com>
@ftynse ftynse merged commit a7ae36f into main Feb 9, 2026
15 checks passed
@ftynse ftynse deleted the users/ftynse/harden-attention branch February 9, 2026 17:37
nirmie pushed a commit to nirmie/wave that referenced this pull request Mar 9, 2026
- fix the type signature for the `mma_kind` argument;
- check that there are exactly 2 mma kinds;
- don't assume divisibility in constraints

---------

Signed-off-by: Alex Zinenko <git@ozinenko.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants