1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file schedule_dataflow_rewrite.cc
22 */
23 #include <tvm/schedule.h>
24 #include <tvm/operation.h>
25 #include <tvm/ir_mutator.h>
26 #include <tvm/ir_pass.h>
27 #include <unordered_set>
28 #include "message_passing.h"
29 #include "../pass/ir_util.h"
30 #include "../arithmetic/compute_expr.h"
31
32 namespace tvm {
33
34 // find first occurance location in leaf
35 template<typename T>
FindNodeRef(ArrayNode * array_node,const T & v)36 size_t FindNodeRef(ArrayNode* array_node, const T& v) {
37 const Node* n = v.get();
38 for (size_t i = 0; i < array_node->data.size(); ++i) {
39 if (array_node->data[i].get() == n) return i;
40 }
41 return array_node->data.size();
42 }
43
44 // The replacer of cache.
45 class VarReplacer : public ir::IRMutator {
46 public:
VarReplacer(const std::unordered_map<const Variable *,Expr> & vsub)47 explicit VarReplacer(
48 const std::unordered_map<const Variable*, Expr>& vsub)
49 : vsub_(vsub) {}
Mutate_(const Variable * op,const Expr & e)50 Expr Mutate_(const Variable* op, const Expr& e) {
51 auto it = vsub_.find(op);
52 if (it != vsub_.end()) return it->second;
53 return e;
54 }
55
MutateCommReducer(ir::CommReducer combiner)56 ir::CommReducer MutateCommReducer(ir::CommReducer combiner) {
57 // Replace free variables in combiner
58 auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const Expr& e) {
59 return this->Mutate(e);
60 });
61 auto new_result = ir::UpdateArray(combiner->result, [this] (const Expr& e) {
62 return this->Mutate(e);
63 });
64
65 if (combiner->identity_element.same_as(new_identity) &&
66 combiner->identity_element.same_as(new_result)) {
67 return combiner;
68 } else {
69 return ir::CommReducerNode::make(
70 combiner->lhs, combiner->rhs, new_result, new_identity);
71 }
72 }
73
Mutate_(const ir::Reduce * op,const Expr & e)74 Expr Mutate_(const ir::Reduce* op, const Expr& e) {
75 Expr new_e = IRMutator::Mutate_(op, e);
76 const ir::Reduce* new_reduce = new_e.as<ir::Reduce>();
77 ir::CommReducer new_combiner = MutateCommReducer(op->combiner);
78 if (op->combiner.same_as(new_combiner)) {
79 return new_e;
80 } else {
81 return ir::Reduce::make(
82 new_combiner,
83 new_reduce->source,
84 new_reduce->axis,
85 new_reduce->condition,
86 new_reduce->value_index);
87 }
88 }
89
90 private:
91 const std::unordered_map<const Variable*, Expr>& vsub_;
92 };
93
InjectPredicate(const Array<Expr> & predicates,Expr body)94 Expr InjectPredicate(const Array<Expr>& predicates,
95 Expr body) {
96 using ir::Reduce;
97 using ir::Select;
98 if (predicates.size() == 0) return body;
99 const Reduce* reduce = body.as<Reduce>();
100 if (reduce) {
101 auto n = make_node<Reduce>(*reduce);
102 n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates, Expr());
103 return Expr(n);
104 }
105 return Select::make(arith::ComputeReduce<ir::And>(predicates, Expr()),
106 body,
107 make_zero(body.type()));
108 }
109
110 // Replace data flow appears in all stages given the tensor change.
111 // Also update vmap if subsequent dataflow need to be replaced.
112 // Need to keep an update to the date transitive closure property on the vmap by a reverse map.
ReplaceDataFlow(const Array<Stage> & stages,std::unordered_map<Tensor,Tensor> * vmap,std::unordered_map<Tensor,Tensor> * rvmap)113 void ReplaceDataFlow(const Array<Stage>& stages,
114 std::unordered_map<Tensor, Tensor>* vmap,
115 std::unordered_map<Tensor, Tensor>* rvmap) {
116 for (Stage s : stages) {
117 Operation op = s->op->ReplaceInputs(s->op, *vmap);
118 if (!op.same_as(s->op)) {
119 for (int i = 0; i < op->num_outputs(); ++i) {
120 auto it = rvmap->find(s->op.output(i));
121 if (it != rvmap->end()) {
122 (*vmap)[it->second] = op.output(i);
123 } else {
124 (*vmap)[s->op.output(i)] = op.output(i);
125 (*rvmap)[op.output(i)] = s->op.output(i);
126 }
127 }
128 s->op = op;
129 }
130 }
131 }
132
ReduceEqual(const ir::Reduce * a,const ir::Reduce * b)133 inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
134 return (a->combiner.same_as(b->combiner)) &&
135 (a->source.same_as(b->source)) &&
136 (a->axis.same_as(b->axis)) &&
137 (a->condition.same_as(b->condition));
138 }
139
cache_read(const Tensor & tensor,const std::string & scope,const Array<Operation> & readers)140 Tensor Schedule::cache_read(const Tensor& tensor,
141 const std::string& scope,
142 const Array<Operation>& readers) {
143 (*this)->InvalidateCache();
144 // create identity mapping.
145 std::ostringstream os;
146 os << tensor->op->name;
147 if (tensor->op->num_outputs() != 1) {
148 os << ".v" << tensor->value_index;
149 }
150 os << "." << scope;
151
152 std::unordered_map<Tensor, Tensor> vsub;
153 Stage s = operator[](tensor->op);
154 Tensor sugar_tensor = s->op.output(tensor->value_index);
155 Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array<Var>& i) {
156 return sugar_tensor(Array<Expr>(i.begin(), i.end()));
157 }, os.str());
158 vsub[sugar_tensor] = cache;
159
160 std::unordered_map<Tensor, Tensor> vmap;
161 std::unordered_map<Tensor, Tensor> rvmap;
162 for (Operation op : readers) {
163 Stage s = operator[](op);
164 Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
165 CHECK(!repl_op.same_as(s->op))
166 << "Cannot find " << tensor
167 << " in the inputs of " << s->op;
168 vmap[s->op.output(0)] = repl_op.output(0);
169 rvmap[repl_op.output(0)] = s->op.output(0);
170 s->op = repl_op;
171 }
172 ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
173 ArrayNode* stages = (*this)->stages.CopyOnWrite();
174 Stage op_stage = operator[](tensor->op);
175 size_t pos = FindNodeRef(stages, op_stage);
176 Stage cache_stage = Stage(cache->op);
177 cache_stage.set_scope(scope);
178 CHECK_LT(pos, stages->data.size());
179 stages->data.insert(stages->data.begin() + pos + 1,
180 cache_stage);
181 (*this)->stage_map.Set(cache->op, cache_stage);
182 // Update group
183 cache_stage->group = op_stage->group;
184 if (cache_stage->group.defined()) {
185 ++cache_stage->group->num_child_stages;
186 }
187 return cache;
188 }
189
190 template<typename OpType>
PrepareAxisMapping(Stage orig_stage,OpType * op,std::unordered_set<IterVar> * p_red_axis,Array<IterVar> * p_new_axis,std::unordered_map<IterVar,Range> * p_dom_map,std::unordered_map<const Variable *,Expr> * p_vsub,std::unordered_map<const Variable *,Expr> * p_vsub2newvar,std::vector<Expr> * p_predicates)191 void PrepareAxisMapping(Stage orig_stage,
192 OpType* op,
193 std::unordered_set<IterVar>* p_red_axis,
194 Array<IterVar>* p_new_axis,
195 std::unordered_map<IterVar, Range>* p_dom_map,
196 std::unordered_map<const Variable*, Expr>* p_vsub,
197 std::unordered_map<const Variable*, Expr>* p_vsub2newvar,
198 std::vector<Expr>* p_predicates) {
199 auto& red_axis = *p_red_axis;
200 auto& new_axis = *p_new_axis;
201 auto& dom_map = *p_dom_map;
202 auto& vsub = *p_vsub;
203 auto& vsub2newvar = *p_vsub2newvar;
204 auto& predicates = *p_predicates;
205 arith::Analyzer analyzer;
206
207 for (IterVar iv : op->reduce_axis) {
208 red_axis.insert(iv);
209 }
210 for (IterVar iv : op->axis) {
211 dom_map[iv] = iv->dom;
212 analyzer.Bind(iv->var, iv->dom);
213 }
214 schedule::PassDownDomain(orig_stage, &dom_map, &analyzer, true);
215 {
216 // The source->cache
217 std::unordered_map<IterVar, Expr> value_map;
218 for (IterVar iv : orig_stage->leaf_iter_vars) {
219 if (red_axis.count(iv)) continue;
220 CHECK_EQ(iv->iter_type, kDataPar)
221 << "Can only relayout with in data parallel dimensions";
222 Range dom = dom_map.at(iv);
223 IterVar new_iv = IterVarNode::make(
224 dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
225 new_axis.push_back(new_iv);
226 if (is_one(dom->min)) {
227 value_map[iv] = dom->min;
228 } else {
229 value_map[iv] = iv->var;
230 vsub2newvar[iv->var.get()] = new_iv->var;
231 }
232 }
233 // skip reduction iteration.
234 std::unordered_set<IterVar> skip_bound_check;
235 for (IterVar iv : op->reduce_axis) {
236 skip_bound_check.insert(iv);
237 }
238 schedule::PassUpIndex(orig_stage, dom_map, &value_map, true);
239 predicates = schedule::MakeBoundCheck(
240 orig_stage, dom_map, value_map, true, skip_bound_check);
241 // The root axis
242 for (IterVar iv : op->axis) {
243 if (value_map.count(iv)) {
244 vsub[iv->var.get()] = value_map.at(iv);
245 } // to handle tensor axis
246 }
247 }
248 }
249
ReplaceOriginalOp(Schedule sch,Stage orig_stage,const std::string & scope,Operation cache_op,Operation orig_new_op,size_t tensor_size)250 Array<Tensor> ReplaceOriginalOp(Schedule sch,
251 Stage orig_stage,
252 const std::string& scope,
253 Operation cache_op,
254 Operation orig_new_op,
255 size_t tensor_size) {
256 Array<Tensor> cache_tensor_list;
257 for (size_t i = 0; i < tensor_size; i++) {
258 Tensor cache_tensor = cache_op.output(i);
259 cache_tensor_list.push_back(cache_tensor);
260 }
261 // The replace of the dataflow
262 std::unordered_map<Tensor, Tensor> vmap;
263 std::unordered_map<Tensor, Tensor> rvmap;
264 vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
265 rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
266 for (size_t i = 0; i < tensor_size; i++) {
267 vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
268 rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
269 }
270 ReplaceDataFlow(sch->stages, &vmap, &rvmap);
271 // mutate orig stage
272 orig_stage->op = orig_new_op;
273 orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
274 orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
275 orig_stage->relations = Array<IterVarRelation>();
276 // create schedule for new cached stage.
277 ArrayNode* stages = sch->stages.CopyOnWrite();
278 size_t pos = FindNodeRef(stages, orig_stage);
279 Stage cache_stage = Stage(cache_op);
280 cache_stage.set_scope(scope);
281 CHECK_LT(pos, stages->data.size());
282 stages->data.insert(stages->data.begin() + pos,
283 cache_stage);
284 sch->stage_map.Set(cache_op, cache_stage);
285 // Update group
286 cache_stage->group = orig_stage->group;
287 if (cache_stage->group.defined()) {
288 ++cache_stage->group->num_child_stages;
289 }
290 return cache_tensor_list;
291 }
292
293
294 // Cache write and relayout the data according to loop pattern
CacheWriteWithReLayout(Schedule sch,const Array<Tensor> & tensor_array,const std::string & scope)295 Array<Tensor> CacheWriteWithReLayout(Schedule sch,
296 const Array<Tensor>& tensor_array,
297 const std::string& scope) {
298 size_t tensor_size = tensor_array.size();
299 sch->InvalidateCache();
300 Tensor tensor = tensor_array[0];
301 Stage orig_stage = sch[tensor->op];
302 const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
303
304 std::unordered_set<IterVar> red_axis;
305 Array<IterVar> new_axis;
306 std::unordered_map<IterVar, Range> dom_map;
307
308 std::unordered_map<const Variable*, Expr> vsub;
309 std::unordered_map<const Variable*, Expr> vsub2newvar;
310 std::vector<Expr> predicates;
311
312 PrepareAxisMapping(orig_stage, compute,
313 &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
314
315 Expr body;
316 Array<Expr> body_list;
317 const ir::Reduce* first_reduce = nullptr;
318 for (auto cbody : compute->body) {
319 body = VarReplacer(vsub).Mutate(cbody);
320 body = InjectPredicate(predicates, body);
321 body = VarReplacer(vsub2newvar).Mutate(body);
322 // Reduce nodes in ONE computeOp must be the same except value_index
323 // This is right only if the original body ensures Reduce nodes are the same
324 if (body->IsInstance<ir::Reduce>()) {
325 const ir::Reduce* reduce_body = body.as<ir::Reduce>();
326 if (first_reduce != nullptr) {
327 CHECK(ReduceEqual(reduce_body, first_reduce));
328 body = ir::Reduce::make(first_reduce->combiner,
329 first_reduce->source,
330 first_reduce->axis,
331 first_reduce->condition,
332 reduce_body->value_index);
333 } else {
334 first_reduce = reduce_body;
335 }
336 } else {
337 CHECK(first_reduce == nullptr)
338 << "cannot mix reduce and other node in ONE compute bodys";
339 }
340 body_list.push_back(body);
341 }
342 // The reader args
343 Array<Expr> args;
344 {
345 // cache->compute
346 std::unordered_map<IterVar, Expr> value_map;
347 for (IterVar iv : compute->axis) {
348 value_map[iv] = iv->var;
349 }
350 schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
351 for (IterVar iv : orig_stage->leaf_iter_vars) {
352 if (red_axis.count(iv)) continue;
353 args.push_back(value_map.at(iv));
354 }
355 }
356 Operation cache_op = ComputeOpNode::make(
357 compute->name + "." + scope, compute->tag, compute->attrs,
358 new_axis, body_list);
359
360 Array<Expr> cache_expr_list;
361 for (size_t i = 0; i < tensor_size; i++) {
362 Tensor cache_tensor = cache_op.output(i);
363 cache_expr_list.push_back(cache_tensor(args));
364 }
365 Operation orig_new_op = ComputeOpNode::make(
366 compute->name, compute->tag, compute->attrs,
367 compute->axis, cache_expr_list);
368 return ReplaceOriginalOp(sch, orig_stage, scope,
369 cache_op, orig_new_op, tensor_size);
370 }
371
372
373 // for tensor compute op
CacheWriteWithReLayoutTensor(Schedule sch,const Array<Tensor> & tensor_array,const std::string & scope)374 Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
375 const Array<Tensor>& tensor_array,
376 const std::string& scope) {
377 size_t tensor_size = tensor_array.size();
378 sch->InvalidateCache();
379 Tensor tensor = tensor_array[0];
380 Stage orig_stage = sch[tensor->op];
381 const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
382 CHECK_EQ(tensor_op->num_outputs(), 1)
383 << "cache write only support single output tensor_compute_op";
384
385 std::unordered_set<IterVar> red_axis;
386 Array<IterVar> new_axis;
387 std::unordered_map<IterVar, Range> dom_map;
388
389 std::unordered_map<const Variable*, Expr> vsub;
390 std::unordered_map<const Variable*, Expr> vsub2newvar;
391 std::vector<Expr> predicates;
392
393 PrepareAxisMapping(orig_stage, tensor_op,
394 &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
395
396
397 for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
398 IterVar iv = tensor_op->axis[i];
399 IterVar new_iv = IterVarNode::make(
400 iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
401 new_axis.push_back(new_iv);
402 }
403 Array<Region> new_regions;
404 for (Region old_region : tensor_op->input_regions) {
405 Region region;
406 for (Range r : old_region) {
407 Expr min = VarReplacer(vsub2newvar).Mutate(r->min);
408 Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent);
409 region.push_back(Range::make_by_min_extent(min, extent));
410 }
411 new_regions.push_back(region);
412 }
413
414 Array<Expr> new_scalar_inputs;
415 for (Expr old_input : tensor_op->scalar_inputs) {
416 new_scalar_inputs.push_back(VarReplacer(vsub2newvar).Mutate(old_input));
417 }
418
419 Operation cache_op = TensorComputeOpNode::make(
420 tensor_op->name + "." + scope, tensor_op->tag, new_axis,
421 tensor_op->reduce_axis, tensor_op->schedulable_ndim,
422 tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs);
423
424 // axis will be used in generating compute op
425 Array<IterVar> compute_axis = tensor_op->axis;
426 for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
427 IterVar iv = tensor_op->axis[i];
428 IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar);
429 compute_axis.Set(i, aiv);
430 }
431
432 // The reader args
433 Array<Expr> args;
434 {
435 // cache->compute
436 std::unordered_map<IterVar, Expr> value_map;
437 for (IterVar iv : compute_axis) {
438 value_map[iv] = iv->var;
439 }
440 schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
441 for (IterVar iv : orig_stage->leaf_iter_vars) {
442 if (red_axis.count(iv)) continue;
443 args.push_back(value_map.at(iv));
444 }
445 // tensorized region axis
446 for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
447 IterVar iv = compute_axis[i];
448 args.push_back(value_map.at(iv));
449 }
450 }
451
452 Array<Expr> cache_expr_list;
453 for (size_t i = 0; i < tensor_size; i++) {
454 Tensor cache_tensor = cache_op.output(i);
455 cache_expr_list.push_back(cache_tensor(args));
456 }
457 Operation orig_new_op = ComputeOpNode::make(
458 tensor_op->name, tensor_op->tag, {},
459 compute_axis, cache_expr_list);
460 return ReplaceOriginalOp(sch, orig_stage, scope,
461 cache_op, orig_new_op, tensor_size);
462 }
463
464
cache_write(const Array<Tensor> & tensor_array,const std::string & scope)465 Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
466 const std::string& scope) {
467 (*this)->InvalidateCache();
468 CHECK(tensor_array.size() > 0)
469 << "size of tensor_array must be greater than 0";
470 Tensor tensor = tensor_array[0];
471 Stage orig_stage = operator[](tensor->op);
472 const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
473 CHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size())
474 << "size of input tensor list must be same as number of stage outputs";
475 for (size_t i = 1; i < tensor_array.size(); i++) {
476 Stage tmp_stage = operator[](tensor_array[i]->op);
477 CHECK(orig_stage.same_as(tmp_stage))
478 << "Input tensor list must be generated by ONE computeOp";
479 }
480 return CacheWriteWithReLayout(*this, tensor_array, scope);
481 }
482
483
cache_write(const Tensor & tensor,const std::string & scope)484 Tensor Schedule::cache_write(const Tensor& tensor,
485 const std::string& scope) {
486 // support original compute and tensor compute both
487 (*this)->InvalidateCache();
488 if (tensor->op.as<ComputeOpNode>()) {
489 return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
490 } else if (tensor->op.as<TensorComputeOpNode>()) {
491 return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
492 } else {
493 LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
494 return Tensor();
495 }
496 }
497
498
RebaseNonZeroMinLoop(const Schedule & sch)499 void RebaseNonZeroMinLoop(const Schedule& sch) {
500 std::unordered_map<IterVar, IterVar> rebase_map;
501 for (Stage s : sch->stages) {
502 if (s->attach_type == kInlinedAlready) continue;
503
504 auto root_iter_vars = s->op->root_iter_vars();
505 ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
506 for (IterVar iv : root_iter_vars) {
507 size_t idx = FindNodeRef(leaf_vars, iv);
508 auto it = s->iter_var_attrs.find(iv);
509 // don;t need to rebase path that are binded.
510 if (it != s->iter_var_attrs.end() &&
511 (*it).second->bind_thread.defined()) {
512 continue;
513 }
514 if (idx < leaf_vars->data.size()) {
515 // insert rebase
516 IterVar rebased = IterVarNode::make(
517 Range(), iv->var.copy_with_suffix(""), iv->iter_type);
518 s->relations.push_back(RebaseNode::make(iv, rebased));
519 if (s->iter_var_attrs.count(iv)) {
520 s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
521 }
522 leaf_vars->data[idx] = rebased;
523 rebase_map[iv] = rebased;
524 }
525 }
526 }
527 // remap the parent relation
528 for (Stage s : sch->stages) {
529 if (s->attach_type != kScope) continue;
530 if (rebase_map.count(s->attach_ivar)) {
531 s->attach_ivar = rebase_map.at(s->attach_ivar);
532 }
533 }
534 for (Stage s : sch->groups) {
535 if (s->attach_type != kScope) continue;
536 if (rebase_map.count(s->attach_ivar)) {
537 s->attach_ivar = rebase_map.at(s->attach_ivar);
538 }
539 }
540 }
541
InjectInline(ScheduleNode * sch)542 void InjectInline(ScheduleNode* sch) {
543 sch->InvalidateCache();
544
545 std::vector<Array<Expr> > new_body(sch->stages.size());
546 std::vector<bool> changed(sch->stages.size(), false);
547 std::vector<Stmt> new_hybrid_body(sch->stages.size());
548 std::vector<bool> hybrid_changed(sch->stages.size(), false);
549 // inline all the ops
550 for (size_t i = sch->stages.size(); i != 0; --i) {
551 Stage stage = sch->stages[i - 1];
552 if (stage->attach_type == kInline) {
553 stage->attach_type = kInlinedAlready;
554 Array<Var> args;
555 Expr body;
556 {
557 // setup args
558 const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
559 CHECK(compute)
560 << "can only inline compute op";
561 for (auto iv : compute->axis) {
562 args.push_back(iv->var);
563 }
564 CHECK_EQ(compute->body.size(), 1U)
565 << "can only inline compute op with 1 output";
566 body = compute->body[0];
567 }
568 for (size_t j = i; j < sch->stages.size(); ++j) {
569 Stage s = sch->stages[j];
570 const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
571 const HybridOpNode* hybrid = s->op.as<HybridOpNode>();
572 if (compute) {
573 if (!new_body[j].size()) {
574 new_body[j] = compute->body;
575 }
576 if (new_body[j][0]->IsInstance<ir::Reduce>()) {
577 // specially handle reduction inline for multiplre reductions.
578 const ir::Reduce* reduce = new_body[j][0].as<ir::Reduce>();
579 for (size_t k = 1; k < new_body[j].size(); ++k) {
580 const ir::Reduce* reduce_ = new_body[j][k].as<ir::Reduce>();
581 CHECK(reduce_);
582 CHECK(ReduceEqual(reduce_, reduce))
583 << "The Reduce inputs of ComputeOp should "
584 << "have the same attribute except value_index";
585 }
586 Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][0]),
587 stage->op, args, body).as<ir::Evaluate>()->value;
588 if (!new_value.same_as(new_body[j][0])) {
589 changed[j] = true;
590 const ir::Reduce* r = new_value.as<ir::Reduce>();
591 CHECK_EQ(new_body[j].size(), r->source.size());
592 CHECK(r != nullptr);
593 for (size_t k = 0; k < new_body[j].size(); ++k) {
594 auto n = make_node<ir::Reduce>(*r);
595 n->value_index = static_cast<int>(k);
596 n->type = r->source[k].type();
597 new_body[j].Set(k, Expr(n));
598 }
599 }
600 } else {
601 for (size_t k = 0; k < new_body[j].size(); ++k) {
602 Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][k]),
603 stage->op, args, body).as<ir::Evaluate>()->value;
604 if (!new_value.same_as(new_body[j][k])) {
605 new_body[j].Set(k, new_value);
606 changed[j] = true;
607 }
608 }
609 }
610 } else if (hybrid) {
611 if (!new_hybrid_body[j].defined()) {
612 new_hybrid_body[j] = hybrid->body;
613 }
614 Stmt new_stmt = ir::Inline(new_hybrid_body[j], stage->op, args, body);
615 if (!new_stmt.same_as(new_hybrid_body[j])) {
616 new_hybrid_body[j] = new_stmt;
617 hybrid_changed[j] = true;
618 }
619 }
620 }
621 }
622 }
623 std::unordered_map<Tensor, Tensor> repl;
624 // rewrite dataflow
625 for (size_t i = 0; i < sch->stages.size(); ++i) {
626 Stage s = sch->stages[i];
627 if (s->attach_type == kInlinedAlready) continue;
628 if (new_body[i].size()) {
629 // Logics from ReplaceDataFlow
630 const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
631 CHECK(compute);
632 Operation op = s->op;
633 if (changed[i]) {
634 op = ComputeOpNode::make(
635 compute->name, compute->tag, compute->attrs,
636 compute->axis, new_body[i]);
637 }
638 op = op->ReplaceInputs(op, repl);
639 if (!op.same_as(s->op)) {
640 for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
641 repl[s->op.output(idx)] = op.output(idx);
642 }
643 s->op = op;
644 }
645 } else if (hybrid_changed[i]) {
646 const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
647 CHECK(hybrid);
648 Operation op = HybridOpNode::make(
649 hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
650 hybrid->outputs, new_hybrid_body[i]);
651 op = op->ReplaceInputs(op, repl);
652 for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
653 repl[s->op.output(idx)] = op.output(idx);
654 }
655 s->op = op;
656 } else {
657 Operation op = s->op->ReplaceInputs(s->op, repl);
658 if (!op.same_as(s->op)) {
659 for (int j = 0; j < op->num_outputs(); ++j) {
660 repl[s->op.output(j)] = op.output(j);
661 }
662 s->op = op;
663 }
664 }
665 }
666 }
667
normalize()668 Schedule Schedule::normalize() {
669 Schedule sn = copy();
670 InjectInline(sn.operator->());
671 RebaseNonZeroMinLoop(sn);
672 return sn;
673 }
674
675 // Handle reduction factor.
rfactor(const Tensor & tensor,const IterVar & axis,int factor_axis)676 Array<Tensor> Schedule::rfactor(const Tensor& tensor,
677 const IterVar& axis,
678 int factor_axis) {
679 (*this)->InvalidateCache();
680 using ir::Reduce;
681 CHECK_EQ(axis->iter_type, kCommReduce)
682 << "Can only factor reduction axis";
683 Stage reduce_stage = operator[](tensor->op);
684 const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
685 CHECK(compute_op) << "Can only factor ComputeOp";
686 ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
687 {
688 size_t axis_pos = FindNodeRef(leaf_vars, axis);
689 CHECK_NE(axis_pos, leaf_vars->data.size())
690 << "Cannot find IterVar " << axis << " in leaf iter vars";
691 }
692 // Find touched reduction axis.
693 std::unordered_map<IterVar, int> touch_map;
694 touch_map[axis] = 1;
695 schedule::PassUpBitMaskOr(reduce_stage, &touch_map, true);
696 schedule::PassDownBitMaskOr(reduce_stage, &touch_map, true);
697 // skip reduction iteration.
698 std::unordered_set<IterVar> skip_bound_check;
699 // Verify normal axis are not touched.
700 for (IterVar iv : compute_op->axis) {
701 CHECK(!touch_map.count(iv))
702 << "Factor axis touches normal axis.";
703 skip_bound_check.insert(iv);
704 }
705 // get analyzer.
706 arith::Analyzer analyzer;
707 // Get the replace index
708 std::unordered_map<IterVar, Range> dom_map;
709 std::unordered_map<IterVar, Expr> value_map;
710 for (IterVar iv : compute_op->reduce_axis) {
711 if (touch_map.count(iv)) {
712 dom_map[iv] = iv->dom;
713 } else {
714 skip_bound_check.insert(iv);
715 }
716 analyzer.Bind(iv->var, iv->dom);
717 }
718 schedule::PassDownDomain(reduce_stage, &dom_map, &analyzer, true);
719 for (IterVar iv : reduce_stage->leaf_iter_vars) {
720 if (touch_map.count(iv)) {
721 Range dom = dom_map.at(iv);
722 if (is_one(dom->extent)) {
723 value_map[iv] = dom->min;
724 } else {
725 value_map[iv] = iv->var;
726 }
727 }
728 }
729 schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true);
730 std::vector<Expr> predicates = schedule::MakeBoundCheck(
731 reduce_stage, dom_map, value_map, true, skip_bound_check);
732
733 // Get the factored op node.
734 const int factor_axis_pos = \
735 factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
736 CHECK_LE(factor_axis_pos, compute_op->axis.size());
737 auto n = make_node<ComputeOpNode>();
738 n->name = compute_op->name + ".rf";
739 {
740 // axis relacement.
741 auto iv_node = make_node<IterVarNode>();
742 iv_node->dom = dom_map.at(axis);
743 CHECK(is_zero(iv_node->dom->min))
744 << "Can only factor reduction domain starting from 0";
745 iv_node->var = axis->var;
746 iv_node->iter_type = kDataPar;
747
748 const int size = compute_op->axis.size();
749 for (int idx = 0; idx < size; ++idx) {
750 if (factor_axis_pos == idx) {
751 n->axis.push_back(IterVar(iv_node));
752 }
753 n->axis.push_back(compute_op->axis[idx]);
754 }
755 if (factor_axis_pos == size) {
756 n->axis.push_back(IterVar(iv_node));
757 }
758 }
759 // predicate generation, copy not touched axis.
760 int idx = tensor->value_index;
761 const Reduce* reduce = compute_op->body[idx].as<Reduce>();
762 CHECK(reduce) << "Can only rfactor non-inline reductions";
763 predicates.push_back(reduce->condition);
764 Expr predicate = likely(arith::ComputeReduce<ir::And>(predicates, Expr()));
765
766 std::unordered_map<const Variable*, Expr> vsub;
767
768 for (IterVar iv : compute_op->reduce_axis) {
769 if (!touch_map.count(iv)) {
770 n->reduce_axis.push_back(iv);
771 } else {
772 CHECK(value_map.count(iv));
773 Expr index = value_map.at(iv);
774 vsub[iv->var.get()] = index;
775 }
776 }
777
778 // Copy touched axis.
779 for (IterVar iv : reduce_stage->leaf_iter_vars) {
780 if (touch_map.count(iv) && !iv.same_as(axis)) {
781 CHECK_EQ(iv->iter_type, kCommReduce);
782 auto ncpy = make_node<IterVarNode>(*iv.operator->());
783 ncpy->dom = dom_map.at(iv);
784 n->reduce_axis.push_back(IterVar(ncpy));
785 }
786 }
787 VarReplacer replacer(vsub);
788 Array<Expr> new_source = ir::UpdateArray(reduce->source,
789 [&replacer] (const Expr& e) { return replacer.Mutate(e); });
790
791 Expr new_pred = replacer.Mutate(predicate);
792
793 std::vector<Expr> body;
794 for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
795 body.emplace_back(Reduce::make(reduce->combiner,
796 new_source,
797 n->reduce_axis,
798 new_pred,
799 idx));
800 }
801 n->body = Array<Expr>(body);
802 // refresh relations, keep the un-touched relations.
803 Array<IterVarRelation> rels;
804 for (IterVarRelation rel : reduce_stage->relations) {
805 bool touched = false;
806 if (const SplitNode* r = rel.as<SplitNode>()) {
807 if (touch_map.count(r->parent)) touched = true;
808 } else if (const FuseNode* r = rel.as<FuseNode>()) {
809 if (touch_map.count(r->fused)) touched = true;
810 } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
811 if (touch_map.count(r->parent)) touched = true;
812 } else {
813 LOG(FATAL) << "unknown relation type";
814 }
815 if (!touched) {
816 rels.push_back(rel);
817 }
818 }
819 // initialize the factored stage.
820 Operation factor_op(n);
821 ArrayNode* stages = (*this)->stages.CopyOnWrite();
822 size_t stage_pos = FindNodeRef(stages, reduce_stage);
823 Stage factor_stage = Stage(factor_op);
824 factor_stage->relations = rels;
825 CHECK_LT(stage_pos, stages->data.size());
826 stages->data.insert(stages->data.begin() + stage_pos,
827 factor_stage);
828 (*this)->stage_map.Set(factor_op, factor_stage);
829 factor_stage->group = reduce_stage->group;
830 if (factor_stage->group.defined()) {
831 ++factor_stage->group->num_child_stages;
832 }
833 // Replace the old reduction.
834 IterVar repl_red_axis = reduce_axis(
835 dom_map.at(axis), axis->var->name_hint + ".v");
836 Array<Tensor> factor_tensors;
837 Array<Tensor> old_tensors;
838 int size = factor_op->num_outputs();
839 for (int idx = 0; idx < size; ++idx) {
840 factor_tensors.push_back(factor_op.output(idx));
841 old_tensors.push_back(reduce_stage->op.output(idx));
842 }
843 Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
844 [&](const Array<Var>& i) {
845 Array<Expr> indices;
846 const int idx_size = static_cast<int>(i.size());
847 for (int idx = 0; idx < idx_size; ++idx) {
848 if (factor_axis_pos == idx) {
849 indices.push_back(repl_red_axis->var);
850 }
851 indices.push_back(i[idx]);
852 }
853 if (factor_axis_pos == idx_size) {
854 indices.push_back(repl_red_axis->var);
855 }
856 Array<Expr> factor_exprs;
857 for (int idx = 0; idx < size; ++idx) {
858 factor_exprs.push_back(factor_tensors[idx](indices));
859 }
860 Array<Expr> reductions;
861 Array<IterVar> axis = {repl_red_axis};
862 Expr cond = const_true();
863 for (int idx = 0; idx < size; ++idx) {
864 reductions.push_back(Reduce::make(reduce->combiner,
865 factor_exprs, axis, cond, idx));
866 }
867 return reductions;
868 }, reduce_stage->op->name + ".repl");
869
870 std::unordered_map<Tensor, Tensor> vmap;
871 std::unordered_map<Tensor, Tensor> rvmap;
872 for (int idx = 0; idx < size; ++idx) {
873 vmap[old_tensors[idx]] = repl_tensors[idx];
874 rvmap[repl_tensors[idx]] = old_tensors[idx];
875 }
876 ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
877 // revamp the reduction stage.
878 reduce_stage->op = repl_tensors[0]->op;
879 reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
880 reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
881 reduce_stage->relations = Array<IterVarRelation>();
882 return factor_tensors;
883 }
884
885 } // namespace tvm
886