1 /*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5
6 #include <functional>
7 #include "onnx/defs/schema.h"
8 #include "onnx/defs/tensor_proto_util.h"
9 #include "onnx/defs/function.h"
10
11
12 namespace ONNX_NAMESPACE {
13
MathDocGenerator_opset13(const char * name)14 std::function<void(OpSchema&)> MathDocGenerator_opset13(const char* name) {
15 return [=](OpSchema& schema) {
16 std::string doc;
17 POPULATE_OP_DOC_STR(
18 doc = R"DOC(
19 Performs element-wise binary {name} (with Numpy-style broadcasting support).
20
21 {broadcast_doc}
22 )DOC";
23 ReplaceAll(doc, "{name}", name);
24 ReplaceAll(
25 doc, "{broadcast_doc}", GenerateBroadcastingDocMul().c_str()););
26 schema.SetDoc(doc);
27 schema.Input(0,
28 "A",
29 "First operand.",
30 "T",
31 OpSchema::Single,
32 true,
33 1,
34 OpSchema::Differentiable);
35 schema.Input(1,
36 "B",
37 "Second operand.",
38 "T",
39 OpSchema::Single,
40 true,
41 1,
42 OpSchema::Differentiable);
43 schema.Output(0,
44 "C",
45 "Result, has same element type as two inputs",
46 "T",
47 OpSchema::Single,
48 true,
49 1,
50 OpSchema::Differentiable);
51 schema.TypeConstraint(
52 "T",
53 OpSchema::numeric_types_for_math_reduction_with_bfloat(),
54 "Constrain input and output types to high-precision numeric tensors.");
55 schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
56 propagateElemTypeFromInputToOutput(ctx, 0, 0);
57 if (hasNInputShapes(ctx, 2))
58 bidirectionalBroadcastShapeInference(
59 ctx.getInputType(0)->tensor_type().shape(),
60 ctx.getInputType(1)->tensor_type().shape(),
61 *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
62 });
63 };
64 }
65
66 ONNX_OPERATOR_SET_SCHEMA(
67 Add,
68 13,
69 OpSchema().FillUsing(MathDocGenerator_opset13("addition")));
70
71 ONNX_OPERATOR_SET_SCHEMA(
72 Sub,
73 13,
74 OpSchema().FillUsing(MathDocGenerator_opset13("subtraction")));
75
76 ONNX_OPERATOR_SET_SCHEMA(
77 Mul,
78 13,
79 OpSchema().FillUsing(MathDocGenerator_opset13("multiplication")));
80
81 ONNX_OPERATOR_SET_SCHEMA(
82 Div,
83 13,
84 OpSchema().FillUsing(MathDocGenerator_opset13("division")));
85
MathDocGenerator_opset_7(const char * name)86 std::function<void(OpSchema&)> MathDocGenerator_opset_7(const char* name) {
87 return [=](OpSchema& schema) {
88 std::string doc;
89 POPULATE_OP_DOC_STR(
90 doc = R"DOC(
91 Performs element-wise binary {name} (with Numpy-style broadcasting support).
92
93 {broadcast_doc}
94 )DOC";
95 ReplaceAll(doc, "{name}", name);
96 ReplaceAll(
97 doc, "{broadcast_doc}", GenerateBroadcastingDocMul().c_str()););
98 schema.SetDoc(doc);
99 schema.Input(0, "A", "First operand.", "T");
100 schema.Input(1, "B", "Second operand.", "T");
101 schema.Output(0, "C", "Result, has same element type as two inputs", "T");
102 schema.TypeConstraint(
103 "T",
104 OpSchema::numeric_types_for_math_reduction(),
105 "Constrain input and output types to high-precision numeric tensors.");
106 schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
107 propagateElemTypeFromInputToOutput(ctx, 0, 0);
108 if (hasNInputShapes(ctx, 2))
109 bidirectionalBroadcastShapeInference(
110 ctx.getInputType(0)->tensor_type().shape(),
111 ctx.getInputType(1)->tensor_type().shape(),
112 *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
113 });
114 };
115 }
116
117 ONNX_OPERATOR_SET_SCHEMA(
118 Add,
119 7,
120 OpSchema().FillUsing(MathDocGenerator_opset_7("addition")));
121
122 ONNX_OPERATOR_SET_SCHEMA(
123 Sub,
124 7,
125 OpSchema().FillUsing(MathDocGenerator_opset_7("subtraction")));
126
127 ONNX_OPERATOR_SET_SCHEMA(
128 Mul,
129 7,
130 OpSchema().FillUsing(MathDocGenerator_opset_7("multiplication")));
131
132 ONNX_OPERATOR_SET_SCHEMA(
133 Div,
134 7,
135 OpSchema().FillUsing(MathDocGenerator_opset_7("division")));
136
SoftmaxFamilyDocGenerator_opset_11(const char * name,const char * description)137 std::function<void(OpSchema&)> SoftmaxFamilyDocGenerator_opset_11(
138 const char* name,
139 const char* description) {
140 return [=](OpSchema& schema) {
141 std::string doc;
142 POPULATE_OP_DOC_STR(doc = R"DOC(
143 The operator computes the {name} ({description}) values for each layer in the batch
144 of the given input.
145
146 The input does not need to explicitly be a 2D vector; rather, it will be
147 coerced into one. For an arbitrary n-dimensional tensor
148 input \in [a_0, a_1, ..., a_{k-1}, a_k, ..., a_{n-1}] and k is
149 the axis provided, then input will be coerced into a 2-dimensional tensor with
150 dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. For the default
151 case where axis=1, this means the input tensor will be coerced into a 2D tensor
152 of dimensions [a_0, a_1 * ... * a_{n-1}], where a_0 is often the batch size.
153 In this situation, we must have a_0 = N and a_1 * ... * a_{n-1} = D.
154 Each of these dimensions must be matched correctly, or else the operator
155 will throw errors. The output tensor has the same shape
156 and contains the {name} values of the corresponding input.
157 )DOC";
158 ReplaceAll(doc, "{name}", name);
159 ReplaceAll(doc, "{description}", description););
160 schema.SetDoc(doc);
161 schema.Attr(
162 "axis",
163 "Describes the axis of the inputs when coerced "
164 "to 2D; defaults to one because the 0th axis most likely describes "
165 "the batch_size. Negative value means counting dimensions "
166 "from the back. Accepted range is [-r, r-1] where r = rank(input).",
167 AttributeProto::INT,
168 static_cast<int64_t>(1));
169 schema.Input(
170 0,
171 "input",
172 "The input tensor that's coerced into a 2D matrix of size (NxD) "
173 "as described above.",
174 "T");
175 schema.Output(
176 0,
177 "output",
178 "The output values with the same "
179 "shape as input tensor (the original size without coercion).",
180 "T");
181 schema.TypeConstraint(
182 "T",
183 {"tensor(float16)", "tensor(float)", "tensor(double)"},
184 "Constrain input and output types to float tensors.");
185 schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
186 // Type inference
187 propagateElemTypeFromInputToOutput(ctx, 0, 0);
188
189 // Shape inference starts
190 if (!hasNInputShapes(ctx, 1)) {
191 return;
192 }
193
194 // Validate the value of 'axis'
195 const TensorShapeProto& input_shape =
196 ctx.getInputType(0)->tensor_type().shape();
197 int r = input_shape.dim_size();
198 int axis = static_cast<int>(getAttribute(ctx, "axis", 1));
199 if (axis < -r || axis >= r) {
200 fail_shape_inference(
201 "'axis' must be in [",
202 -r,
203 " , ",
204 (r - 1),
205 "]. Its actual value is: ",
206 axis);
207 }
208
209 // Shape inference
210 propagateShapeFromInputToOutput(ctx, 0, 0);
211 });
212 };
213 }
214
215 ONNX_OPERATOR_SET_SCHEMA(
216 Softmax,
217 11,
218 OpSchema().FillUsing(SoftmaxFamilyDocGenerator_opset_11(
219 "softmax",
220 "normalized exponential")));
221
222 ONNX_OPERATOR_SET_SCHEMA(
223 LogSoftmax,
224 11,
225 OpSchema().FillUsing(
226 SoftmaxFamilyDocGenerator_opset_11("logsoftmax", "log of softmax")));
227
228 ONNX_OPERATOR_SET_SCHEMA(
229 Hardmax,
230 11,
231 OpSchema().FillUsing(SoftmaxFamilyDocGenerator_opset_11(
232 "hardmax",
233 "1 for the first maximum value, and 0 for all others")));
234
235 static const char* Mod_doc_10 = R"DOC(
236 Performs element-wise binary modulus (with Numpy-style broadcasting support).
237 The sign of the remainder is the same as that of the Divisor.
238
239 Mod operator can also behave like C fmod() or numpy.fmod. In this case, the sign of the remainder however, will be the same as the Dividend
240 (in contrast to integer mod). To force a behavior like numpy.fmod() an 'fmod' Attribute is provided.
241 This attribute is set to 0 by default causing the behavior to be like integer mod.
242 Setting this attribute to 1 causes the remainder to be calculated similar to that of numpy.fmod().
243
244 If the input type is floating point, then `fmod` attribute must be set to 1.
245
246 In case of dividend being zero, the results will be platform dependent.
247
248 This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md).
249 )DOC";
250
251 ONNX_OPERATOR_SET_SCHEMA(
252 Mod,
253 10,
254 OpSchema()
255 .SetDoc(Mod_doc_10)
256 .Attr(
257 "fmod",
258 "Whether the operator should behave like fmod (default=0 meaning it will do integer mods); Set this to 1 to force fmod treatment",
259 AttributeProto::INT,
260 static_cast<int64_t>(0))
261 .Input(0, "A", "Dividend tensor", "T")
262 .Input(1, "B", "Divisor tensor", "T")
263 .Output(0, "C", "Remainder tensor", "T")
264 .TypeConstraint(
265 "T",
266 OpSchema::all_numeric_types(),
267 "Constrain input and output types to high-precision numeric tensors.")
__anon4a9f2ddb0702(InferenceContext& ctx) 268 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
269 propagateElemTypeFromInputToOutput(ctx, 0, 0);
270 if (hasNInputShapes(ctx, 2))
271 bidirectionalBroadcastShapeInference(
272 ctx.getInputType(0)->tensor_type().shape(),
273 ctx.getInputType(1)->tensor_type().shape(),
274 *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
275 }));
276
277 static const char* Neg_ver6_doc = R"DOC(
278 Neg takes one input data (Tensor<T>) and produces one output data
279 (Tensor<T>) where each element flipped sign, y = -x, is applied to
280 the tensor elementwise.
281 )DOC";
282
283 ONNX_OPERATOR_SET_SCHEMA(
284 Neg,
285 6,
286 OpSchema()
287 .SetDoc(Neg_ver6_doc)
288 .Input(0, "X", "Input tensor", "T")
289 .Output(0, "Y", "Output tensor", "T")
290 .TypeConstraint(
291 "T",
292 {"tensor(float)",
293 "tensor(int32)",
294 "tensor(int8)",
295 "tensor(int16)",
296 "tensor(int64)",
297 "tensor(float16)",
298 "tensor(double)"},
299 "Constrain input and output types to signed numeric tensors.")
300 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
301
302 static const char* Abs_ver6_doc = R"DOC(
303 Absolute takes one input data (Tensor<T>) and produces one output data
304 (Tensor<T>) where the absolute is, y = abs(x), is applied to
305 the tensor elementwise.
306 )DOC";
307
308 ONNX_OPERATOR_SET_SCHEMA(
309 Abs,
310 6,
311 OpSchema()
312 .SetDoc(Abs_ver6_doc)
313 .Input(0, "X", "Input tensor", "T")
314 .Output(0, "Y", "Output tensor", "T")
315 .TypeConstraint(
316 "T",
317 OpSchema::all_numeric_types(),
318 "Constrain input and output types to all numeric tensors.")
319 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
320
321 static const char* Reciprocal_ver6_doc = R"DOC(
322 Reciprocal takes one input data (Tensor<T>) and produces one output data
323 (Tensor<T>) where the reciprocal is, y = 1/x, is applied to
324 the tensor elementwise.
325 )DOC";
326
327 ONNX_OPERATOR_SET_SCHEMA(
328 Reciprocal,
329 6,
330 OpSchema()
331 .SetDoc(Reciprocal_ver6_doc)
332 .Input(0, "X", "Input tensor", "T")
333 .Output(0, "Y", "Output tensor", "T")
334 .TypeConstraint(
335 "T",
336 {"tensor(float16)", "tensor(float)", "tensor(double)"},
337 "Constrain input and output types to float tensors.")
338 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
339
340 static const char* Floor_ver6_doc = R"DOC(
341 Floor takes one input data (Tensor<T>) and produces one output data
342 (Tensor<T>) where the floor is, y = floor(x), is applied to
343 the tensor elementwise.
344 )DOC";
345
346 ONNX_OPERATOR_SET_SCHEMA(
347 Floor,
348 6,
349 OpSchema()
350 .SetDoc(Floor_ver6_doc)
351 .Input(0, "X", "Input tensor", "T")
352 .Output(0, "Y", "Output tensor", "T")
353 .TypeConstraint(
354 "T",
355 {"tensor(float16)", "tensor(float)", "tensor(double)"},
356 "Constrain input and output types to float tensors.")
357 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
358
359 static const char* Ceil_ver6_doc = R"DOC(
360 Ceil takes one input data (Tensor<T>) and produces one output data
361 (Tensor<T>) where the ceil is, y = ceil(x), is applied to
362 the tensor elementwise.
363 )DOC";
364
365 ONNX_OPERATOR_SET_SCHEMA(
366 Ceil,
367 6,
368 OpSchema()
369 .SetDoc(Ceil_ver6_doc)
370 .Input(0, "X", "Input tensor", "T")
371 .Output(0, "Y", "Output tensor", "T")
372 .TypeConstraint(
373 "T",
374 {"tensor(float16)", "tensor(float)", "tensor(double)"},
375 "Constrain input and output types to float tensors.")
376 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
377
378 static const char* Sqrt_ver6_doc = R"DOC(
379 Square root takes one input data (Tensor<T>) and produces one output data
380 (Tensor<T>) where the square root is, y = x^0.5, is applied to
381 the tensor elementwise. If x is negative, then it will return NaN.
382 )DOC";
383
384 ONNX_OPERATOR_SET_SCHEMA(
385 Sqrt,
386 6,
387 OpSchema()
388 .SetDoc(Sqrt_ver6_doc)
389 .Input(0, "X", "Input tensor", "T")
390 .Output(0, "Y", "Output tensor", "T")
391 .TypeConstraint(
392 "T",
393 {"tensor(float16)", "tensor(float)", "tensor(double)"},
394 "Constrain input and output types to float tensors.")
395 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
396
397 static const char* Relu_ver6_doc = R"DOC(
398 Relu takes one input data (Tensor<T>) and produces one output data
399 (Tensor<T>) where the rectified linear function, y = max(0, x), is applied to
400 the tensor elementwise.
401 )DOC";
402
403 ONNX_OPERATOR_SET_SCHEMA(
404 Relu,
405 6,
406 OpSchema()
407 .SetDoc(Relu_ver6_doc)
408 .Input(0, "X", "Input tensor", "T")
409 .Output(0, "Y", "Output tensor", "T")
410 .TypeConstraint(
411 "T",
412 {"tensor(float16)", "tensor(float)", "tensor(double)"},
413 "Constrain input and output types to float tensors.")
414 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
415
416 static const char* Relu_ver13_doc = R"DOC(
417 Relu takes one input data (Tensor<T>) and produces one output data
418 (Tensor<T>) where the rectified linear function, y = max(0, x), is applied to
419 the tensor elementwise.
420 )DOC";
421
422 ONNX_OPERATOR_SET_SCHEMA(
423 Relu,
424 13,
425 OpSchema()
426 .SetDoc(Relu_ver13_doc)
427 .Input(0,
428 "X",
429 "Input tensor",
430 "T",
431 OpSchema::Single,
432 true,
433 1,
434 OpSchema::Differentiable)
435 .Output(0,
436 "Y",
437 "Output tensor",
438 "T",
439 OpSchema::Single,
440 true,
441 1,
442 OpSchema::Differentiable)
443 .TypeConstraint(
444 "T",
445 {"tensor(float16)",
446 "tensor(float)",
447 "tensor(double)",
448 "tensor(bfloat16)"},
449 "Constrain input and output types to float tensors.")
450 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
451
452 static const char* Exp_ver6_doc = R"DOC(
453 Calculates the exponential of the given input tensor, element-wise.
454 )DOC";
455
456 ONNX_OPERATOR_SET_SCHEMA(
457 Exp,
458 6,
459 OpSchema()
460 .SetDoc(Exp_ver6_doc)
461 .Input(0, "input", "Input tensor", "T")
462 .Output(
463 0,
464 "output",
465 "The exponential of the input tensor computed "
466 "element-wise",
467 "T")
468 .TypeConstraint(
469 "T",
470 {"tensor(float16)", "tensor(float)", "tensor(double)"},
471 "Constrain input and output types to float tensors.")
472 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
473
474 static const char* Log_ver6_doc = R"DOC(
475 Calculates the natural log of the given input tensor, element-wise.
476 )DOC";
477
478 ONNX_OPERATOR_SET_SCHEMA(
479 Log,
480 6,
481 OpSchema()
482 .SetDoc(Log_ver6_doc)
483 .Input(0, "input", "Input tensor", "T")
484 .Output(
485 0,
486 "output",
487 "The natural log of the input tensor computed "
488 "element-wise",
489 "T")
490 .TypeConstraint(
491 "T",
492 {"tensor(float16)", "tensor(float)", "tensor(double)"},
493 "Constrain input and output types to float tensors.")
494 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
495
496 static const char* Tanh_ver6_doc = R"DOC(
497 Calculates the hyperbolic tangent of the given input tensor element-wise.
498 )DOC";
499
500 ONNX_OPERATOR_SET_SCHEMA(
501 Tanh,
502 6,
503 OpSchema()
504 .SetDoc(Tanh_ver6_doc)
505 .Input(0, "input", "Input tensor", "T")
506 .Output(
507 0,
508 "output",
509 "The hyperbolic tangent values of the input tensor "
510 "computed element-wise",
511 "T")
512 .TypeConstraint(
513 "T",
514 {"tensor(float16)", "tensor(float)", "tensor(double)"},
515 "Constrain input and output types to float tensors.")
516 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
517
518 static const char* Pow_ver13_doc = R"DOC(
519 Pow takes input data (Tensor<T>) and exponent Tensor, and
520 produces one output data (Tensor<T>) where the function `f(x) = x^exponent`,
521 is applied to the data tensor elementwise.
522 )DOC";
523
524 ONNX_OPERATOR_SET_SCHEMA(
525 Pow,
526 13,
527 OpSchema()
528 .SetDoc(GET_OP_DOC_STR(
529 std::string(Pow_ver13_doc) + GenerateBroadcastingDocMul()))
530 .Input(0,
531 "X",
532 "First operand, base of the exponent.",
533 "T",
534 OpSchema::Single,
535 true,
536 1,
537 OpSchema::Differentiable)
538 .Input(1,
539 "Y",
540 "Second operand, power of the exponent.",
541 "T1",
542 OpSchema::Single,
543 true,
544 1,
545 OpSchema::Differentiable)
546 .Output(0,
547 "Z",
548 "Output tensor",
549 "T",
550 OpSchema::Single,
551 true,
552 1,
553 OpSchema::Differentiable)
554 .TypeConstraint(
555 "T",
556 {"tensor(int32)",
557 "tensor(int64)",
558 "tensor(float16)",
559 "tensor(float)",
560 "tensor(double)",
561 "tensor(bfloat16)"},
562 "Constrain input X and output types to float/int tensors.")
563 .TypeConstraint(
564 "T1",
565 {"tensor(uint8)",
566 "tensor(uint16)",
567 "tensor(uint32)",
568 "tensor(uint64)",
569 "tensor(int8)",
570 "tensor(int16)",
571 "tensor(int32)",
572 "tensor(int64)",
573 "tensor(float16)",
574 "tensor(float)",
575 "tensor(double)"},
576 "Constrain input Y types to float/int tensors.")
__anon4a9f2ddb0802(InferenceContext& ctx) 577 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
578 propagateElemTypeFromInputToOutput(ctx, 0, 0);
579 if (hasNInputShapes(ctx, 2))
580 bidirectionalBroadcastShapeInference(
581 ctx.getInputType(0)->tensor_type().shape(),
582 ctx.getInputType(1)->tensor_type().shape(),
583 *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
584 }));
585
586 static const char* Pow_ver12_doc = R"DOC(
587 Pow takes input data (Tensor<T>) and exponent Tensor, and
588 produces one output data (Tensor<T>) where the function `f(x) = x^exponent`,
589 is applied to the data tensor elementwise.
590 )DOC";
591
592 ONNX_OPERATOR_SET_SCHEMA(
593 Pow,
594 12,
595 OpSchema()
596 .SetDoc(GET_OP_DOC_STR(
597 std::string(Pow_ver12_doc) + GenerateBroadcastingDocMul()))
598 .Input(0, "X", "First operand, base of the exponent.", "T")
599 .Input(1, "Y", "Second operand, power of the exponent.", "T1")
600 .Output(0, "Z", "Output tensor.", "T")
601 .TypeConstraint(
602 "T",
603 {"tensor(int32)",
604 "tensor(int64)",
605 "tensor(float16)",
606 "tensor(float)",
607 "tensor(double)"},
608 "Constrain input X and output types to float/int tensors.")
609 .TypeConstraint(
610 "T1",
611 {"tensor(uint8)",
612 "tensor(uint16)",
613 "tensor(uint32)",
614 "tensor(uint64)",
615 "tensor(int8)",
616 "tensor(int16)",
617 "tensor(int32)",
618 "tensor(int64)",
619 "tensor(float16)",
620 "tensor(float)",
621 "tensor(double)"},
622 "Constrain input Y types to float/int tensors.")
__anon4a9f2ddb0902(InferenceContext& ctx) 623 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
624 propagateElemTypeFromInputToOutput(ctx, 0, 0);
625 if (hasNInputShapes(ctx, 2))
626 bidirectionalBroadcastShapeInference(
627 ctx.getInputType(0)->tensor_type().shape(),
628 ctx.getInputType(1)->tensor_type().shape(),
629 *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
630 }));
631
632 static const char* Sigmoid_ver6_doc = R"DOC(
633 Sigmoid takes one input data (Tensor<T>) and produces one output data
634 (Tensor<T>) where the sigmoid function, y = 1 / (1 + exp(-x)), is applied to the
635 tensor elementwise.
636 )DOC";
637
638 ONNX_OPERATOR_SET_SCHEMA(
639 Sigmoid,
640 6,
641 OpSchema()
642 .SetDoc(Sigmoid_ver6_doc)
643 .Input(0, "X", "Input tensor", "T")
644 .Output(0, "Y", "Output tensor", "T")
645 .TypeConstraint(
646 "T",
647 {"tensor(float16)", "tensor(float)", "tensor(double)"},
648 "Constrain input and output types to float tensors.")
649 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
650
651 // Generate opschema for element-wise ops. Leaves type constraint "T"
652 // unspecified.
ElementwiseMultiOpDocGenerator_opset8(const char * name)653 std::function<void(OpSchema&)> ElementwiseMultiOpDocGenerator_opset8(
654 const char* name) {
655 return [=](OpSchema& schema) {
656 std::string doc;
657 POPULATE_OP_DOC_STR(
658 doc = R"DOC(
659 Element-wise {name} of each of the input tensors (with Numpy-style broadcasting support).
660 All inputs and outputs must have the same data type.
661 {broadcast_doc}
662 )DOC";
663 ReplaceAll(doc, "{name}", name);
664 ReplaceAll(
665 doc, "{broadcast_doc}", GenerateBroadcastingDocMul().c_str()););
666 schema.SetDoc(doc);
667 schema.Input(
668 0,
669 "data_0",
670 "List of tensors for " + std::string(name) + ".",
671 "T",
672 OpSchema::Variadic);
673 schema.Output(0, name, "Output tensor.", "T");
674 schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
675 propagateElemTypeFromInputToOutput(ctx, 0, 0);
676 int num_inputs = static_cast<int>(ctx.getNumInputs());
677 std::vector<const TensorShapeProto*> shapes;
678 for (int i = 0; i < num_inputs; ++i) {
679 auto input_type = ctx.getInputType(i);
680 if (nullptr == input_type || !input_type->has_tensor_type() ||
681 !input_type->tensor_type().has_shape()) {
682 return;
683 }
684 shapes.push_back(&input_type->tensor_type().shape());
685 }
686
687 multidirectionalBroadcastShapeInference(
688 shapes,
689 *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
690 });
691 };
692 }
693
694 ONNX_OPERATOR_SET_SCHEMA(
695 Max,
696 12,
697 OpSchema()
698 .FillUsing(ElementwiseMultiOpDocGenerator_opset8("max"))
699 .TypeConstraint(
700 "T",
701 OpSchema::all_numeric_types(),
702 "Constrain input and output types to numeric tensors."));
703
704 ONNX_OPERATOR_SET_SCHEMA(
705 Min,
706 12,
707 OpSchema()
708 .FillUsing(ElementwiseMultiOpDocGenerator_opset8("min"))
709 .TypeConstraint(
710 "T",
711 OpSchema::all_numeric_types(),
712 "Constrain input and output types to numeric tensors."));
713
714 ONNX_OPERATOR_SET_SCHEMA(
715 Sum,
716 8,
717 OpSchema()
718 .FillUsing(ElementwiseMultiOpDocGenerator_opset8("sum"))
719 .TypeConstraint(
720 "T",
721 {"tensor(float16)", "tensor(float)", "tensor(double)"},
722 "Constrain input and output types to float tensors."));
723
724 ONNX_OPERATOR_SET_SCHEMA(
725 Mean,
726 8,
727 OpSchema()
728 .FillUsing(ElementwiseMultiOpDocGenerator_opset8("mean"))
729 .TypeConstraint(
730 "T",
731 {"tensor(float16)", "tensor(float)", "tensor(double)"},
732 "Constrain input and output types to float tensors."));
733
734 static const char* Clip_ver12_doc = R"DOC(
735 Clip operator limits the given input within an interval. The interval is
736 specified by the inputs 'min' and 'max'. They default to
737 numeric_limits::lowest() and numeric_limits::max(), respectively.
738 )DOC";
739
740 ONNX_OPERATOR_SET_SCHEMA(
741 Clip,
742 12,
743 OpSchema()
744 .SetDoc(Clip_ver12_doc)
745 .Input(0, "input", "Input tensor whose elements to be clipped", "T")
746 .Input(
747 1,
748 "min",
749 "Minimum value, under which element is replaced by min. "
750 "It must be a scalar(tensor of empty shape).",
751 "T",
752 OpSchema::Optional)
753 .Input(
754 2,
755 "max",
756 "Maximum value, above which element is replaced by max. "
757 "It must be a scalar(tensor of empty shape).",
758 "T",
759 OpSchema::Optional)
760 .Output(0, "output", "Output tensor with clipped input elements", "T")
761 .TypeConstraint(
762 "T",
763 OpSchema::all_numeric_types(),
764 "Constrain input and output types to all numeric tensors.")
765 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
766
767 static const char* Gemm_ver11_doc = R"DOC(General Matrix multiplication:
768 https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3
769
770 A' = transpose(A) if transA else A
771
772 B' = transpose(B) if transB else B
773
774 Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M),
775 input tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to shape (M, N),
776 and output tensor Y has shape (M, N). A will be transposed before doing the
777 computation if attribute transA is non-zero, same for B and transB.
778 )DOC";
779
780 ONNX_OPERATOR_SET_SCHEMA(
781 Gemm,
782 11,
783 OpSchema()
784 .SetDoc(GET_OP_DOC_STR(
785 std::string(Gemm_ver11_doc) +
786 GenerateBroadcastingDocUni("tensor C", "tensor A * B") + "\n" +
787 GenerateOptionalArgumentsDoc()))
788 .Input(
789 0,
790 "A",
791 "Input tensor A. "
792 "The shape of A should be (M, K) if transA is 0, "
793 "or (K, M) if transA is non-zero.",
794 "T")
795 .Input(
796 1,
797 "B",
798 "Input tensor B. "
799 "The shape of B should be (K, N) if transB is 0, "
800 "or (N, K) if transB is non-zero.",
801 "T")
802 .Input(
803 2,
804 "C",
805 "Optional input tensor C. "
806 "If not specified, the computation is done as if C is a scalar 0. "
807 "The shape of C should be unidirectional broadcastable to (M, N).",
808 "T",
809 OpSchema::Optional)
810 .Output(0, "Y", "Output tensor of shape (M, N).", "T")
811 .TypeConstraint(
812 "T",
813 {"tensor(float16)",
814 "tensor(float)",
815 "tensor(double)",
816 "tensor(uint32)",
817 "tensor(uint64)",
818 "tensor(int32)",
819 "tensor(int64)"},
820 "Constrain input and output types to float/int tensors.")
821 .Attr(
822 "transA",
823 "Whether A should be transposed",
824 AttributeProto::INT,
825 static_cast<int64_t>(0))
826 .Attr(
827 "transB",
828 "Whether B should be transposed",
829 AttributeProto::INT,
830 static_cast<int64_t>(0))
831 .Attr(
832 "alpha",
833 "Scalar multiplier for the product of input tensors A * B.",
834 AttributeProto::FLOAT,
835 1.0f)
836 .Attr(
837 "beta",
838 "Scalar multiplier for input tensor C.",
839 AttributeProto::FLOAT,
840 1.0f)
__anon4a9f2ddb0c02(InferenceContext& ctx) 841 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
842 propagateElemTypeFromInputToOutput(ctx, 0, 0);
843 if (hasNInputShapes(ctx, 2)) {
844 auto transAAttr = ctx.getAttribute("transA");
845 bool transA =
846 transAAttr ? static_cast<int>(transAAttr->i()) != 0 : false;
847 auto transBAttr = ctx.getAttribute("transB");
848 bool transB =
849 transBAttr ? static_cast<int>(transBAttr->i()) != 0 : false;
850 auto& first_input_shape = getInputShape(ctx, 0);
851 auto& second_input_shape = getInputShape(ctx, 1);
852 if (first_input_shape.dim_size() != 2) {
853 fail_shape_inference("First input does not have rank 2");
854 }
855 if (second_input_shape.dim_size() != 2) {
856 fail_shape_inference("Second input does not have rank 2");
857 }
858 updateOutputShape(
859 ctx,
860 0,
861 {first_input_shape.dim(transA ? 1 : 0),
862 second_input_shape.dim(transB ? 0 : 1)});
863 }
864 }));
865
matmulShapeInference_opset_9(ONNX_NAMESPACE::InferenceContext & ctx,int input1Idx,int input2Idx)866 void matmulShapeInference_opset_9(
867 ONNX_NAMESPACE::InferenceContext& ctx,
868 int input1Idx,
869 int input2Idx) {
870 if (!hasInputShape(ctx, input1Idx) || !hasInputShape(ctx, input2Idx)) {
871 return;
872 }
873
874 const auto shape0 = ctx.getInputType(input1Idx)->tensor_type().shape();
875 const auto shape1 = ctx.getInputType(input2Idx)->tensor_type().shape();
876
877 if (shape0.dim_size() == 0 || shape1.dim_size() == 0) {
878 fail_shape_inference("Input tensors of wrong rank (0).");
879 }
880
881 ONNX_NAMESPACE::TensorShapeProto shapeL, shapeR;
882
883 // First promote each shape to at least rank-2. This logic is
884 // specific to matmul, not generic broadcasting.
885 {
886 if (shape0.dim_size() == 1) {
887 shapeL.add_dim()->set_dim_value(1);
888 *shapeL.add_dim() = shape0.dim(0);
889 } else {
890 *shapeL.mutable_dim() = shape0.dim();
891 }
892 if (shape1.dim_size() == 1) {
893 *shapeR.add_dim() = shape1.dim(0);
894 shapeR.add_dim()->set_dim_value(1);
895 } else {
896 *shapeR.mutable_dim() = shape1.dim();
897 }
898 }
899
900 // Check for compatible matrix multiply dimensions
901 {
902 auto dimL = shapeL.dim(shapeL.dim_size() - 1);
903 auto dimR = shapeR.dim(shapeR.dim_size() - 2);
904 if (dimL.has_dim_value() && dimR.has_dim_value() &&
905 dimL.dim_value() != dimR.dim_value()) {
906 fail_shape_inference("Incompatible dimensions for matrix multiplication");
907 }
908 }
909
910 ONNX_NAMESPACE::TensorShapeProto resultShape;
911
912 // Now call out to generic multidimensional broadcasting for
913 // the broadcastable prefixes.
914 {
915 ONNX_NAMESPACE::TensorShapeProto prefixShapeL, prefixShapeR;
916 for (int i = 0; i < shapeL.dim_size() - 2; ++i) {
917 *prefixShapeL.add_dim() = shapeL.dim(i);
918 }
919 for (int i = 0; i < shapeR.dim_size() - 2; ++i) {
920 *prefixShapeR.add_dim() = shapeR.dim(i);
921 }
922 bidirectionalBroadcastShapeInference(
923 prefixShapeL, prefixShapeR, resultShape);
924 }
925
926 // Back to matmul-specific. Add the trailing dimensions back in.
927 {
928 if (shape0.dim_size() != 1) {
929 *resultShape.add_dim() = shapeL.dim(shapeL.dim_size() - 2);
930 }
931 if (shape1.dim_size() != 1) {
932 *resultShape.add_dim() = shapeR.dim(shapeR.dim_size() - 1);
933 }
934 }
935
936 *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape;
937 }
938
939 static const char* MatMul_ver9_doc = R"DOC(
940 Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
941 )DOC";
942
943 ONNX_OPERATOR_SET_SCHEMA(
944 MatMul,
945 9,
946 OpSchema()
947 .Input(0, "A", "N-dimensional matrix A", "T")
948 .Input(1, "B", "N-dimensional matrix B", "T")
949 .Output(0, "Y", "Matrix multiply results from A * B", "T")
950 .TypeConstraint(
951 "T",
952 {"tensor(float16)",
953 "tensor(float)",
954 "tensor(double)",
955 "tensor(uint32)",
956 "tensor(uint64)",
957 "tensor(int32)",
958 "tensor(int64)"},
959 "Constrain input and output types to float/int tensors.")
960 .SetDoc(MatMul_ver9_doc)
__anon4a9f2ddb0d02(InferenceContext& ctx) 961 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
962 propagateElemTypeFromInputToOutput(ctx, 0, 0);
963 matmulShapeInference_opset_9(ctx, 0, 1);
964 }));
965
966 static const char* Expand_ver8_doc = R"DOC(
967 Broadcast the input tensor following the given shape and the broadcast rule.
968 The broadcast rule is similar to numpy.array(input) * numpy.ones(shape):
969 Dimensions are right alignment;
970 Two corresponding dimension must have the same value, or one of them is equal to 1.
971 Also, this operator is similar to numpy.broadcast_to(input, shape),
972 but the major difference is numpy.broadcast_to() does not allow shape to be smaller than input.size().
973 It is possible that the output.shape is not equal to shape, when some dimensions in shape is equal to 1,
974 or the shape.ndim < input.shape.ndim.
975 )DOC";
976
977 ONNX_OPERATOR_SET_SCHEMA(
978 Expand,
979 8,
980 OpSchema()
981 .SetDoc(Expand_ver8_doc)
982 .Input(0, "input", "Input tensor", "T")
983 .Input(
984 1,
985 "shape",
986 "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule",
987 "tensor(int64)")
988 .Output(0, "output", "Output tensor", "T")
989 .TypeConstraint(
990 "T",
991 OpSchema::all_tensor_types(),
992 "Constrain input and output types to all tensors.")
__anon4a9f2ddb0e02(InferenceContext& ctx) 993 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
994 // Type inference
995 propagateElemTypeFromInputToOutput(ctx, 0, 0);
996
997 // Shape inference
998 // For shape inference (and rank inference), we need both input shape
999 // and values in 'shape' tensor
1000 const auto* shape_initializer = ctx.getInputData(1);
1001 if (hasNInputShapes(ctx, 2) && nullptr != shape_initializer) {
1002 const auto& shape_initializer_shape =
1003 ctx.getInputType(1)->tensor_type().shape();
1004 if (shape_initializer_shape.dim_size() != 1 ||
1005 shape_initializer->data_type() != TensorProto::INT64) {
1006 fail_shape_inference("'shape' input must be 1D tensor of type INT64");
1007 }
1008
1009 const auto& input_shape =
1010 ctx.getInputType(0)->tensor_type().shape();
1011 const auto& shape_data = ParseData<int64_t>(shape_initializer);
1012
1013 TensorShapeProto second_shape;
1014 for (const auto& e : shape_data) {
1015 auto* dim = second_shape.add_dim();
1016 dim->set_dim_value(e);
1017 }
1018
1019 bidirectionalBroadcastShapeInference(
1020 input_shape, second_shape, *getOutputShape(ctx, 0));
1021 }
1022 return;
1023 }));
1024
1025 static const char* Sign_ver9_doc = R"DOC(
1026 Calculate the sign of the given input tensor element-wise.
1027 If input > 0, output 1. if input < 0, output -1. if input == 0, output 0.
1028 )DOC";
1029
1030 ONNX_OPERATOR_SET_SCHEMA(
1031 Sign,
1032 9,
1033 OpSchema()
1034 .SetDoc(Sign_ver9_doc)
1035 .Input(0, "input", "Input tensor", "T")
1036 .Output(
1037 0,
1038 "output",
1039 "The sign of the input tensor "
1040 "computed element-wise. It has the same shape and type of the input.",
1041 "T")
1042 .TypeConstraint(
1043 "T",
1044 OpSchema::all_numeric_types(),
1045 "Constrain input and output types to all numeric tensors.")
1046 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
1047
1048 static const char* Erf_ver9_doc = R"DOC(
1049 Computes the error function of the given input tensor element-wise.
1050 )DOC";
1051
1052 ONNX_OPERATOR_SET_SCHEMA(
1053 Erf,
1054 9,
1055 OpSchema()
1056 .SetDoc(Erf_ver9_doc)
1057 .Input(0, "input", "Input tensor", "T")
1058 .Output(
1059 0,
1060 "output",
1061 "The error function of the input tensor "
1062 "computed element-wise. It has the same shape and type of the input.",
1063 "T")
1064 .TypeConstraint(
1065 "T",
1066 OpSchema::all_numeric_types(),
1067 "Constrain input and output types to all numeric tensors.")
1068 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
1069
1070 static const char* CumSum_ver11_doc = R"DOC(
1071 Performs cumulative sum of the input elements along the given axis.
1072 By default, it will do the sum inclusively meaning the first element is copied as is.
1073 Through an `exclusive` attribute, this behavior can change to exclude the first element.
1074 It can also perform summation in the opposite direction of the axis. For that, set `reverse` attribute to 1.
1075
1076 Example:
1077 ```
1078 input_x = [1, 2, 3]
1079 axis=0
1080 output = [1, 3, 6]
1081 exclusive=1
1082 output = [0, 1, 3]
1083 exclusive=0
1084 reverse=1
1085 output = [6, 5, 3]
1086 exclusive=1
1087 reverse=1
1088 output = [5, 3, 0]
1089 ```
1090 )DOC";
1091
1092 ONNX_OPERATOR_SET_SCHEMA(
1093 CumSum,
1094 11,
1095 OpSchema()
1096 .SetDoc(CumSum_ver11_doc)
1097 .Attr(
1098 "exclusive",
1099 "If set to 1 will return exclusive sum in which the top element is not included."
1100 " In other terms, if set to 1, the j-th output element would be the sum of the first (j-1) elements."
1101 " Otherwise, it would be the sum of the first j elements.",
1102 AttributeProto::INT,
1103 static_cast<int64_t>(0))
1104 .Attr(
1105 "reverse",
1106 "If set to 1 will perform the sums in reverse direction.",
1107 AttributeProto::INT,
1108 static_cast<int64_t>(0))
1109 .Input(
1110 0,
1111 "x",
1112 "An input tensor that is to be processed.",
1113 "T",
1114 OpSchema::Single,
1115 true,
1116 1,
1117 OpSchema::Differentiable)
1118 .Input(
1119 1,
1120 "axis",
1121 "A 0-D tensor. Must be in the range [-rank(x), rank(x)-1]. "
1122 "Negative value means counting dimensions from the back.",
1123 "T2",
1124 OpSchema::Single,
1125 true,
1126 1,
1127 OpSchema::NonDifferentiable)
1128 .Output(
1129 0,
1130 "y",
1131 "Output tensor of the same type as 'x' with cumulative sums of the x's elements",
1132 "T",
1133 OpSchema::Single,
1134 true,
1135 1,
1136 OpSchema::Differentiable)
1137 .TypeConstraint(
1138 "T",
1139 {"tensor(uint32)",
1140 "tensor(uint64)",
1141 "tensor(int32)",
1142 "tensor(int64)",
1143 "tensor(float)",
1144 "tensor(double)"},
1145 "Input can be of any tensor type.")
1146 .TypeConstraint(
1147 "T2",
1148 {"tensor(int32)", "tensor(int64)"},
1149 "axis tensor can be int32 or int64 only")
1150 .TypeAndShapeInferenceFunction(
1151 ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
1152
1153 static const char* NegativeLogLikelihoodLoss_ver12_doc = R"DOC(
1154 A NegativeLogLikelihoodLoss operator computes (weighted) negative log likelihood loss.
1155 Its "input" tensor has the shape of (N, C, d1, d2, ..., dk) where k >= 0.
1156 The "input" tensor contains log-probabilities for input[n, :, d_1, d_2,..., d_k] being in a class of [0, C).
1157 The operator's "target" input tensor has the shape of (N, d1, d2, ..., dk). It encodes class labels (one of C classes)
1158 or it may contain a special value (indicated by an attribute ignore_index) for N x d1 x d2 x ... x dk samples.
1159 The loss value for input[n, :, d_1, d_2,...d_k] being classified as class c = target[n][d_1][d_2]...[d_k] is computed as:
1160 loss[n][d_1][d_2]...[d_k] = -input[n][c][d_1][d_2]...[d_k].
1161 When an optional "weight" is provided, the sample loss is calculated as:
1162 loss[n][d_1][d_2]...[d_k] = -input[n][c][d_1][d_2]...[d_k] * weight[c].
1163 loss is zero for the case when target-value equals ignore_index.
1164
1165 loss[n][d_1][d_2]...[d_k] = 0, when target[n][d_1][d_2]...[d_k] = ignore_index
1166 If "reduction" attribute is set to "none", the operator's output will be the above loss with shape (N, d1, d2, ..., dk).
1167 If "reduction" attribute is set to "mean" (the default attribute value), the output loss is (weight) averaged:
1168 mean(loss), if "weight" is not provided,
1169 or if weight is provided,
1170 sum(loss) / sum(weight[target[n][d_1][d_2]...[d_k]]]), for all samples.
1171 If "reduction" attribute is set to "sum", the output is a scalar:
1172 sum(loss).
1173 See also https://pytorch.org/docs/stable/nn.html#torch.nn.NLLLoss.
1174 Example 1:
1175 // negative log likelihood loss, "none" reduction
1176 N, C, d1 = 2, 3, 2
1177 input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]],
1178 [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]]
1179 target = [[2, 1], [0, 2]]
1180 loss = np.zeros((N, d1))
1181 for n in range(N):
1182 for d_1 in range(d1):
1183 c = target[n][d_1]
1184 loss[n][d_1] = -input[n][c][d_1]
1185 // print(loss)
1186 // [[-3. -2.]
1187 // [-0. -2.]]
1188 Example 2:
1189 // weighted negative log likelihood loss, sum reduction
1190 N, C, d1 = 2, 3, 2
1191 input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]],
1192 [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]]
1193 target = [[2, 1], [0, 2]]
1194 weight = [0.2, 0.3, 0.1]
1195 loss = np.zeros((N, d1))
1196 for n in range(N):
1197 for d_1 in range(d1):
1198 c = target[n][d_1]
1199 loss[n][d_1] = -input[n][c][d_1] * weight[c]
1200 loss = np.sum(loss)
1201 // print(loss)
1202 // -1.1
1203 Example 3:
1204 // weighted negative log likelihood loss, mean reduction
1205 N, C, d1 = 2, 3, 2
1206 input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]],
1207 [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]]
1208 target = [[2, 1], [0, 2]]
1209 weight = [0.2, 0.3, 0.1]
1210 loss = np.zeros((N, d1))
1211 weight_total = 0
1212 for n in range(N):
1213 for d_1 in range(d1):
1214 c = target[n][d_1]
1215 loss[n][d_1] = -input[n][c][d_1] * weight[c]
1216 weight_total = weight_total + weight[c]
1217 loss = np.sum(loss) / weight_total
1218 // print(loss)
1219 // -1.57
1220 )DOC";
1221
ToDimensionOneFloatTensor_old(float value)1222 TensorProto ToDimensionOneFloatTensor_old(float value) {
1223 auto t = ToTensor(std::vector<float>({value}));
1224 t.add_dims(1);
1225 return t;
1226 }
1227
ToDimensionOneTensor_old(int32_t value)1228 TensorProto ToDimensionOneTensor_old(int32_t value) {
1229 auto t = ToTensor(std::vector<int32_t>({value}));
1230 t.add_dims(1);
1231 return t;
1232 }
1233
ToDimensionOneInt64Tensor_old(int64_t value)1234 TensorProto ToDimensionOneInt64Tensor_old(int64_t value) {
1235 auto t = ToTensor(std::vector<int64_t>({value}));
1236 t.add_dims(1);
1237 return t;
1238 }
1239
ToDimensionOneInt64Tensor_old(std::vector<int64_t> value)1240 TensorProto ToDimensionOneInt64Tensor_old(std::vector<int64_t> value) {
1241 auto t = ToTensor(value);
1242 t.add_dims(value.size());
1243 return t;
1244 }
1245
BuildContextDependentFunctionBody_opset12(const FunctionBodyBuildContext & ctx,const OpSchema & schema,FunctionProto & functionProto)1246 bool BuildContextDependentFunctionBody_opset12(
1247 const FunctionBodyBuildContext& ctx,
1248 const OpSchema& schema,
1249 FunctionProto& functionProto) {
1250 if (ctx.getInputType(0) == nullptr) {
1251 // we cannot create a correct function body without knowing the input type
1252 return false;
1253 }
1254 auto input_type = ctx.getInputType(0)->tensor_type().elem_type();
1255 bool float_input = input_type == TensorProto_DataType_FLOAT;
1256 auto reduction_attr_proto = ctx.getAttribute("reduction");
1257 std::string reduction_attr =
1258 reduction_attr_proto != nullptr && reduction_attr_proto->has_s() ? reduction_attr_proto->s() : "mean";
1259 std::vector<FunctionBodyHelper::NodeDef> body;
1260 body.push_back(
1261 {{"const_zero"},
1262 "Constant",
1263 {},
1264 {MakeAttribute("value", ToDimensionOneTensor_old(0))}});
1265
1266 body.push_back(
1267 {{"const_one"},
1268 "Constant",
1269 {},
1270 {MakeAttribute("value", ToDimensionOneTensor_old(1))}});
1271
1272 body.push_back(
1273 {{"expanded_target"},
1274 "Unsqueeze",
1275 {"target"},
1276 {MakeAttribute("axes", std::vector<int64_t>({1}))}});
1277
1278 if (ctx.getAttribute("ignore_index") == nullptr) {
1279 body.push_back(
1280 {{"input_gather_element"},
1281 "GatherElements",
1282 {"input", "expanded_target"},
1283 {MakeAttribute("axis", (int64_t)1)}});
1284
1285 body.push_back({{"loss_NCdd"}, "Neg", {"input_gather_element"}});
1286
1287 body.push_back(
1288 {{"loss_N1dd"},
1289 "Slice",
1290 {"loss_NCdd", "const_zero", "const_one", "const_one"}});
1291
1292 if (!ctx.hasInput(2)) {
1293 if (reduction_attr == "none") {
1294 body.push_back(
1295 {{"loss"},
1296 "Squeeze",
1297 {"loss_N1dd"},
1298 {MakeAttribute("axes", std::vector<int64_t>({1}))}});
1299 } else {
1300 body.push_back(
1301 {{"loss_Ndd"},
1302 "Squeeze",
1303 {"loss_N1dd"},
1304 {MakeAttribute("axes", std::vector<int64_t>({1}))}});
1305 if (reduction_attr == "mean") {
1306 body.push_back(
1307 {{"loss"},
1308 "ReduceMean",
1309 {"loss_Ndd"},
1310 {MakeAttribute("keepdims", (int64_t)0)}});
1311 } else {
1312 body.push_back(
1313 {{"loss"},
1314 "ReduceSum",
1315 {"loss_Ndd"},
1316 {MakeAttribute("keepdims", (int64_t)0)}});
1317 }
1318 }
1319 } else {
1320 body.push_back({{"weight_gather"}, "Gather", {"weight", "target"}});
1321 body.push_back(
1322 {{"loss_unweighted"},
1323 "Squeeze",
1324 {"loss_N1dd"},
1325 {MakeAttribute("axes", std::vector<int64_t>({1}))}});
1326 if (reduction_attr == "none") {
1327 body.push_back({{"loss"}, "Mul", {"loss_unweighted", "weight_gather"}});
1328 } else {
1329 body.push_back(
1330 {{"loss_Ndd"}, "Mul", {"loss_unweighted", "weight_gather"}});
1331 if (reduction_attr == "mean") {
1332 body.push_back(
1333 {{"loss_sum"},
1334 "ReduceSum",
1335 {"loss_Ndd"},
1336 {MakeAttribute("keepdims", (int64_t)0)}});
1337 body.push_back(
1338 {{"weight_gather_sum"},
1339 "ReduceSum",
1340 {"weight_gather"},
1341 {MakeAttribute("keepdims", (int64_t)0)}});
1342 body.push_back({{"loss"}, "Div", {"loss_sum", "weight_gather_sum"}});
1343 } else {
1344 body.push_back(
1345 {{"loss"},
1346 "ReduceSum",
1347 {"loss_Ndd"},
1348 {MakeAttribute("keepdims", (int64_t)0)}});
1349 }
1350 }
1351 }
1352 } else {
1353 body.push_back(
1354 {{"const_ignore_index"},
1355 "Constant",
1356 {},
1357 {MakeAttribute(
1358 "value",
1359 ToDimensionOneInt64Tensor_old(
1360 ctx.getAttribute("ignore_index")->i()))}});
1361
1362 body.push_back(
1363 {{"const_zero_target_typed"},
1364 "Sub",
1365 {"expanded_target", "expanded_target"}});
1366 body.push_back(
1367 {{"expanded_target_int64"},
1368 "Cast",
1369 {"expanded_target"},
1370 {MakeAttribute(
1371 "to",
1372 (int64_t)TensorProto_DataType::TensorProto_DataType_INT64)}});
1373
1374 body.push_back(
1375 {{"mask"}, "Equal", {"expanded_target_int64", "const_ignore_index"}});
1376 body.push_back(
1377 {{"transform_targets"},
1378 "Where",
1379 {"mask", "const_zero_target_typed", "expanded_target"}});
1380 body.push_back(
1381 {{"input_gather_element"},
1382 "GatherElements",
1383 {"input", "transform_targets"},
1384 {MakeAttribute("axis", (int64_t)1)}});
1385 body.push_back(
1386 {{"const_zero_float"},
1387 "Constant",
1388 {},
1389 {MakeAttribute("value", ToDimensionOneFloatTensor_old(0.0f))}});
1390 if (!float_input) {
1391 body.push_back(
1392 {{"const_zero_casted"},
1393 "Cast",
1394 {"const_zero_float"},
1395 {MakeAttribute("to", static_cast<int64_t>(input_type))}});
1396 }
1397 body.push_back(
1398 {{"input_gather_element_transform"},
1399 "Where",
1400 {"mask", float_input ? "const_zero_float" : "const_zero_casted", "input_gather_element"}});
1401 body.push_back({{"loss_NCdd"}, "Neg", {"input_gather_element_transform"}});
1402 body.push_back(
1403 {{"loss_N1dd"},
1404 "Slice",
1405 {"loss_NCdd", "const_zero", "const_one", "const_one"}});
1406
1407 if (!ctx.hasInput(2)) {
1408 body.push_back(
1409 {{"squeeze_mask"},
1410 "Squeeze",
1411 {"mask"},
1412 {MakeAttribute("axes", std::vector<int64_t>({1}))}});
1413
1414 body.push_back(
1415 {{"const_one_float"},
1416 "Constant",
1417 {},
1418 {MakeAttribute("value", ToDimensionOneFloatTensor_old(1.0f))}});
1419 if (!float_input) {
1420 body.push_back(
1421 {{"const_one_casted"},
1422 "Cast",
1423 {"const_one_float"},
1424 {MakeAttribute("to", static_cast<int64_t>(input_type))}});
1425 }
1426 body.push_back(
1427 {{"weight_gather"},
1428 "Where",
1429 {"squeeze_mask", float_input ? "const_zero_float" : "const_zero_casted",
1430 float_input ? "const_one_float" :"const_one_casted"}});
1431
1432 } else {
1433 body.push_back(
1434 {{"weight_gather_temp"}, "Gather", {"weight", "transform_targets"}});
1435
1436 body.push_back(
1437 {{"weight_gather_temp_1"},
1438 "Where",
1439 {"mask", float_input ? "const_zero_float" : "const_zero_casted", "weight_gather_temp"}});
1440
1441 body.push_back(
1442 {{"weight_gather"},
1443 "Squeeze",
1444 {"weight_gather_temp_1"},
1445 {MakeAttribute("axes", std::vector<int64_t>({1}))}});
1446 }
1447
1448 body.push_back(
1449 {{"loss_unweighted"},
1450 "Squeeze",
1451 {"loss_N1dd"},
1452 {MakeAttribute("axes", std::vector<int64_t>({1}))}});
1453 if (reduction_attr == "none") {
1454 body.push_back({{"loss"}, "Mul", {"loss_unweighted", "weight_gather"}});
1455 } else {
1456 body.push_back(
1457 {{"loss_Ndd"}, "Mul", {"loss_unweighted", "weight_gather"}});
1458 if (reduction_attr == "mean") {
1459 body.push_back(
1460 {{"loss_sum"},
1461 "ReduceSum",
1462 {"loss_Ndd"},
1463 {MakeAttribute("keepdims", (int64_t)0)}});
1464 body.push_back(
1465 {{"weight_gather_sum"},
1466 "ReduceSum",
1467 {"weight_gather"},
1468 {MakeAttribute("keepdims", (int64_t)0)}});
1469 body.push_back({{"loss"}, "Div", {"loss_sum", "weight_gather_sum"}});
1470 } else {
1471 body.push_back(
1472 {{"loss"},
1473 "ReduceSum",
1474 {"loss_Ndd"},
1475 {MakeAttribute("keepdims", (int64_t)0)}});
1476 }
1477 }
1478 }
1479
1480 auto func_nodes = FunctionBodyHelper::BuildNodes(body);
1481 for (const auto& node : func_nodes) {
1482 auto new_node = functionProto.add_node();
1483 new_node->CopyFrom(node);
1484 }
1485
1486 schema.BuildFunction(functionProto);
1487 return true;
1488 }
1489
1490 ONNX_OPERATOR_SET_SCHEMA(
1491 NegativeLogLikelihoodLoss,
1492 12,
1493 OpSchema()
1494 .SetDoc(NegativeLogLikelihoodLoss_ver12_doc)
1495 .Input(
1496 0,
1497 "input",
1498 "Input tensor of shape (N, C) or (N, C, d1, d2, ..., dk).",
1499 "T")
1500 .Input(
1501 1,
1502 "target",
1503 "Target tensor of shape (N) or (N, d1, d2, ..., dk). Target element value shall be in range of [0, C). "
1504 "If ignore_index is specified, it may have a value outside [0, C) and the target values should either be "
1505 "in the range [0, C) or have the value ignore_index.",
1506 "Tind")
1507 .Input(
1508 2,
1509 "weight",
1510 "Optional rescaling weight tensor. "
1511 "If given, it has to be a tensor of size C. Otherwise, it is treated as if having all ones.",
1512 "T",
1513 OpSchema::Optional)
1514 .Output(0, "loss", "The negative log likelihood loss", "T")
1515 .Attr(
1516 "reduction",
1517 "Type of reduction to apply to loss: none, sum, mean (default). "
1518 "'none': the output is the loss for each sample. "
1519 "'sum': the output will be summed. "
1520 "'mean': the sum of the output will be divided by the sum of applied weights.",
1521 AttributeProto::STRING,
1522 std::string("mean"))
1523 .Attr(
1524 "ignore_index",
1525 "Specifies a target value that is ignored and does not contribute to the input gradient. It's an optional value.",
1526 AttributeProto::INT,
1527 false)
1528 .TypeConstraint(
1529 "T",
1530 {"tensor(float16)", "tensor(float)", "tensor(double)"},
1531 "Constrain input, weight, and output types to floating-point tensors.")
1532 .TypeConstraint(
1533 "Tind",
1534 {"tensor(int32)", "tensor(int64)"},
1535 "Constrain target to integer types")
1536 .SetContextDependentFunctionBodyBuilder(
1537 BuildContextDependentFunctionBody_opset12)
__anon4a9f2ddb0f02(InferenceContext& ctx) 1538 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
1539 // Type inference
1540 propagateElemTypeFromInputToOutput(ctx, 0, 0);
1541
1542 // Shape inference
1543 if (hasNInputShapes(ctx, 2)) {
1544 const TensorShapeProto& input_shape =
1545 ctx.getInputType(0)->tensor_type().shape();
1546 const TensorShapeProto& target_shape =
1547 ctx.getInputType(1)->tensor_type().shape();
1548
1549 const int input_rank = static_cast<int>(input_shape.dim_size());
1550 const int target_rank = static_cast<int>(target_shape.dim_size());
1551
1552 if (input_rank < 2) {
1553 fail_shape_inference("Input rank must be >= 2.");
1554 }
1555 if (target_rank != input_rank - 1) {
1556 fail_shape_inference(
1557 "Target rank must be 1 less than the input rank.");
1558 }
1559
1560 // match input dimensions (N, C, d1, ..., dk) with target
1561 // dimensions of (C, d1, ..., dk)
1562 for (int dim = 0; dim < target_rank; dim++) {
1563 const auto input_dim =
1564 dim == 0 ? input_shape.dim(dim) : input_shape.dim(dim + 1);
1565 const auto target_dim = target_shape.dim(dim);
1566 if (input_dim.has_dim_value() && target_dim.has_dim_value() &&
1567 input_dim.dim_value() != target_dim.dim_value())
1568 fail_shape_inference(
1569 "Input and target dimension value mismatch.");
1570 }
1571
1572 if (ctx.getNumInputs() == 3 && hasInputShape(ctx, 2)) {
1573 const TensorShapeProto& weight_shape =
1574 ctx.getInputType(2)->tensor_type().shape();
1575 if (weight_shape.dim_size() != 1) {
1576 fail_shape_inference("Weight rank must be 1.");
1577 }
1578 }
1579
1580 TensorShapeProto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
1581 if (getAttribute(ctx, "reduction", "mean") == "none") {
1582 // output tensor is of shape (N, d1, d2, ..., dk) if
1583 // reduction attribute is "none".
1584 for (int i = 0; i < input_rank - 1; i++) {
1585 auto* dim = output_shape->add_dim();
1586 if (i == 0)
1587 *dim = input_shape.dim(i);
1588 else
1589 *dim = input_shape.dim(i + 1);
1590 }
1591 }
1592 // otherwise output is a scalar.
1593 }
1594 }));
1595
1596 const char* reduction_doc_sce_opset12 =
1597 "Type of reduction to apply to loss: none, sum, mean(default). "
1598 "'none': no reduction will be applied, "
1599 "'sum': the output will be summed. "
1600 "'mean': the sum of the output will be divided by the number of "
1601 "elements in the output.";
1602
1603 static const char* SoftmaxCrossEntropyLoss_ver12_doc =
1604 R"DOC(Loss function that measures the softmax cross entropy
1605 between 'scores' and 'labels'.
1606 This operator first computes a loss tensor whose shape is identical to the labels input.
1607 If the input is 2-D with shape (N, C), the loss tensor may be a N-element vector L = (l_1, l_2, ..., l_N).
1608 If the input is N-D tensor with shape (N, C, D1, D2, ..., Dk),
1609 the loss tensor L may have (N, D1, D2, ..., Dk) as its shape and L[i,][j_1][j_2]...[j_k] denotes a scalar element in L.
1610 After L is available, this operator can optionally do a reduction operator.
1611
1612 shape(scores): (N, C) where C is the number of classes, or (N, C, D1, D2,..., Dk),
1613 with K >= 1 in case of K-dimensional loss.
1614 shape(labels): (N) where each value is 0 <= labels[i] <= C-1, or (N, D1, D2,..., Dk),
1615 with K >= 1 in case of K-dimensional loss.
1616
1617 The loss for one sample, l_i, can caculated as follows:
1618 l[i][d1][d2]...[dk] = -y[i][c][d1][d2]..[dk], where i is the index of classes.
1619 or
1620 l[i][d1][d2]...[dk] = -y[i][c][d1][d2]..[dk] * weights[c], if 'weights' is provided.
1621
1622 loss is zero for the case when label-value equals ignore_index.
1623 l[i][d1][d2]...[dk] = 0, when labels[n][d1][d2]...[dk] = ignore_index
1624
1625 where:
1626 p = Softmax(scores)
1627 y = Log(p)
1628 c = labels[i][d1][d2]...[dk]
1629
1630 Finally, L is optionally reduced:
1631 If reduction = 'none', the output is L with shape (N, D1, D2, ..., Dk).
1632 If reduction = 'sum', the output is scalar: Sum(L).
1633 If reduction = 'mean', the output is scalar: ReduceMean(L), or if weight is provided: ReduceSum(L) / ReduceSum(W),
1634 where tensor W is of shape (N, D1, D2, ..., Dk) and W[n][d1][d2]...[dk] = weights[labels[i][d1][d2]...[dk]].
1635 )DOC";
1636
BuildContextDependentFunctionBodySCE_opset12(const FunctionBodyBuildContext & ctx,const OpSchema & schema,FunctionProto & functionProto)1637 bool BuildContextDependentFunctionBodySCE_opset12(
1638 const FunctionBodyBuildContext& ctx,
1639 const OpSchema& schema,
1640 FunctionProto& functionProto) {
1641 std::vector<FunctionBodyHelper::NodeDef> body;
1642
1643 // Using stable implementation of LogSoftmax
1644 body.push_back(
1645 {{"Shape3D"},
1646 "Constant",
1647 {},
1648 {MakeAttribute("value", ToDimensionOneInt64Tensor_old({0,0,-1}))}});
1649 body.push_back(
1650 {{"X_NCD"},
1651 "Reshape",
1652 {"scores", "Shape3D"}});
1653 body.push_back(
1654 {{"X_NDC"},
1655 "Transpose",
1656 {"X_NCD"},
1657 {MakeAttribute("perm", std::vector<int64_t>({0,2,1}))}});
1658 body.push_back(
1659 {{"X_LogSM"},
1660 "LogSoftmax",
1661 {"X_NDC"},
1662 {MakeAttribute("axis", (int64_t)2)}});
1663 body.push_back(
1664 {{"X_LogSM_NCD"},
1665 "Transpose",
1666 {"X_LogSM"},
1667 {MakeAttribute("perm", std::vector<int64_t>({0,2,1}))}});
1668 body.push_back(
1669 {{"X_shape"},
1670 "Shape",
1671 {"scores"}});
1672 body.push_back(
1673 {{"X_Log"},
1674 "Reshape",
1675 {"X_LogSM_NCD", "X_shape"}});
1676
1677 // Review(mzs): Ideally we want to reuse the output from Log for sub-graph
1678 // output as well but looking at the graph resolve code it does not include
1679 // graph outputs as intermediate outputs, hence if intermediate X_log is
1680 // renamed as log_prob then it will be treated as graph output and will not be
1681 // available to NegativeLogLikelihoodLoss. May be my understanding is
1682 // incorrect or there is a bug in function population code in ORTbut I will
1683 // dig further to be 100%. In the meantime we just replicate the log.
1684 if (ctx.hasOutput(1)) {
1685 body.push_back({{"log_prob"}, "Identity", {"X_Log"}});
1686 }
1687
1688 std::vector<std::string> input_tensor_names{"X_Log", "labels"};
1689 std::vector<FunctionBodyHelper::AttributeProtoWrapper> attributes{
1690 MakeRefAttribute("reduction", AttributeProto::STRING)};
1691 // Add weights as input if needed.
1692 if (ctx.hasInput(2)) {
1693 input_tensor_names.push_back("weights");
1694 }
1695
1696 // add ignore_index attributes if needed.
1697 if (ctx.getAttribute("ignore_index") != nullptr) {
1698 attributes.push_back(MakeRefAttribute("ignore_index", AttributeProto::INT));
1699 }
1700
1701 body.push_back(
1702 {{"output"},
1703 "NegativeLogLikelihoodLoss",
1704 input_tensor_names,
1705 attributes});
1706
1707 auto func_nodes = FunctionBodyHelper::BuildNodes(body);
1708 for (const auto& node : func_nodes) {
1709 auto new_node = functionProto.add_node();
1710 new_node->CopyFrom(node);
1711 }
1712
1713 schema.BuildFunction(functionProto);
1714 return true;
1715 }
1716
1717 ONNX_OPERATOR_SET_SCHEMA(
1718 SoftmaxCrossEntropyLoss,
1719 12,
1720 OpSchema()
1721 .SetDoc(SoftmaxCrossEntropyLoss_ver12_doc)
1722 .Attr(
1723 "reduction",
1724 reduction_doc_sce_opset12,
1725 AttributeProto::STRING,
1726 std::string("mean"))
1727 .Attr(
1728 "ignore_index",
1729 "Specifies a target value that is ignored and does not contribute to the input gradient. It's an optional value.",
1730 AttributeProto::INT,
1731 false)
1732 .Input(
1733 0,
1734 "scores",
1735 "The predicted outputs with shape [batch_size, class_size], or "
1736 "[batch_size, class_size, D1, D2 , ..., Dk], where K is the number of dimensions.",
1737 "T")
1738 .Input(
1739 1,
1740 "labels",
1741 "The ground truth output tensor, with shape [batch_size], or "
1742 "[batch_size, D1, D2, ..., Dk], where K is the number of dimensions. "
1743 "Labels element value shall be in range of [0, C). "
1744 "If ignore_index is specified, it may have a value outside [0, C) and the label values should either be "
1745 "in the range [0, C) or have the value ignore_index.",
1746 "Tind")
1747 .Input(
1748 2,
1749 "weights",
1750 "A manual rescaling weight given to each class. If given, it has to "
1751 "be a 1D Tensor assigning weight to each of the classes. Otherwise, "
1752 "it is treated as if having all ones.",
1753 "T",
1754 OpSchema::Optional)
1755 .Output(
1756 0,
1757 "output",
1758 "Weighted loss float Tensor. If reduction is 'none', this has the "
1759 "shape of [batch_size], or [batch_size, D1, D2, ..., Dk] in case of "
1760 "K-dimensional loss. Otherwise, it is a scalar.",
1761 "T")
1762 .Output(
1763 1,
1764 "log_prob",
1765 "Log probability tensor. If the output of softmax is prob, its value is log(prob).",
1766 "T",
1767 OpSchema::Optional)
1768 .TypeConstraint(
1769 "T",
1770 {"tensor(float16)",
1771 "tensor(float)",
1772 "tensor(double)"},
1773 "Constrain input and output types to float tensors.")
1774 .TypeConstraint(
1775 "Tind",
1776 {"tensor(int32)", "tensor(int64)"},
1777 "Constrain target to integer types")
1778 .SetContextDependentFunctionBodyBuilder(
1779 BuildContextDependentFunctionBodySCE_opset12)
__anon4a9f2ddb1002(InferenceContext& ctx) 1780 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
1781 propagateElemTypeFromInputToOutput(ctx, 0, 0);
1782 std::string reduction = getAttribute(ctx, "reduction", "mean");
1783 if (reduction.compare("none") == 0) {
1784 if (hasInputShape(ctx, 1)) {
1785 propagateShapeFromInputToOutput(ctx, 1, 0);
1786 }
1787 } else {
1788 updateOutputShape(ctx, 0, TensorShapeProto());
1789 }
1790
1791 if (ctx.getNumOutputs() == 2) {
1792 propagateElemTypeFromInputToOutput(ctx, 0, 1);
1793 propagateShapeFromInputToOutput(ctx, 0, 1);
1794 }
1795 }));
1796
SoftmaxFamilyDocGenerator_opset1(const char * name,const char * description)1797 std::function<void(OpSchema&)> SoftmaxFamilyDocGenerator_opset1(
1798 const char* name,
1799 const char* description) {
1800 return [=](OpSchema& schema) {
1801 std::string doc;
1802 POPULATE_OP_DOC_STR(doc = R"DOC(
1803 The operator computes the {name} ({description}) values for each layer in the batch
1804 of the given input. The input is a 2-D tensor (Tensor<float>) of size
1805 (batch_size x input_feature_dimensions). The output tensor has the same shape
1806 and contains the {name} values of the corresponding input.
1807
1808 Input does not need to explicitly be a 2D vector; rather, it will be
1809 coerced into one. For an arbitrary n-dimensional tensor
1810 input \in [a_0, a_1, ..., a_{k-1}, a_k, ..., a_{n-1}] and k is
1811 the axis provided, then input will be coerced into a 2-dimensional tensor with
1812 dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. For the default
1813 case where axis=1, this means the input tensor will be coerced into a 2D tensor
1814 of dimensions [a_0, a_1 * ... * a_{n-1}], where a_0 is often the batch size.
1815 In this situation, we must have a_0 = N and a_1 * ... * a_{n-1} = D.
1816 Each of these dimensions must be matched correctly, or else the operator
1817 will throw errors.
1818 )DOC";
1819 ReplaceAll(doc, "{name}", name);
1820 ReplaceAll(doc, "{description}", description););
1821 schema.SetDoc(doc);
1822 schema.Attr(
1823 "axis",
1824 "Describes the axis of the inputs when coerced "
1825 "to 2D; defaults to one because the 0th axis most likely describes "
1826 "the batch_size",
1827 AttributeProto::INT,
1828 static_cast<int64_t>(1));
1829 schema.Input(
1830 0,
1831 "input",
1832 "The input tensor that's coerced into a 2D matrix of size (NxD) "
1833 "as described above.",
1834 "T");
1835 schema.Output(
1836 0,
1837 "output",
1838 "The output values with the same "
1839 "shape as input tensor (the original size without coercion).",
1840 "T");
1841 schema.TypeConstraint(
1842 "T",
1843 {"tensor(float16)", "tensor(float)", "tensor(double)"},
1844 "Constrain input and output types to float tensors.");
1845 schema.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput);
1846 };
1847 }
1848
1849 ONNX_OPERATOR_SET_SCHEMA(
1850 Softmax,
1851 1,
1852 OpSchema().FillUsing(
1853 SoftmaxFamilyDocGenerator_opset1("softmax", "normalized exponential")));
1854
1855 ONNX_OPERATOR_SET_SCHEMA(
1856 LogSoftmax,
1857 1,
1858 OpSchema().FillUsing(
1859 SoftmaxFamilyDocGenerator_opset1("logsoftmax", "log of softmax")));
1860
1861 ONNX_OPERATOR_SET_SCHEMA(
1862 Hardmax,
1863 1,
1864 OpSchema().FillUsing(SoftmaxFamilyDocGenerator_opset1(
1865 "hardmax",
1866 "1 for the first maximum value, and 0 for all others")));
1867
1868 const char* kBroadcastDoc_old = R"DOC(
1869 If necessary the right-hand-side argument will be broadcasted to match the
1870 shape of left-hand-side argument. When broadcasting is specified, the second
1871 tensor can either be of element size 1 (including a scalar tensor and any
1872 tensor with rank equal to or smaller than the first tensor), or having its
1873 shape as a contiguous subset of the first tensor's shape. The starting of the
1874 mutually equal shape is specified by the argument "axis", and if it is not set,
1875 suffix matching is assumed. 1-dim expansion doesn't work yet.
1876
1877 For example, the following tensor shapes are supported (with broadcast=1):
1878
1879 shape(A) = (2, 3, 4, 5), shape(B) = (,), i.e. B is a scalar tensor
1880 shape(A) = (2, 3, 4, 5), shape(B) = (1, 1), i.e. B is an 1-element tensor
1881 shape(A) = (2, 3, 4, 5), shape(B) = (5,)
1882 shape(A) = (2, 3, 4, 5), shape(B) = (4, 5)
1883 shape(A) = (2, 3, 4, 5), shape(B) = (3, 4), with axis=1
1884 shape(A) = (2, 3, 4, 5), shape(B) = (2), with axis=0
1885
1886 Attribute `broadcast=1` needs to be passed to enable broadcasting.
1887 )DOC";
1888
MathDocGenerator_old(const char * name)1889 std::function<void(OpSchema&)> MathDocGenerator_old(const char* name) {
1890 return [=](OpSchema& schema) {
1891 std::string doc;
1892 POPULATE_OP_DOC_STR(doc = R"DOC(
1893 Performs element-wise binary {name} (with limited broadcast support).
1894 {broadcast_doc})DOC";
1895 ReplaceAll(doc, "{name}", name);
1896 ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc_old););
1897 schema.SetDoc(doc);
1898 schema.Attr(
1899 "broadcast",
1900 "Pass 1 to enable broadcasting",
1901 AttributeProto::INT,
1902 static_cast<int64_t>(0));
1903
1904 // This attribute was added via AllowConsumed API in OpSchema.
1905 // After removing the API, we're now using the Attr API to simulate the old
1906 // definition.
1907 schema.Attr(
1908 "consumed_inputs",
1909 "legacy optimization attribute.",
1910 AttributeProto::INTS,
1911 OPTIONAL_VALUE);
1912 schema.Attr(
1913 "axis",
1914 "If set, defines the broadcast dimensions. See doc for details.",
1915 AttributeProto::INT,
1916 OPTIONAL_VALUE);
1917 schema.Input(
1918 0,
1919 "A",
1920 "First operand, should share the type with the second operand.",
1921 "T");
1922 schema.Input(
1923 1,
1924 "B",
1925 "Second operand. With broadcasting can be of smaller size than A. "
1926 "If broadcasting is disabled it should be of the same size.",
1927 "T");
1928 schema.Output(0, "C", "Result, has same dimensions and type as A", "T");
1929 schema.TypeConstraint(
1930 "T",
1931 {"tensor(float16)", "tensor(float)", "tensor(double)"},
1932 "Constrain input and output types to float tensors.");
1933 };
1934 }
1935
MathDocGenerator_old_opset6(const char * name)1936 std::function<void(OpSchema&)> MathDocGenerator_old_opset6(const char* name) {
1937 return [=](OpSchema& schema) {
1938 std::string doc;
1939 POPULATE_OP_DOC_STR(doc = R"DOC(
1940 Performs element-wise binary {name} (with limited broadcast support).
1941 {broadcast_doc})DOC";
1942 ReplaceAll(doc, "{name}", name);
1943 ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc_old););
1944 schema.SetDoc(doc);
1945 schema.Attr(
1946 "broadcast",
1947 "Pass 1 to enable broadcasting",
1948 AttributeProto::INT,
1949 static_cast<int64_t>(0));
1950 schema.Attr(
1951 "axis",
1952 "If set, defines the broadcast dimensions. See doc for details.",
1953 AttributeProto::INT,
1954 OPTIONAL_VALUE);
1955 schema.Input(
1956 0,
1957 "A",
1958 "First operand, should share the type with the second operand.",
1959 "T");
1960 schema.Input(
1961 1,
1962 "B",
1963 "Second operand. With broadcasting can be of smaller size than A. "
1964 "If broadcasting is disabled it should be of the same size.",
1965 "T");
1966 schema.Output(0, "C", "Result, has same dimensions and type as A", "T");
1967 schema.TypeConstraint(
1968 "T",
1969 OpSchema::numeric_types_for_math_reduction(),
1970 "Constrain input and output types to high-precision numeric tensors.");
1971 schema.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput);
1972 };
1973 }
1974
1975 ONNX_OPERATOR_SET_SCHEMA(
1976 Add,
1977 1,
1978 OpSchema().FillUsing(MathDocGenerator_old("addition")));
1979
1980 ONNX_OPERATOR_SET_SCHEMA(
1981 Sub,
1982 1,
1983 OpSchema().FillUsing(MathDocGenerator_old("subtraction")));
1984
1985 ONNX_OPERATOR_SET_SCHEMA(
1986 Mul,
1987 1,
1988 OpSchema().FillUsing(MathDocGenerator_old("multiplication")));
1989
1990 ONNX_OPERATOR_SET_SCHEMA(
1991 Div,
1992 1,
1993 OpSchema().FillUsing(MathDocGenerator_old("division")));
1994
1995 ONNX_OPERATOR_SET_SCHEMA(
1996 Add,
1997 6,
1998 OpSchema().FillUsing(MathDocGenerator_old_opset6("addition")));
1999
2000 ONNX_OPERATOR_SET_SCHEMA(
2001 Sub,
2002 6,
2003 OpSchema().FillUsing(MathDocGenerator_old_opset6("subtraction")));
2004
2005 ONNX_OPERATOR_SET_SCHEMA(
2006 Mul,
2007 6,
2008 OpSchema().FillUsing(MathDocGenerator_old_opset6("multiplication")));
2009
2010 ONNX_OPERATOR_SET_SCHEMA(
2011 Div,
2012 6,
2013 OpSchema().FillUsing(MathDocGenerator_old_opset6("division")));
2014
2015 static const char* Pow_ver1_doc = R"DOC(
2016 Pow takes input data (Tensor<T>) and exponent Tensor, and
2017 produces one output data (Tensor<T>) where the function `f(x) = x^exponent`,
2018 is applied to the data tensor elementwise.
2019 )DOC";
2020
2021 ONNX_OPERATOR_SET_SCHEMA(
2022 Pow,
2023 1,
2024 OpSchema()
2025 .SetDoc(Pow_ver1_doc + std::string(kBroadcastDoc_old))
2026 .Input(0, "X", "Input tensor of any shape, base of the exponent.", "T")
2027 .Input(
2028 1,
2029 "Y",
2030 "Input tensor of any shape broadcastable to X shape, "
2031 "the exponent component.",
2032 "T")
2033 .Attr(
2034 "broadcast",
2035 "Pass 1 to enable broadcasting",
2036 AttributeProto::INT,
2037 static_cast<int64_t>(0))
2038 .Attr(
2039 "axis",
2040 "If set, defines the broadcast dimensions. See doc for details.",
2041 AttributeProto::INT,
2042 OPTIONAL_VALUE)
2043 .Output(0, "Z", "Output tensor (same size as X)", "T")
2044 .TypeConstraint(
2045 "T",
2046 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2047 "Constrain input and output types to float tensors.")
2048 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
2049
2050 static const char* Pow_ver7_doc = R"DOC(
2051 Pow takes input data (Tensor<T>) and exponent Tensor, and
2052 produces one output data (Tensor<T>) where the function `f(x) = x^exponent`,
2053 is applied to the data tensor elementwise.
2054 )DOC";
2055
2056 ONNX_OPERATOR_SET_SCHEMA(
2057 Pow,
2058 7,
2059 OpSchema()
2060 .SetDoc(std::string(Pow_ver7_doc) + GenerateBroadcastingDocMul())
2061 .Input(0, "X", "First operand, base of the exponent.", "T")
2062 .Input(1, "Y", "Second operand, power of the exponent.", "T")
2063 .Output(0, "Z", "Output tensor.", "T")
2064 .TypeConstraint(
2065 "T",
2066 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2067 "Constrain input and output types to float tensors.")
__anon4a9f2ddb1402(InferenceContext& ctx) 2068 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
2069 propagateElemTypeFromInputToOutput(ctx, 0, 0);
2070 if (hasNInputShapes(ctx, 2))
2071 bidirectionalBroadcastShapeInference(
2072 ctx.getInputType(0)->tensor_type().shape(),
2073 ctx.getInputType(1)->tensor_type().shape(),
2074 *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
2075 }));
2076
2077 static const char* Neg_ver1_doc = R"DOC(
2078 Neg takes one input data (Tensor<T>) and produces one output data
2079 (Tensor<T>) where each element flipped sign, y = -x, is applied to
2080 the tensor elementwise.
2081 )DOC";
2082
2083 ONNX_OPERATOR_SET_SCHEMA(
2084 Neg,
2085 1,
2086 OpSchema()
2087 .SetDoc(Neg_ver1_doc)
2088 .Input(0, "X", "Input tensor", "T")
2089 .Output(0, "Y", "Output tensor", "T")
2090 // This attribute was added via AllowConsumed API in OpSchema.
2091 // After removing the API, we're now using the Attr API to simulate the
2092 // old definition.
2093 .Attr(
2094 "consumed_inputs",
2095 "legacy optimization attribute.",
2096 AttributeProto::INTS,
2097 OPTIONAL_VALUE)
2098 .TypeConstraint(
2099 "T",
2100 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2101 "Constrain input and output types to float tensors."));
2102
2103 static const char* Abs_ver1_doc = R"DOC(
2104 Absolute takes one input data (Tensor<T>) and produces one output data
2105 (Tensor<T>) where the absolute is, y = abs(x), is applied to
2106 the tensor elementwise.
2107 )DOC";
2108
2109 ONNX_OPERATOR_SET_SCHEMA(
2110 Abs,
2111 1,
2112 OpSchema()
2113 .SetDoc(Abs_ver1_doc)
2114 .Input(0, "X", "Input tensor", "T")
2115 .Output(0, "Y", "Output tensor", "T")
2116 // This attribute was added via AllowConsumed API in OpSchema.
2117 // After removing the API, we're now using the Attr API to simulate the
2118 // old definition.
2119 .Attr(
2120 "consumed_inputs",
2121 "legacy optimization attribute.",
2122 AttributeProto::INTS,
2123 OPTIONAL_VALUE)
2124 .TypeConstraint(
2125 "T",
2126 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2127 "Constrain input and output types to float tensors."));
2128
2129 static const char* Reciprocal_ver1_doc = R"DOC(
2130 Reciprocal takes one input data (Tensor<T>) and produces one output data
2131 (Tensor<T>) where the reciprocal is, y = 1/x, is applied to
2132 the tensor elementwise.
2133 )DOC";
2134
2135 ONNX_OPERATOR_SET_SCHEMA(
2136 Reciprocal,
2137 1,
2138 OpSchema()
2139 .SetDoc(Reciprocal_ver1_doc)
2140 .Input(0, "X", "Input tensor", "T")
2141 .Output(0, "Y", "Output tensor", "T")
2142 // This attribute was added via AllowConsumed API in OpSchema.
2143 // After removing the API, we're now using the Attr API to simulate the
2144 // old definition.
2145 .Attr(
2146 "consumed_inputs",
2147 "legacy optimization attribute.",
2148 AttributeProto::INTS,
2149 OPTIONAL_VALUE)
2150 .TypeConstraint(
2151 "T",
2152 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2153 "Constrain input and output types to float tensors."));
2154
2155 static const char* Floor_ver1_doc = R"DOC(
2156 Floor takes one input data (Tensor<T>) and produces one output data
2157 (Tensor<T>) where the floor is, y = floor(x), is applied to
2158 the tensor elementwise.
2159 )DOC";
2160
2161 ONNX_OPERATOR_SET_SCHEMA(
2162 Floor,
2163 1,
2164 OpSchema()
2165 .SetDoc(Floor_ver1_doc)
2166 .Input(0, "X", "Input tensor", "T")
2167 .Output(0, "Y", "Output tensor", "T")
2168 // This attribute was added via AllowConsumed API in OpSchema.
2169 // After removing the API, we're now using the Attr API to simulate the
2170 // old definition.
2171 .Attr(
2172 "consumed_inputs",
2173 "legacy optimization attribute.",
2174 AttributeProto::INTS,
2175 OPTIONAL_VALUE)
2176 .TypeConstraint(
2177 "T",
2178 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2179 "Constrain input and output types to float tensors."));
2180
2181 static const char* Ceil_ver1_doc = R"DOC(
2182 Ceil takes one input data (Tensor<T>) and produces one output data
2183 (Tensor<T>) where the ceil is, y = ceil(x), is applied to
2184 the tensor elementwise.
2185 )DOC";
2186
2187 ONNX_OPERATOR_SET_SCHEMA(
2188 Ceil,
2189 1,
2190 OpSchema()
2191 .SetDoc(Ceil_ver1_doc)
2192 .Input(0, "X", "Input tensor", "T")
2193 .Output(0, "Y", "Output tensor", "T")
2194 // This attribute was added via AllowConsumed API in OpSchema.
2195 // After removing the API, we're now using the Attr API to simulate the
2196 // old definition.
2197 .Attr(
2198 "consumed_inputs",
2199 "legacy optimization attribute.",
2200 AttributeProto::INTS,
2201 OPTIONAL_VALUE)
2202 .TypeConstraint(
2203 "T",
2204 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2205 "Constrain input and output types to float tensors."));
2206
2207 static const char* Sqrt_ver1_doc = R"DOC(
2208 Square root takes one input data (Tensor<T>) and produces one output data
2209 (Tensor<T>) where the square root is, y = x^0.5, is applied to
2210 the tensor elementwise. If x is negative, then it will return NaN.
2211 )DOC";
2212
2213 ONNX_OPERATOR_SET_SCHEMA(
2214 Sqrt,
2215 1,
2216 OpSchema()
2217 .SetDoc(Sqrt_ver1_doc)
2218 .Input(0, "X", "Input tensor", "T")
2219 .Output(0, "Y", "Output tensor", "T")
2220 // This attribute was added via AllowConsumed API in OpSchema.
2221 // After removing the API, we're now using the Attr API to simulate the
2222 // old definition.
2223 .Attr(
2224 "consumed_inputs",
2225 "legacy optimization attribute.",
2226 AttributeProto::INTS,
2227 OPTIONAL_VALUE)
2228 .TypeConstraint(
2229 "T",
2230 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2231 "Constrain input and output types to float tensors."));
2232
2233 static const char* Relu_ver1_doc = R"DOC(
2234 Relu takes one input data (Tensor<T>) and produces one output data
2235 (Tensor<T>) where the rectified linear function, y = max(0, x), is applied to
2236 the tensor elementwise.
2237 )DOC";
2238
2239 ONNX_OPERATOR_SET_SCHEMA(
2240 Relu,
2241 1,
2242 OpSchema()
2243 .SetDoc(Relu_ver1_doc)
2244 .Input(0, "X", "Input tensor", "T")
2245 .Output(0, "Y", "Output tensor", "T")
2246 // This attribute was added via AllowConsumed API in OpSchema.
2247 // After removing the API, we're now using the Attr API to simulate the
2248 // old definition.
2249 .Attr(
2250 "consumed_inputs",
2251 "legacy optimization attribute.",
2252 AttributeProto::INTS,
2253 OPTIONAL_VALUE)
2254 .TypeConstraint(
2255 "T",
2256 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2257 "Constrain input and output types to float tensors."));
2258
2259 static const char* LeakyRelu_ver1_doc = R"DOC(
2260 LeakyRelu takes input data (Tensor<T>) and an argument alpha, and produces one
2261 output data (Tensor<T>) where the function `f(x) = alpha * x for x < 0`,
2262 `f(x) = x for x >= 0`, is applied to the data tensor elementwise.
2263 )DOC";
2264
2265 ONNX_OPERATOR_SET_SCHEMA(
2266 LeakyRelu,
2267 1,
2268 OpSchema()
2269 .Attr(
2270 "alpha",
2271 "Coefficient of leakage default to 0.01.",
2272 AttributeProto::FLOAT,
2273 0.01f)
2274 .SetDoc(LeakyRelu_ver1_doc)
2275 .Input(0, "X", "Input tensor", "T")
2276 .Output(0, "Y", "Output tensor", "T")
2277 // This attribute was added via AllowConsumed API in OpSchema.
2278 // After removing the API, we're now using the Attr API to simulate the
2279 // old definition.
2280 .Attr(
2281 "consumed_inputs",
2282 "legacy optimization attribute.",
2283 AttributeProto::INTS,
2284 OPTIONAL_VALUE)
2285 .TypeConstraint(
2286 "T",
2287 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2288 "Constrain input and output types to float tensors."));
2289
2290 static const char* Selu_ver1_doc = R"DOC(
2291 Selu takes one input data (Tensor<T>) and produces one output data
2292 (Tensor<T>) where the scaled exponential linear unit function,
2293 `y = gamma * (alpha * e^x - alpha) for x <= 0`, `y = gamma * x for x > 0`,
2294 is applied to the tensor elementwise.
2295 )DOC";
2296
2297 ONNX_OPERATOR_SET_SCHEMA(
2298 Selu,
2299 1,
2300 OpSchema()
2301 .Attr(
2302 "alpha",
2303 "Coefficient of SELU default to 1.6732.",
2304 AttributeProto::FLOAT,
2305 1.6732f)
2306 .Attr(
2307 "gamma",
2308 "Coefficient of SELU default to 1.0507.",
2309 AttributeProto::FLOAT,
2310 1.0507f)
2311 // This attribute was added via AllowConsumed API in OpSchema.
2312 // After removing the API, we're now using the Attr API to simulate the
2313 // old definition.
2314 .Attr(
2315 "consumed_inputs",
2316 "legacy optimization attribute.",
2317 AttributeProto::INTS,
2318 OPTIONAL_VALUE)
2319 .SetDoc(Selu_ver1_doc)
2320 .Input(0, "X", "Input tensor", "T")
2321 .Output(0, "Y", "Output tensor", "T")
2322 .TypeConstraint(
2323 "T",
2324 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2325 "Constrain input and output types to float tensors."));
2326
2327 static const char* Elu_ver1_doc = R"DOC(
2328 Elu takes one input data (Tensor<T>) and produces one output data
2329 (Tensor<T>) where the function `f(x) = alpha * (exp(x) - 1.) for x <
2330 0`, `f(x) = x for x >= 0`., is applied to the tensor elementwise.
2331
2332 )DOC";
2333
2334 ONNX_OPERATOR_SET_SCHEMA(
2335 Elu,
2336 1,
2337 OpSchema()
2338 .Attr(
2339 "alpha",
2340 "Coefficient of ELU default to 1.0.",
2341 AttributeProto::FLOAT,
2342 1.0f)
2343 // This attribute was added via AllowConsumed API in OpSchema.
2344 // After removing the API, we're now using the Attr API to simulate the
2345 // old definition.
2346 .Attr(
2347 "consumed_inputs",
2348 "legacy optimization attribute.",
2349 AttributeProto::INTS,
2350 OPTIONAL_VALUE)
2351 .SetDoc(Elu_ver1_doc)
2352 .Input(0, "X", "1D input tensor", "T")
2353 .Output(0, "Y", "1D input tensor", "T")
2354 .TypeConstraint(
2355 "T",
2356 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2357 "Constrain input and output types to float tensors."));
2358
2359 static const char* Exp_ver1_doc = R"DOC(
2360 Calculates the exponential of the given input tensor, element-wise.
2361 )DOC";
2362
2363 ONNX_OPERATOR_SET_SCHEMA(
2364 Exp,
2365 1,
2366 OpSchema()
2367 .SetDoc(Exp_ver1_doc)
2368 .Input(0, "input", "Input tensor", "T")
2369 .Output(
2370 0,
2371 "output",
2372 "The exponential of the input tensor computed "
2373 "element-wise",
2374 "T")
2375 // This attribute was added via AllowConsumed API in OpSchema.
2376 // After removing the API, we're now using the Attr API to simulate the
2377 // old definition.
2378 .Attr(
2379 "consumed_inputs",
2380 "legacy optimization attribute.",
2381 AttributeProto::INTS,
2382 OPTIONAL_VALUE)
2383 .TypeConstraint(
2384 "T",
2385 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2386 "Constrain input and output types to float tensors."));
2387
2388 static const char* Log_ver1_doc = R"DOC(
2389 Calculates the natural log of the given input tensor, element-wise.
2390 )DOC";
2391
2392 ONNX_OPERATOR_SET_SCHEMA(
2393 Log,
2394 1,
2395 OpSchema()
2396 .SetDoc(Log_ver1_doc)
2397 .Input(0, "input", "Input tensor", "T")
2398 .Output(
2399 0,
2400 "output",
2401 "The natural log of the input tensor computed "
2402 "element-wise",
2403 "T")
2404 // This attribute was added via AllowConsumed API in OpSchema.
2405 // After removing the API, we're now using the Attr API to simulate the
2406 // old definition.
2407 .Attr(
2408 "consumed_inputs",
2409 "legacy optimization attribute.",
2410 AttributeProto::INTS,
2411 OPTIONAL_VALUE)
2412 .TypeConstraint(
2413 "T",
2414 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2415 "Constrain input and output types to float tensors."));
2416
2417 static const char* Tanh_ver1_doc = R"DOC(
2418 Calculates the hyperbolic tangent of the given input tensor element-wise.
2419 )DOC";
2420
2421 ONNX_OPERATOR_SET_SCHEMA(
2422 Tanh,
2423 1,
2424 OpSchema()
2425 .SetDoc(Tanh_ver1_doc)
2426 .Input(0, "input", "1-D input tensor", "T")
2427 .Output(
2428 0,
2429 "output",
2430 "The hyperbolic tangent values of the input tensor "
2431 "computed element-wise",
2432 "T")
2433 // This attribute was added via AllowConsumed API in OpSchema.
2434 // After removing the API, we're now using the Attr API to simulate the
2435 // old definition.
2436 .Attr(
2437 "consumed_inputs",
2438 "legacy optimization attribute.",
2439 AttributeProto::INTS,
2440 OPTIONAL_VALUE)
2441 .TypeConstraint(
2442 "T",
2443 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2444 "Constrain input and output types to float tensors."));
2445
2446 static const char* PRelu_ver1_doc = R"DOC(
2447
2448 PRelu takes input data (Tensor<T>) and slope tensor as input, and produces one
2449 output data (Tensor<T>) where the function `f(x) = slope * x for x < 0`,
2450 `f(x) = x for x >= 0`., is applied to the data tensor elementwise.
2451
2452 )DOC";
2453
2454 ONNX_OPERATOR_SET_SCHEMA(
2455 PRelu,
2456 1,
2457 OpSchema()
2458 .SetDoc(PRelu_ver1_doc)
2459 .Input(0, "X", "Input tensor", "T")
2460 .Input(
2461 1,
2462 "slope",
2463 "Slope tensor. If `Slope` is of size 1, the value is shared"
2464 "across different channels",
2465 "T")
2466 .Output(0, "Y", "Output tensor", "T")
2467 // This attribute was added via AllowConsumed API in OpSchema.
2468 // After removing the API, we're now using the Attr API to simulate the
2469 // old definition.
2470 .Attr(
2471 "consumed_inputs",
2472 "legacy optimization attribute.",
2473 AttributeProto::INTS,
2474 OPTIONAL_VALUE)
2475 .TypeConstraint(
2476 "T",
2477 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2478 "Constrain input and output types to float tensors."));
2479
2480 ONNX_OPERATOR_SET_SCHEMA(
2481 PRelu,
2482 6,
2483 OpSchema()
2484 .SetDoc(PRelu_ver1_doc)
2485 .Input(0, "X", "Input tensor", "T")
2486 .Input(
2487 1,
2488 "slope",
2489 "Slope tensor. If `Slope` is of size 1, the value is shared"
2490 "across different channels",
2491 "T")
2492 .Output(0, "Y", "Output tensor", "T")
2493 .TypeConstraint(
2494 "T",
2495 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2496 "Constrain input and output types to float tensors.")
2497 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
2498
2499 static const char* PRelu_ver7_doc = R"DOC(
2500 PRelu takes input data (Tensor<T>) and slope tensor as input, and produces one
2501 output data (Tensor<T>) where the function `f(x) = slope * x for x < 0`,
2502 `f(x) = x for x >= 0`., is applied to the data tensor elementwise.
2503 )DOC";
2504
2505 ONNX_OPERATOR_SET_SCHEMA(
2506 PRelu,
2507 7,
2508 OpSchema()
2509 .SetDoc(GET_OP_DOC_STR(
2510 std::string(PRelu_ver7_doc) +
2511 GenerateBroadcastingDocUni("tensor slope", "input tensor X")))
2512 .Input(0, "X", "Input tensor", "T")
2513 .Input(
2514 1,
2515 "slope",
2516 "Slope tensor. The shape of slope can be smaller then first input X; "
2517 "if so, its shape must be unidirectional broadcastable to X",
2518 "T")
2519 .Output(0, "Y", "Output tensor (same size as X)", "T")
2520 .TypeConstraint(
2521 "T",
2522 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2523 "Constrain input and output types to float tensors.")
2524 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
2525
2526 static const char* Sigmoid_ver1_doc = R"DOC(
2527 Sigmoid takes one input data (Tensor<T>) and produces one output data
2528 (Tensor<T>) where the sigmoid function, y = 1 / (1 + exp(-x)), is applied to the
2529 tensor elementwise.
2530 )DOC";
2531
2532 ONNX_OPERATOR_SET_SCHEMA(
2533 Sigmoid,
2534 1,
2535 OpSchema()
2536 .SetDoc(Sigmoid_ver1_doc)
2537 .Input(0, "X", "Input tensor", "T")
2538 .Output(0, "Y", "Output tensor", "T")
2539 // This attribute was added via AllowConsumed API in OpSchema.
2540 // After removing the API, we're now using the Attr API to simulate the
2541 // old definition.
2542 .Attr(
2543 "consumed_inputs",
2544 "legacy optimization attribute.",
2545 AttributeProto::INTS,
2546 OPTIONAL_VALUE)
2547 .TypeConstraint(
2548 "T",
2549 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2550 "Constrain input and output types to float tensors."));
2551
2552 static const char* HardSigmoid_ver1_doc = R"DOC(
2553 HardSigmoid takes one input data (Tensor<T>) and produces one output data
2554 (Tensor<T>) where the HardSigmoid function, y = max(0, min(1, alpha * x + beta)),
2555 is applied to the tensor elementwise.
2556 )DOC";
2557
2558 ONNX_OPERATOR_SET_SCHEMA(
2559 HardSigmoid,
2560 1,
2561 OpSchema()
2562 .Attr(
2563 "alpha",
2564 "Value of alpha default to 0.2",
2565 AttributeProto::FLOAT,
2566 0.2f)
2567 .Attr(
2568 "beta",
2569 "Value of beta default to 0.5",
2570 AttributeProto::FLOAT,
2571 0.5f)
2572 // This attribute was added via AllowConsumed API in OpSchema.
2573 // After removing the API, we're now using the Attr API to simulate the
2574 // old definition.
2575 .Attr(
2576 "consumed_inputs",
2577 "legacy optimization attribute.",
2578 AttributeProto::INTS,
2579 OPTIONAL_VALUE)
2580 .SetDoc(HardSigmoid_ver1_doc)
2581 .Input(0, "X", "Input tensor", "T")
2582 .Output(0, "Y", "Output tensor", "T")
2583 .TypeConstraint(
2584 "T",
2585 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2586 "Constrain input and output types to float tensors."));
2587
2588 static const char* Max_ver1_doc = R"DOC(
2589 Element-wise max of each of the input tensors. All inputs and outputs must
2590 have the same shape and data type.
2591 )DOC";
2592
2593 ONNX_OPERATOR_SET_SCHEMA(
2594 Max,
2595 1,
2596 OpSchema()
2597 .SetDoc(Max_ver1_doc)
2598 .Input(0, "data_0", "List of tensors for Max.", "T", OpSchema::Variadic)
2599 .Output(0, "max", "Output tensor. Same dimension as inputs.", "T")
2600 // This attribute was added via AllowConsumed API in OpSchema.
2601 // After removing the API, we're now using the Attr API to simulate the
2602 // old definition.
2603 .Attr(
2604 "consumed_inputs",
2605 "legacy optimization attribute.",
2606 AttributeProto::INTS,
2607 OPTIONAL_VALUE)
2608 .TypeConstraint(
2609 "T",
2610 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2611 "Constrain input and output types to float tensors."));
2612
2613 static const char* Min_ver1_doc = R"DOC(
2614 Element-wise min of each of the input tensors. All inputs and outputs must
2615 have the same shape and data type.
2616 )DOC";
2617
2618 ONNX_OPERATOR_SET_SCHEMA(
2619 Min,
2620 1,
2621 OpSchema()
2622 .SetDoc(Min_ver1_doc)
2623 .Input(0, "data_0", "List of tensors for Min", "T", OpSchema::Variadic)
2624 .Output(0, "min", "Output tensor. Same dimension as inputs.", "T")
2625 // This attribute was added via AllowConsumed API in OpSchema.
2626 // After removing the API, we're now using the Attr API to simulate the
2627 // old definition.
2628 .Attr(
2629 "consumed_inputs",
2630 "legacy optimization attribute.",
2631 AttributeProto::INTS,
2632 OPTIONAL_VALUE)
2633 .TypeConstraint(
2634 "T",
2635 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2636 "Constrain input and output types to float tensors."));
2637
2638 static const char* Sum_ver1_doc = R"DOC(
2639 Element-wise sum of each of the input tensors. All inputs and outputs must
2640 have the same shape and data type.
2641 )DOC";
2642
2643 ONNX_OPERATOR_SET_SCHEMA(
2644 Sum,
2645 1,
2646 OpSchema()
2647 .SetDoc(Sum_ver1_doc)
2648 .Input(0, "data_0", "List of tensors for Sum.", "T", OpSchema::Variadic)
2649 .Output(0, "sum", "Output tensor. Same dimension as inputs.", "T")
2650 // This attribute was added via AllowConsumed API in OpSchema.
2651 // After removing the API, we're now using the Attr API to simulate the
2652 // old definition.
2653 .Attr(
2654 "consumed_inputs",
2655 "legacy optimization attribute.",
2656 AttributeProto::INTS,
2657 OPTIONAL_VALUE)
2658 .TypeConstraint(
2659 "T",
2660 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2661 "Constrain input and output types to float tensors."));
2662
2663 static const char* Mean_ver1_doc = R"DOC(
2664 Element-wise mean of each of the input tensors. All inputs and outputs must
2665 have the same shape and data type.
2666 )DOC";
2667
2668 ONNX_OPERATOR_SET_SCHEMA(
2669 Mean,
2670 1,
2671 OpSchema()
2672 .SetDoc(Mean_ver1_doc)
2673 .Input(
2674 0,
2675 "data_0",
2676 "List of tensors for Mean.",
2677 "T",
2678 OpSchema::Variadic)
2679 .Output(0, "mean", "Output tensor. Same dimension as inputs.", "T")
2680 // This attribute was added via AllowConsumed API in OpSchema.
2681 // After removing the API, we're now using the Attr API to simulate the
2682 // old definition.
2683 .Attr(
2684 "consumed_inputs",
2685 "legacy optimization attribute.",
2686 AttributeProto::INTS,
2687 OPTIONAL_VALUE)
2688 .TypeConstraint(
2689 "T",
2690 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2691 "Constrain input and output types to float tensors."));
2692
2693 static const char* Clip_ver1_doc = R"DOC(
2694 Clip operator limits the given input within an interval. The interval is
2695 specified with arguments 'min' and 'max'. They default to
2696 numeric_limits::lowest() and numeric_limits::max() respectively.
2697 )DOC";
2698
2699 ONNX_OPERATOR_SET_SCHEMA(
2700 Clip,
2701 1,
2702 OpSchema()
2703 .SetDoc(Clip_ver1_doc)
2704 .Attr(
2705 "min",
2706 "Minimum value, under which element is replaced by min",
2707 AttributeProto::FLOAT,
2708 OPTIONAL_VALUE)
2709 .Attr(
2710 "max",
2711 "Maximum value, above which element is replaced by max",
2712 AttributeProto::FLOAT,
2713 OPTIONAL_VALUE)
2714 // This attribute was added via AllowConsumed API in OpSchema.
2715 // After removing the API, we're now using the Attr API to simulate the
2716 // old definition.
2717 .Attr(
2718 "consumed_inputs",
2719 "legacy optimization attribute.",
2720 AttributeProto::INTS,
2721 OPTIONAL_VALUE)
2722 .Input(0, "input", "Input tensor whose elements to be clipped", "T")
2723 .Output(0, "output", "Output tensor with clipped input elements", "T")
2724 .TypeConstraint(
2725 "T",
2726 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2727 "Constrain input and output types to float tensors."));
2728
2729 static const char* Gemm_ver1_doc = R"DOC(General Matrix multiplication:
2730 https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3
2731 Compute Y = alpha * A * B + beta * C, where input tensor A has
2732 dimension (M X K), input tensor B has dimension (K X N), input tensor C and
2733 output tensor Y have dimension (M X N).
2734 If attribute broadcast is non-zero, input tensor C will be broadcasted to match
2735 the dimension requirement. A will be transposed before doing the computation
2736 if attribute transA is non-zero, same for B and transB.
2737 )DOC";
2738
2739 ONNX_OPERATOR_SET_SCHEMA(
2740 Gemm,
2741 1,
2742 OpSchema()
2743 .SetDoc(Gemm_ver1_doc)
2744 .Input(0, "A", "Input tensor A", "T")
2745 .Input(1, "B", "Input tensor B", "T")
2746 .Input(2, "C", "Input tensor C, can be inplace.", "T")
2747 .Output(0, "Y", "Output tensor.", "T")
2748 .TypeConstraint(
2749 "T",
2750 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2751 "Constrain input and output types to float tensors.")
2752 // This attribute was added via AllowConsumed API in OpSchema.
2753 // After removing the API, we're now using the Attr API to simulate the
2754 // old definition.
2755 .Attr(
2756 "transA",
2757 "Whether A should be transposed",
2758 AttributeProto::INT,
2759 static_cast<int64_t>(0))
2760 .Attr(
2761 "transB",
2762 "Whether B should be transposed",
2763 AttributeProto::INT,
2764 static_cast<int64_t>(0))
2765 .Attr(
2766 "broadcast",
2767 "Whether C should be broadcasted",
2768 AttributeProto::INT,
2769 static_cast<int64_t>(0))
2770 .Attr(
2771 "alpha",
2772 "Scalar multiplier for the product of input tensors A * B, the default value is 1.0.",
2773 AttributeProto::FLOAT,
2774 1.0f)
2775 .Attr(
2776 "beta",
2777 "Scalar multiplier for input tensor C, the default value is 1.0.",
2778 AttributeProto::FLOAT,
2779 1.0f));
2780
2781 static const char* Gemm_ver6_doc = R"DOC(General Matrix multiplication:
2782 https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3
2783 Compute Y = alpha * A * B + beta * C, where input tensor A has
2784 dimension (M X K), input tensor B has dimension (K X N), input tensor C and
2785 output tensor Y have dimension (M X N).
2786 If attribute broadcast is non-zero, input tensor C will be broadcasted to match
2787 the dimension requirement. A will be transposed before doing the computation
2788 if attribute transA is non-zero, same for B and transB.
2789 )DOC";
2790
2791 ONNX_OPERATOR_SET_SCHEMA(
2792 Gemm,
2793 6,
2794 OpSchema()
2795 .SetDoc(Gemm_ver6_doc)
2796 .Input(0, "A", "Input tensor A", "T")
2797 .Input(1, "B", "Input tensor B", "T")
2798 .Input(2, "C", "Input tensor C", "T")
2799 .Output(0, "Y", "Output tensor.", "T")
2800 .TypeConstraint(
2801 "T",
2802 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2803 "Constrain input and output types to float tensors.")
2804 .Attr(
2805 "transA",
2806 "Whether A should be transposed",
2807 AttributeProto::INT,
2808 static_cast<int64_t>(0))
2809 .Attr(
2810 "transB",
2811 "Whether B should be transposed",
2812 AttributeProto::INT,
2813 static_cast<int64_t>(0))
2814 .Attr(
2815 "broadcast",
2816 "Whether C should be broadcasted",
2817 AttributeProto::INT,
2818 static_cast<int64_t>(0))
2819 .Attr(
2820 "alpha",
2821 "Scalar multiplier for the product of input tensors A * B, the default value is 1.0.",
2822 AttributeProto::FLOAT,
2823 1.0f)
2824 .Attr(
2825 "beta",
2826 "Scalar multiplier for input tensor C, the default value is 1.0.",
2827 AttributeProto::FLOAT,
2828 1.0f)
__anon4a9f2ddb1502(InferenceContext& ctx) 2829 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
2830 propagateElemTypeFromInputToOutput(ctx, 0, 0);
2831 if (hasNInputShapes(ctx, 2)) {
2832 auto transAAttr = ctx.getAttribute("transA");
2833 bool transA =
2834 transAAttr ? static_cast<int>(transAAttr->i()) != 0 : false;
2835 auto transBAttr = ctx.getAttribute("transB");
2836 bool transB =
2837 transBAttr ? static_cast<int>(transBAttr->i()) != 0 : false;
2838
2839 *ctx.getOutputType(0)
2840 ->mutable_tensor_type()
2841 ->mutable_shape()
2842 ->add_dim() =
2843 ctx.getInputType(0)->tensor_type().shape().dim(transA ? 1 : 0);
2844 *ctx.getOutputType(0)
2845 ->mutable_tensor_type()
2846 ->mutable_shape()
2847 ->add_dim() =
2848 ctx.getInputType(1)->tensor_type().shape().dim(transB ? 0 : 1);
2849 } else if (
2850 hasInputShape(ctx, 2) &&
2851 (!ctx.getAttribute("broadcast") ||
2852 static_cast<int>(ctx.getAttribute("broadcast")->i()) == 0)) {
2853 *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() =
2854 ctx.getInputType(2)->tensor_type().shape();
2855 }
2856 }));
2857
2858 static const char* Gemm_ver7_doc = R"DOC(General Matrix multiplication:
2859 https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3
2860
2861 A' = transpose(A) if transA else A
2862
2863 B' = transpose(B) if transB else B
2864
2865 Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M),
2866 input tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to shape (M, N),
2867 and output tensor Y has shape (M, N). A will be transposed before doing the
2868 computation if attribute transA is non-zero, same for B and transB.
2869 )DOC";
2870
2871 ONNX_OPERATOR_SET_SCHEMA(
2872 Gemm,
2873 7,
2874 OpSchema()
2875 .SetDoc(GET_OP_DOC_STR(
2876 std::string(Gemm_ver7_doc) +
2877 GenerateBroadcastingDocUni("tensor C", "tensor A * B")))
2878 .Input(
2879 0,
2880 "A",
2881 "Input tensor A. "
2882 "The shape of A should be (M, K) if transA is 0, "
2883 "or (K, M) if transA is non-zero.",
2884 "T")
2885 .Input(
2886 1,
2887 "B",
2888 "Input tensor B. "
2889 "The shape of B should be (K, N) if transB is 0, "
2890 "or (N, K) if transB is non-zero.",
2891 "T")
2892 .Input(
2893 2,
2894 "C",
2895 "Input tensor C. "
2896 "The shape of C should be unidirectional broadcastable to (M, N).",
2897 "T")
2898 .Output(0, "Y", "Output tensor of shape (M, N).", "T")
2899 .TypeConstraint(
2900 "T",
2901 {"tensor(float16)", "tensor(float)", "tensor(double)"},
2902 "Constrain input and output types to float tensors.")
2903 .Attr(
2904 "transA",
2905 "Whether A should be transposed",
2906 AttributeProto::INT,
2907 static_cast<int64_t>(0))
2908 .Attr(
2909 "transB",
2910 "Whether B should be transposed",
2911 AttributeProto::INT,
2912 static_cast<int64_t>(0))
2913 .Attr(
2914 "alpha",
2915 "Scalar multiplier for the product of input tensors A * B.",
2916 AttributeProto::FLOAT,
2917 1.0f)
2918 .Attr(
2919 "beta",
2920 "Scalar multiplier for input tensor C.",
2921 AttributeProto::FLOAT,
2922 1.0f)
__anon4a9f2ddb1602(InferenceContext& ctx) 2923 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
2924 propagateElemTypeFromInputToOutput(ctx, 0, 0);
2925 if (hasNInputShapes(ctx, 2)) {
2926 auto transAAttr = ctx.getAttribute("transA");
2927 bool transA =
2928 transAAttr ? static_cast<int>(transAAttr->i()) != 0 : false;
2929 auto transBAttr = ctx.getAttribute("transB");
2930 bool transB =
2931 transBAttr ? static_cast<int>(transBAttr->i()) != 0 : false;
2932 auto& first_input_shape = getInputShape(ctx, 0);
2933 auto& second_input_shape = getInputShape(ctx, 1);
2934 if (first_input_shape.dim_size() != 2) {
2935 fail_shape_inference("First input does not have rank 2");
2936 }
2937 if (second_input_shape.dim_size() != 2) {
2938 fail_shape_inference("Second input does not have rank 2");
2939 }
2940 updateOutputShape(
2941 ctx,
2942 0,
2943 {first_input_shape.dim(transA ? 1 : 0),
2944 second_input_shape.dim(transB ? 0 : 1)});
2945 }
2946 }));
2947
2948 static const char* Gemm_ver9_doc = R"DOC(General Matrix multiplication:
2949 https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3
2950
2951 A' = transpose(A) if transA else A
2952
2953 B' = transpose(B) if transB else B
2954
2955 Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M),
2956 input tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to shape (M, N),
2957 and output tensor Y has shape (M, N). A will be transposed before doing the
2958 computation if attribute transA is non-zero, same for B and transB.
2959 )DOC";
2960
2961 ONNX_OPERATOR_SET_SCHEMA(
2962 Gemm,
2963 9,
2964 OpSchema()
2965 .SetDoc(GET_OP_DOC_STR(
2966 std::string(Gemm_ver9_doc) +
2967 GenerateBroadcastingDocUni("tensor C", "tensor A * B")))
2968 .Input(
2969 0,
2970 "A",
2971 "Input tensor A. "
2972 "The shape of A should be (M, K) if transA is 0, "
2973 "or (K, M) if transA is non-zero.",
2974 "T")
2975 .Input(
2976 1,
2977 "B",
2978 "Input tensor B. "
2979 "The shape of B should be (K, N) if transB is 0, "
2980 "or (N, K) if transB is non-zero.",
2981 "T")
2982 .Input(
2983 2,
2984 "C",
2985 "Input tensor C. "
2986 "The shape of C should be unidirectional broadcastable to (M, N).",
2987 "T")
2988 .Output(0, "Y", "Output tensor of shape (M, N).", "T")
2989 .TypeConstraint(
2990 "T",
2991 {"tensor(float16)",
2992 "tensor(float)",
2993 "tensor(double)",
2994 "tensor(uint32)",
2995 "tensor(uint64)",
2996 "tensor(int32)",
2997 "tensor(int64)"},
2998 "Constrain input and output types to float/int tensors.")
2999 .Attr(
3000 "transA",
3001 "Whether A should be transposed",
3002 AttributeProto::INT,
3003 static_cast<int64_t>(0))
3004 .Attr(
3005 "transB",
3006 "Whether B should be transposed",
3007 AttributeProto::INT,
3008 static_cast<int64_t>(0))
3009 .Attr(
3010 "alpha",
3011 "Scalar multiplier for the product of input tensors A * B.",
3012 AttributeProto::FLOAT,
3013 1.0f)
3014 .Attr(
3015 "beta",
3016 "Scalar multiplier for input tensor C.",
3017 AttributeProto::FLOAT,
3018 1.0f)
__anon4a9f2ddb1702(InferenceContext& ctx) 3019 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
3020 propagateElemTypeFromInputToOutput(ctx, 0, 0);
3021 if (hasNInputShapes(ctx, 2)) {
3022 auto transAAttr = ctx.getAttribute("transA");
3023 bool transA =
3024 transAAttr ? static_cast<int>(transAAttr->i()) != 0 : false;
3025 auto transBAttr = ctx.getAttribute("transB");
3026 bool transB =
3027 transBAttr ? static_cast<int>(transBAttr->i()) != 0 : false;
3028 auto& first_input_shape = getInputShape(ctx, 0);
3029 auto& second_input_shape = getInputShape(ctx, 1);
3030 if (first_input_shape.dim_size() != 2) {
3031 fail_shape_inference("First input does not have rank 2");
3032 }
3033 if (second_input_shape.dim_size() != 2) {
3034 fail_shape_inference("Second input does not have rank 2");
3035 }
3036 updateOutputShape(
3037 ctx,
3038 0,
3039 {first_input_shape.dim(transA ? 1 : 0),
3040 second_input_shape.dim(transB ? 0 : 1)});
3041 }
3042 }));
3043
3044 static const char* Max_ver6_doc = R"DOC(
3045 Element-wise max of each of the input tensors. All inputs and outputs must
3046 have the same shape and data type.
3047 )DOC";
3048
3049 ONNX_OPERATOR_SET_SCHEMA(
3050 Max,
3051 6,
3052 OpSchema()
3053 .SetDoc(Max_ver6_doc)
3054 .Input(0, "data_0", "List of tensors for Max.", "T", OpSchema::Variadic)
3055 .Output(0, "max", "Output tensor. Same dimension as inputs.", "T")
3056 .TypeConstraint(
3057 "T",
3058 {"tensor(float16)", "tensor(float)", "tensor(double)"},
3059 "Constrain input and output types to float tensors.")
3060 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
3061
3062 static const char* Min_ver6_doc = R"DOC(
3063 Element-wise min of each of the input tensors. All inputs and outputs must
3064 have the same shape and data type.
3065 )DOC";
3066
3067 ONNX_OPERATOR_SET_SCHEMA(
3068 Min,
3069 6,
3070 OpSchema()
3071 .SetDoc(Min_ver6_doc)
3072 .Input(0, "data_0", "List of tensors for Min", "T", OpSchema::Variadic)
3073 .Output(0, "min", "Output tensor. Same dimension as inputs.", "T")
3074 .TypeConstraint(
3075 "T",
3076 {"tensor(float16)", "tensor(float)", "tensor(double)"},
3077 "Constrain input and output types to float tensors.")
3078 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
3079
3080 static const char* Sum_ver6_doc = R"DOC(
3081 Element-wise sum of each of the input tensors. All inputs and outputs must
3082 have the same shape and data type.
3083 )DOC";
3084
3085 ONNX_OPERATOR_SET_SCHEMA(
3086 Sum,
3087 6,
3088 OpSchema()
3089 .SetDoc(Sum_ver6_doc)
3090 .Input(0, "data_0", "List of tensors for Sum.", "T", OpSchema::Variadic)
3091 .Output(0, "sum", "Output tensor. Same dimension as inputs.", "T")
3092 .TypeConstraint(
3093 "T",
3094 {"tensor(float16)", "tensor(float)", "tensor(double)"},
3095 "Constrain input and output types to float tensors.")
3096 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
3097
3098 static const char* Mean_ver6_doc = R"DOC(
3099 Element-wise mean of each of the input tensors. All inputs and outputs must
3100 have the same shape and data type.
3101 )DOC";
3102
3103 ONNX_OPERATOR_SET_SCHEMA(
3104 Mean,
3105 6,
3106 OpSchema()
3107 .SetDoc(Mean_ver6_doc)
3108 .Input(
3109 0,
3110 "data_0",
3111 "List of tensors for Mean.",
3112 "T",
3113 OpSchema::Variadic)
3114 .Output(0, "mean", "Output tensor. Same dimension as inputs.", "T")
3115 .TypeConstraint(
3116 "T",
3117 {"tensor(float16)", "tensor(float)", "tensor(double)"},
3118 "Constrain input and output types to float tensors.")
3119 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
3120
3121 static const char* MatMul_ver1_doc = R"DOC(
3122 Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
3123 )DOC";
3124
3125 ONNX_OPERATOR_SET_SCHEMA(
3126 MatMul,
3127 1,
3128 OpSchema()
3129 .Input(0, "A", "N-dimensional matrix A", "T")
3130 .Input(1, "B", "N-dimensional matrix B", "T")
3131 .Output(0, "Y", "Matrix multiply results from A * B", "T")
3132 .TypeConstraint(
3133 "T",
3134 {"tensor(float16)", "tensor(float)", "tensor(double)"},
3135 "Constrain input and output types to float tensors.")
3136 .SetDoc(MatMul_ver1_doc)
__anon4a9f2ddb1802(InferenceContext& ctx) 3137 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
3138 propagateElemTypeFromInputToOutput(ctx, 0, 0);
3139 if (!hasNInputShapes(ctx, 2)) {
3140 return;
3141 }
3142
3143 const auto shape0 = ctx.getInputType(0)->tensor_type().shape();
3144 const auto shape1 = ctx.getInputType(1)->tensor_type().shape();
3145
3146 if (shape0.dim_size() == 0 || shape1.dim_size() == 0) {
3147 fail_shape_inference("Input tensors of wrong rank (0).");
3148 }
3149
3150 TensorShapeProto shapeL, shapeR;
3151
3152 // First promote each shape to at least rank-2. This logic is
3153 // specific to matmul, not generic broadcasting.
3154 {
3155 if (shape0.dim_size() == 1) {
3156 shapeL.add_dim()->set_dim_value(1);
3157 *shapeL.add_dim() = shape0.dim(0);
3158 } else {
3159 *shapeL.mutable_dim() = shape0.dim();
3160 }
3161 if (shape1.dim_size() == 1) {
3162 *shapeR.add_dim() = shape1.dim(0);
3163 shapeR.add_dim()->set_dim_value(1);
3164 } else {
3165 *shapeR.mutable_dim() = shape1.dim();
3166 }
3167 }
3168
3169 // Check for compatible matrix multiply dimensions
3170 {
3171 auto dimL = shapeL.dim(shapeL.dim_size() - 1);
3172 auto dimR = shapeR.dim(shapeR.dim_size() - 2);
3173 if (dimL.has_dim_value() && dimR.has_dim_value() &&
3174 dimL.dim_value() != dimR.dim_value()) {
3175 fail_shape_inference(
3176 "Incompatible dimensions for matrix multiplication");
3177 ;
3178 }
3179 }
3180
3181 TensorShapeProto resultShape;
3182
3183 // Now call out to generic multidimensional broadcasting for
3184 // the broadcastable prefixes.
3185 {
3186 TensorShapeProto prefixShapeL, prefixShapeR;
3187 for (int i = 0; i < shapeL.dim_size() - 2; ++i) {
3188 *prefixShapeL.add_dim() = shapeL.dim(i);
3189 }
3190 for (int i = 0; i < shapeR.dim_size() - 2; ++i) {
3191 *prefixShapeR.add_dim() = shapeR.dim(i);
3192 }
3193 bidirectionalBroadcastShapeInference(
3194 prefixShapeL, prefixShapeR, resultShape);
3195 }
3196
3197 // Back to matmul-specific. Add the trailing dimensions back in.
3198 {
3199 if (shape0.dim_size() != 1) {
3200 *resultShape.add_dim() = shapeL.dim(shapeL.dim_size() - 2);
3201 }
3202 if (shape1.dim_size() != 1) {
3203 *resultShape.add_dim() = shapeR.dim(shapeR.dim_size() - 1);
3204 }
3205 }
3206
3207 *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() =
3208 resultShape;
3209 }));
3210
3211 static const char* TopK_ver1_doc = R"DOC(
3212 Retrieve the top-K elements along a specified axis. Given an input tensor of
3213 shape [a_1, a_2, ..., a_n, r] and integer argument k, return two outputs:
3214 -Value tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n]
3215 which contains the values of the top k elements along the specified axis
3216 -Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which
3217 contains the indices of the top k elements (original indices from the input
3218 tensor).
3219 Given two equivalent values, this operator uses the indices along the axis as
3220 a tiebreaker. That is, the element with the lower index will appear first.
3221 )DOC";
3222
3223 ONNX_OPERATOR_SET_SCHEMA(
3224 TopK,
3225 1,
3226 OpSchema()
3227 .SetDoc(TopK_ver1_doc)
3228 .Input(0, "X", "Tensor of shape [a_1, a_2, ..., a_n, r]", "T")
3229 .Output(
3230 0,
3231 "Values",
3232 "Tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] "
3233 "containing top K values from the input tensor",
3234 "T")
3235 .Output(
3236 1,
3237 "Indices",
3238 "Tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] "
3239 "containing the corresponding input tensor indices for the top K "
3240 "values.",
3241 "I")
3242 .TypeConstraint(
3243 "T",
3244 {"tensor(float16)", "tensor(float)", "tensor(double)"},
3245 "Constrain input and output types to float tensors.")
3246 .TypeConstraint(
3247 "I",
3248 {"tensor(int64)"},
3249 "Constrain index tensor to int64")
3250 .Attr(
3251 "k",
3252 "Number of top elements to retrieve",
3253 AttributeProto::INT,
3254 true)
3255 .Attr(
3256 "axis",
3257 "Dimension on which to do the sort.",
3258 AttributeProto::INT,
3259 static_cast<int64_t>(-1))
__anon4a9f2ddb1902(InferenceContext& ctx) 3260 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
3261 // Type inference:
3262 propagateElemTypeFromInputToOutput(ctx, 0, 0);
3263 updateOutputElemType(ctx, 1, TensorProto::INT64);
3264
3265 // Shape inference:
3266 if (!hasInputShape(ctx, 0))
3267 return;
3268 auto& input_shape = getInputShape(ctx, 0);
3269 int64_t rank = input_shape.dim_size();
3270 int64_t axis = getAttribute(ctx, "axis", -1);
3271 if (axis < 0)
3272 axis += rank;
3273 if (axis < 0 || axis >= rank) {
3274 fail_shape_inference("Invalid value for attribute axis");
3275 }
3276 int64_t k = getAttribute(ctx, "k", -1);
3277 if (k <= 0) {
3278 fail_shape_inference("Invalid value for attribute k");
3279 }
3280 // TODO: unclear what results should be if axis has less than k
3281 // elements.
3282 TensorShapeProto result_shape = input_shape;
3283 result_shape.mutable_dim(static_cast<int>(axis))->set_dim_value(k);
3284 updateOutputShape(ctx, 0, result_shape);
3285 updateOutputShape(ctx, 1, result_shape);
3286 }));
3287
3288 static const char* TopK_ver10_doc = R"DOC(
3289 Retrieve the top-K elements along a specified axis. Given an input tensor of
3290 shape [a_1, a_2, ..., a_n, r] and integer argument k, return two outputs:
3291 -Value tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n]
3292 which contains the values of the top k elements along the specified axis
3293 -Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which
3294 contains the indices of the top k elements (original indices from the input
3295 tensor).
3296
3297 Given two equivalent values, this operator uses the indices along the axis as
3298 a tiebreaker. That is, the element with the lower index will appear first.
3299 )DOC";
3300
3301 ONNX_OPERATOR_SET_SCHEMA(
3302 TopK,
3303 10,
3304 OpSchema()
3305 .SetDoc(TopK_ver10_doc)
3306 .Input(0, "X", "Tensor of shape [a_1, a_2, ..., a_n, r]", "T")
3307 .Input(
3308 1,
3309 "K",
3310 "A 1-D tensor containing a single positive value corresponding to the number of top elements to retrieve",
3311 "tensor(int64)")
3312 .Output(
3313 0,
3314 "Values",
3315 "Tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] "
3316 "containing top K values from the input tensor",
3317 "T")
3318 .Output(
3319 1,
3320 "Indices",
3321 "Tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] "
3322 "containing the corresponding input tensor indices for the top K "
3323 "values.",
3324 "I")
3325 .TypeConstraint(
3326 "T",
3327 {"tensor(float16)", "tensor(float)", "tensor(double)"},
3328 "Constrain input and output types to float tensors.")
3329 .TypeConstraint(
3330 "I",
3331 {"tensor(int64)"},
3332 "Constrain index tensor to int64")
3333 .Attr(
3334 "axis",
3335 "Dimension on which to do the sort.",
3336 AttributeProto::INT,
3337 static_cast<int64_t>(-1))
__anon4a9f2ddb1a02(InferenceContext& ctx) 3338 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
3339 // Type inference:
3340 propagateElemTypeFromInputToOutput(ctx, 0, 0);
3341 updateOutputElemType(ctx, 1, TensorProto::INT64);
3342 // Shape inference:
3343 if (!hasInputShape(ctx, 0))
3344 return;
3345 auto& input_shape = getInputShape(ctx, 0);
3346 int64_t rank = input_shape.dim_size();
3347 int64_t axis = getAttribute(ctx, "axis", -1);
3348 if (axis < 0)
3349 axis += rank;
3350 if (axis < 0 || axis >= rank) {
3351 fail_shape_inference("Invalid value for attribute axis");
3352 }
3353
3354 const auto& axis_dim = input_shape.dim(static_cast<int>(axis));
3355 const auto* k = ctx.getInputData(1);
3356
3357 // Infer output shape if:
3358 // (1) 'K' is available
3359 // (2) axis_dim has dim value
3360 // Othewise cannot reliably compute output shape as axis dim value is
3361 // unknown and hence cannot determine if axis dim value >= k (which
3362 // should be enforced)
3363 if (nullptr != k && axis_dim.has_dim_value()) {
3364 int64_t k_value = 0;
3365 if (k->dims_size() != 1 || k->dims(0) != 1) {
3366 fail_shape_inference("K input must be a one-dimensional tensor of size 1.");
3367 }
3368
3369 if (k->data_type() == TensorProto::INT64) {
3370 const auto& data = ParseData<int64_t>(k);
3371 k_value = data[0];
3372 } else {
3373 fail_shape_inference("K input must be of type int64.");
3374 }
3375
3376 if (axis_dim.dim_value() < k_value) {
3377 fail_shape_inference("Axis has less than the requested k elements.");
3378 }
3379
3380 TensorShapeProto result_shape = input_shape;
3381 result_shape.mutable_dim(static_cast<int>(axis))
3382 ->set_dim_value(k_value);
3383
3384 updateOutputShape(ctx, 0, result_shape);
3385 updateOutputShape(ctx, 1, result_shape);
3386
3387 return;
3388 }
3389
3390 // Infer output shapes' rank in any case
3391 auto* output_shape_0 = getOutputShape(ctx, 0);
3392 auto* output_shape_1 = getOutputShape(ctx, 1);
3393 for (int i = 0; i < input_shape.dim_size(); ++i) {
3394 output_shape_0->add_dim();
3395 output_shape_1->add_dim();
3396 }
3397
3398 return;
3399 }));
3400
3401 static const char* Clip_ver6_doc = R"DOC(
3402 Clip operator limits the given input within an interval. The interval is
3403 specified with arguments 'min' and 'max'. They default to
3404 numeric_limits::lowest() and numeric_limits::max() respectively.
3405 )DOC";
3406
3407 ONNX_OPERATOR_SET_SCHEMA(
3408 Clip,
3409 6,
3410 OpSchema()
3411 .SetDoc(Clip_ver6_doc)
3412 .Attr(
3413 "min",
3414 "Minimum value, under which element is replaced by min",
3415 AttributeProto::FLOAT,
3416 std::numeric_limits<float>::lowest())
3417 .Attr(
3418 "max",
3419 "Maximum value, above which element is replaced by max",
3420 AttributeProto::FLOAT,
3421 std::numeric_limits<float>::max())
3422 .Input(0, "input", "Input tensor whose elements to be clipped", "T")
3423 .Output(0, "output", "Output tensor with clipped input elements", "T")
3424 .TypeConstraint(
3425 "T",
3426 {"tensor(float16)", "tensor(float)", "tensor(double)"},
3427 "Constrain input and output types to float tensors.")
3428 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
3429
3430 static const char* Clip_ver11_doc = R"DOC(
3431 Clip operator limits the given input within an interval. The interval is
3432 specified by the inputs 'min' and 'max'. They default to
3433 numeric_limits::lowest() and numeric_limits::max(), respectively.
3434 )DOC";
3435
3436 ONNX_OPERATOR_SET_SCHEMA(
3437 Clip,
3438 11,
3439 OpSchema()
3440 .SetDoc(Clip_ver11_doc)
3441 .Input(0, "input", "Input tensor whose elements to be clipped", "T")
3442 .Input(
3443 1,
3444 "min",
3445 "Minimum value, under which element is replaced by min. "
3446 "It must be a scalar(tensor of empty shape).",
3447 "T",
3448 OpSchema::Optional)
3449 .Input(
3450 2,
3451 "max",
3452 "Maximum value, above which element is replaced by max. "
3453 "It must be a scalar(tensor of empty shape).",
3454 "T",
3455 OpSchema::Optional)
3456 .Output(0, "output", "Output tensor with clipped input elements", "T")
3457 .TypeConstraint(
3458 "T",
3459 {"tensor(float16)", "tensor(float)", "tensor(double)"},
3460 "Constrain input and output types to float tensors.")
3461 .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
3462
ElementwiseMultiOpDocGenerator_old(const char * name)3463 std::function<void(OpSchema&)> ElementwiseMultiOpDocGenerator_old(
3464 const char* name) {
3465 return [=](OpSchema& schema) {
3466 std::string doc;
3467 POPULATE_OP_DOC_STR(
3468 doc = R"DOC(
3469 Element-wise {name} of each of the input tensors (with Numpy-style broadcasting support).
3470 All inputs and outputs must have the same data type.
3471 {broadcast_doc}
3472 )DOC";
3473 ReplaceAll(doc, "{name}", name);
3474 ReplaceAll(
3475 doc, "{broadcast_doc}", GenerateBroadcastingDocMul().c_str()););
3476 schema.SetDoc(doc);
3477 schema.Input(
3478 0,
3479 "data_0",
3480 "List of tensors for " + std::string(name) + ".",
3481 "T",
3482 OpSchema::Variadic);
3483 schema.Output(0, name, "Output tensor.", "T");
3484 schema.TypeConstraint(
3485 "T",
3486 {"tensor(float16)", "tensor(float)", "tensor(double)"},
3487 "Constrain input and output types to float tensors.");
3488 schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
3489 propagateElemTypeFromInputToOutput(ctx, 0, 0);
3490 int num_inputs = static_cast<int>(ctx.getNumInputs());
3491 std::vector<const TensorShapeProto*> shapes;
3492 for (int i = 0; i < num_inputs; ++i) {
3493 auto input_type = ctx.getInputType(i);
3494 if (nullptr == input_type || !input_type->has_tensor_type() ||
3495 !input_type->tensor_type().has_shape()) {
3496 return;
3497 }
3498 shapes.push_back(&input_type->tensor_type().shape());
3499 }
3500
3501 multidirectionalBroadcastShapeInference(
3502 shapes,
3503 *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
3504 });
3505 };
3506 }
3507
3508 ONNX_OPERATOR_SET_SCHEMA(
3509 Max,
3510 8,
3511 OpSchema().FillUsing(ElementwiseMultiOpDocGenerator_old("max")));
3512
3513 ONNX_OPERATOR_SET_SCHEMA(
3514 Min,
3515 8,
3516 OpSchema().FillUsing(ElementwiseMultiOpDocGenerator_old("min")));
3517
3518 } // namespace ONNX_NAMESPACE
3519