1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines some vectorizer utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H 14 #define LLVM_ANALYSIS_VECTORUTILS_H 15 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/SmallSet.h" 18 #include "llvm/Analysis/LoopAccessAnalysis.h" 19 #include "llvm/Analysis/TargetLibraryInfo.h" 20 #include "llvm/Support/CheckedArithmetic.h" 21 22 namespace llvm { 23 24 /// Describes the type of Parameters 25 enum class VFParamKind { 26 Vector, // No semantic information. 27 OMP_Linear, // declare simd linear(i) 28 OMP_LinearRef, // declare simd linear(ref(i)) 29 OMP_LinearVal, // declare simd linear(val(i)) 30 OMP_LinearUVal, // declare simd linear(uval(i)) 31 OMP_LinearPos, // declare simd linear(i:c) uniform(c) 32 OMP_LinearValPos, // declare simd linear(val(i:c)) uniform(c) 33 OMP_LinearRefPos, // declare simd linear(ref(i:c)) uniform(c) 34 OMP_LinearUValPos, // declare simd linear(uval(i:c)) uniform(c 35 OMP_Uniform, // declare simd uniform(i) 36 GlobalPredicate, // Global logical predicate that acts on all lanes 37 // of the input and output mask concurrently. For 38 // example, it is implied by the `M` token in the 39 // Vector Function ABI mangled name. 40 Unknown 41 }; 42 43 /// Describes the type of Instruction Set Architecture 44 enum class VFISAKind { 45 AdvancedSIMD, // AArch64 Advanced SIMD (NEON) 46 SVE, // AArch64 Scalable Vector Extension 47 SSE, // x86 SSE 48 AVX, // x86 AVX 49 AVX2, // x86 AVX2 50 AVX512, // x86 AVX512 51 LLVM, // LLVM internal ISA for functions that are not 52 // attached to an existing ABI via name mangling. 53 Unknown // Unknown ISA 54 }; 55 56 /// Encapsulates information needed to describe a parameter. 57 /// 58 /// The description of the parameter is not linked directly to 59 /// OpenMP or any other vector function description. This structure 60 /// is extendible to handle other paradigms that describe vector 61 /// functions and their parameters. 62 struct VFParameter { 63 unsigned ParamPos; // Parameter Position in Scalar Function. 64 VFParamKind ParamKind; // Kind of Parameter. 65 int LinearStepOrPos = 0; // Step or Position of the Parameter. 66 Align Alignment = Align(); // Optional alignment in bytes, defaulted to 1. 67 68 // Comparison operator. 69 bool operator==(const VFParameter &Other) const { 70 return std::tie(ParamPos, ParamKind, LinearStepOrPos, Alignment) == 71 std::tie(Other.ParamPos, Other.ParamKind, Other.LinearStepOrPos, 72 Other.Alignment); 73 } 74 }; 75 76 /// Contains the information about the kind of vectorization 77 /// available. 78 /// 79 /// This object in independent on the paradigm used to 80 /// represent vector functions. in particular, it is not attached to 81 /// any target-specific ABI. 82 struct VFShape { 83 unsigned VF; // Vectorization factor. 84 bool IsScalable; // True if the function is a scalable function. 85 SmallVector<VFParameter, 8> Parameters; // List of parameter information. 86 // Comparison operator. 87 bool operator==(const VFShape &Other) const { 88 return std::tie(VF, IsScalable, Parameters) == 89 std::tie(Other.VF, Other.IsScalable, Other.Parameters); 90 } 91 92 /// Update the parameter in position P.ParamPos to P. updateParamVFShape93 void updateParam(VFParameter P) { 94 assert(P.ParamPos < Parameters.size() && "Invalid parameter position."); 95 Parameters[P.ParamPos] = P; 96 assert(hasValidParameterList() && "Invalid parameter list"); 97 } 98 99 // Retrieve the VFShape that can be used to map a (scalar) function to itself, 100 // with VF = 1. getScalarShapeVFShape101 static VFShape getScalarShape(const CallInst &CI) { 102 return VFShape::get(CI, /*EC*/ {1, false}, /*HasGlobalPredicate*/ false); 103 } 104 105 // Retrieve the basic vectorization shape of the function, where all 106 // parameters are mapped to VFParamKind::Vector with \p EC 107 // lanes. Specifies whether the function has a Global Predicate 108 // argument via \p HasGlobalPred. getVFShape109 static VFShape get(const CallInst &CI, ElementCount EC, bool HasGlobalPred) { 110 SmallVector<VFParameter, 8> Parameters; 111 for (unsigned I = 0; I < CI.arg_size(); ++I) 112 Parameters.push_back(VFParameter({I, VFParamKind::Vector})); 113 if (HasGlobalPred) 114 Parameters.push_back( 115 VFParameter({CI.arg_size(), VFParamKind::GlobalPredicate})); 116 117 return {EC.Min, EC.Scalable, Parameters}; 118 } 119 /// Sanity check on the Parameters in the VFShape. 120 bool hasValidParameterList() const; 121 }; 122 123 /// Holds the VFShape for a specific scalar to vector function mapping. 124 struct VFInfo { 125 VFShape Shape; /// Classification of the vector function. 126 std::string ScalarName; /// Scalar Function Name. 127 std::string VectorName; /// Vector Function Name associated to this VFInfo. 128 VFISAKind ISA; /// Instruction Set Architecture. 129 130 // Comparison operator. 131 bool operator==(const VFInfo &Other) const { 132 return std::tie(Shape, ScalarName, VectorName, ISA) == 133 std::tie(Shape, Other.ScalarName, Other.VectorName, Other.ISA); 134 } 135 }; 136 137 namespace VFABI { 138 /// LLVM Internal VFABI ISA token for vector functions. 139 static constexpr char const *_LLVM_ = "_LLVM_"; 140 /// Prefix for internal name redirection for vector function that 141 /// tells the compiler to scalarize the call using the scalar name 142 /// of the function. For example, a mangled name like 143 /// `_ZGV_LLVM_N2v_foo(_LLVM_Scalarize_foo)` would tell the 144 /// vectorizer to vectorize the scalar call `foo`, and to scalarize 145 /// it once vectorization is done. 146 static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_"; 147 148 /// Function to construct a VFInfo out of a mangled names in the 149 /// following format: 150 /// 151 /// <VFABI_name>{(<redirection>)} 152 /// 153 /// where <VFABI_name> is the name of the vector function, mangled according 154 /// to the rules described in the Vector Function ABI of the target vector 155 /// extension (or <isa> from now on). The <VFABI_name> is in the following 156 /// format: 157 /// 158 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)] 159 /// 160 /// This methods support demangling rules for the following <isa>: 161 /// 162 /// * AArch64: https://developer.arm.com/docs/101129/latest 163 /// 164 /// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and 165 /// https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt 166 /// 167 /// \param MangledName -> input string in the format 168 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]. 169 /// \param M -> Module used to retrieve informations about the vector 170 /// function that are not possible to retrieve from the mangled 171 /// name. At the moment, this parameter is needed only to retrieve the 172 /// Vectorization Factor of scalable vector functions from their 173 /// respective IR declarations. 174 Optional<VFInfo> tryDemangleForVFABI(StringRef MangledName, const Module &M); 175 176 /// This routine mangles the given VectorName according to the LangRef 177 /// specification for vector-function-abi-variant attribute and is specific to 178 /// the TLI mappings. It is the responsibility of the caller to make sure that 179 /// this is only used if all parameters in the vector function are vector type. 180 /// This returned string holds scalar-to-vector mapping: 181 /// _ZGV<isa><mask><vlen><vparams>_<scalarname>(<vectorname>) 182 /// 183 /// where: 184 /// 185 /// <isa> = "_LLVM_" 186 /// <mask> = "N". Note: TLI does not support masked interfaces. 187 /// <vlen> = Number of concurrent lanes, stored in the `VectorizationFactor` 188 /// field of the `VecDesc` struct. 189 /// <vparams> = "v", as many as are the numArgs. 190 /// <scalarname> = the name of the scalar function. 191 /// <vectorname> = the name of the vector function. 192 std::string mangleTLIVectorName(StringRef VectorName, StringRef ScalarName, 193 unsigned numArgs, unsigned VF); 194 195 /// Retrieve the `VFParamKind` from a string token. 196 VFParamKind getVFParamKindFromString(const StringRef Token); 197 198 // Name of the attribute where the variant mappings are stored. 199 static constexpr char const *MappingsAttrName = "vector-function-abi-variant"; 200 201 /// Populates a set of strings representing the Vector Function ABI variants 202 /// associated to the CallInst CI. If the CI does not contain the 203 /// vector-function-abi-variant attribute, we return without populating 204 /// VariantMappings, i.e. callers of getVectorVariantNames need not check for 205 /// the presence of the attribute (see InjectTLIMappings). 206 void getVectorVariantNames(const CallInst &CI, 207 SmallVectorImpl<std::string> &VariantMappings); 208 } // end namespace VFABI 209 210 /// The Vector Function Database. 211 /// 212 /// Helper class used to find the vector functions associated to a 213 /// scalar CallInst. 214 class VFDatabase { 215 /// The Module of the CallInst CI. 216 const Module *M; 217 /// The CallInst instance being queried for scalar to vector mappings. 218 const CallInst &CI; 219 /// List of vector functions descriptors associated to the call 220 /// instruction. 221 const SmallVector<VFInfo, 8> ScalarToVectorMappings; 222 223 /// Retrieve the scalar-to-vector mappings associated to the rule of 224 /// a vector Function ABI. getVFABIMappings(const CallInst & CI,SmallVectorImpl<VFInfo> & Mappings)225 static void getVFABIMappings(const CallInst &CI, 226 SmallVectorImpl<VFInfo> &Mappings) { 227 if (!CI.getCalledFunction()) 228 return; 229 230 const StringRef ScalarName = CI.getCalledFunction()->getName(); 231 232 SmallVector<std::string, 8> ListOfStrings; 233 // The check for the vector-function-abi-variant attribute is done when 234 // retrieving the vector variant names here. 235 VFABI::getVectorVariantNames(CI, ListOfStrings); 236 if (ListOfStrings.empty()) 237 return; 238 for (const auto &MangledName : ListOfStrings) { 239 const Optional<VFInfo> Shape = 240 VFABI::tryDemangleForVFABI(MangledName, *(CI.getModule())); 241 // A match is found via scalar and vector names, and also by 242 // ensuring that the variant described in the attribute has a 243 // corresponding definition or declaration of the vector 244 // function in the Module M. 245 if (Shape.hasValue() && (Shape.getValue().ScalarName == ScalarName)) { 246 assert(CI.getModule()->getFunction(Shape.getValue().VectorName) && 247 "Vector function is missing."); 248 Mappings.push_back(Shape.getValue()); 249 } 250 } 251 } 252 253 public: 254 /// Retrieve all the VFInfo instances associated to the CallInst CI. getMappings(const CallInst & CI)255 static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) { 256 SmallVector<VFInfo, 8> Ret; 257 258 // Get mappings from the Vector Function ABI variants. 259 getVFABIMappings(CI, Ret); 260 261 // Other non-VFABI variants should be retrieved here. 262 263 return Ret; 264 } 265 266 /// Constructor, requires a CallInst instance. VFDatabase(CallInst & CI)267 VFDatabase(CallInst &CI) 268 : M(CI.getModule()), CI(CI), 269 ScalarToVectorMappings(VFDatabase::getMappings(CI)) {} 270 /// \defgroup VFDatabase query interface. 271 /// 272 /// @{ 273 /// Retrieve the Function with VFShape \p Shape. getVectorizedFunction(const VFShape & Shape)274 Function *getVectorizedFunction(const VFShape &Shape) const { 275 if (Shape == VFShape::getScalarShape(CI)) 276 return CI.getCalledFunction(); 277 278 for (const auto &Info : ScalarToVectorMappings) 279 if (Info.Shape == Shape) 280 return M->getFunction(Info.VectorName); 281 282 return nullptr; 283 } 284 /// @} 285 }; 286 287 template <typename T> class ArrayRef; 288 class DemandedBits; 289 class GetElementPtrInst; 290 template <typename InstTy> class InterleaveGroup; 291 class IRBuilderBase; 292 class Loop; 293 class ScalarEvolution; 294 class TargetTransformInfo; 295 class Type; 296 class Value; 297 298 namespace Intrinsic { 299 typedef unsigned ID; 300 } 301 302 /// A helper function for converting Scalar types to vector types. 303 /// If the incoming type is void, we return void. If the VF is 1, we return 304 /// the scalar type. 305 inline Type *ToVectorTy(Type *Scalar, unsigned VF, bool isScalable = false) { 306 if (Scalar->isVoidTy() || VF == 1) 307 return Scalar; 308 return VectorType::get(Scalar, {VF, isScalable}); 309 } 310 311 /// Identify if the intrinsic is trivially vectorizable. 312 /// This method returns true if the intrinsic's argument types are all scalars 313 /// for the scalar form of the intrinsic and all vectors (or scalars handled by 314 /// hasVectorInstrinsicScalarOpd) for the vector form of the intrinsic. 315 bool isTriviallyVectorizable(Intrinsic::ID ID); 316 317 /// Identifies if the vector form of the intrinsic has a scalar operand. 318 bool hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, unsigned ScalarOpdIdx); 319 320 /// Returns intrinsic ID for call. 321 /// For the input call instruction it finds mapping intrinsic and returns 322 /// its intrinsic ID, in case it does not found it return not_intrinsic. 323 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI, 324 const TargetLibraryInfo *TLI); 325 326 /// Find the operand of the GEP that should be checked for consecutive 327 /// stores. This ignores trailing indices that have no effect on the final 328 /// pointer. 329 unsigned getGEPInductionOperand(const GetElementPtrInst *Gep); 330 331 /// If the argument is a GEP, then returns the operand identified by 332 /// getGEPInductionOperand. However, if there is some other non-loop-invariant 333 /// operand, it returns that instead. 334 Value *stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp); 335 336 /// If a value has only one user that is a CastInst, return it. 337 Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty); 338 339 /// Get the stride of a pointer access in a loop. Looks for symbolic 340 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. 341 Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp); 342 343 /// Given a vector and an element number, see if the scalar value is 344 /// already around as a register, for example if it were inserted then extracted 345 /// from the vector. 346 Value *findScalarElement(Value *V, unsigned EltNo); 347 348 /// If all non-negative \p Mask elements are the same value, return that value. 349 /// If all elements are negative (undefined) or \p Mask contains different 350 /// non-negative values, return -1. 351 int getSplatIndex(ArrayRef<int> Mask); 352 353 /// Get splat value if the input is a splat vector or return nullptr. 354 /// The value may be extracted from a splat constants vector or from 355 /// a sequence of instructions that broadcast a single value into a vector. 356 const Value *getSplatValue(const Value *V); 357 358 /// Return true if each element of the vector value \p V is poisoned or equal to 359 /// every other non-poisoned element. If an index element is specified, either 360 /// every element of the vector is poisoned or the element at that index is not 361 /// poisoned and equal to every other non-poisoned element. 362 /// This may be more powerful than the related getSplatValue() because it is 363 /// not limited by finding a scalar source value to a splatted vector. 364 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0); 365 366 /// Replace each shuffle mask index with the scaled sequential indices for an 367 /// equivalent mask of narrowed elements. Mask elements that are less than 0 368 /// (sentinel values) are repeated in the output mask. 369 /// 370 /// Example with Scale = 4: 371 /// <4 x i32> <3, 2, 0, -1> --> 372 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> 373 /// 374 /// This is the reverse process of widening shuffle mask elements, but it always 375 /// succeeds because the indexes can always be multiplied (scaled up) to map to 376 /// narrower vector elements. 377 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask, 378 SmallVectorImpl<int> &ScaledMask); 379 380 /// Try to transform a shuffle mask by replacing elements with the scaled index 381 /// for an equivalent mask of widened elements. If all mask elements that would 382 /// map to a wider element of the new mask are the same negative number 383 /// (sentinel value), that element of the new mask is the same value. If any 384 /// element in a given slice is negative and some other element in that slice is 385 /// not the same value, return false (partial matches with sentinel values are 386 /// not allowed). 387 /// 388 /// Example with Scale = 4: 389 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> --> 390 /// <4 x i32> <3, 2, 0, -1> 391 /// 392 /// This is the reverse process of narrowing shuffle mask elements if it 393 /// succeeds. This transform is not always possible because indexes may not 394 /// divide evenly (scale down) to map to wider vector elements. 395 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask, 396 SmallVectorImpl<int> &ScaledMask); 397 398 /// Compute a map of integer instructions to their minimum legal type 399 /// size. 400 /// 401 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int 402 /// type (e.g. i32) whenever arithmetic is performed on them. 403 /// 404 /// For targets with native i8 or i16 operations, usually InstCombine can shrink 405 /// the arithmetic type down again. However InstCombine refuses to create 406 /// illegal types, so for targets without i8 or i16 registers, the lengthening 407 /// and shrinking remains. 408 /// 409 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when 410 /// their scalar equivalents do not, so during vectorization it is important to 411 /// remove these lengthens and truncates when deciding the profitability of 412 /// vectorization. 413 /// 414 /// This function analyzes the given range of instructions and determines the 415 /// minimum type size each can be converted to. It attempts to remove or 416 /// minimize type size changes across each def-use chain, so for example in the 417 /// following code: 418 /// 419 /// %1 = load i8, i8* 420 /// %2 = add i8 %1, 2 421 /// %3 = load i16, i16* 422 /// %4 = zext i8 %2 to i32 423 /// %5 = zext i16 %3 to i32 424 /// %6 = add i32 %4, %5 425 /// %7 = trunc i32 %6 to i16 426 /// 427 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes 428 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}. 429 /// 430 /// If the optional TargetTransformInfo is provided, this function tries harder 431 /// to do less work by only looking at illegal types. 432 MapVector<Instruction*, uint64_t> 433 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks, 434 DemandedBits &DB, 435 const TargetTransformInfo *TTI=nullptr); 436 437 /// Compute the union of two access-group lists. 438 /// 439 /// If the list contains just one access group, it is returned directly. If the 440 /// list is empty, returns nullptr. 441 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2); 442 443 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2 444 /// are both in. If either instruction does not access memory at all, it is 445 /// considered to be in every list. 446 /// 447 /// If the list contains just one access group, it is returned directly. If the 448 /// list is empty, returns nullptr. 449 MDNode *intersectAccessGroups(const Instruction *Inst1, 450 const Instruction *Inst2); 451 452 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath, 453 /// MD_nontemporal, MD_access_group]. 454 /// For K in Kinds, we get the MDNode for K from each of the 455 /// elements of VL, compute their "intersection" (i.e., the most generic 456 /// metadata value that covers all of the individual values), and set I's 457 /// metadata for M equal to the intersection value. 458 /// 459 /// This function always sets a (possibly null) value for each K in Kinds. 460 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL); 461 462 /// Create a mask that filters the members of an interleave group where there 463 /// are gaps. 464 /// 465 /// For example, the mask for \p Group with interleave-factor 3 466 /// and \p VF 4, that has only its first member present is: 467 /// 468 /// <1,0,0,1,0,0,1,0,0,1,0,0> 469 /// 470 /// Note: The result is a mask of 0's and 1's, as opposed to the other 471 /// create[*]Mask() utilities which create a shuffle mask (mask that 472 /// consists of indices). 473 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF, 474 const InterleaveGroup<Instruction> &Group); 475 476 /// Create a mask with replicated elements. 477 /// 478 /// This function creates a shuffle mask for replicating each of the \p VF 479 /// elements in a vector \p ReplicationFactor times. It can be used to 480 /// transform a mask of \p VF elements into a mask of 481 /// \p VF * \p ReplicationFactor elements used by a predicated 482 /// interleaved-group of loads/stores whose Interleaved-factor == 483 /// \p ReplicationFactor. 484 /// 485 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is: 486 /// 487 /// <0,0,0,1,1,1,2,2,2,3,3,3> 488 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor, 489 unsigned VF); 490 491 /// Create an interleave shuffle mask. 492 /// 493 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of 494 /// vectorization factor \p VF into a single wide vector. The mask is of the 495 /// form: 496 /// 497 /// <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...> 498 /// 499 /// For example, the mask for VF = 4 and NumVecs = 2 is: 500 /// 501 /// <0, 4, 1, 5, 2, 6, 3, 7>. 502 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs); 503 504 /// Create a stride shuffle mask. 505 /// 506 /// This function creates a shuffle mask whose elements begin at \p Start and 507 /// are incremented by \p Stride. The mask can be used to deinterleave an 508 /// interleaved vector into separate vectors of vectorization factor \p VF. The 509 /// mask is of the form: 510 /// 511 /// <Start, Start + Stride, ..., Start + Stride * (VF - 1)> 512 /// 513 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is: 514 /// 515 /// <0, 2, 4, 6> 516 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride, 517 unsigned VF); 518 519 /// Create a sequential shuffle mask. 520 /// 521 /// This function creates shuffle mask whose elements are sequential and begin 522 /// at \p Start. The mask contains \p NumInts integers and is padded with \p 523 /// NumUndefs undef values. The mask is of the form: 524 /// 525 /// <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs> 526 /// 527 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is: 528 /// 529 /// <0, 1, 2, 3, undef, undef, undef, undef> 530 llvm::SmallVector<int, 16> 531 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs); 532 533 /// Concatenate a list of vectors. 534 /// 535 /// This function generates code that concatenate the vectors in \p Vecs into a 536 /// single large vector. The number of vectors should be greater than one, and 537 /// their element types should be the same. The number of elements in the 538 /// vectors should also be the same; however, if the last vector has fewer 539 /// elements, it will be padded with undefs. 540 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs); 541 542 /// Given a mask vector of the form <Y x i1>, Return true if all of the 543 /// elements of this predicate mask are false or undef. That is, return true 544 /// if all lanes can be assumed inactive. 545 bool maskIsAllZeroOrUndef(Value *Mask); 546 547 /// Given a mask vector of the form <Y x i1>, Return true if all of the 548 /// elements of this predicate mask are true or undef. That is, return true 549 /// if all lanes can be assumed active. 550 bool maskIsAllOneOrUndef(Value *Mask); 551 552 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y) 553 /// for each lane which may be active. 554 APInt possiblyDemandedEltsInMask(Value *Mask); 555 556 /// The group of interleaved loads/stores sharing the same stride and 557 /// close to each other. 558 /// 559 /// Each member in this group has an index starting from 0, and the largest 560 /// index should be less than interleaved factor, which is equal to the absolute 561 /// value of the access's stride. 562 /// 563 /// E.g. An interleaved load group of factor 4: 564 /// for (unsigned i = 0; i < 1024; i+=4) { 565 /// a = A[i]; // Member of index 0 566 /// b = A[i+1]; // Member of index 1 567 /// d = A[i+3]; // Member of index 3 568 /// ... 569 /// } 570 /// 571 /// An interleaved store group of factor 4: 572 /// for (unsigned i = 0; i < 1024; i+=4) { 573 /// ... 574 /// A[i] = a; // Member of index 0 575 /// A[i+1] = b; // Member of index 1 576 /// A[i+2] = c; // Member of index 2 577 /// A[i+3] = d; // Member of index 3 578 /// } 579 /// 580 /// Note: the interleaved load group could have gaps (missing members), but 581 /// the interleaved store group doesn't allow gaps. 582 template <typename InstTy> class InterleaveGroup { 583 public: InterleaveGroup(uint32_t Factor,bool Reverse,Align Alignment)584 InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment) 585 : Factor(Factor), Reverse(Reverse), Alignment(Alignment), 586 InsertPos(nullptr) {} 587 InterleaveGroup(InstTy * Instr,int32_t Stride,Align Alignment)588 InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment) 589 : Alignment(Alignment), InsertPos(Instr) { 590 Factor = std::abs(Stride); 591 assert(Factor > 1 && "Invalid interleave factor"); 592 593 Reverse = Stride < 0; 594 Members[0] = Instr; 595 } 596 isReverse()597 bool isReverse() const { return Reverse; } getFactor()598 uint32_t getFactor() const { return Factor; } 599 LLVM_ATTRIBUTE_DEPRECATED(uint32_t getAlignment() const, 600 "Use getAlign instead.") { 601 return Alignment.value(); 602 } getAlign()603 Align getAlign() const { return Alignment; } getNumMembers()604 uint32_t getNumMembers() const { return Members.size(); } 605 606 /// Try to insert a new member \p Instr with index \p Index and 607 /// alignment \p NewAlign. The index is related to the leader and it could be 608 /// negative if it is the new leader. 609 /// 610 /// \returns false if the instruction doesn't belong to the group. insertMember(InstTy * Instr,int32_t Index,Align NewAlign)611 bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) { 612 // Make sure the key fits in an int32_t. 613 Optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey); 614 if (!MaybeKey) 615 return false; 616 int32_t Key = *MaybeKey; 617 618 // Skip if there is already a member with the same index. 619 if (Members.find(Key) != Members.end()) 620 return false; 621 622 if (Key > LargestKey) { 623 // The largest index is always less than the interleave factor. 624 if (Index >= static_cast<int32_t>(Factor)) 625 return false; 626 627 LargestKey = Key; 628 } else if (Key < SmallestKey) { 629 630 // Make sure the largest index fits in an int32_t. 631 Optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key); 632 if (!MaybeLargestIndex) 633 return false; 634 635 // The largest index is always less than the interleave factor. 636 if (*MaybeLargestIndex >= static_cast<int64_t>(Factor)) 637 return false; 638 639 SmallestKey = Key; 640 } 641 642 // It's always safe to select the minimum alignment. 643 Alignment = std::min(Alignment, NewAlign); 644 Members[Key] = Instr; 645 return true; 646 } 647 648 /// Get the member with the given index \p Index 649 /// 650 /// \returns nullptr if contains no such member. getMember(uint32_t Index)651 InstTy *getMember(uint32_t Index) const { 652 int32_t Key = SmallestKey + Index; 653 auto Member = Members.find(Key); 654 if (Member == Members.end()) 655 return nullptr; 656 657 return Member->second; 658 } 659 660 /// Get the index for the given member. Unlike the key in the member 661 /// map, the index starts from 0. getIndex(const InstTy * Instr)662 uint32_t getIndex(const InstTy *Instr) const { 663 for (auto I : Members) { 664 if (I.second == Instr) 665 return I.first - SmallestKey; 666 } 667 668 llvm_unreachable("InterleaveGroup contains no such member"); 669 } 670 getInsertPos()671 InstTy *getInsertPos() const { return InsertPos; } setInsertPos(InstTy * Inst)672 void setInsertPos(InstTy *Inst) { InsertPos = Inst; } 673 674 /// Add metadata (e.g. alias info) from the instructions in this group to \p 675 /// NewInst. 676 /// 677 /// FIXME: this function currently does not add noalias metadata a'la 678 /// addNewMedata. To do that we need to compute the intersection of the 679 /// noalias info from all members. 680 void addMetadata(InstTy *NewInst) const; 681 682 /// Returns true if this Group requires a scalar iteration to handle gaps. requiresScalarEpilogue()683 bool requiresScalarEpilogue() const { 684 // If the last member of the Group exists, then a scalar epilog is not 685 // needed for this group. 686 if (getMember(getFactor() - 1)) 687 return false; 688 689 // We have a group with gaps. It therefore cannot be a group of stores, 690 // and it can't be a reversed access, because such groups get invalidated. 691 assert(!getMember(0)->mayWriteToMemory() && 692 "Group should have been invalidated"); 693 assert(!isReverse() && "Group should have been invalidated"); 694 695 // This is a group of loads, with gaps, and without a last-member 696 return true; 697 } 698 699 private: 700 uint32_t Factor; // Interleave Factor. 701 bool Reverse; 702 Align Alignment; 703 DenseMap<int32_t, InstTy *> Members; 704 int32_t SmallestKey = 0; 705 int32_t LargestKey = 0; 706 707 // To avoid breaking dependences, vectorized instructions of an interleave 708 // group should be inserted at either the first load or the last store in 709 // program order. 710 // 711 // E.g. %even = load i32 // Insert Position 712 // %add = add i32 %even // Use of %even 713 // %odd = load i32 714 // 715 // store i32 %even 716 // %odd = add i32 // Def of %odd 717 // store i32 %odd // Insert Position 718 InstTy *InsertPos; 719 }; 720 721 /// Drive the analysis of interleaved memory accesses in the loop. 722 /// 723 /// Use this class to analyze interleaved accesses only when we can vectorize 724 /// a loop. Otherwise it's meaningless to do analysis as the vectorization 725 /// on interleaved accesses is unsafe. 726 /// 727 /// The analysis collects interleave groups and records the relationships 728 /// between the member and the group in a map. 729 class InterleavedAccessInfo { 730 public: InterleavedAccessInfo(PredicatedScalarEvolution & PSE,Loop * L,DominatorTree * DT,LoopInfo * LI,const LoopAccessInfo * LAI)731 InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, 732 DominatorTree *DT, LoopInfo *LI, 733 const LoopAccessInfo *LAI) 734 : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {} 735 ~InterleavedAccessInfo()736 ~InterleavedAccessInfo() { invalidateGroups(); } 737 738 /// Analyze the interleaved accesses and collect them in interleave 739 /// groups. Substitute symbolic strides using \p Strides. 740 /// Consider also predicated loads/stores in the analysis if 741 /// \p EnableMaskedInterleavedGroup is true. 742 void analyzeInterleaving(bool EnableMaskedInterleavedGroup); 743 744 /// Invalidate groups, e.g., in case all blocks in loop will be predicated 745 /// contrary to original assumption. Although we currently prevent group 746 /// formation for predicated accesses, we may be able to relax this limitation 747 /// in the future once we handle more complicated blocks. Returns true if any 748 /// groups were invalidated. invalidateGroups()749 bool invalidateGroups() { 750 if (InterleaveGroups.empty()) { 751 assert( 752 !RequiresScalarEpilogue && 753 "RequiresScalarEpilog should not be set without interleave groups"); 754 return false; 755 } 756 757 InterleaveGroupMap.clear(); 758 for (auto *Ptr : InterleaveGroups) 759 delete Ptr; 760 InterleaveGroups.clear(); 761 RequiresScalarEpilogue = false; 762 return true; 763 } 764 765 /// Check if \p Instr belongs to any interleave group. isInterleaved(Instruction * Instr)766 bool isInterleaved(Instruction *Instr) const { 767 return InterleaveGroupMap.find(Instr) != InterleaveGroupMap.end(); 768 } 769 770 /// Get the interleave group that \p Instr belongs to. 771 /// 772 /// \returns nullptr if doesn't have such group. 773 InterleaveGroup<Instruction> * getInterleaveGroup(const Instruction * Instr)774 getInterleaveGroup(const Instruction *Instr) const { 775 if (InterleaveGroupMap.count(Instr)) 776 return InterleaveGroupMap.find(Instr)->second; 777 return nullptr; 778 } 779 780 iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>> getInterleaveGroups()781 getInterleaveGroups() { 782 return make_range(InterleaveGroups.begin(), InterleaveGroups.end()); 783 } 784 785 /// Returns true if an interleaved group that may access memory 786 /// out-of-bounds requires a scalar epilogue iteration for correctness. requiresScalarEpilogue()787 bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; } 788 789 /// Invalidate groups that require a scalar epilogue (due to gaps). This can 790 /// happen when optimizing for size forbids a scalar epilogue, and the gap 791 /// cannot be filtered by masking the load/store. 792 void invalidateGroupsRequiringScalarEpilogue(); 793 794 private: 795 /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. 796 /// Simplifies SCEV expressions in the context of existing SCEV assumptions. 797 /// The interleaved access analysis can also add new predicates (for example 798 /// by versioning strides of pointers). 799 PredicatedScalarEvolution &PSE; 800 801 Loop *TheLoop; 802 DominatorTree *DT; 803 LoopInfo *LI; 804 const LoopAccessInfo *LAI; 805 806 /// True if the loop may contain non-reversed interleaved groups with 807 /// out-of-bounds accesses. We ensure we don't speculatively access memory 808 /// out-of-bounds by executing at least one scalar epilogue iteration. 809 bool RequiresScalarEpilogue = false; 810 811 /// Holds the relationships between the members and the interleave group. 812 DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap; 813 814 SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups; 815 816 /// Holds dependences among the memory accesses in the loop. It maps a source 817 /// access to a set of dependent sink accesses. 818 DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences; 819 820 /// The descriptor for a strided memory access. 821 struct StrideDescriptor { 822 StrideDescriptor() = default; StrideDescriptorStrideDescriptor823 StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, 824 Align Alignment) 825 : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {} 826 827 // The access's stride. It is negative for a reverse access. 828 int64_t Stride = 0; 829 830 // The scalar expression of this access. 831 const SCEV *Scev = nullptr; 832 833 // The size of the memory object. 834 uint64_t Size = 0; 835 836 // The alignment of this access. 837 Align Alignment; 838 }; 839 840 /// A type for holding instructions and their stride descriptors. 841 using StrideEntry = std::pair<Instruction *, StrideDescriptor>; 842 843 /// Create a new interleave group with the given instruction \p Instr, 844 /// stride \p Stride and alignment \p Align. 845 /// 846 /// \returns the newly created interleave group. 847 InterleaveGroup<Instruction> * createInterleaveGroup(Instruction * Instr,int Stride,Align Alignment)848 createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) { 849 assert(!InterleaveGroupMap.count(Instr) && 850 "Already in an interleaved access group"); 851 InterleaveGroupMap[Instr] = 852 new InterleaveGroup<Instruction>(Instr, Stride, Alignment); 853 InterleaveGroups.insert(InterleaveGroupMap[Instr]); 854 return InterleaveGroupMap[Instr]; 855 } 856 857 /// Release the group and remove all the relationships. releaseGroup(InterleaveGroup<Instruction> * Group)858 void releaseGroup(InterleaveGroup<Instruction> *Group) { 859 for (unsigned i = 0; i < Group->getFactor(); i++) 860 if (Instruction *Member = Group->getMember(i)) 861 InterleaveGroupMap.erase(Member); 862 863 InterleaveGroups.erase(Group); 864 delete Group; 865 } 866 867 /// Collect all the accesses with a constant stride in program order. 868 void collectConstStrideAccesses( 869 MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, 870 const ValueToValueMap &Strides); 871 872 /// Returns true if \p Stride is allowed in an interleaved group. 873 static bool isStrided(int Stride); 874 875 /// Returns true if \p BB is a predicated block. isPredicated(BasicBlock * BB)876 bool isPredicated(BasicBlock *BB) const { 877 return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); 878 } 879 880 /// Returns true if LoopAccessInfo can be used for dependence queries. areDependencesValid()881 bool areDependencesValid() const { 882 return LAI && LAI->getDepChecker().getDependences(); 883 } 884 885 /// Returns true if memory accesses \p A and \p B can be reordered, if 886 /// necessary, when constructing interleaved groups. 887 /// 888 /// \p A must precede \p B in program order. We return false if reordering is 889 /// not necessary or is prevented because \p A and \p B may be dependent. canReorderMemAccessesForInterleavedGroups(StrideEntry * A,StrideEntry * B)890 bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A, 891 StrideEntry *B) const { 892 // Code motion for interleaved accesses can potentially hoist strided loads 893 // and sink strided stores. The code below checks the legality of the 894 // following two conditions: 895 // 896 // 1. Potentially moving a strided load (B) before any store (A) that 897 // precedes B, or 898 // 899 // 2. Potentially moving a strided store (A) after any load or store (B) 900 // that A precedes. 901 // 902 // It's legal to reorder A and B if we know there isn't a dependence from A 903 // to B. Note that this determination is conservative since some 904 // dependences could potentially be reordered safely. 905 906 // A is potentially the source of a dependence. 907 auto *Src = A->first; 908 auto SrcDes = A->second; 909 910 // B is potentially the sink of a dependence. 911 auto *Sink = B->first; 912 auto SinkDes = B->second; 913 914 // Code motion for interleaved accesses can't violate WAR dependences. 915 // Thus, reordering is legal if the source isn't a write. 916 if (!Src->mayWriteToMemory()) 917 return true; 918 919 // At least one of the accesses must be strided. 920 if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride)) 921 return true; 922 923 // If dependence information is not available from LoopAccessInfo, 924 // conservatively assume the instructions can't be reordered. 925 if (!areDependencesValid()) 926 return false; 927 928 // If we know there is a dependence from source to sink, assume the 929 // instructions can't be reordered. Otherwise, reordering is legal. 930 return Dependences.find(Src) == Dependences.end() || 931 !Dependences.lookup(Src).count(Sink); 932 } 933 934 /// Collect the dependences from LoopAccessInfo. 935 /// 936 /// We process the dependences once during the interleaved access analysis to 937 /// enable constant-time dependence queries. collectDependences()938 void collectDependences() { 939 if (!areDependencesValid()) 940 return; 941 auto *Deps = LAI->getDepChecker().getDependences(); 942 for (auto Dep : *Deps) 943 Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI)); 944 } 945 }; 946 947 } // llvm namespace 948 949 #endif 950