1 /**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "glow/Base/Type.h"
18 #include "glow/Graph/Graph.h"
19 #include "glow/Graph/Nodes.h"
20 #include "glow/Graph/VerifierHelper.h"
21 #include "glow/Support/Support.h"
22
23 using namespace glow;
24
setPredicate(const NodeValue & P)25 void Node::setPredicate(const NodeValue &P) { predicate_ = P; }
26
hasPredicate() const27 bool Node::hasPredicate() const { return predicate_.getNode(); }
28
getType(unsigned idx) const29 TypeRef Node::getType(unsigned idx) const {
30 assert(idx < getNumResults() && "Result number does not exist.");
31 return types_[idx];
32 }
33
setType(unsigned idx,TypeRef ty)34 void Node::setType(unsigned idx, TypeRef ty) {
35 assert(types_[idx]->dims() == ty->dims() &&
36 "Better create a new node at this point");
37 setTypeUnsafe(idx, ty);
38 }
39
setTypeUnsafe(unsigned idx,TypeRef ty)40 void Node::setTypeUnsafe(unsigned idx, TypeRef ty) {
41 assert(idx < getNumResults() && "Result number does not exist.");
42 types_[idx] = ty;
43 }
44
getElementType(unsigned resNo) const45 ElemKind Node::getElementType(unsigned resNo) const {
46 TypeRef TR = getType(resNo);
47 return TR->getElementType();
48 }
49
dims(unsigned resNo) const50 llvm::ArrayRef<dim_t> Node::dims(unsigned resNo) const {
51 TypeRef TR = getType(resNo);
52 return TR->dims();
53 }
54
addResult(TypeRef T)55 void Node::addResult(TypeRef T) { types_.push_back(T); }
56
isEqual(const Node & other) const57 bool Node::isEqual(const Node &other) const {
58 if (this == &other)
59 return true;
60
61 if (getKind() != other.getKind())
62 return false;
63
64 switch (getKind()) {
65 #define DEF_NODE(CLASS, NAME) \
66 case glow::Kinded::Kind::CLASS##Kind: \
67 return static_cast<const CLASS *>(this)->isEqual( \
68 *static_cast<const CLASS *>(&other));
69 #include "glow/AutoGenNodes.def"
70
71 #define DEF_INSTR(CLASS, NAME) case glow::Kinded::Kind::CLASS##Kind:
72 #define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME) DEF_INSTR(CLASS, NAME)
73 #define DEF_VALUE(CLASS, NAME) DEF_INSTR(CLASS, NAME)
74 #include "glow/AutoGenInstr.def"
75
76 llvm_unreachable(
77 "Not reachable, values and instructions are not handled here");
78 }
79 return false;
80 }
81
getPredicate() const82 const NodeValue Node::getPredicate() const { return predicate_; }
83
84 namespace {
85 class HashNodeVisitor : public NodeVisitor<HashNodeVisitor, llvm::hash_code> {
86 using hash_code = llvm::hash_code;
87 using super = NodeVisitor;
88
89 public:
90 #define DEF_NODE(CLASS, NAME) \
91 hash_code visit##CLASS(const CLASS *N) const { return N->getHash(); }
92 #include "glow/AutoGenNodes.def"
93
visit(const Node * N) const94 hash_code visit(const Node *N) const {
95 return const_cast<HashNodeVisitor *>(this)->super::visit(
96 const_cast<Node *>(N));
97 }
98 };
99
100 } // namespace
101
getHash() const102 llvm::hash_code Node::getHash() const { return HashNodeVisitor().visit(this); }
103
visit(Node * parent,NodeWalker * visitor)104 void Node::visit(Node *parent, NodeWalker *visitor) {
105 if (hasPredicate()) {
106 getPredicate().getNode()->visit(this, visitor);
107 }
108
109 switch (getKind()) {
110 #define DEF_NODE(CLASS, NAME) \
111 case glow::Kinded::Kind::CLASS##Kind: \
112 return static_cast<CLASS *>(this)->visit(parent, visitor);
113 #include "glow/AutoGenNodes.def"
114 default:
115 llvm_unreachable("Unhandled node");
116 }
117 }
118
119 //===----------------------------------------------------------------------===//
120 // Debug description methods
121 //===----------------------------------------------------------------------===//
122
getNumInputs() const123 unsigned Node::getNumInputs() const {
124 switch (getKind()) {
125 #define DEF_NODE(CLASS, NAME) \
126 case glow::Kinded::Kind::CLASS##Kind: \
127 return static_cast<const CLASS *>(this)->getNumInputs();
128 #include "glow/AutoGenNodes.def"
129 default:
130 llvm_unreachable("Unhandled node");
131 }
132 }
133
getInputName(unsigned idx) const134 std::string Node::getInputName(unsigned idx) const {
135 switch (getKind()) {
136 #define DEF_NODE(CLASS, NAME) \
137 case glow::Kinded::Kind::CLASS##Kind: \
138 return static_cast<const CLASS *>(this)->getInputName(idx);
139 #include "glow/AutoGenNodes.def"
140 default:
141 llvm_unreachable("Unhandled node");
142 }
143 }
144
getNthInput(unsigned idx)145 NodeValue Node::getNthInput(unsigned idx) {
146 switch (getKind()) {
147 #define DEF_NODE(CLASS, NAME) \
148 case glow::Kinded::Kind::CLASS##Kind: \
149 return static_cast<CLASS *>(this)->getNthInput(idx);
150 #include "glow/AutoGenNodes.def"
151 default:
152 llvm_unreachable("Unhandled node");
153 }
154 }
155
getNthInput(unsigned idx) const156 const NodeValue Node::getNthInput(unsigned idx) const {
157 switch (getKind()) {
158 #define DEF_NODE(CLASS, NAME) \
159 case glow::Kinded::Kind::CLASS##Kind: \
160 return static_cast<CLASS *>(const_cast<Node *>(this))->getNthInput(idx);
161 #include "glow/AutoGenNodes.def"
162 default:
163 llvm_unreachable("Unhandled node");
164 }
165 }
166
setNthInput(unsigned idx,NodeValue val)167 void Node::setNthInput(unsigned idx, NodeValue val) {
168 switch (getKind()) {
169 #define DEF_NODE(CLASS, NAME) \
170 case glow::Kinded::Kind::CLASS##Kind: \
171 if (getParent()) { \
172 getParent()->getLogContext()->logNodeInputChange( \
173 *this, this->getNthInput(idx), val); \
174 } \
175 return static_cast<CLASS *>(this)->setNthInput(idx, val);
176 #include "glow/AutoGenNodes.def"
177 default:
178 llvm_unreachable("Unhandled node");
179 }
180 }
181
getNthResult(unsigned idx)182 NodeValue Node::getNthResult(unsigned idx) {
183 assert(idx < getNumResults());
184 return NodeValue(this, idx);
185 }
186
getNthResult(unsigned idx) const187 const NodeValue Node::getNthResult(unsigned idx) const {
188 assert(idx < getNumResults());
189 return NodeValue(const_cast<Node *>(this), idx);
190 }
191
getOutputName(unsigned idx) const192 llvm::StringRef Node::getOutputName(unsigned idx) const {
193 switch (getKind()) {
194 #define DEF_NODE(CLASS, NAME) \
195 case glow::Kinded::Kind::CLASS##Kind: \
196 return static_cast<const CLASS *>(this)->getOutputName(idx);
197 #include "glow/AutoGenNodes.def"
198 default:
199 llvm_unreachable("Unhandled node");
200 }
201 }
202
hasSideEffects() const203 bool Node::hasSideEffects() const {
204 switch (getKind()) {
205 #define DEF_NODE(CLASS, NAME) \
206 case glow::Kinded::Kind::CLASS##Kind: \
207 return static_cast<const CLASS *>(this)->hasSideEffects();
208 #include "glow/AutoGenNodes.def"
209 default:
210 llvm_unreachable("Unhandled node");
211 }
212 }
213
isCanonical() const214 bool Node::isCanonical() const {
215 switch (getKind()) {
216 #define DEF_NODE(CLASS, NAME) \
217 case glow::Kinded::Kind::CLASS##Kind: \
218 return static_cast<const CLASS *>(this)->isCanonical();
219 #include "glow/AutoGenNodes.def"
220 default:
221 llvm_unreachable("Unhandled node");
222 }
223 }
224
isDataParallel() const225 bool Node::isDataParallel() const {
226 switch (getKind()) {
227 #define DEF_NODE(CLASS, NAME) \
228 case glow::Kinded::Kind::CLASS##Kind: \
229 return static_cast<const CLASS *>(this)->isDataParallel();
230 #include "glow/AutoGenNodes.def"
231 default:
232 llvm_unreachable("Unhandled node");
233 }
234 }
235
236 // NOTE: This is used in conjunction with assuming the 1st input is LHS, and 2nd
237 // input is RHS, and 1st result is Result.
isArithmetic() const238 bool Node::isArithmetic() const {
239 // Each case includes a static assert that the generated nodes that we
240 // consider arithmetic have the expected format/order of LHS, RHS, Result.
241 #define ARITHMETIC_NODE_CASE(NODE_NAME_) \
242 static_assert((NODE_NAME_##Node::LHSIdx == ArithmeticNode::LHSIdx && \
243 NODE_NAME_##Node::RHSIdx == ArithmeticNode::RHSIdx && \
244 NODE_NAME_##Node::ResultIdx == ArithmeticNode::ResultIdx), \
245 #NODE_NAME_ \
246 "Node does not match expected arithmetic node format."); \
247 case glow::Kinded::Kind::NODE_NAME_##NodeKind:
248
249 switch (getKind()) {
250 ARITHMETIC_NODE_CASE(Add)
251 ARITHMETIC_NODE_CASE(Mul)
252 ARITHMETIC_NODE_CASE(Sub)
253 ARITHMETIC_NODE_CASE(Div)
254 ARITHMETIC_NODE_CASE(Max)
255 ARITHMETIC_NODE_CASE(Min)
256 ARITHMETIC_NODE_CASE(CmpLTE)
257 ARITHMETIC_NODE_CASE(CmpLT)
258 ARITHMETIC_NODE_CASE(CmpEQ)
259 ARITHMETIC_NODE_CASE(Pow)
260 return true;
261 default:
262 return false;
263 }
264 #undef ARITHMETIC_NODE_CASE
265 }
266
isOverwrittenNthInput(unsigned idx) const267 bool Node::isOverwrittenNthInput(unsigned idx) const {
268 switch (getKind()) {
269 #define DEF_NODE(CLASS, NAME) \
270 case glow::Kinded::Kind::CLASS##Kind: \
271 return static_cast<const CLASS *>(this)->isOverwrittenNthInput(idx);
272 #include "glow/AutoGenNodes.def"
273 default:
274 llvm_unreachable("Unhandled node");
275 }
276 }
277
getDebugDesc() const278 std::string Node::getDebugDesc() const {
279 switch (getKind()) {
280 #define DEF_NODE(CLASS, NAME) \
281 case glow::Kinded::Kind::CLASS##Kind: \
282 return static_cast<const CLASS *>(this)->getDebugDesc();
283 #include "glow/AutoGenNodes.def"
284 default:
285 llvm_unreachable("Unhandled node");
286 }
287 }
288
dump(llvm::raw_ostream & out) const289 void Node::dump(llvm::raw_ostream &out) const { out << this->getDebugDesc(); }
290
dump() const291 void Node::dump() const { dump(llvm::outs()); }
292
toString() const293 std::string Node::toString() const { return this->getDebugDesc(); }
294
getTotMemSize() const295 size_t Node::getTotMemSize() const {
296 size_t totMemSize = 0;
297 for (unsigned idx = 0, e = getNumInputs(); idx < e; idx++) {
298 totMemSize += getNthInput(idx).getType()->getSizeInBytes();
299 }
300 for (unsigned idx = 0, e = getNumResults(); idx < e; idx++) {
301 totMemSize += getNthResult(idx).getType()->getSizeInBytes();
302 }
303 return totMemSize;
304 }
305
clone() const306 Node *Node::clone() const {
307 switch (getKind()) {
308 #define DEF_NODE(CLASS, NAME) \
309 case glow::Kinded::Kind::CLASS##Kind: \
310 return static_cast<const CLASS *>(this)->clone();
311 #include "glow/AutoGenNodes.def"
312 default:
313 llvm_unreachable("Unhandled node");
314 }
315 }
316
destroyNode(Node * N)317 void Node::destroyNode(Node *N) {
318 switch (N->getKind()) {
319 #define DEF_NODE(CLASS, NAME) \
320 case glow::Kinded::Kind::CLASS##Kind: { \
321 delete static_cast<CLASS *>(N); \
322 break; \
323 }
324 #include "glow/AutoGenNodes.def"
325 default:
326 llvm_unreachable("Unhandled node");
327 }
328 }
329
330 namespace glow {
331
operator <<(llvm::raw_ostream & os,const Node & node)332 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Node &node) {
333 node.dump(os);
334 return os;
335 }
336
operator <<(llvm::raw_ostream & os,const Node * node)337 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Node *node) {
338 assert(node != nullptr && "Null Pointer.");
339 node->dump(os);
340 return os;
341 }
342 } // namespace glow
343
344 //===----------------------------------------------------------------------===//
345 // Nodes verification
346 //===----------------------------------------------------------------------===//
347
verify() const348 bool Node::verify() const {
349 // Verify the shared members of the node.
350 bool isValid = true;
351
352 // Verify the predicate field.
353 if (hasPredicate()) {
354 auto pred = getPredicate();
355 if (!expectCompareTrue("Invalid predicate", bool(pred.getNode()), true,
356 this)) {
357 // The following code assumes pred is valid.
358 return false;
359 }
360 auto Ty = pred.getType();
361 isValid &= expectCompareTrue("Predicate must be a vector",
362 Ty->dims().size(), size_t(1), this);
363 }
364
365 if (getParent()) {
366 isValid &=
367 expectCompareTrue("Node not present in its parent",
368 std::find(getParent()->getNodes().begin(),
369 getParent()->getNodes().end(),
370 *this) != getParent()->getNodes().end(),
371 true, this);
372 }
373
374 // Verify node-specific properties:
375 switch (getKind()) {
376 #define DEF_NODE(CLASS, NAME) \
377 case glow::Kinded::Kind::CLASS##Kind: \
378 isValid &= static_cast<const CLASS *>(this)->verify(); \
379 break;
380 #include "glow/AutoGenNodes.def"
381 default:
382 llvm_unreachable("Unhandled node");
383 }
384 return isValid;
385 }
386
387 //===----------------------------------------------------------------------===//
388 // ilist_traits<glow::Node> Implementation
389 //===----------------------------------------------------------------------===//
390
391 // The trait object is embedded into a Function. Use dirty hacks to
392 // reconstruct the Function from the 'self' pointer of the trait.
getContainingFunction()393 Function *llvm::ilist_traits<Node>::getContainingFunction() {
394 size_t Offset(size_t(&((Function *)nullptr->*Function::getNodesMemberPtr())));
395 iplist<Node> *Anchor(static_cast<iplist<Node> *>(this));
396 return reinterpret_cast<Function *>(reinterpret_cast<char *>(Anchor) -
397 Offset);
398 }
399
addNodeToList(Node * node)400 void llvm::ilist_traits<Node>::addNodeToList(Node *node) {
401 assert(node->getParent() == nullptr && "Already in a list!");
402 node->setParent(getContainingFunction());
403 }
404
removeNodeFromList(Node * node)405 void llvm::ilist_traits<Node>::removeNodeFromList(Node *node) {
406 // When an instruction is removed from a function, clear the parent pointer.
407 assert(node->getParent() && "Not in a list!");
408 node->setParent(nullptr);
409 }
410
transferNodesFromList(llvm::ilist_traits<Node> & L2,node_iterator first,node_iterator last)411 void llvm::ilist_traits<Node>::transferNodesFromList(
412 llvm::ilist_traits<Node> &L2, node_iterator first, node_iterator last) {
413 // If transferring nodes within the same Function, no reason to
414 // update their parent pointers.
415 Function *ThisParent = getContainingFunction();
416 if (ThisParent == L2.getContainingFunction())
417 return;
418
419 // Update the parent fields in the nodes.
420 for (; first != last; ++first)
421 first->setParent(ThisParent);
422 }
423