1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 1994-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING.  If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 #  include "config.h"
28 #endif
29 
30 #include <ostream>
31 
32 #include "Array-util.h"
33 #include "lo-error.h"
34 #include "mx-base.h"
35 #include "mx-inlines.cc"
36 #include "oct-cmplx.h"
37 
38 // Diagonal Matrix class.
39 
40 bool
operator ==(const FloatDiagMatrix & a) const41 FloatDiagMatrix::operator == (const FloatDiagMatrix& a) const
42 {
43   if (rows () != a.rows () || cols () != a.cols ())
44     return 0;
45 
46   return mx_inline_equal (length (), data (), a.data ());
47 }
48 
49 bool
operator !=(const FloatDiagMatrix & a) const50 FloatDiagMatrix::operator != (const FloatDiagMatrix& a) const
51 {
52   return !(*this == a);
53 }
54 
55 FloatDiagMatrix&
fill(float val)56 FloatDiagMatrix::fill (float val)
57 {
58   for (octave_idx_type i = 0; i < length (); i++)
59     elem (i, i) = val;
60   return *this;
61 }
62 
63 FloatDiagMatrix&
fill(float val,octave_idx_type beg,octave_idx_type end)64 FloatDiagMatrix::fill (float val, octave_idx_type beg, octave_idx_type end)
65 {
66   if (beg < 0 || end >= length () || end < beg)
67     (*current_liboctave_error_handler) ("range error for fill");
68 
69   for (octave_idx_type i = beg; i <= end; i++)
70     elem (i, i) = val;
71 
72   return *this;
73 }
74 
75 FloatDiagMatrix&
fill(const FloatColumnVector & a)76 FloatDiagMatrix::fill (const FloatColumnVector& a)
77 {
78   octave_idx_type len = length ();
79   if (a.numel () != len)
80     (*current_liboctave_error_handler) ("range error for fill");
81 
82   for (octave_idx_type i = 0; i < len; i++)
83     elem (i, i) = a.elem (i);
84 
85   return *this;
86 }
87 
88 FloatDiagMatrix&
fill(const FloatRowVector & a)89 FloatDiagMatrix::fill (const FloatRowVector& a)
90 {
91   octave_idx_type len = length ();
92   if (a.numel () != len)
93     (*current_liboctave_error_handler) ("range error for fill");
94 
95   for (octave_idx_type i = 0; i < len; i++)
96     elem (i, i) = a.elem (i);
97 
98   return *this;
99 }
100 
101 FloatDiagMatrix&
fill(const FloatColumnVector & a,octave_idx_type beg)102 FloatDiagMatrix::fill (const FloatColumnVector& a, octave_idx_type beg)
103 {
104   octave_idx_type a_len = a.numel ();
105   if (beg < 0 || beg + a_len >= length ())
106     (*current_liboctave_error_handler) ("range error for fill");
107 
108   for (octave_idx_type i = 0; i < a_len; i++)
109     elem (i+beg, i+beg) = a.elem (i);
110 
111   return *this;
112 }
113 
114 FloatDiagMatrix&
fill(const FloatRowVector & a,octave_idx_type beg)115 FloatDiagMatrix::fill (const FloatRowVector& a, octave_idx_type beg)
116 {
117   octave_idx_type a_len = a.numel ();
118   if (beg < 0 || beg + a_len >= length ())
119     (*current_liboctave_error_handler) ("range error for fill");
120 
121   for (octave_idx_type i = 0; i < a_len; i++)
122     elem (i+beg, i+beg) = a.elem (i);
123 
124   return *this;
125 }
126 
127 FloatDiagMatrix
abs(void) const128 FloatDiagMatrix::abs (void) const
129 {
130   return FloatDiagMatrix (extract_diag ().abs (), rows (), columns ());
131 }
132 
133 FloatDiagMatrix
real(const FloatComplexDiagMatrix & a)134 real (const FloatComplexDiagMatrix& a)
135 {
136   return FloatDiagMatrix (real (a.extract_diag ()), a.rows (), a.columns ());
137 }
138 
139 FloatDiagMatrix
imag(const FloatComplexDiagMatrix & a)140 imag (const FloatComplexDiagMatrix& a)
141 {
142   return FloatDiagMatrix (imag (a.extract_diag ()), a.rows (), a.columns ());
143 }
144 
145 FloatMatrix
extract(octave_idx_type r1,octave_idx_type c1,octave_idx_type r2,octave_idx_type c2) const146 FloatDiagMatrix::extract (octave_idx_type r1, octave_idx_type c1,
147                           octave_idx_type r2, octave_idx_type c2) const
148 {
149   if (r1 > r2) { std::swap (r1, r2); }
150   if (c1 > c2) { std::swap (c1, c2); }
151 
152   octave_idx_type new_r = r2 - r1 + 1;
153   octave_idx_type new_c = c2 - c1 + 1;
154 
155   FloatMatrix result (new_r, new_c);
156 
157   for (octave_idx_type j = 0; j < new_c; j++)
158     for (octave_idx_type i = 0; i < new_r; i++)
159       result.elem (i, j) = elem (r1+i, c1+j);
160 
161   return result;
162 }
163 
164 // extract row or column i.
165 
166 FloatRowVector
row(octave_idx_type i) const167 FloatDiagMatrix::row (octave_idx_type i) const
168 {
169   octave_idx_type r = rows ();
170   octave_idx_type c = cols ();
171   if (i < 0 || i >= r)
172     (*current_liboctave_error_handler) ("invalid row selection");
173 
174   FloatRowVector retval (c, 0.0);
175   if (r <= c || i < c)
176     retval.elem (i) = elem (i, i);
177 
178   return retval;
179 }
180 
181 FloatRowVector
row(char * s) const182 FloatDiagMatrix::row (char *s) const
183 {
184   if (! s)
185     (*current_liboctave_error_handler) ("invalid row selection");
186 
187   char c = s[0];
188   if (c == 'f' || c == 'F')
189     return row (static_cast<octave_idx_type> (0));
190   else if (c == 'l' || c == 'L')
191     return row (rows () - 1);
192   else
193     (*current_liboctave_error_handler) ("invalid row selection");
194 }
195 
196 FloatColumnVector
column(octave_idx_type i) const197 FloatDiagMatrix::column (octave_idx_type i) const
198 {
199   octave_idx_type r = rows ();
200   octave_idx_type c = cols ();
201   if (i < 0 || i >= c)
202     (*current_liboctave_error_handler) ("invalid column selection");
203 
204   FloatColumnVector retval (r, 0.0);
205   if (r >= c || i < r)
206     retval.elem (i) = elem (i, i);
207 
208   return retval;
209 }
210 
211 FloatColumnVector
column(char * s) const212 FloatDiagMatrix::column (char *s) const
213 {
214   if (! s)
215     (*current_liboctave_error_handler) ("invalid column selection");
216 
217   char c = s[0];
218   if (c == 'f' || c == 'F')
219     return column (static_cast<octave_idx_type> (0));
220   else if (c == 'l' || c == 'L')
221     return column (cols () - 1);
222   else
223     (*current_liboctave_error_handler) ("invalid column selection");
224 }
225 
226 FloatDiagMatrix
inverse(void) const227 FloatDiagMatrix::inverse (void) const
228 {
229   octave_idx_type info;
230   return inverse (info);
231 }
232 
233 FloatDiagMatrix
inverse(octave_idx_type & info) const234 FloatDiagMatrix::inverse (octave_idx_type& info) const
235 {
236   octave_idx_type r = rows ();
237   octave_idx_type c = cols ();
238   octave_idx_type len = length ();
239   if (r != c)
240     (*current_liboctave_error_handler) ("inverse requires square matrix");
241 
242   FloatDiagMatrix retval (r, c);
243 
244   info = 0;
245   for (octave_idx_type i = 0; i < len; i++)
246     {
247       if (elem (i, i) == 0.0)
248         retval.elem (i, i) = octave::numeric_limits<float>::Inf ();
249       else
250         retval.elem (i, i) = 1.0 / elem (i, i);
251     }
252 
253   return retval;
254 }
255 
256 FloatDiagMatrix
pseudo_inverse(float tol) const257 FloatDiagMatrix::pseudo_inverse (float tol) const
258 {
259   octave_idx_type r = rows ();
260   octave_idx_type c = cols ();
261   octave_idx_type len = length ();
262 
263   FloatDiagMatrix retval (c, r);
264 
265   for (octave_idx_type i = 0; i < len; i++)
266     {
267       float val = std::abs (elem (i, i));
268       if (val < tol || val == 0.0f)
269         retval.elem (i, i) = 0.0f;
270       else
271         retval.elem (i, i) = 1.0f / elem (i, i);
272     }
273 
274   return retval;
275 }
276 
277 // diagonal matrix by diagonal matrix -> diagonal matrix operations
278 
279 // diagonal matrix by diagonal matrix -> diagonal matrix operations
280 
281 FloatDiagMatrix
operator *(const FloatDiagMatrix & a,const FloatDiagMatrix & b)282 operator * (const FloatDiagMatrix& a, const FloatDiagMatrix& b)
283 {
284   octave_idx_type a_nr = a.rows ();
285   octave_idx_type a_nc = a.cols ();
286 
287   octave_idx_type b_nr = b.rows ();
288   octave_idx_type b_nc = b.cols ();
289 
290   if (a_nc != b_nr)
291     octave::err_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc);
292 
293   FloatDiagMatrix c (a_nr, b_nc);
294 
295   octave_idx_type len = c.length ();
296   octave_idx_type lenm = (len < a_nc ? len : a_nc);
297 
298   for (octave_idx_type i = 0; i < lenm; i++)
299     c.dgxelem (i) = a.dgelem (i) * b.dgelem (i);
300   for (octave_idx_type i = lenm; i < len; i++)
301     c.dgxelem (i) = 0.0f;
302 
303   return c;
304 }
305 
306 // other operations
307 
308 FloatDET
determinant(void) const309 FloatDiagMatrix::determinant (void) const
310 {
311   FloatDET det (1.0f);
312   if (rows () != cols ())
313     (*current_liboctave_error_handler) ("determinant requires square matrix");
314 
315   octave_idx_type len = length ();
316   for (octave_idx_type i = 0; i < len; i++)
317     det *= elem (i, i);
318 
319   return det;
320 }
321 
322 float
rcond(void) const323 FloatDiagMatrix::rcond (void) const
324 {
325   FloatColumnVector av = extract_diag (0).map<float> (fabsf);
326   float amx = av.max ();
327   float amn = av.min ();
328   return amx == 0 ? 0.0f : amn / amx;
329 }
330 
331 std::ostream&
operator <<(std::ostream & os,const FloatDiagMatrix & a)332 operator << (std::ostream& os, const FloatDiagMatrix& a)
333 {
334 //  int field_width = os.precision () + 7;
335 
336   for (octave_idx_type i = 0; i < a.rows (); i++)
337     {
338       for (octave_idx_type j = 0; j < a.cols (); j++)
339         {
340           if (i == j)
341             os << ' ' /* setw (field_width) */ << a.elem (i, i);
342           else
343             os << ' ' /* setw (field_width) */ << 0.0;
344         }
345       os << "\n";
346     }
347   return os;
348 }
349