Add blocked matmul (linalg.mmt4d) op support#265
Add blocked matmul (linalg.mmt4d) op support#265MaheshRavishankar wants to merge 1 commit intoiree-org:mainfrom
Conversation
Add BlockedMatmulNode for tiled matrix multiplication that lowers to linalg.mmt4d via torch_c casts: LHS logical [M0, K0, M1, K1] x RHS logical [K0, N0, K1, N1] -> OUT [M0, N0, M1, N1] RHS must be specified with transposed strides (physical [N0, K0, N1, K1]) matching linalg.mmt4d's expected layout. Non-transposed RHS returns a NotImplemented error. The emitter casts torch tensors to builtin tensors (torch_c), applies linalg.fill + linalg.mmt4d, and casts the result back. No permute ops are needed since the physical layout is used directly. New files: - BlockedMatmulAttr (attributes) - BlockedMatmulNode (node with validation and shape inference) - ASM emitter with getBuiltinTensorTypeAsm helper - Lit test verifying MLIR structure, flow compilation, single dispatch - E2E sample verifying numerical correctness on CPU Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
7966694 to
5cb5565
Compare
sjain-stanford
left a comment
There was a problem hiding this comment.
I finally took time to think about the design change (of directly emitting linalg dialect ops) in the case for mmt4d and looked at some alternatives. Unfortunately, none of these alternatives seem reasonable FWICT:
- There is no torch dialect op that captures mmt4d semantics (
torch.aten.matmulcan't express mmt4d's contraction pattern). torch.aten.einsumcould in theory express the contraction string (ijkl,mjln->imkn) but I don't think we handle lowering this to linalg.mmt4d in torch-mlir and moreover IREE might not recover the tiling structure from a generic einsum (correct me if I'm mistaken)- Reshape +
torch.aten.mm+ reshape by collapsing [M0, K0, M1, K1] -> [M0K0, M1K1]. This might work semantically but defeats the whole point of blocked matmul in preserving the tile structure. - CustomOp -> this would be less ergonomic and loses validation, shape inference and the structured API that the
BlockedMatmulNodeprovides. - Add
torch.aten.mmt4dto upstream torch-mlir - cleanest but it is not a PyTorch operation so this might not make sense to be added to the dialect.
In this scenario considering we've exhausted all options, I think this change is justified.
My asks to get this in shape for landing are as follows:
- Explicitly document this in the docstrings for
blocked_matmul_node.hor the emitter explaining why this is implemented as a direct linalg lowering, to prevent future contributors from casually proliferating linalg ops where alternatives exist. This should be an outlier, not the norm. - Fix CI failures / rebase.
- Address comments.
| // RHS must be transposed: logical [K0, N0, K1, N1] must have physical | ||
| // layout [N0, K0, N1, K1] for linalg.mmt4d. This corresponds to | ||
| // logical-to-physical permutation [1, 0, 3, 2]. | ||
| std::vector<int64_t> rhsPerm = rhsT->getLogicalToPhysicalPermuteOrder(); | ||
| std::vector<int64_t> expectedPerm = {1, 0, 3, 2}; | ||
| FUSILLI_RETURN_ERROR_IF( | ||
| rhsPerm != expectedPerm, ErrorCode::NotImplemented, | ||
| "BlockedMatmul only supports RHS with transposed physical layout " | ||
| "[N0, K0, N1, K1]. Non-transposed RHS is not yet supported"); |
There was a problem hiding this comment.
Since the emitter directly emits physical dimensions and since linalg.mmt4d expects LHS in physical layout [M0, K0, M1, K1], add a check that LHS is contiguous. You are checking RHS has the correct permutation already but we need it for LHS too. Please also add unit tests for when this condition is not satisfied.
// LHS must be contiguous (identity permutation) for linalg.mmt4d.
std::vector<int64_t> lhsPerm = lhsT->getLogicalToPhysicalPermuteOrder();
std::vector<int64_t> identityPerm = {0, 1, 2, 3};
FUSILLI_RETURN_ERROR_IF(
lhsPerm != identityPerm, ErrorCode::NotImplemented,
"BlockedMatmul only supports contiguous LHS (identity permutation). "
"Non-contiguous LHS is not yet supported");| @@ -0,0 +1,111 @@ | |||
| // Copyright 2025 Advanced Micro Devices, Inc. | |||
There was a problem hiding this comment.
| // Copyright 2025 Advanced Micro Devices, Inc. | |
| // Copyright 2026 Advanced Micro Devices, Inc. |
| @@ -0,0 +1,94 @@ | |||
| // Copyright 2025 Advanced Micro Devices, Inc. | |||
There was a problem hiding this comment.
| // Copyright 2025 Advanced Micro Devices, Inc. | |
| // Copyright 2026 Advanced Micro Devices, Inc. |
| @@ -0,0 +1,170 @@ | |||
| // Copyright 2025 Advanced Micro Devices, Inc. | |||
There was a problem hiding this comment.
| // Copyright 2025 Advanced Micro Devices, Inc. | |
| // Copyright 2026 Advanced Micro Devices, Inc. |
| @@ -0,0 +1,48 @@ | |||
| // Copyright 2025 Advanced Micro Devices, Inc. | |||
There was a problem hiding this comment.
| // Copyright 2025 Advanced Micro Devices, Inc. | |
| // Copyright 2026 Advanced Micro Devices, Inc. |
| constexpr std::string_view schema = R"( | ||
| %{0}_lhs_builtin = torch_c.to_builtin_tensor {1} : {2} -> {3} | ||
| %{0}_rhs_builtin = torch_c.to_builtin_tensor {4} : {5} -> {6} | ||
| %{0}_cst = arith.constant 0.000000e+00 : {7} |
There was a problem hiding this comment.
This hardcodes a floating point zero literal. Do we ever plan on supporting integer blocked matmul? If not, this is fine.
| #include "fusilli/node/custom_op_node.h" | ||
| #include "fusilli/node/layernorm_node.h" | ||
| #include "fusilli/node/pointwise_node.h" | ||
| #include "fusilli/node/rmsnorm_node.h" |
There was a problem hiding this comment.
Include the blocked_matmul_node.h
Add BlockedMatmulNode for tiled matrix multiplication that lowers to linalg.mmt4d via torch_c casts:
LHS logical [M0, K0, M1, K1] x RHS logical [K0, N0, K1, N1]
-> OUT [M0, N0, M1, N1]
RHS must be specified with transposed strides (physical [N0, K0, N1, K1]) matching linalg.mmt4d's expected layout. Non-transposed RHS returns a NotImplemented error.
The emitter casts torch tensors to builtin tensors (torch_c), applies linalg.fill + linalg.mmt4d, and casts the result back. No permute ops are needed since the physical layout is used directly.
New files: