Skip to content

Commit d132131

Browse files
committed
clang-format for a8w8_moe_blk_gemm1 splitk change
1 parent eed3c5a commit d132131

6 files changed

Lines changed: 207 additions & 179 deletions

File tree

example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceM
171171
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
172172
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
173173
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
174-
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
174+
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, false, MulRoutedWeight, int32_t, A0DataType>;
175175
#else
176176
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
177177
Row, Col, DsLayout, ELayout,
@@ -185,7 +185,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
185185
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
186186
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
187187
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
188-
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
188+
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, false, MulRoutedWeight, int32_t, A0DataType>;
189189
#endif
190190
// clang-format on
191191

example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ using Row = ck::tensor_layout::gemm::RowMajor;
3838
using Col = ck::tensor_layout::gemm::ColumnMajor;
3939
using Bypass = ck::tensor_layout::BypassLayoutVerification;
4040

41-
using A0DataType = F8;
42-
using A1DataType = F32;
43-
using B0DataType = F8;
44-
using B1DataType = F32;
41+
using A0DataType = F8;
42+
using A1DataType = F32;
43+
using B0DataType = F8;
44+
using B1DataType = F32;
4545
using EDataType = F32;
4646
using AccDataType = F32;
4747
using CShuffleDataType = EDataType;
@@ -56,7 +56,6 @@ using D1Layout = Col;
5656
using D2Layout = ELayout;
5757
using DsLayout = ck::Tuple<D2Layout>;
5858

59-
6059
struct MulABScaleExpertWeight
6160
{
6261
template <typename E, typename C, typename D2>
@@ -113,30 +112,30 @@ static constexpr ck::index_t Scale_Block_M = 1;
113112
static constexpr ck::index_t Scale_Block_N = 128;
114113
static constexpr ck::index_t Scale_Block_K = 128;
115114

116-
static constexpr ck::index_t Nswizzle = false;
117-
static constexpr ck::index_t IsInputGemm= true; //splitk gemm1 goes to gemm2 pipeline.
118-
static constexpr ck::index_t IsSplitK = true; //splitk gemm1
119-
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
120-
static constexpr bool MulRoutedWeight = false; //splitk gemm1 does not do routedWeight.
115+
static constexpr ck::index_t Nswizzle = false;
116+
static constexpr ck::index_t IsInputGemm = true; // splitk gemm1 goes to gemm2 pipeline.
117+
static constexpr ck::index_t IsSplitK = true; // splitk gemm1
118+
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
119+
static constexpr bool MulRoutedWeight = false; // splitk gemm1 does not do routedWeight.
121120

122121
#if 1
123-
static constexpr ck::index_t MPerBlock = 32;
124-
static constexpr ck::index_t NPerBlock = 128;
125-
static constexpr ck::index_t MNPerXDL = 16;
126-
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
127-
static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4);
122+
static constexpr ck::index_t MPerBlock = 32;
123+
static constexpr ck::index_t NPerBlock = 128;
124+
static constexpr ck::index_t MNPerXDL = 16;
125+
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
126+
static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4);
128127
static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave;
129128
static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave;
130-
static constexpr ck::index_t BLOCKSIZE = 256;
129+
static constexpr ck::index_t BLOCKSIZE = 256;
131130

132-
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
133-
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
134-
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
135-
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
136-
static constexpr ck::index_t D0Vec = 1;
137-
static constexpr ck::index_t D1Vec = 1;
131+
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
132+
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
133+
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
134+
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
135+
static constexpr ck::index_t D0Vec = 1;
136+
static constexpr ck::index_t D1Vec = 1;
138137

139-
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
138+
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
140139
// clang-format off
141140
< Row, Col, DsLayout, ELayout,
142141
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
@@ -183,8 +182,8 @@ int main(int argc, char* argv[])
183182
bool time_kernel = true;
184183
#if 1
185184
// GEMM shape
186-
ck::index_t N = 4096;
187-
ck::index_t K = 6144;
185+
ck::index_t N = 4096;
186+
ck::index_t K = 6144;
188187
// ck::index_t N = 128;
189188
// ck::index_t K = 512;
190189
ck::index_t experts = 8;
@@ -397,30 +396,29 @@ int main(int argc, char* argv[])
397396

398397
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
399398

