1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \brief Hybrid computation rule.
22 * \file hybrid_op.cc
23 */
24 #include <tvm/operation.h>
25 #include <tvm/arithmetic.h>
26 #include <tvm/ir.h>
27 #include <tvm/ir_mutator.h>
28 #include <tvm/ir_pass.h>
29 #include <tvm/expr_operator.h>
30 #include <unordered_set>
31 #include <string>
32 #include <utility>
33 #include "op_util.h"
34 #include "hybrid_op.h"
35
36 namespace tvm {
37 using namespace ir;
38 // HybridOpNode
39 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonb7a6c0fe0102(const ObjectRef& node, IRPrinter* p) 40 .set_dispatch<HybridOpNode>([](const ObjectRef& node, IRPrinter* p) {
41 auto* op = static_cast<const HybridOpNode*>(node.get());
42 p->stream << "hybrid(" << op->name << ", " << op << ")";
43 });
44
45 TVM_REGISTER_NODE_TYPE(HybridOpNode);
46
num_outputs() const47 int HybridOpNode::num_outputs() const {
48 return static_cast<int>(outputs.size());
49 }
50
root_iter_vars() const51 Array<IterVar> HybridOpNode::root_iter_vars() const {
52 return this->axis;
53 }
54
output_dtype(size_t i) const55 Type HybridOpNode::output_dtype(size_t i) const {
56 return outputs[i]->dtype;
57 }
58
output_shape(size_t i) const59 Array<Expr> HybridOpNode::output_shape(size_t i) const {
60 return outputs[i]->shape;
61 }
62
63
make(std::string name,std::string tag,Map<std::string,NodeRef> attrs,Array<Tensor> inputs,Array<Tensor> outputs,Stmt body)64 Operation HybridOpNode::make(std::string name,
65 std::string tag,
66 Map<std::string, NodeRef> attrs,
67 Array<Tensor> inputs,
68 Array<Tensor> outputs,
69 Stmt body) {
70 if (!attrs.defined()) {
71 attrs = Map<std::string, NodeRef>();
72 }
73 auto n = make_node<HybridOpNode>();
74 n->name = std::move(name);
75 n->tag = std::move(tag);
76 n->attrs = std::move(attrs);
77 n->inputs = std::move(inputs);
78 n->outputs = std::move(outputs);
79 n->axis = op::GatherLoopVars(body);
80 n->body = std::move(body);
81 Operation res = Operation(n);
82 return res;
83 }
84
InputTensors() const85 Array<Tensor> HybridOpNode::InputTensors() const {
86 // Because input tensors could be potentially inlined into hybrid scripts,
87 // we need to check if all input tensors are used in the body.
88 std::unordered_set<Tensor> orig_inputs;
89 for (auto t : inputs) {
90 orig_inputs.insert(t);
91 }
92 std::unordered_set<Tensor> visited;
93 Array<Tensor> curr_inputs;
94 ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const NodeRef& n) {
95 const ir::Call *call = n.as<ir::Call>();
96 if (call != nullptr && call->func.defined()) {
97 Tensor t = Downcast<Operation>(call->func).output(call->value_index);
98 if (orig_inputs.count(t) && !visited.count(t)) {
99 curr_inputs.push_back(t);
100 visited.insert(t);
101 }
102 }
103 });
104 return curr_inputs;
105 }
106
ReplaceInputs(const Operation & self,const std::unordered_map<Tensor,Tensor> & rmap) const107 Operation HybridOpNode::ReplaceInputs(
108 const Operation &self,
109 const std::unordered_map<Tensor, Tensor> &rmap) const {
110 CHECK_EQ(self.operator->(), this);
111 auto n = make_node<HybridOpNode>(*this);
112 n->body = op::ReplaceTensor(this->body, rmap);
113 for (size_t i = 0; i < n->inputs.size(); ++i) {
114 Tensor t = n->inputs[i];
115 if (rmap.count(t)) {
116 n->inputs.Set(i, rmap.at(t));
117 }
118 }
119
120 if (body.same_as(n->body) &&
121 inputs.same_as(n->inputs)) {
122 return self;
123 } else {
124 return Operation(n);
125 }
126 }
127
PropBoundToInputs(const Operation & self,arith::Analyzer * analyzer,const std::unordered_map<const Variable *,IntSet> & dom_map,std::unordered_map<Tensor,TensorDom> * out_dom_map) const128 void HybridOpNode::PropBoundToInputs(
129 const Operation &self,
130 arith::Analyzer* analyzer,
131 const std::unordered_map<const Variable*, IntSet> &dom_map,
132 std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
133 auto curr_inputs = InputTensors();
134 for (Tensor t : curr_inputs) {
135 auto it = out_dom_map->find(t);
136 if (it == out_dom_map->end()) continue;
137 TensorDom &dom = it->second;
138 for (size_t i = 0; i < t->shape.size(); ++i) {
139 dom.data[i].emplace_back(IntSet::range(
140 Range::make_by_min_extent(
141 make_const(t->shape[i].type(), 0), t->shape[i])));
142 }
143 }
144 }
145
GatherBound(const Operation & self,const std::unordered_map<Tensor,TensorDom> & tensor_dom,std::unordered_map<IterVar,Range> * out_dom_map) const146 void HybridOpNode::GatherBound(
147 const Operation &self,
148 const std::unordered_map<Tensor, TensorDom> &tensor_dom,
149 std::unordered_map<IterVar, Range>* out_dom_map) const {
150 for (auto iter_var : axis) {
151 CHECK(!out_dom_map->count(iter_var));
152 out_dom_map->operator[](iter_var) = iter_var->dom;
153 }
154 }
155
BuildRealize(const Stage & stage,const std::unordered_map<IterVar,Range> & realize_map,const Stmt & body) const156 Stmt HybridOpNode::BuildRealize(
157 const Stage &stage,
158 const std::unordered_map<IterVar, Range> &realize_map,
159 const Stmt &body) const {
160 // TODO(@were): Add attribute inject here and remove it from hybrid parser.
161 CHECK_EQ(stage->op.get(), this);
162 Stmt realize_body = body;
163 for (int k = 0; k < num_outputs(); ++k) {
164 Tensor t = stage->op.output(k);
165 Region bounds;
166 for (size_t i = 0; i < t->shape.size(); ++i) {
167 bounds.push_back(
168 Range::make_by_min_extent(
169 make_const(t->shape[i].type(), 0), t->shape[i]));
170 }
171 realize_body = ir::Realize::make(
172 t->op, t->value_index, t->dtype,
173 bounds, const_true(), realize_body);
174 }
175 return realize_body;
176 }
177
BuildProvide(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,bool debug_keep_trivial_loop) const178 Stmt HybridOpNode::BuildProvide(
179 const Stage &stage,
180 const std::unordered_map<IterVar, Range> &dom_map,
181 bool debug_keep_trivial_loop) const {
182 CHECK_EQ(stage->op.operator->(), this);
183 Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
184 std::unordered_map<Tensor, Tensor> rmap;
185 for (int i = 0; i < this->num_outputs(); ++i) {
186 rmap[outputs[i]] = stage->op.output(i);
187 }
188 auto n = make_node<HybridOpNode>(*this);
189 /* This is a story little bit complicated.
190 * The following two lines of codes replace output tensors' usage.
191 * This is the simplest way I (@were) can come up with to glue
192 * hybrid operation node to TVM op system.
193 * In hybrid script all the tensors, especially the output tensors,
194 * have their own names defined by the users. However, In TVM
195 * conventional ops:
196 * 1. Output tensors refer the corresponding op node so that the output
197 * tensors have the same names as the operation produces them.
198 * 2. Once OpNode is wrapped up by an Operation node, it is finalized.
199 * Later access will be from a const OpNode*.
200 * This is a chicken-egg paradox. It is impossible to put the output
201 * tensors into the function body without forming the op node. The
202 * function body is immutable after the node is formed.
203 *
204 * Finally, I decided to resolve this issue "lazily". During the
205 * pipeline of compilation, this stage is a very preliminary stage.
206 * Technically, it is before Phase 0. The actual tensors will be replaced
207 * here.
208 * Thus, the operation body is slightly different from the Phase 0 body.
209 * This is a major difference that HybridOpNode is NOT the same as
210 * ExternOpNode.
211 * */
212 ret = op::ReplaceTensor(ret, rmap);
213 ret = op::ReplaceProvideTensor(ret, rmap);
214
215 ret = op::ApplySchedule(stage, dom_map, ret);
216 return ret;
217 }
218
219 namespace op {
220
221
ApplyLoopShapes(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,Stmt stmt)222 Stmt ApplyLoopShapes(const Stage &stage,
223 const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
224 class LoopSpliter : public IRMutator {
225 Expr factor;
226 const Variable *parent;
227 IterVar inner, outer;
228
229 public:
230 bool splitted;
231 LoopSpliter(const SplitNode *split,
232 const std::unordered_map<IterVar, Range> &dom_map) :
233 factor(split->factor), splitted(false) {
234 parent = split->parent->var.get();
235
236 auto &inner_ = split->inner;
237 CHECK(dom_map.count(inner_));
238 auto &inner_dom = dom_map.find(inner_)->second;
239 CHECK(is_const_int(inner_dom->min, 0));
240
241 auto &outer_ = split->outer;
242 CHECK(dom_map.count(outer_));
243 auto &outer_dom = dom_map.find(outer_)->second;
244 CHECK(is_const_int(outer_dom->min, 0));
245
246 inner = IterVarNode::make(inner_dom, inner_->var, inner_->iter_type);
247 outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type);
248 }
249
250 Stmt Mutate_(const For *op, const Stmt &stmt) {
251 if (op->loop_var.get() == parent) {
252 std::unordered_map<const Variable *, Expr> rmap;
253 rmap[op->loop_var.get()] = inner + outer * factor;
254 Stmt ret = ir::Substitute(op->body, rmap);
255 Expr cond = likely(outer * factor < (op->extent - inner));
256 ret = IfThenElse::make(cond, ret);
257 ret = For::make(inner->var, Expr(0), inner->dom->extent,
258 IterVarTypeToForType(inner->iter_type), op->device_api, ret);
259 ret = For::make(outer->var, Expr(0), outer->dom->extent,
260 IterVarTypeToForType(outer->iter_type), op->device_api, ret);
261 splitted = true;
262 return ret;
263 }
264 return IRMutator::Mutate_(op, stmt);
265 }
266 };
267
268 class LoopFuser : public IRMutator {
269 const IterVar &parent;
270 const Variable *inner;
271 const Variable *outer;
272 bool under_outer;
273 Expr extent;
274
275 public:
276 bool fused;
277 explicit LoopFuser(const FuseNode *fuse_)
278 : parent(fuse_->fused), inner(fuse_->inner->var.get()),
279 outer(fuse_->outer->var.get()), under_outer(false),
280 extent(0), fused(false) {}
281
282 // TODO(@were): Handle imperfect loops
283
284 Stmt Mutate_(const For *op, const Stmt &stmt) {
285 if (op->loop_var.get() == inner) {
286 CHECK(under_outer);
287 std::unordered_map<const Variable *, Expr> rmap;
288 rmap[op->loop_var.get()] = indexmod(parent, op->extent);
289 extent = op->extent;
290 fused = true;
291 return ir::Substitute(op->body, rmap);
292 } else if (op->loop_var.get() == outer) {
293 under_outer = true;
294 Stmt body = IRMutator::Mutate(op->body);
295 std::unordered_map<const Variable *, Expr> rmap;
296 rmap[op->loop_var.get()] = indexdiv(parent, extent);
297 body = ir::Substitute(body, rmap);
298 under_outer = false;
299 return For::make(parent->var, Expr(0), extent * op->extent,
300 op->for_type, op->device_api, body);
301 } else if (under_outer) {
302 Stmt body = IRMutator::Mutate(op->body);
303 std::unordered_map<const Variable *, Expr> rmap;
304 rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
305 body = ir::Substitute(body, rmap);
306 extent = extent * op->extent;
307 return body;
308 }
309 return IRMutator::Mutate(stmt);
310 }
311 };
312
313 for (auto &rel : stage->relations) {
314 if (const SplitNode *split = rel.as<SplitNode>()) {
315 LoopSpliter Spliter(split, dom_map);
316 stmt = Spliter.Mutate(stmt);
317 CHECK(Spliter.splitted);
318 } else if (const FuseNode *fuse = rel.as<FuseNode>()) {
319 LoopFuser Fuser(fuse);
320 stmt = Fuser.Mutate(stmt);
321 CHECK(Fuser.fused);
322 }
323 }
324
325 return stmt;
326 }
327
ApplyLoopAnnotations(const Stage & stage,const std::unordered_map<IterVar,IterVar> & rebased,Stmt stmt)328 Stmt ApplyLoopAnnotations(const Stage &stage,
329 const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
330 class LoopAnnotator : public IRMutator {
331 const Variable *var;
332 const IterVarAttr &attr;
333
334 public:
335 LoopAnnotator(const Variable *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
336
337 Stmt Mutate_(const For *op, const Stmt &stmt) {
338 if (op->loop_var.get() == var) {
339 if (attr->bind_thread.defined()) {
340 const auto &iter_var = attr->bind_thread;
341 if (iter_var->dom.defined()) {
342 CHECK(is_const_int(iter_var->dom->min, 0));
343 CHECK(Equal(iter_var->dom->extent, op->extent))
344 << "Thread extent and loop extent mismatch!\n";
345 }
346 std::unordered_map<const Variable *, Expr> rmap;
347 rmap[op->loop_var.get()] = iter_var;
348 Stmt body = ir::Substitute(op->body, rmap);
349 return AttrStmt::make(iter_var, "thread_extent", op->extent, body);
350 } else {
351 return For::make(op->loop_var, op->min, op->extent,
352 IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
353 }
354 }
355 return IRMutator::Mutate_(op, stmt);
356 }
357 };
358
359 for (auto &iter_var : stage->leaf_iter_vars) {
360 bool need_change = false;
361 int found = 0;
362
363 const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
364 const Variable *var = actual->var.get();
365 ForType expected = IterVarTypeToForType(iter_var->iter_type);
366 IterVarAttr attr;
367 if (stage->iter_var_attrs.count(iter_var)) {
368 attr = stage->iter_var_attrs[iter_var];
369 expected = IterVarTypeToForType(attr->iter_type);
370 }
371
372 PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const NodeRef &node) {
373 if (const For *op = node.as<For>()) {
374 if (op->loop_var.get() == var) {
375 ++found;
376 need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined());
377 }
378 }
379 });
380
381 CHECK_EQ(found, 1) << " iter var should be found exactly once!";
382 if (need_change) {
383 stmt = LoopAnnotator(var, attr).Mutate(stmt);
384 }
385 }
386 return stmt;
387 }
388
ApplyLoopOrder(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,const std::unordered_map<IterVar,IterVar> & rebased,Stmt stmt)389 Stmt ApplyLoopOrder(const Stage &stage,
390 const std::unordered_map<IterVar, Range> &dom_map,
391 const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
392 std::vector<const Variable*> current_order;
393 PostOrderVisit(stmt, [¤t_order](const NodeRef &node) {
394 if (const For *op = node.as<For>())
395 current_order.push_back(op->loop_var.get());
396 });
397 std::reverse(current_order.begin(), current_order.end());
398 auto &required_ord = stage->leaf_iter_vars;
399 CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!";
400 std::unordered_map<const Variable *, IterVar> reorder;
401 bool need_reorder = false;
402 for (size_t i = 0; i < current_order.size(); ++i) {
403 auto ¤t = current_order[i];
404 const IterVar &iter_var = required_ord[i];
405 const IterVar &required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
406 CHECK(required->dom.defined() || dom_map.count(required)) << required << "\n";
407 reorder[current] = required;
408 if (current != required->var.get()) {
409 need_reorder = true;
410 }
411 }
412
413 class LoopReorder : public IRMutator {
414 const Stage &stage;
415 const std::unordered_map<IterVar, Range> &dom_map;
416 const std::unordered_map<const Variable *, IterVar> &reorder;
417
418 public:
419 LoopReorder(const Stage &stage,
420 const std::unordered_map<IterVar, Range> &dom_map,
421 const std::unordered_map<const Variable*, IterVar> &reorder)
422 : stage(stage), dom_map(dom_map), reorder(reorder) {}
423
424 Stmt Mutate_(const For *op, const Stmt &stmt) {
425 // Reorder from in to out
426 Stmt body_ = IRMutator::Mutate(op->body);
427 CHECK(reorder.count(op->loop_var.get()));
428 auto target = reorder.find(op->loop_var.get())->second;
429 if (body_.same_as(op->body) && op->loop_var.get() == target->var.get())
430 return stmt;
431 const Stmt &body = op->body.same_as(body_) ? op->body : body_;
432 ForType for_type = IterVarTypeToForType(target->iter_type);
433 if (stage->iter_var_attrs.count(target)) {
434 for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
435 }
436 const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
437 return For::make(target->var, range->min, range->extent,
438 for_type, DeviceAPI::None, body);
439 }
440 };
441
442 if (need_reorder)
443 return LoopReorder(stage, dom_map, reorder).Mutate(stmt);
444
445 return stmt;
446 }
447
ApplySchedule(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,Stmt stmt)448 Stmt ApplySchedule(const Stage &stage,
449 const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
450 // TODO(@were): Eliminate loop rebase in script parser and move the burden here
451 // Gather rebased variables
452 std::unordered_map<IterVar, IterVar> rebased;
453 for (auto rel : stage->relations) {
454 if (const auto* rebase = rel.as<RebaseNode>()) {
455 rebased[rebase->rebased] = rebase->parent;
456 CHECK(rebase->parent->dom.defined());
457 CHECK(dom_map.count(rebase->rebased));
458 }
459 }
460 stmt = ApplyLoopShapes(stage, dom_map, stmt);
461 stmt = ApplyLoopOrder(stage, dom_map, rebased, stmt);
462 stmt = ApplyLoopAnnotations(stage, rebased, stmt);
463 return stmt;
464 }
465
GatherLoopVars(Stmt stmt)466 std::vector<IterVar> GatherLoopVars(Stmt stmt) {
467 // TODO(@were): Write a comprehensive pass to analyze iter var types
468 std::vector<IterVar> res_;
469 PostOrderVisit(stmt, [&res_](const NodeRef &node) {
470 if (const For *op = node.as<For>()) {
471 Var loop_var(op->loop_var);
472 Range dom = Range::make_by_min_extent(op->min, op->extent);
473 res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type)));
474 }
475 });
476 std::reverse(res_.begin(), res_.end());
477 return res_;
478 }
479
480 // replacer to replace tensors' usage in Provide
481 class ProviderReplacer : public ir::IRMutator {
482 public:
ProviderReplacer(const std::unordered_map<Tensor,Tensor> & vmap)483 explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor> &vmap)
484 : vmap_(vmap) {}
485
Mutate_(const ir::Provide * op,const Stmt & s)486 Stmt Mutate_(const ir::Provide* op, const Stmt &s) {
487 Tensor t = Downcast<Operation>(op->func).output(op->value_index);
488 auto it = vmap_.find(t);
489 if (it != vmap_.end()) {
490 Stmt ret = ir::Provide::make(
491 it->second->op, it->second->value_index, op->value, op->args);
492 found = true;
493 return IRMutator::Mutate_(ret.as<ir::Provide>(), ret);
494 }
495 return IRMutator::Mutate_(op, s);
496 }
497
498 // whether it is found.
499 bool found{false};
500
501 private:
502 const std::unordered_map<Tensor, Tensor> &vmap_;
503 };
504
ReplaceProvideTensor(Stmt stmt,const std::unordered_map<Tensor,Tensor> & replace)505 Stmt ReplaceProvideTensor(Stmt stmt,
506 const std::unordered_map<Tensor, Tensor> &replace) {
507 ProviderReplacer repl(replace);
508 Stmt ret = repl.Mutate(stmt);
509 return repl.found ? ret : stmt;
510 }
511 } // namespace op
512 } // namespace tvm
513