1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file ad_simplify.cc
22  * \brief Simplify tensor compute generated by tensor-level autodiff.
23  *
24  * The major simplification we do in this file is to eliminate
25  * the Jacobian tensor created by autodiff.
26  *
27  * Jacobian tensor is sparse because one output element usually relates
28  * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
29  * between input tensor and output tensor, thus the Jacobian is diagonal.
30  *
31  * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
32  * \alpha and \beta are vectors represent the indices of In and Out respectively.
33  * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
34  * Thereby we solve linear equations of \beta = A \alpha,
35  * as well as linear inequalities of their domain ranges.
36  *
37  * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
38  * arXiv preprint arXiv:1711.01348, 2017. for more details.
39  *
40  * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
41  * replace the compute expression with solved new axes, and create a selection node
42  * (non-zero-condition ? new_compute_expression : 0).
43  *
44  * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
45  *
46  */
47 #include <dmlc/optional.h>
48 #include <tvm/arith/analyzer.h>
49 #include <tvm/arith/int_solver.h>
50 #include <tvm/runtime/registry.h>
51 #include <tvm/te/autodiff.h>
52 #include <tvm/tir/analysis.h>
53 #include <tvm/tir/stmt_functor.h>
54 
55 #include <iterator>
56 #include <memory>
57 #include <utility>
58 
59 #include "ad_util.h"
60 
61 namespace tvm {
62 namespace te {
63 
64 using arith::DivMode;
65 using arith::kFloorDiv;
66 using arith::kSimplifyRewriteCanonicalRewrite;
67 using arith::kTruncDiv;
68 
69 // Combine all expressions from the container using &&.
70 template <class container>
All(const container & c)71 PrimExpr All(const container& c) {
72   PrimExpr res;
73   for (const auto& e : c) {
74     if (res.get()) {
75       res = res && e;
76     } else {
77       res = e;
78     }
79   }
80   if (res.get()) {
81     return res;
82   } else {
83     return const_true();
84   }
85 }
86 
IterVarsToMap(const Array<IterVar> & itervars)87 Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
88   Map<Var, Range> res;
89   for (const IterVar& v : itervars) {
90     res.Set(v->var, v->dom);
91   }
92   return res;
93 }
94 
95 // Given a map from vars to ranges create an array of itervars
IterVarsFromMap(const Array<Var> & vars,const Map<Var,Range> & vranges,IterVarType iter_type=kDataPar,std::string thread_tag="")96 Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
97                                IterVarType iter_type = kDataPar, std::string thread_tag = "") {
98   Array<IterVar> res;
99   for (const Var& v : vars) {
100     CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
101                             << vranges;
102     res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
103   }
104   return res;
105 }
106 
IterVarsToVars(const Array<IterVar> & itervars)107 Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
108   Array<Var> res;
109   for (const IterVar& v : itervars) {
110     res.push_back(v->var);
111   }
112   return res;
113 }
114 
115 template <typename ValueType>
is_const_value(const PrimExpr & e,ValueType value)116 bool is_const_value(const PrimExpr& e, ValueType value) {
117   static_assert(std::is_integral<ValueType>::value,
118                 "Comparison to non-integer values is forbidden.");
119   if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
120     return i->value == value;
121   } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
122     return i->value == value;
123   } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
124     return is_const_value(c->value, value);
125   } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
126     return is_const_value(b->value, value);
127   } else {
128     return false;
129   }
130 }
131 
132 // Return true if this combiner is just a sum.
IsSumCombiner(const CommReducer & combiner,const Map<Var,Range> & vranges)133 bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
134   arith::Analyzer analyzer;
135   analyzer.Bind(vranges);
136   if (combiner->result.size() != 1) {
137     return false;
138   }
139 
140   if (!is_const_value(
141           analyzer.Simplify(combiner->identity_element[0], kSimplifyRewriteCanonicalRewrite), 0)) {
142     return false;
143   }
144 
145   PrimExpr combiner_result =
146       analyzer.Simplify(combiner->result[0], kSimplifyRewriteCanonicalRewrite);
147 
148   return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
149          tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
150 }
151 
CanFactorZeroFromCombiner(const CommReducer & combiner,int value_index,const Map<Var,Range> & vranges)152 bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
153                                const Map<Var, Range>& vranges) {
154   arith::Analyzer analyzer;
155   analyzer.Bind(vranges);
156   if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
157                                         kSimplifyRewriteCanonicalRewrite),
158                       0)) {
159     return false;
160   }
161 
162   PrimExpr zero = make_zero(combiner->result[value_index].dtype());
163   PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
164                                                            {combiner->rhs[value_index], zero}});
165   in = analyzer.Simplify(in, kSimplifyRewriteCanonicalRewrite);
166 
167   return is_const_value(in, 0);
168 }
169 
170 struct NonzeroConditionResult {
171   PrimExpr cond;
172   PrimExpr value;
173 
to_exprtvm::te::NonzeroConditionResult174   PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
175 
operator <<(std::ostream & os,const NonzeroConditionResult & r)176   friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
177     return os << r.to_expr();
178   }
179 };
180 
181 // The implementation of NonzeroCondition
182 // transform expression to cond ? value : 0
183 class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
184  public:
NonzeroCondition(const PrimExpr & e)185   NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
186     if (e.dtype().is_bool()) {
187       // Boolean expressions are non-zero whenever they are true themselves
188       return {e, const_true()};
189     } else {
190       return VisitExpr(e);
191     }
192   }
193 
194   // Most of the cases are implemented using helpers below
VisitExpr_(const VarNode * op)195   result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
VisitExpr_(const IntImmNode * op)196   result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
VisitExpr_(const FloatImmNode * op)197   result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
VisitExpr_(const StringImmNode * op)198   result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
VisitExpr_(const AddNode * op)199   result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
VisitExpr_(const SubNode * op)200   result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
VisitExpr_(const MulNode * op)201   result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
VisitExpr_(const DivNode * op)202   result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
VisitExpr_(const ModNode * op)203   result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
VisitExpr_(const FloorDivNode * op)204   result_type VisitExpr_(const FloorDivNode* op) final {
205     return BinOpDivLike_(GetRef<FloorDiv>(op));
206   }
VisitExpr_(const FloorModNode * op)207   result_type VisitExpr_(const FloorModNode* op) final {
208     return BinOpDivLike_(GetRef<FloorMod>(op));
209   }
VisitExpr_(const MinNode * op)210   result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
VisitExpr_(const MaxNode * op)211   result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
212 
VisitExpr_(const CastNode * op)213   result_type VisitExpr_(const CastNode* op) final {
214     auto nz_a = NonzeroCondition(op->value);
215     return {nz_a.cond, Cast(op->dtype, nz_a.value)};
216   }
217 
VisitExpr_(const SelectNode * op)218   result_type VisitExpr_(const SelectNode* op) final {
219     PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
220     auto nz_a = NonzeroCondition(true_val);
221     auto nz_b = NonzeroCondition(false_val);
222 
223     // If the false part is zero, we can get rid of the select
224     if (is_const_value(nz_b.value, 0)) {
225       PrimExpr new_cond = analyzer_.Simplify(nz_a.cond && cond, kSimplifyRewriteCanonicalRewrite);
226       return {new_cond, nz_a.value};
227     }
228 
229     // If the true part is zero, we can also get rid of the select
230     if (is_const_value(nz_a.value, 0)) {
231       PrimExpr new_cond = analyzer_.Simplify(nz_b.cond && !cond, kSimplifyRewriteCanonicalRewrite);
232       return {new_cond, nz_b.value};
233     }
234 
235     // Otherwise we retain the select and combine the conditions into this
236     PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
237                                            kSimplifyRewriteCanonicalRewrite);
238     if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
239       return {new_cond, GetRef<PrimExpr>(op)};
240     } else {
241       return {new_cond, Select(cond, nz_a.value, nz_b.value)};
242     }
243   }
244 
VisitExpr_(const CallNode * op)245   result_type VisitExpr_(const CallNode* op) final {
246     if (op->op.same_as(op_if_then_else_)) {
247       PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
248       auto nz_a = NonzeroCondition(true_val);
249       auto nz_b = NonzeroCondition(false_val);
250 
251       // We don't have as much freedom here as in the select case
252       // since the `if` must be preserved in any case
253       PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
254                                              kSimplifyRewriteCanonicalRewrite);
255       if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
256         return {new_cond, GetRef<PrimExpr>(op)};
257       } else {
258         return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
259       }
260     } else {
261       return Default_(GetRef<PrimExpr>(op));
262     }
263   }
264 
VisitExpr_(const ProducerLoadNode * op)265   result_type VisitExpr_(const ProducerLoadNode* op) final {
266     return Default_(GetRef<PrimExpr>(op));
267   }
268 
Default_(const PrimExpr & e)269   NonzeroConditionResult Default_(const PrimExpr& e) {
270     // This is always correct, so it's the default
271     return {const_true(), e};
272   }
273 
274   template <class T>
Const_(const T & op)275   NonzeroConditionResult Const_(const T& op) {
276     if (op->value == 0) {
277       return {const_false(), op};
278     } else {
279       return {const_true(), op};
280     }
281   }
282 
283   template <class T>
BinOpAddLike_(const T & op)284   NonzeroConditionResult BinOpAddLike_(const T& op) {
285     auto nz_a = NonzeroCondition(op->a);
286     auto nz_b = NonzeroCondition(op->b);
287 
288     // For addition and similar ops the result may be nonzero if either of the arguments is
289     // nonzero, so we combine the conditions with Or.
290     if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
291       // If the conditions are the same, we don't need Or
292       if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
293         return {nz_a.cond, op};
294       } else {
295         return {nz_a.cond, T(nz_a.value, nz_b.value)};
296       }
297     } else {
298       // Otherwise use Or
299       PrimExpr new_cond =
300           analyzer_.Simplify(nz_a.cond || nz_b.cond, kSimplifyRewriteCanonicalRewrite);
301       // A little optimization: if the combined condition is the same as one of the inner
302       // conditions, we don't need to guard the inner value with a select, otherwise
303       // we create a select in the `to_expr` call.
304       PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
305       PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
306       PrimExpr new_expr = T(new_a, new_b);
307       return {new_cond, new_expr};
308     }
309   }
310 
311   template <class T>
BinOpMulLike_(const T & op)312   NonzeroConditionResult BinOpMulLike_(const T& op) {
313     auto nz_a = NonzeroCondition(op->a);
314     auto nz_b = NonzeroCondition(op->b);
315 
316     // For multiplication and similar ops the result may be nonzero if
317     // both the arguments are nonzero, so we combine with And.
318     PrimExpr new_cond =
319         analyzer_.Simplify(nz_a.cond && nz_b.cond, kSimplifyRewriteCanonicalRewrite);
320 
321     if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
322       return {new_cond, op};
323     } else {
324       return {new_cond, T(nz_a.value, nz_b.value)};
325     }
326   }
327 
328   template <class T>
BinOpDivLike_(const T & op)329   NonzeroConditionResult BinOpDivLike_(const T& op) {
330     auto nz_a = NonzeroCondition(op->a);
331 
332     // For Div we simply use the condition of the numerator.
333 
334     if (nz_a.value.same_as(op->a)) {
335       return {nz_a.cond, op};
336     } else {
337       return {nz_a.cond, T(nz_a.value, op->b)};
338     }
339   }
340 
341  private:
342   arith::Analyzer analyzer_;
343   const Op& op_if_then_else_ = Op::Get("tir.if_then_else");
344 };
345 
NonzeronessCondition(const PrimExpr & expr)346 inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
347   return NonzeroConditionFunctor().NonzeroCondition(expr);
348 }
349 
350 struct FactorOutAtomicFormulasResult {
351   std::vector<PrimExpr> atomic_formulas;
352   PrimExpr rest;
353 
to_exprtvm::te::FactorOutAtomicFormulasResult354   PrimExpr to_expr() const {
355     PrimExpr res = rest;
356     for (const PrimExpr& e : atomic_formulas) {
357       res = And(e, res);
358     }
359     return res;
360   }
361 
to_arraytvm::te::FactorOutAtomicFormulasResult362   Array<PrimExpr> to_array() const {
363     Array<PrimExpr> res = atomic_formulas;
364     res.push_back(rest);
365     return res;
366   }
367 };
368 
369 // The implementation of FactorOutAtomicFormulas
370 class FactorOutAtomicFormulasFunctor
371     : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
372  public:
Atomic_(const PrimExpr & e)373   result_type Atomic_(const PrimExpr& e) {
374     // For atomic expressions the result is the expr itself with True as the residual
375     return {{e}, make_const(e.dtype(), 1)};
376   }
377 
378   // This is basically the list of expression kinds that are considered atomic
VisitExpr_(const VarNode * op)379   result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const CallNode * op)380   result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const IntImmNode * op)381   result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const EQNode * op)382   result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const NENode * op)383   result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const LENode * op)384   result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const LTNode * op)385   result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const GENode * op)386   result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const GTNode * op)387   result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
388 
VisitExpr_(const SelectNode * op)389   result_type VisitExpr_(const SelectNode* op) final {
390     // Select can be rewritten through other logical ops
391     PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
392     return VisitExpr(expr);
393   }
394 
VisitExpr_(const NotNode * op)395   result_type VisitExpr_(const NotNode* op) final {
396     // Not should be moved down
397     if (const OrNode* or_expr = op->a.as<OrNode>()) {
398       PrimExpr expr = !or_expr->a && !or_expr->b;
399       return VisitExpr(expr);
400     } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
401       PrimExpr expr = !and_expr->a || !and_expr->b;
402       return VisitExpr(expr);
403     } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
404       PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
405                        (sel_expr->condition || !sel_expr->false_value));
406       return VisitExpr(expr);
407     }
408     return Atomic_(GetRef<PrimExpr>(op));
409   }
410 
VisitExpr_(const AndNode * op)411   result_type VisitExpr_(const AndNode* op) final {
412     auto res_a = VisitExpr(op->a);
413     auto res_b = VisitExpr(op->b);
414 
415     // For the And case we return the union of the sets of atomic formulas
416     std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
417     res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
418     std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
419               std::inserter(res_set, res_set.end()));
420     std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
421               std::inserter(res_set, res_set.end()));
422 
423     std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
424 
425     // And the residuals are combined with &&
426     return {res, res_a.rest && res_b.rest};
427   }
428 
VisitExpr_(const MulNode * op)429   result_type VisitExpr_(const MulNode* op) final {
430     // Since we work with bools, for multiplication we do the same thing as for And
431     PrimExpr e_and = op->a && op->b;
432     return VisitExpr(e_and);
433   }
434 
VisitExpr_(const OrNode * op)435   result_type VisitExpr_(const OrNode* op) final {
436     auto res_a = VisitExpr(op->a);
437     auto res_b = VisitExpr(op->b);
438 
439     std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
440         res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
441     std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
442         res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
443 
444     // For the Or case we intersect the sets of atomic formulas
445     std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
446     res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
447     for (const auto& res_b_formula : res_b_set) {
448       if (res_a_set.count(res_b_formula)) {
449         res_set.insert(res_b_formula);
450       }
451     }
452 
453     // Computing the residual is more complex: we have to compute the sets of atomic formulas
454     // which are left behind, and then combine them with the residuals into the new residual.
455     std::vector<PrimExpr> new_cond_a;
456     new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
457     for (const auto& formula : res_a_set) {
458       if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
459     }
460 
461     std::vector<PrimExpr> new_cond_b;
462     new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
463     for (const auto& formula : res_b_set) {
464       if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
465     }
466 
467     res_a.atomic_formulas = std::move(new_cond_a);
468     res_b.atomic_formulas = std::move(new_cond_b);
469 
470     PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
471     std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
472 
473     return {res, new_rest};
474   }
475 };
476 
477 // Transform the given formula into a conjunction of atomic formulas (represented as an array)
478 // and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
479 // etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
FactorOutAtomicFormulas(const PrimExpr & e)480 FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
481   CHECK(e.dtype().is_bool());
482   return FactorOutAtomicFormulasFunctor().VisitExpr(e);
483 }
484 
485 struct EliminateDivModResult {
486   PrimExpr expr;
487   Map<Var, PrimExpr> substitution;
488   Array<Var> new_variables;
489   Array<PrimExpr> conditions;
490   Map<Var, Range> ranges;
491 };
492 
ModImpl(PrimExpr a,PrimExpr b,DivMode mode)493 inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
494   if (mode == kTruncDiv) {
495     return truncmod(a, b);
496   } else {
497     CHECK_EQ(mode, kFloorDiv);
498     return floormod(a, b);
499   }
500 }
501 
DivImpl(PrimExpr a,PrimExpr b,DivMode mode)502 inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
503   if (mode == kTruncDiv) {
504     return truncdiv(a, b);
505   } else {
506     CHECK_EQ(mode, kFloorDiv);
507     return floordiv(a, b);
508   }
509 }
510 
511 class EliminateDivModMutator : public ExprMutator {
512  public:
513   Map<Var, PrimExpr> substitution;
514   Array<Var> new_variables;
515   Array<PrimExpr> conditions;
516   Map<Var, Range> ranges;
517 
EliminateDivModMutator(Map<Var,Range> ranges)518   explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
519 
VisitExpr_(const DivNode * op)520   virtual PrimExpr VisitExpr_(const DivNode* op) {
521     const IntImmNode* imm = op->b.as<IntImmNode>();
522     if (imm && imm->value != 0) {
523       if (imm->value < 0) {
524         // x / -c == -(x/c) for truncated division
525         return make_zero(op->dtype) -
526                VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
527       }
528 
529       // Try to find the already existing variables for this expression
530       auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
531       if (it != expr_to_vars_.end()) {
532         return it->second.first;
533       }
534 
535       // Otherwise recursively mutate the left hand side, and create new variables
536       PrimExpr mutated_a = VisitExpr(op->a);
537       if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
538         return var_pair_opt.value().first;
539       } else {
540         return truncdiv(mutated_a, op->b);
541       }
542     }
543 
544     return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
545   }
546 
VisitExpr_(const ModNode * op)547   virtual PrimExpr VisitExpr_(const ModNode* op) {
548     const IntImmNode* imm = op->b.as<IntImmNode>();
549     if (imm && imm->value != 0) {
550       if (imm->value < 0) {
551         // x % -c == x % c for truncated division
552         return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
553       }
554 
555       // Try to find the already existing variables for this expression
556       auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
557       if (it != expr_to_vars_.end()) {
558         return it->second.second;
559       }
560 
561       // Otherwise recursively mutate the left hand side, and create new variables
562       PrimExpr mutated_a = VisitExpr(op->a);
563       if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
564         return var_pair_opt.value().second;
565       } else {
566         return truncmod(mutated_a, op->b);
567       }
568     }
569 
570     return truncmod(VisitExpr(op->a), VisitExpr(op->b));
571   }
572 
VisitExpr_(const FloorDivNode * op)573   virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
574     const IntImmNode* imm = op->b.as<IntImmNode>();
575     if (imm && imm->value != 0) {
576       if (imm->value < 0) {
577         // x / -c == (-x) / c for flooring division
578         return VisitExpr(
579             floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
580       }
581 
582       // Try to find the already existing variables for this expression
583       auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
584       if (it != expr_to_vars_.end()) {
585         return it->second.first;
586       }
587 
588       // Otherwise recursively mutate the left hand side, and create new variables
589       PrimExpr mutated_a = VisitExpr(op->a);
590       if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
591         return var_pair_opt.value().first;
592       } else {
593         return floordiv(mutated_a, op->b);
594       }
595     }
596 
597     return floordiv(VisitExpr(op->a), VisitExpr(op->b));
598   }
599 
VisitExpr_(const FloorModNode * op)600   virtual PrimExpr VisitExpr_(const FloorModNode* op) {
601     const IntImmNode* imm = op->b.as<IntImmNode>();
602     if (imm && imm->value != 0) {
603       if (imm->value < 0) {
604         // x % -c == -(-x % c) for flooring division
605         return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
606                                                          make_const(op->dtype, -imm->value)));
607       }
608 
609       // Try to find the already existing variables for this expression
610       auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
611       if (it != expr_to_vars_.end()) {
612         return it->second.second;
613       }
614 
615       // Otherwise recursively mutate the left hand side, and create new variables
616       PrimExpr mutated_a = VisitExpr(op->a);
617       if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
618         return var_pair_opt.value().second;
619       } else {
620         return floormod(mutated_a, op->b);
621       }
622     }
623 
624     return floormod(VisitExpr(op->a), VisitExpr(op->b));
625   }
626 
627  private:
AddNewVarPair(const PrimExpr & e,const PrimExpr & mut,int64_t val,DivMode mode)628   dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
629                                                     int64_t val, DivMode mode) {
630     using tresult = dmlc::optional<std::pair<Var, Var>>;
631 
632     // Try to find the variables using the mutated expressions
633     if (!e.same_as(mut)) {
634       auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
635       if (it != expr_to_vars_.end()) {
636         return tresult(it->second);
637       }
638     }
639 
640     PrimExpr val_e = make_const(e.dtype(), val);
641     idx_ += 1;
642 
643     // Convert `ranges` to IntSets
644     std::unordered_map<const VarNode*, IntSet> var_intsets;
645     for (const auto& p : ranges) {
646       var_intsets[p.first.get()] = IntSet::FromRange(p.second);
647     }
648 
649     // Infer ranges for the expressions we want to replace with variables
650     Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
651     Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
652 
653     // We don't want to add unbounded variables
654     if (!div_range.get() || !mod_range.get()) {
655       LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
656                    << "  because its bounds cannot be inferred";
657       return tresult();
658     }
659     if (!mod_range.get()) {
660       LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
661                    << "  because its bounds cannot be inferred";
662       return tresult();
663     }
664 
665     // Create new variables for the expressions
666     auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
667     auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
668 
669     new_variables.push_back(div);
670     new_variables.push_back(mod);
671 
672     // Note that we have to perform substitution to mut because mut may contain new variables
673     substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
674     substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
675 
676     ranges.Set(div, div_range);
677     ranges.Set(mod, mod_range);
678 
679     // This additional condition works as a definition for the new variables
680     conditions.push_back(mut == div * val_e + mod);
681 
682     if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
683       // If we use the C/C++ definition of mod, there may be multiple values of `mod`
684       // satisfying the added condition if the expr `e` may change its sign, so we
685       // have to add another condition.
686       LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
687                    << ModImpl(e, val_e, mode) << "  probably may change its sign";
688       conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
689     }
690 
691     auto p = std::make_pair(div, mod);
692     expr_to_vars_[std::make_tuple(mode, e, val)] = p;
693     if (!e.same_as(mut)) {
694       expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
695     }
696     return tresult(p);
697   }
698 
699   class TupleEqual_ {
700    public:
operator ()(const std::tuple<DivMode,PrimExpr,int64_t> & lhs,const std::tuple<DivMode,PrimExpr,int64_t> & rhs) const701     bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
702                     const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
703       return std::get<0>(lhs) == std::get<0>(rhs) &&
704              tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
705              std::get<2>(lhs) == std::get<2>(rhs);
706     }
707   };
708 
709   class TupleHasher_ {
710    public:
operator ()(const std::tuple<DivMode,PrimExpr,int64_t> & key) const711     size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
712       return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
713               1) ^
714              (std::hash<int64_t>()(std::get<2>(key)) << 1);
715     }
716   };
717 
718   // A counter for naming new variables
719   int idx_{0};
720   // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
721   // such that `div = e / n` and `mod = e % n`
722   std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
723                      TupleEqual_>
724       expr_to_vars_;
725   arith::Analyzer analyzer_;
726 };
727 
728 // Replace every subexpr of the form e/const and e % const with a new variable.
729 // Syntactically equal expressions will be mapped to the same variable.
EliminateDivMod(const PrimExpr & expr,Map<Var,Range> ranges)730 EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
731   EliminateDivModResult res;
732   EliminateDivModMutator mutator(ranges);
733   res.expr = mutator(expr);
734   res.conditions = std::move(mutator.conditions);
735   res.new_variables = std::move(mutator.new_variables);
736   res.substitution = std::move(mutator.substitution);
737   res.ranges = std::move(mutator.ranges);
738   return res;
739 }
740 
EliminateDivModFromDomainConditions(const arith::IntConstraints & domain)741 arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
742     const arith::IntConstraints& domain) {
743   auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
744 
745   Map<Var, Range> new_vranges = elim_res.ranges;
746   Array<Var> new_axis = Concat(domain->variables, elim_res.new_variables);
747   PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
748 
749   arith::IntConstraints new_domain(new_axis, new_vranges,
750                                    FactorOutAtomicFormulas(new_cond).to_array());
751 
752   Map<Var, PrimExpr> src_to_dst;
753   Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
754   for (const Var& v : domain->variables) {
755     src_to_dst.Set(v, v);
756     dst_to_src.Set(v, v);
757   }
758 
759   return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
760 }
761 
IdentityTransformation(const arith::IntConstraints & domain)762 inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) {
763   Map<Var, PrimExpr> identity_map;
764   for (const Var& v : domain->variables) {
765     identity_map.Set(v, v);
766   }
767   return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map);
768 }
769 
770 // Simplify an iteration domain.
SimplifyDomain(const arith::IntConstraints & iter_domains,bool eliminate_div_mod)771 arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
772                                               bool eliminate_div_mod) {
773   arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains);
774 
775   if (eliminate_div_mod) {
776     transf = transf + EliminateDivModFromDomainConditions(transf->dst);
777   }
778 
779   // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably
780   // should find a better terminating criterion (like stop when the domain volume stops decreasing)
781   // Also 2 steps seems to be slightly better than 3
782   for (size_t i = 0; i < 2; ++i) {
783     transf = transf + arith::SolveLinearEquations(transf->dst);
784     transf = transf + arith::SolveInequalitiesDeskewRange(transf->dst);
785   }
786 
787   return transf;
788 }
789 
790 // Use the condition of a reduction op to simplify its domain (axis)
SimplifyReductionDomain(const PrimExpr & expr,const Map<Var,Range> & outer_vranges)791 PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& outer_vranges) {
792   if (const ReduceNode* red = expr.as<ReduceNode>()) {
793     Array<Var> vars = IterVarsToVars(red->axis);
794     Map<Var, Range> vranges = Merge(outer_vranges, IterVarsToMap(red->axis));
795     Array<PrimExpr> relations = FactorOutAtomicFormulas(red->condition).to_array();
796 
797     arith::IntConstraints domain(vars, vranges, relations);
798     auto res = SimplifyDomain(domain);
799 
800     Array<PrimExpr> new_source;
801     for (const PrimExpr& src : red->source) {
802       new_source.push_back(Substitute(src, res->src_to_dst));
803     }
804 
805     Array<IterVar> new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce);
806 
807     // Perform simplification mainly to remove a possibly empty reduction.
808     arith::Analyzer analyzer;
809     return analyzer.Simplify(Reduce(red->combiner, new_source, new_axis, All(res->dst->relations),
810                                     red->value_index, red->init),
811                              kSimplifyRewriteCanonicalRewrite);
812   } else {
813     return expr;
814   }
815 }
816 
817 // Extract from cond an implication of cond not containing vars
ImplicationNotContainingVars(const PrimExpr & cond,const std::unordered_set<const VarNode * > & vars)818 std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
819     const PrimExpr& cond, const std::unordered_set<const VarNode*>& vars) {
820   CHECK(cond.dtype().is_bool()) << "The type of cond must be bool";
821   // TODO(sgrechanik-h): NOTs could be pushed down using De Morgan laws
822   // before running this function but this case didn't seem to be important enough.
823   if (const AndNode* op = cond.as<AndNode>()) {
824     auto pair_a = ImplicationNotContainingVars(op->a, vars);
825     auto pair_b = ImplicationNotContainingVars(op->b, vars);
826     return {pair_a.first && pair_b.first, pair_a.second && pair_b.second};
827   } else if (const OrNode* op = cond.as<OrNode>()) {
828     auto pair_a = ImplicationNotContainingVars(op->a, vars);
829     auto pair_b = ImplicationNotContainingVars(op->b, vars);
830     return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) &&
831                                               (pair_b.first || pair_a.second) &&
832                                               (pair_a.second || pair_b.second)};
833   } else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
834     return {cond, const_true()};
835   } else {
836     return {const_true(), cond};
837   }
838 }
839 
840 // Factor conditions out of a reduction by applying Fourier-Motzkin elimination and moving out
841 // (in)equalities which do not depend on the reduction variables.
LiftConditionsThroughReduction(const PrimExpr & cond,const Array<IterVar> & red_axis,const Array<IterVar> & outer_axis)842 std::pair<PrimExpr, PrimExpr> LiftConditionsThroughReduction(const PrimExpr& cond,
843                                                              const Array<IterVar>& red_axis,
844                                                              const Array<IterVar>& outer_axis) {
845   // Factor out atomics so that we can consider this as a system of inequalities
846   auto factor_atomic_res = FactorOutAtomicFormulas(cond);
847   Array<PrimExpr> atomics = factor_atomic_res.atomic_formulas;
848   const PrimExpr& rest = factor_atomic_res.rest;
849 
850   Array<Var> allvars;
851   for (const IterVar& v : red_axis) {
852     allvars.push_back(v->var);
853   }
854   for (const IterVar& v : outer_axis) {
855     allvars.push_back(v->var);
856   }
857 
858   auto vranges = Merge(IterVarsToMap(red_axis), IterVarsToMap(outer_axis));
859   // start from reduction vars, so that input vars don't depend on them
860   arith::IntConstraints ineq_to_solve(allvars, vranges, atomics);
861   auto res_ineq = arith::SolveLinearInequalities(ineq_to_solve);
862   atomics = arith::AsConditions(allvars, res_ineq.first, res_ineq.second);
863 
864   // Append the rest part
865   PrimExpr rewritten_cond = All(atomics) && rest;
866 
867   std::unordered_set<const VarNode*> vset;
868   for (const IterVar& v : red_axis) {
869     vset.insert(v->var.get());
870   }
871 
872   // The outer (first) condition does not contain reduction vars,
873   // the inner (second) condition is everything else
874   auto res = ImplicationNotContainingVars(rewritten_cond, vset);
875   return res;
876 }
877 
878 // Convert an array of itervars to an array of inequalities
IterVarsToInequalities(const Array<IterVar> & itervars)879 Array<PrimExpr> IterVarsToInequalities(const Array<IterVar>& itervars) {
880   Array<PrimExpr> res;
881   for (const IterVar& v : itervars) {
882     res.push_back(GE(v->var, v->dom->min));
883     res.push_back(LT(v->var, v->dom->min + v->dom->extent));
884   }
885   return res;
886 }
887 
888 class RemoveRedundantInequalitiesMutator : public ExprMutator {
889  public:
RemoveRedundantInequalitiesMutator(Array<PrimExpr> known)890   explicit RemoveRedundantInequalitiesMutator(Array<PrimExpr> known) {
891     for (const PrimExpr& cond : known) {
892       known_.push_back(analyzer_.Simplify(cond, kSimplifyRewriteCanonicalRewrite));
893     }
894   }
895 
VisitExpr_(const SelectNode * op)896   virtual PrimExpr VisitExpr_(const SelectNode* op) {
897     bool has_side_effect = (SideEffect(GetRef<PrimExpr>(op)) > CallEffectKind::kReadState);
898     PrimExpr new_cond =
899         analyzer_.Simplify(VisitExpr(op->condition), kSimplifyRewriteCanonicalRewrite);
900     if (is_one(new_cond) && !has_side_effect) {
901       return VisitExpr(op->true_value);
902     } else if (is_zero(new_cond) && !has_side_effect) {
903       return VisitExpr(op->false_value);
904     } else {
905       Array<PrimExpr> new_known = known_;
906       for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
907         new_known.push_back(atomic);
908       }
909       RemoveRedundantInequalitiesMutator new_mutator(new_known);
910       // Note that we mutate only the true value with the new mutator
911       // TODO(sgrechanik-h): Update known conditions for the false value as well
912       return Select(new_cond, new_mutator(op->true_value), VisitExpr(op->false_value));
913     }
914   }
915 
VisitExpr_(const CallNode * op)916   virtual PrimExpr VisitExpr_(const CallNode* op) {
917     if (op->op.same_as(op_if_then_else_)) {
918       PrimExpr new_cond =
919           analyzer_.Simplify(VisitExpr(op->args[0]), kSimplifyRewriteCanonicalRewrite);
920       if (is_one(new_cond)) {
921         return VisitExpr(op->args[1]);
922       } else if (is_zero(new_cond)) {
923         return VisitExpr(op->args[2]);
924       } else {
925         Array<PrimExpr> new_known = known_;
926         for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
927           new_known.push_back(atomic);
928         }
929         RemoveRedundantInequalitiesMutator new_mutator(new_known);
930         // Note that we mutate only the true value with the new mutator
931         // TODO(sgrechanik-h): Update known conditions for the false value as well
932         return if_then_else(new_cond, new_mutator(op->args[1]), VisitExpr(op->args[2]));
933       }
934     } else {
935       return ExprMutator::VisitExpr_(op);
936     }
937   }
938 
VisitExpr_(const ReduceNode * op)939   virtual PrimExpr VisitExpr_(const ReduceNode* op) {
940     Array<PrimExpr> known_with_axes = known_;
941     CHECK(op->init.empty()) << "Derivative of Reduction with initialization is not implemented";
942     for (const PrimExpr& axis_cond : IterVarsToInequalities(op->axis)) {
943       known_with_axes.push_back(axis_cond);
944     }
945     RemoveRedundantInequalitiesMutator mutator_with_axes(known_with_axes);
946 
947     PrimExpr new_cond = mutator_with_axes(op->condition);
948 
949     Array<PrimExpr> new_known = known_with_axes;
950     for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
951       new_known.push_back(atomic);
952     }
953     RemoveRedundantInequalitiesMutator new_mutator(new_known);
954 
955     Array<PrimExpr> new_source;
956     for (const PrimExpr& src : op->source) {
957       new_source.push_back(new_mutator(src));
958     }
959 
960     return Reduce(op->combiner, new_source, op->axis, new_cond, op->value_index, op->init);
961   }
962 
VisitExpr_(const EQNode * op)963   virtual PrimExpr VisitExpr_(const EQNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const NENode * op)964   virtual PrimExpr VisitExpr_(const NENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const LTNode * op)965   virtual PrimExpr VisitExpr_(const LTNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const LENode * op)966   virtual PrimExpr VisitExpr_(const LENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const GTNode * op)967   virtual PrimExpr VisitExpr_(const GTNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const GENode * op)968   virtual PrimExpr VisitExpr_(const GENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
969 
VisitExpr_(const AndNode * op)970   virtual PrimExpr VisitExpr_(const AndNode* op) { return VisitExpr(op->a) && VisitExpr(op->b); }
971 
972  private:
MutateAtomic_(const PrimExpr & e)973   PrimExpr MutateAtomic_(const PrimExpr& e) {
974     PrimExpr simplified = analyzer_.Simplify(e, kSimplifyRewriteCanonicalRewrite);
975     for (const PrimExpr& other : known_) {
976       if (ExprDeepEqual()(simplified, other)) {
977         return const_true();
978       }
979     }
980     return simplified;
981   }
982 
983   Array<PrimExpr> known_;
984   arith::Analyzer analyzer_;
985   const Op& op_if_then_else_ = Op::Get("tir.if_then_else");
986 };
987 
988 // Propagate information from conditions and remove redundant inequalities
RemoveRedundantInequalities(const PrimExpr & expr,const Array<PrimExpr> & known)989 inline PrimExpr RemoveRedundantInequalities(const PrimExpr& expr, const Array<PrimExpr>& known) {
990   return RemoveRedundantInequalitiesMutator(known)(expr);
991 }
992 
993 // Extract the given expr under the given condition as a separate tensor if the volume of the
994 // extracted tensor will be less than the volume of the outer_axis
TrySimplifyCompute(const PrimExpr & expr,const PrimExpr & cond,const Array<Var> & outer_axis,const Map<Var,Range> & vranges)995 PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond,
996                             const Array<Var>& outer_axis, const Map<Var, Range>& vranges) {
997   // solve cond, e.g., (jac_i0 == i) && (jac_i1 == j)
998   arith::IntConstraints domain_to_solve(outer_axis, vranges,
999                                         FactorOutAtomicFormulas(cond).to_array());
1000   auto res = SimplifyDomain(domain_to_solve);
1001 
1002   arith::Analyzer analyzer;
1003   analyzer.Bind(res->dst->ranges);
1004   PrimExpr new_expr =
1005       analyzer.Simplify(Substitute(expr, res->src_to_dst), kSimplifyRewriteCanonicalRewrite);
1006   // TODO(yzhliu): This is mostly done to simplify if_then_else
1007   // which is not realized by the canonical simplifier
1008   new_expr = RemoveRedundantInequalities(new_expr, res->dst->relations);
1009 
1010   // Keep only those variables of the new vars which are used in the new_expr
1011   Array<Var> used_res_variables;
1012   for (const Var& var : res->dst->variables) {
1013     if (ExprUseVar(new_expr, var)) {
1014       CHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred.";
1015       used_res_variables.push_back(var);
1016     }
1017   }
1018 
1019   // If the expression does not use vars then it is probably better to keep it inlined
1020   if (used_res_variables.empty()) {
1021     // We can return the new_expr here instead of the old expr because it doesn't use variables
1022     // otherwise we would need to replace the new vars or create a let-expression
1023     return new_expr;
1024   }
1025 
1026   // If it's already tensor[...] then it will probably be useless to further simplify it.
1027   if (new_expr.as<ProducerLoadNode>()) {
1028     return expr;
1029   }
1030 
1031   // Compute volumes before and after
1032   PrimExpr old_volume = make_const(DataType::Int(64), 1);
1033   for (const Var& var : outer_axis) {
1034     CHECK(vranges.count(var)) << "Range of " << var << " was not provided.";
1035     old_volume = old_volume * vranges[var]->extent;
1036   }
1037 
1038   PrimExpr new_volume = make_const(DataType::Int(64), 1);
1039   for (const Var& var : used_res_variables) {
1040     new_volume = new_volume * res->dst->ranges[var]->extent;
1041   }
1042 
1043   // if we can prove that the old volume is not greater than the new volume then
1044   // prefer the old expression.
1045   arith::Analyzer ana_vranges;
1046   ana_vranges.Bind(vranges);
1047   if (ana_vranges.CanProve(old_volume <= new_volume)) {
1048     return expr;
1049   }
1050 
1051   Tensor tensor = TensorFromExpr(new_expr, IterVarsFromMap(used_res_variables, res->dst->ranges),
1052                                  "extracted_tensor");
1053 
1054   Array<PrimExpr> args;
1055   for (const Var& var : used_res_variables) {
1056     args.push_back(res->dst_to_src[var]);
1057   }
1058 
1059   return ProducerLoad(tensor, args);
1060 }
1061 
1062 class ReductionAsTensorAccessMutator : public ExprMutator {
1063  public:
ReductionAsTensorAccessMutator(const Array<Var> & outer_axis,Map<Var,Range> vranges,std::string name="extracted_reduction")1064   explicit ReductionAsTensorAccessMutator(const Array<Var>& outer_axis, Map<Var, Range> vranges,
1065                                           std::string name = "extracted_reduction")
1066       : outer_axis_(outer_axis), vranges_(std::move(vranges)), name_(std::move(name)) {}
1067 
VisitExpr_(const ReduceNode * op)1068   PrimExpr VisitExpr_(const ReduceNode* op) final {
1069     ReductionAsTensorAccessMutator new_mutator(Concat(IterVarsToVars(op->axis), outer_axis_),
1070                                                Merge(vranges_, IterVarsToMap(op->axis)), name_);
1071 
1072     CHECK(op->init.empty()) << "Derivative of Reduction with initialization is not implemented";
1073     Array<PrimExpr> new_source;
1074     for (const PrimExpr& src : op->source) {
1075       new_source.push_back(new_mutator(src));
1076     }
1077 
1078     PrimExpr new_reduce =
1079         Reduce(op->combiner, new_source, op->axis, op->condition, op->value_index, op->init);
1080 
1081     Array<Var> undefined_vars = UndefinedVars(new_reduce);
1082     std::unordered_set<const VarNode*> undefined_var_set;
1083     for (const Var& var : undefined_vars) {
1084       undefined_var_set.insert(var.get());
1085     }
1086 
1087     // Vars of the tensor we are going to create for this reduction
1088     Array<Var> vars;
1089     for (const Var& v : outer_axis_) {
1090       // We take variables from the outer_axis_ which are also present in the new reduction
1091       if (undefined_var_set.count(v.get())) {
1092         vars.push_back(v);
1093       }
1094     }
1095 
1096     auto new_axis_vmap_pair = CloneIterVars(IterVarsFromMap(vars, vranges_));
1097     Array<IterVar> new_axis = new_axis_vmap_pair.first;
1098     arith::Analyzer analyzer;
1099     analyzer.Bind(IterVarsToMap(new_axis));
1100     new_reduce = analyzer.Simplify(Substitute(new_reduce, new_axis_vmap_pair.second),
1101                                    kSimplifyRewriteCanonicalRewrite);
1102 
1103     Tensor tensor = TensorFromExpr(new_reduce, new_axis, name_, tag_, attrs_);
1104 
1105     Array<PrimExpr> args;
1106     for (const Var& v : vars) {
1107       args.push_back(v);
1108     }
1109 
1110     return ProducerLoad(tensor, args);
1111   }
1112 
1113  private:
1114   Array<Var> outer_axis_;
1115   Map<Var, Range> vranges_;
1116   std::string name_;
1117   std::string tag_;
1118   Map<String, ObjectRef> attrs_;
1119 };
1120 
1121 // Extract reductions as separate tensors.
ReductionAsTensorAccess(const PrimExpr & expr,const Array<Var> & outer_axis,const Map<Var,Range> & vranges)1122 inline PrimExpr ReductionAsTensorAccess(const PrimExpr& expr, const Array<Var>& outer_axis,
1123                                         const Map<Var, Range>& vranges) {
1124   return ReductionAsTensorAccessMutator(outer_axis, vranges)(expr);
1125 }
1126 
LiftReductions(const PrimExpr & expr,const Array<Var> & outer_axis,const Map<Var,Range> & vranges)1127 PrimExpr LiftReductions(const PrimExpr& expr, const Array<Var>& outer_axis,
1128                         const Map<Var, Range>& vranges) {
1129   if (const ReduceNode* red = expr.as<ReduceNode>()) {
1130     Array<Var> new_outer_axis = Concat(IterVarsToVars(red->axis), outer_axis);
1131     Map<Var, Range> new_vranges = Merge(vranges, IterVarsToMap(red->axis));
1132     Array<PrimExpr> new_source;
1133     for (const PrimExpr& src : red->source) {
1134       new_source.push_back(ReductionAsTensorAccess(src, new_outer_axis, new_vranges));
1135     }
1136     PrimExpr new_condition = ReductionAsTensorAccess(red->condition, new_outer_axis, new_vranges);
1137 
1138     return Reduce(red->combiner, new_source, red->axis, new_condition, red->value_index, red->init);
1139   } else {
1140     return ReductionAsTensorAccess(expr, outer_axis, vranges);
1141   }
1142 }
1143 
RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr & expr_orig,const Array<IterVar> & axis,const Map<Var,Range> & vranges)1144 PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const Array<IterVar>& axis,
1145                                               const Map<Var, Range>& vranges) {
1146   PrimExpr result;
1147   Map<Var, Range> combined_vranges = Merge(vranges, IterVarsToMap(axis));
1148   arith::Analyzer analyzer;
1149   analyzer.Bind(combined_vranges);
1150 
1151   // Simplify the original expression first, mostly to simplify combiners
1152   PrimExpr expr = analyzer.Simplify(expr_orig, kSimplifyRewriteCanonicalRewrite);
1153 
1154   if (const ReduceNode* red = expr.as<ReduceNode>()) {
1155     CHECK(red->init.empty()) << "Derivative of Reduction with initialization is not implemented";
1156     // TODO(sgrechanik-h): There are some other operations which behave like sum
1157     bool is_sum = IsSumCombiner(red->combiner, vranges);
1158     if (is_sum || CanFactorZeroFromCombiner(red->combiner, red->value_index, vranges)) {
1159       PrimExpr new_red = expr;
1160 
1161       // Here we simplify the reduction
1162       PrimExpr cond = red->condition;
1163       Array<PrimExpr> source = red->source;
1164 
1165       // If it is a summation then we can lift nonzeroness conditions from the source
1166       // and add them to the reduction conditions
1167       if (is_sum) {
1168         auto nz = NonzeronessCondition(red->source[red->value_index]);
1169         cond = nz.cond && cond;
1170         source.Set(0, nz.value);
1171       }
1172 
1173       new_red = Reduce(red->combiner, source, red->axis, cond, red->value_index, red->init);
1174       new_red = SimplifyReductionDomain(new_red, combined_vranges);
1175       // Update original red pointer for later use.
1176       red = new_red.as<ReduceNode>();
1177       // If the reduction disappears completely then transform the result as a non-reduction
1178       if (!red) {
1179         return RemoveJacobianAndLiftNonzeroCondImpl(new_red, axis, vranges);
1180       }
1181 
1182       PrimExpr new_outer_cond, new_reduce_cond;
1183       Array<PrimExpr> new_source = red->source;
1184 
1185       // Partially lift conditions from the reduce condition
1186       std::tie(new_outer_cond, new_reduce_cond) =
1187           LiftConditionsThroughReduction(red->condition, red->axis, axis);
1188 
1189       // If it's not sum then we haven't yet lifted nonzeroness cond from the source
1190       if (!is_sum) {
1191         PrimExpr outer_nz_cond, nz_cond, nz_source;
1192         auto nz = NonzeronessCondition(red->source[red->value_index]);
1193         // Append conditions from the reduction
1194         nz_cond = new_reduce_cond && nz.cond;
1195         nz_source = nz.value;
1196         std::tie(outer_nz_cond, nz_cond) = LiftConditionsThroughReduction(nz_cond, red->axis, axis);
1197         new_outer_cond = new_outer_cond && outer_nz_cond;
1198         new_source.Set(red->value_index, Select(nz_cond, nz_source, make_zero(nz_source.dtype())));
1199       }
1200 
1201       PrimExpr new_reduce = Reduce(red->combiner, new_source, red->axis, new_reduce_cond,
1202                                    red->value_index, red->init);
1203       new_reduce =
1204           TrySimplifyCompute(new_reduce, new_outer_cond, IterVarsToVars(axis), combined_vranges);
1205       result = Select(new_outer_cond, new_reduce, make_zero(new_reduce.dtype()));
1206     } else {
1207       return SimplifyReductionDomain(expr, combined_vranges);
1208     }
1209   } else {
1210     auto nz = NonzeronessCondition(expr);
1211     PrimExpr new_expr =
1212         TrySimplifyCompute(nz.value, nz.cond, IterVarsToVars(axis), combined_vranges);
1213     result = Select(nz.cond, new_expr, make_zero(new_expr.dtype()));
1214   }
1215 
1216   // Note that RemoveRedundantInequalities can sometimes propagate equalities which
1217   // other simplifiers cannot, like (i % 3) == 0.
1218   Array<PrimExpr> axis_conds = IterVarsToInequalities(axis);
1219   result = RemoveRedundantInequalities(result, axis_conds);
1220 
1221   // Currently in TVM reductions are only allowed at the top level of compute,
1222   // we need to extract intermediate inlined reduction as a separate stage (tensor).
1223   // Sometimes TrySimplifyCompute doesn't perform lift / extraction,
1224   // so there may be some non-top reductions left, take care of them.
1225   result = LiftReductions(result, IterVarsToVars(axis), combined_vranges);
1226   return analyzer.Simplify(result, kSimplifyRewriteCanonicalRewrite);
1227 }
1228 
RemoveJacobianAndLiftNonzeroCond(const Tensor & tensor,const Map<Var,Range> & vranges)1229 Tensor RemoveJacobianAndLiftNonzeroCond(const Tensor& tensor, const Map<Var, Range>& vranges) {
1230   auto transform_func = [&vranges](const PrimExpr& expr, const Array<IterVar>& axis) {
1231     return RemoveJacobianAndLiftNonzeroCondImpl(expr, axis, vranges);
1232   };
1233   return TransformTensorBody(tensor, transform_func);
1234 }
1235 
1236 }  // namespace te
1237 }  // namespace tvm
1238