1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 #include <Eigen/CXX11/Tensor>
13 
14 using Eigen::Tensor;
15 
test_simple_patch()16 void test_simple_patch()
17 {
18   Tensor<float, 4> tensor(2,3,5,7);
19   tensor.setRandom();
20   Tensor<float, 4, RowMajor> tensor_row_major = tensor.swap_layout();
21   VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3));
22   VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2));
23   VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1));
24   VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0));
25 
26   // Single pixel patch: ColMajor
27   Tensor<float, 5> single_pixel_patch;
28   single_pixel_patch = tensor.extract_image_patches(1, 1);
29   VERIFY_IS_EQUAL(single_pixel_patch.dimension(0), 2);
30   VERIFY_IS_EQUAL(single_pixel_patch.dimension(1), 1);
31   VERIFY_IS_EQUAL(single_pixel_patch.dimension(2), 1);
32   VERIFY_IS_EQUAL(single_pixel_patch.dimension(3), 3*5);
33   VERIFY_IS_EQUAL(single_pixel_patch.dimension(4), 7);
34 
35   // Single pixel patch: RowMajor
36   Tensor<float, 5, RowMajor> single_pixel_patch_row_major;
37   single_pixel_patch_row_major = tensor_row_major.extract_image_patches(1, 1);
38   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(0), 7);
39   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(1), 3*5);
40   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(2), 1);
41   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(3), 1);
42   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(4), 2);
43 
44   for (int i = 0; i < tensor.size(); ++i) {
45     // ColMajor
46     if (tensor.data()[i] != single_pixel_patch.data()[i]) {
47       std::cout << "Mismatch detected at index " << i << " : "
48            << tensor.data()[i] << " vs " << single_pixel_patch.data()[i]
49            << std::endl;
50     }
51     VERIFY_IS_EQUAL(single_pixel_patch.data()[i], tensor.data()[i]);
52     // RowMajor
53     if (tensor_row_major.data()[i] != single_pixel_patch_row_major.data()[i]) {
54       std::cout << "Mismatch detected at index " << i << " : "
55            << tensor.data()[i] << " vs "
56            << single_pixel_patch_row_major.data()[i] << std::endl;
57     }
58     VERIFY_IS_EQUAL(single_pixel_patch_row_major.data()[i],
59                     tensor_row_major.data()[i]);
60     VERIFY_IS_EQUAL(tensor.data()[i], tensor_row_major.data()[i]);
61     VERIFY_IS_EQUAL(single_pixel_patch.data()[i],
62                     single_pixel_patch_row_major.data()[i]);
63   }
64 
65   // Entire image patch: ColMajor
66   Tensor<float, 5> entire_image_patch;
67   entire_image_patch = tensor.extract_image_patches(3, 5);
68   VERIFY_IS_EQUAL(entire_image_patch.dimension(0), 2);
69   VERIFY_IS_EQUAL(entire_image_patch.dimension(1), 3);
70   VERIFY_IS_EQUAL(entire_image_patch.dimension(2), 5);
71   VERIFY_IS_EQUAL(entire_image_patch.dimension(3), 3*5);
72   VERIFY_IS_EQUAL(entire_image_patch.dimension(4), 7);
73 
74   // Entire image patch: RowMajor
75   Tensor<float, 5, RowMajor> entire_image_patch_row_major;
76   entire_image_patch_row_major = tensor_row_major.extract_image_patches(3, 5);
77   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(0), 7);
78   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(1), 3*5);
79   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(2), 5);
80   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(3), 3);
81   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(4), 2);
82 
83   for (int i = 0; i < 3; ++i) {
84     for (int j = 0; j < 5; ++j) {
85       int patchId = i+3*j;
86       for (int r = 0; r < 3; ++r) {
87         for (int c = 0; c < 5; ++c) {
88           for (int d = 0; d < 2; ++d) {
89             for (int b = 0; b < 7; ++b) {
90               float expected = 0.0f;
91               float expected_row_major = 0.0f;
92               if (r-1+i >= 0 && c-2+j >= 0 && r-1+i < 3 && c-2+j < 5) {
93                 expected = tensor(d, r-1+i, c-2+j, b);
94                 expected_row_major = tensor_row_major(b, c-2+j, r-1+i, d);
95               }
96               // ColMajor
97               if (entire_image_patch(d, r, c, patchId, b) != expected) {
98                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
99               }
100               VERIFY_IS_EQUAL(entire_image_patch(d, r, c, patchId, b), expected);
101               // RowMajor
102               if (entire_image_patch_row_major(b, patchId, c, r, d) !=
103                   expected_row_major) {
104                 std::cout << "Mismatch detected at index i=" << i << " j=" << j
105                      << " r=" << r << " c=" << c << " d=" << d << " b=" << b
106                      << std::endl;
107               }
108               VERIFY_IS_EQUAL(entire_image_patch_row_major(b, patchId, c, r, d),
109                               expected_row_major);
110               // Check that ColMajor and RowMajor agree.
111               VERIFY_IS_EQUAL(expected, expected_row_major);
112             }
113           }
114         }
115       }
116     }
117   }
118 
119   // 2D patch: ColMajor
120   Tensor<float, 5> twod_patch;
121   twod_patch = tensor.extract_image_patches(2, 2);
122   VERIFY_IS_EQUAL(twod_patch.dimension(0), 2);
123   VERIFY_IS_EQUAL(twod_patch.dimension(1), 2);
124   VERIFY_IS_EQUAL(twod_patch.dimension(2), 2);
125   VERIFY_IS_EQUAL(twod_patch.dimension(3), 3*5);
126   VERIFY_IS_EQUAL(twod_patch.dimension(4), 7);
127 
128   // 2D patch: RowMajor
129   Tensor<float, 5, RowMajor> twod_patch_row_major;
130   twod_patch_row_major = tensor_row_major.extract_image_patches(2, 2);
131   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(0), 7);
132   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(1), 3*5);
133   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(2), 2);
134   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(3), 2);
135   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(4), 2);
136 
137 
138   // Based on the calculation described in TensorTraits.h, padding happens to be 0.
139   int row_padding = 0;
140   int col_padding = 0;
141   int stride = 1;
142 
143   for (int i = 0; i < 3; ++i) {
144     for (int j = 0; j < 5; ++j) {
145       int patchId = i+3*j;
146       for (int r = 0; r < 2; ++r) {
147         for (int c = 0; c < 2; ++c) {
148           for (int d = 0; d < 2; ++d) {
149             for (int b = 0; b < 7; ++b) {
150               float expected = 0.0f;
151               float expected_row_major = 0.0f;
152               int row_offset = r*stride + i - row_padding;
153               int col_offset = c*stride + j - col_padding;
154               // ColMajor
155               if (row_offset >= 0 && col_offset >= 0 && row_offset < tensor.dimension(1) && col_offset < tensor.dimension(2)) {
156                 expected = tensor(d, row_offset, col_offset, b);
157               }
158               if (twod_patch(d, r, c, patchId, b) != expected) {
159                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
160               }
161               VERIFY_IS_EQUAL(twod_patch(d, r, c, patchId, b), expected);
162 
163               // RowMajor
164               if (row_offset >= 0 && col_offset >= 0 && row_offset < tensor_row_major.dimension(2) && col_offset < tensor_row_major.dimension(1)) {
165                 expected_row_major = tensor_row_major(b, col_offset, row_offset, d);
166 
167               }
168               if (twod_patch_row_major(b, patchId, c, r, d) != expected_row_major) {
169                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
170               }
171               VERIFY_IS_EQUAL(twod_patch_row_major(b, patchId, c, r, d), expected_row_major);
172               // Check that ColMajor and RowMajor agree.
173               VERIFY_IS_EQUAL(expected, expected_row_major);
174             }
175           }
176         }
177       }
178     }
179   }
180 }
181 
182 // Verifies VALID padding (no padding) with incrementing values.
test_patch_padding_valid()183 void test_patch_padding_valid()
184 {
185   int input_depth = 3;
186   int input_rows = 3;
187   int input_cols = 3;
188   int input_batches = 1;
189   int ksize = 2;  // Corresponds to the Rows and Cols for tensor.extract_image_patches<>.
190   int stride = 2;  // Only same stride is supported.
191   Tensor<float, 4> tensor(input_depth, input_rows, input_cols, input_batches);
192   // Initializes tensor with incrementing numbers.
193   for (int i = 0; i < tensor.size(); ++i) {
194     tensor.data()[i] = i + 1;
195   }
196   // ColMajor
197   Tensor<float, 5> result = tensor.extract_image_patches(ksize, ksize, stride, stride, 1, 1, PADDING_VALID);
198 
199   VERIFY_IS_EQUAL(result.dimension(0), input_depth);  // depth
200   VERIFY_IS_EQUAL(result.dimension(1), ksize);  // kernel rows
201   VERIFY_IS_EQUAL(result.dimension(2), ksize);  // kernel cols
202   VERIFY_IS_EQUAL(result.dimension(3), 1);  // number of patches
203   VERIFY_IS_EQUAL(result.dimension(4), input_batches);  // number of batches
204 
205   // RowMajor
206   Tensor<float, 4, RowMajor> tensor_row_major = tensor.swap_layout();
207   VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3));
208   VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2));
209   VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1));
210   VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0));
211 
212   Tensor<float, 5, RowMajor> result_row_major = tensor_row_major.extract_image_patches(ksize, ksize, stride, stride, 1, 1, PADDING_VALID);
213   VERIFY_IS_EQUAL(result.dimension(0), result_row_major.dimension(4));
214   VERIFY_IS_EQUAL(result.dimension(1), result_row_major.dimension(3));
215   VERIFY_IS_EQUAL(result.dimension(2), result_row_major.dimension(2));
216   VERIFY_IS_EQUAL(result.dimension(3), result_row_major.dimension(1));
217   VERIFY_IS_EQUAL(result.dimension(4), result_row_major.dimension(0));
218 
219   // No padding is carried out.
220   int row_padding = 0;
221   int col_padding = 0;
222 
223   for (int i = 0; (i+stride+ksize-1) < input_rows; i += stride) {  // input rows
224     for (int j = 0; (j+stride+ksize-1) < input_cols; j += stride) {  // input cols
225       int patchId = i+input_rows*j;
226       for (int r = 0; r < ksize; ++r) {  // patch rows
227         for (int c = 0; c < ksize; ++c) {  // patch cols
228           for (int d = 0; d < input_depth; ++d) {  // depth
229             for (int b = 0; b < input_batches; ++b) {  // batch
230               float expected = 0.0f;
231               float expected_row_major = 0.0f;
232               int row_offset = r + i - row_padding;
233               int col_offset = c + j - col_padding;
234               if (row_offset >= 0 && col_offset >= 0 && row_offset < input_rows && col_offset < input_cols) {
235                 expected = tensor(d, row_offset, col_offset, b);
236                 expected_row_major = tensor_row_major(b, col_offset, row_offset, d);
237               }
238               // ColMajor
239               if (result(d, r, c, patchId, b) != expected) {
240                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
241               }
242               VERIFY_IS_EQUAL(result(d, r, c, patchId, b), expected);
243               // RowMajor
244               if (result_row_major(b, patchId, c, r, d) != expected_row_major) {
245                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
246               }
247               VERIFY_IS_EQUAL(result_row_major(b, patchId, c, r, d), expected_row_major);
248               // Check that ColMajor and RowMajor agree.
249               VERIFY_IS_EQUAL(expected, expected_row_major);
250             }
251           }
252         }
253       }
254     }
255   }
256 }
257 
258 // Verifies VALID padding (no padding) with the same value.
test_patch_padding_valid_same_value()259 void test_patch_padding_valid_same_value()
260 {
261   int input_depth = 1;
262   int input_rows = 5;
263   int input_cols = 5;
264   int input_batches = 2;
265   int ksize = 3;  // Corresponds to the Rows and Cols for tensor.extract_image_patches<>.
266   int stride = 2;  // Only same stride is supported.
267   // ColMajor
268   Tensor<float, 4> tensor(input_depth, input_rows, input_cols, input_batches);
269   tensor = tensor.constant(11.0f);
270   Tensor<float, 5> result = tensor.extract_image_patches(ksize, ksize, stride, stride, 1, 1, PADDING_VALID);
271 
272   VERIFY_IS_EQUAL(result.dimension(0), input_depth);  // depth
273   VERIFY_IS_EQUAL(result.dimension(1), ksize);  // kernel rows
274   VERIFY_IS_EQUAL(result.dimension(2), ksize);  // kernel cols
275   VERIFY_IS_EQUAL(result.dimension(3), 4);  // number of patches
276   VERIFY_IS_EQUAL(result.dimension(4), input_batches);  // number of batches
277 
278   // RowMajor
279   Tensor<float, 4, RowMajor> tensor_row_major = tensor.swap_layout();
280   VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3));
281   VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2));
282   VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1));
283   VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0));
284 
285   Tensor<float, 5, RowMajor> result_row_major = tensor_row_major.extract_image_patches(ksize, ksize, stride, stride, 1, 1, PADDING_VALID);
286   VERIFY_IS_EQUAL(result.dimension(0), result_row_major.dimension(4));
287   VERIFY_IS_EQUAL(result.dimension(1), result_row_major.dimension(3));
288   VERIFY_IS_EQUAL(result.dimension(2), result_row_major.dimension(2));
289   VERIFY_IS_EQUAL(result.dimension(3), result_row_major.dimension(1));
290   VERIFY_IS_EQUAL(result.dimension(4), result_row_major.dimension(0));
291 
292   // No padding is carried out.
293   int row_padding = 0;
294   int col_padding = 0;
295 
296   for (int i = 0; (i+stride+ksize-1) <= input_rows; i += stride) {  // input rows
297     for (int j = 0; (j+stride+ksize-1) <= input_cols; j += stride) {  // input cols
298       int patchId = i+input_rows*j;
299       for (int r = 0; r < ksize; ++r) {  // patch rows
300         for (int c = 0; c < ksize; ++c) {  // patch cols
301           for (int d = 0; d < input_depth; ++d) {  // depth
302             for (int b = 0; b < input_batches; ++b) {  // batch
303               float expected = 0.0f;
304               float expected_row_major = 0.0f;
305               int row_offset = r + i - row_padding;
306               int col_offset = c + j - col_padding;
307               if (row_offset >= 0 && col_offset >= 0 && row_offset < input_rows && col_offset < input_cols) {
308                 expected = tensor(d, row_offset, col_offset, b);
309                 expected_row_major = tensor_row_major(b, col_offset, row_offset, d);
310               }
311               // ColMajor
312               if (result(d, r, c, patchId, b) != expected) {
313                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
314               }
315               VERIFY_IS_EQUAL(result(d, r, c, patchId, b), expected);
316               // RowMajor
317               if (result_row_major(b, patchId, c, r, d) != expected_row_major) {
318                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
319               }
320               VERIFY_IS_EQUAL(result_row_major(b, patchId, c, r, d), expected_row_major);
321               // Check that ColMajor and RowMajor agree.
322               VERIFY_IS_EQUAL(expected, expected_row_major);
323             }
324           }
325         }
326       }
327     }
328   }
329 }
330 
331 // Verifies SAME padding.
test_patch_padding_same()332 void test_patch_padding_same()
333 {
334   int input_depth = 3;
335   int input_rows = 4;
336   int input_cols = 2;
337   int input_batches = 1;
338   int ksize = 2;  // Corresponds to the Rows and Cols for tensor.extract_image_patches<>.
339   int stride = 2;  // Only same stride is supported.
340   // ColMajor
341   Tensor<float, 4> tensor(input_depth, input_rows, input_cols, input_batches);
342   // Initializes tensor with incrementing numbers.
343   for (int i = 0; i < tensor.size(); ++i) {
344     tensor.data()[i] = i + 1;
345   }
346   Tensor<float, 5> result = tensor.extract_image_patches(ksize, ksize, stride, stride, PADDING_SAME);
347 
348   VERIFY_IS_EQUAL(result.dimension(0), input_depth);  // depth
349   VERIFY_IS_EQUAL(result.dimension(1), ksize);  // kernel rows
350   VERIFY_IS_EQUAL(result.dimension(2), ksize);  // kernel cols
351   VERIFY_IS_EQUAL(result.dimension(3), 2);  // number of patches
352   VERIFY_IS_EQUAL(result.dimension(4), input_batches);  // number of batches
353 
354   // RowMajor
355   Tensor<float, 4, RowMajor> tensor_row_major = tensor.swap_layout();
356   VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3));
357   VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2));
358   VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1));
359   VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0));
360 
361   Tensor<float, 5, RowMajor> result_row_major = tensor_row_major.extract_image_patches(ksize, ksize, stride, stride, PADDING_SAME);
362   VERIFY_IS_EQUAL(result.dimension(0), result_row_major.dimension(4));
363   VERIFY_IS_EQUAL(result.dimension(1), result_row_major.dimension(3));
364   VERIFY_IS_EQUAL(result.dimension(2), result_row_major.dimension(2));
365   VERIFY_IS_EQUAL(result.dimension(3), result_row_major.dimension(1));
366   VERIFY_IS_EQUAL(result.dimension(4), result_row_major.dimension(0));
367 
368   // Based on the calculation described in TensorTraits.h, padding happens to be
369   // 0.
370   int row_padding = 0;
371   int col_padding = 0;
372 
373   for (int i = 0; (i+stride+ksize-1) <= input_rows; i += stride) {  // input rows
374     for (int j = 0; (j+stride+ksize-1) <= input_cols; j += stride) {  // input cols
375       int patchId = i+input_rows*j;
376       for (int r = 0; r < ksize; ++r) {  // patch rows
377         for (int c = 0; c < ksize; ++c) {  // patch cols
378           for (int d = 0; d < input_depth; ++d) {  // depth
379             for (int b = 0; b < input_batches; ++b) {  // batch
380               float expected = 0.0f;
381               float expected_row_major = 0.0f;
382               int row_offset = r*stride + i - row_padding;
383               int col_offset = c*stride + j - col_padding;
384               if (row_offset >= 0 && col_offset >= 0 && row_offset < input_rows && col_offset < input_cols) {
385                 expected = tensor(d, row_offset, col_offset, b);
386                 expected_row_major = tensor_row_major(b, col_offset, row_offset, d);
387               }
388               // ColMajor
389               if (result(d, r, c, patchId, b) != expected) {
390                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
391               }
392               VERIFY_IS_EQUAL(result(d, r, c, patchId, b), expected);
393               // RowMajor
394               if (result_row_major(b, patchId, c, r, d) != expected_row_major) {
395                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
396               }
397               VERIFY_IS_EQUAL(result_row_major(b, patchId, c, r, d), expected_row_major);
398               // Check that ColMajor and RowMajor agree.
399               VERIFY_IS_EQUAL(expected, expected_row_major);
400             }
401           }
402         }
403       }
404     }
405   }
406 }
407 
test_patch_no_extra_dim()408 void test_patch_no_extra_dim()
409 {
410   Tensor<float, 3> tensor(2,3,5);
411   tensor.setRandom();
412   Tensor<float, 3, RowMajor> tensor_row_major = tensor.swap_layout();
413   VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(2));
414   VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(1));
415   VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(0));
416 
417   // Single pixel patch: ColMajor
418   Tensor<float, 4> single_pixel_patch;
419   single_pixel_patch = tensor.extract_image_patches(1, 1);
420   VERIFY_IS_EQUAL(single_pixel_patch.dimension(0), 2);
421   VERIFY_IS_EQUAL(single_pixel_patch.dimension(1), 1);
422   VERIFY_IS_EQUAL(single_pixel_patch.dimension(2), 1);
423   VERIFY_IS_EQUAL(single_pixel_patch.dimension(3), 3*5);
424 
425   // Single pixel patch: RowMajor
426   Tensor<float, 4, RowMajor> single_pixel_patch_row_major;
427   single_pixel_patch_row_major = tensor_row_major.extract_image_patches(1, 1);
428   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(0), 3*5);
429   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(1), 1);
430   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(2), 1);
431   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(3), 2);
432 
433   for (int i = 0; i < tensor.size(); ++i) {
434     // ColMajor
435     if (tensor.data()[i] != single_pixel_patch.data()[i]) {
436       std::cout << "Mismatch detected at index " << i << " : " << tensor.data()[i] << " vs " << single_pixel_patch.data()[i] << std::endl;
437     }
438     VERIFY_IS_EQUAL(single_pixel_patch.data()[i], tensor.data()[i]);
439     // RowMajor
440     if (tensor_row_major.data()[i] != single_pixel_patch_row_major.data()[i]) {
441       std::cout << "Mismatch detected at index " << i << " : "
442            << tensor.data()[i] << " vs "
443            << single_pixel_patch_row_major.data()[i] << std::endl;
444     }
445     VERIFY_IS_EQUAL(single_pixel_patch_row_major.data()[i],
446                     tensor_row_major.data()[i]);
447     VERIFY_IS_EQUAL(tensor.data()[i], tensor_row_major.data()[i]);
448     VERIFY_IS_EQUAL(single_pixel_patch.data()[i],
449                     single_pixel_patch_row_major.data()[i]);
450   }
451 
452   // Entire image patch: ColMajor
453   Tensor<float, 4> entire_image_patch;
454   entire_image_patch = tensor.extract_image_patches(3, 5);
455   VERIFY_IS_EQUAL(entire_image_patch.dimension(0), 2);
456   VERIFY_IS_EQUAL(entire_image_patch.dimension(1), 3);
457   VERIFY_IS_EQUAL(entire_image_patch.dimension(2), 5);
458   VERIFY_IS_EQUAL(entire_image_patch.dimension(3), 3*5);
459 
460   // Entire image patch: RowMajor
461   Tensor<float, 4, RowMajor> entire_image_patch_row_major;
462   entire_image_patch_row_major = tensor_row_major.extract_image_patches(3, 5);
463   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(0), 3*5);
464   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(1), 5);
465   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(2), 3);
466   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(3), 2);
467 
468   for (int i = 0; i < 3; ++i) {
469     for (int j = 0; j < 5; ++j) {
470       int patchId = i+3*j;
471       for (int r = 0; r < 3; ++r) {
472         for (int c = 0; c < 5; ++c) {
473           for (int d = 0; d < 2; ++d) {
474             float expected = 0.0f;
475             float expected_row_major = 0.0f;
476             if (r-1+i >= 0 && c-2+j >= 0 && r-1+i < 3 && c-2+j < 5) {
477               expected = tensor(d, r-1+i, c-2+j);
478               expected_row_major = tensor_row_major(c-2+j, r-1+i, d);
479             }
480             // ColMajor
481             if (entire_image_patch(d, r, c, patchId) != expected) {
482               std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << std::endl;
483             }
484             VERIFY_IS_EQUAL(entire_image_patch(d, r, c, patchId), expected);
485             // RowMajor
486             if (entire_image_patch_row_major(patchId, c, r, d) !=
487                 expected_row_major) {
488               std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << std::endl;
489             }
490             VERIFY_IS_EQUAL(entire_image_patch_row_major(patchId, c, r, d),
491                             expected_row_major);
492             // Check that ColMajor and RowMajor agree.
493             VERIFY_IS_EQUAL(expected, expected_row_major);
494           }
495         }
496       }
497     }
498   }
499 
500   // 2D patch: ColMajor
501   Tensor<float, 4> twod_patch;
502   twod_patch = tensor.extract_image_patches(2, 2);
503   VERIFY_IS_EQUAL(twod_patch.dimension(0), 2);
504   VERIFY_IS_EQUAL(twod_patch.dimension(1), 2);
505   VERIFY_IS_EQUAL(twod_patch.dimension(2), 2);
506   VERIFY_IS_EQUAL(twod_patch.dimension(3), 3*5);
507 
508   // 2D patch: RowMajor
509   Tensor<float, 4, RowMajor> twod_patch_row_major;
510   twod_patch_row_major = tensor_row_major.extract_image_patches(2, 2);
511   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(0), 3*5);
512   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(1), 2);
513   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(2), 2);
514   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(3), 2);
515 
516   // Based on the calculation described in TensorTraits.h, padding happens to be 0.
517   int row_padding = 0;
518   int col_padding = 0;
519   int stride = 1;
520 
521   for (int i = 0; i < 3; ++i) {
522     for (int j = 0; j < 5; ++j) {
523       int patchId = i+3*j;
524       for (int r = 0; r < 2; ++r) {
525         for (int c = 0; c < 2; ++c) {
526           for (int d = 0; d < 2; ++d) {
527             float expected = 0.0f;
528             float expected_row_major = 0.0f;
529             int row_offset = r*stride + i - row_padding;
530             int col_offset = c*stride + j - col_padding;
531             // ColMajor
532             if (row_offset >= 0 && col_offset >= 0 && row_offset < tensor.dimension(1) && col_offset < tensor.dimension(2)) {
533               expected = tensor(d, row_offset, col_offset);
534             }
535             if (twod_patch(d, r, c, patchId) != expected) {
536               std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << std::endl;
537             }
538             VERIFY_IS_EQUAL(twod_patch(d, r, c, patchId), expected);
539             // RowMajor
540             if (row_offset >= 0 && col_offset >= 0 && row_offset < tensor_row_major.dimension(1) && col_offset < tensor_row_major.dimension(0)) {
541               expected_row_major = tensor_row_major(col_offset, row_offset, d);
542             }
543             if (twod_patch_row_major(patchId, c, r, d) != expected_row_major) {
544               std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << std::endl;
545             }
546             VERIFY_IS_EQUAL(twod_patch_row_major(patchId, c, r, d), expected_row_major);
547             // Check that ColMajor and RowMajor agree.
548             VERIFY_IS_EQUAL(expected, expected_row_major);
549           }
550         }
551       }
552     }
553   }
554 }
555 
test_imagenet_patches()556 void test_imagenet_patches()
557 {
558   // Test the code on typical configurations used by the 'imagenet' benchmarks at
559   // https://github.com/soumith/convnet-benchmarks
560   // ColMajor
561   Tensor<float, 4> l_in(3, 128, 128, 16);
562   l_in.setRandom();
563   Tensor<float, 5> l_out = l_in.extract_image_patches(11, 11);
564   VERIFY_IS_EQUAL(l_out.dimension(0), 3);
565   VERIFY_IS_EQUAL(l_out.dimension(1), 11);
566   VERIFY_IS_EQUAL(l_out.dimension(2), 11);
567   VERIFY_IS_EQUAL(l_out.dimension(3), 128*128);
568   VERIFY_IS_EQUAL(l_out.dimension(4), 16);
569 
570   // RowMajor
571   Tensor<float, 5, RowMajor> l_out_row_major = l_in.swap_layout().extract_image_patches(11, 11);
572   VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 16);
573   VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 128*128);
574   VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 11);
575   VERIFY_IS_EQUAL(l_out_row_major.dimension(3), 11);
576   VERIFY_IS_EQUAL(l_out_row_major.dimension(4), 3);
577 
578   for (int b = 0; b < 16; ++b) {
579     for (int i = 0; i < 128; ++i) {
580       for (int j = 0; j < 128; ++j) {
581         int patchId = i+128*j;
582         for (int c = 0; c < 11; ++c) {
583           for (int r = 0; r < 11; ++r) {
584             for (int d = 0; d < 3; ++d) {
585               float expected = 0.0f;
586               if (r-5+i >= 0 && c-5+j >= 0 && r-5+i < 128 && c-5+j < 128) {
587                 expected = l_in(d, r-5+i, c-5+j, b);
588               }
589               // ColMajor
590               if (l_out(d, r, c, patchId, b) != expected) {
591                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
592               }
593               VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected);
594               // RowMajor
595               if (l_out_row_major(b, patchId, c, r, d) !=
596                   expected) {
597                 std::cout << "Mismatch detected at index i=" << i << " j=" << j
598                      << " r=" << r << " c=" << c << " d=" << d << " b=" << b
599                      << std::endl;
600               }
601               VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d),
602                               expected);
603             }
604           }
605         }
606       }
607     }
608   }
609 
610   // ColMajor
611   l_in.resize(16, 64, 64, 32);
612   l_in.setRandom();
613   l_out = l_in.extract_image_patches(9, 9);
614   VERIFY_IS_EQUAL(l_out.dimension(0), 16);
615   VERIFY_IS_EQUAL(l_out.dimension(1), 9);
616   VERIFY_IS_EQUAL(l_out.dimension(2), 9);
617   VERIFY_IS_EQUAL(l_out.dimension(3), 64*64);
618   VERIFY_IS_EQUAL(l_out.dimension(4), 32);
619 
620   // RowMajor
621   l_out_row_major = l_in.swap_layout().extract_image_patches(9, 9);
622   VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 32);
623   VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 64*64);
624   VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 9);
625   VERIFY_IS_EQUAL(l_out_row_major.dimension(3), 9);
626   VERIFY_IS_EQUAL(l_out_row_major.dimension(4), 16);
627 
628   for (int b = 0; b < 32; ++b) {
629     for (int i = 0; i < 64; ++i) {
630       for (int j = 0; j < 64; ++j) {
631         int patchId = i+64*j;
632         for (int c = 0; c < 9; ++c) {
633           for (int r = 0; r < 9; ++r) {
634             for (int d = 0; d < 16; ++d) {
635               float expected = 0.0f;
636               if (r-4+i >= 0 && c-4+j >= 0 && r-4+i < 64 && c-4+j < 64) {
637                 expected = l_in(d, r-4+i, c-4+j, b);
638               }
639               // ColMajor
640               if (l_out(d, r, c, patchId, b) != expected) {
641                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
642               }
643               VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected);
644               // RowMajor
645               if (l_out_row_major(b, patchId, c, r, d) != expected) {
646                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
647               }
648               VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected);
649             }
650           }
651         }
652       }
653     }
654   }
655 
656   // ColMajor
657   l_in.resize(32, 16, 16, 32);
658   l_in.setRandom();
659   l_out = l_in.extract_image_patches(7, 7);
660   VERIFY_IS_EQUAL(l_out.dimension(0), 32);
661   VERIFY_IS_EQUAL(l_out.dimension(1), 7);
662   VERIFY_IS_EQUAL(l_out.dimension(2), 7);
663   VERIFY_IS_EQUAL(l_out.dimension(3), 16*16);
664   VERIFY_IS_EQUAL(l_out.dimension(4), 32);
665 
666   // RowMajor
667   l_out_row_major = l_in.swap_layout().extract_image_patches(7, 7);
668   VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 32);
669   VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 16*16);
670   VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 7);
671   VERIFY_IS_EQUAL(l_out_row_major.dimension(3), 7);
672   VERIFY_IS_EQUAL(l_out_row_major.dimension(4), 32);
673 
674   for (int b = 0; b < 32; ++b) {
675     for (int i = 0; i < 16; ++i) {
676       for (int j = 0; j < 16; ++j) {
677         int patchId = i+16*j;
678         for (int c = 0; c < 7; ++c) {
679           for (int r = 0; r < 7; ++r) {
680             for (int d = 0; d < 32; ++d) {
681               float expected = 0.0f;
682               if (r-3+i >= 0 && c-3+j >= 0 && r-3+i < 16 && c-3+j < 16) {
683                 expected = l_in(d, r-3+i, c-3+j, b);
684               }
685               // ColMajor
686               if (l_out(d, r, c, patchId, b) != expected) {
687                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
688               }
689               VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected);
690               // RowMajor
691               if (l_out_row_major(b, patchId, c, r, d) != expected) {
692                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
693               }
694               VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected);
695             }
696           }
697         }
698       }
699     }
700   }
701 
702   // ColMajor
703   l_in.resize(64, 13, 13, 32);
704   l_in.setRandom();
705   l_out = l_in.extract_image_patches(3, 3);
706   VERIFY_IS_EQUAL(l_out.dimension(0), 64);
707   VERIFY_IS_EQUAL(l_out.dimension(1), 3);
708   VERIFY_IS_EQUAL(l_out.dimension(2), 3);
709   VERIFY_IS_EQUAL(l_out.dimension(3), 13*13);
710   VERIFY_IS_EQUAL(l_out.dimension(4), 32);
711 
712   // RowMajor
713   l_out_row_major = l_in.swap_layout().extract_image_patches(3, 3);
714   VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 32);
715   VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 13*13);
716   VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 3);
717   VERIFY_IS_EQUAL(l_out_row_major.dimension(3), 3);
718   VERIFY_IS_EQUAL(l_out_row_major.dimension(4), 64);
719 
720   for (int b = 0; b < 32; ++b) {
721     for (int i = 0; i < 13; ++i) {
722       for (int j = 0; j < 13; ++j) {
723         int patchId = i+13*j;
724         for (int c = 0; c < 3; ++c) {
725           for (int r = 0; r < 3; ++r) {
726             for (int d = 0; d < 64; ++d) {
727               float expected = 0.0f;
728               if (r-1+i >= 0 && c-1+j >= 0 && r-1+i < 13 && c-1+j < 13) {
729                 expected = l_in(d, r-1+i, c-1+j, b);
730               }
731               // ColMajor
732               if (l_out(d, r, c, patchId, b) != expected) {
733                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
734               }
735               VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected);
736               // RowMajor
737               if (l_out_row_major(b, patchId, c, r, d) != expected) {
738                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
739               }
740               VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected);
741             }
742           }
743         }
744       }
745     }
746   }
747 }
748 
test_cxx11_tensor_image_patch()749 void test_cxx11_tensor_image_patch()
750 {
751   CALL_SUBTEST_1(test_simple_patch());
752   CALL_SUBTEST_2(test_patch_no_extra_dim());
753   CALL_SUBTEST_3(test_patch_padding_valid());
754   CALL_SUBTEST_4(test_patch_padding_valid_same_value());
755   CALL_SUBTEST_5(test_patch_padding_same());
756   CALL_SUBTEST_6(test_imagenet_patches());
757 }
758