reduction_functions_blockwise.hpp Source File

reduction_functions_blockwise.hpp Source File#

Composable Kernel: reduction_functions_blockwise.hpp Source File
reduction_functions_blockwise.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
9
10namespace ck {
11
12// clang-format off
13// Assume:
14// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
15// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
16// 3) in_out_value is the input data in vgpr from each thread
17// 4) in_out_value is the over-written reduced output in vgpr for each thread
18// clang-format on
19template <typename AccDataType,
20 index_t BlockSize,
21 typename ThreadClusterLengths_M_K,
22 typename ThreadClusterArrangeOrder,
23 typename OpReduce,
24 bool PropagateNan,
25 typename Accumulation =
28{
29 static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
30 "The product of cluster lengths should be same as BlockSize!");
31
32 static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
33 static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
34
35 static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
36
39
40 static constexpr auto thread_cluster_desc =
41 make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
42
43 template <typename BufferType>
44 __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
45 {
47 "Buffer data type should be consistent as AccDataType!");
48
49 constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
50
51 const auto thread_cluster_idx =
53
54 const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
55 const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
56
57 work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
58
59 __syncthreads();
60
62 constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
63
64 if(thread_k_cluster_id < indOffset)
65 {
66 index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
67 index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
68 make_tuple(0, indOffset));
69
70 AccDataType opData1 = work_buffer[offset1];
71 AccDataType opData2 = work_buffer[offset2];
72 Accumulation::Calculate(opData1, opData2);
73 work_buffer(offset1) = opData1;
74 }
75
76 __syncthreads();
77 });
78
79 index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
80
81 in_out_value = work_buffer[offset];
82 };
83};
84
85// clang-format off
86// Assume:
87// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
88// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
89// 3) in_out_value is the input data in vgpr from each thread
90// 4) in_out_value is the over-written reduced output in vgpr for each thread
91// clang-format on
92template <typename AccDataType,
93 index_t BlockSize,
94 typename ThreadClusterLengths_M_K,
95 typename ThreadClusterDesc,
96 typename OpReduce,
97 bool PropagateNan,
98 typename Accumulation =
99 detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
101{
102 static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
103 "The product of cluster lengths should be same as BlockSize!");
104
105 static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
106 static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
107
108 static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
109
112
113 static constexpr auto thread_cluster_desc = ThreadClusterDesc{};
114
115 template <typename BufferType>
116 __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
117 {
119 "Buffer data type should be consistent as AccDataType!");
120
121 constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
122
123 const auto thread_cluster_idx =
125
126 const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
127 const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
128
129 work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
130
131 __syncthreads();
132
134 constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
135
136 if(thread_k_cluster_id < indOffset)
137 {
138 index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
139 index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
140 make_tuple(0, indOffset));
141
142 AccDataType opData1 = work_buffer[offset1];
143 AccDataType opData2 = work_buffer[offset2];
144 Accumulation::Calculate(opData1, opData2);
145 work_buffer(offset1) = opData1;
146 }
147
148 __syncthreads();
149 });
150
151 index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
152
153 in_out_value = work_buffer[offset];
154 };
155};
156
157// clang-format off
158// Assume:
159// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
160// 2) work_val_buffer/work_idx_buffer has AccDataType/IndexDataType elements, and space size is no less than BlockSize
161// 3) in_out_value/in_out_index is the input data in vgpr from each thread
162// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread
163// clang-format on
164template <
165 typename AccDataType,
166 typename IndexDataType,
167 index_t BlockSize,
168 typename ThreadClusterLengths_M_K,
169 typename ThreadClusterArrangeOrder,
170 typename OpReduce,
171 bool PropagateNan,
172 typename Accumulation =
173 detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
175{
176 static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
177 "The product of cluster lengths should be same as BlockSize!");
178
179 static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
180 static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
181
182 static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
183
186
187 static constexpr auto thread_cluster_desc =
188 make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
189
190 // This interface accumulates on both data values and indices
191 template <typename BufferType, typename IdxBufferType>
192 __device__ static void Reduce(BufferType& work_val_buffer,
193 IdxBufferType& work_idx_buffer,
194 AccDataType& in_out_value,
195 IndexDataType& in_out_index)
196 {
198 "Buffer data type should be consistent as AccDataType!");
200 "Buffer data type should be consistent as IndexDataType!");
201
202 constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
203
204 const auto thread_cluster_idx =
206
207 const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
208 const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
209
210 work_val_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
211 work_idx_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_index;
212
213 __syncthreads();
214
216 constexpr index_t indOffset = 1 << I();
217
218 if(thread_k_cluster_id % (indOffset * 2) == 0)
219 {
220 index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
221 index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
222 make_tuple(0, indOffset));
223
224 AccDataType opData1 = work_val_buffer[offset1];
225 AccDataType opData2 = work_val_buffer[offset2];
226 IndexDataType currIndex1 = work_idx_buffer[offset1];
227 IndexDataType currIndex2 = work_idx_buffer[offset2];
228
229 Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2);
230 work_val_buffer(offset1) = opData1;
231 work_idx_buffer(offset1) = currIndex1;
232 }
233
234 __syncthreads();
235 });
236
237 index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
238
239 in_out_value = work_val_buffer[offset];
240 in_out_index = work_idx_buffer[offset];
241 };
242};
243
244} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
Definition reduction_functions_blockwise.hpp:101
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition reduction_functions_blockwise.hpp:116
Definition reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition reduction_functions_blockwise.hpp:44
Definition reduction_functions_blockwise.hpp:175
static constexpr auto BufferLength_K
Definition reduction_functions_blockwise.hpp:180
static constexpr auto block_buf_desc_m_k
Definition reduction_functions_blockwise.hpp:184
static constexpr auto thread_cluster_desc
Definition reduction_functions_blockwise.hpp:187
static __device__ void Reduce(BufferType &work_val_buffer, IdxBufferType &work_idx_buffer, AccDataType &in_out_value, IndexDataType &in_out_index)
Definition reduction_functions_blockwise.hpp:192
static constexpr auto BufferLength_M
Definition reduction_functions_blockwise.hpp:179
Definition reduction_functions_accumulate.hpp:28
Definition type.hpp:177
Definition functional2.hpp:33