1 /*
2 * Copyright (c) 2019, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <assert.h>
13 #include <math.h>
14 #include <stdio.h>
15
16 #include "third_party/googletest/src/googletest/include/gtest/gtest.h"
17
18 #include "config/av1_rtcd.h"
19
20 #include "aom_ports/aom_timer.h"
21 #include "av1/encoder/cnn.h"
22 #include "av1/encoder/partition_cnn_weights.h"
23 #include "test/acm_random.h"
24 #include "test/function_equivalence_test.h"
25 #include "test/util.h"
26
27 #define SQR(x) ((x) * (x))
28
29 // Best possible pixelwise guarenteed preicison given each float has at most
30 // 3 specified decimals.
31 #define PIXELWISE_FLOAT_TOL 1E-2
32
33 #define MSE_FLOAT_TOL 1E-6
34 #define MSE_INT_TOL 0
35
36 // CNN convolve pixelwise error threshold for functional equivalence.
37 #define CNN_CONVOLVE_PIXELWISE_FLOAT_TOL 1E-3f
38
39 namespace {
40
41 class CNNTest : public ::testing::Test {
42 protected:
RunCNNTest(int image_width,int image_height,const float * input,const float * expected,const CNN_CONFIG * cnn_config,int in_stride,CNN_THREAD_DATA * thread_data,double tolerance)43 static void RunCNNTest(int image_width, int image_height, const float *input,
44 const float *expected, const CNN_CONFIG *cnn_config,
45 int in_stride, CNN_THREAD_DATA *thread_data,
46 double tolerance) {
47 int out_width, out_height, out_channels;
48 av1_find_cnn_output_size(image_width, image_height, cnn_config, &out_width,
49 &out_height, &out_channels);
50
51 const int out_size = out_width * out_height;
52 const int out_stride = out_width;
53
54 float *output_ =
55 (float *)aom_malloc(sizeof(*output_) * out_size * out_channels);
56 float *output[CNN_MAX_CHANNELS] = { nullptr };
57 for (int channel = 0; channel < out_channels; ++channel) {
58 output[channel] = output_ + (channel * out_size);
59 }
60 const int num_outputs = 1;
61 const int output_chs[1] = { out_channels };
62 const int output_strides[1] = { out_stride };
63 CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_strides,
64 output };
65
66 RunMultiOutCNNTest(&input, image_width, image_height, in_stride, cnn_config,
67 thread_data, &output_struct, &expected, tolerance);
68
69 aom_free(output_);
70 }
71
RunMultiOutCNNTest(const float ** input,int image_width,int image_height,int in_stride,const CNN_CONFIG * cnn_config,CNN_THREAD_DATA * thread_data,CNN_MULTI_OUT * output,const float ** expected,double tolerance)72 static void RunMultiOutCNNTest(const float **input, int image_width,
73 int image_height, int in_stride,
74 const CNN_CONFIG *cnn_config,
75 CNN_THREAD_DATA *thread_data,
76 CNN_MULTI_OUT *output, const float **expected,
77 double tolerance) {
78 const int num_outputs = output->num_outputs;
79 const int *output_chs = output->output_channels;
80
81 int *out_widths = (int *)aom_calloc(num_outputs, sizeof(*out_widths));
82 int *out_heights = (int *)aom_calloc(num_outputs, sizeof(*out_heights));
83 int *not_used = (int *)aom_calloc(num_outputs, sizeof(*not_used));
84
85 av1_find_cnn_output_size(image_width, image_height, cnn_config, out_widths,
86 out_heights, not_used);
87 av1_cnn_predict(input, image_width, image_height, in_stride, cnn_config,
88 thread_data, output);
89
90 int channel_offset = 0;
91 for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
92 const float *expected_out = expected[output_idx];
93 const int curr_output_chs = output_chs[output_idx];
94 const int out_size = out_widths[output_idx] * out_heights[output_idx];
95
96 double mse = 0;
97 int expected_ite = 0;
98 for (int channel = 0; channel < curr_output_chs; ++channel) {
99 const float *buf_out = output->output_buffer[channel_offset];
100
101 for (int i = 0; i < out_size; ++i) {
102 EXPECT_NEAR(expected_out[expected_ite], buf_out[i],
103 PIXELWISE_FLOAT_TOL)
104 << " output " << output_idx << " channel " << channel << " pixel "
105 << expected_ite % out_size << ": " << expected_out[expected_ite]
106 << "/" << buf_out[i] << std::endl;
107 mse += SQR(expected_out[expected_ite] - buf_out[i]);
108 expected_ite++;
109 }
110
111 channel_offset++;
112 }
113 mse /= (out_size * curr_output_chs);
114 EXPECT_LE(mse, tolerance) << " output " << output_idx << std::endl;
115 }
116
117 aom_free(out_widths);
118 aom_free(out_heights);
119 aom_free(not_used);
120 }
121
AssignLayerWeightsBiases(CNN_CONFIG * cnn_config,float * weights,float * bias)122 static void AssignLayerWeightsBiases(CNN_CONFIG *cnn_config, float *weights,
123 float *bias) {
124 size_t weight_offset = 0;
125 size_t bias_offset = 0;
126 for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
127 CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[layer];
128 layer_config->weights = weights + weight_offset;
129 layer_config->bias = bias + bias_offset;
130 weight_offset += layer_config->filter_width *
131 layer_config->filter_height * layer_config->in_channels *
132 layer_config->out_channels;
133 bias_offset += layer_config->out_channels;
134
135 ASSERT_NE(layer_config->weights, nullptr);
136 ASSERT_NE(layer_config->bias, nullptr);
137 }
138 }
139 };
140
141 } // namespace
142
TEST_F(CNNTest,TestMultilayerConvolution)143 TEST_F(CNNTest, TestMultilayerConvolution) {
144 int image_height = 16;
145 int image_width = 16;
146 int filter_height = 5;
147 int filter_width = 4;
148
149 float input[] = {
150 -3, 1, -3, 2, -2, -2, 2, -2, 1, -2, -3, 1, 2, 2, 2, -2, 0, 1, -1,
151 -3, -1, -1, 1, 0, -3, 1, 0, -1, 1, 0, 0, -3, -3, -3, 0, 2, 1, -1,
152 2, 0, 1, -3, -1, 2, 2, 1, -2, 0, -1, 0, -2, -2, -1, 1, 0, 0, 0,
153 -2, -2, -2, 1, 1, -2, 1, 1, -2, -2, 1, -2, -1, -2, -3, 2, -3, -1, 1,
154 0, -2, -2, -2, 1, -2, -2, -1, -1, 2, 2, 2, -1, 1, -3, -3, 0, 2, 0,
155 2, 1, -3, -3, 1, 2, 2, 1, -2, -3, 0, -3, 0, -3, -2, 0, 1, 1, 0,
156 -3, 2, -1, 2, 1, 0, 1, -2, 1, -1, -1, 2, 0, -2, -3, 1, 1, -2, -1,
157 -3, -3, -1, 0, -3, -2, 0, 0, 1, 0, -3, -2, -1, 1, 0, 2, 1, 0, -3,
158 -2, -3, -3, -1, 0, -2, 2, -1, -3, 0, -1, -1, 2, 0, -3, -2, -1, 0, 0,
159 1, -2, 1, 2, 1, 2, 2, -3, 2, -1, 0, 0, -1, 0, 2, 2, -1, 2, -2,
160 1, 1, -3, -3, 1, -1, -1, -2, 2, -2, -2, 2, -1, -3, 2, -3, 1, -1, -1,
161 -3, 1, -1, 1, 0, -3, -3, 1, -3, -3, 0, 2, 2, -2, -1, 2, 0, 2, 1,
162 -1, -3, 0, 0, -1, -1, 1, 0, 2, 0, -3, 2, 1, 0, 1, -3, 2, -3, -3,
163 -1, -3, -3, 2, 0, 2, -2, 1, -1,
164 };
165
166 float weights[] = {
167 -2, 2, -2, 2, -1, -3, 2, 2, 0, 0, -3, -1, -2, -3, 1, -1, 0, 0, 0,
168 2, -2, 2, -2, -3, 1, 1, 1, -3, -1, 0, 1, 2, -2, 0, -1, -3, -1, -2,
169 2, -3, -3, 1, -2, -3, 0, 2, 1, -3, -3, -1, -3, -2, -1, -3, -1, -3, -2,
170 -1, -3, -1, -2, -2, -3, 2, 0, -3, 0, -3, -3, 1, -3, -1, 0, -1, 1, 1,
171 -1, 1, -2, 0, 2, 0, -3, 1, -1, -1, 2, 0, 1, -3, -3, 1, 2, -3, -3,
172 1, -3, 2, 0, -3, 1, 2, 2, -2, -1, -2, 1, 1, 0, -2, -2, 1, 2, -1,
173 -3, 1, -2, 2, -3, -2, -3, 2, 1, 0, -2, 0, 1, -3, 2, -2, -2, 0, 2,
174 -3, 2, 0, 0, 1, -2, 1, 1, -2, -1, -2, 1, -2, 0, -2, -2, 0, -1, -1,
175 -3, -3, -3, 1, -3, -2, 2, -1, 2, 0, 2, -2, 2, -2, 1, -3, -3, -1, 0,
176 2, 2, 1, -1, -3, -1, -3, 2, 1, -2, 0, -3, -1, -3, -1, 2, 1, 0, 2,
177 -1, 1, 0, 1, 2, -1, -2, 2, 1, -3, -1, -3, 0, 1, -2, 0, -2, -3, 0,
178 -2, 2, 2, 0, 0, 2, -3, 2, -3, -2, 1, 2, -3, -3, -1, -3, 0, -3, -3,
179 -2, -2, -2, 0, 0, 1, 0, 0, -1, 0, 0, -3, 0, -3, -1, -2, 1, -2, -1,
180 2, -2, 0, 0, 1, 0, -2, -1, 0, -3, 1, 0, -1, -3, 1, -1, 1, -1, -3,
181 1, 0, 1, 1, -1, 2, 2, 0, 0, 1, -3, 2, -2, -2, -3, -2, -1, -2, 2,
182 0, 2, -2, -3, -1, -3, 2, 2, -1, 2, 2, -1, 0, -3, 1,
183 };
184
185 float bias[] = {
186 1, -1, 0, 1, 1, 1, -2,
187 };
188
189 float expected_same[] = {
190 -1125, 2926, 6406, 631, -1244, 97, -1454, 2526, 1065, 3292, 3464,
191 2553, -330, 532, 1038, 1182, -402, 3758, 3392, 9854, 4365, 1408,
192 4736, 3134, 3838, 2409, 3221, 4350, 6750, 4045, 815, 1188, 2959,
193 9802, 9590, 4572, 5740, 4253, 1701, 7974, 7012, 6854, 7093, 3907,
194 4539, 3886, 4267, 3505, 465, 7824, 9219, 10026, 7968, 957, 2295,
195 5594, 10811, 9641, 5950, 10043, 8783, 3132, 1421, 1110, 4108, 13929,
196 10660, -84, -61, 3932, -180, 6811, 13393, 15147, 15640, 9337, 6961,
197 3808, 1604, 1398, 1047, 6739, 10144, 6517, 4698, 2678, 7389, 2595,
198 5248, 12075, 11272, 13951, 8820, 1090, 2199, 2206, 2788, 12116, 6683,
199 2612, -291, 3183, 9414, 12316, 14524, 12333, 13208, 7832, 4664, 4657,
200 3534, 1298, -666, 4250, 7707, 9103, 5760, 688, 9571, 15782, 14203,
201 14878, 17339, 14684, 8690, 5671, 875, 1429, 1531, 6173, 2984, 5558,
202 2996, 7928, 6733, 16117, 15262, 12757, 7980, 3923, 4795, 5973, 2051,
203 455, -1922, 1816, 5906, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
204 7451, 6666, 74, -1645, -35, -391, 3813, 7324, 892, 1656, 6095,
205 12193, 14648, 12156, 14663, 10251, 10325, 7821, 3925, 323, 697, 442,
206 1324, 4669, 7002, 5485, 5171, 5086, 10582, 11053, 9709, 11353, 8543,
207 5256, 2873, 235, -628, 1496, 1878, -867, 3420, 6865, 5937, 10182,
208 13277, 10069, 10789, 5998, 624, -2082, 4417, 1258, -1080, -819, -1430,
209 1033, 5220, 6335, 8471, 8980, 11908, 14430, 12584, 8404, 1576, -803,
210 985, 1481, 1367, -193, 873, 3684, 2288, 6676, 9477, 11155, 9602,
211 9707, 10507, 4739, 3174, -575, -178, 3002, 1710, 423, -477, 554,
212 3088, 2029, 5113, 5000, 3771, 6090, 5365, 1185, 2855, 399, -312,
213 -1577, 176, 955,
214 };
215
216 float expected_replicate[] = {
217 13768, 13528, 12999, 6906, 4618, 4043, 2611, 9955, 6685, 4776, 2753,
218 1036, 3063, 4544, 5183, 7349, 12451, 12501, 9131, 12753, 8908, 4058,
219 6299, 7542, 7115, 3307, 3360, 3543, 9754, 7808, 5991, 9019, 14320,
220 14919, 12492, 6871, 7373, 3336, 2085, 10604, 9377, 6882, 5009, 3103,
221 6220, 6278, 7588, 10196, 11045, 11563, 11842, 11911, 8279, 2030, 1858,
222 6368, 12123, 9909, 6347, 10345, 9365, 4038, 1673, 3051, 16492, 16649,
223 12276, 408, -301, 4122, -654, 7864, 14038, 15279, 15315, 9744, 8243,
224 5298, 746, 380, 9824, 9124, 10895, 6640, 4712, 2669, 6980, 2759,
225 5385, 12345, 11336, 13129, 8600, 2370, 3682, 5219, 12407, 13123, 6784,
226 2612, -291, 3183, 9414, 12316, 14524, 12333, 13397, 7543, 3916, 4153,
227 4477, 4314, 7983, 8418, 9163, 9103, 5760, 688, 9571, 15782, 14203,
228 14878, 17718, 14570, 7940, 6642, 5094, 7133, 9964, 10219, 3224, 5558,
229 2996, 7928, 6733, 16117, 15262, 12757, 7958, 4401, 5187, 5476, 5529,
230 6055, 2206, 3909, 6015, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
231 6967, 6840, 481, -1600, 274, 1, 10373, 8514, 1123, 2117, 6758,
232 12736, 16223, 13585, 15988, 11771, 10600, 7918, 4156, 2840, 3111, 3287,
233 6359, 7652, 8813, 6530, 6967, 7789, 13671, 13990, 13247, 13241, 9836,
234 5251, 3024, 2313, 1834, 4187, 2637, -1312, 2139, 7378, 7665, 11933,
235 15591, 15314, 15678, 9531, 2820, -1516, 3400, 1314, 22, 363, -2896,
236 -898, 5906, 7308, 10650, 12975, 16978, 20370, 18817, 12381, 4118, -861,
237 -137, 236, 1802, 1632, -350, 2334, 3400, 8680, 14064, 18216, 18675,
238 21765, 22871, 11491, 4937, -1555, -11, 1669, 2392, 3265, -5254, -217,
239 5001, 8063, 13444, 18884, 19706, 22794, 21064, 9545, 6689, -7, 289,
240 -2021, 504, 2347,
241 };
242
243 float expected_valid[] = {
244 2612, -291, 3183, 9414, 12316, 14524, 12333, 9103, 5760, 688,
245 9571, 15782, 14203, 14878, 5558, 2996, 7928, 6733, 16117, 15262,
246 12757, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
247 };
248
249 CNN_CONFIG cnn_config = { 3,
250 0,
251 0,
252 0,
253 0,
254 {
255 {
256 1,
257 filter_width,
258 filter_height,
259 3,
260 1,
261 1,
262 0,
263 nullptr,
264 nullptr,
265 PADDING_SAME_ZERO,
266 NONE,
267 0,
268 0,
269 BRANCH_NO_COPY,
270 BRANCH_NOC,
271 {},
272 {},
273 -1,
274 },
275 {
276 3,
277 filter_width,
278 filter_height,
279 3,
280 1,
281 1,
282 0,
283 nullptr,
284 nullptr,
285 PADDING_SAME_ZERO,
286 NONE,
287 0,
288 0,
289 BRANCH_NO_COPY,
290 BRANCH_NOC,
291 {},
292 {},
293 -1,
294 },
295 {
296 3,
297 filter_width,
298 filter_height,
299 1,
300 1,
301 1,
302 0,
303 nullptr,
304 nullptr,
305 PADDING_SAME_ZERO,
306 NONE,
307 0,
308 0,
309 BRANCH_NO_COPY,
310 BRANCH_NOC,
311 {},
312 {},
313 0,
314 },
315 } };
316
317 // Weights and biases need to be specified separately because
318 // of the offset.
319 AssignLayerWeightsBiases(&cnn_config, weights, bias);
320
321 CNN_THREAD_DATA thread_data = { 1, NULL };
322
323 RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
324 image_width, &thread_data, MSE_INT_TOL);
325
326 for (int i = 0; i < cnn_config.num_layers; ++i) {
327 cnn_config.layer_config[i].pad = PADDING_SAME_REPLICATE;
328 }
329
330 RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
331 image_width, &thread_data, MSE_INT_TOL);
332
333 for (int i = 0; i < cnn_config.num_layers; ++i) {
334 cnn_config.layer_config[i].pad = PADDING_VALID;
335 }
336
337 RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
338 image_width, &thread_data, MSE_INT_TOL);
339 }
340
TEST_F(CNNTest,TestRELUSingleLayer)341 TEST_F(CNNTest, TestRELUSingleLayer) {
342 int image_width = 8;
343 int image_height = 8;
344 int filter_height = 5;
345 int filter_width = 4;
346 float input[] = {
347 0, -2, -3, 1, -1, 2, -2, 1, -3, -1, 0, 1, -2, -3, -2, -2,
348 1, -3, 2, -3, -1, -1, 2, 0, -2, -3, 0, -2, -3, 1, -1, -1,
349 2, -2, 0, -2, -3, -3, 1, 1, -1, 1, 0, 1, -3, 0, 2, 2,
350 0, -3, 1, -3, 2, -2, 1, -1, -1, -2, -3, -2, -1, -3, -2, -1,
351 };
352 float expected_same[] = {
353 9, 0, 1, 1, 0, 3, 0, 19, 0, 12, 10, 0, 0, 0, 5, 0,
354 0, 18, 21, 7, 19, 4, 3, 0, 0, 9, 16, 0, 11, 16, 0, 11,
355 12, 2, 0, 11, 0, 16, 6, 0, 8, 22, 13, 10, 12, 0, 0, 0,
356 0, 1, 2, 12, 29, 6, 10, 0, 13, 0, 0, 5, 8, 10, 0, 0,
357 };
358 float expected_replicate[] = {
359 18, 17, 12, 2, 0, 0, 5, 11, 0, 17, 22, 6, 0, 0, 17, 0,
360 0, 18, 21, 7, 19, 4, 3, 5, 3, 9, 16, 0, 11, 16, 0, 3,
361 3, 2, 0, 11, 0, 16, 6, 0, 17, 22, 13, 10, 12, 0, 0, 0,
362 0, 4, 1, 10, 30, 7, 10, 0, 23, 8, 0, 13, 15, 19, 8, 10,
363 };
364 float expected_valid[] = {
365 18, 21, 7, 19, 4, 9, 16, 0, 11, 16, 2, 0, 11, 0, 16, 22, 13, 10, 12, 0,
366 };
367 float weights[] = {
368 -2, -3, 1, 2, 2, -2, -3, 0, -3, 2, 2, -3, -3, -2, 0, 1, 2, 0, -1, -1,
369 };
370 float bias[] = { -3 };
371
372 CNN_CONFIG cnn_config = { 1,
373 0,
374 0,
375 0,
376 0,
377 { {
378 1,
379 filter_width,
380 filter_height,
381 1,
382 1,
383 1,
384 0,
385 weights,
386 bias,
387 PADDING_SAME_ZERO,
388 RELU,
389 0,
390 0,
391 BRANCH_NO_COPY,
392 BRANCH_NOC,
393 {},
394 {},
395 0,
396 } } };
397
398 CNN_THREAD_DATA thread_data = { 1, NULL };
399
400 RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
401 image_width, &thread_data, MSE_INT_TOL);
402
403 cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
404
405 RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
406 image_width, &thread_data, MSE_INT_TOL);
407
408 cnn_config.layer_config[0].pad = PADDING_VALID;
409
410 RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
411 image_width, &thread_data, MSE_INT_TOL);
412 }
413
TEST_F(CNNTest,TestVaryingStridesVaryingDimImages)414 TEST_F(CNNTest, TestVaryingStridesVaryingDimImages) {
415 float weights[] = {
416 1, -5, -3, -4, -1, 1, 2, -3, 2, 2, -1, 1, -5, 1, 1,
417 -3, -5, 3, 1, 4, -2, -5, -2, -3, -5, 0, -1, -5, 2, -2,
418 -2, 1, -2, -4, 1, 3, -2, 2, 0, -3, 2, -3, -2, -3,
419 };
420 float bias[] = { 2 };
421
422 CNN_CONFIG cnn_config = { 1,
423 0,
424 0,
425 0,
426 0,
427 {
428 {
429 1,
430 4,
431 11,
432 1,
433 7,
434 6,
435 0,
436 weights,
437 bias,
438 PADDING_SAME_ZERO,
439 NONE,
440 0,
441 0,
442 BRANCH_NO_COPY,
443 BRANCH_NOC,
444 {},
445 {},
446 0,
447 },
448 } };
449
450 int image_height = 24;
451 int image_width = 17;
452 float input[] = {
453 -1, -3, 4, 4, -5, 4, 3, -5, -1, -3, 4, -4, 2, -3, 3, -5, 2, -1, -5,
454 1, -1, 3, 1, -3, -3, 4, 0, 2, -3, -5, -5, -4, 0, -5, -2, -3, -1, -2,
455 2, -5, 4, 4, 0, -4, -3, 1, -3, -5, -4, -4, 1, -2, -3, 3, -3, -3, -1,
456 -5, -5, -2, 3, 1, -1, -5, -5, 1, -4, -2, -1, -2, -4, -4, 2, -2, 2, 1,
457 -2, -4, -1, 1, -2, -5, 3, -2, -1, -1, -5, -3, 1, -2, -2, -3, -1, -2, -4,
458 -2, 1, -4, -1, 4, 3, -4, 0, 4, 2, 2, 4, -3, -5, 2, 2, 1, -1, -4,
459 -2, 1, 3, 2, 0, 4, -1, -3, 2, 1, -4, 2, 2, -4, -2, 0, -2, -1, 4,
460 4, 2, 3, -4, 2, -4, -5, 4, -1, -3, -1, 0, -4, 1, 3, -1, -3, -5, 3,
461 -2, -4, 1, 2, -2, -3, -3, -5, 1, -3, -1, 0, -1, 3, -4, -1, -5, -5, 1,
462 0, 0, -2, -2, 2, -2, 0, 0, 2, 0, -3, 0, -1, -4, -4, -1, 3, -4, -4,
463 -1, 0, -5, -3, -2, 4, -3, -4, -4, 0, -5, 1, -2, -3, -3, -4, 4, 3, 4,
464 3, 3, -1, 3, 1, -3, -2, 3, 3, 0, 2, -4, -3, 2, 2, 0, -2, 4, -2,
465 2, -2, -1, -4, -2, 2, -4, 3, -1, 4, 1, 1, 4, -1, -4, -4, 1, 1, -2,
466 4, -1, 3, 2, -3, 4, 3, 1, 4, 0, -4, 2, 0, 2, 4, -2, -2, 4, 2,
467 -1, -2, 1, -3, 2, 3, -5, -3, 4, 4, 2, -5, -4, -5, -2, -4, 2, 0, 2,
468 -5, 4, -4, -2, -5, 2, 1, 0, 4, 1, -2, -3, -4, -3, -4, 3, 3, 2, 0,
469 -3, 1, -5, 4, 0, 4, -1, 3, -5, -5, -2, -1, -1, 4, 3, 3, 4, 3, -4,
470 4, -3, -3, -1, -4, -1, -4, -1, -2, 4, -2, -4, 4, 4, -3, -4, -1, 1, 2,
471 -1, -2, -2, 3, 2, 2, -3, 0, -1, 0, 3, 2, -5, 0, -4, 0, 0, 2, -4,
472 -1, -1, 0, -2, 0, 1, 0, 0, 4, -5, -1, -5, 2, -1, 0, 2, -1, 1, 3,
473 -3, -5, -2, -3, 4, -2, -2, -1, -3, -4, -1, -2, -4, 1, 4, -3, -2, -1, 3,
474 -3, -2, 3, 2, 1, -4, -3, -5, 1,
475 };
476 float expected_1[] = {
477 41, -26, 5, 76, 13, 83, -21, 53, -54, -14, 21, 121,
478 };
479
480 CNN_THREAD_DATA thread_data = { 1, NULL };
481
482 RunCNNTest(image_width, image_height, input, expected_1, &cnn_config,
483 image_width, &thread_data, MSE_INT_TOL);
484
485 cnn_config.layer_config[0].skip_width = 6;
486 cnn_config.layer_config[0].skip_height = 7;
487
488 float expected_2[] = {
489 21, -50, 41, 20, 72, 127, -21, 103, 62, -37, 83, -3,
490 };
491 RunCNNTest(image_width, image_height, input, expected_2, &cnn_config,
492 image_width, &thread_data, MSE_INT_TOL);
493
494 cnn_config.layer_config[0].skip_width = 3;
495 cnn_config.layer_config[0].skip_height = 10;
496
497 float expected_3[] = {
498 -26, -21, -35, 69, 49, 4, -51, -43, -56,
499 -41, 15, -44, 40, -62, 63, 38, 27, 47,
500 };
501 RunCNNTest(image_width, image_height, input, expected_3, &cnn_config,
502 image_width, &thread_data, MSE_INT_TOL);
503
504 cnn_config.layer_config[0].skip_width = 10;
505 cnn_config.layer_config[0].skip_height = 3;
506
507 float expected_4[] = {
508 21, 49, 28, 87, 50, 40, 102, 81, 58, 85, 51, 66, 36, 19, -37, -45,
509 };
510
511 RunCNNTest(image_width, image_height, input, expected_4, &cnn_config,
512 image_width, &thread_data, MSE_INT_TOL);
513 }
514
TEST_F(CNNTest,TestMaxPool)515 TEST_F(CNNTest, TestMaxPool) {
516 int image_width = 8;
517 int image_height = 8;
518 int stride = 3;
519 float input[] = {
520 1, -4, -4, 8, 0, 7, -5, -2, 8, 2, 2, 8, 5, -1, -1, 9,
521 -3, 0, -2, 0, 6, 3, -4, 8, 7, 8, 7, -1, 4, -1, 0, 2,
522 -5, -2, 8, 5, 5, 4, 2, 7, 4, 6, 2, 8, 8, -4, -3, -4,
523 -3, -1, 2, 3, 3, 6, -5, 8, 9, 5, 0, -2, -1, 6, 5, 7,
524 };
525
526 float expected[] = {
527 49, 58, 70, 68, 68, 70, 48, 57, 88,
528 };
529
530 float weights[] = {
531 3, 1, 3, 4, -1, 5, -2, 1, -4,
532 };
533
534 float bias[] = {
535 -3,
536 };
537
538 CNN_CONFIG cnn_config = { 1,
539 0,
540 0,
541 0,
542 0,
543 { {
544 1,
545 3,
546 3,
547 1,
548 stride,
549 stride,
550 1,
551 weights,
552 bias,
553 PADDING_SAME_ZERO,
554 NONE,
555 0,
556 0,
557 BRANCH_NO_COPY,
558 BRANCH_NOC,
559 {},
560 {},
561 0,
562 } } };
563
564 CNN_THREAD_DATA thread_data = { 1, NULL };
565
566 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
567 image_width, &thread_data, MSE_INT_TOL);
568 }
569
TEST_F(CNNTest,TestDeconvolveNonActivationSingleLayerSingleKernel)570 TEST_F(CNNTest, TestDeconvolveNonActivationSingleLayerSingleKernel) {
571 int image_width = 4;
572 int image_height = 7;
573 float input[] = {
574 9, 6, 181, 9, 218, 30, 80, 108, 68, 216, 70, 128, 179, 228,
575 33, 212, 34, 14, 48, 27, 230, 23, 202, 113, 80, 56, 122, 112,
576 };
577
578 float expected_1_same[] = {
579 15, -30, 36, -525, 377, -193, 558, 531, 6, -24, -15, 124,
580 166, -561, -356, -754, -3, -3, -3, -3, -3, -3, -3, -3,
581 433, -311, 711, 381, 247, -317, 453, 129, 215, -627, -409, -885,
582 17, -255, -55, -647, -3, -3, -3, -3, -3, -3, -3, -3,
583 133, -719, 633, -225, 785, 191, 463, 79, 65, 9, 77, -853,
584 -365, -949, -15, -667, -3, -3, -3, -3, -3, -3, -3, -3,
585 355, -866, 990, 207, 747, 12, 520, -116, 176, -312, -133, -1370,
586 -426, -802, 143, -771, -3, -3, -3, -3, -3, -3, -3, -3,
587 65, -79, 127, -59, 135, -90, 195, 114, 31, -91, -57, -133,
588 17, -176, -72, -276, -3, -3, -3, -3, -3, -3, -3, -3,
589 457, -302, 733, 58, 470, -475, 829, 490, 227, -670, -440, -790,
590 153, -588, -294, -1150, -3, -3, -3, -3, -3, -3, -3, -3,
591 157, -251, 349, -185, 409, -293, 587, 251, 77, -187, -107, -369,
592 7, -481, -135, -827, -3, -3, -3, -3, -3, -3, -3, -3,
593 };
594 float expected_1_valid[] = {
595 -30, 15, -30, 36, -525, 377, -193, 558, 531, 24, 24, 6,
596 6, -24, -15, 124, 166, -561, -356, -754, -21, -39, -3, -3,
597 -3, -3, -3, -3, -3, -3, -3, -3, -3, -657, 433, -311,
598 711, 381, 247, -317, 453, 129, 321, 321, 215, 215, -627, -409,
599 -885, 17, -255, -55, -647, -219, -435, -3, -3, -3, -3, -3,
600 -3, -3, -3, -3, -3, -3, -207, 133, -719, 633, -225, 785,
601 191, 463, 79, 381, 381, 65, 65, 9, 77, -853, -365, -949,
602 -15, -667, -259, -515, -3, -3, -3, -3, -3, -3, -3, -3,
603 -3, -3, -3, -540, 355, -866, 990, 207, 747, 12, 520, -116,
604 633, 633, 176, 176, -312, -133, -1370, -426, -802, 143, -771, -427,
605 -851, -3, -3, -3, -3, -3, -3, -3, -3, -3, -3, -3,
606 -105, 65, -79, 127, -59, 135, -90, 195, 114, 78, 78, 31,
607 31, -91, -57, -133, 17, -176, -72, -276, -57, -111, -3, -3,
608 -3, -3, -3, -3, -3, -3, -3, -3, -3, -693, 457, -302,
609 733, 58, 470, -475, 829, 490, 336, 336, 227, 227, -670, -440,
610 -790, 153, -588, -294, -1150, -229, -455, -3, -3, -3, -3, -3,
611 -3, -3, -3, -3, -3, -3, -243, 157, -251, 349, -185, 409,
612 -293, 587, 251, 333, 333, 77, 77, -187, -107, -369, 7, -481,
613 -135, -827, -227, -451,
614 };
615 float weights_1[] = { -3, 2, -1, 3, 3, 1, 1, -3, -2, -4 };
616 float bias_1[] = { -3 };
617
618 CNN_CONFIG cnn_config = { 1,
619 0,
620 0,
621 0,
622 0,
623 { {
624 1,
625 5,
626 2,
627 1,
628 2,
629 3,
630 0,
631 weights_1,
632 bias_1,
633 PADDING_SAME_ZERO,
634 NONE,
635 1,
636 0,
637 BRANCH_NO_COPY,
638 BRANCH_NOC,
639 {},
640 {},
641 0,
642 } } };
643
644 CNN_THREAD_DATA thread_data = { 1, NULL };
645
646 RunCNNTest(image_width, image_height, input, expected_1_same, &cnn_config,
647 image_width, &thread_data, MSE_INT_TOL);
648
649 // Change padding to valid
650 cnn_config.layer_config[0].pad = PADDING_VALID;
651
652 RunCNNTest(image_width, image_height, input, expected_1_valid, &cnn_config,
653 image_width, &thread_data, MSE_INT_TOL);
654
655 float expected_12_same[] = {
656 15, -12, 6, 36, -9, -528, 377, -184, 513, 558, -12, 24,
657 6, -30, -15, -33, -21, 166, 154, -546, -356, -718, -30, -21,
658 433, -221, 561, 711, -33, -153, 247, -83, -87, 453, -111, 321,
659 215, -657, -409, -845, -93, 17, -43, -243, -55, -215, -327, -219,
660 133, -71, -447, 633, -219, 435, 785, -73, -177, 463, -131, 381,
661 65, -207, 77, -59, -651, -365, -797, -213, -15, -155, -387, -259,
662 355, -182, -150, 990, -231, 582, 747, -36, -540, 520, -215, 633,
663 176, -540, -133, -491, -687, -426, -882, -102, 143, 77, -639, -427,
664 65, -37, 57, 127, -17, -105, 135, -51, 60, 195, -30, 78,
665 31, -105, -57, -125, -45, 17, -11, -147, -72, -168, -84, -57,
666 457, -233, 618, 733, -26, -540, 470, -205, 264, 829, -116, 336,
667 227, -693, -440, -900, -72, 153, 107, -609, -294, -698, -342, -229,
668 157, -83, 69, 349, -59, -201, 409, -125, 27, 587, -115, 333,
669 77, -243, -107, -267, -171, 7, -105, -369, -135, -379, -339, -227,
670 };
671 float expected_12_valid[] = {
672 -30, 15, -12, 6, 36, -9, -528, 377, -184, 513, 558, -12,
673 24, 24, 6, 6, -30, -15, -33, -21, 166, 154, -546, -356,
674 -718, -30, -21, -39, -657, 433, -221, 561, 711, -33, -153, 247,
675 -83, -87, 453, -111, 321, 321, 215, 215, -657, -409, -845, -93,
676 17, -43, -243, -55, -215, -327, -219, -435, -207, 133, -71, -447,
677 633, -219, 435, 785, -73, -177, 463, -131, 381, 381, 65, 65,
678 -207, 77, -59, -651, -365, -797, -213, -15, -155, -387, -259, -515,
679 -540, 355, -182, -150, 990, -231, 582, 747, -36, -540, 520, -215,
680 633, 633, 176, 176, -540, -133, -491, -687, -426, -882, -102, 143,
681 77, -639, -427, -851, -105, 65, -37, 57, 127, -17, -105, 135,
682 -51, 60, 195, -30, 78, 78, 31, 31, -105, -57, -125, -45,
683 17, -11, -147, -72, -168, -84, -57, -111, -693, 457, -233, 618,
684 733, -26, -540, 470, -205, 264, 829, -116, 336, 336, 227, 227,
685 -693, -440, -900, -72, 153, 107, -609, -294, -698, -342, -229, -455,
686 -243, 157, -83, 69, 349, -59, -201, 409, -125, 27, 587, -115,
687 333, 333, 77, 77, -243, -107, -267, -171, 7, -105, -369, -135,
688 -379, -339, -227, -451,
689 };
690
691 // Change skip_width, skip_height to {2, 3}
692 cnn_config.layer_config[0].skip_width = 3;
693 cnn_config.layer_config[0].skip_height = 2;
694 // Set padding to same
695 cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
696
697 RunCNNTest(image_width, image_height, input, expected_12_same, &cnn_config,
698 image_width, &thread_data, MSE_INT_TOL);
699
700 // Change padding to valid
701 cnn_config.layer_config[0].pad = PADDING_VALID;
702 RunCNNTest(image_width, image_height, input, expected_12_valid, &cnn_config,
703 image_width, &thread_data, MSE_INT_TOL);
704
705 cnn_config.layer_config[0].filter_width = 4;
706 cnn_config.layer_config[0].filter_height = 3;
707 float weights_2[] = { -1, -3, -1, -3, 0, 2, -2, 4, 3, 0, 1, 4 };
708 float bias_2[] = { -4 };
709 cnn_config.layer_config[0].weights = weights_2;
710 cnn_config.layer_config[0].bias = bias_2;
711
712 cnn_config.layer_config[0].skip_width = 5;
713 cnn_config.layer_config[0].skip_height = 2;
714 float expected_2_same[] = {
715 -13, -31, -13, -31, -4, -10, -22, -10, -22, -4, -185, -547,
716 -185, -547, -4, -13, -31, -13, -31, -4, -4, 14, -22, 32,
717 -4, -4, 8, -16, 20, -4, -4, 358, -366, 720, -4, -4,
718 14, -22, 32, -4, -195, -658, -213, -622, -4, -16, -94, -28,
719 -70, -4, 459, -244, 97, 480, -4, -85, -328, -103, -292, -4,
720 -4, 432, -440, 868, -4, -4, 56, -64, 116, -4, -4, 156,
721 -164, 316, -4, -4, 212, -220, 428, -4, 582, -208, 146, 664,
722 -4, -130, -652, -190, -532, -4, 166, -214, 6, 106, -4, 192,
723 -388, -24, 44, -4, -4, 132, -140, 268, -4, -4, 428, -436,
724 860, -4, -4, 136, -144, 276, -4, -4, 252, -260, 508, -4,
725 21, -541, -115, -269, -4, 416, -688, -16, 176, -4, 173, -103,
726 33, 177, -4, 168, -640, -88, -128, -4, -4, 354, -362, 712,
727 -4, -4, 452, -460, 908, -4, -4, 62, -70, 128, -4, -4,
728 420, -428, 844, -4, 499, -106, 141, 610, -4, 666, -46, 210,
729 866, -4, 47, -148, -19, -16, -4, 605, -85, 181, 763, -4,
730 -4, 64, -72, 132, -4, -4, 24, -32, 52, -4, -4, 92,
731 -100, 188, -4, -4, 50, -58, 104, -4, -132, -694, -200, -558,
732 -4, 15, -73, -13, -17, -4, -62, -610, -158, -418, -4, -36,
733 -343, -90, -235, -4, -4, 456, -464, 916, -4, -4, 42, -50,
734 88, -4, -4, 400, -408, 804, -4, -4, 222, -230, 448, -4,
735 606, -244, 146, 676, -4, 9, -172, -37, -80, -4, 480, -370,
736 76, 438, -4, 223, -340, -3, 112, -4, -4, 156, -164, 316,
737 -4, -4, 108, -116, 220, -4, -4, 240, -248, 484, -4, -4,
738 220, -228, 444, -4,
739 };
740 float expected_2_valid[] = {
741 -13, -31, -13, -31, -4, -10, -22, -10, -22, -4, -185, -547,
742 -185, -547, -4, -13, -31, -13, -31, -4, 14, -22, 32, -4,
743 -4, 8, -16, 20, -4, -4, 358, -366, 720, -4, -4, 14,
744 -22, 32, -195, -658, -213, -622, -4, -16, -94, -28, -70, -4,
745 459, -244, 97, 480, -4, -85, -328, -103, -292, -4, 432, -440,
746 868, -4, -4, 56, -64, 116, -4, -4, 156, -164, 316, -4,
747 -4, 212, -220, 428, 582, -208, 146, 664, -4, -130, -652, -190,
748 -532, -4, 166, -214, 6, 106, -4, 192, -388, -24, 44, -4,
749 132, -140, 268, -4, -4, 428, -436, 860, -4, -4, 136, -144,
750 276, -4, -4, 252, -260, 508, 21, -541, -115, -269, -4, 416,
751 -688, -16, 176, -4, 173, -103, 33, 177, -4, 168, -640, -88,
752 -128, -4, 354, -362, 712, -4, -4, 452, -460, 908, -4, -4,
753 62, -70, 128, -4, -4, 420, -428, 844, 499, -106, 141, 610,
754 -4, 666, -46, 210, 866, -4, 47, -148, -19, -16, -4, 605,
755 -85, 181, 763, -4, 64, -72, 132, -4, -4, 24, -32, 52,
756 -4, -4, 92, -100, 188, -4, -4, 50, -58, 104, -132, -694,
757 -200, -558, -4, 15, -73, -13, -17, -4, -62, -610, -158, -418,
758 -4, -36, -343, -90, -235, -4, 456, -464, 916, -4, -4, 42,
759 -50, 88, -4, -4, 400, -408, 804, -4, -4, 222, -230, 448,
760 606, -244, 146, 676, -4, 9, -172, -37, -80, -4, 480, -370,
761 76, 438, -4, 223, -340, -3, 112, -4, 156, -164, 316, -4,
762 -4, 108, -116, 220, -4, -4, 240, -248, 484, -4, -4, 220,
763 -228, 444, 236, -4, 76, 316, -4, 164, -4, 52, 220, -4,
764 362, -4, 118, 484, -4, 332, -4, 108, 444,
765 };
766 // Set padding to same
767 cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
768
769 RunCNNTest(image_width, image_height, input, expected_2_same, &cnn_config,
770 image_width, &thread_data, MSE_INT_TOL);
771
772 cnn_config.layer_config[0].pad = PADDING_VALID;
773
774 RunCNNTest(image_width, image_height, input, expected_2_valid, &cnn_config,
775 image_width, &thread_data, MSE_INT_TOL);
776
777 cnn_config.layer_config[0].skip_width = 2;
778 cnn_config.layer_config[0].skip_height = 5;
779 float expected_21_same[] = {
780 -31, -19, -49, -191, -565, -194, -574, -13, 14, -22, 44, -16,
781 382, -366, 738, -22, -4, 23, 32, 545, 20, 204, 720, 5,
782 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
783 -4, -4, -4, -4, -658, -252, -748, -114, -334, -192, -568, -112,
784 432, -440, 928, -64, 276, -164, 532, -220, -4, 304, 868, 266,
785 116, 400, 316, 104, -4, -4, -4, -4, -4, -4, -4, -4,
786 -4, -4, -4, -4, -4, -4, -4, -4, -208, -288, -856, -290,
787 -862, -202, -598, -132, 132, -140, 700, -436, 1000, -144, 532, -260,
788 -4, 712, 268, 422, 860, 450, 276, 124, -4, -4, -4, -4,
789 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
790 -541, -411, -1225, -265, -787, -249, -739, -216, 354, -362, 1168, -460,
791 974, -70, 552, -428, -4, 859, 712, 323, 908, 665, 128, 208,
792 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
793 -4, -4, -4, -4, -106, -52, -148, -66, -190, -79, -229, -31,
794 64, -72, 160, -32, 148, -100, 242, -58, -4, 72, 132, 154,
795 52, 125, 188, 23, -4, -4, -4, -4, -4, -4, -4, -4,
796 -4, -4, -4, -4, -4, -4, -4, -4, -694, -257, -763, -229,
797 -679, -319, -949, -117, 456, -464, 962, -50, 492, -408, 1030, -230,
798 -4, 295, 916, 625, 88, 537, 804, 109, -4, -4, -4, -4,
799 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
800 -244, -140, -412, -182, -538, -238, -706, -116, 156, -164, 428, -116,
801 464, -248, 708, -228, -4, 244, 316, 418, 220, 454, 484, 108,
802 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
803 -4, -4, -4, -4,
804 };
805 float expected_21_valid[] = {
806 -13, -31, -19, -49, -191, -565, -194, -574, -13, -31, -4, 14,
807 -22, 44, -16, 382, -366, 738, -22, 32, 23, -4, 23, 32,
808 545, 20, 204, 720, 5, 32, -4, -4, -4, -4, -4, -4,
809 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
810 -4, -4, -222, -658, -252, -748, -114, -334, -192, -568, -112, -328,
811 -4, 432, -440, 928, -64, 276, -164, 532, -220, 428, 650, -4,
812 304, 868, 266, 116, 400, 316, 104, 428, -4, -4, -4, -4,
813 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
814 -4, -4, -4, -4, -72, -208, -288, -856, -290, -862, -202, -598,
815 -132, -388, -4, 132, -140, 700, -436, 1000, -144, 532, -260, 508,
816 200, -4, 712, 268, 422, 860, 450, 276, 124, 508, -4, -4,
817 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
818 -4, -4, -4, -4, -4, -4, -183, -541, -411, -1225, -265, -787,
819 -249, -739, -216, -640, -4, 354, -362, 1168, -460, 974, -70, 552,
820 -428, 844, 533, -4, 859, 712, 323, 908, 665, 128, 208, 844,
821 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
822 -4, -4, -4, -4, -4, -4, -4, -4, -38, -106, -52, -148,
823 -66, -190, -79, -229, -31, -85, -4, 64, -72, 160, -32, 148,
824 -100, 242, -58, 104, 98, -4, 72, 132, 154, 52, 125, 188,
825 23, 104, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
826 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -234, -694,
827 -257, -763, -229, -679, -319, -949, -117, -343, -4, 456, -464, 962,
828 -50, 492, -408, 1030, -230, 448, 686, -4, 295, 916, 625, 88,
829 537, 804, 109, 448, -4, -4, -4, -4, -4, -4, -4, -4,
830 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
831 -84, -244, -140, -412, -182, -538, -238, -706, -116, -340, -4, 156,
832 -164, 428, -116, 464, -248, 708, -228, 444, 236, -4, 244, 316,
833 418, 220, 454, 484, 108, 444,
834 };
835
836 cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
837
838 RunCNNTest(image_width, image_height, input, expected_21_same, &cnn_config,
839 image_width, &thread_data, MSE_INT_TOL);
840
841 cnn_config.layer_config[0].pad = PADDING_VALID;
842
843 RunCNNTest(image_width, image_height, input, expected_21_valid, &cnn_config,
844 image_width, &thread_data, MSE_INT_TOL);
845 }
846
TEST_F(CNNTest,TestLargeKernelsAndStrides)847 TEST_F(CNNTest, TestLargeKernelsAndStrides) {
848 float input_10x11[] = {
849 4, 4, 2, 4, 2, -5, -2, 3, -1, 0, 0, 1, 2, 0, -5, -2, -5, 1, -3,
850 -1, 4, -3, 2, -2, 1, 0, 1, -3, -3, -4, -2, -2, 1, -4, -1, 4, 1, -4,
851 -4, -4, 3, 2, -5, 3, -5, 1, 2, -4, 1, -1, 3, 4, -2, 3, -3, 3, 0,
852 2, -4, -5, -5, -2, -1, -2, 1, 1, 1, -2, 4, -5, 4, -1, -1, 2, 3, -4,
853 2, 2, 3, 0, 0, 1, 0, 3, 2, 3, 1, -2, 3, -4, 3, 2, 4, -2, 0,
854 4, -4, 1, -3, -3, -3, -5, 1, -3, -5, 0, 4, -1, -3, 2,
855 };
856
857 float weights_10x11[] = {
858 -3, 4, -4, -3, -5, 1, -2, 3, 1, -4, -4, 0, -1, 0, 3, 1, -3, -2, 0,
859 -1, 1, 3, -4, -4, -3, -3, -2, 4, 3, -5, 4, 2, -3, 4, -2, -1, 2, -1,
860 -5, 0, -3, 0, 3, -5, -5, 3, -4, -1, -5, 3, 4, 0, 4, -5, 2, -1, 2,
861 -1, -1, -1, -5, 0, -4, 3, -1, 1, 1, -1, 3, 2, -5, -4, 0, -4, 4, -5,
862 -3, 4, -5, 2, -5, -4, -4, -1, 3, 3, 0, 2, -4, 1, -2, 1, 1, 0, 3,
863 -2, 0, 1, 2, 4, -3, -1, -5, -5, 2, -4, 1, 1, 2, -4, -2, -2, 2, 1,
864 3, 4, -5, 1, -1, -3, -3, -1, -2, -5, 1, -1, 0, 1, 4, 4, 0, 0, 4,
865 -3, -1, -5, -3, 0, 1, 1, 1, -5, 3, 4, 3, -5, 3, -2, -2, 0, -4, 0,
866 0, -2, 1, -4, -1, 0, -5, -2, -2, -5, -3, -3, 1, 1, -3, 2, 4, 2, 4,
867 -4, -3, 3, 1, 1, 3, -4, 4, -2, -3, -3, -3, -3, -4, -2, 3, -5, 2, 4,
868 -1, -4, -4, 4, -2, -1, 3, -3, -4, -4, -2, 4, 1, 0, 2, -1, 4, -3, 1,
869 4, -3, 4, 4, 0, -4, 3, -2, -3, 2, 3, -1, -3, 2, 1, 4, -2, -3, 1,
870 4, -2, 2, -2, -5, -2, 1, 4, -1, -4, 4, -5, 2, -5, -4, -1, -2, 3, 1,
871 2, 1, -5, 1, -5, -4, -1, -2, 2, -2, -4, -3, -2, -2, 4, -1, 2, 2, -4,
872 2, -2, 4, -4, -2, -2, 1, -1, 1, 1, 1, -4, -5, -2, 3, -4, -1, 3, -2,
873 3, 2, -5, -4, 0, 3, -2, -4, -5, 3, -2, -4, 2, -2, 1, -4, 0, 2, -5,
874 1, -4, -1, -1, 4, -5, -4, 0, -5, -4, -3, -5, -4, 0, 2, 0, -4, 2, -2,
875 1, 1, -3, 2, 0, -4, 0, -4, 1, 0, -5, -1, -1, -1, -5, 4, 2, 2, -4,
876 3, -2, -2, 2, -3, -2, -1, 2, -4, -5, 2, -2, -4, -5, -5, -1, 2, -1, 0,
877 -5, -2, -2, -5, 0, 1, -1, -5, 0, 3, 2, 3, 0, -3, -2, 0, -5, -1, -2,
878 2, -4, -1, 2, 2, -5, 2, -4, 0, 3, -3, 1, 0, 0, 1, -5, -3, 1, -1,
879 0, -4, -3, 2, -4, -4, 4, -1, 0, 1, 2, -4, -5, 4, -2, 1, -4, -4, -3,
880 -1, -1, 1, -1, -4, -1, -4, -3, 2, -1, -2, -4, 1, 1, 0, -2, 0, -4, 3,
881 -3, 0, -4, -1, -4, 2, -1, -2, -5, -1, -2, -3, 3, -1, 0, -3, 0, 1, -5,
882 1, -5, 0, 1,
883 };
884
885 float bias_10x11[] = { 3 };
886
887 float expected_10x11[] = {
888 118,
889 };
890
891 CNN_CONFIG cnn_config = { 1,
892 0,
893 0,
894 0,
895 0,
896 { {
897 1,
898 23,
899 20,
900 1,
901 15,
902 20,
903 0,
904 weights_10x11,
905 bias_10x11,
906 PADDING_SAME_ZERO,
907 NONE,
908 0,
909 0,
910 BRANCH_NO_COPY,
911 BRANCH_NOC,
912 {},
913 {},
914 0,
915 } } };
916
917 int image_height = 10;
918 int image_width = 11;
919
920 CNN_THREAD_DATA thread_data = { 1, NULL };
921
922 RunCNNTest(image_width, image_height, input_10x11, expected_10x11,
923 &cnn_config, image_width, &thread_data, MSE_INT_TOL);
924
925 float input_11x10[] = {
926 -2, -2, 3, -5, -1, -3, 1, 3, 2, 1, 1, -5, 4, 1, 3, -5, 3, -3, -5,
927 0, -1, -3, -3, 1, 1, -5, -1, -5, -5, -3, 0, 1, -3, -1, -3, -3, 0, 3,
928 4, -4, -1, 3, -3, -1, -3, 1, -3, -2, -1, -4, -3, 2, -4, 1, -4, -1, -3,
929 -5, -1, 2, 3, 0, 2, 2, -5, 4, 1, 2, -1, -4, 4, -4, -4, 0, -1, 1,
930 -1, 1, -3, -3, -2, 1, 2, 4, 4, 4, -3, -3, 0, 1, 0, 1, 4, 1, 3,
931 4, -3, -2, -4, 4, 2, 0, 3, 4, -1, 2, -2, 1, -3, -2,
932 };
933
934 float weights_11x10[] = {
935 4, -1, 1, -1, 2, 4, 3, 3, -4, 3, -5, 1, -1, -1, -2, -2, 0, 2, -3,
936 -2, 3, -5, -1, 0, -1, -2, -2, -1, 2, 4, 3, 1, 0, 0, -3, 3, -4, -1,
937 -5, 4, -2, -2, 1, 2, -1, -3, 1, 2, -5, 1, -3, 3, 3, 0, -4, -4, -5,
938 -3, -4, -4, 4, -2, 4, 4, -2, 2, -5, -1, -2, -5, -1, 4, -3, 3, -2, 0,
939 -4, -3, 0, -1, -2, 4, 2, 0, -2, -5, -4, 1, 4, -4, -2, 2, -2, 1, 1,
940 -4, 1, -4, -4, -2, 4, 2, -1, -5, -5, 1, -3, -3, 3, -3, -5, -3, 4, -1,
941 -1, -3, 0, -4, 3, -1, 0, -2, 0, -5, -2, -5, 2, 0, -5, 2, 3, -2, 2,
942 4, -1, 1, -3, 2, 3, 2, 0, -5, -4, -5, 2, 1, 1, -1, -2, 3, 4, 2,
943 -2, 4, -2, 3, 1, -4, -3, -1, 4, 4, -3, -5, -2, 2, 0, 3, -2, 3, -1,
944 -4, 0, -2, 0, 3, 4, -2, -3, -2, 0, 3, 4, 2, -4, 0, 1, 2, 2, -1,
945 -1, 4, 1, 4, -2, -1, -1, -5, 1, -3, 3, 3, -1, -4, 3, -5, 0, 0, -1,
946 -4, -1, -2, 4, -2, 3, 3, -3, 1, -1, 2, -1, 4, 4, -2, -2, 4, -2, 0,
947 3, -3, -5, -1, -2, 4, -4, 2, -4, 0, -2, 3, -3, 2, 2, -2, -5, -1, 4,
948 3, -2, -1, 3, 3, -1, 3, 0, -3, 0, 4, 2, 0, -1, 4, 1, 1, 2, 1,
949 3, 1, 1, 1, -3, -5, -4, 4, -4, 2, 0, 0, -4, 1, 4, -5, 4, 4, 0,
950 1, 0, -2, -4, -4, -3, 0, 1, -5, 4, 0, -3, -2, -4, 2, 4, 1, -5, 1,
951 -4, 1, 0, -3, -3, 0, 2, -5, 4, 3, -2, -5, 3, 1, -1, 0, 3, -2, -2,
952 3, -2, -5, 4, 1, -2, 2, -1, 0, 4, 0, -5, 3, -2, 1, 2, 1, -5, -3,
953 -2, -5, 4, -4, 0, 3, 2, -1, -4, -1, 2, 1, -2, 3, -1, -4, 2, 0, -3,
954 1, -1, 2, -5, -4, -1, -5, 1, 4, 3, 4, 2, -3, 1, -5, -1, 3, 0, -1,
955 -4, 3, 4, -5, 4, 4, -3, 2, -3, -1, -3, -5, -3, 2, -3, -2, 1, 1, 0,
956 -5, 3, 2, 1, -5, 1, 1, 1, 3, 4, -4, -1, -2, 0, -5, -3, -5, -2, -4,
957 3, 3, 3, 4, 0, -4, -1, -5, 0, -3, 1, 4, 4, -4, 4, -5, -5, -1, -2,
958 -5, 3, -4, 4, 3, 0, -3, 2, -2, 0, 0, 4, 4, 0, -2, 1, -1, -3, 2,
959 -1, 1, -3, -5,
960 };
961
962 float bias_11x10[] = {
963 -5,
964 };
965
966 float expected_11x10[] = {
967 36, -84, 95, 45, 18, 46, 77, -54, -99, -149, 66, 49, 161, 11,
968 39, 61, -66, 61, 4, -3, 34, -44, -23, 31, 64, 29, 47, 72,
969 -27, -27, 121, -3, 100, 1, 30, -78, -12, -89, -59, 8, -16, 112,
970 91, -102, -26, -4, 30, 54, 4, -84, -24, -58, 27, -53, -33, 5,
971 53, -26, 63, 50, -103, -130, -23, 6, -104, -207, 73, 23, 77, 132,
972 38, 32, -130, -44, -60, 7, 27, 176, 45, -32, -2, 99, -97, 63,
973 69, 126, 47, 63, 136, -57, 5, 16, -40, -157, 8, 38, -44, -10,
974 91, 7, 122, 140, 30, -105, 4, -1, 113, 64, 180, 141,
975 };
976
977 cnn_config.layer_config[0].weights = weights_11x10;
978 cnn_config.layer_config[0].bias = bias_11x10;
979 cnn_config.layer_config[0].filter_width = 20;
980 cnn_config.layer_config[0].filter_height = 23;
981 cnn_config.layer_config[0].skip_width = 1;
982 cnn_config.layer_config[0].skip_height = 1;
983 image_height = 11;
984 image_width = 10;
985
986 RunCNNTest(image_width, image_height, input_11x10, expected_11x10,
987 &cnn_config, image_width, &thread_data, MSE_INT_TOL);
988 }
989
TEST_F(CNNTest,TestSoftsignSingleLayer)990 TEST_F(CNNTest, TestSoftsignSingleLayer) {
991 int image_width = 8;
992 int image_height = 8;
993 int filter_height = 5;
994 int filter_width = 4;
995 float input[] = {
996 -0.5220f, 0.8410f, -0.8990f, -0.0090f, 0.6710f, -0.9470f, -0.8240f,
997 -0.0870f, 0.5380f, 0.4750f, 0.570f, -0.3760f, -0.6960f, -0.5940f,
998 -0.3830f, 0.080f, -0.0980f, -0.4940f, -0.4030f, 0.9460f, -0.6020f,
999 0.4220f, 0.6190f, 0.6640f, -0.9210f, -0.1470f, -0.2480f, -0.1120f,
1000 -0.580f, -0.0650f, 0.3330f, 0.9860f, -0.7430f, 0.7610f, 0.4840f,
1001 0.1030f, 0.9570f, 0.6120f, -0.5240f, -0.1220f, -0.5850f, -0.270f,
1002 0.7840f, -0.9790f, 0.7290f, -0.30f, -0.6460f, 0.0780f, 0.4750f,
1003 -0.0510f, 0.4550f, 0.3850f, -0.7230f, 0.4460f, -0.6260f, -0.810f,
1004 0.8720f, -0.2120f, -0.580f, -0.9510f, -0.8430f, -0.1340f, -0.0850f,
1005 0.9190f,
1006 };
1007 float expected_same[] = {
1008 0.430f, 0.660f, 0.5510f, -0.610f, 0.450f, -0.1610f, 0.0520f, 0.3240f,
1009 0.6820f, 0.3820f, 0.6360f, 0.7480f, 0.3080f, 0.090f, 0.3910f, 0.1730f,
1010 0.340f, 0.6660f, -0.4990f, 0.4280f, 0.1540f, 0.120f, 0.4670f, 0.6150f,
1011 -0.3880f, 0.7590f, 0.4190f, 0.7350f, 0.5310f, -0.5160f, -0.1760f, 0.6790f,
1012 -0.6780f, 0.5470f, 0.5750f, -0.6420f, 0.7210f, -0.4620f, 0.5430f, 0.770f,
1013 -0.1990f, 0.3950f, 0.7860f, -0.4380f, 0.7540f, 0.2640f, -0.6430f, 0.4510f,
1014 -0.1260f, 0.1590f, -0.2110f, -0.0560f, 0.6570f, 0.680f, 0.5870f, 0.4720f,
1015 0.4040f, 0.3630f, 0.670f, 0.2360f, 0.410f, 0.6980f, -0.5350f, 0.3940f,
1016 };
1017 float expected_replicate[] = {
1018 0.540f, 0.7230f, -0.3530f, -0.2130f, 0.7440f, -0.4470f, -0.6260f,
1019 -0.2050f, 0.7230f, 0.4630f, 0.5920f, 0.7440f, 0.6080f, 0.3130f,
1020 -0.5670f, -0.4720f, 0.5480f, 0.6660f, -0.4990f, 0.4280f, 0.1540f,
1021 0.120f, 0.3390f, 0.6090f, 0.4160f, 0.7590f, 0.4190f, 0.7350f,
1022 0.5310f, -0.5160f, -0.490f, 0.4450f, -0.610f, 0.5470f, 0.5750f,
1023 -0.6420f, 0.7210f, -0.4620f, 0.3150f, 0.7370f, -0.5820f, 0.3950f,
1024 0.7860f, -0.4380f, 0.7540f, 0.2640f, -0.7430f, -0.5340f, -0.6270f,
1025 0.4430f, 0.4730f, 0.4570f, 0.7450f, 0.630f, 0.2620f, 0.3140f,
1026 -0.1840f, 0.1810f, 0.7210f, 0.2760f, 0.6430f, 0.6720f, -0.4390f,
1027 0.2040f,
1028 };
1029 float expected_valid[] = {
1030 0.6660f, -0.4990f, 0.4280f, 0.1540f, 0.120f, 0.7590f, 0.4190f,
1031 0.7350f, 0.5310f, -0.5160f, 0.5470f, 0.5750f, -0.6420f, 0.7210f,
1032 -0.4620f, 0.3950f, 0.7860f, -0.4380f, 0.7540f, 0.2640f,
1033 };
1034 float weights[] = {
1035 0.6210f, 0.3710f, -0.2770f, -0.7230f, -0.2450f, 0.6770f, 0.3080f,
1036 -0.9880f, -0.080f, 0.7190f, -0.6760f, -0.0170f, -0.8970f, 0.8260f,
1037 0.7390f, -0.4550f, -0.4260f, -0.6330f, 0.0880f, -0.9390f,
1038 };
1039 float bias[] = {
1040 0.750f,
1041 };
1042
1043 CNN_CONFIG cnn_config = { 1,
1044 0,
1045 0,
1046 0,
1047 0,
1048 { {
1049 1,
1050 filter_width,
1051 filter_height,
1052 1,
1053 1,
1054 1,
1055 0,
1056 weights,
1057 bias,
1058 PADDING_SAME_ZERO,
1059 SOFTSIGN,
1060 0,
1061 0,
1062 BRANCH_NO_COPY,
1063 BRANCH_NOC,
1064 {},
1065 {},
1066 0,
1067 } } };
1068
1069 CNN_THREAD_DATA thread_data = { 1, NULL };
1070
1071 RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
1072 image_width, &thread_data, MSE_FLOAT_TOL);
1073
1074 cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
1075
1076 RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
1077 image_width, &thread_data, MSE_FLOAT_TOL);
1078
1079 cnn_config.layer_config[0].pad = PADDING_VALID;
1080
1081 RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
1082 image_width, &thread_data, MSE_FLOAT_TOL);
1083 }
1084
TEST_F(CNNTest,TestBranchTensorAdd)1085 TEST_F(CNNTest, TestBranchTensorAdd) {
1086 int filter_width = 2;
1087 int filter_height = 3;
1088
1089 int image_width = 4;
1090 int image_height = 4;
1091
1092 float input[] = {
1093 -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
1094 };
1095
1096 float weights[] = {
1097 -3, -1, 4, -1, -3, 3, 3, 0, 2, 0, 3, 2, 4, 4, 4, -5, 1, -4,
1098 2, -4, 1, -3, 0, 4, -5, 4, 0, -4, -3, -1, 0, 0, -2, 0, 0, 2,
1099 -5, -1, 1, -3, 3, 4, 3, 0, 1, -1, 1, 1, 2, 4, -2, -5, 2, -2,
1100 3, -2, 4, -1, 0, 2, 3, 2, -2, -1, -3, 1, 3, 4, -1, -3, 0, -4,
1101 4, 2, -3, -3, -1, 0, 1, 0, 3, 3, -3, 0, 3, 2, -5, -3, 4, -5,
1102 3, -1, -1, -3, 0, 1, -1, -4, 2, 4, -1, 4, -1, 1, 3, 4, 4, 4,
1103 0, -1, -3, -3, -3, -3, 2, -3, -2, 2, 3, -3,
1104 };
1105
1106 float bias[] = {
1107 3, 4, -1, -1, 2, 1, -2, 1, 4, 1, 3,
1108 };
1109
1110 float expected[] = {
1111 -11502, -4101, -3424, 668, -17950, -5470, -5504, 626,
1112 4835, 446, 1779, -3483, 3679, -4214, 4578, -105,
1113 };
1114
1115 int channels = 2;
1116
1117 CNN_CONFIG cnn_config = { 6,
1118 0,
1119 0,
1120 0,
1121 0,
1122 { {
1123 1,
1124 filter_width,
1125 filter_height,
1126 channels,
1127 1,
1128 1,
1129 0,
1130 weights,
1131 bias,
1132 PADDING_SAME_ZERO,
1133 NONE,
1134 0,
1135 0,
1136 BRANCH_NO_COPY,
1137 BRANCH_NOC,
1138 {},
1139 {},
1140 -1,
1141 },
1142 {
1143 channels,
1144 filter_width,
1145 filter_height,
1146 channels,
1147 1,
1148 1,
1149 0,
1150 nullptr,
1151 nullptr,
1152 PADDING_SAME_ZERO,
1153 NONE,
1154 0,
1155 0,
1156 BRANCH_INPUT,
1157 BRANCH_NOC,
1158 {
1159 0x02,
1160 0,
1161 0x00,
1162 },
1163 {},
1164 -1,
1165 },
1166 {
1167 channels,
1168 filter_width,
1169 filter_height,
1170 channels,
1171 1,
1172 1,
1173 0,
1174 nullptr,
1175 nullptr,
1176 PADDING_SAME_ZERO,
1177 NONE,
1178 0,
1179 1,
1180 BRANCH_NO_COPY,
1181 BRANCH_NOC,
1182 {},
1183 {},
1184 -1,
1185 },
1186 {
1187 channels,
1188 filter_width,
1189 filter_height,
1190 channels,
1191 1,
1192 1,
1193 0,
1194 nullptr,
1195 nullptr,
1196 PADDING_SAME_ZERO,
1197 NONE,
1198 0,
1199 1,
1200 BRANCH_NO_COPY,
1201 BRANCH_NOC,
1202 {},
1203 {},
1204 -1,
1205 },
1206 {
1207 channels,
1208 filter_width,
1209 filter_height,
1210 channels,
1211 1,
1212 1,
1213 0,
1214 nullptr,
1215 nullptr,
1216 PADDING_SAME_ZERO,
1217 NONE,
1218 0,
1219 0,
1220 BRANCH_NO_COPY,
1221 BRANCH_ADD,
1222 {
1223 0x00,
1224 0,
1225 0x02,
1226 },
1227 {},
1228 -1,
1229 },
1230 {
1231 channels,
1232 filter_width,
1233 filter_height,
1234 1,
1235 1,
1236 1,
1237 0,
1238 nullptr,
1239 nullptr,
1240 PADDING_SAME_ZERO,
1241 NONE,
1242 0,
1243 0,
1244 BRANCH_NO_COPY,
1245 BRANCH_NOC,
1246 {},
1247 {},
1248 0,
1249 } } };
1250
1251 // Weights and biases need to be specified separately because
1252 // of the offset.
1253 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1254
1255 CNN_THREAD_DATA thread_data = { 1, NULL };
1256
1257 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1258 image_width, &thread_data, MSE_INT_TOL);
1259 }
1260
TEST_F(CNNTest,TestBranchTensorConcatenation)1261 TEST_F(CNNTest, TestBranchTensorConcatenation) {
1262 int filter_width = 2;
1263 int filter_height = 3;
1264
1265 int image_width = 4;
1266 int image_height = 4;
1267
1268 float input[] = {
1269 -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
1270 };
1271
1272 float weights[] = {
1273 3, 0, 2, 0, 2, 3, 1, -3, 1, -5, -3, 0, -4, 4, 0, -5, 0, -5, -1,
1274 -2, -5, 0, -3, 2, -4, 2, 0, 2, -1, 0, -4, 3, 0, 0, -1, -5, 2, -1,
1275 4, -4, -2, -3, -3, 3, 4, -2, -1, -4, -1, 4, 4, -1, 4, 3, -4, 2, -2,
1276 -4, -3, -2, 3, -3, -5, -1, 3, -2, 4, 1, -4, -3, -5, -5, -3, 4, -2, -2,
1277 -1, -5, -5, 0, -1, -2, -3, 3, -4, -5, 2, -3, 1, 0, -5, 2, 2, -2, 0,
1278 2, 2, -2, 4, 2, 2, 0, 1, -5, -3, 0, 2, -2, 1, 2, -5, 2, 3, 3,
1279 -1, 3, 0, -3, 3, -4, -4, 3, 3, -4, -2, 2, -2, 2, -2, -1, 3, 0,
1280 };
1281
1282 float bias[] = {
1283 -3, -5, 4, -4, -3, -2, 0, 3, -4, 4, -3,
1284 };
1285
1286 float expected[] = {
1287 -33533, -32087, -6741, -2124, 39979, 41453, 14034, 689,
1288 -22611, -42203, -14882, -239, 15781, 15963, 9524, 837,
1289 };
1290
1291 int channels = 2;
1292
1293 CNN_CONFIG cnn_config = { 6,
1294 0,
1295 0,
1296 0,
1297 0,
1298 { {
1299 1,
1300 filter_width,
1301 filter_height,
1302 channels,
1303 1,
1304 1,
1305 0,
1306 weights,
1307 bias,
1308 PADDING_SAME_ZERO,
1309 NONE,
1310 0,
1311 0,
1312 BRANCH_NO_COPY,
1313 BRANCH_NOC,
1314 {},
1315 {},
1316 -1,
1317 },
1318 {
1319 channels,
1320 filter_width,
1321 filter_height,
1322 channels,
1323 1,
1324 1,
1325 0,
1326 nullptr,
1327 nullptr,
1328 PADDING_SAME_ZERO,
1329 NONE,
1330 0,
1331 0,
1332 BRANCH_INPUT,
1333 BRANCH_NOC,
1334 {
1335 0x02,
1336 0,
1337 0x00,
1338 },
1339 {},
1340 -1,
1341 },
1342 {
1343 channels,
1344 filter_width,
1345 filter_height,
1346 channels,
1347 1,
1348 1,
1349 0,
1350 nullptr,
1351 nullptr,
1352 PADDING_SAME_ZERO,
1353 NONE,
1354 0,
1355 1,
1356 BRANCH_NO_COPY,
1357 BRANCH_NOC,
1358 {},
1359 {},
1360 -1,
1361 },
1362 {
1363 channels,
1364 filter_width,
1365 filter_height,
1366 channels,
1367 1,
1368 1,
1369 0,
1370 nullptr,
1371 nullptr,
1372 PADDING_SAME_ZERO,
1373 NONE,
1374 0,
1375 1,
1376 BRANCH_NO_COPY,
1377 BRANCH_NOC,
1378 {},
1379 {},
1380 -1,
1381 },
1382 {
1383 channels,
1384 filter_width,
1385 filter_height,
1386 channels,
1387 1,
1388 1,
1389 0,
1390 nullptr,
1391 nullptr,
1392 PADDING_SAME_ZERO,
1393 NONE,
1394 0,
1395 0,
1396 BRANCH_NO_COPY,
1397 BRANCH_CAT,
1398 {
1399 0x00,
1400 0,
1401 0x02,
1402 },
1403 {},
1404 -1,
1405 },
1406 {
1407 channels + channels,
1408 filter_width,
1409 filter_height,
1410 1,
1411 1,
1412 1,
1413 0,
1414 nullptr,
1415 nullptr,
1416 PADDING_SAME_ZERO,
1417 NONE,
1418 0,
1419 0,
1420 BRANCH_NO_COPY,
1421 BRANCH_NOC,
1422 {},
1423 {},
1424 0,
1425 } } };
1426
1427 // Weights and biases need to be specified separately because
1428 // of the offset.
1429 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1430
1431 CNN_THREAD_DATA thread_data = { 1, NULL };
1432
1433 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1434 image_width, &thread_data, MSE_INT_TOL);
1435 }
1436
1437 // TODO(logangw): Add test to test all combinations of branch_copy_type.
1438
TEST_F(CNNTest,TestBranchCombinations)1439 TEST_F(CNNTest, TestBranchCombinations) {
1440 int filter_width = 2;
1441 int filter_height = 3;
1442
1443 int image_width = 4;
1444 int image_height = 4;
1445
1446 float input[] = {
1447 3, 2, -5, -4, 4, -2, -4, -3, 4, 2, -3, 2, -3, 1, -5, -1,
1448 };
1449
1450 float weights[] = {
1451 2, 3, 0, 4, 4, 3, 1, 0, 1, -5, 4, -3, 3, 0, 4, -1, -1, -5,
1452 2, 1, -3, -5, 3, -1, -3, -2, 0, -2, 3, 0, -2, -4, -2, -2, 2, -5,
1453 4, -5, 0, 1, -5, -4, -3, -4, 2, -2, 1, 0, 3, -2, -4, 3, 4, -4,
1454 -1, -1, -3, -2, -2, -1, 2, 0, 2, -1, 2, -4, -4, -1, 2, 0, 3, -2,
1455 -2, 3, -3, 4, -2, 4, 3, 4, 1, 0, -2, -3, -5, 1, -3, 2, 0, -2,
1456 -2, -1, -1, -5, -2, -3, -1, 3, 3, 4, 4, 0, 2, 1, 3, -3, 2, -5,
1457 -5, 1, -5, -1, 3, 3, 2, -4, -1, 3, -4, -2, -5, -2, 1, 3, 2, 2,
1458 -5, -2, -3, -1, -2, -4, -1, -2, 2, 1, -4, -4, 2, 0, 2, 0, 2, -3,
1459 -2, -4, 4, 0, 1, -3, -5, 4, -1, 2, 3, -5, -1, 0, 4, -1, -1, 3,
1460 -1, -3, 3, 1, 4, 3, 4, 3, -4, -5, -1, 3, 3, -4, 3, 1, 3, -5,
1461 3, 4, -5, 4, 2, -1, -5, 2, 1, 0, 4, 0, -3, 2, 0, 2, -2, 1,
1462 -1, -2, -1, -5, 4, 3, 3, -2, 2, 4, -5, -5, -3, -2, 4, 0, -4, 1,
1463 };
1464
1465 float bias[] = {
1466 -1, 4, 0, 2, 2, -2, 0, -4, -5, -1, 1, -2, 3, 0, 4, -2, 1, 0, 0,
1467 };
1468
1469 float expected[] = {
1470 149496, 15553, -24193, -20956, 134094, 86432, -68283, -6366,
1471 -53031, 133739, 67407, -13539, -53205, -58635, -20033, 1979,
1472 };
1473
1474 int channels = 2;
1475
1476 CNN_CONFIG cnn_config = { 10,
1477 0,
1478 0,
1479 0,
1480 0,
1481 {
1482 {
1483 1,
1484 filter_width,
1485 filter_height,
1486 channels,
1487 1,
1488 1,
1489 0,
1490 weights,
1491 bias,
1492 PADDING_SAME_ZERO,
1493 NONE,
1494 0,
1495 0,
1496 BRANCH_NO_COPY,
1497 BRANCH_NOC,
1498 {},
1499 {},
1500 -1,
1501 },
1502 {
1503 channels,
1504 filter_width,
1505 filter_height,
1506 channels,
1507 1,
1508 1,
1509 0,
1510 nullptr,
1511 nullptr,
1512 PADDING_SAME_ZERO,
1513 NONE,
1514 0,
1515 0,
1516 BRANCH_INPUT,
1517 BRANCH_NOC,
1518 {
1519 0x06,
1520 0,
1521 0x00,
1522 },
1523 {},
1524 -1,
1525 },
1526 {
1527 channels,
1528 filter_width,
1529 filter_height,
1530 channels,
1531 1,
1532 1,
1533 0,
1534 nullptr,
1535 nullptr,
1536 PADDING_SAME_ZERO,
1537 NONE,
1538 0,
1539 2,
1540 BRANCH_OUTPUT,
1541 BRANCH_NOC,
1542 {
1543 0x08,
1544 0,
1545 0x00,
1546 },
1547 {},
1548 -1,
1549 },
1550 {
1551 channels,
1552 filter_width,
1553 filter_height,
1554 channels,
1555 1,
1556 1,
1557 0,
1558 nullptr,
1559 nullptr,
1560 PADDING_SAME_ZERO,
1561 NONE,
1562 0,
1563 3,
1564 BRANCH_NO_COPY,
1565 BRANCH_NOC,
1566 {},
1567 {},
1568 -1,
1569 },
1570 {
1571 channels,
1572 filter_width,
1573 filter_height,
1574 channels,
1575 1,
1576 1,
1577 0,
1578 nullptr,
1579 nullptr,
1580 PADDING_SAME_ZERO,
1581 NONE,
1582 0,
1583 2,
1584 BRANCH_NO_COPY,
1585 BRANCH_ADD,
1586 {
1587 0x00,
1588 0,
1589 0x08,
1590 },
1591 {},
1592 -1,
1593 },
1594 {
1595 channels,
1596 filter_width,
1597 filter_height,
1598 channels,
1599 1,
1600 1,
1601 0,
1602 nullptr,
1603 nullptr,
1604 PADDING_SAME_ZERO,
1605 NONE,
1606 0,
1607 2,
1608 BRANCH_NO_COPY,
1609 BRANCH_NOC,
1610 {},
1611 {},
1612 -1,
1613 },
1614 {
1615 channels,
1616 filter_width,
1617 filter_height,
1618 channels,
1619 1,
1620 1,
1621 0,
1622 nullptr,
1623 nullptr,
1624 PADDING_SAME_ZERO,
1625 NONE,
1626 0,
1627 1,
1628 BRANCH_NO_COPY,
1629 BRANCH_NOC,
1630 {},
1631 {},
1632 -1,
1633 },
1634 {
1635 channels,
1636 filter_width,
1637 filter_height,
1638 channels,
1639 1,
1640 1,
1641 0,
1642 nullptr,
1643 nullptr,
1644 PADDING_SAME_ZERO,
1645 NONE,
1646 0,
1647 1,
1648 BRANCH_NO_COPY,
1649 BRANCH_ADD,
1650 {
1651 0x00,
1652 0,
1653 0x0C,
1654 },
1655 {},
1656 -1,
1657 },
1658 {
1659 channels,
1660 filter_width,
1661 filter_height,
1662 channels,
1663 1,
1664 1,
1665 0,
1666 nullptr,
1667 nullptr,
1668 PADDING_SAME_ZERO,
1669 NONE,
1670 0,
1671 0,
1672 BRANCH_NO_COPY,
1673 BRANCH_ADD,
1674 {
1675 0x00,
1676 0,
1677 0x02,
1678 },
1679 {},
1680 -1,
1681 },
1682 {
1683 channels,
1684 filter_width,
1685 filter_height,
1686 1,
1687 1,
1688 1,
1689 0,
1690 nullptr,
1691 nullptr,
1692 PADDING_SAME_ZERO,
1693 NONE,
1694 0,
1695 0,
1696 BRANCH_NO_COPY,
1697 BRANCH_NOC,
1698 {},
1699 {},
1700 0,
1701 },
1702 } };
1703
1704 // Weights and biases need to be specified separately because
1705 // of the offset.
1706 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1707
1708 CNN_THREAD_DATA thread_data = { 1, NULL };
1709
1710 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1711 image_width, &thread_data, MSE_INT_TOL);
1712 }
1713
TEST_F(CNNTest,TestSplittingTensors)1714 TEST_F(CNNTest, TestSplittingTensors) {
1715 int filter_width = 2;
1716 int filter_height = 3;
1717
1718 int image_width = 4;
1719 int image_height = 4;
1720
1721 float input[] = {
1722 -1, -1, 2, 1, 3, 2, 4, -3, -4, -2, 2, -3, 1, -3, 4, -2,
1723 };
1724
1725 float weights[] = {
1726 -4, 1, 0, 2, 3, 4, 4, -4, -5, -3, 2, 2, -4, -3, 3, 2,
1727 4, -4, -3, -4, -4, 1, -3, -5, -3, 4, 2, -2, 2, -1, -4, -1,
1728 -2, -3, 1, 1, 0, -5, -1, 3, 3, -5, -3, 0, -3, 1, -3, -1,
1729 1, -3, -2, -2, 4, -2, 0, 1, 2, 2, -4, 2, 4, 0, -5, -2,
1730 4, 4, -5, 1, 0, 2, -2, -5, -5, -3, -5, -5, 4, -3, 0, 0,
1731 -4, -4, 0, -5, -4, 0, 0, -3, -5, -3, -1, 2, -1, 4, -1, 2,
1732 };
1733
1734 float bias[] = {
1735 -4, -2, -3, -3, 3, 1, -2,
1736 };
1737
1738 float expected[] = {
1739 530, -762, 1469, 777, 849, -771, -1698, 600,
1740 -658, -1821, 98, -668, -1798, 30, 887, -971,
1741 };
1742
1743 CNN_CONFIG cnn_config = { 3,
1744 0,
1745 0,
1746 0,
1747 0,
1748 {
1749 {
1750 1,
1751 filter_width,
1752 filter_height,
1753 4,
1754 1,
1755 1,
1756 0,
1757 nullptr,
1758 nullptr,
1759 PADDING_SAME_ZERO,
1760 NONE,
1761 0,
1762 0,
1763 BRANCH_OUTPUT,
1764 BRANCH_NOC,
1765 {
1766 0x02,
1767 2,
1768 0x00,
1769 },
1770 {},
1771 -1,
1772 },
1773 {
1774 4,
1775 filter_width,
1776 filter_height,
1777 2,
1778 1,
1779 1,
1780 0,
1781 nullptr,
1782 nullptr,
1783 PADDING_SAME_ZERO,
1784 NONE,
1785 0,
1786 0,
1787 BRANCH_NO_COPY,
1788 BRANCH_CAT,
1789 {
1790 0x00,
1791 0,
1792 0x02,
1793 },
1794 {},
1795 -1,
1796 },
1797 {
1798 4,
1799 filter_width,
1800 filter_height,
1801 1,
1802 1,
1803 1,
1804 0,
1805 nullptr,
1806 nullptr,
1807 PADDING_SAME_ZERO,
1808 NONE,
1809 0,
1810 0,
1811 BRANCH_NO_COPY,
1812 BRANCH_NOC,
1813 {},
1814 {},
1815 0,
1816 },
1817 } };
1818
1819 // Weights and biases need to be specified separately because
1820 // of the offset.
1821 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1822
1823 CNN_THREAD_DATA thread_data = { 1, NULL };
1824
1825 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1826 image_width, &thread_data, MSE_INT_TOL);
1827 }
1828
TEST_F(CNNTest,TestOutputChannelsCount)1829 TEST_F(CNNTest, TestOutputChannelsCount) {
1830 int filter_width = 1;
1831 int filter_height = 1;
1832
1833 int image_width = 2;
1834 int image_height = 2;
1835
1836 float input[] = { 0, 0, 0, 0 };
1837
1838 float weights[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
1839
1840 float bias[] = { 0, 0, 0, 0, 0, 0 };
1841
1842 float expected[] = {
1843 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1844 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1845 };
1846
1847 CNN_CONFIG cnn_config = { 3,
1848 0,
1849 0,
1850 0,
1851 0,
1852 {
1853 {
1854 1,
1855 filter_width,
1856 filter_height,
1857 2,
1858 1,
1859 1,
1860 0,
1861 weights,
1862 bias,
1863 PADDING_SAME_ZERO,
1864 NONE,
1865 0,
1866 0,
1867 BRANCH_INPUT,
1868 BRANCH_NOC,
1869 {
1870 0x06,
1871 0,
1872 0x00,
1873 },
1874 {},
1875 -1,
1876 },
1877 {
1878 1,
1879 filter_width,
1880 filter_height,
1881 2,
1882 1,
1883 1,
1884 0,
1885 weights,
1886 bias,
1887 PADDING_SAME_ZERO,
1888 NONE,
1889 0,
1890 2,
1891 BRANCH_NO_COPY,
1892 BRANCH_CAT,
1893 {
1894 0x00,
1895 0,
1896 0x03,
1897 },
1898 {},
1899 -1,
1900 },
1901 {
1902 2,
1903 filter_width,
1904 filter_height,
1905 2,
1906 1,
1907 1,
1908 0,
1909 weights,
1910 bias,
1911 PADDING_SAME_ZERO,
1912 NONE,
1913 0,
1914 0,
1915 BRANCH_NO_COPY,
1916 BRANCH_CAT,
1917 {
1918 0x00,
1919 0,
1920 0x04,
1921 },
1922 {},
1923 0,
1924 },
1925 } };
1926
1927 // Weights and biases need to be specified separately because
1928 // of the offset.
1929 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1930
1931 CNN_THREAD_DATA thread_data = { 1, NULL };
1932
1933 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1934 image_width, &thread_data, MSE_FLOAT_TOL);
1935 }
1936
TEST_F(CNNTest,TestBatchNorm)1937 TEST_F(CNNTest, TestBatchNorm) {
1938 int image_width = 28;
1939 int image_height = 28;
1940 int filter_height = 7;
1941 int filter_width = 7;
1942 float input[] = {
1943 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1944 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1945 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1946 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1947 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1948 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1949 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1950 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1951 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1952 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1953 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1954 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1955 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1956 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1957 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1958 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1959 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1960 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1961 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1962 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1963 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1964 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1965 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1966 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1967 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1968 0.0f, 0.0f, 0.0117647f, 0.0705882f, 0.0705882f, 0.0705882f,
1969 0.494118f, 0.533333f, 0.686275f, 0.101961f, 0.65098f, 1.0f,
1970 0.968627f, 0.498039f, 0.0f, 0.0f, 0.0f, 0.0f,
1971 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1972 0.0f, 0.0f, 0.117647f, 0.141176f, 0.368627f, 0.603922f,
1973 0.666667f, 0.992157f, 0.992157f, 0.992157f, 0.992157f, 0.992157f,
1974 0.882353f, 0.67451f, 0.992157f, 0.94902f, 0.764706f, 0.25098f,
1975 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1976 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.192157f,
1977 0.933333f, 0.992157f, 0.992157f, 0.992157f, 0.992157f, 0.992157f,
1978 0.992157f, 0.992157f, 0.992157f, 0.984314f, 0.364706f, 0.321569f,
1979 0.321569f, 0.219608f, 0.152941f, 0.0f, 0.0f, 0.0f,
1980 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1981 0.0f, 0.0f, 0.0f, 0.0705882f, 0.858824f, 0.992157f,
1982 0.992157f, 0.992157f, 0.992157f, 0.992157f, 0.776471f, 0.713725f,
1983 0.968627f, 0.945098f, 0.0f, 0.0f, 0.0f, 0.0f,
1984 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1985 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1986 0.0f, 0.0f, 0.313725f, 0.611765f, 0.419608f, 0.992157f,
1987 0.992157f, 0.803922f, 0.0431373f, 0.0f, 0.168627f, 0.603922f,
1988 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1989 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1990 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1991 0.0f, 0.054902f, 0.00392157f, 0.603922f, 0.992157f, 0.352941f,
1992 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1993 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1994 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1995 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1996 0.0f, 0.545098f, 0.992157f, 0.745098f, 0.00784314f, 0.0f,
1997 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1998 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1999 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2000 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0431373f,
2001 0.745098f, 0.992157f, 0.27451f, 0.0f, 0.0f, 0.0f,
2002 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2003 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2004 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2005 0.0f, 0.0f, 0.0f, 0.0f, 0.137255f, 0.945098f,
2006 0.882353f, 0.627451f, 0.423529f, 0.00392157f, 0.0f, 0.0f,
2007 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2008 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2009 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2010 0.0f, 0.0f, 0.0f, 0.317647f, 0.941176f, 0.992157f,
2011 0.992157f, 0.466667f, 0.0980392f, 0.0f, 0.0f, 0.0f,
2012 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2013 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2014 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2015 0.0f, 0.0f, 0.176471f, 0.729412f, 0.992157f, 0.992157f,
2016 0.588235f, 0.105882f, 0.0f, 0.0f, 0.0f, 0.0f,
2017 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2018 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2019 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2020 0.0f, 0.0627451f, 0.364706f, 0.988235f, 0.992157f, 0.733333f,
2021 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2022 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2023 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2024 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2025 0.0f, 0.976471f, 0.992157f, 0.976471f, 0.25098f, 0.0f,
2026 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2027 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2028 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2029 0.0f, 0.0f, 0.180392f, 0.509804f, 0.717647f, 0.992157f,
2030 0.992157f, 0.811765f, 0.00784314f, 0.0f, 0.0f, 0.0f,
2031 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2032 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2033 0.0f, 0.0f, 0.0f, 0.0f, 0.152941f, 0.580392f,
2034 0.898039f, 0.992157f, 0.992157f, 0.992157f, 0.980392f, 0.713725f,
2035 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2036 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2037 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2038 0.0941176f, 0.447059f, 0.866667f, 0.992157f, 0.992157f, 0.992157f,
2039 0.992157f, 0.788235f, 0.305882f, 0.0f, 0.0f, 0.0f,
2040 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2041 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2042 0.0f, 0.0f, 0.0901961f, 0.258824f, 0.835294f, 0.992157f,
2043 0.992157f, 0.992157f, 0.992157f, 0.776471f, 0.317647f, 0.00784314f,
2044 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2045 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2046 0.0f, 0.0f, 0.0f, 0.0f, 0.0705882f, 0.670588f,
2047 0.858824f, 0.992157f, 0.992157f, 0.992157f, 0.992157f, 0.764706f,
2048 0.313725f, 0.0352941f, 0.0f, 0.0f, 0.0f, 0.0f,
2049 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2050 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2051 0.215686f, 0.67451f, 0.886275f, 0.992157f, 0.992157f, 0.992157f,
2052 0.992157f, 0.956863f, 0.521569f, 0.0431373f, 0.0f, 0.0f,
2053 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2054 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2055 0.0f, 0.0f, 0.0f, 0.0f, 0.533333f, 0.992157f,
2056 0.992157f, 0.992157f, 0.831373f, 0.529412f, 0.517647f, 0.0627451f,
2057 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2058 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2059 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2060 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2061 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2062 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2063 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2064 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2065 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2066 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2067 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2068 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2069 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2070 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2071 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2072 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2073 0.0f, 0.0f, 0.0f, 0.0f
2074 };
2075 float expected[] = {
2076 -0.836424f, -0.857365f, -1.62739f, -1.62739f, -0.836424f, 5.40742f,
2077 0.920853f, -0.692567f, -0.836424f, -0.534405f, -1.62739f, -0.836424f,
2078 1.32602f, 1.36312f, 0.112766f, -0.836424f, -0.192962f, 1.56975f,
2079 2.45777f, 0.944414f, -0.192962f, -1.5519f, -1.5519f, -0.554006f,
2080 -0.192962f, 1.4231f, -1.5519f, -0.192962f, 1.3661f, -1.5519f,
2081 -1.5519f, -0.192962f, -0.843708f, -0.359025f, -0.843708f, -0.843708f,
2082 -0.843708f, 4.53065f, 0.0429584f, -0.796804f, -0.843708f, 0.3473f,
2083 -0.843708f, -0.843708f, -0.114439f, 3.14817f, 0.0811934f, -0.843708f
2084 };
2085 float kernel[] = {
2086 0.119643f, -0.237864f, 0.0462892f, 0.0502297f, -0.0134528f,
2087 0.146347f, 0.153133f, 0.0513307f, 0.0752369f, 0.0135557f,
2088 -0.111434f, 0.0941854f, 0.0788362f, 0.0299412f, 0.111762f,
2089 0.144066f, 0.00431504f, -0.0177954f, 0.0738092f, -0.0344215f,
2090 0.0832582f, 0.053989f, -0.112691f, 0.0962145f, 0.0186525f,
2091 -0.00660205f, -0.111962f, -0.126801f, -0.231625f, 0.17309f,
2092 0.0748875f, -0.179569f, -0.00513812f, -0.156579f, -0.147322f,
2093 0.184168f, 0.189308f, -0.200359f, -0.0156733f, 0.140649f,
2094 0.0858496f, -0.0263217f, -0.0740749f, -0.112563f, 0.107528f,
2095 0.0609729f, -0.221625f, 0.0769944f, -0.00900815f, -0.00136441f,
2096 -0.0236521f, -0.0418025f, -0.00286299f, 0.12241f, 0.0964093f,
2097 -0.0150897f, 0.0532171f, 0.0625916f, 0.116939f, 0.118024f,
2098 0.161918f, -0.00909767f, 0.100897f, -0.054563f, -0.175179f,
2099 -0.0687892f, 0.00734235f, 0.109833f, -0.113776f, 0.0595405f,
2100 -0.170255f, 0.0124815f, -0.0363301f, -0.0127038f, 0.0445554f,
2101 -0.0729894f, 0.107428f, -0.0341417f, 0.132619f, 0.00984557f,
2102 -0.00443654f, 0.202929f, 0.0945134f, 0.0148725f, 0.00998574f,
2103 -0.0226449f, 0.0478197f, -0.0793442f, 0.0707599f, -0.084225f,
2104 0.0865795f, 0.071104f, -0.047894f, 0.0838322f, 0.0635493f,
2105 -0.00370265f, -0.157247f, -0.0289622f, -0.0590963f, 0.13207f,
2106 0.00468011f, -0.0345372f, 0.217939f, 0.18861f, -0.0290393f,
2107 -0.0440664f, 0.0126197f, -0.129132f, -0.124943f, 0.0968156f,
2108 -0.0853643f, -0.182305f, 0.00461618f, -0.147095f, -0.230282f,
2109 0.00856019f, 0.0278893f, -0.0300229f, 0.0417871f, 0.0804717f,
2110 -0.0768571f, -0.0397085f, -0.0601096f, 0.100901f, -0.0184926f,
2111 0.0350673f, 0.0971094f, -0.0171837f, -0.289644f, -0.0899041f,
2112 0.08998f, -0.160319f, -0.0195103f, 0.0392167f, -0.137864f,
2113 -0.0136294f, 0.0330886f, -0.0409244f, -0.092533f, -0.0427934f,
2114 -0.191144f, -0.0969461f, 0.112035f, 0.138611f, 0.128717f,
2115 0.191184f, 0.197462f
2116 };
2117 float bias[] = { 0.186703f, 0.204358f, -0.0230452f };
2118
2119 float bn_gamma[] = { 1.32173f, 1.26171f, 1.21966f };
2120 float bn_beta[] = { -0.232595f, -0.222652f, -0.232209f };
2121 float bn_mean[] = { 0.329233f, 0.199894f, 0.12389f };
2122 float bn_std[] = { 0.311986f, 0.189737f, 0.247104f };
2123
2124 CNN_BATCHNORM_PARAMS bn_params = {
2125 bn_gamma,
2126 bn_beta,
2127 bn_mean,
2128 bn_std,
2129 };
2130
2131 CNN_CONFIG cnn_config = {
2132 1,
2133 0,
2134 0,
2135 0,
2136 0,
2137 {
2138 {
2139 1,
2140 filter_width,
2141 filter_height,
2142 3,
2143 7,
2144 7,
2145 0,
2146 kernel,
2147 bias,
2148 PADDING_VALID,
2149 RELU,
2150 0,
2151 0,
2152 BRANCH_NO_COPY,
2153 BRANCH_NOC,
2154 {},
2155 bn_params,
2156 0,
2157 },
2158 },
2159 };
2160
2161 CNN_THREAD_DATA thread_data = { 1, NULL };
2162
2163 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2164 image_width, &thread_data, MSE_FLOAT_TOL);
2165 }
2166
TEST_F(CNNTest,TestMultithreading)2167 TEST_F(CNNTest, TestMultithreading) {
2168 int image_height = 2;
2169 int image_width = 2;
2170 int filter_height = 3;
2171 int filter_width = 3;
2172
2173 float input[] = {
2174 -2,
2175 4,
2176 1,
2177 0,
2178 };
2179
2180 float weights[] = {
2181 -4, 2, -2, 0, -4, 4, -3, -3, -3, -1, 1, 0, -5, -3, 0, -5, 0, 0,
2182 -1, 0, 2, -5, 0, 1, 4, 2, 1, 0, -2, -1, -5, -3, 2, -2, 1, -5,
2183 };
2184
2185 float bias[] = {
2186 -4,
2187 -3,
2188 -2,
2189 3,
2190 };
2191
2192 float expected[] = {
2193 2, 10, -8, -17, -24, 5, -15, 6, -5, -5, 7, -10, 4, 13, 9, -14,
2194 };
2195
2196 CNN_CONFIG cnn_config = {
2197 1,
2198 0,
2199 0,
2200 0,
2201 0,
2202 {
2203 {
2204 1,
2205 filter_width,
2206 filter_height,
2207 4,
2208 1,
2209 1,
2210 0,
2211 weights,
2212 bias,
2213 PADDING_SAME_ZERO,
2214 NONE,
2215 0,
2216 0,
2217 BRANCH_NO_COPY,
2218 BRANCH_NOC,
2219 {},
2220 {},
2221 0,
2222 },
2223 },
2224 };
2225
2226 CNN_THREAD_DATA thread_data = { 1, NULL };
2227
2228 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2229 image_width, &thread_data, MSE_FLOAT_TOL);
2230
2231 const AVxWorkerInterface *const winterface = aom_get_worker_interface();
2232 AVxWorker workers[4];
2233
2234 for (int i = 0; i < 4; ++i) {
2235 winterface->init(&workers[i]);
2236 }
2237
2238 thread_data = { 4, workers };
2239
2240 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2241 image_width, &thread_data, MSE_FLOAT_TOL);
2242
2243 for (int i = 0; i < 4; ++i) {
2244 winterface->end(&workers[i]);
2245 }
2246 }
2247
TEST_F(CNNTest,TestMultiOutput)2248 TEST_F(CNNTest, TestMultiOutput) {
2249 const int image_dim = 8;
2250 const int image_ch = 3;
2251 const int filter_dim = 2;
2252 const int stride = 2;
2253 const int num_filters = 2;
2254
2255 const float input_[] = {
2256 1.7537929121f, 0.134331551012f, 0.123580039877f, 0.957731845246f,
2257 0.391006834217f, 1.00699352042f, -0.778177955829f, -0.814166433059f,
2258 -0.656374394915f, 0.321967305228f, -2.19455719176f, 0.708035038966f,
2259 0.409148822266f, -0.318254408902f, 0.152450211189f, -0.250210793369f,
2260 0.826811563186f, 1.6804156584f, 0.273626975978f, 0.437936241887f,
2261 -0.329935520167f, -0.288761611645f, 0.156937008304f, 0.271054157295f,
2262 -0.0224828854332f, 1.70110336895f, -0.989066699309f, 1.30863131729f,
2263 -0.165813705702f, 0.00380178619265f, -0.0837342367587f, 0.760954783156f,
2264 -0.413610373524f, 1.17968204175f, 0.720295719536f, 0.308718974472f,
2265 -1.10091337671f, 0.693160033687f, -0.0202862320697f, 1.0221927503f,
2266 -1.24521801881f, -0.478501952308f, -1.71648619442f, -0.182571723636f,
2267 0.339292649504f, 2.0806519131f, 0.967974033444f, 0.175248672328f,
2268 0.0658124561472f, 0.795504169496f, 0.750592557361f, -1.46631013249f,
2269 -1.79052846838f, -1.03672179515f, -0.841985521653f, 1.20995011489f,
2270 0.140859718215f, -0.651552622661f, 0.451065110806f, 1.1189443693f,
2271 0.100213260593f, -0.834076868118f, -1.28734321611f, 1.22064420095f,
2272 -0.364143084361f, 0.750961509335f, -0.888689074553f, -0.8253547106f,
2273 -1.21800999027f, -0.966670603566f, 1.37384014741f, 0.47281264834f,
2274 -0.420416235531f, 0.520163906493f, 0.501296589423f, 1.53418976951f,
2275 0.715234751485f, 0.644551588907f, 0.0763504863375f, -0.0018541943723f,
2276 0.322853189656f, -0.795099723224f, -0.125177096675f, 1.4476577471f,
2277 -0.585888410088f, -1.44391754955f, -0.610543221933f, -0.221859179799f,
2278 0.252060200774f, -0.86287169623f, -0.0350246229157f, 1.0932311997f,
2279 0.899464648842f, -0.468806951704f, -0.300861137168f, 1.15776414206f,
2280 1.03268544738f, -0.171579585622f, -0.179136557119f, -0.354091003368f,
2281 -0.612298249394f, -1.20237379258f, 1.54604109659f, 0.130664370287f,
2282 0.885225111868f, 1.0362799581f, 0.980561720868f, -0.619379186999f,
2283 -1.33818929924f, -0.237233737961f, -1.89335425073f, 0.567821011321f,
2284 0.862420368465f, -1.37380916821f, 0.352190056666f, 0.611261516274f,
2285 0.393237747152f, 0.894686247967f, 0.190405182149f, 0.264872662911f,
2286 -0.0657009133797f, 0.0580512653493f, -0.401825294366f, 0.4106081318f,
2287 0.49484512188f, -0.0751103149442f, -1.43243736382f, 1.79855656009f,
2288 -1.1075351975f, 0.000354882733011f, -0.950716438608f, 1.27129831688f,
2289 1.00495189838f, 0.110358656713f, 1.08315032822f, -0.972676676218f,
2290 -0.0757668962831f, 1.88932045165f, -0.0672638136275f, 0.425913010161f,
2291 -0.781540372017f, 0.976000248609f, 0.687218504122f, 1.31374513445f,
2292 -0.932658930672f, -1.25339468479f, 0.422071294078f, -0.24189927912f,
2293 0.216906604642f, -1.88720997548f, 1.99252872889f, 0.353943735777f,
2294 0.737434784132f, -1.17848645017f, 1.70424254896f, 0.775297112968f,
2295 -0.516392797501f, 0.398130609129f, 0.737248101457f, 0.166282500886f,
2296 1.24699015468f, 0.47116183125f, 1.19091180182f, -0.372695424578f,
2297 0.219773209389f, -0.829467838962f, -0.52533122724f, 1.98707754595f,
2298 0.553692606972f, -0.933228902369f, 1.55427751643f, -1.08813399144f,
2299 -0.325686682094f, 0.205091443796f, -1.70381666435f, 0.466465327942f,
2300 1.73126863447f, -0.939133672634f, 1.48318077459f, -0.599414038168f,
2301 -1.1583078687f, 0.518116190201f, 0.133571482458f, 0.84958342672f,
2302 1.02205000597f, -0.0772082009087f, -1.69567503859f, 1.4697939436f,
2303 1.67813743122f, -0.627911582938f, 0.131380509137f, -1.35717850726f,
2304 };
2305 const float *input[3] = { input_, &input_[image_dim * image_dim],
2306 &input_[2 * image_dim * image_dim] };
2307
2308 const float bias[] = { 0.0f, 0.0f };
2309
2310 const float weights_1[] = {
2311 -0.489547413618f, 0.141916424749f, -0.279286485585f, -0.115322211094f,
2312 0.299572786936f, 0.205289980785f, -0.536254480088f, -0.253626313744f,
2313 -0.422883815849f, -0.169702966298f, -0.540104704793f, 0.495319646763f,
2314 0.298799079422f, -0.10054550901f, -0.306085047056f, 0.171061886165f,
2315 -0.108058703878f, -0.410734629888f, -0.0640674673049f, -0.386524840979f,
2316 -0.157203423678f, -0.362138920529f, -0.216206085209f, 0.147502517971f,
2317 };
2318
2319 const float weights_2[] = {
2320 0.207580604357f, 0.480821146263f, -0.29111909562f, 0.47422567493f,
2321 0.206892553253f, -0.235067084092f, 0.354516800602f, -0.212399370252f,
2322 -0.419071343731f, -0.050350731631f, -0.0516457320279f, -0.0359310500731f,
2323 0.567044864811f, -0.060341127522f, 0.0501464839637f, -0.437785677916f,
2324 };
2325
2326 const float weights_3[] = {
2327 -0.0690452401448f, -0.356657338763f, -0.219464031809f, 0.551288365843f,
2328 0.181372090853f, -0.00245268542109f, 0.409000696276f, -0.593209108763f,
2329 0.587352566749f, -0.243720660227f, 0.266232713887f, -0.00439285245097f,
2330 0.252883228305f, 0.152646192631f, 0.0918944932026f, 0.398853715057f,
2331 };
2332
2333 const float weights_4[] = {
2334 0.207560791573f, 0.194201350401f, 0.227802322443f, 0.206533663345f,
2335 0.0557331066805f, 0.0224159800424f, -0.143939197467f, -0.27703361602f,
2336 0.130643888389f, -0.269456557461f, 0.186242862864f, -0.162879944774f,
2337 -0.145503996718f, -0.0768822987581f, -0.203127976359f, -0.238119922873f,
2338 -0.258806479994f, 0.0357957680385f, -0.1027606976f, -0.287920082345f,
2339 0.189047820993f, 0.250711538481f, -0.272815714175f, -0.0431449742024f,
2340 0.207261230996f, -0.0396472677451f, 0.131236557412f, 0.174291832499f,
2341 -0.251515885765f, -0.107164007499f, 0.185824534748f, -0.00561585838161f,
2342 0.273393799578f, -0.139563699075f, -0.263922456031f, -0.118859844081f,
2343 0.109230982597f, -0.170170294794f, 0.0123025648515f, -0.0839368964355f,
2344 -0.0774058234297f, 0.255847138286f, -0.208430879637f, 0.279170114319f,
2345 -0.272890330712f, -0.217725903006f, -0.295923275459f, -0.17008723953f,
2346 -0.284281803405f, 0.281406323629f, 0.266910044663f, -0.209963914338f,
2347 0.271980962964f, 0.142013581699f, -0.143896509026f, -0.290509242975f,
2348 -0.305768180935f, 0.196902832117f, -0.090424189662f, -0.147460802346f,
2349 0.217722016651f, 0.12353848977f, -0.169177363577f, -0.0454230918512f,
2350 };
2351
2352 const float expected_0[] = {
2353 -2.04858441055f, -2.12883075791f, -0.045177363807f, 0.763949675768f,
2354 -0.544361512821f, -1.58123168032f, 1.89319847039f, 0.16859080901f,
2355 -1.16023321135f, -0.396988107751f, 1.76637090744f, -1.40434786514f,
2356 0.908227575669f, 0.817064817605f, 0.215631134908f, -0.848605613428f,
2357 -0.106756747018f, 0.0193027166685f, 0.801345615113f, -0.395407237598f,
2358 -1.79983795658f, -1.73054496242f, 0.0584392594454f, -0.388786095569f,
2359 -0.237269619354f, 0.000843578271263f, -1.24043512104f, 0.487839445893f,
2360 -0.394259726605f, 0.559632843424f, -0.527224052291f, -1.53792340282f,
2361 };
2362
2363 const float expected_1[] = {
2364 0.0f, 0.0f, 0.0f, 0.0f, 0.4057888292f, 0.325309571755f,
2365 0.0f, 1.22013465602f,
2366 };
2367
2368 const float expected_2[] = {
2369 0.156119444687f,
2370 0.517385299817f,
2371 };
2372
2373 const float expected_3[] = {
2374 0.224177852984f,
2375 0.503384419034f,
2376 0.156119444687f,
2377 0.517385299817f,
2378 };
2379
2380 const float *expected[] = { expected_0, expected_1, expected_2, expected_3 };
2381
2382 CNN_CONFIG cnn_config = {
2383 4, // num_layers
2384 0, // is_residue
2385 0, // ext_width
2386 0, // ext_height
2387 0, // strict_bounds
2388 {
2389 // layer_config
2390 {
2391 image_ch, // in_channels
2392 filter_dim, // filter_width
2393 filter_dim, // filter_height
2394 num_filters, // out_channels
2395 stride, // skip_width
2396 stride, // skip_height
2397 0, // max_pool
2398 weights_1, // weights
2399 bias, // bias
2400 PADDING_SAME_ZERO, // pad
2401 NONE, // activation
2402 0, // deconvolve
2403 0, // branch
2404 BRANCH_OUTPUT, // branch_copy_type
2405 BRANCH_NOC, // branch_combine_type
2406 { 2, 0, 0 }, // branch_config
2407 {}, // bn_params
2408 0, // output_num
2409 },
2410 {
2411 num_filters, // in_channels
2412 filter_dim, // filter_width
2413 filter_dim, // filter_height
2414 num_filters, // out_channels
2415 stride, // skip_width
2416 stride, // skip_height
2417 0, // max_pool
2418 weights_2, // weights
2419 bias, // bias
2420 PADDING_SAME_ZERO, // pad
2421 RELU, // activation
2422 0, // deconvolve
2423 0, // branch
2424 BRANCH_NO_COPY, // branch_copy_type
2425 BRANCH_NOC, // branch_combine_type
2426 {}, // branch_config
2427 {}, // bn_params
2428 1, // output_num
2429 },
2430 {
2431 num_filters, // in_channels
2432 filter_dim, // filter_width
2433 filter_dim, // filter_height
2434 num_filters, // out_channels
2435 stride, // skip_width
2436 stride, // skip_height
2437 0, // max_pool
2438 weights_3, // weights
2439 bias, // bias
2440 PADDING_SAME_ZERO, // pad
2441 RELU, // activation
2442 0, // deconvolve
2443 0, // branch
2444 BRANCH_NO_COPY, // branch_copy_type
2445 BRANCH_NOC, // branch_combine_type
2446 {}, // branch_config
2447 {}, // bn_params
2448 2, // output_num
2449 },
2450 {
2451 num_filters, // in_channels
2452 2 * filter_dim, // filter_width
2453 2 * filter_dim, // filter_height
2454 num_filters, // out_channels
2455 2 * stride, // skip_width
2456 2 * stride, // skip_height
2457 0, // max_pool
2458 weights_4, // weights
2459 bias, // bias
2460 PADDING_VALID, // pad
2461 RELU, // activation
2462 0, // deconvolve
2463 1, // branch
2464 BRANCH_NO_COPY, // branch_copy_type
2465 BRANCH_CAT, // branch_combine_type
2466 { 0, 0, 1 }, // branch_config
2467 {}, // bn_params
2468 3, // output_num
2469 },
2470 },
2471 };
2472
2473 CNN_THREAD_DATA thread_data = { 1, NULL };
2474
2475 const int num_outputs = 4;
2476 const int output_chs[4] = { filter_dim, filter_dim, filter_dim,
2477 2 * filter_dim };
2478 const int output_dims[4] = { 4, 2, 1, 1 };
2479 const int output_sizes[4] = {
2480 output_chs[0] * output_dims[0] * output_dims[0],
2481 output_chs[1] * output_dims[1] * output_dims[1],
2482 output_chs[2] * output_dims[2] * output_dims[2],
2483 output_chs[3] * output_dims[3] * output_dims[3],
2484 };
2485 float *const output_ = (float *)aom_malloc(
2486 sizeof(*output_) *
2487 (output_sizes[0] + output_sizes[1] + output_sizes[2] + output_sizes[3]));
2488 float *output[CNN_MAX_CHANNELS] = { nullptr };
2489 int ch_ite = 0;
2490 float *output_ite = output_;
2491 for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
2492 for (int channel = 0; channel < output_chs[output_idx]; ++channel) {
2493 output[ch_ite++] = output_ite;
2494 output_ite += output_dims[output_idx] * output_dims[output_idx];
2495 }
2496 }
2497 CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_dims,
2498 output };
2499
2500 RunMultiOutCNNTest(input, image_dim, image_dim, image_dim, &cnn_config,
2501 &thread_data, &output_struct, expected, MSE_FLOAT_TOL);
2502
2503 aom_free(output_);
2504 }
2505
2506 namespace {
2507
2508 typedef void (*CNNConvolveNoMaxpoolPaddingValidFunc)(
2509 const float **input, int in_width, int in_height, int in_stride,
2510 const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride,
2511 int start_idx, int cstep, int channel_step);
2512
2513 typedef libaom_test::FuncParam<CNNConvolveNoMaxpoolPaddingValidFunc>
2514 CNNConvolveTestFuncs;
2515
2516 class CNNConvolveTest : public ::testing::TestWithParam<CNNConvolveTestFuncs> {
2517 protected:
SetUp()2518 virtual void SetUp() { params_ = GetParam(); }
2519
RunCNNConvolveSetup(int run_times)2520 void RunCNNConvolveSetup(int run_times) {
2521 int in_width = 65;
2522 int in_height = 65;
2523
2524 const CNN_CONFIG *cnn_config = &av1_intra_mode_cnn_partition_cnn_config;
2525
2526 for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
2527 int out_width = 0, out_height = 0;
2528 int in_size = in_width * in_height;
2529 // Get current layer output width and height.
2530 av1_find_cnn_layer_output_size(in_height, in_width,
2531 &cnn_config->layer_config[layer],
2532 &out_width, &out_height);
2533
2534 int out_size = out_width * out_height;
2535 float *input[20], *output_ref[20], *output_mod[20];
2536
2537 float *input_data =
2538 (float *)aom_malloc(sizeof(*input_data) * in_size *
2539 cnn_config->layer_config[layer].in_channels);
2540 float *temp_ptr = input_data;
2541 for (int i = 0; i < cnn_config->layer_config[layer].in_channels; ++i) {
2542 input[i] = temp_ptr;
2543 for (int j = 0; j < in_size; j++) {
2544 *(temp_ptr++) = ((float)rng_.Rand31() - (1 << 30)) / (1u << 31);
2545 }
2546 }
2547
2548 float *out_data_ref = (float *)aom_calloc(
2549 sizeof(*out_data_ref),
2550 out_size * cnn_config->layer_config[layer].out_channels);
2551 float *out_data_mod = (float *)aom_calloc(
2552 sizeof(*out_data_mod),
2553 out_size * cnn_config->layer_config[layer].out_channels);
2554 float *temp_ptr1 = out_data_ref;
2555 float *temp_ptr2 = out_data_mod;
2556 for (int i = 0; i < cnn_config->layer_config[layer].out_channels; ++i) {
2557 output_ref[i] = temp_ptr1;
2558 output_mod[i] = temp_ptr2;
2559 temp_ptr1 += out_size;
2560 temp_ptr2 += out_size;
2561 }
2562
2563 RunCNNConvolveTest(input, in_width, in_height, out_size,
2564 &cnn_config->layer_config[layer], 0, 1, run_times,
2565 layer, output_ref, output_mod, out_width);
2566
2567 // Set current layer output width and height as next layer input width and
2568 // height.
2569 in_width = out_width;
2570 in_height = out_height;
2571
2572 aom_free(input_data);
2573 aom_free(out_data_ref);
2574 aom_free(out_data_mod);
2575 }
2576 }
2577
RunCNNConvolveTest(float ** input,int in_width,int in_height,int out_size,const CNN_LAYER_CONFIG * layer_config,int start_idx,int step,int run_times,int layer,float ** output_ref,float ** output_mod,int out_stride)2578 void RunCNNConvolveTest(float **input, int in_width, int in_height,
2579 int out_size, const CNN_LAYER_CONFIG *layer_config,
2580 int start_idx, int step, int run_times, int layer,
2581 float **output_ref, float **output_mod,
2582 int out_stride) {
2583 const int cstep = layer_config->in_channels * layer_config->out_channels;
2584 const int channel_step = AOMMAX(step, 1);
2585 aom_usec_timer timer;
2586 aom_usec_timer_start(&timer);
2587 for (int i = 0; i < run_times; ++i) {
2588 params_.ref_func((const float **)input, in_width, in_height, in_width,
2589 layer_config, output_ref, out_stride, start_idx, cstep,
2590 channel_step);
2591 }
2592 aom_usec_timer_mark(&timer);
2593 const double time1 = static_cast<double>(aom_usec_timer_elapsed(&timer));
2594
2595 aom_usec_timer_start(&timer);
2596 for (int i = 0; i < run_times; ++i) {
2597 params_.tst_func((const float **)input, in_width, in_height, in_width,
2598 layer_config, output_mod, out_stride, start_idx, cstep,
2599 channel_step);
2600 }
2601 aom_usec_timer_mark(&timer);
2602 const double time2 = static_cast<double>(aom_usec_timer_elapsed(&timer));
2603
2604 if (run_times > 1) {
2605 printf("layer : %d \n", layer);
2606 printf("%7.2f/%7.2fns (%3.2f)\n", time1, time2, time1 / time2);
2607 } else {
2608 for (int channel = 0; channel < layer_config->out_channels; ++channel) {
2609 const float *buf_ref = output_ref[channel];
2610 const float *buf_mod = output_mod[channel];
2611
2612 for (int i = 0; i < out_size; ++i) {
2613 if (buf_ref[i] < CNN_CONVOLVE_PIXELWISE_FLOAT_TOL) {
2614 ASSERT_LE(buf_ref[i], CNN_CONVOLVE_PIXELWISE_FLOAT_TOL)
2615 << "Reference output was near-zero, test output was not ("
2616 << buf_mod[i] << ")";
2617 } else {
2618 const float error = buf_ref[i] - buf_mod[i];
2619 const float relative_error = fabsf(error / buf_ref[i]);
2620 ASSERT_LE(relative_error, CNN_CONVOLVE_PIXELWISE_FLOAT_TOL)
2621 << " channel " << channel << " pixel " << i << ": "
2622 << buf_ref[i] << "/" << buf_mod[i] << std::endl;
2623 }
2624 }
2625 }
2626 }
2627 }
2628
2629 private:
2630 CNNConvolveTestFuncs params_;
2631 libaom_test::ACMRandom rng_;
2632 };
2633 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CNNConvolveTest);
2634
TEST_P(CNNConvolveTest,CheckOutput)2635 TEST_P(CNNConvolveTest, CheckOutput) { RunCNNConvolveSetup(1); }
2636
TEST_P(CNNConvolveTest,DISABLED_Speed)2637 TEST_P(CNNConvolveTest, DISABLED_Speed) { RunCNNConvolveSetup(100000); }
2638
2639 #if HAVE_AVX2 && !CONFIG_EXCLUDE_SIMD_MISMATCH
2640 INSTANTIATE_TEST_SUITE_P(AVX2, CNNConvolveTest,
2641 ::testing::Values(CNNConvolveTestFuncs(
2642 &av1_cnn_convolve_no_maxpool_padding_valid_c,
2643 &av1_cnn_convolve_no_maxpool_padding_valid_avx2)));
2644 #endif
2645
2646 } // namespace
2647