1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2017-2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 #pragma once
10 #include "Compiler/CISACodeGen/CISACodeGen.h"
11 #include "Compiler/CISACodeGen/LivenessAnalysis.hpp"
12 #include "Compiler/IGCPassSupport.h"
13 #include "common/LLVMWarningsPush.hpp"
14 #include "llvm/ADT/DenseMap.h"
15 #include "llvm/ADT/SetVector.h"
16 #include "llvm/ADT/SparseBitVector.h"
17 #include "llvm/IR/Function.h"
18 #include "llvm/IR/BasicBlock.h"
19 #include "llvm/IR/Instructions.h"
20 #include "llvm/IR/Value.h"
21 #include "common/LLVMWarningsPop.hpp"
22 #include "Probe/Assertion.h"
23 
24 namespace IGC
25 {
26     enum RegClass : uint8_t {
27         REGISTER_CLASS_GRF = 0,
28         REGISTER_CLASS_FLAG = 1,
29 
30         REGISTER_CLASS_TOTAL = 2   // GRF only for now
31     };
32 
33     // various constants
34     enum {
35         GRF_TOTAL_NUM = 128,       // total number per thread
36         GRF_NUM_THRESHOLD = 50,    // used to see if register pressure is high
37         GRF_SIZE_IN_BYTE = 32,
38         DWORD_SIZE_IN_BYTE = 4,
39         FLAG_TOTAL_NUM = 4,
40         FLAG_TOTAL_NUM_SIMD32 = 2
41     };
42 
43     // Register Use info
44     struct RegUse
45     {
46         RegClass rClass;
47         //uint16_t nregs_simd8;
48         uint16_t nregs_simd16;
49         uint16_t uniformInBytes;
50 
RegUseIGC::RegUse51         RegUse() :
52             rClass(REGISTER_CLASS_GRF),
53             nregs_simd16(0), uniformInBytes(0)
54         {}
55 
RegUseIGC::RegUse56         RegUse(const RegUse& rhs) :
57             rClass(rhs.rClass),
58             nregs_simd16(rhs.nregs_simd16),
59             uniformInBytes(rhs.uniformInBytes)
60         {}
61 
operator +=IGC::RegUse62         RegUse& operator += (const RegUse& rhs)
63         {
64             nregs_simd16 += rhs.nregs_simd16;
65             uniformInBytes += rhs.uniformInBytes;
66             return *this;
67         }
68 
operator -=IGC::RegUse69         RegUse& operator -= (const RegUse& rhs)
70         {
71             if (nregs_simd16 > rhs.nregs_simd16)
72             {
73                 nregs_simd16 -= rhs.nregs_simd16;
74             }
75             else
76             {
77                 nregs_simd16 = 0;
78             }
79             if (uniformInBytes > rhs.uniformInBytes)
80             {
81                 uniformInBytes -= rhs.uniformInBytes;
82             }
83             else
84             {
85                 uniformInBytes = 0;
86             }
87             return *this;
88         }
89 
operator =IGC::RegUse90         RegUse& operator = (const RegUse& rhs) {
91             rClass = rhs.rClass;
92             nregs_simd16 = rhs.nregs_simd16;
93             uniformInBytes = rhs.uniformInBytes;
94             return *this;
95         }
96 
operator <IGC::RegUse97         bool operator < (const RegUse& rhs) const {
98             uint32_t n0 = (nregs_simd16 << 5) + uniformInBytes;
99             uint32_t n1 = (rhs.nregs_simd16 << 5) + rhs.uniformInBytes;
100             return n0 < n1;
101         }
102 
clearIGC::RegUse103         void clear(RegClass rc = REGISTER_CLASS_GRF) {
104             rClass = rc;
105             nregs_simd16 = 0;
106             uniformInBytes = 0;
107         }
108     };
109 
110     struct RegUsage {
111         RegUse allUses[REGISTER_CLASS_TOTAL];
112 
RegUsageIGC::RegUsage113         RegUsage() { clear(); }
114 
operator =IGC::RegUsage115         RegUsage& operator = (const RegUsage& rhs) {
116             for (int i = 0; i < REGISTER_CLASS_TOTAL; ++i) {
117                 allUses[(RegClass)i] = rhs.allUses[(RegClass)i];
118 
119             }
120             return *this;
121         }
122 
clearIGC::RegUsage123         void clear()
124         {
125             for (int i = 0; i < REGISTER_CLASS_TOTAL; ++i) {
126                 allUses[(RegClass)i].clear((RegClass)i);
127 
128             }
129         }
130     };
131 
132     class RegisterEstimator : public llvm::FunctionPass
133     {
134         friend class RegPressureTracker;
135     public:
136         typedef llvm::SmallVector<RegUse, 32>  ValueToRegUseMap;
137         typedef llvm::DenseMap<llvm::Instruction*, RegUsage>   InstToRegUsageMap;
138         typedef llvm::DenseMap<llvm::BasicBlock*, RegUsage> BBToRegUsageMap;
139 
140         // Either GRF or flag. FLAG value has LLVM type of i1
getValueRegClass(llvm::Value * V)141         static RegClass getValueRegClass(llvm::Value* V) {
142             bool isBool = V->getType()->getScalarType()->isIntegerTy(1);
143             return isBool ? REGISTER_CLASS_FLAG : REGISTER_CLASS_GRF;
144         }
145 
146         static char ID; // Pass identification, replacement for typeid
147 
RegisterEstimator()148         RegisterEstimator() :
149             llvm::FunctionPass(ID),
150             m_DL(nullptr),
151             m_LVA(nullptr),
152             m_F(nullptr),
153             m_WIA(nullptr)
154         {
155             initializeRegisterEstimatorPass(*llvm::PassRegistry::getPassRegistry());
156         }
157 
158         bool runOnFunction(llvm::Function& F) override;
159 
releaseMemory()160         void releaseMemory() override { clear(); }
161 
getPassName() const162         llvm::StringRef getPassName() const override { return "RegisterEstimator"; }
163 
getAnalysisUsage(llvm::AnalysisUsage & AU) const164         void getAnalysisUsage(llvm::AnalysisUsage& AU) const override
165         {
166             AU.addRequired<LivenessAnalysis>();
167             AU.setPreservesAll();
168         }
169 
getLivenessAnalysis() const170         LivenessAnalysis* getLivenessAnalysis() const { return m_LVA; }
171 
172         RegUse estimateNumOfRegs(llvm::Value* V) const;
173 
174         // This will compute Register pressure estimates. It also saves
175         // register pressure estimate per instruction if "doRPPerInst"
176         // is true.
177         void calculate(bool doRPEPerInst = false);
178 
179         // Once MAX register estimate of a function is computed, check
180         // if there is GRF pressure.  If the number of estimated registers
181         // is larger than a given threshold (threshold is selected based on
182         // the result of running SA PS shaders), it would think the register
183         // pressure could be high.  This is used to turn on/off register
184         // pressure tracking.
isGRFPressureLow(uint16_t simdsize=32) const185         bool isGRFPressureLow(uint16_t simdsize = 32) const
186         {
187             return isGRFPressureLow(simdsize, m_MaxRegs);
188         }
189 
190         // Return true if this function has no GRF pressure at all.
191         // A quick check to see if LivenessAnalysis is needed at all.
hasNoGRFPressure() const192         bool hasNoGRFPressure() const { return m_noGRFPressure; }
193 
194         uint32_t getNumLiveGRFAtInst(llvm::Instruction* I, uint16_t simdsize = 16);
195 
196         // Return the max number of GRF needed for a BB
getMaxLiveGRFAtBB(llvm::BasicBlock * BB,uint16_t simdsize=16)197         uint32_t getMaxLiveGRFAtBB(llvm::BasicBlock* BB, uint16_t simdsize = 16) {
198             RegUsage& ruse = m_BBMaxLiveVirtRegs[BB];
199             RegUse& grfuse = ruse.allUses[REGISTER_CLASS_GRF];
200             return getNumRegs(grfuse, simdsize);
201         }
202 
203         // Return the number of GRF needed at entry to a BB
getNumLiveInGRFAtBB(llvm::BasicBlock * BB,uint16_t simdsize=16)204         uint32_t getNumLiveInGRFAtBB(llvm::BasicBlock* BB, uint16_t simdsize = 16) {
205             RegUsage& ruse = m_BBLiveInVirtRegs[BB];
206             RegUse& grfuse = ruse.allUses[REGISTER_CLASS_GRF];
207             return getNumRegs(grfuse, simdsize);
208         }
209 
getNumValues() const210         uint32_t getNumValues() const {
211             return (uint32_t)m_ValueRegUses.capacity();
212         }
213 
214     private:
215 
216         bool m_RPEComputed;
217         const llvm::DataLayout* m_DL;
218         LivenessAnalysis* m_LVA;
219         llvm::Function* m_F;
220         WIAnalysis* m_WIA;   // optional
221 
222         // The number of live registers needed at each instruction
223         InstToRegUsageMap m_LiveVirtRegs;
224 
225         // Max live registers for each BB
226         BBToRegUsageMap m_BBMaxLiveVirtRegs;
227 
228         // The number of live-in registers for each BB
229         BBToRegUsageMap m_BBLiveInVirtRegs;
230 
231         // The max registers for this function, derived from RPE computation.
232         RegUsage m_MaxRegs;
233 
234         // Set it to true when the GRF needed is low even we assume all values are
235         // live from the entry to the end. This is used to skip RPE calucation entirely.
236         bool m_noGRFPressure;
237 
238         // Register needed for each value. Computed once for each value.
239         // Used to avoid recomputing the same value again. It is also
240         // used to check if Register Estimation for each BB/each Inst
241         // can be skipped completedly.
242         ValueToRegUseMap m_ValueRegUses;
243 
244         // Temporary use.
245         llvm::DenseMap<llvm::BasicBlock*, int> m_pBB2ID;
246 
247         void addRegUsage(RegUsage& RUsage, SBitVector& BV);
248 
getNumGRF(RegUsage & rusage,uint16_t simdsize=16)249         uint32_t getNumGRF(RegUsage& rusage, uint16_t simdsize = 16) {
250             RegUse& grfuse = rusage.allUses[REGISTER_CLASS_GRF];
251             return getNumRegs(grfuse, simdsize);
252         }
253 
254         int getNUsesInBB(llvm::Value* V, llvm::BasicBlock* BB);
255 
getLiveinRegsAtBB(RegUsage & RUsage,llvm::BasicBlock * BB)256         void getLiveinRegsAtBB(RegUsage& RUsage, llvm::BasicBlock* BB)
257         {
258             RUsage = m_BBLiveInVirtRegs[BB];
259             return;
260         }
getMaxLiveinRegsAtBB(RegUsage & RUsage,llvm::BasicBlock * BB)261         void getMaxLiveinRegsAtBB(RegUsage& RUsage, llvm::BasicBlock* BB)
262         {
263             RUsage = m_BBMaxLiveVirtRegs[BB];
264             return;
265         }
266 
isGRFPressureLow(uint16_t simdsize,const RegUsage & Regs) const267         bool isGRFPressureLow(uint16_t simdsize, const RegUsage& Regs) const
268         {
269             const RegUse& ruse = Regs.allUses[REGISTER_CLASS_GRF];
270             return (getNumRegs(ruse, simdsize) < (uint32_t)GRF_NUM_THRESHOLD);
271         }
272 
getRegUse(uint32_t valId)273         const RegUse* getRegUse(uint32_t valId)
274         {
275             return &(m_ValueRegUses[valId]);
276         }
277 
getRegUse(llvm::Value * V)278         const RegUse* getRegUse(llvm::Value* V)
279         {
280             ValueToIntMap::iterator II = m_LVA->ValueIds.find(V);
281             if (II == m_LVA->ValueIds.end())
282             {
283                 IGC_ASSERT_MESSAGE(0, "Value is not part of LivenessAnalysis");
284                 return nullptr;
285             }
286             uint32_t valId = II->second;
287             return getRegUse(valId);
288         }
289 
getNumRegs(const RegUse & RUse,uint16_t simdsize) const290         uint32_t getNumRegs(const RegUse& RUse, uint16_t simdsize) const
291         {
292             uint32_t uniformRegs =
293                 (RUse.uniformInBytes + GRF_SIZE_IN_BYTE - 1) / GRF_SIZE_IN_BYTE;
294             switch (simdsize) {
295             case 16:
296                 return RUse.nregs_simd16 + uniformRegs;
297             case 32:
298                 return 2 * RUse.nregs_simd16 + uniformRegs;
299             default:
300                 return (RUse.nregs_simd16 + 1) / 2 + uniformRegs;
301             }
302         }
303 
clear()304         void clear()
305         {
306             m_LiveVirtRegs.clear();
307             m_MaxRegs.clear();
308             m_BBMaxLiveVirtRegs.clear();
309             m_BBLiveInVirtRegs.clear();
310         }
311 
312     public:
313         /// print - Convert to human readable form
314         void print(llvm::raw_ostream& OS, int dumpLevel);
315         void print(llvm::raw_ostream& OS, llvm::BasicBlock* BB, int dumpLevl);
316 
317 #if defined( _DEBUG )
318         /// dump - Dump RPE info to dbgs(), used in debugger.
319         void dump();
320         void dump(int dumpLevel);
321         void dump(llvm::BasicBlock* BB);
322         void dump(llvm::BasicBlock* BB, int dumpLevel);
323 #endif
324     };
325 
326     // This is used to track the register pressure for a given BB or a
327     // a stream of instructions. It requires RegisterEstimator to
328     // provide the basic register estimation functionality and liveness
329     // information.
330     // To use it, an object of RegPressureTracker must be created first;
331     // then follow these steps:
332     //   1.  the object is initialized by init(), which sets up initial
333     //       live-in sets, with an empty initial instruction stream.
334     //   2.  advance() to add an instruction at a time sequentially to the
335     //       head of the instruction stream.
336     //   3.  getCurrNumGRF() returns the number of GRFs needed at the head
337     //       of the stream.
338     class RegPressureTracker {
339     public:
340         RegPressureTracker(RegisterEstimator* RPE);
341 
342         // init() will set up the initial live-in with BB's live-in.
343         // If "doMaxRegInBB" is true, it will use the max live-in in this
344         // BB instead of BB's live-in, which is more accurate, but also
345         // more expensive.
346         void init(llvm::BasicBlock* BB, bool doMaxRegInBB = false);
347         void advance(llvm::Instruction* I);
348 
getCurrNumGRF(uint16_t simdsize=16)349         uint32_t getCurrNumGRF(uint16_t simdsize = 16)
350         {
351             return m_pRPE->getNumGRF(m_RUsage, simdsize);
352         }
353 
isTrackingRegPressure() const354         bool isTrackingRegPressure() const { return m_TrackRegPressure; }
355 
356     private:
357         llvm::BasicBlock* m_BB;
358         RegisterEstimator* m_pRPE;
359         SBitVector  m_LiveOutSet;
360         bool m_TrackRegPressure;
361 
362         // register usage at the head of the current instruction stream.
363         RegUsage m_RUsage;
364 
365         // For a value that will be dead at the end of BB, keep the
366         // number of uses within the stream so we can have more
367         // accurate live information.
368         llvm::DenseMap<llvm::Value*, int> m_DeadValueNumUses;
369     };
370 }
371