Skip to content

Conversation

@timmoon10
Copy link
Collaborator

Description

This PR adds ops needed for the grouped MLP block in Mixture-of-Experts models. In particular, it adds a grouped linear op (similar to the GroupedLinear module) and a ScaledSwiGLU op. It is the same as #2622, but doesn't include the fused ops with experimental kernels. Closes #2560.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add grouped linear op
  • Add scaled SwiGLU op
  • Handle edge cases in noop_cat function

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added the enhancement New feature or request label Feb 9, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Greptile Overview

Greptile Summary

This PR adds new PyTorch fusible ops required for MoE grouped MLP blocks: a new GroupedLinear basic op backed by general_grouped_gemm (with support for FP8/MXFP8 quantized compute and quantized weights), and new SwiGLU variants in a dedicated swiglu.py module including optional gate interleaving and a ScaledSwiGLU op that applies per-row scaling. It also updates noop_cat to better handle a split_quantize edge case where manually-constructed subviews can confuse PyTorch’s storage bounds checks.

The changes integrate by exporting the new ops in transformer_engine/pytorch/ops/basic/__init__.py and adding targeted unit tests covering grouped linear, scaled swiglu, interleaved swiglu, and an end-to-end grouped MLP pipeline.

Confidence Score: 3/5

  • This PR is mergeable after addressing a couple of correctness/robustness issues around split-size handling and test assumptions.
  • Core wiring and new ops are integrated and tested, but GroupedLinear lacks an explicit validation that split sizes sum to the input’s leading dimension, and the new unit test currently always includes a zero-sized split (which may be unsupported depending on kernel path). Fixing/clarifying these will reduce runtime failure risk.
  • transformer_engine/pytorch/ops/basic/grouped_linear.py; tests/pytorch/test_fusible_ops.py

Important Files Changed

Filename Overview
tests/pytorch/test_fusible_ops.py Adds tests for GroupedLinear, ScaledSwiGLU, interleaved SwiGLU, and grouped MLP; test_grouped_linear always includes a zero-sized split which may be unsupported and should be clarified/handled.
transformer_engine/pytorch/module/_common.py Updates noop_cat to use as_strided and adds storage-size guard to fall back to torch.cat for split_quantize edge case; change is localized and intended to prevent OOB view reconstruction.
transformer_engine/pytorch/ops/basic/init.py Exports new GroupedLinear and moves SwiGLU variants to new swiglu module; import wiring looks consistent.
transformer_engine/pytorch/ops/basic/activation.py Removes SwiGLU and ClampedSwiGLU from activation.py as they’re moved into swiglu.py; no logic remains here beyond other activations.
transformer_engine/pytorch/ops/basic/grouped_linear.py Introduces new fusible GroupedLinear op supporting quantized compute/weights; forward only validates split count (not sum vs input length), which can lead to runtime errors or mis-slicing with bad splits.
transformer_engine/pytorch/ops/basic/swiglu.py Adds new fusible SwiGLU variants including optional gate interleaving and ScaledSwiGLU; logic aligns with tests and prior discussion on vecdot gradient behavior.

Sequence Diagram

sequenceDiagram
participant T as Test
participant GL as GroupedLinear
participant SS as ScaledSwiGLU
participant TE as tex
participant GG as general_grouped_gemm
T->>GL: forward(x, split_sizes)
GL->>TE: split_quantize(x) (fp8 only)
GL->>GG: grouped_gemm fprop
GG-->>GL: out
T->>SS: forward(out, scales)
SS->>TE: swiglu
T->>SS: backward(dy)
SS->>TE: dswiglu
T->>GL: backward(dy)
GL->>GG: grouped_gemm dgrad/wgrad
Loading

greptile-apps[bot]

This comment was marked as resolved.

greptile-apps[bot]

This comment was marked as resolved.

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[PyTorch] Support grouped linear op in te.Sequential

1 participant