1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file graph.cc
22  * \brief Utilities to get information about schedule graph.
23  */
24 #include "graph.h"
25 
26 #include <tvm/runtime/registry.h>
27 #include <tvm/te/operation.h>
28 #include <tvm/tir/expr.h>
29 #include <tvm/tir/stmt_functor.h>
30 
31 #include <unordered_map>
32 #include <unordered_set>
33 #include <utility>
34 
35 namespace tvm {
36 namespace te {
37 // key to specific tensor dimension.
38 struct TensorDimKey {
39   Operation op;
40   int value_index;
41   int dim;
TensorDimKeytvm::te::TensorDimKey42   TensorDimKey() {}
TensorDimKeytvm::te::TensorDimKey43   TensorDimKey(const Tensor& t, int dim) : op(t->op), value_index(t->value_index), dim(dim) {}
TensorDimKeytvm::te::TensorDimKey44   TensorDimKey(const Tensor& t, size_t dim)
45       : op(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {}
operator ==tvm::te::TensorDimKey46   inline bool operator==(const TensorDimKey& other) const {
47     return op == other.op && value_index == other.value_index && dim == other.dim;
48   }
operator !=tvm::te::TensorDimKey49   inline bool operator!=(const TensorDimKey& other) const { return !operator==(other); }
50 };
51 }  // namespace te
52 }  // namespace tvm
53 
54 namespace std {
55 template <>
56 struct hash<::tvm::te::TensorDimKey> {
operator ()std::hash57   std::size_t operator()(const ::tvm::te::TensorDimKey& k) const {
58     size_t lhs = ::tvm::ObjectPtrHash()(k.op);
59     size_t rhs = static_cast<size_t>(k.value_index) << 16UL | static_cast<size_t>(k.dim);
60     lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
61     return lhs;
62   }
63 };
64 }  // namespace std
65 
66 namespace tvm {
67 namespace te {
68 
69 // construct a read graph that gives readers of each operation
70 // that the root depend on
CreateReadGraph(const Array<Operation> & roots)71 ReadGraph CreateReadGraph(const Array<Operation>& roots) {
72   ReadGraph rmap;
73   std::vector<Operation> stack;
74   std::unordered_set<const Object*> visited;
75   // initialize the roots
76   for (Operation op : roots) {
77     stack.push_back(op);
78     visited.insert(op.get());
79   }
80 
81   while (!stack.empty()) {
82     Operation op = stack.back();
83     stack.pop_back();
84     Array<Tensor> deps = op->InputTensors();
85     rmap.Set(op, deps);
86     for (Tensor t : deps) {
87       if (t->op.defined() && visited.count(t->op.get()) == 0) {
88         visited.insert(t->op.get());
89         stack.push_back(t->op);
90       }
91     }
92   }
93   return rmap;
94 }
95 
96 // Do DFS visit to get the subgraph.
97 // Return if op is inside the subgraph.
GetSubGraphByPostDFS_(const Operation & op,const std::unordered_set<const Object * > & boundary,bool include_bounary,std::unordered_map<const Object *,bool> * visited,Array<Operation> * result)98 bool GetSubGraphByPostDFS_(const Operation& op, const std::unordered_set<const Object*>& boundary,
99                            bool include_bounary, std::unordered_map<const Object*, bool>* visited,
100                            Array<Operation>* result) {
101   if (visited->count(op.get())) {
102     return visited->at(op.get());
103   }
104   if (boundary.count(op.get())) {
105     (*visited)[op.get()] = true;
106     if (include_bounary) {
107       result->push_back(op);
108     }
109     return true;
110   }
111   // mark to avoid loop
112   // Not necessary for DAG.
113   (*visited)[op.get()] = false;
114   // check if we can reach boundary.
115   bool reach_boundary = false;
116   for (Tensor t : op->InputTensors()) {
117     if (GetSubGraphByPostDFS_(t->op, boundary, include_bounary, visited, result)) {
118       reach_boundary = true;
119     }
120   }
121   (*visited)[op.get()] = reach_boundary;
122   if (reach_boundary) {
123     result->push_back(op);
124   }
125   return reach_boundary;
126 }
127 
GetSubGraph(const Array<Tensor> & outputs,const Array<Tensor> & inputs,bool include_inputs)128 Array<Operation> GetSubGraph(const Array<Tensor>& outputs, const Array<Tensor>& inputs,
129                              bool include_inputs) {
130   Array<Operation> result;
131   std::unordered_set<const Object*> boundary;
132   for (Tensor t : inputs) {
133     boundary.insert(t->op.get());
134   }
135   std::unordered_map<const Object*, bool> visited;
136   for (Tensor t : outputs) {
137     GetSubGraphByPostDFS_(t->op, boundary, include_inputs, &visited, &result);
138   }
139   return result;
140 }
141 
PostDFSOrder(const Operation & op,const ReadGraph & g,std::unordered_set<Operation> * visited,Array<Operation> * post_order)142 void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_set<Operation>* visited,
143                   Array<Operation>* post_order) {
144   if (visited->count(op)) return;
145   visited->insert(op);
146   for (const auto& t : g.at(op)) {
147     PostDFSOrder(t->op, g, visited, post_order);
148   }
149   post_order->push_back(op);
150 }
151 
PostDFSOrder(const Array<Operation> & roots,const ReadGraph & g)152 Array<Operation> PostDFSOrder(const Array<Operation>& roots, const ReadGraph& g) {
153   std::unordered_set<Operation> visited;
154   Array<Operation> post_order;
155   for (Operation op : roots) {
156     PostDFSOrder(op, g, &visited, &post_order);
157   }
158   return post_order;
159 }
160 
CreateFeedGraph(const ReadGraph & g)161 FeedGraph CreateFeedGraph(const ReadGraph& g) {
162   FeedGraph fg;
163   for (auto kv : g) {
164     for (Tensor t : kv.second) {
165       fg[t].push_back(kv.first);
166     }
167   }
168   return fg;
169 }
170 
CreateAttachPath(Schedule sch)171 AttachPath CreateAttachPath(Schedule sch) {
172   AttachPath ret;
173   for (Stage stage : sch->stages) {
174     std::unordered_set<const Object*> visited;
175     Array<IterVar> path;
176     for (Stage s = stage; s.defined();) {
177       CHECK(!visited.count(s.get())) << "Find loop in compute_at attach group";
178       visited.insert(s.get());
179       Stage spec = s.GetAttachSpec();
180       bool start_attach;
181       IterVar attach_ivar;
182       if (spec->attach_type == kScope) {
183         attach_ivar = spec->attach_ivar;
184         s = spec->attach_stage;
185         start_attach = false;
186         CHECK(attach_ivar.defined());
187       } else if (spec->attach_type == kScanUpdate) {
188         s = spec->attach_stage;
189         start_attach = true;
190       } else {
191         break;
192       }
193       CHECK(s.defined());
194       for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
195         IterVar iv = s->leaf_iter_vars[i - 1];
196         if (!start_attach && iv.same_as(attach_ivar)) {
197           start_attach = true;
198         }
199         if (start_attach) path.push_back(iv);
200       }
201       CHECK(start_attach) << "Invalid Schedule: cannot find attach point " << attach_ivar
202                           << " in the schedule of " << s->op;
203     }
204     if (!ret.count(stage->op)) {
205       ret.Set(stage->op, path);
206     }
207   }
208   return ret;
209 }
210 
211 // graph of push reach relation of tensor dimensions
212 using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey>>;
213 
GetReachGraph(const Array<Operation> & ops)214 ReachGraph GetReachGraph(const Array<Operation>& ops) {
215   ReachGraph reach;
216   std::unordered_set<const Object*> bset;
217   for (size_t i = 0; i < ops.size(); ++i) {
218     bset.insert(ops[i].get());
219   }
220 
221   for (Operation op : ops) {
222     if (const auto* scan_op = op.as<ScanOpNode>()) {
223       const auto& update = scan_op->update;
224       const auto& init = scan_op->init;
225       for (size_t i = 0; i < update.size(); ++i) {
226         Tensor t = op.output(i);
227         for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
228           reach[TensorDimKey(t, k)].emplace_back(TensorDimKey(update[i], k));
229           reach[TensorDimKey(t, k)].emplace_back(TensorDimKey(init[i], k));
230         }
231       }
232     } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
233       std::unordered_map<const Object*, TensorDimKey> vmap;
234       const auto& axis = compute_op->axis;
235       Tensor t = op.output(0);
236       for (size_t i = 0; i < axis.size(); ++i) {
237         vmap[axis[i]->var.get()] = TensorDimKey(t, i);
238         reach[TensorDimKey(t, i)] = {};
239       }
240       auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) {
241         if (auto* pload = n.as<tir::ProducerLoadNode>()) {
242           Tensor t = Downcast<Tensor>(pload->producer);
243           if (!bset.count(t->op.get())) return;
244           for (size_t i = 0; i < pload->indices.size(); ++i) {
245             TensorDimKey dkey(t, static_cast<int>(i));
246             auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) {
247               const VarNode* v = node.as<VarNode>();
248               auto it = vmap.find(v);
249               if (it != vmap.end()) {
250                 reach[it->second].push_back(dkey);
251               }
252             };
253             tir::PostOrderVisit(pload->indices[i], fpush);
254           }
255         }
256       };
257       for (auto& e : compute_op->body) {
258         tir::PostOrderVisit(e, fvisit);
259       }
260     }
261   }
262   return reach;
263 }
264 
ScanGetBody(const Operation & scan_op)265 Array<Operation> ScanGetBody(const Operation& scan_op) {
266   const ScanOpNode* scan = scan_op.as<ScanOpNode>();
267   // Get the body.
268   Array<Tensor> inputs;
269   for (Tensor t : scan->state_placeholder) {
270     inputs.push_back(t);
271   }
272   for (Tensor t : scan->inputs) {
273     inputs.push_back(t);
274   }
275   return GetSubGraph(scan->update, inputs, false);
276 }
277 
ScanFixPointAnalysis(const Operation & scan_op)278 Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
279   const ScanOpNode* scan = scan_op.as<ScanOpNode>();
280   Array<Operation> body = ScanGetBody(scan_op);
281 
282   std::unordered_map<TensorDimKey, const Object*> exact_reach;
283   std::unordered_set<const Object*> fail_set;
284 
285   for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
286     for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
287       TensorDimKey key(scan->state_placeholder[i], k);
288       exact_reach[key] = scan->spatial_axis_[sp_idx].get();
289     }
290   }
291   // merge exact reach
292   auto f_merge_key = [&exact_reach, &fail_set](const TensorDimKey& dst, const TensorDimKey& src) {
293     auto sit = exact_reach.find(src);
294     if (sit == exact_reach.end()) return;
295     auto dit = exact_reach.find(dst);
296     if (dit == exact_reach.end()) {
297       exact_reach[dst] = sit->second;
298     } else {
299       if (dit->second != sit->second) {
300         fail_set.insert(dit->second);
301         fail_set.insert(sit->second);
302       }
303     }
304   };
305   // prop exact reach back.
306   for (size_t i = 0; i < body.size(); ++i) {
307     const Operation& op = body[i];
308     if (const auto* scan_op = op.as<ScanOpNode>()) {
309       const auto& update = scan_op->update;
310       const auto& init = scan_op->init;
311       for (size_t i = 0; i < update.size(); ++i) {
312         Tensor t = op.output(i);
313         for (size_t k = 1; k < update[i]->shape.size(); ++k) {
314           f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
315           f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
316         }
317       }
318     } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
319       std::unordered_map<const Object*, std::vector<TensorDimKey>> vmap;
320       const auto& axis = compute_op->axis;
321       for (size_t i = 0; i < axis.size(); ++i) {
322         std::vector<TensorDimKey> keys;
323         for (int j = 0; j < op->num_outputs(); ++j) {
324           keys.emplace_back(op.output(j), i);
325         }
326         vmap[axis[i]->var.get()] = std::move(keys);
327       }
328       auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](const ObjectRef& n) {
329         if (auto* pload = n.as<tir::ProducerLoadNode>()) {
330           Tensor t = Downcast<Tensor>(pload->producer);
331           for (size_t i = 0; i < pload->indices.size(); ++i) {
332             auto it = vmap.find(pload->indices[i].get());
333             TensorDimKey src(t, static_cast<int>(i));
334             if (it != vmap.end()) {
335               const std::vector<TensorDimKey>& keys = it->second;
336               for (const auto& key : keys) {
337                 f_merge_key(key, src);
338               }
339             } else {
340               if (exact_reach.count(src)) {
341                 fail_set.insert(exact_reach.at(src));
342               }
343             }
344           }
345         }
346       };
347       for (auto& e : compute_op->body) {
348         tir::PostOrderVisit(e, fvisit);
349       }
350     }
351   }
352   ReachGraph reach;
353   Map<IterVar, PrimExpr> ret;
354   std::unordered_set<TensorDimKey> place_holder_ref;
355   for (size_t i = 0; i < scan->state_placeholder.size(); ++i) {
356     for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) {
357       place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k));
358     }
359   }
360 
361   for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
362     for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
363       TensorDimKey key(scan->update[i], k);
364       TensorDimKey target(scan->state_placeholder[i], k);
365       IterVar sp_iv = scan->spatial_axis_[sp_idx];
366       if (fail_set.count(sp_iv.get()) || !exact_reach.count(key) ||
367           exact_reach.at(key) != sp_iv.get()) {
368         ret.Set(sp_iv, make_const(DataType::Int(32), 0));
369       } else {
370         // now we proved exact match, need to prove no interference with other graph.
371         if (reach.size() == 0) reach = GetReachGraph(body);
372         // do a DFS
373         std::unordered_set<TensorDimKey> visited;
374         std::vector<TensorDimKey> stack{key};
375         visited.insert(key);
376         while (!stack.empty()) {
377           TensorDimKey k = stack.back();
378           if (k != target && place_holder_ref.count(k)) break;
379           stack.pop_back();
380           if (!reach.count(k)) {
381             LOG(FATAL) << "cannot find reach of " << k.op << "-" << k.dim;
382           }
383 
384           for (TensorDimKey kk : reach.at(k)) {
385             if (visited.count(kk)) {
386               continue;
387             }
388             visited.insert(kk);
389             stack.push_back(kk);
390           }
391         }
392         if (!stack.empty()) {
393           // failed the prove.
394           ret.Set(sp_iv, make_const(DataType::Int(32), 0));
395         } else {
396           ret.Set(sp_iv, make_const(DataType::Int(32), 1));
397         }
398       }
399     }
400   }
401   return ret;
402 }
403 
404 TVM_REGISTER_GLOBAL("schedule.CreateReadGraph").set_body_typed(CreateReadGraph);
405 
406 TVM_REGISTER_GLOBAL("schedule.PostDFSOrder")
__anon65a395210502(const Array<Operation>& roots, const ReadGraph& g) 407     .set_body_typed([](const Array<Operation>& roots, const ReadGraph& g) {
408       return PostDFSOrder(roots, g);
409     });
410 
411 TVM_REGISTER_GLOBAL("schedule.CreateAttachPath").set_body_typed(CreateAttachPath);
412 
413 TVM_REGISTER_GLOBAL("schedule.ScanGetBody").set_body_typed(ScanGetBody);
414 
415 TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis").set_body_typed(ScanFixPointAnalysis);
416 
417 }  // namespace te
418 }  // namespace tvm
419