gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp Source File

gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp Source File
gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16#include "ck/utility/env.hpp"
18
19namespace ck {
20
21#ifndef KERNEL_GEMM_XDL_CSHUFFLE_V3_MX
22#define KERNEL_GEMM_XDL_CSHUFFLE_V3_MX
23// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24// kernel function Blockers:
25// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26// two lds chunks.
27// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28// buffer when we declare __shared__ inside blkgemmpipe
29template <bool Use2LDS,
30 typename GridwiseGemm,
31 bool HasMainKBlockLoop,
32 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
33 index_t MinimumOccupancy = 1,
36#if CK_USE_LAUNCH_BOUNDS
37__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
38#endif
39 // __attribute__((amdgpu_waves_per_eu(1, 1)))
40 kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg)
41{
42#if defined(__gfx950__)
43 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
44 {
45 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
46
47 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
48
49 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
50 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
51 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
52 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
53 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
54 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
55 p_shared,
56 karg);
57 }
58#else
59 ignore = karg;
60#endif // end of if (defined(__gfx9__))
61}
62
63template <bool Use2LDS,
64 typename GridwiseGemm,
65 bool HasMainKBlockLoop,
66 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
67 index_t MinimumOccupancy = 1,
70#if CK_USE_LAUNCH_BOUNDS
71__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
72#endif
73 // __attribute__((amdgpu_waves_per_eu(1, 1)))
74 kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg)
75{
76#if defined(__gfx950__)
77 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
78 {
79 // Pass two lds pointer is the key to tell compiler that ds_read/write
80 // operate on different lds chunk at same time without order dependecy
81 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
82 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
83
84 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
85
86 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
87 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
88 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
89 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
90 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
91 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
92 p_shared_0,
93 p_shared_1,
94 karg);
95 }
96#else
97 ignore = karg;
98#endif // end of if (defined(__gfx9__))
99}
100#endif
101
102template <typename ALayout,
103 typename BLayout,
104 typename CLayout,
105 typename ADataType,
106 typename AScaleDataType,
107 typename BDataType,
108 typename BScaleDataType,
109 typename AccDataType,
110 typename CShuffleDataType,
111 typename CDataType,
112 typename AElementwiseOperation,
113 typename BElementwiseOperation,
114 typename CElementwiseOperation,
116 index_t ScaleBlockSize, // Scaling block size
117 index_t BlockSize, // Thread block size
118 index_t MPerBlock,
119 index_t NPerBlock,
120 index_t KPerBlock,
121 index_t AK1Value,
122 index_t BK1Value,
123 index_t MPerXdl,
124 index_t NPerXdl,
125 index_t MXdlPerWave,
126 index_t NXdlPerWave,
127 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
128 typename ABlockTransferThreadClusterArrangeOrder,
129 typename ABlockTransferSrcAccessOrder,
130 index_t ABlockTransferSrcVectorDim,
131 index_t ABlockTransferSrcScalarPerVector,
132 index_t ABlockTransferDstScalarPerVector_AK1,
133 bool AThreadTransferSrcResetCoordinateAfterRun,
134 index_t ABlockLdsExtraM,
135 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
136 typename BBlockTransferThreadClusterArrangeOrder,
137 typename BBlockTransferSrcAccessOrder,
138 index_t BBlockTransferSrcVectorDim,
139 index_t BBlockTransferSrcScalarPerVector,
140 index_t BBlockTransferDstScalarPerVector_BK1,
141 bool BThreadTransferSrcResetCoordinateAfterRun,
142 index_t BBlockLdsExtraN,
143 index_t CShuffleMXdlPerWavePerShuffle,
144 index_t CShuffleNXdlPerWavePerShuffle,
145 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
146 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
149 typename ComputeTypeA =
150 ADataType, // XXX: These should always be the same as ADataType and BDataType
151 typename ComputeTypeB =
152 BDataType, // TODO: Hardcode them and remove from the list of template parameters
153 bool PermuteA = false,
154 bool PermuteB = false>
156{
157
158 static constexpr auto I0 = Number<0>{};
159 static constexpr auto I1 = Number<1>{};
160 static constexpr auto I2 = Number<2>{};
161 static constexpr auto I3 = Number<3>{};
162 static constexpr auto I4 = Number<4>{};
163 static constexpr auto I5 = Number<5>{};
164 static constexpr auto I6 = Number<6>{};
165 static constexpr auto I7 = Number<7>{};
166 static constexpr auto I8 = Number<8>{};
167 static constexpr auto I9 = Number<9>{};
168
169 // K1 should be Number<...>
170 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
171 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
172 static constexpr auto AK1Number = Number<AK1Value>{};
173 static constexpr auto BK1Number = Number<BK1Value>{};
174
175 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
176 static constexpr bool is_single_rate_mfma = false;
177 static constexpr auto is_scale_mfma = true;
178
179 static constexpr auto MXdlPack = 2;
180 static constexpr auto NXdlPack = 2;
181 static constexpr auto KXdlPack = 2;
182
183 //> KPack is at least the k_per_blk of selected mfma
184 //
185 // Should be a multiple of k_per_blk.
186 // TODO: Move this to blockwise pipeline base
187 // KPack in packed data types for pk A/B
188
191
192 static constexpr index_t KPack =
194 MfmaSelector<ComputeTypeA,
195 MPerXdl,
196 NPerXdl,
197 ComputeTypeB,
199 is_scale_mfma>::selected_mfma.k_per_blk /
201
202 static constexpr index_t NLane = NPerXdl;
203 static constexpr index_t KLane = 64 / NLane;
204 static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
205 static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
206
208
210 static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
211 static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
212 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
213 "A scale pack data type too large!");
214 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
215 "B scale pack data type too large!");
216
217 __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
218 {
219 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
220 }
221
222 __host__ static auto CalculateMPadded(index_t M)
223 {
224 return math::integer_least_multiple(M, MPerBlock);
225 }
226
227 __host__ static auto CalculateNPadded(index_t N)
228 {
229 return math::integer_least_multiple(N, NPerBlock);
230 }
231
232 __host__ static auto CalculateKPadded(index_t K)
233 {
234 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
235 }
236
237 __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
238 {
239 auto K_t = K_Batch * KPerBlock;
240 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
241 }
242
243 __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
244 {
245 auto K_t = K_Batch * KPerBlock;
246 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
247 }
248
249 __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
250 {
252 }
253 __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
254 {
256 }
257
258 __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
259 {
260 auto K_t = K_Batch * KPerBlock;
261 return (K + K_t - 1) / K_t * KPerBlock;
262 }
263
264 __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
265 {
266 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
267 auto K_t = K_Batch * KReadVec;
268 return (K + K_t - 1) / K_t * KReadVec;
269 }
270
271 __host__ static auto CalculateMBlock(index_t M)
272 {
273 return math::integer_divide_ceil(M, MPerBlock);
274 }
275
276 __host__ static auto CalculateNBlock(index_t N)
277 {
278 return math::integer_divide_ceil(N, NPerBlock);
279 }
280
281 template <index_t MNXdlPerWave,
282 index_t MNWaves,
283 index_t MNXdlPack,
284 index_t MNPerXdl,
285 bool IsXor,
286 typename TileDesc_K0_MN_K1>
287 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
288 {
289 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
290 constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
291 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
292
293 if constexpr(IsXor)
294 {
295 constexpr auto permuted_desc = transform_tensor_descriptor(
296 TileDesc_K0_MN_K1{},
301
303 permuted_desc,
309 Number<MNPerXdl>{}))),
312 }
313 else
314 {
316 TileDesc_K0_MN_K1{},
322 Number<MNPerXdl>{}))),
325 }
326 }
327
328 __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
329 index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
330 {
331 const auto a_grid_desc_mraw_kraw = [&]() {
333 {
334 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
335 }
337 {
338 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
339 }
340 }();
341
342 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
343
344 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
345 GemmSpec == GemmSpecialization::MNKPadding)
346 {
347 // pad both M and K
348 const auto a_grid_desc_m_k =
349 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
351 make_right_pad_transform(K, KPad - K)),
354
355 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
356 a_grid_desc_m_k,
361
362 return a_grid_desc_ak0_m_ak1;
363 }
364 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
365 GemmSpec == GemmSpecialization::MNPadding)
366 {
367 // pad M, but not K
368 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
369 a_grid_desc_mraw_kraw,
371 make_right_pad_transform(M, MPad - M)),
374
375 return a_grid_desc_ak0_m_ak1;
376 }
377 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
378 GemmSpec == GemmSpecialization::NKPadding)
379 {
380 // pad K, but not M
381 const auto a_grid_desc_m_k = transform_tensor_descriptor(
382 a_grid_desc_mraw_kraw,
386
387 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
388 a_grid_desc_m_k,
393
394 return a_grid_desc_ak0_m_ak1;
395 }
396 else
397 {
398 // not pad M or K
399 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
400 a_grid_desc_mraw_kraw,
401 make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
405
406 const auto a_grid_desc_permuted = transform_tensor_descriptor(
407 a_grid_desc_ak0_m_ak1,
413
414 const auto a_grid_desc = transform_tensor_descriptor(
415 a_grid_desc_permuted,
422
423 return a_grid_desc;
424 }
425 }
426
427 __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
428 {
429 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
430 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
431 constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack>{};
433 math::integer_divide_ceil(N0, NWave * NXdlPack), NWave, NXdlPack, K0, NkSwizzleNumber));
434 }
435
436 __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
437 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
438 {
439 const auto b_grid_desc_nraw_kraw = [&]() {
441 {
442 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
443 }
445 {
446 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
447 }
448 }();
449
450 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
451
453 GemmSpec != GemmSpecialization::Default),
454 "pk_i4_t does not support padding");
456 GemmSpec != GemmSpecialization::Default),
457 "f4x2_pk_t does not support padding");
458
459 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
460 GemmSpec == GemmSpecialization::MNKPadding)
461 {
462 // pad both N and K
463 const auto b_grid_desc_n_k =
464 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
466 make_right_pad_transform(K, KPad - K)),
469
470 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
471 b_grid_desc_n_k,
476
477 return b_grid_desc_bk0_n_bk1;
478 }
479 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
480 GemmSpec == GemmSpecialization::MNPadding)
481 {
482 // pad N, but not K
483 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
484 b_grid_desc_nraw_kraw,
486 make_right_pad_transform(N, NPad - N)),
489
490 return b_grid_desc_bk0_n_bk1;
491 }
492 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
493 GemmSpec == GemmSpecialization::MKPadding)
494 {
495 // pad K, but not N
496 const auto b_grid_desc_n_k = transform_tensor_descriptor(
497 b_grid_desc_nraw_kraw,
501
502 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
503 b_grid_desc_n_k,
508
509 return b_grid_desc_bk0_n_bk1;
510 }
511 else
512 {
513 if constexpr(!PermuteB)
514 {
515 // not pad N or K
516 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
517 b_grid_desc_nraw_kraw,
519 make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
523
524 const auto b_grid_desc_permuted = transform_tensor_descriptor(
525 b_grid_desc_bk0_n_bk1,
531
532 const auto b_grid_desc = transform_tensor_descriptor(
533 b_grid_desc_permuted,
540
541 return b_grid_desc;
542 }
543 else
544 {
545 // Weight Tile Permute
546 constexpr index_t BK01 = KPerBlock / BK1Value;
547 // const index_t BK00 = BK0 / BK01;
548 const index_t BK0_ = StrideB / BK1Value;
549 const index_t BK00 = BK0_ / BK01;
550
551 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
552 make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
553
554 const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
555 b_grid_desc_bk00_n_bk01_bk1_permute,
561
562 return b_grid_desc_bk0_n_bk1_permute;
563 }
564 }
565 }
566
567 template <typename ABlockDesc_AK0_M_AK1>
568 __host__ __device__ static constexpr auto
569 MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
570 {
571 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
572
574 ABlockDesc_AK0_M_AK1{});
575 }
576
577 template <typename BBlockDesc_BK0_N_BK1>
578 __host__ __device__ static constexpr auto
579 MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
580 {
581 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
582
584 BBlockDesc_BK0_N_BK1{});
585 }
586
587 __host__ __device__ static auto
589 {
590 const auto c_grid_desc_mraw_nraw = [&]() {
592 {
593 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
594 }
596 {
597 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
598 }
599 }();
600
601 // pad M and N
602 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
604 make_right_pad_transform(N, NPad - N)),
607#if 0
608 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
609
610 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
611 GemmSpec == GemmSpecialization::MNKPadding)
612 {
613 // pad M and N
614 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
616 make_right_pad_transform(N, NPad - N)),
619 }
620 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
621 GemmSpec == GemmSpecialization::MKPadding)
622 {
623 // pad M, but not N
625 c_grid_desc_mraw_nraw,
629 }
630 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
631 GemmSpec == GemmSpecialization::NKPadding)
632 {
633 // pad N, but not M
635 c_grid_desc_mraw_nraw,
639 }
640 else
641 {
642 // not pad M or N
643 return c_grid_desc_mraw_nraw;
644 }
645#endif
646 }
647
648 struct Problem
649 {
650 __host__ Problem(index_t M_,
651 index_t N_,
652 index_t K_,
653 index_t StrideA_,
654 index_t StrideScaleA_,
655 index_t StrideB_,
656 index_t StrideScaleB_,
657 index_t StrideC_,
658 index_t KBatch_)
659 : M{M_},
660 N{N_},
661 K{K_},
662 StrideA{StrideA_},
663 StrideScaleA{StrideScaleA_},
664 StrideB{StrideB_},
665 StrideScaleB{StrideScaleB_},
666 StrideC{StrideC_},
667 KBatch{KBatch_},
670 KRead{CalculateKRead(K_, KBatch_)},
671 KPadded{CalculateKPadded(K_, KBatch_)},
672 AK0{CalculateAK0Padded(K_, KBatch_)},
673 BK0{CalculateBK0Padded(K_, KBatch_)},
676 {
677 }
678
679 __host__ void Print() const
680 {
681 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
682 << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
683 << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
684 << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
685 << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
686 << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
687 << ", " << "NBlock: " << NBlock << "}" << std::endl;
688 }
689
707 };
708
709 // Argument
711 {
712 __host__ Argument(const ADataType* p_a_grid_,
713 const AScaleDataType* p_a_scale_grid_,
714 const BDataType* p_b_grid_,
715 const BScaleDataType* p_b_scale_grid_,
716 CDataType* p_c_grid_,
717 index_t M_,
718 index_t N_,
719 index_t K_,
720 index_t StrideA_,
721 index_t StrideScaleA_,
722 index_t StrideB_,
723 index_t StrideScaleB_,
724 index_t StrideC_,
725 index_t k_batch_,
726 AElementwiseOperation a_element_op_,
727 BElementwiseOperation b_element_op_,
728 CElementwiseOperation c_element_op_,
729 bool is_reduce_ = false)
730 : Problem{M_,
731 N_,
732 K_ / APackedSize,
733 StrideA_ / APackedSize,
734 StrideScaleA_,
735 StrideB_ / BPackedSize,
736 StrideScaleB_,
737 StrideC_,
738 k_batch_},
739 p_a_grid{p_a_grid_},
740 p_a_scale_grid{p_a_scale_grid_},
741 p_b_grid{p_b_grid_},
742 p_b_scale_grid{p_b_scale_grid_},
743 p_c_grid{p_c_grid_},
744 a_element_op{a_element_op_},
745 b_element_op{b_element_op_},
746 c_element_op{c_element_op_},
747 is_reduce(is_reduce_)
748 {
749 }
750
751 __host__ __device__ inline bool IsReduceAdd() const
752 {
753 return (Problem::KBatch > 1) && is_reduce;
754 }
755
756 __host__ __device__ inline bool IsAtomicAdd() const
757 {
758 return (Problem::KBatch > 1) && (!is_reduce);
759 }
760
761 const ADataType* p_a_grid;
762 const AScaleDataType* p_a_scale_grid;
763 const BDataType* p_b_grid;
764 const BScaleDataType* p_b_scale_grid;
765 CDataType* p_c_grid;
766
767 const AElementwiseOperation a_element_op;
768 const BElementwiseOperation b_element_op;
769 const CElementwiseOperation c_element_op;
771 };
772
774 {
775
776 __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
777 {
779 {
780 a_k_split_offset = k_id * karg.KRead;
781 }
783 {
784 a_k_split_offset = k_id * karg.KRead * karg.StrideA;
785 }
786
788 {
789 b_k_split_offset = k_id * karg.KRead * karg.StrideB;
790 }
792 {
793 if constexpr(!PermuteB)
794 {
795 b_k_split_offset = k_id * karg.KRead * NPerXdl;
796 }
797 else
798 {
799 const int k0_offset = karg.KRead * karg.N;
800 b_k_split_offset = k_id * k0_offset;
801 }
802 }
803
804 // Calculate A scale offset
805 a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack *
806 MPerXdl / scale_pack_size_a;
807
808 // Calculate B scale offset
809 b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack *
810 NPerXdl / scale_pack_size_b;
811
812 if(k_id < (karg.KBatch - 1))
813 {
814 karg.K = karg.KRead;
815 }
816 else
817 {
818 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
819 }
820
821 if(karg.IsReduceAdd())
822 {
823 c_reduce_offset = k_id * karg.M * karg.N;
824 }
825 else
826 {
827 c_reduce_offset = 0;
828 }
829 }
830
833 index_t a_scale_k_split_offset; // New member for scale matrix offset
834 index_t b_scale_k_split_offset; // New member for scale matrix offset
836 };
837
838 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
839 {
840 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
841 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
842
843 // A matrix in LDS memory, dst of blockwise copy
844 if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
845 {
846 // contiguous in LDS
850 }
851 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
852 // in some cases.
854 {
855 constexpr auto a_lds_block_desc =
858
859 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
860 a_lds_block_desc,
866
867 return a_lds_block_desc_permuted;
868 }
869 else // ColumnMajor A
870 {
871 // kfold and mpair dimension is not always required.
872 // more dimension in merge_transform increase the difficulty of generating immarg offset
873 // for compiler.
874 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
875 constexpr auto M1 = MPerBlock / M0;
876
877 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
878 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
879 constexpr auto KThreadRead = WaveSize / MPerXdl;
880 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
881
882 constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
883 ? 1
884 : 128 / (AK1Number * M0 * sizeof(ADataType));
885 constexpr auto KThreadReadPerm =
886 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
887 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
888 : KThreadRead;
889
890 // 1<=mpair<=n0
891 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
892 ? 1
893 : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
894 ? M0
895 : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
896
897 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
901 Number<kfold * M0 / mpair>{},
903 AK1Number));
904
905 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
906 a_lds_block_desc,
911 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
918
919 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
920 a_lds_block_desc_permuted,
929 Sequence<1>{},
930 Sequence<2>{},
931 Sequence<3>{},
932 Sequence<4>{},
933 Sequence<5>{}),
935 Sequence<2>{},
938 Sequence<6>{},
939 Sequence<7>{}));
940
941 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
942 a_lds_block_desc_unmerged,
945 Number<KThreadWrite / kfold / KThreadReadPerm>{},
953
954 return a_lds_block_desc_ak0_m_ak1;
955 }
956 }
957
958 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
959 {
960 // K0 -> N0/NWave/NXdlPack -> NWave -> NXdlPack -> KLane -> NLane -> KPack
962 I1,
966 }
967
969 {
970 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
971 // constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
972
973 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
977 I1,
979
980 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
981 }
982
985 BlkGemmPipelineVer,
986 BlkGemmPipeSched,
987 BlockSize,
988 ScaleBlockSize,
989 ADataType,
990 AScaleDataType,
991 BDataType,
992 BScaleDataType,
993 ComputeTypeA,
994 AccDataType,
1001 ABlockTransferSrcScalarPerVector,
1002 BBlockTransferSrcScalarPerVector,
1003 MPerBlock,
1004 NPerBlock,
1005 KPerBlock,
1006 MPerXdl,
1007 NPerXdl,
1008 MXdlPerWave,
1009 NXdlPerWave,
1010 KPack>())>;
1011
1012 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1013 {
1014 // LDS allocation for A and B: be careful of alignment
1015 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1016
1017 // lds max alignment
1018 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1019
1020 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1021 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1022
1023 // LDS allocation for C shuffle in LDS
1024 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1026
1027 constexpr auto c_block_size =
1028 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1029
1030 return math::max(a_block_space_size_aligned * sizeof(ADataType),
1031 c_block_size * sizeof(CShuffleDataType));
1032 }
1033
1035
1036 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1037 __host__ static constexpr bool CheckValidity(const Argument& karg)
1038 {
1039 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1040 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1041 "Invalid tuning param!");
1042
1043 static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1044 "KPerBlock should be multiple of ScaleBlockSize");
1045
1046 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1051 {
1052 if(!(karg.M % MPerBlock == 0))
1053 {
1054 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1055 {
1056 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1057 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1058 << std::endl;
1059 }
1060 return false;
1061 }
1062 }
1063
1064 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1069 {
1070 if(!(karg.N % NPerBlock == 0))
1071 {
1072 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1073 {
1074 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1075 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1076 << std::endl;
1077 }
1078 return false;
1079 }
1080 }
1081
1082 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1086 {
1087 auto K_t = karg.KBatch * KPerBlock;
1088 if(!(karg.K % K_t == 0))
1089 {
1090 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1091 {
1092 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1093 << karg.K << " " << __FILE__ << ":" << __LINE__
1094 << ", in function: " << __func__ << std::endl;
1095 }
1096 return false;
1097 }
1098 }
1099 else
1100 {
1101 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1102 auto K_t = karg.KBatch * KReadVec;
1103 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1104 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1105 {
1106 return false;
1107 }
1108 }
1109
1111 {
1112 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1113 {
1114 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1115 {
1116 std::cout << "Arg K (" << karg.K
1117 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1118 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1119 << __LINE__ << ", in function: " << __func__ << std::endl;
1120 }
1121 return false;
1122 }
1123 }
1124 else
1125 {
1126 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1127 {
1128 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1129 {
1130 std::cout << "Arg M (" << karg.M
1131 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1132 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1133 << __LINE__ << ", in function: " << __func__ << std::endl;
1134 }
1135 return false;
1136 }
1137 }
1138
1140 {
1141 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1142 {
1143 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1144 {
1145 std::cout << "Arg N (" << karg.N
1146 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1147 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1148 << __LINE__ << ", in function: " << __func__ << std::endl;
1149 }
1150 return false;
1151 }
1152 }
1153 else
1154 {
1155 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1156 {
1157 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1158 {
1159 std::cout << "Arg K (" << karg.K
1160 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1161 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1162 << __LINE__ << ", in function: " << __func__ << std::endl;
1163 }
1164 return false;
1165 }
1166 }
1167
1169 {
1170 if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1171 {
1172 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1173 {
1174 std::cout << "Arg N (" << karg.N
1175 << ") value is not a multiple of "
1176 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1177 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1178 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1179 << std::endl;
1180 }
1181 return false;
1182 }
1183 }
1184 else
1185 {
1186 if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1187 {
1188 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1189 {
1190 std::cout << "Arg M (" << karg.M
1191 << ") value is not a multiple of "
1192 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1193 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1194 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1195 << std::endl;
1196 }
1197 return false;
1198 }
1199 }
1200
1205 {
1206 if(!karg.IsReduceAdd())
1207 {
1208 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1209 {
1210 std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
1211 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1212 }
1213 if(karg.KBatch > 1)
1214 {
1215 return false;
1216 }
1217 }
1218 }
1219 // check gridwise gemm pipeline
1220 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1221
1222 if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1223 {
1224 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1225 {
1226 return false;
1227 }
1228 }
1229 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1230 return true;
1231 }
1232
1233 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1234 {
1235 const index_t num_loop = K / KPerBlock;
1236
1237 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1238 }
1239
1240 __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1241 {
1242 const index_t num_loop = K / KPerBlock;
1243
1244 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1245 }
1246
1247 template <typename CGridDesc>
1248 __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1249 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1250 {
1251 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1252 c_grid_desc_m_n,
1257
1258 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1259 }
1260
1261 // return block_id to C matrix tile idx (m0, n0) mapping
1262 // if arch = gfx942
1264 // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
1265
1266 template <typename AGridDesc_AK0_M_K1,
1267 typename AScaleGridDesc_AM_AK,
1268 typename BGridDesc_BK0_N_K1,
1269 typename BScaleGridDesc_BN_AK,
1270 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1271 bool HasMainKBlockLoop,
1272 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1273 TailNumber TailNum = TailNumber::Odd>
1274 __device__ static void Run(const ADataType* p_a_grid,
1275 const AScaleDataType* p_a_scale_grid,
1276 const BDataType* p_b_grid,
1277 const BScaleDataType* p_b_scale_grid,
1278 CDataType* p_c_grid,
1279 void* p_shared,
1280 const Problem& problem,
1281 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1282 const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak,
1283 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1284 const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
1285 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1286 c_grid_desc_mblock_mperblock_nblock_nperblock)
1287 {
1288 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1289 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1290 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1291 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1293 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1294
1295 // A Scale buffer
1296 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1297 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1298
1299 // B Scale buffer
1300 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1301 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1302
1303 const AElementwiseOperation a_element_op{};
1304 const BElementwiseOperation b_element_op{};
1305 const CElementwiseOperation c_element_op{};
1306
1307 // divide block work by [M, N]
1308 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1309
1310 const auto block_work_idx =
1311 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1312
1313 if(!block_2_ctile_map.ValidCTileIndex(
1314 block_work_idx,
1315 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1316 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1317 {
1318 return;
1319 }
1320
1321 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1322 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1323
1324 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1325 const index_t m_block_data_idx_on_grid =
1326 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1327
1328 const index_t n_block_data_idx_on_grid =
1329 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1330
1331 // lds max alignment
1332 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1333
1334 // A matrix in LDS memory, dst of blockwise copy
1335 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1336
1337 // B matrix in LDS memory, dst of blockwise copy
1338 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1339
1340 auto a_blockwise_copy =
1343 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1344 ABlockTransferThreadClusterArrangeOrder,
1345 ADataType,
1346 ADataType,
1347 decltype(a_grid_desc_ak0_m_ak1),
1348 decltype(a_block_desc_ak0_m_ak1),
1349 ABlockTransferSrcAccessOrder,
1350 ABlockTransferSrcVectorDim,
1351 2,
1352 ABlockTransferSrcScalarPerVector>(
1353 a_grid_desc_ak0_m_ak1,
1354 make_multi_index(0, m_block_data_idx_on_grid, 0),
1355 a_block_desc_ak0_m_ak1,
1356 make_multi_index(0, 0, 0));
1357
1358 // B matrix blockwise copy
1359 auto b_blockwise_copy =
1362 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1363 BBlockTransferThreadClusterArrangeOrder,
1364 BDataType,
1365 BDataType,
1366 decltype(b_grid_desc_bk0_n_bk1),
1367 decltype(b_block_desc_bk0_n_bk1),
1368 BBlockTransferSrcAccessOrder,
1369 BBlockTransferSrcVectorDim,
1370 2,
1371 BBlockTransferSrcScalarPerVector>(
1372 b_grid_desc_bk0_n_bk1,
1373 make_multi_index(0, n_block_data_idx_on_grid, 0),
1374 b_block_desc_bk0_n_bk1,
1375 make_multi_index(0, 0, 0));
1376
1377 // LDS allocation for A and B: be careful of alignment
1378 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1379 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1380
1381 // Cast after lds
1383 static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1384
1386 reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1387 a_block_space_size_aligned * sizeof(ADataType)),
1388 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1389
1390 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1391 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1392
1393 // Blockwise GEMM pipeline
1394 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1395 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1396 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1397
1398 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1399 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1400 KPerBlock);
1401
1402 // Initial thread mapping for:
1403 // BlockSize = 256
1404 // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2
1405 // For each [m0, n0] tile, there are 4 waves:
1406 // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0]
1407 // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1]
1408 // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0]
1409 // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1]
1410
1411 // BlockSize = 128
1412 // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1
1413 // For each [m0, n0] tile, there are 2 waves:
1414 // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0]
1415 // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0]
1416
1417 // TODO: Document initial thread mapping for more combinations of parameters
1418
1419 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1420 const auto waveId_m = wave_idx[I0];
1421 const auto waveId_n = wave_idx[I1];
1422
1423 // static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
1424
1425 // auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) /
1426 // mfma.selected_mfma.num_threads_per_blk;
1427
1428 // A wave access continuous memory
1429 auto thread_offset_shuffled =
1430 get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1431
1432 auto a_thread_offset_m = waveId_m;
1433
1434 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1435 AScaleDataType,
1436 AScaleDataType,
1437 decltype(a_scale_grid_desc_am_ak),
1438 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1439 Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1440 Sequence<0, 1, 2>, // DimAccessOrder
1441 2, // SrcVectorDim
1442 KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1443 1, // SrcScalarStrideInVector
1444 true>(a_scale_grid_desc_am_ak,
1445 make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1446 0,
1447 thread_offset_shuffled / scale_pack_size_a));
1448
1449 auto b_thread_offset_n = waveId_n;
1450
1451 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1452 BScaleDataType,
1453 BScaleDataType,
1454 decltype(b_scale_grid_desc_bn_ak),
1455 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1456 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1457 Sequence<0, 1, 2>, // DimAccessOrder
1458 2, // SrcVectorDim
1459 KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1460 1, // SrcScalarStrideInVector
1461 true>(b_scale_grid_desc_bn_ak,
1462 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1463 0,
1464 thread_offset_shuffled / scale_pack_size_b));
1465
1466 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1467 a_block_desc_ak0_m_ak1,
1468 a_blockwise_copy,
1469 a_grid_buf,
1470 a_block_buf,
1471 a_block_slice_copy_step,
1472 b_grid_desc_bk0_n_bk1,
1473 b_block_desc_bk0_n_bk1,
1474 b_blockwise_copy,
1475 b_grid_buf,
1476 b_block_buf,
1477 b_block_slice_copy_step,
1478 c_thread_buf,
1479 a_scale_grid_desc_am_ak,
1480 a_scale_thread_copy,
1481 a_scale_grid_buf,
1482 b_scale_grid_desc_bn_ak,
1483 b_scale_thread_copy,
1484 b_scale_grid_buf,
1485 num_k_block_main_loop);
1486
1487 // shuffle C and write out
1488 {
1489 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1490 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1491 "wrong!");
1492 static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1493 CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1494 "wrong!");
1495
1496 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1497
1498 // TODO: hacky, fix it!
1499 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1500 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1501
1502 // TODO: hacky, fix it!
1503 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1504 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1505 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1506
1507 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1508 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1509 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1510 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1511 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1512 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1513 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1514 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1515 constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
1516 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
1517
1518 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1520
1521 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1522 static_cast<CShuffleDataType*>(p_shared),
1523 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1524
1525 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1526 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1527 make_tuple(
1530 Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
1531 // shuffle
1532 M1, // M1 = MWave
1533 M2, // M2 = MXdlPack
1534 M3, // M3 * M4 * M5 = MPerXdl
1535 M4,
1536 M5)),
1539 Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave) per
1540 // shuffle
1541 N1, // N1 = NWave
1542 N2, // N2 = NXdlPack
1543 N3))), // N3 = NPerXdl
1547 Sequence<>{},
1549
1550 // calculate origin of thread output tensor on global memory
1551 // blockwise GEMM c matrix starting index
1552 const auto c_thread_mtx_on_block =
1553 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1554
1555 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1556 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1557
1558 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1560 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
1563
1564 const auto m_thread_data_on_block_idx =
1565 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1566 make_multi_index(m_thread_data_on_block));
1567
1568 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1570 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
1573
1574 const auto n_thread_data_on_block_idx =
1575 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1576 make_multi_index(n_thread_data_on_block));
1577
1578 // shuffle: threadwise copy C from VGPR to LDS
1579 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1580 AccDataType,
1581 CShuffleDataType,
1582 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1583 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1585 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1586 CShuffleNXdlPerWavePerShuffle / NXdlPack,
1587 I1,
1588 I1,
1589 M2,
1590 N2,
1591 M3,
1592 I1,
1593 M5,
1594 I1>,
1596 9,
1597 1,
1599 1,
1600 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1602 0,
1603 m_thread_data_on_block_idx[I1],
1604 n_thread_data_on_block_idx[I1],
1605 m_thread_data_on_block_idx[I2],
1606 n_thread_data_on_block_idx[I2],
1607 m_thread_data_on_block_idx[I3],
1608 m_thread_data_on_block_idx[I4],
1609 m_thread_data_on_block_idx[I5],
1610 n_thread_data_on_block_idx[I3]),
1612
1613 // shuffle: blockwise copy C from LDS to global
1614 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1615 ThisThreadBlock, // ThreadGroup
1616 CElementwiseOperation, // ElementwiseOperation,
1617 CGlobalMemoryDataOperation, // DstInMemOp,
1618 Sequence<1,
1619 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1620 1,
1621 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1622 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1623 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1624 CShuffleDataType, // typename SrcData,
1625 CDataType, // typename DstData,
1626 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1627 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1628 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1629 3, // index_t VectorDim,
1630 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1631 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1632 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1633 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1634 make_multi_index(0, 0, 0, 0),
1635 c_grid_desc_mblock_mperblock_nblock_nperblock,
1636 make_multi_index(block_m_id, 0, block_n_id, 0),
1637 c_element_op};
1638
1639 // space filling curve for threadwise C in VGPR
1640 constexpr auto sfc_c_vgpr =
1641 SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
1642 NXdlPerWave / NXdlPack,
1643 1,
1644 1,
1645 MXdlPack,
1646 NXdlPack,
1647 M2,
1648 1,
1649 M4,
1650 1>,
1652 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1653 CShuffleNXdlPerWavePerShuffle / NXdlPack,
1654 1,
1655 1,
1656 MXdlPack,
1657 NXdlPack,
1658 M2,
1659 1,
1660 M4,
1661 1>>{};
1662
1663 // space filling curve for shuffled blockwise C in global mem
1664 constexpr auto sfc_c_global =
1667 Sequence<1,
1668 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1669 1,
1670 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1671
1672 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1673
1674 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1675
1676 static_for<0, num_access, 1>{}([&](auto access_id) {
1677 // make sure it's safe to write to LDS
1679
1680 // each thread write its data from VGPR to LDS
1681 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1682 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1683 c_thread_buf,
1684 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1685 c_shuffle_block_buf);
1686
1687 // make sure it's safe to read from LDS
1689
1690 // each block copy its data from LDS to global
1691 c_shuffle_block_copy_lds_to_global.Run(
1692 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1693 c_shuffle_block_buf,
1694 c_grid_desc_mblock_mperblock_nblock_nperblock,
1695 c_grid_buf);
1696
1697 if constexpr(access_id < num_access - 1)
1698 {
1699 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1700
1701 // move on C
1702 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1703 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1704 }
1705 });
1706 }
1707 }
1708
1709 template <bool HasMainKBlockLoop,
1710 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1711 TailNumber TailNum = TailNumber::Odd>
1712 __device__ static void Run(const ADataType* p_a_grid,
1713 const AScaleDataType* p_a_scale_grid,
1714 const BDataType* p_b_grid,
1715 const BScaleDataType* p_b_scale_grid,
1716 CDataType* p_c_grid,
1717 void* p_shared,
1718 const Problem& problem)
1719 {
1720 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1721 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1722 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1723 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1724 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1725 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1726 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1728 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1729
1730 // A/B shuffled scale for better 8-bit scale access pattern
1731 // MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack
1732 const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1733 make_tuple(problem.M / (MXdlPack * MPerXdl),
1734 math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1735 (KXdlPack * 64 / MPerXdl),
1737
1738 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1739 make_tuple(problem.N / (NXdlPack * NPerXdl),
1740 math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1741 (KXdlPack * 64 / NPerXdl),
1743
1744 Run<decltype(a_grid_desc_ak0_m_ak1),
1745 decltype(a_scale_grid_desc_am_ak),
1746 decltype(b_grid_desc_bk0_n_bk1),
1747 decltype(b_scale_grid_desc_bn_ak),
1748 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1749 HasMainKBlockLoop,
1750 CGlobalMemoryDataOperation,
1751 TailNum>(p_a_grid,
1752 p_a_scale_grid,
1753 p_b_grid,
1754 p_b_scale_grid,
1755 p_c_grid,
1756 p_shared,
1757 problem,
1758 a_grid_desc_ak0_m_ak1,
1759 a_scale_grid_desc_am_ak,
1760 b_grid_desc_bk0_n_bk1,
1761 b_scale_grid_desc_bn_ak,
1762 c_grid_desc_mblock_mperblock_nblock_nperblock);
1763 }
1764
1765 template <typename AGridDesc_AK0_M_K1,
1766 typename AScaleGridDesc_AM_AK,
1767 typename BGridDesc_BK0_N_K1,
1768 typename BScaleGridDesc_BN_AK,
1769 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1770 bool HasMainKBlockLoop,
1771 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1772 TailNumber TailNum = TailNumber::Odd>
1773 __device__ static void Run_2Lds(const ADataType* p_a_grid,
1774 const AScaleDataType* p_a_scale_grid,
1775 const BDataType* p_b_grid,
1776 const BScaleDataType* p_b_scale_grid,
1777 CDataType* p_c_grid,
1778 void* p_shared_0,
1779 void* p_shared_1,
1780 const Problem& problem,
1781 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1782 const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak,
1783 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1784 const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
1785 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1786 c_grid_desc_mblock_mperblock_nblock_nperblock)
1787 {
1788 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1789 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1790 const auto b_grid_buf =
1792 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1794 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1795
1796 // A Scale buffer
1797 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1798 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1799
1800 // B Scale buffer
1801 const auto b_scale_grid_buf =
1803 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1804
1805 const CElementwiseOperation c_element_op{};
1806
1807 // divide block work by [M, N]
1808 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1809
1810 const auto block_work_idx =
1811 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1812
1813 if(!block_2_ctile_map.ValidCTileIndex(
1814 block_work_idx,
1815 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1816 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1817 {
1818 return;
1819 }
1820
1821 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1822 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1823
1824 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1825 const index_t m_block_data_idx_on_grid =
1826 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1827
1828 const index_t n_block_data_idx_on_grid =
1829 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
1830
1831 // lds max alignment
1832 // constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1833
1834 // A matrix in LDS memory, dst of blockwise copy
1835 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1836
1837 // B matrix in LDS memory, dst of blockwise copy
1838 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1839
1840 auto a_blockwise_copy =
1843 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1844 ABlockTransferThreadClusterArrangeOrder,
1845 ADataType,
1846 ADataType,
1847 decltype(a_grid_desc_ak0_m_ak1),
1848 decltype(a_block_desc_ak0_m_ak1),
1849 ABlockTransferSrcAccessOrder,
1850 ABlockTransferSrcVectorDim,
1851 2,
1852 ABlockTransferSrcScalarPerVector>(
1853 a_grid_desc_ak0_m_ak1,
1854 make_multi_index(0, m_block_data_idx_on_grid, 0),
1855 a_block_desc_ak0_m_ak1,
1856 make_multi_index(0, 0, 0));
1857
1858 // dummys
1860 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1862 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1863 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1864
1865 auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1866 BDataType,
1867 BDataType,
1868 decltype(b_grid_desc_bk0_n_bk1),
1869 decltype(b_block_desc_bk0_n_bk1), // actually the thread desc
1870 Sequence<Number<NXdlPerWave / NXdlPack>{},
1871 I1,
1876 4,
1877 BBlockTransferSrcScalarPerVector,
1878 BThreadTransferSrcResetCoordinateAfterRun,
1879 true>(b_grid_desc_bk0_n_bk1,
1880 make_multi_index(n_block_data_idx_on_grid,
1882 0,
1883 0,
1884 KPack * (get_thread_local_1d_id() % WarpSize)));
1885
1886 // LDS allocation for A and B: be careful of alignment
1887 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1888 static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1889
1890 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1891 static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1892
1893 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
1894
1895 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1896 constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
1897
1898 // Blockwise GEMM pipeline
1899 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1900 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1901 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1902
1903 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1904 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1905 KPerBlock);
1906
1907 // Initial thread mapping for:
1908 // BlockSize = 256
1909 // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2
1910 // For each [m0, n0] tile, there are 4 waves:
1911 // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0]
1912 // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1]
1913 // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0]
1914 // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1]
1915
1916 // BlockSize = 128
1917 // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1
1918 // For each [m0, n0] tile, there are 2 waves:
1919 // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0]
1920 // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0]
1921
1922 // TODO: Document initial thread mapping for more combinations of parameters
1923
1924 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1925 const auto waveId_m = wave_idx[I0];
1926 const auto waveId_n = wave_idx[I1];
1927
1928 // static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
1929
1930 // auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) /
1931 // mfma.selected_mfma.num_threads_per_blk;
1932
1933 // A wave access continuous memory
1934 auto thread_offset_shuffled =
1935 get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1936
1937 auto a_thread_offset_m = waveId_m;
1938
1939 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1940 AScaleDataType,
1941 AScaleDataType,
1942 decltype(a_scale_grid_desc_am_ak),
1943 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1944 Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1945 Sequence<0, 1, 2>, // DimAccessOrder
1946 2, // SrcVectorDim
1947 KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1948 1, // SrcScalarStrideInVector
1949 true>(a_scale_grid_desc_am_ak,
1950 make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1951 0,
1952 thread_offset_shuffled / scale_pack_size_a));
1953
1954 auto b_thread_offset_n = waveId_n;
1955
1956 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1957 BScaleDataType,
1958 BScaleDataType,
1959 decltype(b_scale_grid_desc_bn_ak),
1960 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1961 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1962 Sequence<0, 1, 2>, // DimAccessOrder
1963 2, // SrcVectorDim
1964 KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1965 1, // SrcScalarStrideInVector
1966 true>(b_scale_grid_desc_bn_ak,
1967 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1968 0,
1969 thread_offset_shuffled / scale_pack_size_b));
1970
1971 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1972 a_block_desc_ak0_m_ak1,
1973 a_blockwise_copy,
1974 a_grid_buf,
1975 a_block_bufs,
1976 a_block_slice_copy_step,
1977 b_grid_desc_bk0_n_bk1,
1978 b_block_desc_bk0_n_bk1,
1979 b_blockwise_copy,
1980 b_grid_buf,
1981 b_block_bufs,
1982 b_block_slice_copy_step,
1983 c_thread_buf,
1984 a_scale_grid_desc_am_ak,
1985 a_scale_thread_copy,
1986 a_scale_grid_buf,
1987 b_scale_grid_desc_bn_ak,
1988 b_scale_thread_copy,
1989 b_scale_grid_buf,
1990 num_k_block_main_loop);
1991
1992 // shuffle C and write out
1993 {
1994 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1995 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1996 "wrong!");
1997 static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1998 CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1999 "wrong!");
2000
2001 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2002 // constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2003
2004 // TODO: hacky, fix it!
2005 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2006 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2007
2008 // TODO: hacky, fix it!
2009 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2010 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2011 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2012
2013 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2014 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2015 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2016 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2017 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2018 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2019 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2020 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2021 constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
2022 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
2023
2024 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2026
2027 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2028 static_cast<CShuffleDataType*>(p_shared_0),
2029 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2030
2031 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2032 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2033 make_tuple(
2036 Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
2037 // shuffle
2038 M1, // M1 = MWave
2039 M2, // M2 = MXdlPack
2040 M3, // M3 * M4 * M5 = MPerXdl
2041 M4,
2042 M5)),
2045 Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave) per
2046 // shuffle
2047 N1, // N1 = NWave
2048 N2, // N2 = NXdlPack
2049 N3))), // N3 = NPerXdl
2053 Sequence<>{},
2055
2056 // calculate origin of thread output tensor on global memory
2057 // blockwise GEMM c matrix starting index
2058 const auto c_thread_mtx_on_block =
2059 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2060
2061 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2062 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2063
2064 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2066 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
2069
2070 const auto m_thread_data_on_block_idx =
2071 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2072 make_multi_index(m_thread_data_on_block));
2073
2074 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2076 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
2079
2080 const auto n_thread_data_on_block_idx =
2081 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2082 make_multi_index(n_thread_data_on_block));
2083
2084 // shuffle: threadwise copy C from VGPR to LDS
2085 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2086 AccDataType,
2087 CShuffleDataType,
2088 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2089 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2091 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2092 CShuffleNXdlPerWavePerShuffle / NXdlPack,
2093 I1,
2094 I1,
2095 M2,
2096 N2,
2097 M3,
2098 I1,
2099 M5,
2100 I1>,
2102 9,
2103 1,
2105 1,
2106 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2108 0,
2109 m_thread_data_on_block_idx[I1],
2110 n_thread_data_on_block_idx[I1],
2111 m_thread_data_on_block_idx[I2],
2112 n_thread_data_on_block_idx[I2],
2113 m_thread_data_on_block_idx[I3],
2114 m_thread_data_on_block_idx[I4],
2115 m_thread_data_on_block_idx[I5],
2116 n_thread_data_on_block_idx[I3]),
2118
2119 // calculate C grid descriptor
2120 constexpr auto DWORD_BYTES = 4;
2121 constexpr auto atomic_vector_size = DWORD_BYTES / sizeof(CDataType);
2122
2123 constexpr auto CShuffleBlockTransferClusterLengths = [&]() {
2124 if constexpr(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set)
2125 {
2126 return CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{};
2127 }
2128 // Atomic operation
2129 else
2130 {
2131 return generate_sequence_v2(
2132 [&](auto i) {
2133 if constexpr(i == 3)
2134 {
2135 return Number<
2136 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
2137 .At(i) *
2138 CShuffleBlockTransferScalarPerVector_NPerBlock /
2139 atomic_vector_size>{};
2140 }
2141 else if constexpr(i == 1)
2142 {
2143 return Number<
2144 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
2145 .At(i) /
2146 CShuffleBlockTransferScalarPerVector_NPerBlock *
2147 atomic_vector_size>{};
2148 }
2149 else
2150 {
2151 return Number<
2152 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
2153 .At(i)>{};
2154 }
2155 },
2156 Number<4>{});
2157 }
2158 }();
2159
2160 constexpr auto CShuffleBlockTransferScalarPerVector = [&]() {
2161 if constexpr(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set)
2162 {
2163 return CShuffleBlockTransferScalarPerVector_NPerBlock;
2164 }
2165 else
2166 {
2167 return atomic_vector_size;
2168 }
2169 }();
2170
2171 // shuffle: blockwise copy C from LDS to global
2172 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
2173 ThisThreadBlock, // ThreadGroup
2174 CElementwiseOperation, // ElementwiseOperation,
2175 CGlobalMemoryDataOperation, // DstInMemOp,
2176 Sequence<1,
2177 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2178 1,
2179 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2180 decltype(CShuffleBlockTransferClusterLengths),
2181 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2182 CShuffleDataType, // typename SrcData,
2183 CDataType, // typename DstData,
2184 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2185 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2186 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
2187 3, // index_t VectorDim,
2188 CShuffleBlockTransferScalarPerVector, // index_t ScalarPerVector,
2189 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
2190 false> // bool ThreadTransferDstResetCoordinateAfterRun>
2191 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2192 make_multi_index(0, 0, 0, 0),
2193 c_grid_desc_mblock_mperblock_nblock_nperblock,
2194 make_multi_index(block_m_id, 0, block_n_id, 0),
2195 c_element_op};
2196
2197 // space filling curve for threadwise C in VGPR
2198 constexpr auto sfc_c_vgpr =
2199 SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2200 NXdlPerWave / NXdlPack,
2201 1,
2202 1,
2203 MXdlPack,
2204 NXdlPack,
2205 M2,
2206 1,
2207 M4,
2208 1>,
2210 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2211 CShuffleNXdlPerWavePerShuffle / NXdlPack,
2212 1,
2213 1,
2214 MXdlPack,
2215 NXdlPack,
2216 M2,
2217 1,
2218 M4,
2219 1>>{};
2220
2221 // space filling curve for shuffled blockwise C in global mem
2222 constexpr auto sfc_c_global =
2225 Sequence<1,
2226 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2227 1,
2228 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2229
2230 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2231
2232 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
2233
2234 static_for<0, num_access, 1>{}([&](auto access_id) {
2235 // make sure it's safe to write to LDS
2237
2238 // each thread write its data from VGPR to LDS
2239 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2240 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2241 c_thread_buf,
2242 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2243 c_shuffle_block_buf);
2244
2245 // make sure it's safe to read from LDS
2247
2248 // each block copy its data from LDS to global
2249 c_shuffle_block_copy_lds_to_global.Run(
2250 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2251 c_shuffle_block_buf,
2252 c_grid_desc_mblock_mperblock_nblock_nperblock,
2253 c_grid_buf);
2254
2255 if constexpr(access_id < num_access - 1)
2256 {
2257 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2258
2259 // move on C
2260 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2261 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2262 }
2263 });
2264 }
2265 }
2266
2267 template <bool HasMainKBlockLoop,
2268 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2269 TailNumber TailNum = TailNumber::Odd>
2270 __device__ static void Run_2Lds(const ADataType* p_a_grid,
2271 const AScaleDataType* p_a_scale_grid,
2272 const BDataType* p_b_grid,
2273 const BScaleDataType* p_b_scale_grid,
2274 CDataType* p_c_grid,
2275 void* p_shared_0,
2276 void* p_shared_1,
2277 const Problem& problem)
2278 {
2279 index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
2280 index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
2281 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2282 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2283 const auto b_grid_desc_bk0_n_bk1 =
2284 MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
2285
2286 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
2287 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2288 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2290 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2291
2292 // A/B shuffled scale for better 8-bit scale access pattern
2293 // MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack
2294 // We pad the M unconditionaly for Scale
2295 const auto Padded_Scale_M =
2296 math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
2297 const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
2298 make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
2299 math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2300 (KXdlPack * 64 / MPerXdl),
2302 make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2303 (ScaleBlockSize / APackedSize)) *
2304 MPerXdl * MXdlPack / scale_pack_size_a,
2306 1));
2307
2308 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
2309 make_tuple(problem.N / (NXdlPack * NPerXdl),
2310 math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2311 (KXdlPack * 64 / NPerXdl),
2313 make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2314 (ScaleBlockSize / BPackedSize)) *
2315 NPerXdl * NXdlPack / scale_pack_size_b,
2317 1));
2318
2319 Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
2320 decltype(a_scale_grid_desc_am_ak),
2321 decltype(b_grid_desc_bk0_n_bk1),
2322 decltype(b_scale_grid_desc_bn_ak),
2323 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2324 HasMainKBlockLoop,
2325 CGlobalMemoryDataOperation,
2326 TailNum>(p_a_grid,
2327 p_a_scale_grid,
2328 p_b_grid,
2329 p_b_scale_grid,
2330 p_c_grid,
2331 p_shared_0,
2332 p_shared_1,
2333 problem,
2334 a_grid_desc_ak0_m_ak1,
2335 a_scale_grid_desc_am_ak,
2336 b_grid_desc_bk0_n_bk1,
2337 b_scale_grid_desc_bn_ak,
2338 c_grid_desc_mblock_mperblock_nblock_nperblock);
2339 }
2340};
2341
2342} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition get_id.hpp:45
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp:37
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
_Float16 half_t
Definition data_type.hpp:31
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__global__ enable_if_t<!Use2LDS, void > kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:40
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
constexpr index_t packed_size_v
Definition data_type.hpp:411
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
signed int int32_t
Definition stdint.h:123
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:711
__host__ Argument(const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_, bool is_reduce_=false)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:712
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:763
const BScaleDataType * p_b_scale_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:764
__host__ __device__ bool IsReduceAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:751
__host__ __device__ bool IsAtomicAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:756
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:765
const CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:769
const AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:767
const BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:768
bool is_reduce
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:770
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:761
const AScaleDataType * p_a_scale_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:762
index_t StrideScaleB
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:696
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:690
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:691
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:699
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:700
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:679
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:698
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:693
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:706
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:692
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:704
index_t StrideScaleA
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:694
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:705
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:697
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:703
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:702
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:701
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:695
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, index_t StrideC_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:650
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:776
index_t c_reduce_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:835
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:831
index_t a_scale_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:833
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:832
index_t b_scale_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:834
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:156
static __device__ void Run(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const AScaleGridDesc_AM_AK &a_scale_grid_desc_am_ak, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const BScaleGridDesc_BN_AK &b_scale_grid_desc_bn_ak, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1274
__host__ static __device__ constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:287
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:249
static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:222
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:258
__host__ static __device__ constexpr auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:569
static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1233
static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1012
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:253
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:217
static __device__ constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:958
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:588
static __device__ void Run(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1712
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:436
remove_cvref_t< decltype(BlockGemmMXBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:983
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:328
static __host__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:232
static __device__ constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:838
static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:276
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:427
static __device__ void Run_2Lds(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:2270
static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:227
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:264
__host__ static __device__ constexpr auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:579
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:243
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1248
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:237
static __device__ void Run_2Lds(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const AScaleGridDesc_AM_AK &a_scale_grid_desc_am_ak, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const BScaleGridDesc_BN_AK &b_scale_grid_desc_bn_ak, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1773
static __device__ constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:968
static __host__ constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1240
static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:271
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition thread_group_tensor_slice_transfer_direct_load.hpp:55
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition data_type.hpp:42
Definition type.hpp:177
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129