Conversation
- fix the type signature for the `mma_kind` argument; - check that there are exactly 2 mma kinds; - don't assume divisibility in constraints Signed-off-by: Alex Zinenko <git@ozinenko.com>
There was a problem hiding this comment.
Pull request overview
Updates the vanilla attention kernel template to improve mfma_variant typing/validation and to relax tiling constraints that previously assumed divisibility.
Changes:
- Change
mfma_variantparameter type to a sequence and validate it has exactly 2 entries. - Replace
BLOCK_M / 4wave constraint with a ceiling-based expression to avoid assuming divisibility.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def get_vanilla_attention_kernel( | ||
| shape: AttentionShape, | ||
| mfma_variant: MMAType, | ||
| mfma_variant: Sequence[MMAType], |
There was a problem hiding this comment.
If callers must pass exactly two MMA variants (as enforced below), it’s clearer to encode that in the type signature instead of a generic Sequence. Consider using a fixed-length tuple type (e.g., tuple[MMAType, MMAType] / Tuple[MMAType, MMAType]) so type-checkers can catch invalid call sites without relying on runtime validation.
There was a problem hiding this comment.
I agree, a tuple would be better
| "Sliding window is only supported for causal attention." | ||
| ) | ||
|
|
||
| assert len(mfma_variant) == 2, "mfma_variant must be a sequence of two MMATypes" |
There was a problem hiding this comment.
Using assert for argument validation is brittle because assertions can be stripped with Python optimizations (-O), which would remove this check and could lead to downstream index errors (e.g., mfma_variant[1]). Prefer raising a ValueError/TypeError with a helpful message instead.
| assert len(mfma_variant) == 2, "mfma_variant must be a sequence of two MMATypes" | |
| if len(mfma_variant) != 2: | |
| raise ValueError("mfma_variant must be a sequence of two MMATypes") |
| "Sliding window is only supported for causal attention." | ||
| ) | ||
|
|
||
| assert len(mfma_variant) == 2, "mfma_variant must be a sequence of two MMATypes" |
There was a problem hiding this comment.
The error message could be more actionable and consistent with the type name (MMAType vs MMATypes). Consider including the received length/value (e.g., expected 2, got {len(mfma_variant)}) to make debugging easier.
| assert len(mfma_variant) == 2, "mfma_variant must be a sequence of two MMATypes" | |
| assert len(mfma_variant) == 2, ( | |
| f"mfma_variant must be a sequence of two MMAType elements; " | |
| f"expected length 2, got {len(mfma_variant)}" | |
| ) |
| constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] | ||
| constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] | ||
| constraints += [tkw.WaveConstraint(M, BLOCK_M / 4)] | ||
| constraints += [tkw.WaveConstraint(M, sympy.ceiling(BLOCK_M / 4))] |
There was a problem hiding this comment.
If BLOCK_M is a Python int, BLOCK_M / 4 produces a float, which can introduce floating-point rounding before sympy.ceiling(...) is applied. To keep this exact, use an integer-safe form (e.g., (BLOCK_M + 3) // 4 when BLOCK_M is integral) or a SymPy rational (e.g., sympy.ceiling(sympy.Rational(BLOCK_M, 4))) when BLOCK_M is symbolic/integer-like.
| constraints += [tkw.WaveConstraint(M, sympy.ceiling(BLOCK_M / 4))] | |
| constraints += [tkw.WaveConstraint(M, sympy.ceiling(sympy.Rational(BLOCK_M, 4)))] |
martin-luecke
left a comment
There was a problem hiding this comment.
LGTM
To improve, we could add some documentation to the template and also move to a type with an explicit length rather than Sequence and assert len == 2 for mfma_variant
| def get_vanilla_attention_kernel( | ||
| shape: AttentionShape, | ||
| mfma_variant: MMAType, | ||
| mfma_variant: Sequence[MMAType], |
There was a problem hiding this comment.
I agree, a tuple would be better
Signed-off-by: Alex Zinenko <git@ozinenko.com>
- fix the type signature for the `mma_kind` argument; - check that there are exactly 2 mma kinds; - don't assume divisibility in constraints --------- Signed-off-by: Alex Zinenko <git@ozinenko.com>
mma_kindargument;