[TIRX] Bind parallel loops to GPU threads before VerifyMemory#19363
[TIRX] Bind parallel loops to GPU threads before VerifyMemory#19363zhils wants to merge 4 commits intoapache:mainfrom
Conversation
`VerifyMemory` on GPU targets treats direct accesses outside thread environments as illegal. In the ScatterValue CUDA lowering path, `topi.scatter_elements` emits `ForKind::kParallel` loops without explicit thread bindings, which triggers false host-memory access failures (e.g. "Did you forget to bind?") during TIR verification. This change adds a new `tirx` pass (`BindParallelLoopsToThreads`) and inserts it before `VerifyMemory` in the `s_tir` pipelines (including adreno). The pass rewrites parallel loops into `blockIdx.x/threadIdx.x` thread-extent regions, substitutes loop vars with global thread indices, and adds bounds checks for non-divisible extents. This preserves correctness while ensuring GPU kernels pass memory verification for this path.
There was a problem hiding this comment.
Code Review
This pull request introduces the BindParallelLoopsToThreads pass, which converts ForKind::kParallel loops into GPU block and thread bindings, and integrates this pass into the S-TIR pipelines. Additionally, it provides a configuration option to allow unsupported host compilers for NVCC on Windows and adds a functional test for scatter operations on CUDA. Review feedback identifies a critical issue regarding the handling of nested parallel loops which could lead to invalid GPU register bindings, an inconsistency in GPU device type definitions between files, and a minor code redundancy in the loop variable substitution logic.
Fix three correctness/configuration issues in the GPU parallel-loop binding path used before VerifyMemory. First, preserve non-zero loop mins by mapping parallel indices as min + global_idx instead of global_idx. Second, avoid rewriting parallel loops when already inside a thread environment to prevent invalid nested bindings. Third, register cuda.nvcc_allow_unsupported_compiler as a valid PassContext key so the NVCC workaround can be enabled via config without raising Invalid config option. Made-with: Cursor
- Add kDLWebGPU to IsGPUDevice in verify_memory.cc - Remove redundant Var wrapper in loop_partition.cc - Fix nested parallel loop handling in bind_parallel_loops_to_threads.cc
tlopex
left a comment
There was a problem hiding this comment.
A few things to address here:
-
Please remove
.tmp_scatter_cuda_check.py. This looks like a local debugging script and should not be committed. It should be replaced with a proper pytest test instead. -
This PR currently mixes several unrelated changes: the NVCC
-allow-unsupported-compilerworkaround, theloop_partition.ccVar{}fix, and thekDLWebGPUchange inverify_memory.ccare all independent and should be split into separate PRs. -
Please clarify the semantics for nested parallel loops. When
in_parallel_loop_is already true, inner parallel loops are left askParallel, but once they are inside a GPU kernel they effectively become serial. If that is intentional, it would be good to document it explicitly; otherwise it may be better to reject this case. -
The
elsebranch inIfThenElselooks unnecessary.Evaluate(IntImm(DataType::Int(32), 0))is a no-op, and TIR already supportsIfThenElsewithout anelse, so this can just beIfThenElse(global_idx < extent, mapped_body). -
For target-less
PrimFuncs, please avoid falling back toTarget::Current(). If the function has no target attribute, it is better to leave it unchanged rather than guessing from ambient context. That would also be consistent with how other recent passes handle target-less functions. -
This PR also needs proper pytest coverage for the new pass.
VerifyMemoryon GPU targets treats direct accesses outside thread environments as illegal. In the ScatterValue CUDA lowering path,topi.scatter_elementsemitsForKind::kParallelloops without explicit thread bindings, which triggers false host-memory access failures (e.g. "Did you forget to bind?") during TIR verification.This change adds a new
tirxpass (BindParallelLoopsToThreads) and inserts it beforeVerifyMemoryin thes_tirpipelines (including adreno). The pass rewrites parallel loops intoblockIdx.x/threadIdx.xthread-extent regions, substitutes loop vars with global thread indices, and adds bounds checks for non-divisible extents. This preserves correctness while ensuring GPU kernels pass memory verification for this path.