1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file tvm/arith/solve_linear_inequality.cc
22 * \brief Solve linear inequalities.
23 */
24 #include <tvm/arith/analyzer.h>
25 #include <tvm/arith/int_solver.h>
26 #include <tvm/arith/pattern.h>
27 #include <tvm/runtime/data_type.h>
28 #include <tvm/runtime/registry.h>
29 #include <tvm/tir/analysis.h>
30 #include <tvm/tir/expr.h>
31 #include <tvm/tir/op.h>
32 #include <tvm/tir/stmt_functor.h>
33
34 #include "int_operator.h"
35
36 namespace tvm {
37 namespace arith {
38
39 using namespace tvm::runtime;
40 using namespace tvm::tir;
41
42 #define PLUS_ONE(OP) \
43 void VisitExpr_(const OP* op) final { num_symbols_++; }
44
45 #define PLUS_ONE_BINARY(OP) \
46 void VisitExpr_(const OP* op) final { \
47 num_symbols_++; \
48 VisitExpr(op->a); \
49 VisitExpr(op->b); \
50 }
51
52 /*!
53 * \brief Calculate the expresion complexity based on number of symbols it contains.
54 */
55 class ExprComplexity : public ExprVisitor {
56 public:
Eval(const PrimExpr & expr)57 size_t Eval(const PrimExpr& expr) {
58 VisitExpr(expr);
59 return num_symbols_;
60 }
61
62 PLUS_ONE_BINARY(AddNode)
PLUS_ONE_BINARY(SubNode)63 PLUS_ONE_BINARY(SubNode)
64 PLUS_ONE_BINARY(MulNode)
65 PLUS_ONE_BINARY(DivNode)
66 PLUS_ONE_BINARY(ModNode)
67 PLUS_ONE_BINARY(FloorDivNode)
68 PLUS_ONE_BINARY(FloorModNode)
69 PLUS_ONE_BINARY(MinNode)
70 PLUS_ONE_BINARY(MaxNode)
71 PLUS_ONE_BINARY(EQNode)
72 PLUS_ONE_BINARY(NENode)
73 PLUS_ONE_BINARY(LTNode)
74 PLUS_ONE_BINARY(LENode)
75 PLUS_ONE_BINARY(GTNode)
76 PLUS_ONE_BINARY(GENode)
77 PLUS_ONE_BINARY(AndNode)
78 PLUS_ONE_BINARY(OrNode)
79 PLUS_ONE(VarNode)
80 PLUS_ONE(FloatImmNode)
81 PLUS_ONE(IntImmNode)
82 void VisitExpr_(const NotNode* op) final {
83 num_symbols_++;
84 VisitExpr(op->a);
85 }
86
87 private:
88 size_t num_symbols_{0};
89 };
90
91 struct ExprLess {
operator ()tvm::arith::ExprLess92 bool operator()(const PrimExpr& l, const PrimExpr& r) const {
93 return ExprComplexity().Eval(l) < ExprComplexity().Eval(r);
94 }
95 };
96
DebugPrint(const std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> & current_ineq_set,const std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> & next_ineq_set,const std::vector<PrimExpr> & rest,const std::vector<std::pair<int64_t,PrimExpr>> & coef_pos,const std::vector<std::pair<int64_t,PrimExpr>> & coef_neg)97 void DebugPrint(
98 const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& current_ineq_set,
99 const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& next_ineq_set,
100 const std::vector<PrimExpr>& rest, const std::vector<std::pair<int64_t, PrimExpr>>& coef_pos,
101 const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) {
102 std::cout << "Current ineq set:\n[";
103 for (auto& ineq : current_ineq_set) {
104 std::cout << ineq << ", ";
105 }
106 std::cout << "]\n";
107
108 std::cout << "Next ineq set:\n[";
109 for (auto& ineq : next_ineq_set) {
110 std::cout << ineq << ", ";
111 }
112 std::cout << "]\n";
113
114 std::cout << "coef_pos:\n[";
115 for (auto& coef : coef_pos) {
116 std::cout << "(" << coef.first << ", " << coef.second << "), ";
117 }
118 std::cout << "]\n";
119
120 std::cout << "coef_neg:\n[";
121 for (auto& coef : coef_neg) {
122 std::cout << "(" << coef.first << ", " << coef.second << "), ";
123 }
124 std::cout << "]\n";
125 }
126
127 /*!
128 * \brief normalize to the form `expr <= 0`
129 */
130 class NormalizeComparisons : public ExprMutator {
131 public:
VisitExpr_(const EQNode * op)132 PrimExpr VisitExpr_(const EQNode* op) override { return Make<EQ>(op->a, op->b); }
VisitExpr_(const NENode * op)133 PrimExpr VisitExpr_(const NENode* op) override { return Make<NE>(op->a, op->b); }
VisitExpr_(const LTNode * op)134 PrimExpr VisitExpr_(const LTNode* op) override { return Make<LT>(op->a, op->b); }
VisitExpr_(const LENode * op)135 PrimExpr VisitExpr_(const LENode* op) override { return Make<LE>(op->a, op->b); }
VisitExpr_(const GTNode * op)136 PrimExpr VisitExpr_(const GTNode* op) override { return Make<LT>(op->b, op->a); }
VisitExpr_(const GENode * op)137 PrimExpr VisitExpr_(const GENode* op) override { return Make<LE>(op->b, op->a); }
138
139 private:
140 template <class T>
Make(const PrimExpr & a,const PrimExpr & b)141 PrimExpr Make(const PrimExpr& a, const PrimExpr& b) {
142 // rewrite LT to LE for ints
143 if (std::is_same<T, LT>::value && (a.dtype().is_int() || a.dtype().is_uint())) {
144 return LE(analyzer_.Simplify(a - b + 1), make_zero(a.dtype()));
145 }
146 return T(analyzer_.Simplify(a - b), make_zero(a.dtype()));
147 }
148 arith::Analyzer analyzer_;
149 };
150
AddInequality(std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> * inequality_set,const PrimExpr & new_ineq,Analyzer * analyzer)151 void AddInequality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* inequality_set,
152 const PrimExpr& new_ineq, Analyzer* analyzer) {
153 if (analyzer->CanProve(new_ineq) || inequality_set->find(new_ineq) != inequality_set->end()) {
154 // redundant: follows from the vranges
155 // or has already been added
156 return;
157 }
158 if (const LENode* new_le = new_ineq.as<LENode>()) {
159 for (auto iter = inequality_set->begin(); iter != inequality_set->end();) {
160 const LENode* le = iter->as<LENode>();
161 if (le && analyzer->CanProve(new_le->a - le->a <= 0)) {
162 return;
163 } else if (le && analyzer->CanProve(le->a - new_le->a <= 0)) {
164 iter = inequality_set->erase(iter);
165 } else {
166 ++iter;
167 }
168 }
169 }
170
171 inequality_set->insert(new_ineq);
172 }
173
ClassifyByPolarity(const Var & var,const std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> & current_ineq_set,std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> * next_ineq_set,std::vector<PrimExpr> * rest,std::vector<std::pair<int64_t,PrimExpr>> * coef_pos,std::vector<std::pair<int64_t,PrimExpr>> * coef_neg,Analyzer * analyzer)174 void ClassifyByPolarity(
175 const Var& var,
176 const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& current_ineq_set,
177 std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* next_ineq_set,
178 std::vector<PrimExpr>* rest, std::vector<std::pair<int64_t, PrimExpr>>* coef_pos,
179 std::vector<std::pair<int64_t, PrimExpr>>* coef_neg, Analyzer* analyzer) {
180 // Take formulas from current_ineq_set and classify them according to polarity wrt var
181 // and store to coef_pos and coef_neg respectively.
182 for (const PrimExpr& ineq : current_ineq_set) {
183 if (const LENode* le = ineq.as<LENode>()) {
184 Array<PrimExpr> coef = arith::DetectLinearEquation(le->a, {var});
185 if (!coef.empty() && is_const_int(coef[0])) {
186 int64_t coef0 = *as_const_int(coef[0]);
187 if (coef0 == 0) {
188 // zero polarity, straight to next_ineq_set
189 AddInequality(next_ineq_set, ineq, analyzer);
190 } else if (coef0 > 0) {
191 coef_pos->push_back({coef0, coef[1]});
192 } else if (coef0 < 0) {
193 coef_neg->push_back({coef0, coef[1]});
194 }
195 continue;
196 }
197 } else if (const EQNode* eq = ineq.as<EQNode>()) {
198 Array<PrimExpr> coef = arith::DetectLinearEquation(eq->a, {var});
199 if (!coef.empty() && is_const_int(coef[0])) {
200 int64_t coef0 = *as_const_int(coef[0]);
201 if (coef0 == 0) {
202 // zero polarity, straight to next_ineq_set
203 AddInequality(next_ineq_set, ineq, analyzer);
204 } else if (coef0 > 0) {
205 // Equalities may be considered as pairs of two inequalities
206 coef_pos->push_back({coef0, coef[1]});
207 coef_neg->push_back({-coef0, -coef[1]});
208 } else if (coef0 < 0) {
209 coef_pos->push_back({-coef0, -coef[1]});
210 coef_neg->push_back({coef0, coef[1]});
211 }
212 continue;
213 }
214 }
215
216 // if nothing worked, put it in rest
217 rest->push_back(ineq);
218 }
219 }
220
MoveEquality(std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> * upper_bounds,std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> * lower_bounds,std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> * equalities)221 void MoveEquality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* upper_bounds,
222 std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* lower_bounds,
223 std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* equalities) {
224 // those exist in both upper & lower bounds will be moved to equalities
225 for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) {
226 auto lb = lower_bounds->find(*ub);
227 if (lb != lower_bounds->end()) {
228 equalities->insert(*lb);
229 lower_bounds->erase(lb);
230 ub = upper_bounds->erase(ub);
231 } else {
232 ++ub;
233 }
234 }
235 }
236
SolveLinearInequalities(const IntConstraints & system_to_solve)237 PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve) {
238 arith::Analyzer analyzer;
239 analyzer.Bind(system_to_solve->ranges);
240
241 // The algorithm consists in doing the following things for each variable v
242 // - Take formulas from `current_ineq_set_to_solve` and
243 // classify them according to polarity wrt v.
244 // - Combine each formula of positive polarity (wrt v)
245 // with each formula of negative polarity.
246 // - Put the resulting combinations into `next_ineq_set_to_solve`
247 // along with unclassifiable formulas.
248 // - Replace `current_ineq_set_to_solve` with `next_ineq_set_to_solve`
249 // and move to the next variable.
250
251 // normalized inequality
252 std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> current_ineq_set_to_solve;
253 std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> next_ineq_set_to_solve;
254 // A vector of pairs (c, e), c > 0, representing formulas of the form c*v + e <= 0
255 std::vector<std::pair<int64_t, PrimExpr>> coef_pos;
256 // A vector of pairs (c, e), c < 0, representing formulas of the form c*v + e <= 0
257 std::vector<std::pair<int64_t, PrimExpr>> coef_neg;
258
259 // formulas we don't know what to do with
260 std::vector<PrimExpr> rest;
261
262 // Simplify each inequality into the form `expr <= 0` and add to current formulas
263 for (const PrimExpr& ineq : system_to_solve->relations) {
264 AddInequality(¤t_ineq_set_to_solve,
265 NormalizeComparisons()(analyzer.Simplify(ineq, kSimplifyRewriteCanonicalRewrite)),
266 &analyzer);
267 }
268
269 Map<Var, IntGroupBounds> res_bounds;
270 for (const Var& v : system_to_solve->variables) {
271 CHECK(!res_bounds.count(v))
272 << "Variable " << v
273 << " appears more than one time in the `variables` which might be a bug";
274
275 next_ineq_set_to_solve.clear();
276 coef_pos.clear();
277 coef_neg.clear();
278
279 // Add bounds from vranges
280 if (system_to_solve->ranges.count(v)) {
281 const Range& range = system_to_solve->ranges[v];
282 PrimExpr range_lbound = analyzer.Simplify(range->min, kSimplifyRewriteCanonicalRewrite);
283 PrimExpr range_ubound =
284 analyzer.Simplify(range->min + range->extent - 1, kSimplifyRewriteCanonicalRewrite);
285 coef_neg.push_back({-1, range_lbound});
286 coef_pos.push_back({1, -range_ubound});
287 }
288
289 ClassifyByPolarity(v, current_ineq_set_to_solve, &next_ineq_set_to_solve, &rest, &coef_pos,
290 &coef_neg, &analyzer);
291
292 // Combine each positive inequality with each negative one (by adding them together)
293 int64_t gcd_x, gcd_y;
294 for (const auto& pos : coef_pos) {
295 for (const auto& neg : coef_neg) {
296 auto first_gcd = ExtendedEuclidean(pos.first, -neg.first, &gcd_x, &gcd_y);
297 PrimExpr c_pos = make_const(v.dtype(), neg.first / first_gcd);
298 PrimExpr c_neg = make_const(v.dtype(), pos.first / first_gcd);
299 // eliminate the current variable
300 PrimExpr new_lhs = c_neg * neg.second - c_pos * pos.second;
301 PrimExpr new_ineq = LE(new_lhs, make_zero(pos.second.dtype()));
302 // we need rewrite_simplify -> canonical_simplify -> rewrite_simplify
303 // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0
304 // with steps = 2 it's (y*2) - 10 <= 0
305 new_ineq =
306 NormalizeComparisons()(analyzer.Simplify(new_ineq, kSimplifyRewriteCanonicalRewrite));
307 AddInequality(&next_ineq_set_to_solve, new_ineq, &analyzer);
308 }
309 }
310
311 // Now we have to generate resulting (in)equalities for the variable v
312
313 // Find the common denominator in a sense
314 // We will generate formulas of the form coef_lcm*v <= bound
315 int64_t coef_lcm = 1;
316 for (const auto& pos : coef_pos) {
317 coef_lcm = LeastCommonMultiple(coef_lcm, pos.first);
318 }
319 for (const auto& neg : coef_neg) {
320 coef_lcm = LeastCommonMultiple(coef_lcm, -neg.first);
321 }
322
323 // The resulting lower and upper bounds
324 std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> upper_bounds;
325 std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> lower_bounds;
326 upper_bounds.reserve(coef_pos.size());
327 lower_bounds.reserve(coef_neg.size());
328
329 for (const auto& pos : coef_pos) {
330 PrimExpr bound = make_const(v.dtype(), -coef_lcm / pos.first) * pos.second;
331 bound = analyzer.Simplify(bound, kSimplifyRewriteCanonicalRewrite);
332 // Don't add if any of the existing bounds is better
333 if (std::any_of(upper_bounds.begin(), upper_bounds.end(),
334 [&bound, &analyzer](const PrimExpr& o) {
335 return analyzer.CanProve(o - bound <= 0);
336 })) {
337 continue;
338 }
339 // Erase all worse bounds
340 for (auto iter = upper_bounds.begin(); iter != upper_bounds.end();) {
341 if (analyzer.CanProve(*iter - bound >= 0)) {
342 iter = upper_bounds.erase(iter);
343 } else {
344 ++iter;
345 }
346 }
347 // Add the upper bound
348 upper_bounds.insert(bound);
349 }
350 for (const auto& neg : coef_neg) {
351 PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second;
352 bound = analyzer.Simplify(bound, kSimplifyRewriteCanonicalRewrite);
353 // Don't add if any of the existing bounds is better
354 if (std::any_of(lower_bounds.begin(), lower_bounds.end(),
355 [&bound, &analyzer](const PrimExpr& o) {
356 return analyzer.CanProve(o - bound >= 0);
357 })) {
358 continue;
359 }
360 // Erase all worse bounds
361 for (auto iter = lower_bounds.begin(); iter != lower_bounds.end();) {
362 if (analyzer.CanProve(*iter - bound <= 0)) {
363 iter = lower_bounds.erase(iter);
364 } else {
365 ++iter;
366 }
367 }
368 // Add the lower bound
369 lower_bounds.insert(bound);
370 }
371
372 std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> equal;
373 equal.reserve(std::min(upper_bounds.size(), lower_bounds.size()));
374 MoveEquality(&upper_bounds, &lower_bounds, &equal);
375 std::vector<PrimExpr> equal_list(equal.begin(), equal.end());
376 std::sort(equal_list.begin(), equal_list.end(), ExprLess());
377
378 // Write it to the result.
379 IntGroupBounds bnds(make_const(v.dtype(), coef_lcm),
380 Array<PrimExpr>(lower_bounds.begin(), lower_bounds.end()),
381 Array<PrimExpr>(equal_list.begin(), equal_list.end()),
382 Array<PrimExpr>(upper_bounds.begin(), upper_bounds.end()));
383 res_bounds.Set(v, bnds);
384
385 std::swap(current_ineq_set_to_solve, next_ineq_set_to_solve);
386 }
387
388 // Everything that is left goes to res.relations
389 Array<PrimExpr> other_conditions;
390 for (const PrimExpr& e : current_ineq_set_to_solve) {
391 PrimExpr e_simp = analyzer.Simplify(e, kSimplifyRewriteCanonicalRewrite);
392 if (is_const_int(e_simp, 0)) {
393 // contradiction detected
394 other_conditions = {const_false()};
395 break;
396 } else if (is_const_int(e_simp, 1)) {
397 continue;
398 } else {
399 other_conditions.push_back(e_simp);
400 }
401 }
402
403 for (const PrimExpr& e : rest) {
404 other_conditions.push_back(e);
405 }
406
407 return {res_bounds, other_conditions};
408 }
409
410 #ifdef _MSC_VER
411 #pragma optimize("g", off)
412 #endif
SolveInequalitiesToRange(const IntConstraints & inequalities)413 IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) {
414 // Resulting ranges will contain ranges for the new variables and for the variables that are
415 // not in the inequalities->variables but are in inequalities->ranges
416 // It will be useful when solving Jacobian axes jac_xxx)
417 Map<Var, Range> res_ranges;
418 // we get a set of equality, lower, upper bound of each variable.
419 auto solved_system = SolveLinearInequalities(inequalities);
420
421 Map<Var, IntGroupBounds> solved_bounds = solved_system.first;
422 Array<PrimExpr> solved_other_relations = solved_system.second;
423
424 Array<PrimExpr> res_relations;
425
426 // this keeps being updated during determining the range of each variable.
427 Map<Var, Range> vranges;
428 for (std::pair<Var, Range> vr : inequalities->ranges) {
429 vranges.Set(vr.first, vr.second);
430 }
431
432 // We process variables in the reverse direction to start with the most independent one.
433 // This order is needed to compute new ranges.
434 for (auto it = inequalities->variables.rbegin(); it != inequalities->variables.rend(); ++it) {
435 arith::Analyzer analyzer;
436 analyzer.Bind(vranges);
437
438 const Var& var = *it;
439 CHECK(solved_bounds.count(var));
440 auto bnd = solved_bounds[var];
441 if (is_one(bnd->coef) && !bnd->equal.empty()) {
442 // There is an equation of the form `v == expr`, so this variable can be completely removed.
443 // Note that we use the 0-th expression because they are ordered by complexity,
444 // so it must be the simplest one.
445 // The MSVC compiler optimization must be disabled for the expression `bnd->equal[0]` which
446 // triggers an internal compiler error.
447 Range best_range(bnd->equal[0],
448 analyzer.Simplify(bnd->equal[0] + 1, kSimplifyRewriteCanonicalRewrite));
449 res_ranges.Set(var, best_range);
450 vranges.Set(var, best_range);
451 } else {
452 if (vranges.count(var) > 0) {
453 bnd = bnd + vranges[var];
454 }
455
456 auto best_range = bnd.FindBestRange(vranges);
457
458 if (best_range.defined()) {
459 if (analyzer.CanProveGreaterEqual(-best_range->extent, 0)) {
460 // range.extent <= 0 implies the input inequality system is unsolvable
461 return IntConstraints(/*variables=*/{}, /*ranges=*/{},
462 /*relations=*/{tir::make_zero(DataType::Bool())});
463 }
464 res_ranges.Set(var, best_range);
465 vranges.Set(var, best_range);
466 }
467 }
468 }
469
470 // Add the original conditions to the resulting conditions
471 arith::Analyzer analyzer;
472 analyzer.Bind(vranges);
473 for (const PrimExpr& old_cond :
474 AsConditions(inequalities->variables, solved_bounds, solved_other_relations)) {
475 if (!analyzer.CanProve(old_cond)) {
476 // those not represented in vranges (res_ranges)
477 res_relations.push_back(old_cond);
478 }
479 }
480
481 IntConstraints system(inequalities->variables, res_ranges, res_relations);
482
483 return system;
484 }
485 #ifdef _MSC_VER
486 #pragma optimize("g", on)
487 #endif
488
SolveInequalitiesDeskewRange(const IntConstraints & inequalities)489 IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequalities) {
490 // Resulting ranges will contain ranges for the new variables and for the variables that are
491 // not in the inequalities->variables but are in inequalities->ranges (jac_xxx)
492 Map<Var, Range> res_ranges;
493 // we get a set of equality, lower, upper bound of each variable.
494 auto solved_system = SolveLinearInequalities(inequalities);
495 Map<Var, IntGroupBounds> solved_bounds = solved_system.first;
496 Array<PrimExpr> solved_other_relations = solved_system.second;
497
498 arith::Analyzer analyzer;
499
500 Map<Var, PrimExpr> res_src_to_dst;
501 Map<Var, PrimExpr> res_dst_to_src;
502 Array<Var> res_variables;
503 Array<PrimExpr> res_relations;
504
505 // this keeps being updated during determining the range of each variable.
506 Map<Var, Range> vranges;
507 for (std::pair<Var, Range> vr : inequalities->ranges) {
508 vranges.Set(vr.first, vr.second);
509 }
510 analyzer.Bind(vranges);
511
512 // We process variables in the reverse direction to start with the most independent one.
513 // This order is needed to compute new ranges.
514 for (auto it = inequalities->variables.rbegin(); it != inequalities->variables.rend(); ++it) {
515 const Var& var = *it;
516 auto bnd = solved_bounds[var];
517 // Note that we replace old vars with new ones
518 bnd = bnd.Substitute(res_src_to_dst);
519
520 if (is_one(bnd->coef) && !bnd->equal.empty()) {
521 // There is an equation of the form `v == expr`,
522 // so this variable can be completely removed.
523 // Note that we use the 0-th expression because they are ordered by complexity,
524 // so it must be the simplest one.
525 res_src_to_dst.Set(var, bnd->equal[0]);
526 } else {
527 if (vranges.count(var) > 0) {
528 bnd = bnd + vranges[var];
529 }
530
531 auto best_range = bnd.FindBestRange(vranges);
532
533 Var new_var = var.copy_with_suffix(".shifted");
534 if (!best_range.defined()) {
535 res_src_to_dst.Set(var, var);
536 res_dst_to_src.Set(var, var);
537 res_variables.push_back(var);
538 } else if (is_const_int(best_range->extent, 1)) {
539 // Don't create an itervar, just replace it everywhere with its min
540 res_src_to_dst.Set(var, best_range->min);
541 } else if (analyzer.CanProveGreaterEqual(-best_range->extent, 0)) {
542 // range.extent <= 0 implies the input inequality system is unsolvable
543 return IntConstraintsTransform(inequalities,
544 IntConstraints(
545 /*variables=*/{},
546 /*ranges=*/{},
547 /*relations=*/{tir::make_zero(DataType::Bool())}),
548 {}, {});
549 } else {
550 // created new_var starts from 0
551 res_src_to_dst.Set(var, new_var + best_range->min);
552 // Note that we are substituting old with new, so best_range contains new var,
553 // that is we have to substitute new with old in best_range here
554 res_dst_to_src.Set(new_var,
555 analyzer.Simplify(var - Substitute(best_range->min, res_dst_to_src)));
556
557 // Add the new var to the resulting axis
558 auto range = Range(make_zero(new_var.dtype()), best_range->extent);
559 res_variables.push_back(new_var);
560 res_ranges.Set(new_var, range);
561
562 vranges.Set(new_var, range);
563 analyzer.Bind(new_var, range);
564 }
565 }
566 }
567
568 // Add the original conditions (with variables substituted) to the resulting conditions
569 for (const PrimExpr& old_cond :
570 AsConditions(inequalities->variables, solved_bounds, solved_other_relations)) {
571 PrimExpr new_cond = analyzer.Simplify(Substitute(old_cond, res_src_to_dst));
572 if (!is_const_int(new_cond, 1)) {
573 // those not represented in vranges (res_ranges)
574 res_relations.push_back(new_cond);
575 }
576 }
577
578 // Reverse the axis so that it matches the order of the original variables
579 res_variables = Array<Var>(res_variables.rbegin(), res_variables.rend());
580
581 IntConstraints new_inequalities(res_variables, res_ranges, res_relations);
582 IntConstraintsTransform transform(inequalities, new_inequalities, res_src_to_dst, res_dst_to_src);
583
584 return transform;
585 }
586
587 TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition")
__anon1c4334d90302(TVMArgs args, TVMRetValue* ret) 588 .set_body([](TVMArgs args, TVMRetValue* ret) {
589 IntConstraints problem;
590 PartialSolvedInequalities ret_ineq;
591 if (args.size() == 1) {
592 problem = args[0];
593 ret_ineq = SolveLinearInequalities(problem);
594 } else if (args.size() == 3) {
595 problem = IntConstraints(args[0], args[1], args[2]);
596 ret_ineq = SolveLinearInequalities(problem);
597 } else {
598 LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets "
599 << args.size();
600 }
601 *ret = AsConditions(problem->variables, ret_ineq.first, ret_ineq.second);
602 });
603
__anon1c4334d90402(TVMArgs args, TVMRetValue* ret) 604 TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange").set_body([](TVMArgs args, TVMRetValue* ret) {
605 if (args.size() == 1) {
606 *ret = SolveInequalitiesToRange(args[0]);
607 } else if (args.size() == 3) {
608 IntConstraints problem(args[0], args[1], args[2]);
609 *ret = SolveInequalitiesToRange(problem);
610 } else {
611 LOG(FATAL) << "arith.SolveInequalitiesToRange expects 1 or 3 arguments, gets " << args.size();
612 }
613 });
614
615 TVM_REGISTER_GLOBAL("arith.SolveInequalitiesDeskewRange")
__anon1c4334d90502(TVMArgs args, TVMRetValue* ret) 616 .set_body([](TVMArgs args, TVMRetValue* ret) {
617 if (args.size() == 1) {
618 *ret = SolveInequalitiesDeskewRange(args[0]);
619 } else if (args.size() == 3) {
620 IntConstraints problem(args[0], args[1], args[2]);
621 *ret = SolveInequalitiesDeskewRange(problem);
622 } else {
623 LOG(FATAL) << "arith.SolveInequalitiesDeskewRange expects 1 or 3 arguments, gets "
624 << args.size();
625 }
626 });
627
628 } // namespace arith
629 } // namespace tvm
630