1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <math.h>
14 #include <stdio.h>
15 
16 #include "third_party/googletest/src/googletest/include/gtest/gtest.h"
17 
18 #include "config/av1_rtcd.h"
19 
20 #include "aom_ports/aom_timer.h"
21 #include "av1/encoder/cnn.h"
22 #include "av1/encoder/partition_cnn_weights.h"
23 #include "test/acm_random.h"
24 #include "test/function_equivalence_test.h"
25 #include "test/util.h"
26 
27 #define SQR(x) ((x) * (x))
28 
29 // Best possible pixelwise guarenteed preicison given each float has at most
30 // 3 specified decimals.
31 #define PIXELWISE_FLOAT_TOL 1E-2
32 
33 #define MSE_FLOAT_TOL 1E-6
34 #define MSE_INT_TOL 0
35 
36 // CNN convolve pixelwise error threshold for functional equivalence.
37 #define CNN_CONVOLVE_PIXELWISE_FLOAT_TOL 1E-3f
38 
39 namespace {
40 
41 class CNNTest : public ::testing::Test {
42  protected:
RunCNNTest(int image_width,int image_height,const float * input,const float * expected,const CNN_CONFIG * cnn_config,int in_stride,CNN_THREAD_DATA * thread_data,double tolerance)43   static void RunCNNTest(int image_width, int image_height, const float *input,
44                          const float *expected, const CNN_CONFIG *cnn_config,
45                          int in_stride, CNN_THREAD_DATA *thread_data,
46                          double tolerance) {
47     int out_width, out_height, out_channels;
48     av1_find_cnn_output_size(image_width, image_height, cnn_config, &out_width,
49                              &out_height, &out_channels);
50 
51     const int out_size = out_width * out_height;
52     const int out_stride = out_width;
53 
54     float *output_ =
55         (float *)aom_malloc(sizeof(*output_) * out_size * out_channels);
56     float *output[CNN_MAX_CHANNELS] = { nullptr };
57     for (int channel = 0; channel < out_channels; ++channel) {
58       output[channel] = output_ + (channel * out_size);
59     }
60     const int num_outputs = 1;
61     const int output_chs[1] = { out_channels };
62     const int output_strides[1] = { out_stride };
63     CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_strides,
64                                     output };
65 
66     RunMultiOutCNNTest(&input, image_width, image_height, in_stride, cnn_config,
67                        thread_data, &output_struct, &expected, tolerance);
68 
69     aom_free(output_);
70   }
71 
RunMultiOutCNNTest(const float ** input,int image_width,int image_height,int in_stride,const CNN_CONFIG * cnn_config,CNN_THREAD_DATA * thread_data,CNN_MULTI_OUT * output,const float ** expected,double tolerance)72   static void RunMultiOutCNNTest(const float **input, int image_width,
73                                  int image_height, int in_stride,
74                                  const CNN_CONFIG *cnn_config,
75                                  CNN_THREAD_DATA *thread_data,
76                                  CNN_MULTI_OUT *output, const float **expected,
77                                  double tolerance) {
78     const int num_outputs = output->num_outputs;
79     const int *output_chs = output->output_channels;
80 
81     int *out_widths = (int *)aom_calloc(num_outputs, sizeof(*out_widths));
82     int *out_heights = (int *)aom_calloc(num_outputs, sizeof(*out_heights));
83     int *not_used = (int *)aom_calloc(num_outputs, sizeof(*not_used));
84 
85     av1_find_cnn_output_size(image_width, image_height, cnn_config, out_widths,
86                              out_heights, not_used);
87     av1_cnn_predict(input, image_width, image_height, in_stride, cnn_config,
88                     thread_data, output);
89 
90     int channel_offset = 0;
91     for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
92       const float *expected_out = expected[output_idx];
93       const int curr_output_chs = output_chs[output_idx];
94       const int out_size = out_widths[output_idx] * out_heights[output_idx];
95 
96       double mse = 0;
97       int expected_ite = 0;
98       for (int channel = 0; channel < curr_output_chs; ++channel) {
99         const float *buf_out = output->output_buffer[channel_offset];
100 
101         for (int i = 0; i < out_size; ++i) {
102           EXPECT_NEAR(expected_out[expected_ite], buf_out[i],
103                       PIXELWISE_FLOAT_TOL)
104               << " output " << output_idx << " channel " << channel << " pixel "
105               << expected_ite % out_size << ": " << expected_out[expected_ite]
106               << "/" << buf_out[i] << std::endl;
107           mse += SQR(expected_out[expected_ite] - buf_out[i]);
108           expected_ite++;
109         }
110 
111         channel_offset++;
112       }
113       mse /= (out_size * curr_output_chs);
114       EXPECT_LE(mse, tolerance) << " output " << output_idx << std::endl;
115     }
116 
117     aom_free(out_widths);
118     aom_free(out_heights);
119     aom_free(not_used);
120   }
121 
AssignLayerWeightsBiases(CNN_CONFIG * cnn_config,float * weights,float * bias)122   static void AssignLayerWeightsBiases(CNN_CONFIG *cnn_config, float *weights,
123                                        float *bias) {
124     size_t weight_offset = 0;
125     size_t bias_offset = 0;
126     for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
127       CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[layer];
128       layer_config->weights = weights + weight_offset;
129       layer_config->bias = bias + bias_offset;
130       weight_offset += layer_config->filter_width *
131                        layer_config->filter_height * layer_config->in_channels *
132                        layer_config->out_channels;
133       bias_offset += layer_config->out_channels;
134 
135       ASSERT_NE(layer_config->weights, nullptr);
136       ASSERT_NE(layer_config->bias, nullptr);
137     }
138   }
139 };
140 
141 }  // namespace
142 
TEST_F(CNNTest,TestMultilayerConvolution)143 TEST_F(CNNTest, TestMultilayerConvolution) {
144   int image_height = 16;
145   int image_width = 16;
146   int filter_height = 5;
147   int filter_width = 4;
148 
149   float input[] = {
150     -3, 1,  -3, 2,  -2, -2, 2,  -2, 1,  -2, -3, 1,  2,  2,  2,  -2, 0,  1,  -1,
151     -3, -1, -1, 1,  0,  -3, 1,  0,  -1, 1,  0,  0,  -3, -3, -3, 0,  2,  1,  -1,
152     2,  0,  1,  -3, -1, 2,  2,  1,  -2, 0,  -1, 0,  -2, -2, -1, 1,  0,  0,  0,
153     -2, -2, -2, 1,  1,  -2, 1,  1,  -2, -2, 1,  -2, -1, -2, -3, 2,  -3, -1, 1,
154     0,  -2, -2, -2, 1,  -2, -2, -1, -1, 2,  2,  2,  -1, 1,  -3, -3, 0,  2,  0,
155     2,  1,  -3, -3, 1,  2,  2,  1,  -2, -3, 0,  -3, 0,  -3, -2, 0,  1,  1,  0,
156     -3, 2,  -1, 2,  1,  0,  1,  -2, 1,  -1, -1, 2,  0,  -2, -3, 1,  1,  -2, -1,
157     -3, -3, -1, 0,  -3, -2, 0,  0,  1,  0,  -3, -2, -1, 1,  0,  2,  1,  0,  -3,
158     -2, -3, -3, -1, 0,  -2, 2,  -1, -3, 0,  -1, -1, 2,  0,  -3, -2, -1, 0,  0,
159     1,  -2, 1,  2,  1,  2,  2,  -3, 2,  -1, 0,  0,  -1, 0,  2,  2,  -1, 2,  -2,
160     1,  1,  -3, -3, 1,  -1, -1, -2, 2,  -2, -2, 2,  -1, -3, 2,  -3, 1,  -1, -1,
161     -3, 1,  -1, 1,  0,  -3, -3, 1,  -3, -3, 0,  2,  2,  -2, -1, 2,  0,  2,  1,
162     -1, -3, 0,  0,  -1, -1, 1,  0,  2,  0,  -3, 2,  1,  0,  1,  -3, 2,  -3, -3,
163     -1, -3, -3, 2,  0,  2,  -2, 1,  -1,
164   };
165 
166   float weights[] = {
167     -2, 2,  -2, 2,  -1, -3, 2,  2,  0,  0,  -3, -1, -2, -3, 1,  -1, 0,  0,  0,
168     2,  -2, 2,  -2, -3, 1,  1,  1,  -3, -1, 0,  1,  2,  -2, 0,  -1, -3, -1, -2,
169     2,  -3, -3, 1,  -2, -3, 0,  2,  1,  -3, -3, -1, -3, -2, -1, -3, -1, -3, -2,
170     -1, -3, -1, -2, -2, -3, 2,  0,  -3, 0,  -3, -3, 1,  -3, -1, 0,  -1, 1,  1,
171     -1, 1,  -2, 0,  2,  0,  -3, 1,  -1, -1, 2,  0,  1,  -3, -3, 1,  2,  -3, -3,
172     1,  -3, 2,  0,  -3, 1,  2,  2,  -2, -1, -2, 1,  1,  0,  -2, -2, 1,  2,  -1,
173     -3, 1,  -2, 2,  -3, -2, -3, 2,  1,  0,  -2, 0,  1,  -3, 2,  -2, -2, 0,  2,
174     -3, 2,  0,  0,  1,  -2, 1,  1,  -2, -1, -2, 1,  -2, 0,  -2, -2, 0,  -1, -1,
175     -3, -3, -3, 1,  -3, -2, 2,  -1, 2,  0,  2,  -2, 2,  -2, 1,  -3, -3, -1, 0,
176     2,  2,  1,  -1, -3, -1, -3, 2,  1,  -2, 0,  -3, -1, -3, -1, 2,  1,  0,  2,
177     -1, 1,  0,  1,  2,  -1, -2, 2,  1,  -3, -1, -3, 0,  1,  -2, 0,  -2, -3, 0,
178     -2, 2,  2,  0,  0,  2,  -3, 2,  -3, -2, 1,  2,  -3, -3, -1, -3, 0,  -3, -3,
179     -2, -2, -2, 0,  0,  1,  0,  0,  -1, 0,  0,  -3, 0,  -3, -1, -2, 1,  -2, -1,
180     2,  -2, 0,  0,  1,  0,  -2, -1, 0,  -3, 1,  0,  -1, -3, 1,  -1, 1,  -1, -3,
181     1,  0,  1,  1,  -1, 2,  2,  0,  0,  1,  -3, 2,  -2, -2, -3, -2, -1, -2, 2,
182     0,  2,  -2, -3, -1, -3, 2,  2,  -1, 2,  2,  -1, 0,  -3, 1,
183   };
184 
185   float bias[] = {
186     1, -1, 0, 1, 1, 1, -2,
187   };
188 
189   float expected_same[] = {
190     -1125, 2926,  6406,  631,   -1244, 97,    -1454, 2526,  1065,  3292,  3464,
191     2553,  -330,  532,   1038,  1182,  -402,  3758,  3392,  9854,  4365,  1408,
192     4736,  3134,  3838,  2409,  3221,  4350,  6750,  4045,  815,   1188,  2959,
193     9802,  9590,  4572,  5740,  4253,  1701,  7974,  7012,  6854,  7093,  3907,
194     4539,  3886,  4267,  3505,  465,   7824,  9219,  10026, 7968,  957,   2295,
195     5594,  10811, 9641,  5950,  10043, 8783,  3132,  1421,  1110,  4108,  13929,
196     10660, -84,   -61,   3932,  -180,  6811,  13393, 15147, 15640, 9337,  6961,
197     3808,  1604,  1398,  1047,  6739,  10144, 6517,  4698,  2678,  7389,  2595,
198     5248,  12075, 11272, 13951, 8820,  1090,  2199,  2206,  2788,  12116, 6683,
199     2612,  -291,  3183,  9414,  12316, 14524, 12333, 13208, 7832,  4664,  4657,
200     3534,  1298,  -666,  4250,  7707,  9103,  5760,  688,   9571,  15782, 14203,
201     14878, 17339, 14684, 8690,  5671,  875,   1429,  1531,  6173,  2984,  5558,
202     2996,  7928,  6733,  16117, 15262, 12757, 7980,  3923,  4795,  5973,  2051,
203     455,   -1922, 1816,  5906,  3321,  10908, 10910, 7377,  12204, 12809, 11195,
204     7451,  6666,  74,    -1645, -35,   -391,  3813,  7324,  892,   1656,  6095,
205     12193, 14648, 12156, 14663, 10251, 10325, 7821,  3925,  323,   697,   442,
206     1324,  4669,  7002,  5485,  5171,  5086,  10582, 11053, 9709,  11353, 8543,
207     5256,  2873,  235,   -628,  1496,  1878,  -867,  3420,  6865,  5937,  10182,
208     13277, 10069, 10789, 5998,  624,   -2082, 4417,  1258,  -1080, -819,  -1430,
209     1033,  5220,  6335,  8471,  8980,  11908, 14430, 12584, 8404,  1576,  -803,
210     985,   1481,  1367,  -193,  873,   3684,  2288,  6676,  9477,  11155, 9602,
211     9707,  10507, 4739,  3174,  -575,  -178,  3002,  1710,  423,   -477,  554,
212     3088,  2029,  5113,  5000,  3771,  6090,  5365,  1185,  2855,  399,   -312,
213     -1577, 176,   955,
214   };
215 
216   float expected_replicate[] = {
217     13768, 13528, 12999, 6906,  4618,  4043,  2611,  9955,  6685,  4776,  2753,
218     1036,  3063,  4544,  5183,  7349,  12451, 12501, 9131,  12753, 8908,  4058,
219     6299,  7542,  7115,  3307,  3360,  3543,  9754,  7808,  5991,  9019,  14320,
220     14919, 12492, 6871,  7373,  3336,  2085,  10604, 9377,  6882,  5009,  3103,
221     6220,  6278,  7588,  10196, 11045, 11563, 11842, 11911, 8279,  2030,  1858,
222     6368,  12123, 9909,  6347,  10345, 9365,  4038,  1673,  3051,  16492, 16649,
223     12276, 408,   -301,  4122,  -654,  7864,  14038, 15279, 15315, 9744,  8243,
224     5298,  746,   380,   9824,  9124,  10895, 6640,  4712,  2669,  6980,  2759,
225     5385,  12345, 11336, 13129, 8600,  2370,  3682,  5219,  12407, 13123, 6784,
226     2612,  -291,  3183,  9414,  12316, 14524, 12333, 13397, 7543,  3916,  4153,
227     4477,  4314,  7983,  8418,  9163,  9103,  5760,  688,   9571,  15782, 14203,
228     14878, 17718, 14570, 7940,  6642,  5094,  7133,  9964,  10219, 3224,  5558,
229     2996,  7928,  6733,  16117, 15262, 12757, 7958,  4401,  5187,  5476,  5529,
230     6055,  2206,  3909,  6015,  3321,  10908, 10910, 7377,  12204, 12809, 11195,
231     6967,  6840,  481,   -1600, 274,   1,     10373, 8514,  1123,  2117,  6758,
232     12736, 16223, 13585, 15988, 11771, 10600, 7918,  4156,  2840,  3111,  3287,
233     6359,  7652,  8813,  6530,  6967,  7789,  13671, 13990, 13247, 13241, 9836,
234     5251,  3024,  2313,  1834,  4187,  2637,  -1312, 2139,  7378,  7665,  11933,
235     15591, 15314, 15678, 9531,  2820,  -1516, 3400,  1314,  22,    363,   -2896,
236     -898,  5906,  7308,  10650, 12975, 16978, 20370, 18817, 12381, 4118,  -861,
237     -137,  236,   1802,  1632,  -350,  2334,  3400,  8680,  14064, 18216, 18675,
238     21765, 22871, 11491, 4937,  -1555, -11,   1669,  2392,  3265,  -5254, -217,
239     5001,  8063,  13444, 18884, 19706, 22794, 21064, 9545,  6689,  -7,    289,
240     -2021, 504,   2347,
241   };
242 
243   float expected_valid[] = {
244     2612,  -291,  3183,  9414,  12316, 14524, 12333, 9103,  5760,  688,
245     9571,  15782, 14203, 14878, 5558,  2996,  7928,  6733,  16117, 15262,
246     12757, 3321,  10908, 10910, 7377,  12204, 12809, 11195,
247   };
248 
249   CNN_CONFIG cnn_config = { 3,
250                             0,
251                             0,
252                             0,
253                             0,
254                             {
255                                 {
256                                     1,
257                                     filter_width,
258                                     filter_height,
259                                     3,
260                                     1,
261                                     1,
262                                     0,
263                                     nullptr,
264                                     nullptr,
265                                     PADDING_SAME_ZERO,
266                                     NONE,
267                                     0,
268                                     0,
269                                     BRANCH_NO_COPY,
270                                     BRANCH_NOC,
271                                     {},
272                                     {},
273                                     -1,
274                                 },
275                                 {
276                                     3,
277                                     filter_width,
278                                     filter_height,
279                                     3,
280                                     1,
281                                     1,
282                                     0,
283                                     nullptr,
284                                     nullptr,
285                                     PADDING_SAME_ZERO,
286                                     NONE,
287                                     0,
288                                     0,
289                                     BRANCH_NO_COPY,
290                                     BRANCH_NOC,
291                                     {},
292                                     {},
293                                     -1,
294                                 },
295                                 {
296                                     3,
297                                     filter_width,
298                                     filter_height,
299                                     1,
300                                     1,
301                                     1,
302                                     0,
303                                     nullptr,
304                                     nullptr,
305                                     PADDING_SAME_ZERO,
306                                     NONE,
307                                     0,
308                                     0,
309                                     BRANCH_NO_COPY,
310                                     BRANCH_NOC,
311                                     {},
312                                     {},
313                                     0,
314                                 },
315                             } };
316 
317   // Weights and biases need to be specified separately because
318   // of the offset.
319   AssignLayerWeightsBiases(&cnn_config, weights, bias);
320 
321   CNN_THREAD_DATA thread_data = { 1, NULL };
322 
323   RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
324              image_width, &thread_data, MSE_INT_TOL);
325 
326   for (int i = 0; i < cnn_config.num_layers; ++i) {
327     cnn_config.layer_config[i].pad = PADDING_SAME_REPLICATE;
328   }
329 
330   RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
331              image_width, &thread_data, MSE_INT_TOL);
332 
333   for (int i = 0; i < cnn_config.num_layers; ++i) {
334     cnn_config.layer_config[i].pad = PADDING_VALID;
335   }
336 
337   RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
338              image_width, &thread_data, MSE_INT_TOL);
339 }
340 
TEST_F(CNNTest,TestRELUSingleLayer)341 TEST_F(CNNTest, TestRELUSingleLayer) {
342   int image_width = 8;
343   int image_height = 8;
344   int filter_height = 5;
345   int filter_width = 4;
346   float input[] = {
347     0, -2, -3, 1,  -1, 2,  -2, 1,  -3, -1, 0,  1,  -2, -3, -2, -2,
348     1, -3, 2,  -3, -1, -1, 2,  0,  -2, -3, 0,  -2, -3, 1,  -1, -1,
349     2, -2, 0,  -2, -3, -3, 1,  1,  -1, 1,  0,  1,  -3, 0,  2,  2,
350     0, -3, 1,  -3, 2,  -2, 1,  -1, -1, -2, -3, -2, -1, -3, -2, -1,
351   };
352   float expected_same[] = {
353     9,  0,  1,  1,  0,  3,  0,  19, 0,  12, 10, 0,  0,  0,  5, 0,
354     0,  18, 21, 7,  19, 4,  3,  0,  0,  9,  16, 0,  11, 16, 0, 11,
355     12, 2,  0,  11, 0,  16, 6,  0,  8,  22, 13, 10, 12, 0,  0, 0,
356     0,  1,  2,  12, 29, 6,  10, 0,  13, 0,  0,  5,  8,  10, 0, 0,
357   };
358   float expected_replicate[] = {
359     18, 17, 12, 2,  0,  0,  5,  11, 0,  17, 22, 6,  0,  0,  17, 0,
360     0,  18, 21, 7,  19, 4,  3,  5,  3,  9,  16, 0,  11, 16, 0,  3,
361     3,  2,  0,  11, 0,  16, 6,  0,  17, 22, 13, 10, 12, 0,  0,  0,
362     0,  4,  1,  10, 30, 7,  10, 0,  23, 8,  0,  13, 15, 19, 8,  10,
363   };
364   float expected_valid[] = {
365     18, 21, 7, 19, 4, 9, 16, 0, 11, 16, 2, 0, 11, 0, 16, 22, 13, 10, 12, 0,
366   };
367   float weights[] = {
368     -2, -3, 1, 2, 2, -2, -3, 0, -3, 2, 2, -3, -3, -2, 0, 1, 2, 0, -1, -1,
369   };
370   float bias[] = { -3 };
371 
372   CNN_CONFIG cnn_config = { 1,
373                             0,
374                             0,
375                             0,
376                             0,
377                             { {
378                                 1,
379                                 filter_width,
380                                 filter_height,
381                                 1,
382                                 1,
383                                 1,
384                                 0,
385                                 weights,
386                                 bias,
387                                 PADDING_SAME_ZERO,
388                                 RELU,
389                                 0,
390                                 0,
391                                 BRANCH_NO_COPY,
392                                 BRANCH_NOC,
393                                 {},
394                                 {},
395                                 0,
396                             } } };
397 
398   CNN_THREAD_DATA thread_data = { 1, NULL };
399 
400   RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
401              image_width, &thread_data, MSE_INT_TOL);
402 
403   cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
404 
405   RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
406              image_width, &thread_data, MSE_INT_TOL);
407 
408   cnn_config.layer_config[0].pad = PADDING_VALID;
409 
410   RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
411              image_width, &thread_data, MSE_INT_TOL);
412 }
413 
TEST_F(CNNTest,TestVaryingStridesVaryingDimImages)414 TEST_F(CNNTest, TestVaryingStridesVaryingDimImages) {
415   float weights[] = {
416     1,  -5, -3, -4, -1, 1,  2,  -3, 2,  2,  -1, 1,  -5, 1,  1,
417     -3, -5, 3,  1,  4,  -2, -5, -2, -3, -5, 0,  -1, -5, 2,  -2,
418     -2, 1,  -2, -4, 1,  3,  -2, 2,  0,  -3, 2,  -3, -2, -3,
419   };
420   float bias[] = { 2 };
421 
422   CNN_CONFIG cnn_config = { 1,
423                             0,
424                             0,
425                             0,
426                             0,
427                             {
428                                 {
429                                     1,
430                                     4,
431                                     11,
432                                     1,
433                                     7,
434                                     6,
435                                     0,
436                                     weights,
437                                     bias,
438                                     PADDING_SAME_ZERO,
439                                     NONE,
440                                     0,
441                                     0,
442                                     BRANCH_NO_COPY,
443                                     BRANCH_NOC,
444                                     {},
445                                     {},
446                                     0,
447                                 },
448                             } };
449 
450   int image_height = 24;
451   int image_width = 17;
452   float input[] = {
453     -1, -3, 4,  4,  -5, 4,  3,  -5, -1, -3, 4,  -4, 2,  -3, 3,  -5, 2,  -1, -5,
454     1,  -1, 3,  1,  -3, -3, 4,  0,  2,  -3, -5, -5, -4, 0,  -5, -2, -3, -1, -2,
455     2,  -5, 4,  4,  0,  -4, -3, 1,  -3, -5, -4, -4, 1,  -2, -3, 3,  -3, -3, -1,
456     -5, -5, -2, 3,  1,  -1, -5, -5, 1,  -4, -2, -1, -2, -4, -4, 2,  -2, 2,  1,
457     -2, -4, -1, 1,  -2, -5, 3,  -2, -1, -1, -5, -3, 1,  -2, -2, -3, -1, -2, -4,
458     -2, 1,  -4, -1, 4,  3,  -4, 0,  4,  2,  2,  4,  -3, -5, 2,  2,  1,  -1, -4,
459     -2, 1,  3,  2,  0,  4,  -1, -3, 2,  1,  -4, 2,  2,  -4, -2, 0,  -2, -1, 4,
460     4,  2,  3,  -4, 2,  -4, -5, 4,  -1, -3, -1, 0,  -4, 1,  3,  -1, -3, -5, 3,
461     -2, -4, 1,  2,  -2, -3, -3, -5, 1,  -3, -1, 0,  -1, 3,  -4, -1, -5, -5, 1,
462     0,  0,  -2, -2, 2,  -2, 0,  0,  2,  0,  -3, 0,  -1, -4, -4, -1, 3,  -4, -4,
463     -1, 0,  -5, -3, -2, 4,  -3, -4, -4, 0,  -5, 1,  -2, -3, -3, -4, 4,  3,  4,
464     3,  3,  -1, 3,  1,  -3, -2, 3,  3,  0,  2,  -4, -3, 2,  2,  0,  -2, 4,  -2,
465     2,  -2, -1, -4, -2, 2,  -4, 3,  -1, 4,  1,  1,  4,  -1, -4, -4, 1,  1,  -2,
466     4,  -1, 3,  2,  -3, 4,  3,  1,  4,  0,  -4, 2,  0,  2,  4,  -2, -2, 4,  2,
467     -1, -2, 1,  -3, 2,  3,  -5, -3, 4,  4,  2,  -5, -4, -5, -2, -4, 2,  0,  2,
468     -5, 4,  -4, -2, -5, 2,  1,  0,  4,  1,  -2, -3, -4, -3, -4, 3,  3,  2,  0,
469     -3, 1,  -5, 4,  0,  4,  -1, 3,  -5, -5, -2, -1, -1, 4,  3,  3,  4,  3,  -4,
470     4,  -3, -3, -1, -4, -1, -4, -1, -2, 4,  -2, -4, 4,  4,  -3, -4, -1, 1,  2,
471     -1, -2, -2, 3,  2,  2,  -3, 0,  -1, 0,  3,  2,  -5, 0,  -4, 0,  0,  2,  -4,
472     -1, -1, 0,  -2, 0,  1,  0,  0,  4,  -5, -1, -5, 2,  -1, 0,  2,  -1, 1,  3,
473     -3, -5, -2, -3, 4,  -2, -2, -1, -3, -4, -1, -2, -4, 1,  4,  -3, -2, -1, 3,
474     -3, -2, 3,  2,  1,  -4, -3, -5, 1,
475   };
476   float expected_1[] = {
477     41, -26, 5, 76, 13, 83, -21, 53, -54, -14, 21, 121,
478   };
479 
480   CNN_THREAD_DATA thread_data = { 1, NULL };
481 
482   RunCNNTest(image_width, image_height, input, expected_1, &cnn_config,
483              image_width, &thread_data, MSE_INT_TOL);
484 
485   cnn_config.layer_config[0].skip_width = 6;
486   cnn_config.layer_config[0].skip_height = 7;
487 
488   float expected_2[] = {
489     21, -50, 41, 20, 72, 127, -21, 103, 62, -37, 83, -3,
490   };
491   RunCNNTest(image_width, image_height, input, expected_2, &cnn_config,
492              image_width, &thread_data, MSE_INT_TOL);
493 
494   cnn_config.layer_config[0].skip_width = 3;
495   cnn_config.layer_config[0].skip_height = 10;
496 
497   float expected_3[] = {
498     -26, -21, -35, 69, 49,  4,  -51, -43, -56,
499     -41, 15,  -44, 40, -62, 63, 38,  27,  47,
500   };
501   RunCNNTest(image_width, image_height, input, expected_3, &cnn_config,
502              image_width, &thread_data, MSE_INT_TOL);
503 
504   cnn_config.layer_config[0].skip_width = 10;
505   cnn_config.layer_config[0].skip_height = 3;
506 
507   float expected_4[] = {
508     21, 49, 28, 87, 50, 40, 102, 81, 58, 85, 51, 66, 36, 19, -37, -45,
509   };
510 
511   RunCNNTest(image_width, image_height, input, expected_4, &cnn_config,
512              image_width, &thread_data, MSE_INT_TOL);
513 }
514 
TEST_F(CNNTest,TestMaxPool)515 TEST_F(CNNTest, TestMaxPool) {
516   int image_width = 8;
517   int image_height = 8;
518   int stride = 3;
519   float input[] = {
520     1,  -4, -4, 8, 0, 7, -5, -2, 8, 2, 2, 8,  5,  -1, -1, 9,
521     -3, 0,  -2, 0, 6, 3, -4, 8,  7, 8, 7, -1, 4,  -1, 0,  2,
522     -5, -2, 8,  5, 5, 4, 2,  7,  4, 6, 2, 8,  8,  -4, -3, -4,
523     -3, -1, 2,  3, 3, 6, -5, 8,  9, 5, 0, -2, -1, 6,  5,  7,
524   };
525 
526   float expected[] = {
527     49, 58, 70, 68, 68, 70, 48, 57, 88,
528   };
529 
530   float weights[] = {
531     3, 1, 3, 4, -1, 5, -2, 1, -4,
532   };
533 
534   float bias[] = {
535     -3,
536   };
537 
538   CNN_CONFIG cnn_config = { 1,
539                             0,
540                             0,
541                             0,
542                             0,
543                             { {
544                                 1,
545                                 3,
546                                 3,
547                                 1,
548                                 stride,
549                                 stride,
550                                 1,
551                                 weights,
552                                 bias,
553                                 PADDING_SAME_ZERO,
554                                 NONE,
555                                 0,
556                                 0,
557                                 BRANCH_NO_COPY,
558                                 BRANCH_NOC,
559                                 {},
560                                 {},
561                                 0,
562                             } } };
563 
564   CNN_THREAD_DATA thread_data = { 1, NULL };
565 
566   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
567              image_width, &thread_data, MSE_INT_TOL);
568 }
569 
TEST_F(CNNTest,TestDeconvolveNonActivationSingleLayerSingleKernel)570 TEST_F(CNNTest, TestDeconvolveNonActivationSingleLayerSingleKernel) {
571   int image_width = 4;
572   int image_height = 7;
573   float input[] = {
574     9,  6,   181, 9,  218, 30, 80,  108, 68,  216, 70, 128, 179, 228,
575     33, 212, 34,  14, 48,  27, 230, 23,  202, 113, 80, 56,  122, 112,
576   };
577 
578   float expected_1_same[] = {
579     15,   -30,  36,   -525,  377, -193, 558, 531,  6,   -24,  -15,  124,
580     166,  -561, -356, -754,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
581     433,  -311, 711,  381,   247, -317, 453, 129,  215, -627, -409, -885,
582     17,   -255, -55,  -647,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
583     133,  -719, 633,  -225,  785, 191,  463, 79,   65,  9,    77,   -853,
584     -365, -949, -15,  -667,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
585     355,  -866, 990,  207,   747, 12,   520, -116, 176, -312, -133, -1370,
586     -426, -802, 143,  -771,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
587     65,   -79,  127,  -59,   135, -90,  195, 114,  31,  -91,  -57,  -133,
588     17,   -176, -72,  -276,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
589     457,  -302, 733,  58,    470, -475, 829, 490,  227, -670, -440, -790,
590     153,  -588, -294, -1150, -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
591     157,  -251, 349,  -185,  409, -293, 587, 251,  77,  -187, -107, -369,
592     7,    -481, -135, -827,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
593   };
594   float expected_1_valid[] = {
595     -30,  15,   -30,  36,   -525,  377,  -193,  558,  531,  24,   24,   6,
596     6,    -24,  -15,  124,  166,   -561, -356,  -754, -21,  -39,  -3,   -3,
597     -3,   -3,   -3,   -3,   -3,    -3,   -3,    -3,   -3,   -657, 433,  -311,
598     711,  381,  247,  -317, 453,   129,  321,   321,  215,  215,  -627, -409,
599     -885, 17,   -255, -55,  -647,  -219, -435,  -3,   -3,   -3,   -3,   -3,
600     -3,   -3,   -3,   -3,   -3,    -3,   -207,  133,  -719, 633,  -225, 785,
601     191,  463,  79,   381,  381,   65,   65,    9,    77,   -853, -365, -949,
602     -15,  -667, -259, -515, -3,    -3,   -3,    -3,   -3,   -3,   -3,   -3,
603     -3,   -3,   -3,   -540, 355,   -866, 990,   207,  747,  12,   520,  -116,
604     633,  633,  176,  176,  -312,  -133, -1370, -426, -802, 143,  -771, -427,
605     -851, -3,   -3,   -3,   -3,    -3,   -3,    -3,   -3,   -3,   -3,   -3,
606     -105, 65,   -79,  127,  -59,   135,  -90,   195,  114,  78,   78,   31,
607     31,   -91,  -57,  -133, 17,    -176, -72,   -276, -57,  -111, -3,   -3,
608     -3,   -3,   -3,   -3,   -3,    -3,   -3,    -3,   -3,   -693, 457,  -302,
609     733,  58,   470,  -475, 829,   490,  336,   336,  227,  227,  -670, -440,
610     -790, 153,  -588, -294, -1150, -229, -455,  -3,   -3,   -3,   -3,   -3,
611     -3,   -3,   -3,   -3,   -3,    -3,   -243,  157,  -251, 349,  -185, 409,
612     -293, 587,  251,  333,  333,   77,   77,    -187, -107, -369, 7,    -481,
613     -135, -827, -227, -451,
614   };
615   float weights_1[] = { -3, 2, -1, 3, 3, 1, 1, -3, -2, -4 };
616   float bias_1[] = { -3 };
617 
618   CNN_CONFIG cnn_config = { 1,
619                             0,
620                             0,
621                             0,
622                             0,
623                             { {
624                                 1,
625                                 5,
626                                 2,
627                                 1,
628                                 2,
629                                 3,
630                                 0,
631                                 weights_1,
632                                 bias_1,
633                                 PADDING_SAME_ZERO,
634                                 NONE,
635                                 1,
636                                 0,
637                                 BRANCH_NO_COPY,
638                                 BRANCH_NOC,
639                                 {},
640                                 {},
641                                 0,
642                             } } };
643 
644   CNN_THREAD_DATA thread_data = { 1, NULL };
645 
646   RunCNNTest(image_width, image_height, input, expected_1_same, &cnn_config,
647              image_width, &thread_data, MSE_INT_TOL);
648 
649   // Change padding to valid
650   cnn_config.layer_config[0].pad = PADDING_VALID;
651 
652   RunCNNTest(image_width, image_height, input, expected_1_valid, &cnn_config,
653              image_width, &thread_data, MSE_INT_TOL);
654 
655   float expected_12_same[] = {
656     15,  -12,  6,    36,   -9,   -528, 377,  -184, 513,  558,  -12,  24,
657     6,   -30,  -15,  -33,  -21,  166,  154,  -546, -356, -718, -30,  -21,
658     433, -221, 561,  711,  -33,  -153, 247,  -83,  -87,  453,  -111, 321,
659     215, -657, -409, -845, -93,  17,   -43,  -243, -55,  -215, -327, -219,
660     133, -71,  -447, 633,  -219, 435,  785,  -73,  -177, 463,  -131, 381,
661     65,  -207, 77,   -59,  -651, -365, -797, -213, -15,  -155, -387, -259,
662     355, -182, -150, 990,  -231, 582,  747,  -36,  -540, 520,  -215, 633,
663     176, -540, -133, -491, -687, -426, -882, -102, 143,  77,   -639, -427,
664     65,  -37,  57,   127,  -17,  -105, 135,  -51,  60,   195,  -30,  78,
665     31,  -105, -57,  -125, -45,  17,   -11,  -147, -72,  -168, -84,  -57,
666     457, -233, 618,  733,  -26,  -540, 470,  -205, 264,  829,  -116, 336,
667     227, -693, -440, -900, -72,  153,  107,  -609, -294, -698, -342, -229,
668     157, -83,  69,   349,  -59,  -201, 409,  -125, 27,   587,  -115, 333,
669     77,  -243, -107, -267, -171, 7,    -105, -369, -135, -379, -339, -227,
670   };
671   float expected_12_valid[] = {
672     -30,  15,   -12,  6,    36,   -9,   -528, 377,  -184, 513,  558,  -12,
673     24,   24,   6,    6,    -30,  -15,  -33,  -21,  166,  154,  -546, -356,
674     -718, -30,  -21,  -39,  -657, 433,  -221, 561,  711,  -33,  -153, 247,
675     -83,  -87,  453,  -111, 321,  321,  215,  215,  -657, -409, -845, -93,
676     17,   -43,  -243, -55,  -215, -327, -219, -435, -207, 133,  -71,  -447,
677     633,  -219, 435,  785,  -73,  -177, 463,  -131, 381,  381,  65,   65,
678     -207, 77,   -59,  -651, -365, -797, -213, -15,  -155, -387, -259, -515,
679     -540, 355,  -182, -150, 990,  -231, 582,  747,  -36,  -540, 520,  -215,
680     633,  633,  176,  176,  -540, -133, -491, -687, -426, -882, -102, 143,
681     77,   -639, -427, -851, -105, 65,   -37,  57,   127,  -17,  -105, 135,
682     -51,  60,   195,  -30,  78,   78,   31,   31,   -105, -57,  -125, -45,
683     17,   -11,  -147, -72,  -168, -84,  -57,  -111, -693, 457,  -233, 618,
684     733,  -26,  -540, 470,  -205, 264,  829,  -116, 336,  336,  227,  227,
685     -693, -440, -900, -72,  153,  107,  -609, -294, -698, -342, -229, -455,
686     -243, 157,  -83,  69,   349,  -59,  -201, 409,  -125, 27,   587,  -115,
687     333,  333,  77,   77,   -243, -107, -267, -171, 7,    -105, -369, -135,
688     -379, -339, -227, -451,
689   };
690 
691   // Change skip_width, skip_height to {2, 3}
692   cnn_config.layer_config[0].skip_width = 3;
693   cnn_config.layer_config[0].skip_height = 2;
694   // Set padding to same
695   cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
696 
697   RunCNNTest(image_width, image_height, input, expected_12_same, &cnn_config,
698              image_width, &thread_data, MSE_INT_TOL);
699 
700   // Change padding to valid
701   cnn_config.layer_config[0].pad = PADDING_VALID;
702   RunCNNTest(image_width, image_height, input, expected_12_valid, &cnn_config,
703              image_width, &thread_data, MSE_INT_TOL);
704 
705   cnn_config.layer_config[0].filter_width = 4;
706   cnn_config.layer_config[0].filter_height = 3;
707   float weights_2[] = { -1, -3, -1, -3, 0, 2, -2, 4, 3, 0, 1, 4 };
708   float bias_2[] = { -4 };
709   cnn_config.layer_config[0].weights = weights_2;
710   cnn_config.layer_config[0].bias = bias_2;
711 
712   cnn_config.layer_config[0].skip_width = 5;
713   cnn_config.layer_config[0].skip_height = 2;
714   float expected_2_same[] = {
715     -13,  -31,  -13,  -31,  -4,   -10,  -22,  -10,  -22,  -4,   -185, -547,
716     -185, -547, -4,   -13,  -31,  -13,  -31,  -4,   -4,   14,   -22,  32,
717     -4,   -4,   8,    -16,  20,   -4,   -4,   358,  -366, 720,  -4,   -4,
718     14,   -22,  32,   -4,   -195, -658, -213, -622, -4,   -16,  -94,  -28,
719     -70,  -4,   459,  -244, 97,   480,  -4,   -85,  -328, -103, -292, -4,
720     -4,   432,  -440, 868,  -4,   -4,   56,   -64,  116,  -4,   -4,   156,
721     -164, 316,  -4,   -4,   212,  -220, 428,  -4,   582,  -208, 146,  664,
722     -4,   -130, -652, -190, -532, -4,   166,  -214, 6,    106,  -4,   192,
723     -388, -24,  44,   -4,   -4,   132,  -140, 268,  -4,   -4,   428,  -436,
724     860,  -4,   -4,   136,  -144, 276,  -4,   -4,   252,  -260, 508,  -4,
725     21,   -541, -115, -269, -4,   416,  -688, -16,  176,  -4,   173,  -103,
726     33,   177,  -4,   168,  -640, -88,  -128, -4,   -4,   354,  -362, 712,
727     -4,   -4,   452,  -460, 908,  -4,   -4,   62,   -70,  128,  -4,   -4,
728     420,  -428, 844,  -4,   499,  -106, 141,  610,  -4,   666,  -46,  210,
729     866,  -4,   47,   -148, -19,  -16,  -4,   605,  -85,  181,  763,  -4,
730     -4,   64,   -72,  132,  -4,   -4,   24,   -32,  52,   -4,   -4,   92,
731     -100, 188,  -4,   -4,   50,   -58,  104,  -4,   -132, -694, -200, -558,
732     -4,   15,   -73,  -13,  -17,  -4,   -62,  -610, -158, -418, -4,   -36,
733     -343, -90,  -235, -4,   -4,   456,  -464, 916,  -4,   -4,   42,   -50,
734     88,   -4,   -4,   400,  -408, 804,  -4,   -4,   222,  -230, 448,  -4,
735     606,  -244, 146,  676,  -4,   9,    -172, -37,  -80,  -4,   480,  -370,
736     76,   438,  -4,   223,  -340, -3,   112,  -4,   -4,   156,  -164, 316,
737     -4,   -4,   108,  -116, 220,  -4,   -4,   240,  -248, 484,  -4,   -4,
738     220,  -228, 444,  -4,
739   };
740   float expected_2_valid[] = {
741     -13,  -31,  -13,  -31,  -4,   -10,  -22,  -10,  -22,  -4,   -185, -547,
742     -185, -547, -4,   -13,  -31,  -13,  -31,  -4,   14,   -22,  32,   -4,
743     -4,   8,    -16,  20,   -4,   -4,   358,  -366, 720,  -4,   -4,   14,
744     -22,  32,   -195, -658, -213, -622, -4,   -16,  -94,  -28,  -70,  -4,
745     459,  -244, 97,   480,  -4,   -85,  -328, -103, -292, -4,   432,  -440,
746     868,  -4,   -4,   56,   -64,  116,  -4,   -4,   156,  -164, 316,  -4,
747     -4,   212,  -220, 428,  582,  -208, 146,  664,  -4,   -130, -652, -190,
748     -532, -4,   166,  -214, 6,    106,  -4,   192,  -388, -24,  44,   -4,
749     132,  -140, 268,  -4,   -4,   428,  -436, 860,  -4,   -4,   136,  -144,
750     276,  -4,   -4,   252,  -260, 508,  21,   -541, -115, -269, -4,   416,
751     -688, -16,  176,  -4,   173,  -103, 33,   177,  -4,   168,  -640, -88,
752     -128, -4,   354,  -362, 712,  -4,   -4,   452,  -460, 908,  -4,   -4,
753     62,   -70,  128,  -4,   -4,   420,  -428, 844,  499,  -106, 141,  610,
754     -4,   666,  -46,  210,  866,  -4,   47,   -148, -19,  -16,  -4,   605,
755     -85,  181,  763,  -4,   64,   -72,  132,  -4,   -4,   24,   -32,  52,
756     -4,   -4,   92,   -100, 188,  -4,   -4,   50,   -58,  104,  -132, -694,
757     -200, -558, -4,   15,   -73,  -13,  -17,  -4,   -62,  -610, -158, -418,
758     -4,   -36,  -343, -90,  -235, -4,   456,  -464, 916,  -4,   -4,   42,
759     -50,  88,   -4,   -4,   400,  -408, 804,  -4,   -4,   222,  -230, 448,
760     606,  -244, 146,  676,  -4,   9,    -172, -37,  -80,  -4,   480,  -370,
761     76,   438,  -4,   223,  -340, -3,   112,  -4,   156,  -164, 316,  -4,
762     -4,   108,  -116, 220,  -4,   -4,   240,  -248, 484,  -4,   -4,   220,
763     -228, 444,  236,  -4,   76,   316,  -4,   164,  -4,   52,   220,  -4,
764     362,  -4,   118,  484,  -4,   332,  -4,   108,  444,
765   };
766   // Set padding to same
767   cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
768 
769   RunCNNTest(image_width, image_height, input, expected_2_same, &cnn_config,
770              image_width, &thread_data, MSE_INT_TOL);
771 
772   cnn_config.layer_config[0].pad = PADDING_VALID;
773 
774   RunCNNTest(image_width, image_height, input, expected_2_valid, &cnn_config,
775              image_width, &thread_data, MSE_INT_TOL);
776 
777   cnn_config.layer_config[0].skip_width = 2;
778   cnn_config.layer_config[0].skip_height = 5;
779   float expected_21_same[] = {
780     -31,  -19,  -49,   -191, -565, -194, -574, -13,  14,   -22,  44,   -16,
781     382,  -366, 738,   -22,  -4,   23,   32,   545,  20,   204,  720,  5,
782     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
783     -4,   -4,   -4,    -4,   -658, -252, -748, -114, -334, -192, -568, -112,
784     432,  -440, 928,   -64,  276,  -164, 532,  -220, -4,   304,  868,  266,
785     116,  400,  316,   104,  -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
786     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -208, -288, -856, -290,
787     -862, -202, -598,  -132, 132,  -140, 700,  -436, 1000, -144, 532,  -260,
788     -4,   712,  268,   422,  860,  450,  276,  124,  -4,   -4,   -4,   -4,
789     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
790     -541, -411, -1225, -265, -787, -249, -739, -216, 354,  -362, 1168, -460,
791     974,  -70,  552,   -428, -4,   859,  712,  323,  908,  665,  128,  208,
792     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
793     -4,   -4,   -4,    -4,   -106, -52,  -148, -66,  -190, -79,  -229, -31,
794     64,   -72,  160,   -32,  148,  -100, 242,  -58,  -4,   72,   132,  154,
795     52,   125,  188,   23,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
796     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -694, -257, -763, -229,
797     -679, -319, -949,  -117, 456,  -464, 962,  -50,  492,  -408, 1030, -230,
798     -4,   295,  916,   625,  88,   537,  804,  109,  -4,   -4,   -4,   -4,
799     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
800     -244, -140, -412,  -182, -538, -238, -706, -116, 156,  -164, 428,  -116,
801     464,  -248, 708,   -228, -4,   244,  316,  418,  220,  454,  484,  108,
802     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
803     -4,   -4,   -4,    -4,
804   };
805   float expected_21_valid[] = {
806     -13,  -31,  -19,  -49,  -191, -565, -194, -574, -13,  -31,   -4,   14,
807     -22,  44,   -16,  382,  -366, 738,  -22,  32,   23,   -4,    23,   32,
808     545,  20,   204,  720,  5,    32,   -4,   -4,   -4,   -4,    -4,   -4,
809     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
810     -4,   -4,   -222, -658, -252, -748, -114, -334, -192, -568,  -112, -328,
811     -4,   432,  -440, 928,  -64,  276,  -164, 532,  -220, 428,   650,  -4,
812     304,  868,  266,  116,  400,  316,  104,  428,  -4,   -4,    -4,   -4,
813     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
814     -4,   -4,   -4,   -4,   -72,  -208, -288, -856, -290, -862,  -202, -598,
815     -132, -388, -4,   132,  -140, 700,  -436, 1000, -144, 532,   -260, 508,
816     200,  -4,   712,  268,  422,  860,  450,  276,  124,  508,   -4,   -4,
817     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
818     -4,   -4,   -4,   -4,   -4,   -4,   -183, -541, -411, -1225, -265, -787,
819     -249, -739, -216, -640, -4,   354,  -362, 1168, -460, 974,   -70,  552,
820     -428, 844,  533,  -4,   859,  712,  323,  908,  665,  128,   208,  844,
821     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
822     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -38,  -106,  -52,  -148,
823     -66,  -190, -79,  -229, -31,  -85,  -4,   64,   -72,  160,   -32,  148,
824     -100, 242,  -58,  104,  98,   -4,   72,   132,  154,  52,    125,  188,
825     23,   104,  -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
826     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -234, -694,
827     -257, -763, -229, -679, -319, -949, -117, -343, -4,   456,   -464, 962,
828     -50,  492,  -408, 1030, -230, 448,  686,  -4,   295,  916,   625,  88,
829     537,  804,  109,  448,  -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
830     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
831     -84,  -244, -140, -412, -182, -538, -238, -706, -116, -340,  -4,   156,
832     -164, 428,  -116, 464,  -248, 708,  -228, 444,  236,  -4,    244,  316,
833     418,  220,  454,  484,  108,  444,
834   };
835 
836   cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
837 
838   RunCNNTest(image_width, image_height, input, expected_21_same, &cnn_config,
839              image_width, &thread_data, MSE_INT_TOL);
840 
841   cnn_config.layer_config[0].pad = PADDING_VALID;
842 
843   RunCNNTest(image_width, image_height, input, expected_21_valid, &cnn_config,
844              image_width, &thread_data, MSE_INT_TOL);
845 }
846 
TEST_F(CNNTest,TestLargeKernelsAndStrides)847 TEST_F(CNNTest, TestLargeKernelsAndStrides) {
848   float input_10x11[] = {
849     4,  4,  2,  4,  2,  -5, -2, 3, -1, 0,  0,  1,  2,  0,  -5, -2, -5, 1,  -3,
850     -1, 4,  -3, 2,  -2, 1,  0,  1, -3, -3, -4, -2, -2, 1,  -4, -1, 4,  1,  -4,
851     -4, -4, 3,  2,  -5, 3,  -5, 1, 2,  -4, 1,  -1, 3,  4,  -2, 3,  -3, 3,  0,
852     2,  -4, -5, -5, -2, -1, -2, 1, 1,  1,  -2, 4,  -5, 4,  -1, -1, 2,  3,  -4,
853     2,  2,  3,  0,  0,  1,  0,  3, 2,  3,  1,  -2, 3,  -4, 3,  2,  4,  -2, 0,
854     4,  -4, 1,  -3, -3, -3, -5, 1, -3, -5, 0,  4,  -1, -3, 2,
855   };
856 
857   float weights_10x11[] = {
858     -3, 4,  -4, -3, -5, 1,  -2, 3,  1,  -4, -4, 0,  -1, 0,  3,  1,  -3, -2, 0,
859     -1, 1,  3,  -4, -4, -3, -3, -2, 4,  3,  -5, 4,  2,  -3, 4,  -2, -1, 2,  -1,
860     -5, 0,  -3, 0,  3,  -5, -5, 3,  -4, -1, -5, 3,  4,  0,  4,  -5, 2,  -1, 2,
861     -1, -1, -1, -5, 0,  -4, 3,  -1, 1,  1,  -1, 3,  2,  -5, -4, 0,  -4, 4,  -5,
862     -3, 4,  -5, 2,  -5, -4, -4, -1, 3,  3,  0,  2,  -4, 1,  -2, 1,  1,  0,  3,
863     -2, 0,  1,  2,  4,  -3, -1, -5, -5, 2,  -4, 1,  1,  2,  -4, -2, -2, 2,  1,
864     3,  4,  -5, 1,  -1, -3, -3, -1, -2, -5, 1,  -1, 0,  1,  4,  4,  0,  0,  4,
865     -3, -1, -5, -3, 0,  1,  1,  1,  -5, 3,  4,  3,  -5, 3,  -2, -2, 0,  -4, 0,
866     0,  -2, 1,  -4, -1, 0,  -5, -2, -2, -5, -3, -3, 1,  1,  -3, 2,  4,  2,  4,
867     -4, -3, 3,  1,  1,  3,  -4, 4,  -2, -3, -3, -3, -3, -4, -2, 3,  -5, 2,  4,
868     -1, -4, -4, 4,  -2, -1, 3,  -3, -4, -4, -2, 4,  1,  0,  2,  -1, 4,  -3, 1,
869     4,  -3, 4,  4,  0,  -4, 3,  -2, -3, 2,  3,  -1, -3, 2,  1,  4,  -2, -3, 1,
870     4,  -2, 2,  -2, -5, -2, 1,  4,  -1, -4, 4,  -5, 2,  -5, -4, -1, -2, 3,  1,
871     2,  1,  -5, 1,  -5, -4, -1, -2, 2,  -2, -4, -3, -2, -2, 4,  -1, 2,  2,  -4,
872     2,  -2, 4,  -4, -2, -2, 1,  -1, 1,  1,  1,  -4, -5, -2, 3,  -4, -1, 3,  -2,
873     3,  2,  -5, -4, 0,  3,  -2, -4, -5, 3,  -2, -4, 2,  -2, 1,  -4, 0,  2,  -5,
874     1,  -4, -1, -1, 4,  -5, -4, 0,  -5, -4, -3, -5, -4, 0,  2,  0,  -4, 2,  -2,
875     1,  1,  -3, 2,  0,  -4, 0,  -4, 1,  0,  -5, -1, -1, -1, -5, 4,  2,  2,  -4,
876     3,  -2, -2, 2,  -3, -2, -1, 2,  -4, -5, 2,  -2, -4, -5, -5, -1, 2,  -1, 0,
877     -5, -2, -2, -5, 0,  1,  -1, -5, 0,  3,  2,  3,  0,  -3, -2, 0,  -5, -1, -2,
878     2,  -4, -1, 2,  2,  -5, 2,  -4, 0,  3,  -3, 1,  0,  0,  1,  -5, -3, 1,  -1,
879     0,  -4, -3, 2,  -4, -4, 4,  -1, 0,  1,  2,  -4, -5, 4,  -2, 1,  -4, -4, -3,
880     -1, -1, 1,  -1, -4, -1, -4, -3, 2,  -1, -2, -4, 1,  1,  0,  -2, 0,  -4, 3,
881     -3, 0,  -4, -1, -4, 2,  -1, -2, -5, -1, -2, -3, 3,  -1, 0,  -3, 0,  1,  -5,
882     1,  -5, 0,  1,
883   };
884 
885   float bias_10x11[] = { 3 };
886 
887   float expected_10x11[] = {
888     118,
889   };
890 
891   CNN_CONFIG cnn_config = { 1,
892                             0,
893                             0,
894                             0,
895                             0,
896                             { {
897                                 1,
898                                 23,
899                                 20,
900                                 1,
901                                 15,
902                                 20,
903                                 0,
904                                 weights_10x11,
905                                 bias_10x11,
906                                 PADDING_SAME_ZERO,
907                                 NONE,
908                                 0,
909                                 0,
910                                 BRANCH_NO_COPY,
911                                 BRANCH_NOC,
912                                 {},
913                                 {},
914                                 0,
915                             } } };
916 
917   int image_height = 10;
918   int image_width = 11;
919 
920   CNN_THREAD_DATA thread_data = { 1, NULL };
921 
922   RunCNNTest(image_width, image_height, input_10x11, expected_10x11,
923              &cnn_config, image_width, &thread_data, MSE_INT_TOL);
924 
925   float input_11x10[] = {
926     -2, -2, 3,  -5, -1, -3, 1,  3,  2,  1,  1,  -5, 4,  1,  3,  -5, 3,  -3, -5,
927     0,  -1, -3, -3, 1,  1,  -5, -1, -5, -5, -3, 0,  1,  -3, -1, -3, -3, 0,  3,
928     4,  -4, -1, 3,  -3, -1, -3, 1,  -3, -2, -1, -4, -3, 2,  -4, 1,  -4, -1, -3,
929     -5, -1, 2,  3,  0,  2,  2,  -5, 4,  1,  2,  -1, -4, 4,  -4, -4, 0,  -1, 1,
930     -1, 1,  -3, -3, -2, 1,  2,  4,  4,  4,  -3, -3, 0,  1,  0,  1,  4,  1,  3,
931     4,  -3, -2, -4, 4,  2,  0,  3,  4,  -1, 2,  -2, 1,  -3, -2,
932   };
933 
934   float weights_11x10[] = {
935     4,  -1, 1,  -1, 2,  4,  3,  3,  -4, 3,  -5, 1,  -1, -1, -2, -2, 0,  2,  -3,
936     -2, 3,  -5, -1, 0,  -1, -2, -2, -1, 2,  4,  3,  1,  0,  0,  -3, 3,  -4, -1,
937     -5, 4,  -2, -2, 1,  2,  -1, -3, 1,  2,  -5, 1,  -3, 3,  3,  0,  -4, -4, -5,
938     -3, -4, -4, 4,  -2, 4,  4,  -2, 2,  -5, -1, -2, -5, -1, 4,  -3, 3,  -2, 0,
939     -4, -3, 0,  -1, -2, 4,  2,  0,  -2, -5, -4, 1,  4,  -4, -2, 2,  -2, 1,  1,
940     -4, 1,  -4, -4, -2, 4,  2,  -1, -5, -5, 1,  -3, -3, 3,  -3, -5, -3, 4,  -1,
941     -1, -3, 0,  -4, 3,  -1, 0,  -2, 0,  -5, -2, -5, 2,  0,  -5, 2,  3,  -2, 2,
942     4,  -1, 1,  -3, 2,  3,  2,  0,  -5, -4, -5, 2,  1,  1,  -1, -2, 3,  4,  2,
943     -2, 4,  -2, 3,  1,  -4, -3, -1, 4,  4,  -3, -5, -2, 2,  0,  3,  -2, 3,  -1,
944     -4, 0,  -2, 0,  3,  4,  -2, -3, -2, 0,  3,  4,  2,  -4, 0,  1,  2,  2,  -1,
945     -1, 4,  1,  4,  -2, -1, -1, -5, 1,  -3, 3,  3,  -1, -4, 3,  -5, 0,  0,  -1,
946     -4, -1, -2, 4,  -2, 3,  3,  -3, 1,  -1, 2,  -1, 4,  4,  -2, -2, 4,  -2, 0,
947     3,  -3, -5, -1, -2, 4,  -4, 2,  -4, 0,  -2, 3,  -3, 2,  2,  -2, -5, -1, 4,
948     3,  -2, -1, 3,  3,  -1, 3,  0,  -3, 0,  4,  2,  0,  -1, 4,  1,  1,  2,  1,
949     3,  1,  1,  1,  -3, -5, -4, 4,  -4, 2,  0,  0,  -4, 1,  4,  -5, 4,  4,  0,
950     1,  0,  -2, -4, -4, -3, 0,  1,  -5, 4,  0,  -3, -2, -4, 2,  4,  1,  -5, 1,
951     -4, 1,  0,  -3, -3, 0,  2,  -5, 4,  3,  -2, -5, 3,  1,  -1, 0,  3,  -2, -2,
952     3,  -2, -5, 4,  1,  -2, 2,  -1, 0,  4,  0,  -5, 3,  -2, 1,  2,  1,  -5, -3,
953     -2, -5, 4,  -4, 0,  3,  2,  -1, -4, -1, 2,  1,  -2, 3,  -1, -4, 2,  0,  -3,
954     1,  -1, 2,  -5, -4, -1, -5, 1,  4,  3,  4,  2,  -3, 1,  -5, -1, 3,  0,  -1,
955     -4, 3,  4,  -5, 4,  4,  -3, 2,  -3, -1, -3, -5, -3, 2,  -3, -2, 1,  1,  0,
956     -5, 3,  2,  1,  -5, 1,  1,  1,  3,  4,  -4, -1, -2, 0,  -5, -3, -5, -2, -4,
957     3,  3,  3,  4,  0,  -4, -1, -5, 0,  -3, 1,  4,  4,  -4, 4,  -5, -5, -1, -2,
958     -5, 3,  -4, 4,  3,  0,  -3, 2,  -2, 0,  0,  4,  4,  0,  -2, 1,  -1, -3, 2,
959     -1, 1,  -3, -5,
960   };
961 
962   float bias_11x10[] = {
963     -5,
964   };
965 
966   float expected_11x10[] = {
967     36,  -84,  95,   45,  18,   46,   77,  -54, -99,  -149, 66,  49,  161, 11,
968     39,  61,   -66,  61,  4,    -3,   34,  -44, -23,  31,   64,  29,  47,  72,
969     -27, -27,  121,  -3,  100,  1,    30,  -78, -12,  -89,  -59, 8,   -16, 112,
970     91,  -102, -26,  -4,  30,   54,   4,   -84, -24,  -58,  27,  -53, -33, 5,
971     53,  -26,  63,   50,  -103, -130, -23, 6,   -104, -207, 73,  23,  77,  132,
972     38,  32,   -130, -44, -60,  7,    27,  176, 45,   -32,  -2,  99,  -97, 63,
973     69,  126,  47,   63,  136,  -57,  5,   16,  -40,  -157, 8,   38,  -44, -10,
974     91,  7,    122,  140, 30,   -105, 4,   -1,  113,  64,   180, 141,
975   };
976 
977   cnn_config.layer_config[0].weights = weights_11x10;
978   cnn_config.layer_config[0].bias = bias_11x10;
979   cnn_config.layer_config[0].filter_width = 20;
980   cnn_config.layer_config[0].filter_height = 23;
981   cnn_config.layer_config[0].skip_width = 1;
982   cnn_config.layer_config[0].skip_height = 1;
983   image_height = 11;
984   image_width = 10;
985 
986   RunCNNTest(image_width, image_height, input_11x10, expected_11x10,
987              &cnn_config, image_width, &thread_data, MSE_INT_TOL);
988 }
989 
TEST_F(CNNTest,TestSoftsignSingleLayer)990 TEST_F(CNNTest, TestSoftsignSingleLayer) {
991   int image_width = 8;
992   int image_height = 8;
993   int filter_height = 5;
994   int filter_width = 4;
995   float input[] = {
996     -0.5220f, 0.8410f,  -0.8990f, -0.0090f, 0.6710f,  -0.9470f, -0.8240f,
997     -0.0870f, 0.5380f,  0.4750f,  0.570f,   -0.3760f, -0.6960f, -0.5940f,
998     -0.3830f, 0.080f,   -0.0980f, -0.4940f, -0.4030f, 0.9460f,  -0.6020f,
999     0.4220f,  0.6190f,  0.6640f,  -0.9210f, -0.1470f, -0.2480f, -0.1120f,
1000     -0.580f,  -0.0650f, 0.3330f,  0.9860f,  -0.7430f, 0.7610f,  0.4840f,
1001     0.1030f,  0.9570f,  0.6120f,  -0.5240f, -0.1220f, -0.5850f, -0.270f,
1002     0.7840f,  -0.9790f, 0.7290f,  -0.30f,   -0.6460f, 0.0780f,  0.4750f,
1003     -0.0510f, 0.4550f,  0.3850f,  -0.7230f, 0.4460f,  -0.6260f, -0.810f,
1004     0.8720f,  -0.2120f, -0.580f,  -0.9510f, -0.8430f, -0.1340f, -0.0850f,
1005     0.9190f,
1006   };
1007   float expected_same[] = {
1008     0.430f,   0.660f,  0.5510f,  -0.610f,  0.450f,  -0.1610f, 0.0520f,  0.3240f,
1009     0.6820f,  0.3820f, 0.6360f,  0.7480f,  0.3080f, 0.090f,   0.3910f,  0.1730f,
1010     0.340f,   0.6660f, -0.4990f, 0.4280f,  0.1540f, 0.120f,   0.4670f,  0.6150f,
1011     -0.3880f, 0.7590f, 0.4190f,  0.7350f,  0.5310f, -0.5160f, -0.1760f, 0.6790f,
1012     -0.6780f, 0.5470f, 0.5750f,  -0.6420f, 0.7210f, -0.4620f, 0.5430f,  0.770f,
1013     -0.1990f, 0.3950f, 0.7860f,  -0.4380f, 0.7540f, 0.2640f,  -0.6430f, 0.4510f,
1014     -0.1260f, 0.1590f, -0.2110f, -0.0560f, 0.6570f, 0.680f,   0.5870f,  0.4720f,
1015     0.4040f,  0.3630f, 0.670f,   0.2360f,  0.410f,  0.6980f,  -0.5350f, 0.3940f,
1016   };
1017   float expected_replicate[] = {
1018     0.540f,   0.7230f,  -0.3530f, -0.2130f, 0.7440f,  -0.4470f, -0.6260f,
1019     -0.2050f, 0.7230f,  0.4630f,  0.5920f,  0.7440f,  0.6080f,  0.3130f,
1020     -0.5670f, -0.4720f, 0.5480f,  0.6660f,  -0.4990f, 0.4280f,  0.1540f,
1021     0.120f,   0.3390f,  0.6090f,  0.4160f,  0.7590f,  0.4190f,  0.7350f,
1022     0.5310f,  -0.5160f, -0.490f,  0.4450f,  -0.610f,  0.5470f,  0.5750f,
1023     -0.6420f, 0.7210f,  -0.4620f, 0.3150f,  0.7370f,  -0.5820f, 0.3950f,
1024     0.7860f,  -0.4380f, 0.7540f,  0.2640f,  -0.7430f, -0.5340f, -0.6270f,
1025     0.4430f,  0.4730f,  0.4570f,  0.7450f,  0.630f,   0.2620f,  0.3140f,
1026     -0.1840f, 0.1810f,  0.7210f,  0.2760f,  0.6430f,  0.6720f,  -0.4390f,
1027     0.2040f,
1028   };
1029   float expected_valid[] = {
1030     0.6660f,  -0.4990f, 0.4280f,  0.1540f,  0.120f,  0.7590f,  0.4190f,
1031     0.7350f,  0.5310f,  -0.5160f, 0.5470f,  0.5750f, -0.6420f, 0.7210f,
1032     -0.4620f, 0.3950f,  0.7860f,  -0.4380f, 0.7540f, 0.2640f,
1033   };
1034   float weights[] = {
1035     0.6210f,  0.3710f,  -0.2770f, -0.7230f, -0.2450f, 0.6770f,  0.3080f,
1036     -0.9880f, -0.080f,  0.7190f,  -0.6760f, -0.0170f, -0.8970f, 0.8260f,
1037     0.7390f,  -0.4550f, -0.4260f, -0.6330f, 0.0880f,  -0.9390f,
1038   };
1039   float bias[] = {
1040     0.750f,
1041   };
1042 
1043   CNN_CONFIG cnn_config = { 1,
1044                             0,
1045                             0,
1046                             0,
1047                             0,
1048                             { {
1049                                 1,
1050                                 filter_width,
1051                                 filter_height,
1052                                 1,
1053                                 1,
1054                                 1,
1055                                 0,
1056                                 weights,
1057                                 bias,
1058                                 PADDING_SAME_ZERO,
1059                                 SOFTSIGN,
1060                                 0,
1061                                 0,
1062                                 BRANCH_NO_COPY,
1063                                 BRANCH_NOC,
1064                                 {},
1065                                 {},
1066                                 0,
1067                             } } };
1068 
1069   CNN_THREAD_DATA thread_data = { 1, NULL };
1070 
1071   RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
1072              image_width, &thread_data, MSE_FLOAT_TOL);
1073 
1074   cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
1075 
1076   RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
1077              image_width, &thread_data, MSE_FLOAT_TOL);
1078 
1079   cnn_config.layer_config[0].pad = PADDING_VALID;
1080 
1081   RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
1082              image_width, &thread_data, MSE_FLOAT_TOL);
1083 }
1084 
TEST_F(CNNTest,TestBranchTensorAdd)1085 TEST_F(CNNTest, TestBranchTensorAdd) {
1086   int filter_width = 2;
1087   int filter_height = 3;
1088 
1089   int image_width = 4;
1090   int image_height = 4;
1091 
1092   float input[] = {
1093     -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
1094   };
1095 
1096   float weights[] = {
1097     -3, -1, 4,  -1, -3, 3,  3,  0,  2,  0,  3,  2,  4,  4, 4,  -5, 1, -4,
1098     2,  -4, 1,  -3, 0,  4,  -5, 4,  0,  -4, -3, -1, 0,  0, -2, 0,  0, 2,
1099     -5, -1, 1,  -3, 3,  4,  3,  0,  1,  -1, 1,  1,  2,  4, -2, -5, 2, -2,
1100     3,  -2, 4,  -1, 0,  2,  3,  2,  -2, -1, -3, 1,  3,  4, -1, -3, 0, -4,
1101     4,  2,  -3, -3, -1, 0,  1,  0,  3,  3,  -3, 0,  3,  2, -5, -3, 4, -5,
1102     3,  -1, -1, -3, 0,  1,  -1, -4, 2,  4,  -1, 4,  -1, 1, 3,  4,  4, 4,
1103     0,  -1, -3, -3, -3, -3, 2,  -3, -2, 2,  3,  -3,
1104   };
1105 
1106   float bias[] = {
1107     3, 4, -1, -1, 2, 1, -2, 1, 4, 1, 3,
1108   };
1109 
1110   float expected[] = {
1111     -11502, -4101, -3424, 668,   -17950, -5470, -5504, 626,
1112     4835,   446,   1779,  -3483, 3679,   -4214, 4578,  -105,
1113   };
1114 
1115   int channels = 2;
1116 
1117   CNN_CONFIG cnn_config = { 6,
1118                             0,
1119                             0,
1120                             0,
1121                             0,
1122                             { {
1123                                   1,
1124                                   filter_width,
1125                                   filter_height,
1126                                   channels,
1127                                   1,
1128                                   1,
1129                                   0,
1130                                   weights,
1131                                   bias,
1132                                   PADDING_SAME_ZERO,
1133                                   NONE,
1134                                   0,
1135                                   0,
1136                                   BRANCH_NO_COPY,
1137                                   BRANCH_NOC,
1138                                   {},
1139                                   {},
1140                                   -1,
1141                               },
1142                               {
1143                                   channels,
1144                                   filter_width,
1145                                   filter_height,
1146                                   channels,
1147                                   1,
1148                                   1,
1149                                   0,
1150                                   nullptr,
1151                                   nullptr,
1152                                   PADDING_SAME_ZERO,
1153                                   NONE,
1154                                   0,
1155                                   0,
1156                                   BRANCH_INPUT,
1157                                   BRANCH_NOC,
1158                                   {
1159                                       0x02,
1160                                       0,
1161                                       0x00,
1162                                   },
1163                                   {},
1164                                   -1,
1165                               },
1166                               {
1167                                   channels,
1168                                   filter_width,
1169                                   filter_height,
1170                                   channels,
1171                                   1,
1172                                   1,
1173                                   0,
1174                                   nullptr,
1175                                   nullptr,
1176                                   PADDING_SAME_ZERO,
1177                                   NONE,
1178                                   0,
1179                                   1,
1180                                   BRANCH_NO_COPY,
1181                                   BRANCH_NOC,
1182                                   {},
1183                                   {},
1184                                   -1,
1185                               },
1186                               {
1187                                   channels,
1188                                   filter_width,
1189                                   filter_height,
1190                                   channels,
1191                                   1,
1192                                   1,
1193                                   0,
1194                                   nullptr,
1195                                   nullptr,
1196                                   PADDING_SAME_ZERO,
1197                                   NONE,
1198                                   0,
1199                                   1,
1200                                   BRANCH_NO_COPY,
1201                                   BRANCH_NOC,
1202                                   {},
1203                                   {},
1204                                   -1,
1205                               },
1206                               {
1207                                   channels,
1208                                   filter_width,
1209                                   filter_height,
1210                                   channels,
1211                                   1,
1212                                   1,
1213                                   0,
1214                                   nullptr,
1215                                   nullptr,
1216                                   PADDING_SAME_ZERO,
1217                                   NONE,
1218                                   0,
1219                                   0,
1220                                   BRANCH_NO_COPY,
1221                                   BRANCH_ADD,
1222                                   {
1223                                       0x00,
1224                                       0,
1225                                       0x02,
1226                                   },
1227                                   {},
1228                                   -1,
1229                               },
1230                               {
1231                                   channels,
1232                                   filter_width,
1233                                   filter_height,
1234                                   1,
1235                                   1,
1236                                   1,
1237                                   0,
1238                                   nullptr,
1239                                   nullptr,
1240                                   PADDING_SAME_ZERO,
1241                                   NONE,
1242                                   0,
1243                                   0,
1244                                   BRANCH_NO_COPY,
1245                                   BRANCH_NOC,
1246                                   {},
1247                                   {},
1248                                   0,
1249                               } } };
1250 
1251   // Weights and biases need to be specified separately because
1252   // of the offset.
1253   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1254 
1255   CNN_THREAD_DATA thread_data = { 1, NULL };
1256 
1257   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1258              image_width, &thread_data, MSE_INT_TOL);
1259 }
1260 
TEST_F(CNNTest,TestBranchTensorConcatenation)1261 TEST_F(CNNTest, TestBranchTensorConcatenation) {
1262   int filter_width = 2;
1263   int filter_height = 3;
1264 
1265   int image_width = 4;
1266   int image_height = 4;
1267 
1268   float input[] = {
1269     -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
1270   };
1271 
1272   float weights[] = {
1273     3,  0,  2,  0,  2,  3,  1,  -3, 1,  -5, -3, 0,  -4, 4,  0,  -5, 0,  -5, -1,
1274     -2, -5, 0,  -3, 2,  -4, 2,  0,  2,  -1, 0,  -4, 3,  0,  0,  -1, -5, 2,  -1,
1275     4,  -4, -2, -3, -3, 3,  4,  -2, -1, -4, -1, 4,  4,  -1, 4,  3,  -4, 2,  -2,
1276     -4, -3, -2, 3,  -3, -5, -1, 3,  -2, 4,  1,  -4, -3, -5, -5, -3, 4,  -2, -2,
1277     -1, -5, -5, 0,  -1, -2, -3, 3,  -4, -5, 2,  -3, 1,  0,  -5, 2,  2,  -2, 0,
1278     2,  2,  -2, 4,  2,  2,  0,  1,  -5, -3, 0,  2,  -2, 1,  2,  -5, 2,  3,  3,
1279     -1, 3,  0,  -3, 3,  -4, -4, 3,  3,  -4, -2, 2,  -2, 2,  -2, -1, 3,  0,
1280   };
1281 
1282   float bias[] = {
1283     -3, -5, 4, -4, -3, -2, 0, 3, -4, 4, -3,
1284   };
1285 
1286   float expected[] = {
1287     -33533, -32087, -6741,  -2124, 39979, 41453, 14034, 689,
1288     -22611, -42203, -14882, -239,  15781, 15963, 9524,  837,
1289   };
1290 
1291   int channels = 2;
1292 
1293   CNN_CONFIG cnn_config = { 6,
1294                             0,
1295                             0,
1296                             0,
1297                             0,
1298                             { {
1299                                   1,
1300                                   filter_width,
1301                                   filter_height,
1302                                   channels,
1303                                   1,
1304                                   1,
1305                                   0,
1306                                   weights,
1307                                   bias,
1308                                   PADDING_SAME_ZERO,
1309                                   NONE,
1310                                   0,
1311                                   0,
1312                                   BRANCH_NO_COPY,
1313                                   BRANCH_NOC,
1314                                   {},
1315                                   {},
1316                                   -1,
1317                               },
1318                               {
1319                                   channels,
1320                                   filter_width,
1321                                   filter_height,
1322                                   channels,
1323                                   1,
1324                                   1,
1325                                   0,
1326                                   nullptr,
1327                                   nullptr,
1328                                   PADDING_SAME_ZERO,
1329                                   NONE,
1330                                   0,
1331                                   0,
1332                                   BRANCH_INPUT,
1333                                   BRANCH_NOC,
1334                                   {
1335                                       0x02,
1336                                       0,
1337                                       0x00,
1338                                   },
1339                                   {},
1340                                   -1,
1341                               },
1342                               {
1343                                   channels,
1344                                   filter_width,
1345                                   filter_height,
1346                                   channels,
1347                                   1,
1348                                   1,
1349                                   0,
1350                                   nullptr,
1351                                   nullptr,
1352                                   PADDING_SAME_ZERO,
1353                                   NONE,
1354                                   0,
1355                                   1,
1356                                   BRANCH_NO_COPY,
1357                                   BRANCH_NOC,
1358                                   {},
1359                                   {},
1360                                   -1,
1361                               },
1362                               {
1363                                   channels,
1364                                   filter_width,
1365                                   filter_height,
1366                                   channels,
1367                                   1,
1368                                   1,
1369                                   0,
1370                                   nullptr,
1371                                   nullptr,
1372                                   PADDING_SAME_ZERO,
1373                                   NONE,
1374                                   0,
1375                                   1,
1376                                   BRANCH_NO_COPY,
1377                                   BRANCH_NOC,
1378                                   {},
1379                                   {},
1380                                   -1,
1381                               },
1382                               {
1383                                   channels,
1384                                   filter_width,
1385                                   filter_height,
1386                                   channels,
1387                                   1,
1388                                   1,
1389                                   0,
1390                                   nullptr,
1391                                   nullptr,
1392                                   PADDING_SAME_ZERO,
1393                                   NONE,
1394                                   0,
1395                                   0,
1396                                   BRANCH_NO_COPY,
1397                                   BRANCH_CAT,
1398                                   {
1399                                       0x00,
1400                                       0,
1401                                       0x02,
1402                                   },
1403                                   {},
1404                                   -1,
1405                               },
1406                               {
1407                                   channels + channels,
1408                                   filter_width,
1409                                   filter_height,
1410                                   1,
1411                                   1,
1412                                   1,
1413                                   0,
1414                                   nullptr,
1415                                   nullptr,
1416                                   PADDING_SAME_ZERO,
1417                                   NONE,
1418                                   0,
1419                                   0,
1420                                   BRANCH_NO_COPY,
1421                                   BRANCH_NOC,
1422                                   {},
1423                                   {},
1424                                   0,
1425                               } } };
1426 
1427   // Weights and biases need to be specified separately because
1428   // of the offset.
1429   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1430 
1431   CNN_THREAD_DATA thread_data = { 1, NULL };
1432 
1433   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1434              image_width, &thread_data, MSE_INT_TOL);
1435 }
1436 
1437 // TODO(logangw): Add test to test all combinations of branch_copy_type.
1438 
TEST_F(CNNTest,TestBranchCombinations)1439 TEST_F(CNNTest, TestBranchCombinations) {
1440   int filter_width = 2;
1441   int filter_height = 3;
1442 
1443   int image_width = 4;
1444   int image_height = 4;
1445 
1446   float input[] = {
1447     3, 2, -5, -4, 4, -2, -4, -3, 4, 2, -3, 2, -3, 1, -5, -1,
1448   };
1449 
1450   float weights[] = {
1451     2,  3,  0,  4,  4,  3,  1,  0,  1,  -5, 4,  -3, 3,  0,  4,  -1, -1, -5,
1452     2,  1,  -3, -5, 3,  -1, -3, -2, 0,  -2, 3,  0,  -2, -4, -2, -2, 2,  -5,
1453     4,  -5, 0,  1,  -5, -4, -3, -4, 2,  -2, 1,  0,  3,  -2, -4, 3,  4,  -4,
1454     -1, -1, -3, -2, -2, -1, 2,  0,  2,  -1, 2,  -4, -4, -1, 2,  0,  3,  -2,
1455     -2, 3,  -3, 4,  -2, 4,  3,  4,  1,  0,  -2, -3, -5, 1,  -3, 2,  0,  -2,
1456     -2, -1, -1, -5, -2, -3, -1, 3,  3,  4,  4,  0,  2,  1,  3,  -3, 2,  -5,
1457     -5, 1,  -5, -1, 3,  3,  2,  -4, -1, 3,  -4, -2, -5, -2, 1,  3,  2,  2,
1458     -5, -2, -3, -1, -2, -4, -1, -2, 2,  1,  -4, -4, 2,  0,  2,  0,  2,  -3,
1459     -2, -4, 4,  0,  1,  -3, -5, 4,  -1, 2,  3,  -5, -1, 0,  4,  -1, -1, 3,
1460     -1, -3, 3,  1,  4,  3,  4,  3,  -4, -5, -1, 3,  3,  -4, 3,  1,  3,  -5,
1461     3,  4,  -5, 4,  2,  -1, -5, 2,  1,  0,  4,  0,  -3, 2,  0,  2,  -2, 1,
1462     -1, -2, -1, -5, 4,  3,  3,  -2, 2,  4,  -5, -5, -3, -2, 4,  0,  -4, 1,
1463   };
1464 
1465   float bias[] = {
1466     -1, 4, 0, 2, 2, -2, 0, -4, -5, -1, 1, -2, 3, 0, 4, -2, 1, 0, 0,
1467   };
1468 
1469   float expected[] = {
1470     149496, 15553,  -24193, -20956, 134094, 86432,  -68283, -6366,
1471     -53031, 133739, 67407,  -13539, -53205, -58635, -20033, 1979,
1472   };
1473 
1474   int channels = 2;
1475 
1476   CNN_CONFIG cnn_config = { 10,
1477                             0,
1478                             0,
1479                             0,
1480                             0,
1481                             {
1482                                 {
1483                                     1,
1484                                     filter_width,
1485                                     filter_height,
1486                                     channels,
1487                                     1,
1488                                     1,
1489                                     0,
1490                                     weights,
1491                                     bias,
1492                                     PADDING_SAME_ZERO,
1493                                     NONE,
1494                                     0,
1495                                     0,
1496                                     BRANCH_NO_COPY,
1497                                     BRANCH_NOC,
1498                                     {},
1499                                     {},
1500                                     -1,
1501                                 },
1502                                 {
1503                                     channels,
1504                                     filter_width,
1505                                     filter_height,
1506                                     channels,
1507                                     1,
1508                                     1,
1509                                     0,
1510                                     nullptr,
1511                                     nullptr,
1512                                     PADDING_SAME_ZERO,
1513                                     NONE,
1514                                     0,
1515                                     0,
1516                                     BRANCH_INPUT,
1517                                     BRANCH_NOC,
1518                                     {
1519                                         0x06,
1520                                         0,
1521                                         0x00,
1522                                     },
1523                                     {},
1524                                     -1,
1525                                 },
1526                                 {
1527                                     channels,
1528                                     filter_width,
1529                                     filter_height,
1530                                     channels,
1531                                     1,
1532                                     1,
1533                                     0,
1534                                     nullptr,
1535                                     nullptr,
1536                                     PADDING_SAME_ZERO,
1537                                     NONE,
1538                                     0,
1539                                     2,
1540                                     BRANCH_OUTPUT,
1541                                     BRANCH_NOC,
1542                                     {
1543                                         0x08,
1544                                         0,
1545                                         0x00,
1546                                     },
1547                                     {},
1548                                     -1,
1549                                 },
1550                                 {
1551                                     channels,
1552                                     filter_width,
1553                                     filter_height,
1554                                     channels,
1555                                     1,
1556                                     1,
1557                                     0,
1558                                     nullptr,
1559                                     nullptr,
1560                                     PADDING_SAME_ZERO,
1561                                     NONE,
1562                                     0,
1563                                     3,
1564                                     BRANCH_NO_COPY,
1565                                     BRANCH_NOC,
1566                                     {},
1567                                     {},
1568                                     -1,
1569                                 },
1570                                 {
1571                                     channels,
1572                                     filter_width,
1573                                     filter_height,
1574                                     channels,
1575                                     1,
1576                                     1,
1577                                     0,
1578                                     nullptr,
1579                                     nullptr,
1580                                     PADDING_SAME_ZERO,
1581                                     NONE,
1582                                     0,
1583                                     2,
1584                                     BRANCH_NO_COPY,
1585                                     BRANCH_ADD,
1586                                     {
1587                                         0x00,
1588                                         0,
1589                                         0x08,
1590                                     },
1591                                     {},
1592                                     -1,
1593                                 },
1594                                 {
1595                                     channels,
1596                                     filter_width,
1597                                     filter_height,
1598                                     channels,
1599                                     1,
1600                                     1,
1601                                     0,
1602                                     nullptr,
1603                                     nullptr,
1604                                     PADDING_SAME_ZERO,
1605                                     NONE,
1606                                     0,
1607                                     2,
1608                                     BRANCH_NO_COPY,
1609                                     BRANCH_NOC,
1610                                     {},
1611                                     {},
1612                                     -1,
1613                                 },
1614                                 {
1615                                     channels,
1616                                     filter_width,
1617                                     filter_height,
1618                                     channels,
1619                                     1,
1620                                     1,
1621                                     0,
1622                                     nullptr,
1623                                     nullptr,
1624                                     PADDING_SAME_ZERO,
1625                                     NONE,
1626                                     0,
1627                                     1,
1628                                     BRANCH_NO_COPY,
1629                                     BRANCH_NOC,
1630                                     {},
1631                                     {},
1632                                     -1,
1633                                 },
1634                                 {
1635                                     channels,
1636                                     filter_width,
1637                                     filter_height,
1638                                     channels,
1639                                     1,
1640                                     1,
1641                                     0,
1642                                     nullptr,
1643                                     nullptr,
1644                                     PADDING_SAME_ZERO,
1645                                     NONE,
1646                                     0,
1647                                     1,
1648                                     BRANCH_NO_COPY,
1649                                     BRANCH_ADD,
1650                                     {
1651                                         0x00,
1652                                         0,
1653                                         0x0C,
1654                                     },
1655                                     {},
1656                                     -1,
1657                                 },
1658                                 {
1659                                     channels,
1660                                     filter_width,
1661                                     filter_height,
1662                                     channels,
1663                                     1,
1664                                     1,
1665                                     0,
1666                                     nullptr,
1667                                     nullptr,
1668                                     PADDING_SAME_ZERO,
1669                                     NONE,
1670                                     0,
1671                                     0,
1672                                     BRANCH_NO_COPY,
1673                                     BRANCH_ADD,
1674                                     {
1675                                         0x00,
1676                                         0,
1677                                         0x02,
1678                                     },
1679                                     {},
1680                                     -1,
1681                                 },
1682                                 {
1683                                     channels,
1684                                     filter_width,
1685                                     filter_height,
1686                                     1,
1687                                     1,
1688                                     1,
1689                                     0,
1690                                     nullptr,
1691                                     nullptr,
1692                                     PADDING_SAME_ZERO,
1693                                     NONE,
1694                                     0,
1695                                     0,
1696                                     BRANCH_NO_COPY,
1697                                     BRANCH_NOC,
1698                                     {},
1699                                     {},
1700                                     0,
1701                                 },
1702                             } };
1703 
1704   // Weights and biases need to be specified separately because
1705   // of the offset.
1706   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1707 
1708   CNN_THREAD_DATA thread_data = { 1, NULL };
1709 
1710   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1711              image_width, &thread_data, MSE_INT_TOL);
1712 }
1713 
TEST_F(CNNTest,TestSplittingTensors)1714 TEST_F(CNNTest, TestSplittingTensors) {
1715   int filter_width = 2;
1716   int filter_height = 3;
1717 
1718   int image_width = 4;
1719   int image_height = 4;
1720 
1721   float input[] = {
1722     -1, -1, 2, 1, 3, 2, 4, -3, -4, -2, 2, -3, 1, -3, 4, -2,
1723   };
1724 
1725   float weights[] = {
1726     -4, 1,  0,  2,  3,  4,  4,  -4, -5, -3, 2,  2,  -4, -3, 3,  2,
1727     4,  -4, -3, -4, -4, 1,  -3, -5, -3, 4,  2,  -2, 2,  -1, -4, -1,
1728     -2, -3, 1,  1,  0,  -5, -1, 3,  3,  -5, -3, 0,  -3, 1,  -3, -1,
1729     1,  -3, -2, -2, 4,  -2, 0,  1,  2,  2,  -4, 2,  4,  0,  -5, -2,
1730     4,  4,  -5, 1,  0,  2,  -2, -5, -5, -3, -5, -5, 4,  -3, 0,  0,
1731     -4, -4, 0,  -5, -4, 0,  0,  -3, -5, -3, -1, 2,  -1, 4,  -1, 2,
1732   };
1733 
1734   float bias[] = {
1735     -4, -2, -3, -3, 3, 1, -2,
1736   };
1737 
1738   float expected[] = {
1739     530,  -762,  1469, 777,  849,   -771, -1698, 600,
1740     -658, -1821, 98,   -668, -1798, 30,   887,   -971,
1741   };
1742 
1743   CNN_CONFIG cnn_config = { 3,
1744                             0,
1745                             0,
1746                             0,
1747                             0,
1748                             {
1749                                 {
1750                                     1,
1751                                     filter_width,
1752                                     filter_height,
1753                                     4,
1754                                     1,
1755                                     1,
1756                                     0,
1757                                     nullptr,
1758                                     nullptr,
1759                                     PADDING_SAME_ZERO,
1760                                     NONE,
1761                                     0,
1762                                     0,
1763                                     BRANCH_OUTPUT,
1764                                     BRANCH_NOC,
1765                                     {
1766                                         0x02,
1767                                         2,
1768                                         0x00,
1769                                     },
1770                                     {},
1771                                     -1,
1772                                 },
1773                                 {
1774                                     4,
1775                                     filter_width,
1776                                     filter_height,
1777                                     2,
1778                                     1,
1779                                     1,
1780                                     0,
1781                                     nullptr,
1782                                     nullptr,
1783                                     PADDING_SAME_ZERO,
1784                                     NONE,
1785                                     0,
1786                                     0,
1787                                     BRANCH_NO_COPY,
1788                                     BRANCH_CAT,
1789                                     {
1790                                         0x00,
1791                                         0,
1792                                         0x02,
1793                                     },
1794                                     {},
1795                                     -1,
1796                                 },
1797                                 {
1798                                     4,
1799                                     filter_width,
1800                                     filter_height,
1801                                     1,
1802                                     1,
1803                                     1,
1804                                     0,
1805                                     nullptr,
1806                                     nullptr,
1807                                     PADDING_SAME_ZERO,
1808                                     NONE,
1809                                     0,
1810                                     0,
1811                                     BRANCH_NO_COPY,
1812                                     BRANCH_NOC,
1813                                     {},
1814                                     {},
1815                                     0,
1816                                 },
1817                             } };
1818 
1819   // Weights and biases need to be specified separately because
1820   // of the offset.
1821   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1822 
1823   CNN_THREAD_DATA thread_data = { 1, NULL };
1824 
1825   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1826              image_width, &thread_data, MSE_INT_TOL);
1827 }
1828 
TEST_F(CNNTest,TestOutputChannelsCount)1829 TEST_F(CNNTest, TestOutputChannelsCount) {
1830   int filter_width = 1;
1831   int filter_height = 1;
1832 
1833   int image_width = 2;
1834   int image_height = 2;
1835 
1836   float input[] = { 0, 0, 0, 0 };
1837 
1838   float weights[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
1839 
1840   float bias[] = { 0, 0, 0, 0, 0, 0 };
1841 
1842   float expected[] = {
1843     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1844     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1845   };
1846 
1847   CNN_CONFIG cnn_config = { 3,
1848                             0,
1849                             0,
1850                             0,
1851                             0,
1852                             {
1853                                 {
1854                                     1,
1855                                     filter_width,
1856                                     filter_height,
1857                                     2,
1858                                     1,
1859                                     1,
1860                                     0,
1861                                     weights,
1862                                     bias,
1863                                     PADDING_SAME_ZERO,
1864                                     NONE,
1865                                     0,
1866                                     0,
1867                                     BRANCH_INPUT,
1868                                     BRANCH_NOC,
1869                                     {
1870                                         0x06,
1871                                         0,
1872                                         0x00,
1873                                     },
1874                                     {},
1875                                     -1,
1876                                 },
1877                                 {
1878                                     1,
1879                                     filter_width,
1880                                     filter_height,
1881                                     2,
1882                                     1,
1883                                     1,
1884                                     0,
1885                                     weights,
1886                                     bias,
1887                                     PADDING_SAME_ZERO,
1888                                     NONE,
1889                                     0,
1890                                     2,
1891                                     BRANCH_NO_COPY,
1892                                     BRANCH_CAT,
1893                                     {
1894                                         0x00,
1895                                         0,
1896                                         0x03,
1897                                     },
1898                                     {},
1899                                     -1,
1900                                 },
1901                                 {
1902                                     2,
1903                                     filter_width,
1904                                     filter_height,
1905                                     2,
1906                                     1,
1907                                     1,
1908                                     0,
1909                                     weights,
1910                                     bias,
1911                                     PADDING_SAME_ZERO,
1912                                     NONE,
1913                                     0,
1914                                     0,
1915                                     BRANCH_NO_COPY,
1916                                     BRANCH_CAT,
1917                                     {
1918                                         0x00,
1919                                         0,
1920                                         0x04,
1921                                     },
1922                                     {},
1923                                     0,
1924                                 },
1925                             } };
1926 
1927   // Weights and biases need to be specified separately because
1928   // of the offset.
1929   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1930 
1931   CNN_THREAD_DATA thread_data = { 1, NULL };
1932 
1933   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1934              image_width, &thread_data, MSE_FLOAT_TOL);
1935 }
1936 
TEST_F(CNNTest,TestBatchNorm)1937 TEST_F(CNNTest, TestBatchNorm) {
1938   int image_width = 28;
1939   int image_height = 28;
1940   int filter_height = 7;
1941   int filter_width = 7;
1942   float input[] = {
1943     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1944     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1945     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1946     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1947     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1948     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1949     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1950     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1951     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1952     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1953     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1954     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1955     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1956     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1957     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1958     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1959     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1960     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1961     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1962     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1963     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1964     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1965     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1966     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1967     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1968     0.0f,       0.0f,       0.0117647f,  0.0705882f,  0.0705882f,  0.0705882f,
1969     0.494118f,  0.533333f,  0.686275f,   0.101961f,   0.65098f,    1.0f,
1970     0.968627f,  0.498039f,  0.0f,        0.0f,        0.0f,        0.0f,
1971     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1972     0.0f,       0.0f,       0.117647f,   0.141176f,   0.368627f,   0.603922f,
1973     0.666667f,  0.992157f,  0.992157f,   0.992157f,   0.992157f,   0.992157f,
1974     0.882353f,  0.67451f,   0.992157f,   0.94902f,    0.764706f,   0.25098f,
1975     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1976     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.192157f,
1977     0.933333f,  0.992157f,  0.992157f,   0.992157f,   0.992157f,   0.992157f,
1978     0.992157f,  0.992157f,  0.992157f,   0.984314f,   0.364706f,   0.321569f,
1979     0.321569f,  0.219608f,  0.152941f,   0.0f,        0.0f,        0.0f,
1980     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1981     0.0f,       0.0f,       0.0f,        0.0705882f,  0.858824f,   0.992157f,
1982     0.992157f,  0.992157f,  0.992157f,   0.992157f,   0.776471f,   0.713725f,
1983     0.968627f,  0.945098f,  0.0f,        0.0f,        0.0f,        0.0f,
1984     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1985     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1986     0.0f,       0.0f,       0.313725f,   0.611765f,   0.419608f,   0.992157f,
1987     0.992157f,  0.803922f,  0.0431373f,  0.0f,        0.168627f,   0.603922f,
1988     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1989     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1990     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1991     0.0f,       0.054902f,  0.00392157f, 0.603922f,   0.992157f,   0.352941f,
1992     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1993     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1994     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1995     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1996     0.0f,       0.545098f,  0.992157f,   0.745098f,   0.00784314f, 0.0f,
1997     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1998     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1999     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2000     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0431373f,
2001     0.745098f,  0.992157f,  0.27451f,    0.0f,        0.0f,        0.0f,
2002     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2003     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2004     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2005     0.0f,       0.0f,       0.0f,        0.0f,        0.137255f,   0.945098f,
2006     0.882353f,  0.627451f,  0.423529f,   0.00392157f, 0.0f,        0.0f,
2007     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2008     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2009     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2010     0.0f,       0.0f,       0.0f,        0.317647f,   0.941176f,   0.992157f,
2011     0.992157f,  0.466667f,  0.0980392f,  0.0f,        0.0f,        0.0f,
2012     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2013     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2014     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2015     0.0f,       0.0f,       0.176471f,   0.729412f,   0.992157f,   0.992157f,
2016     0.588235f,  0.105882f,  0.0f,        0.0f,        0.0f,        0.0f,
2017     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2018     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2019     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2020     0.0f,       0.0627451f, 0.364706f,   0.988235f,   0.992157f,   0.733333f,
2021     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2022     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2023     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2024     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2025     0.0f,       0.976471f,  0.992157f,   0.976471f,   0.25098f,    0.0f,
2026     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2027     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2028     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2029     0.0f,       0.0f,       0.180392f,   0.509804f,   0.717647f,   0.992157f,
2030     0.992157f,  0.811765f,  0.00784314f, 0.0f,        0.0f,        0.0f,
2031     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2032     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2033     0.0f,       0.0f,       0.0f,        0.0f,        0.152941f,   0.580392f,
2034     0.898039f,  0.992157f,  0.992157f,   0.992157f,   0.980392f,   0.713725f,
2035     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2036     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2037     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2038     0.0941176f, 0.447059f,  0.866667f,   0.992157f,   0.992157f,   0.992157f,
2039     0.992157f,  0.788235f,  0.305882f,   0.0f,        0.0f,        0.0f,
2040     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2041     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2042     0.0f,       0.0f,       0.0901961f,  0.258824f,   0.835294f,   0.992157f,
2043     0.992157f,  0.992157f,  0.992157f,   0.776471f,   0.317647f,   0.00784314f,
2044     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2045     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2046     0.0f,       0.0f,       0.0f,        0.0f,        0.0705882f,  0.670588f,
2047     0.858824f,  0.992157f,  0.992157f,   0.992157f,   0.992157f,   0.764706f,
2048     0.313725f,  0.0352941f, 0.0f,        0.0f,        0.0f,        0.0f,
2049     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2050     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2051     0.215686f,  0.67451f,   0.886275f,   0.992157f,   0.992157f,   0.992157f,
2052     0.992157f,  0.956863f,  0.521569f,   0.0431373f,  0.0f,        0.0f,
2053     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2054     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2055     0.0f,       0.0f,       0.0f,        0.0f,        0.533333f,   0.992157f,
2056     0.992157f,  0.992157f,  0.831373f,   0.529412f,   0.517647f,   0.0627451f,
2057     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2058     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2059     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2060     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2061     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2062     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2063     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2064     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2065     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2066     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2067     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2068     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2069     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2070     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2071     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2072     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2073     0.0f,       0.0f,       0.0f,        0.0f
2074   };
2075   float expected[] = {
2076     -0.836424f, -0.857365f, -1.62739f,  -1.62739f,  -0.836424f, 5.40742f,
2077     0.920853f,  -0.692567f, -0.836424f, -0.534405f, -1.62739f,  -0.836424f,
2078     1.32602f,   1.36312f,   0.112766f,  -0.836424f, -0.192962f, 1.56975f,
2079     2.45777f,   0.944414f,  -0.192962f, -1.5519f,   -1.5519f,   -0.554006f,
2080     -0.192962f, 1.4231f,    -1.5519f,   -0.192962f, 1.3661f,    -1.5519f,
2081     -1.5519f,   -0.192962f, -0.843708f, -0.359025f, -0.843708f, -0.843708f,
2082     -0.843708f, 4.53065f,   0.0429584f, -0.796804f, -0.843708f, 0.3473f,
2083     -0.843708f, -0.843708f, -0.114439f, 3.14817f,   0.0811934f, -0.843708f
2084   };
2085   float kernel[] = {
2086     0.119643f,    -0.237864f,   0.0462892f,   0.0502297f,   -0.0134528f,
2087     0.146347f,    0.153133f,    0.0513307f,   0.0752369f,   0.0135557f,
2088     -0.111434f,   0.0941854f,   0.0788362f,   0.0299412f,   0.111762f,
2089     0.144066f,    0.00431504f,  -0.0177954f,  0.0738092f,   -0.0344215f,
2090     0.0832582f,   0.053989f,    -0.112691f,   0.0962145f,   0.0186525f,
2091     -0.00660205f, -0.111962f,   -0.126801f,   -0.231625f,   0.17309f,
2092     0.0748875f,   -0.179569f,   -0.00513812f, -0.156579f,   -0.147322f,
2093     0.184168f,    0.189308f,    -0.200359f,   -0.0156733f,  0.140649f,
2094     0.0858496f,   -0.0263217f,  -0.0740749f,  -0.112563f,   0.107528f,
2095     0.0609729f,   -0.221625f,   0.0769944f,   -0.00900815f, -0.00136441f,
2096     -0.0236521f,  -0.0418025f,  -0.00286299f, 0.12241f,     0.0964093f,
2097     -0.0150897f,  0.0532171f,   0.0625916f,   0.116939f,    0.118024f,
2098     0.161918f,    -0.00909767f, 0.100897f,    -0.054563f,   -0.175179f,
2099     -0.0687892f,  0.00734235f,  0.109833f,    -0.113776f,   0.0595405f,
2100     -0.170255f,   0.0124815f,   -0.0363301f,  -0.0127038f,  0.0445554f,
2101     -0.0729894f,  0.107428f,    -0.0341417f,  0.132619f,    0.00984557f,
2102     -0.00443654f, 0.202929f,    0.0945134f,   0.0148725f,   0.00998574f,
2103     -0.0226449f,  0.0478197f,   -0.0793442f,  0.0707599f,   -0.084225f,
2104     0.0865795f,   0.071104f,    -0.047894f,   0.0838322f,   0.0635493f,
2105     -0.00370265f, -0.157247f,   -0.0289622f,  -0.0590963f,  0.13207f,
2106     0.00468011f,  -0.0345372f,  0.217939f,    0.18861f,     -0.0290393f,
2107     -0.0440664f,  0.0126197f,   -0.129132f,   -0.124943f,   0.0968156f,
2108     -0.0853643f,  -0.182305f,   0.00461618f,  -0.147095f,   -0.230282f,
2109     0.00856019f,  0.0278893f,   -0.0300229f,  0.0417871f,   0.0804717f,
2110     -0.0768571f,  -0.0397085f,  -0.0601096f,  0.100901f,    -0.0184926f,
2111     0.0350673f,   0.0971094f,   -0.0171837f,  -0.289644f,   -0.0899041f,
2112     0.08998f,     -0.160319f,   -0.0195103f,  0.0392167f,   -0.137864f,
2113     -0.0136294f,  0.0330886f,   -0.0409244f,  -0.092533f,   -0.0427934f,
2114     -0.191144f,   -0.0969461f,  0.112035f,    0.138611f,    0.128717f,
2115     0.191184f,    0.197462f
2116   };
2117   float bias[] = { 0.186703f, 0.204358f, -0.0230452f };
2118 
2119   float bn_gamma[] = { 1.32173f, 1.26171f, 1.21966f };
2120   float bn_beta[] = { -0.232595f, -0.222652f, -0.232209f };
2121   float bn_mean[] = { 0.329233f, 0.199894f, 0.12389f };
2122   float bn_std[] = { 0.311986f, 0.189737f, 0.247104f };
2123 
2124   CNN_BATCHNORM_PARAMS bn_params = {
2125     bn_gamma,
2126     bn_beta,
2127     bn_mean,
2128     bn_std,
2129   };
2130 
2131   CNN_CONFIG cnn_config = {
2132     1,
2133     0,
2134     0,
2135     0,
2136     0,
2137     {
2138         {
2139             1,
2140             filter_width,
2141             filter_height,
2142             3,
2143             7,
2144             7,
2145             0,
2146             kernel,
2147             bias,
2148             PADDING_VALID,
2149             RELU,
2150             0,
2151             0,
2152             BRANCH_NO_COPY,
2153             BRANCH_NOC,
2154             {},
2155             bn_params,
2156             0,
2157         },
2158     },
2159   };
2160 
2161   CNN_THREAD_DATA thread_data = { 1, NULL };
2162 
2163   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2164              image_width, &thread_data, MSE_FLOAT_TOL);
2165 }
2166 
TEST_F(CNNTest,TestMultithreading)2167 TEST_F(CNNTest, TestMultithreading) {
2168   int image_height = 2;
2169   int image_width = 2;
2170   int filter_height = 3;
2171   int filter_width = 3;
2172 
2173   float input[] = {
2174     -2,
2175     4,
2176     1,
2177     0,
2178   };
2179 
2180   float weights[] = {
2181     -4, 2, -2, 0,  -4, 4, -3, -3, -3, -1, 1,  0,  -5, -3, 0, -5, 0, 0,
2182     -1, 0, 2,  -5, 0,  1, 4,  2,  1,  0,  -2, -1, -5, -3, 2, -2, 1, -5,
2183   };
2184 
2185   float bias[] = {
2186     -4,
2187     -3,
2188     -2,
2189     3,
2190   };
2191 
2192   float expected[] = {
2193     2, 10, -8, -17, -24, 5, -15, 6, -5, -5, 7, -10, 4, 13, 9, -14,
2194   };
2195 
2196   CNN_CONFIG cnn_config = {
2197     1,
2198     0,
2199     0,
2200     0,
2201     0,
2202     {
2203         {
2204             1,
2205             filter_width,
2206             filter_height,
2207             4,
2208             1,
2209             1,
2210             0,
2211             weights,
2212             bias,
2213             PADDING_SAME_ZERO,
2214             NONE,
2215             0,
2216             0,
2217             BRANCH_NO_COPY,
2218             BRANCH_NOC,
2219             {},
2220             {},
2221             0,
2222         },
2223     },
2224   };
2225 
2226   CNN_THREAD_DATA thread_data = { 1, NULL };
2227 
2228   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2229              image_width, &thread_data, MSE_FLOAT_TOL);
2230 
2231   const AVxWorkerInterface *const winterface = aom_get_worker_interface();
2232   AVxWorker workers[4];
2233 
2234   for (int i = 0; i < 4; ++i) {
2235     winterface->init(&workers[i]);
2236   }
2237 
2238   thread_data = { 4, workers };
2239 
2240   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2241              image_width, &thread_data, MSE_FLOAT_TOL);
2242 
2243   for (int i = 0; i < 4; ++i) {
2244     winterface->end(&workers[i]);
2245   }
2246 }
2247 
TEST_F(CNNTest,TestMultiOutput)2248 TEST_F(CNNTest, TestMultiOutput) {
2249   const int image_dim = 8;
2250   const int image_ch = 3;
2251   const int filter_dim = 2;
2252   const int stride = 2;
2253   const int num_filters = 2;
2254 
2255   const float input_[] = {
2256     1.7537929121f,     0.134331551012f,    0.123580039877f,   0.957731845246f,
2257     0.391006834217f,   1.00699352042f,     -0.778177955829f,  -0.814166433059f,
2258     -0.656374394915f,  0.321967305228f,    -2.19455719176f,   0.708035038966f,
2259     0.409148822266f,   -0.318254408902f,   0.152450211189f,   -0.250210793369f,
2260     0.826811563186f,   1.6804156584f,      0.273626975978f,   0.437936241887f,
2261     -0.329935520167f,  -0.288761611645f,   0.156937008304f,   0.271054157295f,
2262     -0.0224828854332f, 1.70110336895f,     -0.989066699309f,  1.30863131729f,
2263     -0.165813705702f,  0.00380178619265f,  -0.0837342367587f, 0.760954783156f,
2264     -0.413610373524f,  1.17968204175f,     0.720295719536f,   0.308718974472f,
2265     -1.10091337671f,   0.693160033687f,    -0.0202862320697f, 1.0221927503f,
2266     -1.24521801881f,   -0.478501952308f,   -1.71648619442f,   -0.182571723636f,
2267     0.339292649504f,   2.0806519131f,      0.967974033444f,   0.175248672328f,
2268     0.0658124561472f,  0.795504169496f,    0.750592557361f,   -1.46631013249f,
2269     -1.79052846838f,   -1.03672179515f,    -0.841985521653f,  1.20995011489f,
2270     0.140859718215f,   -0.651552622661f,   0.451065110806f,   1.1189443693f,
2271     0.100213260593f,   -0.834076868118f,   -1.28734321611f,   1.22064420095f,
2272     -0.364143084361f,  0.750961509335f,    -0.888689074553f,  -0.8253547106f,
2273     -1.21800999027f,   -0.966670603566f,   1.37384014741f,    0.47281264834f,
2274     -0.420416235531f,  0.520163906493f,    0.501296589423f,   1.53418976951f,
2275     0.715234751485f,   0.644551588907f,    0.0763504863375f,  -0.0018541943723f,
2276     0.322853189656f,   -0.795099723224f,   -0.125177096675f,  1.4476577471f,
2277     -0.585888410088f,  -1.44391754955f,    -0.610543221933f,  -0.221859179799f,
2278     0.252060200774f,   -0.86287169623f,    -0.0350246229157f, 1.0932311997f,
2279     0.899464648842f,   -0.468806951704f,   -0.300861137168f,  1.15776414206f,
2280     1.03268544738f,    -0.171579585622f,   -0.179136557119f,  -0.354091003368f,
2281     -0.612298249394f,  -1.20237379258f,    1.54604109659f,    0.130664370287f,
2282     0.885225111868f,   1.0362799581f,      0.980561720868f,   -0.619379186999f,
2283     -1.33818929924f,   -0.237233737961f,   -1.89335425073f,   0.567821011321f,
2284     0.862420368465f,   -1.37380916821f,    0.352190056666f,   0.611261516274f,
2285     0.393237747152f,   0.894686247967f,    0.190405182149f,   0.264872662911f,
2286     -0.0657009133797f, 0.0580512653493f,   -0.401825294366f,  0.4106081318f,
2287     0.49484512188f,    -0.0751103149442f,  -1.43243736382f,   1.79855656009f,
2288     -1.1075351975f,    0.000354882733011f, -0.950716438608f,  1.27129831688f,
2289     1.00495189838f,    0.110358656713f,    1.08315032822f,    -0.972676676218f,
2290     -0.0757668962831f, 1.88932045165f,     -0.0672638136275f, 0.425913010161f,
2291     -0.781540372017f,  0.976000248609f,    0.687218504122f,   1.31374513445f,
2292     -0.932658930672f,  -1.25339468479f,    0.422071294078f,   -0.24189927912f,
2293     0.216906604642f,   -1.88720997548f,    1.99252872889f,    0.353943735777f,
2294     0.737434784132f,   -1.17848645017f,    1.70424254896f,    0.775297112968f,
2295     -0.516392797501f,  0.398130609129f,    0.737248101457f,   0.166282500886f,
2296     1.24699015468f,    0.47116183125f,     1.19091180182f,    -0.372695424578f,
2297     0.219773209389f,   -0.829467838962f,   -0.52533122724f,   1.98707754595f,
2298     0.553692606972f,   -0.933228902369f,   1.55427751643f,    -1.08813399144f,
2299     -0.325686682094f,  0.205091443796f,    -1.70381666435f,   0.466465327942f,
2300     1.73126863447f,    -0.939133672634f,   1.48318077459f,    -0.599414038168f,
2301     -1.1583078687f,    0.518116190201f,    0.133571482458f,   0.84958342672f,
2302     1.02205000597f,    -0.0772082009087f,  -1.69567503859f,   1.4697939436f,
2303     1.67813743122f,    -0.627911582938f,   0.131380509137f,   -1.35717850726f,
2304   };
2305   const float *input[3] = { input_, &input_[image_dim * image_dim],
2306                             &input_[2 * image_dim * image_dim] };
2307 
2308   const float bias[] = { 0.0f, 0.0f };
2309 
2310   const float weights_1[] = {
2311     -0.489547413618f, 0.141916424749f,  -0.279286485585f,  -0.115322211094f,
2312     0.299572786936f,  0.205289980785f,  -0.536254480088f,  -0.253626313744f,
2313     -0.422883815849f, -0.169702966298f, -0.540104704793f,  0.495319646763f,
2314     0.298799079422f,  -0.10054550901f,  -0.306085047056f,  0.171061886165f,
2315     -0.108058703878f, -0.410734629888f, -0.0640674673049f, -0.386524840979f,
2316     -0.157203423678f, -0.362138920529f, -0.216206085209f,  0.147502517971f,
2317   };
2318 
2319   const float weights_2[] = {
2320     0.207580604357f,  0.480821146263f,  -0.29111909562f,   0.47422567493f,
2321     0.206892553253f,  -0.235067084092f, 0.354516800602f,   -0.212399370252f,
2322     -0.419071343731f, -0.050350731631f, -0.0516457320279f, -0.0359310500731f,
2323     0.567044864811f,  -0.060341127522f, 0.0501464839637f,  -0.437785677916f,
2324   };
2325 
2326   const float weights_3[] = {
2327     -0.0690452401448f, -0.356657338763f,   -0.219464031809f, 0.551288365843f,
2328     0.181372090853f,   -0.00245268542109f, 0.409000696276f,  -0.593209108763f,
2329     0.587352566749f,   -0.243720660227f,   0.266232713887f,  -0.00439285245097f,
2330     0.252883228305f,   0.152646192631f,    0.0918944932026f, 0.398853715057f,
2331   };
2332 
2333   const float weights_4[] = {
2334     0.207560791573f,   0.194201350401f,   0.227802322443f,  0.206533663345f,
2335     0.0557331066805f,  0.0224159800424f,  -0.143939197467f, -0.27703361602f,
2336     0.130643888389f,   -0.269456557461f,  0.186242862864f,  -0.162879944774f,
2337     -0.145503996718f,  -0.0768822987581f, -0.203127976359f, -0.238119922873f,
2338     -0.258806479994f,  0.0357957680385f,  -0.1027606976f,   -0.287920082345f,
2339     0.189047820993f,   0.250711538481f,   -0.272815714175f, -0.0431449742024f,
2340     0.207261230996f,   -0.0396472677451f, 0.131236557412f,  0.174291832499f,
2341     -0.251515885765f,  -0.107164007499f,  0.185824534748f,  -0.00561585838161f,
2342     0.273393799578f,   -0.139563699075f,  -0.263922456031f, -0.118859844081f,
2343     0.109230982597f,   -0.170170294794f,  0.0123025648515f, -0.0839368964355f,
2344     -0.0774058234297f, 0.255847138286f,   -0.208430879637f, 0.279170114319f,
2345     -0.272890330712f,  -0.217725903006f,  -0.295923275459f, -0.17008723953f,
2346     -0.284281803405f,  0.281406323629f,   0.266910044663f,  -0.209963914338f,
2347     0.271980962964f,   0.142013581699f,   -0.143896509026f, -0.290509242975f,
2348     -0.305768180935f,  0.196902832117f,   -0.090424189662f, -0.147460802346f,
2349     0.217722016651f,   0.12353848977f,    -0.169177363577f, -0.0454230918512f,
2350   };
2351 
2352   const float expected_0[] = {
2353     -2.04858441055f,  -2.12883075791f,    -0.045177363807f, 0.763949675768f,
2354     -0.544361512821f, -1.58123168032f,    1.89319847039f,   0.16859080901f,
2355     -1.16023321135f,  -0.396988107751f,   1.76637090744f,   -1.40434786514f,
2356     0.908227575669f,  0.817064817605f,    0.215631134908f,  -0.848605613428f,
2357     -0.106756747018f, 0.0193027166685f,   0.801345615113f,  -0.395407237598f,
2358     -1.79983795658f,  -1.73054496242f,    0.0584392594454f, -0.388786095569f,
2359     -0.237269619354f, 0.000843578271263f, -1.24043512104f,  0.487839445893f,
2360     -0.394259726605f, 0.559632843424f,    -0.527224052291f, -1.53792340282f,
2361   };
2362 
2363   const float expected_1[] = {
2364     0.0f, 0.0f,           0.0f, 0.0f, 0.4057888292f, 0.325309571755f,
2365     0.0f, 1.22013465602f,
2366   };
2367 
2368   const float expected_2[] = {
2369     0.156119444687f,
2370     0.517385299817f,
2371   };
2372 
2373   const float expected_3[] = {
2374     0.224177852984f,
2375     0.503384419034f,
2376     0.156119444687f,
2377     0.517385299817f,
2378   };
2379 
2380   const float *expected[] = { expected_0, expected_1, expected_2, expected_3 };
2381 
2382   CNN_CONFIG cnn_config = {
2383     4,  // num_layers
2384     0,  // is_residue
2385     0,  // ext_width
2386     0,  // ext_height
2387     0,  // strict_bounds
2388     {
2389         // layer_config
2390         {
2391             image_ch,           // in_channels
2392             filter_dim,         // filter_width
2393             filter_dim,         // filter_height
2394             num_filters,        // out_channels
2395             stride,             // skip_width
2396             stride,             // skip_height
2397             0,                  // max_pool
2398             weights_1,          // weights
2399             bias,               // bias
2400             PADDING_SAME_ZERO,  // pad
2401             NONE,               // activation
2402             0,                  // deconvolve
2403             0,                  // branch
2404             BRANCH_OUTPUT,      // branch_copy_type
2405             BRANCH_NOC,         // branch_combine_type
2406             { 2, 0, 0 },        // branch_config
2407             {},                 // bn_params
2408             0,                  // output_num
2409         },
2410         {
2411             num_filters,        // in_channels
2412             filter_dim,         // filter_width
2413             filter_dim,         // filter_height
2414             num_filters,        // out_channels
2415             stride,             // skip_width
2416             stride,             // skip_height
2417             0,                  // max_pool
2418             weights_2,          // weights
2419             bias,               // bias
2420             PADDING_SAME_ZERO,  // pad
2421             RELU,               // activation
2422             0,                  // deconvolve
2423             0,                  // branch
2424             BRANCH_NO_COPY,     // branch_copy_type
2425             BRANCH_NOC,         // branch_combine_type
2426             {},                 // branch_config
2427             {},                 // bn_params
2428             1,                  // output_num
2429         },
2430         {
2431             num_filters,        // in_channels
2432             filter_dim,         // filter_width
2433             filter_dim,         // filter_height
2434             num_filters,        // out_channels
2435             stride,             // skip_width
2436             stride,             // skip_height
2437             0,                  // max_pool
2438             weights_3,          // weights
2439             bias,               // bias
2440             PADDING_SAME_ZERO,  // pad
2441             RELU,               // activation
2442             0,                  // deconvolve
2443             0,                  // branch
2444             BRANCH_NO_COPY,     // branch_copy_type
2445             BRANCH_NOC,         // branch_combine_type
2446             {},                 // branch_config
2447             {},                 // bn_params
2448             2,                  // output_num
2449         },
2450         {
2451             num_filters,     // in_channels
2452             2 * filter_dim,  // filter_width
2453             2 * filter_dim,  // filter_height
2454             num_filters,     // out_channels
2455             2 * stride,      // skip_width
2456             2 * stride,      // skip_height
2457             0,               // max_pool
2458             weights_4,       // weights
2459             bias,            // bias
2460             PADDING_VALID,   // pad
2461             RELU,            // activation
2462             0,               // deconvolve
2463             1,               // branch
2464             BRANCH_NO_COPY,  // branch_copy_type
2465             BRANCH_CAT,      // branch_combine_type
2466             { 0, 0, 1 },     // branch_config
2467             {},              // bn_params
2468             3,               // output_num
2469         },
2470     },
2471   };
2472 
2473   CNN_THREAD_DATA thread_data = { 1, NULL };
2474 
2475   const int num_outputs = 4;
2476   const int output_chs[4] = { filter_dim, filter_dim, filter_dim,
2477                               2 * filter_dim };
2478   const int output_dims[4] = { 4, 2, 1, 1 };
2479   const int output_sizes[4] = {
2480     output_chs[0] * output_dims[0] * output_dims[0],
2481     output_chs[1] * output_dims[1] * output_dims[1],
2482     output_chs[2] * output_dims[2] * output_dims[2],
2483     output_chs[3] * output_dims[3] * output_dims[3],
2484   };
2485   float *const output_ = (float *)aom_malloc(
2486       sizeof(*output_) *
2487       (output_sizes[0] + output_sizes[1] + output_sizes[2] + output_sizes[3]));
2488   float *output[CNN_MAX_CHANNELS] = { nullptr };
2489   int ch_ite = 0;
2490   float *output_ite = output_;
2491   for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
2492     for (int channel = 0; channel < output_chs[output_idx]; ++channel) {
2493       output[ch_ite++] = output_ite;
2494       output_ite += output_dims[output_idx] * output_dims[output_idx];
2495     }
2496   }
2497   CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_dims,
2498                                   output };
2499 
2500   RunMultiOutCNNTest(input, image_dim, image_dim, image_dim, &cnn_config,
2501                      &thread_data, &output_struct, expected, MSE_FLOAT_TOL);
2502 
2503   aom_free(output_);
2504 }
2505 
2506 namespace {
2507 
2508 typedef void (*CNNConvolveNoMaxpoolPaddingValidFunc)(
2509     const float **input, int in_width, int in_height, int in_stride,
2510     const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride,
2511     int start_idx, int cstep, int channel_step);
2512 
2513 typedef libaom_test::FuncParam<CNNConvolveNoMaxpoolPaddingValidFunc>
2514     CNNConvolveTestFuncs;
2515 
2516 class CNNConvolveTest : public ::testing::TestWithParam<CNNConvolveTestFuncs> {
2517  protected:
SetUp()2518   virtual void SetUp() { params_ = GetParam(); }
2519 
RunCNNConvolveSetup(int run_times)2520   void RunCNNConvolveSetup(int run_times) {
2521     int in_width = 65;
2522     int in_height = 65;
2523 
2524     const CNN_CONFIG *cnn_config = &av1_intra_mode_cnn_partition_cnn_config;
2525 
2526     for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
2527       int out_width = 0, out_height = 0;
2528       int in_size = in_width * in_height;
2529       // Get current layer output width and height.
2530       av1_find_cnn_layer_output_size(in_height, in_width,
2531                                      &cnn_config->layer_config[layer],
2532                                      &out_width, &out_height);
2533 
2534       int out_size = out_width * out_height;
2535       float *input[20], *output_ref[20], *output_mod[20];
2536 
2537       float *input_data =
2538           (float *)aom_malloc(sizeof(*input_data) * in_size *
2539                               cnn_config->layer_config[layer].in_channels);
2540       float *temp_ptr = input_data;
2541       for (int i = 0; i < cnn_config->layer_config[layer].in_channels; ++i) {
2542         input[i] = temp_ptr;
2543         for (int j = 0; j < in_size; j++) {
2544           *(temp_ptr++) = ((float)rng_.Rand31() - (1 << 30)) / (1u << 31);
2545         }
2546       }
2547 
2548       float *out_data_ref = (float *)aom_calloc(
2549           sizeof(*out_data_ref),
2550           out_size * cnn_config->layer_config[layer].out_channels);
2551       float *out_data_mod = (float *)aom_calloc(
2552           sizeof(*out_data_mod),
2553           out_size * cnn_config->layer_config[layer].out_channels);
2554       float *temp_ptr1 = out_data_ref;
2555       float *temp_ptr2 = out_data_mod;
2556       for (int i = 0; i < cnn_config->layer_config[layer].out_channels; ++i) {
2557         output_ref[i] = temp_ptr1;
2558         output_mod[i] = temp_ptr2;
2559         temp_ptr1 += out_size;
2560         temp_ptr2 += out_size;
2561       }
2562 
2563       RunCNNConvolveTest(input, in_width, in_height, out_size,
2564                          &cnn_config->layer_config[layer], 0, 1, run_times,
2565                          layer, output_ref, output_mod, out_width);
2566 
2567       // Set current layer output width and height as next layer input width and
2568       // height.
2569       in_width = out_width;
2570       in_height = out_height;
2571 
2572       aom_free(input_data);
2573       aom_free(out_data_ref);
2574       aom_free(out_data_mod);
2575     }
2576   }
2577 
RunCNNConvolveTest(float ** input,int in_width,int in_height,int out_size,const CNN_LAYER_CONFIG * layer_config,int start_idx,int step,int run_times,int layer,float ** output_ref,float ** output_mod,int out_stride)2578   void RunCNNConvolveTest(float **input, int in_width, int in_height,
2579                           int out_size, const CNN_LAYER_CONFIG *layer_config,
2580                           int start_idx, int step, int run_times, int layer,
2581                           float **output_ref, float **output_mod,
2582                           int out_stride) {
2583     const int cstep = layer_config->in_channels * layer_config->out_channels;
2584     const int channel_step = AOMMAX(step, 1);
2585     aom_usec_timer timer;
2586     aom_usec_timer_start(&timer);
2587     for (int i = 0; i < run_times; ++i) {
2588       params_.ref_func((const float **)input, in_width, in_height, in_width,
2589                        layer_config, output_ref, out_stride, start_idx, cstep,
2590                        channel_step);
2591     }
2592     aom_usec_timer_mark(&timer);
2593     const double time1 = static_cast<double>(aom_usec_timer_elapsed(&timer));
2594 
2595     aom_usec_timer_start(&timer);
2596     for (int i = 0; i < run_times; ++i) {
2597       params_.tst_func((const float **)input, in_width, in_height, in_width,
2598                        layer_config, output_mod, out_stride, start_idx, cstep,
2599                        channel_step);
2600     }
2601     aom_usec_timer_mark(&timer);
2602     const double time2 = static_cast<double>(aom_usec_timer_elapsed(&timer));
2603 
2604     if (run_times > 1) {
2605       printf("layer : %d \n", layer);
2606       printf("%7.2f/%7.2fns (%3.2f)\n", time1, time2, time1 / time2);
2607     } else {
2608       for (int channel = 0; channel < layer_config->out_channels; ++channel) {
2609         const float *buf_ref = output_ref[channel];
2610         const float *buf_mod = output_mod[channel];
2611 
2612         for (int i = 0; i < out_size; ++i) {
2613           if (buf_ref[i] < CNN_CONVOLVE_PIXELWISE_FLOAT_TOL) {
2614             ASSERT_LE(buf_ref[i], CNN_CONVOLVE_PIXELWISE_FLOAT_TOL)
2615                 << "Reference output was near-zero, test output was not ("
2616                 << buf_mod[i] << ")";
2617           } else {
2618             const float error = buf_ref[i] - buf_mod[i];
2619             const float relative_error = fabsf(error / buf_ref[i]);
2620             ASSERT_LE(relative_error, CNN_CONVOLVE_PIXELWISE_FLOAT_TOL)
2621                 << " channel " << channel << " pixel " << i << ": "
2622                 << buf_ref[i] << "/" << buf_mod[i] << std::endl;
2623           }
2624         }
2625       }
2626     }
2627   }
2628 
2629  private:
2630   CNNConvolveTestFuncs params_;
2631   libaom_test::ACMRandom rng_;
2632 };
2633 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CNNConvolveTest);
2634 
TEST_P(CNNConvolveTest,CheckOutput)2635 TEST_P(CNNConvolveTest, CheckOutput) { RunCNNConvolveSetup(1); }
2636 
TEST_P(CNNConvolveTest,DISABLED_Speed)2637 TEST_P(CNNConvolveTest, DISABLED_Speed) { RunCNNConvolveSetup(100000); }
2638 
2639 #if HAVE_AVX2 && !CONFIG_EXCLUDE_SIMD_MISMATCH
2640 INSTANTIATE_TEST_SUITE_P(AVX2, CNNConvolveTest,
2641                          ::testing::Values(CNNConvolveTestFuncs(
2642                              &av1_cnn_convolve_no_maxpool_padding_valid_c,
2643                              &av1_cnn_convolve_no_maxpool_padding_valid_avx2)));
2644 #endif
2645 
2646 }  // namespace
2647