1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2008-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25
26 // This file should not include config.h. It is only included in other
27 // C++ source files that should have included config.h before including
28 // this file.
29
30 #include <istream>
31 #include <ostream>
32 #include <sstream>
33
34 #include "mach-info.h"
35 #include "lo-ieee.h"
36
37 #include "ov-base-diag.h"
38 #include "mxarray.h"
39 #include "ov-base.h"
40 #include "ov-base-mat.h"
41 #include "pr-output.h"
42 #include "error.h"
43 #include "errwarn.h"
44 #include "oct-stream.h"
45 #include "ops.h"
46
47 #include "ls-oct-text.h"
48
49 template <typename DMT, typename MT>
50 octave_value
subsref(const std::string & type,const std::list<octave_value_list> & idx)51 octave_base_diag<DMT, MT>::subsref (const std::string& type,
52 const std::list<octave_value_list>& idx)
53 {
54 octave_value retval;
55
56 switch (type[0])
57 {
58 case '(':
59 retval = do_index_op (idx.front ());
60 break;
61
62 case '{':
63 case '.':
64 {
65 std::string nm = type_name ();
66 error ("%s cannot be indexed with %c", nm.c_str (), type[0]);
67 }
68 break;
69
70 default:
71 panic_impossible ();
72 }
73
74 return retval.next_subsref (type, idx);
75 }
76
77 template <typename DMT, typename MT>
78 octave_value
diag(octave_idx_type k) const79 octave_base_diag<DMT,MT>::diag (octave_idx_type k) const
80 {
81 octave_value retval;
82 if (matrix.rows () == 1 || matrix.cols () == 1)
83 {
84 // Rather odd special case. This is a row or column vector
85 // represented as a diagonal matrix with a single nonzero entry, but
86 // Fdiag semantics are to product a diagonal matrix for vector
87 // inputs.
88 if (k == 0)
89 // Returns Diag2Array<T> with nnz <= 1.
90 retval = matrix.build_diag_matrix ();
91 else
92 // Returns Array<T> matrix
93 retval = matrix.array_value ().diag (k);
94 }
95 else
96 // Returns Array<T> vector
97 retval = matrix.extract_diag (k);
98 return retval;
99 }
100
101 template <typename DMT, typename MT>
102 octave_value
do_index_op(const octave_value_list & idx,bool resize_ok)103 octave_base_diag<DMT, MT>::do_index_op (const octave_value_list& idx,
104 bool resize_ok)
105 {
106 octave_value retval;
107
108 if (idx.length () == 2 && ! resize_ok)
109 {
110 int k = 0; // index we're accessing when index_vector throws
111 try
112 {
113 idx_vector idx0 = idx(0).index_vector ();
114 k = 1;
115 idx_vector idx1 = idx(1).index_vector ();
116
117 if (idx0.is_scalar () && idx1.is_scalar ())
118 {
119 retval = matrix.checkelem (idx0(0), idx1(0));
120 }
121 else
122 {
123 octave_idx_type m = idx0.length (matrix.rows ());
124 octave_idx_type n = idx1.length (matrix.columns ());
125 if (idx0.is_colon_equiv (m) && idx1.is_colon_equiv (n)
126 && m <= matrix.rows () && n <= matrix.rows ())
127 {
128 DMT rm (matrix);
129 rm.resize (m, n);
130 retval = rm;
131 }
132 else
133 retval = to_dense ().do_index_op (idx, resize_ok);
134 }
135 }
136 catch (octave::index_exception& e)
137 {
138 // Rethrow to allow more info to be reported later.
139 e.set_pos_if_unset (2, k+1);
140 throw;
141 }
142 }
143 else
144 retval = to_dense ().do_index_op (idx, resize_ok);
145
146 return retval;
147 }
148
149 template <typename DMT, typename MT>
150 octave_value
subsasgn(const std::string & type,const std::list<octave_value_list> & idx,const octave_value & rhs)151 octave_base_diag<DMT, MT>::subsasgn (const std::string& type,
152 const std::list<octave_value_list>& idx,
153 const octave_value& rhs)
154 {
155 octave_value retval;
156
157 switch (type[0])
158 {
159 case '(':
160 {
161 if (type.length () != 1)
162 {
163 std::string nm = type_name ();
164 error ("in indexed assignment of %s, last lhs index must be ()",
165 nm.c_str ());
166 }
167
168 octave_value_list jdx = idx.front ();
169
170 // FIXME: Mostly repeated code for cases 1 and 2 could be
171 // consolidated for DRY (Don't Repeat Yourself).
172 // Check for assignments to diagonal elements which should not
173 // destroy the diagonal property of the matrix.
174 // If D is a diagonal matrix then the assignment can be
175 // 1) linear, D(i) = x, where ind2sub results in case #2 below
176 // 2) subscript D(i,i) = x, where both indices are equal.
177 if (jdx.length () == 1 && jdx(0).is_scalar_type ())
178 {
179 typename DMT::element_type val;
180 int k = 0;
181 try
182 {
183 idx_vector ind = jdx(0).index_vector ();
184 k = 1;
185 dim_vector dv (matrix.rows (), matrix.cols ());
186 Array<idx_vector> ivec = ind2sub (dv, ind);
187 idx_vector i0 = ivec(0);
188 idx_vector i1 = ivec(1);
189
190 if (i0(0) == i1(0)
191 && chk_valid_scalar (rhs, val))
192 {
193 matrix.dgelem (i0(0)) = val;
194 retval = this;
195 this->count++;
196 // invalidate cache
197 dense_cache = octave_value ();
198 }
199 }
200 catch (octave::index_exception& e)
201 {
202 // Rethrow to allow more info to be reported later.
203 e.set_pos_if_unset (2, k+1);
204 throw;
205 }
206 }
207 else if (jdx.length () == 2
208 && jdx(0).is_scalar_type () && jdx(1).is_scalar_type ())
209 {
210 typename DMT::element_type val;
211 int k = 0;
212 try
213 {
214 idx_vector i0 = jdx(0).index_vector ();
215 k = 1;
216 idx_vector i1 = jdx(1).index_vector ();
217 if (i0(0) == i1(0)
218 && i0(0) < matrix.rows () && i1(0) < matrix.cols ()
219 && chk_valid_scalar (rhs, val))
220 {
221 matrix.dgelem (i0(0)) = val;
222 retval = this;
223 this->count++;
224 // invalidate cache
225 dense_cache = octave_value ();
226 }
227 }
228 catch (octave::index_exception& e)
229 {
230 // Rethrow to allow more info to be reported later.
231 e.set_pos_if_unset (2, k+1);
232 throw;
233 }
234 }
235
236 if (! retval.is_defined ())
237 retval = numeric_assign (type, idx, rhs);
238 }
239 break;
240
241 case '{':
242 case '.':
243 {
244 if (! isempty ())
245 {
246 std::string nm = type_name ();
247 error ("%s cannot be indexed with %c", nm.c_str (), type[0]);
248 }
249
250 octave_value tmp = octave_value::empty_conv (type, rhs);
251
252 retval = tmp.subsasgn (type, idx, rhs);
253 }
254 break;
255
256 default:
257 panic_impossible ();
258 }
259
260 return retval;
261 }
262
263 template <typename DMT, typename MT>
264 octave_value
resize(const dim_vector & dv,bool fill) const265 octave_base_diag<DMT, MT>::resize (const dim_vector& dv, bool fill) const
266 {
267 octave_value retval;
268 if (dv.ndims () == 2)
269 {
270 DMT rm (matrix);
271 rm.resize (dv(0), dv(1));
272 retval = rm;
273 }
274 else
275 retval = to_dense ().resize (dv, fill);
276 return retval;
277 }
278
279 // Return true if this matrix has all true elements (non-zero, not NA/NaN).
280 template <typename DMT, typename MT>
281 bool
is_true(void) const282 octave_base_diag<DMT, MT>::is_true (void) const
283 {
284 if (dims ().numel () > 1)
285 {
286 warn_array_as_logical (dims ());
287 // Throw error if any NaN or NA by calling is_true().
288 octave_value (matrix.extract_diag ()).is_true ();
289 return false; // > 1x1 diagonal always has zeros
290 }
291 else
292 return to_dense ().is_true (); // 0x0 or 1x1, handle NaN etc.
293 }
294
295 // FIXME: This should be achieveable using ::real
helper_getreal(T x)296 template <typename T> inline T helper_getreal (T x) { return x; }
helper_getreal(std::complex<T> x)297 template <typename T> inline T helper_getreal (std::complex<T> x)
298 { return x.real (); }
299 // FIXME: We really need some traits so that ad hoc hooks like this
300 // are not necessary.
helper_iscomplex(T)301 template <typename T> inline T helper_iscomplex (T) { return false; }
helper_iscomplex(std::complex<T>)302 template <typename T> inline T helper_iscomplex (std::complex<T>) { return true; }
303
304 template <typename DMT, typename MT>
305 double
double_value(bool force_conversion) const306 octave_base_diag<DMT, MT>::double_value (bool force_conversion) const
307 {
308 typedef typename DMT::element_type el_type;
309
310 if (helper_iscomplex (el_type ()) && ! force_conversion)
311 warn_implicit_conversion ("Octave:imag-to-real",
312 "complex matrix", "real scalar");
313
314 if (isempty ())
315 err_invalid_conversion (type_name (), "real scalar");
316
317 warn_implicit_conversion ("Octave:array-to-scalar",
318 type_name (), "real scalar");
319
320 return helper_getreal (el_type (matrix (0, 0)));
321 }
322
323 template <typename DMT, typename MT>
324 float
float_value(bool force_conversion) const325 octave_base_diag<DMT, MT>::float_value (bool force_conversion) const
326 {
327 typedef typename DMT::element_type el_type;
328
329 if (helper_iscomplex (el_type ()) && ! force_conversion)
330 warn_implicit_conversion ("Octave:imag-to-real",
331 "complex matrix", "real scalar");
332
333 if (! (numel () > 0))
334 err_invalid_conversion (type_name (), "real scalar");
335
336 warn_implicit_conversion ("Octave:array-to-scalar",
337 type_name (), "real scalar");
338
339 return helper_getreal (el_type (matrix (0, 0)));
340 }
341
342 template <typename DMT, typename MT>
343 Complex
complex_value(bool) const344 octave_base_diag<DMT, MT>::complex_value (bool) const
345 {
346 if (rows () == 0 || columns () == 0)
347 err_invalid_conversion (type_name (), "complex scalar");
348
349 warn_implicit_conversion ("Octave:array-to-scalar",
350 type_name (), "complex scalar");
351
352 return matrix(0, 0);
353 }
354
355 template <typename DMT, typename MT>
356 FloatComplex
float_complex_value(bool) const357 octave_base_diag<DMT, MT>::float_complex_value (bool) const
358 {
359 float tmp = lo_ieee_float_nan_value ();
360
361 FloatComplex retval (tmp, tmp);
362
363 if (rows () == 0 || columns () == 0)
364 err_invalid_conversion (type_name (), "complex scalar");
365
366 warn_implicit_conversion ("Octave:array-to-scalar",
367 type_name (), "complex scalar");
368
369 retval = matrix (0, 0);
370
371 return retval;
372 }
373
374 template <typename DMT, typename MT>
375 Matrix
matrix_value(bool) const376 octave_base_diag<DMT, MT>::matrix_value (bool) const
377 {
378 return Matrix (diag_matrix_value ());
379 }
380
381 template <typename DMT, typename MT>
382 FloatMatrix
float_matrix_value(bool) const383 octave_base_diag<DMT, MT>::float_matrix_value (bool) const
384 {
385 return FloatMatrix (float_diag_matrix_value ());
386 }
387
388 template <typename DMT, typename MT>
389 ComplexMatrix
complex_matrix_value(bool) const390 octave_base_diag<DMT, MT>::complex_matrix_value (bool) const
391 {
392 return ComplexMatrix (complex_diag_matrix_value ());
393 }
394
395 template <typename DMT, typename MT>
396 FloatComplexMatrix
float_complex_matrix_value(bool) const397 octave_base_diag<DMT, MT>::float_complex_matrix_value (bool) const
398 {
399 return FloatComplexMatrix (float_complex_diag_matrix_value ());
400 }
401
402 template <typename DMT, typename MT>
403 NDArray
array_value(bool) const404 octave_base_diag<DMT, MT>::array_value (bool) const
405 {
406 return NDArray (matrix_value ());
407 }
408
409 template <typename DMT, typename MT>
410 FloatNDArray
float_array_value(bool) const411 octave_base_diag<DMT, MT>::float_array_value (bool) const
412 {
413 return FloatNDArray (float_matrix_value ());
414 }
415
416 template <typename DMT, typename MT>
417 ComplexNDArray
complex_array_value(bool) const418 octave_base_diag<DMT, MT>::complex_array_value (bool) const
419 {
420 return ComplexNDArray (complex_matrix_value ());
421 }
422
423 template <typename DMT, typename MT>
424 FloatComplexNDArray
float_complex_array_value(bool) const425 octave_base_diag<DMT, MT>::float_complex_array_value (bool) const
426 {
427 return FloatComplexNDArray (float_complex_matrix_value ());
428 }
429
430 template <typename DMT, typename MT>
431 boolNDArray
bool_array_value(bool warn) const432 octave_base_diag<DMT, MT>::bool_array_value (bool warn) const
433 {
434 return to_dense ().bool_array_value (warn);
435 }
436
437 template <typename DMT, typename MT>
438 charNDArray
char_array_value(bool warn) const439 octave_base_diag<DMT, MT>::char_array_value (bool warn) const
440 {
441 return to_dense ().char_array_value (warn);
442 }
443
444 template <typename DMT, typename MT>
445 SparseMatrix
sparse_matrix_value(bool) const446 octave_base_diag<DMT, MT>::sparse_matrix_value (bool) const
447 {
448 return SparseMatrix (diag_matrix_value ());
449 }
450
451 template <typename DMT, typename MT>
452 SparseComplexMatrix
sparse_complex_matrix_value(bool) const453 octave_base_diag<DMT, MT>::sparse_complex_matrix_value (bool) const
454 {
455 return SparseComplexMatrix (complex_diag_matrix_value ());
456 }
457
458 template <typename DMT, typename MT>
459 idx_vector
index_vector(bool require_integers) const460 octave_base_diag<DMT, MT>::index_vector (bool require_integers) const
461 {
462 return to_dense ().index_vector (require_integers);
463 }
464
465 template <typename DMT, typename MT>
466 octave_value
convert_to_str_internal(bool pad,bool force,char type) const467 octave_base_diag<DMT, MT>::convert_to_str_internal (bool pad, bool force,
468 char type) const
469 {
470 return to_dense ().convert_to_str_internal (pad, force, type);
471 }
472
473 template <typename DMT, typename MT>
474 float_display_format
get_edit_display_format(void) const475 octave_base_diag<DMT, MT>::get_edit_display_format (void) const
476 {
477 // FIXME
478 return float_display_format ();
479 }
480
481 template <typename DMT, typename MT>
482 std::string
edit_display(const float_display_format & fmt,octave_idx_type i,octave_idx_type j) const483 octave_base_diag<DMT, MT>::edit_display (const float_display_format& fmt,
484 octave_idx_type i,
485 octave_idx_type j) const
486 {
487 std::ostringstream buf;
488 octave_print_internal (buf, fmt, matrix(i,j));
489 return buf.str ();
490 }
491
492 template <typename DMT, typename MT>
493 bool
save_ascii(std::ostream & os)494 octave_base_diag<DMT, MT>::save_ascii (std::ostream& os)
495 {
496 os << "# rows: " << matrix.rows () << "\n"
497 << "# columns: " << matrix.columns () << "\n";
498
499 os << matrix.extract_diag ();
500
501 return true;
502 }
503
504 template <typename DMT, typename MT>
505 bool
load_ascii(std::istream & is)506 octave_base_diag<DMT, MT>::load_ascii (std::istream& is)
507 {
508 octave_idx_type r = 0;
509 octave_idx_type c = 0;
510
511 if (! extract_keyword (is, "rows", r, true)
512 || ! extract_keyword (is, "columns", c, true))
513 error ("load: failed to extract number of rows and columns");
514
515 octave_idx_type l = (r < c ? r : c);
516 MT tmp (l, 1);
517 is >> tmp;
518
519 if (! is)
520 error ("load: failed to load diagonal matrix constant");
521
522 // This is a little tricky, as we have the Matrix type, but
523 // not ColumnVector type. We need to help the compiler get
524 // through the inheritance tree.
525 typedef typename DMT::element_type el_type;
526 matrix = DMT (MDiagArray2<el_type> (MArray<el_type> (tmp)));
527 matrix.resize (r, c);
528
529 // Invalidate cache. Probably not necessary, but safe.
530 dense_cache = octave_value ();
531
532 return true;
533 }
534
535 template <typename DMT, typename MT>
536 void
print_raw(std::ostream & os,bool pr_as_read_syntax) const537 octave_base_diag<DMT, MT>::print_raw (std::ostream& os,
538 bool pr_as_read_syntax) const
539 {
540 return octave_print_internal (os, matrix, pr_as_read_syntax,
541 current_print_indent_level ());
542 }
543
544 template <typename DMT, typename MT>
545 mxArray *
as_mxArray(void) const546 octave_base_diag<DMT, MT>::as_mxArray (void) const
547 {
548 return to_dense ().as_mxArray ();
549 }
550
551 template <typename DMT, typename MT>
552 bool
print_as_scalar(void) const553 octave_base_diag<DMT, MT>::print_as_scalar (void) const
554 {
555 dim_vector dv = dims ();
556
557 return (dv.all_ones () || dv.any_zero ());
558 }
559
560 template <typename DMT, typename MT>
561 void
print(std::ostream & os,bool pr_as_read_syntax)562 octave_base_diag<DMT, MT>::print (std::ostream& os, bool pr_as_read_syntax)
563 {
564 print_raw (os, pr_as_read_syntax);
565 newline (os);
566 }
567 template <typename DMT, typename MT>
568 int
write(octave::stream & os,int block_size,oct_data_conv::data_type output_type,int skip,octave::mach_info::float_format flt_fmt) const569 octave_base_diag<DMT, MT>::write (octave::stream& os, int block_size,
570 oct_data_conv::data_type output_type,
571 int skip,
572 octave::mach_info::float_format flt_fmt) const
573 {
574 return to_dense ().write (os, block_size, output_type, skip, flt_fmt);
575 }
576
577 template <typename DMT, typename MT>
578 void
print_info(std::ostream & os,const std::string & prefix) const579 octave_base_diag<DMT, MT>::print_info (std::ostream& os,
580 const std::string& prefix) const
581 {
582 matrix.print_info (os, prefix);
583 }
584
585 // FIXME: this function is duplicated in octave_base_matrix<T>. Could
586 // it somehow be shared instead?
587
588 template <typename DMT, typename MT>
589 void
short_disp(std::ostream & os) const590 octave_base_diag<DMT, MT>::short_disp (std::ostream& os) const
591 {
592 if (matrix.isempty ())
593 os << "[]";
594 else if (matrix.ndims () == 2)
595 {
596 // FIXME: should this be configurable?
597 octave_idx_type max_elts = 10;
598 octave_idx_type elts = 0;
599
600 octave_idx_type nel = matrix.numel ();
601
602 octave_idx_type nr = matrix.rows ();
603 octave_idx_type nc = matrix.columns ();
604
605 os << '[';
606
607 for (octave_idx_type i = 0; i < nr; i++)
608 {
609 for (octave_idx_type j = 0; j < nc; j++)
610 {
611 std::ostringstream buf;
612 octave_print_internal (buf, matrix(i,j));
613 std::string tmp = buf.str ();
614 std::size_t pos = tmp.find_first_not_of (' ');
615 if (pos != std::string::npos)
616 os << tmp.substr (pos);
617 else if (! tmp.empty ())
618 os << tmp[0];
619
620 if (++elts >= max_elts)
621 goto done;
622
623 if (j < nc - 1)
624 os << ", ";
625 }
626
627 if (i < nr - 1 && elts < max_elts)
628 os << "; ";
629 }
630
631 done:
632
633 if (nel <= max_elts)
634 os << ']';
635 }
636 else
637 os << "...";
638 }
639
640 template <typename DMT, typename MT>
641 octave_value
fast_elem_extract(octave_idx_type n) const642 octave_base_diag<DMT, MT>::fast_elem_extract (octave_idx_type n) const
643 {
644 if (n < matrix.numel ())
645 {
646 octave_idx_type nr = matrix.rows ();
647
648 octave_idx_type r = n % nr;
649 octave_idx_type c = n / nr;
650
651 return octave_value (matrix.elem (r, c));
652 }
653 else
654 return octave_value ();
655 }
656
657 template <typename DMT, typename MT>
658 octave_value
to_dense(void) const659 octave_base_diag<DMT, MT>::to_dense (void) const
660 {
661 if (! dense_cache.is_defined ())
662 dense_cache = MT (matrix);
663
664 return dense_cache;
665 }
666