1 //
2 // resnetExpr.cpp
3 // MNN
4 // Reference paper: https://arxiv.org/pdf/1512.03385.pdf
5 //
6 // Created by MNN on 2019/06/25.
7 // Copyright © 2018, Alibaba Group Holding Limited
8 //
9
10 #include <map>
11 #include <vector>
12 #include "ResNetExpr.hpp"
13 #include <MNN/expr/ExprCreator.hpp>
14
15 using namespace MNN::Express;
16
17 // When we use MNNConverter to convert other resnet model to MNN model,
18 // {Conv + BN + Relu} will be converted and optimized to {Conv}
residual(VARP x,INTS channels,int stride)19 static VARP residual(VARP x, INTS channels, int stride) {
20 int inputChannel = x->getInfo()->dim[1], outputChannel = channels[1];
21 auto y = _Conv(0.0f, 0.0f, x, {inputChannel, outputChannel}, {3, 3}, SAME, {stride, stride}, {1, 1}, 1);
22 y = _Conv(0.0f, 0.0f, y, {outputChannel, outputChannel}, {3, 3}, SAME, {1, 1}, {1, 1}, 1);
23 if (inputChannel != outputChannel || stride != 1) {
24 x = _Conv(0.0f, 0.0f, x, {inputChannel, outputChannel}, {1, 1}, SAME, {stride, stride}, {1, 1}, 1);
25 }
26 y = _Add(x, y);
27 return y;
28 }
29
residualBlock(VARP x,INTS channels,int stride,int number)30 static VARP residualBlock(VARP x, INTS channels, int stride, int number) {
31 x = residual(x, {channels[0], channels[1]}, stride);
32 for (int i = 1; i < number; ++i) {
33 x = residual(x, {channels[1], channels[1]}, 1);
34 }
35 return x;
36 }
37
bottleNeck(VARP x,INTS channels,int stride)38 static VARP bottleNeck(VARP x, INTS channels, int stride) {
39 int inputChannel = x->getInfo()->dim[1], narrowChannel = channels[1], outputChannel = channels[2];
40 auto y = _Conv(0.0f, 0.0f, x, {inputChannel, narrowChannel}, {1, 1}, SAME, {stride, stride}, {1, 1}, 1);
41 y = _Conv(0.0f, 0.0f, y, {narrowChannel, narrowChannel}, {3, 3}, SAME, {1, 1}, {1, 1}, 1);
42 y = _Conv(0.0f, 0.0f, y, {narrowChannel, outputChannel}, {1, 1}, VALID, {1, 1}, {1, 1}, 1);
43 if (inputChannel != outputChannel || stride != 1) {
44 x = _Conv(0.0f, 0.0f, x, {inputChannel, outputChannel}, {1, 1}, SAME, {stride, stride}, {1, 1}, 1);
45 }
46 y = _Add(x, y);
47 return y;
48 }
49
bottleNeckBlock(VARP x,INTS channels,int stride,int number)50 static VARP bottleNeckBlock(VARP x, INTS channels, int stride, int number) {
51 x = bottleNeck(x, {channels[0], channels[1], channels[2]}, stride);
52 for (int i = 1; i < number; ++i) {
53 x = bottleNeck(x, {channels[2], channels[1], channels[2]}, 1);
54 }
55 return x;
56 }
57
resNetExpr(ResNetType resNetType,int numClass)58 VARP resNetExpr(ResNetType resNetType, int numClass) {
59 std::vector<int> numbers;
60 {
61 auto numbersMap = std::map<ResNetType, std::vector<int>>({
62 {ResNet18, {2, 2, 2, 2}},
63 {ResNet34, {3, 4, 6, 3}},
64 {ResNet50, {3, 4, 6, 3}},
65 {ResNet101, {3, 4, 23, 3}},
66 {ResNet152, {3, 8, 36, 3}}
67 });
68 if (numbersMap.find(resNetType) == numbersMap.end()) {
69 MNN_ERROR("resNetType (%d) not support, only support [ResNet18, ResNet34, ResNet50, ResNet101, ResNet152]\n", resNetType);
70 return VARP(nullptr);
71 }
72 numbers = numbersMap[resNetType];
73 }
74 std::vector<int> channels({64, 64, 128, 256, 512});
75 {
76 if (resNetType != ResNet18 && resNetType != ResNet34) {
77 channels[0] = 16;
78 }
79 }
80 std::vector<int> strides({1, 2, 2, 2});
81 int finalChannel = channels[4] * 4;
82 auto x = _Input({1, 3, 224, 224}, NC4HW4);
83 x = _Conv(0.0f, 0.0f, x, {3, 64}, {7, 7}, SAME, {2, 2}, {1, 1}, 1);
84 x = _MaxPool(x, {3, 3}, {2, 2}, SAME);
85 for (int i = 0; i < 4; ++i) {
86 if (resNetType == ResNet18 || resNetType == ResNet34) {
87 x = residualBlock(x, {channels[i], channels[i+1]}, strides[i], numbers[i]);
88 } else {
89 x = bottleNeckBlock(x, {channels[i] * 4, channels[i+1], channels[i+1] * 4}, strides[i], numbers[i]);
90 }
91 }
92 x = _AvePool(x, {7, 7}, {1, 1}, VALID);
93 x = _Conv(0.0f, 0.0f, x, {x->getInfo()->dim[1], numClass}, {1, 1}, VALID, {1, 1}, {1, 1}, 1); // reshape FC with Conv1x1
94 x = _Softmax(x, -1);
95 return x;
96 }
97