1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2020-2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 #include "Compiler/Optimizer/IntDivRemCombine.hpp"
10 #include "Compiler/IGCPassSupport.h"
11 
12 #include "common/LLVMWarningsPush.hpp"
13 #include "common/igc_regkeys.hpp"
14 #include <llvm/IR/Constants.h>
15 #include <llvm/IR/Operator.h>
16 #include <llvm/IR/Dominators.h>
17 #include <llvm/IR/Function.h>
18 #include <llvm/IR/IRBuilder.h>
19 #include <llvm/IR/Instructions.h>
20 #include <llvm/IR/InstIterator.h>
21 #include <llvm/IR/PatternMatch.h>
22 #include <llvm/Pass.h>
23 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
24 #include "common/LLVMWarningsPop.hpp"
25 
26 #include <cmath>
27 
28 using namespace llvm;
29 
30 // This replaces the remainder of div/rem pairs with same operands with
31 // a multiply and subtract.  This presumes that integer division/remainder
32 // are significantly more expensive than multiplication and substract.
33 //  Given n/d:   n = q*d + r ==> r = n - q*d
34 //
35 //  I.e.
36 //   %q = sdiv i32 %n, %d
37 //   ... stuff
38 //   %r1 = srem i32 %n, %d
39 //   ...
40 //   %r2 = srem i32 %n, %d
41 // ==>
42 //
43 //   %q   = sdiv i32 %n, %d
44 //   %tmp = imul i32 %q, %d
45 //   %r1  = isub i32 %n, %tmp
46 //   %r2  = isub i32 %n, %tmp
47 //   ... stuff
48 //
49 // NOTE: this also handles the case where rem precedes div
50 // I.e.
51 //   %r = srem i32 %n, %d
52 //   ... stuff
53 //   %q = sdiv i32 %n, %d
54 //
55 // NOTE: we constrain this to a basic block at the moment
56 //
57 struct IntDivRemCombine : public FunctionPass
58 {
59     static char ID;
60 
61     int options;
62 
63     IntDivRemCombine();
64 
65     /// @brief  Provides name of pass
getPassNameIntDivRemCombine66     virtual StringRef getPassName() const override {
67         return "IntDivRemCombine";
68     }
69 
getAnalysisUsageIntDivRemCombine70     void getAnalysisUsage(llvm::AnalysisUsage& AU) const override {
71         AU.setPreservesCFG();
72         //    AU.addRequired<MetaDataUtilsWrapper>();
73     }
74 
runOnFunctionIntDivRemCombine75     virtual bool runOnFunction(Function& F) override {
76         bool changed = false;
77 
78         for (Function::iterator bi = F.begin(), bie = F.end();
79             bi != bie; ++bi)
80         {
81             SmallVector<BinaryOperator*,4> divs;
82             BasicBlock &BB = *bi;
83             bool blockChanged;
84             do {
85                 blockChanged = false;
86                 for (BasicBlock::iterator ii = BB.begin(), ie = BB.end();
87                     ii != ie && !blockChanged; ii++)
88                 {
89                     Instruction *I = &*ii;
90                     switch (I->getOpcode()) {
91                     case Instruction::SDiv:
92                     case Instruction::UDiv:
93                         if (shouldSimplify(cast<BinaryOperator>(I))) {
94                             blockChanged |= replaceAllRemsInBlock(ii, ie);
95                             break;
96                         }
97                     case Instruction::SRem:
98                     case Instruction::URem:
99                         if (shouldSimplify(cast<BinaryOperator>(I))) {
100                             blockChanged |= hoistMatchingDivAbove(ii, ie);
101                             break;
102                         }
103                     }
104                 }
105                 changed |= blockChanged;
106             } while(blockChanged);
107         }
108 
109 #if 0
110         struct Merge {
111             BinaryOperator *anchor = nullptr;
112             //
113             BinaryOperator *div = nullptr;
114             SmallVector<BinaryOperator*,16> rems;
115         };
116         // for global we could
117         DominatorTree DT;
118         DT.recalculate(F);
119 #endif
120         return changed;
121     } // runOnFunction
122 
shouldSimplifyIntDivRemCombine123     bool shouldSimplify(BinaryOperator *b) const {
124         if (ConstantInt *divisor = dyn_cast<ConstantInt>(b->getOperand(1))) {
125             return !divisor->getValue().isPowerOf2();
126         } else {
127             return true;
128         }
129     }
130 
131     // replace all successive rems with divs
replaceAllRemsInBlockIntDivRemCombine132     bool replaceAllRemsInBlock(
133         BasicBlock::iterator ii, BasicBlock::iterator ie) const
134     {
135         bool changed = false;
136 
137         SmallVector<Instruction*,4> deleteMe;
138 
139         BinaryOperator *Q = cast<BinaryOperator>(&*ii);
140         auto targetOp = Q->getOpcode() == Instruction::SDiv ?
141             Instruction::SRem : Instruction::URem;
142         for (; ii != ie; ii++) {
143             Instruction *R = &*ii;
144             // TODO: can we use match(...) here?
145             // (c.f. PatternMatch.hpp)
146             if (R->getOpcode() == targetOp &&
147                 Q->getOperand(0) == R->getOperand(0) &&
148                 Q->getOperand(1) == R->getOperand(1))
149             {
150                 emulateRem(Q, R);
151                 //
152                 deleteMe.push_back(R);
153                 //
154                 changed = true;
155             }
156         }
157 
158         // delete all the replaced instructions
159         for (auto *R : deleteMe) {
160             R->eraseFromParent();
161         }
162 
163         return changed;
164     }
165 
166     // find a matching divide to place above us
167     // (we'll restart the block)
168     //
169     // E.g. given
170     //   %R = srem %N, %D
171     //   ...
172     //
173     // scan for a
174     //   %Q = sdiv %N, %D
175     // down below and move it above the remainder op
hoistMatchingDivAboveIntDivRemCombine176     bool hoistMatchingDivAbove(
177         BasicBlock::iterator ii, BasicBlock::iterator ie) const
178     {
179         BinaryOperator *R = cast<BinaryOperator>(&*ii);
180         auto targetOp = R->getOpcode() == Instruction::SRem ?
181             Instruction::SDiv : Instruction::UDiv;
182         for (; ii != ie; ii++) {
183             Instruction *Q = &*ii;
184             // TODO: can we use match(...) here?
185             // (c.f. PatternMatch.hpp)
186             if (Q->getOpcode() == targetOp &&
187                 Q->getOperand(0) == R->getOperand(0) &&
188                 Q->getOperand(1) == R->getOperand(1))
189             {
190                 Q->removeFromParent();
191                 Q->insertBefore(R);
192 
193                 // since we're here, we'll fix this one
194                 emulateRem(Q, R);
195                 R->eraseFromParent();
196 
197                 // return to restart iteration over the block
198                 return true;
199             }
200         }
201 
202         return false;
203     }
204 
205     // assumes Q precedes R
emulateRemIntDivRemCombine206     void emulateRem(Instruction *Q, Instruction *R) const {
207         IRBuilder<> B(R);
208         auto *Q_D = B.CreateMul(Q, Q->getOperand(1));
209         auto *R1 = B.CreateSub(Q->getOperand(0), Q_D, R->getName());
210         //
211         R->replaceAllUsesWith(R1);
212         R->dropAllReferences();
213     }
214 }; // class IntDivRemCombine
215 
216 
217 char IntDivRemCombine::ID = 0;
218 
219 // Register pass to igc-opt
220 #define PASS_FLAG "igc-divrem-combine"
221 #define PASS_DESCRIPTION "Integer Division/Remainder Combine"
222 #define PASS_CFG_ONLY false
223 #define PASS_ANALYSIS false
IGC_INITIALIZE_PASS_BEGIN(IntDivRemCombine,PASS_FLAG,PASS_DESCRIPTION,PASS_CFG_ONLY,PASS_ANALYSIS)224 IGC_INITIALIZE_PASS_BEGIN(IntDivRemCombine,
225     PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
226 IGC_INITIALIZE_PASS_END(IntDivRemCombine,
227     PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
228 
229 IntDivRemCombine::IntDivRemCombine()
230     : FunctionPass(ID)
231     , options((int)(IGC_GET_FLAG_VALUE(EnableIntDivRemCombine)))
232 {
233     initializeIntDivRemCombinePass(*PassRegistry::getPassRegistry());
234 }
235 
createIntDivRemCombinePass()236 llvm::FunctionPass* IGC::createIntDivRemCombinePass()
237 {
238     return new IntDivRemCombine();
239 }
240 
241 
242