1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/AssumptionCache.h"
18 #include "llvm/Analysis/BasicAliasAnalysis.h"
19 #include "llvm/Analysis/GlobalsModRef.h"
20 #include "llvm/Analysis/Loads.h"
21 #include "llvm/Analysis/TargetTransformInfo.h"
22 #include "llvm/Analysis/ValueTracking.h"
23 #include "llvm/Analysis/VectorUtils.h"
24 #include "llvm/IR/Dominators.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/PatternMatch.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Transforms/Utils/Local.h"
30 #include <numeric>
31 
32 #define DEBUG_TYPE "vector-combine"
33 #include "llvm/Transforms/Utils/InstructionWorklist.h"
34 
35 using namespace llvm;
36 using namespace llvm::PatternMatch;
37 
38 STATISTIC(NumVecLoad, "Number of vector loads formed");
39 STATISTIC(NumVecCmp, "Number of vector compares formed");
40 STATISTIC(NumVecBO, "Number of vector binops formed");
41 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
42 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
43 STATISTIC(NumScalarBO, "Number of scalar binops formed");
44 STATISTIC(NumScalarCmp, "Number of scalar compares formed");
45 
46 static cl::opt<bool> DisableVectorCombine(
47     "disable-vector-combine", cl::init(false), cl::Hidden,
48     cl::desc("Disable all vector combine transforms"));
49 
50 static cl::opt<bool> DisableBinopExtractShuffle(
51     "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
52     cl::desc("Disable binop extract to shuffle transforms"));
53 
54 static cl::opt<unsigned> MaxInstrsToScan(
55     "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden,
56     cl::desc("Max number of instructions to scan for vector combining."));
57 
58 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
59 
60 namespace {
61 class VectorCombine {
62 public:
63   VectorCombine(Function &F, const TargetTransformInfo &TTI,
64                 const DominatorTree &DT, AAResults &AA, AssumptionCache &AC,
65                 bool TryEarlyFoldsOnly)
66       : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC),
67         TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
68 
69   bool run();
70 
71 private:
72   Function &F;
73   IRBuilder<> Builder;
74   const TargetTransformInfo &TTI;
75   const DominatorTree &DT;
76   AAResults &AA;
77   AssumptionCache &AC;
78 
79   /// If true, only perform beneficial early IR transforms. Do not introduce new
80   /// vector operations.
81   bool TryEarlyFoldsOnly;
82 
83   InstructionWorklist Worklist;
84 
85   // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
86   //       parameter. That should be updated to specific sub-classes because the
87   //       run loop was changed to dispatch on opcode.
88   bool vectorizeLoadInsert(Instruction &I);
89   bool widenSubvectorLoad(Instruction &I);
90   ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
91                                         ExtractElementInst *Ext1,
92                                         unsigned PreferredExtractIndex) const;
93   bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
94                              const Instruction &I,
95                              ExtractElementInst *&ConvertToShuffle,
96                              unsigned PreferredExtractIndex);
97   void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
98                      Instruction &I);
99   void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
100                        Instruction &I);
101   bool foldExtractExtract(Instruction &I);
102   bool foldInsExtFNeg(Instruction &I);
103   bool foldBitcastShuf(Instruction &I);
104   bool scalarizeBinopOrCmp(Instruction &I);
105   bool foldExtractedCmps(Instruction &I);
106   bool foldSingleElementStore(Instruction &I);
107   bool scalarizeLoadExtract(Instruction &I);
108   bool foldShuffleOfBinops(Instruction &I);
109   bool foldShuffleFromReductions(Instruction &I);
110   bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
111 
112   void replaceValue(Value &Old, Value &New) {
113     Old.replaceAllUsesWith(&New);
114     if (auto *NewI = dyn_cast<Instruction>(&New)) {
115       New.takeName(&Old);
116       Worklist.pushUsersToWorkList(*NewI);
117       Worklist.pushValue(NewI);
118     }
119     Worklist.pushValue(&Old);
120   }
121 
122   void eraseInstruction(Instruction &I) {
123     for (Value *Op : I.operands())
124       Worklist.pushValue(Op);
125     Worklist.remove(&I);
126     I.eraseFromParent();
127   }
128 };
129 } // namespace
130 
131 static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
132   // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
133   // The widened load may load data from dirty regions or create data races
134   // non-existent in the source.
135   if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
136       Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) ||
137       mustSuppressSpeculation(*Load))
138     return false;
139 
140   // We are potentially transforming byte-sized (8-bit) memory accesses, so make
141   // sure we have all of our type-based constraints in place for this target.
142   Type *ScalarTy = Load->getType()->getScalarType();
143   uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
144   unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
145   if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
146       ScalarSize % 8 != 0)
147     return false;
148 
149   return true;
150 }
151 
152 bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
153   // Match insert into fixed vector of scalar value.
154   // TODO: Handle non-zero insert index.
155   Value *Scalar;
156   if (!match(&I, m_InsertElt(m_Undef(), m_Value(Scalar), m_ZeroInt())) ||
157       !Scalar->hasOneUse())
158     return false;
159 
160   // Optionally match an extract from another vector.
161   Value *X;
162   bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt()));
163   if (!HasExtract)
164     X = Scalar;
165 
166   auto *Load = dyn_cast<LoadInst>(X);
167   if (!canWidenLoad(Load, TTI))
168     return false;
169 
170   Type *ScalarTy = Scalar->getType();
171   uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
172   unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
173 
174   // Check safety of replacing the scalar load with a larger vector load.
175   // We use minimal alignment (maximum flexibility) because we only care about
176   // the dereferenceable region. When calculating cost and creating a new op,
177   // we may use a larger value based on alignment attributes.
178   const DataLayout &DL = I.getModule()->getDataLayout();
179   Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
180   assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
181 
182   unsigned MinVecNumElts = MinVectorSize / ScalarSize;
183   auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false);
184   unsigned OffsetEltIndex = 0;
185   Align Alignment = Load->getAlign();
186   if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &AC,
187                                    &DT)) {
188     // It is not safe to load directly from the pointer, but we can still peek
189     // through gep offsets and check if it safe to load from a base address with
190     // updated alignment. If it is, we can shuffle the element(s) into place
191     // after loading.
192     unsigned OffsetBitWidth = DL.getIndexTypeSizeInBits(SrcPtr->getType());
193     APInt Offset(OffsetBitWidth, 0);
194     SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(DL, Offset);
195 
196     // We want to shuffle the result down from a high element of a vector, so
197     // the offset must be positive.
198     if (Offset.isNegative())
199       return false;
200 
201     // The offset must be a multiple of the scalar element to shuffle cleanly
202     // in the element's size.
203     uint64_t ScalarSizeInBytes = ScalarSize / 8;
204     if (Offset.urem(ScalarSizeInBytes) != 0)
205       return false;
206 
207     // If we load MinVecNumElts, will our target element still be loaded?
208     OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue();
209     if (OffsetEltIndex >= MinVecNumElts)
210       return false;
211 
212     if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &AC,
213                                      &DT))
214       return false;
215 
216     // Update alignment with offset value. Note that the offset could be negated
217     // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
218     // negation does not change the result of the alignment calculation.
219     Alignment = commonAlignment(Alignment, Offset.getZExtValue());
220   }
221 
222   // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
223   // Use the greater of the alignment on the load or its source pointer.
224   Alignment = std::max(SrcPtr->getPointerAlignment(DL), Alignment);
225   Type *LoadTy = Load->getType();
226   unsigned AS = Load->getPointerAddressSpace();
227   InstructionCost OldCost =
228       TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS);
229   APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0);
230   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
231   OldCost +=
232       TTI.getScalarizationOverhead(MinVecTy, DemandedElts,
233                                    /* Insert */ true, HasExtract, CostKind);
234 
235   // New pattern: load VecPtr
236   InstructionCost NewCost =
237       TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS);
238   // Optionally, we are shuffling the loaded vector element(s) into place.
239   // For the mask set everything but element 0 to undef to prevent poison from
240   // propagating from the extra loaded memory. This will also optionally
241   // shrink/grow the vector from the loaded size to the output size.
242   // We assume this operation has no cost in codegen if there was no offset.
243   // Note that we could use freeze to avoid poison problems, but then we might
244   // still need a shuffle to change the vector size.
245   auto *Ty = cast<FixedVectorType>(I.getType());
246   unsigned OutputNumElts = Ty->getNumElements();
247   SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
248   assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
249   Mask[0] = OffsetEltIndex;
250   if (OffsetEltIndex)
251     NewCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, MinVecTy, Mask);
252 
253   // We can aggressively convert to the vector form because the backend can
254   // invert this transform if it does not result in a performance win.
255   if (OldCost < NewCost || !NewCost.isValid())
256     return false;
257 
258   // It is safe and potentially profitable to load a vector directly:
259   // inselt undef, load Scalar, 0 --> load VecPtr
260   IRBuilder<> Builder(Load);
261   Value *CastedPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
262       SrcPtr, MinVecTy->getPointerTo(AS));
263   Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment);
264   VecLd = Builder.CreateShuffleVector(VecLd, Mask);
265 
266   replaceValue(I, *VecLd);
267   ++NumVecLoad;
268   return true;
269 }
270 
271 /// If we are loading a vector and then inserting it into a larger vector with
272 /// undefined elements, try to load the larger vector and eliminate the insert.
273 /// This removes a shuffle in IR and may allow combining of other loaded values.
274 bool VectorCombine::widenSubvectorLoad(Instruction &I) {
275   // Match subvector insert of fixed vector.
276   auto *Shuf = cast<ShuffleVectorInst>(&I);
277   if (!Shuf->isIdentityWithPadding())
278     return false;
279 
280   // Allow a non-canonical shuffle mask that is choosing elements from op1.
281   unsigned NumOpElts =
282       cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
283   unsigned OpIndex = any_of(Shuf->getShuffleMask(), [&NumOpElts](int M) {
284     return M >= (int)(NumOpElts);
285   });
286 
287   auto *Load = dyn_cast<LoadInst>(Shuf->getOperand(OpIndex));
288   if (!canWidenLoad(Load, TTI))
289     return false;
290 
291   // We use minimal alignment (maximum flexibility) because we only care about
292   // the dereferenceable region. When calculating cost and creating a new op,
293   // we may use a larger value based on alignment attributes.
294   auto *Ty = cast<FixedVectorType>(I.getType());
295   const DataLayout &DL = I.getModule()->getDataLayout();
296   Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
297   assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
298   Align Alignment = Load->getAlign();
299   if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), DL, Load, &AC, &DT))
300     return false;
301 
302   Alignment = std::max(SrcPtr->getPointerAlignment(DL), Alignment);
303   Type *LoadTy = Load->getType();
304   unsigned AS = Load->getPointerAddressSpace();
305 
306   // Original pattern: insert_subvector (load PtrOp)
307   // This conservatively assumes that the cost of a subvector insert into an
308   // undef value is 0. We could add that cost if the cost model accurately
309   // reflects the real cost of that operation.
310   InstructionCost OldCost =
311       TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS);
312 
313   // New pattern: load PtrOp
314   InstructionCost NewCost =
315       TTI.getMemoryOpCost(Instruction::Load, Ty, Alignment, AS);
316 
317   // We can aggressively convert to the vector form because the backend can
318   // invert this transform if it does not result in a performance win.
319   if (OldCost < NewCost || !NewCost.isValid())
320     return false;
321 
322   IRBuilder<> Builder(Load);
323   Value *CastedPtr =
324       Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Ty->getPointerTo(AS));
325   Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment);
326   replaceValue(I, *VecLd);
327   ++NumVecLoad;
328   return true;
329 }
330 
331 /// Determine which, if any, of the inputs should be replaced by a shuffle
332 /// followed by extract from a different index.
333 ExtractElementInst *VectorCombine::getShuffleExtract(
334     ExtractElementInst *Ext0, ExtractElementInst *Ext1,
335     unsigned PreferredExtractIndex = InvalidIndex) const {
336   auto *Index0C = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
337   auto *Index1C = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
338   assert(Index0C && Index1C && "Expected constant extract indexes");
339 
340   unsigned Index0 = Index0C->getZExtValue();
341   unsigned Index1 = Index1C->getZExtValue();
342 
343   // If the extract indexes are identical, no shuffle is needed.
344   if (Index0 == Index1)
345     return nullptr;
346 
347   Type *VecTy = Ext0->getVectorOperand()->getType();
348   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
349   assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
350   InstructionCost Cost0 =
351       TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
352   InstructionCost Cost1 =
353       TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
354 
355   // If both costs are invalid no shuffle is needed
356   if (!Cost0.isValid() && !Cost1.isValid())
357     return nullptr;
358 
359   // We are extracting from 2 different indexes, so one operand must be shuffled
360   // before performing a vector operation and/or extract. The more expensive
361   // extract will be replaced by a shuffle.
362   if (Cost0 > Cost1)
363     return Ext0;
364   if (Cost1 > Cost0)
365     return Ext1;
366 
367   // If the costs are equal and there is a preferred extract index, shuffle the
368   // opposite operand.
369   if (PreferredExtractIndex == Index0)
370     return Ext1;
371   if (PreferredExtractIndex == Index1)
372     return Ext0;
373 
374   // Otherwise, replace the extract with the higher index.
375   return Index0 > Index1 ? Ext0 : Ext1;
376 }
377 
378 /// Compare the relative costs of 2 extracts followed by scalar operation vs.
379 /// vector operation(s) followed by extract. Return true if the existing
380 /// instructions are cheaper than a vector alternative. Otherwise, return false
381 /// and if one of the extracts should be transformed to a shufflevector, set
382 /// \p ConvertToShuffle to that extract instruction.
383 bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
384                                           ExtractElementInst *Ext1,
385                                           const Instruction &I,
386                                           ExtractElementInst *&ConvertToShuffle,
387                                           unsigned PreferredExtractIndex) {
388   auto *Ext0IndexC = dyn_cast<ConstantInt>(Ext0->getOperand(1));
389   auto *Ext1IndexC = dyn_cast<ConstantInt>(Ext1->getOperand(1));
390   assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
391 
392   unsigned Opcode = I.getOpcode();
393   Type *ScalarTy = Ext0->getType();
394   auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType());
395   InstructionCost ScalarOpCost, VectorOpCost;
396 
397   // Get cost estimates for scalar and vector versions of the operation.
398   bool IsBinOp = Instruction::isBinaryOp(Opcode);
399   if (IsBinOp) {
400     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
401     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
402   } else {
403     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
404            "Expected a compare");
405     CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
406     ScalarOpCost = TTI.getCmpSelInstrCost(
407         Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred);
408     VectorOpCost = TTI.getCmpSelInstrCost(
409         Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred);
410   }
411 
412   // Get cost estimates for the extract elements. These costs will factor into
413   // both sequences.
414   unsigned Ext0Index = Ext0IndexC->getZExtValue();
415   unsigned Ext1Index = Ext1IndexC->getZExtValue();
416   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
417 
418   InstructionCost Extract0Cost =
419       TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Ext0Index);
420   InstructionCost Extract1Cost =
421       TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Ext1Index);
422 
423   // A more expensive extract will always be replaced by a splat shuffle.
424   // For example, if Ext0 is more expensive:
425   // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
426   // extelt (opcode (splat V0, Ext0), V1), Ext1
427   // TODO: Evaluate whether that always results in lowest cost. Alternatively,
428   //       check the cost of creating a broadcast shuffle and shuffling both
429   //       operands to element 0.
430   InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
431 
432   // Extra uses of the extracts mean that we include those costs in the
433   // vector total because those instructions will not be eliminated.
434   InstructionCost OldCost, NewCost;
435   if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) {
436     // Handle a special case. If the 2 extracts are identical, adjust the
437     // formulas to account for that. The extra use charge allows for either the
438     // CSE'd pattern or an unoptimized form with identical values:
439     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
440     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
441                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
442     OldCost = CheapExtractCost + ScalarOpCost;
443     NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
444   } else {
445     // Handle the general case. Each extract is actually a different value:
446     // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
447     OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
448     NewCost = VectorOpCost + CheapExtractCost +
449               !Ext0->hasOneUse() * Extract0Cost +
450               !Ext1->hasOneUse() * Extract1Cost;
451   }
452 
453   ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
454   if (ConvertToShuffle) {
455     if (IsBinOp && DisableBinopExtractShuffle)
456       return true;
457 
458     // If we are extracting from 2 different indexes, then one operand must be
459     // shuffled before performing the vector operation. The shuffle mask is
460     // poison except for 1 lane that is being translated to the remaining
461     // extraction lane. Therefore, it is a splat shuffle. Ex:
462     // ShufMask = { poison, poison, 0, poison }
463     // TODO: The cost model has an option for a "broadcast" shuffle
464     //       (splat-from-element-0), but no option for a more general splat.
465     NewCost +=
466         TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy);
467   }
468 
469   // Aggressively form a vector op if the cost is equal because the transform
470   // may enable further optimization.
471   // Codegen can reverse this transform (scalarize) if it was not profitable.
472   return OldCost < NewCost;
473 }
474 
475 /// Create a shuffle that translates (shifts) 1 element from the input vector
476 /// to a new element location.
477 static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
478                                  unsigned NewIndex, IRBuilder<> &Builder) {
479   // The shuffle mask is poison except for 1 lane that is being translated
480   // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
481   // ShufMask = { 2, poison, poison, poison }
482   auto *VecTy = cast<FixedVectorType>(Vec->getType());
483   SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
484   ShufMask[NewIndex] = OldIndex;
485   return Builder.CreateShuffleVector(Vec, ShufMask, "shift");
486 }
487 
488 /// Given an extract element instruction with constant index operand, shuffle
489 /// the source vector (shift the scalar element) to a NewIndex for extraction.
490 /// Return null if the input can be constant folded, so that we are not creating
491 /// unnecessary instructions.
492 static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
493                                             unsigned NewIndex,
494                                             IRBuilder<> &Builder) {
495   // Shufflevectors can only be created for fixed-width vectors.
496   if (!isa<FixedVectorType>(ExtElt->getOperand(0)->getType()))
497     return nullptr;
498 
499   // If the extract can be constant-folded, this code is unsimplified. Defer
500   // to other passes to handle that.
501   Value *X = ExtElt->getVectorOperand();
502   Value *C = ExtElt->getIndexOperand();
503   assert(isa<ConstantInt>(C) && "Expected a constant index operand");
504   if (isa<Constant>(X))
505     return nullptr;
506 
507   Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
508                                    NewIndex, Builder);
509   return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex));
510 }
511 
512 /// Try to reduce extract element costs by converting scalar compares to vector
513 /// compares followed by extract.
514 /// cmp (ext0 V0, C), (ext1 V1, C)
515 void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
516                                   ExtractElementInst *Ext1, Instruction &I) {
517   assert(isa<CmpInst>(&I) && "Expected a compare");
518   assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
519              cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
520          "Expected matching constant extract indexes");
521 
522   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
523   ++NumVecCmp;
524   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
525   Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
526   Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
527   Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
528   replaceValue(I, *NewExt);
529 }
530 
531 /// Try to reduce extract element costs by converting scalar binops to vector
532 /// binops followed by extract.
533 /// bo (ext0 V0, C), (ext1 V1, C)
534 void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
535                                     ExtractElementInst *Ext1, Instruction &I) {
536   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
537   assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
538              cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
539          "Expected matching constant extract indexes");
540 
541   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
542   ++NumVecBO;
543   Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
544   Value *VecBO =
545       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
546 
547   // All IR flags are safe to back-propagate because any potential poison
548   // created in unused vector elements is discarded by the extract.
549   if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
550     VecBOInst->copyIRFlags(&I);
551 
552   Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
553   replaceValue(I, *NewExt);
554 }
555 
556 /// Match an instruction with extracted vector operands.
557 bool VectorCombine::foldExtractExtract(Instruction &I) {
558   // It is not safe to transform things like div, urem, etc. because we may
559   // create undefined behavior when executing those on unknown vector elements.
560   if (!isSafeToSpeculativelyExecute(&I))
561     return false;
562 
563   Instruction *I0, *I1;
564   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
565   if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
566       !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
567     return false;
568 
569   Value *V0, *V1;
570   uint64_t C0, C1;
571   if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
572       !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) ||
573       V0->getType() != V1->getType())
574     return false;
575 
576   // If the scalar value 'I' is going to be re-inserted into a vector, then try
577   // to create an extract to that same element. The extract/insert can be
578   // reduced to a "select shuffle".
579   // TODO: If we add a larger pattern match that starts from an insert, this
580   //       probably becomes unnecessary.
581   auto *Ext0 = cast<ExtractElementInst>(I0);
582   auto *Ext1 = cast<ExtractElementInst>(I1);
583   uint64_t InsertIndex = InvalidIndex;
584   if (I.hasOneUse())
585     match(I.user_back(),
586           m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
587 
588   ExtractElementInst *ExtractToChange;
589   if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex))
590     return false;
591 
592   if (ExtractToChange) {
593     unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
594     ExtractElementInst *NewExtract =
595         translateExtract(ExtractToChange, CheapExtractIdx, Builder);
596     if (!NewExtract)
597       return false;
598     if (ExtractToChange == Ext0)
599       Ext0 = NewExtract;
600     else
601       Ext1 = NewExtract;
602   }
603 
604   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
605     foldExtExtCmp(Ext0, Ext1, I);
606   else
607     foldExtExtBinop(Ext0, Ext1, I);
608 
609   Worklist.push(Ext0);
610   Worklist.push(Ext1);
611   return true;
612 }
613 
614 /// Try to replace an extract + scalar fneg + insert with a vector fneg +
615 /// shuffle.
616 bool VectorCombine::foldInsExtFNeg(Instruction &I) {
617   // Match an insert (op (extract)) pattern.
618   Value *DestVec;
619   uint64_t Index;
620   Instruction *FNeg;
621   if (!match(&I, m_InsertElt(m_Value(DestVec), m_OneUse(m_Instruction(FNeg)),
622                              m_ConstantInt(Index))))
623     return false;
624 
625   // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
626   Value *SrcVec;
627   Instruction *Extract;
628   if (!match(FNeg, m_FNeg(m_CombineAnd(
629                        m_Instruction(Extract),
630                        m_ExtractElt(m_Value(SrcVec), m_SpecificInt(Index))))))
631     return false;
632 
633   // TODO: We could handle this with a length-changing shuffle.
634   auto *VecTy = cast<FixedVectorType>(I.getType());
635   if (SrcVec->getType() != VecTy)
636     return false;
637 
638   // Ignore bogus insert/extract index.
639   unsigned NumElts = VecTy->getNumElements();
640   if (Index >= NumElts)
641     return false;
642 
643   // We are inserting the negated element into the same lane that we extracted
644   // from. This is equivalent to a select-shuffle that chooses all but the
645   // negated element from the destination vector.
646   SmallVector<int> Mask(NumElts);
647   std::iota(Mask.begin(), Mask.end(), 0);
648   Mask[Index] = Index + NumElts;
649 
650   Type *ScalarTy = VecTy->getScalarType();
651   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
652   InstructionCost OldCost =
653       TTI.getArithmeticInstrCost(Instruction::FNeg, ScalarTy) +
654       TTI.getVectorInstrCost(I, VecTy, CostKind, Index);
655 
656   // If the extract has one use, it will be eliminated, so count it in the
657   // original cost. If it has more than one use, ignore the cost because it will
658   // be the same before/after.
659   if (Extract->hasOneUse())
660     OldCost += TTI.getVectorInstrCost(*Extract, VecTy, CostKind, Index);
661 
662   InstructionCost NewCost =
663       TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy) +
664       TTI.getShuffleCost(TargetTransformInfo::SK_Select, VecTy, Mask);
665 
666   if (NewCost > OldCost)
667     return false;
668 
669   // insertelt DestVec, (fneg (extractelt SrcVec, Index)), Index -->
670   // shuffle DestVec, (fneg SrcVec), Mask
671   Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg);
672   Value *Shuf = Builder.CreateShuffleVector(DestVec, VecFNeg, Mask);
673   replaceValue(I, *Shuf);
674   return true;
675 }
676 
677 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
678 /// destination type followed by shuffle. This can enable further transforms by
679 /// moving bitcasts or shuffles together.
680 bool VectorCombine::foldBitcastShuf(Instruction &I) {
681   Value *V;
682   ArrayRef<int> Mask;
683   if (!match(&I, m_BitCast(
684                      m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))))))
685     return false;
686 
687   // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
688   // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
689   // mask for scalable type is a splat or not.
690   // 2) Disallow non-vector casts and length-changing shuffles.
691   // TODO: We could allow any shuffle.
692   auto *SrcTy = dyn_cast<FixedVectorType>(V->getType());
693   if (!SrcTy || I.getOperand(0)->getType() != SrcTy)
694     return false;
695 
696   auto *DestTy = cast<FixedVectorType>(I.getType());
697   unsigned DestNumElts = DestTy->getNumElements();
698   unsigned SrcNumElts = SrcTy->getNumElements();
699   SmallVector<int, 16> NewMask;
700   if (SrcNumElts <= DestNumElts) {
701     // The bitcast is from wide to narrow/equal elements. The shuffle mask can
702     // always be expanded to the equivalent form choosing narrower elements.
703     assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask");
704     unsigned ScaleFactor = DestNumElts / SrcNumElts;
705     narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
706   } else {
707     // The bitcast is from narrow elements to wide elements. The shuffle mask
708     // must choose consecutive elements to allow casting first.
709     assert(SrcNumElts % DestNumElts == 0 && "Unexpected shuffle mask");
710     unsigned ScaleFactor = SrcNumElts / DestNumElts;
711     if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
712       return false;
713   }
714 
715   // The new shuffle must not cost more than the old shuffle. The bitcast is
716   // moved ahead of the shuffle, so assume that it has the same cost as before.
717   InstructionCost DestCost = TTI.getShuffleCost(
718       TargetTransformInfo::SK_PermuteSingleSrc, DestTy, NewMask);
719   InstructionCost SrcCost =
720       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy, Mask);
721   if (DestCost > SrcCost || !DestCost.isValid())
722     return false;
723 
724   // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC'
725   ++NumShufOfBitcast;
726   Value *CastV = Builder.CreateBitCast(V, DestTy);
727   Value *Shuf = Builder.CreateShuffleVector(CastV, NewMask);
728   replaceValue(I, *Shuf);
729   return true;
730 }
731 
732 /// Match a vector binop or compare instruction with at least one inserted
733 /// scalar operand and convert to scalar binop/cmp followed by insertelement.
734 bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
735   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
736   Value *Ins0, *Ins1;
737   if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
738       !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1))))
739     return false;
740 
741   // Do not convert the vector condition of a vector select into a scalar
742   // condition. That may cause problems for codegen because of differences in
743   // boolean formats and register-file transfers.
744   // TODO: Can we account for that in the cost model?
745   bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE;
746   if (IsCmp)
747     for (User *U : I.users())
748       if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
749         return false;
750 
751   // Match against one or both scalar values being inserted into constant
752   // vectors:
753   // vec_op VecC0, (inselt VecC1, V1, Index)
754   // vec_op (inselt VecC0, V0, Index), VecC1
755   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
756   // TODO: Deal with mismatched index constants and variable indexes?
757   Constant *VecC0 = nullptr, *VecC1 = nullptr;
758   Value *V0 = nullptr, *V1 = nullptr;
759   uint64_t Index0 = 0, Index1 = 0;
760   if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
761                                m_ConstantInt(Index0))) &&
762       !match(Ins0, m_Constant(VecC0)))
763     return false;
764   if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
765                                m_ConstantInt(Index1))) &&
766       !match(Ins1, m_Constant(VecC1)))
767     return false;
768 
769   bool IsConst0 = !V0;
770   bool IsConst1 = !V1;
771   if (IsConst0 && IsConst1)
772     return false;
773   if (!IsConst0 && !IsConst1 && Index0 != Index1)
774     return false;
775 
776   // Bail for single insertion if it is a load.
777   // TODO: Handle this once getVectorInstrCost can cost for load/stores.
778   auto *I0 = dyn_cast_or_null<Instruction>(V0);
779   auto *I1 = dyn_cast_or_null<Instruction>(V1);
780   if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
781       (IsConst1 && I0 && I0->mayReadFromMemory()))
782     return false;
783 
784   uint64_t Index = IsConst0 ? Index1 : Index0;
785   Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
786   Type *VecTy = I.getType();
787   assert(VecTy->isVectorTy() &&
788          (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
789          (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
790           ScalarTy->isPointerTy()) &&
791          "Unexpected types for insert element into binop or cmp");
792 
793   unsigned Opcode = I.getOpcode();
794   InstructionCost ScalarOpCost, VectorOpCost;
795   if (IsCmp) {
796     CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
797     ScalarOpCost = TTI.getCmpSelInstrCost(
798         Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred);
799     VectorOpCost = TTI.getCmpSelInstrCost(
800         Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred);
801   } else {
802     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
803     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
804   }
805 
806   // Get cost estimate for the insert element. This cost will factor into
807   // both sequences.
808   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
809   InstructionCost InsertCost = TTI.getVectorInstrCost(
810       Instruction::InsertElement, VecTy, CostKind, Index);
811   InstructionCost OldCost =
812       (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost;
813   InstructionCost NewCost = ScalarOpCost + InsertCost +
814                             (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
815                             (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
816 
817   // We want to scalarize unless the vector variant actually has lower cost.
818   if (OldCost < NewCost || !NewCost.isValid())
819     return false;
820 
821   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
822   // inselt NewVecC, (scalar_op V0, V1), Index
823   if (IsCmp)
824     ++NumScalarCmp;
825   else
826     ++NumScalarBO;
827 
828   // For constant cases, extract the scalar element, this should constant fold.
829   if (IsConst0)
830     V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
831   if (IsConst1)
832     V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
833 
834   Value *Scalar =
835       IsCmp ? Builder.CreateCmp(Pred, V0, V1)
836             : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
837 
838   Scalar->setName(I.getName() + ".scalar");
839 
840   // All IR flags are safe to back-propagate. There is no potential for extra
841   // poison to be created by the scalar instruction.
842   if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
843     ScalarInst->copyIRFlags(&I);
844 
845   // Fold the vector constants in the original vectors into a new base vector.
846   Value *NewVecC =
847       IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1)
848             : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1);
849   Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
850   replaceValue(I, *Insert);
851   return true;
852 }
853 
854 /// Try to combine a scalar binop + 2 scalar compares of extracted elements of
855 /// a vector into vector operations followed by extract. Note: The SLP pass
856 /// may miss this pattern because of implementation problems.
857 bool VectorCombine::foldExtractedCmps(Instruction &I) {
858   // We are looking for a scalar binop of booleans.
859   // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
860   if (!I.isBinaryOp() || !I.getType()->isIntegerTy(1))
861     return false;
862 
863   // The compare predicates should match, and each compare should have a
864   // constant operand.
865   // TODO: Relax the one-use constraints.
866   Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
867   Instruction *I0, *I1;
868   Constant *C0, *C1;
869   CmpInst::Predicate P0, P1;
870   if (!match(B0, m_OneUse(m_Cmp(P0, m_Instruction(I0), m_Constant(C0)))) ||
871       !match(B1, m_OneUse(m_Cmp(P1, m_Instruction(I1), m_Constant(C1)))) ||
872       P0 != P1)
873     return false;
874 
875   // The compare operands must be extracts of the same vector with constant
876   // extract indexes.
877   // TODO: Relax the one-use constraints.
878   Value *X;
879   uint64_t Index0, Index1;
880   if (!match(I0, m_OneUse(m_ExtractElt(m_Value(X), m_ConstantInt(Index0)))) ||
881       !match(I1, m_OneUse(m_ExtractElt(m_Specific(X), m_ConstantInt(Index1)))))
882     return false;
883 
884   auto *Ext0 = cast<ExtractElementInst>(I0);
885   auto *Ext1 = cast<ExtractElementInst>(I1);
886   ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1);
887   if (!ConvertToShuf)
888     return false;
889 
890   // The original scalar pattern is:
891   // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
892   CmpInst::Predicate Pred = P0;
893   unsigned CmpOpcode = CmpInst::isFPPredicate(Pred) ? Instruction::FCmp
894                                                     : Instruction::ICmp;
895   auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
896   if (!VecTy)
897     return false;
898 
899   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
900   InstructionCost OldCost =
901       TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
902   OldCost += TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
903   OldCost +=
904       TTI.getCmpSelInstrCost(CmpOpcode, I0->getType(),
905                              CmpInst::makeCmpResultType(I0->getType()), Pred) *
906       2;
907   OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType());
908 
909   // The proposed vector pattern is:
910   // vcmp = cmp Pred X, VecC
911   // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
912   int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
913   int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
914   auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType()));
915   InstructionCost NewCost = TTI.getCmpSelInstrCost(
916       CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred);
917   SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
918   ShufMask[CheapIndex] = ExpensiveIndex;
919   NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy,
920                                 ShufMask);
921   NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy);
922   NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex);
923 
924   // Aggressively form vector ops if the cost is equal because the transform
925   // may enable further optimization.
926   // Codegen can reverse this transform (scalarize) if it was not profitable.
927   if (OldCost < NewCost || !NewCost.isValid())
928     return false;
929 
930   // Create a vector constant from the 2 scalar constants.
931   SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
932                                    PoisonValue::get(VecTy->getElementType()));
933   CmpC[Index0] = C0;
934   CmpC[Index1] = C1;
935   Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
936 
937   Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
938   Value *VecLogic = Builder.CreateBinOp(cast<BinaryOperator>(I).getOpcode(),
939                                         VCmp, Shuf);
940   Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
941   replaceValue(I, *NewExt);
942   ++NumVecCmpBO;
943   return true;
944 }
945 
946 // Check if memory loc modified between two instrs in the same BB
947 static bool isMemModifiedBetween(BasicBlock::iterator Begin,
948                                  BasicBlock::iterator End,
949                                  const MemoryLocation &Loc, AAResults &AA) {
950   unsigned NumScanned = 0;
951   return std::any_of(Begin, End, [&](const Instruction &Instr) {
952     return isModSet(AA.getModRefInfo(&Instr, Loc)) ||
953            ++NumScanned > MaxInstrsToScan;
954   });
955 }
956 
957 namespace {
958 /// Helper class to indicate whether a vector index can be safely scalarized and
959 /// if a freeze needs to be inserted.
960 class ScalarizationResult {
961   enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
962 
963   StatusTy Status;
964   Value *ToFreeze;
965 
966   ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
967       : Status(Status), ToFreeze(ToFreeze) {}
968 
969 public:
970   ScalarizationResult(const ScalarizationResult &Other) = default;
971   ~ScalarizationResult() {
972     assert(!ToFreeze && "freeze() not called with ToFreeze being set");
973   }
974 
975   static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
976   static ScalarizationResult safe() { return {StatusTy::Safe}; }
977   static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
978     return {StatusTy::SafeWithFreeze, ToFreeze};
979   }
980 
981   /// Returns true if the index can be scalarize without requiring a freeze.
982   bool isSafe() const { return Status == StatusTy::Safe; }
983   /// Returns true if the index cannot be scalarized.
984   bool isUnsafe() const { return Status == StatusTy::Unsafe; }
985   /// Returns true if the index can be scalarize, but requires inserting a
986   /// freeze.
987   bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
988 
989   /// Reset the state of Unsafe and clear ToFreze if set.
990   void discard() {
991     ToFreeze = nullptr;
992     Status = StatusTy::Unsafe;
993   }
994 
995   /// Freeze the ToFreeze and update the use in \p User to use it.
996   void freeze(IRBuilder<> &Builder, Instruction &UserI) {
997     assert(isSafeWithFreeze() &&
998            "should only be used when freezing is required");
999     assert(is_contained(ToFreeze->users(), &UserI) &&
1000            "UserI must be a user of ToFreeze");
1001     IRBuilder<>::InsertPointGuard Guard(Builder);
1002     Builder.SetInsertPoint(cast<Instruction>(&UserI));
1003     Value *Frozen =
1004         Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen");
1005     for (Use &U : make_early_inc_range((UserI.operands())))
1006       if (U.get() == ToFreeze)
1007         U.set(Frozen);
1008 
1009     ToFreeze = nullptr;
1010   }
1011 };
1012 } // namespace
1013 
1014 /// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1015 /// Idx. \p Idx must access a valid vector element.
1016 static ScalarizationResult canScalarizeAccess(FixedVectorType *VecTy,
1017                                               Value *Idx, Instruction *CtxI,
1018                                               AssumptionCache &AC,
1019                                               const DominatorTree &DT) {
1020   if (auto *C = dyn_cast<ConstantInt>(Idx)) {
1021     if (C->getValue().ult(VecTy->getNumElements()))
1022       return ScalarizationResult::safe();
1023     return ScalarizationResult::unsafe();
1024   }
1025 
1026   unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1027   APInt Zero(IntWidth, 0);
1028   APInt MaxElts(IntWidth, VecTy->getNumElements());
1029   ConstantRange ValidIndices(Zero, MaxElts);
1030   ConstantRange IdxRange(IntWidth, true);
1031 
1032   if (isGuaranteedNotToBePoison(Idx, &AC)) {
1033     if (ValidIndices.contains(computeConstantRange(Idx, /* ForSigned */ false,
1034                                                    true, &AC, CtxI, &DT)))
1035       return ScalarizationResult::safe();
1036     return ScalarizationResult::unsafe();
1037   }
1038 
1039   // If the index may be poison, check if we can insert a freeze before the
1040   // range of the index is restricted.
1041   Value *IdxBase;
1042   ConstantInt *CI;
1043   if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) {
1044     IdxRange = IdxRange.binaryAnd(CI->getValue());
1045   } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) {
1046     IdxRange = IdxRange.urem(CI->getValue());
1047   }
1048 
1049   if (ValidIndices.contains(IdxRange))
1050     return ScalarizationResult::safeWithFreeze(IdxBase);
1051   return ScalarizationResult::unsafe();
1052 }
1053 
1054 /// The memory operation on a vector of \p ScalarType had alignment of
1055 /// \p VectorAlignment. Compute the maximal, but conservatively correct,
1056 /// alignment that will be valid for the memory operation on a single scalar
1057 /// element of the same type with index \p Idx.
1058 static Align computeAlignmentAfterScalarization(Align VectorAlignment,
1059                                                 Type *ScalarType, Value *Idx,
1060                                                 const DataLayout &DL) {
1061   if (auto *C = dyn_cast<ConstantInt>(Idx))
1062     return commonAlignment(VectorAlignment,
1063                            C->getZExtValue() * DL.getTypeStoreSize(ScalarType));
1064   return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType));
1065 }
1066 
1067 // Combine patterns like:
1068 //   %0 = load <4 x i32>, <4 x i32>* %a
1069 //   %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1070 //   store <4 x i32> %1, <4 x i32>* %a
1071 // to:
1072 //   %0 = bitcast <4 x i32>* %a to i32*
1073 //   %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1074 //   store i32 %b, i32* %1
1075 bool VectorCombine::foldSingleElementStore(Instruction &I) {
1076   auto *SI = cast<StoreInst>(&I);
1077   if (!SI->isSimple() ||
1078       !isa<FixedVectorType>(SI->getValueOperand()->getType()))
1079     return false;
1080 
1081   // TODO: Combine more complicated patterns (multiple insert) by referencing
1082   // TargetTransformInfo.
1083   Instruction *Source;
1084   Value *NewElement;
1085   Value *Idx;
1086   if (!match(SI->getValueOperand(),
1087              m_InsertElt(m_Instruction(Source), m_Value(NewElement),
1088                          m_Value(Idx))))
1089     return false;
1090 
1091   if (auto *Load = dyn_cast<LoadInst>(Source)) {
1092     auto VecTy = cast<FixedVectorType>(SI->getValueOperand()->getType());
1093     const DataLayout &DL = I.getModule()->getDataLayout();
1094     Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1095     // Don't optimize for atomic/volatile load or store. Ensure memory is not
1096     // modified between, vector type matches store size, and index is inbounds.
1097     if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1098         !DL.typeSizeEqualsStoreSize(Load->getType()) ||
1099         SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1100       return false;
1101 
1102     auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT);
1103     if (ScalarizableIdx.isUnsafe() ||
1104         isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
1105                              MemoryLocation::get(SI), AA))
1106       return false;
1107 
1108     if (ScalarizableIdx.isSafeWithFreeze())
1109       ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx));
1110     Value *GEP = Builder.CreateInBoundsGEP(
1111         SI->getValueOperand()->getType(), SI->getPointerOperand(),
1112         {ConstantInt::get(Idx->getType(), 0), Idx});
1113     StoreInst *NSI = Builder.CreateStore(NewElement, GEP);
1114     NSI->copyMetadata(*SI);
1115     Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1116         std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx,
1117         DL);
1118     NSI->setAlignment(ScalarOpAlignment);
1119     replaceValue(I, *NSI);
1120     eraseInstruction(I);
1121     return true;
1122   }
1123 
1124   return false;
1125 }
1126 
1127 /// Try to scalarize vector loads feeding extractelement instructions.
1128 bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1129   Value *Ptr;
1130   if (!match(&I, m_Load(m_Value(Ptr))))
1131     return false;
1132 
1133   auto *FixedVT = cast<FixedVectorType>(I.getType());
1134   auto *LI = cast<LoadInst>(&I);
1135   const DataLayout &DL = I.getModule()->getDataLayout();
1136   if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(FixedVT))
1137     return false;
1138 
1139   InstructionCost OriginalCost =
1140       TTI.getMemoryOpCost(Instruction::Load, FixedVT, LI->getAlign(),
1141                           LI->getPointerAddressSpace());
1142   InstructionCost ScalarizedCost = 0;
1143 
1144   Instruction *LastCheckedInst = LI;
1145   unsigned NumInstChecked = 0;
1146   // Check if all users of the load are extracts with no memory modifications
1147   // between the load and the extract. Compute the cost of both the original
1148   // code and the scalarized version.
1149   for (User *U : LI->users()) {
1150     auto *UI = dyn_cast<ExtractElementInst>(U);
1151     if (!UI || UI->getParent() != LI->getParent())
1152       return false;
1153 
1154     if (!isGuaranteedNotToBePoison(UI->getOperand(1), &AC, LI, &DT))
1155       return false;
1156 
1157     // Check if any instruction between the load and the extract may modify
1158     // memory.
1159     if (LastCheckedInst->comesBefore(UI)) {
1160       for (Instruction &I :
1161            make_range(std::next(LI->getIterator()), UI->getIterator())) {
1162         // Bail out if we reached the check limit or the instruction may write
1163         // to memory.
1164         if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
1165           return false;
1166         NumInstChecked++;
1167       }
1168       LastCheckedInst = UI;
1169     }
1170 
1171     auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT);
1172     if (!ScalarIdx.isSafe()) {
1173       // TODO: Freeze index if it is safe to do so.
1174       ScalarIdx.discard();
1175       return false;
1176     }
1177 
1178     auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
1179     TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1180     OriginalCost +=
1181         TTI.getVectorInstrCost(Instruction::ExtractElement, FixedVT, CostKind,
1182                                Index ? Index->getZExtValue() : -1);
1183     ScalarizedCost +=
1184         TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(),
1185                             Align(1), LI->getPointerAddressSpace());
1186     ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType());
1187   }
1188 
1189   if (ScalarizedCost >= OriginalCost)
1190     return false;
1191 
1192   // Replace extracts with narrow scalar loads.
1193   for (User *U : LI->users()) {
1194     auto *EI = cast<ExtractElementInst>(U);
1195     Builder.SetInsertPoint(EI);
1196 
1197     Value *Idx = EI->getOperand(1);
1198     Value *GEP =
1199         Builder.CreateInBoundsGEP(FixedVT, Ptr, {Builder.getInt32(0), Idx});
1200     auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
1201         FixedVT->getElementType(), GEP, EI->getName() + ".scalar"));
1202 
1203     Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1204         LI->getAlign(), FixedVT->getElementType(), Idx, DL);
1205     NewLoad->setAlignment(ScalarOpAlignment);
1206 
1207     replaceValue(*EI, *NewLoad);
1208   }
1209 
1210   return true;
1211 }
1212 
1213 /// Try to convert "shuffle (binop), (binop)" with a shared binop operand into
1214 /// "binop (shuffle), (shuffle)".
1215 bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
1216   auto *VecTy = cast<FixedVectorType>(I.getType());
1217   BinaryOperator *B0, *B1;
1218   ArrayRef<int> Mask;
1219   if (!match(&I, m_Shuffle(m_OneUse(m_BinOp(B0)), m_OneUse(m_BinOp(B1)),
1220                            m_Mask(Mask))) ||
1221       B0->getOpcode() != B1->getOpcode() || B0->getType() != VecTy)
1222     return false;
1223 
1224   // Try to replace a binop with a shuffle if the shuffle is not costly.
1225   // The new shuffle will choose from a single, common operand, so it may be
1226   // cheaper than the existing two-operand shuffle.
1227   SmallVector<int> UnaryMask = createUnaryMask(Mask, Mask.size());
1228   Instruction::BinaryOps Opcode = B0->getOpcode();
1229   InstructionCost BinopCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
1230   InstructionCost ShufCost = TTI.getShuffleCost(
1231       TargetTransformInfo::SK_PermuteSingleSrc, VecTy, UnaryMask);
1232   if (ShufCost > BinopCost)
1233     return false;
1234 
1235   // If we have something like "add X, Y" and "add Z, X", swap ops to match.
1236   Value *X = B0->getOperand(0), *Y = B0->getOperand(1);
1237   Value *Z = B1->getOperand(0), *W = B1->getOperand(1);
1238   if (BinaryOperator::isCommutative(Opcode) && X != Z && Y != W)
1239     std::swap(X, Y);
1240 
1241   Value *Shuf0, *Shuf1;
1242   if (X == Z) {
1243     // shuf (bo X, Y), (bo X, W) --> bo (shuf X), (shuf Y, W)
1244     Shuf0 = Builder.CreateShuffleVector(X, UnaryMask);
1245     Shuf1 = Builder.CreateShuffleVector(Y, W, Mask);
1246   } else if (Y == W) {
1247     // shuf (bo X, Y), (bo Z, Y) --> bo (shuf X, Z), (shuf Y)
1248     Shuf0 = Builder.CreateShuffleVector(X, Z, Mask);
1249     Shuf1 = Builder.CreateShuffleVector(Y, UnaryMask);
1250   } else {
1251     return false;
1252   }
1253 
1254   Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1);
1255   // Intersect flags from the old binops.
1256   if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
1257     NewInst->copyIRFlags(B0);
1258     NewInst->andIRFlags(B1);
1259   }
1260   replaceValue(I, *NewBO);
1261   return true;
1262 }
1263 
1264 /// Given a commutative reduction, the order of the input lanes does not alter
1265 /// the results. We can use this to remove certain shuffles feeding the
1266 /// reduction, removing the need to shuffle at all.
1267 bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
1268   auto *II = dyn_cast<IntrinsicInst>(&I);
1269   if (!II)
1270     return false;
1271   switch (II->getIntrinsicID()) {
1272   case Intrinsic::vector_reduce_add:
1273   case Intrinsic::vector_reduce_mul:
1274   case Intrinsic::vector_reduce_and:
1275   case Intrinsic::vector_reduce_or:
1276   case Intrinsic::vector_reduce_xor:
1277   case Intrinsic::vector_reduce_smin:
1278   case Intrinsic::vector_reduce_smax:
1279   case Intrinsic::vector_reduce_umin:
1280   case Intrinsic::vector_reduce_umax:
1281     break;
1282   default:
1283     return false;
1284   }
1285 
1286   // Find all the inputs when looking through operations that do not alter the
1287   // lane order (binops, for example). Currently we look for a single shuffle,
1288   // and can ignore splat values.
1289   std::queue<Value *> Worklist;
1290   SmallPtrSet<Value *, 4> Visited;
1291   ShuffleVectorInst *Shuffle = nullptr;
1292   if (auto *Op = dyn_cast<Instruction>(I.getOperand(0)))
1293     Worklist.push(Op);
1294 
1295   while (!Worklist.empty()) {
1296     Value *CV = Worklist.front();
1297     Worklist.pop();
1298     if (Visited.contains(CV))
1299       continue;
1300 
1301     // Splats don't change the order, so can be safely ignored.
1302     if (isSplatValue(CV))
1303       continue;
1304 
1305     Visited.insert(CV);
1306 
1307     if (auto *CI = dyn_cast<Instruction>(CV)) {
1308       if (CI->isBinaryOp()) {
1309         for (auto *Op : CI->operand_values())
1310           Worklist.push(Op);
1311         continue;
1312       } else if (auto *SV = dyn_cast<ShuffleVectorInst>(CI)) {
1313         if (Shuffle && Shuffle != SV)
1314           return false;
1315         Shuffle = SV;
1316         continue;
1317       }
1318     }
1319 
1320     // Anything else is currently an unknown node.
1321     return false;
1322   }
1323 
1324   if (!Shuffle)
1325     return false;
1326 
1327   // Check all uses of the binary ops and shuffles are also included in the
1328   // lane-invariant operations (Visited should be the list of lanewise
1329   // instructions, including the shuffle that we found).
1330   for (auto *V : Visited)
1331     for (auto *U : V->users())
1332       if (!Visited.contains(U) && U != &I)
1333         return false;
1334 
1335   FixedVectorType *VecType =
1336       dyn_cast<FixedVectorType>(II->getOperand(0)->getType());
1337   if (!VecType)
1338     return false;
1339   FixedVectorType *ShuffleInputType =
1340       dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType());
1341   if (!ShuffleInputType)
1342     return false;
1343   int NumInputElts = ShuffleInputType->getNumElements();
1344 
1345   // Find the mask from sorting the lanes into order. This is most likely to
1346   // become a identity or concat mask. Undef elements are pushed to the end.
1347   SmallVector<int> ConcatMask;
1348   Shuffle->getShuffleMask(ConcatMask);
1349   sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
1350   bool UsesSecondVec =
1351       any_of(ConcatMask, [&](int M) { return M >= NumInputElts; });
1352   InstructionCost OldCost = TTI.getShuffleCost(
1353       UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType,
1354       Shuffle->getShuffleMask());
1355   InstructionCost NewCost = TTI.getShuffleCost(
1356       UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType,
1357       ConcatMask);
1358 
1359   LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
1360                     << "\n");
1361   LLVM_DEBUG(dbgs() << "  OldCost: " << OldCost << " vs NewCost: " << NewCost
1362                     << "\n");
1363   if (NewCost < OldCost) {
1364     Builder.SetInsertPoint(Shuffle);
1365     Value *NewShuffle = Builder.CreateShuffleVector(
1366         Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask);
1367     LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
1368     replaceValue(*Shuffle, *NewShuffle);
1369   }
1370 
1371   // See if we can re-use foldSelectShuffle, getting it to reduce the size of
1372   // the shuffle into a nicer order, as it can ignore the order of the shuffles.
1373   return foldSelectShuffle(*Shuffle, true);
1374 }
1375 
1376 /// This method looks for groups of shuffles acting on binops, of the form:
1377 ///  %x = shuffle ...
1378 ///  %y = shuffle ...
1379 ///  %a = binop %x, %y
1380 ///  %b = binop %x, %y
1381 ///  shuffle %a, %b, selectmask
1382 /// We may, especially if the shuffle is wider than legal, be able to convert
1383 /// the shuffle to a form where only parts of a and b need to be computed. On
1384 /// architectures with no obvious "select" shuffle, this can reduce the total
1385 /// number of operations if the target reports them as cheaper.
1386 bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
1387   auto *SVI = cast<ShuffleVectorInst>(&I);
1388   auto *VT = cast<FixedVectorType>(I.getType());
1389   auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0));
1390   auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1));
1391   if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
1392       VT != Op0->getType())
1393     return false;
1394 
1395   auto *SVI0A = dyn_cast<Instruction>(Op0->getOperand(0));
1396   auto *SVI0B = dyn_cast<Instruction>(Op0->getOperand(1));
1397   auto *SVI1A = dyn_cast<Instruction>(Op1->getOperand(0));
1398   auto *SVI1B = dyn_cast<Instruction>(Op1->getOperand(1));
1399   SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
1400   auto checkSVNonOpUses = [&](Instruction *I) {
1401     if (!I || I->getOperand(0)->getType() != VT)
1402       return true;
1403     return any_of(I->users(), [&](User *U) {
1404       return U != Op0 && U != Op1 &&
1405              !(isa<ShuffleVectorInst>(U) &&
1406                (InputShuffles.contains(cast<Instruction>(U)) ||
1407                 isInstructionTriviallyDead(cast<Instruction>(U))));
1408     });
1409   };
1410   if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
1411       checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
1412     return false;
1413 
1414   // Collect all the uses that are shuffles that we can transform together. We
1415   // may not have a single shuffle, but a group that can all be transformed
1416   // together profitably.
1417   SmallVector<ShuffleVectorInst *> Shuffles;
1418   auto collectShuffles = [&](Instruction *I) {
1419     for (auto *U : I->users()) {
1420       auto *SV = dyn_cast<ShuffleVectorInst>(U);
1421       if (!SV || SV->getType() != VT)
1422         return false;
1423       if ((SV->getOperand(0) != Op0 && SV->getOperand(0) != Op1) ||
1424           (SV->getOperand(1) != Op0 && SV->getOperand(1) != Op1))
1425         return false;
1426       if (!llvm::is_contained(Shuffles, SV))
1427         Shuffles.push_back(SV);
1428     }
1429     return true;
1430   };
1431   if (!collectShuffles(Op0) || !collectShuffles(Op1))
1432     return false;
1433   // From a reduction, we need to be processing a single shuffle, otherwise the
1434   // other uses will not be lane-invariant.
1435   if (FromReduction && Shuffles.size() > 1)
1436     return false;
1437 
1438   // Add any shuffle uses for the shuffles we have found, to include them in our
1439   // cost calculations.
1440   if (!FromReduction) {
1441     for (ShuffleVectorInst *SV : Shuffles) {
1442       for (auto *U : SV->users()) {
1443         ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(U);
1444         if (SSV && isa<UndefValue>(SSV->getOperand(1)) && SSV->getType() == VT)
1445           Shuffles.push_back(SSV);
1446       }
1447     }
1448   }
1449 
1450   // For each of the output shuffles, we try to sort all the first vector
1451   // elements to the beginning, followed by the second array elements at the
1452   // end. If the binops are legalized to smaller vectors, this may reduce total
1453   // number of binops. We compute the ReconstructMask mask needed to convert
1454   // back to the original lane order.
1455   SmallVector<std::pair<int, int>> V1, V2;
1456   SmallVector<SmallVector<int>> OrigReconstructMasks;
1457   int MaxV1Elt = 0, MaxV2Elt = 0;
1458   unsigned NumElts = VT->getNumElements();
1459   for (ShuffleVectorInst *SVN : Shuffles) {
1460     SmallVector<int> Mask;
1461     SVN->getShuffleMask(Mask);
1462 
1463     // Check the operands are the same as the original, or reversed (in which
1464     // case we need to commute the mask).
1465     Value *SVOp0 = SVN->getOperand(0);
1466     Value *SVOp1 = SVN->getOperand(1);
1467     if (isa<UndefValue>(SVOp1)) {
1468       auto *SSV = cast<ShuffleVectorInst>(SVOp0);
1469       SVOp0 = SSV->getOperand(0);
1470       SVOp1 = SSV->getOperand(1);
1471       for (unsigned I = 0, E = Mask.size(); I != E; I++) {
1472         if (Mask[I] >= static_cast<int>(SSV->getShuffleMask().size()))
1473           return false;
1474         Mask[I] = Mask[I] < 0 ? Mask[I] : SSV->getMaskValue(Mask[I]);
1475       }
1476     }
1477     if (SVOp0 == Op1 && SVOp1 == Op0) {
1478       std::swap(SVOp0, SVOp1);
1479       ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
1480     }
1481     if (SVOp0 != Op0 || SVOp1 != Op1)
1482       return false;
1483 
1484     // Calculate the reconstruction mask for this shuffle, as the mask needed to
1485     // take the packed values from Op0/Op1 and reconstructing to the original
1486     // order.
1487     SmallVector<int> ReconstructMask;
1488     for (unsigned I = 0; I < Mask.size(); I++) {
1489       if (Mask[I] < 0) {
1490         ReconstructMask.push_back(-1);
1491       } else if (Mask[I] < static_cast<int>(NumElts)) {
1492         MaxV1Elt = std::max(MaxV1Elt, Mask[I]);
1493         auto It = find_if(V1, [&](const std::pair<int, int> &A) {
1494           return Mask[I] == A.first;
1495         });
1496         if (It != V1.end())
1497           ReconstructMask.push_back(It - V1.begin());
1498         else {
1499           ReconstructMask.push_back(V1.size());
1500           V1.emplace_back(Mask[I], V1.size());
1501         }
1502       } else {
1503         MaxV2Elt = std::max<int>(MaxV2Elt, Mask[I] - NumElts);
1504         auto It = find_if(V2, [&](const std::pair<int, int> &A) {
1505           return Mask[I] - static_cast<int>(NumElts) == A.first;
1506         });
1507         if (It != V2.end())
1508           ReconstructMask.push_back(NumElts + It - V2.begin());
1509         else {
1510           ReconstructMask.push_back(NumElts + V2.size());
1511           V2.emplace_back(Mask[I] - NumElts, NumElts + V2.size());
1512         }
1513       }
1514     }
1515 
1516     // For reductions, we know that the lane ordering out doesn't alter the
1517     // result. In-order can help simplify the shuffle away.
1518     if (FromReduction)
1519       sort(ReconstructMask);
1520     OrigReconstructMasks.push_back(std::move(ReconstructMask));
1521   }
1522 
1523   // If the Maximum element used from V1 and V2 are not larger than the new
1524   // vectors, the vectors are already packes and performing the optimization
1525   // again will likely not help any further. This also prevents us from getting
1526   // stuck in a cycle in case the costs do not also rule it out.
1527   if (V1.empty() || V2.empty() ||
1528       (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
1529        MaxV2Elt == static_cast<int>(V2.size()) - 1))
1530     return false;
1531 
1532   // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
1533   // shuffle of another shuffle, or not a shuffle (that is treated like a
1534   // identity shuffle).
1535   auto GetBaseMaskValue = [&](Instruction *I, int M) {
1536     auto *SV = dyn_cast<ShuffleVectorInst>(I);
1537     if (!SV)
1538       return M;
1539     if (isa<UndefValue>(SV->getOperand(1)))
1540       if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
1541         if (InputShuffles.contains(SSV))
1542           return SSV->getMaskValue(SV->getMaskValue(M));
1543     return SV->getMaskValue(M);
1544   };
1545 
1546   // Attempt to sort the inputs my ascending mask values to make simpler input
1547   // shuffles and push complex shuffles down to the uses. We sort on the first
1548   // of the two input shuffle orders, to try and get at least one input into a
1549   // nice order.
1550   auto SortBase = [&](Instruction *A, std::pair<int, int> X,
1551                       std::pair<int, int> Y) {
1552     int MXA = GetBaseMaskValue(A, X.first);
1553     int MYA = GetBaseMaskValue(A, Y.first);
1554     return MXA < MYA;
1555   };
1556   stable_sort(V1, [&](std::pair<int, int> A, std::pair<int, int> B) {
1557     return SortBase(SVI0A, A, B);
1558   });
1559   stable_sort(V2, [&](std::pair<int, int> A, std::pair<int, int> B) {
1560     return SortBase(SVI1A, A, B);
1561   });
1562   // Calculate our ReconstructMasks from the OrigReconstructMasks and the
1563   // modified order of the input shuffles.
1564   SmallVector<SmallVector<int>> ReconstructMasks;
1565   for (const auto &Mask : OrigReconstructMasks) {
1566     SmallVector<int> ReconstructMask;
1567     for (int M : Mask) {
1568       auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
1569         auto It = find_if(V, [M](auto A) { return A.second == M; });
1570         assert(It != V.end() && "Expected all entries in Mask");
1571         return std::distance(V.begin(), It);
1572       };
1573       if (M < 0)
1574         ReconstructMask.push_back(-1);
1575       else if (M < static_cast<int>(NumElts)) {
1576         ReconstructMask.push_back(FindIndex(V1, M));
1577       } else {
1578         ReconstructMask.push_back(NumElts + FindIndex(V2, M));
1579       }
1580     }
1581     ReconstructMasks.push_back(std::move(ReconstructMask));
1582   }
1583 
1584   // Calculate the masks needed for the new input shuffles, which get padded
1585   // with undef
1586   SmallVector<int> V1A, V1B, V2A, V2B;
1587   for (unsigned I = 0; I < V1.size(); I++) {
1588     V1A.push_back(GetBaseMaskValue(SVI0A, V1[I].first));
1589     V1B.push_back(GetBaseMaskValue(SVI0B, V1[I].first));
1590   }
1591   for (unsigned I = 0; I < V2.size(); I++) {
1592     V2A.push_back(GetBaseMaskValue(SVI1A, V2[I].first));
1593     V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first));
1594   }
1595   while (V1A.size() < NumElts) {
1596     V1A.push_back(PoisonMaskElem);
1597     V1B.push_back(PoisonMaskElem);
1598   }
1599   while (V2A.size() < NumElts) {
1600     V2A.push_back(PoisonMaskElem);
1601     V2B.push_back(PoisonMaskElem);
1602   }
1603 
1604   auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
1605     auto *SV = dyn_cast<ShuffleVectorInst>(I);
1606     if (!SV)
1607       return C;
1608     return C + TTI.getShuffleCost(isa<UndefValue>(SV->getOperand(1))
1609                                       ? TTI::SK_PermuteSingleSrc
1610                                       : TTI::SK_PermuteTwoSrc,
1611                                   VT, SV->getShuffleMask());
1612   };
1613   auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
1614     return C + TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, VT, Mask);
1615   };
1616 
1617   // Get the costs of the shuffles + binops before and after with the new
1618   // shuffle masks.
1619   InstructionCost CostBefore =
1620       TTI.getArithmeticInstrCost(Op0->getOpcode(), VT) +
1621       TTI.getArithmeticInstrCost(Op1->getOpcode(), VT);
1622   CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(),
1623                                 InstructionCost(0), AddShuffleCost);
1624   CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
1625                                 InstructionCost(0), AddShuffleCost);
1626 
1627   // The new binops will be unused for lanes past the used shuffle lengths.
1628   // These types attempt to get the correct cost for that from the target.
1629   FixedVectorType *Op0SmallVT =
1630       FixedVectorType::get(VT->getScalarType(), V1.size());
1631   FixedVectorType *Op1SmallVT =
1632       FixedVectorType::get(VT->getScalarType(), V2.size());
1633   InstructionCost CostAfter =
1634       TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT) +
1635       TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT);
1636   CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(),
1637                                InstructionCost(0), AddShuffleMaskCost);
1638   std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
1639   CostAfter +=
1640       std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(),
1641                       InstructionCost(0), AddShuffleMaskCost);
1642 
1643   LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
1644   LLVM_DEBUG(dbgs() << "  CostBefore: " << CostBefore
1645                     << " vs CostAfter: " << CostAfter << "\n");
1646   if (CostBefore <= CostAfter)
1647     return false;
1648 
1649   // The cost model has passed, create the new instructions.
1650   auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
1651     auto *SV = dyn_cast<ShuffleVectorInst>(I);
1652     if (!SV)
1653       return I;
1654     if (isa<UndefValue>(SV->getOperand(1)))
1655       if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
1656         if (InputShuffles.contains(SSV))
1657           return SSV->getOperand(Op);
1658     return SV->getOperand(Op);
1659   };
1660   Builder.SetInsertPoint(SVI0A->getInsertionPointAfterDef());
1661   Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0),
1662                                              GetShuffleOperand(SVI0A, 1), V1A);
1663   Builder.SetInsertPoint(SVI0B->getInsertionPointAfterDef());
1664   Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0),
1665                                              GetShuffleOperand(SVI0B, 1), V1B);
1666   Builder.SetInsertPoint(SVI1A->getInsertionPointAfterDef());
1667   Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0),
1668                                              GetShuffleOperand(SVI1A, 1), V2A);
1669   Builder.SetInsertPoint(SVI1B->getInsertionPointAfterDef());
1670   Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0),
1671                                              GetShuffleOperand(SVI1B, 1), V2B);
1672   Builder.SetInsertPoint(Op0);
1673   Value *NOp0 = Builder.CreateBinOp((Instruction::BinaryOps)Op0->getOpcode(),
1674                                     NSV0A, NSV0B);
1675   if (auto *I = dyn_cast<Instruction>(NOp0))
1676     I->copyIRFlags(Op0, true);
1677   Builder.SetInsertPoint(Op1);
1678   Value *NOp1 = Builder.CreateBinOp((Instruction::BinaryOps)Op1->getOpcode(),
1679                                     NSV1A, NSV1B);
1680   if (auto *I = dyn_cast<Instruction>(NOp1))
1681     I->copyIRFlags(Op1, true);
1682 
1683   for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
1684     Builder.SetInsertPoint(Shuffles[S]);
1685     Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]);
1686     replaceValue(*Shuffles[S], *NSV);
1687   }
1688 
1689   Worklist.pushValue(NSV0A);
1690   Worklist.pushValue(NSV0B);
1691   Worklist.pushValue(NSV1A);
1692   Worklist.pushValue(NSV1B);
1693   for (auto *S : Shuffles)
1694     Worklist.add(S);
1695   return true;
1696 }
1697 
1698 /// This is the entry point for all transforms. Pass manager differences are
1699 /// handled in the callers of this function.
1700 bool VectorCombine::run() {
1701   if (DisableVectorCombine)
1702     return false;
1703 
1704   // Don't attempt vectorization if the target does not support vectors.
1705   if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true)))
1706     return false;
1707 
1708   bool MadeChange = false;
1709   auto FoldInst = [this, &MadeChange](Instruction &I) {
1710     Builder.SetInsertPoint(&I);
1711     bool IsFixedVectorType = isa<FixedVectorType>(I.getType());
1712     auto Opcode = I.getOpcode();
1713 
1714     // These folds should be beneficial regardless of when this pass is run
1715     // in the optimization pipeline.
1716     // The type checking is for run-time efficiency. We can avoid wasting time
1717     // dispatching to folding functions if there's no chance of matching.
1718     if (IsFixedVectorType) {
1719       switch (Opcode) {
1720       case Instruction::InsertElement:
1721         MadeChange |= vectorizeLoadInsert(I);
1722         break;
1723       case Instruction::ShuffleVector:
1724         MadeChange |= widenSubvectorLoad(I);
1725         break;
1726       case Instruction::Load:
1727         MadeChange |= scalarizeLoadExtract(I);
1728         break;
1729       default:
1730         break;
1731       }
1732     }
1733 
1734     // This transform works with scalable and fixed vectors
1735     // TODO: Identify and allow other scalable transforms
1736     if (isa<VectorType>(I.getType()))
1737       MadeChange |= scalarizeBinopOrCmp(I);
1738 
1739     if (Opcode == Instruction::Store)
1740       MadeChange |= foldSingleElementStore(I);
1741 
1742 
1743     // If this is an early pipeline invocation of this pass, we are done.
1744     if (TryEarlyFoldsOnly)
1745       return;
1746 
1747     // Otherwise, try folds that improve codegen but may interfere with
1748     // early IR canonicalizations.
1749     // The type checking is for run-time efficiency. We can avoid wasting time
1750     // dispatching to folding functions if there's no chance of matching.
1751     if (IsFixedVectorType) {
1752       switch (Opcode) {
1753       case Instruction::InsertElement:
1754         MadeChange |= foldInsExtFNeg(I);
1755         break;
1756       case Instruction::ShuffleVector:
1757         MadeChange |= foldShuffleOfBinops(I);
1758         MadeChange |= foldSelectShuffle(I);
1759         break;
1760       case Instruction::BitCast:
1761         MadeChange |= foldBitcastShuf(I);
1762         break;
1763       }
1764     } else {
1765       switch (Opcode) {
1766       case Instruction::Call:
1767         MadeChange |= foldShuffleFromReductions(I);
1768         break;
1769       case Instruction::ICmp:
1770       case Instruction::FCmp:
1771         MadeChange |= foldExtractExtract(I);
1772         break;
1773       default:
1774         if (Instruction::isBinaryOp(Opcode)) {
1775           MadeChange |= foldExtractExtract(I);
1776           MadeChange |= foldExtractedCmps(I);
1777         }
1778         break;
1779       }
1780     }
1781   };
1782 
1783   for (BasicBlock &BB : F) {
1784     // Ignore unreachable basic blocks.
1785     if (!DT.isReachableFromEntry(&BB))
1786       continue;
1787     // Use early increment range so that we can erase instructions in loop.
1788     for (Instruction &I : make_early_inc_range(BB)) {
1789       if (I.isDebugOrPseudoInst())
1790         continue;
1791       FoldInst(I);
1792     }
1793   }
1794 
1795   while (!Worklist.isEmpty()) {
1796     Instruction *I = Worklist.removeOne();
1797     if (!I)
1798       continue;
1799 
1800     if (isInstructionTriviallyDead(I)) {
1801       eraseInstruction(*I);
1802       continue;
1803     }
1804 
1805     FoldInst(*I);
1806   }
1807 
1808   return MadeChange;
1809 }
1810 
1811 PreservedAnalyses VectorCombinePass::run(Function &F,
1812                                          FunctionAnalysisManager &FAM) {
1813   auto &AC = FAM.getResult<AssumptionAnalysis>(F);
1814   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
1815   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
1816   AAResults &AA = FAM.getResult<AAManager>(F);
1817   VectorCombine Combiner(F, TTI, DT, AA, AC, TryEarlyFoldsOnly);
1818   if (!Combiner.run())
1819     return PreservedAnalyses::all();
1820   PreservedAnalyses PA;
1821   PA.preserveSet<CFGAnalyses>();
1822   return PA;
1823 }
1824