1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \brief Hybrid computation rule.
22  * \file hybrid_op.cc
23  */
24 #include "hybrid_op.h"
25 
26 #include <tvm/arith/analyzer.h>
27 #include <tvm/runtime/registry.h>
28 #include <tvm/te/operation.h>
29 #include <tvm/tir/analysis.h>
30 #include <tvm/tir/expr.h>
31 #include <tvm/tir/op.h>
32 #include <tvm/tir/stmt_functor.h>
33 
34 #include <string>
35 #include <unordered_set>
36 #include <utility>
37 
38 #include "op_util.h"
39 
40 namespace tvm {
41 namespace te {
42 using namespace tir;
43 // HybridOpNode
44 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
__anonec2f27b30102(const ObjectRef& node, ReprPrinter* p) 45     .set_dispatch<HybridOpNode>([](const ObjectRef& node, ReprPrinter* p) {
46       auto* op = static_cast<const HybridOpNode*>(node.get());
47       p->stream << "hybrid(" << op->name << ", " << op << ")";
48     });
49 
50 TVM_REGISTER_NODE_TYPE(HybridOpNode);
51 
num_outputs() const52 int HybridOpNode::num_outputs() const { return static_cast<int>(outputs.size()); }
53 
root_iter_vars() const54 Array<IterVar> HybridOpNode::root_iter_vars() const { return this->axis; }
55 
output_dtype(size_t i) const56 DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; }
57 
output_shape(size_t i) const58 Array<PrimExpr> HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; }
59 
HybridOp(std::string name,std::string tag,Map<String,ObjectRef> attrs,Array<Tensor> inputs,Array<Tensor> outputs,Stmt body)60 HybridOp::HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
61                    Array<Tensor> inputs, Array<Tensor> outputs, Stmt body) {
62   if (!attrs.defined()) {
63     attrs = Map<String, ObjectRef>();
64   }
65   auto n = make_object<HybridOpNode>();
66   n->name = std::move(name);
67   n->tag = std::move(tag);
68   n->attrs = std::move(attrs);
69   n->inputs = std::move(inputs);
70   n->outputs = std::move(outputs);
71   n->axis = te::GatherLoopVars(body);
72   n->body = std::move(body);
73   data_ = std::move(n);
74 }
75 
76 TVM_REGISTER_GLOBAL("te.HybridOp")
77     .set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs,
78                        Array<Tensor> inputs, Array<Tensor> outputs,
__anonec2f27b30202(std::string name, std::string tag, Map<String, ObjectRef> attrs, Array<Tensor> inputs, Array<Tensor> outputs, Stmt body) 79                        Stmt body) { return HybridOp(name, tag, attrs, inputs, outputs, body); });
80 
InputTensors() const81 Array<Tensor> HybridOpNode::InputTensors() const {
82   // Because input tensors could be potentially inlined into hybrid scripts,
83   // we need to check if all input tensors are used in the body.
84   std::unordered_set<Tensor> orig_inputs;
85   for (auto t : inputs) {
86     orig_inputs.insert(t);
87   }
88   std::unordered_set<Tensor> visited;
89   Array<Tensor> curr_inputs;
90   tir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) {
91     if (auto* pload = n.as<tir::ProducerLoadNode>()) {
92       Tensor t = Downcast<Tensor>(pload->producer);
93       if (orig_inputs.count(t) && !visited.count(t)) {
94         curr_inputs.push_back(t);
95         visited.insert(t);
96       }
97     }
98   });
99   return curr_inputs;
100 }
101 
ReplaceInputs(const Operation & self,const std::unordered_map<Tensor,Tensor> & rmap) const102 Operation HybridOpNode::ReplaceInputs(const Operation& self,
103                                       const std::unordered_map<Tensor, Tensor>& rmap) const {
104   CHECK_EQ(self.operator->(), this);
105   auto n = make_object<HybridOpNode>(*this);
106   n->body = te::ReplaceTensor(this->body, rmap);
107   for (size_t i = 0; i < n->inputs.size(); ++i) {
108     Tensor t = n->inputs[i];
109     if (rmap.count(t)) {
110       n->inputs.Set(i, rmap.at(t));
111     }
112   }
113 
114   if (body.same_as(n->body) && inputs.same_as(n->inputs)) {
115     return self;
116   } else {
117     return Operation(n);
118   }
119 }
120 
PropBoundToInputs(const Operation & self,arith::Analyzer * analyzer,const std::unordered_map<const VarNode *,IntSet> & dom_map,std::unordered_map<Tensor,TensorDom> * out_dom_map) const121 void HybridOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
122                                      const std::unordered_map<const VarNode*, IntSet>& dom_map,
123                                      std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
124   auto curr_inputs = InputTensors();
125   for (Tensor t : curr_inputs) {
126     auto it = out_dom_map->find(t);
127     if (it == out_dom_map->end()) continue;
128     TensorDom& dom = it->second;
129     for (size_t i = 0; i < t->shape.size(); ++i) {
130       dom.data[i].emplace_back(
131           IntSet::FromRange(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])));
132     }
133   }
134 }
135 
GatherBound(const Operation & self,const std::unordered_map<Tensor,TensorDom> & tensor_dom,std::unordered_map<IterVar,Range> * out_dom_map) const136 void HybridOpNode::GatherBound(const Operation& self,
137                                const std::unordered_map<Tensor, TensorDom>& tensor_dom,
138                                std::unordered_map<IterVar, Range>* out_dom_map) const {
139   for (auto iter_var : axis) {
140     CHECK(!out_dom_map->count(iter_var));
141     out_dom_map->operator[](iter_var) = iter_var->dom;
142   }
143 }
144 
BuildRealize(const Stage & stage,const std::unordered_map<IterVar,Range> & realize_map,const Stmt & body) const145 Stmt HybridOpNode::BuildRealize(const Stage& stage,
146                                 const std::unordered_map<IterVar, Range>& realize_map,
147                                 const Stmt& body) const {
148   // TODO(@were): Add attribute inject here and remove it from hybrid parser.
149   CHECK_EQ(stage->op.get(), this);
150   Stmt realize_body = body;
151   for (int k = 0; k < num_outputs(); ++k) {
152     Tensor t = stage->op.output(k);
153     Region bounds;
154     for (size_t i = 0; i < t->shape.size(); ++i) {
155       bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i]));
156     }
157     realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body);
158   }
159   return realize_body;
160 }
161 
BuildProvide(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,bool debug_keep_trivial_loop) const162 Stmt HybridOpNode::BuildProvide(const Stage& stage,
163                                 const std::unordered_map<IterVar, Range>& dom_map,
164                                 bool debug_keep_trivial_loop) const {
165   CHECK_EQ(stage->op.operator->(), this);
166   Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
167   std::unordered_map<Tensor, Tensor> rmap;
168   for (int i = 0; i < this->num_outputs(); ++i) {
169     rmap[outputs[i]] = stage->op.output(i);
170   }
171   auto n = make_object<HybridOpNode>(*this);
172   /* This is a story little bit complicated.
173    * The following two lines of codes replace output tensors' usage.
174    * This is the simplest way I (@were) can come up with to glue
175    * hybrid operation node to TVM op system.
176    * In hybrid script all the tensors, especially the output tensors,
177    * have their own names defined by the users. However, In TVM
178    * conventional ops:
179    *   1. Output tensors refer the corresponding op node so that the output
180    *      tensors have the same names as the operation produces them.
181    *   2. Once OpNode is wrapped up by an Operation node, it is finalized.
182    *      Later access will be from a const OpNode*.
183    * This is a chicken-egg paradox. It is impossible to put the output
184    * tensors into the function body without forming the op node. The
185    * function body is immutable after the node is formed.
186    *
187    * Finally, I decided to resolve this issue "lazily". During the
188    * pipeline of compilation, this stage is a very preliminary stage.
189    * Technically, it is before Phase 0. The actual tensors will be replaced
190    * here.
191    * Thus, the operation body is slightly different from the Phase 0 body.
192    * This is a major difference that HybridOpNode is NOT the same as
193    * ExternOpNode.
194    * */
195   ret = te::ReplaceTensor(ret, rmap);
196   ret = te::ReplaceProvideTensor(ret, rmap);
197 
198   ret = te::ApplySchedule(stage, dom_map, ret);
199   return ret;
200 }
201 
ApplyLoopShapes(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,Stmt stmt)202 Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
203                      Stmt stmt) {
204   class LoopSpliter : public StmtExprMutator {
205     PrimExpr factor;
206     const VarNode* parent;
207     IterVar inner, outer;
208 
209    public:
210     bool splitted;
211     LoopSpliter(const SplitNode* split, const std::unordered_map<IterVar, Range>& dom_map)
212         : factor(split->factor), splitted(false) {
213       parent = split->parent->var.get();
214 
215       auto& inner_ = split->inner;
216       CHECK(dom_map.count(inner_));
217       auto& inner_dom = dom_map.find(inner_)->second;
218       CHECK(is_const_int(inner_dom->min, 0));
219 
220       auto& outer_ = split->outer;
221       CHECK(dom_map.count(outer_));
222       auto& outer_dom = dom_map.find(outer_)->second;
223       CHECK(is_const_int(outer_dom->min, 0));
224 
225       inner = IterVar(inner_dom, inner_->var, inner_->iter_type);
226       outer = IterVar(outer_dom, outer_->var, outer_->iter_type);
227     }
228 
229     Stmt VisitStmt_(const ForNode* op) final {
230       if (op->loop_var.get() == parent) {
231         std::unordered_map<const VarNode*, PrimExpr> rmap;
232         rmap[op->loop_var.get()] = inner + outer * factor;
233         Stmt ret = tir::Substitute(op->body, rmap);
234         PrimExpr cond = likely(outer * factor < (op->extent - inner));
235         ret = IfThenElse(cond, ret);
236         ret = For(inner->var, PrimExpr(0), inner->dom->extent,
237                   IterVarTypeToForType(inner->iter_type), op->device_api, ret);
238         ret = For(outer->var, PrimExpr(0), outer->dom->extent,
239                   IterVarTypeToForType(outer->iter_type), op->device_api, ret);
240         splitted = true;
241         return ret;
242       }
243       return StmtExprMutator::VisitStmt_(op);
244     }
245   };
246 
247   class LoopFuser : public StmtExprMutator {
248     const IterVar& parent;
249     const VarNode* inner;
250     const VarNode* outer;
251     bool under_outer;
252     PrimExpr extent;
253 
254    public:
255     bool fused;
256     explicit LoopFuser(const FuseNode* fuse_)
257         : parent(fuse_->fused),
258           inner(fuse_->inner->var.get()),
259           outer(fuse_->outer->var.get()),
260           under_outer(false),
261           extent(0),
262           fused(false) {}
263 
264     // TODO(@were): Handle imperfect loops
265     Stmt VisitStmt_(const ForNode* op) final {
266       if (op->loop_var.get() == inner) {
267         CHECK(under_outer);
268         std::unordered_map<const VarNode*, PrimExpr> rmap;
269         rmap[op->loop_var.get()] = indexmod(parent, op->extent);
270         extent = op->extent;
271         fused = true;
272         return tir::Substitute(op->body, rmap);
273       } else if (op->loop_var.get() == outer) {
274         under_outer = true;
275         Stmt body = this->VisitStmt(op->body);
276         std::unordered_map<const VarNode*, PrimExpr> rmap;
277         rmap[op->loop_var.get()] = indexdiv(parent, extent);
278         body = tir::Substitute(body, rmap);
279         under_outer = false;
280         return For(parent->var, PrimExpr(0), extent * op->extent, op->for_type, op->device_api,
281                    body);
282       } else if (under_outer) {
283         Stmt body = this->VisitStmt(op->body);
284         std::unordered_map<const VarNode*, PrimExpr> rmap;
285         rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
286         body = tir::Substitute(body, rmap);
287         extent = extent * op->extent;
288         return body;
289       }
290       return StmtExprMutator::VisitStmt_(op);
291     }
292   };
293 
294   for (auto& rel : stage->relations) {
295     if (const SplitNode* split = rel.as<SplitNode>()) {
296       LoopSpliter Spliter(split, dom_map);
297       stmt = Spliter(stmt);
298       CHECK(Spliter.splitted);
299     } else if (const FuseNode* fuse = rel.as<FuseNode>()) {
300       LoopFuser Fuser(fuse);
301       stmt = Fuser(stmt);
302       CHECK(Fuser.fused);
303     }
304   }
305 
306   return stmt;
307 }
308 
ApplyLoopAnnotations(const Stage & stage,const std::unordered_map<IterVar,IterVar> & rebased,Stmt stmt)309 Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map<IterVar, IterVar>& rebased,
310                           Stmt stmt) {
311   class LoopAnnotator : public StmtMutator {
312     const VarNode* var;
313     const IterVarAttr& attr;
314 
315    public:
316     LoopAnnotator(const VarNode* var_, const IterVarAttr& attr_) : var(var_), attr(attr_) {}
317 
318     Stmt VisitStmt_(const ForNode* op) final {
319       tir::ExprDeepEqual expr_equal;
320 
321       if (op->loop_var.get() == var) {
322         if (attr->bind_thread.defined()) {
323           const auto& iter_var = attr->bind_thread;
324           if (iter_var->dom.defined()) {
325             CHECK(is_const_int(iter_var->dom->min, 0));
326             CHECK(expr_equal(iter_var->dom->extent, op->extent))
327                 << "Thread extent and loop extent mismatch!\n";
328           }
329           std::unordered_map<const VarNode*, PrimExpr> rmap;
330           rmap[op->loop_var.get()] = iter_var;
331           Stmt body = tir::Substitute(op->body, rmap);
332           return AttrStmt(iter_var, "thread_extent", op->extent, body);
333         } else {
334           return For(op->loop_var, op->min, op->extent, IterVarTypeToForType(attr->iter_type),
335                      op->device_api, op->body);
336         }
337       }
338       return StmtMutator::VisitStmt_(op);
339     }
340   };
341 
342   for (auto& iter_var : stage->leaf_iter_vars) {
343     bool need_change = false;
344     int found = 0;
345 
346     const IterVar& actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
347     const VarNode* var = actual->var.get();
348     ForType expected = IterVarTypeToForType(iter_var->iter_type);
349     IterVarAttr attr;
350     if (stage->iter_var_attrs.count(iter_var)) {
351       attr = stage->iter_var_attrs[iter_var];
352       expected = IterVarTypeToForType(attr->iter_type);
353     }
354 
355     PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) {
356       if (const ForNode* op = node.as<ForNode>()) {
357         if (op->loop_var.get() == var) {
358           ++found;
359           need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined());
360         }
361       }
362     });
363 
364     CHECK_EQ(found, 1) << " iter var should be found exactly once!";
365     if (need_change) {
366       stmt = LoopAnnotator(var, attr)(std::move(stmt));
367     }
368   }
369   return stmt;
370 }
371 
ApplyLoopOrder(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,const std::unordered_map<IterVar,IterVar> & rebased,Stmt stmt)372 Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
373                     const std::unordered_map<IterVar, IterVar>& rebased, Stmt stmt) {
374   std::vector<const VarNode*> current_order;
375   PostOrderVisit(stmt, [&current_order](const ObjectRef& node) {
376     if (const ForNode* op = node.as<ForNode>()) current_order.push_back(op->loop_var.get());
377   });
378   std::reverse(current_order.begin(), current_order.end());
379   auto& required_ord = stage->leaf_iter_vars;
380   CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!";
381   std::unordered_map<const VarNode*, IterVar> reorder;
382   bool need_reorder = false;
383   for (size_t i = 0; i < current_order.size(); ++i) {
384     auto& current = current_order[i];
385     const IterVar& iter_var = required_ord[i];
386     const IterVar& required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
387     CHECK(required->dom.defined() || dom_map.count(required)) << required << "\n";
388     reorder[current] = required;
389     if (current != required->var.get()) {
390       need_reorder = true;
391     }
392   }
393 
394   class LoopReorder : public StmtMutator {
395     const Stage& stage;
396     const std::unordered_map<IterVar, Range>& dom_map;
397     const std::unordered_map<const VarNode*, IterVar>& reorder;
398 
399    public:
400     LoopReorder(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
401                 const std::unordered_map<const VarNode*, IterVar>& reorder)
402         : stage(stage), dom_map(dom_map), reorder(reorder) {}
403 
404     Stmt VisitStmt_(const ForNode* op) final {
405       // Reorder from in to out
406       Stmt body_ = this->VisitStmt(op->body);
407       CHECK(reorder.count(op->loop_var.get()));
408       auto target = reorder.find(op->loop_var.get())->second;
409       if (body_.same_as(op->body) && op->loop_var.get() == target->var.get())
410         return GetRef<Stmt>(op);
411       const Stmt& body = op->body.same_as(body_) ? op->body : body_;
412       ForType for_type = IterVarTypeToForType(target->iter_type);
413       if (stage->iter_var_attrs.count(target)) {
414         for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
415       }
416       const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
417       return For(target->var, range->min, range->extent, for_type, DeviceAPI::None, body);
418     }
419   };
420 
421   if (need_reorder) return LoopReorder(stage, dom_map, reorder)(stmt);
422 
423   return stmt;
424 }
425 
ApplySchedule(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,Stmt stmt)426 Stmt ApplySchedule(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
427                    Stmt stmt) {
428   // TODO(@were): Eliminate loop rebase in script parser and move the burden here
429   // Gather rebased variables
430   std::unordered_map<IterVar, IterVar> rebased;
431   for (auto rel : stage->relations) {
432     if (const auto* rebase = rel.as<RebaseNode>()) {
433       rebased[rebase->rebased] = rebase->parent;
434       CHECK(rebase->parent->dom.defined());
435       CHECK(dom_map.count(rebase->rebased));
436     }
437   }
438   stmt = ApplyLoopShapes(stage, dom_map, stmt);
439   stmt = ApplyLoopOrder(stage, dom_map, rebased, stmt);
440   stmt = ApplyLoopAnnotations(stage, rebased, stmt);
441   return stmt;
442 }
443 
GatherLoopVars(Stmt stmt)444 std::vector<IterVar> GatherLoopVars(Stmt stmt) {
445   // TODO(@were): Write a comprehensive pass to analyze iter var types
446   std::vector<IterVar> res_;
447   PostOrderVisit(stmt, [&res_](const ObjectRef& node) {
448     if (const ForNode* op = node.as<ForNode>()) {
449       Var loop_var(op->loop_var);
450       Range dom = Range::FromMinExtent(op->min, op->extent);
451       res_.push_back(IterVar(dom, loop_var, ForTypeToIterVarType(op->for_type)));
452     }
453   });
454   std::reverse(res_.begin(), res_.end());
455   return res_;
456 }
457 
458 // replacer to replace tensors' usage in Provide
459 class ProviderReplacer : public tir::StmtMutator {
460  public:
ProviderReplacer(const std::unordered_map<Tensor,Tensor> & vmap)461   explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor>& vmap) : vmap_(vmap) {}
462 
VisitStmt_(const tir::ProducerStoreNode * op)463   Stmt VisitStmt_(const tir::ProducerStoreNode* op) final {
464     Tensor t = Downcast<Tensor>(op->producer);
465     auto it = vmap_.find(t);
466     if (it != vmap_.end()) {
467       Stmt ret = tir::ProducerStore(it->second, op->value, op->indices);
468       found = true;
469       return this->VisitStmt(ret);
470     }
471     return StmtMutator::VisitStmt_(op);
472   }
473 
474   // whether it is found.
475   bool found{false};
476 
477  private:
478   const std::unordered_map<Tensor, Tensor>& vmap_;
479 };
480 
ReplaceProvideTensor(Stmt stmt,const std::unordered_map<Tensor,Tensor> & replace)481 Stmt ReplaceProvideTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& replace) {
482   ProviderReplacer repl(replace);
483   Stmt ret = repl(stmt);
484   return repl.found ? ret : stmt;
485 }
486 }  // namespace te
487 }  // namespace tvm
488