1 //
2 //  UnaryGrad.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/05/25.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "OpGrad.hpp"
10 #include "core/Macro.h"
11 using namespace std;
12 using namespace MNN;
13 using namespace MNN::Express;
14 
15 class UnaryGrad : public OpGrad {
16 public:
onGrad(Express::EXPRP expr,const std::vector<Express::VARP> & backwardOutput)17     virtual std::vector<Express::VARP> onGrad(Express::EXPRP expr,
18                                               const std::vector<Express::VARP>& backwardOutput) override {
19         std::unique_ptr<OpT> forwardOp(expr->get()->UnPack());
20         auto outputDiff = backwardOutput[0];
21         auto input      = expr->inputs()[0];
22         std::vector<Express::VARP> res(1, nullptr);
23         std::vector<Express::VARP> output{Variable::create(expr, 0)};
24 
25         switch (forwardOp->main.AsUnaryOp()->opType) {
26             case MNN::UnaryOpOperation_LOG1P: {
27                 // d log(1+x) = 1/(1+x) * dx = dx / (1+x)
28                 auto oneConst = _Const(1.0f, {}, NHWC);
29                 auto addOne   = _Add(input, oneConst);
30                 res[0]        = _Divide(outputDiff, addOne);
31                 break;
32             }
33             case MNN::UnaryOpOperation_EXP: {
34                 // d Exp(x) = Exp(x) * dx
35                 res[0] = _Multiply(outputDiff, output[0]);
36                 break;
37             }
38             case MNN::UnaryOpOperation_LOG: {
39                 // d Log(x) =  dx / x
40                 res[0] = _Divide(outputDiff, input);
41                 break;
42             }
43             case MNN::UnaryOpOperation_COS: {
44                 // d Sin(x) =  -dx * Sin(x)
45                 res[0] = _Negative(outputDiff) * _Sin(input);
46                 break;
47             }
48             case MNN::UnaryOpOperation_SIN: {
49                 // d Sin(x) =  dx * Cos(x)
50                 res[0] = outputDiff * _Cos(input);
51                 break;
52             }
53             case MNN::UnaryOpOperation_ABS: {
54                 // d Abs(x) =  dx * (x > 0 ? 1 : -1)
55                 res[0] = outputDiff * _Sign(input);
56                 break;
57             }
58             case MNN::UnaryOpOperation_NEG: {
59                 // d (-x) = - dx
60                 res[0] = _Negative(outputDiff);
61                 break;
62             }
63             case MNN::UnaryOpOperation_SQRT: {
64                 // d (-sqrt(x)) = 0.5 / sqrt(x) * dx
65                 auto oneConst = _Const(0.5f, {}, NHWC);
66                 auto mul      = _Multiply(outputDiff, oneConst);
67                 res[0]        = _Divide(mul, output[0]);
68                 break;
69             }
70             case MNN::UnaryOpOperation_SQUARE: {
71                 // d (x^2) = (x*dx + x*dx)
72                 auto mul = _Multiply(input, outputDiff);
73                 res[0]   = _Add(mul, mul);
74                 break;
75             }
76             default:
77                 return res;
78         }
79 
80         res[0]->setName(expr->name() + "_Grad");
81         return res;
82     }
83 };
84 class SigmoidGrad : public OpGrad {
85 public:
onGrad(Express::EXPRP expr,const std::vector<Express::VARP> & backwardOutput)86     virtual std::vector<Express::VARP> onGrad(Express::EXPRP expr,
87                                               const std::vector<Express::VARP>& backwardOutput) override {
88         std::vector<Express::VARP> result(1, nullptr);
89         auto outputDiff = backwardOutput[0];
90         std::vector<Express::VARP> output{Variable::create(expr, 0)};
91 
92         // y = (1/(1+e(-x))) , dy = y(1-y) * dx = (y*y - y)*dx
93         auto mul  = _Multiply(output[0], output[0]);
94         auto sub  = _Subtract(mul, output[0]);
95         auto grad = _Multiply(sub, outputDiff);
96         result[0] = grad;
97         result[0]->setName(expr->name() + "_Grad");
98         return result;
99     }
100 };
101 
102 class TanhGrad : public OpGrad {
103 public:
onGrad(Express::EXPRP expr,const std::vector<Express::VARP> & backwardOutput)104     virtual std::vector<Express::VARP> onGrad(Express::EXPRP expr,
105                                               const std::vector<Express::VARP>& backwardOutput) override {
106         std::vector<Express::VARP> result{nullptr};
107         std::vector<Express::VARP> output{Variable::create(expr, 0)};
108 
109         auto outputDiff = backwardOutput[0];
110         // d tanh(x) = (1-tanh(x)^2)dx
111         result[0] = (_Const(1.0f, {}, NCHW) - _Square(output[0])) * outputDiff;
112         return result;
113     }
114 };
115 
__anon8c9c1f290102() 116 static const auto gRegister = []() {
117     static UnaryGrad _c;
118     static SigmoidGrad _s;
119     static TanhGrad _t;
120     OpGrad::insert(OpType_UnaryOp, &_c);
121     OpGrad::insert(OpType_Sigmoid, &_s);
122     OpGrad::insert(OpType_TanH, &_t);
123     return true;
124 }();
125