1 // 2 // MergeBNToConvolution.cpp 3 // MNNConverter 4 // 5 // Created by MNN on 2019/11/27. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #include "../PostTreatUtils.hpp" 10 #include "MergeToConvolution.hpp" 11 12 using namespace MNN; 13 14 class MergeBNToConvolution : public MergeToConvolution { 15 public: merge2Convolution(const MNN::OpT * inplaceOp,MNN::OpT * convolutionOp) const16 bool merge2Convolution(const MNN::OpT* inplaceOp, MNN::OpT* convolutionOp) const { 17 const auto& convCommon = convolutionOp->main.AsConvolution2D()->common; 18 if (convCommon->relu || convCommon->relu6 || convolutionOp->inputIndexes.size() > 1) { 19 return false; 20 } 21 22 if (inplaceOp->type == MNN::OpType_BatchNorm) { 23 std::vector<float> alpha; 24 std::vector<float> bias; 25 26 auto l = inplaceOp->main.AsBatchNorm(); 27 alpha.resize(l->channels); 28 bias.resize(l->channels); 29 const float* slopePtr = l->slopeData.data(); 30 const float* meanDataPtr = l->meanData.data(); 31 const float* varDataPtr = l->varData.data(); 32 const float* biasDataPtr = l->biasData.data(); 33 const float eps = l->epsilon; 34 35 for (int i = 0; i < l->channels; i++) { 36 float sqrt_var = sqrt(varDataPtr[i] + eps); 37 bias[i] = biasDataPtr[i] - slopePtr[i] * meanDataPtr[i] / sqrt_var; 38 alpha[i] = slopePtr[i] / sqrt_var; 39 } 40 41 auto conv2D = convolutionOp->main.AsConvolution2D(); 42 int outputCount = conv2D->common->outputCount; 43 for (int i = 0; i < outputCount; ++i) { 44 conv2D->bias[i] = conv2D->bias[i] * alpha[i] + bias[i]; 45 } 46 47 if (nullptr != conv2D->quanParameter.get()) { 48 for (int i = 0; i < outputCount; ++i) { 49 conv2D->quanParameter->alpha[i] *= alpha[i]; 50 } 51 } else { 52 int weightPartSize = conv2D->weight.size() / outputCount; 53 if (convolutionOp->type == OpType_Deconvolution) { 54 int inputCount = 55 conv2D->weight.size() / outputCount / conv2D->common->kernelX / conv2D->common->kernelY; 56 for (int i = 0; i < inputCount; ++i) { 57 auto dstPos = i * outputCount * conv2D->common->kernelY * conv2D->common->kernelX; 58 for (int j = 0; j < outputCount; ++j) { 59 auto dstPosJ = dstPos + j * conv2D->common->kernelY * conv2D->common->kernelX; 60 float a = alpha[j]; 61 for (int k = 0; k < conv2D->common->kernelY * conv2D->common->kernelX; ++k) { 62 conv2D->weight[dstPosJ + k] *= a; 63 } 64 } 65 } 66 } else { 67 for (int i = 0; i < outputCount; ++i) { 68 float a = alpha[i]; 69 for (int j = 0; j < weightPartSize; ++j) { 70 conv2D->weight[i * weightPartSize + j] *= a; 71 } 72 } 73 } 74 } 75 return true; 76 } 77 return false; 78 } 79 merge2Convolution3D(const MNN::OpT * inplaceOp,MNN::OpT * convolutionOp) const80 bool merge2Convolution3D(const MNN::OpT* inplaceOp, MNN::OpT* convolutionOp) const { 81 const auto& convCommon = convolutionOp->main.AsConvolution3D()->common; 82 if (convCommon->relu || convCommon->relu6) { 83 return false; 84 } 85 86 if (inplaceOp->type == MNN::OpType_BatchNorm) { 87 std::vector<float> alpha; 88 std::vector<float> bias; 89 90 auto l = inplaceOp->main.AsBatchNorm(); 91 alpha.resize(l->channels); 92 bias.resize(l->channels); 93 const float* slopePtr = l->slopeData.data(); 94 const float* meanDataPtr = l->meanData.data(); 95 const float* varDataPtr = l->varData.data(); 96 const float* biasDataPtr = l->biasData.data(); 97 const float eps = l->epsilon; 98 99 for (int i = 0; i < l->channels; i++) { 100 float sqrt_var = sqrt(varDataPtr[i] + eps); 101 bias[i] = biasDataPtr[i] - slopePtr[i] * meanDataPtr[i] / sqrt_var; 102 alpha[i] = slopePtr[i] / sqrt_var; 103 } 104 105 auto conv3D = convolutionOp->main.AsConvolution3D(); 106 int outputCount = conv3D->common->outputCount; 107 for (int i = 0; i < outputCount; ++i) { 108 conv3D->bias[i] = conv3D->bias[i] * alpha[i] + bias[i]; 109 } 110 111 int weightPartSize = conv3D->weight.size() / outputCount; 112 for (int i = 0; i < outputCount; ++i) { 113 float a = alpha[i]; 114 for (int j = 0; j < weightPartSize; ++j) { 115 conv3D->weight[i * weightPartSize + j] *= a; 116 } 117 } 118 return true; 119 } 120 return false; 121 } 122 }; 123 static PostConverterRegister<MergeBNToConvolution> __l("MergeBNToConvolution"); 124