1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file message_passing.cc
22 * \brief The message passing domain.
23 */
24 #include "message_passing.h"
25
26 #include <tvm/arith/analyzer.h>
27 #include <tvm/tir/expr.h>
28
29 namespace tvm {
30 namespace te {
31
32 using namespace tir;
33
Update(std::unordered_map<IterVar,Range> * p_state,const IterVar & iv,Range r,arith::Analyzer * analyzer)34 void Update(std::unordered_map<IterVar, Range>* p_state, const IterVar& iv, Range r,
35 arith::Analyzer* analyzer) {
36 auto it = p_state->find(iv);
37 if (it == p_state->end()) {
38 (*p_state)[iv] = r;
39 analyzer->Bind(iv->var, r);
40 } else {
41 bool match =
42 is_zero(it->second->min) && analyzer->CanProve(r->extent - it->second->extent == 0);
43 CHECK(match) << iv << " domain already inferred,"
44 << " cannot prove their extents are the same " << it->second->extent << " vs "
45 << r->extent;
46 }
47 }
48
49 /*!
50 * \param Upward propagating whether an IterVar derives at least one leaf IterVar that binds to
51 * a thread.
52 *
53 * \param stage The stage to operate on.
54 * \param p_state The propagation result of each IterVar.
55 */
PassUpThreadBinding(const Stage & stage,std::unordered_map<IterVar,bool> * p_state)56 void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar, bool>* p_state) {
57 auto bound_to_thread = [&stage](const IterVar& iv) {
58 bool bound = false;
59 auto it = stage->iter_var_attrs.find(iv);
60 if (it != stage->iter_var_attrs.end()) {
61 bound = (*it).second->bind_thread.defined();
62 }
63 return bound;
64 };
65
66 auto& state = *p_state;
67 // Fill p_state with leaf itervars
68 for (const IterVar& iv : stage->leaf_iter_vars) {
69 state[iv] = bound_to_thread(iv);
70 }
71 // Traverse the graph bottom-up to propagate thread binding information
72 for (size_t i = stage->relations.size(); i != 0; --i) {
73 IterVarRelation rel = stage->relations[i - 1];
74 if (const SplitNode* s = rel.as<SplitNode>()) {
75 state[s->parent] = state[s->inner] || state[s->outer];
76 } else if (const FuseNode* s = rel.as<FuseNode>()) {
77 state[s->inner] = state[s->fused];
78 state[s->outer] = state[s->fused];
79 } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
80 state[s->parent] = state[s->rebased];
81 } else if (rel.as<SingletonNode>()) {
82 } else {
83 LOG(FATAL) << "unknown relation type";
84 }
85 }
86 }
87
PassDownDomain(const Stage & stage,std::unordered_map<IterVar,Range> * p_state,arith::Analyzer * actx,bool allow_missing)88 void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_state,
89 arith::Analyzer* actx, bool allow_missing) {
90 auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) {
91 if (actx->CanProve(indexmod(a, b) == 0)) {
92 return actx->Simplify(indexdiv(a, b));
93 }
94 return actx->Simplify(indexdiv(a + (b - 1), b));
95 };
96
97 auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) {
98 if (actx->CanProve(a < b)) {
99 return actx->Simplify(a);
100 }
101 return actx->Simplify(b);
102 };
103
104 std::unordered_map<IterVar, bool> dominating_thread;
105 PassUpThreadBinding(stage, &dominating_thread);
106
107 auto& state = *p_state;
108 // forwar iteration on relations
109 for (IterVarRelation rel : stage->relations) {
110 if (const SplitNode* r = rel.as<SplitNode>()) {
111 if (!state.count(r->parent)) {
112 CHECK(allow_missing);
113 continue;
114 }
115 CHECK(!state.count(r->inner));
116 const Range& range_parent = state.at(r->parent);
117 // Tighten iv's extent to min(parent_extent, factor_or_nparts), only if all of the
118 // following conditions are met:
119 // 1. No leaf IterVar derived from iv binds to any thread. People may use split
120 // to force an IterVar extent to match the number of allocated threads to fuse stages
121 // that require different number of threads. We don't want to change these extents.
122 // 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound,
123 // rather than by an early compiler phase, such as rfactor(). We don't want to tighten an
124 // IterVar in an early phase allowing missing IterVars, because it may bind to a thread later.
125 // 3. range_parent's extent is not 0. At lest one Topi test has a case where a tensor has one
126 // zero-sized dimension. Split creates iv with a positive extent to avoid zero-extent
127 // IterVar. We don't touch it.
128 auto resolve_min_extent_for_split = [&](const IterVar& iv, const PrimExpr& factor_or_nparts) {
129 return dominating_thread[iv] || allow_missing || is_zero(range_parent->extent)
130 ? factor_or_nparts
131 : minimum_or_later(range_parent->extent, factor_or_nparts);
132 };
133 if (r->factor.defined()) {
134 Update(p_state, r->inner,
135 Range::FromMinExtent(0, resolve_min_extent_for_split(r->inner, r->factor)), actx);
136 Update(p_state, r->outer,
137 Range::FromMinExtent(0, ceil_div(range_parent->extent, r->factor)), actx);
138 } else {
139 Update(p_state, r->outer,
140 Range::FromMinExtent(0, resolve_min_extent_for_split(r->outer, r->nparts)), actx);
141 Update(p_state, r->inner,
142 Range::FromMinExtent(0, ceil_div(range_parent->extent, r->nparts)), actx);
143 }
144 } else if (const FuseNode* r = rel.as<FuseNode>()) {
145 if (!state.count(r->outer) || !state.count(r->inner)) {
146 CHECK(allow_missing);
147 continue;
148 }
149 const Range& range_outer = state.at(r->outer);
150 const Range& range_inner = state.at(r->inner);
151 state[r->fused] = Range::FromMinExtent(0, range_outer->extent * range_inner->extent);
152 } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
153 if (!state.count(r->parent)) {
154 CHECK(allow_missing);
155 continue;
156 }
157 Update(p_state, r->rebased, Range::FromMinExtent(0, state.at(r->parent)->extent), actx);
158 } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
159 Update(p_state, s->iter, Range::FromMinExtent(0, 1), actx);
160 } else {
161 LOG(FATAL) << "unknown relation type";
162 }
163 }
164 // update the extents of binded threads.
165 for (auto kv : stage->iter_var_attrs) {
166 if (kv.second->bind_thread.defined()) {
167 CHECK(state.count(kv.first));
168 Update(p_state, kv.second->bind_thread, state.at(kv.first), actx);
169 }
170 }
171 }
172
PassUpIndex(const Stage & stage,const Map<IterVar,Range> & dom_map,std::unordered_map<IterVar,PrimExpr> * p_state,bool allow_missing)173 void PassUpIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
174 std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing) {
175 auto& state = *p_state;
176 for (size_t i = stage->relations.size(); i != 0; --i) {
177 IterVarRelation rel = stage->relations[i - 1];
178 if (const SplitNode* s = rel.as<SplitNode>()) {
179 if (!state.count(s->outer) || !state.count(s->inner)) {
180 CHECK(allow_missing);
181 continue;
182 }
183 PrimExpr outer = state.at(s->outer);
184 PrimExpr inner = state.at(s->inner);
185 PrimExpr factor = dom_map.at(s->inner)->extent;
186 PrimExpr parent_min = dom_map.at(s->parent)->min;
187 state[s->parent] = inner + outer * factor;
188 // add min if they exist
189 if (!is_zero(parent_min)) {
190 state[s->parent] = state[s->parent] + parent_min;
191 }
192 } else if (const FuseNode* s = rel.as<FuseNode>()) {
193 if (!state.count(s->fused)) {
194 CHECK(allow_missing);
195 continue;
196 }
197 PrimExpr value = state.at(s->fused);
198 PrimExpr factor = dom_map.at(s->inner)->extent;
199 PrimExpr outer_min = dom_map.at(s->outer)->min;
200 PrimExpr inner_min = dom_map.at(s->inner)->min;
201 state[s->outer] = indexdiv(value, factor);
202 state[s->inner] = indexmod(value, factor);
203 // add min if they exist
204 if (!is_zero(outer_min)) {
205 state[s->outer] = state[s->outer] + outer_min;
206 }
207 if (!is_zero(inner_min)) {
208 state[s->inner] = state[s->inner] + inner_min;
209 }
210 // s->fused, s->outer and s->inner may be of different dtype,
211 // so we cast the `state` back to its original dtype
212 state[s->outer] = cast(s->outer->var.dtype(), state[s->outer]);
213 state[s->inner] = cast(s->inner->var.dtype(), state[s->inner]);
214 } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
215 if (!state.count(s->rebased)) {
216 CHECK(allow_missing);
217 continue;
218 }
219 PrimExpr value = state.at(s->rebased);
220 PrimExpr parent_min = dom_map.at(s->parent)->min;
221 // add min if they exist
222 if (!is_zero(parent_min)) {
223 state[s->parent] = value + parent_min;
224 } else {
225 state[s->parent] = value;
226 }
227 } else if (rel.as<SingletonNode>()) {
228 } else {
229 LOG(FATAL) << "unknown relation type";
230 }
231 }
232 }
233
PassDownIndex(const Stage & stage,const Map<IterVar,Range> & dom_map,std::unordered_map<IterVar,PrimExpr> * p_state,bool allow_missing)234 void PassDownIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
235 std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing) {
236 auto& state = *p_state;
237 for (IterVarRelation rel : stage->relations) {
238 if (const SplitNode* s = rel.as<SplitNode>()) {
239 if (!state.count(s->parent)) {
240 CHECK(allow_missing);
241 continue;
242 }
243 Range r = dom_map.at(s->inner);
244 CHECK(is_zero(r->min));
245 PrimExpr parent = state.at(s->parent);
246 PrimExpr factor = r->extent;
247 state[s->outer] = indexdiv(parent, factor);
248 state[s->inner] = indexmod(parent, factor);
249 } else if (const FuseNode* s = rel.as<FuseNode>()) {
250 if (!state.count(s->inner) && !state.count(s->outer)) {
251 CHECK(allow_missing);
252 continue;
253 }
254 PrimExpr factor = dom_map.at(s->inner)->extent;
255 PrimExpr outer_min = dom_map.at(s->outer)->min;
256 PrimExpr inner_min = dom_map.at(s->inner)->min;
257 PrimExpr inner = state.at(s->inner);
258 PrimExpr outer = state.at(s->outer);
259 CHECK(is_zero(outer_min));
260 CHECK(is_zero(inner_min));
261 state[s->fused] = outer * factor + inner;
262 } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
263 if (!state.count(s->rebased)) {
264 CHECK(allow_missing);
265 continue;
266 }
267 PrimExpr value = state.at(s->parent);
268 PrimExpr parent_min = dom_map.at(s->parent)->min;
269 CHECK(is_zero(parent_min));
270 state[s->rebased] = value;
271 } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
272 state[s->iter] = make_zero(s->iter->var.dtype());
273 } else {
274 LOG(FATAL) << "unknown relation type";
275 }
276 }
277 }
278
279 // Domain message passing.
PassUpDomain(const SplitNode * s,const std::unordered_map<IterVar,Range> & dom_map,const IntSet & outer,const IntSet & inner,IntSet * parent)280 void PassUpDomain(const SplitNode* s, const std::unordered_map<IterVar, Range>& dom_map,
281 const IntSet& outer, const IntSet& inner, IntSet* parent) {
282 if (dom_map.count(s->outer) && dom_map.count(s->inner) && dom_map.count(s->parent) &&
283 outer.MatchRange(dom_map.at(s->outer)) && inner.MatchRange(dom_map.at(s->inner))) {
284 *parent = IntSet::FromRange(dom_map.at(s->parent));
285 return;
286 }
287 PrimExpr factor = dom_map.at(s->inner)->extent;
288 PrimExpr parent_min = dom_map.at(s->parent)->min;
289 CHECK(outer.defined());
290 CHECK(inner.defined());
291 CHECK(factor.defined());
292 *parent = arith::EvalSet(s->outer->var * factor + s->inner->var + parent_min,
293 {{s->outer, outer}, {s->inner, inner}});
294 }
295
PassUpDomain(const FuseNode * s,const std::unordered_map<IterVar,Range> & dom_map,const IntSet & fused,IntSet * outer,IntSet * inner)296 void PassUpDomain(const FuseNode* s, const std::unordered_map<IterVar, Range>& dom_map,
297 const IntSet& fused, IntSet* outer, IntSet* inner) {
298 CHECK(dom_map.count(s->outer));
299 CHECK(dom_map.count(s->inner));
300 CHECK(dom_map.count(s->fused));
301 arith::Analyzer ana;
302
303 if (fused.MatchRange(dom_map.at(s->fused))) {
304 *outer = IntSet::FromRange(dom_map.at(s->outer));
305 *inner = IntSet::FromRange(dom_map.at(s->inner));
306 return;
307 }
308 PrimExpr outer_min = dom_map.at(s->outer)->min;
309 PrimExpr inner_min = dom_map.at(s->inner)->min;
310
311 if (fused.IsSinglePoint()) {
312 PrimExpr value = fused.PointValue();
313 PrimExpr factor = dom_map.at(s->inner)->extent;
314 PrimExpr v_outer = indexdiv(value, factor);
315 PrimExpr v_inner = indexmod(value, factor);
316 if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
317 if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
318 *outer = IntSet::SinglePoint(v_outer);
319 *inner = IntSet::SinglePoint(v_inner);
320 } else {
321 PrimExpr fused_extent = (fused.max() - fused.min() + 1);
322 PrimExpr inner_extent = dom_map.at(s->inner)->extent;
323 *outer = IntSet::Interval(outer_min + indexdiv(fused.min(), inner_extent),
324 outer_min + indexdiv(fused.max(), inner_extent));
325 if (is_zero(ana.Simplify(indexmod(inner_extent, fused_extent))) &&
326 is_zero(ana.Simplify(indexmod(fused.min(), fused_extent)))) {
327 // fused never spans multiple rows, make a tight bounding box
328 // there may be other cases when bounding box could be tightened
329 *inner = IntSet::Interval(inner_min + indexmod(fused.min(), inner_extent),
330 inner_min + indexmod(fused.max(), inner_extent));
331 } else { // fused may span multiple rows, use full row widths
332 if (!is_zero(ana.Simplify(indexmod(fused_extent, inner_extent))) ||
333 !is_zero(ana.Simplify(indexmod(fused.min(), inner_extent)))) {
334 LOG(WARNING)
335 << "fused and original axes are not aligned, this may cause redundant computations";
336 }
337 *inner = IntSet::FromRange(dom_map.at(s->inner));
338 }
339 return;
340 }
341 }
342
PassUpDomain(const RebaseNode * s,const std::unordered_map<IterVar,Range> & dom_map,const IntSet & rebased,IntSet * parent)343 void PassUpDomain(const RebaseNode* s, const std::unordered_map<IterVar, Range>& dom_map,
344 const IntSet& rebased, IntSet* parent) {
345 CHECK(dom_map.count(s->parent));
346 if (rebased.MatchRange(dom_map.at(s->rebased))) {
347 *parent = IntSet::FromRange(dom_map.at(s->parent));
348 return;
349 }
350 PrimExpr parent_min = dom_map.at(s->parent)->min;
351 *parent = arith::EvalSet(s->rebased->var + parent_min, {{s->rebased, rebased}});
352 }
353
PassUpDomain(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,std::unordered_map<IterVar,IntSet> * p_state)354 void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
355 std::unordered_map<IterVar, IntSet>* p_state) {
356 auto& state = *p_state;
357 for (size_t i = stage->relations.size(); i != 0; --i) {
358 IterVarRelation rel = stage->relations[i - 1];
359 if (const SplitNode* r = rel.as<SplitNode>()) {
360 IntSet parent;
361 PassUpDomain(r, dom_map, state.at(r->outer), state.at(r->inner), &parent);
362 state[r->parent] = parent;
363 } else if (const FuseNode* r = rel.as<FuseNode>()) {
364 IntSet outer, inner;
365 PassUpDomain(r, dom_map, state.at(r->fused), &outer, &inner);
366 state[r->outer] = outer;
367 state[r->inner] = inner;
368 } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
369 IntSet parent;
370 PassUpDomain(r, dom_map, state.at(r->rebased), &parent);
371 state[r->parent] = parent;
372 } else if (rel.as<SingletonNode>()) {
373 } else {
374 LOG(FATAL) << "unknown relation type";
375 }
376 }
377 }
378
379 // Pass up bit mask with or relation.
PassUpBitMaskOr(const Stage & stage,std::unordered_map<IterVar,int> * p_state,bool allow_missing)380 void PassUpBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state,
381 bool allow_missing) {
382 auto& state = *p_state;
383 for (size_t i = stage->relations.size(); i != 0; --i) {
384 IterVarRelation rel = stage->relations[i - 1];
385 if (const SplitNode* s = rel.as<SplitNode>()) {
386 if (!state.count(s->inner) && !state.count(s->outer)) {
387 CHECK(allow_missing);
388 continue;
389 }
390 int res = 0;
391 if (state.count(s->parent)) res |= state[s->parent];
392 if (state.count(s->inner)) res |= state[s->inner];
393 if (state.count(s->outer)) res |= state[s->outer];
394 state[s->parent] = res;
395 } else if (const FuseNode* s = rel.as<FuseNode>()) {
396 if (!state.count(s->fused)) {
397 CHECK(allow_missing);
398 continue;
399 }
400 if (!state.count(s->outer)) {
401 state[s->outer] = state[s->fused];
402 } else {
403 state[s->outer] |= state[s->fused];
404 }
405 if (!state.count(s->inner)) {
406 state[s->inner] = state[s->fused];
407 } else {
408 state[s->inner] |= state[s->fused];
409 }
410 } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
411 if (!state.count(s->rebased)) {
412 CHECK(allow_missing);
413 continue;
414 }
415 if (!state.count(s->parent)) {
416 state[s->parent] = state[s->rebased];
417 } else {
418 state[s->parent] |= state[s->rebased];
419 }
420 } else if (rel.as<SingletonNode>()) {
421 } else {
422 LOG(FATAL) << "unknown relation type";
423 }
424 }
425 }
426
PassDownBitMaskOr(const Stage & stage,std::unordered_map<IterVar,int> * p_state,bool allow_missing)427 void PassDownBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state,
428 bool allow_missing) {
429 auto& state = *p_state;
430 for (IterVarRelation rel : stage->relations) {
431 if (const SplitNode* s = rel.as<SplitNode>()) {
432 if (!state.count(s->parent)) {
433 CHECK(allow_missing);
434 continue;
435 }
436 if (!state.count(s->outer)) {
437 state[s->outer] = state.at(s->parent);
438 } else {
439 state[s->outer] |= state.at(s->parent);
440 }
441 if (!state.count(s->inner)) {
442 state[s->inner] = state.at(s->parent);
443 } else {
444 state[s->inner] |= state.at(s->parent);
445 }
446 } else if (const FuseNode* s = rel.as<FuseNode>()) {
447 if (!state.count(s->outer) && !state.count(s->inner)) {
448 CHECK(allow_missing);
449 continue;
450 }
451 int res = 0;
452 if (state.count(s->outer)) res |= state.at(s->outer);
453 if (state.count(s->inner)) res |= state.at(s->inner);
454 if (state.count(s->fused)) res |= state.at(s->fused);
455 state[s->fused] = res;
456 } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
457 if (!state.count(s->parent)) {
458 CHECK(allow_missing);
459 continue;
460 }
461 if (!state.count(s->rebased)) {
462 state[s->rebased] = state.at(s->parent);
463 } else {
464 state[s->rebased] |= state.at(s->parent);
465 }
466 } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
467 state[s->iter] = 0;
468 } else {
469 LOG(FATAL) << "unknown relation type";
470 }
471 }
472 }
473
474 /*!
475 * \brief message passing to find if boundary checking on IterVar is needed.
476 * \param s The stage to be used.
477 * \param p_state The message passing state
478 * IterVar->flag
479 */
PassUpBoundCheck(const Stage & s,const Map<IterVar,Range> & dom_map,std::unordered_map<IterVar,bool> * p_state,arith::Analyzer * analyzer)480 void PassUpBoundCheck(const Stage& s, const Map<IterVar, Range>& dom_map,
481 std::unordered_map<IterVar, bool>* p_state, arith::Analyzer* analyzer) {
482 auto& state = *p_state;
483 for (size_t i = s->relations.size(); i != 0; --i) {
484 IterVarRelation rel = s->relations[i - 1];
485 if (const SplitNode* s = rel.as<SplitNode>()) {
486 bool outer = state.at(s->outer);
487 bool inner = state.at(s->inner);
488
489 if (dom_map.count(s->inner) && dom_map.count(s->outer)) {
490 PrimExpr factor = dom_map.at(s->inner)->extent;
491 PrimExpr step = dom_map.at(s->outer)->extent;
492 if (outer || inner) {
493 state[s->parent] = true;
494 } else {
495 if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) {
496 state[s->parent] = false;
497 } else {
498 state[s->parent] = true;
499 }
500 }
501 } else {
502 state[s->parent] = true;
503 }
504 } else if (const FuseNode* s = rel.as<FuseNode>()) {
505 bool fused = state.at(s->fused);
506 state[s->outer] = fused;
507 state[s->inner] = fused;
508 } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
509 state[s->parent] = state.at(s->rebased);
510 } else if (rel.as<SingletonNode>()) {
511 // nop
512 } else {
513 LOG(FATAL) << "unknown relation type";
514 }
515 }
516 }
517
IsRangeSame(const Range input_1,const Range input_2)518 bool IsRangeSame(const Range input_1, const Range input_2) {
519 arith::Analyzer analyzer;
520 if (input_1.same_as(input_2)) return true;
521
522 return (analyzer.CanProve(input_1->min == input_2->min) &&
523 analyzer.CanProve(input_1->extent == input_2->extent));
524 }
525
MakeBoundCheck(const Stage & stage,const Map<IterVar,Range> & dom_map,const std::unordered_map<IterVar,PrimExpr> & value_map,bool skip_ivar_domain,const std::unordered_set<IterVar> & skip_iter)526 std::vector<PrimExpr> MakeBoundCheck(const Stage& stage, const Map<IterVar, Range>& dom_map,
527 const std::unordered_map<IterVar, PrimExpr>& value_map,
528 bool skip_ivar_domain,
529 const std::unordered_set<IterVar>& skip_iter) {
530 arith::Analyzer analyzer;
531
532 std::unordered_map<IterVar, bool> bound_state;
533 for (IterVar iv : stage->leaf_iter_vars) {
534 bound_state[iv] = false;
535 }
536 PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
537
538 std::vector<PrimExpr> preds;
539 Map<Var, IntSet> iset_dmap;
540
541 // setup domain map for set analysis
542 for (const auto& kv : dom_map) {
543 iset_dmap.Set(kv.first->var, IntSet::FromRange(kv.second));
544 }
545
546 for (auto entry : dom_map) {
547 analyzer.Bind(entry.first->var, entry.second);
548 }
549
550 for (const IterVar& iv : stage->all_iter_vars) {
551 if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
552 if (bound_state.at(iv)) {
553 Range dom = dom_map.at(iv);
554 PrimExpr value = value_map.at(iv) - dom->min;
555 PrimExpr vmax = analyzer.int_set(value, iset_dmap).max();
556 if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
557 preds.emplace_back(value < dom->extent);
558 }
559 }
560 }
561 for (const IterVar& iv : stage->op->root_iter_vars()) {
562 if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
563 Range dom = dom_map.at(iv);
564 CHECK(iv->dom.defined());
565 if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) {
566 PrimExpr value = value_map.at(iv) - iv->dom->min;
567 IntSet s = analyzer.int_set(value, iset_dmap);
568 PrimExpr vmin = s.min();
569 PrimExpr vmax = s.max();
570 // The range of `value` resides in [vmin, vmax]
571 if (vmin.dtype() != value.dtype() || !analyzer.CanProve(vmin >= 0)) {
572 preds.emplace_back(value >= 0);
573 }
574 if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < iv->dom->extent)) {
575 preds.emplace_back(value < iv->dom->extent);
576 }
577 }
578 }
579 return preds;
580 }
581 } // namespace te
582 } // namespace tvm
583