1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 //                                    instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/CodeGen/TargetSubtargetInfo.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include <algorithm>
36 #include <cassert>
37 
38 using namespace llvm;
39 
40 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
41 
42 namespace {
43 
44 class ScalarizeMaskedMemIntrin : public FunctionPass {
45   const TargetTransformInfo *TTI = nullptr;
46   const DataLayout *DL = nullptr;
47 
48 public:
49   static char ID; // Pass identification, replacement for typeid
50 
ScalarizeMaskedMemIntrin()51   explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
52     initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
53   }
54 
55   bool runOnFunction(Function &F) override;
56 
getPassName() const57   StringRef getPassName() const override {
58     return "Scalarize Masked Memory Intrinsics";
59   }
60 
getAnalysisUsage(AnalysisUsage & AU) const61   void getAnalysisUsage(AnalysisUsage &AU) const override {
62     AU.addRequired<TargetTransformInfoWrapperPass>();
63   }
64 
65 private:
66   bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
67   bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
68 };
69 
70 } // end anonymous namespace
71 
72 char ScalarizeMaskedMemIntrin::ID = 0;
73 
74 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
75                 "Scalarize unsupported masked memory intrinsics", false, false)
76 
createScalarizeMaskedMemIntrinPass()77 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
78   return new ScalarizeMaskedMemIntrin();
79 }
80 
isConstantIntVector(Value * Mask)81 static bool isConstantIntVector(Value *Mask) {
82   Constant *C = dyn_cast<Constant>(Mask);
83   if (!C)
84     return false;
85 
86   unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
87   for (unsigned i = 0; i != NumElts; ++i) {
88     Constant *CElt = C->getAggregateElement(i);
89     if (!CElt || !isa<ConstantInt>(CElt))
90       return false;
91   }
92 
93   return true;
94 }
95 
96 // Translate a masked load intrinsic like
97 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
98 //                               <16 x i1> %mask, <16 x i32> %passthru)
99 // to a chain of basic blocks, with loading element one-by-one if
100 // the appropriate mask bit is set
101 //
102 //  %1 = bitcast i8* %addr to i32*
103 //  %2 = extractelement <16 x i1> %mask, i32 0
104 //  br i1 %2, label %cond.load, label %else
105 //
106 // cond.load:                                        ; preds = %0
107 //  %3 = getelementptr i32* %1, i32 0
108 //  %4 = load i32* %3
109 //  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
110 //  br label %else
111 //
112 // else:                                             ; preds = %0, %cond.load
113 //  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
114 //  %6 = extractelement <16 x i1> %mask, i32 1
115 //  br i1 %6, label %cond.load1, label %else2
116 //
117 // cond.load1:                                       ; preds = %else
118 //  %7 = getelementptr i32* %1, i32 1
119 //  %8 = load i32* %7
120 //  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
121 //  br label %else2
122 //
123 // else2:                                          ; preds = %else, %cond.load1
124 //  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
125 //  %10 = extractelement <16 x i1> %mask, i32 2
126 //  br i1 %10, label %cond.load4, label %else5
127 //
scalarizeMaskedLoad(CallInst * CI,bool & ModifiedDT)128 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
129   Value *Ptr = CI->getArgOperand(0);
130   Value *Alignment = CI->getArgOperand(1);
131   Value *Mask = CI->getArgOperand(2);
132   Value *Src0 = CI->getArgOperand(3);
133 
134   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
135   VectorType *VecType = cast<FixedVectorType>(CI->getType());
136 
137   Type *EltTy = VecType->getElementType();
138 
139   IRBuilder<> Builder(CI->getContext());
140   Instruction *InsertPt = CI;
141   BasicBlock *IfBlock = CI->getParent();
142 
143   Builder.SetInsertPoint(InsertPt);
144   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
145 
146   // Short-cut if the mask is all-true.
147   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
148     Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
149     CI->replaceAllUsesWith(NewI);
150     CI->eraseFromParent();
151     return;
152   }
153 
154   // Adjust alignment for the scalar instruction.
155   const Align AdjustedAlignVal =
156       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
157   // Bitcast %addr from i8* to EltTy*
158   Type *NewPtrType =
159       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
160   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
161   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
162 
163   // The result vector
164   Value *VResult = Src0;
165 
166   if (isConstantIntVector(Mask)) {
167     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
168       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
169         continue;
170       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
171       LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
172       VResult = Builder.CreateInsertElement(VResult, Load, Idx);
173     }
174     CI->replaceAllUsesWith(VResult);
175     CI->eraseFromParent();
176     return;
177   }
178 
179   // If the mask is not v1i1, use scalar bit test operations. This generates
180   // better results on X86 at least.
181   Value *SclrMask;
182   if (VectorWidth != 1) {
183     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
184     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
185   }
186 
187   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
188     // Fill the "else" block, created in the previous iteration
189     //
190     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
191     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
192     //  %cond = icmp ne i16 %mask_1, 0
193     //  br i1 %mask_1, label %cond.load, label %else
194     //
195     Value *Predicate;
196     if (VectorWidth != 1) {
197       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
198       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
199                                        Builder.getIntN(VectorWidth, 0));
200     } else {
201       Predicate = Builder.CreateExtractElement(Mask, Idx);
202     }
203 
204     // Create "cond" block
205     //
206     //  %EltAddr = getelementptr i32* %1, i32 0
207     //  %Elt = load i32* %EltAddr
208     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
209     //
210     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
211                                                      "cond.load");
212     Builder.SetInsertPoint(InsertPt);
213 
214     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
215     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
216     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
217 
218     // Create "else" block, fill it in the next iteration
219     BasicBlock *NewIfBlock =
220         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
221     Builder.SetInsertPoint(InsertPt);
222     Instruction *OldBr = IfBlock->getTerminator();
223     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
224     OldBr->eraseFromParent();
225     BasicBlock *PrevIfBlock = IfBlock;
226     IfBlock = NewIfBlock;
227 
228     // Create the phi to join the new and previous value.
229     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
230     Phi->addIncoming(NewVResult, CondBlock);
231     Phi->addIncoming(VResult, PrevIfBlock);
232     VResult = Phi;
233   }
234 
235   CI->replaceAllUsesWith(VResult);
236   CI->eraseFromParent();
237 
238   ModifiedDT = true;
239 }
240 
241 // Translate a masked store intrinsic, like
242 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
243 //                               <16 x i1> %mask)
244 // to a chain of basic blocks, that stores element one-by-one if
245 // the appropriate mask bit is set
246 //
247 //   %1 = bitcast i8* %addr to i32*
248 //   %2 = extractelement <16 x i1> %mask, i32 0
249 //   br i1 %2, label %cond.store, label %else
250 //
251 // cond.store:                                       ; preds = %0
252 //   %3 = extractelement <16 x i32> %val, i32 0
253 //   %4 = getelementptr i32* %1, i32 0
254 //   store i32 %3, i32* %4
255 //   br label %else
256 //
257 // else:                                             ; preds = %0, %cond.store
258 //   %5 = extractelement <16 x i1> %mask, i32 1
259 //   br i1 %5, label %cond.store1, label %else2
260 //
261 // cond.store1:                                      ; preds = %else
262 //   %6 = extractelement <16 x i32> %val, i32 1
263 //   %7 = getelementptr i32* %1, i32 1
264 //   store i32 %6, i32* %7
265 //   br label %else2
266 //   . . .
scalarizeMaskedStore(CallInst * CI,bool & ModifiedDT)267 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
268   Value *Src = CI->getArgOperand(0);
269   Value *Ptr = CI->getArgOperand(1);
270   Value *Alignment = CI->getArgOperand(2);
271   Value *Mask = CI->getArgOperand(3);
272 
273   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
274   auto *VecType = cast<VectorType>(Src->getType());
275 
276   Type *EltTy = VecType->getElementType();
277 
278   IRBuilder<> Builder(CI->getContext());
279   Instruction *InsertPt = CI;
280   BasicBlock *IfBlock = CI->getParent();
281   Builder.SetInsertPoint(InsertPt);
282   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
283 
284   // Short-cut if the mask is all-true.
285   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
286     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
287     CI->eraseFromParent();
288     return;
289   }
290 
291   // Adjust alignment for the scalar instruction.
292   const Align AdjustedAlignVal =
293       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
294   // Bitcast %addr from i8* to EltTy*
295   Type *NewPtrType =
296       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
297   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
298   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
299 
300   if (isConstantIntVector(Mask)) {
301     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
302       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
303         continue;
304       Value *OneElt = Builder.CreateExtractElement(Src, Idx);
305       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
306       Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
307     }
308     CI->eraseFromParent();
309     return;
310   }
311 
312   // If the mask is not v1i1, use scalar bit test operations. This generates
313   // better results on X86 at least.
314   Value *SclrMask;
315   if (VectorWidth != 1) {
316     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
317     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
318   }
319 
320   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
321     // Fill the "else" block, created in the previous iteration
322     //
323     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
324     //  %cond = icmp ne i16 %mask_1, 0
325     //  br i1 %mask_1, label %cond.store, label %else
326     //
327     Value *Predicate;
328     if (VectorWidth != 1) {
329       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
330       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
331                                        Builder.getIntN(VectorWidth, 0));
332     } else {
333       Predicate = Builder.CreateExtractElement(Mask, Idx);
334     }
335 
336     // Create "cond" block
337     //
338     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
339     //  %EltAddr = getelementptr i32* %1, i32 0
340     //  %store i32 %OneElt, i32* %EltAddr
341     //
342     BasicBlock *CondBlock =
343         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
344     Builder.SetInsertPoint(InsertPt);
345 
346     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
347     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
348     Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
349 
350     // Create "else" block, fill it in the next iteration
351     BasicBlock *NewIfBlock =
352         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
353     Builder.SetInsertPoint(InsertPt);
354     Instruction *OldBr = IfBlock->getTerminator();
355     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
356     OldBr->eraseFromParent();
357     IfBlock = NewIfBlock;
358   }
359   CI->eraseFromParent();
360 
361   ModifiedDT = true;
362 }
363 
364 // Translate a masked gather intrinsic like
365 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
366 //                               <16 x i1> %Mask, <16 x i32> %Src)
367 // to a chain of basic blocks, with loading element one-by-one if
368 // the appropriate mask bit is set
369 //
370 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
371 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
372 // br i1 %Mask0, label %cond.load, label %else
373 //
374 // cond.load:
375 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
376 // %Load0 = load i32, i32* %Ptr0, align 4
377 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
378 // br label %else
379 //
380 // else:
381 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
382 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
383 // br i1 %Mask1, label %cond.load1, label %else2
384 //
385 // cond.load1:
386 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
387 // %Load1 = load i32, i32* %Ptr1, align 4
388 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
389 // br label %else2
390 // . . .
391 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
392 // ret <16 x i32> %Result
scalarizeMaskedGather(CallInst * CI,bool & ModifiedDT)393 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
394   Value *Ptrs = CI->getArgOperand(0);
395   Value *Alignment = CI->getArgOperand(1);
396   Value *Mask = CI->getArgOperand(2);
397   Value *Src0 = CI->getArgOperand(3);
398 
399   auto *VecType = cast<FixedVectorType>(CI->getType());
400   Type *EltTy = VecType->getElementType();
401 
402   IRBuilder<> Builder(CI->getContext());
403   Instruction *InsertPt = CI;
404   BasicBlock *IfBlock = CI->getParent();
405   Builder.SetInsertPoint(InsertPt);
406   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
407 
408   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
409 
410   // The result vector
411   Value *VResult = Src0;
412   unsigned VectorWidth = VecType->getNumElements();
413 
414   // Shorten the way if the mask is a vector of constants.
415   if (isConstantIntVector(Mask)) {
416     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
417       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
418         continue;
419       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
420       LoadInst *Load =
421           Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
422       VResult =
423           Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
424     }
425     CI->replaceAllUsesWith(VResult);
426     CI->eraseFromParent();
427     return;
428   }
429 
430   // If the mask is not v1i1, use scalar bit test operations. This generates
431   // better results on X86 at least.
432   Value *SclrMask;
433   if (VectorWidth != 1) {
434     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
435     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
436   }
437 
438   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
439     // Fill the "else" block, created in the previous iteration
440     //
441     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
442     //  %cond = icmp ne i16 %mask_1, 0
443     //  br i1 %Mask1, label %cond.load, label %else
444     //
445 
446     Value *Predicate;
447     if (VectorWidth != 1) {
448       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
449       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
450                                        Builder.getIntN(VectorWidth, 0));
451     } else {
452       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
453     }
454 
455     // Create "cond" block
456     //
457     //  %EltAddr = getelementptr i32* %1, i32 0
458     //  %Elt = load i32* %EltAddr
459     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
460     //
461     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
462     Builder.SetInsertPoint(InsertPt);
463 
464     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
465     LoadInst *Load =
466         Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
467     Value *NewVResult =
468         Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
469 
470     // Create "else" block, fill it in the next iteration
471     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
472     Builder.SetInsertPoint(InsertPt);
473     Instruction *OldBr = IfBlock->getTerminator();
474     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
475     OldBr->eraseFromParent();
476     BasicBlock *PrevIfBlock = IfBlock;
477     IfBlock = NewIfBlock;
478 
479     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
480     Phi->addIncoming(NewVResult, CondBlock);
481     Phi->addIncoming(VResult, PrevIfBlock);
482     VResult = Phi;
483   }
484 
485   CI->replaceAllUsesWith(VResult);
486   CI->eraseFromParent();
487 
488   ModifiedDT = true;
489 }
490 
491 // Translate a masked scatter intrinsic, like
492 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
493 //                                  <16 x i1> %Mask)
494 // to a chain of basic blocks, that stores element one-by-one if
495 // the appropriate mask bit is set.
496 //
497 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
498 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
499 // br i1 %Mask0, label %cond.store, label %else
500 //
501 // cond.store:
502 // %Elt0 = extractelement <16 x i32> %Src, i32 0
503 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
504 // store i32 %Elt0, i32* %Ptr0, align 4
505 // br label %else
506 //
507 // else:
508 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
509 // br i1 %Mask1, label %cond.store1, label %else2
510 //
511 // cond.store1:
512 // %Elt1 = extractelement <16 x i32> %Src, i32 1
513 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
514 // store i32 %Elt1, i32* %Ptr1, align 4
515 // br label %else2
516 //   . . .
scalarizeMaskedScatter(CallInst * CI,bool & ModifiedDT)517 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
518   Value *Src = CI->getArgOperand(0);
519   Value *Ptrs = CI->getArgOperand(1);
520   Value *Alignment = CI->getArgOperand(2);
521   Value *Mask = CI->getArgOperand(3);
522 
523   auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
524 
525   assert(
526       isa<VectorType>(Ptrs->getType()) &&
527       isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
528       "Vector of pointers is expected in masked scatter intrinsic");
529 
530   IRBuilder<> Builder(CI->getContext());
531   Instruction *InsertPt = CI;
532   BasicBlock *IfBlock = CI->getParent();
533   Builder.SetInsertPoint(InsertPt);
534   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
535 
536   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
537   unsigned VectorWidth = SrcFVTy->getNumElements();
538 
539   // Shorten the way if the mask is a vector of constants.
540   if (isConstantIntVector(Mask)) {
541     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
542       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
543         continue;
544       Value *OneElt =
545           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
546       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
547       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
548     }
549     CI->eraseFromParent();
550     return;
551   }
552 
553   // If the mask is not v1i1, use scalar bit test operations. This generates
554   // better results on X86 at least.
555   Value *SclrMask;
556   if (VectorWidth != 1) {
557     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
558     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
559   }
560 
561   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
562     // Fill the "else" block, created in the previous iteration
563     //
564     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
565     //  %cond = icmp ne i16 %mask_1, 0
566     //  br i1 %Mask1, label %cond.store, label %else
567     //
568     Value *Predicate;
569     if (VectorWidth != 1) {
570       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
571       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
572                                        Builder.getIntN(VectorWidth, 0));
573     } else {
574       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
575     }
576 
577     // Create "cond" block
578     //
579     //  %Elt1 = extractelement <16 x i32> %Src, i32 1
580     //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
581     //  %store i32 %Elt1, i32* %Ptr1
582     //
583     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
584     Builder.SetInsertPoint(InsertPt);
585 
586     Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
587     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
588     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
589 
590     // Create "else" block, fill it in the next iteration
591     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
592     Builder.SetInsertPoint(InsertPt);
593     Instruction *OldBr = IfBlock->getTerminator();
594     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
595     OldBr->eraseFromParent();
596     IfBlock = NewIfBlock;
597   }
598   CI->eraseFromParent();
599 
600   ModifiedDT = true;
601 }
602 
scalarizeMaskedExpandLoad(CallInst * CI,bool & ModifiedDT)603 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
604   Value *Ptr = CI->getArgOperand(0);
605   Value *Mask = CI->getArgOperand(1);
606   Value *PassThru = CI->getArgOperand(2);
607 
608   auto *VecType = cast<FixedVectorType>(CI->getType());
609 
610   Type *EltTy = VecType->getElementType();
611 
612   IRBuilder<> Builder(CI->getContext());
613   Instruction *InsertPt = CI;
614   BasicBlock *IfBlock = CI->getParent();
615 
616   Builder.SetInsertPoint(InsertPt);
617   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
618 
619   unsigned VectorWidth = VecType->getNumElements();
620 
621   // The result vector
622   Value *VResult = PassThru;
623 
624   // Shorten the way if the mask is a vector of constants.
625   // Create a build_vector pattern, with loads/undefs as necessary and then
626   // shuffle blend with the pass through value.
627   if (isConstantIntVector(Mask)) {
628     unsigned MemIndex = 0;
629     VResult = UndefValue::get(VecType);
630     SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem);
631     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
632       Value *InsertElt;
633       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
634         InsertElt = UndefValue::get(EltTy);
635         ShuffleMask[Idx] = Idx + VectorWidth;
636       } else {
637         Value *NewPtr =
638             Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
639         InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1),
640                                               "Load" + Twine(Idx));
641         ShuffleMask[Idx] = Idx;
642         ++MemIndex;
643       }
644       VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
645                                             "Res" + Twine(Idx));
646     }
647     VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
648     CI->replaceAllUsesWith(VResult);
649     CI->eraseFromParent();
650     return;
651   }
652 
653   // If the mask is not v1i1, use scalar bit test operations. This generates
654   // better results on X86 at least.
655   Value *SclrMask;
656   if (VectorWidth != 1) {
657     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
658     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
659   }
660 
661   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
662     // Fill the "else" block, created in the previous iteration
663     //
664     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
665     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
666     //  br i1 %mask_1, label %cond.load, label %else
667     //
668 
669     Value *Predicate;
670     if (VectorWidth != 1) {
671       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
672       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
673                                        Builder.getIntN(VectorWidth, 0));
674     } else {
675       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
676     }
677 
678     // Create "cond" block
679     //
680     //  %EltAddr = getelementptr i32* %1, i32 0
681     //  %Elt = load i32* %EltAddr
682     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
683     //
684     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
685                                                      "cond.load");
686     Builder.SetInsertPoint(InsertPt);
687 
688     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1));
689     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
690 
691     // Move the pointer if there are more blocks to come.
692     Value *NewPtr;
693     if ((Idx + 1) != VectorWidth)
694       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
695 
696     // Create "else" block, fill it in the next iteration
697     BasicBlock *NewIfBlock =
698         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
699     Builder.SetInsertPoint(InsertPt);
700     Instruction *OldBr = IfBlock->getTerminator();
701     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
702     OldBr->eraseFromParent();
703     BasicBlock *PrevIfBlock = IfBlock;
704     IfBlock = NewIfBlock;
705 
706     // Create the phi to join the new and previous value.
707     PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
708     ResultPhi->addIncoming(NewVResult, CondBlock);
709     ResultPhi->addIncoming(VResult, PrevIfBlock);
710     VResult = ResultPhi;
711 
712     // Add a PHI for the pointer if this isn't the last iteration.
713     if ((Idx + 1) != VectorWidth) {
714       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
715       PtrPhi->addIncoming(NewPtr, CondBlock);
716       PtrPhi->addIncoming(Ptr, PrevIfBlock);
717       Ptr = PtrPhi;
718     }
719   }
720 
721   CI->replaceAllUsesWith(VResult);
722   CI->eraseFromParent();
723 
724   ModifiedDT = true;
725 }
726 
scalarizeMaskedCompressStore(CallInst * CI,bool & ModifiedDT)727 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
728   Value *Src = CI->getArgOperand(0);
729   Value *Ptr = CI->getArgOperand(1);
730   Value *Mask = CI->getArgOperand(2);
731 
732   auto *VecType = cast<FixedVectorType>(Src->getType());
733 
734   IRBuilder<> Builder(CI->getContext());
735   Instruction *InsertPt = CI;
736   BasicBlock *IfBlock = CI->getParent();
737 
738   Builder.SetInsertPoint(InsertPt);
739   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
740 
741   Type *EltTy = VecType->getElementType();
742 
743   unsigned VectorWidth = VecType->getNumElements();
744 
745   // Shorten the way if the mask is a vector of constants.
746   if (isConstantIntVector(Mask)) {
747     unsigned MemIndex = 0;
748     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
749       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
750         continue;
751       Value *OneElt =
752           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
753       Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
754       Builder.CreateAlignedStore(OneElt, NewPtr, Align(1));
755       ++MemIndex;
756     }
757     CI->eraseFromParent();
758     return;
759   }
760 
761   // If the mask is not v1i1, use scalar bit test operations. This generates
762   // better results on X86 at least.
763   Value *SclrMask;
764   if (VectorWidth != 1) {
765     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
766     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
767   }
768 
769   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
770     // Fill the "else" block, created in the previous iteration
771     //
772     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
773     //  br i1 %mask_1, label %cond.store, label %else
774     //
775     Value *Predicate;
776     if (VectorWidth != 1) {
777       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
778       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
779                                        Builder.getIntN(VectorWidth, 0));
780     } else {
781       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
782     }
783 
784     // Create "cond" block
785     //
786     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
787     //  %EltAddr = getelementptr i32* %1, i32 0
788     //  %store i32 %OneElt, i32* %EltAddr
789     //
790     BasicBlock *CondBlock =
791         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
792     Builder.SetInsertPoint(InsertPt);
793 
794     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
795     Builder.CreateAlignedStore(OneElt, Ptr, Align(1));
796 
797     // Move the pointer if there are more blocks to come.
798     Value *NewPtr;
799     if ((Idx + 1) != VectorWidth)
800       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
801 
802     // Create "else" block, fill it in the next iteration
803     BasicBlock *NewIfBlock =
804         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
805     Builder.SetInsertPoint(InsertPt);
806     Instruction *OldBr = IfBlock->getTerminator();
807     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
808     OldBr->eraseFromParent();
809     BasicBlock *PrevIfBlock = IfBlock;
810     IfBlock = NewIfBlock;
811 
812     // Add a PHI for the pointer if this isn't the last iteration.
813     if ((Idx + 1) != VectorWidth) {
814       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
815       PtrPhi->addIncoming(NewPtr, CondBlock);
816       PtrPhi->addIncoming(Ptr, PrevIfBlock);
817       Ptr = PtrPhi;
818     }
819   }
820   CI->eraseFromParent();
821 
822   ModifiedDT = true;
823 }
824 
runOnFunction(Function & F)825 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
826   bool EverMadeChange = false;
827 
828   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
829   DL = &F.getParent()->getDataLayout();
830 
831   bool MadeChange = true;
832   while (MadeChange) {
833     MadeChange = false;
834     for (Function::iterator I = F.begin(); I != F.end();) {
835       BasicBlock *BB = &*I++;
836       bool ModifiedDTOnIteration = false;
837       MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
838 
839       // Restart BB iteration if the dominator tree of the Function was changed
840       if (ModifiedDTOnIteration)
841         break;
842     }
843 
844     EverMadeChange |= MadeChange;
845   }
846 
847   return EverMadeChange;
848 }
849 
optimizeBlock(BasicBlock & BB,bool & ModifiedDT)850 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
851   bool MadeChange = false;
852 
853   BasicBlock::iterator CurInstIterator = BB.begin();
854   while (CurInstIterator != BB.end()) {
855     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
856       MadeChange |= optimizeCallInst(CI, ModifiedDT);
857     if (ModifiedDT)
858       return true;
859   }
860 
861   return MadeChange;
862 }
863 
optimizeCallInst(CallInst * CI,bool & ModifiedDT)864 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
865                                                 bool &ModifiedDT) {
866   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
867   if (II) {
868     // The scalarization code below does not work for scalable vectors.
869     if (isa<ScalableVectorType>(II->getType()) ||
870         any_of(II->arg_operands(),
871                [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
872       return false;
873 
874     switch (II->getIntrinsicID()) {
875     default:
876       break;
877     case Intrinsic::masked_load:
878       // Scalarize unsupported vector masked load
879       if (TTI->isLegalMaskedLoad(
880               CI->getType(),
881               cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
882         return false;
883       scalarizeMaskedLoad(CI, ModifiedDT);
884       return true;
885     case Intrinsic::masked_store:
886       if (TTI->isLegalMaskedStore(
887               CI->getArgOperand(0)->getType(),
888               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
889         return false;
890       scalarizeMaskedStore(CI, ModifiedDT);
891       return true;
892     case Intrinsic::masked_gather: {
893       unsigned AlignmentInt =
894           cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
895       Type *LoadTy = CI->getType();
896       Align Alignment =
897           DL->getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), LoadTy);
898       if (TTI->isLegalMaskedGather(LoadTy, Alignment))
899         return false;
900       scalarizeMaskedGather(CI, ModifiedDT);
901       return true;
902     }
903     case Intrinsic::masked_scatter: {
904       unsigned AlignmentInt =
905           cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
906       Type *StoreTy = CI->getArgOperand(0)->getType();
907       Align Alignment =
908           DL->getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), StoreTy);
909       if (TTI->isLegalMaskedScatter(StoreTy, Alignment))
910         return false;
911       scalarizeMaskedScatter(CI, ModifiedDT);
912       return true;
913     }
914     case Intrinsic::masked_expandload:
915       if (TTI->isLegalMaskedExpandLoad(CI->getType()))
916         return false;
917       scalarizeMaskedExpandLoad(CI, ModifiedDT);
918       return true;
919     case Intrinsic::masked_compressstore:
920       if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
921         return false;
922       scalarizeMaskedCompressStore(CI, ModifiedDT);
923       return true;
924     }
925   }
926 
927   return false;
928 }
929