welford_helper.hpp Source File

welford_helper.hpp Source File#

Composable Kernel: welford_helper.hpp Source File
welford_helper.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6namespace ck {
7namespace tensor_operation {
8namespace device {
9
10template <index_t K_BlockTileSize, index_t KThreadSliceSize>
12{
14 long_index_t reduce_length)
15 : numBlockTileIteration_{numBlockTileIteration}
16 {
17 count_in_last_tile_ = reduce_length % K_BlockTileSize;
18 };
19
20 __device__ index_t operator()(index_t thread_k_cluster_id) const
21 {
22 if(count_in_last_tile_ == 0)
23 return (KThreadSliceSize * numBlockTileIteration_);
24 else
25 {
26 index_t num_complete_slice = count_in_last_tile_ / KThreadSliceSize;
27 index_t count_in_last_slice = count_in_last_tile_ % KThreadSliceSize;
28
29 if(thread_k_cluster_id < num_complete_slice)
30 return (KThreadSliceSize * numBlockTileIteration_);
31 else if(thread_k_cluster_id == num_complete_slice)
32 return (KThreadSliceSize * (numBlockTileIteration_ - 1) + count_in_last_slice);
33 else
34 return (KThreadSliceSize * (numBlockTileIteration_ - 1));
35 };
36 };
37
40};
41
42template <index_t K_BlockTileSize, index_t KThreadSliceSize>
44{
46 index_t numBlockTileIteration,
47 long_index_t reduce_length)
48 : blkGroupSize_(blkGroupSize), numBlockTileIteration_{numBlockTileIteration}
49 {
51 reduce_length - K_BlockTileSize * numBlockTileIteration_ * (blkGroupSize_ - 1);
53 (last_block_reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize;
54 };
55
56 __device__ index_t operator()(index_t block_local_id, index_t thread_k_cluster_id) const
57 {
58 if(last_block_reduce_length_ == K_BlockTileSize * numBlockTileIteration_ ||
59 block_local_id < blkGroupSize_ - 1)
60 return (KThreadSliceSize * numBlockTileIteration_);
61
62 index_t count_in_last_tile = last_block_reduce_length_ % K_BlockTileSize;
63
64 if(count_in_last_tile == 0)
65 return (KThreadSliceSize * numBlockTileIterationByLastBlock_);
66 else
67 {
68 index_t num_complete_slice = count_in_last_tile / KThreadSliceSize;
69
70 if(thread_k_cluster_id < num_complete_slice)
71 return (KThreadSliceSize * numBlockTileIterationByLastBlock_);
72 else if(thread_k_cluster_id == num_complete_slice)
73 return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1) +
74 count_in_last_tile);
75 else
76 return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1));
77 };
78 };
79
82
85};
86
87} // namespace device
88} // namespace tensor_operation
89} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
int64_t long_index_t
Definition ck.hpp:300
__device__ index_t operator()(index_t thread_k_cluster_id) const
Definition welford_helper.hpp:20
GetReduceCountPerThreadForBlockwiseWelford(index_t numBlockTileIteration, long_index_t reduce_length)
Definition welford_helper.hpp:13
__device__ index_t operator()(index_t block_local_id, index_t thread_k_cluster_id) const
Definition welford_helper.hpp:56
index_t numBlockTileIterationByLastBlock_
Definition welford_helper.hpp:84
GetReduceCountPerThreadForMultiblockWelford(index_t blkGroupSize, index_t numBlockTileIteration, long_index_t reduce_length)
Definition welford_helper.hpp:45