gridwise_elementwise_layernorm_welford_variance.hpp Source File

gridwise_elementwise_layernorm_welford_variance.hpp Source File#

Composable Kernel: gridwise_elementwise_layernorm_welford_variance.hpp Source File
gridwise_elementwise_layernorm_welford_variance.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck {
14
15// X = Elementwise(input1, input2, input3, ...)
16// Y = Normalization(X, beta, gamma)
17template <typename InDataTypePointerTuple,
18 typename XDataType,
19 typename GammaDataType,
20 typename BetaDataType,
21 typename YDataType,
22 typename AccDataType,
23 typename XElementwiseOperation,
24 typename YElementwiseOperation,
25 typename InGrid2dDescTuple,
26 typename GridDesc_M_K,
27 index_t BlockSize,
28 index_t MThreadClusterSize,
29 index_t KThreadClusterSize,
30 index_t MThreadSliceSize,
31 index_t KThreadSliceSize,
32 index_t XSrcVectorDim,
33 index_t XSrcVectorSize,
34 index_t GammaSrcVectorDim,
35 index_t GammaSrcVectorSize,
36 index_t BetaSrcVectorDim,
37 index_t BetaSrcVectorSize,
38 index_t YDstVectorDim,
39 index_t YDstVectorSize,
40 bool SweepOnce>
42{
43 static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
44 (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
45 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
46
47 static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
48 (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
49 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
50
51 static constexpr index_t NumInput = InDataTypePointerTuple::Size();
52
53 static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
54
56
59
62
63 static constexpr auto thread_cluster_desc =
65
70
73
75 BlockSize,
78
79 static constexpr auto I0 = Number<0>{};
80 static constexpr auto I1 = Number<1>{};
81 static constexpr auto I2 = Number<2>{};
82
83 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
84 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
85 static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
86
87 static constexpr auto XThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
88 static constexpr auto GammaThreadBufferNumber = Number<KThreadSliceSize / GammaSrcVectorSize>{};
89 static constexpr auto BetaThreadBufferNumber = Number<KThreadSliceSize / BetaSrcVectorSize>{};
90 static constexpr auto YThreadBufferNumber = Number<KThreadSliceSize / YDstVectorSize>{};
91
92 __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k,
93 int thread_k_cluster_id)
94 {
95 int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
96 int kPerThread =
97 kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
98 int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
99
100 if(kPerBlockTail > 0)
101 {
103 int thread_max_len =
104 (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i;
105 int delta = thread_max_len - kPerBlockTail;
106 delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize);
107 kPerThread += XSrcVectorSize - delta;
108 });
109 }
110
111 return kPerThread;
112 }
113
114 __device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple,
115 const GridDesc_M_K& x_grid_desc_m_k,
116 const GridDesc_M_K& gamma_grid_desc_m_k,
117 const GridDesc_M_K& beta_grid_desc_m_k,
118 const GridDesc_M_K& y_grid_desc_m_k,
119 index_t num_k_block_tile_iteration,
120 AccDataType epsilon,
121 const InDataTypePointerTuple p_in_global_tuple,
122 XDataType* const __restrict__ p_x_lds_,
123 const GammaDataType* const __restrict__ p_gamma_global,
124 const BetaDataType* const __restrict__ p_beta_global,
125 YDataType* const __restrict__ p_y_global,
126 const XElementwiseOperation x_elementwise_op,
127 const YElementwiseOperation y_elementwise_op)
128 {
129 if constexpr(SweepOnce)
130 {
131 num_k_block_tile_iteration = 1;
132 }
133
134 const index_t thread_local_id = get_thread_local_1d_id();
135 const index_t block_global_id = get_block_1d_id();
136 const index_t grid_size = get_grid_size();
137
138 auto in_global_buf_tuple = generate_tuple(
139 [&](auto I) {
140 static_assert(in_grid_2d_desc_tuple[I].GetNumOfDimension() ==
141 2); // matrix dimension
142
144 p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize());
145 },
147
148 auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
149 p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
150
152 p_x_lds_, x_grid_desc_m_k.GetElementSpaceSize() / grid_size);
153
154 auto in_thread_buf_tuple = generate_tuple(
155 [&](auto) {
156 return generate_tuple(
157 [&](auto) {
159 AccDataType,
160 MThreadSliceSize * XSrcVectorSize,
161 true>{};
162 },
164 },
166
167 auto x_thread_buf = generate_tuple(
168 [&](auto) {
170 AccDataType,
171 MThreadSliceSize * XSrcVectorSize,
172 true>{};
173 },
175
176 auto gamma_thread_buf = generate_tuple(
177 [&](auto) {
179 AccDataType,
180 MThreadSliceSize * GammaSrcVectorSize,
181 true>{};
182 },
184
185 auto beta_thread_buf = generate_tuple(
186 [&](auto) {
188 AccDataType,
189 MThreadSliceSize * BetaSrcVectorSize,
190 true>{};
191 },
193
194 auto y_thread_buf = generate_tuple(
195 [&](auto) {
197 AccDataType,
198 MThreadSliceSize * YDstVectorSize,
199 true>{};
200 },
202
205
206 const auto thread_cluster_idx =
207 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
208
209 const auto thread_m_cluster_id = thread_cluster_idx[I0];
210 const auto thread_k_cluster_id = thread_cluster_idx[I1];
211
212 using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>;
213
214 constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
216
217 auto in_global_load_tuple = generate_tuple(
218 [&](auto I) {
219 using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
221
222 return ThreadwiseTensorSliceTransfer_v2<DataType,
223 AccDataType,
224 decltype(in_grid_2d_desc_tuple[I]),
225 decltype(thread_buffer_desc_m_k),
226 ThreadBufferLengths_M_K,
228 XSrcVectorDim,
229 XSrcVectorSize,
230 1,
231 false>{
232 in_grid_2d_desc_tuple[I],
233 make_multi_index(block_global_id * M_BlockTileSize +
234 thread_m_cluster_id * MThreadSliceSize,
235 thread_k_cluster_id * XSrcVectorSize)};
236 },
238
239 auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
240 AccDataType,
241 GridDesc_M_K,
242 decltype(thread_buffer_desc_m_k),
243 ThreadBufferLengths_M_K,
245 XSrcVectorDim,
246 XSrcVectorSize,
247 1,
248 true>(
249 x_grid_desc_m_k,
250 make_multi_index(thread_m_cluster_id * MThreadSliceSize,
251 thread_k_cluster_id * XSrcVectorSize));
252
253 auto threadwise_gamma_load =
255 AccDataType,
256 GridDesc_M_K,
257 decltype(thread_buffer_desc_m_k),
258 ThreadBufferLengths_M_K,
260 GammaSrcVectorDim,
261 GammaSrcVectorSize,
262 1,
263 true>(
264 gamma_grid_desc_m_k,
265 make_multi_index(block_global_id * M_BlockTileSize +
266 thread_m_cluster_id * MThreadSliceSize,
267 thread_k_cluster_id * GammaSrcVectorSize));
268
269 auto threadwise_beta_load =
271 AccDataType,
272 GridDesc_M_K,
273 decltype(thread_buffer_desc_m_k),
274 ThreadBufferLengths_M_K,
276 BetaSrcVectorDim,
277 BetaSrcVectorSize,
278 1,
279 true>(
280 beta_grid_desc_m_k,
281 make_multi_index(block_global_id * M_BlockTileSize +
282 thread_m_cluster_id * MThreadSliceSize,
283 thread_k_cluster_id * BetaSrcVectorSize));
284
286 PassThrough pass_through_op;
287 auto threadwise_x_store =
289 XDataType,
290 decltype(thread_buffer_desc_m_k),
291 GridDesc_M_K,
293 ThreadBufferLengths_M_K,
295 XSrcVectorDim,
296 XSrcVectorSize,
298 1,
299 true>(
300 x_grid_desc_m_k,
301 make_multi_index(thread_m_cluster_id * MThreadSliceSize,
302 thread_k_cluster_id * XSrcVectorSize),
303 pass_through_op);
304
305 auto threadwise_y_store =
307 YDataType,
308 decltype(thread_buffer_desc_m_k),
309 GridDesc_M_K,
310 YElementwiseOperation,
311 ThreadBufferLengths_M_K,
313 YDstVectorDim,
314 YDstVectorSize,
316 1,
317 true>(
318 y_grid_desc_m_k,
319 make_multi_index(block_global_id * M_BlockTileSize +
320 thread_m_cluster_id * MThreadSliceSize,
321 thread_k_cluster_id * YDstVectorSize),
322 y_elementwise_op);
323
324 // Copy x from Cache
325 // one pass: fwd, second pass: bwd
326 constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
327 constexpr auto thread_copy_bwd_step_m_k =
328 make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
329
330 const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
331 p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
332
333 const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
334 p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
335
336 auto threadwise_welford = ThreadwiseWelford();
337 threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
338
340 mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
341 var_thread_buf(I) = type_convert<AccDataType>(0.0f);
342 });
343
344 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
345 {
347 static_for<0, NumInput, 1>{}([&](auto I) { // input load loop
348 in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I],
349 in_global_buf_tuple[I],
350 thread_buffer_desc_m_k,
351 make_tuple(I0, I0),
352 in_thread_buf_tuple(iK0)(I));
353
354 in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I],
355 thread_copy_fwd_step_m_k);
356 });
357
358 static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // input add loop
359 static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
360 constexpr auto offset_m_k =
361 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
362
363 // get reference to in data
364 const auto in_data_refs = generate_tie(
365 // return type should be lvalue
366 [&](auto I) -> const auto& {
367 return in_thread_buf_tuple(iK0)(I)(Number<offset_m_k>{});
368 },
370
371 // get reference to dst data
372 auto out_data_refs = generate_tie(
373 // return type should be lvalue
374 [&](auto) -> auto& { return x_thread_buf(iK0)(Number<offset_m_k>{}); },
375 I1);
376
377 unpack2(x_elementwise_op, out_data_refs, in_data_refs);
378 });
379 });
380 threadwise_welford.Run(x_thread_buf[iK0], mean_thread_buf, var_thread_buf);
381
382 if constexpr(!SweepOnce)
383 {
384 threadwise_x_store.Run(thread_buffer_desc_m_k,
385 make_tuple(I0, I0),
386 x_thread_buf(iK0),
387 x_grid_desc_m_k,
388 x_lds_val_buf);
389 threadwise_x_store.MoveDstSliceWindow(x_grid_desc_m_k,
390 thread_copy_fwd_step_m_k);
391 }
392 });
393 }
394
396 if constexpr(I > 0)
398
399 int count = threadwise_welford.cur_count_;
400 BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
401 });
402
403 auto thread_copy_tail_m_k =
404 (num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k;
405
406 if constexpr(!SweepOnce)
407 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_tail_m_k);
408 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
409 threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
410 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
411
412 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
413 {
414 if constexpr(!SweepOnce)
415 {
417 threadwise_x_load.Run(x_grid_desc_m_k,
418 x_lds_val_buf,
419 thread_buffer_desc_m_k,
420 make_tuple(I0, I0),
421 x_thread_buf(i));
422 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
423 });
424 }
425
427 threadwise_gamma_load.Run(gamma_grid_desc_m_k,
428 gamma_global_val_buf,
429 thread_buffer_desc_m_k,
430 make_tuple(I0, I0),
431 gamma_thread_buf(i));
432 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
433 thread_copy_fwd_step_m_k);
434 });
435
437 auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
439 static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
440 constexpr auto offset_m_k =
441 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
442
443 // normalize
444 y_thread_buf(iK0)(Number<offset_m_k>{}) =
445 (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
446 divisor;
447
448 // gamma
449 y_thread_buf(iK0)(Number<offset_m_k>{}) =
450 y_thread_buf(iK0)(Number<offset_m_k>{}) *
451 gamma_thread_buf(iK0)(Number<offset_m_k>{});
452 });
453 });
454 });
455
457 threadwise_beta_load.Run(beta_grid_desc_m_k,
458 beta_global_val_buf,
459 thread_buffer_desc_m_k,
460 make_tuple(I0, I0),
461 beta_thread_buf(i));
462 threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
463 thread_copy_fwd_step_m_k);
464 });
465
468 static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
469 constexpr auto offset_m_k =
470 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
471
472 // beta
473 y_thread_buf(iK0)(Number<offset_m_k>{}) =
474 y_thread_buf(iK0)(Number<offset_m_k>{}) +
475 beta_thread_buf(iK0)(Number<offset_m_k>{});
476 });
477 });
478 });
479
481 threadwise_y_store.Run(thread_buffer_desc_m_k,
482 make_tuple(I0, I0),
483 y_thread_buf(i),
484 y_grid_desc_m_k,
485 y_global_val_buf);
486 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k);
487 });
488
489 if constexpr(!SweepOnce)
490 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
491 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
492 2 * thread_copy_bwd_step_m_k);
493 threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
494 2 * thread_copy_bwd_step_m_k);
495 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
496 }
497 }
498};
499
500} // namespace ck
__host__ __device__ constexpr T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition utility/math.hpp:148
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_elementwise_layernorm_welford_variance.hpp:42
static __device__ void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple, const GridDesc_M_K &x_grid_desc_m_k, const GridDesc_M_K &gamma_grid_desc_m_k, const GridDesc_M_K &beta_grid_desc_m_k, const GridDesc_M_K &y_grid_desc_m_k, index_t num_k_block_tile_iteration, AccDataType epsilon, const InDataTypePointerTuple p_in_global_tuple, XDataType *const __restrict__ p_x_lds_, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, const XElementwiseOperation x_elementwise_op, const YElementwiseOperation y_elementwise_op)
Definition gridwise_elementwise_layernorm_welford_variance.hpp:114
static __device__ int GetKPerThread(const GridDesc_M_K &x_grid_desc_m_k, int thread_k_cluster_id)
Definition gridwise_elementwise_layernorm_welford_variance.hpp:92
Definition multi_index_transform.hpp:13
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340