1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 1994-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29
30 #include <ostream>
31
32 #include "Array-util.h"
33 #include "lo-error.h"
34 #include "mx-base.h"
35 #include "mx-inlines.cc"
36 #include "oct-cmplx.h"
37
38 // Diagonal Matrix class.
39
40 bool
operator ==(const FloatDiagMatrix & a) const41 FloatDiagMatrix::operator == (const FloatDiagMatrix& a) const
42 {
43 if (rows () != a.rows () || cols () != a.cols ())
44 return 0;
45
46 return mx_inline_equal (length (), data (), a.data ());
47 }
48
49 bool
operator !=(const FloatDiagMatrix & a) const50 FloatDiagMatrix::operator != (const FloatDiagMatrix& a) const
51 {
52 return !(*this == a);
53 }
54
55 FloatDiagMatrix&
fill(float val)56 FloatDiagMatrix::fill (float val)
57 {
58 for (octave_idx_type i = 0; i < length (); i++)
59 elem (i, i) = val;
60 return *this;
61 }
62
63 FloatDiagMatrix&
fill(float val,octave_idx_type beg,octave_idx_type end)64 FloatDiagMatrix::fill (float val, octave_idx_type beg, octave_idx_type end)
65 {
66 if (beg < 0 || end >= length () || end < beg)
67 (*current_liboctave_error_handler) ("range error for fill");
68
69 for (octave_idx_type i = beg; i <= end; i++)
70 elem (i, i) = val;
71
72 return *this;
73 }
74
75 FloatDiagMatrix&
fill(const FloatColumnVector & a)76 FloatDiagMatrix::fill (const FloatColumnVector& a)
77 {
78 octave_idx_type len = length ();
79 if (a.numel () != len)
80 (*current_liboctave_error_handler) ("range error for fill");
81
82 for (octave_idx_type i = 0; i < len; i++)
83 elem (i, i) = a.elem (i);
84
85 return *this;
86 }
87
88 FloatDiagMatrix&
fill(const FloatRowVector & a)89 FloatDiagMatrix::fill (const FloatRowVector& a)
90 {
91 octave_idx_type len = length ();
92 if (a.numel () != len)
93 (*current_liboctave_error_handler) ("range error for fill");
94
95 for (octave_idx_type i = 0; i < len; i++)
96 elem (i, i) = a.elem (i);
97
98 return *this;
99 }
100
101 FloatDiagMatrix&
fill(const FloatColumnVector & a,octave_idx_type beg)102 FloatDiagMatrix::fill (const FloatColumnVector& a, octave_idx_type beg)
103 {
104 octave_idx_type a_len = a.numel ();
105 if (beg < 0 || beg + a_len >= length ())
106 (*current_liboctave_error_handler) ("range error for fill");
107
108 for (octave_idx_type i = 0; i < a_len; i++)
109 elem (i+beg, i+beg) = a.elem (i);
110
111 return *this;
112 }
113
114 FloatDiagMatrix&
fill(const FloatRowVector & a,octave_idx_type beg)115 FloatDiagMatrix::fill (const FloatRowVector& a, octave_idx_type beg)
116 {
117 octave_idx_type a_len = a.numel ();
118 if (beg < 0 || beg + a_len >= length ())
119 (*current_liboctave_error_handler) ("range error for fill");
120
121 for (octave_idx_type i = 0; i < a_len; i++)
122 elem (i+beg, i+beg) = a.elem (i);
123
124 return *this;
125 }
126
127 FloatDiagMatrix
abs(void) const128 FloatDiagMatrix::abs (void) const
129 {
130 return FloatDiagMatrix (extract_diag ().abs (), rows (), columns ());
131 }
132
133 FloatDiagMatrix
real(const FloatComplexDiagMatrix & a)134 real (const FloatComplexDiagMatrix& a)
135 {
136 return FloatDiagMatrix (real (a.extract_diag ()), a.rows (), a.columns ());
137 }
138
139 FloatDiagMatrix
imag(const FloatComplexDiagMatrix & a)140 imag (const FloatComplexDiagMatrix& a)
141 {
142 return FloatDiagMatrix (imag (a.extract_diag ()), a.rows (), a.columns ());
143 }
144
145 FloatMatrix
extract(octave_idx_type r1,octave_idx_type c1,octave_idx_type r2,octave_idx_type c2) const146 FloatDiagMatrix::extract (octave_idx_type r1, octave_idx_type c1,
147 octave_idx_type r2, octave_idx_type c2) const
148 {
149 if (r1 > r2) { std::swap (r1, r2); }
150 if (c1 > c2) { std::swap (c1, c2); }
151
152 octave_idx_type new_r = r2 - r1 + 1;
153 octave_idx_type new_c = c2 - c1 + 1;
154
155 FloatMatrix result (new_r, new_c);
156
157 for (octave_idx_type j = 0; j < new_c; j++)
158 for (octave_idx_type i = 0; i < new_r; i++)
159 result.elem (i, j) = elem (r1+i, c1+j);
160
161 return result;
162 }
163
164 // extract row or column i.
165
166 FloatRowVector
row(octave_idx_type i) const167 FloatDiagMatrix::row (octave_idx_type i) const
168 {
169 octave_idx_type r = rows ();
170 octave_idx_type c = cols ();
171 if (i < 0 || i >= r)
172 (*current_liboctave_error_handler) ("invalid row selection");
173
174 FloatRowVector retval (c, 0.0);
175 if (r <= c || i < c)
176 retval.elem (i) = elem (i, i);
177
178 return retval;
179 }
180
181 FloatRowVector
row(char * s) const182 FloatDiagMatrix::row (char *s) const
183 {
184 if (! s)
185 (*current_liboctave_error_handler) ("invalid row selection");
186
187 char c = s[0];
188 if (c == 'f' || c == 'F')
189 return row (static_cast<octave_idx_type> (0));
190 else if (c == 'l' || c == 'L')
191 return row (rows () - 1);
192 else
193 (*current_liboctave_error_handler) ("invalid row selection");
194 }
195
196 FloatColumnVector
column(octave_idx_type i) const197 FloatDiagMatrix::column (octave_idx_type i) const
198 {
199 octave_idx_type r = rows ();
200 octave_idx_type c = cols ();
201 if (i < 0 || i >= c)
202 (*current_liboctave_error_handler) ("invalid column selection");
203
204 FloatColumnVector retval (r, 0.0);
205 if (r >= c || i < r)
206 retval.elem (i) = elem (i, i);
207
208 return retval;
209 }
210
211 FloatColumnVector
column(char * s) const212 FloatDiagMatrix::column (char *s) const
213 {
214 if (! s)
215 (*current_liboctave_error_handler) ("invalid column selection");
216
217 char c = s[0];
218 if (c == 'f' || c == 'F')
219 return column (static_cast<octave_idx_type> (0));
220 else if (c == 'l' || c == 'L')
221 return column (cols () - 1);
222 else
223 (*current_liboctave_error_handler) ("invalid column selection");
224 }
225
226 FloatDiagMatrix
inverse(void) const227 FloatDiagMatrix::inverse (void) const
228 {
229 octave_idx_type info;
230 return inverse (info);
231 }
232
233 FloatDiagMatrix
inverse(octave_idx_type & info) const234 FloatDiagMatrix::inverse (octave_idx_type& info) const
235 {
236 octave_idx_type r = rows ();
237 octave_idx_type c = cols ();
238 octave_idx_type len = length ();
239 if (r != c)
240 (*current_liboctave_error_handler) ("inverse requires square matrix");
241
242 FloatDiagMatrix retval (r, c);
243
244 info = 0;
245 for (octave_idx_type i = 0; i < len; i++)
246 {
247 if (elem (i, i) == 0.0)
248 retval.elem (i, i) = octave::numeric_limits<float>::Inf ();
249 else
250 retval.elem (i, i) = 1.0 / elem (i, i);
251 }
252
253 return retval;
254 }
255
256 FloatDiagMatrix
pseudo_inverse(float tol) const257 FloatDiagMatrix::pseudo_inverse (float tol) const
258 {
259 octave_idx_type r = rows ();
260 octave_idx_type c = cols ();
261 octave_idx_type len = length ();
262
263 FloatDiagMatrix retval (c, r);
264
265 for (octave_idx_type i = 0; i < len; i++)
266 {
267 float val = std::abs (elem (i, i));
268 if (val < tol || val == 0.0f)
269 retval.elem (i, i) = 0.0f;
270 else
271 retval.elem (i, i) = 1.0f / elem (i, i);
272 }
273
274 return retval;
275 }
276
277 // diagonal matrix by diagonal matrix -> diagonal matrix operations
278
279 // diagonal matrix by diagonal matrix -> diagonal matrix operations
280
281 FloatDiagMatrix
operator *(const FloatDiagMatrix & a,const FloatDiagMatrix & b)282 operator * (const FloatDiagMatrix& a, const FloatDiagMatrix& b)
283 {
284 octave_idx_type a_nr = a.rows ();
285 octave_idx_type a_nc = a.cols ();
286
287 octave_idx_type b_nr = b.rows ();
288 octave_idx_type b_nc = b.cols ();
289
290 if (a_nc != b_nr)
291 octave::err_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc);
292
293 FloatDiagMatrix c (a_nr, b_nc);
294
295 octave_idx_type len = c.length ();
296 octave_idx_type lenm = (len < a_nc ? len : a_nc);
297
298 for (octave_idx_type i = 0; i < lenm; i++)
299 c.dgxelem (i) = a.dgelem (i) * b.dgelem (i);
300 for (octave_idx_type i = lenm; i < len; i++)
301 c.dgxelem (i) = 0.0f;
302
303 return c;
304 }
305
306 // other operations
307
308 FloatDET
determinant(void) const309 FloatDiagMatrix::determinant (void) const
310 {
311 FloatDET det (1.0f);
312 if (rows () != cols ())
313 (*current_liboctave_error_handler) ("determinant requires square matrix");
314
315 octave_idx_type len = length ();
316 for (octave_idx_type i = 0; i < len; i++)
317 det *= elem (i, i);
318
319 return det;
320 }
321
322 float
rcond(void) const323 FloatDiagMatrix::rcond (void) const
324 {
325 FloatColumnVector av = extract_diag (0).map<float> (fabsf);
326 float amx = av.max ();
327 float amn = av.min ();
328 return amx == 0 ? 0.0f : amn / amx;
329 }
330
331 std::ostream&
operator <<(std::ostream & os,const FloatDiagMatrix & a)332 operator << (std::ostream& os, const FloatDiagMatrix& a)
333 {
334 // int field_width = os.precision () + 7;
335
336 for (octave_idx_type i = 0; i < a.rows (); i++)
337 {
338 for (octave_idx_type j = 0; j < a.cols (); j++)
339 {
340 if (i == j)
341 os << ' ' /* setw (field_width) */ << a.elem (i, i);
342 else
343 os << ' ' /* setw (field_width) */ << 0.0;
344 }
345 os << "\n";
346 }
347 return os;
348 }
349