1 //===-- SPIRVDuplicatesTracker.h - SPIR-V Duplicates Tracker ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // General infrastructure for keeping track of the values that according to
10 // the SPIR-V binary layout should be global to the whole module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
15 #define LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
16 
17 #include "MCTargetDesc/SPIRVBaseInfo.h"
18 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
22 #include "llvm/CodeGen/MachineModuleInfo.h"
23 
24 #include <type_traits>
25 
26 namespace llvm {
27 namespace SPIRV {
28 // NOTE: using MapVector instead of DenseMap because it helps getting
29 // everything ordered in a stable manner for a price of extra (NumKeys)*PtrSize
30 // memory and expensive removals which do not happen anyway.
31 class DTSortableEntry : public MapVector<const MachineFunction *, Register> {
32   SmallVector<DTSortableEntry *, 2> Deps;
33 
34   struct FlagsTy {
35     unsigned IsFunc : 1;
36     unsigned IsGV : 1;
37     // NOTE: bit-field default init is a C++20 feature.
38     FlagsTy() : IsFunc(0), IsGV(0) {}
39   };
40   FlagsTy Flags;
41 
42 public:
43   // Common hoisting utility doesn't support function, because their hoisting
44   // require hoisting of params as well.
45   bool getIsFunc() const { return Flags.IsFunc; }
46   bool getIsGV() const { return Flags.IsGV; }
47   void setIsFunc(bool V) { Flags.IsFunc = V; }
48   void setIsGV(bool V) { Flags.IsGV = V; }
49 
50   const SmallVector<DTSortableEntry *, 2> &getDeps() const { return Deps; }
51   void addDep(DTSortableEntry *E) { Deps.push_back(E); }
52 };
53 } // namespace SPIRV
54 
55 template <typename KeyTy> class SPIRVDuplicatesTrackerBase {
56 public:
57   // NOTE: using MapVector instead of DenseMap helps getting everything ordered
58   // in a stable manner for a price of extra (NumKeys)*PtrSize memory and
59   // expensive removals which don't happen anyway.
60   using StorageTy = MapVector<KeyTy, SPIRV::DTSortableEntry>;
61 
62 private:
63   StorageTy Storage;
64 
65 public:
66   void add(KeyTy V, const MachineFunction *MF, Register R) {
67     if (find(V, MF).isValid())
68       return;
69 
70     Storage[V][MF] = R;
71     if (std::is_same<Function,
72                      typename std::remove_const<
73                          typename std::remove_pointer<KeyTy>::type>::type>() ||
74         std::is_same<Argument,
75                      typename std::remove_const<
76                          typename std::remove_pointer<KeyTy>::type>::type>())
77       Storage[V].setIsFunc(true);
78     if (std::is_same<GlobalVariable,
79                      typename std::remove_const<
80                          typename std::remove_pointer<KeyTy>::type>::type>())
81       Storage[V].setIsGV(true);
82   }
83 
84   Register find(KeyTy V, const MachineFunction *MF) const {
85     auto iter = Storage.find(V);
86     if (iter != Storage.end()) {
87       auto Map = iter->second;
88       auto iter2 = Map.find(MF);
89       if (iter2 != Map.end())
90         return iter2->second;
91     }
92     return Register();
93   }
94 
95   const StorageTy &getAllUses() const { return Storage; }
96 
97 private:
98   StorageTy &getAllUses() { return Storage; }
99 
100   // The friend class needs to have access to the internal storage
101   // to be able to build dependency graph, can't declare only one
102   // function a 'friend' due to the incomplete declaration at this point
103   // and mutual dependency problems.
104   friend class SPIRVGeneralDuplicatesTracker;
105 };
106 
107 template <typename T>
108 class SPIRVDuplicatesTracker : public SPIRVDuplicatesTrackerBase<const T *> {};
109 
110 class SPIRVGeneralDuplicatesTracker {
111   SPIRVDuplicatesTracker<Type> TT;
112   SPIRVDuplicatesTracker<Constant> CT;
113   SPIRVDuplicatesTracker<GlobalVariable> GT;
114   SPIRVDuplicatesTracker<Function> FT;
115   SPIRVDuplicatesTracker<Argument> AT;
116 
117   // NOTE: using MOs instead of regs to get rid of MF dependency to be able
118   // to use flat data structure.
119   // NOTE: replacing DenseMap with MapVector doesn't affect overall correctness
120   // but makes LITs more stable, should prefer DenseMap still due to
121   // significant perf difference.
122   using SPIRVReg2EntryTy =
123       MapVector<MachineOperand *, SPIRV::DTSortableEntry *>;
124 
125   template <typename T>
126   void prebuildReg2Entry(SPIRVDuplicatesTracker<T> &DT,
127                          SPIRVReg2EntryTy &Reg2Entry);
128 
129 public:
130   void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph,
131                       MachineModuleInfo *MMI);
132 
133   void add(const Type *T, const MachineFunction *MF, Register R) {
134     TT.add(T, MF, R);
135   }
136 
137   void add(const Constant *C, const MachineFunction *MF, Register R) {
138     CT.add(C, MF, R);
139   }
140 
141   void add(const GlobalVariable *GV, const MachineFunction *MF, Register R) {
142     GT.add(GV, MF, R);
143   }
144 
145   void add(const Function *F, const MachineFunction *MF, Register R) {
146     FT.add(F, MF, R);
147   }
148 
149   void add(const Argument *Arg, const MachineFunction *MF, Register R) {
150     AT.add(Arg, MF, R);
151   }
152 
153   Register find(const Type *T, const MachineFunction *MF) {
154     return TT.find(const_cast<Type *>(T), MF);
155   }
156 
157   Register find(const Constant *C, const MachineFunction *MF) {
158     return CT.find(const_cast<Constant *>(C), MF);
159   }
160 
161   Register find(const GlobalVariable *GV, const MachineFunction *MF) {
162     return GT.find(const_cast<GlobalVariable *>(GV), MF);
163   }
164 
165   Register find(const Function *F, const MachineFunction *MF) {
166     return FT.find(const_cast<Function *>(F), MF);
167   }
168 
169   Register find(const Argument *Arg, const MachineFunction *MF) {
170     return AT.find(const_cast<Argument *>(Arg), MF);
171   }
172 };
173 } // namespace llvm
174 #endif