1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \brief Scan Operator.
22  * \file scan_op.cc
23  */
24 #include <tvm/operation.h>
25 #include <tvm/ir.h>
26 #include <tvm/ir_pass.h>
27 #include "op_util.h"
28 #include "../schedule/graph.h"
29 
30 namespace tvm {
31 
32 using namespace ir;
33 
34 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon18723ed70102(const ObjectRef& node, IRPrinter* p) 35 .set_dispatch<ScanOpNode>([](const ObjectRef& node, IRPrinter* p) {
36     auto* op = static_cast<const ScanOpNode*>(node.get());
37     p->stream << "scan(" << op->name << ", " << op << ")";
38 });
39 TVM_REGISTER_NODE_TYPE(ScanOpNode);
40 
prove_equal(Expr lhs,Expr rhs)41 inline bool prove_equal(Expr lhs, Expr rhs) {
42   return is_zero(ir::Simplify(lhs - rhs));
43 }
44 
num_outputs() const45 int ScanOpNode::num_outputs() const {
46   return static_cast<int>(update.size());
47 }
root_iter_vars() const48 Array<IterVar> ScanOpNode::root_iter_vars() const {
49   Array<IterVar> ret{scan_axis};
50   for (IterVar iv : spatial_axis_) {
51     ret.push_back(iv);
52   }
53   return ret;
54 }
55 
output_dtype(size_t i) const56 Type ScanOpNode::output_dtype(size_t i) const {
57   return update[i]->dtype;
58 }
59 
output_shape(size_t i) const60 Array<Expr> ScanOpNode::output_shape(size_t i) const {
61   CHECK_LT(i, state_placeholder.size());
62   return state_placeholder[i]->shape;
63 }
64 
make(std::string name,std::string tag,Map<std::string,NodeRef> attrs,IterVar axis,Array<Tensor> init,Array<Tensor> update,Array<Tensor> state_placeholder,Array<Tensor> inputs)65 Operation ScanOpNode::make(std::string name,
66                            std::string tag,
67                            Map<std::string, NodeRef> attrs,
68                            IterVar axis,
69                            Array<Tensor> init,
70                            Array<Tensor> update,
71                            Array<Tensor> state_placeholder,
72                            Array<Tensor> inputs) {
73   if (!attrs.defined()) {
74     attrs = Map<std::string, NodeRef>();
75   }
76   auto n = make_node<ScanOpNode>();
77   CHECK_EQ(init.size(), update.size());
78   CHECK_EQ(init.size(), state_placeholder.size());
79 
80   for (size_t i = 0; i < init.size(); ++i) {
81     CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
82     CHECK_EQ(init[i]->dtype, update[i]->dtype);
83     CHECK(prove_equal(init[i]->shape[0], axis->dom->min))
84         << "init.shape[0] need to match scan_axis.dom.min";
85     CHECK(prove_equal(
86         state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
87         << "state_placeholder.shape[0] need to match"
88         << " scan_axis.dom.min + scan_axis.dom.extent";
89     CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
90         << "The dimension of init need to match state_placeholder";
91     CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim())
92         << "The update.ndim need to be state_placeholder.ndim - 1";
93     for (size_t k = 0;  k < update[i].ndim(); ++k) {
94       CHECK(prove_equal(
95           update[i]->shape[k], state_placeholder[i]->shape[k]));
96       if (k != 0) {
97         // setup spatial axis
98         std::ostringstream spatial_name;
99         spatial_name << name << ".out" << i << ".i" << k;
100         n->spatial_axis_.push_back(
101             IterVarNode::make(
102                 Range::make_by_min_extent(0, update[i]->shape[k]),
103                 Var(spatial_name.str()), kOpaque));
104       }
105     }
106 
107     for (size_t k = 1;  k < init[i].ndim(); ++k) {
108       CHECK(prove_equal(
109           init[i]->shape[k], state_placeholder[i]->shape[k]));
110     }
111   }
112   n->name = std::move(name);
113   n->tag = std::move(tag);
114   n->attrs = std::move(attrs);
115   n->scan_axis = std::move(axis);
116   n->init = std::move(init);
117   n->update = std::move(update);
118   n->state_placeholder = std::move(state_placeholder);
119   n->inputs = std::move(inputs);
120   return Operation(n);
121 }
122 
scan(Array<Tensor> init,Array<Tensor> update,Array<Tensor> state_placeholder,Array<Tensor> inputs,std::string name,std::string tag,Map<std::string,NodeRef> attrs)123 Array<Tensor> scan(Array<Tensor> init,
124                    Array<Tensor> update,
125                    Array<Tensor> state_placeholder,
126                    Array<Tensor> inputs,
127                    std::string name,
128                    std::string tag,
129                    Map<std::string, NodeRef> attrs) {
130   IterVar scan_axis =
131       IterVarNode::make(
132           Range::make_by_min_extent(
133               init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
134           Var(name + ".idx"), kOrdered);
135   Operation op = ScanOpNode::make(
136       name, tag, attrs, scan_axis,
137       init, update, state_placeholder, inputs);
138   Array<Tensor> res;
139   for (int i = 0; i < op->num_outputs(); ++i) {
140     res.push_back(op.output(i));
141   }
142   return res;
143 }
144 
InputTensors() const145 Array<Tensor> ScanOpNode::InputTensors() const {
146   Array<Tensor> ret;
147   for (Tensor t : init) {
148     ret.push_back(t);
149   }
150   for (Tensor t : update) {
151     ret.push_back(t);
152   }
153   return ret;
154 }
155 
ReplaceInputs(const Operation & self,const std::unordered_map<Tensor,Tensor> & rmap) const156 Operation ScanOpNode::ReplaceInputs(
157     const Operation& self,
158     const std::unordered_map<Tensor, Tensor>& rmap) const {
159   CHECK_EQ(self.operator->(), this);
160   auto n = make_node<ScanOpNode>(*this);
161   for (size_t i = 0; i < n->init.size(); ++i) {
162     if (rmap.count(n->init[i])) {
163       n->init.Set(i, rmap.at(n->init[i]));
164     }
165     if (rmap.count(n->update[i])) {
166       n->update.Set(i, rmap.at(n->update[i]));
167     }
168   }
169   if (!n->init.same_as(init) ||
170       !n->update.same_as(update)) {
171     return Operation(n);
172   } else {
173     return self;
174   }
175 }
176 
PropBoundToInputs(const Operation & self,arith::Analyzer * analyzer,const std::unordered_map<const Variable *,IntSet> & dom_map,std::unordered_map<Tensor,TensorDom> * out_dom_map) const177 void ScanOpNode::PropBoundToInputs(
178     const Operation& self,
179     arith::Analyzer* analyzer,
180     const std::unordered_map<const Variable*, IntSet>& dom_map,
181     std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
182   CHECK_EQ(self.operator->(), this);
183   for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) {
184     TensorDom* init_dom = nullptr;
185     TensorDom* update_dom = nullptr;
186     if (out_dom_map->count(this->init[i])) {
187       init_dom = &out_dom_map->at(this->init[i]);
188     }
189     if (out_dom_map->count(this->update[i])) {
190       update_dom = &out_dom_map->at(this->update[i]);
191     }
192     // first dimension, always needed.
193     if (init_dom) {
194       init_dom->data[0].push_back(IntSet::range(
195           Range::make_by_min_extent(0, this->init[i]->shape[0])));
196     }
197     if (update_dom) {
198       update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get()));
199     }
200     // The update dimensions
201     for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
202       IterVar sp_ax = this->spatial_axis_[sp_idx];
203       if (init_dom) {
204         init_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
205       }
206       if (update_dom) {
207         update_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
208       }
209     }
210   }
211 }
212 
GatherBound(const Operation & self,const std::unordered_map<Tensor,TensorDom> & tensor_dom,std::unordered_map<IterVar,Range> * out_dom_map) const213 void ScanOpNode::GatherBound(
214     const Operation& self,
215     const std::unordered_map<Tensor, TensorDom>& tensor_dom,
216     std::unordered_map<IterVar, Range>* out_dom_map) const {
217   CHECK_EQ(self.operator->(), this);
218   using namespace schedule;
219   CHECK(!out_dom_map->count(this->scan_axis));
220   std::vector<Tensor> output(this->num_outputs());
221   for (size_t i = 0; i < output.size(); ++i) {
222     output[i] = self.output(i);
223   }
224   // Update for time axis.
225   std::vector<IntSet> time_dom;
226   for (size_t i = 0; i < output.size(); ++i) {
227     const TensorDom& d = tensor_dom.at(output[i]);
228     time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
229   }
230   CHECK(!out_dom_map->count(this->scan_axis));
231   Range sdom = this->scan_axis->dom;
232   Range r = arith::Union(time_dom).cover_range(sdom);
233   (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent(
234       sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
235   Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self);
236   // Update for spatial axis.
237   size_t sp_idx = 0;
238   for (size_t i = 0; i < output.size(); ++i) {
239     const TensorDom& d = tensor_dom.at(output[i]);
240     for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
241       IterVar sp_ax = this->spatial_axis_[sp_idx];
242       CHECK(!out_dom_map->count(sp_ax));
243       CHECK(fix_pt.count(sp_ax));
244       if (fix_pt[sp_ax].as<ir::IntImm>()->value) {
245         // fix point, we can slice it.
246         (*out_dom_map)[sp_ax] = arith::Union(d.data[k]).cover_range(sp_ax->dom);
247       } else {
248         // not a fix point, need to include everything.
249         (*out_dom_map)[sp_ax] = sp_ax->dom;
250       }
251     }
252   }
253 }
254 
BuildRealize(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,const Stmt & body) const255 Stmt ScanOpNode::BuildRealize(
256     const Stage& stage,
257     const std::unordered_map<IterVar, Range>& dom_map,
258     const Stmt& body) const {
259   CHECK_EQ(stage->op.get(), this);
260   Range sdom = dom_map.at(this->scan_axis);
261   Range tdom = Range::make_by_min_extent(
262       0, ir::Simplify(sdom->extent + sdom->min));
263   Stmt ret = body;
264   size_t sp_idx = 0;
265   for (size_t i = 0; i < update.size(); ++i) {
266     Tensor t = stage->op.output(i);
267     CHECK_EQ(static_cast<size_t>(t->value_index), i);
268     Region bounds;
269     bounds.push_back(tdom);
270     for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
271       IterVar sp_ax = this->spatial_axis_[sp_idx];
272       bounds.push_back(dom_map.at(sp_ax));
273     }
274     ret = ir::Realize::make(t->op, t->value_index, t->dtype,
275                             bounds, const_true(), ret);
276   }
277   return ret;
278 }
279 
BuildProvide(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,bool debug_keep_trivial_loop) const280 Stmt ScanOpNode::BuildProvide(
281     const Stage& stage,
282     const std::unordered_map<IterVar, Range>& dom_map,
283     bool debug_keep_trivial_loop) const {
284   CHECK_EQ(stage->op.operator->(), this);
285   Stmt provide = AttrStmt::make(
286       stage->op, attr::scan_update_scope, this->scan_axis->var,
287       Evaluate::make(0));
288   Stmt init = AttrStmt::make(
289       stage->op, attr::scan_init_scope, 0,
290       Evaluate::make(0));
291   size_t begin_scan = 0;
292   for (size_t  i = 0; i < stage->leaf_iter_vars.size(); ++i) {
293     if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) {
294       CHECK_EQ(begin_scan, i);
295       begin_scan = i + 1;
296     }
297   }
298   std::unordered_map<IterVar, Expr> vmap;
299   std::unordered_set<IterVar> empty;
300   auto nest = op::MakeLoopNest(
301       stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop);
302   nest[begin_scan].push_back(init);
303   nest.push_back(
304       op::MakeIfNest(
305           schedule::MakeBoundCheck(stage, dom_map, vmap, false, empty)));
306   return MergeNest(nest, provide);
307 }
308 }  // namespace tvm
309