1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file schedule_dataflow_rewrite.cc
22  */
23 #include <tvm/te/operation.h>
24 #include <tvm/te/schedule.h>
25 #include <tvm/tir/op.h>
26 #include <tvm/tir/stmt_functor.h>
27 
28 #include <unordered_set>
29 
30 #include "../../tir/transforms/ir_util.h"
31 #include "message_passing.h"
32 #include "operation_inline.h"
33 
34 namespace tvm {
35 namespace te {
36 // find first occurance location in leaf
37 template <typename T>
FindNodeRef(ArrayNode * array_node,const T & v)38 size_t FindNodeRef(ArrayNode* array_node, const T& v) {
39   const Object* n = v.get();
40   for (size_t i = 0; i < array_node->size(); ++i) {
41     if (array_node->at(i).get() == n) return i;
42   }
43   return array_node->size();
44 }
45 
46 // The replacer of cache.
47 class VarReplacer : public tir::StmtExprMutator {
48  public:
VarReplacer(const std::unordered_map<const VarNode *,PrimExpr> & vsub)49   explicit VarReplacer(const std::unordered_map<const VarNode*, PrimExpr>& vsub) : vsub_(vsub) {}
VisitExpr_(const VarNode * op)50   PrimExpr VisitExpr_(const VarNode* op) final {
51     auto it = vsub_.find(op);
52     if (it != vsub_.end()) return it->second;
53     return GetRef<PrimExpr>(op);
54   }
55 
MutateCommReducer(tir::CommReducer combiner)56   tir::CommReducer MutateCommReducer(tir::CommReducer combiner) {
57     // Replace free variables in combiner
58     auto new_identity = tir::UpdateArray(combiner->identity_element,
59                                          [this](const PrimExpr& e) { return this->VisitExpr(e); });
60     auto new_result = tir::UpdateArray(combiner->result,
61                                        [this](const PrimExpr& e) { return this->VisitExpr(e); });
62 
63     if (combiner->identity_element.same_as(new_identity) &&
64         combiner->identity_element.same_as(new_result)) {
65       return combiner;
66     } else {
67       return tir::CommReducer(combiner->lhs, combiner->rhs, new_result, new_identity);
68     }
69   }
70 
VisitExpr_(const tir::ReduceNode * op)71   PrimExpr VisitExpr_(const tir::ReduceNode* op) final {
72     PrimExpr new_e = StmtExprMutator::VisitExpr_(op);
73     const tir::ReduceNode* new_reduce = new_e.as<tir::ReduceNode>();
74     tir::CommReducer new_combiner = MutateCommReducer(op->combiner);
75     if (op->combiner.same_as(new_combiner)) {
76       return new_e;
77     } else {
78       return tir::Reduce(new_combiner, new_reduce->source, new_reduce->axis, new_reduce->condition,
79                          new_reduce->value_index, new_reduce->init);
80     }
81   }
82 
83  private:
84   const std::unordered_map<const VarNode*, PrimExpr>& vsub_;
85 };
86 
InjectPredicate(const Array<PrimExpr> & predicates,PrimExpr body)87 PrimExpr InjectPredicate(const Array<PrimExpr>& predicates, PrimExpr body) {
88   using tir::ReduceNode;
89   using tir::SelectNode;
90   if (predicates.size() == 0) return body;
91   const ReduceNode* reduce = body.as<ReduceNode>();
92   auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
93 
94   if (reduce) {
95     auto n = make_object<ReduceNode>(*reduce);
96     n->condition = foldl(fand, n->condition, predicates);
97     return PrimExpr(n);
98   }
99   return Select(foldl(fand, const_true(1), predicates), body, make_zero(body.dtype()));
100 }
101 
102 // Replace data flow appears in all stages given the tensor change.
103 // Also update vmap if subsequent dataflow need to be replaced.
104 // Need to keep an update to the date transitive closure property on the vmap by a reverse map.
ReplaceDataFlow(const Array<Stage> & stages,std::unordered_map<Tensor,Tensor> * vmap,std::unordered_map<Tensor,Tensor> * rvmap)105 void ReplaceDataFlow(const Array<Stage>& stages, std::unordered_map<Tensor, Tensor>* vmap,
106                      std::unordered_map<Tensor, Tensor>* rvmap) {
107   for (Stage s : stages) {
108     Operation op = s->op->ReplaceInputs(s->op, *vmap);
109     if (!op.same_as(s->op)) {
110       for (int i = 0; i < op->num_outputs(); ++i) {
111         auto it = rvmap->find(s->op.output(i));
112         if (it != rvmap->end()) {
113           (*vmap)[it->second] = op.output(i);
114         } else {
115           (*vmap)[s->op.output(i)] = op.output(i);
116           (*rvmap)[op.output(i)] = s->op.output(i);
117         }
118       }
119       s->op = op;
120     }
121   }
122 }
123 
ReduceEqual(const tir::ReduceNode * a,const tir::ReduceNode * b)124 inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) {
125   return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) &&
126          (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) &&
127          ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init)));
128 }
129 
cache_read(const Tensor & tensor,const std::string & scope,const Array<Operation> & readers)130 Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope,
131                             const Array<Operation>& readers) {
132   (*this)->InvalidateCache();
133   // create identity mapping.
134   std::ostringstream os;
135   os << tensor->op->name;
136   if (tensor->op->num_outputs() != 1) {
137     os << ".v" << tensor->value_index;
138   }
139   os << "." << scope;
140 
141   std::unordered_map<Tensor, Tensor> vsub;
142   Stage s = operator[](tensor->op);
143   Tensor sugar_tensor = s->op.output(tensor->value_index);
144   Tensor cache = compute(
145       sugar_tensor->shape,
146       [&sugar_tensor](const Array<Var>& i) {
147         return sugar_tensor(Array<PrimExpr>(i.begin(), i.end()));
148       },
149       os.str());
150   vsub[sugar_tensor] = cache;
151 
152   std::unordered_map<Tensor, Tensor> vmap;
153   std::unordered_map<Tensor, Tensor> rvmap;
154   for (Operation op : readers) {
155     Stage s = operator[](op);
156     Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
157     CHECK(!repl_op.same_as(s->op)) << "Cannot find " << tensor << " in the inputs of " << s->op;
158     vmap[s->op.output(0)] = repl_op.output(0);
159     rvmap[repl_op.output(0)] = s->op.output(0);
160     s->op = repl_op;
161   }
162   ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
163   Array<Stage>& stages = (*this)->stages;
164   Stage op_stage = operator[](tensor->op);
165   size_t pos = FindNodeRef(stages.GetArrayNode(), op_stage);
166   Stage cache_stage = Stage(cache->op);
167   cache_stage.set_scope(scope);
168   CHECK_LT(pos, stages.size());
169   stages.insert(stages.begin() + pos + 1, cache_stage);
170   (*this)->stage_map.Set(cache->op, cache_stage);
171   // Update group
172   cache_stage->group = op_stage->group;
173   if (cache_stage->group.defined()) {
174     ++cache_stage->group->num_child_stages;
175   }
176   return cache;
177 }
178 
179 template <typename OpType>
PrepareAxisMapping(Stage orig_stage,OpType * op,std::unordered_set<IterVar> * p_red_axis,Array<IterVar> * p_new_axis,std::unordered_map<IterVar,Range> * p_dom_map,std::unordered_map<const VarNode *,PrimExpr> * p_vsub,std::unordered_map<const VarNode *,PrimExpr> * p_vsub2newvar,std::vector<PrimExpr> * p_predicates)180 void PrepareAxisMapping(Stage orig_stage, OpType* op, std::unordered_set<IterVar>* p_red_axis,
181                         Array<IterVar>* p_new_axis, std::unordered_map<IterVar, Range>* p_dom_map,
182                         std::unordered_map<const VarNode*, PrimExpr>* p_vsub,
183                         std::unordered_map<const VarNode*, PrimExpr>* p_vsub2newvar,
184                         std::vector<PrimExpr>* p_predicates) {
185   auto& red_axis = *p_red_axis;
186   auto& new_axis = *p_new_axis;
187   auto& dom_map = *p_dom_map;
188   auto& vsub = *p_vsub;
189   auto& vsub2newvar = *p_vsub2newvar;
190   auto& predicates = *p_predicates;
191   arith::Analyzer analyzer;
192 
193   for (IterVar iv : op->reduce_axis) {
194     red_axis.insert(iv);
195   }
196   for (IterVar iv : op->axis) {
197     dom_map[iv] = iv->dom;
198     analyzer.Bind(iv->var, iv->dom);
199   }
200   te::PassDownDomain(orig_stage, &dom_map, &analyzer, true);
201   {
202     // The source->cache
203     std::unordered_map<IterVar, PrimExpr> value_map;
204     for (IterVar iv : orig_stage->leaf_iter_vars) {
205       if (red_axis.count(iv)) continue;
206       CHECK_EQ(iv->iter_type, kDataPar) << "Can only relayout with in data parallel dimensions";
207       Range dom = dom_map.at(iv);
208       IterVar new_iv = IterVar(dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
209       new_axis.push_back(new_iv);
210       if (is_one(dom->min)) {
211         value_map[iv] = dom->min;
212       } else {
213         value_map[iv] = iv->var;
214         vsub2newvar[iv->var.get()] = new_iv->var;
215       }
216     }
217     // skip reduction iteration.
218     std::unordered_set<IterVar> skip_bound_check;
219     for (IterVar iv : op->reduce_axis) {
220       skip_bound_check.insert(iv);
221     }
222     PassUpIndex(orig_stage, dom_map, &value_map, true);
223     predicates = MakeBoundCheck(orig_stage, dom_map, value_map, true, skip_bound_check);
224     // The root axis
225     for (IterVar iv : op->axis) {
226       if (value_map.count(iv)) {
227         vsub[iv->var.get()] = value_map.at(iv);
228       }  // to handle tensor axis
229     }
230   }
231 }
232 
ReplaceOriginalOp(Schedule sch,Stage orig_stage,const std::string & scope,Operation cache_op,Operation orig_new_op,size_t tensor_size)233 Array<Tensor> ReplaceOriginalOp(Schedule sch, Stage orig_stage, const std::string& scope,
234                                 Operation cache_op, Operation orig_new_op, size_t tensor_size) {
235   Array<Tensor> cache_tensor_list;
236   for (size_t i = 0; i < tensor_size; i++) {
237     Tensor cache_tensor = cache_op.output(i);
238     cache_tensor_list.push_back(cache_tensor);
239   }
240   // The replace of the dataflow
241   std::unordered_map<Tensor, Tensor> vmap;
242   std::unordered_map<Tensor, Tensor> rvmap;
243   vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
244   rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
245   for (size_t i = 0; i < tensor_size; i++) {
246     vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
247     rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
248   }
249   ReplaceDataFlow(sch->stages, &vmap, &rvmap);
250   // mutate orig stage
251   orig_stage->op = orig_new_op;
252   orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
253   orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
254   orig_stage->relations = Array<IterVarRelation>();
255   // create schedule for new cached stage.
256   Array<Stage>& stages = sch->stages;
257   size_t pos = FindNodeRef(stages.GetArrayNode(), orig_stage);
258   Stage cache_stage = Stage(cache_op);
259   cache_stage.set_scope(scope);
260   CHECK_LT(pos, stages.size());
261   stages.insert(stages.begin() + pos, cache_stage);
262   sch->stage_map.Set(cache_op, cache_stage);
263   // Update group
264   cache_stage->group = orig_stage->group;
265   if (cache_stage->group.defined()) {
266     ++cache_stage->group->num_child_stages;
267   }
268   return cache_tensor_list;
269 }
270 
271 // Cache write and relayout the data according to loop pattern
CacheWriteWithReLayout(Schedule sch,const Array<Tensor> & tensor_array,const std::string & scope)272 Array<Tensor> CacheWriteWithReLayout(Schedule sch, const Array<Tensor>& tensor_array,
273                                      const std::string& scope) {
274   size_t tensor_size = tensor_array.size();
275   sch->InvalidateCache();
276   Tensor tensor = tensor_array[0];
277   Stage orig_stage = sch[tensor->op];
278   const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
279 
280   std::unordered_set<IterVar> red_axis;
281   Array<IterVar> new_axis;
282   std::unordered_map<IterVar, Range> dom_map;
283 
284   std::unordered_map<const VarNode*, PrimExpr> vsub;
285   std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
286   std::vector<PrimExpr> predicates;
287 
288   PrepareAxisMapping(orig_stage, compute, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar,
289                      &predicates);
290 
291   PrimExpr body;
292   Array<PrimExpr> body_list;
293   const tir::ReduceNode* first_reduce = nullptr;
294   for (auto cbody : compute->body) {
295     body = VarReplacer(vsub)(cbody);
296     body = InjectPredicate(predicates, body);
297     body = VarReplacer(vsub2newvar)(body);
298     // Reduce nodes in ONE computeOp must be the same except value_index
299     // This is right only if the original body ensures Reduce nodes are the same
300     if (body->IsInstance<tir::ReduceNode>()) {
301       const tir::ReduceNode* reduce_body = body.as<tir::ReduceNode>();
302       if (first_reduce != nullptr) {
303         CHECK(ReduceEqual(reduce_body, first_reduce));
304         body = tir::Reduce(first_reduce->combiner, first_reduce->source, first_reduce->axis,
305                            first_reduce->condition, reduce_body->value_index, reduce_body->init);
306       } else {
307         first_reduce = reduce_body;
308       }
309     } else {
310       CHECK(first_reduce == nullptr) << "cannot mix reduce and other node in ONE compute bodys";
311     }
312     body_list.push_back(body);
313   }
314   // The reader args
315   Array<PrimExpr> args;
316   {
317     // cache->compute
318     std::unordered_map<IterVar, PrimExpr> value_map;
319     for (IterVar iv : compute->axis) {
320       value_map[iv] = iv->var;
321     }
322     te::PassDownIndex(orig_stage, dom_map, &value_map, true);
323     for (IterVar iv : orig_stage->leaf_iter_vars) {
324       if (red_axis.count(iv)) continue;
325       args.push_back(value_map.at(iv));
326     }
327   }
328   Operation cache_op =
329       ComputeOp(compute->name + "." + scope, compute->tag, compute->attrs, new_axis, body_list);
330 
331   Array<PrimExpr> cache_expr_list;
332   for (size_t i = 0; i < tensor_size; i++) {
333     Tensor cache_tensor = cache_op.output(i);
334     cache_expr_list.push_back(cache_tensor(args));
335   }
336   Operation orig_new_op =
337       ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, cache_expr_list);
338   return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size);
339 }
340 
341 // for tensor compute op
CacheWriteWithReLayoutTensor(Schedule sch,const Array<Tensor> & tensor_array,const std::string & scope)342 Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch, const Array<Tensor>& tensor_array,
343                                            const std::string& scope) {
344   size_t tensor_size = tensor_array.size();
345   sch->InvalidateCache();
346   Tensor tensor = tensor_array[0];
347   Stage orig_stage = sch[tensor->op];
348   const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
349   CHECK_EQ(tensor_op->num_outputs(), 1)
350       << "cache write only support single output tensor_compute_op";
351 
352   std::unordered_set<IterVar> red_axis;
353   Array<IterVar> new_axis;
354   std::unordered_map<IterVar, Range> dom_map;
355 
356   std::unordered_map<const VarNode*, PrimExpr> vsub;
357   std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
358   std::vector<PrimExpr> predicates;
359 
360   PrepareAxisMapping(orig_stage, tensor_op, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar,
361                      &predicates);
362 
363   for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
364     IterVar iv = tensor_op->axis[i];
365     IterVar new_iv = IterVar(iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
366     new_axis.push_back(new_iv);
367   }
368   Array<Region> new_regions;
369   for (Region old_region : tensor_op->input_regions) {
370     Region region;
371     for (Range r : old_region) {
372       PrimExpr min = VarReplacer(vsub2newvar)(r->min);
373       PrimExpr extent = VarReplacer(vsub2newvar)(r->extent);
374       region.push_back(Range::FromMinExtent(min, extent));
375     }
376     new_regions.push_back(region);
377   }
378 
379   Array<PrimExpr> new_scalar_inputs;
380   for (PrimExpr old_input : tensor_op->scalar_inputs) {
381     new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input));
382   }
383 
384   Operation cache_op =
385       TensorComputeOp(tensor_op->name + "." + scope, tensor_op->tag, new_axis,
386                       tensor_op->reduce_axis, tensor_op->schedulable_ndim, tensor_op->intrin,
387                       tensor_op->inputs, new_regions, new_scalar_inputs);
388 
389   // axis will be used in generating compute op
390   Array<IterVar> compute_axis = tensor_op->axis;
391   for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
392     IterVar iv = tensor_op->axis[i];
393     IterVar aiv = IterVar(iv->dom, iv->var, kDataPar);
394     compute_axis.Set(i, aiv);
395   }
396 
397   // The reader args
398   Array<PrimExpr> args;
399   {
400     // cache->compute
401     std::unordered_map<IterVar, PrimExpr> value_map;
402     for (IterVar iv : compute_axis) {
403       value_map[iv] = iv->var;
404     }
405     PassDownIndex(orig_stage, dom_map, &value_map, true);
406     for (IterVar iv : orig_stage->leaf_iter_vars) {
407       if (red_axis.count(iv)) continue;
408       args.push_back(value_map.at(iv));
409     }
410     // tensorized region axis
411     for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
412       IterVar iv = compute_axis[i];
413       args.push_back(value_map.at(iv));
414     }
415   }
416 
417   Array<PrimExpr> cache_expr_list;
418   for (size_t i = 0; i < tensor_size; i++) {
419     Tensor cache_tensor = cache_op.output(i);
420     cache_expr_list.push_back(cache_tensor(args));
421   }
422   Operation orig_new_op =
423       ComputeOp(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list);
424   return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size);
425 }
426 
cache_write(const Array<Tensor> & tensor_array,const std::string & scope)427 Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array, const std::string& scope) {
428   (*this)->InvalidateCache();
429   CHECK(tensor_array.size() > 0) << "size of tensor_array must be greater than 0";
430   Tensor tensor = tensor_array[0];
431   Stage orig_stage = operator[](tensor->op);
432   const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
433   CHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size())
434       << "size of input tensor list must be same as number of stage outputs";
435   for (size_t i = 1; i < tensor_array.size(); i++) {
436     Stage tmp_stage = operator[](tensor_array[i]->op);
437     CHECK(orig_stage.same_as(tmp_stage)) << "Input tensor list must be generated by ONE computeOp";
438   }
439   return CacheWriteWithReLayout(*this, tensor_array, scope);
440 }
441 
cache_write(const Tensor & tensor,const std::string & scope)442 Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) {
443   // support original compute and tensor compute both
444   (*this)->InvalidateCache();
445   if (tensor->op.as<ComputeOpNode>()) {
446     return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
447   } else if (tensor->op.as<TensorComputeOpNode>()) {
448     return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
449   } else {
450     LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
451     return Tensor();
452   }
453 }
454 
RebaseNonZeroMinLoop(ScheduleNode * sch)455 void RebaseNonZeroMinLoop(ScheduleNode* sch) {
456   std::unordered_map<IterVar, IterVar> rebase_map;
457   for (Stage s : sch->stages) {
458     if (s->attach_type == kInlinedAlready) continue;
459 
460     auto root_iter_vars = s->op->root_iter_vars();
461     ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
462     for (IterVar iv : root_iter_vars) {
463       size_t idx = FindNodeRef(leaf_vars, iv);
464       auto it = s->iter_var_attrs.find(iv);
465       // don;t need to rebase path that are binded.
466       if (it != s->iter_var_attrs.end() && (*it).second->bind_thread.defined()) {
467         continue;
468       }
469       if (idx < leaf_vars->size()) {
470         // insert rebase
471         IterVar rebased = IterVar(Range(), iv->var.copy_with_suffix(""), iv->iter_type);
472         s->relations.push_back(te::Rebase(iv, rebased));
473         if (s->iter_var_attrs.count(iv)) {
474           s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
475         }
476         leaf_vars->SetItem(idx, rebased);
477         rebase_map[iv] = rebased;
478       }
479     }
480   }
481   // remap the parent relation
482   for (Stage s : sch->stages) {
483     if (s->attach_type != kScope) continue;
484     if (rebase_map.count(s->attach_ivar)) {
485       s->attach_ivar = rebase_map.at(s->attach_ivar);
486     }
487   }
488   for (Stage s : sch->groups) {
489     if (s->attach_type != kScope) continue;
490     if (rebase_map.count(s->attach_ivar)) {
491       s->attach_ivar = rebase_map.at(s->attach_ivar);
492     }
493   }
494 }
495 
InjectInline(ScheduleNode * sch)496 void InjectInline(ScheduleNode* sch) {
497   sch->InvalidateCache();
498 
499   std::vector<Array<PrimExpr> > new_body(sch->stages.size());
500   std::vector<bool> changed(sch->stages.size(), false);
501   std::vector<Stmt> new_hybrid_body(sch->stages.size());
502   std::vector<bool> hybrid_changed(sch->stages.size(), false);
503   // inline all the ops
504   for (size_t i = sch->stages.size(); i != 0; --i) {
505     Stage stage = sch->stages[i - 1];
506     if (stage->attach_type == kInline) {
507       stage->attach_type = kInlinedAlready;
508       Array<Var> args;
509       PrimExpr body;
510       {
511         // setup args
512         const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
513         CHECK(compute) << "can only inline compute op";
514         for (auto iv : compute->axis) {
515           args.push_back(iv->var);
516         }
517         CHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output";
518         body = compute->body[0];
519       }
520       for (size_t j = i; j < sch->stages.size(); ++j) {
521         Stage s = sch->stages[j];
522         const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
523         const HybridOpNode* hybrid = s->op.as<HybridOpNode>();
524         if (compute) {
525           if (!new_body[j].size()) {
526             new_body[j] = compute->body;
527           }
528           if (new_body[j][0]->IsInstance<tir::ReduceNode>()) {
529             // specially handle reduction inline for multiplre reductions.
530             const tir::ReduceNode* reduce = new_body[j][0].as<tir::ReduceNode>();
531             for (size_t k = 1; k < new_body[j].size(); ++k) {
532               const tir::ReduceNode* reduce_ = new_body[j][k].as<tir::ReduceNode>();
533               CHECK(reduce_);
534               CHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should "
535                                                   << "have the same attribute except value_index";
536             }
537             PrimExpr new_value = Inline(tir::Evaluate(new_body[j][0]), stage->op, args, body)
538                                      .as<tir::EvaluateNode>()
539                                      ->value;
540             if (!new_value.same_as(new_body[j][0])) {
541               changed[j] = true;
542               const tir::ReduceNode* r = new_value.as<tir::ReduceNode>();
543               CHECK(r != nullptr);
544               CHECK_EQ(new_body[j].size(), r->source.size());
545               for (size_t k = 0; k < new_body[j].size(); ++k) {
546                 auto n = make_object<tir::ReduceNode>(*r);
547                 n->value_index = static_cast<int>(k);
548                 n->dtype = r->source[k].dtype();
549                 new_body[j].Set(k, PrimExpr(n));
550               }
551             }
552           } else {
553             for (size_t k = 0; k < new_body[j].size(); ++k) {
554               PrimExpr new_value = Inline(tir::Evaluate(new_body[j][k]), stage->op, args, body)
555                                        .as<tir::EvaluateNode>()
556                                        ->value;
557               if (!new_value.same_as(new_body[j][k])) {
558                 new_body[j].Set(k, new_value);
559                 changed[j] = true;
560               }
561             }
562           }
563         } else if (hybrid) {
564           if (!new_hybrid_body[j].defined()) {
565             new_hybrid_body[j] = hybrid->body;
566           }
567           Stmt new_stmt = Inline(new_hybrid_body[j], stage->op, args, body);
568           if (!new_stmt.same_as(new_hybrid_body[j])) {
569             new_hybrid_body[j] = new_stmt;
570             hybrid_changed[j] = true;
571           }
572         }
573       }
574     }
575   }
576   std::unordered_map<Tensor, Tensor> repl;
577   // rewrite dataflow
578   for (size_t i = 0; i < sch->stages.size(); ++i) {
579     Stage s = sch->stages[i];
580     if (s->attach_type == kInlinedAlready) continue;
581     if (new_body[i].size()) {
582       // Logics from ReplaceDataFlow
583       const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
584       CHECK(compute);
585       Operation op = s->op;
586       if (changed[i]) {
587         op = ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, new_body[i]);
588       }
589       op = op->ReplaceInputs(op, repl);
590       if (!op.same_as(s->op)) {
591         for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
592           repl[s->op.output(idx)] = op.output(idx);
593         }
594         s->op = op;
595       }
596     } else if (hybrid_changed[i]) {
597       const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
598       CHECK(hybrid);
599       Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
600                               hybrid->outputs, new_hybrid_body[i]);
601       op = op->ReplaceInputs(op, repl);
602       for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
603         repl[s->op.output(idx)] = op.output(idx);
604       }
605       s->op = op;
606     } else {
607       Operation op = s->op->ReplaceInputs(s->op, repl);
608       if (!op.same_as(s->op)) {
609         for (int j = 0; j < op->num_outputs(); ++j) {
610           repl[s->op.output(j)] = op.output(j);
611         }
612         s->op = op;
613       }
614     }
615   }
616 }
617 
LegalizeInvalidAttach(ScheduleNode * sch)618 void LegalizeInvalidAttach(ScheduleNode* sch) {
619   // Legalize the compute_at location if the target iterator of compute_at is split or fused.
620   // Case 1: If the target of compute_at is split,
621   //         we will move the compute_at location to the inner iterator.
622   // Case 2: If the target of compute_at is fused,
623   //         we will move the compute_at location to the newly fused iterator.
624   // Note that case 2 can only happen if the target of compute_at
625   // is the innermost operand of fuse operation.
626 
627   // Map an old invalid attach point to its new valid attach point
628   std::unordered_map<IterVar, IterVar> replace_map;
629 
630   for (Stage stage : sch->stages) {
631     for (Stage s = stage; s.defined();) {
632       // The following logic is simiar to the `CreateAttachPath` in `src/te/schedule/graph.h`,
633       // because we follow the validation check in that function to legalize the attach.
634       Stage spec = s.GetAttachSpec();
635       if (spec->attach_type != kScope) {
636         break;
637       }
638       bool start_attach = false;
639       IterVar attach_ivar = spec->attach_ivar;
640       s = spec->attach_stage;
641       CHECK(attach_ivar.defined());
642       CHECK(s.defined());
643 
644       for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
645         IterVar iv = s->leaf_iter_vars[i - 1];
646         if (!start_attach && iv.same_as(attach_ivar)) {
647           start_attach = true;
648           break;
649         }
650       }
651 
652       if (!start_attach) {
653         IterVar new_attach_ivar = attach_ivar;
654         bool updated = true;
655         // recursively update the relations
656         while (updated) {
657           updated = false;
658           for (const auto& rel : s->relations) {
659             if (const FuseNode* r = rel.as<FuseNode>()) {
660               if (new_attach_ivar.same_as(r->inner)) {
661                 new_attach_ivar = r->fused;
662                 updated = true;
663               }
664             } else if (const SplitNode* r = rel.as<SplitNode>()) {
665               if (new_attach_ivar.same_as(r->parent)) {
666                 new_attach_ivar = r->inner;
667                 updated = true;
668               }
669             }
670           }
671           replace_map[attach_ivar] = new_attach_ivar;
672         }
673       }
674     }
675   }
676 
677   // remap the parent relation
678   for (Stage s : sch->stages) {
679     if (s->attach_type != kScope) continue;
680     if (replace_map.count(s->attach_ivar)) {
681       s->attach_ivar = replace_map.at(s->attach_ivar);
682     }
683   }
684   for (Stage s : sch->groups) {
685     if (s->attach_type != kScope) continue;
686     if (replace_map.count(s->attach_ivar)) {
687       s->attach_ivar = replace_map.at(s->attach_ivar);
688     }
689   }
690 }
691 
normalize()692 Schedule Schedule::normalize() {
693   Schedule sn = copy();
694   InjectInline(sn.operator->());
695   RebaseNonZeroMinLoop(sn.operator->());
696   LegalizeInvalidAttach(sn.operator->());
697   return sn;
698 }
699 
700 // Handle reduction factor.
rfactor(const Tensor & tensor,const IterVar & axis,int factor_axis)701 Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis) {
702   (*this)->InvalidateCache();
703   using tir::ReduceNode;
704   CHECK_EQ(axis->iter_type, kCommReduce) << "Can only factor reduction axis";
705   Stage reduce_stage = operator[](tensor->op);
706   const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
707   CHECK(compute_op) << "Can only factor ComputeOp";
708   ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
709   {
710     size_t axis_pos = FindNodeRef(leaf_vars, axis);
711     CHECK_NE(axis_pos, leaf_vars->size()) << "Cannot find IterVar " << axis << " in leaf iter vars";
712   }
713   // Find touched reduction axis.
714   std::unordered_map<IterVar, int> touch_map;
715   touch_map[axis] = 1;
716   te::PassUpBitMaskOr(reduce_stage, &touch_map, true);
717   te::PassDownBitMaskOr(reduce_stage, &touch_map, true);
718   // skip reduction iteration.
719   std::unordered_set<IterVar> skip_bound_check;
720   // Verify normal axis are not touched.
721   for (IterVar iv : compute_op->axis) {
722     CHECK(!touch_map.count(iv)) << "Factor axis touches normal axis.";
723     skip_bound_check.insert(iv);
724   }
725   // get analyzer.
726   arith::Analyzer analyzer;
727   // Get the replace index
728   std::unordered_map<IterVar, Range> dom_map;
729   std::unordered_map<IterVar, PrimExpr> value_map;
730   for (IterVar iv : compute_op->reduce_axis) {
731     if (touch_map.count(iv)) {
732       dom_map[iv] = iv->dom;
733     } else {
734       skip_bound_check.insert(iv);
735     }
736     analyzer.Bind(iv->var, iv->dom);
737   }
738   te::PassDownDomain(reduce_stage, &dom_map, &analyzer, true);
739   for (IterVar iv : reduce_stage->leaf_iter_vars) {
740     if (touch_map.count(iv)) {
741       Range dom = dom_map.at(iv);
742       if (is_one(dom->extent)) {
743         value_map[iv] = dom->min;
744       } else {
745         value_map[iv] = iv->var;
746       }
747     }
748   }
749   te::PassUpIndex(reduce_stage, dom_map, &value_map, true);
750   std::vector<PrimExpr> predicates =
751       MakeBoundCheck(reduce_stage, dom_map, value_map, true, skip_bound_check);
752 
753   // Get the factored op node.
754   const int factor_axis_pos =
755       factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
756   CHECK_LE(factor_axis_pos, compute_op->axis.size());
757   auto n = make_object<ComputeOpNode>();
758   n->name = compute_op->name + ".rf";
759   {
760     // axis relacement.
761     auto iv_node = make_object<IterVarNode>();
762     iv_node->dom = dom_map.at(axis);
763     CHECK(is_zero(iv_node->dom->min)) << "Can only factor reduction domain starting from 0";
764     iv_node->var = axis->var;
765     iv_node->iter_type = kDataPar;
766 
767     const int size = compute_op->axis.size();
768     for (int idx = 0; idx < size; ++idx) {
769       if (factor_axis_pos == idx) {
770         n->axis.push_back(IterVar(iv_node));
771       }
772       n->axis.push_back(compute_op->axis[idx]);
773     }
774     if (factor_axis_pos == size) {
775       n->axis.push_back(IterVar(iv_node));
776     }
777   }
778   // predicate generation, copy not touched axis.
779   int idx = tensor->value_index;
780   const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
781   CHECK(reduce) << "Can only rfactor non-inline reductions";
782   predicates.push_back(reduce->condition);
783   auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
784 
785   PrimExpr predicate = likely(foldl(fand, const_true(1), predicates));
786 
787   std::unordered_map<const VarNode*, PrimExpr> vsub;
788 
789   for (IterVar iv : compute_op->reduce_axis) {
790     if (!touch_map.count(iv)) {
791       n->reduce_axis.push_back(iv);
792     } else {
793       CHECK(value_map.count(iv));
794       PrimExpr index = value_map.at(iv);
795       vsub[iv->var.get()] = index;
796     }
797   }
798 
799   // Copy touched axis.
800   for (IterVar iv : reduce_stage->leaf_iter_vars) {
801     if (touch_map.count(iv) && !iv.same_as(axis)) {
802       CHECK_EQ(iv->iter_type, kCommReduce);
803       auto ncpy = make_object<IterVarNode>(*iv.operator->());
804       ncpy->dom = dom_map.at(iv);
805       n->reduce_axis.push_back(IterVar(ncpy));
806     }
807   }
808   VarReplacer replacer(vsub);
809   Array<PrimExpr> new_source =
810       tir::UpdateArray(reduce->source, [&replacer](const PrimExpr& e) { return replacer(e); });
811 
812   PrimExpr new_pred = replacer(predicate);
813 
814   std::vector<PrimExpr> body;
815   for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
816     body.emplace_back(Reduce(reduce->combiner, new_source, n->reduce_axis, new_pred, idx, {}));
817   }
818   n->body = Array<PrimExpr>(body);
819   // refresh relations, keep the un-touched relations.
820   Array<IterVarRelation> rels;
821   for (IterVarRelation rel : reduce_stage->relations) {
822     bool touched = false;
823     if (const SplitNode* r = rel.as<SplitNode>()) {
824       if (touch_map.count(r->parent)) touched = true;
825     } else if (const FuseNode* r = rel.as<FuseNode>()) {
826       if (touch_map.count(r->fused)) touched = true;
827     } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
828       if (touch_map.count(r->parent)) touched = true;
829     } else {
830       LOG(FATAL) << "unknown relation type";
831     }
832     if (!touched) {
833       rels.push_back(rel);
834     }
835   }
836   // initialize the factored stage.
837   Operation factor_op(n);
838   Array<Stage>& stages = (*this)->stages;
839   size_t stage_pos = FindNodeRef(stages.GetArrayNode(), reduce_stage);
840   Stage factor_stage = Stage(factor_op);
841   factor_stage->relations = rels;
842   CHECK_LT(stage_pos, stages.size());
843   stages.insert(stages.begin() + stage_pos, factor_stage);
844   (*this)->stage_map.Set(factor_op, factor_stage);
845   factor_stage->group = reduce_stage->group;
846   if (factor_stage->group.defined()) {
847     ++factor_stage->group->num_child_stages;
848   }
849   // Replace the old reduction.
850   IterVar repl_red_axis = reduce_axis(dom_map.at(axis), axis->var->name_hint + ".v");
851   Array<Tensor> factor_tensors;
852   Array<Tensor> old_tensors;
853   int size = factor_op->num_outputs();
854   for (int idx = 0; idx < size; ++idx) {
855     factor_tensors.push_back(factor_op.output(idx));
856     old_tensors.push_back(reduce_stage->op.output(idx));
857   }
858   Array<Tensor> repl_tensors = compute(
859       old_tensors[0]->shape,
860       [&](const Array<Var>& i) {
861         Array<PrimExpr> indices;
862         const int idx_size = static_cast<int>(i.size());
863         for (int idx = 0; idx < idx_size; ++idx) {
864           if (factor_axis_pos == idx) {
865             indices.push_back(repl_red_axis->var);
866           }
867           indices.push_back(i[idx]);
868         }
869         Array<PrimExpr> new_init = reduce->init;
870         if (!reduce->init.empty()) {
871           std::unordered_map<const VarNode*, PrimExpr> init_vsub;
872           for (const auto& init : reduce->init) {
873             if (init->IsInstance<ProducerLoadNode>()) {
874               CHECK_EQ(compute_op->axis.size(), idx_size)
875                   << "'init' should have the number of dimensions as output when using with "
876                      "rfactor";
877               for (int idx = 0; idx < idx_size; idx++) {
878                 init_vsub[compute_op->axis[idx]->var.get()] = i[idx];
879               }
880             }
881           }
882           VarReplacer init_replacer(init_vsub);
883           new_init = tir::UpdateArray(
884               reduce->init, [&init_replacer](const PrimExpr& e) { return init_replacer(e); });
885         }
886         if (factor_axis_pos == idx_size) {
887           indices.push_back(repl_red_axis->var);
888         }
889         Array<PrimExpr> factor_exprs;
890         for (int idx = 0; idx < size; ++idx) {
891           factor_exprs.push_back(factor_tensors[idx](indices));
892         }
893         Array<PrimExpr> reductions;
894         Array<IterVar> axis = {repl_red_axis};
895         PrimExpr cond = const_true();
896         for (int idx = 0; idx < size; ++idx) {
897           reductions.push_back(Reduce(reduce->combiner, factor_exprs, axis, cond, idx, new_init));
898         }
899         return reductions;
900       },
901       reduce_stage->op->name + ".repl");
902 
903   std::unordered_map<Tensor, Tensor> vmap;
904   std::unordered_map<Tensor, Tensor> rvmap;
905   for (int idx = 0; idx < size; ++idx) {
906     vmap[old_tensors[idx]] = repl_tensors[idx];
907     rvmap[repl_tensors[idx]] = old_tensors[idx];
908   }
909   ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
910   // revamp the reduction stage.
911   reduce_stage->op = repl_tensors[0]->op;
912   reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
913   reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
914   reduce_stage->relations = Array<IterVarRelation>();
915   return factor_tensors;
916 }
917 }  // namespace te
918 }  // namespace tvm
919