Skip to content

Commit 7675308

Browse files
committed
Add overloads of MakeA/B/C/DBlockWindows that accept descriptors
This adds overloaded versions of the block window creation functions that allow the caller to specify explicit descriptors instead of the default ones, and reimplements the existing definitions by calling the new ones using default descriptors.
1 parent fc38355 commit 7675308

1 file changed

Lines changed: 81 additions & 49 deletions

File tree

include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp

Lines changed: 81 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -775,14 +775,12 @@ struct UniversalGemmKernel
775775
CK_TILE_DEVICE static auto
776776
MakeDefaultETensorDescriptor(const index_t M, const index_t N, const index_t stride)
777777
{
778-
// TODO: enable vector write for C in ColMajor
779778
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
780779
{
781-
return make_naive_tensor_descriptor(
782-
make_tuple(M, N), // arguments not matching with flatmm.
783-
make_tuple(stride, 1),
784-
number<EpiloguePipeline::GetVectorSizeC()>{},
785-
number<1>{});
780+
return make_naive_tensor_descriptor(make_tuple(M, N),
781+
make_tuple(stride, 1),
782+
number<EpiloguePipeline::GetVectorSizeC()>{},
783+
number<1>{});
786784
}
787785
else
788786
{
@@ -791,26 +789,18 @@ struct UniversalGemmKernel
791789
}
792790
}
793791

792+
template <typename AsTensorDesc>
794793
CK_TILE_DEVICE static auto
795794
MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
796-
const KernelArgs& kargs,
797-
const index_t k_size,
795+
const AsTensorDesc& as_desc,
798796
const index_t i_m)
799797
{
800-
// Step 1: Create tensor descriptors for A tensors
801-
const auto& as_tensor_desc = generate_tuple(
802-
[&](auto i) {
803-
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
804-
return MakeDefaultATensorDescriptor<AiLayout>(kargs.M, kargs.stride_As[i], k_size);
805-
},
806-
number<NumATensor>{});
807-
808798
// Step 1: Create tensor views
809799
const auto& as_tensor_view = generate_tuple(
810800
[&](auto i) {
811801
using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
812802
return make_tensor_view<address_space_enum::global>(
813-
static_cast<const AiDataType*>(as_ptr[i]), as_tensor_desc[i]);
803+
static_cast<const AiDataType*>(as_ptr[i]), as_desc[i]);
814804
},
815805
number<NumATensor>{});
816806

@@ -860,30 +850,38 @@ struct UniversalGemmKernel
860850
}
861851

862852
CK_TILE_DEVICE static auto
863-
MakeBBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
853+
MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
864854
const KernelArgs& kargs,
865855
const index_t k_size,
866-
const index_t i_n)
856+
const index_t i_m)
867857
{
868-
// Step 1: Create tensor descriptors for B tensors
869-
const auto& bs_tensor_desc = generate_tuple(
858+
// Step 1: Create tensor descriptors for A tensors
859+
const auto& as_tensor_desc = generate_tuple(
870860
[&](auto i) {
871-
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
872-
return MakeDefaultBTensorDescriptor<BiLayout>(
873-
kargs.N, kargs.K, kargs.stride_Bs[i], k_size);
861+
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
862+
return MakeDefaultATensorDescriptor<AiLayout>(kargs.M, kargs.stride_As[i], k_size);
874863
},
875-
number<NumBTensor>{});
864+
number<NumATensor>{});
876865

877-
// Step 2: Create tensor views
866+
return MakeABlockWindows(as_ptr, as_tensor_desc, i_m);
867+
}
868+
869+
template <typename BsTensorDesc>
870+
CK_TILE_DEVICE static auto
871+
MakeBBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
872+
const BsTensorDesc& bs_desc,
873+
const index_t i_n)
874+
{
875+
// Step 1: Create tensor views
878876
const auto& bs_tensor_view = generate_tuple(
879877
[&](auto i) {
880878
using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
881879
return make_tensor_view<address_space_enum::global>(
882-
static_cast<const BiDataType*>(bs_ptr[i]), bs_tensor_desc[i])
880+
static_cast<const BiDataType*>(bs_ptr[i]), bs_desc[i]);
883881
},
884882
number<NumBTensor>{});
885883

886-
// Step 3: Create padded views
884+
// Step 2: Create padded views
887885
const auto& bs_pad_view = generate_tuple(
888886
[&](auto i) {
889887
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
@@ -904,7 +902,7 @@ struct UniversalGemmKernel
904902
},
905903
number<NumBTensor>{});
906904

