Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/dxc/DXIL/DxilMetadataHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ class DxilMDHelper {
public:
// Utility functions.
static bool IsKnownNamedMetaData(const llvm::NamedMDNode &Node);
static bool IsKnownGeneratedMetaData(const llvm::NamedMDNode &Node);
static bool IsKnownMetadataID(llvm::LLVMContext &Ctx, unsigned ID);
static void GetKnownMetadataIDs(llvm::LLVMContext &Ctx,
llvm::SmallVectorImpl<unsigned> *pIDs);
Expand Down
2 changes: 2 additions & 0 deletions include/dxc/DXIL/DxilOperations.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class OP {
static bool CheckOpCodeTable();
static bool IsDxilOpFuncName(llvm::StringRef name);
static bool IsDxilOpFunc(const llvm::Function *F);
static bool IsDxilOpLinAlgFuncName(llvm::StringRef Name);
static bool IsDxilOpFuncCallInst(const llvm::Instruction *I);
static bool IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode);
static bool IsDxilOpWave(OpCode C);
Expand Down Expand Up @@ -286,6 +287,7 @@ class OP {
static const char *m_NamePrefix;
static const char *m_TypePrefix;
static const char *m_MatrixTypePrefix;
static const char *m_LinAlgNamePrefix;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dx.types version of this was put into DxilUtil in this PR https://github.com/microsoft/DirectXShaderCompiler/pull/8186/changes

Might make sense to have them next to each other?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is probably fine in either place, maybe a better fit to leave it here. The reason being that all of these are called *Prefix, they are all used in DXIL ops, m_NamePrefix is equal to "dx.op", so IMHO m_LinAlgNamePrefix with "dx.op.linAlg" fits here quite well.

static unsigned GetTypeSlot(llvm::Type *pType);
static const char *GetOverloadTypeName(unsigned TypeSlot);
static llvm::StringRef GetTypeName(llvm::Type *Ty,
Expand Down
1 change: 1 addition & 0 deletions include/dxc/DXIL/DxilUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ llvm::Type *GetHLSLHitObjectType(llvm::Module *M);
bool IsHLSLHitObjectType(llvm::Type *Ty);
bool IsHLSLLinAlgMatrixType(llvm::Type *Ty);
llvm::StringRef GetHLSLLinAlgMatrixTypeMangling(llvm::StructType *Ty);
bool IsHLSLKnownTargetType(llvm::Type *Ty);
bool IsHLSLResourceDescType(llvm::Type *Ty);
bool IsResourceSingleComponent(llvm::Type *Ty);
uint8_t GetResourceComponentCount(llvm::Type *Ty);
Expand Down
3 changes: 3 additions & 0 deletions include/dxc/HLSL/DxilGenerationPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,7 @@ void initializeDxilSimpleGVNEliminateRegionPass(llvm::PassRegistry &);
ModulePass *createDxilModuleInitPass();
void initializeDxilModuleInitPass(llvm::PassRegistry &);

ModulePass *createDxilTrimTargetTypesPass();
void initializeDxilTrimTargetTypesPass(llvm::PassRegistry &);

} // namespace llvm
5 changes: 5 additions & 0 deletions lib/DXIL/DxilMetadataHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3338,6 +3338,11 @@ bool DxilMDHelper::IsKnownNamedMetaData(const llvm::NamedMDNode &Node) {
return false;
}

bool DxilMDHelper::IsKnownGeneratedMetaData(const llvm::NamedMDNode &Node) {
return IsKnownNamedMetaData(Node) &&
Node.getName() != DxilMDHelper::kDxilTargetTypesMDName;
}

bool DxilMDHelper::IsKnownMetadataID(LLVMContext &Ctx, unsigned ID) {
SmallVector<unsigned, 2> IDs;
GetKnownMetadataIDs(Ctx, &IDs);
Expand Down
5 changes: 5 additions & 0 deletions lib/DXIL/DxilOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3040,6 +3040,7 @@ const char *OP::m_OverloadTypeName[TS_BasicCount] = {
const char *OP::m_NamePrefix = "dx.op.";
const char *OP::m_TypePrefix = "dx.types.";
const char *OP::m_MatrixTypePrefix = "class.matrix."; // Allowed in library
const char *OP::m_LinAlgNamePrefix = "dx.op.linAlg";

// Keep sync with DXIL::AtomicBinOpCode
static const char *AtomicBinOpCodeName[] = {
Expand Down Expand Up @@ -3306,6 +3307,10 @@ bool OP::IsDxilOpFuncName(StringRef name) {
return name.startswith(OP::m_NamePrefix);
}

bool OP::IsDxilOpLinAlgFuncName(StringRef Name) {
return Name.startswith(OP::m_LinAlgNamePrefix);
}

bool OP::IsDxilOpFunc(const llvm::Function *F) {
// Test for null to allow IsDxilOpFunc(Call.getCalledFunc()) to be resilient
// to indirect calls
Expand Down
5 changes: 5 additions & 0 deletions lib/DXIL/DxilUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,11 @@ StringRef GetHLSLLinAlgMatrixTypeMangling(llvm::StructType *Ty) {
return Ty->getStructName().substr(strlen(DXIL::kDxLinAlgMatrixTypePrefix));
}

bool IsHLSLKnownTargetType(llvm::Type *Ty) {
// Currently only LinAlgMatrix types are target types.
return IsHLSLLinAlgMatrixType(Ty);
}

bool IsHLSLResourceDescType(llvm::Type *Ty) {
if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
if (!ST->hasName())
Expand Down
3 changes: 2 additions & 1 deletion lib/HLSL/DxilLinker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ void DxilLinkJob::LinkNamedMDNodes(Module *pM, ValueToValueMapTy &vmap) {
if (&NMD == pSrcModFlags)
continue;
// Skip dxil metadata which will be regenerated.
if (DxilMDHelper::IsKnownNamedMetaData(NMD))
if (DxilMDHelper::IsKnownGeneratedMetaData(NMD))
continue;
NamedMDNode *DestNMD = pM->getOrInsertNamedMetadata(NMD.getName());
// Add Src elements into Dest node.
Expand Down Expand Up @@ -1293,6 +1293,7 @@ void DxilLinkJob::RunPreparePass(Module &M) {
PM.add(createComputeViewIdStatePass());
PM.add(createDxilDeadFunctionEliminationPass());
PM.add(createNoPausePassesPass());
PM.add(createDxilTrimTargetTypesPass());
PM.add(createDxilEmitMetadataPass());
PM.add(createDxilFinalizePreservesPass());

Expand Down
104 changes: 104 additions & 0 deletions lib/HLSL/DxilPreparePasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1641,6 +1641,110 @@ INITIALIZE_PASS(DxilEmitMetadata, "hlsl-dxilemit", "HLSL DXIL Metadata Emit",

namespace {

// DxilTrimTargetTypes pass makes sure the !dx.targetTypes metadata only
// contains types that are actually used by the shader.

class DxilTrimTargetTypes : public ModulePass {
public:
static char ID; // Pass identification, replacement for typeid
explicit DxilTrimTargetTypes() : ModulePass(ID) {}

StringRef getPassName() const override {
return "HLSL DXIL Trim Target Types";
}

// Map of target type to its metadata node and usage flag.
using TargetTypesUsageMap =
SmallDenseMap<llvm::Type *, std::pair<MDTuple *, bool>, 16>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: do we actually need to track whether or not types are used, or should we just collect the list of types that are and replace the metadata?

It seems to me that re-computing the set of used types is probably faster than looking at the existing set and determining if it is correct.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we just collect the list of types that are used, we will need to parse the type name to reconstruct the metadata constants. I was under the impression that that's what we want to avoid.


void markTargetTypeAsUsed(TargetTypesUsageMap &Map, llvm::Type *Ty) {
auto It = Map.find(Ty);
assert(It != Map.end() &&
"used target type is not in dx.targetTypes metadata list");
(*It).second.second = true;
}

bool runOnModule(Module &M) override {
NamedMDNode *TargetTypesMDNode =
M.getNamedMetadata(DxilMDHelper::kDxilTargetTypesMDName);
if (!TargetTypesMDNode)
return false;

// Add all target types that from "dx.targetTypes" metadata to the map
// to track their usage.
TargetTypesUsageMap TargetTypesMap;
for (MDNode *Node : TargetTypesMDNode->operands()) {
MDTuple *TypeMD = dyn_cast<MDTuple>(Node);
if (!TypeMD || TypeMD->getNumOperands() == 0)
continue;

ConstantAsMetadata *ConstMD =
dyn_cast<ConstantAsMetadata>(TypeMD->getOperand(0).get());
if (!ConstMD)
continue;

Constant *TypeUndefPtr = ConstMD->getValue();
llvm::Type *Ty = TypeUndefPtr->getType();
TargetTypesMap.try_emplace(Ty, std::make_pair(TypeMD, false));
}

// Scan all LinAlgMatrix functions and check the return type and argument
// types to find all used target types.
for (const llvm::Function &F : M.functions()) {
if (!F.isDeclaration())
continue;

// Currently only LinAlgMatrix ops use target types.
if (!OP::IsDxilOpLinAlgFuncName(F.getName()))
continue;

llvm::Type *RetTy = F.getReturnType();
if (dxilutil::IsHLSLKnownTargetType(RetTy))
markTargetTypeAsUsed(TargetTypesMap, RetTy);

for (const auto &Arg : F.args()) {
llvm::Type *Ty = Arg.getType();
if (dxilutil::IsHLSLKnownTargetType(Ty))
markTargetTypeAsUsed(TargetTypesMap, Ty);
}
}

// Remove old metadata node from the module.
TargetTypesMDNode->eraseFromParent();

// Create a new one with the used target types.
NamedMDNode *NewTargetTypesMDNode =
M.getOrInsertNamedMetadata(DxilMDHelper::kDxilTargetTypesMDName);
for (auto &Entry : TargetTypesMap) {
MDTuple *Node = Entry.second.first;
bool IsUsed = Entry.second.second;
if (IsUsed)
NewTargetTypesMDNode->addOperand(Node);
}

// If no target type is used, remove the new metadata node from module.
if (NewTargetTypesMDNode->getNumOperands() == 0)
NewTargetTypesMDNode->eraseFromParent();

return true;
}
};

} // namespace

char DxilTrimTargetTypes::ID = 0;

ModulePass *llvm::createDxilTrimTargetTypesPass() {
return new DxilTrimTargetTypes();
}

INITIALIZE_PASS(DxilTrimTargetTypes, "hlsl-trim-target-types",
"HLSL DXIL Trim Target Types", false, false)

///////////////////////////////////////////////////////////////////////////////

namespace {

const StringRef UniNoWaveSensitiveGradientErrMsg =
"Gradient operations are not affected by wave-sensitive data or control "
"flow.";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// REQUIRES: dxil-1-10
// RUN: %dxc -T cs_6_10 %s | FileCheck %s

// This test is using 2 LinAlgMatrix operations:
// - __builtin_LinAlg_FillMatrix - has the matrix as a return value
// - __builtin_LinAlg_MatrixLength - has the matrix as an argument
// This is done to verify that target types are correctly collected from both
// return values and arguments of LinAlgMatrix operations.

uint useMatrix1() {
// Matrix<ComponentType::I32, 4, 5, MatrixUse::A, MatrixScope::Thread> m;
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 4, 5, 0, 0)]] mat1;
// mat1 = Matrix::Splat(5);
__builtin_LinAlg_FillMatrix(mat1, 5);

// Matrix<ComponentType::I32, 4, 5, MatrixUse::A, MatrixScope::Thread> m;
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(5, 3, 3, 0, 0)]] mat2;
// mat2 = Matrix::Splat(5);
return __builtin_LinAlg_MatrixLength(mat2);
}

uint useMatrix2() {
// Matrix<ComponentType::F64, 2, 2, MatrixUse::B, MatrixScope::Wave> m;
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(10, 2, 2, 1, 1)]] mat3;
// mat3 = Matrix::Splat(5);
__builtin_LinAlg_FillMatrix(mat3, 5);
return __builtin_LinAlg_MatrixLength(mat3);
}

RWBuffer<uint> Out;

[numthreads(4,1,1)]
void main() {
Out[0] = useMatrix1();
}

// CHECK: !dx.targetTypes = !{!{{[0-9]+}}, !{{[0-9]+}}}
// CHECK: !{{[0-9]+}} = !{%dx.types.LinAlgMatrixC4M4N5U0S0 undef, i32 4, i32 4, i32 5, i32 0, i32 0}
// CHECK: !{{[0-9]+}} = !{%dx.types.LinAlgMatrixC5M3N3U0S0 undef, i32 5, i32 3, i32 3, i32 0, i32 0}
// CHECK-NOT: !{%dx.types.LinAlgMatrixC10M2N2U1S1 undef, i32 10, i32 2, i32 2, i32 1, i32 1}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// REQUIRES: dxil-1-10
// RUN: %dxc -T lib_6_x -DLIB1 %s | FileCheck %s --check-prefix=LIB1
// RUN: %dxc -T lib_6_x -DLIB1 -Fo %t.lib1.dxbc %s
// RUN: %dxc -T lib_6_x -DLIB2 %s | FileCheck %s --check-prefix=LIB2
// RUN: %dxc -T lib_6_x -DLIB2 -Fo %t.lib2.dxbc %s
// RUN: %dxl -T cs_6_10 -E CSMain1 "%t.lib1.dxbc;%t.lib2.dxbc" | FileCheck %s --check-prefix=CSMAIN1
// RUN: %dxl -T cs_6_10 -E CSMain2 "%t.lib1.dxbc;%t.lib2.dxbc" | FileCheck %s --check-prefix=CSMAIN2
// RUN: %dxl -T cs_6_10 -E CSMain3 "%t.lib1.dxbc;%t.lib2.dxbc" | FileCheck %s --check-prefix=CSMAIN3

uint useMatrix1();
uint useMatrix2();
void useMatrix3();

// This test is using 2 LinAlgMatrix operations:
// - __builtin_LinAlg_FillMatrix - has the matrix as a return value
// - __builtin_LinAlg_MatrixLength - has the matrix as an argument
// This is done to verify that target types are correctly collected from both
// return values and arguments of LinAlgMatrix operations.

// ---- lib1 source code --- //
#ifdef LIB1

uint useMatrix1() {
// Matrix<ComponentType::I32, 4, 5, MatrixUse::A, MatrixScope::Thread> m;
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 4, 5, 0, 0)]] mat1;
// mat1 = Matrix::Splat(5);
__builtin_LinAlg_FillMatrix(mat1, 5);
// return mat1.Length();
return __builtin_LinAlg_MatrixLength(mat1);
}

uint useMatrix2() {
// Matrix<ComponentType::F64, 2, 2, MatrixUse::B, MatrixScope::Wave> m;
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(10, 2, 2, 1, 1)]] mat2;
// return mat2.Length();
return __builtin_LinAlg_MatrixLength(mat2);
}

