1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \brief Hybrid computation rule.
22  * \file hybrid_op.cc
23  */
24 #include <tvm/operation.h>
25 #include <tvm/arithmetic.h>
26 #include <tvm/ir.h>
27 #include <tvm/ir_mutator.h>
28 #include <tvm/ir_pass.h>
29 #include <tvm/expr_operator.h>
30 #include <unordered_set>
31 #include <string>
32 #include <utility>
33 #include "op_util.h"
34 #include "hybrid_op.h"
35 
36 namespace tvm {
37 using namespace ir;
38 // HybridOpNode
39 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonb7a6c0fe0102(const ObjectRef& node, IRPrinter* p) 40 .set_dispatch<HybridOpNode>([](const ObjectRef& node, IRPrinter* p) {
41     auto* op = static_cast<const HybridOpNode*>(node.get());
42     p->stream << "hybrid(" << op->name << ", " << op << ")";
43   });
44 
45 TVM_REGISTER_NODE_TYPE(HybridOpNode);
46 
num_outputs() const47 int HybridOpNode::num_outputs() const {
48   return static_cast<int>(outputs.size());
49 }
50 
root_iter_vars() const51 Array<IterVar> HybridOpNode::root_iter_vars() const {
52   return this->axis;
53 }
54 
output_dtype(size_t i) const55 Type HybridOpNode::output_dtype(size_t i) const {
56   return outputs[i]->dtype;
57 }
58 
output_shape(size_t i) const59 Array<Expr> HybridOpNode::output_shape(size_t i) const {
60   return outputs[i]->shape;
61 }
62 
63 
make(std::string name,std::string tag,Map<std::string,NodeRef> attrs,Array<Tensor> inputs,Array<Tensor> outputs,Stmt body)64 Operation HybridOpNode::make(std::string name,
65                              std::string tag,
66                              Map<std::string, NodeRef> attrs,
67                              Array<Tensor> inputs,
68                              Array<Tensor> outputs,
69                              Stmt body) {
70   if (!attrs.defined()) {
71     attrs = Map<std::string, NodeRef>();
72   }
73   auto n = make_node<HybridOpNode>();
74   n->name = std::move(name);
75   n->tag = std::move(tag);
76   n->attrs = std::move(attrs);
77   n->inputs = std::move(inputs);
78   n->outputs = std::move(outputs);
79   n->axis = op::GatherLoopVars(body);
80   n->body = std::move(body);
81   Operation res = Operation(n);
82   return res;
83 }
84 
InputTensors() const85 Array<Tensor> HybridOpNode::InputTensors() const {
86   // Because input tensors could be potentially inlined into hybrid scripts,
87   // we need to check if all input tensors are used in the body.
88   std::unordered_set<Tensor> orig_inputs;
89   for (auto t : inputs) {
90     orig_inputs.insert(t);
91   }
92   std::unordered_set<Tensor> visited;
93   Array<Tensor> curr_inputs;
94   ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const NodeRef& n) {
95       const ir::Call *call = n.as<ir::Call>();
96       if (call != nullptr && call->func.defined()) {
97         Tensor t = Downcast<Operation>(call->func).output(call->value_index);
98         if (orig_inputs.count(t) && !visited.count(t)) {
99           curr_inputs.push_back(t);
100           visited.insert(t);
101         }
102       }
103   });
104   return curr_inputs;
105 }
106 
ReplaceInputs(const Operation & self,const std::unordered_map<Tensor,Tensor> & rmap) const107 Operation HybridOpNode::ReplaceInputs(
108     const Operation &self,
109     const std::unordered_map<Tensor, Tensor> &rmap) const {
110   CHECK_EQ(self.operator->(), this);
111   auto n = make_node<HybridOpNode>(*this);
112   n->body = op::ReplaceTensor(this->body, rmap);
113   for (size_t i = 0; i < n->inputs.size(); ++i) {
114     Tensor t = n->inputs[i];
115     if (rmap.count(t)) {
116       n->inputs.Set(i, rmap.at(t));
117     }
118   }
119 
120   if (body.same_as(n->body) &&
121       inputs.same_as(n->inputs)) {
122     return self;
123   } else {
124     return Operation(n);
125   }
126 }
127 
PropBoundToInputs(const Operation & self,arith::Analyzer * analyzer,const std::unordered_map<const Variable *,IntSet> & dom_map,std::unordered_map<Tensor,TensorDom> * out_dom_map) const128 void HybridOpNode::PropBoundToInputs(
129     const Operation &self,
130     arith::Analyzer* analyzer,
131     const std::unordered_map<const Variable*, IntSet> &dom_map,
132     std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
133   auto curr_inputs = InputTensors();
134   for (Tensor t : curr_inputs) {
135     auto it = out_dom_map->find(t);
136     if (it == out_dom_map->end()) continue;
137     TensorDom &dom = it->second;
138     for (size_t i = 0; i < t->shape.size(); ++i) {
139       dom.data[i].emplace_back(IntSet::range(
140           Range::make_by_min_extent(
141               make_const(t->shape[i].type(), 0), t->shape[i])));
142     }
143   }
144 }
145 
GatherBound(const Operation & self,const std::unordered_map<Tensor,TensorDom> & tensor_dom,std::unordered_map<IterVar,Range> * out_dom_map) const146 void HybridOpNode::GatherBound(
147     const Operation &self,
148     const std::unordered_map<Tensor, TensorDom> &tensor_dom,
149     std::unordered_map<IterVar, Range>* out_dom_map) const {
150   for (auto iter_var : axis) {
151     CHECK(!out_dom_map->count(iter_var));
152     out_dom_map->operator[](iter_var) = iter_var->dom;
153   }
154 }
155 
BuildRealize(const Stage & stage,const std::unordered_map<IterVar,Range> & realize_map,const Stmt & body) const156 Stmt HybridOpNode::BuildRealize(
157     const Stage &stage,
158     const std::unordered_map<IterVar, Range> &realize_map,
159     const Stmt &body) const {
160   // TODO(@were): Add attribute inject here and remove it from hybrid parser.
161   CHECK_EQ(stage->op.get(), this);
162   Stmt realize_body = body;
163   for (int k = 0; k < num_outputs(); ++k) {
164     Tensor t = stage->op.output(k);
165     Region bounds;
166     for (size_t i = 0; i < t->shape.size(); ++i) {
167       bounds.push_back(
168           Range::make_by_min_extent(
169               make_const(t->shape[i].type(), 0), t->shape[i]));
170     }
171     realize_body = ir::Realize::make(
172         t->op, t->value_index, t->dtype,
173         bounds, const_true(), realize_body);
174   }
175   return realize_body;
176 }
177 
BuildProvide(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,bool debug_keep_trivial_loop) const178 Stmt HybridOpNode::BuildProvide(
179     const Stage &stage,
180     const std::unordered_map<IterVar, Range> &dom_map,
181     bool debug_keep_trivial_loop) const {
182   CHECK_EQ(stage->op.operator->(), this);
183   Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
184   std::unordered_map<Tensor, Tensor> rmap;
185   for (int i = 0; i < this->num_outputs(); ++i) {
186     rmap[outputs[i]] = stage->op.output(i);
187   }
188   auto n = make_node<HybridOpNode>(*this);
189   /* This is a story little bit complicated.
190    * The following two lines of codes replace output tensors' usage.
191    * This is the simplest way I (@were) can come up with to glue
192    * hybrid operation node to TVM op system.
193    * In hybrid script all the tensors, especially the output tensors,
194    * have their own names defined by the users. However, In TVM
195    * conventional ops:
196    *   1. Output tensors refer the corresponding op node so that the output
197    *      tensors have the same names as the operation produces them.
198    *   2. Once OpNode is wrapped up by an Operation node, it is finalized.
199    *      Later access will be from a const OpNode*.
200    * This is a chicken-egg paradox. It is impossible to put the output
201    * tensors into the function body without forming the op node. The
202    * function body is immutable after the node is formed.
203    *
204    * Finally, I decided to resolve this issue "lazily". During the
205    * pipeline of compilation, this stage is a very preliminary stage.
206    * Technically, it is before Phase 0. The actual tensors will be replaced
207    * here.
208    * Thus, the operation body is slightly different from the Phase 0 body.
209    * This is a major difference that HybridOpNode is NOT the same as
210    * ExternOpNode.
211    * */
212   ret = op::ReplaceTensor(ret, rmap);
213   ret = op::ReplaceProvideTensor(ret, rmap);
214 
215   ret = op::ApplySchedule(stage, dom_map, ret);
216   return ret;
217 }
218 
219 namespace op {
220 
221 
ApplyLoopShapes(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,Stmt stmt)222 Stmt ApplyLoopShapes(const Stage &stage,
223                  const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
224   class LoopSpliter : public IRMutator {
225     Expr factor;
226     const Variable *parent;
227     IterVar inner, outer;
228 
229    public:
230     bool splitted;
231     LoopSpliter(const SplitNode *split,
232                 const std::unordered_map<IterVar, Range> &dom_map) :
233       factor(split->factor), splitted(false) {
234       parent = split->parent->var.get();
235 
236       auto &inner_ = split->inner;
237       CHECK(dom_map.count(inner_));
238       auto &inner_dom = dom_map.find(inner_)->second;
239       CHECK(is_const_int(inner_dom->min, 0));
240 
241       auto &outer_ = split->outer;
242       CHECK(dom_map.count(outer_));
243       auto &outer_dom = dom_map.find(outer_)->second;
244       CHECK(is_const_int(outer_dom->min, 0));
245 
246       inner = IterVarNode::make(inner_dom, inner_->var, inner_->iter_type);
247       outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type);
248     }
249 
250     Stmt Mutate_(const For *op, const Stmt &stmt) {
251       if (op->loop_var.get() == parent) {
252         std::unordered_map<const Variable *, Expr> rmap;
253         rmap[op->loop_var.get()] = inner + outer * factor;
254         Stmt ret = ir::Substitute(op->body, rmap);
255         Expr cond = likely(outer * factor < (op->extent - inner));
256         ret = IfThenElse::make(cond, ret);
257         ret = For::make(inner->var, Expr(0), inner->dom->extent,
258                         IterVarTypeToForType(inner->iter_type), op->device_api, ret);
259         ret = For::make(outer->var, Expr(0), outer->dom->extent,
260                         IterVarTypeToForType(outer->iter_type), op->device_api, ret);
261         splitted = true;
262         return ret;
263       }
264       return IRMutator::Mutate_(op, stmt);
265     }
266   };
267 
268   class LoopFuser : public IRMutator {
269     const IterVar &parent;
270     const Variable *inner;
271     const Variable *outer;
272     bool under_outer;
273     Expr extent;
274 
275    public:
276     bool fused;
277     explicit LoopFuser(const FuseNode *fuse_)
278       : parent(fuse_->fused), inner(fuse_->inner->var.get()),
279         outer(fuse_->outer->var.get()), under_outer(false),
280         extent(0), fused(false) {}
281 
282     // TODO(@were): Handle imperfect loops
283 
284     Stmt Mutate_(const For *op, const Stmt &stmt) {
285       if (op->loop_var.get() == inner) {
286         CHECK(under_outer);
287         std::unordered_map<const Variable *, Expr> rmap;
288         rmap[op->loop_var.get()] = indexmod(parent, op->extent);
289         extent = op->extent;
290         fused = true;
291         return ir::Substitute(op->body, rmap);
292       } else if (op->loop_var.get() == outer) {
293         under_outer = true;
294         Stmt body = IRMutator::Mutate(op->body);
295         std::unordered_map<const Variable *, Expr> rmap;
296         rmap[op->loop_var.get()] = indexdiv(parent, extent);
297         body = ir::Substitute(body, rmap);
298         under_outer = false;
299         return For::make(parent->var, Expr(0), extent * op->extent,
300                          op->for_type, op->device_api, body);
301       } else if (under_outer) {
302         Stmt body = IRMutator::Mutate(op->body);
303         std::unordered_map<const Variable *, Expr> rmap;
304         rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
305         body = ir::Substitute(body, rmap);
306         extent = extent * op->extent;
307         return body;
308       }
309       return IRMutator::Mutate(stmt);
310     }
311   };
312 
313   for (auto &rel : stage->relations) {
314     if (const SplitNode *split = rel.as<SplitNode>()) {
315       LoopSpliter Spliter(split, dom_map);
316       stmt = Spliter.Mutate(stmt);
317       CHECK(Spliter.splitted);
318     } else if (const FuseNode *fuse = rel.as<FuseNode>()) {
319       LoopFuser Fuser(fuse);
320       stmt = Fuser.Mutate(stmt);
321       CHECK(Fuser.fused);
322     }
323   }
324 
325   return stmt;
326 }
327 
ApplyLoopAnnotations(const Stage & stage,const std::unordered_map<IterVar,IterVar> & rebased,Stmt stmt)328 Stmt ApplyLoopAnnotations(const Stage &stage,
329                           const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
330   class LoopAnnotator : public IRMutator {
331     const Variable *var;
332     const IterVarAttr &attr;
333 
334    public:
335     LoopAnnotator(const Variable *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
336 
337     Stmt Mutate_(const For *op, const Stmt &stmt) {
338       if (op->loop_var.get() == var) {
339         if (attr->bind_thread.defined()) {
340           const auto &iter_var = attr->bind_thread;
341           if (iter_var->dom.defined()) {
342             CHECK(is_const_int(iter_var->dom->min, 0));
343             CHECK(Equal(iter_var->dom->extent, op->extent))
344               << "Thread extent and loop extent mismatch!\n";
345           }
346           std::unordered_map<const Variable *, Expr> rmap;
347           rmap[op->loop_var.get()] = iter_var;
348           Stmt body = ir::Substitute(op->body, rmap);
349           return AttrStmt::make(iter_var, "thread_extent", op->extent, body);
350         } else {
351           return For::make(op->loop_var, op->min, op->extent,
352                            IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
353         }
354       }
355       return IRMutator::Mutate_(op, stmt);
356     }
357   };
358 
359   for (auto &iter_var : stage->leaf_iter_vars) {
360     bool need_change = false;
361     int found = 0;
362 
363     const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
364     const Variable *var = actual->var.get();
365     ForType expected = IterVarTypeToForType(iter_var->iter_type);
366     IterVarAttr attr;
367     if (stage->iter_var_attrs.count(iter_var)) {
368       attr = stage->iter_var_attrs[iter_var];
369       expected = IterVarTypeToForType(attr->iter_type);
370     }
371 
372     PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const NodeRef &node) {
373       if (const For *op = node.as<For>()) {
374         if (op->loop_var.get() == var) {
375           ++found;
376           need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined());
377         }
378       }
379     });
380 
381     CHECK_EQ(found, 1) << " iter var should be found exactly once!";
382     if (need_change) {
383       stmt = LoopAnnotator(var, attr).Mutate(stmt);
384     }
385   }
386   return stmt;
387 }
388 
ApplyLoopOrder(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,const std::unordered_map<IterVar,IterVar> & rebased,Stmt stmt)389 Stmt ApplyLoopOrder(const Stage &stage,
390                     const std::unordered_map<IterVar, Range> &dom_map,
391                     const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
392   std::vector<const Variable*> current_order;
393   PostOrderVisit(stmt, [&current_order](const NodeRef &node) {
394     if (const For *op = node.as<For>())
395       current_order.push_back(op->loop_var.get());
396   });
397   std::reverse(current_order.begin(), current_order.end());
398   auto &required_ord = stage->leaf_iter_vars;
399   CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!";
400   std::unordered_map<const Variable *, IterVar> reorder;
401   bool need_reorder = false;
402   for (size_t i = 0; i < current_order.size(); ++i) {
403     auto &current = current_order[i];
404     const IterVar &iter_var = required_ord[i];
405     const IterVar &required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
406     CHECK(required->dom.defined() || dom_map.count(required)) << required << "\n";
407     reorder[current] = required;
408     if (current != required->var.get()) {
409       need_reorder = true;
410     }
411   }
412 
413   class LoopReorder : public IRMutator {
414     const Stage &stage;
415     const std::unordered_map<IterVar, Range> &dom_map;
416     const std::unordered_map<const Variable *, IterVar> &reorder;
417 
418    public:
419     LoopReorder(const Stage &stage,
420                 const std::unordered_map<IterVar, Range> &dom_map,
421                 const std::unordered_map<const Variable*, IterVar> &reorder)
422       : stage(stage), dom_map(dom_map), reorder(reorder) {}
423 
424     Stmt Mutate_(const For *op, const Stmt &stmt) {
425       // Reorder from in to out
426       Stmt body_ = IRMutator::Mutate(op->body);
427       CHECK(reorder.count(op->loop_var.get()));
428       auto target = reorder.find(op->loop_var.get())->second;
429       if (body_.same_as(op->body) && op->loop_var.get() == target->var.get())
430         return stmt;
431       const Stmt &body = op->body.same_as(body_) ? op->body : body_;
432       ForType for_type = IterVarTypeToForType(target->iter_type);
433       if (stage->iter_var_attrs.count(target)) {
434         for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
435       }
436       const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
437       return For::make(target->var, range->min, range->extent,
438                        for_type, DeviceAPI::None, body);
439     }
440   };
441 
442   if (need_reorder)
443     return LoopReorder(stage, dom_map, reorder).Mutate(stmt);
444 
445   return stmt;
446 }
447 
ApplySchedule(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,Stmt stmt)448 Stmt ApplySchedule(const Stage &stage,
449                    const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
450   // TODO(@were): Eliminate loop rebase in script parser and move the burden here
451   // Gather rebased variables
452   std::unordered_map<IterVar, IterVar> rebased;
453   for (auto rel : stage->relations) {
454     if (const auto* rebase = rel.as<RebaseNode>()) {
455       rebased[rebase->rebased] = rebase->parent;
456       CHECK(rebase->parent->dom.defined());
457       CHECK(dom_map.count(rebase->rebased));
458     }
459   }
460   stmt = ApplyLoopShapes(stage, dom_map, stmt);
461   stmt = ApplyLoopOrder(stage, dom_map, rebased, stmt);
462   stmt = ApplyLoopAnnotations(stage, rebased, stmt);
463   return stmt;
464 }
465 
GatherLoopVars(Stmt stmt)466 std::vector<IterVar> GatherLoopVars(Stmt stmt) {
467   // TODO(@were): Write a comprehensive pass to analyze iter var types
468   std::vector<IterVar> res_;
469   PostOrderVisit(stmt, [&res_](const NodeRef &node) {
470     if (const For *op = node.as<For>()) {
471       Var loop_var(op->loop_var);
472       Range dom = Range::make_by_min_extent(op->min, op->extent);
473       res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type)));
474     }
475   });
476   std::reverse(res_.begin(), res_.end());
477   return res_;
478 }
479 
480 // replacer to replace tensors' usage in Provide
481 class ProviderReplacer : public ir::IRMutator {
482  public:
ProviderReplacer(const std::unordered_map<Tensor,Tensor> & vmap)483   explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor> &vmap)
484       : vmap_(vmap) {}
485 
Mutate_(const ir::Provide * op,const Stmt & s)486   Stmt Mutate_(const ir::Provide* op, const Stmt &s) {
487     Tensor t = Downcast<Operation>(op->func).output(op->value_index);
488     auto it = vmap_.find(t);
489     if (it != vmap_.end()) {
490       Stmt ret = ir::Provide::make(
491         it->second->op, it->second->value_index, op->value, op->args);
492       found = true;
493       return IRMutator::Mutate_(ret.as<ir::Provide>(), ret);
494     }
495     return IRMutator::Mutate_(op, s);
496   }
497 
498   // whether it is found.
499   bool found{false};
500 
501  private:
502   const std::unordered_map<Tensor, Tensor> &vmap_;
503 };
504 
ReplaceProvideTensor(Stmt stmt,const std::unordered_map<Tensor,Tensor> & replace)505 Stmt ReplaceProvideTensor(Stmt stmt,
506                    const std::unordered_map<Tensor, Tensor> &replace) {
507   ProviderReplacer repl(replace);
508   Stmt ret = repl.Mutate(stmt);
509   return repl.found ? ret : stmt;
510 }
511 }  // namespace op
512 }  // namespace tvm
513