1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2008-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29
30 #include "byte-swap.h"
31 #include "dim-vector.h"
32
33 #include "mxarray.h"
34 #include "ov-perm.h"
35 #include "ov-re-mat.h"
36 #include "ov-scalar.h"
37 #include "error.h"
38 #include "errwarn.h"
39 #include "ops.h"
40 #include "pr-output.h"
41
42 #include "ls-oct-text.h"
43
44 octave_value
subsref(const std::string & type,const std::list<octave_value_list> & idx)45 octave_perm_matrix::subsref (const std::string& type,
46 const std::list<octave_value_list>& idx)
47 {
48 octave_value retval;
49
50 switch (type[0])
51 {
52 case '(':
53 retval = do_index_op (idx.front ());
54 break;
55
56 case '{':
57 case '.':
58 {
59 std::string nm = type_name ();
60 error ("%s cannot be indexed with %c", nm.c_str (), type[0]);
61 }
62 break;
63
64 default:
65 panic_impossible ();
66 }
67
68 return retval.next_subsref (type, idx);
69 }
70
71 octave_value
do_index_op(const octave_value_list & idx,bool resize_ok)72 octave_perm_matrix::do_index_op (const octave_value_list& idx,
73 bool resize_ok)
74 {
75 octave_value retval;
76 octave_idx_type nidx = idx.length ();
77 idx_vector idx0, idx1;
78 if (nidx == 2)
79 {
80 int k = 0; // index we're processing when index_vector throws
81 try
82 {
83 idx0 = idx(0).index_vector ();
84 k = 1;
85 idx1 = idx(1).index_vector ();
86 }
87 catch (octave::index_exception& e)
88 {
89 // Rethrow to allow more info to be reported later.
90 e.set_pos_if_unset (2, k+1);
91 throw;
92 }
93 }
94
95 // This hack is to allow constructing permutation matrices using
96 // eye(n)(p,:), eye(n)(:,q) && eye(n)(p,q) where p & q are permutation
97 // vectors.
98 // Note that, for better consistency, eye(n)(:,:) still converts to a full
99 // matrix.
100 if (nidx == 2)
101 {
102 bool left = idx0.is_permutation (matrix.rows ());
103 bool right = idx1.is_permutation (matrix.cols ());
104
105 if (left && right)
106 {
107 if (idx0.is_colon ()) left = false;
108 if (idx1.is_colon ()) right = false;
109 if (left || right)
110 {
111 PermMatrix p = matrix;
112 if (left)
113 p = PermMatrix (idx0, false) * p;
114 if (right)
115 p = p * PermMatrix (idx1, true);
116 retval = p;
117 }
118 else
119 {
120 retval = this;
121 this->count++;
122 }
123 }
124 }
125
126 if (! retval.is_defined ())
127 {
128 if (nidx == 2 && ! resize_ok && idx0.is_scalar () && idx1.is_scalar ())
129 retval = matrix.checkelem (idx0(0), idx1(0));
130 else
131 retval = to_dense ().do_index_op (idx, resize_ok);
132 }
133
134 return retval;
135 }
136
137 // Return true if this matrix has all true elements (non-zero, not NaN/NA).
138 // A permutation cannot have NaN/NA.
139 bool
is_true(void) const140 octave_perm_matrix::is_true (void) const
141 {
142 if (dims ().numel () > 1)
143 {
144 warn_array_as_logical (dims ());
145 return false; // > 1x1 permutation always has zeros, and no NaN.
146 }
147 else
148 return dims ().numel (); // 1x1 is [1] == true, 0x0 == false.
149 }
150
151 double
double_value(bool) const152 octave_perm_matrix::double_value (bool) const
153 {
154 if (isempty ())
155 err_invalid_conversion (type_name (), "real scalar");
156
157 warn_implicit_conversion ("Octave:array-to-scalar",
158 type_name (), "real scalar");
159
160 return matrix(0, 0);
161 }
162
163 float
float_value(bool) const164 octave_perm_matrix::float_value (bool) const
165 {
166 if (isempty ())
167 err_invalid_conversion (type_name (), "real scalar");
168
169 warn_implicit_conversion ("Octave:array-to-scalar",
170 type_name (), "real scalar");
171
172 return matrix(0, 0);
173 }
174
175 Complex
complex_value(bool) const176 octave_perm_matrix::complex_value (bool) const
177 {
178 if (rows () == 0 || columns () == 0)
179 err_invalid_conversion (type_name (), "complex scalar");
180
181 warn_implicit_conversion ("Octave:array-to-scalar",
182 type_name (), "complex scalar");
183
184 return Complex (matrix(0, 0), 0);
185 }
186
187 FloatComplex
float_complex_value(bool) const188 octave_perm_matrix::float_complex_value (bool) const
189 {
190 float tmp = lo_ieee_float_nan_value ();
191
192 FloatComplex retval (tmp, tmp);
193
194 if (rows () == 0 || columns () == 0)
195 err_invalid_conversion (type_name (), "complex scalar");
196
197 warn_implicit_conversion ("Octave:array-to-scalar",
198 type_name (), "complex scalar");
199
200 retval = matrix(0, 0);
201
202 return retval;
203 }
204
205 #define FORWARD_MATRIX_VALUE(TYPE, PREFIX) \
206 TYPE \
207 octave_perm_matrix::PREFIX ## _value (bool frc_str_conv) const \
208 { \
209 return to_dense ().PREFIX ## _value (frc_str_conv); \
210 }
211
212 SparseMatrix
sparse_matrix_value(bool) const213 octave_perm_matrix::sparse_matrix_value (bool) const
214 {
215 return SparseMatrix (matrix);
216 }
217
218 SparseBoolMatrix
sparse_bool_matrix_value(bool) const219 octave_perm_matrix::sparse_bool_matrix_value (bool) const
220 {
221 return SparseBoolMatrix (matrix);
222 }
223
224 SparseComplexMatrix
sparse_complex_matrix_value(bool) const225 octave_perm_matrix::sparse_complex_matrix_value (bool) const
226 {
227 return SparseComplexMatrix (sparse_matrix_value ());
228 }
229
FORWARD_MATRIX_VALUE(Matrix,matrix)230 FORWARD_MATRIX_VALUE (Matrix, matrix)
231 FORWARD_MATRIX_VALUE (FloatMatrix, float_matrix)
232 FORWARD_MATRIX_VALUE (ComplexMatrix, complex_matrix)
233 FORWARD_MATRIX_VALUE (FloatComplexMatrix, float_complex_matrix)
234
235 FORWARD_MATRIX_VALUE (NDArray, array)
236 FORWARD_MATRIX_VALUE (FloatNDArray, float_array)
237 FORWARD_MATRIX_VALUE (ComplexNDArray, complex_array)
238 FORWARD_MATRIX_VALUE (FloatComplexNDArray, float_complex_array)
239
240 FORWARD_MATRIX_VALUE (boolNDArray, bool_array)
241 FORWARD_MATRIX_VALUE (charNDArray, char_array)
242
243 idx_vector
244 octave_perm_matrix::index_vector (bool require_integers) const
245 {
246 return to_dense ().index_vector (require_integers);
247 }
248
249 octave_value
convert_to_str_internal(bool pad,bool force,char type) const250 octave_perm_matrix::convert_to_str_internal (bool pad, bool force,
251 char type) const
252 {
253 return to_dense ().convert_to_str_internal (pad, force, type);
254 }
255
256 octave_value
as_double(void) const257 octave_perm_matrix::as_double (void) const
258 {
259 return matrix;
260 }
261
262 octave_value
as_single(void) const263 octave_perm_matrix::as_single (void) const
264 {
265 return float_array_value ();
266 }
267
268 octave_value
as_int8(void) const269 octave_perm_matrix::as_int8 (void) const
270 {
271 return int8_array_value ();
272 }
273
274 octave_value
as_int16(void) const275 octave_perm_matrix::as_int16 (void) const
276 {
277 return int16_array_value ();
278 }
279
280 octave_value
as_int32(void) const281 octave_perm_matrix::as_int32 (void) const
282 {
283 return int32_array_value ();
284 }
285
286 octave_value
as_int64(void) const287 octave_perm_matrix::as_int64 (void) const
288 {
289 return int64_array_value ();
290 }
291
292 octave_value
as_uint8(void) const293 octave_perm_matrix::as_uint8 (void) const
294 {
295 return uint8_array_value ();
296 }
297
298 octave_value
as_uint16(void) const299 octave_perm_matrix::as_uint16 (void) const
300 {
301 return uint16_array_value ();
302 }
303
304 octave_value
as_uint32(void) const305 octave_perm_matrix::as_uint32 (void) const
306 {
307 return uint32_array_value ();
308 }
309
310 octave_value
as_uint64(void) const311 octave_perm_matrix::as_uint64 (void) const
312 {
313 return uint64_array_value ();
314 }
315
316 float_display_format
get_edit_display_format(void) const317 octave_perm_matrix::get_edit_display_format (void) const
318 {
319 return float_display_format (float_format (1, 0, 0));
320 }
321
322 std::string
edit_display(const float_display_format & fmt,octave_idx_type i,octave_idx_type j) const323 octave_perm_matrix::edit_display (const float_display_format& fmt,
324 octave_idx_type i,
325 octave_idx_type j) const
326 {
327 std::ostringstream buf;
328 octave_print_internal (buf, fmt, octave_int<octave_idx_type> (matrix(i,j)));
329 return buf.str ();
330 }
331
332 bool
save_ascii(std::ostream & os)333 octave_perm_matrix::save_ascii (std::ostream& os)
334 {
335 os << "# size: " << matrix.rows () << "\n";
336 os << "# orient: c\n";
337
338 Array<octave_idx_type> pvec = matrix.col_perm_vec ();
339 octave_idx_type n = pvec.numel ();
340 ColumnVector tmp (n);
341 for (octave_idx_type i = 0; i < n; i++) tmp(i) = pvec(i) + 1;
342 os << tmp;
343
344 return true;
345 }
346
347 bool
load_ascii(std::istream & is)348 octave_perm_matrix::load_ascii (std::istream& is)
349 {
350 octave_idx_type n;
351 char orient;
352
353 if (! extract_keyword (is, "size", n, true)
354 || ! extract_keyword (is, "orient", orient, true))
355 error ("load: failed to extract size & orientation");
356
357 bool colp = orient == 'c';
358 ColumnVector tmp (n);
359 is >> tmp;
360 if (! is)
361 error ("load: failed to load permutation matrix constant");
362
363 Array<octave_idx_type> pvec (dim_vector (n, 1));
364 for (octave_idx_type i = 0; i < n; i++) pvec(i) = tmp(i) - 1;
365 matrix = PermMatrix (pvec, colp);
366
367 // Invalidate cache. Probably not necessary, but safe.
368 dense_cache = octave_value ();
369
370 return true;
371 }
372
373 bool
save_binary(std::ostream & os,bool)374 octave_perm_matrix::save_binary (std::ostream& os, bool)
375 {
376
377 int32_t sz = matrix.rows ();
378 bool colp = true;
379 os.write (reinterpret_cast<char *> (&sz), 4);
380 os.write (reinterpret_cast<char *> (&colp), 1);
381 const Array<octave_idx_type>& col_perm = matrix.col_perm_vec ();
382 os.write (reinterpret_cast<const char *> (col_perm.data ()),
383 col_perm.byte_size ());
384
385 return true;
386 }
387
388 bool
load_binary(std::istream & is,bool swap,octave::mach_info::float_format)389 octave_perm_matrix::load_binary (std::istream& is, bool swap,
390 octave::mach_info::float_format)
391 {
392 int32_t sz;
393 bool colp;
394 if (! (is.read (reinterpret_cast<char *> (&sz), 4)
395 && is.read (reinterpret_cast<char *> (&colp), 1)))
396 return false;
397
398 MArray<octave_idx_type> m (dim_vector (sz, 1));
399
400 if (! is.read (reinterpret_cast<char *> (m.fortran_vec ()), m.byte_size ()))
401 return false;
402
403 if (swap)
404 {
405 int nel = m.numel ();
406 for (int i = 0; i < nel; i++)
407 switch (sizeof (octave_idx_type))
408 {
409 case 8:
410 swap_bytes<8> (&m(i));
411 break;
412 case 4:
413 swap_bytes<4> (&m(i));
414 break;
415 case 2:
416 swap_bytes<2> (&m(i));
417 break;
418 case 1:
419 default:
420 break;
421 }
422 }
423
424 matrix = PermMatrix (m, colp);
425 return true;
426 }
427
428 void
print_raw(std::ostream & os,bool pr_as_read_syntax) const429 octave_perm_matrix::print_raw (std::ostream& os,
430 bool pr_as_read_syntax) const
431 {
432 return octave_print_internal (os, matrix, pr_as_read_syntax,
433 current_print_indent_level ());
434 }
435
436 mxArray *
as_mxArray(void) const437 octave_perm_matrix::as_mxArray (void) const
438 {
439 return to_dense ().as_mxArray ();
440 }
441
442 bool
print_as_scalar(void) const443 octave_perm_matrix::print_as_scalar (void) const
444 {
445 dim_vector dv = dims ();
446
447 return (dv.all_ones () || dv.any_zero ());
448 }
449
450 void
print(std::ostream & os,bool pr_as_read_syntax)451 octave_perm_matrix::print (std::ostream& os, bool pr_as_read_syntax)
452 {
453 print_raw (os, pr_as_read_syntax);
454 newline (os);
455 }
456
457 int
write(octave::stream & os,int block_size,oct_data_conv::data_type output_type,int skip,octave::mach_info::float_format flt_fmt) const458 octave_perm_matrix::write (octave::stream& os, int block_size,
459 oct_data_conv::data_type output_type, int skip,
460 octave::mach_info::float_format flt_fmt) const
461 {
462 return to_dense ().write (os, block_size, output_type, skip, flt_fmt);
463 }
464
465 void
print_info(std::ostream & os,const std::string & prefix) const466 octave_perm_matrix::print_info (std::ostream& os,
467 const std::string& prefix) const
468 {
469 matrix.print_info (os, prefix);
470 }
471
472 octave_value
to_dense(void) const473 octave_perm_matrix::to_dense (void) const
474 {
475 if (! dense_cache.is_defined ())
476 dense_cache = Matrix (matrix);
477
478 return dense_cache;
479 }
480
481 DEFINE_OV_TYPEID_FUNCTIONS_AND_DATA (octave_perm_matrix,
482 "permutation matrix", "double");
483
484 static octave_base_value *
default_numeric_conversion_function(const octave_base_value & a)485 default_numeric_conversion_function (const octave_base_value& a)
486 {
487 const octave_perm_matrix& v = dynamic_cast<const octave_perm_matrix&> (a);
488
489 return new octave_matrix (v.matrix_value ());
490 }
491
492 octave_base_value::type_conv_info
numeric_conversion_function(void) const493 octave_perm_matrix::numeric_conversion_function (void) const
494 {
495 return octave_base_value::type_conv_info (default_numeric_conversion_function,
496 octave_matrix::static_type_id ());
497 }
498
499 // FIXME: This is duplicated from octave_base_matrix<T>. Could
500 // octave_perm_matrix be derived from octave_base_matrix<T>?
501
502 void
short_disp(std::ostream & os) const503 octave_perm_matrix::short_disp (std::ostream& os) const
504 {
505 if (matrix.isempty ())
506 os << "[]";
507 else if (matrix.ndims () == 2)
508 {
509 // FIXME: should this be configurable?
510 octave_idx_type max_elts = 10;
511 octave_idx_type elts = 0;
512
513 octave_idx_type nel = matrix.numel ();
514
515 octave_idx_type nr = matrix.rows ();
516 octave_idx_type nc = matrix.columns ();
517
518 os << '[';
519
520 for (octave_idx_type i = 0; i < nr; i++)
521 {
522 for (octave_idx_type j = 0; j < nc; j++)
523 {
524 std::ostringstream buf;
525 octave_int<octave_idx_type> tval (matrix(i,j));
526 octave_print_internal (buf, tval);
527 std::string tmp = buf.str ();
528 std::size_t pos = tmp.find_first_not_of (' ');
529 if (pos != std::string::npos)
530 os << tmp.substr (pos);
531 else if (! tmp.empty ())
532 os << tmp[0];
533
534 if (++elts >= max_elts)
535 goto done;
536
537 if (j < nc - 1)
538 os << ", ";
539 }
540
541 if (i < nr - 1 && elts < max_elts)
542 os << "; ";
543 }
544
545 done:
546
547 if (nel <= max_elts)
548 os << ']';
549 }
550 else
551 os << "...";
552 }
553
554 octave_base_value *
try_narrowing_conversion(void)555 octave_perm_matrix::try_narrowing_conversion (void)
556 {
557 octave_base_value *retval = nullptr;
558
559 if (matrix.numel () == 1)
560 retval = new octave_scalar (matrix (0, 0));
561
562 return retval;
563 }
564
565 octave_value
fast_elem_extract(octave_idx_type n) const566 octave_perm_matrix::fast_elem_extract (octave_idx_type n) const
567 {
568 if (n < matrix.numel ())
569 {
570 octave_idx_type nr = matrix.rows ();
571
572 octave_idx_type r = n % nr;
573 octave_idx_type c = n / nr;
574
575 return octave_value (matrix.elem (r, c));
576 }
577 else
578 return octave_value ();
579 }
580