1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines some vectorizer utilities.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H
14 #define LLVM_ANALYSIS_VECTORUTILS_H
15
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopAccessAnalysis.h"
19 #include "llvm/Support/CheckedArithmetic.h"
20
21 namespace llvm {
22 class TargetLibraryInfo;
23
24 /// Describes the type of Parameters
25 enum class VFParamKind {
26 Vector, // No semantic information.
27 OMP_Linear, // declare simd linear(i)
28 OMP_LinearRef, // declare simd linear(ref(i))
29 OMP_LinearVal, // declare simd linear(val(i))
30 OMP_LinearUVal, // declare simd linear(uval(i))
31 OMP_LinearPos, // declare simd linear(i:c) uniform(c)
32 OMP_LinearValPos, // declare simd linear(val(i:c)) uniform(c)
33 OMP_LinearRefPos, // declare simd linear(ref(i:c)) uniform(c)
34 OMP_LinearUValPos, // declare simd linear(uval(i:c)) uniform(c
35 OMP_Uniform, // declare simd uniform(i)
36 GlobalPredicate, // Global logical predicate that acts on all lanes
37 // of the input and output mask concurrently. For
38 // example, it is implied by the `M` token in the
39 // Vector Function ABI mangled name.
40 Unknown
41 };
42
43 /// Describes the type of Instruction Set Architecture
44 enum class VFISAKind {
45 AdvancedSIMD, // AArch64 Advanced SIMD (NEON)
46 SVE, // AArch64 Scalable Vector Extension
47 SSE, // x86 SSE
48 AVX, // x86 AVX
49 AVX2, // x86 AVX2
50 AVX512, // x86 AVX512
51 LLVM, // LLVM internal ISA for functions that are not
52 // attached to an existing ABI via name mangling.
53 Unknown // Unknown ISA
54 };
55
56 /// Encapsulates information needed to describe a parameter.
57 ///
58 /// The description of the parameter is not linked directly to
59 /// OpenMP or any other vector function description. This structure
60 /// is extendible to handle other paradigms that describe vector
61 /// functions and their parameters.
62 struct VFParameter {
63 unsigned ParamPos; // Parameter Position in Scalar Function.
64 VFParamKind ParamKind; // Kind of Parameter.
65 int LinearStepOrPos = 0; // Step or Position of the Parameter.
66 Align Alignment = Align(); // Optional alignment in bytes, defaulted to 1.
67
68 // Comparison operator.
69 bool operator==(const VFParameter &Other) const {
70 return std::tie(ParamPos, ParamKind, LinearStepOrPos, Alignment) ==
71 std::tie(Other.ParamPos, Other.ParamKind, Other.LinearStepOrPos,
72 Other.Alignment);
73 }
74 };
75
76 /// Contains the information about the kind of vectorization
77 /// available.
78 ///
79 /// This object in independent on the paradigm used to
80 /// represent vector functions. in particular, it is not attached to
81 /// any target-specific ABI.
82 struct VFShape {
83 unsigned VF; // Vectorization factor.
84 bool IsScalable; // True if the function is a scalable function.
85 SmallVector<VFParameter, 8> Parameters; // List of parameter information.
86 // Comparison operator.
87 bool operator==(const VFShape &Other) const {
88 return std::tie(VF, IsScalable, Parameters) ==
89 std::tie(Other.VF, Other.IsScalable, Other.Parameters);
90 }
91
92 /// Update the parameter in position P.ParamPos to P.
updateParamVFShape93 void updateParam(VFParameter P) {
94 assert(P.ParamPos < Parameters.size() && "Invalid parameter position.");
95 Parameters[P.ParamPos] = P;
96 assert(hasValidParameterList() && "Invalid parameter list");
97 }
98
99 // Retrieve the VFShape that can be used to map a (scalar) function to itself,
100 // with VF = 1.
getScalarShapeVFShape101 static VFShape getScalarShape(const CallInst &CI) {
102 return VFShape::get(CI, ElementCount::getFixed(1),
103 /*HasGlobalPredicate*/ false);
104 }
105
106 // Retrieve the basic vectorization shape of the function, where all
107 // parameters are mapped to VFParamKind::Vector with \p EC
108 // lanes. Specifies whether the function has a Global Predicate
109 // argument via \p HasGlobalPred.
getVFShape110 static VFShape get(const CallInst &CI, ElementCount EC, bool HasGlobalPred) {
111 SmallVector<VFParameter, 8> Parameters;
112 for (unsigned I = 0; I < CI.arg_size(); ++I)
113 Parameters.push_back(VFParameter({I, VFParamKind::Vector}));
114 if (HasGlobalPred)
115 Parameters.push_back(
116 VFParameter({CI.arg_size(), VFParamKind::GlobalPredicate}));
117
118 return {EC.getKnownMinValue(), EC.isScalable(), Parameters};
119 }
120 /// Sanity check on the Parameters in the VFShape.
121 bool hasValidParameterList() const;
122 };
123
124 /// Holds the VFShape for a specific scalar to vector function mapping.
125 struct VFInfo {
126 VFShape Shape; /// Classification of the vector function.
127 std::string ScalarName; /// Scalar Function Name.
128 std::string VectorName; /// Vector Function Name associated to this VFInfo.
129 VFISAKind ISA; /// Instruction Set Architecture.
130 };
131
132 namespace VFABI {
133 /// LLVM Internal VFABI ISA token for vector functions.
134 static constexpr char const *_LLVM_ = "_LLVM_";
135 /// Prefix for internal name redirection for vector function that
136 /// tells the compiler to scalarize the call using the scalar name
137 /// of the function. For example, a mangled name like
138 /// `_ZGV_LLVM_N2v_foo(_LLVM_Scalarize_foo)` would tell the
139 /// vectorizer to vectorize the scalar call `foo`, and to scalarize
140 /// it once vectorization is done.
141 static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_";
142
143 /// Function to construct a VFInfo out of a mangled names in the
144 /// following format:
145 ///
146 /// <VFABI_name>{(<redirection>)}
147 ///
148 /// where <VFABI_name> is the name of the vector function, mangled according
149 /// to the rules described in the Vector Function ABI of the target vector
150 /// extension (or <isa> from now on). The <VFABI_name> is in the following
151 /// format:
152 ///
153 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]
154 ///
155 /// This methods support demangling rules for the following <isa>:
156 ///
157 /// * AArch64: https://developer.arm.com/docs/101129/latest
158 ///
159 /// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and
160 /// https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt
161 ///
162 /// \param MangledName -> input string in the format
163 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)].
164 /// \param M -> Module used to retrieve informations about the vector
165 /// function that are not possible to retrieve from the mangled
166 /// name. At the moment, this parameter is needed only to retrieve the
167 /// Vectorization Factor of scalable vector functions from their
168 /// respective IR declarations.
169 Optional<VFInfo> tryDemangleForVFABI(StringRef MangledName, const Module &M);
170
171 /// This routine mangles the given VectorName according to the LangRef
172 /// specification for vector-function-abi-variant attribute and is specific to
173 /// the TLI mappings. It is the responsibility of the caller to make sure that
174 /// this is only used if all parameters in the vector function are vector type.
175 /// This returned string holds scalar-to-vector mapping:
176 /// _ZGV<isa><mask><vlen><vparams>_<scalarname>(<vectorname>)
177 ///
178 /// where:
179 ///
180 /// <isa> = "_LLVM_"
181 /// <mask> = "N". Note: TLI does not support masked interfaces.
182 /// <vlen> = Number of concurrent lanes, stored in the `VectorizationFactor`
183 /// field of the `VecDesc` struct. If the number of lanes is scalable
184 /// then 'x' is printed instead.
185 /// <vparams> = "v", as many as are the numArgs.
186 /// <scalarname> = the name of the scalar function.
187 /// <vectorname> = the name of the vector function.
188 std::string mangleTLIVectorName(StringRef VectorName, StringRef ScalarName,
189 unsigned numArgs, ElementCount VF);
190
191 /// Retrieve the `VFParamKind` from a string token.
192 VFParamKind getVFParamKindFromString(const StringRef Token);
193
194 // Name of the attribute where the variant mappings are stored.
195 static constexpr char const *MappingsAttrName = "vector-function-abi-variant";
196
197 /// Populates a set of strings representing the Vector Function ABI variants
198 /// associated to the CallInst CI. If the CI does not contain the
199 /// vector-function-abi-variant attribute, we return without populating
200 /// VariantMappings, i.e. callers of getVectorVariantNames need not check for
201 /// the presence of the attribute (see InjectTLIMappings).
202 void getVectorVariantNames(const CallInst &CI,
203 SmallVectorImpl<std::string> &VariantMappings);
204 } // end namespace VFABI
205
206 /// The Vector Function Database.
207 ///
208 /// Helper class used to find the vector functions associated to a
209 /// scalar CallInst.
210 class VFDatabase {
211 /// The Module of the CallInst CI.
212 const Module *M;
213 /// The CallInst instance being queried for scalar to vector mappings.
214 const CallInst &CI;
215 /// List of vector functions descriptors associated to the call
216 /// instruction.
217 const SmallVector<VFInfo, 8> ScalarToVectorMappings;
218
219 /// Retrieve the scalar-to-vector mappings associated to the rule of
220 /// a vector Function ABI.
getVFABIMappings(const CallInst & CI,SmallVectorImpl<VFInfo> & Mappings)221 static void getVFABIMappings(const CallInst &CI,
222 SmallVectorImpl<VFInfo> &Mappings) {
223 if (!CI.getCalledFunction())
224 return;
225
226 const StringRef ScalarName = CI.getCalledFunction()->getName();
227
228 SmallVector<std::string, 8> ListOfStrings;
229 // The check for the vector-function-abi-variant attribute is done when
230 // retrieving the vector variant names here.
231 VFABI::getVectorVariantNames(CI, ListOfStrings);
232 if (ListOfStrings.empty())
233 return;
234 for (const auto &MangledName : ListOfStrings) {
235 const Optional<VFInfo> Shape =
236 VFABI::tryDemangleForVFABI(MangledName, *(CI.getModule()));
237 // A match is found via scalar and vector names, and also by
238 // ensuring that the variant described in the attribute has a
239 // corresponding definition or declaration of the vector
240 // function in the Module M.
241 if (Shape.hasValue() && (Shape.getValue().ScalarName == ScalarName)) {
242 assert(CI.getModule()->getFunction(Shape.getValue().VectorName) &&
243 "Vector function is missing.");
244 Mappings.push_back(Shape.getValue());
245 }
246 }
247 }
248
249 public:
250 /// Retrieve all the VFInfo instances associated to the CallInst CI.
getMappings(const CallInst & CI)251 static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) {
252 SmallVector<VFInfo, 8> Ret;
253
254 // Get mappings from the Vector Function ABI variants.
255 getVFABIMappings(CI, Ret);
256
257 // Other non-VFABI variants should be retrieved here.
258
259 return Ret;
260 }
261
262 /// Constructor, requires a CallInst instance.
VFDatabase(CallInst & CI)263 VFDatabase(CallInst &CI)
264 : M(CI.getModule()), CI(CI),
265 ScalarToVectorMappings(VFDatabase::getMappings(CI)) {}
266 /// \defgroup VFDatabase query interface.
267 ///
268 /// @{
269 /// Retrieve the Function with VFShape \p Shape.
getVectorizedFunction(const VFShape & Shape)270 Function *getVectorizedFunction(const VFShape &Shape) const {
271 if (Shape == VFShape::getScalarShape(CI))
272 return CI.getCalledFunction();
273
274 for (const auto &Info : ScalarToVectorMappings)
275 if (Info.Shape == Shape)
276 return M->getFunction(Info.VectorName);
277
278 return nullptr;
279 }
280 /// @}
281 };
282
283 template <typename T> class ArrayRef;
284 class DemandedBits;
285 class GetElementPtrInst;
286 template <typename InstTy> class InterleaveGroup;
287 class IRBuilderBase;
288 class Loop;
289 class ScalarEvolution;
290 class TargetTransformInfo;
291 class Type;
292 class Value;
293
294 namespace Intrinsic {
295 typedef unsigned ID;
296 }
297
298 /// A helper function for converting Scalar types to vector types. If
299 /// the incoming type is void, we return void. If the EC represents a
300 /// scalar, we return the scalar type.
ToVectorTy(Type * Scalar,ElementCount EC)301 inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
302 if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
303 return Scalar;
304 return VectorType::get(Scalar, EC);
305 }
306
ToVectorTy(Type * Scalar,unsigned VF)307 inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
308 return ToVectorTy(Scalar, ElementCount::getFixed(VF));
309 }
310
311 /// Identify if the intrinsic is trivially vectorizable.
312 /// This method returns true if the intrinsic's argument types are all scalars
313 /// for the scalar form of the intrinsic and all vectors (or scalars handled by
314 /// hasVectorInstrinsicScalarOpd) for the vector form of the intrinsic.
315 bool isTriviallyVectorizable(Intrinsic::ID ID);
316
317 /// Identifies if the vector form of the intrinsic has a scalar operand.
318 bool hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, unsigned ScalarOpdIdx);
319
320 /// Returns intrinsic ID for call.
321 /// For the input call instruction it finds mapping intrinsic and returns
322 /// its intrinsic ID, in case it does not found it return not_intrinsic.
323 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI,
324 const TargetLibraryInfo *TLI);
325
326 /// Find the operand of the GEP that should be checked for consecutive
327 /// stores. This ignores trailing indices that have no effect on the final
328 /// pointer.
329 unsigned getGEPInductionOperand(const GetElementPtrInst *Gep);
330
331 /// If the argument is a GEP, then returns the operand identified by
332 /// getGEPInductionOperand. However, if there is some other non-loop-invariant
333 /// operand, it returns that instead.
334 Value *stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
335
336 /// If a value has only one user that is a CastInst, return it.
337 Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty);
338
339 /// Get the stride of a pointer access in a loop. Looks for symbolic
340 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
341 Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
342
343 /// Given a vector and an element number, see if the scalar value is
344 /// already around as a register, for example if it were inserted then extracted
345 /// from the vector.
346 Value *findScalarElement(Value *V, unsigned EltNo);
347
348 /// If all non-negative \p Mask elements are the same value, return that value.
349 /// If all elements are negative (undefined) or \p Mask contains different
350 /// non-negative values, return -1.
351 int getSplatIndex(ArrayRef<int> Mask);
352
353 /// Get splat value if the input is a splat vector or return nullptr.
354 /// The value may be extracted from a splat constants vector or from
355 /// a sequence of instructions that broadcast a single value into a vector.
356 Value *getSplatValue(const Value *V);
357
358 /// Return true if each element of the vector value \p V is poisoned or equal to
359 /// every other non-poisoned element. If an index element is specified, either
360 /// every element of the vector is poisoned or the element at that index is not
361 /// poisoned and equal to every other non-poisoned element.
362 /// This may be more powerful than the related getSplatValue() because it is
363 /// not limited by finding a scalar source value to a splatted vector.
364 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
365
366 /// Replace each shuffle mask index with the scaled sequential indices for an
367 /// equivalent mask of narrowed elements. Mask elements that are less than 0
368 /// (sentinel values) are repeated in the output mask.
369 ///
370 /// Example with Scale = 4:
371 /// <4 x i32> <3, 2, 0, -1> -->
372 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1>
373 ///
374 /// This is the reverse process of widening shuffle mask elements, but it always
375 /// succeeds because the indexes can always be multiplied (scaled up) to map to
376 /// narrower vector elements.
377 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
378 SmallVectorImpl<int> &ScaledMask);
379
380 /// Try to transform a shuffle mask by replacing elements with the scaled index
381 /// for an equivalent mask of widened elements. If all mask elements that would
382 /// map to a wider element of the new mask are the same negative number
383 /// (sentinel value), that element of the new mask is the same value. If any
384 /// element in a given slice is negative and some other element in that slice is
385 /// not the same value, return false (partial matches with sentinel values are
386 /// not allowed).
387 ///
388 /// Example with Scale = 4:
389 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> -->
390 /// <4 x i32> <3, 2, 0, -1>
391 ///
392 /// This is the reverse process of narrowing shuffle mask elements if it
393 /// succeeds. This transform is not always possible because indexes may not
394 /// divide evenly (scale down) to map to wider vector elements.
395 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
396 SmallVectorImpl<int> &ScaledMask);
397
398 /// Compute a map of integer instructions to their minimum legal type
399 /// size.
400 ///
401 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
402 /// type (e.g. i32) whenever arithmetic is performed on them.
403 ///
404 /// For targets with native i8 or i16 operations, usually InstCombine can shrink
405 /// the arithmetic type down again. However InstCombine refuses to create
406 /// illegal types, so for targets without i8 or i16 registers, the lengthening
407 /// and shrinking remains.
408 ///
409 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
410 /// their scalar equivalents do not, so during vectorization it is important to
411 /// remove these lengthens and truncates when deciding the profitability of
412 /// vectorization.
413 ///
414 /// This function analyzes the given range of instructions and determines the
415 /// minimum type size each can be converted to. It attempts to remove or
416 /// minimize type size changes across each def-use chain, so for example in the
417 /// following code:
418 ///
419 /// %1 = load i8, i8*
420 /// %2 = add i8 %1, 2
421 /// %3 = load i16, i16*
422 /// %4 = zext i8 %2 to i32
423 /// %5 = zext i16 %3 to i32
424 /// %6 = add i32 %4, %5
425 /// %7 = trunc i32 %6 to i16
426 ///
427 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
428 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
429 ///
430 /// If the optional TargetTransformInfo is provided, this function tries harder
431 /// to do less work by only looking at illegal types.
432 MapVector<Instruction*, uint64_t>
433 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks,
434 DemandedBits &DB,
435 const TargetTransformInfo *TTI=nullptr);
436
437 /// Compute the union of two access-group lists.
438 ///
439 /// If the list contains just one access group, it is returned directly. If the
440 /// list is empty, returns nullptr.
441 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2);
442
443 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2
444 /// are both in. If either instruction does not access memory at all, it is
445 /// considered to be in every list.
446 ///
447 /// If the list contains just one access group, it is returned directly. If the
448 /// list is empty, returns nullptr.
449 MDNode *intersectAccessGroups(const Instruction *Inst1,
450 const Instruction *Inst2);
451
452 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath,
453 /// MD_nontemporal, MD_access_group].
454 /// For K in Kinds, we get the MDNode for K from each of the
455 /// elements of VL, compute their "intersection" (i.e., the most generic
456 /// metadata value that covers all of the individual values), and set I's
457 /// metadata for M equal to the intersection value.
458 ///
459 /// This function always sets a (possibly null) value for each K in Kinds.
460 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL);
461
462 /// Create a mask that filters the members of an interleave group where there
463 /// are gaps.
464 ///
465 /// For example, the mask for \p Group with interleave-factor 3
466 /// and \p VF 4, that has only its first member present is:
467 ///
468 /// <1,0,0,1,0,0,1,0,0,1,0,0>
469 ///
470 /// Note: The result is a mask of 0's and 1's, as opposed to the other
471 /// create[*]Mask() utilities which create a shuffle mask (mask that
472 /// consists of indices).
473 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
474 const InterleaveGroup<Instruction> &Group);
475
476 /// Create a mask with replicated elements.
477 ///
478 /// This function creates a shuffle mask for replicating each of the \p VF
479 /// elements in a vector \p ReplicationFactor times. It can be used to
480 /// transform a mask of \p VF elements into a mask of
481 /// \p VF * \p ReplicationFactor elements used by a predicated
482 /// interleaved-group of loads/stores whose Interleaved-factor ==
483 /// \p ReplicationFactor.
484 ///
485 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
486 ///
487 /// <0,0,0,1,1,1,2,2,2,3,3,3>
488 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor,
489 unsigned VF);
490
491 /// Create an interleave shuffle mask.
492 ///
493 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of
494 /// vectorization factor \p VF into a single wide vector. The mask is of the
495 /// form:
496 ///
497 /// <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...>
498 ///
499 /// For example, the mask for VF = 4 and NumVecs = 2 is:
500 ///
501 /// <0, 4, 1, 5, 2, 6, 3, 7>.
502 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs);
503
504 /// Create a stride shuffle mask.
505 ///
506 /// This function creates a shuffle mask whose elements begin at \p Start and
507 /// are incremented by \p Stride. The mask can be used to deinterleave an
508 /// interleaved vector into separate vectors of vectorization factor \p VF. The
509 /// mask is of the form:
510 ///
511 /// <Start, Start + Stride, ..., Start + Stride * (VF - 1)>
512 ///
513 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
514 ///
515 /// <0, 2, 4, 6>
516 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride,
517 unsigned VF);
518
519 /// Create a sequential shuffle mask.
520 ///
521 /// This function creates shuffle mask whose elements are sequential and begin
522 /// at \p Start. The mask contains \p NumInts integers and is padded with \p
523 /// NumUndefs undef values. The mask is of the form:
524 ///
525 /// <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs>
526 ///
527 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
528 ///
529 /// <0, 1, 2, 3, undef, undef, undef, undef>
530 llvm::SmallVector<int, 16>
531 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs);
532
533 /// Concatenate a list of vectors.
534 ///
535 /// This function generates code that concatenate the vectors in \p Vecs into a
536 /// single large vector. The number of vectors should be greater than one, and
537 /// their element types should be the same. The number of elements in the
538 /// vectors should also be the same; however, if the last vector has fewer
539 /// elements, it will be padded with undefs.
540 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs);
541
542 /// Given a mask vector of i1, Return true if all of the elements of this
543 /// predicate mask are known to be false or undef. That is, return true if all
544 /// lanes can be assumed inactive.
545 bool maskIsAllZeroOrUndef(Value *Mask);
546
547 /// Given a mask vector of i1, Return true if all of the elements of this
548 /// predicate mask are known to be true or undef. That is, return true if all
549 /// lanes can be assumed active.
550 bool maskIsAllOneOrUndef(Value *Mask);
551
552 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
553 /// for each lane which may be active.
554 APInt possiblyDemandedEltsInMask(Value *Mask);
555
556 /// The group of interleaved loads/stores sharing the same stride and
557 /// close to each other.
558 ///
559 /// Each member in this group has an index starting from 0, and the largest
560 /// index should be less than interleaved factor, which is equal to the absolute
561 /// value of the access's stride.
562 ///
563 /// E.g. An interleaved load group of factor 4:
564 /// for (unsigned i = 0; i < 1024; i+=4) {
565 /// a = A[i]; // Member of index 0
566 /// b = A[i+1]; // Member of index 1
567 /// d = A[i+3]; // Member of index 3
568 /// ...
569 /// }
570 ///
571 /// An interleaved store group of factor 4:
572 /// for (unsigned i = 0; i < 1024; i+=4) {
573 /// ...
574 /// A[i] = a; // Member of index 0
575 /// A[i+1] = b; // Member of index 1
576 /// A[i+2] = c; // Member of index 2
577 /// A[i+3] = d; // Member of index 3
578 /// }
579 ///
580 /// Note: the interleaved load group could have gaps (missing members), but
581 /// the interleaved store group doesn't allow gaps.
582 template <typename InstTy> class InterleaveGroup {
583 public:
InterleaveGroup(uint32_t Factor,bool Reverse,Align Alignment)584 InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
585 : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
586 InsertPos(nullptr) {}
587
InterleaveGroup(InstTy * Instr,int32_t Stride,Align Alignment)588 InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
589 : Alignment(Alignment), InsertPos(Instr) {
590 Factor = std::abs(Stride);
591 assert(Factor > 1 && "Invalid interleave factor");
592
593 Reverse = Stride < 0;
594 Members[0] = Instr;
595 }
596
isReverse()597 bool isReverse() const { return Reverse; }
getFactor()598 uint32_t getFactor() const { return Factor; }
getAlign()599 Align getAlign() const { return Alignment; }
getNumMembers()600 uint32_t getNumMembers() const { return Members.size(); }
601
602 /// Try to insert a new member \p Instr with index \p Index and
603 /// alignment \p NewAlign. The index is related to the leader and it could be
604 /// negative if it is the new leader.
605 ///
606 /// \returns false if the instruction doesn't belong to the group.
insertMember(InstTy * Instr,int32_t Index,Align NewAlign)607 bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
608 // Make sure the key fits in an int32_t.
609 Optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
610 if (!MaybeKey)
611 return false;
612 int32_t Key = *MaybeKey;
613
614 // Skip if the key is used for either the tombstone or empty special values.
615 if (DenseMapInfo<int32_t>::getTombstoneKey() == Key ||
616 DenseMapInfo<int32_t>::getEmptyKey() == Key)
617 return false;
618
619 // Skip if there is already a member with the same index.
620 if (Members.find(Key) != Members.end())
621 return false;
622
623 if (Key > LargestKey) {
624 // The largest index is always less than the interleave factor.
625 if (Index >= static_cast<int32_t>(Factor))
626 return false;
627
628 LargestKey = Key;
629 } else if (Key < SmallestKey) {
630
631 // Make sure the largest index fits in an int32_t.
632 Optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key);
633 if (!MaybeLargestIndex)
634 return false;
635
636 // The largest index is always less than the interleave factor.
637 if (*MaybeLargestIndex >= static_cast<int64_t>(Factor))
638 return false;
639
640 SmallestKey = Key;
641 }
642
643 // It's always safe to select the minimum alignment.
644 Alignment = std::min(Alignment, NewAlign);
645 Members[Key] = Instr;
646 return true;
647 }
648
649 /// Get the member with the given index \p Index
650 ///
651 /// \returns nullptr if contains no such member.
getMember(uint32_t Index)652 InstTy *getMember(uint32_t Index) const {
653 int32_t Key = SmallestKey + Index;
654 return Members.lookup(Key);
655 }
656
657 /// Get the index for the given member. Unlike the key in the member
658 /// map, the index starts from 0.
getIndex(const InstTy * Instr)659 uint32_t getIndex(const InstTy *Instr) const {
660 for (auto I : Members) {
661 if (I.second == Instr)
662 return I.first - SmallestKey;
663 }
664
665 llvm_unreachable("InterleaveGroup contains no such member");
666 }
667
getInsertPos()668 InstTy *getInsertPos() const { return InsertPos; }
setInsertPos(InstTy * Inst)669 void setInsertPos(InstTy *Inst) { InsertPos = Inst; }
670
671 /// Add metadata (e.g. alias info) from the instructions in this group to \p
672 /// NewInst.
673 ///
674 /// FIXME: this function currently does not add noalias metadata a'la
675 /// addNewMedata. To do that we need to compute the intersection of the
676 /// noalias info from all members.
677 void addMetadata(InstTy *NewInst) const;
678
679 /// Returns true if this Group requires a scalar iteration to handle gaps.
requiresScalarEpilogue()680 bool requiresScalarEpilogue() const {
681 // If the last member of the Group exists, then a scalar epilog is not
682 // needed for this group.
683 if (getMember(getFactor() - 1))
684 return false;
685
686 // We have a group with gaps. It therefore cannot be a group of stores,
687 // and it can't be a reversed access, because such groups get invalidated.
688 assert(!getMember(0)->mayWriteToMemory() &&
689 "Group should have been invalidated");
690 assert(!isReverse() && "Group should have been invalidated");
691
692 // This is a group of loads, with gaps, and without a last-member
693 return true;
694 }
695
696 private:
697 uint32_t Factor; // Interleave Factor.
698 bool Reverse;
699 Align Alignment;
700 DenseMap<int32_t, InstTy *> Members;
701 int32_t SmallestKey = 0;
702 int32_t LargestKey = 0;
703
704 // To avoid breaking dependences, vectorized instructions of an interleave
705 // group should be inserted at either the first load or the last store in
706 // program order.
707 //
708 // E.g. %even = load i32 // Insert Position
709 // %add = add i32 %even // Use of %even
710 // %odd = load i32
711 //
712 // store i32 %even
713 // %odd = add i32 // Def of %odd
714 // store i32 %odd // Insert Position
715 InstTy *InsertPos;
716 };
717
718 /// Drive the analysis of interleaved memory accesses in the loop.
719 ///
720 /// Use this class to analyze interleaved accesses only when we can vectorize
721 /// a loop. Otherwise it's meaningless to do analysis as the vectorization
722 /// on interleaved accesses is unsafe.
723 ///
724 /// The analysis collects interleave groups and records the relationships
725 /// between the member and the group in a map.
726 class InterleavedAccessInfo {
727 public:
InterleavedAccessInfo(PredicatedScalarEvolution & PSE,Loop * L,DominatorTree * DT,LoopInfo * LI,const LoopAccessInfo * LAI)728 InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
729 DominatorTree *DT, LoopInfo *LI,
730 const LoopAccessInfo *LAI)
731 : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
732
~InterleavedAccessInfo()733 ~InterleavedAccessInfo() { invalidateGroups(); }
734
735 /// Analyze the interleaved accesses and collect them in interleave
736 /// groups. Substitute symbolic strides using \p Strides.
737 /// Consider also predicated loads/stores in the analysis if
738 /// \p EnableMaskedInterleavedGroup is true.
739 void analyzeInterleaving(bool EnableMaskedInterleavedGroup);
740
741 /// Invalidate groups, e.g., in case all blocks in loop will be predicated
742 /// contrary to original assumption. Although we currently prevent group
743 /// formation for predicated accesses, we may be able to relax this limitation
744 /// in the future once we handle more complicated blocks. Returns true if any
745 /// groups were invalidated.
invalidateGroups()746 bool invalidateGroups() {
747 if (InterleaveGroups.empty()) {
748 assert(
749 !RequiresScalarEpilogue &&
750 "RequiresScalarEpilog should not be set without interleave groups");
751 return false;
752 }
753
754 InterleaveGroupMap.clear();
755 for (auto *Ptr : InterleaveGroups)
756 delete Ptr;
757 InterleaveGroups.clear();
758 RequiresScalarEpilogue = false;
759 return true;
760 }
761
762 /// Check if \p Instr belongs to any interleave group.
isInterleaved(Instruction * Instr)763 bool isInterleaved(Instruction *Instr) const {
764 return InterleaveGroupMap.find(Instr) != InterleaveGroupMap.end();
765 }
766
767 /// Get the interleave group that \p Instr belongs to.
768 ///
769 /// \returns nullptr if doesn't have such group.
770 InterleaveGroup<Instruction> *
getInterleaveGroup(const Instruction * Instr)771 getInterleaveGroup(const Instruction *Instr) const {
772 return InterleaveGroupMap.lookup(Instr);
773 }
774
775 iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
getInterleaveGroups()776 getInterleaveGroups() {
777 return make_range(InterleaveGroups.begin(), InterleaveGroups.end());
778 }
779
780 /// Returns true if an interleaved group that may access memory
781 /// out-of-bounds requires a scalar epilogue iteration for correctness.
requiresScalarEpilogue()782 bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; }
783
784 /// Invalidate groups that require a scalar epilogue (due to gaps). This can
785 /// happen when optimizing for size forbids a scalar epilogue, and the gap
786 /// cannot be filtered by masking the load/store.
787 void invalidateGroupsRequiringScalarEpilogue();
788
789 private:
790 /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
791 /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
792 /// The interleaved access analysis can also add new predicates (for example
793 /// by versioning strides of pointers).
794 PredicatedScalarEvolution &PSE;
795
796 Loop *TheLoop;
797 DominatorTree *DT;
798 LoopInfo *LI;
799 const LoopAccessInfo *LAI;
800
801 /// True if the loop may contain non-reversed interleaved groups with
802 /// out-of-bounds accesses. We ensure we don't speculatively access memory
803 /// out-of-bounds by executing at least one scalar epilogue iteration.
804 bool RequiresScalarEpilogue = false;
805
806 /// Holds the relationships between the members and the interleave group.
807 DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap;
808
809 SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups;
810
811 /// Holds dependences among the memory accesses in the loop. It maps a source
812 /// access to a set of dependent sink accesses.
813 DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences;
814
815 /// The descriptor for a strided memory access.
816 struct StrideDescriptor {
817 StrideDescriptor() = default;
StrideDescriptorStrideDescriptor818 StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
819 Align Alignment)
820 : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
821
822 // The access's stride. It is negative for a reverse access.
823 int64_t Stride = 0;
824
825 // The scalar expression of this access.
826 const SCEV *Scev = nullptr;
827
828 // The size of the memory object.
829 uint64_t Size = 0;
830
831 // The alignment of this access.
832 Align Alignment;
833 };
834
835 /// A type for holding instructions and their stride descriptors.
836 using StrideEntry = std::pair<Instruction *, StrideDescriptor>;
837
838 /// Create a new interleave group with the given instruction \p Instr,
839 /// stride \p Stride and alignment \p Align.
840 ///
841 /// \returns the newly created interleave group.
842 InterleaveGroup<Instruction> *
createInterleaveGroup(Instruction * Instr,int Stride,Align Alignment)843 createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
844 assert(!InterleaveGroupMap.count(Instr) &&
845 "Already in an interleaved access group");
846 InterleaveGroupMap[Instr] =
847 new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
848 InterleaveGroups.insert(InterleaveGroupMap[Instr]);
849 return InterleaveGroupMap[Instr];
850 }
851
852 /// Release the group and remove all the relationships.
releaseGroup(InterleaveGroup<Instruction> * Group)853 void releaseGroup(InterleaveGroup<Instruction> *Group) {
854 for (unsigned i = 0; i < Group->getFactor(); i++)
855 if (Instruction *Member = Group->getMember(i))
856 InterleaveGroupMap.erase(Member);
857
858 InterleaveGroups.erase(Group);
859 delete Group;
860 }
861
862 /// Collect all the accesses with a constant stride in program order.
863 void collectConstStrideAccesses(
864 MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
865 const ValueToValueMap &Strides);
866
867 /// Returns true if \p Stride is allowed in an interleaved group.
868 static bool isStrided(int Stride);
869
870 /// Returns true if \p BB is a predicated block.
isPredicated(BasicBlock * BB)871 bool isPredicated(BasicBlock *BB) const {
872 return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
873 }
874
875 /// Returns true if LoopAccessInfo can be used for dependence queries.
areDependencesValid()876 bool areDependencesValid() const {
877 return LAI && LAI->getDepChecker().getDependences();
878 }
879
880 /// Returns true if memory accesses \p A and \p B can be reordered, if
881 /// necessary, when constructing interleaved groups.
882 ///
883 /// \p A must precede \p B in program order. We return false if reordering is
884 /// not necessary or is prevented because \p A and \p B may be dependent.
canReorderMemAccessesForInterleavedGroups(StrideEntry * A,StrideEntry * B)885 bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A,
886 StrideEntry *B) const {
887 // Code motion for interleaved accesses can potentially hoist strided loads
888 // and sink strided stores. The code below checks the legality of the
889 // following two conditions:
890 //
891 // 1. Potentially moving a strided load (B) before any store (A) that
892 // precedes B, or
893 //
894 // 2. Potentially moving a strided store (A) after any load or store (B)
895 // that A precedes.
896 //
897 // It's legal to reorder A and B if we know there isn't a dependence from A
898 // to B. Note that this determination is conservative since some
899 // dependences could potentially be reordered safely.
900
901 // A is potentially the source of a dependence.
902 auto *Src = A->first;
903 auto SrcDes = A->second;
904
905 // B is potentially the sink of a dependence.
906 auto *Sink = B->first;
907 auto SinkDes = B->second;
908
909 // Code motion for interleaved accesses can't violate WAR dependences.
910 // Thus, reordering is legal if the source isn't a write.
911 if (!Src->mayWriteToMemory())
912 return true;
913
914 // At least one of the accesses must be strided.
915 if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride))
916 return true;
917
918 // If dependence information is not available from LoopAccessInfo,
919 // conservatively assume the instructions can't be reordered.
920 if (!areDependencesValid())
921 return false;
922
923 // If we know there is a dependence from source to sink, assume the
924 // instructions can't be reordered. Otherwise, reordering is legal.
925 return Dependences.find(Src) == Dependences.end() ||
926 !Dependences.lookup(Src).count(Sink);
927 }
928
929 /// Collect the dependences from LoopAccessInfo.
930 ///
931 /// We process the dependences once during the interleaved access analysis to
932 /// enable constant-time dependence queries.
collectDependences()933 void collectDependences() {
934 if (!areDependencesValid())
935 return;
936 auto *Deps = LAI->getDepChecker().getDependences();
937 for (auto Dep : *Deps)
938 Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI));
939 }
940 };
941
942 } // llvm namespace
943
944 #endif
945