1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file graph.cc
22 * \brief Utilities to get information about schedule graph.
23 */
24 #include <tvm/ir.h>
25 #include <tvm/ir_visitor.h>
26 #include <tvm/operation.h>
27 #include <utility>
28 #include <unordered_set>
29 #include <unordered_map>
30 #include "graph.h"
31
32 namespace tvm {
33 namespace schedule {
34 // key to specific tensor dimension.
35 struct TensorDimKey {
36 ir::FunctionRef f;
37 int value_index;
38 int dim;
TensorDimKeytvm::schedule::TensorDimKey39 TensorDimKey() {}
TensorDimKeytvm::schedule::TensorDimKey40 TensorDimKey(const ir::Call* op, int dim)
41 : f(op->func), value_index(op->value_index), dim(dim) {
42 }
TensorDimKeytvm::schedule::TensorDimKey43 TensorDimKey(const Tensor& t, int dim)
44 : f(t->op), value_index(t->value_index), dim(dim) {
45 }
TensorDimKeytvm::schedule::TensorDimKey46 TensorDimKey(const Tensor& t, size_t dim)
47 : f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {
48 }
operator ==tvm::schedule::TensorDimKey49 inline bool operator==(const TensorDimKey& other) const {
50 return f == other.f &&
51 value_index == other.value_index &&
52 dim == other.dim;
53 }
operator !=tvm::schedule::TensorDimKey54 inline bool operator!=(const TensorDimKey& other) const {
55 return !operator==(other);
56 }
57 };
58 } // namespace schedule
59 } // namespace tvm
60
61 namespace std {
62 template <>
63 struct hash<::tvm::schedule::TensorDimKey> {
operator ()std::hash64 std::size_t operator()(const ::tvm::schedule::TensorDimKey& k) const {
65 size_t lhs = ::tvm::NodeHash()(k.f);
66 size_t rhs = static_cast<size_t>(k.value_index) << 16UL |
67 static_cast<size_t>(k.dim);
68 lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
69 return lhs;
70 }
71 };
72 } // namespace std
73
74
75 namespace tvm {
76 namespace schedule {
77
78 // construct a read graph that gives readers of each operation
79 // that the root depend on
CreateReadGraph(const Array<Operation> & roots)80 ReadGraph CreateReadGraph(const Array<Operation>& roots) {
81 ReadGraph rmap;
82 std::vector<Operation> stack;
83 std::unordered_set<const Node*> visited;
84 // initialize the roots
85 for (Operation op : roots) {
86 stack.push_back(op);
87 visited.insert(op.get());
88 }
89
90 while (!stack.empty()) {
91 Operation op = stack.back();
92 stack.pop_back();
93 Array<Tensor> deps = op->InputTensors();
94 rmap.Set(op, deps);
95 for (Tensor t : deps) {
96 if (t->op.defined() && visited.count(t->op.get()) == 0) {
97 visited.insert(t->op.get());
98 stack.push_back(t->op);
99 }
100 }
101 }
102 return rmap;
103 }
104
105 // Do DFS visit to get the subgraph.
106 // Return if op is inside the subgraph.
GetSubGraphByPostDFS_(const Operation & op,const std::unordered_set<const Node * > & boundary,bool include_bounary,std::unordered_map<const Node *,bool> * visited,Array<Operation> * result)107 bool GetSubGraphByPostDFS_(
108 const Operation& op,
109 const std::unordered_set<const Node*>& boundary,
110 bool include_bounary,
111 std::unordered_map<const Node*, bool>* visited,
112 Array<Operation>* result) {
113 if (visited->count(op.get())) {
114 return visited->at(op.get());
115 }
116 if (boundary.count(op.get())) {
117 (*visited)[op.get()] = true;
118 if (include_bounary) {
119 result->push_back(op);
120 }
121 return true;
122 }
123 // mark to avoid loop
124 // Not necessary for DAG.
125 (*visited)[op.get()] = false;
126 // check if we can reach boundary.
127 bool reach_boundary = false;
128 for (Tensor t : op->InputTensors()) {
129 if (GetSubGraphByPostDFS_(t->op, boundary,
130 include_bounary,
131 visited, result)) {
132 reach_boundary = true;
133 }
134 }
135 (*visited)[op.get()] = reach_boundary;
136 if (reach_boundary) {
137 result->push_back(op);
138 }
139 return reach_boundary;
140 }
141
GetSubGraph(const Array<Tensor> & outputs,const Array<Tensor> & inputs,bool include_inputs)142 Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
143 const Array<Tensor>& inputs,
144 bool include_inputs) {
145 Array<Operation> result;
146 std::unordered_set<const Node*> boundary;
147 for (Tensor t : inputs) {
148 boundary.insert(t->op.get());
149 }
150 std::unordered_map<const Node*, bool> visited;
151 for (Tensor t : outputs) {
152 GetSubGraphByPostDFS_(t->op, boundary, include_inputs,
153 &visited, &result);
154 }
155 return result;
156 }
157
158
PostDFSOrder(const Operation & op,const ReadGraph & g,std::unordered_set<Operation> * visited,Array<Operation> * post_order)159 void PostDFSOrder(const Operation& op,
160 const ReadGraph& g,
161 std::unordered_set<Operation>* visited,
162 Array<Operation>* post_order) {
163 if (visited->count(op)) return;
164 visited->insert(op);
165 for (const auto& t : g.at(op)) {
166 PostDFSOrder(t->op, g, visited, post_order);
167 }
168 post_order->push_back(op);
169 }
170
PostDFSOrder(const Array<Operation> & roots,const ReadGraph & g)171 Array<Operation> PostDFSOrder(
172 const Array<Operation>& roots,
173 const ReadGraph& g) {
174 std::unordered_set<Operation> visited;
175 Array<Operation> post_order;
176 for (Operation op : roots) {
177 PostDFSOrder(op, g, &visited, &post_order);
178 }
179 return post_order;
180 }
181
CreateFeedGraph(const ReadGraph & g)182 FeedGraph CreateFeedGraph(const ReadGraph& g) {
183 FeedGraph fg;
184 for (auto kv : g) {
185 for (Tensor t : kv.second) {
186 fg[t].push_back(kv.first);
187 }
188 }
189 return fg;
190 }
191
CreateAttachPath(Schedule sch)192 AttachPath CreateAttachPath(Schedule sch) {
193 AttachPath ret;
194 for (Stage stage : sch->stages) {
195 std::unordered_set<const Node*> visited;
196 Array<IterVar> path;
197 for (Stage s = stage; s.defined();) {
198 CHECK(!visited.count(s.get()))
199 << "Find loop in compute_at attach group";
200 visited.insert(s.get());
201 Stage spec = s.GetAttachSpec();
202 bool start_attach;
203 IterVar attach_ivar;
204 if (spec->attach_type == kScope) {
205 attach_ivar = spec->attach_ivar;
206 s = spec->attach_stage;
207 start_attach = false;
208 CHECK(attach_ivar.defined());
209 } else if (spec->attach_type == kScanUpdate) {
210 s = spec->attach_stage;
211 start_attach = true;
212 } else {
213 break;
214 }
215 CHECK(s.defined());
216 for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
217 IterVar iv = s->leaf_iter_vars[i - 1];
218 if (!start_attach && iv.same_as(attach_ivar)) {
219 start_attach = true;
220 }
221 if (start_attach) path.push_back(iv);
222 }
223 CHECK(start_attach)
224 << "Invalid Schedule: cannot find attach point " << attach_ivar
225 << " in the schedule of " << s->op;
226 }
227 if (!ret.count(stage->op)) {
228 ret.Set(stage->op, path);
229 }
230 }
231 return ret;
232 }
233
234 // graph of push reach relation of tensor dimensions
235 using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey> >;
236
GetReachGraph(const Array<Operation> & ops)237 ReachGraph GetReachGraph(const Array<Operation>& ops) {
238 ReachGraph reach;
239 std::unordered_set<const Node*> bset;
240 for (size_t i = 0; i < ops.size(); ++i) {
241 bset.insert(ops[i].get());
242 }
243
244 for (Operation op : ops) {
245 if (const auto* scan_op = op.as<ScanOpNode>()) {
246 const auto& update = scan_op->update;
247 const auto& init = scan_op->init;
248 for (size_t i = 0; i < update.size(); ++i) {
249 Tensor t = op.output(i);
250 for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
251 reach[TensorDimKey(t, k)].emplace_back(
252 TensorDimKey(update[i], k));
253 reach[TensorDimKey(t, k)].emplace_back(
254 TensorDimKey(init[i], k));
255 }
256 }
257 } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
258 std::unordered_map<const Node*, TensorDimKey> vmap;
259 const auto& axis = compute_op->axis;
260 Tensor t = op.output(0);
261 for (size_t i = 0; i < axis.size(); ++i) {
262 vmap[axis[i]->var.get()] = TensorDimKey(t, i);
263 reach[TensorDimKey(t, i)] = {};
264 }
265 auto fvisit = [&vmap, &reach, &bset](const NodeRef& n) {
266 const ir::Call *call = n.as<ir::Call>();
267 if (call != nullptr && call->func.defined()) {
268 if (!bset.count(call->func.get())) return;
269 for (size_t i = 0; i < call->args.size(); ++i) {
270 TensorDimKey dkey(call, static_cast<int>(i));
271 auto fpush = [&dkey, &vmap, &reach](const NodeRef& node) {
272 const Variable *v = node.as<Variable>();
273 auto it = vmap.find(v);
274 if (it != vmap.end()) {
275 reach[it->second].push_back(dkey);
276 }
277 };
278 ir::PostOrderVisit(call->args[i], fpush);
279 }
280 }
281 };
282 for (auto& e : compute_op->body) {
283 ir::PostOrderVisit(e, fvisit);
284 }
285 }
286 }
287 return reach;
288 }
289
ScanGetBody(const Operation & scan_op)290 Array<Operation> ScanGetBody(const Operation& scan_op) {
291 const ScanOpNode* scan = scan_op.as<ScanOpNode>();
292 // Get the body.
293 Array<Tensor> inputs;
294 for (Tensor t : scan->state_placeholder) {
295 inputs.push_back(t);
296 }
297 for (Tensor t : scan->inputs) {
298 inputs.push_back(t);
299 }
300 return GetSubGraph(scan->update, inputs, false);
301 }
302
ScanFixPointAnalysis(const Operation & scan_op)303 Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
304 const ScanOpNode* scan = scan_op.as<ScanOpNode>();
305 Array<Operation> body = ScanGetBody(scan_op);
306
307 std::unordered_map<TensorDimKey, const Node*> exact_reach;
308 std::unordered_set<const Node*> fail_set;
309
310 for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
311 for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
312 TensorDimKey key(scan->state_placeholder[i], k);
313 exact_reach[key] = scan->spatial_axis_[sp_idx].get();
314 }
315 }
316 // merge exact reach
317 auto f_merge_key = [&exact_reach, &fail_set](
318 const TensorDimKey& dst, const TensorDimKey& src) {
319 auto sit = exact_reach.find(src);
320 if (sit == exact_reach.end()) return;
321 auto dit = exact_reach.find(dst);
322 if (dit == exact_reach.end()) {
323 exact_reach[dst] = sit->second;
324 } else {
325 if (dit->second != sit->second) {
326 fail_set.insert(dit->second);
327 fail_set.insert(sit->second);
328 }
329 }
330 };
331 // prop exact reach back.
332 for (size_t i = 0; i < body.size(); ++i) {
333 const Operation& op = body[i];
334 if (const auto* scan_op = op.as<ScanOpNode>()) {
335 const auto& update = scan_op->update;
336 const auto& init = scan_op->init;
337 for (size_t i = 0; i < update.size(); ++i) {
338 Tensor t = op.output(i);
339 for (size_t k = 1; k < update[i]->shape.size(); ++k) {
340 f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
341 f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
342 }
343 }
344 } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
345 std::unordered_map<const Node*, std::vector<TensorDimKey> > vmap;
346 const auto& axis = compute_op->axis;
347 for (size_t i = 0; i < axis.size(); ++i) {
348 std::vector<TensorDimKey> keys;
349 for (int j = 0; j < op->num_outputs(); ++j) {
350 keys.emplace_back(op.output(j), i);
351 }
352 vmap[axis[i]->var.get()] = std::move(keys);
353 }
354 auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
355 const NodeRef& n) {
356 const ir::Call *call = n.as<ir::Call>();
357 if (call != nullptr && call->func.defined()) {
358 for (size_t i = 0; i < call->args.size(); ++i) {
359 auto it = vmap.find(call->args[i].get());
360 TensorDimKey src(call, static_cast<int>(i));
361 if (it != vmap.end()) {
362 const std::vector<TensorDimKey>& keys = it->second;
363 for (const auto& key : keys) {
364 f_merge_key(key, src);
365 }
366 } else {
367 if (exact_reach.count(src)) {
368 fail_set.insert(exact_reach.at(src));
369 }
370 }
371 }
372 }
373 };
374 for (auto& e : compute_op->body) {
375 ir::PostOrderVisit(e, fvisit);
376 }
377 }
378 }
379 ReachGraph reach;
380 Map<IterVar, Expr> ret;
381 std::unordered_set<TensorDimKey> place_holder_ref;
382 for (size_t i = 0; i < scan->state_placeholder.size(); ++i) {
383 for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) {
384 place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k));
385 }
386 }
387
388 for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
389 for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
390 TensorDimKey key(scan->update[i], k);
391 TensorDimKey target(scan->state_placeholder[i], k);
392 IterVar sp_iv = scan->spatial_axis_[sp_idx];
393 if (fail_set.count(sp_iv.get()) ||
394 !exact_reach.count(key) ||
395 exact_reach.at(key) != sp_iv.get()) {
396 ret.Set(sp_iv, make_const(Int(32), 0));
397 } else {
398 // now we proved exact match, need to prove no interference with other graph.
399 if (reach.size() == 0) reach = GetReachGraph(body);
400 // do a DFS
401 std::unordered_set<TensorDimKey> visited;
402 std::vector<TensorDimKey> stack{key};
403 visited.insert(key);
404 while (!stack.empty()) {
405 TensorDimKey k = stack.back();
406 if (k != target && place_holder_ref.count(k)) break;
407 stack.pop_back();
408 if (!reach.count(k)) {
409 LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim;
410 }
411
412 for (TensorDimKey kk : reach.at(k)) {
413 if (visited.count(kk)) {
414 continue;
415 }
416 visited.insert(kk);
417 stack.push_back(kk);
418 }
419 }
420 if (!stack.empty()) {
421 // failed the prove.
422 ret.Set(sp_iv, make_const(Int(32), 0));
423 } else {
424 ret.Set(sp_iv, make_const(Int(32), 1));
425 }
426 }
427 }
428 }
429 return ret;
430 }
431
432 } // namespace schedule
433 } // namespace tvm
434