1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 //                                    instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Transforms/Scalar.h"
36 #include <algorithm>
37 #include <cassert>
38 
39 using namespace llvm;
40 
41 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
42 
43 namespace {
44 
45 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
46 public:
47   static char ID; // Pass identification, replacement for typeid
48 
ScalarizeMaskedMemIntrinLegacyPass()49   explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
50     initializeScalarizeMaskedMemIntrinLegacyPassPass(
51         *PassRegistry::getPassRegistry());
52   }
53 
54   bool runOnFunction(Function &F) override;
55 
getPassName() const56   StringRef getPassName() const override {
57     return "Scalarize Masked Memory Intrinsics";
58   }
59 
getAnalysisUsage(AnalysisUsage & AU) const60   void getAnalysisUsage(AnalysisUsage &AU) const override {
61     AU.addRequired<TargetTransformInfoWrapperPass>();
62   }
63 };
64 
65 } // end anonymous namespace
66 
67 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
68                           const TargetTransformInfo &TTI, const DataLayout &DL);
69 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
70                              const TargetTransformInfo &TTI,
71                              const DataLayout &DL);
72 
73 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
74 
75 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
76                       "Scalarize unsupported masked memory intrinsics", false,
77                       false)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)78 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
79 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
80                     "Scalarize unsupported masked memory intrinsics", false,
81                     false)
82 
83 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() {
84   return new ScalarizeMaskedMemIntrinLegacyPass();
85 }
86 
isConstantIntVector(Value * Mask)87 static bool isConstantIntVector(Value *Mask) {
88   Constant *C = dyn_cast<Constant>(Mask);
89   if (!C)
90     return false;
91 
92   unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
93   for (unsigned i = 0; i != NumElts; ++i) {
94     Constant *CElt = C->getAggregateElement(i);
95     if (!CElt || !isa<ConstantInt>(CElt))
96       return false;
97   }
98 
99   return true;
100 }
101 
102 // Translate a masked load intrinsic like
103 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
104 //                               <16 x i1> %mask, <16 x i32> %passthru)
105 // to a chain of basic blocks, with loading element one-by-one if
106 // the appropriate mask bit is set
107 //
108 //  %1 = bitcast i8* %addr to i32*
109 //  %2 = extractelement <16 x i1> %mask, i32 0
110 //  br i1 %2, label %cond.load, label %else
111 //
112 // cond.load:                                        ; preds = %0
113 //  %3 = getelementptr i32* %1, i32 0
114 //  %4 = load i32* %3
115 //  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
116 //  br label %else
117 //
118 // else:                                             ; preds = %0, %cond.load
119 //  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
120 //  %6 = extractelement <16 x i1> %mask, i32 1
121 //  br i1 %6, label %cond.load1, label %else2
122 //
123 // cond.load1:                                       ; preds = %else
124 //  %7 = getelementptr i32* %1, i32 1
125 //  %8 = load i32* %7
126 //  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
127 //  br label %else2
128 //
129 // else2:                                          ; preds = %else, %cond.load1
130 //  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
131 //  %10 = extractelement <16 x i1> %mask, i32 2
132 //  br i1 %10, label %cond.load4, label %else5
133 //
scalarizeMaskedLoad(CallInst * CI,bool & ModifiedDT)134 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
135   Value *Ptr = CI->getArgOperand(0);
136   Value *Alignment = CI->getArgOperand(1);
137   Value *Mask = CI->getArgOperand(2);
138   Value *Src0 = CI->getArgOperand(3);
139 
140   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
141   VectorType *VecType = cast<FixedVectorType>(CI->getType());
142 
143   Type *EltTy = VecType->getElementType();
144 
145   IRBuilder<> Builder(CI->getContext());
146   Instruction *InsertPt = CI;
147   BasicBlock *IfBlock = CI->getParent();
148 
149   Builder.SetInsertPoint(InsertPt);
150   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
151 
152   // Short-cut if the mask is all-true.
153   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
154     Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
155     CI->replaceAllUsesWith(NewI);
156     CI->eraseFromParent();
157     return;
158   }
159 
160   // Adjust alignment for the scalar instruction.
161   const Align AdjustedAlignVal =
162       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
163   // Bitcast %addr from i8* to EltTy*
164   Type *NewPtrType =
165       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
166   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
167   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
168 
169   // The result vector
170   Value *VResult = Src0;
171 
172   if (isConstantIntVector(Mask)) {
173     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
174       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
175         continue;
176       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
177       LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
178       VResult = Builder.CreateInsertElement(VResult, Load, Idx);
179     }
180     CI->replaceAllUsesWith(VResult);
181     CI->eraseFromParent();
182     return;
183   }
184 
185   // If the mask is not v1i1, use scalar bit test operations. This generates
186   // better results on X86 at least.
187   Value *SclrMask;
188   if (VectorWidth != 1) {
189     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
190     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
191   }
192 
193   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
194     // Fill the "else" block, created in the previous iteration
195     //
196     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
197     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
198     //  %cond = icmp ne i16 %mask_1, 0
199     //  br i1 %mask_1, label %cond.load, label %else
200     //
201     Value *Predicate;
202     if (VectorWidth != 1) {
203       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
204       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
205                                        Builder.getIntN(VectorWidth, 0));
206     } else {
207       Predicate = Builder.CreateExtractElement(Mask, Idx);
208     }
209 
210     // Create "cond" block
211     //
212     //  %EltAddr = getelementptr i32* %1, i32 0
213     //  %Elt = load i32* %EltAddr
214     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
215     //
216     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
217                                                      "cond.load");
218     Builder.SetInsertPoint(InsertPt);
219 
220     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
221     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
222     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
223 
224     // Create "else" block, fill it in the next iteration
225     BasicBlock *NewIfBlock =
226         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
227     Builder.SetInsertPoint(InsertPt);
228     Instruction *OldBr = IfBlock->getTerminator();
229     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
230     OldBr->eraseFromParent();
231     BasicBlock *PrevIfBlock = IfBlock;
232     IfBlock = NewIfBlock;
233 
234     // Create the phi to join the new and previous value.
235     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
236     Phi->addIncoming(NewVResult, CondBlock);
237     Phi->addIncoming(VResult, PrevIfBlock);
238     VResult = Phi;
239   }
240 
241   CI->replaceAllUsesWith(VResult);
242   CI->eraseFromParent();
243 
244   ModifiedDT = true;
245 }
246 
247 // Translate a masked store intrinsic, like
248 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
249 //                               <16 x i1> %mask)
250 // to a chain of basic blocks, that stores element one-by-one if
251 // the appropriate mask bit is set
252 //
253 //   %1 = bitcast i8* %addr to i32*
254 //   %2 = extractelement <16 x i1> %mask, i32 0
255 //   br i1 %2, label %cond.store, label %else
256 //
257 // cond.store:                                       ; preds = %0
258 //   %3 = extractelement <16 x i32> %val, i32 0
259 //   %4 = getelementptr i32* %1, i32 0
260 //   store i32 %3, i32* %4
261 //   br label %else
262 //
263 // else:                                             ; preds = %0, %cond.store
264 //   %5 = extractelement <16 x i1> %mask, i32 1
265 //   br i1 %5, label %cond.store1, label %else2
266 //
267 // cond.store1:                                      ; preds = %else
268 //   %6 = extractelement <16 x i32> %val, i32 1
269 //   %7 = getelementptr i32* %1, i32 1
270 //   store i32 %6, i32* %7
271 //   br label %else2
272 //   . . .
scalarizeMaskedStore(CallInst * CI,bool & ModifiedDT)273 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
274   Value *Src = CI->getArgOperand(0);
275   Value *Ptr = CI->getArgOperand(1);
276   Value *Alignment = CI->getArgOperand(2);
277   Value *Mask = CI->getArgOperand(3);
278 
279   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
280   auto *VecType = cast<VectorType>(Src->getType());
281 
282   Type *EltTy = VecType->getElementType();
283 
284   IRBuilder<> Builder(CI->getContext());
285   Instruction *InsertPt = CI;
286   BasicBlock *IfBlock = CI->getParent();
287   Builder.SetInsertPoint(InsertPt);
288   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
289 
290   // Short-cut if the mask is all-true.
291   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
292     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
293     CI->eraseFromParent();
294     return;
295   }
296 
297   // Adjust alignment for the scalar instruction.
298   const Align AdjustedAlignVal =
299       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
300   // Bitcast %addr from i8* to EltTy*
301   Type *NewPtrType =
302       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
303   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
304   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
305 
306   if (isConstantIntVector(Mask)) {
307     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
308       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
309         continue;
310       Value *OneElt = Builder.CreateExtractElement(Src, Idx);
311       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
312       Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
313     }
314     CI->eraseFromParent();
315     return;
316   }
317 
318   // If the mask is not v1i1, use scalar bit test operations. This generates
319   // better results on X86 at least.
320   Value *SclrMask;
321   if (VectorWidth != 1) {
322     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
323     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
324   }
325 
326   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
327     // Fill the "else" block, created in the previous iteration
328     //
329     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
330     //  %cond = icmp ne i16 %mask_1, 0
331     //  br i1 %mask_1, label %cond.store, label %else
332     //
333     Value *Predicate;
334     if (VectorWidth != 1) {
335       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
336       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
337                                        Builder.getIntN(VectorWidth, 0));
338     } else {
339       Predicate = Builder.CreateExtractElement(Mask, Idx);
340     }
341 
342     // Create "cond" block
343     //
344     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
345     //  %EltAddr = getelementptr i32* %1, i32 0
346     //  %store i32 %OneElt, i32* %EltAddr
347     //
348     BasicBlock *CondBlock =
349         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
350     Builder.SetInsertPoint(InsertPt);
351 
352     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
353     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
354     Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
355 
356     // Create "else" block, fill it in the next iteration
357     BasicBlock *NewIfBlock =
358         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
359     Builder.SetInsertPoint(InsertPt);
360     Instruction *OldBr = IfBlock->getTerminator();
361     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
362     OldBr->eraseFromParent();
363     IfBlock = NewIfBlock;
364   }
365   CI->eraseFromParent();
366 
367   ModifiedDT = true;
368 }
369 
370 // Translate a masked gather intrinsic like
371 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
372 //                               <16 x i1> %Mask, <16 x i32> %Src)
373 // to a chain of basic blocks, with loading element one-by-one if
374 // the appropriate mask bit is set
375 //
376 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
377 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
378 // br i1 %Mask0, label %cond.load, label %else
379 //
380 // cond.load:
381 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
382 // %Load0 = load i32, i32* %Ptr0, align 4
383 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
384 // br label %else
385 //
386 // else:
387 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
388 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
389 // br i1 %Mask1, label %cond.load1, label %else2
390 //
391 // cond.load1:
392 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
393 // %Load1 = load i32, i32* %Ptr1, align 4
394 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
395 // br label %else2
396 // . . .
397 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
398 // ret <16 x i32> %Result
scalarizeMaskedGather(CallInst * CI,bool & ModifiedDT)399 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
400   Value *Ptrs = CI->getArgOperand(0);
401   Value *Alignment = CI->getArgOperand(1);
402   Value *Mask = CI->getArgOperand(2);
403   Value *Src0 = CI->getArgOperand(3);
404 
405   auto *VecType = cast<FixedVectorType>(CI->getType());
406   Type *EltTy = VecType->getElementType();
407 
408   IRBuilder<> Builder(CI->getContext());
409   Instruction *InsertPt = CI;
410   BasicBlock *IfBlock = CI->getParent();
411   Builder.SetInsertPoint(InsertPt);
412   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
413 
414   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
415 
416   // The result vector
417   Value *VResult = Src0;
418   unsigned VectorWidth = VecType->getNumElements();
419 
420   // Shorten the way if the mask is a vector of constants.
421   if (isConstantIntVector(Mask)) {
422     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
423       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
424         continue;
425       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
426       LoadInst *Load =
427           Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
428       VResult =
429           Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
430     }
431     CI->replaceAllUsesWith(VResult);
432     CI->eraseFromParent();
433     return;
434   }
435 
436   // If the mask is not v1i1, use scalar bit test operations. This generates
437   // better results on X86 at least.
438   Value *SclrMask;
439   if (VectorWidth != 1) {
440     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
441     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
442   }
443 
444   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
445     // Fill the "else" block, created in the previous iteration
446     //
447     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
448     //  %cond = icmp ne i16 %mask_1, 0
449     //  br i1 %Mask1, label %cond.load, label %else
450     //
451 
452     Value *Predicate;
453     if (VectorWidth != 1) {
454       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
455       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
456                                        Builder.getIntN(VectorWidth, 0));
457     } else {
458       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
459     }
460 
461     // Create "cond" block
462     //
463     //  %EltAddr = getelementptr i32* %1, i32 0
464     //  %Elt = load i32* %EltAddr
465     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
466     //
467     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
468     Builder.SetInsertPoint(InsertPt);
469 
470     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
471     LoadInst *Load =
472         Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
473     Value *NewVResult =
474         Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
475 
476     // Create "else" block, fill it in the next iteration
477     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
478     Builder.SetInsertPoint(InsertPt);
479     Instruction *OldBr = IfBlock->getTerminator();
480     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
481     OldBr->eraseFromParent();
482     BasicBlock *PrevIfBlock = IfBlock;
483     IfBlock = NewIfBlock;
484 
485     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
486     Phi->addIncoming(NewVResult, CondBlock);
487     Phi->addIncoming(VResult, PrevIfBlock);
488     VResult = Phi;
489   }
490 
491   CI->replaceAllUsesWith(VResult);
492   CI->eraseFromParent();
493 
494   ModifiedDT = true;
495 }
496 
497 // Translate a masked scatter intrinsic, like
498 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
499 //                                  <16 x i1> %Mask)
500 // to a chain of basic blocks, that stores element one-by-one if
501 // the appropriate mask bit is set.
502 //
503 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
504 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
505 // br i1 %Mask0, label %cond.store, label %else
506 //
507 // cond.store:
508 // %Elt0 = extractelement <16 x i32> %Src, i32 0
509 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
510 // store i32 %Elt0, i32* %Ptr0, align 4
511 // br label %else
512 //
513 // else:
514 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
515 // br i1 %Mask1, label %cond.store1, label %else2
516 //
517 // cond.store1:
518 // %Elt1 = extractelement <16 x i32> %Src, i32 1
519 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
520 // store i32 %Elt1, i32* %Ptr1, align 4
521 // br label %else2
522 //   . . .
scalarizeMaskedScatter(CallInst * CI,bool & ModifiedDT)523 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
524   Value *Src = CI->getArgOperand(0);
525   Value *Ptrs = CI->getArgOperand(1);
526   Value *Alignment = CI->getArgOperand(2);
527   Value *Mask = CI->getArgOperand(3);
528 
529   auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
530 
531   assert(
532       isa<VectorType>(Ptrs->getType()) &&
533       isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
534       "Vector of pointers is expected in masked scatter intrinsic");
535 
536   IRBuilder<> Builder(CI->getContext());
537   Instruction *InsertPt = CI;
538   BasicBlock *IfBlock = CI->getParent();
539   Builder.SetInsertPoint(InsertPt);
540   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
541 
542   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
543   unsigned VectorWidth = SrcFVTy->getNumElements();
544 
545   // Shorten the way if the mask is a vector of constants.
546   if (isConstantIntVector(Mask)) {
547     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
548       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
549         continue;
550       Value *OneElt =
551           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
552       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
553       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
554     }
555     CI->eraseFromParent();
556     return;
557   }
558 
559   // If the mask is not v1i1, use scalar bit test operations. This generates
560   // better results on X86 at least.
561   Value *SclrMask;
562   if (VectorWidth != 1) {
563     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
564     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
565   }
566 
567   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
568     // Fill the "else" block, created in the previous iteration
569     //
570     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
571     //  %cond = icmp ne i16 %mask_1, 0
572     //  br i1 %Mask1, label %cond.store, label %else
573     //
574     Value *Predicate;
575     if (VectorWidth != 1) {
576       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
577       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
578                                        Builder.getIntN(VectorWidth, 0));
579     } else {
580       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
581     }
582 
583     // Create "cond" block
584     //
585     //  %Elt1 = extractelement <16 x i32> %Src, i32 1
586     //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
587     //  %store i32 %Elt1, i32* %Ptr1
588     //
589     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
590     Builder.SetInsertPoint(InsertPt);
591 
592     Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
593     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
594     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
595 
596     // Create "else" block, fill it in the next iteration
597     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
598     Builder.SetInsertPoint(InsertPt);
599     Instruction *OldBr = IfBlock->getTerminator();
600     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
601     OldBr->eraseFromParent();
602     IfBlock = NewIfBlock;
603   }
604   CI->eraseFromParent();
605 
606   ModifiedDT = true;
607 }
608 
scalarizeMaskedExpandLoad(CallInst * CI,bool & ModifiedDT)609 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
610   Value *Ptr = CI->getArgOperand(0);
611   Value *Mask = CI->getArgOperand(1);
612   Value *PassThru = CI->getArgOperand(2);
613 
614   auto *VecType = cast<FixedVectorType>(CI->getType());
615 
616   Type *EltTy = VecType->getElementType();
617 
618   IRBuilder<> Builder(CI->getContext());
619   Instruction *InsertPt = CI;
620   BasicBlock *IfBlock = CI->getParent();
621 
622   Builder.SetInsertPoint(InsertPt);
623   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
624 
625   unsigned VectorWidth = VecType->getNumElements();
626 
627   // The result vector
628   Value *VResult = PassThru;
629 
630   // Shorten the way if the mask is a vector of constants.
631   // Create a build_vector pattern, with loads/undefs as necessary and then
632   // shuffle blend with the pass through value.
633   if (isConstantIntVector(Mask)) {
634     unsigned MemIndex = 0;
635     VResult = UndefValue::get(VecType);
636     SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem);
637     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
638       Value *InsertElt;
639       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
640         InsertElt = UndefValue::get(EltTy);
641         ShuffleMask[Idx] = Idx + VectorWidth;
642       } else {
643         Value *NewPtr =
644             Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
645         InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1),
646                                               "Load" + Twine(Idx));
647         ShuffleMask[Idx] = Idx;
648         ++MemIndex;
649       }
650       VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
651                                             "Res" + Twine(Idx));
652     }
653     VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
654     CI->replaceAllUsesWith(VResult);
655     CI->eraseFromParent();
656     return;
657   }
658 
659   // If the mask is not v1i1, use scalar bit test operations. This generates
660   // better results on X86 at least.
661   Value *SclrMask;
662   if (VectorWidth != 1) {
663     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
664     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
665   }
666 
667   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
668     // Fill the "else" block, created in the previous iteration
669     //
670     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
671     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
672     //  br i1 %mask_1, label %cond.load, label %else
673     //
674 
675     Value *Predicate;
676     if (VectorWidth != 1) {
677       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
678       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
679                                        Builder.getIntN(VectorWidth, 0));
680     } else {
681       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
682     }
683 
684     // Create "cond" block
685     //
686     //  %EltAddr = getelementptr i32* %1, i32 0
687     //  %Elt = load i32* %EltAddr
688     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
689     //
690     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
691                                                      "cond.load");
692     Builder.SetInsertPoint(InsertPt);
693 
694     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1));
695     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
696 
697     // Move the pointer if there are more blocks to come.
698     Value *NewPtr;
699     if ((Idx + 1) != VectorWidth)
700       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
701 
702     // Create "else" block, fill it in the next iteration
703     BasicBlock *NewIfBlock =
704         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
705     Builder.SetInsertPoint(InsertPt);
706     Instruction *OldBr = IfBlock->getTerminator();
707     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
708     OldBr->eraseFromParent();
709     BasicBlock *PrevIfBlock = IfBlock;
710     IfBlock = NewIfBlock;
711 
712     // Create the phi to join the new and previous value.
713     PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
714     ResultPhi->addIncoming(NewVResult, CondBlock);
715     ResultPhi->addIncoming(VResult, PrevIfBlock);
716     VResult = ResultPhi;
717 
718     // Add a PHI for the pointer if this isn't the last iteration.
719     if ((Idx + 1) != VectorWidth) {
720       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
721       PtrPhi->addIncoming(NewPtr, CondBlock);
722       PtrPhi->addIncoming(Ptr, PrevIfBlock);
723       Ptr = PtrPhi;
724     }
725   }
726 
727   CI->replaceAllUsesWith(VResult);
728   CI->eraseFromParent();
729 
730   ModifiedDT = true;
731 }
732 
scalarizeMaskedCompressStore(CallInst * CI,bool & ModifiedDT)733 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
734   Value *Src = CI->getArgOperand(0);
735   Value *Ptr = CI->getArgOperand(1);
736   Value *Mask = CI->getArgOperand(2);
737 
738   auto *VecType = cast<FixedVectorType>(Src->getType());
739 
740   IRBuilder<> Builder(CI->getContext());
741   Instruction *InsertPt = CI;
742   BasicBlock *IfBlock = CI->getParent();
743 
744   Builder.SetInsertPoint(InsertPt);
745   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
746 
747   Type *EltTy = VecType->getElementType();
748 
749   unsigned VectorWidth = VecType->getNumElements();
750 
751   // Shorten the way if the mask is a vector of constants.
752   if (isConstantIntVector(Mask)) {
753     unsigned MemIndex = 0;
754     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
755       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
756         continue;
757       Value *OneElt =
758           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
759       Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
760       Builder.CreateAlignedStore(OneElt, NewPtr, Align(1));
761       ++MemIndex;
762     }
763     CI->eraseFromParent();
764     return;
765   }
766 
767   // If the mask is not v1i1, use scalar bit test operations. This generates
768   // better results on X86 at least.
769   Value *SclrMask;
770   if (VectorWidth != 1) {
771     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
772     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
773   }
774 
775   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
776     // Fill the "else" block, created in the previous iteration
777     //
778     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
779     //  br i1 %mask_1, label %cond.store, label %else
780     //
781     Value *Predicate;
782     if (VectorWidth != 1) {
783       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
784       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
785                                        Builder.getIntN(VectorWidth, 0));
786     } else {
787       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
788     }
789 
790     // Create "cond" block
791     //
792     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
793     //  %EltAddr = getelementptr i32* %1, i32 0
794     //  %store i32 %OneElt, i32* %EltAddr
795     //
796     BasicBlock *CondBlock =
797         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
798     Builder.SetInsertPoint(InsertPt);
799 
800     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
801     Builder.CreateAlignedStore(OneElt, Ptr, Align(1));
802 
803     // Move the pointer if there are more blocks to come.
804     Value *NewPtr;
805     if ((Idx + 1) != VectorWidth)
806       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
807 
808     // Create "else" block, fill it in the next iteration
809     BasicBlock *NewIfBlock =
810         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
811     Builder.SetInsertPoint(InsertPt);
812     Instruction *OldBr = IfBlock->getTerminator();
813     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
814     OldBr->eraseFromParent();
815     BasicBlock *PrevIfBlock = IfBlock;
816     IfBlock = NewIfBlock;
817 
818     // Add a PHI for the pointer if this isn't the last iteration.
819     if ((Idx + 1) != VectorWidth) {
820       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
821       PtrPhi->addIncoming(NewPtr, CondBlock);
822       PtrPhi->addIncoming(Ptr, PrevIfBlock);
823       Ptr = PtrPhi;
824     }
825   }
826   CI->eraseFromParent();
827 
828   ModifiedDT = true;
829 }
830 
runImpl(Function & F,const TargetTransformInfo & TTI)831 static bool runImpl(Function &F, const TargetTransformInfo &TTI) {
832   bool EverMadeChange = false;
833   bool MadeChange = true;
834   auto &DL = F.getParent()->getDataLayout();
835   while (MadeChange) {
836     MadeChange = false;
837     for (Function::iterator I = F.begin(); I != F.end();) {
838       BasicBlock *BB = &*I++;
839       bool ModifiedDTOnIteration = false;
840       MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration, TTI, DL);
841 
842       // Restart BB iteration if the dominator tree of the Function was changed
843       if (ModifiedDTOnIteration)
844         break;
845     }
846 
847     EverMadeChange |= MadeChange;
848   }
849   return EverMadeChange;
850 }
851 
runOnFunction(Function & F)852 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
853   auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
854   return runImpl(F, TTI);
855 }
856 
857 PreservedAnalyses
run(Function & F,FunctionAnalysisManager & AM)858 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
859   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
860   if (!runImpl(F, TTI))
861     return PreservedAnalyses::all();
862   PreservedAnalyses PA;
863   PA.preserve<TargetIRAnalysis>();
864   return PA;
865 }
866 
optimizeBlock(BasicBlock & BB,bool & ModifiedDT,const TargetTransformInfo & TTI,const DataLayout & DL)867 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
868                           const TargetTransformInfo &TTI,
869                           const DataLayout &DL) {
870   bool MadeChange = false;
871 
872   BasicBlock::iterator CurInstIterator = BB.begin();
873   while (CurInstIterator != BB.end()) {
874     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
875       MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL);
876     if (ModifiedDT)
877       return true;
878   }
879 
880   return MadeChange;
881 }
882 
optimizeCallInst(CallInst * CI,bool & ModifiedDT,const TargetTransformInfo & TTI,const DataLayout & DL)883 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
884                              const TargetTransformInfo &TTI,
885                              const DataLayout &DL) {
886   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
887   if (II) {
888     // The scalarization code below does not work for scalable vectors.
889     if (isa<ScalableVectorType>(II->getType()) ||
890         any_of(II->arg_operands(),
891                [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
892       return false;
893 
894     switch (II->getIntrinsicID()) {
895     default:
896       break;
897     case Intrinsic::masked_load:
898       // Scalarize unsupported vector masked load
899       if (TTI.isLegalMaskedLoad(
900               CI->getType(),
901               cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
902         return false;
903       scalarizeMaskedLoad(CI, ModifiedDT);
904       return true;
905     case Intrinsic::masked_store:
906       if (TTI.isLegalMaskedStore(
907               CI->getArgOperand(0)->getType(),
908               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
909         return false;
910       scalarizeMaskedStore(CI, ModifiedDT);
911       return true;
912     case Intrinsic::masked_gather: {
913       unsigned AlignmentInt =
914           cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
915       Type *LoadTy = CI->getType();
916       Align Alignment =
917           DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), LoadTy);
918       if (TTI.isLegalMaskedGather(LoadTy, Alignment))
919         return false;
920       scalarizeMaskedGather(CI, ModifiedDT);
921       return true;
922     }
923     case Intrinsic::masked_scatter: {
924       unsigned AlignmentInt =
925           cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
926       Type *StoreTy = CI->getArgOperand(0)->getType();
927       Align Alignment =
928           DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), StoreTy);
929       if (TTI.isLegalMaskedScatter(StoreTy, Alignment))
930         return false;
931       scalarizeMaskedScatter(CI, ModifiedDT);
932       return true;
933     }
934     case Intrinsic::masked_expandload:
935       if (TTI.isLegalMaskedExpandLoad(CI->getType()))
936         return false;
937       scalarizeMaskedExpandLoad(CI, ModifiedDT);
938       return true;
939     case Intrinsic::masked_compressstore:
940       if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
941         return false;
942       scalarizeMaskedCompressStore(CI, ModifiedDT);
943       return true;
944     }
945   }
946 
947   return false;
948 }
949