1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \brief Compute Op.
22  * \file compute_op.cc
23  */
24 #include <tvm/operation.h>
25 #include <tvm/arithmetic.h>
26 #include <tvm/ir.h>
27 #include <tvm/ir_visitor.h>
28 #include <tvm/ir_pass.h>
29 #include <unordered_set>
30 #include <string>
31 #include <utility>
32 #include "compute_op.h"
33 #include "op_util.h"
34 #include "../schedule/message_passing.h"
35 #include "../arithmetic/compute_expr.h"
36 #include "../arithmetic/int_set.h"
37 
38 namespace tvm {
39 
40 using namespace ir;
41 
42 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anoncb9766f90102(const ObjectRef& node, IRPrinter* p) 43 .set_dispatch<ComputeOpNode>([](const ObjectRef& node, IRPrinter* p) {
44     auto* op = static_cast<const ComputeOpNode*>(node.get());
45     p->stream << "compute(" << op->name << ", " << op << ")";
46 });
47 
48 TVM_REGISTER_NODE_TYPE(ComputeOpNode);
49 
50 /// Verify if ComputeOp is valid with respect to Reduce operations.
51 static void VerifyComputeOp(const ComputeOpNode *op);
52 
ReduceEqual(const ir::Reduce * a,const ir::Reduce * b)53 inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
54   return (a->combiner.same_as(b->combiner)) &&
55          (a->source.same_as(b->source)) &&
56          (a->axis.same_as(b->axis)) &&
57          (a->condition.same_as(b->condition));
58 }
59 
num_outputs() const60 int ComputeOpNode::num_outputs() const {
61   return body.size();
62 }
63 
root_iter_vars() const64 Array<IterVar> BaseComputeOpNode::root_iter_vars() const {
65   if (reduce_axis.size() == 0) return axis;
66   Array<IterVar> ret = axis;
67   for (IterVar iv : reduce_axis) {
68     ret.push_back(iv);
69   }
70   return ret;
71 }
72 
output_dtype(size_t idx) const73 Type ComputeOpNode::output_dtype(size_t idx) const {
74   CHECK_LT(idx, num_outputs());
75   return body[idx].type();
76 }
77 
output_shape(size_t idx) const78 Array<Expr> BaseComputeOpNode::output_shape(size_t idx) const {
79   CHECK_LT(idx, num_outputs());
80   // for now, all outputs of a BaseComputeOp have the same shape
81   Array<Expr> shape;
82   for (const auto& ivar : this->axis) {
83     const Range& r = ivar->dom;
84     shape.push_back(r->extent);
85   }
86   return shape;
87 }
88 
compute(Array<Expr> shape,FCompute fcompute,std::string name,std::string tag,Map<std::string,NodeRef> attrs)89 Tensor compute(Array<Expr> shape,
90                FCompute fcompute,
91                std::string name,
92                std::string tag,
93                Map<std::string, NodeRef> attrs) {
94   auto op_node = make_node<ComputeOpNode>();
95   // compute dimension.
96   size_t ndim = shape.size();
97   std::vector<IterVar> axis;
98   std::vector<Var> args;
99   for (size_t i = 0; i < ndim; ++i) {
100     std::ostringstream os;
101     os << "ax" << i;
102     axis.emplace_back(IterVarNode::make(
103         Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar));
104     args.push_back(axis.back()->var);
105   }
106 
107   return ComputeOpNode::make(
108       name, tag, attrs, axis, {fcompute(args)}).output(0);
109 }
110 
compute(Array<Expr> shape,FBatchCompute fcompute,std::string name,std::string tag,Map<std::string,NodeRef> attrs)111 Array<Tensor> compute(Array<Expr> shape,
112                       FBatchCompute fcompute,
113                       std::string name,
114                       std::string tag,
115                       Map<std::string, NodeRef> attrs) {
116   auto op_node = make_node<ComputeOpNode>();
117   // compute dimension.
118   size_t ndim = shape.size();
119   std::vector<IterVar> axis;
120   std::vector<Var> args;
121   for (size_t i = 0; i < ndim; ++i) {
122     std::ostringstream os;
123     os << "ax" << i;
124     axis.emplace_back(IterVarNode::make(
125         Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar));
126     args.push_back(axis.back()->var);
127   }
128 
129   Operation op = ComputeOpNode::make(name, tag, attrs, axis, fcompute(args));
130   Array<Tensor> outputs;
131   for (int idx = 0; idx < op->num_outputs(); ++idx) {
132     outputs.push_back(op.output(idx));
133   }
134   return outputs;
135 }
136 
make(std::string name,std::string tag,Map<std::string,NodeRef> attrs,Array<IterVar> axis,Array<Expr> body)137 Operation ComputeOpNode::make(std::string name,
138                               std::string tag,
139                               Map<std::string, NodeRef> attrs,
140                               Array<IterVar> axis,
141                               Array<Expr> body) {
142   if (!attrs.defined()) {
143     attrs = Map<std::string, NodeRef>();
144   }
145   auto n = make_node<ComputeOpNode>();
146   n->name = std::move(name);
147   n->tag = std::move(tag);
148   n->attrs = std::move(attrs);
149   n->axis = std::move(axis);
150   n->body = std::move(body);
151   if (n->body[0]->IsInstance<ir::Reduce>()) {
152     const ir::Reduce* reduce = n->body[0].as<ir::Reduce>();
153     n->reduce_axis = reduce->axis;
154   }
155   VerifyComputeOp(n.get());
156   return Operation(n);
157 }
158 
159 // The schedule related logics
InputTensors() const160 Array<Tensor> ComputeOpNode::InputTensors() const {
161   Array<Tensor> ret;
162   std::unordered_set<Tensor> visited;
163   for (auto& e : body) {
164     ir::PostOrderVisit(e, [&ret, &visited](const NodeRef& n) {
165         const ir::Call *call = n.as<ir::Call>();
166         if (call != nullptr && call->func.defined()) {
167           Tensor t = Downcast<Operation>(call->func).output(call->value_index);
168           if (!visited.count(t)) {
169             ret.push_back(t);
170             visited.insert(t);
171           }
172         }
173       });
174   }
175   return ret;
176 }
177 
ReplaceInputs(const Operation & self,const std::unordered_map<Tensor,Tensor> & rmap) const178 Operation ComputeOpNode::ReplaceInputs(
179     const Operation& self,
180     const std::unordered_map<Tensor, Tensor>& rmap) const {
181   CHECK_EQ(self.operator->(), this);
182   VerifyComputeOp(this);
183   Array<Expr> arr;
184   if (this->body[0]->IsInstance<ir::Reduce>()) {
185     // Specially handle reduce so the replaced op
186     // still share all the components
187     Expr new_reduce = op::ReplaceTensor(this->body[0], rmap);
188     if (!new_reduce.same_as(this->body[0])) {
189       const ir::Reduce* r = new_reduce.as<ir::Reduce>();
190       for (size_t k = 0; k < this->body.size(); ++k) {
191         auto n = make_node<ir::Reduce>(*r);
192         n->value_index = static_cast<int>(k);
193         n->type = r->source[k].type();
194         arr.push_back(Expr(n));
195       }
196     } else {
197       arr = this->body;
198     }
199   } else {
200     arr = UpdateArray(this->body, [&rmap] (const Expr& e) {
201         return op::ReplaceTensor(e, rmap);
202       });
203   }
204   if (!arr.same_as(this->body)) {
205     return ComputeOpNode::make(
206         this->name, this->tag, this->attrs, this->axis, arr);
207   } else {
208     return self;
209   }
210 }
211 
PropBoundToInputs(const Operation & self,arith::Analyzer * analyzer,const std::unordered_map<const Variable *,IntSet> & dom_map,std::unordered_map<Tensor,TensorDom> * out_dom_map) const212 void ComputeOpNode::PropBoundToInputs(
213     const Operation& self,
214     arith::Analyzer* analyzer,
215     const std::unordered_map<const Variable*, IntSet>& dom_map,
216     std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
217   CHECK_EQ(self.operator->(), this);
218   auto fvisit = [&dom_map, out_dom_map, analyzer](const NodeRef& n) {
219     auto *call = n.as<ir::Call>();
220     if (call != nullptr && call->func.defined()) {
221       Tensor t = Downcast<Operation>(call->func).output(call->value_index);
222       if (t->op.defined() && out_dom_map->count(t)) {
223         TensorDom& dom = out_dom_map->at(t);
224         for (size_t i = 0; i < t.ndim(); ++i) {
225           // We assume that the value of the argument cannot be out of bounds (otherwise it is
226           // undefined behaviour), so we can intersect the estimated set of the argument with the
227           // range expected by the tensor. However, intersection may result in overly complex
228           // expressions, so we perform a more relaxed form of intersection.
229           IntSet arg_intset = EvalSet(call->args[i], dom_map);
230           const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
231           if (arg_interval) {
232             Expr shape_i_min_value = make_zero(t->shape[i].type());
233             Expr shape_i_max_value = t->shape[i] - 1;
234             Expr min_value = arg_interval->min_value;
235             Expr max_value = arg_interval->max_value;
236             // Prefer the shape bounds only when we can prove they are tighter.
237             if (arith::is_neg_inf(min_value) ||
238                 analyzer->CanProve(shape_i_min_value >= min_value)) {
239               min_value = shape_i_min_value;
240             }
241             if (arith::is_pos_inf(max_value) ||
242                 analyzer->CanProve(shape_i_max_value <= max_value)) {
243               max_value = shape_i_max_value;
244             }
245             dom.data[i].push_back(IntSet::interval(min_value, max_value));
246           } else {
247             dom.data[i].push_back(arg_intset);
248           }
249         }
250       }
251     }
252   };
253   for (auto& e : body) ir::PostOrderVisit(e, fvisit);
254 }
255 
GatherBound(const Operation & self,const std::unordered_map<Tensor,TensorDom> & tensor_dom,std::unordered_map<IterVar,Range> * out_dom_map) const256 void BaseComputeOpNode::GatherBound(
257     const Operation& self,
258     const std::unordered_map<Tensor, TensorDom>& tensor_dom,
259     std::unordered_map<IterVar, Range>* out_dom_map) const {
260   CHECK_EQ(self.operator->(), this);
261   const TensorDom& tdom = tensor_dom.at(self.output(0));
262   for (size_t i = 0; i < this->axis.size(); ++i) {
263     Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
264     CHECK(!out_dom_map->count(this->axis[i]));
265     (*out_dom_map)[this->axis[i]] = r;
266   }
267   for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
268     CHECK(!out_dom_map->count(this->reduce_axis[i]));
269     (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
270   }
271 }
272 
BuildRealize(const Stage & stage,const std::unordered_map<IterVar,Range> & realize_map,const Stmt & body) const273 Stmt BaseComputeOpNode::BuildRealize(
274     const Stage& stage,
275     const std::unordered_map<IterVar, Range>& realize_map,
276     const Stmt& body) const {
277   CHECK_EQ(stage->op.get(), this);
278   Region bounds;
279   for (IterVar iv : this->axis) {
280     bounds.push_back(realize_map.at(iv));
281   }
282   Stmt realize = body;
283   for (int i = this->num_outputs(); i > 0; --i) {
284     Tensor t = stage->op.output(i-1);
285     realize = ir::Realize::make(t->op, t->value_index,
286       t->dtype, bounds, const_true(), realize);
287     // alignment requirement, only useful for compute
288     for (size_t i = 0; i < num_schedulable_dims(); ++i) {
289       auto it = stage->iter_var_attrs.find(this->axis[i]);
290       if (it != stage->iter_var_attrs.end()) {
291         IterVarAttr attr = (*it).second;
292         if (attr->dim_align_factor != 0) {
293           Array<Expr> tuple = {static_cast<int>(i),
294                                attr->dim_align_factor,
295                                attr->dim_align_offset};
296           realize = ir::AttrStmt::make(
297               t, ir::attr::buffer_dim_align,
298               Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
299               realize);
300         }
301       }
302     }
303   }
304   return realize;
305 }
306 
num_schedulable_dims() const307 size_t ComputeOpNode::num_schedulable_dims() const {
308   return axis.size();
309 }
310 
311 // Build a reduction body.
MakeReduction(const ComputeOpNode * op,const Array<Tensor> & tensors,Stmt * init,Stmt * provide)312 void MakeReduction(const ComputeOpNode* op,
313                    const Array<Tensor>& tensors,
314                    Stmt* init,
315                    Stmt* provide) {
316   Array<Expr>  args;
317   for (IterVar iv : op->axis) {
318     args.push_back(iv->var);
319   }
320   std::vector<Stmt> inits, provides;
321 
322   size_t size = op->body.size();
323   const Reduce* reduce = op->body[0].as<Reduce>();
324   CHECK(reduce);
325   const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
326   CHECK(combiner);
327   Array<Expr> lhs;
328   for (size_t i = 0; i < size; ++i) {
329     lhs.push_back(tensors[i](args));
330   }
331   Array<Expr> init_value = combiner->identity_element;
332   Array<Expr> update_value = (*combiner)(lhs, reduce->source);
333   for (size_t i = 0; i < size; ++i) {
334     Tensor t = tensors[i];
335     inits.emplace_back(Provide::make(
336           t->op, t->value_index, init_value[i], args));
337     provides.emplace_back(Provide::make(
338           t->op, t->value_index, update_value[i], args));
339   }
340   *init = Block::make(inits);
341   *provide = Block::make(provides);
342   if (!is_one(reduce->condition)) {
343     *provide = IfThenElse::make(reduce->condition, *provide);
344   }
345 }
346 
347 // Normal computation.
MakeProvide(const ComputeOpNode * op,const Tensor & t)348 Stmt MakeProvide(const ComputeOpNode* op,
349                  const Tensor& t) {
350   Array<Expr> args;
351   for (IterVar iv : op->axis) {
352     args.push_back(iv->var);
353   }
354   return Provide::make(t->op, t->value_index, op->body[t->value_index], args);
355 }
356 
MakeComputeStmt(const ComputeOpNode * self,const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,bool debug_keep_trivial_loop)357 Stmt MakeComputeStmt(const ComputeOpNode* self,
358                      const Stage& stage,
359                      const std::unordered_map<IterVar, Range>& dom_map,
360                      bool debug_keep_trivial_loop) {
361   // grab the nest structure
362   ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
363   // Normal loop structure
364   n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
365   n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
366   if (self->reduce_axis.size() != 0) {
367     // make reduction.
368     Stmt init, provide;
369     Array<Tensor> source;
370     for (size_t i = 0; i < self->body.size(); ++i) {
371       source.push_back(stage->op.output(i));
372     }
373     MakeReduction(self, source, &init, &provide);
374     init = MergeNest(n.init_nest, init);
375     init = op::Substitute(init, n.init_vmap);
376     // common nest
377     std::vector<std::vector<Stmt> > common(
378         n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
379     std::vector<std::vector<Stmt> > reduce(
380         n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
381     provide = MergeNest(reduce, provide);
382     if (debug_keep_trivial_loop) {
383       provide = MergeNest(common, provide);
384     } else {
385       provide = MergeNest(common, Block::make(init, provide));
386     }
387     // run substitution in the on the full nest, because  loop condition
388     // could depend on outer loops.
389     return op::Substitute(provide, n.main_vmap);
390   } else {
391     std::vector<Stmt> provides;
392     for (size_t i = 0; i < self->body.size(); ++i) {
393       provides.emplace_back(MakeProvide(self, stage->op.output(i)));
394     }
395     Stmt provide = Block::make(provides);
396     provide = MergeNest(n.main_nest, provide);
397     // run substitution in the on the full nest, because  loop condition
398     // could depend on outer loops.
399     return op::Substitute(provide, n.main_vmap);
400   }
401 }
402 
403 enum class ComputeType {
404   kNormal,
405   kCrossThreadReduction,
406   kTensorize
407 };
408 
DetectComputeType(const ComputeOpNode * self,const Stage & stage)409 ComputeType DetectComputeType(const ComputeOpNode* self,
410                               const Stage& stage) {
411   // Verify correctness of leaf nest.
412   int normal_red = 0, thread_red = 0, tensorize = 0;
413 
414   for (IterVar iv : stage->leaf_iter_vars) {
415     IterVarAttr attr;
416     auto it = stage->iter_var_attrs.find(iv);
417     if (it != stage->iter_var_attrs.end()) {
418       attr = (*it).second;
419     }
420     if (attr.defined() && attr->iter_type == kTensorized) {
421       ++tensorize;
422     }
423     if (iv->iter_type == kCommReduce) {
424       if (attr.defined() && attr->bind_thread.defined()) {
425         ++thread_red;
426       } else {
427         ++normal_red;
428       }
429     } else {
430       CHECK_EQ(thread_red, 0)
431           << "Cross thread reduce cannot swap with normal data axis";
432     }
433   }
434   if (tensorize != 0) {
435     CHECK(thread_red == 0)
436         << "Cannot mix cross thread reduction with Tensorize";
437     return ComputeType::kTensorize;
438   }
439   CHECK(normal_red == 0 || thread_red == 0)
440       << "Cannot mix normal reduction with thread reduce";
441   if (thread_red != 0) {
442     return ComputeType::kCrossThreadReduction;
443   } else {
444     return ComputeType::kNormal;
445   }
446 }
447 
448 // implement the provide utility.
BuildProvide(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,bool debug_keep_trivial_loop) const449 Stmt ComputeOpNode::BuildProvide(
450     const Stage& stage,
451     const std::unordered_map<IterVar, Range>& dom_map,
452     bool debug_keep_trivial_loop) const {
453   CHECK_EQ(stage->op.operator->(), this);
454   ComputeType ctype = DetectComputeType(this, stage);
455   if (ctype == ComputeType::kCrossThreadReduction) {
456     // specially handle cross thread reduction.
457     return MakeCrossThreadReduction(this, stage, dom_map, debug_keep_trivial_loop);
458   } else if (ctype == ComputeType::kTensorize) {
459     return MakeTensorize(this, stage, dom_map, debug_keep_trivial_loop);
460   } else {
461     return MakeComputeStmt(this, stage, dom_map, debug_keep_trivial_loop);
462   }
463 }
464 
make(const BaseComputeOpNode * self,const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,bool debug_keep_trivial_loop)465 ComputeLoopNest ComputeLoopNest::make(
466     const BaseComputeOpNode* self,
467     const Stage& stage,
468     const std::unordered_map<IterVar, Range>& dom_map,
469     bool debug_keep_trivial_loop) {
470   CHECK_EQ(stage->op.operator->(), self);
471   ComputeLoopNest ret;
472   // make main loop nest
473   ret.main_nest = op::MakeLoopNest(
474       stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
475       debug_keep_trivial_loop);
476   ret.main_predicates = schedule::MakeBoundCheck(
477       stage, dom_map, ret.main_vmap, false,
478       std::unordered_set<IterVar>());
479   for (auto& e : ret.main_predicates) {
480     e = likely(e);
481   }
482   if (stage->store_predicate.defined()) {
483     ret.main_predicates.push_back(stage->store_predicate);
484   }
485   if (self->reduce_axis.size() != 0) {
486     // try to find the location to insert the initialization.
487     // Fuse the initialization and provide loop when possible.
488     std::unordered_map<IterVar, int> update_state;
489     for (IterVar iv : self->reduce_axis) {
490       update_state[iv] = 2;
491     }
492     for (size_t i = 0; i < self->num_schedulable_dims(); ++i) {
493       update_state[self->axis[i]] = 1;
494     }
495     // find which iter var is related to reduction and which is related to axis.
496     schedule::PassDownBitMaskOr(stage, &update_state);
497     auto leaf_iter_vars = stage->leaf_iter_vars;
498     // first first loop that is related to reduction.
499     size_t begin_loop = leaf_iter_vars.size();
500     for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
501       auto iv = leaf_iter_vars[i];
502       int flag = update_state.at(iv);
503       if ((flag & 2) != 0) {
504         begin_loop = i; break;
505       }
506       ret.init_vmap[iv] = ret.main_vmap.at(iv);
507     }
508     ret.num_common_loop = begin_loop;
509     // skip loops that are related to reduction and are unrelated to axis.
510     std::unordered_set<IterVar> skip_iter;
511     for (auto kv : update_state) {
512       int flag = kv.second;
513       if (flag == 2) skip_iter.insert(kv.first);
514     }
515     ret.init_nest = op::MakeLoopNest(
516         stage, dom_map, begin_loop, true,
517         skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
518     ret.init_predicates = schedule::MakeBoundCheck(
519         stage, dom_map, ret.init_vmap, true, skip_iter);
520     for (auto& e : ret.init_predicates) {
521       e = likely(e);
522     }
523   } else {
524     CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
525     ret.num_common_loop = stage->leaf_iter_vars.size();
526   }
527   // copy elison here.
528   return ret;
529 }
530 
531 namespace {
532 /*!
533  * \brief Verify if ComputeOp is valid with respect to Reduce operations.
534  *
535  *  The following two properties are verified:
536  *  (1) All Reduce operations must exist at top level.
537  *  (2) For a list of operations, if one is Reduce, then the others
538  *      must be Reduce as well; and their inputs should have the
539  *      same attribute except value_index.
540  */
541 class ComputeVerifier final : protected ir::IRVisitor {
542  public:
543   /// Special member functions
544   //@{
ComputeVerifier(const ComputeOpNode * compute)545   explicit ComputeVerifier(const ComputeOpNode* compute)
546       : compute_(compute), reduce_(compute->body[0].as<ir::Reduce>()) {}
547   virtual ~ComputeVerifier() = default;
548   ComputeVerifier(const ComputeVerifier&) = delete;
549   ComputeVerifier(ComputeVerifier&&) = delete;
550   ComputeVerifier& operator=(const ComputeVerifier&) = delete;
551   ComputeVerifier& operator=(ComputeVerifier&&) = delete;
552   //@}
553 
554   /// Interface to perform compute verification
Run()555   void Run() {
556     for (const Expr e : compute_->body) {
557       // Check for consistency of top level reductions
558       const ir::Reduce* reduce = e.as<ir::Reduce>();
559       CHECK((reduce && reduce_) || (!reduce && !reduce_))
560           << "All ComputeOp should be consistent "
561           << "with being Reduce operation or not.";
562 
563       if (reduce && reduce_) {
564         CHECK(ReduceEqual(reduce, reduce_))
565             << "The Reduce inputs of ComputeOp should "
566             << "have the same attribute except value_index";
567       }
568 
569       level_ = 0;
570       ir::IRVisitor::Visit(e);
571     }
572   }
573 
574  protected:
575   /// Visitor implementation
576   //@{
Visit(const NodeRef & n)577   void Visit(const NodeRef& n) final {
578     ++level_;
579     ir::IRVisitor::Visit(n);
580     --level_;
581   }
582 
Visit_(const ir::Reduce * op)583   void Visit_(const ir::Reduce* op) final {
584     // Check for non top level reductions
585     CHECK(0 == level_)
586         << "Reductions are only allowed at the top level of compute. "
587         << "Please create another tensor for further composition.";
588   }
589   //@}
590 
591  private:
592   const ComputeOpNode* compute_{nullptr};  ///< ComputeOpNode to verify
593   const ir::Reduce* reduce_{nullptr};      ///< Top level Reduce operation
594   int level_{0};                           ///< Level of op being processed
595 };
596 }  // namespace
597 
598 /// Verify if ComputeOp is valid with respect to Reduce operations.
VerifyComputeOp(const ComputeOpNode * op)599 static void VerifyComputeOp(const ComputeOpNode* op) {
600   ComputeVerifier v(op);
601   v.Run();
602 }
603 
TransformUpdate(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,const ComputeLoopNest & n,Stmt body,Stmt update)604 Stmt TransformUpdate(const Stage& stage,
605                      const std::unordered_map<IterVar, Range>& dom_map,
606                      const ComputeLoopNest& n,
607                      Stmt body,
608                      Stmt update) {
609   Array<Expr> conds;
610   std::unordered_set<const Variable*> banned;
611   for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
612     IterVar iv = stage->leaf_iter_vars[i];
613     auto iit = stage->iter_var_attrs.find(iv);
614     if (iit != stage->iter_var_attrs.end()) {
615       const IterVarAttr& attr = (*iit).second;
616       if (attr->iter_type == kTensorized) {
617         break;
618       }
619     }
620     if (iv->iter_type == kCommReduce) {
621       auto vit = dom_map.find(iv);
622       CHECK(vit != dom_map.end());
623       const Range& vrange = vit->second;
624       conds.push_back(likely(iv->var > vrange->min));
625       banned.insert(iv->var.get());
626     }
627   }
628   for (const Expr& pred : n.main_predicates) {
629     if (ir::ExprUseVar(pred, banned)) {
630       LOG(FATAL) << "Tensorize update transform failed, the condition "
631                  << pred << " has a conflict with the reset condition";
632     }
633   }
634 
635   return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
636                           update, body);
637 }
638 }  // namespace tvm
639