rmsnorm2d_fwd_pipeline_one_pass.hpp Source File

rmsnorm2d_fwd_pipeline_one_pass.hpp Source File#

Composable Kernel: rmsnorm2d_fwd_pipeline_one_pass.hpp Source File
rmsnorm2d_fwd_pipeline_one_pass.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <string>
9#include <type_traits>
10
11namespace ck_tile {
12
13template <typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
15{
18
24
27
28 static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
29 static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
30 static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
31
32 static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
33 static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
34 static constexpr bool kPadN = Problem::Traits::kPadN;
35 static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
36 static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
37
38 static constexpr const char* name = []() {
39 if constexpr(kNeedCrossWarpSync)
40 return "bpr_op"; // block per row
41 else
42 return "wpr_op"; // warp per row
43 }();
44
46 {
47 return Policy::template GetSmemSize<Problem>();
48 }
49
50 template <typename XWindow,
51 typename XResidualWindow,
52 typename GammaWindow,
53 typename YWindow,
54 typename YResidualWindow,
55 typename InvRmsWindow,
56 typename SmoothScaleWindow,
57 typename YScaleWindow,
58 typename UnquantYWindow,
59 typename Epilogue>
60 CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
61 const XResidualWindow& x_residual_window_,
62 const GammaWindow& gamma_window_,
63 YWindow& y_window_,
64 const YResidualWindow& y_residual_window_,
65 InvRmsWindow& inv_rms_window,
66 const SmoothScaleWindow& sm_scale_window_,
67 YScaleWindow& y_scale_window_,
68 UnquantYWindow& unquant_y_window,
69 ComputeDataType epsilon,
70 ck_tile::index_t row_size,
71 void* smem,
72 Epilogue) const
73 {
74 const auto x_window =
75 make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
76 const auto gamma_window = make_tile_window(
77 gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
78 const auto x_residual_window = make_tile_window(
79 x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
80 auto y_residual_window = make_tile_window(
81 y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
82
83 auto reduce_square_sum_func = ReduceOp::SquareAdd{};
84 auto reduce_sum_func = ReduceOp::Add{};
85 auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
86 auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
87 auto block_reduce2d_cross_warp_sync =
88 Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
89
90 auto x = load_tile(x_window);
91 auto x_resi = load_tile(x_residual_window);
92
93 // load gamma (TODO: support no gamma?)
94 const auto gamma = load_tile(gamma_window);
95
96 auto acc = cast_tile<ComputeDataType>(x);
97
100 {
101 sweep_tile(x_resi, [&](auto idx) {
102 // compute x = x_resi + x
103 acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
104 });
106 {
107 store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
108 }
109 }
110
111 // compute mean square each-thread->cross-lane->cross-warp
112 auto square_sum = block_reduce2d(acc,
113 reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
114 reduce_square_sum_func);
115 block_reduce2d_sync(square_sum, reduce_sum_func);
116 block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
117
118 // compute inv-rms
119 auto inv_rms = tile_elementwise_in(
120 [&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
121
122 if constexpr(kSaveInvRms)
123 store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
124
125 // rmsnorm computation
126 auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
127 sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
128 constexpr auto i_idx = make_tuple(idx[number<0>{}]);
129 constexpr auto j_idx = make_tuple(idx[number<1>{}]);
130
131 const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
132
133 auto rmsn_ = acc[idx] * inv_rms_[i_idx] * gamma_;
134
135 rmsn(idx) = rmsn_;
136 });
137
139 {
140 if constexpr(kSaveUnquant)
141 {
142 Epilogue{}(
143 unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
144 }
145 else
146 {
147 Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
148 }
149 }
151 {
152 if constexpr(kSaveUnquant)
153 {
154 Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem);
155 }
156 else
157 {
158 Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
159 }
160 }
161 else
162 {
163 Epilogue{}(y_window_, rmsn, nullptr);
164 }
165 }
166};
167} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
@ SMOOTH_DYNAMIC_QUANT
Definition rmsnorm2d_fwd_traits.hpp:29
@ DYNAMIC_QUANT
Definition rmsnorm2d_fwd_traits.hpp:30
@ PRE_ADD_STORE
Definition rmsnorm2d_fwd_traits.hpp:14
@ PRE_ADD
Definition rmsnorm2d_fwd_traits.hpp:16
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F &f, UnpacksPerXDim={})
Definition sweep_tile.hpp:231
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition reduce_operator.hpp:14
Definition reduce_operator.hpp:40
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:15
static constexpr bool kSaveUnquant
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:30
static constexpr auto kFusedAdd
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:35
static constexpr bool kSaveInvRms
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:29
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:19
ck_tile::remove_cvref_t< typename Problem::YDataType > YDataType
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:22
XDataType XResidualDataType
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:25
static constexpr bool kPadN
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:34
ck_tile::remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:20
static constexpr bool kHasGamma
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:28
CK_TILE_DEVICE auto operator()(const XWindow &x_window_, const XResidualWindow &x_residual_window_, const GammaWindow &gamma_window_, YWindow &y_window_, const YResidualWindow &y_residual_window_, InvRmsWindow &inv_rms_window, const SmoothScaleWindow &sm_scale_window_, YScaleWindow &y_scale_window_, UnquantYWindow &unquant_y_window, ComputeDataType epsilon, ck_tile::index_t row_size, void *smem, Epilogue) const
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:60
XDataType YResidualDataType
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:26
static constexpr auto kFusedQuant
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:36
ck_tile::remove_cvref_t< Problem_ > Problem
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:16
static constexpr const char * name
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:38
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:21
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:45
ck_tile::remove_cvref_t< Policy_ > Policy
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:17
static constexpr bool kNeedCrossWarpSync
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:32
static constexpr bool kPadM
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:33
ck_tile::remove_cvref_t< typename Problem::InvRmsDataType > InvRmsDataType
Definition rmsnorm2d_fwd_pipeline_one_pass.hpp:23