Skip to content

TMA outer-reduction: Support multi-input fusions#6036

Open
tbqh wants to merge 1 commit intomainfrom
tbqh/multi_input_tma
Open

TMA outer-reduction: Support multi-input fusions#6036
tbqh wants to merge 1 commit intomainfrom
tbqh/multi_input_tma

Conversation

@tbqh
Copy link
Collaborator

@tbqh tbqh commented Mar 17, 2026

Fix a canonicalization issue with reduction_outer_tma, and support fusions with multiple inputs. Update TmaOuterReductionTest to cover these new cases. Also disabled Welford op for inner+outer TMA since supporting it is complicated.

@tbqh tbqh requested a review from liqiangxl March 17, 2026 10:52
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 17, 2026

Greptile Summary

This PR extends the TMA outer-reduction scheduler to support multi-input fusions by computing a per-input shared memory budget and shrinking TMA tiles to fit, fixes a canonicalization propagation issue that prevented the [R, I] reorder from reaching all tensors, switches from rFactor to ir_utils::rFactorHelper to handle multi-output reduction definitions correctly, and disables Welford ops for TMA paths until they are properly supported.

Key changes:

  • Multi-input smem budgeting (reduction_outer_tma.cpp lines 38-57): tile sizes start at 128×128 and halve (reduction first, then iteration) until they fit in smem_bytes / n_inputs. A matching minimum-tile check is added to mayUseTmaOuter so fusions that can't fit even the 16×32 minimum are rejected early.
  • Canonicalization propagation fix (reduction_outer_tma.cpp lines 131-133): a new TransformPropagator pass runs immediately after the [I, R]→[R, I] reorder so that all tensors (including additional TMA cache TVs) share the canonical form before Phase 2 tiling begins.
  • rFactorHelper switch (reduction_outer_tma.cpp line 205): when the reduction definition has multiple outputs, rFactor(axes) only handles the first; rFactorHelper correctly passes all output TVs.
  • Welford guard (reduction.cpp lines 257-259, 317-319): both mayUseTma and mayUseTmaOuter now return false for fusions containing WelfordOp.
  • Test refactoring: TmaOuterReductionTestParams gains n_inputs and dtype fields; 6 new multi-input cases (2/3/5 inputs × float/half) are added, though the single-input size sweep is reduced from 25 combinations to 9.

Confidence Score: 3/5

  • Mostly safe for the tested input counts, but a latent Vectorize(1) edge case and unaccounted smem overhead need attention before merging.
  • The core logic (tile shrinking, canonicalization fix, rFactorHelper) is well-structured and correctly guarded. However, two issues lower confidence: (1) when tiles shrink to the absolute minimum (tma_tile_i = bdimx = 32), iter_unroll_factor becomes 1, producing an invalid Vectorize(1) axis — a path that is not covered by any test case; (2) the smem budget does not reserve space for non-input smem consumers (accumulators, output caches), risking CUDA launch failures at runtime with many inputs. These are edge cases under normal workloads but real failure modes.
  • Pay close attention to csrc/scheduler/reduction_outer_tma.cpp — specifically the tile-shrinking loop (lines 50-57) and the corresponding smem budget calculation (lines 40-44).

Important Files Changed

Filename Overview
csrc/scheduler/reduction_outer_tma.cpp Adds smem-budget-aware TMA tile shrinking for multi-input fusions, a pre-transform canonicalization propagation pass, and switches to rFactorHelper for multi-output reductions. The tile-shrinking logic is sound but has a latent edge case when tiles hit the minimum size (tma_tile_i = bdimx = 32), causing iter_unroll_factor = 1 and an invalid Vectorize(1) axis. The smem budget also does not reserve space for non-input smem consumers.
csrc/scheduler/reduction.cpp Replaces the blanket single-input guard with an smem-capacity check consistent with the heuristic's minimum tile sizes. Adds Welford rejection to both mayUseTma and mayUseTmaOuter. The logic is clean and matches the heuristic, though like the heuristic it does not reserve smem for non-input consumers.
tests/cpp/test_reduction.cpp Extends TmaOuterReductionTest to accept n_inputs and dtype parameters, adds multi-input test cases for 2/3/5 float and half inputs, and updates expectOuterTmaUsed to mirror the new smem guard. The refactoring is clean, though single-input coverage was quietly reduced from 25 to 9 test cases.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[mayUseTmaOuter called] --> B{SM arch >= 9?}
    B -- No --> REJECT[return false]
    B -- Yes --> C{outer reduction?}
    C -- No --> REJECT
    C -- Yes --> D{total_reduction_bytes >= 16 KB?}
    D -- No --> REJECT
    D -- Yes --> E{min_tile smem fits?\nmin = 16×32×dtype × n_inputs\n<= sharedMemPerBlockOptin}
    E -- No --> REJECT
    E -- Yes --> F{vectorize_factor > 1\nand 128-bit aligned?}
    F -- No --> REJECT
    F -- Yes --> G{WelfordOp present?}
    G -- Yes --> REJECT
    G -- No --> ACCEPT[return true]

    ACCEPT --> H[getReductionHeuristics]
    H --> I[Compute smem_per_input\n= sharedMemPerBlockOptin / n_inputs]
    I --> J[tma_tile_r = 128\ntma_tile_i = 128]
    J --> K{tile_r × tile_i × dtype\n> smem_per_input\nand tile_r > 16?}
    K -- Yes --> L[tma_tile_r /= 2]
    L --> K
    K -- No --> M{tile_r × tile_i × dtype\n> smem_per_input\nand tile_i > 32?}
    M -- Yes --> N[tma_tile_i /= 2]
    N --> M
    M -- No --> O[iter_unroll_factor\n= tma_tile_i / bdimx]
    O --> P[scheduleReduction]

    P --> Q[cacheInputs → tma_tvs]
    Q --> R[canonicalizeReduction\nreorder to R,I\npropagate to ALL TVs]
    R --> S[Apply TMA splits to tma_tvs 0\npropagate via spanning tree]
    S --> T[rFactorHelper\nfor grid reduction]
    T --> U[Propagate parallelization\nto non-TMA TVs]
    U --> V[inlineMost + refineCachePolicy]
