1 // Copyright (c) ONNX Project Contributors.
2 // Licensed under the MIT license.
3
4 #include "onnx/defs/schema.h"
5 #include <stdexcept>
6 #include <unordered_set>
7 #include "onnx/checker.h"
8 #include "onnx/defs/operator_sets.h"
9
10 #ifdef ONNX_ML
11 #include "onnx/defs/operator_sets-ml.h"
12 #endif
13
14 #include "onnx/common/assertions.h"
15 #include "onnx/common/stl_backports.h"
16
17 namespace ONNX_NAMESPACE {
18
RegisterSchema(OpSchema && schema)19 void RegisterSchema(OpSchema&& schema) {
20 OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration = schema;
21 }
22
23 #ifndef NDEBUG
Instance()24 DbgOperatorSetTracker& DbgOperatorSetTracker::Instance() {
25 static DbgOperatorSetTracker instance;
26 return instance;
27 }
28 #endif
29
FormalParameter(std::string name,DataTypeSet allowed_type_set,std::string type_str,std::string description,FormalParameterOption param_option,bool is_homogeneous)30 OpSchema::FormalParameter::FormalParameter(
31 std::string name,
32 DataTypeSet allowed_type_set,
33 std::string type_str,
34 std::string description,
35 FormalParameterOption param_option,
36 bool is_homogeneous)
37 : name_(std::move(name)),
38 type_set_(std::move(allowed_type_set)),
39 type_str_(std::move(type_str)),
40 description_(std::move(description)),
41 param_option_(param_option),
42 is_homogeneous_(is_homogeneous) {}
43
FormalParameter(std::string name,std::string description,std::string type_str,FormalParameterOption param_option,bool is_homogeneous)44 OpSchema::FormalParameter::FormalParameter(
45 std::string name,
46 std::string description,
47 std::string type_str,
48 FormalParameterOption param_option,
49 bool is_homogeneous)
50 : name_(std::move(name)),
51 type_str_(std::move(type_str)),
52 description_(std::move(description)),
53 param_option_(param_option),
54 is_homogeneous_(is_homogeneous) {}
55
GetName() const56 const std::string& OpSchema::FormalParameter::GetName() const {
57 return name_;
58 }
59
GetTypes() const60 const DataTypeSet& OpSchema::FormalParameter::GetTypes() const {
61 return type_set_;
62 }
63
MutableTypes()64 DataTypeSet& OpSchema::FormalParameter::MutableTypes() {
65 return type_set_;
66 }
67
GetTypeStr() const68 const std::string& OpSchema::FormalParameter::GetTypeStr() const {
69 return type_str_;
70 }
71
GetDescription() const72 const std::string& OpSchema::FormalParameter::GetDescription() const {
73 return description_;
74 }
75
GetOption() const76 OpSchema::FormalParameterOption OpSchema::FormalParameter::GetOption() const {
77 return param_option_;
78 }
79
GetIsHomogeneous() const80 bool OpSchema::FormalParameter::GetIsHomogeneous() const {
81 return is_homogeneous_;
82 }
83
Instance()84 OpSchemaRegistry* OpSchemaRegistry::Instance() {
85 static OpSchemaRegistry instance;
86 return &instance;
87 }
88
Verify(const NodeProto & node) const89 void OpSchema::Verify(const NodeProto& node) const {
90 if (deprecated_) {
91 fail_check(
92 "Operator '",
93 name_,
94 "' has been deprecated since version ",
95 since_version_);
96 }
97
98 // Check the number of inputs.
99 if (node.input_size() < min_input_ || node.input_size() > max_input_) {
100 fail_check(
101 "Node (",
102 node.name(),
103 ") has input size ",
104 node.input_size(),
105 " not in range [min=",
106 min_input_,
107 ", max=",
108 max_input_,
109 "].");
110 }
111
112 if (!num_inputs_allowed_(node.input_size())) {
113 fail_check(
114 "Node (",
115 node.name(),
116 ") has input size ",
117 node.input_size(),
118 " not in allowed input sizes.");
119 }
120
121 // Check the number of outputs.
122 if (node.output_size() < min_output_ || node.output_size() > max_output_) {
123 fail_check(
124 "Node (",
125 node.name(),
126 ") has output size ",
127 node.output_size(),
128 " not in range [min=",
129 min_output_,
130 ", max=",
131 max_output_,
132 "].");
133 }
134
135 if (!num_outputs_allowed_(node.output_size())) {
136 fail_check(
137 "Node (",
138 node.name(),
139 "has output size ",
140 node.output_size(),
141 " not in allowed output sizes.");
142 }
143
144 // Check the values of inputs / outputs
145 for (int in_idx = 0; in_idx < node.input_size(); ++in_idx) {
146 if (in_idx >= static_cast<int>(inputs_.size())) {
147 if (inputs_.size() > 0 && Variadic == inputs_.back().GetOption()) {
148 // The last input formal parameter should be variadic.
149 break;
150 } else {
151 fail_check(
152 "Node (",
153 node.name(),
154 ") has more inputs (",
155 node.input_size(),
156 ") than declared (",
157 inputs_.size(),
158 ") in op definition.");
159 }
160 }
161 if (node.input(in_idx).empty() && (Single == inputs_[in_idx].GetOption())) {
162 fail_check(
163 "Node (",
164 node.name(),
165 ")'s input ",
166 in_idx,
167 " is marked single but has an empty string in the graph");
168 }
169 }
170
171 for (int out_idx = 0; out_idx < node.output_size(); ++out_idx) {
172 if (out_idx >= static_cast<int>(outputs_.size())) {
173 if (outputs_.size() > 0 && Variadic == outputs_.back().GetOption()) {
174 // The last output formal parameter should be variadic.
175 break;
176 } else {
177 fail_check(
178 "Node (",
179 node.name(),
180 ") has more outputs (",
181 node.output_size(),
182 ") than declared (",
183 outputs_.size(),
184 ") in op definition.");
185 }
186 }
187
188 if (node.output(out_idx).empty() &&
189 (Single == outputs_[out_idx].GetOption())) {
190 fail_check(
191 "Node (",
192 node.name(),
193 ")'s output ",
194 out_idx,
195 " is marked single but has an empty string in the graph");
196 }
197 }
198
199 // An internal symbol is defined as starting with two underscores. Attributes
200 // with names meeting this condition are considered implementation details
201 // and should be ignored for the purpose of schema checking.
202 auto isInternalSymbol = [](const std::string& sym) -> bool {
203 return sym.length() >= 2 && sym[0] == '_' && sym[1] == '_';
204 };
205
206 // Check attributes
207 std::unordered_set<std::string> seen_attr_names{};
208 for (const auto& attr_proto : node.attribute()) {
209 const auto& name = attr_proto.name();
210
211 if (!seen_attr_names.insert(name).second) {
212 fail_check("Attribute '", name, "' appeared multiple times.");
213 };
214
215 const auto& search = attributes_.find(name);
216 AttributeProto::AttributeType expected_type;
217 if (search != attributes_.end()) {
218 expected_type = search->second.type;
219 } else if (allows_unchecked_attributes_ || isInternalSymbol(name)) {
220 continue;
221 } else {
222 fail_check(
223 "Unrecognized attribute: ", name, " for operator ", node.op_type());
224 }
225
226 if (attr_proto.has_ref_attr_name()) {
227 if (!attr_proto.has_type() || attr_proto.type() != expected_type) {
228 fail_check(
229 "Mismatched attribute type in '", node.name() + " : " + name, "'");
230 }
231 continue;
232 }
233
234 switch (expected_type) {
235 case AttributeProto::FLOAT:
236 if (!attr_proto.has_f()) {
237 fail_check("Attribute '", name, "' is expected to have field 'f'");
238 }
239 break;
240 case AttributeProto::INT:
241 if (!attr_proto.has_i()) {
242 fail_check("Attribute '", name, "' is expected to have field 'i'");
243 }
244 break;
245 case AttributeProto::STRING:
246 if (!attr_proto.has_s()) {
247 fail_check("Attribute '", name, "' is expected to have field 's'");
248 }
249 break;
250 case AttributeProto::TENSOR:
251 if (!attr_proto.has_t()) {
252 fail_check("Attribute '", name, "' is expected to have field 't'");
253 }
254 break;
255 case AttributeProto::GRAPH:
256 if (!attr_proto.has_g()) {
257 fail_check("Attribute '", name, "' is expected to have field 'g'");
258 }
259 break;
260 case AttributeProto::FLOATS:
261 if (!attr_proto.floats_size()) {
262 fail_check(
263 "Attribute '", name, "' is expected to have field 'floats'");
264 }
265 break;
266 case AttributeProto::INTS:
267 if (!attr_proto.ints_size()) {
268 fail_check("Attribute '", name, "' is expected to have field 'ints'");
269 }
270 break;
271 case AttributeProto::STRINGS:
272 if (!attr_proto.strings_size()) {
273 fail_check(
274 "Attribute '", name, "' is expected to have field 'strings'");
275 }
276 break;
277 case AttributeProto::TENSORS:
278 if (!attr_proto.tensors_size()) {
279 fail_check(
280 "Attribute '", name, "' is expected to have field 'tensors'");
281 }
282 break;
283 case AttributeProto::GRAPHS:
284 if (!attr_proto.graphs_size()) {
285 fail_check(
286 "Attribute '", name, "' is expected to have field 'graphs'");
287 }
288 break;
289 default:
290 fail_check("Attribute '", name, " has unknown expected type");
291 }
292 }
293 for (const auto& pair : attributes_) {
294 const auto& attr = pair.second;
295 if (!attr.required) {
296 continue;
297 }
298 if (!seen_attr_names.count(attr.name)) {
299 fail_check("Required attribute '", attr.name, "' is missing.");
300 }
301 }
302
303 // Phew. All verifications passed.
304 }
305
SinceVersion(OperatorSetVersion v)306 OpSchema& OpSchema::SinceVersion(OperatorSetVersion v) {
307 since_version_ = v;
308 return *this;
309 }
310
Deprecate()311 OpSchema& OpSchema::Deprecate() {
312 deprecated_ = true;
313 return *this;
314 }
315
NumInputs(std::set<int> allowed_input_nums)316 OpSchema& OpSchema::NumInputs(std::set<int> allowed_input_nums) {
317 num_inputs_allowed_ =
318 [MOVE_CAPTURE_IF_CPP14(allowed_input_nums)](int n) -> bool {
319 return allowed_input_nums.count(n);
320 };
321 return *this;
322 }
323
NumOutputs(std::set<int> allowed_output_nums)324 OpSchema& OpSchema::NumOutputs(std::set<int> allowed_output_nums) {
325 num_outputs_allowed_ =
326 [MOVE_CAPTURE_IF_CPP14(allowed_output_nums)](int n) -> bool {
327 return allowed_output_nums.count(n);
328 };
329 return *this;
330 }
331
TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction)332 OpSchema& OpSchema::TypeAndShapeInferenceFunction(
333 InferenceFunction inferenceFunction) {
334 tensor_inference_function_ = inferenceFunction;
335 return *this;
336 }
337
SetSupportLevel(SupportType support)338 OpSchema& OpSchema::SetSupportLevel(SupportType support) {
339 support_ = support;
340 return *this;
341 }
342
SetDoc(std::string doc)343 OpSchema& OpSchema::SetDoc(std::string doc) {
344 doc_ = std::move(doc);
345 return *this;
346 }
347
348 // Functions to specify name for the operator schema.
SetName(std::string name)349 OpSchema& OpSchema::SetName(std::string name) {
350 name_ = std::move(name);
351 return *this;
352 }
353
SetName(const char * name)354 OpSchema& OpSchema::SetName(const char* name) {
355 return SetName(std::string(name));
356 }
357
358 // Functions to specify code location for the operator schema.
SetLocation(std::string file,int line)359 OpSchema& OpSchema::SetLocation(std::string file, int line) {
360 file_ = std::move(file);
361 line_ = line;
362 return *this;
363 }
364
SetLocation(const char * file,int line)365 OpSchema& OpSchema::SetLocation(const char* file, int line) {
366 return SetLocation(std::string(file), line);
367 }
368
SetDomain(std::string domain)369 OpSchema& OpSchema::SetDomain(std::string domain) {
370 domain_ = std::move(domain);
371 return *this;
372 }
373
SetDomain(const char * domain)374 OpSchema& OpSchema::SetDomain(const char* domain) {
375 return SetDomain(std::string(domain));
376 }
377
Attr(Attribute attr)378 OpSchema& OpSchema::Attr(Attribute attr) {
379 auto name = attr.name; // copy name so we can move attr in the next line
380 attributes_.insert(std::make_pair(std::move(name), std::move(attr)));
381 return *this;
382 }
383
Attr(std::string name,std::string description,AttributeProto::AttributeType type,bool required)384 OpSchema& OpSchema::Attr(
385 std::string name,
386 std::string description,
387 AttributeProto::AttributeType type,
388 bool required) {
389 Attr(Attribute{std::move(name), std::move(description), type, required});
390 return *this;
391 }
392
Attr(const char * name,const char * description,AttributeProto::AttributeType type,bool required)393 OpSchema& OpSchema::Attr(
394 const char* name,
395 const char* description,
396 AttributeProto::AttributeType type,
397 bool required) {
398 return Attr(std::string(name), std::string(description), type, required);
399 }
400
401 #define ATTR_SETTER_WITH_SINGLE_VALUE(type, field, attrtype) \
402 OpSchema& OpSchema::Attr( \
403 std::string name, \
404 std::string description, \
405 AttributeProto::AttributeType attr_type, \
406 const type& default_value) { \
407 if (attrtype != attr_type) { \
408 fail_schema("Attribute specification type mismatch."); \
409 } \
410 AttributeProto a; \
411 a.set_name(name); \
412 a.set_##field(default_value); \
413 a.set_type(attr_type); \
414 Attr(Attribute(std::move(name), std::move(description), std::move(a))); \
415 return *this; \
416 } \
417 OpSchema& OpSchema::Attr( \
418 const char* name, \
419 const char* description, \
420 AttributeProto::AttributeType attr_type, \
421 const type& default_value) { \
422 return Attr( \
423 std::string(name), \
424 std::string(description), \
425 attr_type, \
426 default_value); \
427 }
428
429 #define ATTR_SETTER_WITH_LIST_VALUE(type, field, attrtype) \
430 OpSchema& OpSchema::Attr( \
431 std::string name, \
432 std::string description, \
433 AttributeProto::AttributeType attr_type, \
434 const std::vector<type>& default_value) { \
435 if (attrtype != attr_type) { \
436 fail_schema("Attribute specification type mismatch."); \
437 } \
438 AttributeProto a; \
439 a.set_name(name); \
440 a.set_type(attr_type); \
441 for (const auto& v : default_value) { \
442 a.add_##field(v); \
443 } \
444 Attr(Attribute(std::move(name), std::move(description), std::move(a))); \
445 return *this; \
446 }
447
448 #define ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(type, field, attrtype) \
449 OpSchema& OpSchema::Attr( \
450 std::string name, \
451 std::string description, \
452 AttributeProto::AttributeType attr_type, \
453 const type& default_value) { \
454 if (attrtype != attr_type) { \
455 fail_schema("Attribute specification type mismatch."); \
456 } \
457 AttributeProto a; \
458 a.set_name(name); \
459 *(a.mutable_##field()) = default_value; \
460 a.set_type(attr_type); \
461 Attr(Attribute(std::move(name), std::move(description), a)); \
462 return *this; \
463 }
464
465 #define ATTR_SETTER_WITH_LIST_COMPLEXVALUE(type, field, attrtype) \
466 OpSchema& OpSchema::Attr( \
467 std::string name, \
468 std::string description, \
469 AttributeProto::AttributeType attr_type, \
470 const std::vector<type>& default_value) { \
471 if (attrtype != attr_type) { \
472 fail_schema("Attribute specification type mismatch."); \
473 } \
474 AttributeProto a; \
475 a.set_name(name); \
476 a.set_type(attr_type); \
477 for (const auto& v : default_value) { \
478 *(a.add_##field()) = v; \
479 } \
480 Attr(Attribute(std::move(name), std::move(description), std::move(a))); \
481 return *this; \
482 }
483
ATTR_SETTER_WITH_SINGLE_VALUE(int64_t,i,AttributeProto::INT)484 ATTR_SETTER_WITH_SINGLE_VALUE(int64_t, i, AttributeProto::INT)
485 ATTR_SETTER_WITH_SINGLE_VALUE(float, f, AttributeProto::FLOAT)
486 ATTR_SETTER_WITH_SINGLE_VALUE(std::string, s, AttributeProto::STRING)
487 ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(TensorProto, t, AttributeProto::TENSOR)
488 ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(GraphProto, g, AttributeProto::GRAPH)
489 ATTR_SETTER_WITH_LIST_VALUE(int64_t, ints, AttributeProto::INTS)
490 ATTR_SETTER_WITH_LIST_VALUE(float, floats, AttributeProto::FLOATS)
491 ATTR_SETTER_WITH_LIST_COMPLEXVALUE(
492 std::string,
493 strings,
494 AttributeProto::STRINGS)
495 ATTR_SETTER_WITH_LIST_COMPLEXVALUE(
496 TensorProto,
497 tensors,
498 AttributeProto::TENSORS)
499 ATTR_SETTER_WITH_LIST_COMPLEXVALUE(GraphProto, graphs, AttributeProto::GRAPHS)
500
501 OpSchema& OpSchema::AllowUncheckedAttributes() {
502 allows_unchecked_attributes_ = true;
503 return *this;
504 }
505
Input(int n,std::string name,std::string description,std::string type_str,OpSchema::FormalParameterOption param_option,bool is_homogeneous)506 OpSchema& OpSchema::Input(
507 int n,
508 std::string name,
509 std::string description,
510 std::string type_str,
511 OpSchema::FormalParameterOption param_option,
512 bool is_homogeneous) {
513 if (int(inputs_.size()) <= n) {
514 inputs_.resize(n + 1);
515 }
516 inputs_[n] = FormalParameter(
517 std::move(name),
518 std::move(description),
519 std::move(type_str),
520 param_option,
521 is_homogeneous);
522 return *this;
523 }
524
Input(int n,const char * name,const char * description,const char * type_str,FormalParameterOption param_option,bool is_homogeneous)525 OpSchema& OpSchema::Input(
526 int n,
527 const char* name,
528 const char* description,
529 const char* type_str,
530 FormalParameterOption param_option,
531 bool is_homogeneous) {
532 return Input(
533 n,
534 std::string(name),
535 std::string(description),
536 std::string(type_str),
537 param_option,
538 is_homogeneous);
539 }
540
Output(int n,std::string name,std::string description,std::string type_str,OpSchema::FormalParameterOption param_option,bool is_homogeneous)541 OpSchema& OpSchema::Output(
542 int n,
543 std::string name,
544 std::string description,
545 std::string type_str,
546 OpSchema::FormalParameterOption param_option,
547 bool is_homogeneous) {
548 if (int(outputs_.size()) <= n) {
549 outputs_.resize(n + 1);
550 }
551 outputs_[n] = FormalParameter(
552 std::move(name),
553 std::move(description),
554 std::move(type_str),
555 param_option,
556 is_homogeneous);
557 return *this;
558 }
559
Output(int n,const char * name,const char * description,const char * type_str,FormalParameterOption param_option,bool is_homogeneous)560 OpSchema& OpSchema::Output(
561 int n,
562 const char* name,
563 const char* description,
564 const char* type_str,
565 FormalParameterOption param_option,
566 bool is_homogeneous) {
567 return Output(
568 n,
569 std::string(name),
570 std::string(description),
571 std::string(type_str),
572 param_option,
573 is_homogeneous);
574 }
575
TypeConstraint(std::string type_str,std::vector<std::string> constraints,std::string description)576 OpSchema& OpSchema::TypeConstraint(
577 std::string type_str,
578 std::vector<std::string> constraints,
579 std::string description) {
580 if (type_constraints_.end() != type_constraints_.find(type_str)) {
581 fail_schema("Duplicate type constraint name");
582 }
583
584 DataTypeSet d;
585 for (const auto& t : constraints) {
586 d.insert(Utils::DataTypeUtils::ToType(t));
587 }
588 type_constraints_.insert(
589 std::make_pair(type_str, std::make_pair(d, description)));
590 type_constraint_params_.push_back(TypeConstraintParam(
591 std::move(type_str), std::move(constraints), std::move(description)));
592 return *this;
593 }
594
TypeConstraint(const char * type_str,std::initializer_list<const char * > constraints,const char * description)595 OpSchema& OpSchema::TypeConstraint(
596 const char* type_str,
597 std::initializer_list<const char*> constraints,
598 const char* description) {
599 std::vector<std::string> constraints_vector;
600 constraints_vector.reserve(constraints.size());
601 for (auto iter = constraints.begin(); iter != constraints.end(); ++iter) {
602 constraints_vector.push_back(*iter);
603 }
604
605 return TypeConstraint(
606 std::string(type_str), constraints_vector, std::string(description));
607 }
608
ParseAndSetTypes(std::vector<OpSchema::FormalParameter> * formal_parameters)609 void OpSchema::ParseAndSetTypes(
610 /*out*/ std::vector<OpSchema::FormalParameter>* formal_parameters) {
611 for (auto& formal_parameter : *formal_parameters) {
612 auto& type = formal_parameter.GetTypeStr();
613 DataTypeSet allowed_types;
614 auto it = type_constraints_.find(type);
615 if (it != type_constraints_.end()) {
616 allowed_types = it->second.first;
617 } else {
618 allowed_types.emplace(Utils::DataTypeUtils::ToType(type));
619 }
620
621 formal_parameter.MutableTypes() = allowed_types;
622 }
623 }
624
FunctionBody(const std::vector<NodeProto> & func_nodes)625 OpSchema& OpSchema::FunctionBody(const std::vector<NodeProto>& func_nodes) {
626 for (const auto node : func_nodes) {
627 auto new_node = function_body_.add_node();
628 new_node->CopyFrom(node);
629 }
630 return *this;
631 }
632
GetFunction() const633 const FunctionProto* OpSchema::GetFunction() const {
634 return function_body_.node_size()>0 ? &function_body_ : nullptr;
635 }
636
FillUsing(const std::function<void (OpSchema &)> & populator)637 OpSchema& OpSchema::FillUsing(const std::function<void(OpSchema&)>& populator) {
638 if (populator) {
639 populator(*this);
640 }
641 return *this;
642 }
643
BuildFunction()644 void OpSchema::BuildFunction(){
645 function_body_.set_name(this->name_);
646 function_body_.set_doc_string(this->doc_);
647 function_body_.set_since_version(this->since_version_);
648 function_body_.set_status(OperatorStatus(1 - (int)this->support_));
649 for (auto& i : inputs_) {
650 function_body_.add_input(i.GetName());
651 }
652 for (auto& o : outputs_) {
653 function_body_.add_output(o.GetName());
654 }
655 for (auto& a : attributes_) {
656 function_body_.add_attribute(a.first);
657 }
658 }
659
Finalize()660 void OpSchema::Finalize() {
661 #define ENFORCE(x) \
662 do { \
663 if (!(x)) \
664 throw std::logic_error( \
665 "ONNX Schema " + name_ + ": failed validating the check: " + #x); \
666 } while (0)
667
668 // Calculate min/max number of inputs.
669 // <Min number of inputs> = <number of "single" inputs> + <number of
670 // "optional" but not trailing inputs>. <Max number of inputs> = <number of
671 // all inputs or std::numeric_limits<int>::max() (if the last input is
672 // variadic).
673
674 // Flag indicates whether an optional input is trailing one (there's no single
675 // or variadic input behind).
676 for (size_t i = 0; i < inputs_.size(); ++i) {
677 switch (inputs_[i].GetOption()) {
678 case OpSchema::Single:
679 ++max_input_;
680 min_input_ = max_input_;
681 break;
682 case OpSchema::Optional:
683 ++max_input_;
684 break;
685 case OpSchema::Variadic:
686 // Only last input formal parameter could be variadic.
687 ENFORCE((inputs_.size() - 1) == i);
688 min_input_ = max_input_ + 1;
689 max_input_ = std::numeric_limits<int>::max();
690 break;
691 }
692 }
693
694 // Calculate min/max number of outputs.
695 for (size_t i = 0; i < outputs_.size(); ++i) {
696 switch (outputs_[i].GetOption()) {
697 case OpSchema::Single:
698 ++max_output_;
699 min_output_ = max_output_;
700 break;
701 case OpSchema::Optional:
702 ++max_output_;
703 break;
704 case OpSchema::Variadic:
705 // Only last output formal parameter could be variadic.
706 ENFORCE((outputs_.size() - 1) == i);
707 min_output_ = max_output_ + 1;
708 max_output_ = std::numeric_limits<int>::max();
709 break;
710 }
711 }
712
713 // all inputs and outputs have names
714 for (const auto& it : inputs_) {
715 ENFORCE(!(it.GetName().empty()));
716 }
717 for (const auto& it : outputs_) {
718 ENFORCE(!(it.GetName().empty()));
719 }
720
721 ParseAndSetTypes(&inputs_);
722 ParseAndSetTypes(&outputs_);
723
724 if (this->HasFunction()) {
725 BuildFunction();
726 }
727 }
728
operator <<(std::ostream & out,const OpSchema & schema)729 std::ostream& operator<<(std::ostream& out, const OpSchema& schema) {
730 if (!schema.attributes_.empty()) {
731 out << "Attributes:" << std::endl;
732 for (const auto& pair : schema.attributes_) {
733 out << " " << pair.second.name << " : " << pair.second.description
734 << std::endl;
735 }
736 }
737 if (schema.max_input_ > 0) {
738 out << "Inputs:" << std::endl;
739 if (!schema.inputs_.empty()) {
740 for (size_t i = 0; i < schema.inputs_.size(); ++i) {
741 const auto& p = schema.inputs_[i];
742 const auto& name = p.GetName();
743 const auto& description = p.GetDescription();
744 const auto& type_str = p.GetTypeStr();
745 out << " " << i << ", " << ("" != name ? name : "(unnamed)") << " : "
746 << ("" != description ? description : "(no doc)") << " : "
747 << ("" != type_str ? type_str : "(no type)") << std::endl;
748 }
749 } else {
750 out << " (no explicit description available)" << std::endl;
751 }
752 }
753 if (schema.max_output_ > 0) {
754 out << "Outputs:" << std::endl;
755 if (!schema.outputs_.empty()) {
756 for (size_t i = 0; i < schema.outputs_.size(); ++i) {
757 const auto& p = schema.outputs_[i];
758 const auto& name = p.GetName();
759 const auto& description = p.GetDescription();
760 const auto& type_str = p.GetTypeStr();
761 out << " " << i << ", " << ("" != name ? name : "(unnamed)") << " : "
762 << ("" != description ? description : "(no doc)") << " : "
763 << ("" != type_str ? type_str : "(no type)") << std::endl;
764 }
765 } else {
766 out << " (no explicit description available)" << std::endl;
767 }
768 }
769 out << std::endl;
770 if (schema.doc()) {
771 out << schema.doc();
772 } else {
773 out << "(no documentation yet)" << std::endl;
774 }
775 out << std::endl;
776 if (schema.line_) {
777 out << "Defined at " << schema.file_ << ":" << schema.line_ << std::endl;
778 }
779 return out;
780 }
781
782 OpSchemaRegistry::DomainToVersionRange&
Instance()783 OpSchemaRegistry::DomainToVersionRange::Instance() {
784 static DomainToVersionRange domain_to_version_range;
785 return domain_to_version_range;
786 };
787
788 // Private method used by OpSchemaRegisterOnce and OpSchemaRegistry::map()
789 OpName_Domain_Version_Schema_Map&
GetMapWithoutEnsuringRegistration()790 OpSchemaRegistry::GetMapWithoutEnsuringRegistration() {
791 static OpName_Domain_Version_Schema_Map map;
792 return map;
793 }
794
map()795 OpName_Domain_Version_Schema_Map& OpSchemaRegistry::map() {
796 auto& map = GetMapWithoutEnsuringRegistration();
797
798 // The following class is used to register operators the
799 // first time this method is called, in a thread-safe fashion.
800 class SchemasRegisterer {
801 public:
802 SchemasRegisterer() {
803 // In debug builds, the number of schema registered in this constructor
804 // is compared against the number of calls to schema registration macros.
805 #ifndef NDEBUG
806 size_t dbg_initial_schema_count = GetRegisteredSchemaCount();
807 #endif
808
809 RegisterOnnxOperatorSetSchema();
810
811 #ifdef ONNX_ML
812 RegisterOnnxMLOperatorSetSchema();
813 #endif
814
815 #ifndef NDEBUG
816 size_t dbg_registered_schema_count =
817 GetRegisteredSchemaCount() - dbg_initial_schema_count;
818
819 ONNX_ASSERTM(
820 dbg_registered_schema_count == ONNX_DBG_GET_COUNT_IN_OPSETS(),
821 "%u schema were exposed from operator sets and automatically placed into the static registry. "
822 "%u were expected based on calls to registration macros. Operator set functions may need to be updated.",
823 dbg_registered_schema_count,
824 ONNX_DBG_GET_COUNT_IN_OPSETS());
825 #endif
826 }
827
828 private:
829 static size_t GetRegisteredSchemaCount() {
830 size_t count = 0;
831 for (auto& x : GetMapWithoutEnsuringRegistration()) {
832 for (auto& y : x.second) {
833 count += y.second.size();
834 }
835 }
836 return count;
837 }
838 };
839
840 #ifndef __ONNX_DISABLE_STATIC_REGISTRATION
841 static SchemasRegisterer schemasRegisterer;
842 #endif
843
844 return map;
845 }
846
ReplaceAll(std::string & s,const char * from,const char * to)847 size_t ReplaceAll(std::string& s, const char* from, const char* to) {
848 size_t numReplaced = 0;
849 std::string::size_type lenFrom = std::strlen(from);
850 std::string::size_type lenTo = std::strlen(to);
851 for (std::string::size_type pos = s.find(from); pos != std::string::npos;
852 pos = s.find(from, pos + lenTo)) {
853 s.replace(pos, lenFrom, to);
854 numReplaced++;
855 }
856 return numReplaced;
857 }
858
859 } // namespace ONNX_NAMESPACE
860