1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 1994-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING.  If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 #  include "config.h"
28 #endif
29 
30 #include <ostream>
31 
32 #include "Array-util.h"
33 #include "lo-error.h"
34 #include "mx-base.h"
35 #include "mx-inlines.cc"
36 #include "oct-cmplx.h"
37 
38 // Diagonal Matrix class.
39 
40 bool
operator ==(const DiagMatrix & a) const41 DiagMatrix::operator == (const DiagMatrix& a) const
42 {
43   if (rows () != a.rows () || cols () != a.cols ())
44     return 0;
45 
46   return mx_inline_equal (length (), data (), a.data ());
47 }
48 
49 bool
operator !=(const DiagMatrix & a) const50 DiagMatrix::operator != (const DiagMatrix& a) const
51 {
52   return !(*this == a);
53 }
54 
55 DiagMatrix&
fill(double val)56 DiagMatrix::fill (double val)
57 {
58   for (octave_idx_type i = 0; i < length (); i++)
59     elem (i, i) = val;
60   return *this;
61 }
62 
63 DiagMatrix&
fill(double val,octave_idx_type beg,octave_idx_type end)64 DiagMatrix::fill (double val, octave_idx_type beg, octave_idx_type end)
65 {
66   if (beg < 0 || end >= length () || end < beg)
67     (*current_liboctave_error_handler) ("range error for fill");
68 
69   for (octave_idx_type i = beg; i <= end; i++)
70     elem (i, i) = val;
71 
72   return *this;
73 }
74 
75 DiagMatrix&
fill(const ColumnVector & a)76 DiagMatrix::fill (const ColumnVector& a)
77 {
78   octave_idx_type len = length ();
79   if (a.numel () != len)
80     (*current_liboctave_error_handler) ("range error for fill");
81 
82   for (octave_idx_type i = 0; i < len; i++)
83     elem (i, i) = a.elem (i);
84 
85   return *this;
86 }
87 
88 DiagMatrix&
fill(const RowVector & a)89 DiagMatrix::fill (const RowVector& a)
90 {
91   octave_idx_type len = length ();
92   if (a.numel () != len)
93     (*current_liboctave_error_handler) ("range error for fill");
94 
95   for (octave_idx_type i = 0; i < len; i++)
96     elem (i, i) = a.elem (i);
97 
98   return *this;
99 }
100 
101 DiagMatrix&
fill(const ColumnVector & a,octave_idx_type beg)102 DiagMatrix::fill (const ColumnVector& a, octave_idx_type beg)
103 {
104   octave_idx_type a_len = a.numel ();
105   if (beg < 0 || beg + a_len >= length ())
106     (*current_liboctave_error_handler) ("range error for fill");
107 
108   for (octave_idx_type i = 0; i < a_len; i++)
109     elem (i+beg, i+beg) = a.elem (i);
110 
111   return *this;
112 }
113 
114 DiagMatrix&
fill(const RowVector & a,octave_idx_type beg)115 DiagMatrix::fill (const RowVector& a, octave_idx_type beg)
116 {
117   octave_idx_type a_len = a.numel ();
118   if (beg < 0 || beg + a_len >= length ())
119     (*current_liboctave_error_handler) ("range error for fill");
120 
121   for (octave_idx_type i = 0; i < a_len; i++)
122     elem (i+beg, i+beg) = a.elem (i);
123 
124   return *this;
125 }
126 
127 DiagMatrix
abs(void) const128 DiagMatrix::abs (void) const
129 {
130   return DiagMatrix (extract_diag ().abs (), rows (), columns ());
131 }
132 
133 DiagMatrix
real(const ComplexDiagMatrix & a)134 real (const ComplexDiagMatrix& a)
135 {
136   return DiagMatrix (real (a.extract_diag ()), a.rows (), a.cols ());
137 }
138 
139 DiagMatrix
imag(const ComplexDiagMatrix & a)140 imag (const ComplexDiagMatrix& a)
141 {
142   return DiagMatrix (imag (a.extract_diag ()), a.rows (), a.cols ());
143 }
144 
145 Matrix
extract(octave_idx_type r1,octave_idx_type c1,octave_idx_type r2,octave_idx_type c2) const146 DiagMatrix::extract (octave_idx_type r1, octave_idx_type c1,
147                      octave_idx_type r2, octave_idx_type c2) const
148 {
149   if (r1 > r2) { std::swap (r1, r2); }
150   if (c1 > c2) { std::swap (c1, c2); }
151 
152   octave_idx_type new_r = r2 - r1 + 1;
153   octave_idx_type new_c = c2 - c1 + 1;
154 
155   Matrix result (new_r, new_c);
156 
157   for (octave_idx_type j = 0; j < new_c; j++)
158     for (octave_idx_type i = 0; i < new_r; i++)
159       result.elem (i, j) = elem (r1+i, c1+j);
160 
161   return result;
162 }
163 
164 // extract row or column i.
165 
166 RowVector
row(octave_idx_type i) const167 DiagMatrix::row (octave_idx_type i) const
168 {
169   octave_idx_type r = rows ();
170   octave_idx_type c = cols ();
171   if (i < 0 || i >= r)
172     (*current_liboctave_error_handler) ("invalid row selection");
173 
174   RowVector retval (c, 0.0);
175   if (r <= c || i < c)
176     retval.elem (i) = elem (i, i);
177 
178   return retval;
179 }
180 
181 RowVector
row(char * s) const182 DiagMatrix::row (char *s) const
183 {
184   if (! s)
185     (*current_liboctave_error_handler) ("invalid row selection");
186 
187   char c = s[0];
188   if (c == 'f' || c == 'F')
189     return row (static_cast<octave_idx_type> (0));
190   else if (c == 'l' || c == 'L')
191     return row (rows () - 1);
192   else
193     (*current_liboctave_error_handler) ("invalid row selection");
194 }
195 
196 ColumnVector
column(octave_idx_type i) const197 DiagMatrix::column (octave_idx_type i) const
198 {
199   octave_idx_type r = rows ();
200   octave_idx_type c = cols ();
201   if (i < 0 || i >= c)
202     (*current_liboctave_error_handler) ("invalid column selection");
203 
204   ColumnVector retval (r, 0.0);
205   if (r >= c || i < r)
206     retval.elem (i) = elem (i, i);
207 
208   return retval;
209 }
210 
211 ColumnVector
column(char * s) const212 DiagMatrix::column (char *s) const
213 {
214   if (! s)
215     (*current_liboctave_error_handler) ("invalid column selection");
216 
217   char c = s[0];
218   if (c == 'f' || c == 'F')
219     return column (static_cast<octave_idx_type> (0));
220   else if (c == 'l' || c == 'L')
221     return column (cols () - 1);
222   else
223     (*current_liboctave_error_handler) ("invalid column selection");
224 }
225 
226 DiagMatrix
inverse(void) const227 DiagMatrix::inverse (void) const
228 {
229   octave_idx_type info;
230   return inverse (info);
231 }
232 
233 DiagMatrix
inverse(octave_idx_type & info) const234 DiagMatrix::inverse (octave_idx_type& info) const
235 {
236   octave_idx_type r = rows ();
237   octave_idx_type c = cols ();
238   if (r != c)
239     (*current_liboctave_error_handler) ("inverse requires square matrix");
240 
241   DiagMatrix retval (r, c);
242 
243   info = 0;
244   octave_idx_type len = r;        // alias for readability
245   octave_idx_type z_count  = 0;   // zeros
246   octave_idx_type nz_count = 0;   // non-zeros
247   for (octave_idx_type i = 0; i < len; i++)
248     {
249       if (xelem (i, i) == 0.0)
250         {
251           z_count++;
252           if (nz_count > 0)
253             break;
254         }
255       else
256         {
257           nz_count++;
258           if (z_count > 0)
259             break;
260           retval.elem (i, i) = 1.0 / xelem (i, i);
261         }
262     }
263   if (nz_count == 0)
264     {
265       (*current_liboctave_error_handler)
266         ("inverse of the null matrix not defined");
267     }
268   else if (z_count > 0)
269     {
270       info = -1;
271       element_type *data = retval.fortran_vec ();
272       std::fill (data, data + len, octave::numeric_limits<double>::Inf ());
273     }
274 
275   return retval;
276 }
277 
278 DiagMatrix
pseudo_inverse(double tol) const279 DiagMatrix::pseudo_inverse (double tol) const
280 {
281   octave_idx_type r = rows ();
282   octave_idx_type c = cols ();
283   octave_idx_type len = length ();
284 
285   DiagMatrix retval (c, r);
286 
287   for (octave_idx_type i = 0; i < len; i++)
288     {
289       double val = std::abs (elem (i, i));
290       if (val < tol || val == 0.0)
291         retval.elem (i, i) = 0.0;
292       else
293         retval.elem (i, i) = 1.0 / elem (i, i);
294     }
295 
296   return retval;
297 }
298 
299 // diagonal matrix by diagonal matrix -> diagonal matrix operations
300 
301 // diagonal matrix by diagonal matrix -> diagonal matrix operations
302 
303 DiagMatrix
operator *(const DiagMatrix & a,const DiagMatrix & b)304 operator * (const DiagMatrix& a, const DiagMatrix& b)
305 {
306   octave_idx_type a_nr = a.rows ();
307   octave_idx_type a_nc = a.cols ();
308 
309   octave_idx_type b_nr = b.rows ();
310   octave_idx_type b_nc = b.cols ();
311 
312   if (a_nc != b_nr)
313     octave::err_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc);
314 
315   DiagMatrix c (a_nr, b_nc);
316 
317   octave_idx_type len = c.length ();
318   octave_idx_type lenm = (len < a_nc ? len : a_nc);
319 
320   for (octave_idx_type i = 0; i < lenm; i++)
321     c.dgxelem (i) = a.dgelem (i) * b.dgelem (i);
322   for (octave_idx_type i = lenm; i < len; i++)
323     c.dgxelem (i) = 0.0;
324 
325   return c;
326 }
327 
328 // other operations
329 
330 DET
determinant(void) const331 DiagMatrix::determinant (void) const
332 {
333   DET det (1.0);
334   if (rows () != cols ())
335     (*current_liboctave_error_handler) ("determinant requires square matrix");
336 
337   octave_idx_type len = length ();
338   for (octave_idx_type i = 0; i < len; i++)
339     det *= elem (i, i);
340 
341   return det;
342 }
343 
344 double
rcond(void) const345 DiagMatrix::rcond (void) const
346 {
347   ColumnVector av = extract_diag (0).map<double> (fabs);
348   double amx = av.max ();
349   double amn = av.min ();
350   return amx == 0 ? 0.0 : amn / amx;
351 }
352 
353 std::ostream&
operator <<(std::ostream & os,const DiagMatrix & a)354 operator << (std::ostream& os, const DiagMatrix& a)
355 {
356 //  int field_width = os.precision () + 7;
357 
358   for (octave_idx_type i = 0; i < a.rows (); i++)
359     {
360       for (octave_idx_type j = 0; j < a.cols (); j++)
361         {
362           if (i == j)
363             os << ' ' /* setw (field_width) */ << a.elem (i, i);
364           else
365             os << ' ' /* setw (field_width) */ << 0.0;
366         }
367       os << "\n";
368     }
369   return os;
370 }
371