1 //
2 //  GeometryELU.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/07/23.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "geometry/GeometryComputer.hpp"
10 #include "core/OpCommonUtils.hpp"
11 #include "geometry/GeometryComputerUtils.hpp"
12 
13 namespace MNN {
14 
initTensor(std::shared_ptr<Tensor> tensor,Tensor * input)15 static void initTensor(std::shared_ptr<Tensor> tensor, Tensor* input) {
16     tensor->buffer().type = input->getType();
17     TensorUtils::copyShape(input, tensor.get(), true);
18 }
19 
20 class GeometryELU : public GeometryComputer {
21 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const22     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, Context& context, CommandBuffer& res) const override {
23         MNN_ASSERT(1 == inputs.size());
24         MNN_ASSERT(1 == outputs.size());
25         auto input = inputs[0];
26         auto output = outputs[0];
27         // ELU : y = x > 0 ? x : alpha * (exp(x) - 1)
28         // exp + sub + mul : y1 = alhpa * (exp(x) - 1)
29         // exp
30         std::shared_ptr<Tensor> expValue(new Tensor);
31         {
32             initTensor(expValue, input);
33             auto cmd = GeometryComputerUtils::makeUnary(UnaryOpOperation_EXP, input, expValue.get());
34             res.extras.emplace_back(expValue);
35             res.command.emplace_back(std::move(cmd));
36         }
37         // sub
38         std::shared_ptr<Tensor> subValue(new Tensor);
39         {
40             auto oneConst = context.allocConst(op, {}, halide_type_of<float>());
41             oneConst->host<float>()[0] = 1.0;
42             initTensor(subValue, input);
43             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_SUB, expValue.get(), oneConst.get(), subValue.get());
44             res.extras.emplace_back(subValue);
45             res.command.emplace_back(std::move(cmd));
46         }
47         // mul
48         std::shared_ptr<Tensor> mulValue(new Tensor);
49         {
50             auto alphaConst = context.allocConst(op, {}, halide_type_of<float>());
51             float alpha = 0.0;
52             if (op->type() == OpType_ELU) {
53                 alpha = op->main_as_ELU()->alpha();
54             } else if (op->type() == OpType_Selu){
55                 alpha = op->main_as_Selu()->alpha() *
56                         op->main_as_Selu()->scale();
57             }
58             alphaConst->host<float>()[0] = alpha;
59             initTensor(mulValue, input);
60             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, subValue.get(), alphaConst.get(), mulValue.get());
61             res.extras.emplace_back(mulValue);
62             res.command.emplace_back(std::move(cmd));
63         }
64         // compare + select : y = x > 0 ? x : y1
65         // compare
66         std::shared_ptr<Tensor> compValue(new Tensor);
67         {
68             auto zeroConst = context.allocConst(op, {}, halide_type_of<float>());
69             zeroConst->host<float>()[0] = 0;
70             compValue->buffer().type = halide_type_of<int>();
71             TensorUtils::copyShape(input, compValue.get(), true);
72             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_GREATER, input, zeroConst.get(), compValue.get());
73             res.extras.emplace_back(compValue);
74             res.command.emplace_back(std::move(cmd));
75         }
76         std::shared_ptr<Tensor> scaleValue(new Tensor);
77         {
78             if (op->type() == OpType_Selu) {
79                 auto scaleConst = context.allocConst(op, {}, halide_type_of<float>());
80                 float scale = op->main_as_Selu()->scale();
81                 scaleConst->host<float>()[0] = scale;
82                 initTensor(scaleValue, input);
83                 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, input, scaleConst.get(), scaleValue.get());
84                 res.extras.emplace_back(scaleValue);
85                 res.command.emplace_back(std::move(cmd));
86             }
87         }
88         // select
89         {
90             flatbuffers::FlatBufferBuilder builder;
91             OpBuilder opBuilder(builder);
92             opBuilder.add_type(OpType_Select);
93             builder.Finish(opBuilder.Finish());
94             auto y0 = op->type() == OpType_ELU ? input : scaleValue.get();
95             auto cmd = GeometryComputerUtils::makeCommand(builder, {compValue.get(), y0, mulValue.get()}, {output});
96             res.command.emplace_back(std::move(cmd));
97         }
98         return true;
99     }
100 };
101 
_create()102 static void _create() {
103     std::shared_ptr<GeometryComputer> comp(new GeometryELU);
104     GeometryComputer::registerGeometryComputer(comp, {OpType_ELU, OpType_Selu});
105 }
106 
107 REGISTER_GEOMETRY(GeometryELU, _create);
108 
109 } // namespace MNN
110 
111