1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2018-2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 /*========================== begin_copyright_notice ============================
10 
11 This file is distributed under the University of Illinois Open Source License.
12 See LICENSE.TXT for details.
13 
14 ============================= end_copyright_notice ===========================*/
15 
16 // This file implements the visit functions for add, fadd, sub, and fsub.
17 
18 #include "common/LLVMWarningsPush.hpp"
19 #include "InstCombineInternal.h"
20 #include "llvm/ADT/APFloat.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Analysis/InstructionSimplify.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/IR/Constant.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/InstrTypes.h"
29 #include "llvm/IR/Instruction.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/Operator.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/IR/Type.h"
34 #include "llvm/IR/Value.h"
35 #include "llvm/Support/AlignOf.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/KnownBits.h"
38 #include "common/LLVMWarningsPop.hpp"
39 #include <utility>
40 #include "Probe/Assertion.h"
41 
42 using namespace llvm;
43 using namespace PatternMatch;
44 using namespace IGCombiner;
45 
46 #define DEBUG_TYPE "instcombine"
47 
48 namespace {
49 
50     /// Class representing coefficient of floating-point addend.
51     /// This class needs to be highly efficient, which is especially true for
52     /// the constructor. As of I write this comment, the cost of the default
53     /// constructor is merely 4-byte-store-zero (Assuming compiler is able to
54     /// perform write-merging).
55     ///
56     class FAddendCoef {
57     public:
58         // The constructor has to initialize a APFloat, which is unnecessary for
59         // most addends which have coefficient either 1 or -1. So, the constructor
60         // is expensive. In order to avoid the cost of the constructor, we should
61         // reuse some instances whenever possible. The pre-created instances
62         // FAddCombine::Add[0-5] embodies this idea.
63         FAddendCoef() = default;
64         ~FAddendCoef();
65 
66         // If possible, don't define operator+/operator- etc because these
67         // operators inevitably call FAddendCoef's constructor which is not cheap.
68         void operator=(const FAddendCoef& A);
69         void operator+=(const FAddendCoef& A);
70         void operator*=(const FAddendCoef& S);
71 
set(short C)72         void set(short C) {
73             IGC_ASSERT_MESSAGE(!insaneIntVal(C), "Insane coefficient");
74             IsFp = false; IntVal = C;
75         }
76 
77         void set(const APFloat& C);
78 
79         void negate();
80 
isZero() const81         bool isZero() const { return isInt() ? !IntVal : getFpVal().isZero(); }
82         Value* getValue(Type*) const;
83 
isOne() const84         bool isOne() const { return isInt() && IntVal == 1; }
isTwo() const85         bool isTwo() const { return isInt() && IntVal == 2; }
isMinusOne() const86         bool isMinusOne() const { return isInt() && IntVal == -1; }
isMinusTwo() const87         bool isMinusTwo() const { return isInt() && IntVal == -2; }
88 
89     private:
insaneIntVal(int V)90         bool insaneIntVal(int V) { return V > 4 || V < -4; }
91 
getFpValPtr()92         APFloat* getFpValPtr()
93         {
94             return reinterpret_cast<APFloat*>(&FpValBuf.buffer[0]);
95         }
96 
getFpValPtr() const97         const APFloat* getFpValPtr() const
98         {
99             return reinterpret_cast<const APFloat*>(&FpValBuf.buffer[0]);
100         }
101 
getFpVal() const102         const APFloat& getFpVal() const {
103             IGC_ASSERT_MESSAGE(IsFp, "Incorret state");
104             IGC_ASSERT_MESSAGE(BufHasFpVal, "Incorret state");
105             return *getFpValPtr();
106         }
107 
getFpVal()108         APFloat& getFpVal() {
109             IGC_ASSERT_MESSAGE(IsFp, "Incorret state");
110             IGC_ASSERT_MESSAGE(BufHasFpVal, "Incorret state");
111             return *getFpValPtr();
112         }
113 
isInt() const114         bool isInt() const { return !IsFp; }
115 
116         // If the coefficient is represented by an integer, promote it to a
117         // floating point.
118         void convertToFpType(const fltSemantics& Sem);
119 
120         // Construct an APFloat from a signed integer.
121         // TODO: We should get rid of this function when APFloat can be constructed
122         //       from an *SIGNED* integer.
123         APFloat createAPFloatFromInt(const fltSemantics& Sem, int Val);
124 
125         bool IsFp = false;
126 
127         // True iff FpValBuf contains an instance of APFloat.
128         bool BufHasFpVal = false;
129 
130         // The integer coefficient of an individual addend is either 1 or -1,
131         // and we try to simplify at most 4 addends from neighboring at most
132         // two instructions. So the range of <IntVal> falls in [-4, 4]. APInt
133         // is overkill of this end.
134         short IntVal = 0;
135 
136         AlignedCharArrayUnion<APFloat> FpValBuf;
137     };
138 
139     /// FAddend is used to represent floating-point addend. An addend is
140     /// represented as <C, V>, where the V is a symbolic value, and C is a
141     /// constant coefficient. A constant addend is represented as <C, 0>.
142     class FAddend {
143     public:
144         FAddend() = default;
145 
operator +=(const FAddend & T)146         void operator+=(const FAddend& T) {
147             IGC_ASSERT_MESSAGE((Val == T.Val), "Symbolic-values disagree");
148             Coeff += T.Coeff;
149         }
150 
getSymVal() const151         Value* getSymVal() const { return Val; }
getCoef() const152         const FAddendCoef& getCoef() const { return Coeff; }
153 
isConstant() const154         bool isConstant() const { return Val == nullptr; }
isZero() const155         bool isZero() const { return Coeff.isZero(); }
156 
set(short Coefficient,Value * V)157         void set(short Coefficient, Value* V) {
158             Coeff.set(Coefficient);
159             Val = V;
160         }
set(const APFloat & Coefficient,Value * V)161         void set(const APFloat& Coefficient, Value* V) {
162             Coeff.set(Coefficient);
163             Val = V;
164         }
set(const ConstantFP * Coefficient,Value * V)165         void set(const ConstantFP* Coefficient, Value* V) {
166             Coeff.set(Coefficient->getValueAPF());
167             Val = V;
168         }
169 
negate()170         void negate() { Coeff.negate(); }
171 
172         /// Drill down the U-D chain one step to find the definition of V, and
173         /// try to break the definition into one or two addends.
174         static unsigned drillValueDownOneStep(Value* V, FAddend& A0, FAddend& A1);
175 
176         /// Similar to FAddend::drillDownOneStep() except that the value being
177         /// splitted is the addend itself.
178         unsigned drillAddendDownOneStep(FAddend& Addend0, FAddend& Addend1) const;
179 
180     private:
Scale(const FAddendCoef & ScaleAmt)181         void Scale(const FAddendCoef& ScaleAmt) { Coeff *= ScaleAmt; }
182 
183         // This addend has the value of "Coeff * Val".
184         Value* Val = nullptr;
185         FAddendCoef Coeff;
186     };
187 
188     /// FAddCombine is the class for optimizing an unsafe fadd/fsub along
189     /// with its neighboring at most two instructions.
190     ///
191     class FAddCombine {
192     public:
FAddCombine(InstCombiner::BuilderTy & B)193         FAddCombine(InstCombiner::BuilderTy& B) : Builder(B) {}
194 
195         Value* simplify(Instruction* FAdd);
196 
197     private:
198         using AddendVect = SmallVector<const FAddend*, 4>;
199 
200         Value* simplifyFAdd(AddendVect& V, unsigned InstrQuota);
201 
202         Value* performFactorization(Instruction* I);
203 
204         /// Convert given addend to a Value
205         Value* createAddendVal(const FAddend& A, bool& NeedNeg);
206 
207         /// Return the number of instructions needed to emit the N-ary addition.
208         unsigned calcInstrNumber(const AddendVect& Vect);
209 
210         Value* createFSub(Value* Opnd0, Value* Opnd1);
211         Value* createFAdd(Value* Opnd0, Value* Opnd1);
212         Value* createFMul(Value* Opnd0, Value* Opnd1);
213         Value* createFDiv(Value* Opnd0, Value* Opnd1);
214         Value* createFNeg(Value* V);
215         Value* createNaryFAdd(const AddendVect& Opnds, unsigned InstrQuota);
216         void createInstPostProc(Instruction* NewInst, bool NoNumber = false);
217 
218         InstCombiner::BuilderTy& Builder;
219         Instruction* Instr = nullptr;
220 
221         unsigned InstructionCounter;
222     };
223 
224 } // end anonymous namespace
225 
226 //===----------------------------------------------------------------------===//
227 //
228 // Implementation of
229 //    {FAddendCoef, FAddend, FAddition, FAddCombine}.
230 //
231 //===----------------------------------------------------------------------===//
~FAddendCoef()232 FAddendCoef::~FAddendCoef() {
233     if (BufHasFpVal)
234         getFpValPtr()->~APFloat();
235 }
236 
set(const APFloat & C)237 void FAddendCoef::set(const APFloat& C) {
238     APFloat* P = getFpValPtr();
239 
240     if (isInt()) {
241         // As the buffer is meanless byte stream, we cannot call
242         // APFloat::operator=().
243         new(P) APFloat(C);
244     }
245     else
246         *P = C;
247 
248     IsFp = BufHasFpVal = true;
249 }
250 
convertToFpType(const fltSemantics & Sem)251 void FAddendCoef::convertToFpType(const fltSemantics& Sem) {
252     if (!isInt())
253         return;
254 
255     APFloat* P = getFpValPtr();
256     if (IntVal > 0)
257         new(P) APFloat(Sem, IntVal);
258     else {
259         new(P) APFloat(Sem, 0 - IntVal);
260         P->changeSign();
261     }
262     IsFp = BufHasFpVal = true;
263 }
264 
createAPFloatFromInt(const fltSemantics & Sem,int Val)265 APFloat FAddendCoef::createAPFloatFromInt(const fltSemantics& Sem, int Val) {
266     if (Val >= 0)
267         return APFloat(Sem, Val);
268 
269     APFloat T(Sem, 0 - Val);
270     T.changeSign();
271 
272     return T;
273 }
274 
operator =(const FAddendCoef & That)275 void FAddendCoef::operator=(const FAddendCoef& That) {
276     if (That.isInt())
277         set(That.IntVal);
278     else
279         set(That.getFpVal());
280 }
281 
operator +=(const FAddendCoef & That)282 void FAddendCoef::operator+=(const FAddendCoef& That) {
283     enum APFloat::roundingMode RndMode = APFloat::rmNearestTiesToEven;
284     if (isInt() == That.isInt()) {
285         if (isInt())
286             IntVal += That.IntVal;
287         else
288             getFpVal().add(That.getFpVal(), RndMode);
289         return;
290     }
291 
292     if (isInt()) {
293         const APFloat& T = That.getFpVal();
294         convertToFpType(T.getSemantics());
295         getFpVal().add(T, RndMode);
296         return;
297     }
298 
299     APFloat& T = getFpVal();
300     T.add(createAPFloatFromInt(T.getSemantics(), That.IntVal), RndMode);
301 }
302 
operator *=(const FAddendCoef & That)303 void FAddendCoef::operator*=(const FAddendCoef& That) {
304     if (That.isOne())
305         return;
306 
307     if (That.isMinusOne()) {
308         negate();
309         return;
310     }
311 
312     if (isInt() && That.isInt()) {
313         int Res = IntVal * (int)That.IntVal;
314         IGC_ASSERT_MESSAGE(!insaneIntVal(Res), "Insane int value");
315         IntVal = Res;
316         return;
317     }
318 
319     const fltSemantics& Semantic =
320         isInt() ? That.getFpVal().getSemantics() : getFpVal().getSemantics();
321 
322     if (isInt())
323         convertToFpType(Semantic);
324     APFloat& F0 = getFpVal();
325 
326     if (That.isInt())
327         F0.multiply(createAPFloatFromInt(Semantic, That.IntVal),
328             APFloat::rmNearestTiesToEven);
329     else
330         F0.multiply(That.getFpVal(), APFloat::rmNearestTiesToEven);
331 }
332 
negate()333 void FAddendCoef::negate() {
334     if (isInt())
335         IntVal = 0 - IntVal;
336     else
337         getFpVal().changeSign();
338 }
339 
getValue(Type * Ty) const340 Value* FAddendCoef::getValue(Type* Ty) const {
341     return isInt() ?
342         ConstantFP::get(Ty, float(IntVal)) :
343         ConstantFP::get(Ty->getContext(), getFpVal());
344 }
345 
346 // The definition of <Val>     Addends
347 // =========================================
348 //  A + B                     <1, A>, <1,B>
349 //  A - B                     <1, A>, <1,B>
350 //  0 - B                     <-1, B>
351 //  C * A,                    <C, A>
352 //  A + C                     <1, A> <C, NULL>
353 //  0 +/- 0                   <0, NULL> (corner case)
354 //
355 // Legend: A and B are not constant, C is constant
drillValueDownOneStep(Value * Val,FAddend & Addend0,FAddend & Addend1)356 unsigned FAddend::drillValueDownOneStep
357 (Value* Val, FAddend& Addend0, FAddend& Addend1) {
358     Instruction* I = nullptr;
359     if (!Val || !(I = dyn_cast<Instruction>(Val)))
360         return 0;
361 
362     unsigned Opcode = I->getOpcode();
363 
364     if (Opcode == Instruction::FAdd || Opcode == Instruction::FSub) {
365         ConstantFP* C0, * C1;
366         Value* Opnd0 = I->getOperand(0);
367         Value* Opnd1 = I->getOperand(1);
368         if ((C0 = dyn_cast<ConstantFP>(Opnd0)) && C0->isZero())
369             Opnd0 = nullptr;
370 
371         if ((C1 = dyn_cast<ConstantFP>(Opnd1)) && C1->isZero())
372             Opnd1 = nullptr;
373 
374         if (Opnd0) {
375             if (!C0)
376                 Addend0.set(1, Opnd0);
377             else
378                 Addend0.set(C0, nullptr);
379         }
380 
381         if (Opnd1) {
382             FAddend& Addend = Opnd0 ? Addend1 : Addend0;
383             if (!C1)
384                 Addend.set(1, Opnd1);
385             else
386                 Addend.set(C1, nullptr);
387             if (Opcode == Instruction::FSub)
388                 Addend.negate();
389         }
390 
391         if (Opnd0 || Opnd1)
392             return Opnd0 && Opnd1 ? 2 : 1;
393 
394         // Both operands are zero. Weird!
395         Addend0.set(APFloat(C0->getValueAPF().getSemantics()), nullptr);
396         return 1;
397     }
398 
399     if (I->getOpcode() == Instruction::FMul) {
400         Value* V0 = I->getOperand(0);
401         Value* V1 = I->getOperand(1);
402         if (ConstantFP * C = dyn_cast<ConstantFP>(V0)) {
403             Addend0.set(C, V1);
404             return 1;
405         }
406 
407         if (ConstantFP * C = dyn_cast<ConstantFP>(V1)) {
408             Addend0.set(C, V0);
409             return 1;
410         }
411     }
412 
413     return 0;
414 }
415 
416 // Try to break *this* addend into two addends. e.g. Suppose this addend is
417 // <2.3, V>, and V = X + Y, by calling this function, we obtain two addends,
418 // i.e. <2.3, X> and <2.3, Y>.
drillAddendDownOneStep(FAddend & Addend0,FAddend & Addend1) const419 unsigned FAddend::drillAddendDownOneStep
420 (FAddend& Addend0, FAddend& Addend1) const {
421     if (isConstant())
422         return 0;
423 
424     unsigned BreakNum = FAddend::drillValueDownOneStep(Val, Addend0, Addend1);
425     if (!BreakNum || Coeff.isOne())
426         return BreakNum;
427 
428     Addend0.Scale(Coeff);
429 
430     if (BreakNum == 2)
431         Addend1.Scale(Coeff);
432 
433     return BreakNum;
434 }
435 
436 // Try to perform following optimization on the input instruction I. Return the
437 // simplified expression if was successful; otherwise, return 0.
438 //
439 //   Instruction "I" is                Simplified into
440 // -------------------------------------------------------
441 //   (x * y) +/- (x * z)               x * (y +/- z)
442 //   (y / x) +/- (z / x)               (y +/- z) / x
performFactorization(Instruction * I)443 Value* FAddCombine::performFactorization(Instruction* I) {
444     IGC_ASSERT_MESSAGE((I->getOpcode() == Instruction::FAdd) || (I->getOpcode() == Instruction::FSub), "Expect add/sub");
445 
446     Instruction* I0 = dyn_cast<Instruction>(I->getOperand(0));
447     Instruction* I1 = dyn_cast<Instruction>(I->getOperand(1));
448 
449     if (!I0 || !I1 || I0->getOpcode() != I1->getOpcode())
450         return nullptr;
451 
452     bool isMpy = false;
453     if (I0->getOpcode() == Instruction::FMul)
454         isMpy = true;
455     else if (I0->getOpcode() != Instruction::FDiv)
456         return nullptr;
457 
458     Value* Opnd0_0 = I0->getOperand(0);
459     Value* Opnd0_1 = I0->getOperand(1);
460     Value* Opnd1_0 = I1->getOperand(0);
461     Value* Opnd1_1 = I1->getOperand(1);
462 
463     //  Input Instr I       Factor   AddSub0  AddSub1
464     //  ----------------------------------------------
465     // (x*y) +/- (x*z)        x        y         z
466     // (y/x) +/- (z/x)        x        y         z
467     Value* Factor = nullptr;
468     Value* AddSub0 = nullptr, * AddSub1 = nullptr;
469 
470     if (isMpy) {
471         if (Opnd0_0 == Opnd1_0 || Opnd0_0 == Opnd1_1)
472             Factor = Opnd0_0;
473         else if (Opnd0_1 == Opnd1_0 || Opnd0_1 == Opnd1_1)
474             Factor = Opnd0_1;
475 
476         if (Factor) {
477             AddSub0 = (Factor == Opnd0_0) ? Opnd0_1 : Opnd0_0;
478             AddSub1 = (Factor == Opnd1_0) ? Opnd1_1 : Opnd1_0;
479         }
480     }
481     else if (Opnd0_1 == Opnd1_1) {
482         Factor = Opnd0_1;
483         AddSub0 = Opnd0_0;
484         AddSub1 = Opnd1_0;
485     }
486 
487     if (!Factor)
488         return nullptr;
489 
490     FastMathFlags Flags;
491     Flags.setFast();
492     if (I0) Flags &= I->getFastMathFlags();
493     if (I1) Flags &= I->getFastMathFlags();
494 
495     // Create expression "NewAddSub = AddSub0 +/- AddsSub1"
496     Value* NewAddSub = (I->getOpcode() == Instruction::FAdd) ?
497         createFAdd(AddSub0, AddSub1) :
498         createFSub(AddSub0, AddSub1);
499     if (ConstantFP * CFP = dyn_cast<ConstantFP>(NewAddSub)) {
500         const APFloat& F = CFP->getValueAPF();
501         if (!F.isNormal())
502             return nullptr;
503     }
504     else if (Instruction * II = dyn_cast<Instruction>(NewAddSub))
505         II->setFastMathFlags(Flags);
506 
507     if (isMpy) {
508         Value* RI = createFMul(Factor, NewAddSub);
509         if (Instruction * II = dyn_cast<Instruction>(RI))
510             II->setFastMathFlags(Flags);
511         return RI;
512     }
513 
514     Value* RI = createFDiv(NewAddSub, Factor);
515     if (Instruction * II = dyn_cast<Instruction>(RI))
516         II->setFastMathFlags(Flags);
517     return RI;
518 }
519 
simplify(Instruction * I)520 Value* FAddCombine::simplify(Instruction* I) {
521     IGC_ASSERT_MESSAGE(I->hasAllowReassoc(), "Expected 'reassoc'+'nsz' instruction");
522     IGC_ASSERT_MESSAGE(I->hasNoSignedZeros(), "Expected 'reassoc'+'nsz' instruction");
523 
524     // Currently we are not able to handle vector type.
525     if (I->getType()->isVectorTy())
526         return nullptr;
527 
528     IGC_ASSERT_MESSAGE((I->getOpcode() == Instruction::FAdd) || (I->getOpcode() == Instruction::FSub), "Expect add/sub");
529 
530     // Save the instruction before calling other member-functions.
531     Instr = I;
532 
533     FAddend Opnd0, Opnd1, Opnd0_0, Opnd0_1, Opnd1_0, Opnd1_1;
534 
535     unsigned OpndNum = FAddend::drillValueDownOneStep(I, Opnd0, Opnd1);
536 
537     // Step 1: Expand the 1st addend into Opnd0_0 and Opnd0_1.
538     unsigned Opnd0_ExpNum = 0;
539     unsigned Opnd1_ExpNum = 0;
540 
541     if (!Opnd0.isConstant())
542         Opnd0_ExpNum = Opnd0.drillAddendDownOneStep(Opnd0_0, Opnd0_1);
543 
544     // Step 2: Expand the 2nd addend into Opnd1_0 and Opnd1_1.
545     if (OpndNum == 2 && !Opnd1.isConstant())
546         Opnd1_ExpNum = Opnd1.drillAddendDownOneStep(Opnd1_0, Opnd1_1);
547 
548     // Step 3: Try to optimize Opnd0_0 + Opnd0_1 + Opnd1_0 + Opnd1_1
549     if (Opnd0_ExpNum && Opnd1_ExpNum) {
550         AddendVect AllOpnds;
551         AllOpnds.push_back(&Opnd0_0);
552         AllOpnds.push_back(&Opnd1_0);
553         if (Opnd0_ExpNum == 2)
554             AllOpnds.push_back(&Opnd0_1);
555         if (Opnd1_ExpNum == 2)
556             AllOpnds.push_back(&Opnd1_1);
557 
558         // Compute instruction quota. We should save at least one instruction.
559         unsigned InstQuota = 0;
560 
561         Value* V0 = I->getOperand(0);
562         Value* V1 = I->getOperand(1);
563         InstQuota = ((!isa<Constant>(V0) && V0->hasOneUse()) &&
564             (!isa<Constant>(V1) && V1->hasOneUse())) ? 2 : 1;
565 
566         if (Value * R = simplifyFAdd(AllOpnds, InstQuota))
567             return R;
568     }
569 
570     if (OpndNum != 2) {
571         // The input instruction is : "I=0.0 +/- V". If the "V" were able to be
572         // splitted into two addends, say "V = X - Y", the instruction would have
573         // been optimized into "I = Y - X" in the previous steps.
574         //
575         const FAddendCoef& CE = Opnd0.getCoef();
576         return CE.isOne() ? Opnd0.getSymVal() : nullptr;
577     }
578 
579     // step 4: Try to optimize Opnd0 + Opnd1_0 [+ Opnd1_1]
580     if (Opnd1_ExpNum) {
581         AddendVect AllOpnds;
582         AllOpnds.push_back(&Opnd0);
583         AllOpnds.push_back(&Opnd1_0);
584         if (Opnd1_ExpNum == 2)
585             AllOpnds.push_back(&Opnd1_1);
586 
587         if (Value * R = simplifyFAdd(AllOpnds, 1))
588             return R;
589     }
590 
591     // step 5: Try to optimize Opnd1 + Opnd0_0 [+ Opnd0_1]
592     if (Opnd0_ExpNum) {
593         AddendVect AllOpnds;
594         AllOpnds.push_back(&Opnd1);
595         AllOpnds.push_back(&Opnd0_0);
596         if (Opnd0_ExpNum == 2)
597             AllOpnds.push_back(&Opnd0_1);
598 
599         if (Value * R = simplifyFAdd(AllOpnds, 1))
600             return R;
601     }
602 
603     // step 6: Try factorization as the last resort,
604     return performFactorization(I);
605 }
606 
simplifyFAdd(AddendVect & Addends,unsigned InstrQuota)607 Value* FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) {
608     unsigned AddendNum = Addends.size();
609     IGC_ASSERT_MESSAGE(AddendNum <= 4, "Too many addends");
610 
611     // For saving intermediate results;
612     unsigned NextTmpIdx = 0;
613     FAddend TmpResult[3];
614 
615     // Points to the constant addend of the resulting simplified expression.
616     // If the resulting expr has constant-addend, this constant-addend is
617     // desirable to reside at the top of the resulting expression tree. Placing
618     // constant close to supper-expr(s) will potentially reveal some optimization
619     // opportunities in super-expr(s).
620     const FAddend* ConstAdd = nullptr;
621 
622     // Simplified addends are placed <SimpVect>.
623     AddendVect SimpVect;
624 
625     // The outer loop works on one symbolic-value at a time. Suppose the input
626     // addends are : <a1, x>, <b1, y>, <a2, x>, <c1, z>, <b2, y>, ...
627     // The symbolic-values will be processed in this order: x, y, z.
628     for (unsigned SymIdx = 0; SymIdx < AddendNum; SymIdx++) {
629 
630         const FAddend* ThisAddend = Addends[SymIdx];
631         if (!ThisAddend) {
632             // This addend was processed before.
633             continue;
634         }
635 
636         Value* Val = ThisAddend->getSymVal();
637         unsigned StartIdx = SimpVect.size();
638         SimpVect.push_back(ThisAddend);
639 
640         // The inner loop collects addends sharing same symbolic-value, and these
641         // addends will be later on folded into a single addend. Following above
642         // example, if the symbolic value "y" is being processed, the inner loop
643         // will collect two addends "<b1,y>" and "<b2,Y>". These two addends will
644         // be later on folded into "<b1+b2, y>".
645         for (unsigned SameSymIdx = SymIdx + 1;
646             SameSymIdx < AddendNum; SameSymIdx++) {
647             const FAddend* T = Addends[SameSymIdx];
648             if (T && T->getSymVal() == Val) {
649                 // Set null such that next iteration of the outer loop will not process
650                 // this addend again.
651                 Addends[SameSymIdx] = nullptr;
652                 SimpVect.push_back(T);
653             }
654         }
655 
656         // If multiple addends share same symbolic value, fold them together.
657         if (StartIdx + 1 != SimpVect.size()) {
658             FAddend& R = TmpResult[NextTmpIdx++];
659             R = *SimpVect[StartIdx];
660             for (unsigned Idx = StartIdx + 1; Idx < SimpVect.size(); Idx++)
661                 R += *SimpVect[Idx];
662 
663             // Pop all addends being folded and push the resulting folded addend.
664             SimpVect.resize(StartIdx);
665             if (Val) {
666                 if (!R.isZero()) {
667                     SimpVect.push_back(&R);
668                 }
669             }
670             else {
671                 // Don't push constant addend at this time. It will be the last element
672                 // of <SimpVect>.
673                 ConstAdd = &R;
674             }
675         }
676     }
677 
678     IGC_ASSERT_MESSAGE((NextTmpIdx <= array_lengthof(TmpResult) + 1), "out-of-bound access");
679 
680     if (ConstAdd)
681         SimpVect.push_back(ConstAdd);
682 
683     Value* Result;
684     if (!SimpVect.empty())
685         Result = createNaryFAdd(SimpVect, InstrQuota);
686     else {
687         // The addition is folded to 0.0.
688         Result = ConstantFP::get(Instr->getType(), 0.0);
689     }
690 
691     return Result;
692 }
693 
createNaryFAdd(const AddendVect & Opnds,unsigned InstrQuota)694 Value* FAddCombine::createNaryFAdd
695 (const AddendVect& Opnds, unsigned InstrQuota) {
696     IGC_ASSERT_MESSAGE(!Opnds.empty(), "Expect at least one addend");
697 
698     // Step 1: Check if the # of instructions needed exceeds the quota.
699 
700     unsigned InstrNeeded = calcInstrNumber(Opnds);
701     if (InstrNeeded > InstrQuota)
702         return nullptr;
703 
704     InstructionCounter = 0;
705 
706     // step 2: Emit the N-ary addition.
707     // Note that at most three instructions are involved in Fadd-InstCombine: the
708     // addition in question, and at most two neighboring instructions.
709     // The resulting optimized addition should have at least one less instruction
710     // than the original addition expression tree. This implies that the resulting
711     // N-ary addition has at most two instructions, and we don't need to worry
712     // about tree-height when constructing the N-ary addition.
713 
714     Value* LastVal = nullptr;
715     bool LastValNeedNeg = false;
716 
717     // Iterate the addends, creating fadd/fsub using adjacent two addends.
718     for (const FAddend* Opnd : Opnds) {
719         bool NeedNeg;
720         Value* V = createAddendVal(*Opnd, NeedNeg);
721         if (!LastVal) {
722             LastVal = V;
723             LastValNeedNeg = NeedNeg;
724             continue;
725         }
726 
727         if (LastValNeedNeg == NeedNeg) {
728             LastVal = createFAdd(LastVal, V);
729             continue;
730         }
731 
732         if (LastValNeedNeg)
733             LastVal = createFSub(V, LastVal);
734         else
735             LastVal = createFSub(LastVal, V);
736 
737         LastValNeedNeg = false;
738     }
739 
740     if (LastValNeedNeg) {
741         LastVal = createFNeg(LastVal);
742     }
743 
744     IGC_ASSERT_MESSAGE((InstructionCounter == InstrNeeded), "Inconsistent in instruction numbers");
745 
746     return LastVal;
747 }
748 
createFSub(Value * Opnd0,Value * Opnd1)749 Value* FAddCombine::createFSub(Value* Opnd0, Value* Opnd1) {
750     Value* V = Builder.CreateFSub(Opnd0, Opnd1);
751     if (Instruction * I = dyn_cast<Instruction>(V))
752         createInstPostProc(I);
753     return V;
754 }
755 
createFNeg(Value * V)756 Value* FAddCombine::createFNeg(Value* V) {
757     Value* Zero = cast<Value>(ConstantFP::getZeroValueForNegation(V->getType()));
758     Value* NewV = createFSub(Zero, V);
759     if (Instruction * I = dyn_cast<Instruction>(NewV))
760         createInstPostProc(I, true); // fneg's don't receive instruction numbers.
761     return NewV;
762 }
763 
createFAdd(Value * Opnd0,Value * Opnd1)764 Value* FAddCombine::createFAdd(Value* Opnd0, Value* Opnd1) {
765     Value* V = Builder.CreateFAdd(Opnd0, Opnd1);
766     if (Instruction * I = dyn_cast<Instruction>(V))
767         createInstPostProc(I);
768     return V;
769 }
770 
createFMul(Value * Opnd0,Value * Opnd1)771 Value* FAddCombine::createFMul(Value* Opnd0, Value* Opnd1) {
772     Value* V = Builder.CreateFMul(Opnd0, Opnd1);
773     if (Instruction * I = dyn_cast<Instruction>(V))
774         createInstPostProc(I);
775     return V;
776 }
777 
createFDiv(Value * Opnd0,Value * Opnd1)778 Value* FAddCombine::createFDiv(Value* Opnd0, Value* Opnd1) {
779     Value* V = Builder.CreateFDiv(Opnd0, Opnd1);
780     if (Instruction * I = dyn_cast<Instruction>(V))
781         createInstPostProc(I);
782     return V;
783 }
784 
createInstPostProc(Instruction * NewInstr,bool NoNumber)785 void FAddCombine::createInstPostProc(Instruction* NewInstr, bool NoNumber) {
786     NewInstr->setDebugLoc(Instr->getDebugLoc());
787 
788     // Keep track of the number of instruction created.
789     if (!NoNumber)
790         ++InstructionCounter;
791 
792     // Propagate fast-math flags
793     NewInstr->setFastMathFlags(Instr->getFastMathFlags());
794 }
795 
796 // Return the number of instruction needed to emit the N-ary addition.
797 // NOTE: Keep this function in sync with createAddendVal().
calcInstrNumber(const AddendVect & Opnds)798 unsigned FAddCombine::calcInstrNumber(const AddendVect& Opnds) {
799     unsigned OpndNum = Opnds.size();
800     unsigned InstrNeeded = OpndNum - 1;
801 
802     // The number of addends in the form of "(-1)*x".
803     unsigned NegOpndNum = 0;
804 
805     // Adjust the number of instructions needed to emit the N-ary add.
806     for (const FAddend* Opnd : Opnds) {
807         if (Opnd->isConstant())
808             continue;
809 
810         // The constant check above is really for a few special constant
811         // coefficients.
812         if (isa<UndefValue>(Opnd->getSymVal()))
813             continue;
814 
815         const FAddendCoef& CE = Opnd->getCoef();
816         if (CE.isMinusOne() || CE.isMinusTwo())
817             NegOpndNum++;
818 
819         // Let the addend be "c * x". If "c == +/-1", the value of the addend
820         // is immediately available; otherwise, it needs exactly one instruction
821         // to evaluate the value.
822         if (!CE.isMinusOne() && !CE.isOne())
823             InstrNeeded++;
824     }
825     if (NegOpndNum == OpndNum)
826         InstrNeeded++;
827     return InstrNeeded;
828 }
829 
830 // Input Addend        Value           NeedNeg(output)
831 // ================================================================
832 // Constant C          C               false
833 // <+/-1, V>           V               coefficient is -1
834 // <2/-2, V>          "fadd V, V"      coefficient is -2
835 // <C, V>             "fmul V, C"      false
836 //
837 // NOTE: Keep this function in sync with FAddCombine::calcInstrNumber.
createAddendVal(const FAddend & Opnd,bool & NeedNeg)838 Value* FAddCombine::createAddendVal(const FAddend& Opnd, bool& NeedNeg) {
839     const FAddendCoef& Coeff = Opnd.getCoef();
840 
841     if (Opnd.isConstant()) {
842         NeedNeg = false;
843         return Coeff.getValue(Instr->getType());
844     }
845 
846     Value* OpndVal = Opnd.getSymVal();
847 
848     if (Coeff.isMinusOne() || Coeff.isOne()) {
849         NeedNeg = Coeff.isMinusOne();
850         return OpndVal;
851     }
852 
853     if (Coeff.isTwo() || Coeff.isMinusTwo()) {
854         NeedNeg = Coeff.isMinusTwo();
855         return createFAdd(OpndVal, OpndVal);
856     }
857 
858     NeedNeg = false;
859     return createFMul(OpndVal, Coeff.getValue(Instr->getType()));
860 }
861 
862 // Checks if any operand is negative and we can convert add to sub.
863 // This function checks for following negative patterns
864 //   ADD(XOR(OR(Z, NOT(C)), C)), 1) == NEG(AND(Z, C))
865 //   ADD(XOR(AND(Z, C), C), 1) == NEG(OR(Z, ~C))
866 //   XOR(AND(Z, C), (C + 1)) == NEG(OR(Z, ~C)) if C is even
checkForNegativeOperand(BinaryOperator & I,InstCombiner::BuilderTy & Builder)867 static Value* checkForNegativeOperand(BinaryOperator& I,
868     InstCombiner::BuilderTy& Builder) {
869     Value* LHS = I.getOperand(0), * RHS = I.getOperand(1);
870 
871     // This function creates 2 instructions to replace ADD, we need at least one
872     // of LHS or RHS to have one use to ensure benefit in transform.
873     if (!LHS->hasOneUse() && !RHS->hasOneUse())
874         return nullptr;
875 
876     Value* X = nullptr, * Y = nullptr, * Z = nullptr;
877     const APInt* C1 = nullptr, * C2 = nullptr;
878 
879     // if ONE is on other side, swap
880     if (match(RHS, m_Add(m_Value(X), m_One())))
881         std::swap(LHS, RHS);
882 
883     if (match(LHS, m_Add(m_Value(X), m_One()))) {
884         // if XOR on other side, swap
885         if (match(RHS, m_Xor(m_Value(Y), m_APInt(C1))))
886             std::swap(X, RHS);
887 
888         if (match(X, m_Xor(m_Value(Y), m_APInt(C1)))) {
889             // X = XOR(Y, C1), Y = OR(Z, C2), C2 = NOT(C1) ==> X == NOT(AND(Z, C1))
890             // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, AND(Z, C1))
891             if (match(Y, m_Or(m_Value(Z), m_APInt(C2))) && (*C2 == ~(*C1))) {
892                 Value* NewAnd = Builder.CreateAnd(Z, *C1);
893                 return Builder.CreateSub(RHS, NewAnd, "sub");
894             }
895             else if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && (*C1 == *C2)) {
896                 // X = XOR(Y, C1), Y = AND(Z, C2), C2 == C1 ==> X == NOT(OR(Z, ~C1))
897                 // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, OR(Z, ~C1))
898                 Value* NewOr = Builder.CreateOr(Z, ~(*C1));
899                 return Builder.CreateSub(RHS, NewOr, "sub");
900             }
901         }
902     }
903 
904     // Restore LHS and RHS
905     LHS = I.getOperand(0);
906     RHS = I.getOperand(1);
907 
908     // if XOR is on other side, swap
909     if (match(RHS, m_Xor(m_Value(Y), m_APInt(C1))))
910         std::swap(LHS, RHS);
911 
912     // C2 is ODD
913     // LHS = XOR(Y, C1), Y = AND(Z, C2), C1 == (C2 + 1) => LHS == NEG(OR(Z, ~C2))
914     // ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2))
915     if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1))))
916         if (C1->countTrailingZeros() == 0)
917             if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) {
918                 Value* NewOr = Builder.CreateOr(Z, ~(*C2));
919                 return Builder.CreateSub(RHS, NewOr, "sub");
920             }
921     return nullptr;
922 }
923 
foldAddWithConstant(BinaryOperator & Add)924 Instruction* InstCombiner::foldAddWithConstant(BinaryOperator& Add) {
925     Value* Op0 = Add.getOperand(0), * Op1 = Add.getOperand(1);
926     Constant* Op1C;
927     if (!match(Op1, m_Constant(Op1C)))
928         return nullptr;
929 
930     if (Instruction * NV = foldBinOpIntoSelectOrPhi(Add))
931         return NV;
932 
933     Value* X, * Y;
934 
935     // add (sub X, Y), -1 --> add (not Y), X
936     if (match(Op0, m_OneUse(m_Sub(m_Value(X), m_Value(Y)))) &&
937         match(Op1, m_AllOnes()))
938         return BinaryOperator::CreateAdd(Builder.CreateNot(Y), X);
939 
940     // zext(bool) + C -> bool ? C + 1 : C
941     if (match(Op0, m_ZExt(m_Value(X))) &&
942         X->getType()->getScalarSizeInBits() == 1)
943         return SelectInst::Create(X, AddOne(Op1C), Op1);
944 
945     // ~X + C --> (C-1) - X
946     if (match(Op0, m_Not(m_Value(X))))
947         return BinaryOperator::CreateSub(SubOne(Op1C), X);
948 
949     const APInt* C;
950     if (!match(Op1, m_APInt(C)))
951         return nullptr;
952 
953     if (C->isSignMask()) {
954         // If wrapping is not allowed, then the addition must set the sign bit:
955         // X + (signmask) --> X | signmask
956         if (Add.hasNoSignedWrap() || Add.hasNoUnsignedWrap())
957             return BinaryOperator::CreateOr(Op0, Op1);
958 
959         // If wrapping is allowed, then the addition flips the sign bit of LHS:
960         // X + (signmask) --> X ^ signmask
961         return BinaryOperator::CreateXor(Op0, Op1);
962     }
963 
964     // Is this add the last step in a convoluted sext?
965     // add(zext(xor i16 X, -32768), -32768) --> sext X
966     Type* Ty = Add.getType();
967     const APInt* C2;
968     if (match(Op0, m_ZExt(m_Xor(m_Value(X), m_APInt(C2)))) &&
969         C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C)
970         return CastInst::Create(Instruction::SExt, X, Ty);
971 
972     // (add (zext (add nuw X, C2)), C) --> (zext (add nuw X, C2 + C))
973     if (match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C2))))) &&
974         C->isNegative() && C->sge(-C2->sext(C->getBitWidth()))) {
975         Constant* NewC =
976             ConstantInt::get(X->getType(), *C2 + C->trunc(C2->getBitWidth()));
977         return new ZExtInst(Builder.CreateNUWAdd(X, NewC), Ty);
978     }
979 
980     if (C->isOneValue() && Op0->hasOneUse()) {
981         // add (sext i1 X), 1 --> zext (not X)
982         // TODO: The smallest IR representation is (select X, 0, 1), and that would
983         // not require the one-use check. But we need to remove a transform in
984         // visitSelect and make sure that IR value tracking for select is equal or
985         // better than for these ops.
986         if (match(Op0, m_SExt(m_Value(X))) &&
987             X->getType()->getScalarSizeInBits() == 1)
988             return new ZExtInst(Builder.CreateNot(X), Ty);
989 
990         // Shifts and add used to flip and mask off the low bit:
991         // add (ashr (shl i32 X, 31), 31), 1 --> and (not X), 1
992         const APInt* C3;
993         if (match(Op0, m_AShr(m_Shl(m_Value(X), m_APInt(C2)), m_APInt(C3))) &&
994             C2 == C3 && *C2 == Ty->getScalarSizeInBits() - 1) {
995             Value* NotX = Builder.CreateNot(X);
996             return BinaryOperator::CreateAnd(NotX, ConstantInt::get(Ty, 1));
997         }
998     }
999 
1000     return nullptr;
1001 }
1002 
1003 // Matches multiplication expression Op * C where C is a constant. Returns the
1004 // constant value in C and the other operand in Op. Returns true if such a
1005 // match is found.
MatchMul(Value * E,Value * & Op,APInt & C)1006 static bool MatchMul(Value* E, Value*& Op, APInt& C) {
1007     const APInt* AI;
1008     if (match(E, m_Mul(m_Value(Op), m_APInt(AI)))) {
1009         C = *AI;
1010         return true;
1011     }
1012     if (match(E, m_Shl(m_Value(Op), m_APInt(AI)))) {
1013         C = APInt(AI->getBitWidth(), 1);
1014         C <<= *AI;
1015         return true;
1016     }
1017     return false;
1018 }
1019 
1020 // Matches remainder expression Op % C where C is a constant. Returns the
1021 // constant value in C and the other operand in Op. Returns the signedness of
1022 // the remainder operation in IsSigned. Returns true if such a match is
1023 // found.
MatchRem(Value * E,Value * & Op,APInt & C,bool & IsSigned)1024 static bool MatchRem(Value* E, Value*& Op, APInt& C, bool& IsSigned) {
1025     const APInt* AI;
1026     IsSigned = false;
1027     if (match(E, m_SRem(m_Value(Op), m_APInt(AI)))) {
1028         IsSigned = true;
1029         C = *AI;
1030         return true;
1031     }
1032     if (match(E, m_URem(m_Value(Op), m_APInt(AI)))) {
1033         C = *AI;
1034         return true;
1035     }
1036     if (match(E, m_And(m_Value(Op), m_APInt(AI))) && (*AI + 1).isPowerOf2()) {
1037         C = *AI + 1;
1038         return true;
1039     }
1040     return false;
1041 }
1042 
1043 // Matches division expression Op / C with the given signedness as indicated
1044 // by IsSigned, where C is a constant. Returns the constant value in C and the
1045 // other operand in Op. Returns true if such a match is found.
MatchDiv(Value * E,Value * & Op,APInt & C,bool IsSigned)1046 static bool MatchDiv(Value* E, Value*& Op, APInt& C, bool IsSigned) {
1047     const APInt* AI;
1048     if (IsSigned && match(E, m_SDiv(m_Value(Op), m_APInt(AI)))) {
1049         C = *AI;
1050         return true;
1051     }
1052     if (!IsSigned) {
1053         if (match(E, m_UDiv(m_Value(Op), m_APInt(AI)))) {
1054             C = *AI;
1055             return true;
1056         }
1057         if (match(E, m_LShr(m_Value(Op), m_APInt(AI)))) {
1058             C = APInt(AI->getBitWidth(), 1);
1059             C <<= *AI;
1060             return true;
1061         }
1062     }
1063     return false;
1064 }
1065 
1066 // Returns whether C0 * C1 with the given signedness overflows.
MulWillOverflow(APInt & C0,APInt & C1,bool IsSigned)1067 static bool MulWillOverflow(APInt& C0, APInt& C1, bool IsSigned) {
1068     bool overflow;
1069     if (IsSigned)
1070         (void)C0.smul_ov(C1, overflow);
1071     else
1072         (void)C0.umul_ov(C1, overflow);
1073     return overflow;
1074 }
1075 
1076 // Simplifies X % C0 + (( X / C0 ) % C1) * C0 to X % (C0 * C1), where (C0 * C1)
1077 // does not overflow.
SimplifyAddWithRemainder(BinaryOperator & I)1078 Value* InstCombiner::SimplifyAddWithRemainder(BinaryOperator& I) {
1079     Value* LHS = I.getOperand(0), * RHS = I.getOperand(1);
1080     Value* X, * MulOpV;
1081     APInt C0, MulOpC;
1082     bool IsSigned;
1083     // Match I = X % C0 + MulOpV * C0
1084     if (((MatchRem(LHS, X, C0, IsSigned) && MatchMul(RHS, MulOpV, MulOpC)) ||
1085         (MatchRem(RHS, X, C0, IsSigned) && MatchMul(LHS, MulOpV, MulOpC))) &&
1086         C0 == MulOpC) {
1087         Value* RemOpV;
1088         APInt C1;
1089         bool Rem2IsSigned;
1090         // Match MulOpC = RemOpV % C1
1091         if (MatchRem(MulOpV, RemOpV, C1, Rem2IsSigned) &&
1092             IsSigned == Rem2IsSigned) {
1093             Value* DivOpV;
1094             APInt DivOpC;
1095             // Match RemOpV = X / C0
1096             if (MatchDiv(RemOpV, DivOpV, DivOpC, IsSigned) && X == DivOpV &&
1097                 C0 == DivOpC && !MulWillOverflow(C0, C1, IsSigned)) {
1098                 Value* NewDivisor =
1099                     ConstantInt::get(X->getType()->getContext(), C0 * C1);
1100                 return IsSigned ? Builder.CreateSRem(X, NewDivisor, "srem")
1101                     : Builder.CreateURem(X, NewDivisor, "urem");
1102             }
1103         }
1104     }
1105 
1106     return nullptr;
1107 }
1108 
1109 /// Fold
1110 ///   (1 << NBits) - 1
1111 /// Into:
1112 ///   ~(-(1 << NBits))
1113 /// Because a 'not' is better for bit-tracking analysis and other transforms
1114 /// than an 'add'. The new shl is always nsw, and is nuw if old `and` was.
canonicalizeLowbitMask(BinaryOperator & I,InstCombiner::BuilderTy & Builder)1115 static Instruction* canonicalizeLowbitMask(BinaryOperator& I,
1116     InstCombiner::BuilderTy& Builder) {
1117     Value* NBits;
1118     if (!match(&I, m_Add(m_OneUse(m_Shl(m_One(), m_Value(NBits))), m_AllOnes())))
1119         return nullptr;
1120 
1121     Constant* MinusOne = Constant::getAllOnesValue(NBits->getType());
1122     Value* NotMask = Builder.CreateShl(MinusOne, NBits, "notmask");
1123     // Be wary of constant folding.
1124     if (auto * BOp = dyn_cast<BinaryOperator>(NotMask)) {
1125         // Always NSW. But NUW propagates from `add`.
1126         BOp->setHasNoSignedWrap();
1127         BOp->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
1128     }
1129 
1130     return BinaryOperator::CreateNot(NotMask, I.getName());
1131 }
1132 
visitAdd(BinaryOperator & I)1133 Instruction* InstCombiner::visitAdd(BinaryOperator& I) {
1134     if (Value * V = SimplifyAddInst(I.getOperand(0), I.getOperand(1),
1135         I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
1136         SQ.getWithInstruction(&I)))
1137         return replaceInstUsesWith(I, V);
1138 
1139     if (SimplifyAssociativeOrCommutative(I))
1140         return &I;
1141 
1142     if (Instruction * X = foldShuffledBinop(I))
1143         return X;
1144 
1145     // (A*B)+(A*C) -> A*(B+C) etc
1146     if (Value * V = SimplifyUsingDistributiveLaws(I))
1147         return replaceInstUsesWith(I, V);
1148 
1149     if (Instruction * X = foldAddWithConstant(I))
1150         return X;
1151 
1152     // FIXME: This should be moved into the above helper function to allow these
1153     // transforms for general constant or constant splat vectors.
1154     Value* LHS = I.getOperand(0), * RHS = I.getOperand(1);
1155     Type* Ty = I.getType();
1156     if (ConstantInt * CI = dyn_cast<ConstantInt>(RHS)) {
1157         Value* XorLHS = nullptr; ConstantInt* XorRHS = nullptr;
1158         if (match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) {
1159             unsigned TySizeBits = Ty->getScalarSizeInBits();
1160             const APInt& RHSVal = CI->getValue();
1161             unsigned ExtendAmt = 0;
1162             // If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext.
1163             // If we have ADD(XOR(AND(X, 0xFF), 0xF..F80), 0x80), it's a sext.
1164             if (XorRHS->getValue() == -RHSVal) {
1165                 if (RHSVal.isPowerOf2())
1166                     ExtendAmt = TySizeBits - RHSVal.logBase2() - 1;
1167                 else if (XorRHS->getValue().isPowerOf2())
1168                     ExtendAmt = TySizeBits - XorRHS->getValue().logBase2() - 1;
1169             }
1170 
1171             if (ExtendAmt) {
1172                 APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt);
1173                 if (!MaskedValueIsZero(XorLHS, Mask, 0, &I))
1174                     ExtendAmt = 0;
1175             }
1176 
1177             if (ExtendAmt) {
1178                 Constant* ShAmt = ConstantInt::get(Ty, ExtendAmt);
1179                 Value* NewShl = Builder.CreateShl(XorLHS, ShAmt, "sext");
1180                 return BinaryOperator::CreateAShr(NewShl, ShAmt);
1181             }
1182 
1183             // If this is a xor that was canonicalized from a sub, turn it back into
1184             // a sub and fuse this add with it.
1185             if (LHS->hasOneUse() && (XorRHS->getValue() + 1).isPowerOf2()) {
1186                 KnownBits LHSKnown = computeKnownBits(XorLHS, 0, &I);
1187                 if ((XorRHS->getValue() | LHSKnown.Zero).isAllOnesValue())
1188                     return BinaryOperator::CreateSub(ConstantExpr::getAdd(XorRHS, CI),
1189                         XorLHS);
1190             }
1191             // (X + signmask) + C could have gotten canonicalized to (X^signmask) + C,
1192             // transform them into (X + (signmask ^ C))
1193             if (XorRHS->getValue().isSignMask())
1194                 return BinaryOperator::CreateAdd(XorLHS,
1195                     ConstantExpr::getXor(XorRHS, CI));
1196         }
1197     }
1198 
1199     if (Ty->isIntOrIntVectorTy(1))
1200         return BinaryOperator::CreateXor(LHS, RHS);
1201 
1202     // X + X --> X << 1
1203     if (LHS == RHS) {
1204         auto* Shl = BinaryOperator::CreateShl(LHS, ConstantInt::get(Ty, 1));
1205         Shl->setHasNoSignedWrap(I.hasNoSignedWrap());
1206         Shl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
1207         return Shl;
1208     }
1209 
1210     Value* A, * B;
1211     if (match(LHS, m_Neg(m_Value(A)))) {
1212         // -A + -B --> -(A + B)
1213         if (match(RHS, m_Neg(m_Value(B))))
1214             return BinaryOperator::CreateNeg(Builder.CreateAdd(A, B));
1215 
1216         // -A + B --> B - A
1217         return BinaryOperator::CreateSub(RHS, A);
1218     }
1219 
1220     // A + -B  -->  A - B
1221     if (match(RHS, m_Neg(m_Value(B))))
1222         return BinaryOperator::CreateSub(LHS, B);
1223 
1224     if (Value * V = checkForNegativeOperand(I, Builder))
1225         return replaceInstUsesWith(I, V);
1226 
1227     // (A + 1) + ~B --> A - B
1228     // ~B + (A + 1) --> A - B
1229     if (match(&I, m_c_BinOp(m_Add(m_Value(A), m_One()), m_Not(m_Value(B)))))
1230         return BinaryOperator::CreateSub(A, B);
1231 
1232     // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1)
1233     if (Value * V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V);
1234 
1235     // A+B --> A|B iff A and B have no bits set in common.
1236     if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT))
1237         return BinaryOperator::CreateOr(LHS, RHS);
1238 
1239     // FIXME: We already did a check for ConstantInt RHS above this.
1240     // FIXME: Is this pattern covered by another fold? No regression tests fail on
1241     // removal.
1242     if (ConstantInt * CRHS = dyn_cast<ConstantInt>(RHS)) {
1243         // (X & FF00) + xx00  -> (X+xx00) & FF00
1244         Value* X;
1245         ConstantInt* C2;
1246         if (LHS->hasOneUse() &&
1247             match(LHS, m_And(m_Value(X), m_ConstantInt(C2))) &&
1248             CRHS->getValue() == (CRHS->getValue() & C2->getValue())) {
1249             // See if all bits from the first bit set in the Add RHS up are included
1250             // in the mask.  First, get the rightmost bit.
1251             const APInt& AddRHSV = CRHS->getValue();
1252 
1253             // Form a mask of all bits from the lowest bit added through the top.
1254             APInt AddRHSHighBits(~((AddRHSV & -AddRHSV) - 1));
1255 
1256             // See if the and mask includes all of these bits.
1257             APInt AddRHSHighBitsAnd(AddRHSHighBits & C2->getValue());
1258 
1259             if (AddRHSHighBits == AddRHSHighBitsAnd) {
1260                 // Okay, the xform is safe.  Insert the new add pronto.
1261                 Value* NewAdd = Builder.CreateAdd(X, CRHS, LHS->getName());
1262                 return BinaryOperator::CreateAnd(NewAdd, C2);
1263             }
1264         }
1265     }
1266 
1267     // add (select X 0 (sub n A)) A  -->  select X A n
1268     {
1269         SelectInst* SI = dyn_cast<SelectInst>(LHS);
1270         Value* A = RHS;
1271         if (!SI) {
1272             SI = dyn_cast<SelectInst>(RHS);
1273             A = LHS;
1274         }
1275         if (SI && SI->hasOneUse()) {
1276             Value* TV = SI->getTrueValue();
1277             Value* FV = SI->getFalseValue();
1278             Value* N;
1279 
1280             // Can we fold the add into the argument of the select?
1281             // We check both true and false select arguments for a matching subtract.
1282             if (match(FV, m_Zero()) && match(TV, m_Sub(m_Value(N), m_Specific(A))))
1283                 // Fold the add into the true select value.
1284                 return SelectInst::Create(SI->getCondition(), N, A);
1285 
1286             if (match(TV, m_Zero()) && match(FV, m_Sub(m_Value(N), m_Specific(A))))
1287                 // Fold the add into the false select value.
1288                 return SelectInst::Create(SI->getCondition(), A, N);
1289         }
1290     }
1291 
1292     // Check for (add (sext x), y), see if we can merge this into an
1293     // integer add followed by a sext.
1294     if (SExtInst * LHSConv = dyn_cast<SExtInst>(LHS)) {
1295         // (add (sext x), cst) --> (sext (add x, cst'))
1296         if (ConstantInt * RHSC = dyn_cast<ConstantInt>(RHS)) {
1297             if (LHSConv->hasOneUse()) {
1298                 Constant* CI =
1299                     ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType());
1300                 if (ConstantExpr::getSExt(CI, Ty) == RHSC &&
1301                     willNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) {
1302                     // Insert the new, smaller add.
1303                     Value* NewAdd =
1304                         Builder.CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv");
1305                     return new SExtInst(NewAdd, Ty);
1306                 }
1307             }
1308         }
1309 
1310         // (add (sext x), (sext y)) --> (sext (add int x, y))
1311         if (SExtInst * RHSConv = dyn_cast<SExtInst>(RHS)) {
1312             // Only do this if x/y have the same type, if at least one of them has a
1313             // single use (so we don't increase the number of sexts), and if the
1314             // integer add will not overflow.
1315             if (LHSConv->getOperand(0)->getType() ==
1316                 RHSConv->getOperand(0)->getType() &&
1317                 (LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
1318                 willNotOverflowSignedAdd(LHSConv->getOperand(0),
1319                     RHSConv->getOperand(0), I)) {
1320                 // Insert the new integer add.
1321                 Value* NewAdd = Builder.CreateNSWAdd(LHSConv->getOperand(0),
1322                     RHSConv->getOperand(0), "addconv");
1323                 return new SExtInst(NewAdd, Ty);
1324             }
1325         }
1326     }
1327 
1328     // Check for (add (zext x), y), see if we can merge this into an
1329     // integer add followed by a zext.
1330     if (auto * LHSConv = dyn_cast<ZExtInst>(LHS)) {
1331         // (add (zext x), cst) --> (zext (add x, cst'))
1332         if (ConstantInt * RHSC = dyn_cast<ConstantInt>(RHS)) {
1333             if (LHSConv->hasOneUse()) {
1334                 Constant* CI =
1335                     ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType());
1336                 if (ConstantExpr::getZExt(CI, Ty) == RHSC &&
1337                     willNotOverflowUnsignedAdd(LHSConv->getOperand(0), CI, I)) {
1338                     // Insert the new, smaller add.
1339                     Value* NewAdd =
1340                         Builder.CreateNUWAdd(LHSConv->getOperand(0), CI, "addconv");
1341                     return new ZExtInst(NewAdd, Ty);
1342                 }
1343             }
1344         }
1345 
1346         // (add (zext x), (zext y)) --> (zext (add int x, y))
1347         if (auto * RHSConv = dyn_cast<ZExtInst>(RHS)) {
1348             // Only do this if x/y have the same type, if at least one of them has a
1349             // single use (so we don't increase the number of zexts), and if the
1350             // integer add will not overflow.
1351             if (LHSConv->getOperand(0)->getType() ==
1352                 RHSConv->getOperand(0)->getType() &&
1353                 (LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
1354                 willNotOverflowUnsignedAdd(LHSConv->getOperand(0),
1355                     RHSConv->getOperand(0), I)) {
1356                 // Insert the new integer add.
1357                 Value* NewAdd = Builder.CreateNUWAdd(
1358                     LHSConv->getOperand(0), RHSConv->getOperand(0), "addconv");
1359                 return new ZExtInst(NewAdd, Ty);
1360             }
1361         }
1362     }
1363 
1364     // (add (xor A, B) (and A, B)) --> (or A, B)
1365     // (add (and A, B) (xor A, B)) --> (or A, B)
1366     if (match(&I, m_c_BinOp(m_Xor(m_Value(A), m_Value(B)),
1367         m_c_And(m_Deferred(A), m_Deferred(B)))))
1368         return BinaryOperator::CreateOr(A, B);
1369 
1370     // (add (or A, B) (and A, B)) --> (add A, B)
1371     // (add (and A, B) (or A, B)) --> (add A, B)
1372     if (match(&I, m_c_BinOp(m_Or(m_Value(A), m_Value(B)),
1373         m_c_And(m_Deferred(A), m_Deferred(B))))) {
1374         I.setOperand(0, A);
1375         I.setOperand(1, B);
1376         return &I;
1377     }
1378 
1379     // TODO(jingyue): Consider willNotOverflowSignedAdd and
1380     // willNotOverflowUnsignedAdd to reduce the number of invocations of
1381     // computeKnownBits.
1382     bool Changed = false;
1383     if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) {
1384         Changed = true;
1385         I.setHasNoSignedWrap(true);
1386     }
1387     if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedAdd(LHS, RHS, I)) {
1388         Changed = true;
1389         I.setHasNoUnsignedWrap(true);
1390     }
1391 
1392     if (Instruction * V = canonicalizeLowbitMask(I, Builder))
1393         return V;
1394 
1395     return Changed ? &I : nullptr;
1396 }
1397 
visitFAdd(BinaryOperator & I)1398 Instruction* InstCombiner::visitFAdd(BinaryOperator& I) {
1399     if (Value * V = SimplifyFAddInst(I.getOperand(0), I.getOperand(1),
1400         I.getFastMathFlags(),
1401         SQ.getWithInstruction(&I)))
1402         return replaceInstUsesWith(I, V);
1403 
1404     if (SimplifyAssociativeOrCommutative(I))
1405         return &I;
1406 
1407     if (Instruction * X = foldShuffledBinop(I))
1408         return X;
1409 
1410     if (Instruction * FoldedFAdd = foldBinOpIntoSelectOrPhi(I))
1411         return FoldedFAdd;
1412 
1413     Value* LHS = I.getOperand(0), * RHS = I.getOperand(1);
1414     Value* X;
1415     // (-X) + Y --> Y - X
1416     if (match(LHS, m_FNeg(m_Value(X))))
1417         return BinaryOperator::CreateFSubFMF(RHS, X, &I);
1418     // Y + (-X) --> Y - X
1419     if (match(RHS, m_FNeg(m_Value(X))))
1420         return BinaryOperator::CreateFSubFMF(LHS, X, &I);
1421 
1422     // Check for (fadd double (sitofp x), y), see if we can merge this into an
1423     // integer add followed by a promotion.
1424     if (SIToFPInst * LHSConv = dyn_cast<SIToFPInst>(LHS)) {
1425         Value* LHSIntVal = LHSConv->getOperand(0);
1426         Type* FPType = LHSConv->getType();
1427 
1428         // TODO: This check is overly conservative. In many cases known bits
1429         // analysis can tell us that the result of the addition has less significant
1430         // bits than the integer type can hold.
1431         auto IsValidPromotion = [](Type* FTy, Type* ITy) {
1432             Type* FScalarTy = FTy->getScalarType();
1433             Type* IScalarTy = ITy->getScalarType();
1434 
1435             // Do we have enough bits in the significand to represent the result of
1436             // the integer addition?
1437             unsigned MaxRepresentableBits =
1438                 APFloat::semanticsPrecision(FScalarTy->getFltSemantics());
1439             return IScalarTy->getIntegerBitWidth() <= MaxRepresentableBits;
1440         };
1441 
1442         // (fadd double (sitofp x), fpcst) --> (sitofp (add int x, intcst))
1443         // ... if the constant fits in the integer value.  This is useful for things
1444         // like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer
1445         // requires a constant pool load, and generally allows the add to be better
1446         // instcombined.
1447         if (ConstantFP * CFP = dyn_cast<ConstantFP>(RHS))
1448             if (IsValidPromotion(FPType, LHSIntVal->getType())) {
1449                 Constant* CI =
1450                     ConstantExpr::getFPToSI(CFP, LHSIntVal->getType());
1451                 if (LHSConv->hasOneUse() &&
1452                     ConstantExpr::getSIToFP(CI, I.getType()) == CFP &&
1453                     willNotOverflowSignedAdd(LHSIntVal, CI, I)) {
1454                     // Insert the new integer add.
1455                     Value* NewAdd = Builder.CreateNSWAdd(LHSIntVal, CI, "addconv");
1456                     return new SIToFPInst(NewAdd, I.getType());
1457                 }
1458             }
1459 
1460         // (fadd double (sitofp x), (sitofp y)) --> (sitofp (add int x, y))
1461         if (SIToFPInst * RHSConv = dyn_cast<SIToFPInst>(RHS)) {
1462             Value* RHSIntVal = RHSConv->getOperand(0);
1463             // It's enough to check LHS types only because we require int types to
1464             // be the same for this transform.
1465             if (IsValidPromotion(FPType, LHSIntVal->getType())) {
1466                 // Only do this if x/y have the same type, if at least one of them has a
1467                 // single use (so we don't increase the number of int->fp conversions),
1468                 // and if the integer add will not overflow.
1469                 if (LHSIntVal->getType() == RHSIntVal->getType() &&
1470                     (LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
1471                     willNotOverflowSignedAdd(LHSIntVal, RHSIntVal, I)) {
1472                     // Insert the new integer add.
1473                     Value* NewAdd = Builder.CreateNSWAdd(LHSIntVal, RHSIntVal, "addconv");
1474                     return new SIToFPInst(NewAdd, I.getType());
1475                 }
1476             }
1477         }
1478     }
1479 
1480     // Handle specials cases for FAdd with selects feeding the operation
1481     if (Value * V = SimplifySelectsFeedingBinaryOp(I, LHS, RHS))
1482         return replaceInstUsesWith(I, V);
1483 
1484     if (I.hasAllowReassoc() && I.hasNoSignedZeros()) {
1485         if (Value * V = FAddCombine(Builder).simplify(&I))
1486             return replaceInstUsesWith(I, V);
1487     }
1488 
1489     return nullptr;
1490 }
1491 
1492 /// Optimize pointer differences into the same array into a size.  Consider:
1493 ///  &A[10] - &A[0]: we should compile this to "10".  LHS/RHS are the pointer
1494 /// operands to the ptrtoint instructions for the LHS/RHS of the subtract.
OptimizePointerDifference(Value * LHS,Value * RHS,Type * Ty)1495 Value* InstCombiner::OptimizePointerDifference(Value* LHS, Value* RHS,
1496     Type* Ty) {
1497     // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize
1498     // this.
1499     bool Swapped = false;
1500     GEPOperator* GEP1 = nullptr, * GEP2 = nullptr;
1501 
1502     // For now we require one side to be the base pointer "A" or a constant
1503     // GEP derived from it.
1504     if (GEPOperator * LHSGEP = dyn_cast<GEPOperator>(LHS)) {
1505         // (gep X, ...) - X
1506         if (LHSGEP->getOperand(0) == RHS) {
1507             GEP1 = LHSGEP;
1508             Swapped = false;
1509         }
1510         else if (GEPOperator * RHSGEP = dyn_cast<GEPOperator>(RHS)) {
1511             // (gep X, ...) - (gep X, ...)
1512             if (LHSGEP->getOperand(0)->stripPointerCasts() ==
1513                 RHSGEP->getOperand(0)->stripPointerCasts()) {
1514                 GEP2 = RHSGEP;
1515                 GEP1 = LHSGEP;
1516                 Swapped = false;
1517             }
1518         }
1519     }
1520 
1521     if (GEPOperator * RHSGEP = dyn_cast<GEPOperator>(RHS)) {
1522         // X - (gep X, ...)
1523         if (RHSGEP->getOperand(0) == LHS) {
1524             GEP1 = RHSGEP;
1525             Swapped = true;
1526         }
1527         else if (GEPOperator * LHSGEP = dyn_cast<GEPOperator>(LHS)) {
1528             // (gep X, ...) - (gep X, ...)
1529             if (RHSGEP->getOperand(0)->stripPointerCasts() ==
1530                 LHSGEP->getOperand(0)->stripPointerCasts()) {
1531                 GEP2 = LHSGEP;
1532                 GEP1 = RHSGEP;
1533                 Swapped = true;
1534             }
1535         }
1536     }
1537 
1538     if (!GEP1)
1539         // No GEP found.
1540         return nullptr;
1541 
1542     if (GEP2) {
1543         // (gep X, ...) - (gep X, ...)
1544         //
1545         // Avoid duplicating the arithmetic if there are more than one non-constant
1546         // indices between the two GEPs and either GEP has a non-constant index and
1547         // multiple users. If zero non-constant index, the result is a constant and
1548         // there is no duplication. If one non-constant index, the result is an add
1549         // or sub with a constant, which is no larger than the original code, and
1550         // there's no duplicated arithmetic, even if either GEP has multiple
1551         // users. If more than one non-constant indices combined, as long as the GEP
1552         // with at least one non-constant index doesn't have multiple users, there
1553         // is no duplication.
1554         unsigned NumNonConstantIndices1 = GEP1->countNonConstantIndices();
1555         unsigned NumNonConstantIndices2 = GEP2->countNonConstantIndices();
1556         if (NumNonConstantIndices1 + NumNonConstantIndices2 > 1 &&
1557             ((NumNonConstantIndices1 > 0 && !GEP1->hasOneUse()) ||
1558             (NumNonConstantIndices2 > 0 && !GEP2->hasOneUse()))) {
1559             return nullptr;
1560         }
1561     }
1562 
1563     // Emit the offset of the GEP and an intptr_t.
1564     Value* Result = EmitGEPOffset(GEP1);
1565 
1566     // If we had a constant expression GEP on the other side offsetting the
1567     // pointer, subtract it from the offset we have.
1568     if (GEP2) {
1569         Value* Offset = EmitGEPOffset(GEP2);
1570         Result = Builder.CreateSub(Result, Offset);
1571     }
1572 
1573     // If we have p - gep(p, ...)  then we have to negate the result.
1574     if (Swapped)
1575         Result = Builder.CreateNeg(Result, "diff.neg");
1576 
1577     return Builder.CreateIntCast(Result, Ty, true);
1578 }
1579 
visitSub(BinaryOperator & I)1580 Instruction* InstCombiner::visitSub(BinaryOperator& I) {
1581     if (Value * V = SimplifySubInst(I.getOperand(0), I.getOperand(1),
1582         I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
1583         SQ.getWithInstruction(&I)))
1584         return replaceInstUsesWith(I, V);
1585 
1586     if (Instruction * X = foldShuffledBinop(I))
1587         return X;
1588 
1589     // (A*B)-(A*C) -> A*(B-C) etc
1590     if (Value * V = SimplifyUsingDistributiveLaws(I))
1591         return replaceInstUsesWith(I, V);
1592 
1593     // If this is a 'B = x-(-A)', change to B = x+A.
1594     Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
1595     if (Value * V = dyn_castNegVal(Op1)) {
1596         BinaryOperator* Res = BinaryOperator::CreateAdd(Op0, V);
1597 
1598         if (const auto * BO = dyn_cast<BinaryOperator>(Op1)) {
1599             IGC_ASSERT_MESSAGE(BO->getOpcode() == Instruction::Sub, "Expected a subtraction operator!");
1600             if (BO->hasNoSignedWrap() && I.hasNoSignedWrap())
1601                 Res->setHasNoSignedWrap(true);
1602         }
1603         else {
1604             if (cast<Constant>(Op1)->isNotMinSignedValue() && I.hasNoSignedWrap())
1605                 Res->setHasNoSignedWrap(true);
1606         }
1607 
1608         return Res;
1609     }
1610 
1611     if (I.getType()->isIntOrIntVectorTy(1))
1612         return BinaryOperator::CreateXor(Op0, Op1);
1613 
1614     // Replace (-1 - A) with (~A).
1615     if (match(Op0, m_AllOnes()))
1616         return BinaryOperator::CreateNot(Op1);
1617 
1618     // (~X) - (~Y) --> Y - X
1619     Value* X, * Y;
1620     if (match(Op0, m_Not(m_Value(X))) && match(Op1, m_Not(m_Value(Y))))
1621         return BinaryOperator::CreateSub(Y, X);
1622 
1623     // (X + -1) - Y --> ~Y + X
1624     if (match(Op0, m_OneUse(m_Add(m_Value(X), m_AllOnes()))))
1625         return BinaryOperator::CreateAdd(Builder.CreateNot(Op1), X);
1626 
1627     // Y - (X + 1) --> ~X + Y
1628     if (match(Op1, m_OneUse(m_Add(m_Value(X), m_One()))))
1629         return BinaryOperator::CreateAdd(Builder.CreateNot(X), Op0);
1630 
1631     if (Constant * C = dyn_cast<Constant>(Op0)) {
1632         bool IsNegate = match(C, m_ZeroInt());
1633         Value* X;
1634         if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
1635             // 0 - (zext bool) --> sext bool
1636             // C - (zext bool) --> bool ? C - 1 : C
1637             if (IsNegate)
1638                 return CastInst::CreateSExtOrBitCast(X, I.getType());
1639             return SelectInst::Create(X, SubOne(C), C);
1640         }
1641         if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
1642             // 0 - (sext bool) --> zext bool
1643             // C - (sext bool) --> bool ? C + 1 : C
1644             if (IsNegate)
1645                 return CastInst::CreateZExtOrBitCast(X, I.getType());
1646             return SelectInst::Create(X, AddOne(C), C);
1647         }
1648 
1649         // C - ~X == X + (1+C)
1650         if (match(Op1, m_Not(m_Value(X))))
1651             return BinaryOperator::CreateAdd(X, AddOne(C));
1652 
1653         // Try to fold constant sub into select arguments.
1654         if (SelectInst * SI = dyn_cast<SelectInst>(Op1))
1655             if (Instruction * R = FoldOpIntoSelect(I, SI))
1656                 return R;
1657 
1658         // Try to fold constant sub into PHI values.
1659         if (PHINode * PN = dyn_cast<PHINode>(Op1))
1660             if (Instruction * R = foldOpIntoPhi(I, PN))
1661                 return R;
1662 
1663         // C-(X+C2) --> (C-C2)-X
1664         Constant* C2;
1665         if (match(Op1, m_Add(m_Value(X), m_Constant(C2))))
1666             return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X);
1667     }
1668 
1669     const APInt* Op0C;
1670     if (match(Op0, m_APInt(Op0C))) {
1671         unsigned BitWidth = I.getType()->getScalarSizeInBits();
1672 
1673         // -(X >>u 31) -> (X >>s 31)
1674         // -(X >>s 31) -> (X >>u 31)
1675         if (Op0C->isNullValue()) {
1676             Value* X;
1677             const APInt* ShAmt;
1678             if (match(Op1, m_LShr(m_Value(X), m_APInt(ShAmt))) &&
1679                 *ShAmt == BitWidth - 1) {
1680                 Value* ShAmtOp = cast<Instruction>(Op1)->getOperand(1);
1681                 return BinaryOperator::CreateAShr(X, ShAmtOp);
1682             }
1683             if (match(Op1, m_AShr(m_Value(X), m_APInt(ShAmt))) &&
1684                 *ShAmt == BitWidth - 1) {
1685                 Value* ShAmtOp = cast<Instruction>(Op1)->getOperand(1);
1686                 return BinaryOperator::CreateLShr(X, ShAmtOp);
1687             }
1688 
1689             if (Op1->hasOneUse()) {
1690                 Value* LHS, * RHS;
1691                 SelectPatternFlavor SPF = matchSelectPattern(Op1, LHS, RHS).Flavor;
1692                 if (SPF == SPF_ABS || SPF == SPF_NABS) {
1693                     // This is a negate of an ABS/NABS pattern. Just swap the operands
1694                     // of the select.
1695                     SelectInst* SI = cast<SelectInst>(Op1);
1696                     Value* TrueVal = SI->getTrueValue();
1697                     Value* FalseVal = SI->getFalseValue();
1698                     SI->setTrueValue(FalseVal);
1699                     SI->setFalseValue(TrueVal);
1700                     // Don't swap prof metadata, we didn't change the branch behavior.
1701                     return replaceInstUsesWith(I, SI);
1702                 }
1703             }
1704         }
1705 
1706         // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known
1707         // zero.
1708         if (Op0C->isMask()) {
1709             KnownBits RHSKnown = computeKnownBits(Op1, 0, &I);
1710             if ((*Op0C | RHSKnown.Zero).isAllOnesValue())
1711                 return BinaryOperator::CreateXor(Op1, Op0);
1712         }
1713     }
1714 
1715     {
1716         Value* Y;
1717         // X-(X+Y) == -Y    X-(Y+X) == -Y
1718         if (match(Op1, m_c_Add(m_Specific(Op0), m_Value(Y))))
1719             return BinaryOperator::CreateNeg(Y);
1720 
1721         // (X-Y)-X == -Y
1722         if (match(Op0, m_Sub(m_Specific(Op1), m_Value(Y))))
1723             return BinaryOperator::CreateNeg(Y);
1724     }
1725 
1726     // (sub (or A, B), (xor A, B)) --> (and A, B)
1727     {
1728         Value* A, * B;
1729         if (match(Op1, m_Xor(m_Value(A), m_Value(B))) &&
1730             match(Op0, m_c_Or(m_Specific(A), m_Specific(B))))
1731             return BinaryOperator::CreateAnd(A, B);
1732     }
1733 
1734     {
1735         Value* Y;
1736         // ((X | Y) - X) --> (~X & Y)
1737         if (match(Op0, m_OneUse(m_c_Or(m_Value(Y), m_Specific(Op1)))))
1738             return BinaryOperator::CreateAnd(
1739                 Y, Builder.CreateNot(Op1, Op1->getName() + ".not"));
1740     }
1741 
1742     if (Op1->hasOneUse()) {
1743         Value* X = nullptr, * Y = nullptr, * Z = nullptr;
1744         Constant* C = nullptr;
1745 
1746         // (X - (Y - Z))  -->  (X + (Z - Y)).
1747         if (match(Op1, m_Sub(m_Value(Y), m_Value(Z))))
1748             return BinaryOperator::CreateAdd(Op0,
1749                 Builder.CreateSub(Z, Y, Op1->getName()));
1750 
1751         // (X - (X & Y))   -->   (X & ~Y)
1752         if (match(Op1, m_c_And(m_Value(Y), m_Specific(Op0))))
1753             return BinaryOperator::CreateAnd(Op0,
1754                 Builder.CreateNot(Y, Y->getName() + ".not"));
1755 
1756         // 0 - (X sdiv C)  -> (X sdiv -C)  provided the negation doesn't overflow.
1757         if (match(Op1, m_SDiv(m_Value(X), m_Constant(C))) && match(Op0, m_Zero()) &&
1758             C->isNotMinSignedValue() && !C->isOneValue())
1759             return BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(C));
1760 
1761         // 0 - (X << Y)  -> (-X << Y)   when X is freely negatable.
1762         if (match(Op1, m_Shl(m_Value(X), m_Value(Y))) && match(Op0, m_Zero()))
1763             if (Value * XNeg = dyn_castNegVal(X))
1764                 return BinaryOperator::CreateShl(XNeg, Y);
1765 
1766         // Subtracting -1/0 is the same as adding 1/0:
1767         // sub [nsw] Op0, sext(bool Y) -> add [nsw] Op0, zext(bool Y)
1768         // 'nuw' is dropped in favor of the canonical form.
1769         if (match(Op1, m_SExt(m_Value(Y))) &&
1770             Y->getType()->getScalarSizeInBits() == 1) {
1771             Value* Zext = Builder.CreateZExt(Y, I.getType());
1772             BinaryOperator* Add = BinaryOperator::CreateAdd(Op0, Zext);
1773             Add->setHasNoSignedWrap(I.hasNoSignedWrap());
1774             return Add;
1775         }
1776 
1777         // X - A*-B -> X + A*B
1778         // X - -A*B -> X + A*B
1779         Value* A, * B;
1780         Constant* CI;
1781         if (match(Op1, m_c_Mul(m_Value(A), m_Neg(m_Value(B)))))
1782             return BinaryOperator::CreateAdd(Op0, Builder.CreateMul(A, B));
1783 
1784         // X - A*CI -> X + A*-CI
1785         // No need to handle commuted multiply because multiply handling will
1786         // ensure constant will be move to the right hand side.
1787         if (match(Op1, m_Mul(m_Value(A), m_Constant(CI)))) {
1788             Value* NewMul = Builder.CreateMul(A, ConstantExpr::getNeg(CI));
1789             return BinaryOperator::CreateAdd(Op0, NewMul);
1790         }
1791     }
1792 
1793     // Optimize pointer differences into the same array into a size.  Consider:
1794     //  &A[10] - &A[0]: we should compile this to "10".
1795     Value* LHSOp, * RHSOp;
1796     if (match(Op0, m_PtrToInt(m_Value(LHSOp))) &&
1797         match(Op1, m_PtrToInt(m_Value(RHSOp))))
1798         if (Value * Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
1799             return replaceInstUsesWith(I, Res);
1800 
1801     // trunc(p)-trunc(q) -> trunc(p-q)
1802     if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) &&
1803         match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp)))))
1804         if (Value * Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
1805             return replaceInstUsesWith(I, Res);
1806 
1807     // Canonicalize a shifty way to code absolute value to the common pattern.
1808     // There are 2 potential commuted variants.
1809     // We're relying on the fact that we only do this transform when the shift has
1810     // exactly 2 uses and the xor has exactly 1 use (otherwise, we might increase
1811     // instructions).
1812     Value* A;
1813     const APInt* ShAmt;
1814     Type* Ty = I.getType();
1815     if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) &&
1816         Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 &&
1817         match(Op0, m_OneUse(m_c_Xor(m_Specific(A), m_Specific(Op1))))) {
1818         // B = ashr i32 A, 31 ; smear the sign bit
1819         // sub (xor A, B), B  ; flip bits if negative and subtract -1 (add 1)
1820         // --> (A < 0) ? -A : A
1821         Value* Cmp = Builder.CreateICmpSLT(A, ConstantInt::getNullValue(Ty));
1822         // Copy the nuw/nsw flags from the sub to the negate.
1823         Value* Neg = Builder.CreateNeg(A, "", I.hasNoUnsignedWrap(),
1824             I.hasNoSignedWrap());
1825         return SelectInst::Create(Cmp, Neg, A);
1826     }
1827 
1828     bool Changed = false;
1829     if (!I.hasNoSignedWrap() && willNotOverflowSignedSub(Op0, Op1, I)) {
1830         Changed = true;
1831         I.setHasNoSignedWrap(true);
1832     }
1833     if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedSub(Op0, Op1, I)) {
1834         Changed = true;
1835         I.setHasNoUnsignedWrap(true);
1836     }
1837 
1838     return Changed ? &I : nullptr;
1839 }
1840 
visitFSub(BinaryOperator & I)1841 Instruction* InstCombiner::visitFSub(BinaryOperator& I) {
1842     if (Value * V = SimplifyFSubInst(I.getOperand(0), I.getOperand(1),
1843         I.getFastMathFlags(),
1844         SQ.getWithInstruction(&I)))
1845         return replaceInstUsesWith(I, V);
1846 
1847     if (Instruction * X = foldShuffledBinop(I))
1848         return X;
1849 
1850     // Subtraction from -0.0 is the canonical form of fneg.
1851     // fsub nsz 0, X ==> fsub nsz -0.0, X
1852     Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
1853     if (I.hasNoSignedZeros() && match(Op0, m_PosZeroFP()))
1854         return BinaryOperator::CreateFNegFMF(Op1, &I);
1855 
1856     // If Op0 is not -0.0 or we can ignore -0.0: Z - (X - Y) --> Z + (Y - X)
1857     // Canonicalize to fadd to make analysis easier.
1858     // This can also help codegen because fadd is commutative.
1859     // Note that if this fsub was really an fneg, the fadd with -0.0 will get
1860     // killed later. We still limit that particular transform with 'hasOneUse'
1861     // because an fneg is assumed better/cheaper than a generic fsub.
1862     Value* X, * Y;
1863     if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) {
1864         if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) {
1865             Value* NewSub = Builder.CreateFSubFMF(Y, X, &I);
1866             return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I);
1867         }
1868     }
1869 
1870     if (isa<Constant>(Op0))
1871         if (SelectInst * SI = dyn_cast<SelectInst>(Op1))
1872             if (Instruction * NV = FoldOpIntoSelect(I, SI))
1873                 return NV;
1874 
1875     // X - C --> X + (-C)
1876     // But don't transform constant expressions because there's an inverse fold
1877     // for X + (-Y) --> X - Y.
1878     Constant* C;
1879     if (match(Op1, m_Constant(C)) && !isa<ConstantExpr>(Op1))
1880         return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I);
1881 
1882     // X - (-Y) --> X + Y
1883     if (match(Op1, m_FNeg(m_Value(Y))))
1884         return BinaryOperator::CreateFAddFMF(Op0, Y, &I);
1885 
1886     // Similar to above, but look through a cast of the negated value:
1887     // X - (fptrunc(-Y)) --> X + fptrunc(Y)
1888     if (match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) {
1889         Value* TruncY = Builder.CreateFPTrunc(Y, I.getType());
1890         return BinaryOperator::CreateFAddFMF(Op0, TruncY, &I);
1891     }
1892     // X - (fpext(-Y)) --> X + fpext(Y)
1893     if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) {
1894         Value* ExtY = Builder.CreateFPExt(Y, I.getType());
1895         return BinaryOperator::CreateFAddFMF(Op0, ExtY, &I);
1896     }
1897 
1898     // Handle specials cases for FSub with selects feeding the operation
1899     if (Value * V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
1900         return replaceInstUsesWith(I, V);
1901 
1902     if (I.hasAllowReassoc() && I.hasNoSignedZeros()) {
1903         if (Value * V = FAddCombine(Builder).simplify(&I))
1904             return replaceInstUsesWith(I, V);
1905     }
1906 
1907     return nullptr;
1908 }
1909