codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp Source File#
codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
Go to the documentation of this file.
42 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
50 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
53 * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
57 * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
58 * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
376 using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
489 using APointers = ck::conditional_t<isMultiA, ck::Array<const void*, NumATensor>&, const void*>;
490 using BPointers = ck::conditional_t<isMultiB, ck::Array<const void*, NumBTensor>&, const void*>;
Definition tensor_operation/gpu/device/tensor_layout.hpp:42
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition device_grouped_conv_fwd_multiple_abd.hpp:23
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ void array_convert(std::array< Y, NumElems > &y, const std::array< X, NumElems > &x)
Definition utility/type_convert.hpp:2466
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition utility/array.hpp:14
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:77
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:284
Definition functional2.hpp:33
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:25
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:518
ck::Array< index_t, NDimSpatial > conv_filter_strides_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:772
ck::Array< index_t, NDimSpatial+3 > e_g_n_k_wos_lengths_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:770
CDEElementwiseOperation cde_element_op_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:761
ck::Array< index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:767
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:731
ComputePtrOffsetOfStridedBatch< NumATensor, NumBTensor, NumDTensor > compute_ptr_offset_of_batch_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:756
ck::Array< index_t, NDimSpatial > input_left_pads_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:774
ck::Array< index_t, NDimSpatial+3 > e_g_n_k_wos_strides_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:771
AGridDesc_M_K a_grid_desc_m_k_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:739
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:746
__device__ __host__ Argument(APointers p_as, BPointers p_bs, const ck::Array< const void *, NumDTensor > &p_ds, void *p_e, const ck::Array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const ck::Array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const ck::Array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const ck::Array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const ck::Array< ck::Array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const ck::Array< ck::Array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const ck::Array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const ck::Array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const ck::Array< index_t, NDimSpatial > &conv_filter_strides, const ck::Array< index_t, NDimSpatial > &conv_filter_dilations, const ck::Array< index_t, NDimSpatial > &input_left_pads, const ck::Array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:562
EDataType * p_e_grid_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:732
ck::Array< index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:764
Block2ETileMap block_2_etile_map_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:752
ck::Array< index_t, NDimSpatial > input_right_pads_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:775
EGridDesc_M_N e_grid_desc_m_n_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:742
BGridPointer p_bs_grid_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:730
index_t num_group_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:735
AElementwiseOperation a_element_op_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:759
ck::Array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:766
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:749
AGridPointer p_as_grid_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:729
BGridDesc_N_K b_grid_desc_n_k_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:740
ConvToGemmFwdTransformer conv_to_gemm_transformer_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:737
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:748
BElementwiseOperation b_element_op_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:760
ck::Array< ck::Array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_strides_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:769
__host__ __device__ void init_ds_e_grid_desc()
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:520
ck::Array< index_t, NDimSpatial > conv_filter_dilations_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:773
ck::Array< ck::Array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_lengths_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:768
DsGridDesc_M_N ds_grid_desc_m_n_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:741
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:745
ck::Array< index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:765
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:358
static constexpr auto I1
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:372
remove_cvref_t< decltype(MakeEGridDescriptor_M_N< ELayout >(dummy_conv_to_gemm_transformer))> EGridDesc_M_N
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:441
static constexpr ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:434
ck::conditional_t< isMultiA, ck::Array< const void *, NumATensor > &, const void * > APointers
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:489
__host__ static __device__ auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:383
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:499
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:513
static constexpr auto NXdlPerWave32
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:362
static __device__ __host__ auto MakeArgument(APointers p_as, BPointers p_bs, const ck::Array< const void *, NumDTensor > &p_ds, void *p_e, const ck::Array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const ck::Array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const ck::Array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const ck::Array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const ck::Array< ck::Array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const ck::Array< ck::Array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const ck::Array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const ck::Array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const ck::Array< long_index_t, NDimSpatial > &conv_filter_strides, const ck::Array< long_index_t, NDimSpatial > &conv_filter_dilations, const ck::Array< long_index_t, NDimSpatial > &input_left_pads, const ck::Array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1012
remove_cvref_t< decltype(MakeAGridDescriptor_M_K< ALayout >(dummy_conv_to_gemm_transformer))> AGridDesc_M_K
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:435
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:361
static constexpr auto I0
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:371
static constexpr bool isMultiA
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:364
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization > ConvToGemmFwdTransformer
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:376
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:508
remove_cvref_t< decltype(MakeBGridDescriptor_N_K< BLayout >(dummy_conv_to_gemm_transformer))> BGridDesc_N_K
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:437
__host__ static __device__ auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:396
ck::conditional_t< isMultiA||isMultiB, GridwiseGemmMultipleABD_xdl_cshuffle< GridwiseGemmMultiABDTemplateParameters >, GridwiseGemmMultipleD_xdl_cshuffle< GridwiseGemmTemplateParameters > > GridwiseGemmBase
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:481
ck::conditional_t<!isMultiB &&isMultiA, Tuple< BDataType >, BDataType > GemmBDataType
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:447
remove_cvref_t< decltype(GetBGridPointer< isMultiA||isMultiB, GridwiseGemm64, BDataType >())> BGridPointer
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:495
static constexpr index_t NumATensor
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:367
ck::conditional_t<!isMultiA &&isMultiB, Tuple< ADataType >, ADataType > GemmADataType
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:446
static constexpr index_t NumDTensor
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:369
static __device__ __host__ auto MakeArgument(APointers p_as, BPointers p_bs, const ck::Array< const void *, NumDTensor > &p_ds, void *p_e, const ck::Array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const ck::Array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const ck::Array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const ck::Array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const ck::Array< ck::Array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const ck::Array< ck::Array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const ck::Array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const ck::Array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const ck::Array< index_t, NDimSpatial > &conv_filter_strides, const ck::Array< index_t, NDimSpatial > &conv_filter_dilations, const ck::Array< index_t, NDimSpatial > &input_left_pads, const ck::Array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:970
static constexpr auto I2
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:373
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:502
ck::conditional_t< isMultiB, ck::Array< const void *, NumBTensor > &, const void * > BPointers
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:490
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:505
static constexpr bool isMultiB
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:365
remove_cvref_t< decltype(GetAGridPointer< isMultiA||isMultiB, GridwiseGemm64, ADataType >())> AGridPointer
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:493
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:485
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:422
static __device__ __host__ bool IsSupportedArgument(const Argument &arg)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:802
CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle DeviceOp
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:359
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))> DsGridDesc_M_N
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:439
static constexpr auto matrix_padder
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:378
static constexpr auto I3
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:374
static __device__ __host__ bool check_gemm_validity(const Argument &arg)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:778
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:486
static constexpr index_t NumBTensor
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:368
__host__ static __device__ auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:409
Grouped Convolution Forward.
Definition device_grouped_conv_fwd_multiple_abd.hpp:73
Definition matrix_padder.hpp:180