1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2017-2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 /*========================== begin_copyright_notice ============================
10 
11 This file is distributed under the University of Illinois Open Source License.
12 See LICENSE.TXT for details.
13 
14 ============================= end_copyright_notice ===========================*/
15 
16 /*========================== CustomUnsafeOptPass.cpp ==========================
17 
18 This file contains IGC custom optimizations that are arithmetically safe.
19 The passes are
20     CustomSafeOptPass
21     GenSpecificPattern
22     IGCConstProp
23     IGCIndirectICBPropagaion
24     CustomLoopInfo
25     VectorBitCastOpt
26     GenStrengthReduction
27     FlattenSmallSwitch
28     SplitIndirectEEtoSel
29 
30 CustomSafeOptPass does peephole optimizations
31 For example, reduce the alloca size so there is a chance to promote indexed temp.
32 
33 GenSpecificPattern reverts llvm changes back to what is needed.
34 For example, llvm changes AND to OR, and GenSpecificPaattern changes it back to
35 allow more optimizations
36 
37 IGCConstProp was originated from llvm copy-prop code with one addition for
38 shader-constant replacement. So llvm copyright is added above.
39 
40 IGCIndirectICBPropagaion reads the immediate constant buffer from meta data and
41 use them as immediates instead of using send messages to read from buffer.
42 
43 CustomLoopInfo returns true if there is any sampleL in a loop for the driver.
44 
45 VectorBitCastOpt preprocesses vector bitcasts to be after extractelement
46 instructions.
47 
48 GenStrengthReduction performs a fdiv optimization.
49 
50 FlattenSmallSwitch flatten the if/else or switch structure and use cmp+sel
51 instead if the structure is small.
52 
53 SplitIndirectEEtoSel splits extractelements with very small vec to a series of
54 cmp+sel to avoid expensive VxH mov.
55 
56 =============================================================================*/
57 
58 #include "Compiler/CustomSafeOptPass.hpp"
59 #include "Compiler/CISACodeGen/helper.h"
60 #include "Compiler/CodeGenPublic.h"
61 #include "Compiler/IGCPassSupport.h"
62 #include "GenISAIntrinsics/GenIntrinsics.h"
63 #include "GenISAIntrinsics/GenIntrinsicInst.h"
64 #include "common/IGCConstantFolder.h"
65 #include "common/LLVMWarningsPush.hpp"
66 #include "llvm/Config/llvm-config.h"
67 #include "WrapperLLVM/Utils.h"
68 #include <llvmWrapper/IR/DerivedTypes.h>
69 #include <llvmWrapper/IR/IRBuilder.h>
70 #include <llvmWrapper/IR/PatternMatch.h>
71 #include <llvmWrapper/Analysis/TargetLibraryInfo.h>
72 #include <llvm/ADT/Statistic.h>
73 #include <llvm/ADT/SetVector.h>
74 #include <llvm/Analysis/ConstantFolding.h>
75 #include <llvm/IR/Constants.h>
76 #include <llvm/IR/Function.h>
77 #include <llvm/IR/Instructions.h>
78 #include <llvm/IR/Intrinsics.h>
79 #include <llvm/IR/InstIterator.h>
80 #include <llvm/Transforms/Utils/Local.h>
81 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
82 #include <llvm/Analysis/ValueTracking.h>
83 #include "common/LLVMWarningsPop.hpp"
84 #include <set>
85 #include "../inc/common/secure_mem.h"
86 #include "Probe/Assertion.h"
87 
88 using namespace llvm;
89 using namespace IGC;
90 using namespace GenISAIntrinsic;
91 
92 // Register pass to igc-opt
93 #define PASS_FLAG1 "igc-custom-safe-opt"
94 #define PASS_DESCRIPTION1 "Custom Pass Optimization"
95 #define PASS_CFG_ONLY1 false
96 #define PASS_ANALYSIS1 false
97 IGC_INITIALIZE_PASS_BEGIN(CustomSafeOptPass, PASS_FLAG1, PASS_DESCRIPTION1, PASS_CFG_ONLY1, PASS_ANALYSIS1)
98 IGC_INITIALIZE_PASS_END(CustomSafeOptPass, PASS_FLAG1, PASS_DESCRIPTION1, PASS_CFG_ONLY1, PASS_ANALYSIS1)
99 
100 char CustomSafeOptPass::ID = 0;
101 
CustomSafeOptPass()102 CustomSafeOptPass::CustomSafeOptPass() : FunctionPass(ID)
103 {
104     initializeCustomSafeOptPassPass(*PassRegistry::getPassRegistry());
105 }
106 
107 #define DEBUG_TYPE "CustomSafeOptPass"
108 
109 STATISTIC(Stat_DiscardRemoved, "Number of insts removed in Discard Opt");
110 
runOnFunction(Function & F)111 bool CustomSafeOptPass::runOnFunction(Function& F)
112 {
113     psHasSideEffect = getAnalysis<CodeGenContextWrapper>().getCodeGenContext()->m_instrTypes.psHasSideEffect;
114     visit(F);
115     return true;
116 }
117 
visitInstruction(Instruction & I)118 void CustomSafeOptPass::visitInstruction(Instruction& I)
119 {
120     // nothing
121 }
122 
123 //  Searches the following pattern
124 //      %1 = icmp eq i32 %cmpop1, %cmpop2
125 //      %2 = xor i1 %1, true
126 //      ...
127 //      %3 = select i1 %1, i8 0, i8 1
128 //
129 //  And changes it to:
130 //      %1 = icmp ne i32 %cmpop1, %cmpop2
131 //      ...
132 //      %3 = select i1 %1, i8 1, i8 0
133 //
134 //  and
135 //
136 //  Searches the following pattern
137 //      %1 = icmp ule i32 %cmpop1, %cmpop2
138 //      %2 = xor i1 %1, true
139 //      br i1 %1, label %3, label %4
140 //
141 //  And changes it to:
142 //      %1 = icmp ugt i32 %cmpop1, %cmpop2
143 //      br i1 %1, label %4, label %3
144 //
145 //  This optimization combines statements regardless of the predicate.
146 //  It will also work if the icmp instruction does not have users, except for the xor, select or branch instruction.
visitXor(Instruction & XorInstr)147 void CustomSafeOptPass::visitXor(Instruction& XorInstr) {
148     using namespace llvm::PatternMatch;
149 
150     CmpInst::Predicate Pred;
151     auto XorPattern = m_c_Xor(m_ICmp(Pred, m_Value(), m_Value()), m_SpecificInt(1));
152     if (!match(&XorInstr, XorPattern)) {
153         return;
154     }
155 
156     Value* XorOp0 = XorInstr.getOperand(0);
157     Value* XorOp1 = XorInstr.getOperand(1);
158     auto ICmpInstr = cast<Instruction>(isa<ICmpInst>(XorOp0) ? XorOp0 : XorOp1);
159 
160     llvm::SmallVector<Instruction*, 4> UsersList;
161 
162     for (auto U : ICmpInstr->users()) {
163         if (isa<BranchInst>(U)) {
164             UsersList.push_back(cast<Instruction>(U));
165         }
166         else if (SelectInst* S = dyn_cast<SelectInst>(U)) {
167             if (S->getCondition() == ICmpInstr) {
168                 UsersList.push_back(cast<Instruction>(S));
169             }
170             else {
171                 return;
172             }
173         }
174         else if (U != &XorInstr) {
175             return;
176         }
177     }
178 
179     IRBuilder<> builder(ICmpInstr);
180     auto NegatedCmpPred = cast<ICmpInst>(ICmpInstr)->getInversePredicate();
181     auto NewCmp = cast<ICmpInst>(builder.CreateICmp(NegatedCmpPred, ICmpInstr->getOperand(0), ICmpInstr->getOperand(1)));
182 
183     for (auto I : UsersList) {
184         if (SelectInst* S = dyn_cast<SelectInst>(I)) {
185             S->swapProfMetadata();
186             Value* TrueVal = S->getTrueValue();
187             Value* FalseVal = S->getFalseValue();
188             S->setTrueValue(FalseVal);
189             S->setFalseValue(TrueVal);
190         }
191         else {
192             IGC_ASSERT(isa<BranchInst>(I));
193             BranchInst* B = cast<BranchInst>(I);
194             B->swapSuccessors();
195         }
196     }
197 
198     XorInstr.replaceAllUsesWith(NewCmp);
199     ICmpInstr->replaceAllUsesWith(NewCmp);
200     XorInstr.eraseFromParent();
201     ICmpInstr->eraseFromParent();
202 }
203 
204 //  Searches for following pattern:
205 //    %cmp = icmp slt i32 %x, %y
206 //    %cond.not = xor i1 %cond, true
207 //    %and.cond = and i1 %cmp, %cond.not
208 //    br i1 %or.cond, label %bb1, label %bb2
209 //
210 //  And changes it to:
211 //    %0 = icmp sge i32 %x, %y
212 //    %1 = or i1 %cond, %0
213 //    br i1 %1, label %bb2, label %bb1
visitAnd(BinaryOperator & I)214 void CustomSafeOptPass::visitAnd(BinaryOperator& I) {
215     using namespace llvm::PatternMatch;
216 
217     if (!I.hasOneUse() ||
218         !isa<BranchInst>(*I.user_begin()) ||
219         !I.getType()->isIntegerTy(1)) {
220         return;
221     }
222 
223     Value* XorArgValue;
224     CmpInst::Predicate Pred;
225     auto AndPattern = m_c_And(m_c_Xor(m_Value(XorArgValue), m_SpecificInt(1)), m_ICmp(Pred, m_Value(), m_Value()));
226     if (!match(&I, AndPattern)) return;
227 
228     IRBuilder<> builder(&I);
229     auto CompareInst = cast<ICmpInst>(isa<ICmpInst>(I.getOperand(0)) ? I.getOperand(0) : I.getOperand(1));
230     auto NegatedCompareInst = builder.CreateICmp(CompareInst->getInversePredicate(), CompareInst->getOperand(0), CompareInst->getOperand(1));
231     auto OrInst = builder.CreateOr(XorArgValue, NegatedCompareInst);
232 
233     auto BrInst = cast<BranchInst>(*I.user_begin());
234     BrInst->setCondition(OrInst);
235     BrInst->swapSuccessors();
236 
237     I.eraseFromParent();
238 }
239 
240 /*
241 Optimizing from
242 % 377 = call i32 @llvm.genx.GenISA.simdSize()
243 % .rhs.trunc = trunc i32 % 377 to i8
244 % .lhs.trunc = trunc i32 % 26 to i8
245 % 383 = udiv i8 % .lhs.trunc, % .rhs.trunc
246 to
247 % 377 = call i32 @llvm.genx.GenISA.simdSize()
248 % .rhs.trunc = trunc i32 % 377 to i8
249 % .lhs.trunc = trunc i32 % 26 to i8
250 % a = shr i8 % rhs.trunc, 4
251 % b = shr i8 % .lhs.trunc, 3
252 % 382 = shr i8 % b, % a
253 or
254 % 377 = call i32 @llvm.genx.GenISA.simdSize()
255 % 383 = udiv i32 %382, %377
256 to
257 % 377 = call i32 @llvm.genx.GenISA.simdSize()
258 % a = shr i32 %377, 4
259 % b = shr i32 %382, 3
260 % 382 = shr i32 % b, % a
261 */
visitUDiv(llvm::BinaryOperator & I)262 void CustomSafeOptPass::visitUDiv(llvm::BinaryOperator& I)
263 {
264     bool isPatternfound = false;
265 
266     if (TruncInst* trunc = dyn_cast<TruncInst>(I.getOperand(1)))
267     {
268         if (CallInst* Inst = dyn_cast<CallInst>(trunc->getOperand(0)))
269         {
270             if (GenIntrinsicInst* SimdInst = dyn_cast<GenIntrinsicInst>(Inst))
271                 if (SimdInst->getIntrinsicID() == GenISAIntrinsic::GenISA_simdSize)
272                     isPatternfound = true;
273         }
274     }
275     else
276     {
277         if (CallInst* Inst = dyn_cast<CallInst>(I.getOperand(1)))
278         {
279             if (GenIntrinsicInst* SimdInst = dyn_cast<GenIntrinsicInst>(Inst))
280                 if (SimdInst->getIntrinsicID() == GenISAIntrinsic::GenISA_simdSize)
281                     isPatternfound = true;
282         }
283     }
284     if (isPatternfound)
285     {
286         IRBuilder<> builder(&I);
287         Value* Shift1 = builder.CreateLShr(I.getOperand(1), ConstantInt::get(I.getOperand(1)->getType(), 4));
288         Value* Shift2 = builder.CreateLShr(I.getOperand(0), ConstantInt::get(I.getOperand(0)->getType(), 3));
289         Value* Shift3 = builder.CreateLShr(Shift2, Shift1);
290         I.replaceAllUsesWith(Shift3);
291         I.eraseFromParent();
292     }
293 }
294 
visitAllocaInst(AllocaInst & I)295 void CustomSafeOptPass::visitAllocaInst(AllocaInst& I)
296 {
297     // reduce the alloca size so there is a chance to promote indexed temp.
298 
299     // ex:                                                                  to:
300     // dcl_indexable_temp x1[356], 4                                        dcl_indexable_temp x1[2], 4
301     // mov x[1][354].xyzw, l(1f, 0f, -1f, 0f)                               mov x[1][0].xyzw, l(1f, 0f, -1f, 0f)
302     // mov x[1][355].xyzw, l(1f, 1f, 0f, -1f)                               mov x[1][1].xyzw, l(1f, 1f, 0f, -1f)
303     // mov r[1].xy, x[1][r[1].x + 354].xyxx                                 mov r[1].xy, x[1][r[1].x].xyxx
304 
305     // in llvm:                                                             to:
306     // %outarray_x1 = alloca[356 x float], align 4                          %31 = alloca[2 x float]
307     // %outarray_y2 = alloca[356 x float], align 4                          %32 = alloca[2 x float]
308     // %27 = getelementptr[356 x float] * %outarray_x1, i32 0, i32 352      %35 = getelementptr[2 x float] * %31, i32 0, i32 0
309     // store float 0.000000e+00, float* %27, align 4                        store float 0.000000e+00, float* %35, align 4
310     // %28 = getelementptr[356 x float] * %outarray_y2, i32 0, i32 352      %36 = getelementptr[2 x float] * %32, i32 0, i32 0
311     // store float 0.000000e+00, float* %28, align 4                        store float 0.000000e+00, float* %36, align 4
312     // %43 = add nsw i32 %selRes_s, 354
313     // %44 = getelementptr[356 x float] * %outarray_x1, i32 0, i32 %43      %51 = getelementptr[2 x float] * %31, i32 0, i32 %selRes_s
314     // %45 = load float* %44, align 4                                       %52 = load float* %51, align 4
315 
316     llvm::Type* pType = I.getType()->getPointerElementType();
317     if (!pType->isArrayTy() ||
318         static_cast<ADDRESS_SPACE>(I.getType()->getAddressSpace()) != ADDRESS_SPACE_PRIVATE)
319     {
320         return;
321     }
322 
323     if (!(pType->getArrayElementType()->isFloatingPointTy() ||
324         pType->getArrayElementType()->isIntegerTy() ||
325         pType->getArrayElementType()->isPointerTy()))
326     {
327         return;
328     }
329 
330     int index_lb = int_cast<int>(pType->getArrayNumElements());
331     int index_ub = 0;
332 
333     // Find all uses of this alloca
334     for (Value::user_iterator it = I.user_begin(), e = I.user_end(); it != e; ++it)
335     {
336         if (GetElementPtrInst * pGEP = llvm::dyn_cast<GetElementPtrInst>(*it))
337         {
338             ConstantInt* C0 = dyn_cast<ConstantInt>(pGEP->getOperand(1));
339             if (!C0 || !C0->isZero() || pGEP->getNumOperands() != 3)
340             {
341                 return;
342             }
343             for (Value::user_iterator use_it = pGEP->user_begin(), use_e = pGEP->user_end(); use_it != use_e; ++use_it)
344             {
345                 if (llvm::dyn_cast<llvm::LoadInst>(*use_it))
346                 {
347                 }
348                 else if (llvm::StoreInst * pStore = llvm::dyn_cast<llvm::StoreInst>(*use_it))
349                 {
350                     llvm::Value* pValueOp = pStore->getValueOperand();
351                     if (pValueOp == *it)
352                     {
353                         // GEP instruction is the stored value of the StoreInst (not supported case)
354                         return;
355                     }
356                     if (dyn_cast<ConstantInt>(pGEP->getOperand(2)) && pGEP->getOperand(2)->getType()->isIntegerTy(32))
357                     {
358                         int currentIndex = int_cast<int>(
359                             dyn_cast<ConstantInt>(pGEP->getOperand(2))->getZExtValue());
360                         index_lb = (currentIndex < index_lb) ? currentIndex : index_lb;
361                         index_ub = (currentIndex > index_ub) ? currentIndex : index_ub;
362 
363                     }
364                     else
365                     {
366                         return;
367                     }
368                 }
369                 else
370                 {
371                     // This is some other instruction. Right now we don't want to handle these
372                     return;
373                 }
374             }
375         }
376         else
377         {
378             if (!IsBitCastForLifetimeMark(dyn_cast<Value>(*it)))
379             {
380                 return;
381             }
382         }
383     }
384 
385     unsigned int newSize = index_ub + 1 - index_lb;
386     if (newSize >= pType->getArrayNumElements())
387     {
388         return;
389     }
390     // found a case to optimize
391     IGCLLVM::IRBuilder<> IRB(&I);
392     llvm::ArrayType* allocaArraySize = llvm::ArrayType::get(pType->getArrayElementType(), newSize);
393     llvm::Value* newAlloca = IRB.CreateAlloca(allocaArraySize, nullptr);
394     llvm::Value* gepArg1;
395 
396     for (Value::user_iterator it = I.user_begin(), e = I.user_end(); it != e; ++it)
397     {
398         if (GetElementPtrInst * pGEP = llvm::dyn_cast<GetElementPtrInst>(*it))
399         {
400             if (dyn_cast<ConstantInt>(pGEP->getOperand(2)))
401             {
402                 // pGEP->getOperand(2) is constant. Reduce the constant value directly
403                 int newIndex = int_cast<int>(dyn_cast<ConstantInt>(pGEP->getOperand(2))->getZExtValue())
404                     - index_lb;
405                 gepArg1 = IRB.getInt32(newIndex);
406             }
407             else
408             {
409                 // pGEP->getOperand(2) is not constant. create a sub instruction to reduce it
410                 gepArg1 = BinaryOperator::CreateSub(pGEP->getOperand(2), IRB.getInt32(index_lb), "reducedIndex", pGEP);
411             }
412             llvm::Value* gepArg[] = { pGEP->getOperand(1), gepArg1 };
413             llvm::Value* pGEPnew = GetElementPtrInst::Create(nullptr, newAlloca, gepArg, "", pGEP);
414             pGEP->replaceAllUsesWith(pGEPnew);
415         }
416     }
417 }
418 
visitLoadInst(LoadInst & load)419 void CustomSafeOptPass::visitLoadInst(LoadInst& load)
420 {
421     // Optimize indirect access to private arrays. Handle cases where
422     // array index is a select between two immediate constant values.
423     // After the optimization there is fair chance the alloca will be
424     // promoted to registers.
425     //
426     // E.g. change the following:
427     // %PrivareArray = alloca[4 x <3 x float>], align 16
428     // %IndirectIndex = select i1 %SomeCondition, i32 1, i32 2
429     // %IndirectAccessPtr= getelementptr[4 x <3 x float>], [4 x <3 x float>] * %PrivareArray, i32 0, i32 %IndirectIndex
430     // %LoadedValue = load <3 x float>, <3 x float>* %IndirectAccess, align 16
431 
432     // %PrivareArray = alloca[4 x <3 x float>], align 16
433     // %DirectAccessPtr1 = getelementptr[4 x <3 x float>], [4 x <3 x float>] * %PrivareArray, i32 0, i32 1
434     // %DirectAccessPtr2 = getelementptr[4 x <3 x float>], [4 x <3 x float>] * %PrivareArray, i32 0, i32 2
435     // %LoadedValue1 = load <3 x float>, <3 x float>* %DirectAccessPtr1, align 16
436     // %LoadedValue2 = load <3 x float>, <3 x float>* %DirectAccessPtr2, align 16
437     // %LoadedValue = select i1 %SomeCondition, <3 x float> %LoadedValue1, <3 x float> %LoadedValue2
438 
439     Value* ptr = load.getPointerOperand();
440     if (ptr->getType()->getPointerAddressSpace() != 0)
441     {
442         // only private arrays are handled
443         return;
444     }
445     if (GetElementPtrInst * gep = dyn_cast<GetElementPtrInst>(ptr))
446     {
447         bool found = false;
448         uint selIdx = 0;
449         // Check if this gep is a good candidate for optimization.
450         // The instruction has to have exactly one non-constant index value.
451         // The index value has to be a select instruction with immediate
452         // constant values.
453         for (uint i = 0; i < gep->getNumIndices(); ++i)
454         {
455             Value* gepIdx = gep->getOperand(i + 1);
456             if (!isa<ConstantInt>(gepIdx))
457             {
458                 SelectInst* sel = dyn_cast<SelectInst>(gepIdx);
459                 if (!found &&
460                     sel &&
461                     isa<ConstantInt>(sel->getOperand(1)) &&
462                     isa<ConstantInt>(sel->getOperand(2)))
463                 {
464                     found = true;
465                     selIdx = i;
466                 }
467                 else
468                 {
469                     found = false; // optimize cases with only a single non-constant index.
470                     break;
471                 }
472             }
473 
474         }
475         if (found)
476         {
477             SelectInst* sel = cast<SelectInst>(gep->getOperand(selIdx + 1));
478             SmallVector<Value*, 8> indices;
479             indices.append(gep->idx_begin(), gep->idx_end());
480             indices[selIdx] = sel->getOperand(1);
481             GetElementPtrInst* gep1 = GetElementPtrInst::Create(nullptr, gep->getPointerOperand(), indices, gep->getName(), gep);
482             gep1->setDebugLoc(gep->getDebugLoc());
483             indices[selIdx] = sel->getOperand(2);
484             GetElementPtrInst* gep2 = GetElementPtrInst::Create(nullptr, gep->getPointerOperand(), indices, gep->getName(), gep);
485             gep2->setDebugLoc(gep->getDebugLoc());
486             LoadInst* load1 = cast<LoadInst>(load.clone());
487             load1->insertBefore(&load);
488             load1->setOperand(0, gep1);
489             LoadInst* load2 = cast<LoadInst>(load.clone());
490             load2->insertBefore(&load);
491             load2->setOperand(0, gep2);
492             SelectInst* result = SelectInst::Create(sel->getCondition(), load1, load2, load.getName(), &load);
493             result->setDebugLoc(load.getDebugLoc());
494             load.replaceAllUsesWith(result);
495             load.eraseFromParent();
496             if (gep->use_empty())
497             {
498                 gep->eraseFromParent();
499             }
500             if (sel->use_empty())
501             {
502                 sel->eraseFromParent();
503             }
504         }
505 
506     }
507 
508 }
509 
visitCallInst(CallInst & C)510 void CustomSafeOptPass::visitCallInst(CallInst& C)
511 {
512     // discard optimizations
513     if (llvm::GenIntrinsicInst * inst = llvm::dyn_cast<GenIntrinsicInst>(&C))
514     {
515         GenISAIntrinsic::ID id = inst->getIntrinsicID();
516         // try to prune the destination size
517         switch (id)
518         {
519         case GenISAIntrinsic::GenISA_discard:
520         {
521             Value* srcVal0 = C.getOperand(0);
522             if (ConstantInt * CI = dyn_cast<ConstantInt>(srcVal0))
523             {
524                 if (CI->isZero()) { // i1 is false
525                     C.eraseFromParent();
526                     ++Stat_DiscardRemoved;
527                 }
528                 else if (!psHasSideEffect)
529                 {
530                     BasicBlock* blk = C.getParent();
531                     BasicBlock* pred = blk->getSinglePredecessor();
532                     if (blk && pred)
533                     {
534                         BranchInst* cbr = dyn_cast<BranchInst>(pred->getTerminator());
535                         if (cbr && cbr->isConditional())
536                         {
537                             if (blk == cbr->getSuccessor(0))
538                             {
539                                 C.setOperand(0, cbr->getCondition());
540                                 C.removeFromParent();
541                                 C.insertBefore(cbr);
542                             }
543                             else if (blk == cbr->getSuccessor(1))
544                             {
545                                 Value* flipCond = llvm::BinaryOperator::CreateNot(cbr->getCondition(), "", cbr);
546                                 C.setOperand(0, flipCond);
547                                 C.removeFromParent();
548                                 C.insertBefore(cbr);
549                             }
550                         }
551                     }
552                 }
553             }
554             break;
555         }
556 
557         case GenISAIntrinsic::GenISA_bfi:
558         {
559             visitBfi(inst);
560             break;
561         }
562 
563         case GenISAIntrinsic::GenISA_f32tof16_rtz:
564         {
565             visitf32tof16(inst);
566             break;
567         }
568 
569         case GenISAIntrinsic::GenISA_sampleBptr:
570         {
571             visitSampleBptr(llvm::cast<llvm::SampleIntrinsic>(inst));
572             break;
573         }
574 
575         case GenISAIntrinsic::GenISA_umulH:
576         {
577             visitMulH(inst, false);
578             break;
579         }
580 
581         case GenISAIntrinsic::GenISA_imulH:
582         {
583             visitMulH(inst, true);
584             break;
585         }
586 
587         case GenISAIntrinsic::GenISA_ldptr:
588         {
589             visitLdptr(llvm::cast<llvm::SamplerLoadIntrinsic>(inst));
590             break;
591         }
592 
593         case GenISAIntrinsic::GenISA_ldrawvector_indexed:
594         {
595             visitLdRawVec(inst);
596             break;
597         }
598         default:
599             break;
600         }
601     }
602 }
603 
604 //
605 // pattern match packing of two half float from f32tof16:
606 //
607 // % 43 = call float @llvm.genx.GenISA.f32tof16.rtz(float %res_s55.i)
608 // % 44 = call float @llvm.genx.GenISA.f32tof16.rtz(float %res_s59.i)
609 // % 47 = bitcast float %44 to i32
610 // % 49 = bitcast float %43 to i32
611 // %addres_s68.i = shl i32 % 47, 16
612 // % mulres_s69.i = add i32 %addres_s68.i, % 49
613 // % 51 = bitcast i32 %mulres_s69.i to float
614 // into
615 // %43 = call half @llvm.genx.GenISA_ftof_rtz(float %res_s55.i)
616 // %44 = call half @llvm.genx.GenISA_ftof_rtz(float %res_s59.i)
617 // %45 = insertelement <2 x half>undef, %43, 0
618 // %46 = insertelement <2 x half>%45, %44, 1
619 // %51 = bitcast <2 x half> %46 to float
620 //
621 // or if the f32tof16 are from fpext half:
622 //
623 // % src0_s = fpext half %res_s to float
624 // % src0_s2 = fpext half %res_s1 to float
625 // % 2 = call fast float @llvm.genx.GenISA.f32tof16.rtz(float %src0_s)
626 // % 3 = call fast float @llvm.genx.GenISA.f32tof16.rtz(float %src0_s2)
627 // % 4 = bitcast float %2 to i32
628 // % 5 = bitcast float %3 to i32
629 // % addres_s = shl i32 % 4, 16
630 // % mulres_s = add i32 %addres_s, % 5
631 // % 6 = bitcast i32 %mulres_s to float
632 // into
633 // % 2 = insertelement <2 x half> undef, half %res_s1, i32 0, !dbg !113
634 // % 3 = insertelement <2 x half> % 2, half %res_s, i32 1, !dbg !113
635 // % 4 = bitcast <2 x half> % 3 to float, !dbg !113
636 
visitf32tof16(llvm::CallInst * inst)637 void CustomSafeOptPass::visitf32tof16(llvm::CallInst* inst)
638 {
639     if (!inst->hasOneUse())
640     {
641         return;
642     }
643 
644     BitCastInst* bitcast = dyn_cast<BitCastInst>(*(inst->user_begin()));
645     if (!bitcast || !bitcast->hasOneUse() || !bitcast->getType()->isIntegerTy(32))
646     {
647         return;
648     }
649     Instruction* addInst = dyn_cast<BinaryOperator>(*(bitcast->user_begin()));
650     if (!addInst || addInst->getOpcode() != Instruction::Add || !addInst->hasOneUse())
651     {
652         return;
653     }
654     Instruction* lastValue = addInst;
655 
656     if (BitCastInst * finalBitCast = dyn_cast<BitCastInst>(*(addInst->user_begin())))
657     {
658         lastValue = finalBitCast;
659     }
660 
661     // check the other half
662     Value* otherOpnd = addInst->getOperand(0) == bitcast ? addInst->getOperand(1) : addInst->getOperand(0);
663     Instruction* shiftOrMul = dyn_cast<BinaryOperator>(otherOpnd);
664 
665     if (!shiftOrMul ||
666         (shiftOrMul->getOpcode() != Instruction::Shl && shiftOrMul->getOpcode() != Instruction::Mul))
667     {
668         return;
669     }
670     bool isShift = shiftOrMul->getOpcode() == Instruction::Shl;
671     ConstantInt* constVal = dyn_cast<ConstantInt>(shiftOrMul->getOperand(1));
672     if (!constVal || !constVal->equalsInt(isShift ? 16 : 65536))
673     {
674         return;
675     }
676     BitCastInst* bitcast2 = dyn_cast<BitCastInst>(shiftOrMul->getOperand(0));
677     if (!bitcast2)
678     {
679         return;
680     }
681     llvm::GenIntrinsicInst* valueHi = dyn_cast<GenIntrinsicInst>(bitcast2->getOperand(0));
682     if (!valueHi || valueHi->getIntrinsicID() != GenISA_f32tof16_rtz)
683     {
684         return;
685     }
686 
687     Value* loVal = nullptr;
688     Value* hiVal = nullptr;
689 
690     FPExtInst* extInstLo = dyn_cast<FPExtInst>(inst->getOperand(0));
691     FPExtInst* extInstHi = dyn_cast<FPExtInst>(valueHi->getOperand(0));
692 
693     IRBuilder<> builder(lastValue);
694     Type* funcType[] = { Type::getHalfTy(builder.getContext()), Type::getFloatTy(builder.getContext()) };
695     Type* halfx2 = IGCLLVM::FixedVectorType::get(Type::getHalfTy(builder.getContext()), 2);
696 
697     if (extInstLo && extInstHi &&
698         extInstLo->getOperand(0)->getType()->isHalfTy() &&
699         extInstHi->getOperand(0)->getType()->isHalfTy())
700     {
701         loVal = extInstLo->getOperand(0);
702         hiVal = extInstHi->getOperand(0);
703     }
704     else
705     {
706         Function* f32tof16_rtz = GenISAIntrinsic::getDeclaration(inst->getParent()->getParent()->getParent(),
707             GenISAIntrinsic::GenISA_ftof_rtz, funcType);
708         loVal = builder.CreateCall(f32tof16_rtz, inst->getArgOperand(0));
709         hiVal = builder.CreateCall(f32tof16_rtz, valueHi->getArgOperand(0));
710     }
711     Value* vector = builder.CreateInsertElement(UndefValue::get(halfx2), loVal, builder.getInt32(0));
712     vector = builder.CreateInsertElement(vector, hiVal, builder.getInt32(1));
713     vector = builder.CreateBitCast(vector, lastValue->getType());
714     lastValue->replaceAllUsesWith(vector);
715     lastValue->eraseFromParent();
716 }
717 
visitBfi(llvm::CallInst * inst)718 void CustomSafeOptPass::visitBfi(llvm::CallInst* inst)
719 {
720     IGC_ASSERT(inst->getType()->isIntegerTy(32));
721     ConstantInt* widthV = dyn_cast<ConstantInt>(inst->getOperand(0));
722     ConstantInt* offsetV = dyn_cast<ConstantInt>(inst->getOperand(1));
723     if (widthV && offsetV)
724     {
725         // transformation is beneficial if src3 is constant or if the offset is zero
726         if (isa<ConstantInt>(inst->getOperand(3)) || offsetV->isZero())
727         {
728             unsigned int width = static_cast<unsigned int>(widthV->getZExtValue());
729             unsigned int offset = static_cast<unsigned int>(offsetV->getZExtValue());
730             unsigned int bitMask = ((1 << width) - 1) << offset;
731             IRBuilder<> builder(inst);
732             // dst = ((src2 << offset) & bitmask) | (src3 & ~bitmask)
733             Value* firstTerm = builder.CreateShl(inst->getOperand(2), offsetV);
734             firstTerm = builder.CreateAnd(firstTerm, builder.getInt32(bitMask));
735             Value* secondTerm = builder.CreateAnd(inst->getOperand(3), builder.getInt32(~bitMask));
736             Value* dst = builder.CreateOr(firstTerm, secondTerm);
737             inst->replaceAllUsesWith(dst);
738             inst->eraseFromParent();
739         }
740     }
741     else if (widthV && widthV->isZeroValue())
742     {
743         inst->replaceAllUsesWith(inst->getOperand(3));
744         inst->eraseFromParent();
745     }
746 }
747 
visitMulH(CallInst * inst,bool isSigned)748 void CustomSafeOptPass::visitMulH(CallInst* inst, bool isSigned)
749 {
750     ConstantInt* src0 = dyn_cast<ConstantInt>(inst->getOperand(0));
751     ConstantInt* src1 = dyn_cast<ConstantInt>(inst->getOperand(1));
752     if (src0 && src1)
753     {
754         unsigned nbits = inst->getType()->getIntegerBitWidth();
755         IGC_ASSERT(nbits < 64);
756 
757         if (isSigned)
758         {
759             uint64_t ui0 = src0->getZExtValue();
760             uint64_t ui1 = src1->getZExtValue();
761             uint64_t r = ((ui0 * ui1) >> nbits);
762             inst->replaceAllUsesWith(ConstantInt::get(inst->getType(), r));
763         }
764         else
765         {
766             int64_t si0 = src0->getSExtValue();
767             int64_t si1 = src1->getSExtValue();
768             int64_t r = ((si0 * si1) >> nbits);
769             inst->replaceAllUsesWith(ConstantInt::get(inst->getType(), r, true));
770         }
771         inst->eraseFromParent();
772     }
773 }
774 
775 // if phi is used in a FPTrunc and the sources all come from fpext we can skip the conversions
visitFPTruncInst(FPTruncInst & I)776 void CustomSafeOptPass::visitFPTruncInst(FPTruncInst& I)
777 {
778     if (PHINode * phi = dyn_cast<PHINode>(I.getOperand(0)))
779     {
780         bool foundPattern = true;
781         unsigned int numSrc = phi->getNumIncomingValues();
782         SmallVector<Value*, 6> newSources(numSrc);
783         for (unsigned int i = 0; i < numSrc; i++)
784         {
785             FPExtInst* source = dyn_cast<FPExtInst>(phi->getIncomingValue(i));
786             if (source && source->getOperand(0)->getType() == I.getType())
787             {
788                 newSources[i] = source->getOperand(0);
789             }
790             else
791             {
792                 foundPattern = false;
793                 break;
794             }
795         }
796         if (foundPattern)
797         {
798             PHINode* newPhi = PHINode::Create(I.getType(), numSrc, "", phi);
799             for (unsigned int i = 0; i < numSrc; i++)
800             {
801                 newPhi->addIncoming(newSources[i], phi->getIncomingBlock(i));
802             }
803 
804             I.replaceAllUsesWith(newPhi);
805             I.eraseFromParent();
806             // if phi has other uses we add a fpext to avoid having two phi
807             if (!phi->use_empty())
808             {
809                 IRBuilder<> builder(&(*phi->getParent()->getFirstInsertionPt()));
810                 Value* extV = builder.CreateFPExt(newPhi, phi->getType());
811                 phi->replaceAllUsesWith(extV);
812             }
813         }
814     }
815 
816 }
817 
visitFPToUIInst(llvm::FPToUIInst & FPUII)818 void CustomSafeOptPass::visitFPToUIInst(llvm::FPToUIInst& FPUII)
819 {
820     if (llvm::IntrinsicInst * intrinsicInst = llvm::dyn_cast<llvm::IntrinsicInst>(FPUII.getOperand(0)))
821     {
822         if (intrinsicInst->getIntrinsicID() == Intrinsic::floor)
823         {
824             FPUII.setOperand(0, intrinsicInst->getOperand(0));
825             if (intrinsicInst->use_empty())
826             {
827                 intrinsicInst->eraseFromParent();
828             }
829         }
830     }
831 }
832 
833 /// This remove simplify bitcast across phi and select instruction
834 /// LLVM doesn't catch those case and it is common in DX10+ as the input language is not typed
835 /// TODO: support cases where some sources are constant
visitBitCast(BitCastInst & BC)836 void CustomSafeOptPass::visitBitCast(BitCastInst& BC)
837 {
838     if (SelectInst * sel = dyn_cast<SelectInst>(BC.getOperand(0)))
839     {
840         BitCastInst* trueVal = dyn_cast<BitCastInst>(sel->getTrueValue());
841         BitCastInst* falseVal = dyn_cast<BitCastInst>(sel->getFalseValue());
842         if (trueVal && falseVal)
843         {
844             Value* trueValOrignalType = trueVal->getOperand(0);
845             Value* falseValOrignalType = falseVal->getOperand(0);
846             if (trueValOrignalType->getType() == BC.getType() &&
847                 falseValOrignalType->getType() == BC.getType())
848             {
849                 Value* cond = sel->getCondition();
850                 Value* newVal = SelectInst::Create(cond, trueValOrignalType, falseValOrignalType, "", sel);
851                 BC.replaceAllUsesWith(newVal);
852                 BC.eraseFromParent();
853             }
854         }
855     }
856     else if (PHINode * phi = dyn_cast<PHINode>(BC.getOperand(0)))
857     {
858         if (phi->hasOneUse())
859         {
860             bool foundPattern = true;
861             unsigned int numSrc = phi->getNumIncomingValues();
862             SmallVector<Value*, 6> newSources(numSrc);
863             for (unsigned int i = 0; i < numSrc; i++)
864             {
865                 BitCastInst* source = dyn_cast<BitCastInst>(phi->getIncomingValue(i));
866                 if (source && source->getOperand(0)->getType() == BC.getType())
867                 {
868                     newSources[i] = source->getOperand(0);
869                 }
870                 else if (Constant * C = dyn_cast<Constant>(phi->getIncomingValue(i)))
871                 {
872                     newSources[i] = ConstantExpr::getCast(Instruction::BitCast, C, BC.getType());
873                 }
874                 else
875                 {
876                     foundPattern = false;
877                     break;
878                 }
879             }
880             if (foundPattern)
881             {
882                 PHINode* newPhi = PHINode::Create(BC.getType(), numSrc, "", phi);
883                 for (unsigned int i = 0; i < numSrc; i++)
884                 {
885                     newPhi->addIncoming(newSources[i], phi->getIncomingBlock(i));
886                 }
887                 BC.replaceAllUsesWith(newPhi);
888                 BC.eraseFromParent();
889             }
890         }
891     }
892 }
893 
894 /*
895 Search for DP4A pattern, e.g.:
896 
897 %14 = load i32, i32 addrspace(1)* %13, align 4 <---- c variable (accumulator)
898 // Instructions to be matched:
899 %conv.i.i4 = sext i8 %scalar35 to i32
900 %conv.i7.i = sext i8 %scalar39 to i32
901 %mul.i = mul nsw i32 %conv.i.i4, %conv.i7.i
902 %conv.i6.i = sext i8 %scalar36 to i32
903 %conv.i5.i = sext i8 %scalar40 to i32
904 %mul4.i = mul nsw i32 %conv.i6.i, %conv.i5.i
905 %add.i = add nsw i32 %mul.i, %mul4.i
906 %conv.i4.i = sext i8 %scalar37 to i32
907 %conv.i3.i = sext i8 %scalar41 to i32
908 %mul7.i = mul nsw i32 %conv.i4.i, %conv.i3.i
909 %add8.i = add nsw i32 %add.i, %mul7.i
910 %conv.i2.i = sext i8 %scalar38 to i32
911 %conv.i1.i = sext i8 %scalar42 to i32
912 %mul11.i = mul nsw i32 %conv.i2.i, %conv.i1.i
913 %add12.i = add nsw i32 %add8.i, %mul11.i
914 %add13.i = add nsw i32 %14, %add12.i
915 // end matched instructions
916 store i32 %add13.i, i32 addrspace(1)* %18, align 4
917 
918 =>
919 
920 %14 = load i32, i32 addrspace(1)* %13, align 4
921 %15 = insertelement <4 x i8> undef, i8 %scalar35, i64 0
922 %16 = insertelement <4 x i8> undef, i8 %scalar39, i64 0
923 %17 = insertelement <4 x i8> %15, i8 %scalar36, i64 1
924 %18 = insertelement <4 x i8> %16, i8 %scalar40, i64 1
925 %19 = insertelement <4 x i8> %17, i8 %scalar37, i64 2
926 %20 = insertelement <4 x i8> %18, i8 %scalar41, i64 2
927 %21 = insertelement <4 x i8> %19, i8 %scalar38, i64 3
928 %22 = insertelement <4 x i8> %20, i8 %scalar42, i64 3
929 %23 = bitcast <4 x i8> %21 to i32
930 %24 = bitcast <4 x i8> %22 to i32
931 %25 = call i32 @llvm.genx.GenISA.dp4a.ss.i32(i32 %14, i32 %23, i32 %24)
932 ...
933 store i32 %25, i32 addrspace(1)* %29, align 4
934 */
matchDp4a(BinaryOperator & I)935 void CustomSafeOptPass::matchDp4a(BinaryOperator &I) {
936     using namespace llvm::PatternMatch;
937 
938     if (I.getOpcode() != Instruction::Add) return;
939     CodeGenContext* Ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
940     if (!Ctx->platform.hasHWDp4AddSupport()) return;
941 
942     static constexpr int NUM_DP4A_COMPONENTS = 4;
943 
944     // Holds sext/zext instructions.
945     std::array<Instruction *, NUM_DP4A_COMPONENTS> ExtA{}, ExtB{};
946     // Holds i8 values that were multiplied.
947     std::array<Value *, NUM_DP4A_COMPONENTS> ArrA{}, ArrB{};
948     // Holds the accumulator.
949     Value* AccVal = nullptr;
950     IRBuilder<> Builder(&I);
951 
952     // Enum to check if given branch is signed or unsigned, e.g.
953     // comes from SExt or ZExt value.
954     enum class OriginSignedness { originSigned, originUnsigned, originUnknown };
955 
956     // Note: the returned pattern from this lambda doesn't use m_ZExtOrSExt pattern on purpose -
957     // There is no way to bind to both instruction and it's arguments, so we bind to instruction and
958     // check the opcode later.
959     auto getMulPat = [&](int i) { return m_c_Mul(m_Instruction(ExtA[i]), m_Instruction(ExtB[i])); };
960     auto getAdd4Pat = [&](const auto& pat0, const auto& pat1, const auto& pat2, const auto& pat3) {
961       return m_c_Add(m_c_Add(m_c_Add(pat0, pat1), pat2), pat3);
962     };
963     auto getAdd5Pat = [&](const auto& pat0, const auto& pat1, const auto& pat2, const auto& pat3, const auto& pat4) {
964       return m_c_Add(getAdd4Pat(pat0, pat1, pat2, pat3), pat4);
965     };
966     auto PatternAccAtBeginning = getAdd5Pat(m_Value(AccVal), getMulPat(0), getMulPat(1), getMulPat(2), getMulPat(3));
967     auto PatternAccAtEnd = getAdd5Pat(getMulPat(0), getMulPat(1), getMulPat(2), getMulPat(3), m_Value(AccVal));
968     auto PatternNoAcc = getAdd4Pat(getMulPat(0), getMulPat(1), getMulPat(2), getMulPat(3));
969 
970     if (!match(&I, PatternAccAtEnd) && !match(&I, PatternAccAtBeginning)) {
971       if (match(&I, PatternNoAcc)) {
972         AccVal = Builder.getInt32(0);
973       } else {
974         return;
975       }
976     }
977 
978 
979     // Check if values in A and B all have the same extension type (sext/zext)
980     // and that they come from i8 type. A and B extension types can be
981     // different.
982     auto checkIfValuesComeFromCharType = [](auto &range,
983                                             OriginSignedness &retSign) {
984       IGC_ASSERT_MESSAGE((range.begin() != range.end()),
985                          "Cannot check empty collection.");
986 
987       auto shr24Pat = m_Shr(m_Value(), m_SpecificInt(24));
988       auto and255Pat = m_And(m_Value(), m_SpecificInt(255));
989 
990       IGC_ASSERT_MESSAGE(range.size() == NUM_DP4A_COMPONENTS,
991                          "Range too big in dp4a pattern match!");
992       std::array<OriginSignedness, NUM_DP4A_COMPONENTS> signs;
993       int counter = 0;
994       for (auto I : range) {
995         if (!I->getType()->isIntegerTy(32)) {
996           return false;
997         }
998 
999         switch (I->getOpcode()) {
1000         case Instruction::SExt:
1001         case Instruction::ZExt:
1002           if (!I->getOperand(0)->getType()->isIntegerTy(8)) {
1003             return false;
1004           }
1005           signs[counter] = (I->getOpcode() == Instruction::SExt)
1006                                ? OriginSignedness::originSigned
1007                                : OriginSignedness::originUnsigned;
1008           break;
1009         case Instruction::AShr:
1010         case Instruction::LShr:
1011           if (!match(I, shr24Pat)) {
1012             return false;
1013           }
1014           signs[counter] = (I->getOpcode() == Instruction::AShr)
1015                                ? OriginSignedness::originSigned
1016                                : OriginSignedness::originUnsigned;
1017           break;
1018         case Instruction::And:
1019           if (!match(I, and255Pat)) {
1020             return false;
1021           }
1022           signs[counter] = OriginSignedness::originUnsigned;
1023           break;
1024         default:
1025           return false;
1026         }
1027 
1028         counter++;
1029       }
1030 
1031       // check if all have the same sign.
1032       retSign = signs[0];
1033       return std::all_of(
1034           signs.begin(), signs.end(),
1035           [&signs](OriginSignedness v) { return v == signs[0]; });
1036     };
1037 
1038     OriginSignedness ABranch = OriginSignedness::originUnknown, BBranch = OriginSignedness::originUnknown;
1039     bool canMatch = AccVal->getType()->isIntegerTy(32) && checkIfValuesComeFromCharType(ExtA, ABranch) && checkIfValuesComeFromCharType(ExtB, BBranch);
1040     if (!canMatch) return;
1041 
1042     GenISAIntrinsic::ID IntrinsicID{};
1043 
1044     static_assert(ExtA.size() == ArrA.size() && ExtB.size() == ArrB.size() && ExtA.size() == NUM_DP4A_COMPONENTS, "array sizes must match!");
1045 
1046     auto getInt8Origin = [&](Instruction *I) {
1047       if (I->getOpcode() == Instruction::SExt ||
1048           I->getOpcode() == Instruction::ZExt) {
1049         return I->getOperand(0);
1050       }
1051       Builder.SetInsertPoint(I->getNextNode());
1052       return Builder.CreateTrunc(I, Builder.getInt8Ty());
1053     };
1054 
1055     std::transform(ExtA.begin(), ExtA.end(), ArrA.begin(), getInt8Origin);
1056     std::transform(ExtB.begin(), ExtB.end(), ArrB.begin(), getInt8Origin);
1057 
1058     switch (ABranch) {
1059     case OriginSignedness::originSigned:
1060       IntrinsicID = (BBranch == OriginSignedness::originSigned)
1061                         ? GenISAIntrinsic::GenISA_dp4a_ss
1062                         : GenISAIntrinsic::GenISA_dp4a_su;
1063       break;
1064     case OriginSignedness::originUnsigned:
1065       IntrinsicID = (BBranch == OriginSignedness::originSigned)
1066                         ? GenISAIntrinsic::GenISA_dp4a_us
1067                         : GenISAIntrinsic::GenISA_dp4a_uu;
1068       break;
1069     default:
1070       IGC_ASSERT(0);
1071     }
1072 
1073     // Additional optimisation: check if the values come from an ExtractElement instruction.
1074     // If this is the case, reorder the elements in the array to match the ExtractElement pattern.
1075     // This way we avoid potential shufflevector instructions which cause additional mov instructions in final asm.
1076     // Note: indices in ExtractElement do not have to be in range [0-3], they can be greater. We just want to have them ordered ascending.
1077     auto extractElementOrderOpt = [&](std::array<Value*, NUM_DP4A_COMPONENTS> & Arr) {
1078       bool CanOptOrder = true;
1079       llvm::SmallPtrSet<Value*, NUM_DP4A_COMPONENTS> OriginValues;
1080       std::map<int64_t, Value*> IndexMap;
1081       for (int i = 0; i < NUM_DP4A_COMPONENTS; ++i) {
1082         ConstantInt* IndexVal = nullptr;
1083         Value* OriginVal = nullptr;
1084         auto P = m_ExtractElt(m_Value(OriginVal), m_ConstantInt(IndexVal));
1085         if (!match(Arr[i], P)) {
1086           CanOptOrder = false;
1087           break;
1088         }
1089         OriginValues.insert(OriginVal);
1090         IndexMap.insert({ IndexVal->getSExtValue(), Arr[i] });
1091       }
1092 
1093       if (CanOptOrder && OriginValues.size() == 1 && IndexMap.size() == NUM_DP4A_COMPONENTS) {
1094         int i = 0;
1095         for(auto &El : IndexMap) {
1096           Arr[i++] = El.second;
1097         }
1098       }
1099     };
1100     extractElementOrderOpt(ArrA);
1101     extractElementOrderOpt(ArrB);
1102 
1103     Value* VectorA = UndefValue::get(IGCLLVM::FixedVectorType::get(Builder.getInt8Ty(), NUM_DP4A_COMPONENTS));
1104     Value* VectorB = UndefValue::get(IGCLLVM::FixedVectorType::get(Builder.getInt8Ty(), NUM_DP4A_COMPONENTS));
1105     for (int i = 0; i < NUM_DP4A_COMPONENTS; ++i) {
1106       VectorA = Builder.CreateInsertElement(VectorA, ArrA[i], i);
1107       VectorB = Builder.CreateInsertElement(VectorB, ArrB[i], i);
1108     }
1109     Value* ValA = Builder.CreateBitCast(VectorA, Builder.getInt32Ty());
1110     Value* ValB = Builder.CreateBitCast(VectorB, Builder.getInt32Ty());
1111 
1112     Function* Dp4aFun = GenISAIntrinsic::getDeclaration(I.getModule(), IntrinsicID, Builder.getInt32Ty());
1113     Value* Res = Builder.CreateCall(Dp4aFun, { AccVal, ValA, ValB });
1114     I.replaceAllUsesWith(Res);
1115 }
1116 
isEmulatedAdd(BinaryOperator & I)1117 bool CustomSafeOptPass::isEmulatedAdd(BinaryOperator& I)
1118 {
1119     if (I.getOpcode() == Instruction::Or)
1120     {
1121         if (BinaryOperator * OrOp0 = dyn_cast<BinaryOperator>(I.getOperand(0)))
1122         {
1123             if (OrOp0->getOpcode() == Instruction::Shl)
1124             {
1125                 // Check the SHl. If we have a constant Shift Left val then we can check
1126                 // it to see if it is emulating an add.
1127                 if (ConstantInt * pConstShiftLeft = dyn_cast<ConstantInt>(OrOp0->getOperand(1)))
1128                 {
1129                     if (ConstantInt * pConstOrVal = dyn_cast<ConstantInt>(I.getOperand(1)))
1130                     {
1131                         int const_shift = int_cast<int>(pConstShiftLeft->getZExtValue());
1132                         int const_or_val = int_cast<int>(pConstOrVal->getSExtValue());
1133                         if ((1 << const_shift) > abs(const_or_val))
1134                         {
1135                             // The value fits in the shl. So this is an emulated add.
1136                             return true;
1137                         }
1138                     }
1139                 }
1140             }
1141             else if (OrOp0->getOpcode() == Instruction::Mul)
1142             {
1143                 // Check to see if the Or is emulating and add.
1144                 // If we have a constant Mul and a constant Or. The Mul constant needs to be divisible by the rounded up 2^n of Or value.
1145                 if (ConstantInt * pConstMul = dyn_cast<ConstantInt>(OrOp0->getOperand(1)))
1146                 {
1147                     if (ConstantInt * pConstOrVal = dyn_cast<ConstantInt>(I.getOperand(1)))
1148                     {
1149                         if (pConstOrVal->isNegative() == false)
1150                         {
1151                             DWORD const_or_val = int_cast<DWORD>(pConstOrVal->getZExtValue());
1152                             DWORD nextPowerOfTwo = iSTD::RoundPower2(const_or_val + 1);
1153                             if (nextPowerOfTwo && (pConstMul->getZExtValue() % nextPowerOfTwo == 0))
1154                             {
1155                                 return true;
1156                             }
1157                         }
1158                     }
1159                 }
1160             }
1161         }
1162     }
1163     return false;
1164 }
1165 
1166 // Attempt to create new float instruction if both operands are from FPTruncInst instructions.
1167 // Example with fadd:
1168 //  %Temp-31.prec.i = fptrunc float %34 to half
1169 //  %Temp-30.prec.i = fptrunc float %33 to half
1170 //  %41 = fadd fast half %Temp-31.prec.i, %Temp-30.prec.i
1171 //  %Temp-32.i = fpext half %41 to float
1172 //
1173 //  This fadd is used as a float, and doesn't need the operands to be cased to half.
1174 //  We can remove the extra casts in this case.
1175 //  This becomes:
1176 //  %41 = fadd fast float %34, %33
1177 // Can also do matches with fadd/fmul that will later become an mad instruction.
1178 // mad example:
1179 //  %.prec70.i = fptrunc float %273 to half
1180 //  %.prec78.i = fptrunc float %276 to half
1181 //  %279 = fmul fast half %233, %.prec70.i
1182 //  %282 = fadd fast half %279, %.prec78.i
1183 //  %.prec84.i = fpext half %282 to float
1184 // This becomes:
1185 //  %279 = fpext half %233 to float
1186 //  %280 = fmul fast float %273, %279
1187 //  %281 = fadd fast float %280, %276
removeHftoFCast(Instruction & I)1188 void CustomSafeOptPass::removeHftoFCast(Instruction& I)
1189 {
1190     if (!I.getType()->isFloatingPointTy())
1191         return;
1192 
1193     // Check if the only user is a FPExtInst
1194     if (!I.hasOneUse())
1195         return;
1196 
1197     // Check if this instruction is used in a single FPExtInst
1198     FPExtInst* castInst = NULL;
1199     User* U = *I.user_begin();
1200     if (FPExtInst* inst = dyn_cast<FPExtInst>(U))
1201     {
1202         if (inst->getType()->isFloatTy())
1203         {
1204             castInst = inst;
1205         }
1206     }
1207     if (!castInst)
1208       return;
1209 
1210     // Check for fmad pattern
1211     if (I.getOpcode() == Instruction::FAdd)
1212     {
1213         Value* src0 = nullptr, * src1 = nullptr, * src2 = nullptr;
1214 
1215         // CodeGenPatternMatch::MatchMad matches the first fmul.
1216         Instruction* fmulInst = nullptr;
1217         for (uint i = 0; i < 2; i++)
1218         {
1219             fmulInst = dyn_cast<Instruction>(I.getOperand(i));
1220             if (fmulInst && fmulInst->getOpcode() == Instruction::FMul)
1221             {
1222                 src0 = fmulInst->getOperand(0);
1223                 src1 = fmulInst->getOperand(1);
1224                 src2 = I.getOperand(1 - i);
1225                 break;
1226             }
1227             else
1228             {
1229                 // Prevent other non-fmul instructions from getting used
1230                 fmulInst = nullptr;
1231             }
1232         }
1233         if (fmulInst)
1234         {
1235             // Used to get the new float operands for the new instructions
1236             auto getFloatValue = [](Value* operand, Instruction* I, Type* type)
1237             {
1238                 if (FPTruncInst* inst = dyn_cast<FPTruncInst>(operand))
1239                 {
1240                     // Use the float input of the FPTrunc
1241                     if (inst->getOperand(0)->getType()->isFloatTy())
1242                     {
1243                         return inst->getOperand(0);
1244                     }
1245                     else
1246                     {
1247                         return (Value*)NULL;
1248                     }
1249                 }
1250                 else if (Instruction* inst = dyn_cast<Instruction>(operand))
1251                 {
1252                     // Cast the result of this operand to a float
1253                     return dyn_cast<Value>(new FPExtInst(inst, type, "", I));
1254                 }
1255                 return (Value*)NULL;
1256             };
1257 
1258             int convertCount = 0;
1259             if (dyn_cast<FPTruncInst>(src0))
1260                 convertCount++;
1261             if (dyn_cast<FPTruncInst>(src1))
1262                 convertCount++;
1263             if (dyn_cast<FPTruncInst>(src2))
1264                 convertCount++;
1265             if (convertCount >= 2)
1266             {
1267                 // Conversion for the hf values
1268                 auto floatTy = castInst->getType();
1269                 src0 = getFloatValue(src0, fmulInst, floatTy);
1270                 src1 = getFloatValue(src1, fmulInst, floatTy);
1271                 src2 = getFloatValue(src2, &I, floatTy);
1272 
1273                 if (!src0 || !src1 || !src2)
1274                     return;
1275 
1276                 // Create new float fmul and fadd instructions
1277                 Value* newFmul = BinaryOperator::Create(Instruction::FMul, src0, src1, "", &I);
1278                 Value* newFadd = BinaryOperator::Create(Instruction::FAdd, newFmul, src2, "", &I);
1279 
1280                 // Copy fast math flags
1281                 Instruction* fmulInst = dyn_cast<Instruction>(newFmul);
1282                 Instruction* faddInst = dyn_cast<Instruction>(newFadd);
1283                 fmulInst->copyFastMathFlags(fmulInst);
1284                 faddInst->copyFastMathFlags(&I);
1285                 castInst->replaceAllUsesWith(faddInst);
1286                 return;
1287             }
1288         }
1289     }
1290 
1291     // Check if operands come from a Float to HF Cast
1292     Value *S1 = NULL, *S2 = NULL;
1293     if (FPTruncInst* inst = dyn_cast<FPTruncInst>(I.getOperand(0)))
1294     {
1295         if (!inst->getType()->isHalfTy())
1296           return;
1297         S1 = inst->getOperand(0);
1298     }
1299     if (FPTruncInst* inst = dyn_cast<FPTruncInst>(I.getOperand(1)))
1300     {
1301         if (!inst->getType()->isHalfTy())
1302           return;
1303         S2 = inst->getOperand(0);
1304     }
1305     if (!S1 || !S2)
1306     {
1307         return;
1308     }
1309 
1310     Value* newInst = NULL;
1311     if (BinaryOperator* bo = dyn_cast<BinaryOperator>(&I))
1312     {
1313         newInst = BinaryOperator::Create(bo->getOpcode(), S1, S2, "", &I);
1314         Instruction* inst = dyn_cast<Instruction>(newInst);
1315         inst->copyFastMathFlags(&I);
1316         castInst->replaceAllUsesWith(inst);
1317     }
1318 }
1319 
visitBinaryOperator(BinaryOperator & I)1320 void CustomSafeOptPass::visitBinaryOperator(BinaryOperator& I)
1321 {
1322     matchDp4a(I);
1323 
1324     // move immediate value in consecutive integer adds to the last added value.
1325     // this can allow more chance of doing CSE and memopt.
1326     //    a = b + 8
1327     //    d = a + c
1328     //        to
1329     //    a = b + c
1330     //    d = a + 8
1331 
1332     CodeGenContext* pContext = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
1333 
1334     // Before WA if() as it's validated behavior.
1335     if (I.getType()->isIntegerTy() && I.getOpcode() == Instruction::Or)
1336     {
1337         unsigned int bitWidth = cast<IntegerType>(I.getType())->getBitWidth();
1338         switch (bitWidth)
1339         {
1340         case 16:
1341             matchReverse<unsigned short>(I);
1342             break;
1343         case 32:
1344             matchReverse<unsigned int>(I);
1345             break;
1346         case 64:
1347             matchReverse<unsigned long long>(I);
1348             break;
1349         }
1350     }
1351 
1352     // WA for remaining bug in custom pass
1353     if (pContext->m_DriverInfo.WADisableCustomPass())
1354         return;
1355     if (I.getType()->isIntegerTy())
1356     {
1357         if ((I.getOpcode() == Instruction::Add || isEmulatedAdd(I)) &&
1358             I.hasOneUse())
1359         {
1360             ConstantInt* src0imm = dyn_cast<ConstantInt>(I.getOperand(0));
1361             ConstantInt* src1imm = dyn_cast<ConstantInt>(I.getOperand(1));
1362             if (src0imm || src1imm)
1363             {
1364                 llvm::Instruction* nextInst = llvm::dyn_cast<llvm::Instruction>(*(I.user_begin()));
1365                 if (nextInst && nextInst->getOpcode() == Instruction::Add)
1366                 {
1367                     // do not apply if any add instruction has NSW flag since we can't save it
1368                     if ((isa<OverflowingBinaryOperator>(I) && I.hasNoSignedWrap()) || nextInst->hasNoSignedWrap())
1369                         return;
1370                     ConstantInt* secondSrc0imm = dyn_cast<ConstantInt>(nextInst->getOperand(0));
1371                     ConstantInt* secondSrc1imm = dyn_cast<ConstantInt>(nextInst->getOperand(1));
1372                     // found 2 add instructions to swap srcs
1373                     if (!secondSrc0imm && !secondSrc1imm && nextInst->getOperand(0) != nextInst->getOperand(1))
1374                     {
1375                         for (int i = 0; i < 2; i++)
1376                         {
1377                             if (nextInst->getOperand(i) == &I)
1378                             {
1379                                 Value* newAdd = BinaryOperator::CreateAdd(src0imm ? I.getOperand(1) : I.getOperand(0), nextInst->getOperand(1 - i), "", nextInst);
1380 
1381                                 // Conservatively clear the NSW NUW flags, since they may not be
1382                                 // preserved by the reassociation.
1383                                 bool IsNUW = isa<OverflowingBinaryOperator>(I) && I.hasNoUnsignedWrap() && nextInst->hasNoUnsignedWrap();
1384                                 cast<BinaryOperator>(newAdd)->setHasNoUnsignedWrap(IsNUW);
1385                                 nextInst->setHasNoUnsignedWrap(IsNUW);
1386                                 nextInst->setHasNoSignedWrap(false);
1387 
1388                                 nextInst->setOperand(0, newAdd);
1389                                 nextInst->setOperand(1, I.getOperand(src0imm ? 0 : 1));
1390                                 break;
1391                             }
1392                         }
1393                     }
1394                 }
1395             }
1396         }
1397     } else if (I.getType()->isFloatingPointTy()) {
1398         removeHftoFCast(I);
1399     }
1400 }
1401 
visitLdptr(llvm::SamplerLoadIntrinsic * inst)1402 void IGC::CustomSafeOptPass::visitLdptr(llvm::SamplerLoadIntrinsic* inst)
1403 {
1404     if (!IGC_IS_FLAG_ENABLED(UseHDCTypedReadForAllTextures) &&
1405         !IGC_IS_FLAG_ENABLED(UseHDCTypedReadForAllTypedBuffers))
1406     {
1407         return;
1408     }
1409 
1410     // change
1411     // % 10 = call fast <4 x float> @llvm.genx.GenISA.ldptr.v4f32.p196608v4f32(i32 %_s1.i, i32 %_s14.i, i32 0, i32 0, <4 x float> addrspace(196608)* null, i32 0, i32 0, i32 0), !dbg !123
1412     // to
1413     // % 10 = call fast <4 x float> @llvm.genx.GenISA.typedread.p196608v4f32(<4 x float> addrspace(196608)* null, i32 %_s1.i, i32 %_s14.i, i32 0, i32 0), !dbg !123
1414     // when the index comes directly from threadid
1415 
1416     Constant* src1 = dyn_cast<Constant>(inst->getOperand(1));
1417     Constant* src2 = dyn_cast<Constant>(inst->getOperand(2));
1418     Constant* src3 = dyn_cast<Constant>(inst->getOperand(3));
1419 
1420     // src2 and src3 has to be zero
1421     if (!src2 || !src3 || !src2->isZeroValue() || !src3->isZeroValue())
1422     {
1423         return;
1424     }
1425 
1426     // if only doing the opt on buffers, make sure src1 is zero too
1427     if (!IGC_IS_FLAG_ENABLED(UseHDCTypedReadForAllTextures) &&
1428         IGC_IS_FLAG_ENABLED(UseHDCTypedReadForAllTypedBuffers))
1429     {
1430         if (!src1 || !src1->isZeroValue())
1431             return;
1432     }
1433 
1434     // do the transformation
1435     llvm::IRBuilder<> builder(inst);
1436     Module* M = inst->getParent()->getParent()->getParent();
1437 
1438     Function* pLdIntrinsic = llvm::GenISAIntrinsic::getDeclaration(
1439         M,
1440         GenISAIntrinsic::GenISA_typedread,
1441         inst->getOperand(4)->getType());
1442 
1443     SmallVector<Value*, 5> ld_FunctionArgList(5);
1444     ld_FunctionArgList[0] = inst->getTextureValue();
1445     ld_FunctionArgList[1] = builder.CreateAdd(inst->getOperand(0), inst->getOperand(inst->getTextureIndex() + 1));
1446     ld_FunctionArgList[2] = builder.CreateAdd(inst->getOperand(1), inst->getOperand(inst->getTextureIndex() + 2));
1447     ld_FunctionArgList[3] = builder.CreateAdd(inst->getOperand(3), inst->getOperand(inst->getTextureIndex() + 3));
1448     ld_FunctionArgList[4] = inst->getOperand(2);  // lod=zero
1449 
1450     llvm::CallInst* pNewCallInst = builder.CreateCall(
1451         pLdIntrinsic, ld_FunctionArgList);
1452 
1453     // as typedread returns float4 by default, bitcast it
1454     // to int4 if necessary
1455     // FIXME: is it better to make typedRead return ty a anyvector?
1456     if (inst->getType() != pNewCallInst->getType())
1457     {
1458         IGC_ASSERT_MESSAGE(inst->getType()->isVectorTy(), "expect int4 here");
1459         IGC_ASSERT_MESSAGE(cast<IGCLLVM::FixedVectorType>(inst->getType())
1460                                ->getElementType()
1461                                ->isIntegerTy(32),
1462                            "expect int4 here");
1463         IGC_ASSERT_MESSAGE(
1464             cast<IGCLLVM::FixedVectorType>(inst->getType())->getNumElements() ==
1465                 4,
1466             "expect int4 here");
1467         auto bitCastInst = builder.CreateBitCast(pNewCallInst, inst->getType());
1468         inst->replaceAllUsesWith(bitCastInst);
1469     }
1470     else
1471     {
1472         inst->replaceAllUsesWith(pNewCallInst);
1473     }
1474 }
1475 
1476 
visitLdRawVec(llvm::CallInst * inst)1477 void IGC::CustomSafeOptPass::visitLdRawVec(llvm::CallInst* inst)
1478 {
1479     //Try to optimize and remove vector ld raw and change to scalar ld raw
1480 
1481     //%a = call <4 x float> @llvm.genx.GenISA.ldrawvector.indexed.v4f32.p1441792f32(
1482     //.....float addrspace(1441792) * %243, i32 %offset, i32 4, i1 false), !dbg !216
1483     //%b = extractelement <4 x float> % 245, i32 0, !dbg !216
1484 
1485     //into
1486 
1487     //%new_offset = add i32 %offset, 0, !dbg !216
1488     //%b = call float @llvm.genx.GenISA.ldraw.indexed.f32.p1441792f32.i32.i32.i1(
1489     //.....float addrspace(1441792) * %251, i32 %new_offset, i32 4, i1 false)
1490 
1491     if (inst->hasOneUse() &&
1492         isa<ExtractElementInst>(inst->user_back()))
1493     {
1494         auto EE = cast<ExtractElementInst>(inst->user_back());
1495         if (auto constIndex = dyn_cast<ConstantInt>(EE->getIndexOperand()))
1496         {
1497             llvm::IRBuilder<> builder(inst);
1498 
1499             llvm::SmallVector<llvm::Type*, 2> ovldtypes{
1500                 EE->getType(), //float type
1501                 inst->getOperand(0)->getType(),
1502             };
1503 
1504             // For new_offset we need to take into acount the index of the Extract
1505             // and convert it to bytes and add it to the existing offset
1506             auto new_offset = constIndex->getZExtValue() * 4;
1507 
1508             llvm::SmallVector<llvm::Value*, 4> new_args{
1509                 inst->getOperand(0),
1510                 builder.CreateAdd(inst->getOperand(1),builder.getInt32((unsigned)new_offset)),
1511                 inst->getOperand(2),
1512                 inst->getOperand(3)
1513             };
1514 
1515             Function* pLdraw_indexed_intrinsic = llvm::GenISAIntrinsic::getDeclaration(
1516                 inst->getModule(),
1517                 GenISAIntrinsic::GenISA_ldraw_indexed,
1518                 ovldtypes);
1519 
1520             llvm::Value* ldraw_indexed = builder.CreateCall(pLdraw_indexed_intrinsic, new_args, "");
1521             EE->replaceAllUsesWith(ldraw_indexed);
1522         }
1523     }
1524 }
1525 
visitSampleBptr(llvm::SampleIntrinsic * sampleInst)1526 void IGC::CustomSafeOptPass::visitSampleBptr(llvm::SampleIntrinsic* sampleInst)
1527 {
1528     // sampleB with bias_value==0 -> sample
1529     llvm::ConstantFP* constBias = llvm::dyn_cast<llvm::ConstantFP>(sampleInst->getOperand(0));
1530     if (constBias && constBias->isZero())
1531     {
1532         // Copy args skipping bias operand:
1533         llvm::SmallVector<llvm::Value*, 10> args;
1534         for (unsigned int i = 1; i < sampleInst->getNumArgOperands(); i++)
1535         {
1536             args.push_back(sampleInst->getArgOperand(i));
1537         }
1538 
1539         // Copy overloaded types unchanged:
1540         llvm::SmallVector<llvm::Type*, 4> overloadedTys;
1541         overloadedTys.push_back(sampleInst->getCalledFunction()->getReturnType());
1542         overloadedTys.push_back(sampleInst->getOperand(0)->getType());
1543         overloadedTys.push_back(sampleInst->getTextureValue()->getType());
1544         overloadedTys.push_back(sampleInst->getSamplerValue()->getType());
1545 
1546         llvm::Function* sampleIntr = llvm::GenISAIntrinsic::getDeclaration(
1547             sampleInst->getParent()->getParent()->getParent(),
1548             GenISAIntrinsic::GenISA_sampleptr,
1549             overloadedTys);
1550 
1551         llvm::Value* newSample = llvm::CallInst::Create(sampleIntr, args, "", sampleInst);
1552         sampleInst->replaceAllUsesWith(newSample);
1553     }
1554 }
1555 
isIdentityMatrix(ExtractElementInst & I)1556 bool CustomSafeOptPass::isIdentityMatrix(ExtractElementInst& I)
1557 {
1558     bool found = false;
1559     auto extractType = cast<IGCLLVM::FixedVectorType>(I.getVectorOperandType());
1560     auto extractTypeVecSize = (uint32_t)extractType->getNumElements();
1561     if (extractTypeVecSize == 20 ||
1562         extractTypeVecSize == 16)
1563     {
1564         if (Constant * C = dyn_cast<Constant>(I.getVectorOperand()))
1565         {
1566             found = true;
1567 
1568             // found = true if the extractelement is like this:
1569             // %189 = extractelement <20 x float>
1570             //    <float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00,
1571             //     float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00,
1572             //     float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00,
1573             //     float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00,
1574             //     float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>, i32 %188
1575             for (unsigned int i = 0; i < extractTypeVecSize; i++)
1576             {
1577                 if (i == 0 || i == 5 || i == 10 || i == 15)
1578                 {
1579                     ConstantFP* fpC = dyn_cast<ConstantFP>(C->getAggregateElement(i));
1580                     ConstantInt* intC = dyn_cast<ConstantInt>(C->getAggregateElement(i));
1581                     if((fpC && !fpC->isExactlyValue(1.f)) || (intC && !intC->isAllOnesValue()))
1582                     {
1583                         found = false;
1584                         break;
1585                     }
1586                 }
1587                 else if (!C->getAggregateElement(i)->isZeroValue())
1588                 {
1589                     found = false;
1590                     break;
1591                 }
1592             }
1593         }
1594     }
1595     return found;
1596 }
1597 
dp4WithIdentityMatrix(ExtractElementInst & I)1598 void CustomSafeOptPass::dp4WithIdentityMatrix(ExtractElementInst& I)
1599 {
1600     /*
1601     convert dp4 with identity matrix icb ( ex: "dp4 r[6].x, cb[2][8].xyzw, icb[ r[4].w].xyzw") from
1602         %189 = shl nuw nsw i32 %188, 2, !dbg !326
1603         %190 = extractelement <20 x float> <float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>, i32 %189, !dbg !326
1604         %191 = or i32 %189, 1, !dbg !326
1605         %192 = extractelement <20 x float> <float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>, i32 %191, !dbg !326
1606         %193 = or i32 %189, 2, !dbg !326
1607         %194 = extractelement <20 x float> <float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>, i32 %193, !dbg !326
1608         %195 = or i32 %189, 3, !dbg !326
1609         %196 = extractelement <20 x float> <float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>, i32 %195, !dbg !326
1610         %s01_s.chan0141 = fmul fast float %181, %190, !dbg !326
1611         %197 = fmul fast float %182, %192, !dbg !326
1612         %198 = fadd fast float %197, %s01_s.chan0141, !dbg !326
1613         %199 = fmul fast float %183, %194, !dbg !326
1614         %200 = fadd fast float %199, %198, !dbg !326
1615         %201 = fmul fast float %184, %196, !dbg !326
1616         %202 = fadd fast float %201, %200, !dbg !326
1617     to
1618         %533 = icmp eq i32 %532, 0, !dbg !434
1619         %534 = icmp eq i32 %532, 1, !dbg !434
1620         %535 = icmp eq i32 %532, 2, !dbg !434
1621         %536 = icmp eq i32 %532, 3, !dbg !434
1622         %537 = select i1 %533, float %525, float 0.000000e+00, !dbg !434
1623         %538 = select i1 %534, float %526, float %537, !dbg !434
1624         %539 = select i1 %535, float %527, float %538, !dbg !434
1625         %540 = select i1 %536, float %528, float %539, !dbg !434
1626     */
1627 
1628     // check %190 = extractelement <20 x float> <float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00 ...
1629     if (!I.hasOneUse() || !isIdentityMatrix(I))
1630         return;
1631 
1632     Instruction* offset[4] = {nullptr, nullptr, nullptr, nullptr};
1633     ExtractElementInst* eeInst[4] = { &I, nullptr, nullptr, nullptr };
1634 
1635     // check %189 = shl nuw nsw i32 %188, 2, !dbg !326
1636     // put it in offset[0]
1637     offset[0] = dyn_cast<BinaryOperator>(I.getOperand(1));
1638     if (!offset[0] ||
1639         offset[0]->getOpcode() != Instruction::Shl ||
1640         offset[0]->getOperand(1) != ConstantInt::get(offset[0]->getOperand(1)->getType(), 2))
1641     {
1642         return;
1643     }
1644 
1645     // check %191 = or i32 %189, 1, !dbg !326
1646     //       %193 = or i32 % 189, 2, !dbg !326
1647     //       %195 = or i32 % 189, 3, !dbg !326
1648     // put them in offset[1], offset[2], offset[3]
1649     for (auto iter = offset[0]->user_begin(); iter != offset[0]->user_end(); iter++)
1650     {
1651         // skip checking for the %190 = extractelement <20 x float> <float 1.000000e+00, ....
1652         if (*iter == &I)
1653             continue;
1654 
1655         if (BinaryOperator * orInst = dyn_cast<BinaryOperator>(*iter))
1656         {
1657             if (orInst->getOpcode() == BinaryOperator::Or && orInst->hasOneUse())
1658             {
1659                 if (ConstantInt * orSrc1 = dyn_cast<ConstantInt>(orInst->getOperand(1)))
1660                 {
1661                     if (orSrc1->getZExtValue() < 4)
1662                     {
1663                         offset[orSrc1->getZExtValue()] = orInst;
1664                     }
1665                 }
1666             }
1667         }
1668     }
1669 
1670     for (int i = 0; i < 4; i++)
1671     {
1672         if (offset[i] == nullptr)
1673             return;
1674     }
1675 
1676     // check %192 = extractelement <20 x float> ...
1677     //       %194 = extractelement <20 x float> ...
1678     //       %196 = extractelement <20 x float> ...
1679     // put them in eeInst[i]
1680     for (int i = 1; i < 4; i++)
1681     {
1682         eeInst[i] = dyn_cast<ExtractElementInst>(*offset[i]->user_begin());
1683         if (!eeInst[i] || !isIdentityMatrix(*eeInst[i]))
1684         {
1685             return;
1686         }
1687     }
1688 
1689     // check dp4 and put them in mulI[] and addI[]
1690     Instruction* mulI[4] = { nullptr, nullptr, nullptr, nullptr };
1691     for (int i = 0; i < 4; i++)
1692     {
1693         mulI[i] = dyn_cast<Instruction>(*eeInst[i]->user_begin());
1694         if (mulI[i] == nullptr || !mulI[i]->hasOneUse())
1695         {
1696             return;
1697         }
1698     }
1699     int inputInSrcIndex = 0;
1700     if (mulI[0]->getOperand(0) == eeInst[0])
1701         inputInSrcIndex = 1;
1702 
1703     // the 1st and 2nd mul are the srcs for add
1704     if (*mulI[0]->user_begin() != *mulI[1]->user_begin())
1705     {
1706         return;
1707     }
1708     Instruction* addI[3] = { nullptr, nullptr, nullptr };
1709     addI[0] = dyn_cast<Instruction>(*mulI[0]->user_begin());
1710     addI[1] = dyn_cast<Instruction>(*mulI[2]->user_begin());
1711     addI[2] = dyn_cast<Instruction>(*mulI[3]->user_begin());
1712 
1713     if( addI[0] == nullptr ||
1714         addI[1] == nullptr ||
1715         addI[2] == nullptr ||
1716         !addI[0]->hasOneUse() ||
1717         !addI[1]->hasOneUse() ||
1718         *addI[0]->user_begin() != *mulI[2]->user_begin() ||
1719         *addI[1]->user_begin() != *mulI[3]->user_begin())
1720     {
1721         return;
1722     }
1723 
1724     // start the conversion
1725     IRBuilder<> builder(addI[2]);
1726 
1727     Value* cond0 = builder.CreateICmp(ICmpInst::ICMP_EQ, offset[0]->getOperand(0), ConstantInt::get(offset[0]->getOperand(0)->getType(), 0));
1728     Value* cond1 = builder.CreateICmp(ICmpInst::ICMP_EQ, offset[0]->getOperand(0), ConstantInt::get(offset[0]->getOperand(0)->getType(), 1));
1729     Value* cond2 = builder.CreateICmp(ICmpInst::ICMP_EQ, offset[0]->getOperand(0), ConstantInt::get(offset[0]->getOperand(0)->getType(), 2));
1730     Value* cond3 = builder.CreateICmp(ICmpInst::ICMP_EQ, offset[0]->getOperand(0), ConstantInt::get(offset[0]->getOperand(0)->getType(), 3));
1731 
1732     Value* zero = ConstantFP::get(Type::getFloatTy(I.getContext()), 0);
1733     Value* sel0 = builder.CreateSelect(cond0, mulI[0]->getOperand(inputInSrcIndex), zero);
1734     Value* sel1 = builder.CreateSelect(cond1, mulI[1]->getOperand(inputInSrcIndex), sel0);
1735     Value* sel2 = builder.CreateSelect(cond2, mulI[2]->getOperand(inputInSrcIndex), sel1);
1736     Value* sel3 = builder.CreateSelect(cond3, mulI[3]->getOperand(inputInSrcIndex), sel2);
1737 
1738     addI[2]->replaceAllUsesWith(sel3);
1739 }
1740 
1741 
visitExtractElementInst(ExtractElementInst & I)1742 void CustomSafeOptPass::visitExtractElementInst(ExtractElementInst& I)
1743 {
1744     // convert:
1745     // %1 = lshr i32 %0, 16,
1746     // %2 = bitcast i32 %1 to <2 x half>
1747     // %3 = extractelement <2 x half> %2, i32 0
1748     // to ->
1749     // %2 = bitcast i32 %0 to <2 x half>
1750     // %3 = extractelement <2 x half> %2, i32 1
1751     BitCastInst* bitCast = dyn_cast<BitCastInst>(I.getVectorOperand());
1752     ConstantInt* cstIndex = dyn_cast<ConstantInt>(I.getIndexOperand());
1753     if (bitCast && cstIndex)
1754     {
1755         // skip intermediate bitcast
1756         while (isa<BitCastInst>(bitCast->getOperand(0)))
1757         {
1758             bitCast = cast<BitCastInst>(bitCast->getOperand(0));
1759         }
1760         if (BinaryOperator * binOp = dyn_cast<BinaryOperator>(bitCast->getOperand(0)))
1761         {
1762             unsigned int bitShift = 0;
1763             bool rightShift = false;
1764             if (binOp->getOpcode() == Instruction::LShr)
1765             {
1766                 if (ConstantInt * cstShift = dyn_cast<ConstantInt>(binOp->getOperand(1)))
1767                 {
1768                     bitShift = (unsigned int)cstShift->getZExtValue();
1769                     rightShift = true;
1770                 }
1771             }
1772             else if (binOp->getOpcode() == Instruction::Shl)
1773             {
1774                 if (ConstantInt * cstShift = dyn_cast<ConstantInt>(binOp->getOperand(1)))
1775                 {
1776                     bitShift = (unsigned int)cstShift->getZExtValue();
1777                 }
1778             }
1779             if (bitShift != 0)
1780             {
1781                 Type* vecType = I.getVectorOperand()->getType();
1782                 unsigned int eltSize = (unsigned int)cast<VectorType>(vecType)->getElementType()->getPrimitiveSizeInBits();
1783                 if (bitShift % eltSize == 0)
1784                 {
1785                     int elOffset = (int)(bitShift / eltSize);
1786                     elOffset = rightShift ? elOffset : -elOffset;
1787                     unsigned int newIndex = (unsigned int)((int)cstIndex->getZExtValue() + elOffset);
1788                     if (newIndex < cast<IGCLLVM::FixedVectorType>(vecType)->getNumElements())
1789                     {
1790                         IRBuilder<> builder(&I);
1791                         Value* newBitCast = builder.CreateBitCast(binOp->getOperand(0), vecType);
1792                         Value* newScalar = builder.CreateExtractElement(newBitCast, newIndex);
1793                         I.replaceAllUsesWith(newScalar);
1794                         I.eraseFromParent();
1795                         return;
1796                     }
1797                 }
1798             }
1799         }
1800     }
1801 
1802     dp4WithIdentityMatrix(I);
1803 }
1804 
1805 #if LLVM_VERSION_MAJOR >= 7
1806 ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
1807 // This pass removes dead local memory loads and stores. If we remove all such loads and stores, we also
1808 // remove all local memory fences together with barriers that follow.
1809 //
1810 IGC_INITIALIZE_PASS_BEGIN(TrivialLocalMemoryOpsElimination, "TrivialLocalMemoryOpsElimination", "TrivialLocalMemoryOpsElimination", false, false)
1811 IGC_INITIALIZE_PASS_END(TrivialLocalMemoryOpsElimination, "TrivialLocalMemoryOpsElimination", "TrivialLocalMemoryOpsElimination", false, false)
1812 
1813 char TrivialLocalMemoryOpsElimination::ID = 0;
1814 
TrivialLocalMemoryOpsElimination()1815 TrivialLocalMemoryOpsElimination::TrivialLocalMemoryOpsElimination() : FunctionPass(ID)
1816 {
1817     initializeTrivialLocalMemoryOpsEliminationPass(*PassRegistry::getPassRegistry());
1818 }
1819 
runOnFunction(Function & F)1820 bool TrivialLocalMemoryOpsElimination::runOnFunction(Function& F)
1821 {
1822     bool change = false;
1823 
1824     IGCMD::MetaDataUtils* pMdUtil = getAnalysis<MetaDataUtilsWrapper>().getMetaDataUtils();
1825     if (!isEntryFunc(pMdUtil, &F))
1826     {
1827         // Skip if it is non-entry function.  For example, a subroutine
1828         //   foo ( local int* p) { ...... store v, p; ......}
1829         // in which no localMemoptimization will be performed.
1830         return change;
1831     }
1832 
1833     visit(F);
1834     if (!abortPass && (m_LocalLoadsToRemove.empty() ^ m_LocalStoresToRemove.empty()))
1835     {
1836         for (StoreInst* Inst : m_LocalStoresToRemove)
1837         {
1838             Inst->eraseFromParent();
1839             change = true;
1840         }
1841 
1842         for (LoadInst* Inst : m_LocalLoadsToRemove)
1843         {
1844             if (Inst->use_empty())
1845             {
1846                 Inst->eraseFromParent();
1847                 change = true;
1848             }
1849         }
1850 
1851         for (CallInst* Inst : m_LocalFencesBariersToRemove)
1852         {
1853             Inst->eraseFromParent();
1854             change = true;
1855         }
1856     }
1857     m_LocalStoresToRemove.clear();
1858     m_LocalLoadsToRemove.clear();
1859     m_LocalFencesBariersToRemove.clear();
1860 
1861 
1862     return change;
1863 }
1864 
1865 /*
1866 OCL instruction barrier(CLK_LOCAL_MEM_FENCE); is translate to two instructions
1867 call void @llvm.genx.GenISA.memoryfence(i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true)
1868 call void @llvm.genx.GenISA.threadgroupbarrier()
1869 
1870 if we remove call void @llvm.genx.GenISA.memoryfence(i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true)
1871 we must remove next instruction if it is call void @llvm.genx.GenISA.threadgroupbarrier()
1872 */
findNextThreadGroupBarrierInst(Instruction & I)1873 void TrivialLocalMemoryOpsElimination::findNextThreadGroupBarrierInst(Instruction& I)
1874 {
1875     auto nextInst = I.getNextNonDebugInstruction();
1876     if (isa<GenIntrinsicInst>(nextInst))
1877     {
1878         GenIntrinsicInst* II = dyn_cast<GenIntrinsicInst>(nextInst);
1879         if (II->getIntrinsicID() == GenISAIntrinsic::GenISA_threadgroupbarrier)
1880         {
1881             m_LocalFencesBariersToRemove.push_back(dyn_cast<CallInst>(nextInst));
1882         }
1883     }
1884 }
1885 
visitLoadInst(LoadInst & I)1886 void TrivialLocalMemoryOpsElimination::visitLoadInst(LoadInst& I)
1887 {
1888     if (I.getPointerAddressSpace() == ADDRESS_SPACE_LOCAL)
1889     {
1890         m_LocalLoadsToRemove.push_back(&I);
1891     }
1892     else if (I.getPointerAddressSpace() == ADDRESS_SPACE_GENERIC)
1893     {
1894         abortPass = true;
1895     }
1896 }
1897 
visitStoreInst(StoreInst & I)1898 void TrivialLocalMemoryOpsElimination::visitStoreInst(StoreInst& I)
1899 {
1900     if (I.getPointerAddressSpace() == ADDRESS_SPACE_LOCAL)
1901     {
1902         m_LocalStoresToRemove.push_back(&I);
1903     }
1904     else if (I.getPointerAddressSpace() == ADDRESS_SPACE_GENERIC)
1905     {
1906         abortPass = true;
1907     }
1908 }
1909 
isLocalBarrier(CallInst & I)1910 bool TrivialLocalMemoryOpsElimination::isLocalBarrier(CallInst& I)
1911 {
1912     //check arguments in call void @llvm.genx.GenISA.memoryfence(i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true) if match to
1913     // (i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true) it is local barrier
1914     std::vector<bool> argumentsOfMemoryBarrier;
1915 
1916     for (auto arg = I.arg_begin(); arg != I.arg_end(); ++arg)
1917     {
1918         ConstantInt* ci = dyn_cast<ConstantInt>(arg);
1919         if (ci) {
1920             argumentsOfMemoryBarrier.push_back(ci->getValue().getBoolValue());
1921         }
1922         else {
1923             // argument is not a constant, so we can't tell.
1924             return false;
1925         }
1926     }
1927 
1928     return argumentsOfMemoryBarrier == m_argumentsOfLocalMemoryBarrier;
1929 }
1930 
1931 // If any call instruction use pointer to local memory abort pass execution
anyCallInstUseLocalMemory(CallInst & I)1932 void TrivialLocalMemoryOpsElimination::anyCallInstUseLocalMemory(CallInst& I)
1933 {
1934     Function* fn = I.getCalledFunction();
1935 
1936     if (fn != NULL)
1937     {
1938         for (auto arg = fn->arg_begin(); arg != fn->arg_end(); ++arg)
1939         {
1940             if (arg->getType()->isPointerTy())
1941             {
1942                 if (arg->getType()->getPointerAddressSpace() == ADDRESS_SPACE_LOCAL || arg->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GENERIC) abortPass = true;
1943             }
1944         }
1945     }
1946 }
1947 
visitCallInst(CallInst & I)1948 void TrivialLocalMemoryOpsElimination::visitCallInst(CallInst& I)
1949 {
1950     // detect only: llvm.genx.GenISA.memoryfence(i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true)
1951     // (note: the first and last arguments are true)
1952     // and add them with immediately following barriers to m_LocalFencesBariersToRemove
1953     anyCallInstUseLocalMemory(I);
1954 
1955     if (isa<GenIntrinsicInst>(I))
1956     {
1957         GenIntrinsicInst* II = dyn_cast<GenIntrinsicInst>(&I);
1958         if (II->getIntrinsicID() == GenISAIntrinsic::GenISA_memoryfence)
1959         {
1960             if (isLocalBarrier(I))
1961             {
1962                 m_LocalFencesBariersToRemove.push_back(&I);
1963                 findNextThreadGroupBarrierInst(I);
1964             }
1965         }
1966     }
1967 
1968 }
1969 #endif
1970 
1971 // Register pass to igc-opt
1972 #define PASS_FLAG2 "igc-gen-specific-pattern"
1973 #define PASS_DESCRIPTION2 "LastPatternMatch Pass"
1974 #define PASS_CFG_ONLY2 false
1975 #define PASS_ANALYSIS2 false
1976 IGC_INITIALIZE_PASS_BEGIN(GenSpecificPattern, PASS_FLAG2, PASS_DESCRIPTION2, PASS_CFG_ONLY2, PASS_ANALYSIS2)
1977 IGC_INITIALIZE_PASS_END(GenSpecificPattern, PASS_FLAG2, PASS_DESCRIPTION2, PASS_CFG_ONLY2, PASS_ANALYSIS2)
1978 
1979 char GenSpecificPattern::ID = 0;
1980 
GenSpecificPattern()1981 GenSpecificPattern::GenSpecificPattern() : FunctionPass(ID)
1982 {
1983     initializeGenSpecificPatternPass(*PassRegistry::getPassRegistry());
1984 }
1985 
runOnFunction(Function & F)1986 bool GenSpecificPattern::runOnFunction(Function& F)
1987 {
1988     visit(F);
1989     return true;
1990 }
1991 
1992 // Lower SDiv to better code sequence if possible
visitSDiv(llvm::BinaryOperator & I)1993 void GenSpecificPattern::visitSDiv(llvm::BinaryOperator& I)
1994 {
1995     if (ConstantInt * divisor = dyn_cast<ConstantInt>(I.getOperand(1)))
1996     {
1997         // signed division of power of 2 can be transformed to asr
1998         // For negative values we need to make sure we round correctly
1999         int log2Div = divisor->getValue().exactLogBase2();
2000         if (log2Div > 0)
2001         {
2002             unsigned int intWidth = divisor->getBitWidth();
2003             IRBuilder<> builder(&I);
2004             Value* signedBitOnly = I.getOperand(0);
2005             if (log2Div > 1)
2006             {
2007                 signedBitOnly = builder.CreateAShr(signedBitOnly, builder.getIntN(intWidth, intWidth - 1));
2008             }
2009             Value* offset = builder.CreateLShr(signedBitOnly, builder.getIntN(intWidth, intWidth - log2Div));
2010             Value* offsetedValue = builder.CreateAdd(I.getOperand(0), offset);
2011             Value* newValue = builder.CreateAShr(offsetedValue, builder.getIntN(intWidth, log2Div));
2012             I.replaceAllUsesWith(newValue);
2013             I.eraseFromParent();
2014         }
2015     }
2016 }
2017 
2018 /*
2019 Optimizes bit reversing pattern:
2020 
2021 %and = shl i32 %0, 1
2022 %shl = and i32 %and, 0xAAAAAAAA
2023 %and2 = lshr i32 %0, 1
2024 %shr = and i32 %and2, 0x55555555
2025 %or = or i32 %shl, %shr
2026 %and3 = shl i32 %or, 2
2027 %shl4 = and i32 %and3, 0xCCCCCCCC
2028 %and5 = lshr i32 %or, 2
2029 %shr6 = and i32 %and5, 0x33333333
2030 %or7 = or i32 %shl4, %shr6
2031 %and8 = shl i32 %or7, 4
2032 %shl9 = and i32 %and8, 0xF0F0F0F0
2033 %and10 = lshr i32 %or7, 4
2034 %shr11 = and i32 %and10, 0x0F0F0F0F
2035 %or12 = or i32 %shl9, %shr11
2036 %and13 = shl i32 %or12, 8
2037 %shl14 = and i32 %and13, 0xFF00FF00
2038 %and15 = lshr i32 %or12, 8
2039 %shr16 = and i32 %and15, 0x00FF00FF
2040 %or17 = or i32 %shl14, %shr16
2041 %shl19 = shl i32 %or17, 16
2042 %shr21 = lshr i32 %or17, 16
2043 %or22 = or i32 %shl19, %shr21
2044 
2045 into:
2046 
2047 %or22 = call i32 @llvm.genx.GenISA.bfrev.i32(i32 %0)
2048 
2049 And similarly for patterns reversing 16 and 64 bit type values.
2050 */
2051 template <typename MaskType>
matchReverse(BinaryOperator & I)2052 void CustomSafeOptPass::matchReverse(BinaryOperator& I)
2053 {
2054     using namespace llvm::PatternMatch;
2055     IGC_ASSERT(I.getType()->isIntegerTy());
2056     Value* nextOrShl = nullptr, * nextOrShr = nullptr;
2057     uint64_t currentShiftShl = 0, currentShiftShr = 0;
2058     uint64_t currentMaskShl = 0, currentMaskShr = 0;
2059     auto patternBfrevFirst =
2060         m_Or(
2061             m_Shl(m_Value(nextOrShl), m_ConstantInt(currentShiftShl)),
2062             m_LShr(m_Value(nextOrShr), m_ConstantInt(currentShiftShr)));
2063 
2064     auto patternBfrev =
2065         m_Or(
2066             m_And(
2067                 m_Shl(m_Value(nextOrShl), m_ConstantInt(currentShiftShl)),
2068                 m_ConstantInt(currentMaskShl)),
2069             m_And(
2070                 m_LShr(m_Value(nextOrShr), m_ConstantInt(currentShiftShr)),
2071                 m_ConstantInt(currentMaskShr)));
2072 
2073     unsigned int bitWidth = std::numeric_limits<MaskType>::digits;
2074     IGC_ASSERT(bitWidth == 16 || bitWidth == 32 || bitWidth == 64);
2075 
2076     unsigned int currentShift = bitWidth / 2;
2077     // First mask is a value with all upper half bits present.
2078     MaskType mask = std::numeric_limits<MaskType>::max() << currentShift;
2079 
2080     bool isBfrevMatchFound = false;
2081     nextOrShl = &I;
2082     if (match(nextOrShl, patternBfrevFirst) &&
2083         nextOrShl == nextOrShr &&
2084         currentShiftShl == currentShift &&
2085         currentShiftShr == currentShift)
2086     {
2087         // NextOrShl is assigned to next one by match().
2088         currentShift /= 2;
2089         // Constructing next mask to match.
2090         mask ^= mask >> currentShift;
2091     }
2092 
2093     while (currentShift > 0)
2094     {
2095         if (match(nextOrShl, patternBfrev) &&
2096             nextOrShl == nextOrShr &&
2097             currentShiftShl == currentShift &&
2098             currentShiftShr == currentShift &&
2099             currentMaskShl == mask &&
2100             currentMaskShr == (MaskType)~mask)
2101         {
2102             // NextOrShl is assigned to next one by match().
2103             if (currentShift == 1)
2104             {
2105                 isBfrevMatchFound = true;
2106                 break;
2107             }
2108 
2109             currentShift /= 2;
2110             // Constructing next mask to match.
2111             mask ^= mask >> currentShift;
2112         }
2113         else
2114         {
2115             break;
2116         }
2117     }
2118 
2119     if (isBfrevMatchFound)
2120     {
2121         llvm::IRBuilder<> builder(&I);
2122         Function* bfrevFunc = GenISAIntrinsic::getDeclaration(
2123             I.getParent()->getParent()->getParent(), GenISAIntrinsic::GenISA_bfrev, builder.getInt32Ty());
2124         if (bitWidth == 16)
2125         {
2126             Value* zext = builder.CreateZExt(nextOrShl, builder.getInt32Ty());
2127             Value* bfrev = builder.CreateCall(bfrevFunc, zext);
2128             Value* lshr = builder.CreateLShr(bfrev, 16);
2129             Value* trunc = builder.CreateTrunc(lshr, I.getType());
2130             I.replaceAllUsesWith(trunc);
2131         }
2132         else if (bitWidth == 32)
2133         {
2134             Value* bfrev = builder.CreateCall(bfrevFunc, nextOrShl);
2135             I.replaceAllUsesWith(bfrev);
2136         }
2137         else
2138         { // bitWidth == 64
2139             Value* int32Source = builder.CreateBitCast(nextOrShl, IGCLLVM::FixedVectorType::get(builder.getInt32Ty(), 2));
2140             Value* extractElement0 = builder.CreateExtractElement(int32Source, builder.getInt32(0));
2141             Value* extractElement1 = builder.CreateExtractElement(int32Source, builder.getInt32(1));
2142             Value* bfrevLow = builder.CreateCall(bfrevFunc, extractElement0);
2143             Value* bfrevHigh = builder.CreateCall(bfrevFunc, extractElement1);
2144             Value* bfrev64Result = llvm::UndefValue::get(int32Source->getType());
2145             bfrev64Result = builder.CreateInsertElement(bfrev64Result, bfrevHigh, builder.getInt32(0));
2146             bfrev64Result = builder.CreateInsertElement(bfrev64Result, bfrevLow, builder.getInt32(1));
2147             Value* bfrevBitcast = builder.CreateBitCast(bfrev64Result, I.getType());
2148             I.replaceAllUsesWith(bfrevBitcast);
2149         }
2150     }
2151 }
2152 
2153 /* Transforms pattern1 or pattern2 to a bitcast,extract,insert,insert,bitcast
2154 
2155     From:
2156         %5 = zext i32 %a to i64 <--- optional
2157         %6 = shl i64 %5, 32
2158         or
2159         %6 = and i64 %5, 0xFFFFFFFF00000000
2160     To:
2161         %BC = bitcast i64 %5 to <2 x i32> <---- not needed when %5 is zext
2162         %6 = extractelement <2 x i32> %BC, i32 0/1 <---- not needed when %5 is zext
2163         %7 = insertelement <2 x i32> %vec, i32 0, i32 0
2164         %8 = insertelement <2 x i32> %vec, %6, i32 1
2165         %9 = bitcast <2 x i32> %8 to i64
2166  */
createBitcastExtractInsertPattern(BinaryOperator & I,Value * OpLow,Value * OpHi,unsigned extractNum1,unsigned extractNum2)2167 void GenSpecificPattern::createBitcastExtractInsertPattern(BinaryOperator& I, Value* OpLow, Value* OpHi, unsigned extractNum1, unsigned extractNum2)
2168 {
2169     if (IGC_IS_FLAG_DISABLED(EnableBitcastExtractInsertPattern)) {
2170         return;
2171     }
2172 
2173     llvm::IRBuilder<> builder(&I);
2174     auto vec2 = IGCLLVM::FixedVectorType::get(builder.getInt32Ty(), 2);
2175     Value* vec = UndefValue::get(vec2);
2176     Value* elemLow = nullptr;
2177     Value* elemHi = nullptr;
2178 
2179     auto zeroextorNot = [&](Value* Op, unsigned num) -> Value *
2180     {
2181         Value* elem = nullptr;
2182         if (auto ZextInst = dyn_cast<ZExtInst>(Op))
2183         {
2184             if (ZextInst->getDestTy() == builder.getInt64Ty() && ZextInst->getSrcTy() == builder.getInt32Ty())
2185             {
2186                 elem = ZextInst->getOperand(0);
2187             }
2188         }
2189         else if (auto IEIInst = dyn_cast<InsertElementInst>(Op))
2190         {
2191             auto opType = IEIInst->getType();
2192             if (opType->isVectorTy() && cast<VectorType>(opType)->getElementType()->isIntegerTy(32) && cast<IGCLLVM::FixedVectorType>(opType)->getNumElements() == 2)
2193             {
2194                 elem = IEIInst->getOperand(1);
2195             }
2196         }
2197         else
2198         {
2199             Value* BC = builder.CreateBitCast(Op, vec2);
2200             elem = builder.CreateExtractElement(BC, builder.getInt32(num));
2201         }
2202         return elem;
2203     };
2204 
2205     elemLow = (OpLow == nullptr) ? builder.getInt32(0) : zeroextorNot(OpLow, extractNum1);
2206     elemHi = (OpHi == nullptr) ? builder.getInt32(0) : zeroextorNot(OpHi, extractNum2);
2207 
2208     if (elemHi == nullptr || elemLow == nullptr)
2209         return;
2210 
2211     vec = builder.CreateInsertElement(vec, elemLow, builder.getInt32(0));
2212     vec = builder.CreateInsertElement(vec, elemHi, builder.getInt32(1));
2213     vec = builder.CreateBitCast(vec, builder.getInt64Ty());
2214     I.replaceAllUsesWith(vec);
2215     I.eraseFromParent();
2216 }
2217 
visitBinaryOperator(BinaryOperator & I)2218 void GenSpecificPattern::visitBinaryOperator(BinaryOperator& I)
2219 {
2220     if (I.getOpcode() == Instruction::Or)
2221     {
2222         using namespace llvm::PatternMatch;
2223 
2224         /*
2225         llvm changes ADD to OR when possible, and this optimization changes it back and allow 2 ADDs to merge.
2226         This can avoid scattered read for constant buffer when the index is calculated by shl + or + add.
2227 
2228         ex:
2229         from
2230         %22 = shl i32 %14, 2
2231         %23 = or i32 %22, 3
2232         %24 = add i32 %23, 16
2233         to
2234         %22 = shl i32 %14, 2
2235         %23 = add i32 %22, 19
2236         */
2237         Value* AndOp1 = nullptr, * EltOp1 = nullptr;
2238         auto pattern1 = m_Or(
2239             m_And(m_Value(AndOp1), m_SpecificInt(0xFFFFFFFF)),
2240             m_Shl(m_Value(EltOp1), m_SpecificInt(32)));
2241 #if LLVM_VERSION_MAJOR >= 7
2242         Value * AndOp2 = nullptr, *EltOp2 = nullptr, *VecOp = nullptr;
2243         auto pattern2 = m_Or(
2244             m_And(m_Value(AndOp2), m_SpecificInt(0xFFFFFFFF)),
2245             m_BitCast(m_InsertElt(m_Value(VecOp), m_Value(EltOp2), m_SpecificInt(1))));
2246 #endif // LLVM_VERSION_MAJOR >= 7
2247         if (match(&I, pattern1) && AndOp1->getType()->isIntegerTy(64))
2248         {
2249             createBitcastExtractInsertPattern(I, AndOp1, EltOp1, 0, 1);
2250         }
2251 #if LLVM_VERSION_MAJOR >= 7
2252         else if (match(&I, pattern2) && AndOp2->getType()->isIntegerTy(64))
2253         {
2254             ConstantVector* cVec = dyn_cast<ConstantVector>(VecOp);
2255             IGCLLVM::FixedVectorType* vector_type = dyn_cast<IGCLLVM::FixedVectorType>(VecOp->getType());
2256             if (cVec && vector_type &&
2257                 isa<ConstantInt>(cVec->getOperand(0)) &&
2258                 cast<ConstantInt>(cVec->getOperand(0))->isZero() &&
2259                 vector_type->getElementType()->isIntegerTy(32) &&
2260                 vector_type->getNumElements() == 2)
2261             {
2262                 auto InsertOp = cast<BitCastInst>(I.getOperand(1))->getOperand(0);
2263                 createBitcastExtractInsertPattern(I, AndOp2, InsertOp, 0, 1);
2264             }
2265         }
2266 #endif // LLVM_VERSION_MAJOR >= 7
2267         else
2268         {
2269             /*
2270             from
2271                 % 22 = shl i32 % 14, 2
2272                 % 23 = or i32 % 22, 3
2273             to
2274                 % 22 = shl i32 % 14, 2
2275                 % 23 = add i32 % 22, 3
2276             */
2277             ConstantInt* OrConstant = dyn_cast<ConstantInt>(I.getOperand(1));
2278             if (OrConstant)
2279             {
2280                 llvm::Instruction* ShlInst = llvm::dyn_cast<llvm::Instruction>(I.getOperand(0));
2281                 if (ShlInst && ShlInst->getOpcode() == Instruction::Shl)
2282                 {
2283                     ConstantInt* ShlConstant = dyn_cast<ConstantInt>(ShlInst->getOperand(1));
2284                     if (ShlConstant)
2285                     {
2286                         // if the constant bit width is larger than 64, we cannot store ShlIntValue and OrIntValue rawdata as uint64_t.
2287                         // will need a fix then
2288                         IGC_ASSERT(ShlConstant->getBitWidth() <= 64);
2289                         IGC_ASSERT(OrConstant->getBitWidth() <= 64);
2290 
2291                         uint64_t ShlIntValue = *(ShlConstant->getValue()).getRawData();
2292                         uint64_t OrIntValue = *(OrConstant->getValue()).getRawData();
2293 
2294                         if (OrIntValue < pow(2, ShlIntValue))
2295                         {
2296                             Value* newAdd = BinaryOperator::CreateAdd(I.getOperand(0), I.getOperand(1), "", &I);
2297                             I.replaceAllUsesWith(newAdd);
2298                         }
2299                     }
2300                 }
2301             }
2302         }
2303     }
2304     else if (I.getOpcode() == Instruction::Add)
2305     {
2306         /*
2307         from
2308             %23 = add i32 %22, 3
2309             %24 = add i32 %23, 16
2310         to
2311             %24 = add i32 %22, 19
2312         */
2313         for (int ImmSrcId1 = 0; ImmSrcId1 < 2; ImmSrcId1++)
2314         {
2315             ConstantInt* IConstant = dyn_cast<ConstantInt>(I.getOperand(ImmSrcId1));
2316             if (IConstant)
2317             {
2318                 llvm::Instruction* AddInst = llvm::dyn_cast<llvm::Instruction>(I.getOperand(1 - ImmSrcId1));
2319                 if (AddInst && AddInst->getOpcode() == Instruction::Add)
2320                 {
2321                     for (int ImmSrcId2 = 0; ImmSrcId2 < 2; ImmSrcId2++)
2322                     {
2323                         ConstantInt* AddConstant = dyn_cast<ConstantInt>(AddInst->getOperand(ImmSrcId2));
2324                         if (AddConstant)
2325                         {
2326                             llvm::APInt CombineAddValue = AddConstant->getValue() + IConstant->getValue();
2327                             I.setOperand(0, AddInst->getOperand(1 - ImmSrcId2));
2328                             I.setOperand(1, ConstantInt::get(I.getType(), CombineAddValue));
2329                         }
2330                     }
2331                 }
2332             }
2333         }
2334     }
2335     else if (I.getOpcode() == Instruction::Shl)
2336     {
2337     /*
2338       From:
2339           %5 = zext i32 %a to i64 <--- optional
2340           %6 = shl i64 %5, 32
2341 
2342       To:
2343           %BC = bitcast i64 %5 to <2 x i32> <---- not needed when %5 is zext
2344           %6 = extractelement <2 x i32> %BC, i32 0 <---- not needed when %5 is zext
2345           %7 = insertelement <2 x i32> %vec, i32 0, i32 0
2346           %8 = insertelement <2 x i32> %vec, %6, i32 1
2347           %9 = bitcast <2 x i32> %8 to i64
2348     */
2349 
2350         using namespace llvm::PatternMatch;
2351         Instruction* inst = nullptr;
2352 
2353         auto pattern1 = m_Shl(m_Instruction(inst), m_SpecificInt(32));
2354 
2355         if (match(&I, pattern1) && I.getType()->isIntegerTy(64))
2356         {
2357             createBitcastExtractInsertPattern(I, nullptr, I.getOperand(0), 0, 0);
2358         }
2359     }
2360     else if (I.getOpcode() == Instruction::And)
2361     {
2362         /*  This `and` is basically fabs() done on high part of int representation.
2363             For float instructions minus operand can end as SrcMod, but since we cast it
2364             from double to int it will end as additional mov, and we can ignore this m_FNeg
2365             anyway.
2366 
2367             From :
2368                 %sub = fsub double -0.000000e+00, %res.039
2369                 %25 = bitcast double %sub to i64
2370                 %26 = bitcast i64 %25 to <2 x i32> // or directly double to <2xi32>
2371                 %27 = extractelement <2 x i32> %26, i32 1
2372                 %and31.i = and i32 %27, 2147483647
2373 
2374             To:
2375                 %25 = bitcast double %res.039 to <2 x i32>
2376                 %27 = extractelement <2 x i32> %26, i32 1
2377                 %and31.i = and i32 %27, 2147483647
2378 
2379             Or on Int64 without extract:
2380             From:
2381                 %sub = fsub double -0.000000e+00, %res.039
2382                 %astype.i112.i.i = bitcast double %sub to i64
2383                 %and107.i.i = and i64 %astype.i112.i.i, 9223372032559808512 // 0x7FFFFFFF00000000
2384             To:
2385                 %bit_cast = bitcast double %res.039 to i64
2386                 %and107.i.i = and i64 %bit_cast, 9223372032559808512 // 0x7FFFFFFF00000000
2387 
2388         */
2389 
2390         /*  Get src of either 2 bitcast chain: double -> i64, i64 -> 2xi32
2391             or from single direct: double -> 2xi32
2392         */
2393         auto getValidBitcastSrc = [](Instruction* op) -> llvm::Value *
2394         {
2395             if (!(isa<BitCastInst>(op)))
2396                 return nullptr;
2397 
2398             BitCastInst* opBC = cast<BitCastInst>(op);
2399 
2400             auto opType = opBC->getType();
2401             if (!(opType->isVectorTy() && cast<VectorType>(opType)->getElementType()->isIntegerTy(32) && cast<IGCLLVM::FixedVectorType>(opType)->getNumElements() == 2))
2402                 return nullptr;
2403 
2404             if (opBC->getSrcTy()->isDoubleTy())
2405                 return opBC->getOperand(0); // double -> 2xi32
2406 
2407             BitCastInst* bitCastSrc = dyn_cast<BitCastInst>(opBC->getOperand(0));
2408 
2409             if (bitCastSrc && bitCastSrc->getDestTy()->isIntegerTy(64) && bitCastSrc->getSrcTy()->isDoubleTy())
2410                 return bitCastSrc->getOperand(0); // double -> i64, i64 -> 2xi32
2411 
2412             return nullptr;
2413         };
2414 
2415         using namespace llvm::PatternMatch;
2416         Value* src_of_FNeg = nullptr;
2417         Instruction* inst = nullptr;
2418 
2419         auto fabs_on_int_pattern1 = m_And(m_ExtractElt(m_Instruction(inst), m_SpecificInt(1)), m_SpecificInt(0x7FFFFFFF));
2420         auto fabs_on_int_pattern2 = m_And(m_Instruction(inst), m_SpecificInt(0x7FFFFFFF00000000));
2421         auto fneg_pattern = m_FNeg(m_Value(src_of_FNeg));
2422 
2423         /*
2424         From:
2425             %5 = zext i32 %a to i64 <--- optional
2426             %6 = and i64 %5, 0xFFFFFFFF00000000 (-4294967296)
2427         To:
2428             %BC = bitcast i64 %5 to <2 x i32> <---- not needed when %5 is zext
2429             %6 = extractelement <2 x i32> %BC, i32 1 <---- not needed when %5 is zext
2430             %7 = insertelement <2 x i32> %vec, i32 0, i32 0
2431             %8 = insertelement <2 x i32> %vec, %6, i32 1
2432             %9 = bitcast <2 x i32> %8 to i64
2433         */
2434 
2435         auto pattern1 = m_And(m_Instruction(inst), m_SpecificInt(0xFFFFFFFF00000000));
2436 
2437         if (match(&I, fabs_on_int_pattern1))
2438         {
2439             Value* src = getValidBitcastSrc(inst);
2440             if (src && match(src, fneg_pattern) && src_of_FNeg->getType()->isDoubleTy())
2441             {
2442                 llvm::IRBuilder<> builder(&I);
2443                 VectorType* vec2 = IGCLLVM::FixedVectorType::get(builder.getInt32Ty(), 2);
2444                 Value* BC = builder.CreateBitCast(src_of_FNeg, vec2);
2445                 Value* EE = builder.CreateExtractElement(BC, builder.getInt32(1));
2446                 Value* AI = builder.CreateAnd(EE, builder.getInt32(0x7FFFFFFF));
2447                 I.replaceAllUsesWith(AI);
2448                 I.eraseFromParent();
2449             }
2450         }
2451         else if (match(&I, fabs_on_int_pattern2))
2452         {
2453             BitCastInst* bitcast = dyn_cast<BitCastInst>(inst);
2454             bool bitcastValid = bitcast && bitcast->getDestTy()->isIntegerTy(64) && bitcast->getSrcTy()->isDoubleTy();
2455 
2456             if (bitcastValid && match(bitcast->getOperand(0), fneg_pattern) && src_of_FNeg->getType()->isDoubleTy())
2457             {
2458                 llvm::IRBuilder<> builder(&I);
2459                 Value* BC = builder.CreateBitCast(src_of_FNeg, I.getType());
2460                 Value* AI = builder.CreateAnd(BC, builder.getInt64(0x7FFFFFFF00000000));
2461                 I.replaceAllUsesWith(AI);
2462                 I.eraseFromParent();
2463             }
2464         }
2465         else if (match(&I, pattern1) && I.getType()->isIntegerTy(64))
2466         {
2467             createBitcastExtractInsertPattern(I, nullptr, I.getOperand(0), 0, 1);
2468         }
2469         else
2470         {
2471 
2472             Instruction* AndSrc = nullptr;
2473             ConstantInt* CI;
2474 
2475             /*
2476             From:
2477               %28 = and i32 %24, 255
2478               %29 = lshr i32 %24, 8
2479               %30 = and i32 %29, 255
2480               %31 = lshr i32 %24, 16
2481               %32 = and i32 %31, 255
2482             To:
2483               %temp = bitcast i32 %24 to <4 x i8>
2484               %ee1 = extractelement <4 x i8> %temp, i32 0
2485               %ee2 = extractelement <4 x i8> %temp, i32 1
2486               %ee3 = extractelement <4 x i8> %temp, i32 2
2487               %28 = zext i8 %ee1 to i32
2488               %30 = zext i8 %ee2 to i32
2489               %32 = zext i8 %ee3 to i32
2490             */
2491             auto pattern_And_0xFF = m_And(m_Instruction(AndSrc), m_SpecificInt(0xFF));
2492 
2493             CodeGenContext* ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
2494             bool bytesAllowed = ctx->platform.supportByteALUOperation();
2495 
2496             if (bytesAllowed && match(&I, pattern_And_0xFF) && I.getType()->isIntegerTy(32) && AndSrc->getType()->isIntegerTy(32))
2497             {
2498                 Instruction* LhsSrc = nullptr;
2499 
2500                 auto LShr_Pattern = m_LShr(m_Instruction(LhsSrc), m_ConstantInt(CI));
2501                 bool LShrMatch = match(AndSrc, LShr_Pattern) && LhsSrc->getType()->isIntegerTy(32) && (CI->getZExtValue() % 8 == 0);
2502 
2503                 // in case there's no shr, it will be 0
2504                 uint32_t newIndex = 0;
2505 
2506                 if (LShrMatch) // extract inner
2507                 {
2508                     AndSrc = LhsSrc;
2509                     newIndex = (uint32_t)CI->getZExtValue() / 8;
2510                 }
2511 
2512                 llvm::IRBuilder<> builder(&I);
2513                 VectorType* vec4 = VectorType::get(builder.getInt8Ty(), 4, false);
2514                 Value* BC = builder.CreateBitCast(AndSrc, vec4);
2515                 Value* EE = builder.CreateExtractElement(BC, builder.getInt32(newIndex));
2516                 Value* Zext = builder.CreateZExt(EE, builder.getInt32Ty());
2517                 I.replaceAllUsesWith(Zext);
2518                 I.eraseFromParent();
2519 
2520             }
2521 
2522         }
2523     }
2524 }
2525 
visitCmpInst(CmpInst & I)2526 void GenSpecificPattern::visitCmpInst(CmpInst& I)
2527 {
2528     using namespace llvm::PatternMatch;
2529     CmpInst::Predicate Pred = CmpInst::Predicate::BAD_ICMP_PREDICATE;
2530     Value* Val1 = nullptr;
2531     uint64_t const_int1 = 0, const_int2 = 0;
2532     auto cmp_pattern = m_Cmp(Pred,
2533         m_And(m_Value(Val1), m_ConstantInt(const_int1)), m_ConstantInt(const_int2));
2534 
2535     if (match(&I, cmp_pattern) &&
2536         (const_int1 << 32) == 0 &&
2537         (const_int2 << 32) == 0 &&
2538         Val1->getType()->isIntegerTy(64))
2539     {
2540         llvm::IRBuilder<> builder(&I);
2541         VectorType* vec2 = IGCLLVM::FixedVectorType::get(builder.getInt32Ty(), 2);
2542         Value* BC = builder.CreateBitCast(Val1, vec2);
2543         Value* EE = builder.CreateExtractElement(BC, builder.getInt32(1));
2544         Value* AI = builder.CreateAnd(EE, builder.getInt32(const_int1 >> 32));
2545         Value* new_Val = builder.CreateICmp(Pred, AI, builder.getInt32(const_int2 >> 32));
2546         I.replaceAllUsesWith(new_Val);
2547         I.eraseFromParent();
2548     }
2549     else
2550     {
2551         CodeGenContext* pCtx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
2552         if (pCtx->getCompilerOption().NoNaNs)
2553         {
2554             if (I.getPredicate() == CmpInst::FCMP_ORD)
2555             {
2556                 I.replaceAllUsesWith(ConstantInt::getTrue(I.getType()));
2557             }
2558         }
2559     }
2560 }
2561 
visitSelectInst(SelectInst & I)2562 void GenSpecificPattern::visitSelectInst(SelectInst& I)
2563 {
2564     /*
2565     from
2566         %res_s42 = icmp eq i32 %src1_s41, 0
2567         %src1_s81 = select i1 %res_s42, i32 15, i32 0
2568     to
2569         %res_s42 = icmp eq i32 %src1_s41, 0
2570         %17 = sext i1 %res_s42 to i32
2571         %18 = and i32 15, %17
2572 
2573                or
2574 
2575     from
2576         %res_s73 = fcmp oge float %res_s61, %42
2577         %res_s187 = select i1 %res_s73, float 1.000000e+00, float 0.000000e+00
2578     to
2579         %res_s73 = fcmp oge float %res_s61, %42
2580         %46 = sext i1 %res_s73 to i32
2581         %47 = and i32 %46, 1065353216
2582         %48 = bitcast i32 %47 to float
2583     */
2584 
2585     IGC_ASSERT(I.getOpcode() == Instruction::Select);
2586 
2587     bool skipOpt = false;
2588 
2589     ConstantInt* Cint = dyn_cast<ConstantInt>(I.getOperand(2));
2590     if (Cint && Cint->isZero())
2591     {
2592         llvm::Instruction* cmpInst = llvm::dyn_cast<llvm::Instruction>(I.getOperand(0));
2593         if (cmpInst &&
2594             cmpInst->getOpcode() == Instruction::ICmp &&
2595             I.getOperand(1) != cmpInst->getOperand(0))
2596         {
2597             // disable the cases for csel or where we can optimize the instructions to such as add.ge.* later in vISA
2598             ConstantInt* C = dyn_cast<ConstantInt>(cmpInst->getOperand(1));
2599             if (C && C->isZero())
2600             {
2601                 skipOpt = true;
2602             }
2603 
2604             if (!skipOpt)
2605             {
2606                 // temporary disable the case where cmp is used in multiple sel, and not all of them have src2=0
2607                 // we should remove this if we can allow both flag and grf dst for the cmp to be used.
2608                 for (auto selI = cmpInst->user_begin(), E = cmpInst->user_end(); selI != E; ++selI)
2609                 {
2610                     if (llvm::SelectInst * selInst = llvm::dyn_cast<llvm::SelectInst>(*selI))
2611                     {
2612                         ConstantInt* C = dyn_cast<ConstantInt>(selInst->getOperand(2));
2613                         if (!(C && C->isZero()))
2614                         {
2615                             skipOpt = true;
2616                             break;
2617                         }
2618                     }
2619                 }
2620             }
2621 
2622             if (!skipOpt)
2623             {
2624                 Value* newValueSext = CastInst::CreateSExtOrBitCast(I.getOperand(0), I.getType(), "", &I);
2625                 Value* newValueAnd = BinaryOperator::CreateAnd(I.getOperand(1), newValueSext, "", &I);
2626                 I.replaceAllUsesWith(newValueAnd);
2627             }
2628         }
2629     }
2630     else
2631     {
2632         ConstantFP* Cfp = dyn_cast<ConstantFP>(I.getOperand(2));
2633         if (Cfp && Cfp->isZero())
2634         {
2635             llvm::Instruction* cmpInst = llvm::dyn_cast<llvm::Instruction>(I.getOperand(0));
2636             if (cmpInst &&
2637                 cmpInst->getOpcode() == Instruction::FCmp &&
2638                 I.getOperand(1) != cmpInst->getOperand(0))
2639             {
2640                 // disable the cases for csel or where we can optimize the instructions to such as add.ge.* later in vISA
2641                 ConstantFP* C = dyn_cast<ConstantFP>(cmpInst->getOperand(1));
2642                 if (C && C->isZero())
2643                 {
2644                     skipOpt = true;
2645                 }
2646 
2647                 if (!skipOpt)
2648                 {
2649                     for (auto selI = cmpInst->user_begin(), E = cmpInst->user_end(); selI != E; ++selI)
2650                     {
2651                         if (llvm::SelectInst * selInst = llvm::dyn_cast<llvm::SelectInst>(*selI))
2652                         {
2653                             // temporary disable the case where cmp is used in multiple sel, and not all of them have src2=0
2654                             // we should remove this if we can allow both flag and grf dst for the cmp to be used.
2655                             ConstantFP* C2 = dyn_cast<ConstantFP>(selInst->getOperand(2));
2656                             if (!(C2 && C2->isZero()))
2657                             {
2658                                 skipOpt = true;
2659                                 break;
2660                             }
2661 
2662                             // if it is cmp-sel(1.0 / 0.0)-mul, we could better patten match it later in codeGen.
2663                             ConstantFP* C1 = dyn_cast<ConstantFP>(selInst->getOperand(1));
2664                             if (C1 && C2 && selInst->hasOneUse())
2665                             {
2666                                 if ((C2->isZero() && C1->isExactlyValue(1.f)) || (C1->isZero() && C2->isExactlyValue(1.f)))
2667                                 {
2668                                     Instruction* mulInst = dyn_cast<Instruction>(*selInst->user_begin());
2669                                     if (mulInst && mulInst->getOpcode() == Instruction::FMul)
2670                                     {
2671                                         skipOpt = true;
2672                                         break;
2673                                     }
2674                                 }
2675                             }
2676                         }
2677                     }
2678                 }
2679 
2680                 if (!skipOpt)
2681                 {
2682                     ConstantFP* C1 = dyn_cast<ConstantFP>(I.getOperand(1));
2683                     if (C1)
2684                     {
2685                         if (C1->getType()->isHalfTy())
2686                         {
2687                             Value* newValueSext = CastInst::CreateSExtOrBitCast(I.getOperand(0), Type::getInt16Ty(I.getContext()), "", &I);
2688                             Value* newConstant = ConstantInt::get(I.getContext(), C1->getValueAPF().bitcastToAPInt());
2689                             Value* newValueAnd = BinaryOperator::CreateAnd(newValueSext, newConstant, "", &I);
2690                             Value* newValueCastFP = CastInst::CreateZExtOrBitCast(newValueAnd, Type::getHalfTy(I.getContext()), "", &I);
2691                             I.replaceAllUsesWith(newValueCastFP);
2692                         }
2693                         else if (C1->getType()->isFloatTy())
2694                         {
2695                             Value* newValueSext = CastInst::CreateSExtOrBitCast(I.getOperand(0), Type::getInt32Ty(I.getContext()), "", &I);
2696                             Value* newConstant = ConstantInt::get(I.getContext(), C1->getValueAPF().bitcastToAPInt());
2697                             Value* newValueAnd = BinaryOperator::CreateAnd(newValueSext, newConstant, "", &I);
2698                             Value* newValueCastFP = CastInst::CreateZExtOrBitCast(newValueAnd, Type::getFloatTy(I.getContext()), "", &I);
2699                             I.replaceAllUsesWith(newValueCastFP);
2700                         }
2701                     }
2702                     else
2703                     {
2704                         if (I.getOperand(1)->getType()->isHalfTy())
2705                         {
2706                             Value* newValueSext = CastInst::CreateSExtOrBitCast(I.getOperand(0), Type::getInt16Ty(I.getContext()), "", &I);
2707                             Value* newValueBitcast = CastInst::CreateZExtOrBitCast(I.getOperand(1), Type::getInt16Ty(I.getContext()), "", &I);
2708                             Value* newValueAnd = BinaryOperator::CreateAnd(newValueSext, newValueBitcast, "", &I);
2709                             Value* newValueCastFP = CastInst::CreateZExtOrBitCast(newValueAnd, Type::getHalfTy(I.getContext()), "", &I); \
2710                                 I.replaceAllUsesWith(newValueCastFP);
2711                         }
2712                         else if (I.getOperand(1)->getType()->isFloatTy())
2713                         {
2714                             Value* newValueSext = CastInst::CreateSExtOrBitCast(I.getOperand(0), Type::getInt32Ty(I.getContext()), "", &I);
2715                             Value* newValueBitcast = CastInst::CreateZExtOrBitCast(I.getOperand(1), Type::getInt32Ty(I.getContext()), "", &I);
2716                             Value* newValueAnd = BinaryOperator::CreateAnd(newValueSext, newValueBitcast, "", &I);
2717                             Value* newValueCastFP = CastInst::CreateZExtOrBitCast(newValueAnd, Type::getFloatTy(I.getContext()), "", &I); \
2718                                 I.replaceAllUsesWith(newValueCastFP);
2719                         }
2720                     }
2721                 }
2722             }
2723         }
2724     }
2725 
2726     /*
2727     from
2728         %230 = sdiv i32 %214, %scale
2729         %276 = trunc i32 %230 to i8
2730         %277 = icmp slt i32 %230, 255
2731         %278 = select i1 %277, i8 %276, i8 -1
2732     to
2733         %230 = sdiv i32 %214, %scale
2734         %277 = icmp slt i32 %230, 255
2735         %278 = select i1 %277, i32 %230, i32 255
2736         %279 = trunc i32 %278 to i8
2737 
2738         This tranform allows for min/max instructions to be
2739         picked up in the IsMinOrMax instruction in PatternMatchPass.cpp
2740     */
2741     if (auto * compInst = dyn_cast<ICmpInst>(I.getOperand(0)))
2742     {
2743         auto selOp1 = I.getOperand(1);
2744         auto selOp2 = I.getOperand(2);
2745         auto cmpOp0 = compInst->getOperand(0);
2746         auto cmpOp1 = compInst->getOperand(1);
2747         auto trunc1 = dyn_cast<TruncInst>(selOp1);
2748         auto trunc2 = dyn_cast<TruncInst>(selOp2);
2749         auto icmpType = compInst->getOperand(0)->getType();
2750 
2751         if (selOp1->getType()->isIntegerTy() &&
2752             icmpType->isIntegerTy() &&
2753             selOp1->getType()->getIntegerBitWidth() < icmpType->getIntegerBitWidth())
2754         {
2755             Value* newSelOp1 = NULL;
2756             Value* newSelOp2 = NULL;
2757             if (trunc1 &&
2758                 (trunc1->getOperand(0) == cmpOp0 ||
2759                     trunc1->getOperand(0) == cmpOp1))
2760             {
2761                 newSelOp1 = (trunc1->getOperand(0) == cmpOp0) ? cmpOp0 : cmpOp1;
2762             }
2763 
2764             if (trunc2 &&
2765                 (trunc2->getOperand(0) == cmpOp0 ||
2766                     trunc2->getOperand(0) == cmpOp1))
2767             {
2768                 newSelOp2 = (trunc2->getOperand(0) == cmpOp0) ? cmpOp0 : cmpOp1;
2769             }
2770 
2771             if (isa<llvm::ConstantInt>(selOp1) &&
2772                 isa<llvm::ConstantInt>(cmpOp0) &&
2773                 (cast<llvm::ConstantInt>(selOp1)->getZExtValue() ==
2774                     cast<llvm::ConstantInt>(cmpOp0)->getZExtValue()))
2775             {
2776                 IGC_ASSERT(newSelOp1 == NULL);
2777                 newSelOp1 = cmpOp0;
2778             }
2779 
2780             if (isa<llvm::ConstantInt>(selOp1) &&
2781                 isa<llvm::ConstantInt>(cmpOp1) &&
2782                 (cast<llvm::ConstantInt>(selOp1)->getZExtValue() ==
2783                     cast<llvm::ConstantInt>(cmpOp1)->getZExtValue()))
2784             {
2785                 IGC_ASSERT(newSelOp1 == NULL);
2786                 newSelOp1 = cmpOp1;
2787             }
2788 
2789             if (isa<llvm::ConstantInt>(selOp2) &&
2790                 isa<llvm::ConstantInt>(cmpOp0) &&
2791                 (cast<llvm::ConstantInt>(selOp2)->getZExtValue() ==
2792                     cast<llvm::ConstantInt>(cmpOp0)->getZExtValue()))
2793             {
2794                 IGC_ASSERT(newSelOp2 == NULL);
2795                 newSelOp2 = cmpOp0;
2796             }
2797 
2798             if (isa<llvm::ConstantInt>(selOp2) &&
2799                 isa<llvm::ConstantInt>(cmpOp1) &&
2800                 (cast<llvm::ConstantInt>(selOp2)->getZExtValue() ==
2801                     cast<llvm::ConstantInt>(cmpOp1)->getZExtValue()))
2802             {
2803                 IGC_ASSERT(newSelOp2 == NULL);
2804                 newSelOp2 = cmpOp1;
2805             }
2806 
2807             if (newSelOp1 && newSelOp2)
2808             {
2809                 auto newSelInst = SelectInst::Create(I.getCondition(), newSelOp1, newSelOp2, "", &I);
2810                 auto newTruncInst = TruncInst::CreateTruncOrBitCast(newSelInst, selOp1->getType(), "", &I);
2811                 I.replaceAllUsesWith(newTruncInst);
2812                 I.eraseFromParent();
2813             }
2814         }
2815 
2816     }
2817 
2818 }
2819 
visitCastInst(CastInst & I)2820 void GenSpecificPattern::visitCastInst(CastInst& I)
2821 {
2822     Instruction* srcVal = nullptr;
2823     // Intrinsic::trunc call is handled by 'rndz' hardware instruction which
2824     // does not support double precision floating point type
2825     if (I.getType()->isDoubleTy())
2826     {
2827         return;
2828     }
2829     if (isa<SIToFPInst>(&I))
2830     {
2831         srcVal = dyn_cast<FPToSIInst>(I.getOperand(0));
2832     }
2833     if (srcVal && srcVal->getOperand(0)->getType() == I.getType())
2834     {
2835         if ((srcVal = dyn_cast<Instruction>(srcVal->getOperand(0))))
2836         {
2837             // need fast math to know that we can ignore Nan
2838             if (isa<FPMathOperator>(srcVal) && srcVal->isFast())
2839             {
2840                 IRBuilder<> builder(&I);
2841                 Function* func = Intrinsic::getDeclaration(
2842                     I.getParent()->getParent()->getParent(),
2843                     Intrinsic::trunc,
2844                     I.getType());
2845                 Value* newVal = builder.CreateCall(func, srcVal);
2846                 I.replaceAllUsesWith(newVal);
2847                 I.eraseFromParent();
2848             }
2849         }
2850     }
2851 }
2852 
2853 /*
2854 from:
2855     %HighBits.Vec = insertelement <2 x i32> <i32 0, i32 undef>, i32 %HighBits.32, i32 1
2856     %HighBits.64 = bitcast <2 x i32> %HighBits.Vec to i64
2857     %LowBits.64 = zext i32 %LowBits.32 to i64
2858     %LowPlusHighBits = or i64 %HighBits.64, %LowBits.64
2859     %19 = bitcast i64 %LowPlusHighBits to double
2860 to:
2861     %17 = insertelement <2 x i32> undef, i32 %LowBits.32, i32 0
2862     %18 = insertelement <2 x i32> %17, i32 %HighBits.32, i32 1
2863     %19 = bitcast <2 x i32> %18 to double
2864 */
2865 
visitBitCastInst(BitCastInst & I)2866 void GenSpecificPattern::visitBitCastInst(BitCastInst& I)
2867 {
2868     if (I.getType()->isDoubleTy() && I.getOperand(0)->getType()->isIntegerTy(64))
2869     {
2870         BinaryOperator* binOperator = nullptr;
2871         if ((binOperator = dyn_cast<BinaryOperator>(I.getOperand(0))) && binOperator->getOpcode() == Instruction::Or)
2872         {
2873             if (isa<BitCastInst>(binOperator->getOperand(0)) && isa<ZExtInst>(binOperator->getOperand(1)))
2874             {
2875                 BitCastInst* bitCastInst = cast<BitCastInst>(binOperator->getOperand(0));
2876                 ZExtInst* zExtInst = cast<ZExtInst>(binOperator->getOperand(1));
2877 
2878                 if (zExtInst->getOperand(0)->getType()->isIntegerTy(32) &&
2879                     isa<InsertElementInst>(bitCastInst->getOperand(0)) &&
2880                     bitCastInst->getOperand(0)->getType()->isVectorTy() &&
2881                     cast<IGCLLVM::FixedVectorType>(bitCastInst->getOperand(0)->getType())->getElementType()->isIntegerTy(32) &&
2882                     cast<IGCLLVM::FixedVectorType>(bitCastInst->getOperand(0)->getType())->getNumElements() == 2)
2883                 {
2884                     InsertElementInst* insertElementInst = cast<InsertElementInst>(bitCastInst->getOperand(0));
2885 
2886                     if (isa<Constant>(insertElementInst->getOperand(0)) &&
2887                         cast<Constant>(insertElementInst->getOperand(0))->getAggregateElement((unsigned int)0)->isZeroValue() &&
2888                         cast<ConstantInt>(insertElementInst->getOperand(2))->getZExtValue() == 1)
2889                     {
2890                         IRBuilder<> builder(&I);
2891                         Value* vectorValue = UndefValue::get(bitCastInst->getOperand(0)->getType());
2892                         vectorValue = builder.CreateInsertElement(vectorValue, zExtInst->getOperand(0), builder.getInt32(0));
2893                         vectorValue = builder.CreateInsertElement(vectorValue, insertElementInst->getOperand(1), builder.getInt32(1));
2894                         Value* newBitCast = builder.CreateBitCast(vectorValue, builder.getDoubleTy());
2895                         I.replaceAllUsesWith(newBitCast);
2896                         I.eraseFromParent();
2897                     }
2898                 }
2899             }
2900         }
2901     }
2902 }
2903 
2904 /*
2905     Matches a pattern where pointer to load instruction is fetched by other load instruction.
2906     On targets that do not support 64 bit operations, Emu64OpsPass will insert pair_to_ptr intrinsic
2907     between the loads and InstructionCombining will not optimize this case.
2908 
2909     This function changes following pattern:
2910     %3 = load <2 x i32>, <2 x i32> addrspace(1)* %2, align 64
2911     %4 = extractelement <2 x i32> %3, i32 0
2912     %5 = extractelement <2 x i32> %3, i32 1
2913     %6 = call %union._XReq addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* @llvm.genx.GenISA.pair.to.ptr.p1p1p1p1p1p1p1p1union._XReq(i32 %4, i32 %5)
2914     %7 = bitcast %union._XReq addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* addrspace(1)* %6 to i64 addrspace(1)*
2915     %8 = bitcast i64 addrspace(1)* %7 to <2 x i32> addrspace(1)*
2916     %9 = load <2 x i32>, <2 x i32> addrspace(1)* %8, align 64
2917 
2918     to:
2919     %3 = bitcast <2 x i32> addrspace(1)* %2 to <2 x i32> addrspace(1)* addrspace(1)*
2920     %4 = load <2 x i32> addrspace(1)*, <2 x i32> addrspace(1)* addrspace(1)* %3, align 64
2921     ... dead code
2922     %11 = load <2 x i32>, <2 x i32> addrspace(1)* %4, align 64
2923 */
visitLoadInst(LoadInst & LI)2924 void GenSpecificPattern::visitLoadInst(LoadInst &LI) {
2925     Value* PO = LI.getPointerOperand();
2926     std::vector<Value*> OneUseValues = { PO };
2927     while (isa<BitCastInst>(PO)) {
2928         PO = cast<BitCastInst>(PO)->getOperand(0);
2929         OneUseValues.push_back(PO);
2930     }
2931 
2932     bool IsPairToPtrInst = (isa<GenIntrinsicInst>(PO) &&
2933         cast<GenIntrinsicInst>(PO)->getIntrinsicID() ==
2934         GenISAIntrinsic::GenISA_pair_to_ptr);
2935 
2936     if (!IsPairToPtrInst)
2937         return;
2938 
2939     // check if this pointer comes from a load.
2940     auto CallInst = cast<GenIntrinsicInst>(PO);
2941     auto Op0 = dyn_cast<ExtractElementInst>(CallInst->getArgOperand(0));
2942     auto Op1 = dyn_cast<ExtractElementInst>(CallInst->getArgOperand(1));
2943     bool PointerComesFromALoad = (Op0 && Op1 && isa<ConstantInt>(Op0->getIndexOperand()) &&
2944         isa<ConstantInt>(Op1->getIndexOperand()) &&
2945         cast<ConstantInt>(Op0->getIndexOperand())->getZExtValue() == 0 &&
2946         cast<ConstantInt>(Op1->getIndexOperand())->getZExtValue() == 1 &&
2947         isa<LoadInst>(Op0->getVectorOperand()) &&
2948         isa<LoadInst>(Op1->getVectorOperand()) &&
2949         Op0->getVectorOperand() == Op1->getVectorOperand());
2950 
2951     if (!PointerComesFromALoad)
2952         return;
2953 
2954     OneUseValues.insert(OneUseValues.end(), { Op0, Op1 });
2955 
2956     if (!std::all_of(OneUseValues.begin(), OneUseValues.end(), [](auto v) { return v->hasOneUse(); }))
2957         return;
2958 
2959     auto VectorLoadInst = cast<LoadInst>(Op0->getVectorOperand());
2960     if (VectorLoadInst->getNumUses() != 2)
2961         return;
2962 
2963     auto PointerOperand = VectorLoadInst->getPointerOperand();
2964     PointerType* newLoadPointerType = PointerType::get(
2965         LI.getPointerOperand()->getType(), PointerOperand->getType()->getPointerAddressSpace());
2966     IRBuilder<> builder(VectorLoadInst);
2967     auto CastedPointer =
2968         builder.CreateBitCast(PointerOperand, newLoadPointerType);
2969     auto NewLoadInst = IGC::cloneLoad(VectorLoadInst, CastedPointer);
2970 
2971     LI.setOperand(0, NewLoadInst);
2972 }
2973 
visitZExtInst(ZExtInst & ZEI)2974 void GenSpecificPattern::visitZExtInst(ZExtInst& ZEI)
2975 {
2976     CmpInst* Cmp = dyn_cast<CmpInst>(ZEI.getOperand(0));
2977     if (!Cmp)
2978         return;
2979 
2980     IRBuilder<> Builder(&ZEI);
2981 
2982     Value* S = Builder.CreateSExt(Cmp, ZEI.getType());
2983     Value* N = Builder.CreateNeg(S);
2984     ZEI.replaceAllUsesWith(N);
2985     ZEI.eraseFromParent();
2986 }
2987 
visitIntToPtr(llvm::IntToPtrInst & I)2988 void GenSpecificPattern::visitIntToPtr(llvm::IntToPtrInst& I)
2989 {
2990     if (ZExtInst * zext = dyn_cast<ZExtInst>(I.getOperand(0)))
2991     {
2992         IRBuilder<> builder(&I);
2993         Value* newV = builder.CreateIntToPtr(zext->getOperand(0), I.getType());
2994         I.replaceAllUsesWith(newV);
2995         I.eraseFromParent();
2996     }
2997 }
2998 
visitTruncInst(llvm::TruncInst & I)2999 void GenSpecificPattern::visitTruncInst(llvm::TruncInst& I)
3000 {
3001     /*
3002     from
3003     %22 = lshr i64 %a, 52
3004     %23 = trunc i64  %22 to i32
3005     to
3006     %22 = extractelement <2 x i32> %a, 1
3007     %23 = lshr i32 %22, 20 //52-32
3008     */
3009 
3010     using namespace llvm::PatternMatch;
3011     Value* LHS = nullptr;
3012     ConstantInt* CI;
3013     if (match(&I, m_Trunc(m_LShr(m_Value(LHS), m_ConstantInt(CI)))) &&
3014         I.getType()->isIntegerTy(32) &&
3015         LHS->getType()->isIntegerTy(64) &&
3016         CI->getZExtValue() >= 32)
3017     {
3018         auto new_shift_size = (unsigned)CI->getZExtValue() - 32;
3019         llvm::IRBuilder<> builder(&I);
3020         VectorType* vec2 = IGCLLVM::FixedVectorType::get(builder.getInt32Ty(), 2);
3021         Value* new_Val = builder.CreateBitCast(LHS, vec2);
3022         new_Val = builder.CreateExtractElement(new_Val, builder.getInt32(1));
3023         if (new_shift_size > 0)
3024         {
3025             new_Val = builder.CreateLShr(new_Val, builder.getInt32(new_shift_size));
3026         }
3027         I.replaceAllUsesWith(new_Val);
3028         I.eraseFromParent();
3029     }
3030 }
3031 
3032 #if LLVM_VERSION_MAJOR >= 10
visitFNeg(llvm::UnaryOperator & I)3033 void GenSpecificPattern::visitFNeg(llvm::UnaryOperator& I)
3034 {
3035     // from
3036     // %neg = fneg double %1
3037     // to
3038     // %neg = fsub double -0.000000e+00, %1
3039     // vISA have parser which looks for such operation pattern "0 - x"
3040     // and adds source modifier for this region/value.
3041 
3042     IRBuilder<> builder(&I);
3043 
3044     Value* fsub = nullptr;
3045 
3046     if (!I.getType()->isVectorTy())
3047     {
3048         fsub = builder.CreateFSub(ConstantFP::get(I.getType(), 0.0f), I.getOperand(0));
3049     }
3050     else
3051     {
3052         uint32_t vectorSize = cast<IGCLLVM::FixedVectorType>(I.getType())->getNumElements();
3053         fsub = llvm::UndefValue::get(I.getType());
3054 
3055         for (uint32_t i = 0; i < vectorSize; ++i)
3056         {
3057             Value* extract = builder.CreateExtractElement(I.getOperand(0), i);
3058             Value* extract_fsub = builder.CreateFSub(ConstantFP::get(extract->getType(), 0.0f), extract);
3059             fsub = builder.CreateInsertElement(fsub, extract_fsub, i);
3060         }
3061     }
3062 
3063     I.replaceAllUsesWith(fsub);
3064 }
3065 #endif
3066 
3067 // Register pass to igc-opt
3068 #define PASS_FLAG3 "igc-const-prop"
3069 #define PASS_DESCRIPTION3 "Custom Const-prop Pass"
3070 #define PASS_CFG_ONLY3 false
3071 #define PASS_ANALYSIS3 false
3072 IGC_INITIALIZE_PASS_BEGIN(IGCConstProp, PASS_FLAG3, PASS_DESCRIPTION3, PASS_CFG_ONLY3, PASS_ANALYSIS3)
3073 IGC_INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
3074 IGC_INITIALIZE_PASS_END(IGCConstProp, PASS_FLAG3, PASS_DESCRIPTION3, PASS_CFG_ONLY3, PASS_ANALYSIS3)
3075 
3076 char IGCConstProp::ID = 0;
3077 
IGCConstProp(bool enableSimplifyGEP)3078 IGCConstProp::IGCConstProp(
3079     bool enableSimplifyGEP) :
3080     FunctionPass(ID),
3081     m_enableSimplifyGEP(enableSimplifyGEP),
3082     m_TD(nullptr), m_TLI(nullptr)
3083 {
3084     initializeIGCConstPropPass(*PassRegistry::getPassRegistry());
3085 }
3086 
GetConstantValue(Type * type,char * rawData)3087 static Constant* GetConstantValue(Type* type, char* rawData)
3088 {
3089     unsigned int size_in_bytes = (unsigned int)type->getPrimitiveSizeInBits() / 8;
3090     uint64_t returnConstant = 0;
3091     memcpy_s(&returnConstant, size_in_bytes, rawData, size_in_bytes);
3092     if (type->isIntegerTy())
3093     {
3094         return ConstantInt::get(type, returnConstant);
3095     }
3096     else if (type->isFloatingPointTy())
3097     {
3098         return  ConstantFP::get(type->getContext(),
3099             APFloat(type->getFltSemantics(), APInt((unsigned int)type->getPrimitiveSizeInBits(), returnConstant)));
3100     }
3101     return nullptr;
3102 }
3103 
3104 
replaceShaderConstant(Instruction * inst)3105 Constant* IGCConstProp::replaceShaderConstant(Instruction* inst)
3106 {
3107     CodeGenContext* ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
3108     ModuleMetaData* modMD = ctx->getModuleMetaData();
3109     ConstantAddress cl;
3110     unsigned int& bufIdOrGRFOffset = cl.bufId;
3111     unsigned int& eltId = cl.eltId;
3112     unsigned int& size_in_bytes = cl.size;
3113     bool directBuf = false;
3114     bool statelessBuf = false;
3115     bool bindlessBuf = false;
3116 
3117     if (getConstantAddress(*inst, cl, ctx, directBuf, statelessBuf, bindlessBuf))
3118     {
3119         if (size_in_bytes)
3120         {
3121             if (modMD->immConstant.data.size() &&
3122                 ((statelessBuf && (bufIdOrGRFOffset == modMD->pushInfo.inlineConstantBufferGRFOffset))||
3123                 (directBuf && (bufIdOrGRFOffset == modMD->pushInfo.inlineConstantBufferSlot))))
3124             {
3125                 char* offset = &(modMD->immConstant.data[0]);
3126                 if (inst->getType()->isVectorTy())
3127                 {
3128                     Type* srcEltTy = cast<VectorType>(inst->getType())->getElementType();
3129                     uint32_t srcNElts = (uint32_t)cast<IGCLLVM::FixedVectorType>(inst->getType())->getNumElements();
3130                     uint32_t eltSize_in_bytes = (unsigned int)srcEltTy->getPrimitiveSizeInBits() / 8;
3131                     IRBuilder<> builder(inst);
3132                     Value* vectorValue = UndefValue::get(inst->getType());
3133                     char* pEltValue;        // Pointer to element value
3134                     for (uint i = 0; i < srcNElts; i++)
3135                     {
3136                         if (eltId < 0 || eltId >= (int)modMD->immConstant.data.size())
3137                         {
3138                             int OOBvalue = 0;       // OOB access to immediate constant buffer should return 0
3139                             char* pOOBvalue = (char*)& OOBvalue;    // Pointer to value 0 which is a OOB access value
3140                             pEltValue = pOOBvalue;
3141                         }
3142                         else
3143                             pEltValue = offset + eltId + (i * eltSize_in_bytes);
3144                         vectorValue = builder.CreateInsertElement(
3145                             vectorValue,
3146                             GetConstantValue(srcEltTy, pEltValue),
3147                             builder.getInt32(i));
3148                     }
3149                     return dyn_cast<Constant>(vectorValue);
3150                 }
3151                 else
3152                 {
3153                     char* pEltValue;        // Pointer to element value
3154                     if (eltId < 0 || eltId >= (int)modMD->immConstant.data.size())
3155                     {
3156                         int OOBvalue = 0;       // OOB access to immediate constant buffer should return 0
3157                         char* pOOBvalue = (char*)& OOBvalue;    // Pointer to value 0 which is a OOB access value
3158                         pEltValue = pOOBvalue;
3159                     }
3160                     else
3161                         pEltValue = offset + eltId;
3162                     return GetConstantValue(inst->getType(), pEltValue);
3163                 }
3164             }
3165         }
3166     }
3167     return nullptr;
3168 }
3169 
ConstantFoldCallInstruction(CallInst * inst)3170 Constant* IGCConstProp::ConstantFoldCallInstruction(CallInst* inst)
3171 {
3172     IGCConstantFolder constantFolder;
3173     Constant* C = nullptr;
3174     if (inst)
3175     {
3176         Constant* C0 = dyn_cast<Constant>(inst->getOperand(0));
3177         EOPCODE igcop = GetOpCode(inst);
3178 
3179         switch (igcop)
3180         {
3181         case llvm_gradientXfine:
3182         {
3183             if (C0)
3184             {
3185                 C = constantFolder.CreateGradientXFine(C0);
3186             }
3187         }
3188         break;
3189         case llvm_gradientYfine:
3190         {
3191             if (C0)
3192             {
3193                 C = constantFolder.CreateGradientYFine(C0);
3194             }
3195         }
3196         break;
3197         case llvm_gradientX:
3198         {
3199             if (C0)
3200             {
3201                 C = constantFolder.CreateGradientX(C0);
3202             }
3203         }
3204         break;
3205         case llvm_gradientY:
3206         {
3207             if (C0)
3208             {
3209                 C = constantFolder.CreateGradientY(C0);
3210             }
3211         }
3212         break;
3213         case llvm_rsq:
3214         {
3215             if (C0)
3216             {
3217                 C = constantFolder.CreateRsq(C0);
3218             }
3219         }
3220         break;
3221         case llvm_roundne:
3222         {
3223             if (C0)
3224             {
3225                 C = constantFolder.CreateRoundNE(C0);
3226             }
3227         }
3228         break;
3229         case llvm_fsat:
3230         {
3231             if (C0)
3232             {
3233                 C = constantFolder.CreateFSat(C0);
3234             }
3235         }
3236         break;
3237         case llvm_fptrunc_rte:
3238         {
3239             if (C0)
3240             {
3241                 C = constantFolder.CreateFPTrunc(C0, inst->getType(), llvm::APFloatBase::rmNearestTiesToEven);
3242             }
3243         }
3244         break;
3245         case llvm_fptrunc_rtz:
3246         {
3247             if (C0)
3248             {
3249                 C = constantFolder.CreateFPTrunc(C0, inst->getType(), llvm::APFloatBase::rmTowardZero);
3250             }
3251         }
3252         break;
3253         case llvm_fptrunc_rtp:
3254         {
3255             if (C0)
3256             {
3257                 C = constantFolder.CreateFPTrunc(C0, inst->getType(), llvm::APFloatBase::rmTowardPositive);
3258             }
3259         }
3260         break;
3261         case llvm_fptrunc_rtn:
3262         {
3263             if (C0)
3264             {
3265                 C = constantFolder.CreateFPTrunc(C0, inst->getType(), llvm::APFloatBase::rmTowardNegative);
3266             }
3267         }
3268         break;
3269         case llvm_f32tof16_rtz:
3270         {
3271             if (C0)
3272             {
3273                 C = constantFolder.CreateFPTrunc(C0, Type::getHalfTy(inst->getContext()), llvm::APFloatBase::rmTowardZero);
3274                 C = constantFolder.CreateBitCast(C, Type::getInt16Ty(inst->getContext()));
3275                 C = constantFolder.CreateZExtOrBitCast(C, Type::getInt32Ty(inst->getContext()));
3276                 C = constantFolder.CreateBitCast(C, inst->getType());
3277             }
3278         }
3279         break;
3280         case llvm_fadd_rtz:
3281         {
3282             Constant* C1 = dyn_cast<Constant>(inst->getOperand(1));
3283             if (C0 && C1)
3284             {
3285                 C = constantFolder.CreateFAdd(C0, C1, llvm::APFloatBase::rmTowardZero);
3286             }
3287         }
3288         break;
3289         case llvm_fmul_rtz:
3290         {
3291             Constant* C1 = dyn_cast<Constant>(inst->getOperand(1));
3292             if (C0 && C1)
3293             {
3294                 C = constantFolder.CreateFMul(C0, C1, llvm::APFloatBase::rmTowardZero);
3295             }
3296         }
3297         break;
3298         case llvm_ubfe:
3299         {
3300             Constant* C1 = dyn_cast<Constant>(inst->getOperand(1));
3301             Constant* C2 = dyn_cast<Constant>(inst->getOperand(2));
3302             if (C0 && C0->isZeroValue())
3303             {
3304                 C = llvm::ConstantInt::get(inst->getType(), 0);
3305             }
3306             else if (C0 && C1 && C2)
3307             {
3308                 C = constantFolder.CreateUbfe(C0, C1, C2);
3309             }
3310         }
3311         break;
3312         case llvm_ibfe:
3313         {
3314             Constant* C1 = dyn_cast<Constant>(inst->getOperand(1));
3315             Constant* C2 = dyn_cast<Constant>(inst->getOperand(2));
3316             if (C0 && C0->isZeroValue())
3317             {
3318                 C = llvm::ConstantInt::get(inst->getType(), 0);
3319             }
3320             else if (C0 && C1 && C2)
3321             {
3322                 C = constantFolder.CreateIbfe(C0, C1, C2);
3323             }
3324         }
3325         break;
3326         case llvm_canonicalize:
3327         {
3328             // If the instruction should be emitted anyway, then remove the condition.
3329             // Please, be aware of the fact that clients can understand the term canonical FP value in other way.
3330             if (C0)
3331             {
3332                 CodeGenContext* pCodeGenContext = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
3333                 bool flushVal = pCodeGenContext->m_floatDenormMode16 == ::IGC::FLOAT_DENORM_FLUSH_TO_ZERO && inst->getType()->isHalfTy();
3334                 flushVal = flushVal || (pCodeGenContext->m_floatDenormMode32 == ::IGC::FLOAT_DENORM_FLUSH_TO_ZERO && inst->getType()->isFloatTy());
3335                 flushVal = flushVal || (pCodeGenContext->m_floatDenormMode64 == ::IGC::FLOAT_DENORM_FLUSH_TO_ZERO && inst->getType()->isDoubleTy());
3336                 C = constantFolder.CreateCanonicalize(C0, flushVal);
3337             }
3338         }
3339         break;
3340         case llvm_fbh:
3341         {
3342             if (C0)
3343             {
3344                 C = constantFolder.CreateFirstBitHi(C0);
3345             }
3346         }
3347         break;
3348         case llvm_fbh_shi:
3349         {
3350             if (C0)
3351             {
3352                 C = constantFolder.CreateFirstBitShi(C0);
3353             }
3354         }
3355         break;
3356         case llvm_fbl:
3357         {
3358             if (C0)
3359             {
3360                 C = constantFolder.CreateFirstBitLo(C0);
3361             }
3362         }
3363         break;
3364         case llvm_bfrev:
3365         {
3366             if (C0)
3367             {
3368                 C = constantFolder.CreateBfrev(C0);
3369             }
3370         }
3371         break;
3372         case llvm_bfi:
3373         {
3374             Constant* C1 = dyn_cast<Constant>(inst->getOperand(1));
3375             Constant* C2 = dyn_cast<Constant>(inst->getOperand(2));
3376             Constant* C3 = dyn_cast<Constant>(inst->getOperand(3));
3377             if (C0 && C0->isZeroValue() && C3)
3378             {
3379                 C = C3;
3380             }
3381             else if (C0 && C1 && C2 && C3)
3382             {
3383                 C = constantFolder.CreateBfi(C0, C1, C2, C3);
3384             }
3385         }
3386         break;
3387         default:
3388             break;
3389         }
3390     }
3391     return C;
3392 }
3393 
3394 // constant fold the following code for any index:
3395 //
3396 // %95 = extractelement <4 x i16> <i16 3, i16 16, i16 21, i16 39>, i32 %94
3397 // %96 = icmp eq i16 %95, 0
3398 //
ConstantFoldCmpInst(CmpInst * CI)3399 Constant* IGCConstProp::ConstantFoldCmpInst(CmpInst* CI)
3400 {
3401     // Only handle scalar result.
3402     if (CI->getType()->isVectorTy())
3403         return nullptr;
3404 
3405     Value* LHS = CI->getOperand(0);
3406     Value* RHS = CI->getOperand(1);
3407     if (!isa<Constant>(RHS) && CI->isCommutative())
3408         std::swap(LHS, RHS);
3409     if (!isa<ConstantInt>(RHS) && !isa<ConstantFP>(RHS))
3410         return nullptr;
3411 
3412     auto EEI = dyn_cast<ExtractElementInst>(LHS);
3413     if (EEI && isa<Constant>(EEI->getVectorOperand()))
3414     {
3415         bool AllTrue = true, AllFalse = true;
3416         auto VecOpnd = cast<Constant>(EEI->getVectorOperand());
3417         unsigned N = (unsigned)cast<IGCLLVM::FixedVectorType>(VecOpnd->getType())->getNumElements();
3418         for (unsigned i = 0; i < N; ++i)
3419         {
3420             Constant* const Opnd = VecOpnd->getAggregateElement(i);
3421             IGC_ASSERT_MESSAGE(nullptr != Opnd, "null entry");
3422 
3423             if (isa<UndefValue>(Opnd))
3424                 continue;
3425             Constant* Result = ConstantFoldCompareInstOperands(
3426                 CI->getPredicate(), Opnd, cast<Constant>(RHS), CI->getFunction()->getParent()->getDataLayout());
3427             if (!Result->isAllOnesValue())
3428                 AllTrue = false;
3429             if (!Result->isNullValue())
3430                 AllFalse = false;
3431         }
3432 
3433         if (AllTrue)
3434         {
3435             return ConstantInt::getTrue(CI->getType());
3436         }
3437         else if (AllFalse)
3438         {
3439             return ConstantInt::getFalse(CI->getType());
3440         }
3441     }
3442 
3443     return nullptr;
3444 }
3445 
3446 // constant fold the following code for any index:
3447 //
3448 // %93 = insertelement  <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, float %v7.w_, i32 0
3449 // %95 = extractelement <4 x float> %93, i32 1
3450 //
3451 // constant fold the selection of the same value in a vector component, e.g.:
3452 // %Temp - 119.i.i.v.v = select i1 %Temp - 118.i.i, <2 x i32> <i32 0, i32 17>, <2 x i32> <i32 4, i32 17>
3453 // %scalar9 = extractelement <2 x i32> %Temp - 119.i.i.v.v, i32 1
3454 //
ConstantFoldExtractElement(ExtractElementInst * EEI)3455 Constant* IGCConstProp::ConstantFoldExtractElement(ExtractElementInst* EEI)
3456 {
3457 
3458     Constant* EltIdx = dyn_cast<Constant>(EEI->getIndexOperand());
3459     if (EltIdx)
3460     {
3461         if (InsertElementInst * IEI = dyn_cast<InsertElementInst>(EEI->getVectorOperand()))
3462         {
3463             Constant* InsertIdx = dyn_cast<Constant>(IEI->getOperand(2));
3464             // try to find the constant from a chain of InsertElement
3465             while (IEI && InsertIdx)
3466             {
3467                 if (InsertIdx == EltIdx)
3468                 {
3469                     Constant* EltVal = dyn_cast<Constant>(IEI->getOperand(1));
3470                     return EltVal;
3471                 }
3472                 else
3473                 {
3474                     Value* Vec = IEI->getOperand(0);
3475                     if (isa<ConstantDataVector>(Vec))
3476                     {
3477                         ConstantDataVector* CVec = cast<ConstantDataVector>(Vec);
3478                         return CVec->getAggregateElement(EltIdx);
3479                     }
3480                     else if (isa<InsertElementInst>(Vec))
3481                     {
3482                         IEI = cast<InsertElementInst>(Vec);
3483                         InsertIdx = dyn_cast<Constant>(IEI->getOperand(2));
3484                     }
3485                     else
3486                     {
3487                         break;
3488                     }
3489                 }
3490             }
3491         }
3492         else if (SelectInst * sel = dyn_cast<SelectInst>(EEI->getVectorOperand()))
3493         {
3494             Value* vec0 = sel->getOperand(1);
3495             Value* vec1 = sel->getOperand(2);
3496 
3497             IGC_ASSERT(vec0->getType() == vec1->getType());
3498 
3499             if (isa<ConstantDataVector>(vec0) && isa<ConstantDataVector>(vec1))
3500             {
3501                 ConstantDataVector* cvec0 = cast<ConstantDataVector>(vec0);
3502                 ConstantDataVector* cvec1 = cast<ConstantDataVector>(vec1);
3503                 Constant* cval0 = cvec0->getAggregateElement(EltIdx);
3504                 Constant* cval1 = cvec1->getAggregateElement(EltIdx);
3505 
3506                 if (cval0 == cval1)
3507                 {
3508                     return cval0;
3509                 }
3510             }
3511         }
3512     }
3513     return nullptr;
3514 }
3515 
3516 //  simplifyAdd() push any constants to the top of a sequence of Add instructions,
3517 //  which makes CSE/GVN to do a better job.  For example,
3518 //      (((A + 8) + B) + C) + 15
3519 //  will become
3520 //      ((A + B) + C) + 23
3521 //  Note that the order of non-constant values remain unchanged throughout this
3522 //  transformation.
3523 //  (This was added to remove redundant loads. If the
3524 //  the future LLVM does better job on this (reassociation), we should use LLVM's
3525 //  instead.)
simplifyAdd(BinaryOperator * BO)3526 bool IGCConstProp::simplifyAdd(BinaryOperator* BO)
3527 {
3528     // Only handle Add
3529     if (BO->getOpcode() != Instruction::Add)
3530     {
3531         return false;
3532     }
3533 
3534     Value* LHS = BO->getOperand(0);
3535     Value* RHS = BO->getOperand(1);
3536     bool changed = false;
3537     if (BinaryOperator * LBO = dyn_cast<BinaryOperator>(LHS))
3538     {
3539         if (simplifyAdd(LBO))
3540         {
3541             changed = true;
3542         }
3543     }
3544     if (BinaryOperator * RBO = dyn_cast<BinaryOperator>(RHS))
3545     {
3546         if (simplifyAdd(RBO))
3547         {
3548             changed = true;
3549         }
3550     }
3551 
3552     // Refresh LHS and RHS
3553     LHS = BO->getOperand(0);
3554     RHS = BO->getOperand(1);
3555     BinaryOperator* LHSbo = dyn_cast<BinaryOperator>(LHS);
3556     BinaryOperator* RHSbo = dyn_cast<BinaryOperator>(RHS);
3557     bool isLHSAdd = LHSbo && LHSbo->getOpcode() == Instruction::Add;
3558     bool isRHSAdd = RHSbo && RHSbo->getOpcode() == Instruction::Add;
3559     IRBuilder<> Builder(BO);
3560     if (isLHSAdd && isRHSAdd)
3561     {
3562         Value* A = LHSbo->getOperand(0);
3563         Value* B = LHSbo->getOperand(1);
3564         Value* C = RHSbo->getOperand(0);
3565         Value* D = RHSbo->getOperand(1);
3566 
3567         ConstantInt* C0 = dyn_cast<ConstantInt>(B);
3568         ConstantInt* C1 = dyn_cast<ConstantInt>(D);
3569 
3570         if (C0 || C1)
3571         {
3572             Value* R = nullptr;
3573             if (C0 && C1)
3574             {
3575                 Value* newC = ConstantFoldBinaryOpOperands(Instruction::Add,
3576                     C0, C1, *m_TD);
3577                 R = Builder.CreateAdd(A, C);
3578                 R = Builder.CreateAdd(R, newC);
3579             }
3580             else if (C0)
3581             {
3582                 R = Builder.CreateAdd(A, RHS);
3583                 R = Builder.CreateAdd(R, B);
3584             }
3585             else
3586             {   // C1 is not nullptr
3587                 R = Builder.CreateAdd(LHS, C);
3588                 R = Builder.CreateAdd(R, D);
3589             }
3590             BO->replaceAllUsesWith(R);
3591             return true;
3592         }
3593     }
3594     else if (isLHSAdd)
3595     {
3596         Value* A = LHSbo->getOperand(0);
3597         Value* B = LHSbo->getOperand(1);
3598         Value* C = RHS;
3599 
3600         ConstantInt* C0 = dyn_cast<ConstantInt>(B);
3601         ConstantInt* C1 = dyn_cast<ConstantInt>(C);
3602 
3603         if (C0 && C1)
3604         {
3605             Value* newC = ConstantFoldBinaryOpOperands(Instruction::Add,
3606                 C0, C1, *m_TD);
3607             Value* R = Builder.CreateAdd(A, newC);
3608             BO->replaceAllUsesWith(R);
3609             return true;
3610         }
3611         if (C0)
3612         {
3613             Value* R = Builder.CreateAdd(A, C);
3614             R = Builder.CreateAdd(R, B);
3615             BO->replaceAllUsesWith(R);
3616             return true;
3617         }
3618     }
3619     else if (isRHSAdd)
3620     {
3621         Value* A = LHS;
3622         Value* B = RHSbo->getOperand(0);
3623         Value* C = RHSbo->getOperand(1);
3624 
3625         ConstantInt* C0 = dyn_cast<ConstantInt>(A);
3626         ConstantInt* C1 = dyn_cast<ConstantInt>(C);
3627 
3628         if (C0 && C1)
3629         {
3630             Value* newC = ConstantFoldBinaryOpOperands(Instruction::Add,
3631                 C0, C1, *m_TD);
3632             Value* R = Builder.CreateAdd(B, newC);
3633             BO->replaceAllUsesWith(R);
3634             return true;
3635         }
3636         if (C0)
3637         {
3638             Value* R = Builder.CreateAdd(RHS, A);
3639             BO->replaceAllUsesWith(R);
3640             return true;
3641         }
3642         if (C1)
3643         {
3644             Value* R = Builder.CreateAdd(A, B);
3645             R = Builder.CreateAdd(R, C);
3646             BO->replaceAllUsesWith(R);
3647             return true;
3648         }
3649     }
3650     else
3651     {
3652         if (ConstantInt * CLHS = dyn_cast<ConstantInt>(LHS))
3653         {
3654             if (ConstantInt * CRHS = dyn_cast<ConstantInt>(RHS))
3655             {
3656                 Value* newC = ConstantFoldBinaryOpOperands(Instruction::Add,
3657                     CLHS, CRHS, *m_TD);
3658                 BO->replaceAllUsesWith(newC);
3659                 return true;
3660             }
3661 
3662             // Constant is kept as RHS
3663             Value* R = Builder.CreateAdd(RHS, LHS);
3664             BO->replaceAllUsesWith(R);
3665             return true;
3666         }
3667     }
3668     return changed;
3669 }
3670 
simplifyGEP(GetElementPtrInst * GEP)3671 bool IGCConstProp::simplifyGEP(GetElementPtrInst* GEP)
3672 {
3673     bool changed = false;
3674     for (int i = 0; i < (int)GEP->getNumIndices(); ++i)
3675     {
3676         Value* Index = GEP->getOperand(i + 1);
3677         BinaryOperator* BO = dyn_cast<BinaryOperator>(Index);
3678         if (!BO || BO->getOpcode() != Instruction::Add)
3679         {
3680             continue;
3681         }
3682         if (simplifyAdd(BO))
3683         {
3684             changed = true;
3685         }
3686     }
3687     return changed;
3688 }
3689 
3690 /**
3691 * the following code is essentially a copy of llvm copy-prop code with one little
3692 * addition for shader-constant replacement.
3693 *
3694 * we don't have to do this if llvm version uses a virtual function in place of calling
3695 * ConstantFoldInstruction.
3696 */
runOnFunction(Function & F)3697 bool IGCConstProp::runOnFunction(Function& F)
3698 {
3699     module = F.getParent();
3700     // Initialize the worklist to all of the instructions ready to process...
3701     llvm::SetVector<Instruction*> WorkList;
3702     for (inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i)
3703     {
3704         WorkList.insert(&*i);
3705     }
3706     bool Changed = false;
3707     m_TD = &F.getParent()->getDataLayout();
3708     m_TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
3709     while (!WorkList.empty())
3710     {
3711         Instruction* I = *WorkList.rbegin();
3712         WorkList.remove(I);    // Get an element from the worklist...
3713         if (I->use_empty())                  // Don't muck with dead instructions...
3714         {
3715             continue;
3716         }
3717         Constant* C = nullptr;
3718         C = ConstantFoldInstruction(I, *m_TD, m_TLI);
3719 
3720         if (!C && isa<CallInst>(I))
3721         {
3722             C = ConstantFoldCallInstruction(cast<CallInst>(I));
3723         }
3724 
3725         // replace shader-constant load with the known value
3726         if (!C && isa<LoadInst>(I))
3727         {
3728             C = replaceShaderConstant(cast<LoadInst>(I));
3729         }
3730         if (!C && isa<CmpInst>(I))
3731         {
3732             C = ConstantFoldCmpInst(cast<CmpInst>(I));
3733         }
3734         if (!C && isa<ExtractElementInst>(I))
3735         {
3736             C = ConstantFoldExtractElement(cast<ExtractElementInst>(I));
3737         }
3738         if (C)
3739         {
3740             // Add all of the users of this instruction to the worklist, they might
3741             // be constant propagatable now...
3742             for (Value::user_iterator UI = I->user_begin(), UE = I->user_end();
3743                 UI != UE; ++UI)
3744             {
3745                 WorkList.insert(cast<Instruction>(*UI));
3746             }
3747 
3748             // Replace all of the uses of a variable with uses of the constant.
3749             I->replaceAllUsesWith(C);
3750 
3751             if (0 /* isa<ConstantPointerNull>(C)*/) // disable optimization generating invalid IR until it gets re-written
3752             {
3753                 // if we are changing function calls/ genisa intrinsics, then we need
3754                 // to fix the function declarations to account for the change in pointer address type
3755                 for (Value::user_iterator UI = C->user_begin(), UE = C->user_end();
3756                     UI != UE; ++UI)
3757                 {
3758                     if (GenIntrinsicInst * genIntr = dyn_cast<GenIntrinsicInst>(*UI))
3759                     {
3760                         GenISAIntrinsic::ID ID = genIntr->getIntrinsicID();
3761                         if (ID == GenISAIntrinsic::GenISA_storerawvector_indexed)
3762                         {
3763                             llvm::Type* tys[2];
3764                             tys[0] = genIntr->getOperand(0)->getType();
3765                             tys[1] = genIntr->getOperand(2)->getType();
3766                             GenISAIntrinsic::getDeclaration(F.getParent(),
3767                                 llvm::GenISAIntrinsic::GenISA_storerawvector_indexed,
3768                                 tys);
3769                         }
3770                         else if (ID == GenISAIntrinsic::GenISA_storeraw_indexed)
3771                         {
3772                             llvm::Type* types[2] = {
3773                                 genIntr->getOperand(0)->getType(),
3774                                 genIntr->getOperand(1)->getType() };
3775 
3776                             GenISAIntrinsic::getDeclaration(F.getParent(),
3777                                 llvm::GenISAIntrinsic::GenISA_storeraw_indexed,
3778                                 types);
3779                         }
3780                         else if (ID == GenISAIntrinsic::GenISA_ldrawvector_indexed || ID == GenISAIntrinsic::GenISA_ldraw_indexed)
3781                         {
3782                             llvm::Type* tys[2];
3783                             tys[0] = genIntr->getType();
3784                             tys[1] = genIntr->getOperand(0)->getType();
3785                             GenISAIntrinsic::getDeclaration(F.getParent(),
3786                                 ID,
3787                                 tys);
3788                         }
3789                     }
3790                 }
3791             }
3792 
3793             // Remove the dead instruction.
3794             I->eraseFromParent();
3795 
3796             // We made a change to the function...
3797             Changed = true;
3798 
3799             // I is erased, continue to the next one.
3800             continue;
3801         }
3802 
3803         if (GetElementPtrInst * GEP = dyn_cast<GetElementPtrInst>(I))
3804         {
3805             if (m_enableSimplifyGEP && simplifyGEP(GEP))
3806             {
3807                 Changed = true;
3808             }
3809         }
3810     }
3811     return Changed;
3812 }
3813 
3814 namespace {
3815 
3816     class IGCIndirectICBPropagaion : public FunctionPass
3817     {
3818     public:
3819         static char ID;
IGCIndirectICBPropagaion()3820         IGCIndirectICBPropagaion() : FunctionPass(ID)
3821         {
3822             initializeIGCIndirectICBPropagaionPass(*PassRegistry::getPassRegistry());
3823         }
getPassName() const3824         virtual llvm::StringRef getPassName() const { return "Indirect ICB Propagation"; }
3825         virtual bool runOnFunction(Function& F);
getAnalysisUsage(llvm::AnalysisUsage & AU) const3826         virtual void getAnalysisUsage(llvm::AnalysisUsage& AU) const
3827         {
3828             AU.setPreservesCFG();
3829             AU.addRequired<CodeGenContextWrapper>();
3830         }
3831     private:
3832         bool isICBOffseted(llvm::LoadInst* inst, uint offset);
3833     };
3834 
3835 } // namespace
3836 
3837 char IGCIndirectICBPropagaion::ID = 0;
createIGCIndirectICBPropagaionPass()3838 FunctionPass* IGC::createIGCIndirectICBPropagaionPass() { return new IGCIndirectICBPropagaion(); }
3839 
runOnFunction(Function & F)3840 bool IGCIndirectICBPropagaion::runOnFunction(Function& F)
3841 {
3842     CodeGenContext* ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
3843     ModuleMetaData* modMD = ctx->getModuleMetaData();
3844 
3845     //MaxImmConstantSizePushed = 256 by default. For float values, it will contains 64 numbers, and stored in 8 GRF
3846     if (modMD &&
3847         modMD->immConstant.data.size() &&
3848         modMD->immConstant.data.size() <= IGC_GET_FLAG_VALUE(MaxImmConstantSizePushed))
3849     {
3850         uint maxImmConstantSizePushed = modMD->immConstant.data.size();
3851         char* offset = &(modMD->immConstant.data[0]);
3852         IRBuilder<> m_builder(F.getContext());
3853 
3854         for (auto& BB : F)
3855         {
3856             for (auto BI = BB.begin(), BE = BB.end(); BI != BE;)
3857             {
3858                 if (llvm::LoadInst * inst = llvm::dyn_cast<llvm::LoadInst>(&(*BI++)))
3859                 {
3860                     unsigned as = inst->getPointerAddressSpace();
3861                     bool directBuf;
3862                     unsigned bufId;
3863                     BufferType bufType = IGC::DecodeAS4GFXResource(as, directBuf, bufId);
3864                     bool bICBNoOffset =
3865                         (IGC::INVALID_CONSTANT_BUFFER_INVALID_ADDR == modMD->pushInfo.inlineConstantBufferOffset && bufType == CONSTANT_BUFFER && directBuf && bufId == modMD->pushInfo.inlineConstantBufferSlot);
3866                     bool bICBOffseted =
3867                         (IGC::INVALID_CONSTANT_BUFFER_INVALID_ADDR != modMD->pushInfo.inlineConstantBufferOffset && ADDRESS_SPACE_CONSTANT == as && isICBOffseted(inst, modMD->pushInfo.inlineConstantBufferOffset));
3868                     if (bICBNoOffset || bICBOffseted)
3869                     {
3870                         Value* ptrVal = inst->getPointerOperand();
3871                         Value* eltPtr = nullptr;
3872                         Value* eltIdx = nullptr;
3873                         if (IntToPtrInst * i2p = dyn_cast<IntToPtrInst>(ptrVal))
3874                         {
3875                             eltPtr = i2p->getOperand(0);
3876                         }
3877                         else if (GetElementPtrInst * gep = dyn_cast<GetElementPtrInst>(ptrVal))
3878                         {
3879                             if (gep->getNumOperands() != 3)
3880                             {
3881                                 continue;
3882                             }
3883 
3884                             Type* eleType = gep->getPointerOperandType()->getPointerElementType();
3885                             if (!eleType->isArrayTy() ||
3886                                 !(eleType->getArrayElementType()->isFloatTy() || eleType->getArrayElementType()->isIntegerTy(32)))
3887                             {
3888                                 continue;
3889                             }
3890 
3891                             eltIdx = gep->getOperand(2);
3892                         }
3893                         else
3894                         {
3895                             continue;
3896                         }
3897 
3898                         m_builder.SetInsertPoint(inst);
3899 
3900                         unsigned int size_in_bytes = (unsigned int)inst->getType()->getPrimitiveSizeInBits() / 8;
3901                         if (size_in_bytes)
3902                         {
3903                             Value* ICBbuffer = UndefValue::get(IGCLLVM::FixedVectorType::get(inst->getType(), maxImmConstantSizePushed / size_in_bytes));
3904                             if (inst->getType()->isFloatTy())
3905                             {
3906                                 float returnConstant = 0;
3907                                 for (unsigned int i = 0; i < maxImmConstantSizePushed; i += size_in_bytes)
3908                                 {
3909                                     memcpy_s(&returnConstant, size_in_bytes, offset + i, size_in_bytes);
3910                                     Value* fp = ConstantFP::get(inst->getType(), returnConstant);
3911                                     ICBbuffer = m_builder.CreateInsertElement(ICBbuffer, fp, m_builder.getInt32(i / size_in_bytes));
3912                                 }
3913 
3914                                 if (eltPtr)
3915                                 {
3916                                     eltIdx = m_builder.CreateLShr(eltPtr, m_builder.getInt32(2));
3917                                 }
3918                                 Value* ICBvalue = m_builder.CreateExtractElement(ICBbuffer, eltIdx);
3919                                 inst->replaceAllUsesWith(ICBvalue);
3920                             }
3921                             else if (inst->getType()->isIntegerTy(32))
3922                             {
3923                                 int returnConstant = 0;
3924                                 for (unsigned int i = 0; i < maxImmConstantSizePushed; i += size_in_bytes)
3925                                 {
3926                                     memcpy_s(&returnConstant, size_in_bytes, offset + i, size_in_bytes);
3927                                     Value* fp = ConstantInt::get(inst->getType(), returnConstant);
3928                                     ICBbuffer = m_builder.CreateInsertElement(ICBbuffer, fp, m_builder.getInt32(i / size_in_bytes));
3929                                 }
3930                                 if (eltPtr)
3931                                 {
3932                                     eltIdx = m_builder.CreateLShr(eltPtr, m_builder.getInt32(2));
3933                                 }
3934                                 Value* ICBvalue = m_builder.CreateExtractElement(ICBbuffer, eltIdx);
3935                                 inst->replaceAllUsesWith(ICBvalue);
3936                             }
3937                         }
3938                     }
3939                 }
3940             }
3941         }
3942     }
3943 
3944     return false;
3945 }
3946 
isICBOffseted(llvm::LoadInst * inst,uint offset)3947 bool IGCIndirectICBPropagaion::isICBOffseted(llvm::LoadInst* inst, uint offset) {
3948     Value* ptrVal = inst->getPointerOperand();
3949     std::vector<Value*> srcInstList;
3950     IGC::TracePointerSource(ptrVal, false, true, true, srcInstList);
3951     if (srcInstList.size())
3952     {
3953         CallInst* inst = dyn_cast<CallInst>(srcInstList.back());
3954         GenIntrinsicInst* genIntr = inst ? dyn_cast<GenIntrinsicInst>(inst) : nullptr;
3955         if (!genIntr || (genIntr->getIntrinsicID() != GenISAIntrinsic::GenISA_RuntimeValue))
3956             return false;
3957 
3958         llvm::ConstantInt* ci = dyn_cast<llvm::ConstantInt>(inst->getOperand(0));
3959         return ci && (uint)ci->getZExtValue() == offset;
3960     }
3961 
3962     return false;
3963 }
3964 
3965 IGC_INITIALIZE_PASS_BEGIN(IGCIndirectICBPropagaion, "IGCIndirectICBPropagaion",
3966     "IGCIndirectICBPropagaion", false, false)
3967     IGC_INITIALIZE_PASS_END(IGCIndirectICBPropagaion, "IGCIndirectICBPropagaion",
3968         "IGCIndirectICBPropagaion", false, false)
3969 
3970     namespace {
3971     class NanHandling : public FunctionPass, public llvm::InstVisitor<NanHandling>
3972     {
3973     public:
3974         static char ID;
NanHandling()3975         NanHandling() : FunctionPass(ID)
3976         {
3977             initializeNanHandlingPass(*PassRegistry::getPassRegistry());
3978         }
3979 
getAnalysisUsage(llvm::AnalysisUsage & AU) const3980         void getAnalysisUsage(llvm::AnalysisUsage& AU) const
3981         {
3982             AU.setPreservesCFG();
3983             AU.addRequired<LoopInfoWrapperPass>();
3984         }
3985 
getPassName() const3986         virtual llvm::StringRef getPassName() const { return "NAN handling"; }
3987         virtual bool runOnFunction(llvm::Function& F);
3988         void visitBranchInst(llvm::BranchInst& I);
3989         void loopNanCases(Function& F);
3990 
3991     private:
3992         int longestPathInstCount(llvm::BasicBlock* BB, int& depth);
3993         void swapBranch(llvm::Instruction* inst, llvm::BranchInst& BI);
3994         SmallVector<llvm::BranchInst*, 10> visitedInst;
3995     };
3996 } // namespace
3997 
3998 char NanHandling::ID = 0;
createNanHandlingPass()3999 FunctionPass* IGC::createNanHandlingPass() { return new NanHandling(); }
4000 
runOnFunction(Function & F)4001 bool NanHandling::runOnFunction(Function& F)
4002 {
4003     loopNanCases(F);
4004     visit(F);
4005     return true;
4006 }
4007 
loopNanCases(Function & F)4008 void NanHandling::loopNanCases(Function& F)
4009 {
4010     // take care of loop cases
4011     visitedInst.clear();
4012     llvm::LoopInfo* LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
4013     if (LI && !LI->empty())
4014     {
4015         FastMathFlags FMF;
4016         FMF.clear();
4017         for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
4018         {
4019             Loop* loop = *I;
4020             BranchInst* br = cast<BranchInst>(loop->getLoopLatch()->getTerminator());
4021             BasicBlock* header = loop->getHeader();
4022             if (br && br->isConditional() && header)
4023             {
4024                 visitedInst.push_back(br);
4025                 if (FCmpInst * brCmpInst = dyn_cast<FCmpInst>(br->getCondition()))
4026                 {
4027                     FPMathOperator* FPO = dyn_cast<FPMathOperator>(brCmpInst);
4028                     if (!FPO || !FPO->isFast())
4029                     {
4030                         continue;
4031                     }
4032                     if (br->getSuccessor(1) == header)
4033                     {
4034                         swapBranch(brCmpInst, *br);
4035                     }
4036                 }
4037                 else if (BinaryOperator * andOrInst = dyn_cast<BinaryOperator>(br->getCondition()))
4038                 {
4039                     if (andOrInst->getOpcode() != BinaryOperator::And &&
4040                         andOrInst->getOpcode() != BinaryOperator::Or)
4041                     {
4042                         continue;
4043                     }
4044                     FCmpInst* brCmpInst0 = dyn_cast<FCmpInst>(andOrInst->getOperand(0));
4045                     FCmpInst* brCmpInst1 = dyn_cast<FCmpInst>(andOrInst->getOperand(1));
4046                     if (!brCmpInst0 || !brCmpInst1)
4047                     {
4048                         continue;
4049                     }
4050                     if (br->getSuccessor(1) == header)
4051                     {
4052                         brCmpInst0->copyFastMathFlags(FMF);
4053                         brCmpInst1->copyFastMathFlags(FMF);
4054                     }
4055                 }
4056             }
4057         }
4058     }
4059 }
4060 
longestPathInstCount(llvm::BasicBlock * BB,int & depth)4061 int NanHandling::longestPathInstCount(llvm::BasicBlock* BB, int& depth)
4062 {
4063 #define MAX_SEARCH_DEPTH 10
4064 
4065     depth++;
4066     if (!BB || depth > MAX_SEARCH_DEPTH)
4067         return 0;
4068 
4069     int sumSuccInstCount = 0;
4070     for (succ_iterator SI = succ_begin(BB), E = succ_end(BB); SI != E; ++SI)
4071     {
4072         sumSuccInstCount += longestPathInstCount(*SI, depth);
4073     }
4074     return (int)(BB->getInstList().size()) + sumSuccInstCount;
4075 }
4076 
swapBranch(llvm::Instruction * inst,llvm::BranchInst & BI)4077 void NanHandling::swapBranch(llvm::Instruction* inst, llvm::BranchInst& BI)
4078 {
4079     if (FCmpInst * brCondition = dyn_cast<FCmpInst>(inst))
4080     {
4081         if (inst->hasOneUse())
4082         {
4083             brCondition->setPredicate(FCmpInst::getInversePredicate(brCondition->getPredicate()));
4084             BI.swapSuccessors();
4085         }
4086     }
4087     else
4088     {
4089         // inst not expected
4090         IGC_ASSERT(0);
4091     }
4092 }
4093 
visitBranchInst(llvm::BranchInst & I)4094 void NanHandling::visitBranchInst(llvm::BranchInst& I)
4095 {
4096     if (!I.isConditional())
4097         return;
4098 
4099     // if this branch is part of a loop, it is taken care of already in loopNanCases
4100     for (auto iter = visitedInst.begin(); iter != visitedInst.end(); iter++)
4101     {
4102         if (&I == *iter)
4103             return;
4104     }
4105 
4106     FCmpInst* brCmpInst = dyn_cast<FCmpInst>(I.getCondition());
4107     FCmpInst* src0 = nullptr;
4108     FCmpInst* src1 = nullptr;
4109 
4110     // if the branching is based on a cmp instruction
4111     if (brCmpInst)
4112     {
4113         FPMathOperator* FPO = dyn_cast<FPMathOperator>(brCmpInst);
4114         if (!FPO || !FPO->isFast())
4115             return;
4116 
4117         if (!brCmpInst->hasOneUse())
4118             return;
4119     }
4120     // if the branching is based on a and/or from multiple conditions.
4121     else if (BinaryOperator * andOrInst = dyn_cast<BinaryOperator>(I.getCondition()))
4122     {
4123         if (andOrInst->getOpcode() != BinaryOperator::And && andOrInst->getOpcode() != BinaryOperator::Or)
4124             return;
4125 
4126         src0 = dyn_cast<FCmpInst>(andOrInst->getOperand(0));
4127         src1 = dyn_cast<FCmpInst>(andOrInst->getOperand(1));
4128 
4129         if (!src0 || !src1)
4130             return;
4131     }
4132     else
4133     {
4134         return;
4135     }
4136 
4137     // Calculate the maximum instruction count when going down one branch.
4138     // Make the false case (including NaN) goes to the shorter path.
4139     int depth = 0;
4140     int trueBranchSize = longestPathInstCount(I.getSuccessor(0), depth);
4141     depth = 0;
4142     int falseBranchSize = longestPathInstCount(I.getSuccessor(1), depth);
4143 
4144     if (falseBranchSize - trueBranchSize > (int)(IGC_GET_FLAG_VALUE(SetBranchSwapThreshold)))
4145     {
4146         if (brCmpInst)
4147         {
4148             // swap the condition and the successor blocks
4149             swapBranch(brCmpInst, I);
4150         }
4151         else
4152         {
4153             FastMathFlags FMF;
4154             FMF.clear();
4155             src0->copyFastMathFlags(FMF);
4156             src1->copyFastMathFlags(FMF);
4157         }
4158         return;
4159     }
4160 }
4161 IGC_INITIALIZE_PASS_BEGIN(NanHandling, "NanHandling", "NanHandling", false, false)
4162 IGC_INITIALIZE_PASS_END(NanHandling, "NanHandling", "NanHandling", false, false)
4163 
4164 namespace IGC
4165 {
4166 
4167     class VectorBitCastOpt: public FunctionPass
4168     {
4169     public:
4170         static char ID;
4171 
VectorBitCastOpt()4172         VectorBitCastOpt() : FunctionPass(ID)
4173         {
4174             initializeVectorBitCastOptPass(*PassRegistry::getPassRegistry());
4175         };
~VectorBitCastOpt()4176         ~VectorBitCastOpt() {}
4177 
getPassName() const4178         virtual llvm::StringRef getPassName() const override
4179         {
4180             return "VectorBitCastOptPass";
4181         }
4182 
getAnalysisUsage(llvm::AnalysisUsage & AU) const4183         virtual void getAnalysisUsage(llvm::AnalysisUsage& AU) const override
4184         {
4185             AU.setPreservesCFG();
4186         }
4187 
4188         virtual bool runOnFunction(llvm::Function& F) override;
4189 
4190     private:
4191         // Transform (extractelement (bitcast %vector) ...) to
4192         // (bitcast (extractelement %vector) ...) in order to help coalescing
4193         // in DeSSA and enable memory operations simplification
4194         // in VectorPreProcess.
4195         bool optimizeVectorBitCast(Function& F) const;
4196     };
4197 
runOnFunction(Function & F)4198     bool VectorBitCastOpt::runOnFunction(Function& F)
4199     {
4200         bool Changed = optimizeVectorBitCast(F);
4201         return Changed;
4202     }
4203 
optimizeVectorBitCast(Function & F) const4204     bool VectorBitCastOpt::optimizeVectorBitCast(Function& F) const {
4205         IRBuilder<> Builder(F.getContext());
4206 
4207         bool Changed = false;
4208         for (auto& BB : F) {
4209             for (auto BI = BB.begin(), BE = BB.end(); BI != BE; /*EMPTY*/) {
4210                 BitCastInst* BC = dyn_cast<BitCastInst>(&*BI++);
4211                 if (!BC) continue;
4212                 // Skip non-element-wise bitcast.
4213                 IGCLLVM::FixedVectorType* DstVTy = dyn_cast<IGCLLVM::FixedVectorType>(BC->getType());
4214                 IGCLLVM::FixedVectorType* SrcVTy = dyn_cast<IGCLLVM::FixedVectorType>(BC->getOperand(0)->getType());
4215                 if (!DstVTy || !SrcVTy || DstVTy->getNumElements() != SrcVTy->getNumElements())
4216                     continue;
4217                 // Skip if it's not used only all extractelement.
4218                 bool ExactOnly = true;
4219                 for (auto User : BC->users()) {
4220                     if (isa<ExtractElementInst>(User)) continue;
4221                     ExactOnly = false;
4222                     break;
4223                 }
4224                 if (!ExactOnly)
4225                     continue;
4226                 // Autobots, transform and roll out!
4227                 Value* Src = BC->getOperand(0);
4228                 Type* DstEltTy = DstVTy->getElementType();
4229                 for (auto UI = BC->user_begin(), UE = BC->user_end(); UI != UE;
4230                     /*EMPTY*/) {
4231                     auto EEI = cast<ExtractElementInst>(*UI++);
4232                     Builder.SetInsertPoint(EEI);
4233                     auto NewVal = Builder.CreateExtractElement(Src, EEI->getIndexOperand());
4234                     NewVal = Builder.CreateBitCast(NewVal, DstEltTy);
4235                     EEI->replaceAllUsesWith(NewVal);
4236                     EEI->eraseFromParent();
4237                 }
4238                 BI = BC->eraseFromParent();
4239                 Changed = true;
4240             }
4241         }
4242 
4243         return Changed;
4244     }
4245 
4246     char VectorBitCastOpt::ID = 0;
createVectorBitCastOptPass()4247     FunctionPass* createVectorBitCastOptPass() { return new VectorBitCastOpt(); }
4248 
4249 } // namespace IGC
4250 
4251 #define VECTOR_BITCAST_OPT_PASS_FLAG "igc-vector-bitcast-opt"
4252 #define VECTOR_BITCAST_OPT_PASS_DESCRIPTION "Preprocess vector bitcasts to be after extractelement instructions."
4253 #define VECTOR_BITCAST_OPT_PASS_CFG_ONLY true
4254 #define VECTOR_BITCAST_OPT_PASS_ANALYSIS false
4255 IGC_INITIALIZE_PASS_BEGIN(VectorBitCastOpt, VECTOR_BITCAST_OPT_PASS_FLAG, VECTOR_BITCAST_OPT_PASS_DESCRIPTION, VECTOR_BITCAST_OPT_PASS_CFG_ONLY, VECTOR_BITCAST_OPT_PASS_ANALYSIS)
4256 IGC_INITIALIZE_PASS_END(VectorBitCastOpt, VECTOR_BITCAST_OPT_PASS_FLAG, VECTOR_BITCAST_OPT_PASS_DESCRIPTION, VECTOR_BITCAST_OPT_PASS_CFG_ONLY, VECTOR_BITCAST_OPT_PASS_ANALYSIS)
4257 
4258 namespace {
4259 
4260     class GenStrengthReduction : public FunctionPass
4261     {
4262     public:
4263         static char ID;
GenStrengthReduction()4264         GenStrengthReduction() : FunctionPass(ID)
4265         {
4266             initializeGenStrengthReductionPass(*PassRegistry::getPassRegistry());
4267         }
getPassName() const4268         virtual llvm::StringRef getPassName() const { return "Gen strength reduction"; }
4269         virtual bool runOnFunction(Function& F);
4270 
4271     private:
4272         bool processInst(Instruction* Inst);
4273     };
4274 
4275 } // namespace
4276 
4277 char GenStrengthReduction::ID = 0;
createGenStrengthReductionPass()4278 FunctionPass* IGC::createGenStrengthReductionPass() { return new GenStrengthReduction(); }
4279 
runOnFunction(Function & F)4280 bool GenStrengthReduction::runOnFunction(Function& F)
4281 {
4282     bool Changed = false;
4283     for (auto& BB : F)
4284     {
4285         for (auto BI = BB.begin(), BE = BB.end(); BI != BE;)
4286         {
4287             Instruction* Inst = &(*BI++);
4288             if (isInstructionTriviallyDead(Inst))
4289             {
4290                 Inst->eraseFromParent();
4291                 Changed = true;
4292                 continue;
4293             }
4294             Changed |= processInst(Inst);
4295         }
4296     }
4297     return Changed;
4298 }
4299 
4300 // Check if this is a fdiv that allows reciprocal, and its divident is not known
4301 // to be 1.0.
isCandidateFDiv(Instruction * Inst)4302 static bool isCandidateFDiv(Instruction* Inst)
4303 {
4304     // Only floating points, and no vectors.
4305     if (!Inst->getType()->isFloatingPointTy() || Inst->use_empty())
4306         return false;
4307 
4308     auto Op = dyn_cast<FPMathOperator>(Inst);
4309     if (Op && Op->getOpcode() == Instruction::FDiv && Op->hasAllowReciprocal())
4310     {
4311         Value* Src0 = Op->getOperand(0);
4312         if (auto CFP = dyn_cast<ConstantFP>(Src0))
4313             return !CFP->isExactlyValue(1.0);
4314         return true;
4315     }
4316     return false;
4317 }
4318 
processInst(Instruction * Inst)4319 bool GenStrengthReduction::processInst(Instruction* Inst)
4320 {
4321 
4322     unsigned opc = Inst->getOpcode();
4323     auto Op = dyn_cast<FPMathOperator>(Inst);
4324     if (opc == Instruction::Select)
4325     {
4326         Value* oprd1 = Inst->getOperand(1);
4327         Value* oprd2 = Inst->getOperand(2);
4328         ConstantFP* CF1 = dyn_cast<ConstantFP>(oprd1);
4329         ConstantFP* CF2 = dyn_cast<ConstantFP>(oprd2);
4330         if (oprd1 == oprd2 ||
4331             (CF1 && CF2 && CF1->isExactlyValue(CF2->getValueAPF())))
4332         {
4333             Inst->replaceAllUsesWith(oprd1);
4334             Inst->eraseFromParent();
4335             return true;
4336         }
4337     }
4338     if (Op &&
4339         Op->hasNoNaNs() &&
4340         Op->hasNoInfs() &&
4341         Op->hasNoSignedZeros())
4342     {
4343         switch (opc)
4344         {
4345         case Instruction::FDiv:
4346         {
4347             Value* Oprd0 = Inst->getOperand(0);
4348             if (ConstantFP * CF = dyn_cast<ConstantFP>(Oprd0))
4349             {
4350                 if (CF->isZero())
4351                 {
4352                     Inst->replaceAllUsesWith(Oprd0);
4353                     Inst->eraseFromParent();
4354                     return true;
4355                 }
4356             }
4357             break;
4358         }
4359         case Instruction::FMul:
4360         {
4361             for (int i = 0; i < 2; ++i)
4362             {
4363                 ConstantFP* CF = dyn_cast<ConstantFP>(Inst->getOperand(i));
4364                 if (CF && CF->isZero())
4365                 {
4366                     Inst->replaceAllUsesWith(CF);
4367                     Inst->eraseFromParent();
4368                     return true;
4369                 }
4370             }
4371             break;
4372         }
4373         case Instruction::FAdd:
4374         {
4375             for (int i = 0; i < 2; ++i)
4376             {
4377                 ConstantFP* CF = dyn_cast<ConstantFP>(Inst->getOperand(i));
4378                 if (CF && CF->isZero())
4379                 {
4380                     Value* otherOprd = Inst->getOperand(1 - i);
4381                     Inst->replaceAllUsesWith(otherOprd);
4382                     Inst->eraseFromParent();
4383                     return true;
4384                 }
4385             }
4386             break;
4387         }
4388         }
4389     }
4390 
4391     // fdiv -> inv + mul. On gen, fdiv seems always slower
4392     // than inv + mul. Do it if fdiv's fastMathFlag allows it.
4393     //
4394     // Rewrite
4395     // %1 = fdiv arcp float %x, %z
4396     // into
4397     // %1 = fdiv arcp float 1.0, %z
4398     // %2 = fmul arcp float %x, %1
4399     if (isCandidateFDiv(Inst))
4400     {
4401         Value* Src1 = Inst->getOperand(1);
4402         if (isa<Constant>(Src1))
4403         {
4404             // should not happen (but do see "fdiv  x / 0.0f"). Skip.
4405             return false;
4406         }
4407 
4408         Value* Src0 = ConstantFP::get(Inst->getType(), 1.0);
4409         Instruction* Inv = nullptr;
4410 
4411         // Check if there is any other (x / Src1). If so, commonize 1/Src1.
4412         for (auto UI = Src1->user_begin(), UE = Src1->user_end();
4413             UI != UE; ++UI)
4414         {
4415             Value* Val = *UI;
4416             Instruction* I = dyn_cast<Instruction>(Val);
4417             if (I && I != Inst && I->getOpcode() == Instruction::FDiv &&
4418                 I->getOperand(1) == Src1 && isCandidateFDiv(I))
4419             {
4420                 // special case
4421                 if (ConstantFP * CF = dyn_cast<ConstantFP>(I->getOperand(0)))
4422                 {
4423                     if (CF->isZero())
4424                     {
4425                         // Skip this one.
4426                         continue;
4427                     }
4428                 }
4429 
4430                 // Found another 1/Src1. Insert Inv right after the def of Src1
4431                 // or in the entry BB if Src1 is an argument.
4432                 if (!Inv)
4433                 {
4434                     Instruction* insertBefore = dyn_cast<Instruction>(Src1);
4435                     if (insertBefore)
4436                     {
4437                         if (isa<PHINode>(insertBefore))
4438                         {
4439                             BasicBlock* BB = insertBefore->getParent();
4440                             insertBefore = &(*BB->getFirstInsertionPt());
4441                         }
4442                         else
4443                         {
4444                             // Src1 is an instruction
4445                             BasicBlock::iterator iter(insertBefore);
4446                             ++iter;
4447                             insertBefore = &(*iter);
4448                         }
4449                     }
4450                     else
4451                     {
4452                         // Src1 is an argument and insert at the begin of entry BB
4453                         BasicBlock& entryBB = Inst->getParent()->getParent()->getEntryBlock();
4454                         insertBefore = &(*entryBB.getFirstInsertionPt());
4455                     }
4456                     Inv = BinaryOperator::CreateFDiv(Src0, Src1, "", insertBefore);
4457                     Inv->setFastMathFlags(Inst->getFastMathFlags());
4458                 }
4459 
4460                 Instruction* Mul = BinaryOperator::CreateFMul(I->getOperand(0), Inv, "", I);
4461                 Mul->setFastMathFlags(Inst->getFastMathFlags());
4462                 I->replaceAllUsesWith(Mul);
4463                 // Don't erase it as doing so would invalidate iterator in this func's caller
4464                 // Instead, erase it in the caller.
4465                 // I->eraseFromParent();
4466             }
4467         }
4468 
4469         if (!Inv)
4470         {
4471             // Only a single use of 1 / Src1. Create Inv right before the use.
4472             Inv = BinaryOperator::CreateFDiv(Src0, Src1, "", Inst);
4473             Inv->setFastMathFlags(Inst->getFastMathFlags());
4474         }
4475 
4476         auto Mul = BinaryOperator::CreateFMul(Inst->getOperand(0), Inv, "", Inst);
4477         Mul->setFastMathFlags(Inst->getFastMathFlags());
4478         Inst->replaceAllUsesWith(Mul);
4479         Inst->eraseFromParent();
4480         return true;
4481     }
4482 
4483     return false;
4484 }
4485 
4486 IGC_INITIALIZE_PASS_BEGIN(GenStrengthReduction, "GenStrengthReduction",
4487     "GenStrengthReduction", false, false)
4488     IGC_INITIALIZE_PASS_END(GenStrengthReduction, "GenStrengthReduction",
4489         "GenStrengthReduction", false, false)
4490 
4491 
4492     /*========================== FlattenSmallSwitch ==============================
4493 
4494     This class flatten small switch. For example,
4495 
4496     before optimization:
4497         then153:
4498         switch i32 %115, label %else229 [
4499         i32 1, label %then214
4500         i32 2, label %then222
4501         i32 3, label %then222 ; duplicate blocks are fine
4502         ]
4503 
4504         then214:                                          ; preds = %then153
4505         %150 = fdiv float 1.000000e+00, %res_s208
4506         %151 = fmul float %147, %150
4507         br label %ifcont237
4508 
4509         then222:                                          ; preds = %then153
4510         %152 = fsub float 1.000000e+00, %141
4511         br label %ifcont237
4512 
4513         else229:                                          ; preds = %then153
4514         %res_s230 = icmp eq i32 %115, 3
4515         %. = select i1 %res_s230, float 1.000000e+00, float 0.000000e+00
4516         br label %ifcont237
4517 
4518         ifcont237:                                        ; preds = %else229, %then222, %then214
4519         %"r[9][0].x.0" = phi float [ %151, %then214 ], [ %152, %then222 ], [ %., %else229 ]
4520 
4521     after optimization:
4522         %res_s230 = icmp eq i32 %115, 3
4523         %. = select i1 %res_s230, float 1.000000e+00, float 0.000000e+00
4524         %150 = fdiv float 1.000000e+00, %res_s208
4525         %151 = fmul float %147, %150
4526         %152 = icmp eq i32 %115, 1
4527         %153 = select i1 %152, float %151, float %.
4528         %154 = fsub float 1.000000e+00, %141
4529         %155 = icmp eq i32 %115, 2
4530         %156 = select i1 %155, float %154, float %153
4531 
4532     =============================================================================*/
4533     namespace {
4534     class FlattenSmallSwitch : public FunctionPass
4535     {
4536     public:
4537         static char ID;
FlattenSmallSwitch()4538         FlattenSmallSwitch() : FunctionPass(ID)
4539         {
4540             initializeFlattenSmallSwitchPass(*PassRegistry::getPassRegistry());
4541         }
getPassName() const4542         virtual llvm::StringRef getPassName() const { return "Flatten Small Switch"; }
4543         virtual bool runOnFunction(Function& F);
4544         bool processSwitchInst(SwitchInst* SI);
4545     };
4546 
4547 } // namespace
4548 
4549 char FlattenSmallSwitch::ID = 0;
createFlattenSmallSwitchPass()4550 FunctionPass* IGC::createFlattenSmallSwitchPass() { return new FlattenSmallSwitch(); }
4551 
processSwitchInst(SwitchInst * SI)4552 bool FlattenSmallSwitch::processSwitchInst(SwitchInst* SI)
4553 {
4554     const unsigned maxSwitchCases = 3;  // only apply to switch with 3 cases or less
4555     const unsigned maxCaseInsts = 3;    // only apply optimization when each case has 3 instructions or less.
4556 
4557     BasicBlock* Default = SI->getDefaultDest();
4558     Value* Val = SI->getCondition();  // The value we are switching on...
4559     IRBuilder<> builder(SI);
4560 
4561     if (SI->getNumCases() > maxSwitchCases || SI->getNumCases() == 0)
4562     {
4563         return false;
4564     }
4565 
4566     // Dest will be the block that the control flow from the switch merges to.
4567     // Currently, there are two options:
4568     // 1. The Dest block is the default block from the switch
4569     // 2. The Dest block is jumped to by all of the switch cases (and the default)
4570     BasicBlock* Dest = nullptr;
4571     {
4572         const auto* CaseSucc =
4573 #if LLVM_VERSION_MAJOR == 4
4574             SI->case_begin().getCaseSuccessor();
4575 #elif LLVM_VERSION_MAJOR >= 7
4576             SI->case_begin()->getCaseSuccessor();
4577 #endif
4578         auto* BI = dyn_cast<BranchInst>(CaseSucc->getTerminator());
4579 
4580         if (BI == nullptr)
4581             return false;
4582 
4583         if (BI->isConditional())
4584             return false;
4585 
4586         // We know the first case jumps to this block.  Now let's
4587         // see below whether all the cases jump to this same block.
4588         Dest = BI->getSuccessor(0);
4589     }
4590 
4591     // Does BB unconditionally branch to MergeBlock?
4592     auto branchPattern = [](const BasicBlock* BB, const BasicBlock* MergeBlock)
4593     {
4594         auto* br = dyn_cast<BranchInst>(BB->getTerminator());
4595 
4596         if (br == nullptr)
4597             return false;
4598 
4599         if (br->isConditional())
4600             return false;
4601 
4602         if (br->getSuccessor(0) != MergeBlock)
4603             return false;
4604 
4605         if (!BB->getUniquePredecessor())
4606             return false;
4607 
4608         return true;
4609     };
4610 
4611     // We can speculatively execute a basic block if it
4612     // is small, unconditionally branches to Dest, and doesn't
4613     // have high latency or unsafe to speculate instructions.
4614     auto canSpeculateBlock = [&](BasicBlock* BB)
4615     {
4616         if (BB->size() > maxCaseInsts)
4617             return false;
4618 
4619         if (!branchPattern(BB, Dest))
4620             return false;
4621 
4622         for (auto& I : *BB)
4623         {
4624             auto* inst = &I;
4625 
4626             if (isa<BranchInst>(inst))
4627                 continue;
4628 
4629             // if there is any high-latency instruction in the switch,
4630             // don't flatten it
4631             if (isSampleInstruction(inst) ||
4632                 isGather4Instruction(inst) ||
4633                 isInfoInstruction(inst) ||
4634                 isLdInstruction(inst) ||
4635                 // If the instruction can't be speculated (e.g., phi node),
4636                 // punt.
4637                 !isSafeToSpeculativelyExecute(inst))
4638             {
4639                 return false;
4640             }
4641         }
4642 
4643         return true;
4644     };
4645 
4646     // Are all Phi incomming blocks from SI switch?
4647     auto checkPhiPredecessorBlocks = [](const SwitchInst* SI, const PHINode* Phi, bool DefaultMergeBlock)
4648     {
4649         for (auto* BB : Phi->blocks())
4650         {
4651             if (BB == SI->getDefaultDest())
4652                 continue;
4653             bool successorFound = false;
4654             for (auto Case : SI->cases()) {
4655                 if (Case.getCaseSuccessor() == BB)
4656                     successorFound = true;
4657             }
4658             if (successorFound)
4659                 continue;
4660             return false;
4661         }
4662         return true;
4663     };
4664 
4665     for (auto& I : SI->cases())
4666     {
4667         BasicBlock* CaseDest = I.getCaseSuccessor();
4668 
4669         if (!canSpeculateBlock(CaseDest))
4670             return false;
4671     }
4672 
4673     // Is the default case of the switch the block
4674     // where all other cases meet?
4675     const bool DefaultMergeBlock = (Dest == Default);
4676 
4677     // If we merge to the default block, there is no block
4678     // we jump to beforehand so there is nothing to
4679     // speculate.
4680     if (!DefaultMergeBlock && !canSpeculateBlock(Default))
4681         return false;
4682 
4683     // Get all PHI nodes that needs to be replaced
4684     SmallVector<PHINode*, 4> PhiNodes;
4685     for (auto& I : *Dest)
4686     {
4687         auto* Phi = dyn_cast<PHINode>(&I);
4688 
4689         if (!Phi)
4690             break;
4691 
4692         if (!checkPhiPredecessorBlocks(SI, Phi, DefaultMergeBlock))
4693             return false;
4694 
4695         PhiNodes.push_back(Phi);
4696     }
4697 
4698     if (PhiNodes.empty())
4699         return false;
4700 
4701     // Move all instructions except the last (i.e., the branch)
4702     // from BB to the InsertPoint.
4703     auto splice = [](BasicBlock* BB, Instruction* InsertPoint)
4704     {
4705         for (auto II = BB->begin(), IE = BB->end(); II != IE; /* empty */)
4706         {
4707             auto* I = &*II++;
4708             if (I->isTerminator())
4709                 return;
4710 
4711             I->moveBefore(InsertPoint);
4712         }
4713     };
4714 
4715     // move default block out
4716     if (!DefaultMergeBlock)
4717         splice(Default, SI);
4718 
4719     // move case blocks out
4720     for (auto& I : SI->cases())
4721     {
4722         BasicBlock* CaseDest = I.getCaseSuccessor();
4723         splice(CaseDest, SI);
4724     }
4725 
4726     // replaces PHI with select
4727     for (auto* Phi : PhiNodes)
4728     {
4729         Value* vTemp = Phi->getIncomingValueForBlock(
4730             DefaultMergeBlock ? SI->getParent() : Default);
4731 
4732         for (auto& I : SI->cases())
4733         {
4734             BasicBlock* CaseDest = I.getCaseSuccessor();
4735             ConstantInt* CaseValue = I.getCaseValue();
4736 
4737             Value* selTrueValue = Phi->getIncomingValueForBlock(CaseDest);
4738             builder.SetInsertPoint(SI);
4739             Value* cmp = builder.CreateICmp(CmpInst::Predicate::ICMP_EQ, Val, CaseValue);
4740             Value* sel = builder.CreateSelect(cmp, selTrueValue, vTemp);
4741             vTemp = sel;
4742         }
4743 
4744         Phi->replaceAllUsesWith(vTemp);
4745         Phi->removeFromParent();
4746     }
4747 
4748     // connect the original block and the phi node block with a pass through branch
4749     builder.CreateBr(Dest);
4750 
4751     SmallPtrSet<BasicBlock*, 4> Succs;
4752 
4753     // Remove the switch.
4754     BasicBlock* SelectBB = SI->getParent();
4755     for (unsigned i = 0, e = SI->getNumSuccessors(); i < e; ++i)
4756     {
4757         BasicBlock* Succ = SI->getSuccessor(i);
4758         if (Succ == Dest)
4759         {
4760             continue;
4761         }
4762 
4763         if (Succs.insert(Succ).second)
4764             Succ->removePredecessor(SelectBB);
4765     }
4766     SI->eraseFromParent();
4767 
4768     return true;
4769 }
4770 
runOnFunction(Function & F)4771 bool FlattenSmallSwitch::runOnFunction(Function& F)
4772 {
4773     bool Changed = false;
4774     for (Function::iterator I = F.begin(), E = F.end(); I != E; )
4775     {
4776         BasicBlock* Cur = &*I++; // Advance over block so we don't traverse new blocks
4777         if (SwitchInst * SI = dyn_cast<SwitchInst>(Cur->getTerminator()))
4778         {
4779             Changed |= processSwitchInst(SI);
4780         }
4781     }
4782     return Changed;
4783 }
4784 
4785 IGC_INITIALIZE_PASS_BEGIN(FCmpPaternMatch, "FCmpPaternMatch", "FCmpPaternMatch", false, false)
4786 IGC_INITIALIZE_PASS_END(FCmpPaternMatch, "FCmpPaternMatch", "FCmpPaternMatch", false, false)
4787 
4788 char FCmpPaternMatch::ID = 0;
4789 
FCmpPaternMatch()4790 FCmpPaternMatch::FCmpPaternMatch() : FunctionPass(ID)
4791 {
4792     initializeFCmpPaternMatchPass(*PassRegistry::getPassRegistry());
4793 }
4794 
runOnFunction(Function & F)4795 bool FCmpPaternMatch::runOnFunction(Function& F)
4796 {
4797     bool change = true;
4798     visit(F);
4799     return change;
4800 }
4801 
visitSelectInst(SelectInst & I)4802 void FCmpPaternMatch::visitSelectInst(SelectInst& I)
4803 {
4804     /*
4805     from
4806     %7 = fcmp olt float %5, %6
4807     %8 = select i1 %7, float 1.000000e+00, float 0.000000e+00
4808     %9 = bitcast float %8 to i32
4809     %10 = icmp eq i32 %9, 0
4810     %.01 = select i1 %10, float %4, float %11
4811     br i1 %10, label %endif, label %then
4812     to
4813     %7 = fcmp olt float %5, %6
4814     %.01 = select i1 %7, float %11, float %4
4815     br i1 %7, label %then, label %endif
4816     */
4817     {
4818         bool swapNodesFromSel = false;
4819         bool isSelWithConstants = false;
4820         ConstantFP* Cfp1 = dyn_cast<ConstantFP>(I.getOperand(1));
4821         ConstantFP* Cfp2 = dyn_cast<ConstantFP>(I.getOperand(2));
4822         if (Cfp1 && Cfp1->getValueAPF().isFiniteNonZero() &&
4823             Cfp2 && Cfp2->isZero())
4824         {
4825             isSelWithConstants = true;
4826         }
4827         if (Cfp1 && Cfp1->isZero() &&
4828             Cfp2 && Cfp2->getValueAPF().isFiniteNonZero())
4829         {
4830             isSelWithConstants = true;
4831             swapNodesFromSel = true;
4832         }
4833         if (isSelWithConstants &&
4834             dyn_cast<FCmpInst>(I.getOperand(0)))
4835         {
4836             for (auto bitCastI : I.users())
4837             {
4838                 if (BitCastInst * bitcastInst = dyn_cast<BitCastInst>(bitCastI))
4839                 {
4840                     for (auto cmpI : bitcastInst->users())
4841                     {
4842                         ICmpInst* iCmpInst = dyn_cast<ICmpInst>(cmpI);
4843                         if (iCmpInst &&
4844                             iCmpInst->isEquality())
4845                         {
4846                             ConstantInt* icmpC = dyn_cast<ConstantInt>(iCmpInst->getOperand(1));
4847                             if (!icmpC || !icmpC->isZero())
4848                             {
4849                                 continue;
4850                             }
4851 
4852                             bool swapNodes = swapNodesFromSel;
4853                             if (iCmpInst->getPredicate() == CmpInst::Predicate::ICMP_EQ)
4854                             {
4855                                 swapNodes = (swapNodes != true);
4856                             }
4857 
4858                             SmallVector<Instruction*, 4> matchedBrSelInsts;
4859                             for (auto brOrSelI : iCmpInst->users())
4860                             {
4861                                 BranchInst* brInst = dyn_cast<BranchInst>(brOrSelI);
4862                                 if (brInst &&
4863                                     brInst->isConditional())
4864                                 {
4865                                     //match
4866                                     matchedBrSelInsts.push_back(brInst);
4867                                     if (swapNodes)
4868                                     {
4869                                         brInst->swapSuccessors();
4870                                     }
4871                                 }
4872 
4873                                 if (SelectInst * selInst = dyn_cast<SelectInst>(brOrSelI))
4874                                 {
4875                                     //match
4876                                     matchedBrSelInsts.push_back(selInst);
4877                                     if (swapNodes)
4878                                     {
4879                                         Value* selTrue = selInst->getTrueValue();
4880                                         Value* selFalse = selInst->getFalseValue();
4881                                         selInst->setTrueValue(selFalse);
4882                                         selInst->setFalseValue(selTrue);
4883                                         selInst->swapProfMetadata();
4884                                     }
4885                                 }
4886                             }
4887                             for (Instruction* inst : matchedBrSelInsts)
4888                             {
4889                                 inst->setOperand(0, I.getOperand(0));
4890                             }
4891                         }
4892                     }
4893                 }
4894             }
4895         }
4896     }
4897 }
4898 
4899 IGC_INITIALIZE_PASS_BEGIN(FlattenSmallSwitch, "flattenSmallSwitch", "flattenSmallSwitch", false, false)
4900 IGC_INITIALIZE_PASS_END(FlattenSmallSwitch, "flattenSmallSwitch", "flattenSmallSwitch", false, false)
4901 
4902 
4903 
4904 /*======================== SplitIndirectEEtoSel =============================
4905 
4906 This class changes extract element for small vectors to series of cmp+sel to avoid VxH mov.
4907 before:
4908   %268 = mul nuw i32 %res.i2.i, 3
4909   %269 = extractelement <12 x float> %234, i32 %268
4910   %270 = add i32 %268, 1
4911   %271 = extractelement <12 x float> %234, i32 %270
4912   %272 = add i32 %268, 2
4913   %273 = extractelement <12 x float> %234, i32 %272
4914   %274 = extractelement <12 x float> %198, i32 %268
4915   %275 = extractelement <12 x float> %198, i32 %270
4916   %276 = extractelement <12 x float> %198, i32 %272
4917 
4918 after:
4919   %250 = icmp eq i32 %res.i2.i, i16 1
4920   %251 = select i1 %250, float %206, float %200
4921   %252 = select i1 %250, float %208, float %202
4922   %253 = select i1 %250, float %210, float %204
4923   %254 = select i1 %250, float %48, float %32
4924   %255 = select i1 %250, float %49, float %33
4925   %256 = select i1 %250, float %50, float %34
4926   %257 = icmp eq i32 %res.i2.i, i16 2
4927   %258 = select i1 %257, float %214, float %251
4928   %259 = select i1 %257, float %215, float %252
4929   %260 = select i1 %257, float %216, float %253
4930   %261 = select i1 %257, float %64, float %254
4931   %262 = select i1 %257, float %65, float %255
4932   %263 = select i1 %257, float %66, float %256
4933 
4934   It is a bit similar to SimplifyConstant::isCmpSelProfitable for OCL, but not restricted to api.
4935   And to GenSimplification::visitExtractElement() but not restricted to vec of 2, and later.
4936   TODO: for known vectors check how many unique items there are.
4937 ===========================================================================*/
4938 namespace {
4939     class SplitIndirectEEtoSel : public FunctionPass, public llvm::InstVisitor<SplitIndirectEEtoSel>
4940     {
4941     public:
4942         static char ID;
SplitIndirectEEtoSel()4943         SplitIndirectEEtoSel() : FunctionPass(ID)
4944         {
4945             initializeSplitIndirectEEtoSelPass(*PassRegistry::getPassRegistry());
4946         }
getPassName() const4947         virtual llvm::StringRef getPassName() const { return "Split Indirect EE to ICmp Plus Sel"; }
4948         virtual bool runOnFunction(Function& F);
4949         void visitExtractElementInst(llvm::ExtractElementInst& I);
4950     private:
4951         bool isProfitableToSplit(uint64_t num, int64_t mul, int64_t add);
4952         bool didSomething;
4953     };
4954 
4955 } // namespace
4956 
4957 
4958 char SplitIndirectEEtoSel::ID = 0;
createSplitIndirectEEtoSelPass()4959 FunctionPass* IGC::createSplitIndirectEEtoSelPass() { return new SplitIndirectEEtoSel(); }
4960 
runOnFunction(Function & F)4961 bool SplitIndirectEEtoSel::runOnFunction(Function& F)
4962 {
4963     didSomething = false;
4964     visit(F);
4965     return didSomething;
4966 }
4967 
isProfitableToSplit(uint64_t num,int64_t mul,int64_t add)4968 bool SplitIndirectEEtoSel::isProfitableToSplit(uint64_t num, int64_t mul, int64_t add)
4969 {
4970     /* Assumption:
4971        Pass is profitable when: (X * cmp + Y * sel) < (ExecSize * mov VxH).
4972     */
4973 
4974     const int64_t assumedVXHCost = IGC_GET_FLAG_VALUE(SplitIndirectEEtoSelThreshold);
4975     int64_t possibleCost = 0;
4976 
4977     /* for: extractelement <4 x float> , %index
4978        cost is (4 - 1)  * (icmp + sel) = 6;
4979     */
4980     possibleCost = ((int64_t)num -1) * 2;
4981     if (possibleCost < assumedVXHCost)
4982         return true;
4983 
4984     /* for: extractelement <12 x float> , (mul %real_index, 3)
4985        cost is ((12/3) - 1) * (icmp + sel) = 6;
4986     */
4987 
4988     if (mul > 0) // not tested negative options
4989     {
4990         int64_t differentOptions = 1 + ((int64_t)num - 1) / mul; // ceil(num/mul)
4991         possibleCost = (differentOptions - 1) * 2;
4992 
4993         if (possibleCost < assumedVXHCost)
4994             return true;
4995     }
4996 
4997     return false;
4998 }
4999 
visitExtractElementInst(llvm::ExtractElementInst & I)5000 void SplitIndirectEEtoSel::visitExtractElementInst(llvm::ExtractElementInst& I)
5001 {
5002     using namespace llvm::PatternMatch;
5003 
5004     IGCLLVM::FixedVectorType* vecTy = dyn_cast<IGCLLVM::FixedVectorType>(I.getVectorOperandType());
5005     IGC_ASSERT( vecTy != nullptr );
5006     uint64_t num = vecTy->getNumElements();
5007     Type* eleType = vecTy->getElementType();
5008 
5009     Value* vec = I.getVectorOperand();
5010     Value* index = I.getIndexOperand();
5011 
5012     // ignore constant index
5013     if (dyn_cast<ConstantInt>(index))
5014     {
5015         return;
5016     }
5017 
5018     // ignore others for now (did not yet evaluate perf. impact)
5019     if (!(eleType->isIntegerTy(32) || eleType->isFloatTy()))
5020     {
5021         return;
5022     }
5023 
5024     // used to calculate offsets
5025     int64_t add = 0;
5026     int64_t mul = 1;
5027 
5028 
5029     /* strip mul/add from index calculation and remember it for later:
5030        %268 = mul nuw i32 %res.i2.i, 3
5031        %270 = add i32 %268, 1
5032        %271 = extractelement <12 x float> %234, i32 %270
5033     */
5034     Value* Val1 = nullptr;
5035     ConstantInt* ci_add = nullptr;
5036     ConstantInt* ci_mul = nullptr;
5037 
5038     auto pat1 = m_Add(m_Mul(m_Value(Val1), m_ConstantInt(ci_mul)), m_ConstantInt(ci_add));
5039     auto pat2 = m_Mul(m_Value(Val1), m_ConstantInt(ci_mul));
5040     // Some code shows `shl+or` instead of mul+add.
5041     auto pat21 = m_Or(m_Shl(m_Value(Val1), m_ConstantInt(ci_mul)), m_ConstantInt(ci_add));
5042     auto pat22 = m_Shl(m_Value(Val1), m_ConstantInt(ci_mul));
5043 
5044     if (match(index, pat1) || match(index, pat2))
5045     {
5046         add = ci_add ? ci_add->getSExtValue() : 0;
5047         mul = ci_mul ? ci_mul->getSExtValue() : 1;
5048         index = Val1;
5049     }
5050     else if (match(index, pat21) || match(index, pat22))
5051     {
5052         add = ci_add ? ci_add->getSExtValue() : 0;
5053         mul = ci_mul ? (1LL << ci_mul->getSExtValue()) : 1LL;
5054         index = Val1;
5055     }
5056 
5057     if (!isProfitableToSplit(num, mul, add))
5058         return;
5059 
5060     Value* vTemp = llvm::UndefValue::get(eleType);
5061     IRBuilder<> builder(I.getNextNode());
5062 
5063     // returns true if we can skip this icmp, such as:
5064     // icmp eq (add (mul %index, 3), 2), 1
5065     // icmp eq (mul %index, 3), 1
5066     auto canSafelySkipThis = [&](int64_t add, int64_t mul, int64_t & newIndex) {
5067         if (mul)
5068         {
5069             newIndex -= add;
5070             if ((newIndex % mul) != 0)
5071                 return true;
5072             newIndex = newIndex / mul;
5073         }
5074         return false;
5075     };
5076 
5077     // Generate combinations
5078     for (uint64_t elemIndex = 0; elemIndex < num; elemIndex++)
5079     {
5080         int64_t cmpIndex = elemIndex;
5081 
5082         if (canSafelySkipThis(add, mul, cmpIndex))
5083             continue;
5084 
5085         // Those 2 might be different, when cmp will get altered by it's operands, but EE index stays the same
5086         ConstantInt* cmpIndexCI = llvm::ConstantInt::get(builder.getInt32Ty(), (uint64_t)cmpIndex);
5087         ConstantInt* eeiIndexCI = llvm::ConstantInt::get(builder.getInt32Ty(), (uint64_t)elemIndex);
5088 
5089         Value* cmp = builder.CreateICmp(CmpInst::Predicate::ICMP_EQ, index, cmpIndexCI);
5090         Value* subcaseEE = builder.CreateExtractElement(vec, eeiIndexCI);
5091         Value* sel = builder.CreateSelect(cmp, subcaseEE, vTemp);
5092         vTemp = sel;
5093         didSomething = true;
5094     }
5095 
5096     // In theory there's no situation where we don't do something up to this point.
5097     if (didSomething)
5098     {
5099         I.replaceAllUsesWith(vTemp);
5100     }
5101 }
5102 
5103 
5104 IGC_INITIALIZE_PASS_BEGIN(SplitIndirectEEtoSel, "SplitIndirectEEtoSel", "SplitIndirectEEtoSel", false, false)
5105 IGC_INITIALIZE_PASS_END(SplitIndirectEEtoSel, "SplitIndirectEEtoSel", "SplitIndirectEEtoSel", false, false)
5106 
5107 ////////////////////////////////////////////////////////////////////////
5108 // LogicalAndToBranch trying to find logical AND like below:
5109 //    res = simpleCond0 && complexCond1
5110 // and convert it to:
5111 //    if simpleCond0
5112 //        res = complexCond1
5113 //    else
5114 //        res = false
5115 namespace {
5116     class LogicalAndToBranch : public FunctionPass
5117     {
5118     public:
5119         static char ID;
5120         const int NUM_INST_THRESHOLD = 32;
5121         LogicalAndToBranch();
5122 
getPassName() const5123         StringRef getPassName() const override { return "LogicalAndToBranch"; }
5124 
5125         bool runOnFunction(Function& F) override;
5126 
5127     protected:
5128         SmallPtrSet<Instruction*, 8> m_sched;
5129 
5130         // schedule instruction up before insertPos
5131         bool scheduleUp(BasicBlock* bb, Value* V, Instruction*& insertPos);
5132 
5133         // check if it's safe to convert instructions between cond0 & cond1,
5134         // moveInsts are the values referened out of (cond0, cond1), we need to
5135         // move them before cond0
5136         bool isSafeToConvert(Instruction* cond0, Instruction* cond1,
5137             smallvector<Instruction*, 8> & moveInsts);
5138 
5139         void convertAndToBranch(Instruction* opAnd,
5140             Instruction* cond0, Instruction* cond1, BasicBlock*& newBB);
5141     };
5142 
5143 }
5144 
5145 IGC_INITIALIZE_PASS_BEGIN(LogicalAndToBranch, "logicalAndToBranch", "logicalAndToBranch", false, false)
5146 IGC_INITIALIZE_PASS_END(LogicalAndToBranch, "logicalAndToBranch", "logicalAndToBranch", false, false)
5147 
5148 char LogicalAndToBranch::ID = 0;
createLogicalAndToBranchPass()5149 FunctionPass* IGC::createLogicalAndToBranchPass() { return new LogicalAndToBranch(); }
5150 
LogicalAndToBranch()5151 LogicalAndToBranch::LogicalAndToBranch() : FunctionPass(ID)
5152 {
5153     initializeLogicalAndToBranchPass(*PassRegistry::getPassRegistry());
5154 }
5155 
scheduleUp(BasicBlock * bb,Value * V,Instruction * & insertPos)5156 bool LogicalAndToBranch::scheduleUp(BasicBlock* bb, Value* V,
5157     Instruction*& insertPos)
5158 {
5159     Instruction* inst = dyn_cast<Instruction>(V);
5160     if (!inst)
5161         return false;
5162 
5163     if (inst->getParent() != bb || isa<PHINode>(inst))
5164         return false;
5165 
5166     if (m_sched.count(inst))
5167     {
5168         if (insertPos && !isInstPrecede(inst, insertPos))
5169             insertPos = inst;
5170         return false;
5171     }
5172 
5173     bool changed = false;
5174 
5175     for (auto OI = inst->op_begin(), OE = inst->op_end(); OI != OE; ++OI)
5176     {
5177         changed |= scheduleUp(bb, OI->get(), insertPos);
5178     }
5179     m_sched.insert(inst);
5180 
5181     if (insertPos && isInstPrecede(inst, insertPos))
5182         return changed;
5183 
5184     if (insertPos) {
5185         inst->removeFromParent();
5186         inst->insertBefore(insertPos);
5187     }
5188 
5189     return true;
5190 }
5191 
5192 // split original basic block from:
5193 //   original BB:
5194 //     %cond0 =
5195 //     ...
5196 //     %cond1 =
5197 //     %andRes = and i1 %cond0, %cond1
5198 //     ...
5199 // to:
5200 //    original BB:
5201 //      %cond0 =
5202 //      if %cond0, if.then, if.else
5203 //
5204 //    if.then:
5205 //      ...
5206 //      %cond1 =
5207 //      br if.end
5208 //
5209 //    if.else:
5210 //      br if.end
5211 //
5212 //    if.end:
5213 //       %andRes = phi [%cond1, if.then], [false, if.else]
5214 //       ...
convertAndToBranch(Instruction * opAnd,Instruction * cond0,Instruction * cond1,BasicBlock * & newBB)5215 void LogicalAndToBranch::convertAndToBranch(Instruction* opAnd,
5216     Instruction* cond0, Instruction* cond1, BasicBlock*& newBB)
5217 {
5218     BasicBlock* bb = opAnd->getParent();
5219     BasicBlock* bbThen, * bbElse, * bbEnd;
5220 
5221     bbThen = bb->splitBasicBlock(cond0->getNextNode(), "if.then");
5222     bbElse = bbThen->splitBasicBlock(opAnd, "if.else");
5223     bbEnd = bbElse->splitBasicBlock(opAnd, "if.end");
5224 
5225     bb->getTerminator()->eraseFromParent();
5226     BranchInst* br = BranchInst::Create(bbThen, bbElse, cond0, bb);
5227 
5228     bbThen->getTerminator()->eraseFromParent();
5229     br = BranchInst::Create(bbEnd, bbThen);
5230 
5231     PHINode* phi = PHINode::Create(opAnd->getType(), 2, "", opAnd);
5232     phi->addIncoming(cond1, bbThen);
5233     phi->addIncoming(ConstantInt::getFalse(opAnd->getType()), bbElse);
5234     opAnd->replaceAllUsesWith(phi);
5235     opAnd->eraseFromParent();
5236 
5237     newBB = bbEnd;
5238 }
5239 
isSafeToConvert(Instruction * cond0,Instruction * cond1,smallvector<Instruction *,8> & moveInsts)5240 bool LogicalAndToBranch::isSafeToConvert(
5241     Instruction* cond0, Instruction* cond1,
5242     smallvector<Instruction*, 8> & moveInsts)
5243 {
5244     BasicBlock::iterator is0(cond0);
5245     BasicBlock::iterator is1(cond1);
5246 
5247     bool isSafe = true;
5248     SmallPtrSet<Value*, 32> iset;
5249 
5250     iset.insert(cond1);
5251     for (auto i = ++is0; i != is1; ++i)
5252     {
5253         if ((*i).mayHaveSideEffects())
5254         {
5255             isSafe = false;
5256             break;
5257         }
5258         iset.insert(&(*i));
5259     }
5260 
5261     if (!isSafe)
5262     {
5263         return false;
5264     }
5265 
5266     is0 = cond0->getIterator();
5267     // check if the values in between are used elsewhere
5268     for (auto i = ++is0; i != is1; ++i)
5269     {
5270         Instruction* inst = &*i;
5271         for (auto ui : inst->users())
5272         {
5273             if (iset.count(ui) == 0)
5274             {
5275                 moveInsts.push_back(inst);
5276                 break;
5277             }
5278         }
5279     }
5280     return isSafe;
5281 }
5282 
runOnFunction(Function & F)5283 bool LogicalAndToBranch::runOnFunction(Function& F)
5284 {
5285     bool changed = false;
5286     if (IGC_IS_FLAG_DISABLED(EnableLogicalAndToBranch))
5287     {
5288         return changed;
5289     }
5290 
5291     for (auto BI = F.begin(), BE = F.end(); BI != BE; )
5292     {
5293         // advance iterator before handling current BB
5294         BasicBlock* bb = &*BI++;
5295 
5296         for (auto II = bb->begin(), IE = bb->end(); II != IE; )
5297         {
5298             Instruction* inst = &(*II++);
5299 
5300             // search for "and i1"
5301             if (inst->getOpcode() == BinaryOperator::And &&
5302                 inst->getType()->isIntegerTy(1))
5303             {
5304                 Instruction* s0 = dyn_cast<Instruction>(inst->getOperand(0));
5305                 Instruction* s1 = dyn_cast<Instruction>(inst->getOperand(1));
5306                 if (s0 && s1 &&
5307                     !isa<PHINode>(s0) && !isa<PHINode>(s1) &&
5308                     s0->hasOneUse() && s1->hasOneUse() &&
5309                     s0->getParent() == bb && s1->getParent() == bb)
5310                 {
5311                     if (isInstPrecede(s1, s0))
5312                     {
5313                         std::swap(s0, s1);
5314                     }
5315                     BasicBlock::iterator is0(s0);
5316                     BasicBlock::iterator is1(s1);
5317 
5318                     if (std::distance(is0, is1) < NUM_INST_THRESHOLD)
5319                     {
5320                         continue;
5321                     }
5322 
5323                     smallvector<Instruction*, 8> moveInsts;
5324                     if (isSafeToConvert(s0, inst, moveInsts))
5325                     {
5326                         // if values defined between s0 & inst(branch) are referenced
5327                         // outside of (s0, inst), they need to be moved before
5328                         // s0 to keep SSA form.
5329                         for (auto inst : moveInsts)
5330                             scheduleUp(bb, inst, s0);
5331                         m_sched.clear();
5332 
5333                         // IE need to be updated since original BB is splited
5334                         convertAndToBranch(inst, s0, s1, bb);
5335                         IE = bb->end();
5336                         changed = true;
5337                     }
5338                 }
5339             }
5340         }
5341     }
5342 
5343     return changed;
5344 }
5345 
5346 // clean PHINode does the following:
5347 //   given the following:
5348 //     a = phi (x, b0), (x, b1)
5349 //   this pass will replace 'a' with 'x', and as result, phi is removed.
5350 //
5351 // Special note:
5352 //   LCSSA PHINode has a single incoming value. Make sure it is not removed
5353 //   as WIA uses lcssa phi as a seperator between a uniform value inside loop
5354 //   and non-uniform value outside a loop.  For example:
5355 //      B0:
5356 //             i = 0;
5357 //      Loop:
5358 //             i_0 = phi (0, B0)  (t, Bn)
5359 //             .... <use i_0>
5360 //             if (divergent cond)
5361 //      Bi:
5362 //                goto out;
5363 //      Bn:
5364 //             t = i_0 + 1;
5365 //             if (t < N) goto Loop;
5366 //             goto output;
5367 //      out:
5368 //             i_1 = phi (i_0, Bi)    <-- lcssa phi node
5369 //      ....
5370 //      output:
5371 //   Here, i_0 is uniform within the loop,  but it is not outside loop as each WI will
5372 //   exit with different i, thus i_1 is non-uniform. (Note that removing lcssa might be
5373 //   bad in performance, but it should not cause any functional issue.)
5374 //
5375 // This is needed to avoid generating the following code for which vISA cannot generate
5376 // the correct code:
5377 //   i = 0;             // i is uniform
5378 //   Loop:
5379 //         x = i + 1      // x is uniform
5380 //     B0  if (divergent-condition)
5381 //            <code1>
5382 //     B1  else
5383 //            z = array[i]
5384 //     B2  endif
5385 //         i = phi (x, B0), (x, B1)
5386 //         ......
5387 //     if (i < n) goto Loop
5388 //
5389 // Generated code (visa) (phi becomes mov in its incoming BBs).
5390 //
5391 //   i = 0;             // i is uniform
5392 //   Loop:
5393 //         (W) x = i + 1      // x is uniform, NoMask
5394 //     B0  if (divergent-condition)
5395 //            <code1>
5396 //         (W) i = x         // noMask
5397 //     B1  else
5398 //            z = array[i]
5399 //         (W) i = x         // noMask
5400 //     B2  endif
5401 //         ......
5402 //         if (i < n) goto Loop
5403 //
5404 // In the 1st iteration, 'z' should be array[0].  Assume 'if' is divergent, thus both B0 and B1
5405 // blocks will be executed. As result, the value of 'i' after B0 will be x, which is 1. And 'z'
5406 // will take array[1], which is wrong (correct one is array[0]).
5407 //
5408 // This case happens if phi is uniform, which means all phis' incoming values are identical
5409 // and uniform (current hehavior of WIAnalysis). Note that the identical values means this phi
5410 // is no longer needed.  Once such a phi is removed,  we will never generate code like one shown
5411 // above and thus, no wrong code will be generated from visa.
5412 //
5413 // This pass will be invoked at place close to the Emit pass, where WIAnalysis will be invoked,
5414 // so that IR between this pass and WIAnalysis stays the same, at least no new PHINodes like this
5415 // will be generated.
5416 //
5417 namespace {
5418     class CleanPHINode : public FunctionPass
5419     {
5420     public:
5421         static char ID;
5422         CleanPHINode();
5423 
getPassName() const5424         StringRef getPassName() const override { return "CleanPhINode"; }
5425 
5426         bool runOnFunction(Function& F) override;
5427     };
5428 }
5429 
5430 #undef PASS_FLAG
5431 #undef PASS_DESCRIPTION
5432 #undef PASS_CFG_ONLY
5433 #undef PASS_ANALYSIS
5434 #define PASS_FLAG "igc-cleanphinode"
5435 #define PASS_DESCRIPTION "Clean up PHINode"
5436 #define PASS_CFG_ONLY false
5437 #define PASS_ANALYSIS false
5438 IGC_INITIALIZE_PASS_BEGIN(CleanPHINode, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
5439 IGC_INITIALIZE_PASS_END(CleanPHINode,   PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
5440 
5441 
5442 char CleanPHINode::ID = 0;
createCleanPHINodePass()5443 FunctionPass* IGC::createCleanPHINodePass()
5444 {
5445     return new CleanPHINode();
5446 }
5447 
CleanPHINode()5448 CleanPHINode::CleanPHINode() : FunctionPass(ID)
5449 {
5450     initializeCleanPHINodePass(*PassRegistry::getPassRegistry());
5451 }
5452 
runOnFunction(Function & F)5453 bool CleanPHINode::runOnFunction(Function& F)
5454 {
5455     auto isLCSSAPHINode = [](PHINode* PHI) { return (PHI->getNumIncomingValues() == 1); };
5456 
5457     bool changed = false;
5458     for (auto BI = F.begin(), BE = F.end(); BI != BE; ++BI)
5459     {
5460         BasicBlock* BB = &*BI;
5461         auto II = BB->begin();
5462         auto IE = BB->end();
5463         while (II != IE)
5464         {
5465             auto currII = II;
5466             ++II;
5467             PHINode* PHI = dyn_cast<PHINode>(currII);
5468             if (PHI == nullptr)
5469             {
5470                 // proceed to the next BB
5471                 break;
5472             }
5473             if (isLCSSAPHINode(PHI))
5474             {
5475                 // Keep LCSSA PHI as uniform analysis needs it.
5476                 continue;
5477             }
5478 
5479             if (PHI->getNumIncomingValues() > 0) // sanity
5480             {
5481                 Value* sameVal = PHI->getIncomingValue(0);
5482                 bool isAllSame = true;
5483                 for (int i = 1, sz = (int)PHI->getNumIncomingValues(); i < sz; ++i)
5484                 {
5485                     if (sameVal != PHI->getIncomingValue(i))
5486                     {
5487                         isAllSame = false;
5488                         break;
5489                     }
5490                 }
5491                 if (isAllSame)
5492                 {
5493                     PHI->replaceAllUsesWith(sameVal);
5494                     PHI->eraseFromParent();
5495                     changed = true;
5496                 }
5497             }
5498         }
5499     }
5500     return changed;
5501 }