1 #include "onnx_converter.h"
2 #include <cmath>
3 #include <random>
4
5 #define EXPECT_EQ(a, b) \
6 if ((a) != (b)) { \
7 exit(-1); \
8 }
9 #define EXPECT_NEAR(a, b, c) \
10 if (std::abs((a) - (b)) > (c)) { \
11 exit(-1); \
12 }
13
test_abs()14 static void test_abs() {
15 onnx::NodeProto abs_node;
16 abs_node.set_name("abs_node");
17 abs_node.set_op_type("Abs");
18 abs_node.add_input("x");
19 abs_node.add_output("y");
20
21 std::vector<Tensor> node_inputs;
22 node_inputs.resize(1);
23 node_inputs[0].shape = {200};
24 Halide::Buffer<float> input(200);
25 std::uniform_real_distribution<float> dis(-1.0, 1.0);
26 std::mt19937 rnd;
27 input.for_each_value([&](float &f) { f = dis(rnd); });
28 Halide::Var index;
29 node_inputs[0].rep(index) = input(index);
30
31 Node converted = convert_node(abs_node, node_inputs);
32
33 GOOGLE_CHECK_EQ(1, converted.outputs.size());
34 Halide::Buffer<float> output = converted.outputs[0].rep.realize(200);
35 for (int i = 0; i < 200; ++i) {
36 EXPECT_EQ(output(i), std::abs(input(i)));
37 }
38 }
39
test_activation_function()40 static void test_activation_function() {
41 onnx::NodeProto relu_node;
42 relu_node.set_name("relu_node");
43 relu_node.set_op_type("Relu");
44 relu_node.add_input("x");
45 relu_node.add_output("y");
46
47 std::vector<Tensor> node_inputs;
48 node_inputs.resize(1);
49 node_inputs[0].shape = {200};
50 Halide::Buffer<float> input(200);
51 std::mt19937 rnd;
52 std::uniform_real_distribution<float> dis(-1.0, 1.0);
53 input.for_each_value([&](float &f) { f = dis(rnd); });
54 Halide::Var index;
55 node_inputs[0].rep(index) = input(index);
56
57 Node converted = convert_node(relu_node, node_inputs);
58
59 GOOGLE_CHECK_EQ(1, converted.outputs.size());
60 Halide::Buffer<float> output = converted.outputs[0].rep.realize(200);
61 for (int i = 0; i < 200; ++i) {
62 EXPECT_EQ(output(i), std::max(input(i), 0.0f));
63 }
64 }
65
test_cast()66 static void test_cast() {
67 onnx::NodeProto cast_node;
68 cast_node.set_name("relu_node");
69 cast_node.set_op_type("Cast");
70 cast_node.add_input("x");
71 cast_node.add_output("y");
72
73 std::vector<Tensor> node_inputs;
74 onnx::AttributeProto *attr = cast_node.add_attribute();
75 attr->set_name("to");
76 attr->set_i(onnx::TensorProto_DataType_FLOAT);
77 node_inputs.resize(1);
78 node_inputs[0].shape = {200};
79 Halide::Buffer<int> input(200);
80 std::mt19937 rnd;
81 std::uniform_int_distribution<int> dis(-100, 100);
82 input.for_each_value([&](int &f) { f = dis(rnd); });
83 Halide::Var index;
84 node_inputs[0].rep(index) = input(index);
85
86 Node converted = convert_node(cast_node, node_inputs);
87
88 GOOGLE_CHECK_EQ(1, converted.outputs.size());
89 Halide::Buffer<float> output = converted.outputs[0].rep.realize(200);
90 for (int i = 0; i < 200; ++i) {
91 EXPECT_EQ(output(i), static_cast<float>(input(i)));
92 }
93 }
94
test_add()95 static void test_add() {
96 onnx::NodeProto add_node;
97 add_node.set_name("add_node");
98 add_node.set_op_type("Add");
99 add_node.add_input("x");
100 add_node.add_input("y");
101 add_node.add_output("z");
102
103 std::vector<Tensor> node_inputs;
104 node_inputs.resize(2);
105 node_inputs[0].shape = {200};
106 node_inputs[1].shape = node_inputs[0].shape;
107 Halide::Buffer<float> in1(200);
108 std::mt19937 rnd;
109 std::uniform_real_distribution<float> dis(-1.0, 1.0);
110 std::uniform_real_distribution<float> dis10(-10.0, 10.0);
111 in1.for_each_value([&](float &f) { f = dis(rnd); });
112 Halide::Buffer<float> in2(200);
113 in2.for_each_value([&](float &f) { f = dis10(rnd); });
114 Halide::Var index;
115 node_inputs[0].rep(index) = in1(index);
116 node_inputs[1].rep(index) = in2(index);
117
118 Node converted = convert_node(add_node, node_inputs);
119
120 GOOGLE_CHECK_EQ(1, converted.outputs.size());
121 Halide::Buffer<float> output = converted.outputs[0].rep.realize(200);
122 for (int i = 0; i < 200; ++i) {
123 EXPECT_NEAR(output(i), in1(i) + in2(i), 1e-6);
124 }
125 }
126
test_constant()127 static void test_constant() {
128 onnx::NodeProto add_node;
129 add_node.set_name("constant_node");
130 add_node.set_op_type("Constant");
131 add_node.add_output("y");
132 onnx::AttributeProto *attr = add_node.add_attribute();
133 attr->set_name("value");
134
135 onnx::TensorProto &value = *attr->mutable_t();
136 value.set_data_type(onnx::TensorProto_DataType_FLOAT);
137 value.add_dims(3);
138 value.add_dims(7);
139 std::mt19937 rnd;
140 std::uniform_real_distribution<float> dis(-10.0, 10.0);
141 for (int i = 0; i < 3 * 7; ++i) {
142 value.add_float_data(dis(rnd));
143 }
144
145 Node converted = convert_node(add_node, {});
146
147 GOOGLE_CHECK_EQ(1, converted.outputs.size());
148 Halide::Buffer<float> output = converted.outputs[0].rep.realize({3, 7});
149 for (int i = 0; i < 3; ++i) {
150 for (int j = 0; j < 7; ++j) {
151 EXPECT_EQ(output(i, j), value.float_data(j + 7 * i));
152 }
153 }
154 }
155
test_gemm()156 static void test_gemm() {
157 onnx::NodeProto add_node;
158 add_node.set_name("gemm_node");
159 add_node.set_op_type("Gemm");
160 add_node.add_input("a");
161 add_node.add_input("b");
162 add_node.add_input("c");
163 add_node.add_output("y");
164
165 std::vector<Tensor> node_inputs;
166 node_inputs.resize(3);
167 node_inputs[0].shape = {32, 100};
168 node_inputs[1].shape = {100, 64};
169 node_inputs[2].shape = {32, 64};
170
171 std::uniform_real_distribution<float> dis(-1.0, 1.0);
172 std::uniform_real_distribution<float> dis10(-10.0, 10.0);
173
174 std::mt19937 rnd;
175 Halide::Buffer<float> in1(32, 100);
176 in1.for_each_value([&](float &f) { f = dis(rnd); });
177 Halide::Buffer<float> in2(100, 64);
178 in2.for_each_value([&](float &f) { f = dis10(rnd); });
179 Halide::Buffer<float> in3(32, 64);
180 in3.for_each_value([&](float &f) { f = dis(rnd); });
181 Halide::Var i1, j1;
182 node_inputs[0].rep(i1, j1) = in1(i1, j1);
183 Halide::Var i2, j2;
184 node_inputs[1].rep(i2, j2) = in2(i2, j2);
185 Halide::Var i3, j3;
186 node_inputs[2].rep(i3, j3) = in3(i3, j3);
187 Node converted = convert_node(add_node, node_inputs);
188
189 GOOGLE_CHECK_EQ(1, converted.outputs.size());
190 Halide::Buffer<float> output = converted.outputs[0].rep.realize(32, 64);
191
192 for (int i = 0; i < 32; ++i) {
193 for (int j = 0; j < 64; ++j) {
194 float expected = in3(i, j);
195 for (int k = 0; k < 100; ++k) {
196 expected += in1(i, k) * in2(k, j);
197 }
198 EXPECT_NEAR(output(i, j), expected, 5e-5f);
199 }
200 }
201 }
202
test_conv()203 static void test_conv() {
204 onnx::NodeProto add_node;
205 add_node.set_name("conv_node");
206 add_node.set_op_type("Conv");
207 add_node.add_input("x");
208 add_node.add_input("w");
209 add_node.add_output("y");
210
211 std::vector<Tensor> node_inputs;
212 node_inputs.resize(2);
213 node_inputs[0].shape = {3, 5, 6, 6};
214 node_inputs[1].shape = {7, 5, 3, 3};
215
216 std::uniform_real_distribution<float> dis(-1.0, 1.0);
217 std::uniform_real_distribution<float> dis10(-10.0, 10.0);
218
219 std::mt19937 rnd;
220 Halide::Buffer<float> weights(7, 5, 3, 3);
221 weights.for_each_value([&](float &f) { f = dis10(rnd); });
222 Halide::Var i2, j2, k2, l2;
223 node_inputs[1].rep(i2, j2, k2, l2) = weights(i2, j2, k2, l2);
224
225 const std::vector<int> in_shape[2] = {{3, 5, 6, 11}, {3, 5, 10, 14}};
226 const std::vector<int> out_shape[2] = {{3, 7, 4, 9}, {3, 7, 8, 12}};
227
228 for (int trial = 0; trial < 2; ++trial) {
229 node_inputs[0].shape.resize(4);
230 for (int dim = 0; dim < 4; ++dim) {
231 node_inputs[0].shape[dim] = in_shape[trial][dim];
232 }
233
234 Halide::Buffer<float> in(in_shape[trial]);
235 in.for_each_value([&](float &f) { f = dis(rnd); });
236 Halide::Var i1, j1, k1, l1;
237 node_inputs[0].rep = Halide::Func();
238 node_inputs[0].rep(i1, j1, k1, l1) = in(i1, j1, k1, l1);
239
240 Node converted = convert_node(add_node, node_inputs);
241
242 GOOGLE_CHECK_EQ(1, converted.outputs.size());
243 Halide::Buffer<float> output =
244 converted.outputs[0].rep.realize(out_shape[trial]);
245
246 for (int i = 0; i < 3; ++i) {
247 for (int j = 0; j < 7; ++j) {
248 for (int k = 0; k < out_shape[trial][2]; ++k) {
249 for (int l = 0; l < out_shape[trial][3]; ++l) {
250 float expected = 0;
251 for (int c = 0; c < 5; ++c) {
252 for (int w = 0; w < 3; ++w) {
253 for (int h = 0; h < 3; ++h) {
254 expected += in(i, c, k + w, l + h) * weights(j, c, w, h);
255 }
256 }
257 }
258 EXPECT_NEAR(output(i, j, k, l), expected, 5e-4f);
259 }
260 }
261 }
262 }
263 }
264 }
265
test_sum()266 static void test_sum() {
267 onnx::NodeProto sum_node;
268 sum_node.set_name("sum_node");
269 sum_node.set_op_type("ReduceSum");
270 sum_node.add_input("x");
271 sum_node.add_output("y");
272
273 onnx::AttributeProto *attr = sum_node.add_attribute();
274 attr->set_name("axes");
275 attr->add_ints(0);
276 attr->add_ints(2);
277
278 std::vector<Tensor> node_inputs;
279 node_inputs.resize(1);
280 node_inputs[0].shape = {7, 3, 5, 11};
281 Halide::Buffer<float> in1(7, 3, 5, 11);
282 std::uniform_real_distribution<float> dis(-1.0, 1.0);
283 std::mt19937 rnd;
284 in1.for_each_value([&](float &f) { f = dis(rnd); });
285 Halide::Var i, j, k, l;
286 node_inputs[0].rep(i, j, k, l) = in1(i, j, k, l);
287
288 Node converted = convert_node(sum_node, node_inputs);
289
290 GOOGLE_CHECK_EQ(1, converted.outputs.size());
291 Halide::Buffer<float> output = converted.outputs[0].rep.realize(1, 3, 1, 11);
292 for (int i = 0; i < 3; ++i) {
293 for (int j = 0; j < 11; ++j) {
294 float expected = 0.0f;
295 for (int k = 0; k < 7; ++k) {
296 for (int l = 0; l < 5; ++l) {
297 expected += in1(k, i, l, j);
298 }
299 }
300 EXPECT_NEAR(expected, output(0, i, 0, j), 1e-5);
301 }
302 }
303 }
304
test_where_broadcast()305 static void test_where_broadcast() {
306 onnx::NodeProto where_node;
307 where_node.set_name("where_node");
308 where_node.set_op_type("Where");
309 where_node.add_input("c");
310 where_node.add_input("x");
311 where_node.add_input("y");
312 where_node.add_output("z");
313
314 std::vector<Tensor> node_inputs;
315 node_inputs.resize(3);
316 node_inputs[0].shape = {2, 2, 2};
317 node_inputs[1].shape = {2};
318 node_inputs[2].shape = {2, 2};
319 Halide::Buffer<bool> in_c(2, 2, 2);
320 in_c.for_each_element(
321 [&](int x, int y, int z) { in_c(x, y, z) = (x == y && x == z); });
322 Halide::Buffer<float> in_x(2);
323 Halide::Buffer<float> in_y(2, 2);
324 std::uniform_real_distribution<float> dis(-1.0, 1.0);
325 std::mt19937 rnd;
326 in_x.for_each_value([&](float &f) { f = dis(rnd); });
327 in_y.for_each_value([&](float &f) { f = dis(rnd); });
328 Halide::Var i("i"), j("j"), k("k");
329 node_inputs[0].rep(i, j, k) = in_c(i, j, k);
330 node_inputs[1].rep(i) = in_x(i);
331 node_inputs[2].rep(i, j) = in_y(i, j);
332
333 Node converted = convert_node(where_node, node_inputs);
334 GOOGLE_CHECK_EQ(1, converted.outputs.size());
335 Halide::Buffer<float> output = converted.outputs[0].rep.realize(2, 2, 2);
336
337 for (int i = 0; i < 2; ++i) {
338 for (int j = 0; j < 2; ++j) {
339 for (int k = 0; k < 2; ++k) {
340 if (in_c(i, j, k)) {
341 EXPECT_EQ(output(i, j, k), in_x(k));
342 } else {
343 EXPECT_EQ(output(i, j, k), in_y(j, k));
344 }
345 }
346 }
347 }
348 }
349
test_concat()350 static void test_concat() {
351 onnx::NodeProto concat_node;
352 concat_node.set_name("concat_node");
353 concat_node.set_op_type("Concat");
354 concat_node.add_input("x");
355 concat_node.add_output("y");
356
357 onnx::AttributeProto *attr = concat_node.add_attribute();
358 attr->set_name("axis");
359 attr->add_ints(0);
360
361 std::vector<Tensor> node_inputs;
362 node_inputs.resize(2);
363 node_inputs[0].shape = {7, 3};
364 Halide::Buffer<float> in1(7, 3);
365 std::uniform_real_distribution<float> dis(-1.0, 1.0);
366 std::mt19937 rnd;
367 in1.for_each_value([&](float &f) { f = dis(rnd); });
368 Halide::Var i, j;
369 node_inputs[0].rep(i, j) = in1(i, j);
370
371 node_inputs[1].shape = {5, 3};
372 Halide::Buffer<float> in2(5, 3);
373 in2.for_each_value([&](float &f) { f = dis(rnd); });
374 node_inputs[1].rep(i, j) = in2(i, j);
375
376 Node converted = convert_node(concat_node, node_inputs);
377
378 GOOGLE_CHECK_EQ(1, converted.outputs.size());
379 Halide::Buffer<float> output = converted.outputs[0].rep.realize(7 + 5, 3);
380 for (int i = 0; i < 3; ++i) {
381 for (int j = 0; j < 7; ++j) {
382 EXPECT_EQ(in1(j, i), output(j, i));
383 }
384 for (int j = 0; j < 5; ++j) {
385 EXPECT_EQ(in2(j, i), output(j + 7, i));
386 }
387 }
388 }
389
test_constant_fill()390 static void test_constant_fill() {
391 constexpr float const_value = 2.0f;
392 onnx::NodeProto concat_node;
393 concat_node.set_name("constant_fill_node");
394 concat_node.set_op_type("ConstantFill");
395 concat_node.add_output("y");
396 onnx::AttributeProto *shape_attr = concat_node.add_attribute();
397 shape_attr->set_name("shape");
398 shape_attr->add_ints(3);
399 shape_attr->add_ints(4);
400 onnx::AttributeProto *val_attr = concat_node.add_attribute();
401 val_attr->set_name("value");
402 val_attr->set_f(const_value);
403 onnx::AttributeProto *dtype_attr = concat_node.add_attribute();
404 dtype_attr->set_name("dtype");
405 dtype_attr->set_i(4);
406
407 Node converted = convert_node(concat_node, {});
408 GOOGLE_CHECK_EQ(1, converted.outputs.size());
409 Halide::Buffer<uint16_t> output = converted.outputs[0].rep.realize(3, 4);
410 for (int i = 0; i < 3; ++i) {
411 for (int j = 0; j < 4; ++j) {
412 EXPECT_EQ(2u, output(i, j));
413 }
414 }
415 }
416
test_model()417 static void test_model() {
418 onnx::ModelProto model;
419 onnx::ValueInfoProto *input_def = model.mutable_graph()->add_input();
420 input_def->set_name("model_input");
421 input_def->mutable_type()->mutable_tensor_type()->set_elem_type(
422 onnx::TensorProto_DataType_FLOAT);
423 input_def->mutable_type()
424 ->mutable_tensor_type()
425 ->mutable_shape()
426 ->add_dim()
427 ->set_dim_value(3);
428 input_def->mutable_type()
429 ->mutable_tensor_type()
430 ->mutable_shape()
431 ->add_dim()
432 ->set_dim_value(7);
433
434 model.mutable_graph()->add_output()->set_name("model_output");
435 model.mutable_graph()->add_output()->set_name("output_shape");
436 model.mutable_graph()->add_output()->set_name("output_size");
437
438 onnx::NodeProto *first_node = model.mutable_graph()->add_node();
439 first_node->set_name("exp_of_input");
440 first_node->set_op_type("Exp");
441 first_node->add_input("model_input");
442 first_node->add_output("input_exp");
443
444 onnx::NodeProto *second_node = model.mutable_graph()->add_node();
445 second_node->set_name("log_of_exp");
446 second_node->set_op_type("Log");
447 second_node->add_input("input_exp");
448 second_node->add_output("log_exp");
449
450 onnx::NodeProto *third_node = model.mutable_graph()->add_node();
451 third_node->set_name("sum");
452 third_node->set_op_type("Add");
453 third_node->add_input("input_exp");
454 third_node->add_input("log_exp");
455 third_node->add_output("model_output");
456
457 onnx::NodeProto *fourth_node = model.mutable_graph()->add_node();
458 fourth_node->set_name("shape");
459 fourth_node->set_op_type("Shape");
460 fourth_node->add_input("model_output");
461 fourth_node->add_output("output_shape");
462
463 onnx::NodeProto *fifth_node = model.mutable_graph()->add_node();
464 fifth_node->set_name("size");
465 fifth_node->set_op_type("Size");
466 fifth_node->add_input("model_output");
467 fifth_node->add_output("output_size");
468
469 std::unordered_map<std::string, int> dummy;
470 Model converted = convert_model(model, dummy, IOLayout::Native);
471
472 Halide::Buffer<float> input_values(3, 7);
473 std::uniform_real_distribution<float> dis(-1.0, 1.0);
474 std::mt19937 rnd;
475 input_values.for_each_value([&](float &f) { f = dis(rnd); });
476
477 Halide::ImageParam &input = converted.inputs.at("model_input");
478 input.set(input_values);
479 Tensor node = converted.outputs.at("model_output");
480 Halide::Buffer<float> output_values = node.rep.realize({3, 7});
481
482 for (int i = 0; i < 3; ++i) {
483 for (int j = 0; j < 7; ++j) {
484 float expected =
485 std::exp(input_values(i, j)) + std::log(std::exp(input_values(i, j)));
486 float actual = output_values(i, j);
487 EXPECT_NEAR(actual, expected, 1e-6f);
488 }
489 }
490
491 Tensor size = converted.outputs.at("output_size");
492 Halide::Buffer<int64_t> output_size = size.rep.realize();
493 EXPECT_EQ(21, output_size());
494
495 Tensor shape = converted.outputs.at("output_shape");
496 Halide::Buffer<int64_t> output_shape = shape.rep.realize(2);
497 EXPECT_EQ(3, output_shape(0));
498 EXPECT_EQ(7, output_shape(1));
499 }
500
main()501 int main() {
502 test_abs();
503 test_activation_function();
504 test_cast();
505 test_add();
506 test_constant();
507 test_gemm();
508 test_conv();
509 test_sum();
510 test_where_broadcast();
511 test_concat();
512 test_constant_fill();
513 test_model();
514 printf("Success!\n");
515 return 0;
516 }
517