1 //
2 // UnaryGrad.cpp
3 // MNN
4 //
5 // Created by MNN on 2019/05/25.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "OpGrad.hpp"
10 #include "core/Macro.h"
11 using namespace std;
12 using namespace MNN;
13 using namespace MNN::Express;
14
15 class UnaryGrad : public OpGrad {
16 public:
onGrad(Express::EXPRP expr,const std::vector<Express::VARP> & backwardOutput)17 virtual std::vector<Express::VARP> onGrad(Express::EXPRP expr,
18 const std::vector<Express::VARP>& backwardOutput) override {
19 std::unique_ptr<OpT> forwardOp(expr->get()->UnPack());
20 auto outputDiff = backwardOutput[0];
21 auto input = expr->inputs()[0];
22 std::vector<Express::VARP> res(1, nullptr);
23 std::vector<Express::VARP> output{Variable::create(expr, 0)};
24
25 switch (forwardOp->main.AsUnaryOp()->opType) {
26 case MNN::UnaryOpOperation_LOG1P: {
27 // d log(1+x) = 1/(1+x) * dx = dx / (1+x)
28 auto oneConst = _Const(1.0f, {}, NHWC);
29 auto addOne = _Add(input, oneConst);
30 res[0] = _Divide(outputDiff, addOne);
31 break;
32 }
33 case MNN::UnaryOpOperation_EXP: {
34 // d Exp(x) = Exp(x) * dx
35 res[0] = _Multiply(outputDiff, output[0]);
36 break;
37 }
38 case MNN::UnaryOpOperation_LOG: {
39 // d Log(x) = dx / x
40 res[0] = _Divide(outputDiff, input);
41 break;
42 }
43 case MNN::UnaryOpOperation_COS: {
44 // d Sin(x) = -dx * Sin(x)
45 res[0] = _Negative(outputDiff) * _Sin(input);
46 break;
47 }
48 case MNN::UnaryOpOperation_SIN: {
49 // d Sin(x) = dx * Cos(x)
50 res[0] = outputDiff * _Cos(input);
51 break;
52 }
53 case MNN::UnaryOpOperation_ABS: {
54 // d Abs(x) = dx * (x > 0 ? 1 : -1)
55 res[0] = outputDiff * _Sign(input);
56 break;
57 }
58 case MNN::UnaryOpOperation_NEG: {
59 // d (-x) = - dx
60 res[0] = _Negative(outputDiff);
61 break;
62 }
63 case MNN::UnaryOpOperation_SQRT: {
64 // d (-sqrt(x)) = 0.5 / sqrt(x) * dx
65 auto oneConst = _Const(0.5f, {}, NHWC);
66 auto mul = _Multiply(outputDiff, oneConst);
67 res[0] = _Divide(mul, output[0]);
68 break;
69 }
70 case MNN::UnaryOpOperation_SQUARE: {
71 // d (x^2) = (x*dx + x*dx)
72 auto mul = _Multiply(input, outputDiff);
73 res[0] = _Add(mul, mul);
74 break;
75 }
76 default:
77 return res;
78 }
79
80 res[0]->setName(expr->name() + "_Grad");
81 return res;
82 }
83 };
84 class SigmoidGrad : public OpGrad {
85 public:
onGrad(Express::EXPRP expr,const std::vector<Express::VARP> & backwardOutput)86 virtual std::vector<Express::VARP> onGrad(Express::EXPRP expr,
87 const std::vector<Express::VARP>& backwardOutput) override {
88 std::vector<Express::VARP> result(1, nullptr);
89 auto outputDiff = backwardOutput[0];
90 std::vector<Express::VARP> output{Variable::create(expr, 0)};
91
92 // y = (1/(1+e(-x))) , dy = y(1-y) * dx = (y*y - y)*dx
93 auto mul = _Multiply(output[0], output[0]);
94 auto sub = _Subtract(mul, output[0]);
95 auto grad = _Multiply(sub, outputDiff);
96 result[0] = grad;
97 result[0]->setName(expr->name() + "_Grad");
98 return result;
99 }
100 };
101
102 class TanhGrad : public OpGrad {
103 public:
onGrad(Express::EXPRP expr,const std::vector<Express::VARP> & backwardOutput)104 virtual std::vector<Express::VARP> onGrad(Express::EXPRP expr,
105 const std::vector<Express::VARP>& backwardOutput) override {
106 std::vector<Express::VARP> result{nullptr};
107 std::vector<Express::VARP> output{Variable::create(expr, 0)};
108
109 auto outputDiff = backwardOutput[0];
110 // d tanh(x) = (1-tanh(x)^2)dx
111 result[0] = (_Const(1.0f, {}, NCHW) - _Square(output[0])) * outputDiff;
112 return result;
113 }
114 };
115
__anon8c9c1f290102() 116 static const auto gRegister = []() {
117 static UnaryGrad _c;
118 static SigmoidGrad _s;
119 static TanhGrad _t;
120 OpGrad::insert(OpType_UnaryOp, &_c);
121 OpGrad::insert(OpType_Sigmoid, &_s);
122 OpGrad::insert(OpType_TanH, &_t);
123 return true;
124 }();
125