1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 1994-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING.  If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 #  include "config.h"
28 #endif
29 
30 #include <istream>
31 #include <ostream>
32 #include <type_traits>
33 
34 #include "Array-util.h"
35 #include "lo-blas-proto.h"
36 #include "lo-error.h"
37 #include "mx-base.h"
38 #include "mx-inlines.cc"
39 #include "oct-cmplx.h"
40 
41 // Row Vector class.
42 
43 bool
operator ==(const RowVector & a) const44 RowVector::operator == (const RowVector& a) const
45 {
46   octave_idx_type len = numel ();
47   if (len != a.numel ())
48     return 0;
49   return mx_inline_equal (len, data (), a.data ());
50 }
51 
52 bool
operator !=(const RowVector & a) const53 RowVector::operator != (const RowVector& a) const
54 {
55   return !(*this == a);
56 }
57 
58 RowVector&
insert(const RowVector & a,octave_idx_type c)59 RowVector::insert (const RowVector& a, octave_idx_type c)
60 {
61   octave_idx_type a_len = a.numel ();
62 
63   if (c < 0 || c + a_len > numel ())
64     (*current_liboctave_error_handler) ("range error for insert");
65 
66   if (a_len > 0)
67     {
68       make_unique ();
69 
70       for (octave_idx_type i = 0; i < a_len; i++)
71         xelem (c+i) = a.elem (i);
72     }
73 
74   return *this;
75 }
76 
77 RowVector&
fill(double val)78 RowVector::fill (double val)
79 {
80   octave_idx_type len = numel ();
81 
82   if (len > 0)
83     {
84       make_unique ();
85 
86       for (octave_idx_type i = 0; i < len; i++)
87         xelem (i) = val;
88     }
89 
90   return *this;
91 }
92 
93 RowVector&
fill(double val,octave_idx_type c1,octave_idx_type c2)94 RowVector::fill (double val, octave_idx_type c1, octave_idx_type c2)
95 {
96   octave_idx_type len = numel ();
97 
98   if (c1 < 0 || c2 < 0 || c1 >= len || c2 >= len)
99     (*current_liboctave_error_handler) ("range error for fill");
100 
101   if (c1 > c2) { std::swap (c1, c2); }
102 
103   if (c2 >= c1)
104     {
105       make_unique ();
106 
107       for (octave_idx_type i = c1; i <= c2; i++)
108         xelem (i) = val;
109     }
110 
111   return *this;
112 }
113 
114 RowVector
append(const RowVector & a) const115 RowVector::append (const RowVector& a) const
116 {
117   octave_idx_type len = numel ();
118   octave_idx_type nc_insert = len;
119   RowVector retval (len + a.numel ());
120   retval.insert (*this, 0);
121   retval.insert (a, nc_insert);
122   return retval;
123 }
124 
125 ColumnVector
transpose(void) const126 RowVector::transpose (void) const
127 {
128   return MArray<double>::transpose ();
129 }
130 
131 RowVector
real(const ComplexRowVector & a)132 real (const ComplexRowVector& a)
133 {
134   return do_mx_unary_op<double, Complex> (a, mx_inline_real);
135 }
136 
137 RowVector
imag(const ComplexRowVector & a)138 imag (const ComplexRowVector& a)
139 {
140   return do_mx_unary_op<double, Complex> (a, mx_inline_imag);
141 }
142 
143 RowVector
extract(octave_idx_type c1,octave_idx_type c2) const144 RowVector::extract (octave_idx_type c1, octave_idx_type c2) const
145 {
146   if (c1 > c2) { std::swap (c1, c2); }
147 
148   octave_idx_type new_c = c2 - c1 + 1;
149 
150   RowVector result (new_c);
151 
152   for (octave_idx_type i = 0; i < new_c; i++)
153     result.xelem (i) = elem (c1+i);
154 
155   return result;
156 }
157 
158 RowVector
extract_n(octave_idx_type r1,octave_idx_type n) const159 RowVector::extract_n (octave_idx_type r1, octave_idx_type n) const
160 {
161   RowVector result (n);
162 
163   for (octave_idx_type i = 0; i < n; i++)
164     result.xelem (i) = elem (r1+i);
165 
166   return result;
167 }
168 
169 // row vector by matrix -> row vector
170 
171 RowVector
operator *(const RowVector & v,const Matrix & a)172 operator * (const RowVector& v, const Matrix& a)
173 {
174   RowVector retval;
175 
176   F77_INT len = octave::to_f77_int (v.numel ());
177 
178   F77_INT a_nr = octave::to_f77_int (a.rows ());
179   F77_INT a_nc = octave::to_f77_int (a.cols ());
180 
181   if (a_nr != len)
182     octave::err_nonconformant ("operator *", 1, len, a_nr, a_nc);
183 
184   if (len == 0)
185     retval.resize (a_nc, 0.0);
186   else
187     {
188       // Transpose A to form A'*x == (x'*A)'
189 
190       F77_INT ld = a_nr;
191 
192       retval.resize (a_nc);
193       double *y = retval.fortran_vec ();
194 
195       F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 ("T", 1),
196                                a_nr, a_nc, 1.0, a.data (),
197                                ld, v.data (), 1, 0.0, y, 1
198                                F77_CHAR_ARG_LEN (1)));
199     }
200 
201   return retval;
202 }
203 
204 // other operations
205 
206 double
min(void) const207 RowVector::min (void) const
208 {
209   octave_idx_type len = numel ();
210   if (len == 0)
211     return 0;
212 
213   double res = elem (0);
214 
215   for (octave_idx_type i = 1; i < len; i++)
216     if (elem (i) < res)
217       res = elem (i);
218 
219   return res;
220 }
221 
222 double
max(void) const223 RowVector::max (void) const
224 {
225   octave_idx_type len = numel ();
226   if (len == 0)
227     return 0;
228 
229   double res = elem (0);
230 
231   for (octave_idx_type i = 1; i < len; i++)
232     if (elem (i) > res)
233       res = elem (i);
234 
235   return res;
236 }
237 
238 std::ostream&
operator <<(std::ostream & os,const RowVector & a)239 operator << (std::ostream& os, const RowVector& a)
240 {
241 //  int field_width = os.precision () + 7;
242 
243   for (octave_idx_type i = 0; i < a.numel (); i++)
244     os << ' ' /* setw (field_width) */ << a.elem (i);
245   return os;
246 }
247 
248 std::istream&
operator >>(std::istream & is,RowVector & a)249 operator >> (std::istream& is, RowVector& a)
250 {
251   octave_idx_type len = a.numel ();
252 
253   if (len > 0)
254     {
255       double tmp;
256       for (octave_idx_type i = 0; i < len; i++)
257         {
258           is >> tmp;
259           if (is)
260             a.elem (i) = tmp;
261           else
262             break;
263         }
264     }
265   return is;
266 }
267 
268 // other operations
269 
270 RowVector
linspace(double x1,double x2,octave_idx_type n_in)271 linspace (double x1, double x2, octave_idx_type n_in)
272 {
273   RowVector retval;
274 
275   if (n_in < 1)
276     return retval;
277   else if (n_in == 1)
278     {
279       retval.resize (1, x2);
280       return retval;
281     }
282 
283   // Use unsigned type (guaranteed n_in > 1 at this point) so that divisions
284   // by 2 can be replaced by compiler with shift right instructions.
285   typedef std::make_unsigned<octave_idx_type>::type unsigned_octave_idx_type;
286 
287   unsigned_octave_idx_type n = n_in;
288 
289   // Set endpoints, rather than calculate, for maximum accuracy.
290   retval.clear (n);
291   retval.xelem (0) = x1;
292   retval.xelem (n-1) = x2;
293 
294   // Construct linspace symmetrically from both ends.
295   double delta = (x2 - x1) / (n - 1);
296   unsigned_octave_idx_type n2 = n/2;
297   for (unsigned_octave_idx_type i = 1; i < n2; i++)
298     {
299       retval.xelem (i) = x1 + i*delta;
300       retval.xelem (n-1-i) = x2 - i*delta;
301     }
302   if (n % 2 == 1)  // Middle element if number of elements is odd.
303     retval.xelem (n2) = (x1 == -x2 ? 0 : (x1 + x2) / 2);
304 
305   return retval;
306 }
307 
308 // row vector by column vector -> scalar
309 
310 double
operator *(const RowVector & v,const ColumnVector & a)311 operator * (const RowVector& v, const ColumnVector& a)
312 {
313   double retval = 0.0;
314 
315   F77_INT len = octave::to_f77_int (v.numel ());
316 
317   F77_INT a_len = octave::to_f77_int (a.numel ());
318 
319   if (len != a_len)
320     octave::err_nonconformant ("operator *", len, a_len);
321 
322   if (len != 0)
323     F77_FUNC (xddot, XDDOT) (len, v.data (), 1, a.data (), 1, retval);
324 
325   return retval;
326 }
327 
328 Complex
operator *(const RowVector & v,const ComplexColumnVector & a)329 operator * (const RowVector& v, const ComplexColumnVector& a)
330 {
331   ComplexRowVector tmp (v);
332   return tmp * a;
333 }
334