1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/BasicAliasAnalysis.h"
18 #include "llvm/Analysis/GlobalsModRef.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/Analysis/ValueTracking.h"
21 #include "llvm/Analysis/VectorUtils.h"
22 #include "llvm/IR/Dominators.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/PatternMatch.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Transforms/Utils/Local.h"
30 #include "llvm/Transforms/Vectorize.h"
31
32 using namespace llvm;
33 using namespace llvm::PatternMatch;
34
35 #define DEBUG_TYPE "vector-combine"
36 STATISTIC(NumVecCmp, "Number of vector compares formed");
37 STATISTIC(NumVecBO, "Number of vector binops formed");
38 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
39 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
40 STATISTIC(NumScalarBO, "Number of scalar binops formed");
41 STATISTIC(NumScalarCmp, "Number of scalar compares formed");
42
43 static cl::opt<bool> DisableVectorCombine(
44 "disable-vector-combine", cl::init(false), cl::Hidden,
45 cl::desc("Disable all vector combine transforms"));
46
47 static cl::opt<bool> DisableBinopExtractShuffle(
48 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
49 cl::desc("Disable binop extract to shuffle transforms"));
50
51 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
52
53 namespace {
54 class VectorCombine {
55 public:
VectorCombine(Function & F,const TargetTransformInfo & TTI,const DominatorTree & DT)56 VectorCombine(Function &F, const TargetTransformInfo &TTI,
57 const DominatorTree &DT)
58 : F(F), Builder(F.getContext()), TTI(TTI), DT(DT) {}
59
60 bool run();
61
62 private:
63 Function &F;
64 IRBuilder<> Builder;
65 const TargetTransformInfo &TTI;
66 const DominatorTree &DT;
67
68 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
69 ExtractElementInst *Ext1,
70 unsigned PreferredExtractIndex) const;
71 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
72 unsigned Opcode,
73 ExtractElementInst *&ConvertToShuffle,
74 unsigned PreferredExtractIndex);
75 void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
76 Instruction &I);
77 void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
78 Instruction &I);
79 bool foldExtractExtract(Instruction &I);
80 bool foldBitcastShuf(Instruction &I);
81 bool scalarizeBinopOrCmp(Instruction &I);
82 bool foldExtractedCmps(Instruction &I);
83 };
84 } // namespace
85
replaceValue(Value & Old,Value & New)86 static void replaceValue(Value &Old, Value &New) {
87 Old.replaceAllUsesWith(&New);
88 New.takeName(&Old);
89 }
90
91 /// Determine which, if any, of the inputs should be replaced by a shuffle
92 /// followed by extract from a different index.
getShuffleExtract(ExtractElementInst * Ext0,ExtractElementInst * Ext1,unsigned PreferredExtractIndex=InvalidIndex) const93 ExtractElementInst *VectorCombine::getShuffleExtract(
94 ExtractElementInst *Ext0, ExtractElementInst *Ext1,
95 unsigned PreferredExtractIndex = InvalidIndex) const {
96 assert(isa<ConstantInt>(Ext0->getIndexOperand()) &&
97 isa<ConstantInt>(Ext1->getIndexOperand()) &&
98 "Expected constant extract indexes");
99
100 unsigned Index0 = cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue();
101 unsigned Index1 = cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue();
102
103 // If the extract indexes are identical, no shuffle is needed.
104 if (Index0 == Index1)
105 return nullptr;
106
107 Type *VecTy = Ext0->getVectorOperand()->getType();
108 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
109 int Cost0 = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0);
110 int Cost1 = TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1);
111
112 // We are extracting from 2 different indexes, so one operand must be shuffled
113 // before performing a vector operation and/or extract. The more expensive
114 // extract will be replaced by a shuffle.
115 if (Cost0 > Cost1)
116 return Ext0;
117 if (Cost1 > Cost0)
118 return Ext1;
119
120 // If the costs are equal and there is a preferred extract index, shuffle the
121 // opposite operand.
122 if (PreferredExtractIndex == Index0)
123 return Ext1;
124 if (PreferredExtractIndex == Index1)
125 return Ext0;
126
127 // Otherwise, replace the extract with the higher index.
128 return Index0 > Index1 ? Ext0 : Ext1;
129 }
130
131 /// Compare the relative costs of 2 extracts followed by scalar operation vs.
132 /// vector operation(s) followed by extract. Return true if the existing
133 /// instructions are cheaper than a vector alternative. Otherwise, return false
134 /// and if one of the extracts should be transformed to a shufflevector, set
135 /// \p ConvertToShuffle to that extract instruction.
isExtractExtractCheap(ExtractElementInst * Ext0,ExtractElementInst * Ext1,unsigned Opcode,ExtractElementInst * & ConvertToShuffle,unsigned PreferredExtractIndex)136 bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
137 ExtractElementInst *Ext1,
138 unsigned Opcode,
139 ExtractElementInst *&ConvertToShuffle,
140 unsigned PreferredExtractIndex) {
141 assert(isa<ConstantInt>(Ext0->getOperand(1)) &&
142 isa<ConstantInt>(Ext1->getOperand(1)) &&
143 "Expected constant extract indexes");
144 Type *ScalarTy = Ext0->getType();
145 auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType());
146 int ScalarOpCost, VectorOpCost;
147
148 // Get cost estimates for scalar and vector versions of the operation.
149 bool IsBinOp = Instruction::isBinaryOp(Opcode);
150 if (IsBinOp) {
151 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
152 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
153 } else {
154 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
155 "Expected a compare");
156 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy,
157 CmpInst::makeCmpResultType(ScalarTy));
158 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy,
159 CmpInst::makeCmpResultType(VecTy));
160 }
161
162 // Get cost estimates for the extract elements. These costs will factor into
163 // both sequences.
164 unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue();
165 unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue();
166
167 int Extract0Cost =
168 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext0Index);
169 int Extract1Cost =
170 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext1Index);
171
172 // A more expensive extract will always be replaced by a splat shuffle.
173 // For example, if Ext0 is more expensive:
174 // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
175 // extelt (opcode (splat V0, Ext0), V1), Ext1
176 // TODO: Evaluate whether that always results in lowest cost. Alternatively,
177 // check the cost of creating a broadcast shuffle and shuffling both
178 // operands to element 0.
179 int CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
180
181 // Extra uses of the extracts mean that we include those costs in the
182 // vector total because those instructions will not be eliminated.
183 int OldCost, NewCost;
184 if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) {
185 // Handle a special case. If the 2 extracts are identical, adjust the
186 // formulas to account for that. The extra use charge allows for either the
187 // CSE'd pattern or an unoptimized form with identical values:
188 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
189 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
190 : !Ext0->hasOneUse() || !Ext1->hasOneUse();
191 OldCost = CheapExtractCost + ScalarOpCost;
192 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
193 } else {
194 // Handle the general case. Each extract is actually a different value:
195 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
196 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
197 NewCost = VectorOpCost + CheapExtractCost +
198 !Ext0->hasOneUse() * Extract0Cost +
199 !Ext1->hasOneUse() * Extract1Cost;
200 }
201
202 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
203 if (ConvertToShuffle) {
204 if (IsBinOp && DisableBinopExtractShuffle)
205 return true;
206
207 // If we are extracting from 2 different indexes, then one operand must be
208 // shuffled before performing the vector operation. The shuffle mask is
209 // undefined except for 1 lane that is being translated to the remaining
210 // extraction lane. Therefore, it is a splat shuffle. Ex:
211 // ShufMask = { undef, undef, 0, undef }
212 // TODO: The cost model has an option for a "broadcast" shuffle
213 // (splat-from-element-0), but no option for a more general splat.
214 NewCost +=
215 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy);
216 }
217
218 // Aggressively form a vector op if the cost is equal because the transform
219 // may enable further optimization.
220 // Codegen can reverse this transform (scalarize) if it was not profitable.
221 return OldCost < NewCost;
222 }
223
224 /// Create a shuffle that translates (shifts) 1 element from the input vector
225 /// to a new element location.
createShiftShuffle(Value * Vec,unsigned OldIndex,unsigned NewIndex,IRBuilder<> & Builder)226 static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
227 unsigned NewIndex, IRBuilder<> &Builder) {
228 // The shuffle mask is undefined except for 1 lane that is being translated
229 // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
230 // ShufMask = { 2, undef, undef, undef }
231 auto *VecTy = cast<FixedVectorType>(Vec->getType());
232 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem);
233 ShufMask[NewIndex] = OldIndex;
234 Value *Undef = UndefValue::get(VecTy);
235 return Builder.CreateShuffleVector(Vec, Undef, ShufMask, "shift");
236 }
237
238 /// Given an extract element instruction with constant index operand, shuffle
239 /// the source vector (shift the scalar element) to a NewIndex for extraction.
240 /// Return null if the input can be constant folded, so that we are not creating
241 /// unnecessary instructions.
translateExtract(ExtractElementInst * ExtElt,unsigned NewIndex,IRBuilder<> & Builder)242 static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
243 unsigned NewIndex,
244 IRBuilder<> &Builder) {
245 // If the extract can be constant-folded, this code is unsimplified. Defer
246 // to other passes to handle that.
247 Value *X = ExtElt->getVectorOperand();
248 Value *C = ExtElt->getIndexOperand();
249 assert(isa<ConstantInt>(C) && "Expected a constant index operand");
250 if (isa<Constant>(X))
251 return nullptr;
252
253 Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
254 NewIndex, Builder);
255 return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex));
256 }
257
258 /// Try to reduce extract element costs by converting scalar compares to vector
259 /// compares followed by extract.
260 /// cmp (ext0 V0, C), (ext1 V1, C)
foldExtExtCmp(ExtractElementInst * Ext0,ExtractElementInst * Ext1,Instruction & I)261 void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
262 ExtractElementInst *Ext1, Instruction &I) {
263 assert(isa<CmpInst>(&I) && "Expected a compare");
264 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
265 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
266 "Expected matching constant extract indexes");
267
268 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
269 ++NumVecCmp;
270 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
271 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
272 Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
273 Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
274 replaceValue(I, *NewExt);
275 }
276
277 /// Try to reduce extract element costs by converting scalar binops to vector
278 /// binops followed by extract.
279 /// bo (ext0 V0, C), (ext1 V1, C)
foldExtExtBinop(ExtractElementInst * Ext0,ExtractElementInst * Ext1,Instruction & I)280 void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
281 ExtractElementInst *Ext1, Instruction &I) {
282 assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
283 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
284 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
285 "Expected matching constant extract indexes");
286
287 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
288 ++NumVecBO;
289 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
290 Value *VecBO =
291 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
292
293 // All IR flags are safe to back-propagate because any potential poison
294 // created in unused vector elements is discarded by the extract.
295 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
296 VecBOInst->copyIRFlags(&I);
297
298 Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
299 replaceValue(I, *NewExt);
300 }
301
302 /// Match an instruction with extracted vector operands.
foldExtractExtract(Instruction & I)303 bool VectorCombine::foldExtractExtract(Instruction &I) {
304 // It is not safe to transform things like div, urem, etc. because we may
305 // create undefined behavior when executing those on unknown vector elements.
306 if (!isSafeToSpeculativelyExecute(&I))
307 return false;
308
309 Instruction *I0, *I1;
310 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
311 if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
312 !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
313 return false;
314
315 Value *V0, *V1;
316 uint64_t C0, C1;
317 if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
318 !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) ||
319 V0->getType() != V1->getType())
320 return false;
321
322 // If the scalar value 'I' is going to be re-inserted into a vector, then try
323 // to create an extract to that same element. The extract/insert can be
324 // reduced to a "select shuffle".
325 // TODO: If we add a larger pattern match that starts from an insert, this
326 // probably becomes unnecessary.
327 auto *Ext0 = cast<ExtractElementInst>(I0);
328 auto *Ext1 = cast<ExtractElementInst>(I1);
329 uint64_t InsertIndex = InvalidIndex;
330 if (I.hasOneUse())
331 match(I.user_back(),
332 m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
333
334 ExtractElementInst *ExtractToChange;
335 if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), ExtractToChange,
336 InsertIndex))
337 return false;
338
339 if (ExtractToChange) {
340 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
341 ExtractElementInst *NewExtract =
342 translateExtract(ExtractToChange, CheapExtractIdx, Builder);
343 if (!NewExtract)
344 return false;
345 if (ExtractToChange == Ext0)
346 Ext0 = NewExtract;
347 else
348 Ext1 = NewExtract;
349 }
350
351 if (Pred != CmpInst::BAD_ICMP_PREDICATE)
352 foldExtExtCmp(Ext0, Ext1, I);
353 else
354 foldExtExtBinop(Ext0, Ext1, I);
355
356 return true;
357 }
358
359 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
360 /// destination type followed by shuffle. This can enable further transforms by
361 /// moving bitcasts or shuffles together.
foldBitcastShuf(Instruction & I)362 bool VectorCombine::foldBitcastShuf(Instruction &I) {
363 Value *V;
364 ArrayRef<int> Mask;
365 if (!match(&I, m_BitCast(
366 m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))))))
367 return false;
368
369 // Disallow non-vector casts and length-changing shuffles.
370 // TODO: We could allow any shuffle.
371 auto *DestTy = dyn_cast<VectorType>(I.getType());
372 auto *SrcTy = cast<VectorType>(V->getType());
373 if (!DestTy || I.getOperand(0)->getType() != SrcTy)
374 return false;
375
376 // The new shuffle must not cost more than the old shuffle. The bitcast is
377 // moved ahead of the shuffle, so assume that it has the same cost as before.
378 if (TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, DestTy) >
379 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy))
380 return false;
381
382 unsigned DestNumElts = DestTy->getNumElements();
383 unsigned SrcNumElts = SrcTy->getNumElements();
384 SmallVector<int, 16> NewMask;
385 if (SrcNumElts <= DestNumElts) {
386 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
387 // always be expanded to the equivalent form choosing narrower elements.
388 assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask");
389 unsigned ScaleFactor = DestNumElts / SrcNumElts;
390 narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
391 } else {
392 // The bitcast is from narrow elements to wide elements. The shuffle mask
393 // must choose consecutive elements to allow casting first.
394 assert(SrcNumElts % DestNumElts == 0 && "Unexpected shuffle mask");
395 unsigned ScaleFactor = SrcNumElts / DestNumElts;
396 if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
397 return false;
398 }
399 // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC'
400 ++NumShufOfBitcast;
401 Value *CastV = Builder.CreateBitCast(V, DestTy);
402 Value *Shuf =
403 Builder.CreateShuffleVector(CastV, UndefValue::get(DestTy), NewMask);
404 replaceValue(I, *Shuf);
405 return true;
406 }
407
408 /// Match a vector binop or compare instruction with at least one inserted
409 /// scalar operand and convert to scalar binop/cmp followed by insertelement.
scalarizeBinopOrCmp(Instruction & I)410 bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
411 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
412 Value *Ins0, *Ins1;
413 if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
414 !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1))))
415 return false;
416
417 // Do not convert the vector condition of a vector select into a scalar
418 // condition. That may cause problems for codegen because of differences in
419 // boolean formats and register-file transfers.
420 // TODO: Can we account for that in the cost model?
421 bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE;
422 if (IsCmp)
423 for (User *U : I.users())
424 if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
425 return false;
426
427 // Match against one or both scalar values being inserted into constant
428 // vectors:
429 // vec_op VecC0, (inselt VecC1, V1, Index)
430 // vec_op (inselt VecC0, V0, Index), VecC1
431 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
432 // TODO: Deal with mismatched index constants and variable indexes?
433 Constant *VecC0 = nullptr, *VecC1 = nullptr;
434 Value *V0 = nullptr, *V1 = nullptr;
435 uint64_t Index0 = 0, Index1 = 0;
436 if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
437 m_ConstantInt(Index0))) &&
438 !match(Ins0, m_Constant(VecC0)))
439 return false;
440 if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
441 m_ConstantInt(Index1))) &&
442 !match(Ins1, m_Constant(VecC1)))
443 return false;
444
445 bool IsConst0 = !V0;
446 bool IsConst1 = !V1;
447 if (IsConst0 && IsConst1)
448 return false;
449 if (!IsConst0 && !IsConst1 && Index0 != Index1)
450 return false;
451
452 // Bail for single insertion if it is a load.
453 // TODO: Handle this once getVectorInstrCost can cost for load/stores.
454 auto *I0 = dyn_cast_or_null<Instruction>(V0);
455 auto *I1 = dyn_cast_or_null<Instruction>(V1);
456 if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
457 (IsConst1 && I0 && I0->mayReadFromMemory()))
458 return false;
459
460 uint64_t Index = IsConst0 ? Index1 : Index0;
461 Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
462 Type *VecTy = I.getType();
463 assert(VecTy->isVectorTy() &&
464 (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
465 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
466 ScalarTy->isPointerTy()) &&
467 "Unexpected types for insert element into binop or cmp");
468
469 unsigned Opcode = I.getOpcode();
470 int ScalarOpCost, VectorOpCost;
471 if (IsCmp) {
472 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy);
473 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy);
474 } else {
475 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
476 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
477 }
478
479 // Get cost estimate for the insert element. This cost will factor into
480 // both sequences.
481 int InsertCost =
482 TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, Index);
483 int OldCost = (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) +
484 VectorOpCost;
485 int NewCost = ScalarOpCost + InsertCost +
486 (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
487 (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
488
489 // We want to scalarize unless the vector variant actually has lower cost.
490 if (OldCost < NewCost)
491 return false;
492
493 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
494 // inselt NewVecC, (scalar_op V0, V1), Index
495 if (IsCmp)
496 ++NumScalarCmp;
497 else
498 ++NumScalarBO;
499
500 // For constant cases, extract the scalar element, this should constant fold.
501 if (IsConst0)
502 V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
503 if (IsConst1)
504 V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
505
506 Value *Scalar =
507 IsCmp ? Builder.CreateCmp(Pred, V0, V1)
508 : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
509
510 Scalar->setName(I.getName() + ".scalar");
511
512 // All IR flags are safe to back-propagate. There is no potential for extra
513 // poison to be created by the scalar instruction.
514 if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
515 ScalarInst->copyIRFlags(&I);
516
517 // Fold the vector constants in the original vectors into a new base vector.
518 Constant *NewVecC = IsCmp ? ConstantExpr::getCompare(Pred, VecC0, VecC1)
519 : ConstantExpr::get(Opcode, VecC0, VecC1);
520 Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
521 replaceValue(I, *Insert);
522 return true;
523 }
524
525 /// Try to combine a scalar binop + 2 scalar compares of extracted elements of
526 /// a vector into vector operations followed by extract. Note: The SLP pass
527 /// may miss this pattern because of implementation problems.
foldExtractedCmps(Instruction & I)528 bool VectorCombine::foldExtractedCmps(Instruction &I) {
529 // We are looking for a scalar binop of booleans.
530 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
531 if (!I.isBinaryOp() || !I.getType()->isIntegerTy(1))
532 return false;
533
534 // The compare predicates should match, and each compare should have a
535 // constant operand.
536 // TODO: Relax the one-use constraints.
537 Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
538 Instruction *I0, *I1;
539 Constant *C0, *C1;
540 CmpInst::Predicate P0, P1;
541 if (!match(B0, m_OneUse(m_Cmp(P0, m_Instruction(I0), m_Constant(C0)))) ||
542 !match(B1, m_OneUse(m_Cmp(P1, m_Instruction(I1), m_Constant(C1)))) ||
543 P0 != P1)
544 return false;
545
546 // The compare operands must be extracts of the same vector with constant
547 // extract indexes.
548 // TODO: Relax the one-use constraints.
549 Value *X;
550 uint64_t Index0, Index1;
551 if (!match(I0, m_OneUse(m_ExtractElt(m_Value(X), m_ConstantInt(Index0)))) ||
552 !match(I1, m_OneUse(m_ExtractElt(m_Specific(X), m_ConstantInt(Index1)))))
553 return false;
554
555 auto *Ext0 = cast<ExtractElementInst>(I0);
556 auto *Ext1 = cast<ExtractElementInst>(I1);
557 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1);
558 if (!ConvertToShuf)
559 return false;
560
561 // The original scalar pattern is:
562 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
563 CmpInst::Predicate Pred = P0;
564 unsigned CmpOpcode = CmpInst::isFPPredicate(Pred) ? Instruction::FCmp
565 : Instruction::ICmp;
566 auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
567 if (!VecTy)
568 return false;
569
570 int OldCost = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0);
571 OldCost += TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1);
572 OldCost += TTI.getCmpSelInstrCost(CmpOpcode, I0->getType()) * 2;
573 OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType());
574
575 // The proposed vector pattern is:
576 // vcmp = cmp Pred X, VecC
577 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
578 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
579 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
580 auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType()));
581 int NewCost = TTI.getCmpSelInstrCost(CmpOpcode, X->getType());
582 NewCost +=
583 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy);
584 NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy);
585 NewCost += TTI.getVectorInstrCost(Ext0->getOpcode(), CmpTy, CheapIndex);
586
587 // Aggressively form vector ops if the cost is equal because the transform
588 // may enable further optimization.
589 // Codegen can reverse this transform (scalarize) if it was not profitable.
590 if (OldCost < NewCost)
591 return false;
592
593 // Create a vector constant from the 2 scalar constants.
594 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
595 UndefValue::get(VecTy->getElementType()));
596 CmpC[Index0] = C0;
597 CmpC[Index1] = C1;
598 Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
599
600 Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
601 Value *VecLogic = Builder.CreateBinOp(cast<BinaryOperator>(I).getOpcode(),
602 VCmp, Shuf);
603 Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
604 replaceValue(I, *NewExt);
605 ++NumVecCmpBO;
606 return true;
607 }
608
609 /// This is the entry point for all transforms. Pass manager differences are
610 /// handled in the callers of this function.
run()611 bool VectorCombine::run() {
612 if (DisableVectorCombine)
613 return false;
614
615 bool MadeChange = false;
616 for (BasicBlock &BB : F) {
617 // Ignore unreachable basic blocks.
618 if (!DT.isReachableFromEntry(&BB))
619 continue;
620 // Do not delete instructions under here and invalidate the iterator.
621 // Walk the block forwards to enable simple iterative chains of transforms.
622 // TODO: It could be more efficient to remove dead instructions
623 // iteratively in this loop rather than waiting until the end.
624 for (Instruction &I : BB) {
625 if (isa<DbgInfoIntrinsic>(I))
626 continue;
627 Builder.SetInsertPoint(&I);
628 MadeChange |= foldExtractExtract(I);
629 MadeChange |= foldBitcastShuf(I);
630 MadeChange |= scalarizeBinopOrCmp(I);
631 MadeChange |= foldExtractedCmps(I);
632 }
633 }
634
635 // We're done with transforms, so remove dead instructions.
636 if (MadeChange)
637 for (BasicBlock &BB : F)
638 SimplifyInstructionsInBlock(&BB);
639
640 return MadeChange;
641 }
642
643 // Pass manager boilerplate below here.
644
645 namespace {
646 class VectorCombineLegacyPass : public FunctionPass {
647 public:
648 static char ID;
VectorCombineLegacyPass()649 VectorCombineLegacyPass() : FunctionPass(ID) {
650 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
651 }
652
getAnalysisUsage(AnalysisUsage & AU) const653 void getAnalysisUsage(AnalysisUsage &AU) const override {
654 AU.addRequired<DominatorTreeWrapperPass>();
655 AU.addRequired<TargetTransformInfoWrapperPass>();
656 AU.setPreservesCFG();
657 AU.addPreserved<DominatorTreeWrapperPass>();
658 AU.addPreserved<GlobalsAAWrapperPass>();
659 AU.addPreserved<AAResultsWrapperPass>();
660 AU.addPreserved<BasicAAWrapperPass>();
661 FunctionPass::getAnalysisUsage(AU);
662 }
663
runOnFunction(Function & F)664 bool runOnFunction(Function &F) override {
665 if (skipFunction(F))
666 return false;
667 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
668 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
669 VectorCombine Combiner(F, TTI, DT);
670 return Combiner.run();
671 }
672 };
673 } // namespace
674
675 char VectorCombineLegacyPass::ID = 0;
676 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
677 "Optimize scalar/vector ops", false,
678 false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)679 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
680 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
681 "Optimize scalar/vector ops", false, false)
682 Pass *llvm::createVectorCombinePass() {
683 return new VectorCombineLegacyPass();
684 }
685
run(Function & F,FunctionAnalysisManager & FAM)686 PreservedAnalyses VectorCombinePass::run(Function &F,
687 FunctionAnalysisManager &FAM) {
688 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
689 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
690 VectorCombine Combiner(F, TTI, DT);
691 if (!Combiner.run())
692 return PreservedAnalyses::all();
693 PreservedAnalyses PA;
694 PA.preserveSet<CFGAnalyses>();
695 PA.preserve<GlobalsAA>();
696 PA.preserve<AAManager>();
697 PA.preserve<BasicAA>();
698 return PA;
699 }
700