1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines some vectorizer utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H
14 #define LLVM_ANALYSIS_VECTORUTILS_H
15 
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopAccessAnalysis.h"
19 #include "llvm/Support/CheckedArithmetic.h"
20 
21 namespace llvm {
22 class TargetLibraryInfo;
23 
24 /// Describes the type of Parameters
25 enum class VFParamKind {
26   Vector,            // No semantic information.
27   OMP_Linear,        // declare simd linear(i)
28   OMP_LinearRef,     // declare simd linear(ref(i))
29   OMP_LinearVal,     // declare simd linear(val(i))
30   OMP_LinearUVal,    // declare simd linear(uval(i))
31   OMP_LinearPos,     // declare simd linear(i:c) uniform(c)
32   OMP_LinearValPos,  // declare simd linear(val(i:c)) uniform(c)
33   OMP_LinearRefPos,  // declare simd linear(ref(i:c)) uniform(c)
34   OMP_LinearUValPos, // declare simd linear(uval(i:c)) uniform(c)
35   OMP_Uniform,       // declare simd uniform(i)
36   GlobalPredicate,   // Global logical predicate that acts on all lanes
37                      // of the input and output mask concurrently. For
38                      // example, it is implied by the `M` token in the
39                      // Vector Function ABI mangled name.
40   Unknown
41 };
42 
43 /// Describes the type of Instruction Set Architecture
44 enum class VFISAKind {
45   AdvancedSIMD, // AArch64 Advanced SIMD (NEON)
46   SVE,          // AArch64 Scalable Vector Extension
47   SSE,          // x86 SSE
48   AVX,          // x86 AVX
49   AVX2,         // x86 AVX2
50   AVX512,       // x86 AVX512
51   LLVM,         // LLVM internal ISA for functions that are not
52   // attached to an existing ABI via name mangling.
53   Unknown // Unknown ISA
54 };
55 
56 /// Encapsulates information needed to describe a parameter.
57 ///
58 /// The description of the parameter is not linked directly to
59 /// OpenMP or any other vector function description. This structure
60 /// is extendible to handle other paradigms that describe vector
61 /// functions and their parameters.
62 struct VFParameter {
63   unsigned ParamPos;         // Parameter Position in Scalar Function.
64   VFParamKind ParamKind;     // Kind of Parameter.
65   int LinearStepOrPos = 0;   // Step or Position of the Parameter.
66   Align Alignment = Align(); // Optional alignment in bytes, defaulted to 1.
67 
68   // Comparison operator.
69   bool operator==(const VFParameter &Other) const {
70     return std::tie(ParamPos, ParamKind, LinearStepOrPos, Alignment) ==
71            std::tie(Other.ParamPos, Other.ParamKind, Other.LinearStepOrPos,
72                     Other.Alignment);
73   }
74 };
75 
76 /// Contains the information about the kind of vectorization
77 /// available.
78 ///
79 /// This object in independent on the paradigm used to
80 /// represent vector functions. in particular, it is not attached to
81 /// any target-specific ABI.
82 struct VFShape {
83   ElementCount VF;                        // Vectorization factor.
84   SmallVector<VFParameter, 8> Parameters; // List of parameter information.
85   // Comparison operator.
86   bool operator==(const VFShape &Other) const {
87     return std::tie(VF, Parameters) == std::tie(Other.VF, Other.Parameters);
88   }
89 
90   /// Update the parameter in position P.ParamPos to P.
91   void updateParam(VFParameter P) {
92     assert(P.ParamPos < Parameters.size() && "Invalid parameter position.");
93     Parameters[P.ParamPos] = P;
94     assert(hasValidParameterList() && "Invalid parameter list");
95   }
96 
97   // Retrieve the VFShape that can be used to map a (scalar) function to itself,
98   // with VF = 1.
99   static VFShape getScalarShape(const CallInst &CI) {
100     return VFShape::get(CI, ElementCount::getFixed(1),
101                         /*HasGlobalPredicate*/ false);
102   }
103 
104   // Retrieve the basic vectorization shape of the function, where all
105   // parameters are mapped to VFParamKind::Vector with \p EC
106   // lanes. Specifies whether the function has a Global Predicate
107   // argument via \p HasGlobalPred.
108   static VFShape get(const CallInst &CI, ElementCount EC, bool HasGlobalPred) {
109     SmallVector<VFParameter, 8> Parameters;
110     for (unsigned I = 0; I < CI.arg_size(); ++I)
111       Parameters.push_back(VFParameter({I, VFParamKind::Vector}));
112     if (HasGlobalPred)
113       Parameters.push_back(
114           VFParameter({CI.arg_size(), VFParamKind::GlobalPredicate}));
115 
116     return {EC, Parameters};
117   }
118   /// Validation check on the Parameters in the VFShape.
119   bool hasValidParameterList() const;
120 };
121 
122 /// Holds the VFShape for a specific scalar to vector function mapping.
123 struct VFInfo {
124   VFShape Shape;          /// Classification of the vector function.
125   std::string ScalarName; /// Scalar Function Name.
126   std::string VectorName; /// Vector Function Name associated to this VFInfo.
127   VFISAKind ISA;          /// Instruction Set Architecture.
128 
129   /// Returns the index of the first parameter with the kind 'GlobalPredicate',
130   /// if any exist.
131   std::optional<unsigned> getParamIndexForOptionalMask() const {
132     unsigned ParamCount = Shape.Parameters.size();
133     for (unsigned i = 0; i < ParamCount; ++i)
134       if (Shape.Parameters[i].ParamKind == VFParamKind::GlobalPredicate)
135         return i;
136 
137     return std::nullopt;
138   }
139 
140   /// Returns true if at least one of the operands to the vectorized function
141   /// has the kind 'GlobalPredicate'.
142   bool isMasked() const { return getParamIndexForOptionalMask().has_value(); }
143 };
144 
145 namespace VFABI {
146 /// LLVM Internal VFABI ISA token for vector functions.
147 static constexpr char const *_LLVM_ = "_LLVM_";
148 /// Prefix for internal name redirection for vector function that
149 /// tells the compiler to scalarize the call using the scalar name
150 /// of the function. For example, a mangled name like
151 /// `_ZGV_LLVM_N2v_foo(_LLVM_Scalarize_foo)` would tell the
152 /// vectorizer to vectorize the scalar call `foo`, and to scalarize
153 /// it once vectorization is done.
154 static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_";
155 
156 /// Function to construct a VFInfo out of a mangled names in the
157 /// following format:
158 ///
159 /// <VFABI_name>{(<redirection>)}
160 ///
161 /// where <VFABI_name> is the name of the vector function, mangled according
162 /// to the rules described in the Vector Function ABI of the target vector
163 /// extension (or <isa> from now on). The <VFABI_name> is in the following
164 /// format:
165 ///
166 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]
167 ///
168 /// This methods support demangling rules for the following <isa>:
169 ///
170 /// * AArch64: https://developer.arm.com/docs/101129/latest
171 ///
172 /// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and
173 ///  https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt
174 ///
175 /// \param MangledName -> input string in the format
176 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)].
177 /// \param M -> Module used to retrieve informations about the vector
178 /// function that are not possible to retrieve from the mangled
179 /// name. At the moment, this parameter is needed only to retrieve the
180 /// Vectorization Factor of scalable vector functions from their
181 /// respective IR declarations.
182 std::optional<VFInfo> tryDemangleForVFABI(StringRef MangledName,
183                                           const Module &M);
184 
185 /// This routine mangles the given VectorName according to the LangRef
186 /// specification for vector-function-abi-variant attribute and is specific to
187 /// the TLI mappings. It is the responsibility of the caller to make sure that
188 /// this is only used if all parameters in the vector function are vector type.
189 /// This returned string holds scalar-to-vector mapping:
190 ///    _ZGV<isa><mask><vlen><vparams>_<scalarname>(<vectorname>)
191 ///
192 /// where:
193 ///
194 /// <isa> = "_LLVM_"
195 /// <mask> = "M" if masked, "N" if no mask.
196 /// <vlen> = Number of concurrent lanes, stored in the `VectorizationFactor`
197 ///          field of the `VecDesc` struct. If the number of lanes is scalable
198 ///          then 'x' is printed instead.
199 /// <vparams> = "v", as many as are the numArgs.
200 /// <scalarname> = the name of the scalar function.
201 /// <vectorname> = the name of the vector function.
202 std::string mangleTLIVectorName(StringRef VectorName, StringRef ScalarName,
203                                 unsigned numArgs, ElementCount VF,
204                                 bool Masked = false);
205 
206 /// Retrieve the `VFParamKind` from a string token.
207 VFParamKind getVFParamKindFromString(const StringRef Token);
208 
209 // Name of the attribute where the variant mappings are stored.
210 static constexpr char const *MappingsAttrName = "vector-function-abi-variant";
211 
212 /// Populates a set of strings representing the Vector Function ABI variants
213 /// associated to the CallInst CI. If the CI does not contain the
214 /// vector-function-abi-variant attribute, we return without populating
215 /// VariantMappings, i.e. callers of getVectorVariantNames need not check for
216 /// the presence of the attribute (see InjectTLIMappings).
217 void getVectorVariantNames(const CallInst &CI,
218                            SmallVectorImpl<std::string> &VariantMappings);
219 } // end namespace VFABI
220 
221 /// The Vector Function Database.
222 ///
223 /// Helper class used to find the vector functions associated to a
224 /// scalar CallInst.
225 class VFDatabase {
226   /// The Module of the CallInst CI.
227   const Module *M;
228   /// The CallInst instance being queried for scalar to vector mappings.
229   const CallInst &CI;
230   /// List of vector functions descriptors associated to the call
231   /// instruction.
232   const SmallVector<VFInfo, 8> ScalarToVectorMappings;
233 
234   /// Retrieve the scalar-to-vector mappings associated to the rule of
235   /// a vector Function ABI.
236   static void getVFABIMappings(const CallInst &CI,
237                                SmallVectorImpl<VFInfo> &Mappings) {
238     if (!CI.getCalledFunction())
239       return;
240 
241     const StringRef ScalarName = CI.getCalledFunction()->getName();
242 
243     SmallVector<std::string, 8> ListOfStrings;
244     // The check for the vector-function-abi-variant attribute is done when
245     // retrieving the vector variant names here.
246     VFABI::getVectorVariantNames(CI, ListOfStrings);
247     if (ListOfStrings.empty())
248       return;
249     for (const auto &MangledName : ListOfStrings) {
250       const std::optional<VFInfo> Shape =
251           VFABI::tryDemangleForVFABI(MangledName, *(CI.getModule()));
252       // A match is found via scalar and vector names, and also by
253       // ensuring that the variant described in the attribute has a
254       // corresponding definition or declaration of the vector
255       // function in the Module M.
256       if (Shape && (Shape->ScalarName == ScalarName)) {
257         assert(CI.getModule()->getFunction(Shape->VectorName) &&
258                "Vector function is missing.");
259         Mappings.push_back(*Shape);
260       }
261     }
262   }
263 
264 public:
265   /// Retrieve all the VFInfo instances associated to the CallInst CI.
266   static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) {
267     SmallVector<VFInfo, 8> Ret;
268 
269     // Get mappings from the Vector Function ABI variants.
270     getVFABIMappings(CI, Ret);
271 
272     // Other non-VFABI variants should be retrieved here.
273 
274     return Ret;
275   }
276 
277   static bool hasMaskedVariant(const CallInst &CI,
278                                std::optional<ElementCount> VF = std::nullopt) {
279     // Check whether we have at least one masked vector version of a scalar
280     // function. If no VF is specified then we check for any masked variant,
281     // otherwise we look for one that matches the supplied VF.
282     auto Mappings = VFDatabase::getMappings(CI);
283     for (VFInfo Info : Mappings)
284       if (!VF || Info.Shape.VF == *VF)
285         if (Info.isMasked())
286           return true;
287 
288     return false;
289   }
290 
291   /// Constructor, requires a CallInst instance.
292   VFDatabase(CallInst &CI)
293       : M(CI.getModule()), CI(CI),
294         ScalarToVectorMappings(VFDatabase::getMappings(CI)) {}
295   /// \defgroup VFDatabase query interface.
296   ///
297   /// @{
298   /// Retrieve the Function with VFShape \p Shape.
299   Function *getVectorizedFunction(const VFShape &Shape) const {
300     if (Shape == VFShape::getScalarShape(CI))
301       return CI.getCalledFunction();
302 
303     for (const auto &Info : ScalarToVectorMappings)
304       if (Info.Shape == Shape)
305         return M->getFunction(Info.VectorName);
306 
307     return nullptr;
308   }
309   /// @}
310 };
311 
312 template <typename T> class ArrayRef;
313 class DemandedBits;
314 template <typename InstTy> class InterleaveGroup;
315 class IRBuilderBase;
316 class Loop;
317 class ScalarEvolution;
318 class TargetTransformInfo;
319 class Type;
320 class Value;
321 
322 namespace Intrinsic {
323 typedef unsigned ID;
324 }
325 
326 /// A helper function for converting Scalar types to vector types. If
327 /// the incoming type is void, we return void. If the EC represents a
328 /// scalar, we return the scalar type.
329 inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
330   if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
331     return Scalar;
332   return VectorType::get(Scalar, EC);
333 }
334 
335 inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
336   return ToVectorTy(Scalar, ElementCount::getFixed(VF));
337 }
338 
339 /// Identify if the intrinsic is trivially vectorizable.
340 /// This method returns true if the intrinsic's argument types are all scalars
341 /// for the scalar form of the intrinsic and all vectors (or scalars handled by
342 /// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic.
343 bool isTriviallyVectorizable(Intrinsic::ID ID);
344 
345 /// Identifies if the vector form of the intrinsic has a scalar operand.
346 bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
347                                         unsigned ScalarOpdIdx);
348 
349 /// Identifies if the vector form of the intrinsic is overloaded on the type of
350 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
351 bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
352 
353 /// Returns intrinsic ID for call.
354 /// For the input call instruction it finds mapping intrinsic and returns
355 /// its intrinsic ID, in case it does not found it return not_intrinsic.
356 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI,
357                                           const TargetLibraryInfo *TLI);
358 
359 /// Given a vector and an element number, see if the scalar value is
360 /// already around as a register, for example if it were inserted then extracted
361 /// from the vector.
362 Value *findScalarElement(Value *V, unsigned EltNo);
363 
364 /// If all non-negative \p Mask elements are the same value, return that value.
365 /// If all elements are negative (undefined) or \p Mask contains different
366 /// non-negative values, return -1.
367 int getSplatIndex(ArrayRef<int> Mask);
368 
369 /// Get splat value if the input is a splat vector or return nullptr.
370 /// The value may be extracted from a splat constants vector or from
371 /// a sequence of instructions that broadcast a single value into a vector.
372 Value *getSplatValue(const Value *V);
373 
374 /// Return true if each element of the vector value \p V is poisoned or equal to
375 /// every other non-poisoned element. If an index element is specified, either
376 /// every element of the vector is poisoned or the element at that index is not
377 /// poisoned and equal to every other non-poisoned element.
378 /// This may be more powerful than the related getSplatValue() because it is
379 /// not limited by finding a scalar source value to a splatted vector.
380 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
381 
382 /// Transform a shuffle mask's output demanded element mask into demanded
383 /// element masks for the 2 operands, returns false if the mask isn't valid.
384 /// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth].
385 /// \p AllowUndefElts permits "-1" indices to be treated as undef.
386 bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
387                             const APInt &DemandedElts, APInt &DemandedLHS,
388                             APInt &DemandedRHS, bool AllowUndefElts = false);
389 
390 /// Replace each shuffle mask index with the scaled sequential indices for an
391 /// equivalent mask of narrowed elements. Mask elements that are less than 0
392 /// (sentinel values) are repeated in the output mask.
393 ///
394 /// Example with Scale = 4:
395 ///   <4 x i32> <3, 2, 0, -1> -->
396 ///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1>
397 ///
398 /// This is the reverse process of widening shuffle mask elements, but it always
399 /// succeeds because the indexes can always be multiplied (scaled up) to map to
400 /// narrower vector elements.
401 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
402                            SmallVectorImpl<int> &ScaledMask);
403 
404 /// Try to transform a shuffle mask by replacing elements with the scaled index
405 /// for an equivalent mask of widened elements. If all mask elements that would
406 /// map to a wider element of the new mask are the same negative number
407 /// (sentinel value), that element of the new mask is the same value. If any
408 /// element in a given slice is negative and some other element in that slice is
409 /// not the same value, return false (partial matches with sentinel values are
410 /// not allowed).
411 ///
412 /// Example with Scale = 4:
413 ///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> -->
414 ///   <4 x i32> <3, 2, 0, -1>
415 ///
416 /// This is the reverse process of narrowing shuffle mask elements if it
417 /// succeeds. This transform is not always possible because indexes may not
418 /// divide evenly (scale down) to map to wider vector elements.
419 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
420                           SmallVectorImpl<int> &ScaledMask);
421 
422 /// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds,
423 /// to get the shuffle mask with widest possible elements.
424 void getShuffleMaskWithWidestElts(ArrayRef<int> Mask,
425                                   SmallVectorImpl<int> &ScaledMask);
426 
427 /// Splits and processes shuffle mask depending on the number of input and
428 /// output registers. The function does 2 main things: 1) splits the
429 /// source/destination vectors into real registers; 2) do the mask analysis to
430 /// identify which real registers are permuted. Then the function processes
431 /// resulting registers mask using provided action items. If no input register
432 /// is defined, \p NoInputAction action is used. If only 1 input register is
433 /// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to
434 /// process > 2 input registers and masks.
435 /// \param Mask Original shuffle mask.
436 /// \param NumOfSrcRegs Number of source registers.
437 /// \param NumOfDestRegs Number of destination registers.
438 /// \param NumOfUsedRegs Number of actually used destination registers.
439 void processShuffleMasks(
440     ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs,
441     unsigned NumOfUsedRegs, function_ref<void()> NoInputAction,
442     function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
443     function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction);
444 
445 /// Compute a map of integer instructions to their minimum legal type
446 /// size.
447 ///
448 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
449 /// type (e.g. i32) whenever arithmetic is performed on them.
450 ///
451 /// For targets with native i8 or i16 operations, usually InstCombine can shrink
452 /// the arithmetic type down again. However InstCombine refuses to create
453 /// illegal types, so for targets without i8 or i16 registers, the lengthening
454 /// and shrinking remains.
455 ///
456 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
457 /// their scalar equivalents do not, so during vectorization it is important to
458 /// remove these lengthens and truncates when deciding the profitability of
459 /// vectorization.
460 ///
461 /// This function analyzes the given range of instructions and determines the
462 /// minimum type size each can be converted to. It attempts to remove or
463 /// minimize type size changes across each def-use chain, so for example in the
464 /// following code:
465 ///
466 ///   %1 = load i8, i8*
467 ///   %2 = add i8 %1, 2
468 ///   %3 = load i16, i16*
469 ///   %4 = zext i8 %2 to i32
470 ///   %5 = zext i16 %3 to i32
471 ///   %6 = add i32 %4, %5
472 ///   %7 = trunc i32 %6 to i16
473 ///
474 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
475 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
476 ///
477 /// If the optional TargetTransformInfo is provided, this function tries harder
478 /// to do less work by only looking at illegal types.
479 MapVector<Instruction*, uint64_t>
480 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks,
481                          DemandedBits &DB,
482                          const TargetTransformInfo *TTI=nullptr);
483 
484 /// Compute the union of two access-group lists.
485 ///
486 /// If the list contains just one access group, it is returned directly. If the
487 /// list is empty, returns nullptr.
488 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2);
489 
490 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2
491 /// are both in. If either instruction does not access memory at all, it is
492 /// considered to be in every list.
493 ///
494 /// If the list contains just one access group, it is returned directly. If the
495 /// list is empty, returns nullptr.
496 MDNode *intersectAccessGroups(const Instruction *Inst1,
497                               const Instruction *Inst2);
498 
499 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath,
500 /// MD_nontemporal, MD_access_group].
501 /// For K in Kinds, we get the MDNode for K from each of the
502 /// elements of VL, compute their "intersection" (i.e., the most generic
503 /// metadata value that covers all of the individual values), and set I's
504 /// metadata for M equal to the intersection value.
505 ///
506 /// This function always sets a (possibly null) value for each K in Kinds.
507 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL);
508 
509 /// Create a mask that filters the members of an interleave group where there
510 /// are gaps.
511 ///
512 /// For example, the mask for \p Group with interleave-factor 3
513 /// and \p VF 4, that has only its first member present is:
514 ///
515 ///   <1,0,0,1,0,0,1,0,0,1,0,0>
516 ///
517 /// Note: The result is a mask of 0's and 1's, as opposed to the other
518 /// create[*]Mask() utilities which create a shuffle mask (mask that
519 /// consists of indices).
520 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
521                                const InterleaveGroup<Instruction> &Group);
522 
523 /// Create a mask with replicated elements.
524 ///
525 /// This function creates a shuffle mask for replicating each of the \p VF
526 /// elements in a vector \p ReplicationFactor times. It can be used to
527 /// transform a mask of \p VF elements into a mask of
528 /// \p VF * \p ReplicationFactor elements used by a predicated
529 /// interleaved-group of loads/stores whose Interleaved-factor ==
530 /// \p ReplicationFactor.
531 ///
532 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
533 ///
534 ///   <0,0,0,1,1,1,2,2,2,3,3,3>
535 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor,
536                                                 unsigned VF);
537 
538 /// Create an interleave shuffle mask.
539 ///
540 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of
541 /// vectorization factor \p VF into a single wide vector. The mask is of the
542 /// form:
543 ///
544 ///   <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...>
545 ///
546 /// For example, the mask for VF = 4 and NumVecs = 2 is:
547 ///
548 ///   <0, 4, 1, 5, 2, 6, 3, 7>.
549 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs);
550 
551 /// Create a stride shuffle mask.
552 ///
553 /// This function creates a shuffle mask whose elements begin at \p Start and
554 /// are incremented by \p Stride. The mask can be used to deinterleave an
555 /// interleaved vector into separate vectors of vectorization factor \p VF. The
556 /// mask is of the form:
557 ///
558 ///   <Start, Start + Stride, ..., Start + Stride * (VF - 1)>
559 ///
560 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
561 ///
562 ///   <0, 2, 4, 6>
563 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride,
564                                             unsigned VF);
565 
566 /// Create a sequential shuffle mask.
567 ///
568 /// This function creates shuffle mask whose elements are sequential and begin
569 /// at \p Start.  The mask contains \p NumInts integers and is padded with \p
570 /// NumUndefs undef values. The mask is of the form:
571 ///
572 ///   <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs>
573 ///
574 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
575 ///
576 ///   <0, 1, 2, 3, undef, undef, undef, undef>
577 llvm::SmallVector<int, 16>
578 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs);
579 
580 /// Given a shuffle mask for a binary shuffle, create the equivalent shuffle
581 /// mask assuming both operands are identical. This assumes that the unary
582 /// shuffle will use elements from operand 0 (operand 1 will be unused).
583 llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask,
584                                            unsigned NumElts);
585 
586 /// Concatenate a list of vectors.
587 ///
588 /// This function generates code that concatenate the vectors in \p Vecs into a
589 /// single large vector. The number of vectors should be greater than one, and
590 /// their element types should be the same. The number of elements in the
591 /// vectors should also be the same; however, if the last vector has fewer
592 /// elements, it will be padded with undefs.
593 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs);
594 
595 /// Given a mask vector of i1, Return true if all of the elements of this
596 /// predicate mask are known to be false or undef.  That is, return true if all
597 /// lanes can be assumed inactive.
598 bool maskIsAllZeroOrUndef(Value *Mask);
599 
600 /// Given a mask vector of i1, Return true if all of the elements of this
601 /// predicate mask are known to be true or undef.  That is, return true if all
602 /// lanes can be assumed active.
603 bool maskIsAllOneOrUndef(Value *Mask);
604 
605 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
606 /// for each lane which may be active.
607 APInt possiblyDemandedEltsInMask(Value *Mask);
608 
609 /// The group of interleaved loads/stores sharing the same stride and
610 /// close to each other.
611 ///
612 /// Each member in this group has an index starting from 0, and the largest
613 /// index should be less than interleaved factor, which is equal to the absolute
614 /// value of the access's stride.
615 ///
616 /// E.g. An interleaved load group of factor 4:
617 ///        for (unsigned i = 0; i < 1024; i+=4) {
618 ///          a = A[i];                           // Member of index 0
619 ///          b = A[i+1];                         // Member of index 1
620 ///          d = A[i+3];                         // Member of index 3
621 ///          ...
622 ///        }
623 ///
624 ///      An interleaved store group of factor 4:
625 ///        for (unsigned i = 0; i < 1024; i+=4) {
626 ///          ...
627 ///          A[i]   = a;                         // Member of index 0
628 ///          A[i+1] = b;                         // Member of index 1
629 ///          A[i+2] = c;                         // Member of index 2
630 ///          A[i+3] = d;                         // Member of index 3
631 ///        }
632 ///
633 /// Note: the interleaved load group could have gaps (missing members), but
634 /// the interleaved store group doesn't allow gaps.
635 template <typename InstTy> class InterleaveGroup {
636 public:
637   InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
638       : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
639         InsertPos(nullptr) {}
640 
641   InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
642       : Alignment(Alignment), InsertPos(Instr) {
643     Factor = std::abs(Stride);
644     assert(Factor > 1 && "Invalid interleave factor");
645 
646     Reverse = Stride < 0;
647     Members[0] = Instr;
648   }
649 
650   bool isReverse() const { return Reverse; }
651   uint32_t getFactor() const { return Factor; }
652   Align getAlign() const { return Alignment; }
653   uint32_t getNumMembers() const { return Members.size(); }
654 
655   /// Try to insert a new member \p Instr with index \p Index and
656   /// alignment \p NewAlign. The index is related to the leader and it could be
657   /// negative if it is the new leader.
658   ///
659   /// \returns false if the instruction doesn't belong to the group.
660   bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
661     // Make sure the key fits in an int32_t.
662     std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
663     if (!MaybeKey)
664       return false;
665     int32_t Key = *MaybeKey;
666 
667     // Skip if the key is used for either the tombstone or empty special values.
668     if (DenseMapInfo<int32_t>::getTombstoneKey() == Key ||
669         DenseMapInfo<int32_t>::getEmptyKey() == Key)
670       return false;
671 
672     // Skip if there is already a member with the same index.
673     if (Members.find(Key) != Members.end())
674       return false;
675 
676     if (Key > LargestKey) {
677       // The largest index is always less than the interleave factor.
678       if (Index >= static_cast<int32_t>(Factor))
679         return false;
680 
681       LargestKey = Key;
682     } else if (Key < SmallestKey) {
683 
684       // Make sure the largest index fits in an int32_t.
685       std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key);
686       if (!MaybeLargestIndex)
687         return false;
688 
689       // The largest index is always less than the interleave factor.
690       if (*MaybeLargestIndex >= static_cast<int64_t>(Factor))
691         return false;
692 
693       SmallestKey = Key;
694     }
695 
696     // It's always safe to select the minimum alignment.
697     Alignment = std::min(Alignment, NewAlign);
698     Members[Key] = Instr;
699     return true;
700   }
701 
702   /// Get the member with the given index \p Index
703   ///
704   /// \returns nullptr if contains no such member.
705   InstTy *getMember(uint32_t Index) const {
706     int32_t Key = SmallestKey + Index;
707     return Members.lookup(Key);
708   }
709 
710   /// Get the index for the given member. Unlike the key in the member
711   /// map, the index starts from 0.
712   uint32_t getIndex(const InstTy *Instr) const {
713     for (auto I : Members) {
714       if (I.second == Instr)
715         return I.first - SmallestKey;
716     }
717 
718     llvm_unreachable("InterleaveGroup contains no such member");
719   }
720 
721   InstTy *getInsertPos() const { return InsertPos; }
722   void setInsertPos(InstTy *Inst) { InsertPos = Inst; }
723 
724   /// Add metadata (e.g. alias info) from the instructions in this group to \p
725   /// NewInst.
726   ///
727   /// FIXME: this function currently does not add noalias metadata a'la
728   /// addNewMedata.  To do that we need to compute the intersection of the
729   /// noalias info from all members.
730   void addMetadata(InstTy *NewInst) const;
731 
732   /// Returns true if this Group requires a scalar iteration to handle gaps.
733   bool requiresScalarEpilogue() const {
734     // If the last member of the Group exists, then a scalar epilog is not
735     // needed for this group.
736     if (getMember(getFactor() - 1))
737       return false;
738 
739     // We have a group with gaps. It therefore can't be a reversed access,
740     // because such groups get invalidated (TODO).
741     assert(!isReverse() && "Group should have been invalidated");
742 
743     // This is a group of loads, with gaps, and without a last-member
744     return true;
745   }
746 
747 private:
748   uint32_t Factor; // Interleave Factor.
749   bool Reverse;
750   Align Alignment;
751   DenseMap<int32_t, InstTy *> Members;
752   int32_t SmallestKey = 0;
753   int32_t LargestKey = 0;
754 
755   // To avoid breaking dependences, vectorized instructions of an interleave
756   // group should be inserted at either the first load or the last store in
757   // program order.
758   //
759   // E.g. %even = load i32             // Insert Position
760   //      %add = add i32 %even         // Use of %even
761   //      %odd = load i32
762   //
763   //      store i32 %even
764   //      %odd = add i32               // Def of %odd
765   //      store i32 %odd               // Insert Position
766   InstTy *InsertPos;
767 };
768 
769 /// Drive the analysis of interleaved memory accesses in the loop.
770 ///
771 /// Use this class to analyze interleaved accesses only when we can vectorize
772 /// a loop. Otherwise it's meaningless to do analysis as the vectorization
773 /// on interleaved accesses is unsafe.
774 ///
775 /// The analysis collects interleave groups and records the relationships
776 /// between the member and the group in a map.
777 class InterleavedAccessInfo {
778 public:
779   InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
780                         DominatorTree *DT, LoopInfo *LI,
781                         const LoopAccessInfo *LAI)
782       : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
783 
784   ~InterleavedAccessInfo() { invalidateGroups(); }
785 
786   /// Analyze the interleaved accesses and collect them in interleave
787   /// groups. Substitute symbolic strides using \p Strides.
788   /// Consider also predicated loads/stores in the analysis if
789   /// \p EnableMaskedInterleavedGroup is true.
790   void analyzeInterleaving(bool EnableMaskedInterleavedGroup);
791 
792   /// Invalidate groups, e.g., in case all blocks in loop will be predicated
793   /// contrary to original assumption. Although we currently prevent group
794   /// formation for predicated accesses, we may be able to relax this limitation
795   /// in the future once we handle more complicated blocks. Returns true if any
796   /// groups were invalidated.
797   bool invalidateGroups() {
798     if (InterleaveGroups.empty()) {
799       assert(
800           !RequiresScalarEpilogue &&
801           "RequiresScalarEpilog should not be set without interleave groups");
802       return false;
803     }
804 
805     InterleaveGroupMap.clear();
806     for (auto *Ptr : InterleaveGroups)
807       delete Ptr;
808     InterleaveGroups.clear();
809     RequiresScalarEpilogue = false;
810     return true;
811   }
812 
813   /// Check if \p Instr belongs to any interleave group.
814   bool isInterleaved(Instruction *Instr) const {
815     return InterleaveGroupMap.contains(Instr);
816   }
817 
818   /// Get the interleave group that \p Instr belongs to.
819   ///
820   /// \returns nullptr if doesn't have such group.
821   InterleaveGroup<Instruction> *
822   getInterleaveGroup(const Instruction *Instr) const {
823     return InterleaveGroupMap.lookup(Instr);
824   }
825 
826   iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
827   getInterleaveGroups() {
828     return make_range(InterleaveGroups.begin(), InterleaveGroups.end());
829   }
830 
831   /// Returns true if an interleaved group that may access memory
832   /// out-of-bounds requires a scalar epilogue iteration for correctness.
833   bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; }
834 
835   /// Invalidate groups that require a scalar epilogue (due to gaps). This can
836   /// happen when optimizing for size forbids a scalar epilogue, and the gap
837   /// cannot be filtered by masking the load/store.
838   void invalidateGroupsRequiringScalarEpilogue();
839 
840   /// Returns true if we have any interleave groups.
841   bool hasGroups() const { return !InterleaveGroups.empty(); }
842 
843 private:
844   /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
845   /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
846   /// The interleaved access analysis can also add new predicates (for example
847   /// by versioning strides of pointers).
848   PredicatedScalarEvolution &PSE;
849 
850   Loop *TheLoop;
851   DominatorTree *DT;
852   LoopInfo *LI;
853   const LoopAccessInfo *LAI;
854 
855   /// True if the loop may contain non-reversed interleaved groups with
856   /// out-of-bounds accesses. We ensure we don't speculatively access memory
857   /// out-of-bounds by executing at least one scalar epilogue iteration.
858   bool RequiresScalarEpilogue = false;
859 
860   /// Holds the relationships between the members and the interleave group.
861   DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap;
862 
863   SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups;
864 
865   /// Holds dependences among the memory accesses in the loop. It maps a source
866   /// access to a set of dependent sink accesses.
867   DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences;
868 
869   /// The descriptor for a strided memory access.
870   struct StrideDescriptor {
871     StrideDescriptor() = default;
872     StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
873                      Align Alignment)
874         : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
875 
876     // The access's stride. It is negative for a reverse access.
877     int64_t Stride = 0;
878 
879     // The scalar expression of this access.
880     const SCEV *Scev = nullptr;
881 
882     // The size of the memory object.
883     uint64_t Size = 0;
884 
885     // The alignment of this access.
886     Align Alignment;
887   };
888 
889   /// A type for holding instructions and their stride descriptors.
890   using StrideEntry = std::pair<Instruction *, StrideDescriptor>;
891 
892   /// Create a new interleave group with the given instruction \p Instr,
893   /// stride \p Stride and alignment \p Align.
894   ///
895   /// \returns the newly created interleave group.
896   InterleaveGroup<Instruction> *
897   createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
898     assert(!InterleaveGroupMap.count(Instr) &&
899            "Already in an interleaved access group");
900     InterleaveGroupMap[Instr] =
901         new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
902     InterleaveGroups.insert(InterleaveGroupMap[Instr]);
903     return InterleaveGroupMap[Instr];
904   }
905 
906   /// Release the group and remove all the relationships.
907   void releaseGroup(InterleaveGroup<Instruction> *Group) {
908     for (unsigned i = 0; i < Group->getFactor(); i++)
909       if (Instruction *Member = Group->getMember(i))
910         InterleaveGroupMap.erase(Member);
911 
912     InterleaveGroups.erase(Group);
913     delete Group;
914   }
915 
916   /// Collect all the accesses with a constant stride in program order.
917   void collectConstStrideAccesses(
918       MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
919       const DenseMap<Value *, const SCEV *> &Strides);
920 
921   /// Returns true if \p Stride is allowed in an interleaved group.
922   static bool isStrided(int Stride);
923 
924   /// Returns true if \p BB is a predicated block.
925   bool isPredicated(BasicBlock *BB) const {
926     return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
927   }
928 
929   /// Returns true if LoopAccessInfo can be used for dependence queries.
930   bool areDependencesValid() const {
931     return LAI && LAI->getDepChecker().getDependences();
932   }
933 
934   /// Returns true if memory accesses \p A and \p B can be reordered, if
935   /// necessary, when constructing interleaved groups.
936   ///
937   /// \p A must precede \p B in program order. We return false if reordering is
938   /// not necessary or is prevented because \p A and \p B may be dependent.
939   bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A,
940                                                  StrideEntry *B) const {
941     // Code motion for interleaved accesses can potentially hoist strided loads
942     // and sink strided stores. The code below checks the legality of the
943     // following two conditions:
944     //
945     // 1. Potentially moving a strided load (B) before any store (A) that
946     //    precedes B, or
947     //
948     // 2. Potentially moving a strided store (A) after any load or store (B)
949     //    that A precedes.
950     //
951     // It's legal to reorder A and B if we know there isn't a dependence from A
952     // to B. Note that this determination is conservative since some
953     // dependences could potentially be reordered safely.
954 
955     // A is potentially the source of a dependence.
956     auto *Src = A->first;
957     auto SrcDes = A->second;
958 
959     // B is potentially the sink of a dependence.
960     auto *Sink = B->first;
961     auto SinkDes = B->second;
962 
963     // Code motion for interleaved accesses can't violate WAR dependences.
964     // Thus, reordering is legal if the source isn't a write.
965     if (!Src->mayWriteToMemory())
966       return true;
967 
968     // At least one of the accesses must be strided.
969     if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride))
970       return true;
971 
972     // If dependence information is not available from LoopAccessInfo,
973     // conservatively assume the instructions can't be reordered.
974     if (!areDependencesValid())
975       return false;
976 
977     // If we know there is a dependence from source to sink, assume the
978     // instructions can't be reordered. Otherwise, reordering is legal.
979     return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink);
980   }
981 
982   /// Collect the dependences from LoopAccessInfo.
983   ///
984   /// We process the dependences once during the interleaved access analysis to
985   /// enable constant-time dependence queries.
986   void collectDependences() {
987     if (!areDependencesValid())
988       return;
989     auto *Deps = LAI->getDepChecker().getDependences();
990     for (auto Dep : *Deps)
991       Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI));
992   }
993 };
994 
995 } // llvm namespace
996 
997 #endif
998