1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file message_passing.cc
22  * \brief The message passing domain.
23  */
24 #include "message_passing.h"
25 
26 #include <tvm/arith/analyzer.h>
27 #include <tvm/tir/expr.h>
28 
29 namespace tvm {
30 namespace te {
31 
32 using namespace tir;
33 
Update(std::unordered_map<IterVar,Range> * p_state,const IterVar & iv,Range r,arith::Analyzer * analyzer)34 void Update(std::unordered_map<IterVar, Range>* p_state, const IterVar& iv, Range r,
35             arith::Analyzer* analyzer) {
36   auto it = p_state->find(iv);
37   if (it == p_state->end()) {
38     (*p_state)[iv] = r;
39     analyzer->Bind(iv->var, r);
40   } else {
41     bool match =
42         is_zero(it->second->min) && analyzer->CanProve(r->extent - it->second->extent == 0);
43     CHECK(match) << iv << " domain already inferred,"
44                  << " cannot prove their extents are the same " << it->second->extent << " vs "
45                  << r->extent;
46   }
47 }
48 
49 /*!
50  * \param Upward propagating whether an IterVar derives at least one leaf IterVar that binds to
51  * a thread.
52  *
53  * \param stage The stage to operate on.
54  * \param p_state The propagation result of each IterVar.
55  */
PassUpThreadBinding(const Stage & stage,std::unordered_map<IterVar,bool> * p_state)56 void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar, bool>* p_state) {
57   auto bound_to_thread = [&stage](const IterVar& iv) {
58     bool bound = false;
59     auto it = stage->iter_var_attrs.find(iv);
60     if (it != stage->iter_var_attrs.end()) {
61       bound = (*it).second->bind_thread.defined();
62     }
63     return bound;
64   };
65 
66   auto& state = *p_state;
67   // Fill p_state with leaf itervars
68   for (const IterVar& iv : stage->leaf_iter_vars) {
69     state[iv] = bound_to_thread(iv);
70   }
71   // Traverse the graph bottom-up to propagate thread binding information
72   for (size_t i = stage->relations.size(); i != 0; --i) {
73     IterVarRelation rel = stage->relations[i - 1];
74     if (const SplitNode* s = rel.as<SplitNode>()) {
75       state[s->parent] = state[s->inner] || state[s->outer];
76     } else if (const FuseNode* s = rel.as<FuseNode>()) {
77       state[s->inner] = state[s->fused];
78       state[s->outer] = state[s->fused];
79     } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
80       state[s->parent] = state[s->rebased];
81     } else if (rel.as<SingletonNode>()) {
82     } else {
83       LOG(FATAL) << "unknown relation type";
84     }
85   }
86 }
87 
PassDownDomain(const Stage & stage,std::unordered_map<IterVar,Range> * p_state,arith::Analyzer * actx,bool allow_missing)88 void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_state,
89                     arith::Analyzer* actx, bool allow_missing) {
90   auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) {
91     if (actx->CanProve(indexmod(a, b) == 0)) {
92       return actx->Simplify(indexdiv(a, b));
93     }
94     return actx->Simplify(indexdiv(a + (b - 1), b));
95   };
96 
97   auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) {
98     if (actx->CanProve(a < b)) {
99       return actx->Simplify(a);
100     }
101     return actx->Simplify(b);
102   };
103 
104   std::unordered_map<IterVar, bool> dominating_thread;
105   PassUpThreadBinding(stage, &dominating_thread);
106 
107   auto& state = *p_state;
108   // forwar iteration on relations
109   for (IterVarRelation rel : stage->relations) {
110     if (const SplitNode* r = rel.as<SplitNode>()) {
111       if (!state.count(r->parent)) {
112         CHECK(allow_missing);
113         continue;
114       }
115       CHECK(!state.count(r->inner));
116       const Range& range_parent = state.at(r->parent);
117       // Tighten iv's extent to min(parent_extent, factor_or_nparts), only if all of the
118       // following conditions are met:
119       // 1. No leaf IterVar derived from iv binds to any thread.  People may use split
120       // to force an IterVar extent to match the number of allocated threads to fuse stages
121       // that require different number of threads.  We don't want to change these extents.
122       // 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound,
123       // rather than by an early compiler phase, such as rfactor().  We don't want to tighten an
124       // IterVar in an early phase allowing missing IterVars, because it may bind to a thread later.
125       // 3. range_parent's extent is not 0.  At lest one Topi test has a case where a tensor has one
126       // zero-sized dimension.  Split creates iv with a positive extent to avoid zero-extent
127       // IterVar.  We don't touch it.
128       auto resolve_min_extent_for_split = [&](const IterVar& iv, const PrimExpr& factor_or_nparts) {
129         return dominating_thread[iv] || allow_missing || is_zero(range_parent->extent)
130                    ? factor_or_nparts
131                    : minimum_or_later(range_parent->extent, factor_or_nparts);
132       };
133       if (r->factor.defined()) {
134         Update(p_state, r->inner,
135                Range::FromMinExtent(0, resolve_min_extent_for_split(r->inner, r->factor)), actx);
136         Update(p_state, r->outer,
137                Range::FromMinExtent(0, ceil_div(range_parent->extent, r->factor)), actx);
138       } else {
139         Update(p_state, r->outer,
140                Range::FromMinExtent(0, resolve_min_extent_for_split(r->outer, r->nparts)), actx);
141         Update(p_state, r->inner,
142                Range::FromMinExtent(0, ceil_div(range_parent->extent, r->nparts)), actx);
143       }
144     } else if (const FuseNode* r = rel.as<FuseNode>()) {
145       if (!state.count(r->outer) || !state.count(r->inner)) {
146         CHECK(allow_missing);
147         continue;
148       }
149       const Range& range_outer = state.at(r->outer);
150       const Range& range_inner = state.at(r->inner);
151       state[r->fused] = Range::FromMinExtent(0, range_outer->extent * range_inner->extent);
152     } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
153       if (!state.count(r->parent)) {
154         CHECK(allow_missing);
155         continue;
156       }
157       Update(p_state, r->rebased, Range::FromMinExtent(0, state.at(r->parent)->extent), actx);
158     } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
159       Update(p_state, s->iter, Range::FromMinExtent(0, 1), actx);
160     } else {
161       LOG(FATAL) << "unknown relation type";
162     }
163   }
164   // update the extents of binded threads.
165   for (auto kv : stage->iter_var_attrs) {
166     if (kv.second->bind_thread.defined()) {
167       CHECK(state.count(kv.first));
168       Update(p_state, kv.second->bind_thread, state.at(kv.first), actx);
169     }
170   }
171 }
172 
PassUpIndex(const Stage & stage,const Map<IterVar,Range> & dom_map,std::unordered_map<IterVar,PrimExpr> * p_state,bool allow_missing)173 void PassUpIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
174                  std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing) {
175   auto& state = *p_state;
176   for (size_t i = stage->relations.size(); i != 0; --i) {
177     IterVarRelation rel = stage->relations[i - 1];
178     if (const SplitNode* s = rel.as<SplitNode>()) {
179       if (!state.count(s->outer) || !state.count(s->inner)) {
180         CHECK(allow_missing);
181         continue;
182       }
183       PrimExpr outer = state.at(s->outer);
184       PrimExpr inner = state.at(s->inner);
185       PrimExpr factor = dom_map.at(s->inner)->extent;
186       PrimExpr parent_min = dom_map.at(s->parent)->min;
187       state[s->parent] = inner + outer * factor;
188       // add min if they exist
189       if (!is_zero(parent_min)) {
190         state[s->parent] = state[s->parent] + parent_min;
191       }
192     } else if (const FuseNode* s = rel.as<FuseNode>()) {
193       if (!state.count(s->fused)) {
194         CHECK(allow_missing);
195         continue;
196       }
197       PrimExpr value = state.at(s->fused);
198       PrimExpr factor = dom_map.at(s->inner)->extent;
199       PrimExpr outer_min = dom_map.at(s->outer)->min;
200       PrimExpr inner_min = dom_map.at(s->inner)->min;
201       state[s->outer] = indexdiv(value, factor);
202       state[s->inner] = indexmod(value, factor);
203       // add min if they exist
204       if (!is_zero(outer_min)) {
205         state[s->outer] = state[s->outer] + outer_min;
206       }
207       if (!is_zero(inner_min)) {
208         state[s->inner] = state[s->inner] + inner_min;
209       }
210       // s->fused, s->outer and s->inner may be of different dtype,
211       // so we cast the `state` back to its original dtype
212       state[s->outer] = cast(s->outer->var.dtype(), state[s->outer]);
213       state[s->inner] = cast(s->inner->var.dtype(), state[s->inner]);
214     } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
215       if (!state.count(s->rebased)) {
216         CHECK(allow_missing);
217         continue;
218       }
219       PrimExpr value = state.at(s->rebased);
220       PrimExpr parent_min = dom_map.at(s->parent)->min;
221       // add min if they exist
222       if (!is_zero(parent_min)) {
223         state[s->parent] = value + parent_min;
224       } else {
225         state[s->parent] = value;
226       }
227     } else if (rel.as<SingletonNode>()) {
228     } else {
229       LOG(FATAL) << "unknown relation type";
230     }
231   }
232 }
233 
PassDownIndex(const Stage & stage,const Map<IterVar,Range> & dom_map,std::unordered_map<IterVar,PrimExpr> * p_state,bool allow_missing)234 void PassDownIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
235                    std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing) {
236   auto& state = *p_state;
237   for (IterVarRelation rel : stage->relations) {
238     if (const SplitNode* s = rel.as<SplitNode>()) {
239       if (!state.count(s->parent)) {
240         CHECK(allow_missing);
241         continue;
242       }
243       Range r = dom_map.at(s->inner);
244       CHECK(is_zero(r->min));
245       PrimExpr parent = state.at(s->parent);
246       PrimExpr factor = r->extent;
247       state[s->outer] = indexdiv(parent, factor);
248       state[s->inner] = indexmod(parent, factor);
249     } else if (const FuseNode* s = rel.as<FuseNode>()) {
250       if (!state.count(s->inner) && !state.count(s->outer)) {
251         CHECK(allow_missing);
252         continue;
253       }
254       PrimExpr factor = dom_map.at(s->inner)->extent;
255       PrimExpr outer_min = dom_map.at(s->outer)->min;
256       PrimExpr inner_min = dom_map.at(s->inner)->min;
257       PrimExpr inner = state.at(s->inner);
258       PrimExpr outer = state.at(s->outer);
259       CHECK(is_zero(outer_min));
260       CHECK(is_zero(inner_min));
261       state[s->fused] = outer * factor + inner;
262     } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
263       if (!state.count(s->rebased)) {
264         CHECK(allow_missing);
265         continue;
266       }
267       PrimExpr value = state.at(s->parent);
268       PrimExpr parent_min = dom_map.at(s->parent)->min;
269       CHECK(is_zero(parent_min));
270       state[s->rebased] = value;
271     } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
272       state[s->iter] = make_zero(s->iter->var.dtype());
273     } else {
274       LOG(FATAL) << "unknown relation type";
275     }
276   }
277 }
278 
279 // Domain message passing.
PassUpDomain(const SplitNode * s,const std::unordered_map<IterVar,Range> & dom_map,const IntSet & outer,const IntSet & inner,IntSet * parent)280 void PassUpDomain(const SplitNode* s, const std::unordered_map<IterVar, Range>& dom_map,
281                   const IntSet& outer, const IntSet& inner, IntSet* parent) {
282   if (dom_map.count(s->outer) && dom_map.count(s->inner) && dom_map.count(s->parent) &&
283       outer.MatchRange(dom_map.at(s->outer)) && inner.MatchRange(dom_map.at(s->inner))) {
284     *parent = IntSet::FromRange(dom_map.at(s->parent));
285     return;
286   }
287   PrimExpr factor = dom_map.at(s->inner)->extent;
288   PrimExpr parent_min = dom_map.at(s->parent)->min;
289   CHECK(outer.defined());
290   CHECK(inner.defined());
291   CHECK(factor.defined());
292   *parent = arith::EvalSet(s->outer->var * factor + s->inner->var + parent_min,
293                            {{s->outer, outer}, {s->inner, inner}});
294 }
295 
PassUpDomain(const FuseNode * s,const std::unordered_map<IterVar,Range> & dom_map,const IntSet & fused,IntSet * outer,IntSet * inner)296 void PassUpDomain(const FuseNode* s, const std::unordered_map<IterVar, Range>& dom_map,
297                   const IntSet& fused, IntSet* outer, IntSet* inner) {
298   CHECK(dom_map.count(s->outer));
299   CHECK(dom_map.count(s->inner));
300   CHECK(dom_map.count(s->fused));
301   arith::Analyzer ana;
302 
303   if (fused.MatchRange(dom_map.at(s->fused))) {
304     *outer = IntSet::FromRange(dom_map.at(s->outer));
305     *inner = IntSet::FromRange(dom_map.at(s->inner));
306     return;
307   }
308   PrimExpr outer_min = dom_map.at(s->outer)->min;
309   PrimExpr inner_min = dom_map.at(s->inner)->min;
310 
311   if (fused.IsSinglePoint()) {
312     PrimExpr value = fused.PointValue();
313     PrimExpr factor = dom_map.at(s->inner)->extent;
314     PrimExpr v_outer = indexdiv(value, factor);
315     PrimExpr v_inner = indexmod(value, factor);
316     if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
317     if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
318     *outer = IntSet::SinglePoint(v_outer);
319     *inner = IntSet::SinglePoint(v_inner);
320   } else {
321     PrimExpr fused_extent = (fused.max() - fused.min() + 1);
322     PrimExpr inner_extent = dom_map.at(s->inner)->extent;
323     *outer = IntSet::Interval(outer_min + indexdiv(fused.min(), inner_extent),
324                               outer_min + indexdiv(fused.max(), inner_extent));
325     if (is_zero(ana.Simplify(indexmod(inner_extent, fused_extent))) &&
326         is_zero(ana.Simplify(indexmod(fused.min(), fused_extent)))) {
327       // fused never spans multiple rows, make a tight bounding box
328       // there may be other cases when bounding box could be tightened
329       *inner = IntSet::Interval(inner_min + indexmod(fused.min(), inner_extent),
330                                 inner_min + indexmod(fused.max(), inner_extent));
331     } else {  // fused may span multiple rows, use full row widths
332       if (!is_zero(ana.Simplify(indexmod(fused_extent, inner_extent))) ||
333           !is_zero(ana.Simplify(indexmod(fused.min(), inner_extent)))) {
334         LOG(WARNING)
335             << "fused and original axes are not aligned, this may cause redundant computations";
336       }
337       *inner = IntSet::FromRange(dom_map.at(s->inner));
338     }
339     return;
340   }
341 }
342 
PassUpDomain(const RebaseNode * s,const std::unordered_map<IterVar,Range> & dom_map,const IntSet & rebased,IntSet * parent)343 void PassUpDomain(const RebaseNode* s, const std::unordered_map<IterVar, Range>& dom_map,
344                   const IntSet& rebased, IntSet* parent) {
345   CHECK(dom_map.count(s->parent));
346   if (rebased.MatchRange(dom_map.at(s->rebased))) {
347     *parent = IntSet::FromRange(dom_map.at(s->parent));
348     return;
349   }
350   PrimExpr parent_min = dom_map.at(s->parent)->min;
351   *parent = arith::EvalSet(s->rebased->var + parent_min, {{s->rebased, rebased}});
352 }
353 
PassUpDomain(const Stage & stage,const std::unordered_map<IterVar,Range> & dom_map,std::unordered_map<IterVar,IntSet> * p_state)354 void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
355                   std::unordered_map<IterVar, IntSet>* p_state) {
356   auto& state = *p_state;
357   for (size_t i = stage->relations.size(); i != 0; --i) {
358     IterVarRelation rel = stage->relations[i - 1];
359     if (const SplitNode* r = rel.as<SplitNode>()) {
360       IntSet parent;
361       PassUpDomain(r, dom_map, state.at(r->outer), state.at(r->inner), &parent);
362       state[r->parent] = parent;
363     } else if (const FuseNode* r = rel.as<FuseNode>()) {
364       IntSet outer, inner;
365       PassUpDomain(r, dom_map, state.at(r->fused), &outer, &inner);
366       state[r->outer] = outer;
367       state[r->inner] = inner;
368     } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
369       IntSet parent;
370       PassUpDomain(r, dom_map, state.at(r->rebased), &parent);
371       state[r->parent] = parent;
372     } else if (rel.as<SingletonNode>()) {
373     } else {
374       LOG(FATAL) << "unknown relation type";
375     }
376   }
377 }
378 
379 // Pass up bit mask with or relation.
PassUpBitMaskOr(const Stage & stage,std::unordered_map<IterVar,int> * p_state,bool allow_missing)380 void PassUpBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state,
381                      bool allow_missing) {
382   auto& state = *p_state;
383   for (size_t i = stage->relations.size(); i != 0; --i) {
384     IterVarRelation rel = stage->relations[i - 1];
385     if (const SplitNode* s = rel.as<SplitNode>()) {
386       if (!state.count(s->inner) && !state.count(s->outer)) {
387         CHECK(allow_missing);
388         continue;
389       }
390       int res = 0;
391       if (state.count(s->parent)) res |= state[s->parent];
392       if (state.count(s->inner)) res |= state[s->inner];
393       if (state.count(s->outer)) res |= state[s->outer];
394       state[s->parent] = res;
395     } else if (const FuseNode* s = rel.as<FuseNode>()) {
396       if (!state.count(s->fused)) {
397         CHECK(allow_missing);
398         continue;
399       }
400       if (!state.count(s->outer)) {
401         state[s->outer] = state[s->fused];
402       } else {
403         state[s->outer] |= state[s->fused];
404       }
405       if (!state.count(s->inner)) {
406         state[s->inner] = state[s->fused];
407       } else {
408         state[s->inner] |= state[s->fused];
409       }
410     } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
411       if (!state.count(s->rebased)) {
412         CHECK(allow_missing);
413         continue;
414       }
415       if (!state.count(s->parent)) {
416         state[s->parent] = state[s->rebased];
417       } else {
418         state[s->parent] |= state[s->rebased];
419       }
420     } else if (rel.as<SingletonNode>()) {
421     } else {
422       LOG(FATAL) << "unknown relation type";
423     }
424   }
425 }
426 
PassDownBitMaskOr(const Stage & stage,std::unordered_map<IterVar,int> * p_state,bool allow_missing)427 void PassDownBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state,
428                        bool allow_missing) {
429   auto& state = *p_state;
430   for (IterVarRelation rel : stage->relations) {
431     if (const SplitNode* s = rel.as<SplitNode>()) {
432       if (!state.count(s->parent)) {
433         CHECK(allow_missing);
434         continue;
435       }
436       if (!state.count(s->outer)) {
437         state[s->outer] = state.at(s->parent);
438       } else {
439         state[s->outer] |= state.at(s->parent);
440       }
441       if (!state.count(s->inner)) {
442         state[s->inner] = state.at(s->parent);
443       } else {
444         state[s->inner] |= state.at(s->parent);
445       }
446     } else if (const FuseNode* s = rel.as<FuseNode>()) {
447       if (!state.count(s->outer) && !state.count(s->inner)) {
448         CHECK(allow_missing);
449         continue;
450       }
451       int res = 0;
452       if (state.count(s->outer)) res |= state.at(s->outer);
453       if (state.count(s->inner)) res |= state.at(s->inner);
454       if (state.count(s->fused)) res |= state.at(s->fused);
455       state[s->fused] = res;
456     } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
457       if (!state.count(s->parent)) {
458         CHECK(allow_missing);
459         continue;
460       }
461       if (!state.count(s->rebased)) {
462         state[s->rebased] = state.at(s->parent);
463       } else {
464         state[s->rebased] |= state.at(s->parent);
465       }
466     } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
467       state[s->iter] = 0;
468     } else {
469       LOG(FATAL) << "unknown relation type";
470     }
471   }
472 }
473 
474 /*!
475  * \brief message passing to find if boundary checking on IterVar is needed.
476  * \param s The stage to be used.
477  * \param p_state The message passing state
478  *     IterVar->flag
479  */
PassUpBoundCheck(const Stage & s,const Map<IterVar,Range> & dom_map,std::unordered_map<IterVar,bool> * p_state,arith::Analyzer * analyzer)480 void PassUpBoundCheck(const Stage& s, const Map<IterVar, Range>& dom_map,
481                       std::unordered_map<IterVar, bool>* p_state, arith::Analyzer* analyzer) {
482   auto& state = *p_state;
483   for (size_t i = s->relations.size(); i != 0; --i) {
484     IterVarRelation rel = s->relations[i - 1];
485     if (const SplitNode* s = rel.as<SplitNode>()) {
486       bool outer = state.at(s->outer);
487       bool inner = state.at(s->inner);
488 
489       if (dom_map.count(s->inner) && dom_map.count(s->outer)) {
490         PrimExpr factor = dom_map.at(s->inner)->extent;
491         PrimExpr step = dom_map.at(s->outer)->extent;
492         if (outer || inner) {
493           state[s->parent] = true;
494         } else {
495           if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) {
496             state[s->parent] = false;
497           } else {
498             state[s->parent] = true;
499           }
500         }
501       } else {
502         state[s->parent] = true;
503       }
504     } else if (const FuseNode* s = rel.as<FuseNode>()) {
505       bool fused = state.at(s->fused);
506       state[s->outer] = fused;
507       state[s->inner] = fused;
508     } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
509       state[s->parent] = state.at(s->rebased);
510     } else if (rel.as<SingletonNode>()) {
511       // nop
512     } else {
513       LOG(FATAL) << "unknown relation type";
514     }
515   }
516 }
517 
IsRangeSame(const Range input_1,const Range input_2)518 bool IsRangeSame(const Range input_1, const Range input_2) {
519   arith::Analyzer analyzer;
520   if (input_1.same_as(input_2)) return true;
521 
522   return (analyzer.CanProve(input_1->min == input_2->min) &&
523           analyzer.CanProve(input_1->extent == input_2->extent));
524 }
525 
MakeBoundCheck(const Stage & stage,const Map<IterVar,Range> & dom_map,const std::unordered_map<IterVar,PrimExpr> & value_map,bool skip_ivar_domain,const std::unordered_set<IterVar> & skip_iter)526 std::vector<PrimExpr> MakeBoundCheck(const Stage& stage, const Map<IterVar, Range>& dom_map,
527                                      const std::unordered_map<IterVar, PrimExpr>& value_map,
528                                      bool skip_ivar_domain,
529                                      const std::unordered_set<IterVar>& skip_iter) {
530   arith::Analyzer analyzer;
531 
532   std::unordered_map<IterVar, bool> bound_state;
533   for (IterVar iv : stage->leaf_iter_vars) {
534     bound_state[iv] = false;
535   }
536   PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
537 
538   std::vector<PrimExpr> preds;
539   Map<Var, IntSet> iset_dmap;
540 
541   // setup domain map for set analysis
542   for (const auto& kv : dom_map) {
543     iset_dmap.Set(kv.first->var, IntSet::FromRange(kv.second));
544   }
545 
546   for (auto entry : dom_map) {
547     analyzer.Bind(entry.first->var, entry.second);
548   }
549 
550   for (const IterVar& iv : stage->all_iter_vars) {
551     if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
552     if (bound_state.at(iv)) {
553       Range dom = dom_map.at(iv);
554       PrimExpr value = value_map.at(iv) - dom->min;
555       PrimExpr vmax = analyzer.int_set(value, iset_dmap).max();
556       if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
557         preds.emplace_back(value < dom->extent);
558       }
559     }
560   }
561   for (const IterVar& iv : stage->op->root_iter_vars()) {
562     if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
563     Range dom = dom_map.at(iv);
564     CHECK(iv->dom.defined());
565     if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) {
566       PrimExpr value = value_map.at(iv) - iv->dom->min;
567       IntSet s = analyzer.int_set(value, iset_dmap);
568       PrimExpr vmin = s.min();
569       PrimExpr vmax = s.max();
570       // The range of `value` resides in [vmin, vmax]
571       if (vmin.dtype() != value.dtype() || !analyzer.CanProve(vmin >= 0)) {
572         preds.emplace_back(value >= 0);
573       }
574       if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < iv->dom->extent)) {
575         preds.emplace_back(value < iv->dom->extent);
576       }
577     }
578   }
579   return preds;
580 }
581 }  // namespace te
582 }  // namespace tvm
583