1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file graph.cc
22  * \brief Utilities to get information about schedule graph.
23  */
24 #include <tvm/ir.h>
25 #include <tvm/ir_visitor.h>
26 #include <tvm/operation.h>
27 #include <utility>
28 #include <unordered_set>
29 #include <unordered_map>
30 #include "graph.h"
31 
32 namespace tvm {
33 namespace schedule {
34 // key to specific tensor dimension.
35 struct TensorDimKey {
36   ir::FunctionRef f;
37   int value_index;
38   int dim;
TensorDimKeytvm::schedule::TensorDimKey39   TensorDimKey() {}
TensorDimKeytvm::schedule::TensorDimKey40   TensorDimKey(const ir::Call* op, int dim)
41       : f(op->func), value_index(op->value_index), dim(dim) {
42   }
TensorDimKeytvm::schedule::TensorDimKey43   TensorDimKey(const Tensor& t, int dim)
44       : f(t->op), value_index(t->value_index), dim(dim) {
45   }
TensorDimKeytvm::schedule::TensorDimKey46   TensorDimKey(const Tensor& t, size_t dim)
47       : f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {
48   }
operator ==tvm::schedule::TensorDimKey49   inline bool operator==(const TensorDimKey& other) const {
50     return f == other.f &&
51         value_index == other.value_index &&
52         dim == other.dim;
53   }
operator !=tvm::schedule::TensorDimKey54   inline bool operator!=(const TensorDimKey& other) const {
55     return !operator==(other);
56   }
57 };
58 }  // namespace schedule
59 }  // namespace tvm
60 
61 namespace std {
62 template <>
63 struct hash<::tvm::schedule::TensorDimKey> {
operator ()std::hash64   std::size_t operator()(const ::tvm::schedule::TensorDimKey& k) const {
65     size_t lhs = ::tvm::NodeHash()(k.f);
66     size_t rhs = static_cast<size_t>(k.value_index) << 16UL |
67         static_cast<size_t>(k.dim);
68     lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
69     return lhs;
70   }
71 };
72 }  // namespace std
73 
74 
75 namespace tvm {
76 namespace schedule {
77 
78 // construct a read graph that gives readers of each operation
79 // that the root depend on
CreateReadGraph(const Array<Operation> & roots)80 ReadGraph CreateReadGraph(const Array<Operation>& roots) {
81   ReadGraph rmap;
82   std::vector<Operation> stack;
83   std::unordered_set<const Node*> visited;
84   // initialize the roots
85   for (Operation op : roots) {
86     stack.push_back(op);
87     visited.insert(op.get());
88   }
89 
90   while (!stack.empty()) {
91     Operation op = stack.back();
92     stack.pop_back();
93     Array<Tensor> deps = op->InputTensors();
94     rmap.Set(op, deps);
95     for (Tensor t : deps) {
96       if (t->op.defined() && visited.count(t->op.get()) == 0) {
97         visited.insert(t->op.get());
98         stack.push_back(t->op);
99       }
100     }
101   }
102   return rmap;
103 }
104 
105 // Do DFS visit to get the subgraph.
106 // Return if op is inside the subgraph.
GetSubGraphByPostDFS_(const Operation & op,const std::unordered_set<const Node * > & boundary,bool include_bounary,std::unordered_map<const Node *,bool> * visited,Array<Operation> * result)107 bool GetSubGraphByPostDFS_(
108     const Operation& op,
109     const std::unordered_set<const Node*>& boundary,
110     bool include_bounary,
111     std::unordered_map<const Node*, bool>* visited,
112     Array<Operation>* result) {
113   if (visited->count(op.get())) {
114     return visited->at(op.get());
115   }
116   if (boundary.count(op.get())) {
117     (*visited)[op.get()] = true;
118     if (include_bounary) {
119       result->push_back(op);
120     }
121     return true;
122   }
123   // mark to avoid loop
124   // Not necessary for DAG.
125   (*visited)[op.get()] = false;
126   // check if we can reach boundary.
127   bool reach_boundary = false;
128   for (Tensor t : op->InputTensors()) {
129     if (GetSubGraphByPostDFS_(t->op, boundary,
130                               include_bounary,
131                               visited, result)) {
132       reach_boundary = true;
133     }
134   }
135   (*visited)[op.get()] = reach_boundary;
136   if (reach_boundary) {
137     result->push_back(op);
138   }
139   return reach_boundary;
140 }
141 
GetSubGraph(const Array<Tensor> & outputs,const Array<Tensor> & inputs,bool include_inputs)142 Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
143                              const Array<Tensor>& inputs,
144                              bool include_inputs) {
145   Array<Operation> result;
146   std::unordered_set<const Node*> boundary;
147   for (Tensor t : inputs) {
148     boundary.insert(t->op.get());
149   }
150   std::unordered_map<const Node*, bool> visited;
151   for (Tensor t : outputs) {
152     GetSubGraphByPostDFS_(t->op, boundary, include_inputs,
153                           &visited, &result);
154   }
155   return result;
156 }
157 
158 
PostDFSOrder(const Operation & op,const ReadGraph & g,std::unordered_set<Operation> * visited,Array<Operation> * post_order)159 void PostDFSOrder(const Operation& op,
160                   const ReadGraph& g,
161                   std::unordered_set<Operation>* visited,
162                   Array<Operation>* post_order) {
163   if (visited->count(op)) return;
164   visited->insert(op);
165   for (const auto& t : g.at(op)) {
166     PostDFSOrder(t->op, g, visited, post_order);
167   }
168   post_order->push_back(op);
169 }
170 
PostDFSOrder(const Array<Operation> & roots,const ReadGraph & g)171 Array<Operation> PostDFSOrder(
172     const Array<Operation>& roots,
173     const ReadGraph& g) {
174   std::unordered_set<Operation> visited;
175   Array<Operation> post_order;
176   for (Operation op : roots) {
177     PostDFSOrder(op, g, &visited, &post_order);
178   }
179   return post_order;
180 }
181 
CreateFeedGraph(const ReadGraph & g)182 FeedGraph CreateFeedGraph(const ReadGraph& g) {
183   FeedGraph fg;
184   for (auto kv : g) {
185     for (Tensor t : kv.second) {
186       fg[t].push_back(kv.first);
187     }
188   }
189   return fg;
190 }
191 
CreateAttachPath(Schedule sch)192 AttachPath CreateAttachPath(Schedule sch) {
193   AttachPath ret;
194   for (Stage stage : sch->stages) {
195     std::unordered_set<const Node*> visited;
196     Array<IterVar> path;
197     for (Stage s = stage; s.defined();) {
198       CHECK(!visited.count(s.get()))
199           << "Find loop in compute_at attach group";
200       visited.insert(s.get());
201       Stage spec = s.GetAttachSpec();
202       bool start_attach;
203       IterVar attach_ivar;
204       if (spec->attach_type == kScope) {
205         attach_ivar = spec->attach_ivar;
206         s = spec->attach_stage;
207         start_attach = false;
208         CHECK(attach_ivar.defined());
209       } else if (spec->attach_type == kScanUpdate) {
210         s = spec->attach_stage;
211         start_attach = true;
212       } else {
213         break;
214       }
215       CHECK(s.defined());
216       for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
217         IterVar iv = s->leaf_iter_vars[i - 1];
218         if (!start_attach && iv.same_as(attach_ivar)) {
219           start_attach = true;
220         }
221         if (start_attach) path.push_back(iv);
222       }
223       CHECK(start_attach)
224           << "Invalid Schedule: cannot find attach point " << attach_ivar
225           << " in the schedule of " << s->op;
226     }
227     if (!ret.count(stage->op)) {
228       ret.Set(stage->op, path);
229     }
230   }
231   return ret;
232 }
233 
234 // graph of push reach relation of tensor dimensions
235 using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey> >;
236 
GetReachGraph(const Array<Operation> & ops)237 ReachGraph GetReachGraph(const Array<Operation>& ops) {
238   ReachGraph reach;
239   std::unordered_set<const Node*> bset;
240   for (size_t i = 0; i < ops.size(); ++i) {
241     bset.insert(ops[i].get());
242   }
243 
244   for (Operation op : ops) {
245     if (const auto* scan_op = op.as<ScanOpNode>()) {
246       const auto& update = scan_op->update;
247       const auto& init = scan_op->init;
248       for (size_t i = 0; i < update.size(); ++i) {
249         Tensor t = op.output(i);
250         for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
251           reach[TensorDimKey(t, k)].emplace_back(
252               TensorDimKey(update[i], k));
253           reach[TensorDimKey(t, k)].emplace_back(
254               TensorDimKey(init[i], k));
255         }
256       }
257     } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
258       std::unordered_map<const Node*, TensorDimKey> vmap;
259       const auto& axis = compute_op->axis;
260       Tensor t = op.output(0);
261       for (size_t i = 0; i < axis.size(); ++i) {
262         vmap[axis[i]->var.get()] = TensorDimKey(t, i);
263         reach[TensorDimKey(t, i)] = {};
264       }
265       auto fvisit = [&vmap, &reach, &bset](const NodeRef& n) {
266         const ir::Call *call = n.as<ir::Call>();
267         if (call != nullptr && call->func.defined()) {
268           if (!bset.count(call->func.get())) return;
269           for (size_t i = 0; i < call->args.size(); ++i) {
270             TensorDimKey dkey(call, static_cast<int>(i));
271             auto fpush = [&dkey, &vmap, &reach](const NodeRef& node) {
272               const Variable *v = node.as<Variable>();
273               auto it = vmap.find(v);
274               if (it != vmap.end()) {
275                 reach[it->second].push_back(dkey);
276               }
277             };
278             ir::PostOrderVisit(call->args[i], fpush);
279           }
280         }
281       };
282       for (auto& e : compute_op->body) {
283         ir::PostOrderVisit(e, fvisit);
284       }
285     }
286   }
287   return reach;
288 }
289 
ScanGetBody(const Operation & scan_op)290 Array<Operation> ScanGetBody(const Operation& scan_op) {
291   const ScanOpNode* scan = scan_op.as<ScanOpNode>();
292   // Get the body.
293   Array<Tensor> inputs;
294   for (Tensor t : scan->state_placeholder) {
295     inputs.push_back(t);
296   }
297   for (Tensor t : scan->inputs) {
298     inputs.push_back(t);
299   }
300   return GetSubGraph(scan->update, inputs, false);
301 }
302 
ScanFixPointAnalysis(const Operation & scan_op)303 Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
304   const ScanOpNode* scan = scan_op.as<ScanOpNode>();
305   Array<Operation> body = ScanGetBody(scan_op);
306 
307   std::unordered_map<TensorDimKey, const Node*> exact_reach;
308   std::unordered_set<const Node*> fail_set;
309 
310   for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
311     for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
312       TensorDimKey key(scan->state_placeholder[i], k);
313       exact_reach[key] = scan->spatial_axis_[sp_idx].get();
314     }
315   }
316   // merge exact reach
317   auto f_merge_key = [&exact_reach, &fail_set](
318       const TensorDimKey& dst, const TensorDimKey& src) {
319     auto sit = exact_reach.find(src);
320     if (sit == exact_reach.end()) return;
321     auto dit = exact_reach.find(dst);
322     if (dit == exact_reach.end()) {
323       exact_reach[dst] = sit->second;
324     } else {
325       if (dit->second != sit->second) {
326         fail_set.insert(dit->second);
327         fail_set.insert(sit->second);
328       }
329     }
330   };
331   // prop exact reach back.
332   for (size_t i = 0; i < body.size(); ++i) {
333     const Operation& op = body[i];
334     if (const auto* scan_op = op.as<ScanOpNode>()) {
335       const auto& update = scan_op->update;
336       const auto& init = scan_op->init;
337       for (size_t i = 0; i < update.size(); ++i) {
338         Tensor t = op.output(i);
339         for (size_t k = 1; k < update[i]->shape.size(); ++k) {
340           f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
341           f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
342         }
343       }
344     } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
345       std::unordered_map<const Node*, std::vector<TensorDimKey> > vmap;
346       const auto& axis = compute_op->axis;
347       for (size_t i = 0; i < axis.size(); ++i) {
348         std::vector<TensorDimKey> keys;
349         for (int j = 0; j < op->num_outputs(); ++j) {
350           keys.emplace_back(op.output(j), i);
351         }
352         vmap[axis[i]->var.get()] = std::move(keys);
353       }
354       auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
355           const NodeRef& n) {
356         const ir::Call *call = n.as<ir::Call>();
357         if (call != nullptr && call->func.defined()) {
358           for (size_t i = 0; i < call->args.size(); ++i) {
359             auto it = vmap.find(call->args[i].get());
360             TensorDimKey src(call, static_cast<int>(i));
361             if (it != vmap.end()) {
362               const std::vector<TensorDimKey>& keys = it->second;
363               for (const auto& key : keys) {
364                 f_merge_key(key, src);
365               }
366             } else {
367               if (exact_reach.count(src)) {
368                 fail_set.insert(exact_reach.at(src));
369               }
370             }
371           }
372         }
373       };
374       for (auto& e : compute_op->body) {
375         ir::PostOrderVisit(e, fvisit);
376       }
377     }
378   }
379   ReachGraph reach;
380   Map<IterVar, Expr> ret;
381   std::unordered_set<TensorDimKey> place_holder_ref;
382   for (size_t i = 0; i < scan->state_placeholder.size(); ++i) {
383     for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) {
384       place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k));
385     }
386   }
387 
388   for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
389     for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
390       TensorDimKey key(scan->update[i], k);
391       TensorDimKey target(scan->state_placeholder[i], k);
392       IterVar sp_iv = scan->spatial_axis_[sp_idx];
393       if (fail_set.count(sp_iv.get()) ||
394           !exact_reach.count(key) ||
395           exact_reach.at(key) != sp_iv.get()) {
396         ret.Set(sp_iv, make_const(Int(32), 0));
397       } else {
398         // now we proved exact match, need to prove no interference with other graph.
399         if (reach.size() == 0) reach = GetReachGraph(body);
400         // do a DFS
401         std::unordered_set<TensorDimKey> visited;
402         std::vector<TensorDimKey> stack{key};
403         visited.insert(key);
404         while (!stack.empty()) {
405           TensorDimKey k = stack.back();
406           if (k != target && place_holder_ref.count(k)) break;
407           stack.pop_back();
408           if (!reach.count(k)) {
409             LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim;
410           }
411 
412           for (TensorDimKey kk : reach.at(k)) {
413             if (visited.count(kk)) {
414               continue;
415             }
416             visited.insert(kk);
417             stack.push_back(kk);
418           }
419         }
420         if (!stack.empty()) {
421           // failed the prove.
422           ret.Set(sp_iv, make_const(Int(32), 0));
423         } else {
424           ret.Set(sp_iv, make_const(Int(32), 1));
425         }
426       }
427     }
428   }
429   return ret;
430 }
431 
432 }  // namespace schedule
433 }  // namespace tvm
434