pk_fp4.hpp Source File

pk_fp4.hpp Source File#

Composable Kernel: pk_fp4.hpp Source File
pk_fp4.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <cmath>
10
11#if defined(__gfx950__)
12#define CK_TILE_FP4_CVT_DEVICE 1
13#else
14#define CK_TILE_FP4_CVT_DEVICE 0
15#endif
16
17#define TEST_convert_with_table 0
18
19namespace ck_tile {
20
21using fp32_t = float;
22using fp32x2_t = float __attribute__((ext_vector_type(2)));
23using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
24using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
25
26// Helpers: constexpr-safe access to elements of ext_vector_type(2)
27// Some compilers don't allow operator[] in constant expressions for vector types.
28// We use bit_cast to a trivially copyable representation to extract lanes.
29namespace detail {
31{
32 _Float16 e[2];
33};
35{
37};
39{
40 float e[2];
41};
42
43CK_TILE_HOST_DEVICE constexpr _Float16 lane0(const fp16x2_t& v)
44{
45 return ck_tile::bit_cast<fp16x2_repr>(v).e[0];
46}
47CK_TILE_HOST_DEVICE constexpr _Float16 lane1(const fp16x2_t& v)
48{
49 return ck_tile::bit_cast<fp16x2_repr>(v).e[1];
50}
51
53{
54 return ck_tile::bit_cast<bf16x2_repr>(v).e[0];
55}
57{
58 return ck_tile::bit_cast<bf16x2_repr>(v).e[1];
59}
60
61CK_TILE_HOST_DEVICE constexpr float lane0(const fp32x2_t& v)
62{
63 return ck_tile::bit_cast<fp32x2_repr>(v).e[0];
64}
65CK_TILE_HOST_DEVICE constexpr float lane1(const fp32x2_t& v)
66{
67 return ck_tile::bit_cast<fp32x2_repr>(v).e[1];
68}
69} // namespace detail
70
71struct pk_float4_e2m1_t;
72CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t float_to_pk_fp4(const float& x, float scale = 1.f);
73
74// TODO: Add stochastic method
76{
77 // TODO: Can we merge raw_type and type?
79 using type = raw_type;
81
83 template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
84 CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast<type>(init)}
85 {
86 }
87 CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f)
88 : data{float_to_pk_fp4(init, scale)}
89 {
90 }
91 CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
92 CK_TILE_HOST_DEVICE constexpr type& get() { return data; }
93 CK_TILE_HOST_DEVICE constexpr type get() const { return data; }
94
95 CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const;
96 CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const;
97 CK_TILE_HOST_DEVICE constexpr fp16_t to_fp16(float scale = 1.f) const;
98 CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const;
99 CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const;
100 CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const;
101
102 CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); }
103 CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
104 CK_TILE_HOST_DEVICE constexpr operator fp16_t() const { return to_fp16(); }
105 CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); }
106 CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); }
107 CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
108
109 template <index_t I>
111 {
112 return _unpack(number<I>{});
113 }
115 const pk_float4_e2m1_t& x1)
116 {
117 return _pack(x0.get(), x1.get());
118 }
119
120 template <index_t I>
122 CK_TILE_HOST_DEVICE constexpr static type _pack(const type x0, const type x1)
123 {
124 return (x1 << 4) | (x0 & 0b00001111);
125 }
126
127#if TEST_convert_with_table
128 static constexpr float e2m1_to_fp32_table[16] = {
129 0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6};
130 static constexpr fp16_t e2m1_to_fp16_table[16] = {
131 bit_cast<fp16_t>(static_cast<uint16_t>(0x0000)), // 0
132 bit_cast<fp16_t>(static_cast<uint16_t>(0x3800)), // 0.5
133 bit_cast<fp16_t>(static_cast<uint16_t>(0x3C00)), // 1
134 bit_cast<fp16_t>(static_cast<uint16_t>(0x3E00)), // 1.5
135 bit_cast<fp16_t>(static_cast<uint16_t>(0x4000)), // 2
136 bit_cast<fp16_t>(static_cast<uint16_t>(0x4200)), // 3
137 bit_cast<fp16_t>(static_cast<uint16_t>(0x4400)), // 4
138 bit_cast<fp16_t>(static_cast<uint16_t>(0x4600)), // 6
139 bit_cast<fp16_t>(static_cast<uint16_t>(0x8000)), // -0
140 bit_cast<fp16_t>(static_cast<uint16_t>(0xB800)), // -0.5
141 bit_cast<fp16_t>(static_cast<uint16_t>(0xBC00)), // -1
142 bit_cast<fp16_t>(static_cast<uint16_t>(0xBE00)), // -1.5
143 bit_cast<fp16_t>(static_cast<uint16_t>(0xC000)), // -2
144 bit_cast<fp16_t>(static_cast<uint16_t>(0xC200)), // -3
145 bit_cast<fp16_t>(static_cast<uint16_t>(0xC400)), // -4
146 bit_cast<fp16_t>(static_cast<uint16_t>(0xC600)) // -6
147 };
148#endif
149};
150
153
154template <>
156{
158
159 static constexpr int exp = 2;
160 static constexpr int mant = 1;
161 static constexpr int bias = 1;
162 static constexpr int PackedSize = 2;
163};
164
165// limits
166template <class T>
167struct numeric;
168
169template <>
171{
172 static constexpr pk_fp4_raw_t binary_min_normal = 0b00100010; // 1
173 static constexpr pk_fp4_raw_t binary_max_normal = 0b01110111; // 6
174 static constexpr pk_fp4_raw_t binary_lowest_normal = 0b11111111; // -6
175 static constexpr pk_fp4_raw_t binary_min_subnorm = 0b00010001; // 0.5
176 static constexpr pk_fp4_raw_t binary_max_subnorm = 0b00010001; // 0.5
177 static constexpr pk_fp4_raw_t binary_zero = 0b00000000; // 0
178 CK_TILE_HOST_DEVICE static constexpr pk_fp4_t min() { return binary_min_normal; }
179 CK_TILE_HOST_DEVICE static constexpr pk_fp4_t max() { return binary_max_normal; }
183 CK_TILE_HOST_DEVICE static constexpr pk_fp4_t zero() { return binary_zero; }
185
186 CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
187 // N/A
188 CK_TILE_HOST_DEVICE static constexpr pk_fp4_t infinity() { return max(); }
189 // N/A
190 CK_TILE_HOST_DEVICE static constexpr pk_fp4_t quiet_NaN() { return max(); }
191 // N/A
192 CK_TILE_HOST_DEVICE static constexpr pk_fp4_t signaling_NaN() { return max(); }
193};
194
195template <index_t I>
197{
198 static_assert(I < 2, "Index is out of range.");
199 if constexpr(I == 1)
200 return (data >> 4);
201 else
202 return data & 0b00001111;
203}
205// TODO: consider replace this macro to improve performance
206
207#if CK_TILE_FP4_CVT_DEVICE
208namespace impl {
209
210template <typename T>
211CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
212{
213 if constexpr(std::is_same_v<T, fp32_t>)
214 {
215 fp32x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
216 return detail::lane0(tmp);
217 }
218 else if constexpr(std::is_same_v<T, fp32x2_t>)
219 return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
220 else if constexpr(std::is_same_v<T, fp16_t>)
221 {
222 fp16x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
223 return detail::lane0(tmp);
224 }
225 else if constexpr(std::is_same_v<T, fp16x2_t>)
226 return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
227 else if constexpr(std::is_same_v<T, bf16_t>)
228 {
229 bf16x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
230 return detail::lane0(tmp);
231 }
232 else if constexpr(std::is_same_v<T, bf16x2_t>)
233 return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
234 else
235 static_assert(std::false_type::value, "Unsupported type.");
236 return T{};
237}
238template <typename T>
239CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
240{
241 union
242 {
243 uint32_t u32;
244 pk_fp4_raw_t pf4[4];
245 } cvt{0};
246 if constexpr(std::is_same_v<T, fp32_t>)
247 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0);
248 else if constexpr(std::is_same_v<T, fp32x2_t>)
249 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(
250 cvt.u32, detail::lane0(src), detail::lane1(src), scale, 0);
251 else if constexpr(std::is_same_v<T, fp16_t>)
252 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, fp16x2_t{src, src}, scale, 0);
253 else if constexpr(std::is_same_v<T, fp16x2_t>)
254 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, src, scale, 0);
255 else if constexpr(std::is_same_v<T, bf16_t>)
256 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, bf16x2_t{src, src}, scale, 0);
257 else if constexpr(std::is_same_v<T, bf16x2_t>)
258 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, src, scale, 0);
259 else
260 static_assert(std::false_type::value, "Unsupported type.");
261 return cvt.pf4[0];
262}
263
264} // namespace impl
265#endif
266
267CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const
268{
269#if CK_TILE_FP4_CVT_DEVICE
270 return impl::_from_f4<bf16_t>(data, scale);
271#else
273#endif
274}
275
277{
278#if CK_TILE_FP4_CVT_DEVICE
279 return impl::_from_f4<bf16x2_t>(data, scale);
280#else
283#endif
284}
285
286// TODO: make it generic so that we can convert from directrly.
287CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_mxfp4(float x, float scale)
288{
289#if CK_TILE_FP4_CVT_DEVICE
290 return impl::_to_f4(x, scale);
291#else
292 return convert_to_type<pk_fp4_t>(x, scale);
293#endif
294}
295CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale)
296{
297#if CK_TILE_FP4_CVT_DEVICE
298 return impl::_to_f4(x, scale);
299#else
300 auto res = convert_to_type<pk_fp4_t>(x, scale);
301 return pk_fp4_t::_pack(res, res);
302#endif
303}
304CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale)
305{
306#if CK_TILE_FP4_CVT_DEVICE
307 return impl::_to_f4(x, scale);
308#else
309 auto res = float_to_mxfp4(type_convert<float>(x), scale);
310 return pk_fp4_t::_pack(res, res);
311#endif
312}
313CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale)
314{
315#if CK_TILE_FP4_CVT_DEVICE
316 return impl::_to_f4(x, scale);
317#else
318 auto res = float_to_mxfp4(type_convert<float>(x), scale);
319 return pk_fp4_t::_pack(res, res);
320#endif
321}
323{
324#if CK_TILE_FP4_CVT_DEVICE
325 return impl::_to_f4(x, scale);
326#else
328 float_to_mxfp4(detail::lane1(x), scale));
329#endif
330}
332{
333#if CK_TILE_FP4_CVT_DEVICE
334 return impl::_to_f4(x, scale);
335#else
337 float_to_mxfp4(detail::lane1(x), scale));
338#endif
339}
341{
342#if CK_TILE_FP4_CVT_DEVICE
343 return impl::_to_f4(x, scale);
344#else
346 float_to_mxfp4(detail::lane1(x), scale));
347#endif
348}
349
351{
352 return x.to_fp32x2(scale);
353}
355{
356 return x.to_fp16x2(scale);
357}
359{
360 return x.to_bf16x2(scale);
361}
362CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale)
363{
364 return x.to_float(scale);
365}
366CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale)
367{
368 return x.to_fp16(scale);
369}
370CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale)
371{
372 return x.to_bf16(scale);
373}
374
375#if TEST_convert_with_table == 0
376CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
377{
378#if CK_TILE_FP4_CVT_DEVICE
379 return impl::_from_f4<fp32_t>(data, scale);
380#else
382#endif
383}
385{
386#if CK_TILE_FP4_CVT_DEVICE
387 return impl::_from_f4<fp32x2_t>(data, scale);
388#else
391#endif
392}
393
394CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
395{
396#if CK_TILE_FP4_CVT_DEVICE
397 return impl::_from_f4<fp16_t>(data, scale);
398#else
400#endif
401}
403{
404#if CK_TILE_FP4_CVT_DEVICE
405 return impl::_from_f4<fp16x2_t>(data, scale);
406#else
409#endif
410}
411#else
412CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
413{
414 return e2m1_to_fp32_table[_unpack(number<0>{})] * scale;
415}
416CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
417{
418 return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale};
419}
420CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
421{
422 return type_convert<float>(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale;
423}
424CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
425{
426 return fp16x2_t{
427 type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale),
429 scale)};
430}
431#endif
432
433} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition arch.hpp:385
CK_TILE_HOST_DEVICE constexpr _Float16 lane0(const fp16x2_t &v)
Definition pk_fp4.hpp:43
CK_TILE_HOST_DEVICE constexpr _Float16 lane1(const fp16x2_t &v)
Definition pk_fp4.hpp:47
Definition tile/core/arch/amd_buffer_addressing.hpp:110
Definition tile/core/algorithm/cluster_descriptor.hpp:13
_Float16 fp16x2_t
Definition half.hpp:385
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale=1.f)
Definition mxfp_convert.hpp:50
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t &x, float scale)
Definition pk_fp4.hpp:354
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t &x, float scale)
Definition pk_fp4.hpp:358
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t float_to_pk_fp4(const float &x, float scale=1.f)
Definition pk_fp4.hpp:295
CK_TILE_HOST_DEVICE T::raw_type convert_to_type(float value, float scale=1.f)
Definition mxfp_convert.hpp:61
ushort bfloat16_t
Definition bfloat16.hpp:111
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
bfloat16_t bf16x2_t
Definition pk_fp4.hpp:24
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x, float scale)
Definition pk_fp4.hpp:350
typename pk_fp4_t::type pk_fp4_raw_t
Definition pk_fp4.hpp:152
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t &x, float scale)
Definition pk_fp4.hpp:340
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t &x, float scale)
Definition pk_fp4.hpp:322
pk_float4_e2m1_t pk_fp4_t
Definition pk_fp4.hpp:151
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t &x, float scale)
Definition pk_fp4.hpp:304
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t &x, float scale)
Definition pk_fp4.hpp:331
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t &x, float scale)
Definition pk_fp4.hpp:366
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_mxfp4(float x, float scale)
Definition pk_fp4.hpp:287
float fp32x2_t
Definition pk_fp4.hpp:22
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t &x, float scale)
Definition pk_fp4.hpp:370
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t &x, float scale)
Definition pk_fp4.hpp:362
float fp32_t
Definition pk_fp4.hpp:21
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t &x, float scale)
Definition pk_fp4.hpp:313
unsigned short uint16_t
Definition stdint.h:125
unsigned int uint32_t
Definition stdint.h:126
unsigned char uint8_t
Definition stdint.h:124
Definition pk_fp4.hpp:35
bfloat16_t e[2]
Definition pk_fp4.hpp:36
Definition pk_fp4.hpp:31
_Float16 e[2]
Definition pk_fp4.hpp:32
Definition pk_fp4.hpp:39
float e[2]
Definition pk_fp4.hpp:40
static CK_TILE_HOST_DEVICE constexpr pk_fp4_t lowest()
Definition pk_fp4.hpp:180
static CK_TILE_HOST_DEVICE constexpr bool has_inf()
Definition pk_fp4.hpp:186
static constexpr pk_fp4_raw_t binary_min_normal
Definition pk_fp4.hpp:172
static CK_TILE_HOST_DEVICE constexpr pk_fp4_t round_error()
Definition pk_fp4.hpp:182
static constexpr pk_fp4_raw_t binary_max_subnorm
Definition pk_fp4.hpp:176
static constexpr pk_fp4_raw_t binary_min_subnorm
Definition pk_fp4.hpp:175
static constexpr pk_fp4_raw_t binary_zero
Definition pk_fp4.hpp:177
static CK_TILE_HOST_DEVICE constexpr pk_fp4_t denorm_min()
Definition pk_fp4.hpp:184
static CK_TILE_HOST_DEVICE constexpr pk_fp4_t epsilon()
Definition pk_fp4.hpp:181
static CK_TILE_HOST_DEVICE constexpr pk_fp4_t signaling_NaN()
Definition pk_fp4.hpp:192
static CK_TILE_HOST_DEVICE constexpr pk_fp4_t min()
Definition pk_fp4.hpp:178
static CK_TILE_HOST_DEVICE constexpr pk_fp4_t max()
Definition pk_fp4.hpp:179
static constexpr pk_fp4_raw_t binary_lowest_normal
Definition pk_fp4.hpp:174
static CK_TILE_HOST_DEVICE constexpr pk_fp4_t zero()
Definition pk_fp4.hpp:183
static CK_TILE_HOST_DEVICE constexpr pk_fp4_t quiet_NaN()
Definition pk_fp4.hpp:190
static constexpr pk_fp4_raw_t binary_max_normal
Definition pk_fp4.hpp:173
static CK_TILE_HOST_DEVICE constexpr pk_fp4_t infinity()
Definition pk_fp4.hpp:188
static constexpr int bias
Definition pk_fp4.hpp:161
pk_fp4_raw_t bitwise_type
Definition pk_fp4.hpp:157
static constexpr int mant
Definition pk_fp4.hpp:160
static constexpr int exp
Definition pk_fp4.hpp:159
static constexpr int PackedSize
Definition pk_fp4.hpp:162
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/numeric/numeric.hpp:18
static CK_TILE_HOST_DEVICE constexpr T max()
Definition tile/core/numeric/numeric.hpp:26
Definition pk_fp4.hpp:76
CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale=1.f) const
Definition pk_fp4.hpp:384
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(float init, float scale=1.f)
Definition pk_fp4.hpp:87
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number< I >) const
Definition pk_fp4.hpp:110
CK_TILE_HOST_DEVICE constexpr fp16_t to_fp16(float scale=1.f) const
Definition pk_fp4.hpp:394
CK_TILE_HOST_DEVICE constexpr type get() const
Definition pk_fp4.hpp:93
CK_TILE_HOST_DEVICE static constexpr type _pack(const type x0, const type x1)
Definition pk_fp4.hpp:122
CK_TILE_HOST_DEVICE constexpr type & get()
Definition pk_fp4.hpp:92
uint8_t raw_type
Definition pk_fp4.hpp:78
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t()
Definition pk_fp4.hpp:82
CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale=1.f) const
Definition pk_fp4.hpp:267
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init)
Definition pk_fp4.hpp:84
CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale=1.f) const
Definition pk_fp4.hpp:402
type data
Definition pk_fp4.hpp:80
raw_type type
Definition pk_fp4.hpp:79
CK_TILE_HOST_DEVICE static constexpr pk_float4_e2m1_t pack(const pk_float4_e2m1_t &x0, const pk_float4_e2m1_t &x1)
Definition pk_fp4.hpp:114
CK_TILE_HOST_DEVICE constexpr type _unpack(number< I >) const
CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale=1.f) const
Definition pk_fp4.hpp:276
CK_TILE_HOST_DEVICE constexpr float to_float(float scale=1.f) const
Definition pk_fp4.hpp:376
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition tile/core/numeric/numeric.hpp:106