1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2008-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING.  If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 #  include "config.h"
28 #endif
29 
30 #include "byte-swap.h"
31 
32 #include "ov-re-diag.h"
33 #include "ov-flt-re-diag.h"
34 #include "ov-base-diag.cc"
35 #include "ov-scalar.h"
36 #include "ov-re-mat.h"
37 #include "ls-utils.h"
38 
39 
40 template class octave_base_diag<DiagMatrix, Matrix>;
41 
42 DEFINE_OV_TYPEID_FUNCTIONS_AND_DATA (octave_diag_matrix, "diagonal matrix",
43                                      "double");
44 
45 static octave_base_value *
default_numeric_conversion_function(const octave_base_value & a)46 default_numeric_conversion_function (const octave_base_value& a)
47 {
48   const octave_diag_matrix& v = dynamic_cast<const octave_diag_matrix&> (a);
49 
50   return new octave_matrix (v.matrix_value ());
51 }
52 
53 octave_base_value::type_conv_info
numeric_conversion_function(void) const54 octave_diag_matrix::numeric_conversion_function (void) const
55 {
56   return octave_base_value::type_conv_info (default_numeric_conversion_function,
57                                             octave_matrix::static_type_id ());
58 }
59 
60 static octave_base_value *
default_numeric_demotion_function(const octave_base_value & a)61 default_numeric_demotion_function (const octave_base_value& a)
62 {
63   const octave_diag_matrix& v = dynamic_cast<const octave_diag_matrix&> (a);
64 
65   return new octave_float_diag_matrix (v.float_diag_matrix_value ());
66 }
67 
68 octave_base_value::type_conv_info
numeric_demotion_function(void) const69 octave_diag_matrix::numeric_demotion_function (void) const
70 {
71   return octave_base_value::type_conv_info
72            (default_numeric_demotion_function,
73             octave_float_diag_matrix::static_type_id ());
74 }
75 
76 octave_base_value *
try_narrowing_conversion(void)77 octave_diag_matrix::try_narrowing_conversion (void)
78 {
79   octave_base_value *retval = nullptr;
80 
81   if (matrix.nelem () == 1)
82     retval = new octave_scalar (matrix (0, 0));
83 
84   return retval;
85 }
86 
87 octave_value
do_index_op(const octave_value_list & idx,bool resize_ok)88 octave_diag_matrix::do_index_op (const octave_value_list& idx,
89                                  bool resize_ok)
90 {
91   octave_value retval;
92 
93   // This hack is to allow constructing permutation matrices using
94   // eye(n)(p,:), eye(n)(:,q) && eye(n)(p,q) where p & q are permutation
95   // vectors.
96   if (! resize_ok && idx.length () == 2 && matrix.is_multiple_of_identity (1))
97     {
98       int k = 0;        // index we're accessing when index_vector throws
99       try
100         {
101           idx_vector idx0 = idx(0).index_vector ();
102           k = 1;
103           idx_vector idx1 = idx(1).index_vector ();
104 
105           bool left = idx0.is_permutation (matrix.rows ());
106           bool right = idx1.is_permutation (matrix.cols ());
107 
108           if (left && right)
109             {
110               if (idx0.is_colon ()) left = false;
111               if (idx1.is_colon ()) right = false;
112               if (left && right)
113                 retval = PermMatrix (idx0, false) * PermMatrix (idx1, true);
114               else if (left)
115                 retval = PermMatrix (idx0, false);
116               else if (right)
117                 retval = PermMatrix (idx1, true);
118               else
119                 {
120                   retval = this;
121                   this->count++;
122                 }
123             }
124         }
125       catch (octave::index_exception& e)
126         {
127           // Rethrow to allow more info to be reported later.
128           e.set_pos_if_unset (2, k+1);
129           throw;
130         }
131     }
132 
133   if (retval.is_undefined ())
134     retval = octave_base_diag<DiagMatrix, Matrix>::do_index_op (idx, resize_ok);
135 
136   return retval;
137 }
138 
139 DiagMatrix
diag_matrix_value(bool) const140 octave_diag_matrix::diag_matrix_value (bool) const
141 {
142   return matrix;
143 }
144 
145 FloatDiagMatrix
float_diag_matrix_value(bool) const146 octave_diag_matrix::float_diag_matrix_value (bool) const
147 {
148   return FloatDiagMatrix (matrix);
149 }
150 
151 ComplexDiagMatrix
complex_diag_matrix_value(bool) const152 octave_diag_matrix::complex_diag_matrix_value (bool) const
153 {
154   return ComplexDiagMatrix (matrix);
155 }
156 
157 FloatComplexDiagMatrix
float_complex_diag_matrix_value(bool) const158 octave_diag_matrix::float_complex_diag_matrix_value (bool) const
159 {
160   return FloatComplexDiagMatrix (matrix);
161 }
162 
163 octave_value
as_double(void) const164 octave_diag_matrix::as_double (void) const
165 {
166   return matrix;
167 }
168 
169 octave_value
as_single(void) const170 octave_diag_matrix::as_single (void) const
171 {
172   return FloatDiagMatrix (matrix);
173 }
174 
175 octave_value
as_int8(void) const176 octave_diag_matrix::as_int8 (void) const
177 {
178   return int8_array_value ();
179 }
180 
181 octave_value
as_int16(void) const182 octave_diag_matrix::as_int16 (void) const
183 {
184   return int16_array_value ();
185 }
186 
187 octave_value
as_int32(void) const188 octave_diag_matrix::as_int32 (void) const
189 {
190   return int32_array_value ();
191 }
192 
193 octave_value
as_int64(void) const194 octave_diag_matrix::as_int64 (void) const
195 {
196   return int64_array_value ();
197 }
198 
199 octave_value
as_uint8(void) const200 octave_diag_matrix::as_uint8 (void) const
201 {
202   return uint8_array_value ();
203 }
204 
205 octave_value
as_uint16(void) const206 octave_diag_matrix::as_uint16 (void) const
207 {
208   return uint16_array_value ();
209 }
210 
211 octave_value
as_uint32(void) const212 octave_diag_matrix::as_uint32 (void) const
213 {
214   return uint32_array_value ();
215 }
216 
217 octave_value
as_uint64(void) const218 octave_diag_matrix::as_uint64 (void) const
219 {
220   return uint64_array_value ();
221 }
222 
223 octave_value
map(unary_mapper_t umap) const224 octave_diag_matrix::map (unary_mapper_t umap) const
225 {
226   switch (umap)
227     {
228     case umap_abs:
229       return matrix.abs ();
230     case umap_real:
231     case umap_conj:
232       return matrix;
233     case umap_imag:
234       return DiagMatrix (matrix.rows (), matrix.cols (), 0.0);
235     case umap_sqrt:
236       {
237         ComplexColumnVector tmp;
238         tmp = matrix.extract_diag ().map<Complex> (octave::math::rc_sqrt);
239         ComplexDiagMatrix retval (tmp);
240         retval.resize (matrix.rows (), matrix.columns ());
241         return retval;
242       }
243     default:
244       return to_dense ().map (umap);
245     }
246 }
247 
248 bool
save_binary(std::ostream & os,bool save_as_floats)249 octave_diag_matrix::save_binary (std::ostream& os, bool save_as_floats)
250 {
251 
252   int32_t r = matrix.rows ();
253   int32_t c = matrix.cols ();
254   os.write (reinterpret_cast<char *> (&r), 4);
255   os.write (reinterpret_cast<char *> (&c), 4);
256 
257   Matrix m = Matrix (matrix.extract_diag ());
258   save_type st = LS_DOUBLE;
259   if (save_as_floats)
260     {
261       if (m.too_large_for_float ())
262         {
263           warning ("save: some values too large to save as floats --");
264           warning ("save: saving as doubles instead");
265         }
266       else
267         st = LS_FLOAT;
268     }
269   else if (matrix.length () > 8192) // FIXME: make this configurable.
270     {
271       double max_val, min_val;
272       if (m.all_integers (max_val, min_val))
273         st = get_save_type (max_val, min_val);
274     }
275 
276   const double *mtmp = m.data ();
277   write_doubles (os, mtmp, st, m.numel ());
278 
279   return true;
280 }
281 
282 bool
load_binary(std::istream & is,bool swap,octave::mach_info::float_format fmt)283 octave_diag_matrix::load_binary (std::istream& is, bool swap,
284                                  octave::mach_info::float_format fmt)
285 {
286   int32_t r, c;
287   char tmp;
288   if (! (is.read (reinterpret_cast<char *> (&r), 4)
289          && is.read (reinterpret_cast<char *> (&c), 4)
290          && is.read (reinterpret_cast<char *> (&tmp), 1)))
291     return false;
292   if (swap)
293     {
294       swap_bytes<4> (&r);
295       swap_bytes<4> (&c);
296     }
297 
298   DiagMatrix m (r, c);
299   double *re = m.fortran_vec ();
300   octave_idx_type len = m.length ();
301   read_doubles (is, re, static_cast<save_type> (tmp), len, swap, fmt);
302 
303   if (! is)
304     return false;
305 
306   matrix = m;
307 
308   return true;
309 }
310 
311 bool
chk_valid_scalar(const octave_value & val,double & x) const312 octave_diag_matrix::chk_valid_scalar (const octave_value& val,
313                                       double& x) const
314 {
315   bool retval = val.is_real_scalar ();
316   if (retval)
317     x = val.double_value ();
318   return retval;
319 }
320