1 // Copyright 2014-2020 bluss and ndarray developers.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8 
9 use crate::imp_prelude::*;
10 use crate::numeric_util;
11 
12 use crate::{LinalgScalar, Zip};
13 
14 use std::any::TypeId;
15 use alloc::vec::Vec;
16 
17 #[cfg(feature = "blas")]
18 use std::cmp;
19 #[cfg(feature = "blas")]
20 use std::mem::swap;
21 #[cfg(feature = "blas")]
22 use libc::c_int;
23 
24 #[cfg(feature = "blas")]
25 use cblas_sys as blas_sys;
26 #[cfg(feature = "blas")]
27 use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT};
28 
29 /// len of vector before we use blas
30 #[cfg(feature = "blas")]
31 const DOT_BLAS_CUTOFF: usize = 32;
32 /// side of matrix before we use blas
33 #[cfg(feature = "blas")]
34 const GEMM_BLAS_CUTOFF: usize = 7;
35 #[cfg(feature = "blas")]
36 #[allow(non_camel_case_types)]
37 type blas_index = c_int; // blas index type
38 
39 impl<A, S> ArrayBase<S, Ix1>
40 where
41     S: Data<Elem = A>,
42 {
43     /// Perform dot product or matrix multiplication of arrays `self` and `rhs`.
44     ///
45     /// `Rhs` may be either a one-dimensional or a two-dimensional array.
46     ///
47     /// If `Rhs` is one-dimensional, then the operation is a vector dot
48     /// product, which is the sum of the elementwise products (no conjugation
49     /// of complex operands, and thus not their inner product). In this case,
50     /// `self` and `rhs` must be the same length.
51     ///
52     /// If `Rhs` is two-dimensional, then the operation is matrix
53     /// multiplication, where `self` is treated as a row vector. In this case,
54     /// if `self` is shape *M*, then `rhs` is shape *M* × *N* and the result is
55     /// shape *N*.
56     ///
57     /// **Panics** if the array shapes are incompatible.<br>
58     /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
59     /// layout allows.
dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output where Self: Dot<Rhs>,60     pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
61     where
62         Self: Dot<Rhs>,
63     {
64         Dot::dot(self, rhs)
65     }
66 
dot_generic<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A where S2: Data<Elem = A>, A: LinalgScalar,67     fn dot_generic<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
68     where
69         S2: Data<Elem = A>,
70         A: LinalgScalar,
71     {
72         debug_assert_eq!(self.len(), rhs.len());
73         assert!(self.len() == rhs.len());
74         if let Some(self_s) = self.as_slice() {
75             if let Some(rhs_s) = rhs.as_slice() {
76                 return numeric_util::unrolled_dot(self_s, rhs_s);
77             }
78         }
79         let mut sum = A::zero();
80         for i in 0..self.len() {
81             unsafe {
82                 sum = sum + *self.uget(i) * *rhs.uget(i);
83             }
84         }
85         sum
86     }
87 
88     #[cfg(not(feature = "blas"))]
dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A where S2: Data<Elem = A>, A: LinalgScalar,89     fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
90     where
91         S2: Data<Elem = A>,
92         A: LinalgScalar,
93     {
94         self.dot_generic(rhs)
95     }
96 
97     #[cfg(feature = "blas")]
dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A where S2: Data<Elem = A>, A: LinalgScalar,98     fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
99     where
100         S2: Data<Elem = A>,
101         A: LinalgScalar,
102     {
103         // Use only if the vector is large enough to be worth it
104         if self.len() >= DOT_BLAS_CUTOFF {
105             debug_assert_eq!(self.len(), rhs.len());
106             assert!(self.len() == rhs.len());
107             macro_rules! dot {
108                 ($ty:ty, $func:ident) => {{
109                     if blas_compat_1d::<$ty, _>(self) && blas_compat_1d::<$ty, _>(rhs) {
110                         unsafe {
111                             let (lhs_ptr, n, incx) =
112                                 blas_1d_params(self.ptr.as_ptr(), self.len(), self.strides()[0]);
113                             let (rhs_ptr, _, incy) =
114                                 blas_1d_params(rhs.ptr.as_ptr(), rhs.len(), rhs.strides()[0]);
115                             let ret = blas_sys::$func(
116                                 n,
117                                 lhs_ptr as *const $ty,
118                                 incx,
119                                 rhs_ptr as *const $ty,
120                                 incy,
121                             );
122                             return cast_as::<$ty, A>(&ret);
123                         }
124                     }
125                 }};
126             }
127 
128             dot! {f32, cblas_sdot};
129             dot! {f64, cblas_ddot};
130         }
131         self.dot_generic(rhs)
132     }
133 }
134 
135 /// Return a pointer to the starting element in BLAS's view.
136 ///
137 /// BLAS wants a pointer to the element with lowest address,
138 /// which agrees with our pointer for non-negative strides, but
139 /// is at the opposite end for negative strides.
140 #[cfg(feature = "blas")]
blas_1d_params<A>( ptr: *const A, len: usize, stride: isize, ) -> (*const A, blas_index, blas_index)141 unsafe fn blas_1d_params<A>(
142     ptr: *const A,
143     len: usize,
144     stride: isize,
145 ) -> (*const A, blas_index, blas_index) {
146     // [x x x x]
147     //        ^--ptr
148     //        stride = -1
149     //  ^--blas_ptr = ptr + (len - 1) * stride
150     if stride >= 0 || len == 0 {
151         (ptr, len as blas_index, stride as blas_index)
152     } else {
153         let ptr = ptr.offset((len - 1) as isize * stride);
154         (ptr, len as blas_index, stride as blas_index)
155     }
156 }
157 
158 /// Matrix Multiplication
159 ///
160 /// For two-dimensional arrays, the dot method computes the matrix
161 /// multiplication.
162 pub trait Dot<Rhs> {
163     /// The result of the operation.
164     ///
165     /// For two-dimensional arrays: a rectangular array.
166     type Output;
dot(&self, rhs: &Rhs) -> Self::Output167     fn dot(&self, rhs: &Rhs) -> Self::Output;
168 }
169 
170 impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix1>
171 where
172     S: Data<Elem = A>,
173     S2: Data<Elem = A>,
174     A: LinalgScalar,
175 {
176     type Output = A;
177 
178     /// Compute the dot product of one-dimensional arrays.
179     ///
180     /// The dot product is a sum of the elementwise products (no conjugation
181     /// of complex operands, and thus not their inner product).
182     ///
183     /// **Panics** if the arrays are not of the same length.<br>
184     /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
185     /// layout allows.
dot(&self, rhs: &ArrayBase<S2, Ix1>) -> A186     fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> A {
187         self.dot_impl(rhs)
188     }
189 }
190 
191 impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix1>
192 where
193     S: Data<Elem = A>,
194     S2: Data<Elem = A>,
195     A: LinalgScalar,
196 {
197     type Output = Array<A, Ix1>;
198 
199     /// Perform the matrix multiplication of the row vector `self` and
200     /// rectangular matrix `rhs`.
201     ///
202     /// The array shapes must agree in the way that
203     /// if `self` is *M*, then `rhs` is *M* × *N*.
204     ///
205     /// Return a result array with shape *N*.
206     ///
207     /// **Panics** if shapes are incompatible.
dot(&self, rhs: &ArrayBase<S2, Ix2>) -> Array<A, Ix1>208     fn dot(&self, rhs: &ArrayBase<S2, Ix2>) -> Array<A, Ix1> {
209         rhs.t().dot(self)
210     }
211 }
212 
213 impl<A, S> ArrayBase<S, Ix2>
214 where
215     S: Data<Elem = A>,
216 {
217     /// Perform matrix multiplication of rectangular arrays `self` and `rhs`.
218     ///
219     /// `Rhs` may be either a one-dimensional or a two-dimensional array.
220     ///
221     /// If Rhs is two-dimensional, they array shapes must agree in the way that
222     /// if `self` is *M* × *N*, then `rhs` is *N* × *K*.
223     ///
224     /// Return a result array with shape *M* × *K*.
225     ///
226     /// **Panics** if shapes are incompatible or the number of elements in the
227     /// result would overflow `isize`.
228     ///
229     /// *Note:* If enabled, uses blas `gemv/gemm` for elements of `f32, f64`
230     /// when memory layout allows. The default matrixmultiply backend
231     /// is otherwise used for `f32, f64` for all memory layouts.
232     ///
233     /// ```
234     /// use ndarray::arr2;
235     ///
236     /// let a = arr2(&[[1., 2.],
237     ///                [0., 1.]]);
238     /// let b = arr2(&[[1., 2.],
239     ///                [2., 3.]]);
240     ///
241     /// assert!(
242     ///     a.dot(&b) == arr2(&[[5., 8.],
243     ///                         [2., 3.]])
244     /// );
245     /// ```
dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output where Self: Dot<Rhs>,246     pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
247     where
248         Self: Dot<Rhs>,
249     {
250         Dot::dot(self, rhs)
251     }
252 }
253 
254 impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix2>
255 where
256     S: Data<Elem = A>,
257     S2: Data<Elem = A>,
258     A: LinalgScalar,
259 {
260     type Output = Array2<A>;
dot(&self, b: &ArrayBase<S2, Ix2>) -> Array2<A>261     fn dot(&self, b: &ArrayBase<S2, Ix2>) -> Array2<A> {
262         let a = self.view();
263         let b = b.view();
264         let ((m, k), (k2, n)) = (a.dim(), b.dim());
265         if k != k2 || m.checked_mul(n).is_none() {
266             dot_shape_error(m, k, k2, n);
267         }
268 
269         let lhs_s0 = a.strides()[0];
270         let rhs_s0 = b.strides()[0];
271         let column_major = lhs_s0 == 1 && rhs_s0 == 1;
272         // A is Copy so this is safe
273         let mut v = Vec::with_capacity(m * n);
274         let mut c;
275         unsafe {
276             v.set_len(m * n);
277             c = Array::from_shape_vec_unchecked((m, n).set_f(column_major), v);
278         }
279         mat_mul_impl(A::one(), &a, &b, A::zero(), &mut c.view_mut());
280         c
281     }
282 }
283 
284 /// Assumes that `m` and `n` are ≤ `isize::MAX`.
285 #[cold]
286 #[inline(never)]
dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> !287 fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! {
288     match m.checked_mul(n) {
289         Some(len) if len <= ::std::isize::MAX as usize => {}
290         _ => panic!("ndarray: shape {} × {} overflows isize", m, n),
291     }
292     panic!(
293         "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication",
294         m, k, k2, n
295     );
296 }
297 
298 #[cold]
299 #[inline(never)]
general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> !300 fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> ! {
301     panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication",
302            m, k, k2, n, c1, c2);
303 }
304 
305 /// Perform the matrix multiplication of the rectangular array `self` and
306 /// column vector `rhs`.
307 ///
308 /// The array shapes must agree in the way that
309 /// if `self` is *M* × *N*, then `rhs` is *N*.
310 ///
311 /// Return a result array with shape *M*.
312 ///
313 /// **Panics** if shapes are incompatible.
314 impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix2>
315 where
316     S: Data<Elem = A>,
317     S2: Data<Elem = A>,
318     A: LinalgScalar,
319 {
320     type Output = Array<A, Ix1>;
dot(&self, rhs: &ArrayBase<S2, Ix1>) -> Array<A, Ix1>321     fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> Array<A, Ix1> {
322         let ((m, a), n) = (self.dim(), rhs.dim());
323         if a != n {
324             dot_shape_error(m, a, n, 1);
325         }
326 
327         // Avoid initializing the memory in vec -- set it during iteration
328         unsafe {
329             let mut c = Array1::uninit(m);
330             general_mat_vec_mul_impl(A::one(), self, rhs, A::zero(), c.raw_view_mut().cast::<A>());
331             c.assume_init()
332         }
333     }
334 }
335 
336 impl<A, S, D> ArrayBase<S, D>
337 where
338     S: Data<Elem = A>,
339     D: Dimension,
340 {
341     /// Perform the operation `self += alpha * rhs` efficiently, where
342     /// `alpha` is a scalar and `rhs` is another array. This operation is
343     /// also known as `axpy` in BLAS.
344     ///
345     /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
346     ///
347     /// **Panics** if broadcasting isn’t possible.
scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>) where S: DataMut, S2: Data<Elem = A>, A: LinalgScalar, E: Dimension,348     pub fn scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
349     where
350         S: DataMut,
351         S2: Data<Elem = A>,
352         A: LinalgScalar,
353         E: Dimension,
354     {
355         self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
356     }
357 }
358 
359 // mat_mul_impl uses ArrayView arguments to send all array kinds into
360 // the same instantiated implementation.
361 #[cfg(not(feature = "blas"))]
362 use self::mat_mul_general as mat_mul_impl;
363 
364 #[cfg(feature = "blas")]
mat_mul_impl<A>( alpha: A, lhs: &ArrayView2<'_, A>, rhs: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>, ) where A: LinalgScalar,365 fn mat_mul_impl<A>(
366     alpha: A,
367     lhs: &ArrayView2<'_, A>,
368     rhs: &ArrayView2<'_, A>,
369     beta: A,
370     c: &mut ArrayViewMut2<'_, A>,
371 ) where
372     A: LinalgScalar,
373 {
374     // size cutoff for using BLAS
375     let cut = GEMM_BLAS_CUTOFF;
376     let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim());
377     if !(m > cut || n > cut || a > cut) || !(same_type::<A, f32>() || same_type::<A, f64>()) {
378         return mat_mul_general(alpha, lhs, rhs, beta, c);
379     }
380     {
381         // Use `c` for c-order and `f` for an f-order matrix
382         // We can handle c * c, f * f generally and
383         // c * f and f * c if the `f` matrix is square.
384         let mut lhs_ = lhs.view();
385         let mut rhs_ = rhs.view();
386         let mut c_ = c.view_mut();
387         let lhs_s0 = lhs_.strides()[0];
388         let rhs_s0 = rhs_.strides()[0];
389         let both_f = lhs_s0 == 1 && rhs_s0 == 1;
390         let mut lhs_trans = CblasNoTrans;
391         let mut rhs_trans = CblasNoTrans;
392         if both_f {
393             // A^t B^t = C^t => B A = C
394             let lhs_t = lhs_.reversed_axes();
395             lhs_ = rhs_.reversed_axes();
396             rhs_ = lhs_t;
397             c_ = c_.reversed_axes();
398             swap(&mut m, &mut n);
399         } else if lhs_s0 == 1 && m == a {
400             lhs_ = lhs_.reversed_axes();
401             lhs_trans = CblasTrans;
402         } else if rhs_s0 == 1 && a == n {
403             rhs_ = rhs_.reversed_axes();
404             rhs_trans = CblasTrans;
405         }
406 
407         macro_rules! gemm {
408             ($ty:ty, $gemm:ident) => {
409                 if blas_row_major_2d::<$ty, _>(&lhs_)
410                     && blas_row_major_2d::<$ty, _>(&rhs_)
411                     && blas_row_major_2d::<$ty, _>(&c_)
412                 {
413                     let (m, k) = match lhs_trans {
414                         CblasNoTrans => lhs_.dim(),
415                         _ => {
416                             let (rows, cols) = lhs_.dim();
417                             (cols, rows)
418                         }
419                     };
420                     let n = match rhs_trans {
421                         CblasNoTrans => rhs_.raw_dim()[1],
422                         _ => rhs_.raw_dim()[0],
423                     };
424                     // adjust strides, these may [1, 1] for column matrices
425                     let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index);
426                     let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index);
427                     let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index);
428 
429                     // gemm is C ← αA^Op B^Op + βC
430                     // Where Op is notrans/trans/conjtrans
431                     unsafe {
432                         blas_sys::$gemm(
433                             CblasRowMajor,
434                             lhs_trans,
435                             rhs_trans,
436                             m as blas_index,               // m, rows of Op(a)
437                             n as blas_index,               // n, cols of Op(b)
438                             k as blas_index,               // k, cols of Op(a)
439                             cast_as(&alpha),               // alpha
440                             lhs_.ptr.as_ptr() as *const _, // a
441                             lhs_stride,                    // lda
442                             rhs_.ptr.as_ptr() as *const _, // b
443                             rhs_stride,                    // ldb
444                             cast_as(&beta),                // beta
445                             c_.ptr.as_ptr() as *mut _,     // c
446                             c_stride,                      // ldc
447                         );
448                     }
449                     return;
450                 }
451             };
452         }
453         gemm!(f32, cblas_sgemm);
454         gemm!(f64, cblas_dgemm);
455     }
456     mat_mul_general(alpha, lhs, rhs, beta, c)
457 }
458 
459 /// C ← α A B + β C
mat_mul_general<A>( alpha: A, lhs: &ArrayView2<'_, A>, rhs: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>, ) where A: LinalgScalar,460 fn mat_mul_general<A>(
461     alpha: A,
462     lhs: &ArrayView2<'_, A>,
463     rhs: &ArrayView2<'_, A>,
464     beta: A,
465     c: &mut ArrayViewMut2<'_, A>,
466 ) where
467     A: LinalgScalar,
468 {
469     let ((m, k), (_, n)) = (lhs.dim(), rhs.dim());
470 
471     // common parameters for gemm
472     let ap = lhs.as_ptr();
473     let bp = rhs.as_ptr();
474     let cp = c.as_mut_ptr();
475     let (rsc, csc) = (c.strides()[0], c.strides()[1]);
476     if same_type::<A, f32>() {
477         unsafe {
478             ::matrixmultiply::sgemm(
479                 m,
480                 k,
481                 n,
482                 cast_as(&alpha),
483                 ap as *const _,
484                 lhs.strides()[0],
485                 lhs.strides()[1],
486                 bp as *const _,
487                 rhs.strides()[0],
488                 rhs.strides()[1],
489                 cast_as(&beta),
490                 cp as *mut _,
491                 rsc,
492                 csc,
493             );
494         }
495     } else if same_type::<A, f64>() {
496         unsafe {
497             ::matrixmultiply::dgemm(
498                 m,
499                 k,
500                 n,
501                 cast_as(&alpha),
502                 ap as *const _,
503                 lhs.strides()[0],
504                 lhs.strides()[1],
505                 bp as *const _,
506                 rhs.strides()[0],
507                 rhs.strides()[1],
508                 cast_as(&beta),
509                 cp as *mut _,
510                 rsc,
511                 csc,
512             );
513         }
514     } else {
515         // It's a no-op if `c` has zero length.
516         if c.is_empty() {
517             return;
518         }
519 
520         // initialize memory if beta is zero
521         if beta.is_zero() {
522             c.fill(beta);
523         }
524 
525         let mut i = 0;
526         let mut j = 0;
527         loop {
528             unsafe {
529                 let elt = c.uget_mut((i, j));
530                 *elt = *elt * beta
531                     + alpha
532                         * (0..k).fold(A::zero(), move |s, x| {
533                             s + *lhs.uget((i, x)) * *rhs.uget((x, j))
534                         });
535             }
536             j += 1;
537             if j == n {
538                 j = 0;
539                 i += 1;
540                 if i == m {
541                     break;
542                 }
543             }
544         }
545     }
546 }
547 
548 /// General matrix-matrix multiplication.
549 ///
550 /// Compute C ← α A B + β C
551 ///
552 /// The array shapes must agree in the way that
553 /// if `a` is *M* × *N*, then `b` is *N* × *K* and `c` is *M* × *K*.
554 ///
555 /// ***Panics*** if array shapes are not compatible<br>
556 /// *Note:* If enabled, uses blas `gemm` for elements of `f32, f64` when memory
557 /// layout allows.  The default matrixmultiply backend is otherwise used for
558 /// `f32, f64` for all memory layouts.
general_mat_mul<A, S1, S2, S3>( alpha: A, a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix2>, beta: A, c: &mut ArrayBase<S3, Ix2>, ) where S1: Data<Elem = A>, S2: Data<Elem = A>, S3: DataMut<Elem = A>, A: LinalgScalar,559 pub fn general_mat_mul<A, S1, S2, S3>(
560     alpha: A,
561     a: &ArrayBase<S1, Ix2>,
562     b: &ArrayBase<S2, Ix2>,
563     beta: A,
564     c: &mut ArrayBase<S3, Ix2>,
565 ) where
566     S1: Data<Elem = A>,
567     S2: Data<Elem = A>,
568     S3: DataMut<Elem = A>,
569     A: LinalgScalar,
570 {
571     let ((m, k), (k2, n)) = (a.dim(), b.dim());
572     let (m2, n2) = c.dim();
573     if k != k2 || m != m2 || n != n2 {
574         general_dot_shape_error(m, k, k2, n, m2, n2);
575     } else {
576         mat_mul_impl(alpha, &a.view(), &b.view(), beta, &mut c.view_mut());
577     }
578 }
579 
580 /// General matrix-vector multiplication.
581 ///
582 /// Compute y ← α A x + β y
583 ///
584 /// where A is a *M* × *N* matrix and x is an *N*-element column vector and
585 /// y an *M*-element column vector (one dimensional arrays).
586 ///
587 /// ***Panics*** if array shapes are not compatible<br>
588 /// *Note:* If enabled, uses blas `gemv` for elements of `f32, f64` when memory
589 /// layout allows.
590 #[allow(clippy::collapsible_if)]
general_mat_vec_mul<A, S1, S2, S3>( alpha: A, a: &ArrayBase<S1, Ix2>, x: &ArrayBase<S2, Ix1>, beta: A, y: &mut ArrayBase<S3, Ix1>, ) where S1: Data<Elem = A>, S2: Data<Elem = A>, S3: DataMut<Elem = A>, A: LinalgScalar,591 pub fn general_mat_vec_mul<A, S1, S2, S3>(
592     alpha: A,
593     a: &ArrayBase<S1, Ix2>,
594     x: &ArrayBase<S2, Ix1>,
595     beta: A,
596     y: &mut ArrayBase<S3, Ix1>,
597 ) where
598     S1: Data<Elem = A>,
599     S2: Data<Elem = A>,
600     S3: DataMut<Elem = A>,
601     A: LinalgScalar,
602 {
603     unsafe {
604         general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut())
605     }
606 }
607 
608 /// General matrix-vector multiplication
609 ///
610 /// Use a raw view for the destination vector, so that it can be uninitalized.
611 ///
612 /// ## Safety
613 ///
614 /// The caller must ensure that the raw view is valid for writing.
615 /// the destination may be uninitialized iff beta is zero.
616 #[allow(clippy::collapsible_else_if)]
general_mat_vec_mul_impl<A, S1, S2>( alpha: A, a: &ArrayBase<S1, Ix2>, x: &ArrayBase<S2, Ix1>, beta: A, y: RawArrayViewMut<A, Ix1>, ) where S1: Data<Elem = A>, S2: Data<Elem = A>, A: LinalgScalar,617 unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
618     alpha: A,
619     a: &ArrayBase<S1, Ix2>,
620     x: &ArrayBase<S2, Ix1>,
621     beta: A,
622     y: RawArrayViewMut<A, Ix1>,
623 ) where
624     S1: Data<Elem = A>,
625     S2: Data<Elem = A>,
626     A: LinalgScalar,
627 {
628     let ((m, k), k2) = (a.dim(), x.dim());
629     let m2 = y.dim();
630     if k != k2 || m != m2 {
631         general_dot_shape_error(m, k, k2, 1, m2, 1);
632     } else {
633         #[cfg(feature = "blas")]
634         macro_rules! gemv {
635             ($ty:ty, $gemv:ident) => {
636                 if let Some(layout) = blas_layout::<$ty, _>(&a) {
637                     if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) {
638                         // Determine stride between rows or columns. Note that the stride is
639                         // adjusted to at least `k` or `m` to handle the case of a matrix with a
640                         // trivial (length 1) dimension, since the stride for the trivial dimension
641                         // may be arbitrary.
642                         let a_trans = CblasNoTrans;
643                         let a_stride = match layout {
644                             CBLAS_LAYOUT::CblasRowMajor => {
645                                 a.strides()[0].max(k as isize) as blas_index
646                             }
647                             CBLAS_LAYOUT::CblasColMajor => {
648                                 a.strides()[1].max(m as isize) as blas_index
649                             }
650                         };
651 
652                         let x_stride = x.strides()[0] as blas_index;
653                         let y_stride = y.strides()[0] as blas_index;
654 
655                         blas_sys::$gemv(
656                             layout,
657                             a_trans,
658                             m as blas_index,            // m, rows of Op(a)
659                             k as blas_index,            // n, cols of Op(a)
660                             cast_as(&alpha),            // alpha
661                             a.ptr.as_ptr() as *const _, // a
662                             a_stride,                   // lda
663                             x.ptr.as_ptr() as *const _, // x
664                             x_stride,
665                             cast_as(&beta),           // beta
666                             y.ptr.as_ptr() as *mut _, // x
667                             y_stride,
668                         );
669                         return;
670                     }
671                 }
672             };
673         }
674         #[cfg(feature = "blas")]
675         gemv!(f32, cblas_sgemv);
676         #[cfg(feature = "blas")]
677         gemv!(f64, cblas_dgemv);
678 
679         /* general */
680 
681         if beta.is_zero() {
682             // when beta is zero, c may be uninitialized
683             Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
684                 elt.write(row.dot(x) * alpha);
685             });
686         } else {
687             Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
688                 *elt = *elt * beta + row.dot(x) * alpha;
689             });
690         }
691     }
692 }
693 
694 #[inline(always)]
695 /// Return `true` if `A` and `B` are the same type
same_type<A: 'static, B: 'static>() -> bool696 fn same_type<A: 'static, B: 'static>() -> bool {
697     TypeId::of::<A>() == TypeId::of::<B>()
698 }
699 
700 // Read pointer to type `A` as type `B`.
701 //
702 // **Panics** if `A` and `B` are not the same type
cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B703 fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B {
704     assert!(same_type::<A, B>());
705     unsafe { ::std::ptr::read(a as *const _ as *const B) }
706 }
707 
708 #[cfg(feature = "blas")]
blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool where S: RawData, A: 'static, S::Elem: 'static,709 fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
710 where
711     S: RawData,
712     A: 'static,
713     S::Elem: 'static,
714 {
715     if !same_type::<A, S::Elem>() {
716         return false;
717     }
718     if a.len() > blas_index::max_value() as usize {
719         return false;
720     }
721     let stride = a.strides()[0];
722     if stride > blas_index::max_value() as isize || stride < blas_index::min_value() as isize {
723         return false;
724     }
725     true
726 }
727 
728 #[cfg(feature = "blas")]
729 enum MemoryOrder {
730     C,
731     F,
732 }
733 
734 #[cfg(feature = "blas")]
blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool where S: Data, A: 'static, S::Elem: 'static,735 fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
736 where
737     S: Data,
738     A: 'static,
739     S::Elem: 'static,
740 {
741     if !same_type::<A, S::Elem>() {
742         return false;
743     }
744     is_blas_2d(&a.dim, &a.strides, MemoryOrder::C)
745 }
746 
747 #[cfg(feature = "blas")]
blas_column_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool where S: Data, A: 'static, S::Elem: 'static,748 fn blas_column_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
749 where
750     S: Data,
751     A: 'static,
752     S::Elem: 'static,
753 {
754     if !same_type::<A, S::Elem>() {
755         return false;
756     }
757     is_blas_2d(&a.dim, &a.strides, MemoryOrder::F)
758 }
759 
760 #[cfg(feature = "blas")]
is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool761 fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool {
762     let (m, n) = dim.into_pattern();
763     let s0 = stride[0] as isize;
764     let s1 = stride[1] as isize;
765     let (inner_stride, outer_dim) = match order {
766         MemoryOrder::C => (s1, n),
767         MemoryOrder::F => (s0, m),
768     };
769     if !(inner_stride == 1 || outer_dim == 1) {
770         return false;
771     }
772     if s0 < 1 || s1 < 1 {
773         return false;
774     }
775     if (s0 > blas_index::max_value() as isize || s0 < blas_index::min_value() as isize)
776         || (s1 > blas_index::max_value() as isize || s1 < blas_index::min_value() as isize)
777     {
778         return false;
779     }
780     if m > blas_index::max_value() as usize || n > blas_index::max_value() as usize {
781         return false;
782     }
783     true
784 }
785 
786 #[cfg(feature = "blas")]
blas_layout<A, S>(a: &ArrayBase<S, Ix2>) -> Option<CBLAS_LAYOUT> where S: Data, A: 'static, S::Elem: 'static,787 fn blas_layout<A, S>(a: &ArrayBase<S, Ix2>) -> Option<CBLAS_LAYOUT>
788 where
789     S: Data,
790     A: 'static,
791     S::Elem: 'static,
792 {
793     if blas_row_major_2d::<A, _>(a) {
794         Some(CBLAS_LAYOUT::CblasRowMajor)
795     } else if blas_column_major_2d::<A, _>(a) {
796         Some(CBLAS_LAYOUT::CblasColMajor)
797     } else {
798         None
799     }
800 }
801 
802 #[cfg(test)]
803 #[cfg(feature = "blas")]
804 mod blas_tests {
805     use super::*;
806 
807     #[test]
blas_row_major_2d_normal_matrix()808     fn blas_row_major_2d_normal_matrix() {
809         let m: Array2<f32> = Array2::zeros((3, 5));
810         assert!(blas_row_major_2d::<f32, _>(&m));
811         assert!(!blas_column_major_2d::<f32, _>(&m));
812     }
813 
814     #[test]
blas_row_major_2d_row_matrix()815     fn blas_row_major_2d_row_matrix() {
816         let m: Array2<f32> = Array2::zeros((1, 5));
817         assert!(blas_row_major_2d::<f32, _>(&m));
818         assert!(blas_column_major_2d::<f32, _>(&m));
819     }
820 
821     #[test]
blas_row_major_2d_column_matrix()822     fn blas_row_major_2d_column_matrix() {
823         let m: Array2<f32> = Array2::zeros((5, 1));
824         assert!(blas_row_major_2d::<f32, _>(&m));
825         assert!(blas_column_major_2d::<f32, _>(&m));
826     }
827 
828     #[test]
blas_row_major_2d_transposed_row_matrix()829     fn blas_row_major_2d_transposed_row_matrix() {
830         let m: Array2<f32> = Array2::zeros((1, 5));
831         let m_t = m.t();
832         assert!(blas_row_major_2d::<f32, _>(&m_t));
833         assert!(blas_column_major_2d::<f32, _>(&m_t));
834     }
835 
836     #[test]
blas_row_major_2d_transposed_column_matrix()837     fn blas_row_major_2d_transposed_column_matrix() {
838         let m: Array2<f32> = Array2::zeros((5, 1));
839         let m_t = m.t();
840         assert!(blas_row_major_2d::<f32, _>(&m_t));
841         assert!(blas_column_major_2d::<f32, _>(&m_t));
842     }
843 
844     #[test]
blas_column_major_2d_normal_matrix()845     fn blas_column_major_2d_normal_matrix() {
846         let m: Array2<f32> = Array2::zeros((3, 5).f());
847         assert!(!blas_row_major_2d::<f32, _>(&m));
848         assert!(blas_column_major_2d::<f32, _>(&m));
849     }
850 }
851