#endif

// ---- lib2 source code --- //
#ifdef LIB2

void useMatrix3() {
//Matrix<ComponentType::U32, 6, 6, MatrixUse::Accumulator, MatrixScope::ThreadGroup> m;
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(5, 6, 6, 2, 2)]] mat3;
// mat3 = Matrix::Splat(5);
__builtin_LinAlg_FillMatrix(mat3, 5);
}

RWBuffer<uint> Out;

[shader("compute")]
[numthreads(4,1,1)]
void CSMain1() {
// no matrix used
}

[shader("compute")]
[numthreads(4,1,1)]
void CSMain2() {
Out[0] = useMatrix1();
}

[shader("compute")]
[numthreads(4,1,1)]
void CSMain3() {
Out[0] = useMatrix2();
useMatrix3();
}

#endif

// Target types in lib1
// LIB1: !dx.targetTypes = !{![[TT1:.*]], ![[TT2:.*]]}
// LIB1: ![[TT1]] = !{%dx.types.LinAlgMatrixC4M4N5U0S0 undef, i32 4, i32 4, i32 5, i32 0, i32 0}
// LIB1: ![[TT2]] = !{%dx.types.LinAlgMatrixC10M2N2U1S1 undef, i32 10, i32 2, i32 2, i32 1, i32 1}

