Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
Eigen  3.4.0
Loading...
Searching...
No Matches
MatrixProduct.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com)
5// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H
12#define EIGEN_MATRIX_PRODUCT_ALTIVEC_H
13
14#ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK
15#define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
16#endif
17
18#include "MatrixProductCommon.h"
19
20// Since LLVM doesn't support dynamic dispatching, force either always MMA or VSX
21#if EIGEN_COMP_LLVM
22#if !defined(EIGEN_ALTIVEC_DISABLE_MMA) && !defined(EIGEN_ALTIVEC_MMA_ONLY)
23#ifdef __MMA__
24#define EIGEN_ALTIVEC_MMA_ONLY
25#else
26#define EIGEN_ALTIVEC_DISABLE_MMA
27#endif
28#endif
29#endif
30
31#ifdef __has_builtin
32#if __has_builtin(__builtin_mma_assemble_acc)
33 #define ALTIVEC_MMA_SUPPORT
34#endif
35#endif
36
37#if defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
38 #include "MatrixProductMMA.h"
39#endif
40
41/**************************************************************************************************
42 * TODO *
43 * - Check StorageOrder on dhs_pack (the innermost second loop seems unvectorized when it could). *
44 * - Check the possibility of transposing as GETREAL and GETIMAG when needed. *
45 **************************************************************************************************/
46namespace Eigen {
47
48namespace internal {
49
50/**************************
51 * Constants and typedefs *
52 **************************/
53template<typename Scalar>
54struct quad_traits
55{
56 typedef typename packet_traits<Scalar>::type vectortype;
57 typedef PacketBlock<vectortype,4> type;
58 typedef vectortype rhstype;
59 enum
60 {
61 vectorsize = packet_traits<Scalar>::size,
62 size = 4,
63 rows = 4
64 };
65};
66
67template<>
68struct quad_traits<double>
69{
70 typedef Packet2d vectortype;
71 typedef PacketBlock<vectortype,4> type;
72 typedef PacketBlock<Packet2d,2> rhstype;
73 enum
74 {
75 vectorsize = packet_traits<double>::size,
76 size = 2,
77 rows = 4
78 };
79};
80
81// MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out
82// to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then
83// are responsible to extract from convert between Eigen's and MatrixProduct approach.
84
85const static Packet16uc p16uc_GETREAL32 = { 0, 1, 2, 3,
86 8, 9, 10, 11,
87 16, 17, 18, 19,
88 24, 25, 26, 27};
89
90const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7,
91 12, 13, 14, 15,
92 20, 21, 22, 23,
93 28, 29, 30, 31};
94const static Packet16uc p16uc_GETREAL64 = { 0, 1, 2, 3, 4, 5, 6, 7,
95 16, 17, 18, 19, 20, 21, 22, 23};
96
97//[a,ai],[b,bi] = [ai,bi]
98const static Packet16uc p16uc_GETIMAG64 = { 8, 9, 10, 11, 12, 13, 14, 15,
99 24, 25, 26, 27, 28, 29, 30, 31};
100
101/*********************************************
102 * Single precision real and complex packing *
103 * *******************************************/
104
119template<typename Scalar, typename Index, int StorageOrder>
120EIGEN_ALWAYS_INLINE std::complex<Scalar> getAdjointVal(Index i, Index j, const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder>& dt)
121{
122 std::complex<Scalar> v;
123 if(i < j)
124 {
125 v.real( dt(j,i).real());
126 v.imag(-dt(j,i).imag());
127 } else if(i > j)
128 {
129 v.real( dt(i,j).real());
130 v.imag( dt(i,j).imag());
131 } else {
132 v.real( dt(i,j).real());
133 v.imag((Scalar)0.0);
134 }
135 return v;
136}
137
138template<typename Scalar, typename Index, int StorageOrder, int N>
139EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* blockB, const std::complex<Scalar>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
140{
141 const Index depth = k2 + rows;
142 const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> rhs(_rhs, rhsStride);
143 const Index vectorSize = N*quad_traits<Scalar>::vectorsize;
144 const Index vectorDelta = vectorSize * rows;
145 Scalar* blockBf = reinterpret_cast<Scalar *>(blockB);
146
147 Index rir = 0, rii, j = 0;
148 for(; j + vectorSize <= cols; j+=vectorSize)
149 {
150 rii = rir + vectorDelta;
151
152 for(Index i = k2; i < depth; i++)
153 {
154 for(Index k = 0; k < vectorSize; k++)
155 {
156 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j + k, rhs);
157
158 blockBf[rir + k] = v.real();
159 blockBf[rii + k] = v.imag();
160 }
161 rir += vectorSize;
162 rii += vectorSize;
163 }
164
165 rir += vectorDelta;
166 }
167
168 for(; j < cols; j++)
169 {
170 rii = rir + rows;
171
172 for(Index i = k2; i < depth; i++)
173 {
174 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j, rhs);
175
176 blockBf[rir] = v.real();
177 blockBf[rii] = v.imag();
178
179 rir += 1;
180 rii += 1;
181 }
182
183 rir += rows;
184 }
185}
186
187template<typename Scalar, typename Index, int StorageOrder>
188EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar>* blockA, const std::complex<Scalar>* _lhs, Index lhsStride, Index cols, Index rows)
189{
190 const Index depth = cols;
191 const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> lhs(_lhs, lhsStride);
192 const Index vectorSize = quad_traits<Scalar>::vectorsize;
193 const Index vectorDelta = vectorSize * depth;
194 Scalar* blockAf = (Scalar *)(blockA);
195
196 Index rir = 0, rii, j = 0;
197 for(; j + vectorSize <= rows; j+=vectorSize)
198 {
199 rii = rir + vectorDelta;
200
201 for(Index i = 0; i < depth; i++)
202 {
203 for(Index k = 0; k < vectorSize; k++)
204 {
205 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(j+k, i, lhs);
206
207 blockAf[rir + k] = v.real();
208 blockAf[rii + k] = v.imag();
209 }
210 rir += vectorSize;
211 rii += vectorSize;
212 }
213
214 rir += vectorDelta;
215 }
216
217 if (j < rows)
218 {
219 rii = rir + ((rows - j) * depth);
220
221 for(Index i = 0; i < depth; i++)
222 {
223 Index k = j;
224 for(; k < rows; k++)
225 {
226 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(k, i, lhs);
227
228 blockAf[rir] = v.real();
229 blockAf[rii] = v.imag();
230
231 rir += 1;
232 rii += 1;
233 }
234 }
235 }
236}
237
238template<typename Scalar, typename Index, int StorageOrder, int N>
239EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
240{
241 const Index depth = k2 + rows;
242 const_blas_data_mapper<Scalar, Index, StorageOrder> rhs(_rhs, rhsStride);
243 const Index vectorSize = quad_traits<Scalar>::vectorsize;
244
245 Index ri = 0, j = 0;
246 for(; j + N*vectorSize <= cols; j+=N*vectorSize)
247 {
248 Index i = k2;
249 for(; i < depth; i++)
250 {
251 for(Index k = 0; k < N*vectorSize; k++)
252 {
253 if(i <= j+k)
254 blockB[ri + k] = rhs(j+k, i);
255 else
256 blockB[ri + k] = rhs(i, j+k);
257 }
258 ri += N*vectorSize;
259 }
260 }
261
262 for(; j < cols; j++)
263 {
264 for(Index i = k2; i < depth; i++)
265 {
266 if(j <= i)
267 blockB[ri] = rhs(i, j);
268 else
269 blockB[ri] = rhs(j, i);
270 ri += 1;
271 }
272 }
273}
274
275template<typename Scalar, typename Index, int StorageOrder>
276EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows)
277{
278 const Index depth = cols;
279 const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs, lhsStride);
280 const Index vectorSize = quad_traits<Scalar>::vectorsize;
281
282 Index ri = 0, j = 0;
283 for(; j + vectorSize <= rows; j+=vectorSize)
284 {
285 Index i = 0;
286
287 for(; i < depth; i++)
288 {
289 for(Index k = 0; k < vectorSize; k++)
290 {
291 if(i <= j+k)
292 blockA[ri + k] = lhs(j+k, i);
293 else
294 blockA[ri + k] = lhs(i, j+k);
295 }
296 ri += vectorSize;
297 }
298 }
299
300 if (j < rows)
301 {
302 for(Index i = 0; i < depth; i++)
303 {
304 Index k = j;
305 for(; k < rows; k++)
306 {
307 if(i <= k)
308 blockA[ri] = lhs(k, i);
309 else
310 blockA[ri] = lhs(i, k);
311 ri += 1;
312 }
313 }
314 }
315}
316
317template<typename Index, int nr, int StorageOrder>
318struct symm_pack_rhs<std::complex<float>, Index, nr, StorageOrder>
319{
320 void operator()(std::complex<float>* blockB, const std::complex<float>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
321 {
322 symm_pack_complex_rhs_helper<float, Index, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
323 }
324};
325
326template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
327struct symm_pack_lhs<std::complex<float>, Index, Pack1, Pack2_dummy, StorageOrder>
328{
329 void operator()(std::complex<float>* blockA, const std::complex<float>* _lhs, Index lhsStride, Index cols, Index rows)
330 {
331 symm_pack_complex_lhs_helper<float, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
332 }
333};
334
335// *********** symm_pack std::complex<float64> ***********
336
337template<typename Index, int nr, int StorageOrder>
338struct symm_pack_rhs<std::complex<double>, Index, nr, StorageOrder>
339{
340 void operator()(std::complex<double>* blockB, const std::complex<double>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
341 {
342 symm_pack_complex_rhs_helper<double, Index, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
343 }
344};
345
346template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
347struct symm_pack_lhs<std::complex<double>, Index, Pack1, Pack2_dummy, StorageOrder>
348{
349 void operator()(std::complex<double>* blockA, const std::complex<double>* _lhs, Index lhsStride, Index cols, Index rows)
350 {
351 symm_pack_complex_lhs_helper<double, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
352 }
353};
354
355// *********** symm_pack float32 ***********
356template<typename Index, int nr, int StorageOrder>
357struct symm_pack_rhs<float, Index, nr, StorageOrder>
358{
359 void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
360 {
361 symm_pack_rhs_helper<float, Index, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
362 }
363};
364
365template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
366struct symm_pack_lhs<float, Index, Pack1, Pack2_dummy, StorageOrder>
367{
368 void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows)
369 {
370 symm_pack_lhs_helper<float, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
371 }
372};
373
374// *********** symm_pack float64 ***********
375template<typename Index, int nr, int StorageOrder>
376struct symm_pack_rhs<double, Index, nr, StorageOrder>
377{
378 void operator()(double* blockB, const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
379 {
380 symm_pack_rhs_helper<double, Index, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
381 }
382};
383
384template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
385struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder>
386{
387 void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows)
388 {
389 symm_pack_lhs_helper<double, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
390 }
391};
392
403
404template<typename Scalar, typename Packet, typename Index, int N>
405EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,N>& block)
406{
407 const Index size = 16 / sizeof(Scalar);
408 pstore<Scalar>(to + (0 * size), block.packet[0]);
409 pstore<Scalar>(to + (1 * size), block.packet[1]);
410 if (N > 2) {
411 pstore<Scalar>(to + (2 * size), block.packet[2]);
412 }
413 if (N > 3) {
414 pstore<Scalar>(to + (3 * size), block.packet[3]);
415 }
416}
417
418// General template for lhs & rhs complex packing.
419template<typename Scalar, typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode, bool UseLhs>
420struct dhs_cpack {
421 EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
422 {
423 const Index vectorSize = quad_traits<Scalar>::vectorsize;
424 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
425 Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii;
426 Scalar* blockAt = reinterpret_cast<Scalar *>(blockA);
427 Index j = 0;
428
429 for(; j + vectorSize <= rows; j+=vectorSize)
430 {
431 Index i = 0;
432
433 rii = rir + vectorDelta;
434
435 for(; i + vectorSize <= depth; i+=vectorSize)
436 {
437 PacketBlock<Packet,4> blockr, blocki;
438 PacketBlock<PacketC,8> cblock;
439
440 if (UseLhs) {
441 bload<DataMapper, PacketC, Index, 2, StorageOrder, true, 4>(cblock, lhs, j, i);
442 } else {
443 bload<DataMapper, PacketC, Index, 2, StorageOrder, true, 4>(cblock, lhs, i, j);
444 }
445
446 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
447 blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32);
448 blockr.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32);
449 blockr.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32);
450
451 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32);
452 blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32);
453 blocki.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32);
454 blocki.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32);
455
456 if(Conjugate)
457 {
458 blocki.packet[0] = -blocki.packet[0];
459 blocki.packet[1] = -blocki.packet[1];
460 blocki.packet[2] = -blocki.packet[2];
461 blocki.packet[3] = -blocki.packet[3];
462 }
463
464 if(((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs)))
465 {
466 ptranspose(blockr);
467 ptranspose(blocki);
468 }
469
470 storeBlock<Scalar, Packet, Index, 4>(blockAt + rir, blockr);
471 storeBlock<Scalar, Packet, Index, 4>(blockAt + rii, blocki);
472
473 rir += 4*vectorSize;
474 rii += 4*vectorSize;
475 }
476 for(; i < depth; i++)
477 {
478 PacketBlock<Packet,1> blockr, blocki;
479 PacketBlock<PacketC,2> cblock;
480
481 if(((StorageOrder == ColMajor) && UseLhs) || (((StorageOrder == RowMajor) && !UseLhs)))
482 {
483 if (UseLhs) {
484 cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
485 cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 2, i);
486 } else {
487 cblock.packet[0] = lhs.template loadPacket<PacketC>(i, j + 0);
488 cblock.packet[1] = lhs.template loadPacket<PacketC>(i, j + 2);
489 }
490 } else {
491 if (UseLhs) {
492 cblock.packet[0] = pload2(lhs(j + 0, i), lhs(j + 1, i));
493 cblock.packet[1] = pload2(lhs(j + 2, i), lhs(j + 3, i));
494 } else {
495 cblock.packet[0] = pload2(lhs(i, j + 0), lhs(i, j + 1));
496 cblock.packet[1] = pload2(lhs(i, j + 2), lhs(i, j + 3));
497 }
498 }
499
500 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL32);
501 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG32);
502
503 if(Conjugate)
504 {
505 blocki.packet[0] = -blocki.packet[0];
506 }
507
508 pstore<Scalar>(blockAt + rir, blockr.packet[0]);
509 pstore<Scalar>(blockAt + rii, blocki.packet[0]);
510
511 rir += vectorSize;
512 rii += vectorSize;
513 }
514
515 rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
516 }
517
518 if (!UseLhs)
519 {
520 if(PanelMode) rir -= (offset*(vectorSize - 1));
521
522 for(; j < rows; j++)
523 {
524 rii = rir + ((PanelMode) ? stride : depth);
525
526 for(Index i = 0; i < depth; i++)
527 {
528 blockAt[rir] = lhs(i, j).real();
529
530 if(Conjugate)
531 blockAt[rii] = -lhs(i, j).imag();
532 else
533 blockAt[rii] = lhs(i, j).imag();
534
535 rir += 1;
536 rii += 1;
537 }
538
539 rir += ((PanelMode) ? (2*stride - depth) : depth);
540 }
541 } else {
542 if (j < rows)
543 {
544 if(PanelMode) rir += (offset*(rows - j - vectorSize));
545 rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
546
547 for(Index i = 0; i < depth; i++)
548 {
549 Index k = j;
550 for(; k < rows; k++)
551 {
552 blockAt[rir] = lhs(k, i).real();
553
554 if(Conjugate)
555 blockAt[rii] = -lhs(k, i).imag();
556 else
557 blockAt[rii] = lhs(k, i).imag();
558
559 rir += 1;
560 rii += 1;
561 }
562 }
563 }
564 }
565 }
566};
567
568// General template for lhs & rhs packing.
569template<typename Scalar, typename Index, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode, bool UseLhs>
570struct dhs_pack{
571 EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
572 {
573 const Index vectorSize = quad_traits<Scalar>::vectorsize;
574 Index ri = 0, j = 0;
575
576 for(; j + vectorSize <= rows; j+=vectorSize)
577 {
578 Index i = 0;
579
580 if(PanelMode) ri += vectorSize*offset;
581
582 for(; i + vectorSize <= depth; i+=vectorSize)
583 {
584 PacketBlock<Packet,4> block;
585
586 if (UseLhs) {
587 bload<DataMapper, Packet, Index, 4, StorageOrder, false, 4>(block, lhs, j, i);
588 } else {
589 bload<DataMapper, Packet, Index, 4, StorageOrder, false, 4>(block, lhs, i, j);
590 }
591 if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
592 {
593 ptranspose(block);
594 }
595
596 storeBlock<Scalar, Packet, Index, 4>(blockA + ri, block);
597
598 ri += 4*vectorSize;
599 }
600 for(; i < depth; i++)
601 {
602 if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
603 {
604 if (UseLhs) {
605 blockA[ri+0] = lhs(j+0, i);
606 blockA[ri+1] = lhs(j+1, i);
607 blockA[ri+2] = lhs(j+2, i);
608 blockA[ri+3] = lhs(j+3, i);
609 } else {
610 blockA[ri+0] = lhs(i, j+0);
611 blockA[ri+1] = lhs(i, j+1);
612 blockA[ri+2] = lhs(i, j+2);
613 blockA[ri+3] = lhs(i, j+3);
614 }
615 } else {
616 Packet lhsV;
617 if (UseLhs) {
618 lhsV = lhs.template loadPacket<Packet>(j, i);
619 } else {
620 lhsV = lhs.template loadPacket<Packet>(i, j);
621 }
622 pstore<Scalar>(blockA + ri, lhsV);
623 }
624
625 ri += vectorSize;
626 }
627
628 if(PanelMode) ri += vectorSize*(stride - offset - depth);
629 }
630
631 if (!UseLhs)
632 {
633 if(PanelMode) ri += offset;
634
635 for(; j < rows; j++)
636 {
637 for(Index i = 0; i < depth; i++)
638 {
639 blockA[ri] = lhs(i, j);
640 ri += 1;
641 }
642
643 if(PanelMode) ri += stride - depth;
644 }
645 } else {
646 if (j < rows)
647 {
648 if(PanelMode) ri += offset*(rows - j);
649
650 for(Index i = 0; i < depth; i++)
651 {
652 Index k = j;
653 for(; k < rows; k++)
654 {
655 blockA[ri] = lhs(k, i);
656 ri += 1;
657 }
658 }
659 }
660 }
661 }
662};
663
664// General template for lhs packing, float64 specialization.
665template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode>
666struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, true>
667{
668 EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
669 {
670 const Index vectorSize = quad_traits<double>::vectorsize;
671 Index ri = 0, j = 0;
672
673 for(; j + vectorSize <= rows; j+=vectorSize)
674 {
675 Index i = 0;
676
677 if(PanelMode) ri += vectorSize*offset;
678
679 for(; i + vectorSize <= depth; i+=vectorSize)
680 {
681 PacketBlock<Packet2d,2> block;
682 if(StorageOrder == RowMajor)
683 {
684 block.packet[0] = lhs.template loadPacket<Packet2d>(j + 0, i);
685 block.packet[1] = lhs.template loadPacket<Packet2d>(j + 1, i);
686
687 ptranspose(block);
688 } else {
689 block.packet[0] = lhs.template loadPacket<Packet2d>(j, i + 0);
690 block.packet[1] = lhs.template loadPacket<Packet2d>(j, i + 1);
691 }
692
693 storeBlock<double, Packet2d, Index, 2>(blockA + ri, block);
694
695 ri += 2*vectorSize;
696 }
697 for(; i < depth; i++)
698 {
699 if(StorageOrder == RowMajor)
700 {
701 blockA[ri+0] = lhs(j+0, i);
702 blockA[ri+1] = lhs(j+1, i);
703 } else {
704 Packet2d lhsV = lhs.template loadPacket<Packet2d>(j, i);
705 pstore<double>(blockA + ri, lhsV);
706 }
707
708 ri += vectorSize;
709 }
710
711 if(PanelMode) ri += vectorSize*(stride - offset - depth);
712 }
713
714 if (j < rows)
715 {
716 if(PanelMode) ri += offset*(rows - j);
717
718 for(Index i = 0; i < depth; i++)
719 {
720 Index k = j;
721 for(; k < rows; k++)
722 {
723 blockA[ri] = lhs(k, i);
724 ri += 1;
725 }
726 }
727 }
728 }
729};
730
731// General template for rhs packing, float64 specialization.
732template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode>
733struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, false>
734{
735 EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
736 {
737 const Index vectorSize = quad_traits<double>::vectorsize;
738 Index ri = 0, j = 0;
739
740 for(; j + 2*vectorSize <= cols; j+=2*vectorSize)
741 {
742 Index i = 0;
743
744 if(PanelMode) ri += offset*(2*vectorSize);
745
746 for(; i + vectorSize <= depth; i+=vectorSize)
747 {
748 PacketBlock<Packet2d,4> block;
749 if(StorageOrder == ColMajor)
750 {
751 PacketBlock<Packet2d,2> block1, block2;
752 block1.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 0);
753 block1.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 1);
754 block2.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 2);
755 block2.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 3);
756
757 ptranspose(block1);
758 ptranspose(block2);
759
760 pstore<double>(blockB + ri , block1.packet[0]);
761 pstore<double>(blockB + ri + 2, block2.packet[0]);
762 pstore<double>(blockB + ri + 4, block1.packet[1]);
763 pstore<double>(blockB + ri + 6, block2.packet[1]);
764 } else {
765 block.packet[0] = rhs.template loadPacket<Packet2d>(i + 0, j + 0); //[a1 a2]
766 block.packet[1] = rhs.template loadPacket<Packet2d>(i + 0, j + 2); //[a3 a4]
767 block.packet[2] = rhs.template loadPacket<Packet2d>(i + 1, j + 0); //[b1 b2]
768 block.packet[3] = rhs.template loadPacket<Packet2d>(i + 1, j + 2); //[b3 b4]
769
770 storeBlock<double, Packet2d, Index, 4>(blockB + ri, block);
771 }
772
773 ri += 4*vectorSize;
774 }
775 for(; i < depth; i++)
776 {
777 if(StorageOrder == ColMajor)
778 {
779 blockB[ri+0] = rhs(i, j+0);
780 blockB[ri+1] = rhs(i, j+1);
781
782 ri += vectorSize;
783
784 blockB[ri+0] = rhs(i, j+2);
785 blockB[ri+1] = rhs(i, j+3);
786 } else {
787 Packet2d rhsV = rhs.template loadPacket<Packet2d>(i, j);
788 pstore<double>(blockB + ri, rhsV);
789
790 ri += vectorSize;
791
792 rhsV = rhs.template loadPacket<Packet2d>(i, j + 2);
793 pstore<double>(blockB + ri, rhsV);
794 }
795 ri += vectorSize;
796 }
797
798 if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth);
799 }
800
801 if(PanelMode) ri += offset;
802
803 for(; j < cols; j++)
804 {
805 for(Index i = 0; i < depth; i++)
806 {
807 blockB[ri] = rhs(i, j);
808 ri += 1;
809 }
810
811 if(PanelMode) ri += stride - depth;
812 }
813 }
814};
815
816// General template for lhs complex packing, float64 specialization.
817template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
818struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true>
819{
820 EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
821 {
822 const Index vectorSize = quad_traits<double>::vectorsize;
823 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
824 Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii;
825 double* blockAt = reinterpret_cast<double *>(blockA);
826 Index j = 0;
827
828 for(; j + vectorSize <= rows; j+=vectorSize)
829 {
830 Index i = 0;
831
832 rii = rir + vectorDelta;
833
834 for(; i + vectorSize <= depth; i+=vectorSize)
835 {
836 PacketBlock<Packet,2> blockr, blocki;
837 PacketBlock<PacketC,4> cblock;
838
839 if(StorageOrder == ColMajor)
840 {
841 cblock.packet[0] = lhs.template loadPacket<PacketC>(j, i + 0); //[a1 a1i]
842 cblock.packet[1] = lhs.template loadPacket<PacketC>(j, i + 1); //[b1 b1i]
843
844 cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 1, i + 0); //[a2 a2i]
845 cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i]
846
847 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2]
848 blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
849
850 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETIMAG64);
851 blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETIMAG64);
852 } else {
853 cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i); //[a1 a1i]
854 cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i); //[a2 a2i]
855
856 cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 0, i + 1); //[b1 b1i]
857 cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i
858
859 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2]
860 blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
861
862 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
863 blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64);
864 }
865
866 if(Conjugate)
867 {
868 blocki.packet[0] = -blocki.packet[0];
869 blocki.packet[1] = -blocki.packet[1];
870 }
871
872 storeBlock<double, Packet, Index, 2>(blockAt + rir, blockr);
873 storeBlock<double, Packet, Index, 2>(blockAt + rii, blocki);
874
875 rir += 2*vectorSize;
876 rii += 2*vectorSize;
877 }
878 for(; i < depth; i++)
879 {
880 PacketBlock<Packet,1> blockr, blocki;
881 PacketBlock<PacketC,2> cblock;
882
883 cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
884 cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i);
885
886 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
887 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
888
889 if(Conjugate)
890 {
891 blocki.packet[0] = -blocki.packet[0];
892 }
893
894 pstore<double>(blockAt + rir, blockr.packet[0]);
895 pstore<double>(blockAt + rii, blocki.packet[0]);
896
897 rir += vectorSize;
898 rii += vectorSize;
899 }
900
901 rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
902 }
903
904 if (j < rows)
905 {
906 if(PanelMode) rir += (offset*(rows - j - vectorSize));
907 rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
908
909 for(Index i = 0; i < depth; i++)
910 {
911 Index k = j;
912 for(; k < rows; k++)
913 {
914 blockAt[rir] = lhs(k, i).real();
915
916 if(Conjugate)
917 blockAt[rii] = -lhs(k, i).imag();
918 else
919 blockAt[rii] = lhs(k, i).imag();
920
921 rir += 1;
922 rii += 1;
923 }
924 }
925 }
926 }
927};
928
929// General template for rhs complex packing, float64 specialization.
930template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
931struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, false>
932{
933 EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
934 {
935 const Index vectorSize = quad_traits<double>::vectorsize;
936 const Index vectorDelta = 2*vectorSize * ((PanelMode) ? stride : depth);
937 Index rir = ((PanelMode) ? (2*vectorSize*offset) : 0), rii;
938 double* blockBt = reinterpret_cast<double *>(blockB);
939 Index j = 0;
940
941 for(; j + 2*vectorSize <= cols; j+=2*vectorSize)
942 {
943 Index i = 0;
944
945 rii = rir + vectorDelta;
946
947 for(; i < depth; i++)
948 {
949 PacketBlock<PacketC,4> cblock;
950 PacketBlock<Packet,2> blockr, blocki;
951
952 bload<DataMapper, PacketC, Index, 2, ColMajor, false, 4>(cblock, rhs, i, j);
953
954 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
955 blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64);
956
957 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
958 blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64);
959
960 if(Conjugate)
961 {
962 blocki.packet[0] = -blocki.packet[0];
963 blocki.packet[1] = -blocki.packet[1];
964 }
965
966 storeBlock<double, Packet, Index, 2>(blockBt + rir, blockr);
967 storeBlock<double, Packet, Index, 2>(blockBt + rii, blocki);
968
969 rir += 2*vectorSize;
970 rii += 2*vectorSize;
971 }
972
973 rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta);
974 }
975
976 if(PanelMode) rir -= (offset*(2*vectorSize - 1));
977
978 for(; j < cols; j++)
979 {
980 rii = rir + ((PanelMode) ? stride : depth);
981
982 for(Index i = 0; i < depth; i++)
983 {
984 blockBt[rir] = rhs(i, j).real();
985
986 if(Conjugate)
987 blockBt[rii] = -rhs(i, j).imag();
988 else
989 blockBt[rii] = rhs(i, j).imag();
990
991 rir += 1;
992 rii += 1;
993 }
994
995 rir += ((PanelMode) ? (2*stride - depth) : depth);
996 }
997 }
998};
999
1000/**************
1001 * GEMM utils *
1002 **************/
1003
1004// 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm).
1005template<typename Packet, bool NegativeAccumulate, int N>
1006EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,N>* acc, const Packet& lhsV, const Packet* rhsV)
1007{
1008 if(NegativeAccumulate)
1009 {
1010 acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
1011 if (N > 1) {
1012 acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]);
1013 }
1014 if (N > 2) {
1015 acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]);
1016 }
1017 if (N > 3) {
1018 acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]);
1019 }
1020 } else {
1021 acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
1022 if (N > 1) {
1023 acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]);
1024 }
1025 if (N > 2) {
1026 acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]);
1027 }
1028 if (N > 3) {
1029 acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]);
1030 }
1031 }
1032}
1033
1034template<int N, typename Scalar, typename Packet, bool NegativeAccumulate>
1035EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV)
1036{
1037 Packet lhsV = pload<Packet>(lhs);
1038
1039 pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
1040}
1041
1042template<typename Scalar, typename Packet, typename Index, const Index remaining_rows>
1043EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV)
1044{
1045#ifdef _ARCH_PWR9
1046 lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar));
1047#else
1048 Index i = 0;
1049 do {
1050 lhsV[i] = lhs[i];
1051 } while (++i < remaining_rows);
1052#endif
1053}
1054
1055template<int N, typename Scalar, typename Packet, typename Index, bool NegativeAccumulate, const Index remaining_rows>
1056EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV)
1057{
1058 Packet lhsV;
1059 loadPacketRemaining<Scalar, Packet, Index, remaining_rows>(lhs, lhsV);
1060
1061 pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
1062}
1063
1064// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real.
1065template<int N, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1066EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Packet &lhsV, const Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi)
1067{
1068 pger_common<Packet, false, N>(accReal, lhsV, rhsV);
1069 if(LhsIsReal)
1070 {
1071 pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1072 EIGEN_UNUSED_VARIABLE(lhsVi);
1073 } else {
1074 if (!RhsIsReal) {
1075 pger_common<Packet, ConjugateLhs == ConjugateRhs, N>(accReal, lhsVi, rhsVi);
1076 pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1077 } else {
1078 EIGEN_UNUSED_VARIABLE(rhsVi);
1079 }
1080 pger_common<Packet, ConjugateLhs, N>(accImag, lhsVi, rhsV);
1081 }
1082}
1083
1084template<int N, typename Scalar, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1085EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi)
1086{
1087 Packet lhsV = ploadLhs<Scalar, Packet>(lhs_ptr);
1088 Packet lhsVi;
1089 if(!LhsIsReal) lhsVi = ploadLhs<Scalar, Packet>(lhs_ptr_imag);
1090 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1091
1092 pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
1093}
1094
1095template<typename Scalar, typename Packet, typename Index, bool LhsIsReal, const Index remaining_rows>
1096EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi)
1097{
1098#ifdef _ARCH_PWR9
1099 lhsV = vec_xl_len((Scalar *)lhs_ptr, remaining_rows * sizeof(Scalar));
1100 if(!LhsIsReal) lhsVi = vec_xl_len((Scalar *)lhs_ptr_imag, remaining_rows * sizeof(Scalar));
1101 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1102#else
1103 Index i = 0;
1104 do {
1105 lhsV[i] = lhs_ptr[i];
1106 if(!LhsIsReal) lhsVi[i] = lhs_ptr_imag[i];
1107 } while (++i < remaining_rows);
1108 if(LhsIsReal) EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1109#endif
1110}
1111
1112template<int N, typename Scalar, typename Packet, typename Index, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
1113EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi)
1114{
1115 Packet lhsV, lhsVi;
1116 loadPacketRemaining<Scalar, Packet, Index, LhsIsReal, remaining_rows>(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi);
1117
1118 pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
1119}
1120
1121template<typename Scalar, typename Packet>
1122EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs)
1123{
1124 return ploadu<Packet>(lhs);
1125}
1126
1127// Zero the accumulator on PacketBlock.
1128template<typename Scalar, typename Packet, int N>
1129EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,N>& acc)
1130{
1131 acc.packet[0] = pset1<Packet>((Scalar)0);
1132 if (N > 1) {
1133 acc.packet[1] = pset1<Packet>((Scalar)0);
1134 }
1135 if (N > 2) {
1136 acc.packet[2] = pset1<Packet>((Scalar)0);
1137 }
1138 if (N > 3) {
1139 acc.packet[3] = pset1<Packet>((Scalar)0);
1140 }
1141}
1142
1143// Scale the PacketBlock vectors by alpha.
1144template<typename Packet, int N>
1145EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha)
1146{
1147 acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
1148 if (N > 1) {
1149 acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]);
1150 }
1151 if (N > 2) {
1152 acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]);
1153 }
1154 if (N > 3) {
1155 acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]);
1156 }
1157}
1158
1159template<typename Packet, int N>
1160EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha)
1161{
1162 acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
1163 if (N > 1) {
1164 acc.packet[1] = pmul<Packet>(accZ.packet[1], pAlpha);
1165 }
1166 if (N > 2) {
1167 acc.packet[2] = pmul<Packet>(accZ.packet[2], pAlpha);
1168 }
1169 if (N > 3) {
1170 acc.packet[3] = pmul<Packet>(accZ.packet[3], pAlpha);
1171 }
1172}
1173
1174// Complex version of PacketBlock scaling.
1175template<typename Packet, int N>
1176EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag)
1177{
1178 bscalec_common<Packet, N>(cReal, aReal, bReal);
1179
1180 bscalec_common<Packet, N>(cImag, aImag, bReal);
1181
1182 pger_common<Packet, true, N>(&cReal, bImag, aImag.packet);
1183
1184 pger_common<Packet, false, N>(&cImag, bImag, aReal.packet);
1185}
1186
1187template<typename Packet, int N>
1188EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet,N>& acc, const Packet& pMask)
1189{
1190 acc.packet[0] = pand(acc.packet[0], pMask);
1191 if (N > 1) {
1192 acc.packet[1] = pand(acc.packet[1], pMask);
1193 }
1194 if (N > 2) {
1195 acc.packet[2] = pand(acc.packet[2], pMask);
1196 }
1197 if (N > 3) {
1198 acc.packet[3] = pand(acc.packet[3], pMask);
1199 }
1200}
1201
1202template<typename Packet, int N>
1203EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag, const Packet& pMask)
1204{
1205 band<Packet, N>(aReal, pMask);
1206 band<Packet, N>(aImag, pMask);
1207
1208 bscalec<Packet,N>(aReal, aImag, bReal, bImag, cReal, cImag);
1209}
1210
1211// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed.
1212template<typename DataMapper, typename Packet, typename Index, const Index accCols, int StorageOrder, bool Complex, int N>
1213EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,N*(Complex?2:1)>& acc, const DataMapper& res, Index row, Index col)
1214{
1215 if (StorageOrder == RowMajor) {
1216 acc.packet[0] = res.template loadPacket<Packet>(row + 0, col);
1217 if (N > 1) {
1218 acc.packet[1] = res.template loadPacket<Packet>(row + 1, col);
1219 }
1220 if (N > 2) {
1221 acc.packet[2] = res.template loadPacket<Packet>(row + 2, col);
1222 }
1223 if (N > 3) {
1224 acc.packet[3] = res.template loadPacket<Packet>(row + 3, col);
1225 }
1226 if (Complex) {
1227 acc.packet[0+N] = res.template loadPacket<Packet>(row + 0, col + accCols);
1228 if (N > 1) {
1229 acc.packet[1+N] = res.template loadPacket<Packet>(row + 1, col + accCols);
1230 }
1231 if (N > 2) {
1232 acc.packet[2+N] = res.template loadPacket<Packet>(row + 2, col + accCols);
1233 }
1234 if (N > 3) {
1235 acc.packet[3+N] = res.template loadPacket<Packet>(row + 3, col + accCols);
1236 }
1237 }
1238 } else {
1239 acc.packet[0] = res.template loadPacket<Packet>(row, col + 0);
1240 if (N > 1) {
1241 acc.packet[1] = res.template loadPacket<Packet>(row, col + 1);
1242 }
1243 if (N > 2) {
1244 acc.packet[2] = res.template loadPacket<Packet>(row, col + 2);
1245 }
1246 if (N > 3) {
1247 acc.packet[3] = res.template loadPacket<Packet>(row, col + 3);
1248 }
1249 if (Complex) {
1250 acc.packet[0+N] = res.template loadPacket<Packet>(row + accCols, col + 0);
1251 if (N > 1) {
1252 acc.packet[1+N] = res.template loadPacket<Packet>(row + accCols, col + 1);
1253 }
1254 if (N > 2) {
1255 acc.packet[2+N] = res.template loadPacket<Packet>(row + accCols, col + 2);
1256 }
1257 if (N > 3) {
1258 acc.packet[3+N] = res.template loadPacket<Packet>(row + accCols, col + 3);
1259 }
1260 }
1261 }
1262}
1263
1264const static Packet4i mask41 = { -1, 0, 0, 0 };
1265const static Packet4i mask42 = { -1, -1, 0, 0 };
1266const static Packet4i mask43 = { -1, -1, -1, 0 };
1267
1268const static Packet2l mask21 = { -1, 0 };
1269
1270template<typename Packet>
1271EIGEN_ALWAYS_INLINE Packet bmask(const int remaining_rows)
1272{
1273 if (remaining_rows == 0) {
1274 return pset1<Packet>(float(0.0)); // Not used
1275 } else {
1276 switch (remaining_rows) {
1277 case 1: return Packet(mask41);
1278 case 2: return Packet(mask42);
1279 default: return Packet(mask43);
1280 }
1281 }
1282}
1283
1284template<>
1285EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d>(const int remaining_rows)
1286{
1287 if (remaining_rows == 0) {
1288 return pset1<Packet2d>(double(0.0)); // Not used
1289 } else {
1290 return Packet2d(mask21);
1291 }
1292}
1293
1294template<typename Packet, int N>
1295EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha, const Packet& pMask)
1296{
1297 band<Packet, N>(accZ, pMask);
1298
1299 bscale<Packet, N>(acc, accZ, pAlpha);
1300}
1301
1302template<typename Packet, int N> EIGEN_ALWAYS_INLINE void
1303pbroadcastN_old(const __UNPACK_TYPE__(Packet) *a,
1304 Packet& a0, Packet& a1, Packet& a2, Packet& a3)
1305{
1306 a0 = pset1<Packet>(a[0]);
1307 if (N > 1) {
1308 a1 = pset1<Packet>(a[1]);
1309 } else {
1310 EIGEN_UNUSED_VARIABLE(a1);
1311 }
1312 if (N > 2) {
1313 a2 = pset1<Packet>(a[2]);
1314 } else {
1315 EIGEN_UNUSED_VARIABLE(a2);
1316 }
1317 if (N > 3) {
1318 a3 = pset1<Packet>(a[3]);
1319 } else {
1320 EIGEN_UNUSED_VARIABLE(a3);
1321 }
1322}
1323
1324template<>
1325EIGEN_ALWAYS_INLINE void pbroadcastN_old<Packet4f,4>(const float* a, Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
1326{
1327 pbroadcast4<Packet4f>(a, a0, a1, a2, a3);
1328}
1329
1330template<>
1331EIGEN_ALWAYS_INLINE void pbroadcastN_old<Packet2d,4>(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
1332{
1333 a1 = pload<Packet2d>(a);
1334 a3 = pload<Packet2d>(a + 2);
1335 a0 = vec_splat(a1, 0);
1336 a1 = vec_splat(a1, 1);
1337 a2 = vec_splat(a3, 0);
1338 a3 = vec_splat(a3, 1);
1339}
1340
1341template<typename Packet, int N> EIGEN_ALWAYS_INLINE void
1342pbroadcastN(const __UNPACK_TYPE__(Packet) *a,
1343 Packet& a0, Packet& a1, Packet& a2, Packet& a3)
1344{
1345 a0 = pset1<Packet>(a[0]);
1346 if (N > 1) {
1347 a1 = pset1<Packet>(a[1]);
1348 } else {
1349 EIGEN_UNUSED_VARIABLE(a1);
1350 }
1351 if (N > 2) {
1352 a2 = pset1<Packet>(a[2]);
1353 } else {
1354 EIGEN_UNUSED_VARIABLE(a2);
1355 }
1356 if (N > 3) {
1357 a3 = pset1<Packet>(a[3]);
1358 } else {
1359 EIGEN_UNUSED_VARIABLE(a3);
1360 }
1361}
1362
1363template<> EIGEN_ALWAYS_INLINE void
1364pbroadcastN<Packet4f,4>(const float *a,
1365 Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
1366{
1367 a3 = pload<Packet4f>(a);
1368 a0 = vec_splat(a3, 0);
1369 a1 = vec_splat(a3, 1);
1370 a2 = vec_splat(a3, 2);
1371 a3 = vec_splat(a3, 3);
1372}
1373
1374// PEEL loop factor.
1375#define PEEL 7
1376#define PEEL_ROW 7
1377
1378#define MICRO_UNROLL_PEEL(func) \
1379 func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
1380
1381#define MICRO_ZERO_PEEL(peel) \
1382 if ((PEEL_ROW > peel) && (peel != 0)) { \
1383 bsetzero<Scalar, Packet, accRows>(accZero##peel); \
1384 } else { \
1385 EIGEN_UNUSED_VARIABLE(accZero##peel); \
1386 }
1387
1388#define MICRO_ZERO_PEEL_ROW \
1389 MICRO_UNROLL_PEEL(MICRO_ZERO_PEEL);
1390
1391#define MICRO_WORK_PEEL(peel) \
1392 if (PEEL_ROW > peel) { \
1393 pbroadcastN<Packet,accRows>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1394 pger<accRows, Scalar, Packet, false>(&accZero##peel, lhs_ptr + (remaining_rows * peel), rhsV##peel); \
1395 } else { \
1396 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
1397 }
1398
1399#define MICRO_WORK_PEEL_ROW \
1400 Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4], rhsV4[4], rhsV5[4], rhsV6[4], rhsV7[4]; \
1401 MICRO_UNROLL_PEEL(MICRO_WORK_PEEL); \
1402 lhs_ptr += (remaining_rows * PEEL_ROW); \
1403 rhs_ptr += (accRows * PEEL_ROW);
1404
1405#define MICRO_ADD_PEEL(peel, sum) \
1406 if (PEEL_ROW > peel) { \
1407 for (Index i = 0; i < accRows; i++) { \
1408 accZero##sum.packet[i] += accZero##peel.packet[i]; \
1409 } \
1410 }
1411
1412#define MICRO_ADD_PEEL_ROW \
1413 MICRO_ADD_PEEL(4, 0) MICRO_ADD_PEEL(5, 1) MICRO_ADD_PEEL(6, 2) MICRO_ADD_PEEL(7, 3) \
1414 MICRO_ADD_PEEL(2, 0) MICRO_ADD_PEEL(3, 1) MICRO_ADD_PEEL(1, 0)
1415
1416template<typename Scalar, typename Packet, typename Index, const Index accRows, const Index remaining_rows>
1417EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(
1418 const Scalar* &lhs_ptr,
1419 const Scalar* &rhs_ptr,
1420 PacketBlock<Packet,accRows> &accZero)
1421{
1422 Packet rhsV[4];
1423 pbroadcastN<Packet,accRows>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
1424 pger<accRows, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
1425 lhs_ptr += remaining_rows;
1426 rhs_ptr += accRows;
1427}
1428
1429template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index remaining_rows>
1430EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(
1431 const DataMapper& res,
1432 const Scalar* lhs_base,
1433 const Scalar* rhs_base,
1434 Index depth,
1435 Index strideA,
1436 Index offsetA,
1437 Index row,
1438 Index col,
1439 Index rows,
1440 Index cols,
1441 const Packet& pAlpha,
1442 const Packet& pMask)
1443{
1444 const Scalar* rhs_ptr = rhs_base;
1445 const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
1446 PacketBlock<Packet,accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7, acc;
1447
1448 bsetzero<Scalar, Packet, accRows>(accZero0);
1449
1450 Index remaining_depth = (col + quad_traits<Scalar>::rows < cols) ? depth : (depth & -quad_traits<Scalar>::rows);
1451 Index k = 0;
1452 if (remaining_depth >= PEEL_ROW) {
1453 MICRO_ZERO_PEEL_ROW
1454 do
1455 {
1456 EIGEN_POWER_PREFETCH(rhs_ptr);
1457 EIGEN_POWER_PREFETCH(lhs_ptr);
1458 MICRO_WORK_PEEL_ROW
1459 } while ((k += PEEL_ROW) + PEEL_ROW <= remaining_depth);
1460 MICRO_ADD_PEEL_ROW
1461 }
1462 for(; k < remaining_depth; k++)
1463 {
1464 MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows, remaining_rows>(lhs_ptr, rhs_ptr, accZero0);
1465 }
1466
1467 if ((remaining_depth == depth) && (rows >= accCols))
1468 {
1469 bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row, 0);
1470 bscale<Packet,accRows>(acc, accZero0, pAlpha, pMask);
1471 res.template storePacketBlock<Packet,accRows>(row, 0, acc);
1472 } else {
1473 for(; k < depth; k++)
1474 {
1475 Packet rhsV[4];
1476 pbroadcastN<Packet,accRows>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
1477 pger<accRows, Scalar, Packet, Index, false, remaining_rows>(&accZero0, lhs_ptr, rhsV);
1478 lhs_ptr += remaining_rows;
1479 rhs_ptr += accRows;
1480 }
1481
1482 for(Index j = 0; j < accRows; j++) {
1483 accZero0.packet[j] = vec_mul(pAlpha, accZero0.packet[j]);
1484 for(Index i = 0; i < remaining_rows; i++) {
1485 res(row + i, j) += accZero0.packet[j][i];
1486 }
1487 }
1488 }
1489}
1490
1491template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
1492EIGEN_ALWAYS_INLINE void gemm_extra_row(
1493 const DataMapper& res,
1494 const Scalar* lhs_base,
1495 const Scalar* rhs_base,
1496 Index depth,
1497 Index strideA,
1498 Index offsetA,
1499 Index row,
1500 Index col,
1501 Index rows,
1502 Index cols,
1503 Index remaining_rows,
1504 const Packet& pAlpha,
1505 const Packet& pMask)
1506{
1507 switch(remaining_rows) {
1508 case 1:
1509 gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 1>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
1510 break;
1511 case 2:
1512 if (sizeof(Scalar) == sizeof(float)) {
1513 gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 2>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
1514 }
1515 break;
1516 default:
1517 if (sizeof(Scalar) == sizeof(float)) {
1518 gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 3>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
1519 }
1520 break;
1521 }
1522}
1523
1524#define MICRO_UNROLL(func) \
1525 func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
1526
1527#define MICRO_UNROLL_WORK(func, func2, peel) \
1528 MICRO_UNROLL(func2); \
1529 func(0,peel) func(1,peel) func(2,peel) func(3,peel) \
1530 func(4,peel) func(5,peel) func(6,peel) func(7,peel)
1531
1532#define MICRO_LOAD_ONE(iter) \
1533 if (unroll_factor > iter) { \
1534 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
1535 lhs_ptr##iter += accCols; \
1536 } else { \
1537 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
1538 }
1539
1540#define MICRO_WORK_ONE(iter, peel) \
1541 if (unroll_factor > iter) { \
1542 pger_common<Packet, false, accRows>(&accZero##iter, lhsV##iter, rhsV##peel); \
1543 }
1544
1545#define MICRO_TYPE_PEEL4(func, func2, peel) \
1546 if (PEEL > peel) { \
1547 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
1548 pbroadcastN<Packet,accRows>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1549 MICRO_UNROLL_WORK(func, func2, peel) \
1550 } else { \
1551 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
1552 }
1553
1554#define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
1555 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M]; \
1556 func(func1,func2,0); func(func1,func2,1); \
1557 func(func1,func2,2); func(func1,func2,3); \
1558 func(func1,func2,4); func(func1,func2,5); \
1559 func(func1,func2,6); func(func1,func2,7);
1560
1561#define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
1562 Packet rhsV0[M]; \
1563 func(func1,func2,0);
1564
1565#define MICRO_ONE_PEEL4 \
1566 MICRO_UNROLL_TYPE_PEEL(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
1567 rhs_ptr += (accRows * PEEL);
1568
1569#define MICRO_ONE4 \
1570 MICRO_UNROLL_TYPE_ONE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
1571 rhs_ptr += accRows;
1572
1573#define MICRO_DST_PTR_ONE(iter) \
1574 if (unroll_factor > iter) { \
1575 bsetzero<Scalar, Packet, accRows>(accZero##iter); \
1576 } else { \
1577 EIGEN_UNUSED_VARIABLE(accZero##iter); \
1578 }
1579
1580#define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE)
1581
1582#define MICRO_SRC_PTR_ONE(iter) \
1583 if (unroll_factor > iter) { \
1584 lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \
1585 } else { \
1586 EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
1587 }
1588
1589#define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE)
1590
1591#define MICRO_PREFETCH_ONE(iter) \
1592 if (unroll_factor > iter) { \
1593 EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
1594 }
1595
1596#define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE)
1597
1598#define MICRO_STORE_ONE(iter) \
1599 if (unroll_factor > iter) { \
1600 bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row + iter*accCols, 0); \
1601 bscale<Packet,accRows>(acc, accZero##iter, pAlpha); \
1602 res.template storePacketBlock<Packet,accRows>(row + iter*accCols, 0, acc); \
1603 }
1604
1605#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
1606
1607template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
1608EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
1609 const DataMapper& res,
1610 const Scalar* lhs_base,
1611 const Scalar* rhs_base,
1612 Index depth,
1613 Index strideA,
1614 Index& row,
1615 const Packet& pAlpha)
1616{
1617 const Scalar* rhs_ptr = rhs_base;
1618 const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
1619 PacketBlock<Packet,accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
1620 PacketBlock<Packet,accRows> acc;
1621
1622 MICRO_SRC_PTR
1623 MICRO_DST_PTR
1624
1625 Index k = 0;
1626 for(; k + PEEL <= depth; k+= PEEL)
1627 {
1628 EIGEN_POWER_PREFETCH(rhs_ptr);
1629 MICRO_PREFETCH
1630 MICRO_ONE_PEEL4
1631 }
1632 for(; k < depth; k++)
1633 {
1634 MICRO_ONE4
1635 }
1636 MICRO_STORE
1637
1638 row += unroll_factor*accCols;
1639}
1640
1641template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
1642EIGEN_ALWAYS_INLINE void gemm_cols(
1643 const DataMapper& res,
1644 const Scalar* blockA,
1645 const Scalar* blockB,
1646 Index depth,
1647 Index strideA,
1648 Index offsetA,
1649 Index strideB,
1650 Index offsetB,
1651 Index col,
1652 Index rows,
1653 Index cols,
1654 Index remaining_rows,
1655 const Packet& pAlpha,
1656 const Packet& pMask)
1657{
1658 const DataMapper res3 = res.getSubMapper(0, col);
1659
1660 const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
1661 const Scalar* lhs_base = blockA + accCols*offsetA;
1662 Index row = 0;
1663
1664#define MAX_UNROLL 6
1665 while(row + MAX_UNROLL*accCols <= rows) {
1666 gemm_unrolled_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1667 }
1668 switch( (rows-row)/accCols ) {
1669#if MAX_UNROLL > 7
1670 case 7:
1671 gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1672 break;
1673#endif
1674#if MAX_UNROLL > 6
1675 case 6:
1676 gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1677 break;
1678#endif
1679#if MAX_UNROLL > 5
1680 case 5:
1681 gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1682 break;
1683#endif
1684#if MAX_UNROLL > 4
1685 case 4:
1686 gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1687 break;
1688#endif
1689#if MAX_UNROLL > 3
1690 case 3:
1691 gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1692 break;
1693#endif
1694#if MAX_UNROLL > 2
1695 case 2:
1696 gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1697 break;
1698#endif
1699#if MAX_UNROLL > 1
1700 case 1:
1701 gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
1702 break;
1703#endif
1704 default:
1705 break;
1706 }
1707#undef MAX_UNROLL
1708
1709 if(remaining_rows > 0)
1710 {
1711 gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
1712 }
1713}
1714
1715template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
1716EIGEN_STRONG_INLINE void gemm_extra_cols(
1717 const DataMapper& res,
1718 const Scalar* blockA,
1719 const Scalar* blockB,
1720 Index depth,
1721 Index strideA,
1722 Index offsetA,
1723 Index strideB,
1724 Index offsetB,
1725 Index col,
1726 Index rows,
1727 Index cols,
1728 Index remaining_rows,
1729 const Packet& pAlpha,
1730 const Packet& pMask)
1731{
1732 for (; col < cols; col++) {
1733 gemm_cols<Scalar, Packet, DataMapper, Index, 1, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
1734 }
1735}
1736
1737/****************
1738 * GEMM kernels *
1739 * **************/
1740template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
1741EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
1742{
1743 const Index remaining_rows = rows % accCols;
1744
1745 if( strideA == -1 ) strideA = depth;
1746 if( strideB == -1 ) strideB = depth;
1747
1748 const Packet pAlpha = pset1<Packet>(alpha);
1749 const Packet pMask = bmask<Packet>((const int)(remaining_rows));
1750
1751 Index col = 0;
1752 for(; col + accRows <= cols; col += accRows)
1753 {
1754 gemm_cols<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
1755 }
1756
1757 gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
1758}
1759
1760#define accColsC (accCols / 2)
1761#define advanceRows ((LhsIsReal) ? 1 : 2)
1762#define advanceCols ((RhsIsReal) ? 1 : 2)
1763
1764// PEEL_COMPLEX loop factor.
1765#define PEEL_COMPLEX 3
1766#define PEEL_COMPLEX_ROW 3
1767
1768#define MICRO_COMPLEX_UNROLL_PEEL(func) \
1769 func(0) func(1) func(2) func(3)
1770
1771#define MICRO_COMPLEX_ZERO_PEEL(peel) \
1772 if ((PEEL_COMPLEX_ROW > peel) && (peel != 0)) { \
1773 bsetzero<Scalar, Packet, accRows>(accReal##peel); \
1774 bsetzero<Scalar, Packet, accRows>(accImag##peel); \
1775 } else { \
1776 EIGEN_UNUSED_VARIABLE(accReal##peel); \
1777 EIGEN_UNUSED_VARIABLE(accImag##peel); \
1778 }
1779
1780#define MICRO_COMPLEX_ZERO_PEEL_ROW \
1781 MICRO_COMPLEX_UNROLL_PEEL(MICRO_COMPLEX_ZERO_PEEL);
1782
1783#define MICRO_COMPLEX_WORK_PEEL(peel) \
1784 if (PEEL_COMPLEX_ROW > peel) { \
1785 pbroadcastN_old<Packet,accRows>(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1786 if(!RhsIsReal) pbroadcastN_old<Packet,accRows>(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \
1787 pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##peel, &accImag##peel, lhs_ptr_real + (remaining_rows * peel), lhs_ptr_imag + (remaining_rows * peel), rhsV##peel, rhsVi##peel); \
1788 } else { \
1789 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
1790 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
1791 }
1792
1793#define MICRO_COMPLEX_WORK_PEEL_ROW \
1794 Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4]; \
1795 Packet rhsVi0[4], rhsVi1[4], rhsVi2[4], rhsVi3[4]; \
1796 MICRO_COMPLEX_UNROLL_PEEL(MICRO_COMPLEX_WORK_PEEL); \
1797 lhs_ptr_real += (remaining_rows * PEEL_COMPLEX_ROW); \
1798 if(!LhsIsReal) lhs_ptr_imag += (remaining_rows * PEEL_COMPLEX_ROW); \
1799 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); \
1800 rhs_ptr_real += (accRows * PEEL_COMPLEX_ROW); \
1801 if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_ROW); \
1802 else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
1803
1804#define MICRO_COMPLEX_ADD_PEEL(peel, sum) \
1805 if (PEEL_COMPLEX_ROW > peel) { \
1806 for (Index i = 0; i < accRows; i++) { \
1807 accReal##sum.packet[i] += accReal##peel.packet[i]; \
1808 accImag##sum.packet[i] += accImag##peel.packet[i]; \
1809 } \
1810 }
1811
1812#define MICRO_COMPLEX_ADD_PEEL_ROW \
1813 MICRO_COMPLEX_ADD_PEEL(2, 0) MICRO_COMPLEX_ADD_PEEL(3, 1) \
1814 MICRO_COMPLEX_ADD_PEEL(1, 0)
1815
1816template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
1817EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(
1818 const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag,
1819 const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag,
1820 PacketBlock<Packet,accRows> &accReal, PacketBlock<Packet,accRows> &accImag)
1821{
1822 Packet rhsV[4], rhsVi[4];
1823 pbroadcastN_old<Packet,accRows>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
1824 if(!RhsIsReal) pbroadcastN_old<Packet,accRows>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
1825 pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
1826 lhs_ptr_real += remaining_rows;
1827 if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
1828 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1829 rhs_ptr_real += accRows;
1830 if(!RhsIsReal) rhs_ptr_imag += accRows;
1831 else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
1832}
1833
1834template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
1835EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(
1836 const DataMapper& res,
1837 const Scalar* lhs_base,
1838 const Scalar* rhs_base,
1839 Index depth,
1840 Index strideA,
1841 Index offsetA,
1842 Index strideB,
1843 Index row,
1844 Index col,
1845 Index rows,
1846 Index cols,
1847 const Packet& pAlphaReal,
1848 const Packet& pAlphaImag,
1849 const Packet& pMask)
1850{
1851 const Scalar* rhs_ptr_real = rhs_base;
1852 const Scalar* rhs_ptr_imag = NULL;
1853 if(!RhsIsReal) rhs_ptr_imag = rhs_base + accRows*strideB;
1854 else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
1855 const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA;
1856 const Scalar* lhs_ptr_imag = NULL;
1857 if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA;
1858 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1859 PacketBlock<Packet,accRows> accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
1860 PacketBlock<Packet,accRows> taccReal, taccImag;
1861 PacketBlock<Packetc,accRows> acc0, acc1;
1862 PacketBlock<Packetc,accRows*2> tRes;
1863
1864 bsetzero<Scalar, Packet, accRows>(accReal0);
1865 bsetzero<Scalar, Packet, accRows>(accImag0);
1866
1867 Index remaining_depth = (col + quad_traits<Scalar>::rows < cols) ? depth : (depth & -quad_traits<Scalar>::rows);
1868 Index k = 0;
1869 if (remaining_depth >= PEEL_COMPLEX_ROW) {
1870 MICRO_COMPLEX_ZERO_PEEL_ROW
1871 do
1872 {
1873 EIGEN_POWER_PREFETCH(rhs_ptr_real);
1874 if(!RhsIsReal) {
1875 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
1876 }
1877 EIGEN_POWER_PREFETCH(lhs_ptr_real);
1878 if(!LhsIsReal) {
1879 EIGEN_POWER_PREFETCH(lhs_ptr_imag);
1880 }
1881 MICRO_COMPLEX_WORK_PEEL_ROW
1882 } while ((k += PEEL_COMPLEX_ROW) + PEEL_COMPLEX_ROW <= remaining_depth);
1883 MICRO_COMPLEX_ADD_PEEL_ROW
1884 }
1885 for(; k < remaining_depth; k++)
1886 {
1887 MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal0, accImag0);
1888 }
1889
1890 if ((remaining_depth == depth) && (rows >= accCols))
1891 {
1892 bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, accRows>(tRes, res, row, 0);
1893 bscalec<Packet,accRows>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
1894 bcouple<Packet, Packetc, accRows>(taccReal, taccImag, tRes, acc0, acc1);
1895 res.template storePacketBlock<Packetc,accRows>(row + 0, 0, acc0);
1896 res.template storePacketBlock<Packetc,accRows>(row + accColsC, 0, acc1);
1897 } else {
1898 for(; k < depth; k++)
1899 {
1900 Packet rhsV[4], rhsVi[4];
1901 pbroadcastN_old<Packet,accRows>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
1902 if(!RhsIsReal) pbroadcastN_old<Packet,accRows>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
1903 pgerc<accRows, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(&accReal0, &accImag0, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
1904 lhs_ptr_real += remaining_rows;
1905 if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
1906 rhs_ptr_real += accRows;
1907 if(!RhsIsReal) rhs_ptr_imag += accRows;
1908 }
1909
1910 bscalec<Packet,accRows>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag);
1911 bcouple_common<Packet, Packetc, accRows>(taccReal, taccImag, acc0, acc1);
1912
1913 if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1))
1914 {
1915 for(Index j = 0; j < accRows; j++) {
1916 res(row + 0, j) += pfirst<Packetc>(acc0.packet[j]);
1917 }
1918 } else {
1919 for(Index j = 0; j < accRows; j++) {
1920 PacketBlock<Packetc,1> acc2;
1921 acc2.packet[0] = res.template loadPacket<Packetc>(row + 0, j) + acc0.packet[j];
1922 res.template storePacketBlock<Packetc,1>(row + 0, j, acc2);
1923 if(remaining_rows > accColsC) {
1924 res(row + accColsC, j) += pfirst<Packetc>(acc1.packet[j]);
1925 }
1926 }
1927 }
1928 }
1929}
1930
1931template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1932EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(
1933 const DataMapper& res,
1934 const Scalar* lhs_base,
1935 const Scalar* rhs_base,
1936 Index depth,
1937 Index strideA,
1938 Index offsetA,
1939 Index strideB,
1940 Index row,
1941 Index col,
1942 Index rows,
1943 Index cols,
1944 Index remaining_rows,
1945 const Packet& pAlphaReal,
1946 const Packet& pAlphaImag,
1947 const Packet& pMask)
1948{
1949 switch(remaining_rows) {
1950 case 1:
1951 gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, 1>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask);
1952 break;
1953 case 2:
1954 if (sizeof(Scalar) == sizeof(float)) {
1955 gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, 2>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask);
1956 }
1957 break;
1958 default:
1959 if (sizeof(Scalar) == sizeof(float)) {
1960 gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, 3>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask);
1961 }
1962 break;
1963 }
1964}
1965
1966#define MICRO_COMPLEX_UNROLL(func) \
1967 func(0) func(1) func(2) func(3)
1968
1969#define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
1970 MICRO_COMPLEX_UNROLL(func2); \
1971 func(0,peel) func(1,peel) func(2,peel) func(3,peel)
1972
1973#define MICRO_COMPLEX_LOAD_ONE(iter) \
1974 if (unroll_factor > iter) { \
1975 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
1976 if(!LhsIsReal) { \
1977 lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter + imag_delta); \
1978 } else { \
1979 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
1980 } \
1981 lhs_ptr_real##iter += accCols; \
1982 } else { \
1983 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
1984 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
1985 }
1986
1987#define MICRO_COMPLEX_WORK_ONE4(iter, peel) \
1988 if (unroll_factor > iter) { \
1989 pgerc_common<accRows, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
1990 }
1991
1992#define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \
1993 if (PEEL_COMPLEX > peel) { \
1994 Packet lhsV0, lhsV1, lhsV2, lhsV3; \
1995 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
1996 pbroadcastN_old<Packet,accRows>(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1997 if(!RhsIsReal) { \
1998 pbroadcastN_old<Packet,accRows>(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \
1999 } else { \
2000 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2001 } \
2002 MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2003 } else { \
2004 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2005 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2006 }
2007
2008#define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2009 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M]; \
2010 Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M]; \
2011 func(func1,func2,0); func(func1,func2,1); \
2012 func(func1,func2,2); func(func1,func2,3);
2013
2014#define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \
2015 Packet rhsV0[M], rhsVi0[M];\
2016 func(func1,func2,0);
2017
2018#define MICRO_COMPLEX_ONE_PEEL4 \
2019 MICRO_COMPLEX_UNROLL_TYPE_PEEL(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \
2020 rhs_ptr_real += (accRows * PEEL_COMPLEX); \
2021 if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX);
2022
2023#define MICRO_COMPLEX_ONE4 \
2024 MICRO_COMPLEX_UNROLL_TYPE_ONE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \
2025 rhs_ptr_real += accRows; \
2026 if(!RhsIsReal) rhs_ptr_imag += accRows;
2027
2028#define MICRO_COMPLEX_DST_PTR_ONE(iter) \
2029 if (unroll_factor > iter) { \
2030 bsetzero<Scalar, Packet, accRows>(accReal##iter); \
2031 bsetzero<Scalar, Packet, accRows>(accImag##iter); \
2032 } else { \
2033 EIGEN_UNUSED_VARIABLE(accReal##iter); \
2034 EIGEN_UNUSED_VARIABLE(accImag##iter); \
2035 }
2036
2037#define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE)
2038
2039#define MICRO_COMPLEX_SRC_PTR_ONE(iter) \
2040 if (unroll_factor > iter) { \
2041 lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \
2042 } else { \
2043 EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
2044 }
2045
2046#define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
2047
2048#define MICRO_COMPLEX_PREFETCH_ONE(iter) \
2049 if (unroll_factor > iter) { \
2050 EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
2051 }
2052
2053#define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
2054
2055#define MICRO_COMPLEX_STORE_ONE(iter) \
2056 if (unroll_factor > iter) { \
2057 bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, accRows>(tRes, res, row + iter*accCols, 0); \
2058 bscalec<Packet,accRows>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \
2059 bcouple<Packet, Packetc, accRows>(taccReal, taccImag, tRes, acc0, acc1); \
2060 res.template storePacketBlock<Packetc,accRows>(row + iter*accCols + 0, 0, acc0); \
2061 res.template storePacketBlock<Packetc,accRows>(row + iter*accCols + accColsC, 0, acc1); \
2062 }
2063
2064#define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
2065
2066template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2067EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration(
2068 const DataMapper& res,
2069 const Scalar* lhs_base,
2070 const Scalar* rhs_base,
2071 Index depth,
2072 Index strideA,
2073 Index strideB,
2074 Index& row,
2075 const Packet& pAlphaReal,
2076 const Packet& pAlphaImag)
2077{
2078 const Scalar* rhs_ptr_real = rhs_base;
2079 const Scalar* rhs_ptr_imag = NULL;
2080 const Index imag_delta = accCols*strideA;
2081 if(!RhsIsReal) {
2082 rhs_ptr_imag = rhs_base + accRows*strideB;
2083 } else {
2084 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
2085 }
2086 const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
2087 const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
2088 PacketBlock<Packet,accRows> accReal0, accImag0, accReal1, accImag1;
2089 PacketBlock<Packet,accRows> accReal2, accImag2, accReal3, accImag3;
2090 PacketBlock<Packet,accRows> taccReal, taccImag;
2091 PacketBlock<Packetc,accRows> acc0, acc1;
2092 PacketBlock<Packetc,accRows*2> tRes;
2093
2094 MICRO_COMPLEX_SRC_PTR
2095 MICRO_COMPLEX_DST_PTR
2096
2097 Index k = 0;
2098 for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX)
2099 {
2100 EIGEN_POWER_PREFETCH(rhs_ptr_real);
2101 if(!RhsIsReal) {
2102 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
2103 }
2104 MICRO_COMPLEX_PREFETCH
2105 MICRO_COMPLEX_ONE_PEEL4
2106 }
2107 for(; k < depth; k++)
2108 {
2109 MICRO_COMPLEX_ONE4
2110 }
2111 MICRO_COMPLEX_STORE
2112
2113 row += unroll_factor*accCols;
2114}
2115
2116template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2117EIGEN_ALWAYS_INLINE void gemm_complex_cols(
2118 const DataMapper& res,
2119 const Scalar* blockA,
2120 const Scalar* blockB,
2121 Index depth,
2122 Index strideA,
2123 Index offsetA,
2124 Index strideB,
2125 Index offsetB,
2126 Index col,
2127 Index rows,
2128 Index cols,
2129 Index remaining_rows,
2130 const Packet& pAlphaReal,
2131 const Packet& pAlphaImag,
2132 const Packet& pMask)
2133{
2134 const DataMapper res3 = res.getSubMapper(0, col);
2135
2136 const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
2137 const Scalar* lhs_base = blockA + accCols*offsetA;
2138 Index row = 0;
2139
2140#define MAX_COMPLEX_UNROLL 3
2141 while(row + MAX_COMPLEX_UNROLL*accCols <= rows) {
2142 gemm_complex_unrolled_iteration<MAX_COMPLEX_UNROLL, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
2143 }
2144 switch( (rows-row)/accCols ) {
2145#if MAX_COMPLEX_UNROLL > 4
2146 case 4:
2147 gemm_complex_unrolled_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
2148 break;
2149#endif
2150#if MAX_COMPLEX_UNROLL > 3
2151 case 3:
2152 gemm_complex_unrolled_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
2153 break;
2154#endif
2155#if MAX_COMPLEX_UNROLL > 2
2156 case 2:
2157 gemm_complex_unrolled_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
2158 break;
2159#endif
2160#if MAX_COMPLEX_UNROLL > 1
2161 case 1:
2162 gemm_complex_unrolled_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
2163 break;
2164#endif
2165 default:
2166 break;
2167 }
2168#undef MAX_COMPLEX_UNROLL
2169
2170 if(remaining_rows > 0)
2171 {
2172 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
2173 }
2174}
2175
2176template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2177EIGEN_STRONG_INLINE void gemm_complex_extra_cols(
2178 const DataMapper& res,
2179 const Scalar* blockA,
2180 const Scalar* blockB,
2181 Index depth,
2182 Index strideA,
2183 Index offsetA,
2184 Index strideB,
2185 Index offsetB,
2186 Index col,
2187 Index rows,
2188 Index cols,
2189 Index remaining_rows,
2190 const Packet& pAlphaReal,
2191 const Packet& pAlphaImag,
2192 const Packet& pMask)
2193{
2194 for (; col < cols; col++) {
2195 gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, Index, 1, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
2196 }
2197}
2198
2199template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2200EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
2201{
2202 const Index remaining_rows = rows % accCols;
2203
2204 if( strideA == -1 ) strideA = depth;
2205 if( strideB == -1 ) strideB = depth;
2206
2207 const Packet pAlphaReal = pset1<Packet>(alpha.real());
2208 const Packet pAlphaImag = pset1<Packet>(alpha.imag());
2209 const Packet pMask = bmask<Packet>((const int)(remaining_rows));
2210
2211 const Scalar* blockA = (Scalar *) blockAc;
2212 const Scalar* blockB = (Scalar *) blockBc;
2213
2214 Index col = 0;
2215 for(; col + accRows <= cols; col += accRows)
2216 {
2217 gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
2218 }
2219
2220 gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
2221}
2222
2223#undef accColsC
2224#undef advanceCols
2225#undef advanceRows
2226
2227/************************************
2228 * ppc64le template specializations *
2229 * **********************************/
2230template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2231struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2232{
2233 void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2234};
2235
2236template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2237void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2238 ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2239{
2240 dhs_pack<double, Index, DataMapper, Packet2d, ColMajor, PanelMode, true> pack;
2241 pack(blockA, lhs, depth, rows, stride, offset);
2242}
2243
2244template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2245struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2246{
2247 void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2248};
2249
2250template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2251void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2252 ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2253{
2254 dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, true> pack;
2255 pack(blockA, lhs, depth, rows, stride, offset);
2256}
2257
2258#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
2259template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2260struct gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2261{
2262 void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2263};
2264
2265template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2266void gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2267 ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2268{
2269 dhs_pack<double, Index, DataMapper, Packet2d, ColMajor, PanelMode, false> pack;
2270 pack(blockB, rhs, depth, cols, stride, offset);
2271}
2272
2273template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2274struct gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2275{
2276 void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2277};
2278
2279template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2280void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2281 ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2282{
2283 dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, false> pack;
2284 pack(blockB, rhs, depth, cols, stride, offset);
2285}
2286#endif
2287
2288template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2289struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2290{
2291 void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2292};
2293
2294template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2295void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2296 ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2297{
2298 dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, true> pack;
2299 pack(blockA, lhs, depth, rows, stride, offset);
2300}
2301
2302template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2303struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2304{
2305 void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2306};
2307
2308template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2309void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2310 ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2311{
2312 dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, true> pack;
2313 pack(blockA, lhs, depth, rows, stride, offset);
2314}
2315
2316template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2317struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2318{
2319 void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2320};
2321
2322template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2323void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2324 ::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2325{
2326 dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, true> pack;
2327 pack(blockA, lhs, depth, rows, stride, offset);
2328}
2329
2330template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2331struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2332{
2333 void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2334};
2335
2336template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2337void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2338 ::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2339{
2340 dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, true> pack;
2341 pack(blockA, lhs, depth, rows, stride, offset);
2342}
2343
2344#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
2345template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2346struct gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2347{
2348 void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2349};
2350
2351template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2352void gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2353 ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2354{
2355 dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, false> pack;
2356 pack(blockB, rhs, depth, cols, stride, offset);
2357}
2358
2359template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2360struct gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2361{
2362 void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2363};
2364
2365template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2366void gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2367 ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2368{
2369 dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, false> pack;
2370 pack(blockB, rhs, depth, cols, stride, offset);
2371}
2372#endif
2373
2374template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2375struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2376{
2377 void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2378};
2379
2380template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2381void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2382 ::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2383{
2384 dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, false> pack;
2385 pack(blockB, rhs, depth, cols, stride, offset);
2386}
2387
2388template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2389struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2390{
2391 void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2392};
2393
2394template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2395void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2396 ::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2397{
2398 dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, false> pack;
2399 pack(blockB, rhs, depth, cols, stride, offset);
2400}
2401
2402template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2403struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2404{
2405 void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2406};
2407
2408template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2409void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2410 ::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2411{
2412 dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, true> pack;
2413 pack(blockA, lhs, depth, rows, stride, offset);
2414}
2415
2416template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2417struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2418{
2419 void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2420};
2421
2422template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2423void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2424 ::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2425{
2426 dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, true> pack;
2427 pack(blockA, lhs, depth, rows, stride, offset);
2428}
2429
2430template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2431struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2432{
2433 void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2434};
2435
2436template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2437void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2438 ::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2439{
2440 dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, false> pack;
2441 pack(blockB, rhs, depth, cols, stride, offset);
2442}
2443
2444template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2445struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2446{
2447 void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2448};
2449
2450template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2451void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2452 ::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2453{
2454 dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, false> pack;
2455 pack(blockB, rhs, depth, cols, stride, offset);
2456}
2457
2458// ********* gebp specializations *********
2459template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2460struct gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2461{
2462 typedef typename quad_traits<float>::vectortype Packet;
2463 typedef typename quad_traits<float>::rhstype RhsPacket;
2464
2465 void operator()(const DataMapper& res, const float* blockA, const float* blockB,
2466 Index rows, Index depth, Index cols, float alpha,
2467 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2468};
2469
2470template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2471void gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2472 ::operator()(const DataMapper& res, const float* blockA, const float* blockB,
2473 Index rows, Index depth, Index cols, float alpha,
2474 Index strideA, Index strideB, Index offsetA, Index offsetB)
2475 {
2476 const Index accRows = quad_traits<float>::rows;
2477 const Index accCols = quad_traits<float>::size;
2478 void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index);
2479
2480 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2481 //generate with MMA only
2482 gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2483 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2484 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2485 gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2486 }
2487 else{
2488 gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2489 }
2490 #else
2491 gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2492 #endif
2493 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2494 }
2495
2496template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2497struct gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2498{
2499 typedef Packet4f Packet;
2500 typedef Packet2cf Packetc;
2501 typedef Packet4f RhsPacket;
2502
2503 void operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
2504 Index rows, Index depth, Index cols, std::complex<float> alpha,
2505 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2506};
2507
2508template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2509void gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2510 ::operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
2511 Index rows, Index depth, Index cols, std::complex<float> alpha,
2512 Index strideA, Index strideB, Index offsetA, Index offsetB)
2513 {
2514 const Index accRows = quad_traits<float>::rows;
2515 const Index accCols = quad_traits<float>::size;
2516 void (*gemm_function)(const DataMapper&, const std::complex<float>*, const std::complex<float>*,
2517 Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
2518
2519 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2520 //generate with MMA only
2521 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2522 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2523 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2524 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2525 }
2526 else{
2527 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2528 }
2529 #else
2530 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2531 #endif
2532 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2533 }
2534
2535template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2536struct gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2537{
2538 typedef Packet4f Packet;
2539 typedef Packet2cf Packetc;
2540 typedef Packet4f RhsPacket;
2541
2542 void operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB,
2543 Index rows, Index depth, Index cols, std::complex<float> alpha,
2544 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2545};
2546
2547template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2548void gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2549 ::operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB,
2550 Index rows, Index depth, Index cols, std::complex<float> alpha,
2551 Index strideA, Index strideB, Index offsetA, Index offsetB)
2552 {
2553 const Index accRows = quad_traits<float>::rows;
2554 const Index accCols = quad_traits<float>::size;
2555 void (*gemm_function)(const DataMapper&, const float*, const std::complex<float>*,
2556 Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
2557 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2558 //generate with MMA only
2559 gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2560 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2561 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2562 gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2563 }
2564 else{
2565 gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2566 }
2567 #else
2568 gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2569 #endif
2570 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2571 }
2572
2573template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2574struct gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2575{
2576 typedef Packet4f Packet;
2577 typedef Packet2cf Packetc;
2578 typedef Packet4f RhsPacket;
2579
2580 void operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB,
2581 Index rows, Index depth, Index cols, std::complex<float> alpha,
2582 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2583};
2584
2585template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2586void gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2587 ::operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB,
2588 Index rows, Index depth, Index cols, std::complex<float> alpha,
2589 Index strideA, Index strideB, Index offsetA, Index offsetB)
2590 {
2591 const Index accRows = quad_traits<float>::rows;
2592 const Index accCols = quad_traits<float>::size;
2593 void (*gemm_function)(const DataMapper&, const std::complex<float>*, const float*,
2594 Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
2595 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2596 //generate with MMA only
2597 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2598 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2599 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2600 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2601 }
2602 else{
2603 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2604 }
2605 #else
2606 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2607 #endif
2608 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2609 }
2610
2611template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2612struct gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2613{
2614 typedef typename quad_traits<double>::vectortype Packet;
2615 typedef typename quad_traits<double>::rhstype RhsPacket;
2616
2617 void operator()(const DataMapper& res, const double* blockA, const double* blockB,
2618 Index rows, Index depth, Index cols, double alpha,
2619 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2620};
2621
2622template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2623void gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2624 ::operator()(const DataMapper& res, const double* blockA, const double* blockB,
2625 Index rows, Index depth, Index cols, double alpha,
2626 Index strideA, Index strideB, Index offsetA, Index offsetB)
2627 {
2628 const Index accRows = quad_traits<double>::rows;
2629 const Index accCols = quad_traits<double>::size;
2630 void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index);
2631
2632 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2633 //generate with MMA only
2634 gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2635 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2636 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2637 gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2638 }
2639 else{
2640 gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2641 }
2642 #else
2643 gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2644 #endif
2645 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2646 }
2647
2648template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2649struct gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2650{
2651 typedef quad_traits<double>::vectortype Packet;
2652 typedef Packet1cd Packetc;
2653 typedef quad_traits<double>::rhstype RhsPacket;
2654
2655 void operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
2656 Index rows, Index depth, Index cols, std::complex<double> alpha,
2657 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2658};
2659
2660template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2661void gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2662 ::operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
2663 Index rows, Index depth, Index cols, std::complex<double> alpha,
2664 Index strideA, Index strideB, Index offsetA, Index offsetB)
2665 {
2666 const Index accRows = quad_traits<double>::rows;
2667 const Index accCols = quad_traits<double>::size;
2668 void (*gemm_function)(const DataMapper&, const std::complex<double>*, const std::complex<double>*,
2669 Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
2670 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2671 //generate with MMA only
2672 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2673 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2674 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2675 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2676 }
2677 else{
2678 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2679 }
2680 #else
2681 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2682 #endif
2683 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2684 }
2685
2686template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2687struct gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2688{
2689 typedef quad_traits<double>::vectortype Packet;
2690 typedef Packet1cd Packetc;
2691 typedef quad_traits<double>::rhstype RhsPacket;
2692
2693 void operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB,
2694 Index rows, Index depth, Index cols, std::complex<double> alpha,
2695 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2696};
2697
2698template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2699void gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2700 ::operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB,
2701 Index rows, Index depth, Index cols, std::complex<double> alpha,
2702 Index strideA, Index strideB, Index offsetA, Index offsetB)
2703 {
2704 const Index accRows = quad_traits<double>::rows;
2705 const Index accCols = quad_traits<double>::size;
2706 void (*gemm_function)(const DataMapper&, const std::complex<double>*, const double*,
2707 Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
2708 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2709 //generate with MMA only
2710 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2711 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2712 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2713 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2714 }
2715 else{
2716 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2717 }
2718 #else
2719 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2720 #endif
2721 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2722 }
2723
2724template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2725struct gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2726{
2727 typedef quad_traits<double>::vectortype Packet;
2728 typedef Packet1cd Packetc;
2729 typedef quad_traits<double>::rhstype RhsPacket;
2730
2731 void operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB,
2732 Index rows, Index depth, Index cols, std::complex<double> alpha,
2733 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2734};
2735
2736template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2737void gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2738 ::operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB,
2739 Index rows, Index depth, Index cols, std::complex<double> alpha,
2740 Index strideA, Index strideB, Index offsetA, Index offsetB)
2741 {
2742 const Index accRows = quad_traits<double>::rows;
2743 const Index accCols = quad_traits<double>::size;
2744 void (*gemm_function)(const DataMapper&, const double*, const std::complex<double>*,
2745 Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
2746 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2747 //generate with MMA only
2748 gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2749 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2750 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2751 gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2752 }
2753 else{
2754 gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2755 }
2756 #else
2757 gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2758 #endif
2759 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2760 }
2761} // end namespace internal
2762
2763} // end namespace Eigen
2764
2765#endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H
@ ColMajor
Definition Constants.h:319
@ RowMajor
Definition Constants.h:321
Namespace containing all symbols from the Eigen library.
Definition Core:141
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_real_op< typename Derived::Scalar >, const Derived > real(const Eigen::ArrayBase< Derived > &x)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:74
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_imag_op< typename Derived::Scalar >, const Derived > imag(const Eigen::ArrayBase< Derived > &x)