1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 //                                    intrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/DomTreeUpdater.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Constant.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DerivedTypes.h"
24 #include "llvm/IR/Dominators.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Transforms/Scalar.h"
36 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
37 #include <cassert>
38 #include <optional>
39 
40 using namespace llvm;
41 
42 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
43 
44 namespace {
45 
46 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
47 public:
48   static char ID; // Pass identification, replacement for typeid
49 
50   explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
51     initializeScalarizeMaskedMemIntrinLegacyPassPass(
52         *PassRegistry::getPassRegistry());
53   }
54 
55   bool runOnFunction(Function &F) override;
56 
57   StringRef getPassName() const override {
58     return "Scalarize Masked Memory Intrinsics";
59   }
60 
61   void getAnalysisUsage(AnalysisUsage &AU) const override {
62     AU.addRequired<TargetTransformInfoWrapperPass>();
63     AU.addPreserved<DominatorTreeWrapperPass>();
64   }
65 };
66 
67 } // end anonymous namespace
68 
69 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
70                           const TargetTransformInfo &TTI, const DataLayout &DL,
71                           DomTreeUpdater *DTU);
72 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
73                              const TargetTransformInfo &TTI,
74                              const DataLayout &DL, DomTreeUpdater *DTU);
75 
76 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
77 
78 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
79                       "Scalarize unsupported masked memory intrinsics", false,
80                       false)
81 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
82 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
83 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
84                     "Scalarize unsupported masked memory intrinsics", false,
85                     false)
86 
87 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() {
88   return new ScalarizeMaskedMemIntrinLegacyPass();
89 }
90 
91 static bool isConstantIntVector(Value *Mask) {
92   Constant *C = dyn_cast<Constant>(Mask);
93   if (!C)
94     return false;
95 
96   unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
97   for (unsigned i = 0; i != NumElts; ++i) {
98     Constant *CElt = C->getAggregateElement(i);
99     if (!CElt || !isa<ConstantInt>(CElt))
100       return false;
101   }
102 
103   return true;
104 }
105 
106 static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
107                                 unsigned Idx) {
108   return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx;
109 }
110 
111 // Translate a masked load intrinsic like
112 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
113 //                               <16 x i1> %mask, <16 x i32> %passthru)
114 // to a chain of basic blocks, with loading element one-by-one if
115 // the appropriate mask bit is set
116 //
117 //  %1 = bitcast i8* %addr to i32*
118 //  %2 = extractelement <16 x i1> %mask, i32 0
119 //  br i1 %2, label %cond.load, label %else
120 //
121 // cond.load:                                        ; preds = %0
122 //  %3 = getelementptr i32* %1, i32 0
123 //  %4 = load i32* %3
124 //  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
125 //  br label %else
126 //
127 // else:                                             ; preds = %0, %cond.load
128 //  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
129 //  %6 = extractelement <16 x i1> %mask, i32 1
130 //  br i1 %6, label %cond.load1, label %else2
131 //
132 // cond.load1:                                       ; preds = %else
133 //  %7 = getelementptr i32* %1, i32 1
134 //  %8 = load i32* %7
135 //  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
136 //  br label %else2
137 //
138 // else2:                                          ; preds = %else, %cond.load1
139 //  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
140 //  %10 = extractelement <16 x i1> %mask, i32 2
141 //  br i1 %10, label %cond.load4, label %else5
142 //
143 static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
144                                 DomTreeUpdater *DTU, bool &ModifiedDT) {
145   Value *Ptr = CI->getArgOperand(0);
146   Value *Alignment = CI->getArgOperand(1);
147   Value *Mask = CI->getArgOperand(2);
148   Value *Src0 = CI->getArgOperand(3);
149 
150   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
151   VectorType *VecType = cast<FixedVectorType>(CI->getType());
152 
153   Type *EltTy = VecType->getElementType();
154 
155   IRBuilder<> Builder(CI->getContext());
156   Instruction *InsertPt = CI;
157   BasicBlock *IfBlock = CI->getParent();
158 
159   Builder.SetInsertPoint(InsertPt);
160   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
161 
162   // Short-cut if the mask is all-true.
163   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
164     Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
165     CI->replaceAllUsesWith(NewI);
166     CI->eraseFromParent();
167     return;
168   }
169 
170   // Adjust alignment for the scalar instruction.
171   const Align AdjustedAlignVal =
172       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
173   // Bitcast %addr from i8* to EltTy*
174   Type *NewPtrType =
175       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
176   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
177   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
178 
179   // The result vector
180   Value *VResult = Src0;
181 
182   if (isConstantIntVector(Mask)) {
183     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
184       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
185         continue;
186       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
187       LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
188       VResult = Builder.CreateInsertElement(VResult, Load, Idx);
189     }
190     CI->replaceAllUsesWith(VResult);
191     CI->eraseFromParent();
192     return;
193   }
194 
195   // If the mask is not v1i1, use scalar bit test operations. This generates
196   // better results on X86 at least.
197   Value *SclrMask;
198   if (VectorWidth != 1) {
199     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
200     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
201   }
202 
203   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
204     // Fill the "else" block, created in the previous iteration
205     //
206     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
207     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
208     //  %cond = icmp ne i16 %mask_1, 0
209     //  br i1 %mask_1, label %cond.load, label %else
210     //
211     Value *Predicate;
212     if (VectorWidth != 1) {
213       Value *Mask = Builder.getInt(APInt::getOneBitSet(
214           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
215       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
216                                        Builder.getIntN(VectorWidth, 0));
217     } else {
218       Predicate = Builder.CreateExtractElement(Mask, Idx);
219     }
220 
221     // Create "cond" block
222     //
223     //  %EltAddr = getelementptr i32* %1, i32 0
224     //  %Elt = load i32* %EltAddr
225     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
226     //
227     Instruction *ThenTerm =
228         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
229                                   /*BranchWeights=*/nullptr, DTU);
230 
231     BasicBlock *CondBlock = ThenTerm->getParent();
232     CondBlock->setName("cond.load");
233 
234     Builder.SetInsertPoint(CondBlock->getTerminator());
235     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
236     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
237     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
238 
239     // Create "else" block, fill it in the next iteration
240     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
241     NewIfBlock->setName("else");
242     BasicBlock *PrevIfBlock = IfBlock;
243     IfBlock = NewIfBlock;
244 
245     // Create the phi to join the new and previous value.
246     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
247     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
248     Phi->addIncoming(NewVResult, CondBlock);
249     Phi->addIncoming(VResult, PrevIfBlock);
250     VResult = Phi;
251   }
252 
253   CI->replaceAllUsesWith(VResult);
254   CI->eraseFromParent();
255 
256   ModifiedDT = true;
257 }
258 
259 // Translate a masked store intrinsic, like
260 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
261 //                               <16 x i1> %mask)
262 // to a chain of basic blocks, that stores element one-by-one if
263 // the appropriate mask bit is set
264 //
265 //   %1 = bitcast i8* %addr to i32*
266 //   %2 = extractelement <16 x i1> %mask, i32 0
267 //   br i1 %2, label %cond.store, label %else
268 //
269 // cond.store:                                       ; preds = %0
270 //   %3 = extractelement <16 x i32> %val, i32 0
271 //   %4 = getelementptr i32* %1, i32 0
272 //   store i32 %3, i32* %4
273 //   br label %else
274 //
275 // else:                                             ; preds = %0, %cond.store
276 //   %5 = extractelement <16 x i1> %mask, i32 1
277 //   br i1 %5, label %cond.store1, label %else2
278 //
279 // cond.store1:                                      ; preds = %else
280 //   %6 = extractelement <16 x i32> %val, i32 1
281 //   %7 = getelementptr i32* %1, i32 1
282 //   store i32 %6, i32* %7
283 //   br label %else2
284 //   . . .
285 static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
286                                  DomTreeUpdater *DTU, bool &ModifiedDT) {
287   Value *Src = CI->getArgOperand(0);
288   Value *Ptr = CI->getArgOperand(1);
289   Value *Alignment = CI->getArgOperand(2);
290   Value *Mask = CI->getArgOperand(3);
291 
292   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
293   auto *VecType = cast<VectorType>(Src->getType());
294 
295   Type *EltTy = VecType->getElementType();
296 
297   IRBuilder<> Builder(CI->getContext());
298   Instruction *InsertPt = CI;
299   Builder.SetInsertPoint(InsertPt);
300   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
301 
302   // Short-cut if the mask is all-true.
303   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
304     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
305     CI->eraseFromParent();
306     return;
307   }
308 
309   // Adjust alignment for the scalar instruction.
310   const Align AdjustedAlignVal =
311       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
312   // Bitcast %addr from i8* to EltTy*
313   Type *NewPtrType =
314       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
315   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
316   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
317 
318   if (isConstantIntVector(Mask)) {
319     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
320       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
321         continue;
322       Value *OneElt = Builder.CreateExtractElement(Src, Idx);
323       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
324       Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
325     }
326     CI->eraseFromParent();
327     return;
328   }
329 
330   // If the mask is not v1i1, use scalar bit test operations. This generates
331   // better results on X86 at least.
332   Value *SclrMask;
333   if (VectorWidth != 1) {
334     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
335     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
336   }
337 
338   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
339     // Fill the "else" block, created in the previous iteration
340     //
341     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
342     //  %cond = icmp ne i16 %mask_1, 0
343     //  br i1 %mask_1, label %cond.store, label %else
344     //
345     Value *Predicate;
346     if (VectorWidth != 1) {
347       Value *Mask = Builder.getInt(APInt::getOneBitSet(
348           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
349       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
350                                        Builder.getIntN(VectorWidth, 0));
351     } else {
352       Predicate = Builder.CreateExtractElement(Mask, Idx);
353     }
354 
355     // Create "cond" block
356     //
357     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
358     //  %EltAddr = getelementptr i32* %1, i32 0
359     //  %store i32 %OneElt, i32* %EltAddr
360     //
361     Instruction *ThenTerm =
362         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
363                                   /*BranchWeights=*/nullptr, DTU);
364 
365     BasicBlock *CondBlock = ThenTerm->getParent();
366     CondBlock->setName("cond.store");
367 
368     Builder.SetInsertPoint(CondBlock->getTerminator());
369     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
370     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
371     Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
372 
373     // Create "else" block, fill it in the next iteration
374     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
375     NewIfBlock->setName("else");
376 
377     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
378   }
379   CI->eraseFromParent();
380 
381   ModifiedDT = true;
382 }
383 
384 // Translate a masked gather intrinsic like
385 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
386 //                               <16 x i1> %Mask, <16 x i32> %Src)
387 // to a chain of basic blocks, with loading element one-by-one if
388 // the appropriate mask bit is set
389 //
390 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
391 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
392 // br i1 %Mask0, label %cond.load, label %else
393 //
394 // cond.load:
395 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
396 // %Load0 = load i32, i32* %Ptr0, align 4
397 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
398 // br label %else
399 //
400 // else:
401 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
402 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
403 // br i1 %Mask1, label %cond.load1, label %else2
404 //
405 // cond.load1:
406 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
407 // %Load1 = load i32, i32* %Ptr1, align 4
408 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
409 // br label %else2
410 // . . .
411 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
412 // ret <16 x i32> %Result
413 static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
414                                   DomTreeUpdater *DTU, bool &ModifiedDT) {
415   Value *Ptrs = CI->getArgOperand(0);
416   Value *Alignment = CI->getArgOperand(1);
417   Value *Mask = CI->getArgOperand(2);
418   Value *Src0 = CI->getArgOperand(3);
419 
420   auto *VecType = cast<FixedVectorType>(CI->getType());
421   Type *EltTy = VecType->getElementType();
422 
423   IRBuilder<> Builder(CI->getContext());
424   Instruction *InsertPt = CI;
425   BasicBlock *IfBlock = CI->getParent();
426   Builder.SetInsertPoint(InsertPt);
427   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
428 
429   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
430 
431   // The result vector
432   Value *VResult = Src0;
433   unsigned VectorWidth = VecType->getNumElements();
434 
435   // Shorten the way if the mask is a vector of constants.
436   if (isConstantIntVector(Mask)) {
437     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
438       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
439         continue;
440       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
441       LoadInst *Load =
442           Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
443       VResult =
444           Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
445     }
446     CI->replaceAllUsesWith(VResult);
447     CI->eraseFromParent();
448     return;
449   }
450 
451   // If the mask is not v1i1, use scalar bit test operations. This generates
452   // better results on X86 at least.
453   Value *SclrMask;
454   if (VectorWidth != 1) {
455     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
456     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
457   }
458 
459   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
460     // Fill the "else" block, created in the previous iteration
461     //
462     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
463     //  %cond = icmp ne i16 %mask_1, 0
464     //  br i1 %Mask1, label %cond.load, label %else
465     //
466 
467     Value *Predicate;
468     if (VectorWidth != 1) {
469       Value *Mask = Builder.getInt(APInt::getOneBitSet(
470           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
471       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
472                                        Builder.getIntN(VectorWidth, 0));
473     } else {
474       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
475     }
476 
477     // Create "cond" block
478     //
479     //  %EltAddr = getelementptr i32* %1, i32 0
480     //  %Elt = load i32* %EltAddr
481     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
482     //
483     Instruction *ThenTerm =
484         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
485                                   /*BranchWeights=*/nullptr, DTU);
486 
487     BasicBlock *CondBlock = ThenTerm->getParent();
488     CondBlock->setName("cond.load");
489 
490     Builder.SetInsertPoint(CondBlock->getTerminator());
491     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
492     LoadInst *Load =
493         Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
494     Value *NewVResult =
495         Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
496 
497     // Create "else" block, fill it in the next iteration
498     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
499     NewIfBlock->setName("else");
500     BasicBlock *PrevIfBlock = IfBlock;
501     IfBlock = NewIfBlock;
502 
503     // Create the phi to join the new and previous value.
504     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
505     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
506     Phi->addIncoming(NewVResult, CondBlock);
507     Phi->addIncoming(VResult, PrevIfBlock);
508     VResult = Phi;
509   }
510 
511   CI->replaceAllUsesWith(VResult);
512   CI->eraseFromParent();
513 
514   ModifiedDT = true;
515 }
516 
517 // Translate a masked scatter intrinsic, like
518 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
519 //                                  <16 x i1> %Mask)
520 // to a chain of basic blocks, that stores element one-by-one if
521 // the appropriate mask bit is set.
522 //
523 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
524 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
525 // br i1 %Mask0, label %cond.store, label %else
526 //
527 // cond.store:
528 // %Elt0 = extractelement <16 x i32> %Src, i32 0
529 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
530 // store i32 %Elt0, i32* %Ptr0, align 4
531 // br label %else
532 //
533 // else:
534 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
535 // br i1 %Mask1, label %cond.store1, label %else2
536 //
537 // cond.store1:
538 // %Elt1 = extractelement <16 x i32> %Src, i32 1
539 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
540 // store i32 %Elt1, i32* %Ptr1, align 4
541 // br label %else2
542 //   . . .
543 static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
544                                    DomTreeUpdater *DTU, bool &ModifiedDT) {
545   Value *Src = CI->getArgOperand(0);
546   Value *Ptrs = CI->getArgOperand(1);
547   Value *Alignment = CI->getArgOperand(2);
548   Value *Mask = CI->getArgOperand(3);
549 
550   auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
551 
552   assert(
553       isa<VectorType>(Ptrs->getType()) &&
554       isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
555       "Vector of pointers is expected in masked scatter intrinsic");
556 
557   IRBuilder<> Builder(CI->getContext());
558   Instruction *InsertPt = CI;
559   Builder.SetInsertPoint(InsertPt);
560   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
561 
562   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
563   unsigned VectorWidth = SrcFVTy->getNumElements();
564 
565   // Shorten the way if the mask is a vector of constants.
566   if (isConstantIntVector(Mask)) {
567     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
568       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
569         continue;
570       Value *OneElt =
571           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
572       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
573       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
574     }
575     CI->eraseFromParent();
576     return;
577   }
578 
579   // If the mask is not v1i1, use scalar bit test operations. This generates
580   // better results on X86 at least.
581   Value *SclrMask;
582   if (VectorWidth != 1) {
583     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
584     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
585   }
586 
587   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
588     // Fill the "else" block, created in the previous iteration
589     //
590     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
591     //  %cond = icmp ne i16 %mask_1, 0
592     //  br i1 %Mask1, label %cond.store, label %else
593     //
594     Value *Predicate;
595     if (VectorWidth != 1) {
596       Value *Mask = Builder.getInt(APInt::getOneBitSet(
597           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
598       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
599                                        Builder.getIntN(VectorWidth, 0));
600     } else {
601       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
602     }
603 
604     // Create "cond" block
605     //
606     //  %Elt1 = extractelement <16 x i32> %Src, i32 1
607     //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
608     //  %store i32 %Elt1, i32* %Ptr1
609     //
610     Instruction *ThenTerm =
611         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
612                                   /*BranchWeights=*/nullptr, DTU);
613 
614     BasicBlock *CondBlock = ThenTerm->getParent();
615     CondBlock->setName("cond.store");
616 
617     Builder.SetInsertPoint(CondBlock->getTerminator());
618     Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
619     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
620     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
621 
622     // Create "else" block, fill it in the next iteration
623     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
624     NewIfBlock->setName("else");
625 
626     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
627   }
628   CI->eraseFromParent();
629 
630   ModifiedDT = true;
631 }
632 
633 static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
634                                       DomTreeUpdater *DTU, bool &ModifiedDT) {
635   Value *Ptr = CI->getArgOperand(0);
636   Value *Mask = CI->getArgOperand(1);
637   Value *PassThru = CI->getArgOperand(2);
638 
639   auto *VecType = cast<FixedVectorType>(CI->getType());
640 
641   Type *EltTy = VecType->getElementType();
642 
643   IRBuilder<> Builder(CI->getContext());
644   Instruction *InsertPt = CI;
645   BasicBlock *IfBlock = CI->getParent();
646 
647   Builder.SetInsertPoint(InsertPt);
648   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
649 
650   unsigned VectorWidth = VecType->getNumElements();
651 
652   // The result vector
653   Value *VResult = PassThru;
654 
655   // Shorten the way if the mask is a vector of constants.
656   // Create a build_vector pattern, with loads/undefs as necessary and then
657   // shuffle blend with the pass through value.
658   if (isConstantIntVector(Mask)) {
659     unsigned MemIndex = 0;
660     VResult = PoisonValue::get(VecType);
661     SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem);
662     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
663       Value *InsertElt;
664       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
665         InsertElt = UndefValue::get(EltTy);
666         ShuffleMask[Idx] = Idx + VectorWidth;
667       } else {
668         Value *NewPtr =
669             Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
670         InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1),
671                                               "Load" + Twine(Idx));
672         ShuffleMask[Idx] = Idx;
673         ++MemIndex;
674       }
675       VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
676                                             "Res" + Twine(Idx));
677     }
678     VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
679     CI->replaceAllUsesWith(VResult);
680     CI->eraseFromParent();
681     return;
682   }
683 
684   // If the mask is not v1i1, use scalar bit test operations. This generates
685   // better results on X86 at least.
686   Value *SclrMask;
687   if (VectorWidth != 1) {
688     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
689     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
690   }
691 
692   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
693     // Fill the "else" block, created in the previous iteration
694     //
695     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
696     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
697     //  br i1 %mask_1, label %cond.load, label %else
698     //
699 
700     Value *Predicate;
701     if (VectorWidth != 1) {
702       Value *Mask = Builder.getInt(APInt::getOneBitSet(
703           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
704       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
705                                        Builder.getIntN(VectorWidth, 0));
706     } else {
707       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
708     }
709 
710     // Create "cond" block
711     //
712     //  %EltAddr = getelementptr i32* %1, i32 0
713     //  %Elt = load i32* %EltAddr
714     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
715     //
716     Instruction *ThenTerm =
717         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
718                                   /*BranchWeights=*/nullptr, DTU);
719 
720     BasicBlock *CondBlock = ThenTerm->getParent();
721     CondBlock->setName("cond.load");
722 
723     Builder.SetInsertPoint(CondBlock->getTerminator());
724     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1));
725     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
726 
727     // Move the pointer if there are more blocks to come.
728     Value *NewPtr;
729     if ((Idx + 1) != VectorWidth)
730       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
731 
732     // Create "else" block, fill it in the next iteration
733     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
734     NewIfBlock->setName("else");
735     BasicBlock *PrevIfBlock = IfBlock;
736     IfBlock = NewIfBlock;
737 
738     // Create the phi to join the new and previous value.
739     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
740     PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
741     ResultPhi->addIncoming(NewVResult, CondBlock);
742     ResultPhi->addIncoming(VResult, PrevIfBlock);
743     VResult = ResultPhi;
744 
745     // Add a PHI for the pointer if this isn't the last iteration.
746     if ((Idx + 1) != VectorWidth) {
747       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
748       PtrPhi->addIncoming(NewPtr, CondBlock);
749       PtrPhi->addIncoming(Ptr, PrevIfBlock);
750       Ptr = PtrPhi;
751     }
752   }
753 
754   CI->replaceAllUsesWith(VResult);
755   CI->eraseFromParent();
756 
757   ModifiedDT = true;
758 }
759 
760 static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
761                                          DomTreeUpdater *DTU,
762                                          bool &ModifiedDT) {
763   Value *Src = CI->getArgOperand(0);
764   Value *Ptr = CI->getArgOperand(1);
765   Value *Mask = CI->getArgOperand(2);
766 
767   auto *VecType = cast<FixedVectorType>(Src->getType());
768 
769   IRBuilder<> Builder(CI->getContext());
770   Instruction *InsertPt = CI;
771   BasicBlock *IfBlock = CI->getParent();
772 
773   Builder.SetInsertPoint(InsertPt);
774   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
775 
776   Type *EltTy = VecType->getElementType();
777 
778   unsigned VectorWidth = VecType->getNumElements();
779 
780   // Shorten the way if the mask is a vector of constants.
781   if (isConstantIntVector(Mask)) {
782     unsigned MemIndex = 0;
783     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
784       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
785         continue;
786       Value *OneElt =
787           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
788       Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
789       Builder.CreateAlignedStore(OneElt, NewPtr, Align(1));
790       ++MemIndex;
791     }
792     CI->eraseFromParent();
793     return;
794   }
795 
796   // If the mask is not v1i1, use scalar bit test operations. This generates
797   // better results on X86 at least.
798   Value *SclrMask;
799   if (VectorWidth != 1) {
800     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
801     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
802   }
803 
804   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
805     // Fill the "else" block, created in the previous iteration
806     //
807     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
808     //  br i1 %mask_1, label %cond.store, label %else
809     //
810     Value *Predicate;
811     if (VectorWidth != 1) {
812       Value *Mask = Builder.getInt(APInt::getOneBitSet(
813           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
814       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
815                                        Builder.getIntN(VectorWidth, 0));
816     } else {
817       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
818     }
819 
820     // Create "cond" block
821     //
822     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
823     //  %EltAddr = getelementptr i32* %1, i32 0
824     //  %store i32 %OneElt, i32* %EltAddr
825     //
826     Instruction *ThenTerm =
827         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
828                                   /*BranchWeights=*/nullptr, DTU);
829 
830     BasicBlock *CondBlock = ThenTerm->getParent();
831     CondBlock->setName("cond.store");
832 
833     Builder.SetInsertPoint(CondBlock->getTerminator());
834     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
835     Builder.CreateAlignedStore(OneElt, Ptr, Align(1));
836 
837     // Move the pointer if there are more blocks to come.
838     Value *NewPtr;
839     if ((Idx + 1) != VectorWidth)
840       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
841 
842     // Create "else" block, fill it in the next iteration
843     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
844     NewIfBlock->setName("else");
845     BasicBlock *PrevIfBlock = IfBlock;
846     IfBlock = NewIfBlock;
847 
848     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
849 
850     // Add a PHI for the pointer if this isn't the last iteration.
851     if ((Idx + 1) != VectorWidth) {
852       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
853       PtrPhi->addIncoming(NewPtr, CondBlock);
854       PtrPhi->addIncoming(Ptr, PrevIfBlock);
855       Ptr = PtrPhi;
856     }
857   }
858   CI->eraseFromParent();
859 
860   ModifiedDT = true;
861 }
862 
863 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
864                     DominatorTree *DT) {
865   std::optional<DomTreeUpdater> DTU;
866   if (DT)
867     DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
868 
869   bool EverMadeChange = false;
870   bool MadeChange = true;
871   auto &DL = F.getParent()->getDataLayout();
872   while (MadeChange) {
873     MadeChange = false;
874     for (BasicBlock &BB : llvm::make_early_inc_range(F)) {
875       bool ModifiedDTOnIteration = false;
876       MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
877                                   DTU ? &*DTU : nullptr);
878 
879       // Restart BB iteration if the dominator tree of the Function was changed
880       if (ModifiedDTOnIteration)
881         break;
882     }
883 
884     EverMadeChange |= MadeChange;
885   }
886   return EverMadeChange;
887 }
888 
889 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
890   auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
891   DominatorTree *DT = nullptr;
892   if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
893     DT = &DTWP->getDomTree();
894   return runImpl(F, TTI, DT);
895 }
896 
897 PreservedAnalyses
898 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
899   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
900   auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
901   if (!runImpl(F, TTI, DT))
902     return PreservedAnalyses::all();
903   PreservedAnalyses PA;
904   PA.preserve<TargetIRAnalysis>();
905   PA.preserve<DominatorTreeAnalysis>();
906   return PA;
907 }
908 
909 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
910                           const TargetTransformInfo &TTI, const DataLayout &DL,
911                           DomTreeUpdater *DTU) {
912   bool MadeChange = false;
913 
914   BasicBlock::iterator CurInstIterator = BB.begin();
915   while (CurInstIterator != BB.end()) {
916     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
917       MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU);
918     if (ModifiedDT)
919       return true;
920   }
921 
922   return MadeChange;
923 }
924 
925 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
926                              const TargetTransformInfo &TTI,
927                              const DataLayout &DL, DomTreeUpdater *DTU) {
928   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
929   if (II) {
930     // The scalarization code below does not work for scalable vectors.
931     if (isa<ScalableVectorType>(II->getType()) ||
932         any_of(II->args(),
933                [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
934       return false;
935 
936     switch (II->getIntrinsicID()) {
937     default:
938       break;
939     case Intrinsic::masked_load:
940       // Scalarize unsupported vector masked load
941       if (TTI.isLegalMaskedLoad(
942               CI->getType(),
943               cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
944         return false;
945       scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT);
946       return true;
947     case Intrinsic::masked_store:
948       if (TTI.isLegalMaskedStore(
949               CI->getArgOperand(0)->getType(),
950               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
951         return false;
952       scalarizeMaskedStore(DL, CI, DTU, ModifiedDT);
953       return true;
954     case Intrinsic::masked_gather: {
955       MaybeAlign MA =
956           cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue();
957       Type *LoadTy = CI->getType();
958       Align Alignment = DL.getValueOrABITypeAlignment(MA,
959                                                       LoadTy->getScalarType());
960       if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
961           !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
962         return false;
963       scalarizeMaskedGather(DL, CI, DTU, ModifiedDT);
964       return true;
965     }
966     case Intrinsic::masked_scatter: {
967       MaybeAlign MA =
968           cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue();
969       Type *StoreTy = CI->getArgOperand(0)->getType();
970       Align Alignment = DL.getValueOrABITypeAlignment(MA,
971                                                       StoreTy->getScalarType());
972       if (TTI.isLegalMaskedScatter(StoreTy, Alignment) &&
973           !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
974                                            Alignment))
975         return false;
976       scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT);
977       return true;
978     }
979     case Intrinsic::masked_expandload:
980       if (TTI.isLegalMaskedExpandLoad(CI->getType()))
981         return false;
982       scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
983       return true;
984     case Intrinsic::masked_compressstore:
985       if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
986         return false;
987       scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
988       return true;
989     }
990   }
991 
992   return false;
993 }
994