Conversation
Greptile SummaryThis PR extends the TMA outer-reduction scheduler to support multi-input fusions by computing a per-input shared memory budget and shrinking TMA tiles to fit, fixes a canonicalization propagation issue that prevented the Key changes:
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[mayUseTmaOuter called] --> B{SM arch >= 9?}
B -- No --> REJECT[return false]
B -- Yes --> C{outer reduction?}
C -- No --> REJECT
C -- Yes --> D{total_reduction_bytes >= 16 KB?}
D -- No --> REJECT
D -- Yes --> E{min_tile smem fits?\nmin = 16×32×dtype × n_inputs\n<= sharedMemPerBlockOptin}
E -- No --> REJECT
E -- Yes --> F{vectorize_factor > 1\nand 128-bit aligned?}
F -- No --> REJECT
F -- Yes --> G{WelfordOp present?}
G -- Yes --> REJECT
G -- No --> ACCEPT[return true]
ACCEPT --> H[getReductionHeuristics]
H --> I[Compute smem_per_input\n= sharedMemPerBlockOptin / n_inputs]
I --> J[tma_tile_r = 128\ntma_tile_i = 128]
J --> K{tile_r × tile_i × dtype\n> smem_per_input\nand tile_r > 16?}
K -- Yes --> L[tma_tile_r /= 2]
L --> K
K -- No --> M{tile_r × tile_i × dtype\n> smem_per_input\nand tile_i > 32?}
M -- Yes --> N[tma_tile_i /= 2]
N --> M
M -- No --> O[iter_unroll_factor\n= tma_tile_i / bdimx]
O --> P[scheduleReduction]
P --> Q[cacheInputs → tma_tvs]
Q --> R[canonicalizeReduction\nreorder to R,I\npropagate to ALL TVs]
R --> S[Apply TMA splits to tma_tvs 0\npropagate via spanning tree]
S --> T[rFactorHelper\nfor grid reduction]
T --> U[Propagate parallelization\nto non-TMA TVs]
U --> V[inlineMost + refineCachePolicy]
Last reviewed commit: 09fd6c5 |
| while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input && | ||
| tma_tile_r > bdimy) { | ||
| tma_tile_r /= 2; | ||
| } | ||
| while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input && | ||
| tma_tile_i > bdimx) { | ||
| tma_tile_i /= 2; |
There was a problem hiding this comment.
Degenerate
iter_unroll_factor = 1 edge case
When the smem budget is very tight (requiring the tile to shrink to both tma_tile_r = bdimy = 16 and tma_tile_i = bdimx = 32), iter_unroll_factor = tma_tile_i / bdimx = 32 / 32 = 1. The subsequent redu_tv->split(4, 1) at line 176 then creates a size-1 innermost axis that is assigned ParallelType::Vectorize. A vectorize factor of 1 is typically invalid in NVFuser and will likely trigger a validation error at runtime.
This path is reachable: mayUseTmaOuter accepts fusions with up to sharedMemPerBlockOptin / (16 * 32 * dtype_bytes) inputs (e.g., ~48 float inputs or ~96 float16 inputs on a device with 96 KB optin smem). None of the new test cases exercise this boundary, so it would be a silent latent failure.
Consider adding a guard after the shrinking loops:
while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input &&
tma_tile_r > bdimy) {
tma_tile_r /= 2;
}
while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input &&
tma_tile_i > bdimx) {
tma_tile_i /= 2;
}
// If tma_tile_i == bdimx, iter_unroll_factor == 1 which is invalid for
// Vectorize; reject this case in mayUseTmaOuter instead.
NVF_ERROR(
tma_tile_i > bdimx,
"tma_tile_i shrank to bdimx, iter_unroll_factor would be 1");Alternatively, strengthen the mayUseTmaOuter rejection check to min_tile_i > bdimx (i.e., 64 not 32) so that iter_unroll_factor >= 2 always holds.
| auto dev_prop = at::cuda::getCurrentDeviceProperties(); | ||
| const int64_t smem_bytes = (int64_t)dev_prop->sharedMemPerBlockOptin; | ||
| const int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8; | ||
| const int64_t n_inputs = std::max(props.n_tensor_inputs, (int64_t)1); | ||
| const int64_t smem_per_input = smem_bytes / n_inputs; |
There was a problem hiding this comment.
Smem budget ignores non-input consumers
The per-input smem budget (smem_bytes / n_inputs) allocates the entire device shared memory to TMA input tiles. In practice, the kernel also uses shared memory for the rfactor/reduction accumulator TVs and the cached outputs. If those require significant smem (e.g., an output tile), the total smem usage can exceed sharedMemPerBlockOptin, causing a CUDA kernel launch failure at runtime.
mayUseTmaOuter uses the same full-smem assumption when checking whether minimum tiles fit (line 298 of reduction.cpp), so both sides are consistent — but neither accounts for the overhead. Consider subtracting a conservative smem overhead before dividing, e.g.:
// Reserve a fraction of smem for accumulators / output caches
const int64_t smem_overhead = 4096; // conservative 4 KB
const int64_t smem_available = std::max(smem_bytes - smem_overhead, (int64_t)0);
const int64_t smem_per_input = smem_available / n_inputs;| testing::Values( | ||
| // Size sweep with single input, float | ||
| TmaOuterReductionTestParams{256, 256, 1, DataType::Float}, | ||
| TmaOuterReductionTestParams{256, 4096, 1, DataType::Float}, | ||
| TmaOuterReductionTestParams{256, 65536, 1, DataType::Float}, | ||
| TmaOuterReductionTestParams{4096, 256, 1, DataType::Float}, | ||
| TmaOuterReductionTestParams{4096, 4096, 1, DataType::Float}, | ||
| TmaOuterReductionTestParams{4096, 65536, 1, DataType::Float}, | ||
| TmaOuterReductionTestParams{65536, 256, 1, DataType::Float}, | ||
| TmaOuterReductionTestParams{65536, 4096, 1, DataType::Float}, | ||
| TmaOuterReductionTestParams{65536, 65536, 1, DataType::Float}, | ||
| // Multi-input (exercises smem budget / tile shrinking) | ||
| TmaOuterReductionTestParams{4096, 1024, 2, DataType::Float}, | ||
| TmaOuterReductionTestParams{4096, 1024, 3, DataType::Float}, | ||
| TmaOuterReductionTestParams{4096, 1024, 5, DataType::Float}, | ||
| TmaOuterReductionTestParams{4096, 1024, 2, DataType::Half}, | ||
| TmaOuterReductionTestParams{4096, 1024, 3, DataType::Half}, | ||
| TmaOuterReductionTestParams{4096, 1024, 5, DataType::Half}), |
There was a problem hiding this comment.
Single-input test coverage reduced
The original parameterization used testing::Combine with outer/iter sizes {256, 1024, 4096, 16384, 65536}, producing 25 single-input combinations. The new explicit list only covers sizes {256, 4096, 65536} (9 single-input cases), silently dropping the 1024 and 16384 points. This removes coverage for the intermediate sizes that are most likely to exercise tile-shrinking logic for single-input fusions (e.g., {256, 1024} and {4096, 16384}).
Consider either restoring the missing intermediate sizes, or documenting that the reduction is intentional for compile-time reasons.
Fix a canonicalization issue with
reduction_outer_tma, and support fusions with multiple inputs. UpdateTmaOuterReductionTestto cover these new cases. Also disabled Welford op for inner+outer TMA since supporting it is complicated.