device_pool2d_fwd_nhwc_nhwc.hpp Source File

device_pool2d_fwd_nhwc_nhwc.hpp Source File#

Composable Kernel: device_pool2d_fwd_nhwc_nhwc.hpp Source File
device_pool2d_fwd_nhwc_nhwc.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23template <typename InDataType,
24 typename OutDataType,
25 typename IndexDataType, // enable if OutputIndex == true
26 typename ComputeDataType,
27 ck::ReduceTensorOp ReduceOpId,
28 bool OutputIndex,
29 ck::index_t BlockSize,
30 ck::index_t MThreadClusterSize,
31 ck::index_t KThreadClusterSize,
32 ck::index_t MThreadSliceSize,
33 ck::index_t KThreadSliceSize,
34 ck::index_t InSrcOutDstVectorSize>
36 2,
37 InDataType,
38 OutDataType,
39 IndexDataType,
40 tensor_layout::convolution::NHWC,
41 tensor_layout::convolution::NHWC,
42 ReduceOpId,
43 OutputIndex>
44{
45 static constexpr auto I0 = Number<0>{};
46 static constexpr auto I1 = Number<1>{};
47
48 static constexpr index_t InOutRank = 4;
49 static constexpr index_t WindowRank = 2;
50
52
55
58
59 static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
60 static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
61
62 static auto MakeABGridDescriptor_A_M_K_B_M(std::vector<ck::index_t> input_nchw_lengths,
63 std::vector<ck::index_t> output_nchw_lengths,
64 std::vector<ck::index_t> input_nchw_stride,
65 std::vector<ck::index_t> output_nchw_stride,
66 std::vector<ck::index_t> window_spatial_yx_lengths,
67 std::vector<ck::index_t> window_yx_strides,
68 std::vector<ck::index_t> window_yx_dilations,
69 std::vector<ck::index_t> input_left_hw_pads,
70 std::vector<ck::index_t> input_right_hw_pads)
71 {
72 const index_t N = input_nchw_lengths[0];
73 const index_t C = input_nchw_lengths[1];
74 const index_t Hi = input_nchw_lengths[2];
75 const index_t Wi = input_nchw_lengths[3];
76
77 const index_t Ho = output_nchw_lengths[2];
78 const index_t Wo = output_nchw_lengths[3];
79 const index_t Y = window_spatial_yx_lengths[0];
80 const index_t X = window_spatial_yx_lengths[1];
81
82 const index_t WindowStrideH = window_yx_strides[0];
83 const index_t WindowStrideW = window_yx_strides[1];
84
85 const index_t WindowDilationH = window_yx_dilations[0];
86 const index_t WindowDilationW = window_yx_dilations[1];
87
88 const index_t InLeftPadH = input_left_hw_pads[0];
89 const index_t InLeftPadW = input_left_hw_pads[1];
90
91 const index_t InRightPadH = input_right_hw_pads[0];
92 const index_t InRightPadW = input_right_hw_pads[1];
93
94 const index_t MRaw = N * Ho * Wo * C;
95 const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
96
97 const index_t KRaw = Y * X;
98 const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
99
100 // A[ReduceM, ReduceK]
101 const index_t Ni_stride = input_nchw_stride[0];
102 const index_t Ci_stride = input_nchw_stride[1];
103 const index_t Hi_stride = input_nchw_stride[2];
104 const index_t Wi_stride = input_nchw_stride[3];
105
106 const auto in_grid_desc_n_hi_wi_c = make_naive_tensor_descriptor(
107 make_tuple(N, Hi, Wi, C), make_tuple(Ni_stride, Hi_stride, Wi_stride, Ci_stride));
108
109 const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor(
110 in_grid_desc_n_hi_wi_c,
112 make_pad_transform(Hi, InLeftPadH, InRightPadH),
113 make_pad_transform(Wi, InLeftPadW, InRightPadW),
117
118 const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor(
119 in_grid_desc_n_hip_wip_c,
122 make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
123 make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
127
128 const auto in_grid_desc_reducemraw_reducekraw =
129 transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c,
134
135 const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor(
136 in_grid_desc_reducemraw_reducekraw,
140
141 // B[ReduceM]
142 const index_t No_stride = output_nchw_stride[0];
143 const index_t Co_stride = output_nchw_stride[1];
144 const index_t Ho_stride = output_nchw_stride[2];
145 const index_t Wo_stride = output_nchw_stride[3];
146
147 const auto out_grid_desc_n_ho_wo_c = make_naive_tensor_descriptor(
148 make_tuple(N, Hi, Wi, C), make_tuple(No_stride, Ho_stride, Wo_stride, Co_stride));
149
150 const auto out_grid_desc_reducemraw =
151 transform_tensor_descriptor(out_grid_desc_n_ho_wo_c,
155
156 const auto out_grid_desc_reducem =
157 transform_tensor_descriptor(out_grid_desc_reducemraw,
161
162 return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
163 }
164
166 decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {}));
167
170
171 struct Argument : public BaseArgument
172 {
173 Argument(const InDataType* p_in_dev,
174 OutDataType* p_out_dev,
175 IndexDataType* p_out_indices_dev,
176 std::vector<ck::index_t>& input_nchw_lengths,
177 std::vector<ck::index_t>& output_nchw_lengths,
178 std::vector<ck::index_t>& input_nchw_stride,
179 std::vector<ck::index_t>& output_nchw_stride,
180 std::vector<ck::index_t>&, // indices_nchw_stride
181 std::vector<ck::index_t>& window_spatial_yx_lengths,
182 std::vector<ck::index_t>& window_yx_strides,
183 std::vector<ck::index_t>& window_yx_dilations,
184 std::vector<ck::index_t>& input_left_hw_pads,
185 std::vector<ck::index_t>& input_right_hw_pads)
186 : p_in_dev_{p_in_dev},
187 p_out_dev_{p_out_dev},
188 p_out_indices_dev_{p_out_indices_dev},
191 input_nchw_lengths_{input_nchw_lengths},
192 output_nchw_lengths_{output_nchw_lengths},
193 input_nchw_stride_{input_nchw_stride},
194 output_nchw_stride_{output_nchw_stride}
195 {
196 const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_nchw_lengths,
197 output_nchw_lengths,
198 input_nchw_stride,
199 output_nchw_stride,
200 window_spatial_yx_lengths,
201 window_yx_strides,
202 window_yx_dilations,
203 input_left_hw_pads,
204 input_right_hw_pads);
205
206 a_grid_desc_m_k_ = descs[I0];
207 b_grid_desc_m_ = descs[I1];
208
209 int32_t reduceLength = window_spatial_yx_lengths[0] * window_spatial_yx_lengths[1];
210
213 }
214
215 const InDataType* p_in_dev_;
216 OutDataType* p_out_dev_;
217 IndexDataType* p_out_indices_dev_;
220
223
224 // for checking vector load/store
225 std::vector<ck::index_t> input_nchw_lengths_;
226 std::vector<ck::index_t> output_nchw_lengths_;
227 std::vector<ck::index_t> input_nchw_stride_;
228 std::vector<ck::index_t> output_nchw_stride_;
229 };
230
231 struct Invoker : public BaseInvoker
232 {
233 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
234 {
235 // for NHWC, the dim C is the fastest dimension, and is not reduced.
236 // Hence, it is in M dimension for reduction kernel.
237 static constexpr index_t InSrcOutDstVectorDim = 0; // 0: M, 1: K
238
239 using gridwise_reduce =
241 OutDataType,
242 ComputeDataType,
243 IndexDataType,
250 false, // propagate_nan
251 BlockSize,
252 MThreadSliceSize,
253 KThreadSliceSize,
254 InSrcOutDstVectorDim,
255 InSrcOutDstVectorSize,
256 InSrcOutDstVectorSize>;
257
258 const auto kernel =
259 kernel_reduce_threadwise<gridwise_reduce,
260 OutputIndex,
261 true, // pooling need to return global index
262 false, // don't have index input
263 InDataType,
264 OutDataType,
265 ComputeDataType,
266 IndexDataType,
271
272 ck::index_t M = arg.a_grid_desc_m_k_.GetLength(I0);
273
274 const index_t grid_size = (M / M_BlockTileSize);
275
276 return launch_and_time_kernel(stream_config,
277 kernel,
278 dim3(grid_size),
279 dim3(BlockSize),
280 0,
282 arg.b_grid_desc_m_,
283 arg.in_element_op_,
284 arg.acc_element_op_,
285 float(1),
286 arg.p_in_dev_,
287 nullptr,
288 float(0),
289 arg.p_out_dev_,
291 }
292
293 float Run(const BaseArgument* p_arg,
294 const StreamConfig& stream_config = StreamConfig{}) override
295 {
296 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
297 }
298 };
299
300 bool IsSupportedArgument(const BaseArgument* p_arg) override
301 {
302 const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
303
304 // C should be fastest dimension
305 if(pArg->input_nchw_stride_[1] != 1)
306 return false;
307
308 for(int i = 0; i < InOutRank; ++i)
309 {
310 if(pArg->input_nchw_stride_[i] == 1 &&
311 pArg->input_nchw_lengths_[i] % InSrcOutDstVectorSize != 0)
312 return false;
313
314 if(pArg->output_nchw_stride_[i] == 1 &&
315 pArg->output_nchw_lengths_[i] % InSrcOutDstVectorSize != 0)
316 return false;
317 }
318
319 return true;
320 }
321
322 virtual std::unique_ptr<BaseArgument>
323 MakeArgumentPointer(const void* p_in_dev,
324 void* p_out_dev,
325 void* p_out_indices_dev,
326 std::vector<ck::index_t> input_nchw_lengths,
327 std::vector<ck::index_t> window_yx_lengths,
328 std::vector<ck::index_t> output_nchw_lengths,
329 std::vector<ck::index_t> input_nchw_stride,
330 std::vector<ck::index_t> output_nchw_stride,
331 std::vector<ck::index_t> indices_nchw_stride,
332 std::vector<ck::index_t> window_yx_strides,
333 std::vector<ck::index_t> window_yx_dilations,
334 std::vector<ck::index_t> input_left_hw_pads,
335 std::vector<ck::index_t> input_right_hw_pads,
336 std::vector<ck::index_t> pooling_dims) override
337 {
338 if(input_nchw_lengths.size() != InOutRank || window_yx_lengths.size() != WindowRank ||
339 input_nchw_lengths.size() != InOutRank || window_yx_strides.size() != WindowRank ||
340 window_yx_dilations.size() != WindowRank || input_left_hw_pads.size() != WindowRank ||
341 input_right_hw_pads.size() != WindowRank)
342 throw std::runtime_error("dimension is incorrect");
343
344 if(pooling_dims != std::vector<ck::index_t>{2, 3})
345 throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far");
346
347 if(output_nchw_stride != indices_nchw_stride)
348 throw std::runtime_error(
349 "output_nchw_stride need to be equal to indices_nchw_stride for now");
350
351 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
352 static_cast<OutDataType*>(p_out_dev),
353 static_cast<IndexDataType*>(p_out_indices_dev),
354 input_nchw_lengths,
355 output_nchw_lengths,
356 input_nchw_stride,
357 output_nchw_stride,
358 indices_nchw_stride,
359 window_yx_lengths,
360 window_yx_strides,
361 window_yx_dilations,
362 input_left_hw_pads,
363 input_right_hw_pads);
364 }
365
366 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
367 {
368 return std::make_unique<Invoker>(Invoker{});
369 }
370
371 std::string GetTypeString() const override
372 {
373 auto str = std::stringstream();
374
375 // clang-format off
376 str << "DevicePool2dFwd_NHWC_NHWC<" << BlockSize << ",";
377 str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
378 str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
379 str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
380 // clang-format on
381
382 return str.str();
383 }
384};
385
386} // namespace device
387} // namespace tensor_operation
388} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
ReduceTensorOp
Definition reduction_enums.hpp:9
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition gridwise_2d_reduction_threadwise.hpp:28
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
signed int int32_t
Definition stdint.h:123
Definition ck/stream_config.hpp:10
Definition gridwise_2d_reduction_threadwise.hpp:84
Definition utility/sequence.hpp:43
Definition reduction_operator_mapping.hpp:20
tensor_operation::element_wise::PassThrough AccElementwiseOperation
Definition reduction_operator_mapping.hpp:93
static std::tuple< InElementwiseOperation, AccElementwiseOperation > GetElementwiseOperator(int32_t reduceLength)
Definition reduction_operator_mapping.hpp:96
tensor_operation::element_wise::PassThrough InElementwiseOperation
Definition reduction_operator_mapping.hpp:92
Definition device_base.hpp:197
Definition device_pool2d_fwd_nhwc_nhwc.hpp:172
std::vector< ck::index_t > output_nchw_lengths_
Definition device_pool2d_fwd_nhwc_nhwc.hpp:226
OutDataType * p_out_dev_
Definition device_pool2d_fwd_nhwc_nhwc.hpp:216
AGridDesc_M_K a_grid_desc_m_k_
Definition device_pool2d_fwd_nhwc_nhwc.hpp:218
const InDataType * p_in_dev_
Definition device_pool2d_fwd_nhwc_nhwc.hpp:215
std::vector< ck::index_t > input_nchw_stride_
Definition device_pool2d_fwd_nhwc_nhwc.hpp:227
InElementwiseOperation in_element_op_
Definition device_pool2d_fwd_nhwc_nhwc.hpp:221
std::vector< ck::index_t > output_nchw_stride_
Definition device_pool2d_fwd_nhwc_nhwc.hpp:228
IndexDataType * p_out_indices_dev_
Definition device_pool2d_fwd_nhwc_nhwc.hpp:217
Argument(const InDataType *p_in_dev, OutDataType *p_out_dev, IndexDataType *p_out_indices_dev, std::vector< ck::index_t > &input_nchw_lengths, std::vector< ck::index_t > &output_nchw_lengths, std::vector< ck::index_t > &input_nchw_stride, std::vector< ck::index_t > &output_nchw_stride, std::vector< ck::index_t > &, std::vector< ck::index_t > &window_spatial_yx_lengths, std::vector< ck::index_t > &window_yx_strides, std::vector< ck::index_t > &window_yx_dilations, std::vector< ck::index_t > &input_left_hw_pads, std::vector< ck::index_t > &input_right_hw_pads)
Definition device_pool2d_fwd_nhwc_nhwc.hpp:173
BGridDesc_M b_grid_desc_m_
Definition device_pool2d_fwd_nhwc_nhwc.hpp:219
std::vector< ck::index_t > input_nchw_lengths_
Definition device_pool2d_fwd_nhwc_nhwc.hpp:225
AccElementwiseOperation acc_element_op_
Definition device_pool2d_fwd_nhwc_nhwc.hpp:222
Definition device_pool2d_fwd_nhwc_nhwc.hpp:232
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_pool2d_fwd_nhwc_nhwc.hpp:233
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_pool2d_fwd_nhwc_nhwc.hpp:293
Definition device_pool2d_fwd_nhwc_nhwc.hpp:44
std::string GetTypeString() const override
Definition device_pool2d_fwd_nhwc_nhwc.hpp:371
remove_cvref_t< decltype(ABGridDescs{}[I0])> AGridDesc_M_K
Definition device_pool2d_fwd_nhwc_nhwc.hpp:168
static auto MakeABGridDescriptor_A_M_K_B_M(std::vector< ck::index_t > input_nchw_lengths, std::vector< ck::index_t > output_nchw_lengths, std::vector< ck::index_t > input_nchw_stride, std::vector< ck::index_t > output_nchw_stride, std::vector< ck::index_t > window_spatial_yx_lengths, std::vector< ck::index_t > window_yx_strides, std::vector< ck::index_t > window_yx_dilations, std::vector< ck::index_t > input_left_hw_pads, std::vector< ck::index_t > input_right_hw_pads)
Definition device_pool2d_fwd_nhwc_nhwc.hpp:62
static constexpr index_t WindowRank
Definition device_pool2d_fwd_nhwc_nhwc.hpp:49
static constexpr ck::index_t M_BlockTileSize
Definition device_pool2d_fwd_nhwc_nhwc.hpp:59
decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {})) ABGridDescs
Definition device_pool2d_fwd_nhwc_nhwc.hpp:165
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_pool2d_fwd_nhwc_nhwc.hpp:300
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_pool2d_fwd_nhwc_nhwc.hpp:366
remove_cvref_t< decltype(ABGridDescs{}[I1])> BGridDesc_M
Definition device_pool2d_fwd_nhwc_nhwc.hpp:169
typename reduce_unary_operator< ReduceOpId, true, true >::AccElementwiseOperation AccElementwiseOperation
Definition device_pool2d_fwd_nhwc_nhwc.hpp:56
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_dev, void *p_out_dev, void *p_out_indices_dev, std::vector< ck::index_t > input_nchw_lengths, std::vector< ck::index_t > window_yx_lengths, std::vector< ck::index_t > output_nchw_lengths, std::vector< ck::index_t > input_nchw_stride, std::vector< ck::index_t > output_nchw_stride, std::vector< ck::index_t > indices_nchw_stride, std::vector< ck::index_t > window_yx_strides, std::vector< ck::index_t > window_yx_dilations, std::vector< ck::index_t > input_left_hw_pads, std::vector< ck::index_t > input_right_hw_pads, std::vector< ck::index_t > pooling_dims) override
Definition device_pool2d_fwd_nhwc_nhwc.hpp:323
typename reduce_unary_operator< ReduceOpId, true, true >::InElementwiseOperation InElementwiseOperation
Definition device_pool2d_fwd_nhwc_nhwc.hpp:53
typename reduce_binary_operator< ReduceOpId >::opType ReduceOperation
Definition device_pool2d_fwd_nhwc_nhwc.hpp:51
static constexpr ck::index_t K_BlockTileSize
Definition device_pool2d_fwd_nhwc_nhwc.hpp:60
static constexpr auto I0
Definition device_pool2d_fwd_nhwc_nhwc.hpp:45
static constexpr index_t InOutRank
Definition device_pool2d_fwd_nhwc_nhwc.hpp:48
static constexpr auto I1
Definition device_pool2d_fwd_nhwc_nhwc.hpp:46
Definition device_pool_fwd.hpp:25