gemm_pipeline_agmem_bgmem_creg_v1.hpp Source File

gemm_pipeline_agmem_bgmem_creg_v1.hpp Source File#

Composable Kernel: gemm_pipeline_agmem_bgmem_creg_v1.hpp Source File
gemm_pipeline_agmem_bgmem_creg_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12// A Tile Window: global memory
13// B Tile Window: global memory
14// C Distributed tensor: register
15template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
17{
21
25
29
32
35
37
38 using I0 = number<0>;
39 using I1 = number<1>;
40 using I2 = number<2>;
41
42 static constexpr index_t BlockSize = Problem::kBlockSize;
43
44 static constexpr index_t kMPerBlock = BlockGemmShape::kM;
45 static constexpr index_t kNPerBlock = BlockGemmShape::kN;
46 static constexpr index_t kKPerBlock = BlockGemmShape::kK;
47
48 template <bool IsWave32Host = false>
49 static constexpr index_t GetVectorSizeA()
50 {
51 return Problem::VectorSizeA;
52 }
53 template <bool IsWave32Host = false>
54 static constexpr index_t GetVectorSizeB()
55 {
56 return Problem::VectorSizeB;
57 }
58 static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
59
60 static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
61 static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
62
63 static constexpr bool kPadM = Problem::kPadM;
64 static constexpr bool kPadN = Problem::kPadN;
65 static constexpr bool kPadK = Problem::kPadK;
66
67 static constexpr bool Preshuffle = Problem::Preshuffle;
68
69 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
70
71 static constexpr index_t kLdsAlignmentInBytes = 16;
72
73 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
74 {
75 // clang-format off
76 return concat('_', "pipeline_AGmemBGmemCRegV1",
79 concat('x', kPadM, kPadN, kPadK));
80 // clang-format on
81 }
82
83 // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
84 static constexpr bool DoubleSmemBuffer = false;
85
86 CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
87
89 {
90 return Policy::template GetSmemSize<Problem>();
91 }
92
93 template <typename AsDramBlockWindowTmp,
94 typename BsDramBlockWindowTmp,
95 typename AElementFunction,
96 typename BElementFunction,
97 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
99 bool>* = nullptr>
100 CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
101 const AElementFunction& a_element_func,
102 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
103 const BElementFunction& b_element_func,
104 index_t num_loop,
105 void* p_smem) const
106 {
107 using ADramBlockWindowTmp =
108 remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
109 using BDramBlockWindowTmp =
110 remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
111
112 static_assert(
113 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
114 std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
115 "wrong!");
116
117 constexpr bool is_a_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
118 constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
119
120 static_assert(is_a_col_major
121 ? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
122 kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
123 : (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
124 kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
125 "A block window has incorrect lengths for defined ALayout!");
126 static_assert(is_b_row_major
127 ? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
128 kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
129 : (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
130 kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
131 "B block window has incorrect lengths for defined BLayout!");
132 // A tile in LDS
133 ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
134
135 constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
136
137 auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
138
139 constexpr index_t a_lds_block_space_size_aligned =
140 integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
143
144 // B tile in LDS
145 BDataType* p_b_lds = static_cast<BDataType*>(
146 static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
147
148 constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
149
150 auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
151
152 // A DRAM tile window for load
153 auto as_copy_dram_window = generate_tuple(
154 [&](auto idx) {
155 return make_tile_window(
156 a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
158 a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
159 Policy::template MakeADramTileDistribution<Problem>());
160 },
161 number<AsLayout::size()>{});
162
163 // A LDS tile window for store
164 auto a_copy_lds_window = make_tile_window(
165 a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
166
167 // B DRAM tile window for load
168 auto bs_copy_dram_window = generate_tuple(
169 [&](auto idx) {
170 return make_tile_window(
171 b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
173 b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
174 Policy::template MakeBDramTileDistribution<Problem>());
175 },
176 number<BsLayout::size()>{});
177
178 // B LDS tile window for store
179 auto b_copy_lds_window = make_tile_window(
180 b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
181
182 // Tile distribution for load from lds
183 constexpr auto a_lds_load_tile_distr =
184 make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
185 constexpr auto b_lds_load_tile_distr =
186 make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
187
188 // A LDS tile for block GEMM
189 auto a_lds_gemm_window =
190 make_tile_window(a_lds_block,
192 {0, 0},
193 a_lds_load_tile_distr);
194
195 // B LDS tile for block GEMM
196 auto b_lds_gemm_window =
197 make_tile_window(b_lds_block,
199 {0, 0},
200 b_lds_load_tile_distr);
201
202 // Block GEMM
203 auto block_gemm = BlockGemm();
204
205 // Acc register tile
206 auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
207
208 // prefetch
209 // global read 0
210 // Load tile — during value loading, an elementwise function is executed for each A0,
211 // A1, … AN. The values A0, A1, … AN are read by the same thread.
212 auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
213
214 // Load tile — during value loading, an elementwise function is executed for each B0,
215 // B1, … BN. The values B0, B1, … BN are read by the same thread.
216 auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
217
218 {
219 // move to 1
220 // Move each A — the enhanced function move_tile_window is executed, which takes a tuple
221 // as input.
222 move_tile_window(as_copy_dram_window, {0, kKPerBlock});
223 // Move each B — the enhanced function move_tile_window is executed, which takes a tuple
224 // as input.
225 move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
226
227 // initialize C
228 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
229
230 // LDS write 0
231 if constexpr(is_a_col_major)
232 {
234 Policy::template MakeShuffledARegTileDistribution<Problem>());
235 transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
236 store_tile(a_copy_lds_window, a_shuffle_tmp);
237 }
238 else
239 {
240 store_tile(a_copy_lds_window, elementwise_As_res);
241 }
242
243 // LDS write 0
244 if constexpr(is_b_row_major)
245 {
247 Policy::template MakeShuffledBRegTileDistribution<Problem>());
248 transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
249 store_tile(b_copy_lds_window, b_shuffle_tmp);
250 }
251 else
252 {
253 store_tile(b_copy_lds_window, elementwise_Bs_res);
254 }
255 }
256
257 index_t iCounter = num_loop - 1;
258 while(iCounter > 0)
259 {
260 // global read i + 1
261 elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
262 elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
263
265
266 // GEMM i
267 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
268
270
271 // move to i + 2
272 move_tile_window(as_copy_dram_window, {0, kKPerBlock});
273 move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
274
275 // LDS write i + 1
276 if constexpr(is_a_col_major)
277 {
278 auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
279 Policy::template MakeShuffledARegTileDistribution<Problem>());
280 transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
281 store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
282 }
283 else
284 {
285 store_tile(a_copy_lds_window, elementwise_As_res);
286 }
287
288 // LDS write i + 1
289 if constexpr(is_b_row_major)
290 {
291 auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
292 Policy::template MakeShuffledBRegTileDistribution<Problem>());
293 transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
294 store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
295 }
296 else
297 {
298 store_tile(b_copy_lds_window, elementwise_Bs_res);
299 }
300
301 iCounter--;
302 }
303
304 // tail
305 {
307
308 // GEMM num_loop - 1
309 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
310 }
311
312 return c_block_tile;
313 }
314
315 template <typename AsDramBlockWindowTmp,
316 typename BsDramBlockWindowTmp,
317 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
319 bool>* = nullptr>
320 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
321 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
322 index_t num_loop,
323 void* p_smem) const
324 {
325 return operator()(
326 a_dram_block_window_tmp,
327 [](auto& e, const ADataType & a) { e = a; },
328 b_dram_block_window_tmp,
329 [](auto& e, const BDataType & b) { e = b; },
330 num_loop,
331 p_smem);
332 }
333
334 template <typename ADramBlockWindowTmp,
335 typename BDramBlockWindowTmp,
336 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
338 bool>* = nullptr>
339 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
340 const BDramBlockWindowTmp& b_dram_block_window_tmp,
341 index_t num_loop,
342 void* p_smem) const
343 {
344 return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
345 ck_tile::make_tuple(b_dram_block_window_tmp),
346 num_loop,
347 p_smem);
348 }
349};
350
351} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access >={}, bool_constant< oob_conditional_check >={})
Load tile with elementwise function.
Definition load_tile.hpp:41
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_DEVICE void transpose_tile2d(OutTensor &out, const InTensor &in)
Definition transpose_tile.hpp:195
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:17
static constexpr index_t kNPerBlock
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:45
number< 0 > I0
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:38
remove_cvref_t< std::tuple_element_t< 0, BsDataType > > BDataType
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:34
number< 1 > I1
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:39
remove_cvref_t< typename Problem::AsDataTypeTuple > AsDataType
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:18
static CK_TILE_HOST_DEVICE constexpr auto TransposeC()
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:86
remove_cvref_t< std::tuple_element_t< 0, AsLayout > > ALayout
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:30
static constexpr bool Preshuffle
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:67
static constexpr bool kPadN
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:64
static constexpr index_t kKPerBlock
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:46
static constexpr index_t kMPerBlock
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:44
static constexpr index_t BlockSize
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:42
remove_cvref_t< typename Problem::AsLayoutTuple > AsLayout
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:26
static constexpr index_t kLdsAlignmentInBytes
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:71
static constexpr index_t GetSmemPackA()
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:60
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const BsDramBlockWindowTmp &b_dram_block_window_tmp, index_t num_loop, void *p_smem) const
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:320
remove_cvref_t< typename Problem::CDataType > CDataType
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:20
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:88
remove_cvref_t< std::tuple_element_t< 0, BsLayout > > BLayout
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:31
static constexpr index_t GetVectorSizeC()
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:58
static constexpr index_t GetSmemPackB()
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:61
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, index_t num_loop, void *p_smem) const
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:339
number< 2 > I2
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:40
static constexpr index_t GetVectorSizeA()
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:49
remove_cvref_t< std::tuple_element_t< 0, AsDataType > > ADataType
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:33
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *p_smem) const
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:100
remove_cvref_t< typename Problem::BsLayoutTuple > BsLayout
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:27
static constexpr index_t NumWaveGroups
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:69
remove_cvref_t< typename Problem::CLayout > CLayout
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:28
static constexpr index_t GetVectorSizeB()
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:54
remove_cvref_t< typename Problem::BElementWise > BElementWise
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:23
static constexpr bool DoubleSmemBuffer
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:84
remove_cvref_t< decltype(Policy::template GetBlockGemm< Problem >())> BlockGemm
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:36
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:24
static CK_TILE_HOST const std::string GetName()
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:73
static constexpr bool kPadK
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:65
static constexpr bool kPadM
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:63
remove_cvref_t< typename Problem::BsDataTypeTuple > BsDataType
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:19
remove_cvref_t< typename Problem::AElementWise > AElementWise
Definition gemm_pipeline_agmem_bgmem_creg_v1.hpp:22