1 /*
2  * SPDX-License-Identifier: Apache-2.0
3  */
4 
5 #include "onnx/defs/function.h"
6 #include "onnx/defs/schema.h"
7 
8 #include <algorithm>
9 #include <numeric>
10 
11 namespace ONNX_NAMESPACE {
12 
13 static const char* Optional_ver15_doc = R"DOC(
14 Constructs an optional-type value containing either an empty optional of a certain type specified by the attribute,
15 or a non-empty value containing the input element.
16 )DOC";
17 
18 ONNX_OPERATOR_SET_SCHEMA(
19     Optional,
20     15,
21     OpSchema()
22         .SetDoc(Optional_ver15_doc)
23         .Input(0, "input", "The input element.", "V", OpSchema::Optional)
24         .Attr("type", "Type of the element in the optional output", AttributeProto::TYPE_PROTO, OPTIONAL_VALUE)
25         .Output(0, "output", "The optional output enclosing the input element.", "O")
26         .TypeConstraint(
27             "V",
__anon2a3bc79a0102()28             [](){
29               auto t = OpSchema::all_tensor_types();
30               auto s = OpSchema::all_tensor_sequence_types();
31               t.insert(t.end(), s.begin(), s.end());
32               return t;
33             }(),
34             "Constrains input type to all tensor and sequence types.")
35         .TypeConstraint(
36             "O",
37             OpSchema::all_optional_types(),
38             "Constrains output type to all optional tensor or optional sequence types.")
__anon2a3bc79a0202(InferenceContext& ctx) 39         .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
40           const size_t numOutputs = ctx.getNumOutputs();
41           if (numOutputs != 1) {
42             fail_type_inference("Optional is expected to have an output.");
43           }
44 
45           const size_t numInputs = ctx.getNumInputs();
46           const auto* attr_proto = ctx.getAttribute("type");
47 
48           if ((numInputs == 0) && (attr_proto != nullptr)) {
49             if (!attr_proto->has_tp())
50               fail_type_inference("Attribute 'type' should be a TypeProto and it should specify a type.");
51             auto attr_tp = attr_proto->tp();
52 
53             ctx.getOutputType(0)->mutable_optional_type()->mutable_elem_type()->CopyFrom(attr_tp);
54           } else if (numInputs == 1) {
55             auto input_type = ctx.getInputType(0);
56             if (input_type == nullptr) {
57               fail_type_inference("Input type is null. Type information is expected for the input.");
58             }
59             ctx.getOutputType(0)->mutable_optional_type()->mutable_elem_type()->CopyFrom(*input_type);
60           } else {
61             fail_type_inference("Optional is expected to have either an input or the type attribute set.");
62           }
63         }));
64 
65 static const char* OptionalHasElement_ver1_doc = R"DOC(
66 Returns true if the optional-type input contains an element. If it is an empty optional-type, this op returns false.
67 )DOC";
68 
69 ONNX_OPERATOR_SET_SCHEMA(
70     OptionalHasElement,
71     15,
72     OpSchema()
73         .SetDoc(OptionalHasElement_ver1_doc)
74         .Input(0, "input", "The optional input.", "O")
75         .Output(
76             0,
77             "output",
78             "A scalar boolean tensor. If true, it indicates that optional-type input contains an element. Otherwise, it is empty.",
79             "B")
80         .TypeConstraint(
81             "O",
82             OpSchema::all_optional_types(),
83             "Constrains input type to optional tensor and optional sequence types.")
84         .TypeConstraint("B", {"tensor(bool)"}, "Constrains output to a boolean tensor.")
__anon2a3bc79a0302(InferenceContext& ctx) 85         .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
86           const size_t numInputs = ctx.getNumInputs();
87           if (numInputs != 1) {
88             fail_type_inference("OptionalHasElement is expected to have 1 input.");
89           }
90           const size_t numOutputs = ctx.getNumOutputs();
91           if (numOutputs != 1) {
92             fail_type_inference("OptionalHasElement is expected to have 1 output.");
93           }
94           auto* output_tensor_type = ctx.getOutputType(0)->mutable_tensor_type();
95           output_tensor_type->set_elem_type(TensorProto::BOOL);
96           output_tensor_type->mutable_shape()->Clear();
97         }));
98 
99 static const char* OptionalGetElement_ver1_doc = R"DOC(
100 Outputs the element in the optional-type input. It is an error if the input value does not have an element
101 and the behavior is undefined in this case.
102 )DOC";
103 
104 ONNX_OPERATOR_SET_SCHEMA(
105     OptionalGetElement,
106     15,
107     OpSchema()
108         .SetDoc(OptionalGetElement_ver1_doc)
109         .Input(0, "input", "The optional input.", "O")
110         .Output(0, "output", "Output element in the optional input.", "V")
111         .TypeConstraint(
112             "O",
113             OpSchema::all_optional_types(),
114             "Constrains input type to optional tensor and optional sequence types.")
115         .TypeConstraint(
116             "V",
__anon2a3bc79a0402()117             [](){
118               auto t = OpSchema::all_tensor_types();
119               auto s = OpSchema::all_tensor_sequence_types();
120               t.insert(t.end(), s.begin(), s.end());
121               return t;
122             }(),
123             "Constrain output type to all tensor or sequence types.")
__anon2a3bc79a0502(InferenceContext& ctx) 124         .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
125           const size_t numInputs = ctx.getNumInputs();
126           if (numInputs != 1) {
127             fail_type_inference("OptionalGetElement must have an input element.");
128           }
129           auto input_type = ctx.getInputType(0);
130           if (input_type == nullptr) {
131             fail_type_inference("Input type is null. Input must have Type information.");
132           }
133           if (!input_type->has_optional_type() || !input_type->optional_type().has_elem_type()) {
134             fail_type_inference("Input must be an optional-type value containing an element with type information.");
135           }
136           ctx.getOutputType(0)->CopyFrom(input_type->optional_type().elem_type());
137         }));
138 
139 } // namespace ONNX_NAMESPACE