25template <
typename ALayout,
30 typename BScaleDataType,
32 typename GemmAccDataType,
33 typename CShuffleDataType,
34 typename AElementwiseOperation,
35 typename BElementwiseOperation,
36 typename CElementwiseOperation,
50 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
51 typename ABlockTransferThreadClusterArrangeOrder,
52 typename ABlockTransferSrcAccessOrder,
53 index_t ABlockTransferSrcVectorDim,
54 index_t ABlockTransferSrcScalarPerVector,
55 index_t ABlockTransferDstScalarPerVector_AK1,
57 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
58 typename BBlockTransferThreadClusterArrangeOrder,
59 typename BBlockTransferSrcAccessOrder,
60 index_t BBlockTransferSrcVectorDim,
61 index_t BBlockTransferSrcScalarPerVector,
62 index_t BBlockTransferDstScalarPerVector_BK1,
64 index_t CShuffleMXdlPerWavePerShuffle,
65 index_t CShuffleNXdlPerWavePerShuffle,
66 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
67 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
70 typename ComputeTypeA = CDataType,
71 typename ComputeTypeB = ComputeTypeA,
72 bool PermuteA =
false,
73 bool PermuteB =
false>
83 AElementwiseOperation,
84 BElementwiseOperation,
85 CElementwiseOperation>
89 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
92 template <index_t NXdlPerWave_>
102 AElementwiseOperation,
103 BElementwiseOperation,
104 CElementwiseOperation,
118 ABlockTransferThreadClusterLengths_AK0_M_AK1,
119 ABlockTransferThreadClusterArrangeOrder,
120 ABlockTransferSrcAccessOrder,
121 ABlockTransferSrcVectorDim,
122 ABlockTransferSrcScalarPerVector,
123 ABlockTransferDstScalarPerVector_AK1,
126 BBlockTransferThreadClusterLengths_BK0_N_BK1,
127 BBlockTransferThreadClusterArrangeOrder,
128 BBlockTransferSrcAccessOrder,
129 BBlockTransferSrcVectorDim,
130 BBlockTransferSrcScalarPerVector,
131 BBlockTransferDstScalarPerVector_BK1,
134 CShuffleMXdlPerWavePerShuffle,
135 CShuffleNXdlPerWavePerShuffle,
136 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
137 CShuffleBlockTransferScalarPerVector_NPerBlock,
147 using Argument =
typename GridwiseGemm64::Argument;
164 struct Invoker :
public BaseInvoker
166 template <
typename Gr
idwiseGemm>
167 float RunImp(
const typename GridwiseGemm::Argument& arg,
170 if(stream_config.log_level_ > 0)
175 if(!GridwiseGemm::CheckValidity(arg))
177 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
181 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
185 index_t k_grain = arg.KBatch * KPerBlock;
186 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
188 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
190 const auto Run = [&](
const auto& kernel) {
191 if(stream_config.flush_cache)
195 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
196 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
197 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
198 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
200 auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
202 auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
206 arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
207 rotating_mem.Print();
209 auto run_flush_cache = [&]() {
216 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
218 arg_.M * arg_.N *
sizeof(CDataType),
219 stream_config.stream_id_));
234 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
236 arg.M * arg.N *
sizeof(CDataType),
237 stream_config.stream_id_));
240 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
244 constexpr index_t minimum_occupancy =
247 MPerBlock * NPerBlock * KPerBlock *
sizeof(ADataType) <= 128 * 128 * 64 * 2)
252 if(has_main_k_block_loop)
282 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
292 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
304 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
306 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
318 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
320 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
333 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
335 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
348 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
350 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
363 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
365 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
377 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
379 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
394 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
404 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
416 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
418 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
430 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
432 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
445 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
447 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
460 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
462 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
475 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
477 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
489 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
491 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
510 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
533 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
559 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
582 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
639 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
686 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(arg));
703 const BDataType* p_b,
712 const BScaleDataType* p_b_scale,
714 AElementwiseOperation a_element_op,
715 BElementwiseOperation b_element_op,
716 CElementwiseOperation c_element_op)
748 const void* p_b_scale,
750 AElementwiseOperation a_element_op,
751 BElementwiseOperation b_element_op,
752 CElementwiseOperation c_element_op)
override
754 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
755 static_cast<const BDataType*
>(p_b),
756 static_cast<CDataType*
>(p_c),
764 static_cast<const BScaleDataType*
>(p_b_scale),
774 return std::make_unique<Invoker>(
Invoker{});
780 auto str = std::stringstream();
782 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
786 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
794 str <<
"DeviceGemmXdlUniversal"
797 << std::string(ALayout::name)[0]
798 << std::string(BLayout::name)[0]
799 << std::string(CLayout::name)[0]
804 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
806 << MPerXDL<<
"x"<<NPerXDL <<
", "
808 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
810 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
811 <<
"BlkGemmPipelineScheduler: "
812 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
813 <<
"BlkGemmPipelineVersion: "
814 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
815 <<
"BlkGemmPipelinePrefetchStages: "
816 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1202
Definition data_type.hpp:187
Definition device_base.hpp:197
Helper structure responsible for kernel invocation.
Definition device_gemm_xdl_cshuffle_v3.hpp:263
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:167
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_cshuffle_v3.hpp:747
"Universal" GEMM operation with SplitK support.
Definition device_gemm_xdl_cshuffle_v3.hpp:178
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_cshuffle_v3.hpp:235
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:693
index_t GetKPerBlock() override
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:698
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:772
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:649
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:643
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_cshuffle_v3.hpp:234
std::string GetTypeString() const override
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:778
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_cshuffle_v3.hpp:181
GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemmBase
Definition device_gemm_xdl_cshuffle_v3.hpp:185
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideScaleB, const BScaleDataType *p_b_scale, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:702
bool GetPermuteB() override
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:700
static auto MakeInvoker()
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:735
static constexpr index_t BPackedSize
Definition device_gemm_xdl_cshuffle_v3.hpp:246
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl_cshuffle_v3.hpp:237
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_cshuffle_v3.hpp:182
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideScaleB, const void *p_b_scale, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:738
static constexpr index_t APackedSize
Definition device_gemm_xdl_cshuffle_v3.hpp:239
Definition device_gemm_v2.hpp:93
Definition flush_cache.hpp:299