1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17
18
19 //! \addtogroup fn_trace
20 //! @{
21
22
23 template<typename T1>
24 arma_warn_unused
25 inline
26 typename T1::elem_type
trace(const Base<typename T1::elem_type,T1> & X)27 trace(const Base<typename T1::elem_type, T1>& X)
28 {
29 arma_extra_debug_sigprint();
30
31 typedef typename T1::elem_type eT;
32
33 const Proxy<T1> P(X.get_ref());
34
35 const uword N = (std::min)(P.get_n_rows(), P.get_n_cols());
36
37 eT val1 = eT(0);
38 eT val2 = eT(0);
39
40 uword i,j;
41 for(i=0, j=1; j<N; i+=2, j+=2)
42 {
43 val1 += P.at(i,i);
44 val2 += P.at(j,j);
45 }
46
47 if(i < N)
48 {
49 val1 += P.at(i,i);
50 }
51
52 return val1 + val2;
53 }
54
55
56
57 template<typename T1>
58 arma_warn_unused
59 inline
60 typename T1::elem_type
trace(const Op<T1,op_diagmat> & X)61 trace(const Op<T1, op_diagmat>& X)
62 {
63 arma_extra_debug_sigprint();
64
65 typedef typename T1::elem_type eT;
66
67 const diagmat_proxy<T1> A(X.m);
68
69 const uword N = (std::min)(A.n_rows, A.n_cols);
70
71 eT val = eT(0);
72
73 for(uword i=0; i<N; ++i)
74 {
75 val += A[i];
76 }
77
78 return val;
79 }
80
81
82
83 //! speedup for trace(A*B); non-complex elements
84 template<typename T1, typename T2>
85 arma_warn_unused
86 inline
87 typename enable_if2< is_cx<typename T1::elem_type>::no, typename T1::elem_type>::result
trace(const Glue<T1,T2,glue_times> & X)88 trace(const Glue<T1, T2, glue_times>& X)
89 {
90 arma_extra_debug_sigprint();
91
92 typedef typename T1::elem_type eT;
93
94 const partial_unwrap<T1> tmp1(X.A);
95 const partial_unwrap<T2> tmp2(X.B);
96
97 const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
98 const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
99
100 const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times;
101 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
102
103 arma_debug_assert_trans_mul_size< partial_unwrap<T1>::do_trans, partial_unwrap<T2>::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
104
105 if( (A.n_elem == 0) || (B.n_elem == 0) )
106 {
107 return eT(0);
108 }
109
110 const uword A_n_rows = A.n_rows;
111 const uword A_n_cols = A.n_cols;
112
113 const uword B_n_rows = B.n_rows;
114 const uword B_n_cols = B.n_cols;
115
116 eT acc = eT(0);
117
118 if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == false) )
119 {
120 const uword N = (std::min)(A_n_rows, B_n_cols);
121
122 eT acc1 = eT(0);
123 eT acc2 = eT(0);
124
125 for(uword k=0; k < N; ++k)
126 {
127 const eT* B_colptr = B.colptr(k);
128
129 // condition: A_n_cols = B_n_rows
130
131 uword j;
132
133 for(j=1; j < A_n_cols; j+=2)
134 {
135 const uword i = (j-1);
136
137 const eT tmp_i = B_colptr[i];
138 const eT tmp_j = B_colptr[j];
139
140 acc1 += A.at(k, i) * tmp_i;
141 acc2 += A.at(k, j) * tmp_j;
142 }
143
144 const uword i = (j-1);
145
146 if(i < A_n_cols)
147 {
148 acc1 += A.at(k, i) * B_colptr[i];
149 }
150 }
151
152 acc = (acc1 + acc2);
153 }
154 else
155 if( (partial_unwrap<T1>::do_trans == true ) && (partial_unwrap<T2>::do_trans == false) )
156 {
157 const uword N = (std::min)(A_n_cols, B_n_cols);
158
159 for(uword k=0; k < N; ++k)
160 {
161 const eT* A_colptr = A.colptr(k);
162 const eT* B_colptr = B.colptr(k);
163
164 // condition: A_n_rows = B_n_rows
165 acc += op_dot::direct_dot(A_n_rows, A_colptr, B_colptr);
166 }
167 }
168 else
169 if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == true ) )
170 {
171 const uword N = (std::min)(A_n_rows, B_n_rows);
172
173 for(uword k=0; k < N; ++k)
174 {
175 // condition: A_n_cols = B_n_cols
176 for(uword i=0; i < A_n_cols; ++i)
177 {
178 acc += A.at(k,i) * B.at(k,i);
179 }
180 }
181 }
182 else
183 if( (partial_unwrap<T1>::do_trans == true ) && (partial_unwrap<T2>::do_trans == true ) )
184 {
185 const uword N = (std::min)(A_n_cols, B_n_rows);
186
187 for(uword k=0; k < N; ++k)
188 {
189 const eT* A_colptr = A.colptr(k);
190
191 // condition: A_n_rows = B_n_cols
192 for(uword i=0; i < A_n_rows; ++i)
193 {
194 acc += A_colptr[i] * B.at(k,i);
195 }
196 }
197 }
198
199 return (use_alpha) ? (alpha * acc) : acc;
200 }
201
202
203
204 //! speedup for trace(A*B); complex elements
205 template<typename T1, typename T2>
206 arma_warn_unused
207 inline
208 typename enable_if2< is_cx<typename T1::elem_type>::yes, typename T1::elem_type>::result
trace(const Glue<T1,T2,glue_times> & X)209 trace(const Glue<T1, T2, glue_times>& X)
210 {
211 arma_extra_debug_sigprint();
212
213 typedef typename T1::pod_type T;
214 typedef typename T1::elem_type eT;
215
216 const partial_unwrap<T1> tmp1(X.A);
217 const partial_unwrap<T2> tmp2(X.B);
218
219 const typename partial_unwrap<T1>::stored_type& A = tmp1.M;
220 const typename partial_unwrap<T2>::stored_type& B = tmp2.M;
221
222 const bool use_alpha = partial_unwrap<T1>::do_times || partial_unwrap<T2>::do_times;
223 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
224
225 arma_debug_assert_trans_mul_size< partial_unwrap<T1>::do_trans, partial_unwrap<T2>::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
226
227 if( (A.n_elem == 0) || (B.n_elem == 0) )
228 {
229 return eT(0);
230 }
231
232 const uword A_n_rows = A.n_rows;
233 const uword A_n_cols = A.n_cols;
234
235 const uword B_n_rows = B.n_rows;
236 const uword B_n_cols = B.n_cols;
237
238 eT acc = eT(0);
239
240 if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == false) )
241 {
242 const uword N = (std::min)(A_n_rows, B_n_cols);
243
244 T acc_real = T(0);
245 T acc_imag = T(0);
246
247 for(uword k=0; k < N; ++k)
248 {
249 const eT* B_colptr = B.colptr(k);
250
251 // condition: A_n_cols = B_n_rows
252
253 for(uword i=0; i < A_n_cols; ++i)
254 {
255 // acc += A.at(k, i) * B_colptr[i];
256
257 const std::complex<T>& xx = A.at(k, i);
258 const std::complex<T>& yy = B_colptr[i];
259
260 const T a = xx.real();
261 const T b = xx.imag();
262
263 const T c = yy.real();
264 const T d = yy.imag();
265
266 acc_real += (a*c) - (b*d);
267 acc_imag += (a*d) + (b*c);
268 }
269 }
270
271 acc = std::complex<T>(acc_real, acc_imag);
272 }
273 else
274 if( (partial_unwrap<T1>::do_trans == true) && (partial_unwrap<T2>::do_trans == false) )
275 {
276 const uword N = (std::min)(A_n_cols, B_n_cols);
277
278 T acc_real = T(0);
279 T acc_imag = T(0);
280
281 for(uword k=0; k < N; ++k)
282 {
283 const eT* A_colptr = A.colptr(k);
284 const eT* B_colptr = B.colptr(k);
285
286 // condition: A_n_rows = B_n_rows
287
288 for(uword i=0; i < A_n_rows; ++i)
289 {
290 // acc += std::conj(A_colptr[i]) * B_colptr[i];
291
292 const std::complex<T>& xx = A_colptr[i];
293 const std::complex<T>& yy = B_colptr[i];
294
295 const T a = xx.real();
296 const T b = xx.imag();
297
298 const T c = yy.real();
299 const T d = yy.imag();
300
301 // take into account the complex conjugate of xx
302
303 acc_real += (a*c) + (b*d);
304 acc_imag += (a*d) - (b*c);
305 }
306 }
307
308 acc = std::complex<T>(acc_real, acc_imag);
309 }
310 else
311 if( (partial_unwrap<T1>::do_trans == false) && (partial_unwrap<T2>::do_trans == true) )
312 {
313 const uword N = (std::min)(A_n_rows, B_n_rows);
314
315 T acc_real = T(0);
316 T acc_imag = T(0);
317
318 for(uword k=0; k < N; ++k)
319 {
320 // condition: A_n_cols = B_n_cols
321 for(uword i=0; i < A_n_cols; ++i)
322 {
323 // acc += A.at(k,i) * std::conj(B.at(k,i));
324
325 const std::complex<T>& xx = A.at(k, i);
326 const std::complex<T>& yy = B.at(k, i);
327
328 const T a = xx.real();
329 const T b = xx.imag();
330
331 const T c = yy.real();
332 const T d = -yy.imag(); // take the conjugate
333
334 acc_real += (a*c) - (b*d);
335 acc_imag += (a*d) + (b*c);
336 }
337 }
338
339 acc = std::complex<T>(acc_real, acc_imag);
340 }
341 else
342 if( (partial_unwrap<T1>::do_trans == true) && (partial_unwrap<T2>::do_trans == true) )
343 {
344 const uword N = (std::min)(A_n_cols, B_n_rows);
345
346 T acc_real = T(0);
347 T acc_imag = T(0);
348
349 for(uword k=0; k < N; ++k)
350 {
351 const eT* A_colptr = A.colptr(k);
352
353 // condition: A_n_rows = B_n_cols
354 for(uword i=0; i < A_n_rows; ++i)
355 {
356 // acc += std::conj(A_colptr[i]) * std::conj(B.at(k,i));
357
358 const std::complex<T>& xx = A_colptr[i];
359 const std::complex<T>& yy = B.at(k, i);
360
361 const T a = xx.real();
362 const T b = -xx.imag(); // take the conjugate
363
364 const T c = yy.real();
365 const T d = -yy.imag(); // take the conjugate
366
367 acc_real += (a*c) - (b*d);
368 acc_imag += (a*d) + (b*c);
369 }
370 }
371
372 acc = std::complex<T>(acc_real, acc_imag);
373 }
374
375 return (use_alpha) ? eT(alpha * acc) : eT(acc);
376 }
377
378
379
380 //! trace of sparse object; generic version
381 template<typename T1>
382 arma_warn_unused
383 inline
384 typename T1::elem_type
trace(const SpBase<typename T1::elem_type,T1> & expr)385 trace(const SpBase<typename T1::elem_type,T1>& expr)
386 {
387 arma_extra_debug_sigprint();
388
389 typedef typename T1::elem_type eT;
390
391 const SpProxy<T1> P(expr.get_ref());
392
393 const uword N = (std::min)(P.get_n_rows(), P.get_n_cols());
394
395 eT acc = eT(0);
396
397 if( (is_SpMat<typename SpProxy<T1>::stored_type>::value) && (P.get_n_nonzero() >= 5*N) )
398 {
399 const unwrap_spmat<typename SpProxy<T1>::stored_type> U(P.Q);
400
401 const SpMat<eT>& X = U.M;
402
403 for(uword i=0; i < N; ++i)
404 {
405 acc += X.at(i,i); // use binary search
406 }
407 }
408 else
409 {
410 typename SpProxy<T1>::const_iterator_type it = P.begin();
411
412 const uword P_n_nz = P.get_n_nonzero();
413
414 for(uword i=0; i < P_n_nz; ++i)
415 {
416 if(it.row() == it.col()) { acc += (*it); }
417
418 ++it;
419 }
420 }
421
422 return acc;
423 }
424
425
426
427 //! trace of sparse object; speedup for trace(A + B)
428 template<typename T1, typename T2>
429 arma_warn_unused
430 inline
431 typename T1::elem_type
trace(const SpGlue<T1,T2,spglue_plus> & expr)432 trace(const SpGlue<T1, T2, spglue_plus>& expr)
433 {
434 arma_extra_debug_sigprint();
435
436 const unwrap_spmat<T1> UA(expr.A);
437 const unwrap_spmat<T2> UB(expr.B);
438
439 arma_debug_assert_same_size(UA.M.n_rows, UA.M.n_cols, UB.M.n_rows, UB.M.n_cols, "addition");
440
441 return (trace(UA.M) + trace(UB.M));
442 }
443
444
445
446 //! trace of sparse object; speedup for trace(A - B)
447 template<typename T1, typename T2>
448 arma_warn_unused
449 inline
450 typename T1::elem_type
trace(const SpGlue<T1,T2,spglue_minus> & expr)451 trace(const SpGlue<T1, T2, spglue_minus>& expr)
452 {
453 arma_extra_debug_sigprint();
454
455 const unwrap_spmat<T1> UA(expr.A);
456 const unwrap_spmat<T2> UB(expr.B);
457
458 arma_debug_assert_same_size(UA.M.n_rows, UA.M.n_cols, UB.M.n_rows, UB.M.n_cols, "subtraction");
459
460 return (trace(UA.M) - trace(UB.M));
461 }
462
463
464
465 //! trace of sparse object; speedup for trace(A % B)
466 template<typename T1, typename T2>
467 arma_warn_unused
468 inline
469 typename T1::elem_type
trace(const SpGlue<T1,T2,spglue_schur> & expr)470 trace(const SpGlue<T1, T2, spglue_schur>& expr)
471 {
472 arma_extra_debug_sigprint();
473
474 typedef typename T1::elem_type eT;
475
476 const unwrap_spmat<T1> UA(expr.A);
477 const unwrap_spmat<T2> UB(expr.B);
478
479 const SpMat<eT>& A = UA.M;
480 const SpMat<eT>& B = UB.M;
481
482 arma_debug_assert_same_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "element-wise multiplication");
483
484 const uword N = (std::min)(A.n_rows, A.n_cols);
485
486 eT acc = eT(0);
487
488 for(uword i=0; i<N; ++i)
489 {
490 acc += A.at(i,i) * B.at(i,i);
491 }
492
493 return acc;
494 }
495
496
497
498 //! trace of sparse object; speedup for trace(A*B)
499 template<typename T1, typename T2>
500 arma_warn_unused
501 inline
502 typename T1::elem_type
trace(const SpGlue<T1,T2,spglue_times> & expr)503 trace(const SpGlue<T1, T2, spglue_times>& expr)
504 {
505 arma_extra_debug_sigprint();
506
507 typedef typename T1::elem_type eT;
508
509 // better-than-nothing implementation
510
511 const unwrap_spmat<T1> UA(expr.A);
512 const unwrap_spmat<T2> UB(expr.B);
513
514 const SpMat<eT>& A = UA.M;
515 const SpMat<eT>& B = UB.M;
516
517 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
518
519 if( (A.n_nonzero == 0) || (B.n_nonzero == 0) )
520 {
521 return eT(0);
522 }
523
524 const uword N = (std::min)(A.n_rows, B.n_cols);
525
526 eT acc = eT(0);
527
528 // TODO: the threshold may need tuning for complex matrices
529 if( (A.n_nonzero >= 5*N) || (B.n_nonzero >= 5*N) )
530 {
531 for(uword k=0; k < N; ++k)
532 {
533 typename SpMat<eT>::const_col_iterator B_it = B.begin_col_no_sync(k);
534 typename SpMat<eT>::const_col_iterator B_it_end = B.end_col_no_sync(k);
535
536 while(B_it != B_it_end)
537 {
538 const eT B_val = (*B_it);
539 const uword i = B_it.row();
540
541 acc += A.at(k,i) * B_val;
542
543 ++B_it;
544 }
545 }
546 }
547 else
548 {
549 const SpMat<eT> AB = A * B;
550
551 acc = trace(AB);
552 }
553
554 return acc;
555 }
556
557
558
559 //! trace of sparse object; speedup for trace(A.t()*B); non-complex elements
560 template<typename T1, typename T2>
561 arma_warn_unused
562 inline
563 typename enable_if2< is_cx<typename T1::elem_type>::no, typename T1::elem_type>::result
trace(const SpGlue<SpOp<T1,spop_htrans>,T2,spglue_times> & expr)564 trace(const SpGlue<SpOp<T1, spop_htrans>, T2, spglue_times>& expr)
565 {
566 arma_extra_debug_sigprint();
567
568 typedef typename T1::elem_type eT;
569
570 const unwrap_spmat<T1> UA(expr.A.m);
571 const unwrap_spmat<T2> UB(expr.B);
572
573 const SpMat<eT>& A = UA.M;
574 const SpMat<eT>& B = UB.M;
575
576 // NOTE: deliberately swapped A.n_rows and A.n_cols to take into account the requested transpose operation
577 arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "matrix multiplication");
578
579 if( (A.n_nonzero == 0) || (B.n_nonzero == 0) )
580 {
581 return eT(0);
582 }
583
584 const uword N = (std::min)(A.n_cols, B.n_cols);
585
586 eT acc = eT(0);
587
588 if( (A.n_nonzero >= 5*N) || (B.n_nonzero >= 5*N) )
589 {
590 for(uword k=0; k < N; ++k)
591 {
592 typename SpMat<eT>::const_col_iterator B_it = B.begin_col_no_sync(k);
593 typename SpMat<eT>::const_col_iterator B_it_end = B.end_col_no_sync(k);
594
595 while(B_it != B_it_end)
596 {
597 const eT B_val = (*B_it);
598 const uword i = B_it.row();
599
600 acc += A.at(i,k) * B_val;
601
602 ++B_it;
603 }
604 }
605 }
606 else
607 {
608 const SpMat<eT> AtB = A.t() * B;
609
610 acc = trace(AtB);
611 }
612
613 return acc;
614 }
615
616
617
618 //! trace of sparse object; speedup for trace(A.t()*B); complex elements
619 template<typename T1, typename T2>
620 arma_warn_unused
621 inline
622 typename enable_if2< is_cx<typename T1::elem_type>::yes, typename T1::elem_type>::result
trace(const SpGlue<SpOp<T1,spop_htrans>,T2,spglue_times> & expr)623 trace(const SpGlue<SpOp<T1, spop_htrans>, T2, spglue_times>& expr)
624 {
625 arma_extra_debug_sigprint();
626
627 typedef typename T1::elem_type eT;
628
629 const unwrap_spmat<T1> UA(expr.A.m);
630 const unwrap_spmat<T2> UB(expr.B);
631
632 const SpMat<eT>& A = UA.M;
633 const SpMat<eT>& B = UB.M;
634
635 // NOTE: deliberately swapped A.n_rows and A.n_cols to take into account the requested transpose operation
636 arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "matrix multiplication");
637
638 if( (A.n_nonzero == 0) || (B.n_nonzero == 0) )
639 {
640 return eT(0);
641 }
642
643 const uword N = (std::min)(A.n_cols, B.n_cols);
644
645 eT acc = eT(0);
646
647 // TODO: the threshold may need tuning for complex matrices
648 if( (A.n_nonzero >= 5*N) || (B.n_nonzero >= 5*N) )
649 {
650 for(uword k=0; k < N; ++k)
651 {
652 typename SpMat<eT>::const_col_iterator B_it = B.begin_col_no_sync(k);
653 typename SpMat<eT>::const_col_iterator B_it_end = B.end_col_no_sync(k);
654
655 while(B_it != B_it_end)
656 {
657 const eT B_val = (*B_it);
658 const uword i = B_it.row();
659
660 acc += std::conj(A.at(i,k)) * B_val;
661
662 ++B_it;
663 }
664 }
665 }
666 else
667 {
668 const SpMat<eT> AtB = A.t() * B;
669
670 acc = trace(AtB);
671 }
672
673 return acc;
674 }
675
676
677
678 //! @}
679