1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines some vectorizer utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H 14 #define LLVM_ANALYSIS_VECTORUTILS_H 15 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/SmallVector.h" 18 #include "llvm/Analysis/LoopAccessAnalysis.h" 19 #include "llvm/Support/CheckedArithmetic.h" 20 21 namespace llvm { 22 class TargetLibraryInfo; 23 24 /// Describes the type of Parameters 25 enum class VFParamKind { 26 Vector, // No semantic information. 27 OMP_Linear, // declare simd linear(i) 28 OMP_LinearRef, // declare simd linear(ref(i)) 29 OMP_LinearVal, // declare simd linear(val(i)) 30 OMP_LinearUVal, // declare simd linear(uval(i)) 31 OMP_LinearPos, // declare simd linear(i:c) uniform(c) 32 OMP_LinearValPos, // declare simd linear(val(i:c)) uniform(c) 33 OMP_LinearRefPos, // declare simd linear(ref(i:c)) uniform(c) 34 OMP_LinearUValPos, // declare simd linear(uval(i:c)) uniform(c) 35 OMP_Uniform, // declare simd uniform(i) 36 GlobalPredicate, // Global logical predicate that acts on all lanes 37 // of the input and output mask concurrently. For 38 // example, it is implied by the `M` token in the 39 // Vector Function ABI mangled name. 40 Unknown 41 }; 42 43 /// Describes the type of Instruction Set Architecture 44 enum class VFISAKind { 45 AdvancedSIMD, // AArch64 Advanced SIMD (NEON) 46 SVE, // AArch64 Scalable Vector Extension 47 SSE, // x86 SSE 48 AVX, // x86 AVX 49 AVX2, // x86 AVX2 50 AVX512, // x86 AVX512 51 LLVM, // LLVM internal ISA for functions that are not 52 // attached to an existing ABI via name mangling. 53 Unknown // Unknown ISA 54 }; 55 56 /// Encapsulates information needed to describe a parameter. 57 /// 58 /// The description of the parameter is not linked directly to 59 /// OpenMP or any other vector function description. This structure 60 /// is extendible to handle other paradigms that describe vector 61 /// functions and their parameters. 62 struct VFParameter { 63 unsigned ParamPos; // Parameter Position in Scalar Function. 64 VFParamKind ParamKind; // Kind of Parameter. 65 int LinearStepOrPos = 0; // Step or Position of the Parameter. 66 Align Alignment = Align(); // Optional alignment in bytes, defaulted to 1. 67 68 // Comparison operator. 69 bool operator==(const VFParameter &Other) const { 70 return std::tie(ParamPos, ParamKind, LinearStepOrPos, Alignment) == 71 std::tie(Other.ParamPos, Other.ParamKind, Other.LinearStepOrPos, 72 Other.Alignment); 73 } 74 }; 75 76 /// Contains the information about the kind of vectorization 77 /// available. 78 /// 79 /// This object in independent on the paradigm used to 80 /// represent vector functions. in particular, it is not attached to 81 /// any target-specific ABI. 82 struct VFShape { 83 ElementCount VF; // Vectorization factor. 84 SmallVector<VFParameter, 8> Parameters; // List of parameter information. 85 // Comparison operator. 86 bool operator==(const VFShape &Other) const { 87 return std::tie(VF, Parameters) == std::tie(Other.VF, Other.Parameters); 88 } 89 90 /// Update the parameter in position P.ParamPos to P. 91 void updateParam(VFParameter P) { 92 assert(P.ParamPos < Parameters.size() && "Invalid parameter position."); 93 Parameters[P.ParamPos] = P; 94 assert(hasValidParameterList() && "Invalid parameter list"); 95 } 96 97 // Retrieve the VFShape that can be used to map a (scalar) function to itself, 98 // with VF = 1. 99 static VFShape getScalarShape(const CallInst &CI) { 100 return VFShape::get(CI, ElementCount::getFixed(1), 101 /*HasGlobalPredicate*/ false); 102 } 103 104 // Retrieve the basic vectorization shape of the function, where all 105 // parameters are mapped to VFParamKind::Vector with \p EC 106 // lanes. Specifies whether the function has a Global Predicate 107 // argument via \p HasGlobalPred. 108 static VFShape get(const CallInst &CI, ElementCount EC, bool HasGlobalPred) { 109 SmallVector<VFParameter, 8> Parameters; 110 for (unsigned I = 0; I < CI.arg_size(); ++I) 111 Parameters.push_back(VFParameter({I, VFParamKind::Vector})); 112 if (HasGlobalPred) 113 Parameters.push_back( 114 VFParameter({CI.arg_size(), VFParamKind::GlobalPredicate})); 115 116 return {EC, Parameters}; 117 } 118 /// Validation check on the Parameters in the VFShape. 119 bool hasValidParameterList() const; 120 }; 121 122 /// Holds the VFShape for a specific scalar to vector function mapping. 123 struct VFInfo { 124 VFShape Shape; /// Classification of the vector function. 125 std::string ScalarName; /// Scalar Function Name. 126 std::string VectorName; /// Vector Function Name associated to this VFInfo. 127 VFISAKind ISA; /// Instruction Set Architecture. 128 129 /// Returns the index of the first parameter with the kind 'GlobalPredicate', 130 /// if any exist. 131 std::optional<unsigned> getParamIndexForOptionalMask() const { 132 unsigned ParamCount = Shape.Parameters.size(); 133 for (unsigned i = 0; i < ParamCount; ++i) 134 if (Shape.Parameters[i].ParamKind == VFParamKind::GlobalPredicate) 135 return i; 136 137 return std::nullopt; 138 } 139 140 /// Returns true if at least one of the operands to the vectorized function 141 /// has the kind 'GlobalPredicate'. 142 bool isMasked() const { return getParamIndexForOptionalMask().has_value(); } 143 }; 144 145 namespace VFABI { 146 /// LLVM Internal VFABI ISA token for vector functions. 147 static constexpr char const *_LLVM_ = "_LLVM_"; 148 /// Prefix for internal name redirection for vector function that 149 /// tells the compiler to scalarize the call using the scalar name 150 /// of the function. For example, a mangled name like 151 /// `_ZGV_LLVM_N2v_foo(_LLVM_Scalarize_foo)` would tell the 152 /// vectorizer to vectorize the scalar call `foo`, and to scalarize 153 /// it once vectorization is done. 154 static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_"; 155 156 /// Function to construct a VFInfo out of a mangled names in the 157 /// following format: 158 /// 159 /// <VFABI_name>{(<redirection>)} 160 /// 161 /// where <VFABI_name> is the name of the vector function, mangled according 162 /// to the rules described in the Vector Function ABI of the target vector 163 /// extension (or <isa> from now on). The <VFABI_name> is in the following 164 /// format: 165 /// 166 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)] 167 /// 168 /// This methods support demangling rules for the following <isa>: 169 /// 170 /// * AArch64: https://developer.arm.com/docs/101129/latest 171 /// 172 /// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and 173 /// https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt 174 /// 175 /// \param MangledName -> input string in the format 176 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]. 177 /// \param M -> Module used to retrieve informations about the vector 178 /// function that are not possible to retrieve from the mangled 179 /// name. At the moment, this parameter is needed only to retrieve the 180 /// Vectorization Factor of scalable vector functions from their 181 /// respective IR declarations. 182 std::optional<VFInfo> tryDemangleForVFABI(StringRef MangledName, 183 const Module &M); 184 185 /// This routine mangles the given VectorName according to the LangRef 186 /// specification for vector-function-abi-variant attribute and is specific to 187 /// the TLI mappings. It is the responsibility of the caller to make sure that 188 /// this is only used if all parameters in the vector function are vector type. 189 /// This returned string holds scalar-to-vector mapping: 190 /// _ZGV<isa><mask><vlen><vparams>_<scalarname>(<vectorname>) 191 /// 192 /// where: 193 /// 194 /// <isa> = "_LLVM_" 195 /// <mask> = "M" if masked, "N" if no mask. 196 /// <vlen> = Number of concurrent lanes, stored in the `VectorizationFactor` 197 /// field of the `VecDesc` struct. If the number of lanes is scalable 198 /// then 'x' is printed instead. 199 /// <vparams> = "v", as many as are the numArgs. 200 /// <scalarname> = the name of the scalar function. 201 /// <vectorname> = the name of the vector function. 202 std::string mangleTLIVectorName(StringRef VectorName, StringRef ScalarName, 203 unsigned numArgs, ElementCount VF, 204 bool Masked = false); 205 206 /// Retrieve the `VFParamKind` from a string token. 207 VFParamKind getVFParamKindFromString(const StringRef Token); 208 209 // Name of the attribute where the variant mappings are stored. 210 static constexpr char const *MappingsAttrName = "vector-function-abi-variant"; 211 212 /// Populates a set of strings representing the Vector Function ABI variants 213 /// associated to the CallInst CI. If the CI does not contain the 214 /// vector-function-abi-variant attribute, we return without populating 215 /// VariantMappings, i.e. callers of getVectorVariantNames need not check for 216 /// the presence of the attribute (see InjectTLIMappings). 217 void getVectorVariantNames(const CallInst &CI, 218 SmallVectorImpl<std::string> &VariantMappings); 219 } // end namespace VFABI 220 221 /// The Vector Function Database. 222 /// 223 /// Helper class used to find the vector functions associated to a 224 /// scalar CallInst. 225 class VFDatabase { 226 /// The Module of the CallInst CI. 227 const Module *M; 228 /// The CallInst instance being queried for scalar to vector mappings. 229 const CallInst &CI; 230 /// List of vector functions descriptors associated to the call 231 /// instruction. 232 const SmallVector<VFInfo, 8> ScalarToVectorMappings; 233 234 /// Retrieve the scalar-to-vector mappings associated to the rule of 235 /// a vector Function ABI. 236 static void getVFABIMappings(const CallInst &CI, 237 SmallVectorImpl<VFInfo> &Mappings) { 238 if (!CI.getCalledFunction()) 239 return; 240 241 const StringRef ScalarName = CI.getCalledFunction()->getName(); 242 243 SmallVector<std::string, 8> ListOfStrings; 244 // The check for the vector-function-abi-variant attribute is done when 245 // retrieving the vector variant names here. 246 VFABI::getVectorVariantNames(CI, ListOfStrings); 247 if (ListOfStrings.empty()) 248 return; 249 for (const auto &MangledName : ListOfStrings) { 250 const std::optional<VFInfo> Shape = 251 VFABI::tryDemangleForVFABI(MangledName, *(CI.getModule())); 252 // A match is found via scalar and vector names, and also by 253 // ensuring that the variant described in the attribute has a 254 // corresponding definition or declaration of the vector 255 // function in the Module M. 256 if (Shape && (Shape->ScalarName == ScalarName)) { 257 assert(CI.getModule()->getFunction(Shape->VectorName) && 258 "Vector function is missing."); 259 Mappings.push_back(*Shape); 260 } 261 } 262 } 263 264 public: 265 /// Retrieve all the VFInfo instances associated to the CallInst CI. 266 static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) { 267 SmallVector<VFInfo, 8> Ret; 268 269 // Get mappings from the Vector Function ABI variants. 270 getVFABIMappings(CI, Ret); 271 272 // Other non-VFABI variants should be retrieved here. 273 274 return Ret; 275 } 276 277 static bool hasMaskedVariant(const CallInst &CI, 278 std::optional<ElementCount> VF = std::nullopt) { 279 // Check whether we have at least one masked vector version of a scalar 280 // function. If no VF is specified then we check for any masked variant, 281 // otherwise we look for one that matches the supplied VF. 282 auto Mappings = VFDatabase::getMappings(CI); 283 for (VFInfo Info : Mappings) 284 if (!VF || Info.Shape.VF == *VF) 285 if (Info.isMasked()) 286 return true; 287 288 return false; 289 } 290 291 /// Constructor, requires a CallInst instance. 292 VFDatabase(CallInst &CI) 293 : M(CI.getModule()), CI(CI), 294 ScalarToVectorMappings(VFDatabase::getMappings(CI)) {} 295 /// \defgroup VFDatabase query interface. 296 /// 297 /// @{ 298 /// Retrieve the Function with VFShape \p Shape. 299 Function *getVectorizedFunction(const VFShape &Shape) const { 300 if (Shape == VFShape::getScalarShape(CI)) 301 return CI.getCalledFunction(); 302 303 for (const auto &Info : ScalarToVectorMappings) 304 if (Info.Shape == Shape) 305 return M->getFunction(Info.VectorName); 306 307 return nullptr; 308 } 309 /// @} 310 }; 311 312 template <typename T> class ArrayRef; 313 class DemandedBits; 314 template <typename InstTy> class InterleaveGroup; 315 class IRBuilderBase; 316 class Loop; 317 class ScalarEvolution; 318 class TargetTransformInfo; 319 class Type; 320 class Value; 321 322 namespace Intrinsic { 323 typedef unsigned ID; 324 } 325 326 /// A helper function for converting Scalar types to vector types. If 327 /// the incoming type is void, we return void. If the EC represents a 328 /// scalar, we return the scalar type. 329 inline Type *ToVectorTy(Type *Scalar, ElementCount EC) { 330 if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar()) 331 return Scalar; 332 return VectorType::get(Scalar, EC); 333 } 334 335 inline Type *ToVectorTy(Type *Scalar, unsigned VF) { 336 return ToVectorTy(Scalar, ElementCount::getFixed(VF)); 337 } 338 339 /// Identify if the intrinsic is trivially vectorizable. 340 /// This method returns true if the intrinsic's argument types are all scalars 341 /// for the scalar form of the intrinsic and all vectors (or scalars handled by 342 /// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic. 343 bool isTriviallyVectorizable(Intrinsic::ID ID); 344 345 /// Identifies if the vector form of the intrinsic has a scalar operand. 346 bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, 347 unsigned ScalarOpdIdx); 348 349 /// Identifies if the vector form of the intrinsic is overloaded on the type of 350 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1. 351 bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx); 352 353 /// Returns intrinsic ID for call. 354 /// For the input call instruction it finds mapping intrinsic and returns 355 /// its intrinsic ID, in case it does not found it return not_intrinsic. 356 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI, 357 const TargetLibraryInfo *TLI); 358 359 /// Given a vector and an element number, see if the scalar value is 360 /// already around as a register, for example if it were inserted then extracted 361 /// from the vector. 362 Value *findScalarElement(Value *V, unsigned EltNo); 363 364 /// If all non-negative \p Mask elements are the same value, return that value. 365 /// If all elements are negative (undefined) or \p Mask contains different 366 /// non-negative values, return -1. 367 int getSplatIndex(ArrayRef<int> Mask); 368 369 /// Get splat value if the input is a splat vector or return nullptr. 370 /// The value may be extracted from a splat constants vector or from 371 /// a sequence of instructions that broadcast a single value into a vector. 372 Value *getSplatValue(const Value *V); 373 374 /// Return true if each element of the vector value \p V is poisoned or equal to 375 /// every other non-poisoned element. If an index element is specified, either 376 /// every element of the vector is poisoned or the element at that index is not 377 /// poisoned and equal to every other non-poisoned element. 378 /// This may be more powerful than the related getSplatValue() because it is 379 /// not limited by finding a scalar source value to a splatted vector. 380 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0); 381 382 /// Transform a shuffle mask's output demanded element mask into demanded 383 /// element masks for the 2 operands, returns false if the mask isn't valid. 384 /// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth]. 385 /// \p AllowUndefElts permits "-1" indices to be treated as undef. 386 bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask, 387 const APInt &DemandedElts, APInt &DemandedLHS, 388 APInt &DemandedRHS, bool AllowUndefElts = false); 389 390 /// Replace each shuffle mask index with the scaled sequential indices for an 391 /// equivalent mask of narrowed elements. Mask elements that are less than 0 392 /// (sentinel values) are repeated in the output mask. 393 /// 394 /// Example with Scale = 4: 395 /// <4 x i32> <3, 2, 0, -1> --> 396 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> 397 /// 398 /// This is the reverse process of widening shuffle mask elements, but it always 399 /// succeeds because the indexes can always be multiplied (scaled up) to map to 400 /// narrower vector elements. 401 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask, 402 SmallVectorImpl<int> &ScaledMask); 403 404 /// Try to transform a shuffle mask by replacing elements with the scaled index 405 /// for an equivalent mask of widened elements. If all mask elements that would 406 /// map to a wider element of the new mask are the same negative number 407 /// (sentinel value), that element of the new mask is the same value. If any 408 /// element in a given slice is negative and some other element in that slice is 409 /// not the same value, return false (partial matches with sentinel values are 410 /// not allowed). 411 /// 412 /// Example with Scale = 4: 413 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> --> 414 /// <4 x i32> <3, 2, 0, -1> 415 /// 416 /// This is the reverse process of narrowing shuffle mask elements if it 417 /// succeeds. This transform is not always possible because indexes may not 418 /// divide evenly (scale down) to map to wider vector elements. 419 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask, 420 SmallVectorImpl<int> &ScaledMask); 421 422 /// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds, 423 /// to get the shuffle mask with widest possible elements. 424 void getShuffleMaskWithWidestElts(ArrayRef<int> Mask, 425 SmallVectorImpl<int> &ScaledMask); 426 427 /// Splits and processes shuffle mask depending on the number of input and 428 /// output registers. The function does 2 main things: 1) splits the 429 /// source/destination vectors into real registers; 2) do the mask analysis to 430 /// identify which real registers are permuted. Then the function processes 431 /// resulting registers mask using provided action items. If no input register 432 /// is defined, \p NoInputAction action is used. If only 1 input register is 433 /// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to 434 /// process > 2 input registers and masks. 435 /// \param Mask Original shuffle mask. 436 /// \param NumOfSrcRegs Number of source registers. 437 /// \param NumOfDestRegs Number of destination registers. 438 /// \param NumOfUsedRegs Number of actually used destination registers. 439 void processShuffleMasks( 440 ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs, 441 unsigned NumOfUsedRegs, function_ref<void()> NoInputAction, 442 function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction, 443 function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction); 444 445 /// Compute a map of integer instructions to their minimum legal type 446 /// size. 447 /// 448 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int 449 /// type (e.g. i32) whenever arithmetic is performed on them. 450 /// 451 /// For targets with native i8 or i16 operations, usually InstCombine can shrink 452 /// the arithmetic type down again. However InstCombine refuses to create 453 /// illegal types, so for targets without i8 or i16 registers, the lengthening 454 /// and shrinking remains. 455 /// 456 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when 457 /// their scalar equivalents do not, so during vectorization it is important to 458 /// remove these lengthens and truncates when deciding the profitability of 459 /// vectorization. 460 /// 461 /// This function analyzes the given range of instructions and determines the 462 /// minimum type size each can be converted to. It attempts to remove or 463 /// minimize type size changes across each def-use chain, so for example in the 464 /// following code: 465 /// 466 /// %1 = load i8, i8* 467 /// %2 = add i8 %1, 2 468 /// %3 = load i16, i16* 469 /// %4 = zext i8 %2 to i32 470 /// %5 = zext i16 %3 to i32 471 /// %6 = add i32 %4, %5 472 /// %7 = trunc i32 %6 to i16 473 /// 474 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes 475 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}. 476 /// 477 /// If the optional TargetTransformInfo is provided, this function tries harder 478 /// to do less work by only looking at illegal types. 479 MapVector<Instruction*, uint64_t> 480 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks, 481 DemandedBits &DB, 482 const TargetTransformInfo *TTI=nullptr); 483 484 /// Compute the union of two access-group lists. 485 /// 486 /// If the list contains just one access group, it is returned directly. If the 487 /// list is empty, returns nullptr. 488 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2); 489 490 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2 491 /// are both in. If either instruction does not access memory at all, it is 492 /// considered to be in every list. 493 /// 494 /// If the list contains just one access group, it is returned directly. If the 495 /// list is empty, returns nullptr. 496 MDNode *intersectAccessGroups(const Instruction *Inst1, 497 const Instruction *Inst2); 498 499 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath, 500 /// MD_nontemporal, MD_access_group]. 501 /// For K in Kinds, we get the MDNode for K from each of the 502 /// elements of VL, compute their "intersection" (i.e., the most generic 503 /// metadata value that covers all of the individual values), and set I's 504 /// metadata for M equal to the intersection value. 505 /// 506 /// This function always sets a (possibly null) value for each K in Kinds. 507 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL); 508 509 /// Create a mask that filters the members of an interleave group where there 510 /// are gaps. 511 /// 512 /// For example, the mask for \p Group with interleave-factor 3 513 /// and \p VF 4, that has only its first member present is: 514 /// 515 /// <1,0,0,1,0,0,1,0,0,1,0,0> 516 /// 517 /// Note: The result is a mask of 0's and 1's, as opposed to the other 518 /// create[*]Mask() utilities which create a shuffle mask (mask that 519 /// consists of indices). 520 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF, 521 const InterleaveGroup<Instruction> &Group); 522 523 /// Create a mask with replicated elements. 524 /// 525 /// This function creates a shuffle mask for replicating each of the \p VF 526 /// elements in a vector \p ReplicationFactor times. It can be used to 527 /// transform a mask of \p VF elements into a mask of 528 /// \p VF * \p ReplicationFactor elements used by a predicated 529 /// interleaved-group of loads/stores whose Interleaved-factor == 530 /// \p ReplicationFactor. 531 /// 532 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is: 533 /// 534 /// <0,0,0,1,1,1,2,2,2,3,3,3> 535 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor, 536 unsigned VF); 537 538 /// Create an interleave shuffle mask. 539 /// 540 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of 541 /// vectorization factor \p VF into a single wide vector. The mask is of the 542 /// form: 543 /// 544 /// <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...> 545 /// 546 /// For example, the mask for VF = 4 and NumVecs = 2 is: 547 /// 548 /// <0, 4, 1, 5, 2, 6, 3, 7>. 549 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs); 550 551 /// Create a stride shuffle mask. 552 /// 553 /// This function creates a shuffle mask whose elements begin at \p Start and 554 /// are incremented by \p Stride. The mask can be used to deinterleave an 555 /// interleaved vector into separate vectors of vectorization factor \p VF. The 556 /// mask is of the form: 557 /// 558 /// <Start, Start + Stride, ..., Start + Stride * (VF - 1)> 559 /// 560 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is: 561 /// 562 /// <0, 2, 4, 6> 563 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride, 564 unsigned VF); 565 566 /// Create a sequential shuffle mask. 567 /// 568 /// This function creates shuffle mask whose elements are sequential and begin 569 /// at \p Start. The mask contains \p NumInts integers and is padded with \p 570 /// NumUndefs undef values. The mask is of the form: 571 /// 572 /// <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs> 573 /// 574 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is: 575 /// 576 /// <0, 1, 2, 3, undef, undef, undef, undef> 577 llvm::SmallVector<int, 16> 578 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs); 579 580 /// Given a shuffle mask for a binary shuffle, create the equivalent shuffle 581 /// mask assuming both operands are identical. This assumes that the unary 582 /// shuffle will use elements from operand 0 (operand 1 will be unused). 583 llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask, 584 unsigned NumElts); 585 586 /// Concatenate a list of vectors. 587 /// 588 /// This function generates code that concatenate the vectors in \p Vecs into a 589 /// single large vector. The number of vectors should be greater than one, and 590 /// their element types should be the same. The number of elements in the 591 /// vectors should also be the same; however, if the last vector has fewer 592 /// elements, it will be padded with undefs. 593 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs); 594 595 /// Given a mask vector of i1, Return true if all of the elements of this 596 /// predicate mask are known to be false or undef. That is, return true if all 597 /// lanes can be assumed inactive. 598 bool maskIsAllZeroOrUndef(Value *Mask); 599 600 /// Given a mask vector of i1, Return true if all of the elements of this 601 /// predicate mask are known to be true or undef. That is, return true if all 602 /// lanes can be assumed active. 603 bool maskIsAllOneOrUndef(Value *Mask); 604 605 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y) 606 /// for each lane which may be active. 607 APInt possiblyDemandedEltsInMask(Value *Mask); 608 609 /// The group of interleaved loads/stores sharing the same stride and 610 /// close to each other. 611 /// 612 /// Each member in this group has an index starting from 0, and the largest 613 /// index should be less than interleaved factor, which is equal to the absolute 614 /// value of the access's stride. 615 /// 616 /// E.g. An interleaved load group of factor 4: 617 /// for (unsigned i = 0; i < 1024; i+=4) { 618 /// a = A[i]; // Member of index 0 619 /// b = A[i+1]; // Member of index 1 620 /// d = A[i+3]; // Member of index 3 621 /// ... 622 /// } 623 /// 624 /// An interleaved store group of factor 4: 625 /// for (unsigned i = 0; i < 1024; i+=4) { 626 /// ... 627 /// A[i] = a; // Member of index 0 628 /// A[i+1] = b; // Member of index 1 629 /// A[i+2] = c; // Member of index 2 630 /// A[i+3] = d; // Member of index 3 631 /// } 632 /// 633 /// Note: the interleaved load group could have gaps (missing members), but 634 /// the interleaved store group doesn't allow gaps. 635 template <typename InstTy> class InterleaveGroup { 636 public: 637 InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment) 638 : Factor(Factor), Reverse(Reverse), Alignment(Alignment), 639 InsertPos(nullptr) {} 640 641 InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment) 642 : Alignment(Alignment), InsertPos(Instr) { 643 Factor = std::abs(Stride); 644 assert(Factor > 1 && "Invalid interleave factor"); 645 646 Reverse = Stride < 0; 647 Members[0] = Instr; 648 } 649 650 bool isReverse() const { return Reverse; } 651 uint32_t getFactor() const { return Factor; } 652 Align getAlign() const { return Alignment; } 653 uint32_t getNumMembers() const { return Members.size(); } 654 655 /// Try to insert a new member \p Instr with index \p Index and 656 /// alignment \p NewAlign. The index is related to the leader and it could be 657 /// negative if it is the new leader. 658 /// 659 /// \returns false if the instruction doesn't belong to the group. 660 bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) { 661 // Make sure the key fits in an int32_t. 662 std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey); 663 if (!MaybeKey) 664 return false; 665 int32_t Key = *MaybeKey; 666 667 // Skip if the key is used for either the tombstone or empty special values. 668 if (DenseMapInfo<int32_t>::getTombstoneKey() == Key || 669 DenseMapInfo<int32_t>::getEmptyKey() == Key) 670 return false; 671 672 // Skip if there is already a member with the same index. 673 if (Members.find(Key) != Members.end()) 674 return false; 675 676 if (Key > LargestKey) { 677 // The largest index is always less than the interleave factor. 678 if (Index >= static_cast<int32_t>(Factor)) 679 return false; 680 681 LargestKey = Key; 682 } else if (Key < SmallestKey) { 683 684 // Make sure the largest index fits in an int32_t. 685 std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key); 686 if (!MaybeLargestIndex) 687 return false; 688 689 // The largest index is always less than the interleave factor. 690 if (*MaybeLargestIndex >= static_cast<int64_t>(Factor)) 691 return false; 692 693 SmallestKey = Key; 694 } 695 696 // It's always safe to select the minimum alignment. 697 Alignment = std::min(Alignment, NewAlign); 698 Members[Key] = Instr; 699 return true; 700 } 701 702 /// Get the member with the given index \p Index 703 /// 704 /// \returns nullptr if contains no such member. 705 InstTy *getMember(uint32_t Index) const { 706 int32_t Key = SmallestKey + Index; 707 return Members.lookup(Key); 708 } 709 710 /// Get the index for the given member. Unlike the key in the member 711 /// map, the index starts from 0. 712 uint32_t getIndex(const InstTy *Instr) const { 713 for (auto I : Members) { 714 if (I.second == Instr) 715 return I.first - SmallestKey; 716 } 717 718 llvm_unreachable("InterleaveGroup contains no such member"); 719 } 720 721 InstTy *getInsertPos() const { return InsertPos; } 722 void setInsertPos(InstTy *Inst) { InsertPos = Inst; } 723 724 /// Add metadata (e.g. alias info) from the instructions in this group to \p 725 /// NewInst. 726 /// 727 /// FIXME: this function currently does not add noalias metadata a'la 728 /// addNewMedata. To do that we need to compute the intersection of the 729 /// noalias info from all members. 730 void addMetadata(InstTy *NewInst) const; 731 732 /// Returns true if this Group requires a scalar iteration to handle gaps. 733 bool requiresScalarEpilogue() const { 734 // If the last member of the Group exists, then a scalar epilog is not 735 // needed for this group. 736 if (getMember(getFactor() - 1)) 737 return false; 738 739 // We have a group with gaps. It therefore can't be a reversed access, 740 // because such groups get invalidated (TODO). 741 assert(!isReverse() && "Group should have been invalidated"); 742 743 // This is a group of loads, with gaps, and without a last-member 744 return true; 745 } 746 747 private: 748 uint32_t Factor; // Interleave Factor. 749 bool Reverse; 750 Align Alignment; 751 DenseMap<int32_t, InstTy *> Members; 752 int32_t SmallestKey = 0; 753 int32_t LargestKey = 0; 754 755 // To avoid breaking dependences, vectorized instructions of an interleave 756 // group should be inserted at either the first load or the last store in 757 // program order. 758 // 759 // E.g. %even = load i32 // Insert Position 760 // %add = add i32 %even // Use of %even 761 // %odd = load i32 762 // 763 // store i32 %even 764 // %odd = add i32 // Def of %odd 765 // store i32 %odd // Insert Position 766 InstTy *InsertPos; 767 }; 768 769 /// Drive the analysis of interleaved memory accesses in the loop. 770 /// 771 /// Use this class to analyze interleaved accesses only when we can vectorize 772 /// a loop. Otherwise it's meaningless to do analysis as the vectorization 773 /// on interleaved accesses is unsafe. 774 /// 775 /// The analysis collects interleave groups and records the relationships 776 /// between the member and the group in a map. 777 class InterleavedAccessInfo { 778 public: 779 InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, 780 DominatorTree *DT, LoopInfo *LI, 781 const LoopAccessInfo *LAI) 782 : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {} 783 784 ~InterleavedAccessInfo() { invalidateGroups(); } 785 786 /// Analyze the interleaved accesses and collect them in interleave 787 /// groups. Substitute symbolic strides using \p Strides. 788 /// Consider also predicated loads/stores in the analysis if 789 /// \p EnableMaskedInterleavedGroup is true. 790 void analyzeInterleaving(bool EnableMaskedInterleavedGroup); 791 792 /// Invalidate groups, e.g., in case all blocks in loop will be predicated 793 /// contrary to original assumption. Although we currently prevent group 794 /// formation for predicated accesses, we may be able to relax this limitation 795 /// in the future once we handle more complicated blocks. Returns true if any 796 /// groups were invalidated. 797 bool invalidateGroups() { 798 if (InterleaveGroups.empty()) { 799 assert( 800 !RequiresScalarEpilogue && 801 "RequiresScalarEpilog should not be set without interleave groups"); 802 return false; 803 } 804 805 InterleaveGroupMap.clear(); 806 for (auto *Ptr : InterleaveGroups) 807 delete Ptr; 808 InterleaveGroups.clear(); 809 RequiresScalarEpilogue = false; 810 return true; 811 } 812 813 /// Check if \p Instr belongs to any interleave group. 814 bool isInterleaved(Instruction *Instr) const { 815 return InterleaveGroupMap.contains(Instr); 816 } 817 818 /// Get the interleave group that \p Instr belongs to. 819 /// 820 /// \returns nullptr if doesn't have such group. 821 InterleaveGroup<Instruction> * 822 getInterleaveGroup(const Instruction *Instr) const { 823 return InterleaveGroupMap.lookup(Instr); 824 } 825 826 iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>> 827 getInterleaveGroups() { 828 return make_range(InterleaveGroups.begin(), InterleaveGroups.end()); 829 } 830 831 /// Returns true if an interleaved group that may access memory 832 /// out-of-bounds requires a scalar epilogue iteration for correctness. 833 bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; } 834 835 /// Invalidate groups that require a scalar epilogue (due to gaps). This can 836 /// happen when optimizing for size forbids a scalar epilogue, and the gap 837 /// cannot be filtered by masking the load/store. 838 void invalidateGroupsRequiringScalarEpilogue(); 839 840 /// Returns true if we have any interleave groups. 841 bool hasGroups() const { return !InterleaveGroups.empty(); } 842 843 private: 844 /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. 845 /// Simplifies SCEV expressions in the context of existing SCEV assumptions. 846 /// The interleaved access analysis can also add new predicates (for example 847 /// by versioning strides of pointers). 848 PredicatedScalarEvolution &PSE; 849 850 Loop *TheLoop; 851 DominatorTree *DT; 852 LoopInfo *LI; 853 const LoopAccessInfo *LAI; 854 855 /// True if the loop may contain non-reversed interleaved groups with 856 /// out-of-bounds accesses. We ensure we don't speculatively access memory 857 /// out-of-bounds by executing at least one scalar epilogue iteration. 858 bool RequiresScalarEpilogue = false; 859 860 /// Holds the relationships between the members and the interleave group. 861 DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap; 862 863 SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups; 864 865 /// Holds dependences among the memory accesses in the loop. It maps a source 866 /// access to a set of dependent sink accesses. 867 DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences; 868 869 /// The descriptor for a strided memory access. 870 struct StrideDescriptor { 871 StrideDescriptor() = default; 872 StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, 873 Align Alignment) 874 : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {} 875 876 // The access's stride. It is negative for a reverse access. 877 int64_t Stride = 0; 878 879 // The scalar expression of this access. 880 const SCEV *Scev = nullptr; 881 882 // The size of the memory object. 883 uint64_t Size = 0; 884 885 // The alignment of this access. 886 Align Alignment; 887 }; 888 889 /// A type for holding instructions and their stride descriptors. 890 using StrideEntry = std::pair<Instruction *, StrideDescriptor>; 891 892 /// Create a new interleave group with the given instruction \p Instr, 893 /// stride \p Stride and alignment \p Align. 894 /// 895 /// \returns the newly created interleave group. 896 InterleaveGroup<Instruction> * 897 createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) { 898 assert(!InterleaveGroupMap.count(Instr) && 899 "Already in an interleaved access group"); 900 InterleaveGroupMap[Instr] = 901 new InterleaveGroup<Instruction>(Instr, Stride, Alignment); 902 InterleaveGroups.insert(InterleaveGroupMap[Instr]); 903 return InterleaveGroupMap[Instr]; 904 } 905 906 /// Release the group and remove all the relationships. 907 void releaseGroup(InterleaveGroup<Instruction> *Group) { 908 for (unsigned i = 0; i < Group->getFactor(); i++) 909 if (Instruction *Member = Group->getMember(i)) 910 InterleaveGroupMap.erase(Member); 911 912 InterleaveGroups.erase(Group); 913 delete Group; 914 } 915 916 /// Collect all the accesses with a constant stride in program order. 917 void collectConstStrideAccesses( 918 MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, 919 const DenseMap<Value *, const SCEV *> &Strides); 920 921 /// Returns true if \p Stride is allowed in an interleaved group. 922 static bool isStrided(int Stride); 923 924 /// Returns true if \p BB is a predicated block. 925 bool isPredicated(BasicBlock *BB) const { 926 return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); 927 } 928 929 /// Returns true if LoopAccessInfo can be used for dependence queries. 930 bool areDependencesValid() const { 931 return LAI && LAI->getDepChecker().getDependences(); 932 } 933 934 /// Returns true if memory accesses \p A and \p B can be reordered, if 935 /// necessary, when constructing interleaved groups. 936 /// 937 /// \p A must precede \p B in program order. We return false if reordering is 938 /// not necessary or is prevented because \p A and \p B may be dependent. 939 bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A, 940 StrideEntry *B) const { 941 // Code motion for interleaved accesses can potentially hoist strided loads 942 // and sink strided stores. The code below checks the legality of the 943 // following two conditions: 944 // 945 // 1. Potentially moving a strided load (B) before any store (A) that 946 // precedes B, or 947 // 948 // 2. Potentially moving a strided store (A) after any load or store (B) 949 // that A precedes. 950 // 951 // It's legal to reorder A and B if we know there isn't a dependence from A 952 // to B. Note that this determination is conservative since some 953 // dependences could potentially be reordered safely. 954 955 // A is potentially the source of a dependence. 956 auto *Src = A->first; 957 auto SrcDes = A->second; 958 959 // B is potentially the sink of a dependence. 960 auto *Sink = B->first; 961 auto SinkDes = B->second; 962 963 // Code motion for interleaved accesses can't violate WAR dependences. 964 // Thus, reordering is legal if the source isn't a write. 965 if (!Src->mayWriteToMemory()) 966 return true; 967 968 // At least one of the accesses must be strided. 969 if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride)) 970 return true; 971 972 // If dependence information is not available from LoopAccessInfo, 973 // conservatively assume the instructions can't be reordered. 974 if (!areDependencesValid()) 975 return false; 976 977 // If we know there is a dependence from source to sink, assume the 978 // instructions can't be reordered. Otherwise, reordering is legal. 979 return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink); 980 } 981 982 /// Collect the dependences from LoopAccessInfo. 983 /// 984 /// We process the dependences once during the interleaved access analysis to 985 /// enable constant-time dependence queries. 986 void collectDependences() { 987 if (!areDependencesValid()) 988 return; 989 auto *Deps = LAI->getDepChecker().getDependences(); 990 for (auto Dep : *Deps) 991 Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI)); 992 } 993 }; 994 995 } // llvm namespace 996 997 #endif 998