Skip to content

Add converter for aten::_grouped_mm.default#2805

Open
Sid-V5 wants to merge 2 commits intomicrosoft:mainfrom
Sid-V5:add-grouped-mm-2795
Open

Add converter for aten::_grouped_mm.default#2805
Sid-V5 wants to merge 2 commits intomicrosoft:mainfrom
Sid-V5:add-grouped-mm-2795

Conversation

@Sid-V5
Copy link

@Sid-V5 Sid-V5 commented Feb 12, 2026

Implemented the converter for aten::_grouped_mm.default to address #2795.

Changes

  • Added aten_grouped_mm function in onnxscript/function_libs/torch_lib/ops/core.py

Implementation Details

The aten::_grouped_mm operator performs grouped matrix multiplication. This implementation handles the batch/dense mode (when offs is None), where groups are implicit in the batch dimension:

  • self: (G, M, K), mat2: (G, K, N) → result: (G, M, N)
  • Uses op.MatMul for the core computation
  • Supports optional bias addition via op.Add
  • Supports optional out_dtype casting via op.Cast

The offset-based mode (when offs is provided) raises NotImplementedError, as it requires segment-level matrix multiplications that are not directly expressible with standard ONNX operators.

Testing

The function follows the same patterns as other converters in core.py (e.g., aten_bmm, aten_mm) and uses the @torch_op decorator for automatic registration.

Fixes #2795

Implements the converter for aten::_grouped_mm.default to address issue microsoft#2795. Handles the batch/dense mode where groups are implicit in the batch dimension using MatMul, with optional bias addition and dtype casting.
@Sid-V5
Copy link
Author

Sid-V5 commented Feb 12, 2026

@microsoft-github-policy-service agree

@codecov
Copy link

codecov bot commented Feb 16, 2026

Codecov Report

❌ Patch coverage is 20.00000% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 60.42%. Comparing base (e6f79e1) to head (c540859).

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 20.00% 8 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (e6f79e1) and HEAD (c540859). Click for more details.

HEAD has 21 uploads less than BASE
Flag BASE (e6f79e1) HEAD (c540859)
40 19
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2805       +/-   ##
===========================================
- Coverage   70.52%   60.42%   -10.10%     
===========================================
  Files         228      219        -9     
  Lines       27135    25620     -1515     
  Branches     2727     2633       -94     
===========================================
- Hits        19137    15481     -3656     
- Misses       7065     9355     +2290     
+ Partials      933      784      -149     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

Missing converter for OpOverload(op='aten._grouped_mm', overload='default')

2 participants