device_grouped_gemm_xdl_fixed_nk.hpp Source File

device_grouped_gemm_xdl_fixed_nk.hpp Source File#

Composable Kernel: device_grouped_gemm_xdl_fixed_nk.hpp Source File
device_grouped_gemm_xdl_fixed_nk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23template <typename GridwiseGemm,
24 typename GemmDesc,
25 GemmSpecialization GemmSpec,
26 bool Zeroing,
27 typename ALayout,
28 typename BLayout,
29 typename DsLayout,
30 typename ELayout,
31 typename DsDataType,
32 typename Block2ETileMap,
33 typename GroupedGemmBlock2ETileMap,
34 typename AElementwiseOperation,
35 typename BElementwiseOperation,
36 typename CDEElementwiseOperation,
37 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
38 bool HasMainKBlockLoop>
39__global__ void
40#if CK_USE_LAUNCH_BOUNDS
42#endif
44 uint32_t* barrier_count,
45 const index_t barrier_size_grp,
46 const index_t group_count,
47 const index_t grid_size_grp,
48 const index_t KBatch,
49 const AElementwiseOperation a_element_op,
50 const BElementwiseOperation b_element_op,
51 const CDEElementwiseOperation c_element_op)
52{
53#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
54 if constexpr(GridwiseGemm::template IsValidCompilationParameter<EGlobalMemoryDataOperation>())
55 {
56 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
57
58 const index_t block_id = get_block_1d_id();
59
60 const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(
62
63 const index_t group_id = block_id / grid_size_grp;
64
65 if(group_id >= group_count)
66 return;
67
68 const index_t M = gemm_desc_ptr[group_id].M;
69 const index_t N = gemm_desc_ptr[group_id].N;
70 const index_t K = gemm_desc_ptr[group_id].K;
71
72 if(M == 0 || N == 0 || K == 0)
73 return;
74
75 const auto StrideA = gemm_desc_ptr[group_id].StrideA;
76 const auto StrideB = gemm_desc_ptr[group_id].StrideB;
77 const auto StrideDs = gemm_desc_ptr[group_id].StrideDs;
78 const auto StrideE = gemm_desc_ptr[group_id].StrideE;
79
80 const auto e_grid_desc_m_n =
81 GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
82
83 const index_t BlockStart = group_id * grid_size_grp;
84
85 const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch};
86
87 const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n);
88
89 constexpr auto NumDTensor = DsDataType::Size();
90
91 using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
92
93 DsGridPointer p_ds_grid_;
94
95 static_for<0, NumDTensor, 1>{}([&](auto i) {
96 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
97 // D pointer
98 p_ds_grid_(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
99 });
100
101 index_t id_off = 0;
102 index_t id_local = get_block_1d_id() - BlockStart;
103
104 const index_t mn_blocks = local_grid_size / KBatch;
105
106 while(id_local < local_grid_size)
107 {
108 const auto block_2_etile_map =
109 GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
110
111 if constexpr(Zeroing)
112 {
113 auto barrier_count_finished =
114 barrier_count + group_id * barrier_size_grp + id_local % mn_blocks;
115 GridwiseGemm::template RunWithZeroing<HasMainKBlockLoop,
116 EGlobalMemoryDataOperation,
117 GemmSpec,
118 ALayout,
119 BLayout,
120 DsLayout,
121 ELayout>(gemm_desc_ptr[group_id].p_a_grid,
122 gemm_desc_ptr[group_id].p_b_grid,
123 p_ds_grid_,
124 gemm_desc_ptr[group_id].p_e_grid,
125 p_shared,
126 barrier_count_finished,
127 a_element_op,
128 b_element_op,
129 c_element_op,
130 M,
131 N,
132 K,
133 StrideA,
134 StrideB,
135 StrideDs,
136 StrideE,
137 KBatch,
138 block_2_etile_map);
139 }
140 else
141 {
142
143 GridwiseGemm::template Run<HasMainKBlockLoop,
144 EGlobalMemoryDataOperation,
145 GemmSpec,
146 ALayout,
147 BLayout,
148 DsLayout,
149 ELayout>(gemm_desc_ptr[group_id].p_a_grid,
150 gemm_desc_ptr[group_id].p_b_grid,
151 p_ds_grid_,
152 gemm_desc_ptr[group_id].p_e_grid,
153 p_shared,
154 nullptr,
155 a_element_op,
156 b_element_op,
157 c_element_op,
158 M,
159 N,
160 K,
161 StrideA,
162 StrideB,
163 StrideDs,
164 StrideE,
165 KBatch,
166 block_2_etile_map);
167 }
168
169 id_off += grid_size_grp;
170 id_local += grid_size_grp;
171 }
172 }
173#else
174 ignore = gemm_descs_const;
175 ignore = barrier_count;
176 ignore = barrier_size_grp;
177 ignore = group_count;
178 ignore = grid_size_grp;
179 ignore = KBatch;
180 ignore = a_element_op;
181 ignore = b_element_op;
182 ignore = c_element_op;
183#endif
184}
185
186template <typename ALayout,
187 typename BLayout,
188 typename DsLayout,
189 typename ELayout,
190 typename ADataType,
191 typename BDataType,
192 typename AccDataType,
193 typename CShuffleDataType,
194 typename DsDataType,
195 typename EDataType,
196 typename AElementwiseOperation,
197 typename BElementwiseOperation,
198 typename CDEElementwiseOperation,
199 GemmSpecialization GemmSpec,
200 ck::index_t NumPrefetch,
201 ck::index_t BlockSize,
202 ck::index_t MPerBlock,
203 ck::index_t NPerBlock,
204 ck::index_t KPerBlock,
205 ck::index_t AK1,
206 ck::index_t BK1,
207 ck::index_t MPerXDL,
208 ck::index_t NPerXDL,
209 ck::index_t MXdlPerWave,
210 ck::index_t NXdlPerWave,
211 typename ABlockTransferThreadClusterLengths_K0_M_K1,
212 typename ABlockTransferThreadClusterArrangeOrder,
213 typename ABlockTransferSrcAccessOrder,
214 ck::index_t ABlockTransferSrcVectorDim,
215 ck::index_t ABlockTransferSrcScalarPerVector,
216 ck::index_t ABlockTransferDstScalarPerVector_K1,
217 bool ABlockLdsExtraM,
218 typename BBlockTransferThreadClusterLengths_K0_N_K1,
219 typename BBlockTransferThreadClusterArrangeOrder,
220 typename BBlockTransferSrcAccessOrder,
221 ck::index_t BBlockTransferSrcVectorDim,
222 ck::index_t BBlockTransferSrcScalarPerVector,
223 ck::index_t BBlockTransferDstScalarPerVector_K1,
224 bool BBlockLdsExtraN,
225 index_t CShuffleMXdlPerWavePerShuffle,
226 index_t CShuffleNXdlPerWavePerShuffle,
227 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
228 index_t CDEBlockTransferScalarPerVector_NPerBlock,
231 typename ComputeType = ADataType,
232 typename ALDSType = ComputeType,
233 typename BLDSType = ComputeType>
235 BLayout,
236 DsLayout,
237 ELayout,
238 ADataType,
239 BDataType,
240 DsDataType,
241 EDataType,
242 AElementwiseOperation,
243 BElementwiseOperation,
244 CDEElementwiseOperation>
245{
248 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
249 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
250
251 static constexpr index_t NumDTensor = DsDataType::Size();
252
253 static constexpr auto I0 = Number<0>{};
254 static constexpr auto I1 = Number<1>{};
255 static constexpr auto I2 = Number<2>{};
256
257 using AComputeType = ComputeType;
258 using BComputeType = ComputeType;
259
260 // GridwiseGemm
261 template <index_t NXdlPerWave_>
263 ADataType, // TODO: distinguish A/B datatype
264 BDataType,
267 AccDataType,
268 CShuffleDataType,
269 DsDataType,
270 EDataType,
271 AElementwiseOperation,
272 BElementwiseOperation,
273 CDEElementwiseOperation,
274 NumPrefetch, // NumGemmKPrefetchStage
275 BlockSize,
276 MPerBlock,
277 NPerBlock,
278 KPerBlock,
279 AK1,
280 BK1,
281 MPerXDL,
282 NPerXDL,
283 MXdlPerWave,
284 NXdlPerWave_,
285 ABlockTransferThreadClusterLengths_K0_M_K1,
286 ABlockTransferThreadClusterArrangeOrder,
287 ABlockTransferSrcAccessOrder,
288 ABlockTransferSrcVectorDim,
289 ABlockTransferSrcScalarPerVector,
290 ABlockTransferDstScalarPerVector_K1,
291 false, // AThreadTransferSrcResetCoordinateAfterRun,
292 ABlockLdsExtraM,
293 BBlockTransferThreadClusterLengths_K0_N_K1,
294 BBlockTransferThreadClusterArrangeOrder,
295 BBlockTransferSrcAccessOrder,
296 BBlockTransferSrcVectorDim,
297 BBlockTransferSrcScalarPerVector,
298 BBlockTransferDstScalarPerVector_K1,
299 false, // BThreadTransferSrcResetCoordinateAfterRun,
300 BBlockLdsExtraN,
301 CShuffleMXdlPerWavePerShuffle,
302 CShuffleNXdlPerWavePerShuffle,
303 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
304 CDEBlockTransferScalarPerVector_NPerBlock,
305 LoopSched,
306 PipelineVer,
307 ALDSType,
308 BLDSType>;
311
312 template <typename UnderlyingBlockToCTileMap>
314 {
315 using underlying_type = UnderlyingBlockToCTileMap;
316
318 UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0)
319 {
320 block_to_ctile_map_ = block_to_ctile_map;
321 block_start_ = block_start;
322 id_off_ = id_off;
323 }
324
325 template <typename TopIdx>
326 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
327 {
328 auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
330
331 return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]);
332 }
333
334 template <typename CTileIdx, typename CTileDim>
335 __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
336 const CTileDim& c_tile_dim) const
337 {
338 return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
339 }
340
341 template <typename CGridDesc_M_N>
342 __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
343 {
344 return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
345 }
346
347 template <typename CGridDesc_M_N>
348 __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
349 {
350 return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
351 }
352
353 UnderlyingBlockToCTileMap block_to_ctile_map_;
356 };
357
358 template <index_t MPerBlock_, index_t NPerBlock_>
360 {
361 static constexpr auto I0 = Number<0>{};
362 static constexpr auto I1 = Number<1>{};
363
364 __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default;
365
374
376 index_t N,
377 index_t KBatch,
378 index_t M01 = 8)
379 : M_(M), N_(N), KBatch_(KBatch), M01_(M01)
380 {
381 }
382
383 template <typename CGridDesc_M_N>
385 const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8)
387 c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01)
388 {
389 }
390
391 __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
392 {
393 const auto M0 = math::integer_divide_ceil(M, MPerBlock);
394 const auto N0 = math::integer_divide_ceil(N, NPerBlock);
395
396 return M0 * N0 * KBatch_;
397 }
398
399 template <typename CGridDesc_M_N>
400 __host__ __device__ constexpr index_t
401 CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
402 {
403 return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
404 }
405
406 template <typename CGridDesc_M_N>
407 __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
408 {
409 return true;
410 }
411
412 template <typename TopIdx>
413 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
414 {
415 auto block_1d_id = idx_top[I0];
416
417 const auto M0 = math::integer_divide_ceil(M_, MPerBlock_);
418 const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
419
420 block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups
421
422 const index_t idx_ksplit = block_1d_id / (M0 * N0);
423 block_1d_id = block_1d_id % (M0 * N0);
424
425 index_t idx_N0 = block_1d_id % N0;
426 index_t idx_M0 = block_1d_id / N0;
427
428 const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
429
430 index_t idx_M00 = idx_M0 / M01_;
431 index_t idx_M01 = idx_M0 % M01_;
432 index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
433
434 return make_tuple(idx_ksplit,
435 idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
436 idx_N0_M01_local / M01_adapt);
437 }
438
439 template <typename CTileIdx, typename CTileDim>
440 __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
441 const CTileDim& /* c_tile_dim */) const
442 {
443 return true; // always valid provided that user gets grid size from CalculateGridSize()
444 }
445
446 private:
447 index_t M_;
448 index_t N_;
449 index_t KBatch_;
450 index_t M01_;
451 };
452
455
456 // TODO: replace with GroupedGemmKernelArgument
458 {
459 // pointers
460 const void* a_ptr_;
461 const void* b_ptr_;
462 std::array<const void*, NumDTensor> ds_ptr_;
463 void* e_ptr_;
464
467 std::array<index_t, NumDTensor> StrideDs_;
469 };
470
471 // Argument
472 struct Argument : public BaseArgument
473 {
474
475 void UpdateKBatch(index_t k_batch)
476 {
477 k_batch_ = k_batch;
478
479 if(k_batch_ < 1)
480 {
481
482 throw std::runtime_error("wrong! k_batch must be > 0");
483 }
484
486
487 const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE_;
488 const index_t N = gemm_desc_kernel_arg_[0].N_;
489
490 const auto e_grid_desc_m_n =
491 GridwiseGemm64::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
492 AverM, N, StrideE);
493
494 const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_};
495
496 grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
497
499 }
500
501 Argument(std::vector<const void*>&,
502 std::vector<const void*>&,
503 std::vector<std::array<const void*, NumDTensor>>&,
504 std::vector<void*>&,
505 std::vector<GemmDesc>& gemm_descs,
506 AElementwiseOperation a_element_op,
507 BElementwiseOperation b_element_op,
508 CDEElementwiseOperation c_element_op)
509 : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
510 {
511 grid_size_ = 0;
512
513 k_batch_ = 1;
514
516
517 group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
518
520
521 index_t group_id = 0;
522
523 sum_of_m = gemm_descs[0].M_;
525 const index_t N = gemm_descs[0].N_;
526 const index_t K = gemm_descs[0].K_;
527
528 for(std::size_t i = 0; i < gemm_descs.size(); i++)
529 {
530 if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_)
531 {
532 throw std::runtime_error("wrong! M/N/K is not identical");
533 }
534
535 a_mtx_mraw_kraw_.emplace_back(sum_of_m, K);
536 b_mtx_nraw_kraw_.emplace_back(N, K);
537
538 const index_t StrideA = gemm_descs[i].stride_A_;
539 const index_t StrideB = gemm_descs[i].stride_B_;
540 const index_t StrideE = gemm_descs[i].stride_C_;
541
542 // pointer
543 std::array<const void*, NumDTensor> p_ds_grid;
544
545 static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; });
546
547 std::array<index_t, NumDTensor> StrideDs;
548
549 static_for<0, NumDTensor, 1>{}([&](auto j) {
550 // using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
551
552 if(gemm_descs[i].stride_Ds_.size() != NumDTensor)
553 {
554 throw std::runtime_error(
555 "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
556 }
557
558 StrideDs[j] = gemm_descs[i].stride_Ds_[j];
559 });
560
561 const auto e_grid_desc_m_n =
562 GridwiseGemm64::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
563 AverM, N, StrideE);
564
565 // block-to-e-tile map
566 const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_};
567
568 grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
569
570 if(group_id * grid_size_grp_ != grid_size_)
571 {
572 throw std::runtime_error("wrong! grid_size_grp_ is not identical!");
573 }
574
576
577 // check block-to-E-tile
578
579 if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n))
580 {
581 throw std::runtime_error("wrong! block_2_etile_map validation failed");
582 }
583
584 if(get_warp_size() == 64)
585 {
586 if constexpr(NXdlPerWave64 > 0)
587 {
588 if(!GridwiseGemm64::template CheckValidity<ALayout,
589 BLayout,
590 DsLayout,
591 ELayout,
592 GemmSpec>(
593 AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1))
594 {
595 throw std::runtime_error(
596 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid "
597 "setting");
598 }
599 }
600 else
601 {
602 throw std::runtime_error(
603 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
604 }
605 }
606 else
607 {
608 if constexpr(NXdlPerWave32 > 0)
609 {
610 if(!GridwiseGemm32::template CheckValidity<ALayout,
611 BLayout,
612 DsLayout,
613 ELayout,
614 GemmSpec>(
615 AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1))
616 {
617 throw std::runtime_error(
618 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid "
619 "setting");
620 }
621 }
622 else
623 {
624 throw std::runtime_error(
625 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
626 }
627 }
628
630 nullptr,
631 nullptr,
632 p_ds_grid,
633 nullptr,
634 AverM,
635 N,
636 K,
637 StrideA,
638 StrideB,
639 StrideDs,
640 StrideE,
641 });
642
643 group_id++;
644 }
645
646 const auto e_grid_desc_sum_m_n =
647 GridwiseGemm64::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
649
650 const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1};
651
652 barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n);
653 }
654
655 // private:
657
658 AElementwiseOperation a_element_op_;
659 BElementwiseOperation b_element_op_;
660 CDEElementwiseOperation c_element_op_;
661
662 std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
663 std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
664 std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
665
667
672
674 };
675
676 // Invoker
677 struct Invoker : public BaseInvoker
678 {
680
681 template <typename GridwiseGemm>
682 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
683 {
684 bool has_main_k_block_loop = true;
685
686 for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
687 {
688 const auto KPad =
689 GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, arg.k_batch_);
690
691 if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop)
692 {
693 throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
694 }
695 }
696
697 if(arg.grouped_gemm_kernel_args_dev == nullptr)
698 {
699 throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
700 }
701
702 float ave_time = 0;
703
704 auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) {
705 if(arg.k_batch_ == 1)
706 {
707 const auto kernel =
709 GroupedGemmKernelArgument<NumDTensor>,
710 GemmSpec,
711 false,
712 ALayout,
713 BLayout,
714 DsLayout,
715 ELayout,
716 DsDataType,
719 AElementwiseOperation,
720 BElementwiseOperation,
721 CDEElementwiseOperation,
722 e_global_memory_operation_,
723 has_main_k_block_loop_>;
724
726 stream_config,
727 kernel,
728 dim3(arg.grid_size_),
729 dim3(BlockSize),
730 0,
732 nullptr,
734 arg.gemm_desc_kernel_arg_.size(),
735 arg.grid_size_grp_,
736 arg.k_batch_,
737 arg.a_element_op_,
738 arg.b_element_op_,
739 arg.c_element_op_);
740 }
741 else
742 {
743 const auto kernel =
745 GroupedGemmKernelArgument<NumDTensor>,
746 GemmSpec,
747 true,
748 ALayout,
749 BLayout,
750 DsLayout,
751 ELayout,
752 DsDataType,
755 AElementwiseOperation,
756 BElementwiseOperation,
757 CDEElementwiseOperation,
758 e_global_memory_operation_,
759 has_main_k_block_loop_>;
760
762 stream_config,
763 kernel,
764 dim3(arg.grid_size_),
765 dim3(BlockSize),
766 0,
768 reinterpret_cast<uint32_t*>(arg.p_workspace_),
770 arg.gemm_desc_kernel_arg_.size(),
771 arg.grid_size_grp_,
772 arg.k_batch_,
773 arg.a_element_op_,
774 arg.b_element_op_,
775 arg.c_element_op_);
776 }
777 };
778
780 constexpr auto Set = InMemoryDataOperationEnum::Set;
781
782 // For bf16 datatype only kbatch = 1 scenario is supported. This condition is
783 // enforced in IsSupportedArgument function
784 if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
785 {
786 if(has_main_k_block_loop)
787 {
788 ave_time = launch_kernel(integral_constant<bool, true>{},
789 integral_constant<InMemoryDataOperationEnum, Set>{});
790 }
791 else
792 {
793 ave_time = launch_kernel(integral_constant<bool, false>{},
794 integral_constant<InMemoryDataOperationEnum, Set>{});
795 }
796 }
797 else
798 {
799 if(arg.k_batch_ > 1)
800 {
801 if(has_main_k_block_loop)
802 {
803 ave_time = launch_kernel(
804 integral_constant<bool, true>{},
805 integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
806 }
807 else
808 {
809 ave_time = launch_kernel(
810 integral_constant<bool, false>{},
811 integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
812 }
813 }
814 else
815 {
816 if(has_main_k_block_loop)
817 {
818 ave_time =
819 launch_kernel(integral_constant<bool, true>{},
820 integral_constant<InMemoryDataOperationEnum, Set>{});
821 }
822 else
823 {
824 ave_time =
825 launch_kernel(integral_constant<bool, false>{},
826 integral_constant<InMemoryDataOperationEnum, Set>{});
827 }
828 }
829 }
830
831 return ave_time;
832 }
833
835
836 // polymorphic
837 float Run(const BaseArgument* p_arg,
838 const StreamConfig& stream_config = StreamConfig{}) override
839 {
840 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
841 }
842 };
843
844 static bool IsSupportedArgument(const Argument& arg)
845 {
847 {
848 return false;
849 }
850
851 bool supported = true;
852
853 // If we use padding we do not support vector loads for dimensions not divisible by
854 // vector load size.
855 if constexpr(GemmSpec != GemmSpecialization::Default)
856 {
857 // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1}
858 // layout, thus we have to adapt it to the {M,K} or {N,K} layout.
859 const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
860 const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
861
862 for(index_t i = 0; i < arg.group_count_; ++i)
863 {
864 const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
865 const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
866
867 supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
868 supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
869 }
870 }
871
872 // For bf16 datatype only kbatch = 1 is supported since there is no AtomicAdd
873 // instruction that supports bf16 and we cannot use splitk because of that
874 if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
875 {
876 supported = supported & (arg.k_batch_ == 1);
877 }
878
879 return supported;
880 }
881
882 // polymorphic
883 bool IsSupportedArgument(const BaseArgument* p_arg) override
884 {
885 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
886 }
887
888 static auto MakeArgument(std::vector<const void*>& p_As,
889 std::vector<const void*>& p_Bs,
890 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
891 std::vector<void*>& p_Es,
892 std::vector<GemmDesc> gemm_descs,
893 AElementwiseOperation a_element_op,
894 BElementwiseOperation b_element_op,
895 CDEElementwiseOperation c_element_op)
896 {
897 return Argument{
898 p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
899 }
900
901 static auto MakeInvoker() { return Invoker{}; }
902
903 // polymorphic
904 std::unique_ptr<BaseArgument>
905 MakeArgumentPointer(std::vector<const void*>& p_As,
906 std::vector<const void*>& p_Bs,
907 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
908 std::vector<void*>& p_Es,
909 std::vector<GemmDesc>& gemm_descs,
910 AElementwiseOperation a_element_op,
911 BElementwiseOperation b_element_op,
912 CDEElementwiseOperation c_element_op) override
913 {
914 return std::make_unique<Argument>(
915 p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
916 }
917
918 // polymorphic
919 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
920 {
921 return std::make_unique<Invoker>(Invoker{});
922 }
923
924 // polymorphic
925 std::string GetTypeString() const override
926 {
927 auto str = std::stringstream();
928
929 // clang-format off
930 str << "DeviceGroupedGemm_Xdl_Fixed_NK"
931 << "<"
932 << BlockSize << ", "
933 << MPerBlock << ", "
934 << NPerBlock << ", "
935 << KPerBlock << ", "
936 << AK1 << ", "
937 << BK1 << ", "
938 << MPerXDL << ", "
939 << NPerXDL << ", "
940 << MXdlPerWave << ", "
941 << NXdlPerWave << ", "
942 << ABlockTransferSrcScalarPerVector << ", "
943 << BBlockTransferSrcScalarPerVector << ", "
944 << CShuffleMXdlPerWavePerShuffle << ", "
945 << CShuffleNXdlPerWavePerShuffle << ", "
946 << getGemmSpecializationString(GemmSpec)
947 << ">";
948 // clang-format on
949
950 return str.str();
951 }
952
953 // polymorphic
954 void SetDeviceKernelArgs(BaseArgument* p_arg, void* kernel_args) const override
955 {
956 auto arg_ptr = dynamic_cast<Argument*>(p_arg);
957 if(arg_ptr)
958 {
959 arg_ptr->grouped_gemm_kernel_args_dev = kernel_args;
960 }
961 else
962 throw std::runtime_error("The argument pointer is not an object of "
963 "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
964 }
965
966 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
967 {
968 auto arg_ptr = dynamic_cast<const Argument*>(p_arg);
969 if(arg_ptr)
970 {
971 return arg_ptr->group_count_ * arg_ptr->barrier_size_grp_ * sizeof(uint32_t);
972 }
973 else
974 throw std::runtime_error("The argument pointer is not an object of "
975 "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
976 }
977
978 size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
979 {
980 auto arg_ptr = dynamic_cast<const Argument*>(p_arg);
981 if(arg_ptr)
982 {
983 return arg_ptr->group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>);
984 }
985 else
986 throw std::runtime_error("The argument pointer is not an object of "
987 "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
988 }
989
991 void* p_workspace,
992 const StreamConfig& stream_config = StreamConfig{}) const override
993 {
994 auto arg_ptr = dynamic_cast<Argument*>(p_arg);
995 if(arg_ptr)
996 {
997 arg_ptr->p_workspace_ = p_workspace;
998 }
999 else
1000 throw std::runtime_error("The argument pointer is not an object of "
1001 "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
1002
1004 hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(arg_ptr), stream_config.stream_id_));
1005 }
1006
1007 static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
1008
1009 // polymorphic
1010 void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override
1011 {
1012 auto arg_ptr = dynamic_cast<Argument*>(p_arg);
1013 if(arg_ptr)
1014 {
1015 arg_ptr->UpdateKBatch(k_batch);
1016 }
1017 else
1018 throw std::runtime_error("The argument pointer is not an object of "
1019 "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
1020 }
1021
1022 void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
1023 {
1024 auto arg_ptr = dynamic_cast<Argument*>(p_arg);
1025 if(arg_ptr)
1026 {
1027 arg_ptr->UpdateKBatch(kbatch);
1028 }
1029 else
1030 throw std::runtime_error("The argument pointer is not an object of "
1031 "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
1032 }
1033};
1034
1035} // namespace device
1036} // namespace tensor_operation
1037} // namespace ck
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
__global__ void kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const index_t grid_size_grp, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:41
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ Default
Definition gemm_specialization.hpp:13
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
unsigned int uint32_t
Definition stdint.h:126
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:79
Definition functional2.hpp:33
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
Definition device_grouped_gemm_xdl_fixed_nk.hpp:473
index_t grid_size_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:668
std::vector< GemmBiasTransKernelArg > gemm_desc_kernel_arg_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:662
index_t barrier_size_grp_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:670
index_t group_count_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:656
CDEElementwiseOperation c_element_op_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:660
AElementwiseOperation a_element_op_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:658
index_t grid_size_grp_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:669
index_t k_batch_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:673
const void * grouped_gemm_kernel_args_dev
Definition device_grouped_gemm_xdl_fixed_nk.hpp:666
void UpdateKBatch(index_t k_batch)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:475
index_t sum_of_m
Definition device_grouped_gemm_xdl_fixed_nk.hpp:671
Argument(std::vector< const void * > &, std::vector< const void * > &, std::vector< std::array< const void *, NumDTensor > > &, std::vector< void * > &, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:501
std::vector< Tuple< index_t, index_t > > b_mtx_nraw_kraw_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:664
BElementwiseOperation b_element_op_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:659
std::vector< Tuple< index_t, index_t > > a_mtx_mraw_kraw_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:663
__host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:407
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &)=default
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:413
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops & operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &)=default
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops & operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &&)=default
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(const CGridDesc_M_N &c_grid_desc_m_n, index_t KBatch, index_t M01=8)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:384
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, index_t N, index_t KBatch, index_t M01=8)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:375
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:440
__host__ __device__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:401
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:391
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &&)=default
index_t M_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:465
index_t N_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:465
const void * a_ptr_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:460
void * e_ptr_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:463
index_t K_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:465
index_t StrideE_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:468
const void * b_ptr_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:461
std::array< index_t, NumDTensor > StrideDs_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:467
std::array< const void *, NumDTensor > ds_ptr_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:462
index_t StrideB_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:466
index_t StrideA_
Definition device_grouped_gemm_xdl_fixed_nk.hpp:466
Definition device_grouped_gemm_xdl_fixed_nk.hpp:678
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:837
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_gemm_xdl_fixed_nk.hpp:682
DeviceOp::Argument Argument
Definition device_grouped_gemm_xdl_fixed_nk.hpp:679
UnderlyingBlockToCTileMap underlying_type
Definition device_grouped_gemm_xdl_fixed_nk.hpp:315
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:335
__host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:342
__host__ __device__ OffsettedBlockToCTileMapMLoops(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off=0)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:317
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:348
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition device_grouped_gemm_xdl_fixed_nk.hpp:326
Definition device_grouped_gemm_xdl_fixed_nk.hpp:245
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:844
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:966
void SetKBatchSize(BaseArgument *p_arg, index_t kbatch) const override
Sets the k batch size.
Definition device_grouped_gemm_xdl_fixed_nk.hpp:1022
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op) override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:905
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:883
static constexpr auto I0
Definition device_grouped_gemm_xdl_fixed_nk.hpp:253
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_gemm_xdl_fixed_nk.hpp:309
OffsettedBlockToCTileMapMLoops< Block2ETileMap > GroupedGemmBlock2ETileMap
Definition device_grouped_gemm_xdl_fixed_nk.hpp:454
void SetKBatch(BaseArgument *p_arg, index_t k_batch) const override
Sets the k batch size.
Definition device_grouped_gemm_xdl_fixed_nk.hpp:1010
static constexpr index_t NumDTensor
Definition device_grouped_gemm_xdl_fixed_nk.hpp:251
ComputeType BComputeType
Definition device_grouped_gemm_xdl_fixed_nk.hpp:258
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops< MPerBlock, NPerBlock > Block2ETileMap
Definition device_grouped_gemm_xdl_fixed_nk.hpp:453
static constexpr auto NXdlPerWave32
Definition device_grouped_gemm_xdl_fixed_nk.hpp:249
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:919
DeviceGroupedGemm_Xdl_Fixed_NK DeviceOp
Definition device_grouped_gemm_xdl_fixed_nk.hpp:246
void SetDeviceKernelArgs(BaseArgument *p_arg, void *kernel_args) const override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:954
static constexpr auto I2
Definition device_grouped_gemm_xdl_fixed_nk.hpp:255
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition device_grouped_gemm_xdl_fixed_nk.hpp:978
std::string GetTypeString() const override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:925
static void SetKBatch(Argument &arg, index_t k_batch)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:1007
GridwiseGemmMultipleD_xdl_splitk_cshuffle< ADataType, BDataType, AComputeType, BComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumPrefetch, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer, ALDSType, BLDSType > GridwiseGemmBase
Definition device_grouped_gemm_xdl_fixed_nk.hpp:262
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &stream_config=StreamConfig{}) const override
Definition device_grouped_gemm_xdl_fixed_nk.hpp:990
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_gemm_xdl_fixed_nk.hpp:248
static constexpr auto I1
Definition device_grouped_gemm_xdl_fixed_nk.hpp:254
static auto MakeInvoker()
Definition device_grouped_gemm_xdl_fixed_nk.hpp:901
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_gemm_xdl_fixed_nk.hpp:310
ComputeType AComputeType
Definition device_grouped_gemm_xdl_fixed_nk.hpp:257
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition device_grouped_gemm_xdl_fixed_nk.hpp:888
Definition device_grouped_gemm_fixed_nk.hpp:34
Definition device_grouped_gemm.hpp:80
Structure representing single GEMM problem arguments.
Definition device_grouped_gemm.hpp:29