1 //
2 // GeometryELU.cpp
3 // MNN
4 //
5 // Created by MNN on 2020/07/23.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "geometry/GeometryComputer.hpp"
10 #include "core/OpCommonUtils.hpp"
11 #include "geometry/GeometryComputerUtils.hpp"
12
13 namespace MNN {
14
initTensor(std::shared_ptr<Tensor> tensor,Tensor * input)15 static void initTensor(std::shared_ptr<Tensor> tensor, Tensor* input) {
16 tensor->buffer().type = input->getType();
17 TensorUtils::copyShape(input, tensor.get(), true);
18 }
19
20 class GeometryELU : public GeometryComputer {
21 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const22 virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, Context& context, CommandBuffer& res) const override {
23 MNN_ASSERT(1 == inputs.size());
24 MNN_ASSERT(1 == outputs.size());
25 auto input = inputs[0];
26 auto output = outputs[0];
27 // ELU : y = x > 0 ? x : alpha * (exp(x) - 1)
28 // exp + sub + mul : y1 = alhpa * (exp(x) - 1)
29 // exp
30 std::shared_ptr<Tensor> expValue(new Tensor);
31 {
32 initTensor(expValue, input);
33 auto cmd = GeometryComputerUtils::makeUnary(UnaryOpOperation_EXP, input, expValue.get());
34 res.extras.emplace_back(expValue);
35 res.command.emplace_back(std::move(cmd));
36 }
37 // sub
38 std::shared_ptr<Tensor> subValue(new Tensor);
39 {
40 auto oneConst = context.allocConst(op, {}, halide_type_of<float>());
41 oneConst->host<float>()[0] = 1.0;
42 initTensor(subValue, input);
43 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_SUB, expValue.get(), oneConst.get(), subValue.get());
44 res.extras.emplace_back(subValue);
45 res.command.emplace_back(std::move(cmd));
46 }
47 // mul
48 std::shared_ptr<Tensor> mulValue(new Tensor);
49 {
50 auto alphaConst = context.allocConst(op, {}, halide_type_of<float>());
51 float alpha = 0.0;
52 if (op->type() == OpType_ELU) {
53 alpha = op->main_as_ELU()->alpha();
54 } else if (op->type() == OpType_Selu){
55 alpha = op->main_as_Selu()->alpha() *
56 op->main_as_Selu()->scale();
57 }
58 alphaConst->host<float>()[0] = alpha;
59 initTensor(mulValue, input);
60 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, subValue.get(), alphaConst.get(), mulValue.get());
61 res.extras.emplace_back(mulValue);
62 res.command.emplace_back(std::move(cmd));
63 }
64 // compare + select : y = x > 0 ? x : y1
65 // compare
66 std::shared_ptr<Tensor> compValue(new Tensor);
67 {
68 auto zeroConst = context.allocConst(op, {}, halide_type_of<float>());
69 zeroConst->host<float>()[0] = 0;
70 compValue->buffer().type = halide_type_of<int>();
71 TensorUtils::copyShape(input, compValue.get(), true);
72 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_GREATER, input, zeroConst.get(), compValue.get());
73 res.extras.emplace_back(compValue);
74 res.command.emplace_back(std::move(cmd));
75 }
76 std::shared_ptr<Tensor> scaleValue(new Tensor);
77 {
78 if (op->type() == OpType_Selu) {
79 auto scaleConst = context.allocConst(op, {}, halide_type_of<float>());
80 float scale = op->main_as_Selu()->scale();
81 scaleConst->host<float>()[0] = scale;
82 initTensor(scaleValue, input);
83 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, input, scaleConst.get(), scaleValue.get());
84 res.extras.emplace_back(scaleValue);
85 res.command.emplace_back(std::move(cmd));
86 }
87 }
88 // select
89 {
90 flatbuffers::FlatBufferBuilder builder;
91 OpBuilder opBuilder(builder);
92 opBuilder.add_type(OpType_Select);
93 builder.Finish(opBuilder.Finish());
94 auto y0 = op->type() == OpType_ELU ? input : scaleValue.get();
95 auto cmd = GeometryComputerUtils::makeCommand(builder, {compValue.get(), y0, mulValue.get()}, {output});
96 res.command.emplace_back(std::move(cmd));
97 }
98 return true;
99 }
100 };
101
_create()102 static void _create() {
103 std::shared_ptr<GeometryComputer> comp(new GeometryELU);
104 GeometryComputer::registerGeometryComputer(comp, {OpType_ELU, OpType_Selu});
105 }
106
107 REGISTER_GEOMETRY(GeometryELU, _create);
108
109 } // namespace MNN
110
111