thread_group_tensor_slice_transfer_v6r1r2.hpp Source File

thread_group_tensor_slice_transfer_v6r1r2.hpp Source File#

Composable Kernel: thread_group_tensor_slice_transfer_v6r1r2.hpp Source File
thread_group_tensor_slice_transfer_v6r1r2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
14// this version does following things to avoid scratch memory issue
15// 1. Use StaticallyIndexedArray instead of C array for thread buffer
16// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
17// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
18template <typename ThreadGroup,
19 typename ElementwiseOperation,
20 typename SliceLengths,
21 typename ThreadClusterLengths,
22 typename ThreadClusterArrangeOrder,
23 typename SrcData,
24 typename DstData,
25 typename SrcDesc,
26 typename DstDesc,
27 typename DimAccessOrder,
28 index_t VectorDim,
29 index_t ScalarPerVector,
30 bool ThreadTransferSrcResetCoordinateAfterRun,
31 bool ThreadTransferDstResetCoordinateAfterRun>
33{
35
36 static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
37
39
41 const SrcDesc& src_desc,
42 const Index& src_block_slice_origin,
43 const DstDesc& dst_desc,
44 const Index& dst_block_slice_origin,
45 const ElementwiseOperation& element_op)
46 : threadwise_transfer_(src_desc,
48 dst_desc,
50 element_op)
51
52 {
55 nDim == ThreadClusterLengths::Size() &&
56 nDim == ThreadClusterArrangeOrder::Size() &&
57 nDim == DimAccessOrder::Size(),
58 "wrong! nDim not consistent");
59
60 static_assert(
61 is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
62 "wrong! threads should be mapped to cover entire slicing window");
63
64 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
65 "wrong! ThreadGroup::GetNumOfThread() too small");
66
67 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
68 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
69 {
70 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
71 make_multi_index(ThreadGroup::GetThreadId()));
72
73 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
74
75 threadwise_transfer_.SetSrcSliceOrigin(src_desc,
76 src_block_slice_origin + thread_data_idx_begin);
77 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
78 dst_block_slice_origin + thread_data_idx_begin);
79 }
80 }
81
82 template <typename SrcBuffer, typename DstBuffer, InMemoryDataOperationEnum DstInMemOp>
83 __device__ void Run(const SrcDesc& src_desc,
84 const SrcBuffer& src_buf,
85 const DstDesc& dst_desc,
86 DstBuffer& dst_buf)
87 {
88 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
89 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
90 {
91 threadwise_transfer_.template Run<SrcBuffer, DstBuffer, DstInMemOp>(
92 src_desc, src_buf, dst_desc, dst_buf);
93 }
94 }
95
96 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
97 {
98 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
99 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
100 {
101 threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
102 }
103 }
104
105 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
106 {
107 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
108 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
109 {
110 threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
111 }
112 }
113
114 __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin)
115 {
116 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
117 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
118 {
119 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
120 make_multi_index(ThreadGroup::GetThreadId()));
121
122 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
123
124 threadwise_transfer_.SetSrcSliceOrigin(src_desc,
125 src_block_slice_origin + thread_data_idx_begin);
126 }
127 }
128
129 __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_block_slice_origin)
130 {
131 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
132 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
133 {
134 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
135 make_multi_index(ThreadGroup::GetThreadId()));
136
137 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
138
139 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
140 dst_block_slice_origin + thread_data_idx_begin);
141 }
142 }
143
144 private:
145 static constexpr auto thread_cluster_desc_ =
146 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
147
148 using ThreadwiseTransfer =
149 ThreadwiseTensorSliceTransfer_v6r1r2<SrcData,
150 DstData,
151 SrcDesc,
152 DstDesc,
153 ElementwiseOperation,
154 decltype(thread_slice_lengths),
155 DimAccessOrder,
156 VectorDim,
157 ScalarPerVector,
158 ThreadTransferSrcResetCoordinateAfterRun,
159 ThreadTransferDstResetCoordinateAfterRun>;
160
161 ThreadwiseTransfer threadwise_transfer_;
162};
163
164} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:105
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:36
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_block_slice_origin)
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:114
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:34
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:38
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r1r2(const SrcDesc &src_desc, const Index &src_block_slice_origin, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:40
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_block_slice_origin)
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:129
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:83
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:96
Definition type.hpp:177