gridwise_normalization_selector.hpp Source File

gridwise_normalization_selector.hpp Source File#

Composable Kernel: gridwise_normalization_selector.hpp Source File
gridwise_normalization_selector.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck {
10template <typename GridwiseReduction,
11 typename XDataType,
12 typename GammaDataType,
13 typename BetaDataType,
14 typename YDataType,
15 typename SaveMeanInvStdDataType,
16 typename ComputeDataType,
17 typename YElementwiseOperation,
18 typename GridDesc_M_K,
19 typename GridDesc_M>
20__global__ void
21kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
22 const GridDesc_M_K gamma_grid_desc_m_k,
23 const GridDesc_M_K beta_grid_desc_m_k,
24 const GridDesc_M_K y_grid_desc_m_k,
25 const GridDesc_M save_mean_grid_desc_m,
26 const GridDesc_M save_inv_std_grid_desc_m,
27 index_t num_k_block_tile_iteration,
28 ComputeDataType epsilon,
29 const XDataType* const __restrict__ p_x_global,
30 const GammaDataType* const __restrict__ p_gamma_global,
31 const BetaDataType* const __restrict__ p_beta_global,
32 YDataType* const __restrict__ p_y_global,
33 SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
34 SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
35 const YElementwiseOperation y_elementwise_op)
36{
37 GridwiseReduction::Run(x_grid_desc_m_k,
38 gamma_grid_desc_m_k,
39 beta_grid_desc_m_k,
40 y_grid_desc_m_k,
41 save_mean_grid_desc_m,
42 save_inv_std_grid_desc_m,
43 num_k_block_tile_iteration,
44 epsilon,
45 p_x_global,
46 p_gamma_global,
47 p_beta_global,
48 p_y_global,
49 p_save_mean_global,
50 p_save_inv_std_global,
51 y_elementwise_op);
52};
53
54template <typename XDataType,
55 typename GammaDataType,
56 typename BetaDataType,
57 typename YDataType,
58 typename SaveMeanInvStdDataType,
59 typename ComputeDataType,
60 typename YElementwiseOperation,
61 typename GridDesc_M_K,
62 typename GridDesc_M,
63 index_t BlockSize,
64 index_t MThreadClusterSize,
65 index_t KThreadClusterSize,
66 index_t MThreadSliceSize,
67 index_t KThreadSliceSize,
68 index_t XSrcVectorDim,
69 index_t XSrcVectorSize,
70 index_t GammaSrcVectorDim,
71 index_t GammaSrcVectorSize,
72 index_t BetaSrcVectorDim,
73 index_t BetaSrcVectorSize,
74 index_t YDstVectorDim,
75 index_t YDstVectorSize,
76 index_t SaveMeanInvStdDstVectorSize,
77 bool UseWelford>
78auto NormalizationKernelSelector(bool isSweepOnce)
79{
80 using GridwiseNormalizationGenericNaive =
82 GammaDataType,
83 BetaDataType,
84 YDataType,
85 SaveMeanInvStdDataType,
86 ComputeDataType,
87 YElementwiseOperation,
88 GridDesc_M_K,
89 GridDesc_M,
90 BlockSize,
91 MThreadClusterSize,
92 KThreadClusterSize,
93 MThreadSliceSize,
94 KThreadSliceSize,
95 XSrcVectorDim,
96 XSrcVectorSize,
97 GammaSrcVectorDim,
98 GammaSrcVectorSize,
99 BetaSrcVectorDim,
100 BetaSrcVectorSize,
101 YDstVectorDim,
102 YDstVectorSize,
103 SaveMeanInvStdDstVectorSize,
104 false>;
105 using GridwiseNormalizationSweepOnceNaive =
107 GammaDataType,
108 BetaDataType,
109 YDataType,
110 SaveMeanInvStdDataType,
111 ComputeDataType,
112 YElementwiseOperation,
113 GridDesc_M_K,
114 GridDesc_M,
115 BlockSize,
116 MThreadClusterSize,
117 KThreadClusterSize,
118 MThreadSliceSize,
119 KThreadSliceSize,
120 XSrcVectorDim,
121 XSrcVectorSize,
122 GammaSrcVectorDim,
123 GammaSrcVectorSize,
124 BetaSrcVectorDim,
125 BetaSrcVectorSize,
126 YDstVectorDim,
127 YDstVectorSize,
128 SaveMeanInvStdDstVectorSize,
129 true>;
130 using GridwiseNormalizationGenericWelford =
132 GammaDataType,
133 BetaDataType,
134 YDataType,
135 SaveMeanInvStdDataType,
136 ComputeDataType,
137 YElementwiseOperation,
138 GridDesc_M_K,
139 GridDesc_M,
140 BlockSize,
141 MThreadClusterSize,
142 KThreadClusterSize,
143 MThreadSliceSize,
144 KThreadSliceSize,
145 XSrcVectorDim,
146 XSrcVectorSize,
147 GammaSrcVectorDim,
148 GammaSrcVectorSize,
149 BetaSrcVectorDim,
150 BetaSrcVectorSize,
151 YDstVectorDim,
152 YDstVectorSize,
153 SaveMeanInvStdDstVectorSize,
154 false>;
155 using GridwiseNormalizationSweepOnceWelford =
157 GammaDataType,
158 BetaDataType,
159 YDataType,
160 SaveMeanInvStdDataType,
161 ComputeDataType,
162 YElementwiseOperation,
163 GridDesc_M_K,
164 GridDesc_M,
165 BlockSize,
166 MThreadClusterSize,
167 KThreadClusterSize,
168 MThreadSliceSize,
169 KThreadSliceSize,
170 XSrcVectorDim,
171 XSrcVectorSize,
172 GammaSrcVectorDim,
173 GammaSrcVectorSize,
174 BetaSrcVectorDim,
175 BetaSrcVectorSize,
176 YDstVectorDim,
177 YDstVectorSize,
178 SaveMeanInvStdDstVectorSize,
179 true>;
180
181 if constexpr(UseWelford)
182 {
183 return isSweepOnce ? kernel_normalization<GridwiseNormalizationSweepOnceWelford,
184 XDataType,
185 GammaDataType,
186 BetaDataType,
187 YDataType,
188 SaveMeanInvStdDataType,
189 ComputeDataType,
190 YElementwiseOperation,
191 GridDesc_M_K,
192 GridDesc_M>
193 : kernel_normalization<GridwiseNormalizationGenericWelford,
194 XDataType,
195 GammaDataType,
196 BetaDataType,
197 YDataType,
198 SaveMeanInvStdDataType,
199 ComputeDataType,
200 YElementwiseOperation,
201 GridDesc_M_K,
202 GridDesc_M>;
203 }
204 else
205 {
206 return isSweepOnce ? kernel_normalization<GridwiseNormalizationSweepOnceNaive,
207 XDataType,
208 GammaDataType,
209 BetaDataType,
210 YDataType,
211 SaveMeanInvStdDataType,
212 ComputeDataType,
213 YElementwiseOperation,
214 GridDesc_M_K,
215 GridDesc_M>
216 : kernel_normalization<GridwiseNormalizationGenericNaive,
217 XDataType,
218 GammaDataType,
219 BetaDataType,
220 YDataType,
221 SaveMeanInvStdDataType,
222 ComputeDataType,
223 YElementwiseOperation,
224 GridDesc_M_K,
225 GridDesc_M>;
226 }
227}
228
229} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K beta_grid_desc_m_k, const GridDesc_M_K y_grid_desc_m_k, const GridDesc_M save_mean_grid_desc_m, const GridDesc_M save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition gridwise_normalization_selector.hpp:21
auto NormalizationKernelSelector(bool isSweepOnce)
Definition gridwise_normalization_selector.hpp:78
Definition gridwise_normalization_naive_variance.hpp:42
Definition gridwise_normalization_welford_variance.hpp:40