flatmm_kernel.hpp Source File

flatmm_kernel.hpp Source File#

Composable Kernel: flatmm_kernel.hpp Source File
flatmm_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <string>
8
9#include "ck_tile/core.hpp"
12
13namespace ck_tile {
15{
18 index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
19 : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
20 {
21 }
22
29};
30
31template <int SharedGranularityMN, int SharedGranularityK = 0>
33{
34 static constexpr int GranularityMN = SharedGranularityMN;
35 static constexpr int GranularityK = SharedGranularityK;
36
37 const float* ptr;
38
40 CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
41 CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_)
42 : ptr(ptr_)
43 {
44 }
45
47 {
49 if constexpr(GranularityMN == 0)
50 {
51 ret.ptr = ptr + offset / GranularityK;
52 }
53 else
54 {
56 }
57 return ret;
58 }
59
60 CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete;
61};
62
63template <int SharedGranularityMN>
64struct FlatmmScalePointer<SharedGranularityMN, 0>
65{
66 static constexpr int GranularityMN = SharedGranularityMN;
67 static constexpr int GranularityK = 0;
68
69 static_assert(GranularityMN != 0);
70
71 const float* ptr;
73
75 CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {}
77 : ptr(ptr_), length(length_)
78 {
79 }
80
82 {
84 if constexpr(GranularityMN == 1)
85 {
86 ret.ptr = ptr + offset;
87 ret.length = length - offset;
88 }
89 else
90 {
91 ret.ptr = ptr + offset / GranularityMN;
92 ret.length = length - offset / GranularityMN;
93 }
94 return ret;
95 }
96
98 {
99 // with additional oob check
100 if constexpr(GranularityMN == 1)
101 return i < length ? ptr[i] : 0;
102 else
103 return i / GranularityMN < length ? ptr[i / GranularityMN] : 0;
104 }
105};
106
107// shared granularityMN = -1 means no scale
108template <>
110{
111 static constexpr int GranularityMN = -1;
112 static constexpr int GranularityK = 0;
113
114 const float* ptr = nullptr;
115
117 CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
119
125 {
126 return 1; // alway return 1, it doesn't change the result
127 }
128};
129
130template <index_t NumDTensor = 0>
132{
135 const void* b_ptr_,
136 const std::array<const void*, NumDTensor>& ds_ptr_,
137 void* e_ptr_,
138 index_t k_batch_,
139 index_t M_,
140 index_t N_,
141 index_t K_,
142 index_t stride_A_,
143 index_t stride_B_,
144 const std::array<index_t, NumDTensor>& stride_Ds_,
145 index_t stride_E_)
146 : a_ptr(a_ptr_),
147 b_ptr(b_ptr_),
148 ds_ptr(ds_ptr_),
149 e_ptr(e_ptr_),
150 M(M_),
151 N(N_),
152 K(K_),
153 stride_A(stride_A_),
154 stride_B(stride_B_),
155 stride_Ds(stride_Ds_),
156 stride_E(stride_E_),
157 k_batch(k_batch_)
158 {
159 }
160
161 const void* a_ptr;
162 const void* b_ptr;
163 const std::array<const void*, NumDTensor> ds_ptr;
164 union
165 {
166 void* e_ptr;
167 void* c_ptr;
168 };
174 const std::array<index_t, NumDTensor> stride_Ds;
175 union
176 {
179 };
180
182};
183template <class ScaleM = FlatmmScalePointer<-1>,
184 class ScaleN = FlatmmScalePointer<-1>,
185 index_t NumDTensor = 0>
187{
190 const void* b_shuffle_ptr_,
191 const std::array<const void*, NumDTensor>& ds_ptr_,
192 void* c_ptr_,
193 index_t k_batch_,
194 index_t M_,
195 index_t N_,
196 index_t K_,
197 index_t stride_A_,
198 index_t stride_B_,
199 const std::array<index_t, NumDTensor>& stride_Ds_,
200 index_t stride_C_,
201 ScaleM scale_m_ = nullptr,
202 ScaleN scale_n_ = nullptr)
203 : BaseFlatmmHostArgs(a_ptr_,
204 b_shuffle_ptr_,
205 ds_ptr_,
206 c_ptr_,
207 k_batch_,
208 M_,
209 N_,
210 K_,
211 stride_A_,
212 stride_B_,
213 stride_Ds_,
214 stride_C_),
215 scale_m(scale_m_),
216 scale_n(scale_n_)
217 {
218 }
219 ScaleM scale_m = nullptr;
220 ScaleN scale_n = nullptr;
221};
222
223template <int NumberTensor = 0>
226
227template <class ScaleM, class ScaleN, index_t NumDTensor = 0>
229{
230 const void* a_ptr;
231 // const void* b_shuffle_ptr;
232 const void* b_ptr;
233 const std::array<const void*, NumDTensor> ds_ptr;
234 void* e_ptr;
240 std::array<index_t, NumDTensor> stride_Ds;
243 ScaleM scale_m_ptr = nullptr;
244 ScaleN scale_n_ptr = nullptr;
245};
246
247template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
249{
260 static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
261 static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
262
265 // Below type is actually accumulation data type - the output of block GEMM.
267
268 static constexpr index_t NumDTensor = DsDataType::size();
269
270 static constexpr auto I0 = number<0>();
271 static constexpr auto I1 = number<1>();
272 static constexpr auto I2 = number<2>();
273 static constexpr auto I3 = number<3>();
274
275 static_assert(DsLayout::size() == DsDataType::size(),
276 "The size of DsLayout and DsDataType should be the same");
277 // using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
278
279 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
280 {
281 // clang-format off
282 return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
283 // clang-format on
284 }
285
286 CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
287 {
288 assert(!UsePersistentKernel);
289 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
290 }
291
292 template <class ScaleM, class ScaleN>
293 CK_TILE_HOST static constexpr auto
294 GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
295 {
296 if constexpr(UsePersistentKernel)
297 {
298 hipDeviceProp_t prop;
299 int deviceId = 0; // default device
300
301 constexpr int block_size = FlatmmKernel::BlockSize().x;
302 int dync_smem_size = 0;
303 int maxActiveBlocksPerCU = 0;
304
305 [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
306
307 e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
308 &maxActiveBlocksPerCU,
309 reinterpret_cast<void*>(
310 kentry<1, FlatmmKernel, FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
311 block_size,
312 dync_smem_size);
313
314 const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
315 const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
316
317 // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
318 // << ", persistent_block_size: " << persistent_block_size
319 // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
320
321 assert(kargs.k_batch == 1);
322 return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
323 }
324 else
325 {
326 return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
327 }
328 }
329
330 CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
331
332 template <class ScaleM, class ScaleN>
333 CK_TILE_HOST static constexpr FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>
334 MakeKernelArgs(const ScaleFlatmmHostArgs<ScaleM, ScaleN, DsDataType::size()>& hostArgs)
335 {
336 return {hostArgs.a_ptr,
337 hostArgs.b_ptr,
338 hostArgs.ds_ptr,
339 hostArgs.e_ptr,
340 hostArgs.M,
341 hostArgs.N,
342 hostArgs.K,
343 hostArgs.stride_A,
344 hostArgs.stride_B,
345 hostArgs.stride_Ds,
346 hostArgs.stride_E,
347 hostArgs.k_batch,
348 hostArgs.scale_m,
349 hostArgs.scale_n};
350 }
351
353 {
354 return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
355 }
357 {
358 return FlatmmPipeline::GetSmemSize();
359 }
360
362 {
363 template <class KernelArgs>
364 __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
365 {
366 constexpr auto N1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<1>{});
367 constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
368 const index_t K_t = kargs.k_batch * K1;
369 const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
370
371 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
372 {
373 a_k_split_offset = k_id * KRead;
374 }
375 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
376 {
377 a_k_split_offset = k_id * KRead * kargs.stride_A;
378 }
379
380 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
381 {
382 b_k_split_offset = k_id * KRead * kargs.stride_B * N1;
383 }
384 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
385 {
386 b_k_split_offset = k_id * KRead * N1;
387 }
388
389 if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
390 {
391 splitted_k = KRead;
392 }
393 else
394 {
395 splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
396 }
397 }
398
402 };
403
404 template <class KernelArgs>
405 CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
406 {
407 if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
409 {
410 if(kargs.k_batch != 1)
411 {
412 std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
413 return false;
414 }
415 }
416 if constexpr(UsePersistentKernel)
417 {
418 if(kargs.k_batch != 1)
419 {
420 std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl;
421 return false;
422 }
423 }
424
425 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
426 {
427 if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
428 {
429 std::cerr << "Can't support K that is not a multiple of KPerBlock"
430 " without padding!"
431 << std::endl;
432 return false;
433 }
434 if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
435 {
436 std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
437 return false;
438 }
439 }
440 else
441 {
442 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
443 {
444 std::cerr << "Can't support M that is not a multiple of MPerBlock"
445 " without padding!"
446 << std::endl;
447 return false;
448 }
449 if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
450 {
451 std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
452 return false;
453 }
454 }
455
456 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
457 {
458 if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
459 {
460 std::cerr << "Can't support N that is not a multiple of NPerBlock"
461 " without padding!"
462 << std::endl;
463 return false;
464 }
465 if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
466 {
467 std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
468 return false;
469 }
470 }
471 else
472 {
473 if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
474 {
475 std::cerr << "Can't support K that is not a multiple of KPerBlock"
476 " without padding!"
477 << std::endl;
478 return false;
479 }
480 if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
481 {
482 std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
483 return false;
484 }
485 }
486
487 bool DTesnorIsValid = {true};
488 static_for<0, NumDTensor, 1>{}([&](auto index) {
489 using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
490 if(std::is_same_v<DiLayout, ELayout> == false)
491 {
492 DTesnorIsValid = false;
493 }
494 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
495 {
496 if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
497 {
498 CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
499 "NPerBlock without padding!");
500 DTesnorIsValid = false;
501 }
502 if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
503 {
504 CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
505 DTesnorIsValid = false;
506 }
507 }
508 else
509 {
510 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
511 {
512 CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
513 "MPerBlock without padding!");
514
515 DTesnorIsValid = false;
516 }
517 if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
518 {
519 CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
520 DTesnorIsValid = false;
521 }
522 }
523 });
524
525 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
526 {
527 if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
528 {
529 std::cerr << "Can't support N that is not a multiple of NPerBlock"
530 " without padding!"
531 << std::endl;
532 return false;
533 }
534 if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
535 {
536 std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
537 return false;
538 }
539 }
540 else
541 {
542 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
543 {
544 std::cerr << "Can't support M that is not a multiple of MPerBlock"
545 " without padding!"
546 << std::endl;
547 return false;
548 }
549 if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
550 {
551 std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
552 return false;
553 }
554 }
555 return DTesnorIsValid;
556 }
557
558 template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
559 CK_TILE_DEVICE static auto
561 const BDataType* b_flat_ptr,
562 const std::array<const void*, NumDTensor>& ds_ptr,
563 EDataType* e_ptr,
564 const KernelArgs& kargs,
565 const SplitKBatchOffset& splitk_batch_offset)
566 {
567 const auto& a_tensor_view = [&]() {
568 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
569 {
571 a_ptr,
572 make_tuple(kargs.M, splitk_batch_offset.splitted_k),
573 make_tuple(kargs.stride_A, 1),
574 number<FlatmmPipeline::GetVectorSizeA()>{},
575 number<1>{});
576 }
577 else
578 {
580 a_ptr,
581 make_tuple(splitk_batch_offset.splitted_k, kargs.M),
582 make_tuple(kargs.stride_A, 1),
583 number<FlatmmPipeline::GetVectorSizeA()>{},
584 number<1>{});
585 }
586 }();
587
588 index_t kFlatK =
589 FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
590 index_t kFlatN = kargs.N * kargs.K / kFlatK;
591 const auto& b_flat_tensor_view = [&]() {
593 b_flat_ptr,
594 make_tuple(kFlatN, kFlatK),
595 make_tuple(kFlatK, 1),
596 number<FlatmmPipeline::GetVectorSizeB()>{},
597 number<1>{});
598 }();
599
600 const auto& ds_tensor_view = generate_tuple(
601 [&](auto i) {
602 using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
603 using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
604 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
605 {
607 static_cast<const DDataType_*>(ds_ptr[i]),
608 make_tuple(kargs.M, kargs.N),
609 make_tuple(kargs.stride_Ds[i], 1),
610 number<EpiloguePipeline::GetVectorSizeD(i)>{},
611 number<1>{});
612 }
613 else
614 {
616 static_cast<const DDataType_*>(ds_ptr[i]),
617 make_tuple(kargs.N, kargs.M),
618 make_tuple(kargs.stride_Ds[i], 1),
619 number<EpiloguePipeline::GetVectorSizeD(i)>{},
620 number<1>{});
621 }
622 },
624
625 // TODO: enable vector write for C in ColMajor
626 const auto& e_tensor_view = [&]() {
627 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
628 {
630 e_ptr,
631 make_tuple(kargs.M, kargs.N),
632 make_tuple(kargs.stride_E, 1),
633 number<EpiloguePipeline::GetVectorSizeC()>{},
634 number<1>{});
635 }
636 else
637 {
639 e_ptr,
640 make_tuple(kargs.N, kargs.M),
641 make_tuple(kargs.stride_E, 1),
642 number<1>{},
643 number<1>{});
644 }
645 }();
646
647 constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
648 constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
649
650 constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
651 constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
652
653 auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
654 : 1; // per-token scale
655 auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
656 : 1; // per-channel scale
657
658 static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1,
659 "only support per-tensor or per-row scaling");
660 static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1,
661 "only support per-tensor or per-column scaling");
662
664 kargs.scale_m_ptr.ptr,
666 kargs.M / ScaleGranularityM,
667 ScaleGranularityKA == 0 ? 1 : splitk_batch_offset.splitted_k / ScaleGranularityKA),
668 make_tuple(scale_stride_m, 0),
669 number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
670 number<1>{});
672 kargs.scale_n_ptr.ptr,
674 ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
675 kargs.N / ScaleGranularityN),
676 make_tuple(0, scale_stride_n),
677 number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
678 number<1>{});
679
680 return make_tuple(a_tensor_view,
681 b_flat_tensor_view,
682 ds_tensor_view,
683 e_tensor_view,
684 scale_m_view,
685 scale_n_view);
686 }
687
688 template <typename TensorView>
689 CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
690 {
691 const auto& a_pad_view = [&]() {
692 const auto& a_tensor_view = views.at(I0);
693 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
694 {
695 return pad_tensor_view(a_tensor_view,
699 }
700 else
701 {
702 return pad_tensor_view(a_tensor_view,
706 }
707 }();
708
709 const auto& b_flat_tensor_view = views.at(I1);
710
711 const auto& ds_pad_view = generate_tuple(
712 [&](auto i) {
713 const auto& d_tensor_view = views.at(I2);
714 using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
715 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
716 {
717 return pad_tensor_view(d_tensor_view[i],
721 }
722 else
723 {
724 return pad_tensor_view(d_tensor_view[i],
728 }
729 },
731
732 // TODO vector write in for C in ColMajor
733 const auto& e_pad_view = [&]() {
734 const auto& e_tensor_view = views.at(I3);
735 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
736 {
737 return pad_tensor_view(e_tensor_view,
741 }
742 else
743 {
744 return pad_tensor_view(e_tensor_view,
748 }
749 }();
750
751 return make_tuple(a_pad_view,
752 b_flat_tensor_view,
753 ds_pad_view,
754 e_pad_view,
755 views.at(number<4>{}),
756 views.at(number<5>{}));
757 }
758
759 template <typename PadView>
760 CK_TILE_DEVICE static auto
761 MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
762 {
763 const auto& a_pad_view = views.at(I0);
764 const auto& b_flat_pad_view = views.at(I1);
765 const auto& ds_pad_view = views.at(I2);
766 const auto& e_pad_view = views.at(I3);
767
768 const auto& a_block_window = [&]() {
769 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
770 {
771 return make_tile_window(a_pad_view,
774 {i_m, 0});
775 }
776 else
777 {
778 return make_tile_window(a_pad_view,
781 {0, i_m});
782 }
783 }();
784
785 const auto& b_flat_block_window =
786 make_tile_window(b_flat_pad_view,
789 {static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
790
791 const auto ds_block_window = generate_tuple(
792 [&](auto i) {
793 using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
794 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
795 {
796 return make_tile_window(ds_pad_view[i],
799 {i_m, i_n});
800 }
801 else
802 {
803 return make_tile_window(ds_pad_view[i],
806 {i_n, i_m});
807 }
808 },
810
811 auto e_block_window = make_tile_window(
812 e_pad_view,
814 {i_m, i_n});
815
816 constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK;
817 constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK;
818
819 auto scale_m_window = make_tile_window(views.at(number<4>{}),
821 number < ScaleGranularityKA == 0
822 ? TilePartitioner::NPerBlock
823 : TilePartitioner::KPerBlock > {}),
824 {i_m, 0});
825 auto scale_n_window = make_tile_window(views.at(number<5>{}),
826 make_tuple(number < ScaleGranularityKB == 0
827 ? TilePartitioner::MPerBlock
828 : TilePartitioner::KPerBlock > {},
830 {0, i_n});
831
832 return make_tuple(a_block_window,
833 b_flat_block_window,
834 ds_block_window,
835 e_block_window,
836 scale_m_window,
837 scale_n_window);
838 }
839
840 template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
841 CK_TILE_DEVICE static void
842 RunFlatmm(const ADataType* a_ptr,
843 const BDataType* b_flat_ptr,
844 const std::array<const void*, NumDTensor>& ds_ptr,
845 EDataType* e_ptr,
846 void* smem_ptr_ping,
847 void* smem_ptr_pong,
848 const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
849 const SplitKBatchOffset& splitk_batch_offset,
850 const index_t block_idx_m,
851 const index_t block_idx_n)
852 {
853 // Create Gemm tensor views, pad views and tile windows
854 const auto& gemm_tensor_views_tuple =
856 a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
857 const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
858 auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
859
860 const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
861
862 // Run GEMM cooperatively by whole workgroup.
863 const auto& a_block_window = gemm_tile_windows.at(I0);
864 const auto& b_flat_block_window = gemm_tile_windows.at(I1);
865 const auto& d_block_window = gemm_tile_windows.at(I2);
866 const auto& c_block_tile = FlatmmPipeline{}.template operator()(
867 a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
868
869 auto scale_m_window = gemm_tile_windows.at(number<4>{});
870 auto scale_n_window = gemm_tile_windows.at(number<5>{});
871
872 // Run Epilogue Pipeline
873 if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
874 {
875 auto& c_block_window = gemm_tile_windows.at(I3);
876 EpiloguePipeline{}.template
877 operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
878 c_block_window,
879 c_block_tile,
880 d_block_window,
881 smem_ptr_ping,
882 scale_m_window,
883 scale_n_window);
884 }
885 else if(UseDefaultScheduler || (get_warp_id() == 0))
886 {
887 // Run Epilogue Pipeline
888 auto& c_block_window = gemm_tile_windows.at(I3);
889 EpiloguePipeline{}.template
890 operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
891 c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
892 }
893 }
894
895 template <class ScaleM, class ScaleN>
896 CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
897 int partition_idx = blockIdx.x) const
898 {
899 int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
900
901 do
902 {
903 const auto [iM, iN] =
904 TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
905 const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
906 const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
907
908 const SplitKBatchOffset splitk_batch_offset(kargs);
909 // options
910 const ADataType* a_ptr =
911 static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
912 const BDataType* b_flat_ptr =
913 static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
914 EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
915
916 // allocate LDS
917 __shared__ char smem_ptr_ping[GetSmemPingSize()];
918 __shared__ char smem_ptr_pong[GetSmemPongSize()];
919
920 if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
921 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
923 {
924 constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
926 b_flat_ptr,
927 kargs.ds_ptr,
928 e_ptr,
929 smem_ptr_ping,
930 smem_ptr_pong,
931 kargs,
932 splitk_batch_offset,
933 i_m,
934 i_n);
935 }
936 partition_idx += gridDim.x;
937 } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
938 }
939};
940
941} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
__global__ void kentry(Args... args)
Definition tile/host/kernel_launch.hpp:22
@ atomic_add
Definition arch.hpp:58
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
std::string gemm_prec_str()
Definition utils.hpp:31
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
ScaleFlatmmHostArgs< FlatmmScalePointer<-1 >, FlatmmScalePointer<-1 >, NumberTensor > FlatmmHostArgs
Definition flatmm_kernel.hpp:224
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
unsigned int uint32_t
Definition stdint.h:126
index_t N
Definition flatmm_kernel.hpp:170
const void * a_ptr
Definition flatmm_kernel.hpp:161
index_t stride_B
Definition flatmm_kernel.hpp:173
const std::array< index_t, NumDTensor > stride_Ds
Definition flatmm_kernel.hpp:174
index_t stride_C
Definition flatmm_kernel.hpp:178
CK_TILE_HOST BaseFlatmmHostArgs()=default
index_t K
Definition flatmm_kernel.hpp:171
const void * b_ptr
Definition flatmm_kernel.hpp:162
index_t k_batch
Definition flatmm_kernel.hpp:181
index_t stride_E
Definition flatmm_kernel.hpp:177
CK_TILE_HOST BaseFlatmmHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition flatmm_kernel.hpp:134
index_t stride_A
Definition flatmm_kernel.hpp:172
const std::array< const void *, NumDTensor > ds_ptr
Definition flatmm_kernel.hpp:163
void * c_ptr
Definition flatmm_kernel.hpp:167
void * e_ptr
Definition flatmm_kernel.hpp:166
index_t M
Definition flatmm_kernel.hpp:169
index_t b_k_split_offset
Definition flatmm_kernel.hpp:400
index_t a_k_split_offset
Definition flatmm_kernel.hpp:399
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition flatmm_kernel.hpp:364
index_t splitted_k
Definition flatmm_kernel.hpp:401
Definition flatmm_kernel.hpp:229
ScaleN scale_n_ptr
Definition flatmm_kernel.hpp:244
void * e_ptr
Definition flatmm_kernel.hpp:234
std::array< index_t, NumDTensor > stride_Ds
Definition flatmm_kernel.hpp:240
index_t K
Definition flatmm_kernel.hpp:237
ScaleM scale_m_ptr
Definition flatmm_kernel.hpp:243
const void * b_ptr
Definition flatmm_kernel.hpp:232
index_t k_batch
Definition flatmm_kernel.hpp:242
index_t N
Definition flatmm_kernel.hpp:236
index_t stride_B
Definition flatmm_kernel.hpp:239
const std::array< const void *, NumDTensor > ds_ptr
Definition flatmm_kernel.hpp:233
const void * a_ptr
Definition flatmm_kernel.hpp:230
index_t stride_E
Definition flatmm_kernel.hpp:241
index_t M
Definition flatmm_kernel.hpp:235
index_t stride_A
Definition flatmm_kernel.hpp:238
Definition flatmm_kernel.hpp:249
static CK_TILE_HOST constexpr auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition flatmm_kernel.hpp:294
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemPongSize()
Definition flatmm_kernel.hpp:356
static CK_TILE_HOST constexpr auto BlockSize()
Definition flatmm_kernel.hpp:330
static CK_TILE_HOST constexpr FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> MakeKernelArgs(const ScaleFlatmmHostArgs< ScaleM, ScaleN, DsDataType::size()> &hostArgs)
Definition flatmm_kernel.hpp:334
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition flatmm_kernel.hpp:252
static constexpr auto I0
Definition flatmm_kernel.hpp:270
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition flatmm_kernel.hpp:250
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition flatmm_kernel.hpp:258
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition flatmm_kernel.hpp:560
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition flatmm_kernel.hpp:259
static constexpr bool UsePersistentKernel
Definition flatmm_kernel.hpp:261
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition flatmm_kernel.hpp:266
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition flatmm_kernel.hpp:254
static constexpr auto I2
Definition flatmm_kernel.hpp:272
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition flatmm_kernel.hpp:689
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition flatmm_kernel.hpp:896
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition flatmm_kernel.hpp:251
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition flatmm_kernel.hpp:257
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition flatmm_kernel.hpp:263
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition flatmm_kernel.hpp:256
static constexpr index_t NumDTensor
Definition flatmm_kernel.hpp:268
static CK_TILE_HOST const std::string GetName()
Definition flatmm_kernel.hpp:279
static constexpr index_t kBlockSize
Definition flatmm_kernel.hpp:260
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition flatmm_kernel.hpp:255
static constexpr auto I3
Definition flatmm_kernel.hpp:273
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition flatmm_kernel.hpp:761
static constexpr auto I1
Definition flatmm_kernel.hpp:271
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemPingSize()
Definition flatmm_kernel.hpp:352
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition flatmm_kernel.hpp:842
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition flatmm_kernel.hpp:264
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition flatmm_kernel.hpp:405
static CK_TILE_HOST constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
Definition flatmm_kernel.hpp:286
index_t stride_C
Definition flatmm_kernel.hpp:28
CK_TILE_HOST FlatmmProblem()=default
index_t M
Definition flatmm_kernel.hpp:23
index_t stride_B
Definition flatmm_kernel.hpp:27
CK_TILE_HOST FlatmmProblem(index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition flatmm_kernel.hpp:17
index_t stride_A
Definition flatmm_kernel.hpp:26
index_t N
Definition flatmm_kernel.hpp:24
index_t K
Definition flatmm_kernel.hpp:25
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float *)
Definition flatmm_kernel.hpp:117
static constexpr int GranularityMN
Definition flatmm_kernel.hpp:111
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
Definition flatmm_kernel.hpp:120
static constexpr int GranularityK
Definition flatmm_kernel.hpp:112
const float * ptr
Definition flatmm_kernel.hpp:114
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer()=default
CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
Definition flatmm_kernel.hpp:124
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float *, index_t)
Definition flatmm_kernel.hpp:118
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
Definition flatmm_kernel.hpp:81
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *ptr_, index_t length_)
Definition flatmm_kernel.hpp:76
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *ptr_)
Definition flatmm_kernel.hpp:75
static constexpr int GranularityMN
Definition flatmm_kernel.hpp:66
CK_TILE_HOST_DEVICE float operator[](index_t i) const
Definition flatmm_kernel.hpp:97
const float * ptr
Definition flatmm_kernel.hpp:71
static constexpr int GranularityK
Definition flatmm_kernel.hpp:67
index_t length
Definition flatmm_kernel.hpp:72
Definition flatmm_kernel.hpp:33
static constexpr int GranularityMN
Definition flatmm_kernel.hpp:34
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *ptr_, index_t length_)
Definition flatmm_kernel.hpp:41
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
Definition flatmm_kernel.hpp:46
const float * ptr
Definition flatmm_kernel.hpp:37
CK_TILE_HOST_DEVICE FlatmmScalePointer()=default
static constexpr int GranularityK
Definition flatmm_kernel.hpp:35
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *ptr_)
Definition flatmm_kernel.hpp:40
CK_TILE_HOST_DEVICE float operator[](index_t i) const =delete
Definition flatmm_kernel.hpp:187
CK_TILE_HOST ScaleFlatmmHostArgs()=default
CK_TILE_HOST ScaleFlatmmHostArgs(const void *a_ptr_, const void *b_shuffle_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *c_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_C_, ScaleM scale_m_=nullptr, ScaleN scale_n_=nullptr)
Definition flatmm_kernel.hpp:189
Definition type_traits.hpp:115
Definition coordinate_transform.hpp:1392
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43