combined_element_wise_operation.hpp Source File

combined_element_wise_operation.hpp Source File#

Composable Kernel: combined_element_wise_operation.hpp Source File
combined_element_wise_operation.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9namespace tensor_operation {
10namespace element_wise {
11
12// y = UnaryOp0(UnaryOp1(...(x)))
13template <typename... UnaryOpsSet>
15{
16 static constexpr const char* name = "UnaryCombinedOp";
17
18 __host__ __device__ UnaryCombinedOp() : unary_ops_() {}
19
20 __host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {}
21
22 template <typename Y, typename X>
23 __host__ __device__ void operator()(Y& y, const X& x) const
24 {
25 // Execute first unary op to copy data to y
26 unary_ops_.At(Number<0>{})(y, x);
27
28 static_for<1, Tuple<UnaryOpsSet...>::Size(), 1>{}([&](auto i) { unary_ops_.At(i)(y, y); });
29 };
30
31 Tuple<UnaryOpsSet...> unary_ops_;
32};
33
34// y = BinaryOp(UnaryOp0(x0), UnaryOp1(x1))
35template <typename BinaryOp, typename UnaryOp0, typename UnaryOp1>
37{
38 static constexpr const char* name = "BinaryWithUnaryCombinedOp";
39
40 __host__ __device__ BinaryWithUnaryCombinedOp() : binary_op_(), unary_op0_(), unary_op1_() {}
41
42 __host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op,
43 UnaryOp0 unary_op0,
44 UnaryOp1 unary_op1)
45 : binary_op_(binary_op), unary_op0_(unary_op0), unary_op1_(unary_op1)
46 {
47 }
48
49 template <typename Y, typename X0, typename X1>
50 __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
51 {
52 Y unary_x0_tmp_result;
53 Y unary_x1_tmp_result;
54 unary_op0_(unary_x0_tmp_result, x0);
55 unary_op1_(unary_x1_tmp_result, x1);
56 binary_op_(y, unary_x0_tmp_result, unary_x1_tmp_result);
57 };
58
59 private:
60 BinaryOp binary_op_;
61 UnaryOp0 unary_op0_;
62 UnaryOp1 unary_op1_;
63};
64
65// y = BinaryOp0(BinaryOp1(UnaryOp0(x0), UnaryOp1(x1)), UnaryOp2(x2))
66template <typename BinaryOp0,
67 typename BinaryOp1,
68 typename UnaryOp0,
69 typename UnaryOp1,
70 typename UnaryOp2>
72{
73 static constexpr const char* name = "TrinaryWithUnaryCombinedOp";
74
75 __host__ __device__ TrinaryWithUnaryCombinedOp()
76 : binary_op0_(), binary_op1_(), unary_op0_(), unary_op1_(), unary_op2_()
77 {
78 }
79
80 __host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0,
81 BinaryOp0 binary_op1,
82 UnaryOp0 unary_op0,
83 UnaryOp1 unary_op1,
84 UnaryOp2 unary_op2)
85 : binary_op0_(binary_op0),
86 binary_op1_(binary_op1),
87 unary_op0_(unary_op0),
88 unary_op1_(unary_op1),
89 unary_op2_(unary_op2)
90 {
91 }
92
93 template <typename Y, typename X0, typename X1, typename X2>
94 __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1, const X2& x2) const
95 {
96
97 Y unary_x0_tmp_result;
98 Y unary_x1_tmp_result;
99 Y unary_x2_tmp_result;
100 unary_op0_(unary_x0_tmp_result, x0);
101 unary_op1_(unary_x1_tmp_result, x1);
102 unary_op2_(unary_x2_tmp_result, x2);
103 binary_op0_(unary_x0_tmp_result, unary_x0_tmp_result, unary_x1_tmp_result);
104 binary_op1_(y, unary_x0_tmp_result, unary_x2_tmp_result);
105 };
106
107 private:
108 BinaryOp0 binary_op0_{};
109 BinaryOp1 binary_op1_{};
110 UnaryOp0 unary_op0_{};
111 UnaryOp1 unary_op1_{};
112 UnaryOp2 unary_op2_{};
113};
114
117
118} // namespace element_wise
119} // namespace tensor_operation
120} // namespace ck
Definition binary_element_wise_operation.hpp:11
UnaryCombinedOp< Scale, Scale, Relu > ScaleScaleRelu
Definition combined_element_wise_operation.hpp:116
UnaryCombinedOp< Scale, Scale, PassThrough > ScaleScalePass
Definition combined_element_wise_operation.hpp:115
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
integral_constant< index_t, N > Number
Definition number.hpp:12
Definition utility/tuple.hpp:117
Definition functional2.hpp:33
__host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op, UnaryOp0 unary_op0, UnaryOp1 unary_op1)
Definition combined_element_wise_operation.hpp:42
__host__ __device__ BinaryWithUnaryCombinedOp()
Definition combined_element_wise_operation.hpp:40
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition combined_element_wise_operation.hpp:50
static constexpr const char * name
Definition combined_element_wise_operation.hpp:38
__host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0, BinaryOp0 binary_op1, UnaryOp0 unary_op0, UnaryOp1 unary_op1, UnaryOp2 unary_op2)
Definition combined_element_wise_operation.hpp:80
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1, const X2 &x2) const
Definition combined_element_wise_operation.hpp:94
static constexpr const char * name
Definition combined_element_wise_operation.hpp:73
__host__ __device__ TrinaryWithUnaryCombinedOp()
Definition combined_element_wise_operation.hpp:75
Definition combined_element_wise_operation.hpp:15
__host__ __device__ UnaryCombinedOp()
Definition combined_element_wise_operation.hpp:18
Tuple< UnaryOpsSet... > unary_ops_
Definition combined_element_wise_operation.hpp:31
__host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops)
Definition combined_element_wise_operation.hpp:20
__host__ __device__ void operator()(Y &y, const X &x) const
Definition combined_element_wise_operation.hpp:23
static constexpr const char * name
Definition combined_element_wise_operation.hpp:16