1 /* 2 * SPDX-License-Identifier: Apache-2.0 3 */ 4 5 #include "onnx/defs/function.h" 6 #include "onnx/defs/schema.h" 7 8 #include <algorithm> 9 #include <numeric> 10 11 namespace ONNX_NAMESPACE { 12 13 static const char* Optional_ver15_doc = R"DOC( 14 Constructs an optional-type value containing either an empty optional of a certain type specified by the attribute, 15 or a non-empty value containing the input element. 16 )DOC"; 17 18 ONNX_OPERATOR_SET_SCHEMA( 19 Optional, 20 15, 21 OpSchema() 22 .SetDoc(Optional_ver15_doc) 23 .Input(0, "input", "The input element.", "V", OpSchema::Optional) 24 .Attr("type", "Type of the element in the optional output", AttributeProto::TYPE_PROTO, OPTIONAL_VALUE) 25 .Output(0, "output", "The optional output enclosing the input element.", "O") 26 .TypeConstraint( 27 "V", __anon2a3bc79a0102()28 [](){ 29 auto t = OpSchema::all_tensor_types(); 30 auto s = OpSchema::all_tensor_sequence_types(); 31 t.insert(t.end(), s.begin(), s.end()); 32 return t; 33 }(), 34 "Constrains input type to all tensor and sequence types.") 35 .TypeConstraint( 36 "O", 37 OpSchema::all_optional_types(), 38 "Constrains output type to all optional tensor or optional sequence types.") __anon2a3bc79a0202(InferenceContext& ctx) 39 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { 40 const size_t numOutputs = ctx.getNumOutputs(); 41 if (numOutputs != 1) { 42 fail_type_inference("Optional is expected to have an output."); 43 } 44 45 const size_t numInputs = ctx.getNumInputs(); 46 const auto* attr_proto = ctx.getAttribute("type"); 47 48 if ((numInputs == 0) && (attr_proto != nullptr)) { 49 if (!attr_proto->has_tp()) 50 fail_type_inference("Attribute 'type' should be a TypeProto and it should specify a type."); 51 auto attr_tp = attr_proto->tp(); 52 53 ctx.getOutputType(0)->mutable_optional_type()->mutable_elem_type()->CopyFrom(attr_tp); 54 } else if (numInputs == 1) { 55 auto input_type = ctx.getInputType(0); 56 if (input_type == nullptr) { 57 fail_type_inference("Input type is null. Type information is expected for the input."); 58 } 59 ctx.getOutputType(0)->mutable_optional_type()->mutable_elem_type()->CopyFrom(*input_type); 60 } else { 61 fail_type_inference("Optional is expected to have either an input or the type attribute set."); 62 } 63 })); 64 65 static const char* OptionalHasElement_ver1_doc = R"DOC( 66 Returns true if the optional-type input contains an element. If it is an empty optional-type, this op returns false. 67 )DOC"; 68 69 ONNX_OPERATOR_SET_SCHEMA( 70 OptionalHasElement, 71 15, 72 OpSchema() 73 .SetDoc(OptionalHasElement_ver1_doc) 74 .Input(0, "input", "The optional input.", "O") 75 .Output( 76 0, 77 "output", 78 "A scalar boolean tensor. If true, it indicates that optional-type input contains an element. Otherwise, it is empty.", 79 "B") 80 .TypeConstraint( 81 "O", 82 OpSchema::all_optional_types(), 83 "Constrains input type to optional tensor and optional sequence types.") 84 .TypeConstraint("B", {"tensor(bool)"}, "Constrains output to a boolean tensor.") __anon2a3bc79a0302(InferenceContext& ctx) 85 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { 86 const size_t numInputs = ctx.getNumInputs(); 87 if (numInputs != 1) { 88 fail_type_inference("OptionalHasElement is expected to have 1 input."); 89 } 90 const size_t numOutputs = ctx.getNumOutputs(); 91 if (numOutputs != 1) { 92 fail_type_inference("OptionalHasElement is expected to have 1 output."); 93 } 94 auto* output_tensor_type = ctx.getOutputType(0)->mutable_tensor_type(); 95 output_tensor_type->set_elem_type(TensorProto::BOOL); 96 output_tensor_type->mutable_shape()->Clear(); 97 })); 98 99 static const char* OptionalGetElement_ver1_doc = R"DOC( 100 Outputs the element in the optional-type input. It is an error if the input value does not have an element 101 and the behavior is undefined in this case. 102 )DOC"; 103 104 ONNX_OPERATOR_SET_SCHEMA( 105 OptionalGetElement, 106 15, 107 OpSchema() 108 .SetDoc(OptionalGetElement_ver1_doc) 109 .Input(0, "input", "The optional input.", "O") 110 .Output(0, "output", "Output element in the optional input.", "V") 111 .TypeConstraint( 112 "O", 113 OpSchema::all_optional_types(), 114 "Constrains input type to optional tensor and optional sequence types.") 115 .TypeConstraint( 116 "V", __anon2a3bc79a0402()117 [](){ 118 auto t = OpSchema::all_tensor_types(); 119 auto s = OpSchema::all_tensor_sequence_types(); 120 t.insert(t.end(), s.begin(), s.end()); 121 return t; 122 }(), 123 "Constrain output type to all tensor or sequence types.") __anon2a3bc79a0502(InferenceContext& ctx) 124 .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { 125 const size_t numInputs = ctx.getNumInputs(); 126 if (numInputs != 1) { 127 fail_type_inference("OptionalGetElement must have an input element."); 128 } 129 auto input_type = ctx.getInputType(0); 130 if (input_type == nullptr) { 131 fail_type_inference("Input type is null. Input must have Type information."); 132 } 133 if (!input_type->has_optional_type() || !input_type->optional_type().has_elem_type()) { 134 fail_type_inference("Input must be an optional-type value containing an element with type information."); 135 } 136 ctx.getOutputType(0)->CopyFrom(input_type->optional_type().elem_type()); 137 })); 138 139 } // namespace ONNX_NAMESPACE