24template <
typename ALayout,
32 typename GemmAccDataType,
33 typename CShuffleDataType,
34 typename AElementwiseOperation,
35 typename BElementwiseOperation,
36 typename CElementwiseOperation,
48 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
49 typename ABlockTransferThreadClusterArrangeOrder,
50 typename ABlockTransferSrcAccessOrder,
51 index_t ABlockTransferSrcVectorDim,
52 index_t ABlockTransferSrcScalarPerVector,
53 index_t ABlockTransferDstScalarPerVector_AK1,
55 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
56 typename BBlockTransferThreadClusterArrangeOrder,
57 typename BBlockTransferSrcAccessOrder,
58 index_t BBlockTransferSrcVectorDim,
59 index_t BBlockTransferSrcScalarPerVector,
60 index_t BBlockTransferDstScalarPerVector_BK1,
62 index_t CShuffleMXdlPerWavePerShuffle,
63 index_t CShuffleNXdlPerWavePerShuffle,
64 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
65 typename CDEShuffleBlockTransferScalarPerVectors,
68 typename ComputeTypeA = CDataType,
69 typename ComputeTypeB = ComputeTypeA,
70 typename LDSTypeA = ComputeTypeA,
71 typename LDSTypeB = ComputeTypeB>
81 AElementwiseOperation,
82 BElementwiseOperation,
83 CElementwiseOperation>
91 template <index_t NXdlPerWave_>
103 AElementwiseOperation,
104 BElementwiseOperation,
105 CElementwiseOperation,
117 ABlockTransferThreadClusterLengths_AK0_M_AK1,
118 ABlockTransferThreadClusterArrangeOrder,
119 ABlockTransferSrcAccessOrder,
120 ABlockTransferSrcVectorDim,
121 ABlockTransferSrcScalarPerVector,
122 ABlockTransferDstScalarPerVector_AK1,
125 BBlockTransferThreadClusterLengths_BK0_N_BK1,
126 BBlockTransferThreadClusterArrangeOrder,
127 BBlockTransferSrcAccessOrder,
128 BBlockTransferSrcVectorDim,
129 BBlockTransferSrcScalarPerVector,
130 BBlockTransferDstScalarPerVector_BK1,
133 CShuffleMXdlPerWavePerShuffle,
134 CShuffleNXdlPerWavePerShuffle,
135 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
136 CDEShuffleBlockTransferScalarPerVectors,
146 using Argument =
typename GridwiseGemm64::Argument;
152 template <
typename Gr
idwiseGemm>
153 float RunImp(
const typename GridwiseGemm::Argument& arg,
156 if(stream_config.log_level_ > 0)
161 if(!GridwiseGemm::CheckValidity(arg))
163 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
167 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
171 index_t k_grain = arg.KBatch * KPerBlock;
172 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
174 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
176 const auto Run = [&](
const auto& kernel) {
177 if(stream_config.flush_cache)
180 std::array<std::size_t, NumDTensor> DsSize;
184 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
185 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
186 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
187 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
190 a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType);
192 b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType);
194 const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
195 arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
199 DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() *
sizeof(DDataType);
204 stream_config.rotating_count,
208 rotating_mem.Print();
210 auto run_flush_cache = [&]() {
217 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
219 arg_.M * arg_.N *
sizeof(CDataType),
220 stream_config.stream_id_));
235 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
237 arg.M * arg.N *
sizeof(CDataType),
238 stream_config.stream_id_));
241 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
245 constexpr auto estimated_reg_a = MPerBlock * KPerBlock *
sizeof(ADataType) / BlockSize /
246 4 * (1 + GridwiseGemm::NWave);
247 constexpr auto estimated_reg_b =
248 NPerBlock * KPerBlock *
sizeof(BDataType) / BlockSize / 4 * (2);
249 constexpr auto estimated_reg_c =
250 MPerBlock * NPerBlock *
sizeof(GemmAccDataType) / BlockSize / 4;
251 constexpr auto estimated_reg_total =
252 estimated_reg_a + estimated_reg_b + estimated_reg_c;
254 constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
259 if(has_main_k_block_loop)
266 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
289 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
316 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
341 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
367 throw std::runtime_error(
"todo: only v1 v2 and v3 support now");
378 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
401 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
423 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
448 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
473 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
499 throw std::runtime_error(
"todo: only v3 support now");
513 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
546 if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
563 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(arg));
577 std::array<const void*, NumDTensor> p_ds,
584 std::array<index_t, NumDTensor> StrideDs,
587 AElementwiseOperation a_element_op,
588 BElementwiseOperation b_element_op,
589 CElementwiseOperation c_element_op)
591 return Argument{
static_cast<const ADataType*
>(p_a),
592 static_cast<const BDataType*
>(p_b),
594 static_cast<CDataType*
>(p_c),
613 std::array<const void*, NumDTensor> p_ds,
620 std::array<ck::index_t, NumDTensor> StrideDs,
623 AElementwiseOperation a_element_op,
624 BElementwiseOperation b_element_op,
625 CElementwiseOperation c_element_op)
override
627 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
628 static_cast<const BDataType*
>(p_b),
630 static_cast<CDataType*
>(p_c),
647 return std::make_unique<Invoker>(
Invoker{});
653 auto str = std::stringstream();
655 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
659 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
665 str <<
"DeviceGemmXdlUniversal"
668 << std::string(ALayout::name)[0]
669 << std::string(BLayout::name)[0]
670 << std::string(CLayout::name)[0]
675 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
677 << MPerXDL<<
"x"<<NPerXDL <<
", "
679 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
681 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
682 <<
"BlkGemmPipelineScheduler: "
683 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
684 <<
"BlkGemmPipelineVersion: "
685 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
686 <<
"BlkGemmPipelinePrefetchStages: "
687 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:39
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:82
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:168
ck::GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp:950
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:151
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:510
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:153
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:84
std::string GetTypeString() const override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:651
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:144
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:611
GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemmBase
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:92
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:575
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:88
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:570
typename GridwiseGemm64::Argument Argument
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:146
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:523
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:517
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:87
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:143
int GetPreShuffleParameters() override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:147
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:86
static auto MakeInvoker()
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:608
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp:645
Definition device_gemm_multiple_d.hpp:125
Definition flush_cache.hpp:174