device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp Source File

device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp Source File#

Composable Kernel: device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp Source File
device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
10#include "ck/utility/env.hpp"
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename GridwiseGemm,
25 typename A0B0B1DataType,
26 typename D0sPointer,
27 typename D1sPointer,
28 typename E1DataType,
29 typename A0ElementwiseOperation,
30 typename B0ElementwiseOperation,
31 typename CDE0ElementwiseOperation,
32 typename B1ElementwiseOperation,
33 typename CDE1ElementwiseOperation,
34 typename A0GridDesc_AK0_M_AK1,
35 typename B0GridDesc_BK0_N_BK1,
36 typename D0sGridDesc_M_N,
37 typename B1GridDesc_BK0_N_BK1,
38 typename D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
39 typename E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
40 typename Block2E1TileMap,
41 typename ComputeBasePtrOfStridedBatch,
42 bool HasMainKBlockLoop>
43__global__ void
44#if CK_USE_LAUNCH_BOUNDS
46#endif
48 const A0B0B1DataType* __restrict__ p_a0_grid,
49 const A0B0B1DataType* __restrict__ p_b0_grid,
50 D0sPointer p_d0s_grid,
51 const A0B0B1DataType* __restrict__ p_b1_grid,
52 D1sPointer p_d1s_grid,
53 E1DataType* __restrict__ p_e1_grid,
54 const A0ElementwiseOperation a0_element_op,
55 const B0ElementwiseOperation b0_element_op,
56 const CDE0ElementwiseOperation cde0_element_op,
57 const B1ElementwiseOperation b1_element_op,
58 const CDE1ElementwiseOperation cde1_element_op,
59 const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1,
60 const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1,
61 const D0sGridDesc_M_N d0s_griddesc_m_n,
62 const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
63 const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
64 d1s_grid_desc_mblock_mperblock_nblock_nperblock,
65 const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
66 e1_grid_desc_mblock_mperblock_nblock_nperblock,
67 const Block2E1TileMap block_2_e1tile_map,
68 const index_t batch_count,
69 const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
70{
71#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
72 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
73 {
74 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
75 const index_t num_blocks_per_batch =
76 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
77 const index_t g_idx =
78 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
79
80 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
81 static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
82 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
83 static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
84 const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
85 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
86 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
87 static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
88
89 static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
90 const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
91 static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
92 p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
93 });
94
95 static_for<0, p_d1s_grid.Size(), 1>{}([&](auto In) {
96 const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
97 static_cast<long_index_t>(compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In)));
98 p_d1s_grid(In) = p_d1s_grid(In) + d1_batch_offset;
99 });
100
101 GridwiseGemm::template Run<HasMainKBlockLoop>(
102 p_a0_grid + a_batch_offset,
103 p_b0_grid + b_batch_offset,
104 p_d0s_grid,
105 p_b1_grid + b1_batch_offset,
106 p_d1s_grid,
107 p_e1_grid + c_batch_offset,
108 p_shared,
109 a0_element_op,
110 b0_element_op,
111 cde0_element_op,
112 b1_element_op,
113 cde1_element_op,
114 a0_grid_desc_ak0_m_ak1,
115 b0_grid_desc_bk0_n_bk1,
116 d0s_griddesc_m_n,
117 b1_grid_desc_bk0_n_bk1,
118 d1s_grid_desc_mblock_mperblock_nblock_nperblock,
119 e1_grid_desc_mblock_mperblock_nblock_nperblock,
120 block_2_e1tile_map);
121 }
122#else
123 ignore = p_a0_grid;
124 ignore = p_b0_grid;
125 ignore = p_d0s_grid;
126 ignore = p_b1_grid;
127 ignore = p_d1s_grid;
128 ignore = p_e1_grid;
129 ignore = a0_element_op;
130 ignore = b0_element_op;
131 ignore = cde0_element_op;
132 ignore = b1_element_op;
133 ignore = cde1_element_op;
134 ignore = a0_grid_desc_ak0_m_ak1;
135 ignore = b0_grid_desc_bk0_n_bk1;
136 ignore = d0s_griddesc_m_n;
137 ignore = b1_grid_desc_bk0_n_bk1;
138 ignore = d1s_grid_desc_mblock_mperblock_nblock_nperblock;
139 ignore = e1_grid_desc_mblock_mperblock_nblock_nperblock;
140 ignore = block_2_e1tile_map;
141 ignore = batch_count;
142 ignore = compute_base_ptr_of_batch;
143#endif
144}
145
146// Computes C = A * B0 * B1
147// ^^^^^^ (Acc0)
148// ^^^^^^^^^^^ (Acc1)
149template <typename A0Layout,
150 typename B0Layout, // B0Layout
151 typename D0sLayout,
152 typename B1Layout,
153 typename D1sLayout,
154 typename E1Layout,
155 typename A0DataType,
156 typename B0DataType,
157 typename Acc0DataType,
158 typename D0sDataType,
159 typename B1DataType,
160 typename Acc1DataType,
161 typename C1ShuffleDataType,
162 typename D1sDataType,
163 typename E1DataType,
164 typename A0ElementwiseOperation,
165 typename B0ElementwiseOperation,
166 typename CDE0ElementwiseOperation,
167 typename B1ElementwiseOperation,
168 typename CDE1ElementwiseOperation,
169 bool PadGemm0M,
170 bool PadGemm0N,
171 bool PadGemm0K,
172 bool PadGemm1N,
173 bool PadGemm1K,
174 index_t NumGemm0KPrefetchStage,
175 index_t BlockSize,
176 index_t Gemm0MPerBlock,
177 index_t Gemm0NPerBlock,
178 index_t Gemm0KPerBlock,
179 index_t Gemm1NPerBlock,
180 index_t Gemm1KPerBlock,
181 index_t A0K1,
182 index_t B0K1,
183 index_t B1K1,
184 index_t Gemm0MPerXdl,
185 index_t Gemm0NPerXdl,
186 index_t Gemm0MXdlPerWave,
187 index_t Gemm0NXdlPerWave,
188 index_t Gemm1NXdlPerWave,
189 typename A0BlockTransferThreadClusterLengths_AK0_M_AK1,
190 typename A0BlockTransferThreadClusterArrangeOrder,
191 typename A0BlockTransferSrcAccessOrder,
192 index_t A0BlockTransferSrcVectorDim,
193 index_t A0BlockTransferSrcScalarPerVector,
194 index_t A0BlockTransferDstScalarPerVector_AK1,
195 bool A0BlockLdsExtraM,
196 typename B0BlockTransferThreadClusterLengths_BK0_N_BK1,
197 typename B0BlockTransferThreadClusterArrangeOrder,
198 typename B0BlockTransferSrcAccessOrder,
199 index_t B0BlockTransferSrcVectorDim,
200 index_t B0BlockTransferSrcScalarPerVector,
201 index_t B0BlockTransferDstScalarPerVector_BK1,
202 bool B0BlockLdsExtraN,
203 index_t CDE0BlockTransferSrcVectorDim,
204 index_t CDE0BlockTransferSrcScalaerPerVector,
205 typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
206 typename B1BlockTransferThreadClusterArrangeOrder,
207 typename B1BlockTransferSrcAccessOrder,
208 index_t B1BlockTransferSrcVectorDim,
209 index_t B1BlockTransferSrcScalarPerVector,
210 index_t B1BlockTransferDstScalarPerVector_BK1,
211 bool B1BlockLdsExtraN,
212 index_t C1ShuffleMXdlPerWavePerShuffle,
213 index_t C1ShuffleGemm0NXdlPerWavePerShuffle,
214 typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
215 index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
219 B0Layout,
220 D0sLayout,
221 B1Layout,
222 D1sLayout,
223 E1Layout,
224 A0DataType,
225 B0DataType,
226 D0sDataType,
227 B1DataType,
228 D1sDataType,
229 E1DataType,
230 A0ElementwiseOperation,
231 B0ElementwiseOperation,
232 CDE0ElementwiseOperation,
233 B1ElementwiseOperation,
234 CDE1ElementwiseOperation>
235{
237
238 static constexpr auto Gemm0MXdlPerWave64 = GetNXdlPerWave2<BlockSize,
239 Gemm0NPerBlock,
240 Gemm0MPerBlock,
241 Gemm0NPerXdl,
242 Gemm0MPerXdl,
243 Gemm0NXdlPerWave,
244 true>();
245 static constexpr auto Gemm0MXdlPerWave32 = GetNXdlPerWave2<BlockSize,
246 Gemm0NPerBlock,
247 Gemm0MPerBlock,
248 Gemm0NPerXdl,
249 Gemm0MPerXdl,
250 Gemm0NXdlPerWave,
251 false>();
252
253 static constexpr index_t NumD0Tensor = D0sDataType::Size();
254 static constexpr index_t NumD1Tensor = D1sDataType::Size();
255
256 static constexpr auto I0 = Number<0>{};
257 static constexpr auto I1 = Number<1>{};
258 static constexpr auto I2 = Number<2>{};
259 static constexpr auto I3 = Number<3>{};
260 static constexpr auto I4 = Number<4>{};
261 static constexpr auto I5 = Number<5>{};
262 static constexpr auto I6 = Number<6>{};
263 static constexpr auto I7 = Number<7>{};
264 static constexpr auto I8 = Number<8>{};
265 static constexpr auto I9 = Number<9>{};
266
267 static constexpr auto gemm0_padder =
269 Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock};
270
271 static constexpr auto gemm1_padder =
273 Gemm0MPerBlock, Gemm1NPerBlock, Gemm1KPerBlock};
274
275 // for Gemm0
276 static auto MakeA0GridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA0)
277 {
278 const auto a0_grid_desc_mraw_kraw = [&]() {
280 {
281 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
282 make_tuple(StrideA0, I1));
283 }
285 {
286 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
287 make_tuple(I1, StrideA0));
288 }
289 }();
290
291 return gemm0_padder.PadADescriptor_M_K(a0_grid_desc_mraw_kraw);
292 }
293
294 // for Gemm0
295 static auto MakeB0GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
296 {
297 const auto b0_grid_desc_nraw_kraw = [&]() {
299 {
300 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
301 make_tuple(I1, StrideB));
302 }
304 {
305 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
306 make_tuple(StrideB, I1));
307 }
308 }();
309
310 return gemm0_padder.PadBDescriptor_N_K(b0_grid_desc_nraw_kraw);
311 }
312
313 // for Gemm0
314 template <typename DLay>
315 static auto MakeD0GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideD0)
316 {
317 const auto d0_grid_desc_mraw_nraw = [&]() {
319 {
320 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
321 make_tuple(StrideD0, I1));
322 }
324 {
325 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
326 make_tuple(I1, StrideD0));
327 }
328 }();
329
330 return gemm0_padder.PadCDescriptor_M_N(d0_grid_desc_mraw_nraw);
331 }
332
333 // for Gemm1
334 static auto MakeB1GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
335 {
336 const auto b1_grid_desc_nraw_kraw = [&]() {
338 {
339 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
340 make_tuple(I1, StrideB));
341 }
343 {
344 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
345 make_tuple(StrideB, I1));
346 }
347 }();
348
349 return gemm1_padder.PadBDescriptor_N_K(b1_grid_desc_nraw_kraw);
350 }
351
352 // for Gemm1
353 template <typename ELay>
354 static auto MakeE1GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE1)
355 {
356 const auto e1_grid_desc_mraw_nraw = [&]() {
358 {
359 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
360 make_tuple(StrideE1, I1));
361 }
363 {
364 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
365 make_tuple(I1, StrideE1));
366 }
367 }();
368
369 return gemm1_padder.PadCDescriptor_M_N(e1_grid_desc_mraw_nraw);
370 }
371
372 static auto MakeD0sGridDescriptor_M_N(const std::array<index_t, NumD1Tensor>& MRaws,
373 const std::array<index_t, NumD1Tensor>& NRaws,
374 const std::array<index_t, NumD1Tensor>& DsStride)
375 {
376 return generate_tuple(
377 [&](auto i) {
378 using DLayout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
379
380 return DeviceOp::MakeD0GridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
381 },
383 }
384
385 static auto MakeD1sGridDescriptor_M_N(const std::array<index_t, NumD1Tensor>& MRaws,
386 const std::array<index_t, NumD1Tensor>& NRaws,
387 const std::array<index_t, NumD1Tensor>& DsStride)
388 {
389 return generate_tuple(
390 [&](auto i) {
391 using DLayout = remove_cvref_t<tuple_element_t<i.value, D1sLayout>>;
392
393 return DeviceOp::MakeE1GridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
394 },
396 }
397
399 {
401 index_t BatchStrideB0,
402 std::array<index_t, NumD0Tensor> BatchStrideD0s,
403 index_t BatchStrideB1,
404 std::array<index_t, NumD1Tensor> BatchStrideD1s,
405 index_t BatchStrideE1)
406 : BatchStrideA0_(BatchStrideA0),
407 BatchStrideB0_(BatchStrideB0),
408 BatchStrideD0s_(BatchStrideD0s),
409 BatchStrideB1_(BatchStrideB1),
410 BatchStrideD1s_(BatchStrideD1s),
411 BatchStrideE1_(BatchStrideE1)
412 {
413 }
414
415 __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
416 {
417 return g_idx * static_cast<long_index_t>(BatchStrideA0_);
418 }
419
420 __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
421 {
422 return g_idx * static_cast<long_index_t>(BatchStrideB0_);
423 }
424
425 template <index_t I>
426 __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
427 Number<I> d1_idx) const
428 {
429 return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d1_idx]);
430 }
431
432 __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
433 {
434 return g_idx * static_cast<long_index_t>(BatchStrideB1_);
435 }
436
437 __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
438 {
439 return g_idx * static_cast<long_index_t>(BatchStrideE1_);
440 }
441
442 template <index_t I>
443 __host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number<I> d1_idx) const
444 {
445 return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
446 }
447
448 private:
449 index_t BatchStrideA0_;
450 index_t BatchStrideB0_;
451 std::array<index_t, NumD0Tensor> BatchStrideD0s_;
452 index_t BatchStrideB1_;
453 std::array<index_t, NumD1Tensor> BatchStrideD1s_;
454 index_t BatchStrideE1_;
455 };
456
457 using A0GridDesc_M_K = decltype(MakeA0GridDescriptor_M_K(1, 1, 1));
458 using B0GridDesc_N_K = decltype(MakeB0GridDescriptor_N_K(1, 1, 1));
460 using B1GridDesc_N_K = decltype(MakeB1GridDescriptor_N_K(1, 1, 1));
463
464 // GridwiseGemm
465 template <index_t Gemm0MXdlPerWave_>
467 A0DataType, // TODO: distinguish A/B datatype
468 Acc0DataType,
469 D0sDataType,
470 Acc1DataType,
471 C1ShuffleDataType,
472 D1sDataType,
473 E1DataType,
474 A0ElementwiseOperation,
475 B0ElementwiseOperation,
476 CDE0ElementwiseOperation,
477 B1ElementwiseOperation,
478 CDE1ElementwiseOperation,
486 NumGemm0KPrefetchStage,
487 BlockSize,
488 Gemm0MPerBlock,
489 Gemm0NPerBlock,
490 Gemm0KPerBlock,
491 Gemm1NPerBlock,
492 Gemm1KPerBlock,
493 A0K1,
494 B0K1,
495 B1K1,
496 Gemm0MPerXdl,
497 Gemm0NPerXdl,
498 Gemm0MXdlPerWave_,
499 Gemm0NXdlPerWave,
500 Gemm1NXdlPerWave,
501 A0BlockTransferThreadClusterLengths_AK0_M_AK1,
502 A0BlockTransferThreadClusterArrangeOrder,
503 A0BlockTransferSrcAccessOrder,
504 A0BlockTransferSrcVectorDim,
505 A0BlockTransferSrcScalarPerVector,
506 A0BlockTransferDstScalarPerVector_AK1,
507 true,
508 A0BlockLdsExtraM,
509 B0BlockTransferThreadClusterLengths_BK0_N_BK1,
510 B0BlockTransferThreadClusterArrangeOrder,
511 B0BlockTransferSrcAccessOrder,
512 B0BlockTransferSrcVectorDim,
513 B0BlockTransferSrcScalarPerVector,
514 B0BlockTransferDstScalarPerVector_BK1,
515 true,
516 B0BlockLdsExtraN,
517 CDE0BlockTransferSrcVectorDim,
518 CDE0BlockTransferSrcScalaerPerVector,
519 B1BlockTransferThreadClusterLengths_BK0_N_BK1,
520 B1BlockTransferThreadClusterArrangeOrder,
521 B1BlockTransferSrcAccessOrder,
522 B1BlockTransferSrcVectorDim,
523 B1BlockTransferSrcScalarPerVector,
524 B1BlockTransferDstScalarPerVector_BK1,
525 false,
526 B1BlockLdsExtraN,
527 C1ShuffleMXdlPerWavePerShuffle,
528 C1ShuffleGemm0NXdlPerWavePerShuffle,
529 CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
530 CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
531 LoopSched>;
534
537 A0GridDesc_M_K{}))>;
540 B0GridDesc_N_K{}))>;
543 B1GridDesc_N_K{}))>;
544
545 // Argument
546 struct Argument : public BaseArgument
547 {
548 Argument(const A0DataType* p_a0_grid,
549 const B0DataType* p_b0_grid,
550 std::array<const void*, NumD0Tensor> p_d0s_grid,
551 const B1DataType* p_b1_grid,
552 std::array<const void*, NumD1Tensor> p_d1s_grid,
553 E1DataType* p_e1_grid,
554 index_t MRaw,
555 index_t NRaw,
556 index_t KRaw,
557 index_t Gemm1NRaw, // = ORaw
558 index_t Batch,
559 index_t StrideA0,
560 index_t StrideB0,
561 std::array<index_t, NumD0Tensor> StrideD0s,
562 index_t StrideB1,
563 std::array<index_t, NumD1Tensor> StrideD1s,
564 index_t StrideE1,
565 index_t BatchStrideA0,
566 index_t BatchStrideB0,
567 std::array<index_t, NumD0Tensor> BatchStrideD0s,
568 index_t BatchStrideB1,
569 std::array<index_t, NumD1Tensor> BatchStrideD1s,
570 index_t BatchStrideE1,
571 A0ElementwiseOperation a0_element_op,
572 B0ElementwiseOperation b0_element_op,
573 CDE0ElementwiseOperation cde0_element_op,
574 B1ElementwiseOperation b1_element_op,
575 CDE1ElementwiseOperation cde1_element_op)
576 : p_a0_grid_{p_a0_grid},
577 p_b0_grid_{p_b0_grid},
578 p_d0s_grid_{},
579 p_b1_grid_{p_b1_grid},
580 p_d1s_grid_{},
581 p_e1_grid_{p_e1_grid},
585 b1_grid_desc_n_k_{DeviceOp::MakeB1GridDescriptor_N_K(NRaw, Gemm1NRaw, StrideB1)},
588 DeviceOp::MakeE1GridDescriptor_M_N<E1Layout>(MRaw, Gemm1NRaw, StrideE1)},
590 GridwiseGemm64::MakeDefaultA0GridDescriptor_AK0_M_AK1(a0_grid_desc_m_k_)},
592 GridwiseGemm64::MakeDefaultB0GridDescriptor_BK0_N_BK1(b0_grid_desc_n_k_)},
594 GridwiseGemm64::MakeDefaultB1GridDescriptor_BK0_N_BK1(b1_grid_desc_n_k_)},
595 block_2_e1tile_map_{GridwiseGemm64::MakeDefaultBlock2E1TileMap(e1_grid_desc_m_n_)},
596 a0_element_op_{a0_element_op},
597 b0_element_op_{b0_element_op},
598 cde0_element_op_{cde0_element_op},
599 b1_element_op_{b1_element_op},
600 cde1_element_op_{cde1_element_op},
601 batch_count_(Batch),
602 compute_base_ptr_of_batch_{BatchStrideA0,
603 BatchStrideB0,
604 BatchStrideD0s,
605 BatchStrideB1,
606 BatchStrideD1s,
607 BatchStrideE1}
608 {
609 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
610 {
611 std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", "
612 << a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl;
613 std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", "
614 << b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
615 std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0)
616 << ", " << d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl;
617 std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", "
618 << b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
619 std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", "
620 << e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
621 }
622
623 static_for<0, NumD0Tensor, 1>{}([&](auto i) {
624 using D0Layout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
625 using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
626
627 // D0 pointer
628 p_d0s_grid_(i) = static_cast<const D0DataType*>(p_d0s_grid[i]);
629
630 // D0 desc
632 DeviceOp::MakeD0GridDescriptor_M_N<D0Layout>(MRaw, NRaw, StrideD0s[i]);
633 });
634
635 static_for<0, NumD1Tensor, 1>{}([&](auto i) {
636 using D1Layout = remove_cvref_t<tuple_element_t<i.value, D1sLayout>>;
637 using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
638
639 // D1 pointer
640 p_d1s_grid_(i) = static_cast<const D1DataType*>(p_d1s_grid[i]);
641
642 // D1 desc
644 DeviceOp::MakeE1GridDescriptor_M_N<D1Layout>(MRaw, Gemm1NRaw, StrideD1s[i]);
645 });
646 }
647
648 // private:
649 // pointers
650 const A0DataType* p_a0_grid_;
651 const B0DataType* p_b0_grid_;
653 const B1DataType* p_b1_grid_;
655 E1DataType* p_e1_grid_;
656
657 // tensor descriptors for problem definiton
664
665 // tensor descriptors for block/thread-wise copy
669
670 // block-to-e1-tile map
672
673 // element-wise op
674 A0ElementwiseOperation a0_element_op_;
675 B0ElementwiseOperation b0_element_op_;
676 CDE0ElementwiseOperation cde0_element_op_;
677 B1ElementwiseOperation b1_element_op_;
678 CDE1ElementwiseOperation cde1_element_op_;
679
680 // batch
683 };
684
685 // Invoker
686 struct Invoker : public BaseInvoker
687 {
689
690 template <typename GridwiseGemm>
691 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
692 {
693 if(!GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
698 {
699 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
700 }
701
702 auto e1_grid_desc_mblock_mperblock_nblock_nperblock =
703 GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
705
706 auto d1s_grid_desc_mblock_mperblock_nblock_nperblock =
707 GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
709
710 const index_t grid_size =
711 arg.block_2_e1tile_map_.CalculateGridSize(arg.e1_grid_desc_m_n_) * arg.batch_count_;
712
713 // Gemm0_K
714 const auto K = arg.a0_grid_desc_m_k_.GetLength(I1);
715
716 auto launch_kernel = [&](auto has_main_k_block_loop_) {
718 GridwiseGemm,
719 A0DataType, // TODO: distiguish A/B datatype
720 typename GridwiseGemm::D0sGridPointer,
721 typename GridwiseGemm::D1sGridPointer,
722 E1DataType,
723 A0ElementwiseOperation,
724 B0ElementwiseOperation,
725 CDE0ElementwiseOperation,
726 B1ElementwiseOperation,
727 CDE1ElementwiseOperation,
732 typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
733 typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
734 typename GridwiseGemm::DefaultBlock2E1TileMap,
735 ComputeBasePtrOfStridedBatch,
736 has_main_k_block_loop_>;
737
738 return launch_and_time_kernel(stream_config,
739 kernel,
740 dim3(grid_size),
741 dim3(BlockSize),
742 0,
743 arg.p_a0_grid_,
744 arg.p_b0_grid_,
745 arg.p_d0s_grid_,
746 arg.p_b1_grid_,
747 arg.p_d1s_grid_,
748 arg.p_e1_grid_,
749 arg.a0_element_op_,
750 arg.b0_element_op_,
752 arg.b1_element_op_,
758 d1s_grid_desc_mblock_mperblock_nblock_nperblock,
759 e1_grid_desc_mblock_mperblock_nblock_nperblock,
761 arg.batch_count_,
763 };
764
765 // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
766 // to concern Gemm0's loop
767 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
768 {
769 return launch_kernel(integral_constant<bool, true>{});
770 }
771 else
772 {
773 return launch_kernel(integral_constant<bool, false>{});
774 }
775 }
776
777 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
778 {
779 if(get_warp_size() == 64)
780 {
781 if constexpr(Gemm0MXdlPerWave64 > 0)
782 {
783 return RunImp<GridwiseGemm64>(arg, stream_config);
784 }
785 }
786 else
787 {
788 if constexpr(Gemm0MXdlPerWave32 > 0)
789 {
790 return RunImp<GridwiseGemm32>(arg, stream_config);
791 }
792 }
793 return 0;
794 }
795
796 // polymorphic
797 float Run(const BaseArgument* p_arg,
798 const StreamConfig& stream_config = StreamConfig{}) override
799 {
800 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
801 }
802 };
803
804 static constexpr bool IsValidCompilationParameter()
805 {
806 // TODO: properly implement this check
807 return true;
808 }
809
810 // check if DsLayout is supported
811 template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
812 static bool CheckDLayout()
813 {
814 static bool valid = true;
815 // iterate over DLayout tuple
816 static_for<0, NumDTensor, 1>{}([&](auto i) {
817 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
818 // if RefLayout and DLayout are same, keep valid true, otherwise false
819 valid = valid && is_same_v<RefLayout, DLayout>;
820 });
821 return valid;
822 }
823
824 static bool IsSupportedArgument(const Argument& arg)
825 {
827 {
828 return false;
829 }
830
831 // Check supported layouts
832 // A0 - Row
833 // B0 - Col
834 // D0s - Rows
835 // B1 - Row or Col
836 // D1s - Rows
837 // E1 - Row
845 {
846 return false;
847 }
848
849 if(get_warp_size() == 64)
850 {
851 if constexpr(Gemm0MXdlPerWave64 > 0)
852 {
858 }
859 }
860 else
861 {
862 if constexpr(Gemm0MXdlPerWave32 > 0)
863 {
869 }
870 }
871 return false;
872 }
873
874 // polymorphic
875 bool IsSupportedArgument(const BaseArgument* p_arg) override
876 {
877 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
878 }
879
880 static auto MakeArgument(const A0DataType* p_a0,
881 const B0DataType* p_b0,
882 std::array<const void*, NumD0Tensor> p_d0s,
883 const B1DataType* p_b1,
884 std::array<const void*, NumD1Tensor> p_d1s,
885 E1DataType* p_e1,
886 index_t MRaw,
887 index_t NRaw,
888 index_t KRaw,
889 index_t Gemm1NRaw,
890 index_t Batch,
891 index_t StrideA0,
892 index_t StrideB0,
893 std::array<index_t, NumD0Tensor> StrideD0s,
894 index_t StrideB1,
895 std::array<index_t, NumD1Tensor> StrideD1s,
896 index_t StrideE1,
897 index_t BatchStrideA0,
898 index_t BatchStrideB0,
899 std::array<index_t, NumD0Tensor> BatchStrideD0s,
900 index_t BatchStrideB1,
901 std::array<index_t, NumD1Tensor> BatchStrideD1s,
902 index_t BatchStrideE1,
903 A0ElementwiseOperation a0_element_op,
904 B0ElementwiseOperation b0_element_op,
905 CDE0ElementwiseOperation cde0_element_op,
906 B1ElementwiseOperation b1_element_op,
907 CDE1ElementwiseOperation cde1_element_op)
908 {
909 return Argument{p_a0, p_b0,
910 p_d0s, p_b1,
911 p_d1s, p_e1,
912 MRaw, NRaw,
913 KRaw, Gemm1NRaw,
914 Batch, StrideA0,
915 StrideB0, StrideD0s,
916 StrideB1, StrideD1s,
917 StrideE1, BatchStrideA0,
918 BatchStrideB0, BatchStrideD0s,
919 BatchStrideB1, BatchStrideD1s,
920 BatchStrideE1, a0_element_op,
921 b0_element_op, cde0_element_op,
922 b1_element_op, cde1_element_op};
923 }
924
925 static auto MakeInvoker() { return Invoker{}; }
926
927 // polymorphic
928 std::unique_ptr<BaseArgument>
929 MakeArgumentPointer(const void* p_a0,
930 const void* p_b0,
931 std::array<const void*, NumD0Tensor> p_d0s,
932 const void* p_b1,
933 std::array<const void*, NumD1Tensor> p_d1s,
934 void* p_e1,
935 index_t MRaw,
936 index_t NRaw,
937 index_t KRaw,
938 index_t Gemm1NRaw,
939 index_t Batch,
940 index_t StrideA0,
941 index_t StrideB0,
942 std::array<ck::index_t, NumD0Tensor> StrideD0s,
943 index_t StrideB1,
944 std::array<ck::index_t, NumD1Tensor> StrideD1s,
945 index_t StrideE1,
946 index_t BatchStrideA0,
947 index_t BatchStrideB0,
948 std::array<ck::index_t, NumD0Tensor> BatchStrideD0s,
949 index_t BatchStrideB1,
950 std::array<ck::index_t, NumD1Tensor> BatchStrideD1s,
951 index_t BatchStrideE1,
952 A0ElementwiseOperation a0_element_op,
953 B0ElementwiseOperation b0_element_op,
954 CDE0ElementwiseOperation cde0_element_op,
955 B1ElementwiseOperation b1_element_op,
956 CDE1ElementwiseOperation cde1_element_op) override
957 {
958 return std::make_unique<Argument>(static_cast<const A0DataType*>(p_a0),
959 static_cast<const B0DataType*>(p_b0),
960 p_d0s,
961 static_cast<const B1DataType*>(p_b1),
962 p_d1s,
963 static_cast<E1DataType*>(p_e1),
964 MRaw,
965 NRaw,
966 KRaw,
967 Gemm1NRaw,
968 Batch,
969 StrideA0,
970 StrideB0,
971 StrideD0s,
972 StrideB1,
973 StrideD1s,
974 StrideE1,
975 BatchStrideA0,
976 BatchStrideB0,
977 BatchStrideD0s,
978 BatchStrideB1,
979 BatchStrideD1s,
980 BatchStrideE1,
981 a0_element_op,
982 b0_element_op,
983 cde0_element_op,
984 b1_element_op,
985 cde1_element_op);
986 }
987
988 // polymorphic
989 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
990 {
991 return std::make_unique<Invoker>(Invoker{});
992 }
993
994 // polymorphic
995 std::string GetTypeString() const override
996 {
997 auto str = std::stringstream();
998
999 // clang-format off
1000 str << "DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle"
1001 << "<"
1002 << BlockSize << ", "
1003 << Gemm0MPerBlock << ", "
1004 << Gemm0NPerBlock << ", "
1005 << Gemm0KPerBlock << ", "
1006 << A0K1 << ", "
1007 << B0K1 << ", "
1008 << B1K1 << ", "
1009 << Gemm0MPerXdl << ", "
1010 << Gemm0NPerXdl << ", "
1011 << Gemm0MXdlPerWave << ", "
1012 << Gemm0NXdlPerWave << ", "
1013 << Gemm1NXdlPerWave << "> ";
1014 // clang-format on
1015
1016 return str.str();
1017 }
1018};
1019
1020} // namespace device
1021} // namespace tensor_operation
1022} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
__global__ void kernel_batched_gemm_gemm_xdl_cshuffle_v1(const A0B0B1DataType *__restrict__ p_a0_grid, const A0B0B1DataType *__restrict__ p_b0_grid, D0sPointer p_d0s_grid, const A0B0B1DataType *__restrict__ p_b1_grid, D1sPointer p_d1s_grid, E1DataType *__restrict__ p_e1_grid, const A0ElementwiseOperation a0_element_op, const B0ElementwiseOperation b0_element_op, const CDE0ElementwiseOperation cde0_element_op, const B1ElementwiseOperation b1_element_op, const CDE1ElementwiseOperation cde1_element_op, const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1, const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1, const D0sGridDesc_M_N d0s_griddesc_m_n, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock d1s_grid_desc_mblock_mperblock_nblock_nperblock, const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e1_grid_desc_mblock_mperblock_nblock_nperblock, const Block2E1TileMap block_2_e1tile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:47
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
int64_t long_index_t
Definition ck.hpp:300
Definition ck/stream_config.hpp:10
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:86
ck::GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::CheckValidity
static __host__ constexpr bool CheckValidity(const A0GridDesc_M_K &a0_grid_desc_m_k, const B0GridDesc_N_K &b0_grid_desc_n_k, const B1GridDesc_N_K &b1_grid_desc_n_k, const E1GridDesc_M_N &e1_grid_desc_m_n, const Block2E1TileMap &block_2_e1tile_map)
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:286
ck::GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::D0sGridPointer
decltype(MakeD0sGridPointer()) D0sGridPointer
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:553
ck::GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::MakeDefaultB1GridDescriptor_BK0_N_BK1
__host__ static __device__ constexpr auto MakeDefaultB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k)
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:443
ck::GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::DefaultBlock2E1TileMap
remove_cvref_t< decltype(MakeDefaultBlock2E1TileMap(E1GridDesc_M_N{}))> DefaultBlock2E1TileMap
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:520
ck::GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::MakeDefaultB0GridDescriptor_BK0_N_BK1
__host__ static __device__ constexpr auto MakeDefaultB0GridDescriptor_BK0_N_BK1(const B0GridDesc_N_K &b0_grid_desc_n_k)
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:386
ck::GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::MakeDefaultA0GridDescriptor_AK0_M_AK1
__host__ static __device__ constexpr auto MakeDefaultA0GridDescriptor_AK0_M_AK1(const A0GridDesc_M_K &a0_grid_desc_m_k)
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:369
ck::GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::D1sGridPointer
decltype(MakeD1sGridPointer()) D1sGridPointer
Definition gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp:554
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:547
CDE1ElementwiseOperation cde1_element_op_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:678
A0GridDesc_M_K a0_grid_desc_m_k_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:658
const B0DataType * p_b0_grid_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:651
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:668
index_t batch_count_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:681
const B1DataType * p_b1_grid_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:653
B0GridDesc_N_K b0_grid_desc_n_k_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:659
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:682
A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:666
GridwiseGemm64::D0sGridPointer p_d0s_grid_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:652
Argument(const A0DataType *p_a0_grid, const B0DataType *p_b0_grid, std::array< const void *, NumD0Tensor > p_d0s_grid, const B1DataType *p_b1_grid, std::array< const void *, NumD1Tensor > p_d1s_grid, E1DataType *p_e1_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA0, index_t StrideB0, std::array< index_t, NumD0Tensor > StrideD0s, index_t StrideB1, std::array< index_t, NumD1Tensor > StrideD1s, index_t StrideE1, index_t BatchStrideA0, index_t BatchStrideB0, std::array< index_t, NumD0Tensor > BatchStrideD0s, index_t BatchStrideB1, std::array< index_t, NumD1Tensor > BatchStrideD1s, index_t BatchStrideE1, A0ElementwiseOperation a0_element_op, B0ElementwiseOperation b0_element_op, CDE0ElementwiseOperation cde0_element_op, B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation cde1_element_op)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:548
B1ElementwiseOperation b1_element_op_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:677
E1GridDesc_M_N e1_grid_desc_m_n_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:663
B1GridDesc_N_K b1_grid_desc_n_k_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:661
B0ElementwiseOperation b0_element_op_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:675
E1DataType * p_e1_grid_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:655
D1sGridDesc_M_N d1s_grid_desc_m_n_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:662
D0sGridDesc_M_N d0s_grid_desc_m_n_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:660
A0ElementwiseOperation a0_element_op_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:674
CDE0ElementwiseOperation cde0_element_op_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:676
GridwiseGemm64::D1sGridPointer p_d1s_grid_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:654
GridwiseGemm64::DefaultBlock2E1TileMap block_2_e1tile_map_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:671
B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:667
const A0DataType * p_a0_grid_
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:650
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:687
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:777
DeviceOp::Argument Argument
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:688
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:797
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:691
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:235
static constexpr index_t NumD1Tensor
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:254
remove_cvref_t< decltype(MakeD0sGridDescriptor_M_N({}, {}, {}))> D0sGridDesc_M_N
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:459
static auto MakeB1GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:334
static constexpr auto Gemm0MXdlPerWave32
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:245
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultB0GridDescriptor_BK0_N_BK1( B0GridDesc_N_K{}))> B0GridDesc_BK0_N_BK1
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:538
remove_cvref_t< decltype(MakeD1sGridDescriptor_M_N({}, {}, {}))> D1sGridDesc_M_N
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:461
static auto MakeB0GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:295
static constexpr auto I7
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:263
static bool CheckDLayout()
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:812
static constexpr auto I8
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:264
static constexpr auto gemm0_padder
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:267
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:804
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:875
GridwiseGemmBase< Gemm0MXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:533
static auto MakeA0GridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA0)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:276
static constexpr auto I4
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:260
static auto MakeD1sGridDescriptor_M_N(const std::array< index_t, NumD1Tensor > &MRaws, const std::array< index_t, NumD1Tensor > &NRaws, const std::array< index_t, NumD1Tensor > &DsStride)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:385
GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, Acc0DataType, D0sDataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, A0GridDesc_M_K, B0GridDesc_N_K, D0sGridDesc_M_N, B1GridDesc_N_K, D1sGridDesc_M_N, E1GridDesc_M_N, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave_, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, true, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, true, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:466
GridwiseGemmBase< math::max(Gemm0MXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:532
static auto MakeD0GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideD0)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:315
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle DeviceOp
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:236
static auto MakeInvoker()
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:925
static constexpr auto I0
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:256
static auto MakeArgument(const A0DataType *p_a0, const B0DataType *p_b0, std::array< const void *, NumD0Tensor > p_d0s, const B1DataType *p_b1, std::array< const void *, NumD1Tensor > p_d1s, E1DataType *p_e1, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA0, index_t StrideB0, std::array< index_t, NumD0Tensor > StrideD0s, index_t StrideB1, std::array< index_t, NumD1Tensor > StrideD1s, index_t StrideE1, index_t BatchStrideA0, index_t BatchStrideB0, std::array< index_t, NumD0Tensor > BatchStrideD0s, index_t BatchStrideB1, std::array< index_t, NumD1Tensor > BatchStrideD1s, index_t BatchStrideE1, A0ElementwiseOperation a0_element_op, B0ElementwiseOperation b0_element_op, CDE0ElementwiseOperation cde0_element_op, B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation cde1_element_op)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:880
static constexpr auto I2
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:258
decltype(MakeB1GridDescriptor_N_K(1, 1, 1)) B1GridDesc_N_K
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:460
static auto MakeD0sGridDescriptor_M_N(const std::array< index_t, NumD1Tensor > &MRaws, const std::array< index_t, NumD1Tensor > &NRaws, const std::array< index_t, NumD1Tensor > &DsStride)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:372
static constexpr auto gemm1_padder
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:271
std::string GetTypeString() const override
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:995
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultA0GridDescriptor_AK0_M_AK1( A0GridDesc_M_K{}))> A0GridDesc_AK0_M_AK1
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:535
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:824
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a0, const void *p_b0, std::array< const void *, NumD0Tensor > p_d0s, const void *p_b1, std::array< const void *, NumD1Tensor > p_d1s, void *p_e1, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA0, index_t StrideB0, std::array< ck::index_t, NumD0Tensor > StrideD0s, index_t StrideB1, std::array< ck::index_t, NumD1Tensor > StrideD1s, index_t StrideE1, index_t BatchStrideA0, index_t BatchStrideB0, std::array< ck::index_t, NumD0Tensor > BatchStrideD0s, index_t BatchStrideB1, std::array< ck::index_t, NumD1Tensor > BatchStrideD1s, index_t BatchStrideE1, A0ElementwiseOperation a0_element_op, B0ElementwiseOperation b0_element_op, CDE0ElementwiseOperation cde0_element_op, B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation cde1_element_op) override
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:929
static constexpr auto I1
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:257
static constexpr auto Gemm0MXdlPerWave64
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:238
static constexpr auto I9
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:265
decltype(MakeA0GridDescriptor_M_K(1, 1, 1)) A0GridDesc_M_K
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:457
decltype(MakeE1GridDescriptor_M_N< E1Layout >(1, 1, 1)) E1GridDesc_M_N
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:462
decltype(MakeB0GridDescriptor_N_K(1, 1, 1)) B0GridDesc_N_K
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:458
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultB1GridDescriptor_BK0_N_BK1( B1GridDesc_N_K{}))> B1GridDesc_BK0_N_BK1
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:541
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:989
static constexpr auto I5
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:261
static constexpr auto I6
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:262
static constexpr auto I3
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:259
static constexpr index_t NumD0Tensor
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:253
static auto MakeE1GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE1)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:354
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:399
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:415
__host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number< I > d1_idx) const
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:443
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:432
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:420
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, Number< I > d1_idx) const
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:426
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0, index_t BatchStrideB0, std::array< index_t, NumD0Tensor > BatchStrideD0s, index_t BatchStrideB1, std::array< index_t, NumD1Tensor > BatchStrideD1s, index_t BatchStrideE1)
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:400
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp:437
Definition device_batched_gemm_multiple_d_gemm_multiple_d.hpp:33
Definition matrix_padder.hpp:204
#define CK_ENV(name)
Definition utility/env.hpp:129