1 // Copyright 2014-2020 bluss and ndarray developers.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8
9 use crate::imp_prelude::*;
10 use crate::numeric_util;
11
12 use crate::{LinalgScalar, Zip};
13
14 use std::any::TypeId;
15 use alloc::vec::Vec;
16
17 #[cfg(feature = "blas")]
18 use std::cmp;
19 #[cfg(feature = "blas")]
20 use std::mem::swap;
21 #[cfg(feature = "blas")]
22 use libc::c_int;
23
24 #[cfg(feature = "blas")]
25 use cblas_sys as blas_sys;
26 #[cfg(feature = "blas")]
27 use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT};
28
29 /// len of vector before we use blas
30 #[cfg(feature = "blas")]
31 const DOT_BLAS_CUTOFF: usize = 32;
32 /// side of matrix before we use blas
33 #[cfg(feature = "blas")]
34 const GEMM_BLAS_CUTOFF: usize = 7;
35 #[cfg(feature = "blas")]
36 #[allow(non_camel_case_types)]
37 type blas_index = c_int; // blas index type
38
39 impl<A, S> ArrayBase<S, Ix1>
40 where
41 S: Data<Elem = A>,
42 {
43 /// Perform dot product or matrix multiplication of arrays `self` and `rhs`.
44 ///
45 /// `Rhs` may be either a one-dimensional or a two-dimensional array.
46 ///
47 /// If `Rhs` is one-dimensional, then the operation is a vector dot
48 /// product, which is the sum of the elementwise products (no conjugation
49 /// of complex operands, and thus not their inner product). In this case,
50 /// `self` and `rhs` must be the same length.
51 ///
52 /// If `Rhs` is two-dimensional, then the operation is matrix
53 /// multiplication, where `self` is treated as a row vector. In this case,
54 /// if `self` is shape *M*, then `rhs` is shape *M* × *N* and the result is
55 /// shape *N*.
56 ///
57 /// **Panics** if the array shapes are incompatible.<br>
58 /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
59 /// layout allows.
dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output where Self: Dot<Rhs>,60 pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
61 where
62 Self: Dot<Rhs>,
63 {
64 Dot::dot(self, rhs)
65 }
66
dot_generic<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A where S2: Data<Elem = A>, A: LinalgScalar,67 fn dot_generic<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
68 where
69 S2: Data<Elem = A>,
70 A: LinalgScalar,
71 {
72 debug_assert_eq!(self.len(), rhs.len());
73 assert!(self.len() == rhs.len());
74 if let Some(self_s) = self.as_slice() {
75 if let Some(rhs_s) = rhs.as_slice() {
76 return numeric_util::unrolled_dot(self_s, rhs_s);
77 }
78 }
79 let mut sum = A::zero();
80 for i in 0..self.len() {
81 unsafe {
82 sum = sum + *self.uget(i) * *rhs.uget(i);
83 }
84 }
85 sum
86 }
87
88 #[cfg(not(feature = "blas"))]
dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A where S2: Data<Elem = A>, A: LinalgScalar,89 fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
90 where
91 S2: Data<Elem = A>,
92 A: LinalgScalar,
93 {
94 self.dot_generic(rhs)
95 }
96
97 #[cfg(feature = "blas")]
dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A where S2: Data<Elem = A>, A: LinalgScalar,98 fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
99 where
100 S2: Data<Elem = A>,
101 A: LinalgScalar,
102 {
103 // Use only if the vector is large enough to be worth it
104 if self.len() >= DOT_BLAS_CUTOFF {
105 debug_assert_eq!(self.len(), rhs.len());
106 assert!(self.len() == rhs.len());
107 macro_rules! dot {
108 ($ty:ty, $func:ident) => {{
109 if blas_compat_1d::<$ty, _>(self) && blas_compat_1d::<$ty, _>(rhs) {
110 unsafe {
111 let (lhs_ptr, n, incx) =
112 blas_1d_params(self.ptr.as_ptr(), self.len(), self.strides()[0]);
113 let (rhs_ptr, _, incy) =
114 blas_1d_params(rhs.ptr.as_ptr(), rhs.len(), rhs.strides()[0]);
115 let ret = blas_sys::$func(
116 n,
117 lhs_ptr as *const $ty,
118 incx,
119 rhs_ptr as *const $ty,
120 incy,
121 );
122 return cast_as::<$ty, A>(&ret);
123 }
124 }
125 }};
126 }
127
128 dot! {f32, cblas_sdot};
129 dot! {f64, cblas_ddot};
130 }
131 self.dot_generic(rhs)
132 }
133 }
134
135 /// Return a pointer to the starting element in BLAS's view.
136 ///
137 /// BLAS wants a pointer to the element with lowest address,
138 /// which agrees with our pointer for non-negative strides, but
139 /// is at the opposite end for negative strides.
140 #[cfg(feature = "blas")]
blas_1d_params<A>( ptr: *const A, len: usize, stride: isize, ) -> (*const A, blas_index, blas_index)141 unsafe fn blas_1d_params<A>(
142 ptr: *const A,
143 len: usize,
144 stride: isize,
145 ) -> (*const A, blas_index, blas_index) {
146 // [x x x x]
147 // ^--ptr
148 // stride = -1
149 // ^--blas_ptr = ptr + (len - 1) * stride
150 if stride >= 0 || len == 0 {
151 (ptr, len as blas_index, stride as blas_index)
152 } else {
153 let ptr = ptr.offset((len - 1) as isize * stride);
154 (ptr, len as blas_index, stride as blas_index)
155 }
156 }
157
158 /// Matrix Multiplication
159 ///
160 /// For two-dimensional arrays, the dot method computes the matrix
161 /// multiplication.
162 pub trait Dot<Rhs> {
163 /// The result of the operation.
164 ///
165 /// For two-dimensional arrays: a rectangular array.
166 type Output;
dot(&self, rhs: &Rhs) -> Self::Output167 fn dot(&self, rhs: &Rhs) -> Self::Output;
168 }
169
170 impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix1>
171 where
172 S: Data<Elem = A>,
173 S2: Data<Elem = A>,
174 A: LinalgScalar,
175 {
176 type Output = A;
177
178 /// Compute the dot product of one-dimensional arrays.
179 ///
180 /// The dot product is a sum of the elementwise products (no conjugation
181 /// of complex operands, and thus not their inner product).
182 ///
183 /// **Panics** if the arrays are not of the same length.<br>
184 /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
185 /// layout allows.
dot(&self, rhs: &ArrayBase<S2, Ix1>) -> A186 fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> A {
187 self.dot_impl(rhs)
188 }
189 }
190
191 impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix1>
192 where
193 S: Data<Elem = A>,
194 S2: Data<Elem = A>,
195 A: LinalgScalar,
196 {
197 type Output = Array<A, Ix1>;
198
199 /// Perform the matrix multiplication of the row vector `self` and
200 /// rectangular matrix `rhs`.
201 ///
202 /// The array shapes must agree in the way that
203 /// if `self` is *M*, then `rhs` is *M* × *N*.
204 ///
205 /// Return a result array with shape *N*.
206 ///
207 /// **Panics** if shapes are incompatible.
dot(&self, rhs: &ArrayBase<S2, Ix2>) -> Array<A, Ix1>208 fn dot(&self, rhs: &ArrayBase<S2, Ix2>) -> Array<A, Ix1> {
209 rhs.t().dot(self)
210 }
211 }
212
213 impl<A, S> ArrayBase<S, Ix2>
214 where
215 S: Data<Elem = A>,
216 {
217 /// Perform matrix multiplication of rectangular arrays `self` and `rhs`.
218 ///
219 /// `Rhs` may be either a one-dimensional or a two-dimensional array.
220 ///
221 /// If Rhs is two-dimensional, they array shapes must agree in the way that
222 /// if `self` is *M* × *N*, then `rhs` is *N* × *K*.
223 ///
224 /// Return a result array with shape *M* × *K*.
225 ///
226 /// **Panics** if shapes are incompatible or the number of elements in the
227 /// result would overflow `isize`.
228 ///
229 /// *Note:* If enabled, uses blas `gemv/gemm` for elements of `f32, f64`
230 /// when memory layout allows. The default matrixmultiply backend
231 /// is otherwise used for `f32, f64` for all memory layouts.
232 ///
233 /// ```
234 /// use ndarray::arr2;
235 ///
236 /// let a = arr2(&[[1., 2.],
237 /// [0., 1.]]);
238 /// let b = arr2(&[[1., 2.],
239 /// [2., 3.]]);
240 ///
241 /// assert!(
242 /// a.dot(&b) == arr2(&[[5., 8.],
243 /// [2., 3.]])
244 /// );
245 /// ```
dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output where Self: Dot<Rhs>,246 pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
247 where
248 Self: Dot<Rhs>,
249 {
250 Dot::dot(self, rhs)
251 }
252 }
253
254 impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix2>
255 where
256 S: Data<Elem = A>,
257 S2: Data<Elem = A>,
258 A: LinalgScalar,
259 {
260 type Output = Array2<A>;
dot(&self, b: &ArrayBase<S2, Ix2>) -> Array2<A>261 fn dot(&self, b: &ArrayBase<S2, Ix2>) -> Array2<A> {
262 let a = self.view();
263 let b = b.view();
264 let ((m, k), (k2, n)) = (a.dim(), b.dim());
265 if k != k2 || m.checked_mul(n).is_none() {
266 dot_shape_error(m, k, k2, n);
267 }
268
269 let lhs_s0 = a.strides()[0];
270 let rhs_s0 = b.strides()[0];
271 let column_major = lhs_s0 == 1 && rhs_s0 == 1;
272 // A is Copy so this is safe
273 let mut v = Vec::with_capacity(m * n);
274 let mut c;
275 unsafe {
276 v.set_len(m * n);
277 c = Array::from_shape_vec_unchecked((m, n).set_f(column_major), v);
278 }
279 mat_mul_impl(A::one(), &a, &b, A::zero(), &mut c.view_mut());
280 c
281 }
282 }
283
284 /// Assumes that `m` and `n` are ≤ `isize::MAX`.
285 #[cold]
286 #[inline(never)]
dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> !287 fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! {
288 match m.checked_mul(n) {
289 Some(len) if len <= ::std::isize::MAX as usize => {}
290 _ => panic!("ndarray: shape {} × {} overflows isize", m, n),
291 }
292 panic!(
293 "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication",
294 m, k, k2, n
295 );
296 }
297
298 #[cold]
299 #[inline(never)]
general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> !300 fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> ! {
301 panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication",
302 m, k, k2, n, c1, c2);
303 }
304
305 /// Perform the matrix multiplication of the rectangular array `self` and
306 /// column vector `rhs`.
307 ///
308 /// The array shapes must agree in the way that
309 /// if `self` is *M* × *N*, then `rhs` is *N*.
310 ///
311 /// Return a result array with shape *M*.
312 ///
313 /// **Panics** if shapes are incompatible.
314 impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix2>
315 where
316 S: Data<Elem = A>,
317 S2: Data<Elem = A>,
318 A: LinalgScalar,
319 {
320 type Output = Array<A, Ix1>;
dot(&self, rhs: &ArrayBase<S2, Ix1>) -> Array<A, Ix1>321 fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> Array<A, Ix1> {
322 let ((m, a), n) = (self.dim(), rhs.dim());
323 if a != n {
324 dot_shape_error(m, a, n, 1);
325 }
326
327 // Avoid initializing the memory in vec -- set it during iteration
328 unsafe {
329 let mut c = Array1::uninit(m);
330 general_mat_vec_mul_impl(A::one(), self, rhs, A::zero(), c.raw_view_mut().cast::<A>());
331 c.assume_init()
332 }
333 }
334 }
335
336 impl<A, S, D> ArrayBase<S, D>
337 where
338 S: Data<Elem = A>,
339 D: Dimension,
340 {
341 /// Perform the operation `self += alpha * rhs` efficiently, where
342 /// `alpha` is a scalar and `rhs` is another array. This operation is
343 /// also known as `axpy` in BLAS.
344 ///
345 /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
346 ///
347 /// **Panics** if broadcasting isn’t possible.
scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>) where S: DataMut, S2: Data<Elem = A>, A: LinalgScalar, E: Dimension,348 pub fn scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
349 where
350 S: DataMut,
351 S2: Data<Elem = A>,
352 A: LinalgScalar,
353 E: Dimension,
354 {
355 self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
356 }
357 }
358
359 // mat_mul_impl uses ArrayView arguments to send all array kinds into
360 // the same instantiated implementation.
361 #[cfg(not(feature = "blas"))]
362 use self::mat_mul_general as mat_mul_impl;
363
364 #[cfg(feature = "blas")]
mat_mul_impl<A>( alpha: A, lhs: &ArrayView2<'_, A>, rhs: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>, ) where A: LinalgScalar,365 fn mat_mul_impl<A>(
366 alpha: A,
367 lhs: &ArrayView2<'_, A>,
368 rhs: &ArrayView2<'_, A>,
369 beta: A,
370 c: &mut ArrayViewMut2<'_, A>,
371 ) where
372 A: LinalgScalar,
373 {
374 // size cutoff for using BLAS
375 let cut = GEMM_BLAS_CUTOFF;
376 let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim());
377 if !(m > cut || n > cut || a > cut) || !(same_type::<A, f32>() || same_type::<A, f64>()) {
378 return mat_mul_general(alpha, lhs, rhs, beta, c);
379 }
380 {
381 // Use `c` for c-order and `f` for an f-order matrix
382 // We can handle c * c, f * f generally and
383 // c * f and f * c if the `f` matrix is square.
384 let mut lhs_ = lhs.view();
385 let mut rhs_ = rhs.view();
386 let mut c_ = c.view_mut();
387 let lhs_s0 = lhs_.strides()[0];
388 let rhs_s0 = rhs_.strides()[0];
389 let both_f = lhs_s0 == 1 && rhs_s0 == 1;
390 let mut lhs_trans = CblasNoTrans;
391 let mut rhs_trans = CblasNoTrans;
392 if both_f {
393 // A^t B^t = C^t => B A = C
394 let lhs_t = lhs_.reversed_axes();
395 lhs_ = rhs_.reversed_axes();
396 rhs_ = lhs_t;
397 c_ = c_.reversed_axes();
398 swap(&mut m, &mut n);
399 } else if lhs_s0 == 1 && m == a {
400 lhs_ = lhs_.reversed_axes();
401 lhs_trans = CblasTrans;
402 } else if rhs_s0 == 1 && a == n {
403 rhs_ = rhs_.reversed_axes();
404 rhs_trans = CblasTrans;
405 }
406
407 macro_rules! gemm {
408 ($ty:ty, $gemm:ident) => {
409 if blas_row_major_2d::<$ty, _>(&lhs_)
410 && blas_row_major_2d::<$ty, _>(&rhs_)
411 && blas_row_major_2d::<$ty, _>(&c_)
412 {
413 let (m, k) = match lhs_trans {
414 CblasNoTrans => lhs_.dim(),
415 _ => {
416 let (rows, cols) = lhs_.dim();
417 (cols, rows)
418 }
419 };
420 let n = match rhs_trans {
421 CblasNoTrans => rhs_.raw_dim()[1],
422 _ => rhs_.raw_dim()[0],
423 };
424 // adjust strides, these may [1, 1] for column matrices
425 let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index);
426 let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index);
427 let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index);
428
429 // gemm is C ← αA^Op B^Op + βC
430 // Where Op is notrans/trans/conjtrans
431 unsafe {
432 blas_sys::$gemm(
433 CblasRowMajor,
434 lhs_trans,
435 rhs_trans,
436 m as blas_index, // m, rows of Op(a)
437 n as blas_index, // n, cols of Op(b)
438 k as blas_index, // k, cols of Op(a)
439 cast_as(&alpha), // alpha
440 lhs_.ptr.as_ptr() as *const _, // a
441 lhs_stride, // lda
442 rhs_.ptr.as_ptr() as *const _, // b
443 rhs_stride, // ldb
444 cast_as(&beta), // beta
445 c_.ptr.as_ptr() as *mut _, // c
446 c_stride, // ldc
447 );
448 }
449 return;
450 }
451 };
452 }
453 gemm!(f32, cblas_sgemm);
454 gemm!(f64, cblas_dgemm);
455 }
456 mat_mul_general(alpha, lhs, rhs, beta, c)
457 }
458
459 /// C ← α A B + β C
mat_mul_general<A>( alpha: A, lhs: &ArrayView2<'_, A>, rhs: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>, ) where A: LinalgScalar,460 fn mat_mul_general<A>(
461 alpha: A,
462 lhs: &ArrayView2<'_, A>,
463 rhs: &ArrayView2<'_, A>,
464 beta: A,
465 c: &mut ArrayViewMut2<'_, A>,
466 ) where
467 A: LinalgScalar,
468 {
469 let ((m, k), (_, n)) = (lhs.dim(), rhs.dim());
470
471 // common parameters for gemm
472 let ap = lhs.as_ptr();
473 let bp = rhs.as_ptr();
474 let cp = c.as_mut_ptr();
475 let (rsc, csc) = (c.strides()[0], c.strides()[1]);
476 if same_type::<A, f32>() {
477 unsafe {
478 ::matrixmultiply::sgemm(
479 m,
480 k,
481 n,
482 cast_as(&alpha),
483 ap as *const _,
484 lhs.strides()[0],
485 lhs.strides()[1],
486 bp as *const _,
487 rhs.strides()[0],
488 rhs.strides()[1],
489 cast_as(&beta),
490 cp as *mut _,
491 rsc,
492 csc,
493 );
494 }
495 } else if same_type::<A, f64>() {
496 unsafe {
497 ::matrixmultiply::dgemm(
498 m,
499 k,
500 n,
501 cast_as(&alpha),
502 ap as *const _,
503 lhs.strides()[0],
504 lhs.strides()[1],
505 bp as *const _,
506 rhs.strides()[0],
507 rhs.strides()[1],
508 cast_as(&beta),
509 cp as *mut _,
510 rsc,
511 csc,
512 );
513 }
514 } else {
515 // It's a no-op if `c` has zero length.
516 if c.is_empty() {
517 return;
518 }
519
520 // initialize memory if beta is zero
521 if beta.is_zero() {
522 c.fill(beta);
523 }
524
525 let mut i = 0;
526 let mut j = 0;
527 loop {
528 unsafe {
529 let elt = c.uget_mut((i, j));
530 *elt = *elt * beta
531 + alpha
532 * (0..k).fold(A::zero(), move |s, x| {
533 s + *lhs.uget((i, x)) * *rhs.uget((x, j))
534 });
535 }
536 j += 1;
537 if j == n {
538 j = 0;
539 i += 1;
540 if i == m {
541 break;
542 }
543 }
544 }
545 }
546 }
547
548 /// General matrix-matrix multiplication.
549 ///
550 /// Compute C ← α A B + β C
551 ///
552 /// The array shapes must agree in the way that
553 /// if `a` is *M* × *N*, then `b` is *N* × *K* and `c` is *M* × *K*.
554 ///
555 /// ***Panics*** if array shapes are not compatible<br>
556 /// *Note:* If enabled, uses blas `gemm` for elements of `f32, f64` when memory
557 /// layout allows. The default matrixmultiply backend is otherwise used for
558 /// `f32, f64` for all memory layouts.
general_mat_mul<A, S1, S2, S3>( alpha: A, a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix2>, beta: A, c: &mut ArrayBase<S3, Ix2>, ) where S1: Data<Elem = A>, S2: Data<Elem = A>, S3: DataMut<Elem = A>, A: LinalgScalar,559 pub fn general_mat_mul<A, S1, S2, S3>(
560 alpha: A,
561 a: &ArrayBase<S1, Ix2>,
562 b: &ArrayBase<S2, Ix2>,
563 beta: A,
564 c: &mut ArrayBase<S3, Ix2>,
565 ) where
566 S1: Data<Elem = A>,
567 S2: Data<Elem = A>,
568 S3: DataMut<Elem = A>,
569 A: LinalgScalar,
570 {
571 let ((m, k), (k2, n)) = (a.dim(), b.dim());
572 let (m2, n2) = c.dim();
573 if k != k2 || m != m2 || n != n2 {
574 general_dot_shape_error(m, k, k2, n, m2, n2);
575 } else {
576 mat_mul_impl(alpha, &a.view(), &b.view(), beta, &mut c.view_mut());
577 }
578 }
579
580 /// General matrix-vector multiplication.
581 ///
582 /// Compute y ← α A x + β y
583 ///
584 /// where A is a *M* × *N* matrix and x is an *N*-element column vector and
585 /// y an *M*-element column vector (one dimensional arrays).
586 ///
587 /// ***Panics*** if array shapes are not compatible<br>
588 /// *Note:* If enabled, uses blas `gemv` for elements of `f32, f64` when memory
589 /// layout allows.
590 #[allow(clippy::collapsible_if)]
general_mat_vec_mul<A, S1, S2, S3>( alpha: A, a: &ArrayBase<S1, Ix2>, x: &ArrayBase<S2, Ix1>, beta: A, y: &mut ArrayBase<S3, Ix1>, ) where S1: Data<Elem = A>, S2: Data<Elem = A>, S3: DataMut<Elem = A>, A: LinalgScalar,591 pub fn general_mat_vec_mul<A, S1, S2, S3>(
592 alpha: A,
593 a: &ArrayBase<S1, Ix2>,
594 x: &ArrayBase<S2, Ix1>,
595 beta: A,
596 y: &mut ArrayBase<S3, Ix1>,
597 ) where
598 S1: Data<Elem = A>,
599 S2: Data<Elem = A>,
600 S3: DataMut<Elem = A>,
601 A: LinalgScalar,
602 {
603 unsafe {
604 general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut())
605 }
606 }
607
608 /// General matrix-vector multiplication
609 ///
610 /// Use a raw view for the destination vector, so that it can be uninitalized.
611 ///
612 /// ## Safety
613 ///
614 /// The caller must ensure that the raw view is valid for writing.
615 /// the destination may be uninitialized iff beta is zero.
616 #[allow(clippy::collapsible_else_if)]
general_mat_vec_mul_impl<A, S1, S2>( alpha: A, a: &ArrayBase<S1, Ix2>, x: &ArrayBase<S2, Ix1>, beta: A, y: RawArrayViewMut<A, Ix1>, ) where S1: Data<Elem = A>, S2: Data<Elem = A>, A: LinalgScalar,617 unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
618 alpha: A,
619 a: &ArrayBase<S1, Ix2>,
620 x: &ArrayBase<S2, Ix1>,
621 beta: A,
622 y: RawArrayViewMut<A, Ix1>,
623 ) where
624 S1: Data<Elem = A>,
625 S2: Data<Elem = A>,
626 A: LinalgScalar,
627 {
628 let ((m, k), k2) = (a.dim(), x.dim());
629 let m2 = y.dim();
630 if k != k2 || m != m2 {
631 general_dot_shape_error(m, k, k2, 1, m2, 1);
632 } else {
633 #[cfg(feature = "blas")]
634 macro_rules! gemv {
635 ($ty:ty, $gemv:ident) => {
636 if let Some(layout) = blas_layout::<$ty, _>(&a) {
637 if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) {
638 // Determine stride between rows or columns. Note that the stride is
639 // adjusted to at least `k` or `m` to handle the case of a matrix with a
640 // trivial (length 1) dimension, since the stride for the trivial dimension
641 // may be arbitrary.
642 let a_trans = CblasNoTrans;
643 let a_stride = match layout {
644 CBLAS_LAYOUT::CblasRowMajor => {
645 a.strides()[0].max(k as isize) as blas_index
646 }
647 CBLAS_LAYOUT::CblasColMajor => {
648 a.strides()[1].max(m as isize) as blas_index
649 }
650 };
651
652 let x_stride = x.strides()[0] as blas_index;
653 let y_stride = y.strides()[0] as blas_index;
654
655 blas_sys::$gemv(
656 layout,
657 a_trans,
658 m as blas_index, // m, rows of Op(a)
659 k as blas_index, // n, cols of Op(a)
660 cast_as(&alpha), // alpha
661 a.ptr.as_ptr() as *const _, // a
662 a_stride, // lda
663 x.ptr.as_ptr() as *const _, // x
664 x_stride,
665 cast_as(&beta), // beta
666 y.ptr.as_ptr() as *mut _, // x
667 y_stride,
668 );
669 return;
670 }
671 }
672 };
673 }
674 #[cfg(feature = "blas")]
675 gemv!(f32, cblas_sgemv);
676 #[cfg(feature = "blas")]
677 gemv!(f64, cblas_dgemv);
678
679 /* general */
680
681 if beta.is_zero() {
682 // when beta is zero, c may be uninitialized
683 Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
684 elt.write(row.dot(x) * alpha);
685 });
686 } else {
687 Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
688 *elt = *elt * beta + row.dot(x) * alpha;
689 });
690 }
691 }
692 }
693
694 #[inline(always)]
695 /// Return `true` if `A` and `B` are the same type
same_type<A: 'static, B: 'static>() -> bool696 fn same_type<A: 'static, B: 'static>() -> bool {
697 TypeId::of::<A>() == TypeId::of::<B>()
698 }
699
700 // Read pointer to type `A` as type `B`.
701 //
702 // **Panics** if `A` and `B` are not the same type
cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B703 fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B {
704 assert!(same_type::<A, B>());
705 unsafe { ::std::ptr::read(a as *const _ as *const B) }
706 }
707
708 #[cfg(feature = "blas")]
blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool where S: RawData, A: 'static, S::Elem: 'static,709 fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
710 where
711 S: RawData,
712 A: 'static,
713 S::Elem: 'static,
714 {
715 if !same_type::<A, S::Elem>() {
716 return false;
717 }
718 if a.len() > blas_index::max_value() as usize {
719 return false;
720 }
721 let stride = a.strides()[0];
722 if stride > blas_index::max_value() as isize || stride < blas_index::min_value() as isize {
723 return false;
724 }
725 true
726 }
727
728 #[cfg(feature = "blas")]
729 enum MemoryOrder {
730 C,
731 F,
732 }
733
734 #[cfg(feature = "blas")]
blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool where S: Data, A: 'static, S::Elem: 'static,735 fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
736 where
737 S: Data,
738 A: 'static,
739 S::Elem: 'static,
740 {
741 if !same_type::<A, S::Elem>() {
742 return false;
743 }
744 is_blas_2d(&a.dim, &a.strides, MemoryOrder::C)
745 }
746
747 #[cfg(feature = "blas")]
blas_column_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool where S: Data, A: 'static, S::Elem: 'static,748 fn blas_column_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
749 where
750 S: Data,
751 A: 'static,
752 S::Elem: 'static,
753 {
754 if !same_type::<A, S::Elem>() {
755 return false;
756 }
757 is_blas_2d(&a.dim, &a.strides, MemoryOrder::F)
758 }
759
760 #[cfg(feature = "blas")]
is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool761 fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool {
762 let (m, n) = dim.into_pattern();
763 let s0 = stride[0] as isize;
764 let s1 = stride[1] as isize;
765 let (inner_stride, outer_dim) = match order {
766 MemoryOrder::C => (s1, n),
767 MemoryOrder::F => (s0, m),
768 };
769 if !(inner_stride == 1 || outer_dim == 1) {
770 return false;
771 }
772 if s0 < 1 || s1 < 1 {
773 return false;
774 }
775 if (s0 > blas_index::max_value() as isize || s0 < blas_index::min_value() as isize)
776 || (s1 > blas_index::max_value() as isize || s1 < blas_index::min_value() as isize)
777 {
778 return false;
779 }
780 if m > blas_index::max_value() as usize || n > blas_index::max_value() as usize {
781 return false;
782 }
783 true
784 }
785
786 #[cfg(feature = "blas")]
blas_layout<A, S>(a: &ArrayBase<S, Ix2>) -> Option<CBLAS_LAYOUT> where S: Data, A: 'static, S::Elem: 'static,787 fn blas_layout<A, S>(a: &ArrayBase<S, Ix2>) -> Option<CBLAS_LAYOUT>
788 where
789 S: Data,
790 A: 'static,
791 S::Elem: 'static,
792 {
793 if blas_row_major_2d::<A, _>(a) {
794 Some(CBLAS_LAYOUT::CblasRowMajor)
795 } else if blas_column_major_2d::<A, _>(a) {
796 Some(CBLAS_LAYOUT::CblasColMajor)
797 } else {
798 None
799 }
800 }
801
802 #[cfg(test)]
803 #[cfg(feature = "blas")]
804 mod blas_tests {
805 use super::*;
806
807 #[test]
blas_row_major_2d_normal_matrix()808 fn blas_row_major_2d_normal_matrix() {
809 let m: Array2<f32> = Array2::zeros((3, 5));
810 assert!(blas_row_major_2d::<f32, _>(&m));
811 assert!(!blas_column_major_2d::<f32, _>(&m));
812 }
813
814 #[test]
blas_row_major_2d_row_matrix()815 fn blas_row_major_2d_row_matrix() {
816 let m: Array2<f32> = Array2::zeros((1, 5));
817 assert!(blas_row_major_2d::<f32, _>(&m));
818 assert!(blas_column_major_2d::<f32, _>(&m));
819 }
820
821 #[test]
blas_row_major_2d_column_matrix()822 fn blas_row_major_2d_column_matrix() {
823 let m: Array2<f32> = Array2::zeros((5, 1));
824 assert!(blas_row_major_2d::<f32, _>(&m));
825 assert!(blas_column_major_2d::<f32, _>(&m));
826 }
827
828 #[test]
blas_row_major_2d_transposed_row_matrix()829 fn blas_row_major_2d_transposed_row_matrix() {
830 let m: Array2<f32> = Array2::zeros((1, 5));
831 let m_t = m.t();
832 assert!(blas_row_major_2d::<f32, _>(&m_t));
833 assert!(blas_column_major_2d::<f32, _>(&m_t));
834 }
835
836 #[test]
blas_row_major_2d_transposed_column_matrix()837 fn blas_row_major_2d_transposed_column_matrix() {
838 let m: Array2<f32> = Array2::zeros((5, 1));
839 let m_t = m.t();
840 assert!(blas_row_major_2d::<f32, _>(&m_t));
841 assert!(blas_column_major_2d::<f32, _>(&m_t));
842 }
843
844 #[test]
blas_column_major_2d_normal_matrix()845 fn blas_column_major_2d_normal_matrix() {
846 let m: Array2<f32> = Array2::zeros((3, 5).f());
847 assert!(!blas_row_major_2d::<f32, _>(&m));
848 assert!(blas_column_major_2d::<f32, _>(&m));
849 }
850 }
851