1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file schedule_ops.cc
22  */
23 #include <tvm/ir.h>
24 #include <tvm/ir_mutator.h>
25 #include <tvm/ir_pass.h>
26 #include <tvm/ir_visitor.h>
27 #include <tvm/operation.h>
28 #include <tvm/schedule_pass.h>
29 #include <utility>
30 #include <unordered_map>
31 #include <unordered_set>
32 #include "graph.h"
33 #include "../op/op_util.h"
34 #include "../pass/ir_util.h"
35 
36 namespace tvm {
37 namespace schedule {
38 
39 using namespace ir;
40 
MakePipeline(const Stage & s,const std::unordered_map<IterVar,Range> & dom_map,Stmt consumer,bool debug_keep_trivial_loop)41 Stmt MakePipeline(const Stage& s,
42                   const std::unordered_map<IterVar, Range>& dom_map,
43                   Stmt consumer,
44                   bool debug_keep_trivial_loop) {
45   Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
46   if (producer.defined()) {
47     producer = ProducerConsumer::make(s->op, true, producer);
48   }
49   if (s->double_buffer) {
50     producer = AttrStmt::make(
51         s->op, ir::attr::double_buffer_scope, 1, producer);
52   }
53   Stmt pipeline = producer;
54 
55   if (consumer.defined() && !is_no_op(consumer)) {
56     consumer = ProducerConsumer::make(s->op, false, consumer);
57     pipeline = Block::make(producer, consumer);
58   }
59   pipeline = s->op->BuildRealize(s, dom_map, pipeline);
60   // use attribute to mark scope of the operation.
61   pipeline = AttrStmt::make(
62       s->op, ir::attr::realize_scope,
63       StringImm::make(s->scope),
64       pipeline);
65 
66   if (s->is_opengl) {
67     pipeline = AttrStmt::make(
68         s->op, ir::attr::opengl_stage_scope, StringImm::make(""), pipeline);
69   }
70   return pipeline;
71 }
72 
73 // inject the operator's realization on the stmt.
74 class InjectAttach : public IRMutator {
75  public:
InjectAttach(const Stage & stage,const Stage & attach_spec,const std::unordered_map<IterVar,Range> & dom_map,bool debug_keep_trivial_loop)76   InjectAttach(const Stage& stage,
77                const Stage& attach_spec,
78                const std::unordered_map<IterVar, Range>& dom_map,
79                bool debug_keep_trivial_loop)
80       : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
81         debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
82 
Mutate(Stmt stmt)83   Stmt Mutate(Stmt stmt) final {
84     CHECK(stmt.defined());
85     stmt =  IRMutator::Mutate(stmt);
86     const AttrStmt* op = stmt.as<AttrStmt>();
87     if (op != nullptr &&
88         op->attr_key == attr::loop_scope) {
89       if (attach_spec_->attach_type == kScope &&
90           op->node == attach_spec_->attach_ivar) {
91         CHECK(!found_attach)
92             << "Find IterVar" << attach_spec_->attach_ivar
93             << " in multiple places in the IR";
94         found_attach = true;
95         stmt = AttrStmt::make(
96             op->node, op->attr_key, op->value,
97             MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
98       }
99     }
100     return stmt;
101   }
102   // whether attach point is found
103   bool found_attach{false};
104 
105  private:
106   // The stage.
107   const Stage& stage_;
108   // The attach spec, may not contain op.
109   const Stage& attach_spec_;
110   // domain map
111   const std::unordered_map<IterVar, Range>& dom_map_;
112   // Whether keep trivial loops with extent of 1 during lowering.
113   // This is a debug feature for dataflow/axis analysis
114   bool debug_keep_trivial_loop_;
115 };
116 
117 // inject the operator's realization on the stmt.
118 class InjectScanStep : public IRMutator {
119  public:
InjectScanStep(const Stage & stage,const Operation & scan_op,const std::unordered_map<IterVar,Range> & dom_map,bool is_init,bool debug_keep_trivial_loop)120   InjectScanStep(const Stage& stage,
121                  const Operation& scan_op,
122                  const std::unordered_map<IterVar, Range>& dom_map,
123                  bool is_init,
124                  bool debug_keep_trivial_loop)
125       : stage_(stage), scan_op_(scan_op),
126         dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
127 
Mutate(Stmt stmt)128   Stmt Mutate(Stmt stmt) final {
129     CHECK(stmt.defined());
130     stmt =  IRMutator::Mutate(stmt);
131     // update
132     const AttrStmt* op = stmt.as<AttrStmt>();
133     if (op != nullptr &&
134         ((op->attr_key == attr::scan_update_scope && !is_init_) ||
135          (op->attr_key == attr::scan_init_scope && is_init_))) {
136       if (op->node.same_as(scan_op_)) {
137         found_attach = true;
138         stmt = AttrStmt::make(
139             op->node, op->attr_key, op->value,
140             MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
141       }
142     }
143     return stmt;
144   }
145 
146   // whether attach point is found
147   bool found_attach{false};
148 
149  private:
150   // the operations to be carried
151   const Stage& stage_;
152   const Operation& scan_op_;
153   // domain map
154   const std::unordered_map<IterVar, Range>& dom_map_;
155   // whether it is init.
156   bool is_init_;
157   // Whether keep trivial loops with extent of 1 during lowering.
158   // This is a debug feature for dataflow/axis analysis
159   bool debug_keep_trivial_loop_;
160 };
161 
162 // Postprocessing of schedule op
163 // Replace the init and update's expression by scan's buffer.
164 class SchedulePostProc : public IRMutator {
165  public:
Mutate_(const ProducerConsumer * op,const Stmt & s)166   Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
167     auto it = replace_op_.find(op->func.get());
168     if (it != replace_op_.end()) {
169       Stmt body = this->Mutate(op->body);
170       if (it->second.defined()) {
171         return ProducerConsumer::make(
172             it->second, op->is_producer, body);
173       } else {
174         return body;
175       }
176     } else {
177       return IRMutator::Mutate_(op, s);
178     }
179   }
Mutate_(const LetStmt * op,const Stmt & s)180   Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
181     if (!HasSideEffect(op->value)) {
182       var_value_[op->var.get()] = Mutate(op->value);
183       return this->Mutate(op->body);
184     } else {
185       return IRMutator::Mutate_(op, s);
186     }
187   }
188 
Mutate_(const AttrStmt * op,const Stmt & s)189   Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
190     if (op->attr_key == attr::loop_scope ||
191         op->attr_key == attr::scan_init_scope) {
192       return this->Mutate(op->body);
193     } else if (op->attr_key == attr::scan_update_scope) {
194       const ScanOpNode* scan = op->node.as<ScanOpNode>();
195       CHECK(scan);
196       var_value_[scan->scan_axis->var.get()] = op->value;
197       return this->Mutate(op->body);
198     } else if (op->attr_key == attr::thread_extent) {
199       // delete duplicated thread extent attr
200       auto it = thread_extent_scope_.find(op->node.get());
201       if (it != thread_extent_scope_.end()) {
202         CHECK(is_zero(ir::Simplify(it->second - op->value)));
203         return this->Mutate(op->body);
204       } else {
205         thread_extent_scope_[op->node.get()] = op->value;
206         Stmt ret = IRMutator::Mutate_(op, s);
207         thread_extent_scope_.erase(op->node.get());
208         return ret;
209       }
210     } else if (op->attr_key == ir::attr::realize_scope ||
211                op->attr_key == ir::attr::double_buffer_scope) {
212       auto it = replace_op_.find(op->node.get());
213       if (it != replace_op_.end()) {
214         if (it->second.defined()) {
215           Stmt ret = AttrStmt::make(
216               it->second, op->attr_key, op->value, op->body);
217           return this->Mutate(ret);
218         } else {
219           return this->Mutate(op->body);
220         }
221       }
222     } else if (op->attr_key == ir::attr::buffer_bind_scope) {
223       Array<NodeRef> tuple = Downcast<Array<NodeRef> >(op->node);
224       Tensor tensor = Downcast<Tensor>(tuple[1]);
225       auto it = replace_op_.find(tensor->op.get());
226       if (it != replace_op_.end()) {
227         if (it->second.defined()) {
228           return AttrStmt::make(
229               Array<NodeRef>{tuple[0], it->second.output(tensor->value_index)},
230               op->attr_key, op->value, Mutate(op->body));
231         } else {
232           return this->Mutate(op->body);
233         }
234       }
235     } else if (op->attr_key == ir::attr::buffer_dim_align) {
236       Tensor tensor = Downcast<Tensor>(op->node);
237       auto it = replace_op_.find(tensor->op.get());
238       if (it != replace_op_.end()) {
239         if (it->second.defined()) {
240           return AttrStmt::make(
241               it->second.output(tensor->value_index),
242               op->attr_key, op->value, Mutate(op->body));
243         } else {
244           return this->Mutate(op->body);
245         }
246       }
247     }
248     return IRMutator::Mutate_(op, s);
249   }
250 
Mutate_(const Realize * op,const Stmt & s)251   Stmt Mutate_(const Realize* op, const Stmt& s) final {
252     TensorKey key{op->func, op->value_index};
253     auto it = replace_realize_.find(key);
254     if (it != replace_realize_.end()) {
255       if (it->second.defined()) {
256         Stmt ret = Realize::make(
257             it->second->op, it->second->value_index,
258             op->type, op->bounds, op->condition, op->body);
259         return this->Mutate(ret);
260       } else {
261         return this->Mutate(op->body);
262       }
263     } else {
264       return IRMutator::Mutate_(op, s);
265     }
266   }
267 
Mutate_(const Provide * op,const Stmt & s)268   Stmt Mutate_(const Provide* op, const Stmt& s) final {
269     TensorKey key{op->func, op->value_index};
270     auto it = replace_buffer_.find(key);
271     if (it != replace_buffer_.end()) {
272       const Tensor& dst = it->second;
273       Stmt ret = Provide::make(
274           dst->op, dst->value_index, op->value, op->args);
275       return this->Mutate(ret);
276     } else {
277       return IRMutator::Mutate_(op, s);
278     }
279   }
280 
Mutate_(const Call * op,const Expr & e)281   Expr Mutate_(const Call* op, const Expr& e) final {
282     if (op->call_type == Call::Halide) {
283       TensorKey key{op->func, op->value_index};
284       auto it = replace_buffer_.find(key);
285       if (it != replace_buffer_.end()) {
286         const Tensor& dst = it->second;
287         Expr ret = Call::make(
288             op->type, dst->op->name, op->args,
289             op->call_type, dst->op, dst->value_index);
290         return this->Mutate(ret);
291       }
292     }
293     return IRMutator::Mutate_(op, e);
294   }
295 
Mutate_(const Variable * op,const Expr & e)296   Expr Mutate_(const Variable* op, const Expr& e) final {
297     auto it = var_value_.find(op);
298     if (it != var_value_.end()) {
299       return it->second;
300     } else {
301       return e;
302     }
303   }
304 
Init(const Schedule & sch)305   void Init(const Schedule& sch) {
306     for (Stage s : sch->stages) {
307       for (auto kv : s->iter_var_attrs) {
308         // Update bind thread information.
309         if (kv.second->bind_thread.defined()) {
310           const Var& from = kv.first->var;
311           const Var& to = kv.second->bind_thread->var;
312           CHECK(!var_value_.count(from.get()));
313           var_value_[from.get()] = to;
314         }
315       }
316       // This must be checked for all ops, including scan.
317       if (!s->op.same_as(s->origin_op)) {
318         for (int i = 0; i < s->op->num_outputs(); ++i) {
319           Tensor target = s->origin_op.output(i);
320           AddReplace(s->op.output(i), target,
321                      target, s->origin_op);
322         }
323       }
324       // Specially add replacements for scan op.
325       if (const ScanOpNode* scan = s->op.as<ScanOpNode>()) {
326         for (size_t i = 0; i < scan->update.size(); ++i) {
327           Tensor t = s->origin_op.output(i);
328           AddReplace(scan->init[i], t);
329           AddReplace(scan->update[i], t);
330           AddReplace(scan->state_placeholder[i], t);
331         }
332       }
333     }
334   }
335 
336  private:
AddReplace(Tensor src,Tensor dst,Tensor repl_realize=Tensor (),Operation repl_op=Operation ())337   void AddReplace(Tensor src,
338                   Tensor dst,
339                   Tensor repl_realize = Tensor(),
340                   Operation repl_op = Operation()) {
341     TensorKey key{src->op, src->value_index};
342     replace_buffer_[key] = dst;
343     replace_realize_[key] = repl_realize;
344     replace_op_[src->op.get()] = repl_op;
345   }
346   // The thread extent scope.
347   std::unordered_map<const Node*, Expr> thread_extent_scope_;
348   // The scan value
349   std::unordered_map<const Variable*, Expr> var_value_;
350   // buffer replacement
351   std::unordered_map<TensorKey, Tensor> replace_buffer_;
352   // buffere realization to be replaced
353   std::unordered_map<TensorKey, Tensor> replace_realize_;
354   // replace producer consumer.
355   std::unordered_map<const Node*, Operation> replace_op_;
356 };
357 
ScheduleOps(Schedule sch,Map<IterVar,Range> dom_map_,bool debug_keep_trivial_loop)358 Stmt ScheduleOps(
359     Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
360   Stmt body = Stmt();
361   std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
362   // scan init and scan updates
363   std::unordered_map<Operation, Operation> scan_init;
364   for (Stage s : sch->stages) {
365     const ScanOpNode* scan = s->op.as<ScanOpNode>();
366     if (!scan) continue;
367     for (Tensor t : scan->init) {
368       if (scan_init.count(t->op)) {
369         CHECK(scan_init.at(t->op).same_as(s->op))
370             << "Scan init tensor can only belong to one scan";
371       } else {
372         scan_init[t->op] = s->op;
373       }
374     }
375   }
376   // verify correctness of group.
377   for (Stage g : sch->groups) {
378     CHECK(!g->op.defined());
379     CHECK_EQ(g->leaf_iter_vars.size(), 0U);
380   }
381   // reverse the post DFS order.
382   for (size_t i = sch->stages.size(); i != 0; --i) {
383     Stage s = sch->stages[i - 1];
384     CHECK_NE(s->attach_type, kInline)
385         << "call schedule.normalize before scheduleops";
386     CHECK(s->op.defined());
387     // no need to specify place holder op.
388     if (s->op.as<PlaceholderOpNode>()) continue;
389     // Remove grouping sugar, get the real attach spec.
390     Stage attach_spec = s.GetAttachSpec();
391 
392     if (scan_init.count(s->op)) {
393       CHECK(body.defined());
394       InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
395       body = mu.Mutate(body);
396       CHECK(mu.found_attach)
397           << "did not find attachment point for scan.init";
398     } else if (attach_spec->attach_type == kScanUpdate) {
399       // Handle scan update
400       CHECK(body.defined());
401       InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
402       body = mu.Mutate(body);
403       CHECK(mu.found_attach)
404           << "did not find attachment point for scan.update";
405     } else if (attach_spec->attach_type == kInlinedAlready) {
406       // do nothing
407     } else if (attach_spec->attach_type == kGroupRoot) {
408       CHECK(!s->group.defined());
409       body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop);
410     } else {
411       CHECK_EQ(attach_spec->attach_type, kScope);
412       CHECK(body.defined());
413       InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
414       body = mutator.Mutate(body);
415       CHECK(mutator.found_attach)
416           << "did not find attachment point for " << s << " in "
417           << attach_spec->attach_stage->op  << " x " << attach_spec->attach_ivar
418           << ", body:\n"
419           << body;
420     }
421   }
422   SchedulePostProc post_proc;
423   post_proc.Init(sch);
424   return post_proc.Mutate(body);
425 }
426 
427 }  // namespace schedule
428 }  // namespace tvm
429