diff --git a/include/dxc/DXIL/DxilMetadataHelper.h b/include/dxc/DXIL/DxilMetadataHelper.h index 76df69555c..98c4c7898a 100644 --- a/include/dxc/DXIL/DxilMetadataHelper.h +++ b/include/dxc/DXIL/DxilMetadataHelper.h @@ -648,6 +648,7 @@ class DxilMDHelper { public: // Utility functions. static bool IsKnownNamedMetaData(const llvm::NamedMDNode &Node); + static bool IsKnownGeneratedMetaData(const llvm::NamedMDNode &Node); static bool IsKnownMetadataID(llvm::LLVMContext &Ctx, unsigned ID); static void GetKnownMetadataIDs(llvm::LLVMContext &Ctx, llvm::SmallVectorImpl *pIDs); diff --git a/include/dxc/DXIL/DxilOperations.h b/include/dxc/DXIL/DxilOperations.h index fabf07ee14..bab4bffc6e 100644 --- a/include/dxc/DXIL/DxilOperations.h +++ b/include/dxc/DXIL/DxilOperations.h @@ -145,6 +145,7 @@ class OP { static bool CheckOpCodeTable(); static bool IsDxilOpFuncName(llvm::StringRef name); static bool IsDxilOpFunc(const llvm::Function *F); + static bool IsDxilOpLinAlgFuncName(llvm::StringRef Name); static bool IsDxilOpFuncCallInst(const llvm::Instruction *I); static bool IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode); static bool IsDxilOpWave(OpCode C); @@ -286,6 +287,7 @@ class OP { static const char *m_NamePrefix; static const char *m_TypePrefix; static const char *m_MatrixTypePrefix; + static const char *m_LinAlgNamePrefix; static unsigned GetTypeSlot(llvm::Type *pType); static const char *GetOverloadTypeName(unsigned TypeSlot); static llvm::StringRef GetTypeName(llvm::Type *Ty, diff --git a/include/dxc/DXIL/DxilUtil.h b/include/dxc/DXIL/DxilUtil.h index 82e3a1c16b..815bce8bbf 100644 --- a/include/dxc/DXIL/DxilUtil.h +++ b/include/dxc/DXIL/DxilUtil.h @@ -166,6 +166,7 @@ llvm::Type *GetHLSLHitObjectType(llvm::Module *M); bool IsHLSLHitObjectType(llvm::Type *Ty); bool IsHLSLLinAlgMatrixType(llvm::Type *Ty); llvm::StringRef GetHLSLLinAlgMatrixTypeMangling(llvm::StructType *Ty); +bool IsHLSLKnownTargetType(llvm::Type *Ty); bool IsHLSLResourceDescType(llvm::Type *Ty); bool IsResourceSingleComponent(llvm::Type *Ty); uint8_t GetResourceComponentCount(llvm::Type *Ty); diff --git a/include/dxc/HLSL/DxilGenerationPass.h b/include/dxc/HLSL/DxilGenerationPass.h index 7348bbd4d6..9f929e9d8b 100644 --- a/include/dxc/HLSL/DxilGenerationPass.h +++ b/include/dxc/HLSL/DxilGenerationPass.h @@ -151,4 +151,7 @@ void initializeDxilSimpleGVNEliminateRegionPass(llvm::PassRegistry &); ModulePass *createDxilModuleInitPass(); void initializeDxilModuleInitPass(llvm::PassRegistry &); +ModulePass *createDxilTrimTargetTypesPass(); +void initializeDxilTrimTargetTypesPass(llvm::PassRegistry &); + } // namespace llvm diff --git a/lib/DXIL/DxilMetadataHelper.cpp b/lib/DXIL/DxilMetadataHelper.cpp index a656b78f71..9598fa3da4 100644 --- a/lib/DXIL/DxilMetadataHelper.cpp +++ b/lib/DXIL/DxilMetadataHelper.cpp @@ -3338,6 +3338,11 @@ bool DxilMDHelper::IsKnownNamedMetaData(const llvm::NamedMDNode &Node) { return false; } +bool DxilMDHelper::IsKnownGeneratedMetaData(const llvm::NamedMDNode &Node) { + return IsKnownNamedMetaData(Node) && + Node.getName() != DxilMDHelper::kDxilTargetTypesMDName; +} + bool DxilMDHelper::IsKnownMetadataID(LLVMContext &Ctx, unsigned ID) { SmallVector IDs; GetKnownMetadataIDs(Ctx, &IDs); diff --git a/lib/DXIL/DxilOperations.cpp b/lib/DXIL/DxilOperations.cpp index 07e749cb64..89e0962de2 100644 --- a/lib/DXIL/DxilOperations.cpp +++ b/lib/DXIL/DxilOperations.cpp @@ -3040,6 +3040,7 @@ const char *OP::m_OverloadTypeName[TS_BasicCount] = { const char *OP::m_NamePrefix = "dx.op."; const char *OP::m_TypePrefix = "dx.types."; const char *OP::m_MatrixTypePrefix = "class.matrix."; // Allowed in library +const char *OP::m_LinAlgNamePrefix = "dx.op.linAlg"; // Keep sync with DXIL::AtomicBinOpCode static const char *AtomicBinOpCodeName[] = { @@ -3306,6 +3307,10 @@ bool OP::IsDxilOpFuncName(StringRef name) { return name.startswith(OP::m_NamePrefix); } +bool OP::IsDxilOpLinAlgFuncName(StringRef Name) { + return Name.startswith(OP::m_LinAlgNamePrefix); +} + bool OP::IsDxilOpFunc(const llvm::Function *F) { // Test for null to allow IsDxilOpFunc(Call.getCalledFunc()) to be resilient // to indirect calls diff --git a/lib/DXIL/DxilUtil.cpp b/lib/DXIL/DxilUtil.cpp index fb7d68d73a..c8b332107a 100644 --- a/lib/DXIL/DxilUtil.cpp +++ b/lib/DXIL/DxilUtil.cpp @@ -631,6 +631,11 @@ StringRef GetHLSLLinAlgMatrixTypeMangling(llvm::StructType *Ty) { return Ty->getStructName().substr(strlen(DXIL::kDxLinAlgMatrixTypePrefix)); } +bool IsHLSLKnownTargetType(llvm::Type *Ty) { + // Currently only LinAlgMatrix types are target types. + return IsHLSLLinAlgMatrixType(Ty); +} + bool IsHLSLResourceDescType(llvm::Type *Ty) { if (llvm::StructType *ST = dyn_cast(Ty)) { if (!ST->hasName()) diff --git a/lib/HLSL/DxilLinker.cpp b/lib/HLSL/DxilLinker.cpp index ea10b23a4c..1ac0ccf106 100644 --- a/lib/HLSL/DxilLinker.cpp +++ b/lib/HLSL/DxilLinker.cpp @@ -575,7 +575,7 @@ void DxilLinkJob::LinkNamedMDNodes(Module *pM, ValueToValueMapTy &vmap) { if (&NMD == pSrcModFlags) continue; // Skip dxil metadata which will be regenerated. - if (DxilMDHelper::IsKnownNamedMetaData(NMD)) + if (DxilMDHelper::IsKnownGeneratedMetaData(NMD)) continue; NamedMDNode *DestNMD = pM->getOrInsertNamedMetadata(NMD.getName()); // Add Src elements into Dest node. @@ -1293,6 +1293,7 @@ void DxilLinkJob::RunPreparePass(Module &M) { PM.add(createComputeViewIdStatePass()); PM.add(createDxilDeadFunctionEliminationPass()); PM.add(createNoPausePassesPass()); + PM.add(createDxilTrimTargetTypesPass()); PM.add(createDxilEmitMetadataPass()); PM.add(createDxilFinalizePreservesPass()); diff --git a/lib/HLSL/DxilPreparePasses.cpp b/lib/HLSL/DxilPreparePasses.cpp index 68da520984..01fc2bacc3 100644 --- a/lib/HLSL/DxilPreparePasses.cpp +++ b/lib/HLSL/DxilPreparePasses.cpp @@ -1641,6 +1641,110 @@ INITIALIZE_PASS(DxilEmitMetadata, "hlsl-dxilemit", "HLSL DXIL Metadata Emit", namespace { +// DxilTrimTargetTypes pass makes sure the !dx.targetTypes metadata only +// contains types that are actually used by the shader. + +class DxilTrimTargetTypes : public ModulePass { +public: + static char ID; // Pass identification, replacement for typeid + explicit DxilTrimTargetTypes() : ModulePass(ID) {} + + StringRef getPassName() const override { + return "HLSL DXIL Trim Target Types"; + } + + // Map of target type to its metadata node and usage flag. + using TargetTypesUsageMap = + SmallDenseMap, 16>; + + void markTargetTypeAsUsed(TargetTypesUsageMap &Map, llvm::Type *Ty) { + auto It = Map.find(Ty); + assert(It != Map.end() && + "used target type is not in dx.targetTypes metadata list"); + (*It).second.second = true; + } + + bool runOnModule(Module &M) override { + NamedMDNode *TargetTypesMDNode = + M.getNamedMetadata(DxilMDHelper::kDxilTargetTypesMDName); + if (!TargetTypesMDNode) + return false; + + // Add all target types that from "dx.targetTypes" metadata to the map + // to track their usage. + TargetTypesUsageMap TargetTypesMap; + for (MDNode *Node : TargetTypesMDNode->operands()) { + MDTuple *TypeMD = dyn_cast(Node); + if (!TypeMD || TypeMD->getNumOperands() == 0) + continue; + + ConstantAsMetadata *ConstMD = + dyn_cast(TypeMD->getOperand(0).get()); + if (!ConstMD) + continue; + + Constant *TypeUndefPtr = ConstMD->getValue(); + llvm::Type *Ty = TypeUndefPtr->getType(); + TargetTypesMap.try_emplace(Ty, std::make_pair(TypeMD, false)); + } + + // Scan all LinAlgMatrix functions and check the return type and argument + // types to find all used target types. + for (const llvm::Function &F : M.functions()) { + if (!F.isDeclaration()) + continue; + + // Currently only LinAlgMatrix ops use target types. + if (!OP::IsDxilOpLinAlgFuncName(F.getName())) + continue; + + llvm::Type *RetTy = F.getReturnType(); + if (dxilutil::IsHLSLKnownTargetType(RetTy)) + markTargetTypeAsUsed(TargetTypesMap, RetTy); + + for (const auto &Arg : F.args()) { + llvm::Type *Ty = Arg.getType(); + if (dxilutil::IsHLSLKnownTargetType(Ty)) + markTargetTypeAsUsed(TargetTypesMap, Ty); + } + } + + // Remove old metadata node from the module. + TargetTypesMDNode->eraseFromParent(); + + // Create a new one with the used target types. + NamedMDNode *NewTargetTypesMDNode = + M.getOrInsertNamedMetadata(DxilMDHelper::kDxilTargetTypesMDName); + for (auto &Entry : TargetTypesMap) { + MDTuple *Node = Entry.second.first; + bool IsUsed = Entry.second.second; + if (IsUsed) + NewTargetTypesMDNode->addOperand(Node); + } + + // If no target type is used, remove the new metadata node from module. + if (NewTargetTypesMDNode->getNumOperands() == 0) + NewTargetTypesMDNode->eraseFromParent(); + + return true; + } +}; + +} // namespace + +char DxilTrimTargetTypes::ID = 0; + +ModulePass *llvm::createDxilTrimTargetTypesPass() { + return new DxilTrimTargetTypes(); +} + +INITIALIZE_PASS(DxilTrimTargetTypes, "hlsl-trim-target-types", + "HLSL DXIL Trim Target Types", false, false) + +/////////////////////////////////////////////////////////////////////////////// + +namespace { + const StringRef UniNoWaveSensitiveGradientErrMsg = "Gradient operations are not affected by wave-sensitive data or control " "flow."; diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/trim-target-types-metadata-compute.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/trim-target-types-metadata-compute.hlsl new file mode 100644 index 0000000000..4d3f653749 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/trim-target-types-metadata-compute.hlsl @@ -0,0 +1,40 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -T cs_6_10 %s | FileCheck %s + +// This test is using 2 LinAlgMatrix operations: +// - __builtin_LinAlg_FillMatrix - has the matrix as a return value +// - __builtin_LinAlg_MatrixLength - has the matrix as an argument +// This is done to verify that target types are correctly collected from both +// return values and arguments of LinAlgMatrix operations. + +uint useMatrix1() { + // Matrix m; + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 4, 5, 0, 0)]] mat1; + // mat1 = Matrix::Splat(5); + __builtin_LinAlg_FillMatrix(mat1, 5); + + // Matrix m; + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(5, 3, 3, 0, 0)]] mat2; + // mat2 = Matrix::Splat(5); + return __builtin_LinAlg_MatrixLength(mat2); +} + +uint useMatrix2() { + // Matrix m; + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(10, 2, 2, 1, 1)]] mat3; + // mat3 = Matrix::Splat(5); + __builtin_LinAlg_FillMatrix(mat3, 5); + return __builtin_LinAlg_MatrixLength(mat3); +} + +RWBuffer Out; + +[numthreads(4,1,1)] +void main() { + Out[0] = useMatrix1(); +} + +// CHECK: !dx.targetTypes = !{!{{[0-9]+}}, !{{[0-9]+}}} +// CHECK: !{{[0-9]+}} = !{%dx.types.LinAlgMatrixC4M4N5U0S0 undef, i32 4, i32 4, i32 5, i32 0, i32 0} +// CHECK: !{{[0-9]+}} = !{%dx.types.LinAlgMatrixC5M3N3U0S0 undef, i32 5, i32 3, i32 3, i32 0, i32 0} +// CHECK-NOT: !{%dx.types.LinAlgMatrixC10M2N2U1S1 undef, i32 10, i32 2, i32 2, i32 1, i32 1} diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/trim-target-types-metadata-lib.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/trim-target-types-metadata-lib.hlsl new file mode 100644 index 0000000000..a6a9fa82c2 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/trim-target-types-metadata-lib.hlsl @@ -0,0 +1,96 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -T lib_6_x -DLIB1 %s | FileCheck %s --check-prefix=LIB1 +// RUN: %dxc -T lib_6_x -DLIB1 -Fo %t.lib1.dxbc %s +// RUN: %dxc -T lib_6_x -DLIB2 %s | FileCheck %s --check-prefix=LIB2 +// RUN: %dxc -T lib_6_x -DLIB2 -Fo %t.lib2.dxbc %s +// RUN: %dxl -T cs_6_10 -E CSMain1 "%t.lib1.dxbc;%t.lib2.dxbc" | FileCheck %s --check-prefix=CSMAIN1 +// RUN: %dxl -T cs_6_10 -E CSMain2 "%t.lib1.dxbc;%t.lib2.dxbc" | FileCheck %s --check-prefix=CSMAIN2 +// RUN: %dxl -T cs_6_10 -E CSMain3 "%t.lib1.dxbc;%t.lib2.dxbc" | FileCheck %s --check-prefix=CSMAIN3 + +uint useMatrix1(); +uint useMatrix2(); +void useMatrix3(); + +// This test is using 2 LinAlgMatrix operations: +// - __builtin_LinAlg_FillMatrix - has the matrix as a return value +// - __builtin_LinAlg_MatrixLength - has the matrix as an argument +// This is done to verify that target types are correctly collected from both +// return values and arguments of LinAlgMatrix operations. + +// ---- lib1 source code --- // +#ifdef LIB1 + +uint useMatrix1() { + // Matrix m; + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 4, 5, 0, 0)]] mat1; + // mat1 = Matrix::Splat(5); + __builtin_LinAlg_FillMatrix(mat1, 5); + // return mat1.Length(); + return __builtin_LinAlg_MatrixLength(mat1); +} + +uint useMatrix2() { + // Matrix m; + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(10, 2, 2, 1, 1)]] mat2; + // return mat2.Length(); + return __builtin_LinAlg_MatrixLength(mat2); +} + +#endif + +// ---- lib2 source code --- // +#ifdef LIB2 + +void useMatrix3() { + //Matrix m; + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(5, 6, 6, 2, 2)]] mat3; + // mat3 = Matrix::Splat(5); + __builtin_LinAlg_FillMatrix(mat3, 5); +} + +RWBuffer Out; + +[shader("compute")] +[numthreads(4,1,1)] +void CSMain1() { + // no matrix used +} + +[shader("compute")] +[numthreads(4,1,1)] +void CSMain2() { + Out[0] = useMatrix1(); +} + +[shader("compute")] +[numthreads(4,1,1)] +void CSMain3() { + Out[0] = useMatrix2(); + useMatrix3(); +} + +#endif + +// Target types in lib1 +// LIB1: !dx.targetTypes = !{![[TT1:.*]], ![[TT2:.*]]} +// LIB1: ![[TT1]] = !{%dx.types.LinAlgMatrixC4M4N5U0S0 undef, i32 4, i32 4, i32 5, i32 0, i32 0} +// LIB1: ![[TT2]] = !{%dx.types.LinAlgMatrixC10M2N2U1S1 undef, i32 10, i32 2, i32 2, i32 1, i32 1} + +// Target types in lib2 +// LIB2: !dx.targetTypes = !{![[TT3:.*]]} +// LIB2: ![[TT3]] = !{%dx.types.LinAlgMatrixC5M6N6U2S2 undef, i32 5, i32 6, i32 6, i32 2, i32 2} + +// Target types in final module (should be only those that are used) + +// CSMain1 doesn't use any matrix, so the target types should be filtered out from the final module. +// CSMAIN1-NOT: !dx.targetTypes + +// CSMain2 uses one type of matrix +// CSMAIN2: !dx.targetTypes = !{!{{[0-9]+}}} +// CSMAIN2: !{{[0-9]+}} = !{%dx.types.LinAlgMatrixC4M4N5U0S0 undef, i32 4, i32 4, i32 5, i32 0, i32 0} + +// CSMain3 uses two types of matrices +// CSMAIN3: !dx.targetTypes = !{!{{[0-9]+}}, !{{[0-9]+}}} +// CSMAIN3-DAG: !{{[0-9]+}} = !{%dx.types.LinAlgMatrixC10M2N2U1S1 undef, i32 10, i32 2, i32 2, i32 1, i32 1} +// CSMAIN3-DAG: !{{[0-9]+}} = !{%dx.types.LinAlgMatrixC5M6N6U2S2 undef, i32 5, i32 6, i32 6, i32 2, i32 2} +// CSMAIN3-NOT: !{%dx.types.LinAlgMatrixC10M2N2U1S1