block_universal_gemm_as_bs_cr.hpp Source File

block_universal_gemm_as_bs_cr.hpp Source File#

Composable Kernel: block_universal_gemm_as_bs_cr.hpp Source File
block_universal_gemm_as_bs_cr.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
14// A is block window on shared memory
15// B is block window on shared memory
16// C is block distributed tensor
17template <typename Problem_,
19 index_t UnaryOpSize_ = 8>
21{
22 private:
23 // TODO: This should be in Policy - UniversalGemmPolicyBase ?
24 template <typename PipelineProblem_, typename GemmPolicy_>
25 struct GemmTraits_
26 {
28 using Policy = remove_cvref_t<GemmPolicy_>;
34
35 static constexpr index_t kBlockSize = Problem::kBlockSize;
36 static constexpr auto Scheduler = Problem::Scheduler;
37
38 static constexpr index_t MPerBlock = BlockGemmShape::kM;
39 static constexpr index_t NPerBlock = BlockGemmShape::kN;
40 static constexpr index_t KPerBlock = BlockGemmShape::kK;
41
42 static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
43
44 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
45
46 static constexpr index_t MWarp = config.template at<1>();
47 static constexpr index_t NWarp = config.template at<2>();
48
49 using I0 = number<0>;
50 using I1 = number<1>;
51
52 static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
53 "Error! WarpGemm's MWarp is not consisten with BlockGemmShape!");
54 static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
55 "Error! WarpGemm's NWarp is not consisten with BlockGemmShape!");
56 static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
57 "Error! WarpGemm's M is not consisten with BlockGemmShape!");
58 static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
59 "Error! WarpGemm's N is not consisten with BlockGemmShape!");
60
61 static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
62 static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
63 static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
64
65 static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
66 "Error! Warps should cover all Block tile!");
67 static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock,
68 "Error! Warps should cover all Block tile!");
69
70 static constexpr index_t MPerBlockPerIter = MWarp * WarpGemm::kM;
71 static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN;
72 static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
73
74 // Controls how many MAC clusters (MFMA blocks) we have per wave
75 // Ie if
76 // InterWaveSchedulingMacClusters = 1;
77 // KPerBlock == 32
78 // WarpGemm::kK = 8
79 // Then we would group all 4 WarpGemms into single MAC cluster.
80 // But if we would set InterWaveSchedulingMacClusters = 2, then we would
81 // split those 4 warp gemms into two groups.
82 static constexpr index_t InterWaveSchedulingMacClusters = 1;
83
84 // should be at least equal to: WarpGemm::Impl::kABKPerLane
85 static constexpr index_t KPack = WarpGemm::kKPerThread;
86 static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
87 };
88
89 public:
90 using Traits = GemmTraits_<Problem_, Policy_>;
91
96
99
100 static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
101 static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
102 static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
103
104 static constexpr index_t MWarp = Traits::MWarp;
105 static constexpr index_t NWarp = Traits::NWarp;
106
107 static constexpr auto Scheduler = Traits::Scheduler;
108
109 using AWarpDstr = typename WarpGemm::AWarpDstr;
110 using BWarpDstr = typename WarpGemm::BWarpDstr;
111 using CWarpDstr = typename WarpGemm::CWarpDstr;
112
113 using AWarpTensor = typename WarpGemm::AWarpTensor;
114 using BWarpTensor = typename WarpGemm::BWarpTensor;
115 using CWarpTensor = typename WarpGemm::CWarpTensor;
116
117 static constexpr auto a_warp_y_lengths =
118 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
119 static constexpr auto b_warp_y_lengths =
120 to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
121 static constexpr auto c_warp_y_lengths =
122 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
123
127
128 static constexpr index_t APackedSize =
130 static constexpr index_t BPackedSize =
132
133 using I0 = number<0>;
134 using I1 = number<1>;
135
137 {
138 constexpr index_t KPerThread = Traits::KPerThread;
139 constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
140 constexpr index_t KPerInnerLoop =
141 ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
142 constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
143
144 using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
147
148 constexpr auto a_block_outer_dstr_encoding =
155 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
156 a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
157
158 return a_block_dstr_encode;
159 }
160
162 {
163 constexpr index_t KPerThread = Traits::KPerThread;
164 constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
165 constexpr index_t KPerInnerLoop =
166 ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
167 constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
168
169 using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
172
173 constexpr auto b_block_outer_dstr_encoding =
180 constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
181 b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
182
183 return b_block_dstr_encode;
184 }
185
186 template <GemmPipelineScheduler Scheduler, typename GemmTraits>
188 {
189 };
190
191 template <typename GemmTraits>
193 {
198
201
204
205 // C += A * B
206 template <typename CBlockTensor,
207 typename ASmemBlockWindow,
208 typename BSmemBlockWindow,
209 bool ALoadTranspose = false,
210 bool BLoadTranspose = false>
211 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
212 const ASmemBlockWindow& a_block_window,
213 const BSmemBlockWindow& b_block_window,
216 {
217 static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
218 "The CDataType as defined in traits should be the same as correspoinding "
219 "C block tensor data type!");
220 static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
221 std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
222 "The ADataType and BDataType as defined in "
223 "traits should be the same as correspoinding block window data type!");
224
225 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
226 {
227 Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
228 }
229 else
230 {
231 load_tile(a_warp_tile_, a_block_window);
232 }
233 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
234 {
235 Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
236 }
237 else
238 {
239 load_tile(b_warp_tile_, b_block_window);
240 }
241 // hot loop:
242 static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
243 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
244 // read A warp tensor from A block tensor
245 AWarpTensor a_warp_tensor;
246
247 a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
248 merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
249 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
250
251 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
252 // read B warp tensor from B block tensor
253 BWarpTensor b_warp_tensor;
254
255 b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
256 merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
257 merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
258
259 // read C warp tensor from C block tensor-
260 CWarpTensor c_warp_tensor;
261
262 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
263 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
264 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
265
266 // warp GEMM
267 WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
268
269 // write C warp tensor into C block tensor
270 c_block_tensor.set_y_sliced_thread_data(
271 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
272 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
273 c_warp_tensor.get_thread_buffer());
274 });
275 });
276 });
277 }
278 };
279
280 template <typename GemmTraits>
282 {
287
290
293
294 template <typename ASmemBlockWindow,
295 typename BSmemBlockWindow,
296 bool ALoadTranspose = false,
297 bool BLoadTranspose = false>
298 CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
299 const BSmemBlockWindow& b_block_window,
302 {
303 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
304 {
305 Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
306 }
307 else if constexpr(ALoadTranspose)
308 {
309 a_warp_tile_ = load_tile_transpose(a_block_window);
310 }
311 else
312 {
313 load_tile(a_warp_tile_, a_block_window);
314 }
315 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
316 {
317 Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
318 }
319 else if constexpr(BLoadTranspose)
320 {
321 b_warp_tile_ = load_tile_transpose(b_block_window);
322 }
323 else
324 {
325 load_tile(b_warp_tile_, b_block_window);
326 }
327 }
328
329 // C += A * B
330 template <typename CBlockTensor,
331 typename ASmemBlockWindow,
332 typename BSmemBlockWindow,
333 bool ALoadTranspose = false,
334 bool BLoadTranspose = false>
335 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
336 const ASmemBlockWindow&,
337 const BSmemBlockWindow&,
340 {
341 static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
342 "The CDataType as defined in traits should be the same as correspoinding "
343 "C block tensor data type!");
344
345 // hot loop:
346 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
347 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
348 // read A warp tensor from A block tensor
349 AWarpTensor a_warp_tensor;
350
351 a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
352 merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
353 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
354
355 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
356 // read B warp tensor from B block tensor
357 BWarpTensor b_warp_tensor;
358
359 b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
360 merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
361 merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
362
363 // read C warp tensor from C block tensor
364 CWarpTensor c_warp_tensor;
365
366 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
367 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
368 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
369
370 // warp GEMM
371 WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
372
373 // write C warp tensor into C block tensor
374 c_block_tensor.set_y_sliced_thread_data(
375 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
376 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
377 c_warp_tensor.get_thread_buffer());
378 });
379 });
380 });
381 }
382 };
383
384 template <typename GemmTraits>
386 {
387 static constexpr index_t KPerThread = GemmTraits::KPerThread;
388 static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
389 static constexpr index_t KPerInnerLoop =
390 ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
392 static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
393
394 static constexpr auto ALdsTileDistr =
396 static constexpr auto BLdsTileDistr =
398
401
404
405 template <index_t KIdx,
406 typename ASmemBlockWindow,
407 typename BSmemBlockWindow,
408 bool ALoadTranspose = false,
409 bool BLoadTranspose = false>
410 CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
411 const BSmemBlockWindow& b_block_window,
414 {
415 constexpr auto a_lds_load_distr = [&]() {
416 if constexpr(ALoadTranspose)
419 ADataType>::TransposedDstrEncode{});
420 else
422 }();
423 constexpr auto b_lds_load_distr = [&]() {
424 if constexpr(BLoadTranspose)
427 BDataType>::TransposedDstrEncode{});
428 else
430 }();
431 constexpr auto a_lds_shape = []() {
432 if constexpr(ALoadTranspose)
434 else
436 }();
437 constexpr auto b_lds_shape = []() {
438 if constexpr(BLoadTranspose)
440 else
442 }();
443 constexpr auto k_idx_offset = KIdx * KPerInnerLoop;
444 constexpr auto a_offset =
445 ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
446 constexpr auto b_offset =
447 BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
448
449 auto a_lds_gemm_window = make_tile_window(
450 a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr);
451 auto b_lds_gemm_window = make_tile_window(
452 b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
453
454 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
455 {
456 Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
457 }
458 else if constexpr(ALoadTranspose)
459 {
460 a_warp_tile_ = load_tile_transpose(a_lds_gemm_window);
461 }
462 else
463 {
464 load_tile(a_warp_tile_, a_lds_gemm_window);
465 }
466 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
467 {
468 Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
469 }
470 else if constexpr(BLoadTranspose)
471 {
472 b_warp_tile_ = load_tile_transpose(b_lds_gemm_window);
473 }
474 else
475 {
476 load_tile(b_warp_tile_, b_lds_gemm_window);
477 }
478 }
479
480 // C += A * B
481 template <typename CBlockTensor,
482 typename ASmemBlockWindow,
483 typename BSmemBlockWindow,
484 bool ALoadTranspose = false,
485 bool BLoadTranspose = false>
486 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
487 const ASmemBlockWindow& a_block_window,
488 const BSmemBlockWindow& b_block_window,
489 bool_constant<ALoadTranspose> a_load_tr = {},
490 bool_constant<BLoadTranspose> b_load_tr = {})
491 {
492 static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
493 "The CDataType as defined in traits should be the same as correspoinding "
494 "C block tensor data type!");
495
496 // hot loop:
497 static_for<0, KRepeat, 1>{}([&](auto kIter) {
498 LocalPrefetch<kIter.value>(a_block_window, b_block_window, a_load_tr, b_load_tr);
499 __builtin_amdgcn_sched_barrier(0);
500 // NOTE: Synchronize threads in a workgroup at the start of each MAC
501 // cluster, but except the first, as we can shorten non-MAC cluster a bit
502 // and there's no observable negative impact. The desired effect is waves in
503 // a workgroup executing MAC in sync. This avoids some out-of-sync waves
504 // hijacking MAC resource from other workgroups and reducing the chance of
505 // latency hiding by waiting for the rest of the workgroup at the eventual
506 // sync point.
507 if constexpr(kIter.value != 0 || KRepeat == 1)
508 {
509 __builtin_amdgcn_s_barrier();
510 __builtin_amdgcn_sched_barrier(0);
511 }
512
513 static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
514 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
515 // read A warp tensor from A block tensor
516 AWarpTensor a_warp_tensor;
517
518 a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
519 merge_sequences(sequence<mIter, kInnerIter>{}, a_warp_y_index_zeros),
520 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
521 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
522 // read B warp tensor from B block tensor
523 BWarpTensor b_warp_tensor;
524
525 b_warp_tensor.get_thread_buffer() =
526 b_warp_tile_.get_y_sliced_thread_data(
527 merge_sequences(sequence<nIter, kInnerIter>{},
529 merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
530 // read C warp tensor from C block tensor-
531 CWarpTensor c_warp_tensor;
532
533 c_warp_tensor.get_thread_buffer() =
534 c_block_tensor.get_y_sliced_thread_data(
535 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
536 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
537
538 // The block_sync_lds() here performs double duty:
539 // A) safeguard against data hazard because barrier from
540 // blockwise_gemm is moved here B) reduce VMEM FIFO congestion
541 // by applying small delays to different wavefronts It is
542 // performed near the end of MAC cluster to minimize lgkmcnt
543 // penalty
544 if constexpr(kIter.value == KRepeat - 1 &&
545 kInnerIter.value == KInnerLoopIter - 1 &&
546 mIter.value == MIterPerWarp - 1 &&
547 nIter.value == NIterPerWarp - 1)
548 {
549 __builtin_amdgcn_sched_barrier(0);
551 __builtin_amdgcn_sched_barrier(0);
552 }
553 // warp GEMM
554 WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
555
556 // write C warp tensor into C block tensor
557 c_block_tensor.set_y_sliced_thread_data(
558 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
559 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
560 c_warp_tensor.get_thread_buffer());
561
562 if constexpr(kInnerIter.value == 0 && mIter.value == 0 &&
563 nIter.value == 0)
564 {
565 __builtin_amdgcn_sched_barrier(0);
566 __builtin_amdgcn_s_setprio(1);
567 __builtin_amdgcn_sched_barrier(0);
568 }
569 });
570 });
571 });
572
573 __builtin_amdgcn_sched_barrier(0);
574 __builtin_amdgcn_s_setprio(0);
575 __builtin_amdgcn_sched_barrier(0);
576 });
577 }
578 };
579
580 public:
581 CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
582 {
583 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
590
591 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
592 c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
593 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
594 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
595
596 return c_block_tensor;
597 }
598
599 template <typename ASmemBlockWindow,
600 typename BSmemBlockWindow,
601 bool ALoadTranspose = false,
602 bool BLoadTranspose = false>
603 CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
604 const BSmemBlockWindow& b_block_window,
605 bool_constant<ALoadTranspose> a_load_tr = {},
606 bool_constant<BLoadTranspose> b_load_tr = {})
607 {
608 block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
609 }
610
611 // C += A * B
612 template <typename CBlockTensor,
613 typename ASmemBlockWindow,
614 typename BSmemBlockWindow,
615 bool ALoadTranspose = false,
616 bool BLoadTranspose = false>
617 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
618 const ASmemBlockWindow& a_block_window,
619 const BSmemBlockWindow& b_block_window,
620 bool_constant<ALoadTranspose> a_load_tr = {},
621 bool_constant<BLoadTranspose> b_load_tr = {})
622 {
623 block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
624 }
625
626 // C = A * B
627 template <typename ASmemBlockWindow,
628 typename BSmemBlockWindow,
629 bool ALoadTranspose = false,
630 bool BLoadTranspose = false>
631 CK_TILE_DEVICE auto operator()(const ASmemBlockWindow& a_block_window,
632 const BSmemBlockWindow& b_block_window,
633 bool_constant<ALoadTranspose> a_load_tr = {},
634 bool_constant<BLoadTranspose> b_load_tr = {})
635 {
636 auto c_block_tensor = MakeCBlockTile();
637 block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
638 return c_block_tensor;
639 }
640
641 private:
642 BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
643};
644
645} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
TransposeTileDistributionTraits< TileDistributionEncoding_, DataType_, Policy, true > InputTileDistributionTraits
Definition load_tile_transpose.hpp:343
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition load_tile_transpose.hpp:403
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
@ Default
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:15
@ Interwave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:17
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_gemm_asmem_bsmem_creg_v1_default_policy.hpp:15
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow &a_block_window, const BSmemBlockWindow &b_block_window, bool_constant< ALoadTranspose >={}, bool_constant< BLoadTranspose >={})
Definition block_universal_gemm_as_bs_cr.hpp:298
decltype(make_static_distributed_tensor< ComputeDataType >(BLdsTileDistr)) BLdsTile
Definition block_universal_gemm_as_bs_cr.hpp:289
static constexpr auto ALdsTileDistr
Definition block_universal_gemm_as_bs_cr.hpp:283
decltype(make_static_distributed_tensor< ComputeDataType >(ALdsTileDistr)) ALdsTile
Definition block_universal_gemm_as_bs_cr.hpp:288
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ASmemBlockWindow &, const BSmemBlockWindow &, bool_constant< ALoadTranspose >={}, bool_constant< BLoadTranspose >={})
Definition block_universal_gemm_as_bs_cr.hpp:335
static constexpr auto BLdsTileDistr
Definition block_universal_gemm_as_bs_cr.hpp:285
static constexpr index_t KPerThread
Definition block_universal_gemm_as_bs_cr.hpp:387
static constexpr index_t NumMacClusters
Definition block_universal_gemm_as_bs_cr.hpp:388
static constexpr index_t KRepeat
Definition block_universal_gemm_as_bs_cr.hpp:391
static constexpr index_t KPerInnerLoop
Definition block_universal_gemm_as_bs_cr.hpp:389
decltype(make_static_distributed_tensor< ComputeDataType >(ALdsTileDistr)) ALdsTile
Definition block_universal_gemm_as_bs_cr.hpp:399
decltype(make_static_distributed_tensor< ComputeDataType >(BLdsTileDistr)) BLdsTile
Definition block_universal_gemm_as_bs_cr.hpp:400
static constexpr index_t KInnerLoopIter
Definition block_universal_gemm_as_bs_cr.hpp:392
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow &a_block_window, const BSmemBlockWindow &b_block_window, bool_constant< ALoadTranspose >={}, bool_constant< BLoadTranspose >={})
Definition block_universal_gemm_as_bs_cr.hpp:410
static constexpr auto BLdsTileDistr
Definition block_universal_gemm_as_bs_cr.hpp:396
static constexpr auto ALdsTileDistr
Definition block_universal_gemm_as_bs_cr.hpp:394
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ASmemBlockWindow &a_block_window, const BSmemBlockWindow &b_block_window, bool_constant< ALoadTranspose > a_load_tr={}, bool_constant< BLoadTranspose > b_load_tr={})
Definition block_universal_gemm_as_bs_cr.hpp:486
decltype(make_static_distributed_tensor< ComputeDataType >(ALdsTileDistr)) ALdsTile
Definition block_universal_gemm_as_bs_cr.hpp:199
decltype(make_static_distributed_tensor< ComputeDataType >(BLdsTileDistr)) BLdsTile
Definition block_universal_gemm_as_bs_cr.hpp:200
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ASmemBlockWindow &a_block_window, const BSmemBlockWindow &b_block_window, bool_constant< ALoadTranspose >={}, bool_constant< BLoadTranspose >={})
Definition block_universal_gemm_as_bs_cr.hpp:211
static constexpr auto BLdsTileDistr
Definition block_universal_gemm_as_bs_cr.hpp:196
static constexpr auto ALdsTileDistr
Definition block_universal_gemm_as_bs_cr.hpp:194
Definition block_universal_gemm_as_bs_cr.hpp:188
Definition block_universal_gemm_as_bs_cr.hpp:21
remove_cvref_t< typename Traits::ComputeDataType > ComputeDataType
Definition block_universal_gemm_as_bs_cr.hpp:94
static constexpr index_t NIterPerWarp
Definition block_universal_gemm_as_bs_cr.hpp:102
number< 0 > I0
Definition block_universal_gemm_as_bs_cr.hpp:133
static constexpr auto b_warp_y_index_zeros
Definition block_universal_gemm_as_bs_cr.hpp:125
typename WarpGemm::AWarpTensor AWarpTensor
Definition block_universal_gemm_as_bs_cr.hpp:113
remove_cvref_t< typename Traits::WarpGemm > WarpGemm
Definition block_universal_gemm_as_bs_cr.hpp:98
static constexpr auto a_warp_y_index_zeros
Definition block_universal_gemm_as_bs_cr.hpp:124
static CK_TILE_DEVICE constexpr auto MakeBBlockDistributionEncode()
Definition block_universal_gemm_as_bs_cr.hpp:161
static constexpr auto b_warp_y_lengths
Definition block_universal_gemm_as_bs_cr.hpp:119
typename WarpGemm::AWarpDstr AWarpDstr
Definition block_universal_gemm_as_bs_cr.hpp:109
typename WarpGemm::CWarpTensor CWarpTensor
Definition block_universal_gemm_as_bs_cr.hpp:115
GemmTraits_< Problem_, Policy_ > Traits
Definition block_universal_gemm_as_bs_cr.hpp:90
static constexpr index_t MWarp
Definition block_universal_gemm_as_bs_cr.hpp:104
remove_cvref_t< typename Traits::ADataType > ADataType
Definition block_universal_gemm_as_bs_cr.hpp:92
static constexpr index_t APackedSize
Definition block_universal_gemm_as_bs_cr.hpp:128
static constexpr index_t NWarp
Definition block_universal_gemm_as_bs_cr.hpp:105
typename WarpGemm::BWarpDstr BWarpDstr
Definition block_universal_gemm_as_bs_cr.hpp:110
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ASmemBlockWindow &a_block_window, const BSmemBlockWindow &b_block_window, bool_constant< ALoadTranspose > a_load_tr={}, bool_constant< BLoadTranspose > b_load_tr={})
Definition block_universal_gemm_as_bs_cr.hpp:617
static constexpr index_t MIterPerWarp
Definition block_universal_gemm_as_bs_cr.hpp:101
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_universal_gemm_as_bs_cr.hpp:581
number< 1 > I1
Definition block_universal_gemm_as_bs_cr.hpp:134
typename WarpGemm::CWarpDstr CWarpDstr
Definition block_universal_gemm_as_bs_cr.hpp:111
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow &a_block_window, const BSmemBlockWindow &b_block_window, bool_constant< ALoadTranspose > a_load_tr={}, bool_constant< BLoadTranspose > b_load_tr={})
Definition block_universal_gemm_as_bs_cr.hpp:603
static constexpr auto c_warp_y_index_zeros
Definition block_universal_gemm_as_bs_cr.hpp:126
CK_TILE_DEVICE auto operator()(const ASmemBlockWindow &a_block_window, const BSmemBlockWindow &b_block_window, bool_constant< ALoadTranspose > a_load_tr={}, bool_constant< BLoadTranspose > b_load_tr={})
Definition block_universal_gemm_as_bs_cr.hpp:631
remove_cvref_t< typename Traits::CDataType > CDataType
Definition block_universal_gemm_as_bs_cr.hpp:95
static constexpr auto c_warp_y_lengths
Definition block_universal_gemm_as_bs_cr.hpp:121
static constexpr auto a_warp_y_lengths
Definition block_universal_gemm_as_bs_cr.hpp:117
typename WarpGemm::BWarpTensor BWarpTensor
Definition block_universal_gemm_as_bs_cr.hpp:114
static constexpr auto Scheduler
Definition block_universal_gemm_as_bs_cr.hpp:107
static constexpr index_t BPackedSize
Definition block_universal_gemm_as_bs_cr.hpp:130
remove_cvref_t< InterleavedPKTypeLoader< ComputeDataType, UnaryOpSize_ > > Loader
Definition block_universal_gemm_as_bs_cr.hpp:97
static CK_TILE_DEVICE constexpr auto MakeABlockDistributionEncode()
Definition block_universal_gemm_as_bs_cr.hpp:136
remove_cvref_t< typename Traits::BDataType > BDataType
Definition block_universal_gemm_as_bs_cr.hpp:93
static constexpr index_t KIterPerWarp
Definition block_universal_gemm_as_bs_cr.hpp:100
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192