1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 // instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/CodeGen/TargetSubtargetInfo.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Casting.h"
34 #include <algorithm>
35 #include <cassert>
36
37 using namespace llvm;
38
39 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
40
41 namespace {
42
43 class ScalarizeMaskedMemIntrin : public FunctionPass {
44 const TargetTransformInfo *TTI = nullptr;
45
46 public:
47 static char ID; // Pass identification, replacement for typeid
48
ScalarizeMaskedMemIntrin()49 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
50 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
51 }
52
53 bool runOnFunction(Function &F) override;
54
getPassName() const55 StringRef getPassName() const override {
56 return "Scalarize Masked Memory Intrinsics";
57 }
58
getAnalysisUsage(AnalysisUsage & AU) const59 void getAnalysisUsage(AnalysisUsage &AU) const override {
60 AU.addRequired<TargetTransformInfoWrapperPass>();
61 }
62
63 private:
64 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
65 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
66 };
67
68 } // end anonymous namespace
69
70 char ScalarizeMaskedMemIntrin::ID = 0;
71
72 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
73 "Scalarize unsupported masked memory intrinsics", false, false)
74
createScalarizeMaskedMemIntrinPass()75 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
76 return new ScalarizeMaskedMemIntrin();
77 }
78
isConstantIntVector(Value * Mask)79 static bool isConstantIntVector(Value *Mask) {
80 Constant *C = dyn_cast<Constant>(Mask);
81 if (!C)
82 return false;
83
84 unsigned NumElts = Mask->getType()->getVectorNumElements();
85 for (unsigned i = 0; i != NumElts; ++i) {
86 Constant *CElt = C->getAggregateElement(i);
87 if (!CElt || !isa<ConstantInt>(CElt))
88 return false;
89 }
90
91 return true;
92 }
93
94 // Translate a masked load intrinsic like
95 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
96 // <16 x i1> %mask, <16 x i32> %passthru)
97 // to a chain of basic blocks, with loading element one-by-one if
98 // the appropriate mask bit is set
99 //
100 // %1 = bitcast i8* %addr to i32*
101 // %2 = extractelement <16 x i1> %mask, i32 0
102 // br i1 %2, label %cond.load, label %else
103 //
104 // cond.load: ; preds = %0
105 // %3 = getelementptr i32* %1, i32 0
106 // %4 = load i32* %3
107 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
108 // br label %else
109 //
110 // else: ; preds = %0, %cond.load
111 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
112 // %6 = extractelement <16 x i1> %mask, i32 1
113 // br i1 %6, label %cond.load1, label %else2
114 //
115 // cond.load1: ; preds = %else
116 // %7 = getelementptr i32* %1, i32 1
117 // %8 = load i32* %7
118 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
119 // br label %else2
120 //
121 // else2: ; preds = %else, %cond.load1
122 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
123 // %10 = extractelement <16 x i1> %mask, i32 2
124 // br i1 %10, label %cond.load4, label %else5
125 //
scalarizeMaskedLoad(CallInst * CI,bool & ModifiedDT)126 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
127 Value *Ptr = CI->getArgOperand(0);
128 Value *Alignment = CI->getArgOperand(1);
129 Value *Mask = CI->getArgOperand(2);
130 Value *Src0 = CI->getArgOperand(3);
131
132 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
133 VectorType *VecType = cast<VectorType>(CI->getType());
134
135 Type *EltTy = VecType->getElementType();
136
137 IRBuilder<> Builder(CI->getContext());
138 Instruction *InsertPt = CI;
139 BasicBlock *IfBlock = CI->getParent();
140
141 Builder.SetInsertPoint(InsertPt);
142 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
143
144 // Short-cut if the mask is all-true.
145 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
146 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
147 CI->replaceAllUsesWith(NewI);
148 CI->eraseFromParent();
149 return;
150 }
151
152 // Adjust alignment for the scalar instruction.
153 AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
154 // Bitcast %addr from i8* to EltTy*
155 Type *NewPtrType =
156 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
157 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
158 unsigned VectorWidth = VecType->getNumElements();
159
160 // The result vector
161 Value *VResult = Src0;
162
163 if (isConstantIntVector(Mask)) {
164 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
165 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
166 continue;
167 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
168 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
169 VResult = Builder.CreateInsertElement(VResult, Load, Idx);
170 }
171 CI->replaceAllUsesWith(VResult);
172 CI->eraseFromParent();
173 return;
174 }
175
176 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
177 // Fill the "else" block, created in the previous iteration
178 //
179 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
180 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
181 // br i1 %mask_1, label %cond.load, label %else
182 //
183
184 Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
185
186 // Create "cond" block
187 //
188 // %EltAddr = getelementptr i32* %1, i32 0
189 // %Elt = load i32* %EltAddr
190 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
191 //
192 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
193 "cond.load");
194 Builder.SetInsertPoint(InsertPt);
195
196 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
197 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
198 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
199
200 // Create "else" block, fill it in the next iteration
201 BasicBlock *NewIfBlock =
202 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
203 Builder.SetInsertPoint(InsertPt);
204 Instruction *OldBr = IfBlock->getTerminator();
205 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
206 OldBr->eraseFromParent();
207 BasicBlock *PrevIfBlock = IfBlock;
208 IfBlock = NewIfBlock;
209
210 // Create the phi to join the new and previous value.
211 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
212 Phi->addIncoming(NewVResult, CondBlock);
213 Phi->addIncoming(VResult, PrevIfBlock);
214 VResult = Phi;
215 }
216
217 CI->replaceAllUsesWith(VResult);
218 CI->eraseFromParent();
219
220 ModifiedDT = true;
221 }
222
223 // Translate a masked store intrinsic, like
224 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
225 // <16 x i1> %mask)
226 // to a chain of basic blocks, that stores element one-by-one if
227 // the appropriate mask bit is set
228 //
229 // %1 = bitcast i8* %addr to i32*
230 // %2 = extractelement <16 x i1> %mask, i32 0
231 // br i1 %2, label %cond.store, label %else
232 //
233 // cond.store: ; preds = %0
234 // %3 = extractelement <16 x i32> %val, i32 0
235 // %4 = getelementptr i32* %1, i32 0
236 // store i32 %3, i32* %4
237 // br label %else
238 //
239 // else: ; preds = %0, %cond.store
240 // %5 = extractelement <16 x i1> %mask, i32 1
241 // br i1 %5, label %cond.store1, label %else2
242 //
243 // cond.store1: ; preds = %else
244 // %6 = extractelement <16 x i32> %val, i32 1
245 // %7 = getelementptr i32* %1, i32 1
246 // store i32 %6, i32* %7
247 // br label %else2
248 // . . .
scalarizeMaskedStore(CallInst * CI,bool & ModifiedDT)249 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
250 Value *Src = CI->getArgOperand(0);
251 Value *Ptr = CI->getArgOperand(1);
252 Value *Alignment = CI->getArgOperand(2);
253 Value *Mask = CI->getArgOperand(3);
254
255 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
256 VectorType *VecType = cast<VectorType>(Src->getType());
257
258 Type *EltTy = VecType->getElementType();
259
260 IRBuilder<> Builder(CI->getContext());
261 Instruction *InsertPt = CI;
262 BasicBlock *IfBlock = CI->getParent();
263 Builder.SetInsertPoint(InsertPt);
264 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
265
266 // Short-cut if the mask is all-true.
267 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
268 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
269 CI->eraseFromParent();
270 return;
271 }
272
273 // Adjust alignment for the scalar instruction.
274 AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
275 // Bitcast %addr from i8* to EltTy*
276 Type *NewPtrType =
277 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
278 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
279 unsigned VectorWidth = VecType->getNumElements();
280
281 if (isConstantIntVector(Mask)) {
282 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
283 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
284 continue;
285 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
286 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
287 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
288 }
289 CI->eraseFromParent();
290 return;
291 }
292
293 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
294 // Fill the "else" block, created in the previous iteration
295 //
296 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
297 // br i1 %mask_1, label %cond.store, label %else
298 //
299 Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
300
301 // Create "cond" block
302 //
303 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
304 // %EltAddr = getelementptr i32* %1, i32 0
305 // %store i32 %OneElt, i32* %EltAddr
306 //
307 BasicBlock *CondBlock =
308 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
309 Builder.SetInsertPoint(InsertPt);
310
311 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
312 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
313 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
314
315 // Create "else" block, fill it in the next iteration
316 BasicBlock *NewIfBlock =
317 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
318 Builder.SetInsertPoint(InsertPt);
319 Instruction *OldBr = IfBlock->getTerminator();
320 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
321 OldBr->eraseFromParent();
322 IfBlock = NewIfBlock;
323 }
324 CI->eraseFromParent();
325
326 ModifiedDT = true;
327 }
328
329 // Translate a masked gather intrinsic like
330 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
331 // <16 x i1> %Mask, <16 x i32> %Src)
332 // to a chain of basic blocks, with loading element one-by-one if
333 // the appropriate mask bit is set
334 //
335 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
336 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
337 // br i1 %Mask0, label %cond.load, label %else
338 //
339 // cond.load:
340 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
341 // %Load0 = load i32, i32* %Ptr0, align 4
342 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
343 // br label %else
344 //
345 // else:
346 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
347 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
348 // br i1 %Mask1, label %cond.load1, label %else2
349 //
350 // cond.load1:
351 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
352 // %Load1 = load i32, i32* %Ptr1, align 4
353 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
354 // br label %else2
355 // . . .
356 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
357 // ret <16 x i32> %Result
scalarizeMaskedGather(CallInst * CI,bool & ModifiedDT)358 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
359 Value *Ptrs = CI->getArgOperand(0);
360 Value *Alignment = CI->getArgOperand(1);
361 Value *Mask = CI->getArgOperand(2);
362 Value *Src0 = CI->getArgOperand(3);
363
364 VectorType *VecType = cast<VectorType>(CI->getType());
365 Type *EltTy = VecType->getElementType();
366
367 IRBuilder<> Builder(CI->getContext());
368 Instruction *InsertPt = CI;
369 BasicBlock *IfBlock = CI->getParent();
370 Builder.SetInsertPoint(InsertPt);
371 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
372
373 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
374
375 // The result vector
376 Value *VResult = Src0;
377 unsigned VectorWidth = VecType->getNumElements();
378
379 // Shorten the way if the mask is a vector of constants.
380 if (isConstantIntVector(Mask)) {
381 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
382 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
383 continue;
384 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
385 LoadInst *Load =
386 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
387 VResult =
388 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
389 }
390 CI->replaceAllUsesWith(VResult);
391 CI->eraseFromParent();
392 return;
393 }
394
395 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
396 // Fill the "else" block, created in the previous iteration
397 //
398 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
399 // br i1 %Mask1, label %cond.load, label %else
400 //
401
402 Value *Predicate =
403 Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
404
405 // Create "cond" block
406 //
407 // %EltAddr = getelementptr i32* %1, i32 0
408 // %Elt = load i32* %EltAddr
409 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
410 //
411 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
412 Builder.SetInsertPoint(InsertPt);
413
414 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
415 LoadInst *Load =
416 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
417 Value *NewVResult =
418 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
419
420 // Create "else" block, fill it in the next iteration
421 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
422 Builder.SetInsertPoint(InsertPt);
423 Instruction *OldBr = IfBlock->getTerminator();
424 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
425 OldBr->eraseFromParent();
426 BasicBlock *PrevIfBlock = IfBlock;
427 IfBlock = NewIfBlock;
428
429 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
430 Phi->addIncoming(NewVResult, CondBlock);
431 Phi->addIncoming(VResult, PrevIfBlock);
432 VResult = Phi;
433 }
434
435 CI->replaceAllUsesWith(VResult);
436 CI->eraseFromParent();
437
438 ModifiedDT = true;
439 }
440
441 // Translate a masked scatter intrinsic, like
442 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
443 // <16 x i1> %Mask)
444 // to a chain of basic blocks, that stores element one-by-one if
445 // the appropriate mask bit is set.
446 //
447 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
448 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
449 // br i1 %Mask0, label %cond.store, label %else
450 //
451 // cond.store:
452 // %Elt0 = extractelement <16 x i32> %Src, i32 0
453 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
454 // store i32 %Elt0, i32* %Ptr0, align 4
455 // br label %else
456 //
457 // else:
458 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
459 // br i1 %Mask1, label %cond.store1, label %else2
460 //
461 // cond.store1:
462 // %Elt1 = extractelement <16 x i32> %Src, i32 1
463 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
464 // store i32 %Elt1, i32* %Ptr1, align 4
465 // br label %else2
466 // . . .
scalarizeMaskedScatter(CallInst * CI,bool & ModifiedDT)467 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
468 Value *Src = CI->getArgOperand(0);
469 Value *Ptrs = CI->getArgOperand(1);
470 Value *Alignment = CI->getArgOperand(2);
471 Value *Mask = CI->getArgOperand(3);
472
473 assert(isa<VectorType>(Src->getType()) &&
474 "Unexpected data type in masked scatter intrinsic");
475 assert(isa<VectorType>(Ptrs->getType()) &&
476 isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
477 "Vector of pointers is expected in masked scatter intrinsic");
478
479 IRBuilder<> Builder(CI->getContext());
480 Instruction *InsertPt = CI;
481 BasicBlock *IfBlock = CI->getParent();
482 Builder.SetInsertPoint(InsertPt);
483 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
484
485 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
486 unsigned VectorWidth = Src->getType()->getVectorNumElements();
487
488 // Shorten the way if the mask is a vector of constants.
489 if (isConstantIntVector(Mask)) {
490 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
491 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
492 continue;
493 Value *OneElt =
494 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
495 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
496 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
497 }
498 CI->eraseFromParent();
499 return;
500 }
501
502 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
503 // Fill the "else" block, created in the previous iteration
504 //
505 // %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
506 // br i1 %Mask1, label %cond.store, label %else
507 //
508 Value *Predicate =
509 Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
510
511 // Create "cond" block
512 //
513 // %Elt1 = extractelement <16 x i32> %Src, i32 1
514 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
515 // %store i32 %Elt1, i32* %Ptr1
516 //
517 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
518 Builder.SetInsertPoint(InsertPt);
519
520 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
521 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
522 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
523
524 // Create "else" block, fill it in the next iteration
525 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
526 Builder.SetInsertPoint(InsertPt);
527 Instruction *OldBr = IfBlock->getTerminator();
528 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
529 OldBr->eraseFromParent();
530 IfBlock = NewIfBlock;
531 }
532 CI->eraseFromParent();
533
534 ModifiedDT = true;
535 }
536
scalarizeMaskedExpandLoad(CallInst * CI,bool & ModifiedDT)537 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
538 Value *Ptr = CI->getArgOperand(0);
539 Value *Mask = CI->getArgOperand(1);
540 Value *PassThru = CI->getArgOperand(2);
541
542 VectorType *VecType = cast<VectorType>(CI->getType());
543
544 Type *EltTy = VecType->getElementType();
545
546 IRBuilder<> Builder(CI->getContext());
547 Instruction *InsertPt = CI;
548 BasicBlock *IfBlock = CI->getParent();
549
550 Builder.SetInsertPoint(InsertPt);
551 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
552
553 unsigned VectorWidth = VecType->getNumElements();
554
555 // The result vector
556 Value *VResult = PassThru;
557
558 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
559 // Fill the "else" block, created in the previous iteration
560 //
561 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
562 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
563 // br i1 %mask_1, label %cond.load, label %else
564 //
565
566 Value *Predicate =
567 Builder.CreateExtractElement(Mask, Idx);
568
569 // Create "cond" block
570 //
571 // %EltAddr = getelementptr i32* %1, i32 0
572 // %Elt = load i32* %EltAddr
573 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
574 //
575 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
576 "cond.load");
577 Builder.SetInsertPoint(InsertPt);
578
579 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
580 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
581
582 // Move the pointer if there are more blocks to come.
583 Value *NewPtr;
584 if ((Idx + 1) != VectorWidth)
585 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
586
587 // Create "else" block, fill it in the next iteration
588 BasicBlock *NewIfBlock =
589 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
590 Builder.SetInsertPoint(InsertPt);
591 Instruction *OldBr = IfBlock->getTerminator();
592 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
593 OldBr->eraseFromParent();
594 BasicBlock *PrevIfBlock = IfBlock;
595 IfBlock = NewIfBlock;
596
597 // Create the phi to join the new and previous value.
598 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
599 ResultPhi->addIncoming(NewVResult, CondBlock);
600 ResultPhi->addIncoming(VResult, PrevIfBlock);
601 VResult = ResultPhi;
602
603 // Add a PHI for the pointer if this isn't the last iteration.
604 if ((Idx + 1) != VectorWidth) {
605 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
606 PtrPhi->addIncoming(NewPtr, CondBlock);
607 PtrPhi->addIncoming(Ptr, PrevIfBlock);
608 Ptr = PtrPhi;
609 }
610 }
611
612 CI->replaceAllUsesWith(VResult);
613 CI->eraseFromParent();
614
615 ModifiedDT = true;
616 }
617
scalarizeMaskedCompressStore(CallInst * CI,bool & ModifiedDT)618 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
619 Value *Src = CI->getArgOperand(0);
620 Value *Ptr = CI->getArgOperand(1);
621 Value *Mask = CI->getArgOperand(2);
622
623 VectorType *VecType = cast<VectorType>(Src->getType());
624
625 IRBuilder<> Builder(CI->getContext());
626 Instruction *InsertPt = CI;
627 BasicBlock *IfBlock = CI->getParent();
628
629 Builder.SetInsertPoint(InsertPt);
630 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
631
632 Type *EltTy = VecType->getVectorElementType();
633
634 unsigned VectorWidth = VecType->getNumElements();
635
636 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
637 // Fill the "else" block, created in the previous iteration
638 //
639 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
640 // br i1 %mask_1, label %cond.store, label %else
641 //
642 Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
643
644 // Create "cond" block
645 //
646 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
647 // %EltAddr = getelementptr i32* %1, i32 0
648 // %store i32 %OneElt, i32* %EltAddr
649 //
650 BasicBlock *CondBlock =
651 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
652 Builder.SetInsertPoint(InsertPt);
653
654 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
655 Builder.CreateAlignedStore(OneElt, Ptr, 1);
656
657 // Move the pointer if there are more blocks to come.
658 Value *NewPtr;
659 if ((Idx + 1) != VectorWidth)
660 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
661
662 // Create "else" block, fill it in the next iteration
663 BasicBlock *NewIfBlock =
664 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
665 Builder.SetInsertPoint(InsertPt);
666 Instruction *OldBr = IfBlock->getTerminator();
667 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
668 OldBr->eraseFromParent();
669 BasicBlock *PrevIfBlock = IfBlock;
670 IfBlock = NewIfBlock;
671
672 // Add a PHI for the pointer if this isn't the last iteration.
673 if ((Idx + 1) != VectorWidth) {
674 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
675 PtrPhi->addIncoming(NewPtr, CondBlock);
676 PtrPhi->addIncoming(Ptr, PrevIfBlock);
677 Ptr = PtrPhi;
678 }
679 }
680 CI->eraseFromParent();
681
682 ModifiedDT = true;
683 }
684
runOnFunction(Function & F)685 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
686 bool EverMadeChange = false;
687
688 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
689
690 bool MadeChange = true;
691 while (MadeChange) {
692 MadeChange = false;
693 for (Function::iterator I = F.begin(); I != F.end();) {
694 BasicBlock *BB = &*I++;
695 bool ModifiedDTOnIteration = false;
696 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
697
698 // Restart BB iteration if the dominator tree of the Function was changed
699 if (ModifiedDTOnIteration)
700 break;
701 }
702
703 EverMadeChange |= MadeChange;
704 }
705
706 return EverMadeChange;
707 }
708
optimizeBlock(BasicBlock & BB,bool & ModifiedDT)709 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
710 bool MadeChange = false;
711
712 BasicBlock::iterator CurInstIterator = BB.begin();
713 while (CurInstIterator != BB.end()) {
714 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
715 MadeChange |= optimizeCallInst(CI, ModifiedDT);
716 if (ModifiedDT)
717 return true;
718 }
719
720 return MadeChange;
721 }
722
optimizeCallInst(CallInst * CI,bool & ModifiedDT)723 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
724 bool &ModifiedDT) {
725 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
726 if (II) {
727 switch (II->getIntrinsicID()) {
728 default:
729 break;
730 case Intrinsic::masked_load:
731 // Scalarize unsupported vector masked load
732 if (TTI->isLegalMaskedLoad(CI->getType()))
733 return false;
734 scalarizeMaskedLoad(CI, ModifiedDT);
735 return true;
736 case Intrinsic::masked_store:
737 if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType()))
738 return false;
739 scalarizeMaskedStore(CI, ModifiedDT);
740 return true;
741 case Intrinsic::masked_gather:
742 if (TTI->isLegalMaskedGather(CI->getType()))
743 return false;
744 scalarizeMaskedGather(CI, ModifiedDT);
745 return true;
746 case Intrinsic::masked_scatter:
747 if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType()))
748 return false;
749 scalarizeMaskedScatter(CI, ModifiedDT);
750 return true;
751 case Intrinsic::masked_expandload:
752 if (TTI->isLegalMaskedExpandLoad(CI->getType()))
753 return false;
754 scalarizeMaskedExpandLoad(CI, ModifiedDT);
755 return true;
756 case Intrinsic::masked_compressstore:
757 if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
758 return false;
759 scalarizeMaskedCompressStore(CI, ModifiedDT);
760 return true;
761 }
762 }
763
764 return false;
765 }
766