-
Notifications
You must be signed in to change notification settings - Fork 288
Expand file tree
/
Copy pathuniversal_gemm_kernel.hpp
More file actions
1301 lines (1200 loc) · 59.2 KB
/
universal_gemm_kernel.hpp
File metadata and controls
1301 lines (1200 loc) · 59.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
/// @brief The Universal GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref UniversalGemmKernel "UniversalGemmKernel" when creating
/// kernel arguments object. It contain all necessary information required to build proper
/// kernel argument and launch kernel on GPU. This structure defines the GEMM problem
/// configuration by stating all required information like M,N,K sizes and respective strides.
/// NumATensor describes the number of A tensors. The minimum number of tensors is 1(required).
/// NumBTensor describes the number of B tensors. The minimum number of tensors is 1(required).
/// NumDTensor describes the number of D tensors. The minimum number of tensors is 0(not
/// required).
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
struct UniversalGemmHostArgs
{
CK_TILE_HOST UniversalGemmHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
const std::array<const void*, NumBTensor>& bs_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
const std::array<index_t, NumATensor>& stride_As_,
const std::array<index_t, NumBTensor>& stride_Bs_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
: as_ptr(as_ptr_),
bs_ptr(bs_ptr_),
ds_ptr(ds_ptr_),
e_ptr(e_ptr_),
M(M_),
N(N_),
K(K_),
stride_As(stride_As_),
stride_Bs(stride_Bs_),
stride_Ds(stride_Ds_),
stride_E(stride_E_),
k_batch(k_batch_)
{
}
const std::array<const void*, NumATensor> as_ptr;
const std::array<const void*, NumBTensor> bs_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
index_t M;
index_t N;
index_t K;
const std::array<index_t, NumATensor> stride_As;
const std::array<index_t, NumBTensor> stride_Bs;
const std::array<index_t, NumDTensor> stride_Ds;
union
{
index_t stride_E;
index_t stride_C;
};
index_t k_batch;
};
/// @brief The GEMM kernel device arguments.
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
struct UniversalGemmKernelArgs
{
/// @brief The As input tensor's pointer to device memory.
const std::array<const void*, NumATensor> as_ptr;
/// @brief The Bs input tensor's pointer to device memory.
const std::array<const void*, NumBTensor> bs_ptr;
/// @brief The Ds input tensor's pointer to device memory.
const std::array<const void*, NumDTensor> ds_ptr;
/// @brief The E output tensor's pointer to device memory.
void* e_ptr;
/// @brief GEMM's M dimension size.
index_t M;
/// @brief GEMM's N dimension size.
index_t N;
/// @brief GEMM's K dimension size.
index_t K;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of As tensor.
std::array<index_t, NumATensor> stride_As;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of Bs tensor.
std::array<index_t, NumBTensor> stride_Bs;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of Ds tensor.
std::array<index_t, NumDTensor> stride_Ds;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of E tensor.
index_t stride_E;
index_t k_batch;
};
/// @brief The Universal GEMM kernel template.
///
/// @paragraph Overview Overview
/// This class provides the generic matrix multiplication kernel template. By semantic
/// division of GEMM algorithm into following parts we achieve flexible, versatile
/// and robust kernel implementation.
///
/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator()
/// function call operator" which determines the work scope of each workgroup.
/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm.
/// This is the place where each workgroup is loading data from global memory and
/// carrying out dot products.
/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation
/// responsible for storing results to global memory. This is also the place where
/// any additional operator fusion may take place.
///
/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_
/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all
/// internal details of those functional parts. You can think of it like both gemm and
/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover
/// the policy is responsible for definition of all necessary data layouts and thread's
/// work distribution.
///
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into the
/// output data tile to be calculated. It determines the workgroup to
/// data relationship (or in other words - which data would be
/// processed and calculated by which workgroup).
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
/// multiplication. This class should provide implementation of data
/// loading from global memory and performing block-wise matrix
/// multiplication. You can think of it as a work done by single
/// workgroup point of view.
/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix
/// multiplication implementation. It is responsible for storing
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
/// the output E tensor in global memory.
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct UniversalGemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
static constexpr bool ADataTypeIsTuple =
is_detected<is_tuple, typename GemmPipeline::AsDataType>::value;
static constexpr bool BDataTypeIsTuple =
is_detected<is_tuple, typename GemmPipeline::BsDataType>::value;
static constexpr bool DDataTypeIsTuple =
is_detected<is_tuple, typename EpiloguePipeline::DsDataType>::value;
static constexpr bool ALayoutIsTuple =
is_detected<is_tuple, typename GemmPipeline::AsLayout>::value;
static constexpr bool BLayoutIsTuple =
is_detected<is_tuple, typename GemmPipeline::BsLayout>::value;
static constexpr bool DLayoutIsTuple =
is_detected<is_tuple, typename EpiloguePipeline::DsLayout>::value;
using AsLayout = std::conditional_t<ALayoutIsTuple,
remove_cvref_t<typename GemmPipeline::AsLayout>,
remove_cvref_t<tuple<typename GemmPipeline::ALayout>>>;
using BsLayout = std::conditional_t<BLayoutIsTuple,
remove_cvref_t<typename GemmPipeline::BsLayout>,
remove_cvref_t<tuple<typename GemmPipeline::BLayout>>>;
using DsLayout = std::conditional_t<DLayoutIsTuple,
remove_cvref_t<typename EpiloguePipeline::DsLayout>,
remove_cvref_t<tuple<typename EpiloguePipeline::DsLayout>>>;
using AsDataType = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<typename GemmPipeline::AsDataType>,
remove_cvref_t<tuple<typename GemmPipeline::ADataType>>>;
using BsDataType = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<typename GemmPipeline::BsDataType>,
remove_cvref_t<tuple<typename GemmPipeline::BDataType>>>;
using DsDataType =
std::conditional_t<DDataTypeIsTuple,
remove_cvref_t<typename EpiloguePipeline::DsDataType>,
remove_cvref_t<tuple<typename EpiloguePipeline::DsDataType>>>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
using BElementWise = remove_cvref_t<typename GemmPipeline::BElementWise>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
// Get the persistent kernel if the pipeline has it available
struct has_persistent_kernel
{
template <typename T>
using has_persistent_type = decltype(T::UsePersistentKernel);
static constexpr bool value = []() {
if constexpr(is_detected<has_persistent_type, GemmPipeline>{})
return GemmPipeline::UsePersistentKernel;
else
return false;
}();
};
static constexpr bool PersistentKernel = has_persistent_kernel::value;
// Check if TilePartitioner has GetOutputOffset method with kargs and k_id
struct has_tile_partitioner_output_offset_impl
{
template <typename T, typename KernelArgs>
using has_get_output_offset_t =
decltype(T::GetOutputOffset(std::declval<KernelArgs>(), std::declval<index_t>()));
static constexpr bool value = []() {
if constexpr(is_detected<has_get_output_offset_t, TilePartitioner>{})
return true;
else
return false;
}();
};
static constexpr bool has_tile_partitioner_output_offset =
has_tile_partitioner_output_offset_impl::value;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>{};
static constexpr index_t NumATensor = AsDataType::size();
static constexpr index_t NumBTensor = BsDataType::size();
static constexpr index_t NumDTensor = DsDataType::size();
using ADataType = remove_cvref_t<std::tuple_element_t<I0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<I0, BsDataType>>;
static_assert(AsLayout::size() == AsDataType::size(),
"The size of AsLayout and AsDataType should be the same");
static_assert(BsLayout::size() == BsDataType::size(),
"The size of BsLayout and BsDataType should be the same");
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
using KernelArgs =
UniversalGemmKernelArgs<AsLayout::size(), BsLayout::size(), DsLayout::size()>;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
// clang-format on
}
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
/**
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
* @return The maximum occupancy grid size.
* @note This function queries the maximum occupancy of the kernel using
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
*/
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
using Kernel = UniversalGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto kernel = kentry<1, Kernel, KernelArgs>;
int occupancy;
ck_tile::hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
const int grid_size = get_available_compute_units(s) * occupancy;
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto BlockSize()
{
if(ck_tile::is_wave32())
{
return dim3(kBlockSize / 2);
}
else
{
return dim3(kBlockSize);
}
}
CK_TILE_HOST static constexpr KernelArgs
MakeKernelArgs(const UniversalGemmHostArgs<NumATensor, NumBTensor, NumDTensor>& hostArgs)
{
return KernelArgs{hostArgs.as_ptr,
hostArgs.bs_ptr,
hostArgs.ds_ptr,
hostArgs.e_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_As,
hostArgs.stride_Bs,
hostArgs.stride_Ds,
hostArgs.stride_E,
hostArgs.k_batch};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
static_for<0, NumATensor, 1>{}([&](auto index) {
using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, AiLayout>)
{
as_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, AiLayout>)
{
as_k_split_offset[index] =
amd_wave_read_first_lane(k_id * KRead * kargs.stride_As[index]);
}
});
static_for<0, NumBTensor, 1>{}([&](auto index) {
using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BiLayout>)
{
bs_k_split_offset[index] =
amd_wave_read_first_lane(k_id * KRead * kargs.stride_Bs[index]);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BiLayout>)
{
bs_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead);
}
});
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{
splitted_k = amd_wave_read_first_lane(KRead);
}
else
{
splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
}
}
std::array<index_t, NumATensor> as_k_split_offset;
std::array<index_t, NumBTensor> bs_k_split_offset;
index_t splitted_k;
};
CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
{
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value)
{
if(kargs.k_batch != 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
}
return false;
}
}
const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
: GemmPipeline::template GetVectorSizeA<false>();
bool AsTesnorIsValid = {true};
static_for<0, NumATensor, 1>{}([&](auto index) {
using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
GemmPipeline::kPadK == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
}
AsTesnorIsValid = false;
}
if(kargs.K % vectorSizeA != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
}
AsTesnorIsValid = false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support M that is not a multiple of MPerBlock without padding!");
}
AsTesnorIsValid = false;
}
if(kargs.M % vectorSizeA != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
}
AsTesnorIsValid = false;
}
}
});
bool BsTesnorIsValid = {true};
const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB<true>()
: GemmPipeline::template GetVectorSizeB<false>();
static_for<0, NumBTensor, 1>{}([&](auto index) {
using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support N that is not a multiple of NPerBlock without padding!");
}
BsTesnorIsValid = false;
}
if(kargs.N % vectorSizeB != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
}
BsTesnorIsValid = false;
}
}
else
{
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
GemmPipeline::kPadK == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
}
BsTesnorIsValid = false;
}
if(kargs.K % vectorSizeB != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
}
BsTesnorIsValid = false;
}
}
});
bool DTesnorIsValid = {true};
static_for<0, NumDTensor, 1>{}([&](auto index) {
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
if(std::is_same_v<DiLayout, CLayout> == false)
{
DTesnorIsValid = false;
}
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
"NPerBlock without padding!");
}
DTesnorIsValid = false;
}
if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
}
DTesnorIsValid = false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
"MPerBlock without padding!");
}
DTesnorIsValid = false;
}
if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
}
DTesnorIsValid = false;
}
}
});
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support N that is not a multiple of NPerBlock without padding!");
}
return false;
}
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
}
return false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support M that is not a multiple of MPerBlock without padding!");
}
return false;
}
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
}
return false;
}
}
return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const std::array<const ADataType*, NumATensor>& as_ptr,
const std::array<const BDataType*, NumBTensor>& bs_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
const KernelArgs& kargs,
const index_t k_size)
{
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
const auto& as_tensor_view = generate_tuple(
[&](auto i) {
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const AiDataType*>(as_ptr[i]),
make_tuple(kargs.M, k_size),
make_tuple(kargs.stride_As[i], 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const AiDataType*>(as_ptr[i]),
make_tuple(k_size, kargs.M),
make_tuple(kargs.stride_As[i], 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
},
number<NumATensor>{});
const auto& bs_tensor_view = generate_tuple(
[&](auto i) {
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
constexpr index_t VectorSizeB =
std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(
static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
bs_ptr[i],
make_tuple(k_size, kargs.N),
make_tuple(kargs.stride_Bs[i], 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
else
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
constexpr index_t VectorSizeB =
std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tensor_view<address_space_enum::global>(
static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
}
else
{
if constexpr(GemmPipeline::Preshuffle)
{
index_t kFlatK =
GemmPipeline::BlockGemmShape::flatKPerWarp *
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global>(
bs_ptr[i],
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
bs_ptr[i],
make_tuple(kargs.N, k_size),
make_tuple(kargs.stride_Bs[i], 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
}
},
number<NumBTensor>{});
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_Ds[i], 1),
number<EpiloguePipeline::GetVectorSizeD(i)>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
make_tuple(kargs.N, kargs.M),
make_tuple(kargs.stride_Ds[i], 1),
number<EpiloguePipeline::GetVectorSizeD(i)>{},
number<1>{});
}
},
number<NumDTensor>{});
// TODO: enable vector write for C in ColMajor
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm.
make_tuple(kargs.stride_E, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_E),
number<1>{},
number<1>{});
}
}();
return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& as_pad_view = generate_tuple(
[&](auto i) {
const auto& a_tensor_view = views.at(I0);
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view[i],
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
},
number<NumATensor>{});
const auto& b_flat_pad_view = views.at(I1);
const auto& bs_pad_view = generate_tuple(
[&](auto i) {
const auto& b_tensor_view = views.at(I1);
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(b_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(b_tensor_view[i],
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
},
number<NumBTensor>{});
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
const auto& d_tensor_view = views.at(I2);
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// TODO vector write in for C in ColMajor
const auto& e_pad_view = [&]() {
const auto& e_tensor_view = views.at(I3);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
if constexpr(GemmPipeline::Preshuffle)
{
// For flatmm, we need to use the flat B tensor view
return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view);
}
else
{
return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view);
}
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& as_pad_view = views.at(I0);
const auto& bs_pad_view = views.at(I1);
const auto& ds_pad_view = views.at(I2);
const auto& e_pad_view = views.at(I3);
const auto& as_block_window = generate_tuple(
[&](auto i) {
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(as_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(as_pad_view[i],
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
},
number<NumATensor>{});
const auto& bs_block_window = generate_tuple(
[&](auto i) {
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
if constexpr(GemmPipeline::Preshuffle)
{
return make_tile_window(
bs_pad_view[i],
make_tuple(number<GemmPipeline::BlockGemmShape::flatNPerWarp>{},
number<GemmPipeline::BlockGemmShape::flatKPerWarp>{}),
{static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)),
0});
}
else
{
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
{
return make_tile_window(bs_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}
else
{
return make_tile_window(bs_pad_view[i],
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
}
}
},
number<NumBTensor>{});
const auto ds_block_window = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{i_n, i_m});
}
},
number<NumDTensor>{});
auto e_block_window = make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
}
// Version of RunGemm using descriptors
// FIXME: Currently Templated to XsList to allow both arrays and tuples for convenience, which
// doesn't enforce same size nor matching types (as with arrays)
template <typename AsList,
typename BsList,
typename DsList,
typename AGridDescs,
typename BGridDescs,
typename DGridDescs,
typename EGridDesc,
bool UseDefaultScheduler = true>
CK_TILE_DEVICE static void RunGemmDesc(const AsList& as_ptr,
const BsList& bs_ptr,
const DsList& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_0,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n,
const AGridDescs& as_desc,
const BGridDescs& bs_desc,
const DGridDescs& ds_desc,
const EGridDesc& e_desc)
{
// Create tensor views from descriptors (supports arbitrary stride patterns)
const auto& as_tensor_view = generate_tuple(
[&](auto i) {
using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
return make_tensor_view<address_space_enum::global>(
static_cast<const AiDataType*>(as_ptr[i]), as_desc[i]);
},
number<NumATensor>{});
const auto& bs_tensor_view = generate_tuple(
[&](auto i) {
using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
return make_tensor_view<address_space_enum::global>(
static_cast<const BiDataType*>(bs_ptr[i]), bs_desc[i]);
},
number<NumBTensor>{});
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DiDataType = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
return make_tensor_view<address_space_enum::global>(
static_cast<const DiDataType*>(ds_ptr[i]), ds_desc[i]);
},
number<NumDTensor>{});
auto e_tensor_view =
make_tensor_view<address_space_enum::global>(static_cast<EDataType*>(e_ptr), e_desc);
const auto& gemm_tensors_views_tuple =
make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensors_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.