Skip to content

add auto tma transpose scheduler#6018

Open
liqiangxl wants to merge 13 commits intomainfrom
llu/transpose_output_smem_auto
Open

add auto tma transpose scheduler#6018
liqiangxl wants to merge 13 commits intomainfrom
llu/transpose_output_smem_auto

Conversation

@liqiangxl
Copy link
Collaborator

To reduce number of tranpose ops, is_output_smem_transpose is added to control input/output transpose:

1. When there are more inputs than outputs, is_output_smem_transpose = True
TMA load without swizzle, TMA store with swizzle, transpose at regs --> output cached smem

2. When there are less inputs than outputs, is_output_smem_transpose = False
TMA load with swizzle, register store, transpose at input cached smem -> regs

Current performance is in this doc.

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl liqiangxl marked this pull request as ready for review February 27, 2026 15:40
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 27, 2026

Greptile Summary

This PR implements the auto TMA transpose scheduler, adding is_output_smem_transpose to choose between two swizzle strategies based on tensor counts: swizzle the input smem when there are fewer inputs (TMA load + swizzled read), or swizzle the output smem when there are fewer outputs (TMA store + swizzled write). The feature is opt-in via the new EnableOption::TmaTranspose flag and falls back to the non-TMA scheduler when disabled.

Key changes:

  • transpose_tma.cpp: Full implementation of getTransposeHeuristics and scheduleTranspose for both input-smem and output-smem transpose paths.
  • transpose_heuristic.h: Four new params (use_tma_store, is_output_smem_transpose, chunks_per_thread, elements_per_chunk) with correct sameAs/hash/toString updates.
  • transpose.cpp: TMA path gated behind isOptionEnabled(TmaTranspose); dispatch extended to route on use_tma_store in addition to use_tma_load.
  • options.h/cpp: TmaTranspose added to EnableOption; all option enums switched to std::uint8_t underlying type.
  • tma.cpp: Batching eligibility check now skips trivial extent-1 IDs, avoiding spurious exclusions from TMA load batching.
  • Tests cover parameterized dtype/dim combinations, bank-conflict validation, and direct param-override combinations.

Confidence Score: 3/5

  • Mostly safe to merge, but a latent crash path exists when is_output_smem_transpose=true and use_tma_store=false that should be guarded before this lands.
  • The core heuristic and scheduling logic is well-structured and backed by parameterized tests. However, tma_store_tvs.at(0) is called unconditionally when is_output_smem_transpose=true without verifying that tma_store_tvs is non-empty — a combination the test infrastructure explicitly allows overriding. This is an unguarded crash path. The dead code in the bank-conflict test also produces compiler warnings. These issues lower confidence despite the otherwise solid implementation.
  • csrc/scheduler/transpose_tma.cpp — specifically the tma_store_tvs.at(0) access at line 260 needs a guard or assertion.

Important Files Changed

