Skip to content

Add blocked matmul (linalg.mmt4d) op support#265

Draft
MaheshRavishankar wants to merge 1 commit intoiree-org:mainfrom
MaheshRavishankar:users/MaheshRavishankar/blockedMatmul
Draft

Add blocked matmul (linalg.mmt4d) op support#265
MaheshRavishankar wants to merge 1 commit intoiree-org:mainfrom
MaheshRavishankar:users/MaheshRavishankar/blockedMatmul

Conversation

@MaheshRavishankar
Copy link
Copy Markdown
Contributor

Add BlockedMatmulNode for tiled matrix multiplication that lowers to linalg.mmt4d via torch_c casts:
LHS logical [M0, K0, M1, K1] x RHS logical [K0, N0, K1, N1]
-> OUT [M0, N0, M1, N1]

RHS must be specified with transposed strides (physical [N0, K0, N1, K1]) matching linalg.mmt4d's expected layout. Non-transposed RHS returns a NotImplemented error.

The emitter casts torch tensors to builtin tensors (torch_c), applies linalg.fill + linalg.mmt4d, and casts the result back. No permute ops are needed since the physical layout is used directly.

New files:

  • BlockedMatmulAttr (attributes)
  • BlockedMatmulNode (node with validation and shape inference)
  • ASM emitter with getBuiltinTensorTypeAsm helper
  • Lit test verifying MLIR structure, flow compilation, single dispatch
  • E2E sample verifying numerical correctness on CPU

Add BlockedMatmulNode for tiled matrix multiplication that lowers to
linalg.mmt4d via torch_c casts:
  LHS logical [M0, K0, M1, K1] x RHS logical [K0, N0, K1, N1]
  -> OUT [M0, N0, M1, N1]

RHS must be specified with transposed strides (physical [N0, K0, N1, K1])
matching linalg.mmt4d's expected layout. Non-transposed RHS returns a
NotImplemented error.

The emitter casts torch tensors to builtin tensors (torch_c), applies
linalg.fill + linalg.mmt4d, and casts the result back. No permute ops
are needed since the physical layout is used directly.

New files:
- BlockedMatmulAttr (attributes)
- BlockedMatmulNode (node with validation and shape inference)
- ASM emitter with getBuiltinTensorTypeAsm helper
- Lit test verifying MLIR structure, flow compilation, single dispatch
- E2E sample verifying numerical correctness on CPU

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
@MaheshRavishankar MaheshRavishankar force-pushed the users/MaheshRavishankar/blockedMatmul branch from 7966694 to 5cb5565 Compare March 24, 2026 17:59
Copy link
Copy Markdown
Member

@sjain-stanford sjain-stanford left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I finally took time to think about the design change (of directly emitting linalg dialect ops) in the case for mmt4d and looked at some alternatives. Unfortunately, none of these alternatives seem reasonable FWICT:

  1. There is no torch dialect op that captures mmt4d semantics (torch.aten.matmul can't express mmt4d's contraction pattern).
  2. torch.aten.einsum could in theory express the contraction string (ijkl,mjln->imkn) but I don't think we handle lowering this to linalg.mmt4d in torch-mlir and moreover IREE might not recover the tiling structure from a generic einsum (correct me if I'm mistaken)
  3. Reshape + torch.aten.mm + reshape by collapsing [M0, K0, M1, K1] -> [M0K0, M1K1]. This might work semantically but defeats the whole point of blocked matmul in preserving the tile structure.
  4. CustomOp -> this would be less ergonomic and loses validation, shape inference and the structured API that the BlockedMatmulNode provides.
  5. Add torch.aten.mmt4d to upstream torch-mlir - cleanest but it is not a PyTorch operation so this might not make sense to be added to the dialect.

In this scenario considering we've exhausted all options, I think this change is justified.

My asks to get this in shape for landing are as follows:

  1. Explicitly document this in the docstrings for blocked_matmul_node.h or the emitter explaining why this is implemented as a direct linalg lowering, to prevent future contributors from casually proliferating linalg ops where alternatives exist. This should be an outlier, not the norm.
  2. Fix CI failures / rebase.
  3. Address comments.

Comment on lines +110 to +118
// RHS must be transposed: logical [K0, N0, K1, N1] must have physical
// layout [N0, K0, N1, K1] for linalg.mmt4d. This corresponds to
// logical-to-physical permutation [1, 0, 3, 2].
std::vector<int64_t> rhsPerm = rhsT->getLogicalToPhysicalPermuteOrder();
std::vector<int64_t> expectedPerm = {1, 0, 3, 2};
FUSILLI_RETURN_ERROR_IF(
rhsPerm != expectedPerm, ErrorCode::NotImplemented,
"BlockedMatmul only supports RHS with transposed physical layout "
"[N0, K0, N1, K1]. Non-transposed RHS is not yet supported");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the emitter directly emits physical dimensions and since linalg.mmt4d expects LHS in physical layout [M0, K0, M1, K1], add a check that LHS is contiguous. You are checking RHS has the correct permutation already but we need it for LHS too. Please also add unit tests for when this condition is not satisfied.

// LHS must be contiguous (identity permutation) for linalg.mmt4d.
std::vector<int64_t> lhsPerm = lhsT->getLogicalToPhysicalPermuteOrder();
std::vector<int64_t> identityPerm = {0, 1, 2, 3};
FUSILLI_RETURN_ERROR_IF(
    lhsPerm != identityPerm, ErrorCode::NotImplemented,
    "BlockedMatmul only supports contiguous LHS (identity permutation). "
    "Non-contiguous LHS is not yet supported");

@@ -0,0 +1,111 @@
// Copyright 2025 Advanced Micro Devices, Inc.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Copyright 2025 Advanced Micro Devices, Inc.
// Copyright 2026 Advanced Micro Devices, Inc.

@@ -0,0 +1,94 @@
// Copyright 2025 Advanced Micro Devices, Inc.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Copyright 2025 Advanced Micro Devices, Inc.
// Copyright 2026 Advanced Micro Devices, Inc.

@@ -0,0 +1,170 @@
// Copyright 2025 Advanced Micro Devices, Inc.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Copyright 2025 Advanced Micro Devices, Inc.
// Copyright 2026 Advanced Micro Devices, Inc.

@@ -0,0 +1,48 @@
// Copyright 2025 Advanced Micro Devices, Inc.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Copyright 2025 Advanced Micro Devices, Inc.
// Copyright 2026 Advanced Micro Devices, Inc.

constexpr std::string_view schema = R"(
%{0}_lhs_builtin = torch_c.to_builtin_tensor {1} : {2} -> {3}
%{0}_rhs_builtin = torch_c.to_builtin_tensor {4} : {5} -> {6}
%{0}_cst = arith.constant 0.000000e+00 : {7}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hardcodes a floating point zero literal. Do we ever plan on supporting integer blocked matmul? If not, this is fine.

#include "fusilli/node/custom_op_node.h"
#include "fusilli/node/layernorm_node.h"
#include "fusilli/node/pointwise_node.h"
#include "fusilli/node/rmsnorm_node.h"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Include the blocked_matmul_node.h

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants