1 /*========================== begin_copyright_notice ============================
2
3 Copyright (C) 2017-2021 Intel Corporation
4
5 SPDX-License-Identifier: MIT
6
7 ============================= end_copyright_notice ===========================*/
8
9 /*========================== begin_copyright_notice ============================
10
11 This file is distributed under the University of Illinois Open Source License.
12 See LICENSE.TXT for details.
13
14 ============================= end_copyright_notice ===========================*/
15
16 /*========================== CustomUnsafeOptPass.cpp ==========================
17
18 This file contains IGC custom optimizations that are arithmetically safe.
19 The passes are
20 CustomSafeOptPass
21 GenSpecificPattern
22 IGCConstProp
23 IGCIndirectICBPropagaion
24 CustomLoopInfo
25 VectorBitCastOpt
26 GenStrengthReduction
27 FlattenSmallSwitch
28 SplitIndirectEEtoSel
29
30 CustomSafeOptPass does peephole optimizations
31 For example, reduce the alloca size so there is a chance to promote indexed temp.
32
33 GenSpecificPattern reverts llvm changes back to what is needed.
34 For example, llvm changes AND to OR, and GenSpecificPaattern changes it back to
35 allow more optimizations
36
37 IGCConstProp was originated from llvm copy-prop code with one addition for
38 shader-constant replacement. So llvm copyright is added above.
39
40 IGCIndirectICBPropagaion reads the immediate constant buffer from meta data and
41 use them as immediates instead of using send messages to read from buffer.
42
43 CustomLoopInfo returns true if there is any sampleL in a loop for the driver.
44
45 VectorBitCastOpt preprocesses vector bitcasts to be after extractelement
46 instructions.
47
48 GenStrengthReduction performs a fdiv optimization.
49
50 FlattenSmallSwitch flatten the if/else or switch structure and use cmp+sel
51 instead if the structure is small.
52
53 SplitIndirectEEtoSel splits extractelements with very small vec to a series of
54 cmp+sel to avoid expensive VxH mov.
55
56 =============================================================================*/
57
58 #include "Compiler/CustomSafeOptPass.hpp"
59 #include "Compiler/CISACodeGen/helper.h"
60 #include "Compiler/CodeGenPublic.h"
61 #include "Compiler/IGCPassSupport.h"
62 #include "GenISAIntrinsics/GenIntrinsics.h"
63 #include "GenISAIntrinsics/GenIntrinsicInst.h"
64 #include "common/IGCConstantFolder.h"
65 #include "common/LLVMWarningsPush.hpp"
66 #include "llvm/Config/llvm-config.h"
67 #include "WrapperLLVM/Utils.h"
68 #include <llvmWrapper/IR/DerivedTypes.h>
69 #include <llvmWrapper/IR/IRBuilder.h>
70 #include <llvmWrapper/IR/PatternMatch.h>
71 #include <llvmWrapper/Analysis/TargetLibraryInfo.h>
72 #include <llvm/ADT/Statistic.h>
73 #include <llvm/ADT/SetVector.h>
74 #include <llvm/Analysis/ConstantFolding.h>
75 #include <llvm/IR/Constants.h>
76 #include <llvm/IR/Function.h>
77 #include <llvm/IR/Instructions.h>
78 #include <llvm/IR/Intrinsics.h>
79 #include <llvm/IR/InstIterator.h>
80 #include <llvm/Transforms/Utils/Local.h>
81 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
82 #include <llvm/Analysis/ValueTracking.h>
83 #include "common/LLVMWarningsPop.hpp"
84 #include <set>
85 #include "../inc/common/secure_mem.h"
86 #include "Probe/Assertion.h"
87
88 using namespace llvm;
89 using namespace IGC;
90 using namespace GenISAIntrinsic;
91
92 // Register pass to igc-opt
93 #define PASS_FLAG1 "igc-custom-safe-opt"
94 #define PASS_DESCRIPTION1 "Custom Pass Optimization"
95 #define PASS_CFG_ONLY1 false
96 #define PASS_ANALYSIS1 false
97 IGC_INITIALIZE_PASS_BEGIN(CustomSafeOptPass, PASS_FLAG1, PASS_DESCRIPTION1, PASS_CFG_ONLY1, PASS_ANALYSIS1)
98 IGC_INITIALIZE_PASS_END(CustomSafeOptPass, PASS_FLAG1, PASS_DESCRIPTION1, PASS_CFG_ONLY1, PASS_ANALYSIS1)
99
100 char CustomSafeOptPass::ID = 0;
101
CustomSafeOptPass()102 CustomSafeOptPass::CustomSafeOptPass() : FunctionPass(ID)
103 {
104 initializeCustomSafeOptPassPass(*PassRegistry::getPassRegistry());
105 }
106
107 #define DEBUG_TYPE "CustomSafeOptPass"
108
109 STATISTIC(Stat_DiscardRemoved, "Number of insts removed in Discard Opt");
110
runOnFunction(Function & F)111 bool CustomSafeOptPass::runOnFunction(Function& F)
112 {
113 psHasSideEffect = getAnalysis<CodeGenContextWrapper>().getCodeGenContext()->m_instrTypes.psHasSideEffect;
114 visit(F);
115 return true;
116 }
117
visitInstruction(Instruction & I)118 void CustomSafeOptPass::visitInstruction(Instruction& I)
119 {
120 // nothing
121 }
122
123 // Searches the following pattern
124 // %1 = icmp eq i32 %cmpop1, %cmpop2
125 // %2 = xor i1 %1, true
126 // ...
127 // %3 = select i1 %1, i8 0, i8 1
128 //
129 // And changes it to:
130 // %1 = icmp ne i32 %cmpop1, %cmpop2
131 // ...
132 // %3 = select i1 %1, i8 1, i8 0
133 //
134 // and
135 //
136 // Searches the following pattern
137 // %1 = icmp ule i32 %cmpop1, %cmpop2
138 // %2 = xor i1 %1, true
139 // br i1 %1, label %3, label %4
140 //
141 // And changes it to:
142 // %1 = icmp ugt i32 %cmpop1, %cmpop2
143 // br i1 %1, label %4, label %3
144 //
145 // This optimization combines statements regardless of the predicate.
146 // It will also work if the icmp instruction does not have users, except for the xor, select or branch instruction.
visitXor(Instruction & XorInstr)147 void CustomSafeOptPass::visitXor(Instruction& XorInstr) {
148 using namespace llvm::PatternMatch;
149
150 CmpInst::Predicate Pred;
151 auto XorPattern = m_c_Xor(m_ICmp(Pred, m_Value(), m_Value()), m_SpecificInt(1));
152 if (!match(&XorInstr, XorPattern)) {
153 return;
154 }
155
156 Value* XorOp0 = XorInstr.getOperand(0);
157 Value* XorOp1 = XorInstr.getOperand(1);
158 auto ICmpInstr = cast<Instruction>(isa<ICmpInst>(XorOp0) ? XorOp0 : XorOp1);
159
160 llvm::SmallVector<Instruction*, 4> UsersList;
161
162 for (auto U : ICmpInstr->users()) {
163 if (isa<BranchInst>(U)) {
164 UsersList.push_back(cast<Instruction>(U));
165 }
166 else if (SelectInst* S = dyn_cast<SelectInst>(U)) {
167 if (S->getCondition() == ICmpInstr) {
168 UsersList.push_back(cast<Instruction>(S));
169 }
170 else {
171 return;
172 }
173 }
174 else if (U != &XorInstr) {
175 return;
176 }
177 }
178
179 IRBuilder<> builder(ICmpInstr);
180 auto NegatedCmpPred = cast<ICmpInst>(ICmpInstr)->getInversePredicate();
181 auto NewCmp = cast<ICmpInst>(builder.CreateICmp(NegatedCmpPred, ICmpInstr->getOperand(0), ICmpInstr->getOperand(1)));
182
183 for (auto I : UsersList) {
184 if (SelectInst* S = dyn_cast<SelectInst>(I)) {
185 S->swapProfMetadata();
186 Value* TrueVal = S->getTrueValue();
187 Value* FalseVal = S->getFalseValue();
188 S->setTrueValue(FalseVal);
189 S->setFalseValue(TrueVal);
190 }
191 else {
192 IGC_ASSERT(isa<BranchInst>(I));
193 BranchInst* B = cast<BranchInst>(I);
194 B->swapSuccessors();
195 }
196 }
197
198 XorInstr.replaceAllUsesWith(NewCmp);
199 ICmpInstr->replaceAllUsesWith(NewCmp);
200 XorInstr.eraseFromParent();
201 ICmpInstr->eraseFromParent();
202 }
203
204 // Searches for following pattern:
205 // %cmp = icmp slt i32 %x, %y
206 // %cond.not = xor i1 %cond, true
207 // %and.cond = and i1 %cmp, %cond.not
208 // br i1 %or.cond, label %bb1, label %bb2
209 //
210 // And changes it to:
211 // %0 = icmp sge i32 %x, %y
212 // %1 = or i1 %cond, %0
213 // br i1 %1, label %bb2, label %bb1
visitAnd(BinaryOperator & I)214 void CustomSafeOptPass::visitAnd(BinaryOperator& I) {
215 using namespace llvm::PatternMatch;
216
217 if (!I.hasOneUse() ||
218 !isa<BranchInst>(*I.user_begin()) ||
219 !I.getType()->isIntegerTy(1)) {
220 return;
221 }
222
223 Value* XorArgValue;
224 CmpInst::Predicate Pred;
225 auto AndPattern = m_c_And(m_c_Xor(m_Value(XorArgValue), m_SpecificInt(1)), m_ICmp(Pred, m_Value(), m_Value()));
226 if (!match(&I, AndPattern)) return;
227
228 IRBuilder<> builder(&I);
229 auto CompareInst = cast<ICmpInst>(isa<ICmpInst>(I.getOperand(0)) ? I.getOperand(0) : I.getOperand(1));
230 auto NegatedCompareInst = builder.CreateICmp(CompareInst->getInversePredicate(), CompareInst->getOperand(0), CompareInst->getOperand(1));
231 auto OrInst = builder.CreateOr(XorArgValue, NegatedCompareInst);
232
233 auto BrInst = cast<BranchInst>(*I.user_begin());
234 BrInst->setCondition(OrInst);
235 BrInst->swapSuccessors();
236
237 I.eraseFromParent();
238 }
239
240 /*
241 Optimizing from
242 % 377 = call i32 @llvm.genx.GenISA.simdSize()
243 % .rhs.trunc = trunc i32 % 377 to i8
244 % .lhs.trunc = trunc i32 % 26 to i8
245 % 383 = udiv i8 % .lhs.trunc, % .rhs.trunc
246 to
247 % 377 = call i32 @llvm.genx.GenISA.simdSize()
248 % .rhs.trunc = trunc i32 % 377 to i8
249 % .lhs.trunc = trunc i32 % 26 to i8
250 % a = shr i8 % rhs.trunc, 4
251 % b = shr i8 % .lhs.trunc, 3
252 % 382 = shr i8 % b, % a
253 or
254 % 377 = call i32 @llvm.genx.GenISA.simdSize()
255 % 383 = udiv i32 %382, %377
256 to
257 % 377 = call i32 @llvm.genx.GenISA.simdSize()
258 % a = shr i32 %377, 4
259 % b = shr i32 %382, 3
260 % 382 = shr i32 % b, % a
261 */
visitUDiv(llvm::BinaryOperator & I)262 void CustomSafeOptPass::visitUDiv(llvm::BinaryOperator& I)
263 {
264 bool isPatternfound = false;
265
266 if (TruncInst* trunc = dyn_cast<TruncInst>(I.getOperand(1)))
267 {
268 if (CallInst* Inst = dyn_cast<CallInst>(trunc->getOperand(0)))
269 {
270 if (GenIntrinsicInst* SimdInst = dyn_cast<GenIntrinsicInst>(Inst))
271 if (SimdInst->getIntrinsicID() == GenISAIntrinsic::GenISA_simdSize)
272 isPatternfound = true;
273 }
274 }
275 else
276 {
277 if (CallInst* Inst = dyn_cast<CallInst>(I.getOperand(1)))
278 {
279 if (GenIntrinsicInst* SimdInst = dyn_cast<GenIntrinsicInst>(Inst))
280 if (SimdInst->getIntrinsicID() == GenISAIntrinsic::GenISA_simdSize)
281 isPatternfound = true;
282 }
283 }
284 if (isPatternfound)
285 {
286 IRBuilder<> builder(&I);
287 Value* Shift1 = builder.CreateLShr(I.getOperand(1), ConstantInt::get(I.getOperand(1)->getType(), 4));
288 Value* Shift2 = builder.CreateLShr(I.getOperand(0), ConstantInt::get(I.getOperand(0)->getType(), 3));
289 Value* Shift3 = builder.CreateLShr(Shift2, Shift1);
290 I.replaceAllUsesWith(Shift3);
291 I.eraseFromParent();
292 }
293 }
294
visitAllocaInst(AllocaInst & I)295 void CustomSafeOptPass::visitAllocaInst(AllocaInst& I)
296 {
297 // reduce the alloca size so there is a chance to promote indexed temp.
298
299 // ex: to:
300 // dcl_indexable_temp x1[356], 4 dcl_indexable_temp x1[2], 4
301 // mov x[1][354].xyzw, l(1f, 0f, -1f, 0f) mov x[1][0].xyzw, l(1f, 0f, -1f, 0f)
302 // mov x[1][355].xyzw, l(1f, 1f, 0f, -1f) mov x[1][1].xyzw, l(1f, 1f, 0f, -1f)
303 // mov r[1].xy, x[1][r[1].x + 354].xyxx mov r[1].xy, x[1][r[1].x].xyxx
304
305 // in llvm: to:
306 // %outarray_x1 = alloca[356 x float], align 4 %31 = alloca[2 x float]
307 // %outarray_y2 = alloca[356 x float], align 4 %32 = alloca[2 x float]
308 // %27 = getelementptr[356 x float] * %outarray_x1, i32 0, i32 352 %35 = getelementptr[2 x float] * %31, i32 0, i32 0
309 // store float 0.000000e+00, float* %27, align 4 store float 0.000000e+00, float* %35, align 4
310 // %28 = getelementptr[356 x float] * %outarray_y2, i32 0, i32 352 %36 = getelementptr[2 x float] * %32, i32 0, i32 0
311 // store float 0.000000e+00, float* %28, align 4 store float 0.000000e+00, float* %36, align 4
312 // %43 = add nsw i32 %selRes_s, 354
313 // %44 = getelementptr[356 x float] * %outarray_x1, i32 0, i32 %43 %51 = getelementptr[2 x float] * %31, i32 0, i32 %selRes_s
314 // %45 = load float* %44, align 4 %52 = load float* %51, align 4
315
316 llvm::Type* pType = I.getType()->getPointerElementType();
317 if (!pType->isArrayTy() ||
318 static_cast<ADDRESS_SPACE>(I.getType()->getAddressSpace()) != ADDRESS_SPACE_PRIVATE)
319 {
320 return;
321 }
322
323 if (!(pType->getArrayElementType()->isFloatingPointTy() ||
324 pType->getArrayElementType()->isIntegerTy() ||
325 pType->getArrayElementType()->isPointerTy()))
326 {
327 return;
328 }
329
330 int index_lb = int_cast<int>(pType->getArrayNumElements());
331 int index_ub = 0;
332
333 // Find all uses of this alloca
334 for (Value::user_iterator it = I.user_begin(), e = I.user_end(); it != e; ++it)
335 {
336 if (GetElementPtrInst * pGEP = llvm::dyn_cast<GetElementPtrInst>(*it))
337 {
338 ConstantInt* C0 = dyn_cast<ConstantInt>(pGEP->getOperand(1));
339 if (!C0 || !C0->isZero() || pGEP->getNumOperands() != 3)
340 {
341 return;
342 }
343 for (Value::user_iterator use_it = pGEP->user_begin(), use_e = pGEP->user_end(); use_it != use_e; ++use_it)
344 {
345 if (llvm::dyn_cast<llvm::LoadInst>(*use_it))
346 {
347 }
348 else if (llvm::StoreInst * pStore = llvm::dyn_cast<llvm::StoreInst>(*use_it))
349 {
350 llvm::Value* pValueOp = pStore->getValueOperand();
351 if (pValueOp == *it)
352 {
353 // GEP instruction is the stored value of the StoreInst (not supported case)
354 return;
355 }
356 if (dyn_cast<ConstantInt>(pGEP->getOperand(2)) && pGEP->getOperand(2)->getType()->isIntegerTy(32))
357 {
358 int currentIndex = int_cast<int>(
359 dyn_cast<ConstantInt>(pGEP->getOperand(2))->getZExtValue());
360 index_lb = (currentIndex < index_lb) ? currentIndex : index_lb;
361 index_ub = (currentIndex > index_ub) ? currentIndex : index_ub;
362
363 }
364 else
365 {
366 return;
367 }
368 }
369 else
370 {
371 // This is some other instruction. Right now we don't want to handle these
372 return;
373 }
374 }
375 }
376 else
377 {
378 if (!IsBitCastForLifetimeMark(dyn_cast<Value>(*it)))
379 {
380 return;
381 }
382 }
383 }
384
385 unsigned int newSize = index_ub + 1 - index_lb;
386 if (newSize >= pType->getArrayNumElements())
387 {
388 return;
389 }
390 // found a case to optimize
391 IGCLLVM::IRBuilder<> IRB(&I);
392 llvm::ArrayType* allocaArraySize = llvm::ArrayType::get(pType->getArrayElementType(), newSize);
393 llvm::Value* newAlloca = IRB.CreateAlloca(allocaArraySize, nullptr);
394 llvm::Value* gepArg1;
395
396 for (Value::user_iterator it = I.user_begin(), e = I.user_end(); it != e; ++it)
397 {
398 if (GetElementPtrInst * pGEP = llvm::dyn_cast<GetElementPtrInst>(*it))
399 {
400 if (dyn_cast<ConstantInt>(pGEP->getOperand(2)))
401 {
402 // pGEP->getOperand(2) is constant. Reduce the constant value directly
403 int newIndex = int_cast<int>(dyn_cast<ConstantInt>(pGEP->getOperand(2))->getZExtValue())
404 - index_lb;
405 gepArg1 = IRB.getInt32(newIndex);
406 }
407 else
408 {
409 // pGEP->getOperand(2) is not constant. create a sub instruction to reduce it
410 gepArg1 = BinaryOperator::CreateSub(pGEP->getOperand(2), IRB.getInt32(index_lb), "reducedIndex", pGEP);
411 }
412 llvm::Value* gepArg[] = { pGEP->getOperand(1), gepArg1 };
413 llvm::Value* pGEPnew = GetElementPtrInst::Create(nullptr, newAlloca, gepArg, "", pGEP);
414 pGEP->replaceAllUsesWith(pGEPnew);
415 }
416 }
417 }
418
visitLoadInst(LoadInst & load)419 void CustomSafeOptPass::visitLoadInst(LoadInst& load)
420 {
421 // Optimize indirect access to private arrays. Handle cases where
422 // array index is a select between two immediate constant values.
423 // After the optimization there is fair chance the alloca will be
424 // promoted to registers.
425 //
426 // E.g. change the following:
427 // %PrivareArray = alloca[4 x <3 x float>], align 16
428 // %IndirectIndex = select i1 %SomeCondition, i32 1, i32 2
429 // %IndirectAccessPtr= getelementptr[4 x <3 x float>], [4 x <3 x float>] * %PrivareArray, i32 0, i32 %IndirectIndex
430 // %LoadedValue = load <3 x float>, <3 x float>* %IndirectAccess, align 16
431
432 // %PrivareArray = alloca[4 x <3 x float>], align 16
433 // %DirectAccessPtr1 = getelementptr[4 x <3 x float>], [4 x <3 x float>] * %PrivareArray, i32 0, i32 1
434 // %DirectAccessPtr2 = getelementptr[4 x <3 x float>], [4 x <3 x float>] * %PrivareArray, i32 0, i32 2
435 // %LoadedValue1 = load <3 x float>, <3 x float>* %DirectAccessPtr1, align 16
436 // %LoadedValue2 = load <3 x float>, <3 x float>* %DirectAccessPtr2, align 16
437 // %LoadedValue = select i1 %SomeCondition, <3 x float> %LoadedValue1, <3 x float> %LoadedValue2
438
439 Value* ptr = load.getPointerOperand();
440 if (ptr->getType()->getPointerAddressSpace() != 0)
441 {
442 // only private arrays are handled
443 return;
444 }
445 if (GetElementPtrInst * gep = dyn_cast<GetElementPtrInst>(ptr))
446 {
447 bool found = false;
448 uint selIdx = 0;
449 // Check if this gep is a good candidate for optimization.
450 // The instruction has to have exactly one non-constant index value.
451 // The index value has to be a select instruction with immediate
452 // constant values.
453 for (uint i = 0; i < gep->getNumIndices(); ++i)
454 {
455 Value* gepIdx = gep->getOperand(i + 1);
456 if (!isa<ConstantInt>(gepIdx))
457 {
458 SelectInst* sel = dyn_cast<SelectInst>(gepIdx);
459 if (!found &&
460 sel &&
461 isa<ConstantInt>(sel->getOperand(1)) &&
462 isa<ConstantInt>(sel->getOperand(2)))
463 {
464 found = true;
465 selIdx = i;
466 }
467 else
468 {
469 found = false; // optimize cases with only a single non-constant index.
470 break;
471 }
472 }
473
474 }
475 if (found)
476 {
477 SelectInst* sel = cast<SelectInst>(gep->getOperand(selIdx + 1));
478 SmallVector<Value*, 8> indices;
479 indices.append(gep->idx_begin(), gep->idx_end());
480 indices[selIdx] = sel->getOperand(1);
481 GetElementPtrInst* gep1 = GetElementPtrInst::Create(nullptr, gep->getPointerOperand(), indices, gep->getName(), gep);
482 gep1->setDebugLoc(gep->getDebugLoc());
483 indices[selIdx] = sel->getOperand(2);
484 GetElementPtrInst* gep2 = GetElementPtrInst::Create(nullptr, gep->getPointerOperand(), indices, gep->getName(), gep);
485 gep2->setDebugLoc(gep->getDebugLoc());
486 LoadInst* load1 = cast<LoadInst>(load.clone());
487 load1->insertBefore(&load);
488 load1->setOperand(0, gep1);
489 LoadInst* load2 = cast<LoadInst>(load.clone());
490 load2->insertBefore(&load);
491 load2->setOperand(0, gep2);
492 SelectInst* result = SelectInst::Create(sel->getCondition(), load1, load2, load.getName(), &load);
493 result->setDebugLoc(load.getDebugLoc());
494 load.replaceAllUsesWith(result);
495 load.eraseFromParent();
496 if (gep->use_empty())
497 {
498 gep->eraseFromParent();
499 }
500 if (sel->use_empty())
501 {
502 sel->eraseFromParent();
503 }
504 }
505
506 }
507
508 }
509
visitCallInst(CallInst & C)510 void CustomSafeOptPass::visitCallInst(CallInst& C)
511 {
512 // discard optimizations
513 if (llvm::GenIntrinsicInst * inst = llvm::dyn_cast<GenIntrinsicInst>(&C))
514 {
515 GenISAIntrinsic::ID id = inst->getIntrinsicID();
516 // try to prune the destination size
517 switch (id)
518 {
519 case GenISAIntrinsic::GenISA_discard:
520 {
521 Value* srcVal0 = C.getOperand(0);
522 if (ConstantInt * CI = dyn_cast<ConstantInt>(srcVal0))
523 {
524 if (CI->isZero()) { // i1 is false
525 C.eraseFromParent();
526 ++Stat_DiscardRemoved;
527 }
528 else if (!psHasSideEffect)
529 {
530 BasicBlock* blk = C.getParent();
531 BasicBlock* pred = blk->getSinglePredecessor();
532 if (blk && pred)
533 {
534 BranchInst* cbr = dyn_cast<BranchInst>(pred->getTerminator());
535 if (cbr && cbr->isConditional())
536 {
537 if (blk == cbr->getSuccessor(0))
538 {
539 C.setOperand(0, cbr->getCondition());
540 C.removeFromParent();
541 C.insertBefore(cbr);
542 }
543 else if (blk == cbr->getSuccessor(1))
544 {
545 Value* flipCond = llvm::BinaryOperator::CreateNot(cbr->getCondition(), "", cbr);
546 C.setOperand(0, flipCond);
547 C.removeFromParent();
548 C.insertBefore(cbr);
549 }
550 }
551 }
552 }
553 }
554 break;
555 }
556
557 case GenISAIntrinsic::GenISA_bfi:
558 {
559 visitBfi(inst);
560 break;
561 }
562
563 case GenISAIntrinsic::GenISA_f32tof16_rtz:
564 {
565 visitf32tof16(inst);
566 break;
567 }
568
569 case GenISAIntrinsic::GenISA_sampleBptr:
570 {
571 visitSampleBptr(llvm::cast<llvm::SampleIntrinsic>(inst));
572 break;
573 }
574
575 case GenISAIntrinsic::GenISA_umulH:
576 {
577 visitMulH(inst, false);
578 break;
579 }
580
581 case GenISAIntrinsic::GenISA_imulH:
582 {
583 visitMulH(inst, true);
584 break;
585 }
586
587 case GenISAIntrinsic::GenISA_ldptr:
588 {
589 visitLdptr(llvm::cast<llvm::SamplerLoadIntrinsic>(inst));
590 break;
591 }
592
593 case GenISAIntrinsic::GenISA_ldrawvector_indexed:
594 {
595 visitLdRawVec(inst);
596 break;
597 }
598 default:
599 break;
600 }
601 }
602 }
603
604 //
605 // pattern match packing of two half float from f32tof16:
606 //
607 // % 43 = call float @llvm.genx.GenISA.f32tof16.rtz(float %res_s55.i)
608 // % 44 = call float @llvm.genx.GenISA.f32tof16.rtz(float %res_s59.i)
609 // % 47 = bitcast float %44 to i32
610 // % 49 = bitcast float %43 to i32
611 // %addres_s68.i = shl i32 % 47, 16
612 // % mulres_s69.i = add i32 %addres_s68.i, % 49
613 // % 51 = bitcast i32 %mulres_s69.i to float
614 // into
615 // %43 = call half @llvm.genx.GenISA_ftof_rtz(float %res_s55.i)
616 // %44 = call half @llvm.genx.GenISA_ftof_rtz(float %res_s59.i)
617 // %45 = insertelement <2 x half>undef, %43, 0
618 // %46 = insertelement <2 x half>%45, %44, 1
619 // %51 = bitcast <2 x half> %46 to float
620 //
621 // or if the f32tof16 are from fpext half:
622 //
623 // % src0_s = fpext half %res_s to float
624 // % src0_s2 = fpext half %res_s1 to float
625 // % 2 = call fast float @llvm.genx.GenISA.f32tof16.rtz(float %src0_s)
626 // % 3 = call fast float @llvm.genx.GenISA.f32tof16.rtz(float %src0_s2)
627 // % 4 = bitcast float %2 to i32
628 // % 5 = bitcast float %3 to i32
629 // % addres_s = shl i32 % 4, 16
630 // % mulres_s = add i32 %addres_s, % 5
631 // % 6 = bitcast i32 %mulres_s to float
632 // into
633 // % 2 = insertelement <2 x half> undef, half %res_s1, i32 0, !dbg !113
634 // % 3 = insertelement <2 x half> % 2, half %res_s, i32 1, !dbg !113
635 // % 4 = bitcast <2 x half> % 3 to float, !dbg !113
636
visitf32tof16(llvm::CallInst * inst)637 void CustomSafeOptPass::visitf32tof16(llvm::CallInst* inst)
638 {
639 if (!inst->hasOneUse())
640 {
641 return;
642 }
643
644 BitCastInst* bitcast = dyn_cast<BitCastInst>(*(inst->user_begin()));
645 if (!bitcast || !bitcast->hasOneUse() || !bitcast->getType()->isIntegerTy(32))
646 {
647 return;
648 }
649 Instruction* addInst = dyn_cast<BinaryOperator>(*(bitcast->user_begin()));
650 if (!addInst || addInst->getOpcode() != Instruction::Add || !addInst->hasOneUse())
651 {
652 return;
653 }
654 Instruction* lastValue = addInst;
655
656 if (BitCastInst * finalBitCast = dyn_cast<BitCastInst>(*(addInst->user_begin())))
657 {
658 lastValue = finalBitCast;
659 }
660
661 // check the other half
662 Value* otherOpnd = addInst->getOperand(0) == bitcast ? addInst->getOperand(1) : addInst->getOperand(0);
663 Instruction* shiftOrMul = dyn_cast<BinaryOperator>(otherOpnd);
664
665 if (!shiftOrMul ||
666 (shiftOrMul->getOpcode() != Instruction::Shl && shiftOrMul->getOpcode() != Instruction::Mul))
667 {
668 return;
669 }
670 bool isShift = shiftOrMul->getOpcode() == Instruction::Shl;
671 ConstantInt* constVal = dyn_cast<ConstantInt>(shiftOrMul->getOperand(1));
672 if (!constVal || !constVal->equalsInt(isShift ? 16 : 65536))
673 {
674 return;
675 }
676 BitCastInst* bitcast2 = dyn_cast<BitCastInst>(shiftOrMul->getOperand(0));
677 if (!bitcast2)
678 {
679 return;
680 }
681 llvm::GenIntrinsicInst* valueHi = dyn_cast<GenIntrinsicInst>(bitcast2->getOperand(0));
682 if (!valueHi || valueHi->getIntrinsicID() != GenISA_f32tof16_rtz)
683 {
684 return;
685 }
686
687 Value* loVal = nullptr;
688 Value* hiVal = nullptr;
689
690 FPExtInst* extInstLo = dyn_cast<FPExtInst>(inst->getOperand(0));
691 FPExtInst* extInstHi = dyn_cast<FPExtInst>(valueHi->getOperand(0));
692
693 IRBuilder<> builder(lastValue);
694 Type* funcType[] = { Type::getHalfTy(builder.getContext()), Type::getFloatTy(builder.getContext()) };
695 Type* halfx2 = IGCLLVM::FixedVectorType::get(Type::getHalfTy(builder.getContext()), 2);
696
697 if (extInstLo && extInstHi &&
698 extInstLo->getOperand(0)->getType()->isHalfTy() &&
699 extInstHi->getOperand(0)->getType()->isHalfTy())
700 {
701 loVal = extInstLo->getOperand(0);
702 hiVal = extInstHi->getOperand(0);
703 }
704 else
705 {
706 Function* f32tof16_rtz = GenISAIntrinsic::getDeclaration(inst->getParent()->getParent()->getParent(),
707 GenISAIntrinsic::GenISA_ftof_rtz, funcType);
708 loVal = builder.CreateCall(f32tof16_rtz, inst->getArgOperand(0));
709 hiVal = builder.CreateCall(f32tof16_rtz, valueHi->getArgOperand(0));
710 }
711 Value* vector = builder.CreateInsertElement(UndefValue::get(halfx2), loVal, builder.getInt32(0));
712 vector = builder.CreateInsertElement(vector, hiVal, builder.getInt32(1));
713 vector = builder.CreateBitCast(vector, lastValue->getType());
714 lastValue->replaceAllUsesWith(vector);
715 lastValue->eraseFromParent();
716 }
717
visitBfi(llvm::CallInst * inst)718 void CustomSafeOptPass::visitBfi(llvm::CallInst* inst)
719 {
720 IGC_ASSERT(inst->getType()->isIntegerTy(32));
721 ConstantInt* widthV = dyn_cast<ConstantInt>(inst->getOperand(0));
722 ConstantInt* offsetV = dyn_cast<ConstantInt>(inst->getOperand(1));
723 if (widthV && offsetV)
724 {
725 // transformation is beneficial if src3 is constant or if the offset is zero
726 if (isa<ConstantInt>(inst->getOperand(3)) || offsetV->isZero())
727 {
728 unsigned int width = static_cast<unsigned int>(widthV->getZExtValue());
729 unsigned int offset = static_cast<unsigned int>(offsetV->getZExtValue());
730 unsigned int bitMask = ((1 << width) - 1) << offset;
731 IRBuilder<> builder(inst);
732 // dst = ((src2 << offset) & bitmask) | (src3 & ~bitmask)
733 Value* firstTerm = builder.CreateShl(inst->getOperand(2), offsetV);
734 firstTerm = builder.CreateAnd(firstTerm, builder.getInt32(bitMask));
735 Value* secondTerm = builder.CreateAnd(inst->getOperand(3), builder.getInt32(~bitMask));
736 Value* dst = builder.CreateOr(firstTerm, secondTerm);
737 inst->replaceAllUsesWith(dst);
738 inst->eraseFromParent();
739 }
740 }
741 else if (widthV && widthV->isZeroValue())
742 {
743 inst->replaceAllUsesWith(inst->getOperand(3));
744 inst->eraseFromParent();
745 }
746 }
747
visitMulH(CallInst * inst,bool isSigned)748 void CustomSafeOptPass::visitMulH(CallInst* inst, bool isSigned)
749 {
750 ConstantInt* src0 = dyn_cast<ConstantInt>(inst->getOperand(0));
751 ConstantInt* src1 = dyn_cast<ConstantInt>(inst->getOperand(1));
752 if (src0 && src1)
753 {
754 unsigned nbits = inst->getType()->getIntegerBitWidth();
755 IGC_ASSERT(nbits < 64);
756
757 if (isSigned)
758 {
759 uint64_t ui0 = src0->getZExtValue();
760 uint64_t ui1 = src1->getZExtValue();
761 uint64_t r = ((ui0 * ui1) >> nbits);
762 inst->replaceAllUsesWith(ConstantInt::get(inst->getType(), r));
763 }
764 else
765 {
766 int64_t si0 = src0->getSExtValue();
767 int64_t si1 = src1->getSExtValue();
768 int64_t r = ((si0 * si1) >> nbits);
769 inst->replaceAllUsesWith(ConstantInt::get(inst->getType(), r, true));
770 }
771 inst->eraseFromParent();
772 }
773 }
774
775 // if phi is used in a FPTrunc and the sources all come from fpext we can skip the conversions
visitFPTruncInst(FPTruncInst & I)776 void CustomSafeOptPass::visitFPTruncInst(FPTruncInst& I)
777 {
778 if (PHINode * phi = dyn_cast<PHINode>(I.getOperand(0)))
779 {
780 bool foundPattern = true;
781 unsigned int numSrc = phi->getNumIncomingValues();
782 SmallVector<Value*, 6> newSources(numSrc);
783 for (unsigned int i = 0; i < numSrc; i++)
784 {
785 FPExtInst* source = dyn_cast<FPExtInst>(phi->getIncomingValue(i));
786 if (source && source->getOperand(0)->getType() == I.getType())
787 {
788 newSources[i] = source->getOperand(0);
789 }
790 else
791 {
792 foundPattern = false;
793 break;
794 }
795 }
796 if (foundPattern)
797 {
798 PHINode* newPhi = PHINode::Create(I.getType(), numSrc, "", phi);
799 for (unsigned int i = 0; i < numSrc; i++)
800 {
801 newPhi->addIncoming(newSources[i], phi->getIncomingBlock(i));
802 }
803
804 I.replaceAllUsesWith(newPhi);
805 I.eraseFromParent();
806 // if phi has other uses we add a fpext to avoid having two phi
807 if (!phi->use_empty())
808 {
809 IRBuilder<> builder(&(*phi->getParent()->getFirstInsertionPt()));
810 Value* extV = builder.CreateFPExt(newPhi, phi->getType());
811 phi->replaceAllUsesWith(extV);
812 }
813 }
814 }
815
816 }
817
visitFPToUIInst(llvm::FPToUIInst & FPUII)818 void CustomSafeOptPass::visitFPToUIInst(llvm::FPToUIInst& FPUII)
819 {
820 if (llvm::IntrinsicInst * intrinsicInst = llvm::dyn_cast<llvm::IntrinsicInst>(FPUII.getOperand(0)))
821 {
822 if (intrinsicInst->getIntrinsicID() == Intrinsic::floor)
823 {
824 FPUII.setOperand(0, intrinsicInst->getOperand(0));
825 if (intrinsicInst->use_empty())
826 {
827 intrinsicInst->eraseFromParent();
828 }
829 }
830 }
831 }
832
833 /// This remove simplify bitcast across phi and select instruction
834 /// LLVM doesn't catch those case and it is common in DX10+ as the input language is not typed
835 /// TODO: support cases where some sources are constant
visitBitCast(BitCastInst & BC)836 void CustomSafeOptPass::visitBitCast(BitCastInst& BC)
837 {
838 if (SelectInst * sel = dyn_cast<SelectInst>(BC.getOperand(0)))
839 {
840 BitCastInst* trueVal = dyn_cast<BitCastInst>(sel->getTrueValue());
841 BitCastInst* falseVal = dyn_cast<BitCastInst>(sel->getFalseValue());
842 if (trueVal && falseVal)
843 {
844 Value* trueValOrignalType = trueVal->getOperand(0);
845 Value* falseValOrignalType = falseVal->getOperand(0);
846 if (trueValOrignalType->getType() == BC.getType() &&
847 falseValOrignalType->getType() == BC.getType())
848 {
849 Value* cond = sel->getCondition();
850 Value* newVal = SelectInst::Create(cond, trueValOrignalType, falseValOrignalType, "", sel);
851 BC.replaceAllUsesWith(newVal);
852 BC.eraseFromParent();
853 }
854 }
855 }
856 else if (PHINode * phi = dyn_cast<PHINode>(BC.getOperand(0)))
857 {
858 if (phi->hasOneUse())
859 {
860 bool foundPattern = true;
861 unsigned int numSrc = phi->getNumIncomingValues();
862 SmallVector<Value*, 6> newSources(numSrc);
863 for (unsigned int i = 0; i < numSrc; i++)
864 {
865 BitCastInst* source = dyn_cast<BitCastInst>(phi->getIncomingValue(i));
866 if (source && source->getOperand(0)->getType() == BC.getType())
867 {
868 newSources[i] = source->getOperand(0);
869 }
870 else if (Constant * C = dyn_cast<Constant>(phi->getIncomingValue(i)))
871 {
872 newSources[i] = ConstantExpr::getCast(Instruction::BitCast, C, BC.getType());
873 }
874 else
875 {
876 foundPattern = false;
877 break;
878 }
879 }
880 if (foundPattern)
881 {
882 PHINode* newPhi = PHINode::Create(BC.getType(), numSrc, "", phi);
883 for (unsigned int i = 0; i < numSrc; i++)
884 {
885 newPhi->addIncoming(newSources[i], phi->getIncomingBlock(i));
886 }
887 BC.replaceAllUsesWith(newPhi);
888 BC.eraseFromParent();
889 }
890 }
891 }
892 }
893
894 /*
895 Search for DP4A pattern, e.g.:
896
897 %14 = load i32, i32 addrspace(1)* %13, align 4 <---- c variable (accumulator)
898 // Instructions to be matched:
899 %conv.i.i4 = sext i8 %scalar35 to i32
900 %conv.i7.i = sext i8 %scalar39 to i32
901 %mul.i = mul nsw i32 %conv.i.i4, %conv.i7.i
902 %conv.i6.i = sext i8 %scalar36 to i32
903 %conv.i5.i = sext i8 %scalar40 to i32
904 %mul4.i = mul nsw i32 %conv.i6.i, %conv.i5.i
905 %add.i = add nsw i32 %mul.i, %mul4.i
906 %conv.i4.i = sext i8 %scalar37 to i32
907 %conv.i3.i = sext i8 %scalar41 to i32
908 %mul7.i = mul nsw i32 %conv.i4.i, %conv.i3.i
909 %add8.i = add nsw i32 %add.i, %mul7.i
910 %conv.i2.i = sext i8 %scalar38 to i32
911 %conv.i1.i = sext i8 %scalar42 to i32
912 %mul11.i = mul nsw i32 %conv.i2.i, %conv.i1.i
913 %add12.i = add nsw i32 %add8.i, %mul11.i
914 %add13.i = add nsw i32 %14, %add12.i
915 // end matched instructions
916 store i32 %add13.i, i32 addrspace(1)* %18, align 4
917
918 =>
919
920 %14 = load i32, i32 addrspace(1)* %13, align 4
921 %15 = insertelement <4 x i8> undef, i8 %scalar35, i64 0
922 %16 = insertelement <4 x i8> undef, i8 %scalar39, i64 0
923 %17 = insertelement <4 x i8> %15, i8 %scalar36, i64 1
924 %18 = insertelement <4 x i8> %16, i8 %scalar40, i64 1
925 %19 = insertelement <4 x i8> %17, i8 %scalar37, i64 2
926 %20 = insertelement <4 x i8> %18, i8 %scalar41, i64 2
927 %21 = insertelement <4 x i8> %19, i8 %scalar38, i64 3
928 %22 = insertelement <4 x i8> %20, i8 %scalar42, i64 3
929 %23 = bitcast <4 x i8> %21 to i32
930 %24 = bitcast <4 x i8> %22 to i32
931 %25 = call i32 @llvm.genx.GenISA.dp4a.ss.i32(i32 %14, i32 %23, i32 %24)
932 ...
933 store i32 %25, i32 addrspace(1)* %29, align 4
934 */
matchDp4a(BinaryOperator & I)935 void CustomSafeOptPass::matchDp4a(BinaryOperator &I) {
936 using namespace llvm::PatternMatch;
937
938 if (I.getOpcode() != Instruction::Add) return;
939 CodeGenContext* Ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
940 if (!Ctx->platform.hasHWDp4AddSupport()) return;
941
942 static constexpr int NUM_DP4A_COMPONENTS = 4;
943
944 // Holds sext/zext instructions.
945 std::array<Instruction *, NUM_DP4A_COMPONENTS> ExtA{}, ExtB{};
946 // Holds i8 values that were multiplied.
947 std::array<Value *, NUM_DP4A_COMPONENTS> ArrA{}, ArrB{};
948 // Holds the accumulator.
949 Value* AccVal = nullptr;
950 IRBuilder<> Builder(&I);
951
952 // Enum to check if given branch is signed or unsigned, e.g.
953 // comes from SExt or ZExt value.
954 enum class OriginSignedness { originSigned, originUnsigned, originUnknown };
955
956 // Note: the returned pattern from this lambda doesn't use m_ZExtOrSExt pattern on purpose -
957 // There is no way to bind to both instruction and it's arguments, so we bind to instruction and
958 // check the opcode later.
959 auto getMulPat = [&](int i) { return m_c_Mul(m_Instruction(ExtA[i]), m_Instruction(ExtB[i])); };
960 auto getAdd4Pat = [&](const auto& pat0, const auto& pat1, const auto& pat2, const auto& pat3) {
961 return m_c_Add(m_c_Add(m_c_Add(pat0, pat1), pat2), pat3);
962 };
963 auto getAdd5Pat = [&](const auto& pat0, const auto& pat1, const auto& pat2, const auto& pat3, const auto& pat4) {
964 return m_c_Add(getAdd4Pat(pat0, pat1, pat2, pat3), pat4);
965 };
966 auto PatternAccAtBeginning = getAdd5Pat(m_Value(AccVal), getMulPat(0), getMulPat(1), getMulPat(2), getMulPat(3));
967 auto PatternAccAtEnd = getAdd5Pat(getMulPat(0), getMulPat(1), getMulPat(2), getMulPat(3), m_Value(AccVal));
968 auto PatternNoAcc = getAdd4Pat(getMulPat(0), getMulPat(1), getMulPat(2), getMulPat(3));
969
970 if (!match(&I, PatternAccAtEnd) && !match(&I, PatternAccAtBeginning)) {
971 if (match(&I, PatternNoAcc)) {
972 AccVal = Builder.getInt32(0);
973 } else {
974 return;
975 }
976 }
977
978
979 // Check if values in A and B all have the same extension type (sext/zext)
980 // and that they come from i8 type. A and B extension types can be
981 // different.
982 auto checkIfValuesComeFromCharType = [](auto &range,
983 OriginSignedness &retSign) {
984 IGC_ASSERT_MESSAGE((range.begin() != range.end()),
985 "Cannot check empty collection.");
986
987 auto shr24Pat = m_Shr(m_Value(), m_SpecificInt(24));
988 auto and255Pat = m_And(m_Value(), m_SpecificInt(255));
989
990 IGC_ASSERT_MESSAGE(range.size() == NUM_DP4A_COMPONENTS,
991 "Range too big in dp4a pattern match!");
992 std::array<OriginSignedness, NUM_DP4A_COMPONENTS> signs;
993 int counter = 0;
994 for (auto I : range) {
995 if (!I->getType()->isIntegerTy(32)) {
996 return false;
997 }
998
999 switch (I->getOpcode()) {
1000 case Instruction::SExt:
1001 case Instruction::ZExt:
1002 if (!I->getOperand(0)->getType()->isIntegerTy(8)) {
1003 return false;
1004 }
1005 signs[counter] = (I->getOpcode() == Instruction::SExt)
1006 ? OriginSignedness::originSigned
1007 : OriginSignedness::originUnsigned;
1008 break;
1009 case Instruction::AShr:
1010 case Instruction::LShr:
1011 if (!match(I, shr24Pat)) {
1012 return false;
1013 }
1014 signs[counter] = (I->getOpcode() == Instruction::AShr)
1015 ? OriginSignedness::originSigned
1016 : OriginSignedness::originUnsigned;
1017 break;
1018 case Instruction::And:
1019 if (!match(I, and255Pat)) {
1020 return false;
1021 }
1022 signs[counter] = OriginSignedness::originUnsigned;
1023 break;
1024 default:
1025 return false;
1026 }
1027
1028 counter++;
1029 }
1030
1031 // check if all have the same sign.
1032 retSign = signs[0];
1033 return std::all_of(
1034 signs.begin(), signs.end(),
1035 [&signs](OriginSignedness v) { return v == signs[0]; });
1036 };
1037
1038 OriginSignedness ABranch = OriginSignedness::originUnknown, BBranch = OriginSignedness::originUnknown;
1039 bool canMatch = AccVal->getType()->isIntegerTy(32) && checkIfValuesComeFromCharType(ExtA, ABranch) && checkIfValuesComeFromCharType(ExtB, BBranch);
1040 if (!canMatch) return;
1041
1042 GenISAIntrinsic::ID IntrinsicID{};
1043
1044 static_assert(ExtA.size() == ArrA.size() && ExtB.size() == ArrB.size() && ExtA.size() == NUM_DP4A_COMPONENTS, "array sizes must match!");
1045
1046 auto getInt8Origin = [&](Instruction *I) {
1047 if (I->getOpcode() == Instruction::SExt ||
1048 I->getOpcode() == Instruction::ZExt) {
1049 return I->getOperand(0);
1050 }
1051 Builder.SetInsertPoint(I->getNextNode());
1052 return Builder.CreateTrunc(I, Builder.getInt8Ty());
1053 };
1054
1055 std::transform(ExtA.begin(), ExtA.end(), ArrA.begin(), getInt8Origin);
1056 std::transform(ExtB.begin(), ExtB.end(), ArrB.begin(), getInt8Origin);
1057
1058 switch (ABranch) {
1059 case OriginSignedness::originSigned:
1060 IntrinsicID = (BBranch == OriginSignedness::originSigned)
1061 ? GenISAIntrinsic::GenISA_dp4a_ss
1062 : GenISAIntrinsic::GenISA_dp4a_su;
1063 break;
1064 case OriginSignedness::originUnsigned:
1065 IntrinsicID = (BBranch == OriginSignedness::originSigned)
1066 ? GenISAIntrinsic::GenISA_dp4a_us
1067 : GenISAIntrinsic::GenISA_dp4a_uu;
1068 break;
1069 default:
1070 IGC_ASSERT(0);
1071 }
1072
1073 // Additional optimisation: check if the values come from an ExtractElement instruction.
1074 // If this is the case, reorder the elements in the array to match the ExtractElement pattern.
1075 // This way we avoid potential shufflevector instructions which cause additional mov instructions in final asm.
1076 // Note: indices in ExtractElement do not have to be in range [0-3], they can be greater. We just want to have them ordered ascending.
1077 auto extractElementOrderOpt = [&](std::array<Value*, NUM_DP4A_COMPONENTS> & Arr) {
1078 bool CanOptOrder = true;
1079 llvm::SmallPtrSet<Value*, NUM_DP4A_COMPONENTS> OriginValues;
1080 std::map<int64_t, Value*> IndexMap;
1081 for (int i = 0; i < NUM_DP4A_COMPONENTS; ++i) {
1082 ConstantInt* IndexVal = nullptr;
1083 Value* OriginVal = nullptr;
1084 auto P = m_ExtractElt(m_Value(OriginVal), m_ConstantInt(IndexVal));
1085 if (!match(Arr[i], P)) {
1086 CanOptOrder = false;
1087 break;
1088 }
1089 OriginValues.insert(OriginVal);
1090 IndexMap.insert({ IndexVal->getSExtValue(), Arr[i] });
1091 }
1092
1093 if (CanOptOrder && OriginValues.size() == 1 && IndexMap.size() == NUM_DP4A_COMPONENTS) {
1094 int i = 0;
1095 for(auto &El : IndexMap) {
1096 Arr[i++] = El.second;
1097 }
1098 }
1099 };
1100 extractElementOrderOpt(ArrA);
1101 extractElementOrderOpt(ArrB);
1102
1103 Value* VectorA = UndefValue::get(IGCLLVM::FixedVectorType::get(Builder.getInt8Ty(), NUM_DP4A_COMPONENTS));
1104 Value* VectorB = UndefValue::get(IGCLLVM::FixedVectorType::get(Builder.getInt8Ty(), NUM_DP4A_COMPONENTS));
1105 for (int i = 0; i < NUM_DP4A_COMPONENTS; ++i) {
1106 VectorA = Builder.CreateInsertElement(VectorA, ArrA[i], i);
1107 VectorB = Builder.CreateInsertElement(VectorB, ArrB[i], i);
1108 }
1109 Value* ValA = Builder.CreateBitCast(VectorA, Builder.getInt32Ty());
1110 Value* ValB = Builder.CreateBitCast(VectorB, Builder.getInt32Ty());
1111
1112 Function* Dp4aFun = GenISAIntrinsic::getDeclaration(I.getModule(), IntrinsicID, Builder.getInt32Ty());
1113 Value* Res = Builder.CreateCall(Dp4aFun, { AccVal, ValA, ValB });
1114 I.replaceAllUsesWith(Res);
1115 }
1116
isEmulatedAdd(BinaryOperator & I)1117 bool CustomSafeOptPass::isEmulatedAdd(BinaryOperator& I)
1118 {
1119 if (I.getOpcode() == Instruction::Or)
1120 {
1121 if (BinaryOperator * OrOp0 = dyn_cast<BinaryOperator>(I.getOperand(0)))
1122 {
1123 if (OrOp0->getOpcode() == Instruction::Shl)
1124 {
1125 // Check the SHl. If we have a constant Shift Left val then we can check
1126 // it to see if it is emulating an add.
1127 if (ConstantInt * pConstShiftLeft = dyn_cast<ConstantInt>(OrOp0->getOperand(1)))
1128 {
1129 if (ConstantInt * pConstOrVal = dyn_cast<ConstantInt>(I.getOperand(1)))
1130 {
1131 int const_shift = int_cast<int>(pConstShiftLeft->getZExtValue());
1132 int const_or_val = int_cast<int>(pConstOrVal->getSExtValue());
1133 if ((1 << const_shift) > abs(const_or_val))
1134 {
1135 // The value fits in the shl. So this is an emulated add.
1136 return true;
1137 }
1138 }
1139 }
1140 }
1141 else if (OrOp0->getOpcode() == Instruction::Mul)
1142 {
1143 // Check to see if the Or is emulating and add.
1144 // If we have a constant Mul and a constant Or. The Mul constant needs to be divisible by the rounded up 2^n of Or value.
1145 if (ConstantInt * pConstMul = dyn_cast<ConstantInt>(OrOp0->getOperand(1)))
1146 {
1147 if (ConstantInt * pConstOrVal = dyn_cast<ConstantInt>(I.getOperand(1)))
1148 {
1149 if (pConstOrVal->isNegative() == false)
1150 {
1151 DWORD const_or_val = int_cast<DWORD>(pConstOrVal->getZExtValue());
1152 DWORD nextPowerOfTwo = iSTD::RoundPower2(const_or_val + 1);
1153 if (nextPowerOfTwo && (pConstMul->getZExtValue() % nextPowerOfTwo == 0))
1154 {
1155 return true;
1156 }
1157 }
1158 }
1159 }
1160 }
1161 }
1162 }
1163 return false;
1164 }
1165
1166 // Attempt to create new float instruction if both operands are from FPTruncInst instructions.
1167 // Example with fadd:
1168 // %Temp-31.prec.i = fptrunc float %34 to half
1169 // %Temp-30.prec.i = fptrunc float %33 to half
1170 // %41 = fadd fast half %Temp-31.prec.i, %Temp-30.prec.i
1171 // %Temp-32.i = fpext half %41 to float
1172 //
1173 // This fadd is used as a float, and doesn't need the operands to be cased to half.
1174 // We can remove the extra casts in this case.
1175 // This becomes:
1176 // %41 = fadd fast float %34, %33
1177 // Can also do matches with fadd/fmul that will later become an mad instruction.
1178 // mad example:
1179 // %.prec70.i = fptrunc float %273 to half
1180 // %.prec78.i = fptrunc float %276 to half
1181 // %279 = fmul fast half %233, %.prec70.i
1182 // %282 = fadd fast half %279, %.prec78.i
1183 // %.prec84.i = fpext half %282 to float
1184 // This becomes:
1185 // %279 = fpext half %233 to float
1186 // %280 = fmul fast float %273, %279
1187 // %281 = fadd fast float %280, %276
removeHftoFCast(Instruction & I)1188 void CustomSafeOptPass::removeHftoFCast(Instruction& I)
1189 {
1190 if (!I.getType()->isFloatingPointTy())
1191 return;
1192
1193 // Check if the only user is a FPExtInst
1194 if (!I.hasOneUse())
1195 return;
1196
1197 // Check if this instruction is used in a single FPExtInst
1198 FPExtInst* castInst = NULL;
1199 User* U = *I.user_begin();
1200 if (FPExtInst* inst = dyn_cast<FPExtInst>(U))
1201 {
1202 if (inst->getType()->isFloatTy())
1203 {
1204 castInst = inst;
1205 }
1206 }
1207 if (!castInst)
1208 return;
1209
1210 // Check for fmad pattern
1211 if (I.getOpcode() == Instruction::FAdd)
1212 {
1213 Value* src0 = nullptr, * src1 = nullptr, * src2 = nullptr;
1214
1215 // CodeGenPatternMatch::MatchMad matches the first fmul.
1216 Instruction* fmulInst = nullptr;
1217 for (uint i = 0; i < 2; i++)
1218 {
1219 fmulInst = dyn_cast<Instruction>(I.getOperand(i));
1220 if (fmulInst && fmulInst->getOpcode() == Instruction::FMul)
1221 {
1222 src0 = fmulInst->getOperand(0);
1223 src1 = fmulInst->getOperand(1);
1224 src2 = I.getOperand(1 - i);
1225 break;
1226 }
1227 else
1228 {
1229 // Prevent other non-fmul instructions from getting used
1230 fmulInst = nullptr;
1231 }
1232 }
1233 if (fmulInst)
1234 {
1235 // Used to get the new float operands for the new instructions
1236 auto getFloatValue = [](Value* operand, Instruction* I, Type* type)
1237 {
1238 if (FPTruncInst* inst = dyn_cast<FPTruncInst>(operand))
1239 {
1240 // Use the float input of the FPTrunc
1241 if (inst->getOperand(0)->getType()->isFloatTy())
1242 {
1243 return inst->getOperand(0);
1244 }
1245 else
1246 {
1247 return (Value*)NULL;
1248 }
1249 }
1250 else if (Instruction* inst = dyn_cast<Instruction>(operand))
1251 {
1252 // Cast the result of this operand to a float
1253 return dyn_cast<Value>(new FPExtInst(inst, type, "", I));
1254 }
1255 return (Value*)NULL;
1256 };
1257
1258 int convertCount = 0;
1259 if (dyn_cast<FPTruncInst>(src0))
1260 convertCount++;
1261 if (dyn_cast<FPTruncInst>(src1))
1262 convertCount++;
1263 if (dyn_cast<FPTruncInst>(src2))
1264 convertCount++;
1265 if (convertCount >= 2)
1266 {
1267 // Conversion for the hf values
1268 auto floatTy = castInst->getType();
1269 src0 = getFloatValue(src0, fmulInst, floatTy);
1270 src1 = getFloatValue(src1, fmulInst, floatTy);
1271 src2 = getFloatValue(src2, &I, floatTy);
1272
1273 if (!src0 || !src1 || !src2)
1274 return;
1275
1276 // Create new float fmul and fadd instructions
1277 Value* newFmul = BinaryOperator::Create(Instruction::FMul, src0, src1, "", &I);
1278 Value* newFadd = BinaryOperator::Create(Instruction::FAdd, newFmul, src2, "", &I);
1279
1280 // Copy fast math flags
1281 Instruction* fmulInst = dyn_cast<Instruction>(newFmul);
1282 Instruction* faddInst = dyn_cast<Instruction>(newFadd);
1283 fmulInst->copyFastMathFlags(fmulInst);
1284 faddInst->copyFastMathFlags(&I);
1285 castInst->replaceAllUsesWith(faddInst);
1286 return;
1287 }
1288 }
1289 }
1290
1291 // Check if operands come from a Float to HF Cast
1292 Value *S1 = NULL, *S2 = NULL;
1293 if (FPTruncInst* inst = dyn_cast<FPTruncInst>(I.getOperand(0)))
1294 {
1295 if (!inst->getType()->isHalfTy())
1296 return;
1297 S1 = inst->getOperand(0);
1298 }
1299 if (FPTruncInst* inst = dyn_cast<FPTruncInst>(I.getOperand(1)))
1300 {
1301 if (!inst->getType()->isHalfTy())
1302 return;
1303 S2 = inst->getOperand(0);
1304 }
1305 if (!S1 || !S2)
1306 {
1307 return;
1308 }
1309
1310 Value* newInst = NULL;
1311 if (BinaryOperator* bo = dyn_cast<BinaryOperator>(&I))
1312 {
1313 newInst = BinaryOperator::Create(bo->getOpcode(), S1, S2, "", &I);
1314 Instruction* inst = dyn_cast<Instruction>(newInst);
1315 inst->copyFastMathFlags(&I);
1316 castInst->replaceAllUsesWith(inst);
1317 }
1318 }
1319
visitBinaryOperator(BinaryOperator & I)1320 void CustomSafeOptPass::visitBinaryOperator(BinaryOperator& I)
1321 {
1322 matchDp4a(I);
1323
1324 // move immediate value in consecutive integer adds to the last added value.
1325 // this can allow more chance of doing CSE and memopt.
1326 // a = b + 8
1327 // d = a + c
1328 // to
1329 // a = b + c
1330 // d = a + 8
1331
1332 CodeGenContext* pContext = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
1333
1334 // Before WA if() as it's validated behavior.
1335 if (I.getType()->isIntegerTy() && I.getOpcode() == Instruction::Or)
1336 {
1337 unsigned int bitWidth = cast<IntegerType>(I.getType())->getBitWidth();
1338 switch (bitWidth)
1339 {
1340 case 16:
1341 matchReverse<unsigned short>(I);
1342 break;
1343 case 32:
1344 matchReverse<unsigned int>(I);
1345 break;
1346 case 64:
1347 matchReverse<unsigned long long>(I);
1348 break;
1349 }
1350 }
1351
1352 // WA for remaining bug in custom pass
1353 if (pContext->m_DriverInfo.WADisableCustomPass())
1354 return;
1355 if (I.getType()->isIntegerTy())
1356 {
1357 if ((I.getOpcode() == Instruction::Add || isEmulatedAdd(I)) &&
1358 I.hasOneUse())
1359 {
1360 ConstantInt* src0imm = dyn_cast<ConstantInt>(I.getOperand(0));
1361 ConstantInt* src1imm = dyn_cast<ConstantInt>(I.getOperand(1));
1362 if (src0imm || src1imm)
1363 {
1364 llvm::Instruction* nextInst = llvm::dyn_cast<llvm::Instruction>(*(I.user_begin()));
1365 if (nextInst && nextInst->getOpcode() == Instruction::Add)
1366 {
1367 // do not apply if any add instruction has NSW flag since we can't save it
1368 if ((isa<OverflowingBinaryOperator>(I) && I.hasNoSignedWrap()) || nextInst->hasNoSignedWrap())
1369 return;
1370 ConstantInt* secondSrc0imm = dyn_cast<ConstantInt>(nextInst->getOperand(0));
1371 ConstantInt* secondSrc1imm = dyn_cast<ConstantInt>(nextInst->getOperand(1));
1372 // found 2 add instructions to swap srcs
1373 if (!secondSrc0imm && !secondSrc1imm && nextInst->getOperand(0) != nextInst->getOperand(1))
1374 {
1375 for (int i = 0; i < 2; i++)
1376 {
1377 if (nextInst->getOperand(i) == &I)
1378 {
1379 Value* newAdd = BinaryOperator::CreateAdd(src0imm ? I.getOperand(1) : I.getOperand(0), nextInst->getOperand(1 - i), "", nextInst);
1380
1381 // Conservatively clear the NSW NUW flags, since they may not be
1382 // preserved by the reassociation.
1383 bool IsNUW = isa<OverflowingBinaryOperator>(I) && I.hasNoUnsignedWrap() && nextInst->hasNoUnsignedWrap();
1384 cast<BinaryOperator>(newAdd)->setHasNoUnsignedWrap(IsNUW);
1385 nextInst->setHasNoUnsignedWrap(IsNUW);
1386 nextInst->setHasNoSignedWrap(false);
1387
1388 nextInst->setOperand(0, newAdd);
1389 nextInst->setOperand(1, I.getOperand(src0imm ? 0 : 1));
1390 break;
1391 }
1392 }
1393 }
1394 }
1395 }
1396 }
1397 } else if (I.getType()->isFloatingPointTy()) {
1398 removeHftoFCast(I);
1399 }
1400 }
1401
visitLdptr(llvm::SamplerLoadIntrinsic * inst)1402 void IGC::CustomSafeOptPass::visitLdptr(llvm::SamplerLoadIntrinsic* inst)
1403 {
1404 if (!IGC_IS_FLAG_ENABLED(UseHDCTypedReadForAllTextures) &&
1405 !IGC_IS_FLAG_ENABLED(UseHDCTypedReadForAllTypedBuffers))
1406 {
1407 return;
1408 }
1409
1410 // change
1411 // % 10 = call fast <4 x float> @llvm.genx.GenISA.ldptr.v4f32.p196608v4f32(i32 %_s1.i, i32 %_s14.i, i32 0, i32 0, <4 x float> addrspace(196608)* null, i32 0, i32 0, i32 0), !dbg !123
1412 // to
1413 // % 10 = call fast <4 x float> @llvm.genx.GenISA.typedread.p196608v4f32(<4 x float> addrspace(196608)* null, i32 %_s1.i, i32 %_s14.i, i32 0, i32 0), !dbg !123
1414 // when the index comes directly from threadid
1415
1416 Constant* src1 = dyn_cast<Constant>(inst->getOperand(1));
1417 Constant* src2 = dyn_cast<Constant>(inst->getOperand(2));
1418 Constant* src3 = dyn_cast<Constant>(inst->getOperand(3));
1419
1420 // src2 and src3 has to be zero
1421 if (!src2 || !src3 || !src2->isZeroValue() || !src3->isZeroValue())
1422 {
1423 return;
1424 }
1425
1426 // if only doing the opt on buffers, make sure src1 is zero too
1427 if (!IGC_IS_FLAG_ENABLED(UseHDCTypedReadForAllTextures) &&
1428 IGC_IS_FLAG_ENABLED(UseHDCTypedReadForAllTypedBuffers))
1429 {
1430 if (!src1 || !src1->isZeroValue())
1431 return;
1432 }
1433
1434 // do the transformation
1435 llvm::IRBuilder<> builder(inst);
1436 Module* M = inst->getParent()->getParent()->getParent();
1437
1438 Function* pLdIntrinsic = llvm::GenISAIntrinsic::getDeclaration(
1439 M,
1440 GenISAIntrinsic::GenISA_typedread,
1441 inst->getOperand(4)->getType());
1442
1443 SmallVector<Value*, 5> ld_FunctionArgList(5);
1444 ld_FunctionArgList[0] = inst->getTextureValue();
1445 ld_FunctionArgList[1] = builder.CreateAdd(inst->getOperand(0), inst->getOperand(inst->getTextureIndex() + 1));
1446 ld_FunctionArgList[2] = builder.CreateAdd(inst->getOperand(1), inst->getOperand(inst->getTextureIndex() + 2));
1447 ld_FunctionArgList[3] = builder.CreateAdd(inst->getOperand(3), inst->getOperand(inst->getTextureIndex() + 3));
1448 ld_FunctionArgList[4] = inst->getOperand(2); // lod=zero
1449
1450 llvm::CallInst* pNewCallInst = builder.CreateCall(
1451 pLdIntrinsic, ld_FunctionArgList);
1452
1453 // as typedread returns float4 by default, bitcast it
1454 // to int4 if necessary
1455 // FIXME: is it better to make typedRead return ty a anyvector?
1456 if (inst->getType() != pNewCallInst->getType())
1457 {
1458 IGC_ASSERT_MESSAGE(inst->getType()->isVectorTy(), "expect int4 here");
1459 IGC_ASSERT_MESSAGE(cast<IGCLLVM::FixedVectorType>(inst->getType())
1460 ->getElementType()
1461 ->isIntegerTy(32),
1462 "expect int4 here");
1463 IGC_ASSERT_MESSAGE(
1464 cast<IGCLLVM::FixedVectorType>(inst->getType())->getNumElements() ==
1465 4,
1466 "expect int4 here");
1467 auto bitCastInst = builder.CreateBitCast(pNewCallInst, inst->getType());
1468 inst->replaceAllUsesWith(bitCastInst);
1469 }
1470 else
1471 {
1472 inst->replaceAllUsesWith(pNewCallInst);
1473 }
1474 }
1475
1476
visitLdRawVec(llvm::CallInst * inst)1477 void IGC::CustomSafeOptPass::visitLdRawVec(llvm::CallInst* inst)
1478 {
1479 //Try to optimize and remove vector ld raw and change to scalar ld raw
1480
1481 //%a = call <4 x float> @llvm.genx.GenISA.ldrawvector.indexed.v4f32.p1441792f32(
1482 //.....float addrspace(1441792) * %243, i32 %offset, i32 4, i1 false), !dbg !216
1483 //%b = extractelement <4 x float> % 245, i32 0, !dbg !216
1484
1485 //into
1486
1487 //%new_offset = add i32 %offset, 0, !dbg !216
1488 //%b = call float @llvm.genx.GenISA.ldraw.indexed.f32.p1441792f32.i32.i32.i1(
1489 //.....float addrspace(1441792) * %251, i32 %new_offset, i32 4, i1 false)
1490
1491 if (inst->hasOneUse() &&
1492 isa<ExtractElementInst>(inst->user_back()))
1493 {
1494 auto EE = cast<ExtractElementInst>(inst->user_back());
1495 if (auto constIndex = dyn_cast<ConstantInt>(EE->getIndexOperand()))
1496 {
1497 llvm::IRBuilder<> builder(inst);
1498
1499 llvm::SmallVector<llvm::Type*, 2> ovldtypes{
1500 EE->getType(), //float type
1501 inst->getOperand(0)->getType(),
1502 };
1503
1504 // For new_offset we need to take into acount the index of the Extract
1505 // and convert it to bytes and add it to the existing offset
1506 auto new_offset = constIndex->getZExtValue() * 4;
1507
1508 llvm::SmallVector<llvm::Value*, 4> new_args{
1509 inst->getOperand(0),
1510 builder.CreateAdd(inst->getOperand(1),builder.getInt32((unsigned)new_offset)),
1511 inst->getOperand(2),
1512 inst->getOperand(3)
1513 };
1514
1515 Function* pLdraw_indexed_intrinsic = llvm::GenISAIntrinsic::getDeclaration(
1516 inst->getModule(),
1517 GenISAIntrinsic::GenISA_ldraw_indexed,
1518 ovldtypes);
1519
1520 llvm::Value* ldraw_indexed = builder.CreateCall(pLdraw_indexed_intrinsic, new_args, "");
1521 EE->replaceAllUsesWith(ldraw_indexed);
1522 }
1523 }
1524 }
1525
visitSampleBptr(llvm::SampleIntrinsic * sampleInst)1526 void IGC::CustomSafeOptPass::visitSampleBptr(llvm::SampleIntrinsic* sampleInst)
1527 {
1528 // sampleB with bias_value==0 -> sample
1529 llvm::ConstantFP* constBias = llvm::dyn_cast<llvm::ConstantFP>(sampleInst->getOperand(0));
1530 if (constBias && constBias->isZero())
1531 {
1532 // Copy args skipping bias operand:
1533 llvm::SmallVector<llvm::Value*, 10> args;
1534 for (unsigned int i = 1; i < sampleInst->getNumArgOperands(); i++)
1535 {
1536 args.push_back(sampleInst->getArgOperand(i));
1537 }
1538
1539 // Copy overloaded types unchanged:
1540 llvm::SmallVector<llvm::Type*, 4> overloadedTys;
1541 overloadedTys.push_back(sampleInst->getCalledFunction()->getReturnType());
1542 overloadedTys.push_back(sampleInst->getOperand(0)->getType());
1543 overloadedTys.push_back(sampleInst->getTextureValue()->getType());
1544 overloadedTys.push_back(sampleInst->getSamplerValue()->getType());
1545
1546 llvm::Function* sampleIntr = llvm::GenISAIntrinsic::getDeclaration(
1547 sampleInst->getParent()->getParent()->getParent(),
1548 GenISAIntrinsic::GenISA_sampleptr,
1549 overloadedTys);
1550
1551 llvm::Value* newSample = llvm::CallInst::Create(sampleIntr, args, "", sampleInst);
1552 sampleInst->replaceAllUsesWith(newSample);
1553 }
1554 }
1555
isIdentityMatrix(ExtractElementInst & I)1556 bool CustomSafeOptPass::isIdentityMatrix(ExtractElementInst& I)
1557 {
1558 bool found = false;
1559 auto extractType = cast<IGCLLVM::FixedVectorType>(I.getVectorOperandType());
1560 auto extractTypeVecSize = (uint32_t)extractType->getNumElements();
1561 if (extractTypeVecSize == 20 ||
1562 extractTypeVecSize == 16)
1563 {
1564 if (Constant * C = dyn_cast<Constant>(I.getVectorOperand()))
1565 {
1566 found = true;
1567
1568 // found = true if the extractelement is like this:
1569 // %189 = extractelement <20 x float>
1570 // <float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00,
1571 // float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00,
1572 // float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00,
1573 // float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00,
1574 // float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>, i32 %188
1575 for (unsigned int i = 0; i < extractTypeVecSize; i++)
1576 {
1577 if (i == 0 || i == 5 || i == 10 || i == 15)
1578 {
1579 ConstantFP* fpC = dyn_cast<ConstantFP>(C->getAggregateElement(i));
1580 ConstantInt* intC = dyn_cast<ConstantInt>(C->getAggregateElement(i));
1581 if((fpC && !fpC->isExactlyValue(1.f)) || (intC && !intC->isAllOnesValue()))
1582 {
1583 found = false;
1584 break;
1585 }
1586 }
1587 else if (!C->getAggregateElement(i)->isZeroValue())
1588 {
1589 found = false;
1590 break;
1591 }
1592 }
1593 }
1594 }
1595 return found;
1596 }
1597
dp4WithIdentityMatrix(ExtractElementInst & I)1598 void CustomSafeOptPass::dp4WithIdentityMatrix(ExtractElementInst& I)
1599 {
1600 /*
1601 convert dp4 with identity matrix icb ( ex: "dp4 r[6].x, cb[2][8].xyzw, icb[ r[4].w].xyzw") from
1602 %189 = shl nuw nsw i32 %188, 2, !dbg !326
1603 %190 = extractelement <20 x float> <float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>, i32 %189, !dbg !326
1604 %191 = or i32 %189, 1, !dbg !326
1605 %192 = extractelement <20 x float> <float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>, i32 %191, !dbg !326
1606 %193 = or i32 %189, 2, !dbg !326
1607 %194 = extractelement <20 x float> <float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>, i32 %193, !dbg !326
1608 %195 = or i32 %189, 3, !dbg !326
1609 %196 = extractelement <20 x float> <float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>, i32 %195, !dbg !326
1610 %s01_s.chan0141 = fmul fast float %181, %190, !dbg !326
1611 %197 = fmul fast float %182, %192, !dbg !326
1612 %198 = fadd fast float %197, %s01_s.chan0141, !dbg !326
1613 %199 = fmul fast float %183, %194, !dbg !326
1614 %200 = fadd fast float %199, %198, !dbg !326
1615 %201 = fmul fast float %184, %196, !dbg !326
1616 %202 = fadd fast float %201, %200, !dbg !326
1617 to
1618 %533 = icmp eq i32 %532, 0, !dbg !434
1619 %534 = icmp eq i32 %532, 1, !dbg !434
1620 %535 = icmp eq i32 %532, 2, !dbg !434
1621 %536 = icmp eq i32 %532, 3, !dbg !434
1622 %537 = select i1 %533, float %525, float 0.000000e+00, !dbg !434
1623 %538 = select i1 %534, float %526, float %537, !dbg !434
1624 %539 = select i1 %535, float %527, float %538, !dbg !434
1625 %540 = select i1 %536, float %528, float %539, !dbg !434
1626 */
1627
1628 // check %190 = extractelement <20 x float> <float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00 ...
1629 if (!I.hasOneUse() || !isIdentityMatrix(I))
1630 return;
1631
1632 Instruction* offset[4] = {nullptr, nullptr, nullptr, nullptr};
1633 ExtractElementInst* eeInst[4] = { &I, nullptr, nullptr, nullptr };
1634
1635 // check %189 = shl nuw nsw i32 %188, 2, !dbg !326
1636 // put it in offset[0]
1637 offset[0] = dyn_cast<BinaryOperator>(I.getOperand(1));
1638 if (!offset[0] ||
1639 offset[0]->getOpcode() != Instruction::Shl ||
1640 offset[0]->getOperand(1) != ConstantInt::get(offset[0]->getOperand(1)->getType(), 2))
1641 {
1642 return;
1643 }
1644
1645 // check %191 = or i32 %189, 1, !dbg !326
1646 // %193 = or i32 % 189, 2, !dbg !326
1647 // %195 = or i32 % 189, 3, !dbg !326
1648 // put them in offset[1], offset[2], offset[3]
1649 for (auto iter = offset[0]->user_begin(); iter != offset[0]->user_end(); iter++)
1650 {
1651 // skip checking for the %190 = extractelement <20 x float> <float 1.000000e+00, ....
1652 if (*iter == &I)
1653 continue;
1654
1655 if (BinaryOperator * orInst = dyn_cast<BinaryOperator>(*iter))
1656 {
1657 if (orInst->getOpcode() == BinaryOperator::Or && orInst->hasOneUse())
1658 {
1659 if (ConstantInt * orSrc1 = dyn_cast<ConstantInt>(orInst->getOperand(1)))
1660 {
1661 if (orSrc1->getZExtValue() < 4)
1662 {
1663 offset[orSrc1->getZExtValue()] = orInst;
1664 }
1665 }
1666 }
1667 }
1668 }
1669
1670 for (int i = 0; i < 4; i++)
1671 {
1672 if (offset[i] == nullptr)
1673 return;
1674 }
1675
1676 // check %192 = extractelement <20 x float> ...
1677 // %194 = extractelement <20 x float> ...
1678 // %196 = extractelement <20 x float> ...
1679 // put them in eeInst[i]
1680 for (int i = 1; i < 4; i++)
1681 {
1682 eeInst[i] = dyn_cast<ExtractElementInst>(*offset[i]->user_begin());
1683 if (!eeInst[i] || !isIdentityMatrix(*eeInst[i]))
1684 {
1685 return;
1686 }
1687 }
1688
1689 // check dp4 and put them in mulI[] and addI[]
1690 Instruction* mulI[4] = { nullptr, nullptr, nullptr, nullptr };
1691 for (int i = 0; i < 4; i++)
1692 {
1693 mulI[i] = dyn_cast<Instruction>(*eeInst[i]->user_begin());
1694 if (mulI[i] == nullptr || !mulI[i]->hasOneUse())
1695 {
1696 return;
1697 }
1698 }
1699 int inputInSrcIndex = 0;
1700 if (mulI[0]->getOperand(0) == eeInst[0])
1701 inputInSrcIndex = 1;
1702
1703 // the 1st and 2nd mul are the srcs for add
1704 if (*mulI[0]->user_begin() != *mulI[1]->user_begin())
1705 {
1706 return;
1707 }
1708 Instruction* addI[3] = { nullptr, nullptr, nullptr };
1709 addI[0] = dyn_cast<Instruction>(*mulI[0]->user_begin());
1710 addI[1] = dyn_cast<Instruction>(*mulI[2]->user_begin());
1711 addI[2] = dyn_cast<Instruction>(*mulI[3]->user_begin());
1712
1713 if( addI[0] == nullptr ||
1714 addI[1] == nullptr ||
1715 addI[2] == nullptr ||
1716 !addI[0]->hasOneUse() ||
1717 !addI[1]->hasOneUse() ||
1718 *addI[0]->user_begin() != *mulI[2]->user_begin() ||
1719 *addI[1]->user_begin() != *mulI[3]->user_begin())
1720 {
1721 return;
1722 }
1723
1724 // start the conversion
1725 IRBuilder<> builder(addI[2]);
1726
1727 Value* cond0 = builder.CreateICmp(ICmpInst::ICMP_EQ, offset[0]->getOperand(0), ConstantInt::get(offset[0]->getOperand(0)->getType(), 0));
1728 Value* cond1 = builder.CreateICmp(ICmpInst::ICMP_EQ, offset[0]->getOperand(0), ConstantInt::get(offset[0]->getOperand(0)->getType(), 1));
1729 Value* cond2 = builder.CreateICmp(ICmpInst::ICMP_EQ, offset[0]->getOperand(0), ConstantInt::get(offset[0]->getOperand(0)->getType(), 2));
1730 Value* cond3 = builder.CreateICmp(ICmpInst::ICMP_EQ, offset[0]->getOperand(0), ConstantInt::get(offset[0]->getOperand(0)->getType(), 3));
1731
1732 Value* zero = ConstantFP::get(Type::getFloatTy(I.getContext()), 0);
1733 Value* sel0 = builder.CreateSelect(cond0, mulI[0]->getOperand(inputInSrcIndex), zero);
1734 Value* sel1 = builder.CreateSelect(cond1, mulI[1]->getOperand(inputInSrcIndex), sel0);
1735 Value* sel2 = builder.CreateSelect(cond2, mulI[2]->getOperand(inputInSrcIndex), sel1);
1736 Value* sel3 = builder.CreateSelect(cond3, mulI[3]->getOperand(inputInSrcIndex), sel2);
1737
1738 addI[2]->replaceAllUsesWith(sel3);
1739 }
1740
1741
visitExtractElementInst(ExtractElementInst & I)1742 void CustomSafeOptPass::visitExtractElementInst(ExtractElementInst& I)
1743 {
1744 // convert:
1745 // %1 = lshr i32 %0, 16,
1746 // %2 = bitcast i32 %1 to <2 x half>
1747 // %3 = extractelement <2 x half> %2, i32 0
1748 // to ->
1749 // %2 = bitcast i32 %0 to <2 x half>
1750 // %3 = extractelement <2 x half> %2, i32 1
1751 BitCastInst* bitCast = dyn_cast<BitCastInst>(I.getVectorOperand());
1752 ConstantInt* cstIndex = dyn_cast<ConstantInt>(I.getIndexOperand());
1753 if (bitCast && cstIndex)
1754 {
1755 // skip intermediate bitcast
1756 while (isa<BitCastInst>(bitCast->getOperand(0)))
1757 {
1758 bitCast = cast<BitCastInst>(bitCast->getOperand(0));
1759 }
1760 if (BinaryOperator * binOp = dyn_cast<BinaryOperator>(bitCast->getOperand(0)))
1761 {
1762 unsigned int bitShift = 0;
1763 bool rightShift = false;
1764 if (binOp->getOpcode() == Instruction::LShr)
1765 {
1766 if (ConstantInt * cstShift = dyn_cast<ConstantInt>(binOp->getOperand(1)))
1767 {
1768 bitShift = (unsigned int)cstShift->getZExtValue();
1769 rightShift = true;
1770 }
1771 }
1772 else if (binOp->getOpcode() == Instruction::Shl)
1773 {
1774 if (ConstantInt * cstShift = dyn_cast<ConstantInt>(binOp->getOperand(1)))
1775 {
1776 bitShift = (unsigned int)cstShift->getZExtValue();
1777 }
1778 }
1779 if (bitShift != 0)
1780 {
1781 Type* vecType = I.getVectorOperand()->getType();
1782 unsigned int eltSize = (unsigned int)cast<VectorType>(vecType)->getElementType()->getPrimitiveSizeInBits();
1783 if (bitShift % eltSize == 0)
1784 {
1785 int elOffset = (int)(bitShift / eltSize);
1786 elOffset = rightShift ? elOffset : -elOffset;
1787 unsigned int newIndex = (unsigned int)((int)cstIndex->getZExtValue() + elOffset);
1788 if (newIndex < cast<IGCLLVM::FixedVectorType>(vecType)->getNumElements())
1789 {
1790 IRBuilder<> builder(&I);
1791 Value* newBitCast = builder.CreateBitCast(binOp->getOperand(0), vecType);
1792 Value* newScalar = builder.CreateExtractElement(newBitCast, newIndex);
1793 I.replaceAllUsesWith(newScalar);
1794 I.eraseFromParent();
1795 return;
1796 }
1797 }
1798 }
1799 }
1800 }
1801
1802 dp4WithIdentityMatrix(I);
1803 }
1804
1805 #if LLVM_VERSION_MAJOR >= 7
1806 ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
1807 // This pass removes dead local memory loads and stores. If we remove all such loads and stores, we also
1808 // remove all local memory fences together with barriers that follow.
1809 //
1810 IGC_INITIALIZE_PASS_BEGIN(TrivialLocalMemoryOpsElimination, "TrivialLocalMemoryOpsElimination", "TrivialLocalMemoryOpsElimination", false, false)
1811 IGC_INITIALIZE_PASS_END(TrivialLocalMemoryOpsElimination, "TrivialLocalMemoryOpsElimination", "TrivialLocalMemoryOpsElimination", false, false)
1812
1813 char TrivialLocalMemoryOpsElimination::ID = 0;
1814
TrivialLocalMemoryOpsElimination()1815 TrivialLocalMemoryOpsElimination::TrivialLocalMemoryOpsElimination() : FunctionPass(ID)
1816 {
1817 initializeTrivialLocalMemoryOpsEliminationPass(*PassRegistry::getPassRegistry());
1818 }
1819
runOnFunction(Function & F)1820 bool TrivialLocalMemoryOpsElimination::runOnFunction(Function& F)
1821 {
1822 bool change = false;
1823
1824 IGCMD::MetaDataUtils* pMdUtil = getAnalysis<MetaDataUtilsWrapper>().getMetaDataUtils();
1825 if (!isEntryFunc(pMdUtil, &F))
1826 {
1827 // Skip if it is non-entry function. For example, a subroutine
1828 // foo ( local int* p) { ...... store v, p; ......}
1829 // in which no localMemoptimization will be performed.
1830 return change;
1831 }
1832
1833 visit(F);
1834 if (!abortPass && (m_LocalLoadsToRemove.empty() ^ m_LocalStoresToRemove.empty()))
1835 {
1836 for (StoreInst* Inst : m_LocalStoresToRemove)
1837 {
1838 Inst->eraseFromParent();
1839 change = true;
1840 }
1841
1842 for (LoadInst* Inst : m_LocalLoadsToRemove)
1843 {
1844 if (Inst->use_empty())
1845 {
1846 Inst->eraseFromParent();
1847 change = true;
1848 }
1849 }
1850
1851 for (CallInst* Inst : m_LocalFencesBariersToRemove)
1852 {
1853 Inst->eraseFromParent();
1854 change = true;
1855 }
1856 }
1857 m_LocalStoresToRemove.clear();
1858 m_LocalLoadsToRemove.clear();
1859 m_LocalFencesBariersToRemove.clear();
1860
1861
1862 return change;
1863 }
1864
1865 /*
1866 OCL instruction barrier(CLK_LOCAL_MEM_FENCE); is translate to two instructions
1867 call void @llvm.genx.GenISA.memoryfence(i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true)
1868 call void @llvm.genx.GenISA.threadgroupbarrier()
1869
1870 if we remove call void @llvm.genx.GenISA.memoryfence(i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true)
1871 we must remove next instruction if it is call void @llvm.genx.GenISA.threadgroupbarrier()
1872 */
findNextThreadGroupBarrierInst(Instruction & I)1873 void TrivialLocalMemoryOpsElimination::findNextThreadGroupBarrierInst(Instruction& I)
1874 {
1875 auto nextInst = I.getNextNonDebugInstruction();
1876 if (isa<GenIntrinsicInst>(nextInst))
1877 {
1878 GenIntrinsicInst* II = dyn_cast<GenIntrinsicInst>(nextInst);
1879 if (II->getIntrinsicID() == GenISAIntrinsic::GenISA_threadgroupbarrier)
1880 {
1881 m_LocalFencesBariersToRemove.push_back(dyn_cast<CallInst>(nextInst));
1882 }
1883 }
1884 }
1885
visitLoadInst(LoadInst & I)1886 void TrivialLocalMemoryOpsElimination::visitLoadInst(LoadInst& I)
1887 {
1888 if (I.getPointerAddressSpace() == ADDRESS_SPACE_LOCAL)
1889 {
1890 m_LocalLoadsToRemove.push_back(&I);
1891 }
1892 else if (I.getPointerAddressSpace() == ADDRESS_SPACE_GENERIC)
1893 {
1894 abortPass = true;
1895 }
1896 }
1897
visitStoreInst(StoreInst & I)1898 void TrivialLocalMemoryOpsElimination::visitStoreInst(StoreInst& I)
1899 {
1900 if (I.getPointerAddressSpace() == ADDRESS_SPACE_LOCAL)
1901 {
1902 m_LocalStoresToRemove.push_back(&I);
1903 }
1904 else if (I.getPointerAddressSpace() == ADDRESS_SPACE_GENERIC)
1905 {
1906 abortPass = true;
1907 }
1908 }
1909
isLocalBarrier(CallInst & I)1910 bool TrivialLocalMemoryOpsElimination::isLocalBarrier(CallInst& I)
1911 {
1912 //check arguments in call void @llvm.genx.GenISA.memoryfence(i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true) if match to
1913 // (i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true) it is local barrier
1914 std::vector<bool> argumentsOfMemoryBarrier;
1915
1916 for (auto arg = I.arg_begin(); arg != I.arg_end(); ++arg)
1917 {
1918 ConstantInt* ci = dyn_cast<ConstantInt>(arg);
1919 if (ci) {
1920 argumentsOfMemoryBarrier.push_back(ci->getValue().getBoolValue());
1921 }
1922 else {
1923 // argument is not a constant, so we can't tell.
1924 return false;
1925 }
1926 }
1927
1928 return argumentsOfMemoryBarrier == m_argumentsOfLocalMemoryBarrier;
1929 }
1930
1931 // If any call instruction use pointer to local memory abort pass execution
anyCallInstUseLocalMemory(CallInst & I)1932 void TrivialLocalMemoryOpsElimination::anyCallInstUseLocalMemory(CallInst& I)
1933 {
1934 Function* fn = I.getCalledFunction();
1935
1936 if (fn != NULL)
1937 {
1938 for (auto arg = fn->arg_begin(); arg != fn->arg_end(); ++arg)
1939 {
1940 if (arg->getType()->isPointerTy())
1941 {
1942 if (arg->getType()->getPointerAddressSpace() == ADDRESS_SPACE_LOCAL || arg->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GENERIC) abortPass = true;
1943 }
1944 }
1945 }
1946 }
1947
visitCallInst(CallInst & I)1948 void TrivialLocalMemoryOpsElimination::visitCallInst(CallInst& I)
1949 {
1950 // detect only: llvm.genx.GenISA.memoryfence(i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true)
1951 // (note: the first and last arguments are true)
1952 // and add them with immediately following barriers to m_LocalFencesBariersToRemove
1953 anyCallInstUseLocalMemory(I);
1954
1955 if (isa<GenIntrinsicInst>(I))
1956 {
1957 GenIntrinsicInst* II = dyn_cast<GenIntrinsicInst>(&I);
1958 if (II->getIntrinsicID() == GenISAIntrinsic::GenISA_memoryfence)
1959 {
1960 if (isLocalBarrier(I))
1961 {
1962 m_LocalFencesBariersToRemove.push_back(&I);
1963 findNextThreadGroupBarrierInst(I);
1964 }
1965 }
1966 }
1967
1968 }
1969 #endif
1970
1971 // Register pass to igc-opt
1972 #define PASS_FLAG2 "igc-gen-specific-pattern"
1973 #define PASS_DESCRIPTION2 "LastPatternMatch Pass"
1974 #define PASS_CFG_ONLY2 false
1975 #define PASS_ANALYSIS2 false
1976 IGC_INITIALIZE_PASS_BEGIN(GenSpecificPattern, PASS_FLAG2, PASS_DESCRIPTION2, PASS_CFG_ONLY2, PASS_ANALYSIS2)
1977 IGC_INITIALIZE_PASS_END(GenSpecificPattern, PASS_FLAG2, PASS_DESCRIPTION2, PASS_CFG_ONLY2, PASS_ANALYSIS2)
1978
1979 char GenSpecificPattern::ID = 0;
1980
GenSpecificPattern()1981 GenSpecificPattern::GenSpecificPattern() : FunctionPass(ID)
1982 {
1983 initializeGenSpecificPatternPass(*PassRegistry::getPassRegistry());
1984 }
1985
runOnFunction(Function & F)1986 bool GenSpecificPattern::runOnFunction(Function& F)
1987 {
1988 visit(F);
1989 return true;
1990 }
1991
1992 // Lower SDiv to better code sequence if possible
visitSDiv(llvm::BinaryOperator & I)1993 void GenSpecificPattern::visitSDiv(llvm::BinaryOperator& I)
1994 {
1995 if (ConstantInt * divisor = dyn_cast<ConstantInt>(I.getOperand(1)))
1996 {
1997 // signed division of power of 2 can be transformed to asr
1998 // For negative values we need to make sure we round correctly
1999 int log2Div = divisor->getValue().exactLogBase2();
2000 if (log2Div > 0)
2001 {
2002 unsigned int intWidth = divisor->getBitWidth();
2003 IRBuilder<> builder(&I);
2004 Value* signedBitOnly = I.getOperand(0);
2005 if (log2Div > 1)
2006 {
2007 signedBitOnly = builder.CreateAShr(signedBitOnly, builder.getIntN(intWidth, intWidth - 1));
2008 }
2009 Value* offset = builder.CreateLShr(signedBitOnly, builder.getIntN(intWidth, intWidth - log2Div));
2010 Value* offsetedValue = builder.CreateAdd(I.getOperand(0), offset);
2011 Value* newValue = builder.CreateAShr(offsetedValue, builder.getIntN(intWidth, log2Div));
2012 I.replaceAllUsesWith(newValue);
2013 I.eraseFromParent();
2014 }
2015 }
2016 }
2017
2018 /*
2019 Optimizes bit reversing pattern:
2020
2021 %and = shl i32 %0, 1
2022 %shl = and i32 %and, 0xAAAAAAAA
2023 %and2 = lshr i32 %0, 1
2024 %shr = and i32 %and2, 0x55555555
2025 %or = or i32 %shl, %shr
2026 %and3 = shl i32 %or, 2
2027 %shl4 = and i32 %and3, 0xCCCCCCCC
2028 %and5 = lshr i32 %or, 2
2029 %shr6 = and i32 %and5, 0x33333333
2030 %or7 = or i32 %shl4, %shr6
2031 %and8 = shl i32 %or7, 4
2032 %shl9 = and i32 %and8, 0xF0F0F0F0
2033 %and10 = lshr i32 %or7, 4
2034 %shr11 = and i32 %and10, 0x0F0F0F0F
2035 %or12 = or i32 %shl9, %shr11
2036 %and13 = shl i32 %or12, 8
2037 %shl14 = and i32 %and13, 0xFF00FF00
2038 %and15 = lshr i32 %or12, 8
2039 %shr16 = and i32 %and15, 0x00FF00FF
2040 %or17 = or i32 %shl14, %shr16
2041 %shl19 = shl i32 %or17, 16
2042 %shr21 = lshr i32 %or17, 16
2043 %or22 = or i32 %shl19, %shr21
2044
2045 into:
2046
2047 %or22 = call i32 @llvm.genx.GenISA.bfrev.i32(i32 %0)
2048
2049 And similarly for patterns reversing 16 and 64 bit type values.
2050 */
2051 template <typename MaskType>
matchReverse(BinaryOperator & I)2052 void CustomSafeOptPass::matchReverse(BinaryOperator& I)
2053 {
2054 using namespace llvm::PatternMatch;
2055 IGC_ASSERT(I.getType()->isIntegerTy());
2056 Value* nextOrShl = nullptr, * nextOrShr = nullptr;
2057 uint64_t currentShiftShl = 0, currentShiftShr = 0;
2058 uint64_t currentMaskShl = 0, currentMaskShr = 0;
2059 auto patternBfrevFirst =
2060 m_Or(
2061 m_Shl(m_Value(nextOrShl), m_ConstantInt(currentShiftShl)),
2062 m_LShr(m_Value(nextOrShr), m_ConstantInt(currentShiftShr)));
2063
2064 auto patternBfrev =
2065 m_Or(
2066 m_And(
2067 m_Shl(m_Value(nextOrShl), m_ConstantInt(currentShiftShl)),
2068 m_ConstantInt(currentMaskShl)),
2069 m_And(
2070 m_LShr(m_Value(nextOrShr), m_ConstantInt(currentShiftShr)),
2071 m_ConstantInt(currentMaskShr)));
2072
2073 unsigned int bitWidth = std::numeric_limits<MaskType>::digits;
2074 IGC_ASSERT(bitWidth == 16 || bitWidth == 32 || bitWidth == 64);
2075
2076 unsigned int currentShift = bitWidth / 2;
2077 // First mask is a value with all upper half bits present.
2078 MaskType mask = std::numeric_limits<MaskType>::max() << currentShift;
2079
2080 bool isBfrevMatchFound = false;
2081 nextOrShl = &I;
2082 if (match(nextOrShl, patternBfrevFirst) &&
2083 nextOrShl == nextOrShr &&
2084 currentShiftShl == currentShift &&
2085 currentShiftShr == currentShift)
2086 {
2087 // NextOrShl is assigned to next one by match().
2088 currentShift /= 2;
2089 // Constructing next mask to match.
2090 mask ^= mask >> currentShift;
2091 }
2092
2093 while (currentShift > 0)
2094 {
2095 if (match(nextOrShl, patternBfrev) &&
2096 nextOrShl == nextOrShr &&
2097 currentShiftShl == currentShift &&
2098 currentShiftShr == currentShift &&
2099 currentMaskShl == mask &&
2100 currentMaskShr == (MaskType)~mask)
2101 {
2102 // NextOrShl is assigned to next one by match().
2103 if (currentShift == 1)
2104 {
2105 isBfrevMatchFound = true;
2106 break;
2107 }
2108
2109 currentShift /= 2;
2110 // Constructing next mask to match.
2111 mask ^= mask >> currentShift;
2112 }
2113 else
2114 {
2115 break;
2116 }
2117 }
2118
2119 if (isBfrevMatchFound)
2120 {
2121 llvm::IRBuilder<> builder(&I);
2122 Function* bfrevFunc = GenISAIntrinsic::getDeclaration(
2123 I.getParent()->getParent()->getParent(), GenISAIntrinsic::GenISA_bfrev, builder.getInt32Ty());
2124 if (bitWidth == 16)
2125 {
2126 Value* zext = builder.CreateZExt(nextOrShl, builder.getInt32Ty());
2127 Value* bfrev = builder.CreateCall(bfrevFunc, zext);
2128 Value* lshr = builder.CreateLShr(bfrev, 16);
2129 Value* trunc = builder.CreateTrunc(lshr, I.getType());
2130 I.replaceAllUsesWith(trunc);
2131 }
2132 else if (bitWidth == 32)
2133 {
2134 Value* bfrev = builder.CreateCall(bfrevFunc, nextOrShl);
2135 I.replaceAllUsesWith(bfrev);
2136 }
2137 else
2138 { // bitWidth == 64
2139 Value* int32Source = builder.CreateBitCast(nextOrShl, IGCLLVM::FixedVectorType::get(builder.getInt32Ty(), 2));
2140 Value* extractElement0 = builder.CreateExtractElement(int32Source, builder.getInt32(0));
2141 Value* extractElement1 = builder.CreateExtractElement(int32Source, builder.getInt32(1));
2142 Value* bfrevLow = builder.CreateCall(bfrevFunc, extractElement0);
2143 Value* bfrevHigh = builder.CreateCall(bfrevFunc, extractElement1);
2144 Value* bfrev64Result = llvm::UndefValue::get(int32Source->getType());
2145 bfrev64Result = builder.CreateInsertElement(bfrev64Result, bfrevHigh, builder.getInt32(0));
2146 bfrev64Result = builder.CreateInsertElement(bfrev64Result, bfrevLow, builder.getInt32(1));
2147 Value* bfrevBitcast = builder.CreateBitCast(bfrev64Result, I.getType());
2148 I.replaceAllUsesWith(bfrevBitcast);
2149 }
2150 }
2151 }
2152
2153 /* Transforms pattern1 or pattern2 to a bitcast,extract,insert,insert,bitcast
2154
2155 From:
2156 %5 = zext i32 %a to i64 <--- optional
2157 %6 = shl i64 %5, 32
2158 or
2159 %6 = and i64 %5, 0xFFFFFFFF00000000
2160 To:
2161 %BC = bitcast i64 %5 to <2 x i32> <---- not needed when %5 is zext
2162 %6 = extractelement <2 x i32> %BC, i32 0/1 <---- not needed when %5 is zext
2163 %7 = insertelement <2 x i32> %vec, i32 0, i32 0
2164 %8 = insertelement <2 x i32> %vec, %6, i32 1
2165 %9 = bitcast <2 x i32> %8 to i64
2166 */
createBitcastExtractInsertPattern(BinaryOperator & I,Value * OpLow,Value * OpHi,unsigned extractNum1,unsigned extractNum2)2167 void GenSpecificPattern::createBitcastExtractInsertPattern(BinaryOperator& I, Value* OpLow, Value* OpHi, unsigned extractNum1, unsigned extractNum2)
2168 {
2169 if (IGC_IS_FLAG_DISABLED(EnableBitcastExtractInsertPattern)) {
2170 return;
2171 }
2172
2173 llvm::IRBuilder<> builder(&I);
2174 auto vec2 = IGCLLVM::FixedVectorType::get(builder.getInt32Ty(), 2);
2175 Value* vec = UndefValue::get(vec2);
2176 Value* elemLow = nullptr;
2177 Value* elemHi = nullptr;
2178
2179 auto zeroextorNot = [&](Value* Op, unsigned num) -> Value *
2180 {
2181 Value* elem = nullptr;
2182 if (auto ZextInst = dyn_cast<ZExtInst>(Op))
2183 {
2184 if (ZextInst->getDestTy() == builder.getInt64Ty() && ZextInst->getSrcTy() == builder.getInt32Ty())
2185 {
2186 elem = ZextInst->getOperand(0);
2187 }
2188 }
2189 else if (auto IEIInst = dyn_cast<InsertElementInst>(Op))
2190 {
2191 auto opType = IEIInst->getType();
2192 if (opType->isVectorTy() && cast<VectorType>(opType)->getElementType()->isIntegerTy(32) && cast<IGCLLVM::FixedVectorType>(opType)->getNumElements() == 2)
2193 {
2194 elem = IEIInst->getOperand(1);
2195 }
2196 }
2197 else
2198 {
2199 Value* BC = builder.CreateBitCast(Op, vec2);
2200 elem = builder.CreateExtractElement(BC, builder.getInt32(num));
2201 }
2202 return elem;
2203 };
2204
2205 elemLow = (OpLow == nullptr) ? builder.getInt32(0) : zeroextorNot(OpLow, extractNum1);
2206 elemHi = (OpHi == nullptr) ? builder.getInt32(0) : zeroextorNot(OpHi, extractNum2);
2207
2208 if (elemHi == nullptr || elemLow == nullptr)
2209 return;
2210
2211 vec = builder.CreateInsertElement(vec, elemLow, builder.getInt32(0));
2212 vec = builder.CreateInsertElement(vec, elemHi, builder.getInt32(1));
2213 vec = builder.CreateBitCast(vec, builder.getInt64Ty());
2214 I.replaceAllUsesWith(vec);
2215 I.eraseFromParent();
2216 }
2217
visitBinaryOperator(BinaryOperator & I)2218 void GenSpecificPattern::visitBinaryOperator(BinaryOperator& I)
2219 {
2220 if (I.getOpcode() == Instruction::Or)
2221 {
2222 using namespace llvm::PatternMatch;
2223
2224 /*
2225 llvm changes ADD to OR when possible, and this optimization changes it back and allow 2 ADDs to merge.
2226 This can avoid scattered read for constant buffer when the index is calculated by shl + or + add.
2227
2228 ex:
2229 from
2230 %22 = shl i32 %14, 2
2231 %23 = or i32 %22, 3
2232 %24 = add i32 %23, 16
2233 to
2234 %22 = shl i32 %14, 2
2235 %23 = add i32 %22, 19
2236 */
2237 Value* AndOp1 = nullptr, * EltOp1 = nullptr;
2238 auto pattern1 = m_Or(
2239 m_And(m_Value(AndOp1), m_SpecificInt(0xFFFFFFFF)),
2240 m_Shl(m_Value(EltOp1), m_SpecificInt(32)));
2241 #if LLVM_VERSION_MAJOR >= 7
2242 Value * AndOp2 = nullptr, *EltOp2 = nullptr, *VecOp = nullptr;
2243 auto pattern2 = m_Or(
2244 m_And(m_Value(AndOp2), m_SpecificInt(0xFFFFFFFF)),
2245 m_BitCast(m_InsertElt(m_Value(VecOp), m_Value(EltOp2), m_SpecificInt(1))));
2246 #endif // LLVM_VERSION_MAJOR >= 7
2247 if (match(&I, pattern1) && AndOp1->getType()->isIntegerTy(64))
2248 {
2249 createBitcastExtractInsertPattern(I, AndOp1, EltOp1, 0, 1);
2250 }
2251 #if LLVM_VERSION_MAJOR >= 7
2252 else if (match(&I, pattern2) && AndOp2->getType()->isIntegerTy(64))
2253 {
2254 ConstantVector* cVec = dyn_cast<ConstantVector>(VecOp);
2255 IGCLLVM::FixedVectorType* vector_type = dyn_cast<IGCLLVM::FixedVectorType>(VecOp->getType());
2256 if (cVec && vector_type &&
2257 isa<ConstantInt>(cVec->getOperand(0)) &&
2258 cast<ConstantInt>(cVec->getOperand(0))->isZero() &&
2259 vector_type->getElementType()->isIntegerTy(32) &&
2260 vector_type->getNumElements() == 2)
2261 {
2262 auto InsertOp = cast<BitCastInst>(I.getOperand(1))->getOperand(0);
2263 createBitcastExtractInsertPattern(I, AndOp2, InsertOp, 0, 1);
2264 }
2265 }
2266 #endif // LLVM_VERSION_MAJOR >= 7
2267 else
2268 {
2269 /*
2270 from
2271 % 22 = shl i32 % 14, 2
2272 % 23 = or i32 % 22, 3
2273 to
2274 % 22 = shl i32 % 14, 2
2275 % 23 = add i32 % 22, 3
2276 */
2277 ConstantInt* OrConstant = dyn_cast<ConstantInt>(I.getOperand(1));
2278 if (OrConstant)
2279 {
2280 llvm::Instruction* ShlInst = llvm::dyn_cast<llvm::Instruction>(I.getOperand(0));
2281 if (ShlInst && ShlInst->getOpcode() == Instruction::Shl)
2282 {
2283 ConstantInt* ShlConstant = dyn_cast<ConstantInt>(ShlInst->getOperand(1));
2284 if (ShlConstant)
2285 {
2286 // if the constant bit width is larger than 64, we cannot store ShlIntValue and OrIntValue rawdata as uint64_t.
2287 // will need a fix then
2288 IGC_ASSERT(ShlConstant->getBitWidth() <= 64);
2289 IGC_ASSERT(OrConstant->getBitWidth() <= 64);
2290
2291 uint64_t ShlIntValue = *(ShlConstant->getValue()).getRawData();
2292 uint64_t OrIntValue = *(OrConstant->getValue()).getRawData();
2293
2294 if (OrIntValue < pow(2, ShlIntValue))
2295 {
2296 Value* newAdd = BinaryOperator::CreateAdd(I.getOperand(0), I.getOperand(1), "", &I);
2297 I.replaceAllUsesWith(newAdd);
2298 }
2299 }
2300 }
2301 }
2302 }
2303 }
2304 else if (I.getOpcode() == Instruction::Add)
2305 {
2306 /*
2307 from
2308 %23 = add i32 %22, 3
2309 %24 = add i32 %23, 16
2310 to
2311 %24 = add i32 %22, 19
2312 */
2313 for (int ImmSrcId1 = 0; ImmSrcId1 < 2; ImmSrcId1++)
2314 {
2315 ConstantInt* IConstant = dyn_cast<ConstantInt>(I.getOperand(ImmSrcId1));
2316 if (IConstant)
2317 {
2318 llvm::Instruction* AddInst = llvm::dyn_cast<llvm::Instruction>(I.getOperand(1 - ImmSrcId1));
2319 if (AddInst && AddInst->getOpcode() == Instruction::Add)
2320 {
2321 for (int ImmSrcId2 = 0; ImmSrcId2 < 2; ImmSrcId2++)
2322 {
2323 ConstantInt* AddConstant = dyn_cast<ConstantInt>(AddInst->getOperand(ImmSrcId2));
2324 if (AddConstant)
2325 {
2326 llvm::APInt CombineAddValue = AddConstant->getValue() + IConstant->getValue();
2327 I.setOperand(0, AddInst->getOperand(1 - ImmSrcId2));
2328 I.setOperand(1, ConstantInt::get(I.getType(), CombineAddValue));
2329 }
2330 }
2331 }
2332 }
2333 }
2334 }
2335 else if (I.getOpcode() == Instruction::Shl)
2336 {
2337 /*
2338 From:
2339 %5 = zext i32 %a to i64 <--- optional
2340 %6 = shl i64 %5, 32
2341
2342 To:
2343 %BC = bitcast i64 %5 to <2 x i32> <---- not needed when %5 is zext
2344 %6 = extractelement <2 x i32> %BC, i32 0 <---- not needed when %5 is zext
2345 %7 = insertelement <2 x i32> %vec, i32 0, i32 0
2346 %8 = insertelement <2 x i32> %vec, %6, i32 1
2347 %9 = bitcast <2 x i32> %8 to i64
2348 */
2349
2350 using namespace llvm::PatternMatch;
2351 Instruction* inst = nullptr;
2352
2353 auto pattern1 = m_Shl(m_Instruction(inst), m_SpecificInt(32));
2354
2355 if (match(&I, pattern1) && I.getType()->isIntegerTy(64))
2356 {
2357 createBitcastExtractInsertPattern(I, nullptr, I.getOperand(0), 0, 0);
2358 }
2359 }
2360 else if (I.getOpcode() == Instruction::And)
2361 {
2362 /* This `and` is basically fabs() done on high part of int representation.
2363 For float instructions minus operand can end as SrcMod, but since we cast it
2364 from double to int it will end as additional mov, and we can ignore this m_FNeg
2365 anyway.
2366
2367 From :
2368 %sub = fsub double -0.000000e+00, %res.039
2369 %25 = bitcast double %sub to i64
2370 %26 = bitcast i64 %25 to <2 x i32> // or directly double to <2xi32>
2371 %27 = extractelement <2 x i32> %26, i32 1
2372 %and31.i = and i32 %27, 2147483647
2373
2374 To:
2375 %25 = bitcast double %res.039 to <2 x i32>
2376 %27 = extractelement <2 x i32> %26, i32 1
2377 %and31.i = and i32 %27, 2147483647
2378
2379 Or on Int64 without extract:
2380 From:
2381 %sub = fsub double -0.000000e+00, %res.039
2382 %astype.i112.i.i = bitcast double %sub to i64
2383 %and107.i.i = and i64 %astype.i112.i.i, 9223372032559808512 // 0x7FFFFFFF00000000
2384 To:
2385 %bit_cast = bitcast double %res.039 to i64
2386 %and107.i.i = and i64 %bit_cast, 9223372032559808512 // 0x7FFFFFFF00000000
2387
2388 */
2389
2390 /* Get src of either 2 bitcast chain: double -> i64, i64 -> 2xi32
2391 or from single direct: double -> 2xi32
2392 */
2393 auto getValidBitcastSrc = [](Instruction* op) -> llvm::Value *
2394 {
2395 if (!(isa<BitCastInst>(op)))
2396 return nullptr;
2397
2398 BitCastInst* opBC = cast<BitCastInst>(op);
2399
2400 auto opType = opBC->getType();
2401 if (!(opType->isVectorTy() && cast<VectorType>(opType)->getElementType()->isIntegerTy(32) && cast<IGCLLVM::FixedVectorType>(opType)->getNumElements() == 2))
2402 return nullptr;
2403
2404 if (opBC->getSrcTy()->isDoubleTy())
2405 return opBC->getOperand(0); // double -> 2xi32
2406
2407 BitCastInst* bitCastSrc = dyn_cast<BitCastInst>(opBC->getOperand(0));
2408
2409 if (bitCastSrc && bitCastSrc->getDestTy()->isIntegerTy(64) && bitCastSrc->getSrcTy()->isDoubleTy())
2410 return bitCastSrc->getOperand(0); // double -> i64, i64 -> 2xi32
2411
2412 return nullptr;
2413 };
2414
2415 using namespace llvm::PatternMatch;
2416 Value* src_of_FNeg = nullptr;
2417 Instruction* inst = nullptr;
2418
2419 auto fabs_on_int_pattern1 = m_And(m_ExtractElt(m_Instruction(inst), m_SpecificInt(1)), m_SpecificInt(0x7FFFFFFF));
2420 auto fabs_on_int_pattern2 = m_And(m_Instruction(inst), m_SpecificInt(0x7FFFFFFF00000000));
2421 auto fneg_pattern = m_FNeg(m_Value(src_of_FNeg));
2422
2423 /*
2424 From:
2425 %5 = zext i32 %a to i64 <--- optional
2426 %6 = and i64 %5, 0xFFFFFFFF00000000 (-4294967296)
2427 To:
2428 %BC = bitcast i64 %5 to <2 x i32> <---- not needed when %5 is zext
2429 %6 = extractelement <2 x i32> %BC, i32 1 <---- not needed when %5 is zext
2430 %7 = insertelement <2 x i32> %vec, i32 0, i32 0
2431 %8 = insertelement <2 x i32> %vec, %6, i32 1
2432 %9 = bitcast <2 x i32> %8 to i64
2433 */
2434
2435 auto pattern1 = m_And(m_Instruction(inst), m_SpecificInt(0xFFFFFFFF00000000));
2436
2437 if (match(&I, fabs_on_int_pattern1))
2438 {
2439 Value* src = getValidBitcastSrc(inst);
2440 if (src && match(src, fneg_pattern) && src_of_FNeg->getType()->isDoubleTy())
2441 {
2442 llvm::IRBuilder<> builder(&I);
2443 VectorType* vec2 = IGCLLVM::FixedVectorType::get(builder.getInt32Ty(), 2);
2444 Value* BC = builder.CreateBitCast(src_of_FNeg, vec2);
2445 Value* EE = builder.CreateExtractElement(BC, builder.getInt32(1));
2446 Value* AI = builder.CreateAnd(EE, builder.getInt32(0x7FFFFFFF));
2447 I.replaceAllUsesWith(AI);
2448 I.eraseFromParent();
2449 }
2450 }
2451 else if (match(&I, fabs_on_int_pattern2))
2452 {
2453 BitCastInst* bitcast = dyn_cast<BitCastInst>(inst);
2454 bool bitcastValid = bitcast && bitcast->getDestTy()->isIntegerTy(64) && bitcast->getSrcTy()->isDoubleTy();
2455
2456 if (bitcastValid && match(bitcast->getOperand(0), fneg_pattern) && src_of_FNeg->getType()->isDoubleTy())
2457 {
2458 llvm::IRBuilder<> builder(&I);
2459 Value* BC = builder.CreateBitCast(src_of_FNeg, I.getType());
2460 Value* AI = builder.CreateAnd(BC, builder.getInt64(0x7FFFFFFF00000000));
2461 I.replaceAllUsesWith(AI);
2462 I.eraseFromParent();
2463 }
2464 }
2465 else if (match(&I, pattern1) && I.getType()->isIntegerTy(64))
2466 {
2467 createBitcastExtractInsertPattern(I, nullptr, I.getOperand(0), 0, 1);
2468 }
2469 else
2470 {
2471
2472 Instruction* AndSrc = nullptr;
2473 ConstantInt* CI;
2474
2475 /*
2476 From:
2477 %28 = and i32 %24, 255
2478 %29 = lshr i32 %24, 8
2479 %30 = and i32 %29, 255
2480 %31 = lshr i32 %24, 16
2481 %32 = and i32 %31, 255
2482 To:
2483 %temp = bitcast i32 %24 to <4 x i8>
2484 %ee1 = extractelement <4 x i8> %temp, i32 0
2485 %ee2 = extractelement <4 x i8> %temp, i32 1
2486 %ee3 = extractelement <4 x i8> %temp, i32 2
2487 %28 = zext i8 %ee1 to i32
2488 %30 = zext i8 %ee2 to i32
2489 %32 = zext i8 %ee3 to i32
2490 */
2491 auto pattern_And_0xFF = m_And(m_Instruction(AndSrc), m_SpecificInt(0xFF));
2492
2493 CodeGenContext* ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
2494 bool bytesAllowed = ctx->platform.supportByteALUOperation();
2495
2496 if (bytesAllowed && match(&I, pattern_And_0xFF) && I.getType()->isIntegerTy(32) && AndSrc->getType()->isIntegerTy(32))
2497 {
2498 Instruction* LhsSrc = nullptr;
2499
2500 auto LShr_Pattern = m_LShr(m_Instruction(LhsSrc), m_ConstantInt(CI));
2501 bool LShrMatch = match(AndSrc, LShr_Pattern) && LhsSrc->getType()->isIntegerTy(32) && (CI->getZExtValue() % 8 == 0);
2502
2503 // in case there's no shr, it will be 0
2504 uint32_t newIndex = 0;
2505
2506 if (LShrMatch) // extract inner
2507 {
2508 AndSrc = LhsSrc;
2509 newIndex = (uint32_t)CI->getZExtValue() / 8;
2510 }
2511
2512 llvm::IRBuilder<> builder(&I);
2513 VectorType* vec4 = VectorType::get(builder.getInt8Ty(), 4, false);
2514 Value* BC = builder.CreateBitCast(AndSrc, vec4);
2515 Value* EE = builder.CreateExtractElement(BC, builder.getInt32(newIndex));
2516 Value* Zext = builder.CreateZExt(EE, builder.getInt32Ty());
2517 I.replaceAllUsesWith(Zext);
2518 I.eraseFromParent();
2519
2520 }
2521
2522 }
2523 }
2524 }
2525
visitCmpInst(CmpInst & I)2526 void GenSpecificPattern::visitCmpInst(CmpInst& I)
2527 {
2528 using namespace llvm::PatternMatch;
2529 CmpInst::Predicate Pred = CmpInst::Predicate::BAD_ICMP_PREDICATE;
2530 Value* Val1 = nullptr;
2531 uint64_t const_int1 = 0, const_int2 = 0;
2532 auto cmp_pattern = m_Cmp(Pred,
2533 m_And(m_Value(Val1), m_ConstantInt(const_int1)), m_ConstantInt(const_int2));
2534
2535 if (match(&I, cmp_pattern) &&
2536 (const_int1 << 32) == 0 &&
2537 (const_int2 << 32) == 0 &&
2538 Val1->getType()->isIntegerTy(64))
2539 {
2540 llvm::IRBuilder<> builder(&I);
2541 VectorType* vec2 = IGCLLVM::FixedVectorType::get(builder.getInt32Ty(), 2);
2542 Value* BC = builder.CreateBitCast(Val1, vec2);
2543 Value* EE = builder.CreateExtractElement(BC, builder.getInt32(1));
2544 Value* AI = builder.CreateAnd(EE, builder.getInt32(const_int1 >> 32));
2545 Value* new_Val = builder.CreateICmp(Pred, AI, builder.getInt32(const_int2 >> 32));
2546 I.replaceAllUsesWith(new_Val);
2547 I.eraseFromParent();
2548 }
2549 else
2550 {
2551 CodeGenContext* pCtx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
2552 if (pCtx->getCompilerOption().NoNaNs)
2553 {
2554 if (I.getPredicate() == CmpInst::FCMP_ORD)
2555 {
2556 I.replaceAllUsesWith(ConstantInt::getTrue(I.getType()));
2557 }
2558 }
2559 }
2560 }
2561
visitSelectInst(SelectInst & I)2562 void GenSpecificPattern::visitSelectInst(SelectInst& I)
2563 {
2564 /*
2565 from
2566 %res_s42 = icmp eq i32 %src1_s41, 0
2567 %src1_s81 = select i1 %res_s42, i32 15, i32 0
2568 to
2569 %res_s42 = icmp eq i32 %src1_s41, 0
2570 %17 = sext i1 %res_s42 to i32
2571 %18 = and i32 15, %17
2572
2573 or
2574
2575 from
2576 %res_s73 = fcmp oge float %res_s61, %42
2577 %res_s187 = select i1 %res_s73, float 1.000000e+00, float 0.000000e+00
2578 to
2579 %res_s73 = fcmp oge float %res_s61, %42
2580 %46 = sext i1 %res_s73 to i32
2581 %47 = and i32 %46, 1065353216
2582 %48 = bitcast i32 %47 to float
2583 */
2584
2585 IGC_ASSERT(I.getOpcode() == Instruction::Select);
2586
2587 bool skipOpt = false;
2588
2589 ConstantInt* Cint = dyn_cast<ConstantInt>(I.getOperand(2));
2590 if (Cint && Cint->isZero())
2591 {
2592 llvm::Instruction* cmpInst = llvm::dyn_cast<llvm::Instruction>(I.getOperand(0));
2593 if (cmpInst &&
2594 cmpInst->getOpcode() == Instruction::ICmp &&
2595 I.getOperand(1) != cmpInst->getOperand(0))
2596 {
2597 // disable the cases for csel or where we can optimize the instructions to such as add.ge.* later in vISA
2598 ConstantInt* C = dyn_cast<ConstantInt>(cmpInst->getOperand(1));
2599 if (C && C->isZero())
2600 {
2601 skipOpt = true;
2602 }
2603
2604 if (!skipOpt)
2605 {
2606 // temporary disable the case where cmp is used in multiple sel, and not all of them have src2=0
2607 // we should remove this if we can allow both flag and grf dst for the cmp to be used.
2608 for (auto selI = cmpInst->user_begin(), E = cmpInst->user_end(); selI != E; ++selI)
2609 {
2610 if (llvm::SelectInst * selInst = llvm::dyn_cast<llvm::SelectInst>(*selI))
2611 {
2612 ConstantInt* C = dyn_cast<ConstantInt>(selInst->getOperand(2));
2613 if (!(C && C->isZero()))
2614 {
2615 skipOpt = true;
2616 break;
2617 }
2618 }
2619 }
2620 }
2621
2622 if (!skipOpt)
2623 {
2624 Value* newValueSext = CastInst::CreateSExtOrBitCast(I.getOperand(0), I.getType(), "", &I);
2625 Value* newValueAnd = BinaryOperator::CreateAnd(I.getOperand(1), newValueSext, "", &I);
2626 I.replaceAllUsesWith(newValueAnd);
2627 }
2628 }
2629 }
2630 else
2631 {
2632 ConstantFP* Cfp = dyn_cast<ConstantFP>(I.getOperand(2));
2633 if (Cfp && Cfp->isZero())
2634 {
2635 llvm::Instruction* cmpInst = llvm::dyn_cast<llvm::Instruction>(I.getOperand(0));
2636 if (cmpInst &&
2637 cmpInst->getOpcode() == Instruction::FCmp &&
2638 I.getOperand(1) != cmpInst->getOperand(0))
2639 {
2640 // disable the cases for csel or where we can optimize the instructions to such as add.ge.* later in vISA
2641 ConstantFP* C = dyn_cast<ConstantFP>(cmpInst->getOperand(1));
2642 if (C && C->isZero())
2643 {
2644 skipOpt = true;
2645 }
2646
2647 if (!skipOpt)
2648 {
2649 for (auto selI = cmpInst->user_begin(), E = cmpInst->user_end(); selI != E; ++selI)
2650 {
2651 if (llvm::SelectInst * selInst = llvm::dyn_cast<llvm::SelectInst>(*selI))
2652 {
2653 // temporary disable the case where cmp is used in multiple sel, and not all of them have src2=0
2654 // we should remove this if we can allow both flag and grf dst for the cmp to be used.
2655 ConstantFP* C2 = dyn_cast<ConstantFP>(selInst->getOperand(2));
2656 if (!(C2 && C2->isZero()))
2657 {
2658 skipOpt = true;
2659 break;
2660 }
2661
2662 // if it is cmp-sel(1.0 / 0.0)-mul, we could better patten match it later in codeGen.
2663 ConstantFP* C1 = dyn_cast<ConstantFP>(selInst->getOperand(1));
2664 if (C1 && C2 && selInst->hasOneUse())
2665 {
2666 if ((C2->isZero() && C1->isExactlyValue(1.f)) || (C1->isZero() && C2->isExactlyValue(1.f)))
2667 {
2668 Instruction* mulInst = dyn_cast<Instruction>(*selInst->user_begin());
2669 if (mulInst && mulInst->getOpcode() == Instruction::FMul)
2670 {
2671 skipOpt = true;
2672 break;
2673 }
2674 }
2675 }
2676 }
2677 }
2678 }
2679
2680 if (!skipOpt)
2681 {
2682 ConstantFP* C1 = dyn_cast<ConstantFP>(I.getOperand(1));
2683 if (C1)
2684 {
2685 if (C1->getType()->isHalfTy())
2686 {
2687 Value* newValueSext = CastInst::CreateSExtOrBitCast(I.getOperand(0), Type::getInt16Ty(I.getContext()), "", &I);
2688 Value* newConstant = ConstantInt::get(I.getContext(), C1->getValueAPF().bitcastToAPInt());
2689 Value* newValueAnd = BinaryOperator::CreateAnd(newValueSext, newConstant, "", &I);
2690 Value* newValueCastFP = CastInst::CreateZExtOrBitCast(newValueAnd, Type::getHalfTy(I.getContext()), "", &I);
2691 I.replaceAllUsesWith(newValueCastFP);
2692 }
2693 else if (C1->getType()->isFloatTy())
2694 {
2695 Value* newValueSext = CastInst::CreateSExtOrBitCast(I.getOperand(0), Type::getInt32Ty(I.getContext()), "", &I);
2696 Value* newConstant = ConstantInt::get(I.getContext(), C1->getValueAPF().bitcastToAPInt());
2697 Value* newValueAnd = BinaryOperator::CreateAnd(newValueSext, newConstant, "", &I);
2698 Value* newValueCastFP = CastInst::CreateZExtOrBitCast(newValueAnd, Type::getFloatTy(I.getContext()), "", &I);
2699 I.replaceAllUsesWith(newValueCastFP);
2700 }
2701 }
2702 else
2703 {
2704 if (I.getOperand(1)->getType()->isHalfTy())
2705 {
2706 Value* newValueSext = CastInst::CreateSExtOrBitCast(I.getOperand(0), Type::getInt16Ty(I.getContext()), "", &I);
2707 Value* newValueBitcast = CastInst::CreateZExtOrBitCast(I.getOperand(1), Type::getInt16Ty(I.getContext()), "", &I);
2708 Value* newValueAnd = BinaryOperator::CreateAnd(newValueSext, newValueBitcast, "", &I);
2709 Value* newValueCastFP = CastInst::CreateZExtOrBitCast(newValueAnd, Type::getHalfTy(I.getContext()), "", &I); \
2710 I.replaceAllUsesWith(newValueCastFP);
2711 }
2712 else if (I.getOperand(1)->getType()->isFloatTy())
2713 {
2714 Value* newValueSext = CastInst::CreateSExtOrBitCast(I.getOperand(0), Type::getInt32Ty(I.getContext()), "", &I);
2715 Value* newValueBitcast = CastInst::CreateZExtOrBitCast(I.getOperand(1), Type::getInt32Ty(I.getContext()), "", &I);
2716 Value* newValueAnd = BinaryOperator::CreateAnd(newValueSext, newValueBitcast, "", &I);
2717 Value* newValueCastFP = CastInst::CreateZExtOrBitCast(newValueAnd, Type::getFloatTy(I.getContext()), "", &I); \
2718 I.replaceAllUsesWith(newValueCastFP);
2719 }
2720 }
2721 }
2722 }
2723 }
2724 }
2725
2726 /*
2727 from
2728 %230 = sdiv i32 %214, %scale
2729 %276 = trunc i32 %230 to i8
2730 %277 = icmp slt i32 %230, 255
2731 %278 = select i1 %277, i8 %276, i8 -1
2732 to
2733 %230 = sdiv i32 %214, %scale
2734 %277 = icmp slt i32 %230, 255
2735 %278 = select i1 %277, i32 %230, i32 255
2736 %279 = trunc i32 %278 to i8
2737
2738 This tranform allows for min/max instructions to be
2739 picked up in the IsMinOrMax instruction in PatternMatchPass.cpp
2740 */
2741 if (auto * compInst = dyn_cast<ICmpInst>(I.getOperand(0)))
2742 {
2743 auto selOp1 = I.getOperand(1);
2744 auto selOp2 = I.getOperand(2);
2745 auto cmpOp0 = compInst->getOperand(0);
2746 auto cmpOp1 = compInst->getOperand(1);
2747 auto trunc1 = dyn_cast<TruncInst>(selOp1);
2748 auto trunc2 = dyn_cast<TruncInst>(selOp2);
2749 auto icmpType = compInst->getOperand(0)->getType();
2750
2751 if (selOp1->getType()->isIntegerTy() &&
2752 icmpType->isIntegerTy() &&
2753 selOp1->getType()->getIntegerBitWidth() < icmpType->getIntegerBitWidth())
2754 {
2755 Value* newSelOp1 = NULL;
2756 Value* newSelOp2 = NULL;
2757 if (trunc1 &&
2758 (trunc1->getOperand(0) == cmpOp0 ||
2759 trunc1->getOperand(0) == cmpOp1))
2760 {
2761 newSelOp1 = (trunc1->getOperand(0) == cmpOp0) ? cmpOp0 : cmpOp1;
2762 }
2763
2764 if (trunc2 &&
2765 (trunc2->getOperand(0) == cmpOp0 ||
2766 trunc2->getOperand(0) == cmpOp1))
2767 {
2768 newSelOp2 = (trunc2->getOperand(0) == cmpOp0) ? cmpOp0 : cmpOp1;
2769 }
2770
2771 if (isa<llvm::ConstantInt>(selOp1) &&
2772 isa<llvm::ConstantInt>(cmpOp0) &&
2773 (cast<llvm::ConstantInt>(selOp1)->getZExtValue() ==
2774 cast<llvm::ConstantInt>(cmpOp0)->getZExtValue()))
2775 {
2776 IGC_ASSERT(newSelOp1 == NULL);
2777 newSelOp1 = cmpOp0;
2778 }
2779
2780 if (isa<llvm::ConstantInt>(selOp1) &&
2781 isa<llvm::ConstantInt>(cmpOp1) &&
2782 (cast<llvm::ConstantInt>(selOp1)->getZExtValue() ==
2783 cast<llvm::ConstantInt>(cmpOp1)->getZExtValue()))
2784 {
2785 IGC_ASSERT(newSelOp1 == NULL);
2786 newSelOp1 = cmpOp1;
2787 }
2788
2789 if (isa<llvm::ConstantInt>(selOp2) &&
2790 isa<llvm::ConstantInt>(cmpOp0) &&
2791 (cast<llvm::ConstantInt>(selOp2)->getZExtValue() ==
2792 cast<llvm::ConstantInt>(cmpOp0)->getZExtValue()))
2793 {
2794 IGC_ASSERT(newSelOp2 == NULL);
2795 newSelOp2 = cmpOp0;
2796 }
2797
2798 if (isa<llvm::ConstantInt>(selOp2) &&
2799 isa<llvm::ConstantInt>(cmpOp1) &&
2800 (cast<llvm::ConstantInt>(selOp2)->getZExtValue() ==
2801 cast<llvm::ConstantInt>(cmpOp1)->getZExtValue()))
2802 {
2803 IGC_ASSERT(newSelOp2 == NULL);
2804 newSelOp2 = cmpOp1;
2805 }
2806
2807 if (newSelOp1 && newSelOp2)
2808 {
2809 auto newSelInst = SelectInst::Create(I.getCondition(), newSelOp1, newSelOp2, "", &I);
2810 auto newTruncInst = TruncInst::CreateTruncOrBitCast(newSelInst, selOp1->getType(), "", &I);
2811 I.replaceAllUsesWith(newTruncInst);
2812 I.eraseFromParent();
2813 }
2814 }
2815
2816 }
2817
2818 }
2819
visitCastInst(CastInst & I)2820 void GenSpecificPattern::visitCastInst(CastInst& I)
2821 {
2822 Instruction* srcVal = nullptr;
2823 // Intrinsic::trunc call is handled by 'rndz' hardware instruction which
2824 // does not support double precision floating point type
2825 if (I.getType()->isDoubleTy())
2826 {
2827 return;
2828 }
2829 if (isa<SIToFPInst>(&I))
2830 {
2831 srcVal = dyn_cast<FPToSIInst>(I.getOperand(0));
2832 }
2833 if (srcVal && srcVal->getOperand(0)->getType() == I.getType())
2834 {
2835 if ((srcVal = dyn_cast<Instruction>(srcVal->getOperand(0))))
2836 {
2837 // need fast math to know that we can ignore Nan
2838 if (isa<FPMathOperator>(srcVal) && srcVal->isFast())
2839 {
2840 IRBuilder<> builder(&I);
2841 Function* func = Intrinsic::getDeclaration(
2842 I.getParent()->getParent()->getParent(),
2843 Intrinsic::trunc,
2844 I.getType());
2845 Value* newVal = builder.CreateCall(func, srcVal);
2846 I.replaceAllUsesWith(newVal);
2847 I.eraseFromParent();
2848 }
2849 }
2850 }
2851 }
2852
2853 /*
2854 from:
2855 %HighBits.Vec = insertelement <2 x i32> <i32 0, i32 undef>, i32 %HighBits.32, i32 1
2856 %HighBits.64 = bitcast <2 x i32> %HighBits.Vec to i64
2857 %LowBits.64 = zext i32 %LowBits.32 to i64
2858 %LowPlusHighBits = or i64 %HighBits.64, %LowBits.64
2859 %19 = bitcast i64 %LowPlusHighBits to double
2860 to:
2861 %17 = insertelement <2 x i32> undef, i32 %LowBits.32, i32 0
2862 %18 = insertelement <2 x i32> %17, i32 %HighBits.32, i32 1
2863 %19 = bitcast <2 x i32> %18 to double
2864 */
2865
visitBitCastInst(BitCastInst & I)2866 void GenSpecificPattern::visitBitCastInst(BitCastInst& I)
2867 {
2868 if (I.getType()->isDoubleTy() && I.getOperand(0)->getType()->isIntegerTy(64))
2869 {
2870 BinaryOperator* binOperator = nullptr;
2871 if ((binOperator = dyn_cast<BinaryOperator>(I.getOperand(0))) && binOperator->getOpcode() == Instruction::Or)
2872 {
2873 if (isa<BitCastInst>(binOperator->getOperand(0)) && isa<ZExtInst>(binOperator->getOperand(1)))
2874 {
2875 BitCastInst* bitCastInst = cast<BitCastInst>(binOperator->getOperand(0));
2876 ZExtInst* zExtInst = cast<ZExtInst>(binOperator->getOperand(1));
2877
2878 if (zExtInst->getOperand(0)->getType()->isIntegerTy(32) &&
2879 isa<InsertElementInst>(bitCastInst->getOperand(0)) &&
2880 bitCastInst->getOperand(0)->getType()->isVectorTy() &&
2881 cast<IGCLLVM::FixedVectorType>(bitCastInst->getOperand(0)->getType())->getElementType()->isIntegerTy(32) &&
2882 cast<IGCLLVM::FixedVectorType>(bitCastInst->getOperand(0)->getType())->getNumElements() == 2)
2883 {
2884 InsertElementInst* insertElementInst = cast<InsertElementInst>(bitCastInst->getOperand(0));
2885
2886 if (isa<Constant>(insertElementInst->getOperand(0)) &&
2887 cast<Constant>(insertElementInst->getOperand(0))->getAggregateElement((unsigned int)0)->isZeroValue() &&
2888 cast<ConstantInt>(insertElementInst->getOperand(2))->getZExtValue() == 1)
2889 {
2890 IRBuilder<> builder(&I);
2891 Value* vectorValue = UndefValue::get(bitCastInst->getOperand(0)->getType());
2892 vectorValue = builder.CreateInsertElement(vectorValue, zExtInst->getOperand(0), builder.getInt32(0));
2893 vectorValue = builder.CreateInsertElement(vectorValue, insertElementInst->getOperand(1), builder.getInt32(1));
2894 Value* newBitCast = builder.CreateBitCast(vectorValue, builder.getDoubleTy());
2895 I.replaceAllUsesWith(newBitCast);
2896 I.eraseFromParent();
2897 }
2898 }
2899 }
2900 }
2901 }
2902 }
2903
2904 /*
2905 Matches a pattern where pointer to load instruction is fetched by other load instruction.
2906 On targets that do not support 64 bit operations, Emu64OpsPass will insert pair_to_ptr intrinsic
2907 between the loads and InstructionCombining will not optimize this case.
2908
2909 This function changes following pattern:
2910 %3 = load <2 x i32>, <2 x i32> addrspace(1)* %2, align 64
2911 %4 = extractelement <2 x i32> %3, i32 0
2912 %5 = extractelement <2 x i32> %3, i32 1
2913 %6 = call %union._XReq addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* @llvm.genx.GenISA.pair.to.ptr.p1p1p1p1p1p1p1p1union._XReq(i32 %4, i32 %5)
2914 %7 = bitcast %union._XReq addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* %6 to i64 addrspace(1)*
2915 %8 = bitcast i64 addrspace(1)* %7 to <2 x i32> addrspace(1)*
2916 %9 = load <2 x i32>, <2 x i32> addrspace(1)* %8, align 64
2917
2918 to:
2919 %3 = bitcast <2 x i32> addrspace(1)* %2 to <2 x i32> addrspace(1)* addrspace(1)*
2920 %4 = load <2 x i32> addrspace(1)*, <2 x i32> addrspace(1)* addrspace(1)* %3, align 64
2921 ... dead code
2922 %11 = load <2 x i32>, <2 x i32> addrspace(1)* %4, align 64
2923 */
visitLoadInst(LoadInst & LI)2924 void GenSpecificPattern::visitLoadInst(LoadInst &LI) {
2925 Value* PO = LI.getPointerOperand();
2926 std::vector<Value*> OneUseValues = { PO };
2927 while (isa<BitCastInst>(PO)) {
2928 PO = cast<BitCastInst>(PO)->getOperand(0);
2929 OneUseValues.push_back(PO);
2930 }
2931
2932 bool IsPairToPtrInst = (isa<GenIntrinsicInst>(PO) &&
2933 cast<GenIntrinsicInst>(PO)->getIntrinsicID() ==
2934 GenISAIntrinsic::GenISA_pair_to_ptr);
2935
2936 if (!IsPairToPtrInst)
2937 return;
2938
2939 // check if this pointer comes from a load.
2940 auto CallInst = cast<GenIntrinsicInst>(PO);
2941 auto Op0 = dyn_cast<ExtractElementInst>(CallInst->getArgOperand(0));
2942 auto Op1 = dyn_cast<ExtractElementInst>(CallInst->getArgOperand(1));
2943 bool PointerComesFromALoad = (Op0 && Op1 && isa<ConstantInt>(Op0->getIndexOperand()) &&
2944 isa<ConstantInt>(Op1->getIndexOperand()) &&
2945 cast<ConstantInt>(Op0->getIndexOperand())->getZExtValue() == 0 &&
2946 cast<ConstantInt>(Op1->getIndexOperand())->getZExtValue() == 1 &&
2947 isa<LoadInst>(Op0->getVectorOperand()) &&
2948 isa<LoadInst>(Op1->getVectorOperand()) &&
2949 Op0->getVectorOperand() == Op1->getVectorOperand());
2950
2951 if (!PointerComesFromALoad)
2952 return;
2953
2954 OneUseValues.insert(OneUseValues.end(), { Op0, Op1 });
2955
2956 if (!std::all_of(OneUseValues.begin(), OneUseValues.end(), [](auto v) { return v->hasOneUse(); }))
2957 return;
2958
2959 auto VectorLoadInst = cast<LoadInst>(Op0->getVectorOperand());
2960 if (VectorLoadInst->getNumUses() != 2)
2961 return;
2962
2963 auto PointerOperand = VectorLoadInst->getPointerOperand();
2964 PointerType* newLoadPointerType = PointerType::get(
2965 LI.getPointerOperand()->getType(), PointerOperand->getType()->getPointerAddressSpace());
2966 IRBuilder<> builder(VectorLoadInst);
2967 auto CastedPointer =
2968 builder.CreateBitCast(PointerOperand, newLoadPointerType);
2969 auto NewLoadInst = IGC::cloneLoad(VectorLoadInst, CastedPointer);
2970
2971 LI.setOperand(0, NewLoadInst);
2972 }
2973
visitZExtInst(ZExtInst & ZEI)2974 void GenSpecificPattern::visitZExtInst(ZExtInst& ZEI)
2975 {
2976 CmpInst* Cmp = dyn_cast<CmpInst>(ZEI.getOperand(0));
2977 if (!Cmp)
2978 return;
2979
2980 IRBuilder<> Builder(&ZEI);
2981
2982 Value* S = Builder.CreateSExt(Cmp, ZEI.getType());
2983 Value* N = Builder.CreateNeg(S);
2984 ZEI.replaceAllUsesWith(N);
2985 ZEI.eraseFromParent();
2986 }
2987
visitIntToPtr(llvm::IntToPtrInst & I)2988 void GenSpecificPattern::visitIntToPtr(llvm::IntToPtrInst& I)
2989 {
2990 if (ZExtInst * zext = dyn_cast<ZExtInst>(I.getOperand(0)))
2991 {
2992 IRBuilder<> builder(&I);
2993 Value* newV = builder.CreateIntToPtr(zext->getOperand(0), I.getType());
2994 I.replaceAllUsesWith(newV);
2995 I.eraseFromParent();
2996 }
2997 }
2998
visitTruncInst(llvm::TruncInst & I)2999 void GenSpecificPattern::visitTruncInst(llvm::TruncInst& I)
3000 {
3001 /*
3002 from
3003 %22 = lshr i64 %a, 52
3004 %23 = trunc i64 %22 to i32
3005 to
3006 %22 = extractelement <2 x i32> %a, 1
3007 %23 = lshr i32 %22, 20 //52-32
3008 */
3009
3010 using namespace llvm::PatternMatch;
3011 Value* LHS = nullptr;
3012 ConstantInt* CI;
3013 if (match(&I, m_Trunc(m_LShr(m_Value(LHS), m_ConstantInt(CI)))) &&
3014 I.getType()->isIntegerTy(32) &&
3015 LHS->getType()->isIntegerTy(64) &&
3016 CI->getZExtValue() >= 32)
3017 {
3018 auto new_shift_size = (unsigned)CI->getZExtValue() - 32;
3019 llvm::IRBuilder<> builder(&I);
3020 VectorType* vec2 = IGCLLVM::FixedVectorType::get(builder.getInt32Ty(), 2);
3021 Value* new_Val = builder.CreateBitCast(LHS, vec2);
3022 new_Val = builder.CreateExtractElement(new_Val, builder.getInt32(1));
3023 if (new_shift_size > 0)
3024 {
3025 new_Val = builder.CreateLShr(new_Val, builder.getInt32(new_shift_size));
3026 }
3027 I.replaceAllUsesWith(new_Val);
3028 I.eraseFromParent();
3029 }
3030 }
3031
3032 #if LLVM_VERSION_MAJOR >= 10
visitFNeg(llvm::UnaryOperator & I)3033 void GenSpecificPattern::visitFNeg(llvm::UnaryOperator& I)
3034 {
3035 // from
3036 // %neg = fneg double %1
3037 // to
3038 // %neg = fsub double -0.000000e+00, %1
3039 // vISA have parser which looks for such operation pattern "0 - x"
3040 // and adds source modifier for this region/value.
3041
3042 IRBuilder<> builder(&I);
3043
3044 Value* fsub = nullptr;
3045
3046 if (!I.getType()->isVectorTy())
3047 {
3048 fsub = builder.CreateFSub(ConstantFP::get(I.getType(), 0.0f), I.getOperand(0));
3049 }
3050 else
3051 {
3052 uint32_t vectorSize = cast<IGCLLVM::FixedVectorType>(I.getType())->getNumElements();
3053 fsub = llvm::UndefValue::get(I.getType());
3054
3055 for (uint32_t i = 0; i < vectorSize; ++i)
3056 {
3057 Value* extract = builder.CreateExtractElement(I.getOperand(0), i);
3058 Value* extract_fsub = builder.CreateFSub(ConstantFP::get(extract->getType(), 0.0f), extract);
3059 fsub = builder.CreateInsertElement(fsub, extract_fsub, i);
3060 }
3061 }
3062
3063 I.replaceAllUsesWith(fsub);
3064 }
3065 #endif
3066
3067 // Register pass to igc-opt
3068 #define PASS_FLAG3 "igc-const-prop"
3069 #define PASS_DESCRIPTION3 "Custom Const-prop Pass"
3070 #define PASS_CFG_ONLY3 false
3071 #define PASS_ANALYSIS3 false
3072 IGC_INITIALIZE_PASS_BEGIN(IGCConstProp, PASS_FLAG3, PASS_DESCRIPTION3, PASS_CFG_ONLY3, PASS_ANALYSIS3)
3073 IGC_INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
3074 IGC_INITIALIZE_PASS_END(IGCConstProp, PASS_FLAG3, PASS_DESCRIPTION3, PASS_CFG_ONLY3, PASS_ANALYSIS3)
3075
3076 char IGCConstProp::ID = 0;
3077
IGCConstProp(bool enableSimplifyGEP)3078 IGCConstProp::IGCConstProp(
3079 bool enableSimplifyGEP) :
3080 FunctionPass(ID),
3081 m_enableSimplifyGEP(enableSimplifyGEP),
3082 m_TD(nullptr), m_TLI(nullptr)
3083 {
3084 initializeIGCConstPropPass(*PassRegistry::getPassRegistry());
3085 }
3086
GetConstantValue(Type * type,char * rawData)3087 static Constant* GetConstantValue(Type* type, char* rawData)
3088 {
3089 unsigned int size_in_bytes = (unsigned int)type->getPrimitiveSizeInBits() / 8;
3090 uint64_t returnConstant = 0;
3091 memcpy_s(&returnConstant, size_in_bytes, rawData, size_in_bytes);
3092 if (type->isIntegerTy())
3093 {
3094 return ConstantInt::get(type, returnConstant);
3095 }
3096 else if (type->isFloatingPointTy())
3097 {
3098 return ConstantFP::get(type->getContext(),
3099 APFloat(type->getFltSemantics(), APInt((unsigned int)type->getPrimitiveSizeInBits(), returnConstant)));
3100 }
3101 return nullptr;
3102 }
3103
3104
replaceShaderConstant(Instruction * inst)3105 Constant* IGCConstProp::replaceShaderConstant(Instruction* inst)
3106 {
3107 CodeGenContext* ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
3108 ModuleMetaData* modMD = ctx->getModuleMetaData();
3109 ConstantAddress cl;
3110 unsigned int& bufIdOrGRFOffset = cl.bufId;
3111 unsigned int& eltId = cl.eltId;
3112 unsigned int& size_in_bytes = cl.size;
3113 bool directBuf = false;
3114 bool statelessBuf = false;
3115 bool bindlessBuf = false;
3116
3117 if (getConstantAddress(*inst, cl, ctx, directBuf, statelessBuf, bindlessBuf))
3118 {
3119 if (size_in_bytes)
3120 {
3121 if (modMD->immConstant.data.size() &&
3122 ((statelessBuf && (bufIdOrGRFOffset == modMD->pushInfo.inlineConstantBufferGRFOffset))||
3123 (directBuf && (bufIdOrGRFOffset == modMD->pushInfo.inlineConstantBufferSlot))))
3124 {
3125 char* offset = &(modMD->immConstant.data[0]);
3126 if (inst->getType()->isVectorTy())
3127 {
3128 Type* srcEltTy = cast<VectorType>(inst->getType())->getElementType();
3129 uint32_t srcNElts = (uint32_t)cast<IGCLLVM::FixedVectorType>(inst->getType())->getNumElements();
3130 uint32_t eltSize_in_bytes = (unsigned int)srcEltTy->getPrimitiveSizeInBits() / 8;
3131 IRBuilder<> builder(inst);
3132 Value* vectorValue = UndefValue::get(inst->getType());
3133 char* pEltValue; // Pointer to element value
3134 for (uint i = 0; i < srcNElts; i++)
3135 {
3136 if (eltId < 0 || eltId >= (int)modMD->immConstant.data.size())
3137 {
3138 int OOBvalue = 0; // OOB access to immediate constant buffer should return 0
3139 char* pOOBvalue = (char*)& OOBvalue; // Pointer to value 0 which is a OOB access value
3140 pEltValue = pOOBvalue;
3141 }
3142 else
3143 pEltValue = offset + eltId + (i * eltSize_in_bytes);
3144 vectorValue = builder.CreateInsertElement(
3145 vectorValue,
3146 GetConstantValue(srcEltTy, pEltValue),
3147 builder.getInt32(i));
3148 }
3149 return dyn_cast<Constant>(vectorValue);
3150 }
3151 else
3152 {
3153 char* pEltValue; // Pointer to element value
3154 if (eltId < 0 || eltId >= (int)modMD->immConstant.data.size())
3155 {
3156 int OOBvalue = 0; // OOB access to immediate constant buffer should return 0
3157 char* pOOBvalue = (char*)& OOBvalue; // Pointer to value 0 which is a OOB access value
3158 pEltValue = pOOBvalue;
3159 }
3160 else
3161 pEltValue = offset + eltId;
3162 return GetConstantValue(inst->getType(), pEltValue);
3163 }
3164 }
3165 }
3166 }
3167 return nullptr;
3168 }
3169
ConstantFoldCallInstruction(CallInst * inst)3170 Constant* IGCConstProp::ConstantFoldCallInstruction(CallInst* inst)
3171 {
3172 IGCConstantFolder constantFolder;
3173 Constant* C = nullptr;
3174 if (inst)
3175 {
3176 Constant* C0 = dyn_cast<Constant>(inst->getOperand(0));
3177 EOPCODE igcop = GetOpCode(inst);
3178
3179 switch (igcop)
3180 {
3181 case llvm_gradientXfine:
3182 {
3183 if (C0)
3184 {
3185 C = constantFolder.CreateGradientXFine(C0);
3186 }
3187 }
3188 break;
3189 case llvm_gradientYfine:
3190 {
3191 if (C0)
3192 {
3193 C = constantFolder.CreateGradientYFine(C0);
3194 }
3195 }
3196 break;
3197 case llvm_gradientX:
3198 {
3199 if (C0)
3200 {
3201 C = constantFolder.CreateGradientX(C0);
3202 }
3203 }
3204 break;
3205 case llvm_gradientY:
3206 {
3207 if (C0)
3208 {
3209 C = constantFolder.CreateGradientY(C0);
3210 }
3211 }
3212 break;
3213 case llvm_rsq:
3214 {
3215 if (C0)
3216 {
3217 C = constantFolder.CreateRsq(C0);
3218 }
3219 }
3220 break;
3221 case llvm_roundne:
3222 {
3223 if (C0)
3224 {
3225 C = constantFolder.CreateRoundNE(C0);
3226 }
3227 }
3228 break;
3229 case llvm_fsat:
3230 {
3231 if (C0)
3232 {
3233 C = constantFolder.CreateFSat(C0);
3234 }
3235 }
3236 break;
3237 case llvm_fptrunc_rte:
3238 {
3239 if (C0)
3240 {
3241 C = constantFolder.CreateFPTrunc(C0, inst->getType(), llvm::APFloatBase::rmNearestTiesToEven);
3242 }
3243 }
3244 break;
3245 case llvm_fptrunc_rtz:
3246 {
3247 if (C0)
3248 {
3249 C = constantFolder.CreateFPTrunc(C0, inst->getType(), llvm::APFloatBase::rmTowardZero);
3250 }
3251 }
3252 break;
3253 case llvm_fptrunc_rtp:
3254 {
3255 if (C0)
3256 {
3257 C = constantFolder.CreateFPTrunc(C0, inst->getType(), llvm::APFloatBase::rmTowardPositive);
3258 }
3259 }
3260 break;
3261 case llvm_fptrunc_rtn:
3262 {
3263 if (C0)
3264 {
3265 C = constantFolder.CreateFPTrunc(C0, inst->getType(), llvm::APFloatBase::rmTowardNegative);
3266 }
3267 }
3268 break;
3269 case llvm_f32tof16_rtz:
3270 {
3271 if (C0)
3272 {
3273 C = constantFolder.CreateFPTrunc(C0, Type::getHalfTy(inst->getContext()), llvm::APFloatBase::rmTowardZero);
3274 C = constantFolder.CreateBitCast(C, Type::getInt16Ty(inst->getContext()));
3275 C = constantFolder.CreateZExtOrBitCast(C, Type::getInt32Ty(inst->getContext()));
3276 C = constantFolder.CreateBitCast(C, inst->getType());
3277 }
3278 }
3279 break;
3280 case llvm_fadd_rtz:
3281 {
3282 Constant* C1 = dyn_cast<Constant>(inst->getOperand(1));
3283 if (C0 && C1)
3284 {
3285 C = constantFolder.CreateFAdd(C0, C1, llvm::APFloatBase::rmTowardZero);
3286 }
3287 }
3288 break;
3289 case llvm_fmul_rtz:
3290 {
3291 Constant* C1 = dyn_cast<Constant>(inst->getOperand(1));
3292 if (C0 && C1)
3293 {
3294 C = constantFolder.CreateFMul(C0, C1, llvm::APFloatBase::rmTowardZero);
3295 }
3296 }
3297 break;
3298 case llvm_ubfe:
3299 {
3300 Constant* C1 = dyn_cast<Constant>(inst->getOperand(1));
3301 Constant* C2 = dyn_cast<Constant>(inst->getOperand(2));
3302 if (C0 && C0->isZeroValue())
3303 {
3304 C = llvm::ConstantInt::get(inst->getType(), 0);
3305 }
3306 else if (C0 && C1 && C2)
3307 {
3308 C = constantFolder.CreateUbfe(C0, C1, C2);
3309 }
3310 }
3311 break;
3312 case llvm_ibfe:
3313 {
3314 Constant* C1 = dyn_cast<Constant>(inst->getOperand(1));
3315 Constant* C2 = dyn_cast<Constant>(inst->getOperand(2));
3316 if (C0 && C0->isZeroValue())
3317 {
3318 C = llvm::ConstantInt::get(inst->getType(), 0);
3319 }
3320 else if (C0 && C1 && C2)
3321 {
3322 C = constantFolder.CreateIbfe(C0, C1, C2);
3323 }
3324 }
3325 break;
3326 case llvm_canonicalize:
3327 {
3328 // If the instruction should be emitted anyway, then remove the condition.
3329 // Please, be aware of the fact that clients can understand the term canonical FP value in other way.
3330 if (C0)
3331 {
3332 CodeGenContext* pCodeGenContext = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
3333 bool flushVal = pCodeGenContext->m_floatDenormMode16 == ::IGC::FLOAT_DENORM_FLUSH_TO_ZERO && inst->getType()->isHalfTy();
3334 flushVal = flushVal || (pCodeGenContext->m_floatDenormMode32 == ::IGC::FLOAT_DENORM_FLUSH_TO_ZERO && inst->getType()->isFloatTy());
3335 flushVal = flushVal || (pCodeGenContext->m_floatDenormMode64 == ::IGC::FLOAT_DENORM_FLUSH_TO_ZERO && inst->getType()->isDoubleTy());
3336 C = constantFolder.CreateCanonicalize(C0, flushVal);
3337 }
3338 }
3339 break;
3340 case llvm_fbh:
3341 {
3342 if (C0)
3343 {
3344 C = constantFolder.CreateFirstBitHi(C0);
3345 }
3346 }
3347 break;
3348 case llvm_fbh_shi:
3349 {
3350 if (C0)
3351 {
3352 C = constantFolder.CreateFirstBitShi(C0);
3353 }
3354 }
3355 break;
3356 case llvm_fbl:
3357 {
3358 if (C0)
3359 {
3360 C = constantFolder.CreateFirstBitLo(C0);
3361 }
3362 }
3363 break;
3364 case llvm_bfrev:
3365 {
3366 if (C0)
3367 {
3368 C = constantFolder.CreateBfrev(C0);
3369 }
3370 }
3371 break;
3372 case llvm_bfi:
3373 {
3374 Constant* C1 = dyn_cast<Constant>(inst->getOperand(1));
3375 Constant* C2 = dyn_cast<Constant>(inst->getOperand(2));
3376 Constant* C3 = dyn_cast<Constant>(inst->getOperand(3));
3377 if (C0 && C0->isZeroValue() && C3)
3378 {
3379 C = C3;
3380 }
3381 else if (C0 && C1 && C2 && C3)
3382 {
3383 C = constantFolder.CreateBfi(C0, C1, C2, C3);
3384 }
3385 }
3386 break;
3387 default:
3388 break;
3389 }
3390 }
3391 return C;
3392 }
3393
3394 // constant fold the following code for any index:
3395 //
3396 // %95 = extractelement <4 x i16> <i16 3, i16 16, i16 21, i16 39>, i32 %94
3397 // %96 = icmp eq i16 %95, 0
3398 //
ConstantFoldCmpInst(CmpInst * CI)3399 Constant* IGCConstProp::ConstantFoldCmpInst(CmpInst* CI)
3400 {
3401 // Only handle scalar result.
3402 if (CI->getType()->isVectorTy())
3403 return nullptr;
3404
3405 Value* LHS = CI->getOperand(0);
3406 Value* RHS = CI->getOperand(1);
3407 if (!isa<Constant>(RHS) && CI->isCommutative())
3408 std::swap(LHS, RHS);
3409 if (!isa<ConstantInt>(RHS) && !isa<ConstantFP>(RHS))
3410 return nullptr;
3411
3412 auto EEI = dyn_cast<ExtractElementInst>(LHS);
3413 if (EEI && isa<Constant>(EEI->getVectorOperand()))
3414 {
3415 bool AllTrue = true, AllFalse = true;
3416 auto VecOpnd = cast<Constant>(EEI->getVectorOperand());
3417 unsigned N = (unsigned)cast<IGCLLVM::FixedVectorType>(VecOpnd->getType())->getNumElements();
3418 for (unsigned i = 0; i < N; ++i)
3419 {
3420 Constant* const Opnd = VecOpnd->getAggregateElement(i);
3421 IGC_ASSERT_MESSAGE(nullptr != Opnd, "null entry");
3422
3423 if (isa<UndefValue>(Opnd))
3424 continue;
3425 Constant* Result = ConstantFoldCompareInstOperands(
3426 CI->getPredicate(), Opnd, cast<Constant>(RHS), CI->getFunction()->getParent()->getDataLayout());
3427 if (!Result->isAllOnesValue())
3428 AllTrue = false;
3429 if (!Result->isNullValue())
3430 AllFalse = false;
3431 }
3432
3433 if (AllTrue)
3434 {
3435 return ConstantInt::getTrue(CI->getType());
3436 }
3437 else if (AllFalse)
3438 {
3439 return ConstantInt::getFalse(CI->getType());
3440 }
3441 }
3442
3443 return nullptr;
3444 }
3445
3446 // constant fold the following code for any index:
3447 //
3448 // %93 = insertelement <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, float %v7.w_, i32 0
3449 // %95 = extractelement <4 x float> %93, i32 1
3450 //
3451 // constant fold the selection of the same value in a vector component, e.g.:
3452 // %Temp - 119.i.i.v.v = select i1 %Temp - 118.i.i, <2 x i32> <i32 0, i32 17>, <2 x i32> <i32 4, i32 17>
3453 // %scalar9 = extractelement <2 x i32> %Temp - 119.i.i.v.v, i32 1
3454 //
ConstantFoldExtractElement(ExtractElementInst * EEI)3455 Constant* IGCConstProp::ConstantFoldExtractElement(ExtractElementInst* EEI)
3456 {
3457
3458 Constant* EltIdx = dyn_cast<Constant>(EEI->getIndexOperand());
3459 if (EltIdx)
3460 {
3461 if (InsertElementInst * IEI = dyn_cast<InsertElementInst>(EEI->getVectorOperand()))
3462 {
3463 Constant* InsertIdx = dyn_cast<Constant>(IEI->getOperand(2));
3464 // try to find the constant from a chain of InsertElement
3465 while (IEI && InsertIdx)
3466 {
3467 if (InsertIdx == EltIdx)
3468 {
3469 Constant* EltVal = dyn_cast<Constant>(IEI->getOperand(1));
3470 return EltVal;
3471 }
3472 else
3473 {
3474 Value* Vec = IEI->getOperand(0);
3475 if (isa<ConstantDataVector>(Vec))
3476 {
3477 ConstantDataVector* CVec = cast<ConstantDataVector>(Vec);
3478 return CVec->getAggregateElement(EltIdx);
3479 }
3480 else if (isa<InsertElementInst>(Vec))
3481 {
3482 IEI = cast<InsertElementInst>(Vec);
3483 InsertIdx = dyn_cast<Constant>(IEI->getOperand(2));
3484 }
3485 else
3486 {
3487 break;
3488 }
3489 }
3490 }
3491 }
3492 else if (SelectInst * sel = dyn_cast<SelectInst>(EEI->getVectorOperand()))
3493 {
3494 Value* vec0 = sel->getOperand(1);
3495 Value* vec1 = sel->getOperand(2);
3496
3497 IGC_ASSERT(vec0->getType() == vec1->getType());
3498
3499 if (isa<ConstantDataVector>(vec0) && isa<ConstantDataVector>(vec1))
3500 {
3501 ConstantDataVector* cvec0 = cast<ConstantDataVector>(vec0);
3502 ConstantDataVector* cvec1 = cast<ConstantDataVector>(vec1);
3503 Constant* cval0 = cvec0->getAggregateElement(EltIdx);
3504 Constant* cval1 = cvec1->getAggregateElement(EltIdx);
3505
3506 if (cval0 == cval1)
3507 {
3508 return cval0;
3509 }
3510 }
3511 }
3512 }
3513 return nullptr;
3514 }
3515
3516 // simplifyAdd() push any constants to the top of a sequence of Add instructions,
3517 // which makes CSE/GVN to do a better job. For example,
3518 // (((A + 8) + B) + C) + 15
3519 // will become
3520 // ((A + B) + C) + 23
3521 // Note that the order of non-constant values remain unchanged throughout this
3522 // transformation.
3523 // (This was added to remove redundant loads. If the
3524 // the future LLVM does better job on this (reassociation), we should use LLVM's
3525 // instead.)
simplifyAdd(BinaryOperator * BO)3526 bool IGCConstProp::simplifyAdd(BinaryOperator* BO)
3527 {
3528 // Only handle Add
3529 if (BO->getOpcode() != Instruction::Add)
3530 {
3531 return false;
3532 }
3533
3534 Value* LHS = BO->getOperand(0);
3535 Value* RHS = BO->getOperand(1);
3536 bool changed = false;
3537 if (BinaryOperator * LBO = dyn_cast<BinaryOperator>(LHS))
3538 {
3539 if (simplifyAdd(LBO))
3540 {
3541 changed = true;
3542 }
3543 }
3544 if (BinaryOperator * RBO = dyn_cast<BinaryOperator>(RHS))
3545 {
3546 if (simplifyAdd(RBO))
3547 {
3548 changed = true;
3549 }
3550 }
3551
3552 // Refresh LHS and RHS
3553 LHS = BO->getOperand(0);
3554 RHS = BO->getOperand(1);
3555 BinaryOperator* LHSbo = dyn_cast<BinaryOperator>(LHS);
3556 BinaryOperator* RHSbo = dyn_cast<BinaryOperator>(RHS);
3557 bool isLHSAdd = LHSbo && LHSbo->getOpcode() == Instruction::Add;
3558 bool isRHSAdd = RHSbo && RHSbo->getOpcode() == Instruction::Add;
3559 IRBuilder<> Builder(BO);
3560 if (isLHSAdd && isRHSAdd)
3561 {
3562 Value* A = LHSbo->getOperand(0);
3563 Value* B = LHSbo->getOperand(1);
3564 Value* C = RHSbo->getOperand(0);
3565 Value* D = RHSbo->getOperand(1);
3566
3567 ConstantInt* C0 = dyn_cast<ConstantInt>(B);
3568 ConstantInt* C1 = dyn_cast<ConstantInt>(D);
3569
3570 if (C0 || C1)
3571 {
3572 Value* R = nullptr;
3573 if (C0 && C1)
3574 {
3575 Value* newC = ConstantFoldBinaryOpOperands(Instruction::Add,
3576 C0, C1, *m_TD);
3577 R = Builder.CreateAdd(A, C);
3578 R = Builder.CreateAdd(R, newC);
3579 }
3580 else if (C0)
3581 {
3582 R = Builder.CreateAdd(A, RHS);
3583 R = Builder.CreateAdd(R, B);
3584 }
3585 else
3586 { // C1 is not nullptr
3587 R = Builder.CreateAdd(LHS, C);
3588 R = Builder.CreateAdd(R, D);
3589 }
3590 BO->replaceAllUsesWith(R);
3591 return true;
3592 }
3593 }
3594 else if (isLHSAdd)
3595 {
3596 Value* A = LHSbo->getOperand(0);
3597 Value* B = LHSbo->getOperand(1);
3598 Value* C = RHS;
3599
3600 ConstantInt* C0 = dyn_cast<ConstantInt>(B);
3601 ConstantInt* C1 = dyn_cast<ConstantInt>(C);
3602
3603 if (C0 && C1)
3604 {
3605 Value* newC = ConstantFoldBinaryOpOperands(Instruction::Add,
3606 C0, C1, *m_TD);
3607 Value* R = Builder.CreateAdd(A, newC);
3608 BO->replaceAllUsesWith(R);
3609 return true;
3610 }
3611 if (C0)
3612 {
3613 Value* R = Builder.CreateAdd(A, C);
3614 R = Builder.CreateAdd(R, B);
3615 BO->replaceAllUsesWith(R);
3616 return true;
3617 }
3618 }
3619 else if (isRHSAdd)
3620 {
3621 Value* A = LHS;
3622 Value* B = RHSbo->getOperand(0);
3623 Value* C = RHSbo->getOperand(1);
3624
3625 ConstantInt* C0 = dyn_cast<ConstantInt>(A);
3626 ConstantInt* C1 = dyn_cast<ConstantInt>(C);
3627
3628 if (C0 && C1)
3629 {
3630 Value* newC = ConstantFoldBinaryOpOperands(Instruction::Add,
3631 C0, C1, *m_TD);
3632 Value* R = Builder.CreateAdd(B, newC);
3633 BO->replaceAllUsesWith(R);
3634 return true;
3635 }
3636 if (C0)
3637 {
3638 Value* R = Builder.CreateAdd(RHS, A);
3639 BO->replaceAllUsesWith(R);
3640 return true;
3641 }
3642 if (C1)
3643 {
3644 Value* R = Builder.CreateAdd(A, B);
3645 R = Builder.CreateAdd(R, C);
3646 BO->replaceAllUsesWith(R);
3647 return true;
3648 }
3649 }
3650 else
3651 {
3652 if (ConstantInt * CLHS = dyn_cast<ConstantInt>(LHS))
3653 {
3654 if (ConstantInt * CRHS = dyn_cast<ConstantInt>(RHS))
3655 {
3656 Value* newC = ConstantFoldBinaryOpOperands(Instruction::Add,
3657 CLHS, CRHS, *m_TD);
3658 BO->replaceAllUsesWith(newC);
3659 return true;
3660 }
3661
3662 // Constant is kept as RHS
3663 Value* R = Builder.CreateAdd(RHS, LHS);
3664 BO->replaceAllUsesWith(R);
3665 return true;
3666 }
3667 }
3668 return changed;
3669 }
3670
simplifyGEP(GetElementPtrInst * GEP)3671 bool IGCConstProp::simplifyGEP(GetElementPtrInst* GEP)
3672 {
3673 bool changed = false;
3674 for (int i = 0; i < (int)GEP->getNumIndices(); ++i)
3675 {
3676 Value* Index = GEP->getOperand(i + 1);
3677 BinaryOperator* BO = dyn_cast<BinaryOperator>(Index);
3678 if (!BO || BO->getOpcode() != Instruction::Add)
3679 {
3680 continue;
3681 }
3682 if (simplifyAdd(BO))
3683 {
3684 changed = true;
3685 }
3686 }
3687 return changed;
3688 }
3689
3690 /**
3691 * the following code is essentially a copy of llvm copy-prop code with one little
3692 * addition for shader-constant replacement.
3693 *
3694 * we don't have to do this if llvm version uses a virtual function in place of calling
3695 * ConstantFoldInstruction.
3696 */
runOnFunction(Function & F)3697 bool IGCConstProp::runOnFunction(Function& F)
3698 {
3699 module = F.getParent();
3700 // Initialize the worklist to all of the instructions ready to process...
3701 llvm::SetVector<Instruction*> WorkList;
3702 for (inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i)
3703 {
3704 WorkList.insert(&*i);
3705 }
3706 bool Changed = false;
3707 m_TD = &F.getParent()->getDataLayout();
3708 m_TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
3709 while (!WorkList.empty())
3710 {
3711 Instruction* I = *WorkList.rbegin();
3712 WorkList.remove(I); // Get an element from the worklist...
3713 if (I->use_empty()) // Don't muck with dead instructions...
3714 {
3715 continue;
3716 }
3717 Constant* C = nullptr;
3718 C = ConstantFoldInstruction(I, *m_TD, m_TLI);
3719
3720 if (!C && isa<CallInst>(I))
3721 {
3722 C = ConstantFoldCallInstruction(cast<CallInst>(I));
3723 }
3724
3725 // replace shader-constant load with the known value
3726 if (!C && isa<LoadInst>(I))
3727 {
3728 C = replaceShaderConstant(cast<LoadInst>(I));
3729 }
3730 if (!C && isa<CmpInst>(I))
3731 {
3732 C = ConstantFoldCmpInst(cast<CmpInst>(I));
3733 }
3734 if (!C && isa<ExtractElementInst>(I))
3735 {
3736 C = ConstantFoldExtractElement(cast<ExtractElementInst>(I));
3737 }
3738 if (C)
3739 {
3740 // Add all of the users of this instruction to the worklist, they might
3741 // be constant propagatable now...
3742 for (Value::user_iterator UI = I->user_begin(), UE = I->user_end();
3743 UI != UE; ++UI)
3744 {
3745 WorkList.insert(cast<Instruction>(*UI));
3746 }
3747
3748 // Replace all of the uses of a variable with uses of the constant.
3749 I->replaceAllUsesWith(C);
3750
3751 if (0 /* isa<ConstantPointerNull>(C)*/) // disable optimization generating invalid IR until it gets re-written
3752 {
3753 // if we are changing function calls/ genisa intrinsics, then we need
3754 // to fix the function declarations to account for the change in pointer address type
3755 for (Value::user_iterator UI = C->user_begin(), UE = C->user_end();
3756 UI != UE; ++UI)
3757 {
3758 if (GenIntrinsicInst * genIntr = dyn_cast<GenIntrinsicInst>(*UI))
3759 {
3760 GenISAIntrinsic::ID ID = genIntr->getIntrinsicID();
3761 if (ID == GenISAIntrinsic::GenISA_storerawvector_indexed)
3762 {
3763 llvm::Type* tys[2];
3764 tys[0] = genIntr->getOperand(0)->getType();
3765 tys[1] = genIntr->getOperand(2)->getType();
3766 GenISAIntrinsic::getDeclaration(F.getParent(),
3767 llvm::GenISAIntrinsic::GenISA_storerawvector_indexed,
3768 tys);
3769 }
3770 else if (ID == GenISAIntrinsic::GenISA_storeraw_indexed)
3771 {
3772 llvm::Type* types[2] = {
3773 genIntr->getOperand(0)->getType(),
3774 genIntr->getOperand(1)->getType() };
3775
3776 GenISAIntrinsic::getDeclaration(F.getParent(),
3777 llvm::GenISAIntrinsic::GenISA_storeraw_indexed,
3778 types);
3779 }
3780 else if (ID == GenISAIntrinsic::GenISA_ldrawvector_indexed || ID == GenISAIntrinsic::GenISA_ldraw_indexed)
3781 {
3782 llvm::Type* tys[2];
3783 tys[0] = genIntr->getType();
3784 tys[1] = genIntr->getOperand(0)->getType();
3785 GenISAIntrinsic::getDeclaration(F.getParent(),
3786 ID,
3787 tys);
3788 }
3789 }
3790 }
3791 }
3792
3793 // Remove the dead instruction.
3794 I->eraseFromParent();
3795
3796 // We made a change to the function...
3797 Changed = true;
3798
3799 // I is erased, continue to the next one.
3800 continue;
3801 }
3802
3803 if (GetElementPtrInst * GEP = dyn_cast<GetElementPtrInst>(I))
3804 {
3805 if (m_enableSimplifyGEP && simplifyGEP(GEP))
3806 {
3807 Changed = true;
3808 }
3809 }
3810 }
3811 return Changed;
3812 }
3813
3814 namespace {
3815
3816 class IGCIndirectICBPropagaion : public FunctionPass
3817 {
3818 public:
3819 static char ID;
IGCIndirectICBPropagaion()3820 IGCIndirectICBPropagaion() : FunctionPass(ID)
3821 {
3822 initializeIGCIndirectICBPropagaionPass(*PassRegistry::getPassRegistry());
3823 }
getPassName() const3824 virtual llvm::StringRef getPassName() const { return "Indirect ICB Propagation"; }
3825 virtual bool runOnFunction(Function& F);
getAnalysisUsage(llvm::AnalysisUsage & AU) const3826 virtual void getAnalysisUsage(llvm::AnalysisUsage& AU) const
3827 {
3828 AU.setPreservesCFG();
3829 AU.addRequired<CodeGenContextWrapper>();
3830 }
3831 private:
3832 bool isICBOffseted(llvm::LoadInst* inst, uint offset);
3833 };
3834
3835 } // namespace
3836
3837 char IGCIndirectICBPropagaion::ID = 0;
createIGCIndirectICBPropagaionPass()3838 FunctionPass* IGC::createIGCIndirectICBPropagaionPass() { return new IGCIndirectICBPropagaion(); }
3839
runOnFunction(Function & F)3840 bool IGCIndirectICBPropagaion::runOnFunction(Function& F)
3841 {
3842 CodeGenContext* ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
3843 ModuleMetaData* modMD = ctx->getModuleMetaData();
3844
3845 //MaxImmConstantSizePushed = 256 by default. For float values, it will contains 64 numbers, and stored in 8 GRF
3846 if (modMD &&
3847 modMD->immConstant.data.size() &&
3848 modMD->immConstant.data.size() <= IGC_GET_FLAG_VALUE(MaxImmConstantSizePushed))
3849 {
3850 uint maxImmConstantSizePushed = modMD->immConstant.data.size();
3851 char* offset = &(modMD->immConstant.data[0]);
3852 IRBuilder<> m_builder(F.getContext());
3853
3854 for (auto& BB : F)
3855 {
3856 for (auto BI = BB.begin(), BE = BB.end(); BI != BE;)
3857 {
3858 if (llvm::LoadInst * inst = llvm::dyn_cast<llvm::LoadInst>(&(*BI++)))
3859 {
3860 unsigned as = inst->getPointerAddressSpace();
3861 bool directBuf;
3862 unsigned bufId;
3863 BufferType bufType = IGC::DecodeAS4GFXResource(as, directBuf, bufId);
3864 bool bICBNoOffset =
3865 (IGC::INVALID_CONSTANT_BUFFER_INVALID_ADDR == modMD->pushInfo.inlineConstantBufferOffset && bufType == CONSTANT_BUFFER && directBuf && bufId == modMD->pushInfo.inlineConstantBufferSlot);
3866 bool bICBOffseted =
3867 (IGC::INVALID_CONSTANT_BUFFER_INVALID_ADDR != modMD->pushInfo.inlineConstantBufferOffset && ADDRESS_SPACE_CONSTANT == as && isICBOffseted(inst, modMD->pushInfo.inlineConstantBufferOffset));
3868 if (bICBNoOffset || bICBOffseted)
3869 {
3870 Value* ptrVal = inst->getPointerOperand();
3871 Value* eltPtr = nullptr;
3872 Value* eltIdx = nullptr;
3873 if (IntToPtrInst * i2p = dyn_cast<IntToPtrInst>(ptrVal))
3874 {
3875 eltPtr = i2p->getOperand(0);
3876 }
3877 else if (GetElementPtrInst * gep = dyn_cast<GetElementPtrInst>(ptrVal))
3878 {
3879 if (gep->getNumOperands() != 3)
3880 {
3881 continue;
3882 }
3883
3884 Type* eleType = gep->getPointerOperandType()->getPointerElementType();
3885 if (!eleType->isArrayTy() ||
3886 !(eleType->getArrayElementType()->isFloatTy() || eleType->getArrayElementType()->isIntegerTy(32)))
3887 {
3888 continue;
3889 }
3890
3891 eltIdx = gep->getOperand(2);
3892 }
3893 else
3894 {
3895 continue;
3896 }
3897
3898 m_builder.SetInsertPoint(inst);
3899
3900 unsigned int size_in_bytes = (unsigned int)inst->getType()->getPrimitiveSizeInBits() / 8;
3901 if (size_in_bytes)
3902 {
3903 Value* ICBbuffer = UndefValue::get(IGCLLVM::FixedVectorType::get(inst->getType(), maxImmConstantSizePushed / size_in_bytes));
3904 if (inst->getType()->isFloatTy())
3905 {
3906 float returnConstant = 0;
3907 for (unsigned int i = 0; i < maxImmConstantSizePushed; i += size_in_bytes)
3908 {
3909 memcpy_s(&returnConstant, size_in_bytes, offset + i, size_in_bytes);
3910 Value* fp = ConstantFP::get(inst->getType(), returnConstant);
3911 ICBbuffer = m_builder.CreateInsertElement(ICBbuffer, fp, m_builder.getInt32(i / size_in_bytes));
3912 }
3913
3914 if (eltPtr)
3915 {
3916 eltIdx = m_builder.CreateLShr(eltPtr, m_builder.getInt32(2));
3917 }
3918 Value* ICBvalue = m_builder.CreateExtractElement(ICBbuffer, eltIdx);
3919 inst->replaceAllUsesWith(ICBvalue);
3920 }
3921 else if (inst->getType()->isIntegerTy(32))
3922 {
3923 int returnConstant = 0;
3924 for (unsigned int i = 0; i < maxImmConstantSizePushed; i += size_in_bytes)
3925 {
3926 memcpy_s(&returnConstant, size_in_bytes, offset + i, size_in_bytes);
3927 Value* fp = ConstantInt::get(inst->getType(), returnConstant);
3928 ICBbuffer = m_builder.CreateInsertElement(ICBbuffer, fp, m_builder.getInt32(i / size_in_bytes));
3929 }
3930 if (eltPtr)
3931 {
3932 eltIdx = m_builder.CreateLShr(eltPtr, m_builder.getInt32(2));
3933 }
3934 Value* ICBvalue = m_builder.CreateExtractElement(ICBbuffer, eltIdx);
3935 inst->replaceAllUsesWith(ICBvalue);
3936 }
3937 }
3938 }
3939 }
3940 }
3941 }
3942 }
3943
3944 return false;
3945 }
3946
isICBOffseted(llvm::LoadInst * inst,uint offset)3947 bool IGCIndirectICBPropagaion::isICBOffseted(llvm::LoadInst* inst, uint offset) {
3948 Value* ptrVal = inst->getPointerOperand();
3949 std::vector<Value*> srcInstList;
3950 IGC::TracePointerSource(ptrVal, false, true, true, srcInstList);
3951 if (srcInstList.size())
3952 {
3953 CallInst* inst = dyn_cast<CallInst>(srcInstList.back());
3954 GenIntrinsicInst* genIntr = inst ? dyn_cast<GenIntrinsicInst>(inst) : nullptr;
3955 if (!genIntr || (genIntr->getIntrinsicID() != GenISAIntrinsic::GenISA_RuntimeValue))
3956 return false;
3957
3958 llvm::ConstantInt* ci = dyn_cast<llvm::ConstantInt>(inst->getOperand(0));
3959 return ci && (uint)ci->getZExtValue() == offset;
3960 }
3961
3962 return false;
3963 }
3964
3965 IGC_INITIALIZE_PASS_BEGIN(IGCIndirectICBPropagaion, "IGCIndirectICBPropagaion",
3966 "IGCIndirectICBPropagaion", false, false)
3967 IGC_INITIALIZE_PASS_END(IGCIndirectICBPropagaion, "IGCIndirectICBPropagaion",
3968 "IGCIndirectICBPropagaion", false, false)
3969
3970 namespace {
3971 class NanHandling : public FunctionPass, public llvm::InstVisitor<NanHandling>
3972 {
3973 public:
3974 static char ID;
NanHandling()3975 NanHandling() : FunctionPass(ID)
3976 {
3977 initializeNanHandlingPass(*PassRegistry::getPassRegistry());
3978 }
3979
getAnalysisUsage(llvm::AnalysisUsage & AU) const3980 void getAnalysisUsage(llvm::AnalysisUsage& AU) const
3981 {
3982 AU.setPreservesCFG();
3983 AU.addRequired<LoopInfoWrapperPass>();
3984 }
3985
getPassName() const3986 virtual llvm::StringRef getPassName() const { return "NAN handling"; }
3987 virtual bool runOnFunction(llvm::Function& F);
3988 void visitBranchInst(llvm::BranchInst& I);
3989 void loopNanCases(Function& F);
3990
3991 private:
3992 int longestPathInstCount(llvm::BasicBlock* BB, int& depth);
3993 void swapBranch(llvm::Instruction* inst, llvm::BranchInst& BI);
3994 SmallVector<llvm::BranchInst*, 10> visitedInst;
3995 };
3996 } // namespace
3997
3998 char NanHandling::ID = 0;
createNanHandlingPass()3999 FunctionPass* IGC::createNanHandlingPass() { return new NanHandling(); }
4000
runOnFunction(Function & F)4001 bool NanHandling::runOnFunction(Function& F)
4002 {
4003 loopNanCases(F);
4004 visit(F);
4005 return true;
4006 }
4007
loopNanCases(Function & F)4008 void NanHandling::loopNanCases(Function& F)
4009 {
4010 // take care of loop cases
4011 visitedInst.clear();
4012 llvm::LoopInfo* LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
4013 if (LI && !LI->empty())
4014 {
4015 FastMathFlags FMF;
4016 FMF.clear();
4017 for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
4018 {
4019 Loop* loop = *I;
4020 BranchInst* br = cast<BranchInst>(loop->getLoopLatch()->getTerminator());
4021 BasicBlock* header = loop->getHeader();
4022 if (br && br->isConditional() && header)
4023 {
4024 visitedInst.push_back(br);
4025 if (FCmpInst * brCmpInst = dyn_cast<FCmpInst>(br->getCondition()))
4026 {
4027 FPMathOperator* FPO = dyn_cast<FPMathOperator>(brCmpInst);
4028 if (!FPO || !FPO->isFast())
4029 {
4030 continue;
4031 }
4032 if (br->getSuccessor(1) == header)
4033 {
4034 swapBranch(brCmpInst, *br);
4035 }
4036 }
4037 else if (BinaryOperator * andOrInst = dyn_cast<BinaryOperator>(br->getCondition()))
4038 {
4039 if (andOrInst->getOpcode() != BinaryOperator::And &&
4040 andOrInst->getOpcode() != BinaryOperator::Or)
4041 {
4042 continue;
4043 }
4044 FCmpInst* brCmpInst0 = dyn_cast<FCmpInst>(andOrInst->getOperand(0));
4045 FCmpInst* brCmpInst1 = dyn_cast<FCmpInst>(andOrInst->getOperand(1));
4046 if (!brCmpInst0 || !brCmpInst1)
4047 {
4048 continue;
4049 }
4050 if (br->getSuccessor(1) == header)
4051 {
4052 brCmpInst0->copyFastMathFlags(FMF);
4053 brCmpInst1->copyFastMathFlags(FMF);
4054 }
4055 }
4056 }
4057 }
4058 }
4059 }
4060
longestPathInstCount(llvm::BasicBlock * BB,int & depth)4061 int NanHandling::longestPathInstCount(llvm::BasicBlock* BB, int& depth)
4062 {
4063 #define MAX_SEARCH_DEPTH 10
4064
4065 depth++;
4066 if (!BB || depth > MAX_SEARCH_DEPTH)
4067 return 0;
4068
4069 int sumSuccInstCount = 0;
4070 for (succ_iterator SI = succ_begin(BB), E = succ_end(BB); SI != E; ++SI)
4071 {
4072 sumSuccInstCount += longestPathInstCount(*SI, depth);
4073 }
4074 return (int)(BB->getInstList().size()) + sumSuccInstCount;
4075 }
4076
swapBranch(llvm::Instruction * inst,llvm::BranchInst & BI)4077 void NanHandling::swapBranch(llvm::Instruction* inst, llvm::BranchInst& BI)
4078 {
4079 if (FCmpInst * brCondition = dyn_cast<FCmpInst>(inst))
4080 {
4081 if (inst->hasOneUse())
4082 {
4083 brCondition->setPredicate(FCmpInst::getInversePredicate(brCondition->getPredicate()));
4084 BI.swapSuccessors();
4085 }
4086 }
4087 else
4088 {
4089 // inst not expected
4090 IGC_ASSERT(0);
4091 }
4092 }
4093
visitBranchInst(llvm::BranchInst & I)4094 void NanHandling::visitBranchInst(llvm::BranchInst& I)
4095 {
4096 if (!I.isConditional())
4097 return;
4098
4099 // if this branch is part of a loop, it is taken care of already in loopNanCases
4100 for (auto iter = visitedInst.begin(); iter != visitedInst.end(); iter++)
4101 {
4102 if (&I == *iter)
4103 return;
4104 }
4105
4106 FCmpInst* brCmpInst = dyn_cast<FCmpInst>(I.getCondition());
4107 FCmpInst* src0 = nullptr;
4108 FCmpInst* src1 = nullptr;
4109
4110 // if the branching is based on a cmp instruction
4111 if (brCmpInst)
4112 {
4113 FPMathOperator* FPO = dyn_cast<FPMathOperator>(brCmpInst);
4114 if (!FPO || !FPO->isFast())
4115 return;
4116
4117 if (!brCmpInst->hasOneUse())
4118 return;
4119 }
4120 // if the branching is based on a and/or from multiple conditions.
4121 else if (BinaryOperator * andOrInst = dyn_cast<BinaryOperator>(I.getCondition()))
4122 {
4123 if (andOrInst->getOpcode() != BinaryOperator::And && andOrInst->getOpcode() != BinaryOperator::Or)
4124 return;
4125
4126 src0 = dyn_cast<FCmpInst>(andOrInst->getOperand(0));
4127 src1 = dyn_cast<FCmpInst>(andOrInst->getOperand(1));
4128
4129 if (!src0 || !src1)
4130 return;
4131 }
4132 else
4133 {
4134 return;
4135 }
4136
4137 // Calculate the maximum instruction count when going down one branch.
4138 // Make the false case (including NaN) goes to the shorter path.
4139 int depth = 0;
4140 int trueBranchSize = longestPathInstCount(I.getSuccessor(0), depth);
4141 depth = 0;
4142 int falseBranchSize = longestPathInstCount(I.getSuccessor(1), depth);
4143
4144 if (falseBranchSize - trueBranchSize > (int)(IGC_GET_FLAG_VALUE(SetBranchSwapThreshold)))
4145 {
4146 if (brCmpInst)
4147 {
4148 // swap the condition and the successor blocks
4149 swapBranch(brCmpInst, I);
4150 }
4151 else
4152 {
4153 FastMathFlags FMF;
4154 FMF.clear();
4155 src0->copyFastMathFlags(FMF);
4156 src1->copyFastMathFlags(FMF);
4157 }
4158 return;
4159 }
4160 }
4161 IGC_INITIALIZE_PASS_BEGIN(NanHandling, "NanHandling", "NanHandling", false, false)
4162 IGC_INITIALIZE_PASS_END(NanHandling, "NanHandling", "NanHandling", false, false)
4163
4164 namespace IGC
4165 {
4166
4167 class VectorBitCastOpt: public FunctionPass
4168 {
4169 public:
4170 static char ID;
4171
VectorBitCastOpt()4172 VectorBitCastOpt() : FunctionPass(ID)
4173 {
4174 initializeVectorBitCastOptPass(*PassRegistry::getPassRegistry());
4175 };
~VectorBitCastOpt()4176 ~VectorBitCastOpt() {}
4177
getPassName() const4178 virtual llvm::StringRef getPassName() const override
4179 {
4180 return "VectorBitCastOptPass";
4181 }
4182
getAnalysisUsage(llvm::AnalysisUsage & AU) const4183 virtual void getAnalysisUsage(llvm::AnalysisUsage& AU) const override
4184 {
4185 AU.setPreservesCFG();
4186 }
4187
4188 virtual bool runOnFunction(llvm::Function& F) override;
4189
4190 private:
4191 // Transform (extractelement (bitcast %vector) ...) to
4192 // (bitcast (extractelement %vector) ...) in order to help coalescing
4193 // in DeSSA and enable memory operations simplification
4194 // in VectorPreProcess.
4195 bool optimizeVectorBitCast(Function& F) const;
4196 };
4197
runOnFunction(Function & F)4198 bool VectorBitCastOpt::runOnFunction(Function& F)
4199 {
4200 bool Changed = optimizeVectorBitCast(F);
4201 return Changed;
4202 }
4203
optimizeVectorBitCast(Function & F) const4204 bool VectorBitCastOpt::optimizeVectorBitCast(Function& F) const {
4205 IRBuilder<> Builder(F.getContext());
4206
4207 bool Changed = false;
4208 for (auto& BB : F) {
4209 for (auto BI = BB.begin(), BE = BB.end(); BI != BE; /*EMPTY*/) {
4210 BitCastInst* BC = dyn_cast<BitCastInst>(&*BI++);
4211 if (!BC) continue;
4212 // Skip non-element-wise bitcast.
4213 IGCLLVM::FixedVectorType* DstVTy = dyn_cast<IGCLLVM::FixedVectorType>(BC->getType());
4214 IGCLLVM::FixedVectorType* SrcVTy = dyn_cast<IGCLLVM::FixedVectorType>(BC->getOperand(0)->getType());
4215 if (!DstVTy || !SrcVTy || DstVTy->getNumElements() != SrcVTy->getNumElements())
4216 continue;
4217 // Skip if it's not used only all extractelement.
4218 bool ExactOnly = true;
4219 for (auto User : BC->users()) {
4220 if (isa<ExtractElementInst>(User)) continue;
4221 ExactOnly = false;
4222 break;
4223 }
4224 if (!ExactOnly)
4225 continue;
4226 // Autobots, transform and roll out!
4227 Value* Src = BC->getOperand(0);
4228 Type* DstEltTy = DstVTy->getElementType();
4229 for (auto UI = BC->user_begin(), UE = BC->user_end(); UI != UE;
4230 /*EMPTY*/) {
4231 auto EEI = cast<ExtractElementInst>(*UI++);
4232 Builder.SetInsertPoint(EEI);
4233 auto NewVal = Builder.CreateExtractElement(Src, EEI->getIndexOperand());
4234 NewVal = Builder.CreateBitCast(NewVal, DstEltTy);
4235 EEI->replaceAllUsesWith(NewVal);
4236 EEI->eraseFromParent();
4237 }
4238 BI = BC->eraseFromParent();
4239 Changed = true;
4240 }
4241 }
4242
4243 return Changed;
4244 }
4245
4246 char VectorBitCastOpt::ID = 0;
createVectorBitCastOptPass()4247 FunctionPass* createVectorBitCastOptPass() { return new VectorBitCastOpt(); }
4248
4249 } // namespace IGC
4250
4251 #define VECTOR_BITCAST_OPT_PASS_FLAG "igc-vector-bitcast-opt"
4252 #define VECTOR_BITCAST_OPT_PASS_DESCRIPTION "Preprocess vector bitcasts to be after extractelement instructions."
4253 #define VECTOR_BITCAST_OPT_PASS_CFG_ONLY true
4254 #define VECTOR_BITCAST_OPT_PASS_ANALYSIS false
4255 IGC_INITIALIZE_PASS_BEGIN(VectorBitCastOpt, VECTOR_BITCAST_OPT_PASS_FLAG, VECTOR_BITCAST_OPT_PASS_DESCRIPTION, VECTOR_BITCAST_OPT_PASS_CFG_ONLY, VECTOR_BITCAST_OPT_PASS_ANALYSIS)
4256 IGC_INITIALIZE_PASS_END(VectorBitCastOpt, VECTOR_BITCAST_OPT_PASS_FLAG, VECTOR_BITCAST_OPT_PASS_DESCRIPTION, VECTOR_BITCAST_OPT_PASS_CFG_ONLY, VECTOR_BITCAST_OPT_PASS_ANALYSIS)
4257
4258 namespace {
4259
4260 class GenStrengthReduction : public FunctionPass
4261 {
4262 public:
4263 static char ID;
GenStrengthReduction()4264 GenStrengthReduction() : FunctionPass(ID)
4265 {
4266 initializeGenStrengthReductionPass(*PassRegistry::getPassRegistry());
4267 }
getPassName() const4268 virtual llvm::StringRef getPassName() const { return "Gen strength reduction"; }
4269 virtual bool runOnFunction(Function& F);
4270
4271 private:
4272 bool processInst(Instruction* Inst);
4273 };
4274
4275 } // namespace
4276
4277 char GenStrengthReduction::ID = 0;
createGenStrengthReductionPass()4278 FunctionPass* IGC::createGenStrengthReductionPass() { return new GenStrengthReduction(); }
4279
runOnFunction(Function & F)4280 bool GenStrengthReduction::runOnFunction(Function& F)
4281 {
4282 bool Changed = false;
4283 for (auto& BB : F)
4284 {
4285 for (auto BI = BB.begin(), BE = BB.end(); BI != BE;)
4286 {
4287 Instruction* Inst = &(*BI++);
4288 if (isInstructionTriviallyDead(Inst))
4289 {
4290 Inst->eraseFromParent();
4291 Changed = true;
4292 continue;
4293 }
4294 Changed |= processInst(Inst);
4295 }
4296 }
4297 return Changed;
4298 }
4299
4300 // Check if this is a fdiv that allows reciprocal, and its divident is not known
4301 // to be 1.0.
isCandidateFDiv(Instruction * Inst)4302 static bool isCandidateFDiv(Instruction* Inst)
4303 {
4304 // Only floating points, and no vectors.
4305 if (!Inst->getType()->isFloatingPointTy() || Inst->use_empty())
4306 return false;
4307
4308 auto Op = dyn_cast<FPMathOperator>(Inst);
4309 if (Op && Op->getOpcode() == Instruction::FDiv && Op->hasAllowReciprocal())
4310 {
4311 Value* Src0 = Op->getOperand(0);
4312 if (auto CFP = dyn_cast<ConstantFP>(Src0))
4313 return !CFP->isExactlyValue(1.0);
4314 return true;
4315 }
4316 return false;
4317 }
4318
processInst(Instruction * Inst)4319 bool GenStrengthReduction::processInst(Instruction* Inst)
4320 {
4321
4322 unsigned opc = Inst->getOpcode();
4323 auto Op = dyn_cast<FPMathOperator>(Inst);
4324 if (opc == Instruction::Select)
4325 {
4326 Value* oprd1 = Inst->getOperand(1);
4327 Value* oprd2 = Inst->getOperand(2);
4328 ConstantFP* CF1 = dyn_cast<ConstantFP>(oprd1);
4329 ConstantFP* CF2 = dyn_cast<ConstantFP>(oprd2);
4330 if (oprd1 == oprd2 ||
4331 (CF1 && CF2 && CF1->isExactlyValue(CF2->getValueAPF())))
4332 {
4333 Inst->replaceAllUsesWith(oprd1);
4334 Inst->eraseFromParent();
4335 return true;
4336 }
4337 }
4338 if (Op &&
4339 Op->hasNoNaNs() &&
4340 Op->hasNoInfs() &&
4341 Op->hasNoSignedZeros())
4342 {
4343 switch (opc)
4344 {
4345 case Instruction::FDiv:
4346 {
4347 Value* Oprd0 = Inst->getOperand(0);
4348 if (ConstantFP * CF = dyn_cast<ConstantFP>(Oprd0))
4349 {
4350 if (CF->isZero())
4351 {
4352 Inst->replaceAllUsesWith(Oprd0);
4353 Inst->eraseFromParent();
4354 return true;
4355 }
4356 }
4357 break;
4358 }
4359 case Instruction::FMul:
4360 {
4361 for (int i = 0; i < 2; ++i)
4362 {
4363 ConstantFP* CF = dyn_cast<ConstantFP>(Inst->getOperand(i));
4364 if (CF && CF->isZero())
4365 {
4366 Inst->replaceAllUsesWith(CF);
4367 Inst->eraseFromParent();
4368 return true;
4369 }
4370 }
4371 break;
4372 }
4373 case Instruction::FAdd:
4374 {
4375 for (int i = 0; i < 2; ++i)
4376 {
4377 ConstantFP* CF = dyn_cast<ConstantFP>(Inst->getOperand(i));
4378 if (CF && CF->isZero())
4379 {
4380 Value* otherOprd = Inst->getOperand(1 - i);
4381 Inst->replaceAllUsesWith(otherOprd);
4382 Inst->eraseFromParent();
4383 return true;
4384 }
4385 }
4386 break;
4387 }
4388 }
4389 }
4390
4391 // fdiv -> inv + mul. On gen, fdiv seems always slower
4392 // than inv + mul. Do it if fdiv's fastMathFlag allows it.
4393 //
4394 // Rewrite
4395 // %1 = fdiv arcp float %x, %z
4396 // into
4397 // %1 = fdiv arcp float 1.0, %z
4398 // %2 = fmul arcp float %x, %1
4399 if (isCandidateFDiv(Inst))
4400 {
4401 Value* Src1 = Inst->getOperand(1);
4402 if (isa<Constant>(Src1))
4403 {
4404 // should not happen (but do see "fdiv x / 0.0f"). Skip.
4405 return false;
4406 }
4407
4408 Value* Src0 = ConstantFP::get(Inst->getType(), 1.0);
4409 Instruction* Inv = nullptr;
4410
4411 // Check if there is any other (x / Src1). If so, commonize 1/Src1.
4412 for (auto UI = Src1->user_begin(), UE = Src1->user_end();
4413 UI != UE; ++UI)
4414 {
4415 Value* Val = *UI;
4416 Instruction* I = dyn_cast<Instruction>(Val);
4417 if (I && I != Inst && I->getOpcode() == Instruction::FDiv &&
4418 I->getOperand(1) == Src1 && isCandidateFDiv(I))
4419 {
4420 // special case
4421 if (ConstantFP * CF = dyn_cast<ConstantFP>(I->getOperand(0)))
4422 {
4423 if (CF->isZero())
4424 {
4425 // Skip this one.
4426 continue;
4427 }
4428 }
4429
4430 // Found another 1/Src1. Insert Inv right after the def of Src1
4431 // or in the entry BB if Src1 is an argument.
4432 if (!Inv)
4433 {
4434 Instruction* insertBefore = dyn_cast<Instruction>(Src1);
4435 if (insertBefore)
4436 {
4437 if (isa<PHINode>(insertBefore))
4438 {
4439 BasicBlock* BB = insertBefore->getParent();
4440 insertBefore = &(*BB->getFirstInsertionPt());
4441 }
4442 else
4443 {
4444 // Src1 is an instruction
4445 BasicBlock::iterator iter(insertBefore);
4446 ++iter;
4447 insertBefore = &(*iter);
4448 }
4449 }
4450 else
4451 {
4452 // Src1 is an argument and insert at the begin of entry BB
4453 BasicBlock& entryBB = Inst->getParent()->getParent()->getEntryBlock();
4454 insertBefore = &(*entryBB.getFirstInsertionPt());
4455 }
4456 Inv = BinaryOperator::CreateFDiv(Src0, Src1, "", insertBefore);
4457 Inv->setFastMathFlags(Inst->getFastMathFlags());
4458 }
4459
4460 Instruction* Mul = BinaryOperator::CreateFMul(I->getOperand(0), Inv, "", I);
4461 Mul->setFastMathFlags(Inst->getFastMathFlags());
4462 I->replaceAllUsesWith(Mul);
4463 // Don't erase it as doing so would invalidate iterator in this func's caller
4464 // Instead, erase it in the caller.
4465 // I->eraseFromParent();
4466 }
4467 }
4468
4469 if (!Inv)
4470 {
4471 // Only a single use of 1 / Src1. Create Inv right before the use.
4472 Inv = BinaryOperator::CreateFDiv(Src0, Src1, "", Inst);
4473 Inv->setFastMathFlags(Inst->getFastMathFlags());
4474 }
4475
4476 auto Mul = BinaryOperator::CreateFMul(Inst->getOperand(0), Inv, "", Inst);
4477 Mul->setFastMathFlags(Inst->getFastMathFlags());
4478 Inst->replaceAllUsesWith(Mul);
4479 Inst->eraseFromParent();
4480 return true;
4481 }
4482
4483 return false;
4484 }
4485
4486 IGC_INITIALIZE_PASS_BEGIN(GenStrengthReduction, "GenStrengthReduction",
4487 "GenStrengthReduction", false, false)
4488 IGC_INITIALIZE_PASS_END(GenStrengthReduction, "GenStrengthReduction",
4489 "GenStrengthReduction", false, false)
4490
4491
4492 /*========================== FlattenSmallSwitch ==============================
4493
4494 This class flatten small switch. For example,
4495
4496 before optimization:
4497 then153:
4498 switch i32 %115, label %else229 [
4499 i32 1, label %then214
4500 i32 2, label %then222
4501 i32 3, label %then222 ; duplicate blocks are fine
4502 ]
4503
4504 then214: ; preds = %then153
4505 %150 = fdiv float 1.000000e+00, %res_s208
4506 %151 = fmul float %147, %150
4507 br label %ifcont237
4508
4509 then222: ; preds = %then153
4510 %152 = fsub float 1.000000e+00, %141
4511 br label %ifcont237
4512
4513 else229: ; preds = %then153
4514 %res_s230 = icmp eq i32 %115, 3
4515 %. = select i1 %res_s230, float 1.000000e+00, float 0.000000e+00
4516 br label %ifcont237
4517
4518 ifcont237: ; preds = %else229, %then222, %then214
4519 %"r[9][0].x.0" = phi float [ %151, %then214 ], [ %152, %then222 ], [ %., %else229 ]
4520
4521 after optimization:
4522 %res_s230 = icmp eq i32 %115, 3
4523 %. = select i1 %res_s230, float 1.000000e+00, float 0.000000e+00
4524 %150 = fdiv float 1.000000e+00, %res_s208
4525 %151 = fmul float %147, %150
4526 %152 = icmp eq i32 %115, 1
4527 %153 = select i1 %152, float %151, float %.
4528 %154 = fsub float 1.000000e+00, %141
4529 %155 = icmp eq i32 %115, 2
4530 %156 = select i1 %155, float %154, float %153
4531
4532 =============================================================================*/
4533 namespace {
4534 class FlattenSmallSwitch : public FunctionPass
4535 {
4536 public:
4537 static char ID;
FlattenSmallSwitch()4538 FlattenSmallSwitch() : FunctionPass(ID)
4539 {
4540 initializeFlattenSmallSwitchPass(*PassRegistry::getPassRegistry());
4541 }
getPassName() const4542 virtual llvm::StringRef getPassName() const { return "Flatten Small Switch"; }
4543 virtual bool runOnFunction(Function& F);
4544 bool processSwitchInst(SwitchInst* SI);
4545 };
4546
4547 } // namespace
4548
4549 char FlattenSmallSwitch::ID = 0;
createFlattenSmallSwitchPass()4550 FunctionPass* IGC::createFlattenSmallSwitchPass() { return new FlattenSmallSwitch(); }
4551
processSwitchInst(SwitchInst * SI)4552 bool FlattenSmallSwitch::processSwitchInst(SwitchInst* SI)
4553 {
4554 const unsigned maxSwitchCases = 3; // only apply to switch with 3 cases or less
4555 const unsigned maxCaseInsts = 3; // only apply optimization when each case has 3 instructions or less.
4556
4557 BasicBlock* Default = SI->getDefaultDest();
4558 Value* Val = SI->getCondition(); // The value we are switching on...
4559 IRBuilder<> builder(SI);
4560
4561 if (SI->getNumCases() > maxSwitchCases || SI->getNumCases() == 0)
4562 {
4563 return false;
4564 }
4565
4566 // Dest will be the block that the control flow from the switch merges to.
4567 // Currently, there are two options:
4568 // 1. The Dest block is the default block from the switch
4569 // 2. The Dest block is jumped to by all of the switch cases (and the default)
4570 BasicBlock* Dest = nullptr;
4571 {
4572 const auto* CaseSucc =
4573 #if LLVM_VERSION_MAJOR == 4
4574 SI->case_begin().getCaseSuccessor();
4575 #elif LLVM_VERSION_MAJOR >= 7
4576 SI->case_begin()->getCaseSuccessor();
4577 #endif
4578 auto* BI = dyn_cast<BranchInst>(CaseSucc->getTerminator());
4579
4580 if (BI == nullptr)
4581 return false;
4582
4583 if (BI->isConditional())
4584 return false;
4585
4586 // We know the first case jumps to this block. Now let's
4587 // see below whether all the cases jump to this same block.
4588 Dest = BI->getSuccessor(0);
4589 }
4590
4591 // Does BB unconditionally branch to MergeBlock?
4592 auto branchPattern = [](const BasicBlock* BB, const BasicBlock* MergeBlock)
4593 {
4594 auto* br = dyn_cast<BranchInst>(BB->getTerminator());
4595
4596 if (br == nullptr)
4597 return false;
4598
4599 if (br->isConditional())
4600 return false;
4601
4602 if (br->getSuccessor(0) != MergeBlock)
4603 return false;
4604
4605 if (!BB->getUniquePredecessor())
4606 return false;
4607
4608 return true;
4609 };
4610
4611 // We can speculatively execute a basic block if it
4612 // is small, unconditionally branches to Dest, and doesn't
4613 // have high latency or unsafe to speculate instructions.
4614 auto canSpeculateBlock = [&](BasicBlock* BB)
4615 {
4616 if (BB->size() > maxCaseInsts)
4617 return false;
4618
4619 if (!branchPattern(BB, Dest))
4620 return false;
4621
4622 for (auto& I : *BB)
4623 {
4624 auto* inst = &I;
4625
4626 if (isa<BranchInst>(inst))
4627 continue;
4628
4629 // if there is any high-latency instruction in the switch,
4630 // don't flatten it
4631 if (isSampleInstruction(inst) ||
4632 isGather4Instruction(inst) ||
4633 isInfoInstruction(inst) ||
4634 isLdInstruction(inst) ||
4635 // If the instruction can't be speculated (e.g., phi node),
4636 // punt.
4637 !isSafeToSpeculativelyExecute(inst))
4638 {
4639 return false;
4640 }
4641 }
4642
4643 return true;
4644 };
4645
4646 // Are all Phi incomming blocks from SI switch?
4647 auto checkPhiPredecessorBlocks = [](const SwitchInst* SI, const PHINode* Phi, bool DefaultMergeBlock)
4648 {
4649 for (auto* BB : Phi->blocks())
4650 {
4651 if (BB == SI->getDefaultDest())
4652 continue;
4653 bool successorFound = false;
4654 for (auto Case : SI->cases()) {
4655 if (Case.getCaseSuccessor() == BB)
4656 successorFound = true;
4657 }
4658 if (successorFound)
4659 continue;
4660 return false;
4661 }
4662 return true;
4663 };
4664
4665 for (auto& I : SI->cases())
4666 {
4667 BasicBlock* CaseDest = I.getCaseSuccessor();
4668
4669 if (!canSpeculateBlock(CaseDest))
4670 return false;
4671 }
4672
4673 // Is the default case of the switch the block
4674 // where all other cases meet?
4675 const bool DefaultMergeBlock = (Dest == Default);
4676
4677 // If we merge to the default block, there is no block
4678 // we jump to beforehand so there is nothing to
4679 // speculate.
4680 if (!DefaultMergeBlock && !canSpeculateBlock(Default))
4681 return false;
4682
4683 // Get all PHI nodes that needs to be replaced
4684 SmallVector<PHINode*, 4> PhiNodes;
4685 for (auto& I : *Dest)
4686 {
4687 auto* Phi = dyn_cast<PHINode>(&I);
4688
4689 if (!Phi)
4690 break;
4691
4692 if (!checkPhiPredecessorBlocks(SI, Phi, DefaultMergeBlock))
4693 return false;
4694
4695 PhiNodes.push_back(Phi);
4696 }
4697
4698 if (PhiNodes.empty())
4699 return false;
4700
4701 // Move all instructions except the last (i.e., the branch)
4702 // from BB to the InsertPoint.
4703 auto splice = [](BasicBlock* BB, Instruction* InsertPoint)
4704 {
4705 for (auto II = BB->begin(), IE = BB->end(); II != IE; /* empty */)
4706 {
4707 auto* I = &*II++;
4708 if (I->isTerminator())
4709 return;
4710
4711 I->moveBefore(InsertPoint);
4712 }
4713 };
4714
4715 // move default block out
4716 if (!DefaultMergeBlock)
4717 splice(Default, SI);
4718
4719 // move case blocks out
4720 for (auto& I : SI->cases())
4721 {
4722 BasicBlock* CaseDest = I.getCaseSuccessor();
4723 splice(CaseDest, SI);
4724 }
4725
4726 // replaces PHI with select
4727 for (auto* Phi : PhiNodes)
4728 {
4729 Value* vTemp = Phi->getIncomingValueForBlock(
4730 DefaultMergeBlock ? SI->getParent() : Default);
4731
4732 for (auto& I : SI->cases())
4733 {
4734 BasicBlock* CaseDest = I.getCaseSuccessor();
4735 ConstantInt* CaseValue = I.getCaseValue();
4736
4737 Value* selTrueValue = Phi->getIncomingValueForBlock(CaseDest);
4738 builder.SetInsertPoint(SI);
4739 Value* cmp = builder.CreateICmp(CmpInst::Predicate::ICMP_EQ, Val, CaseValue);
4740 Value* sel = builder.CreateSelect(cmp, selTrueValue, vTemp);
4741 vTemp = sel;
4742 }
4743
4744 Phi->replaceAllUsesWith(vTemp);
4745 Phi->removeFromParent();
4746 }
4747
4748 // connect the original block and the phi node block with a pass through branch
4749 builder.CreateBr(Dest);
4750
4751 SmallPtrSet<BasicBlock*, 4> Succs;
4752
4753 // Remove the switch.
4754 BasicBlock* SelectBB = SI->getParent();
4755 for (unsigned i = 0, e = SI->getNumSuccessors(); i < e; ++i)
4756 {
4757 BasicBlock* Succ = SI->getSuccessor(i);
4758 if (Succ == Dest)
4759 {
4760 continue;
4761 }
4762
4763 if (Succs.insert(Succ).second)
4764 Succ->removePredecessor(SelectBB);
4765 }
4766 SI->eraseFromParent();
4767
4768 return true;
4769 }
4770
runOnFunction(Function & F)4771 bool FlattenSmallSwitch::runOnFunction(Function& F)
4772 {
4773 bool Changed = false;
4774 for (Function::iterator I = F.begin(), E = F.end(); I != E; )
4775 {
4776 BasicBlock* Cur = &*I++; // Advance over block so we don't traverse new blocks
4777 if (SwitchInst * SI = dyn_cast<SwitchInst>(Cur->getTerminator()))
4778 {
4779 Changed |= processSwitchInst(SI);
4780 }
4781 }
4782 return Changed;
4783 }
4784
4785 IGC_INITIALIZE_PASS_BEGIN(FCmpPaternMatch, "FCmpPaternMatch", "FCmpPaternMatch", false, false)
4786 IGC_INITIALIZE_PASS_END(FCmpPaternMatch, "FCmpPaternMatch", "FCmpPaternMatch", false, false)
4787
4788 char FCmpPaternMatch::ID = 0;
4789
FCmpPaternMatch()4790 FCmpPaternMatch::FCmpPaternMatch() : FunctionPass(ID)
4791 {
4792 initializeFCmpPaternMatchPass(*PassRegistry::getPassRegistry());
4793 }
4794
runOnFunction(Function & F)4795 bool FCmpPaternMatch::runOnFunction(Function& F)
4796 {
4797 bool change = true;
4798 visit(F);
4799 return change;
4800 }
4801
visitSelectInst(SelectInst & I)4802 void FCmpPaternMatch::visitSelectInst(SelectInst& I)
4803 {
4804 /*
4805 from
4806 %7 = fcmp olt float %5, %6
4807 %8 = select i1 %7, float 1.000000e+00, float 0.000000e+00
4808 %9 = bitcast float %8 to i32
4809 %10 = icmp eq i32 %9, 0
4810 %.01 = select i1 %10, float %4, float %11
4811 br i1 %10, label %endif, label %then
4812 to
4813 %7 = fcmp olt float %5, %6
4814 %.01 = select i1 %7, float %11, float %4
4815 br i1 %7, label %then, label %endif
4816 */
4817 {
4818 bool swapNodesFromSel = false;
4819 bool isSelWithConstants = false;
4820 ConstantFP* Cfp1 = dyn_cast<ConstantFP>(I.getOperand(1));
4821 ConstantFP* Cfp2 = dyn_cast<ConstantFP>(I.getOperand(2));
4822 if (Cfp1 && Cfp1->getValueAPF().isFiniteNonZero() &&
4823 Cfp2 && Cfp2->isZero())
4824 {
4825 isSelWithConstants = true;
4826 }
4827 if (Cfp1 && Cfp1->isZero() &&
4828 Cfp2 && Cfp2->getValueAPF().isFiniteNonZero())
4829 {
4830 isSelWithConstants = true;
4831 swapNodesFromSel = true;
4832 }
4833 if (isSelWithConstants &&
4834 dyn_cast<FCmpInst>(I.getOperand(0)))
4835 {
4836 for (auto bitCastI : I.users())
4837 {
4838 if (BitCastInst * bitcastInst = dyn_cast<BitCastInst>(bitCastI))
4839 {
4840 for (auto cmpI : bitcastInst->users())
4841 {
4842 ICmpInst* iCmpInst = dyn_cast<ICmpInst>(cmpI);
4843 if (iCmpInst &&
4844 iCmpInst->isEquality())
4845 {
4846 ConstantInt* icmpC = dyn_cast<ConstantInt>(iCmpInst->getOperand(1));
4847 if (!icmpC || !icmpC->isZero())
4848 {
4849 continue;
4850 }
4851
4852 bool swapNodes = swapNodesFromSel;
4853 if (iCmpInst->getPredicate() == CmpInst::Predicate::ICMP_EQ)
4854 {
4855 swapNodes = (swapNodes != true);
4856 }
4857
4858 SmallVector<Instruction*, 4> matchedBrSelInsts;
4859 for (auto brOrSelI : iCmpInst->users())
4860 {
4861 BranchInst* brInst = dyn_cast<BranchInst>(brOrSelI);
4862 if (brInst &&
4863 brInst->isConditional())
4864 {
4865 //match
4866 matchedBrSelInsts.push_back(brInst);
4867 if (swapNodes)
4868 {
4869 brInst->swapSuccessors();
4870 }
4871 }
4872
4873 if (SelectInst * selInst = dyn_cast<SelectInst>(brOrSelI))
4874 {
4875 //match
4876 matchedBrSelInsts.push_back(selInst);
4877 if (swapNodes)
4878 {
4879 Value* selTrue = selInst->getTrueValue();
4880 Value* selFalse = selInst->getFalseValue();
4881 selInst->setTrueValue(selFalse);
4882 selInst->setFalseValue(selTrue);
4883 selInst->swapProfMetadata();
4884 }
4885 }
4886 }
4887 for (Instruction* inst : matchedBrSelInsts)
4888 {
4889 inst->setOperand(0, I.getOperand(0));
4890 }
4891 }
4892 }
4893 }
4894 }
4895 }
4896 }
4897 }
4898
4899 IGC_INITIALIZE_PASS_BEGIN(FlattenSmallSwitch, "flattenSmallSwitch", "flattenSmallSwitch", false, false)
4900 IGC_INITIALIZE_PASS_END(FlattenSmallSwitch, "flattenSmallSwitch", "flattenSmallSwitch", false, false)
4901
4902
4903
4904 /*======================== SplitIndirectEEtoSel =============================
4905
4906 This class changes extract element for small vectors to series of cmp+sel to avoid VxH mov.
4907 before:
4908 %268 = mul nuw i32 %res.i2.i, 3
4909 %269 = extractelement <12 x float> %234, i32 %268
4910 %270 = add i32 %268, 1
4911 %271 = extractelement <12 x float> %234, i32 %270
4912 %272 = add i32 %268, 2
4913 %273 = extractelement <12 x float> %234, i32 %272
4914 %274 = extractelement <12 x float> %198, i32 %268
4915 %275 = extractelement <12 x float> %198, i32 %270
4916 %276 = extractelement <12 x float> %198, i32 %272
4917
4918 after:
4919 %250 = icmp eq i32 %res.i2.i, i16 1
4920 %251 = select i1 %250, float %206, float %200
4921 %252 = select i1 %250, float %208, float %202
4922 %253 = select i1 %250, float %210, float %204
4923 %254 = select i1 %250, float %48, float %32
4924 %255 = select i1 %250, float %49, float %33
4925 %256 = select i1 %250, float %50, float %34
4926 %257 = icmp eq i32 %res.i2.i, i16 2
4927 %258 = select i1 %257, float %214, float %251
4928 %259 = select i1 %257, float %215, float %252
4929 %260 = select i1 %257, float %216, float %253
4930 %261 = select i1 %257, float %64, float %254
4931 %262 = select i1 %257, float %65, float %255
4932 %263 = select i1 %257, float %66, float %256
4933
4934 It is a bit similar to SimplifyConstant::isCmpSelProfitable for OCL, but not restricted to api.
4935 And to GenSimplification::visitExtractElement() but not restricted to vec of 2, and later.
4936 TODO: for known vectors check how many unique items there are.
4937 ===========================================================================*/
4938 namespace {
4939 class SplitIndirectEEtoSel : public FunctionPass, public llvm::InstVisitor<SplitIndirectEEtoSel>
4940 {
4941 public:
4942 static char ID;
SplitIndirectEEtoSel()4943 SplitIndirectEEtoSel() : FunctionPass(ID)
4944 {
4945 initializeSplitIndirectEEtoSelPass(*PassRegistry::getPassRegistry());
4946 }
getPassName() const4947 virtual llvm::StringRef getPassName() const { return "Split Indirect EE to ICmp Plus Sel"; }
4948 virtual bool runOnFunction(Function& F);
4949 void visitExtractElementInst(llvm::ExtractElementInst& I);
4950 private:
4951 bool isProfitableToSplit(uint64_t num, int64_t mul, int64_t add);
4952 bool didSomething;
4953 };
4954
4955 } // namespace
4956
4957
4958 char SplitIndirectEEtoSel::ID = 0;
createSplitIndirectEEtoSelPass()4959 FunctionPass* IGC::createSplitIndirectEEtoSelPass() { return new SplitIndirectEEtoSel(); }
4960
runOnFunction(Function & F)4961 bool SplitIndirectEEtoSel::runOnFunction(Function& F)
4962 {
4963 didSomething = false;
4964 visit(F);
4965 return didSomething;
4966 }
4967
isProfitableToSplit(uint64_t num,int64_t mul,int64_t add)4968 bool SplitIndirectEEtoSel::isProfitableToSplit(uint64_t num, int64_t mul, int64_t add)
4969 {
4970 /* Assumption:
4971 Pass is profitable when: (X * cmp + Y * sel) < (ExecSize * mov VxH).
4972 */
4973
4974 const int64_t assumedVXHCost = IGC_GET_FLAG_VALUE(SplitIndirectEEtoSelThreshold);
4975 int64_t possibleCost = 0;
4976
4977 /* for: extractelement <4 x float> , %index
4978 cost is (4 - 1) * (icmp + sel) = 6;
4979 */
4980 possibleCost = ((int64_t)num -1) * 2;
4981 if (possibleCost < assumedVXHCost)
4982 return true;
4983
4984 /* for: extractelement <12 x float> , (mul %real_index, 3)
4985 cost is ((12/3) - 1) * (icmp + sel) = 6;
4986 */
4987
4988 if (mul > 0) // not tested negative options
4989 {
4990 int64_t differentOptions = 1 + ((int64_t)num - 1) / mul; // ceil(num/mul)
4991 possibleCost = (differentOptions - 1) * 2;
4992
4993 if (possibleCost < assumedVXHCost)
4994 return true;
4995 }
4996
4997 return false;
4998 }
4999
visitExtractElementInst(llvm::ExtractElementInst & I)5000 void SplitIndirectEEtoSel::visitExtractElementInst(llvm::ExtractElementInst& I)
5001 {
5002 using namespace llvm::PatternMatch;
5003
5004 IGCLLVM::FixedVectorType* vecTy = dyn_cast<IGCLLVM::FixedVectorType>(I.getVectorOperandType());
5005 IGC_ASSERT( vecTy != nullptr );
5006 uint64_t num = vecTy->getNumElements();
5007 Type* eleType = vecTy->getElementType();
5008
5009 Value* vec = I.getVectorOperand();
5010 Value* index = I.getIndexOperand();
5011
5012 // ignore constant index
5013 if (dyn_cast<ConstantInt>(index))
5014 {
5015 return;
5016 }
5017
5018 // ignore others for now (did not yet evaluate perf. impact)
5019 if (!(eleType->isIntegerTy(32) || eleType->isFloatTy()))
5020 {
5021 return;
5022 }
5023
5024 // used to calculate offsets
5025 int64_t add = 0;
5026 int64_t mul = 1;
5027
5028
5029 /* strip mul/add from index calculation and remember it for later:
5030 %268 = mul nuw i32 %res.i2.i, 3
5031 %270 = add i32 %268, 1
5032 %271 = extractelement <12 x float> %234, i32 %270
5033 */
5034 Value* Val1 = nullptr;
5035 ConstantInt* ci_add = nullptr;
5036 ConstantInt* ci_mul = nullptr;
5037
5038 auto pat1 = m_Add(m_Mul(m_Value(Val1), m_ConstantInt(ci_mul)), m_ConstantInt(ci_add));
5039 auto pat2 = m_Mul(m_Value(Val1), m_ConstantInt(ci_mul));
5040 // Some code shows `shl+or` instead of mul+add.
5041 auto pat21 = m_Or(m_Shl(m_Value(Val1), m_ConstantInt(ci_mul)), m_ConstantInt(ci_add));
5042 auto pat22 = m_Shl(m_Value(Val1), m_ConstantInt(ci_mul));
5043
5044 if (match(index, pat1) || match(index, pat2))
5045 {
5046 add = ci_add ? ci_add->getSExtValue() : 0;
5047 mul = ci_mul ? ci_mul->getSExtValue() : 1;
5048 index = Val1;
5049 }
5050 else if (match(index, pat21) || match(index, pat22))
5051 {
5052 add = ci_add ? ci_add->getSExtValue() : 0;
5053 mul = ci_mul ? (1LL << ci_mul->getSExtValue()) : 1LL;
5054 index = Val1;
5055 }
5056
5057 if (!isProfitableToSplit(num, mul, add))
5058 return;
5059
5060 Value* vTemp = llvm::UndefValue::get(eleType);
5061 IRBuilder<> builder(I.getNextNode());
5062
5063 // returns true if we can skip this icmp, such as:
5064 // icmp eq (add (mul %index, 3), 2), 1
5065 // icmp eq (mul %index, 3), 1
5066 auto canSafelySkipThis = [&](int64_t add, int64_t mul, int64_t & newIndex) {
5067 if (mul)
5068 {
5069 newIndex -= add;
5070 if ((newIndex % mul) != 0)
5071 return true;
5072 newIndex = newIndex / mul;
5073 }
5074 return false;
5075 };
5076
5077 // Generate combinations
5078 for (uint64_t elemIndex = 0; elemIndex < num; elemIndex++)
5079 {
5080 int64_t cmpIndex = elemIndex;
5081
5082 if (canSafelySkipThis(add, mul, cmpIndex))
5083 continue;
5084
5085 // Those 2 might be different, when cmp will get altered by it's operands, but EE index stays the same
5086 ConstantInt* cmpIndexCI = llvm::ConstantInt::get(builder.getInt32Ty(), (uint64_t)cmpIndex);
5087 ConstantInt* eeiIndexCI = llvm::ConstantInt::get(builder.getInt32Ty(), (uint64_t)elemIndex);
5088
5089 Value* cmp = builder.CreateICmp(CmpInst::Predicate::ICMP_EQ, index, cmpIndexCI);
5090 Value* subcaseEE = builder.CreateExtractElement(vec, eeiIndexCI);
5091 Value* sel = builder.CreateSelect(cmp, subcaseEE, vTemp);
5092 vTemp = sel;
5093 didSomething = true;
5094 }
5095
5096 // In theory there's no situation where we don't do something up to this point.
5097 if (didSomething)
5098 {
5099 I.replaceAllUsesWith(vTemp);
5100 }
5101 }
5102
5103
5104 IGC_INITIALIZE_PASS_BEGIN(SplitIndirectEEtoSel, "SplitIndirectEEtoSel", "SplitIndirectEEtoSel", false, false)
5105 IGC_INITIALIZE_PASS_END(SplitIndirectEEtoSel, "SplitIndirectEEtoSel", "SplitIndirectEEtoSel", false, false)
5106
5107 ////////////////////////////////////////////////////////////////////////
5108 // LogicalAndToBranch trying to find logical AND like below:
5109 // res = simpleCond0 && complexCond1
5110 // and convert it to:
5111 // if simpleCond0
5112 // res = complexCond1
5113 // else
5114 // res = false
5115 namespace {
5116 class LogicalAndToBranch : public FunctionPass
5117 {
5118 public:
5119 static char ID;
5120 const int NUM_INST_THRESHOLD = 32;
5121 LogicalAndToBranch();
5122
getPassName() const5123 StringRef getPassName() const override { return "LogicalAndToBranch"; }
5124
5125 bool runOnFunction(Function& F) override;
5126
5127 protected:
5128 SmallPtrSet<Instruction*, 8> m_sched;
5129
5130 // schedule instruction up before insertPos
5131 bool scheduleUp(BasicBlock* bb, Value* V, Instruction*& insertPos);
5132
5133 // check if it's safe to convert instructions between cond0 & cond1,
5134 // moveInsts are the values referened out of (cond0, cond1), we need to
5135 // move them before cond0
5136 bool isSafeToConvert(Instruction* cond0, Instruction* cond1,
5137 smallvector<Instruction*, 8> & moveInsts);
5138
5139 void convertAndToBranch(Instruction* opAnd,
5140 Instruction* cond0, Instruction* cond1, BasicBlock*& newBB);
5141 };
5142
5143 }
5144
5145 IGC_INITIALIZE_PASS_BEGIN(LogicalAndToBranch, "logicalAndToBranch", "logicalAndToBranch", false, false)
5146 IGC_INITIALIZE_PASS_END(LogicalAndToBranch, "logicalAndToBranch", "logicalAndToBranch", false, false)
5147
5148 char LogicalAndToBranch::ID = 0;
createLogicalAndToBranchPass()5149 FunctionPass* IGC::createLogicalAndToBranchPass() { return new LogicalAndToBranch(); }
5150
LogicalAndToBranch()5151 LogicalAndToBranch::LogicalAndToBranch() : FunctionPass(ID)
5152 {
5153 initializeLogicalAndToBranchPass(*PassRegistry::getPassRegistry());
5154 }
5155
scheduleUp(BasicBlock * bb,Value * V,Instruction * & insertPos)5156 bool LogicalAndToBranch::scheduleUp(BasicBlock* bb, Value* V,
5157 Instruction*& insertPos)
5158 {
5159 Instruction* inst = dyn_cast<Instruction>(V);
5160 if (!inst)
5161 return false;
5162
5163 if (inst->getParent() != bb || isa<PHINode>(inst))
5164 return false;
5165
5166 if (m_sched.count(inst))
5167 {
5168 if (insertPos && !isInstPrecede(inst, insertPos))
5169 insertPos = inst;
5170 return false;
5171 }
5172
5173 bool changed = false;
5174
5175 for (auto OI = inst->op_begin(), OE = inst->op_end(); OI != OE; ++OI)
5176 {
5177 changed |= scheduleUp(bb, OI->get(), insertPos);
5178 }
5179 m_sched.insert(inst);
5180
5181 if (insertPos && isInstPrecede(inst, insertPos))
5182 return changed;
5183
5184 if (insertPos) {
5185 inst->removeFromParent();
5186 inst->insertBefore(insertPos);
5187 }
5188
5189 return true;
5190 }
5191
5192 // split original basic block from:
5193 // original BB:
5194 // %cond0 =
5195 // ...
5196 // %cond1 =
5197 // %andRes = and i1 %cond0, %cond1
5198 // ...
5199 // to:
5200 // original BB:
5201 // %cond0 =
5202 // if %cond0, if.then, if.else
5203 //
5204 // if.then:
5205 // ...
5206 // %cond1 =
5207 // br if.end
5208 //
5209 // if.else:
5210 // br if.end
5211 //
5212 // if.end:
5213 // %andRes = phi [%cond1, if.then], [false, if.else]
5214 // ...
convertAndToBranch(Instruction * opAnd,Instruction * cond0,Instruction * cond1,BasicBlock * & newBB)5215 void LogicalAndToBranch::convertAndToBranch(Instruction* opAnd,
5216 Instruction* cond0, Instruction* cond1, BasicBlock*& newBB)
5217 {
5218 BasicBlock* bb = opAnd->getParent();
5219 BasicBlock* bbThen, * bbElse, * bbEnd;
5220
5221 bbThen = bb->splitBasicBlock(cond0->getNextNode(), "if.then");
5222 bbElse = bbThen->splitBasicBlock(opAnd, "if.else");
5223 bbEnd = bbElse->splitBasicBlock(opAnd, "if.end");
5224
5225 bb->getTerminator()->eraseFromParent();
5226 BranchInst* br = BranchInst::Create(bbThen, bbElse, cond0, bb);
5227
5228 bbThen->getTerminator()->eraseFromParent();
5229 br = BranchInst::Create(bbEnd, bbThen);
5230
5231 PHINode* phi = PHINode::Create(opAnd->getType(), 2, "", opAnd);
5232 phi->addIncoming(cond1, bbThen);
5233 phi->addIncoming(ConstantInt::getFalse(opAnd->getType()), bbElse);
5234 opAnd->replaceAllUsesWith(phi);
5235 opAnd->eraseFromParent();
5236
5237 newBB = bbEnd;
5238 }
5239
isSafeToConvert(Instruction * cond0,Instruction * cond1,smallvector<Instruction *,8> & moveInsts)5240 bool LogicalAndToBranch::isSafeToConvert(
5241 Instruction* cond0, Instruction* cond1,
5242 smallvector<Instruction*, 8> & moveInsts)
5243 {
5244 BasicBlock::iterator is0(cond0);
5245 BasicBlock::iterator is1(cond1);
5246
5247 bool isSafe = true;
5248 SmallPtrSet<Value*, 32> iset;
5249
5250 iset.insert(cond1);
5251 for (auto i = ++is0; i != is1; ++i)
5252 {
5253 if ((*i).mayHaveSideEffects())
5254 {
5255 isSafe = false;
5256 break;
5257 }
5258 iset.insert(&(*i));
5259 }
5260
5261 if (!isSafe)
5262 {
5263 return false;
5264 }
5265
5266 is0 = cond0->getIterator();
5267 // check if the values in between are used elsewhere
5268 for (auto i = ++is0; i != is1; ++i)
5269 {
5270 Instruction* inst = &*i;
5271 for (auto ui : inst->users())
5272 {
5273 if (iset.count(ui) == 0)
5274 {
5275 moveInsts.push_back(inst);
5276 break;
5277 }
5278 }
5279 }
5280 return isSafe;
5281 }
5282
runOnFunction(Function & F)5283 bool LogicalAndToBranch::runOnFunction(Function& F)
5284 {
5285 bool changed = false;
5286 if (IGC_IS_FLAG_DISABLED(EnableLogicalAndToBranch))
5287 {
5288 return changed;
5289 }
5290
5291 for (auto BI = F.begin(), BE = F.end(); BI != BE; )
5292 {
5293 // advance iterator before handling current BB
5294 BasicBlock* bb = &*BI++;
5295
5296 for (auto II = bb->begin(), IE = bb->end(); II != IE; )
5297 {
5298 Instruction* inst = &(*II++);
5299
5300 // search for "and i1"
5301 if (inst->getOpcode() == BinaryOperator::And &&
5302 inst->getType()->isIntegerTy(1))
5303 {
5304 Instruction* s0 = dyn_cast<Instruction>(inst->getOperand(0));
5305 Instruction* s1 = dyn_cast<Instruction>(inst->getOperand(1));
5306 if (s0 && s1 &&
5307 !isa<PHINode>(s0) && !isa<PHINode>(s1) &&
5308 s0->hasOneUse() && s1->hasOneUse() &&
5309 s0->getParent() == bb && s1->getParent() == bb)
5310 {
5311 if (isInstPrecede(s1, s0))
5312 {
5313 std::swap(s0, s1);
5314 }
5315 BasicBlock::iterator is0(s0);
5316 BasicBlock::iterator is1(s1);
5317
5318 if (std::distance(is0, is1) < NUM_INST_THRESHOLD)
5319 {
5320 continue;
5321 }
5322
5323 smallvector<Instruction*, 8> moveInsts;
5324 if (isSafeToConvert(s0, inst, moveInsts))
5325 {
5326 // if values defined between s0 & inst(branch) are referenced
5327 // outside of (s0, inst), they need to be moved before
5328 // s0 to keep SSA form.
5329 for (auto inst : moveInsts)
5330 scheduleUp(bb, inst, s0);
5331 m_sched.clear();
5332
5333 // IE need to be updated since original BB is splited
5334 convertAndToBranch(inst, s0, s1, bb);
5335 IE = bb->end();
5336 changed = true;
5337 }
5338 }
5339 }
5340 }
5341 }
5342
5343 return changed;
5344 }
5345
5346 // clean PHINode does the following:
5347 // given the following:
5348 // a = phi (x, b0), (x, b1)
5349 // this pass will replace 'a' with 'x', and as result, phi is removed.
5350 //
5351 // Special note:
5352 // LCSSA PHINode has a single incoming value. Make sure it is not removed
5353 // as WIA uses lcssa phi as a seperator between a uniform value inside loop
5354 // and non-uniform value outside a loop. For example:
5355 // B0:
5356 // i = 0;
5357 // Loop:
5358 // i_0 = phi (0, B0) (t, Bn)
5359 // .... <use i_0>
5360 // if (divergent cond)
5361 // Bi:
5362 // goto out;
5363 // Bn:
5364 // t = i_0 + 1;
5365 // if (t < N) goto Loop;
5366 // goto output;
5367 // out:
5368 // i_1 = phi (i_0, Bi) <-- lcssa phi node
5369 // ....
5370 // output:
5371 // Here, i_0 is uniform within the loop, but it is not outside loop as each WI will
5372 // exit with different i, thus i_1 is non-uniform. (Note that removing lcssa might be
5373 // bad in performance, but it should not cause any functional issue.)
5374 //
5375 // This is needed to avoid generating the following code for which vISA cannot generate
5376 // the correct code:
5377 // i = 0; // i is uniform
5378 // Loop:
5379 // x = i + 1 // x is uniform
5380 // B0 if (divergent-condition)
5381 // <code1>
5382 // B1 else
5383 // z = array[i]
5384 // B2 endif
5385 // i = phi (x, B0), (x, B1)
5386 // ......
5387 // if (i < n) goto Loop
5388 //
5389 // Generated code (visa) (phi becomes mov in its incoming BBs).
5390 //
5391 // i = 0; // i is uniform
5392 // Loop:
5393 // (W) x = i + 1 // x is uniform, NoMask
5394 // B0 if (divergent-condition)
5395 // <code1>
5396 // (W) i = x // noMask
5397 // B1 else
5398 // z = array[i]
5399 // (W) i = x // noMask
5400 // B2 endif
5401 // ......
5402 // if (i < n) goto Loop
5403 //
5404 // In the 1st iteration, 'z' should be array[0]. Assume 'if' is divergent, thus both B0 and B1
5405 // blocks will be executed. As result, the value of 'i' after B0 will be x, which is 1. And 'z'
5406 // will take array[1], which is wrong (correct one is array[0]).
5407 //
5408 // This case happens if phi is uniform, which means all phis' incoming values are identical
5409 // and uniform (current hehavior of WIAnalysis). Note that the identical values means this phi
5410 // is no longer needed. Once such a phi is removed, we will never generate code like one shown
5411 // above and thus, no wrong code will be generated from visa.
5412 //
5413 // This pass will be invoked at place close to the Emit pass, where WIAnalysis will be invoked,
5414 // so that IR between this pass and WIAnalysis stays the same, at least no new PHINodes like this
5415 // will be generated.
5416 //
5417 namespace {
5418 class CleanPHINode : public FunctionPass
5419 {
5420 public:
5421 static char ID;
5422 CleanPHINode();
5423
getPassName() const5424 StringRef getPassName() const override { return "CleanPhINode"; }
5425
5426 bool runOnFunction(Function& F) override;
5427 };
5428 }
5429
5430 #undef PASS_FLAG
5431 #undef PASS_DESCRIPTION
5432 #undef PASS_CFG_ONLY
5433 #undef PASS_ANALYSIS
5434 #define PASS_FLAG "igc-cleanphinode"
5435 #define PASS_DESCRIPTION "Clean up PHINode"
5436 #define PASS_CFG_ONLY false
5437 #define PASS_ANALYSIS false
5438 IGC_INITIALIZE_PASS_BEGIN(CleanPHINode, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
5439 IGC_INITIALIZE_PASS_END(CleanPHINode, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
5440
5441
5442 char CleanPHINode::ID = 0;
createCleanPHINodePass()5443 FunctionPass* IGC::createCleanPHINodePass()
5444 {
5445 return new CleanPHINode();
5446 }
5447
CleanPHINode()5448 CleanPHINode::CleanPHINode() : FunctionPass(ID)
5449 {
5450 initializeCleanPHINodePass(*PassRegistry::getPassRegistry());
5451 }
5452
runOnFunction(Function & F)5453 bool CleanPHINode::runOnFunction(Function& F)
5454 {
5455 auto isLCSSAPHINode = [](PHINode* PHI) { return (PHI->getNumIncomingValues() == 1); };
5456
5457 bool changed = false;
5458 for (auto BI = F.begin(), BE = F.end(); BI != BE; ++BI)
5459 {
5460 BasicBlock* BB = &*BI;
5461 auto II = BB->begin();
5462 auto IE = BB->end();
5463 while (II != IE)
5464 {
5465 auto currII = II;
5466 ++II;
5467 PHINode* PHI = dyn_cast<PHINode>(currII);
5468 if (PHI == nullptr)
5469 {
5470 // proceed to the next BB
5471 break;
5472 }
5473 if (isLCSSAPHINode(PHI))
5474 {
5475 // Keep LCSSA PHI as uniform analysis needs it.
5476 continue;
5477 }
5478
5479 if (PHI->getNumIncomingValues() > 0) // sanity
5480 {
5481 Value* sameVal = PHI->getIncomingValue(0);
5482 bool isAllSame = true;
5483 for (int i = 1, sz = (int)PHI->getNumIncomingValues(); i < sz; ++i)
5484 {
5485 if (sameVal != PHI->getIncomingValue(i))
5486 {
5487 isAllSame = false;
5488 break;
5489 }
5490 }
5491 if (isAllSame)
5492 {
5493 PHI->replaceAllUsesWith(sameVal);
5494 PHI->eraseFromParent();
5495 changed = true;
5496 }
5497 }
5498 }
5499 }
5500 return changed;
5501 }