gridwise_gemm_pipeline_v2.hpp Source File

gridwise_gemm_pipeline_v2.hpp Source File#

Composable Kernel: gridwise_gemm_pipeline_v2.hpp Source File
gridwise_gemm_pipeline_v2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
11{
12 __host__ __device__ static constexpr bool IsSupported(const index_t num_loop)
13 {
14 // TODO: improve applicability
15 return num_loop % 2 == 0;
16 }
17
18 __host__ __device__ static constexpr bool CalculateHasMainLoop(const index_t num_loop)
19 {
20 return (num_loop / 2) > 1;
21 }
22
23 template <bool HasMainLoop,
24 typename AGridDesc,
25 typename ABlockDesc,
26 typename ABlockTransfer,
27 typename AGridBuffer,
28 typename ABlockBuffer,
29 typename ABlockTransferStep,
30 typename BGridDesc,
31 typename BBlockDesc,
32 typename BBlockTransfer,
33 typename BGridBuffer,
34 typename BBlockBuffer,
35 typename BBlockTransferStep,
36 typename BlockwiseGemm,
37 typename CThreadBuffer>
38 __device__ static void Run(const AGridDesc& a_grid_desc,
39 const ABlockDesc& a_block_desc,
40 ABlockTransfer& a_blockwise_copy,
41 const AGridBuffer& a_grid_buf,
42 ABlockBuffer& a_block_buf,
43 const ABlockTransferStep& a_block_copy_step,
44 const BGridDesc& b_grid_desc,
45 const BBlockDesc& b_block_desc,
46 BBlockTransfer& b_blockwise_copy,
47 const BGridBuffer& b_grid_buf,
48 BBlockBuffer& b_block_buf,
49 const BBlockTransferStep& b_block_copy_step,
50 const BlockwiseGemm& blockwise_gemm,
51 CThreadBuffer& c_thread_buf,
52 index_t num_loop)
53 {
54 // global read 0
55 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
56 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
57
58 // move to 1
59 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
60 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
61
62 // Initialize C
63 c_thread_buf.Clear();
64
65 // LDS write 0
66 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
67 // global Read 1
68 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
69
70 // LDS write 0
71 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
72 // global Read 1
73 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
74
75 // main body
76 if constexpr(HasMainLoop)
77 {
78 index_t i = 0;
79
80 do
81 {
82#if CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
83 __builtin_amdgcn_iglp_opt(CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT);
84#endif
85
87
88 // GEMM i
89 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
90
92
93 // move to i + 2
94 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
95 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
96
97 // LDS write i + 1
98 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
99 // global read i + 2
100 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
101
102 // LDS write i + 1
103 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
104 // global read i + 2
105 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
106
107 ++i;
108 } while(i < (num_loop - 2));
109 }
110
111 // tail
112 {
114
115 // GEMM num_loop - 2
116 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
117
119
120 // LDS write num_loop - 1
121 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
122 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
123
125
126 // GEMM num_loop - 1
127 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
128 }
129 }
130};
131
132} // namespace ck
#define CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
Definition ck.hpp:218
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__device__ void block_sync_lds()
Definition synchronization.hpp:16
Definition gridwise_gemm_pipeline_v2.hpp:11
__host__ static __device__ constexpr bool IsSupported(const index_t num_loop)
Definition gridwise_gemm_pipeline_v2.hpp:12
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition gridwise_gemm_pipeline_v2.hpp:38
__host__ static __device__ constexpr bool CalculateHasMainLoop(const index_t num_loop)
Definition gridwise_gemm_pipeline_v2.hpp:18