@@ -775,14 +775,12 @@ struct UniversalGemmKernel
775775 CK_TILE_DEVICE static auto
776776 MakeDefaultETensorDescriptor (const index_t M, const index_t N, const index_t stride)
777777 {
778- // TODO: enable vector write for C in ColMajor
779778 if constexpr (std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
780779 {
781- return make_naive_tensor_descriptor (
782- make_tuple (M, N), // arguments not matching with flatmm.
783- make_tuple (stride, 1 ),
784- number<EpiloguePipeline::GetVectorSizeC ()>{},
785- number<1 >{});
780+ return make_naive_tensor_descriptor (make_tuple (M, N),
781+ make_tuple (stride, 1 ),
782+ number<EpiloguePipeline::GetVectorSizeC ()>{},
783+ number<1 >{});
786784 }
787785 else
788786 {
@@ -791,26 +789,18 @@ struct UniversalGemmKernel
791789 }
792790 }
793791
792+ template <typename AsTensorDesc>
794793 CK_TILE_DEVICE static auto
795794 MakeABlockWindows (const std::array<const ADataType*, NumATensor>& as_ptr,
796- const KernelArgs& kargs,
797- const index_t k_size,
795+ const AsTensorDesc& as_desc,
798796 const index_t i_m)
799797 {
800- // Step 1: Create tensor descriptors for A tensors
801- const auto & as_tensor_desc = generate_tuple (
802- [&](auto i) {
803- using AiLayout = remove_cvref_t <std::tuple_element_t <i.value , AsLayout>>;
804- return MakeDefaultATensorDescriptor<AiLayout>(kargs.M , kargs.stride_As [i], k_size);
805- },
806- number<NumATensor>{});
807-
808798 // Step 1: Create tensor views
809799 const auto & as_tensor_view = generate_tuple (
810800 [&](auto i) {
811801 using AiDataType = remove_cvref_t <std::tuple_element_t <i.value , AsDataType>>;
812802 return make_tensor_view<address_space_enum::global>(
813- static_cast <const AiDataType*>(as_ptr[i]), as_tensor_desc [i]);
803+ static_cast <const AiDataType*>(as_ptr[i]), as_desc [i]);
814804 },
815805 number<NumATensor>{});
816806
@@ -860,30 +850,38 @@ struct UniversalGemmKernel
860850 }
861851
862852 CK_TILE_DEVICE static auto
863- MakeBBlockWindows (const std::array<const BDataType *, NumBTensor >& bs_ptr ,
853+ MakeABlockWindows (const std::array<const ADataType *, NumATensor >& as_ptr ,
864854 const KernelArgs& kargs,
865855 const index_t k_size,
866- const index_t i_n )
856+ const index_t i_m )
867857 {
868- // Step 1: Create tensor descriptors for B tensors
869- const auto & bs_tensor_desc = generate_tuple (
858+ // Step 1: Create tensor descriptors for A tensors
859+ const auto & as_tensor_desc = generate_tuple (
870860 [&](auto i) {
871- using BiLayout = remove_cvref_t <std::tuple_element_t <i.value , BsLayout>>;
872- return MakeDefaultBTensorDescriptor<BiLayout>(
873- kargs.N , kargs.K , kargs.stride_Bs [i], k_size);
861+ using AiLayout = remove_cvref_t <std::tuple_element_t <i.value , AsLayout>>;
862+ return MakeDefaultATensorDescriptor<AiLayout>(kargs.M , kargs.stride_As [i], k_size);
874863 },
875- number<NumBTensor >{});
864+ number<NumATensor >{});
876865
877- // Step 2: Create tensor views
866+ return MakeABlockWindows (as_ptr, as_tensor_desc, i_m);
867+ }
868+
869+ template <typename BsTensorDesc>
870+ CK_TILE_DEVICE static auto
871+ MakeBBlockWindows (const std::array<const BDataType*, NumBTensor>& bs_ptr,
872+ const BsTensorDesc& bs_desc,
873+ const index_t i_n)
874+ {
875+ // Step 1: Create tensor views
878876 const auto & bs_tensor_view = generate_tuple (
879877 [&](auto i) {
880878 using BiDataType = remove_cvref_t <std::tuple_element_t <i.value , BsDataType>>;
881879 return make_tensor_view<address_space_enum::global>(
882- static_cast <const BiDataType*>(bs_ptr[i]), bs_tensor_desc [i])
880+ static_cast <const BiDataType*>(bs_ptr[i]), bs_desc [i]);
883881 },
884882 number<NumBTensor>{});
885883
886- // Step 3 : Create padded views
884+ // Step 2 : Create padded views
887885 const auto & bs_pad_view = generate_tuple (
888886 [&](auto i) {
889887 using BiLayout = remove_cvref_t <std::tuple_element_t <i.value , BsLayout>>;
@@ -904,7 +902,7 @@ struct UniversalGemmKernel
904902 },
905903 number<NumBTensor>{});
906904
907- // Step 4 : Create tile windows
905+ // Step 3 : Create tile windows
908906 const auto & bs_block_window = generate_tuple (
909907 [&](auto i) {
910908 using BiLayout = remove_cvref_t <std::tuple_element_t <i.value , BsLayout>>;
@@ -940,30 +938,39 @@ struct UniversalGemmKernel
940938 return bs_block_window;
941939 }
942940
943- CK_TILE_DEVICE static auto MakeDBlockWindows (const std::array<const void *, NumDTensor>& ds_ptr,
944- const KernelArgs& kargs,
945- const index_t i_m,
946- const index_t i_n)
941+ CK_TILE_DEVICE static auto
942+ MakeBBlockWindows (const std::array<const BDataType*, NumBTensor>& bs_ptr,
943+ const KernelArgs& kargs,
944+ const index_t k_size,
945+ const index_t i_n)
947946 {
948- // Step 1: Create tensor descriptors for D tensors
949- const auto & ds_tensor_desc = generate_tuple (
947+ const auto & bs_tensor_desc = generate_tuple (
950948 [&](auto i) {
951- using DiLayout = remove_cvref_t <std::tuple_element_t <i.value , DsLayout >>;
952- return MakeDefaultDTensorDescriptor<DiLayout, EpiloguePipeline::GetVectorSizeD (i) >(
953- kargs.M , kargs.N , kargs.stride_Ds [i]);
949+ using BiLayout = remove_cvref_t <std::tuple_element_t <i.value , BsLayout >>;
950+ return MakeDefaultBTensorDescriptor<BiLayout >(
951+ kargs.N , kargs.K , kargs.stride_Bs [i], k_size );
954952 },
955- number<NumDTensor >{});
953+ number<NumBTensor >{});
956954
957- // Step 2: Create tensor views
955+ return MakeBBlockWindows (bs_ptr, bs_tensor_desc, i_n);
956+ }
957+
958+ template <typename DsTensorDesc>
959+ CK_TILE_DEVICE static auto MakeDBlockWindows (const std::array<const void *, NumDTensor>& ds_ptr,
960+ const DsTensorDesc& ds_desc,
961+ const index_t i_m,
962+ const index_t i_n)
963+ {
964+ // Step 1: Create tensor views
958965 const auto & ds_tensor_view = generate_tuple (
959966 [&](auto i) {
960967 using DDataType_ = remove_cvref_t <std::tuple_element_t <i.value , DsDataType>>;
961968 return make_tensor_view<address_space_enum::global>(
962- static_cast <const DDataType_*>(ds_ptr[i]), ds_tensor_desc [i]);
969+ static_cast <const DDataType_*>(ds_ptr[i]), ds_desc [i]);
963970 },
964971 number<NumDTensor>{});
965972
966- // Step 3 : Create padded views
973+ // Step 2 : Create padded views
967974 const auto & ds_pad_view = generate_tuple (
968975 [&](auto i) {
969976 using DiLayout = remove_cvref_t <std::tuple_element_t <i.value , DsLayout>>;
@@ -984,7 +991,7 @@ struct UniversalGemmKernel
984991 },
985992 number<NumDTensor>{});
986993
987- // Step 4 : Create tile windows
994+ // Step 3 : Create tile windows
988995 const auto & ds_block_window = generate_tuple (
989996 [&](auto i) {
990997 using DiLayout = remove_cvref_t <std::tuple_element_t <i.value , DsLayout>>;
@@ -1008,18 +1015,32 @@ struct UniversalGemmKernel
10081015 return ds_block_window;
10091016 }
10101017
1011- template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
1012- CK_TILE_DEVICE static auto MakeCBlockWindows (EDataType* e_ptr,
1018+ CK_TILE_DEVICE static auto MakeDBlockWindows (const std::array<const void *, NumDTensor>& ds_ptr,
10131019 const KernelArgs& kargs,
10141020 const index_t i_m,
10151021 const index_t i_n)
10161022 {
1017- // Step 1: Create tensor descriptor for E/C tensor
1018- const auto & e_tensor_desc = MakeDefaultETensorDescriptor (kargs.M , kargs.N , kargs.stride_E );
1023+ const auto & ds_tensor_desc = generate_tuple (
1024+ [&](auto i) {
1025+ using DiLayout = remove_cvref_t <std::tuple_element_t <i.value , DsLayout>>;
1026+ return MakeDefaultDTensorDescriptor<DiLayout, EpiloguePipeline::GetVectorSizeD (i)>(
1027+ kargs.M , kargs.N , kargs.stride_Ds [i]);
1028+ },
1029+ number<NumDTensor>{});
1030+
1031+ return MakeDBlockWindows (ds_ptr, ds_tensor_desc, i_m, i_n);
1032+ }
10191033
1020- // Step 1: Create tensor view
1034+ template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename ETensorDesc>
1035+ CK_TILE_DEVICE static auto MakeCBlockWindows (
1036+ EDataType* e_ptr,
1037+ const index_t i_m,
1038+ const index_t i_n,
1039+ const ETensorDesc& e_desc) // Argument order differs from A,B,D to disambiguate overloads
1040+ {
1041+ // Step 1: Create tensor view for E/C tensor
10211042 const auto & e_tensor_view =
1022- make_tensor_view<address_space_enum::global, DstInMemOp>(e_ptr, e_tensor_desc );
1043+ make_tensor_view<address_space_enum::global, DstInMemOp>(e_ptr, e_desc );
10231044
10241045 // Step 2: Create padded view
10251046 const auto & e_pad_view = [&]() {
@@ -1048,6 +1069,17 @@ struct UniversalGemmKernel
10481069 return e_block_window;
10491070 }
10501071
1072+ template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
1073+ CK_TILE_DEVICE static auto MakeCBlockWindows (EDataType* e_ptr,
1074+ const KernelArgs& kargs,
1075+ const index_t i_m,
1076+ const index_t i_n)
1077+ {
1078+
1079+ const auto & e_tensor_desc = MakeDefaultETensorDescriptor (kargs.M , kargs.N , kargs.stride_E );
1080+ return MakeCBlockWindows<DstInMemOp>(e_ptr, i_m, i_n, e_tensor_desc);
1081+ }
1082+
10511083 /* *
10521084 * @brief Runs single GEMM problem cooperatively by whole workgroup.
10531085 *
0 commit comments