Filename Overview
csrc/scheduler/transpose_tma.cpp Core implementation of the new TMA transpose scheduler — ~340 lines added. Contains a latent crash when is_output_smem_transpose=true but use_tma_store=false (unchecked tma_store_tvs.at(0)), and the tile-size heuristic always divides by n_input regardless of which side is transposed.
tests/cpp/test_transpose.cpp Good breadth of parameterized tests for the new TMA scheduler (dtype, dims, param combinations). Minor issue: OutputTransposeBankconflict contains dead code with unused structured bindings (read_ways/write_ways) that will produce compiler warnings.
csrc/scheduler/transpose_heuristic.h Adds use_tma_store, is_output_smem_transpose, chunks_per_thread, and elements_per_chunk fields; correctly updates sameAs, hash, and toString accordingly.
csrc/scheduler/transpose.cpp TMA path is now gated behind isOptionEnabled(EnableOption::TmaTranspose) and the dispatch condition now also triggers on use_tma_store. Both changes are correct.
csrc/options.h Adds TmaTranspose to EnableOption; changes all option enums to use std::uint8_t (all enums have far fewer than 256 entries so no overflow risk); fixes the copy constructor initialization order.
csrc/device_lower/analysis/tma.cpp Tightens the TMA batching check by filtering out trivial (extent=1) IDs before checking for thread-dim or serial parallelism. Intentional refinement — extent=1 dims don't cause mbarrier issues.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[TransposeScheduler::computeHeuristics] --> B{TmaTranspose enabled?}
    B -- No --> C[non_tma::getTransposeHeuristics]
    B -- Yes --> D[tma::getTransposeHeuristics]
    D --> E{n_input > n_output?}
    E -- Yes --> F["is_output_smem_transpose=true\nuse_tma_load=true\nuse_tma_store=true"]
    E -- No --> G["is_output_smem_transpose=false\nuse_tma_load=true\nuse_tma_store=false"]
    F --> H[scheduleTranspose]
    G --> H
    C --> H

    H --> I[cacheInputs + cacheAndForkOutputs]
    I --> J{use_tma_load?}
    J -- Yes --> K["input → smem_cache (TMA) → reg_cache"]
    J -- No --> L[skip]
    I --> M{use_tma_store?}
    M -- Yes --> N["reg_cache → smem_cache (TMA) → output"]
    M -- No --> O[register store]

    K --> P[Step 1: Tile both transpose dims, propagate BIDx]
    L --> P
    N --> P
    O --> P
    P --> Q{is_output_smem_transpose?}
    Q -- Yes --> R["Step 2: scheduleTMAStoreForMmaOutput\n(swizzled output smem)"]
    Q -- No --> S["Step 3: applyMmaSwizzleForTMALoad\n(swizzled input smem)"]
    R --> T[Step 4: Register scheduling + TIDx propagation]
    S --> T
    T --> U[Vectorize smem reads/writes]
    U --> V[inlineMost]
Loading

