1 // Copyright (c) ONNX Project Contributors.
2 // Licensed under the MIT license.
3 
4 #include "onnx/defs/schema.h"
5 #include <stdexcept>
6 #include <unordered_set>
7 #include "onnx/checker.h"
8 #include "onnx/defs/operator_sets.h"
9 
10 #ifdef ONNX_ML
11 #include "onnx/defs/operator_sets-ml.h"
12 #endif
13 
14 #include "onnx/common/assertions.h"
15 #include "onnx/common/stl_backports.h"
16 
17 namespace ONNX_NAMESPACE {
18 
RegisterSchema(OpSchema && schema)19 void RegisterSchema(OpSchema&& schema) {
20   OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration = schema;
21 }
22 
23 #ifndef NDEBUG
Instance()24 DbgOperatorSetTracker& DbgOperatorSetTracker::Instance() {
25   static DbgOperatorSetTracker instance;
26   return instance;
27 }
28 #endif
29 
FormalParameter(std::string name,DataTypeSet allowed_type_set,std::string type_str,std::string description,FormalParameterOption param_option,bool is_homogeneous)30 OpSchema::FormalParameter::FormalParameter(
31     std::string name,
32     DataTypeSet allowed_type_set,
33     std::string type_str,
34     std::string description,
35     FormalParameterOption param_option,
36     bool is_homogeneous)
37     : name_(std::move(name)),
38       type_set_(std::move(allowed_type_set)),
39       type_str_(std::move(type_str)),
40       description_(std::move(description)),
41       param_option_(param_option),
42       is_homogeneous_(is_homogeneous) {}
43 
FormalParameter(std::string name,std::string description,std::string type_str,FormalParameterOption param_option,bool is_homogeneous)44 OpSchema::FormalParameter::FormalParameter(
45     std::string name,
46     std::string description,
47     std::string type_str,
48     FormalParameterOption param_option,
49     bool is_homogeneous)
50     : name_(std::move(name)),
51       type_str_(std::move(type_str)),
52       description_(std::move(description)),
53       param_option_(param_option),
54       is_homogeneous_(is_homogeneous) {}
55 
GetName() const56 const std::string& OpSchema::FormalParameter::GetName() const {
57   return name_;
58 }
59 
GetTypes() const60 const DataTypeSet& OpSchema::FormalParameter::GetTypes() const {
61   return type_set_;
62 }
63 
MutableTypes()64 DataTypeSet& OpSchema::FormalParameter::MutableTypes() {
65   return type_set_;
66 }
67 
GetTypeStr() const68 const std::string& OpSchema::FormalParameter::GetTypeStr() const {
69   return type_str_;
70 }
71 
GetDescription() const72 const std::string& OpSchema::FormalParameter::GetDescription() const {
73   return description_;
74 }
75 
GetOption() const76 OpSchema::FormalParameterOption OpSchema::FormalParameter::GetOption() const {
77   return param_option_;
78 }
79 
GetIsHomogeneous() const80 bool OpSchema::FormalParameter::GetIsHomogeneous() const {
81   return is_homogeneous_;
82 }
83 
Instance()84 OpSchemaRegistry* OpSchemaRegistry::Instance() {
85   static OpSchemaRegistry instance;
86   return &instance;
87 }
88 
Verify(const NodeProto & node) const89 void OpSchema::Verify(const NodeProto& node) const {
90   if (deprecated_) {
91     fail_check(
92         "Operator '",
93         name_,
94         "' has been deprecated since version ",
95         since_version_);
96   }
97 
98   // Check the number of inputs.
99   if (node.input_size() < min_input_ || node.input_size() > max_input_) {
100     fail_check(
101         "Node (",
102         node.name(),
103         ") has input size ",
104         node.input_size(),
105         " not in range [min=",
106         min_input_,
107         ", max=",
108         max_input_,
109         "].");
110   }
111 
112   if (!num_inputs_allowed_(node.input_size())) {
113     fail_check(
114         "Node (",
115         node.name(),
116         ") has input size ",
117         node.input_size(),
118         " not in allowed input sizes.");
119   }
120 
121   // Check the number of outputs.
122   if (node.output_size() < min_output_ || node.output_size() > max_output_) {
123     fail_check(
124         "Node (",
125         node.name(),
126         ") has output size ",
127         node.output_size(),
128         " not in range [min=",
129         min_output_,
130         ", max=",
131         max_output_,
132         "].");
133   }
134 
135   if (!num_outputs_allowed_(node.output_size())) {
136     fail_check(
137         "Node (",
138         node.name(),
139         "has output size ",
140         node.output_size(),
141         " not in allowed output sizes.");
142   }
143 
144   // Check the values of inputs / outputs
145   for (int in_idx = 0; in_idx < node.input_size(); ++in_idx) {
146     if (in_idx >= static_cast<int>(inputs_.size())) {
147       if (inputs_.size() > 0 && Variadic == inputs_.back().GetOption()) {
148         // The last input formal parameter should be variadic.
149         break;
150       } else {
151         fail_check(
152             "Node (",
153             node.name(),
154             ") has more inputs (",
155             node.input_size(),
156             ") than declared (",
157             inputs_.size(),
158             ") in op definition.");
159       }
160     }
161     if (node.input(in_idx).empty() && (Single == inputs_[in_idx].GetOption())) {
162       fail_check(
163           "Node (",
164           node.name(),
165           ")'s input ",
166           in_idx,
167           " is marked single but has an empty string in the graph");
168     }
169   }
170 
171   for (int out_idx = 0; out_idx < node.output_size(); ++out_idx) {
172     if (out_idx >= static_cast<int>(outputs_.size())) {
173       if (outputs_.size() > 0 && Variadic == outputs_.back().GetOption()) {
174         // The last output formal parameter should be variadic.
175         break;
176       } else {
177         fail_check(
178             "Node (",
179             node.name(),
180             ") has more outputs (",
181             node.output_size(),
182             ") than declared (",
183             outputs_.size(),
184             ") in op definition.");
185       }
186     }
187 
188     if (node.output(out_idx).empty() &&
189         (Single == outputs_[out_idx].GetOption())) {
190       fail_check(
191           "Node (",
192           node.name(),
193           ")'s output ",
194           out_idx,
195           " is marked single but has an empty string in the graph");
196     }
197   }
198 
199   // An internal symbol is defined as starting with two underscores. Attributes
200   // with names meeting this condition are considered implementation details
201   // and should be ignored for the purpose of schema checking.
202   auto isInternalSymbol = [](const std::string& sym) -> bool {
203     return sym.length() >= 2 && sym[0] == '_' && sym[1] == '_';
204   };
205 
206   // Check attributes
207   std::unordered_set<std::string> seen_attr_names{};
208   for (const auto& attr_proto : node.attribute()) {
209     const auto& name = attr_proto.name();
210 
211     if (!seen_attr_names.insert(name).second) {
212       fail_check("Attribute '", name, "' appeared multiple times.");
213     };
214 
215     const auto& search = attributes_.find(name);
216     AttributeProto::AttributeType expected_type;
217     if (search != attributes_.end()) {
218       expected_type = search->second.type;
219     } else if (allows_unchecked_attributes_ || isInternalSymbol(name)) {
220       continue;
221     } else {
222       fail_check(
223           "Unrecognized attribute: ", name, " for operator ", node.op_type());
224     }
225 
226     if (attr_proto.has_ref_attr_name()) {
227       if (!attr_proto.has_type() || attr_proto.type() != expected_type) {
228         fail_check(
229             "Mismatched attribute type in '", node.name() + " : " + name, "'");
230       }
231       continue;
232     }
233 
234     switch (expected_type) {
235       case AttributeProto::FLOAT:
236         if (!attr_proto.has_f()) {
237           fail_check("Attribute '", name, "' is expected to have field 'f'");
238         }
239         break;
240       case AttributeProto::INT:
241         if (!attr_proto.has_i()) {
242           fail_check("Attribute '", name, "' is expected to have field 'i'");
243         }
244         break;
245       case AttributeProto::STRING:
246         if (!attr_proto.has_s()) {
247           fail_check("Attribute '", name, "' is expected to have field 's'");
248         }
249         break;
250       case AttributeProto::TENSOR:
251         if (!attr_proto.has_t()) {
252           fail_check("Attribute '", name, "' is expected to have field 't'");
253         }
254         break;
255       case AttributeProto::GRAPH:
256         if (!attr_proto.has_g()) {
257           fail_check("Attribute '", name, "' is expected to have field 'g'");
258         }
259         break;
260       case AttributeProto::FLOATS:
261         if (!attr_proto.floats_size()) {
262           fail_check(
263               "Attribute '", name, "' is expected to have field 'floats'");
264         }
265         break;
266       case AttributeProto::INTS:
267         if (!attr_proto.ints_size()) {
268           fail_check("Attribute '", name, "' is expected to have field 'ints'");
269         }
270         break;
271       case AttributeProto::STRINGS:
272         if (!attr_proto.strings_size()) {
273           fail_check(
274               "Attribute '", name, "' is expected to have field 'strings'");
275         }
276         break;
277       case AttributeProto::TENSORS:
278         if (!attr_proto.tensors_size()) {
279           fail_check(
280               "Attribute '", name, "' is expected to have field 'tensors'");
281         }
282         break;
283       case AttributeProto::GRAPHS:
284         if (!attr_proto.graphs_size()) {
285           fail_check(
286               "Attribute '", name, "' is expected to have field 'graphs'");
287         }
288         break;
289       default:
290         fail_check("Attribute '", name, " has unknown expected type");
291     }
292   }
293   for (const auto& pair : attributes_) {
294     const auto& attr = pair.second;
295     if (!attr.required) {
296       continue;
297     }
298     if (!seen_attr_names.count(attr.name)) {
299       fail_check("Required attribute '", attr.name, "' is missing.");
300     }
301   }
302 
303   // Phew. All verifications passed.
304 }
305 
SinceVersion(OperatorSetVersion v)306 OpSchema& OpSchema::SinceVersion(OperatorSetVersion v) {
307   since_version_ = v;
308   return *this;
309 }
310 
Deprecate()311 OpSchema& OpSchema::Deprecate() {
312   deprecated_ = true;
313   return *this;
314 }
315 
NumInputs(std::set<int> allowed_input_nums)316 OpSchema& OpSchema::NumInputs(std::set<int> allowed_input_nums) {
317   num_inputs_allowed_ =
318       [MOVE_CAPTURE_IF_CPP14(allowed_input_nums)](int n) -> bool {
319     return allowed_input_nums.count(n);
320   };
321   return *this;
322 }
323 
NumOutputs(std::set<int> allowed_output_nums)324 OpSchema& OpSchema::NumOutputs(std::set<int> allowed_output_nums) {
325   num_outputs_allowed_ =
326       [MOVE_CAPTURE_IF_CPP14(allowed_output_nums)](int n) -> bool {
327     return allowed_output_nums.count(n);
328   };
329   return *this;
330 }
331 
TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction)332 OpSchema& OpSchema::TypeAndShapeInferenceFunction(
333     InferenceFunction inferenceFunction) {
334   tensor_inference_function_ = inferenceFunction;
335   return *this;
336 }
337 
SetSupportLevel(SupportType support)338 OpSchema& OpSchema::SetSupportLevel(SupportType support) {
339   support_ = support;
340   return *this;
341 }
342 
SetDoc(std::string doc)343 OpSchema& OpSchema::SetDoc(std::string doc) {
344   doc_ = std::move(doc);
345   return *this;
346 }
347 
348 // Functions to specify name for the operator schema.
SetName(std::string name)349 OpSchema& OpSchema::SetName(std::string name) {
350   name_ = std::move(name);
351   return *this;
352 }
353 
SetName(const char * name)354 OpSchema& OpSchema::SetName(const char* name) {
355   return SetName(std::string(name));
356 }
357 
358 // Functions to specify code location for the operator schema.
SetLocation(std::string file,int line)359 OpSchema& OpSchema::SetLocation(std::string file, int line) {
360   file_ = std::move(file);
361   line_ = line;
362   return *this;
363 }
364 
SetLocation(const char * file,int line)365 OpSchema& OpSchema::SetLocation(const char* file, int line) {
366   return SetLocation(std::string(file), line);
367 }
368 
SetDomain(std::string domain)369 OpSchema& OpSchema::SetDomain(std::string domain) {
370   domain_ = std::move(domain);
371   return *this;
372 }
373 
SetDomain(const char * domain)374 OpSchema& OpSchema::SetDomain(const char* domain) {
375   return SetDomain(std::string(domain));
376 }
377 
Attr(Attribute attr)378 OpSchema& OpSchema::Attr(Attribute attr) {
379   auto name = attr.name; // copy name so we can move attr in the next line
380   attributes_.insert(std::make_pair(std::move(name), std::move(attr)));
381   return *this;
382 }
383 
Attr(std::string name,std::string description,AttributeProto::AttributeType type,bool required)384 OpSchema& OpSchema::Attr(
385     std::string name,
386     std::string description,
387     AttributeProto::AttributeType type,
388     bool required) {
389   Attr(Attribute{std::move(name), std::move(description), type, required});
390   return *this;
391 }
392 
Attr(const char * name,const char * description,AttributeProto::AttributeType type,bool required)393 OpSchema& OpSchema::Attr(
394     const char* name,
395     const char* description,
396     AttributeProto::AttributeType type,
397     bool required) {
398   return Attr(std::string(name), std::string(description), type, required);
399 }
400 
401 #define ATTR_SETTER_WITH_SINGLE_VALUE(type, field, attrtype)                \
402   OpSchema& OpSchema::Attr(                                                 \
403       std::string name,                                                     \
404       std::string description,                                              \
405       AttributeProto::AttributeType attr_type,                              \
406       const type& default_value) {                                          \
407     if (attrtype != attr_type) {                                            \
408       fail_schema("Attribute specification type mismatch.");                \
409     }                                                                       \
410     AttributeProto a;                                                       \
411     a.set_name(name);                                                       \
412     a.set_##field(default_value);                                           \
413     a.set_type(attr_type);                                                  \
414     Attr(Attribute(std::move(name), std::move(description), std::move(a))); \
415     return *this;                                                           \
416   }                                                                         \
417   OpSchema& OpSchema::Attr(                                                 \
418       const char* name,                                                     \
419       const char* description,                                              \
420       AttributeProto::AttributeType attr_type,                              \
421       const type& default_value) {                                          \
422     return Attr(                                                            \
423         std::string(name),                                                  \
424         std::string(description),                                           \
425         attr_type,                                                          \
426         default_value);                                                     \
427   }
428 
429 #define ATTR_SETTER_WITH_LIST_VALUE(type, field, attrtype)                  \
430   OpSchema& OpSchema::Attr(                                                 \
431       std::string name,                                                     \
432       std::string description,                                              \
433       AttributeProto::AttributeType attr_type,                              \
434       const std::vector<type>& default_value) {                             \
435     if (attrtype != attr_type) {                                            \
436       fail_schema("Attribute specification type mismatch.");                \
437     }                                                                       \
438     AttributeProto a;                                                       \
439     a.set_name(name);                                                       \
440     a.set_type(attr_type);                                                  \
441     for (const auto& v : default_value) {                                   \
442       a.add_##field(v);                                                     \
443     }                                                                       \
444     Attr(Attribute(std::move(name), std::move(description), std::move(a))); \
445     return *this;                                                           \
446   }
447 
448 #define ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(type, field, attrtype) \
449   OpSchema& OpSchema::Attr(                                         \
450       std::string name,                                             \
451       std::string description,                                      \
452       AttributeProto::AttributeType attr_type,                      \
453       const type& default_value) {                                  \
454     if (attrtype != attr_type) {                                    \
455       fail_schema("Attribute specification type mismatch.");        \
456     }                                                               \
457     AttributeProto a;                                               \
458     a.set_name(name);                                               \
459     *(a.mutable_##field()) = default_value;                         \
460     a.set_type(attr_type);                                          \
461     Attr(Attribute(std::move(name), std::move(description), a));    \
462     return *this;                                                   \
463   }
464 
465 #define ATTR_SETTER_WITH_LIST_COMPLEXVALUE(type, field, attrtype)           \
466   OpSchema& OpSchema::Attr(                                                 \
467       std::string name,                                                     \
468       std::string description,                                              \
469       AttributeProto::AttributeType attr_type,                              \
470       const std::vector<type>& default_value) {                             \
471     if (attrtype != attr_type) {                                            \
472       fail_schema("Attribute specification type mismatch.");                \
473     }                                                                       \
474     AttributeProto a;                                                       \
475     a.set_name(name);                                                       \
476     a.set_type(attr_type);                                                  \
477     for (const auto& v : default_value) {                                   \
478       *(a.add_##field()) = v;                                               \
479     }                                                                       \
480     Attr(Attribute(std::move(name), std::move(description), std::move(a))); \
481     return *this;                                                           \
482   }
483 
ATTR_SETTER_WITH_SINGLE_VALUE(int64_t,i,AttributeProto::INT)484 ATTR_SETTER_WITH_SINGLE_VALUE(int64_t, i, AttributeProto::INT)
485 ATTR_SETTER_WITH_SINGLE_VALUE(float, f, AttributeProto::FLOAT)
486 ATTR_SETTER_WITH_SINGLE_VALUE(std::string, s, AttributeProto::STRING)
487 ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(TensorProto, t, AttributeProto::TENSOR)
488 ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(GraphProto, g, AttributeProto::GRAPH)
489 ATTR_SETTER_WITH_LIST_VALUE(int64_t, ints, AttributeProto::INTS)
490 ATTR_SETTER_WITH_LIST_VALUE(float, floats, AttributeProto::FLOATS)
491 ATTR_SETTER_WITH_LIST_COMPLEXVALUE(
492     std::string,
493     strings,
494     AttributeProto::STRINGS)
495 ATTR_SETTER_WITH_LIST_COMPLEXVALUE(
496     TensorProto,
497     tensors,
498     AttributeProto::TENSORS)
499 ATTR_SETTER_WITH_LIST_COMPLEXVALUE(GraphProto, graphs, AttributeProto::GRAPHS)
500 
501 OpSchema& OpSchema::AllowUncheckedAttributes() {
502   allows_unchecked_attributes_ = true;
503   return *this;
504 }
505 
Input(int n,std::string name,std::string description,std::string type_str,OpSchema::FormalParameterOption param_option,bool is_homogeneous)506 OpSchema& OpSchema::Input(
507     int n,
508     std::string name,
509     std::string description,
510     std::string type_str,
511     OpSchema::FormalParameterOption param_option,
512     bool is_homogeneous) {
513   if (int(inputs_.size()) <= n) {
514     inputs_.resize(n + 1);
515   }
516   inputs_[n] = FormalParameter(
517       std::move(name),
518       std::move(description),
519       std::move(type_str),
520       param_option,
521       is_homogeneous);
522   return *this;
523 }
524 
Input(int n,const char * name,const char * description,const char * type_str,FormalParameterOption param_option,bool is_homogeneous)525 OpSchema& OpSchema::Input(
526     int n,
527     const char* name,
528     const char* description,
529     const char* type_str,
530     FormalParameterOption param_option,
531     bool is_homogeneous) {
532   return Input(
533       n,
534       std::string(name),
535       std::string(description),
536       std::string(type_str),
537       param_option,
538       is_homogeneous);
539 }
540 
Output(int n,std::string name,std::string description,std::string type_str,OpSchema::FormalParameterOption param_option,bool is_homogeneous)541 OpSchema& OpSchema::Output(
542     int n,
543     std::string name,
544     std::string description,
545     std::string type_str,
546     OpSchema::FormalParameterOption param_option,
547     bool is_homogeneous) {
548   if (int(outputs_.size()) <= n) {
549     outputs_.resize(n + 1);
550   }
551   outputs_[n] = FormalParameter(
552       std::move(name),
553       std::move(description),
554       std::move(type_str),
555       param_option,
556       is_homogeneous);
557   return *this;
558 }
559 
Output(int n,const char * name,const char * description,const char * type_str,FormalParameterOption param_option,bool is_homogeneous)560 OpSchema& OpSchema::Output(
561     int n,
562     const char* name,
563     const char* description,
564     const char* type_str,
565     FormalParameterOption param_option,
566     bool is_homogeneous) {
567   return Output(
568       n,
569       std::string(name),
570       std::string(description),
571       std::string(type_str),
572       param_option,
573       is_homogeneous);
574 }
575 
TypeConstraint(std::string type_str,std::vector<std::string> constraints,std::string description)576 OpSchema& OpSchema::TypeConstraint(
577     std::string type_str,
578     std::vector<std::string> constraints,
579     std::string description) {
580   if (type_constraints_.end() != type_constraints_.find(type_str)) {
581     fail_schema("Duplicate type constraint name");
582   }
583 
584   DataTypeSet d;
585   for (const auto& t : constraints) {
586     d.insert(Utils::DataTypeUtils::ToType(t));
587   }
588   type_constraints_.insert(
589       std::make_pair(type_str, std::make_pair(d, description)));
590   type_constraint_params_.push_back(TypeConstraintParam(
591       std::move(type_str), std::move(constraints), std::move(description)));
592   return *this;
593 }
594 
TypeConstraint(const char * type_str,std::initializer_list<const char * > constraints,const char * description)595 OpSchema& OpSchema::TypeConstraint(
596     const char* type_str,
597     std::initializer_list<const char*> constraints,
598     const char* description) {
599   std::vector<std::string> constraints_vector;
600   constraints_vector.reserve(constraints.size());
601   for (auto iter = constraints.begin(); iter != constraints.end(); ++iter) {
602     constraints_vector.push_back(*iter);
603   }
604 
605   return TypeConstraint(
606       std::string(type_str), constraints_vector, std::string(description));
607 }
608 
ParseAndSetTypes(std::vector<OpSchema::FormalParameter> * formal_parameters)609 void OpSchema::ParseAndSetTypes(
610     /*out*/ std::vector<OpSchema::FormalParameter>* formal_parameters) {
611   for (auto& formal_parameter : *formal_parameters) {
612     auto& type = formal_parameter.GetTypeStr();
613     DataTypeSet allowed_types;
614     auto it = type_constraints_.find(type);
615     if (it != type_constraints_.end()) {
616       allowed_types = it->second.first;
617     } else {
618       allowed_types.emplace(Utils::DataTypeUtils::ToType(type));
619     }
620 
621     formal_parameter.MutableTypes() = allowed_types;
622   }
623 }
624 
FunctionBody(const std::vector<NodeProto> & func_nodes)625 OpSchema& OpSchema::FunctionBody(const std::vector<NodeProto>& func_nodes) {
626   for (const auto node : func_nodes) {
627     auto new_node = function_body_.add_node();
628     new_node->CopyFrom(node);
629   }
630   return *this;
631 }
632 
GetFunction() const633 const FunctionProto* OpSchema::GetFunction() const {
634   return function_body_.node_size()>0 ? &function_body_ : nullptr;
635 }
636 
FillUsing(const std::function<void (OpSchema &)> & populator)637 OpSchema& OpSchema::FillUsing(const std::function<void(OpSchema&)>& populator) {
638   if (populator) {
639     populator(*this);
640   }
641   return *this;
642 }
643 
BuildFunction()644 void OpSchema::BuildFunction(){
645   function_body_.set_name(this->name_);
646   function_body_.set_doc_string(this->doc_);
647   function_body_.set_since_version(this->since_version_);
648   function_body_.set_status(OperatorStatus(1 - (int)this->support_));
649   for (auto& i : inputs_) {
650     function_body_.add_input(i.GetName());
651   }
652   for (auto& o : outputs_) {
653     function_body_.add_output(o.GetName());
654   }
655   for (auto& a : attributes_) {
656     function_body_.add_attribute(a.first);
657   }
658 }
659 
Finalize()660 void OpSchema::Finalize() {
661 #define ENFORCE(x)                                                          \
662   do {                                                                      \
663     if (!(x))                                                               \
664       throw std::logic_error(                                               \
665           "ONNX Schema " + name_ + ": failed validating the check: " + #x); \
666   } while (0)
667 
668   // Calculate min/max number of inputs.
669   // <Min number of inputs> = <number of "single" inputs> + <number of
670   // "optional" but not trailing inputs>. <Max number of inputs> = <number of
671   // all inputs or std::numeric_limits<int>::max() (if the last input is
672   // variadic).
673 
674   // Flag indicates whether an optional input is trailing one (there's no single
675   // or variadic input behind).
676   for (size_t i = 0; i < inputs_.size(); ++i) {
677     switch (inputs_[i].GetOption()) {
678       case OpSchema::Single:
679         ++max_input_;
680         min_input_ = max_input_;
681         break;
682       case OpSchema::Optional:
683         ++max_input_;
684         break;
685       case OpSchema::Variadic:
686         // Only last input formal parameter could be variadic.
687         ENFORCE((inputs_.size() - 1) == i);
688         min_input_ = max_input_ + 1;
689         max_input_ = std::numeric_limits<int>::max();
690         break;
691     }
692   }
693 
694   // Calculate min/max number of outputs.
695   for (size_t i = 0; i < outputs_.size(); ++i) {
696     switch (outputs_[i].GetOption()) {
697       case OpSchema::Single:
698         ++max_output_;
699         min_output_ = max_output_;
700         break;
701       case OpSchema::Optional:
702         ++max_output_;
703         break;
704       case OpSchema::Variadic:
705         // Only last output formal parameter could be variadic.
706         ENFORCE((outputs_.size() - 1) == i);
707         min_output_ = max_output_ + 1;
708         max_output_ = std::numeric_limits<int>::max();
709         break;
710     }
711   }
712 
713   // all inputs and outputs have names
714   for (const auto& it : inputs_) {
715     ENFORCE(!(it.GetName().empty()));
716   }
717   for (const auto& it : outputs_) {
718     ENFORCE(!(it.GetName().empty()));
719   }
720 
721   ParseAndSetTypes(&inputs_);
722   ParseAndSetTypes(&outputs_);
723 
724   if (this->HasFunction()) {
725     BuildFunction();
726   }
727 }
728 
operator <<(std::ostream & out,const OpSchema & schema)729 std::ostream& operator<<(std::ostream& out, const OpSchema& schema) {
730   if (!schema.attributes_.empty()) {
731     out << "Attributes:" << std::endl;
732     for (const auto& pair : schema.attributes_) {
733       out << "  " << pair.second.name << " : " << pair.second.description
734           << std::endl;
735     }
736   }
737   if (schema.max_input_ > 0) {
738     out << "Inputs:" << std::endl;
739     if (!schema.inputs_.empty()) {
740       for (size_t i = 0; i < schema.inputs_.size(); ++i) {
741         const auto& p = schema.inputs_[i];
742         const auto& name = p.GetName();
743         const auto& description = p.GetDescription();
744         const auto& type_str = p.GetTypeStr();
745         out << "  " << i << ", " << ("" != name ? name : "(unnamed)") << " : "
746             << ("" != description ? description : "(no doc)") << " : "
747             << ("" != type_str ? type_str : "(no type)") << std::endl;
748       }
749     } else {
750       out << "  (no explicit description available)" << std::endl;
751     }
752   }
753   if (schema.max_output_ > 0) {
754     out << "Outputs:" << std::endl;
755     if (!schema.outputs_.empty()) {
756       for (size_t i = 0; i < schema.outputs_.size(); ++i) {
757         const auto& p = schema.outputs_[i];
758         const auto& name = p.GetName();
759         const auto& description = p.GetDescription();
760         const auto& type_str = p.GetTypeStr();
761         out << "  " << i << ", " << ("" != name ? name : "(unnamed)") << " : "
762             << ("" != description ? description : "(no doc)") << " : "
763             << ("" != type_str ? type_str : "(no type)") << std::endl;
764       }
765     } else {
766       out << "  (no explicit description available)" << std::endl;
767     }
768   }
769   out << std::endl;
770   if (schema.doc()) {
771     out << schema.doc();
772   } else {
773     out << "(no documentation yet)" << std::endl;
774   }
775   out << std::endl;
776   if (schema.line_) {
777     out << "Defined at " << schema.file_ << ":" << schema.line_ << std::endl;
778   }
779   return out;
780 }
781 
782 OpSchemaRegistry::DomainToVersionRange&
Instance()783 OpSchemaRegistry::DomainToVersionRange::Instance() {
784   static DomainToVersionRange domain_to_version_range;
785   return domain_to_version_range;
786 };
787 
788 // Private method used by OpSchemaRegisterOnce and OpSchemaRegistry::map()
789 OpName_Domain_Version_Schema_Map&
GetMapWithoutEnsuringRegistration()790 OpSchemaRegistry::GetMapWithoutEnsuringRegistration() {
791   static OpName_Domain_Version_Schema_Map map;
792   return map;
793 }
794 
map()795 OpName_Domain_Version_Schema_Map& OpSchemaRegistry::map() {
796   auto& map = GetMapWithoutEnsuringRegistration();
797 
798   // The following class is used to register operators the
799   // first time this method is called, in a thread-safe fashion.
800   class SchemasRegisterer {
801    public:
802     SchemasRegisterer() {
803       // In debug builds, the number of schema registered in this constructor
804       // is compared against the number of calls to schema registration macros.
805 #ifndef NDEBUG
806       size_t dbg_initial_schema_count = GetRegisteredSchemaCount();
807 #endif
808 
809       RegisterOnnxOperatorSetSchema();
810 
811 #ifdef ONNX_ML
812       RegisterOnnxMLOperatorSetSchema();
813 #endif
814 
815 #ifndef NDEBUG
816       size_t dbg_registered_schema_count =
817           GetRegisteredSchemaCount() - dbg_initial_schema_count;
818 
819       ONNX_ASSERTM(
820           dbg_registered_schema_count == ONNX_DBG_GET_COUNT_IN_OPSETS(),
821           "%u schema were exposed from operator sets and automatically placed into the static registry.  "
822           "%u were expected based on calls to registration macros. Operator set functions may need to be updated.",
823           dbg_registered_schema_count,
824           ONNX_DBG_GET_COUNT_IN_OPSETS());
825 #endif
826     }
827 
828    private:
829     static size_t GetRegisteredSchemaCount() {
830       size_t count = 0;
831       for (auto& x : GetMapWithoutEnsuringRegistration()) {
832         for (auto& y : x.second) {
833           count += y.second.size();
834         }
835       }
836       return count;
837     }
838   };
839 
840 #ifndef __ONNX_DISABLE_STATIC_REGISTRATION
841   static SchemasRegisterer schemasRegisterer;
842 #endif
843 
844   return map;
845 }
846 
ReplaceAll(std::string & s,const char * from,const char * to)847 size_t ReplaceAll(std::string& s, const char* from, const char* to) {
848   size_t numReplaced = 0;
849   std::string::size_type lenFrom = std::strlen(from);
850   std::string::size_type lenTo = std::strlen(to);
851   for (std::string::size_type pos = s.find(from); pos != std::string::npos;
852        pos = s.find(from, pos + lenTo)) {
853     s.replace(pos, lenFrom, to);
854     numReplaced++;
855   }
856   return numReplaced;
857 }
858 
859 } // namespace ONNX_NAMESPACE
860