gridwise_normalization_naive_variance.hpp Source File

gridwise_normalization_naive_variance.hpp Source File#

Composable Kernel: gridwise_normalization_naive_variance.hpp Source File
gridwise_normalization_naive_variance.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
13
14namespace ck {
15
16// Y = Normalization(X, Beta, Gamma)
17template <typename XDataType,
18 typename GammaDataType,
19 typename BetaDataType,
20 typename YDataType,
21 typename SaveMeanInvStdDataType,
22 typename ComputeDataType,
23 typename YElementwiseOperation,
24 typename GridDesc_M_K,
25 typename GridDesc_M,
26 index_t BlockSize,
27 index_t MThreadClusterSize,
28 index_t KThreadClusterSize,
29 index_t MThreadSliceSize,
30 index_t KThreadSliceSize,
31 index_t XSrcVectorDim,
32 index_t XSrcVectorSize,
33 index_t GammaSrcVectorDim,
34 index_t GammaSrcVectorSize,
35 index_t BetaSrcVectorDim,
36 index_t BetaSrcVectorSize,
37 index_t YDstVectorDim,
38 index_t YDstVectorSize,
39 index_t SaveMeanInvStdDstVectorSize,
40 bool SweepOnce>
42{
43 static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
44 (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
45 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
46
47 static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
48 (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
49 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
50
51 static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
52 "Invalid thread slice sizes and/or save mean and inverse std vector sizes "
53 "configuration, please check!");
54
55 static_assert(XSrcVectorSize == YDstVectorSize);
56 static_assert(XSrcVectorSize == GammaSrcVectorSize);
57 static_assert(XSrcVectorSize == BetaSrcVectorSize);
58
59 static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
60
62
65
68
69 static constexpr auto thread_cluster_desc =
71
75
77 static constexpr auto thread_buffer_desc_m =
79
84
86 BlockSize,
90 true>;
91
96 true>;
97
99
100 static constexpr auto I0 = Number<0>{};
101 static constexpr auto I1 = Number<1>{};
102 static constexpr auto I2 = Number<2>{};
103
104 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
105 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
106 static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
107
108 static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
109
110 __device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k,
111 const GridDesc_M_K& gamma_grid_desc_m_k,
112 const GridDesc_M_K& beta_grid_desc_m_k,
113 const GridDesc_M_K& y_grid_desc_m_k,
114 const GridDesc_M& save_mean_grid_desc_m,
115 const GridDesc_M& save_inv_std_grid_desc_m,
116 index_t num_k_block_tile_iteration,
117 ComputeDataType epsilon,
118 const XDataType* const __restrict__ p_x_global,
119 const GammaDataType* const __restrict__ p_gamma_global,
120 const BetaDataType* const __restrict__ p_beta_global,
121 YDataType* const __restrict__ p_y_global,
122 SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
123 SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
124 const YElementwiseOperation y_elementwise_op)
125 {
126 // LDS
127 __shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
128
129 auto reduce_work_buf =
130 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
131
132 auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
133 p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
134
135 auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
136 p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
137
138 auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
139 p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
140
141 auto x_thread_buf = generate_tuple(
142 [&](auto) {
144 ComputeDataType,
145 MThreadSliceSize * XSrcVectorSize,
146 true>{};
147 },
149
150 auto gamma_thread_buf = generate_tuple(
151 [&](auto) {
153 ComputeDataType,
154 MThreadSliceSize * GammaSrcVectorSize,
155 true>{};
156 },
158
159 auto& beta_thread_buf = gamma_thread_buf;
160
161 auto y_thread_buf = generate_tuple(
162 [&](auto) {
164 ComputeDataType,
165 MThreadSliceSize * YDstVectorSize,
166 true>{};
167 },
169
170 auto& x_square_thread_buf = y_thread_buf;
171
173 mean_thread_buf;
175 mean_square_thread_buf;
177 var_thread_buf = mean_square_thread_buf;
179 inv_std_thread_buf = mean_square_thread_buf;
180
181 const index_t thread_local_id = get_thread_local_1d_id();
182 const index_t block_global_id = get_block_1d_id();
183
184 const auto thread_cluster_idx =
185 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
186
187 const auto thread_m_cluster_id = thread_cluster_idx[I0];
188 const auto thread_k_cluster_id = thread_cluster_idx[I1];
189
190 auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
191 ComputeDataType,
192 GridDesc_M_K,
193 decltype(thread_buffer_desc_m_k),
196 XSrcVectorDim,
197 XSrcVectorSize,
198 1,
199 true>(
200 x_grid_desc_m_k,
201 make_multi_index(block_global_id * M_BlockTileSize +
202 thread_m_cluster_id * MThreadSliceSize,
203 thread_k_cluster_id * XSrcVectorSize));
204
205 auto threadwise_gamma_load =
207 ComputeDataType,
208 GridDesc_M_K,
209 decltype(thread_buffer_desc_m_k),
212 GammaSrcVectorDim,
213 GammaSrcVectorSize,
214 1,
215 true>(
216 gamma_grid_desc_m_k,
217 make_multi_index(block_global_id * M_BlockTileSize +
218 thread_m_cluster_id * MThreadSliceSize,
219 thread_k_cluster_id * GammaSrcVectorSize));
220
221 auto threadwise_beta_load =
223 ComputeDataType,
224 GridDesc_M_K,
225 decltype(thread_buffer_desc_m_k),
228 BetaSrcVectorDim,
229 BetaSrcVectorSize,
230 1,
231 true>(
232 beta_grid_desc_m_k,
233 make_multi_index(block_global_id * M_BlockTileSize +
234 thread_m_cluster_id * MThreadSliceSize,
235 thread_k_cluster_id * BetaSrcVectorSize));
236
237 auto threadwise_y_store =
239 YDataType,
240 decltype(thread_buffer_desc_m_k),
241 GridDesc_M_K,
242 YElementwiseOperation,
245 YDstVectorDim,
246 YDstVectorSize,
248 1,
249 true>(
250 y_grid_desc_m_k,
251 make_multi_index(block_global_id * M_BlockTileSize +
252 thread_m_cluster_id * MThreadSliceSize,
253 thread_k_cluster_id * YDstVectorSize),
254 y_elementwise_op);
255
256 auto threadwise_mean_store =
258 SaveMeanInvStdDataType,
259 decltype(thread_buffer_desc_m),
260 GridDesc_M,
263 Sequence<0>, // DimAccessOrder
264 0, // SrcVectorDim
265 SaveMeanInvStdDstVectorSize, // ScalarPerVector
267 1,
268 true>(
269 save_mean_grid_desc_m,
270 make_multi_index(block_global_id * M_BlockTileSize +
271 thread_m_cluster_id * MThreadSliceSize),
272 PassThroughOp{});
273
274 auto threadwise_inv_std_store =
276 SaveMeanInvStdDataType,
277 decltype(thread_buffer_desc_m),
278 GridDesc_M,
281 Sequence<0>, // DimAccessOrder
282 0, // SrcVectorDim
283 SaveMeanInvStdDstVectorSize, // ScalarPerVector
285 1,
286 true>(
287 save_inv_std_grid_desc_m,
288 make_multi_index(block_global_id * M_BlockTileSize +
289 thread_m_cluster_id * MThreadSliceSize),
290 PassThroughOp{});
291
292 constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
293 constexpr auto thread_copy_bwd_step_m_k =
294 make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
295
296 const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
297 p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
298
299 const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
300 p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
301
302 const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
303 p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
304
305 // E(x), E[x^2], var(x)
306 // FIXME: Should not hack the transform from deviceOP
307 ComputeDataType reduce_length = type_convert<ComputeDataType>(
308 x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]);
309
311 mean_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
312 mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
313 });
314
315 // Separate sweep once and sweep twice pipeline
316 if constexpr(SweepOnce)
317 {
319 threadwise_x_load.Run(x_grid_desc_m_k,
320 x_global_val_buf,
322 make_tuple(I0, I0),
323 x_thread_buf(i));
324
325 threadwise_gamma_load.Run(gamma_grid_desc_m_k,
326 gamma_global_val_buf,
328 make_tuple(I0, I0),
329 gamma_thread_buf(i));
330
332 static_for<0, XSrcVectorSize, 1>{}([&](auto iK) {
333 constexpr auto offset_m_k =
334 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
335 x_square_thread_buf(i)(Number<offset_m_k>{}) =
336 x_thread_buf(i)(Number<offset_m_k>{}) *
337 x_thread_buf(i)(Number<offset_m_k>{});
338 });
339 });
340
341 ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf);
342 ThreadwiseSumReduce::Reduce(x_square_thread_buf[i], mean_square_thread_buf);
343
344 if constexpr(i != ThreadBufferNumber - 1)
345 {
346 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
347 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
348 thread_copy_fwd_step_m_k);
349 }
350 });
351
353 if constexpr(I > 0)
355
356 BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I));
357 mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
358
360
361 BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I));
362 mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
363
364 // var(x) = E[x^2] - E[x]^2
365 var_thread_buf(I) =
366 mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
367
368 inv_std_thread_buf(I) = type_convert<ComputeDataType>(1.0f) /
369 ck::math::sqrt(var_thread_buf(I) + epsilon);
370 });
371
372 // save mean and inverse std for backward (optional)
373 if(thread_k_cluster_id == 0)
374 {
375 if(p_save_mean_global != nullptr)
376 {
377 threadwise_mean_store.Run(thread_buffer_desc_m,
378 make_tuple(I0),
379 mean_thread_buf,
380 save_mean_grid_desc_m,
381 save_mean_global_val_buf);
382 }
383 if(p_save_inv_std_global != nullptr)
384 {
385 threadwise_inv_std_store.Run(thread_buffer_desc_m,
386 make_tuple(I0),
387 inv_std_thread_buf,
388 save_inv_std_grid_desc_m,
389 save_inv_std_global_val_buf);
390 }
391 }
392
393 // normalization
396 static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
397 constexpr auto offset_m_k =
398 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
399
400 // normalize
401 y_thread_buf(iK0)(Number<offset_m_k>{}) =
402 (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
403 inv_std_thread_buf(iM);
404
405 // gamma & beta
406 y_thread_buf(iK0)(Number<offset_m_k>{}) =
407 y_thread_buf(iK0)(Number<offset_m_k>{}) *
408 gamma_thread_buf(iK0)(Number<offset_m_k>{});
409 });
410 });
411 });
412
414 threadwise_beta_load.Run(beta_grid_desc_m_k,
415 beta_global_val_buf,
417 make_tuple(I0, I0),
418 beta_thread_buf(i));
419
420 if constexpr(i != ThreadBufferNumber - 1)
421 threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
422 thread_copy_fwd_step_m_k);
423 });
424
427 static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
428 constexpr auto offset_m_k =
429 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
430
431 // beta
432 y_thread_buf(iK0)(Number<offset_m_k>{}) =
433 y_thread_buf(iK0)(Number<offset_m_k>{}) +
434 beta_thread_buf(iK0)(Number<offset_m_k>{});
435 });
436 });
437 });
438
440 threadwise_y_store.Run(thread_buffer_desc_m_k,
441 make_tuple(I0, I0),
442 y_thread_buf(i),
443 y_grid_desc_m_k,
444 y_global_val_buf);
445
446 if constexpr(i != ThreadBufferNumber - 1)
447 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
448 thread_copy_fwd_step_m_k);
449 });
450 } // end of sweep once
451 else
452 {
453 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
454 {
456 threadwise_x_load.Run(x_grid_desc_m_k,
457 x_global_val_buf,
459 make_tuple(I0, I0),
460 x_thread_buf(i));
461 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
462
464 static_for<0, XSrcVectorSize, 1>{}([&](auto iK) {
465 constexpr auto offset_m_k =
466 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
467 x_square_thread_buf(i)(Number<offset_m_k>{}) =
468 x_thread_buf(i)(Number<offset_m_k>{}) *
469 x_thread_buf(i)(Number<offset_m_k>{});
470 });
471 });
472
473 ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf);
474 ThreadwiseSumReduce::Reduce(x_square_thread_buf[i], mean_square_thread_buf);
475 });
476 }
477
479 if constexpr(I > 0)
481
482 BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I));
483 mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
484
486
487 BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I));
488 mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
489
490 // var(x) = E[x^2] - E[x]^2
491 var_thread_buf(I) =
492 mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
493
494 inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon);
495 });
496
497 if(thread_k_cluster_id == 0)
498 {
499 if(p_save_mean_global != nullptr)
500 {
501 threadwise_mean_store.Run(thread_buffer_desc_m,
502 make_tuple(I0),
503 mean_thread_buf,
504 save_mean_grid_desc_m,
505 save_mean_global_val_buf);
506 }
507 if(p_save_inv_std_global != nullptr)
508 {
509 threadwise_inv_std_store.Run(thread_buffer_desc_m,
510 make_tuple(I0),
511 inv_std_thread_buf,
512 save_inv_std_grid_desc_m,
513 save_inv_std_global_val_buf);
514 }
515 }
516
517 auto thread_copy_tail_m_k =
518 (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
519
520 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
521 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
522 threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
523 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
524
525 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
526 {
528 threadwise_x_load.Run(x_grid_desc_m_k,
529 x_global_val_buf,
531 make_tuple(I0, I0),
532 x_thread_buf(i));
533 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
534 });
535
537 threadwise_gamma_load.Run(gamma_grid_desc_m_k,
538 gamma_global_val_buf,
540 make_tuple(I0, I0),
541 gamma_thread_buf(i));
542
543 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
544 thread_copy_fwd_step_m_k);
545 });
546
549 static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
550 constexpr auto offset_m_k =
551 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
552
553 // normalize
554 y_thread_buf(iK0)(Number<offset_m_k>{}) =
555 (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
556 inv_std_thread_buf(iM);
557
558 // gamma
559 y_thread_buf(iK0)(Number<offset_m_k>{}) =
560 y_thread_buf(iK0)(Number<offset_m_k>{}) *
561 gamma_thread_buf(iK0)(Number<offset_m_k>{});
562 });
563 });
564 });
565
567 threadwise_beta_load.Run(beta_grid_desc_m_k,
568 beta_global_val_buf,
570 make_tuple(I0, I0),
571 beta_thread_buf(i));
572 threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
573 thread_copy_fwd_step_m_k);
574 });
575
578 static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
579 constexpr auto offset_m_k =
580 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
581
582 // beta
583 y_thread_buf(iK0)(Number<offset_m_k>{}) =
584 y_thread_buf(iK0)(Number<offset_m_k>{}) +
585 beta_thread_buf(iK0)(Number<offset_m_k>{});
586 });
587 });
588 });
589
591 threadwise_y_store.Run(thread_buffer_desc_m_k,
592 make_tuple(I0, I0),
593 y_thread_buf(i),
594 y_grid_desc_m_k,
595 y_global_val_buf);
596 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
597 thread_copy_fwd_step_m_k);
598 });
599
600 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
601 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
602 2 * thread_copy_bwd_step_m_k);
603 threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
604 2 * thread_copy_bwd_step_m_k);
605 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
606 2 * thread_copy_bwd_step_m_k);
607 }
608 } // end of sweep twice
609 }
610};
611
612} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition gridwise_normalization_naive_variance.hpp:42
static constexpr auto ThreadBufferNumber
Definition gridwise_normalization_naive_variance.hpp:108
PartitionedBlockwiseReduction< ComputeDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder, reduce::Add, true > BlockwiseSumReduce
Definition gridwise_normalization_naive_variance.hpp:85
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_normalization_naive_variance.hpp:63
static constexpr auto I0
Definition gridwise_normalization_naive_variance.hpp:100
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_normalization_naive_variance.hpp:82
static constexpr auto thread_buffer_desc_m_k
Definition gridwise_normalization_naive_variance.hpp:73
static __device__ void Run(const GridDesc_M_K &x_grid_desc_m_k, const GridDesc_M_K &gamma_grid_desc_m_k, const GridDesc_M_K &beta_grid_desc_m_k, const GridDesc_M_K &y_grid_desc_m_k, const GridDesc_M &save_mean_grid_desc_m, const GridDesc_M &save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition gridwise_normalization_naive_variance.hpp:110
static constexpr auto I1
Definition gridwise_normalization_naive_variance.hpp:101
static constexpr index_t K_BlockTileSize
Definition gridwise_normalization_naive_variance.hpp:105
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< XSrcVectorSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_normalization_naive_variance.hpp:80
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition gridwise_normalization_naive_variance.hpp:66
Sequence< MThreadSliceSize, XSrcVectorSize > ThreadBufferLengths_M_K
Definition gridwise_normalization_naive_variance.hpp:72
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_normalization_naive_variance.hpp:61
Sequence< MThreadSliceSize > ThreadBufferLengths_M
Definition gridwise_normalization_naive_variance.hpp:76
static constexpr bool reorder_thread_cluster
Definition gridwise_normalization_naive_variance.hpp:59
static constexpr auto thread_cluster_desc
Definition gridwise_normalization_naive_variance.hpp:69
static constexpr index_t M_BlockTileSize
Definition gridwise_normalization_naive_variance.hpp:104
ThreadwiseReduction< ComputeDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M, reduce::Add, true > ThreadwiseSumReduce
Definition gridwise_normalization_naive_variance.hpp:92
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_normalization_naive_variance.hpp:98
static constexpr auto thread_buffer_desc_m
Definition gridwise_normalization_naive_variance.hpp:77
static constexpr index_t K_BlockTileStepSize
Definition gridwise_normalization_naive_variance.hpp:106
static constexpr auto I2
Definition gridwise_normalization_naive_variance.hpp:102
Definition reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, ComputeDataType &in_out_value)
Definition reduction_functions_blockwise.hpp:44
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition reduction_functions_threadwise.hpp:36
Definition threadwise_tensor_slice_transfer.hpp:39
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer.hpp:173
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition reduction_operator.hpp:37
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340