400-
auto invoker = device_op.MakeInvoker();
401-
auto argument =
402-
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
403-
expert_ids_dev.GetDeviceBuffer(),
404-
max_token_id_dev.GetDeviceBuffer(),
405-
a0_device_buf.GetDeviceBuffer(),
406-
b0_device_buf.GetDeviceBuffer(),
407-
std::array<const void*, NumDTensor>{nullptr},
408-
e_device_buf.GetDeviceBuffer(),
409-
tokens,
410-
topk,
411-
sorted_size,
412-
N,
413-
K,
414-
StrideA,
415-
StrideB,
416-
StrideDs,
417-
StrideE,
418-
a1_device_buf.GetDeviceBuffer(),
419-
b1_device_buf.GetDeviceBuffer(),
420-
KBatch,
421-
a_element_op,
422-
b_element_op,
423-
cde_element_op);
399+
auto invoker = device_op.MakeInvoker();
400+
auto argument = device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
401+
expert_ids_dev.GetDeviceBuffer(),
402+
max_token_id_dev.GetDeviceBuffer(),
403+
a0_device_buf.GetDeviceBuffer(),
404+
b0_device_buf.GetDeviceBuffer(),
405+
std::array<const void*, NumDTensor>{nullptr},
406+
e_device_buf.GetDeviceBuffer(),
407+
tokens,
408+
topk,
409+
sorted_size,
410+
N,
411+
K,
412+
StrideA,
413+
StrideB,
414+
StrideDs,
415+
StrideE,
416+
a1_device_buf.GetDeviceBuffer(),
417+
b1_device_buf.GetDeviceBuffer(),
418+
KBatch,
419+
a_element_op,
420+
b_element_op,
421+
cde_element_op);
424422

425423
if(!device_op.IsSupportedArgument(argument))
426424
{

example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
165165
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
166166
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
167167
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
168-
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, int32_t, A0DataType>;
168+
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, MulRoutedWeight, int32_t, A0DataType>;
169169

170170
#else
171171
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
@@ -180,7 +180,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
180180
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
181181
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
182182
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
183-
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, int32_t, A0DataType>;
183+
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, false, MulRoutedWeight, int32_t, A0DataType>;
184184
#endif
185185
// clang-format on
186186

include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ struct DeviceMoeGemmBlockScale
203203
}
204204

205205
index_t gdx, gdy, gdz;
206-
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N * (IsInputGemm && IsSplitK ? 2 : 1), arg.K, arg.KBatch);
206+
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
207+
arg.M, arg.N * (IsInputGemm && IsSplitK ? 2 : 1), arg.K, arg.KBatch);
207208

208209
float ave_time = 0;
209210

@@ -236,7 +237,7 @@ struct DeviceMoeGemmBlockScale
236237
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
237238
});
238239
ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
239-
DsDataType>
240+
DsDataType>
240241
rotating_mem(arg_,
241242
stream_config.rotating_count,
242243
size_a_buffer,
@@ -253,7 +254,8 @@ struct DeviceMoeGemmBlockScale
253254
// if(arg_.KBatch > 1)
254255
// hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
255256
// 0,
256-
// arg_.M * arg_.N * sizeof(CDataType) * (IsInputGemm && IsSplitK ? 2 : 1),
257+
// arg_.M * arg_.N * sizeof(CDataType)
258+
// * (IsInputGemm && IsSplitK ? 2 : 1),
257259
// stream_config.stream_id_));
258260
};
259261

@@ -271,7 +273,8 @@ struct DeviceMoeGemmBlockScale
271273
// if(arg.KBatch > 1)
272274
// hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
273275
// 0,
274-
// arg.M * arg.N * sizeof(CDataType) * (IsInputGemm && IsSplitK ? 2 : 1),
276+
// arg.M * arg.N * sizeof(CDataType) *
277+
// (IsInputGemm && IsSplitK ? 2 : 1),
275278
// stream_config.stream_id_));
276279

277280
ave_time = launch_and_time_kernel(
@@ -290,8 +293,9 @@ struct DeviceMoeGemmBlockScale
290293

291294
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
292295

293-
constexpr auto MemoryDataOp =
294-
(IsInputGemm && !IsSplitK) ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
296+
constexpr auto MemoryDataOp = (IsInputGemm && !IsSplitK)
297+
? InMemoryDataOperationEnum::Set
298+
: InMemoryDataOperationEnum::AtomicAdd;
295299

296300
if(has_main_k_block_loop)
297301
{
@@ -442,7 +446,7 @@ struct DeviceMoeGemmBlockScale
442446
{
443447
return false;
444448
}
445-
if (arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0)
449+
if(arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0)
446450
{
447451
// Not support Kpadding with KBatch > 1
448452
return false;

0 commit comments

Comments
 (0)