gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp Source File

gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp Source File#

Composable Kernel: gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp Source File
gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
18
19namespace ck {
20
21template <typename GridwiseGemm,
22 typename FloatAB,
23 typename FloatC,
24 typename FloatC0,
25 typename FloatC1,
26 typename ReducePtrsGlobal,
27 typename AElementwiseOperation,
28 typename BElementwiseOperation,
29 typename CElementwiseOperation,
30 typename C1ElementwiseOperation,
31 typename ReduceInElementwiseOperations,
32 typename ReduceAccElementwiseOperations,
33 typename AGridDesc_AK0_M_AK1,
34 typename BGridDesc_BK0_N_BK1,
35 typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
36 typename C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
37 typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
38 typename ReduceGridDescriptor_MBlock_MPerBlock,
39 typename Block2CTileMap,
40 bool HasMainKBlockLoop>
41__global__ void
42#if CK_USE_LAUNCH_BOUNDS
44#endif
46 const FloatAB* __restrict__ p_a_grid,
47 const FloatAB* __restrict__ p_b_grid,
48 FloatC* __restrict__ p_c_grid,
49 const FloatC0* __restrict__ p_bias_grid,
50 const FloatC1* __restrict__ p_d0_grid,
51 ReducePtrsGlobal p_reduces_grid,
52 const AElementwiseOperation a_element_op,
53 const BElementwiseOperation b_element_op,
54 const CElementwiseOperation c_element_op,
55 const C1ElementwiseOperation c1_element_op,
56 const ReduceInElementwiseOperations reduce_in_element_ops,
57 const ReduceAccElementwiseOperations reduce_out_element_ops,
58 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
59 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
60 const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
61 c_grid_desc_mblock_mperblock_nblock_nperblock,
62 const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
63 c0_grid_desc_mblock_mperblock_nblock_nperblock,
64 const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
65 c1_grid_desc_mblock_mperblock_nblock_nperblock,
66 const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
67 const Block2CTileMap block_2_ctile_map)
68{
69#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
70 defined(__gfx12__)
71 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
72 {
73 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
74
75 GridwiseGemm::template Run<HasMainKBlockLoop>(
76 p_a_grid,
77 p_b_grid,
78 p_c_grid,
79 p_bias_grid,
80 p_d0_grid,
81 p_reduces_grid,
82 p_shared,
83 a_element_op,
84 b_element_op,
85 c_element_op,
86 c1_element_op,
87 reduce_in_element_ops,
88 reduce_out_element_ops,
89 a_grid_desc_ak0_m_ak1,
90 b_grid_desc_bk0_n_bk1,
91 c_grid_desc_mblock_mperblock_nblock_nperblock,
92 c0_grid_desc_mblock_mperblock_nblock_nperblock,
93 c1_grid_desc_mblock_mperblock_nblock_nperblock,
94 reduce_grid_desc_mblock_mperblock,
95 block_2_ctile_map);
96 }
97#else
98 ignore = p_a_grid;
99 ignore = p_b_grid;
100 ignore = p_c_grid;
101 ignore = p_bias_grid;
102 ignore = p_d0_grid;
103 ignore = p_reduces_grid;
104 ignore = a_element_op;
105 ignore = b_element_op;
106 ignore = c_element_op;
107 ignore = c1_element_op;
108 ignore = reduce_in_element_ops;
109 ignore = reduce_out_element_ops;
110 ignore = a_grid_desc_ak0_m_ak1;
111 ignore = b_grid_desc_bk0_n_bk1;
112 ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
113 ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock;
114 ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
115 ignore = reduce_grid_desc_mblock_mperblock;
116 ignore = block_2_ctile_map;
117#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
118}
119
120template <typename FloatAB,
121 typename FloatGemmAcc,
122 typename FloatCShuffle,
123 typename FloatC,
124 typename FloatC0,
125 typename FloatC1,
126 typename FloatReduceAcc,
127 typename ReducePtrsGlobal,
128 typename AElementwiseOperation,
129 typename BElementwiseOperation,
130 typename CElementwiseOperation,
131 typename C1ElementwiseOperation,
132 typename ReduceOperations,
133 typename ReduceInElementwiseOperations,
134 typename ReduceAccElementwiseOperations,
135 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
136 typename ReduceGlobalMemoryDataOperation,
137 typename AGridDesc_AK0_M_AK1,
138 typename BGridDesc_BK0_N_BK1,
139 typename CGridDesc_M_N,
140 typename C0GridDesc_M_N,
141 typename C1GridDesc_M_N,
142 typename ReduceGridDesc_M,
143 index_t NumGemmKPrefetchStage,
144 index_t BlockSize,
145 index_t MPerBlock,
146 index_t NPerBlock,
147 index_t KPerBlock,
148 index_t AK1Value,
149 index_t BK1Value,
150 index_t MPerXdl,
151 index_t NPerXdl,
152 index_t MXdlPerWave,
153 index_t NXdlPerWave,
154 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
155 typename ABlockTransferThreadClusterArrangeOrder,
156 typename ABlockTransferSrcAccessOrder,
157 index_t ABlockTransferSrcVectorDim,
158 index_t ABlockTransferSrcScalarPerVector,
159 index_t ABlockTransferDstScalarPerVector_AK1,
160 bool AThreadTransferSrcResetCoordinateAfterRun,
161 index_t ABlockLdsExtraM,
162 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
163 typename BBlockTransferThreadClusterArrangeOrder,
164 typename BBlockTransferSrcAccessOrder,
165 index_t BBlockTransferSrcVectorDim,
166 index_t BBlockTransferSrcScalarPerVector,
167 index_t BBlockTransferDstScalarPerVector_BK1,
168 bool BThreadTransferSrcResetCoordinateAfterRun,
169 index_t BBlockLdsExtraN,
170 index_t CShuffleMXdlPerWavePerShuffle,
171 index_t CShuffleNXdlPerWavePerShuffle,
172 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
173 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
174 typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
175 index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
176 index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
177 LoopScheduler LoopSched,
180{
181 static constexpr auto I0 = Number<0>{};
182 static constexpr auto I1 = Number<1>{};
183 static constexpr auto I2 = Number<2>{};
184 static constexpr auto I3 = Number<3>{};
185 static constexpr auto I4 = Number<4>{};
186 static constexpr auto I5 = Number<5>{};
187 static constexpr auto I6 = Number<6>{};
188 static constexpr auto I7 = Number<7>{};
189
190 // K1 should be Number<...>
191 static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
192 static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
193 static constexpr auto AK1 = Number<AK1Value>{};
194 static constexpr auto BK1 = Number<BK1Value>{};
195
197
200
201 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
202 {
203 // A matrix in LDS memory, dst of blockwise copy
207 }
208
209 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
210 {
211 // B matrix in LDS memory, dst of blockwise copy
215 }
216
217 __host__ __device__ static constexpr auto
219 {
220 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
221 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
222
223 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
227 I1,
229
230 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
231 }
232
233 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
234 {
235 // LDS allocation for A and B: be careful of alignment
236 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
237 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
238
239 // lds max alignment
240 constexpr auto max_lds_align = math::lcm(AK1, BK1);
241
242 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
243 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
244
245 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
246 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
247
248 // LDS allocation for C shuffle in LDS
249 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
251
252 constexpr auto c_block_size =
253 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
254
255 return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
256 sizeof(FloatAB),
257 c_block_size * sizeof(FloatCShuffle));
258 }
259
260 template <
261 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
262 __device__ static bool constexpr IsValidCompilationParameter()
263 {
264 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
265 BlockSize,
266 MPerBlock,
267 NPerBlock,
268 MPerXdl,
269 NPerXdl,
270 MXdlPerWave,
271 NXdlPerWave,
272 FloatC,
273 CGlobalMemoryDataOperation>();
274 }
275
276 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
277 template <typename Block2CTileMap>
278 __host__ __device__ static constexpr bool
279 CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
280 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
281 const CGridDesc_M_N& c_grid_desc_m_n,
282 const Block2CTileMap& block_2_ctile_map)
283 {
284 // static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
285 // is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
286 // "wrong! K1 need to be known at compile-time");
287
288 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
289 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
290 "Invalid tuning param!");
291
292 const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
293 const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
294 const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
295
296 if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
297 return false;
298
299 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
300 return false;
301
302 // check gridwise gemm pipeline
303 const auto num_k_loop = K / KPerBlock;
304
305 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
306 {
307 return false;
308 }
309
310 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
311 {
312 return false;
313 }
314
315 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
316 return true;
317 }
318
319 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
320 {
321 const index_t num_loop = K / KPerBlock;
322
323 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
324 }
325
326 template <typename CGridDesc_M_N_>
327 __host__ __device__ static constexpr auto
328 MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N_& c_grid_desc_m_n)
329 {
330 const auto M = c_grid_desc_m_n.GetLength(I0);
331 const auto N = c_grid_desc_m_n.GetLength(I1);
332
333 const auto MBlock = M / MPerBlock;
334 const auto NBlock = N / NPerBlock;
335
336 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
337 c_grid_desc_m_n,
342
343 return c_grid_desc_mblock_mperblock_nblock_nperblock;
344 }
345
346 __host__ __device__ static constexpr auto
347 MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m)
348 {
349 const auto M = d_grid_desc_m.GetLength(I0);
350 const auto MBlock = M / MPerBlock;
351
352 const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor(
353 d_grid_desc_m,
357
358 return reduce_grid_desc_mblock_mperblock;
359 }
360
361 // return block_id to C matrix tile idx (m0, n0) mapping
362 __host__ __device__ static constexpr auto
363 MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
364 {
366 c_grid_desc_m_n);
367 }
368
371 CGridDesc_M_N{}))>;
372
375 C0GridDesc_M_N{}))>;
376
379 C1GridDesc_M_N{}))>;
380
382 remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
383
385 remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
386
387 template <bool HasMainKBlockLoop, typename Block2CTileMap>
388 __device__ static void
389 Run(const FloatAB* __restrict__ p_a_grid,
390 const FloatAB* __restrict__ p_b_grid,
391 FloatC* __restrict__ p_c_grid,
392 const FloatC0* __restrict__ p_bias_grid,
393 const FloatC1* __restrict__ p_d0_grid,
394 ReducePtrsGlobal p_reduces_grid,
395 void* __restrict__ p_shared,
396 const AElementwiseOperation& a_element_op,
397 const BElementwiseOperation& b_element_op,
398 const CElementwiseOperation& c_element_op,
399 const C1ElementwiseOperation& c1_element_op,
400 const ReduceInElementwiseOperations& reduce_in_element_ops,
401 const ReduceAccElementwiseOperations& reduce_out_element_ops,
402 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
403 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
405 c_grid_desc_mblock_mperblock_nblock_nperblock,
407 c0_grid_desc_mblock_mperblock_nblock_nperblock,
409 c1_grid_desc_mblock_mperblock_nblock_nperblock,
410 const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock,
411 const Block2CTileMap& block_2_ctile_map)
412 {
413 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
414 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
415 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
416 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
418 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
420 p_bias_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
422 p_d0_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
423
424 // divide block work by [M, N]
425 const auto block_work_idx =
426 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
427
428 if(!block_2_ctile_map.ValidCTileIndex(
429 block_work_idx,
430 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
431 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
432 {
433 return;
434 }
435
436 // HACK: this force m/n_block_data_idx_on_grid into SGPR
437 const index_t m_block_data_idx_on_grid =
438 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
439
440 const index_t n_block_data_idx_on_grid =
441 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
442
443 // lds max alignment
444 constexpr auto max_lds_align = math::lcm(AK1, BK1);
445
446 // A matrix in LDS memory, dst of blockwise copy
447 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
448
449 // B matrix in LDS memory, dst of blockwise copy
450 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
451
452 // A matrix blockwise copy
453 auto a_blockwise_copy =
455 AElementwiseOperation,
459 ABlockTransferThreadClusterLengths_AK0_M_AK1,
460 ABlockTransferThreadClusterArrangeOrder,
461 FloatAB,
462 FloatAB,
463 decltype(a_grid_desc_ak0_m_ak1),
464 decltype(a_block_desc_ak0_m_ak1),
465 ABlockTransferSrcAccessOrder,
467 ABlockTransferSrcVectorDim,
468 2,
469 ABlockTransferSrcScalarPerVector,
470 ABlockTransferDstScalarPerVector_AK1,
471 1,
472 1,
473 AThreadTransferSrcResetCoordinateAfterRun,
474 true,
475 NumGemmKPrefetchStage>(
476 a_grid_desc_ak0_m_ak1,
477 make_multi_index(0, m_block_data_idx_on_grid, 0),
478 a_element_op,
479 a_block_desc_ak0_m_ak1,
480 make_multi_index(0, 0, 0),
482
483 // B matrix blockwise copy
484 auto b_blockwise_copy =
486 BElementwiseOperation,
490 BBlockTransferThreadClusterLengths_BK0_N_BK1,
491 BBlockTransferThreadClusterArrangeOrder,
492 FloatAB,
493 FloatAB,
494 decltype(b_grid_desc_bk0_n_bk1),
495 decltype(b_block_desc_bk0_n_bk1),
496 BBlockTransferSrcAccessOrder,
498 BBlockTransferSrcVectorDim,
499 2,
500 BBlockTransferSrcScalarPerVector,
501 BBlockTransferDstScalarPerVector_BK1,
502 1,
503 1,
504 BThreadTransferSrcResetCoordinateAfterRun,
505 true,
506 NumGemmKPrefetchStage>(
507 b_grid_desc_bk0_n_bk1,
508 make_multi_index(0, n_block_data_idx_on_grid, 0),
509 b_element_op,
510 b_block_desc_bk0_n_bk1,
511 make_multi_index(0, 0, 0),
513
514 // GEMM definition
515 // c_mtx += transpose(a_mtx) * b_mtx
516 // a_mtx[K0PerBlock, MPerBlock] is in LDS
517 // b_mtx[K0PerBlock, NPerBlock] is in LDS
518 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
519 // register
520 // sanity check
521 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
522 constexpr bool is_single_rate_mfma =
524 lcm_AK1_BK1 <= 4) ||
525 (is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
527 lcm_AK1_BK1 < 32))
528 ? true
529 : false;
530 constexpr auto is_scale_mfma = false;
531 constexpr index_t KPack = math::max(
532 lcm_AK1_BK1,
534 selected_mfma.k_per_blk);
535
537 BlockSize,
538 FloatAB,
539 FloatAB,
540 FloatGemmAcc,
541 decltype(a_block_desc_ak0_m_ak1),
542 decltype(b_block_desc_bk0_n_bk1),
543 MPerXdl,
544 NPerXdl,
545 MXdlPerWave,
546 NXdlPerWave,
547 KPack,
548 LoopSched>();
549
550 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
551
552 // LDS allocation for A and B: be careful of alignment
553 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
554 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
555
557 static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
558
560 static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
561 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
562
563 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
564 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
565
566 // gridwise GEMM pipeline
567 const auto gridwise_gemm_pipeline =
569
570 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
571 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
572 KPerBlock);
573
574 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
575 a_block_desc_ak0_m_ak1,
576 a_blockwise_copy,
577 a_grid_buf,
578 a_block_buf,
579 a_block_slice_copy_step,
580 b_grid_desc_bk0_n_bk1,
581 b_block_desc_bk0_n_bk1,
582 b_blockwise_copy,
583 b_grid_buf,
584 b_block_buf,
585 b_block_slice_copy_step,
586 blockwise_gemm,
587 c_thread_buf,
588 num_k_block_main_loop);
589
590 // shuffle C + reduction + write out
591 {
592 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
593 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
594 "wrong!");
595
596 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
597 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
598
599 // TODO: hacky, fix it!
600 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
601 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
602
603 // TODO: hacky, fix it!
604 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
605 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
606 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
607
608 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
609 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
610 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
611 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
612 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
613 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
614 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
615 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
616
617 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
619
620 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
621 static_cast<FloatCShuffle*>(p_shared),
622 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
623
624 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
625 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
629 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
630 M1, // M1 = MWave
631 M2, // M2 * M3 * M4 = MPerXdl
632 M3,
633 M4)),
636 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
637 N1, // N1 = NWave
638 N2))), // N2 = NPerXdl
642
643 // calculate origin of thread output tensor on global memory
644 // blockwise GEMM c matrix starting index
645 const auto c_thread_mtx_on_block =
646 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
647
648 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
649 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
650
651 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
653 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
656
657 const auto m_thread_data_on_block_idx =
658 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
659 make_multi_index(m_thread_data_on_block));
660
661 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
666
667 const auto n_thread_data_on_block_idx =
668 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
669 make_multi_index(n_thread_data_on_block));
670
671 // shuffle: threadwise copy C from VGPR to LDS
672 auto c_thread_copy_vgpr_to_lds =
674 FloatCShuffle,
675 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
676 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
678 Sequence<CShuffleMXdlPerWavePerShuffle,
679 CShuffleNXdlPerWavePerShuffle,
680 I1,
681 I1,
682 M2,
683 I1,
684 M4,
685 I1>,
687 7,
688 1,
690 1,
691 true>{
692 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
694 0,
695 m_thread_data_on_block_idx[I1],
696 n_thread_data_on_block_idx[I1],
697 m_thread_data_on_block_idx[I2],
698 m_thread_data_on_block_idx[I3],
699 m_thread_data_on_block_idx[I4],
700 n_thread_data_on_block_idx[I2]),
702
703 // space filling curve for threadwise C in VGPR
704 constexpr auto sfc_c_vgpr =
707 Sequence<CShuffleMXdlPerWavePerShuffle,
708 CShuffleNXdlPerWavePerShuffle,
709 1,
710 1,
711 M2,
712 1,
713 M4,
714 1>>{};
715
716 // space filling curve for shuffled blockwise C in global mem
717 constexpr auto sfc_c_global =
720 Sequence<1,
721 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
722 1,
723 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
724
725 // TODO: this should be implemented as a blockwise reduction
726 // LDS c_reduce_block_desc_mperblock_nperblock
727 constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
728 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
732 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
735 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))),
738
739 static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) *
740 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) ==
741 BlockSize,
742 "wrong!");
743
744 static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) %
745 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) ==
746 0 &&
747 (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) %
748 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) ==
749 0,
750 "wrong!");
751
752 constexpr index_t mreduce_per_thread =
753 (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) /
754 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0);
755
756 constexpr index_t nreduce_per_thread =
757 (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) /
758 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1);
759
760 constexpr auto c_reduce_thread_lengths_mperblock_nperblock =
762
763 // VGPR c_reduce_thread_desc_mperblock_nperblock
764 constexpr auto c_reduce_thread_desc_mperblock_nperblock =
767
768 // VGPR reduce_thread_desc_mperblock
769 constexpr auto reduce_thread_desc_mperblock =
771
772 // VGPR reduce_thread_desc_mblock_mperblock
773 constexpr auto reduce_thread_desc_mblock_mperblock =
775
777 c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
778
779 // reduce: threadwise copy from LDS to VGPR
780 constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
781 CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{});
782
783 const auto c_reduce_thread_cluster_idx =
784 c_reduce_thread_cluster_desc.CalculateBottomIndex(
786
787 const auto c_reduce_thread_data_idx_begin =
788 c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock;
789
790 auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
791 FloatCShuffle,
792 FloatReduceAcc,
793 decltype(c_reduce_block_desc_mperblock_nperblock),
794 decltype(c_reduce_thread_desc_mperblock_nperblock),
795 decltype(c_reduce_thread_lengths_mperblock_nperblock),
797 1,
798 CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
799 1,
800 true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
801
802 auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple(
803 [&](auto I) {
804 auto p_reduce_grid = p_reduces_grid[I];
805 auto reduce_acc_element_op = reduce_out_element_ops[I];
806
808 FloatReduceAcc,
809 remove_pointer_t<decltype(p_reduce_grid)>,
810 decltype(reduce_thread_desc_mblock_mperblock),
811 decltype(reduce_grid_desc_mblock_mperblock),
812 decltype(reduce_acc_element_op),
815 1,
816 CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
817 ReduceGlobalMemoryDataOperation::At(I),
818 1,
819 false>{reduce_grid_desc_mblock_mperblock,
820 make_multi_index(block_work_idx[I0], // mblock
821 c_reduce_thread_data_idx_begin[I0]), // mperblock
822 reduce_acc_element_op};
823 },
824 Number<p_reduces_grid.Size()>{});
825
826 // c0 and c1
827 constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
830
831 constexpr auto c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
832 c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock;
833
835 c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
836
837 auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
838 FloatC0,
839 FloatReduceAcc,
840 decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock),
841 decltype(c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
844 3,
845 CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
846 1,
847 true>(
848 c0_grid_desc_mblock_mperblock_nblock_nperblock,
850 m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
851 I0,
852 n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
853
854 auto c1_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
855 FloatC1,
856 FloatReduceAcc,
857 decltype(c1_grid_desc_mblock_mperblock_nblock_nperblock),
858 decltype(c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
861 3,
862 CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
863 1,
864 true>(
865 c1_grid_desc_mblock_mperblock_nblock_nperblock,
867 m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
868 I0,
869 n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
870
871 constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
874
875 auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
876 FloatReduceAcc,
877 FloatC,
878 decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
879 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
882 Sequence<0, 1, 2, 3>, // DimAccessOrder
883 3, // DstVectorDim
884 CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
886 1,
887 true>{
888 c_grid_desc_mblock_mperblock_nblock_nperblock,
890 m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
891 I0,
892 n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]),
894
895 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
896
897 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
898
899 static_for<0, num_access, 1>{}([&](auto access_id) {
900 // each thread write its data from VGPR to LDS
901 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
902 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
903 c_thread_buf,
904 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
905 c_shuffle_block_buf);
906
907 // make sure it's safe to write to LDS
909 {
910 c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
911 c_shuffle_block_buf,
912 c_reduce_thread_desc_mperblock_nperblock,
913 make_tuple(I0, I0),
914 c_reduce_thread_buf);
915
916 c0_thread_copy_global_to_vgpr.Run(
917 c0_grid_desc_mblock_mperblock_nblock_nperblock,
918 c0_grid_buf,
919 c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
920 make_tuple(I0, I0, I0, I0),
921 c01_thread_buf);
922
923 // c = activation(c + bias)
924 static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
925 [&](auto i) {
926 FloatReduceAcc out;
927 c_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i));
928 c_reduce_thread_buf(i) = out;
929 });
930
931 c1_thread_copy_global_to_vgpr.Run(
932 c1_grid_desc_mblock_mperblock_nblock_nperblock,
933 c1_grid_buf,
934 c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
935 make_tuple(I0, I0, I0, I0),
936 c01_thread_buf);
937
938 // c = c + c1_functior(c1)
939 static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
940 [&](auto i) {
941 c1_element_op(c01_thread_buf(i), c01_thread_buf(i));
942 c_reduce_thread_buf(i) += c01_thread_buf(i);
943 });
944
945 c_reduce_thread_copy_vgpr_to_global.Run(
946 c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
947 make_tuple(I0, I0, I0, I0),
948 c_reduce_thread_buf,
949 c_grid_desc_mblock_mperblock_nblock_nperblock,
950 c_grid_buf);
951
952 static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
953 auto& p_reduce_grid = p_reduces_grid[In];
954
956 p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
957
958 auto reduce_thread_buf =
960 reduce_thread_desc_mperblock.GetElementSpaceSize());
961
962 auto& reduce_in_element_op = reduce_in_element_ops[In];
963
964 auto& reduce_thread_copy_vgpr_to_global =
965 reduce_tuple_thread_copy_vgpr_to_global(In);
966
967 using ReduceOperation = remove_cvref_t<decltype(ReduceOperations{}[In])>;
968 using ThreadwiseReduce =
969 ThreadwiseReduction<FloatReduceAcc,
970 decltype(c_reduce_thread_desc_mperblock_nperblock),
971 decltype(reduce_thread_desc_mperblock),
972 ReduceOperation,
973 false>;
974
975 // Global write Gemm shuffle + reduction
976 const auto reduce_identityVal =
977 ReduceOperation::template GetIdentityValue<FloatReduceAcc>();
978
980 [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
981
982 // reduce in VGPR
985 constexpr auto offset =
986 Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
987 make_tuple(im, in))>{};
988
989 reduce_in_element_op(c_reduce_thread_buf(offset),
990 c_reduce_thread_buf(offset));
991 });
992 });
993
994 ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
995
996 // copy from VGPR to Global
997 reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
998 make_tuple(I0, I0),
999 reduce_thread_buf,
1000 reduce_grid_desc_mblock_mperblock,
1001 reduce_grid_buf);
1002
1003 if constexpr(access_id < num_access - 1)
1004 {
1005 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1006 reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
1007 reduce_grid_desc_mblock_mperblock,
1008 make_tuple(c_global_step[I0], c_global_step[I1]));
1009 }
1010 });
1011 }
1012
1013 if constexpr(access_id < num_access - 1)
1014 {
1015 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1016
1017 // move on C
1018 c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
1019 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1020
1021 // move on C0
1022 c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
1023 c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1024
1025 // move on C1
1026 c1_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
1027 c1_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1028 }
1029 });
1030 } // Reduction
1031 }
1032};
1033
1034} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_gemm_bias_add_reduce_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC0 *__restrict__ p_bias_grid, const FloatC1 *__restrict__ p_d0_grid, ReducePtrsGlobal p_reduces_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const C1ElementwiseOperation c1_element_op, const ReduceInElementwiseOperations reduce_in_element_ops, const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c0_grid_desc_mblock_mperblock_nblock_nperblock, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c1_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:45
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_pointer< T >::type remove_pointer_t
Definition type.hpp:300
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition block_to_ctile_map.hpp:261
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:180
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::DefaultBlock2CTileMap
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))> DefaultBlock2CTileMap
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:384
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::AK1
static constexpr auto AK1
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:193
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >::MakeReduceGridDescriptor_MBlock_MPerBlock
__host__ static __device__ constexpr auto MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M &d_grid_desc_m)
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:347
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >::CheckValidity
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:279
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::I4
static constexpr auto I4
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:185
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::ThisThreadBlock
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:196
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::BK0
static constexpr auto BK0
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:192
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >::CalculateHasMainKBlockLoop
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:319
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::GridwiseGemmPipe
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVersion::v1, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:198
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::I0
static constexpr auto I0
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:181
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C0GridDesc_M_N{}))> C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:373
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::AK0
static constexpr auto AK0
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:191
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::BK1
static constexpr auto BK1
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:194
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::I5
static constexpr auto I5
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:186
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >::MakeDefaultBlock2CTileMap
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:363
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
__host__ static __device__ constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:209
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::I3
static constexpr auto I3
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:184
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >::GetSharedMemoryNumberOfByte
__host__ static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:233
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
__host__ static __device__ constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:201
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))> C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:377
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::I7
static constexpr auto I7
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:188
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::ReduceGridDescriptor_MBlock_MPerBlock
remove_cvref_t< decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))> ReduceGridDescriptor_MBlock_MPerBlock
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:381
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
__host__ static __device__ constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:218
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >::IsValidCompilationParameter
static __device__ bool constexpr IsValidCompilationParameter()
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:262
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N_ &c_grid_desc_m_n)
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:328
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::I1
static constexpr auto I1
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:182
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const ADataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, const BiasDataType *__restrict__ p_bias_grid, const D0DataType *__restrict__ p_d0_grid, ReducePtrsGlobal p_reduces_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const D0ElementwiseOperation &c1_element_op, const ReduceInElementwiseOperations &reduce_in_element_ops, const ReduceAccElementwiseOperations &reduce_out_element_ops, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &c0_grid_desc_mblock_mperblock_nblock_nperblock, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &c1_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock &reduce_grid_desc_mblock_mperblock, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:389
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::I6
static constexpr auto I6
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:187
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:369
ck::GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >< math::max(NXdlPerWave64, 1)>::I2
static constexpr auto I2
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:183
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition reduction_functions_threadwise.hpp:23
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:276
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer.hpp:389
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340