device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp Source File

device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp Source File#

Composable Kernel: device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp Source File
device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5#ifndef __HIPCC_RTC__
6#include <iostream>
7#include <sstream>
8#endif
9
11#ifndef __HIPCC_RTC__
14#endif
15
24
25namespace ck {
26namespace tensor_operation {
27namespace device {
28
29template <typename GridwiseGemm,
30 typename FloatAB,
31 typename FloatC,
32 typename AElementwiseOperation,
33 typename BElementwiseOperation,
34 typename AccElementwiseOperation,
35 typename B1ElementwiseOperation,
36 typename CElementwiseOperation,
37 typename AGridDesc_AK0_M_AK1,
38 typename BGridDesc_BK0_N_BK1,
39 typename B1GridDesc_BK0_N_BK1,
40 typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
41 typename Block2CTileMap,
42 typename ComputeBasePtrOfStridedBatch,
43 typename C0MatrixMask,
44 bool HasMainKBlockLoop>
45__global__ void
46#if CK_USE_LAUNCH_BOUNDS
48#endif
50 const FloatAB* __restrict__ p_a_grid,
51 const FloatAB* __restrict__ p_b_grid,
52 const FloatAB* __restrict__ p_b1_grid,
53 FloatC* __restrict__ p_c_grid,
54 const AElementwiseOperation a_element_op,
55 const BElementwiseOperation b_element_op,
56 const AccElementwiseOperation acc_element_op,
57 const B1ElementwiseOperation b1_element_op,
58 const CElementwiseOperation c_element_op,
59 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
60 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
61 const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
62 const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
63 c_grid_desc_mblock_mperblock_nblock_nperblock,
64 const Block2CTileMap block_2_ctile_map,
65 const index_t batch_count,
66 const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
67 const C0MatrixMask c0_matrix_mask)
68{
69#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
70 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
71 {
72 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
73 const index_t num_blocks_per_batch =
74 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
75 const index_t g_idx =
76 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
77
78 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
79 static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
80 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
81 static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
82 const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
83 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
84 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
85 static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
86
87 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
88 p_b_grid + b_batch_offset,
89 p_b1_grid + b1_batch_offset,
90 p_c_grid + c_batch_offset,
91 p_shared,
92 a_element_op,
93 b_element_op,
94 acc_element_op,
95 b1_element_op,
96 c_element_op,
97 a_grid_desc_ak0_m_ak1,
98 b_grid_desc_bk0_n_bk1,
99 b1_grid_desc_bk0_n_bk1,
100 c_grid_desc_mblock_mperblock_nblock_nperblock,
101 block_2_ctile_map,
102 c0_matrix_mask);
103 }
104#else
105 ignore = p_a_grid;
106 ignore = p_b_grid;
107 ignore = p_b1_grid;
108 ignore = p_c_grid;
109 ignore = a_element_op;
110 ignore = b_element_op;
111 ignore = acc_element_op;
112 ignore = b1_element_op;
113 ignore = c_element_op;
114 ignore = a_grid_desc_ak0_m_ak1;
115 ignore = b_grid_desc_bk0_n_bk1;
116 ignore = b1_grid_desc_bk0_n_bk1;
117 ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
118 ignore = block_2_ctile_map;
119 ignore = batch_count;
120 ignore = compute_base_ptr_of_batch;
121 ignore = c0_matrix_mask;
122#endif // end of if (defined(__gfx9__))
123}
124
125// Computes C = A * B0 * B1
126// ^^^^^^ (Acc0)
127// ^^^^^^^^^^^ (Acc1)
128
129// When using NPadding as GemmSpecialization, AccElementwiseOperation should be set to
130// ScaleAndResetNaNToMinusInfinity.
131// if !isNan(AccElement)
132// AccElement *= scale
133// else
134// AccElement = -INFINITY
135// Otherwise, result may be wrong.
136
137template <typename ALayout,
138 typename BLayout, // B0Layout
139 typename B1Layout,
140 typename CLayout,
141 typename ADataType,
142 typename BDataType,
143 typename B1DataType,
144 typename CDataType,
145 typename GemmAccDataType,
146 typename CShuffleDataType,
147 typename AElementwiseOperation,
148 typename BElementwiseOperation,
149 typename AccElementwiseOperation,
150 typename B1ElementwiseOperation,
151 typename CElementwiseOperation,
152 GemmSpecialization GemmSpec,
153 index_t NumGemmKPrefetchStage,
154 index_t BlockSize,
155 index_t MPerBlock,
156 index_t NPerBlock, // Gemm0NPerBlock
157 index_t KPerBlock, // Gemm0KPerBlock
158 index_t Gemm1NPerBlock,
159 index_t Gemm1KPerBlock,
160 index_t AK1,
161 index_t BK1,
162 index_t B1K1,
163 index_t MPerXDL,
164 index_t NPerXDL,
165 index_t MXdlPerWave,
166 index_t NXdlPerWave,
167 index_t Gemm1NXdlPerWave,
168 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
169 typename ABlockTransferThreadClusterArrangeOrder,
170 typename ABlockTransferSrcAccessOrder,
171 index_t ABlockTransferSrcVectorDim,
172 index_t ABlockTransferSrcScalarPerVector,
173 index_t ABlockTransferDstScalarPerVector_AK1,
174 bool ABlockLdsExtraM,
175 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
176 typename BBlockTransferThreadClusterArrangeOrder,
177 typename BBlockTransferSrcAccessOrder,
178 index_t BBlockTransferSrcVectorDim,
179 index_t BBlockTransferSrcScalarPerVector,
180 index_t BBlockTransferDstScalarPerVector_BK1,
181 bool BBlockLdsExtraN,
182 typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
183 typename B1BlockTransferThreadClusterArrangeOrder,
184 typename B1BlockTransferSrcAccessOrder,
185 index_t B1BlockTransferSrcVectorDim,
186 index_t B1BlockTransferSrcScalarPerVector,
187 index_t B1BlockTransferDstScalarPerVector_BK1,
188 bool B1BlockLdsExtraN,
189 index_t CShuffleMXdlPerWavePerShuffle,
190 index_t CShuffleNXdlPerWavePerShuffle,
191 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
192 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
196 : public DeviceBatchedGemmSoftmaxGemm<ALayout,
197 BLayout,
198 B1Layout,
199 CLayout,
200 ADataType,
201 BDataType,
202 B1DataType,
203 CDataType,
204 AElementwiseOperation,
205 BElementwiseOperation,
206 AccElementwiseOperation,
207 B1ElementwiseOperation,
208 CElementwiseOperation,
209 MaskOutUpperTriangle>
210{
211
213 static constexpr auto MXdlPerWave64 =
214 GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
215 static constexpr auto MXdlPerWave32 =
216 GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
217
218 static constexpr auto I0 = Number<0>{};
219 static constexpr auto I1 = Number<1>{};
220 static constexpr auto I2 = Number<2>{};
221
222 static constexpr auto matrix_padder =
224 MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
225
226 static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
227 {
228 const auto a_grid_desc_mraw_kraw = [&]() {
230 {
231 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
232 make_tuple(StrideA, I1));
233 }
235 {
236 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
237 make_tuple(I1, StrideA));
238 }
239 }();
240
241 const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
242
243 const auto M = a_grid_desc_m_k.GetLength(I0);
244 const auto K = a_grid_desc_m_k.GetLength(I1);
245
246 const auto AK0 = K / AK1;
247
248 return transform_tensor_descriptor(a_grid_desc_m_k,
253 }
254
255 static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
256 {
257 const auto b_grid_desc_nraw_kraw = [&]() {
259 {
260 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
261 make_tuple(I1, StrideB));
262 }
264 {
265 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
266 make_tuple(StrideB, I1));
267 }
268 }();
269
270 const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
271
272 const auto N = b_grid_desc_n_k.GetLength(I0);
273 const auto K = b_grid_desc_n_k.GetLength(I1);
274
275 const auto BK0 = K / BK1;
276
277 return transform_tensor_descriptor(b_grid_desc_n_k,
282 }
283
284 // Args: Gemm1KRaw, Gemm1NRaw, StrideB1
285 static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
286 {
287 const auto b1_grid_desc_nraw_kraw = [&]() {
289 {
290 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
291 make_tuple(I1, StrideB));
292 }
294 {
295 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
296 make_tuple(StrideB, I1));
297 }
298 }();
299
300 const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
301
302 const auto N = b1_grid_desc_n_k.GetLength(I0);
303 const auto K = b1_grid_desc_n_k.GetLength(I1);
304
305 const auto B1K0 = K / B1K1;
306
308 b1_grid_desc_n_k,
313 }
314
315 static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
316 {
317 const auto c_grid_desc_mraw_nraw = [&]() {
319 {
320 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
321 make_tuple(StrideC, I1));
322 }
324 {
325 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
326 make_tuple(I1, StrideC));
327 }
328 }();
329
330 return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
331 }
332
334 {
336 index_t BatchStrideB,
337 index_t BatchStrideB1,
338 index_t BatchStrideC)
339 : BatchStrideA_(BatchStrideA),
340 BatchStrideB_(BatchStrideB),
341 BatchStrideB1_(BatchStrideB1),
342 BatchStrideC_(BatchStrideC)
343 {
344 }
345
346 __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
347 {
348 return g_idx * static_cast<long_index_t>(BatchStrideA_);
349 }
350
351 __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
352 {
353 return g_idx * static_cast<long_index_t>(BatchStrideB_);
354 }
355
356 __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
357 {
358 return g_idx * static_cast<long_index_t>(BatchStrideB1_);
359 }
360
361 __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
362 {
363 return g_idx * static_cast<long_index_t>(BatchStrideC_);
364 }
365
366 private:
367 index_t BatchStrideA_;
368 index_t BatchStrideB_;
369 index_t BatchStrideB1_;
370 index_t BatchStrideC_;
371 };
372
376 using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
377
381
382 // GridwiseGemm
383 template <index_t MXdlPerWave_>
385 ADataType, // TODO: distinguish A/B datatype
386 GemmAccDataType,
387 CShuffleDataType,
388 CDataType,
389 AElementwiseOperation,
390 BElementwiseOperation,
391 AccElementwiseOperation,
392 B1ElementwiseOperation,
393 CElementwiseOperation,
399 NumGemmKPrefetchStage,
400 BlockSize,
401 MPerBlock,
402 NPerBlock,
403 KPerBlock,
404 Gemm1NPerBlock,
405 Gemm1KPerBlock,
406 AK1,
407 BK1,
408 B1K1,
409 MPerXDL,
410 NPerXDL,
411 MXdlPerWave_,
412 NXdlPerWave,
413 Gemm1NXdlPerWave,
414 ABlockTransferThreadClusterLengths_AK0_M_AK1,
415 ABlockTransferThreadClusterArrangeOrder,
416 ABlockTransferSrcAccessOrder,
417 ABlockTransferSrcVectorDim,
418 ABlockTransferSrcScalarPerVector,
419 ABlockTransferDstScalarPerVector_AK1,
420 true,
421 ABlockLdsExtraM,
422 BBlockTransferThreadClusterLengths_BK0_N_BK1,
423 BBlockTransferThreadClusterArrangeOrder,
424 BBlockTransferSrcAccessOrder,
425 BBlockTransferSrcVectorDim,
426 BBlockTransferSrcScalarPerVector,
427 BBlockTransferDstScalarPerVector_BK1,
428 true,
429 BBlockLdsExtraN,
430 B1BlockTransferThreadClusterLengths_BK0_N_BK1,
431 B1BlockTransferThreadClusterArrangeOrder,
432 B1BlockTransferSrcAccessOrder,
433 B1BlockTransferSrcVectorDim,
434 B1BlockTransferSrcScalarPerVector,
435 B1BlockTransferDstScalarPerVector_BK1,
436 false,
437 B1BlockLdsExtraN,
438 CShuffleMXdlPerWavePerShuffle,
439 CShuffleNXdlPerWavePerShuffle,
440 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
441 CShuffleBlockTransferScalarPerVector_NPerBlock,
442 LoopSched,
443 matrix_padder.PadN,
447
448#ifndef __HIPCC_RTC__
449 // Argument
450 struct Argument : public BaseArgument
451 {
452 Argument(const ADataType* p_a_grid,
453 const BDataType* p_b_grid,
454 const B1DataType* p_b1_grid,
455 CDataType* p_c_grid,
456 index_t MRaw,
457 index_t NRaw,
458 index_t KRaw,
459 index_t Gemm1NRaw, // = ORaw
460 index_t Batch,
461 index_t StrideA,
462 index_t StrideB,
463 index_t StrideB1,
464 index_t StrideC,
465 index_t BatchStrideA,
466 index_t BatchStrideB,
467 index_t BatchStrideB1,
468 index_t BatchStrideC,
469 AElementwiseOperation a_element_op,
470 BElementwiseOperation b_element_op,
471 AccElementwiseOperation acc_element_op,
472 B1ElementwiseOperation b1_element_op,
473 CElementwiseOperation c_element_op)
474 : p_a_grid_{p_a_grid},
475 p_b_grid_{p_b_grid},
476 p_b1_grid_{p_b1_grid},
477 p_c_grid_{p_c_grid},
481 DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)},
482 c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, Gemm1NRaw, StrideC)},
483 block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
484 a_element_op_{a_element_op},
485 b_element_op_{b_element_op},
486 acc_element_op_{acc_element_op},
487 b1_element_op_{b1_element_op},
488 c_element_op_{c_element_op},
489 batch_count_(Batch),
490 compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
491 c0_matrix_mask_{NRaw},
492 raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
493 {
494 }
495
496 // private:
497 const ADataType* p_a_grid_;
498 const BDataType* p_b_grid_;
499 const B1DataType* p_b1_grid_;
500 CDataType* p_c_grid_;
506 AElementwiseOperation a_element_op_;
507 BElementwiseOperation b_element_op_;
508 AccElementwiseOperation acc_element_op_;
509 B1ElementwiseOperation b1_element_op_;
510 CElementwiseOperation c_element_op_;
513
514 // check C0 masking and padding
516
517 // For robust IsSupportedArgument() check
518 std::vector<index_t> raw_lengths_m_n_k_o_;
519 };
520
521 // Invoker
522 struct Invoker : public BaseInvoker
523 {
525
526 template <typename GridwiseGemm>
527 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
528 {
529 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
534 {
535 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
536 }
537 auto c_grid_desc_mblock_mperblock_nblock_nperblock =
538 GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
539 arg.c_grid_desc_m_n_);
540 const index_t grid_size =
541 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
542
543 // Gemm0_K
544 const auto K =
545 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
546
547 float ave_time = 0;
548
549 auto launch_kernel = [&](auto has_main_k_block_loop_) {
551 GridwiseGemm,
552 ADataType, // TODO: distiguish A/B datatype
553 CDataType,
554 AElementwiseOperation,
555 BElementwiseOperation,
556 AccElementwiseOperation,
557 B1ElementwiseOperation,
558 CElementwiseOperation,
562 typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
563 typename GridwiseGemm::DefaultBlock2CTileMap,
564 ComputeBasePtrOfStridedBatch,
566 has_main_k_block_loop_>;
567
568 return launch_and_time_kernel(stream_config,
569 kernel,
570 dim3(grid_size),
571 dim3(BlockSize),
572 0,
573 arg.p_a_grid_,
574 arg.p_b_grid_,
575 arg.p_b1_grid_,
576 arg.p_c_grid_,
577 arg.a_element_op_,
578 arg.b_element_op_,
579 arg.acc_element_op_,
580 arg.b1_element_op_,
581 arg.c_element_op_,
585 c_grid_desc_mblock_mperblock_nblock_nperblock,
587 arg.batch_count_,
589 arg.c0_matrix_mask_);
590 };
591
592 // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
593 // to concern Gemm0's loop
594 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
595 {
596 ave_time = launch_kernel(integral_constant<bool, true>{});
597 }
598 else
599 {
600 ave_time = launch_kernel(integral_constant<bool, false>{});
601 }
602
603 return ave_time;
604 }
605
606 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
607 {
608 if(get_warp_size() == 64)
609 {
610 if constexpr(MXdlPerWave64 > 0)
611 {
612 return RunImp<GridwiseGemm64>(arg, stream_config);
613 }
614 }
615 else
616 {
617 if constexpr(MXdlPerWave32 > 0)
618 {
619 return RunImp<GridwiseGemm32>(arg, stream_config);
620 }
621 }
622 return 0;
623 }
624
625 // polymorphic
626 float Run(const BaseArgument* p_arg,
627 const StreamConfig& stream_config = StreamConfig{}) override
628 {
629 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
630 }
631 };
632#endif
633
634 static constexpr bool IsValidCompilationParameter()
635 {
636 // TODO: properly implement this check
637 return true;
638 }
639
640 static constexpr bool
641 IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
642 {
643 // check vector load/store
646
647 // check vector load of A
648 if constexpr(is_same_v<ALayout, Row>)
649 {
650 if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
651 {
652 return false;
653 }
654 }
655 else if constexpr(is_same_v<ALayout, Col>)
656 {
657 if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
658 {
659 return false;
660 }
661 }
662 else
663 {
664 return false;
665 }
666
667 // check vector load of B
668 if constexpr(is_same_v<BLayout, Row>)
669 {
670 if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
671 {
672 return false;
673 }
674 }
675 else if constexpr(is_same_v<BLayout, Col>)
676 {
677 if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
678 {
679 return false;
680 }
681 }
682 else
683 {
684 return false;
685 }
686
687 // check vector load of B1
688 if constexpr(is_same_v<B1Layout, Row>)
689 {
690 if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
691 {
692 return false;
693 }
694 }
695 else if constexpr(is_same_v<B1Layout, Col>)
696 {
697 if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
698 {
699 return false;
700 }
701 }
702 else
703 {
704 return false;
705 }
706
707 // check vector load of C
708 if constexpr(is_same_v<CLayout, Row>)
709 {
710 if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
711 {
712 return false;
713 }
714 }
715 else if constexpr(is_same_v<CLayout, Col>)
716 {
717 if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
718 {
719 return false;
720 }
721 }
722 else
723 {
724 return false;
725 }
726
727 return true;
728 }
729
730#ifndef __HIPCC_RTC__
731 static bool IsSupportedArgument(const Argument& arg)
732 {
734 {
735 return false;
736 }
737 // Note: we need raw lengths since threadwise copy can not handle vector load when part of
738 // vector is out of bounds
739 const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
740 const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
741 const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
742 const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
743
744 if(get_warp_size() == 64)
745 {
746 if constexpr(MXdlPerWave64 > 0)
747 {
752 arg.block_2_ctile_map_) and
753 IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw);
754 }
755 }
756 else
757 {
758 if constexpr(MXdlPerWave32 > 0)
759 {
764 arg.block_2_ctile_map_) and
765 IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw);
766 }
767 }
768 return false;
769 }
770
771 // polymorphic
772 bool IsSupportedArgument(const BaseArgument* p_arg) override
773 {
774 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
775 }
776
777 static auto MakeArgument(const ADataType* p_a,
778 const BDataType* p_b,
779 const B1DataType* p_b1,
780 CDataType* p_c,
781 index_t MRaw,
782 index_t NRaw,
783 index_t KRaw,
784 index_t Gemm1NRaw,
785 index_t Batch,
786 index_t StrideA,
787 index_t StrideB,
788 index_t StrideB1,
789 index_t StrideC,
790 index_t BatchStrideA,
791 index_t BatchStrideB,
792 index_t BatchStrideB1,
793 index_t BatchStrideC,
794 AElementwiseOperation a_element_op,
795 BElementwiseOperation b_element_op,
796 AccElementwiseOperation acc_element_op,
797 B1ElementwiseOperation b1_element_op,
798 CElementwiseOperation c_element_op)
799 {
800 return Argument{p_a, p_b, p_b1, p_c, MRaw,
801 NRaw, KRaw, Gemm1NRaw, Batch, StrideA,
802 StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB,
803 BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op,
804 b1_element_op, c_element_op};
805 }
806
807 static auto MakeInvoker() { return Invoker{}; }
808
809 // polymorphic
810 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
811 const void* p_b,
812 const void* p_b1,
813 void* p_c,
814 index_t MRaw,
815 index_t NRaw,
816 index_t KRaw,
817 index_t Gemm1NRaw,
818 index_t Batch,
819 index_t StrideA,
820 index_t StrideB,
821 index_t StrideB1,
822 index_t StrideC,
823 index_t BatchStrideA,
824 index_t BatchStrideB,
825 index_t BatchStrideB1,
826 index_t BatchStrideC,
827 AElementwiseOperation a_element_op,
828 BElementwiseOperation b_element_op,
829 AccElementwiseOperation acc_element_op,
830 B1ElementwiseOperation b1_element_op,
831 CElementwiseOperation c_element_op) override
832 {
833 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
834 static_cast<const BDataType*>(p_b),
835 static_cast<const B1DataType*>(p_b1),
836 static_cast<CDataType*>(p_c),
837 MRaw,
838 NRaw,
839 KRaw,
840 Gemm1NRaw,
841 Batch,
842 StrideA,
843 StrideB,
844 StrideB1,
845 StrideC,
846 BatchStrideA,
847 BatchStrideB,
848 BatchStrideB1,
849 BatchStrideC,
850 a_element_op,
851 b_element_op,
852 acc_element_op,
853 b1_element_op,
854 c_element_op);
855 }
856
857 // polymorphic
858 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
859 {
860 return std::make_unique<Invoker>(Invoker{});
861 }
862
863 // polymorphic
864 std::string GetTypeString() const override
865 {
866 auto str = std::stringstream();
867
868 // clang-format off
869 str << "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
870 << "<"
871 << BlockSize << ", "
872 << MPerBlock << ", "
873 << NPerBlock << ", "
874 << KPerBlock << ", "
875 << AK1 << ", "
876 << BK1 << ", "
877 << MPerBlock << ", "
878 << Gemm1NPerBlock << ", "
879 << Gemm1KPerBlock << ", "
880 << B1K1 << ", "
881 << getGemmSpecializationString(GemmSpec) << ">";
882 // clang-format on
883
884 return str.str();
885 }
886#endif
887
888 template <class ADesc, class BDesc, class B1Desc, class CDesc>
890 {
891 template <class AGridDescriptor>
892 static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc)
893 {
894 const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc);
895
896 const auto M = a_grid_desc_m_k.GetLength(I0);
897 const auto K = a_grid_desc_m_k.GetLength(I1);
898
899 const auto AK0 = K / AK1;
900
902 a_grid_desc_m_k,
907 }
908
909 template <class BGridDescriptor>
910 static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc)
911 {
912 const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc);
913
914 const auto N = b_grid_desc_n_k.GetLength(I0);
915 const auto K = b_grid_desc_n_k.GetLength(I1);
916
917 const auto BK0 = K / BK1;
918
920 b_grid_desc_n_k,
925 }
926
927 template <class B1GridDescriptor>
928 static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc)
929 {
930 const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc);
931
932 const auto N = b1_grid_desc_n_k.GetLength(I0);
933 const auto K = b1_grid_desc_n_k.GetLength(I1);
934
935 const auto B1K0 = K / B1K1;
936
938 b1_grid_desc_n_k,
943 }
944
945 template <class CGridDescriptor>
946 static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc)
947 {
948 return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc);
949 }
950
958
959 // GridwiseGemm
960 template <index_t MXdlPerWave_>
962 ADataType, // TODO: distinguish A/B datatype
963 GemmAccDataType,
964 CShuffleDataType,
965 CDataType,
966 AElementwiseOperation,
967 BElementwiseOperation,
968 AccElementwiseOperation,
969 B1ElementwiseOperation,
970 CElementwiseOperation,
976 NumGemmKPrefetchStage,
977 BlockSize,
978 MPerBlock,
979 NPerBlock,
980 KPerBlock,
981 Gemm1NPerBlock,
982 Gemm1KPerBlock,
983 AK1,
984 BK1,
985 B1K1,
986 MPerXDL,
987 NPerXDL,
988 MXdlPerWave_,
989 NXdlPerWave,
990 Gemm1NXdlPerWave,
991 ABlockTransferThreadClusterLengths_AK0_M_AK1,
992 ABlockTransferThreadClusterArrangeOrder,
993 ABlockTransferSrcAccessOrder,
994 ABlockTransferSrcVectorDim,
995 ABlockTransferSrcScalarPerVector,
996 ABlockTransferDstScalarPerVector_AK1,
997 true,
998 ABlockLdsExtraM,
999 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1000 BBlockTransferThreadClusterArrangeOrder,
1001 BBlockTransferSrcAccessOrder,
1002 BBlockTransferSrcVectorDim,
1003 BBlockTransferSrcScalarPerVector,
1004 BBlockTransferDstScalarPerVector_BK1,
1005 true,
1006 BBlockLdsExtraN,
1007 B1BlockTransferThreadClusterLengths_BK0_N_BK1,
1008 B1BlockTransferThreadClusterArrangeOrder,
1009 B1BlockTransferSrcAccessOrder,
1010 B1BlockTransferSrcVectorDim,
1011 B1BlockTransferSrcScalarPerVector,
1012 B1BlockTransferDstScalarPerVector_BK1,
1013 false,
1014 B1BlockLdsExtraN,
1015 CShuffleMXdlPerWavePerShuffle,
1016 CShuffleNXdlPerWavePerShuffle,
1017 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1018 CShuffleBlockTransferScalarPerVector_NPerBlock,
1019 LoopSched,
1020 matrix_padder.PadN,
1022
1025
1034
1035 // element-wise op
1036 AElementwiseOperation a_element_op;
1037 BElementwiseOperation b_element_op;
1038 B1ElementwiseOperation b1_element_op;
1039 CElementwiseOperation c_element_op;
1040
1042 bool is_valid = false;
1043
1044 constexpr Descriptor(ADesc a,
1045 BDesc b,
1046 B1Desc b1,
1047 CDesc c,
1048 AElementwiseOperation a_element_op_,
1049 BElementwiseOperation b_element_op_,
1050 B1ElementwiseOperation b1_element_op_,
1051 CElementwiseOperation c_element_op_)
1056 block_2_ctile_map{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)},
1058 GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1060 has_main_k_block_loop{GridwiseGemm64::CalculateHasMainKBlockLoop(
1061 a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
1062 c0_matrix_mask{c.GetLength(I1)},
1063 a_element_op{a_element_op_},
1064 b_element_op{b_element_op_},
1065 b1_element_op{b1_element_op_},
1066 c_element_op{c_element_op_},
1067 is_valid{false}
1068 {
1069 if(get_warp_size() == 64)
1070 {
1071 if constexpr(MXdlPerWave64 > 0)
1072 {
1079 b_grid_desc_bk0_n_bk1.GetLength(I1),
1080 a_grid_desc_ak0_m_ak1.GetLength(I0) *
1081 a_grid_desc_ak0_m_ak1.GetLength(I2),
1082 b1_grid_desc_bk0_n_bk1.GetLength(I1)) and
1083 GridwiseGemm64::template IsValidCompilationParameter<>();
1084 }
1085 }
1086 else
1087 {
1088 if constexpr(MXdlPerWave32 > 0)
1089 {
1096 b_grid_desc_bk0_n_bk1.GetLength(I1),
1097 a_grid_desc_ak0_m_ak1.GetLength(I0) *
1098 a_grid_desc_ak0_m_ak1.GetLength(I2),
1099 b1_grid_desc_bk0_n_bk1.GetLength(I1)) and
1100 GridwiseGemm32::template IsValidCompilationParameter<>();
1101 }
1102 }
1103 }
1104 constexpr bool IsValid() const { return is_valid; }
1105 };
1106
1107 template <class ADesc, class BDesc, class B1Desc, class CDesc>
1108 static constexpr auto
1110 BDesc b,
1111 B1Desc b1,
1112 CDesc c,
1113 AElementwiseOperation a_element_op = AElementwiseOperation{},
1114 BElementwiseOperation b_element_op = BElementwiseOperation{},
1115 B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
1116 CElementwiseOperation c_element_op = CElementwiseOperation{})
1117 {
1118 return Descriptor<ADesc, BDesc, B1Desc, CDesc>(
1119 a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op);
1120 }
1121
1122 template <class Desc>
1123 __device__ static void Run(const Desc& desc,
1124 const float scale,
1125 const ADataType* __restrict__ p_a_grid,
1126 const ADataType* __restrict__ p_b_grid,
1127 const ADataType* __restrict__ p_b1_grid,
1128 CDataType* __restrict__ p_c_grid)
1129 {
1130#ifndef __HIPCC_RTC__
1131 assert(desc.is_valid);
1132#endif
1133 using GridwiseGemm = conditional_t<get_warp_size() == 64,
1134 typename Desc::GridwiseGemm64,
1135 typename Desc::GridwiseGemm32>;
1136 __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
1137 AccElementwiseOperation acc_element_op{scale};
1138
1139 if(desc.has_main_k_block_loop)
1140 {
1141 GridwiseGemm::template Run<true>(
1142 p_a_grid,
1143 p_b_grid,
1144 p_b1_grid,
1145 p_c_grid,
1146 p_shared_block,
1147 desc.a_element_op,
1148 desc.b_element_op,
1149 acc_element_op,
1150 desc.b1_element_op,
1151 desc.c_element_op,
1152 desc.a_grid_desc_ak0_m_ak1,
1153 desc.b_grid_desc_bk0_n_bk1,
1154 desc.b1_grid_desc_bk0_n_bk1,
1155 desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
1156 desc.block_2_ctile_map,
1157 desc.c0_matrix_mask);
1158 }
1159 else
1160 {
1161 GridwiseGemm::template Run<false>(
1162 p_a_grid,
1163 p_b_grid,
1164 p_b1_grid,
1165 p_c_grid,
1166 p_shared_block,
1167 desc.a_element_op,
1168 desc.b_element_op,
1169 acc_element_op,
1170 desc.b1_element_op,
1171 desc.c_element_op,
1172 desc.a_grid_desc_ak0_m_ak1,
1173 desc.b_grid_desc_bk0_n_bk1,
1174 desc.b1_grid_desc_bk0_n_bk1,
1175 desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
1176 desc.block_2_ctile_map,
1177 desc.c0_matrix_mask);
1178 }
1179 }
1180};
1181
1182} // namespace device
1183} // namespace tensor_operation
1184} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
GemmSpecialization
Definition gemm_specialization.hpp:11
__global__ void kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, const FloatAB *__restrict__ p_b1_grid, FloatC *__restrict__ p_c_grid, D0sPointer p_d0s_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const C0DEElementwiseOperation c0de_element_op, const B1ElementwiseOperation b1_element_op, const C1DEElementwiseOperation c1de_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c1_grid_desc_mblock_mperblock_nblock_nperblock, const D0sGridDesc_M_N d0s_griddesc_m_n, const Block2CTileMap block_2_ctile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask)
Definition device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:48
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition ck/stream_config.hpp:10
Gridwise gemm + softmax + gemm fusion.
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:87
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 &b1_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:231
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))> DefaultBlock2CTileMap
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:323
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:319
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition utility/integral_constant.hpp:20
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
Definition masking_specialization.hpp:57
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:361
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:356
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:351
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideB1, index_t BatchStrideC)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:335
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:346
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:451
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:503
B1ElementwiseOperation b1_element_op_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:509
index_t batch_count_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:511
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:512
CDataType * p_c_grid_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:500
GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:505
AElementwiseOperation a_element_op_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:506
std::vector< index_t > raw_lengths_m_n_k_o_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:518
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:502
const B1DataType * p_b1_grid_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:499
CElementwiseOperation c_element_op_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:510
const ADataType * p_a_grid_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:497
C0MatrixMask c0_matrix_mask_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:515
const BDataType * p_b_grid_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:498
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA, index_t StrideB, index_t StrideB1, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideB1, index_t BatchStrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:452
BElementwiseOperation b_element_op_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:507
CGridDesc_M_N c_grid_desc_m_n_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:504
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:501
AccElementwiseOperation acc_element_op_
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:508
remove_cvref_t< decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))> B1GridDesc_BK0_N_BK1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:955
constexpr bool IsValid() const
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1104
BElementwiseOperation b_element_op
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1037
AElementwiseOperation a_element_op
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1036
B1ElementwiseOperation b1_element_op
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1038
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1028
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor &b1_grid_desc)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:928
GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1031
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1026
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor &c_grid_desc)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:946
bool is_valid
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1042
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor &b_grid_desc)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:910
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor &a_grid_desc)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:892
C0MatrixMask c0_matrix_mask
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1030
remove_cvref_t< decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))> BGridDesc_BK0_N_BK1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:953
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle > GridwiseGemmBase
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:961
constexpr Descriptor(ADesc a, BDesc b, B1Desc b1, CDesc c, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, B1ElementwiseOperation b1_element_op_, CElementwiseOperation c_element_op_)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1044
GridwiseGemm64::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1033
GridwiseGemmBase< math::max(MXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1023
bool has_main_k_block_loop
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1041
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(CDesc{}))> CGridDesc_M_N
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:957
remove_cvref_t< decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))> AGridDesc_AK0_M_AK1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:951
CElementwiseOperation c_element_op
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1039
CGridDesc_M_N c_grid_desc_m_n
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1029
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1027
GridwiseGemmBase< MXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1024
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:523
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:527
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:626
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:606
DeviceOp::Argument Argument
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:524
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:210
static constexpr auto I2
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:220
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:858
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_b1, void *p_c, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA, index_t StrideB, index_t StrideB1, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideB1, index_t BatchStrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op) override
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:810
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:641
GridwiseGemmBase< math::max(MXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:445
decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)) B1GridDesc_BK0_N_BK1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:375
static constexpr auto I1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:219
static constexpr auto matrix_padder
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:222
static constexpr auto make_descriptor(ADesc a, BDesc b, B1Desc b1, CDesc c, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, B1ElementwiseOperation b1_element_op=B1ElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1109
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle DeviceOp
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:212
static constexpr auto I0
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:218
static constexpr auto MXdlPerWave32
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:215
decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)) BGridDesc_BK0_N_BK1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:374
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:285
decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)) AGridDesc_AK0_M_AK1
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:373
std::string GetTypeString() const override
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:864
GridwiseGemmBase< MXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:446
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle > GridwiseGemmBase
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:384
static constexpr auto MXdlPerWave64
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:213
conditional_t< MaskOutUpperTriangle, C0MatrixMask_impl< MaskOutUpperTrianglePredicate >, C0MatrixMask_impl< MaskDisabledPredicate > > C0MatrixMask
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:378
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:226
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:376
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const B1DataType *p_b1, CDataType *p_c, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA, index_t StrideB, index_t StrideB1, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideB1, index_t BatchStrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:777
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:634
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:772
static auto MakeInvoker()
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:807
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:255
static __device__ void Run(const Desc &desc, const float scale, const ADataType *__restrict__ p_a_grid, const ADataType *__restrict__ p_b_grid, const ADataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:1123
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:315
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp:731
Definition device_batched_gemm_softmax_gemm.hpp:31
Definition matrix_padder.hpp:63