1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file ad_simplify.cc
22 * \brief Simplify tensor compute generated by tensor-level autodiff.
23 *
24 * The major simplification we do in this file is to eliminate
25 * the Jacobian tensor created by autodiff.
26 *
27 * Jacobian tensor is sparse because one output element usually relates
28 * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
29 * between input tensor and output tensor, thus the Jacobian is diagonal.
30 *
31 * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
32 * \alpha and \beta are vectors represent the indices of In and Out respectively.
33 * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
34 * Thereby we solve linear equations of \beta = A \alpha,
35 * as well as linear inequalities of their domain ranges.
36 *
37 * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
38 * arXiv preprint arXiv:1711.01348, 2017. for more details.
39 *
40 * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
41 * replace the compute expression with solved new axes, and create a selection node
42 * (non-zero-condition ? new_compute_expression : 0).
43 *
44 * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
45 *
46 */
47 #include <dmlc/optional.h>
48 #include <tvm/arith/analyzer.h>
49 #include <tvm/arith/int_solver.h>
50 #include <tvm/runtime/registry.h>
51 #include <tvm/te/autodiff.h>
52 #include <tvm/tir/analysis.h>
53 #include <tvm/tir/stmt_functor.h>
54
55 #include <iterator>
56 #include <memory>
57 #include <utility>
58
59 #include "ad_util.h"
60
61 namespace tvm {
62 namespace te {
63
64 using arith::DivMode;
65 using arith::kFloorDiv;
66 using arith::kSimplifyRewriteCanonicalRewrite;
67 using arith::kTruncDiv;
68
69 // Combine all expressions from the container using &&.
70 template <class container>
All(const container & c)71 PrimExpr All(const container& c) {
72 PrimExpr res;
73 for (const auto& e : c) {
74 if (res.get()) {
75 res = res && e;
76 } else {
77 res = e;
78 }
79 }
80 if (res.get()) {
81 return res;
82 } else {
83 return const_true();
84 }
85 }
86
IterVarsToMap(const Array<IterVar> & itervars)87 Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
88 Map<Var, Range> res;
89 for (const IterVar& v : itervars) {
90 res.Set(v->var, v->dom);
91 }
92 return res;
93 }
94
95 // Given a map from vars to ranges create an array of itervars
IterVarsFromMap(const Array<Var> & vars,const Map<Var,Range> & vranges,IterVarType iter_type=kDataPar,std::string thread_tag="")96 Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
97 IterVarType iter_type = kDataPar, std::string thread_tag = "") {
98 Array<IterVar> res;
99 for (const Var& v : vars) {
100 CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
101 << vranges;
102 res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
103 }
104 return res;
105 }
106
IterVarsToVars(const Array<IterVar> & itervars)107 Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
108 Array<Var> res;
109 for (const IterVar& v : itervars) {
110 res.push_back(v->var);
111 }
112 return res;
113 }
114
115 template <typename ValueType>
is_const_value(const PrimExpr & e,ValueType value)116 bool is_const_value(const PrimExpr& e, ValueType value) {
117 static_assert(std::is_integral<ValueType>::value,
118 "Comparison to non-integer values is forbidden.");
119 if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
120 return i->value == value;
121 } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
122 return i->value == value;
123 } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
124 return is_const_value(c->value, value);
125 } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
126 return is_const_value(b->value, value);
127 } else {
128 return false;
129 }
130 }
131
132 // Return true if this combiner is just a sum.
IsSumCombiner(const CommReducer & combiner,const Map<Var,Range> & vranges)133 bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
134 arith::Analyzer analyzer;
135 analyzer.Bind(vranges);
136 if (combiner->result.size() != 1) {
137 return false;
138 }
139
140 if (!is_const_value(
141 analyzer.Simplify(combiner->identity_element[0], kSimplifyRewriteCanonicalRewrite), 0)) {
142 return false;
143 }
144
145 PrimExpr combiner_result =
146 analyzer.Simplify(combiner->result[0], kSimplifyRewriteCanonicalRewrite);
147
148 return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
149 tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
150 }
151
CanFactorZeroFromCombiner(const CommReducer & combiner,int value_index,const Map<Var,Range> & vranges)152 bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
153 const Map<Var, Range>& vranges) {
154 arith::Analyzer analyzer;
155 analyzer.Bind(vranges);
156 if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
157 kSimplifyRewriteCanonicalRewrite),
158 0)) {
159 return false;
160 }
161
162 PrimExpr zero = make_zero(combiner->result[value_index].dtype());
163 PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
164 {combiner->rhs[value_index], zero}});
165 in = analyzer.Simplify(in, kSimplifyRewriteCanonicalRewrite);
166
167 return is_const_value(in, 0);
168 }
169
170 struct NonzeroConditionResult {
171 PrimExpr cond;
172 PrimExpr value;
173
to_exprtvm::te::NonzeroConditionResult174 PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
175
operator <<(std::ostream & os,const NonzeroConditionResult & r)176 friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
177 return os << r.to_expr();
178 }
179 };
180
181 // The implementation of NonzeroCondition
182 // transform expression to cond ? value : 0
183 class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
184 public:
NonzeroCondition(const PrimExpr & e)185 NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
186 if (e.dtype().is_bool()) {
187 // Boolean expressions are non-zero whenever they are true themselves
188 return {e, const_true()};
189 } else {
190 return VisitExpr(e);
191 }
192 }
193
194 // Most of the cases are implemented using helpers below
VisitExpr_(const VarNode * op)195 result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
VisitExpr_(const IntImmNode * op)196 result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
VisitExpr_(const FloatImmNode * op)197 result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
VisitExpr_(const StringImmNode * op)198 result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
VisitExpr_(const AddNode * op)199 result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
VisitExpr_(const SubNode * op)200 result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
VisitExpr_(const MulNode * op)201 result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
VisitExpr_(const DivNode * op)202 result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
VisitExpr_(const ModNode * op)203 result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
VisitExpr_(const FloorDivNode * op)204 result_type VisitExpr_(const FloorDivNode* op) final {
205 return BinOpDivLike_(GetRef<FloorDiv>(op));
206 }
VisitExpr_(const FloorModNode * op)207 result_type VisitExpr_(const FloorModNode* op) final {
208 return BinOpDivLike_(GetRef<FloorMod>(op));
209 }
VisitExpr_(const MinNode * op)210 result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
VisitExpr_(const MaxNode * op)211 result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
212
VisitExpr_(const CastNode * op)213 result_type VisitExpr_(const CastNode* op) final {
214 auto nz_a = NonzeroCondition(op->value);
215 return {nz_a.cond, Cast(op->dtype, nz_a.value)};
216 }
217
VisitExpr_(const SelectNode * op)218 result_type VisitExpr_(const SelectNode* op) final {
219 PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
220 auto nz_a = NonzeroCondition(true_val);
221 auto nz_b = NonzeroCondition(false_val);
222
223 // If the false part is zero, we can get rid of the select
224 if (is_const_value(nz_b.value, 0)) {
225 PrimExpr new_cond = analyzer_.Simplify(nz_a.cond && cond, kSimplifyRewriteCanonicalRewrite);
226 return {new_cond, nz_a.value};
227 }
228
229 // If the true part is zero, we can also get rid of the select
230 if (is_const_value(nz_a.value, 0)) {
231 PrimExpr new_cond = analyzer_.Simplify(nz_b.cond && !cond, kSimplifyRewriteCanonicalRewrite);
232 return {new_cond, nz_b.value};
233 }
234
235 // Otherwise we retain the select and combine the conditions into this
236 PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
237 kSimplifyRewriteCanonicalRewrite);
238 if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
239 return {new_cond, GetRef<PrimExpr>(op)};
240 } else {
241 return {new_cond, Select(cond, nz_a.value, nz_b.value)};
242 }
243 }
244
VisitExpr_(const CallNode * op)245 result_type VisitExpr_(const CallNode* op) final {
246 if (op->op.same_as(op_if_then_else_)) {
247 PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
248 auto nz_a = NonzeroCondition(true_val);
249 auto nz_b = NonzeroCondition(false_val);
250
251 // We don't have as much freedom here as in the select case
252 // since the `if` must be preserved in any case
253 PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
254 kSimplifyRewriteCanonicalRewrite);
255 if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
256 return {new_cond, GetRef<PrimExpr>(op)};
257 } else {
258 return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
259 }
260 } else {
261 return Default_(GetRef<PrimExpr>(op));
262 }
263 }
264
VisitExpr_(const ProducerLoadNode * op)265 result_type VisitExpr_(const ProducerLoadNode* op) final {
266 return Default_(GetRef<PrimExpr>(op));
267 }
268
Default_(const PrimExpr & e)269 NonzeroConditionResult Default_(const PrimExpr& e) {
270 // This is always correct, so it's the default
271 return {const_true(), e};
272 }
273
274 template <class T>
Const_(const T & op)275 NonzeroConditionResult Const_(const T& op) {
276 if (op->value == 0) {
277 return {const_false(), op};
278 } else {
279 return {const_true(), op};
280 }
281 }
282
283 template <class T>
BinOpAddLike_(const T & op)284 NonzeroConditionResult BinOpAddLike_(const T& op) {
285 auto nz_a = NonzeroCondition(op->a);
286 auto nz_b = NonzeroCondition(op->b);
287
288 // For addition and similar ops the result may be nonzero if either of the arguments is
289 // nonzero, so we combine the conditions with Or.
290 if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
291 // If the conditions are the same, we don't need Or
292 if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
293 return {nz_a.cond, op};
294 } else {
295 return {nz_a.cond, T(nz_a.value, nz_b.value)};
296 }
297 } else {
298 // Otherwise use Or
299 PrimExpr new_cond =
300 analyzer_.Simplify(nz_a.cond || nz_b.cond, kSimplifyRewriteCanonicalRewrite);
301 // A little optimization: if the combined condition is the same as one of the inner
302 // conditions, we don't need to guard the inner value with a select, otherwise
303 // we create a select in the `to_expr` call.
304 PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
305 PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
306 PrimExpr new_expr = T(new_a, new_b);
307 return {new_cond, new_expr};
308 }
309 }
310
311 template <class T>
BinOpMulLike_(const T & op)312 NonzeroConditionResult BinOpMulLike_(const T& op) {
313 auto nz_a = NonzeroCondition(op->a);
314 auto nz_b = NonzeroCondition(op->b);
315
316 // For multiplication and similar ops the result may be nonzero if
317 // both the arguments are nonzero, so we combine with And.
318 PrimExpr new_cond =
319 analyzer_.Simplify(nz_a.cond && nz_b.cond, kSimplifyRewriteCanonicalRewrite);
320
321 if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
322 return {new_cond, op};
323 } else {
324 return {new_cond, T(nz_a.value, nz_b.value)};
325 }
326 }
327
328 template <class T>
BinOpDivLike_(const T & op)329 NonzeroConditionResult BinOpDivLike_(const T& op) {
330 auto nz_a = NonzeroCondition(op->a);
331
332 // For Div we simply use the condition of the numerator.
333
334 if (nz_a.value.same_as(op->a)) {
335 return {nz_a.cond, op};
336 } else {
337 return {nz_a.cond, T(nz_a.value, op->b)};
338 }
339 }
340
341 private:
342 arith::Analyzer analyzer_;
343 const Op& op_if_then_else_ = Op::Get("tir.if_then_else");
344 };
345
NonzeronessCondition(const PrimExpr & expr)346 inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
347 return NonzeroConditionFunctor().NonzeroCondition(expr);
348 }
349
350 struct FactorOutAtomicFormulasResult {
351 std::vector<PrimExpr> atomic_formulas;
352 PrimExpr rest;
353
to_exprtvm::te::FactorOutAtomicFormulasResult354 PrimExpr to_expr() const {
355 PrimExpr res = rest;
356 for (const PrimExpr& e : atomic_formulas) {
357 res = And(e, res);
358 }
359 return res;
360 }
361
to_arraytvm::te::FactorOutAtomicFormulasResult362 Array<PrimExpr> to_array() const {
363 Array<PrimExpr> res = atomic_formulas;
364 res.push_back(rest);
365 return res;
366 }
367 };
368
369 // The implementation of FactorOutAtomicFormulas
370 class FactorOutAtomicFormulasFunctor
371 : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
372 public:
Atomic_(const PrimExpr & e)373 result_type Atomic_(const PrimExpr& e) {
374 // For atomic expressions the result is the expr itself with True as the residual
375 return {{e}, make_const(e.dtype(), 1)};
376 }
377
378 // This is basically the list of expression kinds that are considered atomic
VisitExpr_(const VarNode * op)379 result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const CallNode * op)380 result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const IntImmNode * op)381 result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const EQNode * op)382 result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const NENode * op)383 result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const LENode * op)384 result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const LTNode * op)385 result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const GENode * op)386 result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const GTNode * op)387 result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
388
VisitExpr_(const SelectNode * op)389 result_type VisitExpr_(const SelectNode* op) final {
390 // Select can be rewritten through other logical ops
391 PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
392 return VisitExpr(expr);
393 }
394
VisitExpr_(const NotNode * op)395 result_type VisitExpr_(const NotNode* op) final {
396 // Not should be moved down
397 if (const OrNode* or_expr = op->a.as<OrNode>()) {
398 PrimExpr expr = !or_expr->a && !or_expr->b;
399 return VisitExpr(expr);
400 } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
401 PrimExpr expr = !and_expr->a || !and_expr->b;
402 return VisitExpr(expr);
403 } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
404 PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
405 (sel_expr->condition || !sel_expr->false_value));
406 return VisitExpr(expr);
407 }
408 return Atomic_(GetRef<PrimExpr>(op));
409 }
410
VisitExpr_(const AndNode * op)411 result_type VisitExpr_(const AndNode* op) final {
412 auto res_a = VisitExpr(op->a);
413 auto res_b = VisitExpr(op->b);
414
415 // For the And case we return the union of the sets of atomic formulas
416 std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
417 res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
418 std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
419 std::inserter(res_set, res_set.end()));
420 std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
421 std::inserter(res_set, res_set.end()));
422
423 std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
424
425 // And the residuals are combined with &&
426 return {res, res_a.rest && res_b.rest};
427 }
428
VisitExpr_(const MulNode * op)429 result_type VisitExpr_(const MulNode* op) final {
430 // Since we work with bools, for multiplication we do the same thing as for And
431 PrimExpr e_and = op->a && op->b;
432 return VisitExpr(e_and);
433 }
434
VisitExpr_(const OrNode * op)435 result_type VisitExpr_(const OrNode* op) final {
436 auto res_a = VisitExpr(op->a);
437 auto res_b = VisitExpr(op->b);
438
439 std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
440 res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
441 std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
442 res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
443
444 // For the Or case we intersect the sets of atomic formulas
445 std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
446 res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
447 for (const auto& res_b_formula : res_b_set) {
448 if (res_a_set.count(res_b_formula)) {
449 res_set.insert(res_b_formula);
450 }
451 }
452
453 // Computing the residual is more complex: we have to compute the sets of atomic formulas
454 // which are left behind, and then combine them with the residuals into the new residual.
455 std::vector<PrimExpr> new_cond_a;
456 new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
457 for (const auto& formula : res_a_set) {
458 if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
459 }
460
461 std::vector<PrimExpr> new_cond_b;
462 new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
463 for (const auto& formula : res_b_set) {
464 if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
465 }
466
467 res_a.atomic_formulas = std::move(new_cond_a);
468 res_b.atomic_formulas = std::move(new_cond_b);
469
470 PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
471 std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
472
473 return {res, new_rest};
474 }
475 };
476
477 // Transform the given formula into a conjunction of atomic formulas (represented as an array)
478 // and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
479 // etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
FactorOutAtomicFormulas(const PrimExpr & e)480 FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
481 CHECK(e.dtype().is_bool());
482 return FactorOutAtomicFormulasFunctor().VisitExpr(e);
483 }
484
485 struct EliminateDivModResult {
486 PrimExpr expr;
487 Map<Var, PrimExpr> substitution;
488 Array<Var> new_variables;
489 Array<PrimExpr> conditions;
490 Map<Var, Range> ranges;
491 };
492
ModImpl(PrimExpr a,PrimExpr b,DivMode mode)493 inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
494 if (mode == kTruncDiv) {
495 return truncmod(a, b);
496 } else {
497 CHECK_EQ(mode, kFloorDiv);
498 return floormod(a, b);
499 }
500 }
501
DivImpl(PrimExpr a,PrimExpr b,DivMode mode)502 inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
503 if (mode == kTruncDiv) {
504 return truncdiv(a, b);
505 } else {
506 CHECK_EQ(mode, kFloorDiv);
507 return floordiv(a, b);
508 }
509 }
510
511 class EliminateDivModMutator : public ExprMutator {
512 public:
513 Map<Var, PrimExpr> substitution;
514 Array<Var> new_variables;
515 Array<PrimExpr> conditions;
516 Map<Var, Range> ranges;
517
EliminateDivModMutator(Map<Var,Range> ranges)518 explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
519
VisitExpr_(const DivNode * op)520 virtual PrimExpr VisitExpr_(const DivNode* op) {
521 const IntImmNode* imm = op->b.as<IntImmNode>();
522 if (imm && imm->value != 0) {
523 if (imm->value < 0) {
524 // x / -c == -(x/c) for truncated division
525 return make_zero(op->dtype) -
526 VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
527 }
528
529 // Try to find the already existing variables for this expression
530 auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
531 if (it != expr_to_vars_.end()) {
532 return it->second.first;
533 }
534
535 // Otherwise recursively mutate the left hand side, and create new variables
536 PrimExpr mutated_a = VisitExpr(op->a);
537 if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
538 return var_pair_opt.value().first;
539 } else {
540 return truncdiv(mutated_a, op->b);
541 }
542 }
543
544 return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
545 }
546
VisitExpr_(const ModNode * op)547 virtual PrimExpr VisitExpr_(const ModNode* op) {
548 const IntImmNode* imm = op->b.as<IntImmNode>();
549 if (imm && imm->value != 0) {
550 if (imm->value < 0) {
551 // x % -c == x % c for truncated division
552 return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
553 }
554
555 // Try to find the already existing variables for this expression
556 auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
557 if (it != expr_to_vars_.end()) {
558 return it->second.second;
559 }
560
561 // Otherwise recursively mutate the left hand side, and create new variables
562 PrimExpr mutated_a = VisitExpr(op->a);
563 if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
564 return var_pair_opt.value().second;
565 } else {
566 return truncmod(mutated_a, op->b);
567 }
568 }
569
570 return truncmod(VisitExpr(op->a), VisitExpr(op->b));
571 }
572
VisitExpr_(const FloorDivNode * op)573 virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
574 const IntImmNode* imm = op->b.as<IntImmNode>();
575 if (imm && imm->value != 0) {
576 if (imm->value < 0) {
577 // x / -c == (-x) / c for flooring division
578 return VisitExpr(
579 floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
580 }
581
582 // Try to find the already existing variables for this expression
583 auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
584 if (it != expr_to_vars_.end()) {
585 return it->second.first;
586 }
587
588 // Otherwise recursively mutate the left hand side, and create new variables
589 PrimExpr mutated_a = VisitExpr(op->a);
590 if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
591 return var_pair_opt.value().first;
592 } else {
593 return floordiv(mutated_a, op->b);
594 }
595 }
596
597 return floordiv(VisitExpr(op->a), VisitExpr(op->b));
598 }
599
VisitExpr_(const FloorModNode * op)600 virtual PrimExpr VisitExpr_(const FloorModNode* op) {
601 const IntImmNode* imm = op->b.as<IntImmNode>();
602 if (imm && imm->value != 0) {
603 if (imm->value < 0) {
604 // x % -c == -(-x % c) for flooring division
605 return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
606 make_const(op->dtype, -imm->value)));
607 }
608
609 // Try to find the already existing variables for this expression
610 auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
611 if (it != expr_to_vars_.end()) {
612 return it->second.second;
613 }
614
615 // Otherwise recursively mutate the left hand side, and create new variables
616 PrimExpr mutated_a = VisitExpr(op->a);
617 if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
618 return var_pair_opt.value().second;
619 } else {
620 return floormod(mutated_a, op->b);
621 }
622 }
623
624 return floormod(VisitExpr(op->a), VisitExpr(op->b));
625 }
626
627 private:
AddNewVarPair(const PrimExpr & e,const PrimExpr & mut,int64_t val,DivMode mode)628 dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
629 int64_t val, DivMode mode) {
630 using tresult = dmlc::optional<std::pair<Var, Var>>;
631
632 // Try to find the variables using the mutated expressions
633 if (!e.same_as(mut)) {
634 auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
635 if (it != expr_to_vars_.end()) {
636 return tresult(it->second);
637 }
638 }
639
640 PrimExpr val_e = make_const(e.dtype(), val);
641 idx_ += 1;
642
643 // Convert `ranges` to IntSets
644 std::unordered_map<const VarNode*, IntSet> var_intsets;
645 for (const auto& p : ranges) {
646 var_intsets[p.first.get()] = IntSet::FromRange(p.second);
647 }
648
649 // Infer ranges for the expressions we want to replace with variables
650 Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
651 Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
652
653 // We don't want to add unbounded variables
654 if (!div_range.get() || !mod_range.get()) {
655 LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
656 << " because its bounds cannot be inferred";
657 return tresult();
658 }
659 if (!mod_range.get()) {
660 LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
661 << " because its bounds cannot be inferred";
662 return tresult();
663 }
664
665 // Create new variables for the expressions
666 auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
667 auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
668
669 new_variables.push_back(div);
670 new_variables.push_back(mod);
671
672 // Note that we have to perform substitution to mut because mut may contain new variables
673 substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
674 substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
675
676 ranges.Set(div, div_range);
677 ranges.Set(mod, mod_range);
678
679 // This additional condition works as a definition for the new variables
680 conditions.push_back(mut == div * val_e + mod);
681
682 if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
683 // If we use the C/C++ definition of mod, there may be multiple values of `mod`
684 // satisfying the added condition if the expr `e` may change its sign, so we
685 // have to add another condition.
686 LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
687 << ModImpl(e, val_e, mode) << " probably may change its sign";
688 conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
689 }
690
691 auto p = std::make_pair(div, mod);
692 expr_to_vars_[std::make_tuple(mode, e, val)] = p;
693 if (!e.same_as(mut)) {
694 expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
695 }
696 return tresult(p);
697 }
698
699 class TupleEqual_ {
700 public:
operator ()(const std::tuple<DivMode,PrimExpr,int64_t> & lhs,const std::tuple<DivMode,PrimExpr,int64_t> & rhs) const701 bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
702 const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
703 return std::get<0>(lhs) == std::get<0>(rhs) &&
704 tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
705 std::get<2>(lhs) == std::get<2>(rhs);
706 }
707 };
708
709 class TupleHasher_ {
710 public:
operator ()(const std::tuple<DivMode,PrimExpr,int64_t> & key) const711 size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
712 return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
713 1) ^
714 (std::hash<int64_t>()(std::get<2>(key)) << 1);
715 }
716 };
717
718 // A counter for naming new variables
719 int idx_{0};
720 // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
721 // such that `div = e / n` and `mod = e % n`
722 std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
723 TupleEqual_>
724 expr_to_vars_;
725 arith::Analyzer analyzer_;
726 };
727
728 // Replace every subexpr of the form e/const and e % const with a new variable.
729 // Syntactically equal expressions will be mapped to the same variable.
EliminateDivMod(const PrimExpr & expr,Map<Var,Range> ranges)730 EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
731 EliminateDivModResult res;
732 EliminateDivModMutator mutator(ranges);
733 res.expr = mutator(expr);
734 res.conditions = std::move(mutator.conditions);
735 res.new_variables = std::move(mutator.new_variables);
736 res.substitution = std::move(mutator.substitution);
737 res.ranges = std::move(mutator.ranges);
738 return res;
739 }
740
EliminateDivModFromDomainConditions(const arith::IntConstraints & domain)741 arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
742 const arith::IntConstraints& domain) {
743 auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
744
745 Map<Var, Range> new_vranges = elim_res.ranges;
746 Array<Var> new_axis = Concat(domain->variables, elim_res.new_variables);
747 PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
748
749 arith::IntConstraints new_domain(new_axis, new_vranges,
750 FactorOutAtomicFormulas(new_cond).to_array());
751
752 Map<Var, PrimExpr> src_to_dst;
753 Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
754 for (const Var& v : domain->variables) {
755 src_to_dst.Set(v, v);
756 dst_to_src.Set(v, v);
757 }
758
759 return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
760 }
761
IdentityTransformation(const arith::IntConstraints & domain)762 inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) {
763 Map<Var, PrimExpr> identity_map;
764 for (const Var& v : domain->variables) {
765 identity_map.Set(v, v);
766 }
767 return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map);
768 }
769
770 // Simplify an iteration domain.
SimplifyDomain(const arith::IntConstraints & iter_domains,bool eliminate_div_mod)771 arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
772 bool eliminate_div_mod) {
773 arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains);
774
775 if (eliminate_div_mod) {
776 transf = transf + EliminateDivModFromDomainConditions(transf->dst);
777 }
778
779 // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably
780 // should find a better terminating criterion (like stop when the domain volume stops decreasing)
781 // Also 2 steps seems to be slightly better than 3
782 for (size_t i = 0; i < 2; ++i) {
783 transf = transf + arith::SolveLinearEquations(transf->dst);
784 transf = transf + arith::SolveInequalitiesDeskewRange(transf->dst);
785 }
786
787 return transf;
788 }
789
790 // Use the condition of a reduction op to simplify its domain (axis)
SimplifyReductionDomain(const PrimExpr & expr,const Map<Var,Range> & outer_vranges)791 PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& outer_vranges) {
792 if (const ReduceNode* red = expr.as<ReduceNode>()) {
793 Array<Var> vars = IterVarsToVars(red->axis);
794 Map<Var, Range> vranges = Merge(outer_vranges, IterVarsToMap(red->axis));
795 Array<PrimExpr> relations = FactorOutAtomicFormulas(red->condition).to_array();
796
797 arith::IntConstraints domain(vars, vranges, relations);
798 auto res = SimplifyDomain(domain);
799
800 Array<PrimExpr> new_source;
801 for (const PrimExpr& src : red->source) {
802 new_source.push_back(Substitute(src, res->src_to_dst));
803 }
804
805 Array<IterVar> new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce);
806
807 // Perform simplification mainly to remove a possibly empty reduction.
808 arith::Analyzer analyzer;
809 return analyzer.Simplify(Reduce(red->combiner, new_source, new_axis, All(res->dst->relations),
810 red->value_index, red->init),
811 kSimplifyRewriteCanonicalRewrite);
812 } else {
813 return expr;
814 }
815 }
816
817 // Extract from cond an implication of cond not containing vars
ImplicationNotContainingVars(const PrimExpr & cond,const std::unordered_set<const VarNode * > & vars)818 std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
819 const PrimExpr& cond, const std::unordered_set<const VarNode*>& vars) {
820 CHECK(cond.dtype().is_bool()) << "The type of cond must be bool";
821 // TODO(sgrechanik-h): NOTs could be pushed down using De Morgan laws
822 // before running this function but this case didn't seem to be important enough.
823 if (const AndNode* op = cond.as<AndNode>()) {
824 auto pair_a = ImplicationNotContainingVars(op->a, vars);
825 auto pair_b = ImplicationNotContainingVars(op->b, vars);
826 return {pair_a.first && pair_b.first, pair_a.second && pair_b.second};
827 } else if (const OrNode* op = cond.as<OrNode>()) {
828 auto pair_a = ImplicationNotContainingVars(op->a, vars);
829 auto pair_b = ImplicationNotContainingVars(op->b, vars);
830 return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) &&
831 (pair_b.first || pair_a.second) &&
832 (pair_a.second || pair_b.second)};
833 } else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
834 return {cond, const_true()};
835 } else {
836 return {const_true(), cond};
837 }
838 }
839
840 // Factor conditions out of a reduction by applying Fourier-Motzkin elimination and moving out
841 // (in)equalities which do not depend on the reduction variables.
LiftConditionsThroughReduction(const PrimExpr & cond,const Array<IterVar> & red_axis,const Array<IterVar> & outer_axis)842 std::pair<PrimExpr, PrimExpr> LiftConditionsThroughReduction(const PrimExpr& cond,
843 const Array<IterVar>& red_axis,
844 const Array<IterVar>& outer_axis) {
845 // Factor out atomics so that we can consider this as a system of inequalities
846 auto factor_atomic_res = FactorOutAtomicFormulas(cond);
847 Array<PrimExpr> atomics = factor_atomic_res.atomic_formulas;
848 const PrimExpr& rest = factor_atomic_res.rest;
849
850 Array<Var> allvars;
851 for (const IterVar& v : red_axis) {
852 allvars.push_back(v->var);
853 }
854 for (const IterVar& v : outer_axis) {
855 allvars.push_back(v->var);
856 }
857
858 auto vranges = Merge(IterVarsToMap(red_axis), IterVarsToMap(outer_axis));
859 // start from reduction vars, so that input vars don't depend on them
860 arith::IntConstraints ineq_to_solve(allvars, vranges, atomics);
861 auto res_ineq = arith::SolveLinearInequalities(ineq_to_solve);
862 atomics = arith::AsConditions(allvars, res_ineq.first, res_ineq.second);
863
864 // Append the rest part
865 PrimExpr rewritten_cond = All(atomics) && rest;
866
867 std::unordered_set<const VarNode*> vset;
868 for (const IterVar& v : red_axis) {
869 vset.insert(v->var.get());
870 }
871
872 // The outer (first) condition does not contain reduction vars,
873 // the inner (second) condition is everything else
874 auto res = ImplicationNotContainingVars(rewritten_cond, vset);
875 return res;
876 }
877
878 // Convert an array of itervars to an array of inequalities
IterVarsToInequalities(const Array<IterVar> & itervars)879 Array<PrimExpr> IterVarsToInequalities(const Array<IterVar>& itervars) {
880 Array<PrimExpr> res;
881 for (const IterVar& v : itervars) {
882 res.push_back(GE(v->var, v->dom->min));
883 res.push_back(LT(v->var, v->dom->min + v->dom->extent));
884 }
885 return res;
886 }
887
888 class RemoveRedundantInequalitiesMutator : public ExprMutator {
889 public:
RemoveRedundantInequalitiesMutator(Array<PrimExpr> known)890 explicit RemoveRedundantInequalitiesMutator(Array<PrimExpr> known) {
891 for (const PrimExpr& cond : known) {
892 known_.push_back(analyzer_.Simplify(cond, kSimplifyRewriteCanonicalRewrite));
893 }
894 }
895
VisitExpr_(const SelectNode * op)896 virtual PrimExpr VisitExpr_(const SelectNode* op) {
897 bool has_side_effect = (SideEffect(GetRef<PrimExpr>(op)) > CallEffectKind::kReadState);
898 PrimExpr new_cond =
899 analyzer_.Simplify(VisitExpr(op->condition), kSimplifyRewriteCanonicalRewrite);
900 if (is_one(new_cond) && !has_side_effect) {
901 return VisitExpr(op->true_value);
902 } else if (is_zero(new_cond) && !has_side_effect) {
903 return VisitExpr(op->false_value);
904 } else {
905 Array<PrimExpr> new_known = known_;
906 for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
907 new_known.push_back(atomic);
908 }
909 RemoveRedundantInequalitiesMutator new_mutator(new_known);
910 // Note that we mutate only the true value with the new mutator
911 // TODO(sgrechanik-h): Update known conditions for the false value as well
912 return Select(new_cond, new_mutator(op->true_value), VisitExpr(op->false_value));
913 }
914 }
915
VisitExpr_(const CallNode * op)916 virtual PrimExpr VisitExpr_(const CallNode* op) {
917 if (op->op.same_as(op_if_then_else_)) {
918 PrimExpr new_cond =
919 analyzer_.Simplify(VisitExpr(op->args[0]), kSimplifyRewriteCanonicalRewrite);
920 if (is_one(new_cond)) {
921 return VisitExpr(op->args[1]);
922 } else if (is_zero(new_cond)) {
923 return VisitExpr(op->args[2]);
924 } else {
925 Array<PrimExpr> new_known = known_;
926 for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
927 new_known.push_back(atomic);
928 }
929 RemoveRedundantInequalitiesMutator new_mutator(new_known);
930 // Note that we mutate only the true value with the new mutator
931 // TODO(sgrechanik-h): Update known conditions for the false value as well
932 return if_then_else(new_cond, new_mutator(op->args[1]), VisitExpr(op->args[2]));
933 }
934 } else {
935 return ExprMutator::VisitExpr_(op);
936 }
937 }
938
VisitExpr_(const ReduceNode * op)939 virtual PrimExpr VisitExpr_(const ReduceNode* op) {
940 Array<PrimExpr> known_with_axes = known_;
941 CHECK(op->init.empty()) << "Derivative of Reduction with initialization is not implemented";
942 for (const PrimExpr& axis_cond : IterVarsToInequalities(op->axis)) {
943 known_with_axes.push_back(axis_cond);
944 }
945 RemoveRedundantInequalitiesMutator mutator_with_axes(known_with_axes);
946
947 PrimExpr new_cond = mutator_with_axes(op->condition);
948
949 Array<PrimExpr> new_known = known_with_axes;
950 for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
951 new_known.push_back(atomic);
952 }
953 RemoveRedundantInequalitiesMutator new_mutator(new_known);
954
955 Array<PrimExpr> new_source;
956 for (const PrimExpr& src : op->source) {
957 new_source.push_back(new_mutator(src));
958 }
959
960 return Reduce(op->combiner, new_source, op->axis, new_cond, op->value_index, op->init);
961 }
962
VisitExpr_(const EQNode * op)963 virtual PrimExpr VisitExpr_(const EQNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const NENode * op)964 virtual PrimExpr VisitExpr_(const NENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const LTNode * op)965 virtual PrimExpr VisitExpr_(const LTNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const LENode * op)966 virtual PrimExpr VisitExpr_(const LENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const GTNode * op)967 virtual PrimExpr VisitExpr_(const GTNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
VisitExpr_(const GENode * op)968 virtual PrimExpr VisitExpr_(const GENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
969
VisitExpr_(const AndNode * op)970 virtual PrimExpr VisitExpr_(const AndNode* op) { return VisitExpr(op->a) && VisitExpr(op->b); }
971
972 private:
MutateAtomic_(const PrimExpr & e)973 PrimExpr MutateAtomic_(const PrimExpr& e) {
974 PrimExpr simplified = analyzer_.Simplify(e, kSimplifyRewriteCanonicalRewrite);
975 for (const PrimExpr& other : known_) {
976 if (ExprDeepEqual()(simplified, other)) {
977 return const_true();
978 }
979 }
980 return simplified;
981 }
982
983 Array<PrimExpr> known_;
984 arith::Analyzer analyzer_;
985 const Op& op_if_then_else_ = Op::Get("tir.if_then_else");
986 };
987
988 // Propagate information from conditions and remove redundant inequalities
RemoveRedundantInequalities(const PrimExpr & expr,const Array<PrimExpr> & known)989 inline PrimExpr RemoveRedundantInequalities(const PrimExpr& expr, const Array<PrimExpr>& known) {
990 return RemoveRedundantInequalitiesMutator(known)(expr);
991 }
992
993 // Extract the given expr under the given condition as a separate tensor if the volume of the
994 // extracted tensor will be less than the volume of the outer_axis
TrySimplifyCompute(const PrimExpr & expr,const PrimExpr & cond,const Array<Var> & outer_axis,const Map<Var,Range> & vranges)995 PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond,
996 const Array<Var>& outer_axis, const Map<Var, Range>& vranges) {
997 // solve cond, e.g., (jac_i0 == i) && (jac_i1 == j)
998 arith::IntConstraints domain_to_solve(outer_axis, vranges,
999 FactorOutAtomicFormulas(cond).to_array());
1000 auto res = SimplifyDomain(domain_to_solve);
1001
1002 arith::Analyzer analyzer;
1003 analyzer.Bind(res->dst->ranges);
1004 PrimExpr new_expr =
1005 analyzer.Simplify(Substitute(expr, res->src_to_dst), kSimplifyRewriteCanonicalRewrite);
1006 // TODO(yzhliu): This is mostly done to simplify if_then_else
1007 // which is not realized by the canonical simplifier
1008 new_expr = RemoveRedundantInequalities(new_expr, res->dst->relations);
1009
1010 // Keep only those variables of the new vars which are used in the new_expr
1011 Array<Var> used_res_variables;
1012 for (const Var& var : res->dst->variables) {
1013 if (ExprUseVar(new_expr, var)) {
1014 CHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred.";
1015 used_res_variables.push_back(var);
1016 }
1017 }
1018
1019 // If the expression does not use vars then it is probably better to keep it inlined
1020 if (used_res_variables.empty()) {
1021 // We can return the new_expr here instead of the old expr because it doesn't use variables
1022 // otherwise we would need to replace the new vars or create a let-expression
1023 return new_expr;
1024 }
1025
1026 // If it's already tensor[...] then it will probably be useless to further simplify it.
1027 if (new_expr.as<ProducerLoadNode>()) {
1028 return expr;
1029 }
1030
1031 // Compute volumes before and after
1032 PrimExpr old_volume = make_const(DataType::Int(64), 1);
1033 for (const Var& var : outer_axis) {
1034 CHECK(vranges.count(var)) << "Range of " << var << " was not provided.";
1035 old_volume = old_volume * vranges[var]->extent;
1036 }
1037
1038 PrimExpr new_volume = make_const(DataType::Int(64), 1);
1039 for (const Var& var : used_res_variables) {
1040 new_volume = new_volume * res->dst->ranges[var]->extent;
1041 }
1042
1043 // if we can prove that the old volume is not greater than the new volume then
1044 // prefer the old expression.
1045 arith::Analyzer ana_vranges;
1046 ana_vranges.Bind(vranges);
1047 if (ana_vranges.CanProve(old_volume <= new_volume)) {
1048 return expr;
1049 }
1050
1051 Tensor tensor = TensorFromExpr(new_expr, IterVarsFromMap(used_res_variables, res->dst->ranges),
1052 "extracted_tensor");
1053
1054 Array<PrimExpr> args;
1055 for (const Var& var : used_res_variables) {
1056 args.push_back(res->dst_to_src[var]);
1057 }
1058
1059 return ProducerLoad(tensor, args);
1060 }
1061
1062 class ReductionAsTensorAccessMutator : public ExprMutator {
1063 public:
ReductionAsTensorAccessMutator(const Array<Var> & outer_axis,Map<Var,Range> vranges,std::string name="extracted_reduction")1064 explicit ReductionAsTensorAccessMutator(const Array<Var>& outer_axis, Map<Var, Range> vranges,
1065 std::string name = "extracted_reduction")
1066 : outer_axis_(outer_axis), vranges_(std::move(vranges)), name_(std::move(name)) {}
1067
VisitExpr_(const ReduceNode * op)1068 PrimExpr VisitExpr_(const ReduceNode* op) final {
1069 ReductionAsTensorAccessMutator new_mutator(Concat(IterVarsToVars(op->axis), outer_axis_),
1070 Merge(vranges_, IterVarsToMap(op->axis)), name_);
1071
1072 CHECK(op->init.empty()) << "Derivative of Reduction with initialization is not implemented";
1073 Array<PrimExpr> new_source;
1074 for (const PrimExpr& src : op->source) {
1075 new_source.push_back(new_mutator(src));
1076 }
1077
1078 PrimExpr new_reduce =
1079 Reduce(op->combiner, new_source, op->axis, op->condition, op->value_index, op->init);
1080
1081 Array<Var> undefined_vars = UndefinedVars(new_reduce);
1082 std::unordered_set<const VarNode*> undefined_var_set;
1083 for (const Var& var : undefined_vars) {
1084 undefined_var_set.insert(var.get());
1085 }
1086
1087 // Vars of the tensor we are going to create for this reduction
1088 Array<Var> vars;
1089 for (const Var& v : outer_axis_) {
1090 // We take variables from the outer_axis_ which are also present in the new reduction
1091 if (undefined_var_set.count(v.get())) {
1092 vars.push_back(v);
1093 }
1094 }
1095
1096 auto new_axis_vmap_pair = CloneIterVars(IterVarsFromMap(vars, vranges_));
1097 Array<IterVar> new_axis = new_axis_vmap_pair.first;
1098 arith::Analyzer analyzer;
1099 analyzer.Bind(IterVarsToMap(new_axis));
1100 new_reduce = analyzer.Simplify(Substitute(new_reduce, new_axis_vmap_pair.second),
1101 kSimplifyRewriteCanonicalRewrite);
1102
1103 Tensor tensor = TensorFromExpr(new_reduce, new_axis, name_, tag_, attrs_);
1104
1105 Array<PrimExpr> args;
1106 for (const Var& v : vars) {
1107 args.push_back(v);
1108 }
1109
1110 return ProducerLoad(tensor, args);
1111 }
1112
1113 private:
1114 Array<Var> outer_axis_;
1115 Map<Var, Range> vranges_;
1116 std::string name_;
1117 std::string tag_;
1118 Map<String, ObjectRef> attrs_;
1119 };
1120
1121 // Extract reductions as separate tensors.
ReductionAsTensorAccess(const PrimExpr & expr,const Array<Var> & outer_axis,const Map<Var,Range> & vranges)1122 inline PrimExpr ReductionAsTensorAccess(const PrimExpr& expr, const Array<Var>& outer_axis,
1123 const Map<Var, Range>& vranges) {
1124 return ReductionAsTensorAccessMutator(outer_axis, vranges)(expr);
1125 }
1126
LiftReductions(const PrimExpr & expr,const Array<Var> & outer_axis,const Map<Var,Range> & vranges)1127 PrimExpr LiftReductions(const PrimExpr& expr, const Array<Var>& outer_axis,
1128 const Map<Var, Range>& vranges) {
1129 if (const ReduceNode* red = expr.as<ReduceNode>()) {
1130 Array<Var> new_outer_axis = Concat(IterVarsToVars(red->axis), outer_axis);
1131 Map<Var, Range> new_vranges = Merge(vranges, IterVarsToMap(red->axis));
1132 Array<PrimExpr> new_source;
1133 for (const PrimExpr& src : red->source) {
1134 new_source.push_back(ReductionAsTensorAccess(src, new_outer_axis, new_vranges));
1135 }
1136 PrimExpr new_condition = ReductionAsTensorAccess(red->condition, new_outer_axis, new_vranges);
1137
1138 return Reduce(red->combiner, new_source, red->axis, new_condition, red->value_index, red->init);
1139 } else {
1140 return ReductionAsTensorAccess(expr, outer_axis, vranges);
1141 }
1142 }
1143
RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr & expr_orig,const Array<IterVar> & axis,const Map<Var,Range> & vranges)1144 PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const Array<IterVar>& axis,
1145 const Map<Var, Range>& vranges) {
1146 PrimExpr result;
1147 Map<Var, Range> combined_vranges = Merge(vranges, IterVarsToMap(axis));
1148 arith::Analyzer analyzer;
1149 analyzer.Bind(combined_vranges);
1150
1151 // Simplify the original expression first, mostly to simplify combiners
1152 PrimExpr expr = analyzer.Simplify(expr_orig, kSimplifyRewriteCanonicalRewrite);
1153
1154 if (const ReduceNode* red = expr.as<ReduceNode>()) {
1155 CHECK(red->init.empty()) << "Derivative of Reduction with initialization is not implemented";
1156 // TODO(sgrechanik-h): There are some other operations which behave like sum
1157 bool is_sum = IsSumCombiner(red->combiner, vranges);
1158 if (is_sum || CanFactorZeroFromCombiner(red->combiner, red->value_index, vranges)) {
1159 PrimExpr new_red = expr;
1160
1161 // Here we simplify the reduction
1162 PrimExpr cond = red->condition;
1163 Array<PrimExpr> source = red->source;
1164
1165 // If it is a summation then we can lift nonzeroness conditions from the source
1166 // and add them to the reduction conditions
1167 if (is_sum) {
1168 auto nz = NonzeronessCondition(red->source[red->value_index]);
1169 cond = nz.cond && cond;
1170 source.Set(0, nz.value);
1171 }
1172
1173 new_red = Reduce(red->combiner, source, red->axis, cond, red->value_index, red->init);
1174 new_red = SimplifyReductionDomain(new_red, combined_vranges);
1175 // Update original red pointer for later use.
1176 red = new_red.as<ReduceNode>();
1177 // If the reduction disappears completely then transform the result as a non-reduction
1178 if (!red) {
1179 return RemoveJacobianAndLiftNonzeroCondImpl(new_red, axis, vranges);
1180 }
1181
1182 PrimExpr new_outer_cond, new_reduce_cond;
1183 Array<PrimExpr> new_source = red->source;
1184
1185 // Partially lift conditions from the reduce condition
1186 std::tie(new_outer_cond, new_reduce_cond) =
1187 LiftConditionsThroughReduction(red->condition, red->axis, axis);
1188
1189 // If it's not sum then we haven't yet lifted nonzeroness cond from the source
1190 if (!is_sum) {
1191 PrimExpr outer_nz_cond, nz_cond, nz_source;
1192 auto nz = NonzeronessCondition(red->source[red->value_index]);
1193 // Append conditions from the reduction
1194 nz_cond = new_reduce_cond && nz.cond;
1195 nz_source = nz.value;
1196 std::tie(outer_nz_cond, nz_cond) = LiftConditionsThroughReduction(nz_cond, red->axis, axis);
1197 new_outer_cond = new_outer_cond && outer_nz_cond;
1198 new_source.Set(red->value_index, Select(nz_cond, nz_source, make_zero(nz_source.dtype())));
1199 }
1200
1201 PrimExpr new_reduce = Reduce(red->combiner, new_source, red->axis, new_reduce_cond,
1202 red->value_index, red->init);
1203 new_reduce =
1204 TrySimplifyCompute(new_reduce, new_outer_cond, IterVarsToVars(axis), combined_vranges);
1205 result = Select(new_outer_cond, new_reduce, make_zero(new_reduce.dtype()));
1206 } else {
1207 return SimplifyReductionDomain(expr, combined_vranges);
1208 }
1209 } else {
1210 auto nz = NonzeronessCondition(expr);
1211 PrimExpr new_expr =
1212 TrySimplifyCompute(nz.value, nz.cond, IterVarsToVars(axis), combined_vranges);
1213 result = Select(nz.cond, new_expr, make_zero(new_expr.dtype()));
1214 }
1215
1216 // Note that RemoveRedundantInequalities can sometimes propagate equalities which
1217 // other simplifiers cannot, like (i % 3) == 0.
1218 Array<PrimExpr> axis_conds = IterVarsToInequalities(axis);
1219 result = RemoveRedundantInequalities(result, axis_conds);
1220
1221 // Currently in TVM reductions are only allowed at the top level of compute,
1222 // we need to extract intermediate inlined reduction as a separate stage (tensor).
1223 // Sometimes TrySimplifyCompute doesn't perform lift / extraction,
1224 // so there may be some non-top reductions left, take care of them.
1225 result = LiftReductions(result, IterVarsToVars(axis), combined_vranges);
1226 return analyzer.Simplify(result, kSimplifyRewriteCanonicalRewrite);
1227 }
1228
RemoveJacobianAndLiftNonzeroCond(const Tensor & tensor,const Map<Var,Range> & vranges)1229 Tensor RemoveJacobianAndLiftNonzeroCond(const Tensor& tensor, const Map<Var, Range>& vranges) {
1230 auto transform_func = [&vranges](const PrimExpr& expr, const Array<IterVar>& axis) {
1231 return RemoveJacobianAndLiftNonzeroCondImpl(expr, axis, vranges);
1232 };
1233 return TransformTensorBody(tensor, transform_func);
1234 }
1235
1236 } // namespace te
1237 } // namespace tvm
1238