1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2008-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29
30 #include "byte-swap.h"
31
32 #include "ov-re-diag.h"
33 #include "ov-flt-re-diag.h"
34 #include "ov-base-diag.cc"
35 #include "ov-scalar.h"
36 #include "ov-re-mat.h"
37 #include "ls-utils.h"
38
39
40 template class octave_base_diag<DiagMatrix, Matrix>;
41
42 DEFINE_OV_TYPEID_FUNCTIONS_AND_DATA (octave_diag_matrix, "diagonal matrix",
43 "double");
44
45 static octave_base_value *
default_numeric_conversion_function(const octave_base_value & a)46 default_numeric_conversion_function (const octave_base_value& a)
47 {
48 const octave_diag_matrix& v = dynamic_cast<const octave_diag_matrix&> (a);
49
50 return new octave_matrix (v.matrix_value ());
51 }
52
53 octave_base_value::type_conv_info
numeric_conversion_function(void) const54 octave_diag_matrix::numeric_conversion_function (void) const
55 {
56 return octave_base_value::type_conv_info (default_numeric_conversion_function,
57 octave_matrix::static_type_id ());
58 }
59
60 static octave_base_value *
default_numeric_demotion_function(const octave_base_value & a)61 default_numeric_demotion_function (const octave_base_value& a)
62 {
63 const octave_diag_matrix& v = dynamic_cast<const octave_diag_matrix&> (a);
64
65 return new octave_float_diag_matrix (v.float_diag_matrix_value ());
66 }
67
68 octave_base_value::type_conv_info
numeric_demotion_function(void) const69 octave_diag_matrix::numeric_demotion_function (void) const
70 {
71 return octave_base_value::type_conv_info
72 (default_numeric_demotion_function,
73 octave_float_diag_matrix::static_type_id ());
74 }
75
76 octave_base_value *
try_narrowing_conversion(void)77 octave_diag_matrix::try_narrowing_conversion (void)
78 {
79 octave_base_value *retval = nullptr;
80
81 if (matrix.nelem () == 1)
82 retval = new octave_scalar (matrix (0, 0));
83
84 return retval;
85 }
86
87 octave_value
do_index_op(const octave_value_list & idx,bool resize_ok)88 octave_diag_matrix::do_index_op (const octave_value_list& idx,
89 bool resize_ok)
90 {
91 octave_value retval;
92
93 // This hack is to allow constructing permutation matrices using
94 // eye(n)(p,:), eye(n)(:,q) && eye(n)(p,q) where p & q are permutation
95 // vectors.
96 if (! resize_ok && idx.length () == 2 && matrix.is_multiple_of_identity (1))
97 {
98 int k = 0; // index we're accessing when index_vector throws
99 try
100 {
101 idx_vector idx0 = idx(0).index_vector ();
102 k = 1;
103 idx_vector idx1 = idx(1).index_vector ();
104
105 bool left = idx0.is_permutation (matrix.rows ());
106 bool right = idx1.is_permutation (matrix.cols ());
107
108 if (left && right)
109 {
110 if (idx0.is_colon ()) left = false;
111 if (idx1.is_colon ()) right = false;
112 if (left && right)
113 retval = PermMatrix (idx0, false) * PermMatrix (idx1, true);
114 else if (left)
115 retval = PermMatrix (idx0, false);
116 else if (right)
117 retval = PermMatrix (idx1, true);
118 else
119 {
120 retval = this;
121 this->count++;
122 }
123 }
124 }
125 catch (octave::index_exception& e)
126 {
127 // Rethrow to allow more info to be reported later.
128 e.set_pos_if_unset (2, k+1);
129 throw;
130 }
131 }
132
133 if (retval.is_undefined ())
134 retval = octave_base_diag<DiagMatrix, Matrix>::do_index_op (idx, resize_ok);
135
136 return retval;
137 }
138
139 DiagMatrix
diag_matrix_value(bool) const140 octave_diag_matrix::diag_matrix_value (bool) const
141 {
142 return matrix;
143 }
144
145 FloatDiagMatrix
float_diag_matrix_value(bool) const146 octave_diag_matrix::float_diag_matrix_value (bool) const
147 {
148 return FloatDiagMatrix (matrix);
149 }
150
151 ComplexDiagMatrix
complex_diag_matrix_value(bool) const152 octave_diag_matrix::complex_diag_matrix_value (bool) const
153 {
154 return ComplexDiagMatrix (matrix);
155 }
156
157 FloatComplexDiagMatrix
float_complex_diag_matrix_value(bool) const158 octave_diag_matrix::float_complex_diag_matrix_value (bool) const
159 {
160 return FloatComplexDiagMatrix (matrix);
161 }
162
163 octave_value
as_double(void) const164 octave_diag_matrix::as_double (void) const
165 {
166 return matrix;
167 }
168
169 octave_value
as_single(void) const170 octave_diag_matrix::as_single (void) const
171 {
172 return FloatDiagMatrix (matrix);
173 }
174
175 octave_value
as_int8(void) const176 octave_diag_matrix::as_int8 (void) const
177 {
178 return int8_array_value ();
179 }
180
181 octave_value
as_int16(void) const182 octave_diag_matrix::as_int16 (void) const
183 {
184 return int16_array_value ();
185 }
186
187 octave_value
as_int32(void) const188 octave_diag_matrix::as_int32 (void) const
189 {
190 return int32_array_value ();
191 }
192
193 octave_value
as_int64(void) const194 octave_diag_matrix::as_int64 (void) const
195 {
196 return int64_array_value ();
197 }
198
199 octave_value
as_uint8(void) const200 octave_diag_matrix::as_uint8 (void) const
201 {
202 return uint8_array_value ();
203 }
204
205 octave_value
as_uint16(void) const206 octave_diag_matrix::as_uint16 (void) const
207 {
208 return uint16_array_value ();
209 }
210
211 octave_value
as_uint32(void) const212 octave_diag_matrix::as_uint32 (void) const
213 {
214 return uint32_array_value ();
215 }
216
217 octave_value
as_uint64(void) const218 octave_diag_matrix::as_uint64 (void) const
219 {
220 return uint64_array_value ();
221 }
222
223 octave_value
map(unary_mapper_t umap) const224 octave_diag_matrix::map (unary_mapper_t umap) const
225 {
226 switch (umap)
227 {
228 case umap_abs:
229 return matrix.abs ();
230 case umap_real:
231 case umap_conj:
232 return matrix;
233 case umap_imag:
234 return DiagMatrix (matrix.rows (), matrix.cols (), 0.0);
235 case umap_sqrt:
236 {
237 ComplexColumnVector tmp;
238 tmp = matrix.extract_diag ().map<Complex> (octave::math::rc_sqrt);
239 ComplexDiagMatrix retval (tmp);
240 retval.resize (matrix.rows (), matrix.columns ());
241 return retval;
242 }
243 default:
244 return to_dense ().map (umap);
245 }
246 }
247
248 bool
save_binary(std::ostream & os,bool save_as_floats)249 octave_diag_matrix::save_binary (std::ostream& os, bool save_as_floats)
250 {
251
252 int32_t r = matrix.rows ();
253 int32_t c = matrix.cols ();
254 os.write (reinterpret_cast<char *> (&r), 4);
255 os.write (reinterpret_cast<char *> (&c), 4);
256
257 Matrix m = Matrix (matrix.extract_diag ());
258 save_type st = LS_DOUBLE;
259 if (save_as_floats)
260 {
261 if (m.too_large_for_float ())
262 {
263 warning ("save: some values too large to save as floats --");
264 warning ("save: saving as doubles instead");
265 }
266 else
267 st = LS_FLOAT;
268 }
269 else if (matrix.length () > 8192) // FIXME: make this configurable.
270 {
271 double max_val, min_val;
272 if (m.all_integers (max_val, min_val))
273 st = get_save_type (max_val, min_val);
274 }
275
276 const double *mtmp = m.data ();
277 write_doubles (os, mtmp, st, m.numel ());
278
279 return true;
280 }
281
282 bool
load_binary(std::istream & is,bool swap,octave::mach_info::float_format fmt)283 octave_diag_matrix::load_binary (std::istream& is, bool swap,
284 octave::mach_info::float_format fmt)
285 {
286 int32_t r, c;
287 char tmp;
288 if (! (is.read (reinterpret_cast<char *> (&r), 4)
289 && is.read (reinterpret_cast<char *> (&c), 4)
290 && is.read (reinterpret_cast<char *> (&tmp), 1)))
291 return false;
292 if (swap)
293 {
294 swap_bytes<4> (&r);
295 swap_bytes<4> (&c);
296 }
297
298 DiagMatrix m (r, c);
299 double *re = m.fortran_vec ();
300 octave_idx_type len = m.length ();
301 read_doubles (is, re, static_cast<save_type> (tmp), len, swap, fmt);
302
303 if (! is)
304 return false;
305
306 matrix = m;
307
308 return true;
309 }
310
311 bool
chk_valid_scalar(const octave_value & val,double & x) const312 octave_diag_matrix::chk_valid_scalar (const octave_value& val,
313 double& x) const
314 {
315 bool retval = val.is_real_scalar ();
316 if (retval)
317 x = val.double_value ();
318 return retval;
319 }
320