functional.hpp Source File

functional.hpp Source File#

Composable Kernel: functional.hpp Source File
tile/core/utility/functional.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10#include <stdint.h>
11#include <utility>
12
13namespace ck_tile {
14
15namespace detail {
16
17struct swallow
18{
19 template <typename... Ts>
20 CK_TILE_HOST_DEVICE constexpr swallow(Ts&&...)
21 {
22 }
23};
24
25template <class>
27
28template <index_t... Is>
30{
31 template <class F>
32 CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
33 {
34 swallow{(f(number<Is>{}), 0)...};
35 }
36};
37
38} // namespace detail
39
40// F signature: F(number<Iter>)
41template <index_t NBegin, index_t NEnd, index_t Increment>
43{
45 {
46 static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
47 "Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
48 static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
49 "wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
50 "NBegin >= NEnd)");
51 }
52
53 template <class F>
59};
60
61namespace detail {
62
63template <typename T, T... Is>
64struct applier
65{
66 template <typename F>
67 CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
68 {
69 // tweak -fbracket-depth if compilation fails. Clang default limit is 256
70 (f(number<Is>{}), ...);
71 }
72};
73
74template <int32_t Size> // == sizeof...(Is)
75using make_applier = __make_integer_seq<applier, index_t, Size>;
76
77} // namespace detail
78
79template <index_t N>
80struct static_for<0, N, 1> : detail::make_applier<N>
81{
82 using detail::make_applier<N>::operator();
83};
84
86{
87 template <typename T>
88 CK_TILE_HOST_DEVICE constexpr T&& operator()(T&& arg) const noexcept
89 {
90 return std::forward<T>(arg);
91 }
92};
93
94namespace detail {
95
96// RemainLengths: sequence<...>
97// Orders: sequence<...>
98template <class RemainLengths, class Orders>
100{
102 {
103 static_assert(RemainLengths::size() > 0, "wrong! should not get here");
104 }
105
106 // F signature: F(sequence<...>)
107 // CurrentOrderedId: sequence<...>
108 template <class F, class CurrentOrderedId>
109 CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
110 {
111 static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
112 static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
113 f, CurrentOrderedId::push_back(I));
114 });
115 }
116};
117
118template <class Orders>
119struct static_ford_impl<sequence<>, Orders>
120{
121 // F signature: F(sequence<...>)
122 // OrderedId: sequence<...>
123 template <class F, class OrderedId>
124 CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
125 {
126 // retrive unordered Id
127 f(OrderedId::reorder_old_to_new(Orders{}));
128 }
129};
130
131} // namespace detail
132
133// Lengths is sequence<...>, it is the length of each dimension for
134// N-dimensional loop
135// Orders is sequence<...>, it is the order of dimension in which static_ford
136// will loop over each
137// dimension
138template <class Lengths,
139 class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
141{
143 {
144 static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
145 static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
146 }
147
148 // F signature: F(sequence<...> multi_id)
149 // multi_id is the unordered multi-index
150 template <class F>
151 CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
152 {
153 constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
154 detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
155 }
156};
157
158namespace detail {
159
160template <typename Indices>
162
163template <index_t... Is>
164struct unpack_impl<sequence<Is...>>
165{
166 template <typename F, typename X>
167 CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x) const
168 {
169#if 0
170 return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...);
171#else
172 return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...);
173#endif
174 }
175};
176
177template <typename Seq0, typename Seq1>
179
180// TODO: remove this, after properly implementing unpack that takes any number of containers
181template <index_t... Is, index_t... Js>
182struct unpack2_impl<sequence<Is...>, sequence<Js...>>
183{
184 template <typename F, typename X, typename Y>
185 CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x, Y&& y) const
186 {
187#if 0
188 return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...,
189 std::forward<Y>(y).at(number<Js>{})...);
190#else
191 return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...,
192 std::forward<Y>(y).template at<Js>()...);
193#endif
194 }
195};
196
197} // namespace detail
198
199template <typename F, typename X>
200CK_TILE_HOST_DEVICE constexpr auto unpack(F&& f, X&& x)
201{
202 using X_ = remove_reference_t<X>;
203 return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type>{}(
204 std::forward<F>(f), std::forward<X>(x));
205}
206
207// TODO: properly implement unpack that takes any number of containers
208template <typename F, typename X, typename Y>
209CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y)
210{
211 using X_ = remove_reference_t<X>;
212 using Y_ = remove_reference_t<Y>;
213 return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type,
214 typename arithmetic_sequence_gen<0, Y_::size(), 1>::type>{}(
215 std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
216}
217
218// z = predicate ? x : y
219template <bool predicate, typename X, typename Y>
220constexpr auto conditional_expr(X&& x, Y&& y)
221{
222 if constexpr(predicate)
223 {
224 return std::forward<X>(x);
225 }
226 else
227 {
228 return std::forward<Y>(y);
229 }
230}
231
232} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition arch.hpp:385
__make_integer_seq< applier, index_t, Size > make_applier
Definition tile/core/utility/functional.hpp:75
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition tile/core/utility/functional.hpp:209
constexpr auto conditional_expr(X &&x, Y &&y)
Definition tile/core/utility/functional.hpp:220
typename std::remove_reference< T >::type remove_reference_t
Definition type_traits.hpp:15
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto unpack(F &&f, X &&x)
Definition tile/core/utility/functional.hpp:200
int32_t index_t
Definition integer.hpp:9
Definition tile/core/container/sequence.hpp:287
Definition tile/core/utility/functional.hpp:65
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
Definition tile/core/utility/functional.hpp:67
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
Definition tile/core/utility/functional.hpp:32
Definition tile/core/utility/functional.hpp:26
CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
Definition tile/core/utility/functional.hpp:124
Definition tile/core/utility/functional.hpp:100
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
Definition tile/core/utility/functional.hpp:109
CK_TILE_HOST_DEVICE constexpr static_ford_impl()
Definition tile/core/utility/functional.hpp:101
Definition tile/core/utility/functional.hpp:18
CK_TILE_HOST_DEVICE constexpr swallow(Ts &&...)
Definition tile/core/utility/functional.hpp:20
CK_TILE_HOST_DEVICE constexpr auto operator()(F &&f, X &&x, Y &&y) const
Definition tile/core/utility/functional.hpp:185
Definition tile/core/utility/functional.hpp:178
CK_TILE_HOST_DEVICE constexpr auto operator()(F &&f, X &&x) const
Definition tile/core/utility/functional.hpp:167
Definition tile/core/utility/functional.hpp:161
Definition tile/core/utility/functional.hpp:86
CK_TILE_HOST_DEVICE constexpr T && operator()(T &&arg) const noexcept
Definition tile/core/utility/functional.hpp:88
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
Definition tile/core/utility/functional.hpp:54
CK_TILE_HOST_DEVICE constexpr static_for()
Definition tile/core/utility/functional.hpp:44
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
Definition tile/core/utility/functional.hpp:151
CK_TILE_HOST_DEVICE constexpr static_ford()
Definition tile/core/utility/functional.hpp:142