Comments Outside Diff (2)

  1. csrc/scheduler/transpose_tma.cpp, line 257-262 (link)

    Crash when is_output_smem_transpose=true but use_tma_store=false

    tma_store_tvs.at(0) is called unconditionally whenever is_output_smem_transpose is true, but tma_store_tvs is only populated when use_tma_store is true. If a caller sets is_output_smem_transpose = true with use_tma_store = false (a combination not prevented by the heuristic or scheduler), this throws a std::out_of_range exception at runtime.

    The default heuristic always sets use_tma_store = is_output_smem_transpose, so the normal path is safe, but the TmaTransposeParamsTestP test explicitly overrides these params independently, and a future test or user could hit this combination. An assertion or guard should be added here:

      if (tparams->is_output_smem_transpose) {
        NVF_ERROR(
            !tma_store_tvs.empty(),
            "is_output_smem_transpose requires use_tma_store to be true");
        MmaInputSmemSwizzle swizzle =
            mma_utils::tmaSwizzleSharedMemory(tma_store_tvs.at(0));
  2. csrc/scheduler/transpose_tma.cpp, line 88-90 (link)

    Tile-size heuristic always uses n_input, ignoring n_output for the output-smem path

    The tile-size-1 heuristic divides by n_input unconditionally:

    const int64_t bytes_per_tile = bytes_per_cta / n_input;

    When is_output_smem_transpose = true the swizzled side is the output smem, and n_input > n_output by definition. The non-swizzled tile (tile1) spans the inputs, so using n_input is conceptually correct for the "data loaded per CTA" target. However, the output smem footprint scales with n_output, not n_input. If balancing smem usage across outputs is also a goal, dividing by max(n_input, n_output) (or a weighted combination) would be more symmetric. As a minimum a comment clarifying why n_input is the right divisor in both branches would help maintainability.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Last reviewed commit: 0068fe5

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

NVF_ERROR(grouped_inputs_outputs.size() >= 2);

// When there are more inputs than outputs, output smem transpose should be
// used, however, if it is not, then input smem tranpose will be used, to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tranpose should be transpose

const int64_t cta_per_sm =
dev_props->maxThreadsPerMultiProcessor / threads_per_cta;
const int64_t bytes_per_cta = bytes_per_sm / cta_per_sm;
const int64_t bytes_per_tile = bytes_per_cta / n_input;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add check that n_input > 0 before this division. While the scheduler validation should prevent this, defensive programming would make the code more robust.

Suggested change
const int64_t bytes_per_tile = bytes_per_cta / n_input;
NVF_ERROR(n_input > 0, "Expected at least one TensorView input for transpose");
const int64_t bytes_per_tile = bytes_per_cta / n_input;

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@liqiangxl liqiangxl requested a review from rdspring1 February 27, 2026 17:24
@github-actions
Copy link

github-actions bot commented Mar 2, 2026

Review updated until commit bc772db

Description

  • Implements automatic TMA (Tensor Memory Access) transpose scheduler with two paths: input smem transpose (swizzle on input) and output smem transpose (swizzle on output)

  • Adds new TmaTranspose enable option to toggle the feature; scheduler falls back to non-TMA when disabled

  • Introduces new parameters: use_tma_store, is_output_smem_transpose, chunks_per_thread, elements_per_chunk for flexible TMA configuration

  • Adds comprehensive tests covering different dtypes, transpose dimensions, and TMA parameter combinations

Changes walkthrough

Relevant files

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Potential TMA load restriction

The new code filters loop domains to only include non-trivial IDs (extent > 1 or non-const) before checking for thread/serial dims.
This is more restrictive than the original which checked all loop domains. This could potentially exclude valid TMA loads
where some dimensions have extent 1 but other dimensions are parallelized with threads. Need to verify this doesn't break
existing TMA use cases.

auto non_trivial_ids =
    tv->getLoopDomain() | std::views::filter([](const IterDomain* id) {
      return !id->extent()->isConstScalar() ||
          id->extent()->evaluate().as<int64_t>() > 1;
    });
if (std::ranges::any_of(non_trivial_ids, [](const IterDomain* id) {
      return id->isThreadDim() ||
          id->getParallelType() == ParallelType::Serial;
    })) {
  return {};
}
Missing null check

In scheduleTranspose, when setting up TMA store (lines 165-172), the code accesses fusion->outputs()[output_idx] without
checking if output_idx is within bounds. While cached_outputs should correspond to outputs, a bounds check would be safer.

for (auto [cached_output, output_idx] : cached_outputs) {
  auto output = fusion->outputs()[output_idx]->as<TensorView>();
  output->definition()->as<LoadStoreOp>()->setOpType(
      LoadStoreOpType::CpAsyncBulkTensorTile);
  cached_output->setMemoryType(MemoryType::Shared);
  cached_output->cacheBefore();
  tma_store_tvs.push_back(cached_output);
}
Thread safety consideration

The copy constructor was modified to use a lambda that captures other.mutex_ and returns other.options_. While this appears
correct, the original implementation directly assigned options_. The new approach should be verified to maintain the same
thread-safety semantics under concurrent access patterns.

Options(const Options& other)
    : options_([&other]() {
        std::lock_guard<std::mutex> lock_other(other.mutex_);
        return other.options_;
      }()) {}

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (2)

csrc/scheduler/transpose_tma.cpp, line 106
Infinite loop when estimated_tile_size1 starts at zero

If bytes_per_tile < kTmaSwizzleBytes (line 91-92), integer division yields estimated_tile_size1 = 0. The while loop (line 104) then spins forever because 0 * 2 == 0 and get_chunks_per_thread() (line 98-102) stays at 0, which is always less than min_chunks_per_thread = 4.

On an H100 (maxThreadsPerMultiProcessor = 2048, cta_per_sm = 8, bytes_per_cta = 8192), this triggers when n_input > 64. Add an initialization guard before the loop:

  // Ensure we start from at least 1 to avoid multiplying 0 forever.
  if (estimated_tile_size1 == 0) {
    estimated_tile_size1 = 1;
  }
  while (get_chunks_per_thread() < min_chunks_per_thread) {
    estimated_tile_size1 *= 2;
  }

tests/cpp/test_transpose.cpp, line 1947
Unconditional debug output will pollute test logs

The std::cout block (lines 1945–1947) prints every bank conflict unconditionally. This makes test runner output noisy, especially since the BFloat16 path is expected to have bank conflicts. Consider wrapping the print in a debug flag or removing it:

      if (auto* ke = dynamic_cast<KernelExecutor*>(executor.get())) {
        auto bank_conflicts = getBankConflictInfo(ke->compiledKernel()->kernel());
        if (dtype == DataType::Float) {
          EXPECT_TRUE(bank_conflicts.empty());
        } else {
          // TODO: update to EXPECT_TRUE once bf16 bank conflicts are resolved.
          EXPECT_FALSE(bank_conflicts.empty());
        }
      }

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (4)

csrc/scheduler/transpose_tma.cpp, line 107
Potential infinite loop when estimated_tile_size1 initializes to zero

If bytes_per_tile < kTmaSwizzleBytes (128), integer division yields estimated_tile_size1 = 0. The while loop then evaluates get_chunks_per_thread() as 0 (because the numerator is 0 * tile_size2 = 0) and multiplies: 0 * 2 = 0 — the loop never terminates.

This happens when bytes_per_cta / n_input < 128. With an SM90 GPU (maxThreadsPerMultiProcessor = 2048), cta_per_sm = 8, giving bytes_per_cta = 8192. So the loop infinite-hangs when n_input > 64.

While unlikely for typical transpose fusions (1–2 inputs), this is an unbounded loop with no guard. A simple fix is to initialise estimated_tile_size1 to at least 1:

int64_t estimated_tile_size1 =
    std::max(int64_t(1), bytes_per_tile / kTmaSwizzleBytes);

csrc/scheduler/transpose_tma.cpp, line 267
Missing guard before accessing tma_store_tvs when use_tma_store may be false

tma_store_tvs is only populated when tparams->use_tma_store == true (lines 164–173), but this block checks only tparams->is_output_smem_transpose. If is_output_smem_transpose = true but use_tma_store = false, then tma_store_tvs will be empty and .at(0) throws std::out_of_range.

Note the asymmetry: Step 3 already guards the analogous constraint with an explicit NVF_ERROR(tparams->use_tma_load, ...) at line 286-288. Adding the same guard here would be consistent:

if (tparams->is_output_smem_transpose) {
    NVF_ERROR(
        tparams->use_tma_store,
        "TMA store must be used when output smem is transposed");
    MmaInputSmemSwizzle swizzle =
        mma_utils::tmaSwizzleSharedMemory(tma_store_tvs.at(0));

tests/cpp/test_transpose.cpp, line 1949
Debug std::cout in test code — use GTest facilities instead

These std::cout lines will only fire when bank conflicts are detected (when the test is already failing). However, raw std::cout in tests is unconventional — GTest's ADD_FAILURE() / SCOPED_TRACE or just the EXPECT_TRUE failure message would be more idiomatic:

      for (auto& [expr, ways] : bank_conflicts) {
        auto [read_ways, write_ways] = ways;
        ADD_FAILURE() << "Bank conflict in: " << expr->toString()
                      << "  read=" << read_ways << "-way"
                      << ", write=" << write_ways << "-way";
      }

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


tests/cpp/test_transpose.cpp, line 1969
Typo "tranapose" should be "transpose" in multiple lines

// Test different combinations of TMA transpose parameters:
// (is_output_smem, use_tma_load, use_tma_store)
//   (false, true, false)  - input smem transpose, TMA load only
//   (false, true, true)   - input smem transpose, TMA load + TMA store
//   (true,  true, true)   - output smem transpose, TMA load + TMA store
//   (true,  false, true)  - output smem transpose, TMA store only

if (std::ranges::any_of(non_trivial_ids, [](const IterDomain* id) {
return id->isThreadDim() ||
id->getParallelType() == ParallelType::Serial;
})) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trivial optimization of multiple-tma loads, doesn't have to be in this PR.


// When not using output smem transpose but inputs > outputs, swap groups
// so group 2 remains the swizzled side.
if (!tparams->is_output_smem_transpose &&
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is not used in current heuristics, but may use it in future tuning.

@liqiangxl
Copy link
Collaborator Author

!test

2 similar comments
@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

auto bank_conflicts = getBankConflictInfo(ke->compiledKernel()->kernel());
for (auto& [expr, ways] : bank_conflicts) {
auto [read_ways, write_ways] = ways;
std::cout << " Bank conflict: " << expr->toString()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is std::cout necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Copy link
Collaborator

@rdspring1 rdspring1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the Infinite loop in heuristics and Out-of-range crash with inconsistent params greptile concerns valid?

@liqiangxl
Copy link
Collaborator Author

Are the Infinite loop in heuristics and Out-of-range crash with inconsistent params greptile concerns valid?

Yes, they may happen in theory, added checks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants