1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines some vectorizer utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H
14 #define LLVM_ANALYSIS_VECTORUTILS_H
15 
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopAccessAnalysis.h"
19 #include "llvm/IR/VFABIDemangler.h"
20 #include "llvm/Support/CheckedArithmetic.h"
21 
22 namespace llvm {
23 class TargetLibraryInfo;
24 
25 /// The Vector Function Database.
26 ///
27 /// Helper class used to find the vector functions associated to a
28 /// scalar CallInst.
29 class VFDatabase {
30   /// The Module of the CallInst CI.
31   const Module *M;
32   /// The CallInst instance being queried for scalar to vector mappings.
33   const CallInst &CI;
34   /// List of vector functions descriptors associated to the call
35   /// instruction.
36   const SmallVector<VFInfo, 8> ScalarToVectorMappings;
37 
38   /// Retrieve the scalar-to-vector mappings associated to the rule of
39   /// a vector Function ABI.
getVFABIMappings(const CallInst & CI,SmallVectorImpl<VFInfo> & Mappings)40   static void getVFABIMappings(const CallInst &CI,
41                                SmallVectorImpl<VFInfo> &Mappings) {
42     if (!CI.getCalledFunction())
43       return;
44 
45     const StringRef ScalarName = CI.getCalledFunction()->getName();
46 
47     SmallVector<std::string, 8> ListOfStrings;
48     // The check for the vector-function-abi-variant attribute is done when
49     // retrieving the vector variant names here.
50     VFABI::getVectorVariantNames(CI, ListOfStrings);
51     if (ListOfStrings.empty())
52       return;
53     for (const auto &MangledName : ListOfStrings) {
54       const std::optional<VFInfo> Shape =
55           VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType());
56       // A match is found via scalar and vector names, and also by
57       // ensuring that the variant described in the attribute has a
58       // corresponding definition or declaration of the vector
59       // function in the Module M.
60       if (Shape && (Shape->ScalarName == ScalarName)) {
61         assert(CI.getModule()->getFunction(Shape->VectorName) &&
62                "Vector function is missing.");
63         Mappings.push_back(*Shape);
64       }
65     }
66   }
67 
68 public:
69   /// Retrieve all the VFInfo instances associated to the CallInst CI.
getMappings(const CallInst & CI)70   static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) {
71     SmallVector<VFInfo, 8> Ret;
72 
73     // Get mappings from the Vector Function ABI variants.
74     getVFABIMappings(CI, Ret);
75 
76     // Other non-VFABI variants should be retrieved here.
77 
78     return Ret;
79   }
80 
81   static bool hasMaskedVariant(const CallInst &CI,
82                                std::optional<ElementCount> VF = std::nullopt) {
83     // Check whether we have at least one masked vector version of a scalar
84     // function. If no VF is specified then we check for any masked variant,
85     // otherwise we look for one that matches the supplied VF.
86     auto Mappings = VFDatabase::getMappings(CI);
87     for (VFInfo Info : Mappings)
88       if (!VF || Info.Shape.VF == *VF)
89         if (Info.isMasked())
90           return true;
91 
92     return false;
93   }
94 
95   /// Constructor, requires a CallInst instance.
VFDatabase(CallInst & CI)96   VFDatabase(CallInst &CI)
97       : M(CI.getModule()), CI(CI),
98         ScalarToVectorMappings(VFDatabase::getMappings(CI)) {}
99   /// \defgroup VFDatabase query interface.
100   ///
101   /// @{
102   /// Retrieve the Function with VFShape \p Shape.
getVectorizedFunction(const VFShape & Shape)103   Function *getVectorizedFunction(const VFShape &Shape) const {
104     if (Shape == VFShape::getScalarShape(CI.getFunctionType()))
105       return CI.getCalledFunction();
106 
107     for (const auto &Info : ScalarToVectorMappings)
108       if (Info.Shape == Shape)
109         return M->getFunction(Info.VectorName);
110 
111     return nullptr;
112   }
113   /// @}
114 };
115 
116 template <typename T> class ArrayRef;
117 class DemandedBits;
118 template <typename InstTy> class InterleaveGroup;
119 class IRBuilderBase;
120 class Loop;
121 class ScalarEvolution;
122 class TargetTransformInfo;
123 class Type;
124 class Value;
125 
126 namespace Intrinsic {
127 typedef unsigned ID;
128 }
129 
130 /// A helper function for converting Scalar types to vector types. If
131 /// the incoming type is void, we return void. If the EC represents a
132 /// scalar, we return the scalar type.
ToVectorTy(Type * Scalar,ElementCount EC)133 inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
134   if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
135     return Scalar;
136   return VectorType::get(Scalar, EC);
137 }
138 
ToVectorTy(Type * Scalar,unsigned VF)139 inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
140   return ToVectorTy(Scalar, ElementCount::getFixed(VF));
141 }
142 
143 /// Identify if the intrinsic is trivially vectorizable.
144 /// This method returns true if the intrinsic's argument types are all scalars
145 /// for the scalar form of the intrinsic and all vectors (or scalars handled by
146 /// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic.
147 bool isTriviallyVectorizable(Intrinsic::ID ID);
148 
149 /// Identifies if the vector form of the intrinsic has a scalar operand.
150 bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
151                                         unsigned ScalarOpdIdx);
152 
153 /// Identifies if the vector form of the intrinsic is overloaded on the type of
154 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
155 bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
156 
157 /// Returns intrinsic ID for call.
158 /// For the input call instruction it finds mapping intrinsic and returns
159 /// its intrinsic ID, in case it does not found it return not_intrinsic.
160 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI,
161                                           const TargetLibraryInfo *TLI);
162 
163 /// Given a vector and an element number, see if the scalar value is
164 /// already around as a register, for example if it were inserted then extracted
165 /// from the vector.
166 Value *findScalarElement(Value *V, unsigned EltNo);
167 
168 /// If all non-negative \p Mask elements are the same value, return that value.
169 /// If all elements are negative (undefined) or \p Mask contains different
170 /// non-negative values, return -1.
171 int getSplatIndex(ArrayRef<int> Mask);
172 
173 /// Get splat value if the input is a splat vector or return nullptr.
174 /// The value may be extracted from a splat constants vector or from
175 /// a sequence of instructions that broadcast a single value into a vector.
176 Value *getSplatValue(const Value *V);
177 
178 /// Return true if each element of the vector value \p V is poisoned or equal to
179 /// every other non-poisoned element. If an index element is specified, either
180 /// every element of the vector is poisoned or the element at that index is not
181 /// poisoned and equal to every other non-poisoned element.
182 /// This may be more powerful than the related getSplatValue() because it is
183 /// not limited by finding a scalar source value to a splatted vector.
184 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
185 
186 /// Transform a shuffle mask's output demanded element mask into demanded
187 /// element masks for the 2 operands, returns false if the mask isn't valid.
188 /// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth].
189 /// \p AllowUndefElts permits "-1" indices to be treated as undef.
190 bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
191                             const APInt &DemandedElts, APInt &DemandedLHS,
192                             APInt &DemandedRHS, bool AllowUndefElts = false);
193 
194 /// Replace each shuffle mask index with the scaled sequential indices for an
195 /// equivalent mask of narrowed elements. Mask elements that are less than 0
196 /// (sentinel values) are repeated in the output mask.
197 ///
198 /// Example with Scale = 4:
199 ///   <4 x i32> <3, 2, 0, -1> -->
200 ///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1>
201 ///
202 /// This is the reverse process of widening shuffle mask elements, but it always
203 /// succeeds because the indexes can always be multiplied (scaled up) to map to
204 /// narrower vector elements.
205 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
206                            SmallVectorImpl<int> &ScaledMask);
207 
208 /// Try to transform a shuffle mask by replacing elements with the scaled index
209 /// for an equivalent mask of widened elements. If all mask elements that would
210 /// map to a wider element of the new mask are the same negative number
211 /// (sentinel value), that element of the new mask is the same value. If any
212 /// element in a given slice is negative and some other element in that slice is
213 /// not the same value, return false (partial matches with sentinel values are
214 /// not allowed).
215 ///
216 /// Example with Scale = 4:
217 ///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> -->
218 ///   <4 x i32> <3, 2, 0, -1>
219 ///
220 /// This is the reverse process of narrowing shuffle mask elements if it
221 /// succeeds. This transform is not always possible because indexes may not
222 /// divide evenly (scale down) to map to wider vector elements.
223 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
224                           SmallVectorImpl<int> &ScaledMask);
225 
226 /// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds,
227 /// to get the shuffle mask with widest possible elements.
228 void getShuffleMaskWithWidestElts(ArrayRef<int> Mask,
229                                   SmallVectorImpl<int> &ScaledMask);
230 
231 /// Splits and processes shuffle mask depending on the number of input and
232 /// output registers. The function does 2 main things: 1) splits the
233 /// source/destination vectors into real registers; 2) do the mask analysis to
234 /// identify which real registers are permuted. Then the function processes
235 /// resulting registers mask using provided action items. If no input register
236 /// is defined, \p NoInputAction action is used. If only 1 input register is
237 /// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to
238 /// process > 2 input registers and masks.
239 /// \param Mask Original shuffle mask.
240 /// \param NumOfSrcRegs Number of source registers.
241 /// \param NumOfDestRegs Number of destination registers.
242 /// \param NumOfUsedRegs Number of actually used destination registers.
243 void processShuffleMasks(
244     ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs,
245     unsigned NumOfUsedRegs, function_ref<void()> NoInputAction,
246     function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
247     function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction);
248 
249 /// Compute a map of integer instructions to their minimum legal type
250 /// size.
251 ///
252 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
253 /// type (e.g. i32) whenever arithmetic is performed on them.
254 ///
255 /// For targets with native i8 or i16 operations, usually InstCombine can shrink
256 /// the arithmetic type down again. However InstCombine refuses to create
257 /// illegal types, so for targets without i8 or i16 registers, the lengthening
258 /// and shrinking remains.
259 ///
260 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
261 /// their scalar equivalents do not, so during vectorization it is important to
262 /// remove these lengthens and truncates when deciding the profitability of
263 /// vectorization.
264 ///
265 /// This function analyzes the given range of instructions and determines the
266 /// minimum type size each can be converted to. It attempts to remove or
267 /// minimize type size changes across each def-use chain, so for example in the
268 /// following code:
269 ///
270 ///   %1 = load i8, i8*
271 ///   %2 = add i8 %1, 2
272 ///   %3 = load i16, i16*
273 ///   %4 = zext i8 %2 to i32
274 ///   %5 = zext i16 %3 to i32
275 ///   %6 = add i32 %4, %5
276 ///   %7 = trunc i32 %6 to i16
277 ///
278 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
279 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
280 ///
281 /// If the optional TargetTransformInfo is provided, this function tries harder
282 /// to do less work by only looking at illegal types.
283 MapVector<Instruction*, uint64_t>
284 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks,
285                          DemandedBits &DB,
286                          const TargetTransformInfo *TTI=nullptr);
287 
288 /// Compute the union of two access-group lists.
289 ///
290 /// If the list contains just one access group, it is returned directly. If the
291 /// list is empty, returns nullptr.
292 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2);
293 
294 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2
295 /// are both in. If either instruction does not access memory at all, it is
296 /// considered to be in every list.
297 ///
298 /// If the list contains just one access group, it is returned directly. If the
299 /// list is empty, returns nullptr.
300 MDNode *intersectAccessGroups(const Instruction *Inst1,
301                               const Instruction *Inst2);
302 
303 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath,
304 /// MD_nontemporal, MD_access_group].
305 /// For K in Kinds, we get the MDNode for K from each of the
306 /// elements of VL, compute their "intersection" (i.e., the most generic
307 /// metadata value that covers all of the individual values), and set I's
308 /// metadata for M equal to the intersection value.
309 ///
310 /// This function always sets a (possibly null) value for each K in Kinds.
311 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL);
312 
313 /// Create a mask that filters the members of an interleave group where there
314 /// are gaps.
315 ///
316 /// For example, the mask for \p Group with interleave-factor 3
317 /// and \p VF 4, that has only its first member present is:
318 ///
319 ///   <1,0,0,1,0,0,1,0,0,1,0,0>
320 ///
321 /// Note: The result is a mask of 0's and 1's, as opposed to the other
322 /// create[*]Mask() utilities which create a shuffle mask (mask that
323 /// consists of indices).
324 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
325                                const InterleaveGroup<Instruction> &Group);
326 
327 /// Create a mask with replicated elements.
328 ///
329 /// This function creates a shuffle mask for replicating each of the \p VF
330 /// elements in a vector \p ReplicationFactor times. It can be used to
331 /// transform a mask of \p VF elements into a mask of
332 /// \p VF * \p ReplicationFactor elements used by a predicated
333 /// interleaved-group of loads/stores whose Interleaved-factor ==
334 /// \p ReplicationFactor.
335 ///
336 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
337 ///
338 ///   <0,0,0,1,1,1,2,2,2,3,3,3>
339 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor,
340                                                 unsigned VF);
341 
342 /// Create an interleave shuffle mask.
343 ///
344 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of
345 /// vectorization factor \p VF into a single wide vector. The mask is of the
346 /// form:
347 ///
348 ///   <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...>
349 ///
350 /// For example, the mask for VF = 4 and NumVecs = 2 is:
351 ///
352 ///   <0, 4, 1, 5, 2, 6, 3, 7>.
353 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs);
354 
355 /// Create a stride shuffle mask.
356 ///
357 /// This function creates a shuffle mask whose elements begin at \p Start and
358 /// are incremented by \p Stride. The mask can be used to deinterleave an
359 /// interleaved vector into separate vectors of vectorization factor \p VF. The
360 /// mask is of the form:
361 ///
362 ///   <Start, Start + Stride, ..., Start + Stride * (VF - 1)>
363 ///
364 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
365 ///
366 ///   <0, 2, 4, 6>
367 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride,
368                                             unsigned VF);
369 
370 /// Create a sequential shuffle mask.
371 ///
372 /// This function creates shuffle mask whose elements are sequential and begin
373 /// at \p Start.  The mask contains \p NumInts integers and is padded with \p
374 /// NumUndefs undef values. The mask is of the form:
375 ///
376 ///   <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs>
377 ///
378 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
379 ///
380 ///   <0, 1, 2, 3, undef, undef, undef, undef>
381 llvm::SmallVector<int, 16>
382 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs);
383 
384 /// Given a shuffle mask for a binary shuffle, create the equivalent shuffle
385 /// mask assuming both operands are identical. This assumes that the unary
386 /// shuffle will use elements from operand 0 (operand 1 will be unused).
387 llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask,
388                                            unsigned NumElts);
389 
390 /// Concatenate a list of vectors.
391 ///
392 /// This function generates code that concatenate the vectors in \p Vecs into a
393 /// single large vector. The number of vectors should be greater than one, and
394 /// their element types should be the same. The number of elements in the
395 /// vectors should also be the same; however, if the last vector has fewer
396 /// elements, it will be padded with undefs.
397 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs);
398 
399 /// Given a mask vector of i1, Return true if all of the elements of this
400 /// predicate mask are known to be false or undef.  That is, return true if all
401 /// lanes can be assumed inactive.
402 bool maskIsAllZeroOrUndef(Value *Mask);
403 
404 /// Given a mask vector of i1, Return true if all of the elements of this
405 /// predicate mask are known to be true or undef.  That is, return true if all
406 /// lanes can be assumed active.
407 bool maskIsAllOneOrUndef(Value *Mask);
408 
409 /// Given a mask vector of i1, Return true if any of the elements of this
410 /// predicate mask are known to be true or undef.  That is, return true if at
411 /// least one lane can be assumed active.
412 bool maskContainsAllOneOrUndef(Value *Mask);
413 
414 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
415 /// for each lane which may be active.
416 APInt possiblyDemandedEltsInMask(Value *Mask);
417 
418 /// The group of interleaved loads/stores sharing the same stride and
419 /// close to each other.
420 ///
421 /// Each member in this group has an index starting from 0, and the largest
422 /// index should be less than interleaved factor, which is equal to the absolute
423 /// value of the access's stride.
424 ///
425 /// E.g. An interleaved load group of factor 4:
426 ///        for (unsigned i = 0; i < 1024; i+=4) {
427 ///          a = A[i];                           // Member of index 0
428 ///          b = A[i+1];                         // Member of index 1
429 ///          d = A[i+3];                         // Member of index 3
430 ///          ...
431 ///        }
432 ///
433 ///      An interleaved store group of factor 4:
434 ///        for (unsigned i = 0; i < 1024; i+=4) {
435 ///          ...
436 ///          A[i]   = a;                         // Member of index 0
437 ///          A[i+1] = b;                         // Member of index 1
438 ///          A[i+2] = c;                         // Member of index 2
439 ///          A[i+3] = d;                         // Member of index 3
440 ///        }
441 ///
442 /// Note: the interleaved load group could have gaps (missing members), but
443 /// the interleaved store group doesn't allow gaps.
444 template <typename InstTy> class InterleaveGroup {
445 public:
InterleaveGroup(uint32_t Factor,bool Reverse,Align Alignment)446   InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
447       : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
448         InsertPos(nullptr) {}
449 
InterleaveGroup(InstTy * Instr,int32_t Stride,Align Alignment)450   InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
451       : Alignment(Alignment), InsertPos(Instr) {
452     Factor = std::abs(Stride);
453     assert(Factor > 1 && "Invalid interleave factor");
454 
455     Reverse = Stride < 0;
456     Members[0] = Instr;
457   }
458 
isReverse()459   bool isReverse() const { return Reverse; }
getFactor()460   uint32_t getFactor() const { return Factor; }
getAlign()461   Align getAlign() const { return Alignment; }
getNumMembers()462   uint32_t getNumMembers() const { return Members.size(); }
463 
464   /// Try to insert a new member \p Instr with index \p Index and
465   /// alignment \p NewAlign. The index is related to the leader and it could be
466   /// negative if it is the new leader.
467   ///
468   /// \returns false if the instruction doesn't belong to the group.
insertMember(InstTy * Instr,int32_t Index,Align NewAlign)469   bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
470     // Make sure the key fits in an int32_t.
471     std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
472     if (!MaybeKey)
473       return false;
474     int32_t Key = *MaybeKey;
475 
476     // Skip if the key is used for either the tombstone or empty special values.
477     if (DenseMapInfo<int32_t>::getTombstoneKey() == Key ||
478         DenseMapInfo<int32_t>::getEmptyKey() == Key)
479       return false;
480 
481     // Skip if there is already a member with the same index.
482     if (Members.contains(Key))
483       return false;
484 
485     if (Key > LargestKey) {
486       // The largest index is always less than the interleave factor.
487       if (Index >= static_cast<int32_t>(Factor))
488         return false;
489 
490       LargestKey = Key;
491     } else if (Key < SmallestKey) {
492 
493       // Make sure the largest index fits in an int32_t.
494       std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key);
495       if (!MaybeLargestIndex)
496         return false;
497 
498       // The largest index is always less than the interleave factor.
499       if (*MaybeLargestIndex >= static_cast<int64_t>(Factor))
500         return false;
501 
502       SmallestKey = Key;
503     }
504 
505     // It's always safe to select the minimum alignment.
506     Alignment = std::min(Alignment, NewAlign);
507     Members[Key] = Instr;
508     return true;
509   }
510 
511   /// Get the member with the given index \p Index
512   ///
513   /// \returns nullptr if contains no such member.
getMember(uint32_t Index)514   InstTy *getMember(uint32_t Index) const {
515     int32_t Key = SmallestKey + Index;
516     return Members.lookup(Key);
517   }
518 
519   /// Get the index for the given member. Unlike the key in the member
520   /// map, the index starts from 0.
getIndex(const InstTy * Instr)521   uint32_t getIndex(const InstTy *Instr) const {
522     for (auto I : Members) {
523       if (I.second == Instr)
524         return I.first - SmallestKey;
525     }
526 
527     llvm_unreachable("InterleaveGroup contains no such member");
528   }
529 
getInsertPos()530   InstTy *getInsertPos() const { return InsertPos; }
setInsertPos(InstTy * Inst)531   void setInsertPos(InstTy *Inst) { InsertPos = Inst; }
532 
533   /// Add metadata (e.g. alias info) from the instructions in this group to \p
534   /// NewInst.
535   ///
536   /// FIXME: this function currently does not add noalias metadata a'la
537   /// addNewMedata.  To do that we need to compute the intersection of the
538   /// noalias info from all members.
539   void addMetadata(InstTy *NewInst) const;
540 
541   /// Returns true if this Group requires a scalar iteration to handle gaps.
requiresScalarEpilogue()542   bool requiresScalarEpilogue() const {
543     // If the last member of the Group exists, then a scalar epilog is not
544     // needed for this group.
545     if (getMember(getFactor() - 1))
546       return false;
547 
548     // We have a group with gaps. It therefore can't be a reversed access,
549     // because such groups get invalidated (TODO).
550     assert(!isReverse() && "Group should have been invalidated");
551 
552     // This is a group of loads, with gaps, and without a last-member
553     return true;
554   }
555 
556 private:
557   uint32_t Factor; // Interleave Factor.
558   bool Reverse;
559   Align Alignment;
560   DenseMap<int32_t, InstTy *> Members;
561   int32_t SmallestKey = 0;
562   int32_t LargestKey = 0;
563 
564   // To avoid breaking dependences, vectorized instructions of an interleave
565   // group should be inserted at either the first load or the last store in
566   // program order.
567   //
568   // E.g. %even = load i32             // Insert Position
569   //      %add = add i32 %even         // Use of %even
570   //      %odd = load i32
571   //
572   //      store i32 %even
573   //      %odd = add i32               // Def of %odd
574   //      store i32 %odd               // Insert Position
575   InstTy *InsertPos;
576 };
577 
578 /// Drive the analysis of interleaved memory accesses in the loop.
579 ///
580 /// Use this class to analyze interleaved accesses only when we can vectorize
581 /// a loop. Otherwise it's meaningless to do analysis as the vectorization
582 /// on interleaved accesses is unsafe.
583 ///
584 /// The analysis collects interleave groups and records the relationships
585 /// between the member and the group in a map.
586 class InterleavedAccessInfo {
587 public:
InterleavedAccessInfo(PredicatedScalarEvolution & PSE,Loop * L,DominatorTree * DT,LoopInfo * LI,const LoopAccessInfo * LAI)588   InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
589                         DominatorTree *DT, LoopInfo *LI,
590                         const LoopAccessInfo *LAI)
591       : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
592 
~InterleavedAccessInfo()593   ~InterleavedAccessInfo() { invalidateGroups(); }
594 
595   /// Analyze the interleaved accesses and collect them in interleave
596   /// groups. Substitute symbolic strides using \p Strides.
597   /// Consider also predicated loads/stores in the analysis if
598   /// \p EnableMaskedInterleavedGroup is true.
599   void analyzeInterleaving(bool EnableMaskedInterleavedGroup);
600 
601   /// Invalidate groups, e.g., in case all blocks in loop will be predicated
602   /// contrary to original assumption. Although we currently prevent group
603   /// formation for predicated accesses, we may be able to relax this limitation
604   /// in the future once we handle more complicated blocks. Returns true if any
605   /// groups were invalidated.
invalidateGroups()606   bool invalidateGroups() {
607     if (InterleaveGroups.empty()) {
608       assert(
609           !RequiresScalarEpilogue &&
610           "RequiresScalarEpilog should not be set without interleave groups");
611       return false;
612     }
613 
614     InterleaveGroupMap.clear();
615     for (auto *Ptr : InterleaveGroups)
616       delete Ptr;
617     InterleaveGroups.clear();
618     RequiresScalarEpilogue = false;
619     return true;
620   }
621 
622   /// Check if \p Instr belongs to any interleave group.
isInterleaved(Instruction * Instr)623   bool isInterleaved(Instruction *Instr) const {
624     return InterleaveGroupMap.contains(Instr);
625   }
626 
627   /// Get the interleave group that \p Instr belongs to.
628   ///
629   /// \returns nullptr if doesn't have such group.
630   InterleaveGroup<Instruction> *
getInterleaveGroup(const Instruction * Instr)631   getInterleaveGroup(const Instruction *Instr) const {
632     return InterleaveGroupMap.lookup(Instr);
633   }
634 
635   iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
getInterleaveGroups()636   getInterleaveGroups() {
637     return make_range(InterleaveGroups.begin(), InterleaveGroups.end());
638   }
639 
640   /// Returns true if an interleaved group that may access memory
641   /// out-of-bounds requires a scalar epilogue iteration for correctness.
requiresScalarEpilogue()642   bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; }
643 
644   /// Invalidate groups that require a scalar epilogue (due to gaps). This can
645   /// happen when optimizing for size forbids a scalar epilogue, and the gap
646   /// cannot be filtered by masking the load/store.
647   void invalidateGroupsRequiringScalarEpilogue();
648 
649   /// Returns true if we have any interleave groups.
hasGroups()650   bool hasGroups() const { return !InterleaveGroups.empty(); }
651 
652 private:
653   /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
654   /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
655   /// The interleaved access analysis can also add new predicates (for example
656   /// by versioning strides of pointers).
657   PredicatedScalarEvolution &PSE;
658 
659   Loop *TheLoop;
660   DominatorTree *DT;
661   LoopInfo *LI;
662   const LoopAccessInfo *LAI;
663 
664   /// True if the loop may contain non-reversed interleaved groups with
665   /// out-of-bounds accesses. We ensure we don't speculatively access memory
666   /// out-of-bounds by executing at least one scalar epilogue iteration.
667   bool RequiresScalarEpilogue = false;
668 
669   /// Holds the relationships between the members and the interleave group.
670   DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap;
671 
672   SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups;
673 
674   /// Holds dependences among the memory accesses in the loop. It maps a source
675   /// access to a set of dependent sink accesses.
676   DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences;
677 
678   /// The descriptor for a strided memory access.
679   struct StrideDescriptor {
680     StrideDescriptor() = default;
StrideDescriptorStrideDescriptor681     StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
682                      Align Alignment)
683         : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
684 
685     // The access's stride. It is negative for a reverse access.
686     int64_t Stride = 0;
687 
688     // The scalar expression of this access.
689     const SCEV *Scev = nullptr;
690 
691     // The size of the memory object.
692     uint64_t Size = 0;
693 
694     // The alignment of this access.
695     Align Alignment;
696   };
697 
698   /// A type for holding instructions and their stride descriptors.
699   using StrideEntry = std::pair<Instruction *, StrideDescriptor>;
700 
701   /// Create a new interleave group with the given instruction \p Instr,
702   /// stride \p Stride and alignment \p Align.
703   ///
704   /// \returns the newly created interleave group.
705   InterleaveGroup<Instruction> *
createInterleaveGroup(Instruction * Instr,int Stride,Align Alignment)706   createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
707     assert(!InterleaveGroupMap.count(Instr) &&
708            "Already in an interleaved access group");
709     InterleaveGroupMap[Instr] =
710         new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
711     InterleaveGroups.insert(InterleaveGroupMap[Instr]);
712     return InterleaveGroupMap[Instr];
713   }
714 
715   /// Release the group and remove all the relationships.
releaseGroup(InterleaveGroup<Instruction> * Group)716   void releaseGroup(InterleaveGroup<Instruction> *Group) {
717     for (unsigned i = 0; i < Group->getFactor(); i++)
718       if (Instruction *Member = Group->getMember(i))
719         InterleaveGroupMap.erase(Member);
720 
721     InterleaveGroups.erase(Group);
722     delete Group;
723   }
724 
725   /// Collect all the accesses with a constant stride in program order.
726   void collectConstStrideAccesses(
727       MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
728       const DenseMap<Value *, const SCEV *> &Strides);
729 
730   /// Returns true if \p Stride is allowed in an interleaved group.
731   static bool isStrided(int Stride);
732 
733   /// Returns true if \p BB is a predicated block.
isPredicated(BasicBlock * BB)734   bool isPredicated(BasicBlock *BB) const {
735     return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
736   }
737 
738   /// Returns true if LoopAccessInfo can be used for dependence queries.
areDependencesValid()739   bool areDependencesValid() const {
740     return LAI && LAI->getDepChecker().getDependences();
741   }
742 
743   /// Returns true if memory accesses \p A and \p B can be reordered, if
744   /// necessary, when constructing interleaved groups.
745   ///
746   /// \p A must precede \p B in program order. We return false if reordering is
747   /// not necessary or is prevented because \p A and \p B may be dependent.
canReorderMemAccessesForInterleavedGroups(StrideEntry * A,StrideEntry * B)748   bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A,
749                                                  StrideEntry *B) const {
750     // Code motion for interleaved accesses can potentially hoist strided loads
751     // and sink strided stores. The code below checks the legality of the
752     // following two conditions:
753     //
754     // 1. Potentially moving a strided load (B) before any store (A) that
755     //    precedes B, or
756     //
757     // 2. Potentially moving a strided store (A) after any load or store (B)
758     //    that A precedes.
759     //
760     // It's legal to reorder A and B if we know there isn't a dependence from A
761     // to B. Note that this determination is conservative since some
762     // dependences could potentially be reordered safely.
763 
764     // A is potentially the source of a dependence.
765     auto *Src = A->first;
766     auto SrcDes = A->second;
767 
768     // B is potentially the sink of a dependence.
769     auto *Sink = B->first;
770     auto SinkDes = B->second;
771 
772     // Code motion for interleaved accesses can't violate WAR dependences.
773     // Thus, reordering is legal if the source isn't a write.
774     if (!Src->mayWriteToMemory())
775       return true;
776 
777     // At least one of the accesses must be strided.
778     if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride))
779       return true;
780 
781     // If dependence information is not available from LoopAccessInfo,
782     // conservatively assume the instructions can't be reordered.
783     if (!areDependencesValid())
784       return false;
785 
786     // If we know there is a dependence from source to sink, assume the
787     // instructions can't be reordered. Otherwise, reordering is legal.
788     return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink);
789   }
790 
791   /// Collect the dependences from LoopAccessInfo.
792   ///
793   /// We process the dependences once during the interleaved access analysis to
794   /// enable constant-time dependence queries.
collectDependences()795   void collectDependences() {
796     if (!areDependencesValid())
797       return;
798     auto *Deps = LAI->getDepChecker().getDependences();
799     for (auto Dep : *Deps)
800       Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI));
801   }
802 };
803 
804 } // llvm namespace
805 
806 #endif
807