device_gemm_wmma_cshuffle_v3.hpp Source File

device_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: device_gemm_wmma_cshuffle_v3.hpp Source File
device_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
124template <typename ALayout,
125 typename BLayout,
126 typename CLayout,
127 typename ADataType,
128 typename BDataType,
129 typename CDataType,
130 typename AccDataType,
131 typename CShuffleDataType,
132 typename AElementwiseOperation,
133 typename BElementwiseOperation,
134 typename CElementwiseOperation,
135 GemmSpecialization GemmSpec,
136 index_t BlockSize,
137 index_t MPerBlock,
138 index_t NPerBlock,
139 index_t KPerBlock,
140 index_t AK1,
141 index_t BK1,
142 index_t MPerWmma,
143 index_t NPerWmma,
144 index_t MRepeat,
145 index_t NRepeat,
146 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
147 typename ABlockTransferThreadClusterArrangeOrder,
148 typename ABlockTransferSrcAccessOrder,
149 index_t ABlockTransferSrcVectorDim,
150 index_t ABlockTransferSrcScalarPerVector,
151 index_t ABlockTransferDstScalarPerVector_AK1,
152 bool ABlockLdsExtraM,
153 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
154 typename BBlockTransferThreadClusterArrangeOrder,
155 typename BBlockTransferSrcAccessOrder,
156 index_t BBlockTransferSrcVectorDim,
157 index_t BBlockTransferSrcScalarPerVector,
158 index_t BBlockTransferDstScalarPerVector_BK1,
159 bool BBlockLdsExtraN,
160 index_t CShuffleMRepeatPerShuffle,
161 index_t CShuffleNRepeatPerShuffle,
162 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
163 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
166 typename ComputeTypeA = CDataType,
167 typename ComputeTypeB = ComputeTypeA,
168 bool PermuteA = false,
169 bool PermuteB = false>
171 BLayout,
172 CLayout,
173 ADataType,
174 BDataType,
175 CDataType,
176 AElementwiseOperation,
177 BElementwiseOperation,
178 CElementwiseOperation>
179{
181 ALayout,
182 BLayout,
183 Tuple<>, // DsLayout
184 CLayout,
187 AccDataType,
188 CShuffleDataType,
189 Tuple<>, // DsDataType
190 CDataType,
191 AElementwiseOperation,
192 BElementwiseOperation,
193 CElementwiseOperation,
194 GemmSpec,
195 BlockSize,
196 MPerBlock,
197 NPerBlock,
198 KPerBlock,
199 AK1,
200 BK1,
201 MPerWmma,
202 NPerWmma,
203 MRepeat,
204 NRepeat,
205 ABlockTransferThreadClusterLengths_AK0_M_AK1,
206 ABlockTransferThreadClusterArrangeOrder,
207 ABlockTransferSrcAccessOrder,
208 ABlockTransferSrcVectorDim,
209 ABlockTransferSrcScalarPerVector,
210 ABlockTransferDstScalarPerVector_AK1,
211 false,
212 ABlockLdsExtraM,
213 BBlockTransferThreadClusterLengths_BK0_N_BK1,
214 BBlockTransferThreadClusterArrangeOrder,
215 BBlockTransferSrcAccessOrder,
216 BBlockTransferSrcVectorDim,
217 BBlockTransferSrcScalarPerVector,
218 BBlockTransferDstScalarPerVector_BK1,
219 false,
220 BBlockLdsExtraN,
221 CShuffleMRepeatPerShuffle,
222 CShuffleNRepeatPerShuffle,
223 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
225 BlkGemmPipeSched,
226 BlkGemmPipelineVer,
227 ComputeTypeA,
228 ComputeTypeB,
229 PermuteA,
230 PermuteB>;
231
232 using Argument = typename GridwiseGemm::Argument;
233
238 Tuple<>,
239 CDataType,
240 MPerBlock,
241 NPerBlock,
242 KPerBlock,
243 BlockSize,
244 AK1,
245 BK1,
246 GemmSpec,
248 BlkGemmPipeSched,
249 BlkGemmPipelineVer,
250 ComputeTypeA,
251 ComputeTypeB>;
252
253 // Invoker
254 using Invoker = typename DeviceGemmCommon::Invoker;
255
256 static bool IsSupportedArgument(const Argument& arg)
257 {
259 }
260
261 // polymorphic
262 bool IsSupportedArgument(const BaseArgument* p_arg) override
263 {
264 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
265 }
266
267 index_t GetKPerBlock() override { return KPerBlock; }
268
269 bool GetPermuteA() override { return PermuteA; }
270 bool GetPermuteB() override { return PermuteB; }
271
272 static auto MakeArgument(const ADataType* p_a,
273 const BDataType* p_b,
274 CDataType* p_c,
275 index_t M,
276 index_t N,
277 index_t K,
278 index_t StrideA,
279 index_t StrideB,
280 index_t StrideC,
281 index_t KBatch,
282 AElementwiseOperation a_element_op,
283 BElementwiseOperation b_element_op,
284 CElementwiseOperation cde_element_op)
285 {
286 return Argument{std::array<const void*, 1>{p_a},
287 std::array<const void*, 1>{p_b},
288 std::array<const void*, 0>{}, // p_ds_grid_
289 p_c,
290 M,
291 N,
292 K,
293 std::array<index_t, 1>{StrideA},
294 std::array<index_t, 1>{StrideB},
295 std::array<index_t, 0>{}, // StrideDs_
296 StrideC,
297 KBatch,
298 a_element_op,
299 b_element_op,
300 cde_element_op};
301 }
302
303 static auto MakeInvoker() { return Invoker{}; }
304
305 // polymorphic
306 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
307 const void* p_b,
308 void* p_c,
309 index_t M,
310 index_t N,
311 index_t K,
312 index_t StrideA,
313 index_t StrideB,
314 index_t StrideC,
315 index_t KBatch,
316 AElementwiseOperation a_element_op,
317 BElementwiseOperation b_element_op,
318 CElementwiseOperation c_element_op) override
319 {
320 return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
321 std::array<const void*, 1>{p_b},
322 std::array<const void*, 0>{}, // p_ds_grid_
323 static_cast<CDataType*>(p_c),
324 M,
325 N,
326 K,
327 std::array<index_t, 1>{StrideA},
328 std::array<index_t, 1>{StrideB},
329 std::array<index_t, 0>{}, // StrideDs_
330 StrideC,
331 KBatch,
332 a_element_op,
333 b_element_op,
334 c_element_op);
335 }
336
337 // polymorphic
338 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
339 {
340 return std::make_unique<Invoker>(Invoker{});
341 }
342
343 // polymorphic
344 std::string GetTypeString() const override
345 {
346 auto str = std::stringstream();
347
348 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
351
352 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
358
359 // clang-format off
360 str << "DeviceGemm_Wmma_CShuffleV3"
361 << "<"
362 << getGemmSpecializationString(GemmSpec) << ", "
363 << std::string(ALayout::name)[0]
364 << std::string(BLayout::name)[0]
365 << std::string(CLayout::name)[0]
366 << ">"
367 << " BlkSize: "
368 << BlockSize << ", "
369 << "BlkTile: "
370 << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
371 << "WaveTile: "
372 << MPerWmma << "x"<<NPerWmma << ", "
373 << "WaveMap: "
374 << MRepeat << "x" << NRepeat << ", "
375 << "VmemReadVec: "
376 << ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
377 << "BlkGemmPipelineScheduler: "
378 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
379 << "BlkGemmPipelineVersion: "
380 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
381 << "BlkGemmPipelinePrefetchStages: "
382 << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
383 << "KPack: "
385 // clang-format on
386
387 return str.str();
388 }
390};
391
392} // namespace device
393} // namespace tensor_operation
394} // namespace ck
#define REGISTER_EXTRA_PRINTING_METHODS
Definition device_base.hpp:47
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
static constexpr index_t KPack
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:154
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:233
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:186
Definition utility/tuple.hpp:117
Definition device_base.hpp:197
Definition device_gemm_wmma_cshuffle_v3_common.hpp:43
"Universal" GEMM operation with SplitK support.
Definition device_gemm_wmma_cshuffle_v3.hpp:179
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation cde_element_op)
Definition device_gemm_wmma_cshuffle_v3.hpp:272
std::string GetTypeString() const override
Definition device_gemm_wmma_cshuffle_v3.hpp:344
typename DeviceGemmCommon::Invoker Invoker
Definition device_gemm_wmma_cshuffle_v3.hpp:254
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_wmma_cshuffle_v3.hpp:306
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition device_gemm_wmma_cshuffle_v3.hpp:180
static auto MakeInvoker()
Definition device_gemm_wmma_cshuffle_v3.hpp:303
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_wmma_cshuffle_v3.hpp:262
bool GetPermuteA() override
Definition device_gemm_wmma_cshuffle_v3.hpp:269
DeviceGemm_Wmma_CShuffleV3_Common< GridwiseGemm, Tuple< ADataType >, Tuple< BDataType >, Tuple<>, CDataType, MPerBlock, NPerBlock, KPerBlock, BlockSize, AK1, BK1, GemmSpec, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > DeviceGemmCommon
Definition device_gemm_wmma_cshuffle_v3.hpp:234
typename GridwiseGemm::Argument Argument
Definition device_gemm_wmma_cshuffle_v3.hpp:232
bool GetPermuteB() override
Definition device_gemm_wmma_cshuffle_v3.hpp:270
index_t GetKPerBlock() override
Definition device_gemm_wmma_cshuffle_v3.hpp:267
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_wmma_cshuffle_v3.hpp:338
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_wmma_cshuffle_v3.hpp:256
Definition device_gemm_v2.hpp:22