1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2018-2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 /*========================== begin_copyright_notice ============================
10 
11 This file is distributed under the University of Illinois Open Source License.
12 See LICENSE.TXT for details.
13 
14 ============================= end_copyright_notice ===========================*/
15 
16 // This file implements the visit functions for mul, fmul, sdiv, udiv, fdiv,
17 // srem, urem, frem.
18 
19 #include "common/LLVMWarningsPush.hpp"
20 #include "InstCombineInternal.h"
21 #include "llvm/ADT/APFloat.h"
22 #include "llvm/ADT/APInt.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Analysis/InstructionSimplify.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constant.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/InstrTypes.h"
29 #include "llvm/IR/Instruction.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/Intrinsics.h"
33 #include "llvm/IR/Operator.h"
34 #include "llvm/IR/PatternMatch.h"
35 #include "llvm/IR/Type.h"
36 #include "llvm/IR/Value.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/KnownBits.h"
39 #include "llvm/Transforms/InstCombine/InstCombineWorklist.h"
40 #include "llvm/Transforms/Utils/BuildLibCalls.h"
41 #include "common/LLVMWarningsPop.hpp"
42 #include <cstddef>
43 #include <cstdint>
44 #include <utility>
45 #include "Probe/Assertion.h"
46 
47 using namespace llvm;
48 using namespace PatternMatch;
49 using namespace IGCombiner;
50 
51 #define DEBUG_TYPE "instcombine"
52 
53 /// The specific integer value is used in a context where it is known to be
54 /// non-zero.  If this allows us to simplify the computation, do so and return
55 /// the new operand, otherwise return null.
simplifyValueKnownNonZero(Value * V,InstCombiner & IC,Instruction & CxtI)56 static Value* simplifyValueKnownNonZero(Value* V, InstCombiner& IC,
57     Instruction& CxtI) {
58     // If V has multiple uses, then we would have to do more analysis to determine
59     // if this is safe.  For example, the use could be in dynamically unreached
60     // code.
61     if (!V->hasOneUse()) return nullptr;
62 
63     bool MadeChange = false;
64 
65     // ((1 << A) >>u B) --> (1 << (A-B))
66     // Because V cannot be zero, we know that B is less than A.
67     Value* A = nullptr, * B = nullptr, * One = nullptr;
68     if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(One), m_Value(A))), m_Value(B))) &&
69         match(One, m_One())) {
70         A = IC.Builder.CreateSub(A, B);
71         return IC.Builder.CreateShl(One, A);
72     }
73 
74     // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it
75     // inexact.  Similarly for <<.
76     BinaryOperator* I = dyn_cast<BinaryOperator>(V);
77     if (I && I->isLogicalShift() &&
78         IC.isKnownToBeAPowerOfTwo(I->getOperand(0), false, 0, &CxtI)) {
79         // We know that this is an exact/nuw shift and that the input is a
80         // non-zero context as well.
81         if (Value * V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) {
82             I->setOperand(0, V2);
83             MadeChange = true;
84         }
85 
86         if (I->getOpcode() == Instruction::LShr && !I->isExact()) {
87             I->setIsExact();
88             MadeChange = true;
89         }
90 
91         if (I->getOpcode() == Instruction::Shl && !I->hasNoUnsignedWrap()) {
92             I->setHasNoUnsignedWrap();
93             MadeChange = true;
94         }
95     }
96 
97     // TODO: Lots more we could do here:
98     //    If V is a phi node, we can call this on each of its operands.
99     //    "select cond, X, 0" can simplify to "X".
100 
101     return MadeChange ? V : nullptr;
102 }
103 
104 /// A helper routine of InstCombiner::visitMul().
105 ///
106 /// If C is a scalar/vector of known powers of 2, then this function returns
107 /// a new scalar/vector obtained from logBase2 of C.
108 /// Return a null pointer otherwise.
getLogBase2(Type * Ty,Constant * C)109 static Constant* getLogBase2(Type* Ty, Constant* C) {
110     const APInt* IVal;
111     if (match(C, m_APInt(IVal)) && IVal->isPowerOf2())
112         return ConstantInt::get(Ty, IVal->logBase2());
113 
114     if (!Ty->isVectorTy())
115         return nullptr;
116 
117     SmallVector<Constant*, 4> Elts;
118     for (unsigned I = 0, E = Ty->getVectorNumElements(); I != E; ++I) {
119         Constant* Elt = C->getAggregateElement(I);
120         if (!Elt)
121             return nullptr;
122         if (isa<UndefValue>(Elt)) {
123             Elts.push_back(UndefValue::get(Ty->getScalarType()));
124             continue;
125         }
126         if (!match(Elt, m_APInt(IVal)) || !IVal->isPowerOf2())
127             return nullptr;
128         Elts.push_back(ConstantInt::get(Ty->getScalarType(), IVal->logBase2()));
129     }
130 
131     return ConstantVector::get(Elts);
132 }
133 
visitMul(BinaryOperator & I)134 Instruction* InstCombiner::visitMul(BinaryOperator& I) {
135     if (Value * V = SimplifyMulInst(I.getOperand(0), I.getOperand(1),
136         SQ.getWithInstruction(&I)))
137         return replaceInstUsesWith(I, V);
138 
139     if (SimplifyAssociativeOrCommutative(I))
140         return &I;
141 
142     if (Instruction * X = foldShuffledBinop(I))
143         return X;
144 
145     if (Value * V = SimplifyUsingDistributiveLaws(I))
146         return replaceInstUsesWith(I, V);
147 
148     // X * -1 == 0 - X
149     Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
150     if (match(Op1, m_AllOnes())) {
151         BinaryOperator* BO = BinaryOperator::CreateNeg(Op0, I.getName());
152         if (I.hasNoSignedWrap())
153             BO->setHasNoSignedWrap();
154         return BO;
155     }
156 
157     // Also allow combining multiply instructions on vectors.
158     {
159         Value* NewOp;
160         Constant* C1, * C2;
161         const APInt* IVal;
162         if (match(&I, m_Mul(m_Shl(m_Value(NewOp), m_Constant(C2)),
163             m_Constant(C1))) &&
164             match(C1, m_APInt(IVal))) {
165             // ((X << C2)*C1) == (X * (C1 << C2))
166             Constant* Shl = ConstantExpr::getShl(C1, C2);
167             BinaryOperator* Mul = cast<BinaryOperator>(I.getOperand(0));
168             BinaryOperator* BO = BinaryOperator::CreateMul(NewOp, Shl);
169             if (I.hasNoUnsignedWrap() && Mul->hasNoUnsignedWrap())
170                 BO->setHasNoUnsignedWrap();
171             if (I.hasNoSignedWrap() && Mul->hasNoSignedWrap() &&
172                 Shl->isNotMinSignedValue())
173                 BO->setHasNoSignedWrap();
174             return BO;
175         }
176 
177         if (match(&I, m_Mul(m_Value(NewOp), m_Constant(C1)))) {
178             // Replace X*(2^C) with X << C, where C is either a scalar or a vector.
179             if (Constant * NewCst = getLogBase2(NewOp->getType(), C1)) {
180                 unsigned Width = NewCst->getType()->getPrimitiveSizeInBits();
181                 BinaryOperator* Shl = BinaryOperator::CreateShl(NewOp, NewCst);
182 
183                 if (I.hasNoUnsignedWrap())
184                     Shl->setHasNoUnsignedWrap();
185                 if (I.hasNoSignedWrap()) {
186                     const APInt* V;
187                     if (match(NewCst, m_APInt(V)) && *V != Width - 1)
188                         Shl->setHasNoSignedWrap();
189                 }
190 
191                 return Shl;
192             }
193         }
194     }
195 
196     if (ConstantInt * CI = dyn_cast<ConstantInt>(Op1)) {
197         // (Y - X) * (-(2**n)) -> (X - Y) * (2**n), for positive nonzero n
198         // (Y + const) * (-(2**n)) -> (-constY) * (2**n), for positive nonzero n
199         // The "* (2**n)" thus becomes a potential shifting opportunity.
200         {
201             const APInt& Val = CI->getValue();
202             const APInt& PosVal = Val.abs();
203             if (Val.isNegative() && PosVal.isPowerOf2()) {
204                 Value* X = nullptr, * Y = nullptr;
205                 if (Op0->hasOneUse()) {
206                     ConstantInt* C1;
207                     Value* Sub = nullptr;
208                     if (match(Op0, m_Sub(m_Value(Y), m_Value(X))))
209                         Sub = Builder.CreateSub(X, Y, "suba");
210                     else if (match(Op0, m_Add(m_Value(Y), m_ConstantInt(C1))))
211                         Sub = Builder.CreateSub(Builder.CreateNeg(C1), Y, "subc");
212                     if (Sub)
213                         return
214                         BinaryOperator::CreateMul(Sub,
215                             ConstantInt::get(Y->getType(), PosVal));
216                 }
217             }
218         }
219     }
220 
221     if (Instruction * FoldedMul = foldBinOpIntoSelectOrPhi(I))
222         return FoldedMul;
223 
224     // Simplify mul instructions with a constant RHS.
225     if (isa<Constant>(Op1)) {
226         // Canonicalize (X+C1)*CI -> X*CI+C1*CI.
227         Value* X;
228         Constant* C1;
229         if (match(Op0, m_OneUse(m_Add(m_Value(X), m_Constant(C1))))) {
230             Value* Mul = Builder.CreateMul(C1, Op1);
231             // Only go forward with the transform if C1*CI simplifies to a tidier
232             // constant.
233             if (!match(Mul, m_Mul(m_Value(), m_Value())))
234                 return BinaryOperator::CreateAdd(Builder.CreateMul(X, Op1), Mul);
235         }
236     }
237 
238     // -X * C --> X * -C
239     Value* X, * Y;
240     Constant* Op1C;
241     if (match(Op0, m_Neg(m_Value(X))) && match(Op1, m_Constant(Op1C)))
242         return BinaryOperator::CreateMul(X, ConstantExpr::getNeg(Op1C));
243 
244     // -X * -Y --> X * Y
245     if (match(Op0, m_Neg(m_Value(X))) && match(Op1, m_Neg(m_Value(Y)))) {
246         auto* NewMul = BinaryOperator::CreateMul(X, Y);
247         if (I.hasNoSignedWrap() &&
248             cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap() &&
249             cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap())
250             NewMul->setHasNoSignedWrap();
251         return NewMul;
252     }
253 
254     // (X / Y) *  Y = X - (X % Y)
255     // (X / Y) * -Y = (X % Y) - X
256     {
257         Value* Y = Op1;
258         BinaryOperator* Div = dyn_cast<BinaryOperator>(Op0);
259         if (!Div || (Div->getOpcode() != Instruction::UDiv &&
260             Div->getOpcode() != Instruction::SDiv)) {
261             Y = Op0;
262             Div = dyn_cast<BinaryOperator>(Op1);
263         }
264         Value* Neg = dyn_castNegVal(Y);
265         if (Div && Div->hasOneUse() &&
266             (Div->getOperand(1) == Y || Div->getOperand(1) == Neg) &&
267             (Div->getOpcode() == Instruction::UDiv ||
268                 Div->getOpcode() == Instruction::SDiv)) {
269             Value* X = Div->getOperand(0), * DivOp1 = Div->getOperand(1);
270 
271             // If the division is exact, X % Y is zero, so we end up with X or -X.
272             if (Div->isExact()) {
273                 if (DivOp1 == Y)
274                     return replaceInstUsesWith(I, X);
275                 return BinaryOperator::CreateNeg(X);
276             }
277 
278             auto RemOpc = Div->getOpcode() == Instruction::UDiv ? Instruction::URem
279                 : Instruction::SRem;
280             Value* Rem = Builder.CreateBinOp(RemOpc, X, DivOp1);
281             if (DivOp1 == Y)
282                 return BinaryOperator::CreateSub(X, Rem);
283             return BinaryOperator::CreateSub(Rem, X);
284         }
285     }
286 
287     /// i1 mul -> i1 and.
288     if (I.getType()->isIntOrIntVectorTy(1))
289         return BinaryOperator::CreateAnd(Op0, Op1);
290 
291     // X*(1 << Y) --> X << Y
292     // (1 << Y)*X --> X << Y
293     {
294         Value* Y;
295         BinaryOperator* BO = nullptr;
296         bool ShlNSW = false;
297         if (match(Op0, m_Shl(m_One(), m_Value(Y)))) {
298             BO = BinaryOperator::CreateShl(Op1, Y);
299             ShlNSW = cast<ShlOperator>(Op0)->hasNoSignedWrap();
300         }
301         else if (match(Op1, m_Shl(m_One(), m_Value(Y)))) {
302             BO = BinaryOperator::CreateShl(Op0, Y);
303             ShlNSW = cast<ShlOperator>(Op1)->hasNoSignedWrap();
304         }
305         if (BO) {
306             if (I.hasNoUnsignedWrap())
307                 BO->setHasNoUnsignedWrap();
308             if (I.hasNoSignedWrap() && ShlNSW)
309                 BO->setHasNoSignedWrap();
310             return BO;
311         }
312     }
313 
314     // (bool X) * Y --> X ? Y : 0
315     // Y * (bool X) --> X ? Y : 0
316     if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))
317         return SelectInst::Create(X, Op1, ConstantInt::get(I.getType(), 0));
318     if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))
319         return SelectInst::Create(X, Op0, ConstantInt::get(I.getType(), 0));
320 
321     // (lshr X, 31) * Y --> (ashr X, 31) & Y
322     // Y * (lshr X, 31) --> (ashr X, 31) & Y
323     // TODO: We are not checking one-use because the elimination of the multiply
324     //       is better for analysis?
325     // TODO: Should we canonicalize to '(X < 0) ? Y : 0' instead? That would be
326     //       more similar to what we're doing above.
327     const APInt* C;
328     if (match(Op0, m_LShr(m_Value(X), m_APInt(C))) && *C == C->getBitWidth() - 1)
329         return BinaryOperator::CreateAnd(Builder.CreateAShr(X, *C), Op1);
330     if (match(Op1, m_LShr(m_Value(X), m_APInt(C))) && *C == C->getBitWidth() - 1)
331         return BinaryOperator::CreateAnd(Builder.CreateAShr(X, *C), Op0);
332 
333     // Check for (mul (sext x), y), see if we can merge this into an
334     // integer mul followed by a sext.
335     if (SExtInst * Op0Conv = dyn_cast<SExtInst>(Op0)) {
336         // (mul (sext x), cst) --> (sext (mul x, cst'))
337         if (ConstantInt * Op1C = dyn_cast<ConstantInt>(Op1)) {
338             if (Op0Conv->hasOneUse()) {
339                 Constant* CI =
340                     ConstantExpr::getTrunc(Op1C, Op0Conv->getOperand(0)->getType());
341                 if (ConstantExpr::getSExt(CI, I.getType()) == Op1C &&
342                     willNotOverflowSignedMul(Op0Conv->getOperand(0), CI, I)) {
343                     // Insert the new, smaller mul.
344                     Value* NewMul =
345                         Builder.CreateNSWMul(Op0Conv->getOperand(0), CI, "mulconv");
346                     return new SExtInst(NewMul, I.getType());
347                 }
348             }
349         }
350 
351         // (mul (sext x), (sext y)) --> (sext (mul int x, y))
352         if (SExtInst * Op1Conv = dyn_cast<SExtInst>(Op1)) {
353             // Only do this if x/y have the same type, if at last one of them has a
354             // single use (so we don't increase the number of sexts), and if the
355             // integer mul will not overflow.
356             if (Op0Conv->getOperand(0)->getType() ==
357                 Op1Conv->getOperand(0)->getType() &&
358                 (Op0Conv->hasOneUse() || Op1Conv->hasOneUse()) &&
359                 willNotOverflowSignedMul(Op0Conv->getOperand(0),
360                     Op1Conv->getOperand(0), I)) {
361                 // Insert the new integer mul.
362                 Value* NewMul = Builder.CreateNSWMul(
363                     Op0Conv->getOperand(0), Op1Conv->getOperand(0), "mulconv");
364                 return new SExtInst(NewMul, I.getType());
365             }
366         }
367     }
368 
369     // Check for (mul (zext x), y), see if we can merge this into an
370     // integer mul followed by a zext.
371     if (auto * Op0Conv = dyn_cast<ZExtInst>(Op0)) {
372         // (mul (zext x), cst) --> (zext (mul x, cst'))
373         if (ConstantInt * Op1C = dyn_cast<ConstantInt>(Op1)) {
374             if (Op0Conv->hasOneUse()) {
375                 Constant* CI =
376                     ConstantExpr::getTrunc(Op1C, Op0Conv->getOperand(0)->getType());
377                 if (ConstantExpr::getZExt(CI, I.getType()) == Op1C &&
378                     willNotOverflowUnsignedMul(Op0Conv->getOperand(0), CI, I)) {
379                     // Insert the new, smaller mul.
380                     Value* NewMul =
381                         Builder.CreateNUWMul(Op0Conv->getOperand(0), CI, "mulconv");
382                     return new ZExtInst(NewMul, I.getType());
383                 }
384             }
385         }
386 
387         // (mul (zext x), (zext y)) --> (zext (mul int x, y))
388         if (auto * Op1Conv = dyn_cast<ZExtInst>(Op1)) {
389             // Only do this if x/y have the same type, if at last one of them has a
390             // single use (so we don't increase the number of zexts), and if the
391             // integer mul will not overflow.
392             if (Op0Conv->getOperand(0)->getType() ==
393                 Op1Conv->getOperand(0)->getType() &&
394                 (Op0Conv->hasOneUse() || Op1Conv->hasOneUse()) &&
395                 willNotOverflowUnsignedMul(Op0Conv->getOperand(0),
396                     Op1Conv->getOperand(0), I)) {
397                 // Insert the new integer mul.
398                 Value* NewMul = Builder.CreateNUWMul(
399                     Op0Conv->getOperand(0), Op1Conv->getOperand(0), "mulconv");
400                 return new ZExtInst(NewMul, I.getType());
401             }
402         }
403     }
404 
405     bool Changed = false;
406     if (!I.hasNoSignedWrap() && willNotOverflowSignedMul(Op0, Op1, I)) {
407         Changed = true;
408         I.setHasNoSignedWrap(true);
409     }
410 
411     if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedMul(Op0, Op1, I)) {
412         Changed = true;
413         I.setHasNoUnsignedWrap(true);
414     }
415 
416     return Changed ? &I : nullptr;
417 }
418 
visitFMul(BinaryOperator & I)419 Instruction* InstCombiner::visitFMul(BinaryOperator& I) {
420     if (Value * V = SimplifyFMulInst(I.getOperand(0), I.getOperand(1),
421         I.getFastMathFlags(),
422         SQ.getWithInstruction(&I)))
423         return replaceInstUsesWith(I, V);
424 
425     if (SimplifyAssociativeOrCommutative(I))
426         return &I;
427 
428     if (Instruction * X = foldShuffledBinop(I))
429         return X;
430 
431     if (Instruction * FoldedMul = foldBinOpIntoSelectOrPhi(I))
432         return FoldedMul;
433 
434     // X * -1.0 --> -X
435     Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
436     if (match(Op1, m_SpecificFP(-1.0)))
437         return BinaryOperator::CreateFNegFMF(Op0, &I);
438 
439     // -X * -Y --> X * Y
440     Value* X, * Y;
441     if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y))))
442         return BinaryOperator::CreateFMulFMF(X, Y, &I);
443 
444     // -X * C --> X * -C
445     Constant* C;
446     if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_Constant(C)))
447         return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I);
448 
449     // Sink negation: -X * Y --> -(X * Y)
450     if (match(Op0, m_OneUse(m_FNeg(m_Value(X)))))
451         return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op1, &I), &I);
452 
453     // Sink negation: Y * -X --> -(X * Y)
454     if (match(Op1, m_OneUse(m_FNeg(m_Value(X)))))
455         return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op0, &I), &I);
456 
457     // fabs(X) * fabs(X) -> X * X
458     if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X))))
459         return BinaryOperator::CreateFMulFMF(X, X, &I);
460 
461     // (select A, B, C) * (select A, D, E) --> select A, (B*D), (C*E)
462     if (Value * V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
463         return replaceInstUsesWith(I, V);
464 
465     if (I.hasAllowReassoc()) {
466         // Reassociate constant RHS with another constant to form constant
467         // expression.
468         if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) {
469             Constant* C1;
470             if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) {
471                 // (C1 / X) * C --> (C * C1) / X
472                 Constant* CC1 = ConstantExpr::getFMul(C, C1);
473                 if (CC1->isNormalFP())
474                     return BinaryOperator::CreateFDivFMF(CC1, X, &I);
475             }
476             if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) {
477                 // (X / C1) * C --> X * (C / C1)
478                 Constant* CDivC1 = ConstantExpr::getFDiv(C, C1);
479                 if (CDivC1->isNormalFP())
480                     return BinaryOperator::CreateFMulFMF(X, CDivC1, &I);
481 
482                 // If the constant was a denormal, try reassociating differently.
483                 // (X / C1) * C --> X / (C1 / C)
484                 Constant* C1DivC = ConstantExpr::getFDiv(C1, C);
485                 if (Op0->hasOneUse() && C1DivC->isNormalFP())
486                     return BinaryOperator::CreateFDivFMF(X, C1DivC, &I);
487             }
488 
489             // We do not need to match 'fadd C, X' and 'fsub X, C' because they are
490             // canonicalized to 'fadd X, C'. Distributing the multiply may allow
491             // further folds and (X * C) + C2 is 'fma'.
492             if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) {
493                 // (X + C1) * C --> (X * C) + (C * C1)
494                 Constant* CC1 = ConstantExpr::getFMul(C, C1);
495                 Value* XC = Builder.CreateFMulFMF(X, C, &I);
496                 return BinaryOperator::CreateFAddFMF(XC, CC1, &I);
497             }
498             if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) {
499                 // (C1 - X) * C --> (C * C1) - (X * C)
500                 Constant* CC1 = ConstantExpr::getFMul(C, C1);
501                 Value* XC = Builder.CreateFMulFMF(X, C, &I);
502                 return BinaryOperator::CreateFSubFMF(CC1, XC, &I);
503             }
504         }
505 
506         // sqrt(X) * sqrt(Y) -> sqrt(X * Y)
507         // nnan disallows the possibility of returning a number if both operands are
508         // negative (in that case, we should return NaN).
509         if (I.hasNoNaNs() &&
510             match(Op0, m_OneUse(m_Intrinsic<Intrinsic::sqrt>(m_Value(X)))) &&
511             match(Op1, m_OneUse(m_Intrinsic<Intrinsic::sqrt>(m_Value(Y))))) {
512             Value* XY = Builder.CreateFMulFMF(X, Y, &I);
513             Value* Sqrt = Builder.CreateIntrinsic(Intrinsic::sqrt, { XY }, &I);
514             return replaceInstUsesWith(I, Sqrt);
515         }
516 
517         // (X*Y) * X => (X*X) * Y where Y != X
518         //  The purpose is two-fold:
519         //   1) to form a power expression (of X).
520         //   2) potentially shorten the critical path: After transformation, the
521         //  latency of the instruction Y is amortized by the expression of X*X,
522         //  and therefore Y is in a "less critical" position compared to what it
523         //  was before the transformation.
524         if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) &&
525             Op1 != Y) {
526             Value* XX = Builder.CreateFMulFMF(Op1, Op1, &I);
527             return BinaryOperator::CreateFMulFMF(XX, Y, &I);
528         }
529         if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) &&
530             Op0 != Y) {
531             Value* XX = Builder.CreateFMulFMF(Op0, Op0, &I);
532             return BinaryOperator::CreateFMulFMF(XX, Y, &I);
533         }
534     }
535 
536     // log2(X * 0.5) * Y = log2(X) * Y - Y
537     if (I.isFast()) {
538         IntrinsicInst* Log2 = nullptr;
539         if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::log2>(
540             m_OneUse(m_FMul(m_Value(X), m_SpecificFP(0.5))))))) {
541             Log2 = cast<IntrinsicInst>(Op0);
542             Y = Op1;
543         }
544         if (match(Op1, m_OneUse(m_Intrinsic<Intrinsic::log2>(
545             m_OneUse(m_FMul(m_Value(X), m_SpecificFP(0.5))))))) {
546             Log2 = cast<IntrinsicInst>(Op1);
547             Y = Op0;
548         }
549         if (Log2) {
550             Log2->setArgOperand(0, X);
551             Log2->copyFastMathFlags(&I);
552             Value* LogXTimesY = Builder.CreateFMulFMF(Log2, Y, &I);
553             return BinaryOperator::CreateFSubFMF(LogXTimesY, Y, &I);
554         }
555     }
556 
557     return nullptr;
558 }
559 
560 /// Fold a divide or remainder with a select instruction divisor when one of the
561 /// select operands is zero. In that case, we can use the other select operand
562 /// because div/rem by zero is undefined.
simplifyDivRemOfSelectWithZeroOp(BinaryOperator & I)563 bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator& I) {
564     SelectInst* SI = dyn_cast<SelectInst>(I.getOperand(1));
565     if (!SI)
566         return false;
567 
568     int NonNullOperand;
569     if (match(SI->getTrueValue(), m_Zero()))
570         // div/rem X, (Cond ? 0 : Y) -> div/rem X, Y
571         NonNullOperand = 2;
572     else if (match(SI->getFalseValue(), m_Zero()))
573         // div/rem X, (Cond ? Y : 0) -> div/rem X, Y
574         NonNullOperand = 1;
575     else
576         return false;
577 
578     // Change the div/rem to use 'Y' instead of the select.
579     I.setOperand(1, SI->getOperand(NonNullOperand));
580 
581     // Okay, we know we replace the operand of the div/rem with 'Y' with no
582     // problem.  However, the select, or the condition of the select may have
583     // multiple uses.  Based on our knowledge that the operand must be non-zero,
584     // propagate the known value for the select into other uses of it, and
585     // propagate a known value of the condition into its other users.
586 
587     // If the select and condition only have a single use, don't bother with this,
588     // early exit.
589     Value* SelectCond = SI->getCondition();
590     if (SI->use_empty() && SelectCond->hasOneUse())
591         return true;
592 
593     // Scan the current block backward, looking for other uses of SI.
594     BasicBlock::iterator BBI = I.getIterator(), BBFront = I.getParent()->begin();
595     Type* CondTy = SelectCond->getType();
596     while (BBI != BBFront) {
597         --BBI;
598         // If we found an instruction that we can't assume will return, so
599         // information from below it cannot be propagated above it.
600         if (!isGuaranteedToTransferExecutionToSuccessor(&*BBI))
601             break;
602 
603         // Replace uses of the select or its condition with the known values.
604         for (Instruction::op_iterator I = BBI->op_begin(), E = BBI->op_end();
605             I != E; ++I) {
606             if (*I == SI) {
607                 *I = SI->getOperand(NonNullOperand);
608                 Worklist.Add(&*BBI);
609             }
610             else if (*I == SelectCond) {
611                 *I = NonNullOperand == 1 ? ConstantInt::getTrue(CondTy)
612                     : ConstantInt::getFalse(CondTy);
613                 Worklist.Add(&*BBI);
614             }
615         }
616 
617         // If we past the instruction, quit looking for it.
618         if (&*BBI == SI)
619             SI = nullptr;
620         if (&*BBI == SelectCond)
621             SelectCond = nullptr;
622 
623         // If we ran out of things to eliminate, break out of the loop.
624         if (!SelectCond && !SI)
625             break;
626 
627     }
628     return true;
629 }
630 
631 /// True if the multiply can not be expressed in an int this size.
multiplyOverflows(const APInt & C1,const APInt & C2,APInt & Product,bool IsSigned)632 static bool multiplyOverflows(const APInt& C1, const APInt& C2, APInt& Product,
633     bool IsSigned) {
634     bool Overflow;
635     Product = IsSigned ? C1.smul_ov(C2, Overflow) : C1.umul_ov(C2, Overflow);
636     return Overflow;
637 }
638 
639 /// True if C1 is a multiple of C2. Quotient contains C1/C2.
isMultiple(const APInt & C1,const APInt & C2,APInt & Quotient,bool IsSigned)640 static bool isMultiple(const APInt& C1, const APInt& C2, APInt& Quotient,
641     bool IsSigned) {
642     IGC_ASSERT_MESSAGE(C1.getBitWidth() == C2.getBitWidth(), "Constant widths not equal");
643 
644     // Bail if we will divide by zero.
645     if (C2.isNullValue())
646         return false;
647 
648     // Bail if we would divide INT_MIN by -1.
649     if (IsSigned && C1.isMinSignedValue() && C2.isAllOnesValue())
650         return false;
651 
652     APInt Remainder(C1.getBitWidth(), /*Val=*/0ULL, IsSigned);
653     if (IsSigned)
654         APInt::sdivrem(C1, C2, Quotient, Remainder);
655     else
656         APInt::udivrem(C1, C2, Quotient, Remainder);
657 
658     return Remainder.isMinValue();
659 }
660 
661 /// This function implements the transforms common to both integer division
662 /// instructions (udiv and sdiv). It is called by the visitors to those integer
663 /// division instructions.
664 /// Common integer divide transforms
commonIDivTransforms(BinaryOperator & I)665 Instruction* InstCombiner::commonIDivTransforms(BinaryOperator& I) {
666     Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
667     bool IsSigned = I.getOpcode() == Instruction::SDiv;
668     Type* Ty = I.getType();
669 
670     // The RHS is known non-zero.
671     if (Value * V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) {
672         I.setOperand(1, V);
673         return &I;
674     }
675 
676     // Handle cases involving: [su]div X, (select Cond, Y, Z)
677     // This does not apply for fdiv.
678     if (simplifyDivRemOfSelectWithZeroOp(I))
679         return &I;
680 
681     const APInt* C2;
682     if (match(Op1, m_APInt(C2))) {
683         Value* X;
684         const APInt* C1;
685 
686         // (X / C1) / C2  -> X / (C1*C2)
687         if ((IsSigned && match(Op0, m_SDiv(m_Value(X), m_APInt(C1)))) ||
688             (!IsSigned && match(Op0, m_UDiv(m_Value(X), m_APInt(C1))))) {
689             APInt Product(C1->getBitWidth(), /*Val=*/0ULL, IsSigned);
690             if (!multiplyOverflows(*C1, *C2, Product, IsSigned))
691                 return BinaryOperator::Create(I.getOpcode(), X,
692                     ConstantInt::get(Ty, Product));
693         }
694 
695         if ((IsSigned && match(Op0, m_NSWMul(m_Value(X), m_APInt(C1)))) ||
696             (!IsSigned && match(Op0, m_NUWMul(m_Value(X), m_APInt(C1))))) {
697             APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned);
698 
699             // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1.
700             if (isMultiple(*C2, *C1, Quotient, IsSigned)) {
701                 auto* NewDiv = BinaryOperator::Create(I.getOpcode(), X,
702                     ConstantInt::get(Ty, Quotient));
703                 NewDiv->setIsExact(I.isExact());
704                 return NewDiv;
705             }
706 
707             // (X * C1) / C2 -> X * (C1 / C2) if C1 is a multiple of C2.
708             if (isMultiple(*C1, *C2, Quotient, IsSigned)) {
709                 auto* Mul = BinaryOperator::Create(Instruction::Mul, X,
710                     ConstantInt::get(Ty, Quotient));
711                 auto* OBO = cast<OverflowingBinaryOperator>(Op0);
712                 Mul->setHasNoUnsignedWrap(!IsSigned && OBO->hasNoUnsignedWrap());
713                 Mul->setHasNoSignedWrap(OBO->hasNoSignedWrap());
714                 return Mul;
715             }
716         }
717 
718         if ((IsSigned && match(Op0, m_NSWShl(m_Value(X), m_APInt(C1))) &&
719             *C1 != C1->getBitWidth() - 1) ||
720             (!IsSigned && match(Op0, m_NUWShl(m_Value(X), m_APInt(C1))))) {
721             APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned);
722             APInt C1Shifted = APInt::getOneBitSet(
723                 C1->getBitWidth(), static_cast<unsigned>(C1->getLimitedValue()));
724 
725             // (X << C1) / C2 -> X / (C2 >> C1) if C2 is a multiple of 1 << C1.
726             if (isMultiple(*C2, C1Shifted, Quotient, IsSigned)) {
727                 auto* BO = BinaryOperator::Create(I.getOpcode(), X,
728                     ConstantInt::get(Ty, Quotient));
729                 BO->setIsExact(I.isExact());
730                 return BO;
731             }
732 
733             // (X << C1) / C2 -> X * ((1 << C1) / C2) if 1 << C1 is a multiple of C2.
734             if (isMultiple(C1Shifted, *C2, Quotient, IsSigned)) {
735                 auto* Mul = BinaryOperator::Create(Instruction::Mul, X,
736                     ConstantInt::get(Ty, Quotient));
737                 auto* OBO = cast<OverflowingBinaryOperator>(Op0);
738                 Mul->setHasNoUnsignedWrap(!IsSigned && OBO->hasNoUnsignedWrap());
739                 Mul->setHasNoSignedWrap(OBO->hasNoSignedWrap());
740                 return Mul;
741             }
742         }
743 
744         if (!C2->isNullValue()) // avoid X udiv 0
745             if (Instruction * FoldedDiv = foldBinOpIntoSelectOrPhi(I))
746                 return FoldedDiv;
747     }
748 
749     if (match(Op0, m_One())) {
750         IGC_ASSERT_MESSAGE(!Ty->isIntOrIntVectorTy(1), "i1 divide not removed?");
751         if (IsSigned) {
752             // If Op1 is 0 then it's undefined behaviour, if Op1 is 1 then the
753             // result is one, if Op1 is -1 then the result is minus one, otherwise
754             // it's zero.
755             Value* Inc = Builder.CreateAdd(Op1, Op0);
756             Value* Cmp = Builder.CreateICmpULT(Inc, ConstantInt::get(Ty, 3));
757             return SelectInst::Create(Cmp, Op1, ConstantInt::get(Ty, 0));
758         }
759         else {
760             // If Op1 is 0 then it's undefined behaviour. If Op1 is 1 then the
761             // result is one, otherwise it's zero.
762             return new ZExtInst(Builder.CreateICmpEQ(Op1, Op0), Ty);
763         }
764     }
765 
766     // See if we can fold away this div instruction.
767     if (SimplifyDemandedInstructionBits(I))
768         return &I;
769 
770     // (X - (X rem Y)) / Y -> X / Y; usually originates as ((X / Y) * Y) / Y
771     Value* X, * Z;
772     if (match(Op0, m_Sub(m_Value(X), m_Value(Z)))) // (X - Z) / Y; Y = Op1
773         if ((IsSigned && match(Z, m_SRem(m_Specific(X), m_Specific(Op1)))) ||
774             (!IsSigned && match(Z, m_URem(m_Specific(X), m_Specific(Op1)))))
775             return BinaryOperator::Create(I.getOpcode(), X, Op1);
776 
777     // (X << Y) / X -> 1 << Y
778     Value* Y;
779     if (IsSigned && match(Op0, m_NSWShl(m_Specific(Op1), m_Value(Y))))
780         return BinaryOperator::CreateNSWShl(ConstantInt::get(Ty, 1), Y);
781     if (!IsSigned && match(Op0, m_NUWShl(m_Specific(Op1), m_Value(Y))))
782         return BinaryOperator::CreateNUWShl(ConstantInt::get(Ty, 1), Y);
783 
784     // X / (X * Y) -> 1 / Y if the multiplication does not overflow.
785     if (match(Op1, m_c_Mul(m_Specific(Op0), m_Value(Y)))) {
786         bool HasNSW = cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap();
787         bool HasNUW = cast<OverflowingBinaryOperator>(Op1)->hasNoUnsignedWrap();
788         if ((IsSigned && HasNSW) || (!IsSigned && HasNUW)) {
789             I.setOperand(0, ConstantInt::get(Ty, 1));
790             I.setOperand(1, Y);
791             return &I;
792         }
793     }
794 
795     return nullptr;
796 }
797 
798 static const unsigned MaxDepth = 6;
799 
800 namespace {
801 
802     using FoldUDivOperandCb = Instruction * (*)(Value* Op0, Value* Op1,
803         const BinaryOperator& I,
804         InstCombiner& IC);
805 
806     /// Used to maintain state for visitUDivOperand().
807     struct UDivFoldAction {
808         /// Informs visitUDiv() how to fold this operand.  This can be zero if this
809         /// action joins two actions together.
810         FoldUDivOperandCb FoldAction;
811 
812         /// Which operand to fold.
813         Value* OperandToFold;
814 
815         union {
816             /// The instruction returned when FoldAction is invoked.
817             Instruction* FoldResult;
818 
819             /// Stores the LHS action index if this action joins two actions together.
820             size_t SelectLHSIdx;
821         };
822 
UDivFoldAction__anon7caf39df0111::UDivFoldAction823         UDivFoldAction(FoldUDivOperandCb FA, Value* InputOperand)
824             : FoldAction(FA), OperandToFold(InputOperand), FoldResult(nullptr) {}
UDivFoldAction__anon7caf39df0111::UDivFoldAction825         UDivFoldAction(FoldUDivOperandCb FA, Value* InputOperand, size_t SLHS)
826             : FoldAction(FA), OperandToFold(InputOperand), SelectLHSIdx(SLHS) {}
827     };
828 
829 } // end anonymous namespace
830 
831 // X udiv 2^C -> X >> C
foldUDivPow2Cst(Value * Op0,Value * Op1,const BinaryOperator & I,InstCombiner & IC)832 static Instruction* foldUDivPow2Cst(Value* Op0, Value* Op1,
833     const BinaryOperator& I, InstCombiner& IC) {
834     Constant* C1 = getLogBase2(Op0->getType(), cast<Constant>(Op1));
835     IGC_ASSERT_EXIT_MESSAGE(nullptr != C1, "Failed to constant fold udiv -> logbase2");
836     BinaryOperator* LShr = BinaryOperator::CreateLShr(Op0, C1);
837     if (I.isExact())
838         LShr->setIsExact();
839     return LShr;
840 }
841 
842 // X udiv (C1 << N), where C1 is "1<<C2"  -->  X >> (N+C2)
843 // X udiv (zext (C1 << N)), where C1 is "1<<C2"  -->  X >> (N+C2)
foldUDivShl(Value * Op0,Value * Op1,const BinaryOperator & I,InstCombiner & IC)844 static Instruction* foldUDivShl(Value* Op0, Value* Op1, const BinaryOperator& I,
845     InstCombiner& IC) {
846     Value* ShiftLeft;
847     if (!match(Op1, m_ZExt(m_Value(ShiftLeft))))
848         ShiftLeft = Op1;
849 
850     Constant* CI;
851     Value* N = nullptr;
852     IGC_ASSERT_EXIT_MESSAGE(match(ShiftLeft, m_Shl(m_Constant(CI), m_Value(N))), "match should never fail here!");
853     Constant* Log2Base = getLogBase2(N->getType(), CI);
854     IGC_ASSERT_EXIT_MESSAGE(nullptr != Log2Base, "getLogBase2 should never fail here!");
855     N = IC.Builder.CreateAdd(N, Log2Base);
856     if (Op1 != ShiftLeft)
857         N = IC.Builder.CreateZExt(N, Op1->getType());
858     BinaryOperator* LShr = BinaryOperator::CreateLShr(Op0, N);
859     if (I.isExact())
860         LShr->setIsExact();
861     return LShr;
862 }
863 
864 // Recursively visits the possible right hand operands of a udiv
865 // instruction, seeing through select instructions, to determine if we can
866 // replace the udiv with something simpler.  If we find that an operand is not
867 // able to simplify the udiv, we abort the entire transformation.
visitUDivOperand(Value * Op0,Value * Op1,const BinaryOperator & I,SmallVectorImpl<UDivFoldAction> & Actions,unsigned Depth=0)868 static size_t visitUDivOperand(Value* Op0, Value* Op1, const BinaryOperator& I,
869     SmallVectorImpl<UDivFoldAction>& Actions,
870     unsigned Depth = 0) {
871     // Check to see if this is an unsigned division with an exact power of 2,
872     // if so, convert to a right shift.
873     if (match(Op1, m_Power2())) {
874         Actions.push_back(UDivFoldAction(foldUDivPow2Cst, Op1));
875         return Actions.size();
876     }
877 
878     // X udiv (C1 << N), where C1 is "1<<C2"  -->  X >> (N+C2)
879     if (match(Op1, m_Shl(m_Power2(), m_Value())) ||
880         match(Op1, m_ZExt(m_Shl(m_Power2(), m_Value())))) {
881         Actions.push_back(UDivFoldAction(foldUDivShl, Op1));
882         return Actions.size();
883     }
884 
885     // The remaining tests are all recursive, so bail out if we hit the limit.
886     if (Depth++ == MaxDepth)
887         return 0;
888 
889     if (SelectInst * SI = dyn_cast<SelectInst>(Op1))
890         if (size_t LHSIdx =
891             visitUDivOperand(Op0, SI->getOperand(1), I, Actions, Depth))
892             if (visitUDivOperand(Op0, SI->getOperand(2), I, Actions, Depth)) {
893                 Actions.push_back(UDivFoldAction(nullptr, Op1, LHSIdx - 1));
894                 return Actions.size();
895             }
896 
897     return 0;
898 }
899 
900 /// If we have zero-extended operands of an unsigned div or rem, we may be able
901 /// to narrow the operation (sink the zext below the math).
narrowUDivURem(BinaryOperator & I,InstCombiner::BuilderTy & Builder)902 static Instruction* narrowUDivURem(BinaryOperator& I,
903     InstCombiner::BuilderTy& Builder) {
904     Instruction::BinaryOps Opcode = I.getOpcode();
905     Value* N = I.getOperand(0);
906     Value* D = I.getOperand(1);
907     Type* Ty = I.getType();
908     Value* X, * Y;
909     if (match(N, m_ZExt(m_Value(X))) && match(D, m_ZExt(m_Value(Y))) &&
910         X->getType() == Y->getType() && (N->hasOneUse() || D->hasOneUse())) {
911         // udiv (zext X), (zext Y) --> zext (udiv X, Y)
912         // urem (zext X), (zext Y) --> zext (urem X, Y)
913         Value* NarrowOp = Builder.CreateBinOp(Opcode, X, Y);
914         return new ZExtInst(NarrowOp, Ty);
915     }
916 
917     Constant* C;
918     if ((match(N, m_OneUse(m_ZExt(m_Value(X)))) && match(D, m_Constant(C))) ||
919         (match(D, m_OneUse(m_ZExt(m_Value(X)))) && match(N, m_Constant(C)))) {
920         // If the constant is the same in the smaller type, use the narrow version.
921         Constant* TruncC = ConstantExpr::getTrunc(C, X->getType());
922         if (ConstantExpr::getZExt(TruncC, Ty) != C)
923             return nullptr;
924 
925         // udiv (zext X), C --> zext (udiv X, C')
926         // urem (zext X), C --> zext (urem X, C')
927         // udiv C, (zext X) --> zext (udiv C', X)
928         // urem C, (zext X) --> zext (urem C', X)
929         Value* NarrowOp = isa<Constant>(D) ? Builder.CreateBinOp(Opcode, X, TruncC)
930             : Builder.CreateBinOp(Opcode, TruncC, X);
931         return new ZExtInst(NarrowOp, Ty);
932     }
933 
934     return nullptr;
935 }
936 
visitUDiv(BinaryOperator & I)937 Instruction* InstCombiner::visitUDiv(BinaryOperator& I) {
938     if (Value * V = SimplifyUDivInst(I.getOperand(0), I.getOperand(1),
939         SQ.getWithInstruction(&I)))
940         return replaceInstUsesWith(I, V);
941 
942     if (Instruction * X = foldShuffledBinop(I))
943         return X;
944 
945     // Handle the integer div common cases
946     if (Instruction * Common = commonIDivTransforms(I))
947         return Common;
948 
949     Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
950     Value* X;
951     const APInt* C1, * C2;
952     if (match(Op0, m_LShr(m_Value(X), m_APInt(C1))) && match(Op1, m_APInt(C2))) {
953         // (X lshr C1) udiv C2 --> X udiv (C2 << C1)
954         bool Overflow;
955         APInt C2ShlC1 = C2->ushl_ov(*C1, Overflow);
956         if (!Overflow) {
957             bool IsExact = I.isExact() && match(Op0, m_Exact(m_Value()));
958             BinaryOperator* BO = BinaryOperator::CreateUDiv(
959                 X, ConstantInt::get(X->getType(), C2ShlC1));
960             if (IsExact)
961                 BO->setIsExact();
962             return BO;
963         }
964     }
965 
966     // Op0 / C where C is large (negative) --> zext (Op0 >= C)
967     // TODO: Could use isKnownNegative() to handle non-constant values.
968     Type* Ty = I.getType();
969     if (match(Op1, m_Negative())) {
970         Value* Cmp = Builder.CreateICmpUGE(Op0, Op1);
971         return CastInst::CreateZExtOrBitCast(Cmp, Ty);
972     }
973     // Op0 / (sext i1 X) --> zext (Op0 == -1) (if X is 0, the div is undefined)
974     if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
975         Value* Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::getAllOnesValue(Ty));
976         return CastInst::CreateZExtOrBitCast(Cmp, Ty);
977     }
978 
979     if (Instruction * NarrowDiv = narrowUDivURem(I, Builder))
980         return NarrowDiv;
981 
982     // If the udiv operands are non-overflowing multiplies with a common operand,
983     // then eliminate the common factor:
984     // (A * B) / (A * X) --> B / X (and commuted variants)
985     // TODO: The code would be reduced if we had m_c_NUWMul pattern matching.
986     // TODO: If -reassociation handled this generally, we could remove this.
987     Value* A, * B;
988     if (match(Op0, m_NUWMul(m_Value(A), m_Value(B)))) {
989         if (match(Op1, m_NUWMul(m_Specific(A), m_Value(X))) ||
990             match(Op1, m_NUWMul(m_Value(X), m_Specific(A))))
991             return BinaryOperator::CreateUDiv(B, X);
992         if (match(Op1, m_NUWMul(m_Specific(B), m_Value(X))) ||
993             match(Op1, m_NUWMul(m_Value(X), m_Specific(B))))
994             return BinaryOperator::CreateUDiv(A, X);
995     }
996 
997     // (LHS udiv (select (select (...)))) -> (LHS >> (select (select (...))))
998     SmallVector<UDivFoldAction, 6> UDivActions;
999     if (visitUDivOperand(Op0, Op1, I, UDivActions))
1000         for (unsigned i = 0, e = UDivActions.size(); i != e; ++i) {
1001             FoldUDivOperandCb Action = UDivActions[i].FoldAction;
1002             Value* ActionOp1 = UDivActions[i].OperandToFold;
1003             Instruction* Inst;
1004             if (Action)
1005                 Inst = Action(Op0, ActionOp1, I, *this);
1006             else {
1007                 // This action joins two actions together.  The RHS of this action is
1008                 // simply the last action we processed, we saved the LHS action index in
1009                 // the joining action.
1010                 size_t SelectRHSIdx = i - 1;
1011                 Value* SelectRHS = UDivActions[SelectRHSIdx].FoldResult;
1012                 size_t SelectLHSIdx = UDivActions[i].SelectLHSIdx;
1013                 Value* SelectLHS = UDivActions[SelectLHSIdx].FoldResult;
1014                 Inst = SelectInst::Create(cast<SelectInst>(ActionOp1)->getCondition(),
1015                     SelectLHS, SelectRHS);
1016             }
1017 
1018             // If this is the last action to process, return it to the InstCombiner.
1019             // Otherwise, we insert it before the UDiv and record it so that we may
1020             // use it as part of a joining action (i.e., a SelectInst).
1021             if (e - i != 1) {
1022                 Inst->insertBefore(&I);
1023                 UDivActions[i].FoldResult = Inst;
1024             }
1025             else
1026                 return Inst;
1027         }
1028 
1029     return nullptr;
1030 }
1031 
visitSDiv(BinaryOperator & I)1032 Instruction* InstCombiner::visitSDiv(BinaryOperator& I) {
1033     if (Value * V = SimplifySDivInst(I.getOperand(0), I.getOperand(1),
1034         SQ.getWithInstruction(&I)))
1035         return replaceInstUsesWith(I, V);
1036 
1037     if (Instruction * X = foldShuffledBinop(I))
1038         return X;
1039 
1040     // Handle the integer div common cases
1041     if (Instruction * Common = commonIDivTransforms(I))
1042         return Common;
1043 
1044     Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
1045     Value* X;
1046     // sdiv Op0, -1 --> -Op0
1047     // sdiv Op0, (sext i1 X) --> -Op0 (because if X is 0, the op is undefined)
1048     if (match(Op1, m_AllOnes()) ||
1049         (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)))
1050         return BinaryOperator::CreateNeg(Op0);
1051 
1052     const APInt* Op1C;
1053     if (match(Op1, m_APInt(Op1C))) {
1054         // sdiv exact X, C  -->  ashr exact X, log2(C)
1055         if (I.isExact() && Op1C->isNonNegative() && Op1C->isPowerOf2()) {
1056             Value* ShAmt = ConstantInt::get(Op1->getType(), Op1C->exactLogBase2());
1057             return BinaryOperator::CreateExactAShr(Op0, ShAmt, I.getName());
1058         }
1059 
1060         // If the dividend is sign-extended and the constant divisor is small enough
1061         // to fit in the source type, shrink the division to the narrower type:
1062         // (sext X) sdiv C --> sext (X sdiv C)
1063         Value* Op0Src;
1064         if (match(Op0, m_OneUse(m_SExt(m_Value(Op0Src)))) &&
1065             Op0Src->getType()->getScalarSizeInBits() >= Op1C->getMinSignedBits()) {
1066 
1067             // In the general case, we need to make sure that the dividend is not the
1068             // minimum signed value because dividing that by -1 is UB. But here, we
1069             // know that the -1 divisor case is already handled above.
1070 
1071             Constant* NarrowDivisor =
1072                 ConstantExpr::getTrunc(cast<Constant>(Op1), Op0Src->getType());
1073             Value* NarrowOp = Builder.CreateSDiv(Op0Src, NarrowDivisor);
1074             return new SExtInst(NarrowOp, Op0->getType());
1075         }
1076     }
1077 
1078     if (Constant * RHS = dyn_cast<Constant>(Op1)) {
1079         // X/INT_MIN -> X == INT_MIN
1080         if (RHS->isMinSignedValue())
1081             return new ZExtInst(Builder.CreateICmpEQ(Op0, Op1), I.getType());
1082 
1083         // -X/C  -->  X/-C  provided the negation doesn't overflow.
1084         Value* X;
1085         if (match(Op0, m_NSWSub(m_Zero(), m_Value(X)))) {
1086             auto* BO = BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(RHS));
1087             BO->setIsExact(I.isExact());
1088             return BO;
1089         }
1090     }
1091 
1092     // If the sign bits of both operands are zero (i.e. we can prove they are
1093     // unsigned inputs), turn this into a udiv.
1094     APInt Mask(APInt::getSignMask(I.getType()->getScalarSizeInBits()));
1095     if (MaskedValueIsZero(Op0, Mask, 0, &I)) {
1096         if (MaskedValueIsZero(Op1, Mask, 0, &I)) {
1097             // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set
1098             auto* BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName());
1099             BO->setIsExact(I.isExact());
1100             return BO;
1101         }
1102 
1103         if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) {
1104             // X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y)
1105             // Safe because the only negative value (1 << Y) can take on is
1106             // INT_MIN, and X sdiv INT_MIN == X udiv INT_MIN == 0 if X doesn't have
1107             // the sign bit set.
1108             auto* BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName());
1109             BO->setIsExact(I.isExact());
1110             return BO;
1111         }
1112     }
1113 
1114     return nullptr;
1115 }
1116 
1117 /// Remove negation and try to convert division into multiplication.
foldFDivConstantDivisor(BinaryOperator & I)1118 static Instruction* foldFDivConstantDivisor(BinaryOperator& I) {
1119     Constant* C;
1120     if (!match(I.getOperand(1), m_Constant(C)))
1121         return nullptr;
1122 
1123     // -X / C --> X / -C
1124     Value* X;
1125     if (match(I.getOperand(0), m_FNeg(m_Value(X))))
1126         return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I);
1127 
1128     // If the constant divisor has an exact inverse, this is always safe. If not,
1129     // then we can still create a reciprocal if fast-math-flags allow it and the
1130     // constant is a regular number (not zero, infinite, or denormal).
1131     if (!(C->hasExactInverseFP() || (I.hasAllowReciprocal() && C->isNormalFP())))
1132         return nullptr;
1133 
1134     // Disallow denormal constants because we don't know what would happen
1135     // on all targets.
1136     // TODO: Use Intrinsic::canonicalize or let function attributes tell us that
1137     // denorms are flushed?
1138     auto* RecipC = ConstantExpr::getFDiv(ConstantFP::get(I.getType(), 1.0), C);
1139     if (!RecipC->isNormalFP())
1140         return nullptr;
1141 
1142     // X / C --> X * (1 / C)
1143     return BinaryOperator::CreateFMulFMF(I.getOperand(0), RecipC, &I);
1144 }
1145 
1146 /// Remove negation and try to reassociate constant math.
foldFDivConstantDividend(BinaryOperator & I)1147 static Instruction* foldFDivConstantDividend(BinaryOperator& I) {
1148     Constant* C;
1149     if (!match(I.getOperand(0), m_Constant(C)))
1150         return nullptr;
1151 
1152     // C / -X --> -C / X
1153     Value* X;
1154     if (match(I.getOperand(1), m_FNeg(m_Value(X))))
1155         return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I);
1156 
1157     if (!I.hasAllowReassoc() || !I.hasAllowReciprocal())
1158         return nullptr;
1159 
1160     // Try to reassociate C / X expressions where X includes another constant.
1161     Constant* C2, * NewC = nullptr;
1162     if (match(I.getOperand(1), m_FMul(m_Value(X), m_Constant(C2)))) {
1163         // C / (X * C2) --> (C / C2) / X
1164         NewC = ConstantExpr::getFDiv(C, C2);
1165     }
1166     else if (match(I.getOperand(1), m_FDiv(m_Value(X), m_Constant(C2)))) {
1167         // C / (X / C2) --> (C * C2) / X
1168         NewC = ConstantExpr::getFMul(C, C2);
1169     }
1170     // Disallow denormal constants because we don't know what would happen
1171     // on all targets.
1172     // TODO: Use Intrinsic::canonicalize or let function attributes tell us that
1173     // denorms are flushed?
1174     if (!NewC || !NewC->isNormalFP())
1175         return nullptr;
1176 
1177     return BinaryOperator::CreateFDivFMF(NewC, X, &I);
1178 }
1179 
visitFDiv(BinaryOperator & I)1180 Instruction* InstCombiner::visitFDiv(BinaryOperator& I) {
1181     if (Value * V = SimplifyFDivInst(I.getOperand(0), I.getOperand(1),
1182         I.getFastMathFlags(),
1183         SQ.getWithInstruction(&I)))
1184         return replaceInstUsesWith(I, V);
1185 
1186     if (Instruction * X = foldShuffledBinop(I))
1187         return X;
1188 
1189     if (Instruction * R = foldFDivConstantDivisor(I))
1190         return R;
1191 
1192     if (Instruction * R = foldFDivConstantDividend(I))
1193         return R;
1194 
1195     Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
1196     if (isa<Constant>(Op0))
1197         if (SelectInst * SI = dyn_cast<SelectInst>(Op1))
1198             if (Instruction * R = FoldOpIntoSelect(I, SI))
1199                 return R;
1200 
1201     if (isa<Constant>(Op1))
1202         if (SelectInst * SI = dyn_cast<SelectInst>(Op0))
1203             if (Instruction * R = FoldOpIntoSelect(I, SI))
1204                 return R;
1205 
1206     if (I.hasAllowReassoc() && I.hasAllowReciprocal()) {
1207         Value* X, * Y;
1208         if (match(Op0, m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))) &&
1209             (!isa<Constant>(Y) || !isa<Constant>(Op1))) {
1210             // (X / Y) / Z => X / (Y * Z)
1211             Value* YZ = Builder.CreateFMulFMF(Y, Op1, &I);
1212             return BinaryOperator::CreateFDivFMF(X, YZ, &I);
1213         }
1214         if (match(Op1, m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))) &&
1215             (!isa<Constant>(Y) || !isa<Constant>(Op0))) {
1216             // Z / (X / Y) => (Y * Z) / X
1217             Value* YZ = Builder.CreateFMulFMF(Y, Op0, &I);
1218             return BinaryOperator::CreateFDivFMF(YZ, X, &I);
1219         }
1220     }
1221 
1222     if (I.hasAllowReassoc() && Op0->hasOneUse() && Op1->hasOneUse()) {
1223         // sin(X) / cos(X) -> tan(X)
1224         // cos(X) / sin(X) -> 1/tan(X) (cotangent)
1225         Value* X;
1226         bool IsTan = match(Op0, m_Intrinsic<Intrinsic::sin>(m_Value(X))) &&
1227             match(Op1, m_Intrinsic<Intrinsic::cos>(m_Specific(X)));
1228         bool IsCot =
1229             !IsTan && match(Op0, m_Intrinsic<Intrinsic::cos>(m_Value(X))) &&
1230             match(Op1, m_Intrinsic<Intrinsic::sin>(m_Specific(X)));
1231 
1232         if ((IsTan || IsCot) && hasUnaryFloatFn(&TLI, I.getType(), LibFunc_tan,
1233             LibFunc_tanf, LibFunc_tanl)) {
1234             IRBuilder<> B(&I);
1235             IRBuilder<>::FastMathFlagGuard FMFGuard(B);
1236             B.setFastMathFlags(I.getFastMathFlags());
1237             AttributeList Attrs = CallSite(Op0).getCalledFunction()->getAttributes();
1238             Value* Res = emitUnaryFloatFnCall(X, TLI.getName(LibFunc_tan), B, Attrs);
1239             if (IsCot)
1240                 Res = B.CreateFDiv(ConstantFP::get(I.getType(), 1.0), Res);
1241             return replaceInstUsesWith(I, Res);
1242         }
1243     }
1244 
1245     // -X / -Y -> X / Y
1246     Value* X, * Y;
1247     if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) {
1248         I.setOperand(0, X);
1249         I.setOperand(1, Y);
1250         return &I;
1251     }
1252 
1253     // X / (X * Y) --> 1.0 / Y
1254     // Reassociate to (X / X -> 1.0) is legal when NaNs are not allowed.
1255     // We can ignore the possibility that X is infinity because INF/INF is NaN.
1256     if (I.hasNoNaNs() && I.hasAllowReassoc() &&
1257         match(Op1, m_c_FMul(m_Specific(Op0), m_Value(Y)))) {
1258         I.setOperand(0, ConstantFP::get(I.getType(), 1.0));
1259         I.setOperand(1, Y);
1260         return &I;
1261     }
1262 
1263     return nullptr;
1264 }
1265 
1266 /// This function implements the transforms common to both integer remainder
1267 /// instructions (urem and srem). It is called by the visitors to those integer
1268 /// remainder instructions.
1269 /// Common integer remainder transforms
commonIRemTransforms(BinaryOperator & I)1270 Instruction* InstCombiner::commonIRemTransforms(BinaryOperator& I) {
1271     Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
1272 
1273     // The RHS is known non-zero.
1274     if (Value * V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) {
1275         I.setOperand(1, V);
1276         return &I;
1277     }
1278 
1279     // Handle cases involving: rem X, (select Cond, Y, Z)
1280     if (simplifyDivRemOfSelectWithZeroOp(I))
1281         return &I;
1282 
1283     if (isa<Constant>(Op1)) {
1284         if (Instruction * Op0I = dyn_cast<Instruction>(Op0)) {
1285             if (SelectInst * SI = dyn_cast<SelectInst>(Op0I)) {
1286                 if (Instruction * R = FoldOpIntoSelect(I, SI))
1287                     return R;
1288             }
1289             else if (auto * PN = dyn_cast<PHINode>(Op0I)) {
1290                 const APInt* Op1Int;
1291                 if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() &&
1292                     (I.getOpcode() == Instruction::URem ||
1293                         !Op1Int->isMinSignedValue())) {
1294                     // foldOpIntoPhi will speculate instructions to the end of the PHI's
1295                     // predecessor blocks, so do this only if we know the srem or urem
1296                     // will not fault.
1297                     if (Instruction * NV = foldOpIntoPhi(I, PN))
1298                         return NV;
1299                 }
1300             }
1301 
1302             // See if we can fold away this rem instruction.
1303             if (SimplifyDemandedInstructionBits(I))
1304                 return &I;
1305         }
1306     }
1307 
1308     return nullptr;
1309 }
1310 
visitURem(BinaryOperator & I)1311 Instruction* InstCombiner::visitURem(BinaryOperator& I) {
1312     if (Value * V = SimplifyURemInst(I.getOperand(0), I.getOperand(1),
1313         SQ.getWithInstruction(&I)))
1314         return replaceInstUsesWith(I, V);
1315 
1316     if (Instruction * X = foldShuffledBinop(I))
1317         return X;
1318 
1319     if (Instruction * common = commonIRemTransforms(I))
1320         return common;
1321 
1322     if (Instruction * NarrowRem = narrowUDivURem(I, Builder))
1323         return NarrowRem;
1324 
1325     // X urem Y -> X and Y-1, where Y is a power of 2,
1326     Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
1327     Type* Ty = I.getType();
1328     if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) {
1329         Constant* N1 = Constant::getAllOnesValue(Ty);
1330         Value* Add = Builder.CreateAdd(Op1, N1);
1331         return BinaryOperator::CreateAnd(Op0, Add);
1332     }
1333 
1334     // 1 urem X -> zext(X != 1)
1335     if (match(Op0, m_One()))
1336         return CastInst::CreateZExtOrBitCast(Builder.CreateICmpNE(Op1, Op0), Ty);
1337 
1338     // X urem C -> X < C ? X : X - C, where C >= signbit.
1339     if (match(Op1, m_Negative())) {
1340         Value* Cmp = Builder.CreateICmpULT(Op0, Op1);
1341         Value* Sub = Builder.CreateSub(Op0, Op1);
1342         return SelectInst::Create(Cmp, Op0, Sub);
1343     }
1344 
1345     // If the divisor is a sext of a boolean, then the divisor must be max
1346     // unsigned value (-1). Therefore, the remainder is Op0 unless Op0 is also
1347     // max unsigned value. In that case, the remainder is 0:
1348     // urem Op0, (sext i1 X) --> (Op0 == -1) ? 0 : Op0
1349     Value* X;
1350     if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
1351         Value* Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::getAllOnesValue(Ty));
1352         return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Op0);
1353     }
1354 
1355     return nullptr;
1356 }
1357 
visitSRem(BinaryOperator & I)1358 Instruction* InstCombiner::visitSRem(BinaryOperator& I) {
1359     if (Value * V = SimplifySRemInst(I.getOperand(0), I.getOperand(1),
1360         SQ.getWithInstruction(&I)))
1361         return replaceInstUsesWith(I, V);
1362 
1363     if (Instruction * X = foldShuffledBinop(I))
1364         return X;
1365 
1366     // Handle the integer rem common cases
1367     if (Instruction * Common = commonIRemTransforms(I))
1368         return Common;
1369 
1370     Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
1371     {
1372         const APInt* Y;
1373         // X % -Y -> X % Y
1374         if (match(Op1, m_Negative(Y)) && !Y->isMinSignedValue()) {
1375             Worklist.AddValue(I.getOperand(1));
1376             I.setOperand(1, ConstantInt::get(I.getType(), -*Y));
1377             return &I;
1378         }
1379     }
1380 
1381     // If the sign bits of both operands are zero (i.e. we can prove they are
1382     // unsigned inputs), turn this into a urem.
1383     APInt Mask(APInt::getSignMask(I.getType()->getScalarSizeInBits()));
1384     if (MaskedValueIsZero(Op1, Mask, 0, &I) &&
1385         MaskedValueIsZero(Op0, Mask, 0, &I)) {
1386         // X srem Y -> X urem Y, iff X and Y don't have sign bit set
1387         return BinaryOperator::CreateURem(Op0, Op1, I.getName());
1388     }
1389 
1390     // If it's a constant vector, flip any negative values positive.
1391     if (isa<ConstantVector>(Op1) || isa<ConstantDataVector>(Op1)) {
1392         Constant* C = cast<Constant>(Op1);
1393         unsigned VWidth = C->getType()->getVectorNumElements();
1394 
1395         bool hasNegative = false;
1396         bool hasMissing = false;
1397         for (unsigned i = 0; i != VWidth; ++i) {
1398             Constant* Elt = C->getAggregateElement(i);
1399             if (!Elt) {
1400                 hasMissing = true;
1401                 break;
1402             }
1403 
1404             if (ConstantInt * RHS = dyn_cast<ConstantInt>(Elt))
1405                 if (RHS->isNegative())
1406                     hasNegative = true;
1407         }
1408 
1409         if (hasNegative && !hasMissing) {
1410             SmallVector<Constant*, 16> Elts(VWidth);
1411             for (unsigned i = 0; i != VWidth; ++i) {
1412                 Elts[i] = C->getAggregateElement(i);  // Handle undef, etc.
1413                 if (ConstantInt * RHS = dyn_cast<ConstantInt>(Elts[i])) {
1414                     if (RHS->isNegative())
1415                         Elts[i] = cast<ConstantInt>(ConstantExpr::getNeg(RHS));
1416                 }
1417             }
1418 
1419             Constant* NewRHSV = ConstantVector::get(Elts);
1420             if (NewRHSV != C) {  // Don't loop on -MININT
1421                 Worklist.AddValue(I.getOperand(1));
1422                 I.setOperand(1, NewRHSV);
1423                 return &I;
1424             }
1425         }
1426     }
1427 
1428     return nullptr;
1429 }
1430 
visitFRem(BinaryOperator & I)1431 Instruction* InstCombiner::visitFRem(BinaryOperator& I) {
1432     if (Value * V = SimplifyFRemInst(I.getOperand(0), I.getOperand(1),
1433         I.getFastMathFlags(),
1434         SQ.getWithInstruction(&I)))
1435         return replaceInstUsesWith(I, V);
1436 
1437     if (Instruction * X = foldShuffledBinop(I))
1438         return X;
1439 
1440     return nullptr;
1441 }
1442