1 /*========================== begin_copyright_notice ============================
2
3 Copyright (C) 2018-2021 Intel Corporation
4
5 SPDX-License-Identifier: MIT
6
7 ============================= end_copyright_notice ===========================*/
8
9 /*========================== begin_copyright_notice ============================
10
11 This file is distributed under the University of Illinois Open Source License.
12 See LICENSE.TXT for details.
13
14 ============================= end_copyright_notice ===========================*/
15
16 // This file implements the visitShl, visitLShr, and visitAShr functions.
17
18 #include "common/LLVMWarningsPush.hpp"
19 #include "InstCombineInternal.h"
20 #include "llvm/Analysis/ConstantFolding.h"
21 #include "llvm/Analysis/InstructionSimplify.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/PatternMatch.h"
24 #include "Probe/Assertion.h"
25
26 using namespace llvm;
27 using namespace PatternMatch;
28 using namespace IGCombiner;
29
30 #define DEBUG_TYPE "instcombine"
31
commonShiftTransforms(BinaryOperator & I)32 Instruction* InstCombiner::commonShiftTransforms(BinaryOperator& I) {
33 Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
34 IGC_ASSERT(Op0->getType() == Op1->getType());
35
36 // See if we can fold away this shift.
37 if (SimplifyDemandedInstructionBits(I))
38 return &I;
39
40 // Try to fold constant and into select arguments.
41 if (isa<Constant>(Op0))
42 if (SelectInst * SI = dyn_cast<SelectInst>(Op1))
43 if (Instruction * R = FoldOpIntoSelect(I, SI))
44 return R;
45
46 if (Constant * CUI = dyn_cast<Constant>(Op1))
47 if (Instruction * Res = FoldShiftByConstant(Op0, CUI, I))
48 return Res;
49
50 // (C1 shift (A add C2)) -> (C1 shift C2) shift A)
51 // iff A and C2 are both positive.
52 Value* A;
53 Constant* C;
54 if (match(Op0, m_Constant()) && match(Op1, m_Add(m_Value(A), m_Constant(C))))
55 if (isKnownNonNegative(A, DL, 0, &AC, &I, &DT) &&
56 isKnownNonNegative(C, DL, 0, &AC, &I, &DT))
57 return BinaryOperator::Create(
58 I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), Op0, C), A);
59
60 // X shift (A srem B) -> X shift (A and B-1) iff B is a power of 2.
61 // Because shifts by negative values (which could occur if A were negative)
62 // are undefined.
63 const APInt* B;
64 if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Power2(B)))) {
65 // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't
66 // demand the sign bit (and many others) here??
67 Value* Rem = Builder.CreateAnd(A, ConstantInt::get(I.getType(), *B - 1),
68 Op1->getName());
69 I.setOperand(1, Rem);
70 return &I;
71 }
72
73 return nullptr;
74 }
75
76 /// Return true if we can simplify two logical (either left or right) shifts
77 /// that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
canEvaluateShiftedShift(unsigned OuterShAmt,bool IsOuterShl,Instruction * InnerShift,InstCombiner & IC,Instruction * CxtI)78 static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
79 Instruction* InnerShift, InstCombiner& IC,
80 Instruction* CxtI) {
81 IGC_ASSERT_MESSAGE(InnerShift->isLogicalShift(), "Unexpected instruction type");
82
83 // We need constant scalar or constant splat shifts.
84 const APInt* InnerShiftConst;
85 if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst)))
86 return false;
87
88 // Two logical shifts in the same direction:
89 // shl (shl X, C1), C2 --> shl X, C1 + C2
90 // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
91 bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
92 if (IsInnerShl == IsOuterShl)
93 return true;
94
95 // Equal shift amounts in opposite directions become bitwise 'and':
96 // lshr (shl X, C), C --> and X, C'
97 // shl (lshr X, C), C --> and X, C'
98 if (*InnerShiftConst == OuterShAmt)
99 return true;
100
101 // If the 2nd shift is bigger than the 1st, we can fold:
102 // lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3
103 // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3
104 // but it isn't profitable unless we know the and'd out bits are already zero.
105 // Also, check that the inner shift is valid (less than the type width) or
106 // we'll crash trying to produce the bit mask for the 'and'.
107 unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
108 if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) {
109 unsigned InnerShAmt = InnerShiftConst->getZExtValue();
110 unsigned MaskShift =
111 IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
112 APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
113 if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI))
114 return true;
115 }
116
117 return false;
118 }
119
120 /// See if we can compute the specified value, but shifted logically to the left
121 /// or right by some number of bits. This should return true if the expression
122 /// can be computed for the same cost as the current expression tree. This is
123 /// used to eliminate extraneous shifting from things like:
124 /// %C = shl i128 %A, 64
125 /// %D = shl i128 %B, 96
126 /// %E = or i128 %C, %D
127 /// %F = lshr i128 %E, 64
128 /// where the client will ask if E can be computed shifted right by 64-bits. If
129 /// this succeeds, getShiftedValue() will be called to produce the value.
canEvaluateShifted(Value * V,unsigned NumBits,bool IsLeftShift,InstCombiner & IC,Instruction * CxtI)130 static bool canEvaluateShifted(Value* V, unsigned NumBits, bool IsLeftShift,
131 InstCombiner& IC, Instruction* CxtI) {
132 // We can always evaluate constants shifted.
133 if (isa<Constant>(V))
134 return true;
135
136 Instruction* I = dyn_cast<Instruction>(V);
137 if (!I) return false;
138
139 // If this is the opposite shift, we can directly reuse the input of the shift
140 // if the needed bits are already zero in the input. This allows us to reuse
141 // the value which means that we don't care if the shift has multiple uses.
142 // TODO: Handle opposite shift by exact value.
143 ConstantInt* CI = nullptr;
144 if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) ||
145 (!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) {
146 if (CI->getValue() == NumBits) {
147 // TODO: Check that the input bits are already zero with MaskedValueIsZero
148 #if 0
149 // If this is a truncate of a logical shr, we can truncate it to a smaller
150 // lshr iff we know that the bits we would otherwise be shifting in are
151 // already zeros.
152 uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
153 uint32_t BitWidth = Ty->getScalarSizeInBits();
154 if (MaskedValueIsZero(I->getOperand(0),
155 APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth - BitWidth)) &&
156 CI->getLimitedValue(BitWidth) < BitWidth) {
157 return CanEvaluateTruncated(I->getOperand(0), Ty);
158 }
159 #endif
160
161 }
162 }
163
164 // We can't mutate something that has multiple uses: doing so would
165 // require duplicating the instruction in general, which isn't profitable.
166 if (!I->hasOneUse()) return false;
167
168 switch (I->getOpcode()) {
169 default: return false;
170 case Instruction::And:
171 case Instruction::Or:
172 case Instruction::Xor:
173 // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
174 return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
175 canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
176
177 case Instruction::Shl:
178 case Instruction::LShr:
179 return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI);
180
181 case Instruction::Select: {
182 SelectInst* SI = cast<SelectInst>(I);
183 Value* TrueVal = SI->getTrueValue();
184 Value* FalseVal = SI->getFalseValue();
185 return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
186 canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
187 }
188 case Instruction::PHI: {
189 // We can change a phi if we can change all operands. Note that we never
190 // get into trouble with cyclic PHIs here because we only consider
191 // instructions with a single use.
192 PHINode* PN = cast<PHINode>(I);
193 for (Value* IncValue : PN->incoming_values())
194 if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
195 return false;
196 return true;
197 }
198 }
199 }
200
201 /// Fold OuterShift (InnerShift X, C1), C2.
202 /// See canEvaluateShiftedShift() for the constraints on these instructions.
foldShiftedShift(BinaryOperator * InnerShift,unsigned OuterShAmt,bool IsOuterShl,InstCombiner::BuilderTy & Builder)203 static Value* foldShiftedShift(BinaryOperator* InnerShift, unsigned OuterShAmt,
204 bool IsOuterShl,
205 InstCombiner::BuilderTy& Builder) {
206 bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
207 Type* ShType = InnerShift->getType();
208 unsigned TypeWidth = ShType->getScalarSizeInBits();
209
210 // We only accept shifts-by-a-constant in canEvaluateShifted().
211 const APInt* C1;
212 match(InnerShift->getOperand(1), m_APInt(C1));
213 unsigned InnerShAmt = C1->getZExtValue();
214
215 // Change the shift amount and clear the appropriate IR flags.
216 auto NewInnerShift = [&](unsigned ShAmt) {
217 InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt));
218 if (IsInnerShl) {
219 InnerShift->setHasNoUnsignedWrap(false);
220 InnerShift->setHasNoSignedWrap(false);
221 }
222 else {
223 InnerShift->setIsExact(false);
224 }
225 return InnerShift;
226 };
227
228 // Two logical shifts in the same direction:
229 // shl (shl X, C1), C2 --> shl X, C1 + C2
230 // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
231 if (IsInnerShl == IsOuterShl) {
232 // If this is an oversized composite shift, then unsigned shifts get 0.
233 if (InnerShAmt + OuterShAmt >= TypeWidth)
234 return Constant::getNullValue(ShType);
235
236 return NewInnerShift(InnerShAmt + OuterShAmt);
237 }
238
239 // Equal shift amounts in opposite directions become bitwise 'and':
240 // lshr (shl X, C), C --> and X, C'
241 // shl (lshr X, C), C --> and X, C'
242 if (InnerShAmt == OuterShAmt) {
243 APInt Mask = IsInnerShl
244 ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt)
245 : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt);
246 Value* And = Builder.CreateAnd(InnerShift->getOperand(0),
247 ConstantInt::get(ShType, Mask));
248 if (auto * AndI = dyn_cast<Instruction>(And)) {
249 AndI->moveBefore(InnerShift);
250 AndI->takeName(InnerShift);
251 }
252 return And;
253 }
254
255 IGC_ASSERT_MESSAGE(InnerShAmt > OuterShAmt, "Unexpected opposite direction logical shift pair");
256
257 // In general, we would need an 'and' for this transform, but
258 // canEvaluateShiftedShift() guarantees that the masked-off bits are not used.
259 // lshr (shl X, C1), C2 --> shl X, C1 - C2
260 // shl (lshr X, C1), C2 --> lshr X, C1 - C2
261 return NewInnerShift(InnerShAmt - OuterShAmt);
262 }
263
264 /// When canEvaluateShifted() returns true for an expression, this function
265 /// inserts the new computation that produces the shifted value.
getShiftedValue(Value * V,unsigned NumBits,bool isLeftShift,InstCombiner & IC,const DataLayout & DL)266 static Value* getShiftedValue(Value* V, unsigned NumBits, bool isLeftShift,
267 InstCombiner& IC, const DataLayout& DL) {
268 // We can always evaluate constants shifted.
269 if (Constant * C = dyn_cast<Constant>(V)) {
270 if (isLeftShift)
271 V = IC.Builder.CreateShl(C, NumBits);
272 else
273 V = IC.Builder.CreateLShr(C, NumBits);
274 // If we got a constantexpr back, try to simplify it with TD info.
275 if (auto * C = dyn_cast<Constant>(V))
276 if (auto * FoldedC =
277 ConstantFoldConstant(C, DL, &IC.getTargetLibraryInfo()))
278 V = FoldedC;
279 return V;
280 }
281
282 Instruction* I = cast<Instruction>(V);
283 IC.Worklist.Add(I);
284
285 switch (I->getOpcode()) {
286 default: IGC_ASSERT_EXIT_MESSAGE(0, "Inconsistency with CanEvaluateShifted");
287 case Instruction::And:
288 case Instruction::Or:
289 case Instruction::Xor:
290 // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
291 I->setOperand(
292 0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));
293 I->setOperand(
294 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
295 return I;
296
297 case Instruction::Shl:
298 case Instruction::LShr:
299 return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift,
300 IC.Builder);
301
302 case Instruction::Select:
303 I->setOperand(
304 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
305 I->setOperand(
306 2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL));
307 return I;
308 case Instruction::PHI: {
309 // We can change a phi if we can change all operands. Note that we never
310 // get into trouble with cyclic PHIs here because we only consider
311 // instructions with a single use.
312 PHINode* PN = cast<PHINode>(I);
313 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
314 PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits,
315 isLeftShift, IC, DL));
316 return PN;
317 }
318 }
319 }
320
321 // If this is a bitwise operator or add with a constant RHS we might be able
322 // to pull it through a shift.
canShiftBinOpWithConstantRHS(BinaryOperator & Shift,BinaryOperator * BO,const APInt & C)323 static bool canShiftBinOpWithConstantRHS(BinaryOperator& Shift,
324 BinaryOperator* BO,
325 const APInt& C) {
326 bool IsValid = true; // Valid only for And, Or Xor,
327 bool HighBitSet = false; // Transform ifhigh bit of constant set?
328
329 switch (BO->getOpcode()) {
330 default: IsValid = false; break; // Do not perform transform!
331 case Instruction::Add:
332 IsValid = Shift.getOpcode() == Instruction::Shl;
333 break;
334 case Instruction::Or:
335 case Instruction::Xor:
336 HighBitSet = false;
337 break;
338 case Instruction::And:
339 HighBitSet = true;
340 break;
341 }
342
343 // If this is a signed shift right, and the high bit is modified
344 // by the logical operation, do not perform the transformation.
345 // The HighBitSet boolean indicates the value of the high bit of
346 // the constant which would cause it to be modified for this
347 // operation.
348 //
349 if (IsValid && Shift.getOpcode() == Instruction::AShr)
350 IsValid = C.isNegative() == HighBitSet;
351
352 return IsValid;
353 }
354
FoldShiftByConstant(Value * Op0,Constant * Op1,BinaryOperator & I)355 Instruction* InstCombiner::FoldShiftByConstant(Value* Op0, Constant* Op1,
356 BinaryOperator& I) {
357 bool isLeftShift = I.getOpcode() == Instruction::Shl;
358
359 const APInt* Op1C;
360 if (!match(Op1, m_APInt(Op1C)))
361 return nullptr;
362
363 // See if we can propagate this shift into the input, this covers the trivial
364 // cast of lshr(shl(x,c1),c2) as well as other more complex cases.
365 if (I.getOpcode() != Instruction::AShr &&
366 canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) {
367 LLVM_DEBUG(
368 dbgs() << "ICE: GetShiftedValue propagating shift through expression"
369 " to eliminate shift:\n IN: "
370 << *Op0 << "\n SH: " << I << "\n");
371
372 return replaceInstUsesWith(
373 I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL));
374 }
375
376 // See if we can simplify any instructions used by the instruction whose sole
377 // purpose is to compute bits we don't care about.
378 unsigned TypeBits = Op0->getType()->getScalarSizeInBits();
379
380 IGC_ASSERT_MESSAGE(!Op1C->uge(TypeBits), "Shift over the type width should have been removed already");
381
382 if (Instruction * FoldedShift = foldBinOpIntoSelectOrPhi(I))
383 return FoldedShift;
384
385 // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2))
386 if (TruncInst * TI = dyn_cast<TruncInst>(Op0)) {
387 Instruction* TrOp = dyn_cast<Instruction>(TI->getOperand(0));
388 // If 'shift2' is an ashr, we would have to get the sign bit into a funny
389 // place. Don't try to do this transformation in this case. Also, we
390 // require that the input operand is a shift-by-constant so that we have
391 // confidence that the shifts will get folded together. We could do this
392 // xform in more cases, but it is unlikely to be profitable.
393 if (TrOp && I.isLogicalShift() && TrOp->isShift() &&
394 isa<ConstantInt>(TrOp->getOperand(1))) {
395 // Okay, we'll do this xform. Make the shift of shift.
396 Constant* ShAmt =
397 ConstantExpr::getZExt(cast<Constant>(Op1), TrOp->getType());
398 // (shift2 (shift1 & 0x00FF), c2)
399 Value* NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName());
400
401 // For logical shifts, the truncation has the effect of making the high
402 // part of the register be zeros. Emulate this by inserting an AND to
403 // clear the top bits as needed. This 'and' will usually be zapped by
404 // other xforms later if dead.
405 unsigned SrcSize = TrOp->getType()->getScalarSizeInBits();
406 unsigned DstSize = TI->getType()->getScalarSizeInBits();
407 APInt MaskV(APInt::getLowBitsSet(SrcSize, DstSize));
408
409 // The mask we constructed says what the trunc would do if occurring
410 // between the shifts. We want to know the effect *after* the second
411 // shift. We know that it is a logical shift by a constant, so adjust the
412 // mask as appropriate.
413 if (I.getOpcode() == Instruction::Shl)
414 MaskV <<= Op1C->getZExtValue();
415 else {
416 IGC_ASSERT_MESSAGE(I.getOpcode() == Instruction::LShr, "Unknown logical shift");
417 MaskV.lshrInPlace(Op1C->getZExtValue());
418 }
419
420 // shift1 & 0x00FF
421 Value* And = Builder.CreateAnd(NSh,
422 ConstantInt::get(I.getContext(), MaskV),
423 TI->getName());
424
425 // Return the value truncated to the interesting size.
426 return new TruncInst(And, I.getType());
427 }
428 }
429
430 if (Op0->hasOneUse()) {
431 if (BinaryOperator * Op0BO = dyn_cast<BinaryOperator>(Op0)) {
432 // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C)
433 Value* V1, * V2;
434 ConstantInt* CC;
435 switch (Op0BO->getOpcode()) {
436 default: break;
437 case Instruction::Add:
438 case Instruction::And:
439 case Instruction::Or:
440 case Instruction::Xor: {
441 // These operators commute.
442 // Turn (Y + (X >> C)) << C -> (X + (Y << C)) & (~0 << C)
443 if (isLeftShift && Op0BO->getOperand(1)->hasOneUse() &&
444 match(Op0BO->getOperand(1), m_Shr(m_Value(V1),
445 m_Specific(Op1)))) {
446 Value* YS = // (Y << C)
447 Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
448 // (X + (Y << C))
449 Value* X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1,
450 Op0BO->getOperand(1)->getName());
451 unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
452
453 APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
454 Constant* Mask = ConstantInt::get(I.getContext(), Bits);
455 if (VectorType * VT = dyn_cast<VectorType>(X->getType()))
456 Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
457 return BinaryOperator::CreateAnd(X, Mask);
458 }
459
460 // Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C))
461 Value* Op0BOOp1 = Op0BO->getOperand(1);
462 if (isLeftShift && Op0BOOp1->hasOneUse() &&
463 match(Op0BOOp1,
464 m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
465 m_ConstantInt(CC)))) {
466 Value* YS = // (Y << C)
467 Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
468 // X & (CC << C)
469 Value* XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1),
470 V1->getName() + ".mask");
471 return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM);
472 }
473 LLVM_FALLTHROUGH;
474 }
475
476 case Instruction::Sub: {
477 // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C)
478 if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
479 match(Op0BO->getOperand(0), m_Shr(m_Value(V1),
480 m_Specific(Op1)))) {
481 Value* YS = // (Y << C)
482 Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
483 // (X + (Y << C))
484 Value* X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS,
485 Op0BO->getOperand(0)->getName());
486 unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
487
488 APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
489 Constant* Mask = ConstantInt::get(I.getContext(), Bits);
490 if (VectorType * VT = dyn_cast<VectorType>(X->getType()))
491 Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
492 return BinaryOperator::CreateAnd(X, Mask);
493 }
494
495 // Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C)
496 if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
497 match(Op0BO->getOperand(0),
498 m_And(m_OneUse(m_Shr(m_Value(V1), m_Value(V2))),
499 m_ConstantInt(CC))) && V2 == Op1) {
500 Value* YS = // (Y << C)
501 Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
502 // X & (CC << C)
503 Value* XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1),
504 V1->getName() + ".mask");
505
506 return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS);
507 }
508
509 break;
510 }
511 }
512
513
514 // If the operand is a bitwise operator with a constant RHS, and the
515 // shift is the only use, we can pull it out of the shift.
516 const APInt* Op0C;
517 if (match(Op0BO->getOperand(1), m_APInt(Op0C))) {
518 if (canShiftBinOpWithConstantRHS(I, Op0BO, *Op0C)) {
519 Constant* NewRHS = ConstantExpr::get(I.getOpcode(),
520 cast<Constant>(Op0BO->getOperand(1)), Op1);
521
522 Value* NewShift =
523 Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1);
524 NewShift->takeName(Op0BO);
525
526 return BinaryOperator::Create(Op0BO->getOpcode(), NewShift,
527 NewRHS);
528 }
529 }
530
531 // If the operand is a subtract with a constant LHS, and the shift
532 // is the only use, we can pull it out of the shift.
533 // This folds (shl (sub C1, X), C2) -> (sub (C1 << C2), (shl X, C2))
534 if (isLeftShift && Op0BO->getOpcode() == Instruction::Sub &&
535 match(Op0BO->getOperand(0), m_APInt(Op0C))) {
536 Constant* NewRHS = ConstantExpr::get(I.getOpcode(),
537 cast<Constant>(Op0BO->getOperand(0)), Op1);
538
539 Value* NewShift = Builder.CreateShl(Op0BO->getOperand(1), Op1);
540 NewShift->takeName(Op0BO);
541
542 return BinaryOperator::CreateSub(NewRHS, NewShift);
543 }
544 }
545
546 // If we have a select that conditionally executes some binary operator,
547 // see if we can pull it the select and operator through the shift.
548 //
549 // For example, turning:
550 // shl (select C, (add X, C1), X), C2
551 // Into:
552 // Y = shl X, C2
553 // select C, (add Y, C1 << C2), Y
554 Value* Cond;
555 BinaryOperator* TBO;
556 Value* FalseVal;
557 if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)),
558 m_Value(FalseVal)))) {
559 const APInt* C;
560 if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal &&
561 match(TBO->getOperand(1), m_APInt(C)) &&
562 canShiftBinOpWithConstantRHS(I, TBO, *C)) {
563 Constant* NewRHS = ConstantExpr::get(I.getOpcode(),
564 cast<Constant>(TBO->getOperand(1)), Op1);
565
566 Value* NewShift =
567 Builder.CreateBinOp(I.getOpcode(), FalseVal, Op1);
568 Value* NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift,
569 NewRHS);
570 return SelectInst::Create(Cond, NewOp, NewShift);
571 }
572 }
573
574 BinaryOperator* FBO;
575 Value* TrueVal;
576 if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal),
577 m_OneUse(m_BinOp(FBO))))) {
578 const APInt* C;
579 if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal &&
580 match(FBO->getOperand(1), m_APInt(C)) &&
581 canShiftBinOpWithConstantRHS(I, FBO, *C)) {
582 Constant* NewRHS = ConstantExpr::get(I.getOpcode(),
583 cast<Constant>(FBO->getOperand(1)), Op1);
584
585 Value* NewShift =
586 Builder.CreateBinOp(I.getOpcode(), TrueVal, Op1);
587 Value* NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift,
588 NewRHS);
589 return SelectInst::Create(Cond, NewShift, NewOp);
590 }
591 }
592 }
593
594 return nullptr;
595 }
596
visitShl(BinaryOperator & I)597 Instruction* InstCombiner::visitShl(BinaryOperator& I) {
598 if (Value * V = SimplifyShlInst(I.getOperand(0), I.getOperand(1),
599 I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
600 SQ.getWithInstruction(&I)))
601 return replaceInstUsesWith(I, V);
602
603 if (Instruction * X = foldShuffledBinop(I))
604 return X;
605
606 if (Instruction * V = commonShiftTransforms(I))
607 return V;
608
609 Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
610 Type* Ty = I.getType();
611 const APInt* ShAmtAPInt;
612 if (match(Op1, m_APInt(ShAmtAPInt))) {
613 unsigned ShAmt = ShAmtAPInt->getZExtValue();
614 unsigned BitWidth = Ty->getScalarSizeInBits();
615
616 // shl (zext X), ShAmt --> zext (shl X, ShAmt)
617 // This is only valid if X would have zeros shifted out.
618 Value* X;
619 if (match(Op0, m_ZExt(m_Value(X)))) {
620 unsigned SrcWidth = X->getType()->getScalarSizeInBits();
621 if (ShAmt < SrcWidth &&
622 MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I))
623 return new ZExtInst(Builder.CreateShl(X, ShAmt), Ty);
624 }
625
626 // (X >> C) << C --> X & (-1 << C)
627 if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1)))) {
628 APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt));
629 return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
630 }
631
632 // FIXME: we do not yet transform non-exact shr's. The backend (DAGCombine)
633 // needs a few fixes for the rotate pattern recognition first.
634 const APInt* ShOp1;
635 if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1))))) {
636 unsigned ShrAmt = ShOp1->getZExtValue();
637 if (ShrAmt < ShAmt) {
638 // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1)
639 Constant* ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt);
640 auto* NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
641 NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
642 NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
643 return NewShl;
644 }
645 if (ShrAmt > ShAmt) {
646 // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2)
647 Constant* ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt);
648 auto* NewShr = BinaryOperator::Create(
649 cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff);
650 NewShr->setIsExact(true);
651 return NewShr;
652 }
653 }
654
655 if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) {
656 unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
657 // Oversized shifts are simplified to zero in InstSimplify.
658 if (AmtSum < BitWidth)
659 // (X << C1) << C2 --> X << (C1 + C2)
660 return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum));
661 }
662
663 // If the shifted-out value is known-zero, then this is a NUW shift.
664 if (!I.hasNoUnsignedWrap() &&
665 MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) {
666 I.setHasNoUnsignedWrap();
667 return &I;
668 }
669
670 // If the shifted-out value is all signbits, then this is a NSW shift.
671 if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) {
672 I.setHasNoSignedWrap();
673 return &I;
674 }
675 }
676
677 // Transform (x >> y) << y to x & (-1 << y)
678 // Valid for any type of right-shift.
679 Value* X;
680 if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) {
681 Constant* AllOnes = ConstantInt::getAllOnesValue(Ty);
682 Value* Mask = Builder.CreateShl(AllOnes, Op1);
683 return BinaryOperator::CreateAnd(Mask, X);
684 }
685
686 Constant* C1;
687 if (match(Op1, m_Constant(C1))) {
688 Constant* C2;
689 Value* X;
690 // (C2 << X) << C1 --> (C2 << C1) << X
691 if (match(Op0, m_OneUse(m_Shl(m_Constant(C2), m_Value(X)))))
692 return BinaryOperator::CreateShl(ConstantExpr::getShl(C2, C1), X);
693
694 // (X * C2) << C1 --> X * (C2 << C1)
695 if (match(Op0, m_Mul(m_Value(X), m_Constant(C2))))
696 return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1));
697 }
698
699 return nullptr;
700 }
701
visitLShr(BinaryOperator & I)702 Instruction* InstCombiner::visitLShr(BinaryOperator& I) {
703 if (Value * V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
704 SQ.getWithInstruction(&I)))
705 return replaceInstUsesWith(I, V);
706
707 if (Instruction * X = foldShuffledBinop(I))
708 return X;
709
710 if (Instruction * R = commonShiftTransforms(I))
711 return R;
712
713 Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
714 Type* Ty = I.getType();
715 const APInt* ShAmtAPInt;
716 if (match(Op1, m_APInt(ShAmtAPInt))) {
717 unsigned ShAmt = ShAmtAPInt->getZExtValue();
718 unsigned BitWidth = Ty->getScalarSizeInBits();
719 auto* II = dyn_cast<IntrinsicInst>(Op0);
720 if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt &&
721 (II->getIntrinsicID() == Intrinsic::ctlz ||
722 II->getIntrinsicID() == Intrinsic::cttz ||
723 II->getIntrinsicID() == Intrinsic::ctpop)) {
724 // ctlz.i32(x)>>5 --> zext(x == 0)
725 // cttz.i32(x)>>5 --> zext(x == 0)
726 // ctpop.i32(x)>>5 --> zext(x == -1)
727 bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop;
728 Constant* RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0);
729 Value* Cmp = Builder.CreateICmpEQ(II->getArgOperand(0), RHS);
730 return new ZExtInst(Cmp, Ty);
731 }
732
733 Value* X;
734 const APInt* ShOp1;
735 if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) {
736 unsigned ShlAmt = ShOp1->getZExtValue();
737 if (ShlAmt < ShAmt) {
738 Constant* ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
739 if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
740 // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1)
741 auto* NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff);
742 NewLShr->setIsExact(I.isExact());
743 return NewLShr;
744 }
745 // (X << C1) >>u C2 --> (X >>u (C2 - C1)) & (-1 >> C2)
746 Value* NewLShr = Builder.CreateLShr(X, ShiftDiff, "", I.isExact());
747 APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
748 return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask));
749 }
750 if (ShlAmt > ShAmt) {
751 Constant* ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
752 if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
753 // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2)
754 auto* NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
755 NewShl->setHasNoUnsignedWrap(true);
756 return NewShl;
757 }
758 // (X << C1) >>u C2 --> X << (C1 - C2) & (-1 >> C2)
759 Value* NewShl = Builder.CreateShl(X, ShiftDiff);
760 APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
761 return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask));
762 }
763 IGC_ASSERT(ShlAmt == ShAmt);
764 // (X << C) >>u C --> X & (-1 >>u C)
765 APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
766 return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
767 }
768
769 if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) &&
770 (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) {
771 IGC_ASSERT_MESSAGE(ShAmt < X->getType()->getScalarSizeInBits(), "Big shift not simplified to zero?");
772 // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN
773 Value* NewLShr = Builder.CreateLShr(X, ShAmt);
774 return new ZExtInst(NewLShr, Ty);
775 }
776
777 if (match(Op0, m_SExt(m_Value(X))) &&
778 (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) {
779 // Are we moving the sign bit to the low bit and widening with high zeros?
780 unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits();
781 if (ShAmt == BitWidth - 1) {
782 // lshr (sext i1 X to iN), N-1 --> zext X to iN
783 if (SrcTyBitWidth == 1)
784 return new ZExtInst(X, Ty);
785
786 // lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN
787 if (Op0->hasOneUse()) {
788 Value* NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1);
789 return new ZExtInst(NewLShr, Ty);
790 }
791 }
792
793 // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN
794 if (ShAmt == BitWidth - SrcTyBitWidth && Op0->hasOneUse()) {
795 // The new shift amount can't be more than the narrow source type.
796 unsigned NewShAmt = std::min(ShAmt, SrcTyBitWidth - 1);
797 Value* AShr = Builder.CreateAShr(X, NewShAmt);
798 return new ZExtInst(AShr, Ty);
799 }
800 }
801
802 if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) {
803 unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
804 // Oversized shifts are simplified to zero in InstSimplify.
805 if (AmtSum < BitWidth)
806 // (X >>u C1) >>u C2 --> X >>u (C1 + C2)
807 return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum));
808 }
809
810 // If the shifted-out value is known-zero, then this is an exact shift.
811 if (!I.isExact() &&
812 MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
813 I.setIsExact();
814 return &I;
815 }
816 }
817
818 // Transform (x << y) >> y to x & (-1 >> y)
819 Value* X;
820 if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) {
821 Constant* AllOnes = ConstantInt::getAllOnesValue(Ty);
822 Value* Mask = Builder.CreateLShr(AllOnes, Op1);
823 return BinaryOperator::CreateAnd(Mask, X);
824 }
825
826 return nullptr;
827 }
828
visitAShr(BinaryOperator & I)829 Instruction* InstCombiner::visitAShr(BinaryOperator& I) {
830 if (Value * V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
831 SQ.getWithInstruction(&I)))
832 return replaceInstUsesWith(I, V);
833
834 if (Instruction * X = foldShuffledBinop(I))
835 return X;
836
837 if (Instruction * R = commonShiftTransforms(I))
838 return R;
839
840 Value* Op0 = I.getOperand(0), * Op1 = I.getOperand(1);
841 Type* Ty = I.getType();
842 unsigned BitWidth = Ty->getScalarSizeInBits();
843 const APInt* ShAmtAPInt;
844 if (match(Op1, m_APInt(ShAmtAPInt)) && ShAmtAPInt->ult(BitWidth)) {
845 unsigned ShAmt = ShAmtAPInt->getZExtValue();
846
847 // If the shift amount equals the difference in width of the destination
848 // and source scalar types:
849 // ashr (shl (zext X), C), C --> sext X
850 Value* X;
851 if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) &&
852 ShAmt == BitWidth - X->getType()->getScalarSizeInBits())
853 return new SExtInst(X, Ty);
854
855 // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However,
856 // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
857 const APInt* ShOp1;
858 if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1))) &&
859 ShOp1->ult(BitWidth)) {
860 unsigned ShlAmt = ShOp1->getZExtValue();
861 if (ShlAmt < ShAmt) {
862 // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1)
863 Constant* ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
864 auto* NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff);
865 NewAShr->setIsExact(I.isExact());
866 return NewAShr;
867 }
868 if (ShlAmt > ShAmt) {
869 // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2)
870 Constant* ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
871 auto* NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff);
872 NewShl->setHasNoSignedWrap(true);
873 return NewShl;
874 }
875 }
876
877 if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1))) &&
878 ShOp1->ult(BitWidth)) {
879 unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
880 // Oversized arithmetic shifts replicate the sign bit.
881 AmtSum = std::min(AmtSum, BitWidth - 1);
882 // (X >>s C1) >>s C2 --> X >>s (C1 + C2)
883 return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum));
884 }
885
886 if (match(Op0, m_OneUse(m_SExt(m_Value(X)))) &&
887 (Ty->isVectorTy() || shouldChangeType(Ty, X->getType()))) {
888 // ashr (sext X), C --> sext (ashr X, C')
889 Type* SrcTy = X->getType();
890 ShAmt = std::min(ShAmt, SrcTy->getScalarSizeInBits() - 1);
891 Value* NewSh = Builder.CreateAShr(X, ConstantInt::get(SrcTy, ShAmt));
892 return new SExtInst(NewSh, Ty);
893 }
894
895 // If the shifted-out value is known-zero, then this is an exact shift.
896 if (!I.isExact() &&
897 MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
898 I.setIsExact();
899 return &I;
900 }
901 }
902
903 // See if we can turn a signed shr into an unsigned shr.
904 if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I))
905 return BinaryOperator::CreateLShr(Op0, Op1);
906
907 return nullptr;
908 }
909 #include "common/LLVMWarningsPop.hpp"
910