gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp Source File

gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp Source File#

Composable Kernel: gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp Source File
gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/utility/env.hpp"
18
19namespace ck {
20
21template <typename ALayout,
22 typename BLayout,
23 typename DsLayout,
24 typename ELayout,
25 typename AsDataType,
26 typename BsDataType,
27 typename BScaleType,
28 typename AccDataType,
29 typename CShuffleDataType,
30 typename DsDataType,
31 typename EDataType,
32 typename AElementwiseOperation,
33 typename BElementwiseOperation,
34 typename CDEElementwiseOperation,
36 index_t BlockSize,
37 index_t ScaleBlockN, // scale N
38 index_t ScaleBlockK, // scale K
39 index_t MPerBlock,
40 index_t NPerBlock,
41 index_t KPerBlock,
42 index_t AK1Value,
43 index_t BK1Value,
44 index_t MPerWmma,
45 index_t NPerWmma,
46 index_t MRepeat,
47 index_t NRepeat,
48 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
49 typename ABlockTransferThreadClusterArrangeOrder,
50 typename ABlockTransferSrcAccessOrder,
51 index_t ABlockTransferSrcVectorDim,
52 index_t ABlockTransferSrcScalarPerVector,
53 index_t ABlockTransferDstScalarPerVector_AK1,
54 bool AThreadTransferSrcResetCoordinateAfterRun,
55 index_t ABlockLdsExtraM,
56 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
57 typename BBlockTransferThreadClusterArrangeOrder,
58 typename BBlockTransferSrcAccessOrder,
59 index_t BBlockTransferSrcVectorDim,
60 index_t BBlockTransferSrcScalarPerVector,
61 index_t BBlockTransferDstScalarPerVector_BK1,
62 bool BThreadTransferSrcResetCoordinateAfterRun,
63 index_t BBlockLdsExtraN,
64 index_t CShuffleMRepeatPerShuffle,
65 index_t CShuffleNRepeatPerShuffle,
66 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
67 typename CDEShuffleBlockTransferScalarPerVectors,
70 typename ComputeTypeA = EDataType,
71 typename ComputeTypeB = ComputeTypeA,
72 bool PermuteA = false,
73 bool PermuteB = false>
76 ALayout,
77 BLayout,
78 DsLayout,
79 ELayout,
80 AsDataType,
81 BsDataType,
82 AccDataType,
83 CShuffleDataType,
84 DsDataType,
85 EDataType,
86 AElementwiseOperation,
87 BElementwiseOperation,
88 CDEElementwiseOperation,
89 GemmSpec,
90 BlockSize,
91 MPerBlock,
92 NPerBlock,
93 KPerBlock,
94 AK1Value,
95 BK1Value,
96 MPerWmma,
97 NPerWmma,
98 MRepeat,
99 NRepeat,
100 ABlockTransferThreadClusterLengths_AK0_M_AK1,
101 ABlockTransferThreadClusterArrangeOrder,
102 ABlockTransferSrcAccessOrder,
103 ABlockTransferSrcVectorDim,
104 ABlockTransferSrcScalarPerVector,
105 ABlockTransferDstScalarPerVector_AK1,
106 AThreadTransferSrcResetCoordinateAfterRun,
107 ABlockLdsExtraM,
108 BBlockTransferThreadClusterLengths_BK0_N_BK1,
109 BBlockTransferThreadClusterArrangeOrder,
110 BBlockTransferSrcAccessOrder,
111 BBlockTransferSrcVectorDim,
112 BBlockTransferSrcScalarPerVector,
113 BBlockTransferDstScalarPerVector_BK1,
114 BThreadTransferSrcResetCoordinateAfterRun,
115 BBlockLdsExtraN,
116 CShuffleMRepeatPerShuffle,
117 CShuffleNRepeatPerShuffle,
118 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
119 CDEShuffleBlockTransferScalarPerVectors,
120 BlkGemmPipeSched,
121 BlkGemmPipelineVer,
122 ComputeTypeA,
123 ComputeTypeB,
124 PermuteA,
125 PermuteB,
126 true>
127{
129 ALayout,
130 BLayout,
131 DsLayout,
132 ELayout,
133 AsDataType,
134 BsDataType,
135 AccDataType,
136 CShuffleDataType,
137 DsDataType,
138 EDataType,
139 AElementwiseOperation,
140 BElementwiseOperation,
141 CDEElementwiseOperation,
142 GemmSpec,
143 BlockSize,
144 MPerBlock,
145 NPerBlock,
146 KPerBlock,
147 AK1Value,
148 BK1Value,
149 MPerWmma,
150 NPerWmma,
151 MRepeat,
152 NRepeat,
153 ABlockTransferThreadClusterLengths_AK0_M_AK1,
154 ABlockTransferThreadClusterArrangeOrder,
155 ABlockTransferSrcAccessOrder,
156 ABlockTransferSrcVectorDim,
157 ABlockTransferSrcScalarPerVector,
158 ABlockTransferDstScalarPerVector_AK1,
159 AThreadTransferSrcResetCoordinateAfterRun,
160 ABlockLdsExtraM,
161 BBlockTransferThreadClusterLengths_BK0_N_BK1,
162 BBlockTransferThreadClusterArrangeOrder,
163 BBlockTransferSrcAccessOrder,
164 BBlockTransferSrcVectorDim,
165 BBlockTransferSrcScalarPerVector,
166 BBlockTransferDstScalarPerVector_BK1,
167 BThreadTransferSrcResetCoordinateAfterRun,
168 BBlockLdsExtraN,
169 CShuffleMRepeatPerShuffle,
170 CShuffleNRepeatPerShuffle,
171 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
172 CDEShuffleBlockTransferScalarPerVectors,
173 BlkGemmPipeSched,
174 BlkGemmPipelineVer,
175 ComputeTypeA,
176 ComputeTypeB,
177 PermuteA,
178 PermuteB,
179 true>;
180
181 using Base::I0;
182 using Base::I1;
183 using Base::I2;
184 using Base::I3;
185 using Base::I4;
186 using Base::I5;
187 using Base::I6;
188 using Base::I7;
189
190 using Base::AK0Number;
191 using Base::AK1Number;
192 using Base::BK0Number;
193 using Base::BK1Number;
194
195 using Base::APackedSize;
196 using Base::BPackedSize;
197
211
213
215
216 using Base::NumATensor;
217 using Base::NumBTensor;
218 using Base::NumDTensor;
219 using typename Base::AsGridPointer;
220 using typename Base::BsGridPointer;
221 using typename Base::DsGridPointer;
222 using AsDataType_ = AsDataType;
223 using BsDataType_ = BsDataType;
224
225 struct Problem
226 {
227 __host__ Problem(index_t M_,
228 index_t N_,
229 index_t K_,
230 std::array<index_t, NumATensor> StrideAs_,
231 std::array<index_t, NumBTensor> StrideBs_,
232 std::array<index_t, NumDTensor> StrideDs_,
233 index_t StrideE_,
234 index_t StrideScaleB_,
235 index_t KBatch_)
236 : M{M_},
237 N{N_},
238 K{K_},
239 StrideAs{StrideAs_},
240 StrideBs{StrideBs_},
241 StrideDs{StrideDs_},
242 StrideE{StrideE_},
243 StrideScaleB{StrideScaleB_},
244 KBatch{KBatch_},
247 KRead{CalculateKRead(K_, KBatch_)},
248 KPadded{CalculateKPadded(K_, KBatch_)},
249 AK0{CalculateAK0Padded(K_, KBatch_)},
250 BK0{CalculateBK0Padded(K_, KBatch_)},
253 {
254 }
255
256 __host__ void Print() const
257 {
258 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
259 << "SAs: {";
260 static_for<0, NumATensor, 1>{}([&](auto i) {
261 std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : "");
262 });
263 std::cout << "}, " << "SBs: {";
264 static_for<0, NumBTensor, 1>{}([&](auto i) {
265 std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : "");
266 });
267 std::cout << "}, ";
268 if constexpr(NumDTensor > 0)
269 {
270 std::cout << "SDs: { ";
271 static_for<0, NumDTensor, 1>{}([&](auto i) {
272 std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
273 });
274 std::cout << " }, ";
275 }
276 std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", "
277 << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead
278 << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0
279 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}"
280 << std::endl;
281 }
282
286 std::array<index_t, NumATensor> StrideAs;
287 std::array<index_t, NumBTensor> StrideBs;
288 std::array<index_t, NumDTensor> StrideDs;
300 };
301
302 // Argument
304 {
305 __host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
306 std::array<const void*, NumBTensor> p_bs_grid_,
307 std::array<const void*, NumDTensor> p_ds_grid_,
308 EDataType* p_e_grid_,
309 index_t M_,
310 index_t N_,
311 index_t K_,
312 std::array<index_t, NumATensor> StrideAs_,
313 std::array<index_t, NumBTensor> StrideBs_,
314 std::array<index_t, NumDTensor> StrideDs_,
315 index_t StrideE_,
316 index_t StrideScaleB_,
317 const BScaleType* p_b_scale_grid_,
318 index_t k_batch_,
319 AElementwiseOperation a_element_op_,
320 BElementwiseOperation b_element_op_,
321 CDEElementwiseOperation cde_element_op_,
322 bool is_reduce_ = false)
323 : Problem{M_,
324 N_,
325 K_,
326 StrideAs_,
327 StrideBs_,
328 StrideDs_,
329 StrideE_,
330 StrideScaleB_,
331 k_batch_},
332 p_as_grid{},
333 p_bs_grid{},
334 p_ds_grid{},
335 p_e_grid{p_e_grid_},
336 p_b_scale_grid{p_b_scale_grid_},
337 a_element_op{a_element_op_},
338 b_element_op{b_element_op_},
339 cde_element_op{cde_element_op_},
340 is_reduce(is_reduce_)
341 {
342 // populate pointer, desc for As
343 static_for<0, NumATensor, 1>{}([&](auto i) {
344 using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
345
346 // A pointer
347 p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
348 });
349
350 // populate pointer, desc for Bs
351 static_for<0, NumBTensor, 1>{}([&](auto i) {
352 using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
353
354 // B pointer
355 p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
356 });
357
358 static_for<0, NumDTensor, 1>{}([&](auto i) {
359 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
360
361 p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
362 });
363 }
364
365 __host__ __device__ inline bool IsReduceAdd() const
366 {
367 return (Problem::KBatch > 1) && is_reduce;
368 }
369
370 __host__ __device__ inline bool IsAtomicAdd() const
371 {
372 return (Problem::KBatch > 1) && (!is_reduce);
373 }
374
378 EDataType* p_e_grid;
379
380 const BScaleType* p_b_scale_grid;
381 const AElementwiseOperation a_element_op;
382 const BElementwiseOperation b_element_op;
383 const CDEElementwiseOperation cde_element_op;
385 };
386
388 {
389
390 __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
391 {
392 // Note: in xdl implementation multiple AB supports one layout
393 // but multiple strides, so we create an array of offsets with
394 // the same values.
395 // It should be fixed later on. Once we will have a thread transfer
396 // more flexible.
398 {
400 [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; });
401 }
403 {
405 [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
406 }
407
409 {
411 [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
412 }
414 {
415 if constexpr(!PermuteB)
416 {
418 [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
419 }
420 else
421 {
422 const int k0_offset = karg.KRead * karg.N;
424 [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
425 }
426 }
427
428 // Calculate B scale offset
430 {
431 scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
432 }
434 {
435 scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
436 }
437
438 if(k_id < karg.KBatch - 1)
439 {
440 karg.K = karg.KRead;
441 }
442 else
443 {
444 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
445 }
446
447 if(karg.IsReduceAdd())
448 {
449 c_reduce_offset = k_id * karg.M * karg.N;
450 }
451 else
452 {
453 c_reduce_offset = 0;
454 }
455 }
456
457 std::array<index_t, NumATensor> a_k_split_offset;
458 std::array<index_t, NumBTensor> b_k_split_offset;
459 index_t scale_k_split_offset; // New member for scale matrix offset
461 };
462
464
465 // return block_id to C matrix tile idx (m0, n0) mapping
466 // if arch = gfx942
468 // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
469
470 template <index_t NumberOfBuffers, typename BScaleGridDesc_BN_AK>
471 __device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
472 const BScaleType* p_b_scale_grid,
473 index_t block_n_id)
474 {
475 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
476 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
477
478 static constexpr auto wmma =
480 static constexpr auto KPerThread = wmma.selected_wmma.k_per_wmma;
481
482 static constexpr auto ScaleSliceSizeN = NRepeat;
483 static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK;
484
485 constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
487
488 constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
489
490 auto b_thread_offset_n = get_thread_local_1d_id() % NPerWmma +
491 (get_thread_local_1d_id() / 32) % NWaves * NPerWmma;
492 auto b_thread_offset_k = (get_thread_local_1d_id() % 32) / NPerWmma * KPerThread;
493
494 auto b_scale_thread_copy =
496 BScaleType,
497 decltype(b_scale_grid_desc_bn_ak),
498 decltype(b_scale_thread_desc),
501 1,
502 ScaleSliceSizeK,
503 1,
504 false>(
505 b_scale_grid_desc_bn_ak,
506 make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n,
507 b_thread_offset_k / ScaleBlockK));
508
510 b_scale_thread_desc.GetElementSpaceSize());
511
512 using BScale =
513 typename BlockwiseGemmPipe::template BScale<ScaleSliceSizeN,
514 ScaleSliceSizeK,
515 NWaves,
516 ScaleBlockK,
517 NumberOfBuffers,
518 decltype(b_scale_grid_desc_bn_ak),
519 decltype(b_scale_thread_copy),
520 decltype(b_scale_grid_buf),
521 decltype(b_scale_thread_buf),
522 decltype(b_scale_thread_desc)>;
523
524 return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf};
525 }
526
527 __device__ static index_t GetKBlockPerScale()
528 {
529 return (ScaleBlockK + KPerBlock - 1) / KPerBlock;
530 }
531
532 template <bool HasMainKBlockLoop,
533 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
534 TailNumber TailNum,
535 typename EpilogueArgument>
536 __device__ static void Run(AsGridPointer& p_as_grid,
537 BsGridPointer& p_bs_grid,
538 DsGridPointer& p_ds_grid,
539 EDataType* p_e_grid,
540 const BScaleType* p_b_scale_grid,
541 void* p_shared,
542 const Problem& problem,
543 AElementwiseOperation a_element_op,
544 BElementwiseOperation b_element_op,
545 CDEElementwiseOperation cde_element_op,
546 EpilogueArgument& epilogue_args)
547 {
548 const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
549 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
550 const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
551 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
552 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
553 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
554 const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
555 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
556 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
558 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
559 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
561 e_grid_desc_m_n, problem.MBlock, problem.NBlock);
562
563 // B Scale grid
564 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
565 make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
566 math::integer_divide_ceil(problem.K, ScaleBlockK)),
567 make_tuple(problem.StrideScaleB, 1));
568
569 // divide block work by [M, N]
570 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
571
572 const auto block_work_idx =
573 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
574
575 if(!block_2_ctile_map.ValidCTileIndex(
576 block_work_idx,
577 make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
578 e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
579 {
580 return;
581 }
582
583 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
584 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
585
586 // BScale struct
587 auto b_scale_struct = MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id);
588
589 const index_t num_k_block_per_scale = GetKBlockPerScale();
590
591 Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
592 decltype(bs_grid_desc_bk0_n_bk1),
593 decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
594 decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
595 decltype(b_scale_struct),
596 decltype(epilogue_args),
597 HasMainKBlockLoop,
598 EGlobalMemoryDataOperation,
599 TailNum>(p_as_grid,
600 p_bs_grid,
601 p_ds_grid,
602 p_e_grid,
603 p_shared,
604 as_grid_desc_ak0_m_ak1,
605 bs_grid_desc_bk0_n_bk1,
606 ds_grid_desc_mblock_mperblock_nblock_nperblock,
607 e_grid_desc_mblock_mperblock_nblock_nperblock,
608 a_element_op,
609 b_element_op,
610 cde_element_op,
611 block_m_id,
612 block_n_id,
613 num_k_block_per_scale,
614 b_scale_struct,
615 epilogue_args);
616 }
617
618 // NOTE: Wrapper function to have __global__ function in common
619 // between gemm_universal, b_scale, ab_scale, etc.
620 template <bool HasMainKBlockLoop,
621 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
622 TailNumber TailNum,
623 typename EpilogueArgument>
624 __device__ static void Run(void* p_shared,
625 const SplitKBatchOffset& splitk_batch_offset,
626 Argument& karg,
627 EpilogueArgument& epilogue_args)
628 {
629 // shift A matrices pointer for splitk
630 AsGridPointer p_as_grid_splitk;
631 static_for<0, NumATensor, 1>{}([&](auto i) {
632 using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
633 p_as_grid_splitk(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
634 splitk_batch_offset.a_k_split_offset[i];
635 });
636
637 // shift B matrices pointer for splitk
638 BsGridPointer p_bs_grid_splitk;
639 static_for<0, NumBTensor, 1>{}([&](auto i) {
640 using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
641 p_bs_grid_splitk(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
642 splitk_batch_offset.b_k_split_offset[i];
643 });
644
646 p_as_grid_splitk,
647 p_bs_grid_splitk,
648 karg.p_ds_grid,
649 karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
650 karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
651 p_shared,
652 karg,
653 karg.a_element_op,
654 karg.b_element_op,
655 karg.cde_element_op,
656 epilogue_args);
657 }
658};
659
660} // namespace ck
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition block_to_ctile_map.hpp:271
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:304
const BElementwiseOperation b_element_op
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:382
__host__ __device__ bool IsReduceAdd() const
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:365
__host__ Argument(std::array< const void *, NumATensor > p_as_grid_, std::array< const void *, NumBTensor > p_bs_grid_, std::array< const void *, NumDTensor > p_ds_grid_, EDataType *p_e_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t StrideScaleB_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_=false)
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:305
const AElementwiseOperation a_element_op
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:381
const CDEElementwiseOperation cde_element_op
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:383
const BScaleType * p_b_scale_grid
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:380
AsGridPointer p_as_grid
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:375
EDataType * p_e_grid
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:378
DsGridPointer p_ds_grid
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:377
__host__ __device__ bool IsAtomicAdd() const
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:370
bool is_reduce
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:384
BsGridPointer p_bs_grid
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:376
index_t KBatch
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:291
index_t MPadded
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:292
index_t NBlock
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:299
index_t StrideE
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:289
__host__ void Print() const
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:256
index_t AK0
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:296
index_t MBlock
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:298
index_t BK0
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:297
index_t StrideScaleB
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:290
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:288
std::array< index_t, NumATensor > StrideAs
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:286
index_t KPadded
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:295
index_t N
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:284
std::array< index_t, NumBTensor > StrideBs
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:287
index_t K
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:285
index_t KRead
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:294
index_t M
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:283
index_t NPadded
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:293
__host__ Problem(index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t StrideScaleB_, index_t KBatch_)
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:227
std::array< index_t, NumATensor > a_k_split_offset
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:457
index_t c_reduce_offset
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:460
index_t scale_k_split_offset
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:459
std::array< index_t, NumBTensor > b_k_split_offset
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:458
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:390
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:127
static __host__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:288
static __device__ constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:609
static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:323
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:293
decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:521
static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:318
static constexpr index_t NumATensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:133
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:311
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:405
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:378
static constexpr index_t NumBTensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:134
static __device__ index_t GetKBlockPerScale()
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:527
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:523
static constexpr index_t APackedSize
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:161
GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1Value, BK1Value, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB, true > Base
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:128
decltype(MakeBsGridPointer()) BsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:351
static constexpr index_t BPackedSize
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:168
decltype(MakeAsGridPointer()) AsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:350
static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:278
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args)
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:536
static constexpr index_t NumDTensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:508
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg, EpilogueArgument &epilogue_args)
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:624
static __device__ auto MakeBScale(const BScaleGridDesc_BN_AK &b_scale_grid_desc_bn_ak, const BScaleType *p_b_scale_grid, index_t block_n_id)
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:471
static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:535
static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:283
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:446
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:299
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:122
static __host__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:288
static constexpr auto I2
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:126
static constexpr auto I3
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:127
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor()), decltype(MakeBWmmaTileDescriptor()), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:546
static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:323
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:293
static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:318
decltype(MakeAsGridPointer()) AsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:350
static constexpr auto I1
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:125
static constexpr index_t NumATensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:133
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:311
static constexpr auto AK1Number
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:151
decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:521
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:405
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:446
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:378
static constexpr auto I6
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:130
static constexpr auto AK0Number
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:149
static constexpr index_t NumBTensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:134
static constexpr auto I0
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:124
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:523
static constexpr index_t APackedSize
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:161
static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:535
static constexpr auto I7
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:131
static constexpr auto I4
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:128
static __device__ constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:609
static constexpr index_t BPackedSize
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:168
static constexpr auto BK1Number
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:152
static constexpr auto BK0Number
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:150
decltype(MakeBsGridPointer()) BsGridPointer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:351
static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:278
static constexpr index_t NumDTensor
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:508
static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:283
static constexpr auto I5
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:129
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:299
Definition utility/sequence.hpp:43
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition wmma_gemm.hpp:553
static constexpr auto selected_wmma
Definition wmma_gemm.hpp:636
Definition functional2.hpp:33
Definition device_base.hpp:197