1 //===- GVNExpression.h - GVN Expression classes -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 ///
11 /// The header file for the GVN pass that contains expression handling
12 /// classes
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
17 #define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
18 
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Analysis/MemorySSA.h"
22 #include "llvm/IR/Constant.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Allocator.h"
26 #include "llvm/Support/ArrayRecycler.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/Compiler.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <algorithm>
31 #include <cassert>
32 #include <iterator>
33 #include <utility>
34 
35 namespace llvm {
36 
37 class BasicBlock;
38 class Type;
39 
40 namespace GVNExpression {
41 
42 enum ExpressionType {
43   ET_Base,
44   ET_Constant,
45   ET_Variable,
46   ET_Dead,
47   ET_Unknown,
48   ET_BasicStart,
49   ET_Basic,
50   ET_AggregateValue,
51   ET_Phi,
52   ET_MemoryStart,
53   ET_Call,
54   ET_Load,
55   ET_Store,
56   ET_MemoryEnd,
57   ET_BasicEnd
58 };
59 
60 class Expression {
61 private:
62   ExpressionType EType;
63   unsigned Opcode;
64   mutable hash_code HashVal = 0;
65 
66 public:
67   Expression(ExpressionType ET = ET_Base, unsigned O = ~2U)
68       : EType(ET), Opcode(O) {}
69   Expression(const Expression &) = delete;
70   Expression &operator=(const Expression &) = delete;
71   virtual ~Expression();
72 
73   static unsigned getEmptyKey() { return ~0U; }
74   static unsigned getTombstoneKey() { return ~1U; }
75 
76   bool operator!=(const Expression &Other) const { return !(*this == Other); }
77   bool operator==(const Expression &Other) const {
78     if (getOpcode() != Other.getOpcode())
79       return false;
80     if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
81       return true;
82     // Compare the expression type for anything but load and store.
83     // For load and store we set the opcode to zero to make them equal.
84     if (getExpressionType() != ET_Load && getExpressionType() != ET_Store &&
85         getExpressionType() != Other.getExpressionType())
86       return false;
87 
88     return equals(Other);
89   }
90 
91   hash_code getComputedHash() const {
92     // It's theoretically possible for a thing to hash to zero.  In that case,
93     // we will just compute the hash a few extra times, which is no worse that
94     // we did before, which was to compute it always.
95     if (static_cast<unsigned>(HashVal) == 0)
96       HashVal = getHashValue();
97     return HashVal;
98   }
99 
100   virtual bool equals(const Expression &Other) const { return true; }
101 
102   // Return true if the two expressions are exactly the same, including the
103   // normally ignored fields.
104   virtual bool exactlyEquals(const Expression &Other) const {
105     return getExpressionType() == Other.getExpressionType() && equals(Other);
106   }
107 
108   unsigned getOpcode() const { return Opcode; }
109   void setOpcode(unsigned opcode) { Opcode = opcode; }
110   ExpressionType getExpressionType() const { return EType; }
111 
112   // We deliberately leave the expression type out of the hash value.
113   virtual hash_code getHashValue() const { return getOpcode(); }
114 
115   // Debugging support
116   virtual void printInternal(raw_ostream &OS, bool PrintEType) const {
117     if (PrintEType)
118       OS << "etype = " << getExpressionType() << ",";
119     OS << "opcode = " << getOpcode() << ", ";
120   }
121 
122   void print(raw_ostream &OS) const {
123     OS << "{ ";
124     printInternal(OS, true);
125     OS << "}";
126   }
127 
128   LLVM_DUMP_METHOD void dump() const;
129 };
130 
131 inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) {
132   E.print(OS);
133   return OS;
134 }
135 
136 class BasicExpression : public Expression {
137 private:
138   using RecyclerType = ArrayRecycler<Value *>;
139   using RecyclerCapacity = RecyclerType::Capacity;
140 
141   Value **Operands = nullptr;
142   unsigned MaxOperands;
143   unsigned NumOperands = 0;
144   Type *ValueType = nullptr;
145 
146 public:
147   BasicExpression(unsigned NumOperands)
148       : BasicExpression(NumOperands, ET_Basic) {}
149   BasicExpression(unsigned NumOperands, ExpressionType ET)
150       : Expression(ET), MaxOperands(NumOperands) {}
151   BasicExpression() = delete;
152   BasicExpression(const BasicExpression &) = delete;
153   BasicExpression &operator=(const BasicExpression &) = delete;
154   ~BasicExpression() override;
155 
156   static bool classof(const Expression *EB) {
157     ExpressionType ET = EB->getExpressionType();
158     return ET > ET_BasicStart && ET < ET_BasicEnd;
159   }
160 
161   /// Swap two operands. Used during GVN to put commutative operands in
162   /// order.
163   void swapOperands(unsigned First, unsigned Second) {
164     std::swap(Operands[First], Operands[Second]);
165   }
166 
167   Value *getOperand(unsigned N) const {
168     assert(Operands && "Operands not allocated");
169     assert(N < NumOperands && "Operand out of range");
170     return Operands[N];
171   }
172 
173   void setOperand(unsigned N, Value *V) {
174     assert(Operands && "Operands not allocated before setting");
175     assert(N < NumOperands && "Operand out of range");
176     Operands[N] = V;
177   }
178 
179   unsigned getNumOperands() const { return NumOperands; }
180 
181   using op_iterator = Value **;
182   using const_op_iterator = Value *const *;
183 
184   op_iterator op_begin() { return Operands; }
185   op_iterator op_end() { return Operands + NumOperands; }
186   const_op_iterator op_begin() const { return Operands; }
187   const_op_iterator op_end() const { return Operands + NumOperands; }
188   iterator_range<op_iterator> operands() {
189     return iterator_range<op_iterator>(op_begin(), op_end());
190   }
191   iterator_range<const_op_iterator> operands() const {
192     return iterator_range<const_op_iterator>(op_begin(), op_end());
193   }
194 
195   void op_push_back(Value *Arg) {
196     assert(NumOperands < MaxOperands && "Tried to add too many operands");
197     assert(Operands && "Operandss not allocated before pushing");
198     Operands[NumOperands++] = Arg;
199   }
200   bool op_empty() const { return getNumOperands() == 0; }
201 
202   void allocateOperands(RecyclerType &Recycler, BumpPtrAllocator &Allocator) {
203     assert(!Operands && "Operands already allocated");
204     Operands = Recycler.allocate(RecyclerCapacity::get(MaxOperands), Allocator);
205   }
206   void deallocateOperands(RecyclerType &Recycler) {
207     Recycler.deallocate(RecyclerCapacity::get(MaxOperands), Operands);
208   }
209 
210   void setType(Type *T) { ValueType = T; }
211   Type *getType() const { return ValueType; }
212 
213   bool equals(const Expression &Other) const override {
214     if (getOpcode() != Other.getOpcode())
215       return false;
216 
217     const auto &OE = cast<BasicExpression>(Other);
218     return getType() == OE.getType() && NumOperands == OE.NumOperands &&
219            std::equal(op_begin(), op_end(), OE.op_begin());
220   }
221 
222   hash_code getHashValue() const override {
223     return hash_combine(this->Expression::getHashValue(), ValueType,
224                         hash_combine_range(op_begin(), op_end()));
225   }
226 
227   // Debugging support
228   void printInternal(raw_ostream &OS, bool PrintEType) const override {
229     if (PrintEType)
230       OS << "ExpressionTypeBasic, ";
231 
232     this->Expression::printInternal(OS, false);
233     OS << "operands = {";
234     for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
235       OS << "[" << i << "] = ";
236       Operands[i]->printAsOperand(OS);
237       OS << "  ";
238     }
239     OS << "} ";
240   }
241 };
242 
243 class op_inserter
244     : public std::iterator<std::output_iterator_tag, void, void, void, void> {
245 private:
246   using Container = BasicExpression;
247 
248   Container *BE;
249 
250 public:
251   explicit op_inserter(BasicExpression &E) : BE(&E) {}
252   explicit op_inserter(BasicExpression *E) : BE(E) {}
253 
254   op_inserter &operator=(Value *val) {
255     BE->op_push_back(val);
256     return *this;
257   }
258   op_inserter &operator*() { return *this; }
259   op_inserter &operator++() { return *this; }
260   op_inserter &operator++(int) { return *this; }
261 };
262 
263 class MemoryExpression : public BasicExpression {
264 private:
265   const MemoryAccess *MemoryLeader;
266 
267 public:
268   MemoryExpression(unsigned NumOperands, enum ExpressionType EType,
269                    const MemoryAccess *MemoryLeader)
270       : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader) {}
271   MemoryExpression() = delete;
272   MemoryExpression(const MemoryExpression &) = delete;
273   MemoryExpression &operator=(const MemoryExpression &) = delete;
274 
275   static bool classof(const Expression *EB) {
276     return EB->getExpressionType() > ET_MemoryStart &&
277            EB->getExpressionType() < ET_MemoryEnd;
278   }
279 
280   hash_code getHashValue() const override {
281     return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader);
282   }
283 
284   bool equals(const Expression &Other) const override {
285     if (!this->BasicExpression::equals(Other))
286       return false;
287     const MemoryExpression &OtherMCE = cast<MemoryExpression>(Other);
288 
289     return MemoryLeader == OtherMCE.MemoryLeader;
290   }
291 
292   const MemoryAccess *getMemoryLeader() const { return MemoryLeader; }
293   void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; }
294 };
295 
296 class CallExpression final : public MemoryExpression {
297 private:
298   CallInst *Call;
299 
300 public:
301   CallExpression(unsigned NumOperands, CallInst *C,
302                  const MemoryAccess *MemoryLeader)
303       : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {}
304   CallExpression() = delete;
305   CallExpression(const CallExpression &) = delete;
306   CallExpression &operator=(const CallExpression &) = delete;
307   ~CallExpression() override;
308 
309   static bool classof(const Expression *EB) {
310     return EB->getExpressionType() == ET_Call;
311   }
312 
313   // Debugging support
314   void printInternal(raw_ostream &OS, bool PrintEType) const override {
315     if (PrintEType)
316       OS << "ExpressionTypeCall, ";
317     this->BasicExpression::printInternal(OS, false);
318     OS << " represents call at ";
319     Call->printAsOperand(OS);
320   }
321 };
322 
323 class LoadExpression final : public MemoryExpression {
324 private:
325   LoadInst *Load;
326 
327 public:
328   LoadExpression(unsigned NumOperands, LoadInst *L,
329                  const MemoryAccess *MemoryLeader)
330       : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {}
331 
332   LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L,
333                  const MemoryAccess *MemoryLeader)
334       : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) {}
335 
336   LoadExpression() = delete;
337   LoadExpression(const LoadExpression &) = delete;
338   LoadExpression &operator=(const LoadExpression &) = delete;
339   ~LoadExpression() override;
340 
341   static bool classof(const Expression *EB) {
342     return EB->getExpressionType() == ET_Load;
343   }
344 
345   LoadInst *getLoadInst() const { return Load; }
346   void setLoadInst(LoadInst *L) { Load = L; }
347 
348   bool equals(const Expression &Other) const override;
349   bool exactlyEquals(const Expression &Other) const override {
350     return Expression::exactlyEquals(Other) &&
351            cast<LoadExpression>(Other).getLoadInst() == getLoadInst();
352   }
353 
354   // Debugging support
355   void printInternal(raw_ostream &OS, bool PrintEType) const override {
356     if (PrintEType)
357       OS << "ExpressionTypeLoad, ";
358     this->BasicExpression::printInternal(OS, false);
359     OS << " represents Load at ";
360     Load->printAsOperand(OS);
361     OS << " with MemoryLeader " << *getMemoryLeader();
362   }
363 };
364 
365 class StoreExpression final : public MemoryExpression {
366 private:
367   StoreInst *Store;
368   Value *StoredValue;
369 
370 public:
371   StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue,
372                   const MemoryAccess *MemoryLeader)
373       : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S),
374         StoredValue(StoredValue) {}
375   StoreExpression() = delete;
376   StoreExpression(const StoreExpression &) = delete;
377   StoreExpression &operator=(const StoreExpression &) = delete;
378   ~StoreExpression() override;
379 
380   static bool classof(const Expression *EB) {
381     return EB->getExpressionType() == ET_Store;
382   }
383 
384   StoreInst *getStoreInst() const { return Store; }
385   Value *getStoredValue() const { return StoredValue; }
386 
387   bool equals(const Expression &Other) const override;
388 
389   bool exactlyEquals(const Expression &Other) const override {
390     return Expression::exactlyEquals(Other) &&
391            cast<StoreExpression>(Other).getStoreInst() == getStoreInst();
392   }
393 
394   // Debugging support
395   void printInternal(raw_ostream &OS, bool PrintEType) const override {
396     if (PrintEType)
397       OS << "ExpressionTypeStore, ";
398     this->BasicExpression::printInternal(OS, false);
399     OS << " represents Store  " << *Store;
400     OS << " with StoredValue ";
401     StoredValue->printAsOperand(OS);
402     OS << " and MemoryLeader " << *getMemoryLeader();
403   }
404 };
405 
406 class AggregateValueExpression final : public BasicExpression {
407 private:
408   unsigned MaxIntOperands;
409   unsigned NumIntOperands = 0;
410   unsigned *IntOperands = nullptr;
411 
412 public:
413   AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands)
414       : BasicExpression(NumOperands, ET_AggregateValue),
415         MaxIntOperands(NumIntOperands) {}
416   AggregateValueExpression() = delete;
417   AggregateValueExpression(const AggregateValueExpression &) = delete;
418   AggregateValueExpression &
419   operator=(const AggregateValueExpression &) = delete;
420   ~AggregateValueExpression() override;
421 
422   static bool classof(const Expression *EB) {
423     return EB->getExpressionType() == ET_AggregateValue;
424   }
425 
426   using int_arg_iterator = unsigned *;
427   using const_int_arg_iterator = const unsigned *;
428 
429   int_arg_iterator int_op_begin() { return IntOperands; }
430   int_arg_iterator int_op_end() { return IntOperands + NumIntOperands; }
431   const_int_arg_iterator int_op_begin() const { return IntOperands; }
432   const_int_arg_iterator int_op_end() const {
433     return IntOperands + NumIntOperands;
434   }
435   unsigned int_op_size() const { return NumIntOperands; }
436   bool int_op_empty() const { return NumIntOperands == 0; }
437   void int_op_push_back(unsigned IntOperand) {
438     assert(NumIntOperands < MaxIntOperands &&
439            "Tried to add too many int operands");
440     assert(IntOperands && "Operands not allocated before pushing");
441     IntOperands[NumIntOperands++] = IntOperand;
442   }
443 
444   virtual void allocateIntOperands(BumpPtrAllocator &Allocator) {
445     assert(!IntOperands && "Operands already allocated");
446     IntOperands = Allocator.Allocate<unsigned>(MaxIntOperands);
447   }
448 
449   bool equals(const Expression &Other) const override {
450     if (!this->BasicExpression::equals(Other))
451       return false;
452     const AggregateValueExpression &OE = cast<AggregateValueExpression>(Other);
453     return NumIntOperands == OE.NumIntOperands &&
454            std::equal(int_op_begin(), int_op_end(), OE.int_op_begin());
455   }
456 
457   hash_code getHashValue() const override {
458     return hash_combine(this->BasicExpression::getHashValue(),
459                         hash_combine_range(int_op_begin(), int_op_end()));
460   }
461 
462   // Debugging support
463   void printInternal(raw_ostream &OS, bool PrintEType) const override {
464     if (PrintEType)
465       OS << "ExpressionTypeAggregateValue, ";
466     this->BasicExpression::printInternal(OS, false);
467     OS << ", intoperands = {";
468     for (unsigned i = 0, e = int_op_size(); i != e; ++i) {
469       OS << "[" << i << "] = " << IntOperands[i] << "  ";
470     }
471     OS << "}";
472   }
473 };
474 
475 class int_op_inserter
476     : public std::iterator<std::output_iterator_tag, void, void, void, void> {
477 private:
478   using Container = AggregateValueExpression;
479 
480   Container *AVE;
481 
482 public:
483   explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {}
484   explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {}
485 
486   int_op_inserter &operator=(unsigned int val) {
487     AVE->int_op_push_back(val);
488     return *this;
489   }
490   int_op_inserter &operator*() { return *this; }
491   int_op_inserter &operator++() { return *this; }
492   int_op_inserter &operator++(int) { return *this; }
493 };
494 
495 class PHIExpression final : public BasicExpression {
496 private:
497   BasicBlock *BB;
498 
499 public:
500   PHIExpression(unsigned NumOperands, BasicBlock *B)
501       : BasicExpression(NumOperands, ET_Phi), BB(B) {}
502   PHIExpression() = delete;
503   PHIExpression(const PHIExpression &) = delete;
504   PHIExpression &operator=(const PHIExpression &) = delete;
505   ~PHIExpression() override;
506 
507   static bool classof(const Expression *EB) {
508     return EB->getExpressionType() == ET_Phi;
509   }
510 
511   bool equals(const Expression &Other) const override {
512     if (!this->BasicExpression::equals(Other))
513       return false;
514     const PHIExpression &OE = cast<PHIExpression>(Other);
515     return BB == OE.BB;
516   }
517 
518   hash_code getHashValue() const override {
519     return hash_combine(this->BasicExpression::getHashValue(), BB);
520   }
521 
522   // Debugging support
523   void printInternal(raw_ostream &OS, bool PrintEType) const override {
524     if (PrintEType)
525       OS << "ExpressionTypePhi, ";
526     this->BasicExpression::printInternal(OS, false);
527     OS << "bb = " << BB;
528   }
529 };
530 
531 class DeadExpression final : public Expression {
532 public:
533   DeadExpression() : Expression(ET_Dead) {}
534   DeadExpression(const DeadExpression &) = delete;
535   DeadExpression &operator=(const DeadExpression &) = delete;
536 
537   static bool classof(const Expression *E) {
538     return E->getExpressionType() == ET_Dead;
539   }
540 };
541 
542 class VariableExpression final : public Expression {
543 private:
544   Value *VariableValue;
545 
546 public:
547   VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {}
548   VariableExpression() = delete;
549   VariableExpression(const VariableExpression &) = delete;
550   VariableExpression &operator=(const VariableExpression &) = delete;
551 
552   static bool classof(const Expression *EB) {
553     return EB->getExpressionType() == ET_Variable;
554   }
555 
556   Value *getVariableValue() const { return VariableValue; }
557   void setVariableValue(Value *V) { VariableValue = V; }
558 
559   bool equals(const Expression &Other) const override {
560     const VariableExpression &OC = cast<VariableExpression>(Other);
561     return VariableValue == OC.VariableValue;
562   }
563 
564   hash_code getHashValue() const override {
565     return hash_combine(this->Expression::getHashValue(),
566                         VariableValue->getType(), VariableValue);
567   }
568 
569   // Debugging support
570   void printInternal(raw_ostream &OS, bool PrintEType) const override {
571     if (PrintEType)
572       OS << "ExpressionTypeVariable, ";
573     this->Expression::printInternal(OS, false);
574     OS << " variable = " << *VariableValue;
575   }
576 };
577 
578 class ConstantExpression final : public Expression {
579 private:
580   Constant *ConstantValue = nullptr;
581 
582 public:
583   ConstantExpression() : Expression(ET_Constant) {}
584   ConstantExpression(Constant *constantValue)
585       : Expression(ET_Constant), ConstantValue(constantValue) {}
586   ConstantExpression(const ConstantExpression &) = delete;
587   ConstantExpression &operator=(const ConstantExpression &) = delete;
588 
589   static bool classof(const Expression *EB) {
590     return EB->getExpressionType() == ET_Constant;
591   }
592 
593   Constant *getConstantValue() const { return ConstantValue; }
594   void setConstantValue(Constant *V) { ConstantValue = V; }
595 
596   bool equals(const Expression &Other) const override {
597     const ConstantExpression &OC = cast<ConstantExpression>(Other);
598     return ConstantValue == OC.ConstantValue;
599   }
600 
601   hash_code getHashValue() const override {
602     return hash_combine(this->Expression::getHashValue(),
603                         ConstantValue->getType(), ConstantValue);
604   }
605 
606   // Debugging support
607   void printInternal(raw_ostream &OS, bool PrintEType) const override {
608     if (PrintEType)
609       OS << "ExpressionTypeConstant, ";
610     this->Expression::printInternal(OS, false);
611     OS << " constant = " << *ConstantValue;
612   }
613 };
614 
615 class UnknownExpression final : public Expression {
616 private:
617   Instruction *Inst;
618 
619 public:
620   UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {}
621   UnknownExpression() = delete;
622   UnknownExpression(const UnknownExpression &) = delete;
623   UnknownExpression &operator=(const UnknownExpression &) = delete;
624 
625   static bool classof(const Expression *EB) {
626     return EB->getExpressionType() == ET_Unknown;
627   }
628 
629   Instruction *getInstruction() const { return Inst; }
630   void setInstruction(Instruction *I) { Inst = I; }
631 
632   bool equals(const Expression &Other) const override {
633     const auto &OU = cast<UnknownExpression>(Other);
634     return Inst == OU.Inst;
635   }
636 
637   hash_code getHashValue() const override {
638     return hash_combine(this->Expression::getHashValue(), Inst);
639   }
640 
641   // Debugging support
642   void printInternal(raw_ostream &OS, bool PrintEType) const override {
643     if (PrintEType)
644       OS << "ExpressionTypeUnknown, ";
645     this->Expression::printInternal(OS, false);
646     OS << " inst = " << *Inst;
647   }
648 };
649 
650 } // end namespace GVNExpression
651 
652 } // end namespace llvm
653 
654 #endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
655