1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 1994-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING.  If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 #  include "config.h"
28 #endif
29 
30 #include <istream>
31 #include <ostream>
32 
33 #include "Array-util.h"
34 #include "lo-blas-proto.h"
35 #include "lo-error.h"
36 #include "mx-base.h"
37 #include "mx-inlines.cc"
38 #include "oct-cmplx.h"
39 
40 // Complex Column Vector class
41 
ComplexColumnVector(const ColumnVector & a)42 ComplexColumnVector::ComplexColumnVector (const ColumnVector& a)
43   : MArray<Complex> (a)
44 { }
45 
46 bool
operator ==(const ComplexColumnVector & a) const47 ComplexColumnVector::operator == (const ComplexColumnVector& a) const
48 {
49   octave_idx_type len = numel ();
50   if (len != a.numel ())
51     return 0;
52   return mx_inline_equal (len, data (), a.data ());
53 }
54 
55 bool
operator !=(const ComplexColumnVector & a) const56 ComplexColumnVector::operator != (const ComplexColumnVector& a) const
57 {
58   return !(*this == a);
59 }
60 
61 // destructive insert/delete/reorder operations
62 
63 ComplexColumnVector&
insert(const ColumnVector & a,octave_idx_type r)64 ComplexColumnVector::insert (const ColumnVector& a, octave_idx_type r)
65 {
66   octave_idx_type a_len = a.numel ();
67 
68   if (r < 0 || r + a_len > numel ())
69     (*current_liboctave_error_handler) ("range error for insert");
70 
71   if (a_len > 0)
72     {
73       make_unique ();
74 
75       for (octave_idx_type i = 0; i < a_len; i++)
76         xelem (r+i) = a.elem (i);
77     }
78 
79   return *this;
80 }
81 
82 ComplexColumnVector&
insert(const ComplexColumnVector & a,octave_idx_type r)83 ComplexColumnVector::insert (const ComplexColumnVector& a, octave_idx_type r)
84 {
85   octave_idx_type a_len = a.numel ();
86 
87   if (r < 0 || r + a_len > numel ())
88     (*current_liboctave_error_handler) ("range error for insert");
89 
90   if (a_len > 0)
91     {
92       make_unique ();
93 
94       for (octave_idx_type i = 0; i < a_len; i++)
95         xelem (r+i) = a.elem (i);
96     }
97 
98   return *this;
99 }
100 
101 ComplexColumnVector&
fill(double val)102 ComplexColumnVector::fill (double val)
103 {
104   octave_idx_type len = numel ();
105 
106   if (len > 0)
107     {
108       make_unique ();
109 
110       for (octave_idx_type i = 0; i < len; i++)
111         xelem (i) = val;
112     }
113 
114   return *this;
115 }
116 
117 ComplexColumnVector&
fill(const Complex & val)118 ComplexColumnVector::fill (const Complex& val)
119 {
120   octave_idx_type len = numel ();
121 
122   if (len > 0)
123     {
124       make_unique ();
125 
126       for (octave_idx_type i = 0; i < len; i++)
127         xelem (i) = val;
128     }
129 
130   return *this;
131 }
132 
133 ComplexColumnVector&
fill(double val,octave_idx_type r1,octave_idx_type r2)134 ComplexColumnVector::fill (double val, octave_idx_type r1, octave_idx_type r2)
135 {
136   octave_idx_type len = numel ();
137 
138   if (r1 < 0 || r2 < 0 || r1 >= len || r2 >= len)
139     (*current_liboctave_error_handler) ("range error for fill");
140 
141   if (r1 > r2) { std::swap (r1, r2); }
142 
143   if (r2 >= r1)
144     {
145       make_unique ();
146 
147       for (octave_idx_type i = r1; i <= r2; i++)
148         xelem (i) = val;
149     }
150 
151   return *this;
152 }
153 
154 ComplexColumnVector&
fill(const Complex & val,octave_idx_type r1,octave_idx_type r2)155 ComplexColumnVector::fill (const Complex& val,
156                            octave_idx_type r1, octave_idx_type r2)
157 {
158   octave_idx_type len = numel ();
159 
160   if (r1 < 0 || r2 < 0 || r1 >= len || r2 >= len)
161     (*current_liboctave_error_handler) ("range error for fill");
162 
163   if (r1 > r2) { std::swap (r1, r2); }
164 
165   if (r2 >= r1)
166     {
167       make_unique ();
168 
169       for (octave_idx_type i = r1; i <= r2; i++)
170         xelem (i) = val;
171     }
172 
173   return *this;
174 }
175 
176 ComplexColumnVector
stack(const ColumnVector & a) const177 ComplexColumnVector::stack (const ColumnVector& a) const
178 {
179   octave_idx_type len = numel ();
180   octave_idx_type nr_insert = len;
181   ComplexColumnVector retval (len + a.numel ());
182   retval.insert (*this, 0);
183   retval.insert (a, nr_insert);
184   return retval;
185 }
186 
187 ComplexColumnVector
stack(const ComplexColumnVector & a) const188 ComplexColumnVector::stack (const ComplexColumnVector& a) const
189 {
190   octave_idx_type len = numel ();
191   octave_idx_type nr_insert = len;
192   ComplexColumnVector retval (len + a.numel ());
193   retval.insert (*this, 0);
194   retval.insert (a, nr_insert);
195   return retval;
196 }
197 
198 ComplexRowVector
hermitian(void) const199 ComplexColumnVector::hermitian (void) const
200 {
201   return MArray<Complex>::hermitian (std::conj);
202 }
203 
204 ComplexRowVector
transpose(void) const205 ComplexColumnVector::transpose (void) const
206 {
207   return MArray<Complex>::transpose ();
208 }
209 
210 ColumnVector
abs(void) const211 ComplexColumnVector::abs (void) const
212 {
213   return do_mx_unary_map<double, Complex, std::abs> (*this);
214 }
215 
216 ComplexColumnVector
conj(const ComplexColumnVector & a)217 conj (const ComplexColumnVector& a)
218 {
219   return do_mx_unary_map<Complex, Complex, std::conj<double>> (a);
220 }
221 
222 // resize is the destructive equivalent for this one
223 
224 ComplexColumnVector
extract(octave_idx_type r1,octave_idx_type r2) const225 ComplexColumnVector::extract (octave_idx_type r1, octave_idx_type r2) const
226 {
227   if (r1 > r2) { std::swap (r1, r2); }
228 
229   octave_idx_type new_r = r2 - r1 + 1;
230 
231   ComplexColumnVector result (new_r);
232 
233   for (octave_idx_type i = 0; i < new_r; i++)
234     result.elem (i) = elem (r1+i);
235 
236   return result;
237 }
238 
239 ComplexColumnVector
extract_n(octave_idx_type r1,octave_idx_type n) const240 ComplexColumnVector::extract_n (octave_idx_type r1, octave_idx_type n) const
241 {
242   ComplexColumnVector result (n);
243 
244   for (octave_idx_type i = 0; i < n; i++)
245     result.elem (i) = elem (r1+i);
246 
247   return result;
248 }
249 
250 // column vector by column vector -> column vector operations
251 
252 ComplexColumnVector&
operator +=(const ColumnVector & a)253 ComplexColumnVector::operator += (const ColumnVector& a)
254 {
255   octave_idx_type len = numel ();
256 
257   octave_idx_type a_len = a.numel ();
258 
259   if (len != a_len)
260     octave::err_nonconformant ("operator +=", len, a_len);
261 
262   if (len == 0)
263     return *this;
264 
265   Complex *d = fortran_vec (); // Ensures only one reference to my privates!
266 
267   mx_inline_add2 (len, d, a.data ());
268   return *this;
269 }
270 
271 ComplexColumnVector&
operator -=(const ColumnVector & a)272 ComplexColumnVector::operator -= (const ColumnVector& a)
273 {
274   octave_idx_type len = numel ();
275 
276   octave_idx_type a_len = a.numel ();
277 
278   if (len != a_len)
279     octave::err_nonconformant ("operator -=", len, a_len);
280 
281   if (len == 0)
282     return *this;
283 
284   Complex *d = fortran_vec (); // Ensures only one reference to my privates!
285 
286   mx_inline_sub2 (len, d, a.data ());
287   return *this;
288 }
289 
290 // matrix by column vector -> column vector operations
291 
292 ComplexColumnVector
operator *(const ComplexMatrix & m,const ColumnVector & a)293 operator * (const ComplexMatrix& m, const ColumnVector& a)
294 {
295   ComplexColumnVector tmp (a);
296   return m * tmp;
297 }
298 
299 ComplexColumnVector
operator *(const ComplexMatrix & m,const ComplexColumnVector & a)300 operator * (const ComplexMatrix& m, const ComplexColumnVector& a)
301 {
302   ComplexColumnVector retval;
303 
304   F77_INT nr = octave::to_f77_int (m.rows ());
305   F77_INT nc = octave::to_f77_int (m.cols ());
306 
307   F77_INT a_len = octave::to_f77_int (a.numel ());
308 
309   if (nc != a_len)
310     octave::err_nonconformant ("operator *", nr, nc, a_len, 1);
311 
312   retval.clear (nr);
313 
314   if (nr != 0)
315     {
316       if (nc == 0)
317         retval.fill (0.0);
318       else
319         {
320           Complex *y = retval.fortran_vec ();
321 
322           F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 ("N", 1),
323                                    nr, nc, 1.0,
324                                    F77_CONST_DBLE_CMPLX_ARG (m.data ()), nr,
325                                    F77_CONST_DBLE_CMPLX_ARG (a.data ()), 1, 0.0,
326                                    F77_DBLE_CMPLX_ARG (y), 1
327                                    F77_CHAR_ARG_LEN (1)));
328         }
329     }
330 
331   return retval;
332 }
333 
334 // matrix by column vector -> column vector operations
335 
336 ComplexColumnVector
operator *(const Matrix & m,const ComplexColumnVector & a)337 operator * (const Matrix& m, const ComplexColumnVector& a)
338 {
339   ComplexMatrix tmp (m);
340   return tmp * a;
341 }
342 
343 // diagonal matrix by column vector -> column vector operations
344 
345 ComplexColumnVector
operator *(const DiagMatrix & m,const ComplexColumnVector & a)346 operator * (const DiagMatrix& m, const ComplexColumnVector& a)
347 {
348   F77_INT nr = octave::to_f77_int (m.rows ());
349   F77_INT nc = octave::to_f77_int (m.cols ());
350 
351   F77_INT a_len = octave::to_f77_int (a.numel ());
352 
353   if (nc != a_len)
354     octave::err_nonconformant ("operator *", nr, nc, a_len, 1);
355 
356   if (nc == 0 || nr == 0)
357     return ComplexColumnVector (0);
358 
359   ComplexColumnVector result (nr);
360 
361   for (octave_idx_type i = 0; i < a_len; i++)
362     result.elem (i) = a.elem (i) * m.elem (i, i);
363 
364   for (octave_idx_type i = a_len; i < nr; i++)
365     result.elem (i) = 0.0;
366 
367   return result;
368 }
369 
370 ComplexColumnVector
operator *(const ComplexDiagMatrix & m,const ColumnVector & a)371 operator * (const ComplexDiagMatrix& m, const ColumnVector& a)
372 {
373   F77_INT nr = octave::to_f77_int (m.rows ());
374   F77_INT nc = octave::to_f77_int (m.cols ());
375 
376   F77_INT a_len = octave::to_f77_int (a.numel ());
377 
378   if (nc != a_len)
379     octave::err_nonconformant ("operator *", nr, nc, a_len, 1);
380 
381   if (nc == 0 || nr == 0)
382     return ComplexColumnVector (0);
383 
384   ComplexColumnVector result (nr);
385 
386   for (octave_idx_type i = 0; i < a_len; i++)
387     result.elem (i) = a.elem (i) * m.elem (i, i);
388 
389   for (octave_idx_type i = a_len; i < nr; i++)
390     result.elem (i) = 0.0;
391 
392   return result;
393 }
394 
395 ComplexColumnVector
operator *(const ComplexDiagMatrix & m,const ComplexColumnVector & a)396 operator * (const ComplexDiagMatrix& m, const ComplexColumnVector& a)
397 {
398   F77_INT nr = octave::to_f77_int (m.rows ());
399   F77_INT nc = octave::to_f77_int (m.cols ());
400 
401   F77_INT a_len = octave::to_f77_int (a.numel ());
402 
403   if (nc != a_len)
404     octave::err_nonconformant ("operator *", nr, nc, a_len, 1);
405 
406   if (nc == 0 || nr == 0)
407     return ComplexColumnVector (0);
408 
409   ComplexColumnVector result (nr);
410 
411   for (octave_idx_type i = 0; i < a_len; i++)
412     result.elem (i) = a.elem (i) * m.elem (i, i);
413 
414   for (octave_idx_type i = a_len; i < nr; i++)
415     result.elem (i) = 0.0;
416 
417   return result;
418 }
419 
420 // other operations
421 
422 Complex
min(void) const423 ComplexColumnVector::min (void) const
424 {
425   octave_idx_type len = numel ();
426   if (len == 0)
427     return 0.0;
428 
429   Complex res = elem (0);
430   double absres = std::abs (res);
431 
432   for (octave_idx_type i = 1; i < len; i++)
433     if (std::abs (elem (i)) < absres)
434       {
435         res = elem (i);
436         absres = std::abs (res);
437       }
438 
439   return res;
440 }
441 
442 Complex
max(void) const443 ComplexColumnVector::max (void) const
444 {
445   octave_idx_type len = numel ();
446   if (len == 0)
447     return 0.0;
448 
449   Complex res = elem (0);
450   double absres = std::abs (res);
451 
452   for (octave_idx_type i = 1; i < len; i++)
453     if (std::abs (elem (i)) > absres)
454       {
455         res = elem (i);
456         absres = std::abs (res);
457       }
458 
459   return res;
460 }
461 
462 // i/o
463 
464 std::ostream&
operator <<(std::ostream & os,const ComplexColumnVector & a)465 operator << (std::ostream& os, const ComplexColumnVector& a)
466 {
467 //  int field_width = os.precision () + 7;
468   for (octave_idx_type i = 0; i < a.numel (); i++)
469     os << /* setw (field_width) << */ a.elem (i) << "\n";
470   return os;
471 }
472 
473 std::istream&
operator >>(std::istream & is,ComplexColumnVector & a)474 operator >> (std::istream& is, ComplexColumnVector& a)
475 {
476   octave_idx_type len = a.numel ();
477 
478   if (len > 0)
479     {
480       double tmp;
481       for (octave_idx_type i = 0; i < len; i++)
482         {
483           is >> tmp;
484           if (is)
485             a.elem (i) = tmp;
486           else
487             break;
488         }
489     }
490   return is;
491 }
492