blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp Source File

blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp Source File
blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16
17namespace ck {
18
19template <BlockGemmPipelineVersion BlkGemmPipelineVer,
20 BlockGemmPipelineScheduler BlkGemmPipeSche,
21 index_t BlockSize,
22 typename ADataType,
23 typename BDataType,
24 typename ComputeDataType,
25 typename AccDataType,
26 typename ATileDesc,
27 typename BTileDesc,
28 typename AMmaTileDesc,
29 typename BMmaTileDesc,
30 index_t ABlockTransferSrcScalarPerVector,
31 index_t BBlockTransferSrcScalarPerVector,
32 index_t MPerBlock,
33 index_t NPerBlock,
34 index_t KPerBlock,
35 index_t MPerXDL,
36 index_t NPerXDL,
37 index_t MRepeat,
38 index_t NRepeat,
39 index_t KPack,
40 bool GUFusion = false>
42{
43 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
44 {
45 if constexpr(std::is_same<ADataType, BDataType>::value)
46 {
47 if constexpr(GUFusion)
48 {
50 BlkGemmPipeSche,
51 BlockSize,
52 ADataType,
53 BDataType,
54 ComputeDataType,
55 AccDataType,
56 ATileDesc,
57 BTileDesc,
58 AMmaTileDesc,
59 BMmaTileDesc,
60 ABlockTransferSrcScalarPerVector,
61 BBlockTransferSrcScalarPerVector,
62 MPerBlock,
63 NPerBlock,
64 KPerBlock,
65 MPerXDL,
66 NPerXDL,
67 MRepeat,
68 NRepeat,
69 KPack>{};
70 }
71 else
72 {
74 BlockSize,
75 ADataType,
76 BDataType,
77 ComputeDataType,
78 AccDataType,
79 ATileDesc,
80 BTileDesc,
81 AMmaTileDesc,
82 BMmaTileDesc,
83 ABlockTransferSrcScalarPerVector,
84 BBlockTransferSrcScalarPerVector,
85 MPerBlock,
86 NPerBlock,
87 KPerBlock,
88 MPerXDL,
89 NPerXDL,
90 MRepeat,
91 NRepeat,
92 KPack>{};
93 }
94 }
95 else
96 {
97 if constexpr(GUFusion)
98 {
100 BlkGemmPipeSche,
101 BlockSize,
102 ADataType,
103 BDataType,
104 ComputeDataType,
105 AccDataType,
106 ATileDesc,
107 BTileDesc,
108 AMmaTileDesc,
109 BMmaTileDesc,
110 ABlockTransferSrcScalarPerVector,
111 BBlockTransferSrcScalarPerVector,
112 MPerBlock,
113 NPerBlock,
114 KPerBlock,
115 MPerXDL,
116 NPerXDL,
117 MRepeat,
118 NRepeat,
119 KPack>{};
120 }
121 else
122 {
124 BlkGemmPipeSche,
125 BlockSize,
126 ADataType,
127 BDataType,
128 ComputeDataType,
129 AccDataType,
130 ATileDesc,
131 BTileDesc,
132 AMmaTileDesc,
133 BMmaTileDesc,
134 ABlockTransferSrcScalarPerVector,
135 BBlockTransferSrcScalarPerVector,
136 MPerBlock,
137 NPerBlock,
138 KPerBlock,
139 MPerXDL,
140 NPerXDL,
141 MRepeat,
142 NRepeat,
143 KPack>{};
144 }
145 }
146 }
147 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
148 {
149 return BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlkGemmPipeSche,
150 BlockSize,
151 ADataType,
152 BDataType,
153 ComputeDataType,
154 AccDataType,
155 ATileDesc,
156 BTileDesc,
157 AMmaTileDesc,
158 BMmaTileDesc,
159 ABlockTransferSrcScalarPerVector,
160 BBlockTransferSrcScalarPerVector,
161 MPerBlock,
162 NPerBlock,
163 KPerBlock,
164 MPerXDL,
165 NPerXDL,
166 MRepeat,
167 NRepeat,
168 KPack>{};
169 }
170 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
171 {
172 static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3");
173 if constexpr(std::is_same<ADataType, BDataType>::value)
174 {
175 if constexpr(GUFusion)
176 {
178 BlkGemmPipeSche,
179 BlockSize,
180 ADataType,
181 BDataType,
182 ComputeDataType,
183 AccDataType,
184 ATileDesc,
185 BTileDesc,
186 AMmaTileDesc,
187 BMmaTileDesc,
188 ABlockTransferSrcScalarPerVector,
189 BBlockTransferSrcScalarPerVector,
190 MPerBlock,
191 NPerBlock,
192 KPerBlock,
193 MPerXDL,
194 NPerXDL,
195 MRepeat,
196 NRepeat,
197 KPack>{};
198 }
199 else
200 {
201
202 return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlkGemmPipeSche,
203 BlockSize,
204 ADataType,
205 BDataType,
206 ComputeDataType,
207 AccDataType,
208 ATileDesc,
209 BTileDesc,
210 AMmaTileDesc,
211 BMmaTileDesc,
212 ABlockTransferSrcScalarPerVector,
213 BBlockTransferSrcScalarPerVector,
214 MPerBlock,
215 NPerBlock,
216 KPerBlock,
217 MPerXDL,
218 NPerXDL,
219 MRepeat,
220 NRepeat,
221 KPack>{};
222 }
223 }
224 else
225 {
227 BlkGemmPipeSche,
228 BlockSize,
229 ADataType,
230 BDataType,
231 ComputeDataType,
232 AccDataType,
233 ATileDesc,
234 BTileDesc,
235 AMmaTileDesc,
236 BMmaTileDesc,
237 ABlockTransferSrcScalarPerVector,
238 BBlockTransferSrcScalarPerVector,
239 MPerBlock,
240 NPerBlock,
241 KPerBlock,
242 MPerXDL,
243 NPerXDL,
244 MRepeat,
245 NRepeat,
246 KPack>{};
247 }
248 }
249 else
250 {
251 std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
252 }
253}
254
255} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v1
Definition blkgemmpipe_scheduler.hpp:14
constexpr auto BlockGemmBPreshufflePipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp:41
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp:37
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp:37
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp:37
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp:37
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:37
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp:37
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp:37
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp:47