thread_group_tensor_slice_transfer_v4r2.hpp Source File

thread_group_tensor_slice_transfer_v4r2.hpp Source File#

Composable Kernel: thread_group_tensor_slice_transfer_v4r2.hpp Source File
thread_group_tensor_slice_transfer_v4r2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
23template <typename ThreadGroup,
24 typename ElementwiseOperation,
25 typename DstInMemOps, // Sequence
26 typename BlockSliceLengths,
27 typename ThreadClusterLengths,
28 typename ThreadClusterArrangeOrder,
29 typename SrcDatas,
30 typename DstDatas,
31 typename SrcDescs,
32 typename DstDescs,
33 typename SrcDimAccessOrder,
34 typename DstDimAccessOrder,
35 index_t SrcVectorDim,
36 index_t DstVectorDim,
37 typename SrcsScalarPerVector, // Sequence
38 typename DstsScalarPerVector, // Sequence
39 typename SrcsScalarStrideInVector, // Sequence
40 typename DstsScalarStrideInVector, // Sequence
41 typename ThreadTransferSrcsResetCoordinateAfterRun, // Sequence
42 typename ThreadTransferDstsResetCoordinateAfterRun, // Sequence
43 index_t NumThreadScratch = 1>
45{
46 static constexpr index_t nDim =
48 static constexpr index_t nSrc = SrcDescs::Size();
49 static constexpr index_t nDst = DstDescs::Size();
50
51 static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
52
54
56 const SrcDescs& src_descs,
57 const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
58 const DstDescs& dst_descs,
59 const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
60 const ElementwiseOperation& element_op)
61 : threadwise_transfer_(src_descs,
63 dst_descs,
65 element_op)
66
67 {
68 static_assert(nDim == ThreadClusterLengths::Size() &&
69 nDim == ThreadClusterArrangeOrder::Size() &&
70 nDim == SrcDimAccessOrder::Size() && nDim == SrcDimAccessOrder::Size(),
71 "wrong! nDim not consistent");
72
73 static_for<0, nSrc, 1>{}([&](auto src_i) {
74 static_assert(nDim ==
76 "wrong! nDim not consistent");
77 });
78
79 static_for<0, nDst, 1>{}([&](auto dst_i) {
80 static_assert(nDim ==
82 "wrong! nDim not consistent");
83 });
84
85 static_assert(
86 is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
87 "wrong! threads should be mapped to cover entire slicing window");
88
89 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
90 "wrong! ThreadGroup::GetNumOfThread() too small");
91
92 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
93 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
94 {
95 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
96 make_multi_index(ThreadGroup::GetThreadId()));
97
98 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
99
100 const auto src_thread_slice_origins = generate_tuple(
101 [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; },
102 Number<nSrc>{});
103
104 const auto dst_thread_slice_origins = generate_tuple(
105 [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; },
106 Number<nDst>{});
107
108 threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
109 threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
110 }
111 }
112
113 template <typename SrcBuffers, index_t ThreadScratchId = 0>
114 __device__ void RunRead(const SrcDescs& src_descs,
115 const SrcBuffers& src_bufs,
117 {
118 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
119 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
120 {
121 threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
122 }
123 }
124
125 template <typename DstBuffers, index_t ThreadScratchId = 0>
126 __device__ void RunWrite(const DstDescs& dst_descs,
127 DstBuffers& dst_bufs,
129 {
130 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
131 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
132 {
133 threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
134 }
135 }
136
137 template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
138 __device__ void Run(const SrcDescs& src_descs,
139 const SrcBuffer& src_bufs,
140 const DstDescs& dst_descs,
141 DstBuffer& dst_bufs,
142 Number<ThreadScratchId> thread_scratch_id)
143 {
144 RunRead(src_descs, src_bufs, thread_scratch_id);
145 RunWrite(dst_descs, dst_bufs, thread_scratch_id);
146 }
147
148 __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
149 {
150 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
151 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
152 {
153 threadwise_transfer_.MoveSrcSliceWindow(src_descs, step);
154 }
155 }
156
157 __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
158 {
159 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
160 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
161 {
162 threadwise_transfer_.MoveDstSliceWindow(dst_descs, step);
163 }
164 }
165
166 private:
167 static constexpr auto thread_cluster_desc_ =
168 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
169
170 using ThreadwiseTransfer =
171 ThreadwiseTensorSliceTransfer_v3r2<decltype(thread_slice_lengths),
172 ElementwiseOperation,
173 DstInMemOps,
174 SrcDatas,
175 DstDatas,
176 SrcDescs,
177 DstDescs,
178 SrcDimAccessOrder,
179 DstDimAccessOrder,
180 SrcVectorDim,
181 DstVectorDim,
182 SrcsScalarPerVector,
183 DstsScalarPerVector,
184 SrcsScalarStrideInVector,
185 DstsScalarStrideInVector,
186 ThreadTransferSrcsResetCoordinateAfterRun,
187 ThreadTransferDstsResetCoordinateAfterRun,
188 NumThreadScratch>;
189
190 ThreadwiseTransfer threadwise_transfer_;
191};
192
193} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v4r2.hpp:53
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers &dst_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v4r2.hpp:126
static constexpr index_t nSrc
Definition thread_group_tensor_slice_transfer_v4r2.hpp:48
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, const Index &step)
Definition thread_group_tensor_slice_transfer_v4r2.hpp:148
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, const Index &step)
Definition thread_group_tensor_slice_transfer_v4r2.hpp:157
__device__ void Run(const SrcDescs &src_descs, const SrcBuffer &src_bufs, const DstDescs &dst_descs, DstBuffer &dst_bufs, Number< ThreadScratchId > thread_scratch_id)
Definition thread_group_tensor_slice_transfer_v4r2.hpp:138
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v4r2.hpp:114
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r2(const SrcDescs &src_descs, const StaticallyIndexedArray< Index, nSrc > &src_block_slice_origins, const DstDescs &dst_descs, const StaticallyIndexedArray< Index, nDst > &dst_block_slice_origins, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v4r2.hpp:55
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v4r2.hpp:46
static constexpr index_t nDst
Definition thread_group_tensor_slice_transfer_v4r2.hpp:49
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v4r2.hpp:51
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r2.hpp:99
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers &dst_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r2.hpp:305