1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \brief Utility to make loop nest.
22  * \file op_util.cc
23  */
24 #include <tvm/ir.h>
25 #include <tvm/ir_pass.h>
26 #include <tvm/operation.h>
27 #include <tvm/ir_mutator.h>
28 #include <string>
29 #include "op_util.h"
30 #include "../schedule/message_passing.h"
31 #include "../arithmetic/compute_expr.h"
32 
33 namespace tvm {
34 namespace op {
35 
36 using namespace arith;
37 using namespace ir;
38 
39 std::vector<std::vector<Stmt> >
MakeLoopNest(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,size_t begin_iter_pos,bool new_loop_var,const std::unordered_set<IterVar> & skip_iter,std::unordered_map<IterVar,Expr> * p_value_map,bool debug_keep_trivial_loop)40 MakeLoopNest(const Stage& stage,
41              const std::unordered_map<IterVar, Range>& dom_map,
42              size_t begin_iter_pos,
43              bool new_loop_var,
44              const std::unordered_set<IterVar>& skip_iter,
45              std::unordered_map<IterVar, Expr>* p_value_map,
46              bool debug_keep_trivial_loop) {
47   auto leaf_iter_vars = stage->leaf_iter_vars;
48   Stmt no_op = Evaluate::make(0);
49   // create the loop nest
50   std::vector<std::vector<Stmt> > nest;
51   nest.resize(leaf_iter_vars.size() + 1);
52   std::unordered_map<IterVar, Expr>& value_map = *p_value_map;
53 
54   for (size_t i = begin_iter_pos; i < leaf_iter_vars.size(); ++i) {
55     auto iv = leaf_iter_vars[i];
56     if (skip_iter.count(iv) || iv->iter_type == kOpaque) {
57       // skip this iteration.
58       value_map[iv] = iv->var;
59       continue;
60     }
61     // Bind iv could be another thread.
62     IterVar bind_iv = iv;
63     if (stage->iter_var_attrs.count(iv)) {
64       IterVar bind_thread = stage->iter_var_attrs[iv]->bind_thread;
65       if (bind_thread.defined()) bind_iv = bind_thread;
66     }
67 
68     Range dom = dom_map.at(iv);
69 
70     // initialize the offset and loop_level
71     Var var = bind_iv->var;
72 
73     // Mark the iter var in the IR, to remember the point
74     if (bind_iv->thread_tag.length() == 0) {
75       // Only generate new loop if we're not bound to a thread.
76       if (new_loop_var) {
77         var = Var(iv->var->name_hint + ".init", bind_iv->var.type());
78       }
79 
80       ForType for_type = ForType::Serial;
81       IterVarAttr it_attr;
82       if (stage->iter_var_attrs.count(iv)) {
83         it_attr = stage->iter_var_attrs[iv];
84       }
85       if (it_attr.defined()) {
86         switch (it_attr->iter_type) {
87           case kUnrolled: for_type = ForType::Unrolled; break;
88           case kVectorized: for_type = ForType::Vectorized; break;
89           case kParallelized: for_type = ForType::Parallel; break;
90           case kDataPar: break;
91           case kTensorized: break;
92           default: LOG(FATAL) << "Unknown iter type"
93                               << it_attr->iter_type
94                               << " in the iter_var_attrs";
95         }
96         CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
97         for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
98           const std::string& pkey = it_attr->pragma_keys[k].as<StringImm>()->value;
99           Expr pvalue = it_attr->pragma_values[k];
100           if (!pvalue.defined()) {
101             pvalue = make_const(Int(32), 1);
102           }
103           nest[i + 1].emplace_back(
104               AttrStmt::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
105         }
106       }
107       if (!debug_keep_trivial_loop && is_one(dom->extent)) {
108         nest[i + 1].emplace_back(
109             LetStmt::make(var, dom->min, no_op));
110         value_map[iv] = dom->min;
111       } else if (is_zero(dom->min)) {
112         nest[i + 1].emplace_back(
113             For::make(var, 0, dom->extent,
114                       for_type, DeviceAPI::None, no_op));
115         value_map[iv] = var;
116       } else {
117         Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.type());
118         nest[i + 1].emplace_back(
119             For::make(idx, 0, dom->extent,
120                       for_type, DeviceAPI::None, no_op));
121         Expr new_value = dom->min + idx;
122         value_map[iv] = new_value;
123         nest[i + 1].emplace_back(
124             LetStmt::make(var, new_value, no_op));
125       }
126       if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
127         CHECK(!is_one(dom->extent))
128             << "Cannot prefetch on trivial loop with extent=1";
129         CHECK_EQ(it_attr->prefetch_data.size(),
130                  it_attr->prefetch_offset.size());
131         for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
132           nest[i + 1].emplace_back(
133               AttrStmt::make(it_attr->prefetch_data[j],
134                              ir::attr::prefetch_scope,
135                              it_attr->prefetch_offset[j], no_op));
136         }
137       }
138     } else if (bind_iv->thread_tag == "vthread" ||
139                bind_iv->thread_tag == "cthread") {
140       // virtual thread
141       // Always restrict threaded IterVar to starts from 0.
142       CHECK(is_zero(dom->min));
143       CHECK(is_positive_const(dom->extent));
144       // annotate the extent of the IterVar
145       nest[i + 1].emplace_back(
146           AttrStmt::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
147       value_map[iv] = var;
148     } else if (bind_iv->thread_tag == "pipeline") {
149       // pipeline marker.
150       CHECK(is_zero(dom->min));
151       CHECK(is_one(dom->extent));
152       // annotate the extent of the IterVar
153       nest[i + 1].emplace_back(
154           AttrStmt::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
155       value_map[iv] = dom->min;
156     } else {
157       // Always restrict threaded IterVar to starts from 0.
158       CHECK(is_zero(dom->min));
159       // annotate the extent of the IterVar
160       nest[i + 1].emplace_back(
161           AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
162       if (!debug_keep_trivial_loop && is_one(dom->extent)) {
163         value_map[iv] = dom->min;
164       } else {
165         value_map[iv] = var;
166       }
167     }
168     // annotate the extent of the IterVar
169     if (!new_loop_var) {
170       nest[i + 1].emplace_back(
171           AttrStmt::make(iv, attr::loop_scope, iv->var, no_op));
172     }
173   }
174   // message passing to get offset of root iter vars.
175   schedule::PassUpIndex(stage, dom_map, &value_map);
176   return nest;
177 }
178 
MakeIfNest(const std::vector<Expr> & predicates)179 std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
180   Stmt no_op = Evaluate::make(0);
181   std::vector<Stmt> nest;
182   for (const Expr& cond : predicates) {
183     nest.emplace_back(IfThenElse::make(cond, no_op));
184   }
185   return nest;
186 }
187 
188 // replacer to replace tensors
189 class TensorReplacer : public ir::IRMutator {
190  public:
TensorReplacer(const std::unordered_map<Tensor,Tensor> & vmap)191   explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
192       : vmap_(vmap) {}
193 
Mutate_(const ir::Call * op,const Expr & e)194   Expr Mutate_(const ir::Call* op, const Expr& e) {
195     if (op->call_type == ir::Call::Halide) {
196       Tensor t = Downcast<Operation>(op->func).output(op->value_index);
197       auto it = vmap_.find(t);
198       if (it != vmap_.end()) {
199         Expr ret = ir::Call::make(
200             op->type, it->second->op->name, op->args,
201             op->call_type, it->second->op, it->second->value_index);
202         found = true;
203         return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
204       }
205     }
206     return IRMutator::Mutate_(op, e);
207   }
208 
209   // whether it is found.
210   bool found{false};
211 
212  private:
213   const std::unordered_map<Tensor, Tensor>& vmap_;
214 };
215 
ReplaceTensor(Stmt stmt,const std::unordered_map<Tensor,Tensor> & replace)216 Stmt ReplaceTensor(Stmt stmt,
217                    const std::unordered_map<Tensor, Tensor>& replace) {
218   TensorReplacer repl(replace);
219   Stmt ret = repl.Mutate(stmt);
220   return repl.found ? ret : stmt;
221 }
ReplaceTensor(Expr expr,const std::unordered_map<Tensor,Tensor> & replace)222 Expr ReplaceTensor(Expr expr,
223                    const std::unordered_map<Tensor, Tensor>& replace) {
224   TensorReplacer repl(replace);
225   Expr ret = repl.Mutate(expr);
226   return repl.found ? ret : expr;
227 }
228 
229 
Substitute(Stmt s,const std::unordered_map<IterVar,Expr> & value_map)230 Stmt Substitute(Stmt s,
231                 const std::unordered_map<IterVar, Expr>& value_map) {
232   std::unordered_map<const Variable*, Expr> init;
233   for (const auto& kv : value_map) {
234     init[kv.first->var.get()] = kv.second;
235   }
236   return ir::Substitute(s, init);
237 }
238 
ForTypeToIterVarType(ir::ForType for_type)239 IterVarType ForTypeToIterVarType(ir::ForType for_type) {
240   switch (for_type) {
241   case ForType::Serial:
242     return kDataPar;
243   case ForType::Parallel:
244     return kParallelized;
245   case ForType::Vectorized:
246     return kVectorized;
247   case ForType::Unrolled:
248     return kUnrolled;
249   default:
250     return kDataPar;
251   }
252 }
253 
IterVarTypeToForType(IterVarType iter_type)254 ir::ForType IterVarTypeToForType(IterVarType iter_type) {
255   switch (iter_type) {
256   case kDataPar:
257     return ForType::Serial;
258   case kParallelized:
259     return ForType::Parallel;
260   case kVectorized:
261     return ForType::Vectorized;
262   case kUnrolled:
263     return ForType::Unrolled;
264   default:
265     return ForType::Serial;
266   }
267 }
268 
269 }  // namespace op
270 }  // namespace tvm
271