1 //
2 //  MergeBNToConvolution.cpp
3 //  MNNConverter
4 //
5 //  Created by MNN on 2019/11/27.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "../PostTreatUtils.hpp"
10 #include "MergeToConvolution.hpp"
11 
12 using namespace MNN;
13 
14 class MergeBNToConvolution : public MergeToConvolution {
15 public:
merge2Convolution(const MNN::OpT * inplaceOp,MNN::OpT * convolutionOp) const16     bool merge2Convolution(const MNN::OpT* inplaceOp, MNN::OpT* convolutionOp) const {
17         const auto& convCommon = convolutionOp->main.AsConvolution2D()->common;
18         if (convCommon->relu || convCommon->relu6 || convolutionOp->inputIndexes.size() > 1) {
19             return false;
20         }
21 
22         if (inplaceOp->type == MNN::OpType_BatchNorm) {
23             std::vector<float> alpha;
24             std::vector<float> bias;
25 
26             auto l = inplaceOp->main.AsBatchNorm();
27             alpha.resize(l->channels);
28             bias.resize(l->channels);
29             const float* slopePtr    = l->slopeData.data();
30             const float* meanDataPtr = l->meanData.data();
31             const float* varDataPtr  = l->varData.data();
32             const float* biasDataPtr = l->biasData.data();
33             const float eps          = l->epsilon;
34 
35             for (int i = 0; i < l->channels; i++) {
36                 float sqrt_var = sqrt(varDataPtr[i] + eps);
37                 bias[i]        = biasDataPtr[i] - slopePtr[i] * meanDataPtr[i] / sqrt_var;
38                 alpha[i]       = slopePtr[i] / sqrt_var;
39             }
40 
41             auto conv2D     = convolutionOp->main.AsConvolution2D();
42             int outputCount = conv2D->common->outputCount;
43             for (int i = 0; i < outputCount; ++i) {
44                 conv2D->bias[i] = conv2D->bias[i] * alpha[i] + bias[i];
45             }
46 
47             if (nullptr != conv2D->quanParameter.get()) {
48                 for (int i = 0; i < outputCount; ++i) {
49                     conv2D->quanParameter->alpha[i] *= alpha[i];
50                 }
51             } else {
52                 int weightPartSize = conv2D->weight.size() / outputCount;
53                 if (convolutionOp->type == OpType_Deconvolution) {
54                     int inputCount =
55                         conv2D->weight.size() / outputCount / conv2D->common->kernelX / conv2D->common->kernelY;
56                     for (int i = 0; i < inputCount; ++i) {
57                         auto dstPos = i * outputCount * conv2D->common->kernelY * conv2D->common->kernelX;
58                         for (int j = 0; j < outputCount; ++j) {
59                             auto dstPosJ = dstPos + j * conv2D->common->kernelY * conv2D->common->kernelX;
60                             float a      = alpha[j];
61                             for (int k = 0; k < conv2D->common->kernelY * conv2D->common->kernelX; ++k) {
62                                 conv2D->weight[dstPosJ + k] *= a;
63                             }
64                         }
65                     }
66                 } else {
67                     for (int i = 0; i < outputCount; ++i) {
68                         float a = alpha[i];
69                         for (int j = 0; j < weightPartSize; ++j) {
70                             conv2D->weight[i * weightPartSize + j] *= a;
71                         }
72                     }
73                 }
74             }
75             return true;
76         }
77         return false;
78     }
79 
merge2Convolution3D(const MNN::OpT * inplaceOp,MNN::OpT * convolutionOp) const80     bool merge2Convolution3D(const MNN::OpT* inplaceOp, MNN::OpT* convolutionOp) const {
81         const auto& convCommon = convolutionOp->main.AsConvolution3D()->common;
82         if (convCommon->relu || convCommon->relu6) {
83             return false;
84         }
85 
86         if (inplaceOp->type == MNN::OpType_BatchNorm) {
87             std::vector<float> alpha;
88             std::vector<float> bias;
89 
90             auto l = inplaceOp->main.AsBatchNorm();
91             alpha.resize(l->channels);
92             bias.resize(l->channels);
93             const float* slopePtr    = l->slopeData.data();
94             const float* meanDataPtr = l->meanData.data();
95             const float* varDataPtr  = l->varData.data();
96             const float* biasDataPtr = l->biasData.data();
97             const float eps          = l->epsilon;
98 
99             for (int i = 0; i < l->channels; i++) {
100                 float sqrt_var = sqrt(varDataPtr[i] + eps);
101                 bias[i]        = biasDataPtr[i] - slopePtr[i] * meanDataPtr[i] / sqrt_var;
102                 alpha[i]       = slopePtr[i] / sqrt_var;
103             }
104 
105             auto conv3D     = convolutionOp->main.AsConvolution3D();
106             int outputCount = conv3D->common->outputCount;
107             for (int i = 0; i < outputCount; ++i) {
108                 conv3D->bias[i] = conv3D->bias[i] * alpha[i] + bias[i];
109             }
110 
111             int weightPartSize = conv3D->weight.size() / outputCount;
112             for (int i = 0; i < outputCount; ++i) {
113                 float a = alpha[i];
114                 for (int j = 0; j < weightPartSize; ++j) {
115                     conv3D->weight[i * weightPartSize + j] *= a;
116                 }
117             }
118             return true;
119         }
120         return false;
121     }
122 };
123 static PostConverterRegister<MergeBNToConvolution> __l("MergeBNToConvolution");
124