1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Ke Yang <yangke@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 #include <Eigen/CXX11/Tensor>
13 
14 using Eigen::Tensor;
15 
16 template<int DataLayout>
test_simple_inflation()17 static void test_simple_inflation()
18 {
19   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
20   tensor.setRandom();
21   array<ptrdiff_t, 4> strides;
22 
23   strides[0] = 1;
24   strides[1] = 1;
25   strides[2] = 1;
26   strides[3] = 1;
27 
28   Tensor<float, 4, DataLayout> no_stride;
29   no_stride = tensor.inflate(strides);
30 
31   VERIFY_IS_EQUAL(no_stride.dimension(0), 2);
32   VERIFY_IS_EQUAL(no_stride.dimension(1), 3);
33   VERIFY_IS_EQUAL(no_stride.dimension(2), 5);
34   VERIFY_IS_EQUAL(no_stride.dimension(3), 7);
35 
36   for (int i = 0; i < 2; ++i) {
37     for (int j = 0; j < 3; ++j) {
38       for (int k = 0; k < 5; ++k) {
39         for (int l = 0; l < 7; ++l) {
40           VERIFY_IS_EQUAL(tensor(i,j,k,l), no_stride(i,j,k,l));
41         }
42       }
43     }
44   }
45 
46   strides[0] = 2;
47   strides[1] = 4;
48   strides[2] = 2;
49   strides[3] = 3;
50   Tensor<float, 4, DataLayout> inflated;
51   inflated = tensor.inflate(strides);
52 
53   VERIFY_IS_EQUAL(inflated.dimension(0), 3);
54   VERIFY_IS_EQUAL(inflated.dimension(1), 9);
55   VERIFY_IS_EQUAL(inflated.dimension(2), 9);
56   VERIFY_IS_EQUAL(inflated.dimension(3), 19);
57 
58   for (int i = 0; i < 3; ++i) {
59     for (int j = 0; j < 9; ++j) {
60       for (int k = 0; k < 9; ++k) {
61         for (int l = 0; l < 19; ++l) {
62           if (i % 2 == 0 &&
63               j % 4 == 0 &&
64               k % 2 == 0 &&
65               l % 3 == 0) {
66             VERIFY_IS_EQUAL(inflated(i,j,k,l),
67                             tensor(i/2, j/4, k/2, l/3));
68           } else {
69             VERIFY_IS_EQUAL(0, inflated(i,j,k,l));
70           }
71         }
72       }
73     }
74   }
75 }
76 
test_cxx11_tensor_inflation()77 void test_cxx11_tensor_inflation()
78 {
79   CALL_SUBTEST(test_simple_inflation<ColMajor>());
80   CALL_SUBTEST(test_simple_inflation<RowMajor>());
81 }
82