1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2013 Christian Seiler <christian@iwakd.de>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10 #include "main.h"
11
12 #include <Eigen/CXX11/Tensor>
13
14 using Eigen::Tensor;
15 using Eigen::RowMajor;
16
test_0d()17 static void test_0d()
18 {
19 Tensor<int, 0> scalar1;
20 Tensor<int, 0, RowMajor> scalar2;
21 Tensor<int, 0> scalar3;
22 Tensor<int, 0, RowMajor> scalar4;
23
24 scalar3.resize();
25 scalar4.resize();
26
27 scalar1() = 7;
28 scalar2() = 13;
29 scalar3.setValues(17);
30 scalar4.setZero();
31
32 VERIFY_IS_EQUAL(scalar1.rank(), 0);
33 VERIFY_IS_EQUAL(scalar1.size(), 1);
34
35 VERIFY_IS_EQUAL(scalar1(), 7);
36 VERIFY_IS_EQUAL(scalar2(), 13);
37 VERIFY_IS_EQUAL(scalar3(), 17);
38 VERIFY_IS_EQUAL(scalar4(), 0);
39
40 Tensor<int, 0> scalar5(scalar1);
41
42 VERIFY_IS_EQUAL(scalar5(), 7);
43 VERIFY_IS_EQUAL(scalar5.data()[0], 7);
44 }
45
test_1d()46 static void test_1d()
47 {
48 Tensor<int, 1> vec1(6);
49 Tensor<int, 1, RowMajor> vec2(6);
50 Tensor<int, 1> vec3;
51 Tensor<int, 1, RowMajor> vec4;
52
53 vec3.resize(6);
54 vec4.resize(6);
55
56 vec1(0) = 4; vec2(0) = 0; vec3(0) = 5;
57 vec1(1) = 8; vec2(1) = 1; vec3(1) = 4;
58 vec1(2) = 15; vec2(2) = 2; vec3(2) = 3;
59 vec1(3) = 16; vec2(3) = 3; vec3(3) = 2;
60 vec1(4) = 23; vec2(4) = 4; vec3(4) = 1;
61 vec1(5) = 42; vec2(5) = 5; vec3(5) = 0;
62 vec4.setZero();
63
64 VERIFY_IS_EQUAL((vec1.rank()), 1);
65 VERIFY_IS_EQUAL((vec1.size()), 6);
66 VERIFY_IS_EQUAL((vec1.dimensions()[0]), 6);
67
68 VERIFY_IS_EQUAL((vec1[0]), 4);
69 VERIFY_IS_EQUAL((vec1[1]), 8);
70 VERIFY_IS_EQUAL((vec1[2]), 15);
71 VERIFY_IS_EQUAL((vec1[3]), 16);
72 VERIFY_IS_EQUAL((vec1[4]), 23);
73 VERIFY_IS_EQUAL((vec1[5]), 42);
74
75 VERIFY_IS_EQUAL((vec2[0]), 0);
76 VERIFY_IS_EQUAL((vec2[1]), 1);
77 VERIFY_IS_EQUAL((vec2[2]), 2);
78 VERIFY_IS_EQUAL((vec2[3]), 3);
79 VERIFY_IS_EQUAL((vec2[4]), 4);
80 VERIFY_IS_EQUAL((vec2[5]), 5);
81
82 VERIFY_IS_EQUAL((vec3[0]), 5);
83 VERIFY_IS_EQUAL((vec3[1]), 4);
84 VERIFY_IS_EQUAL((vec3[2]), 3);
85 VERIFY_IS_EQUAL((vec3[3]), 2);
86 VERIFY_IS_EQUAL((vec3[4]), 1);
87 VERIFY_IS_EQUAL((vec3[5]), 0);
88
89 VERIFY_IS_EQUAL((vec4[0]), 0);
90 VERIFY_IS_EQUAL((vec4[1]), 0);
91 VERIFY_IS_EQUAL((vec4[2]), 0);
92 VERIFY_IS_EQUAL((vec4[3]), 0);
93 VERIFY_IS_EQUAL((vec4[4]), 0);
94 VERIFY_IS_EQUAL((vec4[5]), 0);
95
96 Tensor<int, 1> vec5(vec1);
97
98 VERIFY_IS_EQUAL((vec5(0)), 4);
99 VERIFY_IS_EQUAL((vec5(1)), 8);
100 VERIFY_IS_EQUAL((vec5(2)), 15);
101 VERIFY_IS_EQUAL((vec5(3)), 16);
102 VERIFY_IS_EQUAL((vec5(4)), 23);
103 VERIFY_IS_EQUAL((vec5(5)), 42);
104
105 VERIFY_IS_EQUAL((vec5.data()[0]), 4);
106 VERIFY_IS_EQUAL((vec5.data()[1]), 8);
107 VERIFY_IS_EQUAL((vec5.data()[2]), 15);
108 VERIFY_IS_EQUAL((vec5.data()[3]), 16);
109 VERIFY_IS_EQUAL((vec5.data()[4]), 23);
110 VERIFY_IS_EQUAL((vec5.data()[5]), 42);
111 }
112
test_2d()113 static void test_2d()
114 {
115 Tensor<int, 2> mat1(2,3);
116 Tensor<int, 2, RowMajor> mat2(2,3);
117
118 mat1(0,0) = 0;
119 mat1(0,1) = 1;
120 mat1(0,2) = 2;
121 mat1(1,0) = 3;
122 mat1(1,1) = 4;
123 mat1(1,2) = 5;
124
125 mat2(0,0) = 0;
126 mat2(0,1) = 1;
127 mat2(0,2) = 2;
128 mat2(1,0) = 3;
129 mat2(1,1) = 4;
130 mat2(1,2) = 5;
131
132 VERIFY_IS_EQUAL((mat1.rank()), 2);
133 VERIFY_IS_EQUAL((mat1.size()), 6);
134 VERIFY_IS_EQUAL((mat1.dimensions()[0]), 2);
135 VERIFY_IS_EQUAL((mat1.dimensions()[1]), 3);
136
137 VERIFY_IS_EQUAL((mat2.rank()), 2);
138 VERIFY_IS_EQUAL((mat2.size()), 6);
139 VERIFY_IS_EQUAL((mat2.dimensions()[0]), 2);
140 VERIFY_IS_EQUAL((mat2.dimensions()[1]), 3);
141
142 VERIFY_IS_EQUAL((mat1.data()[0]), 0);
143 VERIFY_IS_EQUAL((mat1.data()[1]), 3);
144 VERIFY_IS_EQUAL((mat1.data()[2]), 1);
145 VERIFY_IS_EQUAL((mat1.data()[3]), 4);
146 VERIFY_IS_EQUAL((mat1.data()[4]), 2);
147 VERIFY_IS_EQUAL((mat1.data()[5]), 5);
148
149 VERIFY_IS_EQUAL((mat2.data()[0]), 0);
150 VERIFY_IS_EQUAL((mat2.data()[1]), 1);
151 VERIFY_IS_EQUAL((mat2.data()[2]), 2);
152 VERIFY_IS_EQUAL((mat2.data()[3]), 3);
153 VERIFY_IS_EQUAL((mat2.data()[4]), 4);
154 VERIFY_IS_EQUAL((mat2.data()[5]), 5);
155 }
156
test_3d()157 static void test_3d()
158 {
159 Tensor<int, 3> epsilon(3,3,3);
160 epsilon.setZero();
161 epsilon(0,1,2) = epsilon(2,0,1) = epsilon(1,2,0) = 1;
162 epsilon(2,1,0) = epsilon(0,2,1) = epsilon(1,0,2) = -1;
163
164 VERIFY_IS_EQUAL((epsilon.size()), 27);
165 VERIFY_IS_EQUAL((epsilon.dimensions()[0]), 3);
166 VERIFY_IS_EQUAL((epsilon.dimensions()[1]), 3);
167 VERIFY_IS_EQUAL((epsilon.dimensions()[2]), 3);
168
169 VERIFY_IS_EQUAL((epsilon(0,0,0)), 0);
170 VERIFY_IS_EQUAL((epsilon(0,0,1)), 0);
171 VERIFY_IS_EQUAL((epsilon(0,0,2)), 0);
172 VERIFY_IS_EQUAL((epsilon(0,1,0)), 0);
173 VERIFY_IS_EQUAL((epsilon(0,1,1)), 0);
174 VERIFY_IS_EQUAL((epsilon(0,2,0)), 0);
175 VERIFY_IS_EQUAL((epsilon(0,2,2)), 0);
176 VERIFY_IS_EQUAL((epsilon(1,0,0)), 0);
177 VERIFY_IS_EQUAL((epsilon(1,0,1)), 0);
178 VERIFY_IS_EQUAL((epsilon(1,1,0)), 0);
179 VERIFY_IS_EQUAL((epsilon(1,1,1)), 0);
180 VERIFY_IS_EQUAL((epsilon(1,1,2)), 0);
181 VERIFY_IS_EQUAL((epsilon(1,2,1)), 0);
182 VERIFY_IS_EQUAL((epsilon(1,2,2)), 0);
183 VERIFY_IS_EQUAL((epsilon(2,0,0)), 0);
184 VERIFY_IS_EQUAL((epsilon(2,0,2)), 0);
185 VERIFY_IS_EQUAL((epsilon(2,1,1)), 0);
186 VERIFY_IS_EQUAL((epsilon(2,1,2)), 0);
187 VERIFY_IS_EQUAL((epsilon(2,2,0)), 0);
188 VERIFY_IS_EQUAL((epsilon(2,2,1)), 0);
189 VERIFY_IS_EQUAL((epsilon(2,2,2)), 0);
190
191 VERIFY_IS_EQUAL((epsilon(0,1,2)), 1);
192 VERIFY_IS_EQUAL((epsilon(2,0,1)), 1);
193 VERIFY_IS_EQUAL((epsilon(1,2,0)), 1);
194 VERIFY_IS_EQUAL((epsilon(2,1,0)), -1);
195 VERIFY_IS_EQUAL((epsilon(0,2,1)), -1);
196 VERIFY_IS_EQUAL((epsilon(1,0,2)), -1);
197
198 array<Eigen::DenseIndex, 3> dims;
199 dims[0] = 2;
200 dims[1] = 3;
201 dims[2] = 4;
202 Tensor<int, 3> t1(dims);
203 Tensor<int, 3, RowMajor> t2(dims);
204
205 VERIFY_IS_EQUAL((t1.size()), 24);
206 VERIFY_IS_EQUAL((t1.dimensions()[0]), 2);
207 VERIFY_IS_EQUAL((t1.dimensions()[1]), 3);
208 VERIFY_IS_EQUAL((t1.dimensions()[2]), 4);
209
210 VERIFY_IS_EQUAL((t2.size()), 24);
211 VERIFY_IS_EQUAL((t2.dimensions()[0]), 2);
212 VERIFY_IS_EQUAL((t2.dimensions()[1]), 3);
213 VERIFY_IS_EQUAL((t2.dimensions()[2]), 4);
214
215 for (int i = 0; i < 2; i++) {
216 for (int j = 0; j < 3; j++) {
217 for (int k = 0; k < 4; k++) {
218 t1(i, j, k) = 100 * i + 10 * j + k;
219 t2(i, j, k) = 100 * i + 10 * j + k;
220 }
221 }
222 }
223
224 VERIFY_IS_EQUAL((t1.data()[0]), 0);
225 VERIFY_IS_EQUAL((t1.data()[1]), 100);
226 VERIFY_IS_EQUAL((t1.data()[2]), 10);
227 VERIFY_IS_EQUAL((t1.data()[3]), 110);
228 VERIFY_IS_EQUAL((t1.data()[4]), 20);
229 VERIFY_IS_EQUAL((t1.data()[5]), 120);
230 VERIFY_IS_EQUAL((t1.data()[6]), 1);
231 VERIFY_IS_EQUAL((t1.data()[7]), 101);
232 VERIFY_IS_EQUAL((t1.data()[8]), 11);
233 VERIFY_IS_EQUAL((t1.data()[9]), 111);
234 VERIFY_IS_EQUAL((t1.data()[10]), 21);
235 VERIFY_IS_EQUAL((t1.data()[11]), 121);
236 VERIFY_IS_EQUAL((t1.data()[12]), 2);
237 VERIFY_IS_EQUAL((t1.data()[13]), 102);
238 VERIFY_IS_EQUAL((t1.data()[14]), 12);
239 VERIFY_IS_EQUAL((t1.data()[15]), 112);
240 VERIFY_IS_EQUAL((t1.data()[16]), 22);
241 VERIFY_IS_EQUAL((t1.data()[17]), 122);
242 VERIFY_IS_EQUAL((t1.data()[18]), 3);
243 VERIFY_IS_EQUAL((t1.data()[19]), 103);
244 VERIFY_IS_EQUAL((t1.data()[20]), 13);
245 VERIFY_IS_EQUAL((t1.data()[21]), 113);
246 VERIFY_IS_EQUAL((t1.data()[22]), 23);
247 VERIFY_IS_EQUAL((t1.data()[23]), 123);
248
249 VERIFY_IS_EQUAL((t2.data()[0]), 0);
250 VERIFY_IS_EQUAL((t2.data()[1]), 1);
251 VERIFY_IS_EQUAL((t2.data()[2]), 2);
252 VERIFY_IS_EQUAL((t2.data()[3]), 3);
253 VERIFY_IS_EQUAL((t2.data()[4]), 10);
254 VERIFY_IS_EQUAL((t2.data()[5]), 11);
255 VERIFY_IS_EQUAL((t2.data()[6]), 12);
256 VERIFY_IS_EQUAL((t2.data()[7]), 13);
257 VERIFY_IS_EQUAL((t2.data()[8]), 20);
258 VERIFY_IS_EQUAL((t2.data()[9]), 21);
259 VERIFY_IS_EQUAL((t2.data()[10]), 22);
260 VERIFY_IS_EQUAL((t2.data()[11]), 23);
261 VERIFY_IS_EQUAL((t2.data()[12]), 100);
262 VERIFY_IS_EQUAL((t2.data()[13]), 101);
263 VERIFY_IS_EQUAL((t2.data()[14]), 102);
264 VERIFY_IS_EQUAL((t2.data()[15]), 103);
265 VERIFY_IS_EQUAL((t2.data()[16]), 110);
266 VERIFY_IS_EQUAL((t2.data()[17]), 111);
267 VERIFY_IS_EQUAL((t2.data()[18]), 112);
268 VERIFY_IS_EQUAL((t2.data()[19]), 113);
269 VERIFY_IS_EQUAL((t2.data()[20]), 120);
270 VERIFY_IS_EQUAL((t2.data()[21]), 121);
271 VERIFY_IS_EQUAL((t2.data()[22]), 122);
272 VERIFY_IS_EQUAL((t2.data()[23]), 123);
273 }
274
test_simple_assign()275 static void test_simple_assign()
276 {
277 Tensor<int, 3> epsilon(3,3,3);
278 epsilon.setZero();
279 epsilon(0,1,2) = epsilon(2,0,1) = epsilon(1,2,0) = 1;
280 epsilon(2,1,0) = epsilon(0,2,1) = epsilon(1,0,2) = -1;
281
282 Tensor<int, 3> e2(3,3,3);
283 e2.setZero();
284 VERIFY_IS_EQUAL((e2(1,2,0)), 0);
285
286 e2 = epsilon;
287 VERIFY_IS_EQUAL((e2(1,2,0)), 1);
288 VERIFY_IS_EQUAL((e2(0,1,2)), 1);
289 VERIFY_IS_EQUAL((e2(2,0,1)), 1);
290 VERIFY_IS_EQUAL((e2(2,1,0)), -1);
291 VERIFY_IS_EQUAL((e2(0,2,1)), -1);
292 VERIFY_IS_EQUAL((e2(1,0,2)), -1);
293 }
294
test_resize()295 static void test_resize()
296 {
297 Tensor<int, 3> epsilon;
298 epsilon.resize(2,3,7);
299 VERIFY_IS_EQUAL(epsilon.dimension(0), 2);
300 VERIFY_IS_EQUAL(epsilon.dimension(1), 3);
301 VERIFY_IS_EQUAL(epsilon.dimension(2), 7);
302 VERIFY_IS_EQUAL(epsilon.size(), 2*3*7);
303
304 const int* old_data = epsilon.data();
305 epsilon.resize(3,2,7);
306 VERIFY_IS_EQUAL(epsilon.dimension(0), 3);
307 VERIFY_IS_EQUAL(epsilon.dimension(1), 2);
308 VERIFY_IS_EQUAL(epsilon.dimension(2), 7);
309 VERIFY_IS_EQUAL(epsilon.size(), 2*3*7);
310 VERIFY_IS_EQUAL(epsilon.data(), old_data);
311
312 epsilon.resize(3,5,7);
313 VERIFY_IS_EQUAL(epsilon.dimension(0), 3);
314 VERIFY_IS_EQUAL(epsilon.dimension(1), 5);
315 VERIFY_IS_EQUAL(epsilon.dimension(2), 7);
316 VERIFY_IS_EQUAL(epsilon.size(), 3*5*7);
317 }
318
test_cxx11_tensor_simple()319 void test_cxx11_tensor_simple()
320 {
321 CALL_SUBTEST(test_0d());
322 CALL_SUBTEST(test_1d());
323 CALL_SUBTEST(test_2d());
324 CALL_SUBTEST(test_3d());
325 CALL_SUBTEST(test_simple_assign());
326 CALL_SUBTEST(test_resize());
327 }
328