1 /*
2  * SPDX-License-Identifier: Apache-2.0
3  */
4 
5 // ATTENTION: The code in this file is highly EXPERIMENTAL.
6 // Adventurous users should note that the APIs will probably change.
7 
8 #pragma once
9 
10 #include <atomic>
11 #include <algorithm>
12 #include <cstdint>
13 #include <functional>
14 #include <iostream>
15 #include <memory>
16 #include <sstream>
17 #include <stdint.h>
18 #include <string>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "onnx/string_utils.h"
23 #include "onnx/common/array_ref.h"
24 #include "onnx/common/assertions.h"
25 #include "onnx/common/interned_strings.h"
26 #include "onnx/common/graph_node_list.h"
27 #include "onnx/common/tensor.h"
28 #include "onnx/common/common.h"
29 
30 
31 #define ONNX_DISALLOW_COPY_AND_ASSIGN(TypeName) \
32   TypeName(const TypeName&) = delete; \
33   TypeName& operator=(const TypeName&) = delete
34 
35 
36 namespace ONNX_NAMESPACE {
37 
38 // Graph represents one "function" of computation.
39 // It uses a simple ownership model where the graph owns all the nodes inside it.
40 // All references inside the graph are raw pointers.
41 // Destroying the Graph will invalidate any pointers to nodes in the graph.
42 struct Graph;
43 
44 
45 // Node is the base class of the IR graph. It represents one computation
46 // and dependencies on a list of Values. The "prim-ops", so to speak.
47 struct Node;
48 
49 
50 // A Value represents an input or output to node that is either a
51 // Tensor or an opaque Handle object, as determined by type().
52 struct Value;
53 
54 
55 class ResourceGuard final {
56   std::function<void()> destructor_;
57   bool released_;
58 
59 public:
ResourceGuard(std::function<void ()> destructor)60   ResourceGuard(std::function<void()> destructor)
61     : destructor_(std::move(destructor))
62     , released_(false) {}
63 
~ResourceGuard()64   ~ResourceGuard() {
65     if (!released_) destructor_();
66   }
67 
release()68   void release() {
69     released_ = true;
70   }
71 };
72 
73 
74 struct Dimension final {
Dimensionfinal75   Dimension() : is_unknown(true) {}
Dimensionfinal76   Dimension(std::string param)
77     : is_unknown(false), is_int(false), dim(-1), param(std::move(param)) {
78   }
Dimensionfinal79   Dimension(int64_t dim) : is_unknown(false), is_int(true), dim(dim) {}
80 
81   bool is_unknown;
82   bool is_int;
83   int64_t dim;
84   std::string param;
85 };
86 
87 
88 enum class AttributeKind : uint8_t {
89   // float, float list, int, int list, string, string list,
90   // tensor, tensor list, subgraph, subgraph list. type proto, type proto list
91   f, fs, i, is, s, ss, t, ts, g, gs, tp, tps
92 };
93 
94 
toString(AttributeKind kind)95 static inline const char * toString(AttributeKind kind) {
96   static constexpr const char* names[] = {"f","fs", "i", "is", "s", "ss", "t", "ts", "g", "gs", "tp", "tps"};
97   ONNX_ASSERT(size_t(kind) < sizeof(names) / sizeof(const char*));
98   return names[int(kind)];
99 }
100 
101 
102 struct AttributeValue {
AttributeValueAttributeValue103   AttributeValue(Symbol name)
104   : name(name) {}
105   using Ptr = std::unique_ptr<AttributeValue>;
106   Symbol name;
107   virtual AttributeKind kind() const = 0;
108   virtual Ptr clone() const = 0;
109   virtual ~AttributeValue() = default;
110 };
111 
112 
113 template<typename T, AttributeKind Kind>
114 struct ScalarAttributeValue final : public AttributeValue {
115   using ConstructorType = const T &;
116   using ValueType = T;
ScalarAttributeValuefinal117   ScalarAttributeValue(Symbol name, ConstructorType value_)
118   : AttributeValue(name), value_(value_) {}
valuefinal119   ValueType & value() {
120     return value_;
121   }
clonefinal122   virtual Ptr clone() const override {
123     return Ptr(new ScalarAttributeValue(name, value_));
124   }
kindfinal125   virtual AttributeKind kind() const override { return Kind; }
126 
127 private:
128   ValueType value_;
129 };
130 
131 
132 template<typename T, AttributeKind Kind>
133 struct VectorAttributeValue final : public AttributeValue {
134   using ConstructorType = const std::vector<T> &&;
135   using ValueType = std::vector<T>;
VectorAttributeValuefinal136   VectorAttributeValue(Symbol name, ConstructorType value_)
137   : AttributeValue(name), value_(std::move(value_)) {}
valuefinal138   ValueType & value() {
139     return value_;
140   }
kindfinal141   virtual AttributeKind kind() const override { return Kind; }
clonefinal142   virtual std::unique_ptr<AttributeValue> clone() const override {
143     auto copy = value_;
144     return Ptr(new VectorAttributeValue(name, std::move(copy)));
145   }
146 private:
147   ValueType value_;
148 };
149 
150 
151 using FloatAttr = ScalarAttributeValue<double,AttributeKind::f>;
152 using FloatsAttr = VectorAttributeValue<double,AttributeKind::fs>;
153 using IntAttr = ScalarAttributeValue<int64_t,AttributeKind::i>;
154 using IntsAttr = VectorAttributeValue<int64_t,AttributeKind::is>;
155 using StringAttr = ScalarAttributeValue<std::string,AttributeKind::s>;
156 using StringsAttr = VectorAttributeValue<std::string,AttributeKind::ss>;
157 using TensorAttr = ScalarAttributeValue<Tensor,AttributeKind::t>;
158 using TensorsAttr = VectorAttributeValue<Tensor,AttributeKind::ts>;
159 using GraphAttr = ScalarAttributeValue<std::shared_ptr<Graph>,AttributeKind::g>;
160 using GraphsAttr = VectorAttributeValue<std::shared_ptr<Graph>,AttributeKind::gs>;
161 using TypeProtoAttr = ScalarAttributeValue<TypeProto,AttributeKind::tp>;
162 using TypeProtosAttr = VectorAttributeValue<TypeProto,AttributeKind::tps>;
163 
164 
165 // CRTP so that Node which inherits Attributes can be return for
166 // method chaining e.g:
167 // Node * n = g->create(kSelect)->set_i(kOffset,3)->set_f(kValue,3.5);
168 // we return Derived* pointers because Nodes are normally held as pointers.
169 template<typename Derived>
170 struct Attributes {
AttributesAttributes171   Attributes() {}
copyAttributesAttributes172   void copyAttributes(const Attributes & rhs) {
173     values_.clear();
174     values_.reserve(rhs.values_.size());
175     for(auto & i : rhs.values_) {
176       values_.push_back(i->clone());
177     }
178   }
hasAttributeAttributes179   bool hasAttribute(Symbol name) const {
180     return find(name,false) != values_.end();
181   }
kindOfAttributes182   AttributeKind kindOf(Symbol name) const {
183     return (*find(name,true))->kind();
184   }
removeAttributeAttributes185   Derived* removeAttribute(Symbol name) {
186     values_.erase(find(name,true));
187     return This();
188   }
hasAttributesAttributes189   bool hasAttributes() const {
190     return !values_.empty();
191   }
192   // The names are returned in order, since name actually is the index.
attributeNamesAttributes193   std::vector<Symbol> attributeNames() const {
194     std::vector<Symbol> names;
195     names.reserve(values_.size());
196     for(auto & a : values_)
197       names.push_back(a->name);
198     return names;
199   }
200 
201   #define CREATE_ACCESSOR(Kind, method) \
202   Derived* method##_(Symbol name, Kind##Attr::ConstructorType v) { \
203     return set<Kind##Attr>(name,std::forward<Kind##Attr::ConstructorType>(v)); \
204   } \
205   const Kind##Attr::ValueType& method(Symbol name) const { \
206     return get<Kind##Attr>(name); \
207   }
CREATE_ACCESSORAttributes208   CREATE_ACCESSOR(Float,f)
209   CREATE_ACCESSOR(Floats,fs)
210   CREATE_ACCESSOR(String,s)
211   CREATE_ACCESSOR(Strings,ss)
212   CREATE_ACCESSOR(Int,i)
213   CREATE_ACCESSOR(Ints,is)
214   CREATE_ACCESSOR(Tensor,t)
215   CREATE_ACCESSOR(Tensors,ts)
216   CREATE_ACCESSOR(Graph,g)
217   CREATE_ACCESSOR(Graphs,gs)
218   CREATE_ACCESSOR(TypeProto,tp)
219   CREATE_ACCESSOR(TypeProtos,tps)
220 
221   #undef CREATE_ACCESSOR
222 
223 private:
224   Derived* This() {
225     return static_cast<Derived*>(this);
226   }
227   template<typename T>
setAttributes228   Derived* set(Symbol name, typename T::ConstructorType v) {
229     auto it = find(name, false);
230     auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
231     if(it == values_.end()) {
232       values_.push_back(std::move(nv));
233     } else {
234       *it = std::move(nv);
235     }
236     return This();
237   }
238   template<typename T>
getAttributes239   typename T::ValueType & get(Symbol name) const {
240     auto it = find(name, true);
241     T* child = static_cast<T*>(it->get());
242     return child->value();
243   }
244   using AVPtr = AttributeValue::Ptr;
245   // NB: For determinism, we use a vector rather than a hash map.  This does
246   // mean that lookups are O(n), so you shouldn't use Attributes to store
247   // a big pile of messages.
248   std::vector<AVPtr> values_;
249   using iterator = std::vector<AVPtr>::iterator;
findAttributes250   iterator find(Symbol name, bool required) {
251     auto it = std::find_if(values_.begin(), values_.end(),[&](const AVPtr & v) {
252       return v->name == name;
253     });
254     ONNX_ASSERT(!required || it != values_.end());
255     return it;
256   }
257   using const_iterator = std::vector<AVPtr>::const_iterator;
findAttributes258   const_iterator find(Symbol name, bool required) const {
259     auto it = std::find_if(values_.begin(), values_.end(),[&](const AVPtr & v) {
260       return v->name == name;
261     });
262     ONNX_ASSERTM(!required || it != values_.end(),
263         "%s:%u: %s: required undefined attribute '%s'", __FILE__, __LINE__, __func__, name.toString());
264     return it;
265   }
266 };
267 
268 
269 
270 // Each use is represented by this type, see Node::uses()
271 // 'user' is the consumer of the value, offset is the index into
272 // 'user's input this where the produces will be found.
273 struct Use final {
Usefinal274   Use(Node * user, size_t offset)
275   : user(user), offset(offset) {}
276   Node * user;
277   size_t offset;
278 };
279 
280 static inline bool operator==(const Use & a, const Use & b) {
281   return a.user == b.user && a.offset == b.offset;
282 }
283 
284 
285 // the list types are intentionally simple, but we type-def
286 // them here so if we need to change them, refactoring will be easier
287 using node_list = std::vector<Node*>;
288 using value_list = std::vector<Value*>;
289 using use_list = std::vector<Use>;
290 using NodeKind = Symbol;
291 
292 
293 struct Value final {
294   ONNX_DISALLOW_COPY_AND_ASSIGN(Value);
295   Value(Node * node_, size_t offset_);
296 
297 private:
298   friend struct Node;
299   friend struct Graph;
300   Node * node_;
301   size_t offset_;
302   size_t unique_ = 0;          // unique id
303   size_t stage_ = 0;           // 0-forward, 1-backward, 2-double-backward,...
304   use_list uses_in_current_graph_;
305   bool has_unique_name_;
306   std::string unique_name_;
307   int32_t elem_type_;
308   bool has_sizes_;
309   std::vector<Dimension> sizes_;
310 
311 public:
setElemTypefinal312   Value* setElemType(int32_t elem_type) {
313     elem_type_ = elem_type;
314     return this;
315   }
elemTypefinal316   int32_t elemType() const {
317     return elem_type_;
318   }
has_sizesfinal319   bool has_sizes() const { return has_sizes_; }
setSizesfinal320   Value* setSizes(std::vector<Dimension> sizes) {
321     has_sizes_ = true;
322     sizes_ = std::move(sizes);
323     return this;
324   }
wipeSizesfinal325   Value* wipeSizes() {
326     has_sizes_ = false;
327     sizes_ = std::vector<Dimension>();
328     return this;
329   }
sizesfinal330   const std::vector<Dimension>& sizes() const {
331     return sizes_;
332   }
uniquefinal333   size_t unique() const {
334     return unique_;
335   }
has_unique_namefinal336   bool has_unique_name() const {
337     return has_unique_name_;
338   }
uniqueNamefinal339   std::string uniqueName() const {
340     if(has_unique_name())
341       return unique_name_;
342     return ONNX_NAMESPACE::to_string(unique());
343   }
344   Value* setUniqueName(const std::string &name, bool rename_subgraph_captured_nodes=true);
setStagefinal345   Value* setStage(size_t s) {
346     stage_ = s;
347     return this;
348   }
stagefinal349   size_t stage() const {
350     return stage_;
351   }
nodefinal352   Node* node() {
353     return node_;
354   }
offsetfinal355   size_t offset() const {
356     return offset_;
357   }
nodefinal358   const Node * node() const {
359     return node_;
360   }
361   Graph * owningGraph();
362   const Graph * owningGraph() const;
363   // TODO: make this more const correct
364   const use_list uses() const;
365 
366   // Replaces all uses of this node with 'newValue'.
367   //
368   // Given:   %3 = f(%1, %2)
369   //          %4 = g(%3)
370   //          %5 = h(%3, %3)
371   // Execute: %3.replaceAllUsesWith(%6)
372   // Result:  %3 = f(%1, %2)
373   //          %4 = g(%6)
374   //          %5 = h(%6, %6)
375   void replaceAllUsesWith(Value * newValue);
376 
copyMetadatafinal377   Value* copyMetadata(Value * from) {
378     setElemType(from->elemType());
379     setSizes(from->sizes());
380     if (from->has_unique_name()) {
381       setUniqueName(from->uniqueName());
382     }
383     return this;
384   }
385 
386 };
387 
388 
389 struct Node : public Attributes<Node> {
390   ONNX_DISALLOW_COPY_AND_ASSIGN(Node);
391   friend struct Graph;
392   friend struct Value;
393   friend graph_node_list;
394   friend const_graph_node_list;
395   friend graph_node_list_iterator;
396   friend const_graph_node_list_iterator;
397 
398 private:
399   // each node but Return/Param
400   // is associated with exactly one place in the node list...
401   // of the graph_
402   // this circular is a doubly-linked list, the Return node is used as the sentinel for the beginning and end of the list
403   // such that the list never has null pointers
404   // next_in_graph[0] is next pointer
405   // next_in_graph[1] is prev pointer
406   // using an array to allow the same iterator class for forward and reverse node lists
407   // This list represents a topological sort
408 
409   Node* next_in_graph[2] = { nullptr, nullptr };
nextNode410   Node* & next() { return next_in_graph[kNextDirection]; }
prevNode411   Node* & prev() { return next_in_graph[kPrevDirection]; }
nextNode412   Node* const & next() const { return next_in_graph[kNextDirection]; }
prevNode413   Node* const & prev() const { return next_in_graph[kPrevDirection]; }
414 
415   const NodeKind kind_;
416   std::vector<Value*> inputs_;
417   std::vector<Value*> outputs_;
418   Graph* graph_;
419   size_t stage_;
420   bool has_name_;
421   std::string name_;
422   bool has_domain_;
423   std::string domain_;
424   bool has_doc_string_;
425   std::string doc_string_;
426 
427 protected:
428   Node(Graph * graph_, NodeKind kind_); //defined after graph
429 
430 public:
has_nameNode431   bool has_name() const {
432     return has_name_;
433   }
nameNode434   const std::string& name() const {
435     return name_;
436   }
setNameNode437   void setName(std::string name) {
438     has_name_ = true;
439     name_ = std::move(name);
440   }
has_domainNode441   bool has_domain() const {
442     return has_domain_;
443   }
domainNode444   const std::string& domain() const {
445     return domain_;
446   }
setDomainNode447   void setDomain(std::string domain) {
448     has_domain_ = true;
449     domain_ = std::move(domain);
450   }
has_doc_stringNode451   bool has_doc_string() const {
452     return has_doc_string_;
453   }
docStringNode454   const std::string& docString() const {
455     return doc_string_;
456   }
setDocStringNode457   void setDocString(std::string doc_string) {
458     has_doc_string_ = true;
459     doc_string_ = std::move(doc_string);
460   }
kindNode461   NodeKind kind() const {
462     return kind_;
463   }
owningGraphNode464   Graph * owningGraph() {
465     return graph_;
466   }
owningGraphNode467   const Graph * owningGraph() const {
468     return graph_;
469   }
stageNode470   size_t stage() const {
471     return stage_;
472   }
setStageNode473   Node* setStage(size_t s) {
474     stage_ = s;
475     return this;
476   }
477   // NB: This returns an ArrayRef; that means that it will
478   // get invalidated if you resize inputs (e.g., using addInput)
479   // We can't return a std::vector<Node*>& because there's no
480   // way to soundly cast to std::vector<const Node*> (an insane
481   // implementation of std::vector could make this representationally
482   // different.)
inputsNode483   ArrayRef<Value*> inputs() {
484     return inputs_;
485   }
inputsNode486   ArrayRef<const Value*> inputs() const {
487     // Vectors are not convertible in const-ness of elements, but
488     // raw pointers are.
489     return {inputs_.data(), inputs_.size()};
490   }
491   // NB: This returns an ArrayRef; that means that it will
492   // get invalidated if you resize inputs (e.g., using addInput)
493   // We can't return a std::vector<Node*>& because there's no
494   // way to soundly cast to std::vector<const Node*> (an insane
495   // implementation of std::vector could make this representationally
496   // different.)
outputsNode497   ArrayRef<Value*> outputs() {
498     return outputs_;
499   }
outputsNode500   ArrayRef<const Value*> outputs() const {
501     // Vectors are not convertible in const-ness of elements, but
502     // raw pointers are.
503     return {outputs_.data(), outputs_.size()};
504   }
hasUsesNode505   bool hasUses() const {
506     for(auto o : outputs()) {
507       if(!o->uses().empty())
508         return true;
509     }
510     return false;
511   }
replaceAllUsesWithNode512   void replaceAllUsesWith(Node * n) {
513     ONNX_ASSERT(outputs().size() == n->outputs().size());
514     size_t nOutputs = outputs().size();
515     for(size_t i = 0; i < nOutputs; i++) {
516       outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
517     }
518   }
519   // lots of things like chunk have a single input or single output, so we have a
520   // helper to make accessing it easier
inputNode521   Value * input() {
522     ONNX_ASSERT(inputs_.size() == 1);
523     return inputs_.at(0);
524   }
outputNode525   Value * output() {
526     ONNX_ASSERT(outputs_.size() == 1);
527     return outputs_.at(0);
528   }
inputNode529   const Value * input() const {
530     ONNX_ASSERT(inputs_.size() == 1);
531     return inputs_.at(0);
532   }
outputNode533   Value * output() const {
534     ONNX_ASSERT(outputs_.size() == 1);
535     return outputs_.at(0);
536   }
537   // Access a particular input.  This is a checked index.
inputNode538   Value * input(size_t i) {
539     return inputs_.at(i);
540   }
inputNode541   const Value * input(size_t i) const {
542     return inputs_.at(i);
543   }
544 
545   // Graphs
546 
547   // Note [Topological invariant]
548   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
549   // We always maintain an up-to-date topological ordering of all nodes via
550   // the next()/prev() links.  All transformations to graphs must preserve
551   // this topological ordering: for example, it is only valid to 'addInput'
552   // with an input which is topologically before the current node.
553   //
554   // Usually, it is obvious whether or not topological order is maintained;
555   // for example, if you are adding nodes to the end of the topsort, it's
556   // impossible for them to refer to inputs that are not in the topsort.
557   // If it is not obvious, please comment accordingly.
558 
559   // Add 'node' as an input to 'this' at the end of existing
560   // arguments.  Returns the added node for ease of chaining.
561   //
562   // Given:   %3 = f(%1, %2)
563   // Execute: %3.addInput(%4)
564   // Result:  %3 = f(%1, %2, %4)
addInputNode565   Value* addInput(Value * node) {
566     ONNX_ASSERT(graph_ == node->owningGraph());
567     node->uses_in_current_graph_.emplace_back(this, inputs_.size());
568     inputs_.push_back(node);
569     return node;
570   }
571 
572   // Replace the input of 'this' at position 'i' with
573   // 'newValue', returning the old node.
574   //
575   // Given:   %3 = f(%1, %2)
576   // Execute: %3.replaceInput(1, %4)
577   // Result:  %3 = f(%1, %4)
replaceInputNode578   Value * replaceInput(size_t i, Value * newValue) {
579     ONNX_ASSERT(newValue->owningGraph() == graph_);
580     Value * old = dropInput(i);
581     inputs_[i] = newValue;
582     newValue->uses_in_current_graph_.emplace_back(this, i);
583     return old;
584   }
585 
586   // Replace all occurrences of 'from' in the inputs of this
587   // node with 'to'. Corresponds to llvm's replaceUsesOfWith.
588   //
589   // Given:   %3 = f(%1, %2, %1)
590   // Execute: %3.replaceInputWith(%1, %4)
591   // Result:  %3 = f(%4, %2, %4)
replaceInputWithNode592   void replaceInputWith(Value * from, Value * to) {
593     ONNX_ASSERT(from->owningGraph() == graph_);
594     ONNX_ASSERT(to->owningGraph() == graph_);
595     size_t i = 0;
596     for(auto input : inputs()) {
597       if(input == from)
598         replaceInput(i, to);
599       i++;
600     }
601   }
602 
addOutputNode603   Value* addOutput() {
604     outputs_.push_back(new Value(this, outputs_.size()));
605     return outputs_.back();
606   }
607 
608   void eraseOutput(size_t i);
609 
610   // Insert unattached 'this' node after 'n' in the topological order.
611   // Returns this (for chaining).
612   //
613   // Given:   %3 = f(%1, %2)
614   //          %4 = g(%3)
615   // and unattached: %5 = h(%1)
616   // Execute: %5.insertBefore(%4)
617   // Result:  %3 = f(%1, %2)
618   //          %5 = h(%1)
619   //          %4 = g(%3)
insertBeforeNode620   Node* insertBefore(Node * n) {
621     ONNX_ASSERT(n->inGraphList());
622     insertAfter(n->prev());
623     return this;
624   }
625 
626   // Insert unattached 'this' node after 'n' in the topological order.
627   // Returns this (for chaining).
628   //
629   // Given: %3 = f(%1, %2)
630   //        %4 = g(%3)
631   // and unattached: %5 = h(%1)
632   // Execute: %5.insertAfter(%4)
633   // Result:  %3 = f(%1, %2)
634   //          %4 = g(%3)
635   //          %5 = h(%1)
insertAfterNode636   Node* insertAfter(Node * n) {
637     ONNX_ASSERT(!inGraphList() && n->inGraphList());
638     Node * next = n->next();
639     n->next() = this;
640     this->prev() = n;
641     this->next() = next;
642     next->prev() = this;
643     return this;
644   }
645 
646   // Move 'this' (already in the graph) after 'n' in the topological order.
647   //
648   // Given: %2 = f(%1)
649   //        %3 = g(%1)
650   // Execute: %2.moveAfter(%3)
651   // Result: %3 = g(%1)
652   //         %2 = f(%1)
653   //
moveAfterNode654   void moveAfter(Node * n) {
655     removeFromList();
656     insertAfter(n);
657   }
658 
659   // Move a node 'n' (already in the graph) before 'this' in the topological order.
660   //
661   // Given: %2 = f(%1)
662   //        %3 = g(%1)
663   // Execute: %3.moveBefore(%2)
664   // Result: %3 = g(%1)
665   //         %2 = f(%1)
moveBeforeNode666   void moveBefore(Node * n) {
667     removeFromList();
668     insertBefore(n);
669   }
670 
671   // Remove the input at 'i' from this node.
672   //
673   // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling
674   // removeInput.
675   //
676   // Given: %3 = f(%1, %2)
677   // Execute: %3.removeInput(1)
678   // Result: %3 = f(%1)
removeInputNode679   void removeInput(size_t i) {
680     dropInput(i);
681     // everything after this input shifts left,
682     // so we need to update their use offsets to match
683     for(size_t j = i+1; j < inputs_.size(); j++) {
684       auto it = findUseForInput(j);
685       it->offset--;
686     }
687     inputs_.erase(inputs_.begin() + i);
688   }
689 
690   // Remove all inputs from a node.
691   //
692   // Given: %3 = f(%1, %2)
693   // Execute: %3.removeAllInputs()
694   // Result: %3 = f()
removeAllInputsNode695   void removeAllInputs() {
696     for(size_t i = 0; i < inputs().size(); ++i)
697       dropInput(i);
698     inputs_.clear();
699   }
700 
701   // Check whether this node is before node n in the graph.
702   bool isBefore(Node* n);
703 
704   // iterators of the node list starting at this node
705   // useful for resuming a search starting at this node
706   graph_node_list_iterator iterator();
707   graph_node_list_iterator reverseIterator();
708   const_graph_node_list_iterator iterator() const;
709   const_graph_node_list_iterator reverseIterator() const;
710 
711   // Remove 'this' from the instruction list and deallocate it.
712   //
713   // Invariant: no outputs of 'this' may have any uses.
714   //
715   // Given: %2 = f(%1)
716   //        %3 = g(%1)
717   // Execute: %2.destroy()
718   // Result: %3 = g(%1)
719   void destroy();
720 
721   // Dynamically cast this node to the subclass indicated by the
722   // template variable, returning nullptr if the cast is invalid..
723   //
724   // Example usage: if(auto s = n.cast<Select>()) { ... }
725   //
726   // TODO: Make this const correct
727   template<typename T>
castNode728   T* cast() {
729     if(T::Kind == kind())
730       return static_cast<T*>(this);
731     return nullptr;
732   }
733   template<typename T>
expectNode734   T* expect() {
735     ONNX_ASSERTM(T::Kind == kind(), "expected a %s but found a %s", T::Kind.toString(), kind().toString());
736     return static_cast<T*>(this);
737   }
738 
739   virtual ~Node() = default;
740 
741 private:
742   // Lookup iterator in use list of _input i_ that corresponds to its use of _this_
findUseForInputNode743   use_list::iterator findUseForInput(size_t i) {
744     auto & input_uses = inputs_[i]->uses_in_current_graph_;
745     // O(N) on the use list, but unless we get nodes with +100 uses
746     // vector traversal still is probably faster than linked list
747     auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i));
748     ONNX_ASSERT(use_it != input_uses.end());
749     return use_it;
750   }
751 
752   // remove the use of input i, this sets input i to nullptr, but
753   // is only used internally to Node before setting it to a new value
754   // or erasing the entry from the list.
dropInputNode755   Value* dropInput(size_t i) {
756     ONNX_ASSERT(i < inputs_.size());
757     auto input_node = inputs_[i];
758     auto use_it = findUseForInput(i);
759     input_node->uses_in_current_graph_.erase(use_it);
760     inputs_[i] = nullptr;
761     return input_node;
762   }
763 
inGraphListNode764   bool inGraphList() const {
765     ONNX_ASSERT(next() != nullptr || prev() == nullptr);
766     return next() != nullptr;
767   }
removeFromListNode768   void removeFromList() {
769     ONNX_ASSERT(inGraphList());
770     Node * next = this->next();
771     Node * prev = this->prev();
772     prev->next() = next;
773     next->prev() = prev;
774     this->next() = nullptr;
775     this->prev() = nullptr;
776   }
777 
778 protected:
779   // subclasses must override
780   // this function is used by createClone to initialize a new version
781   // of a node in another graph. It should allocate a new instance of the same
782   // concrete type as 'this', but in graph 'g' which might be different
783   // than graph_
allocNewInstanceNode784   virtual Node * allocNewInstance(Graph * g) {
785     return new Node(g, kind());
786   }
787   // create a copy of all properties of Node s into this.
788   // subclasses should extend if they have additional information to copy.
789   // 'this' will be allocated with s->allocNewInstance(g) so it should have
790   // the same concrete type as 's'
791   //
792   // NB: This does NOT clone stages.  You're expected to set the stage correctly
793   // if you are going to preserve it.
cloneFromNode794   virtual void cloneFrom(Node * s) {
795     copyAttributes(*s);
796   }
797 };
798 
799 // A class with the same properties as OperatorSetIdProto, but without protobuf
800 // overhead, resulting in a simpler and more readable workflow.
801 class OpSetID final {
802   private:
803     std::string domain_;
804     int64_t version_;
805 
806   public:
OpSetID(const OperatorSetIdProto & proto)807     explicit OpSetID(const OperatorSetIdProto& proto)
808       :domain_(proto.domain()), version_(proto.version()) {}
809 
810     // Default Domain Constructor
OpSetID(const int64_t version)811     explicit OpSetID(const int64_t version)
812       :domain_(""), version_(version) {}
813 
OpSetID(const std::string & domain,int64_t version)814     explicit OpSetID(const std::string& domain, int64_t version)
815       :domain_(domain), version_(version) {}
816 
817     // target must be in the form "<domain>&<version>"
toString()818     std::string toString() const {
819       return domain_ + "$" + ONNX_NAMESPACE::to_string(version_);
820     }
821 
822     // target must be in the form "<domain>&<version>"
fromString(const std::string & target)823     static OpSetID fromString(const std::string& target) {
824       ONNX_TRY {
825         std::string new_domain = target.substr(0, target.find("$"));
826         int new_version = ONNX_NAMESPACE::stoi(target.substr(target.find("$") + 1, target.length()).c_str());
827         return OpSetID(std::move(new_domain), new_version);
828       } ONNX_CATCH (const std::runtime_error& e) {
829         ONNX_HANDLE_EXCEPTION([&]() {
830           ONNX_ASSERTM(false, "Error in fromString: %s", e.what());
831         });
832       }
833 
834       // The control will never reach here.
835       // In the default build where exceptions are turned on in case of any error
836       // the control will enter catch block where an exception will be thrown again.
837       // In case of "no exception build" the code aborts at the site of first exception.
838       // Adding this to appease the warning "control may reach end of non-void function"
839       // as the mac build fails when ONNX_WERROR==ON
840       return OpSetID("", 0);
841     }
842 
domain()843     const std::string& domain() const {
844       return domain_;
845     }
846 
version()847     int64_t version() const {
848       return version_;
849     }
850 
incrementVersion(int64_t step)851     void incrementVersion(int64_t step) {
852       version_ += step;
853     }
854 
setVersion(int64_t newVal)855     void setVersion(int64_t newVal) {
856       version_ = newVal;
857     }
858 };
859 
860 struct Graph final {
861 ONNX_DISALLOW_COPY_AND_ASSIGN(Graph);
862 friend struct Node;
863 friend struct Value;
864 
865 private:
866   // only used to keep track of allocated nodes
867   // actual representation of Graph is done with
868   // inputs, outputs, nodes
869 
870   std::unordered_set<const Node*> all_nodes;
871   std::unordered_set<const Value*> all_values;
872   size_t next_unique_;
873 
874   size_t new_node_stage_;
875 
876   // holds outputs in a way that can be reflected
877   // as a Use object
878   // also used as the beginning/end of the circular node list to avoid
879   // having corner cases where the list is empty.
880   Node * const output_;
881   Node * const input_;
882 
883   std::vector<Tensor> initializers_;
884   std::vector<std::string> initializer_names_;
885 
886   bool has_name_;
887   std::string name_;
888   bool has_doc_string_;
889   std::string doc_string_;
890 
891   std::vector <OpSetID> opset_versions_;
892 
isNameUniquefinal893   bool isNameUnique(const std::string& name) const {
894     if (std::find(initializer_names_.cbegin(), initializer_names_.cend(), name) !=
895         initializer_names_.cend()) {
896       return false;
897     }
898     const auto f = [&name](const Value* v) { return v->uniqueName() == name; };
899     for (const Node* node : all_nodes) {
900       for (const auto& attr : node->attributeNames()) {
901         if (node->kindOf(attr) == AttributeKind::g) {
902           const auto& subgraph = node->g(attr);
903           if (!subgraph->isNameUnique(name)) {
904             return false;
905           }
906         } else if (node->kindOf(attr) == AttributeKind::gs) {
907           for (const auto& subgraph : node->gs(attr)) {
908             if (!subgraph->isNameUnique(name)) {
909               return false;
910             }
911           }
912         }
913       }
914       const auto found_in =
915           std::find_if(node->inputs().begin(), node->inputs().end(), f);
916       if (found_in != node->inputs().end()) {
917         return false;
918       }
919       const auto found_out =
920           std::find_if(node->outputs().begin(), node->outputs().end(), f);
921       if (found_out != node->outputs().end()) {
922         return false;
923       }
924     }
925     return true;
926   }
927 
928 
929 public:
Graphfinal930   Graph()
931   : next_unique_(0)
932   , new_node_stage_(0)
933   , output_(initOutput(create(kReturn, 0)))
934   , input_(create(kParam, 0))
935   , has_name_(false)
936   , has_doc_string_(false) {}
937 
has_doc_stringfinal938   bool has_doc_string() const {
939     return has_doc_string_;
940   }
docStringfinal941   const std::string& docString() {
942     return doc_string_;
943   }
setDocStringfinal944   void setDocString(std::string doc_string) {
945     has_doc_string_ = true;
946     doc_string_ = std::move(doc_string);
947   }
948 
addInitializerfinal949   void addInitializer(Tensor initializer, std::string name) {
950     initializers_.push_back(std::move(initializer));
951     initializer_names_.push_back(std::move(name));
952   }
eraseInitializerfinal953   void eraseInitializer(const std::string &name) {
954     initializers_.erase(
955         std::remove_if(
956             initializers_.begin(),
957             initializers_.end(),
958             [&name](Tensor& initializer) {
959               return initializer.name() == name;
960             }),
961         initializers_.end());
962     initializer_names_.erase(
963         std::remove(
964             initializer_names_.begin(),
965             initializer_names_.end(),
966             name),
967         initializer_names_.end());
968   }
clearInitializersfinal969   void clearInitializers() {
970     initializers_.clear();
971     initializer_names_.clear();
972   }
initializersfinal973   const std::vector<Tensor>& initializers() const {
974     return initializers_;
975   }
initializer_namesfinal976   const std::vector<std::string>& initializer_names() const {
977     return initializer_names_;
978   }
getInitializerfinal979   std::vector<Tensor>::const_iterator getInitializer(const std::string& name) const {
980     for (auto it = initializers_.cbegin(); it != initializers_.cend(); ++it) {
981       if (name == it->name()) {
982         return it;
983       }
984     }
985     return initializers_.end();
986   }
inputsfinal987   ArrayRef<Value*> inputs() {
988     return input_->outputs();
989   }
inputsfinal990   ArrayRef<const Value*> inputs() const {
991     const auto & inputs = input_->outputs();
992     return {inputs.data(), inputs.size()};
993   }
outputsfinal994   ArrayRef<Value*> outputs() {
995     return output_->inputs();
996   }
outputsfinal997   ArrayRef<const Value*> outputs() const {
998     return static_cast<const Node*>(output_)->inputs();
999   }
nodesfinal1000   graph_node_list nodes() {
1001     return graph_node_list(output_, kNextDirection);
1002   }
nodesfinal1003   const_graph_node_list nodes() const {
1004     return const_graph_node_list(output_, kNextDirection);
1005   }
1006 
opset_versions_mutablefinal1007   std::vector<OpSetID>& opset_versions_mutable() {
1008     return opset_versions_;
1009   }
1010 
getNextUniquefinal1011   size_t getNextUnique() {
1012       std::string next_unique_name = ONNX_NAMESPACE::to_string(++next_unique_);
1013       while(!isNameUnique(next_unique_name)) {
1014           next_unique_name = ONNX_NAMESPACE::to_string(++next_unique_);
1015       }
1016       return next_unique_;
1017   }
1018 
1019   // These invocations of begin() on output of function are OK
1020   // because graph_node_list is non-owning, so it doesn't matter
1021   // if it immediately dies after the invocation.
beginfinal1022   graph_node_list_iterator begin() {
1023     return nodes().begin();
1024   }
beginfinal1025   const_graph_node_list_iterator begin() const {
1026     return nodes().begin();
1027   }
endfinal1028   graph_node_list_iterator end() {
1029     return nodes().end();
1030   }
endfinal1031   const_graph_node_list_iterator end() const {
1032     return nodes().end();
1033   }
rbeginfinal1034   graph_node_list_iterator rbegin() {
1035     return nodes().rbegin();
1036   }
rbeginfinal1037   const_graph_node_list_iterator rbegin() const {
1038     return nodes().rbegin();
1039   }
rendfinal1040   graph_node_list_iterator rend() {
1041     return nodes().rend();
1042   }
rendfinal1043   const_graph_node_list_iterator rend() const {
1044     return nodes().rend();
1045   }
return_nodefinal1046   Node * return_node() {
1047     return output_;
1048   }
return_nodefinal1049   const Node * return_node() const {
1050     return output_;
1051   }
1052 
addInputfinal1053   Value * addInput() {
1054     return input_->addOutput();
1055   }
eraseInputfinal1056   void eraseInput(size_t i) {
1057     input_->eraseOutput(i);
1058   }
advanceStagefinal1059   void advanceStage() {
1060     new_node_stage_++;
1061   }
setStagefinal1062   void setStage(size_t new_stage) {
1063     new_node_stage_ = new_stage;
1064   }
stagefinal1065   size_t stage() const {
1066     return new_node_stage_;
1067   }
setStageTemporaryfinal1068   ResourceGuard setStageTemporary(size_t s) {
1069     auto prev_stage = new_node_stage_;
1070     new_node_stage_ = s;
1071     return ResourceGuard([prev_stage, this]() { this->new_node_stage_ = prev_stage; });
1072   }
1073 
registerOutputfinal1074   size_t registerOutput(Value * n) {
1075     output_->addInput(n);
1076     return outputs().size() - 1;
1077   }
1078 
1079   Node * create(NodeKind kind, size_t num_outputs=1) {
1080     // NB: Node constructor adds node to all_nodes
1081     auto n = new Node(this, kind);
1082     for(size_t i = 0; i < num_outputs; i++)
1083       n->addOutput();
1084     return n;
1085   }
1086 
1087   Node * create(NodeKind kind, ArrayRef<Value*> inputs, size_t num_outputs=1) {
1088     auto n = create(kind, num_outputs);
1089     for(auto i : inputs)
1090       n->addInput(i);
1091     return n;
1092   }
1093 
appendNodefinal1094   Node * appendNode(Node * n) {
1095     ONNX_ASSERT(n->graph_ == this && !n->inGraphList());
1096     n->insertBefore(output_);
1097     return n;
1098   }
1099 
prependNodefinal1100   Node * prependNode(Node * n) {
1101     ONNX_ASSERT(n->graph_ == this && !n->inGraphList());
1102     n->insertAfter(output_);
1103     return n;
1104   }
1105 
1106   //Adds to graph initializer list, initializer names list, and as a graph input
1107   //Also syncs the initializer name, tensor name, and value name
addInitializerAndInputfinal1108   Value* addInitializerAndInput(const Tensor& initializer, std::string name) {
1109     Tensor initializerCopy = initializer;
1110     std::vector<Dimension> dim_sizes{initializerCopy.sizes().cbegin(),
1111                                      initializerCopy.sizes().cend()};
1112     Value* new_init = addInput();
1113     initializerCopy.setName(name);
1114     new_init->setUniqueName(name);
1115     new_init->setSizes(dim_sizes);
1116     new_init->setElemType(initializerCopy.elem_type());
1117     addInitializer(std::move(initializerCopy), std::move(name));
1118     return new_init;
1119   }
1120 
addInitializerAndInputfinal1121   Value* addInitializerAndInput(const Tensor &initializer) {
1122     return addInitializerAndInput(initializer, ONNX_NAMESPACE::to_string(getNextUnique()));
1123   }
1124 
1125 
1126   //Erases from graph initializer list, initializer names list, and as a graph input
1127   //Must have no uses
eraseInitializerAndInputfinal1128   void eraseInitializerAndInput(Value* v) {
1129     eraseInitializer(v->uniqueName());
1130     eraseInput(v->offset());
1131   }
1132 
~Graphfinal1133   ~Graph() {
1134     for (const Node * n : all_nodes)
1135       delete n;
1136     for (const Value * v : all_values)
1137       delete v;
1138   }
1139 
toStringfinal1140   std::string toString() const {
1141     std::ostringstream oss;
1142     oss << *this;
1143     return oss.str();
1144   }
1145 
has_namefinal1146   bool has_name() const {
1147     return has_name_;
1148   }
1149 
namefinal1150   const std::string& name() const {
1151     return name_;
1152   }
1153 
setNamefinal1154   void setName(std::string name) {
1155     has_name_ = true;
1156     name_ = std::move(name);
1157   }
1158 
1159   friend std::ostream& operator<<(std::ostream & out, const Graph & g);
1160 
forSelfAndEachSubGraphfinal1161   void forSelfAndEachSubGraph(std::function<void(Graph*)> fn) {
1162     fn(this);
1163 
1164     for (const Node* node : all_nodes) {
1165       for (const auto& attr : node->attributeNames()) {
1166         if (node->kindOf(attr) == AttributeKind::g) {
1167           std::shared_ptr<Graph> subgraph = node->g(attr);
1168           subgraph->forSelfAndEachSubGraph(fn);
1169         } else if (node->kindOf(attr) == AttributeKind::gs) {
1170           for (const auto& subgraph : node->gs(attr)) {
1171             subgraph->forSelfAndEachSubGraph(fn);
1172           }
1173         }
1174       }
1175     }
1176   }
1177 
forSelfAndEachSubGraphfinal1178   void forSelfAndEachSubGraph(std::function<void(const Graph*)> fn) const {
1179     std::function<void(Graph*)> tmp_fn = [fn](Graph* graph) {fn(graph);};
1180     const_cast<Graph*>(this)->forSelfAndEachSubGraph(tmp_fn);
1181   }
1182 
forEachNodefinal1183   void forEachNode(std::function<void(Node*)> fn) {
1184     forSelfAndEachSubGraph([fn](Graph *graph) {
1185       for(Node* node : graph->nodes()) {
1186         fn(node);
1187       }
1188     });
1189   }
1190 
forEachNodefinal1191   void forEachNode(std::function<void(const Node*)> fn) const {
1192     std::function<void(Node*)> tmp_fn = [fn](Node* node) {fn(node);};
1193     const_cast<Graph*>(this)->forEachNode(tmp_fn);
1194   }
1195 
1196 private:
1197 
1198   // should only be called in the constructor
initOutputfinal1199   Node* initOutput(Node* p) {
1200     p->next() = p;
1201     p->prev() = p;
1202     p->setStage(std::numeric_limits<size_t>::max());
1203     return p;
1204   }
1205 
freeNodefinal1206   void freeNode(Node * n) {
1207     auto it = all_nodes.find(n);
1208     ONNX_ASSERT(it != all_nodes.end());
1209     delete *it;
1210     all_nodes.erase(it);
1211   }
freeValuefinal1212   void freeValue(Value * v) {
1213     auto it = all_values.find(v);
1214     ONNX_ASSERT(it != all_values.end());
1215     all_values.erase(it);
1216   }
1217 };
1218 
Value(Node * node_,size_t offset_)1219 inline Value::Value(Node *node_, size_t offset_)
1220     : node_(node_), offset_(offset_), unique_(node_->graph_->getNextUnique()),
1221       stage_(node_->graph_->new_node_stage_), has_unique_name_(false),
1222       elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED),
1223       has_sizes_(false) {
1224   node_->graph_->all_values.emplace(this);
1225 }
1226 
owningGraph()1227 inline Graph * Value::owningGraph() {
1228   return node()->owningGraph();
1229 }
1230 
owningGraph()1231 inline const Graph * Value::owningGraph() const {
1232   return node()->owningGraph();
1233 }
1234 
1235 // `captured` nodes in subgraph determines which value it captures
1236 // by storing the value's unique name, so old unique names in `captured` nodes
1237 // should also be updated.
setUniqueName(const std::string & name,bool rename_subgraph_captured_nodes)1238 inline Value* Value::setUniqueName(const std::string &name, bool rename_subgraph_captured_nodes) {
1239   if (has_unique_name() && rename_subgraph_captured_nodes) {
1240     auto *graph = owningGraph();
1241     graph->forEachNode([this, &name](Node *node) {
1242       if (node->owningGraph() == this->owningGraph()) {
1243         // skip non-subgraph
1244         return;
1245       }
1246       if (node->kind() == kCaptured) {
1247         Value *output = node->output();
1248         if (output->uniqueName() == this->uniqueName()) {
1249           output->setUniqueName(name, false);
1250         }
1251       }
1252     });
1253   }
1254   unique_name_ = name;
1255   has_unique_name_ = true;
1256   return this;
1257 }
1258 
replaceAllUsesWith(Value * newValue)1259 inline void Value::replaceAllUsesWith(Value * newValue) {
1260   auto* graph = owningGraph();
1261   ONNX_ASSERT(graph == newValue->owningGraph());
1262   // propagate sizes and elem type
1263   if (this->has_sizes()) {
1264     newValue->setSizes(this->sizes());
1265   }
1266   if (this->elemType() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
1267     newValue->setElemType(this->elemType());
1268   }
1269   const auto unique_name = this->uniqueName();
1270   // We do not want the optimization to change the graph output name
1271   if (std::find(graph->outputs().rbegin(), graph->outputs().rend(),
1272                 this) != graph->outputs().rend()) {
1273     newValue->setUniqueName(unique_name);
1274     // The "unique" semantic of unique_name should be kept or uses()
1275     // will return an incorrect result when the value is used in subgraph
1276     this->setUniqueName(ONNX_NAMESPACE::to_string(graph->getNextUnique()), false);
1277   }
1278   newValue->uses_in_current_graph_.reserve(this->uses_in_current_graph_.size());
1279   for(auto u : uses_in_current_graph_) {
1280     u.user->inputs_[u.offset] = newValue;
1281     newValue->uses_in_current_graph_.push_back(u);
1282   }
1283   graph->forEachNode([this, &newValue, &unique_name](Node *node) {
1284     if (node->owningGraph() == this->owningGraph()) {
1285       // skip non-subgraph
1286       return;
1287     }
1288     if (node->kind() == kCaptured) {
1289       Value *output = node->output();
1290       if (output->uniqueName() == unique_name) {
1291         output->setUniqueName(newValue->uniqueName());
1292       }
1293     }
1294   });
1295   uses_in_current_graph_.clear();
1296   assert(this->uses().empty());
1297 }
1298 
Node(Graph * graph_,NodeKind kind_)1299 inline Node::Node(Graph * graph_, NodeKind kind_) :
1300   kind_(kind_),
1301   graph_(graph_),
1302   stage_(graph_->new_node_stage_),
1303   has_name_(false),
1304   has_domain_(false),
1305   has_doc_string_(false) {
1306   graph_->all_nodes.emplace(this);
1307 }
1308 
eraseOutput(size_t i)1309 inline void Node::eraseOutput(size_t i) {
1310   ONNX_ASSERT(i < outputs_.size());
1311   ONNX_ASSERT(outputs_[i]->uses().empty());
1312   Value * n = outputs_[i];
1313   outputs_.erase(outputs_.begin() + i);
1314   owningGraph()->freeValue(n);
1315   for(size_t j = i; j < outputs_.size(); j++) {
1316     outputs_[j]->offset_--;
1317   }
1318 }
1319 
isBefore(Node * n)1320 inline bool Node::isBefore(Node* n) {
1321   if (n == nullptr || this == n) {
1322     // Bail out early.
1323     return false;
1324   }
1325   // return true if node is Param (in initializers)
1326   if (kind_ == kParam) {
1327     return true;
1328   }
1329   // return false if target node is Param (in initializers)
1330   if (n->kind() == kParam) {
1331     return false;
1332   }
1333   ONNX_ASSERT(n->inGraphList());
1334   for (Node* p = next(); p != *graph_->end(); p = p->next()) {
1335     if (p == n) {
1336       return true;
1337     }
1338   }
1339   return false;
1340 }
1341 
destroy()1342 inline void Node::destroy() {
1343   ONNX_ASSERT(inGraphList());
1344   while(!outputs().empty())
1345     eraseOutput(outputs().size() - 1);
1346   removeAllInputs();
1347   removeFromList();
1348   graph_->freeNode(this);
1349 }
1350 
1351 /************* All nodes not required to be defined before Graph **************/
1352 
iterator()1353 inline graph_node_list_iterator Node::iterator() {
1354   return graph_node_list_iterator(this, 0);
1355 }
reverseIterator()1356 inline graph_node_list_iterator Node::reverseIterator() {
1357   return iterator().reverse();
1358 }
iterator()1359 inline const_graph_node_list_iterator Node::iterator() const {
1360   return const_graph_node_list_iterator(this, 0);
1361 }
reverseIterator()1362 inline const_graph_node_list_iterator Node::reverseIterator() const {
1363   return iterator().reverse();
1364 }
1365 
1366 // Returns a list about which nodes are using this value,
1367 // nodes in subgraph are also included.
1368 // This method is usually used to check whether it is
1369 // safe to delete a Value.
uses()1370 inline const use_list Value::uses() const {
1371   use_list all_uses = uses_in_current_graph_;
1372   owningGraph()->forEachNode([this, &all_uses](const Node* node) {
1373     if (node->owningGraph() == this->owningGraph()) {
1374       // skip non-subgraph
1375       return;
1376     }
1377     if (node->kind() == kCaptured) {
1378       const Value* output = node->outputs()[0];
1379       if (output->uniqueName() == this->uniqueName()) {
1380         const auto output_uses = output->uses();
1381         all_uses.insert(all_uses.end(), output_uses.begin(), output_uses.end());
1382       }
1383     }
1384   });
1385   return all_uses;
1386 }
1387 
1388 
1389 } // namespace ONNX_NAMESPACE
1390