1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file schedule_ops.cc
22 */
23 #include <tvm/ir.h>
24 #include <tvm/ir_mutator.h>
25 #include <tvm/ir_pass.h>
26 #include <tvm/ir_visitor.h>
27 #include <tvm/operation.h>
28 #include <tvm/schedule_pass.h>
29 #include <utility>
30 #include <unordered_map>
31 #include <unordered_set>
32 #include "graph.h"
33 #include "../op/op_util.h"
34 #include "../pass/ir_util.h"
35
36 namespace tvm {
37 namespace schedule {
38
39 using namespace ir;
40
MakePipeline(const Stage & s,const std::unordered_map<IterVar,Range> & dom_map,Stmt consumer,bool debug_keep_trivial_loop)41 Stmt MakePipeline(const Stage& s,
42 const std::unordered_map<IterVar, Range>& dom_map,
43 Stmt consumer,
44 bool debug_keep_trivial_loop) {
45 Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
46 if (producer.defined()) {
47 producer = ProducerConsumer::make(s->op, true, producer);
48 }
49 if (s->double_buffer) {
50 producer = AttrStmt::make(
51 s->op, ir::attr::double_buffer_scope, 1, producer);
52 }
53 Stmt pipeline = producer;
54
55 if (consumer.defined() && !is_no_op(consumer)) {
56 consumer = ProducerConsumer::make(s->op, false, consumer);
57 pipeline = Block::make(producer, consumer);
58 }
59 pipeline = s->op->BuildRealize(s, dom_map, pipeline);
60 // use attribute to mark scope of the operation.
61 pipeline = AttrStmt::make(
62 s->op, ir::attr::realize_scope,
63 StringImm::make(s->scope),
64 pipeline);
65
66 if (s->is_opengl) {
67 pipeline = AttrStmt::make(
68 s->op, ir::attr::opengl_stage_scope, StringImm::make(""), pipeline);
69 }
70 return pipeline;
71 }
72
73 // inject the operator's realization on the stmt.
74 class InjectAttach : public IRMutator {
75 public:
InjectAttach(const Stage & stage,const Stage & attach_spec,const std::unordered_map<IterVar,Range> & dom_map,bool debug_keep_trivial_loop)76 InjectAttach(const Stage& stage,
77 const Stage& attach_spec,
78 const std::unordered_map<IterVar, Range>& dom_map,
79 bool debug_keep_trivial_loop)
80 : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
81 debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
82
Mutate(Stmt stmt)83 Stmt Mutate(Stmt stmt) final {
84 CHECK(stmt.defined());
85 stmt = IRMutator::Mutate(stmt);
86 const AttrStmt* op = stmt.as<AttrStmt>();
87 if (op != nullptr &&
88 op->attr_key == attr::loop_scope) {
89 if (attach_spec_->attach_type == kScope &&
90 op->node == attach_spec_->attach_ivar) {
91 CHECK(!found_attach)
92 << "Find IterVar" << attach_spec_->attach_ivar
93 << " in multiple places in the IR";
94 found_attach = true;
95 stmt = AttrStmt::make(
96 op->node, op->attr_key, op->value,
97 MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
98 }
99 }
100 return stmt;
101 }
102 // whether attach point is found
103 bool found_attach{false};
104
105 private:
106 // The stage.
107 const Stage& stage_;
108 // The attach spec, may not contain op.
109 const Stage& attach_spec_;
110 // domain map
111 const std::unordered_map<IterVar, Range>& dom_map_;
112 // Whether keep trivial loops with extent of 1 during lowering.
113 // This is a debug feature for dataflow/axis analysis
114 bool debug_keep_trivial_loop_;
115 };
116
117 // inject the operator's realization on the stmt.
118 class InjectScanStep : public IRMutator {
119 public:
InjectScanStep(const Stage & stage,const Operation & scan_op,const std::unordered_map<IterVar,Range> & dom_map,bool is_init,bool debug_keep_trivial_loop)120 InjectScanStep(const Stage& stage,
121 const Operation& scan_op,
122 const std::unordered_map<IterVar, Range>& dom_map,
123 bool is_init,
124 bool debug_keep_trivial_loop)
125 : stage_(stage), scan_op_(scan_op),
126 dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
127
Mutate(Stmt stmt)128 Stmt Mutate(Stmt stmt) final {
129 CHECK(stmt.defined());
130 stmt = IRMutator::Mutate(stmt);
131 // update
132 const AttrStmt* op = stmt.as<AttrStmt>();
133 if (op != nullptr &&
134 ((op->attr_key == attr::scan_update_scope && !is_init_) ||
135 (op->attr_key == attr::scan_init_scope && is_init_))) {
136 if (op->node.same_as(scan_op_)) {
137 found_attach = true;
138 stmt = AttrStmt::make(
139 op->node, op->attr_key, op->value,
140 MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
141 }
142 }
143 return stmt;
144 }
145
146 // whether attach point is found
147 bool found_attach{false};
148
149 private:
150 // the operations to be carried
151 const Stage& stage_;
152 const Operation& scan_op_;
153 // domain map
154 const std::unordered_map<IterVar, Range>& dom_map_;
155 // whether it is init.
156 bool is_init_;
157 // Whether keep trivial loops with extent of 1 during lowering.
158 // This is a debug feature for dataflow/axis analysis
159 bool debug_keep_trivial_loop_;
160 };
161
162 // Postprocessing of schedule op
163 // Replace the init and update's expression by scan's buffer.
164 class SchedulePostProc : public IRMutator {
165 public:
Mutate_(const ProducerConsumer * op,const Stmt & s)166 Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
167 auto it = replace_op_.find(op->func.get());
168 if (it != replace_op_.end()) {
169 Stmt body = this->Mutate(op->body);
170 if (it->second.defined()) {
171 return ProducerConsumer::make(
172 it->second, op->is_producer, body);
173 } else {
174 return body;
175 }
176 } else {
177 return IRMutator::Mutate_(op, s);
178 }
179 }
Mutate_(const LetStmt * op,const Stmt & s)180 Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
181 if (!HasSideEffect(op->value)) {
182 var_value_[op->var.get()] = Mutate(op->value);
183 return this->Mutate(op->body);
184 } else {
185 return IRMutator::Mutate_(op, s);
186 }
187 }
188
Mutate_(const AttrStmt * op,const Stmt & s)189 Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
190 if (op->attr_key == attr::loop_scope ||
191 op->attr_key == attr::scan_init_scope) {
192 return this->Mutate(op->body);
193 } else if (op->attr_key == attr::scan_update_scope) {
194 const ScanOpNode* scan = op->node.as<ScanOpNode>();
195 CHECK(scan);
196 var_value_[scan->scan_axis->var.get()] = op->value;
197 return this->Mutate(op->body);
198 } else if (op->attr_key == attr::thread_extent) {
199 // delete duplicated thread extent attr
200 auto it = thread_extent_scope_.find(op->node.get());
201 if (it != thread_extent_scope_.end()) {
202 CHECK(is_zero(ir::Simplify(it->second - op->value)));
203 return this->Mutate(op->body);
204 } else {
205 thread_extent_scope_[op->node.get()] = op->value;
206 Stmt ret = IRMutator::Mutate_(op, s);
207 thread_extent_scope_.erase(op->node.get());
208 return ret;
209 }
210 } else if (op->attr_key == ir::attr::realize_scope ||
211 op->attr_key == ir::attr::double_buffer_scope) {
212 auto it = replace_op_.find(op->node.get());
213 if (it != replace_op_.end()) {
214 if (it->second.defined()) {
215 Stmt ret = AttrStmt::make(
216 it->second, op->attr_key, op->value, op->body);
217 return this->Mutate(ret);
218 } else {
219 return this->Mutate(op->body);
220 }
221 }
222 } else if (op->attr_key == ir::attr::buffer_bind_scope) {
223 Array<NodeRef> tuple = Downcast<Array<NodeRef> >(op->node);
224 Tensor tensor = Downcast<Tensor>(tuple[1]);
225 auto it = replace_op_.find(tensor->op.get());
226 if (it != replace_op_.end()) {
227 if (it->second.defined()) {
228 return AttrStmt::make(
229 Array<NodeRef>{tuple[0], it->second.output(tensor->value_index)},
230 op->attr_key, op->value, Mutate(op->body));
231 } else {
232 return this->Mutate(op->body);
233 }
234 }
235 } else if (op->attr_key == ir::attr::buffer_dim_align) {
236 Tensor tensor = Downcast<Tensor>(op->node);
237 auto it = replace_op_.find(tensor->op.get());
238 if (it != replace_op_.end()) {
239 if (it->second.defined()) {
240 return AttrStmt::make(
241 it->second.output(tensor->value_index),
242 op->attr_key, op->value, Mutate(op->body));
243 } else {
244 return this->Mutate(op->body);
245 }
246 }
247 }
248 return IRMutator::Mutate_(op, s);
249 }
250
Mutate_(const Realize * op,const Stmt & s)251 Stmt Mutate_(const Realize* op, const Stmt& s) final {
252 TensorKey key{op->func, op->value_index};
253 auto it = replace_realize_.find(key);
254 if (it != replace_realize_.end()) {
255 if (it->second.defined()) {
256 Stmt ret = Realize::make(
257 it->second->op, it->second->value_index,
258 op->type, op->bounds, op->condition, op->body);
259 return this->Mutate(ret);
260 } else {
261 return this->Mutate(op->body);
262 }
263 } else {
264 return IRMutator::Mutate_(op, s);
265 }
266 }
267
Mutate_(const Provide * op,const Stmt & s)268 Stmt Mutate_(const Provide* op, const Stmt& s) final {
269 TensorKey key{op->func, op->value_index};
270 auto it = replace_buffer_.find(key);
271 if (it != replace_buffer_.end()) {
272 const Tensor& dst = it->second;
273 Stmt ret = Provide::make(
274 dst->op, dst->value_index, op->value, op->args);
275 return this->Mutate(ret);
276 } else {
277 return IRMutator::Mutate_(op, s);
278 }
279 }
280
Mutate_(const Call * op,const Expr & e)281 Expr Mutate_(const Call* op, const Expr& e) final {
282 if (op->call_type == Call::Halide) {
283 TensorKey key{op->func, op->value_index};
284 auto it = replace_buffer_.find(key);
285 if (it != replace_buffer_.end()) {
286 const Tensor& dst = it->second;
287 Expr ret = Call::make(
288 op->type, dst->op->name, op->args,
289 op->call_type, dst->op, dst->value_index);
290 return this->Mutate(ret);
291 }
292 }
293 return IRMutator::Mutate_(op, e);
294 }
295
Mutate_(const Variable * op,const Expr & e)296 Expr Mutate_(const Variable* op, const Expr& e) final {
297 auto it = var_value_.find(op);
298 if (it != var_value_.end()) {
299 return it->second;
300 } else {
301 return e;
302 }
303 }
304
Init(const Schedule & sch)305 void Init(const Schedule& sch) {
306 for (Stage s : sch->stages) {
307 for (auto kv : s->iter_var_attrs) {
308 // Update bind thread information.
309 if (kv.second->bind_thread.defined()) {
310 const Var& from = kv.first->var;
311 const Var& to = kv.second->bind_thread->var;
312 CHECK(!var_value_.count(from.get()));
313 var_value_[from.get()] = to;
314 }
315 }
316 // This must be checked for all ops, including scan.
317 if (!s->op.same_as(s->origin_op)) {
318 for (int i = 0; i < s->op->num_outputs(); ++i) {
319 Tensor target = s->origin_op.output(i);
320 AddReplace(s->op.output(i), target,
321 target, s->origin_op);
322 }
323 }
324 // Specially add replacements for scan op.
325 if (const ScanOpNode* scan = s->op.as<ScanOpNode>()) {
326 for (size_t i = 0; i < scan->update.size(); ++i) {
327 Tensor t = s->origin_op.output(i);
328 AddReplace(scan->init[i], t);
329 AddReplace(scan->update[i], t);
330 AddReplace(scan->state_placeholder[i], t);
331 }
332 }
333 }
334 }
335
336 private:
AddReplace(Tensor src,Tensor dst,Tensor repl_realize=Tensor (),Operation repl_op=Operation ())337 void AddReplace(Tensor src,
338 Tensor dst,
339 Tensor repl_realize = Tensor(),
340 Operation repl_op = Operation()) {
341 TensorKey key{src->op, src->value_index};
342 replace_buffer_[key] = dst;
343 replace_realize_[key] = repl_realize;
344 replace_op_[src->op.get()] = repl_op;
345 }
346 // The thread extent scope.
347 std::unordered_map<const Node*, Expr> thread_extent_scope_;
348 // The scan value
349 std::unordered_map<const Variable*, Expr> var_value_;
350 // buffer replacement
351 std::unordered_map<TensorKey, Tensor> replace_buffer_;
352 // buffere realization to be replaced
353 std::unordered_map<TensorKey, Tensor> replace_realize_;
354 // replace producer consumer.
355 std::unordered_map<const Node*, Operation> replace_op_;
356 };
357
ScheduleOps(Schedule sch,Map<IterVar,Range> dom_map_,bool debug_keep_trivial_loop)358 Stmt ScheduleOps(
359 Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
360 Stmt body = Stmt();
361 std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
362 // scan init and scan updates
363 std::unordered_map<Operation, Operation> scan_init;
364 for (Stage s : sch->stages) {
365 const ScanOpNode* scan = s->op.as<ScanOpNode>();
366 if (!scan) continue;
367 for (Tensor t : scan->init) {
368 if (scan_init.count(t->op)) {
369 CHECK(scan_init.at(t->op).same_as(s->op))
370 << "Scan init tensor can only belong to one scan";
371 } else {
372 scan_init[t->op] = s->op;
373 }
374 }
375 }
376 // verify correctness of group.
377 for (Stage g : sch->groups) {
378 CHECK(!g->op.defined());
379 CHECK_EQ(g->leaf_iter_vars.size(), 0U);
380 }
381 // reverse the post DFS order.
382 for (size_t i = sch->stages.size(); i != 0; --i) {
383 Stage s = sch->stages[i - 1];
384 CHECK_NE(s->attach_type, kInline)
385 << "call schedule.normalize before scheduleops";
386 CHECK(s->op.defined());
387 // no need to specify place holder op.
388 if (s->op.as<PlaceholderOpNode>()) continue;
389 // Remove grouping sugar, get the real attach spec.
390 Stage attach_spec = s.GetAttachSpec();
391
392 if (scan_init.count(s->op)) {
393 CHECK(body.defined());
394 InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
395 body = mu.Mutate(body);
396 CHECK(mu.found_attach)
397 << "did not find attachment point for scan.init";
398 } else if (attach_spec->attach_type == kScanUpdate) {
399 // Handle scan update
400 CHECK(body.defined());
401 InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
402 body = mu.Mutate(body);
403 CHECK(mu.found_attach)
404 << "did not find attachment point for scan.update";
405 } else if (attach_spec->attach_type == kInlinedAlready) {
406 // do nothing
407 } else if (attach_spec->attach_type == kGroupRoot) {
408 CHECK(!s->group.defined());
409 body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop);
410 } else {
411 CHECK_EQ(attach_spec->attach_type, kScope);
412 CHECK(body.defined());
413 InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
414 body = mutator.Mutate(body);
415 CHECK(mutator.found_attach)
416 << "did not find attachment point for " << s << " in "
417 << attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar
418 << ", body:\n"
419 << body;
420 }
421 }
422 SchedulePostProc post_proc;
423 post_proc.Init(sch);
424 return post_proc.Mutate(body);
425 }
426
427 } // namespace schedule
428 } // namespace tvm
429