20 template <
typename Problem>
25 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
26 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
41 return a_lds_block_desc;
45 template <
typename Problem>
48 constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
49 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
64 return b_lds_block_desc;
67 template <
typename Problem>
73 sizeof(
typename Problem::ADataType) *
78 template <
typename Problem>
84 sizeof(
typename Problem::BDataType) *
89 template <
typename Problem>
94 constexpr index_t smem_size = smem_size_a + smem_size_b;
99 template <
typename Problem>
102 return Problem::VectorLoadSize;
105 template <
typename Problem>
108 return Problem::VectorLoadSize;
111 template <
typename Problem>
117 constexpr index_t BlockSize = Problem::kBlockSize;
119 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
120 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
122 if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
124 constexpr index_t M1 = Problem::VectorSizeA;
125 constexpr index_t M0 = MPerBlock / M1;
126 constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
127 static_assert(total_pixels % M1 == 0);
128 constexpr index_t K3 = total_pixels / M1;
130 static_assert(KPack % K3 == 0);
131 constexpr index_t K2 = KPack / K3;
136 static_assert(KPerBlock == K0 * K1 * K2 * K3);
148 constexpr index_t K2_m = K2 / K1;
150 static_assert(KPerBlock == K0 * K1 * K2_m * K3);
162 constexpr index_t K1 = 16 /
sizeof(ADataType);
163 constexpr index_t K0 = KPerBlock / K1;
169 static_assert(M2 != 0,
"M2 is zero, which will lead to a division by zero error.");
170 static_assert(M1 != 0,
"M1 is zero, which will lead to a division by zero error.");
171 constexpr index_t M0 = MPerBlock / (M2 * M1);
172 static_assert(M0 * M1 * M2 == MPerBlock,
173 "Incorrect M0, M2, M1 configuration! "
174 "M0, M1, M2 must cover whole MPerBlock!");
186 constexpr index_t M1 = MPerBlock / (M2 * M0);
187 static_assert(M0 * M1 * M2 == MPerBlock,
188 "Incorrect M0, M1, M2 configuration! "
189 "M0, M1, M2 must cover whole MPerBlock!");
201 template <
typename Problem>
207 constexpr index_t BlockSize = Problem::kBlockSize;
209 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
210 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
212 if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
214 constexpr index_t N1 = Problem::VectorSizeB;
215 constexpr index_t N0 = NPerBlock / N1;
216 constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
217 static_assert(total_pixels % N1 == 0);
218 constexpr index_t K3 = total_pixels / N1;
220 static_assert(KPack % K3 == 0);
221 constexpr index_t K2 = KPack / K3;
226 static_assert(KPerBlock == K0 * K1 * K2 * K3);
238 constexpr index_t K2_m = K2 / K1;
240 static_assert(KPerBlock == K0 * K1 * K2_m * K3);
253 constexpr index_t K1 = Problem::VectorLoadSize /
sizeof(BDataType);
254 constexpr index_t K0 = KPerBlock / K1;
260 static_assert(N2 != 0,
"N2 is zero, which will lead to a division by zero error.");
261 static_assert(N1 != 0,
"N1 is zero, which will lead to a division by zero error.");
262 constexpr index_t N0 = NPerBlock / (N2 * N1);
263 static_assert(N0 * N1 * N2 == NPerBlock,
264 "Incorrect N0, N1, N2 configuration! "
265 "N0, N1, N2 must cover whole NPerBlock!");
279 constexpr index_t N1 = NPerBlock / (N2 * N0);
280 static_assert(N0 * N1 * N2 == NPerBlock,
281 "Incorrect N0, N1, N2 configuration! "
282 "N0, N1, N2 must cover whole NPerBlock!");
294 template <
typename Problem>
299 static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
300 constexpr index_t kBlockSize = Problem::kBlockSize;
301 constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
302 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
304 constexpr index_t N1 = Problem::VectorLoadSize /
sizeof(BDataType);
305 constexpr index_t N0 = kNPerBlock / N1;
306 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
307 static_assert(total_pixels % N1 == 0);
308 constexpr index_t K3 = total_pixels / N1;
310 static_assert(kKPack % K3 == 0);
311 constexpr index_t K2 = kKPack / K3;
313 if constexpr(warp_size % (K2 * N0) == 0)
315 constexpr index_t K1 = warp_size / (K2 * N0);
316 constexpr index_t K0 = kBlockSize / warp_size;
329 constexpr index_t K2_m = K2 / K1;
331 static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
342 template <
typename Problem>
347 static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
348 constexpr index_t kBlockSize = Problem::kBlockSize;
349 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
350 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
352 constexpr index_t M1 = Problem::VectorLoadSize /
sizeof(ADataType);
353 constexpr index_t M0 = kMPerBlock / M1;
354 constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
355 static_assert(total_pixels % M1 == 0);
356 constexpr index_t K3 = total_pixels / M1;
358 static_assert(kKPack % K3 == 0);
359 constexpr index_t K2 = kKPack / K3;
361 if constexpr(warp_size % (K2 * M0) == 0)
363 constexpr index_t K1 = warp_size / (K2 * M0);
364 constexpr index_t K0 = kBlockSize / warp_size;
377 constexpr index_t K2_m = K2 / K1;
379 static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
390 template <
typename Problem>
393 using AccDataType = float;
394 using BlockWarps =
typename Problem::BlockGemmShape::BlockWarps;
395 using WarpTile =
typename Problem::BlockGemmShape::WarpTile;
397 typename Problem::ComputeDataType,
402 Problem::TransposeC>;
405 typename Problem::BDataType,
406 typename Problem::CDataType,
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp:18
Definition block_universal_gemm_as_bs_cr.hpp:21
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:14
static constexpr auto I2
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:17
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackA()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:100
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledBRegBlockDistribution()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:295
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackB()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:106
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeA()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:68
static constexpr auto I1
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:16
static CK_TILE_HOST_DEVICE constexpr auto MakeBDramTileDistribution()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:202
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:90
static CK_TILE_HOST_DEVICE constexpr auto GetBlockGemm()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:391
static CK_TILE_HOST_DEVICE constexpr auto MakeALdsBlockDescriptor()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:21
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledARegBlockDistribution()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:343
static CK_TILE_HOST_DEVICE constexpr auto MakeBLdsBlockDescriptor()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:46
static CK_TILE_HOST_DEVICE constexpr auto MakeADramTileDistribution()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:112
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeB()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:79
static constexpr auto I0
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:15
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192