1 /*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5 // ATTENTION: The code in this file is highly EXPERIMENTAL.
6 // Adventurous users should note that the APIs will probably change.
7
8 #pragma once
9
10 #include <atomic>
11 #include <algorithm>
12 #include <cstdint>
13 #include <functional>
14 #include <iostream>
15 #include <memory>
16 #include <sstream>
17 #include <stdint.h>
18 #include <string>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "onnx/string_utils.h"
23 #include "onnx/common/array_ref.h"
24 #include "onnx/common/assertions.h"
25 #include "onnx/common/interned_strings.h"
26 #include "onnx/common/graph_node_list.h"
27 #include "onnx/common/tensor.h"
28 #include "onnx/common/common.h"
29
30
31 #define ONNX_DISALLOW_COPY_AND_ASSIGN(TypeName) \
32 TypeName(const TypeName&) = delete; \
33 TypeName& operator=(const TypeName&) = delete
34
35
36 namespace ONNX_NAMESPACE {
37
38 // Graph represents one "function" of computation.
39 // It uses a simple ownership model where the graph owns all the nodes inside it.
40 // All references inside the graph are raw pointers.
41 // Destroying the Graph will invalidate any pointers to nodes in the graph.
42 struct Graph;
43
44
45 // Node is the base class of the IR graph. It represents one computation
46 // and dependencies on a list of Values. The "prim-ops", so to speak.
47 struct Node;
48
49
50 // A Value represents an input or output to node that is either a
51 // Tensor or an opaque Handle object, as determined by type().
52 struct Value;
53
54
55 class ResourceGuard final {
56 std::function<void()> destructor_;
57 bool released_;
58
59 public:
ResourceGuard(std::function<void ()> destructor)60 ResourceGuard(std::function<void()> destructor)
61 : destructor_(std::move(destructor))
62 , released_(false) {}
63
~ResourceGuard()64 ~ResourceGuard() {
65 if (!released_) destructor_();
66 }
67
release()68 void release() {
69 released_ = true;
70 }
71 };
72
73
74 struct Dimension final {
Dimensionfinal75 Dimension() : is_unknown(true) {}
Dimensionfinal76 Dimension(std::string param)
77 : is_unknown(false), is_int(false), dim(-1), param(std::move(param)) {
78 }
Dimensionfinal79 Dimension(int64_t dim) : is_unknown(false), is_int(true), dim(dim) {}
80
81 bool is_unknown;
82 bool is_int;
83 int64_t dim;
84 std::string param;
85 };
86
87
88 enum class AttributeKind : uint8_t {
89 // float, float list, int, int list, string, string list,
90 // tensor, tensor list, subgraph, subgraph list. type proto, type proto list
91 f, fs, i, is, s, ss, t, ts, g, gs, tp, tps
92 };
93
94
toString(AttributeKind kind)95 static inline const char * toString(AttributeKind kind) {
96 static constexpr const char* names[] = {"f","fs", "i", "is", "s", "ss", "t", "ts", "g", "gs", "tp", "tps"};
97 ONNX_ASSERT(size_t(kind) < sizeof(names) / sizeof(const char*));
98 return names[int(kind)];
99 }
100
101
102 struct AttributeValue {
AttributeValueAttributeValue103 AttributeValue(Symbol name)
104 : name(name) {}
105 using Ptr = std::unique_ptr<AttributeValue>;
106 Symbol name;
107 virtual AttributeKind kind() const = 0;
108 virtual Ptr clone() const = 0;
109 virtual ~AttributeValue() = default;
110 };
111
112
113 template<typename T, AttributeKind Kind>
114 struct ScalarAttributeValue final : public AttributeValue {
115 using ConstructorType = const T &;
116 using ValueType = T;
ScalarAttributeValuefinal117 ScalarAttributeValue(Symbol name, ConstructorType value_)
118 : AttributeValue(name), value_(value_) {}
valuefinal119 ValueType & value() {
120 return value_;
121 }
clonefinal122 virtual Ptr clone() const override {
123 return Ptr(new ScalarAttributeValue(name, value_));
124 }
kindfinal125 virtual AttributeKind kind() const override { return Kind; }
126
127 private:
128 ValueType value_;
129 };
130
131
132 template<typename T, AttributeKind Kind>
133 struct VectorAttributeValue final : public AttributeValue {
134 using ConstructorType = const std::vector<T> &&;
135 using ValueType = std::vector<T>;
VectorAttributeValuefinal136 VectorAttributeValue(Symbol name, ConstructorType value_)
137 : AttributeValue(name), value_(std::move(value_)) {}
valuefinal138 ValueType & value() {
139 return value_;
140 }
kindfinal141 virtual AttributeKind kind() const override { return Kind; }
clonefinal142 virtual std::unique_ptr<AttributeValue> clone() const override {
143 auto copy = value_;
144 return Ptr(new VectorAttributeValue(name, std::move(copy)));
145 }
146 private:
147 ValueType value_;
148 };
149
150
151 using FloatAttr = ScalarAttributeValue<double,AttributeKind::f>;
152 using FloatsAttr = VectorAttributeValue<double,AttributeKind::fs>;
153 using IntAttr = ScalarAttributeValue<int64_t,AttributeKind::i>;
154 using IntsAttr = VectorAttributeValue<int64_t,AttributeKind::is>;
155 using StringAttr = ScalarAttributeValue<std::string,AttributeKind::s>;
156 using StringsAttr = VectorAttributeValue<std::string,AttributeKind::ss>;
157 using TensorAttr = ScalarAttributeValue<Tensor,AttributeKind::t>;
158 using TensorsAttr = VectorAttributeValue<Tensor,AttributeKind::ts>;
159 using GraphAttr = ScalarAttributeValue<std::shared_ptr<Graph>,AttributeKind::g>;
160 using GraphsAttr = VectorAttributeValue<std::shared_ptr<Graph>,AttributeKind::gs>;
161 using TypeProtoAttr = ScalarAttributeValue<TypeProto,AttributeKind::tp>;
162 using TypeProtosAttr = VectorAttributeValue<TypeProto,AttributeKind::tps>;
163
164
165 // CRTP so that Node which inherits Attributes can be return for
166 // method chaining e.g:
167 // Node * n = g->create(kSelect)->set_i(kOffset,3)->set_f(kValue,3.5);
168 // we return Derived* pointers because Nodes are normally held as pointers.
169 template<typename Derived>
170 struct Attributes {
AttributesAttributes171 Attributes() {}
copyAttributesAttributes172 void copyAttributes(const Attributes & rhs) {
173 values_.clear();
174 values_.reserve(rhs.values_.size());
175 for(auto & i : rhs.values_) {
176 values_.push_back(i->clone());
177 }
178 }
hasAttributeAttributes179 bool hasAttribute(Symbol name) const {
180 return find(name,false) != values_.end();
181 }
kindOfAttributes182 AttributeKind kindOf(Symbol name) const {
183 return (*find(name,true))->kind();
184 }
removeAttributeAttributes185 Derived* removeAttribute(Symbol name) {
186 values_.erase(find(name,true));
187 return This();
188 }
hasAttributesAttributes189 bool hasAttributes() const {
190 return !values_.empty();
191 }
192 // The names are returned in order, since name actually is the index.
attributeNamesAttributes193 std::vector<Symbol> attributeNames() const {
194 std::vector<Symbol> names;
195 names.reserve(values_.size());
196 for(auto & a : values_)
197 names.push_back(a->name);
198 return names;
199 }
200
201 #define CREATE_ACCESSOR(Kind, method) \
202 Derived* method##_(Symbol name, Kind##Attr::ConstructorType v) { \
203 return set<Kind##Attr>(name,std::forward<Kind##Attr::ConstructorType>(v)); \
204 } \
205 const Kind##Attr::ValueType& method(Symbol name) const { \
206 return get<Kind##Attr>(name); \
207 }
CREATE_ACCESSORAttributes208 CREATE_ACCESSOR(Float,f)
209 CREATE_ACCESSOR(Floats,fs)
210 CREATE_ACCESSOR(String,s)
211 CREATE_ACCESSOR(Strings,ss)
212 CREATE_ACCESSOR(Int,i)
213 CREATE_ACCESSOR(Ints,is)
214 CREATE_ACCESSOR(Tensor,t)
215 CREATE_ACCESSOR(Tensors,ts)
216 CREATE_ACCESSOR(Graph,g)
217 CREATE_ACCESSOR(Graphs,gs)
218 CREATE_ACCESSOR(TypeProto,tp)
219 CREATE_ACCESSOR(TypeProtos,tps)
220
221 #undef CREATE_ACCESSOR
222
223 private:
224 Derived* This() {
225 return static_cast<Derived*>(this);
226 }
227 template<typename T>
setAttributes228 Derived* set(Symbol name, typename T::ConstructorType v) {
229 auto it = find(name, false);
230 auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
231 if(it == values_.end()) {
232 values_.push_back(std::move(nv));
233 } else {
234 *it = std::move(nv);
235 }
236 return This();
237 }
238 template<typename T>
getAttributes239 typename T::ValueType & get(Symbol name) const {
240 auto it = find(name, true);
241 T* child = static_cast<T*>(it->get());
242 return child->value();
243 }
244 using AVPtr = AttributeValue::Ptr;
245 // NB: For determinism, we use a vector rather than a hash map. This does
246 // mean that lookups are O(n), so you shouldn't use Attributes to store
247 // a big pile of messages.
248 std::vector<AVPtr> values_;
249 using iterator = std::vector<AVPtr>::iterator;
findAttributes250 iterator find(Symbol name, bool required) {
251 auto it = std::find_if(values_.begin(), values_.end(),[&](const AVPtr & v) {
252 return v->name == name;
253 });
254 ONNX_ASSERT(!required || it != values_.end());
255 return it;
256 }
257 using const_iterator = std::vector<AVPtr>::const_iterator;
findAttributes258 const_iterator find(Symbol name, bool required) const {
259 auto it = std::find_if(values_.begin(), values_.end(),[&](const AVPtr & v) {
260 return v->name == name;
261 });
262 ONNX_ASSERTM(!required || it != values_.end(),
263 "%s:%u: %s: required undefined attribute '%s'", __FILE__, __LINE__, __func__, name.toString());
264 return it;
265 }
266 };
267
268
269
270 // Each use is represented by this type, see Node::uses()
271 // 'user' is the consumer of the value, offset is the index into
272 // 'user's input this where the produces will be found.
273 struct Use final {
Usefinal274 Use(Node * user, size_t offset)
275 : user(user), offset(offset) {}
276 Node * user;
277 size_t offset;
278 };
279
280 static inline bool operator==(const Use & a, const Use & b) {
281 return a.user == b.user && a.offset == b.offset;
282 }
283
284
285 // the list types are intentionally simple, but we type-def
286 // them here so if we need to change them, refactoring will be easier
287 using node_list = std::vector<Node*>;
288 using value_list = std::vector<Value*>;
289 using use_list = std::vector<Use>;
290 using NodeKind = Symbol;
291
292
293 struct Value final {
294 ONNX_DISALLOW_COPY_AND_ASSIGN(Value);
295 Value(Node * node_, size_t offset_);
296
297 private:
298 friend struct Node;
299 friend struct Graph;
300 Node * node_;
301 size_t offset_;
302 size_t unique_ = 0; // unique id
303 size_t stage_ = 0; // 0-forward, 1-backward, 2-double-backward,...
304 use_list uses_in_current_graph_;
305 bool has_unique_name_;
306 std::string unique_name_;
307 int32_t elem_type_;
308 bool has_sizes_;
309 std::vector<Dimension> sizes_;
310
311 public:
setElemTypefinal312 Value* setElemType(int32_t elem_type) {
313 elem_type_ = elem_type;
314 return this;
315 }
elemTypefinal316 int32_t elemType() const {
317 return elem_type_;
318 }
has_sizesfinal319 bool has_sizes() const { return has_sizes_; }
setSizesfinal320 Value* setSizes(std::vector<Dimension> sizes) {
321 has_sizes_ = true;
322 sizes_ = std::move(sizes);
323 return this;
324 }
wipeSizesfinal325 Value* wipeSizes() {
326 has_sizes_ = false;
327 sizes_ = std::vector<Dimension>();
328 return this;
329 }
sizesfinal330 const std::vector<Dimension>& sizes() const {
331 return sizes_;
332 }
uniquefinal333 size_t unique() const {
334 return unique_;
335 }
has_unique_namefinal336 bool has_unique_name() const {
337 return has_unique_name_;
338 }
uniqueNamefinal339 std::string uniqueName() const {
340 if(has_unique_name())
341 return unique_name_;
342 return ONNX_NAMESPACE::to_string(unique());
343 }
344 Value* setUniqueName(const std::string &name, bool rename_subgraph_captured_nodes=true);
setStagefinal345 Value* setStage(size_t s) {
346 stage_ = s;
347 return this;
348 }
stagefinal349 size_t stage() const {
350 return stage_;
351 }
nodefinal352 Node* node() {
353 return node_;
354 }
offsetfinal355 size_t offset() const {
356 return offset_;
357 }
nodefinal358 const Node * node() const {
359 return node_;
360 }
361 Graph * owningGraph();
362 const Graph * owningGraph() const;
363 // TODO: make this more const correct
364 const use_list uses() const;
365
366 // Replaces all uses of this node with 'newValue'.
367 //
368 // Given: %3 = f(%1, %2)
369 // %4 = g(%3)
370 // %5 = h(%3, %3)
371 // Execute: %3.replaceAllUsesWith(%6)
372 // Result: %3 = f(%1, %2)
373 // %4 = g(%6)
374 // %5 = h(%6, %6)
375 void replaceAllUsesWith(Value * newValue);
376
copyMetadatafinal377 Value* copyMetadata(Value * from) {
378 setElemType(from->elemType());
379 setSizes(from->sizes());
380 if (from->has_unique_name()) {
381 setUniqueName(from->uniqueName());
382 }
383 return this;
384 }
385
386 };
387
388
389 struct Node : public Attributes<Node> {
390 ONNX_DISALLOW_COPY_AND_ASSIGN(Node);
391 friend struct Graph;
392 friend struct Value;
393 friend graph_node_list;
394 friend const_graph_node_list;
395 friend graph_node_list_iterator;
396 friend const_graph_node_list_iterator;
397
398 private:
399 // each node but Return/Param
400 // is associated with exactly one place in the node list...
401 // of the graph_
402 // this circular is a doubly-linked list, the Return node is used as the sentinel for the beginning and end of the list
403 // such that the list never has null pointers
404 // next_in_graph[0] is next pointer
405 // next_in_graph[1] is prev pointer
406 // using an array to allow the same iterator class for forward and reverse node lists
407 // This list represents a topological sort
408
409 Node* next_in_graph[2] = { nullptr, nullptr };
nextNode410 Node* & next() { return next_in_graph[kNextDirection]; }
prevNode411 Node* & prev() { return next_in_graph[kPrevDirection]; }
nextNode412 Node* const & next() const { return next_in_graph[kNextDirection]; }
prevNode413 Node* const & prev() const { return next_in_graph[kPrevDirection]; }
414
415 const NodeKind kind_;
416 std::vector<Value*> inputs_;
417 std::vector<Value*> outputs_;
418 Graph* graph_;
419 size_t stage_;
420 bool has_name_;
421 std::string name_;
422 bool has_domain_;
423 std::string domain_;
424 bool has_doc_string_;
425 std::string doc_string_;
426
427 protected:
428 Node(Graph * graph_, NodeKind kind_); //defined after graph
429
430 public:
has_nameNode431 bool has_name() const {
432 return has_name_;
433 }
nameNode434 const std::string& name() const {
435 return name_;
436 }
setNameNode437 void setName(std::string name) {
438 has_name_ = true;
439 name_ = std::move(name);
440 }
has_domainNode441 bool has_domain() const {
442 return has_domain_;
443 }
domainNode444 const std::string& domain() const {
445 return domain_;
446 }
setDomainNode447 void setDomain(std::string domain) {
448 has_domain_ = true;
449 domain_ = std::move(domain);
450 }
has_doc_stringNode451 bool has_doc_string() const {
452 return has_doc_string_;
453 }
docStringNode454 const std::string& docString() const {
455 return doc_string_;
456 }
setDocStringNode457 void setDocString(std::string doc_string) {
458 has_doc_string_ = true;
459 doc_string_ = std::move(doc_string);
460 }
kindNode461 NodeKind kind() const {
462 return kind_;
463 }
owningGraphNode464 Graph * owningGraph() {
465 return graph_;
466 }
owningGraphNode467 const Graph * owningGraph() const {
468 return graph_;
469 }
stageNode470 size_t stage() const {
471 return stage_;
472 }
setStageNode473 Node* setStage(size_t s) {
474 stage_ = s;
475 return this;
476 }
477 // NB: This returns an ArrayRef; that means that it will
478 // get invalidated if you resize inputs (e.g., using addInput)
479 // We can't return a std::vector<Node*>& because there's no
480 // way to soundly cast to std::vector<const Node*> (an insane
481 // implementation of std::vector could make this representationally
482 // different.)
inputsNode483 ArrayRef<Value*> inputs() {
484 return inputs_;
485 }
inputsNode486 ArrayRef<const Value*> inputs() const {
487 // Vectors are not convertible in const-ness of elements, but
488 // raw pointers are.
489 return {inputs_.data(), inputs_.size()};
490 }
491 // NB: This returns an ArrayRef; that means that it will
492 // get invalidated if you resize inputs (e.g., using addInput)
493 // We can't return a std::vector<Node*>& because there's no
494 // way to soundly cast to std::vector<const Node*> (an insane
495 // implementation of std::vector could make this representationally
496 // different.)
outputsNode497 ArrayRef<Value*> outputs() {
498 return outputs_;
499 }
outputsNode500 ArrayRef<const Value*> outputs() const {
501 // Vectors are not convertible in const-ness of elements, but
502 // raw pointers are.
503 return {outputs_.data(), outputs_.size()};
504 }
hasUsesNode505 bool hasUses() const {
506 for(auto o : outputs()) {
507 if(!o->uses().empty())
508 return true;
509 }
510 return false;
511 }
replaceAllUsesWithNode512 void replaceAllUsesWith(Node * n) {
513 ONNX_ASSERT(outputs().size() == n->outputs().size());
514 size_t nOutputs = outputs().size();
515 for(size_t i = 0; i < nOutputs; i++) {
516 outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
517 }
518 }
519 // lots of things like chunk have a single input or single output, so we have a
520 // helper to make accessing it easier
inputNode521 Value * input() {
522 ONNX_ASSERT(inputs_.size() == 1);
523 return inputs_.at(0);
524 }
outputNode525 Value * output() {
526 ONNX_ASSERT(outputs_.size() == 1);
527 return outputs_.at(0);
528 }
inputNode529 const Value * input() const {
530 ONNX_ASSERT(inputs_.size() == 1);
531 return inputs_.at(0);
532 }
outputNode533 Value * output() const {
534 ONNX_ASSERT(outputs_.size() == 1);
535 return outputs_.at(0);
536 }
537 // Access a particular input. This is a checked index.
inputNode538 Value * input(size_t i) {
539 return inputs_.at(i);
540 }
inputNode541 const Value * input(size_t i) const {
542 return inputs_.at(i);
543 }
544
545 // Graphs
546
547 // Note [Topological invariant]
548 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
549 // We always maintain an up-to-date topological ordering of all nodes via
550 // the next()/prev() links. All transformations to graphs must preserve
551 // this topological ordering: for example, it is only valid to 'addInput'
552 // with an input which is topologically before the current node.
553 //
554 // Usually, it is obvious whether or not topological order is maintained;
555 // for example, if you are adding nodes to the end of the topsort, it's
556 // impossible for them to refer to inputs that are not in the topsort.
557 // If it is not obvious, please comment accordingly.
558
559 // Add 'node' as an input to 'this' at the end of existing
560 // arguments. Returns the added node for ease of chaining.
561 //
562 // Given: %3 = f(%1, %2)
563 // Execute: %3.addInput(%4)
564 // Result: %3 = f(%1, %2, %4)
addInputNode565 Value* addInput(Value * node) {
566 ONNX_ASSERT(graph_ == node->owningGraph());
567 node->uses_in_current_graph_.emplace_back(this, inputs_.size());
568 inputs_.push_back(node);
569 return node;
570 }
571
572 // Replace the input of 'this' at position 'i' with
573 // 'newValue', returning the old node.
574 //
575 // Given: %3 = f(%1, %2)
576 // Execute: %3.replaceInput(1, %4)
577 // Result: %3 = f(%1, %4)
replaceInputNode578 Value * replaceInput(size_t i, Value * newValue) {
579 ONNX_ASSERT(newValue->owningGraph() == graph_);
580 Value * old = dropInput(i);
581 inputs_[i] = newValue;
582 newValue->uses_in_current_graph_.emplace_back(this, i);
583 return old;
584 }
585
586 // Replace all occurrences of 'from' in the inputs of this
587 // node with 'to'. Corresponds to llvm's replaceUsesOfWith.
588 //
589 // Given: %3 = f(%1, %2, %1)
590 // Execute: %3.replaceInputWith(%1, %4)
591 // Result: %3 = f(%4, %2, %4)
replaceInputWithNode592 void replaceInputWith(Value * from, Value * to) {
593 ONNX_ASSERT(from->owningGraph() == graph_);
594 ONNX_ASSERT(to->owningGraph() == graph_);
595 size_t i = 0;
596 for(auto input : inputs()) {
597 if(input == from)
598 replaceInput(i, to);
599 i++;
600 }
601 }
602
addOutputNode603 Value* addOutput() {
604 outputs_.push_back(new Value(this, outputs_.size()));
605 return outputs_.back();
606 }
607
608 void eraseOutput(size_t i);
609
610 // Insert unattached 'this' node after 'n' in the topological order.
611 // Returns this (for chaining).
612 //
613 // Given: %3 = f(%1, %2)
614 // %4 = g(%3)
615 // and unattached: %5 = h(%1)
616 // Execute: %5.insertBefore(%4)
617 // Result: %3 = f(%1, %2)
618 // %5 = h(%1)
619 // %4 = g(%3)
insertBeforeNode620 Node* insertBefore(Node * n) {
621 ONNX_ASSERT(n->inGraphList());
622 insertAfter(n->prev());
623 return this;
624 }
625
626 // Insert unattached 'this' node after 'n' in the topological order.
627 // Returns this (for chaining).
628 //
629 // Given: %3 = f(%1, %2)
630 // %4 = g(%3)
631 // and unattached: %5 = h(%1)
632 // Execute: %5.insertAfter(%4)
633 // Result: %3 = f(%1, %2)
634 // %4 = g(%3)
635 // %5 = h(%1)
insertAfterNode636 Node* insertAfter(Node * n) {
637 ONNX_ASSERT(!inGraphList() && n->inGraphList());
638 Node * next = n->next();
639 n->next() = this;
640 this->prev() = n;
641 this->next() = next;
642 next->prev() = this;
643 return this;
644 }
645
646 // Move 'this' (already in the graph) after 'n' in the topological order.
647 //
648 // Given: %2 = f(%1)
649 // %3 = g(%1)
650 // Execute: %2.moveAfter(%3)
651 // Result: %3 = g(%1)
652 // %2 = f(%1)
653 //
moveAfterNode654 void moveAfter(Node * n) {
655 removeFromList();
656 insertAfter(n);
657 }
658
659 // Move a node 'n' (already in the graph) before 'this' in the topological order.
660 //
661 // Given: %2 = f(%1)
662 // %3 = g(%1)
663 // Execute: %3.moveBefore(%2)
664 // Result: %3 = g(%1)
665 // %2 = f(%1)
moveBeforeNode666 void moveBefore(Node * n) {
667 removeFromList();
668 insertBefore(n);
669 }
670
671 // Remove the input at 'i' from this node.
672 //
673 // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling
674 // removeInput.
675 //
676 // Given: %3 = f(%1, %2)
677 // Execute: %3.removeInput(1)
678 // Result: %3 = f(%1)
removeInputNode679 void removeInput(size_t i) {
680 dropInput(i);
681 // everything after this input shifts left,
682 // so we need to update their use offsets to match
683 for(size_t j = i+1; j < inputs_.size(); j++) {
684 auto it = findUseForInput(j);
685 it->offset--;
686 }
687 inputs_.erase(inputs_.begin() + i);
688 }
689
690 // Remove all inputs from a node.
691 //
692 // Given: %3 = f(%1, %2)
693 // Execute: %3.removeAllInputs()
694 // Result: %3 = f()
removeAllInputsNode695 void removeAllInputs() {
696 for(size_t i = 0; i < inputs().size(); ++i)
697 dropInput(i);
698 inputs_.clear();
699 }
700
701 // Check whether this node is before node n in the graph.
702 bool isBefore(Node* n);
703
704 // iterators of the node list starting at this node
705 // useful for resuming a search starting at this node
706 graph_node_list_iterator iterator();
707 graph_node_list_iterator reverseIterator();
708 const_graph_node_list_iterator iterator() const;
709 const_graph_node_list_iterator reverseIterator() const;
710
711 // Remove 'this' from the instruction list and deallocate it.
712 //
713 // Invariant: no outputs of 'this' may have any uses.
714 //
715 // Given: %2 = f(%1)
716 // %3 = g(%1)
717 // Execute: %2.destroy()
718 // Result: %3 = g(%1)
719 void destroy();
720
721 // Dynamically cast this node to the subclass indicated by the
722 // template variable, returning nullptr if the cast is invalid..
723 //
724 // Example usage: if(auto s = n.cast<Select>()) { ... }
725 //
726 // TODO: Make this const correct
727 template<typename T>
castNode728 T* cast() {
729 if(T::Kind == kind())
730 return static_cast<T*>(this);
731 return nullptr;
732 }
733 template<typename T>
expectNode734 T* expect() {
735 ONNX_ASSERTM(T::Kind == kind(), "expected a %s but found a %s", T::Kind.toString(), kind().toString());
736 return static_cast<T*>(this);
737 }
738
739 virtual ~Node() = default;
740
741 private:
742 // Lookup iterator in use list of _input i_ that corresponds to its use of _this_
findUseForInputNode743 use_list::iterator findUseForInput(size_t i) {
744 auto & input_uses = inputs_[i]->uses_in_current_graph_;
745 // O(N) on the use list, but unless we get nodes with +100 uses
746 // vector traversal still is probably faster than linked list
747 auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i));
748 ONNX_ASSERT(use_it != input_uses.end());
749 return use_it;
750 }
751
752 // remove the use of input i, this sets input i to nullptr, but
753 // is only used internally to Node before setting it to a new value
754 // or erasing the entry from the list.
dropInputNode755 Value* dropInput(size_t i) {
756 ONNX_ASSERT(i < inputs_.size());
757 auto input_node = inputs_[i];
758 auto use_it = findUseForInput(i);
759 input_node->uses_in_current_graph_.erase(use_it);
760 inputs_[i] = nullptr;
761 return input_node;
762 }
763
inGraphListNode764 bool inGraphList() const {
765 ONNX_ASSERT(next() != nullptr || prev() == nullptr);
766 return next() != nullptr;
767 }
removeFromListNode768 void removeFromList() {
769 ONNX_ASSERT(inGraphList());
770 Node * next = this->next();
771 Node * prev = this->prev();
772 prev->next() = next;
773 next->prev() = prev;
774 this->next() = nullptr;
775 this->prev() = nullptr;
776 }
777
778 protected:
779 // subclasses must override
780 // this function is used by createClone to initialize a new version
781 // of a node in another graph. It should allocate a new instance of the same
782 // concrete type as 'this', but in graph 'g' which might be different
783 // than graph_
allocNewInstanceNode784 virtual Node * allocNewInstance(Graph * g) {
785 return new Node(g, kind());
786 }
787 // create a copy of all properties of Node s into this.
788 // subclasses should extend if they have additional information to copy.
789 // 'this' will be allocated with s->allocNewInstance(g) so it should have
790 // the same concrete type as 's'
791 //
792 // NB: This does NOT clone stages. You're expected to set the stage correctly
793 // if you are going to preserve it.
cloneFromNode794 virtual void cloneFrom(Node * s) {
795 copyAttributes(*s);
796 }
797 };
798
799 // A class with the same properties as OperatorSetIdProto, but without protobuf
800 // overhead, resulting in a simpler and more readable workflow.
801 class OpSetID final {
802 private:
803 std::string domain_;
804 int64_t version_;
805
806 public:
OpSetID(const OperatorSetIdProto & proto)807 explicit OpSetID(const OperatorSetIdProto& proto)
808 :domain_(proto.domain()), version_(proto.version()) {}
809
810 // Default Domain Constructor
OpSetID(const int64_t version)811 explicit OpSetID(const int64_t version)
812 :domain_(""), version_(version) {}
813
OpSetID(const std::string & domain,int64_t version)814 explicit OpSetID(const std::string& domain, int64_t version)
815 :domain_(domain), version_(version) {}
816
817 // target must be in the form "<domain>&<version>"
toString()818 std::string toString() const {
819 return domain_ + "$" + ONNX_NAMESPACE::to_string(version_);
820 }
821
822 // target must be in the form "<domain>&<version>"
fromString(const std::string & target)823 static OpSetID fromString(const std::string& target) {
824 ONNX_TRY {
825 std::string new_domain = target.substr(0, target.find("$"));
826 int new_version = ONNX_NAMESPACE::stoi(target.substr(target.find("$") + 1, target.length()).c_str());
827 return OpSetID(std::move(new_domain), new_version);
828 } ONNX_CATCH (const std::runtime_error& e) {
829 ONNX_HANDLE_EXCEPTION([&]() {
830 ONNX_ASSERTM(false, "Error in fromString: %s", e.what());
831 });
832 }
833
834 // The control will never reach here.
835 // In the default build where exceptions are turned on in case of any error
836 // the control will enter catch block where an exception will be thrown again.
837 // In case of "no exception build" the code aborts at the site of first exception.
838 // Adding this to appease the warning "control may reach end of non-void function"
839 // as the mac build fails when ONNX_WERROR==ON
840 return OpSetID("", 0);
841 }
842
domain()843 const std::string& domain() const {
844 return domain_;
845 }
846
version()847 int64_t version() const {
848 return version_;
849 }
850
incrementVersion(int64_t step)851 void incrementVersion(int64_t step) {
852 version_ += step;
853 }
854
setVersion(int64_t newVal)855 void setVersion(int64_t newVal) {
856 version_ = newVal;
857 }
858 };
859
860 struct Graph final {
861 ONNX_DISALLOW_COPY_AND_ASSIGN(Graph);
862 friend struct Node;
863 friend struct Value;
864
865 private:
866 // only used to keep track of allocated nodes
867 // actual representation of Graph is done with
868 // inputs, outputs, nodes
869
870 std::unordered_set<const Node*> all_nodes;
871 std::unordered_set<const Value*> all_values;
872 size_t next_unique_;
873
874 size_t new_node_stage_;
875
876 // holds outputs in a way that can be reflected
877 // as a Use object
878 // also used as the beginning/end of the circular node list to avoid
879 // having corner cases where the list is empty.
880 Node * const output_;
881 Node * const input_;
882
883 std::vector<Tensor> initializers_;
884 std::vector<std::string> initializer_names_;
885
886 bool has_name_;
887 std::string name_;
888 bool has_doc_string_;
889 std::string doc_string_;
890
891 std::vector <OpSetID> opset_versions_;
892
isNameUniquefinal893 bool isNameUnique(const std::string& name) const {
894 if (std::find(initializer_names_.cbegin(), initializer_names_.cend(), name) !=
895 initializer_names_.cend()) {
896 return false;
897 }
898 const auto f = [&name](const Value* v) { return v->uniqueName() == name; };
899 for (const Node* node : all_nodes) {
900 for (const auto& attr : node->attributeNames()) {
901 if (node->kindOf(attr) == AttributeKind::g) {
902 const auto& subgraph = node->g(attr);
903 if (!subgraph->isNameUnique(name)) {
904 return false;
905 }
906 } else if (node->kindOf(attr) == AttributeKind::gs) {
907 for (const auto& subgraph : node->gs(attr)) {
908 if (!subgraph->isNameUnique(name)) {
909 return false;
910 }
911 }
912 }
913 }
914 const auto found_in =
915 std::find_if(node->inputs().begin(), node->inputs().end(), f);
916 if (found_in != node->inputs().end()) {
917 return false;
918 }
919 const auto found_out =
920 std::find_if(node->outputs().begin(), node->outputs().end(), f);
921 if (found_out != node->outputs().end()) {
922 return false;
923 }
924 }
925 return true;
926 }
927
928
929 public:
Graphfinal930 Graph()
931 : next_unique_(0)
932 , new_node_stage_(0)
933 , output_(initOutput(create(kReturn, 0)))
934 , input_(create(kParam, 0))
935 , has_name_(false)
936 , has_doc_string_(false) {}
937
has_doc_stringfinal938 bool has_doc_string() const {
939 return has_doc_string_;
940 }
docStringfinal941 const std::string& docString() {
942 return doc_string_;
943 }
setDocStringfinal944 void setDocString(std::string doc_string) {
945 has_doc_string_ = true;
946 doc_string_ = std::move(doc_string);
947 }
948
addInitializerfinal949 void addInitializer(Tensor initializer, std::string name) {
950 initializers_.push_back(std::move(initializer));
951 initializer_names_.push_back(std::move(name));
952 }
eraseInitializerfinal953 void eraseInitializer(const std::string &name) {
954 initializers_.erase(
955 std::remove_if(
956 initializers_.begin(),
957 initializers_.end(),
958 [&name](Tensor& initializer) {
959 return initializer.name() == name;
960 }),
961 initializers_.end());
962 initializer_names_.erase(
963 std::remove(
964 initializer_names_.begin(),
965 initializer_names_.end(),
966 name),
967 initializer_names_.end());
968 }
clearInitializersfinal969 void clearInitializers() {
970 initializers_.clear();
971 initializer_names_.clear();
972 }
initializersfinal973 const std::vector<Tensor>& initializers() const {
974 return initializers_;
975 }
initializer_namesfinal976 const std::vector<std::string>& initializer_names() const {
977 return initializer_names_;
978 }
getInitializerfinal979 std::vector<Tensor>::const_iterator getInitializer(const std::string& name) const {
980 for (auto it = initializers_.cbegin(); it != initializers_.cend(); ++it) {
981 if (name == it->name()) {
982 return it;
983 }
984 }
985 return initializers_.end();
986 }
inputsfinal987 ArrayRef<Value*> inputs() {
988 return input_->outputs();
989 }
inputsfinal990 ArrayRef<const Value*> inputs() const {
991 const auto & inputs = input_->outputs();
992 return {inputs.data(), inputs.size()};
993 }
outputsfinal994 ArrayRef<Value*> outputs() {
995 return output_->inputs();
996 }
outputsfinal997 ArrayRef<const Value*> outputs() const {
998 return static_cast<const Node*>(output_)->inputs();
999 }
nodesfinal1000 graph_node_list nodes() {
1001 return graph_node_list(output_, kNextDirection);
1002 }
nodesfinal1003 const_graph_node_list nodes() const {
1004 return const_graph_node_list(output_, kNextDirection);
1005 }
1006
opset_versions_mutablefinal1007 std::vector<OpSetID>& opset_versions_mutable() {
1008 return opset_versions_;
1009 }
1010
getNextUniquefinal1011 size_t getNextUnique() {
1012 std::string next_unique_name = ONNX_NAMESPACE::to_string(++next_unique_);
1013 while(!isNameUnique(next_unique_name)) {
1014 next_unique_name = ONNX_NAMESPACE::to_string(++next_unique_);
1015 }
1016 return next_unique_;
1017 }
1018
1019 // These invocations of begin() on output of function are OK
1020 // because graph_node_list is non-owning, so it doesn't matter
1021 // if it immediately dies after the invocation.
beginfinal1022 graph_node_list_iterator begin() {
1023 return nodes().begin();
1024 }
beginfinal1025 const_graph_node_list_iterator begin() const {
1026 return nodes().begin();
1027 }
endfinal1028 graph_node_list_iterator end() {
1029 return nodes().end();
1030 }
endfinal1031 const_graph_node_list_iterator end() const {
1032 return nodes().end();
1033 }
rbeginfinal1034 graph_node_list_iterator rbegin() {
1035 return nodes().rbegin();
1036 }
rbeginfinal1037 const_graph_node_list_iterator rbegin() const {
1038 return nodes().rbegin();
1039 }
rendfinal1040 graph_node_list_iterator rend() {
1041 return nodes().rend();
1042 }
rendfinal1043 const_graph_node_list_iterator rend() const {
1044 return nodes().rend();
1045 }
return_nodefinal1046 Node * return_node() {
1047 return output_;
1048 }
return_nodefinal1049 const Node * return_node() const {
1050 return output_;
1051 }
1052
addInputfinal1053 Value * addInput() {
1054 return input_->addOutput();
1055 }
eraseInputfinal1056 void eraseInput(size_t i) {
1057 input_->eraseOutput(i);
1058 }
advanceStagefinal1059 void advanceStage() {
1060 new_node_stage_++;
1061 }
setStagefinal1062 void setStage(size_t new_stage) {
1063 new_node_stage_ = new_stage;
1064 }
stagefinal1065 size_t stage() const {
1066 return new_node_stage_;
1067 }
setStageTemporaryfinal1068 ResourceGuard setStageTemporary(size_t s) {
1069 auto prev_stage = new_node_stage_;
1070 new_node_stage_ = s;
1071 return ResourceGuard([prev_stage, this]() { this->new_node_stage_ = prev_stage; });
1072 }
1073
registerOutputfinal1074 size_t registerOutput(Value * n) {
1075 output_->addInput(n);
1076 return outputs().size() - 1;
1077 }
1078
1079 Node * create(NodeKind kind, size_t num_outputs=1) {
1080 // NB: Node constructor adds node to all_nodes
1081 auto n = new Node(this, kind);
1082 for(size_t i = 0; i < num_outputs; i++)
1083 n->addOutput();
1084 return n;
1085 }
1086
1087 Node * create(NodeKind kind, ArrayRef<Value*> inputs, size_t num_outputs=1) {
1088 auto n = create(kind, num_outputs);
1089 for(auto i : inputs)
1090 n->addInput(i);
1091 return n;
1092 }
1093
appendNodefinal1094 Node * appendNode(Node * n) {
1095 ONNX_ASSERT(n->graph_ == this && !n->inGraphList());
1096 n->insertBefore(output_);
1097 return n;
1098 }
1099
prependNodefinal1100 Node * prependNode(Node * n) {
1101 ONNX_ASSERT(n->graph_ == this && !n->inGraphList());
1102 n->insertAfter(output_);
1103 return n;
1104 }
1105
1106 //Adds to graph initializer list, initializer names list, and as a graph input
1107 //Also syncs the initializer name, tensor name, and value name
addInitializerAndInputfinal1108 Value* addInitializerAndInput(const Tensor& initializer, std::string name) {
1109 Tensor initializerCopy = initializer;
1110 std::vector<Dimension> dim_sizes{initializerCopy.sizes().cbegin(),
1111 initializerCopy.sizes().cend()};
1112 Value* new_init = addInput();
1113 initializerCopy.setName(name);
1114 new_init->setUniqueName(name);
1115 new_init->setSizes(dim_sizes);
1116 new_init->setElemType(initializerCopy.elem_type());
1117 addInitializer(std::move(initializerCopy), std::move(name));
1118 return new_init;
1119 }
1120
addInitializerAndInputfinal1121 Value* addInitializerAndInput(const Tensor &initializer) {
1122 return addInitializerAndInput(initializer, ONNX_NAMESPACE::to_string(getNextUnique()));
1123 }
1124
1125
1126 //Erases from graph initializer list, initializer names list, and as a graph input
1127 //Must have no uses
eraseInitializerAndInputfinal1128 void eraseInitializerAndInput(Value* v) {
1129 eraseInitializer(v->uniqueName());
1130 eraseInput(v->offset());
1131 }
1132
~Graphfinal1133 ~Graph() {
1134 for (const Node * n : all_nodes)
1135 delete n;
1136 for (const Value * v : all_values)
1137 delete v;
1138 }
1139
toStringfinal1140 std::string toString() const {
1141 std::ostringstream oss;
1142 oss << *this;
1143 return oss.str();
1144 }
1145
has_namefinal1146 bool has_name() const {
1147 return has_name_;
1148 }
1149
namefinal1150 const std::string& name() const {
1151 return name_;
1152 }
1153
setNamefinal1154 void setName(std::string name) {
1155 has_name_ = true;
1156 name_ = std::move(name);
1157 }
1158
1159 friend std::ostream& operator<<(std::ostream & out, const Graph & g);
1160
forSelfAndEachSubGraphfinal1161 void forSelfAndEachSubGraph(std::function<void(Graph*)> fn) {
1162 fn(this);
1163
1164 for (const Node* node : all_nodes) {
1165 for (const auto& attr : node->attributeNames()) {
1166 if (node->kindOf(attr) == AttributeKind::g) {
1167 std::shared_ptr<Graph> subgraph = node->g(attr);
1168 subgraph->forSelfAndEachSubGraph(fn);
1169 } else if (node->kindOf(attr) == AttributeKind::gs) {
1170 for (const auto& subgraph : node->gs(attr)) {
1171 subgraph->forSelfAndEachSubGraph(fn);
1172 }
1173 }
1174 }
1175 }
1176 }
1177
forSelfAndEachSubGraphfinal1178 void forSelfAndEachSubGraph(std::function<void(const Graph*)> fn) const {
1179 std::function<void(Graph*)> tmp_fn = [fn](Graph* graph) {fn(graph);};
1180 const_cast<Graph*>(this)->forSelfAndEachSubGraph(tmp_fn);
1181 }
1182
forEachNodefinal1183 void forEachNode(std::function<void(Node*)> fn) {
1184 forSelfAndEachSubGraph([fn](Graph *graph) {
1185 for(Node* node : graph->nodes()) {
1186 fn(node);
1187 }
1188 });
1189 }
1190
forEachNodefinal1191 void forEachNode(std::function<void(const Node*)> fn) const {
1192 std::function<void(Node*)> tmp_fn = [fn](Node* node) {fn(node);};
1193 const_cast<Graph*>(this)->forEachNode(tmp_fn);
1194 }
1195
1196 private:
1197
1198 // should only be called in the constructor
initOutputfinal1199 Node* initOutput(Node* p) {
1200 p->next() = p;
1201 p->prev() = p;
1202 p->setStage(std::numeric_limits<size_t>::max());
1203 return p;
1204 }
1205
freeNodefinal1206 void freeNode(Node * n) {
1207 auto it = all_nodes.find(n);
1208 ONNX_ASSERT(it != all_nodes.end());
1209 delete *it;
1210 all_nodes.erase(it);
1211 }
freeValuefinal1212 void freeValue(Value * v) {
1213 auto it = all_values.find(v);
1214 ONNX_ASSERT(it != all_values.end());
1215 all_values.erase(it);
1216 }
1217 };
1218
Value(Node * node_,size_t offset_)1219 inline Value::Value(Node *node_, size_t offset_)
1220 : node_(node_), offset_(offset_), unique_(node_->graph_->getNextUnique()),
1221 stage_(node_->graph_->new_node_stage_), has_unique_name_(false),
1222 elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED),
1223 has_sizes_(false) {
1224 node_->graph_->all_values.emplace(this);
1225 }
1226
owningGraph()1227 inline Graph * Value::owningGraph() {
1228 return node()->owningGraph();
1229 }
1230
owningGraph()1231 inline const Graph * Value::owningGraph() const {
1232 return node()->owningGraph();
1233 }
1234
1235 // `captured` nodes in subgraph determines which value it captures
1236 // by storing the value's unique name, so old unique names in `captured` nodes
1237 // should also be updated.
setUniqueName(const std::string & name,bool rename_subgraph_captured_nodes)1238 inline Value* Value::setUniqueName(const std::string &name, bool rename_subgraph_captured_nodes) {
1239 if (has_unique_name() && rename_subgraph_captured_nodes) {
1240 auto *graph = owningGraph();
1241 graph->forEachNode([this, &name](Node *node) {
1242 if (node->owningGraph() == this->owningGraph()) {
1243 // skip non-subgraph
1244 return;
1245 }
1246 if (node->kind() == kCaptured) {
1247 Value *output = node->output();
1248 if (output->uniqueName() == this->uniqueName()) {
1249 output->setUniqueName(name, false);
1250 }
1251 }
1252 });
1253 }
1254 unique_name_ = name;
1255 has_unique_name_ = true;
1256 return this;
1257 }
1258
replaceAllUsesWith(Value * newValue)1259 inline void Value::replaceAllUsesWith(Value * newValue) {
1260 auto* graph = owningGraph();
1261 ONNX_ASSERT(graph == newValue->owningGraph());
1262 // propagate sizes and elem type
1263 if (this->has_sizes()) {
1264 newValue->setSizes(this->sizes());
1265 }
1266 if (this->elemType() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
1267 newValue->setElemType(this->elemType());
1268 }
1269 const auto unique_name = this->uniqueName();
1270 // We do not want the optimization to change the graph output name
1271 if (std::find(graph->outputs().rbegin(), graph->outputs().rend(),
1272 this) != graph->outputs().rend()) {
1273 newValue->setUniqueName(unique_name);
1274 // The "unique" semantic of unique_name should be kept or uses()
1275 // will return an incorrect result when the value is used in subgraph
1276 this->setUniqueName(ONNX_NAMESPACE::to_string(graph->getNextUnique()), false);
1277 }
1278 newValue->uses_in_current_graph_.reserve(this->uses_in_current_graph_.size());
1279 for(auto u : uses_in_current_graph_) {
1280 u.user->inputs_[u.offset] = newValue;
1281 newValue->uses_in_current_graph_.push_back(u);
1282 }
1283 graph->forEachNode([this, &newValue, &unique_name](Node *node) {
1284 if (node->owningGraph() == this->owningGraph()) {
1285 // skip non-subgraph
1286 return;
1287 }
1288 if (node->kind() == kCaptured) {
1289 Value *output = node->output();
1290 if (output->uniqueName() == unique_name) {
1291 output->setUniqueName(newValue->uniqueName());
1292 }
1293 }
1294 });
1295 uses_in_current_graph_.clear();
1296 assert(this->uses().empty());
1297 }
1298
Node(Graph * graph_,NodeKind kind_)1299 inline Node::Node(Graph * graph_, NodeKind kind_) :
1300 kind_(kind_),
1301 graph_(graph_),
1302 stage_(graph_->new_node_stage_),
1303 has_name_(false),
1304 has_domain_(false),
1305 has_doc_string_(false) {
1306 graph_->all_nodes.emplace(this);
1307 }
1308
eraseOutput(size_t i)1309 inline void Node::eraseOutput(size_t i) {
1310 ONNX_ASSERT(i < outputs_.size());
1311 ONNX_ASSERT(outputs_[i]->uses().empty());
1312 Value * n = outputs_[i];
1313 outputs_.erase(outputs_.begin() + i);
1314 owningGraph()->freeValue(n);
1315 for(size_t j = i; j < outputs_.size(); j++) {
1316 outputs_[j]->offset_--;
1317 }
1318 }
1319
isBefore(Node * n)1320 inline bool Node::isBefore(Node* n) {
1321 if (n == nullptr || this == n) {
1322 // Bail out early.
1323 return false;
1324 }
1325 // return true if node is Param (in initializers)
1326 if (kind_ == kParam) {
1327 return true;
1328 }
1329 // return false if target node is Param (in initializers)
1330 if (n->kind() == kParam) {
1331 return false;
1332 }
1333 ONNX_ASSERT(n->inGraphList());
1334 for (Node* p = next(); p != *graph_->end(); p = p->next()) {
1335 if (p == n) {
1336 return true;
1337 }
1338 }
1339 return false;
1340 }
1341
destroy()1342 inline void Node::destroy() {
1343 ONNX_ASSERT(inGraphList());
1344 while(!outputs().empty())
1345 eraseOutput(outputs().size() - 1);
1346 removeAllInputs();
1347 removeFromList();
1348 graph_->freeNode(this);
1349 }
1350
1351 /************* All nodes not required to be defined before Graph **************/
1352
iterator()1353 inline graph_node_list_iterator Node::iterator() {
1354 return graph_node_list_iterator(this, 0);
1355 }
reverseIterator()1356 inline graph_node_list_iterator Node::reverseIterator() {
1357 return iterator().reverse();
1358 }
iterator()1359 inline const_graph_node_list_iterator Node::iterator() const {
1360 return const_graph_node_list_iterator(this, 0);
1361 }
reverseIterator()1362 inline const_graph_node_list_iterator Node::reverseIterator() const {
1363 return iterator().reverse();
1364 }
1365
1366 // Returns a list about which nodes are using this value,
1367 // nodes in subgraph are also included.
1368 // This method is usually used to check whether it is
1369 // safe to delete a Value.
uses()1370 inline const use_list Value::uses() const {
1371 use_list all_uses = uses_in_current_graph_;
1372 owningGraph()->forEachNode([this, &all_uses](const Node* node) {
1373 if (node->owningGraph() == this->owningGraph()) {
1374 // skip non-subgraph
1375 return;
1376 }
1377 if (node->kind() == kCaptured) {
1378 const Value* output = node->outputs()[0];
1379 if (output->uniqueName() == this->uniqueName()) {
1380 const auto output_uses = output->uses();
1381 all_uses.insert(all_uses.end(), output_uses.begin(), output_uses.end());
1382 }
1383 }
1384 });
1385 return all_uses;
1386 }
1387
1388
1389 } // namespace ONNX_NAMESPACE
1390