22template <
typename GridwiseGemm,
24 typename FloatDsPointer,
26 typename FloatRsPointer,
27 typename AElementwiseOperation,
28 typename BElementwiseOperation,
29 typename CDEElementwiseOperation,
30 typename QsElementwiseOperation,
31 typename RsElementwiseOperation,
32 typename AGridDesc_AK0_M_AK1,
33 typename BGridDesc_BK0_N_BK1,
34 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
35 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
36 typename RsGridDescriptor_MBlock_MPerBlock,
37 typename Block2ETileMap,
38 bool HasMainKBlockLoop>
40#if CK_USE_LAUNCH_BOUNDS
44 const FloatAB* __restrict__ p_a_grid,
45 const FloatAB* __restrict__ p_b_grid,
46 FloatDsPointer p_ds_grid,
47 FloatE* __restrict__ p_e_grid,
48 FloatRsPointer p_rs_grid,
49 const AElementwiseOperation a_element_op,
50 const BElementwiseOperation b_element_op,
51 const CDEElementwiseOperation cde_element_op,
52 const QsElementwiseOperation qs_element_op,
53 const RsElementwiseOperation rs_element_op,
54 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
55 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
56 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
57 ds_grid_desc_mblock_mperblock_nblock_nperblock,
58 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
59 e_grid_desc_mblock_mperblock_nblock_nperblock,
60 const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
61 const Block2ETileMap block_2_etile_map)
63#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
64 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
66 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
68 GridwiseGemm::template Run<HasMainKBlockLoop>(
80 a_grid_desc_ak0_m_ak1,
81 b_grid_desc_bk0_n_bk1,
82 ds_grid_desc_mblock_mperblock_nblock_nperblock,
83 e_grid_desc_mblock_mperblock_nblock_nperblock,
84 rs_grid_desc_mblock_mperblock,
98 ignore = a_grid_desc_ak0_m_ak1;
99 ignore = b_grid_desc_bk0_n_bk1;
100 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
101 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
102 ignore = rs_grid_desc_mblock_mperblock;
103 ignore = block_2_etile_map;
125template <
typename ALayout,
130 typename GemmAccDataType,
131 typename CShuffleDataType,
134 typename ReduceAccDataType,
136 typename AElementwiseOperation,
137 typename BElementwiseOperation,
138 typename CDEElementwiseOperation,
139 typename QsElementwiseOperation,
140 typename RsElementwiseOperation,
141 typename ThreadReduceOperations,
142 typename RsGlobalMemoryDataOperation,
155 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
156 typename ABlockTransferThreadClusterArrangeOrder,
157 typename ABlockTransferSrcAccessOrder,
158 index_t ABlockTransferSrcVectorDim,
159 index_t ABlockTransferSrcScalarPerVector,
160 index_t ABlockTransferDstScalarPerVector_AK1,
161 bool ABlockLdsExtraM,
162 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
163 typename BBlockTransferThreadClusterArrangeOrder,
164 typename BBlockTransferSrcAccessOrder,
165 index_t BBlockTransferSrcVectorDim,
166 index_t BBlockTransferSrcScalarPerVector,
167 index_t BBlockTransferDstScalarPerVector_BK1,
168 bool BBlockLdsExtraN,
169 index_t CShuffleMXdlPerWavePerShuffle,
170 index_t CShuffleNXdlPerWavePerShuffle,
171 typename CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
172 index_t CDEReduceThreadTransferScalarPerVector_NPerBlock,
173 index_t RThreadTransferDstScalarPerVector_MPerBlock,
184 AElementwiseOperation,
185 BElementwiseOperation,
186 CDEElementwiseOperation,
187 QsElementwiseOperation,
188 RsElementwiseOperation>
209 const auto a_grid_desc_mraw_kraw = [&]() {
222 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
227 const auto b_grid_desc_nraw_kraw = [&]() {
240 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
245 const auto e_grid_desc_mraw_nraw = [&]() {
258 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
267 const auto MPad = M - MRaw;
283 return r_grid_desc_mraw;
293 template <index_t NXdlPerWave_>
302 AElementwiseOperation,
303 BElementwiseOperation,
304 CDEElementwiseOperation,
305 QsElementwiseOperation,
306 RsElementwiseOperation,
307 ThreadReduceOperations,
309 RsGlobalMemoryDataOperation,
314 NumGemmKPrefetchStage,
325 ABlockTransferThreadClusterLengths_AK0_M_AK1,
326 ABlockTransferThreadClusterArrangeOrder,
327 ABlockTransferSrcAccessOrder,
328 ABlockTransferSrcVectorDim,
329 ABlockTransferSrcScalarPerVector,
330 ABlockTransferDstScalarPerVector_AK1,
333 BBlockTransferThreadClusterLengths_BK0_N_BK1,
334 BBlockTransferThreadClusterArrangeOrder,
335 BBlockTransferSrcAccessOrder,
336 BBlockTransferSrcVectorDim,
337 BBlockTransferSrcScalarPerVector,
338 BBlockTransferDstScalarPerVector_BK1,
341 CShuffleMXdlPerWavePerShuffle,
342 CShuffleNXdlPerWavePerShuffle,
343 CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
344 CDEReduceThreadTransferScalarPerVector_NPerBlock,
345 RThreadTransferDstScalarPerVector_MPerBlock,
363 const void* p_b_grid,
364 std::array<const void*, NumDTensor> p_ds_grid,
366 std::array<void*, NumRTensor> p_rs_grid,
372 std::array<index_t, NumDTensor> StrideDs,
374 AElementwiseOperation a_element_op,
375 BElementwiseOperation b_element_op,
376 CDEElementwiseOperation cde_element_op,
377 QsElementwiseOperation qs_element_op,
378 RsElementwiseOperation rs_element_op)
379 :
p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
380 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
382 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
403 p_ds_grid_(i) =
static_cast<const DDataType*
>(p_ds_grid[i]);
408 p_rs_grid_(i) =
static_cast<RDataType*
>(p_rs_grid[i]);
446 template <
typename Gr
idwiseGemm>
455 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
458 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
460 ds_grid_desc_mblock_mperblock_nblock_nperblock = {};
464 rs_grid_desc_mblock_mperblock = {};
466 auto e_grid_desc_mblock_mperblock_nblock_nperblock =
467 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
471 const auto d_grid_desc_m_n =
473 ds_grid_desc_mblock_mperblock_nblock_nperblock(i) =
474 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
478 static_for<0, NumRTensor, 1>{}([&](
auto i) {
479 rs_grid_desc_mblock_mperblock(i) =
490 constexpr bool has_main_loop = has_main_k_block_loop.value;
495 typename GridwiseGemm::DsGridPointer,
497 typename GridwiseGemm::RsGridPointer,
498 AElementwiseOperation,
499 BElementwiseOperation,
500 CDEElementwiseOperation,
501 QsElementwiseOperation,
502 RsElementwiseOperation,
506 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
508 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
510 typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock,
512 typename GridwiseGemm::DefaultBlock2ETileMap,
532 ds_grid_desc_mblock_mperblock_nblock_nperblock,
533 e_grid_desc_mblock_mperblock_nblock_nperblock,
534 rs_grid_desc_mblock_mperblock,
540 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
558 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
601 std::array<const void*, NumDTensor> p_ds,
603 std::array<void*, NumRTensor> p_rs,
609 std::array<index_t, NumDTensor> StrideDs,
611 AElementwiseOperation a_element_op,
612 BElementwiseOperation b_element_op,
613 CDEElementwiseOperation cde_element_op,
614 QsElementwiseOperation qs_element_op,
615 RsElementwiseOperation rs_element_op)
641 std::array<const void*, NumDTensor> p_ds,
643 std::array<void*, NumRTensor> p_rs,
649 std::array<index_t, NumDTensor> StrideDs,
651 AElementwiseOperation a_element_op,
652 BElementwiseOperation b_element_op,
653 CDEElementwiseOperation cde_element_op,
654 QsElementwiseOperation qs_element_op,
655 RsElementwiseOperation rs_element_op)
override
657 return std::make_unique<Argument>(p_a,
679 return std::make_unique<Invoker>(
Invoker{});
685 auto str = std::stringstream();
688 str <<
"DeviceGemmMultipleDMultipleR_Xdl_CShuffle"
699 << MXdlPerWave <<
", "
700 << NXdlPerWave <<
", "
701 << ABlockTransferSrcScalarPerVector <<
", "
702 << BBlockTransferSrcScalarPerVector <<
", "
703 << CShuffleMXdlPerWavePerShuffle <<
", "
704 << CShuffleNXdlPerWavePerShuffle
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__global__ void kernel_gemm_multiple_d_multiple_r_xdl_cshuffle(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatDsPointer p_ds_grid, FloatE *__restrict__ p_e_grid, FloatRsPointer p_rs_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const QsElementwiseOperation qs_element_op, const RsElementwiseOperation rs_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, const Block2ETileMap block_2_etile_map)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:43
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:74
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEReduceThreadTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::RsGridPointer decltype(MakeTsGridPointer< RsDataType, false >()) RsGridPointer
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:317
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEReduceThreadTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::MakeRGridDescriptor_MBlock_MPerBlock __host__ static __device__ constexpr auto MakeRGridDescriptor_MBlock_MPerBlock(const RGridDesc_M &r_grid_desc_m)
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:279
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEReduceThreadTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::DefaultBlock2ETileMap remove_cvref_t< decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> DefaultBlock2ETileMap
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:313
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEReduceThreadTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::MakeDefaultAGridDescriptor_AK0_M_AK1 __host__ static __device__ constexpr auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:174
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEReduceThreadTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::DsGridPointer decltype(MakeTsGridPointer< DsDataType, true >()) DsGridPointer
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:316
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEReduceThreadTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const RGridDesc_M &r_grid_desc_m, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:208
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEReduceThreadTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::MakeDefaultBGridDescriptor_BK0_N_BK1 __host__ static __device__ constexpr auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:190
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:361
std::array< index_t, NumDTensor > stride_ds_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:421
RGridDesc_M r_grid_desc_m_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:426
Block2ETileMap block_2_etile_map_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:431
AElementwiseOperation a_element_op_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:434
CDEElementwiseOperation cde_element_op_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:436
AGridDesc_M_K a_grid_desc_m_k_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:423
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:429
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, std::array< void *, NumRTensor > p_rs_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, QsElementwiseOperation qs_element_op, RsElementwiseOperation rs_element_op)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:362
index_t MRaw_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:419
QsElementwiseOperation qs_element_op_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:437
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:416
RsElementwiseOperation rs_element_op_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:438
BGridDesc_N_K b_grid_desc_n_k_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:424
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:428
index_t NRaw_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:420
EGridDesc_M_N e_grid_desc_m_n_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:425
const BDataType * p_b_grid_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:415
GridwiseGemm64::RsGridPointer p_rs_grid_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:418
EDataType * p_e_grid_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:417
BElementwiseOperation b_element_op_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:435
const ADataType * p_a_grid_
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:414
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:443
DeviceOp::Argument Argument
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:444
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:447
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:555
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:189
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:562
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, std::array< void *, NumRTensor > p_rs, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, QsElementwiseOperation qs_element_op, RsElementwiseOperation rs_element_op)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:599
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:594
static constexpr auto matrix_padder
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:204
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:287
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:350
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:194
decltype(MakeEGridDescriptor_M_N(1, 1, 1)) EGridDesc_M_N
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:289
DeviceGemmMultipleDMultipleR_Xdl_CShuffle DeviceOp
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:190
static constexpr auto I1
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:200
static constexpr index_t NumRTensor
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:197
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:225
std::string GetTypeString() const override
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:683
static constexpr auto I2
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:201
static auto MakeInvoker()
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:636
static auto MakeRGridDescriptor_M(index_t MRaw)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:262
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:207
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:353
GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEReduceThreadTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched > GridwiseGemmBase
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:294
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:243
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, std::array< void *, NumRTensor > p_rs, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, QsElementwiseOperation qs_element_op, RsElementwiseOperation rs_element_op) override
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:639
static constexpr auto I3
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:202
typename GridwiseGemm64::DefaultBlock2ETileMap Block2ETileMap
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:357
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:196
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:348
decltype(MakeRGridDescriptor_M(1)) RGridDesc_M
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:290
static constexpr auto I0
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:199
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:193
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:288
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:347
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:677
Definition device_gemm_multiple_d_multiple_r.hpp:41
Definition matrix_padder.hpp:180