907-
// Step 4: Create tile windows
905+
// Step 3: Create tile windows
908906
const auto& bs_block_window = generate_tuple(
909907
[&](auto i) {
910908
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
@@ -940,30 +938,39 @@ struct UniversalGemmKernel
940938
return bs_block_window;
941939
}
942940

943-
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
944-
const KernelArgs& kargs,
945-
const index_t i_m,
946-
const index_t i_n)
941+
CK_TILE_DEVICE static auto
942+
MakeBBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
943+
const KernelArgs& kargs,
944+
const index_t k_size,
945+
const index_t i_n)
947946
{
948-
// Step 1: Create tensor descriptors for D tensors
949-
const auto& ds_tensor_desc = generate_tuple(
947+
const auto& bs_tensor_desc = generate_tuple(
950948
[&](auto i) {
951-
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
952-
return MakeDefaultDTensorDescriptor<DiLayout, EpiloguePipeline::GetVectorSizeD(i)>(
953-
kargs.M, kargs.N, kargs.stride_Ds[i]);
949+
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
950+
return MakeDefaultBTensorDescriptor<BiLayout>(
951+
kargs.N, kargs.K, kargs.stride_Bs[i], k_size);
954952
},
955-
number<NumDTensor>{});
953+
number<NumBTensor>{});
956954

957-
// Step 2: Create tensor views
955+
return MakeBBlockWindows(bs_ptr, bs_tensor_desc, i_n);
956+
}
957+
958+
template <typename DsTensorDesc>
959+
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
960+
const DsTensorDesc& ds_desc,
961+
const index_t i_m,
962+
const index_t i_n)
963+
{
964+
// Step 1: Create tensor views
958965
const auto& ds_tensor_view = generate_tuple(
959966
[&](auto i) {
960967
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
961968
return make_tensor_view<address_space_enum::global>(
962-
static_cast<const DDataType_*>(ds_ptr[i]), ds_tensor_desc[i]);
969+
static_cast<const DDataType_*>(ds_ptr[i]), ds_desc[i]);
963970
},
964971
number<NumDTensor>{});
965972

966-
// Step 3: Create padded views
973+
// Step 2: Create padded views
967974
const auto& ds_pad_view = generate_tuple(
968975
[&](auto i) {
969976
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
@@ -984,7 +991,7 @@ struct UniversalGemmKernel
984991
},
985992
number<NumDTensor>{});
986993

987-
// Step 4: Create tile windows
994+
// Step 3: Create tile windows
988995
const auto& ds_block_window = generate_tuple(
989996
[&](auto i) {
990997
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
@@ -1008,18 +1015,32 @@ struct UniversalGemmKernel
10081015
return ds_block_window;
10091016
}
10101017

1011-
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
1012-
CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr,
1018+
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
10131019
const KernelArgs& kargs,
10141020
const index_t i_m,
10151021
const index_t i_n)
10161022
{
1017-
// Step 1: Create tensor descriptor for E/C tensor
1018-
const auto& e_tensor_desc = MakeDefaultETensorDescriptor(kargs.M, kargs.N, kargs.stride_E);
1023+
const auto& ds_tensor_desc = generate_tuple(
1024+
[&](auto i) {
1025+
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
1026+
return MakeDefaultDTensorDescriptor<DiLayout, EpiloguePipeline::GetVectorSizeD(i)>(
1027+
kargs.M, kargs.N, kargs.stride_Ds[i]);
1028+
},
1029+
number<NumDTensor>{});
1030+
1031+
return MakeDBlockWindows(ds_ptr, ds_tensor_desc, i_m, i_n);
1032+
}
10191033

1020-
// Step 1: Create tensor view
1034+
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename ETensorDesc>
1035+
CK_TILE_DEVICE static auto MakeCBlockWindows(
1036+
EDataType* e_ptr,
1037+
const index_t i_m,
1038+
const index_t i_n,
1039+
const ETensorDesc& e_desc) // Argument order differs from A,B,D to disambiguate overloads
1040+
{
1041+
// Step 1: Create tensor view for E/C tensor
10211042
const auto& e_tensor_view =
1022-
make_tensor_view<address_space_enum::global, DstInMemOp>(e_ptr, e_tensor_desc);
1043+
make_tensor_view<address_space_enum::global, DstInMemOp>(e_ptr, e_desc);
10231044

10241045
// Step 2: Create padded view
10251046
const auto& e_pad_view = [&]() {
@@ -1048,6 +1069,17 @@ struct UniversalGemmKernel
10481069
return e_block_window;
10491070
}
10501071

1072+
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
1073+
CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr,
1074+
const KernelArgs& kargs,
1075+
const index_t i_m,
1076+
const index_t i_n)
1077+
{
1078+
1079+
const auto& e_tensor_desc = MakeDefaultETensorDescriptor(kargs.M, kargs.N, kargs.stride_E);
1080+
return MakeCBlockWindows<DstInMemOp>(e_ptr, i_m, i_n, e_tensor_desc);
1081+
}
1082+
10511083
/**
10521084
* @brief Runs single GEMM problem cooperatively by whole workgroup.
10531085
*

0 commit comments

Comments
 (0)