generic_permute_kernel.hpp Source File

generic_permute_kernel.hpp Source File#

Composable Kernel: generic_permute_kernel.hpp Source File
generic_permute_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8// #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
9
10namespace ck_tile {
11
12/* independent host side argument, no template
13 */
15{
16 static constexpr index_t kMaxRanks = 8; // TODO: hardcoded
17
18 const void* p_src;
19 void* p_dst;
21 index_t shape[kMaxRanks]; // input shape
22 index_t perm[kMaxRanks]; // permute index
23};
24
25/*
26simulate torch.permute:
27x_ = x_.view(x.shape[0],
28 x.shape[1]//16, 16,
29 x.shape[2]//32, 4, 8)
30x_ = x_.permute(0,1,3,4,2,5)
31x_ = x_.contiguous()
32x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]);//
33
34this kernel is supposed not to be performant(just OK), with functional support up to kMaxRanks
35dim of permutation, with a single kernel
36
37*/
38template <typename Problem_>
40{
42
44 static constexpr index_t kBlockSize = Problem::kBlockSize;
45 static constexpr index_t kMaxRanks = Problem::kMaxRanks;
46 static constexpr bool KeepLastDim = Problem::KeepLastDim;
47
48 struct __attribute__((packed)) Kargs
49 {
50 const void* p_src;
51 void* p_dst;
52 // index_t rank;
54 index_t perm_length[kMaxRanks]; // tensor length after permutation
55 index_t perm_stride[kMaxRanks]; // tensor stride after permutation
56 };
57
59 {
60 index_t n = 1;
61 for(auto i = 0; i < h.rank; i++)
62 {
63 n *= h.shape[i];
64 }
65 return n;
66 }
67
69 {
70 Kargs a;
71 a.p_src = h.p_src;
72 a.p_dst = h.p_dst;
73
74 // assert rank <= kMaxRanks
75 index_t i = 0;
76
77 index_t perm[kMaxRanks];
78 index_t x_shape[kMaxRanks];
79 index_t x_stride[kMaxRanks];
80 // index_t perm_length[kMaxRanks];
81
82 for(; i < h.rank; i++)
83 {
84 x_shape[i] = h.shape[i];
85 perm[i] = h.perm[i];
86 }
87 for(; i < kMaxRanks; i++)
88 {
89 x_shape[i] = 1;
90 perm[i] = i; // will index to len = 1
91 }
92
93 index_t stride = 1;
94 for(index_t j = kMaxRanks - 1; j >= 0; j--)
95 {
96 x_stride[j] = stride;
97 stride *= x_shape[j];
98 }
99
100 for(index_t j = 0; j < kMaxRanks; j++)
101 {
102 a.perm_length[j] = x_shape[perm[j]];
103 a.perm_stride[j] = x_stride[perm[j]];
104 }
105
106 a.num_elements = TotalElements(h);
107 return a;
108 }
109
111 {
112 auto total = TotalElements(h);
113 auto grids = dim3((total + BlockSize() - 1) / BlockSize());
114 // printf("### total:%d, grids:%dx%dx%d\n", total, );
115 return grids;
116 }
117
118 CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
119
121 {
122 index_t id = blockIdx.x * BlockSize() + threadIdx.x;
123
124 if(id >= kargs.num_elements)
125 return;
126
127 const auto perm_length =
128 generate_tuple([&](auto I) { return kargs.perm_length[I]; }, number<kMaxRanks>{});
129 const auto perm_stride =
130 generate_tuple([&](auto I) { return kargs.perm_stride[I]; }, number<kMaxRanks>{});
131
132 const DataType* p_src = reinterpret_cast<const DataType*>(kargs.p_src);
133 DataType* p_dst = reinterpret_cast<DataType*>(kargs.p_dst);
134
136 p_src, perm_length, perm_stride, number<1>{}, number<1>{});
137
138 const auto src_view = transform_tensor_view(
139 src_view_0,
140 make_tuple(make_merge_transform(perm_length)),
143
145 p_dst, perm_length, number<1>{});
146
147 auto dst_view = transform_tensor_view(
148 dst_view_0,
149 make_tuple(make_merge_transform(perm_length)),
152
153 // TODO: hard code to vector 1
154 using vector_t = thread_buffer<DataType, 1>;
155
156 const auto src_coord =
157 make_tensor_coordinate(src_view.get_tensor_descriptor(), array<index_t, 1>{id});
158 const auto dst_coord =
159 make_tensor_coordinate(dst_view.get_tensor_descriptor(), array<index_t, 1>{id});
160
161 // printf("src id:%d, os:%d\n", id, src_coord.get_offset());
162 // printf("dst id:%d, os:%d\n", id, dst_coord.get_offset());
163
164 const vector_t x = src_view.template get_vectorized_elements<vector_t>(src_coord, 0);
165 dst_view.template set_vectorized_elements<vector_t>(dst_coord, 0, x);
166 }
167};
168
169} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition tensor_coordinate.hpp:60
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view_packed(DataType *__restrict__ p, const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition tensor_view.hpp:494
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition generic_permute_kernel.hpp:49
const void * p_src
Definition generic_permute_kernel.hpp:50
index_t perm_stride[kMaxRanks]
Definition generic_permute_kernel.hpp:55
index_t perm_length[kMaxRanks]
Definition generic_permute_kernel.hpp:54
index_t num_elements
Definition generic_permute_kernel.hpp:53
void * p_dst
Definition generic_permute_kernel.hpp:51
Definition generic_permute_kernel.hpp:15
const void * p_src
Definition generic_permute_kernel.hpp:18
index_t perm[kMaxRanks]
Definition generic_permute_kernel.hpp:22
void * p_dst
Definition generic_permute_kernel.hpp:19
index_t shape[kMaxRanks]
Definition generic_permute_kernel.hpp:21
static constexpr index_t kMaxRanks
Definition generic_permute_kernel.hpp:16
index_t rank
Definition generic_permute_kernel.hpp:20
Definition generic_permute_kernel.hpp:40
static constexpr index_t kBlockSize
Definition generic_permute_kernel.hpp:44
static constexpr index_t kMaxRanks
Definition generic_permute_kernel.hpp:45
static CK_TILE_HOST constexpr Kargs MakeKargs(const GenericPermuteHostArgs &h)
Definition generic_permute_kernel.hpp:68
remove_cvref_t< typename Problem::DataType > DataType
Definition generic_permute_kernel.hpp:43
static CK_TILE_HOST constexpr index_t TotalElements(const GenericPermuteHostArgs &h)
Definition generic_permute_kernel.hpp:58
static constexpr bool KeepLastDim
Definition generic_permute_kernel.hpp:46
static CK_TILE_HOST constexpr auto GridSize(GenericPermuteHostArgs h)
Definition generic_permute_kernel.hpp:110
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition generic_permute_kernel.hpp:120
ck_tile::remove_cvref_t< Problem_ > Problem
Definition generic_permute_kernel.hpp:41
static CK_TILE_HOST_DEVICE constexpr auto BlockSize()
Definition generic_permute_kernel.hpp:118
typename std::conditional< kHasContent, type0, type1 >::type type
Definition tile/core/container/sequence.hpp:302
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/debug.hpp:67