1 //
2 //  MathOp.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/06/27.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef MathOp_HPP
10 #define MathOp_HPP
11 
12 namespace MNN {
13 namespace Express {
14 //BinaryOPs
15 MNN_PUBLIC VARP _Add(VARP x, VARP y);
16 MNN_PUBLIC VARP _Subtract(VARP x, VARP y);
17 MNN_PUBLIC VARP _Multiply(VARP x, VARP y);
18 MNN_PUBLIC VARP _Divide(VARP x, VARP y);
19 MNN_PUBLIC VARP _Pow(VARP x, VARP y);
20 MNN_PUBLIC VARP _Minimum(VARP x, VARP y);
21 MNN_PUBLIC VARP _Maximum(VARP x, VARP y);
22 MNN_PUBLIC VARP _BiasAdd(VARP value, VARP bias);
23 MNN_PUBLIC VARP _Greater(VARP x, VARP y);
24 MNN_PUBLIC VARP _GreaterEqual(VARP x, VARP y);
25 MNN_PUBLIC VARP _Less(VARP x, VARP y);
26 MNN_PUBLIC VARP _FloorDiv(VARP x, VARP y);
27 MNN_PUBLIC VARP _SquaredDifference(VARP x, VARP y);
28 MNN_PUBLIC VARP _Equal(VARP x, VARP y);
29 MNN_PUBLIC VARP _LessEqual(VARP x, VARP y);
30 MNN_PUBLIC VARP _FloorMod(VARP x, VARP y);
31 MNN_PUBLIC VARP _Atan2(VARP x, VARP y);
32 MNN_PUBLIC VARP _LogicalOr(VARP x, VARP y);
33 MNN_PUBLIC VARP _NotEqual(VARP x, VARP y);
34 
35 //UnaryOPs
36 MNN_PUBLIC VARP _Sign(VARP a);
37 MNN_PUBLIC VARP _Abs(VARP x);
38 MNN_PUBLIC VARP _Negative(VARP x);
39 MNN_PUBLIC VARP _Floor(VARP x);
40 MNN_PUBLIC VARP _Round(VARP x);
41 MNN_PUBLIC VARP _Ceil(VARP x);
42 MNN_PUBLIC VARP _Square(VARP x);
43 MNN_PUBLIC VARP _Sqrt(VARP x);
44 MNN_PUBLIC VARP _Rsqrt(VARP x);
45 MNN_PUBLIC VARP _Exp(VARP x);
46 MNN_PUBLIC VARP _Log(VARP x);
47 MNN_PUBLIC VARP _Sin(VARP x);
48 MNN_PUBLIC VARP _Sinh(VARP x);
49 MNN_PUBLIC VARP _Cos(VARP x);
50 MNN_PUBLIC VARP _Cosh(VARP x);
51 MNN_PUBLIC VARP _Tan(VARP x);
52 MNN_PUBLIC VARP _Asin(VARP x);
53 MNN_PUBLIC VARP _Asinh(VARP x);
54 MNN_PUBLIC VARP _Acos(VARP x);
55 MNN_PUBLIC VARP _Acosh(VARP x);
56 MNN_PUBLIC VARP _Atan(VARP x);
57 MNN_PUBLIC VARP _Atanh(VARP x);
58 MNN_PUBLIC VARP _Reciprocal(VARP x);
59 MNN_PUBLIC VARP _Log1p(VARP x);
60 MNN_PUBLIC VARP _Gelu(VARP x);
61 //Only one but not in UnaryOPs
62 MNN_PUBLIC VARP _Tanh(VARP x);
63 MNN_PUBLIC VARP _Sigmoid(VARP x);
64 MNN_PUBLIC VARP _Erf(VARP x);
65 MNN_PUBLIC VARP _Erfc(VARP x);
66 MNN_PUBLIC VARP _Erfinv(VARP x);
67 MNN_PUBLIC VARP _Expm1(VARP x);
68 
69 
70 //ReduceOPs
71 MNN_PUBLIC VARP _ReduceSum(VARP input_variable, INTS axis = {}, bool keepDims = false);
72 MNN_PUBLIC VARP _ReduceMean(VARP input_variable, INTS axis = {}, bool keepDims = false);
73 MNN_PUBLIC VARP _ReduceMax(VARP input_variable, INTS axis = {}, bool keepDims = false);
74 MNN_PUBLIC VARP _ReduceMin(VARP input_variable, INTS axis = {}, bool keepDims = false);
75 MNN_PUBLIC VARP _ReduceProd(VARP input_variable, INTS axis = {}, bool keepDims = false);
76 MNN_PUBLIC VARP _ReduceAny(VARP input_variable, INTS axis = {}, bool keepDims = false);
77 MNN_PUBLIC VARP _ReduceAll(VARP input_variable, INTS axis = {}, bool keepDims = false);
78 
79 MNN_PUBLIC VARP _ReduceSumMutable(VARP input_variable, VARP axis, bool keepDims = false);
80 MNN_PUBLIC VARP _ReduceMeanMutable(VARP input_variable, VARP axis, bool keepDims = false);
81 MNN_PUBLIC VARP _ReduceMaxMutable(VARP input_variable, VARP axis, bool keepDims = false);
82 MNN_PUBLIC VARP _ReduceMinMutable(VARP input_variable, VARP axis, bool keepDims = false);
83 MNN_PUBLIC VARP _ReduceProdMutable(VARP input_variable, VARP axis, bool keepDims = false);
84 MNN_PUBLIC VARP _ReduceAnyMutable(VARP input_variable, VARP axis, bool keepDims = false);
85 MNN_PUBLIC VARP _ReduceAllMutable(VARP input_variable, VARP axis, bool keepDims = false);
86 
87 //EltwiseOPs
88 MNN_PUBLIC VARP _Prod(VARP a, VARP b, std::vector<float> coeff);
89 MNN_PUBLIC VARP _Sum(VARP a, VARP b, std::vector<float> coeff);
90 MNN_PUBLIC VARP _Max(VARP a, VARP b, std::vector<float> coeff);
91 MNN_PUBLIC VARP _Sub(VARP a, VARP b, std::vector<float> coeff);
92 MNN_PUBLIC VARP _EltwiseProdInt8(VARP x, VARP y,
93                     std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
94                     std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
95                     std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale);
96 MNN_PUBLIC VARP _EltwiseSumInt8(VARP x, VARP y,
97                      std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
98                     std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
99                     std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale);
100 MNN_PUBLIC VARP _EltwiseSubInt8(VARP x, VARP y,
101                      std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
102                     std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
103                     std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale);
104 MNN_PUBLIC VARP _EltwiseMaxInt8(VARP x, VARP y,
105                       std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
106                     std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
107                     std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale);
108 
109 
110 //OtherOPs
111 template<typename T>
_Cast(VARP x)112 VARP _Cast(VARP x) {
113     return _Cast(x, halide_type_of<T>());
114 }
115 MNN_PUBLIC VARP _Cast(VARP x, halide_type_t dtype);
116 MNN_PUBLIC VARP _MatMul(VARP a, VARP b, bool tranposeA = false, bool tranposeB = false);
117 MNN_PUBLIC VARP _Normalize(VARP x, int32_t acrossSpatial, int32_t channelShared, float eps, std::vector<float> scale);
118 MNN_PUBLIC VARP _ArgMax(VARP input, int axis = 0);
119 MNN_PUBLIC VARP _ArgMin(VARP input, int axis = 0);
120 MNN_PUBLIC VARP _BatchMatMul(VARP x, VARP y, bool adj_x = false, bool adj_y = false);
121 MNN_PUBLIC VARP _UnravelIndex(VARP indices, VARP dims);
122 MNN_PUBLIC VARP _ScatterNd(VARP indices, VARP updates, VARP shape);
123 MNN_PUBLIC VARP _OneHot(VARP indices, VARP depth, VARP onValue, VARP offValue, int axis = -1);
124 MNN_PUBLIC VARP _BroadcastTo(VARP a, VARP shape);
125 MNN_PUBLIC VARP _LinSpace(VARP start, VARP stop, VARP num);
126 }; // namespace Express
127 }; // namespace MNN
128 
129 #endif /* MathOp_HPP */
130