1 //===- GVNExpression.h - GVN Expression classes -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 ///
11 /// The header file for the GVN pass that contains expression handling
12 /// classes
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
17 #define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
18 
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Analysis/MemorySSA.h"
22 #include "llvm/IR/Constant.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Allocator.h"
26 #include "llvm/Support/ArrayRecycler.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/Compiler.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <algorithm>
31 #include <cassert>
32 #include <iterator>
33 #include <utility>
34 
35 namespace llvm {
36 
37 class BasicBlock;
38 class Type;
39 
40 namespace GVNExpression {
41 
42 enum ExpressionType {
43   ET_Base,
44   ET_Constant,
45   ET_Variable,
46   ET_Dead,
47   ET_Unknown,
48   ET_BasicStart,
49   ET_Basic,
50   ET_AggregateValue,
51   ET_Phi,
52   ET_MemoryStart,
53   ET_Call,
54   ET_Load,
55   ET_Store,
56   ET_MemoryEnd,
57   ET_BasicEnd
58 };
59 
60 class Expression {
61 private:
62   ExpressionType EType;
63   unsigned Opcode;
64   mutable hash_code HashVal = 0;
65 
66 public:
67   Expression(ExpressionType ET = ET_Base, unsigned O = ~2U)
68       : EType(ET), Opcode(O) {}
69   Expression(const Expression &) = delete;
70   Expression &operator=(const Expression &) = delete;
71   virtual ~Expression();
72 
73   static unsigned getEmptyKey() { return ~0U; }
74   static unsigned getTombstoneKey() { return ~1U; }
75 
76   bool operator!=(const Expression &Other) const { return !(*this == Other); }
77   bool operator==(const Expression &Other) const {
78     if (getOpcode() != Other.getOpcode())
79       return false;
80     if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
81       return true;
82     // Compare the expression type for anything but load and store.
83     // For load and store we set the opcode to zero to make them equal.
84     if (getExpressionType() != ET_Load && getExpressionType() != ET_Store &&
85         getExpressionType() != Other.getExpressionType())
86       return false;
87 
88     return equals(Other);
89   }
90 
91   hash_code getComputedHash() const {
92     // It's theoretically possible for a thing to hash to zero.  In that case,
93     // we will just compute the hash a few extra times, which is no worse that
94     // we did before, which was to compute it always.
95     if (static_cast<unsigned>(HashVal) == 0)
96       HashVal = getHashValue();
97     return HashVal;
98   }
99 
100   virtual bool equals(const Expression &Other) const { return true; }
101 
102   // Return true if the two expressions are exactly the same, including the
103   // normally ignored fields.
104   virtual bool exactlyEquals(const Expression &Other) const {
105     return getExpressionType() == Other.getExpressionType() && equals(Other);
106   }
107 
108   unsigned getOpcode() const { return Opcode; }
109   void setOpcode(unsigned opcode) { Opcode = opcode; }
110   ExpressionType getExpressionType() const { return EType; }
111 
112   // We deliberately leave the expression type out of the hash value.
113   virtual hash_code getHashValue() const { return getOpcode(); }
114 
115   // Debugging support
116   virtual void printInternal(raw_ostream &OS, bool PrintEType) const {
117     if (PrintEType)
118       OS << "etype = " << getExpressionType() << ",";
119     OS << "opcode = " << getOpcode() << ", ";
120   }
121 
122   void print(raw_ostream &OS) const {
123     OS << "{ ";
124     printInternal(OS, true);
125     OS << "}";
126   }
127 
128   LLVM_DUMP_METHOD void dump() const;
129 };
130 
131 inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) {
132   E.print(OS);
133   return OS;
134 }
135 
136 class BasicExpression : public Expression {
137 private:
138   using RecyclerType = ArrayRecycler<Value *>;
139   using RecyclerCapacity = RecyclerType::Capacity;
140 
141   Value **Operands = nullptr;
142   unsigned MaxOperands;
143   unsigned NumOperands = 0;
144   Type *ValueType = nullptr;
145 
146 public:
147   BasicExpression(unsigned NumOperands)
148       : BasicExpression(NumOperands, ET_Basic) {}
149   BasicExpression(unsigned NumOperands, ExpressionType ET)
150       : Expression(ET), MaxOperands(NumOperands) {}
151   BasicExpression() = delete;
152   BasicExpression(const BasicExpression &) = delete;
153   BasicExpression &operator=(const BasicExpression &) = delete;
154   ~BasicExpression() override;
155 
156   static bool classof(const Expression *EB) {
157     ExpressionType ET = EB->getExpressionType();
158     return ET > ET_BasicStart && ET < ET_BasicEnd;
159   }
160 
161   /// Swap two operands. Used during GVN to put commutative operands in
162   /// order.
163   void swapOperands(unsigned First, unsigned Second) {
164     std::swap(Operands[First], Operands[Second]);
165   }
166 
167   Value *getOperand(unsigned N) const {
168     assert(Operands && "Operands not allocated");
169     assert(N < NumOperands && "Operand out of range");
170     return Operands[N];
171   }
172 
173   void setOperand(unsigned N, Value *V) {
174     assert(Operands && "Operands not allocated before setting");
175     assert(N < NumOperands && "Operand out of range");
176     Operands[N] = V;
177   }
178 
179   unsigned getNumOperands() const { return NumOperands; }
180 
181   using op_iterator = Value **;
182   using const_op_iterator = Value *const *;
183 
184   op_iterator op_begin() { return Operands; }
185   op_iterator op_end() { return Operands + NumOperands; }
186   const_op_iterator op_begin() const { return Operands; }
187   const_op_iterator op_end() const { return Operands + NumOperands; }
188   iterator_range<op_iterator> operands() {
189     return iterator_range<op_iterator>(op_begin(), op_end());
190   }
191   iterator_range<const_op_iterator> operands() const {
192     return iterator_range<const_op_iterator>(op_begin(), op_end());
193   }
194 
195   void op_push_back(Value *Arg) {
196     assert(NumOperands < MaxOperands && "Tried to add too many operands");
197     assert(Operands && "Operandss not allocated before pushing");
198     Operands[NumOperands++] = Arg;
199   }
200   bool op_empty() const { return getNumOperands() == 0; }
201 
202   void allocateOperands(RecyclerType &Recycler, BumpPtrAllocator &Allocator) {
203     assert(!Operands && "Operands already allocated");
204     Operands = Recycler.allocate(RecyclerCapacity::get(MaxOperands), Allocator);
205   }
206   void deallocateOperands(RecyclerType &Recycler) {
207     Recycler.deallocate(RecyclerCapacity::get(MaxOperands), Operands);
208   }
209 
210   void setType(Type *T) { ValueType = T; }
211   Type *getType() const { return ValueType; }
212 
213   bool equals(const Expression &Other) const override {
214     if (getOpcode() != Other.getOpcode())
215       return false;
216 
217     const auto &OE = cast<BasicExpression>(Other);
218     return getType() == OE.getType() && NumOperands == OE.NumOperands &&
219            std::equal(op_begin(), op_end(), OE.op_begin());
220   }
221 
222   hash_code getHashValue() const override {
223     return hash_combine(this->Expression::getHashValue(), ValueType,
224                         hash_combine_range(op_begin(), op_end()));
225   }
226 
227   // Debugging support
228   void printInternal(raw_ostream &OS, bool PrintEType) const override {
229     if (PrintEType)
230       OS << "ExpressionTypeBasic, ";
231 
232     this->Expression::printInternal(OS, false);
233     OS << "operands = {";
234     for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
235       OS << "[" << i << "] = ";
236       Operands[i]->printAsOperand(OS);
237       OS << "  ";
238     }
239     OS << "} ";
240   }
241 };
242 
243 class op_inserter
244     : public std::iterator<std::output_iterator_tag, void, void, void, void> {
245 private:
246   using Container = BasicExpression;
247 
248   Container *BE;
249 
250 public:
251   explicit op_inserter(BasicExpression &E) : BE(&E) {}
252   explicit op_inserter(BasicExpression *E) : BE(E) {}
253 
254   op_inserter &operator=(Value *val) {
255     BE->op_push_back(val);
256     return *this;
257   }
258   op_inserter &operator*() { return *this; }
259   op_inserter &operator++() { return *this; }
260   op_inserter &operator++(int) { return *this; }
261 };
262 
263 class MemoryExpression : public BasicExpression {
264 private:
265   const MemoryAccess *MemoryLeader;
266 
267 public:
268   MemoryExpression(unsigned NumOperands, enum ExpressionType EType,
269                    const MemoryAccess *MemoryLeader)
270       : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader) {}
271   MemoryExpression() = delete;
272   MemoryExpression(const MemoryExpression &) = delete;
273   MemoryExpression &operator=(const MemoryExpression &) = delete;
274 
275   static bool classof(const Expression *EB) {
276     return EB->getExpressionType() > ET_MemoryStart &&
277            EB->getExpressionType() < ET_MemoryEnd;
278   }
279 
280   hash_code getHashValue() const override {
281     return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader);
282   }
283 
284   bool equals(const Expression &Other) const override {
285     if (!this->BasicExpression::equals(Other))
286       return false;
287     const MemoryExpression &OtherMCE = cast<MemoryExpression>(Other);
288 
289     return MemoryLeader == OtherMCE.MemoryLeader;
290   }
291 
292   const MemoryAccess *getMemoryLeader() const { return MemoryLeader; }
293   void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; }
294 };
295 
296 class CallExpression final : public MemoryExpression {
297 private:
298   CallInst *Call;
299 
300 public:
301   CallExpression(unsigned NumOperands, CallInst *C,
302                  const MemoryAccess *MemoryLeader)
303       : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {}
304   CallExpression() = delete;
305   CallExpression(const CallExpression &) = delete;
306   CallExpression &operator=(const CallExpression &) = delete;
307   ~CallExpression() override;
308 
309   static bool classof(const Expression *EB) {
310     return EB->getExpressionType() == ET_Call;
311   }
312 
313   // Debugging support
314   void printInternal(raw_ostream &OS, bool PrintEType) const override {
315     if (PrintEType)
316       OS << "ExpressionTypeCall, ";
317     this->BasicExpression::printInternal(OS, false);
318     OS << " represents call at ";
319     Call->printAsOperand(OS);
320   }
321 };
322 
323 class LoadExpression final : public MemoryExpression {
324 private:
325   LoadInst *Load;
326   MaybeAlign Alignment;
327 
328 public:
329   LoadExpression(unsigned NumOperands, LoadInst *L,
330                  const MemoryAccess *MemoryLeader)
331       : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {}
332 
333   LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L,
334                  const MemoryAccess *MemoryLeader)
335       : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) {
336     if (L)
337       Alignment = MaybeAlign(L->getAlignment());
338   }
339 
340   LoadExpression() = delete;
341   LoadExpression(const LoadExpression &) = delete;
342   LoadExpression &operator=(const LoadExpression &) = delete;
343   ~LoadExpression() override;
344 
345   static bool classof(const Expression *EB) {
346     return EB->getExpressionType() == ET_Load;
347   }
348 
349   LoadInst *getLoadInst() const { return Load; }
350   void setLoadInst(LoadInst *L) { Load = L; }
351 
352   MaybeAlign getAlignment() const { return Alignment; }
353   void setAlignment(MaybeAlign Align) { Alignment = Align; }
354 
355   bool equals(const Expression &Other) const override;
356   bool exactlyEquals(const Expression &Other) const override {
357     return Expression::exactlyEquals(Other) &&
358            cast<LoadExpression>(Other).getLoadInst() == getLoadInst();
359   }
360 
361   // Debugging support
362   void printInternal(raw_ostream &OS, bool PrintEType) const override {
363     if (PrintEType)
364       OS << "ExpressionTypeLoad, ";
365     this->BasicExpression::printInternal(OS, false);
366     OS << " represents Load at ";
367     Load->printAsOperand(OS);
368     OS << " with MemoryLeader " << *getMemoryLeader();
369   }
370 };
371 
372 class StoreExpression final : public MemoryExpression {
373 private:
374   StoreInst *Store;
375   Value *StoredValue;
376 
377 public:
378   StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue,
379                   const MemoryAccess *MemoryLeader)
380       : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S),
381         StoredValue(StoredValue) {}
382   StoreExpression() = delete;
383   StoreExpression(const StoreExpression &) = delete;
384   StoreExpression &operator=(const StoreExpression &) = delete;
385   ~StoreExpression() override;
386 
387   static bool classof(const Expression *EB) {
388     return EB->getExpressionType() == ET_Store;
389   }
390 
391   StoreInst *getStoreInst() const { return Store; }
392   Value *getStoredValue() const { return StoredValue; }
393 
394   bool equals(const Expression &Other) const override;
395 
396   bool exactlyEquals(const Expression &Other) const override {
397     return Expression::exactlyEquals(Other) &&
398            cast<StoreExpression>(Other).getStoreInst() == getStoreInst();
399   }
400 
401   // Debugging support
402   void printInternal(raw_ostream &OS, bool PrintEType) const override {
403     if (PrintEType)
404       OS << "ExpressionTypeStore, ";
405     this->BasicExpression::printInternal(OS, false);
406     OS << " represents Store  " << *Store;
407     OS << " with StoredValue ";
408     StoredValue->printAsOperand(OS);
409     OS << " and MemoryLeader " << *getMemoryLeader();
410   }
411 };
412 
413 class AggregateValueExpression final : public BasicExpression {
414 private:
415   unsigned MaxIntOperands;
416   unsigned NumIntOperands = 0;
417   unsigned *IntOperands = nullptr;
418 
419 public:
420   AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands)
421       : BasicExpression(NumOperands, ET_AggregateValue),
422         MaxIntOperands(NumIntOperands) {}
423   AggregateValueExpression() = delete;
424   AggregateValueExpression(const AggregateValueExpression &) = delete;
425   AggregateValueExpression &
426   operator=(const AggregateValueExpression &) = delete;
427   ~AggregateValueExpression() override;
428 
429   static bool classof(const Expression *EB) {
430     return EB->getExpressionType() == ET_AggregateValue;
431   }
432 
433   using int_arg_iterator = unsigned *;
434   using const_int_arg_iterator = const unsigned *;
435 
436   int_arg_iterator int_op_begin() { return IntOperands; }
437   int_arg_iterator int_op_end() { return IntOperands + NumIntOperands; }
438   const_int_arg_iterator int_op_begin() const { return IntOperands; }
439   const_int_arg_iterator int_op_end() const {
440     return IntOperands + NumIntOperands;
441   }
442   unsigned int_op_size() const { return NumIntOperands; }
443   bool int_op_empty() const { return NumIntOperands == 0; }
444   void int_op_push_back(unsigned IntOperand) {
445     assert(NumIntOperands < MaxIntOperands &&
446            "Tried to add too many int operands");
447     assert(IntOperands && "Operands not allocated before pushing");
448     IntOperands[NumIntOperands++] = IntOperand;
449   }
450 
451   virtual void allocateIntOperands(BumpPtrAllocator &Allocator) {
452     assert(!IntOperands && "Operands already allocated");
453     IntOperands = Allocator.Allocate<unsigned>(MaxIntOperands);
454   }
455 
456   bool equals(const Expression &Other) const override {
457     if (!this->BasicExpression::equals(Other))
458       return false;
459     const AggregateValueExpression &OE = cast<AggregateValueExpression>(Other);
460     return NumIntOperands == OE.NumIntOperands &&
461            std::equal(int_op_begin(), int_op_end(), OE.int_op_begin());
462   }
463 
464   hash_code getHashValue() const override {
465     return hash_combine(this->BasicExpression::getHashValue(),
466                         hash_combine_range(int_op_begin(), int_op_end()));
467   }
468 
469   // Debugging support
470   void printInternal(raw_ostream &OS, bool PrintEType) const override {
471     if (PrintEType)
472       OS << "ExpressionTypeAggregateValue, ";
473     this->BasicExpression::printInternal(OS, false);
474     OS << ", intoperands = {";
475     for (unsigned i = 0, e = int_op_size(); i != e; ++i) {
476       OS << "[" << i << "] = " << IntOperands[i] << "  ";
477     }
478     OS << "}";
479   }
480 };
481 
482 class int_op_inserter
483     : public std::iterator<std::output_iterator_tag, void, void, void, void> {
484 private:
485   using Container = AggregateValueExpression;
486 
487   Container *AVE;
488 
489 public:
490   explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {}
491   explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {}
492 
493   int_op_inserter &operator=(unsigned int val) {
494     AVE->int_op_push_back(val);
495     return *this;
496   }
497   int_op_inserter &operator*() { return *this; }
498   int_op_inserter &operator++() { return *this; }
499   int_op_inserter &operator++(int) { return *this; }
500 };
501 
502 class PHIExpression final : public BasicExpression {
503 private:
504   BasicBlock *BB;
505 
506 public:
507   PHIExpression(unsigned NumOperands, BasicBlock *B)
508       : BasicExpression(NumOperands, ET_Phi), BB(B) {}
509   PHIExpression() = delete;
510   PHIExpression(const PHIExpression &) = delete;
511   PHIExpression &operator=(const PHIExpression &) = delete;
512   ~PHIExpression() override;
513 
514   static bool classof(const Expression *EB) {
515     return EB->getExpressionType() == ET_Phi;
516   }
517 
518   bool equals(const Expression &Other) const override {
519     if (!this->BasicExpression::equals(Other))
520       return false;
521     const PHIExpression &OE = cast<PHIExpression>(Other);
522     return BB == OE.BB;
523   }
524 
525   hash_code getHashValue() const override {
526     return hash_combine(this->BasicExpression::getHashValue(), BB);
527   }
528 
529   // Debugging support
530   void printInternal(raw_ostream &OS, bool PrintEType) const override {
531     if (PrintEType)
532       OS << "ExpressionTypePhi, ";
533     this->BasicExpression::printInternal(OS, false);
534     OS << "bb = " << BB;
535   }
536 };
537 
538 class DeadExpression final : public Expression {
539 public:
540   DeadExpression() : Expression(ET_Dead) {}
541   DeadExpression(const DeadExpression &) = delete;
542   DeadExpression &operator=(const DeadExpression &) = delete;
543 
544   static bool classof(const Expression *E) {
545     return E->getExpressionType() == ET_Dead;
546   }
547 };
548 
549 class VariableExpression final : public Expression {
550 private:
551   Value *VariableValue;
552 
553 public:
554   VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {}
555   VariableExpression() = delete;
556   VariableExpression(const VariableExpression &) = delete;
557   VariableExpression &operator=(const VariableExpression &) = delete;
558 
559   static bool classof(const Expression *EB) {
560     return EB->getExpressionType() == ET_Variable;
561   }
562 
563   Value *getVariableValue() const { return VariableValue; }
564   void setVariableValue(Value *V) { VariableValue = V; }
565 
566   bool equals(const Expression &Other) const override {
567     const VariableExpression &OC = cast<VariableExpression>(Other);
568     return VariableValue == OC.VariableValue;
569   }
570 
571   hash_code getHashValue() const override {
572     return hash_combine(this->Expression::getHashValue(),
573                         VariableValue->getType(), VariableValue);
574   }
575 
576   // Debugging support
577   void printInternal(raw_ostream &OS, bool PrintEType) const override {
578     if (PrintEType)
579       OS << "ExpressionTypeVariable, ";
580     this->Expression::printInternal(OS, false);
581     OS << " variable = " << *VariableValue;
582   }
583 };
584 
585 class ConstantExpression final : public Expression {
586 private:
587   Constant *ConstantValue = nullptr;
588 
589 public:
590   ConstantExpression() : Expression(ET_Constant) {}
591   ConstantExpression(Constant *constantValue)
592       : Expression(ET_Constant), ConstantValue(constantValue) {}
593   ConstantExpression(const ConstantExpression &) = delete;
594   ConstantExpression &operator=(const ConstantExpression &) = delete;
595 
596   static bool classof(const Expression *EB) {
597     return EB->getExpressionType() == ET_Constant;
598   }
599 
600   Constant *getConstantValue() const { return ConstantValue; }
601   void setConstantValue(Constant *V) { ConstantValue = V; }
602 
603   bool equals(const Expression &Other) const override {
604     const ConstantExpression &OC = cast<ConstantExpression>(Other);
605     return ConstantValue == OC.ConstantValue;
606   }
607 
608   hash_code getHashValue() const override {
609     return hash_combine(this->Expression::getHashValue(),
610                         ConstantValue->getType(), ConstantValue);
611   }
612 
613   // Debugging support
614   void printInternal(raw_ostream &OS, bool PrintEType) const override {
615     if (PrintEType)
616       OS << "ExpressionTypeConstant, ";
617     this->Expression::printInternal(OS, false);
618     OS << " constant = " << *ConstantValue;
619   }
620 };
621 
622 class UnknownExpression final : public Expression {
623 private:
624   Instruction *Inst;
625 
626 public:
627   UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {}
628   UnknownExpression() = delete;
629   UnknownExpression(const UnknownExpression &) = delete;
630   UnknownExpression &operator=(const UnknownExpression &) = delete;
631 
632   static bool classof(const Expression *EB) {
633     return EB->getExpressionType() == ET_Unknown;
634   }
635 
636   Instruction *getInstruction() const { return Inst; }
637   void setInstruction(Instruction *I) { Inst = I; }
638 
639   bool equals(const Expression &Other) const override {
640     const auto &OU = cast<UnknownExpression>(Other);
641     return Inst == OU.Inst;
642   }
643 
644   hash_code getHashValue() const override {
645     return hash_combine(this->Expression::getHashValue(), Inst);
646   }
647 
648   // Debugging support
649   void printInternal(raw_ostream &OS, bool PrintEType) const override {
650     if (PrintEType)
651       OS << "ExpressionTypeUnknown, ";
652     this->Expression::printInternal(OS, false);
653     OS << " inst = " << *Inst;
654   }
655 };
656 
657 } // end namespace GVNExpression
658 
659 } // end namespace llvm
660 
661 #endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
662