gridwise_moe_mx_gemm_bpreshuffle.hpp Source File

gridwise_moe_mx_gemm_bpreshuffle.hpp Source File#

Composable Kernel: gridwise_moe_mx_gemm_bpreshuffle.hpp Source File
gridwise_moe_mx_gemm_bpreshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/env.hpp"
16
20
21#define DEBUG_LOG 0
22
23namespace ck {
24
25// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26// kernel function Blockers:
27// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28// two lds chunks.
29// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30// buffer when we declare __shared__ inside blkgemmpipe
31
33{
34 gelu_and_mul = 0,
35 silu_and_mul = 1
36};
37
38template <typename GridwiseGemm,
39 bool HasMainKBlockLoop,
40 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
41 index_t MinimumOccupancy = 1,
43__global__ void
44#if CK_USE_LAUNCH_BOUNDS
45__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
46#endif
47 // __attribute__((amdgpu_waves_per_eu(1, 1)))
48 kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
49{
50#if defined(__gfx9__)
51 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52
53 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54
55 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56 karg.p_sorted_token_ids,
57 karg.p_sorted_expert_ids,
58 karg.p_max_token_id,
59 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
61 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
62 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
63 karg.p_ds_grid,
64 karg.p_c_grid,
65 p_shared,
66 karg,
67 karg.a_element_op,
68 karg.b_element_op,
69 karg.c_element_op);
70#else
71 ignore = karg;
72#endif // end of if (defined(__gfx9__))
73}
74
75template <typename GridwiseGemm,
76 bool HasMainKBlockLoop,
77 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
78 index_t MinimumOccupancy = 1,
80__global__ void
81#if CK_USE_LAUNCH_BOUNDS
82__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
83#endif
84 // __attribute__((amdgpu_waves_per_eu(1, 1)))
85 kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
86{
87#if defined(__gfx9__)
88 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
89 {
90 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
91 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
92
93 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
94
95 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
96 karg.p_sorted_token_ids,
97 karg.p_sorted_expert_ids,
98 karg.p_max_token_id,
99 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
100 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
101 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
102 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
103 karg.p_ds_grid,
104 karg.p_c_grid,
105 p_shared_0,
106 p_shared_1,
107 karg,
108 karg.a_element_op,
109 karg.b_element_op,
110 karg.c_element_op);
111 }
112#else
113 ignore = karg;
114#endif // end of if (defined(__gfx9__))
115}
116
117template <typename ALayout,
118 typename BLayout,
119 typename DsLayout,
120 typename CLayout,
121 typename ADataType,
122 typename AScaleDataType,
123 typename BDataType,
124 typename BScaleDataType,
125 typename AccDataType,
126 typename CShuffleDataType,
127 typename DsDataType,
128 typename CDataType,
129 typename AElementwiseOperation,
130 typename BElementwiseOperation,
131 typename CElementwiseOperation,
133 index_t ScaleBlockSize, // Scaling block size
134 index_t BlockSize, // Thread block size
135 index_t MPerBlock,
136 index_t NPerBlock,
137 index_t KPerBlock,
138 index_t AK1Value,
139 index_t BK1Value,
140 index_t MPerXdl,
141 index_t NPerXdl,
142 index_t MXdlPerWave,
143 index_t NXdlPerWave,
144 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
145 typename ABlockTransferThreadClusterArrangeOrder,
146 typename ABlockTransferSrcAccessOrder,
147 index_t ABlockTransferSrcVectorDim,
148 index_t ABlockTransferSrcScalarPerVector,
149 index_t ABlockTransferDstScalarPerVector_AK1,
150 bool AThreadTransferSrcResetCoordinateAfterRun,
151 index_t ABlockLdsExtraM,
152 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
153 typename BBlockTransferThreadClusterArrangeOrder,
154 typename BBlockTransferSrcAccessOrder,
155 index_t BBlockTransferSrcVectorDim,
156 index_t BBlockTransferSrcScalarPerVector,
157 index_t BBlockTransferDstScalarPerVector_BK1,
158 bool BThreadTransferSrcResetCoordinateAfterRun,
159 index_t BBlockLdsExtraN,
160 index_t CShuffleMXdlPerWavePerShuffle,
161 index_t CShuffleNXdlPerWavePerShuffle,
162 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
163 typename CDEShuffleBlockTransferScalarPerVectors,
166 index_t ActivationOperation = 0,
167 bool NSwizzle = false,
168 bool IsInputGemm = true,
169 bool MulRoutedWeight = true,
170 typename IndexType = index_t,
171 typename ComputeTypeA = ADataType,
172 typename ComputeTypeB = BDataType>
174{
175 using LDSTypeA = ADataType;
176 using LDSTypeB = BDataType;
177
178 static constexpr auto I0 = Number<0>{};
179 static constexpr auto I1 = Number<1>{};
180 static constexpr auto I2 = Number<2>{};
181 static constexpr auto I3 = Number<3>{};
182 static constexpr auto I4 = Number<4>{};
183 static constexpr auto I5 = Number<5>{};
184 static constexpr auto I6 = Number<6>{};
185 static constexpr auto I7 = Number<7>{};
186 static constexpr auto I8 = Number<8>{};
187 static constexpr auto I9 = Number<9>{};
188
190 CDEShuffleBlockTransferScalarPerVectors{}[I0];
191 // K1 should be Number<...>
192 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
193 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
194 static constexpr auto AK1Number = Number<AK1Value>{};
195 static constexpr auto BK1Number = Number<BK1Value>{};
196
197 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
198 static constexpr bool is_single_rate_mfma = false;
199 static constexpr auto is_scale_mfma = true;
200
201 static constexpr index_t NumDTensor = DsDataType::Size();
202
203 static constexpr auto MXdlPack = 2;
204 static constexpr auto NXdlPack = 2;
205 static constexpr auto KXdlPack = 2;
206
207 //> KPack is at least the k_per_blk of selected mfma
208 //
209 // Should be a multiple of k_per_blk.
210 // TODO: Move this to blockwise pipeline base
211 // KPack in packed data types for pk A/B
212
215
216 using mfma_selector = MfmaSelector<ComputeTypeA,
217 MPerXdl,
218 NPerXdl,
219 ComputeTypeB,
222 static constexpr index_t KPack =
224
225 static constexpr index_t NLane = NPerXdl;
226 static constexpr index_t KLane = 64 / NLane;
227 static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
228 static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
229
230 // static constexpr index_t NumTokens = 1;
231 static constexpr index_t SortedTileSize = MPerBlock;
232
234 static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
235 static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
236 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
237 "A scale pack data type too large!");
238 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
239 "B scale pack data type too large!");
240
241 static constexpr auto MakeDsGridPointer()
242 {
243 return generate_tuple(
244 [&](auto i) {
245 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
246
247 return static_cast<const DDataType*>(nullptr);
248 },
250 }
251
252 using DsGridPointer = decltype(MakeDsGridPointer());
253
255
256 __host__ static auto CalculateGridSize(index_t M, index_t N)
257 {
258 const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
259 const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
260 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
261 const index_t gridy = NSwizzle ? 1 : mblock;
262
263 return std::make_tuple(gridx, gridy, 1);
264 }
265
266 __host__ static auto CalculateMPadded(index_t M)
267 {
268 return math::integer_least_multiple(M, MPerBlock);
269 }
270
271 __host__ static auto CalculateNPadded(index_t N)
272 {
273 return math::integer_least_multiple(N, NPerBlock);
274 }
275
276 __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
277 {
279 }
280 __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
281 {
283 }
284
285 __host__ static auto CalculateKPadded(index_t K)
286 {
287 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
288 }
289
290 __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
291 {
292 auto K_t = K_Batch * KPerBlock;
293 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
294 }
295
296 __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
297 {
298 auto K_t = K_Batch * KPerBlock;
299 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
300 }
301
302 __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
303 {
304 auto K_t = K_Batch * KPerBlock;
305 return (K + K_t - 1) / K_t * KPerBlock;
306 }
307
308 __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
309 {
310 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
311 auto K_t = K_Batch * KReadVec;
312 return (K + K_t - 1) / K_t * KReadVec;
313 }
314
315 __host__ static auto CalculateMBlock(index_t M)
316 {
317 return math::integer_divide_ceil(M, MPerBlock);
318 }
319
320 __host__ static auto CalculateNBlock(index_t N)
321 {
322 return math::integer_divide_ceil(N, NPerBlock);
323 }
324
325 template <index_t MNXdlPerWave,
326 index_t MNWaves,
327 index_t MNXdlPack,
328 index_t MNPerXdl,
329 bool IsXor,
330 typename TileDesc_K0_MN_K1>
331 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
332 {
333 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
334 constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
335 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
336
337 if constexpr(IsXor)
338 {
339 constexpr auto permuted_desc = transform_tensor_descriptor(
340 TileDesc_K0_MN_K1{},
345
347 permuted_desc,
353 Number<MNPerXdl>{}))),
356 }
357 else
358 {
360 TileDesc_K0_MN_K1{},
366 Number<MNPerXdl>{}))),
369 }
370 }
371
372 __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
373 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
374 {
375 const auto a_grid_desc_mraw_kraw = [&]() {
377 {
378 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
379 }
381 {
382 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
383 }
384 }();
385
386 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
387
388 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
389 GemmSpec == GemmSpecialization::MNKPadding)
390 {
391 // pad both M and K
392 const auto a_grid_desc_m_k =
393 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
395 make_right_pad_transform(K, KPad - K)),
398
399 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
400 a_grid_desc_m_k,
405
406 return a_grid_desc_ak0_m_ak1;
407 }
408 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
409 GemmSpec == GemmSpecialization::MNPadding)
410 {
411 // pad M, but not K
412 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
413 a_grid_desc_mraw_kraw,
415 make_right_pad_transform(M, MPad - M)),
418
419 return a_grid_desc_ak0_m_ak1;
420 }
421 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
422 GemmSpec == GemmSpecialization::NKPadding)
423 {
424 // pad K, but not M
425 const auto a_grid_desc_m_k = transform_tensor_descriptor(
426 a_grid_desc_mraw_kraw,
430
431 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
432 a_grid_desc_m_k,
437
438 return a_grid_desc_ak0_m_ak1;
439 }
440 else
441 {
442 // not pad M or K
443 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
444 a_grid_desc_mraw_kraw,
445 make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
449
450 const auto a_grid_desc_permuted = transform_tensor_descriptor(
451 a_grid_desc_ak0_m_ak1,
457
458 const auto a_grid_desc = transform_tensor_descriptor(
459 a_grid_desc_permuted,
466
467 return a_grid_desc;
468 }
469 }
470
471 __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
472 {
473 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
474 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
475 constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack>{};
477 make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
478 }
479
480 __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
481 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
482 {
483 const auto b_grid_desc_nraw_kraw = [&]() {
485 {
486 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
487 }
489 {
490 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
491 }
492 }();
493
494 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
495
497 GemmSpec != GemmSpecialization::Default),
498 "pk_i4_t does not support padding");
500 GemmSpec != GemmSpecialization::Default),
501 "f4x2_pk_t does not support padding");
502
503 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
504 GemmSpec == GemmSpecialization::MNKPadding)
505 {
506 // pad both N and K
507 const auto b_grid_desc_n_k =
508 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
510 make_right_pad_transform(K, KPad - K)),
513
514 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
515 b_grid_desc_n_k,
520
521 return b_grid_desc_bk0_n_bk1;
522 }
523 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
524 GemmSpec == GemmSpecialization::MNPadding)
525 {
526 // pad N, but not K
527 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
528 b_grid_desc_nraw_kraw,
530 make_right_pad_transform(N, NPad - N)),
533
534 return b_grid_desc_bk0_n_bk1;
535 }
536 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
537 GemmSpec == GemmSpecialization::MKPadding)
538 {
539 // pad K, but not N
540 const auto b_grid_desc_n_k = transform_tensor_descriptor(
541 b_grid_desc_nraw_kraw,
545
546 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
547 b_grid_desc_n_k,
552
553 return b_grid_desc_bk0_n_bk1;
554 }
555 else
556 {
557 // not pad N or K
558 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
559 b_grid_desc_nraw_kraw,
560 make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
564
565 const auto b_grid_desc_permuted = transform_tensor_descriptor(
566 b_grid_desc_bk0_n_bk1,
572
573 const auto b_grid_desc = transform_tensor_descriptor(
574 b_grid_desc_permuted,
581
582 return b_grid_desc;
583 }
584 }
585
586 template <typename ABlockDesc_AK0_M_AK1>
587 __host__ __device__ static constexpr auto
588 MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
589 {
590 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
591
593 ABlockDesc_AK0_M_AK1{});
594 }
595
596 template <typename BBlockDesc_BK0_N_BK1>
597 __host__ __device__ static constexpr auto
598 MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
599 {
600 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
601
603 BBlockDesc_BK0_N_BK1{});
604 }
605
606 template <typename ELayout>
607 __host__ __device__ static auto MakeCGridDescriptor_M_N(
608 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
609 {
610 const auto c_grid_desc_mraw_nraw = [&]() {
612 {
613 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
614 }
616 {
617 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
618 }
619 }();
620
621 // pad M and N
622 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
624 make_right_pad_transform(N, NPad - N)),
627 }
628
629 template <typename DLayout>
630 __host__ __device__ static auto
632 {
633 const auto c_grid_desc_mraw_nraw = [&]() {
635 {
636 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
637 }
639 {
640 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
641 }
642 }();
643
644 // pad M and N
645 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
647 make_right_pad_transform(N, NPad - N)),
650 }
651
652 __host__ __device__ static auto MakeDsGridDescriptor_M_N(
653 index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
654 {
655 return generate_tuple(
656 [&](auto i) {
657 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
658 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
659 },
661 }
662
663 template <typename DsGridDesc>
665 const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
666 {
667 return generate_tuple(
668 [&](auto i) {
670 ds_grid_desc_m_n[i], MBlock, NBlock);
671 },
673 }
674
675 struct Problem
676 {
677 __host__ Problem(index_t NumTokens_,
678 index_t TopK_,
679 index_t M_,
680 index_t N_,
681 index_t K_,
682 index_t StrideA_,
683 index_t StrideScaleA_,
684 index_t StrideB_,
685 index_t StrideScaleB_,
686 std::array<index_t, NumDTensor> StrideDs_,
687 index_t StrideC_,
688 index_t KBatch_)
689 : NumTokens{NumTokens_},
690 TopK{TopK_},
691 M{M_},
692 N{N_},
693 K{K_},
694 StrideA{StrideA_},
695 StrideScaleA{StrideScaleA_},
696 StrideB{StrideB_},
697 StrideScaleB{StrideScaleB_},
698 StrideDs{StrideDs_},
699 StrideC{StrideC_},
700 KBatch{KBatch_},
703 KRead{CalculateKRead(K_, KBatch_)},
704 KPadded{CalculateKPadded(K_, KBatch_)},
705 AK0{CalculateAK0Padded(K_, KBatch_)},
706 BK0{CalculateBK0Padded(K_, KBatch_)},
709 {
710 }
711
712 __host__ void Print() const
713 {
714 std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
715 << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
716 << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
717 << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
718 << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
719 << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
720 << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
721 << ", " << "NBlock: " << NBlock << "}" << std::endl;
722 }
723
733 std::array<index_t, NumDTensor> StrideDs;
744 };
745
746 // Argument
748 {
749 __host__ Argument(const index_t* p_sorted_token_ids_,
750 const index_t* p_sorted_expert_ids_,
751 const index_t* p_max_token_id_,
752 const ADataType* p_a_grid_,
753 const AScaleDataType* p_a_scale_grid_,
754 const BDataType* p_b_grid_,
755 const BScaleDataType* p_b_scale_grid_,
756 std::array<const void*, NumDTensor> p_ds_grid_,
757 CDataType* p_c_grid_,
758 index_t NumTokens_,
759 index_t TopK_,
760 index_t M_,
761 index_t N_,
762 index_t K_,
763 index_t StrideA_,
764 index_t StrideScaleA_,
765 index_t StrideB_,
766 index_t StrideScaleB_,
767 std::array<index_t, NumDTensor> StrideDs_,
768 index_t StrideC_,
769 index_t k_batch_,
770 AElementwiseOperation a_element_op_,
771 BElementwiseOperation b_element_op_,
772 CElementwiseOperation c_element_op_)
773 : Problem{NumTokens_,
774 TopK_,
775 M_,
776 N_,
777 K_ / APackedSize,
778 StrideA_ / APackedSize,
779 StrideScaleA_,
780 StrideB_ / BPackedSize,
781 StrideScaleB_,
782 StrideDs_,
783 StrideC_,
784 k_batch_},
785 p_sorted_token_ids{p_sorted_token_ids_},
786 p_sorted_expert_ids{p_sorted_expert_ids_},
787 p_max_token_id{p_max_token_id_},
788 p_a_grid{p_a_grid_},
789 p_a_scale_grid{p_a_scale_grid_},
790 p_b_grid{p_b_grid_},
791 p_b_scale_grid{p_b_scale_grid_},
792 p_ds_grid{},
793 p_c_grid{p_c_grid_},
794 a_element_op{a_element_op_},
795 b_element_op{b_element_op_},
796 c_element_op{c_element_op_}
797 {
798
799 // populate pointer, desc for Ds
800 static_for<0, NumDTensor, 1>{}([&](auto i) {
801 using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
802
803 // D pointer
804 p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
805 });
806 }
807
811 const ADataType* p_a_grid;
812 const AScaleDataType* p_a_scale_grid;
813 const BDataType* p_b_grid;
814 const BScaleDataType* p_b_scale_grid;
816 CDataType* p_c_grid;
817
818 const AElementwiseOperation a_element_op;
819 const BElementwiseOperation b_element_op;
820 const CElementwiseOperation c_element_op;
821 };
822
824 {
825 __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
826 {
828 {
829 a_k_split_offset = k_id * karg.KRead;
830 }
832 {
833 a_k_split_offset = k_id * karg.KRead * karg.StrideA;
834 }
835
837 {
838 b_k_split_offset = k_id * karg.KRead * karg.StrideB;
839 }
841 {
842 // KPack * NLane * KLane * K0 * N0
843 b_k_split_offset = k_id * karg.KRead * NPerXdl;
844 }
845
846 // Calculate A scale offset
847 a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack *
848 MPerXdl / scale_pack_size_a;
849
850 // Calculate B scale offset
851 b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack *
852 NPerXdl / scale_pack_size_b;
853
854 if(k_id < karg.KBatch - 1)
855 {
856 karg.K = karg.KRead;
857 }
858 else
859 {
860 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
861 }
862 }
863
868 };
869
870 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
871 {
872 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
873 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
874
875 // A matrix in LDS memory, dst of blockwise copy
876 if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
877 {
878 // contiguous in LDS
882 }
883 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
884 // in some cases.
886 {
887 constexpr auto a_lds_block_desc =
890
891 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
892 a_lds_block_desc,
898
899 return a_lds_block_desc_permuted;
900 }
901 else // ColumnMajor A
902 {
903 // kfold and mpair dimension is not always required.
904 // more dimension in merge_transform increase the difficulty of generating immarg offset
905 // for compiler.
906 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
907 constexpr auto M1 = MPerBlock / M0;
908
909 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
910 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
911 constexpr auto KThreadRead = WaveSize / MPerXdl;
912 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
913
914 constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
915 ? 1
916 : 128 / (AK1Number * M0 * sizeof(ADataType));
917 constexpr auto KThreadReadPerm =
918 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
919 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
920 : KThreadRead;
921
922 // 1<=mpair<=n0
923 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
924 ? 1
925 : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
926 ? M0
927 : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
928
929 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
933 Number<kfold * M0 / mpair>{},
935 AK1Number));
936
937 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
938 a_lds_block_desc,
943 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
950
951 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
952 a_lds_block_desc_permuted,
961 Sequence<1>{},
962 Sequence<2>{},
963 Sequence<3>{},
964 Sequence<4>{},
965 Sequence<5>{}),
967 Sequence<2>{},
970 Sequence<6>{},
971 Sequence<7>{}));
972
973 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
974 a_lds_block_desc_unmerged,
977 Number<KThreadWrite / kfold / KThreadReadPerm>{},
985
986 return a_lds_block_desc_ak0_m_ak1;
987 }
988 }
989
990 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
991 {
992 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
994 I1,
998 }
999
1001 {
1002 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1003
1004 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1006 make_tuple(I1,
1008 I1,
1010
1011 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1012 }
1013
1016 BlkGemmPipelineVer,
1017 BlkGemmPipeSched,
1018 BlockSize,
1019 ScaleBlockSize,
1020 ADataType,
1021 AScaleDataType,
1022 BDataType,
1023 BScaleDataType,
1024 ComputeTypeA,
1025 AccDataType,
1032 ABlockTransferSrcScalarPerVector,
1033 BBlockTransferSrcScalarPerVector,
1034 MPerBlock,
1035 NPerBlock,
1036 KPerBlock,
1037 MPerXdl,
1038 NPerXdl,
1039 MXdlPerWave,
1040 NXdlPerWave,
1041 KPack,
1042 IsInputGemm>())>;
1043
1044 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1045 {
1046 // LDS allocation for A and B: be careful of alignment
1047 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1048 // lds max alignment
1049 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1050
1051 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1052 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1053
1054 // LDS allocation for C shuffle in LDS
1055 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1057
1058 constexpr auto c_block_size =
1059 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1060
1061 return math::max(a_block_space_size_aligned * sizeof(ADataType),
1062 c_block_size * sizeof(CShuffleDataType));
1063 }
1064
1066
1067 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1068 __host__ static constexpr bool CheckValidity(const Argument& karg)
1069 {
1070 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1071 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1072 "Invalid tuning param!");
1073
1074 static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1075 "KPerBlock should be multiple of ScaleBlockSize");
1076
1077 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1082 {
1083 if(!(karg.M % MPerBlock == 0))
1084 {
1085 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1086 {
1087 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1088 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1089 << std::endl;
1090 }
1091 return false;
1092 }
1093 }
1094
1095 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1100 {
1101 if(!(karg.N % NPerBlock == 0))
1102 {
1103 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1104 {
1105 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1106 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1107 << std::endl;
1108 }
1109 return false;
1110 }
1111 }
1112
1113 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1117 {
1118 auto K_t = karg.KBatch * KPerBlock;
1119 if(!(karg.K % K_t == 0))
1120 {
1121 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1122 {
1123 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1124 << karg.K << " " << __FILE__ << ":" << __LINE__
1125 << ", in function: " << __func__ << std::endl;
1126 }
1127 return false;
1128 }
1129 }
1130 else
1131 {
1132 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1133 auto K_t = karg.KBatch * KReadVec;
1134 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1135 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1136 {
1137 return false;
1138 }
1139 }
1140
1142 {
1143 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1144 {
1145 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1146 {
1147 std::cout << "Arg K (" << karg.K
1148 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1149 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1150 << __LINE__ << ", in function: " << __func__ << std::endl;
1151 }
1152 return false;
1153 }
1154 }
1155 else
1156 {
1157 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1158 {
1159 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1160 {
1161 std::cout << "Arg M (" << karg.M
1162 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1163 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1164 << __LINE__ << ", in function: " << __func__ << std::endl;
1165 }
1166 return false;
1167 }
1168 }
1169
1171 {
1172 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1173 {
1174 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1175 {
1176 std::cout << "Arg N (" << karg.N
1177 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1178 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1179 << __LINE__ << ", in function: " << __func__ << std::endl;
1180 }
1181 return false;
1182 }
1183 }
1184 else
1185 {
1186 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1187 {
1188 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1189 {
1190 std::cout << "Arg K (" << karg.K
1191 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1192 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1193 << __LINE__ << ", in function: " << __func__ << std::endl;
1194 }
1195 return false;
1196 }
1197 }
1198
1200 {
1202 {
1203 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1204 {
1205 std::cout << "Arg N (" << karg.N
1206 << ") value is not a multiple of "
1207 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1209 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1210 << std::endl;
1211 }
1212 return false;
1213 }
1214 }
1215 else
1216 {
1218 {
1219 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1220 {
1221 std::cout << "Arg M (" << karg.M
1222 << ") value is not a multiple of "
1223 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1225 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1226 << std::endl;
1227
1228 return false;
1229 }
1230 }
1231 }
1232
1233 // check gridwise gemm pipeline
1234#if 0
1235 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1236
1237 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1238 {
1239 return false;
1240 }
1241#endif
1242 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1243 return true;
1244 }
1245
1246 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1247 {
1248 const index_t num_loop = K / KPerBlock;
1249 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1250 }
1251
1252 __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1253 {
1254 const index_t num_loop = K / KPerBlock;
1255
1256 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1257 }
1258
1259 template <typename CGridDesc>
1260 __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1261 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1262 {
1263 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1264 c_grid_desc_m_n,
1269
1270 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1271 }
1272
1273 // return block_id to C matrix tile idx (m0, n0) mapping
1274 // if arch = gfx942
1275 // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1276 // NPerBlock>;
1277
1278 template <bool HasMainKBlockLoop,
1279 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1280 TailNumber TailNum = TailNumber::Odd>
1281 __device__ static void Run(const index_t* p_sorted_token_ids,
1282 const index_t* p_sorted_expert_ids,
1283 const index_t* p_max_token_id,
1284 const ADataType* p_a_grid,
1285 const AScaleDataType* p_a_scale_grid,
1286 const BDataType* p_b_grid,
1287 const BScaleDataType* p_b_scale_grid,
1288 DsGridPointer& p_ds_grid,
1289 CDataType* p_c_grid,
1290 void* p_shared,
1291 const Problem& problem,
1292 AElementwiseOperation a_element_op,
1293 BElementwiseOperation b_element_op,
1294 CElementwiseOperation c_element_op)
1295 {
1296 ignore = a_element_op;
1297 ignore = b_element_op;
1298 index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1299 index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1300 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1301 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1302 problem.MPadded,
1303 problem.K,
1304 problem.KPadded,
1305 problem.StrideA,
1306 problem.AK0);
1307 const auto b_grid_desc_bpreshuffled =
1308 MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1309 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1310 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1311 problem.MPadded,
1312 problem.N,
1313 problem.NPadded,
1314 problem.StrideC);
1315
1316 // We pad the M unconditionaly for Scale
1317 const auto Padded_Scale_M =
1318 math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
1319 const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1320 make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
1321 math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1322 (KXdlPack * 64 / MPerXdl),
1324 make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
1325 (ScaleBlockSize / APackedSize)) *
1326 MPerXdl * MXdlPack / scale_pack_size_a,
1328 1));
1329
1330 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1331 make_tuple(problem.N / (NXdlPack * NPerXdl),
1332 math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1333 (KXdlPack * 64 / NPerXdl),
1335 make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
1336 (ScaleBlockSize / BPackedSize)) *
1337 NPerXdl * NXdlPack / scale_pack_size_b,
1339 1));
1340
1341 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1343 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1344
1345 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1346 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1347 if(expert_block_id * MPerBlock >= max_token_id)
1348 return;
1349 const index_t expert_id =
1350 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1351 const auto block_mn = [&]() -> std::pair<int, int> {
1352 if constexpr(NSwizzle)
1353 {
1354 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1355 const index_t prefix_block = ecnt_prefix * problem.NBlock;
1356 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1357 const index_t expert_swizzle =
1358 ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1359 const index_t bid_new = blockIdx.x - prefix_block;
1360 const index_t nid = __builtin_amdgcn_readfirstlane(
1361 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1362 const index_t mid =
1363 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1364 return {nid, mid};
1365 }
1366 else
1367 {
1368 return {blockIdx.x, blockIdx.y};
1369 }
1370 }();
1371
1372 const index_t block_n_id = block_mn.first;
1373 const index_t block_m_id = block_mn.second;
1374 const index_t token0 =
1375 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1376
1377 // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1378 constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1379 constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1380 constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1381 constexpr auto AKThreads = AK0Threads * AK1Threads;
1382 constexpr auto AMRepeats = MPerBlock / AMThreads;
1383 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
1384
1385 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1386 return;
1388 static_for<0, AMRepeats, 1>{}([&](auto m0) {
1389 const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
1390 index_t token_offset = fused_token & 0xffffff;
1391 if constexpr(!IsInputGemm)
1392 {
1393 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1394 }
1395 gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1396 });
1397
1398 const index_t expert_stride =
1399 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1400 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1401 problem.N * (IsInputGemm ? 2 : 1) *
1402 math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
1403
1404 // N0, K0, Blocksize*KPack
1405 const index_t n_block_data_idx_on_grid =
1406 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
1407
1408 // Gride buffer creation
1409 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1410 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1411 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1412 p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1413
1414 // A, B scale buffer
1415 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1416 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1417 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1418 p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
1419 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1420
1421 // A matrix in LDS memory, dst of blockwise copy
1422 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1423
1424 // B matrix in LDS memory, dst of blockwise copy
1425 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1426
1427 // A matrix blockwise direct to LDS copy
1431 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1432 ABlockTransferThreadClusterArrangeOrder,
1433 ADataType,
1434 ADataType,
1435 decltype(a_grid_desc_ak0_m_ak1),
1436 decltype(a_block_desc_ak0_m_ak1),
1437 ABlockTransferSrcAccessOrder,
1438 ABlockTransferSrcVectorDim,
1439 2,
1440 ABlockTransferSrcScalarPerVector,
1441 IndexType,
1442 1>(a_grid_desc_ak0_m_ak1,
1443 make_multi_index(0, 0, 0),
1444 a_block_desc_ak0_m_ak1,
1445 make_multi_index(0, 0, 0),
1446 gather_offsets);
1447
1448 // Thread-wise copy
1449 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1451 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1453 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1454 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1455
1456 auto b_blockwise_copy =
1458 BDataType,
1459 decltype(b_grid_desc_bpreshuffled),
1460 decltype(b_block_desc_bk0_n_bk1),
1461 Sequence<Number<NXdlPerWave / NXdlPack>{},
1462 I1,
1467 4,
1468 BBlockTransferSrcScalarPerVector,
1469 BThreadTransferSrcResetCoordinateAfterRun,
1470 true>(
1471 b_grid_desc_bpreshuffled,
1472 make_multi_index(n_block_data_idx_on_grid,
1474 0,
1475 0,
1476 KPack * (get_thread_local_1d_id() % WarpSize)));
1477
1478 // LDS allocation for A and B: be careful of alignment
1479 // Cast after lds
1481 static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1482
1483 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1484 constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
1485
1486 // Blockwise GEMM pipeline
1487 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1488 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1489 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1490 decltype(c_thread_buf) c_thread_buf_up;
1491
1493 float,
1494 c_thread_buf.num_of_v_,
1495 c_thread_buf.s_per_v,
1496 true>
1497 c_thread_buf_fp32;
1498
1499 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1500 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1501 KPerBlock);
1502
1503 // a and b scale processing
1504 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1505 const auto waveId_m = wave_idx[I0];
1506 const auto waveId_n = wave_idx[I1];
1507
1508 auto thread_offset_shuffled =
1509 get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1510
1511 auto a_thread_offset_m = waveId_m;
1512
1513 // get each thread's offset int the scale tensor
1514 const index_t token_scale_pos = block_m_id * MPerBlock;
1515 if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
1516 return;
1517
1518 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1519 AScaleDataType,
1520 AScaleDataType,
1521 decltype(a_scale_grid_desc_am_ak),
1522 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1523 Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1524 Sequence<0, 1, 2>, // DimAccessOrder
1525 2, // SrcVectorDim
1526 KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1527 1, // SrcScalarStrideInVector
1528 true>(a_scale_grid_desc_am_ak,
1529 make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1530 0,
1531 thread_offset_shuffled / scale_pack_size_a));
1532
1533 // B scale load
1534 auto b_thread_offset_n = waveId_n;
1535
1536 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1537 BScaleDataType,
1538 BScaleDataType,
1539 decltype(b_scale_grid_desc_bn_ak),
1540 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1541 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1542 Sequence<0, 1, 2>, // DimAccessOrder
1543 2, // SrcVectorDim
1544 KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
1545 1, // SrcScalarStrideInVector
1546 true>(b_scale_grid_desc_bn_ak,
1547 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1548 0,
1549 thread_offset_shuffled / scale_pack_size_b));
1550
1551 if constexpr(IsInputGemm)
1552 {
1553 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1554 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1555 p_b_grid_up + expert_id * expert_stride,
1556 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1557 auto b_blockwise_copy_up =
1559 BDataType,
1560 decltype(b_grid_desc_bpreshuffled),
1561 decltype(b_block_desc_bk0_n_bk1),
1562 Sequence<Number<NXdlPerWave / NXdlPack>{},
1563 I1,
1568 4,
1569 BBlockTransferSrcScalarPerVector,
1570 BThreadTransferSrcResetCoordinateAfterRun,
1571 true>(
1572 b_grid_desc_bpreshuffled,
1573 make_multi_index(n_block_data_idx_on_grid,
1575 0,
1576 0,
1577 KPack * (get_thread_local_1d_id() % WarpSize)));
1578 const BScaleDataType* p_b_scale_grid_up =
1579 p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
1580 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1581 p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
1582 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1583
1584 auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1585 BScaleDataType,
1586 BScaleDataType,
1587 decltype(b_scale_grid_desc_bn_ak),
1588 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1589 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1590 Sequence<0, 1, 2>, // DimAccessOrder
1591 2, // SrcVectorDim
1592 KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1593 1, // SrcScalarStrideInVector
1594 true>(
1595 b_scale_grid_desc_bn_ak,
1596 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1597 0,
1598 thread_offset_shuffled / scale_pack_size_b));
1599
1600 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1601 // A
1602 a_grid_desc_ak0_m_ak1,
1603 a_block_desc_ak0_m_ak1,
1604 a_blockwise_copy,
1605 a_grid_buf,
1606 a_block_buf,
1607 a_block_slice_copy_step,
1608 // Gate and Up
1609 b_grid_desc_bpreshuffled,
1610 b_block_desc_bk0_n_bk1,
1611 b_blockwise_copy,
1612 b_blockwise_copy_up,
1613 b_grid_buf,
1614 b_grid_buf_up,
1615 b_block_bufs,
1616 b_block_slice_copy_step,
1617 // C
1618 c_thread_buf,
1619 c_thread_buf_up,
1620 // A scale
1621 a_scale_grid_desc_am_ak,
1622 a_scale_thread_copy,
1623 a_scale_grid_buf,
1624 // B scale
1625 b_scale_grid_desc_bn_ak,
1626 b_scale_thread_copy,
1627 b_scale_thread_copy_up,
1628 b_scale_grid_buf,
1629 b_scale_grid_buf_up,
1630 num_k_block_main_loop);
1631 }
1632 else
1633 {
1634 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1635 a_grid_desc_ak0_m_ak1, // A
1636 a_block_desc_ak0_m_ak1,
1637 a_blockwise_copy,
1638 a_grid_buf,
1639 a_block_buf,
1640 a_block_slice_copy_step,
1641 b_grid_desc_bpreshuffled, // B
1642 b_block_desc_bk0_n_bk1,
1643 b_blockwise_copy,
1644 b_grid_buf,
1645 b_block_bufs,
1646 b_block_slice_copy_step,
1647 c_thread_buf, // C
1648 a_scale_grid_desc_am_ak, // A scale
1649 a_scale_thread_copy,
1650 a_scale_grid_buf,
1651 b_scale_grid_desc_bn_ak, // B scale
1652 b_scale_thread_copy,
1653 b_scale_grid_buf,
1654 num_k_block_main_loop);
1655 }
1656
1657 // shuffle C and write out
1658 {
1659 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1660 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1661 "wrong!");
1662 static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1663 CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1664 "wrong!");
1665
1666 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1667
1668 // TODO: hacky, fix it!
1669 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1670 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1671
1672 // TODO: hacky, fix it!
1673 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1674 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1675 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1676
1677 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1678 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1679 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1680 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1681 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1682 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1683 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1684 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1685 constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
1686 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
1687
1688 // mul scales
1689
1690 static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1691 static_assert(M5 == 4);
1692 const index_t m1 = get_warp_local_1d_id() / NWave;
1693 const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
1694
1695 vector_type<float, 4> topk_weights; // for gemm2 only
1696 static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
1697 static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
1698 static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
1699 static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
1700 static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
1701 const index_t m_pos = block_m_id * MPerBlock +
1702 m0 * M2 * M1 * M3 * M4 * M5 +
1703 m1 * M2 * M3 * M4 * M5 +
1704 imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1705 if constexpr(MulRoutedWeight)
1706 {
1707 topk_weights =
1709 p_ds_grid[I2] + m_pos);
1710 }
1711 static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
1712 constexpr index_t c_offset =
1713 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1714 make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1715 constexpr auto cidx = Number<c_offset>{};
1716
1717 if constexpr(IsInputGemm) // gu fusion
1718 {
1719 if constexpr(ActivationOperation ==
1720 Activation::silu_and_mul)
1721 {
1722 float gate = c_thread_buf[cidx];
1723 float up = c_thread_buf_up[cidx];
1724 if constexpr(MulRoutedWeight)
1725 {
1726 gate = gate * topk_weights.AsType<float>()[m5];
1727 up = up * topk_weights.AsType<float>()[m5];
1728 }
1730 c_thread_buf_fp32(cidx) = gate * up;
1731 }
1732 else if(ActivationOperation == Activation::gelu_and_mul)
1733 {
1734 float gate = c_thread_buf[cidx];
1735 float up = c_thread_buf_up[cidx];
1736 if constexpr(MulRoutedWeight)
1737 {
1738 gate = gate * topk_weights.AsType<float>()[m5];
1739 up = up * topk_weights.AsType<float>()[m5];
1740 }
1742 c_thread_buf_fp32(cidx) = gate * up;
1743 }
1744 }
1745 else
1746 {
1747 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1748 if constexpr(MulRoutedWeight)
1749 {
1750 c_thread_buf_fp32(cidx) =
1751 topk_weights.AsType<float>()[m5] *
1752 c_thread_buf_fp32[cidx];
1753 }
1754 }
1755 });
1756 });
1757 });
1758 });
1759 });
1760 });
1761
1762 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1764
1765 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1766 static_cast<CShuffleDataType*>(p_shared),
1767 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1768
1769 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1770 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1771 make_tuple(
1774 Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
1775 // shuffle
1776 M1, // M1 = MWave
1777 M2, // M2 * M3 * M4 = MPerXdl
1778 M3,
1779 M4,
1780 M5)),
1784 // per shuffle
1785 N1, // N1 = NWave
1786 N2, // N2 = NXdlPack
1787 N3))), // N3 = NPerXdl
1791 Sequence<>{},
1793
1794 // calculate origin of thread output tensor on global memory
1795 // blockwise GEMM c matrix starting index
1796 const auto c_thread_mtx_on_block =
1797 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1798
1799 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1800 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1801
1802 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1804 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
1807
1808 const auto m_thread_data_on_block_idx =
1809 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1810 make_multi_index(m_thread_data_on_block));
1811
1812 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1814 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
1817
1818 const auto n_thread_data_on_block_idx =
1819 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1820 make_multi_index(n_thread_data_on_block));
1821
1822 // shuffle: threadwise copy C from VGPR to LDS
1823 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1824 AccDataType,
1825 CShuffleDataType,
1826 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1827 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1829 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1830 CShuffleNXdlPerWavePerShuffle / NXdlPack,
1831 I1,
1832 I1,
1833 M2,
1834 N2,
1835 M3,
1836 I1,
1837 M5,
1838 I1>,
1840 9,
1841 1,
1843 1,
1844 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1846 0,
1847 m_thread_data_on_block_idx[I1],
1848 n_thread_data_on_block_idx[I1],
1849 m_thread_data_on_block_idx[I2],
1850 n_thread_data_on_block_idx[I2],
1851 m_thread_data_on_block_idx[I3],
1852 m_thread_data_on_block_idx[I4],
1853 m_thread_data_on_block_idx[I5],
1854 n_thread_data_on_block_idx[I3]),
1856
1857 using EDataType = CDataType;
1858
1859 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1860 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1861
1862 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1864 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1865
1866 const auto ds_grid_buf = generate_tuple(
1867 [&](auto i) {
1869 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1870 },
1872
1873 // tuple of reference to C/Ds tensor descriptors
1874 const auto c_ds_desc_refs = concat_tuple_of_reference(
1875 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1876 generate_tie([&](auto i) -> const auto& // return type should be reference
1877 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1879
1880 // tuple of reference to C/Ds tensor descriptors
1881 const auto c_ds_buf_refs = concat_tuple_of_reference(
1882 tie(c_shuffle_block_buf),
1883 generate_tie([&](auto i) -> const auto& // return type should be reference
1884 { return ds_grid_buf[i]; },
1886
1887 // tuple of starting index of C/Ds blockwise copy
1888 const auto idx_c_ds_block_begin =
1891 [&](auto) {
1892 return make_multi_index(block_m_id, 0, block_n_id, 0);
1893 // return make_multi_index(block_work_idx[I0], 0,
1894 // block_work_idx[I1], 0);
1895 },
1897
1898 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1899 c_grid_desc_mblock_mperblock_nblock_nperblock;
1900
1901 using CDEBlockTransferCluster =
1902 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1903 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1904 constexpr index_t scatter_weight_idx = 3; // hack fix felix
1905 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1907 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1909 decltype(c_ds_desc_refs),
1910 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1911 CElementwiseOperation,
1912 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
1913 // Sequence support
1914 // arbitray type
1915 Sequence<1,
1916 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1917 1,
1918 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1919 CDEBlockTransferCluster,
1920 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1921 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1922 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1923 3, // index_t SrcVectorDim,
1924 3, // index_t DstVectorDim,
1925 CDEShuffleBlockTransferScalarPerVectors,
1930 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1931 Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1932 IndexType,
1933 1, // ScatterDim
1934 true, // OutputScatter: false, only use scatter weights
1935 scatter_weight_idx // ScatterWeightIdx: ascale
1936 >{c_ds_desc_refs,
1937 idx_c_ds_block_begin,
1938 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1939 make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1940 c_element_op};
1941
1943 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1944
1945 constexpr auto sfc_c_vgpr =
1946 SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
1947 NXdlPerWave / NXdlPack,
1948 1,
1949 1,
1950 MXdlPack,
1951 NXdlPack,
1952 M2,
1953 1,
1954 M4,
1955 1>,
1957 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1958 CShuffleNXdlPerWavePerShuffle / NXdlPack,
1959 1,
1960 1,
1961 MXdlPack,
1962 NXdlPack,
1963 M2,
1964 1,
1965 M4,
1966 1>>{};
1967
1968 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1969
1970 // space filling curve for shuffled blockwise C/D/E
1971 constexpr auto sfc_cde_block =
1974 Sequence<1,
1975 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1976 1,
1977 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1978
1979 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1980 constexpr auto EMThreads =
1981 CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1982 constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1983 constexpr auto ENThreads =
1984 CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1985 static_for<0, num_access, 1>{}([&](auto access_id) {
1986 // make sure it's safe to write to LDS
1988
1989 auto dstidx = sfc_cde_block.GetIndex(access_id);
1990 const index_t c_token_pos =
1991 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1992 static_for<0, EMRepeats, 1>{}([&](auto m0) {
1993 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1994 IndexType token_offset = fused_token & 0xffffff;
1995 if constexpr(IsInputGemm)
1996 {
1997 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1998 }
1999 scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2000 });
2001
2003
2004 // each thread write its data from VGPR to LDS
2005 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2006 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2007 c_thread_buf_fp32,
2008 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2009 c_shuffle_block_buf);
2010
2011 // make sure it's safe to read from LDS
2013
2014 // each block copy its data from LDS to global
2015 cde_block_copy_lds_and_global.Run(
2016 c_ds_desc_refs,
2017 c_ds_buf_refs,
2018 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2019 tie(c_grid_buf),
2020 scatter_offsets);
2021
2022 if constexpr(access_id < num_access - 1)
2023 {
2024 constexpr auto cde_lds_and_global_step =
2025 sfc_cde_block.GetForwardStep(access_id);
2026
2027 // move on Ds
2028 static_for<0, NumDTensor, 1>{}([&](auto i) {
2029 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2030 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2031 });
2032
2033 // move on E
2034 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2035 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2036 I0,
2037 cde_lds_and_global_step);
2038 }
2039 });
2040 }
2041 }
2042
2043 template <bool HasMainKBlockLoop,
2044 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2045 TailNumber TailNum = TailNumber::Odd>
2046 __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
2047 const index_t* p_sorted_expert_ids,
2048 const index_t* p_max_token_id,
2049 const ADataType* p_a_grid,
2050 const AScaleDataType* p_a_scale_grid,
2051 const BDataType* p_b_grid,
2052 const BScaleDataType* p_b_scale_grid,
2053 DsGridPointer& p_ds_grid,
2054 CDataType* p_c_grid,
2055 void* p_shared_0,
2056 void* p_shared_1,
2057 const Problem& problem,
2058 AElementwiseOperation a_element_op,
2059 BElementwiseOperation b_element_op,
2060 CElementwiseOperation c_element_op)
2061 {
2062 ignore = a_element_op;
2063 ignore = b_element_op;
2064 index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
2065 index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
2066 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2067 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2068 problem.MPadded,
2069 problem.K,
2070 problem.KPadded,
2071 problem.StrideA,
2072 problem.AK0);
2073 const auto b_grid_desc_bpreshuffled =
2074 MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
2075 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2076 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2077 problem.MPadded,
2078 problem.N,
2079 problem.NPadded,
2080 problem.StrideC);
2081
2082 // We pad the M unconditionaly for Scale
2083 const auto Padded_Scale_M =
2084 math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
2085 const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
2086 make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
2087 math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2088 (KXdlPack * 64 / MPerXdl),
2090 make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2091 (ScaleBlockSize / APackedSize)) *
2092 MPerXdl * MXdlPack / scale_pack_size_a,
2094 1));
2095
2096 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
2097 make_tuple(problem.N / (NXdlPack * NPerXdl),
2098 math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2099 (KXdlPack * 64 / NPerXdl),
2101 make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2102 (ScaleBlockSize / BPackedSize)) *
2103 NPerXdl * NXdlPack / scale_pack_size_b,
2105 1));
2106
2107 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2109 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2110
2111 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2112 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2113 if(expert_block_id * MPerBlock >= max_token_id)
2114 return;
2115 const index_t expert_id =
2116 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2117 const auto block_mn = [&]() -> std::pair<int, int> {
2118 if constexpr(NSwizzle)
2119 {
2120 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2121 const index_t prefix_block = ecnt_prefix * problem.NBlock;
2122 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2123 const index_t expert_swizzle =
2124 ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2125 const index_t bid_new = blockIdx.x - prefix_block;
2126 const index_t nid = __builtin_amdgcn_readfirstlane(
2127 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2128 const index_t mid =
2129 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2130 return {nid, mid};
2131 }
2132 else
2133 {
2134 return {blockIdx.x, blockIdx.y};
2135 }
2136 }();
2137
2138 const index_t block_n_id = block_mn.first;
2139 const index_t block_m_id = block_mn.second;
2140 const index_t token0 =
2141 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2142
2143 // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2144 constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2145 constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2146 constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2147 constexpr auto AKThreads = AK0Threads * AK1Threads;
2148 constexpr auto AMRepeats = MPerBlock / AMThreads;
2149 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2150
2151 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2152 return;
2154 static_for<0, AMRepeats, 1>{}([&](auto m0) {
2155 const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2156 index_t token_offset = fused_token & 0xffffff;
2157 if constexpr(!IsInputGemm)
2158 {
2159 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2160 }
2161 gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2162 });
2163
2164 const index_t expert_stride =
2165 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2166 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2167 problem.N * (IsInputGemm ? 2 : 1) *
2168 math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2169
2170 // N0, K0, Blocksize*KPack
2171 const index_t n_block_data_idx_on_grid =
2172 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
2173
2174 // Gride buffer creation
2175 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2176 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2177 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2178 p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2179
2180 // A, B scale buffer
2181 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2182 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2183 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2184 p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2185 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2186
2187 // A matrix in LDS memory, dst of blockwise copy
2188 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2189
2190 // B matrix in LDS memory, dst of blockwise copy
2191 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2192
2193 // A matrix blockwise direct to LDS copy
2197 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2198 ABlockTransferThreadClusterArrangeOrder,
2199 ADataType,
2200 ADataType,
2201 decltype(a_grid_desc_ak0_m_ak1),
2202 decltype(a_block_desc_ak0_m_ak1),
2203 ABlockTransferSrcAccessOrder,
2204 ABlockTransferSrcVectorDim,
2205 2,
2206 ABlockTransferSrcScalarPerVector,
2207 IndexType,
2208 1>(a_grid_desc_ak0_m_ak1,
2209 make_multi_index(0, 0, 0),
2210 a_block_desc_ak0_m_ak1,
2211 make_multi_index(0, 0, 0),
2212 gather_offsets);
2213
2214 // Thread-wise copy
2215 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2217 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2219 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2220 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2221
2222 auto b_blockwise_copy =
2224 BDataType,
2225 decltype(b_grid_desc_bpreshuffled),
2226 decltype(b_block_desc_bk0_n_bk1),
2227 Sequence<Number<NXdlPerWave / NXdlPack>{},
2228 I1,
2233 4,
2234 BBlockTransferSrcScalarPerVector,
2235 BThreadTransferSrcResetCoordinateAfterRun,
2236 true>(
2237 b_grid_desc_bpreshuffled,
2238 make_multi_index(n_block_data_idx_on_grid,
2240 0,
2241 0,
2242 KPack * (get_thread_local_1d_id() % WarpSize)));
2243
2244 // LDS allocation for A and B: be careful of alignment
2245 // Cast after lds
2246 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2247 static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2248 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2249 static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2250 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2251
2252 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2253 constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
2254
2255 // Blockwise GEMM pipeline
2256 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2257 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2258 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2259 decltype(c_thread_buf) c_thread_buf_up;
2260
2262 float,
2263 c_thread_buf.num_of_v_,
2264 c_thread_buf.s_per_v,
2265 true>
2266 c_thread_buf_fp32;
2267
2268 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2269 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2270 KPerBlock);
2271
2272 // a and b scale processing
2273 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2274 const auto waveId_m = wave_idx[I0];
2275 const auto waveId_n = wave_idx[I1];
2276
2277 auto thread_offset_shuffled =
2278 get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2279
2280 auto a_thread_offset_m = waveId_m;
2281
2282 // get each thread's offset int the scale tensor
2283 const index_t token_scale_pos = block_m_id * MPerBlock;
2284 if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2285 return;
2286
2287 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2288 AScaleDataType,
2289 AScaleDataType,
2290 decltype(a_scale_grid_desc_am_ak),
2291 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2292 Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2293 Sequence<0, 1, 2>, // DimAccessOrder
2294 2, // SrcVectorDim
2295 KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2296 1, // SrcScalarStrideInVector
2297 true>(a_scale_grid_desc_am_ak,
2298 make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2299 0,
2300 thread_offset_shuffled / scale_pack_size_a));
2301
2302 // B scale load
2303 auto b_thread_offset_n = waveId_n;
2304
2305 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2306 BScaleDataType,
2307 BScaleDataType,
2308 decltype(b_scale_grid_desc_bn_ak),
2309 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2310 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2311 Sequence<0, 1, 2>, // DimAccessOrder
2312 2, // SrcVectorDim
2313 KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2314 1, // SrcScalarStrideInVector
2315 true>(b_scale_grid_desc_bn_ak,
2316 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2317 0,
2318 thread_offset_shuffled / scale_pack_size_b));
2319
2320 if constexpr(IsInputGemm)
2321 {
2322 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2323 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2324 p_b_grid_up + expert_id * expert_stride,
2325 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2326 auto b_blockwise_copy_up =
2328 BDataType,
2329 decltype(b_grid_desc_bpreshuffled),
2330 decltype(b_block_desc_bk0_n_bk1),
2331 Sequence<Number<NXdlPerWave / NXdlPack>{},
2332 I1,
2337 4,
2338 BBlockTransferSrcScalarPerVector,
2339 BThreadTransferSrcResetCoordinateAfterRun,
2340 true>(
2341 b_grid_desc_bpreshuffled,
2342 make_multi_index(n_block_data_idx_on_grid,
2344 0,
2345 0,
2346 KPack * (get_thread_local_1d_id() % WarpSize)));
2347 const BScaleDataType* p_b_scale_grid_up =
2348 p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
2349 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2350 p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
2351 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2352
2353 auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2354 BScaleDataType,
2355 BScaleDataType,
2356 decltype(b_scale_grid_desc_bn_ak),
2357 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2358 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2359 Sequence<0, 1, 2>, // DimAccessOrder
2360 2, // SrcVectorDim
2361 KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2362 1, // SrcScalarStrideInVector
2363 true>(
2364 b_scale_grid_desc_bn_ak,
2365 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2366 0,
2367 thread_offset_shuffled / scale_pack_size_b));
2368
2369 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2370 // A
2371 a_grid_desc_ak0_m_ak1,
2372 a_block_desc_ak0_m_ak1,
2373 a_blockwise_copy,
2374 a_grid_buf,
2375 a_block_bufs,
2376 a_block_slice_copy_step,
2377 // Gate and Up
2378 b_grid_desc_bpreshuffled,
2379 b_block_desc_bk0_n_bk1,
2380 b_blockwise_copy,
2381 b_blockwise_copy_up,
2382 b_grid_buf,
2383 b_grid_buf_up,
2384 b_block_bufs,
2385 b_block_slice_copy_step,
2386 // C
2387 c_thread_buf,
2388 c_thread_buf_up,
2389 // A scale
2390 a_scale_grid_desc_am_ak,
2391 a_scale_thread_copy,
2392 a_scale_grid_buf,
2393 // B scale
2394 b_scale_grid_desc_bn_ak,
2395 b_scale_thread_copy,
2396 b_scale_thread_copy_up,
2397 b_scale_grid_buf,
2398 b_scale_grid_buf_up,
2399 num_k_block_main_loop);
2400 }
2401 else
2402 {
2403 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2404 a_grid_desc_ak0_m_ak1, // A
2405 a_block_desc_ak0_m_ak1,
2406 a_blockwise_copy,
2407 a_grid_buf,
2408 a_block_bufs,
2409 a_block_slice_copy_step,
2410 b_grid_desc_bpreshuffled, // B
2411 b_block_desc_bk0_n_bk1,
2412 b_blockwise_copy,
2413 b_grid_buf,
2414 b_block_bufs,
2415 b_block_slice_copy_step,
2416 c_thread_buf, // C
2417 a_scale_grid_desc_am_ak, // A scale
2418 a_scale_thread_copy,
2419 a_scale_grid_buf,
2420 b_scale_grid_desc_bn_ak, // B scale
2421 b_scale_thread_copy,
2422 b_scale_grid_buf,
2423 num_k_block_main_loop);
2424 }
2425
2426 // shuffle C and write out
2427 {
2428 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2429 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2430 "wrong!");
2431 static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
2432 CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
2433 "wrong!");
2434
2435 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2436
2437 // TODO: hacky, fix it!
2438 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2439 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2440
2441 // TODO: hacky, fix it!
2442 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2443 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2444 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2445
2446 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2447 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2448 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2449 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2450 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2451 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2452 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2453 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2454 constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
2455 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
2456
2457 // mul scales
2458
2459 static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2460 static_assert(M5 == 4);
2461 const index_t m1 = get_warp_local_1d_id() / NWave;
2462 const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
2463
2464 vector_type<float, 4> topk_weights; // for gemm2 only
2465 static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
2466 static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
2467 static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
2468 static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
2469 static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
2470 const index_t m_pos = block_m_id * MPerBlock +
2471 m0 * M2 * M1 * M3 * M4 * M5 +
2472 m1 * M2 * M3 * M4 * M5 +
2473 imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2474 if constexpr(MulRoutedWeight)
2475 {
2476 topk_weights =
2478 p_ds_grid[I2] + m_pos);
2479 }
2480 static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
2481 constexpr index_t c_offset =
2482 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2483 make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2484 constexpr auto cidx = Number<c_offset>{};
2485
2486 if constexpr(IsInputGemm) // gu fusion
2487 {
2488 if constexpr(ActivationOperation ==
2489 Activation::silu_and_mul)
2490 {
2491 float gate = c_thread_buf[cidx];
2492 float up = c_thread_buf_up[cidx];
2493 if constexpr(MulRoutedWeight)
2494 {
2495 gate = gate * topk_weights.AsType<float>()[m5];
2496 up = up * topk_weights.AsType<float>()[m5];
2497 }
2499 c_thread_buf_fp32(cidx) = gate * up;
2500 }
2501 else if(ActivationOperation == Activation::gelu_and_mul)
2502 {
2503 float gate = c_thread_buf[cidx];
2504 float up = c_thread_buf_up[cidx];
2505 if constexpr(MulRoutedWeight)
2506 {
2507 gate = gate * topk_weights.AsType<float>()[m5];
2508 up = up * topk_weights.AsType<float>()[m5];
2509 }
2511 c_thread_buf_fp32(cidx) = gate * up;
2512 }
2513 }
2514 else
2515 {
2516 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2517 if constexpr(MulRoutedWeight)
2518 {
2519 c_thread_buf_fp32(cidx) =
2520 topk_weights.AsType<float>()[m5] *
2521 c_thread_buf_fp32[cidx];
2522 }
2523 }
2524 });
2525 });
2526 });
2527 });
2528 });
2529 });
2530
2531 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2533
2534 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2535 static_cast<CShuffleDataType*>(p_shared_0),
2536 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2537
2538 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2539 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2540 make_tuple(
2543 Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
2544 // shuffle
2545 M1, // M1 = MWave
2546 M2, // M2 * M3 * M4 = MPerXdl
2547 M3,
2548 M4,
2549 M5)),
2553 // per shuffle
2554 N1, // N1 = NWave
2555 N2, // N2 = NXdlPack
2556 N3))), // N3 = NPerXdl
2560 Sequence<>{},
2562
2563 // calculate origin of thread output tensor on global memory
2564 // blockwise GEMM c matrix starting index
2565 const auto c_thread_mtx_on_block =
2566 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2567
2568 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2569 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2570
2571 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2573 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
2576
2577 const auto m_thread_data_on_block_idx =
2578 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2579 make_multi_index(m_thread_data_on_block));
2580
2581 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2583 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
2586
2587 const auto n_thread_data_on_block_idx =
2588 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2589 make_multi_index(n_thread_data_on_block));
2590
2591 // shuffle: threadwise copy C from VGPR to LDS
2592 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2593 AccDataType,
2594 CShuffleDataType,
2595 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2596 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2598 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2599 CShuffleNXdlPerWavePerShuffle / NXdlPack,
2600 I1,
2601 I1,
2602 M2,
2603 N2,
2604 M3,
2605 I1,
2606 M5,
2607 I1>,
2609 9,
2610 1,
2612 1,
2613 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2615 0,
2616 m_thread_data_on_block_idx[I1],
2617 n_thread_data_on_block_idx[I1],
2618 m_thread_data_on_block_idx[I2],
2619 n_thread_data_on_block_idx[I2],
2620 m_thread_data_on_block_idx[I3],
2621 m_thread_data_on_block_idx[I4],
2622 m_thread_data_on_block_idx[I5],
2623 n_thread_data_on_block_idx[I3]),
2625
2626 using EDataType = CDataType;
2627
2628 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2629 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2630
2631 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2633 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2634
2635 const auto ds_grid_buf = generate_tuple(
2636 [&](auto i) {
2638 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2639 },
2641
2642 // tuple of reference to C/Ds tensor descriptors
2643 const auto c_ds_desc_refs = concat_tuple_of_reference(
2644 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2645 generate_tie([&](auto i) -> const auto& // return type should be reference
2646 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2648
2649 // tuple of reference to C/Ds tensor descriptors
2650 const auto c_ds_buf_refs = concat_tuple_of_reference(
2651 tie(c_shuffle_block_buf),
2652 generate_tie([&](auto i) -> const auto& // return type should be reference
2653 { return ds_grid_buf[i]; },
2655
2656 // tuple of starting index of C/Ds blockwise copy
2657 const auto idx_c_ds_block_begin =
2660 [&](auto) {
2661 return make_multi_index(block_m_id, 0, block_n_id, 0);
2662 // return make_multi_index(block_work_idx[I0], 0,
2663 // block_work_idx[I1], 0);
2664 },
2666
2667 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2668 c_grid_desc_mblock_mperblock_nblock_nperblock;
2669
2670 using CDEBlockTransferCluster =
2671 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2672 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2673 constexpr index_t scatter_weight_idx = 3; // hack fix felix
2674 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2676 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2678 decltype(c_ds_desc_refs),
2679 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2680 CElementwiseOperation,
2681 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2682 // Sequence support
2683 // arbitray type
2684 Sequence<1,
2685 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2686 1,
2687 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2688 CDEBlockTransferCluster,
2689 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2690 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2691 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2692 3, // index_t SrcVectorDim,
2693 3, // index_t DstVectorDim,
2694 CDEShuffleBlockTransferScalarPerVectors,
2699 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2700 Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2701 IndexType,
2702 1, // ScatterDim
2703 true, // OutputScatter: false, only use scatter weights
2704 scatter_weight_idx // ScatterWeightIdx: ascale
2705 >{c_ds_desc_refs,
2706 idx_c_ds_block_begin,
2707 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2708 make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2709 c_element_op};
2710
2712 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2713
2714 constexpr auto sfc_c_vgpr =
2715 SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2716 NXdlPerWave / NXdlPack,
2717 1,
2718 1,
2719 MXdlPack,
2720 NXdlPack,
2721 M2,
2722 1,
2723 M4,
2724 1>,
2726 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2727 CShuffleNXdlPerWavePerShuffle / NXdlPack,
2728 1,
2729 1,
2730 MXdlPack,
2731 NXdlPack,
2732 M2,
2733 1,
2734 M4,
2735 1>>{};
2736
2737 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2738
2739 // space filling curve for shuffled blockwise C/D/E
2740 constexpr auto sfc_cde_block =
2743 Sequence<1,
2744 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2745 1,
2746 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2747
2748 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2749 constexpr auto EMThreads =
2750 CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2751 constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2752 constexpr auto ENThreads =
2753 CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2754 static_for<0, num_access, 1>{}([&](auto access_id) {
2755 // make sure it's safe to write to LDS
2757
2758 auto dstidx = sfc_cde_block.GetIndex(access_id);
2759 const index_t c_token_pos =
2760 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2761 static_for<0, EMRepeats, 1>{}([&](auto m0) {
2762 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2763 IndexType token_offset = fused_token & 0xffffff;
2764 if constexpr(IsInputGemm)
2765 {
2766 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2767 }
2768 scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2769 });
2770
2772
2773 // each thread write its data from VGPR to LDS
2774 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2775 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2776 c_thread_buf_fp32,
2777 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2778 c_shuffle_block_buf);
2779
2780 // make sure it's safe to read from LDS
2782
2783 // each block copy its data from LDS to global
2784 cde_block_copy_lds_and_global.Run(
2785 c_ds_desc_refs,
2786 c_ds_buf_refs,
2787 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2788 tie(c_grid_buf),
2789 scatter_offsets);
2790
2791 if constexpr(access_id < num_access - 1)
2792 {
2793 constexpr auto cde_lds_and_global_step =
2794 sfc_cde_block.GetForwardStep(access_id);
2795
2796 // move on Ds
2797 static_for<0, NumDTensor, 1>{}([&](auto i) {
2798 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2799 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2800 });
2801
2802 // move on E
2803 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2804 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2805 I0,
2806 cde_lds_and_global_step);
2807 }
2808 });
2809 }
2810 }
2811};
2812
2813} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__device__ index_t get_warp_local_1d_id()
Definition get_id.hpp:45
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_mx_gemm_bns.hpp:48
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp:37
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_mx_gemm.hpp:90
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ PY c_style_pointer_cast(PX p_x)
Definition c_style_pointer_cast.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
Activation
Definition gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition gridwise_moe_gemm.hpp:32
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
constexpr index_t packed_size_v
Definition data_type.hpp:411
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:748
const index_t * p_sorted_expert_ids
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:809
const index_t * p_sorted_token_ids
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:808
const BScaleDataType * p_b_scale_grid
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:814
const BElementwiseOperation b_element_op
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:819
const BDataType * p_b_grid
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:813
const index_t * p_max_token_id
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:810
DsGridPointer p_ds_grid
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:815
const AElementwiseOperation a_element_op
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:818
CDataType * p_c_grid
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:816
const AScaleDataType * p_a_scale_grid
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:812
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:749
const CElementwiseOperation c_element_op
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:820
const ADataType * p_a_grid
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:811
index_t AK0
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:740
index_t NPadded
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:737
index_t KBatch
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:735
index_t BK0
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:741
index_t TopK
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:725
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:677
index_t KRead
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:738
index_t K
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:728
index_t MPadded
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:736
index_t StrideC
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:734
index_t StrideScaleB
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:732
index_t NumTokens
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:724
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:733
index_t MBlock
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:742
index_t StrideScaleA
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:730
index_t N
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:727
__host__ void Print() const
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:712
index_t StrideA
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:729
index_t StrideB
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:731
index_t M
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:726
index_t KPadded
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:739
index_t NBlock
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:743
index_t a_scale_k_split_offset
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:866
index_t b_scale_k_split_offset
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:867
index_t a_k_split_offset
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:864
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:825
index_t b_k_split_offset
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:865
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:174
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:2046
remove_cvref_t< decltype(BlockGemmMXBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:1014
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:1281
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:1260
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition static_buffer.hpp:75
Definition thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition data_type.hpp:42
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1041
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1087
Definition dtype_vector.hpp:10
#define CK_ENV(name)
Definition utility/env.hpp:129