1 #![allow(
2     clippy::many_single_char_names,
3     clippy::deref_addrof,
4     clippy::unreadable_literal,
5     clippy::many_single_char_names
6 )]
7 #![cfg(feature = "std")]
8 use ndarray::linalg::general_mat_mul;
9 use ndarray::prelude::*;
10 use ndarray::{rcarr1, rcarr2};
11 use ndarray::{Data, LinalgScalar};
12 use ndarray::{Ix, Ixs};
13 use num_traits::Zero;
14 
15 use approx::assert_abs_diff_eq;
16 use defmac::defmac;
17 
test_oper(op: &str, a: &[f32], b: &[f32], c: &[f32])18 fn test_oper(op: &str, a: &[f32], b: &[f32], c: &[f32]) {
19     let aa = rcarr1(a);
20     let bb = rcarr1(b);
21     let cc = rcarr1(c);
22     test_oper_arr::<f32, _>(op, aa.clone(), bb.clone(), cc.clone());
23     let dim = (2, 2);
24     let aa = aa.reshape(dim);
25     let bb = bb.reshape(dim);
26     let cc = cc.reshape(dim);
27     test_oper_arr::<f32, _>(op, aa.clone(), bb.clone(), cc.clone());
28     let dim = (1, 2, 1, 2);
29     let aa = aa.reshape(dim);
30     let bb = bb.reshape(dim);
31     let cc = cc.reshape(dim);
32     test_oper_arr::<f32, _>(op, aa.clone(), bb.clone(), cc.clone());
33 }
34 
35 
test_oper_arr<A, D>(op: &str, mut aa: ArcArray<f32, D>, bb: ArcArray<f32, D>, cc: ArcArray<f32, D>) where D: Dimension,36 fn test_oper_arr<A, D>(op: &str, mut aa: ArcArray<f32, D>, bb: ArcArray<f32, D>, cc: ArcArray<f32, D>)
37 where
38     D: Dimension,
39 {
40     match op {
41         "+" => {
42             assert_eq!(&aa + &bb, cc);
43             aa += &bb;
44             assert_eq!(aa, cc);
45         }
46         "-" => {
47             assert_eq!(&aa - &bb, cc);
48             aa -= &bb;
49             assert_eq!(aa, cc);
50         }
51         "*" => {
52             assert_eq!(&aa * &bb, cc);
53             aa *= &bb;
54             assert_eq!(aa, cc);
55         }
56         "/" => {
57             assert_eq!(&aa / &bb, cc);
58             aa /= &bb;
59             assert_eq!(aa, cc);
60         }
61         "%" => {
62             assert_eq!(&aa % &bb, cc);
63             aa %= &bb;
64             assert_eq!(aa, cc);
65         }
66         "neg" => {
67             assert_eq!(-&aa, cc);
68             assert_eq!(-aa.clone(), cc);
69         }
70         _ => panic!(),
71     }
72 }
73 
74 #[test]
operations()75 fn operations() {
76     test_oper(
77         "+",
78         &[1.0, 2.0, 3.0, 4.0],
79         &[0.0, 1.0, 2.0, 3.0],
80         &[1.0, 3.0, 5.0, 7.0],
81     );
82     test_oper(
83         "-",
84         &[1.0, 2.0, 3.0, 4.0],
85         &[0.0, 1.0, 2.0, 3.0],
86         &[1.0, 1.0, 1.0, 1.0],
87     );
88     test_oper(
89         "*",
90         &[1.0, 2.0, 3.0, 4.0],
91         &[0.0, 1.0, 2.0, 3.0],
92         &[0.0, 2.0, 6.0, 12.0],
93     );
94     test_oper(
95         "/",
96         &[1.0, 2.0, 3.0, 4.0],
97         &[1.0, 1.0, 2.0, 3.0],
98         &[1.0, 2.0, 3.0 / 2.0, 4.0 / 3.0],
99     );
100     test_oper(
101         "%",
102         &[1.0, 2.0, 3.0, 4.0],
103         &[1.0, 1.0, 2.0, 3.0],
104         &[0.0, 0.0, 1.0, 1.0],
105     );
106     test_oper(
107         "neg",
108         &[1.0, 2.0, 3.0, 4.0],
109         &[1.0, 1.0, 2.0, 3.0],
110         &[-1.0, -2.0, -3.0, -4.0],
111     );
112 }
113 
114 #[test]
scalar_operations()115 fn scalar_operations() {
116     let a = arr0::<f32>(1.);
117     let b = rcarr1::<f32>(&[1., 1.]);
118     let c = rcarr2(&[[1., 1.], [1., 1.]]);
119 
120     {
121         let mut x = a.clone();
122         let mut y = arr0(0.);
123         x += 1.;
124         y.fill(2.);
125         assert_eq!(x, a + arr0(1.));
126         assert_eq!(x, y);
127     }
128 
129     {
130         let mut x = b.clone();
131         let mut y = rcarr1(&[0., 0.]);
132         x += 1.;
133         y.fill(2.);
134         assert_eq!(x, b + arr0(1.));
135         assert_eq!(x, y);
136     }
137 
138     {
139         let mut x = c.clone();
140         let mut y = ArcArray::zeros((2, 2));
141         x += 1.;
142         y.fill(2.);
143         assert_eq!(x, c + arr0(1.));
144         assert_eq!(x, y);
145     }
146 }
147 
reference_dot<'a, V1, V2>(a: V1, b: V2) -> f32 where V1: AsArray<'a, f32>, V2: AsArray<'a, f32>,148 fn reference_dot<'a, V1, V2>(a: V1, b: V2) -> f32
149 where
150     V1: AsArray<'a, f32>,
151     V2: AsArray<'a, f32>,
152 {
153     let a = a.into();
154     let b = b.into();
155     a.iter()
156         .zip(b.iter())
157         .fold(f32::zero(), |acc, (&x, &y)| acc + x * y)
158 }
159 
160 #[test]
dot_product()161 fn dot_product() {
162     let a = Array::range(0., 69., 1.);
163     let b = &a * 2. - 7.;
164     let dot = 197846.;
165     assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5);
166 
167     // test different alignments
168     let max = 8 as Ixs;
169     for i in 1..max {
170         let a1 = a.slice(s![i..]);
171         let b1 = b.slice(s![i..]);
172         assert_abs_diff_eq!(a1.dot(&b1), reference_dot(&a1, &b1), epsilon = 1e-5);
173         let a2 = a.slice(s![..-i]);
174         let b2 = b.slice(s![i..]);
175         assert_abs_diff_eq!(a2.dot(&b2), reference_dot(&a2, &b2), epsilon = 1e-5);
176     }
177 
178     let a = a.map(|f| *f as f32);
179     let b = b.map(|f| *f as f32);
180     assert_abs_diff_eq!(a.dot(&b), dot as f32, epsilon = 1e-5);
181 
182     let max = 8 as Ixs;
183     for i in 1..max {
184         let a1 = a.slice(s![i..]);
185         let b1 = b.slice(s![i..]);
186         assert_abs_diff_eq!(a1.dot(&b1), reference_dot(&a1, &b1), epsilon = 1e-5);
187         let a2 = a.slice(s![..-i]);
188         let b2 = b.slice(s![i..]);
189         assert_abs_diff_eq!(a2.dot(&b2), reference_dot(&a2, &b2), epsilon = 1e-5);
190     }
191 
192     let a = a.map(|f| *f as i32);
193     let b = b.map(|f| *f as i32);
194     assert_eq!(a.dot(&b), dot as i32);
195 }
196 
197 // test that we can dot product with a broadcast array
198 #[test]
dot_product_0()199 fn dot_product_0() {
200     let a = Array::range(0., 69., 1.);
201     let x = 1.5;
202     let b = aview0(&x);
203     let b = b.broadcast(a.dim()).unwrap();
204     assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5);
205 
206     // test different alignments
207     let max = 8 as Ixs;
208     for i in 1..max {
209         let a1 = a.slice(s![i..]);
210         let b1 = b.slice(s![i..]);
211         assert_abs_diff_eq!(a1.dot(&b1), reference_dot(&a1, &b1), epsilon = 1e-5);
212         let a2 = a.slice(s![..-i]);
213         let b2 = b.slice(s![i..]);
214         assert_abs_diff_eq!(a2.dot(&b2), reference_dot(&a2, &b2), epsilon = 1e-5);
215     }
216 }
217 
218 #[test]
dot_product_neg_stride()219 fn dot_product_neg_stride() {
220     // test that we can dot with negative stride
221     let a = Array::range(0., 69., 1.);
222     let b = &a * 2. - 7.;
223     for stride in -10..0 {
224         // both negative
225         let a = a.slice(s![..;stride]);
226         let b = b.slice(s![..;stride]);
227         assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5);
228     }
229     for stride in -10..0 {
230         // mixed
231         let a = a.slice(s![..;-stride]);
232         let b = b.slice(s![..;stride]);
233         assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5);
234     }
235 }
236 
237 #[test]
fold_and_sum()238 fn fold_and_sum() {
239     let a = Array::linspace(0., 127., 128).into_shape((8, 16)).unwrap();
240     assert_abs_diff_eq!(a.fold(0., |acc, &x| acc + x), a.sum(), epsilon = 1e-5);
241 
242     // test different strides
243     let max = 8 as Ixs;
244     for i in 1..max {
245         for j in 1..max {
246             let a1 = a.slice(s![..;i, ..;j]);
247             let mut sum = 0.;
248             for elt in a1.iter() {
249                 sum += *elt;
250             }
251             assert_abs_diff_eq!(a1.fold(0., |acc, &x| acc + x), sum, epsilon = 1e-5);
252             assert_abs_diff_eq!(sum, a1.sum(), epsilon = 1e-5);
253         }
254     }
255 
256     // skip a few elements
257     let max = 8 as Ixs;
258     for i in 1..max {
259         for skip in 1..max {
260             let a1 = a.slice(s![.., ..;i]);
261             let mut iter1 = a1.iter();
262             for _ in 0..skip {
263                 iter1.next();
264             }
265             let iter2 = iter1.clone();
266 
267             let mut sum = 0.;
268             for elt in iter1 {
269                 sum += *elt;
270             }
271             assert_abs_diff_eq!(iter2.fold(0., |acc, &x| acc + x), sum, epsilon = 1e-5);
272         }
273     }
274 }
275 
276 #[test]
product()277 fn product() {
278     let a = Array::linspace(0.5, 2., 128).into_shape((8, 16)).unwrap();
279     assert_abs_diff_eq!(a.fold(1., |acc, &x| acc * x), a.product(), epsilon = 1e-5);
280 
281     // test different strides
282     let max = 8 as Ixs;
283     for i in 1..max {
284         for j in 1..max {
285             let a1 = a.slice(s![..;i, ..;j]);
286             let mut prod = 1.;
287             for elt in a1.iter() {
288                 prod *= *elt;
289             }
290             assert_abs_diff_eq!(a1.fold(1., |acc, &x| acc * x), prod, epsilon = 1e-5);
291             assert_abs_diff_eq!(prod, a1.product(), epsilon = 1e-5);
292         }
293     }
294 }
295 
range_mat(m: Ix, n: Ix) -> Array2<f32>296 fn range_mat(m: Ix, n: Ix) -> Array2<f32> {
297     Array::linspace(0., (m * n) as f32 - 1., m * n)
298         .into_shape((m, n))
299         .unwrap()
300 }
301 
range_mat64(m: Ix, n: Ix) -> Array2<f64>302 fn range_mat64(m: Ix, n: Ix) -> Array2<f64> {
303     Array::linspace(0., (m * n) as f64 - 1., m * n)
304         .into_shape((m, n))
305         .unwrap()
306 }
307 
308 #[cfg(feature = "approx")]
range1_mat64(m: Ix) -> Array1<f64>309 fn range1_mat64(m: Ix) -> Array1<f64> {
310     Array::linspace(0., m as f64 - 1., m)
311 }
312 
range_i32(m: Ix, n: Ix) -> Array2<i32>313 fn range_i32(m: Ix, n: Ix) -> Array2<i32> {
314     Array::from_iter(0..(m * n) as i32)
315         .into_shape((m, n))
316         .unwrap()
317 }
318 
319 // simple, slow, correct (hopefully) mat mul
reference_mat_mul<A, S, S2>(lhs: &ArrayBase<S, Ix2>, rhs: &ArrayBase<S2, Ix2>) -> Array2<A> where A: LinalgScalar, S: Data<Elem = A>, S2: Data<Elem = A>,320 fn reference_mat_mul<A, S, S2>(lhs: &ArrayBase<S, Ix2>, rhs: &ArrayBase<S2, Ix2>) -> Array2<A>
321 where
322     A: LinalgScalar,
323     S: Data<Elem = A>,
324     S2: Data<Elem = A>,
325 {
326     let ((m, k), (k2, n)) = (lhs.dim(), rhs.dim());
327     assert!(m.checked_mul(n).is_some());
328     assert_eq!(k, k2);
329     let mut res_elems = Vec::<A>::with_capacity(m * n);
330     unsafe {
331         res_elems.set_len(m * n);
332     }
333 
334     let mut i = 0;
335     let mut j = 0;
336     for rr in &mut res_elems {
337         unsafe {
338             *rr = (0..k).fold(A::zero(), move |s, x| {
339                 s + *lhs.uget((i, x)) * *rhs.uget((x, j))
340             });
341         }
342         j += 1;
343         if j == n {
344             j = 0;
345             i += 1;
346         }
347     }
348     unsafe { ArrayBase::from_shape_vec_unchecked((m, n), res_elems) }
349 }
350 
351 #[test]
mat_mul()352 fn mat_mul() {
353     let (m, n, k) = (8, 8, 8);
354     let a = range_mat(m, n);
355     let b = range_mat(n, k);
356     let mut b = b / 4.;
357     {
358         let mut c = b.column_mut(0);
359         c += 1.0;
360     }
361     let ab = a.dot(&b);
362 
363     let mut af = Array::zeros(a.dim().f());
364     let mut bf = Array::zeros(b.dim().f());
365     af.assign(&a);
366     bf.assign(&b);
367 
368     assert_eq!(ab, a.dot(&bf));
369     assert_eq!(ab, af.dot(&b));
370     assert_eq!(ab, af.dot(&bf));
371 
372     let (m, n, k) = (10, 5, 11);
373     let a = range_mat(m, n);
374     let b = range_mat(n, k);
375     let mut b = b / 4.;
376     {
377         let mut c = b.column_mut(0);
378         c += 1.0;
379     }
380     let ab = a.dot(&b);
381 
382     let mut af = Array::zeros(a.dim().f());
383     let mut bf = Array::zeros(b.dim().f());
384     af.assign(&a);
385     bf.assign(&b);
386 
387     assert_eq!(ab, a.dot(&bf));
388     assert_eq!(ab, af.dot(&b));
389     assert_eq!(ab, af.dot(&bf));
390 
391     let (m, n, k) = (10, 8, 1);
392     let a = range_mat(m, n);
393     let b = range_mat(n, k);
394     let mut b = b / 4.;
395     {
396         let mut c = b.column_mut(0);
397         c += 1.0;
398     }
399     let ab = a.dot(&b);
400 
401     let mut af = Array::zeros(a.dim().f());
402     let mut bf = Array::zeros(b.dim().f());
403     af.assign(&a);
404     bf.assign(&b);
405 
406     assert_eq!(ab, a.dot(&bf));
407     assert_eq!(ab, af.dot(&b));
408     assert_eq!(ab, af.dot(&bf));
409 }
410 
411 // Check that matrix multiplication of contiguous matrices returns a
412 // matrix with the same order
413 #[test]
mat_mul_order()414 fn mat_mul_order() {
415     let (m, n, k) = (8, 8, 8);
416     let a = range_mat(m, n);
417     let b = range_mat(n, k);
418     let mut af = Array::zeros(a.dim().f());
419     let mut bf = Array::zeros(b.dim().f());
420     af.assign(&a);
421     bf.assign(&b);
422 
423     let cc = a.dot(&b);
424     let ff = af.dot(&bf);
425 
426     assert_eq!(cc.strides()[1], 1);
427     assert_eq!(ff.strides()[0], 1);
428 }
429 
430 // test matrix multiplication shape mismatch
431 #[test]
432 #[should_panic]
mat_mul_shape_mismatch()433 fn mat_mul_shape_mismatch() {
434     let (m, k, k2, n) = (8, 8, 9, 8);
435     let a = range_mat(m, k);
436     let b = range_mat(k2, n);
437     a.dot(&b);
438 }
439 
440 // test matrix multiplication shape mismatch
441 #[test]
442 #[should_panic]
mat_mul_shape_mismatch_2()443 fn mat_mul_shape_mismatch_2() {
444     let (m, k, k2, n) = (8, 8, 8, 8);
445     let a = range_mat(m, k);
446     let b = range_mat(k2, n);
447     let mut c = range_mat(m, n + 1);
448     general_mat_mul(1., &a, &b, 1., &mut c);
449 }
450 
451 // Check that matrix multiplication
452 // supports broadcast arrays.
453 #[test]
mat_mul_broadcast()454 fn mat_mul_broadcast() {
455     let (m, n, k) = (16, 16, 16);
456     let a = range_mat(m, n);
457     let x1 = 1.;
458     let x = Array::from(vec![x1]);
459     let b0 = x.broadcast((n, k)).unwrap();
460     let b1 = Array::from_elem(n, x1);
461     let b1 = b1.broadcast((n, k)).unwrap();
462     let b2 = Array::from_elem((n, k), x1);
463 
464     let c2 = a.dot(&b2);
465     let c1 = a.dot(&b1);
466     let c0 = a.dot(&b0);
467     assert_eq!(c2, c1);
468     assert_eq!(c2, c0);
469 }
470 
471 // Check that matrix multiplication supports reversed axes
472 #[test]
mat_mul_rev()473 fn mat_mul_rev() {
474     let (m, n, k) = (16, 16, 16);
475     let a = range_mat(m, n);
476     let b = range_mat(n, k);
477     let mut rev = Array::zeros(b.dim());
478     let mut rev = rev.slice_mut(s![..;-1, ..]);
479     rev.assign(&b);
480     println!("{:.?}", rev);
481 
482     let c1 = a.dot(&b);
483     let c2 = a.dot(&rev);
484     assert_eq!(c1, c2);
485 }
486 
487 // Check that matrix multiplication supports arrays with zero rows or columns
488 #[test]
mat_mut_zero_len()489 fn mat_mut_zero_len() {
490     defmac!(mat_mul_zero_len range_mat_fn => {
491         for n in 0..4 {
492             for m in 0..4 {
493                 let a = range_mat_fn(m, n);
494                 let b = range_mat_fn(n, 0);
495                 assert_eq!(a.dot(&b), Array2::zeros((m, 0)));
496             }
497             for k in 0..4 {
498                 let a = range_mat_fn(0, n);
499                 let b = range_mat_fn(n, k);
500                 assert_eq!(a.dot(&b), Array2::zeros((0, k)));
501             }
502         }
503     });
504     mat_mul_zero_len!(range_mat);
505     mat_mul_zero_len!(range_mat64);
506     mat_mul_zero_len!(range_i32);
507 }
508 
509 #[test]
scaled_add()510 fn scaled_add() {
511     let a = range_mat(16, 15);
512     let mut b = range_mat(16, 15);
513     b.mapv_inplace(f32::exp);
514 
515     let alpha = 0.2_f32;
516     let mut c = a.clone();
517     c.scaled_add(alpha, &b);
518 
519     let d = alpha * &b + &a;
520     assert_eq!(c, d);
521 }
522 
523 #[cfg(feature = "approx")]
524 #[test]
scaled_add_2()525 fn scaled_add_2() {
526     let beta = -2.3;
527     let sizes = vec![
528         (4, 4, 1, 4),
529         (8, 8, 1, 8),
530         (17, 15, 17, 15),
531         (4, 17, 4, 17),
532         (17, 3, 1, 3),
533         (19, 18, 19, 18),
534         (16, 17, 16, 17),
535         (15, 16, 15, 16),
536         (67, 63, 1, 63),
537     ];
538     // test different strides
539     for &s1 in &[1, 2, -1, -2] {
540         for &s2 in &[1, 2, -1, -2] {
541             for &(m, k, n, q) in &sizes {
542                 let mut a = range_mat64(m, k);
543                 let mut answer = a.clone();
544                 let c = range_mat64(n, q);
545 
546                 {
547                     let mut av = a.slice_mut(s![..;s1, ..;s2]);
548                     let c = c.slice(s![..;s1, ..;s2]);
549 
550                     let mut answerv = answer.slice_mut(s![..;s1, ..;s2]);
551                     answerv += &(beta * &c);
552                     av.scaled_add(beta, &c);
553                 }
554                 approx::assert_relative_eq!(a, answer, epsilon = 1e-12, max_relative = 1e-7);
555             }
556         }
557     }
558 }
559 
560 #[cfg(feature = "approx")]
561 #[test]
scaled_add_3()562 fn scaled_add_3() {
563     use approx::assert_relative_eq;
564     use ndarray::{Slice, SliceInfo, SliceInfoElem};
565     use std::convert::TryFrom;
566 
567     let beta = -2.3;
568     let sizes = vec![
569         (4, 4, 1, 4),
570         (8, 8, 1, 8),
571         (17, 15, 17, 15),
572         (4, 17, 4, 17),
573         (17, 3, 1, 3),
574         (19, 18, 19, 18),
575         (16, 17, 16, 17),
576         (15, 16, 15, 16),
577         (67, 63, 1, 63),
578     ];
579     // test different strides
580     for &s1 in &[1, 2, -1, -2] {
581         for &s2 in &[1, 2, -1, -2] {
582             for &(m, k, n, q) in &sizes {
583                 let mut a = range_mat64(m, k);
584                 let mut answer = a.clone();
585                 let cdim = if n == 1 { vec![q] } else { vec![n, q] };
586                 let cslice: Vec<SliceInfoElem> = if n == 1 {
587                     vec![Slice::from(..).step_by(s2).into()]
588                 } else {
589                     vec![
590                         Slice::from(..).step_by(s1).into(),
591                         Slice::from(..).step_by(s2).into(),
592                     ]
593                 };
594 
595                 let c = range_mat64(n, q).into_shape(cdim).unwrap();
596 
597                 {
598                     let mut av = a.slice_mut(s![..;s1, ..;s2]);
599                     let c = c.slice(SliceInfo::<_, IxDyn, IxDyn>::try_from(cslice).unwrap());
600 
601                     let mut answerv = answer.slice_mut(s![..;s1, ..;s2]);
602                     answerv += &(beta * &c);
603                     av.scaled_add(beta, &c);
604                 }
605                 assert_relative_eq!(a, answer, epsilon = 1e-12, max_relative = 1e-7);
606             }
607         }
608     }
609 }
610 
611 #[cfg(feature = "approx")]
612 #[test]
gen_mat_mul()613 fn gen_mat_mul() {
614     let alpha = -2.3;
615     let beta = 3.14;
616     let sizes = vec![
617         (4, 4, 4),
618         (8, 8, 8),
619         (17, 15, 16),
620         (4, 17, 3),
621         (17, 3, 22),
622         (19, 18, 2),
623         (16, 17, 15),
624         (15, 16, 17),
625         (67, 63, 62),
626     ];
627     // test different strides
628     for &s1 in &[1, 2, -1, -2] {
629         for &s2 in &[1, 2, -1, -2] {
630             for &(m, k, n) in &sizes {
631                 let a = range_mat64(m, k);
632                 let b = range_mat64(k, n);
633                 let mut c = range_mat64(m, n);
634                 let mut answer = c.clone();
635 
636                 {
637                     let a = a.slice(s![..;s1, ..;s2]);
638                     let b = b.slice(s![..;s2, ..;s2]);
639                     let mut cv = c.slice_mut(s![..;s1, ..;s2]);
640 
641                     let answer_part = alpha * reference_mat_mul(&a, &b) + beta * &cv;
642                     answer.slice_mut(s![..;s1, ..;s2]).assign(&answer_part);
643 
644                     general_mat_mul(alpha, &a, &b, beta, &mut cv);
645                 }
646                 approx::assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7);
647             }
648         }
649     }
650 }
651 
652 // Test y = A x where A is f-order
653 #[cfg(feature = "approx")]
654 #[test]
gemm_64_1_f()655 fn gemm_64_1_f() {
656     let a = range_mat64(64, 64).reversed_axes();
657     let (m, n) = a.dim();
658     // m x n  times n x 1  == m x 1
659     let x = range_mat64(n, 1);
660     let mut y = range_mat64(m, 1);
661     let answer = reference_mat_mul(&a, &x) + &y;
662     general_mat_mul(1.0, &a, &x, 1.0, &mut y);
663     approx::assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7);
664 }
665 
666 #[test]
gen_mat_mul_i32()667 fn gen_mat_mul_i32() {
668     let alpha = -1;
669     let beta = 2;
670     let sizes = if cfg!(miri) {
671         vec![(4, 4, 4), (4, 7, 3)]
672     } else {
673         vec![
674             (4, 4, 4),
675             (8, 8, 8),
676             (17, 15, 16),
677             (4, 17, 3),
678             (17, 3, 22),
679             (19, 18, 2),
680             (16, 17, 15),
681             (15, 16, 17),
682             (67, 63, 62),
683         ]
684     };
685     for &(m, k, n) in &sizes {
686         let a = range_i32(m, k);
687         let b = range_i32(k, n);
688         let mut c = range_i32(m, n);
689 
690         let answer = alpha * reference_mat_mul(&a, &b) + beta * &c;
691         general_mat_mul(alpha, &a, &b, beta, &mut c);
692         assert_eq!(&c, &answer);
693     }
694 }
695 
696 #[cfg(feature = "approx")]
697 #[test]
gen_mat_vec_mul()698 fn gen_mat_vec_mul() {
699     use approx::assert_relative_eq;
700     use ndarray::linalg::general_mat_vec_mul;
701 
702     // simple, slow, correct (hopefully) mat mul
703     fn reference_mat_vec_mul<A, S, S2>(
704         lhs: &ArrayBase<S, Ix2>,
705         rhs: &ArrayBase<S2, Ix1>,
706     ) -> Array1<A>
707     where
708         A: LinalgScalar,
709         S: Data<Elem = A>,
710         S2: Data<Elem = A>,
711     {
712         let ((m, _), k) = (lhs.dim(), rhs.dim());
713         reference_mat_mul(lhs, &rhs.as_standard_layout().into_shape((k, 1)).unwrap())
714             .into_shape(m)
715             .unwrap()
716     }
717 
718     let alpha = -2.3;
719     let beta = 3.14;
720     let sizes = vec![
721         (4, 4),
722         (8, 8),
723         (17, 15),
724         (4, 17),
725         (17, 3),
726         (19, 18),
727         (16, 17),
728         (15, 16),
729         (67, 63),
730     ];
731     // test different strides
732     for &s1 in &[1, 2, -1, -2] {
733         for &s2 in &[1, 2, -1, -2] {
734             for &(m, k) in &sizes {
735                 for &rev in &[false, true] {
736                     let mut a = range_mat64(m, k);
737                     if rev {
738                         a = a.reversed_axes();
739                     }
740                     let (m, k) = a.dim();
741                     let b = range1_mat64(k);
742                     let mut c = range1_mat64(m);
743                     let mut answer = c.clone();
744 
745                     {
746                         let a = a.slice(s![..;s1, ..;s2]);
747                         let b = b.slice(s![..;s2]);
748                         let mut cv = c.slice_mut(s![..;s1]);
749 
750                         let answer_part = alpha * reference_mat_vec_mul(&a, &b) + beta * &cv;
751                         answer.slice_mut(s![..;s1]).assign(&answer_part);
752 
753                         general_mat_vec_mul(alpha, &a, &b, beta, &mut cv);
754                     }
755                     assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7);
756                 }
757             }
758         }
759     }
760 }
761 
762 #[cfg(feature = "approx")]
763 #[test]
vec_mat_mul()764 fn vec_mat_mul() {
765     use approx::assert_relative_eq;
766 
767     // simple, slow, correct (hopefully) mat mul
768     fn reference_vec_mat_mul<A, S, S2>(
769         lhs: &ArrayBase<S, Ix1>,
770         rhs: &ArrayBase<S2, Ix2>,
771     ) -> Array1<A>
772     where
773         A: LinalgScalar,
774         S: Data<Elem = A>,
775         S2: Data<Elem = A>,
776     {
777         let (m, (_, n)) = (lhs.dim(), rhs.dim());
778         reference_mat_mul(&lhs.as_standard_layout().into_shape((1, m)).unwrap(), rhs)
779             .into_shape(n)
780             .unwrap()
781     }
782 
783     let sizes = vec![
784         (4, 4),
785         (8, 8),
786         (17, 15),
787         (4, 17),
788         (17, 3),
789         (19, 18),
790         (16, 17),
791         (15, 16),
792         (67, 63),
793     ];
794     // test different strides
795     for &s1 in &[1, 2, -1, -2] {
796         for &s2 in &[1, 2, -1, -2] {
797             for &(m, n) in &sizes {
798                 for &rev in &[false, true] {
799                     let mut b = range_mat64(m, n);
800                     if rev {
801                         b = b.reversed_axes();
802                     }
803                     let (m, n) = b.dim();
804                     let a = range1_mat64(m);
805                     let mut c = range1_mat64(n);
806                     let mut answer = c.clone();
807 
808                     {
809                         let b = b.slice(s![..;s1, ..;s2]);
810                         let a = a.slice(s![..;s1]);
811 
812                         let answer_part = reference_vec_mat_mul(&a, &b);
813                         answer.slice_mut(s![..;s2]).assign(&answer_part);
814 
815                         c.slice_mut(s![..;s2]).assign(&a.dot(&b));
816                     }
817                     assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7);
818                 }
819             }
820         }
821     }
822 }
823