Loading

Last reviewed commit: 09fd6c5

Comment on lines +50 to +56
while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input &&
tma_tile_r > bdimy) {
tma_tile_r /= 2;
}
while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input &&
tma_tile_i > bdimx) {
tma_tile_i /= 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Degenerate iter_unroll_factor = 1 edge case

When the smem budget is very tight (requiring the tile to shrink to both tma_tile_r = bdimy = 16 and tma_tile_i = bdimx = 32), iter_unroll_factor = tma_tile_i / bdimx = 32 / 32 = 1. The subsequent redu_tv->split(4, 1) at line 176 then creates a size-1 innermost axis that is assigned ParallelType::Vectorize. A vectorize factor of 1 is typically invalid in NVFuser and will likely trigger a validation error at runtime.

This path is reachable: mayUseTmaOuter accepts fusions with up to sharedMemPerBlockOptin / (16 * 32 * dtype_bytes) inputs (e.g., ~48 float inputs or ~96 float16 inputs on a device with 96 KB optin smem). None of the new test cases exercise this boundary, so it would be a silent latent failure.

Consider adding a guard after the shrinking loops:

while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input &&
       tma_tile_r > bdimy) {
  tma_tile_r /= 2;
}
while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input &&
       tma_tile_i > bdimx) {
  tma_tile_i /= 2;
}
// If tma_tile_i == bdimx, iter_unroll_factor == 1 which is invalid for
// Vectorize; reject this case in mayUseTmaOuter instead.
NVF_ERROR(
    tma_tile_i > bdimx,
    "tma_tile_i shrank to bdimx, iter_unroll_factor would be 1");

Alternatively, strengthen the mayUseTmaOuter rejection check to min_tile_i > bdimx (i.e., 64 not 32) so that iter_unroll_factor >= 2 always holds.

Comment on lines +40 to +44
auto dev_prop = at::cuda::getCurrentDeviceProperties();
const int64_t smem_bytes = (int64_t)dev_prop->sharedMemPerBlockOptin;
const int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8;
const int64_t n_inputs = std::max(props.n_tensor_inputs, (int64_t)1);
const int64_t smem_per_input = smem_bytes / n_inputs;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Smem budget ignores non-input consumers

The per-input smem budget (smem_bytes / n_inputs) allocates the entire device shared memory to TMA input tiles. In practice, the kernel also uses shared memory for the rfactor/reduction accumulator TVs and the cached outputs. If those require significant smem (e.g., an output tile), the total smem usage can exceed sharedMemPerBlockOptin, causing a CUDA kernel launch failure at runtime.

mayUseTmaOuter uses the same full-smem assumption when checking whether minimum tiles fit (line 298 of reduction.cpp), so both sides are consistent — but neither accounts for the overhead. Consider subtracting a conservative smem overhead before dividing, e.g.:

// Reserve a fraction of smem for accumulators / output caches
const int64_t smem_overhead = 4096; // conservative 4 KB
const int64_t smem_available = std::max(smem_bytes - smem_overhead, (int64_t)0);
const int64_t smem_per_input = smem_available / n_inputs;

Comment on lines +3197 to +3214
testing::Values(
// Size sweep with single input, float
TmaOuterReductionTestParams{256, 256, 1, DataType::Float},
TmaOuterReductionTestParams{256, 4096, 1, DataType::Float},
TmaOuterReductionTestParams{256, 65536, 1, DataType::Float},
TmaOuterReductionTestParams{4096, 256, 1, DataType::Float},
TmaOuterReductionTestParams{4096, 4096, 1, DataType::Float},
TmaOuterReductionTestParams{4096, 65536, 1, DataType::Float},
TmaOuterReductionTestParams{65536, 256, 1, DataType::Float},
TmaOuterReductionTestParams{65536, 4096, 1, DataType::Float},
TmaOuterReductionTestParams{65536, 65536, 1, DataType::Float},
// Multi-input (exercises smem budget / tile shrinking)
TmaOuterReductionTestParams{4096, 1024, 2, DataType::Float},
TmaOuterReductionTestParams{4096, 1024, 3, DataType::Float},
TmaOuterReductionTestParams{4096, 1024, 5, DataType::Float},
TmaOuterReductionTestParams{4096, 1024, 2, DataType::Half},
TmaOuterReductionTestParams{4096, 1024, 3, DataType::Half},
TmaOuterReductionTestParams{4096, 1024, 5, DataType::Half}),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Single-input test coverage reduced

The original parameterization used testing::Combine with outer/iter sizes {256, 1024, 4096, 16384, 65536}, producing 25 single-input combinations. The new explicit list only covers sizes {256, 4096, 65536} (9 single-input cases), silently dropping the 1024 and 16384 points. This removes coverage for the intermediate sizes that are most likely to exercise tile-shrinking logic for single-input fusions (e.g., {256, 1024} and {4096, 16384}).

Consider either restoring the missing intermediate sizes, or documenting that the reduction is intentional for compile-time reasons.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant