1 //
2 //  resnetExpr.cpp
3 //  MNN
4 //  Reference paper: https://arxiv.org/pdf/1512.03385.pdf
5 //
6 //  Created by MNN on 2019/06/25.
7 //  Copyright © 2018, Alibaba Group Holding Limited
8 //
9 
10 #include <map>
11 #include <vector>
12 #include "ResNetExpr.hpp"
13 #include <MNN/expr/ExprCreator.hpp>
14 
15 using namespace MNN::Express;
16 
17 // When we use MNNConverter to convert other resnet model to MNN model,
18 // {Conv + BN + Relu} will be converted and optimized to {Conv}
residual(VARP x,INTS channels,int stride)19 static VARP residual(VARP x, INTS channels, int stride) {
20     int inputChannel = x->getInfo()->dim[1], outputChannel = channels[1];
21     auto y = _Conv(0.0f, 0.0f, x, {inputChannel, outputChannel}, {3, 3}, SAME, {stride, stride}, {1, 1}, 1);
22     y = _Conv(0.0f, 0.0f, y, {outputChannel, outputChannel}, {3, 3}, SAME, {1, 1}, {1, 1}, 1);
23     if (inputChannel != outputChannel || stride != 1) {
24         x = _Conv(0.0f, 0.0f, x, {inputChannel, outputChannel}, {1, 1}, SAME, {stride, stride}, {1, 1}, 1);
25     }
26     y = _Add(x, y);
27     return y;
28 }
29 
residualBlock(VARP x,INTS channels,int stride,int number)30 static VARP residualBlock(VARP x, INTS channels, int stride, int number) {
31     x = residual(x, {channels[0], channels[1]}, stride);
32     for (int i = 1; i < number; ++i) {
33         x = residual(x, {channels[1], channels[1]}, 1);
34     }
35     return x;
36 }
37 
bottleNeck(VARP x,INTS channels,int stride)38 static VARP bottleNeck(VARP x, INTS channels, int stride) {
39     int inputChannel = x->getInfo()->dim[1], narrowChannel = channels[1], outputChannel = channels[2];
40     auto y = _Conv(0.0f, 0.0f, x, {inputChannel, narrowChannel}, {1, 1}, SAME, {stride, stride}, {1, 1}, 1);
41     y = _Conv(0.0f, 0.0f, y, {narrowChannel, narrowChannel}, {3, 3}, SAME, {1, 1}, {1, 1}, 1);
42     y = _Conv(0.0f, 0.0f, y, {narrowChannel, outputChannel}, {1, 1}, VALID, {1, 1}, {1, 1}, 1);
43     if (inputChannel != outputChannel || stride != 1) {
44         x = _Conv(0.0f, 0.0f, x, {inputChannel, outputChannel}, {1, 1}, SAME, {stride, stride}, {1, 1}, 1);
45     }
46     y = _Add(x, y);
47     return y;
48 }
49 
bottleNeckBlock(VARP x,INTS channels,int stride,int number)50 static VARP bottleNeckBlock(VARP x, INTS channels, int stride, int number) {
51     x = bottleNeck(x, {channels[0], channels[1], channels[2]}, stride);
52     for (int i = 1; i < number; ++i) {
53         x = bottleNeck(x, {channels[2], channels[1], channels[2]}, 1);
54     }
55     return x;
56 }
57 
resNetExpr(ResNetType resNetType,int numClass)58 VARP resNetExpr(ResNetType resNetType, int numClass) {
59     std::vector<int> numbers;
60     {
61         auto numbersMap = std::map<ResNetType, std::vector<int>>({
62             {ResNet18, {2, 2, 2, 2}},
63             {ResNet34, {3, 4, 6, 3}},
64             {ResNet50, {3, 4, 6, 3}},
65             {ResNet101, {3, 4, 23, 3}},
66             {ResNet152, {3, 8, 36, 3}}
67         });
68         if (numbersMap.find(resNetType) == numbersMap.end()) {
69             MNN_ERROR("resNetType (%d) not support, only support [ResNet18, ResNet34, ResNet50, ResNet101, ResNet152]\n", resNetType);
70             return VARP(nullptr);
71         }
72         numbers = numbersMap[resNetType];
73     }
74     std::vector<int> channels({64, 64, 128, 256, 512});
75     {
76         if (resNetType != ResNet18 && resNetType != ResNet34) {
77             channels[0] = 16;
78         }
79     }
80     std::vector<int> strides({1, 2, 2, 2});
81     int finalChannel = channels[4] * 4;
82     auto x = _Input({1, 3, 224, 224}, NC4HW4);
83     x = _Conv(0.0f, 0.0f, x, {3, 64}, {7, 7}, SAME, {2, 2}, {1, 1}, 1);
84     x = _MaxPool(x, {3, 3}, {2, 2}, SAME);
85     for (int i = 0; i < 4; ++i) {
86         if (resNetType == ResNet18 || resNetType == ResNet34) {
87             x = residualBlock(x, {channels[i], channels[i+1]}, strides[i], numbers[i]);
88         } else {
89             x = bottleNeckBlock(x, {channels[i] * 4, channels[i+1], channels[i+1] * 4}, strides[i], numbers[i]);
90         }
91     }
92     x = _AvePool(x, {7, 7}, {1, 1}, VALID);
93     x = _Conv(0.0f, 0.0f, x, {x->getInfo()->dim[1], numClass}, {1, 1}, VALID, {1, 1}, {1, 1}, 1); // reshape FC with Conv1x1
94     x = _Softmax(x, -1);
95     return x;
96 }
97