146 SplitImageInfo info{
false, 1, 1, 1};
152 static_cast<long_index_t>(N) * D_out * H_out * W_out * C * G;
154 static_cast<long_index_t>(N) * D_out * H_out * W_out * K * G;
155 const long_index_t input_bytes = input_elements *
sizeof(ADataType);
156 const long_index_t output_bytes = output_elements *
sizeof(CDataType);
158 (input_bytes > output_bytes) ? input_bytes : output_bytes;
162 if(max_tensor_bytes > TwoGB && N > 1)
174 index_t d_piece = D_out / d_split;
175 index_t h_piece = H_out / h_split;
176 index_t w_piece = W_out / w_split;
178 return static_cast<long_index_t>(effective_N) * d_piece * h_piece * w_piece * K * G *
183 const long_index_t memory_after_split_n = calc_memory(1, 1, 1);
186 if(memory_after_split_n <= TwoGB)
188 info.should_split =
false;
193 info.should_split =
true;
201 constexpr index_t MAX_TOTAL_PIECES = 64;
204 info.num_d_pieces = 1;
205 info.num_h_pieces = 1;
206 info.num_w_pieces = 1;
211 index_t max_d_split = (D_out < MAX_TOTAL_PIECES) ? D_out : MAX_TOTAL_PIECES;
212 for(
index_t d_split = 2; d_split <= max_d_split; d_split++)
214 info.num_d_pieces = d_split;
215 if(calc_memory(d_split, 1, 1) <= TwoGB)
226 index_t max_h_split = MAX_TOTAL_PIECES / info.num_d_pieces;
227 max_h_split = (H_out < max_h_split) ? H_out : max_h_split;
229 for(
index_t h_split = 2; h_split <= max_h_split; h_split++)
231 info.num_h_pieces = h_split;
232 if(calc_memory(info.num_d_pieces, h_split, 1) <= TwoGB)
241 index_t max_w_split = MAX_TOTAL_PIECES / (info.num_d_pieces * info.num_h_pieces);
242 max_w_split = (W_out < max_w_split) ? W_out : max_w_split;
244 for(
index_t w_split = 2; w_split <= max_w_split; w_split++)
246 info.num_w_pieces = w_split;
247 if(calc_memory(info.num_d_pieces, info.num_h_pieces, w_split) <= TwoGB)
255 info.num_d_pieces = (D_out < 4) ? D_out : 4;
256 info.num_h_pieces = (H_out < 4) ? H_out : 4;
257 info.num_w_pieces = (W_out < 4) ? W_out : 4;
272 : G_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.G_)},
273 N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.N_)},
274 original_N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.original_N_)},
275 Di_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Di_)},
276 Hi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Hi_)},
277 Wi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wi_)},
278 Do_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Do_)},
279 Ho_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Ho_)},
280 Wo_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wo_)},
281 Z_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Z_)},
282 Y_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Y_)},
283 X_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.X_)},
284 K_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.K_)},
285 C_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.C_)},
286 ConvStrideD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideD_)},
287 ConvStrideH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideH_)},
288 ConvStrideW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideW_)},
289 ConvDilationD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationD_)},
290 ConvDilationH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationH_)},
291 ConvDilationW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationW_)},
292 InLeftPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadD_)},
293 InLeftPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadH_)},
294 InLeftPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadW_)},
295 InRightPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadD_)},
296 InRightPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadH_)},
297 InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)},
298 ZYX_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ZYX_)}
307 const ConvDimsType& b_g_k_c_xs_lengths,
308 const ConvDimsType& c_g_n_k_wos_lengths,
309 const ConvSpatialDimsType& conv_filter_strides,
310 const ConvSpatialDimsType& conv_filter_dilations,
311 const ConvSpatialDimsType& input_left_pads,
312 const ConvSpatialDimsType& input_right_pads)
313 : G_{a_g_n_c_wis_lengths[I0]},
316 Wi_{a_g_n_c_wis_lengths[I3]},
319 Wo_{c_g_n_k_wos_lengths[I3]},
322 X_{b_g_k_c_xs_lengths[I3]},
323 K_{c_g_n_k_wos_lengths[I2]},
324 C_{b_g_k_c_xs_lengths[I2]},
327 ConvStrideW_{conv_filter_strides[I0]},
330 ConvDilationW_{conv_filter_dilations[I0]},
333 InLeftPadW_{input_left_pads[I0]},
336 InRightPadW_{input_right_pads[I0]},
339 static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
340 std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
341 static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
342 std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
345 original_N_ = N_ = c_g_n_k_wos_lengths[I1];
349 N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
358 const ConvDimsType& b_g_k_c_xs_lengths,
359 const ConvDimsType& c_g_n_k_wos_lengths,
360 const ConvSpatialDimsType& conv_filter_strides,
361 const ConvSpatialDimsType& conv_filter_dilations,
362 const ConvSpatialDimsType& input_left_pads,
363 const ConvSpatialDimsType& input_right_pads)
364 : G_{a_g_n_c_wis_lengths[I0]},
366 Hi_{a_g_n_c_wis_lengths[I3]},
367 Wi_{a_g_n_c_wis_lengths[I4]},
369 Ho_{c_g_n_k_wos_lengths[I3]},
370 Wo_{c_g_n_k_wos_lengths[I4]},
372 Y_{b_g_k_c_xs_lengths[I3]},
373 X_{b_g_k_c_xs_lengths[I4]},
374 K_{c_g_n_k_wos_lengths[I2]},
375 C_{b_g_k_c_xs_lengths[I2]},
377 ConvStrideH_{conv_filter_strides[I0]},
378 ConvStrideW_{conv_filter_strides[I1]},
380 ConvDilationH_{conv_filter_dilations[I0]},
381 ConvDilationW_{conv_filter_dilations[I1]},
383 InLeftPadH_{input_left_pads[I0]},
384 InLeftPadW_{input_left_pads[I1]},
386 InRightPadH_{input_right_pads[I0]},
387 InRightPadW_{input_right_pads[I1]},
390 static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
391 std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
392 static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
393 std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
396 original_N_ = N_ = c_g_n_k_wos_lengths[I1];
400 N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
409 const ConvDimsType& b_g_k_c_xs_lengths,
410 const ConvDimsType& c_g_n_k_wos_lengths,
411 const ConvSpatialDimsType& conv_filter_strides,
412 const ConvSpatialDimsType& conv_filter_dilations,
413 const ConvSpatialDimsType& input_left_pads,
414 const ConvSpatialDimsType& input_right_pads)
415 : G_{a_g_n_c_wis_lengths[I0]},
416 Di_{a_g_n_c_wis_lengths[I3]},
417 Hi_{a_g_n_c_wis_lengths[I4]},
418 Wi_{a_g_n_c_wis_lengths[I5]},
419 Do_{c_g_n_k_wos_lengths[I3]},
420 Ho_{c_g_n_k_wos_lengths[I4]},
421 Wo_{c_g_n_k_wos_lengths[I5]},
422 Z_{b_g_k_c_xs_lengths[I3]},
423 Y_{b_g_k_c_xs_lengths[I4]},
424 X_{b_g_k_c_xs_lengths[I5]},
425 K_{c_g_n_k_wos_lengths[I2]},
426 C_{b_g_k_c_xs_lengths[I2]},
427 ConvStrideD_{conv_filter_strides[I0]},
428 ConvStrideH_{conv_filter_strides[I1]},
429 ConvStrideW_{conv_filter_strides[I2]},
430 ConvDilationD_{conv_filter_dilations[I0]},
431 ConvDilationH_{conv_filter_dilations[I1]},
432 ConvDilationW_{conv_filter_dilations[I2]},
433 InLeftPadD_{input_left_pads[I0]},
434 InLeftPadH_{input_left_pads[I1]},
435 InLeftPadW_{input_left_pads[I2]},
436 InRightPadD_{input_right_pads[I0]},
437 InRightPadH_{input_right_pads[I1]},
438 InRightPadW_{input_right_pads[I2]},
441 static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
442 std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
443 static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
444 std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
447 original_N_ = N_ = c_g_n_k_wos_lengths[I1];
451 N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
474 IndexType WiStride_ = G_ * C_;
475 IndexType CStrideTensorA_ = 1;
476 IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
477 IndexType GStrideTensorA_ = C_;
481 if constexpr(NumGroupsToMerge == 1)
485 make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
499 make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
504 in_gemmm_groups_gemmk_desc,
513 if constexpr(NumGroupsToMerge == 1)
516 const auto in_n_wi_c_desc =
548 make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_),
579 if constexpr(NumGroupsToMerge == 1)
583 make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
606 make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
629 if constexpr(NumGroupsToMerge == 1)
633 make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
665 make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
705 IndexType HiStride_ = Wi_ * G_ * C_;
706 IndexType WiStride_ = G_ * C_;
707 IndexType CStrideTensorA_ = 1;
708 IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
709 IndexType GStrideTensorA_ = C_;
713 if constexpr(NumGroupsToMerge == 1)
717 make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
731 make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_),
733 NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
738 in_gemmm_groups_gemmk_desc,
747 if constexpr(NumGroupsToMerge == 1)
749 const auto in_n_hi_wi_c_desc =
751 make_tuple(NStrideTensorA_, HiStride_, WiStride_),
774 in_n_y_ho_x_wo_c_desc,
784 make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_),
789 in_n_hi_wi_groups_c_desc,
798 in_n_hip_wip_groups_c_desc,
809 in_n_y_ho_x_wo_groups_c_desc,
818 if constexpr(NumGroupsToMerge == 1)
822 make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
845 make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
847 NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
852 in_n_hi_wi_groups_c_desc,
864 in_n_ho_wo_groups_c_desc,
873 if constexpr(NumGroupsToMerge == 1)
877 make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
902 in_n_y_ho_x_wo_c_desc,
912 make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
914 NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
919 in_n_hi_wi_groups_c_desc,
931 in_n_hip_wip_groups_c_desc,
948 in_n_y_ho_x_wo_groups_c_desc,
964 IndexType DiStride_ = Hi_ * Wi_ * G_ * C_;
965 IndexType HiStride_ = Wi_ * G_ * C_;
966 IndexType WiStride_ = G_ * C_;
967 IndexType CStrideTensorA_ = 1;
968 IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
969 IndexType GStrideTensorA_ = C_;
973 if constexpr(NumGroupsToMerge == 1)
977 make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
991 make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, C_),
1002 in_gemmm_groups_gemmk_desc,
1012 if constexpr(NumGroupsToMerge == 1)
1016 make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_),
1021 in_n_di_hi_wi_c_desc,
1030 in_n_hip_wip_c_desc,
1043 in_n_z_do_y_ho_x_wo_c_desc,
1053 make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge),
1054 make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_),
1059 in_n_di_hi_wi_c_desc,
1071 in_n_hip_wip_c_desc,
1089 in_n_z_do_y_ho_x_wo_c_desc,
1099 if constexpr(NumGroupsToMerge == 1)
1103 make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
1108 in_n_di_hi_wi_c_desc,
1120 in_n_do_ho_wo_c_desc,
1129 make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
1140 in_n_di_hi_wi_c_desc,
1161 in_n_do_ho_wo_c_desc,
1171 if constexpr(NumGroupsToMerge == 1)
1175 make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
1180 in_n_di_hi_wi_c_desc,
1192 in_n_hip_wip_c_desc,
1210 in_n_z_do_y_ho_x_wo_c_desc,
1219 make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
1230 in_n_di_hi_wi_c_desc,
1251 in_n_hip_wip_c_desc,
1275 in_n_z_do_y_ho_x_wo_c_desc,
1293 IndexType CStrideTensorB_ = 1;
1294 IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_;
1295 IndexType GStrideTensorB_ = K_ * Z_ * Y_ * X_ * C_;
1299 using FilterSizeNumType =
1300 std::conditional_t<NDimSpatial == 1,
1302 std::conditional_t<NDimSpatial == 2, number<9>,
number<27>>>;
1304 if constexpr(NumGroupsToMerge == 1)
1315 make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}),
1316 make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_),
1320 wei_gemmn_groups_gemmk_desc,
1329 if constexpr(NumGroupsToMerge == 1)
1340 make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_),
1344 wei_gemmn_groups_gemmk_desc,
1360 IndexType WoStride_ = G_ * K_;
1361 IndexType KStrideTensorC_ = 1;
1362 IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
1363 IndexType GStrideTensorC_ = K_;
1365 const IndexType NDoHoWo = N_ * Wo_;
1366 if constexpr(NumGroupsToMerge == 1)
1376 make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1),
1378 NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_),
1383 nhwo_groups_k_1_desc,
1393 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
1394 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
1395 NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
1405 unmerged_padded_desc,
1421 IndexType HoStride_ = Wo_ * G_ * K_;
1422 IndexType WoStride_ = G_ * K_;
1423 IndexType KStrideTensorC_ = 1;
1424 IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
1425 IndexType GStrideTensorC_ = K_;
1427 const IndexType NDoHoWo = N_ * Ho_ * Wo_;
1428 if constexpr(NumGroupsToMerge == 1)
1437 const auto nhwo_groups_k_1_desc =
1449 nhwo_groups_k_1_desc,
1459 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
1460 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
1461 NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
1471 unmerged_padded_desc,
1486 IndexType DoStride_ = Ho_ * Wo_ * G_ * K_;
1487 IndexType HoStride_ = Wo_ * G_ * K_;
1488 IndexType WoStride_ = G_ * K_;
1489 IndexType KStrideTensorC_ = 1;
1490 IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
1491 IndexType GStrideTensorC_ = K_;
1493 const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
1494 if constexpr(NumGroupsToMerge == 1)
1503 const auto nhwo_groups_k_1_desc =
1516 nhwo_groups_k_1_desc,
1526 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
1527 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
1528 NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
1538 unmerged_padded_desc,