dpp_gemm.hpp Source File#
dpp_gemm.hpp
Go to the documentation of this file.
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ static __device__ auto CalculateBThreadOriginDataIndex_K_N()
Definition dpp_gemm.hpp:506
__device__ void Run(const ADataType &p_a_wave, const BDataType &p_b_wave, CDataType &p_c_thread) const
Definition dpp_gemm.hpp:449
__host__ static __device__ constexpr auto GetCMNThreadBlkLengths()
Definition dpp_gemm.hpp:532
static __device__ constexpr index_t GetRegSizePerDpp()
Definition dpp_gemm.hpp:442
static __device__ auto GetLaneGroupIdInWave()
Definition dpp_gemm.hpp:473
static __device__ CIndex GetBeginOfThreadBlk()
Definition dpp_gemm.hpp:512
static __device__ auto GetLaneIdInLaneGroup()
Definition dpp_gemm.hpp:468
__host__ static __device__ auto CalculateAThreadOriginDataIndex_K_M()
Definition dpp_gemm.hpp:498
Definition dpp_gemm.hpp:322
static constexpr auto GetDpp()
Definition utility/sequence.hpp:43
Definition amd_gemm_dpp.hpp:37
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition dpp_gemm.hpp:156
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:149
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:146
static constexpr index_t wave_size
Definition dpp_gemm.hpp:143
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:144
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:147
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:151
static constexpr bool share_a
Definition dpp_gemm.hpp:152
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:145
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:150
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:148
static constexpr bool share_a
Definition dpp_gemm.hpp:242
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:238
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:240
static constexpr index_t wave_size
Definition dpp_gemm.hpp:233
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition dpp_gemm.hpp:246
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:241
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:235
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:237
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:234
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:236
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:239
static constexpr bool share_a
Definition dpp_gemm.hpp:302
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:297
static constexpr index_t wave_size
Definition dpp_gemm.hpp:293
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:301
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:295
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:300
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:299
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:298
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition dpp_gemm.hpp:306
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:294
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:296
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:268
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:267
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:270
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:266
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:264
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:269
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition dpp_gemm.hpp:276
static constexpr index_t wave_size
Definition dpp_gemm.hpp:263
static constexpr bool share_a
Definition dpp_gemm.hpp:272
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:265
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:271
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:54
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:57
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:56
static constexpr index_t wave_size
Definition dpp_gemm.hpp:53
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:60
static constexpr bool share_a
Definition dpp_gemm.hpp:62
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:58
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:61
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:55
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition dpp_gemm.hpp:66
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:59
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:208
static constexpr index_t wave_size
Definition dpp_gemm.hpp:203
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:205
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition dpp_gemm.hpp:216
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:210
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:209
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:206
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:204
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:211
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:207
static constexpr bool share_a
Definition dpp_gemm.hpp:212
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:177
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:174
static constexpr index_t wave_size
Definition dpp_gemm.hpp:173
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:180
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:181
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:179
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:175
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:178
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:176
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition dpp_gemm.hpp:186
static constexpr bool share_a
Definition dpp_gemm.hpp:182
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:116
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:118
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:114
static constexpr index_t wave_size
Definition dpp_gemm.hpp:113
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:121
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition dpp_gemm.hpp:126
static constexpr bool share_a
Definition dpp_gemm.hpp:122
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:119
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:115
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:117
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:120
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:88
static constexpr bool share_a
Definition dpp_gemm.hpp:92
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:87
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:90
static constexpr index_t wave_size
Definition dpp_gemm.hpp:83
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:89
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition dpp_gemm.hpp:96
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:91
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:84
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:86
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:85
Definition dpp_gemm.hpp:48
Definition functional2.hpp:33