@@ -38,10 +38,10 @@ using Row = ck::tensor_layout::gemm::RowMajor;
3838using Col = ck::tensor_layout::gemm::ColumnMajor;
3939using Bypass = ck::tensor_layout::BypassLayoutVerification;
4040
41- using A0DataType = F8;
42- using A1DataType = F32;
43- using B0DataType = F8;
44- using B1DataType = F32;
41+ using A0DataType = F8;
42+ using A1DataType = F32;
43+ using B0DataType = F8;
44+ using B1DataType = F32;
4545using EDataType = F32;
4646using AccDataType = F32;
4747using CShuffleDataType = EDataType;
@@ -56,7 +56,6 @@ using D1Layout = Col;
5656using D2Layout = ELayout;
5757using DsLayout = ck::Tuple<D2Layout>;
5858
59-
6059struct MulABScaleExpertWeight
6160{
6261 template <typename E, typename C, typename D2>
@@ -113,30 +112,30 @@ static constexpr ck::index_t Scale_Block_M = 1;
113112static constexpr ck::index_t Scale_Block_N = 128 ;
114113static constexpr ck::index_t Scale_Block_K = 128 ;
115114
116- static constexpr ck::index_t Nswizzle = false ;
117- static constexpr ck::index_t IsInputGemm= true ; // splitk gemm1 goes to gemm2 pipeline.
118- static constexpr ck::index_t IsSplitK = true ; // splitk gemm1
119- static constexpr ck::index_t ActOP = 0 ; // 0: gelu_and_mul, 1: silu_and_mul
120- static constexpr bool MulRoutedWeight = false ; // splitk gemm1 does not do routedWeight.
115+ static constexpr ck::index_t Nswizzle = false ;
116+ static constexpr ck::index_t IsInputGemm = true ; // splitk gemm1 goes to gemm2 pipeline.
117+ static constexpr ck::index_t IsSplitK = true ; // splitk gemm1
118+ static constexpr ck::index_t ActOP = 0 ; // 0: gelu_and_mul, 1: silu_and_mul
119+ static constexpr bool MulRoutedWeight = false ; // splitk gemm1 does not do routedWeight.
121120
122121#if 1
123- static constexpr ck::index_t MPerBlock = 32 ;
124- static constexpr ck::index_t NPerBlock = 128 ;
125- static constexpr ck::index_t MNPerXDL = 16 ;
126- static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1 );
127- static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4 );
122+ static constexpr ck::index_t MPerBlock = 32 ;
123+ static constexpr ck::index_t NPerBlock = 128 ;
124+ static constexpr ck::index_t MNPerXDL = 16 ;
125+ static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1 );
126+ static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4 );
128127static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave;
129128static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave;
130- static constexpr ck::index_t BLOCKSIZE = 256 ;
129+ static constexpr ck::index_t BLOCKSIZE = 256 ;
131130
132- static constexpr ck::index_t KPerBlock = 128 / sizeof (A0DataType);
133- static constexpr ck::index_t AK1 = 16 / sizeof (A0DataType);
134- static constexpr ck::index_t BK1 = 16 / sizeof (B0DataType);
135- static constexpr ck::index_t EVec = 16 / sizeof (EDataType);
136- static constexpr ck::index_t D0Vec = 1 ;
137- static constexpr ck::index_t D1Vec = 1 ;
131+ static constexpr ck::index_t KPerBlock = 128 / sizeof (A0DataType);
132+ static constexpr ck::index_t AK1 = 16 / sizeof (A0DataType);
133+ static constexpr ck::index_t BK1 = 16 / sizeof (B0DataType);
134+ static constexpr ck::index_t EVec = 16 / sizeof (EDataType);
135+ static constexpr ck::index_t D0Vec = 1 ;
136+ static constexpr ck::index_t D1Vec = 1 ;
138137
139- using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
138+ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
140139 // clang-format off
141140 < Row, Col, DsLayout, ELayout,
142141 A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
@@ -183,8 +182,8 @@ int main(int argc, char* argv[])
183182 bool time_kernel = true ;
184183#if 1
185184 // GEMM shape
186- ck::index_t N = 4096 ;
187- ck::index_t K = 6144 ;
185+ ck::index_t N = 4096 ;
186+ ck::index_t K = 6144 ;
188187 // ck::index_t N = 128;
189188 // ck::index_t K = 512;
190189 ck::index_t experts = 8 ;
@@ -397,30 +396,29 @@ int main(int argc, char* argv[])
397396
398397 b0_device_buf.ToDevice (b0_preshuffled.mData .data ());
399398
400- auto invoker = device_op.MakeInvoker ();
401- auto argument =
402- device_op.MakeArgument (sorted_token_ids_dev.GetDeviceBuffer (),
403- expert_ids_dev.GetDeviceBuffer (),
404- max_token_id_dev.GetDeviceBuffer (),
405- a0_device_buf.GetDeviceBuffer (),
406- b0_device_buf.GetDeviceBuffer (),
407- std::array<const void *, NumDTensor>{nullptr },
408- e_device_buf.GetDeviceBuffer (),
409- tokens,
410- topk,
411- sorted_size,
412- N,
413- K,
414- StrideA,
415- StrideB,
416- StrideDs,
417- StrideE,
418- a1_device_buf.GetDeviceBuffer (),
419- b1_device_buf.GetDeviceBuffer (),
420- KBatch,
421- a_element_op,
422- b_element_op,
423- cde_element_op);
399+ auto invoker = device_op.MakeInvoker ();
400+ auto argument = device_op.MakeArgument (sorted_token_ids_dev.GetDeviceBuffer (),
401+ expert_ids_dev.GetDeviceBuffer (),
402+ max_token_id_dev.GetDeviceBuffer (),
403+ a0_device_buf.GetDeviceBuffer (),
404+ b0_device_buf.GetDeviceBuffer (),
405+ std::array<const void *, NumDTensor>{nullptr },
406+ e_device_buf.GetDeviceBuffer (),
407+ tokens,
408+ topk,
409+ sorted_size,
410+ N,
411+ K,
412+ StrideA,
413+ StrideB,
414+ StrideDs,
415+ StrideE,
416+ a1_device_buf.GetDeviceBuffer (),
417+ b1_device_buf.GetDeviceBuffer (),
418+ KBatch,
419+ a_element_op,
420+ b_element_op,
421+ cde_element_op);
424422
425423 if (!device_op.IsSupportedArgument (argument))
426424 {
0 commit comments