1 #![allow(
2 clippy::many_single_char_names,
3 clippy::deref_addrof,
4 clippy::unreadable_literal,
5 clippy::many_single_char_names
6 )]
7 #![cfg(feature = "std")]
8 use ndarray::linalg::general_mat_mul;
9 use ndarray::prelude::*;
10 use ndarray::{rcarr1, rcarr2};
11 use ndarray::{Data, LinalgScalar};
12 use ndarray::{Ix, Ixs};
13 use num_traits::Zero;
14
15 use approx::assert_abs_diff_eq;
16 use defmac::defmac;
17
test_oper(op: &str, a: &[f32], b: &[f32], c: &[f32])18 fn test_oper(op: &str, a: &[f32], b: &[f32], c: &[f32]) {
19 let aa = rcarr1(a);
20 let bb = rcarr1(b);
21 let cc = rcarr1(c);
22 test_oper_arr::<f32, _>(op, aa.clone(), bb.clone(), cc.clone());
23 let dim = (2, 2);
24 let aa = aa.reshape(dim);
25 let bb = bb.reshape(dim);
26 let cc = cc.reshape(dim);
27 test_oper_arr::<f32, _>(op, aa.clone(), bb.clone(), cc.clone());
28 let dim = (1, 2, 1, 2);
29 let aa = aa.reshape(dim);
30 let bb = bb.reshape(dim);
31 let cc = cc.reshape(dim);
32 test_oper_arr::<f32, _>(op, aa.clone(), bb.clone(), cc.clone());
33 }
34
35
test_oper_arr<A, D>(op: &str, mut aa: ArcArray<f32, D>, bb: ArcArray<f32, D>, cc: ArcArray<f32, D>) where D: Dimension,36 fn test_oper_arr<A, D>(op: &str, mut aa: ArcArray<f32, D>, bb: ArcArray<f32, D>, cc: ArcArray<f32, D>)
37 where
38 D: Dimension,
39 {
40 match op {
41 "+" => {
42 assert_eq!(&aa + &bb, cc);
43 aa += &bb;
44 assert_eq!(aa, cc);
45 }
46 "-" => {
47 assert_eq!(&aa - &bb, cc);
48 aa -= &bb;
49 assert_eq!(aa, cc);
50 }
51 "*" => {
52 assert_eq!(&aa * &bb, cc);
53 aa *= &bb;
54 assert_eq!(aa, cc);
55 }
56 "/" => {
57 assert_eq!(&aa / &bb, cc);
58 aa /= &bb;
59 assert_eq!(aa, cc);
60 }
61 "%" => {
62 assert_eq!(&aa % &bb, cc);
63 aa %= &bb;
64 assert_eq!(aa, cc);
65 }
66 "neg" => {
67 assert_eq!(-&aa, cc);
68 assert_eq!(-aa.clone(), cc);
69 }
70 _ => panic!(),
71 }
72 }
73
74 #[test]
operations()75 fn operations() {
76 test_oper(
77 "+",
78 &[1.0, 2.0, 3.0, 4.0],
79 &[0.0, 1.0, 2.0, 3.0],
80 &[1.0, 3.0, 5.0, 7.0],
81 );
82 test_oper(
83 "-",
84 &[1.0, 2.0, 3.0, 4.0],
85 &[0.0, 1.0, 2.0, 3.0],
86 &[1.0, 1.0, 1.0, 1.0],
87 );
88 test_oper(
89 "*",
90 &[1.0, 2.0, 3.0, 4.0],
91 &[0.0, 1.0, 2.0, 3.0],
92 &[0.0, 2.0, 6.0, 12.0],
93 );
94 test_oper(
95 "/",
96 &[1.0, 2.0, 3.0, 4.0],
97 &[1.0, 1.0, 2.0, 3.0],
98 &[1.0, 2.0, 3.0 / 2.0, 4.0 / 3.0],
99 );
100 test_oper(
101 "%",
102 &[1.0, 2.0, 3.0, 4.0],
103 &[1.0, 1.0, 2.0, 3.0],
104 &[0.0, 0.0, 1.0, 1.0],
105 );
106 test_oper(
107 "neg",
108 &[1.0, 2.0, 3.0, 4.0],
109 &[1.0, 1.0, 2.0, 3.0],
110 &[-1.0, -2.0, -3.0, -4.0],
111 );
112 }
113
114 #[test]
scalar_operations()115 fn scalar_operations() {
116 let a = arr0::<f32>(1.);
117 let b = rcarr1::<f32>(&[1., 1.]);
118 let c = rcarr2(&[[1., 1.], [1., 1.]]);
119
120 {
121 let mut x = a.clone();
122 let mut y = arr0(0.);
123 x += 1.;
124 y.fill(2.);
125 assert_eq!(x, a + arr0(1.));
126 assert_eq!(x, y);
127 }
128
129 {
130 let mut x = b.clone();
131 let mut y = rcarr1(&[0., 0.]);
132 x += 1.;
133 y.fill(2.);
134 assert_eq!(x, b + arr0(1.));
135 assert_eq!(x, y);
136 }
137
138 {
139 let mut x = c.clone();
140 let mut y = ArcArray::zeros((2, 2));
141 x += 1.;
142 y.fill(2.);
143 assert_eq!(x, c + arr0(1.));
144 assert_eq!(x, y);
145 }
146 }
147
reference_dot<'a, V1, V2>(a: V1, b: V2) -> f32 where V1: AsArray<'a, f32>, V2: AsArray<'a, f32>,148 fn reference_dot<'a, V1, V2>(a: V1, b: V2) -> f32
149 where
150 V1: AsArray<'a, f32>,
151 V2: AsArray<'a, f32>,
152 {
153 let a = a.into();
154 let b = b.into();
155 a.iter()
156 .zip(b.iter())
157 .fold(f32::zero(), |acc, (&x, &y)| acc + x * y)
158 }
159
160 #[test]
dot_product()161 fn dot_product() {
162 let a = Array::range(0., 69., 1.);
163 let b = &a * 2. - 7.;
164 let dot = 197846.;
165 assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5);
166
167 // test different alignments
168 let max = 8 as Ixs;
169 for i in 1..max {
170 let a1 = a.slice(s![i..]);
171 let b1 = b.slice(s![i..]);
172 assert_abs_diff_eq!(a1.dot(&b1), reference_dot(&a1, &b1), epsilon = 1e-5);
173 let a2 = a.slice(s![..-i]);
174 let b2 = b.slice(s![i..]);
175 assert_abs_diff_eq!(a2.dot(&b2), reference_dot(&a2, &b2), epsilon = 1e-5);
176 }
177
178 let a = a.map(|f| *f as f32);
179 let b = b.map(|f| *f as f32);
180 assert_abs_diff_eq!(a.dot(&b), dot as f32, epsilon = 1e-5);
181
182 let max = 8 as Ixs;
183 for i in 1..max {
184 let a1 = a.slice(s![i..]);
185 let b1 = b.slice(s![i..]);
186 assert_abs_diff_eq!(a1.dot(&b1), reference_dot(&a1, &b1), epsilon = 1e-5);
187 let a2 = a.slice(s![..-i]);
188 let b2 = b.slice(s![i..]);
189 assert_abs_diff_eq!(a2.dot(&b2), reference_dot(&a2, &b2), epsilon = 1e-5);
190 }
191
192 let a = a.map(|f| *f as i32);
193 let b = b.map(|f| *f as i32);
194 assert_eq!(a.dot(&b), dot as i32);
195 }
196
197 // test that we can dot product with a broadcast array
198 #[test]
dot_product_0()199 fn dot_product_0() {
200 let a = Array::range(0., 69., 1.);
201 let x = 1.5;
202 let b = aview0(&x);
203 let b = b.broadcast(a.dim()).unwrap();
204 assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5);
205
206 // test different alignments
207 let max = 8 as Ixs;
208 for i in 1..max {
209 let a1 = a.slice(s![i..]);
210 let b1 = b.slice(s![i..]);
211 assert_abs_diff_eq!(a1.dot(&b1), reference_dot(&a1, &b1), epsilon = 1e-5);
212 let a2 = a.slice(s![..-i]);
213 let b2 = b.slice(s![i..]);
214 assert_abs_diff_eq!(a2.dot(&b2), reference_dot(&a2, &b2), epsilon = 1e-5);
215 }
216 }
217
218 #[test]
dot_product_neg_stride()219 fn dot_product_neg_stride() {
220 // test that we can dot with negative stride
221 let a = Array::range(0., 69., 1.);
222 let b = &a * 2. - 7.;
223 for stride in -10..0 {
224 // both negative
225 let a = a.slice(s![..;stride]);
226 let b = b.slice(s![..;stride]);
227 assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5);
228 }
229 for stride in -10..0 {
230 // mixed
231 let a = a.slice(s![..;-stride]);
232 let b = b.slice(s![..;stride]);
233 assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5);
234 }
235 }
236
237 #[test]
fold_and_sum()238 fn fold_and_sum() {
239 let a = Array::linspace(0., 127., 128).into_shape((8, 16)).unwrap();
240 assert_abs_diff_eq!(a.fold(0., |acc, &x| acc + x), a.sum(), epsilon = 1e-5);
241
242 // test different strides
243 let max = 8 as Ixs;
244 for i in 1..max {
245 for j in 1..max {
246 let a1 = a.slice(s![..;i, ..;j]);
247 let mut sum = 0.;
248 for elt in a1.iter() {
249 sum += *elt;
250 }
251 assert_abs_diff_eq!(a1.fold(0., |acc, &x| acc + x), sum, epsilon = 1e-5);
252 assert_abs_diff_eq!(sum, a1.sum(), epsilon = 1e-5);
253 }
254 }
255
256 // skip a few elements
257 let max = 8 as Ixs;
258 for i in 1..max {
259 for skip in 1..max {
260 let a1 = a.slice(s![.., ..;i]);
261 let mut iter1 = a1.iter();
262 for _ in 0..skip {
263 iter1.next();
264 }
265 let iter2 = iter1.clone();
266
267 let mut sum = 0.;
268 for elt in iter1 {
269 sum += *elt;
270 }
271 assert_abs_diff_eq!(iter2.fold(0., |acc, &x| acc + x), sum, epsilon = 1e-5);
272 }
273 }
274 }
275
276 #[test]
product()277 fn product() {
278 let a = Array::linspace(0.5, 2., 128).into_shape((8, 16)).unwrap();
279 assert_abs_diff_eq!(a.fold(1., |acc, &x| acc * x), a.product(), epsilon = 1e-5);
280
281 // test different strides
282 let max = 8 as Ixs;
283 for i in 1..max {
284 for j in 1..max {
285 let a1 = a.slice(s![..;i, ..;j]);
286 let mut prod = 1.;
287 for elt in a1.iter() {
288 prod *= *elt;
289 }
290 assert_abs_diff_eq!(a1.fold(1., |acc, &x| acc * x), prod, epsilon = 1e-5);
291 assert_abs_diff_eq!(prod, a1.product(), epsilon = 1e-5);
292 }
293 }
294 }
295
range_mat(m: Ix, n: Ix) -> Array2<f32>296 fn range_mat(m: Ix, n: Ix) -> Array2<f32> {
297 Array::linspace(0., (m * n) as f32 - 1., m * n)
298 .into_shape((m, n))
299 .unwrap()
300 }
301
range_mat64(m: Ix, n: Ix) -> Array2<f64>302 fn range_mat64(m: Ix, n: Ix) -> Array2<f64> {
303 Array::linspace(0., (m * n) as f64 - 1., m * n)
304 .into_shape((m, n))
305 .unwrap()
306 }
307
308 #[cfg(feature = "approx")]
range1_mat64(m: Ix) -> Array1<f64>309 fn range1_mat64(m: Ix) -> Array1<f64> {
310 Array::linspace(0., m as f64 - 1., m)
311 }
312
range_i32(m: Ix, n: Ix) -> Array2<i32>313 fn range_i32(m: Ix, n: Ix) -> Array2<i32> {
314 Array::from_iter(0..(m * n) as i32)
315 .into_shape((m, n))
316 .unwrap()
317 }
318
319 // simple, slow, correct (hopefully) mat mul
reference_mat_mul<A, S, S2>(lhs: &ArrayBase<S, Ix2>, rhs: &ArrayBase<S2, Ix2>) -> Array2<A> where A: LinalgScalar, S: Data<Elem = A>, S2: Data<Elem = A>,320 fn reference_mat_mul<A, S, S2>(lhs: &ArrayBase<S, Ix2>, rhs: &ArrayBase<S2, Ix2>) -> Array2<A>
321 where
322 A: LinalgScalar,
323 S: Data<Elem = A>,
324 S2: Data<Elem = A>,
325 {
326 let ((m, k), (k2, n)) = (lhs.dim(), rhs.dim());
327 assert!(m.checked_mul(n).is_some());
328 assert_eq!(k, k2);
329 let mut res_elems = Vec::<A>::with_capacity(m * n);
330 unsafe {
331 res_elems.set_len(m * n);
332 }
333
334 let mut i = 0;
335 let mut j = 0;
336 for rr in &mut res_elems {
337 unsafe {
338 *rr = (0..k).fold(A::zero(), move |s, x| {
339 s + *lhs.uget((i, x)) * *rhs.uget((x, j))
340 });
341 }
342 j += 1;
343 if j == n {
344 j = 0;
345 i += 1;
346 }
347 }
348 unsafe { ArrayBase::from_shape_vec_unchecked((m, n), res_elems) }
349 }
350
351 #[test]
mat_mul()352 fn mat_mul() {
353 let (m, n, k) = (8, 8, 8);
354 let a = range_mat(m, n);
355 let b = range_mat(n, k);
356 let mut b = b / 4.;
357 {
358 let mut c = b.column_mut(0);
359 c += 1.0;
360 }
361 let ab = a.dot(&b);
362
363 let mut af = Array::zeros(a.dim().f());
364 let mut bf = Array::zeros(b.dim().f());
365 af.assign(&a);
366 bf.assign(&b);
367
368 assert_eq!(ab, a.dot(&bf));
369 assert_eq!(ab, af.dot(&b));
370 assert_eq!(ab, af.dot(&bf));
371
372 let (m, n, k) = (10, 5, 11);
373 let a = range_mat(m, n);
374 let b = range_mat(n, k);
375 let mut b = b / 4.;
376 {
377 let mut c = b.column_mut(0);
378 c += 1.0;
379 }
380 let ab = a.dot(&b);
381
382 let mut af = Array::zeros(a.dim().f());
383 let mut bf = Array::zeros(b.dim().f());
384 af.assign(&a);
385 bf.assign(&b);
386
387 assert_eq!(ab, a.dot(&bf));
388 assert_eq!(ab, af.dot(&b));
389 assert_eq!(ab, af.dot(&bf));
390
391 let (m, n, k) = (10, 8, 1);
392 let a = range_mat(m, n);
393 let b = range_mat(n, k);
394 let mut b = b / 4.;
395 {
396 let mut c = b.column_mut(0);
397 c += 1.0;
398 }
399 let ab = a.dot(&b);
400
401 let mut af = Array::zeros(a.dim().f());
402 let mut bf = Array::zeros(b.dim().f());
403 af.assign(&a);
404 bf.assign(&b);
405
406 assert_eq!(ab, a.dot(&bf));
407 assert_eq!(ab, af.dot(&b));
408 assert_eq!(ab, af.dot(&bf));
409 }
410
411 // Check that matrix multiplication of contiguous matrices returns a
412 // matrix with the same order
413 #[test]
mat_mul_order()414 fn mat_mul_order() {
415 let (m, n, k) = (8, 8, 8);
416 let a = range_mat(m, n);
417 let b = range_mat(n, k);
418 let mut af = Array::zeros(a.dim().f());
419 let mut bf = Array::zeros(b.dim().f());
420 af.assign(&a);
421 bf.assign(&b);
422
423 let cc = a.dot(&b);
424 let ff = af.dot(&bf);
425
426 assert_eq!(cc.strides()[1], 1);
427 assert_eq!(ff.strides()[0], 1);
428 }
429
430 // test matrix multiplication shape mismatch
431 #[test]
432 #[should_panic]
mat_mul_shape_mismatch()433 fn mat_mul_shape_mismatch() {
434 let (m, k, k2, n) = (8, 8, 9, 8);
435 let a = range_mat(m, k);
436 let b = range_mat(k2, n);
437 a.dot(&b);
438 }
439
440 // test matrix multiplication shape mismatch
441 #[test]
442 #[should_panic]
mat_mul_shape_mismatch_2()443 fn mat_mul_shape_mismatch_2() {
444 let (m, k, k2, n) = (8, 8, 8, 8);
445 let a = range_mat(m, k);
446 let b = range_mat(k2, n);
447 let mut c = range_mat(m, n + 1);
448 general_mat_mul(1., &a, &b, 1., &mut c);
449 }
450
451 // Check that matrix multiplication
452 // supports broadcast arrays.
453 #[test]
mat_mul_broadcast()454 fn mat_mul_broadcast() {
455 let (m, n, k) = (16, 16, 16);
456 let a = range_mat(m, n);
457 let x1 = 1.;
458 let x = Array::from(vec![x1]);
459 let b0 = x.broadcast((n, k)).unwrap();
460 let b1 = Array::from_elem(n, x1);
461 let b1 = b1.broadcast((n, k)).unwrap();
462 let b2 = Array::from_elem((n, k), x1);
463
464 let c2 = a.dot(&b2);
465 let c1 = a.dot(&b1);
466 let c0 = a.dot(&b0);
467 assert_eq!(c2, c1);
468 assert_eq!(c2, c0);
469 }
470
471 // Check that matrix multiplication supports reversed axes
472 #[test]
mat_mul_rev()473 fn mat_mul_rev() {
474 let (m, n, k) = (16, 16, 16);
475 let a = range_mat(m, n);
476 let b = range_mat(n, k);
477 let mut rev = Array::zeros(b.dim());
478 let mut rev = rev.slice_mut(s![..;-1, ..]);
479 rev.assign(&b);
480 println!("{:.?}", rev);
481
482 let c1 = a.dot(&b);
483 let c2 = a.dot(&rev);
484 assert_eq!(c1, c2);
485 }
486
487 // Check that matrix multiplication supports arrays with zero rows or columns
488 #[test]
mat_mut_zero_len()489 fn mat_mut_zero_len() {
490 defmac!(mat_mul_zero_len range_mat_fn => {
491 for n in 0..4 {
492 for m in 0..4 {
493 let a = range_mat_fn(m, n);
494 let b = range_mat_fn(n, 0);
495 assert_eq!(a.dot(&b), Array2::zeros((m, 0)));
496 }
497 for k in 0..4 {
498 let a = range_mat_fn(0, n);
499 let b = range_mat_fn(n, k);
500 assert_eq!(a.dot(&b), Array2::zeros((0, k)));
501 }
502 }
503 });
504 mat_mul_zero_len!(range_mat);
505 mat_mul_zero_len!(range_mat64);
506 mat_mul_zero_len!(range_i32);
507 }
508
509 #[test]
scaled_add()510 fn scaled_add() {
511 let a = range_mat(16, 15);
512 let mut b = range_mat(16, 15);
513 b.mapv_inplace(f32::exp);
514
515 let alpha = 0.2_f32;
516 let mut c = a.clone();
517 c.scaled_add(alpha, &b);
518
519 let d = alpha * &b + &a;
520 assert_eq!(c, d);
521 }
522
523 #[cfg(feature = "approx")]
524 #[test]
scaled_add_2()525 fn scaled_add_2() {
526 let beta = -2.3;
527 let sizes = vec![
528 (4, 4, 1, 4),
529 (8, 8, 1, 8),
530 (17, 15, 17, 15),
531 (4, 17, 4, 17),
532 (17, 3, 1, 3),
533 (19, 18, 19, 18),
534 (16, 17, 16, 17),
535 (15, 16, 15, 16),
536 (67, 63, 1, 63),
537 ];
538 // test different strides
539 for &s1 in &[1, 2, -1, -2] {
540 for &s2 in &[1, 2, -1, -2] {
541 for &(m, k, n, q) in &sizes {
542 let mut a = range_mat64(m, k);
543 let mut answer = a.clone();
544 let c = range_mat64(n, q);
545
546 {
547 let mut av = a.slice_mut(s![..;s1, ..;s2]);
548 let c = c.slice(s![..;s1, ..;s2]);
549
550 let mut answerv = answer.slice_mut(s![..;s1, ..;s2]);
551 answerv += &(beta * &c);
552 av.scaled_add(beta, &c);
553 }
554 approx::assert_relative_eq!(a, answer, epsilon = 1e-12, max_relative = 1e-7);
555 }
556 }
557 }
558 }
559
560 #[cfg(feature = "approx")]
561 #[test]
scaled_add_3()562 fn scaled_add_3() {
563 use approx::assert_relative_eq;
564 use ndarray::{Slice, SliceInfo, SliceInfoElem};
565 use std::convert::TryFrom;
566
567 let beta = -2.3;
568 let sizes = vec![
569 (4, 4, 1, 4),
570 (8, 8, 1, 8),
571 (17, 15, 17, 15),
572 (4, 17, 4, 17),
573 (17, 3, 1, 3),
574 (19, 18, 19, 18),
575 (16, 17, 16, 17),
576 (15, 16, 15, 16),
577 (67, 63, 1, 63),
578 ];
579 // test different strides
580 for &s1 in &[1, 2, -1, -2] {
581 for &s2 in &[1, 2, -1, -2] {
582 for &(m, k, n, q) in &sizes {
583 let mut a = range_mat64(m, k);
584 let mut answer = a.clone();
585 let cdim = if n == 1 { vec![q] } else { vec![n, q] };
586 let cslice: Vec<SliceInfoElem> = if n == 1 {
587 vec![Slice::from(..).step_by(s2).into()]
588 } else {
589 vec![
590 Slice::from(..).step_by(s1).into(),
591 Slice::from(..).step_by(s2).into(),
592 ]
593 };
594
595 let c = range_mat64(n, q).into_shape(cdim).unwrap();
596
597 {
598 let mut av = a.slice_mut(s![..;s1, ..;s2]);
599 let c = c.slice(SliceInfo::<_, IxDyn, IxDyn>::try_from(cslice).unwrap());
600
601 let mut answerv = answer.slice_mut(s![..;s1, ..;s2]);
602 answerv += &(beta * &c);
603 av.scaled_add(beta, &c);
604 }
605 assert_relative_eq!(a, answer, epsilon = 1e-12, max_relative = 1e-7);
606 }
607 }
608 }
609 }
610
611 #[cfg(feature = "approx")]
612 #[test]
gen_mat_mul()613 fn gen_mat_mul() {
614 let alpha = -2.3;
615 let beta = 3.14;
616 let sizes = vec![
617 (4, 4, 4),
618 (8, 8, 8),
619 (17, 15, 16),
620 (4, 17, 3),
621 (17, 3, 22),
622 (19, 18, 2),
623 (16, 17, 15),
624 (15, 16, 17),
625 (67, 63, 62),
626 ];
627 // test different strides
628 for &s1 in &[1, 2, -1, -2] {
629 for &s2 in &[1, 2, -1, -2] {
630 for &(m, k, n) in &sizes {
631 let a = range_mat64(m, k);
632 let b = range_mat64(k, n);
633 let mut c = range_mat64(m, n);
634 let mut answer = c.clone();
635
636 {
637 let a = a.slice(s![..;s1, ..;s2]);
638 let b = b.slice(s![..;s2, ..;s2]);
639 let mut cv = c.slice_mut(s![..;s1, ..;s2]);
640
641 let answer_part = alpha * reference_mat_mul(&a, &b) + beta * &cv;
642 answer.slice_mut(s![..;s1, ..;s2]).assign(&answer_part);
643
644 general_mat_mul(alpha, &a, &b, beta, &mut cv);
645 }
646 approx::assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7);
647 }
648 }
649 }
650 }
651
652 // Test y = A x where A is f-order
653 #[cfg(feature = "approx")]
654 #[test]
gemm_64_1_f()655 fn gemm_64_1_f() {
656 let a = range_mat64(64, 64).reversed_axes();
657 let (m, n) = a.dim();
658 // m x n times n x 1 == m x 1
659 let x = range_mat64(n, 1);
660 let mut y = range_mat64(m, 1);
661 let answer = reference_mat_mul(&a, &x) + &y;
662 general_mat_mul(1.0, &a, &x, 1.0, &mut y);
663 approx::assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7);
664 }
665
666 #[test]
gen_mat_mul_i32()667 fn gen_mat_mul_i32() {
668 let alpha = -1;
669 let beta = 2;
670 let sizes = if cfg!(miri) {
671 vec![(4, 4, 4), (4, 7, 3)]
672 } else {
673 vec![
674 (4, 4, 4),
675 (8, 8, 8),
676 (17, 15, 16),
677 (4, 17, 3),
678 (17, 3, 22),
679 (19, 18, 2),
680 (16, 17, 15),
681 (15, 16, 17),
682 (67, 63, 62),
683 ]
684 };
685 for &(m, k, n) in &sizes {
686 let a = range_i32(m, k);
687 let b = range_i32(k, n);
688 let mut c = range_i32(m, n);
689
690 let answer = alpha * reference_mat_mul(&a, &b) + beta * &c;
691 general_mat_mul(alpha, &a, &b, beta, &mut c);
692 assert_eq!(&c, &answer);
693 }
694 }
695
696 #[cfg(feature = "approx")]
697 #[test]
gen_mat_vec_mul()698 fn gen_mat_vec_mul() {
699 use approx::assert_relative_eq;
700 use ndarray::linalg::general_mat_vec_mul;
701
702 // simple, slow, correct (hopefully) mat mul
703 fn reference_mat_vec_mul<A, S, S2>(
704 lhs: &ArrayBase<S, Ix2>,
705 rhs: &ArrayBase<S2, Ix1>,
706 ) -> Array1<A>
707 where
708 A: LinalgScalar,
709 S: Data<Elem = A>,
710 S2: Data<Elem = A>,
711 {
712 let ((m, _), k) = (lhs.dim(), rhs.dim());
713 reference_mat_mul(lhs, &rhs.as_standard_layout().into_shape((k, 1)).unwrap())
714 .into_shape(m)
715 .unwrap()
716 }
717
718 let alpha = -2.3;
719 let beta = 3.14;
720 let sizes = vec![
721 (4, 4),
722 (8, 8),
723 (17, 15),
724 (4, 17),
725 (17, 3),
726 (19, 18),
727 (16, 17),
728 (15, 16),
729 (67, 63),
730 ];
731 // test different strides
732 for &s1 in &[1, 2, -1, -2] {
733 for &s2 in &[1, 2, -1, -2] {
734 for &(m, k) in &sizes {
735 for &rev in &[false, true] {
736 let mut a = range_mat64(m, k);
737 if rev {
738 a = a.reversed_axes();
739 }
740 let (m, k) = a.dim();
741 let b = range1_mat64(k);
742 let mut c = range1_mat64(m);
743 let mut answer = c.clone();
744
745 {
746 let a = a.slice(s![..;s1, ..;s2]);
747 let b = b.slice(s![..;s2]);
748 let mut cv = c.slice_mut(s![..;s1]);
749
750 let answer_part = alpha * reference_mat_vec_mul(&a, &b) + beta * &cv;
751 answer.slice_mut(s![..;s1]).assign(&answer_part);
752
753 general_mat_vec_mul(alpha, &a, &b, beta, &mut cv);
754 }
755 assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7);
756 }
757 }
758 }
759 }
760 }
761
762 #[cfg(feature = "approx")]
763 #[test]
vec_mat_mul()764 fn vec_mat_mul() {
765 use approx::assert_relative_eq;
766
767 // simple, slow, correct (hopefully) mat mul
768 fn reference_vec_mat_mul<A, S, S2>(
769 lhs: &ArrayBase<S, Ix1>,
770 rhs: &ArrayBase<S2, Ix2>,
771 ) -> Array1<A>
772 where
773 A: LinalgScalar,
774 S: Data<Elem = A>,
775 S2: Data<Elem = A>,
776 {
777 let (m, (_, n)) = (lhs.dim(), rhs.dim());
778 reference_mat_mul(&lhs.as_standard_layout().into_shape((1, m)).unwrap(), rhs)
779 .into_shape(n)
780 .unwrap()
781 }
782
783 let sizes = vec![
784 (4, 4),
785 (8, 8),
786 (17, 15),
787 (4, 17),
788 (17, 3),
789 (19, 18),
790 (16, 17),
791 (15, 16),
792 (67, 63),
793 ];
794 // test different strides
795 for &s1 in &[1, 2, -1, -2] {
796 for &s2 in &[1, 2, -1, -2] {
797 for &(m, n) in &sizes {
798 for &rev in &[false, true] {
799 let mut b = range_mat64(m, n);
800 if rev {
801 b = b.reversed_axes();
802 }
803 let (m, n) = b.dim();
804 let a = range1_mat64(m);
805 let mut c = range1_mat64(n);
806 let mut answer = c.clone();
807
808 {
809 let b = b.slice(s![..;s1, ..;s2]);
810 let a = a.slice(s![..;s1]);
811
812 let answer_part = reference_vec_mat_mul(&a, &b);
813 answer.slice_mut(s![..;s2]).assign(&answer_part);
814
815 c.slice_mut(s![..;s2]).assign(&a.dot(&b));
816 }
817 assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7);
818 }
819 }
820 }
821 }
822 }
823