xdlops_gemm.hpp Source File

xdlops_gemm.hpp Source File#

Composable Kernel: xdlops_gemm.hpp Source File
xdlops_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/math.hpp"
10
11namespace ck {
15template <typename T>
16static constexpr bool is_scale_mfma_data_type()
17{
18 using U = element_type_t<T>;
21}
22
23#ifndef CK_CODE_GEN_RTC
27template <typename T>
28static constexpr bool is_scale_mfma_scale_type()
29{
31}
32#endif
33
37template <typename ADataType, typename BDataType, typename AScaleDataType, typename BScaleDataType>
38static constexpr bool scale_mfma_hw_support()
39{
40 return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
41 is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
42}
43
100
101template <MfmaInstr instr>
103
104template <>
106{
107 static constexpr index_t group_size = 4;
108 static constexpr index_t num_groups_per_blk = 4;
109 static constexpr index_t num_regs_per_blk = 16;
110 static constexpr index_t num_threads_per_blk = 32;
111 static constexpr index_t wave_size = 64;
112 static constexpr index_t num_input_blks = 2;
113 static constexpr index_t num_output_blks = 2;
114 static constexpr index_t m_per_blk = 32;
115 static constexpr index_t n_per_blk = 32;
116 static constexpr index_t k_per_blk = 1;
117 static constexpr bool is_k_reduction = false;
118
119 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
120 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
121 {
123 }
124};
125
126template <>
128{
129 static constexpr index_t group_size = 4;
130 static constexpr index_t num_groups_per_blk = 4;
131 static constexpr index_t num_regs_per_blk = 16;
132 static constexpr index_t num_threads_per_blk = 32;
133 static constexpr index_t wave_size = 64;
134 static constexpr index_t num_input_blks = 2;
135 static constexpr index_t num_output_blks = 1;
136 static constexpr index_t m_per_blk = 32;
137 static constexpr index_t n_per_blk = 32;
138 static constexpr index_t k_per_blk = 1;
139 static constexpr bool is_k_reduction = true;
140
141 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
142 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
143 {
145 }
146};
147
148template <>
150{
151 static constexpr index_t group_size = 4;
152 static constexpr index_t num_groups_per_blk = 1;
153 static constexpr index_t num_regs_per_blk = 4;
154 static constexpr index_t num_threads_per_blk = 16;
155 static constexpr index_t wave_size = 64;
156 static constexpr index_t num_input_blks = 4;
157 static constexpr index_t num_output_blks = 1;
158 static constexpr index_t m_per_blk = 16;
159 static constexpr index_t n_per_blk = 16;
160 static constexpr index_t k_per_blk = 1;
161 static constexpr bool is_k_reduction = true;
162
163 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
164 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
165 {
167 }
168};
169
170template <>
172{
173 static constexpr index_t group_size = 4;
174 static constexpr index_t num_groups_per_blk = 1;
175 static constexpr index_t num_regs_per_blk = 4;
176 static constexpr index_t num_threads_per_blk = 16;
177 static constexpr index_t wave_size = 64;
178 static constexpr index_t num_input_blks = 4;
179 static constexpr index_t num_output_blks = 4;
180 static constexpr index_t m_per_blk = 16;
181 static constexpr index_t n_per_blk = 16;
182 static constexpr index_t k_per_blk = 1;
183 static constexpr bool is_k_reduction = false;
184
185 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
186 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
187 {
189 }
190};
191
192// treat 4x4x1 as a single-blk 4x64 mfma
193template <>
195{
196 static constexpr index_t group_size = 4;
197 static constexpr index_t num_groups_per_blk = 1;
198 static constexpr index_t num_regs_per_blk = 4;
199 static constexpr index_t num_threads_per_blk = 64;
200 static constexpr index_t wave_size = 64;
201 static constexpr index_t num_input_blks = 1;
202 static constexpr index_t num_output_blks = 1;
203 static constexpr index_t m_per_blk = 4;
204 static constexpr index_t n_per_blk = 64;
205 static constexpr index_t k_per_blk = 1;
206 static constexpr bool is_k_reduction = false;
207
208 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
209 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
210 {
212 }
213};
214
215template <>
217{
218 static constexpr index_t group_size = 4;
219 static constexpr index_t num_groups_per_blk = 4;
220 static constexpr index_t num_regs_per_blk = 16;
221 static constexpr index_t num_threads_per_blk = 32;
222 static constexpr index_t wave_size = 64;
223 static constexpr index_t num_input_blks = 2;
224 static constexpr index_t num_output_blks = 2;
225 static constexpr index_t m_per_blk = 32;
226 static constexpr index_t n_per_blk = 32;
227 static constexpr index_t k_per_blk = 4;
228 static constexpr bool is_k_reduction = false;
229
230 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
231 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
232 {
234 }
235};
236
237template <>
239{
240 static constexpr index_t group_size = 4;
241 static constexpr index_t num_groups_per_blk = 4;
242 static constexpr index_t num_regs_per_blk = 16;
243 static constexpr index_t num_threads_per_blk = 32;
244 static constexpr index_t wave_size = 64;
245 static constexpr index_t num_input_blks = 2;
246 static constexpr index_t num_output_blks = 1;
247 static constexpr index_t m_per_blk = 32;
248 static constexpr index_t n_per_blk = 32;
249 static constexpr index_t k_per_blk = 4;
250 static constexpr bool is_k_reduction = true;
251
252 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
253 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
254 {
256 }
257};
258
259template <>
261{
262 static constexpr index_t group_size = 4;
263 static constexpr index_t num_groups_per_blk = 4;
264 static constexpr index_t num_regs_per_blk = 16;
265 static constexpr index_t num_threads_per_blk = 32;
266 static constexpr index_t wave_size = 64;
267 static constexpr index_t num_input_blks = 2;
268 static constexpr index_t num_output_blks = 1;
269 static constexpr index_t m_per_blk = 32;
270 static constexpr index_t n_per_blk = 32;
271 static constexpr index_t k_per_blk = 8;
272 static constexpr bool is_k_reduction = true;
273
274 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
275 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
276 {
278 }
279};
280
281template <>
283{
284 static constexpr index_t group_size = 4;
285 static constexpr index_t num_groups_per_blk = 1;
286 static constexpr index_t num_regs_per_blk = 4;
287 static constexpr index_t num_threads_per_blk = 16;
288 static constexpr index_t wave_size = 64;
289 static constexpr index_t num_input_blks = 4;
290 static constexpr index_t num_output_blks = 1;
291 static constexpr index_t m_per_blk = 16;
292 static constexpr index_t n_per_blk = 16;
293 static constexpr index_t k_per_blk = 8;
294 static constexpr bool is_k_reduction = true;
295
296 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
297 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
298 {
300 }
301};
302
303template <>
305{
306 static constexpr index_t group_size = 4;
307 static constexpr index_t num_groups_per_blk = 1;
308 static constexpr index_t num_regs_per_blk = 4;
309 static constexpr index_t num_threads_per_blk = 16;
310 static constexpr index_t wave_size = 64;
311 static constexpr index_t num_input_blks = 4;
312 static constexpr index_t num_output_blks = 1;
313 static constexpr index_t m_per_blk = 16;
314 static constexpr index_t n_per_blk = 16;
315 static constexpr index_t k_per_blk = 4;
316 static constexpr bool is_k_reduction = true;
317
318 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
319 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
320 {
322 }
323};
324
325template <>
327{
328 static constexpr index_t group_size = 4;
329 static constexpr index_t num_groups_per_blk = 1;
330 static constexpr index_t num_regs_per_blk = 4;
331 static constexpr index_t num_threads_per_blk = 16;
332 static constexpr index_t wave_size = 64;
333 static constexpr index_t num_input_blks = 4;
334 static constexpr index_t num_output_blks = 4;
335 static constexpr index_t m_per_blk = 16;
336 static constexpr index_t n_per_blk = 16;
337 static constexpr index_t k_per_blk = 4;
338 static constexpr bool is_k_reduction = false;
339
340 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
341 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
342 {
344 }
345};
346
347template <>
349{
350 static constexpr index_t group_size = 4;
351 static constexpr index_t num_groups_per_blk = 1;
352 static constexpr index_t num_regs_per_blk = 4;
353 static constexpr index_t num_threads_per_blk = 64;
354 static constexpr index_t wave_size = 64;
355 static constexpr index_t num_input_blks = 1;
356 static constexpr index_t num_output_blks = 1;
357 static constexpr index_t m_per_blk = 4;
358 static constexpr index_t n_per_blk = 64;
359 static constexpr index_t k_per_blk = 4;
360 static constexpr bool is_k_reduction = false;
361
362 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
363 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
364 {
366 }
367};
368
369template <>
371{
372 static constexpr index_t group_size = 4;
373 static constexpr index_t num_groups_per_blk = 4;
374 static constexpr index_t num_regs_per_blk = 16;
375 static constexpr index_t num_threads_per_blk = 32;
376 static constexpr index_t wave_size = 64;
377 static constexpr index_t num_input_blks = 2;
378 static constexpr index_t num_output_blks = 1;
379 static constexpr index_t m_per_blk = 32;
380 static constexpr index_t n_per_blk = 32;
381 static constexpr index_t k_per_blk = 8;
382 static constexpr bool is_k_reduction = true;
383
384 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
385 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
386 {
388 }
389};
390
391template <>
393{
394 static constexpr index_t group_size = 4;
395 static constexpr index_t num_groups_per_blk = 4;
396 static constexpr index_t num_regs_per_blk = 16;
397 static constexpr index_t num_threads_per_blk = 32;
398 static constexpr index_t wave_size = 64;
399 static constexpr index_t num_input_blks = 2;
400 static constexpr index_t num_output_blks = 1;
401 static constexpr index_t m_per_blk = 32;
402 static constexpr index_t n_per_blk = 32;
403 static constexpr index_t k_per_blk = 4;
404 static constexpr bool is_k_reduction = true;
405
406 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
407 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
408 {
410 }
411};
412
413template <>
415{
416 static constexpr index_t group_size = 4;
417 static constexpr index_t num_groups_per_blk = 1;
418 static constexpr index_t num_regs_per_blk = 4;
419 static constexpr index_t num_threads_per_blk = 16;
420 static constexpr index_t wave_size = 64;
421 static constexpr index_t num_input_blks = 4;
422 static constexpr index_t num_output_blks = 1;
423 static constexpr index_t m_per_blk = 16;
424 static constexpr index_t n_per_blk = 16;
425 static constexpr index_t k_per_blk = 8;
426 static constexpr bool is_k_reduction = true;
427
428 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
429 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
430 {
432 }
433};
434
435template <>
437{
438 static constexpr index_t group_size = 4;
439 static constexpr index_t num_groups_per_blk = 1;
440 static constexpr index_t num_regs_per_blk = 4;
441 static constexpr index_t num_threads_per_blk = 16;
442 static constexpr index_t wave_size = 64;
443 static constexpr index_t num_input_blks = 4;
444 static constexpr index_t num_output_blks = 1;
445 static constexpr index_t m_per_blk = 16;
446 static constexpr index_t n_per_blk = 16;
447 static constexpr index_t k_per_blk = 4;
448 static constexpr bool is_k_reduction = true;
449
450 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
451 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
452 {
454 }
455};
456
457template <>
459{
460 static constexpr index_t group_size = 4;
461 static constexpr index_t num_groups_per_blk = 4;
462 static constexpr index_t num_regs_per_blk = 16;
463 static constexpr index_t num_threads_per_blk = 32;
464 static constexpr index_t wave_size = 64;
465 static constexpr index_t num_input_blks = 2;
466 static constexpr index_t num_output_blks = 1;
467 static constexpr index_t m_per_blk = 32;
468 static constexpr index_t n_per_blk = 32;
469 static constexpr index_t k_per_blk = 2;
470 static constexpr bool is_k_reduction = true;
471
472 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
473 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
474 {
476 }
477};
478
479template <>
481{
482 static constexpr index_t group_size = 4;
483 static constexpr index_t num_groups_per_blk = 1;
484 static constexpr index_t num_regs_per_blk = 4;
485 static constexpr index_t num_threads_per_blk = 16;
486 static constexpr index_t wave_size = 64;
487 static constexpr index_t num_input_blks = 4;
488 static constexpr index_t num_output_blks = 1;
489 static constexpr index_t m_per_blk = 16;
490 static constexpr index_t n_per_blk = 16;
491 static constexpr index_t k_per_blk = 2;
492 static constexpr bool is_k_reduction = true;
493
494 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
495 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
496 {
498 }
499};
500
501template <>
503{
504 static constexpr index_t group_size = 4;
505 static constexpr index_t num_groups_per_blk = 4;
506 static constexpr index_t num_regs_per_blk = 16;
507 static constexpr index_t num_threads_per_blk = 32;
508 static constexpr index_t wave_size = 64;
509 static constexpr index_t num_input_blks = 2;
510 static constexpr index_t num_output_blks = 1;
511 static constexpr index_t m_per_blk = 32;
512 static constexpr index_t n_per_blk = 32;
513 static constexpr index_t k_per_blk = 4;
514 static constexpr bool is_k_reduction = true;
515
516 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
517 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
518 {
520 }
521};
522
523template <>
525{
526 static constexpr index_t group_size = 4;
527 static constexpr index_t num_groups_per_blk = 1;
528 static constexpr index_t num_regs_per_blk = 4;
529 static constexpr index_t num_threads_per_blk = 16;
530 static constexpr index_t wave_size = 64;
531 static constexpr index_t num_input_blks = 4;
532 static constexpr index_t num_output_blks = 1;
533 static constexpr index_t m_per_blk = 16;
534 static constexpr index_t n_per_blk = 16;
535 static constexpr index_t k_per_blk = 4;
536 static constexpr bool is_k_reduction = true;
537
538 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
539 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
540 {
542 }
543};
544
545template <>
547{
548 static constexpr index_t group_size = 4;
549 static constexpr index_t num_groups_per_blk = 4;
550 static constexpr index_t num_regs_per_blk = 16;
551 static constexpr index_t num_threads_per_blk = 32;
552 static constexpr index_t wave_size = 64;
553 static constexpr index_t num_input_blks = 2;
554 static constexpr index_t num_output_blks = 1;
555 static constexpr index_t m_per_blk = 32;
556 static constexpr index_t n_per_blk = 32;
557 static constexpr index_t k_per_blk = 8;
558 static constexpr bool is_k_reduction = true;
559
560 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
561 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
562 {
564 }
565};
566
567template <>
569{
570 static constexpr index_t group_size = 4;
571 static constexpr index_t num_groups_per_blk = 1;
572 static constexpr index_t num_regs_per_blk = 4;
573 static constexpr index_t num_threads_per_blk = 16;
574 static constexpr index_t wave_size = 64;
575 static constexpr index_t num_input_blks = 4;
576 static constexpr index_t num_output_blks = 1;
577 static constexpr index_t m_per_blk = 16;
578 static constexpr index_t n_per_blk = 16;
579 static constexpr index_t k_per_blk = 8;
580 static constexpr bool is_k_reduction = true;
581
582 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
583 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
584 {
586 }
587};
588
589template <>
591{
592 static constexpr index_t group_size = 4;
593 static constexpr index_t num_groups_per_blk = 4;
594 static constexpr index_t num_regs_per_blk = 16;
595 static constexpr index_t num_threads_per_blk = 32;
596 static constexpr index_t wave_size = 64;
597 static constexpr index_t num_input_blks = 2;
598 static constexpr index_t num_output_blks = 1;
599 static constexpr index_t m_per_blk = 32;
600 static constexpr index_t n_per_blk = 32;
601 static constexpr index_t k_per_blk = 16;
602 static constexpr bool is_k_reduction = true;
603
604 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
605 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
606 {
608 }
609};
610
611template <>
613{
614 static constexpr index_t group_size = 4;
615 static constexpr index_t num_groups_per_blk = 1;
616 static constexpr index_t num_regs_per_blk = 4;
617 static constexpr index_t num_threads_per_blk = 16;
618 static constexpr index_t wave_size = 64;
619 static constexpr index_t num_input_blks = 4;
620 static constexpr index_t num_output_blks = 1;
621 static constexpr index_t m_per_blk = 16;
622 static constexpr index_t n_per_blk = 16;
623 static constexpr index_t k_per_blk = 16;
624 static constexpr bool is_k_reduction = true;
625
626 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
627 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
628 {
630 }
631};
632
633template <>
635{
636 static constexpr index_t group_size = 1;
637 static constexpr index_t num_groups_per_blk = 4;
638 static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk;
639 static constexpr index_t num_threads_per_blk = 16;
640 static constexpr index_t wave_size = 64;
641 static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk;
642 static constexpr index_t num_output_blks = 1;
643 static constexpr index_t m_per_blk = 16;
644 static constexpr index_t n_per_blk = 16;
645 static constexpr index_t k_per_blk = 1;
646 static constexpr bool is_k_reduction = true;
647
648 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
649 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
650 {
652 }
653};
654
655template <>
657{
658 static constexpr index_t group_size = 4;
659 static constexpr index_t num_groups_per_blk = 4;
660 static constexpr index_t num_regs_per_blk = 16;
661 static constexpr index_t num_threads_per_blk = 32;
662 static constexpr index_t wave_size = 64;
663 static constexpr index_t num_input_blks = 2;
664 static constexpr index_t num_output_blks = 1;
665 static constexpr index_t m_per_blk = 32;
666 static constexpr index_t n_per_blk = 32;
667 static constexpr index_t k_per_blk = 8;
668 static constexpr bool is_k_reduction = true;
669
670 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
671 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
672 {
674 }
675};
676
677template <>
679{
680 static constexpr index_t group_size = 4;
681 static constexpr index_t num_groups_per_blk = 1;
682 static constexpr index_t num_regs_per_blk = 4;
683 static constexpr index_t num_threads_per_blk = 16;
684 static constexpr index_t wave_size = 64;
685 static constexpr index_t num_input_blks = 4;
686 static constexpr index_t num_output_blks = 1;
687 static constexpr index_t m_per_blk = 16;
688 static constexpr index_t n_per_blk = 16;
689 static constexpr index_t k_per_blk = 8;
690 static constexpr bool is_k_reduction = true;
691
692 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
693 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
694 {
696 }
697};
698
699template <>
701{
702 static constexpr index_t group_size = 4;
703 static constexpr index_t num_groups_per_blk = 4;
704 static constexpr index_t num_regs_per_blk = 16;
705 static constexpr index_t num_threads_per_blk = 32;
706 static constexpr index_t wave_size = 64;
707 static constexpr index_t num_input_blks = 2;
708 static constexpr index_t num_output_blks = 1;
709 static constexpr index_t m_per_blk = 32;
710 static constexpr index_t n_per_blk = 32;
711 static constexpr index_t k_per_blk = 8;
712 static constexpr bool is_k_reduction = true;
713
714 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
715 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
716 {
718 }
719};
720
721template <>
723{
724 static constexpr index_t group_size = 4;
725 static constexpr index_t num_groups_per_blk = 1;
726 static constexpr index_t num_regs_per_blk = 4;
727 static constexpr index_t num_threads_per_blk = 16;
728 static constexpr index_t wave_size = 64;
729 static constexpr index_t num_input_blks = 4;
730 static constexpr index_t num_output_blks = 1;
731 static constexpr index_t m_per_blk = 16;
732 static constexpr index_t n_per_blk = 16;
733 static constexpr index_t k_per_blk = 8;
734 static constexpr bool is_k_reduction = true;
735
736 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
737 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
738 {
740 }
741};
742
743template <>
745{
746 static constexpr index_t group_size = 4;
747 static constexpr index_t num_groups_per_blk = 4;
748 static constexpr index_t num_regs_per_blk = 16;
749 static constexpr index_t num_threads_per_blk = 32;
750 static constexpr index_t wave_size = 64;
751 static constexpr index_t num_input_blks = 2;
752 static constexpr index_t num_output_blks = 1;
753 static constexpr index_t m_per_blk = 32;
754 static constexpr index_t n_per_blk = 32;
755 static constexpr index_t k_per_blk = 8;
756 static constexpr bool is_k_reduction = true;
757
758 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
759 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
760 {
762 }
763};
764
765template <>
767{
768 static constexpr index_t group_size = 4;
769 static constexpr index_t num_groups_per_blk = 1;
770 static constexpr index_t num_regs_per_blk = 4;
771 static constexpr index_t num_threads_per_blk = 16;
772 static constexpr index_t wave_size = 64;
773 static constexpr index_t num_input_blks = 4;
774 static constexpr index_t num_output_blks = 1;
775 static constexpr index_t m_per_blk = 16;
776 static constexpr index_t n_per_blk = 16;
777 static constexpr index_t k_per_blk = 8;
778 static constexpr bool is_k_reduction = true;
779
780 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
781 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
782 {
784 }
785};
786
787template <>
789{
790 static constexpr index_t group_size = 4;
791 static constexpr index_t num_groups_per_blk = 4;
792 static constexpr index_t num_regs_per_blk = 16;
793 static constexpr index_t num_threads_per_blk = 32;
794 static constexpr index_t wave_size = 64;
795 static constexpr index_t num_input_blks = 2;
796 static constexpr index_t num_output_blks = 1;
797 static constexpr index_t m_per_blk = 32;
798 static constexpr index_t n_per_blk = 32;
799 static constexpr index_t k_per_blk = 8;
800 static constexpr bool is_k_reduction = true;
801
802 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
803 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
804 {
806 }
807};
808
809template <>
811{
812 static constexpr index_t group_size = 4;
813 static constexpr index_t num_groups_per_blk = 1;
814 static constexpr index_t num_regs_per_blk = 4;
815 static constexpr index_t num_threads_per_blk = 16;
816 static constexpr index_t wave_size = 64;
817 static constexpr index_t num_input_blks = 4;
818 static constexpr index_t num_output_blks = 1;
819 static constexpr index_t m_per_blk = 16;
820 static constexpr index_t n_per_blk = 16;
821 static constexpr index_t k_per_blk = 8;
822 static constexpr bool is_k_reduction = true;
823
824 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
825 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
826 {
828 }
829};
830
831template <>
833{
834 // clang-format off
835 static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
836 static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
837 static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
838 static constexpr index_t num_threads_per_blk = 32; // n_per_blk
839 static constexpr index_t wave_size = 64; // fixed
840 static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
841 static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
842 static constexpr index_t m_per_blk = 32; // from the instruction
843 static constexpr index_t n_per_blk = 32; // from the instruction
844 static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
845 static constexpr bool is_k_reduction = true; // ???
846 // clang-format on
847
848 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
849 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
850 {
852 }
853};
854
855template <>
857{
858 // clang-format off
859 static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
860 static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
861 static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
862 static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
863 static constexpr index_t wave_size = 64; // fixed
864 static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
865 static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
866 static constexpr index_t m_per_blk = 16; // from the instruction
867 static constexpr index_t n_per_blk = 16; // from the instruction
868 static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
869 static constexpr bool is_k_reduction = true; // ???
870 // clang-format on
871
872 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
873 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
874 {
876 }
877};
878
879template <>
881{
882 // clang-format off
883 static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
884 static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
885 static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
886 static constexpr index_t num_threads_per_blk = 32; // n_per_blk
887 static constexpr index_t wave_size = 64; // fixed
888 static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
889 static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
890 static constexpr index_t m_per_blk = 32; // from the instruction
891 static constexpr index_t n_per_blk = 32; // from the instruction
892 static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
893 static constexpr bool is_k_reduction = true; // ???
894 // clang-format on
895
896 template <index_t MPerXdlops,
897 index_t NPerXdlops,
898 index_t OpselA,
899 index_t OpselB,
900 class FloatA,
901 class ScaleA,
902 class FloatB,
903 class ScaleB,
904 class FloatC>
905 __device__ void run(const FloatA& a,
906 const ScaleA& scale_a,
907 const FloatB& b,
908 const ScaleB& scale_b,
909 FloatC& reg_c) const
910 {
912 a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
913 }
914};
915
916template <>
918{
919 // clang-format off
920 static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
921 static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
922 static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
923 static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
924 static constexpr index_t wave_size = 64; // fixed
925 static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
926 static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
927 static constexpr index_t m_per_blk = 16; // from the instruction
928 static constexpr index_t n_per_blk = 16; // from the instruction
929 static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
930 static constexpr bool is_k_reduction = true; // ???
931 // clang-format on
932
933 template <index_t MPerXdlops,
934 index_t NPerXdlops,
935 index_t OpselA,
936 index_t OpselB,
937 class FloatA,
938 class ScaleA,
939 class FloatB,
940 class ScaleB,
941 class FloatC>
942 __device__ void run(const FloatA& a,
943 const ScaleA& scale_a,
944 const FloatB& b,
945 const ScaleB& scale_b,
946 FloatC& reg_c) const
947 {
948
950 a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
951 }
952};
953
973template <>
975{
976 static constexpr index_t wave_size = 64; // fixed
977 static constexpr index_t m_per_blk = 16; // from the instruction
978 static constexpr index_t n_per_blk = 16; // from the instruction
979 static constexpr index_t num_threads_per_blk = n_per_blk; // 16
982 static constexpr index_t group_size = 4;
983 static constexpr index_t num_groups_per_blk = 1;
984 static constexpr index_t num_output_blks = 1;
985 static constexpr index_t k_per_blk = 2; // k_per_blk(K1PerXdlops) should be 2.
986 static constexpr bool is_k_reduction = true;
987
988 // AB register size : 2, register size: 4
989 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
990 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
991 {
993 }
994};
995
996template <>
998{
999 static constexpr index_t wave_size = 64; // fixed
1000 static constexpr index_t m_per_blk = 32; // from the instruction
1001 static constexpr index_t n_per_blk = 32; // from the instruction
1002 static constexpr index_t num_threads_per_blk = n_per_blk; // 32
1005 static constexpr index_t group_size = 4; // corresponding to CD rows mapping
1006 static constexpr index_t num_groups_per_blk = 4;
1007 static constexpr index_t num_output_blks = 1;
1008 static constexpr index_t k_per_blk = 2;
1009 static constexpr bool is_k_reduction = true;
1010 // AB register size: 2, CD register size: 16
1011 template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
1012 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1013 {
1015 }
1016};
1017
1018// gfx11
1020{
1021 static constexpr index_t group_size = 8;
1022 static constexpr index_t num_groups_per_blk = 1;
1023 static constexpr index_t num_regs_per_blk = 8;
1024 static constexpr index_t num_threads_per_blk = 16;
1025 static constexpr index_t wave_size = 32;
1026 static constexpr index_t num_input_blks = 1;
1027 static constexpr index_t num_output_blks = 1;
1028 static constexpr index_t m_per_blk = 16;
1029 static constexpr index_t n_per_blk = 16;
1030 static constexpr index_t k_per_blk = 16;
1031 static constexpr bool is_k_reduction = true;
1032};
1033
1034template <>
1036{
1037 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1038 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1039 {
1041 }
1042};
1043
1044template <>
1046{
1047 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1048 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1049 {
1051 }
1052};
1053
1054template <>
1056{
1057 template <index_t MPerWmma,
1058 index_t NPerWmma,
1059 class FloatA,
1060 class FloatB,
1061 class FloatC,
1062 bool neg_a = true,
1063 bool neg_b = true,
1064 bool clamp = false>
1065 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1066 {
1068 }
1069};
1070
1071template <>
1073{
1074 static constexpr index_t k_per_blk = 2;
1075 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1076 __device__ void run(const FloatA&, const FloatB&, FloatC&) const
1077 {
1078 // empty for all unsupported types.
1079 }
1080};
1081
1082// gfx12
1084{
1085 static constexpr index_t group_size = 8;
1086 static constexpr index_t num_groups_per_blk = 1;
1087 static constexpr index_t num_regs_per_blk = 8;
1088 static constexpr index_t num_threads_per_blk = 16;
1089 static constexpr index_t wave_size = 32;
1090 static constexpr index_t num_input_blks = 2;
1091 static constexpr index_t num_output_blks = 1;
1092 static constexpr index_t m_per_blk = 16;
1093 static constexpr index_t n_per_blk = 16;
1094 static constexpr index_t k_per_blk = 8;
1095 static constexpr bool is_k_reduction = true;
1096};
1097
1098template <>
1100{
1101 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1102 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1103 {
1105 }
1106};
1107
1108template <>
1110{
1111 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1112 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1113 {
1115 }
1116};
1117
1118template <>
1120{
1121 template <index_t MPerWmma,
1122 index_t NPerWmma,
1123 class FloatA,
1124 class FloatB,
1125 class FloatC,
1126 bool neg_a = true,
1127 bool neg_b = true,
1128 bool clamp = false>
1129 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1130 {
1132 a, b, reg_c);
1133 }
1134};
1135
1136template <>
1138{
1139 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1140 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1141 {
1143 }
1144};
1145
1146template <>
1148{
1149 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1150 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1151 {
1153 }
1154};
1155
1156template <>
1158{
1159 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1160 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1161 {
1163 }
1164};
1165
1166template <>
1168{
1169 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1170 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1171 {
1173 }
1174};
1175
1176template <>
1178{
1179 static constexpr index_t k_per_blk = 2;
1180 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1181 __device__ void run(const FloatA&, const FloatB&, FloatC&) const
1182 {
1183 // empty for all unsupported types.
1184 }
1185};
1186
1201template <typename base_type,
1202 index_t MPerXdlops,
1203 index_t NPerXdlops,
1204 typename additional_type = base_type,
1205 bool is_single_rate_mfma = false,
1206 bool is_scale_mfma = false>
1208{
1209 template <typename base_type_,
1210 index_t MPerXdlops_,
1211 index_t NPerXdlops_,
1212 typename additional_type_ = base_type_,
1213 bool is_single_rate_mfma_ = false,
1214 bool is_scale_mfma_ = false>
1215 static constexpr auto GetMfma();
1216
1217 template <>
1219 {
1220#if defined(__gfx12__)
1222#elif defined(__gfx11__)
1224#else
1226#endif
1227 }
1228
1229 template <>
1230 constexpr auto GetMfma<float, 64, 64>()
1231 {
1233 }
1234
1235 template <>
1236 constexpr auto GetMfma<float, 32, 64>()
1237 {
1239 }
1240
1241 template <>
1242 constexpr auto GetMfma<float, 16, 64>()
1243 {
1245 }
1246
1247 template <>
1248 constexpr auto GetMfma<float, 8, 64>()
1249 {
1251 }
1252
1253 template <>
1254 constexpr auto GetMfma<float, 4, 64>()
1255 {
1257 }
1258
1259 template <>
1260 constexpr auto GetMfma<float, 32, 32>()
1261 {
1263 }
1264
1265 template <>
1266 constexpr auto GetMfma<float, 16, 16>()
1267 {
1268#if defined(__gfx12__)
1270#elif defined(__gfx11__)
1272#else
1274#endif
1275 }
1276
1277 template <>
1279 {
1280#if defined(__gfx12__)
1282#elif defined(__gfx11__)
1284#elif defined(__gfx942__)
1286#else
1288#endif
1289 }
1290
1291 template <>
1293 {
1294#if defined(__gfx12__)
1296#elif defined(__gfx11__)
1298#elif defined(__gfx942__)
1300#else
1302#endif
1303 }
1304
1305 template <>
1307 {
1309 }
1310
1311 template <>
1313 {
1315 }
1316
1317 template <>
1319 {
1320#if defined(__gfx950__)
1322#else
1324#endif
1325 }
1326 template <>
1331
1332 template <>
1334 {
1335#if defined(__gfx12__)
1337#elif defined(__gfx11__)
1339#elif defined(__gfx950__)
1341#else
1343#endif
1344 }
1345
1346 template <>
1348 {
1349#if defined(__gfx12__)
1351#elif defined(__gfx11__)
1353#else
1355#endif
1356 }
1357
1358 template <>
1360 {
1362 }
1363
1364 template <>
1365 constexpr auto GetMfma<half_t, 8, 64>()
1366 {
1368 }
1369
1370 template <>
1371 constexpr auto GetMfma<half_t, 4, 64>()
1372 {
1374 }
1375
1376 template <>
1378 {
1379#if defined(__gfx950__)
1381#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1383#else
1385#endif
1386 }
1387
1388 template <>
1390 {
1391#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1393#else
1395#endif
1396 }
1397
1398 template <>
1400 {
1401#if defined(__gfx12__)
1403#elif defined(__gfx11__)
1405#elif defined(__gfx950__)
1407#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1409#else
1411#endif
1412 }
1413
1414 template <>
1416 {
1417#if defined(__gfx12__)
1419#elif defined(__gfx11__)
1421#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1423#else
1425#endif
1426 }
1427
1428 template <>
1430 {
1431#if defined(__gfx950__)
1433#elif defined(__gfx942__)
1435#else
1437#endif
1438 }
1439
1440 template <>
1442 {
1443#if defined(__gfx942__) || defined(__gfx950__)
1445#else
1447#endif
1448 }
1449
1450 template <>
1452 {
1453#if defined(__gfx12__)
1455#elif defined(__gfx11__)
1457#elif defined(__gfx950__)
1459#elif defined(__gfx942__)
1461#else
1463#endif
1464 }
1465
1466 template <>
1468 {
1469#if defined(__gfx12__)
1471#elif defined(__gfx11__)
1473#elif defined(__gfx942__) || defined(__gfx950__)
1475#else
1477#endif
1478 }
1479
1480 template <>
1485
1486 template <>
1488 {
1489#if defined(__gfx950__)
1491#else
1493#endif
1494 }
1495
1496 template <>
1501
1502 template <>
1507 template <>
1512 template <>
1514 {
1515#if defined(__gfx12__)
1517#elif defined(__gfx11__)
1519#else
1521#endif
1522 }
1523
1524 template <>
1526 {
1527#if defined(__gfx12__)
1529#elif defined(__gfx11__)
1531#else
1533#endif
1534 }
1535
1536 template <>
1538 {
1539#if defined(__gfx12__)
1541#elif defined(__gfx11__)
1543#elif defined(__gfx950__)
1545#else
1547#endif
1548 }
1549
1550 template <>
1552 {
1553#if defined(__gfx12__)
1555#elif defined(__gfx11__)
1557#else
1559#endif
1560 }
1561
1562 template <>
1564 {
1565#if defined(__gfx12__)
1567#elif defined(__gfx11__)
1569#else
1571#endif
1572 }
1573
1574 template <>
1576 {
1577#if defined(__gfx12__)
1579#elif defined(__gfx11__)
1581#else
1583#endif
1584 }
1585
1586 template <>
1588 {
1589#if defined(__gfx12__)
1591#elif defined(__gfx11__)
1593#else
1595#endif
1596 }
1597
1598 template <>
1603 template <>
1605 {
1606#if defined(__gfx12__)
1608#elif defined(__gfx11__)
1610#else
1612#endif
1613 }
1614 template <>
1619 template <>
1621 {
1622#if defined(__gfx12__)
1624#elif defined(__gfx11__)
1626#else
1628#endif
1629 }
1630
1631 template <>
1636
1637 template <>
1639 {
1640#if defined(__gfx950__)
1642#else
1644#endif
1645 }
1646
1647 template <>
1649 {
1650#if defined(__gfx12__)
1652#elif defined(__gfx11__)
1654#else
1656#endif
1657 }
1658
1659 template <>
1661 {
1662#if defined(__gfx12__)
1664#elif defined(__gfx11__)
1666#elif defined(__gfx950__)
1668#else
1670#endif
1671 }
1672
1673 template <>
1678
1679 template <>
1681 {
1682#if defined(__gfx950__)
1684#else
1686#endif
1687 }
1688
1689 template <>
1691 {
1692#if defined(__gfx12__)
1694#elif defined(__gfx11__)
1696#else
1698#endif
1699 }
1700
1701 template <>
1703 {
1704#if defined(__gfx12__)
1706#elif defined(__gfx11__)
1708#elif defined(__gfx950__)
1710#else
1712#endif
1713 }
1714
1715 template <>
1720
1721 template <>
1723 {
1724#if defined(__gfx950__)
1726#else
1728#endif
1729 }
1730
1731 template <>
1733 {
1734#if defined(__gfx12__)
1736#elif defined(__gfx11__)
1738#else
1740#endif
1741 }
1742
1743 template <>
1745 {
1746#if defined(__gfx12__)
1748#elif defined(__gfx11__)
1750#elif defined(__gfx950__)
1752#else
1754#endif
1755 }
1756
1758 MPerXdlops,
1759 NPerXdlops,
1761 is_single_rate_mfma,
1762 is_scale_mfma>()>{};
1763
1764 __host__ __device__ constexpr MfmaSelector()
1765 {
1766 static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk ==
1767 selected_mfma.num_regs_per_blk,
1768 "wrong! num_regs_per_blk");
1769
1770 static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk,
1771 "n_per_blk != num_threads_per_blk");
1772#if defined(__gfx11__)
1773 if constexpr(MPerXdlops == 16 && NPerXdlops == 16)
1774 {
1775 static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks * 2 ==
1776 selected_mfma.m_per_blk,
1777 "m_per_blk != num_input_blks * num_regs_per_blk");
1778 }
1779#else
1780 static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks ==
1781 selected_mfma.m_per_blk,
1782 "m_per_blk != num_input_blks * num_regs_per_blk");
1783#endif
1784
1785 static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks ||
1786 selected_mfma.num_output_blks == 1,
1787 "incorrect num_output_blks");
1788
1789 static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size ==
1790 selected_mfma.m_per_blk * selected_mfma.n_per_blk,
1791 "num_regs_per_blk incorrect");
1792
1793 static_assert(selected_mfma.is_k_reduction ||
1794 (selected_mfma.num_input_blks == selected_mfma.num_output_blks),
1795 "is_k_reduction wrong!");
1796 }
1797
1798 static constexpr bool IsABroadcast()
1799 {
1800 static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
1801 return true;
1802 }
1803
1804 static constexpr index_t GetKPerXdlops()
1805 {
1806 return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) *
1807 selected_mfma.k_per_blk;
1808 }
1809
1810 static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; }
1811};
1812
1813template <typename base_type,
1814 index_t MPerXdlops,
1815 index_t NPerXdlops,
1816 index_t KPack,
1817 typename additional_type = base_type,
1818 bool TransposeC = false,
1819 bool is_scale_mfma = false>
1821{
1822 static constexpr auto I0 = Number<0>{};
1823 static constexpr auto I1 = Number<1>{};
1824 static constexpr auto I2 = Number<2>{};
1825 static constexpr auto I3 = Number<3>{};
1826 static constexpr auto I4 = Number<4>{};
1827 static constexpr auto I5 = Number<5>{};
1828
1831
1832 __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; }
1833
1834 __device__ static constexpr index_t GetNumXdlops()
1835 {
1836 return MPerXdlops * NPerXdlops /
1837 (mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks);
1838 }
1839
1840 __host__ __device__ constexpr XdlopsGemm()
1841 {
1842 static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
1843 NPerXdlops == 64,
1844 "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1845
1846 static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
1847 MPerXdlops == 64,
1848 "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1849#if defined(__HIP_DEVICE_COMPILE__)
1850 static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
1851#endif
1852 }
1853
1854 // XDL output supporting C = A * B
1855 // M2_N2 -> M2_M3_M4_N2
1856 template <typename CDesc_M0_N0_M1_N1_M2_N2>
1857 __host__ __device__ static constexpr auto
1858 MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1859 {
1860 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1861 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1862 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1863 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1864 constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1865
1867 c_desc_m0_n0_m1_n1_m2_n2,
1874 Number<mfma_instr.group_size>{})),
1877 Sequence<1>{},
1878 Sequence<2>{},
1879 Sequence<3>{},
1880 Sequence<4>{},
1881 Sequence<5>{}),
1883 Sequence<1>{},
1884 Sequence<2>{},
1885 Sequence<3>{},
1887 Sequence<7>{}));
1888 }
1889
1890 // XDL output supporting C = A * B
1891 // M3_N3 -> M3_M4_M5_N3
1892 template <typename CDesc_M0_N0_M1_N1_M2_N2>
1893 __host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
1894 const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1895 {
1896 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1897 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1898 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1899 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1900 const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
1901 const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
1902 constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1903
1905 c_desc_m0_n0_m1_n1_m2_n2,
1914 Number<mfma_instr.group_size>{})),
1917 Sequence<1>{},
1918 Sequence<2>{},
1919 Sequence<3>{},
1920 Sequence<4>{},
1921 Sequence<5>{},
1922 Sequence<6>{},
1923 Sequence<7>{}),
1925 Sequence<1>{},
1926 Sequence<2>{},
1927 Sequence<3>{},
1928 Sequence<4>{},
1929 Sequence<5>{},
1931 Sequence<9>{}));
1932 }
1933
1934 // transposed XDL output supporting C' = B' * A'
1935 // M2_N2 -> M2_N2_N3_N4
1936 template <typename CDesc_M0_N0_M1_N1_M2_N2>
1937 __host__ __device__ static constexpr auto
1938 MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1939 {
1940 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1941 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1942 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1943 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1944 constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1945
1947 c_desc_m0_n0_m1_n1_m2_n2,
1955 Number<mfma_instr.group_size>{}))),
1957 Sequence<1>{},
1958 Sequence<2>{},
1959 Sequence<3>{},
1960 Sequence<4>{},
1961 Sequence<5>{}),
1963 Sequence<1>{},
1964 Sequence<2>{},
1965 Sequence<3>{},
1966 Sequence<4>{},
1968 }
1969
1970 template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
1971 __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
1972 const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
1973 {
1974 const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
1975 const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
1976 const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
1977 const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
1978 const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
1979 constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1980
1982 c_desc_g_m0_n0_m1_n1_m2_n2,
1989 mfma_instr.num_groups_per_blk, num_blks, mfma_instr.group_size)),
1990 make_pass_through_transform(mfma_instr.num_threads_per_blk)),
1992 Sequence<1>{},
1993 Sequence<2>{},
1994 Sequence<3>{},
1995 Sequence<4>{},
1996 Sequence<5>{},
1997 Sequence<6>{}),
1999 Sequence<1>{},
2000 Sequence<2>{},
2001 Sequence<3>{},
2002 Sequence<4>{},
2004 Sequence<8>{}));
2005 }
2006
2007 __device__ __host__ static constexpr index_t GetRegSizePerXdlops()
2008 {
2009 return mfma_instr.num_regs_per_blk;
2010 }
2011
2012 __device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; }
2013
2014 template <class FloatA, class FloatB, class FloatC>
2015 __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
2016 {
2017 static_assert(
2024 "base_type must be double, float, tf32_t, half, bfloat16, int8_t, f8_t or bf8_t!");
2025
2026 static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
2027 if constexpr(!TransposeC)
2028 {
2029 mfma_instr.template run<MPerXdlops, NPerXdlops>(
2030 p_a_wave[k], p_b_wave[k], p_c_thread);
2031 }
2032 else
2033 {
2034 mfma_instr.template run<MPerXdlops, NPerXdlops>(
2035 p_b_wave[k], p_a_wave[k], p_c_thread);
2036 }
2037 });
2038 }
2039
2040 template <index_t OpselA,
2041 index_t OpselB,
2042 class FloatA,
2043 class ScaleA,
2044 class FloatB,
2045 class ScaleB,
2046 class FloatC>
2047 __device__ void Run(const FloatA& p_a_wave,
2048 const ScaleA& a_scale_thread,
2049 const FloatB& p_b_wave,
2050 const ScaleB& b_scale_thread,
2051 FloatC& p_c_thread) const
2052 {
2053 static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
2054 if constexpr(!TransposeC)
2055 {
2056 mfma_instr.template run<MPerXdlops, NPerXdlops, OpselA, OpselB>(
2057 p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
2058 }
2059 else
2060 {
2061 mfma_instr.template run<MPerXdlops, NPerXdlops, OpselB, OpselA>(
2062 p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
2063 }
2064 });
2065 }
2066
2067 __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; }
2068
2069 __device__ static auto GetBlkIdx()
2070 {
2071 const auto laneId = GetLaneId();
2072 constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
2073
2074 constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
2075 make_tuple(
2076 make_merge_transform(make_tuple(1, num_blks, mfma_instr.num_threads_per_blk))),
2079
2080 const auto blk_idx =
2081 threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
2082
2083 const auto blk_id = blk_idx[I1];
2084 const auto blk_td = blk_idx[I2];
2085
2086 return make_tuple(blk_id, blk_td);
2087 }
2088
2089 template <bool SwizzleA>
2090 __device__ static auto GetGfx11InputBlkIdx()
2091 {
2092 auto laneId = GetLaneId() % mfma_instr.num_threads_per_blk;
2093 if constexpr(SwizzleA)
2094 {
2095 laneId = ((laneId & 1) << 3) | (laneId >> 1);
2096 }
2097 constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
2099 make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))),
2102
2103 const auto blk_idx =
2104 threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
2105
2106 const auto blk_id = blk_idx[I1];
2107 const auto blk_td = blk_idx[I2];
2108
2109 return make_tuple(blk_id, blk_td);
2110 }
2111
2112 __host__ __device__ static auto CalculateAThreadOriginDataIndex()
2113 {
2114 const auto laneId = GetLaneId();
2115#if defined(__gfx11__)
2116 const auto blk_idx = GetGfx11InputBlkIdx<!TransposeC>();
2117#else
2118 const auto blk_idx = GetBlkIdx();
2119#endif
2120
2121 const auto blk_id = blk_idx[I0];
2122 const auto blk_td = blk_idx[I1];
2123
2124 if constexpr(mfma_instr.is_k_reduction)
2125 {
2126 return make_tuple(blk_id, blk_td);
2127 }
2128 else
2129 {
2130 return make_tuple(0, laneId);
2131 }
2132 }
2133
2134 __host__ __device__ static auto CalculateBThreadOriginDataIndex()
2135 {
2136 const auto laneId = GetLaneId();
2137#if defined(__gfx11__)
2138 const auto blk_idx = GetGfx11InputBlkIdx<TransposeC>();
2139#else
2140 const auto blk_idx = GetBlkIdx();
2141#endif
2142
2143 const auto blk_id = blk_idx[I0];
2144 const auto blk_td = blk_idx[I1];
2145
2146 if constexpr(mfma_instr.is_k_reduction)
2147 {
2148 return make_tuple(blk_id, blk_td);
2149 }
2150 else
2151 {
2152 return make_tuple(0, laneId);
2153 }
2154 }
2155
2156 __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
2157 {
2158 const auto blk_idx = GetBlkIdx();
2159
2160 const auto blk_id = blk_idx[I0];
2161 const auto blk_td = blk_idx[I1];
2162
2163 index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td;
2164 index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size;
2165
2166 return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
2167 }
2168
2169 __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
2170 {
2171 const auto blk_idx = GetBlkIdx();
2172
2173 const auto blk_id = blk_idx[I0];
2174 const auto blk_td = blk_idx[I1];
2175
2176 return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
2177 }
2178
2179 // Falls back to single rate instruction on gfx950 if KPack is single rate; no change on gfx942-
2180 // when base_type is either f8_t or bf8_t, additional_type will always be either f8_t or bf8_t,
2181 // except Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
2182 static constexpr bool is_single_rate_mfma =
2184 KPack <= 4) ||
2185 (is_same<base_type, int8_t>::value && KPack <= 8) ||
2188 ? true
2189 : false;
2190 static constexpr auto mfma = MfmaSelector<base_type,
2191 MPerXdlops,
2192 NPerXdlops,
2193 additional_type,
2195 is_scale_mfma>{};
2196
2197 static constexpr auto mfma_instr = mfma.selected_mfma;
2198
2199 static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
2200 static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
2201 static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
2202
2203 __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
2204 {
2205 return make_tuple(
2207 }
2208};
2209
2210} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
MfmaInstr
Definition xdlops_gemm.hpp:45
@ mfma_f32_16x16x32f8bf8
Definition xdlops_gemm.hpp:70
@ mfma_f32_32x32x16f8bf8
Definition xdlops_gemm.hpp:69
@ mfma_f32_32x32x1f32
Definition xdlops_gemm.hpp:46
@ wmma_f32_16x16x16_bf16_gfx12
Definition xdlops_gemm.hpp:92
@ mfma_f32_32x32x4xf32
Definition xdlops_gemm.hpp:84
@ mfma_f32_32x32x64f8f6f4
Definition xdlops_gemm.hpp:79
@ mfma_f32_32x32x16f8f8
Definition xdlops_gemm.hpp:65
@ mfma_f32_32x32x16bf8f8
Definition xdlops_gemm.hpp:71
@ wmma_unsupport_16x16_gfx11
Definition xdlops_gemm.hpp:89
@ mfma_f32_16x16x4f16
Definition xdlops_gemm.hpp:52
@ mfma_f32_16x16x32bf16
Definition xdlops_gemm.hpp:76
@ mfma_i32_16x16x32i8
Definition xdlops_gemm.hpp:63
@ mfma_f32_16x16x1f32
Definition xdlops_gemm.hpp:47
@ wmma_i32_16x16x16_iu8_gfx12
Definition xdlops_gemm.hpp:93
@ mfma_scale_f32_32x32x64f8f6f4
Definition xdlops_gemm.hpp:81
@ wmma_f32_16x16x16_bf8f8_gfx12
Definition xdlops_gemm.hpp:96
@ wmma_f32_16x16x16_f16_gfx12
Definition xdlops_gemm.hpp:91
@ mfma_i32_32x32x16i8
Definition xdlops_gemm.hpp:62
@ mfma_f32_16x16x16f16
Definition xdlops_gemm.hpp:55
@ mfma_f32_16x16x8xf32
Definition xdlops_gemm.hpp:83
@ wmma_i32_16x16x16_iu8
Definition xdlops_gemm.hpp:88
@ mfma_i32_16x16x16i8
Definition xdlops_gemm.hpp:61
@ mfma_i32_32x32x8i8
Definition xdlops_gemm.hpp:60
@ mfma_i32_32x32x32i8
Definition xdlops_gemm.hpp:77
@ mfma_f32_32x32x2f32
Definition xdlops_gemm.hpp:49
@ wmma_f32_16x16x16_bf8bf8_gfx12
Definition xdlops_gemm.hpp:97
@ mfma_i32_16x16x64i8
Definition xdlops_gemm.hpp:78
@ wmma_unsupport_16x16_gfx12
Definition xdlops_gemm.hpp:98
@ mfma_f32_4x4x1f32
Definition xdlops_gemm.hpp:48
@ mfma_f32_16x16x32f16
Definition xdlops_gemm.hpp:74
@ mfma_f32_16x16x16bf16_1k
Definition xdlops_gemm.hpp:57
@ wmma_f32_16x16x16_f8f8_gfx12
Definition xdlops_gemm.hpp:94
@ mfma_f64_16x16x4f64
Definition xdlops_gemm.hpp:64
@ wmma_f32_16x16x16_bf16
Definition xdlops_gemm.hpp:87
@ wmma_f32_16x16x16_f16
Definition xdlops_gemm.hpp:86
@ mfma_f32_16x16x32f8f8
Definition xdlops_gemm.hpp:66
@ mfma_scale_f32_16x16x128f8f6f4
Definition xdlops_gemm.hpp:82
@ mfma_f32_16x16x8bf16
Definition xdlops_gemm.hpp:59
@ mfma_f32_16x16x32bf8bf8
Definition xdlops_gemm.hpp:68
@ mfma_f32_16x16x128f8f6f4
Definition xdlops_gemm.hpp:80
@ mfma_f32_32x32x8f16
Definition xdlops_gemm.hpp:54
@ mfma_f32_16x16x4f32
Definition xdlops_gemm.hpp:50
@ mfma_f32_32x32x16bf8bf8
Definition xdlops_gemm.hpp:67
@ mfma_f32_32x32x16f16
Definition xdlops_gemm.hpp:73
@ mfma_f32_32x32x8bf16_1k
Definition xdlops_gemm.hpp:56
@ mfma_f32_4x4x4f16
Definition xdlops_gemm.hpp:53
@ wmma_f32_16x16x16_f8bf8_gfx12
Definition xdlops_gemm.hpp:95
@ mfma_f32_32x32x4bf16
Definition xdlops_gemm.hpp:58
@ mfma_f32_16x16x32bf8f8
Definition xdlops_gemm.hpp:72
@ mfma_f32_32x32x4f16
Definition xdlops_gemm.hpp:51
@ mfma_f32_32x32x16bf16
Definition xdlops_gemm.hpp:75
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
typename packed_type_info< T >::element_type element_type_t
Definition data_type.hpp:408
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
@ wmma_f32_16x16x16_bf16_gfx12
Definition wmma_gemm.hpp:23
@ wmma_i32_16x16x16_iu8_gfx12
Definition wmma_gemm.hpp:24
@ wmma_f32_16x16x16_bf8f8_gfx12
Definition wmma_gemm.hpp:27
@ wmma_f32_16x16x16_f16_gfx12
Definition wmma_gemm.hpp:22
@ wmma_i32_16x16x16_iu8
Definition wmma_gemm.hpp:19
@ wmma_f32_16x16x16_bf8bf8_gfx12
Definition wmma_gemm.hpp:28
@ wmma_f32_16x16x16_f8f8_gfx12
Definition wmma_gemm.hpp:25
@ wmma_f32_16x16x16_bf16
Definition wmma_gemm.hpp:16
@ wmma_f32_16x16x16_f16
Definition wmma_gemm.hpp:15
@ wmma_f32_16x16x16_f8bf8_gfx12
Definition wmma_gemm.hpp:26
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
static constexpr bool IsABroadcast()
Definition xdlops_gemm.hpp:1798
static constexpr index_t GetK1PerXdlops()
Definition xdlops_gemm.hpp:1810
static constexpr auto GetMfma()
static constexpr auto selected_mfma
Definition xdlops_gemm.hpp:1757
static constexpr index_t GetKPerXdlops()
Definition xdlops_gemm.hpp:1804
__host__ __device__ constexpr MfmaSelector()
Definition xdlops_gemm.hpp:1764
Definition utility/sequence.hpp:43
MultiIndex< 4 > CIndex4D
Definition xdlops_gemm.hpp:1830
static constexpr auto mfma_instr
Definition xdlops_gemm.hpp:2197
static __device__ constexpr index_t GetWaveSize()
Definition xdlops_gemm.hpp:2012
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition xdlops_gemm.hpp:2134
static __device__ auto GetBlkIdx()
Definition xdlops_gemm.hpp:2069
static constexpr auto I2
Definition xdlops_gemm.hpp:1824
__host__ static __device__ constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition xdlops_gemm.hpp:1971
static __device__ auto GetLaneId()
Definition xdlops_gemm.hpp:2067
static constexpr auto K0PerXdlops
Definition xdlops_gemm.hpp:2201
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition xdlops_gemm.hpp:2112
static __device__ constexpr index_t GetNumXdlops()
Definition xdlops_gemm.hpp:1834
static constexpr bool is_single_rate_mfma
Definition xdlops_gemm.hpp:2182
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition xdlops_gemm.hpp:2169
static __device__ constexpr index_t GetNumBlks()
Definition xdlops_gemm.hpp:1832
__device__ static __host__ constexpr index_t GetRegSizePerXdlops()
Definition xdlops_gemm.hpp:2007
static __device__ auto GetGfx11InputBlkIdx()
Definition xdlops_gemm.hpp:2090
__host__ static __device__ constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition xdlops_gemm.hpp:1938
static constexpr auto I5
Definition xdlops_gemm.hpp:1827
static constexpr auto I3
Definition xdlops_gemm.hpp:1825
static constexpr auto I0
Definition xdlops_gemm.hpp:1822
__device__ void Run(const FloatA &p_a_wave, const ScaleA &a_scale_thread, const FloatB &p_b_wave, const ScaleB &b_scale_thread, FloatC &p_c_thread) const
Definition xdlops_gemm.hpp:2047
__host__ static __device__ constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition xdlops_gemm.hpp:1893
static constexpr auto I1
Definition xdlops_gemm.hpp:1823
static constexpr auto K1PerXdlops
Definition xdlops_gemm.hpp:2200
static constexpr auto KPerXdlops
Definition xdlops_gemm.hpp:2199
__host__ static __device__ constexpr auto GetCM0M1M2NThreadBlkLengths()
Definition xdlops_gemm.hpp:2203
static constexpr auto I4
Definition xdlops_gemm.hpp:1826
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition xdlops_gemm.hpp:2015
static constexpr auto mfma
Definition xdlops_gemm.hpp:2190
__host__ static __device__ constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition xdlops_gemm.hpp:1858
__host__ __device__ constexpr XdlopsGemm()
Definition xdlops_gemm.hpp:1840
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition xdlops_gemm.hpp:2156
MultiIndex< 2 > CIndex
Definition xdlops_gemm.hpp:1829
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition amd_xdlops.hpp:1202
Definition amd_xdlops.hpp:303
Definition amd_xdlops.hpp:193
Definition amd_xdlops.hpp:70
Definition amd_xdlops.hpp:269
Definition amd_xdlops.hpp:1483
Definition amd_xdlops.hpp:1609
Definition amd_xdlops.hpp:159
Definition amd_xdlops.hpp:1546
Definition amd_xdlops.hpp:1420
Definition amd_xdlops.hpp:207
Definition amd_xdlops.hpp:56
Definition amd_xdlops.hpp:331
Definition amd_xdlops.hpp:1641
Definition amd_xdlops.hpp:249
Definition amd_xdlops.hpp:1451
Definition amd_xdlops.hpp:1577
Definition amd_xdlops.hpp:139
Definition amd_xdlops.hpp:1514
Definition amd_xdlops.hpp:1388
Definition amd_xdlops.hpp:15
Definition amd_xdlops.hpp:42
Definition amd_xdlops.hpp:317
Definition amd_xdlops.hpp:112
Definition amd_xdlops.hpp:1661
Definition amd_xdlops.hpp:481
Definition amd_xdlops.hpp:289
Definition amd_xdlops.hpp:179
Definition amd_xdlops.hpp:84
Definition amd_xdlops.hpp:221
Definition amd_xdlops.hpp:461
Definition amd_xdlops.hpp:364
Definition amd_xdlops.hpp:442
Definition amd_xdlops.hpp:403
Definition amd_xdlops.hpp:423
Definition amd_xdlops.hpp:383
Definition amd_xdlops.hpp:345
Definition amd_xdlops.hpp:886
Definition amd_xdlops.hpp:666
Definition amd_wmma.hpp:50
Definition amd_wmma.hpp:25
Definition amd_wmma.hpp:121
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:867
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:862
static constexpr index_t group_size
Definition xdlops_gemm.hpp:859
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:866
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:865
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:869
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:861
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:873
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:863
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:868
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:864
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:860
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:445
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:448
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:440
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:439
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:443
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:444
static constexpr index_t group_size
Definition xdlops_gemm.hpp:438
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:441
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:442
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:447
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:446
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:451
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:312
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:313
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:309
static constexpr index_t group_size
Definition xdlops_gemm.hpp:306
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:315
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:319
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:307
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:314
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:316
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:311
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:310
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:308
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:180
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:174
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:186
static constexpr index_t group_size
Definition xdlops_gemm.hpp:173
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:182
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:178
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:179
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:183
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:181
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:175
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:176
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:177
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:421
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:422
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:420
static constexpr index_t group_size
Definition xdlops_gemm.hpp:416
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:425
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:417
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:424
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:418
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:426
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:429
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:419
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:423
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:728
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:726
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:731
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:737
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:729
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:730
static constexpr index_t group_size
Definition xdlops_gemm.hpp:724
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:732
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:733
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:727
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:725
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:734
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:820
static constexpr index_t group_size
Definition xdlops_gemm.hpp:812
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:825
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:813
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:816
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:819
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:815
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:814
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:818
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:821
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:817
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:822
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:292
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:297
static constexpr index_t group_size
Definition xdlops_gemm.hpp:284
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:288
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:287
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:293
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:289
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:286
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:290
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:291
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:285
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:294
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:772
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:770
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:773
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:776
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:771
static constexpr index_t group_size
Definition xdlops_gemm.hpp:768
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:777
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:781
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:769
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:775
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:778
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:774
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:688
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:687
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:690
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:689
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:681
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:693
static constexpr index_t group_size
Definition xdlops_gemm.hpp:680
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:684
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:683
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:686
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:685
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:682
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:332
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:338
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:331
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:337
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:329
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:341
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:333
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:335
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:336
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:334
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:330
static constexpr index_t group_size
Definition xdlops_gemm.hpp:328
static constexpr index_t group_size
Definition xdlops_gemm.hpp:151
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:156
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:152
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:164
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:157
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:159
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:153
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:155
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:154
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:161
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:160
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:158
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:492
static constexpr index_t group_size
Definition xdlops_gemm.hpp:482
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:487
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:495
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:489
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:488
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:491
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:483
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:490
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:484
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:485
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:486
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:976
static constexpr index_t group_size
Definition xdlops_gemm.hpp:982
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:978
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:979
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:985
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:983
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:984
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:977
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:981
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:986
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:980
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:990
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:378
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:375
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:373
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:374
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:379
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:380
static constexpr index_t group_size
Definition xdlops_gemm.hpp:372
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:381
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:382
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:376
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:385
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:377
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:708
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:705
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:709
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:707
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:703
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:712
static constexpr index_t group_size
Definition xdlops_gemm.hpp:702
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:711
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:704
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:706
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:710
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:715
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:797
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:798
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:800
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:796
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:794
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:792
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:795
static constexpr index_t group_size
Definition xdlops_gemm.hpp:790
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:803
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:799
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:793
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:791
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:268
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:272
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:267
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:271
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:264
static constexpr index_t group_size
Definition xdlops_gemm.hpp:262
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:263
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:275
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:265
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:266
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:269
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:270
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:755
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:752
static constexpr index_t group_size
Definition xdlops_gemm.hpp:746
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:759
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:748
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:747
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:753
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:756
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:750
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:751
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:749
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:754
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:663
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:662
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:664
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:667
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:659
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:668
static constexpr index_t group_size
Definition xdlops_gemm.hpp:658
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:661
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:671
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:665
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:666
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:660
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:113
static constexpr index_t group_size
Definition xdlops_gemm.hpp:107
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:114
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:109
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:111
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:115
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:117
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:108
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:116
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:110
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:120
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:112
static constexpr index_t group_size
Definition xdlops_gemm.hpp:129
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:142
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:138
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:135
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:134
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:137
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:132
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:139
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:133
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:131
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:136
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:130
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:463
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:469
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:473
static constexpr index_t group_size
Definition xdlops_gemm.hpp:460
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:470
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:461
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:466
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:465
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:467
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:468
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:464
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:462
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:226
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:224
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:221
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:227
static constexpr index_t group_size
Definition xdlops_gemm.hpp:218
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:220
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:231
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:219
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:222
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:228
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:225
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:223
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:1009
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:1004
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:1007
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:1008
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:999
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:1006
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:1000
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:1003
static constexpr index_t group_size
Definition xdlops_gemm.hpp:1005
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:1001
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:1012
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:1002
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:845
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:836
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:849
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:837
static constexpr index_t group_size
Definition xdlops_gemm.hpp:835
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:842
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:840
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:838
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:841
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:839
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:844
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:843
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:403
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:402
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:398
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:396
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:407
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:400
static constexpr index_t group_size
Definition xdlops_gemm.hpp:394
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:399
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:395
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:397
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:401
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:404
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:249
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:246
static constexpr index_t group_size
Definition xdlops_gemm.hpp:240
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:247
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:250
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:243
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:242
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:244
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:248
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:245
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:241
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:253
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:209
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:197
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:204
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:199
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:203
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:198
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:202
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:206
static constexpr index_t group_size
Definition xdlops_gemm.hpp:196
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:200
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:205
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:201
static constexpr index_t group_size
Definition xdlops_gemm.hpp:350
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:353
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:358
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:357
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:351
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:363
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:355
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:360
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:352
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:354
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:356
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:359
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:644
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:642
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:639
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:649
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:638
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:640
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:645
static constexpr index_t group_size
Definition xdlops_gemm.hpp:636
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:646
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:643
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:637
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:641
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:539
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:532
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:528
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:530
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:536
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:527
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:531
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:534
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:535
static constexpr index_t group_size
Definition xdlops_gemm.hpp:526
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:533
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:529
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:577
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:575
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:579
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:574
static constexpr index_t group_size
Definition xdlops_gemm.hpp:570
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:578
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:583
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:573
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:580
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:576
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:571
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:572
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:621
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:627
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:622
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:618
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:619
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:620
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:615
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:624
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:616
static constexpr index_t group_size
Definition xdlops_gemm.hpp:614
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:617
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:623
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:561
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:553
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:550
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:555
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:554
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:556
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:552
static constexpr index_t group_size
Definition xdlops_gemm.hpp:548
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:557
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:551
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:549
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:558
static constexpr index_t group_size
Definition xdlops_gemm.hpp:592
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:595
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:599
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:598
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:596
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:601
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:605
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:593
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:602
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:597
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:600
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:594
static constexpr index_t group_size
Definition xdlops_gemm.hpp:504
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:506
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:517
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:507
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:513
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:512
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:511
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:514
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:509
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:510
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:508
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:505
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:930
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:922
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:942
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:928
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:925
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:927
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:926
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:929
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:924
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:921
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:923
static constexpr index_t group_size
Definition xdlops_gemm.hpp:920
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:890
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:905
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:893
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:887
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:891
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:884
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:885
static constexpr index_t group_size
Definition xdlops_gemm.hpp:883
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:886
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:888
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:892
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:889
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:1048
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:1112
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:1170
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:1160
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:1038
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:1102
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:1150
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:1140
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:1065
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition xdlops_gemm.hpp:1129
__device__ void run(const FloatA &, const FloatB &, FloatC &) const
Definition xdlops_gemm.hpp:1076
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:1074
__device__ void run(const FloatA &, const FloatB &, FloatC &) const
Definition xdlops_gemm.hpp:1181
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:1179
Definition xdlops_gemm.hpp:1020
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:1029
static constexpr index_t group_size
Definition xdlops_gemm.hpp:1021
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:1028
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:1031
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:1024
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:1027
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:1025
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:1026
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:1022
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:1023
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:1030
Definition xdlops_gemm.hpp:1084
static constexpr index_t n_per_blk
Definition xdlops_gemm.hpp:1093
static constexpr index_t group_size
Definition xdlops_gemm.hpp:1085
static constexpr index_t num_output_blks
Definition xdlops_gemm.hpp:1091
static constexpr index_t m_per_blk
Definition xdlops_gemm.hpp:1092
static constexpr index_t num_threads_per_blk
Definition xdlops_gemm.hpp:1088
static constexpr bool is_k_reduction
Definition xdlops_gemm.hpp:1095
static constexpr index_t num_regs_per_blk
Definition xdlops_gemm.hpp:1087
static constexpr index_t num_groups_per_blk
Definition xdlops_gemm.hpp:1086
static constexpr index_t num_input_blks
Definition xdlops_gemm.hpp:1090
static constexpr index_t wave_size
Definition xdlops_gemm.hpp:1089
static constexpr index_t k_per_blk
Definition xdlops_gemm.hpp:1094
Definition xdlops_gemm.hpp:102
Definition functional2.hpp:33