1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \brief Compute Op.
22 * \file compute_op.cc
23 */
24 #include <tvm/operation.h>
25 #include <tvm/arithmetic.h>
26 #include <tvm/ir.h>
27 #include <tvm/ir_visitor.h>
28 #include <tvm/ir_pass.h>
29 #include <unordered_set>
30 #include <string>
31 #include <utility>
32 #include "compute_op.h"
33 #include "op_util.h"
34 #include "../schedule/message_passing.h"
35 #include "../arithmetic/compute_expr.h"
36 #include "../arithmetic/int_set.h"
37
38 namespace tvm {
39
40 using namespace ir;
41
42 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anoncb9766f90102(const ObjectRef& node, IRPrinter* p) 43 .set_dispatch<ComputeOpNode>([](const ObjectRef& node, IRPrinter* p) {
44 auto* op = static_cast<const ComputeOpNode*>(node.get());
45 p->stream << "compute(" << op->name << ", " << op << ")";
46 });
47
48 TVM_REGISTER_NODE_TYPE(ComputeOpNode);
49
50 /// Verify if ComputeOp is valid with respect to Reduce operations.
51 static void VerifyComputeOp(const ComputeOpNode *op);
52
ReduceEqual(const ir::Reduce * a,const ir::Reduce * b)53 inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
54 return (a->combiner.same_as(b->combiner)) &&
55 (a->source.same_as(b->source)) &&
56 (a->axis.same_as(b->axis)) &&
57 (a->condition.same_as(b->condition));
58 }
59
num_outputs() const60 int ComputeOpNode::num_outputs() const {
61 return body.size();
62 }
63
root_iter_vars() const64 Array<IterVar> BaseComputeOpNode::root_iter_vars() const {
65 if (reduce_axis.size() == 0) return axis;
66 Array<IterVar> ret = axis;
67 for (IterVar iv : reduce_axis) {
68 ret.push_back(iv);
69 }
70 return ret;
71 }
72
output_dtype(size_t idx) const73 Type ComputeOpNode::output_dtype(size_t idx) const {
74 CHECK_LT(idx, num_outputs());
75 return body[idx].type();
76 }
77
output_shape(size_t idx) const78 Array<Expr> BaseComputeOpNode::output_shape(size_t idx) const {
79 CHECK_LT(idx, num_outputs());
80 // for now, all outputs of a BaseComputeOp have the same shape
81 Array<Expr> shape;
82 for (const auto& ivar : this->axis) {
83 const Range& r = ivar->dom;
84 shape.push_back(r->extent);
85 }
86 return shape;
87 }
88
compute(Array<Expr> shape,FCompute fcompute,std::string name,std::string tag,Map<std::string,NodeRef> attrs)89 Tensor compute(Array<Expr> shape,
90 FCompute fcompute,
91 std::string name,
92 std::string tag,
93 Map<std::string, NodeRef> attrs) {
94 auto op_node = make_node<ComputeOpNode>();
95 // compute dimension.
96 size_t ndim = shape.size();
97 std::vector<IterVar> axis;
98 std::vector<Var> args;
99 for (size_t i = 0; i < ndim; ++i) {
100 std::ostringstream os;
101 os << "ax" << i;
102 axis.emplace_back(IterVarNode::make(
103 Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar));
104 args.push_back(axis.back()->var);
105 }
106
107 return ComputeOpNode::make(
108 name, tag, attrs, axis, {fcompute(args)}).output(0);
109 }
110
compute(Array<Expr> shape,FBatchCompute fcompute,std::string name,std::string tag,Map<std::string,NodeRef> attrs)111 Array<Tensor> compute(Array<Expr> shape,
112 FBatchCompute fcompute,
113 std::string name,
114 std::string tag,
115 Map<std::string, NodeRef> attrs) {
116 auto op_node = make_node<ComputeOpNode>();
117 // compute dimension.
118 size_t ndim = shape.size();
119 std::vector<IterVar> axis;
120 std::vector<Var> args;
121 for (size_t i = 0; i < ndim; ++i) {
122 std::ostringstream os;
123 os << "ax" << i;
124 axis.emplace_back(IterVarNode::make(
125 Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar));
126 args.push_back(axis.back()->var);
127 }
128
129 Operation op = ComputeOpNode::make(name, tag, attrs, axis, fcompute(args));
130 Array<Tensor> outputs;
131 for (int idx = 0; idx < op->num_outputs(); ++idx) {
132 outputs.push_back(op.output(idx));
133 }
134 return outputs;
135 }
136
make(std::string name,std::string tag,Map<std::string,NodeRef> attrs,Array<IterVar> axis,Array<Expr> body)137 Operation ComputeOpNode::make(std::string name,
138 std::string tag,
139 Map<std::string, NodeRef> attrs,
140 Array<IterVar> axis,
141 Array<Expr> body) {
142 if (!attrs.defined()) {
143 attrs = Map<std::string, NodeRef>();
144 }
145 auto n = make_node<ComputeOpNode>();
146 n->name = std::move(name);
147 n->tag = std::move(tag);
148 n->attrs = std::move(attrs);
149 n->axis = std::move(axis);
150 n->body = std::move(body);
151 if (n->body[0]->IsInstance<ir::Reduce>()) {
152 const ir::Reduce* reduce = n->body[0].as<ir::Reduce>();
153 n->reduce_axis = reduce->axis;
154 }
155 VerifyComputeOp(n.get());
156 return Operation(n);
157 }
158
159 // The schedule related logics
InputTensors() const160 Array<Tensor> ComputeOpNode::InputTensors() const {
161 Array<Tensor> ret;
162 std::unordered_set<Tensor> visited;
163 for (auto& e : body) {
164 ir::PostOrderVisit(e, [&ret, &visited](const NodeRef& n) {
165 const ir::Call *call = n.as<ir::Call>();
166 if (call != nullptr && call->func.defined()) {
167 Tensor t = Downcast<Operation>(call->func).output(call->value_index);
168 if (!visited.count(t)) {
169 ret.push_back(t);
170 visited.insert(t);
171 }
172 }
173 });
174 }
175 return ret;
176 }
177
ReplaceInputs(const Operation & self,const std::unordered_map<Tensor,Tensor> & rmap) const178 Operation ComputeOpNode::ReplaceInputs(
179 const Operation& self,
180 const std::unordered_map<Tensor, Tensor>& rmap) const {
181 CHECK_EQ(self.operator->(), this);
182 VerifyComputeOp(this);
183 Array<Expr> arr;
184 if (this->body[0]->IsInstance<ir::Reduce>()) {
185 // Specially handle reduce so the replaced op
186 // still share all the components
187 Expr new_reduce = op::ReplaceTensor(this->body[0], rmap);
188 if (!new_reduce.same_as(this->body[0])) {
189 const ir::Reduce* r = new_reduce.as<ir::Reduce>();
190 for (size_t k = 0; k < this->body.size(); ++k) {
191 auto n = make_node<ir::Reduce>(*r);
192 n->value_index = static_cast<int>(k);
193 n->type = r->source[k].type();
194 arr.push_back(Expr(n));
195 }
196 } else {
197 arr = this->body;
198 }
199 } else {
200 arr = UpdateArray(this->body, [&rmap] (const Expr& e) {
201 return op::ReplaceTensor(e, rmap);
202 });
203 }
204 if (!arr.same_as(this->body)) {
205 return ComputeOpNode::make(
206 this->name, this->tag, this->attrs, this->axis, arr);
207 } else {
208 return self;
209 }
210 }
211
PropBoundToInputs(const Operation & self,arith::Analyzer * analyzer,const std::unordered_map<const Variable *,IntSet> & dom_map,std::unordered_map<Tensor,TensorDom> * out_dom_map) const212 void ComputeOpNode::PropBoundToInputs(
213 const Operation& self,
214 arith::Analyzer* analyzer,
215 const std::unordered_map<const Variable*, IntSet>& dom_map,
216 std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
217 CHECK_EQ(self.operator->(), this);
218 auto fvisit = [&dom_map, out_dom_map, analyzer](const NodeRef& n) {
219 auto *call = n.as<ir::Call>();
220 if (call != nullptr && call->func.defined()) {
221 Tensor t = Downcast<Operation>(call->func).output(call->value_index);
222 if (t->op.defined() && out_dom_map->count(t)) {
223 TensorDom& dom = out_dom_map->at(t);
224 for (size_t i = 0; i < t.ndim(); ++i) {
225 // We assume that the value of the argument cannot be out of bounds (otherwise it is
226 // undefined behaviour), so we can intersect the estimated set of the argument with the
227 // range expected by the tensor. However, intersection may result in overly complex
228 // expressions, so we perform a more relaxed form of intersection.
229 IntSet arg_intset = EvalSet(call->args[i], dom_map);
230 const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
231 if (arg_interval) {
232 Expr shape_i_min_value = make_zero(t->shape[i].type());
233 Expr shape_i_max_value = t->shape[i] - 1;
234 Expr min_value = arg_interval->min_value;
235 Expr max_value = arg_interval->max_value;
236 // Prefer the shape bounds only when we can prove they are tighter.
237 if (arith::is_neg_inf(min_value) ||
238 analyzer->CanProve(shape_i_min_value >= min_value)) {
239 min_value = shape_i_min_value;
240 }
241 if (arith::is_pos_inf(max_value) ||
242 analyzer->CanProve(shape_i_max_value <= max_value)) {
243 max_value = shape_i_max_value;
244 }
245 dom.data[i].push_back(IntSet::interval(min_value, max_value));
246 } else {
247 dom.data[i].push_back(arg_intset);
248 }
249 }
250 }
251 }
252 };
253 for (auto& e : body) ir::PostOrderVisit(e, fvisit);
254 }
255
GatherBound(const Operation & self,const std::unordered_map<Tensor,TensorDom> & tensor_dom,std::unordered_map<IterVar,Range> * out_dom_map) const256 void BaseComputeOpNode::GatherBound(
257 const Operation& self,
258 const std::unordered_map<Tensor, TensorDom>& tensor_dom,
259 std::unordered_map<IterVar, Range>* out_dom_map) const {
260 CHECK_EQ(self.operator->(), this);
261 const TensorDom& tdom = tensor_dom.at(self.output(0));
262 for (size_t i = 0; i < this->axis.size(); ++i) {
263 Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
264 CHECK(!out_dom_map->count(this->axis[i]));
265 (*out_dom_map)[this->axis[i]] = r;
266 }
267 for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
268 CHECK(!out_dom_map->count(this->reduce_axis[i]));
269 (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
270 }
271 }
272
BuildRealize(const Stage & stage,const std::unordered_map<IterVar,Range> & realize_map,const Stmt & body) const273 Stmt BaseComputeOpNode::BuildRealize(
274 const Stage& stage,
275 const std::unordered_map<IterVar, Range>& realize_map,
276 const Stmt& body) const {
277 CHECK_EQ(stage->op.get(), this);
278 Region bounds;
279 for (IterVar iv : this->axis) {
280 bounds.push_back(realize_map.at(iv));
281 }
282 Stmt realize = body;
283 for (int i = this->num_outputs(); i > 0; --i) {
284 Tensor t = stage->op.output(i-1);
285 realize = ir::Realize::make(t->op, t->value_index,
286 t->dtype, bounds, const_true(), realize);
287 // alignment requirement, only useful for compute
288 for (size_t i = 0; i < num_schedulable_dims(); ++i) {
289 auto it = stage->iter_var_attrs.find(this->axis[i]);
290 if (it != stage->iter_var_attrs.end()) {
291 IterVarAttr attr = (*it).second;
292 if (attr->dim_align_factor != 0) {
293 Array<Expr> tuple = {static_cast<int>(i),
294 attr->dim_align_factor,
295 attr->dim_align_offset};
296 realize = ir::AttrStmt::make(
297 t, ir::attr::buffer_dim_align,
298 Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
299 realize);
300 }
301 }
302 }
303 }
304 return realize;
305 }
306
num_schedulable_dims() const307 size_t ComputeOpNode::num_schedulable_dims() const {
308 return axis.size();
309 }
310
311 // Build a reduction body.
MakeReduction(const ComputeOpNode * op,const Array<Tensor> & tensors,Stmt * init,Stmt * provide)312 void MakeReduction(const ComputeOpNode* op,
313 const Array<Tensor>& tensors,
314 Stmt* init,
315 Stmt* provide) {
316 Array<Expr> args;
317 for (IterVar iv : op->axis) {
318 args.push_back(iv->var);
319 }
320 std::vector<Stmt> inits, provides;
321
322 size_t size = op->body.size();
323 const Reduce* reduce = op->body[0].as<Reduce>();
324 CHECK(reduce);
325 const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
326 CHECK(combiner);
327 Array<Expr> lhs;
328 for (size_t i = 0; i < size; ++i) {
329 lhs.push_back(tensors[i](args));
330 }
331 Array<Expr> init_value = combiner->identity_element;
332 Array<Expr> update_value = (*combiner)(lhs, reduce->source);
333 for (size_t i = 0; i < size; ++i) {
334 Tensor t = tensors[i];
335 inits.emplace_back(Provide::make(
336 t->op, t->value_index, init_value[i], args));
337 provides.emplace_back(Provide::make(
338 t->op, t->value_index, update_value[i], args));
339 }
340 *init = Block::make(inits);
341 *provide = Block::make(provides);
342 if (!is_one(reduce->condition)) {
343 *provide = IfThenElse::make(reduce->condition, *provide);
344 }
345 }
346
347 // Normal computation.
MakeProvide(const ComputeOpNode * op,const Tensor & t)348 Stmt MakeProvide(const ComputeOpNode* op,
349 const Tensor& t) {
350 Array<Expr> args;
351 for (IterVar iv : op->axis) {
352 args.push_back(iv->var);
353 }
354 return Provide::make(t->op, t->value_index, op->body[t->value_index], args);
355 }
356
MakeComputeStmt(const ComputeOpNode * self,const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,bool debug_keep_trivial_loop)357 Stmt MakeComputeStmt(const ComputeOpNode* self,
358 const Stage& stage,
359 const std::unordered_map<IterVar, Range>& dom_map,
360 bool debug_keep_trivial_loop) {
361 // grab the nest structure
362 ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
363 // Normal loop structure
364 n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
365 n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
366 if (self->reduce_axis.size() != 0) {
367 // make reduction.
368 Stmt init, provide;
369 Array<Tensor> source;
370 for (size_t i = 0; i < self->body.size(); ++i) {
371 source.push_back(stage->op.output(i));
372 }
373 MakeReduction(self, source, &init, &provide);
374 init = MergeNest(n.init_nest, init);
375 init = op::Substitute(init, n.init_vmap);
376 // common nest
377 std::vector<std::vector<Stmt> > common(
378 n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
379 std::vector<std::vector<Stmt> > reduce(
380 n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
381 provide = MergeNest(reduce, provide);
382 if (debug_keep_trivial_loop) {
383 provide = MergeNest(common, provide);
384 } else {
385 provide = MergeNest(common, Block::make(init, provide));
386 }
387 // run substitution in the on the full nest, because loop condition
388 // could depend on outer loops.
389 return op::Substitute(provide, n.main_vmap);
390 } else {
391 std::vector<Stmt> provides;
392 for (size_t i = 0; i < self->body.size(); ++i) {
393 provides.emplace_back(MakeProvide(self, stage->op.output(i)));
394 }
395 Stmt provide = Block::make(provides);
396 provide = MergeNest(n.main_nest, provide);
397 // run substitution in the on the full nest, because loop condition
398 // could depend on outer loops.
399 return op::Substitute(provide, n.main_vmap);
400 }
401 }
402
403 enum class ComputeType {
404 kNormal,
405 kCrossThreadReduction,
406 kTensorize
407 };
408
DetectComputeType(const ComputeOpNode * self,const Stage & stage)409 ComputeType DetectComputeType(const ComputeOpNode* self,
410 const Stage& stage) {
411 // Verify correctness of leaf nest.
412 int normal_red = 0, thread_red = 0, tensorize = 0;
413
414 for (IterVar iv : stage->leaf_iter_vars) {
415 IterVarAttr attr;
416 auto it = stage->iter_var_attrs.find(iv);
417 if (it != stage->iter_var_attrs.end()) {
418 attr = (*it).second;
419 }
420 if (attr.defined() && attr->iter_type == kTensorized) {
421 ++tensorize;
422 }
423 if (iv->iter_type == kCommReduce) {
424 if (attr.defined() && attr->bind_thread.defined()) {
425 ++thread_red;
426 } else {
427 ++normal_red;
428 }
429 } else {
430 CHECK_EQ(thread_red, 0)
431 << "Cross thread reduce cannot swap with normal data axis";
432 }
433 }
434 if (tensorize != 0) {
435 CHECK(thread_red == 0)
436 << "Cannot mix cross thread reduction with Tensorize";
437 return ComputeType::kTensorize;
438 }
439 CHECK(normal_red == 0 || thread_red == 0)
440 << "Cannot mix normal reduction with thread reduce";
441 if (thread_red != 0) {
442 return ComputeType::kCrossThreadReduction;
443 } else {
444 return ComputeType::kNormal;
445 }
446 }
447
448 // implement the provide utility.
BuildProvide(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,bool debug_keep_trivial_loop) const449 Stmt ComputeOpNode::BuildProvide(
450 const Stage& stage,
451 const std::unordered_map<IterVar, Range>& dom_map,
452 bool debug_keep_trivial_loop) const {
453 CHECK_EQ(stage->op.operator->(), this);
454 ComputeType ctype = DetectComputeType(this, stage);
455 if (ctype == ComputeType::kCrossThreadReduction) {
456 // specially handle cross thread reduction.
457 return MakeCrossThreadReduction(this, stage, dom_map, debug_keep_trivial_loop);
458 } else if (ctype == ComputeType::kTensorize) {
459 return MakeTensorize(this, stage, dom_map, debug_keep_trivial_loop);
460 } else {
461 return MakeComputeStmt(this, stage, dom_map, debug_keep_trivial_loop);
462 }
463 }
464
make(const BaseComputeOpNode * self,const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,bool debug_keep_trivial_loop)465 ComputeLoopNest ComputeLoopNest::make(
466 const BaseComputeOpNode* self,
467 const Stage& stage,
468 const std::unordered_map<IterVar, Range>& dom_map,
469 bool debug_keep_trivial_loop) {
470 CHECK_EQ(stage->op.operator->(), self);
471 ComputeLoopNest ret;
472 // make main loop nest
473 ret.main_nest = op::MakeLoopNest(
474 stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
475 debug_keep_trivial_loop);
476 ret.main_predicates = schedule::MakeBoundCheck(
477 stage, dom_map, ret.main_vmap, false,
478 std::unordered_set<IterVar>());
479 for (auto& e : ret.main_predicates) {
480 e = likely(e);
481 }
482 if (stage->store_predicate.defined()) {
483 ret.main_predicates.push_back(stage->store_predicate);
484 }
485 if (self->reduce_axis.size() != 0) {
486 // try to find the location to insert the initialization.
487 // Fuse the initialization and provide loop when possible.
488 std::unordered_map<IterVar, int> update_state;
489 for (IterVar iv : self->reduce_axis) {
490 update_state[iv] = 2;
491 }
492 for (size_t i = 0; i < self->num_schedulable_dims(); ++i) {
493 update_state[self->axis[i]] = 1;
494 }
495 // find which iter var is related to reduction and which is related to axis.
496 schedule::PassDownBitMaskOr(stage, &update_state);
497 auto leaf_iter_vars = stage->leaf_iter_vars;
498 // first first loop that is related to reduction.
499 size_t begin_loop = leaf_iter_vars.size();
500 for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
501 auto iv = leaf_iter_vars[i];
502 int flag = update_state.at(iv);
503 if ((flag & 2) != 0) {
504 begin_loop = i; break;
505 }
506 ret.init_vmap[iv] = ret.main_vmap.at(iv);
507 }
508 ret.num_common_loop = begin_loop;
509 // skip loops that are related to reduction and are unrelated to axis.
510 std::unordered_set<IterVar> skip_iter;
511 for (auto kv : update_state) {
512 int flag = kv.second;
513 if (flag == 2) skip_iter.insert(kv.first);
514 }
515 ret.init_nest = op::MakeLoopNest(
516 stage, dom_map, begin_loop, true,
517 skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
518 ret.init_predicates = schedule::MakeBoundCheck(
519 stage, dom_map, ret.init_vmap, true, skip_iter);
520 for (auto& e : ret.init_predicates) {
521 e = likely(e);
522 }
523 } else {
524 CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
525 ret.num_common_loop = stage->leaf_iter_vars.size();
526 }
527 // copy elison here.
528 return ret;
529 }
530
531 namespace {
532 /*!
533 * \brief Verify if ComputeOp is valid with respect to Reduce operations.
534 *
535 * The following two properties are verified:
536 * (1) All Reduce operations must exist at top level.
537 * (2) For a list of operations, if one is Reduce, then the others
538 * must be Reduce as well; and their inputs should have the
539 * same attribute except value_index.
540 */
541 class ComputeVerifier final : protected ir::IRVisitor {
542 public:
543 /// Special member functions
544 //@{
ComputeVerifier(const ComputeOpNode * compute)545 explicit ComputeVerifier(const ComputeOpNode* compute)
546 : compute_(compute), reduce_(compute->body[0].as<ir::Reduce>()) {}
547 virtual ~ComputeVerifier() = default;
548 ComputeVerifier(const ComputeVerifier&) = delete;
549 ComputeVerifier(ComputeVerifier&&) = delete;
550 ComputeVerifier& operator=(const ComputeVerifier&) = delete;
551 ComputeVerifier& operator=(ComputeVerifier&&) = delete;
552 //@}
553
554 /// Interface to perform compute verification
Run()555 void Run() {
556 for (const Expr e : compute_->body) {
557 // Check for consistency of top level reductions
558 const ir::Reduce* reduce = e.as<ir::Reduce>();
559 CHECK((reduce && reduce_) || (!reduce && !reduce_))
560 << "All ComputeOp should be consistent "
561 << "with being Reduce operation or not.";
562
563 if (reduce && reduce_) {
564 CHECK(ReduceEqual(reduce, reduce_))
565 << "The Reduce inputs of ComputeOp should "
566 << "have the same attribute except value_index";
567 }
568
569 level_ = 0;
570 ir::IRVisitor::Visit(e);
571 }
572 }
573
574 protected:
575 /// Visitor implementation
576 //@{
Visit(const NodeRef & n)577 void Visit(const NodeRef& n) final {
578 ++level_;
579 ir::IRVisitor::Visit(n);
580 --level_;
581 }
582
Visit_(const ir::Reduce * op)583 void Visit_(const ir::Reduce* op) final {
584 // Check for non top level reductions
585 CHECK(0 == level_)
586 << "Reductions are only allowed at the top level of compute. "
587 << "Please create another tensor for further composition.";
588 }
589 //@}
590
591 private:
592 const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify
593 const ir::Reduce* reduce_{nullptr}; ///< Top level Reduce operation
594 int level_{0}; ///< Level of op being processed
595 };
596 } // namespace
597
598 /// Verify if ComputeOp is valid with respect to Reduce operations.
VerifyComputeOp(const ComputeOpNode * op)599 static void VerifyComputeOp(const ComputeOpNode* op) {
600 ComputeVerifier v(op);
601 v.Run();
602 }
603
TransformUpdate(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,const ComputeLoopNest & n,Stmt body,Stmt update)604 Stmt TransformUpdate(const Stage& stage,
605 const std::unordered_map<IterVar, Range>& dom_map,
606 const ComputeLoopNest& n,
607 Stmt body,
608 Stmt update) {
609 Array<Expr> conds;
610 std::unordered_set<const Variable*> banned;
611 for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
612 IterVar iv = stage->leaf_iter_vars[i];
613 auto iit = stage->iter_var_attrs.find(iv);
614 if (iit != stage->iter_var_attrs.end()) {
615 const IterVarAttr& attr = (*iit).second;
616 if (attr->iter_type == kTensorized) {
617 break;
618 }
619 }
620 if (iv->iter_type == kCommReduce) {
621 auto vit = dom_map.find(iv);
622 CHECK(vit != dom_map.end());
623 const Range& vrange = vit->second;
624 conds.push_back(likely(iv->var > vrange->min));
625 banned.insert(iv->var.get());
626 }
627 }
628 for (const Expr& pred : n.main_predicates) {
629 if (ir::ExprUseVar(pred, banned)) {
630 LOG(FATAL) << "Tensorize update transform failed, the condition "
631 << pred << " has a conflict with the reset condition";
632 }
633 }
634
635 return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
636 update, body);
637 }
638 } // namespace tvm
639