42template <
typename GridwiseGemm,
54 typename OffsettedBlockToCTileMap,
55 typename LocalBlock2ETileMap,
56 typename AElementwiseOperation,
57 typename BElementwiseOperation,
58 typename CDEElementwiseOperation,
62#if CK_USE_LAUNCH_BOUNDS
67 const AElementwiseOperation a_element_op,
68 const BElementwiseOperation b_element_op,
69 const CDEElementwiseOperation cde_element_op)
71#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
72 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
74 constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
75 __shared__
uint8_t p_shared[shared_size];
76 __shared__
uint8_t p_shared1[shared_size];
78 const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*
>(
81 constexpr auto NumDTensor = DsDataType::Size();
98 while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) &&
99 group_id < group_count)
101 group_offset += grid_size_grp;
104 if(group_id >= group_count)
107 M = gemm_desc_ptr[group_id].M;
108 N = gemm_desc_ptr[group_id].N;
109 K = gemm_desc_ptr[group_id].K;
111 if(M == 0 || N == 0 || K == 0)
118 LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset);
119 grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
121 gemm_tile_id_start = group_offset;
122 gemm_tile_id_end = group_offset + grid_size_grp;
125 using DsGridPointer =
decltype(GridwiseGemm::MakeDsGridPointer());
126 DsGridPointer p_ds_grid;
130 p_ds_grid(i) =
static_cast<const DDataType*
>(gemm_desc_ptr[group_id].p_ds_grid[i]);
133 static constexpr index_t kbatch = 1;
134 static constexpr index_t k_grain = kbatch * KPerBlock;
135 index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock;
137 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
140 b2c_tile_map.UpdateTileOffset(tile_offset);
142 using Problem =
typename GridwiseGemm::Problem;
143 auto problem = Problem(gemm_desc_ptr[group_id].M,
144 gemm_desc_ptr[group_id].N,
145 gemm_desc_ptr[group_id].K,
146 gemm_desc_ptr[group_id].StrideA,
147 gemm_desc_ptr[group_id].StrideB,
148 gemm_desc_ptr[group_id].StrideDs,
149 gemm_desc_ptr[group_id].StrideE,
152 if(has_main_k_block_loop)
161 static_cast<const ADataType*
>(gemm_desc_ptr[group_id].p_a_grid),
162 static_cast<const BDataType*
>(gemm_desc_ptr[group_id].p_b_grid),
164 static_cast<EDataType*
>(gemm_desc_ptr[group_id].p_e_grid),
165 static_cast<void*
>(p_shared),
174 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
180 static_cast<const ADataType*
>(gemm_desc_ptr[group_id].p_a_grid),
181 static_cast<const BDataType*
>(gemm_desc_ptr[group_id].p_b_grid),
183 static_cast<EDataType*
>(gemm_desc_ptr[group_id].p_e_grid),
184 static_cast<void*
>(p_shared),
191 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
197 static_cast<const ADataType*
>(gemm_desc_ptr[group_id].p_a_grid),
198 static_cast<const BDataType*
>(gemm_desc_ptr[group_id].p_b_grid),
200 static_cast<EDataType*
>(gemm_desc_ptr[group_id].p_e_grid),
201 static_cast<void*
>(p_shared),
209 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
211 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
217 static_cast<const ADataType*
>(gemm_desc_ptr[group_id].p_a_grid),
218 static_cast<const BDataType*
>(gemm_desc_ptr[group_id].p_b_grid),
220 static_cast<EDataType*
>(gemm_desc_ptr[group_id].p_e_grid),
221 static_cast<void*
>(p_shared),
230 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
238 static_cast<const ADataType*
>(gemm_desc_ptr[group_id].p_a_grid),
239 static_cast<const BDataType*
>(gemm_desc_ptr[group_id].p_b_grid),
241 static_cast<EDataType*
>(gemm_desc_ptr[group_id].p_e_grid),
242 static_cast<void*
>(p_shared),
251 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
259 static_cast<const ADataType*
>(gemm_desc_ptr[group_id].p_a_grid),
260 static_cast<const BDataType*
>(gemm_desc_ptr[group_id].p_b_grid),
262 static_cast<EDataType*
>(gemm_desc_ptr[group_id].p_e_grid),
263 static_cast<void*
>(p_shared),
272 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
280 static_cast<const ADataType*
>(gemm_desc_ptr[group_id].p_a_grid),
281 static_cast<const BDataType*
>(gemm_desc_ptr[group_id].p_b_grid),
283 static_cast<EDataType*
>(gemm_desc_ptr[group_id].p_e_grid),
284 static_cast<void*
>(p_shared),
293 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
295 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
301 static_cast<const ADataType*
>(gemm_desc_ptr[group_id].p_a_grid),
302 static_cast<const BDataType*
>(gemm_desc_ptr[group_id].p_b_grid),
304 static_cast<EDataType*
>(gemm_desc_ptr[group_id].p_e_grid),
305 static_cast<void*
>(p_shared),
314 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
322 static_cast<const ADataType*
>(gemm_desc_ptr[group_id].p_a_grid),
323 static_cast<const BDataType*
>(gemm_desc_ptr[group_id].p_b_grid),
325 static_cast<EDataType*
>(gemm_desc_ptr[group_id].p_e_grid),
326 static_cast<void*
>(p_shared),
338 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
344 static_cast<const ADataType*
>(gemm_desc_ptr[group_id].p_a_grid),
345 static_cast<const BDataType*
>(gemm_desc_ptr[group_id].p_b_grid),
347 static_cast<EDataType*
>(gemm_desc_ptr[group_id].p_e_grid),
348 static_cast<void*
>(p_shared),
349 static_cast<void*
>(p_shared1),
362 static_cast<const ADataType*
>(gemm_desc_ptr[group_id].p_a_grid),
363 static_cast<const BDataType*
>(gemm_desc_ptr[group_id].p_b_grid),
365 static_cast<EDataType*
>(gemm_desc_ptr[group_id].p_e_grid),
366 static_cast<void*
>(p_shared),
367 static_cast<void*
>(p_shared1),
384 static_cast<const ADataType*
>(gemm_desc_ptr[group_id].p_a_grid),
385 static_cast<const BDataType*
>(gemm_desc_ptr[group_id].p_b_grid),
387 static_cast<EDataType*
>(gemm_desc_ptr[group_id].p_e_grid),
388 static_cast<void*
>(p_shared),
400 }
while(group_id < group_count);
403 ignore = gemm_descs_const;
411template <
typename ALayout,
417 typename AccDataType,
418 typename CShuffleDataType,
421 typename AElementwiseOperation,
422 typename BElementwiseOperation,
423 typename CDEElementwiseOperation,
436 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
437 typename ABlockTransferThreadClusterArrangeOrder,
438 typename ABlockTransferSrcAccessOrder,
439 index_t ABlockTransferSrcVectorDim,
440 index_t ABlockTransferSrcScalarPerVector,
441 index_t ABlockTransferDstScalarPerVector_AK1,
443 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
444 typename BBlockTransferThreadClusterArrangeOrder,
445 typename BBlockTransferSrcAccessOrder,
446 index_t BBlockTransferSrcVectorDim,
447 index_t BBlockTransferSrcScalarPerVector,
448 index_t BBlockTransferDstScalarPerVector_BK1,
450 index_t CShuffleMXdlPerWavePerShuffle,
451 index_t CShuffleNXdlPerWavePerShuffle,
452 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
453 typename CDEShuffleBlockTransferScalarPerVectors,
456 typename ComputeTypeA = EDataType,
457 typename ComputeTypeB = ComputeTypeA>
468 AElementwiseOperation,
469 BElementwiseOperation,
470 CDEElementwiseOperation>
478 template <index_t NXdlPerWave_>
490 AElementwiseOperation,
491 BElementwiseOperation,
492 CDEElementwiseOperation,
504 ABlockTransferThreadClusterLengths_AK0_M_AK1,
505 ABlockTransferThreadClusterArrangeOrder,
506 ABlockTransferSrcAccessOrder,
507 ABlockTransferSrcVectorDim,
508 ABlockTransferSrcScalarPerVector,
509 ABlockTransferDstScalarPerVector_AK1,
512 BBlockTransferThreadClusterLengths_BK0_N_BK1,
513 BBlockTransferThreadClusterArrangeOrder,
514 BBlockTransferSrcAccessOrder,
515 BBlockTransferSrcVectorDim,
516 BBlockTransferSrcScalarPerVector,
517 BBlockTransferDstScalarPerVector_BK1,
520 CShuffleMXdlPerWavePerShuffle,
521 CShuffleNXdlPerWavePerShuffle,
522 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
523 CDEShuffleBlockTransferScalarPerVectors,
539 std::vector<const void*>& ,
540 std::vector<std::array<const void*, NumDTensor>>& ,
541 std::vector<void*>& ,
542 const std::vector<GemmDesc>& gemm_descs,
543 AElementwiseOperation a_element_op,
544 BElementwiseOperation b_element_op,
545 CDEElementwiseOperation cde_element_op,
546 int occupancy_num_blocks,
557 for(
const auto& desc : gemm_descs)
559 const auto M = desc.M_;
560 const auto N = desc.N_;
562 tile_count_ += b2c_tile_map.CalculateGridSize(M, N);
609 template <
typename Gr
idwiseGemm>
611 const void* dev_gemm_args,
614 if(dev_gemm_args ==
nullptr)
616 std::ostringstream err;
617 err <<
"The gemm arguments device buffer is not allocated!" <<
" In " << __FILE__
618 <<
":" << __LINE__ <<
", in function: " << __func__;
619 throw std::runtime_error(err.str());
623 ave_time = DispatchKernel<GridwiseGemm>(arg, dev_gemm_args, stream_config);
641 template <
typename Gr
idwiseGemm>
646 std::ostringstream err;
647 err <<
"The gemm arguments device buffer is not allocated!" <<
" In " << __FILE__
648 <<
":" << __LINE__ <<
", in function: " << __func__;
649 throw std::runtime_error(err.str());
660 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
664 template <
typename Gr
idwiseGemm>
665 float DispatchKernel(
const Argument& arg,
666 const void* dev_gemm_args,
683 AElementwiseOperation,
684 BElementwiseOperation,
685 CDEElementwiseOperation,
688 return LaunchKernel(kernel, arg, dev_gemm_args, stream_config);
691 template <
typename KernelFunction>
692 int CalculateMaxOccupancyGridSize(
const KernelFunction& kernel,
696 int occ_num_blocks = 0;
697 size_t dyn_shared_mem_per_blk = 0;
699 &occ_num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk));
701 int cu_count = getAvailableComputeUnitCount(stream_config);
705 std::cout <<
"MaxActiveBlocksPerCU: " << occ_num_blocks
706 <<
", available CUs count: " << cu_count <<
", occup. grid size: "
714 template <
typename KernelFunction>
715 float LaunchKernel(
const KernelFunction& kernel,
717 const void* dev_gemm_args,
720 int grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config);
724 std::cout <<
"grid_size: " << grid_size <<
" tile_count: " << arg.tile_count_
739 arg.cde_element_op_);
755 bool supported =
true;
761 std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
762 std::array<index_t, NumDTensor> stride_Ds;
776 using GridArg =
typename GridwiseGemm64::Argument;
777 GridArg gridwise_arg(
nullptr,
779 placeholder_p_ds_grid,
804 using GridArg =
typename GridwiseGemm32::Argument;
805 GridArg gridwise_arg(
nullptr,
807 placeholder_p_ds_grid,
859 AElementwiseOperation,
860 BElementwiseOperation,
861 CDEElementwiseOperation,
865 hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
887 AElementwiseOperation,
888 BElementwiseOperation,
889 CDEElementwiseOperation,
893 hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
900 std::vector<const void*>& p_Bs,
901 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
902 std::vector<void*>& p_Es,
903 std::vector<GemmDesc>& gemm_descs,
904 AElementwiseOperation a_elementwise_op,
905 BElementwiseOperation b_elementwise_op,
906 CDEElementwiseOperation cde_elementwise_op)
911 hipDeviceProp_t dev_prop;
915 num_cu = dev_prop.multiProcessorCount;
929 std::unique_ptr<BaseArgument>
931 std::vector<const void*>& p_Bs,
932 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
933 std::vector<void*>& p_Es,
934 std::vector<GemmDesc>& gemm_descs,
935 AElementwiseOperation a_elementwise_op,
936 BElementwiseOperation b_elementwise_op,
937 CDEElementwiseOperation cde_elementwise_op)
override
942 hipDeviceProp_t dev_prop;
946 num_cu = dev_prop.multiProcessorCount;
948 return std::make_unique<Argument>(p_As,
964 return std::make_unique<Invoker>(
Invoker{});
969 auto str = std::ostringstream();
971 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
975 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
983 str <<
"DeviceGroupedGemmMultipleDXdlCShuffleTileLoop"
985 << std::string(ALayout::name)[0] <<
","
986 << std::string(BLayout::name)[0] <<
","
987 << std::string(ELayout::name)[0] <<
","
996 << MXdlPerWave <<
", "
997 << NXdlPerWave <<
", "
998 << ABlockTransferSrcScalarPerVector <<
", "
999 << BBlockTransferSrcScalarPerVector <<
", "
1000 << CShuffleMXdlPerWavePerShuffle <<
", "
1001 << CShuffleNXdlPerWavePerShuffle <<
", "
1003 <<
"BlkGemmPipelineScheduler: "
1004 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
1005 <<
"BlkGemmPipelineVersion: "
1006 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
1014 void* p_dev_kernel_args,
1015 const void* p_host_kernel_args)
const
1021 hipMemcpyHostToDevice));
1025 void* p_dev_kernel_args,
1026 const void* p_host_kernel_args)
const override
1029 *
dynamic_cast<Argument*
>(p_arg), p_dev_kernel_args, p_host_kernel_args);
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
__global__ void kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op)
Entry point kernel for device-wide Grouped GEMM operation.
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:65
Definition convolution_backward_data_specialization.hpp:7
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
unsigned char uint8_t
Definition stdint.h:124
Definition ck/stream_config.hpp:10
int log_level_
Definition ck/stream_config.hpp:13
Definition block_to_ctile_map.hpp:271
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:157
ck::GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:1186
Definition block_to_ctile_map.hpp:920
Definition block_to_ctile_map.hpp:872
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:537
CDEElementwiseOperation cde_element_op_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:573
const std::vector< GemmDesc > & gemm_descs_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:570
index_t tile_count_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:574
int gpu_cu_count_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:569
BElementwiseOperation b_element_op_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:572
const void * p_dev_gemm_args_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:567
index_t group_count_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:566
int occupancy_num_blocks_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:568
Argument(std::vector< const void * > &, std::vector< const void * > &, std::vector< std::array< const void *, NumDTensor > > &, std::vector< void * > &, const std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, int occupancy_num_blocks, int gpu_cu_count)
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:538
AElementwiseOperation a_element_op_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:571
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:595
float Run(const Argument &arg, const void *dev_gemm_args, const StreamConfig &stream_config=StreamConfig{})
Launch Grouped Gemm kernel.
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:610
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:657
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Launch Grouped Gemm kernel.
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:642
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:578
static constexpr int BLOCK_SUBSCRIPTION_FACTOR
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:581
static int GetCuBlocks()
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:586
static constexpr int CU_SIMDS
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:583
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:471
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:962
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:1042
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:833
BlockToCTileMap_Grouped_M00_N0_M01Adapt< 8, MPerBlock, NPerBlock > Block2ETileMap
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:532
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:474
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:749
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args, const void *p_host_kernel_args) const override
Sets the device kernel arguments pointer and may copy data to device.
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:1024
static auto MakeInvoker()
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:960
void SetDeviceKernelArgs(Argument &arg, void *p_dev_kernel_args) const
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:1032
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:743
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_elementwise_op, BElementwiseOperation b_elementwise_op, CDEElementwiseOperation cde_elementwise_op) override
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:930
static constexpr index_t NumDTensor
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:476
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const override
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:1037
static int GetKernelOccupancy()
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:838
GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:479
OffsettedBlockToCTileMap2< Block2ETileMap > OffsettedLocalBlock2ETileMap
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:533
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop DeviceOp
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:472
std::string GetTypeString() const override
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:967
GroupedGemmKernelArgument< NumDTensor > KernelArguments
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:531
static constexpr auto NXdlPerWave32
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:475
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_elementwise_op, BElementwiseOperation b_elementwise_op, CDEElementwiseOperation cde_elementwise_op)
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:899
void SetDeviceKernelArgs(Argument &arg, void *p_dev_kernel_args, const void *p_host_kernel_args) const
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:1013
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:529
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:528
Grouped GEMM kernel using output Tile Looping algorithm.
Definition device_grouped_gemm_tile_loop.hpp:43
Definition device_grouped_gemm.hpp:80
Structure representing single GEMM problem arguments.
Definition device_grouped_gemm.hpp:29