device_pool3d_fwd_ndhwc_ndhwc.hpp Source File

device_pool3d_fwd_ndhwc_ndhwc.hpp Source File#

Composable Kernel: device_pool3d_fwd_ndhwc_ndhwc.hpp Source File
device_pool3d_fwd_ndhwc_ndhwc.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23template <typename InDataType,
24 typename OutDataType,
25 typename IndexDataType, // enable if OutputIndex == true
26 typename ComputeDataType,
27 ck::ReduceTensorOp ReduceOpId,
28 bool OutputIndex,
29 ck::index_t BlockSize,
30 ck::index_t MThreadClusterSize,
31 ck::index_t KThreadClusterSize,
32 ck::index_t MThreadSliceSize,
33 ck::index_t KThreadSliceSize,
34 ck::index_t InSrcOutDstVectorSize>
36 3,
37 InDataType,
38 OutDataType,
39 IndexDataType,
40 tensor_layout::convolution::NDHWC,
41 tensor_layout::convolution::NDHWC,
42 ReduceOpId,
43 OutputIndex>
44{
45 static constexpr auto I0 = Number<0>{};
46 static constexpr auto I1 = Number<1>{};
47 static constexpr auto I2 = Number<2>{};
48 static constexpr auto I3 = Number<3>{};
49 static constexpr auto I4 = Number<4>{};
50 static constexpr auto I5 = Number<5>{};
51
52 static constexpr index_t InOutRank = 5;
53 static constexpr index_t WindowRank = 3;
54
56
59
62
63 static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
64 static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
65
66 static auto MakeABGridDescriptor_A_M_K_B_M(std::vector<ck::index_t> input_ncdhw_lengths,
67 std::vector<ck::index_t> output_ncdhw_lengths,
68 std::vector<ck::index_t> input_ncdhw_stride,
69 std::vector<ck::index_t> output_ncdhw_stride,
70 std::vector<ck::index_t> window_spatial_zyx_lengths,
71 std::vector<ck::index_t> window_zyx_strides,
72 std::vector<ck::index_t> window_zyx_dilations,
73 std::vector<ck::index_t> input_left_dhw_pads,
74 std::vector<ck::index_t> input_right_dhw_pads)
75 {
76 const index_t N = input_ncdhw_lengths[0];
77 const index_t C = input_ncdhw_lengths[1];
78 const index_t Di = input_ncdhw_lengths[2];
79 const index_t Hi = input_ncdhw_lengths[3];
80 const index_t Wi = input_ncdhw_lengths[4];
81
82 const index_t Do = output_ncdhw_lengths[2];
83 const index_t Ho = output_ncdhw_lengths[3];
84 const index_t Wo = output_ncdhw_lengths[4];
85
86 const index_t Z = window_spatial_zyx_lengths[0];
87 const index_t Y = window_spatial_zyx_lengths[1];
88 const index_t X = window_spatial_zyx_lengths[2];
89
90 const index_t WindowStrideD = window_zyx_strides[0];
91 const index_t WindowStrideH = window_zyx_strides[1];
92 const index_t WindowStrideW = window_zyx_strides[2];
93
94 const index_t WindowDilationD = window_zyx_dilations[0];
95 const index_t WindowDilationH = window_zyx_dilations[1];
96 const index_t WindowDilationW = window_zyx_dilations[2];
97
98 const index_t InLeftPadD = input_left_dhw_pads[0];
99 const index_t InLeftPadH = input_left_dhw_pads[1];
100 const index_t InLeftPadW = input_left_dhw_pads[2];
101
102 const index_t InRightPadD = input_right_dhw_pads[0];
103 const index_t InRightPadH = input_right_dhw_pads[1];
104 const index_t InRightPadW = input_right_dhw_pads[2];
105
106 const index_t MRaw = N * Do * Ho * Wo * C;
107 const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
108
109 const index_t KRaw = Z * Y * X;
110 const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
111
112 // A[ReduceM, ReduceK]
113 const index_t Ni_stride = input_ncdhw_stride[0];
114 const index_t Ci_stride = input_ncdhw_stride[1];
115 const index_t Di_stride = input_ncdhw_stride[2];
116 const index_t Hi_stride = input_ncdhw_stride[3];
117 const index_t Wi_stride = input_ncdhw_stride[4];
118
119 const auto in_grid_desc_n_di_hi_wi_c = make_naive_tensor_descriptor(
120 make_tuple(N, Di, Hi, Wi, C),
121 make_tuple(Ni_stride, Di_stride, Hi_stride, Wi_stride, Ci_stride));
122
123 const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor(
124 in_grid_desc_n_di_hi_wi_c,
126 make_pad_transform(Di, InLeftPadD, InRightPadD),
127 make_pad_transform(Hi, InLeftPadH, InRightPadH),
128 make_pad_transform(Wi, InLeftPadW, InRightPadW),
132
133 const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor(
134 in_grid_desc_n_dip_hip_wip_c,
137 make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)),
138 make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
139 make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
146 Sequence<7>{}));
147
148 const auto in_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor(
149 in_grid_desc_n_z_do_y_ho_x_wo_c,
150 make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C)),
154
155 const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor(
156 in_grid_desc_reducemraw_reducekraw,
160
161 // B[ReduceM]
162 const index_t No_stride = output_ncdhw_stride[0];
163 const index_t Co_stride = output_ncdhw_stride[1];
164 const index_t Do_stride = output_ncdhw_stride[2];
165 const index_t Ho_stride = output_ncdhw_stride[3];
166 const index_t Wo_stride = output_ncdhw_stride[4];
167
168 const auto out_grid_desc_n_do_ho_wo_c = make_naive_tensor_descriptor(
169 make_tuple(N, Di, Hi, Wi, C),
170 make_tuple(No_stride, Do_stride, Ho_stride, Wo_stride, Co_stride));
171
172 const auto out_grid_desc_reducemraw = transform_tensor_descriptor(
173 out_grid_desc_n_do_ho_wo_c,
174 make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C))),
177
178 const auto out_grid_desc_reducem =
179 transform_tensor_descriptor(out_grid_desc_reducemraw,
183
184 return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
185 }
186
188 decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {}));
189
192
193 struct Argument : public BaseArgument
194 {
195 Argument(const InDataType* p_in_dev,
196 OutDataType* p_out_dev,
197 IndexDataType* p_out_indices_dev,
198 std::vector<ck::index_t>& input_ncdhw_lengths,
199 std::vector<ck::index_t>& output_ncdhw_lengths,
200 std::vector<ck::index_t>& input_ncdhw_stride,
201 std::vector<ck::index_t>& output_ncdhw_stride,
202 std::vector<ck::index_t>&, // indices_ncdhw_stride
203 std::vector<ck::index_t>& window_spatial_zyx_lengths,
204 std::vector<ck::index_t>& window_zyx_strides,
205 std::vector<ck::index_t>& window_zyx_dilations,
206 std::vector<ck::index_t>& input_left_dhw_pads,
207 std::vector<ck::index_t>& input_right_dhw_pads)
208 : p_in_dev_{p_in_dev},
209 p_out_dev_{p_out_dev},
210 p_out_indices_dev_{p_out_indices_dev},
213 input_ncdhw_lengths_{input_ncdhw_lengths},
214 output_ncdhw_lengths_{output_ncdhw_lengths},
215 input_ncdhw_stride_{input_ncdhw_stride},
216 output_ncdhw_stride_{output_ncdhw_stride}
217 {
218 const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_ncdhw_lengths,
219 output_ncdhw_lengths,
220 input_ncdhw_stride,
221 output_ncdhw_stride,
222 window_spatial_zyx_lengths,
223 window_zyx_strides,
224 window_zyx_dilations,
225 input_left_dhw_pads,
226 input_right_dhw_pads);
227
228 a_grid_desc_m_k_ = descs[I0];
229 b_grid_desc_m_ = descs[I1];
230
231 int32_t reduceLength = window_spatial_zyx_lengths[0] * window_spatial_zyx_lengths[1] *
232 window_spatial_zyx_lengths[2];
233
236 }
237
238 const InDataType* p_in_dev_;
239 OutDataType* p_out_dev_;
240 IndexDataType* p_out_indices_dev_;
243
246
247 // for checking vector load/store
248 std::vector<ck::index_t> input_ncdhw_lengths_;
249 std::vector<ck::index_t> output_ncdhw_lengths_;
250 std::vector<ck::index_t> input_ncdhw_stride_;
251 std::vector<ck::index_t> output_ncdhw_stride_;
252 };
253
254 struct Invoker : public BaseInvoker
255 {
256 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
257 {
258 // for NDHWC, the dim C is the fastest dimension, and is not reduced.
259 // Hence, it is in M dimension for reduction kernel.
260 static constexpr index_t InSrcOutDstVectorDim = 0; // 0: M, 1: K
261
262 using gridwise_reduce =
264 OutDataType,
265 ComputeDataType,
266 IndexDataType,
273 false, // propagate_nan
274 BlockSize,
275 MThreadSliceSize,
276 KThreadSliceSize,
277 InSrcOutDstVectorDim,
278 InSrcOutDstVectorSize,
279 InSrcOutDstVectorSize>;
280
281 const auto kernel =
282 kernel_reduce_threadwise<gridwise_reduce,
283 OutputIndex,
284 true, // pooling need to return global index
285 false, // don't have index input
286 InDataType,
287 OutDataType,
288 ComputeDataType,
289 IndexDataType,
294
295 ck::index_t M = arg.a_grid_desc_m_k_.GetLength(I0);
296
297 const index_t grid_size = (M / M_BlockTileSize);
298
299 return launch_and_time_kernel(stream_config,
300 kernel,
301 dim3(grid_size),
302 dim3(BlockSize),
303 0,
305 arg.b_grid_desc_m_,
306 arg.in_element_op_,
307 arg.acc_element_op_,
308 float(1),
309 arg.p_in_dev_,
310 nullptr,
311 float(0),
312 arg.p_out_dev_,
314 }
315
316 float Run(const BaseArgument* p_arg,
317 const StreamConfig& stream_config = StreamConfig{}) override
318 {
319 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
320 }
321 };
322
323 bool IsSupportedArgument(const BaseArgument* p_arg) override
324 {
325 const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
326
327 // C should be fastest dimension
328 if(pArg->input_ncdhw_stride_[1] != 1)
329 return false;
330
331 for(int i = 0; i < InOutRank; ++i)
332 {
333 if(pArg->input_ncdhw_stride_[i] == 1 &&
334 pArg->input_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0)
335 return false;
336
337 if(pArg->output_ncdhw_stride_[i] == 1 &&
338 pArg->output_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0)
339 return false;
340 }
341
342 return true;
343 }
344
345 virtual std::unique_ptr<BaseArgument>
346 MakeArgumentPointer(const void* p_in_dev,
347 void* p_out_dev,
348 void* p_out_indices_dev,
349 std::vector<ck::index_t> input_ncdhw_lengths,
350 std::vector<ck::index_t> window_zyx_lengths,
351 std::vector<ck::index_t> output_ncdhw_lengths,
352 std::vector<ck::index_t> input_ncdhw_stride,
353 std::vector<ck::index_t> output_ncdhw_stride,
354 std::vector<ck::index_t> indices_ncdhw_stride,
355 std::vector<ck::index_t> window_zyx_strides,
356 std::vector<ck::index_t> window_zyx_dilations,
357 std::vector<ck::index_t> input_left_dhw_pads,
358 std::vector<ck::index_t> input_right_dhw_pads,
359 std::vector<ck::index_t> pooling_dims) override
360 {
361 if(input_ncdhw_lengths.size() != InOutRank || window_zyx_lengths.size() != WindowRank ||
362 input_ncdhw_lengths.size() != InOutRank || window_zyx_strides.size() != WindowRank ||
363 window_zyx_dilations.size() != WindowRank || input_left_dhw_pads.size() != WindowRank ||
364 input_right_dhw_pads.size() != WindowRank)
365 throw std::runtime_error("dimension is incorrect");
366
367 if(pooling_dims != std::vector<ck::index_t>{2, 3, 4})
368 throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far");
369
370 if(output_ncdhw_stride != indices_ncdhw_stride)
371 throw std::runtime_error(
372 "output_ncdhw_stride need to be equal to indices_ncdhw_stride for now");
373
374 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
375 static_cast<OutDataType*>(p_out_dev),
376 static_cast<IndexDataType*>(p_out_indices_dev),
377 input_ncdhw_lengths,
378 output_ncdhw_lengths,
379 input_ncdhw_stride,
380 output_ncdhw_stride,
381 indices_ncdhw_stride,
382 window_zyx_lengths,
383 window_zyx_strides,
384 window_zyx_dilations,
385 input_left_dhw_pads,
386 input_right_dhw_pads);
387 }
388
389 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
390 {
391 return std::make_unique<Invoker>(Invoker{});
392 }
393
394 std::string GetTypeString() const override
395 {
396 auto str = std::stringstream();
397
398 // clang-format off
399 str << "DevicePool3dFwd_NDHWC_NDHWC<" << BlockSize << ",";
400 str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
401 str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
402 str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
403 // clang-format on
404
405 return str.str();
406 }
407};
408
409} // namespace device
410} // namespace tensor_operation
411} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
ReduceTensorOp
Definition reduction_enums.hpp:9
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition gridwise_2d_reduction_threadwise.hpp:28
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
signed int int32_t
Definition stdint.h:123
Definition ck/stream_config.hpp:10
Definition gridwise_2d_reduction_threadwise.hpp:84
Definition utility/sequence.hpp:43
Definition reduction_operator_mapping.hpp:20
tensor_operation::element_wise::PassThrough AccElementwiseOperation
Definition reduction_operator_mapping.hpp:93
static std::tuple< InElementwiseOperation, AccElementwiseOperation > GetElementwiseOperator(int32_t reduceLength)
Definition reduction_operator_mapping.hpp:96
tensor_operation::element_wise::PassThrough InElementwiseOperation
Definition reduction_operator_mapping.hpp:92
Definition device_base.hpp:197
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:194
BGridDesc_M b_grid_desc_m_
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:242
Argument(const InDataType *p_in_dev, OutDataType *p_out_dev, IndexDataType *p_out_indices_dev, std::vector< ck::index_t > &input_ncdhw_lengths, std::vector< ck::index_t > &output_ncdhw_lengths, std::vector< ck::index_t > &input_ncdhw_stride, std::vector< ck::index_t > &output_ncdhw_stride, std::vector< ck::index_t > &, std::vector< ck::index_t > &window_spatial_zyx_lengths, std::vector< ck::index_t > &window_zyx_strides, std::vector< ck::index_t > &window_zyx_dilations, std::vector< ck::index_t > &input_left_dhw_pads, std::vector< ck::index_t > &input_right_dhw_pads)
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:195
IndexDataType * p_out_indices_dev_
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:240
std::vector< ck::index_t > input_ncdhw_lengths_
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:248
std::vector< ck::index_t > output_ncdhw_lengths_
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:249
const InDataType * p_in_dev_
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:238
std::vector< ck::index_t > output_ncdhw_stride_
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:251
InElementwiseOperation in_element_op_
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:244
AGridDesc_M_K a_grid_desc_m_k_
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:241
std::vector< ck::index_t > input_ncdhw_stride_
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:250
AccElementwiseOperation acc_element_op_
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:245
OutDataType * p_out_dev_
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:239
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:255
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:316
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:256
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:44
static constexpr index_t WindowRank
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:53
static constexpr auto I2
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:47
decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {})) ABGridDescs
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:187
static constexpr auto I1
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:46
static constexpr auto I3
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:48
static constexpr ck::index_t M_BlockTileSize
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:63
remove_cvref_t< decltype(ABGridDescs{}[I0])> AGridDesc_M_K
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:190
static auto MakeABGridDescriptor_A_M_K_B_M(std::vector< ck::index_t > input_ncdhw_lengths, std::vector< ck::index_t > output_ncdhw_lengths, std::vector< ck::index_t > input_ncdhw_stride, std::vector< ck::index_t > output_ncdhw_stride, std::vector< ck::index_t > window_spatial_zyx_lengths, std::vector< ck::index_t > window_zyx_strides, std::vector< ck::index_t > window_zyx_dilations, std::vector< ck::index_t > input_left_dhw_pads, std::vector< ck::index_t > input_right_dhw_pads)
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:66
static constexpr ck::index_t K_BlockTileSize
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:64
static constexpr index_t InOutRank
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:52
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:389
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:323
typename reduce_unary_operator< ReduceOpId, true, true >::AccElementwiseOperation AccElementwiseOperation
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:60
static constexpr auto I4
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:49
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_dev, void *p_out_dev, void *p_out_indices_dev, std::vector< ck::index_t > input_ncdhw_lengths, std::vector< ck::index_t > window_zyx_lengths, std::vector< ck::index_t > output_ncdhw_lengths, std::vector< ck::index_t > input_ncdhw_stride, std::vector< ck::index_t > output_ncdhw_stride, std::vector< ck::index_t > indices_ncdhw_stride, std::vector< ck::index_t > window_zyx_strides, std::vector< ck::index_t > window_zyx_dilations, std::vector< ck::index_t > input_left_dhw_pads, std::vector< ck::index_t > input_right_dhw_pads, std::vector< ck::index_t > pooling_dims) override
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:346
remove_cvref_t< decltype(ABGridDescs{}[I1])> BGridDesc_M
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:191
typename reduce_unary_operator< ReduceOpId, true, true >::InElementwiseOperation InElementwiseOperation
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:57
std::string GetTypeString() const override
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:394
typename reduce_binary_operator< ReduceOpId >::opType ReduceOperation
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:55
static constexpr auto I5
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:50
static constexpr auto I0
Definition device_pool3d_fwd_ndhwc_ndhwc.hpp:45
Definition device_pool_fwd.hpp:25