1 //
2 //  mobileNetExpr.cpp
3 //  MNN
4 //  Reference paper: https://arxiv.org/pdf/1704.04861.pdf https://arxiv.org/pdf/1801.04381.pdf
5 //
6 //  Created by MNN on 2019/06/25.
7 //  Copyright © 2018, Alibaba Group Holding Limited
8 //
9 
10 #include <map>
11 #include "MobileNetExpr.hpp"
12 #include <MNN/expr/ExprCreator.hpp>
13 
14 using namespace MNN::Express;
15 
16 // When we use MNNConverter to convert other mobilenet model to MNN model,
17 // {Conv3x3Depthwise + BN + Relu + Conv1x1 + BN + Relu} will be converted
18 // and optimized to {Conv3x3Depthwise + Conv1x1}
convBlock(VARP x,INTS channels,int stride)19 static VARP convBlock(VARP x, INTS channels, int stride) {
20     int inputChannel = channels[0], outputChannel = channels[1];
21     int group = inputChannel;
22     x = _Conv(0.0f, 0.0f, x, {inputChannel, inputChannel}, {3, 3}, SAME, {stride, stride}, {1, 1}, group);
23     x = _Conv(0.0f, 0.0f, x, {inputChannel, outputChannel}, {1, 1}, SAME, {1, 1}, {1, 1}, 1);
24     return x;
25 }
26 
mobileNetV1Expr(MobileNetWidthType alpha,MobileNetResolutionType beta,int numClass)27 VARP mobileNetV1Expr(MobileNetWidthType alpha, MobileNetResolutionType beta, int numClass) {
28     int inputSize, poolSize; // MobileNet_224, MobileNet_192, MobileNet_160, MobileNet_128
29     {
30         auto inputSizeMap = std::map<MobileNetResolutionType, int>({
31             {MobileNet_224, 224},
32             {MobileNet_192, 192},
33             {MobileNet_160, 160},
34             {MobileNet_128, 128}
35         });
36         if (inputSizeMap.find(beta) == inputSizeMap.end()) {
37             MNN_ERROR("MobileNetResolutionType (%d) not support, only support [MobileNet_224, MobileNet_192, MobileNet_160, MobileNet_128]\n", beta);
38             return VARP(nullptr);
39         }
40         inputSize = inputSizeMap[beta];
41         poolSize = inputSize / 32;
42     }
43 
44     int channels[6]; // MobileNet_100, MobileNet_075, MobileNet_050, MobileNet_025
45     {
46         auto channelsMap = std::map<MobileNetWidthType, int>({
47             {MobileNet_100, 32},
48             {MobileNet_075, 24},
49             {MobileNet_050, 16},
50             {MobileNet_025, 8}
51         });
52         if (channelsMap.find(alpha) == channelsMap.end()) {
53             MNN_ERROR("MobileNetWidthType (%d) not support, only support [MobileNet_100, MobileNet_075, MobileNet_050, MobileNet_025]\n", alpha);
54             return VARP(nullptr);
55         }
56         channels[0] = channelsMap[alpha];
57     }
58 
59     for (int i = 1; i < 6; ++i) {
60         channels[i] = channels[0] * (1 << i);
61     }
62 
63     auto x = _Input({1, 3, inputSize, inputSize}, NC4HW4);
64     x = _Conv(0.0f, 0.0f, x, {3, channels[0]}, {3, 3}, SAME, {2, 2}, {1, 1}, 1);
65     x = convBlock(x, {channels[0], channels[1]}, 1);
66     x = convBlock(x, {channels[1], channels[2]}, 2);
67     x = convBlock(x, {channels[2], channels[2]}, 1);
68     x = convBlock(x, {channels[2], channels[3]}, 2);
69     x = convBlock(x, {channels[3], channels[3]}, 1);
70     x = convBlock(x, {channels[3], channels[4]}, 2);
71     x = convBlock(x, {channels[4], channels[4]}, 1);
72     x = convBlock(x, {channels[4], channels[4]}, 1);
73     x = convBlock(x, {channels[4], channels[4]}, 1);
74     x = convBlock(x, {channels[4], channels[4]}, 1);
75     x = convBlock(x, {channels[4], channels[4]}, 1);
76     x = convBlock(x, {channels[4], channels[5]}, 2);
77     x = convBlock(x, {channels[5], channels[5]}, 1);
78     x = _AvePool(x, {poolSize, poolSize}, {1, 1}, VALID);
79     x = _Conv(0.0f, 0.0f, x, {channels[5], numClass}, {1, 1}, VALID, {1, 1}, {1, 1}, 1); // reshape FC with Conv1x1
80     x = _Softmax(x, -1);
81     return x;
82 }
83 
bottleNeck(VARP x,INTS channels,int stride,int expansionRatio)84 static VARP bottleNeck(VARP x, INTS channels, int stride, int expansionRatio) {
85     int inputChannel = channels[0], outputChannel = channels[1];
86     int expansionChannel = inputChannel * expansionRatio, group = expansionChannel;
87     auto y = _Conv(0.0f, 0.0f, x, {inputChannel, expansionChannel}, {1, 1}, VALID, {1, 1}, {1, 1}, 1);
88     y = _Conv(0.0f, 0.0f, y, {expansionChannel, expansionChannel}, {3, 3}, SAME, {stride, stride}, {1, 1}, group);
89     y = _Conv(0.0f, 0.0f, y, {expansionChannel, outputChannel}, {1, 1}, VALID, {1, 1}, {1, 1}, 1);
90     if (inputChannel != outputChannel || stride != 1) {
91         x = _Conv(0.0f, 0.0f, x, {inputChannel, outputChannel}, {1, 1}, SAME, {stride, stride}, {1, 1}, 1);
92     }
93     y = _Add(x, y);
94     return y;
95 }
96 
bottleNeckBlock(VARP x,INTS channels,int stride,int expansionRatio,int number)97 static VARP bottleNeckBlock(VARP x, INTS channels, int stride, int expansionRatio, int number) {
98     x = bottleNeck(x, {channels[0], channels[1]}, stride, expansionRatio);
99     for (int i = 1; i < number; ++i) {
100         x = bottleNeck(x, {channels[1], channels[1]}, 1, expansionRatio);
101     }
102     return x;
103 }
104 
mobileNetV2Expr(int numClass)105 VARP mobileNetV2Expr(int numClass) {
106     auto x = _Input({1, 3, 224, 224}, NC4HW4);
107     x = _Conv(0.0f, 0.0f, x, {3, 32}, {3, 3}, SAME, {2, 2}, {1, 1}, 1);
108     x = bottleNeckBlock(x,   {32, 16}, 1, 1, 1);
109     x = bottleNeckBlock(x,   {16, 24}, 2, 6, 2);
110     x = bottleNeckBlock(x,   {24, 32}, 2, 6, 3);
111     x = bottleNeckBlock(x,   {32, 64}, 2, 6, 4);
112     x = bottleNeckBlock(x,   {64, 96}, 1, 6, 3);
113     x = bottleNeckBlock(x,  {96, 160}, 2, 6, 3);
114     x = bottleNeckBlock(x, {160, 320}, 1, 6, 1);
115     x = _Conv(0.0f, 0.0f, x, {320, 1280}, {1, 1}, VALID, {1, 1}, {1, 1}, 1);
116     x = _AvePool(x, {7, 7}, {1, 1}, VALID);
117     x = _Conv(0.0f, 0.0f, x, {1280, numClass}, {1, 1}, VALID, {1, 1}, {1, 1}, 1); // reshape FC with Conv1x1
118     return x;
119 }
120