1 #include "UnifyDuplicateLets.h"
2 #include "IREquality.h"
3 #include "IRMutator.h"
4 #include <map>
5 
6 namespace Halide {
7 namespace Internal {
8 
9 using std::map;
10 using std::string;
11 
12 class UnifyDuplicateLets : public IRMutator {
13     using IRMutator::visit;
14 
15     map<Expr, string, IRDeepCompare> scope;
16     map<string, string> rewrites;
17     string producing;
18 
19 public:
20     using IRMutator::mutate;
21 
mutate(const Expr & e)22     Expr mutate(const Expr &e) override {
23         if (e.defined()) {
24             map<Expr, string, IRDeepCompare>::iterator iter = scope.find(e);
25             if (iter != scope.end()) {
26                 return Variable::make(e.type(), iter->second);
27             } else {
28                 return IRMutator::mutate(e);
29             }
30         } else {
31             return Expr();
32         }
33     }
34 
35 protected:
visit(const Variable * op)36     Expr visit(const Variable *op) override {
37         map<string, string>::iterator iter = rewrites.find(op->name);
38         if (iter != rewrites.end()) {
39             return Variable::make(op->type, iter->second);
40         } else {
41             return op;
42         }
43     }
44 
45     // Can't unify lets where the RHS might be not be pure
46     bool is_impure;
visit(const Call * op)47     Expr visit(const Call *op) override {
48         is_impure |= !op->is_pure();
49         return IRMutator::visit(op);
50     }
51 
visit(const Load * op)52     Expr visit(const Load *op) override {
53         is_impure = true;
54         return IRMutator::visit(op);
55     }
56 
visit(const ProducerConsumer * op)57     Stmt visit(const ProducerConsumer *op) override {
58         if (op->is_producer) {
59             string old_producing = producing;
60             producing = op->name;
61             Stmt stmt = IRMutator::visit(op);
62             producing = old_producing;
63             return stmt;
64         } else {
65             return IRMutator::visit(op);
66         }
67     }
68 
69     template<typename LetStmtOrLet>
visit_let(const LetStmtOrLet * op)70     auto visit_let(const LetStmtOrLet *op) -> decltype(op->body) {
71         is_impure = false;
72         Expr value = mutate(op->value);
73         auto body = op->body;
74 
75         bool should_pop = false;
76         bool should_erase = false;
77 
78         if (!is_impure) {
79             map<Expr, string, IRDeepCompare>::iterator iter = scope.find(value);
80             if (iter == scope.end()) {
81                 scope[value] = op->name;
82                 should_pop = true;
83             } else {
84                 value = Variable::make(value.type(), iter->second);
85                 rewrites[op->name] = iter->second;
86                 should_erase = true;
87             }
88         }
89 
90         body = mutate(op->body);
91 
92         if (should_pop) {
93             scope.erase(value);
94         }
95         if (should_erase) {
96             rewrites.erase(op->name);
97         }
98 
99         if (value.same_as(op->value) && body.same_as(op->body)) {
100             return op;
101         } else {
102             return LetStmtOrLet::make(op->name, value, body);
103         }
104     }
105 
visit(const Let * op)106     Expr visit(const Let *op) override {
107         return visit_let(op);
108     }
109 
visit(const LetStmt * op)110     Stmt visit(const LetStmt *op) override {
111         return visit_let(op);
112     }
113 };
114 
unify_duplicate_lets(const Stmt & s)115 Stmt unify_duplicate_lets(const Stmt &s) {
116     return UnifyDuplicateLets().mutate(s);
117 }
118 
119 }  // namespace Internal
120 }  // namespace Halide
121