1 //
2 // MathOp.cpp
3 // MNN
4 //
5 // Created by MNN on 2019/06/27.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include <algorithm>
10 #include <map>
11 #include <numeric>
12 #include <MNN/expr/ExprCreator.hpp>
13 #include <MNN/MNNDefine.h>
14 #include "Utils.hpp"
15
16 namespace MNN {
17 namespace Express {
_convertDataType(halide_type_t type)18 static DataType _convertDataType(halide_type_t type) {
19 if (type.code == halide_type_float) {
20 return DataType_DT_FLOAT;
21 }
22 if (type.code == halide_type_uint && type.bits == 8) {
23 return DataType_DT_UINT8;
24 }
25 if (type.code == halide_type_int && type.bits == 8) {
26 return DataType_DT_INT8;
27 }
28 if (type.code == halide_type_int && type.bits == 32) {
29 return DataType_DT_INT32;
30 }
31 return DataType_DT_INVALID;
32 }
_checkNC4HW4(VARP x)33 static VARP _checkNC4HW4(VARP x) {
34 #ifdef MNN_EXPR_SHAPE_EAGER
35 auto info = x->getInfo();
36 if (nullptr != info && info->order == NC4HW4) {
37 return _Convert(x, NCHW);
38 }
39 #endif
40 return x;
41 }
_Binary(VARP x,VARP y,BinaryOpOperation operation)42 static VARP _Binary(VARP x, VARP y, BinaryOpOperation operation) {
43 x = _checkNC4HW4(x);
44 y = _checkNC4HW4(y);
45 flatbuffers::FlatBufferBuilder builder;
46 BinaryOpBuilder parameter(builder);
47 parameter.add_opType(operation);
48 auto paOffset = parameter.Finish();
49 OpBuilder opB(builder);
50 opB.add_main(paOffset.Union());
51 opB.add_type(OpType_BinaryOp);
52 opB.add_main_type(OpParameter_BinaryOp);
53 builder.Finish(opB.Finish());
54 std::shared_ptr<BufferStorage> extra(new BufferStorage);
55 extra->storage.reset(builder.ReleaseRaw(extra->allocated_size, extra->offset));
56 return Variable::create(Expr::create(extra, {x, y}, 1));
57 }
_Unary(VARP x,UnaryOpOperation operation)58 static VARP _Unary(VARP x, UnaryOpOperation operation) {
59 flatbuffers::FlatBufferBuilder builder;
60 UnaryOpBuilder parameter(builder);
61 parameter.add_opType(operation);
62 auto paOffset = parameter.Finish();
63 OpBuilder opB(builder);
64 opB.add_main(paOffset.Union());
65 opB.add_type(OpType_UnaryOp);
66 opB.add_main_type(OpParameter_UnaryOp);
67 builder.Finish(opB.Finish());
68 std::shared_ptr<BufferStorage> extra(new BufferStorage);
69 extra->storage.reset(builder.ReleaseRaw(extra->allocated_size, extra->offset));
70 return Variable::create(Expr::create(extra, {x}, 1));
71 }
_Reduce(VARP x,INTS dim,ReductionType type,bool keepDim)72 static VARP _Reduce(VARP x, INTS dim, ReductionType type, bool keepDim) {
73 x = _checkNC4HW4(x);
74 flatbuffers::FlatBufferBuilder builder;
75 flatbuffers::Offset<flatbuffers::Vector<int>> dimOffset;
76 if (!dim.empty()) {
77 dimOffset = builder.CreateVector(dim);
78 }
79 ReductionParamBuilder parameter(builder);
80 parameter.add_operation(type);
81 parameter.add_keepDims(keepDim);
82 if (!dim.empty()) {
83 parameter.add_dim(dimOffset);
84 }
85 auto paOffset = parameter.Finish();
86 OpBuilder opB(builder);
87 opB.add_main(paOffset.Union());
88 opB.add_type(OpType_Reduction);
89 opB.add_main_type(OpParameter_ReductionParam);
90 builder.Finish(opB.Finish());
91 std::shared_ptr<BufferStorage> extra(new BufferStorage);
92 extra->storage.reset(builder.ReleaseRaw(extra->allocated_size, extra->offset));
93 return Variable::create(Expr::create(extra, {x}, 1));
94 }
_ReduceMutable(VARP x,VARP dim,ReductionType type,bool keepDim)95 static VARP _ReduceMutable(VARP x, VARP dim, ReductionType type, bool keepDim) {
96 x = _checkNC4HW4(x);
97 flatbuffers::FlatBufferBuilder builder;
98 ReductionParamBuilder parameter(builder);
99 parameter.add_operation(type);
100 parameter.add_keepDims(keepDim);
101 auto paOffset = parameter.Finish();
102 OpBuilder opB(builder);
103 opB.add_main(paOffset.Union());
104 opB.add_type(OpType_Reduction);
105 opB.add_main_type(OpParameter_ReductionParam);
106 builder.Finish(opB.Finish());
107 // TODO: Remove Copy
108 std::shared_ptr<BufferStorage> extra(new BufferStorage);
109 extra->storage.reset(builder.ReleaseRaw(extra->allocated_size, extra->offset));
110 return Variable::create(Expr::create(extra, {x, dim}, 1));
111 }
_Eltwise(VARP a,VARP b,EltwiseType type,std::vector<float> coeff)112 static VARP _Eltwise(VARP a, VARP b, EltwiseType type, std::vector<float> coeff) {
113 std::unique_ptr<OpT> op(new OpT);
114 op->main.type = OpParameter_Eltwise;
115 op->type = OpType_Eltwise;
116 op->main.value = new EltwiseT;
117 op->main.AsEltwise()->type = type;
118 op->main.AsEltwise()->coeff = coeff;
119 return (Variable::create(Expr::create(std::move(op), {a, b})));
120 }
_EltwiseInt8(VARP x,VARP y,EltwiseType type,std::vector<int8_t> x_weight,std::vector<int32_t> x_bias,std::vector<float> x_scale,std::vector<float> x_tensorScale,std::vector<int8_t> y_weight,std::vector<int32_t> y_bias,std::vector<float> y_scale,std::vector<float> y_tensorScale,std::vector<int8_t> output_weight,std::vector<int32_t> output_bias,std::vector<float> output_scale,std::vector<float> output_tensorScale)121 static VARP _EltwiseInt8(VARP x, VARP y, EltwiseType type,
122 std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
123 std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
124 std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale)
125 {
126 std::unique_ptr<OpT> op(new OpT);
127 std::unique_ptr<QuantizedFloatParamT> param_x(new QuantizedFloatParamT);
128 std::unique_ptr<QuantizedFloatParamT> param_y(new QuantizedFloatParamT);
129 std::unique_ptr<QuantizedFloatParamT> param_output(new QuantizedFloatParamT);
130 auto param_op = new EltwiseInt8T;
131 param_x->weight = x_weight;
132 param_x->bias = x_bias;
133 param_x->scale = x_scale;
134 param_x->tensorScale = y_tensorScale;
135 param_y->weight = y_weight;
136 param_y->bias = y_bias;
137 param_y->scale = y_scale;
138 param_y->tensorScale = y_tensorScale;
139 param_output->weight = output_weight;
140 param_output->bias = output_bias;
141 param_output->scale = output_scale;
142 param_output->tensorScale = output_tensorScale;
143 param_op->type = type;
144 param_op->inputQuan0 = std::move(param_x);
145 param_op->inputQuan1 = std::move(param_y);
146 param_op->outputQuan = std::move(param_output);
147 op->main.type = OpParameter_EltwiseInt8;
148 op->type = OpType_EltwiseInt8;
149 op->main.value = param_op;
150 return (Variable::create(Expr::create(std::move(op), {x, y})));
151 }
152
153 /*Casts a variable to a new type.
154 Args:
155 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float, Halide_Type_Int64, Halide_Type_Uint8
156 dtype: The destination type. The list of supported dtypes is the same as x.
157 Returns:
158 A variable with same shape as x and same type as dtype.
159 */
_Cast(VARP x,halide_type_t dtype)160 VARP _Cast(VARP x, halide_type_t dtype) {
161 std::unique_ptr<OpT> op(new OpT);
162 op->main.type = OpParameter_CastParam;
163 op->type = OpType_Cast;
164 op->main.value = new CastParamT;
165 op->main.AsCastParam()->dstT = _convertDataType(dtype);
166 return (Variable::create(Expr::create(std::move(op), {x})));
167 }
168
169 /*Computes the absolute value of a variable.
170 Given a variable of integer or floating-point values, this operation returns a variable of the same type,
171 where each element contains the absolute value of the corresponding element in the input.
172 x = MNN.const((-1.0, -2.0, 3.0), (3, ))
173 x = MNN.abs(x) # (1.0, 2.0, 3.0)
174 Args:
175 x: A variable of type Halide_Type_Int or Halide_Type_Float
176 Returns:
177 A variable the same size, type as x with absolute values.
178 */
_Abs(VARP x)179 VARP _Abs(VARP x)
180 {
181 return _Unary(x, UnaryOpOperation_ABS);
182 }
183 /*Computes numerical negative value element-wise.
184 x = MNN.const((-1.0, -2.0, 3.0), (3, ))
185 x = MNN.negative(x) #(1.0, 2.0, -3.0)
186 Args:
187 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
188 Returns:
189 A variable. Has the same type as x.
190 */
_Negative(VARP x)191 VARP _Negative(VARP x)
192 {
193 return _Unary(x, UnaryOpOperation_NEG);
194 }
195 /*Returns element-wise largest integer not greater than x.
196 Args:
197 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
198 Returns:
199 A variable. Has the same type as x.
200 */
_Floor(VARP x)201 VARP _Floor(VARP x)
202 {
203 return _Unary(x, UnaryOpOperation_FLOOR);
204 }
205 /*Returns element-wise smallest integer not less than x.
206 Args:
207 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
208 Returns:
209 A variable. Has the same type as x.
210 */
_Ceil(VARP x)211 VARP _Ceil(VARP x)
212 {
213 return _Unary(x, UnaryOpOperation_CEIL);
214 }
215
216 /*Returns element-wise rounded integer not less than x.
217 Args:
218 x: A variable. Must be Halide_Type_Float
219 Returns:
220 A variable. Halide_Type_Float.
221 */
_Round(VARP x)222 VARP _Round(VARP x) {
223 return _Unary(x, UnaryOpOperation_ROUND);
224 }
225
226 /*Computes square of x element-wise.
227 Args:
228 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
229 Returns:
230 A variable. Has the same type as x.
231 */
_Square(VARP x)232 VARP _Square(VARP x)
233 {
234 return _Unary(x, UnaryOpOperation_SQUARE);
235 }
236
237 /*Computes square root of x element-wise.
238 Args:
239 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
240 Returns:
241 A variable. Has the same type as x.
242 */
_Sqrt(VARP x)243 VARP _Sqrt(VARP x)
244 {
245 return _Unary(x, UnaryOpOperation_SQRT);
246 }
247
248 /*Computes reciprocal of square root of x element-wise.
249 Args:
250 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
251 Returns:
252 A variable. Has the same type as x.
253 */
_Rsqrt(VARP x)254 VARP _Rsqrt(VARP x)
255 {
256 return _Unary(x, UnaryOpOperation_RSQRT);
257 }
258
259 /*Computes exponential of x element-wise.
260 Args:
261 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
262 Returns:
263 A variable. Has the same type as x.
264 */
_Exp(VARP x)265 VARP _Exp(VARP x)
266 {
267 return _Unary(x, UnaryOpOperation_EXP);
268 }
269
270 /*Computes natural logarithm of x element-wise.
271 Args:
272 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
273 Returns:
274 A variable. Has the same type as x.
275 */
_Log(VARP x)276 VARP _Log(VARP x)
277 {
278 return _Unary(x, UnaryOpOperation_LOG);
279 }
280
281 /*Computes sine of x element-wise.
282 Given an input variable, this function computes sine of every element in the variable.
283 Input range is (-inf, inf) and output range is [-1,1].
284 Args:
285 x: A variable. Must be one of the following types: Halide_Type_Float
286 Returns:
287 A variable. Has the same type as x.
288 */
_Sin(VARP x)289 VARP _Sin(VARP x)
290 {
291 return _Unary(x, UnaryOpOperation_SIN);
292 }
293
294 /*Computes cos of x element-wise.
295 Given an input variable, this function computes cosine of every element in the variable.
296 Input range is (-inf, inf) and output range is [-1,1]. If input lies outside the boundary, nan is returned.
297 Args:
298 x: A variable. Must be one of the following types: Halide_Type_Float
299 Returns:
300 A variable. Has the same type as x.
301 */
_Cos(VARP x)302 VARP _Cos(VARP x)
303 {
304 return _Unary(x, UnaryOpOperation_COS);
305 }
306
307 /*Computes tan of x element-wise.
308 Given an input variable, this function computes tangent of every element in the variable.
309 Input range is (-inf, inf) and output range is (-inf, inf). If input lies outside the boundary, nan is returned.
310 Args:
311 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
312 Returns:
313 A variable. Has the same type as x.
314 */
_Tan(VARP x)315 VARP _Tan(VARP x)
316 {
317 return _Unary(x, UnaryOpOperation_TAN);
318 }
319
320 /*Computes the trignometric inverse sine of x element-wise.
321 The asin operation returns the inverse of sin, such that if y = sin(x) then, x = asin(y).
322 Note: The output of asin will lie within the invertible range of sine, i.e [-pi/2, pi/2].
323 Args:
324 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
325 Returns:
326 A variable. Has the same type as x.
327 */
_Asin(VARP x)328 VARP _Asin(VARP x)
329 {
330 return _Unary(x, UnaryOpOperation_ASIN);
331 }
332 /*Computes acos of x element-wise.
333 Args:
334 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
335 Note: The output of atan will lie within the invertible range of tan, i.e (0.0, pi).
336 Returns:
337 A variable. Has the same type as x.
338 */
_Acos(VARP x)339 VARP _Acos(VARP x)
340 {
341 return _Unary(x, UnaryOpOperation_ACOS);
342 }
343
344 /*Computes acosh of x element-wise.
345 Args:
346 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
347 Note: The output of atan will lie within the invertible range of tan, i.e (0.0, pi).
348 Returns:
349 A variable. Has the same type as x.
350 */
_Acosh(VARP x)351 VARP _Acosh(VARP x)
352 {
353 return _Unary(x, UnaryOpOperation_ACOSH);
354 }
355
356 /*Computes asinh of x element-wise.
357 Args:
358 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
359 Note: The output of atan will lie within the invertible range of tan, i.e (0.0, pi).
360 Returns:
361 A variable. Has the same type as x.
362 */
_Asinh(VARP x)363 VARP _Asinh(VARP x)
364 {
365 return _Unary(x, UnaryOpOperation_ASINH);
366 }
367
368 /*Computes atanh of x element-wise.
369 Args:
370 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
371 Note: The output of atan will lie within the invertible range of tan, i.e (0.0, pi).
372 Returns:
373 A variable. Has the same type as x.
374 */
_Atanh(VARP x)375 VARP _Atanh(VARP x)
376 {
377 return _Unary(x, UnaryOpOperation_ATANH);
378 }
379
380 /*Computes cosh of x element-wise.
381 Args:
382 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
383 Note: The output of atan will lie within the invertible range of tan, i.e (0.0, pi).
384 Returns:
385 A variable. Has the same type as x.
386 */
_Cosh(VARP x)387 VARP _Cosh(VARP x)
388 {
389 return _Unary(x, UnaryOpOperation_COSH);
390 }
391
392 /*Computes sinh of x element-wise.
393 Args:
394 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
395 Note: The output of atan will lie within the invertible range of tan, i.e (0.0, pi).
396 Returns:
397 A variable. Has the same type as x.
398 */
_Sinh(VARP x)399 VARP _Sinh(VARP x)
400 {
401 return _Unary(x, UnaryOpOperation_SINH);
402 }
403
404 /*Computes the Gauss error function of `x` element-wise.
405 Args:
406 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
407 Note: The output of atan will lie within the invertible range of tan, i.e (0.0, pi).
408 Returns:
409 A variable. Has the same type as x.
410 */
_Erf(VARP x)411 VARP _Erf(VARP x)
412 {
413 return _Unary(x, UnaryOpOperation_ERF);
414 }
415
416 /*Computes the complementary error function of `x` element-wise.
417 Args:
418 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
419 Note: The output of atan will lie within the invertible range of tan, i.e (0.0, pi).
420 Returns:
421 A variable. Has the same type as x.
422 */
_Erfc(VARP x)423 VARP _Erfc(VARP x)
424 {
425 return _Unary(x, UnaryOpOperation_ERFC);
426 }
427
428 /*Computes the inverse function for erf, for `x` element-wise.
429 Args:
430 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
431 Note: The output of atan will lie within the invertible range of tan, i.e (0.0, pi).
432 Returns:
433 A variable. Has the same type as x.
434 */
_Erfinv(VARP x)435 VARP _Erfinv(VARP x)
436 {
437 return _Unary(x, UnaryOpOperation_ERFINV);
438 }
439
440 /*Computes sign of x eltment-wise
441 sign(x) = 0 if x=0
442 sign(x) =-1 if x<0
443 sign(x) = 1 if x>0
444 */
_Sign(VARP x)445 VARP _Sign(VARP x) {
446 return _Unary(x, UnaryOpOperation_SIGN);
447 }
448
449 /*Computes the trignometric inverse tangent of x element-wise.
450 The atan operation returns the inverse of tan, such that if y = tan(x) then, x = atan(y).
451 Note: The output of atan will lie within the invertible range of tan, i.e (-pi/2, pi/2).
452 Args:
453 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
454 Returns:
455 A variable. Has the same type as x.
456 */
_Atan(VARP x)457 VARP _Atan(VARP x)
458 {
459 return _Unary(x, UnaryOpOperation_ATAN);
460 }
461
462 /*Computes the reciprocal of x element-wise.
463 Args:
464 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
465 Returns:
466 A variable. Has the same type as x.
467 */
_Reciprocal(VARP x)468 VARP _Reciprocal(VARP x)
469 {
470 return _Unary(x, UnaryOpOperation_RECIPROCAL);
471 }
472
473 /*Computes natural logarithm of (1 + x) element-wise.
474 Args:
475 x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
476 Returns:
477 A variable. Has the same type as x.
478 */
_Log1p(VARP x)479 VARP _Log1p(VARP x)
480 {
481 return _Unary(x, UnaryOpOperation_LOG1P);
482 }
483
484 /*Computes Gelu of x element-wise.
485 Args:
486 x: A variable. Must be one of the following types: Halide_Type_Float
487 Returns:
488 A variable. Has the same type as x .
489 */
_Gelu(VARP x)490 VARP _Gelu(VARP x)
491 {
492 return _Unary(x, UnaryOpOperation_GELU);
493 }
494
495 /*Computes hyperbolic tangent of x element-wise.
496 Given an input variable, this function computes hyperbolic tangent of every element in the variable.
497 Input range is [-inf, inf] and output range is [-1,1].
498 Args:
499 x: A variable. Must be one of the following types: Halide_Type_Float
500 Returns:
501 A variable. Has the same type as x.
502 */
_Tanh(VARP x)503 VARP _Tanh(VARP x) {
504 std::unique_ptr<OpT> op(new OpT);
505 op->type = OpType_TanH;
506 return (Variable::create(Expr::create(op.get(), {x})));
507 }
508 /*Computes sigmoid of x element-wise.
509 Args:
510 x: A variable. Must be one of the following types: Halide_Type_Float
511 Returns:
512 A variable. Has the same type as x.
513 */
_Sigmoid(VARP x)514 VARP _Sigmoid(VARP x) {
515 std::unique_ptr<OpT> op(new OpT);
516 op->type = OpType_Sigmoid;
517 return (Variable::create(Expr::create(op.get(), {x})));
518 }
519
520 /*Computes ((exponential of x) - 1) element-wise.
521 Args:
522 x: A variable. Must be one of the following types: Halide_Type_Float
523 Returns:
524 A variable. Has the same type as x.
525 */
_Expm1(VARP x)526 VARP _Expm1(VARP x) {
527 return _Unary(x, UnaryOpOperation_EXPM1);
528 }
529
530
531 /*Returns x + y element-wise.
532 Args:
533 x: A variable. Must be one of the following types:
534 Halide_Type_Int or Halide_Type_Float, Halide_Type_Int64, Halide_Type_Uint8.
535 y: A variable. Must have the same type as x.
536 Returns:
537 A variable. Has the same type as x.
538 */
_Add(VARP x,VARP y)539 VARP _Add(VARP x, VARP y) {
540 return _Binary(x, y, BinaryOpOperation_ADD);
541 }
542
543 /*Returns x - y element-wise.
544 Args:
545 x: A variable. Must be one of the following types:
546 Halide_Type_Int or Halide_Type_Float, Halide_Type_Int64, Halide_Type_Uint8.
547 y: A variable. Must have the same type as x.
548 Returns:
549 A variable. Has the same type as x.
550 */
_Subtract(VARP x,VARP y)551 VARP _Subtract(VARP x, VARP y) {
552 return _Binary(x, y, BinaryOpOperation_SUB);
553 }
554
555 /*Returns x * y element-wise.
556 Args:
557 x: A variable. Must be one of the following types:
558 Halide_Type_Int or Halide_Type_Float, Halide_Type_Int64, Halide_Type_Uint8.
559 y: A variable. Must have the same type as x.
560 Returns:
561 A variable. Has the same type as x.
562 */
_Multiply(VARP x,VARP y)563 VARP _Multiply(VARP x, VARP y) {
564 return _Binary(x, y, BinaryOpOperation_MUL);
565 }
566
567 /*Computes Python style division of x by y.
568 Args:
569 x: A variable. Must be one of the following types:
570 Halide_Type_Int or Halide_Type_Float, Halide_Type_Int64, Halide_Type_Uint8.
571 y: A variable. Must have the same type as x.
572 Returns:
573 A variable. Has the same type as x.
574 */
_Divide(VARP x,VARP y)575 VARP _Divide(VARP x, VARP y) {
576 return _Binary(x, y, BinaryOpOperation_REALDIV);
577 }
578
579 /*Computes the power of one value to another.
580 Args:
581 x: A variable. Must be one of the following types:
582 Halide_Type_Int or Halide_Type_Float, Halide_Type_Int64
583 y: A variable. Must be one of the following types:
584 Halide_Type_Int or Halide_Type_Float, Halide_Type_Int64
585 Returns:
586 A variable. Has the same type as x.
587 */
_Pow(VARP x,VARP y)588 VARP _Pow(VARP x, VARP y) {
589 return _Binary(x, y, BinaryOpOperation_POW);
590 }
591
592 /*Returns the min of x and y (i.e. x < y ? x : y) element-wise.
593 Args:
594 x: A variable. Must be one of the following types:
595 Halide_Type_Int or Halide_Type_Float, Halide_Type_Int64
596 y: A variable. Must have the same type as x.
597 Returns:
598 A variable. Has the same type as x.
599 */
_Minimum(VARP x,VARP y)600 VARP _Minimum(VARP x, VARP y) {
601 return _Binary(x, y, BinaryOpOperation_MINIMUM);
602 }
603 /*Returns the max of x and y (i.e. x > y ? x : y) element-wise.
604 Args:
605 x: A variable. Must be one of the following types:
606 Halide_Type_Int or Halide_Type_Float, Halide_Type_Int64
607 y: A variable. Must have the same type as x.
608 Returns:
609 A variable. Has the same type as x.
610 */
_Maximum(VARP x,VARP y)611 VARP _Maximum(VARP x, VARP y) {
612 return _Binary(x, y, BinaryOpOperation_MAXIMUM);
613 }
614
615 /*Adds bias to value.
616 This is (mostly) a special case of add where bias is restricted to 1-D.
617 Broadcasting is supported, so value may have any number of dimensions.
618 Unlike add, the type of bias is allowed to differ from value in the case where both types are quantized.
619 Args:
620 value: A variable with type Halide_Type_Float, Halide_Type_Int
621 bias: A 1-D variable with size matching the channel dimension of value.
622 Must be the same type as value unless value is a quantized type, in which case a different quantized type may be used.
623 Returns:
624 A variable with the same type as value.
625 */
_BiasAdd(VARP value,VARP bias)626 VARP _BiasAdd(VARP value, VARP bias) {
627 return _Add(value, bias);
628 }
629
630 /*Returns the truth value of (x > y) element-wise.
631 Args:
632 x: A variable. Must be one of the following types: Halide_Type_Float, Halide_Type_Int
633 y: A variable. Must have the same type as x.
634 Returns:
635 A variable of type bool.
636 */
637
_Greater(VARP x,VARP y)638 VARP _Greater(VARP x, VARP y) {
639 return _Binary(x, y, BinaryOpOperation_GREATER);
640 }
641
642 /*Returns the truth value of (x >= y) element-wise.
643 Args:
644 x: A variable. Must be one of the following types: Halide_Type_Float, Halide_Type_Int
645 y: A variable. Must have the same type as x.
646 Returns:
647 A variable of type bool.
648 */
649
_GreaterEqual(VARP x,VARP y)650 VARP _GreaterEqual(VARP x, VARP y) {
651 return _Binary(x, y, BinaryOpOperation_GREATER_EQUAL);
652 }
653
654 /*Returns the truth value of (x < y) element-wise.
655 Args:
656 x: A variable. Must be one of the following types: Halide_Type_Float, Halide_Type_Int
657 y: A variable. Must have the same type as x.
658 Returns:
659 A variable of type bool.
660 */
661
_Less(VARP x,VARP y)662 VARP _Less(VARP x, VARP y) {
663 return _Binary(x, y, BinaryOpOperation_LESS);
664 }
665
666 /*Returns the value of (x // y) element-wise.
667 Args:
668 x: A variable. Must be one of the following types: Halide_Type_Float, Halide_Type_Int
669 y: A variable. Must have the same type as x.
670 Returns:
671 A variable. Has the same type as x.
672 */
673
_FloorDiv(VARP x,VARP y)674 VARP _FloorDiv(VARP x, VARP y) {
675 return _Binary(x, y, BinaryOpOperation_FLOORDIV);
676 }
677
678 /*Returns the value of (x - y)(x - y) element-wise.
679 Args:
680 x: A variable. Must be one of the following types: Halide_Type_Float, Halide_Type_Int
681 y: A variable. Must have the same type as x.
682 Returns:
683 A variable. Has the same type as x.
684 */
685
_SquaredDifference(VARP x,VARP y)686 VARP _SquaredDifference(VARP x, VARP y) {
687 return _Binary(x, y, BinaryOpOperation_SquaredDifference);
688 }
689
690 /*Returns the truth value of (x == y) element-wise.
691 Args:
692 x: A variable. Must be one of the following types: Halide_Type_Float, Halide_Type_Int
693 y: A variable. Must have the same type as x.
694 Returns:
695 A variable of type bool.
696 */
697
_Equal(VARP x,VARP y)698 VARP _Equal(VARP x, VARP y) {
699 return _Binary(x, y, BinaryOpOperation_EQUAL);
700 }
701
702 /*Returns the truth value of (x <= y) element-wise.
703 Args:
704 x: A variable. Must be one of the following types: Halide_Type_Float, Halide_Type_Int
705 y: A variable. Must have the same type as x.
706 Returns:
707 A variable of type bool.
708 */
709
_LessEqual(VARP x,VARP y)710 VARP _LessEqual(VARP x, VARP y) {
711 return _Binary(x, y, BinaryOpOperation_LESS_EQUAL);
712 }
713
714 /*Returns element-wise remainder of division
715 Args:
716 x: A variable. Must be one of the following types: Halide_Type_Float, Halide_Type_Int
717 y: A variable. Must have the same type as x.
718 Returns:
719 A variable. Has the same type as x.
720 */
721
_FloorMod(VARP x,VARP y)722 VARP _FloorMod(VARP x, VARP y) {
723 return _Binary(x, y, BinaryOpOperation_FLOORMOD);
724 }
725
726 /*Computes arctangent of `y/x` element-wise, respecting signs of the arguments.
727 Args:
728 x: A variable. Must be one of the following types: Halide_Type_Float, Halide_Type_Int
729 y: A variable. Must have the same type as x.
730 Returns:
731 A variable. Has the same type as x.
732 */
733
_Atan2(VARP x,VARP y)734 VARP _Atan2(VARP x, VARP y) {
735 return _Binary(x, y, BinaryOpOperation_ATAN2);
736 }
737
738 /*Returns the truth value of x OR y element-wise.
739 Args:
740 x: A variable. Must be one of the following types: Halide_Type_Int
741 y: A variable. Must have the same type as x.
742 Returns:
743 A variable. Has the same type as x.
744 */
745
_LogicalOr(VARP x,VARP y)746 VARP _LogicalOr(VARP x, VARP y) {
747 return _Binary(x, y, BinaryOpOperation_LOGICALOR);
748 }
749
750 /*Returns the truth value of x != y element-wise.
751 Args:
752 x: A variable. Must be one of the following types: Halide_Type_Int
753 y: A variable. Must have the same type as x.
754 Returns:
755 A variable. Has the same type as x.
756 */
757
_NotEqual(VARP x,VARP y)758 VARP _NotEqual(VARP x, VARP y) {
759 return _Binary(x, y, BinaryOpOperation_NOTEQUAL);
760 }
761
762 /*Computes the sum of elements across dimensions of a variable
763 Reduces input_variable along the dimensions given in axis.
764 Unless keepdims is true, the rank of the variable is reduced by 1 for each entry in axis.
765 If keepdims is true, the reduced dimensions are retained with length 1.
766 If axis is empty, all dimensions are reduced, and a variable with a single element is returned.
767 Args:
768 input_variable: The variable to reduce. Should have numeric type.
769 axis: The dimensions to reduce. If empty(the default), reduces all dimensions.
770 Must be in the range [-rank(input_variable), rank(input_variable)).
771 keepdims: If true, retains reduced dimensions with length 1.
772 Returns:
773 The reduced variable, of the same dtype as the input_variable.
774 */
_ReduceSum(VARP input_variable,INTS axis,bool keepdims)775 VARP _ReduceSum(VARP input_variable, INTS axis, bool keepdims) {
776 return _Reduce(input_variable, axis, ReductionType_SUM, keepdims);
777 }
778
_ReduceSumMutable(VARP input_variable,VARP axis,bool keepdims)779 VARP _ReduceSumMutable(VARP input_variable, VARP axis, bool keepdims) {
780 return _ReduceMutable(input_variable, axis, ReductionType_SUM, keepdims);
781 }
782 //ruhuan:TODO: ReductionType_ASUM and ReductionType_SUMSQ
783
784
785
786 /*Computes the mean of elements across dimensions of a variable.
787 Reduces input_variable along the dimensions given in axis.
788 Unless keepdims is true, the rank of the variable is reduced by 1 for each entry in axis.
789 If keepdims is true, the reduced dimensions are retained with length 1.
790 If axis is empty, all dimensions are reduced, and a variable with a single element is returned.
791 Args:
792 input_variable: The variable to reduce. Should have numeric type.
793 axis: The dimensions to reduce. If empty(the default), reduces all dimensions.
794 Must be in the range [-rank(input_variable), rank(input_variable)).
795 keepdims: If true, retains reduced dimensions with length 1.
796 Returns:
797 The reduced variable, of the same dtype as the input_variable.
798 */
_ReduceMean(VARP input_variable,INTS axis,bool keepdims)799 VARP _ReduceMean(VARP input_variable, INTS axis, bool keepdims) {
800 return _Reduce(input_variable, axis, ReductionType_MEAN, keepdims);
801 }
_ReduceMeanMutable(VARP input_variable,VARP axis,bool keepdims)802 VARP _ReduceMeanMutable(VARP input_variable, VARP axis, bool keepdims) {
803 return _ReduceMutable(input_variable, axis, ReductionType_MEAN, keepdims);
804 }
805
806 /*Computes the variance of elements across dimensions of a variable.
807 Reduces input_variable along the dimensions given in axis.
808 Unless keepdims is true, the rank of the variable is reduced by 1 for each entry in axis.
809 If keepdims is true, the reduced dimensions are retained with length 1.
810 If axis is empty, all dimensions are reduced, and a variable with a single element is returned.
811 Args:
812 input_variable: The variable to reduce. Should have numeric type.
813 axis: The dimensions to reduce. If empty(the default), reduces all dimensions.
814 Must be in the range [-rank(input_variable), rank(input_variable)).
815 keepdims: If true, retains reduced dimensions with length 1.
816 Returns:
817 The reduced variable, of the same dtype as the input_variable.
818 */
_ReduceVariance(VARP input_variable,INTS axis,bool keepdims)819 VARP _ReduceVariance(VARP input_variable, INTS axis, bool keepdims) {
820 auto mean = _ReduceMean(input_variable, axis, true); // to use broadcast of subtract
821 auto variance = _ReduceMean(_Square(_Subtract(input_variable, mean)), axis, keepdims);
822 return variance;
823 }
824
825 /*Computes the maximum of elements across dimensions of a variable.
826 Reduces input_variable along the dimensions given in axis.
827 Unless keepdims is true, the rank of the variable is reduced by 1 for each entry in axis.
828 If keepdims is true, the reduced dimensions are retained with length 1.
829 If axis is empty, all dimensions are reduced, and a variable with a single element is returned.
830 Args:
831 input_variable: The variable to reduce. Should have numeric type.
832 axis: The dimensions to reduce. If empty(the default), reduces all dimensions.
833 Must be in the range [-rank(input_variable), rank(input_variable)).
834 keepdims: If true, retains reduced dimensions with length 1.
835 Returns:
836 The reduced variable, of the same dtype as the input_variable.
837 */
_ReduceMax(VARP input_variable,INTS axis,bool keepdims)838 VARP _ReduceMax(VARP input_variable, INTS axis, bool keepdims) {
839 return _Reduce(input_variable, axis, ReductionType_MAXIMUM, keepdims);
840 }
_ReduceMaxMutable(VARP input_variable,VARP axis,bool keepdims)841 VARP _ReduceMaxMutable(VARP input_variable, VARP axis, bool keepdims) {
842 return _ReduceMutable(input_variable, axis, ReductionType_MAXIMUM, keepdims);
843 }
844
845 /*Computes the minimum of elements across dimensions of a variable.
846 Reduces input_variable along the dimensions given in axis.
847 Unless keepdims is true, the rank of the variable is reduced by 1 for each entry in axis.
848 If keepdims is true, the reduced dimensions are retained with length 1.
849 If axis is empty, all dimensions are reduced, and a variable with a single element is returned.
850 Args:
851 input_variable: The variable to reduce. Should have numeric type.
852 axis: The dimensions to reduce. If empty(the default), reduces all dimensions.
853 Must be in the range [-rank(input_variable), rank(input_variable)).
854 keepdims: If true, retains reduced dimensions with length 1.
855 Returns:
856 The reduced variable, of the same dtype as the input_variable.
857 */
_ReduceMin(VARP input_variable,INTS axis,bool keepdims)858 VARP _ReduceMin(VARP input_variable, INTS axis, bool keepdims) {
859 return _Reduce(input_variable, axis, ReductionType_MINIMUM, keepdims);
860 }
_ReduceMinMutable(VARP input_variable,VARP axis,bool keepdims)861 VARP _ReduceMinMutable(VARP input_variable, VARP axis, bool keepdims) {
862 return _ReduceMutable(input_variable, axis, ReductionType_MINIMUM, keepdims);
863 }
864
865 /*Computes the product of elements across dimensions of a variable.
866 Reduces input_variable along the dimensions given in axis.
867 Unless keepdims is true, the rank of the variable is reduced by 1 for each entry in axis.
868 If keepdims is true, the reduced dimensions are retained with length 1.
869 If axis is empty, all dimensions are reduced, and a variable with a single element is returned.
870 Args:
871 input_variable: The variable to reduce. Should have numeric type.
872 axis: The dimensions to reduce. If empty(the default), reduces all dimensions.
873 Must be in the range [-rank(input_variable), rank(input_variable)).
874 keepdims: If true, retains reduced dimensions with length 1.
875 Returns:
876 The reduced variable, of the same dtype as the input_variable.
877 */
_ReduceProd(VARP input_variable,INTS axis,bool keepdims)878 VARP _ReduceProd(VARP input_variable, INTS axis, bool keepdims) {
879 return _Reduce(input_variable, axis, ReductionType_PROD, keepdims);
880 }
_ReduceProdMutable(VARP input_variable,VARP axis,bool keepdims)881 VARP _ReduceProdMutable(VARP input_variable, VARP axis, bool keepdims) {
882 return _ReduceMutable(input_variable, axis, ReductionType_PROD, keepdims);
883 }
884 /*Computes the "logical or" of elements across dimensions of a variable.
885 Reduces input_variable along the dimensions given in axis.
886 Unless keepdims is true, the rank of the variable is reduced by 1 for each entry in axis.
887 If keepdims is true, the reduced dimensions are retained with length 1.
888 If axis is empty, all dimensions are reduced, and a variable with a single element is returned.
889 Args:
890 input_variable: The variable to reduce. Should have booling type.
891 axis: The dimensions to reduce. If empty(the default), reduces all dimensions.
892 Must be in the range [-rank(input_variable), rank(input_variable)).
893 keepdims: If true, retains reduced dimensions with length 1.
894 Returns:
895 The reduced variable, of the same dtype as the input_variable.
896 */
_ReduceAny(VARP input_variable,INTS axis,bool keepdims)897 VARP _ReduceAny(VARP input_variable, INTS axis, bool keepdims) {
898 return _Reduce(input_variable, axis, ReductionType_ANY, keepdims);
899 }
_ReduceAnyMutable(VARP input_variable,VARP axis,bool keepdims)900 VARP _ReduceAnyMutable(VARP input_variable, VARP axis, bool keepdims) {
901 return _ReduceMutable(input_variable, axis, ReductionType_ANY, keepdims);
902 }
903 /*Computes the "logical and" of elements across dimensions of a variable.
904 Reduces input_variable along the dimensions given in axis.
905 Unless keepdims is true, the rank of the variable is reduced by 1 for each entry in axis.
906 If keepdims is true, the reduced dimensions are retained with length 1.
907 If axis is empty, all dimensions are reduced, and a variable with a single element is returned.
908 Args:
909 input_variable: The variable to reduce. Should have booling type.
910 axis: The dimensions to reduce. If empty(the default), reduces all dimensions.
911 Must be in the range [-rank(input_variable), rank(input_variable)).
912 keepdims: If true, retains reduced dimensions with length 1.
913 Returns:
914 The reduced variable, of the same dtype as the input_variable.
915 */
_ReduceAll(VARP input_variable,INTS axis,bool keepdims)916 VARP _ReduceAll(VARP input_variable, INTS axis, bool keepdims) {
917 return _Reduce(input_variable, axis, ReductionType_ALL, keepdims);
918 }
_ReduceAllMutable(VARP input_variable,VARP axis,bool keepdims)919 VARP _ReduceAllMutable(VARP input_variable, VARP axis, bool keepdims) {
920 return _ReduceMutable(input_variable, axis, ReductionType_ALL, keepdims);
921 }
922
923 /*Multiply the matrix "a" by the matrix "b".
924 The inputs must be two-dimensional matrices and the inner dimension of "a" (after being transposed if transpose_a is true)
925 must match the outer dimension of "b" (after being transposed if transposed_b is true).
926 Arguments:
927 a: a variable representing a matrix "a"
928 b: a variable representing a matrix "b"
929 tranposeA: If true, "a" is transposed before multiplication.
930 tranposeB: If true, "b" is transposed before multiplication.
931 Returns:
932 The product variable.
933 */
_MatMul(VARP a,VARP b,bool tranposeA,bool tranposeB)934 VARP _MatMul(VARP a, VARP b, bool tranposeA, bool tranposeB) {
935 std::unique_ptr<OpT> op(new OpT);
936 op->main.type = OpParameter_MatMul;
937 op->type = OpType_MatMul;
938 op->main.value = new MatMulT;
939 op->main.AsMatMul()->transposeA = tranposeA;
940 op->main.AsMatMul()->transposeB = tranposeB;
941 return (Variable::create(Expr::create(op.get(), {a, b})));
942 }
_Normalize(VARP x,int32_t acrossSpatial,int32_t channelShared,float eps,std::vector<float> scale)943 VARP _Normalize(VARP x, int32_t acrossSpatial, int32_t channelShared, float eps, std::vector<float> scale) {
944 std::unique_ptr<OpT> op(new OpT);
945 op->main.type = OpParameter_Normalize;
946 op->type = OpType_Normalize;
947 op->main.value = new NormalizeT;
948 op->main.AsNormalize()->acrossSpatial = acrossSpatial;
949 op->main.AsNormalize()->channelShared = channelShared;
950 op->main.AsNormalize()->eps = eps;
951 op->main.AsNormalize()->scale = scale;
952 return (Variable::create(Expr::create(std::move(op), {x})));
953 }
954 /* Compute the element-wise prod
955 Args:
956 a: A variable. Must be one of the following types: Halide_Type_Float
957 b: A variable. Must be one of the following types: Halide_Type_Float
958 coeff: blob-wise coefficients
959 Returns:
960 The prod variable.
961 */
_Prod(VARP a,VARP b,std::vector<float> coeff)962 VARP _Prod(VARP a, VARP b, std::vector<float> coeff) {
963 return _Eltwise(a, b, EltwiseType_PROD, coeff);
964 }
965 /* Compute the element-wise sum
966 Args:
967 a: A variable. Must be one of the following types: Halide_Type_Float
968 b: A variable. Must be one of the following types: Halide_Type_Float
969 coeff: blob-wise coefficients
970 Returns:
971 The sum variable.
972 */
_Sum(VARP a,VARP b,std::vector<float> coeff)973 VARP _Sum(VARP a, VARP b, std::vector<float> coeff) {
974 return _Eltwise(a, b, EltwiseType_SUM, coeff);
975 }
976 /* Compute the element-wise max
977 Args:
978 a: A variable. Must be one of the following types: Halide_Type_Float
979 b: A variable. Must be one of the following types: Halide_Type_Float
980 coeff: blob-wise coefficients
981 Returns:
982 The max variable.
983 */
_Max(VARP a,VARP b,std::vector<float> coeff)984 VARP _Max(VARP a, VARP b, std::vector<float> coeff) {
985 return _Eltwise(a, b, EltwiseType_MAXIMUM, coeff);
986 }
987 /* Compute the element-wise sub
988 Args:
989 a: A variable. Must be one of the following types: Halide_Type_Float
990 b: A variable. Must be one of the following types: Halide_Type_Float
991 coeff: blob-wise coefficients
992 Returns:
993 The sub variable.
994 */
_Sub(VARP a,VARP b,std::vector<float> coeff)995 VARP _Sub(VARP a, VARP b, std::vector<float> coeff) {
996 return _Eltwise(a, b, EltwiseType_SUB, coeff);
997 }
998
999
1000 /*Returns the index with the largest value across axes of a tensor.
1001 Args: input: A variable. Must be one of the following types: Halide_Type_Float, Halide_Type_Int
1002 axis: A int.
1003 must be in the range -rank(input), rank(input)). Describes which axis of the input variable to reduce across.
1004 For vectors, use axis = 0.
1005 Returns:
1006 A variable of type int.
1007 */
_ArgMax(VARP input,int axis)1008 VARP _ArgMax(VARP input, int axis) {
1009 input = _checkNC4HW4(input);
1010 std::unique_ptr<OpT> op(new OpT);
1011 op->main.type = OpParameter_ArgMax;
1012 op->type = OpType_ArgMax;
1013 op->main.value = new ArgMaxT;
1014 op->main.AsArgMax()->axis = axis;
1015 op->main.AsArgMax()->outMaxVal = 0;
1016 op->main.AsArgMax()->topK = 0;
1017 op->main.AsArgMax()->softmaxThreshold = 0;
1018 return (Variable::create(Expr::create(std::move(op), {input})));
1019
1020 }
1021
1022 /*Returns the index with the smallest value across axes of a tensor.
1023 Args: input: A variable. Must be one of the following types: Halide_Type_Float, Halide_Type_Int
1024 axis: A int.
1025 must be in the range -rank(input), rank(input)). Describes which axis of the input variable to reduce across.
1026 For vectors, use axis = 0.
1027 Returns:
1028 A variable of type int.
1029 */
_ArgMin(VARP input,int axis)1030 VARP _ArgMin(VARP input, int axis) {
1031 input = _checkNC4HW4(input);
1032 std::unique_ptr<OpT> op(new OpT);
1033 op->main.type = OpParameter_ArgMax;
1034 op->type = OpType_ArgMin;
1035 op->main.value = new ArgMaxT;
1036 op->main.AsArgMax()->axis = axis;
1037 op->main.AsArgMax()->outMaxVal = 0;
1038 op->main.AsArgMax()->topK = 0;
1039 op->main.AsArgMax()->softmaxThreshold = 0;
1040 return (Variable::create(Expr::create(std::move(op), {input})));
1041 }
1042
1043 /*Multiplies slices of two variable in batches
1044 Multiplies all slices of variable x and y (each slice can be viewed as an element of a batch),
1045 and arranges the individual results in a single output variable of the same batch size.
1046 Each of the individual slices can optionally be adjointed (to adjoint a matrix means to transpose and conjugate it)
1047 before multiplication by setting the adj_x or adj_y flag to True, which are by default False.
1048 The input variable x and y are 2-D or higher with shape [..., r_x, c_x] and [..., r_y, c_y].
1049 The output variable is 2-D or higher with shape [..., r_o, c_o], where:
1050 r_o = c_x if adj_x else r_x
1051 c_o = r_y if adj_y else c_y
1052 It is computed as:
1053 output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
1054 Arguments:
1055 x: 2-D or higher with shape [..., r_x, c_x].
1056 y: 2-D or higher with shape [..., r_y, c_y].
1057 Optional:
1058 adj_x: If True, adjoint the slices of x. Defaults to False.
1059 adj_y: If True, adjoint the slices of y. Defaults to False.
1060 Returns:
1061 Output: 3-D or higher with shape [..., r_o, c_o]
1062 */
_BatchMatMul(VARP x,VARP y,bool adj_x,bool adj_y)1063 VARP _BatchMatMul(VARP x, VARP y, bool adj_x, bool adj_y) {
1064 std::unique_ptr<OpT> op(new OpT);
1065 op->main.type = OpParameter_BatchMatMulParam;
1066 op->type = OpType_BatchMatMul;
1067 op->main.value = new BatchMatMulParamT;
1068 op->main.AsBatchMatMulParam()->adjX = adj_x;
1069 op->main.AsBatchMatMulParam()->adjY = adj_y;
1070
1071 return (Variable::create(Expr::create(std::move(op), {x, y})));
1072 }
1073
1074
_UnravelIndex(VARP indices,VARP dims)1075 VARP _UnravelIndex(VARP indices, VARP dims) {
1076 std::unique_ptr<OpT> op(new OpT);
1077 op->main.type = OpParameter_NONE;
1078 op->type = OpType_UnravelIndex;
1079 op->main.value = nullptr;
1080
1081 return (Variable::create(Expr::create(std::move(op), {indices, dims})));
1082 }
1083
_ScatterNd(VARP indices,VARP updates,VARP shape)1084 VARP _ScatterNd(VARP indices, VARP updates, VARP shape) {
1085 std::unique_ptr<OpT> op(new OpT);
1086 op->main.type = OpParameter_NONE;
1087 op->type = OpType_ScatterNd;
1088 op->main.value = nullptr;
1089 return (Variable::create(Expr::create(std::move(op), {indices, updates, shape})));
1090 }
1091
_OneHot(VARP indices,VARP depth,VARP onValue,VARP offValue,int axis)1092 VARP _OneHot(VARP indices, VARP depth, VARP onValue, VARP offValue, int axis) {
1093 std::unique_ptr<OpT> op(new OpT);
1094 op->type = OpType_OneHot;
1095 op->main.type = OpParameter_OneHotParam;
1096 op->main.value = new OneHotParamT;
1097 op->main.AsOneHotParam()->axis = axis;
1098
1099 return (Variable::create(Expr::create(std::move(op), {indices, depth, onValue, offValue})));
1100 }
1101
_BroadcastTo(VARP a,VARP shape)1102 VARP _BroadcastTo(VARP a, VARP shape) {
1103 std::unique_ptr<OpT> op(new OpT);
1104 op->type = OpType_BroadcastTo;
1105 op->main.type = OpParameter_NONE;
1106 op->main.value = nullptr;
1107 return (Variable::create(Expr::create(std::move(op), {a, shape})));
1108 }
1109
_LinSpace(VARP start,VARP stop,VARP num)1110 VARP _LinSpace(VARP start, VARP stop, VARP num) {
1111 std::unique_ptr<OpT> op(new OpT);
1112 op->type = OpType_LinSpace;
1113 op->main.type = OpParameter_NONE;
1114 op->main.value = nullptr;
1115 return (Variable::create(Expr::create(std::move(op), {start, stop, num})));
1116 }
1117
_EltwiseProdInt8(VARP x,VARP y,std::vector<int8_t> x_weight,std::vector<int32_t> x_bias,std::vector<float> x_scale,std::vector<float> x_tensorScale,std::vector<int8_t> y_weight,std::vector<int32_t> y_bias,std::vector<float> y_scale,std::vector<float> y_tensorScale,std::vector<int8_t> output_weight,std::vector<int32_t> output_bias,std::vector<float> output_scale,std::vector<float> output_tensorScale)1118 VARP _EltwiseProdInt8(VARP x, VARP y,
1119 std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
1120 std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
1121 std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale)
1122 {
1123 return _EltwiseInt8(x, y, EltwiseType_PROD,
1124 x_weight, x_bias, x_scale, x_tensorScale,
1125 y_weight, y_bias, y_scale, y_tensorScale,
1126 output_weight, output_bias, output_scale, output_tensorScale);
1127 }
1128
_EltwiseSumInt8(VARP x,VARP y,std::vector<int8_t> x_weight,std::vector<int32_t> x_bias,std::vector<float> x_scale,std::vector<float> x_tensorScale,std::vector<int8_t> y_weight,std::vector<int32_t> y_bias,std::vector<float> y_scale,std::vector<float> y_tensorScale,std::vector<int8_t> output_weight,std::vector<int32_t> output_bias,std::vector<float> output_scale,std::vector<float> output_tensorScale)1129 VARP _EltwiseSumInt8(VARP x, VARP y,
1130 std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
1131 std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
1132 std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale)
1133 {
1134 return _EltwiseInt8(x, y, EltwiseType_SUM,
1135 x_weight, x_bias, x_scale, x_tensorScale,
1136 y_weight, y_bias, y_scale, y_tensorScale,
1137 output_weight, output_bias, output_scale, output_tensorScale);
1138 }
1139
_EltwiseSubInt8(VARP x,VARP y,std::vector<int8_t> x_weight,std::vector<int32_t> x_bias,std::vector<float> x_scale,std::vector<float> x_tensorScale,std::vector<int8_t> y_weight,std::vector<int32_t> y_bias,std::vector<float> y_scale,std::vector<float> y_tensorScale,std::vector<int8_t> output_weight,std::vector<int32_t> output_bias,std::vector<float> output_scale,std::vector<float> output_tensorScale)1140 VARP _EltwiseSubInt8(VARP x, VARP y,
1141 std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
1142 std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
1143 std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale)
1144 {
1145 return _EltwiseInt8(x, y, EltwiseType_SUB,
1146 x_weight, x_bias, x_scale, x_tensorScale,
1147 y_weight, y_bias, y_scale, y_tensorScale,
1148 output_weight, output_bias, output_scale, output_tensorScale);
1149 }
1150
_EltwiseMaxInt8(VARP x,VARP y,std::vector<int8_t> x_weight,std::vector<int32_t> x_bias,std::vector<float> x_scale,std::vector<float> x_tensorScale,std::vector<int8_t> y_weight,std::vector<int32_t> y_bias,std::vector<float> y_scale,std::vector<float> y_tensorScale,std::vector<int8_t> output_weight,std::vector<int32_t> output_bias,std::vector<float> output_scale,std::vector<float> output_tensorScale)1151 VARP _EltwiseMaxInt8(VARP x, VARP y,
1152 std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
1153 std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
1154 std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale)
1155 {
1156 return _EltwiseInt8(x, y, EltwiseType_MAXIMUM,
1157 x_weight, x_bias, x_scale, x_tensorScale,
1158 y_weight, y_bias, y_scale, y_tensorScale,
1159 output_weight, output_bias, output_scale, output_tensorScale);
1160 }
1161
1162 } // namespace Express
1163 } // namespace MNN
1164