1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17 
18 
19 //! \addtogroup herk
20 //! @{
21 
22 
23 
24 class herk_helper
25   {
26   public:
27 
28   template<typename eT>
29   inline
30   static
31   void
inplace_conj_copy_upper_tri_to_lower_tri(Mat<eT> & C)32   inplace_conj_copy_upper_tri_to_lower_tri(Mat<eT>& C)
33     {
34     // under the assumption that C is a square matrix
35 
36     const uword N = C.n_rows;
37 
38     for(uword k=0; k < N; ++k)
39       {
40       eT* colmem = C.colptr(k);
41 
42       for(uword i=(k+1); i < N; ++i)
43         {
44         colmem[i] = std::conj( C.at(k,i) );
45         }
46       }
47     }
48 
49 
50   template<typename eT>
51   static
52   arma_hot
53   inline
54   eT
dot_conj_row(const uword n_elem,const eT * const A,const Mat<eT> & B,const uword row)55   dot_conj_row(const uword n_elem, const eT* const A, const Mat<eT>& B, const uword row)
56     {
57     arma_extra_debug_sigprint();
58 
59     typedef typename get_pod_type<eT>::result T;
60 
61     T val_real = T(0);
62     T val_imag = T(0);
63 
64     for(uword i=0; i<n_elem; ++i)
65       {
66       const std::complex<T>& X = A[i];
67       const std::complex<T>& Y = B.at(row,i);
68 
69       const T a = X.real();
70       const T b = X.imag();
71 
72       const T c = Y.real();
73       const T d = Y.imag();
74 
75       val_real += (a*c) + (b*d);
76       val_imag += (b*c) - (a*d);
77       }
78 
79     return std::complex<T>(val_real, val_imag);
80     }
81 
82   };
83 
84 
85 
86 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
87 class herk_vec
88   {
89   public:
90 
91   template<typename T, typename TA>
92   arma_hot
93   inline
94   static
95   void
apply(Mat<std::complex<T>> & C,const TA & A,const T alpha=T (1),const T beta=T (0))96   apply
97     (
98           Mat< std::complex<T> >& C,
99     const TA&                     A,
100     const T                       alpha = T(1),
101     const T                       beta  = T(0)
102     )
103     {
104     arma_extra_debug_sigprint();
105 
106     typedef std::complex<T> eT;
107 
108     const uword A_n_rows = A.n_rows;
109     const uword A_n_cols = A.n_cols;
110 
111     // for beta != 0, C is assumed to be hermitian
112 
113     // do_trans_A == false  ->   C = alpha * A   * A^H + beta*C
114     // do_trans_A == true   ->   C = alpha * A^H * A   + beta*C
115 
116     const eT* A_mem = A.memptr();
117 
118     if(do_trans_A == false)
119       {
120       if(A_n_rows == 1)
121         {
122         const eT acc = op_cdot::direct_cdot(A_n_cols, A_mem, A_mem);
123 
124              if( (use_alpha == false) && (use_beta == false) )  { C[0] =       acc;             }
125         else if( (use_alpha == true ) && (use_beta == false) )  { C[0] = alpha*acc;             }
126         else if( (use_alpha == false) && (use_beta == true ) )  { C[0] =       acc + beta*C[0]; }
127         else if( (use_alpha == true ) && (use_beta == true ) )  { C[0] = alpha*acc + beta*C[0]; }
128         }
129       else
130       for(uword row_A=0; row_A < A_n_rows; ++row_A)
131         {
132         const eT& A_rowdata = A_mem[row_A];
133 
134         for(uword k=row_A; k < A_n_rows; ++k)
135           {
136           const eT acc = A_rowdata * std::conj( A_mem[k] );
137 
138           if( (use_alpha == false) && (use_beta == false) )
139             {
140                               C.at(row_A, k) = acc;
141             if(row_A != k)  { C.at(k, row_A) = std::conj(acc); }
142             }
143           else
144           if( (use_alpha == true) && (use_beta == false) )
145             {
146             const eT val = alpha*acc;
147 
148                               C.at(row_A, k) = val;
149             if(row_A != k)  { C.at(k, row_A) = std::conj(val); }
150             }
151           else
152           if( (use_alpha == false) && (use_beta == true) )
153             {
154                               C.at(row_A, k) =           acc  + beta*C.at(row_A, k);
155             if(row_A != k)  { C.at(k, row_A) = std::conj(acc) + beta*C.at(k, row_A); }
156             }
157           else
158           if( (use_alpha == true) && (use_beta == true) )
159             {
160             const eT val = alpha*acc;
161 
162                               C.at(row_A, k) =           val  + beta*C.at(row_A, k);
163             if(row_A != k)  { C.at(k, row_A) = std::conj(val) + beta*C.at(k, row_A); }
164             }
165           }
166         }
167       }
168     else
169     if(do_trans_A == true)
170       {
171       if(A_n_cols == 1)
172         {
173         const eT acc = op_cdot::direct_cdot(A_n_rows, A_mem, A_mem);
174 
175              if( (use_alpha == false) && (use_beta == false) )  { C[0] =       acc;             }
176         else if( (use_alpha == true ) && (use_beta == false) )  { C[0] = alpha*acc;             }
177         else if( (use_alpha == false) && (use_beta == true ) )  { C[0] =       acc + beta*C[0]; }
178         else if( (use_alpha == true ) && (use_beta == true ) )  { C[0] = alpha*acc + beta*C[0]; }
179         }
180       else
181       for(uword col_A=0; col_A < A_n_cols; ++col_A)
182         {
183         // col_A is interpreted as row_A when storing the results in matrix C
184 
185         const eT A_coldata = std::conj( A_mem[col_A] );
186 
187         for(uword k=col_A; k < A_n_cols ; ++k)
188           {
189           const eT acc = A_coldata * A_mem[k];
190 
191           if( (use_alpha == false) && (use_beta == false) )
192             {
193                               C.at(col_A, k) = acc;
194             if(col_A != k)  { C.at(k, col_A) = std::conj(acc); }
195             }
196           else
197           if( (use_alpha == true ) && (use_beta == false) )
198             {
199             const eT val = alpha*acc;
200 
201                               C.at(col_A, k) = val;
202             if(col_A != k)  { C.at(k, col_A) = std::conj(val); }
203             }
204           else
205           if( (use_alpha == false) && (use_beta == true ) )
206             {
207                               C.at(col_A, k) =           acc  + beta*C.at(col_A, k);
208             if(col_A != k)  { C.at(k, col_A) = std::conj(acc) + beta*C.at(k, col_A); }
209             }
210           else
211           if( (use_alpha == true ) && (use_beta == true ) )
212             {
213             const eT val = alpha*acc;
214 
215                               C.at(col_A, k) =           val  + beta*C.at(col_A, k);
216             if(col_A != k)  { C.at(k, col_A) = std::conj(val) + beta*C.at(k, col_A); }
217             }
218           }
219         }
220       }
221     }
222 
223   };
224 
225 
226 
227 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
228 class herk_emul
229   {
230   public:
231 
232   template<typename T, typename TA>
233   arma_hot
234   inline
235   static
236   void
apply(Mat<std::complex<T>> & C,const TA & A,const T alpha=T (1),const T beta=T (0))237   apply
238     (
239           Mat< std::complex<T> >& C,
240     const TA&                     A,
241     const T                       alpha = T(1),
242     const T                       beta  = T(0)
243     )
244     {
245     arma_extra_debug_sigprint();
246 
247     typedef std::complex<T> eT;
248 
249     // do_trans_A == false  ->   C = alpha * A   * A^H + beta*C
250     // do_trans_A == true   ->   C = alpha * A^H * A   + beta*C
251 
252     if(do_trans_A == false)
253       {
254       Mat<eT> AA;
255 
256       op_htrans::apply_mat_noalias(AA, A);
257 
258       herk_emul<true, use_alpha, use_beta>::apply(C, AA, alpha, beta);
259       }
260     else
261     if(do_trans_A == true)
262       {
263       const uword A_n_rows = A.n_rows;
264       const uword A_n_cols = A.n_cols;
265 
266       for(uword col_A=0; col_A < A_n_cols; ++col_A)
267         {
268         // col_A is interpreted as row_A when storing the results in matrix C
269 
270         const eT* A_coldata = A.colptr(col_A);
271 
272         for(uword k=col_A; k < A_n_cols ; ++k)
273           {
274           const eT acc = op_cdot::direct_cdot(A_n_rows, A_coldata, A.colptr(k));
275 
276           if( (use_alpha == false) && (use_beta == false) )
277             {
278                               C.at(col_A, k) = acc;
279             if(col_A != k)  { C.at(k, col_A) = std::conj(acc); }
280             }
281           else
282           if( (use_alpha == true) && (use_beta == false) )
283             {
284             const eT val = alpha*acc;
285 
286                               C.at(col_A, k) = val;
287             if(col_A != k)  { C.at(k, col_A) = std::conj(val); }
288             }
289           else
290           if( (use_alpha == false) && (use_beta == true) )
291             {
292                               C.at(col_A, k) =           acc  + beta*C.at(col_A, k);
293             if(col_A != k)  { C.at(k, col_A) = std::conj(acc) + beta*C.at(k, col_A); }
294             }
295           else
296           if( (use_alpha == true) && (use_beta == true) )
297             {
298             const eT val = alpha*acc;
299 
300                               C.at(col_A, k) =           val  + beta*C.at(col_A, k);
301             if(col_A != k)  { C.at(k, col_A) = std::conj(val) + beta*C.at(k, col_A); }
302             }
303           }
304         }
305       }
306     }
307 
308   };
309 
310 
311 
312 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
313 class herk
314   {
315   public:
316 
317   template<typename T, typename TA>
318   inline
319   static
320   void
apply_blas_type(Mat<std::complex<T>> & C,const TA & A,const T alpha=T (1),const T beta=T (0))321   apply_blas_type( Mat<std::complex<T>>& C, const TA& A, const T alpha = T(1), const T beta = T(0) )
322     {
323     arma_extra_debug_sigprint();
324 
325     const uword threshold = 16;
326 
327     if(A.is_vec())
328       {
329       // work around poor handling of vectors by herk() in ATLAS 3.8.4 and standard BLAS
330 
331       herk_vec<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta);
332 
333       return;
334       }
335 
336 
337     if( (A.n_elem <= threshold) )
338       {
339       herk_emul<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta);
340       }
341     else
342       {
343       #if defined(ARMA_USE_ATLAS)
344         {
345         if(use_beta == true)
346           {
347           typedef typename std::complex<T> eT;
348 
349           // use a temporary matrix, as we can't assume that matrix C is already symmetric
350           Mat<eT> D(C.n_rows, C.n_cols, arma_nozeros_indicator());
351 
352           herk<do_trans_A, use_alpha, false>::apply_blas_type(D,A,alpha);
353 
354           // NOTE: assuming beta=1; this is okay for now, as currently glue_times only uses beta=1
355           arrayops::inplace_plus(C.memptr(), D.memptr(), C.n_elem);
356 
357           return;
358           }
359 
360         atlas::cblas_herk<T>
361           (
362           atlas::CblasColMajor,
363           atlas::CblasUpper,
364           (do_trans_A) ? CblasConjTrans : atlas::CblasNoTrans,
365           C.n_cols,
366           (do_trans_A) ? A.n_rows : A.n_cols,
367           (use_alpha) ? alpha : T(1),
368           A.mem,
369           (do_trans_A) ? A.n_rows : C.n_cols,
370           (use_beta) ? beta : T(0),
371           C.memptr(),
372           C.n_cols
373           );
374 
375         herk_helper::inplace_conj_copy_upper_tri_to_lower_tri(C);
376         }
377       #elif defined(ARMA_USE_BLAS)
378         {
379         if(use_beta == true)
380           {
381           typedef typename std::complex<T> eT;
382 
383           // use a temporary matrix, as we can't assume that matrix C is already symmetric
384           Mat<eT> D(C.n_rows, C.n_cols, arma_nozeros_indicator());
385 
386           herk<do_trans_A, use_alpha, false>::apply_blas_type(D,A,alpha);
387 
388           // NOTE: assuming beta=1; this is okay for now, as currently glue_times only uses beta=1
389           arrayops::inplace_plus(C.memptr(), D.memptr(), C.n_elem);
390 
391           return;
392           }
393 
394         arma_extra_debug_print("blas::herk()");
395 
396         const char uplo = 'U';
397 
398         const char trans_A = (do_trans_A) ? 'C' : 'N';
399 
400         const blas_int n = blas_int(C.n_cols);
401         const blas_int k = (do_trans_A) ? blas_int(A.n_rows) : blas_int(A.n_cols);
402 
403         const T local_alpha = (use_alpha) ? alpha : T(1);
404         const T local_beta  = (use_beta)  ? beta  : T(0);
405 
406         const blas_int lda = (do_trans_A) ? k : n;
407 
408         arma_extra_debug_print( arma_str::format("blas::herk(): trans_A = %c") % trans_A );
409 
410         blas::herk<T>
411           (
412           &uplo,
413           &trans_A,
414           &n,
415           &k,
416           &local_alpha,
417           A.mem,
418           &lda,
419           &local_beta,
420           C.memptr(),
421           &n // &ldc
422           );
423 
424         herk_helper::inplace_conj_copy_upper_tri_to_lower_tri(C);
425         }
426       #else
427         {
428         herk_emul<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta);
429         }
430       #endif
431       }
432 
433     }
434 
435 
436 
437   template<typename eT, typename TA>
438   inline
439   static
440   void
apply(Mat<eT> & C,const TA & A,const eT alpha=eT (1),const eT beta=eT (0),const typename arma_not_cx<eT>::result * junk=nullptr)441   apply( Mat<eT>& C, const TA& A, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_not_cx<eT>::result* junk = nullptr )
442     {
443     arma_ignore(C);
444     arma_ignore(A);
445     arma_ignore(alpha);
446     arma_ignore(beta);
447     arma_ignore(junk);
448 
449     // herk() cannot be used by non-complex matrices
450 
451     return;
452     }
453 
454 
455 
456   template<typename TA>
457   arma_inline
458   static
459   void
apply(Mat<std::complex<float>> & C,const TA & A,const float alpha=float (1),const float beta=float (0))460   apply
461     (
462           Mat< std::complex<float> >& C,
463     const TA&                         A,
464     const float                       alpha = float(1),
465     const float                       beta  = float(0)
466     )
467     {
468     herk<do_trans_A, use_alpha, use_beta>::apply_blas_type(C,A,alpha,beta);
469     }
470 
471 
472 
473   template<typename TA>
474   arma_inline
475   static
476   void
apply(Mat<std::complex<double>> & C,const TA & A,const double alpha=double (1),const double beta=double (0))477   apply
478     (
479           Mat< std::complex<double> >& C,
480     const TA&                          A,
481     const double                       alpha = double(1),
482     const double                       beta  = double(0)
483     )
484     {
485     herk<do_trans_A, use_alpha, use_beta>::apply_blas_type(C,A,alpha,beta);
486     }
487 
488   };
489 
490 
491 
492 //! @}
493