device_normalization_bwd_gamma_beta_impl.hpp Source File

device_normalization_bwd_gamma_beta_impl.hpp Source File#

Composable Kernel: device_normalization_bwd_gamma_beta_impl.hpp Source File
device_normalization_bwd_gamma_beta_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
16
17// M is Invariant dimension, K is reduced dimension
18namespace ck {
19namespace tensor_operation {
20namespace device {
21
22template <typename GridwiseReduction,
23 typename DYDataType,
24 typename XDataType,
25 typename MeanInvStdDataType,
26 typename DGammaDataType,
27 typename DBetaDataType,
28 typename GridDesc_M_K,
29 typename GridDesc_M>
30__global__ void
31kernel_normalization_bwd_gamma_beta(const GridDesc_M_K dy_grid_desc_m_k,
32 const GridDesc_M_K x_grid_desc_m_k,
33 const GridDesc_M_K mean_grid_desc_m_k,
34 const GridDesc_M_K inv_std_grid_desc_m_k,
35 const GridDesc_M dgamma_grid_desc_m,
36 const GridDesc_M dbeta_grid_desc_m,
37 index_t num_k_block_tile_iteration,
38 const DYDataType* const __restrict__ p_dy_global,
39 const XDataType* const __restrict__ p_x_global,
40 const MeanInvStdDataType* const __restrict__ p_mean_global,
41 const MeanInvStdDataType* const __restrict__ p_inv_std_global,
42 DGammaDataType* const __restrict__ p_dgamma_global,
43 DBetaDataType* const __restrict__ p_dbeta_global)
44{
45 GridwiseReduction::Run(dy_grid_desc_m_k,
46 x_grid_desc_m_k,
47 mean_grid_desc_m_k,
48 inv_std_grid_desc_m_k,
49 dgamma_grid_desc_m,
50 dbeta_grid_desc_m,
51 num_k_block_tile_iteration,
52 p_dy_global,
53 p_x_global,
54 p_mean_global,
55 p_inv_std_global,
56 p_dgamma_global,
57 p_dbeta_global);
58};
59
60template <typename DYDataType,
61 typename XDataType,
62 typename MeanInvStdDataType,
63 typename ComputeDataType,
64 typename DGammaDataType,
65 typename DBetaDataType,
66 index_t Rank,
67 index_t NumReduceDim,
68 index_t BlockSize,
69 index_t MThreadClusterSize,
70 index_t KThreadClusterSize,
71 index_t MThreadSliceSize,
72 index_t KThreadSliceSize,
73 bool IsDYFastestDimReduced,
74 index_t DYSrcVectorSize,
75 bool IsXFastestDimReduced,
76 index_t XSrcVectorSize,
77 bool IsMeanInvStdFastestDimReduced,
78 index_t MeanInvStdSrcVectorSize,
79 index_t DGammaDstVectorSize,
80 index_t DBetaDstVectorSize>
82 : public DeviceNormalizationBwdGammaBeta<DYDataType,
83 XDataType,
84 MeanInvStdDataType,
85 DGammaDataType,
86 DBetaDataType,
87 Rank,
88 NumReduceDim>
89{
90 static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0;
91 static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0;
92 static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0;
93
94 static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
95
96 static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) ||
97 (DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)),
98 "Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
99
100 static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
101 (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)),
102 "Invalid thread slice sizes and/or x vector sizes configuration, please check!");
103
104 static_assert(
105 (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
106 (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0),
107 "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
108 "check!");
109
110 static_assert(
111 ((MThreadSliceSize % DGammaDstVectorSize == 0) ||
112 (MThreadSliceSize % DBetaDstVectorSize == 0)),
113 "Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
114 "check!");
115
116 static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
117 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
118 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
119
120 static constexpr bool reduceAllDim = (NumInvariantDim == 0);
121 static_assert(!reduceAllDim);
122
123 static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
124 const std::vector<index_t>& inStrides,
125 int numBlockTileIteration)
126 {
127 const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<Rank>{});
128 const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<Rank>{});
129
130 const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
131
132 const auto in_grid_desc_m_k = [&]() {
133 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
135
136 const auto reduceDimLengths =
137 make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
138 const auto invariantDimLengths =
139 make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
140
141 return transform_tensor_descriptor(inDesc,
142 make_tuple(make_merge_transform(invariantDimLengths),
143 make_merge_transform(reduceDimLengths)),
144 make_tuple(InvariantDims{}, ReduceDims{}),
146 }();
147
148 const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
149 const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
150
151 const auto inPad_M =
152 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
153 const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength;
154
155 auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
156 in_grid_desc_m_k,
157 make_tuple(make_right_pad_transform(invariantLength, inPad_M),
158 make_right_pad_transform(reduceLength, inPad_K)),
161
162 return in_grid_desc_m_k_padded;
163 }
164
165 static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
166 const std::vector<index_t>& outStrides)
167 {
168 const auto tupleDstLengths =
169 generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumInvariantDim>{});
170 const auto tupleDstStrides =
171 generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumInvariantDim>{});
172
173 auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
174
175 auto out_grid_desc_m = transform_tensor_descriptor(
176 outDesc,
177 make_tuple(make_merge_transform(tupleDstLengths)),
180
181 const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
182
183 const auto outPad =
184 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
185
186 auto out_grid_desc_m_padded = transform_tensor_descriptor(
187 out_grid_desc_m,
188 make_tuple(make_right_pad_transform(invariantLength, outPad)),
191 return (out_grid_desc_m_padded);
192 };
193
194 using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1));
195 using GridDesc_M = decltype(MakeDst1dDescriptor({1}, {1}));
196
199 XDataType,
200 MeanInvStdDataType,
201 ComputeDataType,
202 DGammaDataType,
203 DBetaDataType,
206 BlockSize,
207 MThreadClusterSize,
208 KThreadClusterSize,
209 MThreadSliceSize,
210 KThreadSliceSize,
212 DYSrcVectorSize,
214 XSrcVectorSize,
216 MeanInvStdSrcVectorSize,
217 DGammaDstVectorSize,
218 DBetaDstVectorSize>;
219
220 struct Argument : public BaseArgument
221 {
222 Argument(const std::vector<index_t> inLengths,
223 const std::vector<index_t> dyStrides,
224 const std::vector<index_t> xStrides,
225 const std::vector<index_t> meanStrides,
226 const std::vector<index_t> invStdStrides,
227 const std::vector<index_t> outLengths,
228 const std::vector<index_t> dgammaStrides,
229 const std::vector<index_t> dbetaStrides,
230 const std::vector<index_t> reduceDims,
231 const DYDataType* p_dy,
232 const XDataType* p_x,
233 const MeanInvStdDataType* p_mean,
234 const MeanInvStdDataType* p_invStd,
235 DGammaDataType* p_dgamma,
236 DBetaDataType* p_dbeta)
237 : p_dy_(p_dy),
238 p_x_(p_x),
239 p_mean_(p_mean),
240 p_invStd_(p_invStd),
241 p_dgamma_(p_dgamma),
242 p_dbeta_(p_dbeta),
243 outLengths_{outLengths},
244 dgammaStrides_{dgammaStrides},
245 dbetaStrides_{dbetaStrides}
246 {
252 shuffle_tensor_dimensions<Rank, NumReduceDim>(invStdStrides, reduceDims);
253
255
257
259
266
269 }
270
271 const DYDataType* p_dy_;
272 const XDataType* p_x_;
273 const MeanInvStdDataType* p_mean_;
274 const MeanInvStdDataType* p_invStd_;
275 DGammaDataType* p_dgamma_;
276 DBetaDataType* p_dbeta_;
277
278 std::vector<index_t> inLengths_;
279 std::vector<index_t> dyStrides_;
280 std::vector<index_t> xStrides_;
281 std::vector<index_t> meanStrides_;
282 std::vector<index_t> invStdStrides_;
283 std::vector<index_t> outLengths_;
284 std::vector<index_t> dgammaStrides_;
285 std::vector<index_t> dbetaStrides_;
286
288 size_t gridSize_;
289
290 // Source descriptor
295
296 // Destination descriptor
299
300 index_t MRaw_; // Invariant length
301 index_t KRaw_; // reduce length
302 };
303
304 struct Invoker : public BaseInvoker
305 {
306 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
307 {
308 const auto kernel_main =
310 DYDataType,
311 XDataType,
312 MeanInvStdDataType,
313 DGammaDataType,
314 DBetaDataType,
316 GridDesc_M>;
317
318 return launch_and_time_kernel(stream_config,
319 kernel_main,
320 dim3(arg.gridSize_),
321 dim3(BlockSize),
322 0,
330 arg.p_dy_,
331 arg.p_x_,
332 arg.p_mean_,
333 arg.p_invStd_,
334 arg.p_dgamma_,
335 arg.p_dbeta_);
336 }
337
338 float Run(const BaseArgument* p_arg,
339 const StreamConfig& stream_config = StreamConfig{}) override
340 {
341 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
342 }
343 };
344
345 template <index_t SrcVectorDim, index_t SrcVectorSize>
346 bool IsSrcVectorDimSizeValid(const std::vector<index_t>& lengths,
347 const std::vector<index_t>& strides)
348 {
349 if constexpr(SrcVectorSize == 1)
350 return true;
351
352 // Fastest dimension is not reduced
353 if constexpr(SrcVectorDim == 0)
354 {
355 if constexpr(NumInvariantDim == 0)
356 return false;
357
358 if(strides[NumInvariantDim - 1] != 1)
359 return false;
360
361 if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0)
362 return false;
363 }
364 else // Fastest dimension is reduced
365 {
366 if(strides[Rank - 1] != 1)
367 return false;
368
369 if(lengths[Rank - 1] % SrcVectorSize != 0)
370 return false;
371 };
372
373 return true;
374 }
375
376 template <index_t DstVectorSize>
377 bool IsDstVectorSizeValid(const std::vector<index_t>& lengths,
378 const std::vector<index_t>& strides)
379 {
380 if constexpr(DstVectorSize == 1)
381 return true;
382
383 if(strides[NumInvariantDim - 1] != 1)
384 return false;
385
386 if(lengths[NumInvariantDim - 1] % DstVectorSize != 0)
387 return false;
388
389 return true;
390 }
391
392 bool IsSupportedArgument(const BaseArgument* p_arg) override
393 {
394 const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
395
396 bool pass = true;
398 p_arg_->dyStrides_);
400 p_arg_->xStrides_);
402 p_arg_->inLengths_, p_arg_->meanStrides_);
404 p_arg_->inLengths_, p_arg_->invStdStrides_);
405
406 pass &=
408 pass &=
410
411 return pass;
412 }
413
414 std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
415 const std::vector<index_t> dyStrides,
416 const std::vector<index_t> xStrides,
417 const std::vector<index_t> meanStrides,
418 const std::vector<index_t> invStdStrides,
419 const std::vector<index_t> outLengths,
420 const std::vector<index_t> dgammaStrides,
421 const std::vector<index_t> dbetaStrides,
422 const std::vector<index_t> reduceDims,
423 const void* p_dy,
424 const void* p_x,
425 const void* p_mean,
426 const void* p_invStd,
427 void* p_dgamma,
428 void* p_dbeta) override
429 {
430 if(inLengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank ||
431 meanStrides.size() != Rank || invStdStrides.size() != Rank)
432 throw std::runtime_error("dimension is incorrect");
433
434 if(outLengths.size() != NumInvariantDim || dgammaStrides.size() != NumInvariantDim ||
435 dbetaStrides.size() != NumInvariantDim)
436 throw std::runtime_error("dimension is incorrect");
437
438 return std::make_unique<Argument>(inLengths,
439 dyStrides,
440 xStrides,
441 meanStrides,
442 invStdStrides,
443 outLengths,
444 dgammaStrides,
445 dbetaStrides,
446 reduceDims,
447 static_cast<const DYDataType*>(p_dy),
448 static_cast<const XDataType*>(p_x),
449 static_cast<const MeanInvStdDataType*>(p_mean),
450 static_cast<const MeanInvStdDataType*>(p_invStd),
451 static_cast<DGammaDataType*>(p_dgamma),
452 static_cast<DBetaDataType*>(p_dbeta));
453 }
454
455 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
456 {
457 return std::make_unique<Invoker>();
458 }
459
460 std::string GetTypeString() const override
461 {
462 auto str = std::stringstream();
463
464 // clang-format off
465 str << "DeviceNormalizationBwdGammaBetaImpl<" << BlockSize << ",";
466 str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
467 str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
468 str << "VectorSize_DY" << DYSrcVectorSize << "_X" << XSrcVectorSize ;
469 str << "_DGamma" << DGammaDstVectorSize << "_DBeta" << DBetaDstVectorSize << ">";
470 // clang-format on
471
472 return str.str();
473 }
474};
475
476} // namespace device
477} // namespace tensor_operation
478} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
auto make_tuple_from_array(const std::vector< index_t > &lengths, Number< arraySize >)
Definition device_reduce_common.hpp:65
__global__ void kernel_normalization_bwd_gamma_beta(const GridDesc_M_K dy_grid_desc_m_k, const GridDesc_M_K x_grid_desc_m_k, const GridDesc_M_K mean_grid_desc_m_k, const GridDesc_M_K inv_std_grid_desc_m_k, const GridDesc_M dgamma_grid_desc_m, const GridDesc_M dbeta_grid_desc_m, index_t num_k_block_tile_iteration, const DYDataType *const __restrict__ p_dy_global, const XDataType *const __restrict__ p_x_global, const MeanInvStdDataType *const __restrict__ p_mean_global, const MeanInvStdDataType *const __restrict__ p_inv_std_global, DGammaDataType *const __restrict__ p_dgamma_global, DBetaDataType *const __restrict__ p_dbeta_global)
Definition device_normalization_bwd_gamma_beta_impl.hpp:31
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
auto make_tuple_from_array_and_index_seq(const std::vector< index_t > &lengths, Sequence< Ns... >)
Definition device_reduce_common.hpp:59
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
Definition ck/stream_config.hpp:10
Definition gridwise_normalization_bwd_gamma_beta.hpp:37
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition device_base.hpp:197
Definition device_normalization_bwd_gamma_beta.hpp:22
Definition device_normalization_bwd_gamma_beta_impl.hpp:221
GridDesc_M dgamma_grid_desc_m_
Definition device_normalization_bwd_gamma_beta_impl.hpp:297
std::vector< index_t > xStrides_
Definition device_normalization_bwd_gamma_beta_impl.hpp:280
GridDesc_M_K mean_grid_desc_m_k_
Definition device_normalization_bwd_gamma_beta_impl.hpp:293
DBetaDataType * p_dbeta_
Definition device_normalization_bwd_gamma_beta_impl.hpp:276
std::vector< index_t > dyStrides_
Definition device_normalization_bwd_gamma_beta_impl.hpp:279
size_t gridSize_
Definition device_normalization_bwd_gamma_beta_impl.hpp:288
std::vector< index_t > inLengths_
Definition device_normalization_bwd_gamma_beta_impl.hpp:278
const MeanInvStdDataType * p_invStd_
Definition device_normalization_bwd_gamma_beta_impl.hpp:274
GridDesc_M_K inv_std_grid_desc_m_k_
Definition device_normalization_bwd_gamma_beta_impl.hpp:294
std::vector< index_t > outLengths_
Definition device_normalization_bwd_gamma_beta_impl.hpp:283
const MeanInvStdDataType * p_mean_
Definition device_normalization_bwd_gamma_beta_impl.hpp:273
DGammaDataType * p_dgamma_
Definition device_normalization_bwd_gamma_beta_impl.hpp:275
int numBlockTileIteration_
Definition device_normalization_bwd_gamma_beta_impl.hpp:287
index_t MRaw_
Definition device_normalization_bwd_gamma_beta_impl.hpp:300
index_t KRaw_
Definition device_normalization_bwd_gamma_beta_impl.hpp:301
std::vector< index_t > dgammaStrides_
Definition device_normalization_bwd_gamma_beta_impl.hpp:284
const XDataType * p_x_
Definition device_normalization_bwd_gamma_beta_impl.hpp:272
GridDesc_M dbeta_grid_desc_m_
Definition device_normalization_bwd_gamma_beta_impl.hpp:298
std::vector< index_t > dbetaStrides_
Definition device_normalization_bwd_gamma_beta_impl.hpp:285
GridDesc_M_K x_grid_desc_m_k_
Definition device_normalization_bwd_gamma_beta_impl.hpp:292
GridDesc_M_K dy_grid_desc_m_k_
Definition device_normalization_bwd_gamma_beta_impl.hpp:291
Argument(const std::vector< index_t > inLengths, const std::vector< index_t > dyStrides, const std::vector< index_t > xStrides, const std::vector< index_t > meanStrides, const std::vector< index_t > invStdStrides, const std::vector< index_t > outLengths, const std::vector< index_t > dgammaStrides, const std::vector< index_t > dbetaStrides, const std::vector< index_t > reduceDims, const DYDataType *p_dy, const XDataType *p_x, const MeanInvStdDataType *p_mean, const MeanInvStdDataType *p_invStd, DGammaDataType *p_dgamma, DBetaDataType *p_dbeta)
Definition device_normalization_bwd_gamma_beta_impl.hpp:222
std::vector< index_t > invStdStrides_
Definition device_normalization_bwd_gamma_beta_impl.hpp:282
std::vector< index_t > meanStrides_
Definition device_normalization_bwd_gamma_beta_impl.hpp:281
const DYDataType * p_dy_
Definition device_normalization_bwd_gamma_beta_impl.hpp:271
Definition device_normalization_bwd_gamma_beta_impl.hpp:305
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_normalization_bwd_gamma_beta_impl.hpp:306
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_normalization_bwd_gamma_beta_impl.hpp:338
Definition device_normalization_bwd_gamma_beta_impl.hpp:89
bool IsSrcVectorDimSizeValid(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_normalization_bwd_gamma_beta_impl.hpp:346
static constexpr index_t K_BlockTileSize
Definition device_normalization_bwd_gamma_beta_impl.hpp:118
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_normalization_bwd_gamma_beta_impl.hpp:392
static auto MakeSrc2dDescriptor(const std::vector< index_t > &inLengths, const std::vector< index_t > &inStrides, int numBlockTileIteration)
Definition device_normalization_bwd_gamma_beta_impl.hpp:123
static constexpr index_t M_BlockTileSize
Definition device_normalization_bwd_gamma_beta_impl.hpp:117
decltype(MakeSrc2dDescriptor({1}, {1}, 1)) GridDesc_M_K
Definition device_normalization_bwd_gamma_beta_impl.hpp:194
bool IsDstVectorSizeValid(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_normalization_bwd_gamma_beta_impl.hpp:377
static auto MakeDst1dDescriptor(const std::vector< index_t > &outLengths, const std::vector< index_t > &outStrides)
Definition device_normalization_bwd_gamma_beta_impl.hpp:165
decltype(MakeDst1dDescriptor({1}, {1})) GridDesc_M
Definition device_normalization_bwd_gamma_beta_impl.hpp:195
static constexpr index_t MeanInvStdSrcVectorDim
Definition device_normalization_bwd_gamma_beta_impl.hpp:92
static constexpr index_t DYSrcVectorDim
Definition device_normalization_bwd_gamma_beta_impl.hpp:90
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_normalization_bwd_gamma_beta_impl.hpp:455
std::string GetTypeString() const override
Definition device_normalization_bwd_gamma_beta_impl.hpp:460
static constexpr bool reduceAllDim
Definition device_normalization_bwd_gamma_beta_impl.hpp:120
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > inLengths, const std::vector< index_t > dyStrides, const std::vector< index_t > xStrides, const std::vector< index_t > meanStrides, const std::vector< index_t > invStdStrides, const std::vector< index_t > outLengths, const std::vector< index_t > dgammaStrides, const std::vector< index_t > dbetaStrides, const std::vector< index_t > reduceDims, const void *p_dy, const void *p_x, const void *p_mean, const void *p_invStd, void *p_dgamma, void *p_dbeta) override
Definition device_normalization_bwd_gamma_beta_impl.hpp:414
static constexpr index_t XSrcVectorDim
Definition device_normalization_bwd_gamma_beta_impl.hpp:91
GridwiseNormalizationBwdGammaBeta_mk_to_k< DYDataType, XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType, GridDesc_M_K, GridDesc_M, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, DYSrcVectorDim, DYSrcVectorSize, XSrcVectorDim, XSrcVectorSize, MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize, DGammaDstVectorSize, DBetaDstVectorSize > GridwiseNormalizationBwdGammaBeta
Definition device_normalization_bwd_gamma_beta_impl.hpp:197
static constexpr index_t NumInvariantDim
Definition device_normalization_bwd_gamma_beta_impl.hpp:116