1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 #include "Compiler/CISACodeGen/PromoteConstantStructs.hpp"
10 #include "common/LLVMWarningsPush.hpp"
11 #include "llvm/Analysis/MemoryDependenceAnalysis.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/Analysis/PtrUseVisitor.h"
15 #include "llvm/Analysis/CFG.h"
16 #include "common/LLVMWarningsPop.hpp"
17 
18 using namespace llvm;
19 using namespace IGC;
20 
21 namespace {
22 
23     // This class intends to promote constants that are saved into large structs
24     // e.g. specialization constants
25     //
26     // This task is not done by SROA because this structs can be used(loaded) with
27     // non-constant indices if they have an array inside
28     //
29     // It is not done by GVN either because GVN uses MemoryDependenceAnalysis which
30     // does not look deeper than 100 instructions to get a memdep which does not
31     // work for large structures. This parameter can not be tweaked inside compiler
32     //
33     // Restrictions to this implementation to keep it simple and save compile time:
34     // -- The structure does not escape
35     // -- We promote only constant values of the same type and store->load only
36     //
37     // How it works:
38     // -- iterate over allocas to check stores and collect potential loads
39     // -- iterate over loads and use MemDep to find defining store
40 
41     class PromoteConstantStructs : public FunctionPass {
42 
43     public:
44         static char ID;
45 
PromoteConstantStructs()46         PromoteConstantStructs() : FunctionPass(ID) {
47             initializePromoteConstantStructsPass(*PassRegistry::getPassRegistry());
48         }
49 
50         bool runOnFunction(Function& F) override;
51 
52     private:
53 
54         const unsigned int InstructionsLimit = 1000;
55 
56         MemoryDependenceResults *MD = nullptr;
57         DominatorTree *DT = nullptr;
58         LoopInfo* LPI = nullptr;
59 
getAnalysisUsage(AnalysisUsage & AU) const60         void getAnalysisUsage(AnalysisUsage &AU) const override {
61             AU.addRequired<DominatorTreeWrapperPass>();
62             AU.addRequired<MemoryDependenceWrapperPass>();
63             AU.addRequired<LoopInfoWrapperPass>();
64             AU.setPreservesCFG();
65             AU.addPreserved<DominatorTreeWrapperPass>();
66         }
67 
68         bool processAlloca(AllocaInst &AI);
69 
70         bool processLoad(LoadInst *LI, SetVector<BasicBlock*>& StoreBBs);
71 
72     };
73 
74     char PromoteConstantStructs::ID = 0;
75 
76 } // End anonymous namespace
77 
createPromoteConstantStructsPass()78 llvm::FunctionPass* createPromoteConstantStructsPass() {
79     return new PromoteConstantStructs();
80 }
81 
82 #define PASS_FLAG     "igc-promote-constant-structs"
83 #define PASS_DESC     "Promote constant structs"
84 #define PASS_CFG_ONLY false
85 #define PASS_ANALYSIS false
86 IGC_INITIALIZE_PASS_BEGIN(PromoteConstantStructs, PASS_FLAG, PASS_DESC, PASS_CFG_ONLY, PASS_ANALYSIS)
87 INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass)
88 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
89 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
90 IGC_INITIALIZE_PASS_END(PromoteConstantStructs, PASS_FLAG, PASS_DESC, PASS_CFG_ONLY, PASS_ANALYSIS)
91 
92 // This class visits all alloca uses to check that
93 // -- it does not escape
94 // and collect all loads with constant offsets from alloca
95 
96 class AllocaChecker : public PtrUseVisitor<AllocaChecker> {
97     friend class PtrUseVisitor<AllocaChecker>;
98     friend class InstVisitor<AllocaChecker>;
99 
100 public:
AllocaChecker(const DataLayout & DL)101     AllocaChecker(const DataLayout &DL)
102         : PtrUseVisitor<AllocaChecker>(DL), StoreBBs() {}
103 
getPotentialLoads()104     SmallVector<LoadInst*, 8>& getPotentialLoads() {
105         return Loads;
106     }
107 
getStoreBBs()108     SetVector<BasicBlock*>& getStoreBBs() {
109         return StoreBBs;
110     }
111 
112 private:
113     SmallVector<LoadInst*, 8> Loads;
114 
115     SetVector<BasicBlock*> StoreBBs;
116 
visitMemIntrinsic(MemIntrinsic & I)117     void visitMemIntrinsic(MemIntrinsic& I) {
118         StoreBBs.insert(I.getParent());
119     }
120 
visitIntrinsicInst(IntrinsicInst & II)121     void visitIntrinsicInst(IntrinsicInst& II) {
122         auto IID = II.getIntrinsicID();
123         if (IID == Intrinsic::lifetime_start || IID == Intrinsic::lifetime_end) {
124             return;
125         }
126 
127         if (!II.onlyReadsMemory()) {
128             StoreBBs.insert(II.getParent());
129         }
130     }
131 
visitStoreInst(StoreInst & SI)132     void visitStoreInst(StoreInst &SI) {
133         StoreBBs.insert(SI.getParent());
134     }
135 
visitLoadInst(LoadInst & LI)136     void visitLoadInst(LoadInst &LI) {
137         if (LI.isUnordered() && IsOffsetKnown) {
138             Loads.push_back(&LI);
139         }
140     }
141 };
142 
runOnFunction(Function & F)143 bool PromoteConstantStructs::runOnFunction(Function &F) {
144     DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
145     MD = &getAnalysis<MemoryDependenceWrapperPass>().getMemDep();
146     LPI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
147 
148     bool Changed = false;
149     BasicBlock &EntryBB = F.getEntryBlock();
150     for (BasicBlock::iterator I = EntryBB.begin(), E = std::prev(EntryBB.end());
151         I != E; ++I) {
152         if (AllocaInst *AI = dyn_cast<AllocaInst>(I))
153             Changed |= processAlloca(*AI);
154     }
155 
156     return Changed;
157 }
158 
processAlloca(AllocaInst & AI)159 bool PromoteConstantStructs::processAlloca(AllocaInst &AI) {
160     // we do not process single or array allocas
161     if (!AI.getAllocatedType()->isStructTy())
162         return false;
163 
164     AllocaChecker AC(AI.getModule()->getDataLayout());
165     AllocaChecker::PtrInfo PtrI = AC.visitPtr(AI);
166     if (PtrI.isEscaped() || PtrI.isAborted())
167         return false;
168 
169     // if we don't have any stores, nothing to do
170     if (AC.getStoreBBs().empty())
171         return false;
172 
173     bool Changed = false;
174     bool LocalChanged = true;
175     while (LocalChanged) {
176         LocalChanged = false;
177 
178         auto LII = AC.getPotentialLoads().begin();
179         while (LII != AC.getPotentialLoads().end()) {
180             if (processLoad(*LII, AC.getStoreBBs())) {
181                 LII = AC.getPotentialLoads().erase(LII);
182                 LocalChanged = true;
183             } else {
184                 ++LII;
185             }
186         }
187         Changed |= LocalChanged;
188     }
189 
190     return Changed;
191 }
192 
processLoad(LoadInst * LI,SetVector<BasicBlock * > & StoreBBs)193 bool PromoteConstantStructs::processLoad(LoadInst *LI, SetVector<BasicBlock*>& StoreBBs) {
194     unsigned limit = InstructionsLimit;
195     StoreInst* SI = nullptr;
196 
197     auto ML = MemoryLocation::get(LI);
198     for (auto StBB : StoreBBs) {
199         SmallVector<BasicBlock*, 32> Worklist;
200         Worklist.push_back(StBB);
201 
202         if (!isPotentiallyReachableFromMany(Worklist, LI->getParent(), nullptr, DT, LPI))
203             continue;
204 
205         Instruction* InstPt = StBB->getTerminator();
206         if (StBB == LI->getParent())
207             InstPt = LI;
208 
209         MemDepResult Dep = MD->getPointerDependencyFrom(ML, true,
210             InstPt->getIterator(), StBB, LI, &limit);
211 
212         if (Dep.isDef()) {
213             // skip if more than one def
214             if (SI)
215                 return false;
216 
217             // we search only for stores
218             SI = dyn_cast<StoreInst>(Dep.getInst());
219             if (!SI)
220                 return false;
221 
222             if (!DT->dominates(SI, LI))
223                 return false;
224         } else if (!Dep.isNonLocal()) {
225             return false;
226         }
227         // else no memdep in this BB, can move on
228     }
229 
230     if (!SI)
231         return false;
232 
233     // we search only for constants being stored
234     Constant *SC = dyn_cast<Constant>(SI->getValueOperand());
235     if (!SC)
236         return false;
237 
238     // no type casts
239     if (SC->getType() != LI->getType())
240         return false;
241 
242     LI->replaceAllUsesWith(SC);
243     LI->eraseFromParent();
244 
245     return true;
246 }
247