1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 // instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Transforms/Scalar.h"
36 #include <algorithm>
37 #include <cassert>
38
39 using namespace llvm;
40
41 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
42
43 namespace {
44
45 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
46 public:
47 static char ID; // Pass identification, replacement for typeid
48
ScalarizeMaskedMemIntrinLegacyPass()49 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
50 initializeScalarizeMaskedMemIntrinLegacyPassPass(
51 *PassRegistry::getPassRegistry());
52 }
53
54 bool runOnFunction(Function &F) override;
55
getPassName() const56 StringRef getPassName() const override {
57 return "Scalarize Masked Memory Intrinsics";
58 }
59
getAnalysisUsage(AnalysisUsage & AU) const60 void getAnalysisUsage(AnalysisUsage &AU) const override {
61 AU.addRequired<TargetTransformInfoWrapperPass>();
62 }
63 };
64
65 } // end anonymous namespace
66
67 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
68 const TargetTransformInfo &TTI, const DataLayout &DL);
69 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
70 const TargetTransformInfo &TTI,
71 const DataLayout &DL);
72
73 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
74
75 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
76 "Scalarize unsupported masked memory intrinsics", false,
77 false)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)78 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
79 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
80 "Scalarize unsupported masked memory intrinsics", false,
81 false)
82
83 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() {
84 return new ScalarizeMaskedMemIntrinLegacyPass();
85 }
86
isConstantIntVector(Value * Mask)87 static bool isConstantIntVector(Value *Mask) {
88 Constant *C = dyn_cast<Constant>(Mask);
89 if (!C)
90 return false;
91
92 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
93 for (unsigned i = 0; i != NumElts; ++i) {
94 Constant *CElt = C->getAggregateElement(i);
95 if (!CElt || !isa<ConstantInt>(CElt))
96 return false;
97 }
98
99 return true;
100 }
101
102 // Translate a masked load intrinsic like
103 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
104 // <16 x i1> %mask, <16 x i32> %passthru)
105 // to a chain of basic blocks, with loading element one-by-one if
106 // the appropriate mask bit is set
107 //
108 // %1 = bitcast i8* %addr to i32*
109 // %2 = extractelement <16 x i1> %mask, i32 0
110 // br i1 %2, label %cond.load, label %else
111 //
112 // cond.load: ; preds = %0
113 // %3 = getelementptr i32* %1, i32 0
114 // %4 = load i32* %3
115 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
116 // br label %else
117 //
118 // else: ; preds = %0, %cond.load
119 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
120 // %6 = extractelement <16 x i1> %mask, i32 1
121 // br i1 %6, label %cond.load1, label %else2
122 //
123 // cond.load1: ; preds = %else
124 // %7 = getelementptr i32* %1, i32 1
125 // %8 = load i32* %7
126 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
127 // br label %else2
128 //
129 // else2: ; preds = %else, %cond.load1
130 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
131 // %10 = extractelement <16 x i1> %mask, i32 2
132 // br i1 %10, label %cond.load4, label %else5
133 //
scalarizeMaskedLoad(CallInst * CI,bool & ModifiedDT)134 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
135 Value *Ptr = CI->getArgOperand(0);
136 Value *Alignment = CI->getArgOperand(1);
137 Value *Mask = CI->getArgOperand(2);
138 Value *Src0 = CI->getArgOperand(3);
139
140 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
141 VectorType *VecType = cast<FixedVectorType>(CI->getType());
142
143 Type *EltTy = VecType->getElementType();
144
145 IRBuilder<> Builder(CI->getContext());
146 Instruction *InsertPt = CI;
147 BasicBlock *IfBlock = CI->getParent();
148
149 Builder.SetInsertPoint(InsertPt);
150 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
151
152 // Short-cut if the mask is all-true.
153 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
154 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
155 CI->replaceAllUsesWith(NewI);
156 CI->eraseFromParent();
157 return;
158 }
159
160 // Adjust alignment for the scalar instruction.
161 const Align AdjustedAlignVal =
162 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
163 // Bitcast %addr from i8* to EltTy*
164 Type *NewPtrType =
165 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
166 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
167 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
168
169 // The result vector
170 Value *VResult = Src0;
171
172 if (isConstantIntVector(Mask)) {
173 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
174 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
175 continue;
176 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
177 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
178 VResult = Builder.CreateInsertElement(VResult, Load, Idx);
179 }
180 CI->replaceAllUsesWith(VResult);
181 CI->eraseFromParent();
182 return;
183 }
184
185 // If the mask is not v1i1, use scalar bit test operations. This generates
186 // better results on X86 at least.
187 Value *SclrMask;
188 if (VectorWidth != 1) {
189 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
190 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
191 }
192
193 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
194 // Fill the "else" block, created in the previous iteration
195 //
196 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
197 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
198 // %cond = icmp ne i16 %mask_1, 0
199 // br i1 %mask_1, label %cond.load, label %else
200 //
201 Value *Predicate;
202 if (VectorWidth != 1) {
203 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
204 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
205 Builder.getIntN(VectorWidth, 0));
206 } else {
207 Predicate = Builder.CreateExtractElement(Mask, Idx);
208 }
209
210 // Create "cond" block
211 //
212 // %EltAddr = getelementptr i32* %1, i32 0
213 // %Elt = load i32* %EltAddr
214 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
215 //
216 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
217 "cond.load");
218 Builder.SetInsertPoint(InsertPt);
219
220 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
221 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
222 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
223
224 // Create "else" block, fill it in the next iteration
225 BasicBlock *NewIfBlock =
226 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
227 Builder.SetInsertPoint(InsertPt);
228 Instruction *OldBr = IfBlock->getTerminator();
229 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
230 OldBr->eraseFromParent();
231 BasicBlock *PrevIfBlock = IfBlock;
232 IfBlock = NewIfBlock;
233
234 // Create the phi to join the new and previous value.
235 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
236 Phi->addIncoming(NewVResult, CondBlock);
237 Phi->addIncoming(VResult, PrevIfBlock);
238 VResult = Phi;
239 }
240
241 CI->replaceAllUsesWith(VResult);
242 CI->eraseFromParent();
243
244 ModifiedDT = true;
245 }
246
247 // Translate a masked store intrinsic, like
248 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
249 // <16 x i1> %mask)
250 // to a chain of basic blocks, that stores element one-by-one if
251 // the appropriate mask bit is set
252 //
253 // %1 = bitcast i8* %addr to i32*
254 // %2 = extractelement <16 x i1> %mask, i32 0
255 // br i1 %2, label %cond.store, label %else
256 //
257 // cond.store: ; preds = %0
258 // %3 = extractelement <16 x i32> %val, i32 0
259 // %4 = getelementptr i32* %1, i32 0
260 // store i32 %3, i32* %4
261 // br label %else
262 //
263 // else: ; preds = %0, %cond.store
264 // %5 = extractelement <16 x i1> %mask, i32 1
265 // br i1 %5, label %cond.store1, label %else2
266 //
267 // cond.store1: ; preds = %else
268 // %6 = extractelement <16 x i32> %val, i32 1
269 // %7 = getelementptr i32* %1, i32 1
270 // store i32 %6, i32* %7
271 // br label %else2
272 // . . .
scalarizeMaskedStore(CallInst * CI,bool & ModifiedDT)273 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
274 Value *Src = CI->getArgOperand(0);
275 Value *Ptr = CI->getArgOperand(1);
276 Value *Alignment = CI->getArgOperand(2);
277 Value *Mask = CI->getArgOperand(3);
278
279 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
280 auto *VecType = cast<VectorType>(Src->getType());
281
282 Type *EltTy = VecType->getElementType();
283
284 IRBuilder<> Builder(CI->getContext());
285 Instruction *InsertPt = CI;
286 BasicBlock *IfBlock = CI->getParent();
287 Builder.SetInsertPoint(InsertPt);
288 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
289
290 // Short-cut if the mask is all-true.
291 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
292 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
293 CI->eraseFromParent();
294 return;
295 }
296
297 // Adjust alignment for the scalar instruction.
298 const Align AdjustedAlignVal =
299 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
300 // Bitcast %addr from i8* to EltTy*
301 Type *NewPtrType =
302 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
303 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
304 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
305
306 if (isConstantIntVector(Mask)) {
307 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
308 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
309 continue;
310 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
311 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
312 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
313 }
314 CI->eraseFromParent();
315 return;
316 }
317
318 // If the mask is not v1i1, use scalar bit test operations. This generates
319 // better results on X86 at least.
320 Value *SclrMask;
321 if (VectorWidth != 1) {
322 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
323 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
324 }
325
326 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
327 // Fill the "else" block, created in the previous iteration
328 //
329 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
330 // %cond = icmp ne i16 %mask_1, 0
331 // br i1 %mask_1, label %cond.store, label %else
332 //
333 Value *Predicate;
334 if (VectorWidth != 1) {
335 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
336 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
337 Builder.getIntN(VectorWidth, 0));
338 } else {
339 Predicate = Builder.CreateExtractElement(Mask, Idx);
340 }
341
342 // Create "cond" block
343 //
344 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
345 // %EltAddr = getelementptr i32* %1, i32 0
346 // %store i32 %OneElt, i32* %EltAddr
347 //
348 BasicBlock *CondBlock =
349 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
350 Builder.SetInsertPoint(InsertPt);
351
352 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
353 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
354 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
355
356 // Create "else" block, fill it in the next iteration
357 BasicBlock *NewIfBlock =
358 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
359 Builder.SetInsertPoint(InsertPt);
360 Instruction *OldBr = IfBlock->getTerminator();
361 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
362 OldBr->eraseFromParent();
363 IfBlock = NewIfBlock;
364 }
365 CI->eraseFromParent();
366
367 ModifiedDT = true;
368 }
369
370 // Translate a masked gather intrinsic like
371 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
372 // <16 x i1> %Mask, <16 x i32> %Src)
373 // to a chain of basic blocks, with loading element one-by-one if
374 // the appropriate mask bit is set
375 //
376 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
377 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
378 // br i1 %Mask0, label %cond.load, label %else
379 //
380 // cond.load:
381 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
382 // %Load0 = load i32, i32* %Ptr0, align 4
383 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
384 // br label %else
385 //
386 // else:
387 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
388 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
389 // br i1 %Mask1, label %cond.load1, label %else2
390 //
391 // cond.load1:
392 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
393 // %Load1 = load i32, i32* %Ptr1, align 4
394 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
395 // br label %else2
396 // . . .
397 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
398 // ret <16 x i32> %Result
scalarizeMaskedGather(CallInst * CI,bool & ModifiedDT)399 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
400 Value *Ptrs = CI->getArgOperand(0);
401 Value *Alignment = CI->getArgOperand(1);
402 Value *Mask = CI->getArgOperand(2);
403 Value *Src0 = CI->getArgOperand(3);
404
405 auto *VecType = cast<FixedVectorType>(CI->getType());
406 Type *EltTy = VecType->getElementType();
407
408 IRBuilder<> Builder(CI->getContext());
409 Instruction *InsertPt = CI;
410 BasicBlock *IfBlock = CI->getParent();
411 Builder.SetInsertPoint(InsertPt);
412 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
413
414 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
415
416 // The result vector
417 Value *VResult = Src0;
418 unsigned VectorWidth = VecType->getNumElements();
419
420 // Shorten the way if the mask is a vector of constants.
421 if (isConstantIntVector(Mask)) {
422 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
423 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
424 continue;
425 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
426 LoadInst *Load =
427 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
428 VResult =
429 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
430 }
431 CI->replaceAllUsesWith(VResult);
432 CI->eraseFromParent();
433 return;
434 }
435
436 // If the mask is not v1i1, use scalar bit test operations. This generates
437 // better results on X86 at least.
438 Value *SclrMask;
439 if (VectorWidth != 1) {
440 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
441 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
442 }
443
444 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
445 // Fill the "else" block, created in the previous iteration
446 //
447 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
448 // %cond = icmp ne i16 %mask_1, 0
449 // br i1 %Mask1, label %cond.load, label %else
450 //
451
452 Value *Predicate;
453 if (VectorWidth != 1) {
454 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
455 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
456 Builder.getIntN(VectorWidth, 0));
457 } else {
458 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
459 }
460
461 // Create "cond" block
462 //
463 // %EltAddr = getelementptr i32* %1, i32 0
464 // %Elt = load i32* %EltAddr
465 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
466 //
467 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
468 Builder.SetInsertPoint(InsertPt);
469
470 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
471 LoadInst *Load =
472 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
473 Value *NewVResult =
474 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
475
476 // Create "else" block, fill it in the next iteration
477 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
478 Builder.SetInsertPoint(InsertPt);
479 Instruction *OldBr = IfBlock->getTerminator();
480 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
481 OldBr->eraseFromParent();
482 BasicBlock *PrevIfBlock = IfBlock;
483 IfBlock = NewIfBlock;
484
485 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
486 Phi->addIncoming(NewVResult, CondBlock);
487 Phi->addIncoming(VResult, PrevIfBlock);
488 VResult = Phi;
489 }
490
491 CI->replaceAllUsesWith(VResult);
492 CI->eraseFromParent();
493
494 ModifiedDT = true;
495 }
496
497 // Translate a masked scatter intrinsic, like
498 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
499 // <16 x i1> %Mask)
500 // to a chain of basic blocks, that stores element one-by-one if
501 // the appropriate mask bit is set.
502 //
503 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
504 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
505 // br i1 %Mask0, label %cond.store, label %else
506 //
507 // cond.store:
508 // %Elt0 = extractelement <16 x i32> %Src, i32 0
509 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
510 // store i32 %Elt0, i32* %Ptr0, align 4
511 // br label %else
512 //
513 // else:
514 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
515 // br i1 %Mask1, label %cond.store1, label %else2
516 //
517 // cond.store1:
518 // %Elt1 = extractelement <16 x i32> %Src, i32 1
519 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
520 // store i32 %Elt1, i32* %Ptr1, align 4
521 // br label %else2
522 // . . .
scalarizeMaskedScatter(CallInst * CI,bool & ModifiedDT)523 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
524 Value *Src = CI->getArgOperand(0);
525 Value *Ptrs = CI->getArgOperand(1);
526 Value *Alignment = CI->getArgOperand(2);
527 Value *Mask = CI->getArgOperand(3);
528
529 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
530
531 assert(
532 isa<VectorType>(Ptrs->getType()) &&
533 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
534 "Vector of pointers is expected in masked scatter intrinsic");
535
536 IRBuilder<> Builder(CI->getContext());
537 Instruction *InsertPt = CI;
538 BasicBlock *IfBlock = CI->getParent();
539 Builder.SetInsertPoint(InsertPt);
540 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
541
542 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
543 unsigned VectorWidth = SrcFVTy->getNumElements();
544
545 // Shorten the way if the mask is a vector of constants.
546 if (isConstantIntVector(Mask)) {
547 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
548 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
549 continue;
550 Value *OneElt =
551 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
552 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
553 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
554 }
555 CI->eraseFromParent();
556 return;
557 }
558
559 // If the mask is not v1i1, use scalar bit test operations. This generates
560 // better results on X86 at least.
561 Value *SclrMask;
562 if (VectorWidth != 1) {
563 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
564 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
565 }
566
567 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
568 // Fill the "else" block, created in the previous iteration
569 //
570 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
571 // %cond = icmp ne i16 %mask_1, 0
572 // br i1 %Mask1, label %cond.store, label %else
573 //
574 Value *Predicate;
575 if (VectorWidth != 1) {
576 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
577 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
578 Builder.getIntN(VectorWidth, 0));
579 } else {
580 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
581 }
582
583 // Create "cond" block
584 //
585 // %Elt1 = extractelement <16 x i32> %Src, i32 1
586 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
587 // %store i32 %Elt1, i32* %Ptr1
588 //
589 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
590 Builder.SetInsertPoint(InsertPt);
591
592 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
593 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
594 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
595
596 // Create "else" block, fill it in the next iteration
597 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
598 Builder.SetInsertPoint(InsertPt);
599 Instruction *OldBr = IfBlock->getTerminator();
600 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
601 OldBr->eraseFromParent();
602 IfBlock = NewIfBlock;
603 }
604 CI->eraseFromParent();
605
606 ModifiedDT = true;
607 }
608
scalarizeMaskedExpandLoad(CallInst * CI,bool & ModifiedDT)609 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
610 Value *Ptr = CI->getArgOperand(0);
611 Value *Mask = CI->getArgOperand(1);
612 Value *PassThru = CI->getArgOperand(2);
613
614 auto *VecType = cast<FixedVectorType>(CI->getType());
615
616 Type *EltTy = VecType->getElementType();
617
618 IRBuilder<> Builder(CI->getContext());
619 Instruction *InsertPt = CI;
620 BasicBlock *IfBlock = CI->getParent();
621
622 Builder.SetInsertPoint(InsertPt);
623 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
624
625 unsigned VectorWidth = VecType->getNumElements();
626
627 // The result vector
628 Value *VResult = PassThru;
629
630 // Shorten the way if the mask is a vector of constants.
631 // Create a build_vector pattern, with loads/undefs as necessary and then
632 // shuffle blend with the pass through value.
633 if (isConstantIntVector(Mask)) {
634 unsigned MemIndex = 0;
635 VResult = UndefValue::get(VecType);
636 SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem);
637 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
638 Value *InsertElt;
639 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
640 InsertElt = UndefValue::get(EltTy);
641 ShuffleMask[Idx] = Idx + VectorWidth;
642 } else {
643 Value *NewPtr =
644 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
645 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1),
646 "Load" + Twine(Idx));
647 ShuffleMask[Idx] = Idx;
648 ++MemIndex;
649 }
650 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
651 "Res" + Twine(Idx));
652 }
653 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
654 CI->replaceAllUsesWith(VResult);
655 CI->eraseFromParent();
656 return;
657 }
658
659 // If the mask is not v1i1, use scalar bit test operations. This generates
660 // better results on X86 at least.
661 Value *SclrMask;
662 if (VectorWidth != 1) {
663 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
664 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
665 }
666
667 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
668 // Fill the "else" block, created in the previous iteration
669 //
670 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
671 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
672 // br i1 %mask_1, label %cond.load, label %else
673 //
674
675 Value *Predicate;
676 if (VectorWidth != 1) {
677 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
678 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
679 Builder.getIntN(VectorWidth, 0));
680 } else {
681 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
682 }
683
684 // Create "cond" block
685 //
686 // %EltAddr = getelementptr i32* %1, i32 0
687 // %Elt = load i32* %EltAddr
688 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
689 //
690 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
691 "cond.load");
692 Builder.SetInsertPoint(InsertPt);
693
694 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1));
695 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
696
697 // Move the pointer if there are more blocks to come.
698 Value *NewPtr;
699 if ((Idx + 1) != VectorWidth)
700 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
701
702 // Create "else" block, fill it in the next iteration
703 BasicBlock *NewIfBlock =
704 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
705 Builder.SetInsertPoint(InsertPt);
706 Instruction *OldBr = IfBlock->getTerminator();
707 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
708 OldBr->eraseFromParent();
709 BasicBlock *PrevIfBlock = IfBlock;
710 IfBlock = NewIfBlock;
711
712 // Create the phi to join the new and previous value.
713 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
714 ResultPhi->addIncoming(NewVResult, CondBlock);
715 ResultPhi->addIncoming(VResult, PrevIfBlock);
716 VResult = ResultPhi;
717
718 // Add a PHI for the pointer if this isn't the last iteration.
719 if ((Idx + 1) != VectorWidth) {
720 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
721 PtrPhi->addIncoming(NewPtr, CondBlock);
722 PtrPhi->addIncoming(Ptr, PrevIfBlock);
723 Ptr = PtrPhi;
724 }
725 }
726
727 CI->replaceAllUsesWith(VResult);
728 CI->eraseFromParent();
729
730 ModifiedDT = true;
731 }
732
scalarizeMaskedCompressStore(CallInst * CI,bool & ModifiedDT)733 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
734 Value *Src = CI->getArgOperand(0);
735 Value *Ptr = CI->getArgOperand(1);
736 Value *Mask = CI->getArgOperand(2);
737
738 auto *VecType = cast<FixedVectorType>(Src->getType());
739
740 IRBuilder<> Builder(CI->getContext());
741 Instruction *InsertPt = CI;
742 BasicBlock *IfBlock = CI->getParent();
743
744 Builder.SetInsertPoint(InsertPt);
745 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
746
747 Type *EltTy = VecType->getElementType();
748
749 unsigned VectorWidth = VecType->getNumElements();
750
751 // Shorten the way if the mask is a vector of constants.
752 if (isConstantIntVector(Mask)) {
753 unsigned MemIndex = 0;
754 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
755 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
756 continue;
757 Value *OneElt =
758 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
759 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
760 Builder.CreateAlignedStore(OneElt, NewPtr, Align(1));
761 ++MemIndex;
762 }
763 CI->eraseFromParent();
764 return;
765 }
766
767 // If the mask is not v1i1, use scalar bit test operations. This generates
768 // better results on X86 at least.
769 Value *SclrMask;
770 if (VectorWidth != 1) {
771 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
772 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
773 }
774
775 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
776 // Fill the "else" block, created in the previous iteration
777 //
778 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
779 // br i1 %mask_1, label %cond.store, label %else
780 //
781 Value *Predicate;
782 if (VectorWidth != 1) {
783 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
784 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
785 Builder.getIntN(VectorWidth, 0));
786 } else {
787 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
788 }
789
790 // Create "cond" block
791 //
792 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
793 // %EltAddr = getelementptr i32* %1, i32 0
794 // %store i32 %OneElt, i32* %EltAddr
795 //
796 BasicBlock *CondBlock =
797 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
798 Builder.SetInsertPoint(InsertPt);
799
800 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
801 Builder.CreateAlignedStore(OneElt, Ptr, Align(1));
802
803 // Move the pointer if there are more blocks to come.
804 Value *NewPtr;
805 if ((Idx + 1) != VectorWidth)
806 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
807
808 // Create "else" block, fill it in the next iteration
809 BasicBlock *NewIfBlock =
810 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
811 Builder.SetInsertPoint(InsertPt);
812 Instruction *OldBr = IfBlock->getTerminator();
813 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
814 OldBr->eraseFromParent();
815 BasicBlock *PrevIfBlock = IfBlock;
816 IfBlock = NewIfBlock;
817
818 // Add a PHI for the pointer if this isn't the last iteration.
819 if ((Idx + 1) != VectorWidth) {
820 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
821 PtrPhi->addIncoming(NewPtr, CondBlock);
822 PtrPhi->addIncoming(Ptr, PrevIfBlock);
823 Ptr = PtrPhi;
824 }
825 }
826 CI->eraseFromParent();
827
828 ModifiedDT = true;
829 }
830
runImpl(Function & F,const TargetTransformInfo & TTI)831 static bool runImpl(Function &F, const TargetTransformInfo &TTI) {
832 bool EverMadeChange = false;
833 bool MadeChange = true;
834 auto &DL = F.getParent()->getDataLayout();
835 while (MadeChange) {
836 MadeChange = false;
837 for (Function::iterator I = F.begin(); I != F.end();) {
838 BasicBlock *BB = &*I++;
839 bool ModifiedDTOnIteration = false;
840 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration, TTI, DL);
841
842 // Restart BB iteration if the dominator tree of the Function was changed
843 if (ModifiedDTOnIteration)
844 break;
845 }
846
847 EverMadeChange |= MadeChange;
848 }
849 return EverMadeChange;
850 }
851
runOnFunction(Function & F)852 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
853 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
854 return runImpl(F, TTI);
855 }
856
857 PreservedAnalyses
run(Function & F,FunctionAnalysisManager & AM)858 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
859 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
860 if (!runImpl(F, TTI))
861 return PreservedAnalyses::all();
862 PreservedAnalyses PA;
863 PA.preserve<TargetIRAnalysis>();
864 return PA;
865 }
866
optimizeBlock(BasicBlock & BB,bool & ModifiedDT,const TargetTransformInfo & TTI,const DataLayout & DL)867 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
868 const TargetTransformInfo &TTI,
869 const DataLayout &DL) {
870 bool MadeChange = false;
871
872 BasicBlock::iterator CurInstIterator = BB.begin();
873 while (CurInstIterator != BB.end()) {
874 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
875 MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL);
876 if (ModifiedDT)
877 return true;
878 }
879
880 return MadeChange;
881 }
882
optimizeCallInst(CallInst * CI,bool & ModifiedDT,const TargetTransformInfo & TTI,const DataLayout & DL)883 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
884 const TargetTransformInfo &TTI,
885 const DataLayout &DL) {
886 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
887 if (II) {
888 // The scalarization code below does not work for scalable vectors.
889 if (isa<ScalableVectorType>(II->getType()) ||
890 any_of(II->arg_operands(),
891 [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
892 return false;
893
894 switch (II->getIntrinsicID()) {
895 default:
896 break;
897 case Intrinsic::masked_load:
898 // Scalarize unsupported vector masked load
899 if (TTI.isLegalMaskedLoad(
900 CI->getType(),
901 cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
902 return false;
903 scalarizeMaskedLoad(CI, ModifiedDT);
904 return true;
905 case Intrinsic::masked_store:
906 if (TTI.isLegalMaskedStore(
907 CI->getArgOperand(0)->getType(),
908 cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
909 return false;
910 scalarizeMaskedStore(CI, ModifiedDT);
911 return true;
912 case Intrinsic::masked_gather: {
913 unsigned AlignmentInt =
914 cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
915 Type *LoadTy = CI->getType();
916 Align Alignment =
917 DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), LoadTy);
918 if (TTI.isLegalMaskedGather(LoadTy, Alignment))
919 return false;
920 scalarizeMaskedGather(CI, ModifiedDT);
921 return true;
922 }
923 case Intrinsic::masked_scatter: {
924 unsigned AlignmentInt =
925 cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
926 Type *StoreTy = CI->getArgOperand(0)->getType();
927 Align Alignment =
928 DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), StoreTy);
929 if (TTI.isLegalMaskedScatter(StoreTy, Alignment))
930 return false;
931 scalarizeMaskedScatter(CI, ModifiedDT);
932 return true;
933 }
934 case Intrinsic::masked_expandload:
935 if (TTI.isLegalMaskedExpandLoad(CI->getType()))
936 return false;
937 scalarizeMaskedExpandLoad(CI, ModifiedDT);
938 return true;
939 case Intrinsic::masked_compressstore:
940 if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
941 return false;
942 scalarizeMaskedCompressStore(CI, ModifiedDT);
943 return true;
944 }
945 }
946
947 return false;
948 }
949