Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
Eigen  3.4.0
Loading...
Searching...
No Matches
MatrixProductMMA.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com)
5// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
12#define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
13
14#pragma GCC target("cpu=power10,htm")
15
16#ifdef __has_builtin
17#if !__has_builtin(__builtin_vsx_assemble_pair)
18#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
19#endif
20#endif
21
22namespace Eigen {
23
24namespace internal {
25
26template<typename Scalar, typename Packet>
27EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc)
28{
29 __builtin_mma_xxsetaccz(acc);
30}
31
32template<typename DataMapper, typename Index, typename Packet, const Index accCols>
33EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, __vector_quad* acc)
34{
35 PacketBlock<Packet, 4> result;
36 __builtin_mma_disassemble_acc(&result.packet, acc);
37
38 PacketBlock<Packet, 4> tRes;
39 bload<DataMapper, Packet, Index, accCols, ColMajor, false, 4>(tRes, data, i, 0);
40
41 bscale<Packet, 4>(tRes, result, alpha);
42
43 data.template storePacketBlock<Packet, 4>(i, 0, tRes);
44}
45
46template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accColsC>
47EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
48{
49 PacketBlock<Packet, 4> resultReal, resultImag;
50 __builtin_mma_disassemble_acc(&resultReal.packet, accReal);
51 __builtin_mma_disassemble_acc(&resultImag.packet, accImag);
52
53 PacketBlock<Packetc, 8> tRes;
54 bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, 4>(tRes, data, i, 0);
55
56 PacketBlock<Packet,4> taccReal, taccImag;
57 bscalec<Packet,4>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag);
58
59 PacketBlock<Packetc, 4> acc1, acc2;
60 bcouple<Packet, Packetc, 4>(taccReal, taccImag, tRes, acc1, acc2);
61
62 data.template storePacketBlock<Packetc, 4>(i, 0, acc1);
63 data.template storePacketBlock<Packetc, 4>(i + accColsC, 0, acc2);
64}
65
66// Defaults to float32, since Eigen still supports C++03 we can't use default template arguments
67template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
68EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const RhsPacket& a, const LhsPacket& b)
69{
70 if(NegativeAccumulate)
71 {
72 __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
73 } else {
74 __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
75 }
76}
77
78template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
79EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const PacketBlock<Packet2d,2>& a, const Packet2d& b)
80{
81 __vector_pair* a0 = (__vector_pair *)(&a.packet[0]);
82 if(NegativeAccumulate)
83 {
84 __builtin_mma_xvf64gernp(acc, *a0, (__vector unsigned char)b);
85 } else {
86 __builtin_mma_xvf64gerpp(acc, *a0, (__vector unsigned char)b);
87 }
88}
89
90template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
91EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet2d& b)
92{
93 if(NegativeAccumulate)
94 {
95 __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b);
96 } else {
97 __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b);
98 }
99}
100
101template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
102EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad*, const __vector_pair&, const Packet4f&)
103{
104 // Just for compilation
105}
106
107template<typename Scalar, typename Packet, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
108EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, const Packet& lhsVi, const RhsPacket& rhsV, const RhsPacket& rhsVi)
109{
110 pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
111 if(LhsIsReal) {
112 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
113 } else {
114 if(!RhsIsReal) {
115 pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
116 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
117 } else {
118 EIGEN_UNUSED_VARIABLE(rhsVi);
119 }
120 pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
121 }
122}
123
124// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
125template<typename Scalar, typename Packet>
126EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV)
127{
128 rhsV = ploadRhs<Scalar, Packet>(rhs);
129}
130
131template<>
132EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, PacketBlock<Packet2d, 2> >(const double* rhs, PacketBlock<Packet2d, 2>& rhsV)
133{
134 rhsV.packet[0] = ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs ));
135 rhsV.packet[1] = ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1));
136}
137
138template<>
139EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, __vector_pair>(const double* rhs, __vector_pair& rhsV)
140{
141#if EIGEN_COMP_LLVM
142 __builtin_vsx_assemble_pair(&rhsV,
143 (__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1))),
144 (__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs ))));
145#else
146 __asm__ ("lxvp %x0,%1" : "=wa" (rhsV) : "Y" (*rhs));
147#endif
148}
149
150template<>
151EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
152{
153 // Just for compilation
154}
155
156// PEEL_MMA loop factor.
157#define PEEL_MMA 7
158
159#define MICRO_MMA_UNROLL(func) \
160 func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
161
162#define MICRO_MMA_LOAD_ONE(iter) \
163 if (unroll_factor > iter) { \
164 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
165 lhs_ptr##iter += accCols; \
166 } else { \
167 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
168 }
169
170#define MICRO_MMA_WORK_ONE(iter, type, peel) \
171 if (unroll_factor > iter) { \
172 pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV##iter); \
173 }
174
175#define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \
176 if (PEEL_MMA > peel) { \
177 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
178 ploadRhsMMA<Scalar, type>(rhs_ptr + (accRows * peel), rhsV##peel); \
179 MICRO_MMA_UNROLL(func2); \
180 func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
181 func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \
182 } else { \
183 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
184 }
185
186#define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
187 type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7; \
188 MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \
189 MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \
190 MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \
191 MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7);
192
193#define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \
194 type rhsV0; \
195 MICRO_MMA_TYPE_PEEL(func,func2,type,0);
196
197#define MICRO_MMA_ONE_PEEL \
198 if (sizeof(Scalar) == sizeof(float)) { \
199 MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
200 } else { \
201 MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
202 } \
203 rhs_ptr += (accRows * PEEL_MMA);
204
205#define MICRO_MMA_ONE \
206 if (sizeof(Scalar) == sizeof(float)) { \
207 MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
208 } else { \
209 MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
210 } \
211 rhs_ptr += accRows;
212
213#define MICRO_MMA_DST_PTR_ONE(iter) \
214 if (unroll_factor > iter) { \
215 bsetzeroMMA<Scalar, Packet>(&accZero##iter); \
216 } else { \
217 EIGEN_UNUSED_VARIABLE(accZero##iter); \
218 }
219
220#define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
221
222#define MICRO_MMA_SRC_PTR_ONE(iter) \
223 if (unroll_factor > iter) { \
224 lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \
225 } else { \
226 EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
227 }
228
229#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE)
230
231#define MICRO_MMA_PREFETCH_ONE(iter) \
232 if (unroll_factor > iter) { \
233 EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
234 }
235
236#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE)
237
238#define MICRO_MMA_STORE_ONE(iter) \
239 if (unroll_factor > iter) { \
240 storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, res, pAlpha, &accZero##iter); \
241 }
242
243#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
244
245template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
246EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(
247 const DataMapper& res,
248 const Scalar* lhs_base,
249 const Scalar* rhs_base,
250 Index depth,
251 Index strideA,
252 Index& row,
253 const Packet& pAlpha)
254{
255 const Scalar* rhs_ptr = rhs_base;
256 const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
257 __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
258
259 MICRO_MMA_SRC_PTR
260 MICRO_MMA_DST_PTR
261
262 Index k = 0;
263 for(; k + PEEL_MMA <= depth; k+= PEEL_MMA)
264 {
265 EIGEN_POWER_PREFETCH(rhs_ptr);
266 MICRO_MMA_PREFETCH
267 MICRO_MMA_ONE_PEEL
268 }
269 for(; k < depth; k++)
270 {
271 MICRO_MMA_ONE
272 }
273 MICRO_MMA_STORE
274
275 row += unroll_factor*accCols;
276}
277
278template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
279EIGEN_ALWAYS_INLINE void gemmMMA_cols(
280 const DataMapper& res,
281 const Scalar* blockA,
282 const Scalar* blockB,
283 Index depth,
284 Index strideA,
285 Index offsetA,
286 Index strideB,
287 Index offsetB,
288 Index col,
289 Index rows,
290 Index cols,
291 Index remaining_rows,
292 const Packet& pAlpha,
293 const Packet& pMask)
294{
295 const DataMapper res3 = res.getSubMapper(0, col);
296
297 const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
298 const Scalar* lhs_base = blockA + accCols*offsetA;
299 Index row = 0;
300
301#define MAX_MMA_UNROLL 7
302 while(row + MAX_MMA_UNROLL*accCols <= rows) {
303 gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
304 }
305 switch( (rows-row)/accCols ) {
306#if MAX_MMA_UNROLL > 7
307 case 7:
308 gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
309 break;
310#endif
311#if MAX_MMA_UNROLL > 6
312 case 6:
313 gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
314 break;
315#endif
316#if MAX_MMA_UNROLL > 5
317 case 5:
318 gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
319 break;
320#endif
321#if MAX_MMA_UNROLL > 4
322 case 4:
323 gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
324 break;
325#endif
326#if MAX_MMA_UNROLL > 3
327 case 3:
328 gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
329 break;
330#endif
331#if MAX_MMA_UNROLL > 2
332 case 2:
333 gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
334 break;
335#endif
336#if MAX_MMA_UNROLL > 1
337 case 1:
338 gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
339 break;
340#endif
341 default:
342 break;
343 }
344#undef MAX_MMA_UNROLL
345
346 if(remaining_rows > 0)
347 {
348 gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
349 }
350}
351
352template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
353void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
354{
355 const Index remaining_rows = rows % accCols;
356
357 if( strideA == -1 ) strideA = depth;
358 if( strideB == -1 ) strideB = depth;
359
360 const Packet pAlpha = pset1<Packet>(alpha);
361 const Packet pMask = bmask<Packet>((const int)(remaining_rows));
362
363 Index col = 0;
364 for(; col + accRows <= cols; col += accRows)
365 {
366 gemmMMA_cols<Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
367 }
368
369 gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
370}
371
372#define accColsC (accCols / 2)
373#define advanceRows ((LhsIsReal) ? 1 : 2)
374#define advanceCols ((RhsIsReal) ? 1 : 2)
375
376// PEEL_COMPLEX_MMA loop factor.
377#define PEEL_COMPLEX_MMA 3
378
379#define MICRO_COMPLEX_MMA_UNROLL(func) \
380 func(0) func(1) func(2) func(3)
381
382#define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \
383 if (unroll_factor > iter) { \
384 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
385 if(!LhsIsReal) { \
386 lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter + imag_delta); \
387 } else { \
388 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
389 } \
390 lhs_ptr_real##iter += accCols; \
391 } else { \
392 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
393 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
394 }
395
396#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
397 if (unroll_factor > iter) { \
398 pgercMMA<Scalar, Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
399 }
400
401#define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \
402 if (PEEL_COMPLEX_MMA > peel) { \
403 Packet lhsV0, lhsV1, lhsV2, lhsV3; \
404 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
405 ploadRhsMMA<Scalar, type>(rhs_ptr_real + (accRows * peel), rhsV##peel); \
406 if(!RhsIsReal) { \
407 ploadRhsMMA<Scalar, type>(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
408 } else { \
409 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
410 } \
411 MICRO_COMPLEX_MMA_UNROLL(func2); \
412 func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
413 } else { \
414 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
415 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
416 }
417
418#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
419 type rhsV0, rhsV1, rhsV2, rhsV3; \
420 type rhsVi0, rhsVi1, rhsVi2, rhsVi3; \
421 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \
422 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3);
423
424#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \
425 type rhsV0, rhsVi0; \
426 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0);
427
428#define MICRO_COMPLEX_MMA_ONE_PEEL \
429 if (sizeof(Scalar) == sizeof(float)) { \
430 MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
431 } else { \
432 MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
433 } \
434 rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \
435 if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA);
436
437#define MICRO_COMPLEX_MMA_ONE \
438 if (sizeof(Scalar) == sizeof(float)) { \
439 MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
440 } else { \
441 MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
442 } \
443 rhs_ptr_real += accRows; \
444 if(!RhsIsReal) rhs_ptr_imag += accRows;
445
446#define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
447 if (unroll_factor > iter) { \
448 bsetzeroMMA<Scalar, Packet>(&accReal##iter); \
449 bsetzeroMMA<Scalar, Packet>(&accImag##iter); \
450 } else { \
451 EIGEN_UNUSED_VARIABLE(accReal##iter); \
452 EIGEN_UNUSED_VARIABLE(accImag##iter); \
453 }
454
455#define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
456
457#define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \
458 if (unroll_factor > iter) { \
459 lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \
460 } else { \
461 EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
462 }
463
464#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE)
465
466#define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \
467 if (unroll_factor > iter) { \
468 EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
469 }
470
471#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE)
472
473#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
474 if (unroll_factor > iter) { \
475 storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC>(row + iter*accCols, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
476 }
477
478#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
479
480template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
481EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
482 const DataMapper& res,
483 const Scalar* lhs_base,
484 const Scalar* rhs_base,
485 Index depth,
486 Index strideA,
487 Index strideB,
488 Index& row,
489 const Packet& pAlphaReal,
490 const Packet& pAlphaImag)
491{
492 const Scalar* rhs_ptr_real = rhs_base;
493 const Scalar* rhs_ptr_imag = NULL;
494 const Index imag_delta = accCols*strideA;
495 if(!RhsIsReal) {
496 rhs_ptr_imag = rhs_base + accRows*strideB;
497 } else {
498 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
499 }
500 const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
501 const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
502 __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
503
504 MICRO_COMPLEX_MMA_SRC_PTR
505 MICRO_COMPLEX_MMA_DST_PTR
506
507 Index k = 0;
508 for(; k + PEEL_COMPLEX_MMA <= depth; k+= PEEL_COMPLEX_MMA)
509 {
510 EIGEN_POWER_PREFETCH(rhs_ptr_real);
511 if(!RhsIsReal) {
512 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
513 }
514 MICRO_COMPLEX_MMA_PREFETCH
515 MICRO_COMPLEX_MMA_ONE_PEEL
516 }
517 for(; k < depth; k++)
518 {
519 MICRO_COMPLEX_MMA_ONE
520 }
521 MICRO_COMPLEX_MMA_STORE
522
523 row += unroll_factor*accCols;
524}
525
526template<typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
527EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
528 const DataMapper& res,
529 const Scalar* blockA,
530 const Scalar* blockB,
531 Index depth,
532 Index strideA,
533 Index offsetA,
534 Index strideB,
535 Index offsetB,
536 Index col,
537 Index rows,
538 Index cols,
539 Index remaining_rows,
540 const Packet& pAlphaReal,
541 const Packet& pAlphaImag,
542 const Packet& pMask)
543{
544 const DataMapper res3 = res.getSubMapper(0, col);
545
546 const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
547 const Scalar* lhs_base = blockA + accCols*offsetA;
548 Index row = 0;
549
550#define MAX_COMPLEX_MMA_UNROLL 4
551 while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
552 gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
553 }
554 switch( (rows-row)/accCols ) {
555#if MAX_COMPLEX_MMA_UNROLL > 4
556 case 4:
557 gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
558 break;
559#endif
560#if MAX_COMPLEX_MMA_UNROLL > 3
561 case 3:
562 gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
563 break;
564#endif
565#if MAX_COMPLEX_MMA_UNROLL > 2
566 case 2:
567 gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
568 break;
569#endif
570#if MAX_COMPLEX_MMA_UNROLL > 1
571 case 1:
572 gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
573 break;
574#endif
575 default:
576 break;
577 }
578#undef MAX_COMPLEX_MMA_UNROLL
579
580 if(remaining_rows > 0)
581 {
582 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
583 }
584}
585
586template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
587void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
588{
589 const Index remaining_rows = rows % accCols;
590
591 if( strideA == -1 ) strideA = depth;
592 if( strideB == -1 ) strideB = depth;
593
594 const Packet pAlphaReal = pset1<Packet>(alpha.real());
595 const Packet pAlphaImag = pset1<Packet>(alpha.imag());
596 const Packet pMask = bmask<Packet>((const int)(remaining_rows));
597
598 const Scalar* blockA = (Scalar *) blockAc;
599 const Scalar* blockB = (Scalar *) blockBc;
600
601 Index col = 0;
602 for(; col + accRows <= cols; col += accRows)
603 {
604 gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
605 }
606
607 gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
608}
609
610#undef accColsC
611#undef advanceRows
612#undef advanceCols
613
614#pragma GCC reset_options
615} // end namespace internal
616
617} // end namespace Eigen
618
619#endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
620
Namespace containing all symbols from the Eigen library.
Definition Core:141
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:74