1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17 
18 
19 //! \addtogroup fn_trace
20 //! @{
21 
22 
23 template<typename T1>
24 arma_warn_unused
25 inline
26 typename T1::elem_type
trace(const Base<typename T1::elem_type,T1> & X)27 trace(const Base<typename T1::elem_type, T1>& X)
28   {
29   arma_extra_debug_sigprint();
30 
31   typedef typename T1::elem_type eT;
32 
33   const Proxy<T1> P(X.get_ref());
34 
35   const uword N = (std::min)(P.get_n_rows(), P.get_n_cols());
36 
37   eT val1 = eT(0);
38   eT val2 = eT(0);
39 
40   uword i,j;
41   for(i=0, j=1; j<N; i+=2, j+=2)
42     {
43     val1 += P.at(i,i);
44     val2 += P.at(j,j);
45     }
46 
47   if(i < N)
48     {
49     val1 += P.at(i,i);
50     }
51 
52   return val1 + val2;
53   }
54 
55 
56 
57 template<typename T1>
58 arma_warn_unused
59 inline
60 typename T1::elem_type
trace(const Op<T1,op_diagmat> & X)61 trace(const Op<T1, op_diagmat>& X)
62   {
63   arma_extra_debug_sigprint();
64 
65   typedef typename T1::elem_type eT;
66 
67   const diagmat_proxy<T1> A(X.m);
68 
69   const uword N = (std::min)(A.n_rows, A.n_cols);
70 
71   eT val = eT(0);
72 
73   for(uword i=0; i<N; ++i)
74     {
75     val += A[i];
76     }
77 
78   return val;
79   }
80 
81 
82 
83 //! speedup for trace(A*B); non-complex elements
84 template<typename T1, typename T2>
85 arma_warn_unused
86 inline
87 typename enable_if2< is_cx<typename T1::elem_type>::no, typename T1::elem_type>::result
trace(const Glue<T1,T2,glue_times> & X)88 trace(const Glue<T1, T2, glue_times>& X)
89   {
90   arma_extra_debug_sigprint();
91 
92   typedef typename T1::elem_type eT;
93 
94   const partial_unwrap<T1> tmp1(X.A);
95   const partial_unwrap<T2> tmp2(X.B);
96 
97   const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
98   const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
99 
100   const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times;
101   const eT       alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
102 
103   arma_debug_assert_trans_mul_size< partial_unwrap<T1>::do_trans, partial_unwrap<T2>::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
104 
105   if( (A.n_elem == 0) || (B.n_elem == 0) )
106     {
107     return eT(0);
108     }
109 
110   const uword A_n_rows = A.n_rows;
111   const uword A_n_cols = A.n_cols;
112 
113   const uword B_n_rows = B.n_rows;
114   const uword B_n_cols = B.n_cols;
115 
116   eT acc = eT(0);
117 
118   if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == false) )
119     {
120     const uword N = (std::min)(A_n_rows, B_n_cols);
121 
122     eT acc1 = eT(0);
123     eT acc2 = eT(0);
124 
125     for(uword k=0; k < N; ++k)
126       {
127       const eT* B_colptr = B.colptr(k);
128 
129       // condition: A_n_cols = B_n_rows
130 
131       uword j;
132 
133       for(j=1; j < A_n_cols; j+=2)
134         {
135         const uword i = (j-1);
136 
137         const eT tmp_i = B_colptr[i];
138         const eT tmp_j = B_colptr[j];
139 
140         acc1 += A.at(k, i) * tmp_i;
141         acc2 += A.at(k, j) * tmp_j;
142         }
143 
144       const uword i = (j-1);
145 
146       if(i < A_n_cols)
147         {
148         acc1 += A.at(k, i) * B_colptr[i];
149         }
150       }
151 
152     acc = (acc1 + acc2);
153     }
154   else
155   if( (partial_unwrap<T1>::do_trans == true ) && (partial_unwrap<T2>::do_trans == false) )
156     {
157     const uword N = (std::min)(A_n_cols, B_n_cols);
158 
159     for(uword k=0; k < N; ++k)
160       {
161       const eT* A_colptr = A.colptr(k);
162       const eT* B_colptr = B.colptr(k);
163 
164       // condition: A_n_rows = B_n_rows
165       acc += op_dot::direct_dot(A_n_rows, A_colptr, B_colptr);
166       }
167     }
168   else
169   if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == true ) )
170     {
171     const uword N = (std::min)(A_n_rows, B_n_rows);
172 
173     for(uword k=0; k < N; ++k)
174       {
175       // condition: A_n_cols = B_n_cols
176       for(uword i=0; i < A_n_cols; ++i)
177         {
178         acc += A.at(k,i) * B.at(k,i);
179         }
180       }
181     }
182   else
183   if( (partial_unwrap<T1>::do_trans == true ) && (partial_unwrap<T2>::do_trans == true ) )
184     {
185     const uword N = (std::min)(A_n_cols, B_n_rows);
186 
187     for(uword k=0; k < N; ++k)
188       {
189       const eT* A_colptr = A.colptr(k);
190 
191       // condition: A_n_rows = B_n_cols
192       for(uword i=0; i < A_n_rows; ++i)
193         {
194         acc += A_colptr[i] * B.at(k,i);
195         }
196       }
197     }
198 
199   return (use_alpha) ? (alpha * acc) : acc;
200   }
201 
202 
203 
204 //! speedup for trace(A*B); complex elements
205 template<typename T1, typename T2>
206 arma_warn_unused
207 inline
208 typename enable_if2< is_cx<typename T1::elem_type>::yes, typename T1::elem_type>::result
trace(const Glue<T1,T2,glue_times> & X)209 trace(const Glue<T1, T2, glue_times>& X)
210   {
211   arma_extra_debug_sigprint();
212 
213   typedef typename T1::pod_type   T;
214   typedef typename T1::elem_type eT;
215 
216   const partial_unwrap<T1> tmp1(X.A);
217   const partial_unwrap<T2> tmp2(X.B);
218 
219   const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
220   const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
221 
222   const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times;
223   const eT       alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
224 
225   arma_debug_assert_trans_mul_size< partial_unwrap<T1>::do_trans, partial_unwrap<T2>::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
226 
227   if( (A.n_elem == 0) || (B.n_elem == 0) )
228     {
229     return eT(0);
230     }
231 
232   const uword A_n_rows = A.n_rows;
233   const uword A_n_cols = A.n_cols;
234 
235   const uword B_n_rows = B.n_rows;
236   const uword B_n_cols = B.n_cols;
237 
238   eT acc = eT(0);
239 
240   if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == false) )
241     {
242     const uword N = (std::min)(A_n_rows, B_n_cols);
243 
244     T acc_real = T(0);
245     T acc_imag = T(0);
246 
247     for(uword k=0; k < N; ++k)
248       {
249       const eT* B_colptr = B.colptr(k);
250 
251       // condition: A_n_cols = B_n_rows
252 
253       for(uword i=0; i < A_n_cols; ++i)
254         {
255         // acc += A.at(k, i) * B_colptr[i];
256 
257         const std::complex<T>& xx = A.at(k, i);
258         const std::complex<T>& yy = B_colptr[i];
259 
260         const T a = xx.real();
261         const T b = xx.imag();
262 
263         const T c = yy.real();
264         const T d = yy.imag();
265 
266         acc_real += (a*c) - (b*d);
267         acc_imag += (a*d) + (b*c);
268         }
269       }
270 
271     acc = std::complex<T>(acc_real, acc_imag);
272     }
273   else
274   if( (partial_unwrap<T1>::do_trans == true) && (partial_unwrap<T2>::do_trans == false) )
275     {
276     const uword N = (std::min)(A_n_cols, B_n_cols);
277 
278     T acc_real = T(0);
279     T acc_imag = T(0);
280 
281     for(uword k=0; k < N; ++k)
282       {
283       const eT* A_colptr = A.colptr(k);
284       const eT* B_colptr = B.colptr(k);
285 
286       // condition: A_n_rows = B_n_rows
287 
288       for(uword i=0; i < A_n_rows; ++i)
289         {
290         // acc += std::conj(A_colptr[i]) * B_colptr[i];
291 
292         const std::complex<T>& xx = A_colptr[i];
293         const std::complex<T>& yy = B_colptr[i];
294 
295         const T a = xx.real();
296         const T b = xx.imag();
297 
298         const T c = yy.real();
299         const T d = yy.imag();
300 
301         // take into account the complex conjugate of xx
302 
303         acc_real += (a*c) + (b*d);
304         acc_imag += (a*d) - (b*c);
305         }
306       }
307 
308     acc = std::complex<T>(acc_real, acc_imag);
309     }
310   else
311   if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == true) )
312     {
313     const uword N = (std::min)(A_n_rows, B_n_rows);
314 
315     T acc_real = T(0);
316     T acc_imag = T(0);
317 
318     for(uword k=0; k < N; ++k)
319       {
320       // condition: A_n_cols = B_n_cols
321       for(uword i=0; i < A_n_cols; ++i)
322         {
323         // acc += A.at(k,i) * std::conj(B.at(k,i));
324 
325         const std::complex<T>& xx = A.at(k, i);
326         const std::complex<T>& yy = B.at(k, i);
327 
328         const T a = xx.real();
329         const T b = xx.imag();
330 
331         const T c =  yy.real();
332         const T d = -yy.imag();  // take the conjugate
333 
334         acc_real += (a*c) - (b*d);
335         acc_imag += (a*d) + (b*c);
336         }
337       }
338 
339     acc = std::complex<T>(acc_real, acc_imag);
340     }
341   else
342   if( (partial_unwrap<T1>::do_trans == true) && (partial_unwrap<T2>::do_trans == true) )
343     {
344     const uword N = (std::min)(A_n_cols, B_n_rows);
345 
346     T acc_real = T(0);
347     T acc_imag = T(0);
348 
349     for(uword k=0; k < N; ++k)
350       {
351       const eT* A_colptr = A.colptr(k);
352 
353       // condition: A_n_rows = B_n_cols
354       for(uword i=0; i < A_n_rows; ++i)
355         {
356         // acc += std::conj(A_colptr[i]) * std::conj(B.at(k,i));
357 
358         const std::complex<T>& xx = A_colptr[i];
359         const std::complex<T>& yy = B.at(k, i);
360 
361         const T a =  xx.real();
362         const T b = -xx.imag();  // take the conjugate
363 
364         const T c =  yy.real();
365         const T d = -yy.imag();  // take the conjugate
366 
367         acc_real += (a*c) - (b*d);
368         acc_imag += (a*d) + (b*c);
369         }
370       }
371 
372     acc = std::complex<T>(acc_real, acc_imag);
373     }
374 
375   return (use_alpha) ? eT(alpha * acc) : eT(acc);
376   }
377 
378 
379 
380 //! trace of sparse object; generic version
381 template<typename T1>
382 arma_warn_unused
383 inline
384 typename T1::elem_type
trace(const SpBase<typename T1::elem_type,T1> & expr)385 trace(const SpBase<typename T1::elem_type,T1>& expr)
386   {
387   arma_extra_debug_sigprint();
388 
389   typedef typename T1::elem_type eT;
390 
391   const SpProxy<T1> P(expr.get_ref());
392 
393   const uword N = (std::min)(P.get_n_rows(), P.get_n_cols());
394 
395   eT acc = eT(0);
396 
397   if( (is_SpMat<typename SpProxy<T1>::stored_type>::value) && (P.get_n_nonzero() >= 5*N) )
398     {
399     const unwrap_spmat<typename SpProxy<T1>::stored_type> U(P.Q);
400 
401     const SpMat<eT>& X = U.M;
402 
403     for(uword i=0; i < N; ++i)
404       {
405       acc += X.at(i,i);  // use binary search
406       }
407     }
408   else
409     {
410     typename SpProxy<T1>::const_iterator_type it = P.begin();
411 
412     const uword P_n_nz = P.get_n_nonzero();
413 
414     for(uword i=0; i < P_n_nz; ++i)
415       {
416       if(it.row() == it.col())  { acc += (*it); }
417 
418       ++it;
419       }
420     }
421 
422   return acc;
423   }
424 
425 
426 
427 //! trace of sparse object; speedup for trace(A + B)
428 template<typename T1, typename T2>
429 arma_warn_unused
430 inline
431 typename T1::elem_type
trace(const SpGlue<T1,T2,spglue_plus> & expr)432 trace(const SpGlue<T1, T2, spglue_plus>& expr)
433   {
434   arma_extra_debug_sigprint();
435 
436   const unwrap_spmat<T1> UA(expr.A);
437   const unwrap_spmat<T2> UB(expr.B);
438 
439   arma_debug_assert_same_size(UA.M.n_rows, UA.M.n_cols, UB.M.n_rows, UB.M.n_cols, "addition");
440 
441   return (trace(UA.M) + trace(UB.M));
442   }
443 
444 
445 
446 //! trace of sparse object; speedup for trace(A - B)
447 template<typename T1, typename T2>
448 arma_warn_unused
449 inline
450 typename T1::elem_type
trace(const SpGlue<T1,T2,spglue_minus> & expr)451 trace(const SpGlue<T1, T2, spglue_minus>& expr)
452   {
453   arma_extra_debug_sigprint();
454 
455   const unwrap_spmat<T1> UA(expr.A);
456   const unwrap_spmat<T2> UB(expr.B);
457 
458   arma_debug_assert_same_size(UA.M.n_rows, UA.M.n_cols, UB.M.n_rows, UB.M.n_cols, "subtraction");
459 
460   return (trace(UA.M) - trace(UB.M));
461   }
462 
463 
464 
465 //! trace of sparse object; speedup for trace(A % B)
466 template<typename T1, typename T2>
467 arma_warn_unused
468 inline
469 typename T1::elem_type
trace(const SpGlue<T1,T2,spglue_schur> & expr)470 trace(const SpGlue<T1, T2, spglue_schur>& expr)
471   {
472   arma_extra_debug_sigprint();
473 
474   typedef typename T1::elem_type eT;
475 
476   const unwrap_spmat<T1> UA(expr.A);
477   const unwrap_spmat<T2> UB(expr.B);
478 
479   const SpMat<eT>& A = UA.M;
480   const SpMat<eT>& B = UB.M;
481 
482   arma_debug_assert_same_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "element-wise multiplication");
483 
484   const uword N = (std::min)(A.n_rows, A.n_cols);
485 
486   eT acc = eT(0);
487 
488   for(uword i=0; i<N; ++i)
489     {
490     acc += A.at(i,i) * B.at(i,i);
491     }
492 
493   return acc;
494   }
495 
496 
497 
498 //! trace of sparse object; speedup for trace(A*B)
499 template<typename T1, typename T2>
500 arma_warn_unused
501 inline
502 typename T1::elem_type
trace(const SpGlue<T1,T2,spglue_times> & expr)503 trace(const SpGlue<T1, T2, spglue_times>& expr)
504   {
505   arma_extra_debug_sigprint();
506 
507   typedef typename T1::elem_type eT;
508 
509   // better-than-nothing implementation
510 
511   const unwrap_spmat<T1> UA(expr.A);
512   const unwrap_spmat<T2> UB(expr.B);
513 
514   const SpMat<eT>& A = UA.M;
515   const SpMat<eT>& B = UB.M;
516 
517   arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
518 
519   if( (A.n_nonzero == 0) || (B.n_nonzero == 0) )
520     {
521     return eT(0);
522     }
523 
524   const uword N = (std::min)(A.n_rows, B.n_cols);
525 
526   eT acc = eT(0);
527 
528   // TODO: the threshold may need tuning for complex matrices
529   if( (A.n_nonzero >= 5*N) || (B.n_nonzero >= 5*N) )
530     {
531     for(uword k=0; k < N; ++k)
532       {
533       typename SpMat<eT>::const_col_iterator B_it     = B.begin_col_no_sync(k);
534       typename SpMat<eT>::const_col_iterator B_it_end = B.end_col_no_sync(k);
535 
536       while(B_it != B_it_end)
537         {
538         const eT    B_val = (*B_it);
539         const uword i     = B_it.row();
540 
541         acc += A.at(k,i) * B_val;
542 
543         ++B_it;
544         }
545       }
546     }
547   else
548     {
549     const SpMat<eT> AB = A * B;
550 
551     acc = trace(AB);
552     }
553 
554   return acc;
555   }
556 
557 
558 
559 //! trace of sparse object; speedup for trace(A.t()*B); non-complex elements
560 template<typename T1, typename T2>
561 arma_warn_unused
562 inline
563 typename enable_if2< is_cx<typename T1::elem_type>::no, typename T1::elem_type>::result
trace(const SpGlue<SpOp<T1,spop_htrans>,T2,spglue_times> & expr)564 trace(const SpGlue<SpOp<T1, spop_htrans>, T2, spglue_times>& expr)
565   {
566   arma_extra_debug_sigprint();
567 
568   typedef typename T1::elem_type eT;
569 
570   const unwrap_spmat<T1> UA(expr.A.m);
571   const unwrap_spmat<T2> UB(expr.B);
572 
573   const SpMat<eT>& A = UA.M;
574   const SpMat<eT>& B = UB.M;
575 
576   // NOTE: deliberately swapped A.n_rows and A.n_cols to take into account the requested transpose operation
577   arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "matrix multiplication");
578 
579   if( (A.n_nonzero == 0) || (B.n_nonzero == 0) )
580     {
581     return eT(0);
582     }
583 
584   const uword N = (std::min)(A.n_cols, B.n_cols);
585 
586   eT acc = eT(0);
587 
588   if( (A.n_nonzero >= 5*N) || (B.n_nonzero >= 5*N) )
589     {
590     for(uword k=0; k < N; ++k)
591       {
592       typename SpMat<eT>::const_col_iterator B_it     = B.begin_col_no_sync(k);
593       typename SpMat<eT>::const_col_iterator B_it_end = B.end_col_no_sync(k);
594 
595       while(B_it != B_it_end)
596         {
597         const eT    B_val = (*B_it);
598         const uword i     = B_it.row();
599 
600         acc += A.at(i,k) * B_val;
601 
602         ++B_it;
603         }
604       }
605     }
606   else
607     {
608     const SpMat<eT> AtB = A.t() * B;
609 
610     acc = trace(AtB);
611     }
612 
613   return acc;
614   }
615 
616 
617 
618 //! trace of sparse object; speedup for trace(A.t()*B); complex elements
619 template<typename T1, typename T2>
620 arma_warn_unused
621 inline
622 typename enable_if2< is_cx<typename T1::elem_type>::yes, typename T1::elem_type>::result
trace(const SpGlue<SpOp<T1,spop_htrans>,T2,spglue_times> & expr)623 trace(const SpGlue<SpOp<T1, spop_htrans>, T2, spglue_times>& expr)
624   {
625   arma_extra_debug_sigprint();
626 
627   typedef typename T1::elem_type eT;
628 
629   const unwrap_spmat<T1> UA(expr.A.m);
630   const unwrap_spmat<T2> UB(expr.B);
631 
632   const SpMat<eT>& A = UA.M;
633   const SpMat<eT>& B = UB.M;
634 
635   // NOTE: deliberately swapped A.n_rows and A.n_cols to take into account the requested transpose operation
636   arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "matrix multiplication");
637 
638   if( (A.n_nonzero == 0) || (B.n_nonzero == 0) )
639     {
640     return eT(0);
641     }
642 
643   const uword N = (std::min)(A.n_cols, B.n_cols);
644 
645   eT acc = eT(0);
646 
647   // TODO: the threshold may need tuning for complex matrices
648   if( (A.n_nonzero >= 5*N) || (B.n_nonzero >= 5*N) )
649     {
650     for(uword k=0; k < N; ++k)
651       {
652       typename SpMat<eT>::const_col_iterator B_it     = B.begin_col_no_sync(k);
653       typename SpMat<eT>::const_col_iterator B_it_end = B.end_col_no_sync(k);
654 
655       while(B_it != B_it_end)
656         {
657         const eT    B_val = (*B_it);
658         const uword i     = B_it.row();
659 
660         acc += std::conj(A.at(i,k)) * B_val;
661 
662         ++B_it;
663         }
664       }
665     }
666   else
667     {
668     const SpMat<eT> AtB = A.t() * B;
669 
670     acc = trace(AtB);
671     }
672 
673   return acc;
674   }
675 
676 
677 
678 //! @}
679