1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 //                                    instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/CodeGen/TargetSubtargetInfo.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include <algorithm>
36 #include <cassert>
37 
38 using namespace llvm;
39 
40 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
41 
42 namespace {
43 
44 class ScalarizeMaskedMemIntrin : public FunctionPass {
45   const TargetTransformInfo *TTI = nullptr;
46 
47 public:
48   static char ID; // Pass identification, replacement for typeid
49 
ScalarizeMaskedMemIntrin()50   explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
51     initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
52   }
53 
54   bool runOnFunction(Function &F) override;
55 
getPassName() const56   StringRef getPassName() const override {
57     return "Scalarize Masked Memory Intrinsics";
58   }
59 
getAnalysisUsage(AnalysisUsage & AU) const60   void getAnalysisUsage(AnalysisUsage &AU) const override {
61     AU.addRequired<TargetTransformInfoWrapperPass>();
62   }
63 
64 private:
65   bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
66   bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
67 };
68 
69 } // end anonymous namespace
70 
71 char ScalarizeMaskedMemIntrin::ID = 0;
72 
73 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
74                 "Scalarize unsupported masked memory intrinsics", false, false)
75 
createScalarizeMaskedMemIntrinPass()76 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
77   return new ScalarizeMaskedMemIntrin();
78 }
79 
isConstantIntVector(Value * Mask)80 static bool isConstantIntVector(Value *Mask) {
81   Constant *C = dyn_cast<Constant>(Mask);
82   if (!C)
83     return false;
84 
85   unsigned NumElts = Mask->getType()->getVectorNumElements();
86   for (unsigned i = 0; i != NumElts; ++i) {
87     Constant *CElt = C->getAggregateElement(i);
88     if (!CElt || !isa<ConstantInt>(CElt))
89       return false;
90   }
91 
92   return true;
93 }
94 
95 // Translate a masked load intrinsic like
96 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
97 //                               <16 x i1> %mask, <16 x i32> %passthru)
98 // to a chain of basic blocks, with loading element one-by-one if
99 // the appropriate mask bit is set
100 //
101 //  %1 = bitcast i8* %addr to i32*
102 //  %2 = extractelement <16 x i1> %mask, i32 0
103 //  br i1 %2, label %cond.load, label %else
104 //
105 // cond.load:                                        ; preds = %0
106 //  %3 = getelementptr i32* %1, i32 0
107 //  %4 = load i32* %3
108 //  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
109 //  br label %else
110 //
111 // else:                                             ; preds = %0, %cond.load
112 //  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
113 //  %6 = extractelement <16 x i1> %mask, i32 1
114 //  br i1 %6, label %cond.load1, label %else2
115 //
116 // cond.load1:                                       ; preds = %else
117 //  %7 = getelementptr i32* %1, i32 1
118 //  %8 = load i32* %7
119 //  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
120 //  br label %else2
121 //
122 // else2:                                          ; preds = %else, %cond.load1
123 //  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
124 //  %10 = extractelement <16 x i1> %mask, i32 2
125 //  br i1 %10, label %cond.load4, label %else5
126 //
scalarizeMaskedLoad(CallInst * CI,bool & ModifiedDT)127 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
128   Value *Ptr = CI->getArgOperand(0);
129   Value *Alignment = CI->getArgOperand(1);
130   Value *Mask = CI->getArgOperand(2);
131   Value *Src0 = CI->getArgOperand(3);
132 
133   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
134   VectorType *VecType = cast<VectorType>(CI->getType());
135 
136   Type *EltTy = VecType->getElementType();
137 
138   IRBuilder<> Builder(CI->getContext());
139   Instruction *InsertPt = CI;
140   BasicBlock *IfBlock = CI->getParent();
141 
142   Builder.SetInsertPoint(InsertPt);
143   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
144 
145   // Short-cut if the mask is all-true.
146   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
147     Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
148     CI->replaceAllUsesWith(NewI);
149     CI->eraseFromParent();
150     return;
151   }
152 
153   // Adjust alignment for the scalar instruction.
154   AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
155   // Bitcast %addr from i8* to EltTy*
156   Type *NewPtrType =
157       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
158   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
159   unsigned VectorWidth = VecType->getNumElements();
160 
161   // The result vector
162   Value *VResult = Src0;
163 
164   if (isConstantIntVector(Mask)) {
165     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
166       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
167         continue;
168       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
169       LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
170       VResult = Builder.CreateInsertElement(VResult, Load, Idx);
171     }
172     CI->replaceAllUsesWith(VResult);
173     CI->eraseFromParent();
174     return;
175   }
176 
177   // If the mask is not v1i1, use scalar bit test operations. This generates
178   // better results on X86 at least.
179   Value *SclrMask;
180   if (VectorWidth != 1) {
181     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
182     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
183   }
184 
185   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
186     // Fill the "else" block, created in the previous iteration
187     //
188     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
189     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
190     //  %cond = icmp ne i16 %mask_1, 0
191     //  br i1 %mask_1, label %cond.load, label %else
192     //
193     Value *Predicate;
194     if (VectorWidth != 1) {
195       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
196       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
197                                        Builder.getIntN(VectorWidth, 0));
198     } else {
199       Predicate = Builder.CreateExtractElement(Mask, Idx);
200     }
201 
202     // Create "cond" block
203     //
204     //  %EltAddr = getelementptr i32* %1, i32 0
205     //  %Elt = load i32* %EltAddr
206     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
207     //
208     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
209                                                      "cond.load");
210     Builder.SetInsertPoint(InsertPt);
211 
212     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
213     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
214     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
215 
216     // Create "else" block, fill it in the next iteration
217     BasicBlock *NewIfBlock =
218         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
219     Builder.SetInsertPoint(InsertPt);
220     Instruction *OldBr = IfBlock->getTerminator();
221     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
222     OldBr->eraseFromParent();
223     BasicBlock *PrevIfBlock = IfBlock;
224     IfBlock = NewIfBlock;
225 
226     // Create the phi to join the new and previous value.
227     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
228     Phi->addIncoming(NewVResult, CondBlock);
229     Phi->addIncoming(VResult, PrevIfBlock);
230     VResult = Phi;
231   }
232 
233   CI->replaceAllUsesWith(VResult);
234   CI->eraseFromParent();
235 
236   ModifiedDT = true;
237 }
238 
239 // Translate a masked store intrinsic, like
240 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
241 //                               <16 x i1> %mask)
242 // to a chain of basic blocks, that stores element one-by-one if
243 // the appropriate mask bit is set
244 //
245 //   %1 = bitcast i8* %addr to i32*
246 //   %2 = extractelement <16 x i1> %mask, i32 0
247 //   br i1 %2, label %cond.store, label %else
248 //
249 // cond.store:                                       ; preds = %0
250 //   %3 = extractelement <16 x i32> %val, i32 0
251 //   %4 = getelementptr i32* %1, i32 0
252 //   store i32 %3, i32* %4
253 //   br label %else
254 //
255 // else:                                             ; preds = %0, %cond.store
256 //   %5 = extractelement <16 x i1> %mask, i32 1
257 //   br i1 %5, label %cond.store1, label %else2
258 //
259 // cond.store1:                                      ; preds = %else
260 //   %6 = extractelement <16 x i32> %val, i32 1
261 //   %7 = getelementptr i32* %1, i32 1
262 //   store i32 %6, i32* %7
263 //   br label %else2
264 //   . . .
scalarizeMaskedStore(CallInst * CI,bool & ModifiedDT)265 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
266   Value *Src = CI->getArgOperand(0);
267   Value *Ptr = CI->getArgOperand(1);
268   Value *Alignment = CI->getArgOperand(2);
269   Value *Mask = CI->getArgOperand(3);
270 
271   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
272   VectorType *VecType = cast<VectorType>(Src->getType());
273 
274   Type *EltTy = VecType->getElementType();
275 
276   IRBuilder<> Builder(CI->getContext());
277   Instruction *InsertPt = CI;
278   BasicBlock *IfBlock = CI->getParent();
279   Builder.SetInsertPoint(InsertPt);
280   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
281 
282   // Short-cut if the mask is all-true.
283   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
284     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
285     CI->eraseFromParent();
286     return;
287   }
288 
289   // Adjust alignment for the scalar instruction.
290   AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
291   // Bitcast %addr from i8* to EltTy*
292   Type *NewPtrType =
293       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
294   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
295   unsigned VectorWidth = VecType->getNumElements();
296 
297   if (isConstantIntVector(Mask)) {
298     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
299       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
300         continue;
301       Value *OneElt = Builder.CreateExtractElement(Src, Idx);
302       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
303       Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
304     }
305     CI->eraseFromParent();
306     return;
307   }
308 
309   // If the mask is not v1i1, use scalar bit test operations. This generates
310   // better results on X86 at least.
311   Value *SclrMask;
312   if (VectorWidth != 1) {
313     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
314     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
315   }
316 
317   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
318     // Fill the "else" block, created in the previous iteration
319     //
320     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
321     //  %cond = icmp ne i16 %mask_1, 0
322     //  br i1 %mask_1, label %cond.store, label %else
323     //
324     Value *Predicate;
325     if (VectorWidth != 1) {
326       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
327       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
328                                        Builder.getIntN(VectorWidth, 0));
329     } else {
330       Predicate = Builder.CreateExtractElement(Mask, Idx);
331     }
332 
333     // Create "cond" block
334     //
335     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
336     //  %EltAddr = getelementptr i32* %1, i32 0
337     //  %store i32 %OneElt, i32* %EltAddr
338     //
339     BasicBlock *CondBlock =
340         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
341     Builder.SetInsertPoint(InsertPt);
342 
343     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
344     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
345     Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
346 
347     // Create "else" block, fill it in the next iteration
348     BasicBlock *NewIfBlock =
349         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
350     Builder.SetInsertPoint(InsertPt);
351     Instruction *OldBr = IfBlock->getTerminator();
352     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
353     OldBr->eraseFromParent();
354     IfBlock = NewIfBlock;
355   }
356   CI->eraseFromParent();
357 
358   ModifiedDT = true;
359 }
360 
361 // Translate a masked gather intrinsic like
362 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
363 //                               <16 x i1> %Mask, <16 x i32> %Src)
364 // to a chain of basic blocks, with loading element one-by-one if
365 // the appropriate mask bit is set
366 //
367 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
368 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
369 // br i1 %Mask0, label %cond.load, label %else
370 //
371 // cond.load:
372 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
373 // %Load0 = load i32, i32* %Ptr0, align 4
374 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
375 // br label %else
376 //
377 // else:
378 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
379 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
380 // br i1 %Mask1, label %cond.load1, label %else2
381 //
382 // cond.load1:
383 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
384 // %Load1 = load i32, i32* %Ptr1, align 4
385 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
386 // br label %else2
387 // . . .
388 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
389 // ret <16 x i32> %Result
scalarizeMaskedGather(CallInst * CI,bool & ModifiedDT)390 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
391   Value *Ptrs = CI->getArgOperand(0);
392   Value *Alignment = CI->getArgOperand(1);
393   Value *Mask = CI->getArgOperand(2);
394   Value *Src0 = CI->getArgOperand(3);
395 
396   VectorType *VecType = cast<VectorType>(CI->getType());
397   Type *EltTy = VecType->getElementType();
398 
399   IRBuilder<> Builder(CI->getContext());
400   Instruction *InsertPt = CI;
401   BasicBlock *IfBlock = CI->getParent();
402   Builder.SetInsertPoint(InsertPt);
403   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
404 
405   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
406 
407   // The result vector
408   Value *VResult = Src0;
409   unsigned VectorWidth = VecType->getNumElements();
410 
411   // Shorten the way if the mask is a vector of constants.
412   if (isConstantIntVector(Mask)) {
413     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
414       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
415         continue;
416       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
417       LoadInst *Load =
418           Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
419       VResult =
420           Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
421     }
422     CI->replaceAllUsesWith(VResult);
423     CI->eraseFromParent();
424     return;
425   }
426 
427   // If the mask is not v1i1, use scalar bit test operations. This generates
428   // better results on X86 at least.
429   Value *SclrMask;
430   if (VectorWidth != 1) {
431     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
432     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
433   }
434 
435   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
436     // Fill the "else" block, created in the previous iteration
437     //
438     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
439     //  %cond = icmp ne i16 %mask_1, 0
440     //  br i1 %Mask1, label %cond.load, label %else
441     //
442 
443     Value *Predicate;
444     if (VectorWidth != 1) {
445       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
446       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
447                                        Builder.getIntN(VectorWidth, 0));
448     } else {
449       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
450     }
451 
452     // Create "cond" block
453     //
454     //  %EltAddr = getelementptr i32* %1, i32 0
455     //  %Elt = load i32* %EltAddr
456     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
457     //
458     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
459     Builder.SetInsertPoint(InsertPt);
460 
461     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
462     LoadInst *Load =
463         Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
464     Value *NewVResult =
465         Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
466 
467     // Create "else" block, fill it in the next iteration
468     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
469     Builder.SetInsertPoint(InsertPt);
470     Instruction *OldBr = IfBlock->getTerminator();
471     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
472     OldBr->eraseFromParent();
473     BasicBlock *PrevIfBlock = IfBlock;
474     IfBlock = NewIfBlock;
475 
476     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
477     Phi->addIncoming(NewVResult, CondBlock);
478     Phi->addIncoming(VResult, PrevIfBlock);
479     VResult = Phi;
480   }
481 
482   CI->replaceAllUsesWith(VResult);
483   CI->eraseFromParent();
484 
485   ModifiedDT = true;
486 }
487 
488 // Translate a masked scatter intrinsic, like
489 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
490 //                                  <16 x i1> %Mask)
491 // to a chain of basic blocks, that stores element one-by-one if
492 // the appropriate mask bit is set.
493 //
494 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
495 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
496 // br i1 %Mask0, label %cond.store, label %else
497 //
498 // cond.store:
499 // %Elt0 = extractelement <16 x i32> %Src, i32 0
500 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
501 // store i32 %Elt0, i32* %Ptr0, align 4
502 // br label %else
503 //
504 // else:
505 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
506 // br i1 %Mask1, label %cond.store1, label %else2
507 //
508 // cond.store1:
509 // %Elt1 = extractelement <16 x i32> %Src, i32 1
510 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
511 // store i32 %Elt1, i32* %Ptr1, align 4
512 // br label %else2
513 //   . . .
scalarizeMaskedScatter(CallInst * CI,bool & ModifiedDT)514 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
515   Value *Src = CI->getArgOperand(0);
516   Value *Ptrs = CI->getArgOperand(1);
517   Value *Alignment = CI->getArgOperand(2);
518   Value *Mask = CI->getArgOperand(3);
519 
520   assert(isa<VectorType>(Src->getType()) &&
521          "Unexpected data type in masked scatter intrinsic");
522   assert(isa<VectorType>(Ptrs->getType()) &&
523          isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
524          "Vector of pointers is expected in masked scatter intrinsic");
525 
526   IRBuilder<> Builder(CI->getContext());
527   Instruction *InsertPt = CI;
528   BasicBlock *IfBlock = CI->getParent();
529   Builder.SetInsertPoint(InsertPt);
530   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
531 
532   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
533   unsigned VectorWidth = Src->getType()->getVectorNumElements();
534 
535   // Shorten the way if the mask is a vector of constants.
536   if (isConstantIntVector(Mask)) {
537     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
538       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
539         continue;
540       Value *OneElt =
541           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
542       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
543       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
544     }
545     CI->eraseFromParent();
546     return;
547   }
548 
549   // If the mask is not v1i1, use scalar bit test operations. This generates
550   // better results on X86 at least.
551   Value *SclrMask;
552   if (VectorWidth != 1) {
553     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
554     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
555   }
556 
557   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
558     // Fill the "else" block, created in the previous iteration
559     //
560     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
561     //  %cond = icmp ne i16 %mask_1, 0
562     //  br i1 %Mask1, label %cond.store, label %else
563     //
564     Value *Predicate;
565     if (VectorWidth != 1) {
566       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
567       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
568                                        Builder.getIntN(VectorWidth, 0));
569     } else {
570       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
571     }
572 
573     // Create "cond" block
574     //
575     //  %Elt1 = extractelement <16 x i32> %Src, i32 1
576     //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
577     //  %store i32 %Elt1, i32* %Ptr1
578     //
579     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
580     Builder.SetInsertPoint(InsertPt);
581 
582     Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
583     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
584     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
585 
586     // Create "else" block, fill it in the next iteration
587     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
588     Builder.SetInsertPoint(InsertPt);
589     Instruction *OldBr = IfBlock->getTerminator();
590     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
591     OldBr->eraseFromParent();
592     IfBlock = NewIfBlock;
593   }
594   CI->eraseFromParent();
595 
596   ModifiedDT = true;
597 }
598 
scalarizeMaskedExpandLoad(CallInst * CI,bool & ModifiedDT)599 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
600   Value *Ptr = CI->getArgOperand(0);
601   Value *Mask = CI->getArgOperand(1);
602   Value *PassThru = CI->getArgOperand(2);
603 
604   VectorType *VecType = cast<VectorType>(CI->getType());
605 
606   Type *EltTy = VecType->getElementType();
607 
608   IRBuilder<> Builder(CI->getContext());
609   Instruction *InsertPt = CI;
610   BasicBlock *IfBlock = CI->getParent();
611 
612   Builder.SetInsertPoint(InsertPt);
613   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
614 
615   unsigned VectorWidth = VecType->getNumElements();
616 
617   // The result vector
618   Value *VResult = PassThru;
619 
620   // Shorten the way if the mask is a vector of constants.
621   if (isConstantIntVector(Mask)) {
622     unsigned MemIndex = 0;
623     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
624       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
625         continue;
626       Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
627       LoadInst *Load =
628           Builder.CreateAlignedLoad(EltTy, NewPtr, 1, "Load" + Twine(Idx));
629       VResult =
630           Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
631       ++MemIndex;
632     }
633     CI->replaceAllUsesWith(VResult);
634     CI->eraseFromParent();
635     return;
636   }
637 
638   // If the mask is not v1i1, use scalar bit test operations. This generates
639   // better results on X86 at least.
640   Value *SclrMask;
641   if (VectorWidth != 1) {
642     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
643     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
644   }
645 
646   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
647     // Fill the "else" block, created in the previous iteration
648     //
649     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
650     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
651     //  br i1 %mask_1, label %cond.load, label %else
652     //
653 
654     Value *Predicate;
655     if (VectorWidth != 1) {
656       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
657       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
658                                        Builder.getIntN(VectorWidth, 0));
659     } else {
660       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
661     }
662 
663     // Create "cond" block
664     //
665     //  %EltAddr = getelementptr i32* %1, i32 0
666     //  %Elt = load i32* %EltAddr
667     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
668     //
669     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
670                                                      "cond.load");
671     Builder.SetInsertPoint(InsertPt);
672 
673     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
674     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
675 
676     // Move the pointer if there are more blocks to come.
677     Value *NewPtr;
678     if ((Idx + 1) != VectorWidth)
679       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
680 
681     // Create "else" block, fill it in the next iteration
682     BasicBlock *NewIfBlock =
683         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
684     Builder.SetInsertPoint(InsertPt);
685     Instruction *OldBr = IfBlock->getTerminator();
686     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
687     OldBr->eraseFromParent();
688     BasicBlock *PrevIfBlock = IfBlock;
689     IfBlock = NewIfBlock;
690 
691     // Create the phi to join the new and previous value.
692     PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
693     ResultPhi->addIncoming(NewVResult, CondBlock);
694     ResultPhi->addIncoming(VResult, PrevIfBlock);
695     VResult = ResultPhi;
696 
697     // Add a PHI for the pointer if this isn't the last iteration.
698     if ((Idx + 1) != VectorWidth) {
699       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
700       PtrPhi->addIncoming(NewPtr, CondBlock);
701       PtrPhi->addIncoming(Ptr, PrevIfBlock);
702       Ptr = PtrPhi;
703     }
704   }
705 
706   CI->replaceAllUsesWith(VResult);
707   CI->eraseFromParent();
708 
709   ModifiedDT = true;
710 }
711 
scalarizeMaskedCompressStore(CallInst * CI,bool & ModifiedDT)712 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
713   Value *Src = CI->getArgOperand(0);
714   Value *Ptr = CI->getArgOperand(1);
715   Value *Mask = CI->getArgOperand(2);
716 
717   VectorType *VecType = cast<VectorType>(Src->getType());
718 
719   IRBuilder<> Builder(CI->getContext());
720   Instruction *InsertPt = CI;
721   BasicBlock *IfBlock = CI->getParent();
722 
723   Builder.SetInsertPoint(InsertPt);
724   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
725 
726   Type *EltTy = VecType->getVectorElementType();
727 
728   unsigned VectorWidth = VecType->getNumElements();
729 
730   // Shorten the way if the mask is a vector of constants.
731   if (isConstantIntVector(Mask)) {
732     unsigned MemIndex = 0;
733     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
734       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
735         continue;
736       Value *OneElt =
737           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
738       Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
739       Builder.CreateAlignedStore(OneElt, NewPtr, 1);
740       ++MemIndex;
741     }
742     CI->eraseFromParent();
743     return;
744   }
745 
746   // If the mask is not v1i1, use scalar bit test operations. This generates
747   // better results on X86 at least.
748   Value *SclrMask;
749   if (VectorWidth != 1) {
750     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
751     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
752   }
753 
754   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
755     // Fill the "else" block, created in the previous iteration
756     //
757     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
758     //  br i1 %mask_1, label %cond.store, label %else
759     //
760     Value *Predicate;
761     if (VectorWidth != 1) {
762       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
763       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
764                                        Builder.getIntN(VectorWidth, 0));
765     } else {
766       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
767     }
768 
769     // Create "cond" block
770     //
771     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
772     //  %EltAddr = getelementptr i32* %1, i32 0
773     //  %store i32 %OneElt, i32* %EltAddr
774     //
775     BasicBlock *CondBlock =
776         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
777     Builder.SetInsertPoint(InsertPt);
778 
779     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
780     Builder.CreateAlignedStore(OneElt, Ptr, 1);
781 
782     // Move the pointer if there are more blocks to come.
783     Value *NewPtr;
784     if ((Idx + 1) != VectorWidth)
785       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
786 
787     // Create "else" block, fill it in the next iteration
788     BasicBlock *NewIfBlock =
789         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
790     Builder.SetInsertPoint(InsertPt);
791     Instruction *OldBr = IfBlock->getTerminator();
792     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
793     OldBr->eraseFromParent();
794     BasicBlock *PrevIfBlock = IfBlock;
795     IfBlock = NewIfBlock;
796 
797     // Add a PHI for the pointer if this isn't the last iteration.
798     if ((Idx + 1) != VectorWidth) {
799       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
800       PtrPhi->addIncoming(NewPtr, CondBlock);
801       PtrPhi->addIncoming(Ptr, PrevIfBlock);
802       Ptr = PtrPhi;
803     }
804   }
805   CI->eraseFromParent();
806 
807   ModifiedDT = true;
808 }
809 
runOnFunction(Function & F)810 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
811   bool EverMadeChange = false;
812 
813   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
814 
815   bool MadeChange = true;
816   while (MadeChange) {
817     MadeChange = false;
818     for (Function::iterator I = F.begin(); I != F.end();) {
819       BasicBlock *BB = &*I++;
820       bool ModifiedDTOnIteration = false;
821       MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
822 
823       // Restart BB iteration if the dominator tree of the Function was changed
824       if (ModifiedDTOnIteration)
825         break;
826     }
827 
828     EverMadeChange |= MadeChange;
829   }
830 
831   return EverMadeChange;
832 }
833 
optimizeBlock(BasicBlock & BB,bool & ModifiedDT)834 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
835   bool MadeChange = false;
836 
837   BasicBlock::iterator CurInstIterator = BB.begin();
838   while (CurInstIterator != BB.end()) {
839     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
840       MadeChange |= optimizeCallInst(CI, ModifiedDT);
841     if (ModifiedDT)
842       return true;
843   }
844 
845   return MadeChange;
846 }
847 
optimizeCallInst(CallInst * CI,bool & ModifiedDT)848 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
849                                                 bool &ModifiedDT) {
850   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
851   if (II) {
852     unsigned Alignment;
853     switch (II->getIntrinsicID()) {
854     default:
855       break;
856     case Intrinsic::masked_load: {
857       // Scalarize unsupported vector masked load
858       Alignment = cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
859       if (TTI->isLegalMaskedLoad(CI->getType(), MaybeAlign(Alignment)))
860         return false;
861       scalarizeMaskedLoad(CI, ModifiedDT);
862       return true;
863     }
864     case Intrinsic::masked_store: {
865       Alignment = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
866       if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType(),
867                                   MaybeAlign(Alignment)))
868         return false;
869       scalarizeMaskedStore(CI, ModifiedDT);
870       return true;
871     }
872     case Intrinsic::masked_gather:
873       Alignment = cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
874       if (TTI->isLegalMaskedGather(CI->getType(), MaybeAlign(Alignment)))
875         return false;
876       scalarizeMaskedGather(CI, ModifiedDT);
877       return true;
878     case Intrinsic::masked_scatter:
879       Alignment = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
880       if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType(),
881                                     MaybeAlign(Alignment)))
882         return false;
883       scalarizeMaskedScatter(CI, ModifiedDT);
884       return true;
885     case Intrinsic::masked_expandload:
886       if (TTI->isLegalMaskedExpandLoad(CI->getType()))
887         return false;
888       scalarizeMaskedExpandLoad(CI, ModifiedDT);
889       return true;
890     case Intrinsic::masked_compressstore:
891       if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
892         return false;
893       scalarizeMaskedCompressStore(CI, ModifiedDT);
894       return true;
895     }
896   }
897 
898   return false;
899 }
900