1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 1994-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29
30 #include <istream>
31 #include <ostream>
32
33 #include "Array-util.h"
34 #include "lo-blas-proto.h"
35 #include "lo-error.h"
36 #include "mx-base.h"
37 #include "mx-inlines.cc"
38 #include "oct-cmplx.h"
39
40 // Complex Column Vector class
41
ComplexColumnVector(const ColumnVector & a)42 ComplexColumnVector::ComplexColumnVector (const ColumnVector& a)
43 : MArray<Complex> (a)
44 { }
45
46 bool
operator ==(const ComplexColumnVector & a) const47 ComplexColumnVector::operator == (const ComplexColumnVector& a) const
48 {
49 octave_idx_type len = numel ();
50 if (len != a.numel ())
51 return 0;
52 return mx_inline_equal (len, data (), a.data ());
53 }
54
55 bool
operator !=(const ComplexColumnVector & a) const56 ComplexColumnVector::operator != (const ComplexColumnVector& a) const
57 {
58 return !(*this == a);
59 }
60
61 // destructive insert/delete/reorder operations
62
63 ComplexColumnVector&
insert(const ColumnVector & a,octave_idx_type r)64 ComplexColumnVector::insert (const ColumnVector& a, octave_idx_type r)
65 {
66 octave_idx_type a_len = a.numel ();
67
68 if (r < 0 || r + a_len > numel ())
69 (*current_liboctave_error_handler) ("range error for insert");
70
71 if (a_len > 0)
72 {
73 make_unique ();
74
75 for (octave_idx_type i = 0; i < a_len; i++)
76 xelem (r+i) = a.elem (i);
77 }
78
79 return *this;
80 }
81
82 ComplexColumnVector&
insert(const ComplexColumnVector & a,octave_idx_type r)83 ComplexColumnVector::insert (const ComplexColumnVector& a, octave_idx_type r)
84 {
85 octave_idx_type a_len = a.numel ();
86
87 if (r < 0 || r + a_len > numel ())
88 (*current_liboctave_error_handler) ("range error for insert");
89
90 if (a_len > 0)
91 {
92 make_unique ();
93
94 for (octave_idx_type i = 0; i < a_len; i++)
95 xelem (r+i) = a.elem (i);
96 }
97
98 return *this;
99 }
100
101 ComplexColumnVector&
fill(double val)102 ComplexColumnVector::fill (double val)
103 {
104 octave_idx_type len = numel ();
105
106 if (len > 0)
107 {
108 make_unique ();
109
110 for (octave_idx_type i = 0; i < len; i++)
111 xelem (i) = val;
112 }
113
114 return *this;
115 }
116
117 ComplexColumnVector&
fill(const Complex & val)118 ComplexColumnVector::fill (const Complex& val)
119 {
120 octave_idx_type len = numel ();
121
122 if (len > 0)
123 {
124 make_unique ();
125
126 for (octave_idx_type i = 0; i < len; i++)
127 xelem (i) = val;
128 }
129
130 return *this;
131 }
132
133 ComplexColumnVector&
fill(double val,octave_idx_type r1,octave_idx_type r2)134 ComplexColumnVector::fill (double val, octave_idx_type r1, octave_idx_type r2)
135 {
136 octave_idx_type len = numel ();
137
138 if (r1 < 0 || r2 < 0 || r1 >= len || r2 >= len)
139 (*current_liboctave_error_handler) ("range error for fill");
140
141 if (r1 > r2) { std::swap (r1, r2); }
142
143 if (r2 >= r1)
144 {
145 make_unique ();
146
147 for (octave_idx_type i = r1; i <= r2; i++)
148 xelem (i) = val;
149 }
150
151 return *this;
152 }
153
154 ComplexColumnVector&
fill(const Complex & val,octave_idx_type r1,octave_idx_type r2)155 ComplexColumnVector::fill (const Complex& val,
156 octave_idx_type r1, octave_idx_type r2)
157 {
158 octave_idx_type len = numel ();
159
160 if (r1 < 0 || r2 < 0 || r1 >= len || r2 >= len)
161 (*current_liboctave_error_handler) ("range error for fill");
162
163 if (r1 > r2) { std::swap (r1, r2); }
164
165 if (r2 >= r1)
166 {
167 make_unique ();
168
169 for (octave_idx_type i = r1; i <= r2; i++)
170 xelem (i) = val;
171 }
172
173 return *this;
174 }
175
176 ComplexColumnVector
stack(const ColumnVector & a) const177 ComplexColumnVector::stack (const ColumnVector& a) const
178 {
179 octave_idx_type len = numel ();
180 octave_idx_type nr_insert = len;
181 ComplexColumnVector retval (len + a.numel ());
182 retval.insert (*this, 0);
183 retval.insert (a, nr_insert);
184 return retval;
185 }
186
187 ComplexColumnVector
stack(const ComplexColumnVector & a) const188 ComplexColumnVector::stack (const ComplexColumnVector& a) const
189 {
190 octave_idx_type len = numel ();
191 octave_idx_type nr_insert = len;
192 ComplexColumnVector retval (len + a.numel ());
193 retval.insert (*this, 0);
194 retval.insert (a, nr_insert);
195 return retval;
196 }
197
198 ComplexRowVector
hermitian(void) const199 ComplexColumnVector::hermitian (void) const
200 {
201 return MArray<Complex>::hermitian (std::conj);
202 }
203
204 ComplexRowVector
transpose(void) const205 ComplexColumnVector::transpose (void) const
206 {
207 return MArray<Complex>::transpose ();
208 }
209
210 ColumnVector
abs(void) const211 ComplexColumnVector::abs (void) const
212 {
213 return do_mx_unary_map<double, Complex, std::abs> (*this);
214 }
215
216 ComplexColumnVector
conj(const ComplexColumnVector & a)217 conj (const ComplexColumnVector& a)
218 {
219 return do_mx_unary_map<Complex, Complex, std::conj<double>> (a);
220 }
221
222 // resize is the destructive equivalent for this one
223
224 ComplexColumnVector
extract(octave_idx_type r1,octave_idx_type r2) const225 ComplexColumnVector::extract (octave_idx_type r1, octave_idx_type r2) const
226 {
227 if (r1 > r2) { std::swap (r1, r2); }
228
229 octave_idx_type new_r = r2 - r1 + 1;
230
231 ComplexColumnVector result (new_r);
232
233 for (octave_idx_type i = 0; i < new_r; i++)
234 result.elem (i) = elem (r1+i);
235
236 return result;
237 }
238
239 ComplexColumnVector
extract_n(octave_idx_type r1,octave_idx_type n) const240 ComplexColumnVector::extract_n (octave_idx_type r1, octave_idx_type n) const
241 {
242 ComplexColumnVector result (n);
243
244 for (octave_idx_type i = 0; i < n; i++)
245 result.elem (i) = elem (r1+i);
246
247 return result;
248 }
249
250 // column vector by column vector -> column vector operations
251
252 ComplexColumnVector&
operator +=(const ColumnVector & a)253 ComplexColumnVector::operator += (const ColumnVector& a)
254 {
255 octave_idx_type len = numel ();
256
257 octave_idx_type a_len = a.numel ();
258
259 if (len != a_len)
260 octave::err_nonconformant ("operator +=", len, a_len);
261
262 if (len == 0)
263 return *this;
264
265 Complex *d = fortran_vec (); // Ensures only one reference to my privates!
266
267 mx_inline_add2 (len, d, a.data ());
268 return *this;
269 }
270
271 ComplexColumnVector&
operator -=(const ColumnVector & a)272 ComplexColumnVector::operator -= (const ColumnVector& a)
273 {
274 octave_idx_type len = numel ();
275
276 octave_idx_type a_len = a.numel ();
277
278 if (len != a_len)
279 octave::err_nonconformant ("operator -=", len, a_len);
280
281 if (len == 0)
282 return *this;
283
284 Complex *d = fortran_vec (); // Ensures only one reference to my privates!
285
286 mx_inline_sub2 (len, d, a.data ());
287 return *this;
288 }
289
290 // matrix by column vector -> column vector operations
291
292 ComplexColumnVector
operator *(const ComplexMatrix & m,const ColumnVector & a)293 operator * (const ComplexMatrix& m, const ColumnVector& a)
294 {
295 ComplexColumnVector tmp (a);
296 return m * tmp;
297 }
298
299 ComplexColumnVector
operator *(const ComplexMatrix & m,const ComplexColumnVector & a)300 operator * (const ComplexMatrix& m, const ComplexColumnVector& a)
301 {
302 ComplexColumnVector retval;
303
304 F77_INT nr = octave::to_f77_int (m.rows ());
305 F77_INT nc = octave::to_f77_int (m.cols ());
306
307 F77_INT a_len = octave::to_f77_int (a.numel ());
308
309 if (nc != a_len)
310 octave::err_nonconformant ("operator *", nr, nc, a_len, 1);
311
312 retval.clear (nr);
313
314 if (nr != 0)
315 {
316 if (nc == 0)
317 retval.fill (0.0);
318 else
319 {
320 Complex *y = retval.fortran_vec ();
321
322 F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 ("N", 1),
323 nr, nc, 1.0,
324 F77_CONST_DBLE_CMPLX_ARG (m.data ()), nr,
325 F77_CONST_DBLE_CMPLX_ARG (a.data ()), 1, 0.0,
326 F77_DBLE_CMPLX_ARG (y), 1
327 F77_CHAR_ARG_LEN (1)));
328 }
329 }
330
331 return retval;
332 }
333
334 // matrix by column vector -> column vector operations
335
336 ComplexColumnVector
operator *(const Matrix & m,const ComplexColumnVector & a)337 operator * (const Matrix& m, const ComplexColumnVector& a)
338 {
339 ComplexMatrix tmp (m);
340 return tmp * a;
341 }
342
343 // diagonal matrix by column vector -> column vector operations
344
345 ComplexColumnVector
operator *(const DiagMatrix & m,const ComplexColumnVector & a)346 operator * (const DiagMatrix& m, const ComplexColumnVector& a)
347 {
348 F77_INT nr = octave::to_f77_int (m.rows ());
349 F77_INT nc = octave::to_f77_int (m.cols ());
350
351 F77_INT a_len = octave::to_f77_int (a.numel ());
352
353 if (nc != a_len)
354 octave::err_nonconformant ("operator *", nr, nc, a_len, 1);
355
356 if (nc == 0 || nr == 0)
357 return ComplexColumnVector (0);
358
359 ComplexColumnVector result (nr);
360
361 for (octave_idx_type i = 0; i < a_len; i++)
362 result.elem (i) = a.elem (i) * m.elem (i, i);
363
364 for (octave_idx_type i = a_len; i < nr; i++)
365 result.elem (i) = 0.0;
366
367 return result;
368 }
369
370 ComplexColumnVector
operator *(const ComplexDiagMatrix & m,const ColumnVector & a)371 operator * (const ComplexDiagMatrix& m, const ColumnVector& a)
372 {
373 F77_INT nr = octave::to_f77_int (m.rows ());
374 F77_INT nc = octave::to_f77_int (m.cols ());
375
376 F77_INT a_len = octave::to_f77_int (a.numel ());
377
378 if (nc != a_len)
379 octave::err_nonconformant ("operator *", nr, nc, a_len, 1);
380
381 if (nc == 0 || nr == 0)
382 return ComplexColumnVector (0);
383
384 ComplexColumnVector result (nr);
385
386 for (octave_idx_type i = 0; i < a_len; i++)
387 result.elem (i) = a.elem (i) * m.elem (i, i);
388
389 for (octave_idx_type i = a_len; i < nr; i++)
390 result.elem (i) = 0.0;
391
392 return result;
393 }
394
395 ComplexColumnVector
operator *(const ComplexDiagMatrix & m,const ComplexColumnVector & a)396 operator * (const ComplexDiagMatrix& m, const ComplexColumnVector& a)
397 {
398 F77_INT nr = octave::to_f77_int (m.rows ());
399 F77_INT nc = octave::to_f77_int (m.cols ());
400
401 F77_INT a_len = octave::to_f77_int (a.numel ());
402
403 if (nc != a_len)
404 octave::err_nonconformant ("operator *", nr, nc, a_len, 1);
405
406 if (nc == 0 || nr == 0)
407 return ComplexColumnVector (0);
408
409 ComplexColumnVector result (nr);
410
411 for (octave_idx_type i = 0; i < a_len; i++)
412 result.elem (i) = a.elem (i) * m.elem (i, i);
413
414 for (octave_idx_type i = a_len; i < nr; i++)
415 result.elem (i) = 0.0;
416
417 return result;
418 }
419
420 // other operations
421
422 Complex
min(void) const423 ComplexColumnVector::min (void) const
424 {
425 octave_idx_type len = numel ();
426 if (len == 0)
427 return 0.0;
428
429 Complex res = elem (0);
430 double absres = std::abs (res);
431
432 for (octave_idx_type i = 1; i < len; i++)
433 if (std::abs (elem (i)) < absres)
434 {
435 res = elem (i);
436 absres = std::abs (res);
437 }
438
439 return res;
440 }
441
442 Complex
max(void) const443 ComplexColumnVector::max (void) const
444 {
445 octave_idx_type len = numel ();
446 if (len == 0)
447 return 0.0;
448
449 Complex res = elem (0);
450 double absres = std::abs (res);
451
452 for (octave_idx_type i = 1; i < len; i++)
453 if (std::abs (elem (i)) > absres)
454 {
455 res = elem (i);
456 absres = std::abs (res);
457 }
458
459 return res;
460 }
461
462 // i/o
463
464 std::ostream&
operator <<(std::ostream & os,const ComplexColumnVector & a)465 operator << (std::ostream& os, const ComplexColumnVector& a)
466 {
467 // int field_width = os.precision () + 7;
468 for (octave_idx_type i = 0; i < a.numel (); i++)
469 os << /* setw (field_width) << */ a.elem (i) << "\n";
470 return os;
471 }
472
473 std::istream&
operator >>(std::istream & is,ComplexColumnVector & a)474 operator >> (std::istream& is, ComplexColumnVector& a)
475 {
476 octave_idx_type len = a.numel ();
477
478 if (len > 0)
479 {
480 double tmp;
481 for (octave_idx_type i = 0; i < len; i++)
482 {
483 is >> tmp;
484 if (is)
485 a.elem (i) = tmp;
486 else
487 break;
488 }
489 }
490 return is;
491 }
492