1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \brief Utility to make loop nest.
22 * \file op_util.cc
23 */
24 #include <tvm/ir.h>
25 #include <tvm/ir_pass.h>
26 #include <tvm/operation.h>
27 #include <tvm/ir_mutator.h>
28 #include <string>
29 #include "op_util.h"
30 #include "../schedule/message_passing.h"
31 #include "../arithmetic/compute_expr.h"
32
33 namespace tvm {
34 namespace op {
35
36 using namespace arith;
37 using namespace ir;
38
39 std::vector<std::vector<Stmt> >
MakeLoopNest(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,size_t begin_iter_pos,bool new_loop_var,const std::unordered_set<IterVar> & skip_iter,std::unordered_map<IterVar,Expr> * p_value_map,bool debug_keep_trivial_loop)40 MakeLoopNest(const Stage& stage,
41 const std::unordered_map<IterVar, Range>& dom_map,
42 size_t begin_iter_pos,
43 bool new_loop_var,
44 const std::unordered_set<IterVar>& skip_iter,
45 std::unordered_map<IterVar, Expr>* p_value_map,
46 bool debug_keep_trivial_loop) {
47 auto leaf_iter_vars = stage->leaf_iter_vars;
48 Stmt no_op = Evaluate::make(0);
49 // create the loop nest
50 std::vector<std::vector<Stmt> > nest;
51 nest.resize(leaf_iter_vars.size() + 1);
52 std::unordered_map<IterVar, Expr>& value_map = *p_value_map;
53
54 for (size_t i = begin_iter_pos; i < leaf_iter_vars.size(); ++i) {
55 auto iv = leaf_iter_vars[i];
56 if (skip_iter.count(iv) || iv->iter_type == kOpaque) {
57 // skip this iteration.
58 value_map[iv] = iv->var;
59 continue;
60 }
61 // Bind iv could be another thread.
62 IterVar bind_iv = iv;
63 if (stage->iter_var_attrs.count(iv)) {
64 IterVar bind_thread = stage->iter_var_attrs[iv]->bind_thread;
65 if (bind_thread.defined()) bind_iv = bind_thread;
66 }
67
68 Range dom = dom_map.at(iv);
69
70 // initialize the offset and loop_level
71 Var var = bind_iv->var;
72
73 // Mark the iter var in the IR, to remember the point
74 if (bind_iv->thread_tag.length() == 0) {
75 // Only generate new loop if we're not bound to a thread.
76 if (new_loop_var) {
77 var = Var(iv->var->name_hint + ".init", bind_iv->var.type());
78 }
79
80 ForType for_type = ForType::Serial;
81 IterVarAttr it_attr;
82 if (stage->iter_var_attrs.count(iv)) {
83 it_attr = stage->iter_var_attrs[iv];
84 }
85 if (it_attr.defined()) {
86 switch (it_attr->iter_type) {
87 case kUnrolled: for_type = ForType::Unrolled; break;
88 case kVectorized: for_type = ForType::Vectorized; break;
89 case kParallelized: for_type = ForType::Parallel; break;
90 case kDataPar: break;
91 case kTensorized: break;
92 default: LOG(FATAL) << "Unknown iter type"
93 << it_attr->iter_type
94 << " in the iter_var_attrs";
95 }
96 CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
97 for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
98 const std::string& pkey = it_attr->pragma_keys[k].as<StringImm>()->value;
99 Expr pvalue = it_attr->pragma_values[k];
100 if (!pvalue.defined()) {
101 pvalue = make_const(Int(32), 1);
102 }
103 nest[i + 1].emplace_back(
104 AttrStmt::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
105 }
106 }
107 if (!debug_keep_trivial_loop && is_one(dom->extent)) {
108 nest[i + 1].emplace_back(
109 LetStmt::make(var, dom->min, no_op));
110 value_map[iv] = dom->min;
111 } else if (is_zero(dom->min)) {
112 nest[i + 1].emplace_back(
113 For::make(var, 0, dom->extent,
114 for_type, DeviceAPI::None, no_op));
115 value_map[iv] = var;
116 } else {
117 Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.type());
118 nest[i + 1].emplace_back(
119 For::make(idx, 0, dom->extent,
120 for_type, DeviceAPI::None, no_op));
121 Expr new_value = dom->min + idx;
122 value_map[iv] = new_value;
123 nest[i + 1].emplace_back(
124 LetStmt::make(var, new_value, no_op));
125 }
126 if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
127 CHECK(!is_one(dom->extent))
128 << "Cannot prefetch on trivial loop with extent=1";
129 CHECK_EQ(it_attr->prefetch_data.size(),
130 it_attr->prefetch_offset.size());
131 for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
132 nest[i + 1].emplace_back(
133 AttrStmt::make(it_attr->prefetch_data[j],
134 ir::attr::prefetch_scope,
135 it_attr->prefetch_offset[j], no_op));
136 }
137 }
138 } else if (bind_iv->thread_tag == "vthread" ||
139 bind_iv->thread_tag == "cthread") {
140 // virtual thread
141 // Always restrict threaded IterVar to starts from 0.
142 CHECK(is_zero(dom->min));
143 CHECK(is_positive_const(dom->extent));
144 // annotate the extent of the IterVar
145 nest[i + 1].emplace_back(
146 AttrStmt::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
147 value_map[iv] = var;
148 } else if (bind_iv->thread_tag == "pipeline") {
149 // pipeline marker.
150 CHECK(is_zero(dom->min));
151 CHECK(is_one(dom->extent));
152 // annotate the extent of the IterVar
153 nest[i + 1].emplace_back(
154 AttrStmt::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
155 value_map[iv] = dom->min;
156 } else {
157 // Always restrict threaded IterVar to starts from 0.
158 CHECK(is_zero(dom->min));
159 // annotate the extent of the IterVar
160 nest[i + 1].emplace_back(
161 AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
162 if (!debug_keep_trivial_loop && is_one(dom->extent)) {
163 value_map[iv] = dom->min;
164 } else {
165 value_map[iv] = var;
166 }
167 }
168 // annotate the extent of the IterVar
169 if (!new_loop_var) {
170 nest[i + 1].emplace_back(
171 AttrStmt::make(iv, attr::loop_scope, iv->var, no_op));
172 }
173 }
174 // message passing to get offset of root iter vars.
175 schedule::PassUpIndex(stage, dom_map, &value_map);
176 return nest;
177 }
178
MakeIfNest(const std::vector<Expr> & predicates)179 std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
180 Stmt no_op = Evaluate::make(0);
181 std::vector<Stmt> nest;
182 for (const Expr& cond : predicates) {
183 nest.emplace_back(IfThenElse::make(cond, no_op));
184 }
185 return nest;
186 }
187
188 // replacer to replace tensors
189 class TensorReplacer : public ir::IRMutator {
190 public:
TensorReplacer(const std::unordered_map<Tensor,Tensor> & vmap)191 explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
192 : vmap_(vmap) {}
193
Mutate_(const ir::Call * op,const Expr & e)194 Expr Mutate_(const ir::Call* op, const Expr& e) {
195 if (op->call_type == ir::Call::Halide) {
196 Tensor t = Downcast<Operation>(op->func).output(op->value_index);
197 auto it = vmap_.find(t);
198 if (it != vmap_.end()) {
199 Expr ret = ir::Call::make(
200 op->type, it->second->op->name, op->args,
201 op->call_type, it->second->op, it->second->value_index);
202 found = true;
203 return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
204 }
205 }
206 return IRMutator::Mutate_(op, e);
207 }
208
209 // whether it is found.
210 bool found{false};
211
212 private:
213 const std::unordered_map<Tensor, Tensor>& vmap_;
214 };
215
ReplaceTensor(Stmt stmt,const std::unordered_map<Tensor,Tensor> & replace)216 Stmt ReplaceTensor(Stmt stmt,
217 const std::unordered_map<Tensor, Tensor>& replace) {
218 TensorReplacer repl(replace);
219 Stmt ret = repl.Mutate(stmt);
220 return repl.found ? ret : stmt;
221 }
ReplaceTensor(Expr expr,const std::unordered_map<Tensor,Tensor> & replace)222 Expr ReplaceTensor(Expr expr,
223 const std::unordered_map<Tensor, Tensor>& replace) {
224 TensorReplacer repl(replace);
225 Expr ret = repl.Mutate(expr);
226 return repl.found ? ret : expr;
227 }
228
229
Substitute(Stmt s,const std::unordered_map<IterVar,Expr> & value_map)230 Stmt Substitute(Stmt s,
231 const std::unordered_map<IterVar, Expr>& value_map) {
232 std::unordered_map<const Variable*, Expr> init;
233 for (const auto& kv : value_map) {
234 init[kv.first->var.get()] = kv.second;
235 }
236 return ir::Substitute(s, init);
237 }
238
ForTypeToIterVarType(ir::ForType for_type)239 IterVarType ForTypeToIterVarType(ir::ForType for_type) {
240 switch (for_type) {
241 case ForType::Serial:
242 return kDataPar;
243 case ForType::Parallel:
244 return kParallelized;
245 case ForType::Vectorized:
246 return kVectorized;
247 case ForType::Unrolled:
248 return kUnrolled;
249 default:
250 return kDataPar;
251 }
252 }
253
IterVarTypeToForType(IterVarType iter_type)254 ir::ForType IterVarTypeToForType(IterVarType iter_type) {
255 switch (iter_type) {
256 case kDataPar:
257 return ForType::Serial;
258 case kParallelized:
259 return ForType::Parallel;
260 case kVectorized:
261 return ForType::Vectorized;
262 case kUnrolled:
263 return ForType::Unrolled;
264 default:
265 return ForType::Serial;
266 }
267 }
268
269 } // namespace op
270 } // namespace tvm
271