diff --git a/lib/DXIL/DxilOperations.cpp b/lib/DXIL/DxilOperations.cpp index eb5b2a2ceb..7726a79d9a 100644 --- a/lib/DXIL/DxilOperations.cpp +++ b/lib/DXIL/DxilOperations.cpp @@ -6738,6 +6738,13 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { if (existF->getFunctionType() != pFT) return nullptr; F = existF; + // Ensure attributes are set on existing functions. + if (OpProps.FuncAttr != Attribute::None && + !F->hasFnAttribute(OpProps.FuncAttr)) + F->addFnAttr(OpProps.FuncAttr); + // Mark wave ops as convergent since they depend on the active lane set. + if (IsDxilOpWave(opCode) && !F->hasFnAttribute(Attribute::Convergent)) + F->addFnAttr(Attribute::Convergent); UpdateCache(opClass, pOverloadType, F); return F; } @@ -6749,6 +6756,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { F->addFnAttr(Attribute::NoUnwind); if (OpProps.FuncAttr != Attribute::None) F->addFnAttr(OpProps.FuncAttr); + // Mark wave ops as convergent since they depend on the active lane set. + if (IsDxilOpWave(opCode)) + F->addFnAttr(Attribute::Convergent); return F; } diff --git a/lib/HLSL/DxilConvergent.cpp b/lib/HLSL/DxilConvergent.cpp index a96af39fd8..4848533c6b 100644 --- a/lib/HLSL/DxilConvergent.cpp +++ b/lib/HLSL/DxilConvergent.cpp @@ -44,14 +44,27 @@ class DxilConvergentMark : public ModulePass { bool runOnModule(Module &M) override { const ShaderModel *SM = M.GetOrCreateHLModule().GetShaderModel(); + + bool Updated = false; + + // Mark wave-sensitive HL functions as convergent. + // This prevents optimizer passes (especially JumpThreading) from + // restructuring control flow around wave ops, which would change + // the set of active lanes at wave op call sites. + for (Function &F : M.functions()) { + if (F.isDeclaration() && IsHLWaveSensitive(&F) && + !F.hasFnAttribute(Attribute::Convergent)) { + F.addFnAttr(Attribute::Convergent); + Updated = true; + } + } + // Can skip if in a shader and version that doesn't support derivatives. if (!SM->IsPS() && !SM->IsLib() && (!SM->IsSM66Plus() || (!SM->IsCS() && !SM->IsMS() && !SM->IsAS()))) - return false; + return Updated; SupportsVectors = SM->IsSM69Plus(); - bool bUpdated = false; - for (Function &F : M.functions()) { if (F.isDeclaration()) continue; @@ -66,13 +79,13 @@ class DxilConvergentMark : public ModulePass { if (PropagateConvergent(V, &F, PDR)) { // TODO: emit warning here. } - bUpdated = true; + Updated = true; } } } } - return bUpdated; + return Updated; } private: diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp index e4757b472e..e097f375ae 100644 --- a/lib/Transforms/Scalar/JumpThreading.cpp +++ b/lib/Transforms/Scalar/JumpThreading.cpp @@ -11,7 +11,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" @@ -24,6 +23,7 @@ #include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/CallSite.h" // HLSL Change - for convergent call detection #include "llvm/IR/DataLayout.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" @@ -33,6 +33,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" @@ -1388,6 +1389,47 @@ bool JumpThreading::ThreadEdge(BasicBlock *BB, return false; } + // HLSL Change Begin - Don't thread through loop latch blocks when the loop + // body contains convergent calls (e.g., wave intrinsics). Threading through + // a latch can restructure the loop so that convergent calls that were inside + // the loop end up outside it, changing which lanes are active at those call + // sites on SIMT hardware. + for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) { + BasicBlock *Header = *SI; + if (!LoopHeaders.count(Header)) + continue; + // BB is a loop latch (has a back-edge to Header). Walk backward from BB + // to find all blocks in the loop body and check for convergent calls. + SmallVector Worklist; + SmallPtrSet InLoop; + InLoop.insert(Header); // Seed to prevent going above header. + Worklist.push_back(BB); + bool HasConvergent = false; + while (!Worklist.empty() && !HasConvergent) { + BasicBlock *WBB = Worklist.pop_back_val(); + if (!InLoop.insert(WBB).second) + continue; + for (auto &I : *WBB) { + if (auto CS = CallSite(&I)) { + if (CS.hasFnAttr(Attribute::Convergent)) { + HasConvergent = true; + break; + } + } + } + if (!HasConvergent) + for (pred_iterator PI = pred_begin(WBB), PE = pred_end(WBB); PI != PE; + ++PI) + Worklist.push_back(*PI); + } + if (HasConvergent) { + DEBUG(dbgs() << " Not threading across loop latch BB '" << BB->getName() + << "' - loop body has convergent calls\n"); + return false; + } + } + // HLSL Change End + unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB, BBDupThreshold); if (JumpThreadCost > BBDupThreshold) { DEBUG(dbgs() << " Not threading BB '" << BB->getName() diff --git a/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/convergent/wave-in-loop-not-sunk.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/convergent/wave-in-loop-not-sunk.hlsl new file mode 100644 index 0000000000..a108c9b29d --- /dev/null +++ b/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/convergent/wave-in-loop-not-sunk.hlsl @@ -0,0 +1,52 @@ +// RUN: %dxc -T cs_6_6 -E main %s | FileCheck %s + +// Regression test for a bug where the optimizer (JumpThreading) would +// restructure a while-loop containing wave intrinsics, moving +// WaveActiveCountBits outside the loop. This changes the set of active +// lanes at the wave op call site, producing incorrect results on SIMT +// hardware. + +// Verify that WaveAllBitCount (opcode 135) appears BEFORE the loop's +// back-edge phi, ensuring it stays inside the loop body. + +// CHECK: call i32 @dx.op.waveReadLaneFirst +// CHECK: call i32 @dx.op.waveAllOp +// CHECK: call i1 @dx.op.waveIsFirstLane +// CHECK: phi i32 +// CHECK: br i1 + +RWStructuredBuffer Output : register(u1); + +cbuffer Constants : register(b0) { + uint Width; + uint Height; + uint NumMaterials; +}; + +[numthreads(32, 1, 1)] +void main(uint3 DTid : SV_DispatchThreadID) { + uint x = DTid.x; + uint y = DTid.y; + + if (x >= Width || y >= Height) + return; + + // Compute a material ID per lane (simple hash). + uint materialID = ((x * 7) + (y * 13)) % NumMaterials; + + // Binning loop: each iteration peels off one material group. + // WaveReadLaneFirst picks a material, matching lanes count themselves + // with WaveActiveCountBits, and the first lane in the group writes + // the count. Non-matching lanes loop back for the next material. + bool go = true; + while (go) { + uint firstMat = WaveReadLaneFirst(materialID); + if (firstMat == materialID) { + uint count = WaveActiveCountBits(true); + if (WaveIsFirstLane()) { + InterlockedAdd(Output[firstMat], count); + } + go = false; + } + } +}