1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file schedule_dataflow_rewrite.cc
22  */
23 #include <tvm/schedule.h>
24 #include <tvm/operation.h>
25 #include <tvm/ir_mutator.h>
26 #include <tvm/ir_pass.h>
27 #include <unordered_set>
28 #include "message_passing.h"
29 #include "../pass/ir_util.h"
30 #include "../arithmetic/compute_expr.h"
31 
32 namespace tvm {
33 
34 // find first occurance location in leaf
35 template<typename T>
FindNodeRef(ArrayNode * array_node,const T & v)36 size_t FindNodeRef(ArrayNode* array_node, const T& v) {
37   const Node* n = v.get();
38   for (size_t i = 0; i < array_node->data.size(); ++i) {
39     if (array_node->data[i].get() == n) return i;
40   }
41   return array_node->data.size();
42 }
43 
44 // The replacer of cache.
45 class VarReplacer : public ir::IRMutator {
46  public:
VarReplacer(const std::unordered_map<const Variable *,Expr> & vsub)47   explicit VarReplacer(
48       const std::unordered_map<const Variable*, Expr>& vsub)
49       : vsub_(vsub) {}
Mutate_(const Variable * op,const Expr & e)50   Expr Mutate_(const Variable* op, const Expr& e) {
51     auto it = vsub_.find(op);
52     if (it != vsub_.end()) return it->second;
53     return e;
54   }
55 
MutateCommReducer(ir::CommReducer combiner)56   ir::CommReducer MutateCommReducer(ir::CommReducer combiner) {
57     // Replace free variables in combiner
58     auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const Expr& e) {
59       return this->Mutate(e);
60       });
61     auto new_result = ir::UpdateArray(combiner->result, [this] (const Expr& e) {
62       return this->Mutate(e);
63       });
64 
65     if (combiner->identity_element.same_as(new_identity) &&
66         combiner->identity_element.same_as(new_result)) {
67       return combiner;
68     } else {
69       return ir::CommReducerNode::make(
70         combiner->lhs, combiner->rhs, new_result, new_identity);
71     }
72   }
73 
Mutate_(const ir::Reduce * op,const Expr & e)74   Expr Mutate_(const ir::Reduce* op, const Expr& e) {
75     Expr new_e = IRMutator::Mutate_(op, e);
76     const ir::Reduce* new_reduce = new_e.as<ir::Reduce>();
77     ir::CommReducer new_combiner = MutateCommReducer(op->combiner);
78     if (op->combiner.same_as(new_combiner)) {
79       return new_e;
80     } else {
81       return ir::Reduce::make(
82         new_combiner,
83         new_reduce->source,
84         new_reduce->axis,
85         new_reduce->condition,
86         new_reduce->value_index);
87     }
88   }
89 
90  private:
91   const std::unordered_map<const Variable*, Expr>& vsub_;
92 };
93 
InjectPredicate(const Array<Expr> & predicates,Expr body)94 Expr InjectPredicate(const Array<Expr>& predicates,
95                      Expr body) {
96   using ir::Reduce;
97   using ir::Select;
98   if (predicates.size() == 0) return body;
99   const Reduce* reduce = body.as<Reduce>();
100   if (reduce) {
101     auto n = make_node<Reduce>(*reduce);
102     n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates, Expr());
103     return Expr(n);
104   }
105   return Select::make(arith::ComputeReduce<ir::And>(predicates, Expr()),
106                       body,
107                       make_zero(body.type()));
108 }
109 
110 // Replace data flow appears in all stages given the tensor change.
111 // Also update vmap if subsequent dataflow need to be replaced.
112 // Need to keep an update to the date transitive closure property on the vmap by a reverse map.
ReplaceDataFlow(const Array<Stage> & stages,std::unordered_map<Tensor,Tensor> * vmap,std::unordered_map<Tensor,Tensor> * rvmap)113 void ReplaceDataFlow(const Array<Stage>& stages,
114                      std::unordered_map<Tensor, Tensor>* vmap,
115                      std::unordered_map<Tensor, Tensor>* rvmap) {
116   for (Stage s : stages) {
117     Operation op = s->op->ReplaceInputs(s->op, *vmap);
118     if (!op.same_as(s->op)) {
119       for (int i = 0; i < op->num_outputs(); ++i) {
120         auto it = rvmap->find(s->op.output(i));
121         if (it != rvmap->end()) {
122           (*vmap)[it->second] = op.output(i);
123         } else {
124           (*vmap)[s->op.output(i)] = op.output(i);
125           (*rvmap)[op.output(i)] = s->op.output(i);
126         }
127       }
128       s->op = op;
129     }
130   }
131 }
132 
ReduceEqual(const ir::Reduce * a,const ir::Reduce * b)133 inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
134   return (a->combiner.same_as(b->combiner)) &&
135          (a->source.same_as(b->source)) &&
136          (a->axis.same_as(b->axis)) &&
137          (a->condition.same_as(b->condition));
138 }
139 
cache_read(const Tensor & tensor,const std::string & scope,const Array<Operation> & readers)140 Tensor Schedule::cache_read(const Tensor& tensor,
141                             const std::string& scope,
142                             const Array<Operation>& readers) {
143   (*this)->InvalidateCache();
144   // create identity mapping.
145   std::ostringstream os;
146   os << tensor->op->name;
147   if (tensor->op->num_outputs() != 1) {
148     os << ".v" << tensor->value_index;
149   }
150   os << "." << scope;
151 
152   std::unordered_map<Tensor, Tensor> vsub;
153   Stage s = operator[](tensor->op);
154   Tensor sugar_tensor = s->op.output(tensor->value_index);
155   Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array<Var>& i) {
156       return sugar_tensor(Array<Expr>(i.begin(), i.end()));
157     }, os.str());
158   vsub[sugar_tensor] = cache;
159 
160   std::unordered_map<Tensor, Tensor> vmap;
161   std::unordered_map<Tensor, Tensor> rvmap;
162   for (Operation op : readers) {
163     Stage s = operator[](op);
164     Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
165     CHECK(!repl_op.same_as(s->op))
166         << "Cannot find " << tensor
167         << " in the inputs of " << s->op;
168     vmap[s->op.output(0)] = repl_op.output(0);
169     rvmap[repl_op.output(0)] = s->op.output(0);
170     s->op = repl_op;
171   }
172   ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
173   ArrayNode* stages = (*this)->stages.CopyOnWrite();
174   Stage op_stage = operator[](tensor->op);
175   size_t pos = FindNodeRef(stages, op_stage);
176   Stage cache_stage = Stage(cache->op);
177   cache_stage.set_scope(scope);
178   CHECK_LT(pos, stages->data.size());
179   stages->data.insert(stages->data.begin() + pos + 1,
180                       cache_stage);
181   (*this)->stage_map.Set(cache->op, cache_stage);
182   // Update group
183   cache_stage->group = op_stage->group;
184   if (cache_stage->group.defined()) {
185     ++cache_stage->group->num_child_stages;
186   }
187   return cache;
188 }
189 
190 template<typename OpType>
PrepareAxisMapping(Stage orig_stage,OpType * op,std::unordered_set<IterVar> * p_red_axis,Array<IterVar> * p_new_axis,std::unordered_map<IterVar,Range> * p_dom_map,std::unordered_map<const Variable *,Expr> * p_vsub,std::unordered_map<const Variable *,Expr> * p_vsub2newvar,std::vector<Expr> * p_predicates)191 void PrepareAxisMapping(Stage orig_stage,
192                         OpType* op,
193                         std::unordered_set<IterVar>* p_red_axis,
194                         Array<IterVar>* p_new_axis,
195                         std::unordered_map<IterVar, Range>* p_dom_map,
196                         std::unordered_map<const Variable*, Expr>* p_vsub,
197                         std::unordered_map<const Variable*, Expr>* p_vsub2newvar,
198                         std::vector<Expr>* p_predicates) {
199   auto& red_axis = *p_red_axis;
200   auto& new_axis = *p_new_axis;
201   auto& dom_map = *p_dom_map;
202   auto& vsub = *p_vsub;
203   auto& vsub2newvar = *p_vsub2newvar;
204   auto& predicates = *p_predicates;
205   arith::Analyzer analyzer;
206 
207   for (IterVar iv : op->reduce_axis) {
208     red_axis.insert(iv);
209   }
210   for (IterVar iv : op->axis) {
211     dom_map[iv] = iv->dom;
212     analyzer.Bind(iv->var, iv->dom);
213   }
214   schedule::PassDownDomain(orig_stage, &dom_map, &analyzer, true);
215   {
216     // The source->cache
217     std::unordered_map<IterVar, Expr> value_map;
218     for (IterVar iv : orig_stage->leaf_iter_vars) {
219       if (red_axis.count(iv)) continue;
220       CHECK_EQ(iv->iter_type, kDataPar)
221           << "Can only relayout with in data parallel dimensions";
222       Range dom = dom_map.at(iv);
223       IterVar new_iv = IterVarNode::make(
224           dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
225       new_axis.push_back(new_iv);
226       if (is_one(dom->min)) {
227         value_map[iv] = dom->min;
228       } else {
229         value_map[iv] = iv->var;
230         vsub2newvar[iv->var.get()] = new_iv->var;
231       }
232     }
233     // skip reduction iteration.
234     std::unordered_set<IterVar> skip_bound_check;
235     for (IterVar iv : op->reduce_axis) {
236       skip_bound_check.insert(iv);
237     }
238     schedule::PassUpIndex(orig_stage, dom_map, &value_map, true);
239     predicates = schedule::MakeBoundCheck(
240         orig_stage, dom_map, value_map, true, skip_bound_check);
241     // The root axis
242     for (IterVar iv : op->axis) {
243       if (value_map.count(iv)) {
244         vsub[iv->var.get()] = value_map.at(iv);
245       }  // to handle tensor axis
246     }
247   }
248 }
249 
ReplaceOriginalOp(Schedule sch,Stage orig_stage,const std::string & scope,Operation cache_op,Operation orig_new_op,size_t tensor_size)250 Array<Tensor> ReplaceOriginalOp(Schedule sch,
251                                 Stage orig_stage,
252                                 const std::string& scope,
253                                 Operation cache_op,
254                                 Operation orig_new_op,
255                                 size_t tensor_size) {
256   Array<Tensor> cache_tensor_list;
257   for (size_t i = 0; i < tensor_size; i++) {
258     Tensor cache_tensor = cache_op.output(i);
259     cache_tensor_list.push_back(cache_tensor);
260   }
261   // The replace of the dataflow
262   std::unordered_map<Tensor, Tensor> vmap;
263   std::unordered_map<Tensor, Tensor> rvmap;
264   vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
265   rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
266   for (size_t i = 0; i < tensor_size; i++) {
267     vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
268     rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
269   }
270   ReplaceDataFlow(sch->stages, &vmap, &rvmap);
271   // mutate orig stage
272   orig_stage->op = orig_new_op;
273   orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
274   orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
275   orig_stage->relations = Array<IterVarRelation>();
276   // create schedule for new cached stage.
277   ArrayNode* stages = sch->stages.CopyOnWrite();
278   size_t pos = FindNodeRef(stages, orig_stage);
279   Stage cache_stage = Stage(cache_op);
280   cache_stage.set_scope(scope);
281   CHECK_LT(pos, stages->data.size());
282   stages->data.insert(stages->data.begin() + pos,
283                       cache_stage);
284   sch->stage_map.Set(cache_op, cache_stage);
285   // Update group
286   cache_stage->group = orig_stage->group;
287   if (cache_stage->group.defined()) {
288     ++cache_stage->group->num_child_stages;
289   }
290   return cache_tensor_list;
291 }
292 
293 
294 // Cache write and relayout the data according to loop pattern
CacheWriteWithReLayout(Schedule sch,const Array<Tensor> & tensor_array,const std::string & scope)295 Array<Tensor> CacheWriteWithReLayout(Schedule sch,
296                                      const Array<Tensor>& tensor_array,
297                                      const std::string& scope) {
298   size_t tensor_size = tensor_array.size();
299   sch->InvalidateCache();
300   Tensor tensor = tensor_array[0];
301   Stage orig_stage = sch[tensor->op];
302   const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
303 
304   std::unordered_set<IterVar> red_axis;
305   Array<IterVar> new_axis;
306   std::unordered_map<IterVar, Range> dom_map;
307 
308   std::unordered_map<const Variable*, Expr> vsub;
309   std::unordered_map<const Variable*, Expr> vsub2newvar;
310   std::vector<Expr> predicates;
311 
312   PrepareAxisMapping(orig_stage, compute,
313     &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
314 
315   Expr body;
316   Array<Expr> body_list;
317   const ir::Reduce* first_reduce = nullptr;
318   for (auto cbody : compute->body) {
319     body = VarReplacer(vsub).Mutate(cbody);
320     body = InjectPredicate(predicates, body);
321     body = VarReplacer(vsub2newvar).Mutate(body);
322     // Reduce nodes in ONE computeOp must be the same except value_index
323     // This is right only if the original body ensures Reduce nodes are the same
324     if (body->IsInstance<ir::Reduce>()) {
325       const ir::Reduce* reduce_body = body.as<ir::Reduce>();
326       if (first_reduce != nullptr) {
327         CHECK(ReduceEqual(reduce_body, first_reduce));
328         body = ir::Reduce::make(first_reduce->combiner,
329                                 first_reduce->source,
330                                 first_reduce->axis,
331                                 first_reduce->condition,
332                                 reduce_body->value_index);
333       } else {
334         first_reduce = reduce_body;
335       }
336     } else {
337       CHECK(first_reduce == nullptr)
338         << "cannot mix reduce and other node in ONE compute bodys";
339     }
340     body_list.push_back(body);
341   }
342   // The reader args
343   Array<Expr> args;
344   {
345     // cache->compute
346     std::unordered_map<IterVar, Expr> value_map;
347     for (IterVar iv : compute->axis) {
348       value_map[iv] = iv->var;
349     }
350     schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
351     for (IterVar iv : orig_stage->leaf_iter_vars) {
352       if (red_axis.count(iv)) continue;
353       args.push_back(value_map.at(iv));
354     }
355   }
356   Operation cache_op = ComputeOpNode::make(
357       compute->name + "." + scope, compute->tag, compute->attrs,
358       new_axis, body_list);
359 
360   Array<Expr> cache_expr_list;
361   for (size_t i = 0; i < tensor_size; i++) {
362     Tensor cache_tensor = cache_op.output(i);
363     cache_expr_list.push_back(cache_tensor(args));
364   }
365   Operation orig_new_op = ComputeOpNode::make(
366       compute->name, compute->tag, compute->attrs,
367       compute->axis, cache_expr_list);
368   return ReplaceOriginalOp(sch, orig_stage, scope,
369     cache_op, orig_new_op, tensor_size);
370 }
371 
372 
373 // for tensor compute op
CacheWriteWithReLayoutTensor(Schedule sch,const Array<Tensor> & tensor_array,const std::string & scope)374 Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
375                                            const Array<Tensor>& tensor_array,
376                                            const std::string& scope) {
377   size_t tensor_size = tensor_array.size();
378   sch->InvalidateCache();
379   Tensor tensor = tensor_array[0];
380   Stage orig_stage = sch[tensor->op];
381   const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
382   CHECK_EQ(tensor_op->num_outputs(), 1)
383       << "cache write only support single output tensor_compute_op";
384 
385   std::unordered_set<IterVar> red_axis;
386   Array<IterVar> new_axis;
387   std::unordered_map<IterVar, Range> dom_map;
388 
389   std::unordered_map<const Variable*, Expr> vsub;
390   std::unordered_map<const Variable*, Expr> vsub2newvar;
391   std::vector<Expr> predicates;
392 
393   PrepareAxisMapping(orig_stage, tensor_op,
394     &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
395 
396 
397   for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
398     IterVar iv = tensor_op->axis[i];
399     IterVar new_iv = IterVarNode::make(
400       iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
401     new_axis.push_back(new_iv);
402   }
403   Array<Region> new_regions;
404   for (Region old_region : tensor_op->input_regions) {
405     Region region;
406     for (Range r : old_region) {
407       Expr min = VarReplacer(vsub2newvar).Mutate(r->min);
408       Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent);
409       region.push_back(Range::make_by_min_extent(min, extent));
410     }
411     new_regions.push_back(region);
412   }
413 
414   Array<Expr> new_scalar_inputs;
415   for (Expr old_input : tensor_op->scalar_inputs) {
416     new_scalar_inputs.push_back(VarReplacer(vsub2newvar).Mutate(old_input));
417   }
418 
419   Operation cache_op = TensorComputeOpNode::make(
420       tensor_op->name + "." + scope, tensor_op->tag, new_axis,
421       tensor_op->reduce_axis, tensor_op->schedulable_ndim,
422       tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs);
423 
424   // axis will be used in generating compute op
425   Array<IterVar> compute_axis = tensor_op->axis;
426   for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
427     IterVar iv = tensor_op->axis[i];
428     IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar);
429     compute_axis.Set(i, aiv);
430   }
431 
432   // The reader args
433   Array<Expr> args;
434   {
435     // cache->compute
436     std::unordered_map<IterVar, Expr> value_map;
437     for (IterVar iv : compute_axis) {
438       value_map[iv] = iv->var;
439     }
440     schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
441     for (IterVar iv : orig_stage->leaf_iter_vars) {
442       if (red_axis.count(iv)) continue;
443       args.push_back(value_map.at(iv));
444     }
445     // tensorized region axis
446     for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
447       IterVar iv = compute_axis[i];
448       args.push_back(value_map.at(iv));
449     }
450   }
451 
452   Array<Expr> cache_expr_list;
453   for (size_t i = 0; i < tensor_size; i++) {
454     Tensor cache_tensor = cache_op.output(i);
455     cache_expr_list.push_back(cache_tensor(args));
456   }
457   Operation orig_new_op = ComputeOpNode::make(
458       tensor_op->name, tensor_op->tag, {},
459       compute_axis, cache_expr_list);
460   return ReplaceOriginalOp(sch, orig_stage, scope,
461     cache_op, orig_new_op, tensor_size);
462 }
463 
464 
cache_write(const Array<Tensor> & tensor_array,const std::string & scope)465 Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
466                              const std::string& scope) {
467   (*this)->InvalidateCache();
468   CHECK(tensor_array.size() > 0)
469       << "size of tensor_array must be greater than 0";
470   Tensor tensor = tensor_array[0];
471   Stage orig_stage = operator[](tensor->op);
472   const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
473   CHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size())
474       << "size of input tensor list must be same as number of stage outputs";
475   for (size_t i = 1; i < tensor_array.size(); i++) {
476     Stage tmp_stage = operator[](tensor_array[i]->op);
477     CHECK(orig_stage.same_as(tmp_stage))
478         << "Input tensor list must be generated by ONE computeOp";
479   }
480   return CacheWriteWithReLayout(*this, tensor_array, scope);
481 }
482 
483 
cache_write(const Tensor & tensor,const std::string & scope)484 Tensor Schedule::cache_write(const Tensor& tensor,
485                              const std::string& scope) {
486   // support original compute and tensor compute both
487   (*this)->InvalidateCache();
488   if (tensor->op.as<ComputeOpNode>()) {
489     return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
490   } else if (tensor->op.as<TensorComputeOpNode>()) {
491     return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
492   } else {
493     LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
494     return Tensor();
495   }
496 }
497 
498 
RebaseNonZeroMinLoop(const Schedule & sch)499 void RebaseNonZeroMinLoop(const Schedule& sch) {
500   std::unordered_map<IterVar, IterVar> rebase_map;
501   for (Stage s : sch->stages) {
502     if (s->attach_type == kInlinedAlready) continue;
503 
504     auto root_iter_vars = s->op->root_iter_vars();
505     ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
506     for (IterVar iv : root_iter_vars) {
507       size_t idx = FindNodeRef(leaf_vars, iv);
508       auto it  = s->iter_var_attrs.find(iv);
509       // don;t need to rebase path that are binded.
510       if (it != s->iter_var_attrs.end() &&
511           (*it).second->bind_thread.defined()) {
512         continue;
513       }
514       if (idx < leaf_vars->data.size()) {
515         // insert rebase
516         IterVar rebased = IterVarNode::make(
517             Range(), iv->var.copy_with_suffix(""), iv->iter_type);
518         s->relations.push_back(RebaseNode::make(iv, rebased));
519         if (s->iter_var_attrs.count(iv)) {
520           s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
521         }
522         leaf_vars->data[idx] = rebased;
523         rebase_map[iv] = rebased;
524       }
525     }
526   }
527   // remap the parent relation
528   for (Stage s : sch->stages) {
529     if (s->attach_type != kScope) continue;
530     if (rebase_map.count(s->attach_ivar)) {
531       s->attach_ivar = rebase_map.at(s->attach_ivar);
532     }
533   }
534   for (Stage s : sch->groups) {
535     if (s->attach_type != kScope) continue;
536     if (rebase_map.count(s->attach_ivar)) {
537       s->attach_ivar = rebase_map.at(s->attach_ivar);
538     }
539   }
540 }
541 
InjectInline(ScheduleNode * sch)542 void InjectInline(ScheduleNode* sch) {
543   sch->InvalidateCache();
544 
545   std::vector<Array<Expr> > new_body(sch->stages.size());
546   std::vector<bool> changed(sch->stages.size(), false);
547   std::vector<Stmt> new_hybrid_body(sch->stages.size());
548   std::vector<bool> hybrid_changed(sch->stages.size(), false);
549   // inline all the ops
550   for (size_t i = sch->stages.size(); i != 0; --i) {
551     Stage stage = sch->stages[i - 1];
552     if (stage->attach_type == kInline) {
553       stage->attach_type = kInlinedAlready;
554       Array<Var> args;
555       Expr body;
556       {
557         // setup args
558         const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
559         CHECK(compute)
560             << "can only inline compute op";
561         for (auto iv : compute->axis) {
562           args.push_back(iv->var);
563         }
564         CHECK_EQ(compute->body.size(), 1U)
565             << "can only inline compute op with 1 output";
566         body = compute->body[0];
567       }
568       for (size_t j = i; j < sch->stages.size(); ++j) {
569         Stage s = sch->stages[j];
570         const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
571         const HybridOpNode* hybrid = s->op.as<HybridOpNode>();
572         if (compute) {
573           if (!new_body[j].size()) {
574             new_body[j] = compute->body;
575           }
576           if (new_body[j][0]->IsInstance<ir::Reduce>()) {
577             // specially handle reduction inline for multiplre reductions.
578             const ir::Reduce* reduce = new_body[j][0].as<ir::Reduce>();
579             for (size_t k = 1; k < new_body[j].size(); ++k) {
580               const ir::Reduce* reduce_ = new_body[j][k].as<ir::Reduce>();
581               CHECK(reduce_);
582               CHECK(ReduceEqual(reduce_, reduce))
583                   << "The Reduce inputs of ComputeOp should "
584                   << "have the same attribute except value_index";
585             }
586             Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][0]),
587                                         stage->op, args, body).as<ir::Evaluate>()->value;
588             if (!new_value.same_as(new_body[j][0])) {
589               changed[j] = true;
590               const ir::Reduce* r = new_value.as<ir::Reduce>();
591               CHECK_EQ(new_body[j].size(), r->source.size());
592               CHECK(r != nullptr);
593               for (size_t k = 0; k < new_body[j].size(); ++k) {
594                 auto n = make_node<ir::Reduce>(*r);
595                 n->value_index = static_cast<int>(k);
596                 n->type = r->source[k].type();
597                 new_body[j].Set(k, Expr(n));
598               }
599             }
600           } else {
601             for (size_t k = 0; k < new_body[j].size(); ++k) {
602               Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][k]),
603                                           stage->op, args, body).as<ir::Evaluate>()->value;
604               if (!new_value.same_as(new_body[j][k])) {
605                 new_body[j].Set(k, new_value);
606                 changed[j] = true;
607               }
608             }
609           }
610         } else if (hybrid) {
611           if (!new_hybrid_body[j].defined()) {
612             new_hybrid_body[j] = hybrid->body;
613           }
614           Stmt new_stmt = ir::Inline(new_hybrid_body[j], stage->op, args, body);
615           if (!new_stmt.same_as(new_hybrid_body[j])) {
616             new_hybrid_body[j] = new_stmt;
617             hybrid_changed[j] = true;
618           }
619         }
620       }
621     }
622   }
623   std::unordered_map<Tensor, Tensor> repl;
624   // rewrite dataflow
625   for (size_t i = 0; i < sch->stages.size(); ++i) {
626     Stage s = sch->stages[i];
627     if (s->attach_type == kInlinedAlready) continue;
628     if (new_body[i].size()) {
629       // Logics from ReplaceDataFlow
630       const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
631       CHECK(compute);
632       Operation op = s->op;
633       if (changed[i]) {
634         op = ComputeOpNode::make(
635             compute->name, compute->tag, compute->attrs,
636             compute->axis, new_body[i]);
637       }
638       op = op->ReplaceInputs(op, repl);
639       if (!op.same_as(s->op)) {
640         for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
641           repl[s->op.output(idx)] = op.output(idx);
642         }
643         s->op = op;
644       }
645     } else if (hybrid_changed[i]) {
646       const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
647       CHECK(hybrid);
648       Operation op = HybridOpNode::make(
649               hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
650               hybrid->outputs, new_hybrid_body[i]);
651       op = op->ReplaceInputs(op, repl);
652       for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
653         repl[s->op.output(idx)] = op.output(idx);
654       }
655       s->op = op;
656     } else {
657       Operation op = s->op->ReplaceInputs(s->op, repl);
658       if (!op.same_as(s->op)) {
659         for (int j = 0; j < op->num_outputs(); ++j) {
660           repl[s->op.output(j)] = op.output(j);
661         }
662         s->op = op;
663       }
664     }
665   }
666 }
667 
normalize()668 Schedule Schedule::normalize() {
669   Schedule sn = copy();
670   InjectInline(sn.operator->());
671   RebaseNonZeroMinLoop(sn);
672   return sn;
673 }
674 
675 // Handle reduction factor.
rfactor(const Tensor & tensor,const IterVar & axis,int factor_axis)676 Array<Tensor> Schedule::rfactor(const Tensor& tensor,
677                                 const IterVar& axis,
678                                 int factor_axis) {
679   (*this)->InvalidateCache();
680   using ir::Reduce;
681   CHECK_EQ(axis->iter_type, kCommReduce)
682       << "Can only factor reduction axis";
683   Stage reduce_stage = operator[](tensor->op);
684   const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
685   CHECK(compute_op) << "Can only factor ComputeOp";
686   ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
687   {
688     size_t axis_pos = FindNodeRef(leaf_vars, axis);
689     CHECK_NE(axis_pos, leaf_vars->data.size())
690         << "Cannot find IterVar " << axis << " in leaf iter vars";
691   }
692   // Find touched reduction axis.
693   std::unordered_map<IterVar, int> touch_map;
694   touch_map[axis] = 1;
695   schedule::PassUpBitMaskOr(reduce_stage, &touch_map, true);
696   schedule::PassDownBitMaskOr(reduce_stage, &touch_map, true);
697   // skip reduction iteration.
698   std::unordered_set<IterVar> skip_bound_check;
699   // Verify normal axis are not touched.
700   for (IterVar iv : compute_op->axis) {
701     CHECK(!touch_map.count(iv))
702         << "Factor axis touches normal axis.";
703     skip_bound_check.insert(iv);
704   }
705   // get analyzer.
706   arith::Analyzer analyzer;
707   // Get the replace index
708   std::unordered_map<IterVar, Range> dom_map;
709   std::unordered_map<IterVar, Expr> value_map;
710   for (IterVar iv : compute_op->reduce_axis) {
711     if (touch_map.count(iv)) {
712       dom_map[iv] = iv->dom;
713     } else {
714       skip_bound_check.insert(iv);
715     }
716     analyzer.Bind(iv->var, iv->dom);
717   }
718   schedule::PassDownDomain(reduce_stage, &dom_map, &analyzer, true);
719   for (IterVar iv : reduce_stage->leaf_iter_vars) {
720     if (touch_map.count(iv)) {
721       Range dom = dom_map.at(iv);
722       if (is_one(dom->extent)) {
723         value_map[iv] = dom->min;
724       } else {
725         value_map[iv] = iv->var;
726       }
727     }
728   }
729   schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true);
730   std::vector<Expr> predicates = schedule::MakeBoundCheck(
731       reduce_stage, dom_map, value_map, true, skip_bound_check);
732 
733   // Get the factored op node.
734   const int factor_axis_pos = \
735       factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
736   CHECK_LE(factor_axis_pos, compute_op->axis.size());
737   auto n = make_node<ComputeOpNode>();
738   n->name = compute_op->name + ".rf";
739   {
740     // axis relacement.
741     auto iv_node = make_node<IterVarNode>();
742     iv_node->dom = dom_map.at(axis);
743     CHECK(is_zero(iv_node->dom->min))
744         << "Can only factor reduction domain starting from 0";
745     iv_node->var = axis->var;
746     iv_node->iter_type = kDataPar;
747 
748     const int size = compute_op->axis.size();
749     for (int idx = 0; idx < size; ++idx) {
750       if (factor_axis_pos == idx) {
751         n->axis.push_back(IterVar(iv_node));
752       }
753       n->axis.push_back(compute_op->axis[idx]);
754     }
755     if (factor_axis_pos == size) {
756       n->axis.push_back(IterVar(iv_node));
757     }
758   }
759   // predicate generation, copy not touched axis.
760   int idx = tensor->value_index;
761   const Reduce* reduce = compute_op->body[idx].as<Reduce>();
762   CHECK(reduce) << "Can only rfactor non-inline reductions";
763   predicates.push_back(reduce->condition);
764   Expr predicate = likely(arith::ComputeReduce<ir::And>(predicates, Expr()));
765 
766   std::unordered_map<const Variable*, Expr> vsub;
767 
768   for (IterVar iv : compute_op->reduce_axis) {
769     if (!touch_map.count(iv)) {
770       n->reduce_axis.push_back(iv);
771     } else {
772       CHECK(value_map.count(iv));
773       Expr index = value_map.at(iv);
774       vsub[iv->var.get()] = index;
775     }
776   }
777 
778   // Copy touched axis.
779   for (IterVar iv : reduce_stage->leaf_iter_vars) {
780     if (touch_map.count(iv) && !iv.same_as(axis)) {
781       CHECK_EQ(iv->iter_type, kCommReduce);
782       auto ncpy = make_node<IterVarNode>(*iv.operator->());
783       ncpy->dom = dom_map.at(iv);
784       n->reduce_axis.push_back(IterVar(ncpy));
785     }
786   }
787   VarReplacer replacer(vsub);
788   Array<Expr> new_source = ir::UpdateArray(reduce->source,
789     [&replacer] (const Expr& e) { return replacer.Mutate(e); });
790 
791   Expr new_pred = replacer.Mutate(predicate);
792 
793   std::vector<Expr> body;
794   for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
795     body.emplace_back(Reduce::make(reduce->combiner,
796                                    new_source,
797                                    n->reduce_axis,
798                                    new_pred,
799                                    idx));
800   }
801   n->body = Array<Expr>(body);
802   // refresh relations, keep the un-touched relations.
803   Array<IterVarRelation> rels;
804   for (IterVarRelation rel : reduce_stage->relations) {
805     bool touched = false;
806     if (const SplitNode* r = rel.as<SplitNode>()) {
807       if (touch_map.count(r->parent)) touched = true;
808     } else if (const FuseNode* r = rel.as<FuseNode>()) {
809       if (touch_map.count(r->fused)) touched = true;
810     } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
811       if (touch_map.count(r->parent)) touched = true;
812     } else {
813       LOG(FATAL) << "unknown relation type";
814     }
815     if (!touched) {
816       rels.push_back(rel);
817     }
818   }
819   // initialize the factored stage.
820   Operation factor_op(n);
821   ArrayNode* stages = (*this)->stages.CopyOnWrite();
822   size_t stage_pos = FindNodeRef(stages, reduce_stage);
823   Stage factor_stage = Stage(factor_op);
824   factor_stage->relations = rels;
825   CHECK_LT(stage_pos, stages->data.size());
826   stages->data.insert(stages->data.begin() + stage_pos,
827                       factor_stage);
828   (*this)->stage_map.Set(factor_op, factor_stage);
829   factor_stage->group = reduce_stage->group;
830   if (factor_stage->group.defined()) {
831     ++factor_stage->group->num_child_stages;
832   }
833   // Replace the old reduction.
834   IterVar repl_red_axis = reduce_axis(
835       dom_map.at(axis), axis->var->name_hint + ".v");
836   Array<Tensor> factor_tensors;
837   Array<Tensor> old_tensors;
838   int size = factor_op->num_outputs();
839   for (int idx = 0; idx < size; ++idx) {
840     factor_tensors.push_back(factor_op.output(idx));
841     old_tensors.push_back(reduce_stage->op.output(idx));
842   }
843   Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
844     [&](const Array<Var>& i) {
845       Array<Expr> indices;
846       const int idx_size = static_cast<int>(i.size());
847       for (int idx = 0; idx < idx_size; ++idx) {
848         if (factor_axis_pos == idx) {
849           indices.push_back(repl_red_axis->var);
850         }
851         indices.push_back(i[idx]);
852       }
853       if (factor_axis_pos == idx_size) {
854           indices.push_back(repl_red_axis->var);
855       }
856       Array<Expr> factor_exprs;
857       for (int idx = 0; idx < size; ++idx) {
858         factor_exprs.push_back(factor_tensors[idx](indices));
859       }
860       Array<Expr> reductions;
861       Array<IterVar> axis = {repl_red_axis};
862       Expr cond = const_true();
863       for (int idx = 0; idx < size; ++idx) {
864         reductions.push_back(Reduce::make(reduce->combiner,
865           factor_exprs, axis, cond, idx));
866       }
867       return reductions;
868     }, reduce_stage->op->name + ".repl");
869 
870   std::unordered_map<Tensor, Tensor> vmap;
871   std::unordered_map<Tensor, Tensor> rvmap;
872   for (int idx = 0; idx < size; ++idx) {
873     vmap[old_tensors[idx]] = repl_tensors[idx];
874     rvmap[repl_tensors[idx]] = old_tensors[idx];
875   }
876   ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
877   // revamp the reduction stage.
878   reduce_stage->op = repl_tensors[0]->op;
879   reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
880   reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
881   reduce_stage->relations = Array<IterVarRelation>();
882   return factor_tensors;
883 }
884 
885 }  // namespace tvm
886