1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines some vectorizer utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H
14 #define LLVM_ANALYSIS_VECTORUTILS_H
15 
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/Analysis/LoopAccessAnalysis.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/Support/CheckedArithmetic.h"
21 
22 namespace llvm {
23 
24 /// Describes the type of Parameters
25 enum class VFParamKind {
26   Vector,            // No semantic information.
27   OMP_Linear,        // declare simd linear(i)
28   OMP_LinearRef,     // declare simd linear(ref(i))
29   OMP_LinearVal,     // declare simd linear(val(i))
30   OMP_LinearUVal,    // declare simd linear(uval(i))
31   OMP_LinearPos,     // declare simd linear(i:c) uniform(c)
32   OMP_LinearValPos,  // declare simd linear(val(i:c)) uniform(c)
33   OMP_LinearRefPos,  // declare simd linear(ref(i:c)) uniform(c)
34   OMP_LinearUValPos, // declare simd linear(uval(i:c)) uniform(c
35   OMP_Uniform,       // declare simd uniform(i)
36   GlobalPredicate,   // Global logical predicate that acts on all lanes
37                      // of the input and output mask concurrently. For
38                      // example, it is implied by the `M` token in the
39                      // Vector Function ABI mangled name.
40   Unknown
41 };
42 
43 /// Describes the type of Instruction Set Architecture
44 enum class VFISAKind {
45   AdvancedSIMD, // AArch64 Advanced SIMD (NEON)
46   SVE,          // AArch64 Scalable Vector Extension
47   SSE,          // x86 SSE
48   AVX,          // x86 AVX
49   AVX2,         // x86 AVX2
50   AVX512,       // x86 AVX512
51   LLVM,         // LLVM internal ISA for functions that are not
52   // attached to an existing ABI via name mangling.
53   Unknown // Unknown ISA
54 };
55 
56 /// Encapsulates information needed to describe a parameter.
57 ///
58 /// The description of the parameter is not linked directly to
59 /// OpenMP or any other vector function description. This structure
60 /// is extendible to handle other paradigms that describe vector
61 /// functions and their parameters.
62 struct VFParameter {
63   unsigned ParamPos;         // Parameter Position in Scalar Function.
64   VFParamKind ParamKind;     // Kind of Parameter.
65   int LinearStepOrPos = 0;   // Step or Position of the Parameter.
66   Align Alignment = Align(); // Optional aligment in bytes, defaulted to 1.
67 
68   // Comparison operator.
69   bool operator==(const VFParameter &Other) const {
70     return std::tie(ParamPos, ParamKind, LinearStepOrPos, Alignment) ==
71            std::tie(Other.ParamPos, Other.ParamKind, Other.LinearStepOrPos,
72                     Other.Alignment);
73   }
74 };
75 
76 /// Contains the information about the kind of vectorization
77 /// available.
78 ///
79 /// This object in independent on the paradigm used to
80 /// represent vector functions. in particular, it is not attached to
81 /// any target-specific ABI.
82 struct VFShape {
83   unsigned VF;     // Vectorization factor.
84   bool IsScalable; // True if the function is a scalable function.
85   SmallVector<VFParameter, 8> Parameters; // List of parameter informations.
86   // Comparison operator.
87   bool operator==(const VFShape &Other) const {
88     return std::tie(VF, IsScalable, Parameters) ==
89            std::tie(Other.VF, Other.IsScalable, Other.Parameters);
90   }
91 
92   /// Update the parameter in position P.ParamPos to P.
93   void updateParam(VFParameter P) {
94     assert(P.ParamPos < Parameters.size() && "Invalid parameter position.");
95     Parameters[P.ParamPos] = P;
96     assert(hasValidParameterList() && "Invalid parameter list");
97   }
98 
99   // Retrieve the basic vectorization shape of the function, where all
100   // parameters are mapped to VFParamKind::Vector with \p EC
101   // lanes. Specifies whether the function has a Global Predicate
102   // argument via \p HasGlobalPred.
103   static VFShape get(const CallInst &CI, ElementCount EC, bool HasGlobalPred) {
104     SmallVector<VFParameter, 8> Parameters;
105     for (unsigned I = 0; I < CI.arg_size(); ++I)
106       Parameters.push_back(VFParameter({I, VFParamKind::Vector}));
107     if (HasGlobalPred)
108       Parameters.push_back(
109           VFParameter({CI.arg_size(), VFParamKind::GlobalPredicate}));
110 
111     return {EC.Min, EC.Scalable, Parameters};
112   }
113   /// Sanity check on the Parameters in the VFShape.
114   bool hasValidParameterList() const;
115 };
116 
117 /// Holds the VFShape for a specific scalar to vector function mapping.
118 struct VFInfo {
119   VFShape Shape;        // Classification of the vector function.
120   StringRef ScalarName; // Scalar Function Name.
121   StringRef VectorName; // Vector Function Name associated to this VFInfo.
122   VFISAKind ISA;        // Instruction Set Architecture.
123 
124   // Comparison operator.
125   bool operator==(const VFInfo &Other) const {
126     return std::tie(Shape, ScalarName, VectorName, ISA) ==
127            std::tie(Shape, Other.ScalarName, Other.VectorName, Other.ISA);
128   }
129 };
130 
131 namespace VFABI {
132 /// LLVM Internal VFABI ISA token for vector functions.
133 static constexpr char const *_LLVM_ = "_LLVM_";
134 
135 /// Function to contruct a VFInfo out of a mangled names in the
136 /// following format:
137 ///
138 /// <VFABI_name>{(<redirection>)}
139 ///
140 /// where <VFABI_name> is the name of the vector function, mangled according
141 /// to the rules described in the Vector Function ABI of the target vector
142 /// extentsion (or <isa> from now on). The <VFABI_name> is in the following
143 /// format:
144 ///
145 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]
146 ///
147 /// This methods support demangling rules for the following <isa>:
148 ///
149 /// * AArch64: https://developer.arm.com/docs/101129/latest
150 ///
151 /// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and
152 ///  https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt
153 ///
154 /// \param MangledName -> input string in the format
155 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)].
156 Optional<VFInfo> tryDemangleForVFABI(StringRef MangledName);
157 
158 /// Retrieve the `VFParamKind` from a string token.
159 VFParamKind getVFParamKindFromString(const StringRef Token);
160 
161 // Name of the attribute where the variant mappings are stored.
162 static constexpr char const *MappingsAttrName = "vector-function-abi-variant";
163 
164 /// Populates a set of strings representing the Vector Function ABI variants
165 /// associated to the CallInst CI.
166 void getVectorVariantNames(const CallInst &CI,
167                            SmallVectorImpl<std::string> &VariantMappings);
168 } // end namespace VFABI
169 
170 template <typename T> class ArrayRef;
171 class DemandedBits;
172 class GetElementPtrInst;
173 template <typename InstTy> class InterleaveGroup;
174 class Loop;
175 class ScalarEvolution;
176 class TargetTransformInfo;
177 class Type;
178 class Value;
179 
180 namespace Intrinsic {
181 typedef unsigned ID;
182 }
183 
184 /// Identify if the intrinsic is trivially vectorizable.
185 /// This method returns true if the intrinsic's argument types are all scalars
186 /// for the scalar form of the intrinsic and all vectors (or scalars handled by
187 /// hasVectorInstrinsicScalarOpd) for the vector form of the intrinsic.
188 bool isTriviallyVectorizable(Intrinsic::ID ID);
189 
190 /// Identifies if the vector form of the intrinsic has a scalar operand.
191 bool hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, unsigned ScalarOpdIdx);
192 
193 /// Returns intrinsic ID for call.
194 /// For the input call instruction it finds mapping intrinsic and returns
195 /// its intrinsic ID, in case it does not found it return not_intrinsic.
196 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI,
197                                           const TargetLibraryInfo *TLI);
198 
199 /// Find the operand of the GEP that should be checked for consecutive
200 /// stores. This ignores trailing indices that have no effect on the final
201 /// pointer.
202 unsigned getGEPInductionOperand(const GetElementPtrInst *Gep);
203 
204 /// If the argument is a GEP, then returns the operand identified by
205 /// getGEPInductionOperand. However, if there is some other non-loop-invariant
206 /// operand, it returns that instead.
207 Value *stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
208 
209 /// If a value has only one user that is a CastInst, return it.
210 Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty);
211 
212 /// Get the stride of a pointer access in a loop. Looks for symbolic
213 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
214 Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
215 
216 /// Given a vector and an element number, see if the scalar value is
217 /// already around as a register, for example if it were inserted then extracted
218 /// from the vector.
219 Value *findScalarElement(Value *V, unsigned EltNo);
220 
221 /// Get splat value if the input is a splat vector or return nullptr.
222 /// The value may be extracted from a splat constants vector or from
223 /// a sequence of instructions that broadcast a single value into a vector.
224 const Value *getSplatValue(const Value *V);
225 
226 /// Return true if the input value is known to be a vector with all identical
227 /// elements (potentially including undefined elements).
228 /// This may be more powerful than the related getSplatValue() because it is
229 /// not limited by finding a scalar source value to a splatted vector.
230 bool isSplatValue(const Value *V, unsigned Depth = 0);
231 
232 /// Compute a map of integer instructions to their minimum legal type
233 /// size.
234 ///
235 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
236 /// type (e.g. i32) whenever arithmetic is performed on them.
237 ///
238 /// For targets with native i8 or i16 operations, usually InstCombine can shrink
239 /// the arithmetic type down again. However InstCombine refuses to create
240 /// illegal types, so for targets without i8 or i16 registers, the lengthening
241 /// and shrinking remains.
242 ///
243 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
244 /// their scalar equivalents do not, so during vectorization it is important to
245 /// remove these lengthens and truncates when deciding the profitability of
246 /// vectorization.
247 ///
248 /// This function analyzes the given range of instructions and determines the
249 /// minimum type size each can be converted to. It attempts to remove or
250 /// minimize type size changes across each def-use chain, so for example in the
251 /// following code:
252 ///
253 ///   %1 = load i8, i8*
254 ///   %2 = add i8 %1, 2
255 ///   %3 = load i16, i16*
256 ///   %4 = zext i8 %2 to i32
257 ///   %5 = zext i16 %3 to i32
258 ///   %6 = add i32 %4, %5
259 ///   %7 = trunc i32 %6 to i16
260 ///
261 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
262 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
263 ///
264 /// If the optional TargetTransformInfo is provided, this function tries harder
265 /// to do less work by only looking at illegal types.
266 MapVector<Instruction*, uint64_t>
267 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks,
268                          DemandedBits &DB,
269                          const TargetTransformInfo *TTI=nullptr);
270 
271 /// Compute the union of two access-group lists.
272 ///
273 /// If the list contains just one access group, it is returned directly. If the
274 /// list is empty, returns nullptr.
275 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2);
276 
277 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2
278 /// are both in. If either instruction does not access memory at all, it is
279 /// considered to be in every list.
280 ///
281 /// If the list contains just one access group, it is returned directly. If the
282 /// list is empty, returns nullptr.
283 MDNode *intersectAccessGroups(const Instruction *Inst1,
284                               const Instruction *Inst2);
285 
286 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath,
287 /// MD_nontemporal, MD_access_group].
288 /// For K in Kinds, we get the MDNode for K from each of the
289 /// elements of VL, compute their "intersection" (i.e., the most generic
290 /// metadata value that covers all of the individual values), and set I's
291 /// metadata for M equal to the intersection value.
292 ///
293 /// This function always sets a (possibly null) value for each K in Kinds.
294 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL);
295 
296 /// Create a mask that filters the members of an interleave group where there
297 /// are gaps.
298 ///
299 /// For example, the mask for \p Group with interleave-factor 3
300 /// and \p VF 4, that has only its first member present is:
301 ///
302 ///   <1,0,0,1,0,0,1,0,0,1,0,0>
303 ///
304 /// Note: The result is a mask of 0's and 1's, as opposed to the other
305 /// create[*]Mask() utilities which create a shuffle mask (mask that
306 /// consists of indices).
307 Constant *createBitMaskForGaps(IRBuilder<> &Builder, unsigned VF,
308                                const InterleaveGroup<Instruction> &Group);
309 
310 /// Create a mask with replicated elements.
311 ///
312 /// This function creates a shuffle mask for replicating each of the \p VF
313 /// elements in a vector \p ReplicationFactor times. It can be used to
314 /// transform a mask of \p VF elements into a mask of
315 /// \p VF * \p ReplicationFactor elements used by a predicated
316 /// interleaved-group of loads/stores whose Interleaved-factor ==
317 /// \p ReplicationFactor.
318 ///
319 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
320 ///
321 ///   <0,0,0,1,1,1,2,2,2,3,3,3>
322 Constant *createReplicatedMask(IRBuilder<> &Builder, unsigned ReplicationFactor,
323                                unsigned VF);
324 
325 /// Create an interleave shuffle mask.
326 ///
327 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of
328 /// vectorization factor \p VF into a single wide vector. The mask is of the
329 /// form:
330 ///
331 ///   <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...>
332 ///
333 /// For example, the mask for VF = 4 and NumVecs = 2 is:
334 ///
335 ///   <0, 4, 1, 5, 2, 6, 3, 7>.
336 Constant *createInterleaveMask(IRBuilder<> &Builder, unsigned VF,
337                                unsigned NumVecs);
338 
339 /// Create a stride shuffle mask.
340 ///
341 /// This function creates a shuffle mask whose elements begin at \p Start and
342 /// are incremented by \p Stride. The mask can be used to deinterleave an
343 /// interleaved vector into separate vectors of vectorization factor \p VF. The
344 /// mask is of the form:
345 ///
346 ///   <Start, Start + Stride, ..., Start + Stride * (VF - 1)>
347 ///
348 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
349 ///
350 ///   <0, 2, 4, 6>
351 Constant *createStrideMask(IRBuilder<> &Builder, unsigned Start,
352                            unsigned Stride, unsigned VF);
353 
354 /// Create a sequential shuffle mask.
355 ///
356 /// This function creates shuffle mask whose elements are sequential and begin
357 /// at \p Start.  The mask contains \p NumInts integers and is padded with \p
358 /// NumUndefs undef values. The mask is of the form:
359 ///
360 ///   <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs>
361 ///
362 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
363 ///
364 ///   <0, 1, 2, 3, undef, undef, undef, undef>
365 Constant *createSequentialMask(IRBuilder<> &Builder, unsigned Start,
366                                unsigned NumInts, unsigned NumUndefs);
367 
368 /// Concatenate a list of vectors.
369 ///
370 /// This function generates code that concatenate the vectors in \p Vecs into a
371 /// single large vector. The number of vectors should be greater than one, and
372 /// their element types should be the same. The number of elements in the
373 /// vectors should also be the same; however, if the last vector has fewer
374 /// elements, it will be padded with undefs.
375 Value *concatenateVectors(IRBuilder<> &Builder, ArrayRef<Value *> Vecs);
376 
377 /// Given a mask vector of the form <Y x i1>, Return true if all of the
378 /// elements of this predicate mask are false or undef.  That is, return true
379 /// if all lanes can be assumed inactive.
380 bool maskIsAllZeroOrUndef(Value *Mask);
381 
382 /// Given a mask vector of the form <Y x i1>, Return true if all of the
383 /// elements of this predicate mask are true or undef.  That is, return true
384 /// if all lanes can be assumed active.
385 bool maskIsAllOneOrUndef(Value *Mask);
386 
387 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
388 /// for each lane which may be active.
389 APInt possiblyDemandedEltsInMask(Value *Mask);
390 
391 /// The group of interleaved loads/stores sharing the same stride and
392 /// close to each other.
393 ///
394 /// Each member in this group has an index starting from 0, and the largest
395 /// index should be less than interleaved factor, which is equal to the absolute
396 /// value of the access's stride.
397 ///
398 /// E.g. An interleaved load group of factor 4:
399 ///        for (unsigned i = 0; i < 1024; i+=4) {
400 ///          a = A[i];                           // Member of index 0
401 ///          b = A[i+1];                         // Member of index 1
402 ///          d = A[i+3];                         // Member of index 3
403 ///          ...
404 ///        }
405 ///
406 ///      An interleaved store group of factor 4:
407 ///        for (unsigned i = 0; i < 1024; i+=4) {
408 ///          ...
409 ///          A[i]   = a;                         // Member of index 0
410 ///          A[i+1] = b;                         // Member of index 1
411 ///          A[i+2] = c;                         // Member of index 2
412 ///          A[i+3] = d;                         // Member of index 3
413 ///        }
414 ///
415 /// Note: the interleaved load group could have gaps (missing members), but
416 /// the interleaved store group doesn't allow gaps.
417 template <typename InstTy> class InterleaveGroup {
418 public:
419   InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
420       : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
421         InsertPos(nullptr) {}
422 
423   InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
424       : Alignment(Alignment), InsertPos(Instr) {
425     Factor = std::abs(Stride);
426     assert(Factor > 1 && "Invalid interleave factor");
427 
428     Reverse = Stride < 0;
429     Members[0] = Instr;
430   }
431 
432   bool isReverse() const { return Reverse; }
433   uint32_t getFactor() const { return Factor; }
434   uint32_t getAlignment() const { return Alignment.value(); }
435   uint32_t getNumMembers() const { return Members.size(); }
436 
437   /// Try to insert a new member \p Instr with index \p Index and
438   /// alignment \p NewAlign. The index is related to the leader and it could be
439   /// negative if it is the new leader.
440   ///
441   /// \returns false if the instruction doesn't belong to the group.
442   bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
443     // Make sure the key fits in an int32_t.
444     Optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
445     if (!MaybeKey)
446       return false;
447     int32_t Key = *MaybeKey;
448 
449     // Skip if there is already a member with the same index.
450     if (Members.find(Key) != Members.end())
451       return false;
452 
453     if (Key > LargestKey) {
454       // The largest index is always less than the interleave factor.
455       if (Index >= static_cast<int32_t>(Factor))
456         return false;
457 
458       LargestKey = Key;
459     } else if (Key < SmallestKey) {
460 
461       // Make sure the largest index fits in an int32_t.
462       Optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key);
463       if (!MaybeLargestIndex)
464         return false;
465 
466       // The largest index is always less than the interleave factor.
467       if (*MaybeLargestIndex >= static_cast<int64_t>(Factor))
468         return false;
469 
470       SmallestKey = Key;
471     }
472 
473     // It's always safe to select the minimum alignment.
474     Alignment = std::min(Alignment, NewAlign);
475     Members[Key] = Instr;
476     return true;
477   }
478 
479   /// Get the member with the given index \p Index
480   ///
481   /// \returns nullptr if contains no such member.
482   InstTy *getMember(uint32_t Index) const {
483     int32_t Key = SmallestKey + Index;
484     auto Member = Members.find(Key);
485     if (Member == Members.end())
486       return nullptr;
487 
488     return Member->second;
489   }
490 
491   /// Get the index for the given member. Unlike the key in the member
492   /// map, the index starts from 0.
493   uint32_t getIndex(const InstTy *Instr) const {
494     for (auto I : Members) {
495       if (I.second == Instr)
496         return I.first - SmallestKey;
497     }
498 
499     llvm_unreachable("InterleaveGroup contains no such member");
500   }
501 
502   InstTy *getInsertPos() const { return InsertPos; }
503   void setInsertPos(InstTy *Inst) { InsertPos = Inst; }
504 
505   /// Add metadata (e.g. alias info) from the instructions in this group to \p
506   /// NewInst.
507   ///
508   /// FIXME: this function currently does not add noalias metadata a'la
509   /// addNewMedata.  To do that we need to compute the intersection of the
510   /// noalias info from all members.
511   void addMetadata(InstTy *NewInst) const;
512 
513   /// Returns true if this Group requires a scalar iteration to handle gaps.
514   bool requiresScalarEpilogue() const {
515     // If the last member of the Group exists, then a scalar epilog is not
516     // needed for this group.
517     if (getMember(getFactor() - 1))
518       return false;
519 
520     // We have a group with gaps. It therefore cannot be a group of stores,
521     // and it can't be a reversed access, because such groups get invalidated.
522     assert(!getMember(0)->mayWriteToMemory() &&
523            "Group should have been invalidated");
524     assert(!isReverse() && "Group should have been invalidated");
525 
526     // This is a group of loads, with gaps, and without a last-member
527     return true;
528   }
529 
530 private:
531   uint32_t Factor; // Interleave Factor.
532   bool Reverse;
533   Align Alignment;
534   DenseMap<int32_t, InstTy *> Members;
535   int32_t SmallestKey = 0;
536   int32_t LargestKey = 0;
537 
538   // To avoid breaking dependences, vectorized instructions of an interleave
539   // group should be inserted at either the first load or the last store in
540   // program order.
541   //
542   // E.g. %even = load i32             // Insert Position
543   //      %add = add i32 %even         // Use of %even
544   //      %odd = load i32
545   //
546   //      store i32 %even
547   //      %odd = add i32               // Def of %odd
548   //      store i32 %odd               // Insert Position
549   InstTy *InsertPos;
550 };
551 
552 /// Drive the analysis of interleaved memory accesses in the loop.
553 ///
554 /// Use this class to analyze interleaved accesses only when we can vectorize
555 /// a loop. Otherwise it's meaningless to do analysis as the vectorization
556 /// on interleaved accesses is unsafe.
557 ///
558 /// The analysis collects interleave groups and records the relationships
559 /// between the member and the group in a map.
560 class InterleavedAccessInfo {
561 public:
562   InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
563                         DominatorTree *DT, LoopInfo *LI,
564                         const LoopAccessInfo *LAI)
565       : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
566 
567   ~InterleavedAccessInfo() { reset(); }
568 
569   /// Analyze the interleaved accesses and collect them in interleave
570   /// groups. Substitute symbolic strides using \p Strides.
571   /// Consider also predicated loads/stores in the analysis if
572   /// \p EnableMaskedInterleavedGroup is true.
573   void analyzeInterleaving(bool EnableMaskedInterleavedGroup);
574 
575   /// Invalidate groups, e.g., in case all blocks in loop will be predicated
576   /// contrary to original assumption. Although we currently prevent group
577   /// formation for predicated accesses, we may be able to relax this limitation
578   /// in the future once we handle more complicated blocks.
579   void reset() {
580     InterleaveGroupMap.clear();
581     for (auto *Ptr : InterleaveGroups)
582       delete Ptr;
583     InterleaveGroups.clear();
584     RequiresScalarEpilogue = false;
585   }
586 
587 
588   /// Check if \p Instr belongs to any interleave group.
589   bool isInterleaved(Instruction *Instr) const {
590     return InterleaveGroupMap.find(Instr) != InterleaveGroupMap.end();
591   }
592 
593   /// Get the interleave group that \p Instr belongs to.
594   ///
595   /// \returns nullptr if doesn't have such group.
596   InterleaveGroup<Instruction> *
597   getInterleaveGroup(const Instruction *Instr) const {
598     if (InterleaveGroupMap.count(Instr))
599       return InterleaveGroupMap.find(Instr)->second;
600     return nullptr;
601   }
602 
603   iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
604   getInterleaveGroups() {
605     return make_range(InterleaveGroups.begin(), InterleaveGroups.end());
606   }
607 
608   /// Returns true if an interleaved group that may access memory
609   /// out-of-bounds requires a scalar epilogue iteration for correctness.
610   bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; }
611 
612   /// Invalidate groups that require a scalar epilogue (due to gaps). This can
613   /// happen when optimizing for size forbids a scalar epilogue, and the gap
614   /// cannot be filtered by masking the load/store.
615   void invalidateGroupsRequiringScalarEpilogue();
616 
617 private:
618   /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
619   /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
620   /// The interleaved access analysis can also add new predicates (for example
621   /// by versioning strides of pointers).
622   PredicatedScalarEvolution &PSE;
623 
624   Loop *TheLoop;
625   DominatorTree *DT;
626   LoopInfo *LI;
627   const LoopAccessInfo *LAI;
628 
629   /// True if the loop may contain non-reversed interleaved groups with
630   /// out-of-bounds accesses. We ensure we don't speculatively access memory
631   /// out-of-bounds by executing at least one scalar epilogue iteration.
632   bool RequiresScalarEpilogue = false;
633 
634   /// Holds the relationships between the members and the interleave group.
635   DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap;
636 
637   SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups;
638 
639   /// Holds dependences among the memory accesses in the loop. It maps a source
640   /// access to a set of dependent sink accesses.
641   DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences;
642 
643   /// The descriptor for a strided memory access.
644   struct StrideDescriptor {
645     StrideDescriptor() = default;
646     StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
647                      Align Alignment)
648         : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
649 
650     // The access's stride. It is negative for a reverse access.
651     int64_t Stride = 0;
652 
653     // The scalar expression of this access.
654     const SCEV *Scev = nullptr;
655 
656     // The size of the memory object.
657     uint64_t Size = 0;
658 
659     // The alignment of this access.
660     Align Alignment;
661   };
662 
663   /// A type for holding instructions and their stride descriptors.
664   using StrideEntry = std::pair<Instruction *, StrideDescriptor>;
665 
666   /// Create a new interleave group with the given instruction \p Instr,
667   /// stride \p Stride and alignment \p Align.
668   ///
669   /// \returns the newly created interleave group.
670   InterleaveGroup<Instruction> *
671   createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
672     assert(!InterleaveGroupMap.count(Instr) &&
673            "Already in an interleaved access group");
674     InterleaveGroupMap[Instr] =
675         new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
676     InterleaveGroups.insert(InterleaveGroupMap[Instr]);
677     return InterleaveGroupMap[Instr];
678   }
679 
680   /// Release the group and remove all the relationships.
681   void releaseGroup(InterleaveGroup<Instruction> *Group) {
682     for (unsigned i = 0; i < Group->getFactor(); i++)
683       if (Instruction *Member = Group->getMember(i))
684         InterleaveGroupMap.erase(Member);
685 
686     InterleaveGroups.erase(Group);
687     delete Group;
688   }
689 
690   /// Collect all the accesses with a constant stride in program order.
691   void collectConstStrideAccesses(
692       MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
693       const ValueToValueMap &Strides);
694 
695   /// Returns true if \p Stride is allowed in an interleaved group.
696   static bool isStrided(int Stride);
697 
698   /// Returns true if \p BB is a predicated block.
699   bool isPredicated(BasicBlock *BB) const {
700     return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
701   }
702 
703   /// Returns true if LoopAccessInfo can be used for dependence queries.
704   bool areDependencesValid() const {
705     return LAI && LAI->getDepChecker().getDependences();
706   }
707 
708   /// Returns true if memory accesses \p A and \p B can be reordered, if
709   /// necessary, when constructing interleaved groups.
710   ///
711   /// \p A must precede \p B in program order. We return false if reordering is
712   /// not necessary or is prevented because \p A and \p B may be dependent.
713   bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A,
714                                                  StrideEntry *B) const {
715     // Code motion for interleaved accesses can potentially hoist strided loads
716     // and sink strided stores. The code below checks the legality of the
717     // following two conditions:
718     //
719     // 1. Potentially moving a strided load (B) before any store (A) that
720     //    precedes B, or
721     //
722     // 2. Potentially moving a strided store (A) after any load or store (B)
723     //    that A precedes.
724     //
725     // It's legal to reorder A and B if we know there isn't a dependence from A
726     // to B. Note that this determination is conservative since some
727     // dependences could potentially be reordered safely.
728 
729     // A is potentially the source of a dependence.
730     auto *Src = A->first;
731     auto SrcDes = A->second;
732 
733     // B is potentially the sink of a dependence.
734     auto *Sink = B->first;
735     auto SinkDes = B->second;
736 
737     // Code motion for interleaved accesses can't violate WAR dependences.
738     // Thus, reordering is legal if the source isn't a write.
739     if (!Src->mayWriteToMemory())
740       return true;
741 
742     // At least one of the accesses must be strided.
743     if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride))
744       return true;
745 
746     // If dependence information is not available from LoopAccessInfo,
747     // conservatively assume the instructions can't be reordered.
748     if (!areDependencesValid())
749       return false;
750 
751     // If we know there is a dependence from source to sink, assume the
752     // instructions can't be reordered. Otherwise, reordering is legal.
753     return Dependences.find(Src) == Dependences.end() ||
754            !Dependences.lookup(Src).count(Sink);
755   }
756 
757   /// Collect the dependences from LoopAccessInfo.
758   ///
759   /// We process the dependences once during the interleaved access analysis to
760   /// enable constant-time dependence queries.
761   void collectDependences() {
762     if (!areDependencesValid())
763       return;
764     auto *Deps = LAI->getDepChecker().getDependences();
765     for (auto Dep : *Deps)
766       Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI));
767   }
768 };
769 
770 } // llvm namespace
771 
772 #endif
773