24template <
typename GridwiseGemm,
25 typename ComputePtrOffsetOfStridedBatch,
26 bool HasMainKBlockLoop,
31#if CK_USE_LAUNCH_BOUNDS
35 typename GridwiseGemm::Argument karg,
38 const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
40#if(defined(__gfx11__) || defined(__gfx12__))
45 (std::is_same_v<c_data_type, ck::half_t> ||
46 std::is_same_v<c_data_type, ck::bhalf_t>)))
49 constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
50 typename GridwiseGemm::EpilogueCShuffle>();
56 __shared__
char p_shared[LDS_size];
69 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
72 typename GridwiseGemm::AsGridPointer p_as_grid_shift;
76 p_as_grid_shift(i) =
static_cast<const ADataType_*
>(karg.p_as_grid[i]) +
77 splitk_batch_offset.a_k_split_offset[i] + a_batch_offset;
81 typename GridwiseGemm::BsGridPointer p_bs_grid_shift;
85 p_bs_grid_shift(i) =
static_cast<const BDataType_*
>(karg.p_bs_grid[i]) +
86 splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
89 auto epilogue_args =
typename GridwiseGemm::EpilogueCShuffle{};
91 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
95 karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
96 karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset,
103#if defined(__gfx11__)
108 ignore = compute_ptr_offset_of_batch;
208template <
typename ALayout,
213 typename BScaleDataType,
215 typename AccDataType,
216 typename CShuffleDataType,
217 typename AElementwiseOperation,
218 typename BElementwiseOperation,
219 typename CElementwiseOperation,
233 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
234 typename ABlockTransferThreadClusterArrangeOrder,
235 typename ABlockTransferSrcAccessOrder,
236 index_t ABlockTransferSrcVectorDim,
237 index_t ABlockTransferSrcScalarPerVector,
238 index_t ABlockTransferDstScalarPerVector_AK1,
239 bool ABlockLdsExtraM,
240 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
241 typename BBlockTransferThreadClusterArrangeOrder,
242 typename BBlockTransferSrcAccessOrder,
243 index_t BBlockTransferSrcVectorDim,
244 index_t BBlockTransferSrcScalarPerVector,
245 index_t BBlockTransferDstScalarPerVector_BK1,
246 bool BBlockLdsExtraN,
247 index_t CShuffleMRepeatPerShuffle,
248 index_t CShuffleNRepeatPerShuffle,
249 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
250 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
253 typename ComputeTypeA = CDataType,
254 typename ComputeTypeB = ComputeTypeA,
255 bool PermuteA =
false,
256 bool PermuteB =
false>
267 AElementwiseOperation,
268 BElementwiseOperation,
269 CElementwiseOperation>
273 static_assert(PermuteA ==
false,
274 "Permute A functionality not supported by DeviceBatchedGemm operations.\n");
275 static_assert(PermuteB ==
false,
276 "Permute B functionality not supported by DeviceBatchedGemm operations.\n");
284 : BatchStrideA_(BatchStrideA),
285 BatchStrideB_(BatchStrideB),
286 BatchStrideC_(BatchStrideC),
287 BatchStrideScaleB_(BatchStrideScaleB)
293 return g_idx *
static_cast<long_index_t>(BatchStrideA_);
303 return g_idx *
static_cast<long_index_t>(BatchStrideC_);
307 return g_idx *
static_cast<long_index_t>(BatchStrideScaleB_);
330 AElementwiseOperation,
331 BElementwiseOperation,
332 CElementwiseOperation,
346 ABlockTransferThreadClusterLengths_AK0_M_AK1,
347 ABlockTransferThreadClusterArrangeOrder,
348 ABlockTransferSrcAccessOrder,
349 ABlockTransferSrcVectorDim,
350 ABlockTransferSrcScalarPerVector,
351 ABlockTransferDstScalarPerVector_AK1,
354 BBlockTransferThreadClusterLengths_BK0_N_BK1,
355 BBlockTransferThreadClusterArrangeOrder,
356 BBlockTransferSrcAccessOrder,
357 BBlockTransferSrcVectorDim,
358 BBlockTransferSrcScalarPerVector,
359 BBlockTransferDstScalarPerVector_BK1,
362 CShuffleMRepeatPerShuffle,
363 CShuffleNRepeatPerShuffle,
364 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
377 const BDataType* p_b_grid_,
378 CDataType* p_c_grid_,
390 const BScaleDataType* p_b_scale_grid_,
393 AElementwiseOperation a_element_op_,
394 BElementwiseOperation b_element_op_,
395 CElementwiseOperation c_element_op_,
396 bool is_reduce_ =
false)
398 std::array<const void*, 1>{p_b_grid_},
399 std::array<const void*, 0>{},
404 std::array<index_t, 1>{StrideA_},
405 std::array<index_t, 1>{StrideB_},
406 std::array<index_t, 0>{},
417 BatchStrideA_, BatchStrideB_, BatchStrideC_, BatchStrideScaleB_}
443 if(stream_config.log_level_ > 0)
446 GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
451 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
467 index_t k_grain = arg.KBatch * KPerBlock;
468 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
472 const auto Run = [&](
const auto& kernel) {
473 if(stream_config.flush_cache)
478 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
480 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
487 std::array<std::size_t, 1> size_as_buffers;
488 size_as_buffers[0] = a_grid_desc_ak0_m_ak1[
Number<0>{}].GetElementSpaceSize() *
491 std::array<std::size_t, 1> size_bs_buffers;
492 size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[
Number<0>{}].GetElementSpaceSize() *
500 stream_config.rotating_count,
503 std::array<std::size_t, 0>{});
504 rotating_mem.Print();
506 auto run_flush_cache = [&]() {
516 hipMemsetAsync(arg_.p_e_grid,
518 arg_.Batch * arg_.M * arg_.N *
sizeof(CDataType),
519 stream_config.stream_id_));
530 arg_.compute_ptr_offset_of_batch);
534 auto clear_workspace = [&]() {
542 hipMemsetAsync(arg.p_e_grid,
544 arg.
Batch * arg.M * arg.N *
sizeof(CDataType),
545 stream_config.stream_id_));
560 constexpr index_t minimum_occupancy = []() {
567 return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
575 if(has_main_k_block_loop)
585 ComputePtrOffsetOfStridedBatch,
604 throw std::runtime_error(
"Pipeline not implemented");
616 ComputePtrOffsetOfStridedBatch,
642 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
659 if constexpr(std::is_same_v<CDataType, ck::half_t> ||
660 std::is_same_v<CDataType, ck::bhalf_t>)
669 if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
670 std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
699 const BDataType* p_b,
712 const BScaleDataType* p_b_scale,
714 AElementwiseOperation,
715 BElementwiseOperation,
716 CElementwiseOperation,
755 const void* p_b_scale,
758 AElementwiseOperation a_element_op,
759 BElementwiseOperation b_element_op,
760 CElementwiseOperation c_element_op)
override
762 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
763 static_cast<const BDataType*
>(p_b),
764 static_cast<CDataType*
>(p_c),
776 static_cast<const BScaleDataType*
>(p_b_scale),
787 return std::make_unique<Invoker>(
Invoker{});
793 auto str = std::stringstream();
795 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
799 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
807 str <<
"DeviceBatchedGemm_Wmma_CShuffleV3_BScale"
810 << std::string(ALayout::name)[0]
811 << std::string(BLayout::name)[0]
812 << std::string(CLayout::name)[0]
817 << MPerBlock <<
"x" << NPerBlock <<
"x" << KPerBlock <<
", "
819 << MPerWmma <<
"x"<<NPerWmma <<
", "
821 << MRepeat <<
"x" << NRepeat <<
", "
823 << ABlockTransferSrcScalarPerVector <<
"x" << BBlockTransferSrcScalarPerVector <<
", "
824 <<
"BlkGemmPipelineScheduler: "
825 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
826 <<
"BlkGemmPipelineVersion: "
827 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
828 <<
"BlkGemmPipelinePrefetchStages: "
829 << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages <<
", "
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define REGISTER_EXTRA_PRINTING_METHODS
Definition device_base.hpp:47
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition host_utility/hip_check_error.hpp:21
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
__global__ void kernel_batched_gemm_b_scale_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg, const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:34
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Full
Definition blkgemmpipe_scheduler.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
int64_t long_index_t
Definition ck.hpp:300
typename remove_pointer< T >::type remove_pointer_t
Definition type.hpp:300
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:127
ck::GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeBsGridDescriptor_BK0_N_BK1 __host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
ck::GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeAsGridDescriptor_AK0_M_AK1 __host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
ck::GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::APackedSize static constexpr index_t APackedSize
ck::GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::BPackedSize static constexpr index_t BPackedSize
static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:624
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:273
static constexpr index_t KPack
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:154
static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:837
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:186
Definition utility/tuple.hpp:117
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:375
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:422
index_t Batch
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:421
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t StrideScaleB_, index_t BatchStrideA_, index_t BatchStrideB_, index_t BatchStrideC_, index_t BatchStrideScaleB_, const BScaleDataType *p_b_scale_grid_, index_t Batch_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_, bool is_reduce_=false)
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:376
Helper structure responsible for kernel invocation.
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:435
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
This function issues GPU kernel execution.
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:441
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:639
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:279
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t BatchStrideScaleB)
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:280
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:291
__host__ __device__ constexpr long_index_t GetScaleBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:305
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:296
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:301
"Universal" Batched GEMM operation without SplitK support.
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:270
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:646
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideScaleB, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t BatchStrideScaleB, const BScaleDataType *p_b_scale, index_t Batch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, index_t KBatch=1)
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:698
GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:318
std::string GetTypeString() const override
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:791
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:785
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideScaleB, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t BatchStrideScaleB, const void *p_b_scale, index_t Batch, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:741
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:690
bool GetPermuteB() override
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:696
index_t GetKPerBlock() override
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:695
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:652
static auto MakeInvoker()
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:738
Definition device_batched_gemm.hpp:60
Definition flush_cache.hpp:21