Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions lib/DXIL/DxilOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6738,6 +6738,13 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
if (existF->getFunctionType() != pFT)
return nullptr;
F = existF;
// Ensure attributes are set on existing functions.
if (OpProps.FuncAttr != Attribute::None &&
!F->hasFnAttribute(OpProps.FuncAttr))
F->addFnAttr(OpProps.FuncAttr);
// Mark wave ops as convergent since they depend on the active lane set.
if (IsDxilOpWave(opCode) && !F->hasFnAttribute(Attribute::Convergent))
F->addFnAttr(Attribute::Convergent);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change modifies attributes that show up in final DXIL. We should be cautious with changes like that, since it could break drivers not expecting it. We need to decide whether we want to expose the convergent attribute in next DXIL version and probably filter attributes on DXIL ops for prior shader models in DxilFinalizeModule (defined in DxilPreparePasses.cpp).

UpdateCache(opClass, pOverloadType, F);
return F;
}
Expand All @@ -6749,6 +6756,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
F->addFnAttr(Attribute::NoUnwind);
if (OpProps.FuncAttr != Attribute::None)
F->addFnAttr(OpProps.FuncAttr);
// Mark wave ops as convergent since they depend on the active lane set.
if (IsDxilOpWave(opCode))
F->addFnAttr(Attribute::Convergent);

return F;
}
Expand Down
23 changes: 18 additions & 5 deletions lib/HLSL/DxilConvergent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,27 @@ class DxilConvergentMark : public ModulePass {

bool runOnModule(Module &M) override {
const ShaderModel *SM = M.GetOrCreateHLModule().GetShaderModel();

bool Updated = false;

// Mark wave-sensitive HL functions as convergent.
// This prevents optimizer passes (especially JumpThreading) from
// restructuring control flow around wave ops, which would change
// the set of active lanes at wave op call sites.
for (Function &F : M.functions()) {
if (F.isDeclaration() && IsHLWaveSensitive(&F) &&
!F.hasFnAttribute(Attribute::Convergent)) {
F.addFnAttr(Attribute::Convergent);
Updated = true;
}
}

// Can skip if in a shader and version that doesn't support derivatives.
if (!SM->IsPS() && !SM->IsLib() &&
(!SM->IsSM66Plus() || (!SM->IsCS() && !SM->IsMS() && !SM->IsAS())))
return false;
return Updated;
SupportsVectors = SM->IsSM69Plus();

bool bUpdated = false;

for (Function &F : M.functions()) {
if (F.isDeclaration())
continue;
Expand All @@ -66,13 +79,13 @@ class DxilConvergentMark : public ModulePass {
if (PropagateConvergent(V, &F, PDR)) {
// TODO: emit warning here.
}
bUpdated = true;
Updated = true;
}
}
}
}

return bUpdated;
return Updated;
}

private:
Expand Down
44 changes: 43 additions & 1 deletion lib/Transforms/Scalar/JumpThreading.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Scalar.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
Expand All @@ -24,6 +23,7 @@
#include "llvm/Analysis/LazyValueInfo.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/CallSite.h" // HLSL Change - for convergent call detection
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/LLVMContext.h"
Expand All @@ -33,6 +33,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/SSAUpdater.h"
Expand Down Expand Up @@ -1388,6 +1389,47 @@ bool JumpThreading::ThreadEdge(BasicBlock *BB,
return false;
}

// HLSL Change Begin - Don't thread through loop latch blocks when the loop
// body contains convergent calls (e.g., wave intrinsics). Threading through
// a latch can restructure the loop so that convergent calls that were inside
// the loop end up outside it, changing which lanes are active at those call
// sites on SIMT hardware.
for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) {
BasicBlock *Header = *SI;
if (!LoopHeaders.count(Header))
continue;
// BB is a loop latch (has a back-edge to Header). Walk backward from BB
// to find all blocks in the loop body and check for convergent calls.
SmallVector<BasicBlock *, 16> Worklist;
SmallPtrSet<BasicBlock *, 16> InLoop;
InLoop.insert(Header); // Seed to prevent going above header.
Worklist.push_back(BB);
bool HasConvergent = false;
while (!Worklist.empty() && !HasConvergent) {
BasicBlock *WBB = Worklist.pop_back_val();
if (!InLoop.insert(WBB).second)
continue;
for (auto &I : *WBB) {
if (auto CS = CallSite(&I)) {
if (CS.hasFnAttr(Attribute::Convergent)) {
HasConvergent = true;
break;
}
}
}
if (!HasConvergent)
for (pred_iterator PI = pred_begin(WBB), PE = pred_end(WBB); PI != PE;
++PI)
Worklist.push_back(*PI);
}
if (HasConvergent) {
DEBUG(dbgs() << " Not threading across loop latch BB '" << BB->getName()
<< "' - loop body has convergent calls\n");
return false;
}
}
// HLSL Change End

unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB, BBDupThreshold);
if (JumpThreadCost > BBDupThreshold) {
DEBUG(dbgs() << " Not threading BB '" << BB->getName()
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a pass test (under DXC/Passes) if you're testing the JumpThreading pass. Plus this test is under the older TAEF test suite, and isn't run as a lit shell test. It would be better to keep additions under the lit shell test locations. For a test compiling HLSL to DXIL, I think that would be under CodeGenDXIL.

Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// RUN: %dxc -T cs_6_6 -E main %s | FileCheck %s

// Regression test for a bug where the optimizer (JumpThreading) would
// restructure a while-loop containing wave intrinsics, moving
// WaveActiveCountBits outside the loop. This changes the set of active
// lanes at the wave op call site, producing incorrect results on SIMT
// hardware.

// Verify that WaveAllBitCount (opcode 135) appears BEFORE the loop's
// back-edge phi, ensuring it stays inside the loop body.

// CHECK: call i32 @dx.op.waveReadLaneFirst
// CHECK: call i32 @dx.op.waveAllOp
// CHECK: call i1 @dx.op.waveIsFirstLane
// CHECK: phi i32
// CHECK: br i1

RWStructuredBuffer<uint> Output : register(u1);

cbuffer Constants : register(b0) {
uint Width;
uint Height;
uint NumMaterials;
};

[numthreads(32, 1, 1)]
void main(uint3 DTid : SV_DispatchThreadID) {
uint x = DTid.x;
uint y = DTid.y;

if (x >= Width || y >= Height)
return;

// Compute a material ID per lane (simple hash).
uint materialID = ((x * 7) + (y * 13)) % NumMaterials;

// Binning loop: each iteration peels off one material group.
// WaveReadLaneFirst picks a material, matching lanes count themselves
// with WaveActiveCountBits, and the first lane in the group writes
// the count. Non-matching lanes loop back for the next material.
bool go = true;
while (go) {
uint firstMat = WaveReadLaneFirst(materialID);
if (firstMat == materialID) {
uint count = WaveActiveCountBits(true);
if (WaveIsFirstLane()) {
InterlockedAdd(Output[firstMat], count);
}
go = false;
}
}
}