1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17
18
19 //! \addtogroup SpMat
20 //! @{
21
22
23 /**
24 * Initialize a sparse matrix with size 0x0 (empty).
25 */
26 template<typename eT>
27 inline
SpMat()28 SpMat<eT>::SpMat()
29 : n_rows(0)
30 , n_cols(0)
31 , n_elem(0)
32 , n_nonzero(0)
33 , vec_state(0)
34 , values(nullptr)
35 , row_indices(nullptr)
36 , col_ptrs(nullptr)
37 {
38 arma_extra_debug_sigprint_this(this);
39
40 init_cold(0,0);
41 }
42
43
44
45 /**
46 * Clean up the memory of a sparse matrix and destruct it.
47 */
48 template<typename eT>
49 inline
~SpMat()50 SpMat<eT>::~SpMat()
51 {
52 arma_extra_debug_sigprint_this(this);
53
54 if(values ) { memory::release(access::rw(values)); }
55 if(row_indices) { memory::release(access::rw(row_indices)); }
56 if(col_ptrs ) { memory::release(access::rw(col_ptrs)); }
57 }
58
59
60
61 /**
62 * Constructor with size given.
63 */
64 template<typename eT>
65 inline
SpMat(const uword in_rows,const uword in_cols)66 SpMat<eT>::SpMat(const uword in_rows, const uword in_cols)
67 : n_rows(0)
68 , n_cols(0)
69 , n_elem(0)
70 , n_nonzero(0)
71 , vec_state(0)
72 , values(nullptr)
73 , row_indices(nullptr)
74 , col_ptrs(nullptr)
75 {
76 arma_extra_debug_sigprint_this(this);
77
78 init_cold(in_rows, in_cols);
79 }
80
81
82
83 template<typename eT>
84 inline
SpMat(const SizeMat & s)85 SpMat<eT>::SpMat(const SizeMat& s)
86 : n_rows(0)
87 , n_cols(0)
88 , n_elem(0)
89 , n_nonzero(0)
90 , vec_state(0)
91 , values(nullptr)
92 , row_indices(nullptr)
93 , col_ptrs(nullptr)
94 {
95 arma_extra_debug_sigprint_this(this);
96
97 init_cold(s.n_rows, s.n_cols);
98 }
99
100
101
102 template<typename eT>
103 inline
SpMat(const arma_reserve_indicator &,const uword in_rows,const uword in_cols,const uword new_n_nonzero)104 SpMat<eT>::SpMat(const arma_reserve_indicator&, const uword in_rows, const uword in_cols, const uword new_n_nonzero)
105 : n_rows(0)
106 , n_cols(0)
107 , n_elem(0)
108 , n_nonzero(0)
109 , vec_state(0)
110 , values(nullptr)
111 , row_indices(nullptr)
112 , col_ptrs(nullptr)
113 {
114 arma_extra_debug_sigprint_this(this);
115
116 init_cold(in_rows, in_cols, new_n_nonzero);
117 }
118
119
120
121 template<typename eT>
122 template<typename eT2>
123 inline
SpMat(const arma_layout_indicator &,const SpMat<eT2> & x)124 SpMat<eT>::SpMat(const arma_layout_indicator&, const SpMat<eT2>& x)
125 : n_rows(0)
126 , n_cols(0)
127 , n_elem(0)
128 , n_nonzero(0)
129 , vec_state(0)
130 , values(nullptr)
131 , row_indices(nullptr)
132 , col_ptrs(nullptr)
133 {
134 arma_extra_debug_sigprint_this(this);
135
136 init_cold(x.n_rows, x.n_cols, x.n_nonzero);
137
138 if(x.n_nonzero == 0) { return; }
139
140 if(x.row_indices) { arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1); }
141 if(x.col_ptrs ) { arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1); }
142
143 // NOTE: 'values' array is not initialised
144 }
145
146
147
148 /**
149 * Assemble from text.
150 */
151 template<typename eT>
152 inline
SpMat(const char * text)153 SpMat<eT>::SpMat(const char* text)
154 : n_rows(0)
155 , n_cols(0)
156 , n_elem(0)
157 , n_nonzero(0)
158 , vec_state(0)
159 , values(nullptr)
160 , row_indices(nullptr)
161 , col_ptrs(nullptr)
162 {
163 arma_extra_debug_sigprint_this(this);
164
165 init(std::string(text));
166 }
167
168
169
170 template<typename eT>
171 inline
172 SpMat<eT>&
operator =(const char * text)173 SpMat<eT>::operator=(const char* text)
174 {
175 arma_extra_debug_sigprint();
176
177 init(std::string(text));
178
179 return *this;
180 }
181
182
183
184 template<typename eT>
185 inline
SpMat(const std::string & text)186 SpMat<eT>::SpMat(const std::string& text)
187 : n_rows(0)
188 , n_cols(0)
189 , n_elem(0)
190 , n_nonzero(0)
191 , vec_state(0)
192 , values(nullptr)
193 , row_indices(nullptr)
194 , col_ptrs(nullptr)
195 {
196 arma_extra_debug_sigprint();
197
198 init(text);
199 }
200
201
202
203 template<typename eT>
204 inline
205 SpMat<eT>&
operator =(const std::string & text)206 SpMat<eT>::operator=(const std::string& text)
207 {
208 arma_extra_debug_sigprint();
209
210 init(text);
211
212 return *this;
213 }
214
215
216
217 template<typename eT>
218 inline
SpMat(const SpMat<eT> & x)219 SpMat<eT>::SpMat(const SpMat<eT>& x)
220 : n_rows(0)
221 , n_cols(0)
222 , n_elem(0)
223 , n_nonzero(0)
224 , vec_state(0)
225 , values(nullptr)
226 , row_indices(nullptr)
227 , col_ptrs(nullptr)
228 {
229 arma_extra_debug_sigprint_this(this);
230
231 init(x);
232 }
233
234
235
236 template<typename eT>
237 inline
SpMat(SpMat<eT> && in_mat)238 SpMat<eT>::SpMat(SpMat<eT>&& in_mat)
239 : n_rows(0)
240 , n_cols(0)
241 , n_elem(0)
242 , n_nonzero(0)
243 , vec_state(0)
244 , values(nullptr)
245 , row_indices(nullptr)
246 , col_ptrs(nullptr)
247 {
248 arma_extra_debug_sigprint_this(this);
249 arma_extra_debug_sigprint(arma_str::format("this = %x in_mat = %x") % this % &in_mat);
250
251 (*this).steal_mem(in_mat);
252 }
253
254
255
256 template<typename eT>
257 inline
258 SpMat<eT>&
operator =(SpMat<eT> && in_mat)259 SpMat<eT>::operator=(SpMat<eT>&& in_mat)
260 {
261 arma_extra_debug_sigprint(arma_str::format("this = %x in_mat = %x") % this % &in_mat);
262
263 (*this).steal_mem(in_mat);
264
265 return *this;
266 }
267
268
269
270 template<typename eT>
271 inline
SpMat(const MapMat<eT> & x)272 SpMat<eT>::SpMat(const MapMat<eT>& x)
273 : n_rows(0)
274 , n_cols(0)
275 , n_elem(0)
276 , n_nonzero(0)
277 , vec_state(0)
278 , values(nullptr)
279 , row_indices(nullptr)
280 , col_ptrs(nullptr)
281 {
282 arma_extra_debug_sigprint_this(this);
283
284 init(x);
285 }
286
287
288
289 template<typename eT>
290 inline
291 SpMat<eT>&
operator =(const MapMat<eT> & x)292 SpMat<eT>::operator=(const MapMat<eT>& x)
293 {
294 arma_extra_debug_sigprint();
295
296 init(x);
297
298 return *this;
299 }
300
301
302
303 //! Insert a large number of values at once.
304 //! locations.row[0] should be row indices, locations.row[1] should be column indices,
305 //! and values should be the corresponding values.
306 //! If sort_locations is false, then it is assumed that the locations and values
307 //! are already sorted in column-major ordering.
308 template<typename eT>
309 template<typename T1, typename T2>
310 inline
SpMat(const Base<uword,T1> & locations_expr,const Base<eT,T2> & vals_expr,const bool sort_locations)311 SpMat<eT>::SpMat(const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_expr, const bool sort_locations)
312 : n_rows(0)
313 , n_cols(0)
314 , n_elem(0)
315 , n_nonzero(0)
316 , vec_state(0)
317 , values(nullptr)
318 , row_indices(nullptr)
319 , col_ptrs(nullptr)
320 {
321 arma_extra_debug_sigprint_this(this);
322
323 const unwrap<T1> locs_tmp( locations_expr.get_ref() );
324 const unwrap<T2> vals_tmp( vals_expr.get_ref() );
325
326 const Mat<uword>& locs = locs_tmp.M;
327 const Mat<eT>& vals = vals_tmp.M;
328
329 arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object must be a vector" );
330 arma_debug_check( (locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows" );
331 arma_debug_check( (locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values" );
332
333 // If there are no elements in the list, max() will fail.
334 if(locs.n_cols == 0) { init_cold(0, 0); return; }
335
336 // Automatically determine size before pruning zeros.
337 uvec bounds = arma::max(locs, 1);
338 init_cold(bounds[0] + 1, bounds[1] + 1);
339
340 // Ensure that there are no zeros
341 const uword N_old = vals.n_elem;
342 uword N_new = 0;
343
344 for(uword i=0; i < N_old; ++i) { N_new += (vals[i] != eT(0)) ? uword(1) : uword(0); }
345
346 if(N_new != N_old)
347 {
348 Col<eT> filtered_vals( N_new, arma_nozeros_indicator());
349 Mat<uword> filtered_locs(2, N_new, arma_nozeros_indicator());
350
351 uword index = 0;
352 for(uword i = 0; i < N_old; ++i)
353 {
354 if(vals[i] != eT(0))
355 {
356 filtered_vals[index] = vals[i];
357
358 filtered_locs.at(0, index) = locs.at(0, i);
359 filtered_locs.at(1, index) = locs.at(1, i);
360
361 ++index;
362 }
363 }
364
365 init_batch_std(filtered_locs, filtered_vals, sort_locations);
366 }
367 else
368 {
369 init_batch_std(locs, vals, sort_locations);
370 }
371 }
372
373
374
375 //! Insert a large number of values at once.
376 //! locations.row[0] should be row indices, locations.row[1] should be column indices,
377 //! and values should be the corresponding values.
378 //! If sort_locations is false, then it is assumed that the locations and values
379 //! are already sorted in column-major ordering.
380 //! In this constructor the size is explicitly given.
381 template<typename eT>
382 template<typename T1, typename T2>
383 inline
SpMat(const Base<uword,T1> & locations_expr,const Base<eT,T2> & vals_expr,const uword in_n_rows,const uword in_n_cols,const bool sort_locations,const bool check_for_zeros)384 SpMat<eT>::SpMat(const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_expr, const uword in_n_rows, const uword in_n_cols, const bool sort_locations, const bool check_for_zeros)
385 : n_rows(0)
386 , n_cols(0)
387 , n_elem(0)
388 , n_nonzero(0)
389 , vec_state(0)
390 , values(nullptr)
391 , row_indices(nullptr)
392 , col_ptrs(nullptr)
393 {
394 arma_extra_debug_sigprint_this(this);
395
396 const unwrap<T1> locs_tmp( locations_expr.get_ref() );
397 const unwrap<T2> vals_tmp( vals_expr.get_ref() );
398
399 const Mat<uword>& locs = locs_tmp.M;
400 const Mat<eT>& vals = vals_tmp.M;
401
402 arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object must be a vector" );
403 arma_debug_check( (locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows" );
404 arma_debug_check( (locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values" );
405
406 init_cold(in_n_rows, in_n_cols);
407
408 // Ensure that there are no zeros, unless the user asked not to.
409 if(check_for_zeros)
410 {
411 const uword N_old = vals.n_elem;
412 uword N_new = 0;
413
414 for(uword i=0; i < N_old; ++i) { N_new += (vals[i] != eT(0)) ? uword(1) : uword(0); }
415
416 if(N_new != N_old)
417 {
418 Col<eT> filtered_vals( N_new, arma_nozeros_indicator());
419 Mat<uword> filtered_locs(2, N_new, arma_nozeros_indicator());
420
421 uword index = 0;
422 for(uword i = 0; i < N_old; ++i)
423 {
424 if(vals[i] != eT(0))
425 {
426 filtered_vals[index] = vals[i];
427
428 filtered_locs.at(0, index) = locs.at(0, i);
429 filtered_locs.at(1, index) = locs.at(1, i);
430
431 ++index;
432 }
433 }
434
435 init_batch_std(filtered_locs, filtered_vals, sort_locations);
436 }
437 else
438 {
439 init_batch_std(locs, vals, sort_locations);
440 }
441 }
442 else
443 {
444 init_batch_std(locs, vals, sort_locations);
445 }
446 }
447
448
449
450 template<typename eT>
451 template<typename T1, typename T2>
452 inline
SpMat(const bool add_values,const Base<uword,T1> & locations_expr,const Base<eT,T2> & vals_expr,const uword in_n_rows,const uword in_n_cols,const bool sort_locations,const bool check_for_zeros)453 SpMat<eT>::SpMat(const bool add_values, const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_expr, const uword in_n_rows, const uword in_n_cols, const bool sort_locations, const bool check_for_zeros)
454 : n_rows(0)
455 , n_cols(0)
456 , n_elem(0)
457 , n_nonzero(0)
458 , vec_state(0)
459 , values(nullptr)
460 , row_indices(nullptr)
461 , col_ptrs(nullptr)
462 {
463 arma_extra_debug_sigprint_this(this);
464
465 const unwrap<T1> locs_tmp( locations_expr.get_ref() );
466 const unwrap<T2> vals_tmp( vals_expr.get_ref() );
467
468 const Mat<uword>& locs = locs_tmp.M;
469 const Mat<eT>& vals = vals_tmp.M;
470
471 arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object must be a vector" );
472 arma_debug_check( (locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows" );
473 arma_debug_check( (locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values" );
474
475 init_cold(in_n_rows, in_n_cols);
476
477 // Ensure that there are no zeros, unless the user asked not to.
478 if(check_for_zeros)
479 {
480 const uword N_old = vals.n_elem;
481 uword N_new = 0;
482
483 for(uword i=0; i < N_old; ++i) { N_new += (vals[i] != eT(0)) ? uword(1) : uword(0); }
484
485 if(N_new != N_old)
486 {
487 Col<eT> filtered_vals( N_new, arma_nozeros_indicator());
488 Mat<uword> filtered_locs(2, N_new, arma_nozeros_indicator());
489
490 uword index = 0;
491 for(uword i = 0; i < N_old; ++i)
492 {
493 if(vals[i] != eT(0))
494 {
495 filtered_vals[index] = vals[i];
496
497 filtered_locs.at(0, index) = locs.at(0, i);
498 filtered_locs.at(1, index) = locs.at(1, i);
499
500 ++index;
501 }
502 }
503
504 add_values ? init_batch_add(filtered_locs, filtered_vals, sort_locations) : init_batch_std(filtered_locs, filtered_vals, sort_locations);
505 }
506 else
507 {
508 add_values ? init_batch_add(locs, vals, sort_locations) : init_batch_std(locs, vals, sort_locations);
509 }
510 }
511 else
512 {
513 add_values ? init_batch_add(locs, vals, sort_locations) : init_batch_std(locs, vals, sort_locations);
514 }
515 }
516
517
518
519 //! Insert a large number of values at once.
520 //! Per CSC format, rowind_expr should be row indices,
521 //! colptr_expr should column ptr indices locations,
522 //! and values should be the corresponding values.
523 //! In this constructor the size is explicitly given.
524 //! Values are assumed to be sorted, and the size
525 //! information is trusted
526 template<typename eT>
527 template<typename T1, typename T2, typename T3>
528 inline
SpMat(const Base<uword,T1> & rowind_expr,const Base<uword,T2> & colptr_expr,const Base<eT,T3> & values_expr,const uword in_n_rows,const uword in_n_cols)529 SpMat<eT>::SpMat
530 (
531 const Base<uword,T1>& rowind_expr,
532 const Base<uword,T2>& colptr_expr,
533 const Base<eT, T3>& values_expr,
534 const uword in_n_rows,
535 const uword in_n_cols
536 )
537 : n_rows(0)
538 , n_cols(0)
539 , n_elem(0)
540 , n_nonzero(0)
541 , vec_state(0)
542 , values(nullptr)
543 , row_indices(nullptr)
544 , col_ptrs(nullptr)
545 {
546 arma_extra_debug_sigprint_this(this);
547
548 const unwrap<T1> rowind_tmp( rowind_expr.get_ref() );
549 const unwrap<T2> colptr_tmp( colptr_expr.get_ref() );
550 const unwrap<T3> vals_tmp( values_expr.get_ref() );
551
552 const Mat<uword>& rowind = rowind_tmp.M;
553 const Mat<uword>& colptr = colptr_tmp.M;
554 const Mat<eT>& vals = vals_tmp.M;
555
556 arma_debug_check( (rowind.is_vec() == false), "SpMat::SpMat(): given 'rowind' object must be a vector" );
557 arma_debug_check( (colptr.is_vec() == false), "SpMat::SpMat(): given 'colptr' object must be a vector" );
558 arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object must be a vector" );
559
560 // Resize to correct number of elements (this also sets n_nonzero)
561 init_cold(in_n_rows, in_n_cols, vals.n_elem);
562
563 arma_debug_check( (rowind.n_elem != vals.n_elem), "SpMat::SpMat(): number of row indices is not equal to number of values" );
564 arma_debug_check( (colptr.n_elem != (n_cols+1) ), "SpMat::SpMat(): number of column pointers is not equal to n_cols+1" );
565
566 // copy supplied values into sparse matrix -- not checked for consistency
567 arrayops::copy(access::rwp(row_indices), rowind.memptr(), rowind.n_elem );
568 arrayops::copy(access::rwp(col_ptrs), colptr.memptr(), colptr.n_elem );
569 arrayops::copy(access::rwp(values), vals.memptr(), vals.n_elem );
570
571 // important: set the sentinel as well
572 access::rw(col_ptrs[n_cols + 1]) = std::numeric_limits<uword>::max();
573
574 // make sure no zeros are stored
575 remove_zeros();
576 }
577
578
579
580 template<typename eT>
581 inline
582 SpMat<eT>&
operator =(const eT val)583 SpMat<eT>::operator=(const eT val)
584 {
585 arma_extra_debug_sigprint();
586
587 if(val != eT(0))
588 {
589 // Resize to 1x1 then set that to the right value.
590 init(1, 1, 1); // Sets col_ptrs to 0.
591
592 // Manually set element.
593 access::rw(values[0]) = val;
594 access::rw(row_indices[0]) = 0;
595 access::rw(col_ptrs[1]) = 1;
596 }
597 else
598 {
599 init(0, 0);
600 }
601
602 return *this;
603 }
604
605
606
607 template<typename eT>
608 inline
609 SpMat<eT>&
operator *=(const eT val)610 SpMat<eT>::operator*=(const eT val)
611 {
612 arma_extra_debug_sigprint();
613
614 if(val != eT(0))
615 {
616 sync_csc();
617 invalidate_cache();
618
619 const uword n_nz = n_nonzero;
620
621 eT* vals = access::rwp(values);
622
623 bool has_zero = false;
624
625 for(uword i=0; i<n_nz; ++i)
626 {
627 eT& vals_i = vals[i];
628
629 vals_i *= val;
630
631 if(vals_i == eT(0)) { has_zero = true; }
632 }
633
634 if(has_zero) { remove_zeros(); }
635 }
636 else
637 {
638 (*this).zeros();
639 }
640
641 return *this;
642 }
643
644
645
646 template<typename eT>
647 inline
648 SpMat<eT>&
operator /=(const eT val)649 SpMat<eT>::operator/=(const eT val)
650 {
651 arma_extra_debug_sigprint();
652
653 arma_debug_check( (val == eT(0)), "element-wise division: division by zero" );
654
655 sync_csc();
656 invalidate_cache();
657
658 const uword n_nz = n_nonzero;
659
660 eT* vals = access::rwp(values);
661
662 bool has_zero = false;
663
664 for(uword i=0; i<n_nz; ++i)
665 {
666 eT& vals_i = vals[i];
667
668 vals_i /= val;
669
670 if(vals_i == eT(0)) { has_zero = true; }
671 }
672
673 if(has_zero) { remove_zeros(); }
674
675 return *this;
676 }
677
678
679
680 template<typename eT>
681 inline
682 SpMat<eT>&
operator =(const SpMat<eT> & x)683 SpMat<eT>::operator=(const SpMat<eT>& x)
684 {
685 arma_extra_debug_sigprint();
686
687 init(x);
688
689 return *this;
690 }
691
692
693
694 template<typename eT>
695 inline
696 SpMat<eT>&
operator +=(const SpMat<eT> & x)697 SpMat<eT>::operator+=(const SpMat<eT>& x)
698 {
699 arma_extra_debug_sigprint();
700
701 sync_csc();
702
703 SpMat<eT> out = (*this) + x;
704
705 steal_mem(out);
706
707 return *this;
708 }
709
710
711
712 template<typename eT>
713 inline
714 SpMat<eT>&
operator -=(const SpMat<eT> & x)715 SpMat<eT>::operator-=(const SpMat<eT>& x)
716 {
717 arma_extra_debug_sigprint();
718
719 sync_csc();
720
721 SpMat<eT> out = (*this) - x;
722
723 steal_mem(out);
724
725 return *this;
726 }
727
728
729
730 template<typename eT>
731 inline
732 SpMat<eT>&
operator *=(const SpMat<eT> & y)733 SpMat<eT>::operator*=(const SpMat<eT>& y)
734 {
735 arma_extra_debug_sigprint();
736
737 sync_csc();
738
739 SpMat<eT> z = (*this) * y;
740
741 steal_mem(z);
742
743 return *this;
744 }
745
746
747
748 // This is in-place element-wise matrix multiplication.
749 template<typename eT>
750 inline
751 SpMat<eT>&
operator %=(const SpMat<eT> & y)752 SpMat<eT>::operator%=(const SpMat<eT>& y)
753 {
754 arma_extra_debug_sigprint();
755
756 sync_csc();
757
758 SpMat<eT> z = (*this) % y;
759
760 steal_mem(z);
761
762 return *this;
763 }
764
765
766
767 template<typename eT>
768 inline
769 SpMat<eT>&
operator /=(const SpMat<eT> & x)770 SpMat<eT>::operator/=(const SpMat<eT>& x)
771 {
772 arma_extra_debug_sigprint();
773
774 // NOTE: use of this function is not advised; it is implemented only for completeness
775
776 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise division");
777
778 for(uword c = 0; c < n_cols; ++c)
779 for(uword r = 0; r < n_rows; ++r)
780 {
781 at(r, c) /= x.at(r, c);
782 }
783
784 return *this;
785 }
786
787
788
789 template<typename eT>
790 template<typename T1, typename op_type>
791 inline
SpMat(const SpToDOp<T1,op_type> & expr)792 SpMat<eT>::SpMat(const SpToDOp<T1, op_type>& expr)
793 : n_rows(0)
794 , n_cols(0)
795 , n_elem(0)
796 , n_nonzero(0)
797 , vec_state(0)
798 , values(nullptr)
799 , row_indices(nullptr)
800 , col_ptrs(nullptr)
801 {
802 arma_extra_debug_sigprint_this(this);
803
804 typedef typename T1::elem_type T;
805
806 // Make sure the type is compatible.
807 arma_type_check(( is_same_type< eT, T >::no ));
808
809 op_type::apply(*this, expr);
810 }
811
812
813
814 // Construct a complex matrix out of two non-complex matrices
815 template<typename eT>
816 template<typename T1, typename T2>
817 inline
SpMat(const SpBase<typename SpMat<eT>::pod_type,T1> & A,const SpBase<typename SpMat<eT>::pod_type,T2> & B)818 SpMat<eT>::SpMat
819 (
820 const SpBase<typename SpMat<eT>::pod_type, T1>& A,
821 const SpBase<typename SpMat<eT>::pod_type, T2>& B
822 )
823 : n_rows(0)
824 , n_cols(0)
825 , n_elem(0)
826 , n_nonzero(0)
827 , vec_state(0)
828 , values(nullptr)
829 , row_indices(nullptr)
830 , col_ptrs(nullptr)
831 {
832 arma_extra_debug_sigprint();
833
834 typedef typename T1::elem_type T;
835
836 // Make sure eT is complex and T is not (compile-time check).
837 arma_type_check(( is_cx<eT>::no ));
838 arma_type_check(( is_cx< T>::yes ));
839
840 // Compile-time abort if types are not compatible.
841 arma_type_check(( is_same_type< std::complex<T>, eT >::no ));
842
843 const unwrap_spmat<T1> tmp1(A.get_ref());
844 const unwrap_spmat<T2> tmp2(B.get_ref());
845
846 const SpMat<T>& X = tmp1.M;
847 const SpMat<T>& Y = tmp2.M;
848
849 arma_debug_assert_same_size(X.n_rows, X.n_cols, Y.n_rows, Y.n_cols, "SpMat()");
850
851 const uword l_n_rows = X.n_rows;
852 const uword l_n_cols = X.n_cols;
853
854 // Set size of matrix correctly.
855 init_cold(l_n_rows, l_n_cols, n_unique(X, Y, op_n_unique_count()));
856
857 // Now on a second iteration, fill it.
858 typename SpMat<T>::const_iterator x_it = X.begin();
859 typename SpMat<T>::const_iterator x_end = X.end();
860
861 typename SpMat<T>::const_iterator y_it = Y.begin();
862 typename SpMat<T>::const_iterator y_end = Y.end();
863
864 uword cur_pos = 0;
865
866 while((x_it != x_end) || (y_it != y_end))
867 {
868 if(x_it == y_it) // if we are at the same place
869 {
870 access::rw(values[cur_pos]) = std::complex<T>((T) *x_it, (T) *y_it);
871 access::rw(row_indices[cur_pos]) = x_it.row();
872 ++access::rw(col_ptrs[x_it.col() + 1]);
873
874 ++x_it;
875 ++y_it;
876 }
877 else
878 {
879 if((x_it.col() < y_it.col()) || ((x_it.col() == y_it.col()) && (x_it.row() < y_it.row()))) // if y is closer to the end
880 {
881 access::rw(values[cur_pos]) = std::complex<T>((T) *x_it, T(0));
882 access::rw(row_indices[cur_pos]) = x_it.row();
883 ++access::rw(col_ptrs[x_it.col() + 1]);
884
885 ++x_it;
886 }
887 else // x is closer to the end
888 {
889 access::rw(values[cur_pos]) = std::complex<T>(T(0), (T) *y_it);
890 access::rw(row_indices[cur_pos]) = y_it.row();
891 ++access::rw(col_ptrs[y_it.col() + 1]);
892
893 ++y_it;
894 }
895 }
896
897 ++cur_pos;
898 }
899
900 // Now fix the column pointers; they are supposed to be a sum.
901 for(uword c = 1; c <= n_cols; ++c)
902 {
903 access::rw(col_ptrs[c]) += col_ptrs[c - 1];
904 }
905
906 }
907
908
909
910 template<typename eT>
911 template<typename T1>
912 inline
SpMat(const Base<eT,T1> & x)913 SpMat<eT>::SpMat(const Base<eT, T1>& x)
914 : n_rows(0)
915 , n_cols(0)
916 , n_elem(0)
917 , n_nonzero(0)
918 , vec_state(0)
919 , values(nullptr)
920 , row_indices(nullptr)
921 , col_ptrs(nullptr)
922 {
923 arma_extra_debug_sigprint_this(this);
924
925 (*this).operator=(x);
926 }
927
928
929
930 template<typename eT>
931 template<typename T1>
932 inline
933 SpMat<eT>&
operator =(const Base<eT,T1> & expr)934 SpMat<eT>::operator=(const Base<eT, T1>& expr)
935 {
936 arma_extra_debug_sigprint();
937
938 if(is_same_type< T1, Gen<Mat<eT>, gen_zeros> >::yes)
939 {
940 const Proxy<T1> P(expr.get_ref());
941
942 (*this).zeros( P.get_n_rows(), P.get_n_cols() );
943
944 return *this;
945 }
946
947 if(is_same_type< T1, Gen<Mat<eT>, gen_eye> >::yes)
948 {
949 const Proxy<T1> P(expr.get_ref());
950
951 (*this).eye( P.get_n_rows(), P.get_n_cols() );
952
953 return *this;
954 }
955
956 const quasi_unwrap<T1> tmp(expr.get_ref());
957 const Mat<eT>& x = tmp.M;
958
959 const uword x_n_rows = x.n_rows;
960 const uword x_n_cols = x.n_cols;
961 const uword x_n_elem = x.n_elem;
962
963 // Count number of nonzero elements in base object.
964 uword n = 0;
965
966 const eT* x_mem = x.memptr();
967
968 for(uword i=0; i < x_n_elem; ++i) { n += (x_mem[i] != eT(0)) ? uword(1) : uword(0); }
969
970 init(x_n_rows, x_n_cols, n);
971
972 if(n == 0) { return *this; }
973
974 // Now the memory is resized correctly; set nonzero elements.
975 n = 0;
976 for(uword j = 0; j < x_n_cols; ++j)
977 for(uword i = 0; i < x_n_rows; ++i)
978 {
979 const eT val = (*x_mem); x_mem++;
980
981 if(val != eT(0))
982 {
983 access::rw(values[n]) = val;
984 access::rw(row_indices[n]) = i;
985 access::rw(col_ptrs[j + 1])++;
986 ++n;
987 }
988 }
989
990 // Sum column counts to be column pointers.
991 for(uword c = 1; c <= n_cols; ++c)
992 {
993 access::rw(col_ptrs[c]) += col_ptrs[c - 1];
994 }
995
996 return *this;
997 }
998
999
1000
1001 template<typename eT>
1002 template<typename T1>
1003 inline
1004 SpMat<eT>&
operator +=(const Base<eT,T1> & x)1005 SpMat<eT>::operator+=(const Base<eT, T1>& x)
1006 {
1007 arma_extra_debug_sigprint();
1008
1009 sync_csc();
1010
1011 return (*this).operator=( (*this) + x.get_ref() );
1012 }
1013
1014
1015
1016 template<typename eT>
1017 template<typename T1>
1018 inline
1019 SpMat<eT>&
operator -=(const Base<eT,T1> & x)1020 SpMat<eT>::operator-=(const Base<eT, T1>& x)
1021 {
1022 arma_extra_debug_sigprint();
1023
1024 sync_csc();
1025
1026 return (*this).operator=( (*this) - x.get_ref() );
1027 }
1028
1029
1030
1031 template<typename eT>
1032 template<typename T1>
1033 inline
1034 SpMat<eT>&
operator *=(const Base<eT,T1> & y)1035 SpMat<eT>::operator*=(const Base<eT, T1>& y)
1036 {
1037 arma_extra_debug_sigprint();
1038
1039 sync_csc();
1040
1041 const Proxy<T1> p(y.get_ref());
1042
1043 arma_debug_assert_mul_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "matrix multiplication");
1044
1045 // We assume the matrix structure is such that we will end up with a sparse
1046 // matrix. Assuming that every entry in the dense matrix is nonzero (which is
1047 // a fairly valid assumption), each row with any nonzero elements in it (in this
1048 // matrix) implies an entire nonzero column. Therefore, we iterate over all
1049 // the row_indices and count the number of rows with any elements in them
1050 // (using the quasi-linked-list idea from SYMBMM -- see spglue_times_meat.hpp).
1051 podarray<uword> index(n_rows);
1052 index.fill(n_rows); // Fill with invalid links.
1053
1054 uword last_index = n_rows + 1;
1055 for(uword i = 0; i < n_nonzero; ++i)
1056 {
1057 if(index[row_indices[i]] == n_rows)
1058 {
1059 index[row_indices[i]] = last_index;
1060 last_index = row_indices[i];
1061 }
1062 }
1063
1064 // Now count the number of rows which have nonzero elements.
1065 uword nonzero_rows = 0;
1066 while(last_index != n_rows + 1)
1067 {
1068 ++nonzero_rows;
1069 last_index = index[last_index];
1070 }
1071
1072 SpMat<eT> z(arma_reserve_indicator(), n_rows, p.get_n_cols(), (nonzero_rows * p.get_n_cols())); // upper bound on size
1073
1074 // Now we have to fill all the elements using a modification of the NUMBMM algorithm.
1075 uword cur_pos = 0;
1076
1077 podarray<eT> partial_sums(n_rows);
1078 partial_sums.zeros();
1079
1080 for(uword lcol = 0; lcol < n_cols; ++lcol)
1081 {
1082 const_iterator it = begin();
1083 const_iterator it_end = end();
1084
1085 while(it != it_end)
1086 {
1087 const eT value = (*it);
1088
1089 partial_sums[it.row()] += (value * p.at(it.col(), lcol));
1090
1091 ++it;
1092 }
1093
1094 // Now add all partial sums to the matrix.
1095 for(uword i = 0; i < n_rows; ++i)
1096 {
1097 if(partial_sums[i] != eT(0))
1098 {
1099 access::rw(z.values[cur_pos]) = partial_sums[i];
1100 access::rw(z.row_indices[cur_pos]) = i;
1101 ++access::rw(z.col_ptrs[lcol + 1]);
1102 //printf("colptr %d now %d\n", lcol + 1, z.col_ptrs[lcol + 1]);
1103 ++cur_pos;
1104 partial_sums[i] = 0; // Would it be faster to do this in batch later?
1105 }
1106 }
1107 }
1108
1109 // Now fix the column pointers.
1110 for(uword c = 1; c <= z.n_cols; ++c)
1111 {
1112 access::rw(z.col_ptrs[c]) += z.col_ptrs[c - 1];
1113 }
1114
1115 // Resize to final correct size.
1116 z.mem_resize(z.col_ptrs[z.n_cols]);
1117
1118 // Now take the memory of the temporary matrix.
1119 steal_mem(z);
1120
1121 return *this;
1122 }
1123
1124
1125
1126 // NOTE: use of this function is not advised; it is implemented only for completeness
1127 template<typename eT>
1128 template<typename T1>
1129 inline
1130 SpMat<eT>&
operator /=(const Base<eT,T1> & x)1131 SpMat<eT>::operator/=(const Base<eT, T1>& x)
1132 {
1133 arma_extra_debug_sigprint();
1134
1135 sync_csc();
1136
1137 SpMat<eT> tmp = (*this) / x.get_ref();
1138
1139 steal_mem(tmp);
1140
1141 return *this;
1142 }
1143
1144
1145
1146 template<typename eT>
1147 template<typename T1>
1148 inline
1149 SpMat<eT>&
operator %=(const Base<eT,T1> & x)1150 SpMat<eT>::operator%=(const Base<eT, T1>& x)
1151 {
1152 arma_extra_debug_sigprint();
1153
1154 SpMat<eT> tmp;
1155
1156 // Just call the other order (these operations are commutative)
1157 // TODO: if there is a matrix size mismatch, the debug assert will print the matrix sizes in wrong order
1158 spglue_schur_misc::dense_schur_sparse(tmp, x.get_ref(), (*this));
1159
1160 steal_mem(tmp);
1161
1162 return *this;
1163 }
1164
1165
1166
1167 template<typename eT>
1168 template<typename T1>
1169 inline
SpMat(const Op<T1,op_diagmat> & expr)1170 SpMat<eT>::SpMat(const Op<T1, op_diagmat>& expr)
1171 : n_rows(0)
1172 , n_cols(0)
1173 , n_elem(0)
1174 , n_nonzero(0)
1175 , vec_state(0)
1176 , values(nullptr)
1177 , row_indices(nullptr)
1178 , col_ptrs(nullptr)
1179 {
1180 arma_extra_debug_sigprint_this(this);
1181
1182 (*this).operator=(expr);
1183 }
1184
1185
1186
1187 template<typename eT>
1188 template<typename T1>
1189 inline
1190 SpMat<eT>&
operator =(const Op<T1,op_diagmat> & expr)1191 SpMat<eT>::operator=(const Op<T1, op_diagmat>& expr)
1192 {
1193 arma_extra_debug_sigprint();
1194
1195 const diagmat_proxy<T1> P(expr.m);
1196
1197 const uword max_n_nonzero = (std::min)(P.n_rows, P.n_cols);
1198
1199 // resize memory to upper bound
1200 init(P.n_rows, P.n_cols, max_n_nonzero);
1201
1202 uword count = 0;
1203
1204 for(uword i=0; i < max_n_nonzero; ++i)
1205 {
1206 const eT val = P[i];
1207
1208 if(val != eT(0))
1209 {
1210 access::rw(values[count]) = val;
1211 access::rw(row_indices[count]) = i;
1212 access::rw(col_ptrs[i + 1])++;
1213 ++count;
1214 }
1215 }
1216
1217 // fix column pointers to be cumulative
1218 for(uword i = 1; i < n_cols + 1; ++i)
1219 {
1220 access::rw(col_ptrs[i]) += col_ptrs[i - 1];
1221 }
1222
1223 // quick resize without reallocating memory and copying data
1224 access::rw( n_nonzero) = count;
1225 access::rw( values[count]) = eT(0);
1226 access::rw(row_indices[count]) = uword(0);
1227
1228 return *this;
1229 }
1230
1231
1232
1233 template<typename eT>
1234 template<typename T1>
1235 inline
1236 SpMat<eT>&
operator +=(const Op<T1,op_diagmat> & expr)1237 SpMat<eT>::operator+=(const Op<T1, op_diagmat>& expr)
1238 {
1239 arma_extra_debug_sigprint();
1240
1241 const SpMat<eT> tmp(expr);
1242
1243 return (*this).operator+=(tmp);
1244 }
1245
1246
1247
1248 template<typename eT>
1249 template<typename T1>
1250 inline
1251 SpMat<eT>&
operator -=(const Op<T1,op_diagmat> & expr)1252 SpMat<eT>::operator-=(const Op<T1, op_diagmat>& expr)
1253 {
1254 arma_extra_debug_sigprint();
1255
1256 const SpMat<eT> tmp(expr);
1257
1258 return (*this).operator-=(tmp);
1259 }
1260
1261
1262
1263 template<typename eT>
1264 template<typename T1>
1265 inline
1266 SpMat<eT>&
operator *=(const Op<T1,op_diagmat> & expr)1267 SpMat<eT>::operator*=(const Op<T1, op_diagmat>& expr)
1268 {
1269 arma_extra_debug_sigprint();
1270
1271 const SpMat<eT> tmp(expr);
1272
1273 return (*this).operator*=(tmp);
1274 }
1275
1276
1277
1278 template<typename eT>
1279 template<typename T1>
1280 inline
1281 SpMat<eT>&
operator /=(const Op<T1,op_diagmat> & expr)1282 SpMat<eT>::operator/=(const Op<T1, op_diagmat>& expr)
1283 {
1284 arma_extra_debug_sigprint();
1285
1286 const SpMat<eT> tmp(expr);
1287
1288 return (*this).operator/=(tmp);
1289 }
1290
1291
1292
1293 template<typename eT>
1294 template<typename T1>
1295 inline
1296 SpMat<eT>&
operator %=(const Op<T1,op_diagmat> & expr)1297 SpMat<eT>::operator%=(const Op<T1, op_diagmat>& expr)
1298 {
1299 arma_extra_debug_sigprint();
1300
1301 const SpMat<eT> tmp(expr);
1302
1303 return (*this).operator%=(tmp);
1304 }
1305
1306
1307
1308 /**
1309 * Functions on subviews.
1310 */
1311 template<typename eT>
1312 inline
SpMat(const SpSubview<eT> & X)1313 SpMat<eT>::SpMat(const SpSubview<eT>& X)
1314 : n_rows(0)
1315 , n_cols(0)
1316 , n_elem(0)
1317 , n_nonzero(0)
1318 , vec_state(0)
1319 , values(nullptr)
1320 , row_indices(nullptr)
1321 , col_ptrs(nullptr)
1322 {
1323 arma_extra_debug_sigprint_this(this);
1324
1325 (*this).operator=(X);
1326 }
1327
1328
1329
1330 template<typename eT>
1331 inline
1332 SpMat<eT>&
operator =(const SpSubview<eT> & X)1333 SpMat<eT>::operator=(const SpSubview<eT>& X)
1334 {
1335 arma_extra_debug_sigprint();
1336
1337 if(X.n_nonzero == 0) { zeros(X.n_rows, X.n_cols); return *this; }
1338
1339 X.m.sync_csc();
1340
1341 const bool alias = (this == &(X.m));
1342
1343 if(alias)
1344 {
1345 SpMat<eT> tmp(X);
1346
1347 steal_mem(tmp);
1348 }
1349 else
1350 {
1351 init(X.n_rows, X.n_cols, X.n_nonzero);
1352
1353 if(X.n_rows == X.m.n_rows)
1354 {
1355 const uword sv_col_start = X.aux_col1;
1356 const uword sv_col_end = X.aux_col1 + X.n_cols - 1;
1357
1358 typename SpMat<eT>::const_col_iterator m_it = X.m.begin_col(sv_col_start);
1359 typename SpMat<eT>::const_col_iterator m_it_end = X.m.end_col(sv_col_end);
1360
1361 uword count = 0;
1362
1363 while(m_it != m_it_end)
1364 {
1365 const uword m_it_col_adjusted = m_it.col() - sv_col_start;
1366
1367 access::rw(row_indices[count]) = m_it.row();
1368 access::rw(values[count]) = (*m_it);
1369 ++access::rw(col_ptrs[m_it_col_adjusted + 1]);
1370
1371 count++;
1372
1373 ++m_it;
1374 }
1375 }
1376 else
1377 {
1378 typename SpSubview<eT>::const_iterator it = X.begin();
1379 typename SpSubview<eT>::const_iterator it_end = X.end();
1380
1381 while(it != it_end)
1382 {
1383 const uword it_pos = it.pos();
1384
1385 access::rw(row_indices[it_pos]) = it.row();
1386 access::rw(values[it_pos]) = (*it);
1387 ++access::rw(col_ptrs[it.col() + 1]);
1388 ++it;
1389 }
1390 }
1391
1392 // Now sum column pointers.
1393 for(uword c = 1; c <= n_cols; ++c)
1394 {
1395 access::rw(col_ptrs[c]) += col_ptrs[c - 1];
1396 }
1397 }
1398
1399 return *this;
1400 }
1401
1402
1403
1404 template<typename eT>
1405 inline
1406 SpMat<eT>&
operator +=(const SpSubview<eT> & X)1407 SpMat<eT>::operator+=(const SpSubview<eT>& X)
1408 {
1409 arma_extra_debug_sigprint();
1410
1411 sync_csc();
1412
1413 SpMat<eT> tmp = (*this) + X;
1414
1415 steal_mem(tmp);
1416
1417 return *this;
1418 }
1419
1420
1421
1422 template<typename eT>
1423 inline
1424 SpMat<eT>&
operator -=(const SpSubview<eT> & X)1425 SpMat<eT>::operator-=(const SpSubview<eT>& X)
1426 {
1427 arma_extra_debug_sigprint();
1428
1429 sync_csc();
1430
1431 SpMat<eT> tmp = (*this) - X;
1432
1433 steal_mem(tmp);
1434
1435 return *this;
1436 }
1437
1438
1439
1440 template<typename eT>
1441 inline
1442 SpMat<eT>&
operator *=(const SpSubview<eT> & y)1443 SpMat<eT>::operator*=(const SpSubview<eT>& y)
1444 {
1445 arma_extra_debug_sigprint();
1446
1447 sync_csc();
1448
1449 SpMat<eT> z = (*this) * y;
1450
1451 steal_mem(z);
1452
1453 return *this;
1454 }
1455
1456
1457
1458 template<typename eT>
1459 inline
1460 SpMat<eT>&
operator %=(const SpSubview<eT> & x)1461 SpMat<eT>::operator%=(const SpSubview<eT>& x)
1462 {
1463 arma_extra_debug_sigprint();
1464
1465 sync_csc();
1466
1467 SpMat<eT> tmp = (*this) % x;
1468
1469 steal_mem(tmp);
1470
1471 return *this;
1472 }
1473
1474
1475
1476 template<typename eT>
1477 inline
1478 SpMat<eT>&
operator /=(const SpSubview<eT> & x)1479 SpMat<eT>::operator/=(const SpSubview<eT>& x)
1480 {
1481 arma_extra_debug_sigprint();
1482
1483 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise division");
1484
1485 // There is no pretty way to do this.
1486 for(uword elem = 0; elem < n_elem; elem++)
1487 {
1488 at(elem) /= x(elem);
1489 }
1490
1491 return *this;
1492 }
1493
1494
1495
1496 template<typename eT>
1497 template<typename T1>
1498 inline
SpMat(const SpSubview_col_list<eT,T1> & X)1499 SpMat<eT>::SpMat(const SpSubview_col_list<eT,T1>& X)
1500 : n_rows(0)
1501 , n_cols(0)
1502 , n_elem(0)
1503 , n_nonzero(0)
1504 , vec_state(0)
1505 , values(nullptr)
1506 , row_indices(nullptr)
1507 , col_ptrs(nullptr)
1508 {
1509 arma_extra_debug_sigprint_this(this);
1510
1511 SpSubview_col_list<eT,T1>::extract(*this, X);
1512 }
1513
1514
1515
1516 template<typename eT>
1517 template<typename T1>
1518 inline
1519 SpMat<eT>&
operator =(const SpSubview_col_list<eT,T1> & X)1520 SpMat<eT>::operator=(const SpSubview_col_list<eT,T1>& X)
1521 {
1522 arma_extra_debug_sigprint();
1523
1524 const bool alias = (this == &(X.m));
1525
1526 if(alias == false)
1527 {
1528 SpSubview_col_list<eT,T1>::extract(*this, X);
1529 }
1530 else
1531 {
1532 SpMat<eT> tmp(X);
1533
1534 steal_mem(tmp);
1535 }
1536
1537 return *this;
1538 }
1539
1540
1541
1542 template<typename eT>
1543 template<typename T1>
1544 inline
1545 SpMat<eT>&
operator +=(const SpSubview_col_list<eT,T1> & X)1546 SpMat<eT>::operator+=(const SpSubview_col_list<eT,T1>& X)
1547 {
1548 arma_extra_debug_sigprint();
1549
1550 SpSubview_col_list<eT,T1>::plus_inplace(*this, X);
1551
1552 return *this;
1553 }
1554
1555
1556
1557 template<typename eT>
1558 template<typename T1>
1559 inline
1560 SpMat<eT>&
operator -=(const SpSubview_col_list<eT,T1> & X)1561 SpMat<eT>::operator-=(const SpSubview_col_list<eT,T1>& X)
1562 {
1563 arma_extra_debug_sigprint();
1564
1565 SpSubview_col_list<eT,T1>::minus_inplace(*this, X);
1566
1567 return *this;
1568 }
1569
1570
1571
1572 template<typename eT>
1573 template<typename T1>
1574 inline
1575 SpMat<eT>&
operator *=(const SpSubview_col_list<eT,T1> & X)1576 SpMat<eT>::operator*=(const SpSubview_col_list<eT,T1>& X)
1577 {
1578 arma_extra_debug_sigprint();
1579
1580 sync_csc();
1581
1582 SpMat<eT> z = (*this) * X;
1583
1584 steal_mem(z);
1585
1586 return *this;
1587 }
1588
1589
1590
1591 template<typename eT>
1592 template<typename T1>
1593 inline
1594 SpMat<eT>&
operator %=(const SpSubview_col_list<eT,T1> & X)1595 SpMat<eT>::operator%=(const SpSubview_col_list<eT,T1>& X)
1596 {
1597 arma_extra_debug_sigprint();
1598
1599 SpSubview_col_list<eT,T1>::schur_inplace(*this, X);
1600
1601 return *this;
1602 }
1603
1604
1605
1606 template<typename eT>
1607 template<typename T1>
1608 inline
1609 SpMat<eT>&
operator /=(const SpSubview_col_list<eT,T1> & X)1610 SpMat<eT>::operator/=(const SpSubview_col_list<eT,T1>& X)
1611 {
1612 arma_extra_debug_sigprint();
1613
1614 SpSubview_col_list<eT,T1>::div_inplace(*this, X);
1615
1616 return *this;
1617 }
1618
1619
1620
1621 template<typename eT>
1622 inline
SpMat(const spdiagview<eT> & X)1623 SpMat<eT>::SpMat(const spdiagview<eT>& X)
1624 : n_rows(0)
1625 , n_cols(0)
1626 , n_elem(0)
1627 , n_nonzero(0)
1628 , vec_state(0)
1629 , values(nullptr)
1630 , row_indices(nullptr)
1631 , col_ptrs(nullptr)
1632 {
1633 arma_extra_debug_sigprint_this(this);
1634
1635 spdiagview<eT>::extract(*this, X);
1636 }
1637
1638
1639
1640 template<typename eT>
1641 inline
1642 SpMat<eT>&
operator =(const spdiagview<eT> & X)1643 SpMat<eT>::operator=(const spdiagview<eT>& X)
1644 {
1645 arma_extra_debug_sigprint();
1646
1647 spdiagview<eT>::extract(*this, X);
1648
1649 return *this;
1650 }
1651
1652
1653
1654 template<typename eT>
1655 inline
1656 SpMat<eT>&
operator +=(const spdiagview<eT> & X)1657 SpMat<eT>::operator+=(const spdiagview<eT>& X)
1658 {
1659 arma_extra_debug_sigprint();
1660
1661 const SpMat<eT> tmp(X);
1662
1663 return (*this).operator+=(tmp);
1664 }
1665
1666
1667
1668 template<typename eT>
1669 inline
1670 SpMat<eT>&
operator -=(const spdiagview<eT> & X)1671 SpMat<eT>::operator-=(const spdiagview<eT>& X)
1672 {
1673 arma_extra_debug_sigprint();
1674
1675 const SpMat<eT> tmp(X);
1676
1677 return (*this).operator-=(tmp);
1678 }
1679
1680
1681
1682 template<typename eT>
1683 inline
1684 SpMat<eT>&
operator *=(const spdiagview<eT> & X)1685 SpMat<eT>::operator*=(const spdiagview<eT>& X)
1686 {
1687 arma_extra_debug_sigprint();
1688
1689 const SpMat<eT> tmp(X);
1690
1691 return (*this).operator*=(tmp);
1692 }
1693
1694
1695
1696 template<typename eT>
1697 inline
1698 SpMat<eT>&
operator %=(const spdiagview<eT> & X)1699 SpMat<eT>::operator%=(const spdiagview<eT>& X)
1700 {
1701 arma_extra_debug_sigprint();
1702
1703 const SpMat<eT> tmp(X);
1704
1705 return (*this).operator%=(tmp);
1706 }
1707
1708
1709
1710 template<typename eT>
1711 inline
1712 SpMat<eT>&
operator /=(const spdiagview<eT> & X)1713 SpMat<eT>::operator/=(const spdiagview<eT>& X)
1714 {
1715 arma_extra_debug_sigprint();
1716
1717 const SpMat<eT> tmp(X);
1718
1719 return (*this).operator/=(tmp);
1720 }
1721
1722
1723
1724 template<typename eT>
1725 template<typename T1, typename spop_type>
1726 inline
SpMat(const SpOp<T1,spop_type> & X)1727 SpMat<eT>::SpMat(const SpOp<T1, spop_type>& X)
1728 : n_rows(0)
1729 , n_cols(0)
1730 , n_elem(0)
1731 , n_nonzero(0)
1732 , vec_state(0)
1733 , values(nullptr) // set in application of sparse operation
1734 , row_indices(nullptr)
1735 , col_ptrs(nullptr)
1736 {
1737 arma_extra_debug_sigprint_this(this);
1738
1739 arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
1740
1741 spop_type::apply(*this, X);
1742
1743 sync_csc(); // in case apply() used element accessors
1744 invalidate_cache(); // in case apply() modified the CSC representation
1745 }
1746
1747
1748
1749 template<typename eT>
1750 template<typename T1, typename spop_type>
1751 inline
1752 SpMat<eT>&
operator =(const SpOp<T1,spop_type> & X)1753 SpMat<eT>::operator=(const SpOp<T1, spop_type>& X)
1754 {
1755 arma_extra_debug_sigprint();
1756
1757 arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
1758
1759 spop_type::apply(*this, X);
1760
1761 sync_csc(); // in case apply() used element accessors
1762 invalidate_cache(); // in case apply() modified the CSC representation
1763
1764 return *this;
1765 }
1766
1767
1768
1769 template<typename eT>
1770 template<typename T1, typename spop_type>
1771 inline
1772 SpMat<eT>&
operator +=(const SpOp<T1,spop_type> & X)1773 SpMat<eT>::operator+=(const SpOp<T1, spop_type>& X)
1774 {
1775 arma_extra_debug_sigprint();
1776
1777 arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
1778
1779 sync_csc();
1780
1781 const SpMat<eT> m(X);
1782
1783 return (*this).operator+=(m);
1784 }
1785
1786
1787
1788 template<typename eT>
1789 template<typename T1, typename spop_type>
1790 inline
1791 SpMat<eT>&
operator -=(const SpOp<T1,spop_type> & X)1792 SpMat<eT>::operator-=(const SpOp<T1, spop_type>& X)
1793 {
1794 arma_extra_debug_sigprint();
1795
1796 arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
1797
1798 sync_csc();
1799
1800 const SpMat<eT> m(X);
1801
1802 return (*this).operator-=(m);
1803 }
1804
1805
1806
1807 template<typename eT>
1808 template<typename T1, typename spop_type>
1809 inline
1810 SpMat<eT>&
operator *=(const SpOp<T1,spop_type> & X)1811 SpMat<eT>::operator*=(const SpOp<T1, spop_type>& X)
1812 {
1813 arma_extra_debug_sigprint();
1814
1815 arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
1816
1817 sync_csc();
1818
1819 const SpMat<eT> m(X);
1820
1821 return (*this).operator*=(m);
1822 }
1823
1824
1825
1826 template<typename eT>
1827 template<typename T1, typename spop_type>
1828 inline
1829 SpMat<eT>&
operator %=(const SpOp<T1,spop_type> & X)1830 SpMat<eT>::operator%=(const SpOp<T1, spop_type>& X)
1831 {
1832 arma_extra_debug_sigprint();
1833
1834 arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
1835
1836 sync_csc();
1837
1838 const SpMat<eT> m(X);
1839
1840 return (*this).operator%=(m);
1841 }
1842
1843
1844
1845 template<typename eT>
1846 template<typename T1, typename spop_type>
1847 inline
1848 SpMat<eT>&
operator /=(const SpOp<T1,spop_type> & X)1849 SpMat<eT>::operator/=(const SpOp<T1, spop_type>& X)
1850 {
1851 arma_extra_debug_sigprint();
1852
1853 arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
1854
1855 sync_csc();
1856
1857 const SpMat<eT> m(X);
1858
1859 return (*this).operator/=(m);
1860 }
1861
1862
1863
1864 template<typename eT>
1865 template<typename T1, typename T2, typename spglue_type>
1866 inline
SpMat(const SpGlue<T1,T2,spglue_type> & X)1867 SpMat<eT>::SpMat(const SpGlue<T1, T2, spglue_type>& X)
1868 : n_rows(0)
1869 , n_cols(0)
1870 , n_elem(0)
1871 , n_nonzero(0)
1872 , vec_state(0)
1873 , values(nullptr)
1874 , row_indices(nullptr)
1875 , col_ptrs(nullptr)
1876 {
1877 arma_extra_debug_sigprint_this(this);
1878
1879 arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
1880
1881 spglue_type::apply(*this, X);
1882
1883 sync_csc(); // in case apply() used element accessors
1884 invalidate_cache(); // in case apply() modified the CSC representation
1885 }
1886
1887
1888
1889 template<typename eT>
1890 template<typename T1, typename T2, typename spglue_type>
1891 inline
1892 SpMat<eT>&
operator =(const SpGlue<T1,T2,spglue_type> & X)1893 SpMat<eT>::operator=(const SpGlue<T1, T2, spglue_type>& X)
1894 {
1895 arma_extra_debug_sigprint();
1896
1897 arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
1898
1899 spglue_type::apply(*this, X);
1900
1901 sync_csc(); // in case apply() used element accessors
1902 invalidate_cache(); // in case apply() modified the CSC representation
1903
1904 return *this;
1905 }
1906
1907
1908
1909 template<typename eT>
1910 template<typename T1, typename T2, typename spglue_type>
1911 inline
1912 SpMat<eT>&
operator +=(const SpGlue<T1,T2,spglue_type> & X)1913 SpMat<eT>::operator+=(const SpGlue<T1, T2, spglue_type>& X)
1914 {
1915 arma_extra_debug_sigprint();
1916
1917 arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
1918
1919 sync_csc();
1920
1921 const SpMat<eT> m(X);
1922
1923 return (*this).operator+=(m);
1924 }
1925
1926
1927
1928 template<typename eT>
1929 template<typename T1, typename T2, typename spglue_type>
1930 inline
1931 SpMat<eT>&
operator -=(const SpGlue<T1,T2,spglue_type> & X)1932 SpMat<eT>::operator-=(const SpGlue<T1, T2, spglue_type>& X)
1933 {
1934 arma_extra_debug_sigprint();
1935
1936 arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
1937
1938 sync_csc();
1939
1940 const SpMat<eT> m(X);
1941
1942 return (*this).operator-=(m);
1943 }
1944
1945
1946
1947 template<typename eT>
1948 template<typename T1, typename T2, typename spglue_type>
1949 inline
1950 SpMat<eT>&
operator *=(const SpGlue<T1,T2,spglue_type> & X)1951 SpMat<eT>::operator*=(const SpGlue<T1, T2, spglue_type>& X)
1952 {
1953 arma_extra_debug_sigprint();
1954
1955 arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
1956
1957 sync_csc();
1958
1959 const SpMat<eT> m(X);
1960
1961 return (*this).operator*=(m);
1962 }
1963
1964
1965
1966 template<typename eT>
1967 template<typename T1, typename T2, typename spglue_type>
1968 inline
1969 SpMat<eT>&
operator %=(const SpGlue<T1,T2,spglue_type> & X)1970 SpMat<eT>::operator%=(const SpGlue<T1, T2, spglue_type>& X)
1971 {
1972 arma_extra_debug_sigprint();
1973
1974 arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
1975
1976 sync_csc();
1977
1978 const SpMat<eT> m(X);
1979
1980 return (*this).operator%=(m);
1981 }
1982
1983
1984
1985 template<typename eT>
1986 template<typename T1, typename T2, typename spglue_type>
1987 inline
1988 SpMat<eT>&
operator /=(const SpGlue<T1,T2,spglue_type> & X)1989 SpMat<eT>::operator/=(const SpGlue<T1, T2, spglue_type>& X)
1990 {
1991 arma_extra_debug_sigprint();
1992
1993 arma_type_check(( is_same_type< eT, typename T1::elem_type >::no ));
1994
1995 sync_csc();
1996
1997 const SpMat<eT> m(X);
1998
1999 return (*this).operator/=(m);
2000 }
2001
2002
2003
2004 template<typename eT>
2005 template<typename T1, typename spop_type>
2006 inline
SpMat(const mtSpOp<eT,T1,spop_type> & X)2007 SpMat<eT>::SpMat(const mtSpOp<eT, T1, spop_type>& X)
2008 : n_rows(0)
2009 , n_cols(0)
2010 , n_elem(0)
2011 , n_nonzero(0)
2012 , vec_state(0)
2013 , values(nullptr)
2014 , row_indices(nullptr)
2015 , col_ptrs(nullptr)
2016 {
2017 arma_extra_debug_sigprint_this(this);
2018
2019 spop_type::apply(*this, X);
2020
2021 sync_csc(); // in case apply() used element accessors
2022 invalidate_cache(); // in case apply() modified the CSC representation
2023 }
2024
2025
2026
2027 template<typename eT>
2028 template<typename T1, typename spop_type>
2029 inline
2030 SpMat<eT>&
operator =(const mtSpOp<eT,T1,spop_type> & X)2031 SpMat<eT>::operator=(const mtSpOp<eT, T1, spop_type>& X)
2032 {
2033 arma_extra_debug_sigprint();
2034
2035 spop_type::apply(*this, X);
2036
2037 sync_csc(); // in case apply() used element accessors
2038 invalidate_cache(); // in case apply() modified the CSC representation
2039
2040 return *this;
2041 }
2042
2043
2044
2045 template<typename eT>
2046 template<typename T1, typename spop_type>
2047 inline
2048 SpMat<eT>&
operator +=(const mtSpOp<eT,T1,spop_type> & X)2049 SpMat<eT>::operator+=(const mtSpOp<eT, T1, spop_type>& X)
2050 {
2051 arma_extra_debug_sigprint();
2052
2053 sync_csc();
2054
2055 const SpMat<eT> m(X);
2056
2057 return (*this).operator+=(m);
2058 }
2059
2060
2061
2062 template<typename eT>
2063 template<typename T1, typename spop_type>
2064 inline
2065 SpMat<eT>&
operator -=(const mtSpOp<eT,T1,spop_type> & X)2066 SpMat<eT>::operator-=(const mtSpOp<eT, T1, spop_type>& X)
2067 {
2068 arma_extra_debug_sigprint();
2069
2070 sync_csc();
2071
2072 const SpMat<eT> m(X);
2073
2074 return (*this).operator-=(m);
2075 }
2076
2077
2078
2079 template<typename eT>
2080 template<typename T1, typename spop_type>
2081 inline
2082 SpMat<eT>&
operator *=(const mtSpOp<eT,T1,spop_type> & X)2083 SpMat<eT>::operator*=(const mtSpOp<eT, T1, spop_type>& X)
2084 {
2085 arma_extra_debug_sigprint();
2086
2087 sync_csc();
2088
2089 const SpMat<eT> m(X);
2090
2091 return (*this).operator*=(m);
2092 }
2093
2094
2095
2096 template<typename eT>
2097 template<typename T1, typename spop_type>
2098 inline
2099 SpMat<eT>&
operator %=(const mtSpOp<eT,T1,spop_type> & X)2100 SpMat<eT>::operator%=(const mtSpOp<eT, T1, spop_type>& X)
2101 {
2102 arma_extra_debug_sigprint();
2103
2104 sync_csc();
2105
2106 const SpMat<eT> m(X);
2107
2108 return (*this).operator%=(m);
2109 }
2110
2111
2112
2113 template<typename eT>
2114 template<typename T1, typename spop_type>
2115 inline
2116 SpMat<eT>&
operator /=(const mtSpOp<eT,T1,spop_type> & X)2117 SpMat<eT>::operator/=(const mtSpOp<eT, T1, spop_type>& X)
2118 {
2119 arma_extra_debug_sigprint();
2120
2121 sync_csc();
2122
2123 const SpMat<eT> m(X);
2124
2125 return (*this).operator/=(m);
2126 }
2127
2128
2129
2130 template<typename eT>
2131 template<typename T1, typename T2, typename spglue_type>
2132 inline
SpMat(const mtSpGlue<eT,T1,T2,spglue_type> & X)2133 SpMat<eT>::SpMat(const mtSpGlue<eT, T1, T2, spglue_type>& X)
2134 : n_rows(0)
2135 , n_cols(0)
2136 , n_elem(0)
2137 , n_nonzero(0)
2138 , vec_state(0)
2139 , values(nullptr)
2140 , row_indices(nullptr)
2141 , col_ptrs(nullptr)
2142 {
2143 arma_extra_debug_sigprint_this(this);
2144
2145 spglue_type::apply(*this, X);
2146
2147 sync_csc(); // in case apply() used element accessors
2148 invalidate_cache(); // in case apply() modified the CSC representation
2149 }
2150
2151
2152
2153 template<typename eT>
2154 template<typename T1, typename T2, typename spglue_type>
2155 inline
2156 SpMat<eT>&
operator =(const mtSpGlue<eT,T1,T2,spglue_type> & X)2157 SpMat<eT>::operator=(const mtSpGlue<eT, T1, T2, spglue_type>& X)
2158 {
2159 arma_extra_debug_sigprint();
2160
2161 spglue_type::apply(*this, X);
2162
2163 sync_csc(); // in case apply() used element accessors
2164 invalidate_cache(); // in case apply() modified the CSC representation
2165
2166 return *this;
2167 }
2168
2169
2170
2171 template<typename eT>
2172 template<typename T1, typename T2, typename spglue_type>
2173 inline
2174 SpMat<eT>&
operator +=(const mtSpGlue<eT,T1,T2,spglue_type> & X)2175 SpMat<eT>::operator+=(const mtSpGlue<eT, T1, T2, spglue_type>& X)
2176 {
2177 arma_extra_debug_sigprint();
2178
2179 sync_csc();
2180
2181 const SpMat<eT> m(X);
2182
2183 return (*this).operator+=(m);
2184 }
2185
2186
2187
2188 template<typename eT>
2189 template<typename T1, typename T2, typename spglue_type>
2190 inline
2191 SpMat<eT>&
operator -=(const mtSpGlue<eT,T1,T2,spglue_type> & X)2192 SpMat<eT>::operator-=(const mtSpGlue<eT, T1, T2, spglue_type>& X)
2193 {
2194 arma_extra_debug_sigprint();
2195
2196 sync_csc();
2197
2198 const SpMat<eT> m(X);
2199
2200 return (*this).operator-=(m);
2201 }
2202
2203
2204
2205 template<typename eT>
2206 template<typename T1, typename T2, typename spglue_type>
2207 inline
2208 SpMat<eT>&
operator *=(const mtSpGlue<eT,T1,T2,spglue_type> & X)2209 SpMat<eT>::operator*=(const mtSpGlue<eT, T1, T2, spglue_type>& X)
2210 {
2211 arma_extra_debug_sigprint();
2212
2213 sync_csc();
2214
2215 const SpMat<eT> m(X);
2216
2217 return (*this).operator*=(m);
2218 }
2219
2220
2221
2222 template<typename eT>
2223 template<typename T1, typename T2, typename spglue_type>
2224 inline
2225 SpMat<eT>&
operator %=(const mtSpGlue<eT,T1,T2,spglue_type> & X)2226 SpMat<eT>::operator%=(const mtSpGlue<eT, T1, T2, spglue_type>& X)
2227 {
2228 arma_extra_debug_sigprint();
2229
2230 sync_csc();
2231
2232 const SpMat<eT> m(X);
2233
2234 return (*this).operator%=(m);
2235 }
2236
2237
2238
2239 template<typename eT>
2240 template<typename T1, typename T2, typename spglue_type>
2241 inline
2242 SpMat<eT>&
operator /=(const mtSpGlue<eT,T1,T2,spglue_type> & X)2243 SpMat<eT>::operator/=(const mtSpGlue<eT, T1, T2, spglue_type>& X)
2244 {
2245 arma_extra_debug_sigprint();
2246
2247 sync_csc();
2248
2249 const SpMat<eT> m(X);
2250
2251 return (*this).operator/=(m);
2252 }
2253
2254
2255
2256 template<typename eT>
2257 arma_inline
2258 SpSubview_row<eT>
row(const uword row_num)2259 SpMat<eT>::row(const uword row_num)
2260 {
2261 arma_extra_debug_sigprint();
2262
2263 arma_debug_check_bounds(row_num >= n_rows, "SpMat::row(): out of bounds");
2264
2265 return SpSubview_row<eT>(*this, row_num);
2266 }
2267
2268
2269
2270 template<typename eT>
2271 arma_inline
2272 const SpSubview_row<eT>
row(const uword row_num) const2273 SpMat<eT>::row(const uword row_num) const
2274 {
2275 arma_extra_debug_sigprint();
2276
2277 arma_debug_check_bounds(row_num >= n_rows, "SpMat::row(): out of bounds");
2278
2279 return SpSubview_row<eT>(*this, row_num);
2280 }
2281
2282
2283
2284 template<typename eT>
2285 inline
2286 SpSubview_row<eT>
operator ()(const uword row_num,const span & col_span)2287 SpMat<eT>::operator()(const uword row_num, const span& col_span)
2288 {
2289 arma_extra_debug_sigprint();
2290
2291 const bool col_all = col_span.whole;
2292
2293 const uword local_n_cols = n_cols;
2294
2295 const uword in_col1 = col_all ? 0 : col_span.a;
2296 const uword in_col2 = col_span.b;
2297 const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1;
2298
2299 arma_debug_check_bounds
2300 (
2301 (row_num >= n_rows)
2302 ||
2303 ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) )
2304 ,
2305 "SpMat::operator(): indices out of bounds or incorrectly used"
2306 );
2307
2308 return SpSubview_row<eT>(*this, row_num, in_col1, submat_n_cols);
2309 }
2310
2311
2312
2313 template<typename eT>
2314 inline
2315 const SpSubview_row<eT>
operator ()(const uword row_num,const span & col_span) const2316 SpMat<eT>::operator()(const uword row_num, const span& col_span) const
2317 {
2318 arma_extra_debug_sigprint();
2319
2320 const bool col_all = col_span.whole;
2321
2322 const uword local_n_cols = n_cols;
2323
2324 const uword in_col1 = col_all ? 0 : col_span.a;
2325 const uword in_col2 = col_span.b;
2326 const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1;
2327
2328 arma_debug_check_bounds
2329 (
2330 (row_num >= n_rows)
2331 ||
2332 ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) )
2333 ,
2334 "SpMat::operator(): indices out of bounds or incorrectly used"
2335 );
2336
2337 return SpSubview_row<eT>(*this, row_num, in_col1, submat_n_cols);
2338 }
2339
2340
2341
2342 template<typename eT>
2343 arma_inline
2344 SpSubview_col<eT>
col(const uword col_num)2345 SpMat<eT>::col(const uword col_num)
2346 {
2347 arma_extra_debug_sigprint();
2348
2349 arma_debug_check_bounds(col_num >= n_cols, "SpMat::col(): out of bounds");
2350
2351 return SpSubview_col<eT>(*this, col_num);
2352 }
2353
2354
2355
2356 template<typename eT>
2357 arma_inline
2358 const SpSubview_col<eT>
col(const uword col_num) const2359 SpMat<eT>::col(const uword col_num) const
2360 {
2361 arma_extra_debug_sigprint();
2362
2363 arma_debug_check_bounds(col_num >= n_cols, "SpMat::col(): out of bounds");
2364
2365 return SpSubview_col<eT>(*this, col_num);
2366 }
2367
2368
2369
2370 template<typename eT>
2371 inline
2372 SpSubview_col<eT>
operator ()(const span & row_span,const uword col_num)2373 SpMat<eT>::operator()(const span& row_span, const uword col_num)
2374 {
2375 arma_extra_debug_sigprint();
2376
2377 const bool row_all = row_span.whole;
2378
2379 const uword local_n_rows = n_rows;
2380
2381 const uword in_row1 = row_all ? 0 : row_span.a;
2382 const uword in_row2 = row_span.b;
2383 const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1;
2384
2385 arma_debug_check_bounds
2386 (
2387 (col_num >= n_cols)
2388 ||
2389 ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) )
2390 ,
2391 "SpMat::operator(): indices out of bounds or incorrectly used"
2392 );
2393
2394 return SpSubview_col<eT>(*this, col_num, in_row1, submat_n_rows);
2395 }
2396
2397
2398
2399 template<typename eT>
2400 inline
2401 const SpSubview_col<eT>
operator ()(const span & row_span,const uword col_num) const2402 SpMat<eT>::operator()(const span& row_span, const uword col_num) const
2403 {
2404 arma_extra_debug_sigprint();
2405
2406 const bool row_all = row_span.whole;
2407
2408 const uword local_n_rows = n_rows;
2409
2410 const uword in_row1 = row_all ? 0 : row_span.a;
2411 const uword in_row2 = row_span.b;
2412 const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1;
2413
2414 arma_debug_check_bounds
2415 (
2416 (col_num >= n_cols)
2417 ||
2418 ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) )
2419 ,
2420 "SpMat::operator(): indices out of bounds or incorrectly used"
2421 );
2422
2423 return SpSubview_col<eT>(*this, col_num, in_row1, submat_n_rows);
2424 }
2425
2426
2427
2428 template<typename eT>
2429 arma_inline
2430 SpSubview<eT>
rows(const uword in_row1,const uword in_row2)2431 SpMat<eT>::rows(const uword in_row1, const uword in_row2)
2432 {
2433 arma_extra_debug_sigprint();
2434
2435 arma_debug_check_bounds
2436 (
2437 (in_row1 > in_row2) || (in_row2 >= n_rows),
2438 "SpMat::rows(): indices out of bounds or incorrectly used"
2439 );
2440
2441 const uword subview_n_rows = in_row2 - in_row1 + 1;
2442
2443 return SpSubview<eT>(*this, in_row1, 0, subview_n_rows, n_cols);
2444 }
2445
2446
2447
2448 template<typename eT>
2449 arma_inline
2450 const SpSubview<eT>
rows(const uword in_row1,const uword in_row2) const2451 SpMat<eT>::rows(const uword in_row1, const uword in_row2) const
2452 {
2453 arma_extra_debug_sigprint();
2454
2455 arma_debug_check_bounds
2456 (
2457 (in_row1 > in_row2) || (in_row2 >= n_rows),
2458 "SpMat::rows(): indices out of bounds or incorrectly used"
2459 );
2460
2461 const uword subview_n_rows = in_row2 - in_row1 + 1;
2462
2463 return SpSubview<eT>(*this, in_row1, 0, subview_n_rows, n_cols);
2464 }
2465
2466
2467
2468 template<typename eT>
2469 arma_inline
2470 SpSubview<eT>
cols(const uword in_col1,const uword in_col2)2471 SpMat<eT>::cols(const uword in_col1, const uword in_col2)
2472 {
2473 arma_extra_debug_sigprint();
2474
2475 arma_debug_check_bounds
2476 (
2477 (in_col1 > in_col2) || (in_col2 >= n_cols),
2478 "SpMat::cols(): indices out of bounds or incorrectly used"
2479 );
2480
2481 const uword subview_n_cols = in_col2 - in_col1 + 1;
2482
2483 return SpSubview<eT>(*this, 0, in_col1, n_rows, subview_n_cols);
2484 }
2485
2486
2487
2488 template<typename eT>
2489 arma_inline
2490 const SpSubview<eT>
cols(const uword in_col1,const uword in_col2) const2491 SpMat<eT>::cols(const uword in_col1, const uword in_col2) const
2492 {
2493 arma_extra_debug_sigprint();
2494
2495 arma_debug_check_bounds
2496 (
2497 (in_col1 > in_col2) || (in_col2 >= n_cols),
2498 "SpMat::cols(): indices out of bounds or incorrectly used"
2499 );
2500
2501 const uword subview_n_cols = in_col2 - in_col1 + 1;
2502
2503 return SpSubview<eT>(*this, 0, in_col1, n_rows, subview_n_cols);
2504 }
2505
2506
2507
2508 template<typename eT>
2509 arma_inline
2510 SpSubview<eT>
submat(const uword in_row1,const uword in_col1,const uword in_row2,const uword in_col2)2511 SpMat<eT>::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2)
2512 {
2513 arma_extra_debug_sigprint();
2514
2515 arma_debug_check_bounds
2516 (
2517 (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols),
2518 "SpMat::submat(): indices out of bounds or incorrectly used"
2519 );
2520
2521 const uword subview_n_rows = in_row2 - in_row1 + 1;
2522 const uword subview_n_cols = in_col2 - in_col1 + 1;
2523
2524 return SpSubview<eT>(*this, in_row1, in_col1, subview_n_rows, subview_n_cols);
2525 }
2526
2527
2528
2529 template<typename eT>
2530 arma_inline
2531 const SpSubview<eT>
submat(const uword in_row1,const uword in_col1,const uword in_row2,const uword in_col2) const2532 SpMat<eT>::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const
2533 {
2534 arma_extra_debug_sigprint();
2535
2536 arma_debug_check_bounds
2537 (
2538 (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols),
2539 "SpMat::submat(): indices out of bounds or incorrectly used"
2540 );
2541
2542 const uword subview_n_rows = in_row2 - in_row1 + 1;
2543 const uword subview_n_cols = in_col2 - in_col1 + 1;
2544
2545 return SpSubview<eT>(*this, in_row1, in_col1, subview_n_rows, subview_n_cols);
2546 }
2547
2548
2549
2550 template<typename eT>
2551 arma_inline
2552 SpSubview<eT>
submat(const uword in_row1,const uword in_col1,const SizeMat & s)2553 SpMat<eT>::submat(const uword in_row1, const uword in_col1, const SizeMat& s)
2554 {
2555 arma_extra_debug_sigprint();
2556
2557 const uword l_n_rows = n_rows;
2558 const uword l_n_cols = n_cols;
2559
2560 const uword s_n_rows = s.n_rows;
2561 const uword s_n_cols = s.n_cols;
2562
2563 arma_debug_check_bounds
2564 (
2565 ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)),
2566 "SpMat::submat(): indices or size out of bounds"
2567 );
2568
2569 return SpSubview<eT>(*this, in_row1, in_col1, s_n_rows, s_n_cols);
2570 }
2571
2572
2573
2574 template<typename eT>
2575 arma_inline
2576 const SpSubview<eT>
submat(const uword in_row1,const uword in_col1,const SizeMat & s) const2577 SpMat<eT>::submat(const uword in_row1, const uword in_col1, const SizeMat& s) const
2578 {
2579 arma_extra_debug_sigprint();
2580
2581 const uword l_n_rows = n_rows;
2582 const uword l_n_cols = n_cols;
2583
2584 const uword s_n_rows = s.n_rows;
2585 const uword s_n_cols = s.n_cols;
2586
2587 arma_debug_check_bounds
2588 (
2589 ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)),
2590 "SpMat::submat(): indices or size out of bounds"
2591 );
2592
2593 return SpSubview<eT>(*this, in_row1, in_col1, s_n_rows, s_n_cols);
2594 }
2595
2596
2597
2598 template<typename eT>
2599 inline
2600 SpSubview<eT>
submat(const span & row_span,const span & col_span)2601 SpMat<eT>::submat(const span& row_span, const span& col_span)
2602 {
2603 arma_extra_debug_sigprint();
2604
2605 const bool row_all = row_span.whole;
2606 const bool col_all = col_span.whole;
2607
2608 const uword local_n_rows = n_rows;
2609 const uword local_n_cols = n_cols;
2610
2611 const uword in_row1 = row_all ? 0 : row_span.a;
2612 const uword in_row2 = row_span.b;
2613 const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1;
2614
2615 const uword in_col1 = col_all ? 0 : col_span.a;
2616 const uword in_col2 = col_span.b;
2617 const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1;
2618
2619 arma_debug_check_bounds
2620 (
2621 ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) )
2622 ||
2623 ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) )
2624 ,
2625 "SpMat::submat(): indices out of bounds or incorrectly used"
2626 );
2627
2628 return SpSubview<eT>(*this, in_row1, in_col1, submat_n_rows, submat_n_cols);
2629 }
2630
2631
2632
2633 template<typename eT>
2634 inline
2635 const SpSubview<eT>
submat(const span & row_span,const span & col_span) const2636 SpMat<eT>::submat(const span& row_span, const span& col_span) const
2637 {
2638 arma_extra_debug_sigprint();
2639
2640 const bool row_all = row_span.whole;
2641 const bool col_all = col_span.whole;
2642
2643 const uword local_n_rows = n_rows;
2644 const uword local_n_cols = n_cols;
2645
2646 const uword in_row1 = row_all ? 0 : row_span.a;
2647 const uword in_row2 = row_span.b;
2648 const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1;
2649
2650 const uword in_col1 = col_all ? 0 : col_span.a;
2651 const uword in_col2 = col_span.b;
2652 const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1;
2653
2654 arma_debug_check_bounds
2655 (
2656 ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) )
2657 ||
2658 ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) )
2659 ,
2660 "SpMat::submat(): indices out of bounds or incorrectly used"
2661 );
2662
2663 return SpSubview<eT>(*this, in_row1, in_col1, submat_n_rows, submat_n_cols);
2664 }
2665
2666
2667
2668 template<typename eT>
2669 inline
2670 SpSubview<eT>
operator ()(const span & row_span,const span & col_span)2671 SpMat<eT>::operator()(const span& row_span, const span& col_span)
2672 {
2673 arma_extra_debug_sigprint();
2674
2675 return submat(row_span, col_span);
2676 }
2677
2678
2679
2680 template<typename eT>
2681 inline
2682 const SpSubview<eT>
operator ()(const span & row_span,const span & col_span) const2683 SpMat<eT>::operator()(const span& row_span, const span& col_span) const
2684 {
2685 arma_extra_debug_sigprint();
2686
2687 return submat(row_span, col_span);
2688 }
2689
2690
2691
2692 template<typename eT>
2693 arma_inline
2694 SpSubview<eT>
operator ()(const uword in_row1,const uword in_col1,const SizeMat & s)2695 SpMat<eT>::operator()(const uword in_row1, const uword in_col1, const SizeMat& s)
2696 {
2697 arma_extra_debug_sigprint();
2698
2699 return (*this).submat(in_row1, in_col1, s);
2700 }
2701
2702
2703
2704 template<typename eT>
2705 arma_inline
2706 const SpSubview<eT>
operator ()(const uword in_row1,const uword in_col1,const SizeMat & s) const2707 SpMat<eT>::operator()(const uword in_row1, const uword in_col1, const SizeMat& s) const
2708 {
2709 arma_extra_debug_sigprint();
2710
2711 return (*this).submat(in_row1, in_col1, s);
2712 }
2713
2714
2715
2716 template<typename eT>
2717 inline
2718 SpSubview<eT>
head_rows(const uword N)2719 SpMat<eT>::head_rows(const uword N)
2720 {
2721 arma_extra_debug_sigprint();
2722
2723 arma_debug_check_bounds( (N > n_rows), "SpMat::head_rows(): size out of bounds" );
2724
2725 return SpSubview<eT>(*this, 0, 0, N, n_cols);
2726 }
2727
2728
2729
2730 template<typename eT>
2731 inline
2732 const SpSubview<eT>
head_rows(const uword N) const2733 SpMat<eT>::head_rows(const uword N) const
2734 {
2735 arma_extra_debug_sigprint();
2736
2737 arma_debug_check_bounds( (N > n_rows), "SpMat::head_rows(): size out of bounds" );
2738
2739 return SpSubview<eT>(*this, 0, 0, N, n_cols);
2740 }
2741
2742
2743
2744 template<typename eT>
2745 inline
2746 SpSubview<eT>
tail_rows(const uword N)2747 SpMat<eT>::tail_rows(const uword N)
2748 {
2749 arma_extra_debug_sigprint();
2750
2751 arma_debug_check_bounds( (N > n_rows), "SpMat::tail_rows(): size out of bounds" );
2752
2753 const uword start_row = n_rows - N;
2754
2755 return SpSubview<eT>(*this, start_row, 0, N, n_cols);
2756 }
2757
2758
2759
2760 template<typename eT>
2761 inline
2762 const SpSubview<eT>
tail_rows(const uword N) const2763 SpMat<eT>::tail_rows(const uword N) const
2764 {
2765 arma_extra_debug_sigprint();
2766
2767 arma_debug_check_bounds( (N > n_rows), "SpMat::tail_rows(): size out of bounds" );
2768
2769 const uword start_row = n_rows - N;
2770
2771 return SpSubview<eT>(*this, start_row, 0, N, n_cols);
2772 }
2773
2774
2775
2776 template<typename eT>
2777 inline
2778 SpSubview<eT>
head_cols(const uword N)2779 SpMat<eT>::head_cols(const uword N)
2780 {
2781 arma_extra_debug_sigprint();
2782
2783 arma_debug_check_bounds( (N > n_cols), "SpMat::head_cols(): size out of bounds" );
2784
2785 return SpSubview<eT>(*this, 0, 0, n_rows, N);
2786 }
2787
2788
2789
2790 template<typename eT>
2791 inline
2792 const SpSubview<eT>
head_cols(const uword N) const2793 SpMat<eT>::head_cols(const uword N) const
2794 {
2795 arma_extra_debug_sigprint();
2796
2797 arma_debug_check_bounds( (N > n_cols), "SpMat::head_cols(): size out of bounds" );
2798
2799 return SpSubview<eT>(*this, 0, 0, n_rows, N);
2800 }
2801
2802
2803
2804 template<typename eT>
2805 inline
2806 SpSubview<eT>
tail_cols(const uword N)2807 SpMat<eT>::tail_cols(const uword N)
2808 {
2809 arma_extra_debug_sigprint();
2810
2811 arma_debug_check_bounds( (N > n_cols), "SpMat::tail_cols(): size out of bounds" );
2812
2813 const uword start_col = n_cols - N;
2814
2815 return SpSubview<eT>(*this, 0, start_col, n_rows, N);
2816 }
2817
2818
2819
2820 template<typename eT>
2821 inline
2822 const SpSubview<eT>
tail_cols(const uword N) const2823 SpMat<eT>::tail_cols(const uword N) const
2824 {
2825 arma_extra_debug_sigprint();
2826
2827 arma_debug_check_bounds( (N > n_cols), "SpMat::tail_cols(): size out of bounds" );
2828
2829 const uword start_col = n_cols - N;
2830
2831 return SpSubview<eT>(*this, 0, start_col, n_rows, N);
2832 }
2833
2834
2835
2836 template<typename eT>
2837 template<typename T1>
2838 arma_inline
2839 SpSubview_col_list<eT, T1>
cols(const Base<uword,T1> & indices)2840 SpMat<eT>::cols(const Base<uword, T1>& indices)
2841 {
2842 arma_extra_debug_sigprint();
2843
2844 return SpSubview_col_list<eT, T1>(*this, indices);
2845 }
2846
2847
2848
2849 template<typename eT>
2850 template<typename T1>
2851 arma_inline
2852 const SpSubview_col_list<eT, T1>
cols(const Base<uword,T1> & indices) const2853 SpMat<eT>::cols(const Base<uword, T1>& indices) const
2854 {
2855 arma_extra_debug_sigprint();
2856
2857 return SpSubview_col_list<eT, T1>(*this, indices);
2858 }
2859
2860
2861
2862 //! creation of spdiagview (diagonal)
2863 template<typename eT>
2864 inline
2865 spdiagview<eT>
diag(const sword in_id)2866 SpMat<eT>::diag(const sword in_id)
2867 {
2868 arma_extra_debug_sigprint();
2869
2870 const uword row_offset = (in_id < 0) ? uword(-in_id) : 0;
2871 const uword col_offset = (in_id > 0) ? uword( in_id) : 0;
2872
2873 arma_debug_check_bounds
2874 (
2875 ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)),
2876 "SpMat::diag(): requested diagonal out of bounds"
2877 );
2878
2879 const uword len = (std::min)(n_rows - row_offset, n_cols - col_offset);
2880
2881 return spdiagview<eT>(*this, row_offset, col_offset, len);
2882 }
2883
2884
2885
2886 //! creation of spdiagview (diagonal)
2887 template<typename eT>
2888 inline
2889 const spdiagview<eT>
diag(const sword in_id) const2890 SpMat<eT>::diag(const sword in_id) const
2891 {
2892 arma_extra_debug_sigprint();
2893
2894 const uword row_offset = uword( (in_id < 0) ? -in_id : 0 );
2895 const uword col_offset = uword( (in_id > 0) ? in_id : 0 );
2896
2897 arma_debug_check_bounds
2898 (
2899 ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)),
2900 "SpMat::diag(): requested diagonal out of bounds"
2901 );
2902
2903 const uword len = (std::min)(n_rows - row_offset, n_cols - col_offset);
2904
2905 return spdiagview<eT>(*this, row_offset, col_offset, len);
2906 }
2907
2908
2909
2910 template<typename eT>
2911 inline
2912 void
swap_rows(const uword in_row1,const uword in_row2)2913 SpMat<eT>::swap_rows(const uword in_row1, const uword in_row2)
2914 {
2915 arma_extra_debug_sigprint();
2916
2917 arma_debug_check_bounds( ((in_row1 >= n_rows) || (in_row2 >= n_rows)), "SpMat::swap_rows(): out of bounds" );
2918
2919 if(in_row1 == in_row2) { return; }
2920
2921 sync_csc();
2922 invalidate_cache();
2923
2924 // The easier way to do this, instead of collecting all the elements in one row and then swapping with the other, will be
2925 // to iterate over each column of the matrix (since we store in column-major format) and then swap the two elements in the two rows at that time.
2926 // We will try to avoid using the at() call since it is expensive, instead preferring to use an iterator to track our position.
2927 uword col1 = (in_row1 < in_row2) ? in_row1 : in_row2;
2928 uword col2 = (in_row1 < in_row2) ? in_row2 : in_row1;
2929
2930 for(uword lcol = 0; lcol < n_cols; lcol++)
2931 {
2932 // If there is nothing in this column we can ignore it.
2933 if(col_ptrs[lcol] == col_ptrs[lcol + 1])
2934 {
2935 continue;
2936 }
2937
2938 // These will represent the positions of the items themselves.
2939 uword loc1 = n_nonzero + 1;
2940 uword loc2 = n_nonzero + 1;
2941
2942 for(uword search_pos = col_ptrs[lcol]; search_pos < col_ptrs[lcol + 1]; search_pos++)
2943 {
2944 if(row_indices[search_pos] == col1)
2945 {
2946 loc1 = search_pos;
2947 }
2948
2949 if(row_indices[search_pos] == col2)
2950 {
2951 loc2 = search_pos;
2952 break; // No need to look any further.
2953 }
2954 }
2955
2956 // There are four cases: we found both elements; we found one element (loc1); we found one element (loc2); we found zero elements.
2957 // If we found zero elements no work needs to be done and we can continue to the next column.
2958 if((loc1 != (n_nonzero + 1)) && (loc2 != (n_nonzero + 1)))
2959 {
2960 // This is an easy case: just swap the values. No index modifying necessary.
2961 eT tmp = values[loc1];
2962 access::rw(values[loc1]) = values[loc2];
2963 access::rw(values[loc2]) = tmp;
2964 }
2965 else if(loc1 != (n_nonzero + 1)) // We only found loc1 and not loc2.
2966 {
2967 // We need to find the correct place to move our value to. It will be forward (not backwards) because in_row2 > in_row1.
2968 // Each iteration of the loop swaps the current value (loc1) with (loc1 + 1); in this manner we move our value down to where it should be.
2969 while(((loc1 + 1) < col_ptrs[lcol + 1]) && (row_indices[loc1 + 1] < in_row2))
2970 {
2971 // Swap both the values and the indices. The column should not change.
2972 eT tmp = values[loc1];
2973 access::rw(values[loc1]) = values[loc1 + 1];
2974 access::rw(values[loc1 + 1]) = tmp;
2975
2976 uword tmp_index = row_indices[loc1];
2977 access::rw(row_indices[loc1]) = row_indices[loc1 + 1];
2978 access::rw(row_indices[loc1 + 1]) = tmp_index;
2979
2980 loc1++; // And increment the counter.
2981 }
2982
2983 // Now set the row index correctly.
2984 access::rw(row_indices[loc1]) = in_row2;
2985
2986 }
2987 else if(loc2 != (n_nonzero + 1))
2988 {
2989 // We need to find the correct place to move our value to. It will be backwards (not forwards) because in_row1 < in_row2.
2990 // Each iteration of the loop swaps the current value (loc2) with (loc2 - 1); in this manner we move our value up to where it should be.
2991 while(((loc2 - 1) >= col_ptrs[lcol]) && (row_indices[loc2 - 1] > in_row1))
2992 {
2993 // Swap both the values and the indices. The column should not change.
2994 eT tmp = values[loc2];
2995 access::rw(values[loc2]) = values[loc2 - 1];
2996 access::rw(values[loc2 - 1]) = tmp;
2997
2998 uword tmp_index = row_indices[loc2];
2999 access::rw(row_indices[loc2]) = row_indices[loc2 - 1];
3000 access::rw(row_indices[loc2 - 1]) = tmp_index;
3001
3002 loc2--; // And decrement the counter.
3003 }
3004
3005 // Now set the row index correctly.
3006 access::rw(row_indices[loc2]) = in_row1;
3007
3008 }
3009 /* else: no need to swap anything; both values are zero */
3010 }
3011 }
3012
3013
3014
3015 template<typename eT>
3016 inline
3017 void
swap_cols(const uword in_col1,const uword in_col2)3018 SpMat<eT>::swap_cols(const uword in_col1, const uword in_col2)
3019 {
3020 arma_extra_debug_sigprint();
3021
3022 arma_debug_check_bounds( ((in_col1 >= n_cols) || (in_col2 >= n_cols)), "SpMat::swap_cols(): out of bounds" );
3023
3024 if(in_col1 == in_col2) { return; }
3025
3026 // TODO: this is a rudimentary implementation
3027
3028 SpMat<eT> tmp = (*this);
3029
3030 tmp.col(in_col1) = (*this).col(in_col2);
3031 tmp.col(in_col2) = (*this).col(in_col1);
3032
3033 steal_mem(tmp);
3034
3035 // for(uword lrow = 0; lrow < n_rows; ++lrow)
3036 // {
3037 // const eT tmp = at(lrow, in_col1);
3038 // at(lrow, in_col1) = eT( at(lrow, in_col2) );
3039 // at(lrow, in_col2) = tmp;
3040 // }
3041 }
3042
3043
3044
3045 template<typename eT>
3046 inline
3047 void
shed_row(const uword row_num)3048 SpMat<eT>::shed_row(const uword row_num)
3049 {
3050 arma_extra_debug_sigprint();
3051
3052 arma_debug_check_bounds(row_num >= n_rows, "SpMat::shed_row(): out of bounds");
3053
3054 shed_rows (row_num, row_num);
3055 }
3056
3057
3058
3059 template<typename eT>
3060 inline
3061 void
shed_col(const uword col_num)3062 SpMat<eT>::shed_col(const uword col_num)
3063 {
3064 arma_extra_debug_sigprint();
3065
3066 arma_debug_check_bounds(col_num >= n_cols, "SpMat::shed_col(): out of bounds");
3067
3068 shed_cols(col_num, col_num);
3069 }
3070
3071
3072
3073 template<typename eT>
3074 inline
3075 void
shed_rows(const uword in_row1,const uword in_row2)3076 SpMat<eT>::shed_rows(const uword in_row1, const uword in_row2)
3077 {
3078 arma_extra_debug_sigprint();
3079
3080 arma_debug_check_bounds
3081 (
3082 (in_row1 > in_row2) || (in_row2 >= n_rows),
3083 "SpMat::shed_rows(): indices out of bounds or incorectly used"
3084 );
3085
3086 sync_csc();
3087
3088 SpMat<eT> newmat(n_rows - (in_row2 - in_row1 + 1), n_cols);
3089
3090 // First, count the number of elements we will be removing.
3091 uword removing = 0;
3092 for(uword i = 0; i < n_nonzero; ++i)
3093 {
3094 const uword lrow = row_indices[i];
3095 if(lrow >= in_row1 && lrow <= in_row2)
3096 {
3097 ++removing;
3098 }
3099 }
3100
3101 // Obtain counts of the number of points in each column and store them as the
3102 // (invalid) column pointers of the new matrix.
3103 for(uword i = 1; i < n_cols + 1; ++i)
3104 {
3105 access::rw(newmat.col_ptrs[i]) = col_ptrs[i] - col_ptrs[i - 1];
3106 }
3107
3108 // Now initialize memory for the new matrix.
3109 newmat.mem_resize(n_nonzero - removing);
3110
3111 // Now, copy over the elements.
3112 // i is the index in the old matrix; j is the index in the new matrix.
3113 const_iterator it = begin();
3114 const_iterator it_end = end();
3115
3116 uword j = 0; // The index in the new matrix.
3117 while(it != it_end)
3118 {
3119 const uword lrow = it.row();
3120 const uword lcol = it.col();
3121
3122 if(lrow >= in_row1 && lrow <= in_row2)
3123 {
3124 // This element is being removed. Subtract it from the column counts.
3125 --access::rw(newmat.col_ptrs[lcol + 1]);
3126 }
3127 else
3128 {
3129 // This element is being kept. We may need to map the row index,
3130 // if it is past the section of rows we are removing.
3131 if(lrow > in_row2)
3132 {
3133 access::rw(newmat.row_indices[j]) = lrow - (in_row2 - in_row1 + 1);
3134 }
3135 else
3136 {
3137 access::rw(newmat.row_indices[j]) = lrow;
3138 }
3139
3140 access::rw(newmat.values[j]) = (*it);
3141 ++j; // Increment index in new matrix.
3142 }
3143
3144 ++it;
3145 }
3146
3147 // Finally, sum the column counts so they are correct column pointers.
3148 for(uword i = 1; i < n_cols + 1; ++i)
3149 {
3150 access::rw(newmat.col_ptrs[i]) += newmat.col_ptrs[i - 1];
3151 }
3152
3153 // Now steal the memory of the new matrix.
3154 steal_mem(newmat);
3155 }
3156
3157
3158
3159 template<typename eT>
3160 inline
3161 void
shed_cols(const uword in_col1,const uword in_col2)3162 SpMat<eT>::shed_cols(const uword in_col1, const uword in_col2)
3163 {
3164 arma_extra_debug_sigprint();
3165
3166 arma_debug_check_bounds
3167 (
3168 (in_col1 > in_col2) || (in_col2 >= n_cols),
3169 "SpMat::shed_cols(): indices out of bounds or incorrectly used"
3170 );
3171
3172 sync_csc();
3173 invalidate_cache();
3174
3175 // First we find the locations in values and row_indices for the column entries.
3176 uword col_beg = col_ptrs[in_col1];
3177 uword col_end = col_ptrs[in_col2 + 1];
3178
3179 // Then we find the number of entries in the column.
3180 uword diff = col_end - col_beg;
3181
3182 if(diff > 0)
3183 {
3184 eT* new_values = memory::acquire<eT> (n_nonzero - diff);
3185 uword* new_row_indices = memory::acquire<uword>(n_nonzero - diff);
3186
3187 // Copy first part.
3188 if(col_beg != 0)
3189 {
3190 arrayops::copy(new_values, values, col_beg);
3191 arrayops::copy(new_row_indices, row_indices, col_beg);
3192 }
3193
3194 // Copy second part.
3195 if(col_end != n_nonzero)
3196 {
3197 arrayops::copy(new_values + col_beg, values + col_end, n_nonzero - col_end);
3198 arrayops::copy(new_row_indices + col_beg, row_indices + col_end, n_nonzero - col_end);
3199 }
3200
3201 if(values) { memory::release(access::rw(values)); }
3202 if(row_indices) { memory::release(access::rw(row_indices)); }
3203
3204 access::rw(values) = new_values;
3205 access::rw(row_indices) = new_row_indices;
3206
3207 // Update counts and such.
3208 access::rw(n_nonzero) -= diff;
3209 }
3210
3211 // Update column pointers.
3212 const uword new_n_cols = n_cols - ((in_col2 - in_col1) + 1);
3213
3214 uword* new_col_ptrs = memory::acquire<uword>(new_n_cols + 2);
3215 new_col_ptrs[new_n_cols + 1] = std::numeric_limits<uword>::max();
3216
3217 // Copy first set of columns (no manipulation required).
3218 if(in_col1 != 0)
3219 {
3220 arrayops::copy(new_col_ptrs, col_ptrs, in_col1);
3221 }
3222
3223 // Copy second set of columns (manipulation required).
3224 uword cur_col = in_col1;
3225 for(uword i = in_col2 + 1; i <= n_cols; ++i, ++cur_col)
3226 {
3227 new_col_ptrs[cur_col] = col_ptrs[i] - diff;
3228 }
3229
3230 if(col_ptrs) { memory::release(access::rw(col_ptrs)); }
3231 access::rw(col_ptrs) = new_col_ptrs;
3232
3233 // We update the element and column counts, and we're done.
3234 access::rw(n_cols) = new_n_cols;
3235 access::rw(n_elem) = n_cols * n_rows;
3236 }
3237
3238
3239
3240 /**
3241 * Element access; acces the i'th element (works identically to the Mat accessors).
3242 * If there is nothing at element i, 0 is returned.
3243 */
3244
3245 template<typename eT>
3246 arma_inline
3247 arma_warn_unused
3248 SpMat_MapMat_val<eT>
operator [](const uword i)3249 SpMat<eT>::operator[](const uword i)
3250 {
3251 const uword in_col = i / n_rows;
3252 const uword in_row = i % n_rows;
3253
3254 return SpMat_MapMat_val<eT>((*this), cache, in_row, in_col);
3255 }
3256
3257
3258
3259 template<typename eT>
3260 arma_inline
3261 arma_warn_unused
3262 eT
operator [](const uword i) const3263 SpMat<eT>::operator[](const uword i) const
3264 {
3265 return get_value(i);
3266 }
3267
3268
3269
3270 template<typename eT>
3271 arma_inline
3272 arma_warn_unused
3273 SpMat_MapMat_val<eT>
at(const uword i)3274 SpMat<eT>::at(const uword i)
3275 {
3276 const uword in_col = i / n_rows;
3277 const uword in_row = i % n_rows;
3278
3279 return SpMat_MapMat_val<eT>((*this), cache, in_row, in_col);
3280 }
3281
3282
3283
3284 template<typename eT>
3285 arma_inline
3286 arma_warn_unused
3287 eT
at(const uword i) const3288 SpMat<eT>::at(const uword i) const
3289 {
3290 return get_value(i);
3291 }
3292
3293
3294
3295 template<typename eT>
3296 arma_inline
3297 arma_warn_unused
3298 SpMat_MapMat_val<eT>
operator ()(const uword i)3299 SpMat<eT>::operator()(const uword i)
3300 {
3301 arma_debug_check_bounds( (i >= n_elem), "SpMat::operator(): out of bounds" );
3302
3303 const uword in_col = i / n_rows;
3304 const uword in_row = i % n_rows;
3305
3306 return SpMat_MapMat_val<eT>((*this), cache, in_row, in_col);
3307 }
3308
3309
3310
3311 template<typename eT>
3312 arma_inline
3313 arma_warn_unused
3314 eT
operator ()(const uword i) const3315 SpMat<eT>::operator()(const uword i) const
3316 {
3317 arma_debug_check_bounds( (i >= n_elem), "SpMat::operator(): out of bounds" );
3318
3319 return get_value(i);
3320 }
3321
3322
3323
3324 /**
3325 * Element access; access the element at row in_rows and column in_col.
3326 * If there is nothing at that position, 0 is returned.
3327 */
3328
3329 template<typename eT>
3330 arma_inline
3331 arma_warn_unused
3332 SpMat_MapMat_val<eT>
at(const uword in_row,const uword in_col)3333 SpMat<eT>::at(const uword in_row, const uword in_col)
3334 {
3335 return SpMat_MapMat_val<eT>((*this), cache, in_row, in_col);
3336 }
3337
3338
3339
3340 template<typename eT>
3341 arma_inline
3342 arma_warn_unused
3343 eT
at(const uword in_row,const uword in_col) const3344 SpMat<eT>::at(const uword in_row, const uword in_col) const
3345 {
3346 return get_value(in_row, in_col);
3347 }
3348
3349
3350
3351 template<typename eT>
3352 arma_inline
3353 arma_warn_unused
3354 SpMat_MapMat_val<eT>
operator ()(const uword in_row,const uword in_col)3355 SpMat<eT>::operator()(const uword in_row, const uword in_col)
3356 {
3357 arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "SpMat::operator(): out of bounds" );
3358
3359 return SpMat_MapMat_val<eT>((*this), cache, in_row, in_col);
3360 }
3361
3362
3363
3364 template<typename eT>
3365 arma_inline
3366 arma_warn_unused
3367 eT
operator ()(const uword in_row,const uword in_col) const3368 SpMat<eT>::operator()(const uword in_row, const uword in_col) const
3369 {
3370 arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "SpMat::operator(): out of bounds" );
3371
3372 return get_value(in_row, in_col);
3373 }
3374
3375
3376
3377 /**
3378 * Check if matrix is empty (no size, no values).
3379 */
3380 template<typename eT>
3381 arma_inline
3382 arma_warn_unused
3383 bool
is_empty() const3384 SpMat<eT>::is_empty() const
3385 {
3386 return (n_elem == 0);
3387 }
3388
3389
3390
3391 //! returns true if the object can be interpreted as a column or row vector
3392 template<typename eT>
3393 arma_inline
3394 arma_warn_unused
3395 bool
is_vec() const3396 SpMat<eT>::is_vec() const
3397 {
3398 return ( (n_rows == 1) || (n_cols == 1) );
3399 }
3400
3401
3402
3403 //! returns true if the object can be interpreted as a row vector
3404 template<typename eT>
3405 arma_inline
3406 arma_warn_unused
3407 bool
is_rowvec() const3408 SpMat<eT>::is_rowvec() const
3409 {
3410 return (n_rows == 1);
3411 }
3412
3413
3414
3415 //! returns true if the object can be interpreted as a column vector
3416 template<typename eT>
3417 arma_inline
3418 arma_warn_unused
3419 bool
is_colvec() const3420 SpMat<eT>::is_colvec() const
3421 {
3422 return (n_cols == 1);
3423 }
3424
3425
3426
3427 //! returns true if the object has the same number of non-zero rows and columnns
3428 template<typename eT>
3429 arma_inline
3430 arma_warn_unused
3431 bool
is_square() const3432 SpMat<eT>::is_square() const
3433 {
3434 return (n_rows == n_cols);
3435 }
3436
3437
3438
3439 //! returns true if all of the elements are finite
3440 template<typename eT>
3441 inline
3442 arma_warn_unused
3443 bool
is_finite() const3444 SpMat<eT>::is_finite() const
3445 {
3446 arma_extra_debug_sigprint();
3447
3448 sync_csc();
3449
3450 return arrayops::is_finite(values, n_nonzero);
3451 }
3452
3453
3454
3455 template<typename eT>
3456 inline
3457 arma_warn_unused
3458 bool
is_symmetric() const3459 SpMat<eT>::is_symmetric() const
3460 {
3461 arma_extra_debug_sigprint();
3462
3463 const SpMat<eT>& A = (*this);
3464
3465 if(A.n_rows != A.n_cols) { return false; }
3466
3467 const SpMat<eT> tmp = A - A.st();
3468
3469 return (tmp.n_nonzero == uword(0));
3470 }
3471
3472
3473
3474 template<typename eT>
3475 inline
3476 arma_warn_unused
3477 bool
is_symmetric(const typename get_pod_type<elem_type>::result tol) const3478 SpMat<eT>::is_symmetric(const typename get_pod_type<elem_type>::result tol) const
3479 {
3480 arma_extra_debug_sigprint();
3481
3482 typedef typename get_pod_type<eT>::result T;
3483
3484 if(tol == T(0)) { return (*this).is_symmetric(); }
3485
3486 arma_debug_check( (tol < T(0)), "is_symmetric(): parameter 'tol' must be >= 0" );
3487
3488 const SpMat<eT>& A = (*this);
3489
3490 if(A.n_rows != A.n_cols) { return false; }
3491
3492 const T norm_A = as_scalar( arma::max(sum(abs(A), 1), 0) );
3493
3494 if(norm_A == T(0)) { return true; }
3495
3496 const T norm_A_Ast = as_scalar( arma::max(sum(abs(A - A.st()), 1), 0) );
3497
3498 return ( (norm_A_Ast / norm_A) <= tol );
3499 }
3500
3501
3502
3503 template<typename eT>
3504 inline
3505 arma_warn_unused
3506 bool
is_hermitian() const3507 SpMat<eT>::is_hermitian() const
3508 {
3509 arma_extra_debug_sigprint();
3510
3511 const SpMat<eT>& A = (*this);
3512
3513 if(A.n_rows != A.n_cols) { return false; }
3514
3515 const SpMat<eT> tmp = A - A.t();
3516
3517 return (tmp.n_nonzero == uword(0));
3518 }
3519
3520
3521
3522 template<typename eT>
3523 inline
3524 arma_warn_unused
3525 bool
is_hermitian(const typename get_pod_type<elem_type>::result tol) const3526 SpMat<eT>::is_hermitian(const typename get_pod_type<elem_type>::result tol) const
3527 {
3528 arma_extra_debug_sigprint();
3529
3530 typedef typename get_pod_type<eT>::result T;
3531
3532 if(tol == T(0)) { return (*this).is_hermitian(); }
3533
3534 arma_debug_check( (tol < T(0)), "is_hermitian(): parameter 'tol' must be >= 0" );
3535
3536 const SpMat<eT>& A = (*this);
3537
3538 if(A.n_rows != A.n_cols) { return false; }
3539
3540 const T norm_A = as_scalar( arma::max(sum(abs(A), 1), 0) );
3541
3542 if(norm_A == T(0)) { return true; }
3543
3544 const T norm_A_At = as_scalar( arma::max(sum(abs(A - A.t()), 1), 0) );
3545
3546 return ( (norm_A_At / norm_A) <= tol );
3547 }
3548
3549
3550
3551 template<typename eT>
3552 inline
3553 arma_warn_unused
3554 bool
has_inf() const3555 SpMat<eT>::has_inf() const
3556 {
3557 arma_extra_debug_sigprint();
3558
3559 sync_csc();
3560
3561 return arrayops::has_inf(values, n_nonzero);
3562 }
3563
3564
3565
3566 template<typename eT>
3567 inline
3568 arma_warn_unused
3569 bool
has_nan() const3570 SpMat<eT>::has_nan() const
3571 {
3572 arma_extra_debug_sigprint();
3573
3574 sync_csc();
3575
3576 return arrayops::has_nan(values, n_nonzero);
3577 }
3578
3579
3580
3581 //! returns true if the given index is currently in range
3582 template<typename eT>
3583 arma_inline
3584 arma_warn_unused
3585 bool
in_range(const uword i) const3586 SpMat<eT>::in_range(const uword i) const
3587 {
3588 return (i < n_elem);
3589 }
3590
3591
3592 //! returns true if the given start and end indices are currently in range
3593 template<typename eT>
3594 arma_inline
3595 arma_warn_unused
3596 bool
in_range(const span & x) const3597 SpMat<eT>::in_range(const span& x) const
3598 {
3599 arma_extra_debug_sigprint();
3600
3601 if(x.whole)
3602 {
3603 return true;
3604 }
3605 else
3606 {
3607 const uword a = x.a;
3608 const uword b = x.b;
3609
3610 return ( (a <= b) && (b < n_elem) );
3611 }
3612 }
3613
3614
3615
3616 //! returns true if the given location is currently in range
3617 template<typename eT>
3618 arma_inline
3619 arma_warn_unused
3620 bool
in_range(const uword in_row,const uword in_col) const3621 SpMat<eT>::in_range(const uword in_row, const uword in_col) const
3622 {
3623 return ( (in_row < n_rows) && (in_col < n_cols) );
3624 }
3625
3626
3627
3628 template<typename eT>
3629 arma_inline
3630 arma_warn_unused
3631 bool
in_range(const span & row_span,const uword in_col) const3632 SpMat<eT>::in_range(const span& row_span, const uword in_col) const
3633 {
3634 arma_extra_debug_sigprint();
3635
3636 if(row_span.whole)
3637 {
3638 return (in_col < n_cols);
3639 }
3640 else
3641 {
3642 const uword in_row1 = row_span.a;
3643 const uword in_row2 = row_span.b;
3644
3645 return ( (in_row1 <= in_row2) && (in_row2 < n_rows) && (in_col < n_cols) );
3646 }
3647 }
3648
3649
3650
3651 template<typename eT>
3652 arma_inline
3653 arma_warn_unused
3654 bool
in_range(const uword in_row,const span & col_span) const3655 SpMat<eT>::in_range(const uword in_row, const span& col_span) const
3656 {
3657 arma_extra_debug_sigprint();
3658
3659 if(col_span.whole)
3660 {
3661 return (in_row < n_rows);
3662 }
3663 else
3664 {
3665 const uword in_col1 = col_span.a;
3666 const uword in_col2 = col_span.b;
3667
3668 return ( (in_row < n_rows) && (in_col1 <= in_col2) && (in_col2 < n_cols) );
3669 }
3670 }
3671
3672
3673
3674 template<typename eT>
3675 arma_inline
3676 arma_warn_unused
3677 bool
in_range(const span & row_span,const span & col_span) const3678 SpMat<eT>::in_range(const span& row_span, const span& col_span) const
3679 {
3680 arma_extra_debug_sigprint();
3681
3682 const uword in_row1 = row_span.a;
3683 const uword in_row2 = row_span.b;
3684
3685 const uword in_col1 = col_span.a;
3686 const uword in_col2 = col_span.b;
3687
3688 const bool rows_ok = row_span.whole ? true : ( (in_row1 <= in_row2) && (in_row2 < n_rows) );
3689 const bool cols_ok = col_span.whole ? true : ( (in_col1 <= in_col2) && (in_col2 < n_cols) );
3690
3691 return ( rows_ok && cols_ok );
3692 }
3693
3694
3695
3696 template<typename eT>
3697 arma_inline
3698 arma_warn_unused
3699 bool
in_range(const uword in_row,const uword in_col,const SizeMat & s) const3700 SpMat<eT>::in_range(const uword in_row, const uword in_col, const SizeMat& s) const
3701 {
3702 const uword l_n_rows = n_rows;
3703 const uword l_n_cols = n_cols;
3704
3705 if( (in_row >= l_n_rows) || (in_col >= l_n_cols) || ((in_row + s.n_rows) > l_n_rows) || ((in_col + s.n_cols) > l_n_cols) )
3706 {
3707 return false;
3708 }
3709 else
3710 {
3711 return true;
3712 }
3713 }
3714
3715
3716
3717 //! Set the size to the size of another matrix.
3718 template<typename eT>
3719 template<typename eT2>
3720 inline
3721 void
copy_size(const SpMat<eT2> & m)3722 SpMat<eT>::copy_size(const SpMat<eT2>& m)
3723 {
3724 arma_extra_debug_sigprint();
3725
3726 set_size(m.n_rows, m.n_cols);
3727 }
3728
3729
3730
3731 template<typename eT>
3732 template<typename eT2>
3733 inline
3734 void
copy_size(const Mat<eT2> & m)3735 SpMat<eT>::copy_size(const Mat<eT2>& m)
3736 {
3737 arma_extra_debug_sigprint();
3738
3739 set_size(m.n_rows, m.n_cols);
3740 }
3741
3742
3743
3744 template<typename eT>
3745 inline
3746 void
set_size(const uword in_elem)3747 SpMat<eT>::set_size(const uword in_elem)
3748 {
3749 arma_extra_debug_sigprint();
3750
3751 // If this is a row vector, we resize to a row vector.
3752 if(vec_state == 2)
3753 {
3754 set_size(1, in_elem);
3755 }
3756 else
3757 {
3758 set_size(in_elem, 1);
3759 }
3760 }
3761
3762
3763
3764 template<typename eT>
3765 inline
3766 void
set_size(const uword in_rows,const uword in_cols)3767 SpMat<eT>::set_size(const uword in_rows, const uword in_cols)
3768 {
3769 arma_extra_debug_sigprint();
3770
3771 invalidate_cache(); // placed here, as set_size() is used during matrix modification
3772
3773 if( (n_rows == in_rows) && (n_cols == in_cols) )
3774 {
3775 return;
3776 }
3777 else
3778 {
3779 init(in_rows, in_cols);
3780 }
3781 }
3782
3783
3784
3785 template<typename eT>
3786 inline
3787 void
set_size(const SizeMat & s)3788 SpMat<eT>::set_size(const SizeMat& s)
3789 {
3790 arma_extra_debug_sigprint();
3791
3792 (*this).set_size(s.n_rows, s.n_cols);
3793 }
3794
3795
3796
3797 template<typename eT>
3798 inline
3799 void
resize(const uword in_rows,const uword in_cols)3800 SpMat<eT>::resize(const uword in_rows, const uword in_cols)
3801 {
3802 arma_extra_debug_sigprint();
3803
3804 if( (n_rows == in_rows) && (n_cols == in_cols) )
3805 {
3806 return;
3807 }
3808
3809 if( (n_elem == 0) || (n_nonzero == 0) )
3810 {
3811 set_size(in_rows, in_cols);
3812 return;
3813 }
3814
3815 SpMat<eT> tmp(in_rows, in_cols);
3816
3817 if(tmp.n_elem > 0)
3818 {
3819 sync_csc();
3820
3821 const uword last_row = (std::min)(in_rows, n_rows) - 1;
3822 const uword last_col = (std::min)(in_cols, n_cols) - 1;
3823
3824 tmp.submat(0, 0, last_row, last_col) = (*this).submat(0, 0, last_row, last_col);
3825 }
3826
3827 steal_mem(tmp);
3828 }
3829
3830
3831
3832 template<typename eT>
3833 inline
3834 void
resize(const SizeMat & s)3835 SpMat<eT>::resize(const SizeMat& s)
3836 {
3837 arma_extra_debug_sigprint();
3838
3839 (*this).resize(s.n_rows, s.n_cols);
3840 }
3841
3842
3843
3844 template<typename eT>
3845 inline
3846 void
reshape(const uword in_rows,const uword in_cols)3847 SpMat<eT>::reshape(const uword in_rows, const uword in_cols)
3848 {
3849 arma_extra_debug_sigprint();
3850
3851 arma_check( ((in_rows*in_cols) != n_elem), "SpMat::reshape(): changing the number of elements in a sparse matrix is currently not supported" );
3852
3853 if( (n_rows == in_rows) && (n_cols == in_cols) ) { return; }
3854
3855 if(vec_state == 1) { arma_debug_check( (in_cols != 1), "SpMat::reshape(): object is a column vector; requested size is not compatible" ); }
3856 if(vec_state == 2) { arma_debug_check( (in_rows != 1), "SpMat::reshape(): object is a row vector; requested size is not compatible" ); }
3857
3858 if(n_nonzero == 0)
3859 {
3860 (*this).zeros(in_rows, in_cols);
3861 return;
3862 }
3863
3864 if(in_cols == 1)
3865 {
3866 (*this).reshape_helper_intovec();
3867 }
3868 else
3869 {
3870 (*this).reshape_helper_generic(in_rows, in_cols);
3871 }
3872 }
3873
3874
3875
3876 template<typename eT>
3877 inline
3878 void
reshape(const SizeMat & s)3879 SpMat<eT>::reshape(const SizeMat& s)
3880 {
3881 arma_extra_debug_sigprint();
3882
3883 (*this).reshape(s.n_rows, s.n_cols);
3884 }
3885
3886
3887
3888 template<typename eT>
3889 inline
3890 void
reshape_helper_generic(const uword in_rows,const uword in_cols)3891 SpMat<eT>::reshape_helper_generic(const uword in_rows, const uword in_cols)
3892 {
3893 arma_extra_debug_sigprint();
3894
3895 sync_csc();
3896 invalidate_cache();
3897
3898 // We have to modify all of the relevant row indices and the relevant column pointers.
3899 // Iterate over all the points to do this. We won't be deleting any points, but we will be modifying
3900 // columns and rows. We'll have to store a new set of column vectors.
3901 uword* new_col_ptrs = memory::acquire<uword>(in_cols + 2);
3902 new_col_ptrs[in_cols + 1] = std::numeric_limits<uword>::max();
3903
3904 uword* new_row_indices = memory::acquire<uword>(n_nonzero + 1);
3905 access::rw(new_row_indices[n_nonzero]) = 0;
3906
3907 arrayops::fill_zeros(new_col_ptrs, in_cols + 1);
3908
3909 const_iterator it = begin();
3910 const_iterator it_end = end();
3911
3912 for(; it != it_end; ++it)
3913 {
3914 uword vector_position = (it.col() * n_rows) + it.row();
3915 new_row_indices[it.pos()] = vector_position % in_rows;
3916 ++new_col_ptrs[vector_position / in_rows + 1];
3917 }
3918
3919 // Now sum the column counts to get the new column pointers.
3920 for(uword i = 1; i <= in_cols; i++)
3921 {
3922 access::rw(new_col_ptrs[i]) += new_col_ptrs[i - 1];
3923 }
3924
3925 // Copy the new row indices.
3926 if(row_indices) { memory::release(access::rw(row_indices)); }
3927 if(col_ptrs) { memory::release(access::rw(col_ptrs)); }
3928
3929 access::rw(row_indices) = new_row_indices;
3930 access::rw(col_ptrs) = new_col_ptrs;
3931
3932 // Now set the size.
3933 access::rw(n_rows) = in_rows;
3934 access::rw(n_cols) = in_cols;
3935 }
3936
3937
3938
3939 template<typename eT>
3940 inline
3941 void
reshape_helper_intovec()3942 SpMat<eT>::reshape_helper_intovec()
3943 {
3944 arma_extra_debug_sigprint();
3945
3946 sync_csc();
3947 invalidate_cache();
3948
3949 const_iterator it = begin();
3950
3951 const uword t_n_rows = n_rows;
3952 const uword t_n_nonzero = n_nonzero;
3953
3954 for(uword i=0; i < t_n_nonzero; ++i)
3955 {
3956 const uword t_index = (it.col() * t_n_rows) + it.row();
3957
3958 // ensure the iterator is pointing to the next element
3959 // before we overwrite the row index of the current element
3960 ++it;
3961
3962 access::rw(row_indices[i]) = t_index;
3963 }
3964
3965 access::rw(row_indices[n_nonzero]) = 0;
3966
3967 access::rw(col_ptrs[0]) = 0;
3968 access::rw(col_ptrs[1]) = n_nonzero;
3969 access::rw(col_ptrs[2]) = std::numeric_limits<uword>::max();
3970
3971 access::rw(n_rows) = (n_rows * n_cols);
3972 access::rw(n_cols) = 1;
3973 }
3974
3975
3976
3977 //! apply a functor to each non-zero element
3978 template<typename eT>
3979 template<typename functor>
3980 inline
3981 const SpMat<eT>&
for_each(functor F)3982 SpMat<eT>::for_each(functor F)
3983 {
3984 arma_extra_debug_sigprint();
3985
3986 sync_csc();
3987
3988 const uword N = (*this).n_nonzero;
3989
3990 eT* rw_values = access::rwp(values);
3991
3992 bool modified = false;
3993 bool has_zero = false;
3994
3995 for(uword i=0; i < N; ++i)
3996 {
3997 eT& new_value = rw_values[i];
3998 const eT old_value = new_value;
3999
4000 F(new_value);
4001
4002 if(new_value != old_value) { modified = true; }
4003 if(new_value == eT(0) ) { has_zero = true; }
4004 }
4005
4006 if(modified) { invalidate_cache(); }
4007 if(has_zero) { remove_zeros(); }
4008
4009 return *this;
4010 }
4011
4012
4013
4014 template<typename eT>
4015 template<typename functor>
4016 inline
4017 const SpMat<eT>&
for_each(functor F) const4018 SpMat<eT>::for_each(functor F) const
4019 {
4020 arma_extra_debug_sigprint();
4021
4022 sync_csc();
4023
4024 const uword N = (*this).n_nonzero;
4025
4026 for(uword i=0; i < N; ++i)
4027 {
4028 F(values[i]);
4029 }
4030
4031 return *this;
4032 }
4033
4034
4035
4036 //! transform each non-zero element using a functor
4037 template<typename eT>
4038 template<typename functor>
4039 inline
4040 const SpMat<eT>&
transform(functor F)4041 SpMat<eT>::transform(functor F)
4042 {
4043 arma_extra_debug_sigprint();
4044
4045 sync_csc();
4046 invalidate_cache();
4047
4048 const uword N = (*this).n_nonzero;
4049
4050 eT* rw_values = access::rwp(values);
4051
4052 bool has_zero = false;
4053
4054 for(uword i=0; i < N; ++i)
4055 {
4056 eT& rw_values_i = rw_values[i];
4057
4058 rw_values_i = eT( F(rw_values_i) );
4059
4060 if(rw_values_i == eT(0)) { has_zero = true; }
4061 }
4062
4063 if(has_zero) { remove_zeros(); }
4064
4065 return *this;
4066 }
4067
4068
4069
4070 template<typename eT>
4071 inline
4072 const SpMat<eT>&
replace(const eT old_val,const eT new_val)4073 SpMat<eT>::replace(const eT old_val, const eT new_val)
4074 {
4075 arma_extra_debug_sigprint();
4076
4077 if(old_val == eT(0))
4078 {
4079 arma_debug_warn_level(1, "SpMat::replace(): replacement not done, as old_val = 0");
4080 }
4081 else
4082 {
4083 sync_csc();
4084 invalidate_cache();
4085
4086 arrayops::replace(access::rwp(values), n_nonzero, old_val, new_val);
4087
4088 if(new_val == eT(0)) { remove_zeros(); }
4089 }
4090
4091 return *this;
4092 }
4093
4094
4095
4096 template<typename eT>
4097 inline
4098 const SpMat<eT>&
clean(const typename get_pod_type<eT>::result threshold)4099 SpMat<eT>::clean(const typename get_pod_type<eT>::result threshold)
4100 {
4101 arma_extra_debug_sigprint();
4102
4103 if(n_nonzero == 0) { return *this; }
4104
4105 sync_csc();
4106 invalidate_cache();
4107
4108 arrayops::clean(access::rwp(values), n_nonzero, threshold);
4109
4110 remove_zeros();
4111
4112 return *this;
4113 }
4114
4115
4116
4117 template<typename eT>
4118 inline
4119 const SpMat<eT>&
clamp(const eT min_val,const eT max_val)4120 SpMat<eT>::clamp(const eT min_val, const eT max_val)
4121 {
4122 arma_extra_debug_sigprint();
4123
4124 if(is_cx<eT>::no)
4125 {
4126 arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "SpMat::clamp(): min_val must be less than max_val" );
4127 }
4128 else
4129 {
4130 arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "SpMat::clamp(): real(min_val) must be less than real(max_val)" );
4131 arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "SpMat::clamp(): imag(min_val) must be less than imag(max_val)" );
4132 }
4133
4134 if(n_nonzero == 0) { return *this; }
4135
4136 sync_csc();
4137 invalidate_cache();
4138
4139 arrayops::clamp(access::rwp(values), n_nonzero, min_val, max_val);
4140
4141 if( (min_val == eT(0)) || (max_val == eT(0)) ) { remove_zeros(); }
4142
4143 return *this;
4144 }
4145
4146
4147
4148 template<typename eT>
4149 inline
4150 const SpMat<eT>&
zeros()4151 SpMat<eT>::zeros()
4152 {
4153 arma_extra_debug_sigprint();
4154
4155 const bool already_done = ( (sync_state != 1) && (n_nonzero == 0) );
4156
4157 if(already_done == false)
4158 {
4159 init(n_rows, n_cols);
4160 }
4161
4162 return *this;
4163 }
4164
4165
4166
4167 template<typename eT>
4168 inline
4169 const SpMat<eT>&
zeros(const uword in_elem)4170 SpMat<eT>::zeros(const uword in_elem)
4171 {
4172 arma_extra_debug_sigprint();
4173
4174 if(vec_state == 2)
4175 {
4176 zeros(1, in_elem); // Row vector
4177 }
4178 else
4179 {
4180 zeros(in_elem, 1);
4181 }
4182
4183 return *this;
4184 }
4185
4186
4187
4188 template<typename eT>
4189 inline
4190 const SpMat<eT>&
zeros(const uword in_rows,const uword in_cols)4191 SpMat<eT>::zeros(const uword in_rows, const uword in_cols)
4192 {
4193 arma_extra_debug_sigprint();
4194
4195 const bool already_done = ( (sync_state != 1) && (n_nonzero == 0) && (n_rows == in_rows) && (n_cols == in_cols) );
4196
4197 if(already_done == false)
4198 {
4199 init(in_rows, in_cols);
4200 }
4201
4202 return *this;
4203 }
4204
4205
4206
4207 template<typename eT>
4208 inline
4209 const SpMat<eT>&
zeros(const SizeMat & s)4210 SpMat<eT>::zeros(const SizeMat& s)
4211 {
4212 arma_extra_debug_sigprint();
4213
4214 return (*this).zeros(s.n_rows, s.n_cols);
4215 }
4216
4217
4218
4219 template<typename eT>
4220 inline
4221 const SpMat<eT>&
eye()4222 SpMat<eT>::eye()
4223 {
4224 arma_extra_debug_sigprint();
4225
4226 return (*this).eye(n_rows, n_cols);
4227 }
4228
4229
4230
4231 template<typename eT>
4232 inline
4233 const SpMat<eT>&
eye(const uword in_rows,const uword in_cols)4234 SpMat<eT>::eye(const uword in_rows, const uword in_cols)
4235 {
4236 arma_extra_debug_sigprint();
4237
4238 const uword N = (std::min)(in_rows, in_cols);
4239
4240 init(in_rows, in_cols, N);
4241
4242 arrayops::inplace_set(access::rwp(values), eT(1), N);
4243
4244 for(uword i = 0; i < N; ++i) { access::rw(row_indices[i]) = i; }
4245
4246 for(uword i = 0; i <= N; ++i) { access::rw(col_ptrs[i]) = i; }
4247
4248 // take into account non-square matrices
4249 for(uword i = (N+1); i <= in_cols; ++i) { access::rw(col_ptrs[i]) = N; }
4250
4251 access::rw(n_nonzero) = N;
4252
4253 return *this;
4254 }
4255
4256
4257
4258 template<typename eT>
4259 inline
4260 const SpMat<eT>&
eye(const SizeMat & s)4261 SpMat<eT>::eye(const SizeMat& s)
4262 {
4263 arma_extra_debug_sigprint();
4264
4265 return (*this).eye(s.n_rows, s.n_cols);
4266 }
4267
4268
4269
4270 template<typename eT>
4271 inline
4272 const SpMat<eT>&
speye()4273 SpMat<eT>::speye()
4274 {
4275 arma_extra_debug_sigprint();
4276
4277 return (*this).eye(n_rows, n_cols);
4278 }
4279
4280
4281
4282 template<typename eT>
4283 inline
4284 const SpMat<eT>&
speye(const uword in_n_rows,const uword in_n_cols)4285 SpMat<eT>::speye(const uword in_n_rows, const uword in_n_cols)
4286 {
4287 arma_extra_debug_sigprint();
4288
4289 return (*this).eye(in_n_rows, in_n_cols);
4290 }
4291
4292
4293
4294 template<typename eT>
4295 inline
4296 const SpMat<eT>&
speye(const SizeMat & s)4297 SpMat<eT>::speye(const SizeMat& s)
4298 {
4299 arma_extra_debug_sigprint();
4300
4301 return (*this).eye(s.n_rows, s.n_cols);
4302 }
4303
4304
4305
4306 template<typename eT>
4307 inline
4308 const SpMat<eT>&
sprandu(const uword in_rows,const uword in_cols,const double density)4309 SpMat<eT>::sprandu(const uword in_rows, const uword in_cols, const double density)
4310 {
4311 arma_extra_debug_sigprint();
4312
4313 arma_debug_check( ( (density < double(0)) || (density > double(1)) ), "sprandu(): density must be in the [0,1] interval" );
4314
4315 const uword new_n_nonzero = uword(density * double(in_rows) * double(in_cols) + 0.5);
4316
4317 init(in_rows, in_cols, new_n_nonzero);
4318
4319 if(new_n_nonzero == 0) { return *this; }
4320
4321 arma_rng::randu<eT>::fill( access::rwp(values), new_n_nonzero );
4322
4323 uvec indices = linspace<uvec>( 0u, in_rows*in_cols-1, new_n_nonzero );
4324
4325 // perturb the indices
4326 for(uword i=1; i < new_n_nonzero-1; ++i)
4327 {
4328 const uword index_left = indices[i-1];
4329 const uword index_right = indices[i+1];
4330
4331 const uword center = (index_left + index_right) / 2;
4332
4333 const uword delta1 = center - index_left - 1;
4334 const uword delta2 = index_right - center - 1;
4335
4336 const uword min_delta = (std::min)(delta1, delta2);
4337
4338 uword index_new = uword( double(center) + double(min_delta) * (2.0*randu()-1.0) );
4339
4340 // paranoia, but better be safe than sorry
4341 if( (index_left < index_new) && (index_new < index_right) )
4342 {
4343 indices[i] = index_new;
4344 }
4345 }
4346
4347 uword cur_index = 0;
4348 uword count = 0;
4349
4350 for(uword lcol = 0; lcol < in_cols; ++lcol)
4351 for(uword lrow = 0; lrow < in_rows; ++lrow)
4352 {
4353 if(count == indices[cur_index])
4354 {
4355 access::rw(row_indices[cur_index]) = lrow;
4356 access::rw(col_ptrs[lcol + 1])++;
4357 ++cur_index;
4358 }
4359
4360 ++count;
4361 }
4362
4363 if(cur_index != new_n_nonzero)
4364 {
4365 // Fix size to correct size.
4366 mem_resize(cur_index);
4367 }
4368
4369 // Sum column pointers.
4370 for(uword lcol = 1; lcol <= in_cols; ++lcol)
4371 {
4372 access::rw(col_ptrs[lcol]) += col_ptrs[lcol - 1];
4373 }
4374
4375 return *this;
4376 }
4377
4378
4379
4380 template<typename eT>
4381 inline
4382 const SpMat<eT>&
sprandu(const SizeMat & s,const double density)4383 SpMat<eT>::sprandu(const SizeMat& s, const double density)
4384 {
4385 arma_extra_debug_sigprint();
4386
4387 return (*this).sprandu(s.n_rows, s.n_cols, density);
4388 }
4389
4390
4391
4392 template<typename eT>
4393 inline
4394 const SpMat<eT>&
sprandn(const uword in_rows,const uword in_cols,const double density)4395 SpMat<eT>::sprandn(const uword in_rows, const uword in_cols, const double density)
4396 {
4397 arma_extra_debug_sigprint();
4398
4399 arma_debug_check( ( (density < double(0)) || (density > double(1)) ), "sprandn(): density must be in the [0,1] interval" );
4400
4401 const uword new_n_nonzero = uword(density * double(in_rows) * double(in_cols) + 0.5);
4402
4403 init(in_rows, in_cols, new_n_nonzero);
4404
4405 if(new_n_nonzero == 0) { return *this; }
4406
4407 arma_rng::randn<eT>::fill( access::rwp(values), new_n_nonzero );
4408
4409 uvec indices = linspace<uvec>( 0u, in_rows*in_cols-1, new_n_nonzero );
4410
4411 // perturb the indices
4412 for(uword i=1; i < new_n_nonzero-1; ++i)
4413 {
4414 const uword index_left = indices[i-1];
4415 const uword index_right = indices[i+1];
4416
4417 const uword center = (index_left + index_right) / 2;
4418
4419 const uword delta1 = center - index_left - 1;
4420 const uword delta2 = index_right - center - 1;
4421
4422 const uword min_delta = (std::min)(delta1, delta2);
4423
4424 uword index_new = uword( double(center) + double(min_delta) * (2.0*randu()-1.0) );
4425
4426 // paranoia, but better be safe than sorry
4427 if( (index_left < index_new) && (index_new < index_right) )
4428 {
4429 indices[i] = index_new;
4430 }
4431 }
4432
4433 uword cur_index = 0;
4434 uword count = 0;
4435
4436 for(uword lcol = 0; lcol < in_cols; ++lcol)
4437 for(uword lrow = 0; lrow < in_rows; ++lrow)
4438 {
4439 if(count == indices[cur_index])
4440 {
4441 access::rw(row_indices[cur_index]) = lrow;
4442 access::rw(col_ptrs[lcol + 1])++;
4443 ++cur_index;
4444 }
4445
4446 ++count;
4447 }
4448
4449 if(cur_index != new_n_nonzero)
4450 {
4451 // Fix size to correct size.
4452 mem_resize(cur_index);
4453 }
4454
4455 // Sum column pointers.
4456 for(uword lcol = 1; lcol <= in_cols; ++lcol)
4457 {
4458 access::rw(col_ptrs[lcol]) += col_ptrs[lcol - 1];
4459 }
4460
4461 return *this;
4462 }
4463
4464
4465
4466 template<typename eT>
4467 inline
4468 const SpMat<eT>&
sprandn(const SizeMat & s,const double density)4469 SpMat<eT>::sprandn(const SizeMat& s, const double density)
4470 {
4471 arma_extra_debug_sigprint();
4472
4473 return (*this).sprandn(s.n_rows, s.n_cols, density);
4474 }
4475
4476
4477
4478 template<typename eT>
4479 inline
4480 void
reset()4481 SpMat<eT>::reset()
4482 {
4483 arma_extra_debug_sigprint();
4484
4485 switch(vec_state)
4486 {
4487 default:
4488 init(0, 0);
4489 break;
4490
4491 case 1:
4492 init(0, 1);
4493 break;
4494
4495 case 2:
4496 init(1, 0);
4497 break;
4498 }
4499 }
4500
4501
4502
4503 template<typename eT>
4504 inline
4505 void
reset_cache()4506 SpMat<eT>::reset_cache()
4507 {
4508 arma_extra_debug_sigprint();
4509
4510 sync_csc();
4511
4512 #if defined(ARMA_USE_OPENMP)
4513 {
4514 #pragma omp critical (arma_SpMat_cache)
4515 {
4516 cache.reset();
4517
4518 sync_state = 0;
4519 }
4520 }
4521 #elif (!defined(ARMA_DONT_USE_STD_MUTEX))
4522 {
4523 cache_mutex.lock();
4524
4525 cache.reset();
4526
4527 sync_state = 0;
4528
4529 cache_mutex.unlock();
4530 }
4531 #else
4532 {
4533 cache.reset();
4534
4535 sync_state = 0;
4536 }
4537 #endif
4538 }
4539
4540
4541
4542 template<typename eT>
4543 inline
4544 void
reserve(const uword in_rows,const uword in_cols,const uword new_n_nonzero)4545 SpMat<eT>::reserve(const uword in_rows, const uword in_cols, const uword new_n_nonzero)
4546 {
4547 arma_extra_debug_sigprint();
4548
4549 init(in_rows, in_cols, new_n_nonzero);
4550 }
4551
4552
4553
4554 template<typename eT>
4555 template<typename T1>
4556 inline
4557 void
set_real(const SpBase<typename SpMat<eT>::pod_type,T1> & X)4558 SpMat<eT>::set_real(const SpBase<typename SpMat<eT>::pod_type,T1>& X)
4559 {
4560 arma_extra_debug_sigprint();
4561
4562 SpMat_aux::set_real(*this, X);
4563 }
4564
4565
4566
4567 template<typename eT>
4568 template<typename T1>
4569 inline
4570 void
set_imag(const SpBase<typename SpMat<eT>::pod_type,T1> & X)4571 SpMat<eT>::set_imag(const SpBase<typename SpMat<eT>::pod_type,T1>& X)
4572 {
4573 arma_extra_debug_sigprint();
4574
4575 SpMat_aux::set_imag(*this, X);
4576 }
4577
4578
4579
4580 //! save the matrix to a file
4581 template<typename eT>
4582 inline
4583 arma_cold
4584 bool
save(const std::string name,const file_type type) const4585 SpMat<eT>::save(const std::string name, const file_type type) const
4586 {
4587 arma_extra_debug_sigprint();
4588
4589 sync_csc();
4590
4591 bool save_okay;
4592
4593 switch(type)
4594 {
4595 case csv_ascii:
4596 return (*this).save(csv_name(name), type);
4597 break;
4598
4599 case ssv_ascii:
4600 return (*this).save(csv_name(name), type);
4601 break;
4602
4603 case arma_binary:
4604 save_okay = diskio::save_arma_binary(*this, name);
4605 break;
4606
4607 case coord_ascii:
4608 save_okay = diskio::save_coord_ascii(*this, name);
4609 break;
4610
4611 default:
4612 arma_debug_warn_level(1, "SpMat::save(): unsupported file type");
4613 save_okay = false;
4614 }
4615
4616 if(save_okay == false) { arma_debug_warn_level(3, "SpMat::save(): couldn't write; file: ", name); }
4617
4618 return save_okay;
4619 }
4620
4621
4622
4623 template<typename eT>
4624 inline
4625 arma_cold
4626 bool
save(const csv_name & spec,const file_type type) const4627 SpMat<eT>::save(const csv_name& spec, const file_type type) const
4628 {
4629 arma_extra_debug_sigprint();
4630
4631 if( (type != csv_ascii) && (type != ssv_ascii) )
4632 {
4633 arma_stop_runtime_error("SpMat::save(): unsupported file type for csv_name()");
4634 return false;
4635 }
4636
4637 const bool do_trans = bool(spec.opts.flags & csv_opts::flag_trans );
4638 const bool no_header = bool(spec.opts.flags & csv_opts::flag_no_header );
4639 bool with_header = bool(spec.opts.flags & csv_opts::flag_with_header);
4640 const bool use_semicolon = bool(spec.opts.flags & csv_opts::flag_semicolon ) || (type == ssv_ascii);
4641
4642 arma_extra_debug_print("SpMat::save(csv_name): enabled flags:");
4643
4644 if(do_trans ) { arma_extra_debug_print("trans"); }
4645 if(no_header ) { arma_extra_debug_print("no_header"); }
4646 if(with_header ) { arma_extra_debug_print("with_header"); }
4647 if(use_semicolon) { arma_extra_debug_print("semicolon"); }
4648
4649 const char separator = (use_semicolon) ? char(';') : char(',');
4650
4651 if(no_header) { with_header = false; }
4652
4653 if(with_header)
4654 {
4655 if( (spec.header_ro.n_cols != 1) && (spec.header_ro.n_rows != 1) )
4656 {
4657 arma_debug_warn_level(1, "SpMat::save(): given header must have a vector layout");
4658 return false;
4659 }
4660
4661 for(uword i=0; i < spec.header_ro.n_elem; ++i)
4662 {
4663 const std::string& token = spec.header_ro.at(i);
4664
4665 if(token.find(separator) != std::string::npos)
4666 {
4667 arma_debug_warn_level(1, "SpMat::save(): token within the header contains the separator character: '", token, "'");
4668 return false;
4669 }
4670 }
4671
4672 const uword save_n_cols = (do_trans) ? (*this).n_rows : (*this).n_cols;
4673
4674 if(spec.header_ro.n_elem != save_n_cols)
4675 {
4676 arma_debug_warn_level(1, "SpMat::save(): size mistmach between header and matrix");
4677 return false;
4678 }
4679 }
4680
4681 bool save_okay = false;
4682
4683 if(do_trans)
4684 {
4685 const SpMat<eT> tmp = (*this).st();
4686
4687 save_okay = diskio::save_csv_ascii(tmp, spec.filename, spec.header_ro, with_header, separator);
4688 }
4689 else
4690 {
4691 save_okay = diskio::save_csv_ascii(*this, spec.filename, spec.header_ro, with_header, separator);
4692 }
4693
4694 if(save_okay == false) { arma_debug_warn_level(3, "SpMat::save(): couldn't write; file: ", spec.filename); }
4695
4696 return save_okay;
4697 }
4698
4699
4700
4701 //! save the matrix to a stream
4702 template<typename eT>
4703 inline
4704 arma_cold
4705 bool
save(std::ostream & os,const file_type type) const4706 SpMat<eT>::save(std::ostream& os, const file_type type) const
4707 {
4708 arma_extra_debug_sigprint();
4709
4710 sync_csc();
4711
4712 bool save_okay;
4713
4714 switch(type)
4715 {
4716 case csv_ascii:
4717 save_okay = diskio::save_csv_ascii(*this, os, char(','));
4718 break;
4719
4720 case ssv_ascii:
4721 save_okay = diskio::save_csv_ascii(*this, os, char(';'));
4722 break;
4723
4724 case arma_binary:
4725 save_okay = diskio::save_arma_binary(*this, os);
4726 break;
4727
4728 case coord_ascii:
4729 save_okay = diskio::save_coord_ascii(*this, os);
4730 break;
4731
4732 default:
4733 arma_debug_warn_level(1, "SpMat::save(): unsupported file type");
4734 save_okay = false;
4735 }
4736
4737 if(save_okay == false) { arma_debug_warn_level(3, "SpMat::save(): couldn't write to stream"); }
4738
4739 return save_okay;
4740 }
4741
4742
4743
4744 //! load a matrix from a file
4745 template<typename eT>
4746 inline
4747 arma_cold
4748 bool
load(const std::string name,const file_type type)4749 SpMat<eT>::load(const std::string name, const file_type type)
4750 {
4751 arma_extra_debug_sigprint();
4752
4753 invalidate_cache();
4754
4755 bool load_okay;
4756 std::string err_msg;
4757
4758 switch(type)
4759 {
4760 // case auto_detect:
4761 // load_okay = diskio::load_auto_detect(*this, name, err_msg);
4762 // break;
4763
4764 case csv_ascii:
4765 return (*this).load(csv_name(name), type);
4766 break;
4767
4768 case ssv_ascii:
4769 return (*this).load(csv_name(name), type);
4770 break;
4771
4772 case arma_binary:
4773 load_okay = diskio::load_arma_binary(*this, name, err_msg);
4774 break;
4775
4776 case coord_ascii:
4777 load_okay = diskio::load_coord_ascii(*this, name, err_msg);
4778 break;
4779
4780 default:
4781 arma_debug_warn_level(1, "SpMat::load(): unsupported file type");
4782 load_okay = false;
4783 }
4784
4785 if(load_okay == false)
4786 {
4787 if(err_msg.length() > 0)
4788 {
4789 arma_debug_warn_level(3, "SpMat::load(): ", err_msg, "; file: ", name);
4790 }
4791 else
4792 {
4793 arma_debug_warn_level(3, "SpMat::load(): couldn't read; file: ", name);
4794 }
4795 }
4796
4797 if(load_okay == false) { (*this).reset(); }
4798
4799 return load_okay;
4800 }
4801
4802
4803
4804 template<typename eT>
4805 inline
4806 arma_cold
4807 bool
load(const csv_name & spec,const file_type type)4808 SpMat<eT>::load(const csv_name& spec, const file_type type)
4809 {
4810 arma_extra_debug_sigprint();
4811
4812 if( (type != csv_ascii) && (type != ssv_ascii) )
4813 {
4814 arma_stop_runtime_error("SpMat::load(): unsupported file type for csv_name()");
4815 return false;
4816 }
4817
4818 const bool do_trans = bool(spec.opts.flags & csv_opts::flag_trans );
4819 const bool no_header = bool(spec.opts.flags & csv_opts::flag_no_header );
4820 bool with_header = bool(spec.opts.flags & csv_opts::flag_with_header);
4821 const bool use_semicolon = bool(spec.opts.flags & csv_opts::flag_semicolon ) || (type == ssv_ascii);
4822
4823 arma_extra_debug_print("SpMat::load(csv_name): enabled flags:");
4824
4825 if(do_trans ) { arma_extra_debug_print("trans"); }
4826 if(no_header ) { arma_extra_debug_print("no_header"); }
4827 if(with_header ) { arma_extra_debug_print("with_header"); }
4828 if(use_semicolon) { arma_extra_debug_print("semicolon"); }
4829
4830 const char separator = (use_semicolon) ? char(';') : char(',');
4831
4832 if(no_header) { with_header = false; }
4833
4834 bool load_okay = false;
4835 std::string err_msg;
4836
4837 if(do_trans)
4838 {
4839 SpMat<eT> tmp_mat;
4840
4841 load_okay = diskio::load_csv_ascii(tmp_mat, spec.filename, err_msg, spec.header_rw, with_header, separator);
4842
4843 if(load_okay)
4844 {
4845 (*this) = tmp_mat.st();
4846
4847 if(with_header)
4848 {
4849 // field::set_size() preserves data if the number of elements hasn't changed
4850 spec.header_rw.set_size(spec.header_rw.n_elem, 1);
4851 }
4852 }
4853 }
4854 else
4855 {
4856 load_okay = diskio::load_csv_ascii(*this, spec.filename, err_msg, spec.header_rw, with_header, separator);
4857 }
4858
4859 if(load_okay == false)
4860 {
4861 if(err_msg.length() > 0)
4862 {
4863 arma_debug_warn_level(3, "SpMat::load(): ", err_msg, "; file: ", spec.filename);
4864 }
4865 else
4866 {
4867 arma_debug_warn_level(3, "SpMat::load(): couldn't read; file: ", spec.filename);
4868 }
4869 }
4870 else
4871 {
4872 const uword load_n_cols = (do_trans) ? (*this).n_rows : (*this).n_cols;
4873
4874 if(with_header && (spec.header_rw.n_elem != load_n_cols))
4875 {
4876 arma_debug_warn_level(3, "SpMat::load(): size mistmach between header and matrix");
4877 }
4878 }
4879
4880 if(load_okay == false)
4881 {
4882 (*this).reset();
4883
4884 if(with_header) { spec.header_rw.reset(); }
4885 }
4886
4887 return load_okay;
4888 }
4889
4890
4891
4892 //! load a matrix from a stream
4893 template<typename eT>
4894 inline
4895 arma_cold
4896 bool
load(std::istream & is,const file_type type)4897 SpMat<eT>::load(std::istream& is, const file_type type)
4898 {
4899 arma_extra_debug_sigprint();
4900
4901 invalidate_cache();
4902
4903 bool load_okay;
4904 std::string err_msg;
4905
4906 switch(type)
4907 {
4908 // case auto_detect:
4909 // load_okay = diskio::load_auto_detect(*this, is, err_msg);
4910 // break;
4911
4912 case csv_ascii:
4913 load_okay = diskio::load_csv_ascii(*this, is, err_msg, char(','));
4914 break;
4915
4916 case ssv_ascii:
4917 load_okay = diskio::load_csv_ascii(*this, is, err_msg, char(';'));
4918 break;
4919
4920 case arma_binary:
4921 load_okay = diskio::load_arma_binary(*this, is, err_msg);
4922 break;
4923
4924 case coord_ascii:
4925 load_okay = diskio::load_coord_ascii(*this, is, err_msg);
4926 break;
4927
4928 default:
4929 arma_debug_warn_level(1, "SpMat::load(): unsupported file type");
4930 load_okay = false;
4931 }
4932
4933 if(load_okay == false)
4934 {
4935 if(err_msg.length() > 0)
4936 {
4937 arma_debug_warn_level(3, "SpMat::load(): ", err_msg);
4938 }
4939 else
4940 {
4941 arma_debug_warn_level(3, "SpMat::load(): couldn't load from stream");
4942 }
4943 }
4944
4945 if(load_okay == false) { (*this).reset(); }
4946
4947 return load_okay;
4948 }
4949
4950
4951
4952 //! save the matrix to a file, without printing any error messages
4953 template<typename eT>
4954 inline
4955 arma_cold
4956 bool
quiet_save(const std::string name,const file_type type) const4957 SpMat<eT>::quiet_save(const std::string name, const file_type type) const
4958 {
4959 arma_extra_debug_sigprint();
4960
4961 return (*this).save(name, type);
4962 }
4963
4964
4965
4966 //! save the matrix to a stream, without printing any error messages
4967 template<typename eT>
4968 inline
4969 arma_cold
4970 bool
quiet_save(std::ostream & os,const file_type type) const4971 SpMat<eT>::quiet_save(std::ostream& os, const file_type type) const
4972 {
4973 arma_extra_debug_sigprint();
4974
4975 return (*this).save(os, type);
4976 }
4977
4978
4979
4980 //! load a matrix from a file, without printing any error messages
4981 template<typename eT>
4982 inline
4983 arma_cold
4984 bool
quiet_load(const std::string name,const file_type type)4985 SpMat<eT>::quiet_load(const std::string name, const file_type type)
4986 {
4987 arma_extra_debug_sigprint();
4988
4989 return (*this).load(name, type);
4990 }
4991
4992
4993
4994 //! load a matrix from a stream, without printing any error messages
4995 template<typename eT>
4996 inline
4997 arma_cold
4998 bool
quiet_load(std::istream & is,const file_type type)4999 SpMat<eT>::quiet_load(std::istream& is, const file_type type)
5000 {
5001 arma_extra_debug_sigprint();
5002
5003 return (*this).load(is, type);
5004 }
5005
5006
5007
5008 /**
5009 * Initialize the matrix to the specified size. Data is not preserved, so the matrix is assumed to be entirely sparse (empty).
5010 */
5011 template<typename eT>
5012 inline
5013 void
init(uword in_rows,uword in_cols,const uword new_n_nonzero)5014 SpMat<eT>::init(uword in_rows, uword in_cols, const uword new_n_nonzero)
5015 {
5016 arma_extra_debug_sigprint();
5017
5018 invalidate_cache(); // placed here, as init() is used during matrix modification
5019
5020 // Clean out the existing memory.
5021 if(values ) { memory::release(access::rw(values)); }
5022 if(row_indices) { memory::release(access::rw(row_indices)); }
5023 if(col_ptrs ) { memory::release(access::rw(col_ptrs)); }
5024
5025 // in case init_cold() throws an exception
5026 access::rw(n_rows) = 0;
5027 access::rw(n_cols) = 0;
5028 access::rw(n_elem) = 0;
5029 access::rw(n_nonzero) = 0;
5030 access::rw(values) = nullptr;
5031 access::rw(row_indices) = nullptr;
5032 access::rw(col_ptrs) = nullptr;
5033
5034 init_cold(in_rows, in_cols, new_n_nonzero);
5035 }
5036
5037
5038
5039 template<typename eT>
5040 inline
5041 void
5042 arma_cold
init_cold(uword in_rows,uword in_cols,const uword new_n_nonzero)5043 SpMat<eT>::init_cold(uword in_rows, uword in_cols, const uword new_n_nonzero)
5044 {
5045 arma_extra_debug_sigprint();
5046
5047 // Verify that we are allowed to do this.
5048 if(vec_state > 0)
5049 {
5050 if((in_rows == 0) && (in_cols == 0))
5051 {
5052 if(vec_state == 1) { in_cols = 1; }
5053 if(vec_state == 2) { in_rows = 1; }
5054 }
5055 else
5056 {
5057 if(vec_state == 1) { arma_debug_check( (in_cols != 1), "SpMat::init(): object is a column vector; requested size is not compatible" ); }
5058 if(vec_state == 2) { arma_debug_check( (in_rows != 1), "SpMat::init(): object is a row vector; requested size is not compatible" ); }
5059 }
5060 }
5061
5062 #if defined(ARMA_64BIT_WORD)
5063 const char* error_message = "SpMat::init(): requested size is too large";
5064 #else
5065 const char* error_message = "SpMat::init(): requested size is too large; suggest to enable ARMA_64BIT_WORD";
5066 #endif
5067
5068 // Ensure that n_elem can hold the result of (n_rows * n_cols)
5069 arma_debug_check
5070 (
5071 (
5072 ( (in_rows > ARMA_MAX_UHWORD) || (in_cols > ARMA_MAX_UHWORD) )
5073 ? ( (double(in_rows) * double(in_cols)) > double(ARMA_MAX_UWORD) )
5074 : false
5075 ),
5076 error_message
5077 );
5078
5079 access::rw(col_ptrs) = memory::acquire<uword>(in_cols + 2);
5080 access::rw(values) = memory::acquire<eT> (new_n_nonzero + 1);
5081 access::rw(row_indices) = memory::acquire<uword>(new_n_nonzero + 1);
5082
5083 // fill column pointers with 0,
5084 // except for the last element which contains the maximum possible element
5085 // (so iterators terminate correctly).
5086 arrayops::fill_zeros(access::rwp(col_ptrs), in_cols + 1);
5087
5088 access::rw(col_ptrs[in_cols + 1]) = std::numeric_limits<uword>::max();
5089
5090 access::rw( values[new_n_nonzero]) = 0;
5091 access::rw(row_indices[new_n_nonzero]) = 0;
5092
5093 // Set the new size accordingly.
5094 access::rw(n_rows) = in_rows;
5095 access::rw(n_cols) = in_cols;
5096 access::rw(n_elem) = (in_rows * in_cols);
5097 access::rw(n_nonzero) = new_n_nonzero;
5098 }
5099
5100
5101
5102 template<typename eT>
5103 inline
5104 void
init(const std::string & text)5105 SpMat<eT>::init(const std::string& text)
5106 {
5107 arma_extra_debug_sigprint();
5108
5109 Mat<eT> tmp(text);
5110
5111 if(vec_state == 1)
5112 {
5113 if((tmp.n_elem > 0) && tmp.is_vec())
5114 {
5115 access::rw(tmp.n_rows) = tmp.n_elem;
5116 access::rw(tmp.n_cols) = 1;
5117 }
5118 }
5119
5120 if(vec_state == 2)
5121 {
5122 if((tmp.n_elem > 0) && tmp.is_vec())
5123 {
5124 access::rw(tmp.n_rows) = 1;
5125 access::rw(tmp.n_cols) = tmp.n_elem;
5126 }
5127 }
5128
5129 (*this).operator=(tmp);
5130 }
5131
5132
5133
5134 template<typename eT>
5135 inline
5136 void
init(const SpMat<eT> & x)5137 SpMat<eT>::init(const SpMat<eT>& x)
5138 {
5139 arma_extra_debug_sigprint();
5140
5141 if(this == &x) { return; }
5142
5143 bool init_done = false;
5144
5145 #if defined(ARMA_USE_OPENMP)
5146 if(x.sync_state == 1)
5147 {
5148 #pragma omp critical (arma_SpMat_init)
5149 if(x.sync_state == 1)
5150 {
5151 (*this).init(x.cache);
5152 init_done = true;
5153 }
5154 }
5155 #elif (!defined(ARMA_DONT_USE_STD_MUTEX))
5156 if(x.sync_state == 1)
5157 {
5158 x.cache_mutex.lock();
5159 if(x.sync_state == 1)
5160 {
5161 (*this).init(x.cache);
5162 init_done = true;
5163 }
5164 x.cache_mutex.unlock();
5165 }
5166 #else
5167 if(x.sync_state == 1)
5168 {
5169 (*this).init(x.cache);
5170 init_done = true;
5171 }
5172 #endif
5173
5174 if(init_done == false)
5175 {
5176 (*this).init_simple(x);
5177 }
5178 }
5179
5180
5181
5182 template<typename eT>
5183 inline
5184 void
init(const MapMat<eT> & x)5185 SpMat<eT>::init(const MapMat<eT>& x)
5186 {
5187 arma_extra_debug_sigprint();
5188
5189 const uword x_n_rows = x.n_rows;
5190 const uword x_n_cols = x.n_cols;
5191 const uword x_n_nz = x.get_n_nonzero();
5192
5193 init(x_n_rows, x_n_cols, x_n_nz);
5194
5195 if(x_n_nz == 0) { return; }
5196
5197 typename MapMat<eT>::map_type& x_map_ref = *(x.map_ptr);
5198
5199 typename MapMat<eT>::map_type::const_iterator x_it = x_map_ref.begin();
5200
5201 uword x_col = 0;
5202 uword x_col_index_start = 0;
5203 uword x_col_index_endp1 = x_n_rows;
5204
5205 for(uword i=0; i < x_n_nz; ++i)
5206 {
5207 const std::pair<uword, eT>& x_entry = (*x_it);
5208
5209 const uword x_index = x_entry.first;
5210 const eT x_val = x_entry.second;
5211
5212 // have we gone past the curent column?
5213 if(x_index >= x_col_index_endp1)
5214 {
5215 x_col = x_index / x_n_rows;
5216
5217 x_col_index_start = x_col * x_n_rows;
5218 x_col_index_endp1 = x_col_index_start + x_n_rows;
5219 }
5220
5221 const uword x_row = x_index - x_col_index_start;
5222
5223 // // sanity check
5224 //
5225 // const uword tmp_x_row = x_index % x_n_rows;
5226 // const uword tmp_x_col = x_index / x_n_rows;
5227 //
5228 // if(x_row != tmp_x_row) { cout << "x_row != tmp_x_row" << endl; exit(-1); }
5229 // if(x_col != tmp_x_col) { cout << "x_col != tmp_x_col" << endl; exit(-1); }
5230
5231 access::rw(values[i]) = x_val;
5232 access::rw(row_indices[i]) = x_row;
5233
5234 access::rw(col_ptrs[ x_col + 1 ])++;
5235
5236 ++x_it;
5237 }
5238
5239
5240 for(uword i = 0; i < x_n_cols; ++i)
5241 {
5242 access::rw(col_ptrs[i + 1]) += col_ptrs[i];
5243 }
5244
5245
5246 // // OLD METHOD
5247 //
5248 // for(uword i=0; i < x_n_nz; ++i)
5249 // {
5250 // const std::pair<uword, eT>& x_entry = (*x_it);
5251 //
5252 // const uword x_index = x_entry.first;
5253 // const eT x_val = x_entry.second;
5254 //
5255 // const uword x_row = x_index % x_n_rows;
5256 // const uword x_col = x_index / x_n_rows;
5257 //
5258 // access::rw(values[i]) = x_val;
5259 // access::rw(row_indices[i]) = x_row;
5260 //
5261 // access::rw(col_ptrs[ x_col + 1 ])++;
5262 //
5263 // ++x_it;
5264 // }
5265 //
5266 //
5267 // for(uword i = 0; i < x_n_cols; ++i)
5268 // {
5269 // access::rw(col_ptrs[i + 1]) += col_ptrs[i];
5270 // }
5271 }
5272
5273
5274
5275 template<typename eT>
5276 inline
5277 void
init_simple(const SpMat<eT> & x)5278 SpMat<eT>::init_simple(const SpMat<eT>& x)
5279 {
5280 arma_extra_debug_sigprint();
5281
5282 if(this == &x) { return; }
5283
5284 init(x.n_rows, x.n_cols, x.n_nonzero);
5285
5286 if(x.values ) { arrayops::copy(access::rwp(values), x.values, x.n_nonzero + 1); }
5287 if(x.row_indices) { arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1); }
5288 if(x.col_ptrs ) { arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1); }
5289 }
5290
5291
5292
5293 template<typename eT>
5294 inline
5295 void
init_batch_std(const Mat<uword> & locs,const Mat<eT> & vals,const bool sort_locations)5296 SpMat<eT>::init_batch_std(const Mat<uword>& locs, const Mat<eT>& vals, const bool sort_locations)
5297 {
5298 arma_extra_debug_sigprint();
5299
5300 // Resize to correct number of elements.
5301 mem_resize(vals.n_elem);
5302
5303 // Reset column pointers to zero.
5304 arrayops::fill_zeros(access::rwp(col_ptrs), n_cols + 1);
5305
5306 bool actually_sorted = true;
5307
5308 if(sort_locations)
5309 {
5310 // check if we really need a time consuming sort
5311
5312 const uword locs_n_cols = locs.n_cols;
5313
5314 for(uword i = 1; i < locs_n_cols; ++i)
5315 {
5316 const uword* locs_i = locs.colptr(i );
5317 const uword* locs_im1 = locs.colptr(i-1);
5318
5319 const uword row_i = locs_i[0];
5320 const uword col_i = locs_i[1];
5321
5322 const uword row_im1 = locs_im1[0];
5323 const uword col_im1 = locs_im1[1];
5324
5325 if( (col_i < col_im1) || ((col_i == col_im1) && (row_i <= row_im1)) )
5326 {
5327 actually_sorted = false;
5328 break;
5329 }
5330 }
5331
5332 if(actually_sorted == false)
5333 {
5334 // see op_sort_index_bones.hpp for the definition of arma_sort_index_packet and arma_sort_index_helper_ascend
5335
5336 std::vector< arma_sort_index_packet<uword> > packet_vec(locs_n_cols);
5337
5338 const uword* locs_mem = locs.memptr();
5339
5340 for(uword i = 0; i < locs_n_cols; ++i)
5341 {
5342 const uword row = (*locs_mem); locs_mem++;
5343 const uword col = (*locs_mem); locs_mem++;
5344
5345 packet_vec[i].val = (col * n_rows) + row;
5346 packet_vec[i].index = i;
5347 }
5348
5349 arma_sort_index_helper_ascend<uword> comparator;
5350
5351 std::sort( packet_vec.begin(), packet_vec.end(), comparator );
5352
5353 // insert the elements in the sorted order
5354 for(uword i = 0; i < locs_n_cols; ++i)
5355 {
5356 const uword index = packet_vec[i].index;
5357
5358 const uword* locs_i = locs.colptr(index);
5359
5360 const uword row_i = locs_i[0];
5361 const uword col_i = locs_i[1];
5362
5363 arma_debug_check( ( (row_i >= n_rows) || (col_i >= n_cols) ), "SpMat::SpMat(): invalid row or column index" );
5364
5365 if(i > 0)
5366 {
5367 const uword prev_index = packet_vec[i-1].index;
5368
5369 const uword* locs_im1 = locs.colptr(prev_index);
5370
5371 const uword row_im1 = locs_im1[0];
5372 const uword col_im1 = locs_im1[1];
5373
5374 arma_debug_check( ( (row_i == row_im1) && (col_i == col_im1) ), "SpMat::SpMat(): detected identical locations" );
5375 }
5376
5377 access::rw(values[i]) = vals[index];
5378 access::rw(row_indices[i]) = row_i;
5379
5380 access::rw(col_ptrs[ col_i + 1 ])++;
5381 }
5382 }
5383 }
5384
5385 if( (sort_locations == false) || (actually_sorted == true) )
5386 {
5387 // Now set the values and row indices correctly.
5388 // Increment the column pointers in each column (so they are column "counts").
5389
5390 const uword locs_n_cols = locs.n_cols;
5391
5392 for(uword i=0; i < locs_n_cols; ++i)
5393 {
5394 const uword* locs_i = locs.colptr(i);
5395
5396 const uword row_i = locs_i[0];
5397 const uword col_i = locs_i[1];
5398
5399 arma_debug_check( ( (row_i >= n_rows) || (col_i >= n_cols) ), "SpMat::SpMat(): invalid row or column index" );
5400
5401 if(i > 0)
5402 {
5403 const uword* locs_im1 = locs.colptr(i-1);
5404
5405 const uword row_im1 = locs_im1[0];
5406 const uword col_im1 = locs_im1[1];
5407
5408 arma_debug_check
5409 (
5410 ( (col_i < col_im1) || ((col_i == col_im1) && (row_i < row_im1)) ),
5411 "SpMat::SpMat(): out of order points; either pass sort_locations = true, or sort points in column-major ordering"
5412 );
5413
5414 arma_debug_check( ( (col_i == col_im1) && (row_i == row_im1) ), "SpMat::SpMat(): detected identical locations" );
5415 }
5416
5417 access::rw(values[i]) = vals[i];
5418 access::rw(row_indices[i]) = row_i;
5419
5420 access::rw(col_ptrs[ col_i + 1 ])++;
5421 }
5422 }
5423
5424 // Now fix the column pointers.
5425 for(uword i = 0; i < n_cols; ++i)
5426 {
5427 access::rw(col_ptrs[i + 1]) += col_ptrs[i];
5428 }
5429 }
5430
5431
5432
5433 template<typename eT>
5434 inline
5435 void
init_batch_add(const Mat<uword> & locs,const Mat<eT> & vals,const bool sort_locations)5436 SpMat<eT>::init_batch_add(const Mat<uword>& locs, const Mat<eT>& vals, const bool sort_locations)
5437 {
5438 arma_extra_debug_sigprint();
5439
5440 if(locs.n_cols < 2)
5441 {
5442 init_batch_std(locs, vals, false);
5443 return;
5444 }
5445
5446 // Reset column pointers to zero.
5447 arrayops::fill_zeros(access::rwp(col_ptrs), n_cols + 1);
5448
5449 bool actually_sorted = true;
5450
5451 if(sort_locations)
5452 {
5453 // sort_index() uses std::sort() which may use quicksort... so we better
5454 // make sure it's not already sorted before taking an O(N^2) sort penalty.
5455 for(uword i = 1; i < locs.n_cols; ++i)
5456 {
5457 const uword* locs_i = locs.colptr(i );
5458 const uword* locs_im1 = locs.colptr(i-1);
5459
5460 if( (locs_i[1] < locs_im1[1]) || (locs_i[1] == locs_im1[1] && locs_i[0] <= locs_im1[0]) )
5461 {
5462 actually_sorted = false;
5463 break;
5464 }
5465 }
5466
5467 if(actually_sorted == false)
5468 {
5469 // This may not be the fastest possible implementation but it maximizes code reuse.
5470 Col<uword> abslocs(locs.n_cols, arma_nozeros_indicator());
5471
5472 for(uword i = 0; i < locs.n_cols; ++i)
5473 {
5474 const uword* locs_i = locs.colptr(i);
5475
5476 abslocs[i] = locs_i[1] * n_rows + locs_i[0];
5477 }
5478
5479 uvec sorted_indices = sort_index(abslocs); // Ascending sort.
5480
5481 // work out the number of unique elments
5482 uword n_unique = 1; // first element is unique
5483
5484 for(uword i=1; i < sorted_indices.n_elem; ++i)
5485 {
5486 const uword* locs_i = locs.colptr( sorted_indices[i ] );
5487 const uword* locs_im1 = locs.colptr( sorted_indices[i-1] );
5488
5489 if( (locs_i[1] != locs_im1[1]) || (locs_i[0] != locs_im1[0]) ) { ++n_unique; }
5490 }
5491
5492 // resize to correct number of elements
5493 mem_resize(n_unique);
5494
5495 // Now we add the elements in this sorted order.
5496 uword count = 0;
5497
5498 // first element
5499 {
5500 const uword i = 0;
5501 const uword* locs_i = locs.colptr( sorted_indices[i] );
5502
5503 arma_debug_check( ( (locs_i[0] >= n_rows) || (locs_i[1] >= n_cols) ), "SpMat::SpMat(): invalid row or column index" );
5504
5505 access::rw(values[count]) = vals[ sorted_indices[i] ];
5506 access::rw(row_indices[count]) = locs_i[0];
5507
5508 access::rw(col_ptrs[ locs_i[1] + 1 ])++;
5509 }
5510
5511 for(uword i=1; i < sorted_indices.n_elem; ++i)
5512 {
5513 const uword* locs_i = locs.colptr( sorted_indices[i ] );
5514 const uword* locs_im1 = locs.colptr( sorted_indices[i-1] );
5515
5516 arma_debug_check( ( (locs_i[0] >= n_rows) || (locs_i[1] >= n_cols) ), "SpMat::SpMat(): invalid row or column index" );
5517
5518 if( (locs_i[1] == locs_im1[1]) && (locs_i[0] == locs_im1[0]) )
5519 {
5520 access::rw(values[count]) += vals[ sorted_indices[i] ];
5521 }
5522 else
5523 {
5524 count++;
5525 access::rw(values[count]) = vals[ sorted_indices[i] ];
5526 access::rw(row_indices[count]) = locs_i[0];
5527
5528 access::rw(col_ptrs[ locs_i[1] + 1 ])++;
5529 }
5530 }
5531 }
5532 }
5533
5534 if( (sort_locations == false) || (actually_sorted == true) )
5535 {
5536 // work out the number of unique elments
5537 uword n_unique = 1; // first element is unique
5538
5539 for(uword i=1; i < locs.n_cols; ++i)
5540 {
5541 const uword* locs_i = locs.colptr(i );
5542 const uword* locs_im1 = locs.colptr(i-1);
5543
5544 if( (locs_i[1] != locs_im1[1]) || (locs_i[0] != locs_im1[0]) ) { ++n_unique; }
5545 }
5546
5547 // resize to correct number of elements
5548 mem_resize(n_unique);
5549
5550 // Now set the values and row indices correctly.
5551 // Increment the column pointers in each column (so they are column "counts").
5552
5553 uword count = 0;
5554
5555 // first element
5556 {
5557 const uword i = 0;
5558 const uword* locs_i = locs.colptr(i);
5559
5560 arma_debug_check( ( (locs_i[0] >= n_rows) || (locs_i[1] >= n_cols) ), "SpMat::SpMat(): invalid row or column index" );
5561
5562 access::rw(values[count]) = vals[i];
5563 access::rw(row_indices[count]) = locs_i[0];
5564
5565 access::rw(col_ptrs[ locs_i[1] + 1 ])++;
5566 }
5567
5568 for(uword i=1; i < locs.n_cols; ++i)
5569 {
5570 const uword* locs_i = locs.colptr(i );
5571 const uword* locs_im1 = locs.colptr(i-1);
5572
5573 arma_debug_check( ( (locs_i[0] >= n_rows) || (locs_i[1] >= n_cols) ), "SpMat::SpMat(): invalid row or column index" );
5574
5575 arma_debug_check
5576 (
5577 ( (locs_i[1] < locs_im1[1]) || (locs_i[1] == locs_im1[1] && locs_i[0] < locs_im1[0]) ),
5578 "SpMat::SpMat(): out of order points; either pass sort_locations = true, or sort points in column-major ordering"
5579 );
5580
5581 if( (locs_i[1] == locs_im1[1]) && (locs_i[0] == locs_im1[0]) )
5582 {
5583 access::rw(values[count]) += vals[i];
5584 }
5585 else
5586 {
5587 count++;
5588
5589 access::rw(values[count]) = vals[i];
5590 access::rw(row_indices[count]) = locs_i[0];
5591
5592 access::rw(col_ptrs[ locs_i[1] + 1 ])++;
5593 }
5594 }
5595 }
5596
5597 // Now fix the column pointers.
5598 for(uword i = 0; i < n_cols; ++i)
5599 {
5600 access::rw(col_ptrs[i + 1]) += col_ptrs[i];
5601 }
5602 }
5603
5604
5605
5606 //! constructor used by SpRow and SpCol classes
5607 template<typename eT>
5608 inline
SpMat(const arma_vec_indicator &,const uword in_vec_state)5609 SpMat<eT>::SpMat(const arma_vec_indicator&, const uword in_vec_state)
5610 : n_rows(0)
5611 , n_cols(0)
5612 , n_elem(0)
5613 , n_nonzero(0)
5614 , vec_state(in_vec_state)
5615 , values(nullptr)
5616 , row_indices(nullptr)
5617 , col_ptrs(nullptr)
5618 {
5619 arma_extra_debug_sigprint_this(this);
5620
5621 const uword in_n_rows = (in_vec_state == 2) ? 1 : 0;
5622 const uword in_n_cols = (in_vec_state == 1) ? 1 : 0;
5623
5624 init_cold(in_n_rows, in_n_cols);
5625 }
5626
5627
5628
5629 //! constructor used by SpRow and SpCol classes
5630 template<typename eT>
5631 inline
SpMat(const arma_vec_indicator &,const uword in_n_rows,const uword in_n_cols,const uword in_vec_state)5632 SpMat<eT>::SpMat(const arma_vec_indicator&, const uword in_n_rows, const uword in_n_cols, const uword in_vec_state)
5633 : n_rows(0)
5634 , n_cols(0)
5635 , n_elem(0)
5636 , n_nonzero(0)
5637 , vec_state(in_vec_state)
5638 , values(nullptr)
5639 , row_indices(nullptr)
5640 , col_ptrs(nullptr)
5641 {
5642 arma_extra_debug_sigprint_this(this);
5643
5644 init_cold(in_n_rows, in_n_cols);
5645 }
5646
5647
5648
5649 template<typename eT>
5650 inline
5651 void
mem_resize(const uword new_n_nonzero)5652 SpMat<eT>::mem_resize(const uword new_n_nonzero)
5653 {
5654 arma_extra_debug_sigprint();
5655
5656 invalidate_cache(); // placed here, as mem_resize() is used during matrix modification
5657
5658 if(n_nonzero == new_n_nonzero) { return; }
5659
5660 eT* new_values = memory::acquire<eT> (new_n_nonzero + 1);
5661 uword* new_row_indices = memory::acquire<uword>(new_n_nonzero + 1);
5662
5663 if( (n_nonzero > 0 ) && (new_n_nonzero > 0) )
5664 {
5665 // Copy old elements.
5666 uword copy_len = (std::min)(n_nonzero, new_n_nonzero);
5667
5668 arrayops::copy(new_values, values, copy_len);
5669 arrayops::copy(new_row_indices, row_indices, copy_len);
5670 }
5671
5672 if(values) { memory::release(access::rw(values)); }
5673 if(row_indices) { memory::release(access::rw(row_indices)); }
5674
5675 access::rw(values) = new_values;
5676 access::rw(row_indices) = new_row_indices;
5677
5678 // Set the "fake end" of the matrix by setting the last value and row index to 0.
5679 // This helps the iterators work correctly.
5680 access::rw( values[new_n_nonzero]) = 0;
5681 access::rw(row_indices[new_n_nonzero]) = 0;
5682
5683 access::rw(n_nonzero) = new_n_nonzero;
5684 }
5685
5686
5687
5688 template<typename eT>
5689 inline
5690 void
sync() const5691 SpMat<eT>::sync() const
5692 {
5693 arma_extra_debug_sigprint();
5694
5695 sync_csc();
5696 }
5697
5698
5699
5700 template<typename eT>
5701 inline
5702 void
remove_zeros()5703 SpMat<eT>::remove_zeros()
5704 {
5705 arma_extra_debug_sigprint();
5706
5707 sync_csc();
5708
5709 invalidate_cache(); // placed here, as remove_zeros() is used during matrix modification
5710
5711 const uword old_n_nonzero = n_nonzero;
5712 uword new_n_nonzero = 0;
5713
5714 const eT* old_values = values;
5715
5716 for(uword i=0; i < old_n_nonzero; ++i)
5717 {
5718 new_n_nonzero += (old_values[i] != eT(0)) ? uword(1) : uword(0);
5719 }
5720
5721 if(new_n_nonzero != old_n_nonzero)
5722 {
5723 if(new_n_nonzero == 0) { init(n_rows, n_cols); return; }
5724
5725 SpMat<eT> tmp(arma_reserve_indicator(), n_rows, n_cols, new_n_nonzero);
5726
5727 uword new_index = 0;
5728
5729 const_iterator it = begin();
5730 const_iterator it_end = end();
5731
5732 for(; it != it_end; ++it)
5733 {
5734 const eT val = eT(*it);
5735
5736 if(val != eT(0))
5737 {
5738 access::rw(tmp.values[new_index]) = val;
5739 access::rw(tmp.row_indices[new_index]) = it.row();
5740 access::rw(tmp.col_ptrs[it.col() + 1])++;
5741 ++new_index;
5742 }
5743 }
5744
5745 for(uword i=0; i < n_cols; ++i)
5746 {
5747 access::rw(tmp.col_ptrs[i + 1]) += tmp.col_ptrs[i];
5748 }
5749
5750 steal_mem(tmp);
5751 }
5752 }
5753
5754
5755
5756 // Steal memory from another matrix.
5757 template<typename eT>
5758 inline
5759 void
steal_mem(SpMat<eT> & x)5760 SpMat<eT>::steal_mem(SpMat<eT>& x)
5761 {
5762 arma_extra_debug_sigprint();
5763
5764 if(this == &x) { return; }
5765
5766 bool layout_ok = false;
5767
5768 if((*this).vec_state == x.vec_state)
5769 {
5770 layout_ok = true;
5771 }
5772 else
5773 {
5774 if( ((*this).vec_state == 1) && (x.n_cols == 1) ) { layout_ok = true; }
5775 if( ((*this).vec_state == 2) && (x.n_rows == 1) ) { layout_ok = true; }
5776 }
5777
5778 if(layout_ok)
5779 {
5780 x.sync_csc();
5781
5782 steal_mem_simple(x);
5783
5784 x.invalidate_cache();
5785
5786 invalidate_cache();
5787 }
5788 else
5789 {
5790 (*this).operator=(x);
5791 }
5792 }
5793
5794
5795
5796 template<typename eT>
5797 inline
5798 void
steal_mem_simple(SpMat<eT> & x)5799 SpMat<eT>::steal_mem_simple(SpMat<eT>& x)
5800 {
5801 arma_extra_debug_sigprint();
5802
5803 if(this == &x) { return; }
5804
5805 if(values ) { memory::release(access::rw(values)); }
5806 if(row_indices) { memory::release(access::rw(row_indices)); }
5807 if(col_ptrs ) { memory::release(access::rw(col_ptrs)); }
5808
5809 access::rw(n_rows) = x.n_rows;
5810 access::rw(n_cols) = x.n_cols;
5811 access::rw(n_elem) = x.n_elem;
5812 access::rw(n_nonzero) = x.n_nonzero;
5813
5814 access::rw(values) = x.values;
5815 access::rw(row_indices) = x.row_indices;
5816 access::rw(col_ptrs) = x.col_ptrs;
5817
5818 // Set other matrix to empty.
5819 access::rw(x.n_rows) = 0;
5820 access::rw(x.n_cols) = 0;
5821 access::rw(x.n_elem) = 0;
5822 access::rw(x.n_nonzero) = 0;
5823
5824 access::rw(x.values) = nullptr;
5825 access::rw(x.row_indices) = nullptr;
5826 access::rw(x.col_ptrs) = nullptr;
5827 }
5828
5829
5830
5831 template<typename eT>
5832 template<typename T1, typename Functor>
5833 inline
5834 void
init_xform(const SpBase<eT,T1> & A,const Functor & func)5835 SpMat<eT>::init_xform(const SpBase<eT,T1>& A, const Functor& func)
5836 {
5837 arma_extra_debug_sigprint();
5838
5839 // if possible, avoid doing a copy and instead apply func to the generated elements
5840 if(SpProxy<T1>::Q_is_generated)
5841 {
5842 (*this) = A.get_ref();
5843
5844 const uword nnz = n_nonzero;
5845
5846 eT* t_values = access::rwp(values);
5847
5848 bool has_zero = false;
5849
5850 for(uword i=0; i < nnz; ++i)
5851 {
5852 eT& t_values_i = t_values[i];
5853
5854 t_values_i = func(t_values_i);
5855
5856 if(t_values_i == eT(0)) { has_zero = true; }
5857 }
5858
5859 if(has_zero) { remove_zeros(); }
5860 }
5861 else
5862 {
5863 init_xform_mt(A.get_ref(), func);
5864 }
5865 }
5866
5867
5868
5869 template<typename eT>
5870 template<typename eT2, typename T1, typename Functor>
5871 inline
5872 void
init_xform_mt(const SpBase<eT2,T1> & A,const Functor & func)5873 SpMat<eT>::init_xform_mt(const SpBase<eT2,T1>& A, const Functor& func)
5874 {
5875 arma_extra_debug_sigprint();
5876
5877 const SpProxy<T1> P(A.get_ref());
5878
5879 if( P.is_alias(*this) || (is_SpMat<typename SpProxy<T1>::stored_type>::value) )
5880 {
5881 // NOTE: unwrap_spmat will convert a submatrix to a matrix, which in effect takes care of aliasing with submatrices;
5882 // NOTE: however, when more delayed ops are implemented, more elaborate handling of aliasing will be necessary
5883 const unwrap_spmat<typename SpProxy<T1>::stored_type> tmp(P.Q);
5884
5885 const SpMat<eT2>& x = tmp.M;
5886
5887 if(void_ptr(this) != void_ptr(&x))
5888 {
5889 init(x.n_rows, x.n_cols, x.n_nonzero);
5890
5891 arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1);
5892 arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1);
5893 }
5894
5895
5896 // initialise the elements array with a transformed version of the elements from x
5897
5898 const uword nnz = n_nonzero;
5899
5900 const eT2* x_values = x.values;
5901 eT* t_values = access::rwp(values);
5902
5903 bool has_zero = false;
5904
5905 for(uword i=0; i < nnz; ++i)
5906 {
5907 eT& t_values_i = t_values[i];
5908
5909 t_values_i = func(x_values[i]); // NOTE: func() must produce a value of type eT (ie. act as a convertor between eT2 and eT)
5910
5911 if(t_values_i == eT(0)) { has_zero = true; }
5912 }
5913
5914 if(has_zero) { remove_zeros(); }
5915 }
5916 else
5917 {
5918 init(P.get_n_rows(), P.get_n_cols(), P.get_n_nonzero());
5919
5920 typename SpProxy<T1>::const_iterator_type it = P.begin();
5921 typename SpProxy<T1>::const_iterator_type it_end = P.end();
5922
5923 bool has_zero = false;
5924
5925 while(it != it_end)
5926 {
5927 const eT val = func(*it); // NOTE: func() must produce a value of type eT (ie. act as a convertor between eT2 and eT)
5928
5929 if(val == eT(0)) { has_zero = true; }
5930
5931 const uword it_pos = it.pos();
5932
5933 access::rw(row_indices[it_pos]) = it.row();
5934 access::rw(values[it_pos]) = val;
5935 ++access::rw(col_ptrs[it.col() + 1]);
5936 ++it;
5937 }
5938
5939 // Now sum column pointers.
5940 for(uword c = 1; c <= n_cols; ++c)
5941 {
5942 access::rw(col_ptrs[c]) += col_ptrs[c - 1];
5943 }
5944
5945 if(has_zero) { remove_zeros(); }
5946 }
5947 }
5948
5949
5950
5951 template<typename eT>
5952 arma_inline
5953 bool
is_alias(const SpMat<eT> & X) const5954 SpMat<eT>::is_alias(const SpMat<eT>& X) const
5955 {
5956 return (&X == this);
5957 }
5958
5959
5960
5961 template<typename eT>
5962 inline
5963 typename SpMat<eT>::iterator
begin()5964 SpMat<eT>::begin()
5965 {
5966 arma_extra_debug_sigprint();
5967
5968 sync_csc();
5969
5970 return iterator(*this);
5971 }
5972
5973
5974
5975 template<typename eT>
5976 inline
5977 typename SpMat<eT>::const_iterator
begin() const5978 SpMat<eT>::begin() const
5979 {
5980 arma_extra_debug_sigprint();
5981
5982 sync_csc();
5983
5984 return const_iterator(*this);
5985 }
5986
5987
5988
5989 template<typename eT>
5990 inline
5991 typename SpMat<eT>::const_iterator
cbegin() const5992 SpMat<eT>::cbegin() const
5993 {
5994 arma_extra_debug_sigprint();
5995
5996 sync_csc();
5997
5998 return const_iterator(*this);
5999 }
6000
6001
6002
6003 template<typename eT>
6004 inline
6005 typename SpMat<eT>::iterator
end()6006 SpMat<eT>::end()
6007 {
6008 sync_csc();
6009
6010 return iterator(*this, 0, n_cols, n_nonzero);
6011 }
6012
6013
6014
6015 template<typename eT>
6016 inline
6017 typename SpMat<eT>::const_iterator
end() const6018 SpMat<eT>::end() const
6019 {
6020 sync_csc();
6021
6022 return const_iterator(*this, 0, n_cols, n_nonzero);
6023 }
6024
6025
6026
6027 template<typename eT>
6028 inline
6029 typename SpMat<eT>::const_iterator
cend() const6030 SpMat<eT>::cend() const
6031 {
6032 sync_csc();
6033
6034 return const_iterator(*this, 0, n_cols, n_nonzero);
6035 }
6036
6037
6038
6039 template<typename eT>
6040 inline
6041 typename SpMat<eT>::col_iterator
begin_col(const uword col_num)6042 SpMat<eT>::begin_col(const uword col_num)
6043 {
6044 sync_csc();
6045
6046 return col_iterator(*this, 0, col_num);
6047 }
6048
6049
6050
6051 template<typename eT>
6052 inline
6053 typename SpMat<eT>::const_col_iterator
begin_col(const uword col_num) const6054 SpMat<eT>::begin_col(const uword col_num) const
6055 {
6056 sync_csc();
6057
6058 return const_col_iterator(*this, 0, col_num);
6059 }
6060
6061
6062
6063 template<typename eT>
6064 inline
6065 typename SpMat<eT>::col_iterator
begin_col_no_sync(const uword col_num)6066 SpMat<eT>::begin_col_no_sync(const uword col_num)
6067 {
6068 return col_iterator(*this, 0, col_num);
6069 }
6070
6071
6072
6073 template<typename eT>
6074 inline
6075 typename SpMat<eT>::const_col_iterator
begin_col_no_sync(const uword col_num) const6076 SpMat<eT>::begin_col_no_sync(const uword col_num) const
6077 {
6078 return const_col_iterator(*this, 0, col_num);
6079 }
6080
6081
6082
6083 template<typename eT>
6084 inline
6085 typename SpMat<eT>::col_iterator
end_col(const uword col_num)6086 SpMat<eT>::end_col(const uword col_num)
6087 {
6088 sync_csc();
6089
6090 return col_iterator(*this, 0, col_num + 1);
6091 }
6092
6093
6094
6095 template<typename eT>
6096 inline
6097 typename SpMat<eT>::const_col_iterator
end_col(const uword col_num) const6098 SpMat<eT>::end_col(const uword col_num) const
6099 {
6100 sync_csc();
6101
6102 return const_col_iterator(*this, 0, col_num + 1);
6103 }
6104
6105
6106
6107 template<typename eT>
6108 inline
6109 typename SpMat<eT>::col_iterator
end_col_no_sync(const uword col_num)6110 SpMat<eT>::end_col_no_sync(const uword col_num)
6111 {
6112 return col_iterator(*this, 0, col_num + 1);
6113 }
6114
6115
6116
6117 template<typename eT>
6118 inline
6119 typename SpMat<eT>::const_col_iterator
end_col_no_sync(const uword col_num) const6120 SpMat<eT>::end_col_no_sync(const uword col_num) const
6121 {
6122 return const_col_iterator(*this, 0, col_num + 1);
6123 }
6124
6125
6126
6127 template<typename eT>
6128 inline
6129 typename SpMat<eT>::row_iterator
begin_row(const uword row_num)6130 SpMat<eT>::begin_row(const uword row_num)
6131 {
6132 sync_csc();
6133
6134 return row_iterator(*this, row_num, 0);
6135 }
6136
6137
6138
6139 template<typename eT>
6140 inline
6141 typename SpMat<eT>::const_row_iterator
begin_row(const uword row_num) const6142 SpMat<eT>::begin_row(const uword row_num) const
6143 {
6144 sync_csc();
6145
6146 return const_row_iterator(*this, row_num, 0);
6147 }
6148
6149
6150
6151 template<typename eT>
6152 inline
6153 typename SpMat<eT>::row_iterator
end_row()6154 SpMat<eT>::end_row()
6155 {
6156 sync_csc();
6157
6158 return row_iterator(*this, n_nonzero);
6159 }
6160
6161
6162
6163 template<typename eT>
6164 inline
6165 typename SpMat<eT>::const_row_iterator
end_row() const6166 SpMat<eT>::end_row() const
6167 {
6168 sync_csc();
6169
6170 return const_row_iterator(*this, n_nonzero);
6171 }
6172
6173
6174
6175 template<typename eT>
6176 inline
6177 typename SpMat<eT>::row_iterator
end_row(const uword row_num)6178 SpMat<eT>::end_row(const uword row_num)
6179 {
6180 sync_csc();
6181
6182 return row_iterator(*this, row_num + 1, 0);
6183 }
6184
6185
6186
6187 template<typename eT>
6188 inline
6189 typename SpMat<eT>::const_row_iterator
end_row(const uword row_num) const6190 SpMat<eT>::end_row(const uword row_num) const
6191 {
6192 sync_csc();
6193
6194 return const_row_iterator(*this, row_num + 1, 0);
6195 }
6196
6197
6198
6199 template<typename eT>
6200 inline
6201 typename SpMat<eT>::row_col_iterator
begin_row_col()6202 SpMat<eT>::begin_row_col()
6203 {
6204 sync_csc();
6205
6206 return begin();
6207 }
6208
6209
6210
6211 template<typename eT>
6212 inline
6213 typename SpMat<eT>::const_row_col_iterator
begin_row_col() const6214 SpMat<eT>::begin_row_col() const
6215 {
6216 sync_csc();
6217
6218 return begin();
6219 }
6220
6221
6222
6223 template<typename eT>
6224 inline typename SpMat<eT>::row_col_iterator
end_row_col()6225 SpMat<eT>::end_row_col()
6226 {
6227 sync_csc();
6228
6229 return end();
6230 }
6231
6232
6233
6234 template<typename eT>
6235 inline
6236 typename SpMat<eT>::const_row_col_iterator
end_row_col() const6237 SpMat<eT>::end_row_col() const
6238 {
6239 sync_csc();
6240
6241 return end();
6242 }
6243
6244
6245
6246 template<typename eT>
6247 inline
6248 void
clear()6249 SpMat<eT>::clear()
6250 {
6251 (*this).reset();
6252 }
6253
6254
6255
6256 template<typename eT>
6257 inline
6258 bool
empty() const6259 SpMat<eT>::empty() const
6260 {
6261 return (n_elem == 0);
6262 }
6263
6264
6265
6266 template<typename eT>
6267 inline
6268 uword
size() const6269 SpMat<eT>::size() const
6270 {
6271 return n_elem;
6272 }
6273
6274
6275
6276 template<typename eT>
6277 arma_inline
6278 arma_warn_unused
6279 SpMat_MapMat_val<eT>
front()6280 SpMat<eT>::front()
6281 {
6282 arma_debug_check( (n_elem == 0), "SpMat::front(): matrix is empty" );
6283
6284 return SpMat_MapMat_val<eT>((*this), cache, 0, 0);
6285 }
6286
6287
6288
6289 template<typename eT>
6290 arma_inline
6291 arma_warn_unused
6292 eT
front() const6293 SpMat<eT>::front() const
6294 {
6295 arma_debug_check( (n_elem == 0), "SpMat::front(): matrix is empty" );
6296
6297 return get_value(0,0);
6298 }
6299
6300
6301
6302 template<typename eT>
6303 arma_inline
6304 arma_warn_unused
6305 SpMat_MapMat_val<eT>
back()6306 SpMat<eT>::back()
6307 {
6308 arma_debug_check( (n_elem == 0), "SpMat::back(): matrix is empty" );
6309
6310 return SpMat_MapMat_val<eT>((*this), cache, n_rows-1, n_cols-1);
6311 }
6312
6313
6314
6315 template<typename eT>
6316 arma_inline
6317 arma_warn_unused
6318 eT
back() const6319 SpMat<eT>::back() const
6320 {
6321 arma_debug_check( (n_elem == 0), "SpMat::back(): matrix is empty" );
6322
6323 return get_value(n_rows-1, n_cols-1);
6324 }
6325
6326
6327
6328 template<typename eT>
6329 inline
6330 arma_hot
6331 arma_warn_unused
6332 eT
get_value(const uword i) const6333 SpMat<eT>::get_value(const uword i) const
6334 {
6335 const MapMat<eT>& const_cache = cache; // declare as const for clarity of intent
6336
6337 // get the element from the cache if it has more recent data than CSC
6338
6339 return (sync_state == 1) ? const_cache.operator[](i) : get_value_csc(i);
6340 }
6341
6342
6343
6344 template<typename eT>
6345 inline
6346 arma_hot
6347 arma_warn_unused
6348 eT
get_value(const uword in_row,const uword in_col) const6349 SpMat<eT>::get_value(const uword in_row, const uword in_col) const
6350 {
6351 const MapMat<eT>& const_cache = cache; // declare as const for clarity of intent
6352
6353 // get the element from the cache if it has more recent data than CSC
6354
6355 return (sync_state == 1) ? const_cache.at(in_row, in_col) : get_value_csc(in_row, in_col);
6356 }
6357
6358
6359
6360 template<typename eT>
6361 inline
6362 arma_hot
6363 arma_warn_unused
6364 eT
get_value_csc(const uword i) const6365 SpMat<eT>::get_value_csc(const uword i) const
6366 {
6367 // First convert to the actual location.
6368 uword lcol = i / n_rows; // Integer division.
6369 uword lrow = i % n_rows;
6370
6371 return get_value_csc(lrow, lcol);
6372 }
6373
6374
6375
6376 template<typename eT>
6377 inline
6378 arma_hot
6379 arma_warn_unused
6380 const eT*
find_value_csc(const uword in_row,const uword in_col) const6381 SpMat<eT>::find_value_csc(const uword in_row, const uword in_col) const
6382 {
6383 const uword col_offset = col_ptrs[in_col ];
6384 const uword next_col_offset = col_ptrs[in_col + 1];
6385
6386 const uword* start_ptr = &row_indices[ col_offset];
6387 const uword* end_ptr = &row_indices[next_col_offset];
6388
6389 const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, in_row); // binary search
6390
6391 if( (pos_ptr != end_ptr) && ((*pos_ptr) == in_row) )
6392 {
6393 const uword offset = uword(pos_ptr - start_ptr);
6394 const uword index = offset + col_offset;
6395
6396 return &(values[index]);
6397 }
6398
6399 return nullptr;
6400 }
6401
6402
6403
6404 template<typename eT>
6405 inline
6406 arma_hot
6407 arma_warn_unused
6408 eT
get_value_csc(const uword in_row,const uword in_col) const6409 SpMat<eT>::get_value_csc(const uword in_row, const uword in_col) const
6410 {
6411 const eT* val_ptr = find_value_csc(in_row, in_col);
6412
6413 return (val_ptr != nullptr) ? eT(*val_ptr) : eT(0);
6414 }
6415
6416
6417
6418 template<typename eT>
6419 inline
6420 arma_hot
6421 arma_warn_unused
6422 bool
try_set_value_csc(const uword in_row,const uword in_col,const eT in_val)6423 SpMat<eT>::try_set_value_csc(const uword in_row, const uword in_col, const eT in_val)
6424 {
6425 const eT* val_ptr = find_value_csc(in_row, in_col);
6426
6427 // element not found, ie. it's zero; fail if trying to set it to non-zero value
6428 if(val_ptr == nullptr) { return (in_val == eT(0)); }
6429
6430 // fail if trying to erase an existing element
6431 if(in_val == eT(0)) { return false; }
6432
6433 access::rw(*val_ptr) = in_val;
6434
6435 invalidate_cache();
6436
6437 return true;
6438 }
6439
6440
6441
6442 template<typename eT>
6443 inline
6444 arma_hot
6445 arma_warn_unused
6446 bool
try_add_value_csc(const uword in_row,const uword in_col,const eT in_val)6447 SpMat<eT>::try_add_value_csc(const uword in_row, const uword in_col, const eT in_val)
6448 {
6449 const eT* val_ptr = find_value_csc(in_row, in_col);
6450
6451 // element not found, ie. it's zero; fail if trying to add a non-zero value
6452 if(val_ptr == nullptr) { return (in_val == eT(0)); }
6453
6454 const eT new_val = eT(*val_ptr) + in_val;
6455
6456 // fail if trying to erase an existing element
6457 if(new_val == eT(0)) { return false; }
6458
6459 access::rw(*val_ptr) = new_val;
6460
6461 invalidate_cache();
6462
6463 return true;
6464 }
6465
6466
6467
6468 template<typename eT>
6469 inline
6470 arma_hot
6471 arma_warn_unused
6472 bool
try_sub_value_csc(const uword in_row,const uword in_col,const eT in_val)6473 SpMat<eT>::try_sub_value_csc(const uword in_row, const uword in_col, const eT in_val)
6474 {
6475 const eT* val_ptr = find_value_csc(in_row, in_col);
6476
6477 // element not found, ie. it's zero; fail if trying to subtract a non-zero value
6478 if(val_ptr == nullptr) { return (in_val == eT(0)); }
6479
6480 const eT new_val = eT(*val_ptr) - in_val;
6481
6482 // fail if trying to erase an existing element
6483 if(new_val == eT(0)) { return false; }
6484
6485 access::rw(*val_ptr) = new_val;
6486
6487 invalidate_cache();
6488
6489 return true;
6490 }
6491
6492
6493
6494 template<typename eT>
6495 inline
6496 arma_hot
6497 arma_warn_unused
6498 bool
try_mul_value_csc(const uword in_row,const uword in_col,const eT in_val)6499 SpMat<eT>::try_mul_value_csc(const uword in_row, const uword in_col, const eT in_val)
6500 {
6501 const eT* val_ptr = find_value_csc(in_row, in_col);
6502
6503 // element not found, ie. it's zero; succeed if given value is finite; zero multiplied by anything is zero, except for nan and inf
6504 if(val_ptr == nullptr) { return arma_isfinite(in_val); }
6505
6506 const eT new_val = eT(*val_ptr) * in_val;
6507
6508 // fail if trying to erase an existing element
6509 if(new_val == eT(0)) { return false; }
6510
6511 access::rw(*val_ptr) = new_val;
6512
6513 invalidate_cache();
6514
6515 return true;
6516 }
6517
6518
6519
6520 template<typename eT>
6521 inline
6522 arma_hot
6523 arma_warn_unused
6524 bool
try_div_value_csc(const uword in_row,const uword in_col,const eT in_val)6525 SpMat<eT>::try_div_value_csc(const uword in_row, const uword in_col, const eT in_val)
6526 {
6527 const eT* val_ptr = find_value_csc(in_row, in_col);
6528
6529 // element not found, ie. it's zero; succeed if given value is not zero and not nan; zero divided by anything is zero, except for zero and nan
6530 if(val_ptr == nullptr) { return ((in_val != eT(0)) && (arma_isnan(in_val) == false)); }
6531
6532 const eT new_val = eT(*val_ptr) / in_val;
6533
6534 // fail if trying to erase an existing element
6535 if(new_val == eT(0)) { return false; }
6536
6537 access::rw(*val_ptr) = new_val;
6538
6539 invalidate_cache();
6540
6541 return true;
6542 }
6543
6544
6545
6546 /**
6547 * Insert an element at the given position, and return a reference to it.
6548 * The element will be set to 0, unless otherwise specified.
6549 * If the element already exists, its value will be overwritten.
6550 */
6551 template<typename eT>
6552 inline
6553 arma_warn_unused
6554 eT&
insert_element(const uword in_row,const uword in_col,const eT val)6555 SpMat<eT>::insert_element(const uword in_row, const uword in_col, const eT val)
6556 {
6557 arma_extra_debug_sigprint();
6558
6559 sync_csc();
6560 invalidate_cache();
6561
6562 // We will assume the new element does not exist and begin the search for
6563 // where to insert it. If we find that it already exists, we will then
6564 // overwrite it.
6565 uword colptr = col_ptrs[in_col ];
6566 uword next_colptr = col_ptrs[in_col + 1];
6567
6568 uword pos = colptr; // The position in the matrix of this value.
6569
6570 if(colptr != next_colptr)
6571 {
6572 // There are other elements in this column, so we must find where this
6573 // element will fit as compared to those.
6574 while(pos < next_colptr && in_row > row_indices[pos])
6575 {
6576 pos++;
6577 }
6578
6579 // We aren't inserting into the last position, so it is still possible
6580 // that the element may exist.
6581 if(pos != next_colptr && row_indices[pos] == in_row)
6582 {
6583 // It already exists. Then, just overwrite it.
6584 access::rw(values[pos]) = val;
6585
6586 return access::rw(values[pos]);
6587 }
6588 }
6589
6590
6591 //
6592 // Element doesn't exist, so we have to insert it
6593 //
6594
6595 // We have to update the rest of the column pointers.
6596 for(uword i = in_col + 1; i < n_cols + 1; i++)
6597 {
6598 access::rw(col_ptrs[i])++; // We are only inserting one new element.
6599 }
6600
6601 const uword old_n_nonzero = n_nonzero;
6602
6603 access::rw(n_nonzero)++; // Add to count of nonzero elements.
6604
6605 // Allocate larger memory.
6606 eT* new_values = memory::acquire<eT> (n_nonzero + 1);
6607 uword* new_row_indices = memory::acquire<uword>(n_nonzero + 1);
6608
6609 // Copy things over, before the new element.
6610 if(pos > 0)
6611 {
6612 arrayops::copy(new_values, values, pos);
6613 arrayops::copy(new_row_indices, row_indices, pos);
6614 }
6615
6616 // Insert the new element.
6617 new_values[pos] = val;
6618 new_row_indices[pos] = in_row;
6619
6620 // Copy the rest of things over (including the extra element at the end).
6621 arrayops::copy(new_values + pos + 1, values + pos, (old_n_nonzero - pos) + 1);
6622 arrayops::copy(new_row_indices + pos + 1, row_indices + pos, (old_n_nonzero - pos) + 1);
6623
6624 // Assign new pointers.
6625 if(values) { memory::release(access::rw(values)); }
6626 if(row_indices) { memory::release(access::rw(row_indices)); }
6627
6628 access::rw(values) = new_values;
6629 access::rw(row_indices) = new_row_indices;
6630
6631 return access::rw(values[pos]);
6632 }
6633
6634
6635
6636 /**
6637 * Delete an element at the given position.
6638 */
6639 template<typename eT>
6640 inline
6641 void
delete_element(const uword in_row,const uword in_col)6642 SpMat<eT>::delete_element(const uword in_row, const uword in_col)
6643 {
6644 arma_extra_debug_sigprint();
6645
6646 sync_csc();
6647 invalidate_cache();
6648
6649 // We assume the element exists (although... it may not) and look for its
6650 // exact position. If it doesn't exist... well, we don't need to do anything.
6651 uword colptr = col_ptrs[in_col];
6652 uword next_colptr = col_ptrs[in_col + 1];
6653
6654 if(colptr != next_colptr)
6655 {
6656 // There's at least one element in this column.
6657 // Let's see if we are one of them.
6658 for(uword pos = colptr; pos < next_colptr; pos++)
6659 {
6660 if(in_row == row_indices[pos])
6661 {
6662 --access::rw(n_nonzero); // Remove one from the count of nonzero elements.
6663
6664 // Found it. Now remove it.
6665
6666 // Make new arrays.
6667 eT* new_values = memory::acquire<eT> (n_nonzero + 1);
6668 uword* new_row_indices = memory::acquire<uword>(n_nonzero + 1);
6669
6670 if(pos > 0)
6671 {
6672 arrayops::copy(new_values, values, pos);
6673 arrayops::copy(new_row_indices, row_indices, pos);
6674 }
6675
6676 arrayops::copy(new_values + pos, values + pos + 1, (n_nonzero - pos) + 1);
6677 arrayops::copy(new_row_indices + pos, row_indices + pos + 1, (n_nonzero - pos) + 1);
6678
6679 if(values) { memory::release(access::rw(values)); }
6680 if(row_indices) { memory::release(access::rw(row_indices)); }
6681
6682 access::rw(values) = new_values;
6683 access::rw(row_indices) = new_row_indices;
6684
6685 // And lastly, update all the column pointers (decrement by one).
6686 for(uword i = in_col + 1; i < n_cols + 1; i++)
6687 {
6688 --access::rw(col_ptrs[i]); // We only removed one element.
6689 }
6690
6691 return; // There is nothing left to do.
6692 }
6693 }
6694 }
6695
6696 return; // The element does not exist, so there's nothing for us to do.
6697 }
6698
6699
6700
6701 template<typename eT>
6702 arma_inline
6703 void
invalidate_cache() const6704 SpMat<eT>::invalidate_cache() const
6705 {
6706 arma_extra_debug_sigprint();
6707
6708 if(sync_state == 0) { return; }
6709
6710 cache.reset();
6711
6712 sync_state = 0;
6713 }
6714
6715
6716
6717 template<typename eT>
6718 arma_inline
6719 void
invalidate_csc() const6720 SpMat<eT>::invalidate_csc() const
6721 {
6722 arma_extra_debug_sigprint();
6723
6724 sync_state = 1;
6725 }
6726
6727
6728
6729 template<typename eT>
6730 inline
6731 void
sync_cache() const6732 SpMat<eT>::sync_cache() const
6733 {
6734 arma_extra_debug_sigprint();
6735
6736 // using approach adapted from http://preshing.com/20130930/double-checked-locking-is-fixed-in-cpp11/
6737 //
6738 // OpenMP mode:
6739 // sync_state uses atomic read/write, which has an implied flush;
6740 // flush is also implicitly executed at the entrance and the exit of critical section;
6741 // data races are prevented by the 'critical' directive
6742 //
6743 // C++11 mode:
6744 // underlying type for sync_state is std::atomic<int>;
6745 // reading and writing to sync_state uses std::memory_order_seq_cst which has an implied fence;
6746 // data races are prevented via the mutex
6747
6748 #if defined(ARMA_USE_OPENMP)
6749 {
6750 if(sync_state == 0)
6751 {
6752 #pragma omp critical (arma_SpMat_cache)
6753 {
6754 sync_cache_simple();
6755 }
6756 }
6757 }
6758 #elif (!defined(ARMA_DONT_USE_STD_MUTEX))
6759 {
6760 if(sync_state == 0)
6761 {
6762 cache_mutex.lock();
6763
6764 sync_cache_simple();
6765
6766 cache_mutex.unlock();
6767 }
6768 }
6769 #else
6770 {
6771 sync_cache_simple();
6772 }
6773 #endif
6774 }
6775
6776
6777
6778
6779 template<typename eT>
6780 inline
6781 void
sync_cache_simple() const6782 SpMat<eT>::sync_cache_simple() const
6783 {
6784 arma_extra_debug_sigprint();
6785
6786 if(sync_state == 0)
6787 {
6788 cache = (*this);
6789 sync_state = 2;
6790 }
6791 }
6792
6793
6794
6795
6796 template<typename eT>
6797 inline
6798 void
sync_csc() const6799 SpMat<eT>::sync_csc() const
6800 {
6801 arma_extra_debug_sigprint();
6802
6803 #if defined(ARMA_USE_OPENMP)
6804 if(sync_state == 1)
6805 {
6806 #pragma omp critical (arma_SpMat_cache)
6807 {
6808 sync_csc_simple();
6809 }
6810 }
6811 #elif (!defined(ARMA_DONT_USE_STD_MUTEX))
6812 if(sync_state == 1)
6813 {
6814 cache_mutex.lock();
6815
6816 sync_csc_simple();
6817
6818 cache_mutex.unlock();
6819 }
6820 #else
6821 {
6822 sync_csc_simple();
6823 }
6824 #endif
6825 }
6826
6827
6828
6829 template<typename eT>
6830 inline
6831 void
sync_csc_simple() const6832 SpMat<eT>::sync_csc_simple() const
6833 {
6834 arma_extra_debug_sigprint();
6835
6836 // method:
6837 // 1. construct temporary matrix to prevent the cache from getting zapped
6838 // 2. steal memory from the temporary matrix
6839
6840 // sync_state is only set to 1 by non-const element access operators,
6841 // so the shenanigans with const_cast are to satisfy the compiler
6842
6843 // see also the note in sync_cache() above
6844
6845 if(sync_state == 1)
6846 {
6847 SpMat<eT>& x = const_cast< SpMat<eT>& >(*this);
6848
6849 SpMat<eT> tmp(cache);
6850
6851 x.steal_mem_simple(tmp);
6852
6853 sync_state = 2;
6854 }
6855 }
6856
6857
6858
6859
6860 //
6861 // SpMat_aux
6862
6863
6864
6865 template<typename eT, typename T1>
6866 inline
6867 void
set_real(SpMat<eT> & out,const SpBase<eT,T1> & X)6868 SpMat_aux::set_real(SpMat<eT>& out, const SpBase<eT,T1>& X)
6869 {
6870 arma_extra_debug_sigprint();
6871
6872 const unwrap_spmat<T1> tmp(X.get_ref());
6873 const SpMat<eT>& A = tmp.M;
6874
6875 arma_debug_assert_same_size( out, A, "SpMat::set_real()" );
6876
6877 out = A;
6878 }
6879
6880
6881
6882 template<typename eT, typename T1>
6883 inline
6884 void
set_imag(SpMat<eT> &,const SpBase<eT,T1> &)6885 SpMat_aux::set_imag(SpMat<eT>&, const SpBase<eT,T1>&)
6886 {
6887 arma_extra_debug_sigprint();
6888 }
6889
6890
6891
6892 template<typename T, typename T1>
6893 inline
6894 void
set_real(SpMat<std::complex<T>> & out,const SpBase<T,T1> & X)6895 SpMat_aux::set_real(SpMat< std::complex<T> >& out, const SpBase<T,T1>& X)
6896 {
6897 arma_extra_debug_sigprint();
6898
6899 typedef typename std::complex<T> eT;
6900
6901 const unwrap_spmat<T1> U(X.get_ref());
6902 const SpMat<T>& Y = U.M;
6903
6904 arma_debug_assert_same_size(out, Y, "SpMat::set_real()");
6905
6906 SpMat<eT> tmp(Y,arma::imag(out)); // arma:: prefix required due to bugs in GCC 4.4 - 4.6
6907
6908 out.steal_mem(tmp);
6909 }
6910
6911
6912
6913 template<typename T, typename T1>
6914 inline
6915 void
set_imag(SpMat<std::complex<T>> & out,const SpBase<T,T1> & X)6916 SpMat_aux::set_imag(SpMat< std::complex<T> >& out, const SpBase<T,T1>& X)
6917 {
6918 arma_extra_debug_sigprint();
6919
6920 typedef typename std::complex<T> eT;
6921
6922 const unwrap_spmat<T1> U(X.get_ref());
6923 const SpMat<T>& Y = U.M;
6924
6925 arma_debug_assert_same_size(out, Y, "SpMat::set_imag()");
6926
6927 SpMat<eT> tmp(arma::real(out),Y); // arma:: prefix required due to bugs in GCC 4.4 - 4.6
6928
6929 out.steal_mem(tmp);
6930 }
6931
6932
6933
6934 #ifdef ARMA_EXTRA_SPMAT_MEAT
6935 #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPMAT_MEAT)
6936 #endif
6937
6938
6939
6940 //! @}
6941