grouped_convolution_backward_weight_kernel.hpp Source File

grouped_convolution_backward_weight_kernel.hpp Source File#

Composable Kernel: grouped_convolution_backward_weight_kernel.hpp Source File
grouped_convolution_backward_weight_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <string>
8
9#include "ck_tile/core.hpp"
16
17namespace ck_tile {
18
20template <typename GroupedConvTraitsType_>
22{
23
25 TransformConvBwdWeightToGemm<GroupedConvTraitsType_::NDimSpatial,
26 GroupedConvTraitsType_::ConvSpecialization,
27 GroupedConvTraitsType_::VectorSizeA,
28 GroupedConvTraitsType_::VectorSizeB,
29 GroupedConvTraitsType_::VectorSizeC,
30 GroupedConvTraitsType_::NumGroupsToMerge>;
31 static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
32
33 template <
34 typename InLay = typename GroupedConvTraitsType_::InLayout,
35 typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
36 typename OutLay = typename GroupedConvTraitsType_::OutLayout,
37 typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
38 std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
39 std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
40 bool>::type = false>
42 {
43 in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
44 static_cast<index_t>(args.N_),
45 static_cast<index_t>(args.C_),
46 static_cast<index_t>(args.input_spatial_lengths_[0])};
47 wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
48 static_cast<index_t>(args.K_),
49 static_cast<index_t>(args.C_),
50 static_cast<index_t>(args.filter_spatial_lengths_[0])};
51 out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
52 static_cast<index_t>(args.N_),
53 static_cast<index_t>(args.K_),
54 static_cast<index_t>(args.output_spatial_lengths_[0])};
55
56 conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
58 input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
59 input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
60
61 k_batch = args.k_batch;
62
63 in_ptr = args.in_ptr;
64 wei_ptr = args.wei_ptr;
65 for(index_t d = 0; d < NumDTensor; d++)
66 {
67 ds_ptr[d] = args.ds_ptr[d];
68 }
69 out_ptr = args.out_ptr;
70
71 ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
78
79 // tuple
80 auto grid_descs =
81 conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
82 GroupedConvTraitsType_::NDimSpatial>();
83
84 a_grid_desc_k_m = grid_descs.at(number<0>{});
85 b_grid_desc_k_n = grid_descs.at(number<1>{});
86 c_grid_desc_m_n = grid_descs.at(number<2>{});
87
88 NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
89 group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NWGK
90 group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NWGC
91 group_stride_c = args.K_ * args.C_ // C: Wei GKXC
93 std::accumulate(args.filter_spatial_lengths_.begin(),
94 args.filter_spatial_lengths_.end(),
95 1,
96 std::multiplies<index_t>());
97
98 GemmM = a_grid_desc_k_m.get_length(number<1>{});
99 GemmN = b_grid_desc_k_n.get_length(number<1>{});
100 GemmK = a_grid_desc_k_m.get_length(number<0>{});
102
103 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
104 {
105 std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
106 << ", GemmBatch: " << GemmBatch
107 << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
108 }
109 }
110
111 template <
112 typename InLay = typename GroupedConvTraitsType_::InLayout,
113 typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
114 typename OutLay = typename GroupedConvTraitsType_::OutLayout,
115 typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
116 std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
117 std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
118 bool>::type = false>
120 {
121 in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
122 static_cast<index_t>(args.N_),
123 static_cast<index_t>(args.C_),
124 static_cast<index_t>(args.input_spatial_lengths_[0]),
125 static_cast<index_t>(args.input_spatial_lengths_[1])};
126 wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
127 static_cast<index_t>(args.K_),
128 static_cast<index_t>(args.C_),
129 static_cast<index_t>(args.filter_spatial_lengths_[0]),
130 static_cast<index_t>(args.filter_spatial_lengths_[1])};
131 out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
132 static_cast<index_t>(args.N_),
133 static_cast<index_t>(args.K_),
134 static_cast<index_t>(args.output_spatial_lengths_[0]),
135 static_cast<index_t>(args.output_spatial_lengths_[1])};
136
137 conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
138 static_cast<index_t>(args.conv_filter_strides_[1])};
140 static_cast<index_t>(args.conv_filter_dilations_[1])};
141 input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
142 static_cast<index_t>(args.input_left_pads_[1])};
143 input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
144 static_cast<index_t>(args.input_right_pads_[1])};
145
146 k_batch = args.k_batch;
147
148 in_ptr = args.in_ptr;
149 wei_ptr = args.wei_ptr;
150 for(index_t d = 0; d < NumDTensor; d++)
151 {
152 ds_ptr[d] = args.ds_ptr[d];
153 }
154 out_ptr = args.out_ptr;
155
156 ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
163
164 // tuple
165 auto grid_descs =
166 conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
167 GroupedConvTraitsType_::NDimSpatial>();
168
169 a_grid_desc_k_m = grid_descs.at(number<0>{});
170 b_grid_desc_k_n = grid_descs.at(number<1>{});
171 c_grid_desc_m_n = grid_descs.at(number<2>{});
172
173 NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
174 group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NHWGK
175 group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NHWGC
176 group_stride_c = args.K_ * args.C_ // C: Wei GKYXC
178 std::accumulate(args.filter_spatial_lengths_.begin(),
179 args.filter_spatial_lengths_.end(),
180 1,
181 std::multiplies<index_t>());
182
183 GemmM = a_grid_desc_k_m.get_length(number<1>{});
184 GemmN = b_grid_desc_k_n.get_length(number<1>{});
185 GemmK = a_grid_desc_k_m.get_length(number<0>{});
187
188 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
189 {
190 std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
191 << ", GemmBatch: " << GemmBatch
192 << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
193 }
194 }
195
196 template <
197 typename InLay = typename GroupedConvTraitsType_::InLayout,
198 typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
199 typename OutLay = typename GroupedConvTraitsType_::OutLayout,
200 typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
201 std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
202 std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
203 bool>::type = false>
205 {
206 in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
207 static_cast<index_t>(args.N_),
208 static_cast<index_t>(args.C_),
209 static_cast<index_t>(args.input_spatial_lengths_[0]),
210 static_cast<index_t>(args.input_spatial_lengths_[1]),
211 static_cast<index_t>(args.input_spatial_lengths_[2])};
212 wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
213 static_cast<index_t>(args.K_),
214 static_cast<index_t>(args.C_),
215 static_cast<index_t>(args.filter_spatial_lengths_[0]),
216 static_cast<index_t>(args.filter_spatial_lengths_[1]),
217 static_cast<index_t>(args.filter_spatial_lengths_[2])};
218 out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
219 static_cast<index_t>(args.N_),
220 static_cast<index_t>(args.K_),
221 static_cast<index_t>(args.output_spatial_lengths_[0]),
222 static_cast<index_t>(args.output_spatial_lengths_[1]),
223 static_cast<index_t>(args.output_spatial_lengths_[2])};
224
225 conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
226 static_cast<index_t>(args.conv_filter_strides_[1]),
227 static_cast<index_t>(args.conv_filter_strides_[2])};
229 static_cast<index_t>(args.conv_filter_dilations_[1]),
230 static_cast<index_t>(args.conv_filter_dilations_[2])};
231 input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
232 static_cast<index_t>(args.input_left_pads_[1]),
233 static_cast<index_t>(args.input_left_pads_[2])};
234 input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
235 static_cast<index_t>(args.input_right_pads_[1]),
236 static_cast<index_t>(args.input_right_pads_[2])};
237
238 k_batch = args.k_batch;
239
240 in_ptr = args.in_ptr;
241 wei_ptr = args.wei_ptr;
242 for(index_t d = 0; d < NumDTensor; d++)
243 {
244 ds_ptr[d] = args.ds_ptr[d];
245 }
246 out_ptr = args.out_ptr;
247
248 ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
255
256 // tuple
257 auto grid_descs =
258 conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
259 GroupedConvTraitsType_::NDimSpatial>();
260
261 a_grid_desc_k_m = grid_descs.at(number<0>{});
262 b_grid_desc_k_n = grid_descs.at(number<1>{});
263 c_grid_desc_m_n = grid_descs.at(number<2>{});
264
265 NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
266 group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NDHWGK
267 group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NDHWGC
268 group_stride_c = args.K_ * args.C_ // C: Wei GKZYXC
270 std::accumulate(args.filter_spatial_lengths_.begin(),
271 args.filter_spatial_lengths_.end(),
272 1,
273 std::multiplies<index_t>());
274
275 GemmM = a_grid_desc_k_m.get_length(number<1>{});
276 GemmN = b_grid_desc_k_n.get_length(number<1>{});
277 GemmK = a_grid_desc_k_m.get_length(number<0>{});
279
280 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
281 {
282 std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
283 << ", GemmBatch: " << GemmBatch
284 << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
285 }
286 }
287
289 decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>;
290
294
295 static constexpr index_t NonSpatialDims = 3;
299
304
311
312 const void* out_ptr;
313 const void* in_ptr;
314 std::array<const void*, NumDTensor> ds_ptr;
315 void* wei_ptr;
316
320
324};
325
363template <typename GroupedConvTraitsType_,
364 typename TilePartitioner_,
365 typename GemmPipeline_,
366 typename EpiloguePipeline_>
368{
369 static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
371 GroupedConvTraitsType_::ConvSpecialization;
378
383
385 static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
386
387 static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
388
393
396
397 // TODO: Enable this
398 static constexpr bool IsSplitKSupported = true;
399
400 static constexpr auto I0 = number<0>();
401 static constexpr auto I1 = number<1>();
402 static constexpr auto I2 = number<2>();
403 static constexpr auto I3 = number<3>();
404
405 static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
406 "Not supported!");
407 static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
408 static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
409 static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
410
411 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
412 {
413 constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
414 // clang-format off
415 if (NumGroupsToMerge > 1) {
416 return concat('_', "grouped_convolution_backward_weight",
418 "gemm",
419 GemmPipeline::GetName(),
420 "epilogue",
421 EpiloguePipeline::GetName());
422 } else {
423 return concat('_', "grouped_convolution_backward_weight",
425 "gemm",
426 GemmPipeline::GetName(),
427 "epilogue",
428 EpiloguePipeline::GetName(), "merge", NumGroupsToMerge);
429 }
430 // clang-format on
431 }
432
433 CK_TILE_HOST static constexpr auto
435 {
436 return dim3(
437 TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
438 }
439
440 CK_TILE_HOST static constexpr auto BlockSize()
441 {
442 return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
443 }
444
447 {
448 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
449 {
450 std::cout << "MPerBlock: " << number<TilePartitioner::MPerBlock>{} << std::endl;
451 std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
452 std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
453 }
455 }
456
458 {
459 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
460 }
461
463 {
465 const std::size_t k_id = blockIdx.z)
466 {
467 constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
468 const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
469 const index_t KRead = amd_wave_read_first_lane((kargs.GemmK + K_t - 1) / K_t * K1);
470
473
474 if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
475 {
477 }
478 else
479 {
480 splitted_k = amd_wave_read_first_lane(kargs.GemmK - KRead * (kargs.k_batch - 1));
481 }
482 }
483
487 };
488
490 const stream_config& s)
491 {
492 return [&]() {
493 if(kargs.k_batch > 1)
494 {
495 // Total number of convolution groups (ConvG) = GemmBatch * NumGroupsPerBatch
496 // since we require that ConvG % NumGroupsPerBatch == 0.
497 const auto wei_size =
498 kargs.GemmBatch * kargs.GemmM * kargs.GemmN * kargs.NumGroupsPerBatch;
499 hipGetErrorString(
500 hipMemsetAsync(kargs.wei_ptr, 0, wei_size * sizeof(WeiDataType), s.stream_id_));
501 }
502 };
503 }
504
505 CK_TILE_HOST static bool
507 {
508 if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
511 {
512 if(kargs.k_batch != 1)
513 {
514 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
515 {
516 CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
517 }
518 return false;
519 }
520 }
521
522 const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
523 const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
524
525 // check ConvSpecialization
527 {
528 // check if it's 1x1, stride=1 conv
529 for(index_t i = 0; i < NDimSpatial; ++i)
530 {
531 const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
532 const index_t ConvStride = kargs.conv_filter_strides[i];
533 const index_t LeftPad = kargs.input_left_pads[i];
534 const index_t RightPad = kargs.input_right_pads[i];
535
536 if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
537 {
538 return false;
539 }
540 }
541 }
543 {
544 // check if it's 1x1 conv
545 for(index_t i = 0; i < NDimSpatial; ++i)
546 {
547 const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
548 const index_t LeftPad = kargs.input_left_pads[i];
549 const index_t RightPad = kargs.input_right_pads[i];
550
551 if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
552 {
553 return false;
554 }
555 }
556 }
558 {
559 if(ConvC != 1)
560 {
561 return false;
562 }
563 for(index_t i = 0; i < NDimSpatial; ++i)
564 {
565 const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
566
567 if(filter_spatial_dim != I3)
568 {
569 return false;
570 }
571 }
572 }
573
574 namespace ctc = tensor_layout::convolution;
575
576 if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
577 std::is_same_v<InLayout, ctc::NDHWGC>)
578 {
579 // Check access per C
580 if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
581 {
582 CK_TILE_ERROR("Conv C is not a multiple of vector load size for "
583 "input image!");
584 return false;
585 }
586 }
587 else
588 {
589 CK_TILE_ERROR("Not supported input layout!");
590 return false;
591 }
592
593 if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
594 std::is_same_v<WeiLayout, ctc::GKYXC> ||
595 std::is_same_v<WeiLayout, ctc::GKZYXC>)
596 {
597 if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
598 {
599 CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
600 return false;
601 }
602 }
603 else
604 {
605 CK_TILE_ERROR("Not supported weight layout!");
606 return false;
607 }
608
609 if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
610 std::is_same_v<OutLayout, ctc::NHWGK> ||
611 std::is_same_v<OutLayout, ctc::NDHWGK>)
612 {
613 if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
614 {
615 CK_TILE_ERROR("Conv K is not a multiple of vector store size "
616 "for output image!");
617 return false;
618 }
619 }
620 else
621 {
622 CK_TILE_ERROR("Not supported output layout!");
623 return false;
624 }
625
626 if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
627 {
628 const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
629 if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
630 {
631 CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
632 return false;
633 }
634
635 // TODO: Should we also check that GemmM <= MPerBlock and GemmN <= NPerBlock?
636 }
637
638 return true;
639 }
640
641 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
642 CK_TILE_DEVICE static auto
644 const InDataType* b_ptr,
645 const std::array<const void*, NumDTensor>& ds_ptr,
646 WeiDataType* c_ptr,
648 {
649 static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
650 static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
651 const auto& a_tensor_view = [&]() {
653 kargs.a_grid_desc_k_m); // A: out
654 }();
655
656 const auto& b_tensor_view = [&]() {
658 kargs.b_grid_desc_k_n); // B: in
659 }();
660
661 const auto& c_tensor_view = [&]() {
663 kargs.c_grid_desc_m_n);
664 }();
665
666 const auto& ds_tensor_view = generate_tuple(
667 [&](auto i) {
668 static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
669 "Not supported!");
670 static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
671 "Not supported!");
672 static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, WeiDataType>,
673 "Not supported!");
674
676 static_cast<WeiDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
677 },
679
680 return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
681 }
682
683 template <typename TensorView>
684 CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views, const index_t k_batch)
685 {
686 const auto& a_pad_view = [&]() {
687 const auto& a_tensor_view = views.at(I0);
688 return pad_tensor_view(a_tensor_view,
692 }();
693
694 const auto& b_pad_view = [&]() {
695 const auto& b_tensor_view = views.at(I1);
696 return pad_tensor_view(b_tensor_view,
700 }();
701
702 const auto& ds_tensor_view = views.at(I2);
703 const auto& ds_pad_view = generate_tuple(
704 [&](auto i) {
705 return pad_tensor_view(ds_tensor_view[i],
709 },
711
712 const auto& c_pad_view = [&]() {
713 const auto& c_tensor_view = views.at(I3);
714 return pad_tensor_view(c_tensor_view,
718 }();
719
720 return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
721 }
722
733 template <typename PadView>
734 CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
735 const index_t i_m,
736 const index_t i_n,
737 const index_t i_k)
738 {
739 const auto& a_pad_view = views.at(I0);
740 const auto& b_pad_view = views.at(I1);
741 const auto& ds_pad_view = views.at(I2);
742 const auto& c_pad_view = views.at(I3);
743
744 const auto& a_block_window = [&]() {
745 return make_tile_window(a_pad_view,
748 {i_k, i_m});
749 }();
750
751 const auto& b_block_window = [&]() {
752 return make_tile_window(b_pad_view,
755 {i_k, i_n});
756 }();
757
758 const auto ds_block_window = generate_tuple(
759 [&](auto i) {
760 return make_tile_window(ds_pad_view[i],
763 {i_m, i_n});
764 },
766
767 auto c_block_window = make_tile_window(
768 c_pad_view,
770 {i_m, i_n});
771
772 return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
773 }
774
787 CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
788 const InDataType* b_ptr,
789 const std::array<const void*, NumDTensor>& ds_ptr,
790 WeiDataType* c_ptr,
791 void* smem_ptr_0,
793 const index_t num_loop,
794 const index_t block_idx_m,
795 const index_t block_idx_n,
796 const index_t block_idx_k)
797 {
798 // Create Gemm tensor views, pad views and tile windows
799 const auto& gemm_tensor_views_tuple =
801 a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
802
803 const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
804 auto gemm_tile_windows =
805 MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
806
807 // Run GEMM cooperatively by whole workgroup.
808 const auto& a_block_window = gemm_tile_windows.at(I0);
809 const auto& b_block_window = gemm_tile_windows.at(I1);
810 const auto& d_block_window = gemm_tile_windows.at(I2);
811
812 const auto& c_block_tile = GemmPipeline{}.template operator()(
813 a_block_window, b_block_window, num_loop, smem_ptr_0);
814
815 // Run Epilogue Pipeline
816 auto& c_block_window = gemm_tile_windows.at(I3);
817
818 EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
819 c_block_window, c_block_tile, d_block_window, smem_ptr_0);
820 }
821
837 CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
838 const InDataType* b_ptr,
839 const std::array<const void*, NumDTensor>& ds_ptr,
840 WeiDataType* c_ptr,
841 void* __restrict__ smem_ptr_0,
842 void* __restrict__ smem_ptr_1,
844 const index_t num_loop,
845 const index_t block_idx_m,
846 const index_t block_idx_n,
847 const index_t block_idx_k)
848 {
849 // Create Gemm tensor views, pad views and tile windows
850 const auto& gemm_tensor_views_tuple =
852 a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
853 const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
854 auto gemm_tile_windows =
855 MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
856
857 // Run GEMM cooperatively by whole workgroup.
858 const auto& a_block_window = gemm_tile_windows.at(I0);
859 const auto& b_block_window = gemm_tile_windows.at(I1);
860 const auto& d_block_window = gemm_tile_windows.at(I2);
861
862 const auto& c_block_tile = GemmPipeline{}.template operator()(
863 a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
864
865 // Run Epilogue Pipeline
866 auto& c_block_window = gemm_tile_windows.at(I3);
867
868 EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
869 c_block_window, c_block_tile, d_block_window, smem_ptr_0);
870 }
871
873 {
874 const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
875 const auto [iM, iN] =
876 TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
877 const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
878 const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
879
880 const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
881 const index_t num_loop = amd_wave_read_first_lane(
882 ck_tile::integer_divide_ceil(kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
883 const index_t i_k =
884 amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
885
886 const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
887 const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
888 const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
889 const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
890
891 // options
892 // conv_bwd_weight = Out * In = Weight
893 const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
894 const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
895 WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
896
897 __shared__ char smem_ptr_0[GetSmemSize()];
898
899 if constexpr(GemmPipeline::DoubleSmemBuffer == true)
900 {
901 __shared__ char smem_ptr_1[GetSmemSize()];
902 if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
903 GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
905 {
906 RunGemm2LDS(a_ptr,
907 b_ptr,
908 kargs.ds_ptr,
909 c_ptr,
910 smem_ptr_0,
911 smem_ptr_1,
912 kargs,
913 num_loop,
914 i_m,
915 i_n,
916 i_k);
917 }
918 }
919 else
920 {
921 if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
922 GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
924 {
925 RunGemm(
926 a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, i_k);
927 }
928 }
929 }
930};
931
932} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/ops/common/tensor_layout.hpp:27
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
bool EnvIsEnabled(EnvVar)
Definition tile/core/utility/env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
ConvolutionSpecialization
Definition convolution_specialization.hpp:11
@ Filter1x1Stride1Pad0
Definition convolution_specialization.hpp:14
@ Filter3x3
Definition convolution_specialization.hpp:15
@ Filter1x1Pad0
Definition convolution_specialization.hpp:13
@ atomic_add
Definition arch.hpp:58
std::string gemm_prec_str()
Definition utils.hpp:31
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
int64_t long_index_t
Definition integer.hpp:11
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
GroupedConvHostArgs< const void *, void *, const void *, PassThrough > GroupedConvBwdWeightHostArgs
Definition grouped_convolution_utils.hpp:51
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
unsigned int uint32_t
Definition stdint.h:126
The Grouped Convolution kernel device arguments.
Definition grouped_convolution_backward_weight_kernel.hpp:22
long_index_t group_stride_a
Definition grouped_convolution_backward_weight_kernel.hpp:321
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition grouped_convolution_backward_weight_kernel.hpp:300
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())> ABCGridDescs
Definition grouped_convolution_backward_weight_kernel.hpp:288
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition grouped_convolution_backward_weight_kernel.hpp:297
void * wei_ptr
Definition grouped_convolution_backward_weight_kernel.hpp:315
long_index_t group_stride_b
Definition grouped_convolution_backward_weight_kernel.hpp:322
CGridDescMN c_grid_desc_m_n
Definition grouped_convolution_backward_weight_kernel.hpp:319
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition grouped_convolution_backward_weight_kernel.hpp:296
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition grouped_convolution_backward_weight_kernel.hpp:301
AGridDescKM a_grid_desc_k_m
Definition grouped_convolution_backward_weight_kernel.hpp:317
BGridDescKN b_grid_desc_k_n
Definition grouped_convolution_backward_weight_kernel.hpp:318
index_t GemmN
Definition grouped_convolution_backward_weight_kernel.hpp:307
index_t GemmBatch
Definition grouped_convolution_backward_weight_kernel.hpp:309
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition grouped_convolution_backward_weight_kernel.hpp:298
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs &args)
Definition grouped_convolution_backward_weight_kernel.hpp:41
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition grouped_convolution_backward_weight_kernel.hpp:302
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescKN
Definition grouped_convolution_backward_weight_kernel.hpp:292
std::array< const void *, NumDTensor > ds_ptr
Definition grouped_convolution_backward_weight_kernel.hpp:314
TransformConvBwdWeightToGemm< GroupedConvTraitsType_::NDimSpatial, GroupedConvTraitsType_::ConvSpecialization, GroupedConvTraitsType_::VectorSizeA, GroupedConvTraitsType_::VectorSizeB, GroupedConvTraitsType_::VectorSizeC, GroupedConvTraitsType_::NumGroupsToMerge > ConvToGemmTransformer
Definition grouped_convolution_backward_weight_kernel.hpp:24
index_t GemmM
Definition grouped_convolution_backward_weight_kernel.hpp:306
index_t NumGroupsPerBatch
Definition grouped_convolution_backward_weight_kernel.hpp:310
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition grouped_convolution_backward_weight_kernel.hpp:293
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition grouped_convolution_backward_weight_kernel.hpp:303
index_t GemmK
Definition grouped_convolution_backward_weight_kernel.hpp:308
const void * in_ptr
Definition grouped_convolution_backward_weight_kernel.hpp:313
index_t k_batch
Definition grouped_convolution_backward_weight_kernel.hpp:305
static constexpr index_t NonSpatialDims
Definition grouped_convolution_backward_weight_kernel.hpp:295
const void * out_ptr
Definition grouped_convolution_backward_weight_kernel.hpp:312
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescKM
Definition grouped_convolution_backward_weight_kernel.hpp:291
static constexpr index_t NumDTensor
Definition grouped_convolution_backward_weight_kernel.hpp:31
long_index_t group_stride_c
Definition grouped_convolution_backward_weight_kernel.hpp:323
InPtr in_ptr
Definition grouped_convolution_utils.hpp:39
OutPtr out_ptr
Definition grouped_convolution_utils.hpp:42
WeiPtr wei_ptr
Definition grouped_convolution_utils.hpp:40
index_t k_batch
Definition grouped_convolution_utils.hpp:43
const std::vector< const void * > ds_ptr
Definition grouped_convolution_utils.hpp:41
index_t b_k_split_offset
Definition grouped_convolution_backward_weight_kernel.hpp:485
index_t splitted_k
Definition grouped_convolution_backward_weight_kernel.hpp:486
__device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const std::size_t k_id=blockIdx.z)
Definition grouped_convolution_backward_weight_kernel.hpp:464
index_t a_k_split_offset
Definition grouped_convolution_backward_weight_kernel.hpp:484
The Grouped Convolution Backward Weight kernel template.
Definition grouped_convolution_backward_weight_kernel.hpp:368
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition grouped_convolution_backward_weight_kernel.hpp:384
static constexpr index_t kBlockSize
Definition grouped_convolution_backward_weight_kernel.hpp:387
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views, const index_t k_batch)
Definition grouped_convolution_backward_weight_kernel.hpp:684
static CK_TILE_HOST constexpr auto GridSize(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition grouped_convolution_backward_weight_kernel.hpp:434
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition grouped_convolution_backward_weight_kernel.hpp:381
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition grouped_convolution_backward_weight_kernel.hpp:372
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition grouped_convolution_backward_weight_kernel.hpp:373
static CK_TILE_HOST const std::string GetName()
Definition grouped_convolution_backward_weight_kernel.hpp:411
static CK_TILE_DEVICE void RunGemm2LDS(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition grouped_convolution_backward_weight_kernel.hpp:837
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition grouped_convolution_backward_weight_kernel.hpp:377
static constexpr auto I2
Definition grouped_convolution_backward_weight_kernel.hpp:402
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition grouped_convolution_backward_weight_kernel.hpp:375
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition grouped_convolution_backward_weight_kernel.hpp:457
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition grouped_convolution_backward_weight_kernel.hpp:506
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition grouped_convolution_backward_weight_kernel.hpp:374
static constexpr ConvolutionSpecialization ConvSpecialization
Definition grouped_convolution_backward_weight_kernel.hpp:370
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition grouped_convolution_backward_weight_kernel.hpp:380
static constexpr bool IsSplitKSupported
Definition grouped_convolution_backward_weight_kernel.hpp:398
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition grouped_convolution_backward_weight_kernel.hpp:787
static constexpr index_t NDimSpatial
Definition grouped_convolution_backward_weight_kernel.hpp:369
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition grouped_convolution_backward_weight_kernel.hpp:382
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition grouped_convolution_backward_weight_kernel.hpp:379
static CK_TILE_HOST auto Preprocess(const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const stream_config &s)
Definition grouped_convolution_backward_weight_kernel.hpp:489
remove_cvref_t< typename EpiloguePipeline::ODataType > WeiDataType
Definition grouped_convolution_backward_weight_kernel.hpp:392
static constexpr auto I3
Definition grouped_convolution_backward_weight_kernel.hpp:403
static constexpr auto I0
Definition grouped_convolution_backward_weight_kernel.hpp:400
static CK_TILE_HOST constexpr auto BlockSize()
Definition grouped_convolution_backward_weight_kernel.hpp:440
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized kargs) const
Definition grouped_convolution_backward_weight_kernel.hpp:872
static constexpr auto I1
Definition grouped_convolution_backward_weight_kernel.hpp:401
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition grouped_convolution_backward_weight_kernel.hpp:391
static constexpr index_t NumDTensor
Definition grouped_convolution_backward_weight_kernel.hpp:385
static CK_TILE_HOST constexpr GroupedConvBwdWeightKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdWeightHostArgs &hostArgs)
Definition grouped_convolution_backward_weight_kernel.hpp:446
remove_cvref_t< typename GemmPipeline::ADataType > OutDataType
Definition grouped_convolution_backward_weight_kernel.hpp:389
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition grouped_convolution_backward_weight_kernel.hpp:643
remove_cvref_t< typename GemmPipeline::BDataType > InDataType
Definition grouped_convolution_backward_weight_kernel.hpp:390
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition grouped_convolution_backward_weight_kernel.hpp:376
GroupedConvBwdWeightKernelArgs< GroupedConvTraitsType_ > GroupedConvBwdWeightKernelArgsSpecialized
Definition grouped_convolution_backward_weight_kernel.hpp:394
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k)
Create views to the data that each workgroup will process.
Definition grouped_convolution_backward_weight_kernel.hpp:734
Definition tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp:22
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition tile/host/convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition tile/host/convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition tile/host/convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition tile/host/convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition tile/host/convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition tile/host/convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition tile/host/convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition tile/host/convolution_parameter.hpp:134
Definition type_traits.hpp:115
Definition tile/core/container/sequence.hpp:49
Definition ck_tile/host/stream_config.hpp:30
hipStream_t stream_id_
Definition ck_tile/host/stream_config.hpp:31
#define CK_TILE_ENV(name)
Definition tile/core/utility/env.hpp:145