1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 // instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/CodeGen/TargetSubtargetInfo.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include <algorithm>
36 #include <cassert>
37
38 using namespace llvm;
39
40 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
41
42 namespace {
43
44 class ScalarizeMaskedMemIntrin : public FunctionPass {
45 const TargetTransformInfo *TTI = nullptr;
46 const DataLayout *DL = nullptr;
47
48 public:
49 static char ID; // Pass identification, replacement for typeid
50
ScalarizeMaskedMemIntrin()51 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
52 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
53 }
54
55 bool runOnFunction(Function &F) override;
56
getPassName() const57 StringRef getPassName() const override {
58 return "Scalarize Masked Memory Intrinsics";
59 }
60
getAnalysisUsage(AnalysisUsage & AU) const61 void getAnalysisUsage(AnalysisUsage &AU) const override {
62 AU.addRequired<TargetTransformInfoWrapperPass>();
63 }
64
65 private:
66 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
67 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
68 };
69
70 } // end anonymous namespace
71
72 char ScalarizeMaskedMemIntrin::ID = 0;
73
74 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
75 "Scalarize unsupported masked memory intrinsics", false, false)
76
createScalarizeMaskedMemIntrinPass()77 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
78 return new ScalarizeMaskedMemIntrin();
79 }
80
isConstantIntVector(Value * Mask)81 static bool isConstantIntVector(Value *Mask) {
82 Constant *C = dyn_cast<Constant>(Mask);
83 if (!C)
84 return false;
85
86 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
87 for (unsigned i = 0; i != NumElts; ++i) {
88 Constant *CElt = C->getAggregateElement(i);
89 if (!CElt || !isa<ConstantInt>(CElt))
90 return false;
91 }
92
93 return true;
94 }
95
96 // Translate a masked load intrinsic like
97 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
98 // <16 x i1> %mask, <16 x i32> %passthru)
99 // to a chain of basic blocks, with loading element one-by-one if
100 // the appropriate mask bit is set
101 //
102 // %1 = bitcast i8* %addr to i32*
103 // %2 = extractelement <16 x i1> %mask, i32 0
104 // br i1 %2, label %cond.load, label %else
105 //
106 // cond.load: ; preds = %0
107 // %3 = getelementptr i32* %1, i32 0
108 // %4 = load i32* %3
109 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
110 // br label %else
111 //
112 // else: ; preds = %0, %cond.load
113 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
114 // %6 = extractelement <16 x i1> %mask, i32 1
115 // br i1 %6, label %cond.load1, label %else2
116 //
117 // cond.load1: ; preds = %else
118 // %7 = getelementptr i32* %1, i32 1
119 // %8 = load i32* %7
120 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
121 // br label %else2
122 //
123 // else2: ; preds = %else, %cond.load1
124 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
125 // %10 = extractelement <16 x i1> %mask, i32 2
126 // br i1 %10, label %cond.load4, label %else5
127 //
scalarizeMaskedLoad(CallInst * CI,bool & ModifiedDT)128 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
129 Value *Ptr = CI->getArgOperand(0);
130 Value *Alignment = CI->getArgOperand(1);
131 Value *Mask = CI->getArgOperand(2);
132 Value *Src0 = CI->getArgOperand(3);
133
134 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
135 VectorType *VecType = cast<FixedVectorType>(CI->getType());
136
137 Type *EltTy = VecType->getElementType();
138
139 IRBuilder<> Builder(CI->getContext());
140 Instruction *InsertPt = CI;
141 BasicBlock *IfBlock = CI->getParent();
142
143 Builder.SetInsertPoint(InsertPt);
144 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
145
146 // Short-cut if the mask is all-true.
147 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
148 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
149 CI->replaceAllUsesWith(NewI);
150 CI->eraseFromParent();
151 return;
152 }
153
154 // Adjust alignment for the scalar instruction.
155 const Align AdjustedAlignVal =
156 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
157 // Bitcast %addr from i8* to EltTy*
158 Type *NewPtrType =
159 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
160 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
161 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
162
163 // The result vector
164 Value *VResult = Src0;
165
166 if (isConstantIntVector(Mask)) {
167 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
168 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
169 continue;
170 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
171 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
172 VResult = Builder.CreateInsertElement(VResult, Load, Idx);
173 }
174 CI->replaceAllUsesWith(VResult);
175 CI->eraseFromParent();
176 return;
177 }
178
179 // If the mask is not v1i1, use scalar bit test operations. This generates
180 // better results on X86 at least.
181 Value *SclrMask;
182 if (VectorWidth != 1) {
183 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
184 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
185 }
186
187 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
188 // Fill the "else" block, created in the previous iteration
189 //
190 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
191 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
192 // %cond = icmp ne i16 %mask_1, 0
193 // br i1 %mask_1, label %cond.load, label %else
194 //
195 Value *Predicate;
196 if (VectorWidth != 1) {
197 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
198 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
199 Builder.getIntN(VectorWidth, 0));
200 } else {
201 Predicate = Builder.CreateExtractElement(Mask, Idx);
202 }
203
204 // Create "cond" block
205 //
206 // %EltAddr = getelementptr i32* %1, i32 0
207 // %Elt = load i32* %EltAddr
208 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
209 //
210 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
211 "cond.load");
212 Builder.SetInsertPoint(InsertPt);
213
214 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
215 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
216 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
217
218 // Create "else" block, fill it in the next iteration
219 BasicBlock *NewIfBlock =
220 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
221 Builder.SetInsertPoint(InsertPt);
222 Instruction *OldBr = IfBlock->getTerminator();
223 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
224 OldBr->eraseFromParent();
225 BasicBlock *PrevIfBlock = IfBlock;
226 IfBlock = NewIfBlock;
227
228 // Create the phi to join the new and previous value.
229 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
230 Phi->addIncoming(NewVResult, CondBlock);
231 Phi->addIncoming(VResult, PrevIfBlock);
232 VResult = Phi;
233 }
234
235 CI->replaceAllUsesWith(VResult);
236 CI->eraseFromParent();
237
238 ModifiedDT = true;
239 }
240
241 // Translate a masked store intrinsic, like
242 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
243 // <16 x i1> %mask)
244 // to a chain of basic blocks, that stores element one-by-one if
245 // the appropriate mask bit is set
246 //
247 // %1 = bitcast i8* %addr to i32*
248 // %2 = extractelement <16 x i1> %mask, i32 0
249 // br i1 %2, label %cond.store, label %else
250 //
251 // cond.store: ; preds = %0
252 // %3 = extractelement <16 x i32> %val, i32 0
253 // %4 = getelementptr i32* %1, i32 0
254 // store i32 %3, i32* %4
255 // br label %else
256 //
257 // else: ; preds = %0, %cond.store
258 // %5 = extractelement <16 x i1> %mask, i32 1
259 // br i1 %5, label %cond.store1, label %else2
260 //
261 // cond.store1: ; preds = %else
262 // %6 = extractelement <16 x i32> %val, i32 1
263 // %7 = getelementptr i32* %1, i32 1
264 // store i32 %6, i32* %7
265 // br label %else2
266 // . . .
scalarizeMaskedStore(CallInst * CI,bool & ModifiedDT)267 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
268 Value *Src = CI->getArgOperand(0);
269 Value *Ptr = CI->getArgOperand(1);
270 Value *Alignment = CI->getArgOperand(2);
271 Value *Mask = CI->getArgOperand(3);
272
273 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
274 auto *VecType = cast<VectorType>(Src->getType());
275
276 Type *EltTy = VecType->getElementType();
277
278 IRBuilder<> Builder(CI->getContext());
279 Instruction *InsertPt = CI;
280 BasicBlock *IfBlock = CI->getParent();
281 Builder.SetInsertPoint(InsertPt);
282 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
283
284 // Short-cut if the mask is all-true.
285 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
286 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
287 CI->eraseFromParent();
288 return;
289 }
290
291 // Adjust alignment for the scalar instruction.
292 const Align AdjustedAlignVal =
293 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
294 // Bitcast %addr from i8* to EltTy*
295 Type *NewPtrType =
296 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
297 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
298 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
299
300 if (isConstantIntVector(Mask)) {
301 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
302 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
303 continue;
304 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
305 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
306 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
307 }
308 CI->eraseFromParent();
309 return;
310 }
311
312 // If the mask is not v1i1, use scalar bit test operations. This generates
313 // better results on X86 at least.
314 Value *SclrMask;
315 if (VectorWidth != 1) {
316 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
317 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
318 }
319
320 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
321 // Fill the "else" block, created in the previous iteration
322 //
323 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
324 // %cond = icmp ne i16 %mask_1, 0
325 // br i1 %mask_1, label %cond.store, label %else
326 //
327 Value *Predicate;
328 if (VectorWidth != 1) {
329 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
330 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
331 Builder.getIntN(VectorWidth, 0));
332 } else {
333 Predicate = Builder.CreateExtractElement(Mask, Idx);
334 }
335
336 // Create "cond" block
337 //
338 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
339 // %EltAddr = getelementptr i32* %1, i32 0
340 // %store i32 %OneElt, i32* %EltAddr
341 //
342 BasicBlock *CondBlock =
343 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
344 Builder.SetInsertPoint(InsertPt);
345
346 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
347 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
348 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
349
350 // Create "else" block, fill it in the next iteration
351 BasicBlock *NewIfBlock =
352 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
353 Builder.SetInsertPoint(InsertPt);
354 Instruction *OldBr = IfBlock->getTerminator();
355 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
356 OldBr->eraseFromParent();
357 IfBlock = NewIfBlock;
358 }
359 CI->eraseFromParent();
360
361 ModifiedDT = true;
362 }
363
364 // Translate a masked gather intrinsic like
365 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
366 // <16 x i1> %Mask, <16 x i32> %Src)
367 // to a chain of basic blocks, with loading element one-by-one if
368 // the appropriate mask bit is set
369 //
370 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
371 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
372 // br i1 %Mask0, label %cond.load, label %else
373 //
374 // cond.load:
375 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
376 // %Load0 = load i32, i32* %Ptr0, align 4
377 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
378 // br label %else
379 //
380 // else:
381 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
382 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
383 // br i1 %Mask1, label %cond.load1, label %else2
384 //
385 // cond.load1:
386 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
387 // %Load1 = load i32, i32* %Ptr1, align 4
388 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
389 // br label %else2
390 // . . .
391 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
392 // ret <16 x i32> %Result
scalarizeMaskedGather(CallInst * CI,bool & ModifiedDT)393 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
394 Value *Ptrs = CI->getArgOperand(0);
395 Value *Alignment = CI->getArgOperand(1);
396 Value *Mask = CI->getArgOperand(2);
397 Value *Src0 = CI->getArgOperand(3);
398
399 auto *VecType = cast<FixedVectorType>(CI->getType());
400 Type *EltTy = VecType->getElementType();
401
402 IRBuilder<> Builder(CI->getContext());
403 Instruction *InsertPt = CI;
404 BasicBlock *IfBlock = CI->getParent();
405 Builder.SetInsertPoint(InsertPt);
406 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
407
408 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
409
410 // The result vector
411 Value *VResult = Src0;
412 unsigned VectorWidth = VecType->getNumElements();
413
414 // Shorten the way if the mask is a vector of constants.
415 if (isConstantIntVector(Mask)) {
416 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
417 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
418 continue;
419 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
420 LoadInst *Load =
421 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
422 VResult =
423 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
424 }
425 CI->replaceAllUsesWith(VResult);
426 CI->eraseFromParent();
427 return;
428 }
429
430 // If the mask is not v1i1, use scalar bit test operations. This generates
431 // better results on X86 at least.
432 Value *SclrMask;
433 if (VectorWidth != 1) {
434 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
435 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
436 }
437
438 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
439 // Fill the "else" block, created in the previous iteration
440 //
441 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
442 // %cond = icmp ne i16 %mask_1, 0
443 // br i1 %Mask1, label %cond.load, label %else
444 //
445
446 Value *Predicate;
447 if (VectorWidth != 1) {
448 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
449 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
450 Builder.getIntN(VectorWidth, 0));
451 } else {
452 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
453 }
454
455 // Create "cond" block
456 //
457 // %EltAddr = getelementptr i32* %1, i32 0
458 // %Elt = load i32* %EltAddr
459 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
460 //
461 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
462 Builder.SetInsertPoint(InsertPt);
463
464 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
465 LoadInst *Load =
466 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
467 Value *NewVResult =
468 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
469
470 // Create "else" block, fill it in the next iteration
471 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
472 Builder.SetInsertPoint(InsertPt);
473 Instruction *OldBr = IfBlock->getTerminator();
474 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
475 OldBr->eraseFromParent();
476 BasicBlock *PrevIfBlock = IfBlock;
477 IfBlock = NewIfBlock;
478
479 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
480 Phi->addIncoming(NewVResult, CondBlock);
481 Phi->addIncoming(VResult, PrevIfBlock);
482 VResult = Phi;
483 }
484
485 CI->replaceAllUsesWith(VResult);
486 CI->eraseFromParent();
487
488 ModifiedDT = true;
489 }
490
491 // Translate a masked scatter intrinsic, like
492 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
493 // <16 x i1> %Mask)
494 // to a chain of basic blocks, that stores element one-by-one if
495 // the appropriate mask bit is set.
496 //
497 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
498 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
499 // br i1 %Mask0, label %cond.store, label %else
500 //
501 // cond.store:
502 // %Elt0 = extractelement <16 x i32> %Src, i32 0
503 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
504 // store i32 %Elt0, i32* %Ptr0, align 4
505 // br label %else
506 //
507 // else:
508 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
509 // br i1 %Mask1, label %cond.store1, label %else2
510 //
511 // cond.store1:
512 // %Elt1 = extractelement <16 x i32> %Src, i32 1
513 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
514 // store i32 %Elt1, i32* %Ptr1, align 4
515 // br label %else2
516 // . . .
scalarizeMaskedScatter(CallInst * CI,bool & ModifiedDT)517 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
518 Value *Src = CI->getArgOperand(0);
519 Value *Ptrs = CI->getArgOperand(1);
520 Value *Alignment = CI->getArgOperand(2);
521 Value *Mask = CI->getArgOperand(3);
522
523 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
524
525 assert(
526 isa<VectorType>(Ptrs->getType()) &&
527 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
528 "Vector of pointers is expected in masked scatter intrinsic");
529
530 IRBuilder<> Builder(CI->getContext());
531 Instruction *InsertPt = CI;
532 BasicBlock *IfBlock = CI->getParent();
533 Builder.SetInsertPoint(InsertPt);
534 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
535
536 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
537 unsigned VectorWidth = SrcFVTy->getNumElements();
538
539 // Shorten the way if the mask is a vector of constants.
540 if (isConstantIntVector(Mask)) {
541 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
542 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
543 continue;
544 Value *OneElt =
545 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
546 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
547 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
548 }
549 CI->eraseFromParent();
550 return;
551 }
552
553 // If the mask is not v1i1, use scalar bit test operations. This generates
554 // better results on X86 at least.
555 Value *SclrMask;
556 if (VectorWidth != 1) {
557 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
558 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
559 }
560
561 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
562 // Fill the "else" block, created in the previous iteration
563 //
564 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
565 // %cond = icmp ne i16 %mask_1, 0
566 // br i1 %Mask1, label %cond.store, label %else
567 //
568 Value *Predicate;
569 if (VectorWidth != 1) {
570 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
571 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
572 Builder.getIntN(VectorWidth, 0));
573 } else {
574 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
575 }
576
577 // Create "cond" block
578 //
579 // %Elt1 = extractelement <16 x i32> %Src, i32 1
580 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
581 // %store i32 %Elt1, i32* %Ptr1
582 //
583 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
584 Builder.SetInsertPoint(InsertPt);
585
586 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
587 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
588 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
589
590 // Create "else" block, fill it in the next iteration
591 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
592 Builder.SetInsertPoint(InsertPt);
593 Instruction *OldBr = IfBlock->getTerminator();
594 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
595 OldBr->eraseFromParent();
596 IfBlock = NewIfBlock;
597 }
598 CI->eraseFromParent();
599
600 ModifiedDT = true;
601 }
602
scalarizeMaskedExpandLoad(CallInst * CI,bool & ModifiedDT)603 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
604 Value *Ptr = CI->getArgOperand(0);
605 Value *Mask = CI->getArgOperand(1);
606 Value *PassThru = CI->getArgOperand(2);
607
608 auto *VecType = cast<FixedVectorType>(CI->getType());
609
610 Type *EltTy = VecType->getElementType();
611
612 IRBuilder<> Builder(CI->getContext());
613 Instruction *InsertPt = CI;
614 BasicBlock *IfBlock = CI->getParent();
615
616 Builder.SetInsertPoint(InsertPt);
617 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
618
619 unsigned VectorWidth = VecType->getNumElements();
620
621 // The result vector
622 Value *VResult = PassThru;
623
624 // Shorten the way if the mask is a vector of constants.
625 // Create a build_vector pattern, with loads/undefs as necessary and then
626 // shuffle blend with the pass through value.
627 if (isConstantIntVector(Mask)) {
628 unsigned MemIndex = 0;
629 VResult = UndefValue::get(VecType);
630 SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem);
631 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
632 Value *InsertElt;
633 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
634 InsertElt = UndefValue::get(EltTy);
635 ShuffleMask[Idx] = Idx + VectorWidth;
636 } else {
637 Value *NewPtr =
638 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
639 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1),
640 "Load" + Twine(Idx));
641 ShuffleMask[Idx] = Idx;
642 ++MemIndex;
643 }
644 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
645 "Res" + Twine(Idx));
646 }
647 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
648 CI->replaceAllUsesWith(VResult);
649 CI->eraseFromParent();
650 return;
651 }
652
653 // If the mask is not v1i1, use scalar bit test operations. This generates
654 // better results on X86 at least.
655 Value *SclrMask;
656 if (VectorWidth != 1) {
657 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
658 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
659 }
660
661 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
662 // Fill the "else" block, created in the previous iteration
663 //
664 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
665 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
666 // br i1 %mask_1, label %cond.load, label %else
667 //
668
669 Value *Predicate;
670 if (VectorWidth != 1) {
671 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
672 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
673 Builder.getIntN(VectorWidth, 0));
674 } else {
675 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
676 }
677
678 // Create "cond" block
679 //
680 // %EltAddr = getelementptr i32* %1, i32 0
681 // %Elt = load i32* %EltAddr
682 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
683 //
684 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
685 "cond.load");
686 Builder.SetInsertPoint(InsertPt);
687
688 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1));
689 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
690
691 // Move the pointer if there are more blocks to come.
692 Value *NewPtr;
693 if ((Idx + 1) != VectorWidth)
694 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
695
696 // Create "else" block, fill it in the next iteration
697 BasicBlock *NewIfBlock =
698 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
699 Builder.SetInsertPoint(InsertPt);
700 Instruction *OldBr = IfBlock->getTerminator();
701 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
702 OldBr->eraseFromParent();
703 BasicBlock *PrevIfBlock = IfBlock;
704 IfBlock = NewIfBlock;
705
706 // Create the phi to join the new and previous value.
707 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
708 ResultPhi->addIncoming(NewVResult, CondBlock);
709 ResultPhi->addIncoming(VResult, PrevIfBlock);
710 VResult = ResultPhi;
711
712 // Add a PHI for the pointer if this isn't the last iteration.
713 if ((Idx + 1) != VectorWidth) {
714 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
715 PtrPhi->addIncoming(NewPtr, CondBlock);
716 PtrPhi->addIncoming(Ptr, PrevIfBlock);
717 Ptr = PtrPhi;
718 }
719 }
720
721 CI->replaceAllUsesWith(VResult);
722 CI->eraseFromParent();
723
724 ModifiedDT = true;
725 }
726
scalarizeMaskedCompressStore(CallInst * CI,bool & ModifiedDT)727 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
728 Value *Src = CI->getArgOperand(0);
729 Value *Ptr = CI->getArgOperand(1);
730 Value *Mask = CI->getArgOperand(2);
731
732 auto *VecType = cast<FixedVectorType>(Src->getType());
733
734 IRBuilder<> Builder(CI->getContext());
735 Instruction *InsertPt = CI;
736 BasicBlock *IfBlock = CI->getParent();
737
738 Builder.SetInsertPoint(InsertPt);
739 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
740
741 Type *EltTy = VecType->getElementType();
742
743 unsigned VectorWidth = VecType->getNumElements();
744
745 // Shorten the way if the mask is a vector of constants.
746 if (isConstantIntVector(Mask)) {
747 unsigned MemIndex = 0;
748 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
749 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
750 continue;
751 Value *OneElt =
752 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
753 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
754 Builder.CreateAlignedStore(OneElt, NewPtr, Align(1));
755 ++MemIndex;
756 }
757 CI->eraseFromParent();
758 return;
759 }
760
761 // If the mask is not v1i1, use scalar bit test operations. This generates
762 // better results on X86 at least.
763 Value *SclrMask;
764 if (VectorWidth != 1) {
765 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
766 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
767 }
768
769 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
770 // Fill the "else" block, created in the previous iteration
771 //
772 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
773 // br i1 %mask_1, label %cond.store, label %else
774 //
775 Value *Predicate;
776 if (VectorWidth != 1) {
777 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
778 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
779 Builder.getIntN(VectorWidth, 0));
780 } else {
781 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
782 }
783
784 // Create "cond" block
785 //
786 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
787 // %EltAddr = getelementptr i32* %1, i32 0
788 // %store i32 %OneElt, i32* %EltAddr
789 //
790 BasicBlock *CondBlock =
791 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
792 Builder.SetInsertPoint(InsertPt);
793
794 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
795 Builder.CreateAlignedStore(OneElt, Ptr, Align(1));
796
797 // Move the pointer if there are more blocks to come.
798 Value *NewPtr;
799 if ((Idx + 1) != VectorWidth)
800 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
801
802 // Create "else" block, fill it in the next iteration
803 BasicBlock *NewIfBlock =
804 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
805 Builder.SetInsertPoint(InsertPt);
806 Instruction *OldBr = IfBlock->getTerminator();
807 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
808 OldBr->eraseFromParent();
809 BasicBlock *PrevIfBlock = IfBlock;
810 IfBlock = NewIfBlock;
811
812 // Add a PHI for the pointer if this isn't the last iteration.
813 if ((Idx + 1) != VectorWidth) {
814 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
815 PtrPhi->addIncoming(NewPtr, CondBlock);
816 PtrPhi->addIncoming(Ptr, PrevIfBlock);
817 Ptr = PtrPhi;
818 }
819 }
820 CI->eraseFromParent();
821
822 ModifiedDT = true;
823 }
824
runOnFunction(Function & F)825 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
826 bool EverMadeChange = false;
827
828 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
829 DL = &F.getParent()->getDataLayout();
830
831 bool MadeChange = true;
832 while (MadeChange) {
833 MadeChange = false;
834 for (Function::iterator I = F.begin(); I != F.end();) {
835 BasicBlock *BB = &*I++;
836 bool ModifiedDTOnIteration = false;
837 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
838
839 // Restart BB iteration if the dominator tree of the Function was changed
840 if (ModifiedDTOnIteration)
841 break;
842 }
843
844 EverMadeChange |= MadeChange;
845 }
846
847 return EverMadeChange;
848 }
849
optimizeBlock(BasicBlock & BB,bool & ModifiedDT)850 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
851 bool MadeChange = false;
852
853 BasicBlock::iterator CurInstIterator = BB.begin();
854 while (CurInstIterator != BB.end()) {
855 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
856 MadeChange |= optimizeCallInst(CI, ModifiedDT);
857 if (ModifiedDT)
858 return true;
859 }
860
861 return MadeChange;
862 }
863
optimizeCallInst(CallInst * CI,bool & ModifiedDT)864 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
865 bool &ModifiedDT) {
866 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
867 if (II) {
868 // The scalarization code below does not work for scalable vectors.
869 if (isa<ScalableVectorType>(II->getType()) ||
870 any_of(II->arg_operands(),
871 [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
872 return false;
873
874 switch (II->getIntrinsicID()) {
875 default:
876 break;
877 case Intrinsic::masked_load:
878 // Scalarize unsupported vector masked load
879 if (TTI->isLegalMaskedLoad(
880 CI->getType(),
881 cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
882 return false;
883 scalarizeMaskedLoad(CI, ModifiedDT);
884 return true;
885 case Intrinsic::masked_store:
886 if (TTI->isLegalMaskedStore(
887 CI->getArgOperand(0)->getType(),
888 cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
889 return false;
890 scalarizeMaskedStore(CI, ModifiedDT);
891 return true;
892 case Intrinsic::masked_gather: {
893 unsigned AlignmentInt =
894 cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
895 Type *LoadTy = CI->getType();
896 Align Alignment =
897 DL->getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), LoadTy);
898 if (TTI->isLegalMaskedGather(LoadTy, Alignment))
899 return false;
900 scalarizeMaskedGather(CI, ModifiedDT);
901 return true;
902 }
903 case Intrinsic::masked_scatter: {
904 unsigned AlignmentInt =
905 cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
906 Type *StoreTy = CI->getArgOperand(0)->getType();
907 Align Alignment =
908 DL->getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), StoreTy);
909 if (TTI->isLegalMaskedScatter(StoreTy, Alignment))
910 return false;
911 scalarizeMaskedScatter(CI, ModifiedDT);
912 return true;
913 }
914 case Intrinsic::masked_expandload:
915 if (TTI->isLegalMaskedExpandLoad(CI->getType()))
916 return false;
917 scalarizeMaskedExpandLoad(CI, ModifiedDT);
918 return true;
919 case Intrinsic::masked_compressstore:
920 if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
921 return false;
922 scalarizeMaskedCompressStore(CI, ModifiedDT);
923 return true;
924 }
925 }
926
927 return false;
928 }
929