1 #include "UnifyDuplicateLets.h"
2 #include "IREquality.h"
3 #include "IRMutator.h"
4 #include <map>
5
6 namespace Halide {
7 namespace Internal {
8
9 using std::map;
10 using std::string;
11
12 class UnifyDuplicateLets : public IRMutator {
13 using IRMutator::visit;
14
15 map<Expr, string, IRDeepCompare> scope;
16 map<string, string> rewrites;
17 string producing;
18
19 public:
20 using IRMutator::mutate;
21
mutate(const Expr & e)22 Expr mutate(const Expr &e) override {
23 if (e.defined()) {
24 map<Expr, string, IRDeepCompare>::iterator iter = scope.find(e);
25 if (iter != scope.end()) {
26 return Variable::make(e.type(), iter->second);
27 } else {
28 return IRMutator::mutate(e);
29 }
30 } else {
31 return Expr();
32 }
33 }
34
35 protected:
visit(const Variable * op)36 Expr visit(const Variable *op) override {
37 map<string, string>::iterator iter = rewrites.find(op->name);
38 if (iter != rewrites.end()) {
39 return Variable::make(op->type, iter->second);
40 } else {
41 return op;
42 }
43 }
44
45 // Can't unify lets where the RHS might be not be pure
46 bool is_impure;
visit(const Call * op)47 Expr visit(const Call *op) override {
48 is_impure |= !op->is_pure();
49 return IRMutator::visit(op);
50 }
51
visit(const Load * op)52 Expr visit(const Load *op) override {
53 is_impure = true;
54 return IRMutator::visit(op);
55 }
56
visit(const ProducerConsumer * op)57 Stmt visit(const ProducerConsumer *op) override {
58 if (op->is_producer) {
59 string old_producing = producing;
60 producing = op->name;
61 Stmt stmt = IRMutator::visit(op);
62 producing = old_producing;
63 return stmt;
64 } else {
65 return IRMutator::visit(op);
66 }
67 }
68
69 template<typename LetStmtOrLet>
visit_let(const LetStmtOrLet * op)70 auto visit_let(const LetStmtOrLet *op) -> decltype(op->body) {
71 is_impure = false;
72 Expr value = mutate(op->value);
73 auto body = op->body;
74
75 bool should_pop = false;
76 bool should_erase = false;
77
78 if (!is_impure) {
79 map<Expr, string, IRDeepCompare>::iterator iter = scope.find(value);
80 if (iter == scope.end()) {
81 scope[value] = op->name;
82 should_pop = true;
83 } else {
84 value = Variable::make(value.type(), iter->second);
85 rewrites[op->name] = iter->second;
86 should_erase = true;
87 }
88 }
89
90 body = mutate(op->body);
91
92 if (should_pop) {
93 scope.erase(value);
94 }
95 if (should_erase) {
96 rewrites.erase(op->name);
97 }
98
99 if (value.same_as(op->value) && body.same_as(op->body)) {
100 return op;
101 } else {
102 return LetStmtOrLet::make(op->name, value, body);
103 }
104 }
105
visit(const Let * op)106 Expr visit(const Let *op) override {
107 return visit_let(op);
108 }
109
visit(const LetStmt * op)110 Stmt visit(const LetStmt *op) override {
111 return visit_let(op);
112 }
113 };
114
unify_duplicate_lets(const Stmt & s)115 Stmt unify_duplicate_lets(const Stmt &s) {
116 return UnifyDuplicateLets().mutate(s);
117 }
118
119 } // namespace Internal
120 } // namespace Halide
121