1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file stmt_simplify.cc
22  * \brief Statement simplifier based on analyzer
23  */
24 #include <tvm/ir.h>
25 #include <tvm/ir_pass.h>
26 #include <tvm/arithmetic.h>
27 #include <tvm/ir_mutator.h>
28 #include <tvm/expr_operator.h>
29 #include <tvm/arithmetic.h>
30 #include "ir_mutator_with_analyzer.h"
31 
32 namespace tvm {
33 namespace arith {
34 
35 using namespace ir;
36 
37 class StmtSimplifier : public IRMutatorWithAnalyzer {
38  public:
StmtSimplifier(Analyzer * analyzer)39   explicit StmtSimplifier(Analyzer* analyzer)
40       : IRMutatorWithAnalyzer(analyzer) {}
41 
42   using Parent = IRMutatorWithAnalyzer;
43   using Parent::Mutate;
44   using Parent::Mutate_;
45 
Mutate(Expr expr)46   Expr Mutate(Expr expr) final {
47     return analyzer_->Simplify(expr);
48   }
49 
Simplify(Stmt stmt)50   Stmt Simplify(Stmt stmt) {
51     return Mutate(stmt);
52   }
53 
Mutate_(const For * op,const Stmt & s)54   Stmt Mutate_(const For* op, const Stmt& s) final {
55     analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
56     With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
57     With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
58     return IRMutator::Mutate_(op, s);
59   }
60 
Mutate_(const LetStmt * op,const Stmt & s)61   Stmt Mutate_(const LetStmt* op, const Stmt& s) {
62     Expr value = this->Mutate(op->value);
63     if (!ir::HasSideEffect(value)) {
64       // it is fine to discard the let binding
65       // because the call to simplify will always inline the var.
66       analyzer_->Bind(op->var, value);
67       return Mutate(op->body);
68     }
69     Stmt body = this->Mutate(op->body);
70     if (value.same_as(op->value) &&
71         body.same_as(op->body)) {
72       return s;
73     } else {
74       return LetStmt::make(op->var, value, body);
75     }
76   }
77 
78   // eliminate useless stores
Mutate_(const Store * op,const Stmt & s)79   Stmt Mutate_(const Store* op, const Stmt& s) final {
80     Stmt stmt = IRMutator::Mutate_(op, s);
81     op = stmt.as<Store>();
82     if (const Load* load = op->value.as<Load>()) {
83       if (load->buffer_var.same_as(op->buffer_var) &&
84           Equal(load->index, op->index)) {
85         return Evaluate::make(0);
86       }
87     }
88     return stmt;
89   }
90 };
91 
92 }  // namespace arith
93 
94 namespace ir {
95 
CanonicalSimplify(Stmt stmt,Map<Var,Range> vrange)96 Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
97   arith::Analyzer analyzer;
98   for (auto kv : vrange) {
99     analyzer.Bind(kv.first, kv.second);
100   }
101   return arith::StmtSimplifier(&analyzer).Simplify(stmt);
102 }
103 
CanonicalSimplify(Expr expr,Map<Var,Range> vrange)104 Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
105   arith::Analyzer analyzer;
106   for (auto kv : vrange) {
107     analyzer.Bind(kv.first, kv.second);
108   }
109   return analyzer.canonical_simplify(expr);
110 }
111 
Simplify(Expr expr,Map<Var,Range> vrange)112 Expr Simplify(Expr expr, Map<Var, Range> vrange) {
113   arith::Analyzer analyzer;
114   for (auto kv : vrange) {
115     analyzer.Bind(kv.first, kv.second);
116   }
117   expr = analyzer.Simplify(expr);
118   return expr;
119 }
120 
Simplify(Stmt stmt,Map<Var,Range> vrange)121 Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
122   return CanonicalSimplify(stmt, vrange);
123 }
124 }  // namespace ir
125 }  // namespace tvm
126