block_fmha_pipeline_qr_ks_vs.hpp Source File

block_fmha_pipeline_qr_ks_vs.hpp Source File#

Composable Kernel: block_fmha_pipeline_qr_ks_vs.hpp Source File
block_fmha_pipeline_qr_ks_vs.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
14// This pipeline is qkv all located in LDS
15template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
17{
33
36 static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
37 static_assert(kQLoadOnce == Policy::QLoadOnce);
38
39 static constexpr index_t kBlockSize = Problem::kBlockSize;
40
41 static constexpr index_t kM0 = BlockFmhaShape::kM0;
42 static constexpr index_t kN0 = BlockFmhaShape::kN0;
43 static constexpr index_t kK0 = BlockFmhaShape::kK0;
44 static constexpr index_t kN1 = BlockFmhaShape::kN1;
45 static constexpr index_t kK1 = BlockFmhaShape::kK1;
46 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
47 static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
48
49 static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
50
51 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
52 static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
53 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
54 static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
55 static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
56 static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
57 static constexpr auto BiasEnum = Problem::BiasEnum;
58 static constexpr bool kStoreLSE = Problem::kStoreLSE;
59 static constexpr bool kHasDropout = Problem::kHasDropout;
60
61 static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
62 static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
63
64 static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
65 (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
68
69 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
70 // ... together with tensor distribution. tensor dist should able to overwrite this
71 static constexpr index_t kAlignmentQ =
72 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
73 static constexpr index_t kAlignmentK =
74 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
75 static constexpr index_t kAlignmentV = []() {
76 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
77 return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
78 else
79 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
80 }();
81
82 static constexpr index_t kAlignmentO =
83 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
84 static constexpr index_t kAlignmentBias =
85 kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
86
87 static constexpr index_t kBlockPerCu = []() {
88 if constexpr(Problem::kBlockPerCu != -1)
89 return Problem::kBlockPerCu;
90 else
91 {
92 if constexpr(kQKHeaddim <= 32)
93 {
94 return 2;
95 }
96 else if constexpr(kQKHeaddim <= 64)
97 {
98 return 3;
99 }
100 else if constexpr(kQKHeaddim <= 128)
101 {
103 return 1;
104 else
105 return 2;
106 }
107 else if constexpr(kQKHeaddim <= 256)
108 {
109 return 1;
110 }
111 else
112 {
113 return 1;
114 }
115 }
116 }();
117
118 static constexpr const char* name = "qr";
119
120 using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
121
123 {
124 return Policy::template GetSmemSize<Problem>();
125 }
126
127 template <typename QDramBlockWindowTmp,
128 typename KDramBlockWindowTmp,
129 typename VDramBlockWindowTmp,
130 typename BiasDramBlockWindowTmp,
131 typename RandValDramBlockWindowTmp,
132 typename LSEDramBlockWindowTmp,
133 typename QElementFunction,
134 typename KElementFunction,
135 typename VElementFunction,
136 typename BiasElementFunction,
137 typename LSEElementFunction,
138 typename SAccElementFunction,
139 typename PComputeElementFunction,
140 typename OAccElementFunction,
141 typename PositionEncoding,
142 typename AttentionVariantParams,
143 typename BlockIndices>
145 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
146 const QElementFunction& q_element_func,
147 const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
148 const KElementFunction& k_element_func,
149 const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
150 const VElementFunction& v_element_func,
151 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
152 const BiasElementFunction& bias_element_func,
153 RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
154 LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
155 const LSEElementFunction& lse_element_func,
156 const SAccElementFunction& s_acc_element_func,
157 const PComputeElementFunction& p_compute_element_func,
158 const OAccElementFunction& o_acc_element_func,
159 FmhaMask mask,
160 PositionEncoding position_encoding,
161 float scale_s,
162 const AttentionVariant& variant,
163 const AttentionVariantParams& variant_params,
164 const BlockIndices& block_indices,
165 void* smem_ptr,
166 DropoutType& dropout) const
167 {
168 static_assert(
169 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
170 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
171 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
172 "wrong!");
173
174 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
175 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
176 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
177 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
178 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
179 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
180 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
181 "wrong!");
182
183 // K tile in LDS
184 KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
185 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
187 k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
188 auto k_lds_window =
190
191 // V tile in LDS
193 reinterpret_cast<VDataType*>(smem_ptr),
194 Policy::template MakeVLdsBlockDescriptor<Problem>());
195 auto v_lds_window = make_tile_window(
196 v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
197
198 // Block GEMM
199 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
200 constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
201
202 auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
203 q_dram_block_window_tmp.get_window_lengths(),
204 q_dram_block_window_tmp.get_window_origin(),
205 Policy::template MakeQRegTileDistribution<Problem>());
206
207 auto q = load_tile(q_dram_window);
208
209 using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
210 auto s_acc = SaccBlockTileType{};
211
212 // reduction function for softmax
213 const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
214 const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
215
216 // infer Sacc, S, P, M, L, Oacc type
217 using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
218
219 using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
220 SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
221
222 using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
223
224 // init Oacc, M, L
225 auto o_acc = OaccBlockTileType{};
226 auto m = MLBlockTileType{};
227 auto l = MLBlockTileType{};
228
229 clear_tile(o_acc);
231 clear_tile(l);
232
233 const auto q_origin = q_dram_window.get_window_origin();
234 const auto [seqlen_k_start, seqlen_k_end] =
235 mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
236
237 const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
238
239 // check early exit if no work to do
240 if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
241 {
242 if(num_total_loop <= 0)
243 {
244 if constexpr(kStoreLSE)
245 {
246 auto lse =
247 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
248
250
251 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
252 }
253
254 // Note: here occ are all cleard, return it
255 // Note: q loaded but no fence, ignore it.
256 return o_acc;
257 }
258 }
259
260 auto k_dram_block_window =
261 make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
262 k_dram_block_window_tmp.get_window_lengths(),
263 {seqlen_k_start, 0});
264
265 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
266 auto bias_dram_window =
267 make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
268 bias_dram_block_window_tmp.get_window_lengths(),
269 {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
270 Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
271
272 auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
273 randval_dram_block_window_tmp, seqlen_k_start);
274
275 auto v_dram_window =
276 make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
277 v_dram_block_window_tmp.get_window_lengths(),
278 {0, seqlen_k_start}, // TODO: hdim split?
279 Policy::template MakeVDramTileDistribution<Problem>());
280
281 auto q_tile = tile_elementwise_in(q_element_func, q);
282
283 // prefetch K tile
284 index_t i_total_loops = 0;
285 constexpr index_t k0_loops = kQKHeaddim / kK0;
286 constexpr index_t k1_loops = kN0 / kK1;
287 // Use compile-time conditional for group barrier sequence
288 // (No runtime lambda selection)
289 auto schedule_gemm0 = [] {
290 using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
291 constexpr auto WarpGemmConfig =
292 BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
293 using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
294 constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>();
295 constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>();
296 constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM;
297 constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN;
298 constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK;
299 constexpr index_t NumMfmaInsts = (kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) *
300 (kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp);
301 if constexpr(get_warp_size() == 64 && kQKHeaddim == 256)
302 {
303 static_assert(NumMfmaInsts % 8 == 0);
304 static_for<0, NumMfmaInsts / 8, 1>{}([&](auto) {
305 __builtin_amdgcn_sched_group_barrier(DS_READ, 2, 0); // DS read
306 __builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
307 __builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
308 __builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
309 __builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
310 __builtin_amdgcn_sched_group_barrier(MFMA, 4, 0); // MFMA
311 });
312 }
313 };
314
315 static_assert(2 <= k0_loops);
316 static_assert(1 <= k1_loops);
317 do
318 {
319 // STAGE 1, QK gemm
320 auto k_dram_window = make_tile_window(
321 k_dram_block_window.get_bottom_tensor_view(),
322 k_dram_block_window.get_window_lengths(),
323 k_dram_block_window.get_window_origin(),
324 Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
325 // load
326
327 auto k_block_tile = load_tile(k_dram_window);
328 {
329 move_tile_window(k_dram_window, {0, kK0});
330 clear_tile(s_acc); // initialize C
331 store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
332 k_block_tile = load_tile(k_dram_window);
333 }
334
336 {
337 __builtin_amdgcn_sched_barrier(
338 0); // prevent from messing up the order of global loads
339 }
340 const auto bias_tile = load_tile(bias_dram_window); // load bias tile
342 {
343 __builtin_amdgcn_sched_barrier(
344 0); // prevent from messing up the order of global loads
345 }
346
347 if constexpr(k0_loops > 2)
348 {
349 static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
351 gemm_0(s_acc,
352 get_slice_tile(q_tile,
353 sequence<0, i_k0 * kK0>{},
354 sequence<kM0, (i_k0 + 1) * kK0>{}),
355 k_lds_window);
356 schedule_gemm0();
358 move_tile_window(k_dram_window, {0, kK0});
359
361 k_lds_window,
362 tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
363 k_block_tile = load_tile(k_dram_window); // global read i + 2
364 });
365 }
366
367 const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
368 { // tail
370 gemm_0(s_acc,
371 get_slice_tile(q_tile,
372 sequence<0, (k0_loops - 2) * kK0>{},
373 sequence<kM0, (k0_loops - 1) * kK0>{}),
374 k_lds_window);
375 schedule_gemm0();
377
378 store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
380
381 gemm_0(s_acc,
382 get_slice_tile(q_tile,
383 sequence<0, (k0_loops - 1) * kK0>{},
384 sequence<kM0, k0_loops * kK0>{}),
385 k_lds_window);
386 schedule_gemm0();
387 }
388
389 // STAGE 2, scale_s, add bias, mask, softmax
391 {
392 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
393 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
395 [&](auto& x, const auto& y) {
396#if !CK_TILE_FMHA_FWD_FAST_EXP2
397 x += type_convert<SaccDataType>(bias_element_func(y));
398#else
400 type_convert<SaccDataType>(bias_element_func(y));
401#endif
402 },
403 s_acc,
404 bias_tile);
405 }
406 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
407 {
408 const auto k_origin = k_dram_block_window.get_window_origin();
409 constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
410 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
411 sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
412 sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
413 const auto tile_idx = get_x_indices_from_distributed_indices(
414 s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
415
416 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
417 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
418 constexpr auto i_j_idx = make_tuple(idx0, idx1);
419
420 s_acc(i_j_idx) *= scale_s;
421 position_encoding.update(s_acc(i_j_idx), row, col);
422 });
423 });
424 }
425 else
426 {
427 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
428 if constexpr(kHasLogitsSoftCap)
429 {
430 auto apply_logits_transform =
431 [&variant, &variant_params, &block_indices](auto& x) {
432 x = variant.LogitsTransform(variant_params,
433 variant.QueryTransform(variant_params, x),
434 block_indices.batch_idx,
435 block_indices.qo_head_idx,
436 block_indices.kv_head_idx);
437 };
438#if !CK_TILE_FMHA_FWD_FAST_EXP2
439 tile_elementwise_inout(apply_logits_transform, s_acc);
440#else
441 tile_elementwise_inout(apply_logits_transform, s_acc);
442#endif
443 }
444 else
445 {
446#if !CK_TILE_FMHA_FWD_FAST_EXP2
447 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
448#endif
449 }
450 }
451 move_tile_window(bias_dram_window, {0, kN0});
452 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
453 {
454 const auto k_origin = k_dram_block_window.get_window_origin();
455 bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
456 k_origin.at(number<0>{}),
457 number<kM0>{},
458 number<kN0>{});
459 if(need_perpixel_check)
460 {
462 s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
463 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
464 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
465 return !variant.LogitsMask(variant_params,
466 block_indices.batch_idx,
467 row,
468 col,
469 block_indices.qo_head_idx,
470 block_indices.kv_head_idx);
471 });
472 }
473 }
474
475 const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
477 s,
478 sequence<1>{},
479 f_max,
480 -numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
482
483 const auto m_old = m; // m{j-1}
485 [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
486
488 s.get_tile_distribution()); // Pcompute{j}
489
490 static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
494 FmhaMask::IsMasking)
495 {
498 : raw_m;
499 }
500 else
501 {
502 return raw_m;
503 }
504 };
505
506 constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
507 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
508 constexpr auto i_idx = make_tuple(idx0);
509#if CK_TILE_FMHA_FWD_FAST_EXP2
510 auto row_max = scale_s * get_validated_m(m[i_idx]);
511#endif
512 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
513 constexpr auto i_j_idx = make_tuple(idx0, idx1);
514#if CK_TILE_FMHA_FWD_FAST_EXP2
517 {
518 p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
519 }
520 else
521 {
522 if constexpr(kHasLogitsSoftCap)
523 {
524 p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
525 }
526 else
527 {
528 p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
529 }
530 }
531#else
532 p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
533#endif
534 });
535 });
536
538 p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
539
541 // l{j}, Oacc{j}
542 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
543 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
544 constexpr auto i_idx = make_tuple(idx0);
545#if CK_TILE_FMHA_FWD_FAST_EXP2
546 const auto tmp = [&]() {
549 {
550 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
551 }
552 else
553 {
554 if constexpr(kHasLogitsSoftCap)
555 {
556
557 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
558 }
559 else
560 {
561 auto row_max = scale_s * get_validated_m(m[i_idx]);
562 return exp2(scale_s * m_old[i_idx] - row_max);
563 }
564 }
565 }();
566#else
567 const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
568#endif
569 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
570 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
571 constexpr auto i_j_idx = make_tuple(idx0, idx1);
572 // FIXME: this use different equation from FA v2 paper,
573 // but produce correc result.
574 // Is the equation wrong?
575 o_acc(i_j_idx) *= tmp;
576 });
577 });
578
579 if constexpr(kHasDropout)
580 {
581 // K and dropout use the same address in LDS, finish loading from k_lds_window by
582 // gemm_0 to reuse LDS.
584 dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
585 smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
586 }
587
589 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
590 {
592 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
593 shuffle_tile(v_shuffle_tmp, v_prefetch);
595 v_lds_window,
596 tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
597 }
598 else
599 {
600 store_tile(v_lds_window,
601 tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
602 }
603 move_tile_window(v_dram_window, {0, kK1});
604
605 const auto p =
606 cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
607
608 // STAGE 3, KV gemm
609 if constexpr(k1_loops > 1)
610 {
611 static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
612 const auto v = load_tile(v_dram_window); // load next v
614 gemm_1(o_acc,
616 p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
617 v_lds_window);
619 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
620 {
622 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
623 shuffle_tile(v_shuffle_tmp, v);
624 store_tile(v_lds_window,
625 tile_elementwise_in(v_element_func,
626 v_shuffle_tmp)); // store the prefetch
627 }
628 else
629 {
630 store_tile(v_lds_window,
631 tile_elementwise_in(v_element_func, v)); // store next v
632 }
633 move_tile_window(v_dram_window, {0, kK1});
634 });
635 }
636 // move K tile windows
637 move_tile_window(k_dram_block_window, {kN0, 0});
638 // tail
639 {
641 gemm_1(o_acc,
642 get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
643 v_lds_window);
645 }
646 } while(++i_total_loops < num_total_loop);
647
648 // store lse
649 if constexpr(kStoreLSE)
650 {
651 auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
652
653 constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
654 sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
655 constexpr auto i_idx = make_tuple(idx0);
656#if CK_TILE_FMHA_FWD_FAST_EXP2
659 {
660 lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
661 }
662 else
663 {
664 if constexpr(kHasLogitsSoftCap)
665 {
666 lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
667 }
668 else
669 {
670 lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
671 }
672 }
673#else
674 lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
675#endif
676 });
677
678 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
679 }
680
681 // finally, O
682 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
683
684 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
685 constexpr auto i_idx = make_tuple(idx0);
686 const auto tmp = [&]() {
687 if constexpr(FmhaMask::IsMasking)
688 {
689 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
690 }
691 else
692 return 1 / l[i_idx];
693 }();
694 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
695 constexpr auto i_j_idx = make_tuple(idx0, idx1);
696 o_acc(i_j_idx) *= tmp;
697 });
698 });
699
700 o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
701
702 return o_acc;
703 }
704
705 template <typename QDramBlockWindowTmp,
706 typename KDramBlockWindowTmp,
707 typename VDramBlockWindowTmp,
708 typename BiasDramBlockWindowTmp,
709 typename RandValDramBlockWindowTmp,
710 typename LSEDramBlockWindowTmp,
711 typename PositionEncoding,
712 typename AttentionVariantParams,
713 typename BlockIndices>
715 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
716 const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
717 const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
718 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
719 RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
720 LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
721 FmhaMask mask,
722 PositionEncoding position_encoding,
723 float scale_s,
724 const AttentionVariant& variant,
725 const AttentionVariantParams& variant_params,
726 const BlockIndices& block_indices,
727 void* smem_ptr,
728 DropoutType& dropout) const
729 {
730 return operator()(q_dram_block_window_tmp,
731 identity{},
732 k_dram_block_window_tmp,
733 identity{},
734 v_dram_block_window_tmp,
735 identity{},
736 bias_dram_block_window_tmp,
737 identity{},
738 randval_dram_block_window_tmp,
739 lse_dram_block_window_tmp,
740 identity{},
741 identity{},
742 identity{},
743 identity{},
744 mask,
745 position_encoding,
746 scale_s,
747 variant,
748 variant_params,
749 block_indices,
750 smem_ptr,
751 dropout);
752 }
753};
754
755} // namespace ck_tile
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition config.hpp:234
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
unsigned int uint32_t
Definition stdint.h:126
Definition block_fmha_pipeline_qr_ks_vs.hpp:17
remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:28
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:24
static constexpr index_t kN1
Definition block_fmha_pipeline_qr_ks_vs.hpp:44
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const
Definition block_fmha_pipeline_qr_ks_vs.hpp:715
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:29
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_qr_ks_vs.hpp:53
static constexpr index_t kSubQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs.hpp:47
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:25
remove_cvref_t< Policy_ > Policy
Definition block_fmha_pipeline_qr_ks_vs.hpp:19
static constexpr index_t kAlignmentK
Definition block_fmha_pipeline_qr_ks_vs.hpp:73
static constexpr bool kQLoadOnce
Definition block_fmha_pipeline_qr_ks_vs.hpp:36
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition block_fmha_pipeline_qr_ks_vs.hpp:31
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:23
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:20
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:27
static constexpr index_t kAlignmentQ
Definition block_fmha_pipeline_qr_ks_vs.hpp:71
static constexpr bool kHasDropout
Definition block_fmha_pipeline_qr_ks_vs.hpp:59
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_qr_ks_vs.hpp:52
static constexpr index_t kN0
Definition block_fmha_pipeline_qr_ks_vs.hpp:42
static constexpr auto BiasEnum
Definition block_fmha_pipeline_qr_ks_vs.hpp:57
static constexpr index_t kAlignmentV
Definition block_fmha_pipeline_qr_ks_vs.hpp:75
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_qr_ks_vs.hpp:87
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:26
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:22
static constexpr index_t kK0
Definition block_fmha_pipeline_qr_ks_vs.hpp:43
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_qr_ks_vs.hpp:58
static constexpr index_t kAlignmentO
Definition block_fmha_pipeline_qr_ks_vs.hpp:82
static constexpr index_t kQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs.hpp:46
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_qr_ks_vs.hpp:51
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qr_ks_vs.hpp:122
static constexpr index_t kM0
Definition block_fmha_pipeline_qr_ks_vs.hpp:41
static constexpr const char * name
Definition block_fmha_pipeline_qr_ks_vs.hpp:118
static constexpr index_t kK1
Definition block_fmha_pipeline_qr_ks_vs.hpp:45
static constexpr uint32_t MFMA
Definition block_fmha_pipeline_qr_ks_vs.hpp:62
remove_cvref_t< Problem_ > Problem
Definition block_fmha_pipeline_qr_ks_vs.hpp:18
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_qr_ks_vs.hpp:55
static constexpr uint32_t DS_READ
Definition block_fmha_pipeline_qr_ks_vs.hpp:61
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_pipeline_qr_ks_vs.hpp:32
static constexpr index_t kAlignmentBias
Definition block_fmha_pipeline_qr_ks_vs.hpp:84
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition block_fmha_pipeline_qr_ks_vs.hpp:35
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_qr_ks_vs.hpp:39
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:30
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_pipeline_qr_ks_vs.hpp:21
std::conditional_t< kHasDropout, BlockDropout, NullBlockDropout > DropoutType
Definition block_fmha_pipeline_qr_ks_vs.hpp:120
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const
Definition block_fmha_pipeline_qr_ks_vs.hpp:145
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_pipeline_qr_ks_vs.hpp:34
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_pipeline_qr_ks_vs.hpp:56
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_qr_ks_vs.hpp:54
Definition tile/core/utility/functional.hpp:86
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
#define C_LOG2E
Definition tile/core/numeric/math.hpp:469