// Target types in lib2
// LIB2: !dx.targetTypes = !{![[TT3:.*]]}
// LIB2: ![[TT3]] = !{%dx.types.LinAlgMatrixC5M6N6U2S2 undef, i32 5, i32 6, i32 6, i32 2, i32 2}

// Target types in final module (should be only those that are used)

// CSMain1 doesn't use any matrix, so the target types should be filtered out from the final module.
// CSMAIN1-NOT: !dx.targetTypes

// CSMain2 uses one type of matrix
// CSMAIN2: !dx.targetTypes = !{!{{[0-9]+}}}
// CSMAIN2: !{{[0-9]+}} = !{%dx.types.LinAlgMatrixC4M4N5U0S0 undef, i32 4, i32 4, i32 5, i32 0, i32 0}

// CSMain3 uses two types of matrices
// CSMAIN3: !dx.targetTypes = !{!{{[0-9]+}}, !{{[0-9]+}}}
// CSMAIN3-DAG: !{{[0-9]+}} = !{%dx.types.LinAlgMatrixC10M2N2U1S1 undef, i32 10, i32 2, i32 2, i32 1, i32 1}
// CSMAIN3-DAG: !{{[0-9]+}} = !{%dx.types.LinAlgMatrixC5M6N6U2S2 undef, i32 5, i32 6, i32 6, i32 2, i32 2}
// CSMAIN3-NOT: !{%dx.types.LinAlgMatrixC10M2N2U1S1