1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2008-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING.  If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 // This file should not include config.h.  It is only included in other
27 // C++ source files that should have included config.h before including
28 // this file.
29 
30 #include <istream>
31 #include <ostream>
32 #include <sstream>
33 
34 #include "mach-info.h"
35 #include "lo-ieee.h"
36 
37 #include "ov-base-diag.h"
38 #include "mxarray.h"
39 #include "ov-base.h"
40 #include "ov-base-mat.h"
41 #include "pr-output.h"
42 #include "error.h"
43 #include "errwarn.h"
44 #include "oct-stream.h"
45 #include "ops.h"
46 
47 #include "ls-oct-text.h"
48 
49 template <typename DMT, typename MT>
50 octave_value
subsref(const std::string & type,const std::list<octave_value_list> & idx)51 octave_base_diag<DMT, MT>::subsref (const std::string& type,
52                                     const std::list<octave_value_list>& idx)
53 {
54   octave_value retval;
55 
56   switch (type[0])
57     {
58     case '(':
59       retval = do_index_op (idx.front ());
60       break;
61 
62     case '{':
63     case '.':
64       {
65         std::string nm = type_name ();
66         error ("%s cannot be indexed with %c", nm.c_str (), type[0]);
67       }
68       break;
69 
70     default:
71       panic_impossible ();
72     }
73 
74   return retval.next_subsref (type, idx);
75 }
76 
77 template <typename DMT, typename MT>
78 octave_value
diag(octave_idx_type k) const79 octave_base_diag<DMT,MT>::diag (octave_idx_type k) const
80 {
81   octave_value retval;
82   if (matrix.rows () == 1 || matrix.cols () == 1)
83     {
84       // Rather odd special case.  This is a row or column vector
85       // represented as a diagonal matrix with a single nonzero entry, but
86       // Fdiag semantics are to product a diagonal matrix for vector
87       // inputs.
88       if (k == 0)
89         // Returns Diag2Array<T> with nnz <= 1.
90         retval = matrix.build_diag_matrix ();
91       else
92         // Returns Array<T> matrix
93         retval = matrix.array_value ().diag (k);
94     }
95   else
96     // Returns Array<T> vector
97     retval = matrix.extract_diag (k);
98   return retval;
99 }
100 
101 template <typename DMT, typename MT>
102 octave_value
do_index_op(const octave_value_list & idx,bool resize_ok)103 octave_base_diag<DMT, MT>::do_index_op (const octave_value_list& idx,
104                                         bool resize_ok)
105 {
106   octave_value retval;
107 
108   if (idx.length () == 2 && ! resize_ok)
109     {
110       int k = 0;        // index we're accessing when index_vector throws
111       try
112         {
113           idx_vector idx0 = idx(0).index_vector ();
114           k = 1;
115           idx_vector idx1 = idx(1).index_vector ();
116 
117           if (idx0.is_scalar () && idx1.is_scalar ())
118             {
119               retval = matrix.checkelem (idx0(0), idx1(0));
120             }
121           else
122             {
123               octave_idx_type m = idx0.length (matrix.rows ());
124               octave_idx_type n = idx1.length (matrix.columns ());
125               if (idx0.is_colon_equiv (m) && idx1.is_colon_equiv (n)
126                   && m <= matrix.rows () && n <= matrix.rows ())
127                 {
128                   DMT rm (matrix);
129                   rm.resize (m, n);
130                   retval = rm;
131                 }
132               else
133                 retval = to_dense ().do_index_op (idx, resize_ok);
134             }
135         }
136       catch (octave::index_exception& e)
137         {
138           // Rethrow to allow more info to be reported later.
139           e.set_pos_if_unset (2, k+1);
140           throw;
141         }
142     }
143   else
144     retval = to_dense ().do_index_op (idx, resize_ok);
145 
146   return retval;
147 }
148 
149 template <typename DMT, typename MT>
150 octave_value
subsasgn(const std::string & type,const std::list<octave_value_list> & idx,const octave_value & rhs)151 octave_base_diag<DMT, MT>::subsasgn (const std::string& type,
152                                      const std::list<octave_value_list>& idx,
153                                      const octave_value& rhs)
154 {
155   octave_value retval;
156 
157   switch (type[0])
158     {
159     case '(':
160       {
161         if (type.length () != 1)
162           {
163             std::string nm = type_name ();
164             error ("in indexed assignment of %s, last lhs index must be ()",
165                    nm.c_str ());
166           }
167 
168         octave_value_list jdx = idx.front ();
169 
170         // FIXME: Mostly repeated code for cases 1 and 2 could be
171         //        consolidated for DRY (Don't Repeat Yourself).
172         // Check for assignments to diagonal elements which should not
173         // destroy the diagonal property of the matrix.
174         // If D is a diagonal matrix then the assignment can be
175         // 1) linear, D(i) = x, where ind2sub results in case #2 below
176         // 2) subscript D(i,i) = x, where both indices are equal.
177         if (jdx.length () == 1 && jdx(0).is_scalar_type ())
178           {
179             typename DMT::element_type val;
180             int k = 0;
181             try
182               {
183                 idx_vector ind = jdx(0).index_vector ();
184                 k = 1;
185                 dim_vector dv (matrix.rows (), matrix.cols ());
186                 Array<idx_vector> ivec = ind2sub (dv, ind);
187                 idx_vector i0 = ivec(0);
188                 idx_vector i1 = ivec(1);
189 
190                 if (i0(0) == i1(0)
191                     && chk_valid_scalar (rhs, val))
192                   {
193                     matrix.dgelem (i0(0)) = val;
194                     retval = this;
195                     this->count++;
196                     // invalidate cache
197                     dense_cache = octave_value ();
198                   }
199               }
200             catch (octave::index_exception& e)
201               {
202                 // Rethrow to allow more info to be reported later.
203                 e.set_pos_if_unset (2, k+1);
204                 throw;
205               }
206           }
207         else if (jdx.length () == 2
208                  && jdx(0).is_scalar_type () && jdx(1).is_scalar_type ())
209           {
210             typename DMT::element_type val;
211             int k = 0;
212             try
213               {
214                 idx_vector i0 = jdx(0).index_vector ();
215                 k = 1;
216                 idx_vector i1 = jdx(1).index_vector ();
217                 if (i0(0) == i1(0)
218                     && i0(0) < matrix.rows () && i1(0) < matrix.cols ()
219                     && chk_valid_scalar (rhs, val))
220                   {
221                     matrix.dgelem (i0(0)) = val;
222                     retval = this;
223                     this->count++;
224                     // invalidate cache
225                     dense_cache = octave_value ();
226                   }
227               }
228             catch (octave::index_exception& e)
229               {
230                 // Rethrow to allow more info to be reported later.
231                 e.set_pos_if_unset (2, k+1);
232                 throw;
233               }
234           }
235 
236         if (! retval.is_defined ())
237           retval = numeric_assign (type, idx, rhs);
238       }
239       break;
240 
241     case '{':
242     case '.':
243       {
244         if (! isempty ())
245           {
246             std::string nm = type_name ();
247             error ("%s cannot be indexed with %c", nm.c_str (), type[0]);
248           }
249 
250         octave_value tmp = octave_value::empty_conv (type, rhs);
251 
252         retval = tmp.subsasgn (type, idx, rhs);
253       }
254       break;
255 
256     default:
257       panic_impossible ();
258     }
259 
260   return retval;
261 }
262 
263 template <typename DMT, typename MT>
264 octave_value
resize(const dim_vector & dv,bool fill) const265 octave_base_diag<DMT, MT>::resize (const dim_vector& dv, bool fill) const
266 {
267   octave_value retval;
268   if (dv.ndims () == 2)
269     {
270       DMT rm (matrix);
271       rm.resize (dv(0), dv(1));
272       retval = rm;
273     }
274   else
275     retval = to_dense ().resize (dv, fill);
276   return retval;
277 }
278 
279 // Return true if this matrix has all true elements (non-zero, not NA/NaN).
280 template <typename DMT, typename MT>
281 bool
is_true(void) const282 octave_base_diag<DMT, MT>::is_true (void) const
283 {
284   if (dims ().numel () > 1)
285     {
286       warn_array_as_logical (dims ());
287       // Throw error if any NaN or NA by calling is_true().
288       octave_value (matrix.extract_diag ()).is_true ();
289       return false;                 // > 1x1 diagonal always has zeros
290     }
291   else
292     return to_dense ().is_true ();  // 0x0 or 1x1, handle NaN etc.
293 }
294 
295 // FIXME: This should be achieveable using ::real
helper_getreal(T x)296 template <typename T> inline T helper_getreal (T x) { return x; }
helper_getreal(std::complex<T> x)297 template <typename T> inline T helper_getreal (std::complex<T> x)
298 { return x.real (); }
299 // FIXME: We really need some traits so that ad hoc hooks like this
300 //        are not necessary.
helper_iscomplex(T)301 template <typename T> inline T helper_iscomplex (T) { return false; }
helper_iscomplex(std::complex<T>)302 template <typename T> inline T helper_iscomplex (std::complex<T>) { return true; }
303 
304 template <typename DMT, typename MT>
305 double
double_value(bool force_conversion) const306 octave_base_diag<DMT, MT>::double_value (bool force_conversion) const
307 {
308   typedef typename DMT::element_type el_type;
309 
310   if (helper_iscomplex (el_type ()) && ! force_conversion)
311     warn_implicit_conversion ("Octave:imag-to-real",
312                               "complex matrix", "real scalar");
313 
314   if (isempty ())
315     err_invalid_conversion (type_name (), "real scalar");
316 
317   warn_implicit_conversion ("Octave:array-to-scalar",
318                             type_name (), "real scalar");
319 
320   return helper_getreal (el_type (matrix (0, 0)));
321 }
322 
323 template <typename DMT, typename MT>
324 float
float_value(bool force_conversion) const325 octave_base_diag<DMT, MT>::float_value (bool force_conversion) const
326 {
327   typedef typename DMT::element_type el_type;
328 
329   if (helper_iscomplex (el_type ()) && ! force_conversion)
330     warn_implicit_conversion ("Octave:imag-to-real",
331                               "complex matrix", "real scalar");
332 
333   if (! (numel () > 0))
334     err_invalid_conversion (type_name (), "real scalar");
335 
336   warn_implicit_conversion ("Octave:array-to-scalar",
337                             type_name (), "real scalar");
338 
339   return helper_getreal (el_type (matrix (0, 0)));
340 }
341 
342 template <typename DMT, typename MT>
343 Complex
complex_value(bool) const344 octave_base_diag<DMT, MT>::complex_value (bool) const
345 {
346   if (rows () == 0 || columns () == 0)
347     err_invalid_conversion (type_name (), "complex scalar");
348 
349   warn_implicit_conversion ("Octave:array-to-scalar",
350                             type_name (), "complex scalar");
351 
352   return matrix(0, 0);
353 }
354 
355 template <typename DMT, typename MT>
356 FloatComplex
float_complex_value(bool) const357 octave_base_diag<DMT, MT>::float_complex_value (bool) const
358 {
359   float tmp = lo_ieee_float_nan_value ();
360 
361   FloatComplex retval (tmp, tmp);
362 
363   if (rows () == 0 || columns () == 0)
364     err_invalid_conversion (type_name (), "complex scalar");
365 
366   warn_implicit_conversion ("Octave:array-to-scalar",
367                             type_name (), "complex scalar");
368 
369   retval = matrix (0, 0);
370 
371   return retval;
372 }
373 
374 template <typename DMT, typename MT>
375 Matrix
matrix_value(bool) const376 octave_base_diag<DMT, MT>::matrix_value (bool) const
377 {
378   return Matrix (diag_matrix_value ());
379 }
380 
381 template <typename DMT, typename MT>
382 FloatMatrix
float_matrix_value(bool) const383 octave_base_diag<DMT, MT>::float_matrix_value (bool) const
384 {
385   return FloatMatrix (float_diag_matrix_value ());
386 }
387 
388 template <typename DMT, typename MT>
389 ComplexMatrix
complex_matrix_value(bool) const390 octave_base_diag<DMT, MT>::complex_matrix_value (bool) const
391 {
392   return ComplexMatrix (complex_diag_matrix_value ());
393 }
394 
395 template <typename DMT, typename MT>
396 FloatComplexMatrix
float_complex_matrix_value(bool) const397 octave_base_diag<DMT, MT>::float_complex_matrix_value (bool) const
398 {
399   return FloatComplexMatrix (float_complex_diag_matrix_value ());
400 }
401 
402 template <typename DMT, typename MT>
403 NDArray
array_value(bool) const404 octave_base_diag<DMT, MT>::array_value (bool) const
405 {
406   return NDArray (matrix_value ());
407 }
408 
409 template <typename DMT, typename MT>
410 FloatNDArray
float_array_value(bool) const411 octave_base_diag<DMT, MT>::float_array_value (bool) const
412 {
413   return FloatNDArray (float_matrix_value ());
414 }
415 
416 template <typename DMT, typename MT>
417 ComplexNDArray
complex_array_value(bool) const418 octave_base_diag<DMT, MT>::complex_array_value (bool) const
419 {
420   return ComplexNDArray (complex_matrix_value ());
421 }
422 
423 template <typename DMT, typename MT>
424 FloatComplexNDArray
float_complex_array_value(bool) const425 octave_base_diag<DMT, MT>::float_complex_array_value (bool) const
426 {
427   return FloatComplexNDArray (float_complex_matrix_value ());
428 }
429 
430 template <typename DMT, typename MT>
431 boolNDArray
bool_array_value(bool warn) const432 octave_base_diag<DMT, MT>::bool_array_value (bool warn) const
433 {
434   return to_dense ().bool_array_value (warn);
435 }
436 
437 template <typename DMT, typename MT>
438 charNDArray
char_array_value(bool warn) const439 octave_base_diag<DMT, MT>::char_array_value (bool warn) const
440 {
441   return to_dense ().char_array_value (warn);
442 }
443 
444 template <typename DMT, typename MT>
445 SparseMatrix
sparse_matrix_value(bool) const446 octave_base_diag<DMT, MT>::sparse_matrix_value (bool) const
447 {
448   return SparseMatrix (diag_matrix_value ());
449 }
450 
451 template <typename DMT, typename MT>
452 SparseComplexMatrix
sparse_complex_matrix_value(bool) const453 octave_base_diag<DMT, MT>::sparse_complex_matrix_value (bool) const
454 {
455   return SparseComplexMatrix (complex_diag_matrix_value ());
456 }
457 
458 template <typename DMT, typename MT>
459 idx_vector
index_vector(bool require_integers) const460 octave_base_diag<DMT, MT>::index_vector (bool require_integers) const
461 {
462   return to_dense ().index_vector (require_integers);
463 }
464 
465 template <typename DMT, typename MT>
466 octave_value
convert_to_str_internal(bool pad,bool force,char type) const467 octave_base_diag<DMT, MT>::convert_to_str_internal (bool pad, bool force,
468                                                     char type) const
469 {
470   return to_dense ().convert_to_str_internal (pad, force, type);
471 }
472 
473 template <typename DMT, typename MT>
474 float_display_format
get_edit_display_format(void) const475 octave_base_diag<DMT, MT>::get_edit_display_format (void) const
476 {
477   // FIXME
478   return float_display_format ();
479 }
480 
481 template <typename DMT, typename MT>
482 std::string
edit_display(const float_display_format & fmt,octave_idx_type i,octave_idx_type j) const483 octave_base_diag<DMT, MT>::edit_display (const float_display_format& fmt,
484                                          octave_idx_type i,
485                                          octave_idx_type j) const
486 {
487   std::ostringstream buf;
488   octave_print_internal (buf, fmt, matrix(i,j));
489   return buf.str ();
490 }
491 
492 template <typename DMT, typename MT>
493 bool
save_ascii(std::ostream & os)494 octave_base_diag<DMT, MT>::save_ascii (std::ostream& os)
495 {
496   os << "# rows: " << matrix.rows () << "\n"
497      << "# columns: " << matrix.columns () << "\n";
498 
499   os << matrix.extract_diag ();
500 
501   return true;
502 }
503 
504 template <typename DMT, typename MT>
505 bool
load_ascii(std::istream & is)506 octave_base_diag<DMT, MT>::load_ascii (std::istream& is)
507 {
508   octave_idx_type r = 0;
509   octave_idx_type c = 0;
510 
511   if (! extract_keyword (is, "rows", r, true)
512       || ! extract_keyword (is, "columns", c, true))
513     error ("load: failed to extract number of rows and columns");
514 
515   octave_idx_type l = (r < c ? r : c);
516   MT tmp (l, 1);
517   is >> tmp;
518 
519   if (! is)
520     error ("load: failed to load diagonal matrix constant");
521 
522   // This is a little tricky, as we have the Matrix type, but
523   // not ColumnVector type.  We need to help the compiler get
524   // through the inheritance tree.
525   typedef typename DMT::element_type el_type;
526   matrix = DMT (MDiagArray2<el_type> (MArray<el_type> (tmp)));
527   matrix.resize (r, c);
528 
529   // Invalidate cache.  Probably not necessary, but safe.
530   dense_cache = octave_value ();
531 
532   return true;
533 }
534 
535 template <typename DMT, typename MT>
536 void
print_raw(std::ostream & os,bool pr_as_read_syntax) const537 octave_base_diag<DMT, MT>::print_raw (std::ostream& os,
538                                       bool pr_as_read_syntax) const
539 {
540   return octave_print_internal (os, matrix, pr_as_read_syntax,
541                                 current_print_indent_level ());
542 }
543 
544 template <typename DMT, typename MT>
545 mxArray *
as_mxArray(void) const546 octave_base_diag<DMT, MT>::as_mxArray (void) const
547 {
548   return to_dense ().as_mxArray ();
549 }
550 
551 template <typename DMT, typename MT>
552 bool
print_as_scalar(void) const553 octave_base_diag<DMT, MT>::print_as_scalar (void) const
554 {
555   dim_vector dv = dims ();
556 
557   return (dv.all_ones () || dv.any_zero ());
558 }
559 
560 template <typename DMT, typename MT>
561 void
print(std::ostream & os,bool pr_as_read_syntax)562 octave_base_diag<DMT, MT>::print (std::ostream& os, bool pr_as_read_syntax)
563 {
564   print_raw (os, pr_as_read_syntax);
565   newline (os);
566 }
567 template <typename DMT, typename MT>
568 int
write(octave::stream & os,int block_size,oct_data_conv::data_type output_type,int skip,octave::mach_info::float_format flt_fmt) const569 octave_base_diag<DMT, MT>::write (octave::stream& os, int block_size,
570                                   oct_data_conv::data_type output_type,
571                                   int skip,
572                                   octave::mach_info::float_format flt_fmt) const
573 {
574   return to_dense ().write (os, block_size, output_type, skip, flt_fmt);
575 }
576 
577 template <typename DMT, typename MT>
578 void
print_info(std::ostream & os,const std::string & prefix) const579 octave_base_diag<DMT, MT>::print_info (std::ostream& os,
580                                        const std::string& prefix) const
581 {
582   matrix.print_info (os, prefix);
583 }
584 
585 // FIXME: this function is duplicated in octave_base_matrix<T>.  Could
586 // it somehow be shared instead?
587 
588 template <typename DMT, typename MT>
589 void
short_disp(std::ostream & os) const590 octave_base_diag<DMT, MT>::short_disp (std::ostream& os) const
591 {
592   if (matrix.isempty ())
593     os << "[]";
594   else if (matrix.ndims () == 2)
595     {
596       // FIXME: should this be configurable?
597       octave_idx_type max_elts = 10;
598       octave_idx_type elts = 0;
599 
600       octave_idx_type nel = matrix.numel ();
601 
602       octave_idx_type nr = matrix.rows ();
603       octave_idx_type nc = matrix.columns ();
604 
605       os << '[';
606 
607       for (octave_idx_type i = 0; i < nr; i++)
608         {
609           for (octave_idx_type j = 0; j < nc; j++)
610             {
611               std::ostringstream buf;
612               octave_print_internal (buf, matrix(i,j));
613               std::string tmp = buf.str ();
614               std::size_t pos = tmp.find_first_not_of (' ');
615               if (pos != std::string::npos)
616                 os << tmp.substr (pos);
617               else if (! tmp.empty ())
618                 os << tmp[0];
619 
620               if (++elts >= max_elts)
621                 goto done;
622 
623               if (j < nc - 1)
624                 os << ", ";
625             }
626 
627           if (i < nr - 1 && elts < max_elts)
628             os << "; ";
629         }
630 
631     done:
632 
633       if (nel <= max_elts)
634         os << ']';
635     }
636   else
637     os << "...";
638 }
639 
640 template <typename DMT, typename MT>
641 octave_value
fast_elem_extract(octave_idx_type n) const642 octave_base_diag<DMT, MT>::fast_elem_extract (octave_idx_type n) const
643 {
644   if (n < matrix.numel ())
645     {
646       octave_idx_type nr = matrix.rows ();
647 
648       octave_idx_type r = n % nr;
649       octave_idx_type c = n / nr;
650 
651       return octave_value (matrix.elem (r, c));
652     }
653   else
654     return octave_value ();
655 }
656 
657 template <typename DMT, typename MT>
658 octave_value
to_dense(void) const659 octave_base_diag<DMT, MT>::to_dense (void) const
660 {
661   if (! dense_cache.is_defined ())
662     dense_cache = MT (matrix);
663 
664   return dense_cache;
665 }
666