1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 //                                    intrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/DomTreeUpdater.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Constant.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DerivedTypes.h"
24 #include "llvm/IR/Dominators.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Transforms/Scalar.h"
36 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
37 #include <cassert>
38 
39 using namespace llvm;
40 
41 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
42 
43 namespace {
44 
45 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
46 public:
47   static char ID; // Pass identification, replacement for typeid
48 
49   explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
50     initializeScalarizeMaskedMemIntrinLegacyPassPass(
51         *PassRegistry::getPassRegistry());
52   }
53 
54   bool runOnFunction(Function &F) override;
55 
56   StringRef getPassName() const override {
57     return "Scalarize Masked Memory Intrinsics";
58   }
59 
60   void getAnalysisUsage(AnalysisUsage &AU) const override {
61     AU.addRequired<TargetTransformInfoWrapperPass>();
62     AU.addPreserved<DominatorTreeWrapperPass>();
63   }
64 };
65 
66 } // end anonymous namespace
67 
68 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
69                           const TargetTransformInfo &TTI, const DataLayout &DL,
70                           DomTreeUpdater *DTU);
71 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
72                              const TargetTransformInfo &TTI,
73                              const DataLayout &DL, DomTreeUpdater *DTU);
74 
75 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
76 
77 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
78                       "Scalarize unsupported masked memory intrinsics", false,
79                       false)
80 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
81 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
82 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
83                     "Scalarize unsupported masked memory intrinsics", false,
84                     false)
85 
86 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() {
87   return new ScalarizeMaskedMemIntrinLegacyPass();
88 }
89 
90 static bool isConstantIntVector(Value *Mask) {
91   Constant *C = dyn_cast<Constant>(Mask);
92   if (!C)
93     return false;
94 
95   unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
96   for (unsigned i = 0; i != NumElts; ++i) {
97     Constant *CElt = C->getAggregateElement(i);
98     if (!CElt || !isa<ConstantInt>(CElt))
99       return false;
100   }
101 
102   return true;
103 }
104 
105 static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
106                                 unsigned Idx) {
107   return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx;
108 }
109 
110 // Translate a masked load intrinsic like
111 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
112 //                               <16 x i1> %mask, <16 x i32> %passthru)
113 // to a chain of basic blocks, with loading element one-by-one if
114 // the appropriate mask bit is set
115 //
116 //  %1 = bitcast i8* %addr to i32*
117 //  %2 = extractelement <16 x i1> %mask, i32 0
118 //  br i1 %2, label %cond.load, label %else
119 //
120 // cond.load:                                        ; preds = %0
121 //  %3 = getelementptr i32* %1, i32 0
122 //  %4 = load i32* %3
123 //  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
124 //  br label %else
125 //
126 // else:                                             ; preds = %0, %cond.load
127 //  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
128 //  %6 = extractelement <16 x i1> %mask, i32 1
129 //  br i1 %6, label %cond.load1, label %else2
130 //
131 // cond.load1:                                       ; preds = %else
132 //  %7 = getelementptr i32* %1, i32 1
133 //  %8 = load i32* %7
134 //  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
135 //  br label %else2
136 //
137 // else2:                                          ; preds = %else, %cond.load1
138 //  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
139 //  %10 = extractelement <16 x i1> %mask, i32 2
140 //  br i1 %10, label %cond.load4, label %else5
141 //
142 static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
143                                 DomTreeUpdater *DTU, bool &ModifiedDT) {
144   Value *Ptr = CI->getArgOperand(0);
145   Value *Alignment = CI->getArgOperand(1);
146   Value *Mask = CI->getArgOperand(2);
147   Value *Src0 = CI->getArgOperand(3);
148 
149   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
150   VectorType *VecType = cast<FixedVectorType>(CI->getType());
151 
152   Type *EltTy = VecType->getElementType();
153 
154   IRBuilder<> Builder(CI->getContext());
155   Instruction *InsertPt = CI;
156   BasicBlock *IfBlock = CI->getParent();
157 
158   Builder.SetInsertPoint(InsertPt);
159   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
160 
161   // Short-cut if the mask is all-true.
162   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
163     Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
164     CI->replaceAllUsesWith(NewI);
165     CI->eraseFromParent();
166     return;
167   }
168 
169   // Adjust alignment for the scalar instruction.
170   const Align AdjustedAlignVal =
171       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
172   // Bitcast %addr from i8* to EltTy*
173   Type *NewPtrType =
174       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
175   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
176   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
177 
178   // The result vector
179   Value *VResult = Src0;
180 
181   if (isConstantIntVector(Mask)) {
182     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
183       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
184         continue;
185       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
186       LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
187       VResult = Builder.CreateInsertElement(VResult, Load, Idx);
188     }
189     CI->replaceAllUsesWith(VResult);
190     CI->eraseFromParent();
191     return;
192   }
193 
194   // If the mask is not v1i1, use scalar bit test operations. This generates
195   // better results on X86 at least.
196   Value *SclrMask;
197   if (VectorWidth != 1) {
198     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
199     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
200   }
201 
202   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
203     // Fill the "else" block, created in the previous iteration
204     //
205     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
206     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
207     //  %cond = icmp ne i16 %mask_1, 0
208     //  br i1 %mask_1, label %cond.load, label %else
209     //
210     Value *Predicate;
211     if (VectorWidth != 1) {
212       Value *Mask = Builder.getInt(APInt::getOneBitSet(
213           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
214       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
215                                        Builder.getIntN(VectorWidth, 0));
216     } else {
217       Predicate = Builder.CreateExtractElement(Mask, Idx);
218     }
219 
220     // Create "cond" block
221     //
222     //  %EltAddr = getelementptr i32* %1, i32 0
223     //  %Elt = load i32* %EltAddr
224     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
225     //
226     Instruction *ThenTerm =
227         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
228                                   /*BranchWeights=*/nullptr, DTU);
229 
230     BasicBlock *CondBlock = ThenTerm->getParent();
231     CondBlock->setName("cond.load");
232 
233     Builder.SetInsertPoint(CondBlock->getTerminator());
234     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
235     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
236     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
237 
238     // Create "else" block, fill it in the next iteration
239     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
240     NewIfBlock->setName("else");
241     BasicBlock *PrevIfBlock = IfBlock;
242     IfBlock = NewIfBlock;
243 
244     // Create the phi to join the new and previous value.
245     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
246     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
247     Phi->addIncoming(NewVResult, CondBlock);
248     Phi->addIncoming(VResult, PrevIfBlock);
249     VResult = Phi;
250   }
251 
252   CI->replaceAllUsesWith(VResult);
253   CI->eraseFromParent();
254 
255   ModifiedDT = true;
256 }
257 
258 // Translate a masked store intrinsic, like
259 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
260 //                               <16 x i1> %mask)
261 // to a chain of basic blocks, that stores element one-by-one if
262 // the appropriate mask bit is set
263 //
264 //   %1 = bitcast i8* %addr to i32*
265 //   %2 = extractelement <16 x i1> %mask, i32 0
266 //   br i1 %2, label %cond.store, label %else
267 //
268 // cond.store:                                       ; preds = %0
269 //   %3 = extractelement <16 x i32> %val, i32 0
270 //   %4 = getelementptr i32* %1, i32 0
271 //   store i32 %3, i32* %4
272 //   br label %else
273 //
274 // else:                                             ; preds = %0, %cond.store
275 //   %5 = extractelement <16 x i1> %mask, i32 1
276 //   br i1 %5, label %cond.store1, label %else2
277 //
278 // cond.store1:                                      ; preds = %else
279 //   %6 = extractelement <16 x i32> %val, i32 1
280 //   %7 = getelementptr i32* %1, i32 1
281 //   store i32 %6, i32* %7
282 //   br label %else2
283 //   . . .
284 static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
285                                  DomTreeUpdater *DTU, bool &ModifiedDT) {
286   Value *Src = CI->getArgOperand(0);
287   Value *Ptr = CI->getArgOperand(1);
288   Value *Alignment = CI->getArgOperand(2);
289   Value *Mask = CI->getArgOperand(3);
290 
291   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
292   auto *VecType = cast<VectorType>(Src->getType());
293 
294   Type *EltTy = VecType->getElementType();
295 
296   IRBuilder<> Builder(CI->getContext());
297   Instruction *InsertPt = CI;
298   Builder.SetInsertPoint(InsertPt);
299   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
300 
301   // Short-cut if the mask is all-true.
302   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
303     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
304     CI->eraseFromParent();
305     return;
306   }
307 
308   // Adjust alignment for the scalar instruction.
309   const Align AdjustedAlignVal =
310       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
311   // Bitcast %addr from i8* to EltTy*
312   Type *NewPtrType =
313       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
314   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
315   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
316 
317   if (isConstantIntVector(Mask)) {
318     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
319       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
320         continue;
321       Value *OneElt = Builder.CreateExtractElement(Src, Idx);
322       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
323       Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
324     }
325     CI->eraseFromParent();
326     return;
327   }
328 
329   // If the mask is not v1i1, use scalar bit test operations. This generates
330   // better results on X86 at least.
331   Value *SclrMask;
332   if (VectorWidth != 1) {
333     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
334     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
335   }
336 
337   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
338     // Fill the "else" block, created in the previous iteration
339     //
340     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
341     //  %cond = icmp ne i16 %mask_1, 0
342     //  br i1 %mask_1, label %cond.store, label %else
343     //
344     Value *Predicate;
345     if (VectorWidth != 1) {
346       Value *Mask = Builder.getInt(APInt::getOneBitSet(
347           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
348       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
349                                        Builder.getIntN(VectorWidth, 0));
350     } else {
351       Predicate = Builder.CreateExtractElement(Mask, Idx);
352     }
353 
354     // Create "cond" block
355     //
356     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
357     //  %EltAddr = getelementptr i32* %1, i32 0
358     //  %store i32 %OneElt, i32* %EltAddr
359     //
360     Instruction *ThenTerm =
361         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
362                                   /*BranchWeights=*/nullptr, DTU);
363 
364     BasicBlock *CondBlock = ThenTerm->getParent();
365     CondBlock->setName("cond.store");
366 
367     Builder.SetInsertPoint(CondBlock->getTerminator());
368     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
369     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
370     Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
371 
372     // Create "else" block, fill it in the next iteration
373     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
374     NewIfBlock->setName("else");
375 
376     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
377   }
378   CI->eraseFromParent();
379 
380   ModifiedDT = true;
381 }
382 
383 // Translate a masked gather intrinsic like
384 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
385 //                               <16 x i1> %Mask, <16 x i32> %Src)
386 // to a chain of basic blocks, with loading element one-by-one if
387 // the appropriate mask bit is set
388 //
389 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
390 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
391 // br i1 %Mask0, label %cond.load, label %else
392 //
393 // cond.load:
394 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
395 // %Load0 = load i32, i32* %Ptr0, align 4
396 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
397 // br label %else
398 //
399 // else:
400 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
401 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
402 // br i1 %Mask1, label %cond.load1, label %else2
403 //
404 // cond.load1:
405 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
406 // %Load1 = load i32, i32* %Ptr1, align 4
407 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
408 // br label %else2
409 // . . .
410 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
411 // ret <16 x i32> %Result
412 static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
413                                   DomTreeUpdater *DTU, bool &ModifiedDT) {
414   Value *Ptrs = CI->getArgOperand(0);
415   Value *Alignment = CI->getArgOperand(1);
416   Value *Mask = CI->getArgOperand(2);
417   Value *Src0 = CI->getArgOperand(3);
418 
419   auto *VecType = cast<FixedVectorType>(CI->getType());
420   Type *EltTy = VecType->getElementType();
421 
422   IRBuilder<> Builder(CI->getContext());
423   Instruction *InsertPt = CI;
424   BasicBlock *IfBlock = CI->getParent();
425   Builder.SetInsertPoint(InsertPt);
426   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
427 
428   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
429 
430   // The result vector
431   Value *VResult = Src0;
432   unsigned VectorWidth = VecType->getNumElements();
433 
434   // Shorten the way if the mask is a vector of constants.
435   if (isConstantIntVector(Mask)) {
436     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
437       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
438         continue;
439       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
440       LoadInst *Load =
441           Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
442       VResult =
443           Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
444     }
445     CI->replaceAllUsesWith(VResult);
446     CI->eraseFromParent();
447     return;
448   }
449 
450   // If the mask is not v1i1, use scalar bit test operations. This generates
451   // better results on X86 at least.
452   Value *SclrMask;
453   if (VectorWidth != 1) {
454     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
455     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
456   }
457 
458   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
459     // Fill the "else" block, created in the previous iteration
460     //
461     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
462     //  %cond = icmp ne i16 %mask_1, 0
463     //  br i1 %Mask1, label %cond.load, label %else
464     //
465 
466     Value *Predicate;
467     if (VectorWidth != 1) {
468       Value *Mask = Builder.getInt(APInt::getOneBitSet(
469           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
470       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
471                                        Builder.getIntN(VectorWidth, 0));
472     } else {
473       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
474     }
475 
476     // Create "cond" block
477     //
478     //  %EltAddr = getelementptr i32* %1, i32 0
479     //  %Elt = load i32* %EltAddr
480     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
481     //
482     Instruction *ThenTerm =
483         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
484                                   /*BranchWeights=*/nullptr, DTU);
485 
486     BasicBlock *CondBlock = ThenTerm->getParent();
487     CondBlock->setName("cond.load");
488 
489     Builder.SetInsertPoint(CondBlock->getTerminator());
490     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
491     LoadInst *Load =
492         Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
493     Value *NewVResult =
494         Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
495 
496     // Create "else" block, fill it in the next iteration
497     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
498     NewIfBlock->setName("else");
499     BasicBlock *PrevIfBlock = IfBlock;
500     IfBlock = NewIfBlock;
501 
502     // Create the phi to join the new and previous value.
503     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
504     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
505     Phi->addIncoming(NewVResult, CondBlock);
506     Phi->addIncoming(VResult, PrevIfBlock);
507     VResult = Phi;
508   }
509 
510   CI->replaceAllUsesWith(VResult);
511   CI->eraseFromParent();
512 
513   ModifiedDT = true;
514 }
515 
516 // Translate a masked scatter intrinsic, like
517 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
518 //                                  <16 x i1> %Mask)
519 // to a chain of basic blocks, that stores element one-by-one if
520 // the appropriate mask bit is set.
521 //
522 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
523 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
524 // br i1 %Mask0, label %cond.store, label %else
525 //
526 // cond.store:
527 // %Elt0 = extractelement <16 x i32> %Src, i32 0
528 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
529 // store i32 %Elt0, i32* %Ptr0, align 4
530 // br label %else
531 //
532 // else:
533 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
534 // br i1 %Mask1, label %cond.store1, label %else2
535 //
536 // cond.store1:
537 // %Elt1 = extractelement <16 x i32> %Src, i32 1
538 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
539 // store i32 %Elt1, i32* %Ptr1, align 4
540 // br label %else2
541 //   . . .
542 static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
543                                    DomTreeUpdater *DTU, bool &ModifiedDT) {
544   Value *Src = CI->getArgOperand(0);
545   Value *Ptrs = CI->getArgOperand(1);
546   Value *Alignment = CI->getArgOperand(2);
547   Value *Mask = CI->getArgOperand(3);
548 
549   auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
550 
551   assert(
552       isa<VectorType>(Ptrs->getType()) &&
553       isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
554       "Vector of pointers is expected in masked scatter intrinsic");
555 
556   IRBuilder<> Builder(CI->getContext());
557   Instruction *InsertPt = CI;
558   Builder.SetInsertPoint(InsertPt);
559   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
560 
561   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
562   unsigned VectorWidth = SrcFVTy->getNumElements();
563 
564   // Shorten the way if the mask is a vector of constants.
565   if (isConstantIntVector(Mask)) {
566     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
567       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
568         continue;
569       Value *OneElt =
570           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
571       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
572       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
573     }
574     CI->eraseFromParent();
575     return;
576   }
577 
578   // If the mask is not v1i1, use scalar bit test operations. This generates
579   // better results on X86 at least.
580   Value *SclrMask;
581   if (VectorWidth != 1) {
582     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
583     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
584   }
585 
586   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
587     // Fill the "else" block, created in the previous iteration
588     //
589     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
590     //  %cond = icmp ne i16 %mask_1, 0
591     //  br i1 %Mask1, label %cond.store, label %else
592     //
593     Value *Predicate;
594     if (VectorWidth != 1) {
595       Value *Mask = Builder.getInt(APInt::getOneBitSet(
596           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
597       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
598                                        Builder.getIntN(VectorWidth, 0));
599     } else {
600       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
601     }
602 
603     // Create "cond" block
604     //
605     //  %Elt1 = extractelement <16 x i32> %Src, i32 1
606     //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
607     //  %store i32 %Elt1, i32* %Ptr1
608     //
609     Instruction *ThenTerm =
610         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
611                                   /*BranchWeights=*/nullptr, DTU);
612 
613     BasicBlock *CondBlock = ThenTerm->getParent();
614     CondBlock->setName("cond.store");
615 
616     Builder.SetInsertPoint(CondBlock->getTerminator());
617     Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
618     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
619     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
620 
621     // Create "else" block, fill it in the next iteration
622     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
623     NewIfBlock->setName("else");
624 
625     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
626   }
627   CI->eraseFromParent();
628 
629   ModifiedDT = true;
630 }
631 
632 static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
633                                       DomTreeUpdater *DTU, bool &ModifiedDT) {
634   Value *Ptr = CI->getArgOperand(0);
635   Value *Mask = CI->getArgOperand(1);
636   Value *PassThru = CI->getArgOperand(2);
637 
638   auto *VecType = cast<FixedVectorType>(CI->getType());
639 
640   Type *EltTy = VecType->getElementType();
641 
642   IRBuilder<> Builder(CI->getContext());
643   Instruction *InsertPt = CI;
644   BasicBlock *IfBlock = CI->getParent();
645 
646   Builder.SetInsertPoint(InsertPt);
647   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
648 
649   unsigned VectorWidth = VecType->getNumElements();
650 
651   // The result vector
652   Value *VResult = PassThru;
653 
654   // Shorten the way if the mask is a vector of constants.
655   // Create a build_vector pattern, with loads/undefs as necessary and then
656   // shuffle blend with the pass through value.
657   if (isConstantIntVector(Mask)) {
658     unsigned MemIndex = 0;
659     VResult = UndefValue::get(VecType);
660     SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem);
661     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
662       Value *InsertElt;
663       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
664         InsertElt = UndefValue::get(EltTy);
665         ShuffleMask[Idx] = Idx + VectorWidth;
666       } else {
667         Value *NewPtr =
668             Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
669         InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1),
670                                               "Load" + Twine(Idx));
671         ShuffleMask[Idx] = Idx;
672         ++MemIndex;
673       }
674       VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
675                                             "Res" + Twine(Idx));
676     }
677     VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
678     CI->replaceAllUsesWith(VResult);
679     CI->eraseFromParent();
680     return;
681   }
682 
683   // If the mask is not v1i1, use scalar bit test operations. This generates
684   // better results on X86 at least.
685   Value *SclrMask;
686   if (VectorWidth != 1) {
687     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
688     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
689   }
690 
691   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
692     // Fill the "else" block, created in the previous iteration
693     //
694     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
695     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
696     //  br i1 %mask_1, label %cond.load, label %else
697     //
698 
699     Value *Predicate;
700     if (VectorWidth != 1) {
701       Value *Mask = Builder.getInt(APInt::getOneBitSet(
702           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
703       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
704                                        Builder.getIntN(VectorWidth, 0));
705     } else {
706       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
707     }
708 
709     // Create "cond" block
710     //
711     //  %EltAddr = getelementptr i32* %1, i32 0
712     //  %Elt = load i32* %EltAddr
713     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
714     //
715     Instruction *ThenTerm =
716         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
717                                   /*BranchWeights=*/nullptr, DTU);
718 
719     BasicBlock *CondBlock = ThenTerm->getParent();
720     CondBlock->setName("cond.load");
721 
722     Builder.SetInsertPoint(CondBlock->getTerminator());
723     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1));
724     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
725 
726     // Move the pointer if there are more blocks to come.
727     Value *NewPtr;
728     if ((Idx + 1) != VectorWidth)
729       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
730 
731     // Create "else" block, fill it in the next iteration
732     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
733     NewIfBlock->setName("else");
734     BasicBlock *PrevIfBlock = IfBlock;
735     IfBlock = NewIfBlock;
736 
737     // Create the phi to join the new and previous value.
738     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
739     PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
740     ResultPhi->addIncoming(NewVResult, CondBlock);
741     ResultPhi->addIncoming(VResult, PrevIfBlock);
742     VResult = ResultPhi;
743 
744     // Add a PHI for the pointer if this isn't the last iteration.
745     if ((Idx + 1) != VectorWidth) {
746       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
747       PtrPhi->addIncoming(NewPtr, CondBlock);
748       PtrPhi->addIncoming(Ptr, PrevIfBlock);
749       Ptr = PtrPhi;
750     }
751   }
752 
753   CI->replaceAllUsesWith(VResult);
754   CI->eraseFromParent();
755 
756   ModifiedDT = true;
757 }
758 
759 static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
760                                          DomTreeUpdater *DTU,
761                                          bool &ModifiedDT) {
762   Value *Src = CI->getArgOperand(0);
763   Value *Ptr = CI->getArgOperand(1);
764   Value *Mask = CI->getArgOperand(2);
765 
766   auto *VecType = cast<FixedVectorType>(Src->getType());
767 
768   IRBuilder<> Builder(CI->getContext());
769   Instruction *InsertPt = CI;
770   BasicBlock *IfBlock = CI->getParent();
771 
772   Builder.SetInsertPoint(InsertPt);
773   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
774 
775   Type *EltTy = VecType->getElementType();
776 
777   unsigned VectorWidth = VecType->getNumElements();
778 
779   // Shorten the way if the mask is a vector of constants.
780   if (isConstantIntVector(Mask)) {
781     unsigned MemIndex = 0;
782     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
783       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
784         continue;
785       Value *OneElt =
786           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
787       Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
788       Builder.CreateAlignedStore(OneElt, NewPtr, Align(1));
789       ++MemIndex;
790     }
791     CI->eraseFromParent();
792     return;
793   }
794 
795   // If the mask is not v1i1, use scalar bit test operations. This generates
796   // better results on X86 at least.
797   Value *SclrMask;
798   if (VectorWidth != 1) {
799     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
800     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
801   }
802 
803   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
804     // Fill the "else" block, created in the previous iteration
805     //
806     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
807     //  br i1 %mask_1, label %cond.store, label %else
808     //
809     Value *Predicate;
810     if (VectorWidth != 1) {
811       Value *Mask = Builder.getInt(APInt::getOneBitSet(
812           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
813       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
814                                        Builder.getIntN(VectorWidth, 0));
815     } else {
816       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
817     }
818 
819     // Create "cond" block
820     //
821     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
822     //  %EltAddr = getelementptr i32* %1, i32 0
823     //  %store i32 %OneElt, i32* %EltAddr
824     //
825     Instruction *ThenTerm =
826         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
827                                   /*BranchWeights=*/nullptr, DTU);
828 
829     BasicBlock *CondBlock = ThenTerm->getParent();
830     CondBlock->setName("cond.store");
831 
832     Builder.SetInsertPoint(CondBlock->getTerminator());
833     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
834     Builder.CreateAlignedStore(OneElt, Ptr, Align(1));
835 
836     // Move the pointer if there are more blocks to come.
837     Value *NewPtr;
838     if ((Idx + 1) != VectorWidth)
839       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
840 
841     // Create "else" block, fill it in the next iteration
842     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
843     NewIfBlock->setName("else");
844     BasicBlock *PrevIfBlock = IfBlock;
845     IfBlock = NewIfBlock;
846 
847     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
848 
849     // Add a PHI for the pointer if this isn't the last iteration.
850     if ((Idx + 1) != VectorWidth) {
851       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
852       PtrPhi->addIncoming(NewPtr, CondBlock);
853       PtrPhi->addIncoming(Ptr, PrevIfBlock);
854       Ptr = PtrPhi;
855     }
856   }
857   CI->eraseFromParent();
858 
859   ModifiedDT = true;
860 }
861 
862 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
863                     DominatorTree *DT) {
864   Optional<DomTreeUpdater> DTU;
865   if (DT)
866     DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
867 
868   bool EverMadeChange = false;
869   bool MadeChange = true;
870   auto &DL = F.getParent()->getDataLayout();
871   while (MadeChange) {
872     MadeChange = false;
873     for (BasicBlock &BB : llvm::make_early_inc_range(F)) {
874       bool ModifiedDTOnIteration = false;
875       MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
876                                   DTU ? DTU.getPointer() : nullptr);
877 
878       // Restart BB iteration if the dominator tree of the Function was changed
879       if (ModifiedDTOnIteration)
880         break;
881     }
882 
883     EverMadeChange |= MadeChange;
884   }
885   return EverMadeChange;
886 }
887 
888 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
889   auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
890   DominatorTree *DT = nullptr;
891   if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
892     DT = &DTWP->getDomTree();
893   return runImpl(F, TTI, DT);
894 }
895 
896 PreservedAnalyses
897 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
898   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
899   auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
900   if (!runImpl(F, TTI, DT))
901     return PreservedAnalyses::all();
902   PreservedAnalyses PA;
903   PA.preserve<TargetIRAnalysis>();
904   PA.preserve<DominatorTreeAnalysis>();
905   return PA;
906 }
907 
908 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
909                           const TargetTransformInfo &TTI, const DataLayout &DL,
910                           DomTreeUpdater *DTU) {
911   bool MadeChange = false;
912 
913   BasicBlock::iterator CurInstIterator = BB.begin();
914   while (CurInstIterator != BB.end()) {
915     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
916       MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU);
917     if (ModifiedDT)
918       return true;
919   }
920 
921   return MadeChange;
922 }
923 
924 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
925                              const TargetTransformInfo &TTI,
926                              const DataLayout &DL, DomTreeUpdater *DTU) {
927   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
928   if (II) {
929     // The scalarization code below does not work for scalable vectors.
930     if (isa<ScalableVectorType>(II->getType()) ||
931         any_of(II->args(),
932                [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
933       return false;
934 
935     switch (II->getIntrinsicID()) {
936     default:
937       break;
938     case Intrinsic::masked_load:
939       // Scalarize unsupported vector masked load
940       if (TTI.isLegalMaskedLoad(
941               CI->getType(),
942               cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
943         return false;
944       scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT);
945       return true;
946     case Intrinsic::masked_store:
947       if (TTI.isLegalMaskedStore(
948               CI->getArgOperand(0)->getType(),
949               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
950         return false;
951       scalarizeMaskedStore(DL, CI, DTU, ModifiedDT);
952       return true;
953     case Intrinsic::masked_gather: {
954       MaybeAlign MA =
955           cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue();
956       Type *LoadTy = CI->getType();
957       Align Alignment = DL.getValueOrABITypeAlignment(MA,
958                                                       LoadTy->getScalarType());
959       if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
960           !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
961         return false;
962       scalarizeMaskedGather(DL, CI, DTU, ModifiedDT);
963       return true;
964     }
965     case Intrinsic::masked_scatter: {
966       MaybeAlign MA =
967           cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue();
968       Type *StoreTy = CI->getArgOperand(0)->getType();
969       Align Alignment = DL.getValueOrABITypeAlignment(MA,
970                                                       StoreTy->getScalarType());
971       if (TTI.isLegalMaskedScatter(StoreTy, Alignment) &&
972           !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
973                                            Alignment))
974         return false;
975       scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT);
976       return true;
977     }
978     case Intrinsic::masked_expandload:
979       if (TTI.isLegalMaskedExpandLoad(CI->getType()))
980         return false;
981       scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
982       return true;
983     case Intrinsic::masked_compressstore:
984       if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
985         return false;
986       scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
987       return true;
988     }
989   }
990 
991   return false;
992 }
993