1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "glow/Base/Type.h"
18 #include "glow/Graph/Graph.h"
19 #include "glow/Graph/Nodes.h"
20 #include "glow/Graph/VerifierHelper.h"
21 #include "glow/Support/Support.h"
22 
23 using namespace glow;
24 
setPredicate(const NodeValue & P)25 void Node::setPredicate(const NodeValue &P) { predicate_ = P; }
26 
hasPredicate() const27 bool Node::hasPredicate() const { return predicate_.getNode(); }
28 
getType(unsigned idx) const29 TypeRef Node::getType(unsigned idx) const {
30   assert(idx < getNumResults() && "Result number does not exist.");
31   return types_[idx];
32 }
33 
setType(unsigned idx,TypeRef ty)34 void Node::setType(unsigned idx, TypeRef ty) {
35   assert(types_[idx]->dims() == ty->dims() &&
36          "Better create a new node at this point");
37   setTypeUnsafe(idx, ty);
38 }
39 
setTypeUnsafe(unsigned idx,TypeRef ty)40 void Node::setTypeUnsafe(unsigned idx, TypeRef ty) {
41   assert(idx < getNumResults() && "Result number does not exist.");
42   types_[idx] = ty;
43 }
44 
getElementType(unsigned resNo) const45 ElemKind Node::getElementType(unsigned resNo) const {
46   TypeRef TR = getType(resNo);
47   return TR->getElementType();
48 }
49 
dims(unsigned resNo) const50 llvm::ArrayRef<dim_t> Node::dims(unsigned resNo) const {
51   TypeRef TR = getType(resNo);
52   return TR->dims();
53 }
54 
addResult(TypeRef T)55 void Node::addResult(TypeRef T) { types_.push_back(T); }
56 
isEqual(const Node & other) const57 bool Node::isEqual(const Node &other) const {
58   if (this == &other)
59     return true;
60 
61   if (getKind() != other.getKind())
62     return false;
63 
64   switch (getKind()) {
65 #define DEF_NODE(CLASS, NAME)                                                  \
66   case glow::Kinded::Kind::CLASS##Kind:                                        \
67     return static_cast<const CLASS *>(this)->isEqual(                          \
68         *static_cast<const CLASS *>(&other));
69 #include "glow/AutoGenNodes.def"
70 
71 #define DEF_INSTR(CLASS, NAME) case glow::Kinded::Kind::CLASS##Kind:
72 #define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME) DEF_INSTR(CLASS, NAME)
73 #define DEF_VALUE(CLASS, NAME) DEF_INSTR(CLASS, NAME)
74 #include "glow/AutoGenInstr.def"
75 
76     llvm_unreachable(
77         "Not reachable, values and instructions are not handled here");
78   }
79   return false;
80 }
81 
getPredicate() const82 const NodeValue Node::getPredicate() const { return predicate_; }
83 
84 namespace {
85 class HashNodeVisitor : public NodeVisitor<HashNodeVisitor, llvm::hash_code> {
86   using hash_code = llvm::hash_code;
87   using super = NodeVisitor;
88 
89 public:
90 #define DEF_NODE(CLASS, NAME)                                                  \
91   hash_code visit##CLASS(const CLASS *N) const { return N->getHash(); }
92 #include "glow/AutoGenNodes.def"
93 
visit(const Node * N) const94   hash_code visit(const Node *N) const {
95     return const_cast<HashNodeVisitor *>(this)->super::visit(
96         const_cast<Node *>(N));
97   }
98 };
99 
100 } // namespace
101 
getHash() const102 llvm::hash_code Node::getHash() const { return HashNodeVisitor().visit(this); }
103 
visit(Node * parent,NodeWalker * visitor)104 void Node::visit(Node *parent, NodeWalker *visitor) {
105   if (hasPredicate()) {
106     getPredicate().getNode()->visit(this, visitor);
107   }
108 
109   switch (getKind()) {
110 #define DEF_NODE(CLASS, NAME)                                                  \
111   case glow::Kinded::Kind::CLASS##Kind:                                        \
112     return static_cast<CLASS *>(this)->visit(parent, visitor);
113 #include "glow/AutoGenNodes.def"
114   default:
115     llvm_unreachable("Unhandled node");
116   }
117 }
118 
119 //===----------------------------------------------------------------------===//
120 //                     Debug description methods
121 //===----------------------------------------------------------------------===//
122 
getNumInputs() const123 unsigned Node::getNumInputs() const {
124   switch (getKind()) {
125 #define DEF_NODE(CLASS, NAME)                                                  \
126   case glow::Kinded::Kind::CLASS##Kind:                                        \
127     return static_cast<const CLASS *>(this)->getNumInputs();
128 #include "glow/AutoGenNodes.def"
129   default:
130     llvm_unreachable("Unhandled node");
131   }
132 }
133 
getInputName(unsigned idx) const134 std::string Node::getInputName(unsigned idx) const {
135   switch (getKind()) {
136 #define DEF_NODE(CLASS, NAME)                                                  \
137   case glow::Kinded::Kind::CLASS##Kind:                                        \
138     return static_cast<const CLASS *>(this)->getInputName(idx);
139 #include "glow/AutoGenNodes.def"
140   default:
141     llvm_unreachable("Unhandled node");
142   }
143 }
144 
getNthInput(unsigned idx)145 NodeValue Node::getNthInput(unsigned idx) {
146   switch (getKind()) {
147 #define DEF_NODE(CLASS, NAME)                                                  \
148   case glow::Kinded::Kind::CLASS##Kind:                                        \
149     return static_cast<CLASS *>(this)->getNthInput(idx);
150 #include "glow/AutoGenNodes.def"
151   default:
152     llvm_unreachable("Unhandled node");
153   }
154 }
155 
getNthInput(unsigned idx) const156 const NodeValue Node::getNthInput(unsigned idx) const {
157   switch (getKind()) {
158 #define DEF_NODE(CLASS, NAME)                                                  \
159   case glow::Kinded::Kind::CLASS##Kind:                                        \
160     return static_cast<CLASS *>(const_cast<Node *>(this))->getNthInput(idx);
161 #include "glow/AutoGenNodes.def"
162   default:
163     llvm_unreachable("Unhandled node");
164   }
165 }
166 
setNthInput(unsigned idx,NodeValue val)167 void Node::setNthInput(unsigned idx, NodeValue val) {
168   switch (getKind()) {
169 #define DEF_NODE(CLASS, NAME)                                                  \
170   case glow::Kinded::Kind::CLASS##Kind:                                        \
171     if (getParent()) {                                                         \
172       getParent()->getLogContext()->logNodeInputChange(                        \
173           *this, this->getNthInput(idx), val);                                 \
174     }                                                                          \
175     return static_cast<CLASS *>(this)->setNthInput(idx, val);
176 #include "glow/AutoGenNodes.def"
177   default:
178     llvm_unreachable("Unhandled node");
179   }
180 }
181 
getNthResult(unsigned idx)182 NodeValue Node::getNthResult(unsigned idx) {
183   assert(idx < getNumResults());
184   return NodeValue(this, idx);
185 }
186 
getNthResult(unsigned idx) const187 const NodeValue Node::getNthResult(unsigned idx) const {
188   assert(idx < getNumResults());
189   return NodeValue(const_cast<Node *>(this), idx);
190 }
191 
getOutputName(unsigned idx) const192 llvm::StringRef Node::getOutputName(unsigned idx) const {
193   switch (getKind()) {
194 #define DEF_NODE(CLASS, NAME)                                                  \
195   case glow::Kinded::Kind::CLASS##Kind:                                        \
196     return static_cast<const CLASS *>(this)->getOutputName(idx);
197 #include "glow/AutoGenNodes.def"
198   default:
199     llvm_unreachable("Unhandled node");
200   }
201 }
202 
hasSideEffects() const203 bool Node::hasSideEffects() const {
204   switch (getKind()) {
205 #define DEF_NODE(CLASS, NAME)                                                  \
206   case glow::Kinded::Kind::CLASS##Kind:                                        \
207     return static_cast<const CLASS *>(this)->hasSideEffects();
208 #include "glow/AutoGenNodes.def"
209   default:
210     llvm_unreachable("Unhandled node");
211   }
212 }
213 
isCanonical() const214 bool Node::isCanonical() const {
215   switch (getKind()) {
216 #define DEF_NODE(CLASS, NAME)                                                  \
217   case glow::Kinded::Kind::CLASS##Kind:                                        \
218     return static_cast<const CLASS *>(this)->isCanonical();
219 #include "glow/AutoGenNodes.def"
220   default:
221     llvm_unreachable("Unhandled node");
222   }
223 }
224 
isDataParallel() const225 bool Node::isDataParallel() const {
226   switch (getKind()) {
227 #define DEF_NODE(CLASS, NAME)                                                  \
228   case glow::Kinded::Kind::CLASS##Kind:                                        \
229     return static_cast<const CLASS *>(this)->isDataParallel();
230 #include "glow/AutoGenNodes.def"
231   default:
232     llvm_unreachable("Unhandled node");
233   }
234 }
235 
236 // NOTE: This is used in conjunction with assuming the 1st input is LHS, and 2nd
237 // input is RHS, and 1st result is Result.
isArithmetic() const238 bool Node::isArithmetic() const {
239   // Each case includes a static assert that the generated nodes that we
240   // consider arithmetic have the expected format/order of LHS, RHS, Result.
241 #define ARITHMETIC_NODE_CASE(NODE_NAME_)                                       \
242   static_assert((NODE_NAME_##Node::LHSIdx == ArithmeticNode::LHSIdx &&         \
243                  NODE_NAME_##Node::RHSIdx == ArithmeticNode::RHSIdx &&         \
244                  NODE_NAME_##Node::ResultIdx == ArithmeticNode::ResultIdx),    \
245                 #NODE_NAME_                                                    \
246                 "Node does not match expected arithmetic node format.");       \
247   case glow::Kinded::Kind::NODE_NAME_##NodeKind:
248 
249   switch (getKind()) {
250     ARITHMETIC_NODE_CASE(Add)
251     ARITHMETIC_NODE_CASE(Mul)
252     ARITHMETIC_NODE_CASE(Sub)
253     ARITHMETIC_NODE_CASE(Div)
254     ARITHMETIC_NODE_CASE(Max)
255     ARITHMETIC_NODE_CASE(Min)
256     ARITHMETIC_NODE_CASE(CmpLTE)
257     ARITHMETIC_NODE_CASE(CmpLT)
258     ARITHMETIC_NODE_CASE(CmpEQ)
259     ARITHMETIC_NODE_CASE(Pow)
260     return true;
261   default:
262     return false;
263   }
264 #undef ARITHMETIC_NODE_CASE
265 }
266 
isOverwrittenNthInput(unsigned idx) const267 bool Node::isOverwrittenNthInput(unsigned idx) const {
268   switch (getKind()) {
269 #define DEF_NODE(CLASS, NAME)                                                  \
270   case glow::Kinded::Kind::CLASS##Kind:                                        \
271     return static_cast<const CLASS *>(this)->isOverwrittenNthInput(idx);
272 #include "glow/AutoGenNodes.def"
273   default:
274     llvm_unreachable("Unhandled node");
275   }
276 }
277 
getDebugDesc() const278 std::string Node::getDebugDesc() const {
279   switch (getKind()) {
280 #define DEF_NODE(CLASS, NAME)                                                  \
281   case glow::Kinded::Kind::CLASS##Kind:                                        \
282     return static_cast<const CLASS *>(this)->getDebugDesc();
283 #include "glow/AutoGenNodes.def"
284   default:
285     llvm_unreachable("Unhandled node");
286   }
287 }
288 
dump(llvm::raw_ostream & out) const289 void Node::dump(llvm::raw_ostream &out) const { out << this->getDebugDesc(); }
290 
dump() const291 void Node::dump() const { dump(llvm::outs()); }
292 
toString() const293 std::string Node::toString() const { return this->getDebugDesc(); }
294 
getTotMemSize() const295 size_t Node::getTotMemSize() const {
296   size_t totMemSize = 0;
297   for (unsigned idx = 0, e = getNumInputs(); idx < e; idx++) {
298     totMemSize += getNthInput(idx).getType()->getSizeInBytes();
299   }
300   for (unsigned idx = 0, e = getNumResults(); idx < e; idx++) {
301     totMemSize += getNthResult(idx).getType()->getSizeInBytes();
302   }
303   return totMemSize;
304 }
305 
clone() const306 Node *Node::clone() const {
307   switch (getKind()) {
308 #define DEF_NODE(CLASS, NAME)                                                  \
309   case glow::Kinded::Kind::CLASS##Kind:                                        \
310     return static_cast<const CLASS *>(this)->clone();
311 #include "glow/AutoGenNodes.def"
312   default:
313     llvm_unreachable("Unhandled node");
314   }
315 }
316 
destroyNode(Node * N)317 void Node::destroyNode(Node *N) {
318   switch (N->getKind()) {
319 #define DEF_NODE(CLASS, NAME)                                                  \
320   case glow::Kinded::Kind::CLASS##Kind: {                                      \
321     delete static_cast<CLASS *>(N);                                            \
322     break;                                                                     \
323   }
324 #include "glow/AutoGenNodes.def"
325   default:
326     llvm_unreachable("Unhandled node");
327   }
328 }
329 
330 namespace glow {
331 
operator <<(llvm::raw_ostream & os,const Node & node)332 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Node &node) {
333   node.dump(os);
334   return os;
335 }
336 
operator <<(llvm::raw_ostream & os,const Node * node)337 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Node *node) {
338   assert(node != nullptr && "Null Pointer.");
339   node->dump(os);
340   return os;
341 }
342 } // namespace glow
343 
344 //===----------------------------------------------------------------------===//
345 //                       Nodes verification
346 //===----------------------------------------------------------------------===//
347 
verify() const348 bool Node::verify() const {
349   // Verify the shared members of the node.
350   bool isValid = true;
351 
352   // Verify the predicate field.
353   if (hasPredicate()) {
354     auto pred = getPredicate();
355     if (!expectCompareTrue("Invalid predicate", bool(pred.getNode()), true,
356                            this)) {
357       // The following code assumes pred is valid.
358       return false;
359     }
360     auto Ty = pred.getType();
361     isValid &= expectCompareTrue("Predicate must be a vector",
362                                  Ty->dims().size(), size_t(1), this);
363   }
364 
365   if (getParent()) {
366     isValid &=
367         expectCompareTrue("Node not present in its parent",
368                           std::find(getParent()->getNodes().begin(),
369                                     getParent()->getNodes().end(),
370                                     *this) != getParent()->getNodes().end(),
371                           true, this);
372   }
373 
374   // Verify node-specific properties:
375   switch (getKind()) {
376 #define DEF_NODE(CLASS, NAME)                                                  \
377   case glow::Kinded::Kind::CLASS##Kind:                                        \
378     isValid &= static_cast<const CLASS *>(this)->verify();                     \
379     break;
380 #include "glow/AutoGenNodes.def"
381   default:
382     llvm_unreachable("Unhandled node");
383   }
384   return isValid;
385 }
386 
387 //===----------------------------------------------------------------------===//
388 // ilist_traits<glow::Node> Implementation
389 //===----------------------------------------------------------------------===//
390 
391 // The trait object is embedded into a Function.  Use dirty hacks to
392 // reconstruct the Function from the 'self' pointer of the trait.
getContainingFunction()393 Function *llvm::ilist_traits<Node>::getContainingFunction() {
394   size_t Offset(size_t(&((Function *)nullptr->*Function::getNodesMemberPtr())));
395   iplist<Node> *Anchor(static_cast<iplist<Node> *>(this));
396   return reinterpret_cast<Function *>(reinterpret_cast<char *>(Anchor) -
397                                       Offset);
398 }
399 
addNodeToList(Node * node)400 void llvm::ilist_traits<Node>::addNodeToList(Node *node) {
401   assert(node->getParent() == nullptr && "Already in a list!");
402   node->setParent(getContainingFunction());
403 }
404 
removeNodeFromList(Node * node)405 void llvm::ilist_traits<Node>::removeNodeFromList(Node *node) {
406   // When an instruction is removed from a function, clear the parent pointer.
407   assert(node->getParent() && "Not in a list!");
408   node->setParent(nullptr);
409 }
410 
transferNodesFromList(llvm::ilist_traits<Node> & L2,node_iterator first,node_iterator last)411 void llvm::ilist_traits<Node>::transferNodesFromList(
412     llvm::ilist_traits<Node> &L2, node_iterator first, node_iterator last) {
413   // If transferring nodes within the same Function, no reason to
414   // update their parent pointers.
415   Function *ThisParent = getContainingFunction();
416   if (ThisParent == L2.getContainingFunction())
417     return;
418 
419   // Update the parent fields in the nodes.
420   for (; first != last; ++first)
421     first->setParent(ThisParent);
422 }
423