30 typename InElementwiseOperation,
31 typename WeiElementwiseOperation,
32 typename OutElementwiseOperation,
42 typename M1N1ThreadClusterM1Xs,
43 typename M1N1ThreadClusterN1Xs,
44 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
45 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
46 typename ABlockTransferThreadClusterArrangeOrder,
47 typename ABlockTransferSrcAccessOrder,
48 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
49 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
50 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
51 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
52 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
53 typename BBlockTransferThreadClusterArrangeOrder,
54 typename BBlockTransferSrcAccessOrder,
55 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
56 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
57 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
58 typename CThreadTransferSrcDstAccessOrder,
59 index_t CThreadTransferSrcDstVectorDim,
60 index_t CThreadTransferDstScalarPerVector>
64 ck::tuple_element_t<NDimSpatial - 1,
65 ck::Tuple<ck::tensor_layout::convolution::NWC,
66 ck::tensor_layout::convolution::NHWC,
67 ck::tensor_layout::convolution::NDHWC>>,
68 ck::tuple_element_t<NDimSpatial - 1,
69 ck::Tuple<ck::tensor_layout::convolution::KXC,
70 ck::tensor_layout::convolution::KYXC,
71 ck::tensor_layout::convolution::KZYXC>>,
72 ck::tuple_element_t<NDimSpatial - 1,
73 ck::Tuple<ck::tensor_layout::convolution::NWK,
74 ck::tensor_layout::convolution::NHWK,
75 ck::tensor_layout::convolution::NDHWK>>,
79 InElementwiseOperation,
80 WeiElementwiseOperation,
81 OutElementwiseOperation>
101 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
106 std::vector<ck::index_t> input_spatial_lengths,
107 std::vector<ck::index_t> filter_spatial_lengths,
108 std::vector<ck::index_t> output_spatial_lengths,
109 std::vector<ck::index_t> conv_filter_strides,
110 std::vector<ck::index_t> conv_filter_dilations,
111 std::vector<ck::index_t> input_left_pads,
112 std::vector<ck::index_t> input_right_pads,
113 std::vector<ck::index_t> tildes)
119 const index_t Wi = input_spatial_lengths[0];
120 const index_t Wo = output_spatial_lengths[0];
121 const index_t X = filter_spatial_lengths[0];
122 const index_t InLeftPadW = input_left_pads[0];
123 const index_t InRightPadW = input_right_pads[0];
124 const index_t ConvStrideW = conv_filter_strides[0];
125 const index_t ConvDilationW = conv_filter_dilations[0];
127 const auto K0 = K / K1;
131 if constexpr(ConvBackwardDataSpecialization ==
143 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
160 in_n_x_wo_c_grid_desc,
167 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
168 wei_gemmk0_gemmn_gemmk1_grid_desc,
169 in_gemmm_gemmn_grid_desc);
173 const auto out_n_wo_k_grid_desc =
175 const auto wei_k_x_c_grid_desc =
178 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
180 const auto XTilde = ConvStrideW / GcdStrideDilationW;
189 math::max(
I0, InLeftPadW - ConvDilationW * (XTilde -
I1)), ConvStrideW);
194 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
201 out_n_wo_k_grid_desc,
209 out_n_wop_k_grid_desc,
219 out_n_xdot_wtilde_k_grid_desc,
228 out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
246 wei_k_xdot_xtilde_c_grid_desc,
255 wei_k0_k1_xdotslice_c_grid_desc,
272 in_n_wip_c_grid_desc,
281 in_n_xtilde_wtilde_c_grid_desc,
290 in_n_wtildeslice_c_grid_desc,
296 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
297 wei_gemmk0_gemmn_gemmk1_grid_desc,
298 in_gemmm_gemmn_grid_desc);
302 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
307 std::vector<ck::index_t> input_spatial_lengths,
308 std::vector<ck::index_t> filter_spatial_lengths,
309 std::vector<ck::index_t> output_spatial_lengths,
310 std::vector<ck::index_t> conv_filter_strides,
311 std::vector<ck::index_t> conv_filter_dilations,
312 std::vector<ck::index_t> input_left_pads,
313 std::vector<ck::index_t> input_right_pads,
314 std::vector<ck::index_t> tildes)
321 const index_t Hi = input_spatial_lengths[0];
322 const index_t Wi = input_spatial_lengths[1];
324 const index_t Ho = output_spatial_lengths[0];
325 const index_t Wo = output_spatial_lengths[1];
327 const index_t Y = filter_spatial_lengths[0];
328 const index_t X = filter_spatial_lengths[1];
330 const index_t InLeftPadH = input_left_pads[0];
331 const index_t InLeftPadW = input_left_pads[1];
333 const index_t InRightPadH = input_right_pads[0];
334 const index_t InRightPadW = input_right_pads[1];
336 const index_t ConvStrideH = conv_filter_strides[0];
337 const index_t ConvStrideW = conv_filter_strides[1];
339 const index_t ConvDilationH = conv_filter_dilations[0];
340 const index_t ConvDilationW = conv_filter_dilations[1];
342 const auto K0 = K / K1;
344 const auto out_n_ho_wo_k_grid_desc =
346 const auto wei_k_y_x_c_grid_desc =
348 const auto in_n_hi_wi_c_grid_desc =
351 if constexpr(ConvBackwardDataSpecialization ==
363 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
372 in_n_hi_wi_c_grid_desc,
381 in_n_y_ho_x_wo_c_grid_desc,
389 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
390 wei_gemmk0_gemmn_gemmk1_grid_desc,
391 in_gemmm_gemmn_grid_desc);
395 const auto GcdStrideDilationH =
math::gcd(ConvStrideH, ConvDilationH);
396 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
398 const auto YTilde = ConvStrideH / GcdStrideDilationH;
399 const auto XTilde = ConvStrideW / GcdStrideDilationW;
411 math::max(
I0, InLeftPadH - ConvDilationH * (YTilde -
I1)), ConvStrideH);
413 math::max(
I0, InLeftPadW - ConvDilationW * (XTilde -
I1)), ConvStrideW);
420 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
421 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
429 out_n_ho_wo_k_grid_desc,
438 out_n_hop_wop_k_grid_desc,
449 const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
451 out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
472 out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
481 wei_k_y_x_c_grid_desc,
491 const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
513 wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
522 in_n_hi_wi_c_grid_desc,
531 in_n_hip_wip_c_grid_desc,
542 in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
563 in_n_htildeslice_wtildeslice_c_grid_desc,
569 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
570 wei_gemmk0_gemmn_gemmk1_grid_desc,
571 in_gemmm_gemmn_grid_desc);
576 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
581 std::vector<ck::index_t> input_spatial_lengths,
582 std::vector<ck::index_t> filter_spatial_lengths,
583 std::vector<ck::index_t> output_spatial_lengths,
584 std::vector<ck::index_t> conv_filter_strides,
585 std::vector<ck::index_t> conv_filter_dilations,
586 std::vector<ck::index_t> input_left_pads,
587 std::vector<ck::index_t> input_right_pads,
588 std::vector<ck::index_t> tildes)
592 const index_t i_ztilde = tildes[0];
593 const index_t i_ytilde = tildes[1];
594 const index_t i_xtilde = tildes[2];
596 const index_t Di = input_spatial_lengths[0];
597 const index_t Hi = input_spatial_lengths[1];
598 const index_t Wi = input_spatial_lengths[2];
600 const index_t Do = output_spatial_lengths[0];
601 const index_t Ho = output_spatial_lengths[1];
602 const index_t Wo = output_spatial_lengths[2];
604 const index_t Z = filter_spatial_lengths[0];
605 const index_t Y = filter_spatial_lengths[1];
606 const index_t X = filter_spatial_lengths[2];
608 const index_t InLeftPadD = input_left_pads[0];
609 const index_t InLeftPadH = input_left_pads[1];
610 const index_t InLeftPadW = input_left_pads[2];
612 const index_t InRightPadD = input_right_pads[0];
613 const index_t InRightPadH = input_right_pads[1];
614 const index_t InRightPadW = input_right_pads[2];
616 const index_t ConvStrideD = conv_filter_strides[0];
617 const index_t ConvStrideH = conv_filter_strides[1];
618 const index_t ConvStrideW = conv_filter_strides[2];
620 const index_t ConvDilationD = conv_filter_dilations[0];
621 const index_t ConvDilationH = conv_filter_dilations[1];
622 const index_t ConvDilationW = conv_filter_dilations[2];
624 const auto K0 = K / K1;
626 const auto out_n_do_ho_wo_k_grid_desc =
628 const auto wei_k_z_y_x_c_grid_desc =
630 const auto in_n_di_hi_wi_c_grid_desc =
633 if constexpr(ConvBackwardDataSpecialization ==
645 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
654 in_n_di_hi_wi_c_grid_desc,
669 in_n_z_do_y_ho_x_wo_c_grid_desc,
682 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
683 wei_gemmk0_gemmn_gemmk1_grid_desc,
684 in_gemmm_gemmn_grid_desc);
688 const auto GcdStrideDilationD =
math::gcd(ConvStrideD, ConvDilationD);
689 const auto GcdStrideDilationH =
math::gcd(ConvStrideH, ConvDilationH);
690 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
692 const auto ZTilde = ConvStrideD / GcdStrideDilationD;
693 const auto YTilde = ConvStrideH / GcdStrideDilationH;
694 const auto XTilde = ConvStrideW / GcdStrideDilationW;
709 math::max(
I0, InLeftPadD - ConvDilationD * (ZTilde -
I1)), ConvStrideD);
711 math::max(
I0, InLeftPadH - ConvDilationH * (YTilde -
I1)), ConvStrideH);
713 math::max(
I0, InLeftPadW - ConvDilationW * (XTilde -
I1)), ConvStrideW);
722 const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
723 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
724 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
733 out_n_do_ho_wo_k_grid_desc,
744 const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
746 out_n_dop_hop_wop_k_grid_desc,
765 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
767 out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
794 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
803 const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
805 wei_k_z_y_x_c_grid_desc,
823 const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
851 wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
860 in_n_di_hi_wi_c_grid_desc,
871 const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
873 in_n_dip_hip_wip_c_grid_desc,
890 const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
892 in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
919 in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
926 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
927 wei_gemmk0_gemmn_gemmk1_grid_desc,
928 in_gemmm_gemmn_grid_desc);
933 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
937 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0});
940 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
944 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0});
947 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
986 M1N1ThreadClusterM1Xs,
987 M1N1ThreadClusterN1Xs,
988 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
989 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
990 ABlockTransferThreadClusterArrangeOrder,
991 ABlockTransferSrcAccessOrder,
992 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
993 ABlockTransferSrcVectorTensorContiguousDimOrder,
994 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
995 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
996 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
997 BBlockTransferThreadClusterArrangeOrder,
998 BBlockTransferSrcAccessOrder,
999 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
1000 BBlockTransferSrcVectorTensorContiguousDimOrder,
1001 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
1002 CThreadTransferSrcDstAccessOrder,
1003 CThreadTransferSrcDstVectorDim,
1004 CThreadTransferDstScalarPerVector>;
1018 const WeiDataType* p_wei_grid,
1019 const OutDataType* p_out_grid,
1023 std::vector<ck::index_t> input_spatial_lengths,
1024 std::vector<ck::index_t> filter_spatial_lengths,
1025 std::vector<ck::index_t> output_spatial_lengths,
1026 std::vector<ck::index_t> conv_filter_strides,
1027 std::vector<ck::index_t> conv_filter_dilations,
1028 std::vector<ck::index_t> input_left_pads,
1029 std::vector<ck::index_t> input_right_pads,
1030 InElementwiseOperation in_element_op,
1031 WeiElementwiseOperation wei_element_op,
1032 OutElementwiseOperation out_element_op)
1053 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
1058 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
1059 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1063 for(
index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1103 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
1112 const auto GcdStrideDilationH =
math::gcd(ConvStrideH, ConvDilationH);
1113 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
1115 const auto YTilde = ConvStrideH / GcdStrideDilationH;
1116 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1120 for(
index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
1122 for(
index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1127 if(YDotSlice * XDotSlice <= 0)
1144 {i_ytilde, i_xtilde});
1164 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
1175 const auto GcdStrideDilationD =
math::gcd(ConvStrideD, ConvDilationD);
1176 const auto GcdStrideDilationH =
math::gcd(ConvStrideH, ConvDilationH);
1177 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
1179 const auto ZTilde = ConvStrideD / GcdStrideDilationD;
1180 const auto YTilde = ConvStrideH / GcdStrideDilationH;
1181 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1186 for(
index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
1188 for(
index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
1190 for(
index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1196 if(ZDotSlice * YDotSlice * XDotSlice <= 0)
1213 {i_ztilde, i_ytilde, i_xtilde});
1278 std::cout <<
"arg.a_grid_desc_k0_m_k1_container_{"
1284 std::cout <<
"arg.b_grid_desc_k0_n_k1_container_{"
1290 std::cout <<
"arg.c_grid_desc_m_n_container_{ "
1295 std::cout <<
"arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_( "
1307 <<
" ) " << std::endl;
1314 throw std::runtime_error(
1315 "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
1321 auto launch_kernel = [&](
auto has_main_k_block_loop,
1322 auto has_double_tail_k_block_loop) {
1323 constexpr bool has_main_loop = has_main_k_block_loop.value;
1324 constexpr bool has_double_loop = has_double_tail_k_block_loop;
1354 const bool has_double_tail_k_block_loop =
1357 if(has_main_k_block_loop && has_double_tail_k_block_loop)
1361 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
1364 integral_constant<bool, false>{});
1366 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
1369 integral_constant<bool, true>{});
1374 integral_constant<bool, false>{});
1383 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
1402 if constexpr(ConvBackwardDataSpecialization ==
1406 for(
int i = 0; i < NDimSpatial; i++)
1418 auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
1419 if(srcVectorLengths[
I1] != 1 || srcVectorLengths[
I2] != 1)
1423 if(K1 % srcVectorLengths[
I3] != 0 || K0PerBlock % srcVectorLengths[
I0] != 0)
1430 if(K % (srcVectorLengths[
I0] * srcVectorLengths[
I3]) != 0)
1438 auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{};
1439 auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
1440 if(srcVectorLengths[
I0] != 1 || srcVectorLengths[
I3] != 1)
1444 if(srcLoadLenghts[
I1] % srcVectorLengths[
I1] != 0 ||
1445 srcLoadLenghts[
I2] % srcVectorLengths[
I2] != 0)
1452 if(C % (srcVectorLengths[
I1] * srcVectorLengths[
I2]) != 0)
1458 if(!(arg.
Conv_C_ % CThreadTransferDstScalarPerVector == 0))
1460 std::cout <<
"Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = "
1461 << arg.
Conv_C_ % CThreadTransferDstScalarPerVector << std::endl;
1484 const WeiDataType* p_wei_grid,
1485 const OutDataType* p_out_grid,
1489 std::vector<ck::index_t> input_spatial_lengths,
1490 std::vector<ck::index_t> filter_spatial_lengths,
1491 std::vector<ck::index_t> output_spatial_lengths,
1492 std::vector<ck::index_t> conv_filter_strides,
1493 std::vector<ck::index_t> conv_filter_dilations,
1494 std::vector<ck::index_t> input_left_pads,
1495 std::vector<ck::index_t> input_right_pads,
1496 InElementwiseOperation in_element_op,
1497 WeiElementwiseOperation wei_element_op,
1498 OutElementwiseOperation out_element_op)
1506 input_spatial_lengths,
1507 filter_spatial_lengths,
1508 output_spatial_lengths,
1509 conv_filter_strides,
1510 conv_filter_dilations,
1520 std::unique_ptr<BaseArgument>
1522 const void* p_wei_grid,
1523 const void* p_out_grid,
1527 std::vector<ck::index_t> input_spatial_lengths,
1528 std::vector<ck::index_t> filter_spatial_lengths,
1529 std::vector<ck::index_t> output_spatial_lengths,
1530 std::vector<ck::index_t> conv_filter_strides,
1531 std::vector<ck::index_t> conv_filter_dilations,
1532 std::vector<ck::index_t> input_left_pads,
1533 std::vector<ck::index_t> input_right_pads,
1534 InElementwiseOperation in_element_op,
1535 WeiElementwiseOperation wei_element_op,
1536 OutElementwiseOperation out_element_op)
override
1538 return std::make_unique<Argument>(
static_cast<InDataType*
>(p_in_grid),
1539 static_cast<const WeiDataType*
>(p_wei_grid),
1540 static_cast<const OutDataType*
>(p_out_grid),
1544 input_spatial_lengths,
1545 filter_spatial_lengths,
1546 output_spatial_lengths,
1547 conv_filter_strides,
1548 conv_filter_dilations,
1558 return std::make_unique<Invoker>(
Invoker{});
1563 auto str = std::stringstream();
1566 str <<
"DeviceConvNdBwdDataNwcKxcNwk_Dl"
1568 << BlockSize <<
", "
1569 << MPerBlock <<
", "
1570 << NPerBlock <<
", "
1571 << K0PerBlock <<
", "
1574 if constexpr(ConvBackwardDataSpecialization ==
1577 str<<
" Filter1x1Stride1Pad0";
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
Definition utility/math.hpp:154
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
ConvolutionBackwardDataSpecialization
Definition convolution_backward_data_specialization.hpp:11
@ Filter1x1Stride1Pad0
Definition convolution_backward_data_specialization.hpp:13
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition multi_index_transform_helper.hpp:163
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__global__ void kernel_gemm_dl_v1r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_dl_v1r3.hpp:33
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_dl_v1r3.hpp:93
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11 __host__ static __device__ constexpr auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:208
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateHasMainKBlockLoop __host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_v1r3.hpp:153
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:129
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateHasDoubleTailKBlockLoop __host__ static __device__ constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_v1r3.hpp:160
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeDefaultBlock2CTileMap __host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:241
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeAGridDescriptor_K0_M0_M1_K1 __host__ static __device__ constexpr auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition gridwise_gemm_dl_v1r3.hpp:168
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeBGridDescriptor_K0_N0_N1_K1 __host__ static __device__ constexpr auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition gridwise_gemm_dl_v1r3.hpp:188
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition device_base.hpp:197
Definition device_conv_bwd_data.hpp:25
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1016
const BDataType * p_b_grid_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1236
std::vector< CGridDesc_M0_M10_M11_N0_N10_N11 > c_grid_desc_m0_m10_m11_n0_n10_n11_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1244
index_t Conv_N_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1253
std::vector< BGridDesc_K0_N0_N1_K1 > b_grid_desc_k0_n0_n1_k1_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1243
std::vector< CGridDesc_M_N > c_grid_desc_m_n_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1240
std::vector< ck::index_t > filter_spatial_lengths_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1258
index_t Conv_C_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1255
InElementwiseOperation c_element_op_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1251
std::vector< ck::index_t > output_spatial_lengths_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1259
std::vector< ck::index_t > input_left_pads_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1262
std::vector< ck::index_t > conv_filter_strides_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1260
std::vector< ck::index_t > input_spatial_lengths_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1257
WeiElementwiseOperation b_element_op_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1250
index_t Conv_K_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1254
std::vector< BGridDesc_K0_N_K1 > b_grid_desc_k0_n_k1_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1239
CDataType * p_c_grid_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1237
std::vector< ck::index_t > input_right_pads_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1263
void CreateABCDesc()
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1054
std::vector< AGridDesc_K0_M_K1 > a_grid_desc_k0_m_k1_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1238
const ADataType * p_a_grid_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1235
std::vector< DefaultBlock2CTileMap > block_2_ctile_map_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1246
Argument(InDataType *p_in_grid, const WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1017
OutElementwiseOperation a_element_op_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1249
std::vector< ck::index_t > conv_filter_dilations_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1261
std::vector< AGridDesc_K0_M0_M1_K1 > a_grid_desc_k0_m0_m1_k1_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1242
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1268
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1380
DeviceOp::Argument Argument
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1269
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1271
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:82
InDataType CDataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:87
InDataType ABDataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:90
static constexpr auto I3
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:95
static constexpr bool IsValidCompilationParameter()
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1387
DeviceConvNdBwdDataNwcKxcNwk_Dl DeviceOp
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:83
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:963
static constexpr auto I7
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:99
static constexpr auto I5
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:97
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1556
std::string GetTypeString() const override
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1561
static bool IsSupportedArgument(const Argument &arg)
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1393
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, std::vector< ck::index_t > tildes)
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:103
static auto MakeArgument(InDataType *p_in_grid, const WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1483
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:966
static constexpr auto I2
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:94
std::unique_ptr< BaseArgument > MakeArgumentPointer(void *p_in_grid, const void *p_wei_grid, const void *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) override
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1521
static auto MakeInvoker()
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1518
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})) DefaultBlock2CTileMap
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1012
static constexpr auto I4
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:96
OutDataType ADataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:85
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:967
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1008
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:965
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1478
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1006
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1010
GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:970
static constexpr auto I6
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:98
static constexpr auto I0
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:92
static constexpr auto I1
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:93
WeiDataType BDataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:86
static auto GetABCGridDesc()
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:934
#define CK_ENV(name)
Definition utility/env.hpp:129