Conversation
|
!test |
|
!test |
Greptile SummaryThis PR implements the auto TMA transpose scheduler, adding Key changes:
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[TransposeScheduler::computeHeuristics] --> B{TmaTranspose enabled?}
B -- No --> C[non_tma::getTransposeHeuristics]
B -- Yes --> D[tma::getTransposeHeuristics]
D --> E{n_input > n_output?}
E -- Yes --> F["is_output_smem_transpose=true\nuse_tma_load=true\nuse_tma_store=true"]
E -- No --> G["is_output_smem_transpose=false\nuse_tma_load=true\nuse_tma_store=false"]
F --> H[scheduleTranspose]
G --> H
C --> H
H --> I[cacheInputs + cacheAndForkOutputs]
I --> J{use_tma_load?}
J -- Yes --> K["input → smem_cache (TMA) → reg_cache"]
J -- No --> L[skip]
I --> M{use_tma_store?}
M -- Yes --> N["reg_cache → smem_cache (TMA) → output"]
M -- No --> O[register store]
K --> P[Step 1: Tile both transpose dims, propagate BIDx]
L --> P
N --> P
O --> P
P --> Q{is_output_smem_transpose?}
Q -- Yes --> R["Step 2: scheduleTMAStoreForMmaOutput\n(swizzled output smem)"]
Q -- No --> S["Step 3: applyMmaSwizzleForTMALoad\n(swizzled input smem)"]
R --> T[Step 4: Register scheduling + TIDx propagation]
S --> T
T --> U[Vectorize smem reads/writes]
U --> V[inlineMost]
|
csrc/scheduler/transpose_tma.cpp
Outdated
| NVF_ERROR(grouped_inputs_outputs.size() >= 2); | ||
|
|
||
| // When there are more inputs than outputs, output smem transpose should be | ||
| // used, however, if it is not, then input smem tranpose will be used, to |
There was a problem hiding this comment.
tranpose should be transpose
| const int64_t cta_per_sm = | ||
| dev_props->maxThreadsPerMultiProcessor / threads_per_cta; | ||
| const int64_t bytes_per_cta = bytes_per_sm / cta_per_sm; | ||
| const int64_t bytes_per_tile = bytes_per_cta / n_input; |
There was a problem hiding this comment.
Add check that n_input > 0 before this division. While the scheduler validation should prevent this, defensive programming would make the code more robust.
| const int64_t bytes_per_tile = bytes_per_cta / n_input; | |
| NVF_ERROR(n_input > 0, "Expected at least one TensorView input for transpose"); | |
| const int64_t bytes_per_tile = bytes_per_cta / n_input; |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
Review updated until commit bc772db Description
|
| Relevant files |
|---|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Potential TMA load restriction
This is more restrictive than the original which checked all loop domains. This could potentially exclude valid TMA loads where some dimensions have extent 1 but other dimensions are parallelized with threads. Need to verify this doesn't break existing TMA use cases. |
Additional Comments (2)
If On an H100 (maxThreadsPerMultiProcessor = 2048, cta_per_sm = 8, bytes_per_cta = 8192), this triggers when
The |
Additional Comments (4)
If This happens when While unlikely for typical transpose fusions (1–2 inputs), this is an unbounded loop with no guard. A simple fix is to initialise
Note the asymmetry: Step 3 already guards the analogous constraint with an explicit
These Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
| if (std::ranges::any_of(non_trivial_ids, [](const IterDomain* id) { | ||
| return id->isThreadDim() || | ||
| id->getParallelType() == ParallelType::Serial; | ||
| })) { |
There was a problem hiding this comment.
trivial optimization of multiple-tma loads, doesn't have to be in this PR.
|
|
||
| // When not using output smem transpose but inputs > outputs, swap groups | ||
| // so group 2 remains the swizzled side. | ||
| if (!tparams->is_output_smem_transpose && |
There was a problem hiding this comment.
This branch is not used in current heuristics, but may use it in future tuning.
|
!test |
2 similar comments
|
!test |
|
!test |
tests/cpp/test_transpose.cpp
Outdated
| auto bank_conflicts = getBankConflictInfo(ke->compiledKernel()->kernel()); | ||
| for (auto& [expr, ways] : bank_conflicts) { | ||
| auto [read_ways, write_ways] = ways; | ||
| std::cout << " Bank conflict: " << expr->toString() |
rdspring1
left a comment
There was a problem hiding this comment.
Are the Infinite loop in heuristics and Out-of-range crash with inconsistent params greptile concerns valid?
Yes, they may happen in theory, added checks. |
To reduce number of tranpose ops,
is_output_smem_transposeis added to control input/output transpose:1. When there are more inputs than outputs,
is_output_smem_transpose = TrueTMA load without swizzle, TMA store with swizzle, transpose at
regs --> output cached smem2. When there are less inputs than outputs,
is_output_smem_transpose = FalseTMA load with swizzle, register store, transpose at
input cached smem -> regsCurrent performance is in this doc.