1 //===- GVNExpression.h - GVN Expression classes -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 ///
11 /// The header file for the GVN pass that contains expression handling
12 /// classes
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
17 #define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
18 
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Analysis/MemorySSA.h"
22 #include "llvm/IR/Constant.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Allocator.h"
26 #include "llvm/Support/ArrayRecycler.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/Compiler.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <algorithm>
31 #include <cassert>
32 #include <iterator>
33 #include <utility>
34 
35 namespace llvm {
36 
37 class BasicBlock;
38 class Type;
39 
40 namespace GVNExpression {
41 
42 enum ExpressionType {
43   ET_Base,
44   ET_Constant,
45   ET_Variable,
46   ET_Dead,
47   ET_Unknown,
48   ET_BasicStart,
49   ET_Basic,
50   ET_AggregateValue,
51   ET_Phi,
52   ET_MemoryStart,
53   ET_Call,
54   ET_Load,
55   ET_Store,
56   ET_MemoryEnd,
57   ET_BasicEnd
58 };
59 
60 class Expression {
61 private:
62   ExpressionType EType;
63   unsigned Opcode;
64   mutable hash_code HashVal = 0;
65 
66 public:
67   Expression(ExpressionType ET = ET_Base, unsigned O = ~2U)
68       : EType(ET), Opcode(O) {}
69   Expression(const Expression &) = delete;
70   Expression &operator=(const Expression &) = delete;
71   virtual ~Expression();
72 
73   static unsigned getEmptyKey() { return ~0U; }
74   static unsigned getTombstoneKey() { return ~1U; }
75 
76   bool operator!=(const Expression &Other) const { return !(*this == Other); }
77   bool operator==(const Expression &Other) const {
78     if (getOpcode() != Other.getOpcode())
79       return false;
80     if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
81       return true;
82     // Compare the expression type for anything but load and store.
83     // For load and store we set the opcode to zero to make them equal.
84     if (getExpressionType() != ET_Load && getExpressionType() != ET_Store &&
85         getExpressionType() != Other.getExpressionType())
86       return false;
87 
88     return equals(Other);
89   }
90 
91   hash_code getComputedHash() const {
92     // It's theoretically possible for a thing to hash to zero.  In that case,
93     // we will just compute the hash a few extra times, which is no worse that
94     // we did before, which was to compute it always.
95     if (static_cast<unsigned>(HashVal) == 0)
96       HashVal = getHashValue();
97     return HashVal;
98   }
99 
100   virtual bool equals(const Expression &Other) const { return true; }
101 
102   // Return true if the two expressions are exactly the same, including the
103   // normally ignored fields.
104   virtual bool exactlyEquals(const Expression &Other) const {
105     return getExpressionType() == Other.getExpressionType() && equals(Other);
106   }
107 
108   unsigned getOpcode() const { return Opcode; }
109   void setOpcode(unsigned opcode) { Opcode = opcode; }
110   ExpressionType getExpressionType() const { return EType; }
111 
112   // We deliberately leave the expression type out of the hash value.
113   virtual hash_code getHashValue() const { return getOpcode(); }
114 
115   // Debugging support
116   virtual void printInternal(raw_ostream &OS, bool PrintEType) const {
117     if (PrintEType)
118       OS << "etype = " << getExpressionType() << ",";
119     OS << "opcode = " << getOpcode() << ", ";
120   }
121 
122   void print(raw_ostream &OS) const {
123     OS << "{ ";
124     printInternal(OS, true);
125     OS << "}";
126   }
127 
128   LLVM_DUMP_METHOD void dump() const;
129 };
130 
131 inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) {
132   E.print(OS);
133   return OS;
134 }
135 
136 class BasicExpression : public Expression {
137 private:
138   using RecyclerType = ArrayRecycler<Value *>;
139   using RecyclerCapacity = RecyclerType::Capacity;
140 
141   Value **Operands = nullptr;
142   unsigned MaxOperands;
143   unsigned NumOperands = 0;
144   Type *ValueType = nullptr;
145 
146 public:
147   BasicExpression(unsigned NumOperands)
148       : BasicExpression(NumOperands, ET_Basic) {}
149   BasicExpression(unsigned NumOperands, ExpressionType ET)
150       : Expression(ET), MaxOperands(NumOperands) {}
151   BasicExpression() = delete;
152   BasicExpression(const BasicExpression &) = delete;
153   BasicExpression &operator=(const BasicExpression &) = delete;
154   ~BasicExpression() override;
155 
156   static bool classof(const Expression *EB) {
157     ExpressionType ET = EB->getExpressionType();
158     return ET > ET_BasicStart && ET < ET_BasicEnd;
159   }
160 
161   /// Swap two operands. Used during GVN to put commutative operands in
162   /// order.
163   void swapOperands(unsigned First, unsigned Second) {
164     std::swap(Operands[First], Operands[Second]);
165   }
166 
167   Value *getOperand(unsigned N) const {
168     assert(Operands && "Operands not allocated");
169     assert(N < NumOperands && "Operand out of range");
170     return Operands[N];
171   }
172 
173   void setOperand(unsigned N, Value *V) {
174     assert(Operands && "Operands not allocated before setting");
175     assert(N < NumOperands && "Operand out of range");
176     Operands[N] = V;
177   }
178 
179   unsigned getNumOperands() const { return NumOperands; }
180 
181   using op_iterator = Value **;
182   using const_op_iterator = Value *const *;
183 
184   op_iterator op_begin() { return Operands; }
185   op_iterator op_end() { return Operands + NumOperands; }
186   const_op_iterator op_begin() const { return Operands; }
187   const_op_iterator op_end() const { return Operands + NumOperands; }
188   iterator_range<op_iterator> operands() {
189     return iterator_range<op_iterator>(op_begin(), op_end());
190   }
191   iterator_range<const_op_iterator> operands() const {
192     return iterator_range<const_op_iterator>(op_begin(), op_end());
193   }
194 
195   void op_push_back(Value *Arg) {
196     assert(NumOperands < MaxOperands && "Tried to add too many operands");
197     assert(Operands && "Operandss not allocated before pushing");
198     Operands[NumOperands++] = Arg;
199   }
200   bool op_empty() const { return getNumOperands() == 0; }
201 
202   void allocateOperands(RecyclerType &Recycler, BumpPtrAllocator &Allocator) {
203     assert(!Operands && "Operands already allocated");
204     Operands = Recycler.allocate(RecyclerCapacity::get(MaxOperands), Allocator);
205   }
206   void deallocateOperands(RecyclerType &Recycler) {
207     Recycler.deallocate(RecyclerCapacity::get(MaxOperands), Operands);
208   }
209 
210   void setType(Type *T) { ValueType = T; }
211   Type *getType() const { return ValueType; }
212 
213   bool equals(const Expression &Other) const override {
214     if (getOpcode() != Other.getOpcode())
215       return false;
216 
217     const auto &OE = cast<BasicExpression>(Other);
218     return getType() == OE.getType() && NumOperands == OE.NumOperands &&
219            std::equal(op_begin(), op_end(), OE.op_begin());
220   }
221 
222   hash_code getHashValue() const override {
223     return hash_combine(this->Expression::getHashValue(), ValueType,
224                         hash_combine_range(op_begin(), op_end()));
225   }
226 
227   // Debugging support
228   void printInternal(raw_ostream &OS, bool PrintEType) const override {
229     if (PrintEType)
230       OS << "ExpressionTypeBasic, ";
231 
232     this->Expression::printInternal(OS, false);
233     OS << "operands = {";
234     for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
235       OS << "[" << i << "] = ";
236       Operands[i]->printAsOperand(OS);
237       OS << "  ";
238     }
239     OS << "} ";
240   }
241 };
242 
243 class op_inserter {
244 private:
245   using Container = BasicExpression;
246 
247   Container *BE;
248 
249 public:
250   using iterator_category = std::output_iterator_tag;
251   using value_type = void;
252   using difference_type = void;
253   using pointer = void;
254   using reference = void;
255 
256   explicit op_inserter(BasicExpression &E) : BE(&E) {}
257   explicit op_inserter(BasicExpression *E) : BE(E) {}
258 
259   op_inserter &operator=(Value *val) {
260     BE->op_push_back(val);
261     return *this;
262   }
263   op_inserter &operator*() { return *this; }
264   op_inserter &operator++() { return *this; }
265   op_inserter &operator++(int) { return *this; }
266 };
267 
268 class MemoryExpression : public BasicExpression {
269 private:
270   const MemoryAccess *MemoryLeader;
271 
272 public:
273   MemoryExpression(unsigned NumOperands, enum ExpressionType EType,
274                    const MemoryAccess *MemoryLeader)
275       : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader) {}
276   MemoryExpression() = delete;
277   MemoryExpression(const MemoryExpression &) = delete;
278   MemoryExpression &operator=(const MemoryExpression &) = delete;
279 
280   static bool classof(const Expression *EB) {
281     return EB->getExpressionType() > ET_MemoryStart &&
282            EB->getExpressionType() < ET_MemoryEnd;
283   }
284 
285   hash_code getHashValue() const override {
286     return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader);
287   }
288 
289   bool equals(const Expression &Other) const override {
290     if (!this->BasicExpression::equals(Other))
291       return false;
292     const MemoryExpression &OtherMCE = cast<MemoryExpression>(Other);
293 
294     return MemoryLeader == OtherMCE.MemoryLeader;
295   }
296 
297   const MemoryAccess *getMemoryLeader() const { return MemoryLeader; }
298   void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; }
299 };
300 
301 class CallExpression final : public MemoryExpression {
302 private:
303   CallInst *Call;
304 
305 public:
306   CallExpression(unsigned NumOperands, CallInst *C,
307                  const MemoryAccess *MemoryLeader)
308       : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {}
309   CallExpression() = delete;
310   CallExpression(const CallExpression &) = delete;
311   CallExpression &operator=(const CallExpression &) = delete;
312   ~CallExpression() override;
313 
314   static bool classof(const Expression *EB) {
315     return EB->getExpressionType() == ET_Call;
316   }
317 
318   // Debugging support
319   void printInternal(raw_ostream &OS, bool PrintEType) const override {
320     if (PrintEType)
321       OS << "ExpressionTypeCall, ";
322     this->BasicExpression::printInternal(OS, false);
323     OS << " represents call at ";
324     Call->printAsOperand(OS);
325   }
326 };
327 
328 class LoadExpression final : public MemoryExpression {
329 private:
330   LoadInst *Load;
331 
332 public:
333   LoadExpression(unsigned NumOperands, LoadInst *L,
334                  const MemoryAccess *MemoryLeader)
335       : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {}
336 
337   LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L,
338                  const MemoryAccess *MemoryLeader)
339       : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) {}
340 
341   LoadExpression() = delete;
342   LoadExpression(const LoadExpression &) = delete;
343   LoadExpression &operator=(const LoadExpression &) = delete;
344   ~LoadExpression() override;
345 
346   static bool classof(const Expression *EB) {
347     return EB->getExpressionType() == ET_Load;
348   }
349 
350   LoadInst *getLoadInst() const { return Load; }
351   void setLoadInst(LoadInst *L) { Load = L; }
352 
353   bool equals(const Expression &Other) const override;
354   bool exactlyEquals(const Expression &Other) const override {
355     return Expression::exactlyEquals(Other) &&
356            cast<LoadExpression>(Other).getLoadInst() == getLoadInst();
357   }
358 
359   // Debugging support
360   void printInternal(raw_ostream &OS, bool PrintEType) const override {
361     if (PrintEType)
362       OS << "ExpressionTypeLoad, ";
363     this->BasicExpression::printInternal(OS, false);
364     OS << " represents Load at ";
365     Load->printAsOperand(OS);
366     OS << " with MemoryLeader " << *getMemoryLeader();
367   }
368 };
369 
370 class StoreExpression final : public MemoryExpression {
371 private:
372   StoreInst *Store;
373   Value *StoredValue;
374 
375 public:
376   StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue,
377                   const MemoryAccess *MemoryLeader)
378       : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S),
379         StoredValue(StoredValue) {}
380   StoreExpression() = delete;
381   StoreExpression(const StoreExpression &) = delete;
382   StoreExpression &operator=(const StoreExpression &) = delete;
383   ~StoreExpression() override;
384 
385   static bool classof(const Expression *EB) {
386     return EB->getExpressionType() == ET_Store;
387   }
388 
389   StoreInst *getStoreInst() const { return Store; }
390   Value *getStoredValue() const { return StoredValue; }
391 
392   bool equals(const Expression &Other) const override;
393 
394   bool exactlyEquals(const Expression &Other) const override {
395     return Expression::exactlyEquals(Other) &&
396            cast<StoreExpression>(Other).getStoreInst() == getStoreInst();
397   }
398 
399   // Debugging support
400   void printInternal(raw_ostream &OS, bool PrintEType) const override {
401     if (PrintEType)
402       OS << "ExpressionTypeStore, ";
403     this->BasicExpression::printInternal(OS, false);
404     OS << " represents Store  " << *Store;
405     OS << " with StoredValue ";
406     StoredValue->printAsOperand(OS);
407     OS << " and MemoryLeader " << *getMemoryLeader();
408   }
409 };
410 
411 class AggregateValueExpression final : public BasicExpression {
412 private:
413   unsigned MaxIntOperands;
414   unsigned NumIntOperands = 0;
415   unsigned *IntOperands = nullptr;
416 
417 public:
418   AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands)
419       : BasicExpression(NumOperands, ET_AggregateValue),
420         MaxIntOperands(NumIntOperands) {}
421   AggregateValueExpression() = delete;
422   AggregateValueExpression(const AggregateValueExpression &) = delete;
423   AggregateValueExpression &
424   operator=(const AggregateValueExpression &) = delete;
425   ~AggregateValueExpression() override;
426 
427   static bool classof(const Expression *EB) {
428     return EB->getExpressionType() == ET_AggregateValue;
429   }
430 
431   using int_arg_iterator = unsigned *;
432   using const_int_arg_iterator = const unsigned *;
433 
434   int_arg_iterator int_op_begin() { return IntOperands; }
435   int_arg_iterator int_op_end() { return IntOperands + NumIntOperands; }
436   const_int_arg_iterator int_op_begin() const { return IntOperands; }
437   const_int_arg_iterator int_op_end() const {
438     return IntOperands + NumIntOperands;
439   }
440   unsigned int_op_size() const { return NumIntOperands; }
441   bool int_op_empty() const { return NumIntOperands == 0; }
442   void int_op_push_back(unsigned IntOperand) {
443     assert(NumIntOperands < MaxIntOperands &&
444            "Tried to add too many int operands");
445     assert(IntOperands && "Operands not allocated before pushing");
446     IntOperands[NumIntOperands++] = IntOperand;
447   }
448 
449   virtual void allocateIntOperands(BumpPtrAllocator &Allocator) {
450     assert(!IntOperands && "Operands already allocated");
451     IntOperands = Allocator.Allocate<unsigned>(MaxIntOperands);
452   }
453 
454   bool equals(const Expression &Other) const override {
455     if (!this->BasicExpression::equals(Other))
456       return false;
457     const AggregateValueExpression &OE = cast<AggregateValueExpression>(Other);
458     return NumIntOperands == OE.NumIntOperands &&
459            std::equal(int_op_begin(), int_op_end(), OE.int_op_begin());
460   }
461 
462   hash_code getHashValue() const override {
463     return hash_combine(this->BasicExpression::getHashValue(),
464                         hash_combine_range(int_op_begin(), int_op_end()));
465   }
466 
467   // Debugging support
468   void printInternal(raw_ostream &OS, bool PrintEType) const override {
469     if (PrintEType)
470       OS << "ExpressionTypeAggregateValue, ";
471     this->BasicExpression::printInternal(OS, false);
472     OS << ", intoperands = {";
473     for (unsigned i = 0, e = int_op_size(); i != e; ++i) {
474       OS << "[" << i << "] = " << IntOperands[i] << "  ";
475     }
476     OS << "}";
477   }
478 };
479 
480 class int_op_inserter {
481 private:
482   using Container = AggregateValueExpression;
483 
484   Container *AVE;
485 
486 public:
487   using iterator_category = std::output_iterator_tag;
488   using value_type = void;
489   using difference_type = void;
490   using pointer = void;
491   using reference = void;
492 
493   explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {}
494   explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {}
495 
496   int_op_inserter &operator=(unsigned int val) {
497     AVE->int_op_push_back(val);
498     return *this;
499   }
500   int_op_inserter &operator*() { return *this; }
501   int_op_inserter &operator++() { return *this; }
502   int_op_inserter &operator++(int) { return *this; }
503 };
504 
505 class PHIExpression final : public BasicExpression {
506 private:
507   BasicBlock *BB;
508 
509 public:
510   PHIExpression(unsigned NumOperands, BasicBlock *B)
511       : BasicExpression(NumOperands, ET_Phi), BB(B) {}
512   PHIExpression() = delete;
513   PHIExpression(const PHIExpression &) = delete;
514   PHIExpression &operator=(const PHIExpression &) = delete;
515   ~PHIExpression() override;
516 
517   static bool classof(const Expression *EB) {
518     return EB->getExpressionType() == ET_Phi;
519   }
520 
521   bool equals(const Expression &Other) const override {
522     if (!this->BasicExpression::equals(Other))
523       return false;
524     const PHIExpression &OE = cast<PHIExpression>(Other);
525     return BB == OE.BB;
526   }
527 
528   hash_code getHashValue() const override {
529     return hash_combine(this->BasicExpression::getHashValue(), BB);
530   }
531 
532   // Debugging support
533   void printInternal(raw_ostream &OS, bool PrintEType) const override {
534     if (PrintEType)
535       OS << "ExpressionTypePhi, ";
536     this->BasicExpression::printInternal(OS, false);
537     OS << "bb = " << BB;
538   }
539 };
540 
541 class DeadExpression final : public Expression {
542 public:
543   DeadExpression() : Expression(ET_Dead) {}
544   DeadExpression(const DeadExpression &) = delete;
545   DeadExpression &operator=(const DeadExpression &) = delete;
546 
547   static bool classof(const Expression *E) {
548     return E->getExpressionType() == ET_Dead;
549   }
550 };
551 
552 class VariableExpression final : public Expression {
553 private:
554   Value *VariableValue;
555 
556 public:
557   VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {}
558   VariableExpression() = delete;
559   VariableExpression(const VariableExpression &) = delete;
560   VariableExpression &operator=(const VariableExpression &) = delete;
561 
562   static bool classof(const Expression *EB) {
563     return EB->getExpressionType() == ET_Variable;
564   }
565 
566   Value *getVariableValue() const { return VariableValue; }
567   void setVariableValue(Value *V) { VariableValue = V; }
568 
569   bool equals(const Expression &Other) const override {
570     const VariableExpression &OC = cast<VariableExpression>(Other);
571     return VariableValue == OC.VariableValue;
572   }
573 
574   hash_code getHashValue() const override {
575     return hash_combine(this->Expression::getHashValue(),
576                         VariableValue->getType(), VariableValue);
577   }
578 
579   // Debugging support
580   void printInternal(raw_ostream &OS, bool PrintEType) const override {
581     if (PrintEType)
582       OS << "ExpressionTypeVariable, ";
583     this->Expression::printInternal(OS, false);
584     OS << " variable = " << *VariableValue;
585   }
586 };
587 
588 class ConstantExpression final : public Expression {
589 private:
590   Constant *ConstantValue = nullptr;
591 
592 public:
593   ConstantExpression() : Expression(ET_Constant) {}
594   ConstantExpression(Constant *constantValue)
595       : Expression(ET_Constant), ConstantValue(constantValue) {}
596   ConstantExpression(const ConstantExpression &) = delete;
597   ConstantExpression &operator=(const ConstantExpression &) = delete;
598 
599   static bool classof(const Expression *EB) {
600     return EB->getExpressionType() == ET_Constant;
601   }
602 
603   Constant *getConstantValue() const { return ConstantValue; }
604   void setConstantValue(Constant *V) { ConstantValue = V; }
605 
606   bool equals(const Expression &Other) const override {
607     const ConstantExpression &OC = cast<ConstantExpression>(Other);
608     return ConstantValue == OC.ConstantValue;
609   }
610 
611   hash_code getHashValue() const override {
612     return hash_combine(this->Expression::getHashValue(),
613                         ConstantValue->getType(), ConstantValue);
614   }
615 
616   // Debugging support
617   void printInternal(raw_ostream &OS, bool PrintEType) const override {
618     if (PrintEType)
619       OS << "ExpressionTypeConstant, ";
620     this->Expression::printInternal(OS, false);
621     OS << " constant = " << *ConstantValue;
622   }
623 };
624 
625 class UnknownExpression final : public Expression {
626 private:
627   Instruction *Inst;
628 
629 public:
630   UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {}
631   UnknownExpression() = delete;
632   UnknownExpression(const UnknownExpression &) = delete;
633   UnknownExpression &operator=(const UnknownExpression &) = delete;
634 
635   static bool classof(const Expression *EB) {
636     return EB->getExpressionType() == ET_Unknown;
637   }
638 
639   Instruction *getInstruction() const { return Inst; }
640   void setInstruction(Instruction *I) { Inst = I; }
641 
642   bool equals(const Expression &Other) const override {
643     const auto &OU = cast<UnknownExpression>(Other);
644     return Inst == OU.Inst;
645   }
646 
647   hash_code getHashValue() const override {
648     return hash_combine(this->Expression::getHashValue(), Inst);
649   }
650 
651   // Debugging support
652   void printInternal(raw_ostream &OS, bool PrintEType) const override {
653     if (PrintEType)
654       OS << "ExpressionTypeUnknown, ";
655     this->Expression::printInternal(OS, false);
656     OS << " inst = " << *Inst;
657   }
658 };
659 
660 } // end namespace GVNExpression
661 
662 } // end namespace llvm
663 
664 #endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
665