1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2017-2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 #include "Compiler/CISACodeGen/PatternMatchPass.hpp"
10 #include "Compiler/CISACodeGen/EmitVISAPass.hpp"
11 #include "Compiler/CISACodeGen/DeSSA.hpp"
12 #include "Compiler/MetaDataApi/IGCMetaDataHelper.h"
13 #include "common/igc_regkeys.hpp"
14 #include "common/LLVMWarningsPush.hpp"
15 #include <llvm/IR/InlineAsm.h>
16 #include <llvm/IR/Dominators.h>
17 #include <llvm/IR/Constants.h>
18 #include <llvm/IR/Instruction.h>
19 #include <llvm/IR/PatternMatch.h>
20 #include <llvmWrapper/IR/Instructions.h>
21 #include <llvm/IR/IntrinsicInst.h>
22 #include "common/LLVMWarningsPop.hpp"
23 #include "GenISAIntrinsics/GenIntrinsicInst.h"
24 #include "Compiler/IGCPassSupport.h"
25 #include "Compiler/InitializePasses.h"
26 #include "Compiler/DebugInfo/ScalarVISAModule.h"
27 #include "Probe/Assertion.h"
28 
29 using namespace llvm;
30 using namespace IGC;
31 using namespace IGC::IGCMD;
32 
33 char CodeGenPatternMatch::ID = 0;
34 #define PASS_FLAG "CodeGenPatternMatch"
35 #define PASS_DESCRIPTION "Does pattern matching"
36 #define PASS_CFG_ONLY true
37 #define PASS_ANALYSIS true
38 IGC_INITIALIZE_PASS_BEGIN(CodeGenPatternMatch, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
39 IGC_INITIALIZE_PASS_DEPENDENCY(WIAnalysis)
40 IGC_INITIALIZE_PASS_DEPENDENCY(LiveVarsAnalysis)
41 IGC_INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
42 IGC_INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
43 IGC_INITIALIZE_PASS_DEPENDENCY(MetaDataUtilsWrapper)
44 IGC_INITIALIZE_PASS_DEPENDENCY(PositionDepAnalysis)
45 IGC_INITIALIZE_PASS_END(CodeGenPatternMatch, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
46 
47 namespace IGC
48 {
49 
CodeGenPatternMatch()50     CodeGenPatternMatch::CodeGenPatternMatch() : FunctionPass(ID),
51         m_rootIsSubspanUse(false),
52         m_blocks(nullptr),
53         m_numBlocks(0),
54         m_root(nullptr),
55         m_currentPattern(nullptr),
56         m_Platform(),
57         m_AllowContractions(true),
58         m_NeedVMask(false),
59         m_samplertoRenderTargetEnable(false),
60         m_ctx(nullptr),
61         DT(nullptr),
62         LI(nullptr),
63         m_DL(0),
64         m_WI(nullptr),
65         m_LivenessInfo(nullptr)
66     {
67         initializeCodeGenPatternMatchPass(*PassRegistry::getPassRegistry());
68     }
69 
~CodeGenPatternMatch()70     CodeGenPatternMatch::~CodeGenPatternMatch()
71     {
72         delete[] m_blocks;
73     }
74 
CodeGenNode(llvm::DomTreeNode * node)75     void CodeGenPatternMatch::CodeGenNode(llvm::DomTreeNode* node)
76     {
77         // Process blocks by processing the dominance tree depth first
78         for (auto child = node->begin(); child != node->end(); ++child)
79         {
80             CodeGenNode(*child);
81         }
82         llvm::BasicBlock* bb = node->getBlock();
83         CodeGenBlock(bb);
84     }
85 
runOnFunction(llvm::Function & F)86     bool CodeGenPatternMatch::runOnFunction(llvm::Function& F)
87     {
88         m_blockMap.clear();
89         ConstantPlacement.clear();
90         PairOutputMap.clear();
91         UniformBools.clear();
92 
93         delete[] m_blocks;
94         m_blocks = nullptr;
95 
96         m_ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
97 
98         MetaDataUtils* pMdUtils = getAnalysis<MetaDataUtilsWrapper>().getMetaDataUtils();
99         ModuleMetaData* modMD = getAnalysis<MetaDataUtilsWrapper>().getModuleMetaData();
100         if (pMdUtils->findFunctionsInfoItem(&F) == pMdUtils->end_FunctionsInfo())
101         {
102             return false;
103         }
104 
105         m_AllowContractions = true;
106         if (m_ctx->m_DriverInfo.NeedCheckContractionAllowed())
107         {
108             m_AllowContractions =
109                 modMD->compOpt.FastRelaxedMath ||
110                 modMD->compOpt.MadEnable;
111         }
112         m_Platform = m_ctx->platform;
113 
114         DT = &getAnalysis<llvm::DominatorTreeWrapperPass>().getDomTree();
115         LI = &getAnalysis<llvm::LoopInfoWrapperPass>().getLoopInfo();
116         m_DL = &F.getParent()->getDataLayout();
117         m_WI = &getAnalysis<WIAnalysis>();
118         m_PosDep = &getAnalysis<PositionDepAnalysis>();
119         // pattern match will update liveness held by LiveVar, which needs
120         // WIAnalysis result for uniform variable
121         m_LivenessInfo = &getAnalysis<LiveVarsAnalysis>().getLiveVars();
122         CreateBasicBlocks(&F);
123         CodeGenNode(DT->getRootNode());
124         return false;
125     }
126 
HasSideEffect(llvm::Instruction & inst)127     inline bool HasSideEffect(llvm::Instruction& inst)
128     {
129         if (inst.mayWriteToMemory() || inst.isTerminator())
130         {
131             return true;
132         }
133         return false;
134     }
135 
136 
HasPhiUse(llvm::Value & inst)137     inline bool HasPhiUse(llvm::Value& inst)
138     {
139         for (auto UI = inst.user_begin(), E = inst.user_end(); UI != E; ++UI)
140         {
141             llvm::User* U = *UI;
142             if (llvm::isa<llvm::PHINode>(U))
143             {
144                 return true;
145             }
146         }
147         return false;
148     }
149 
IsDbgInst(llvm::Instruction & inst) const150     bool CodeGenPatternMatch::IsDbgInst(llvm::Instruction& inst) const
151     {
152         if (llvm::isa<llvm::DbgInfoIntrinsic>(&inst))
153         {
154             return true;
155         }
156         if (DebugMetadataInfo::hasDashGOption(m_ctx) &&
157             inst.getMetadata("perThreadOffset") != nullptr)
158         {
159             // debugging needs this
160             return true;
161         }
162         return false;
163     }
164 
IsConstOrSimdConstExpr(Value * C)165     bool CodeGenPatternMatch::IsConstOrSimdConstExpr(Value* C)
166     {
167         if (isa<ConstantInt>(C))
168         {
169             return true;
170         }
171         if (Instruction * inst = dyn_cast<Instruction>(C))
172         {
173             return SIMDConstExpr(inst);
174         }
175         return false;
176     }
177 
178     // this function need to be in sync with CShader::EvaluateSIMDConstExpr on what can be supported
SIMDConstExpr(Instruction * C)179     bool CodeGenPatternMatch::SIMDConstExpr(Instruction* C)
180     {
181         auto it = m_IsSIMDConstExpr.find(C);
182         if (it != m_IsSIMDConstExpr.end())
183         {
184             return it->second;
185         }
186         bool isConstExpr = false;
187         if (BinaryOperator * op = dyn_cast<BinaryOperator>(C))
188         {
189             switch (op->getOpcode())
190             {
191             case Instruction::Add:
192                 isConstExpr = IsConstOrSimdConstExpr(op->getOperand(0)) && IsConstOrSimdConstExpr(op->getOperand(1));
193                 break;
194             case Instruction::Mul:
195                 isConstExpr = IsConstOrSimdConstExpr(op->getOperand(0)) && IsConstOrSimdConstExpr(op->getOperand(1));
196                 break;
197             case Instruction::Shl:
198                 isConstExpr = IsConstOrSimdConstExpr(op->getOperand(0)) && IsConstOrSimdConstExpr(op->getOperand(1));
199                 break;
200             default:
201                 break;
202             }
203         }
204         else if (llvm::GenIntrinsicInst * genInst = dyn_cast<GenIntrinsicInst>(C))
205         {
206             if (genInst->getIntrinsicID() == GenISAIntrinsic::GenISA_simdSize)
207             {
208                 isConstExpr = true;
209             }
210         }
211         m_IsSIMDConstExpr.insert(std::make_pair(C, isConstExpr));
212         return isConstExpr;
213     }
214 
NeedInstruction(llvm::Instruction & I)215     bool CodeGenPatternMatch::NeedInstruction(llvm::Instruction& I)
216     {
217         if (SIMDConstExpr(&I))
218         {
219             return false;
220         }
221         if (HasPhiUse(I) || HasSideEffect(I) || IsDbgInst(I) ||
222             (m_usedInstructions.find(&I) != m_usedInstructions.end()))
223         {
224             return true;
225         }
226         return false;
227     }
228 
AddToConstantPool(llvm::BasicBlock * UseBlock,llvm::Value * Val)229     void CodeGenPatternMatch::AddToConstantPool(llvm::BasicBlock* UseBlock,
230         llvm::Value* Val) {
231         Constant* C = dyn_cast_or_null<Constant>(Val);
232         if (!C)
233             return;
234 
235         BasicBlock* LCA = UseBlock;
236         // Determine where we put the constant initialization.
237         // Choose loop pre-header as LICM.
238         // XXX: Further investigation/tuning is needed to see whether
239         // we need to hoist constant initialization out of the
240         // top-level loop within a nested loop. So far, we only hoist
241         // one level up.
242         if (Loop * L = LI->getLoopFor(LCA)) {
243             if (BasicBlock * Preheader = L->getLoopPreheader())
244                 LCA = Preheader;
245         }
246         // Find the common dominator as CSE.
247         if (BasicBlock * BB = ConstantPlacement.lookup(C))
248             LCA = DT->findNearestCommonDominator(LCA, BB);
249         IGC_ASSERT_MESSAGE(LCA, "LCA always exists for reachable BBs within a function!");
250         ConstantPlacement[C] = LCA;
251     }
252 
253     // Check bool values that can be emitted as a single element predicate.
gatherUniformBools(Value * Val)254     void CodeGenPatternMatch::gatherUniformBools(Value* Val)
255     {
256         if (!isUniform(Val) || Val->getType()->getScalarType()->isIntegerTy(1))
257             return;
258 
259         // Only starts from select instruction for now.
260         // It is more complicate for uses in terminators.
261         if (SelectInst * SI = dyn_cast<SelectInst>(Val)) {
262             Value* Cond = SI->getCondition();
263             if (Cond->getType()->isVectorTy() || !Cond->hasOneUse())
264                 return;
265 
266             // All users of bool values.
267             DenseSet<Value*> Vals;
268             Vals.insert(SI);
269 
270             // Grow the list of bool values to be checked.
271             std::vector<Value*> ValList;
272             ValList.push_back(Cond);
273 
274             bool IsLegal = true;
275             while (!ValList.empty()) {
276                 Value* V = ValList.back();
277                 ValList.pop_back();
278                 IGC_ASSERT(nullptr != V);
279                 IGC_ASSERT(nullptr != V->getType());
280                 IGC_ASSERT(isUniform(V));
281                 IGC_ASSERT(V->getType()->isIntegerTy(1));
282 
283                 // Check uses.
284                 for (auto UI = V->user_begin(), UE = V->user_end(); UI != UE; ++UI) {
285                     Value* U = *UI;
286                     if (!Vals.count(U))
287                         goto FAIL;
288                 }
289 
290                 // Check defs.
291                 Vals.insert(V);
292                 if (auto CI = dyn_cast<CmpInst>(V)) {
293                     IGC_ASSERT(isUniform(CI->getOperand(0)));
294                     IGC_ASSERT(isUniform(CI->getOperand(1)));
295                     if (CI->getOperand(0)->getType()->getScalarSizeInBits() == 1)
296                         goto FAIL;
297                     continue;
298                 }
299                 else if (auto BI = dyn_cast<BinaryOperator>(V)) {
300                     IGC_ASSERT(isUniform(BI->getOperand(0)));
301                     IGC_ASSERT(isUniform(BI->getOperand(1)));
302                     if (isa<Instruction>(BI->getOperand(0)))
303                         ValList.push_back(BI->getOperand(0));
304                     if (isa<Instruction>(BI->getOperand(1)))
305                         ValList.push_back(BI->getOperand(1));
306                     continue;
307                 }
308 
309             FAIL:
310                 IsLegal = false;
311                 break;
312             }
313 
314             // Populate all boolean values if legal.
315             if (IsLegal) {
316                 for (auto V : Vals) {
317                     if (V->getType()->isIntegerTy(1))
318                         UniformBools.insert(V);
319                 }
320             }
321         }
322     }
323 
CodeGenBlock(llvm::BasicBlock * bb)324     void CodeGenPatternMatch::CodeGenBlock(llvm::BasicBlock* bb)
325     {
326         llvm::BasicBlock::InstListType& instructionList = bb->getInstList();
327         llvm::BasicBlock::InstListType::reverse_iterator I, E;
328         auto it = m_blockMap.find(bb);
329         IGC_ASSERT(it != m_blockMap.end());
330         SBasicBlock* block = it->second;
331 
332         // loop through instructions bottom up
333         for (I = instructionList.rbegin(), E = instructionList.rend(); I != E; ++I)
334         {
335             llvm::Instruction& inst = (*I);
336 
337             if (NeedInstruction(inst))
338             {
339                 SetPatternRoot(inst);
340                 Pattern* pattern = Match(inst);
341                 if (pattern)
342                 {
343                     block->m_dags.push_back(SDAG(pattern, m_root));
344                     gatherUniformBools(m_root);
345                 }
346             }
347         }
348     }
349 
CreateBasicBlocks(llvm::Function * pLLVMFunc)350     void CodeGenPatternMatch::CreateBasicBlocks(llvm::Function* pLLVMFunc)
351     {
352         m_numBlocks = pLLVMFunc->size();
353         m_blocks = new SBasicBlock[m_numBlocks];
354         uint i = 0;
355         for (BasicBlock& bb : *pLLVMFunc)
356         {
357             m_blocks[i].id = i;
358             m_blocks[i].bb = &bb;
359             m_blockMap.insert(std::pair<llvm::BasicBlock*, SBasicBlock*>(&bb, &m_blocks[i]));
360             i++;
361         }
362     }
Match(llvm::Instruction & inst)363     Pattern* CodeGenPatternMatch::Match(llvm::Instruction& inst)
364     {
365         m_currentPattern = nullptr;
366         visit(inst);
367         return m_currentPattern;
368     }
369 
SetPatternRoot(llvm::Instruction & inst)370     void CodeGenPatternMatch::SetPatternRoot(llvm::Instruction& inst)
371     {
372         m_root = &inst;
373         m_rootIsSubspanUse = IsSubspanUse(m_root);
374     }
375 
376     template<typename Op_t, typename ConstTy>
377     struct ClampWithConstants_match {
378         typedef ConstTy* ConstPtrTy;
379 
380         Op_t Op;
381         ConstPtrTy& CMin, & CMax;
382 
ClampWithConstants_matchIGC::ClampWithConstants_match383         ClampWithConstants_match(const Op_t& OpMatch,
384             ConstPtrTy& Min, ConstPtrTy& Max)
385             : Op(OpMatch), CMin(Min), CMax(Max) {}
386 
387         template<typename OpTy>
matchIGC::ClampWithConstants_match388         bool match(OpTy* V) {
389             CallInst* GII = dyn_cast<CallInst>(V);
390             if (!GII)
391                 return false;
392 
393             EOPCODE op = GetOpCode(GII);
394 
395             if (op != llvm_max && op != llvm_min)
396                 return false;
397 
398             Value* X = GII->getOperand(0);
399             Value* C = GII->getOperand(1);
400             if (isa<ConstTy>(X))
401                 std::swap(X, C);
402 
403             ConstPtrTy C0 = dyn_cast<ConstTy>(C);
404             if (!C0)
405                 return false;
406 
407             CallInst* GII2 = dyn_cast<CallInst>(X);
408             if (!GII2)
409                 return false;
410 
411             EOPCODE op2 = GetOpCode(GII2);
412             if (!(op == llvm_min && op2 == llvm_max) &&
413                 !(op == llvm_max && op2 == llvm_min))
414                 return false;
415 
416             X = GII2->getOperand(0);
417             C = GII2->getOperand(1);
418             if (isa<ConstTy>(X))
419                 std::swap(X, C);
420 
421             ConstPtrTy C1 = dyn_cast<ConstTy>(C);
422             if (!C1)
423                 return false;
424 
425             if (!Op.match(X))
426                 return false;
427 
428             CMin = (op2 == llvm_min) ? C0 : C1;
429             CMax = (op2 == llvm_min) ? C1 : C0;
430             return true;
431         }
432     };
433 
434     template<typename OpTy, typename ConstTy>
435     inline ClampWithConstants_match<OpTy, ConstTy>
m_ClampWithConstants(const OpTy & Op,ConstTy * & Min,ConstTy * & Max)436         m_ClampWithConstants(const OpTy& Op, ConstTy*& Min, ConstTy*& Max) {
437         return ClampWithConstants_match<OpTy, ConstTy>(Op, Min, Max);
438     }
439 
440     template<typename Op_t>
441     struct IsNaN_match {
442         Op_t Op;
443 
IsNaN_matchIGC::IsNaN_match444         IsNaN_match(const Op_t& OpMatch) : Op(OpMatch) {}
445 
446         template<typename OpTy>
matchIGC::IsNaN_match447         bool match(OpTy* V) {
448             using namespace llvm::PatternMatch;
449 
450             FCmpInst* FCI = dyn_cast<FCmpInst>(V);
451             if (!FCI)
452                 return false;
453 
454             switch (FCI->getPredicate()) {
455             case FCmpInst::FCMP_UNE:
456                 return FCI->getOperand(0) == FCI->getOperand(1) &&
457                     Op.match(FCI->getOperand(0));
458             case FCmpInst::FCMP_UNO:
459                 return m_Zero().match(FCI->getOperand(1)) &&
460                     Op.match(FCI->getOperand(0));
461             default:
462                 break;
463             }
464 
465             return false;
466         }
467     };
468 
469     template<typename OpTy>
m_IsNaN(const OpTy & Op)470     inline IsNaN_match<OpTy> m_IsNaN(const OpTy& Op) {
471         return IsNaN_match<OpTy>(Op);
472     }
473 
474     std::tuple<Value*, unsigned, VISA_Type>
isFPToIntegerSatWithExactConstant(llvm::CastInst * I)475         CodeGenPatternMatch::isFPToIntegerSatWithExactConstant(llvm::CastInst* I) {
476         using namespace llvm::PatternMatch; // Scoped using declaration.
477 
478         unsigned Opcode = I->getOpcode();
479         IGC_ASSERT(Opcode == Instruction::FPToSI || Opcode == Instruction::FPToUI);
480 
481         unsigned BitWidth = I->getDestTy()->getIntegerBitWidth();
482         APFloat FMin(I->getSrcTy()->getFltSemantics());
483         APFloat FMax(I->getSrcTy()->getFltSemantics());
484         if (Opcode == Instruction::FPToSI) {
485             if (FMax.convertFromAPInt(APInt::getSignedMaxValue(BitWidth), true,
486                 APFloat::rmNearestTiesToEven) != APFloat::opOK)
487                 return std::make_tuple(nullptr, 0, ISA_TYPE_F);
488             if (FMin.convertFromAPInt(APInt::getSignedMinValue(BitWidth), true,
489                 APFloat::rmNearestTiesToEven) != APFloat::opOK)
490                 return std::make_tuple(nullptr, 0, ISA_TYPE_F);
491         }
492         else {
493             if (FMax.convertFromAPInt(APInt::getMaxValue(BitWidth), false,
494                 APFloat::rmNearestTiesToEven) != APFloat::opOK)
495                 return std::make_tuple(nullptr, 0, ISA_TYPE_F);
496             if (FMin.convertFromAPInt(APInt::getMinValue(BitWidth), false,
497                 APFloat::rmNearestTiesToEven) != APFloat::opOK)
498                 return std::make_tuple(nullptr, 0, ISA_TYPE_F);
499         }
500 
501         llvm::ConstantFP* CMin = nullptr;
502         llvm::ConstantFP* CMax = nullptr;
503         llvm::Value* X = nullptr;
504 
505         if (!match(I->getOperand(0), m_ClampWithConstants(m_Value(X), CMin, CMax)))
506             return std::make_tuple(nullptr, 0, ISA_TYPE_F);
507 
508         if (!CMin->isExactlyValue(FMin) || !CMax->isExactlyValue(FMax))
509             return std::make_tuple(nullptr, 0, ISA_TYPE_F);
510 
511         return std::make_tuple(X, Opcode, GetType(I->getType(), m_ctx));
512     }
513 
514     // The following pattern matching is targeted to the conversion from FP values
515     // to INTEGER values with saturation where the MAX and/or MIN INTEGER values
516     // cannot be represented in FP values exactly. E.g., UINT_MAX (2**32-1) in
517     // 'unsigned' cannot be represented in 'float', where only 23 significant bits
518     // are available but UINT_MAX needs 32 significant bits. We cannot simply
519     // express that conversion with saturation as
520     //
521     //  o := fptoui(clamp(x, float(UINT_MIN), float(UINT_MAX));
522     //
523     // as, in LLVM, fptoui is undefined when the 'unsigned' source cannot fit in
524     // 'float', where clamp(x, MIN, MAX) is defined as max(min(x, MAX), MIN),
525     //
526     // Hence, OCL use the following sequence (over-simplified by excluding the NaN
527     // case.)
528     //
529     //  o := select(fptoui(x), UINT_MIN, x < float(UINT_MIN));
530     //  o := select(o,         UINT_MAX, x > float(UINT_MAX));
531     //
532     // (We SHOULD use 'o := select(o, UINTMAX, x >= float(UINT_MAX))' as
533     // 'float(UINT_MAX)' will be rounded to UINT_MAX+1, i.e. 2 ** 32, and the next
534     // smaller value than float(UINT_MAX) in 'float' is (2 ** 24 - 1) << 8. For
535     // 'int', that's also true for INT_MIN.)
536 
537     std::tuple<Value*, unsigned, VISA_Type>
isFPToSignedIntSatWithInexactConstant(llvm::SelectInst * SI)538         CodeGenPatternMatch::isFPToSignedIntSatWithInexactConstant(llvm::SelectInst* SI) {
539         using namespace llvm::PatternMatch; // Scoped using declaration.
540 
541         // TODO
542         return std::make_tuple(nullptr, 0, ISA_TYPE_F);
543     }
544 
545     std::tuple<Value*, unsigned, VISA_Type>
isFPToUnsignedIntSatWithInexactConstant(llvm::SelectInst * SI)546         CodeGenPatternMatch::isFPToUnsignedIntSatWithInexactConstant(llvm::SelectInst* SI)
547     {
548         using namespace llvm::PatternMatch; // Scoped using declaration.
549 
550         Constant* C0 = dyn_cast<Constant>(SI->getTrueValue());
551         if (!C0)
552             return std::make_tuple(nullptr, 0, ISA_TYPE_F);
553         if (!isa<ConstantFP>(C0) && !isa<ConstantInt>(C0))
554             return std::make_tuple(nullptr, 0, ISA_TYPE_F);
555         Value* Cond = SI->getCondition();
556 
557         SelectInst* SI2 = dyn_cast<SelectInst>(SI->getFalseValue());
558         if (!SI2)
559             return std::make_tuple(nullptr, 0, ISA_TYPE_F);
560         Constant* C1 = dyn_cast<Constant>(SI2->getTrueValue());
561         if (!C1)
562             return std::make_tuple(nullptr, 0, ISA_TYPE_F);
563         if (!isa<ConstantFP>(C1) && !isa<ConstantInt>(C1))
564             return std::make_tuple(nullptr, 0, ISA_TYPE_F);
565         Value* Cond2 = SI2->getCondition();
566 
567         Value* X = SI2->getFalseValue();
568         Type* Ty = X->getType();
569         if (Ty->isFloatTy()) {
570             BitCastInst* BC = dyn_cast<BitCastInst>(X);
571             if (!BC)
572                 return std::make_tuple(nullptr, 0, ISA_TYPE_F);
573             X = BC->getOperand(0);
574             Ty = X->getType();
575             C1 = ConstantExpr::getBitCast(C1, Ty);
576             C0 = ConstantExpr::getBitCast(C0, Ty);
577         }
578         IntegerType* ITy = dyn_cast<IntegerType>(Ty);
579         if (!ITy)
580             return std::make_tuple(nullptr, 0, ISA_TYPE_F);
581         unsigned BitWidth = ITy->getBitWidth();
582         FPToUIInst* CI = dyn_cast<FPToUIInst>(X);
583         if (!CI)
584             return std::make_tuple(nullptr, 0, ISA_TYPE_F);
585         Ty = CI->getSrcTy();
586         if (!(Ty->isFloatTy() && BitWidth == 32) &&
587             !(Ty->isDoubleTy() && BitWidth == 64))
588             return std::make_tuple(nullptr, 0, ISA_TYPE_F);
589         X = CI->getOperand(0);
590 
591         ConstantInt* CMin = dyn_cast<ConstantInt>(C0);
592         ConstantInt* CMax = dyn_cast<ConstantInt>(C1);
593         if (!CMax || !CMin || !CMax->isMaxValue(false) || !CMin->isMinValue(false))
594             return std::make_tuple(nullptr, 0, ISA_TYPE_F);
595 
596         Constant* FMin = ConstantExpr::getUIToFP(CMin, Ty);
597         Constant* FMax = ConstantExpr::getUIToFP(CMax, Ty);
598 
599         FCmpInst::Predicate Pred = FCmpInst::FCMP_FALSE;
600         if (!match(Cond2, m_FCmp(Pred, m_Specific(X), m_Specific(FMax))))
601             return std::make_tuple(nullptr, 0, ISA_TYPE_F);
602         if (Pred != FCmpInst::FCMP_OGT) // FIXME: We should use OGE instead of OGT.
603             return std::make_tuple(nullptr, 0, ISA_TYPE_F);
604 
605         FCmpInst::Predicate Pred2 = FCmpInst::FCMP_FALSE;
606         if (!match(Cond,
607             m_Or(m_FCmp(Pred, m_Specific(X), m_Specific(FMin)),
608                 m_FCmp(Pred2, m_Specific(X), m_Specific(X))))) {
609             if (!match(Cond,
610                 m_Or(m_FCmp(Pred, m_Specific(X), m_Specific(FMin)),
611                     m_Zero()))) {
612                 return std::make_tuple(nullptr, 0, ISA_TYPE_F);
613             }
614             // Special case where the staturatured result is bitcasted into float
615             // again (due to typedwrite only accepts `float`. So the isNaN(X) is
616             // reduced to `false`.
617             Pred2 = FCmpInst::FCMP_UNE;
618         }
619         if (Pred != FCmpInst::FCMP_OLT || Pred2 != FCmpInst::FCMP_UNE)
620             return std::make_tuple(nullptr, 0, ISA_TYPE_F);
621 
622         VISA_Type type = GetType(CI->getType(), m_ctx);
623 
624         // Fold extra clamp.
625         Value* X2 = nullptr;
626         ConstantFP* CMin2 = nullptr;
627         ConstantFP* CMax2 = nullptr;
628         if (match(X, m_ClampWithConstants(m_Value(X2), CMin2, CMax2))) {
629             if (CMin2 == FMin) {
630                 if (CMax2->isExactlyValue(255.0)) {
631                     X = X2;
632                     type = ISA_TYPE_B;
633                 }
634                 else if (CMax2->isExactlyValue(65535.0)) {
635                     X = X2;
636                     type = ISA_TYPE_W;
637                 }
638             }
639         }
640 
641         return std::make_tuple(X, Instruction::FPToUI, type);
642     }
643 
MatchFPToIntegerWithSaturation(llvm::Instruction & I)644     bool CodeGenPatternMatch::MatchFPToIntegerWithSaturation(llvm::Instruction& I) {
645         Value* X = nullptr;
646         unsigned Opcode = 0;
647         VISA_Type type = ISA_TYPE_NUM;
648 
649         if (CastInst * CI = dyn_cast<CastInst>(&I)) {
650             std::tie(X, Opcode, type) = isFPToIntegerSatWithExactConstant(CI);
651             if (!X)
652                 return false;
653         }
654         else if (SelectInst * SI = dyn_cast<SelectInst>(&I)) {
655             std::tie(X, Opcode, type) = isFPToSignedIntSatWithInexactConstant(SI);
656             if (!X) {
657                 std::tie(X, Opcode, type) = isFPToUnsignedIntSatWithInexactConstant(SI);
658                 if (!X)
659                     return false;
660             }
661         }
662         else {
663             return false;
664         }
665 
666         // Match!
667         IGC_ASSERT(Opcode == Instruction::FPToSI || Opcode == Instruction::FPToUI);
668 
669         struct FPToIntegerWithSaturationPattern : public Pattern {
670             bool isUnsigned, needBitCast;
671             VISA_Type type;
672             SSource src;
673             virtual void Emit(EmitPass* pass, const DstModifier& dstMod) {
674                 pass->EmitFPToIntWithSat(isUnsigned, needBitCast, type, src, dstMod);
675             }
676         };
677 
678         bool isUnsigned = (Opcode == Instruction::FPToUI);
679         FPToIntegerWithSaturationPattern* pat
680             = new (m_allocator) FPToIntegerWithSaturationPattern();
681         pat->isUnsigned = isUnsigned;
682         pat->needBitCast = !I.getType()->isIntegerTy();
683         pat->type = type;
684         pat->src = GetSource(X, !isUnsigned, false);
685         AddPattern(pat);
686 
687         return true;
688     }
689 
690     std::tuple<Value*, bool, bool>
isIntegerSatTrunc(llvm::SelectInst * SI)691         CodeGenPatternMatch::isIntegerSatTrunc(llvm::SelectInst* SI) {
692         using namespace llvm::PatternMatch; // Scoped using declaration.
693 
694         ICmpInst* Cmp = dyn_cast<ICmpInst>(SI->getOperand(0));
695         if (!Cmp)
696             return std::make_tuple(nullptr, false, false);
697 
698         ICmpInst::Predicate Pred = Cmp->getPredicate();
699         if (Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_UGT)
700             return std::make_tuple(nullptr, false, false);
701 
702         ConstantInt* CI = dyn_cast<ConstantInt>(Cmp->getOperand(1));
703         if (!CI)
704             return std::make_tuple(nullptr, false, false);
705 
706         // Truncate into unsigned integer by default.
707         bool isSignedDst = false;
708         unsigned DstBitWidth = SI->getType()->getIntegerBitWidth();
709         unsigned SrcBitWidth = Cmp->getOperand(0)->getType()->getIntegerBitWidth();
710         APInt UMax = APInt::getMaxValue(DstBitWidth);
711         APInt UMin = APInt::getMinValue(DstBitWidth);
712         APInt SMax = APInt::getSignedMaxValue(DstBitWidth);
713         APInt SMin = APInt::getSignedMinValue(DstBitWidth);
714         if (SrcBitWidth > DstBitWidth) {
715             UMax = UMax.zext(SrcBitWidth);
716             UMin = UMin.zext(SrcBitWidth);
717             SMax = SMax.sext(SrcBitWidth);
718             SMin = SMin.sext(SrcBitWidth);
719         }
720         else
721         {
722             // SrcBitwidth should be always wider than DstBitwidth,
723             // since src is a source of a trunc instruction, and dst
724             // have the same width as its destination.
725             return std::make_tuple(nullptr, false, false);
726         }
727 
728         if (CI->getValue() != UMax && CI->getValue() != SMax)
729             return std::make_tuple(nullptr, false, false);
730         if (CI->getValue() == SMax) // Truncate into signed integer.
731             isSignedDst = true;
732 
733         APInt MinValue = isSignedDst ? SMin : UMin;
734         CI = dyn_cast<ConstantInt>(SI->getOperand(1));
735         if (!CI || !CI->isMaxValue(isSignedDst))
736             return std::make_tuple(nullptr, false, false);
737 
738         TruncInst* TI = dyn_cast<TruncInst>(SI->getOperand(2));
739         if (!TI)
740             return std::make_tuple(nullptr, false, false);
741 
742         Value* Val = TI->getOperand(0);
743         if (Val != Cmp->getOperand(0))
744             return std::make_tuple(nullptr, false, false);
745 
746         // Truncate from unsigned integer.
747         if (Pred == ICmpInst::ICMP_UGT)
748             return std::make_tuple(Val, isSignedDst, false);
749 
750         // Truncate from signed integer. Need to check further for lower bound.
751         Value* LHS = nullptr, * RHS = nullptr;
752         if (!match(Val, m_SMax(m_Value(LHS), m_Value(RHS))))
753             return std::make_tuple(nullptr, false, false);
754 
755         if (isa<ConstantInt>(LHS))
756             std::swap(LHS, RHS);
757 
758         CI = dyn_cast<ConstantInt>(RHS);
759         if (!CI || CI->getValue() != MinValue)
760             return std::make_tuple(nullptr, false, false);
761 
762         return std::make_tuple(LHS, isSignedDst, true);
763     }
764 
MatchIntegerTruncSatModifier(llvm::SelectInst & I)765     bool CodeGenPatternMatch::MatchIntegerTruncSatModifier(llvm::SelectInst& I) {
766         // Only match BYTE or WORD.
767         if (!I.getType()->isIntegerTy(8) && !I.getType()->isIntegerTy(16))
768             return false;
769         Value* Src = nullptr;
770         bool isSignedDst = false, isSignedSrc = false;
771         std::tie(Src, isSignedDst, isSignedSrc) = isIntegerSatTrunc(&I);
772         if (!Src)
773             return false;
774 
775         struct IntegerSatTruncPattern : public Pattern {
776             SSource src;
777             bool isSignedDst;
778             bool isSignedSrc;
779             virtual void Emit(EmitPass* pass, const DstModifier& dstMod) {
780                 pass->EmitIntegerTruncWithSat(isSignedDst, isSignedSrc, src, dstMod);
781             }
782         };
783 
784         IntegerSatTruncPattern* pat = new (m_allocator) IntegerSatTruncPattern();
785         pat->src = GetSource(Src, isSignedSrc, false);
786         pat->isSignedDst = isSignedDst;
787         pat->isSignedSrc = isSignedSrc;
788         AddPattern(pat);
789 
790         return true;
791     }
792 
visitFPToSIInst(llvm::FPToSIInst & I)793     void CodeGenPatternMatch::visitFPToSIInst(llvm::FPToSIInst& I) {
794         bool match = MatchFPToIntegerWithSaturation(I) || MatchModifier(I);
795         IGC_ASSERT_MESSAGE(match, "Pattern match Failed");
796     }
797 
visitFPToUIInst(llvm::FPToUIInst & I)798     void CodeGenPatternMatch::visitFPToUIInst(llvm::FPToUIInst& I) {
799         bool match = MatchFPToIntegerWithSaturation(I) || MatchModifier(I);
800         IGC_ASSERT_MESSAGE(match, "Pattern match Failed");
801     }
802 
MatchSIToFPZExt(llvm::SIToFPInst * S2FI)803     bool CodeGenPatternMatch::MatchSIToFPZExt(llvm::SIToFPInst* S2FI) {
804         ZExtInst* ZEI = dyn_cast<ZExtInst>(S2FI->getOperand(0));
805         if (!ZEI)
806             return false;
807         if (!ZEI->getSrcTy()->isIntegerTy(1))
808             return false;
809 
810         struct SIToFPExtPattern : public Pattern {
811             SSource src;
812             virtual void Emit(EmitPass* pass, const DstModifier& dstMod) {
813                 pass->EmitSIToFPZExt(src, dstMod);
814             }
815         };
816 
817         SIToFPExtPattern* pat = new (m_allocator) SIToFPExtPattern();
818         pat->src = GetSource(ZEI->getOperand(0), false, false);
819         AddPattern(pat);
820 
821         return true;
822     }
823 
visitCastInst(llvm::CastInst & I)824     void CodeGenPatternMatch::visitCastInst(llvm::CastInst& I)
825     {
826         bool match = 0;
827         if (I.getOpcode() == Instruction::SExt)
828         {
829             match = MatchCmpSext(I) ||
830                 MatchModifier(I);
831         }
832         else if (I.getOpcode() == Instruction::SIToFP)
833         {
834             match = MatchSIToFPZExt(cast<SIToFPInst>(&I)) || MatchModifier(I);
835         }
836         else if (I.getOpcode() == Instruction::Trunc)
837         {
838             match = MatchModifier(I);
839         }
840         else
841         {
842             match = MatchModifier(I);
843         }
844     }
845 
NeedVMask()846     bool CodeGenPatternMatch::NeedVMask()
847     {
848         return m_NeedVMask;
849     }
850 
HasUseOutsideLoop(llvm::Value * v)851     bool CodeGenPatternMatch::HasUseOutsideLoop(llvm::Value* v)
852     {
853         if (Instruction * inst = dyn_cast<Instruction>(v))
854         {
855             if (Loop * L = LI->getLoopFor(inst->getParent()))
856             {
857                 for (auto UI = inst->user_begin(), E = inst->user_end(); UI != E; ++UI)
858                 {
859                     if (!L->contains(cast<Instruction>(*UI)))
860                     {
861                         return true;
862                     }
863                 }
864             }
865         }
866         return false;
867     }
868 
HandleSubspanUse(llvm::Value * v)869     void CodeGenPatternMatch::HandleSubspanUse(llvm::Value* v)
870     {
871         IGC_ASSERT(m_root != nullptr);
872         if (m_ctx->type != ShaderType::PIXEL_SHADER)
873         {
874             return;
875         }
876         if (!isa<Constant>(v) && !m_WI->isUniform(v))
877         {
878             if (isa<PHINode>(v) || HasUseOutsideLoop(v))
879             {
880                 // If a phi is used in a subspan we cannot propagate the subspan use and need to use VMask
881                 m_NeedVMask = true;
882             }
883             else
884             {
885                 m_subSpanUse.insert(v);
886                 if (LoadInst * load = dyn_cast<LoadInst>(v))
887                 {
888                     if (load->getPointerAddressSpace() == ADDRESS_SPACE_PRIVATE)
889                     {
890                         m_NeedVMask = true;
891                     }
892                 }
893                 if (HasPhiUse(*v) && m_WI->insideDivergentCF(m_root))
894                 {
895                     // \todo, more accurate condition for force-isolation
896                     ForceIsolate(v);
897                 }
898             }
899         }
900     }
901 
MatchMinMax(llvm::SelectInst & SI)902     bool CodeGenPatternMatch::MatchMinMax(llvm::SelectInst& SI) {
903         // Pattern to emit.
904         struct MinMaxPattern : public Pattern {
905             SSource srcs[2];
906             bool isMin, isUnsigned;
907             virtual void Emit(EmitPass* pass, const DstModifier& dstMod) {
908                 // FIXME: We should tell umax/umin from max/min as integers in LLVM
909                 // have no sign!
910                 pass->EmitMinMax(isMin, isUnsigned, srcs, dstMod);
911             }
912         };
913 
914         // Skip min/max pattern matching on FP, which needs to either explicitly
915         // use intrinsics or convert them into intrinsic in GenIRLower pass.
916         if (SI.getType()->isFloatingPointTy())
917             return false;
918 
919         bool isMin = false, isUnsigned = false;
920         llvm::Value* LHS = nullptr, * RHS = nullptr;
921 
922         if (!isMinOrMax(&SI, LHS, RHS, isMin, isUnsigned))
923             return false;
924 
925         MinMaxPattern* pat = new (m_allocator) MinMaxPattern();
926         // FIXME: We leave unsigned operand without source modifier so far. When
927         // its behavior is detailed and correcty modeled, consider to add source
928         // modifier support.
929         pat->srcs[0] = GetSource(LHS, !isUnsigned, false);
930         pat->srcs[1] = GetSource(RHS, !isUnsigned, false);
931         pat->isMin = isMin;
932         pat->isUnsigned = isUnsigned;
933         AddPattern(pat);
934 
935         return true;
936     }
937 
visitSelectInst(SelectInst & I)938     void CodeGenPatternMatch::visitSelectInst(SelectInst& I)
939     {
940         bool match = MatchFloatingPointSatModifier(I) ||
941             MatchIntegerTruncSatModifier(I) ||
942             MatchAbsNeg(I) ||
943             MatchFPToIntegerWithSaturation(I) ||
944             MatchMinMax(I) ||
945             /*MatchPredicate(I)   ||*/
946             MatchCmpSelect(I) ||
947             MatchSelectModifier(I);
948         IGC_ASSERT_MESSAGE(match, "Pattern Match failed");
949     }
950 
visitBinaryOperator(llvm::BinaryOperator & I)951     void CodeGenPatternMatch::visitBinaryOperator(llvm::BinaryOperator& I)
952     {
953 
954         bool match = false;
955         switch (I.getOpcode())
956         {
957         case Instruction::FSub:
958             match = MatchFloor(I) ||
959                 MatchFrc(I) ||
960                 MatchLrp(I) ||
961                 MatchPredAdd(I) ||
962                 MatchMad(I) ||
963                 MatchAbsNeg(I) ||
964                 MatchModifier(I);
965             break;
966         case Instruction::Sub:
967             match = MatchMad(I) ||
968                 MatchAbsNeg(I) ||
969                 MatchMulAdd16(I) ||
970                 MatchModifier(I);
971             break;
972         case Instruction::Mul:
973             match = MatchFullMul32(I) ||
974                     MatchMulAdd16(I) ||
975                     MatchModifier(I);
976             break;
977         case Instruction::Add:
978             match = MatchMad(I) ||
979                 MatchMulAdd16(I) ||
980                 MatchModifier(I);
981             break;
982         case Instruction::UDiv:
983         case Instruction::SDiv:
984         case Instruction::AShr:
985             match = MatchAvg(I) ||
986                 MatchModifier(I);
987             break;
988         case Instruction::FMul:
989         case Instruction::URem:
990         case Instruction::SRem:
991         case Instruction::FRem:
992         case Instruction::Shl:
993             match = MatchModifier(I);
994             break;
995         case Instruction::LShr:
996             match = MatchModifier(I, false);
997             break;
998         case Instruction::FDiv:
999             match = MatchRsqrt(I) ||
1000                 MatchModifier(I);
1001             break;
1002         case Instruction::FAdd:
1003             match =
1004                 MatchLrp(I) ||
1005                 MatchPredAdd(I) ||
1006                 MatchMad(I) ||
1007                 MatchSimpleAdd(I) ||
1008                 MatchModifier(I);
1009             break;
1010         case Instruction::And:
1011             match =
1012                 MatchBoolOp(I) ||
1013                 MatchLogicAlu(I);
1014             break;
1015         case Instruction::Or:
1016             match =
1017                 MatchBoolOp(I) ||
1018                 MatchLogicAlu(I);
1019             break;
1020         case Instruction::Xor:
1021             match =
1022                 MatchLogicAlu(I);
1023             break;
1024         default:
1025             IGC_ASSERT_MESSAGE(0, "unknown binary instruction");
1026             break;
1027         }
1028         IGC_ASSERT(match == true);
1029     }
1030 
visitCmpInst(llvm::CmpInst & I)1031     void CodeGenPatternMatch::visitCmpInst(llvm::CmpInst& I)
1032     {
1033         bool match = MatchCondModifier(I) ||
1034             MatchModifier(I);
1035         IGC_ASSERT(match);
1036     }
1037 
visitBranchInst(llvm::BranchInst & I)1038     void CodeGenPatternMatch::visitBranchInst(llvm::BranchInst& I)
1039     {
1040         MatchBranch(I);
1041     }
1042 
visitCallInst(CallInst & I)1043     void CodeGenPatternMatch::visitCallInst(CallInst& I)
1044     {
1045         bool match = false;
1046         using namespace GenISAIntrinsic;
1047         if (GenIntrinsicInst * GII = llvm::dyn_cast<GenIntrinsicInst>(&I))
1048         {
1049             switch (GII->getIntrinsicID())
1050             {
1051             case GenISAIntrinsic::GenISA_ROUNDNE:
1052             case GenISAIntrinsic::GenISA_imulH:
1053             case GenISAIntrinsic::GenISA_umulH:
1054             case GenISAIntrinsic::GenISA_uaddc:
1055             case GenISAIntrinsic::GenISA_usubb:
1056             case GenISAIntrinsic::GenISA_bfrev:
1057             case GenISAIntrinsic::GenISA_IEEE_Sqrt:
1058             case GenISAIntrinsic::GenISA_IEEE_Divide:
1059             case GenISAIntrinsic::GenISA_rsq:
1060                 match = MatchModifier(I);
1061                 break;
1062             case GenISAIntrinsic::GenISA_intatomicraw:
1063             case GenISAIntrinsic::GenISA_floatatomicraw:
1064             case GenISAIntrinsic::GenISA_intatomicrawA64:
1065             case GenISAIntrinsic::GenISA_floatatomicrawA64:
1066             case GenISAIntrinsic::GenISA_icmpxchgatomicraw:
1067             case GenISAIntrinsic::GenISA_fcmpxchgatomicraw:
1068             case GenISAIntrinsic::GenISA_icmpxchgatomicrawA64:
1069             case GenISAIntrinsic::GenISA_fcmpxchgatomicrawA64:
1070             case GenISAIntrinsic::GenISA_dwordatomicstructured:
1071             case GenISAIntrinsic::GenISA_floatatomicstructured:
1072             case GenISAIntrinsic::GenISA_cmpxchgatomicstructured:
1073             case GenISAIntrinsic::GenISA_fcmpxchgatomicstructured:
1074             case GenISAIntrinsic::GenISA_intatomictyped:
1075             case GenISAIntrinsic::GenISA_icmpxchgatomictyped:
1076             case GenISAIntrinsic::GenISA_typedread:
1077             case GenISAIntrinsic::GenISA_typedwrite:
1078             case GenISAIntrinsic::GenISA_ldstructured:
1079             case GenISAIntrinsic::GenISA_storestructured1:
1080             case GenISAIntrinsic::GenISA_storestructured2:
1081             case GenISAIntrinsic::GenISA_storestructured3:
1082             case GenISAIntrinsic::GenISA_storestructured4:
1083             case GenISAIntrinsic::GenISA_atomiccounterinc:
1084             case GenISAIntrinsic::GenISA_atomiccounterpredec:
1085             case GenISAIntrinsic::GenISA_ldptr:
1086             case GenISAIntrinsic::GenISA_ldrawvector_indexed:
1087             case GenISAIntrinsic::GenISA_ldraw_indexed:
1088             case GenISAIntrinsic::GenISA_storerawvector_indexed:
1089             case GenISAIntrinsic::GenISA_storeraw_indexed:
1090                 match = MatchSingleInstruction(I);
1091                 break;
1092             case GenISAIntrinsic::GenISA_GradientX:
1093             case GenISAIntrinsic::GenISA_GradientY:
1094             case GenISAIntrinsic::GenISA_GradientXfine:
1095             case GenISAIntrinsic::GenISA_GradientYfine:
1096                 match = MatchGradient(*GII);
1097                 break;
1098             case GenISAIntrinsic::GenISA_sampleptr:
1099             case GenISAIntrinsic::GenISA_sampleBptr:
1100             case GenISAIntrinsic::GenISA_sampleBCptr:
1101             case GenISAIntrinsic::GenISA_sampleCptr:
1102             case GenISAIntrinsic::GenISA_lodptr:
1103             case GenISAIntrinsic::GenISA_sampleKillPix:
1104                 match = MatchSampleDerivative(*GII);
1105                 break;
1106             case GenISAIntrinsic::GenISA_fsat:
1107                 match = MatchFloatingPointSatModifier(I);
1108                 break;
1109             case GenISAIntrinsic::GenISA_usat:
1110             case GenISAIntrinsic::GenISA_isat:
1111                 match = MatchIntegerSatModifier(I);
1112                 break;
1113             case GenISAIntrinsic::GenISA_WaveShuffleIndex:
1114                 match = MatchRegisterRegion(*GII) ||
1115                     MatchShuffleBroadCast(*GII) ||
1116                     MatchWaveShuffleIndex(*GII);
1117                 break;
1118             case GenISAIntrinsic::GenISA_simdBlockRead:
1119             case GenISAIntrinsic::GenISA_simdBlockWrite:
1120                 match = MatchBlockReadWritePointer(*GII) ||
1121                     MatchSingleInstruction(*GII);
1122                 break;
1123             case GenISAIntrinsic::GenISA_URBRead:
1124             case GenISAIntrinsic::GenISA_URBReadOutput:
1125                 match = MatchURBRead(*GII) ||
1126                     MatchSingleInstruction(*GII);
1127                 break;
1128             case GenISAIntrinsic::GenISA_UnmaskedRegionBegin:
1129                 match = MatchUnmaskedRegionBoundary(I, true);
1130                 break;
1131             case GenISAIntrinsic::GenISA_UnmaskedRegionEnd:
1132                 match = MatchUnmaskedRegionBoundary(I, false);
1133                 break;
1134             case GenISAIntrinsic::GenISA_sub_group_dpas:
1135             case GenISAIntrinsic::GenISA_dpas:
1136                 match = MatchDpas(*GII);
1137                 break;
1138             case GenISAIntrinsic::GenISA_dp4a_ss:
1139             case GenISAIntrinsic::GenISA_dp4a_su:
1140             case GenISAIntrinsic::GenISA_dp4a_us:
1141             case GenISAIntrinsic::GenISA_dp4a_uu:
1142                 match = MatchDp4a(*GII);
1143                 break;
1144             default:
1145                 match = MatchSingleInstruction(I);
1146                 // no pattern for the rest of the intrinsics
1147                 break;
1148             }
1149             IGC_ASSERT_MESSAGE(match, "no pattern found for GenISA intrinsic");
1150         }
1151         else
1152         {
1153             Function* Callee = I.getCalledFunction();
1154 
1155             // Match inline asm
1156             if (I.isInlineAsm())
1157             {
1158                 if (getAnalysis<CodeGenContextWrapper>().getCodeGenContext()->m_DriverInfo.SupportInlineAssembly())
1159                 {
1160                     match = MatchSingleInstruction(I);
1161                 }
1162             }
1163             // Match indirect call, support declarations for indirect funcs
1164             else if (!Callee ||
1165                      Callee->hasFnAttribute("referenced-indirectly") ||
1166                      Callee->hasFnAttribute("invoke_simd_target"))
1167             {
1168                 match = MatchSingleInstruction(I);
1169             }
1170             // Match direct call, skip declarations
1171             else if (!Callee->isDeclaration())
1172             {
1173                 match = MatchSingleInstruction(I);
1174             }
1175         }
1176         IGC_ASSERT_MESSAGE(match, "no match for this call");
1177     }
1178 
visitUnaryInstruction(llvm::UnaryInstruction & I)1179     void CodeGenPatternMatch::visitUnaryInstruction(llvm::UnaryInstruction& I)
1180     {
1181         bool match = false;
1182         switch (I.getOpcode())
1183         {
1184         case Instruction::Alloca:
1185         case Instruction::Load:
1186         case Instruction::ExtractValue:
1187             match = MatchSingleInstruction(I);
1188             break;
1189 #if LLVM_VERSION_MAJOR >= 10
1190         case Instruction::FNeg:
1191             match = MatchAbsNeg(I);
1192             break;
1193 #endif
1194         }
1195         IGC_ASSERT(match);
1196     }
1197 
visitIntrinsicInst(llvm::IntrinsicInst & I)1198     void CodeGenPatternMatch::visitIntrinsicInst(llvm::IntrinsicInst& I)
1199     {
1200         bool match = false;
1201         switch (I.getIntrinsicID())
1202         {
1203         case Intrinsic::sqrt:
1204         case Intrinsic::log2:
1205         case Intrinsic::cos:
1206         case Intrinsic::sin:
1207         case Intrinsic::pow:
1208         case Intrinsic::floor:
1209         case Intrinsic::ceil:
1210         case Intrinsic::trunc:
1211         case Intrinsic::ctpop:
1212         case Intrinsic::ctlz:
1213         case Intrinsic::cttz:
1214             match = MatchModifier(I);
1215             break;
1216         case Intrinsic::exp2:
1217             match = MatchPow(I) ||
1218                 MatchModifier(I);
1219             break;
1220         case Intrinsic::fabs:
1221             match = MatchAbsNeg(I);
1222             break;
1223         case Intrinsic::fma:
1224             match = MatchFMA(I);
1225             break;
1226         case Intrinsic::maxnum:
1227         case Intrinsic::minnum:
1228             match = MatchFloatingPointSatModifier(I) ||
1229                 MatchModifier(I);
1230             break;
1231         case Intrinsic::fshl:
1232         case Intrinsic::fshr:
1233             match = MatchFunnelShiftRotate(I);
1234             break;
1235         case Intrinsic::canonicalize:
1236             match = MatchCanonicalizeInstruction(I);
1237             break;
1238         default:
1239             match = MatchSingleInstruction(I);
1240             // no pattern for the rest of the intrinsics
1241             break;
1242         }
1243         IGC_ASSERT_MESSAGE(match, "no pattern found");
1244     }
1245 
visitStoreInst(StoreInst & I)1246     void CodeGenPatternMatch::visitStoreInst(StoreInst& I)
1247     {
1248          bool match = MatchSingleInstruction(I);
1249         IGC_ASSERT(match);
1250     }
1251 
visitLoadInst(LoadInst & I)1252     void CodeGenPatternMatch::visitLoadInst(LoadInst& I)
1253     {
1254         bool match = MatchSingleInstruction(I);
1255         IGC_ASSERT(match);
1256     }
1257 
visitInstruction(llvm::Instruction & I)1258     void CodeGenPatternMatch::visitInstruction(llvm::Instruction& I)
1259     {
1260         // use default pattern
1261         MatchSingleInstruction(I);
1262     }
1263 
visitExtractElementInst(llvm::ExtractElementInst & I)1264     void CodeGenPatternMatch::visitExtractElementInst(llvm::ExtractElementInst& I)
1265     {
1266         Value* VecOpnd = I.getVectorOperand();
1267         if (isa<Constant>(VecOpnd))
1268         {
1269             const Function* F = I.getParent()->getParent();
1270             unsigned NUse = 0;
1271             for (auto User : VecOpnd->users())
1272             {
1273                 if (auto Inst = dyn_cast<Instruction>(User))
1274                 {
1275                     NUse += (Inst->getParent()->getParent() == F);
1276                 }
1277             }
1278 
1279             // Only add it to pool when there are multiple uses within this
1280             // function; otherwise no benefit but to hurt RP.
1281             if (NUse > 1)
1282                 AddToConstantPool(I.getParent(), VecOpnd);
1283         }
1284         MatchSingleInstruction(I);
1285     }
1286 
visitPHINode(PHINode & I)1287     void CodeGenPatternMatch::visitPHINode(PHINode& I)
1288     {
1289         // nothing to do
1290     }
1291 
visitBitCastInst(BitCastInst & I)1292     void CodeGenPatternMatch::visitBitCastInst(BitCastInst& I)
1293     {
1294         // detect
1295         // %66 = insertelement <2 x i32> <i32 0, i32 undef>, i32 %xor19.i, i32 1
1296         // %67 = bitcast <2 x i32> % 66 to i64
1297         // and replace it with a shl 32
1298         struct Shl32Pattern : public Pattern
1299         {
1300             SSource sources[2];
1301             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
1302             {
1303                 pass->Binary(EOPCODE_SHL, sources, modifier);
1304             }
1305         };
1306 
1307         if (I.getType()->isIntegerTy(64) && I.getOperand(0)->getType()->isVectorTy() &&
1308             cast<VectorType>(I.getOperand(0)->getType())->getElementType()->isIntegerTy(32))
1309         {
1310             if (auto IEI = dyn_cast<InsertElementInst>(I.getOperand(0)))
1311             {
1312                 auto vec = dyn_cast<ConstantVector>(IEI->getOperand(0));
1313                 bool isCandidate = vec && vec->getNumOperands() == 2 && IsZero(vec->getOperand(0)) &&
1314                     isa<UndefValue>(vec->getOperand(1));
1315                 auto index = dyn_cast<ConstantInt>(IEI->getOperand(2));
1316                 isCandidate &= index && index->getZExtValue() == 1;
1317                 if (isCandidate)
1318                 {
1319                     Shl32Pattern* Pat = new (m_allocator) Shl32Pattern();
1320                     Pat->sources[0] = GetSource(IEI->getOperand(1), false, false);
1321                     Pat->sources[1] = GetSource(ConstantInt::get(Type::getInt32Ty(I.getContext()), 32), false, false);
1322                     AddPattern(Pat);
1323                     return;
1324                 }
1325             }
1326         }
1327 
1328         MatchSingleInstruction(I);
1329     }
1330 
visitIntToPtrInst(IntToPtrInst & I)1331     void CodeGenPatternMatch::visitIntToPtrInst(IntToPtrInst& I) {
1332         MatchSingleInstruction(I);
1333     }
1334 
visitPtrToIntInst(PtrToIntInst & I)1335     void CodeGenPatternMatch::visitPtrToIntInst(PtrToIntInst& I) {
1336         MatchSingleInstruction(I);
1337     }
1338 
visitAddrSpaceCast(AddrSpaceCastInst & I)1339     void CodeGenPatternMatch::visitAddrSpaceCast(AddrSpaceCastInst& I)
1340     {
1341         MatchSingleInstruction(I);
1342     }
1343 
visitDbgInfoIntrinsic(DbgInfoIntrinsic & I)1344     void CodeGenPatternMatch::visitDbgInfoIntrinsic(DbgInfoIntrinsic& I)
1345     {
1346         MatchDbgInstruction(I);
1347     }
1348 
visitInsertValueInst(InsertValueInst & I)1349     void CodeGenPatternMatch::visitInsertValueInst(InsertValueInst& I)
1350     {
1351         if (!MatchInsertToStruct(&I))
1352         {
1353             IGC_ASSERT_MESSAGE(0, "Unknown `insertvalue` instruction!");
1354         }
1355     }
1356 
visitExtractValueInst(ExtractValueInst & I)1357     void CodeGenPatternMatch::visitExtractValueInst(ExtractValueInst& I) {
1358         bool Match = false;
1359 
1360         // Ignore the extract value instruction. Handled in the call inst.
1361         if (CallInst * call = dyn_cast<CallInst>(I.getOperand(0)))
1362         {
1363             if (call->isInlineAsm() && call->getType()->isStructTy())
1364             {
1365                 MarkAsSource(call);
1366                 return;
1367             }
1368         }
1369 
1370         Match = matchAddPair(&I) ||
1371             matchSubPair(&I) ||
1372             matchMulPair(&I) ||
1373             matchPtrToPair(&I) ||
1374             MatchExtractFromStruct(&I);
1375 
1376         IGC_ASSERT_MESSAGE(Match, "Unknown `extractvalue` instruction!");
1377     }
1378 
MatchInsertToStruct(InsertValueInst * II)1379     bool CodeGenPatternMatch::MatchInsertToStruct(InsertValueInst* II)
1380     {
1381         if (II->getNumIndices() != 1)
1382             return false;
1383 
1384         // Match the following pattern(s):
1385         //
1386         // %2 = insertvalue % struct.t undef, i32 5, 0
1387         // %3 = insertvalue % struct.t %2, i32 5, 1
1388         //  OR
1389         // %2 = insertvalue % struct.t{ i32 10, i32 undef }, i32 5, 1
1390         //
1391         // In both cases, the first `insertvalue` should allocate the struct and initializes it if needed, and
1392         // subsequent `insertvalue`s should insert into the base struct allocation.
1393         // In EmitVISAPass, we only allocate the CVariable for the struct if it is the base value. For any other
1394         // `insertvalue` instructions, we walk up the calls until we get to the base.
1395 
1396         Value* structOperand = II->getOperand(0);
1397         bool isBaseStruct = (isa<Constant>(structOperand) || structOperand->getValueID() == Value::UndefValueVal);
1398         bool forceVecInit = false;
1399 
1400         // If the first insert is to a const struct value, we will initialize it as an uniform value, but
1401         // if a subsequent insert value is non-uniform, we need to let EmitVISA know to create a vector
1402         // variable to initialize the struct.
1403         // Check all "insertvalue" users to see if there are non-uniform values being inserted.
1404         if (isBaseStruct)
1405         {
1406             // Do DFS on all InsertValue users and check their values
1407             std::function<bool(InsertValueInst*, WIAnalysis*)> HasNonUniformInsertValue =
1408                 [&HasNonUniformInsertValue](InsertValueInst* II, WIAnalysis* WI)->bool
1409             {
1410                 if (!WI->isUniform(II->getOperand(1)))
1411                     return true;
1412 
1413                 for (auto user : II->users())
1414                 {
1415                     if (InsertValueInst* inst = dyn_cast<InsertValueInst>(user))
1416                     {
1417                         return HasNonUniformInsertValue(inst, WI);
1418                     }
1419                 }
1420                 return false;
1421             };
1422 
1423             forceVecInit = HasNonUniformInsertValue(II, m_WI);
1424         }
1425 
1426         struct AddCopyStructPattern : public Pattern {
1427             InsertValueInst* II;
1428             bool forceVectorInit; // force allocating a non-uniform CVar reguardless of uniform analysis
1429             virtual void Emit(EmitPass* Pass, const DstModifier& DstMod) {
1430                 Pass->EmitInsertValueToStruct(II, forceVectorInit, DstMod);
1431             }
1432         };
1433 
1434         AddCopyStructPattern* Pat = new (m_allocator) AddCopyStructPattern();
1435         Pat->II = II;
1436         Pat->forceVectorInit = forceVecInit;
1437         AddPattern(Pat);
1438 
1439         MarkAsSource(II->getOperand(0));
1440         MarkAsSource(II->getOperand(1));
1441         return true;
1442     }
1443 
MatchExtractFromStruct(ExtractValueInst * EI)1444     bool CodeGenPatternMatch::MatchExtractFromStruct(ExtractValueInst* EI)
1445     {
1446         if (EI->getNumIndices() != 1)
1447             return false;
1448         if (GenIntrinsicInst* GII = dyn_cast<GenIntrinsicInst>(EI->getOperand(0)))
1449             return false;
1450 
1451         struct AddReadStructPattern : public Pattern {
1452             ExtractValueInst* EI;
1453             virtual void Emit(EmitPass* Pass, const DstModifier& DstMod) {
1454                 Pass->EmitExtractValueFromStruct(EI, DstMod);
1455             }
1456         };
1457 
1458         AddReadStructPattern* Pat = new (m_allocator) AddReadStructPattern();
1459         Pat->EI = EI;
1460         AddPattern(Pat);
1461 
1462         MarkAsSource(EI->getOperand(0));
1463         return true;
1464     }
1465 
matchAddPair(ExtractValueInst * Ex)1466     bool CodeGenPatternMatch::matchAddPair(ExtractValueInst* Ex) {
1467         Value* V = Ex->getOperand(0);
1468         GenIntrinsicInst* GII = dyn_cast<GenIntrinsicInst>(V);
1469         if (!GII || GII->getIntrinsicID() != GenISAIntrinsic::GenISA_add_pair)
1470             return false;
1471 
1472         if (Ex->getNumIndices() != 1)
1473             return false;
1474         unsigned Idx = *Ex->idx_begin();
1475         if (Idx != 0 && Idx != 1)
1476             return false;
1477 
1478         struct AddPairPattern : public Pattern {
1479             GenIntrinsicInst* GII;
1480             SSource Sources[4]; // L0, H0, L1, H1
1481             virtual void Emit(EmitPass* Pass, const DstModifier& DstMod) {
1482                 Pass->EmitAddPair(GII, Sources, DstMod);
1483             }
1484         };
1485 
1486         struct AddPairSubPattern : public Pattern {
1487             virtual void Emit(EmitPass* Pass, const DstModifier& Mod) {
1488                 // DO NOTHING. Dummy pattern.
1489             }
1490         };
1491 
1492         PairOutputMapTy::iterator MI;
1493         bool New = false;
1494         std::tie(MI, New) = PairOutputMap.insert(std::make_pair(GII, PairOutputTy()));
1495         if (New) {
1496             AddPairPattern* Pat = new (m_allocator) AddPairPattern();
1497             Pat->GII = GII;
1498             Pat->Sources[0] = GetSource(GII->getOperand(0), false, false);
1499             Pat->Sources[1] = GetSource(GII->getOperand(1), false, false);
1500             Pat->Sources[2] = GetSource(GII->getOperand(2), false, false);
1501             Pat->Sources[3] = GetSource(GII->getOperand(3), false, false);
1502             AddPattern(Pat);
1503         }
1504         else {
1505             AddPairSubPattern* Pat = new (m_allocator) AddPairSubPattern();
1506             AddPattern(Pat);
1507         }
1508         if (Idx == 0)
1509             MI->second.first = Ex;
1510         else
1511             MI->second.second = Ex;
1512 
1513         return true;
1514     }
1515 
matchSubPair(ExtractValueInst * Ex)1516     bool CodeGenPatternMatch::matchSubPair(ExtractValueInst* Ex) {
1517         Value* V = Ex->getOperand(0);
1518         GenIntrinsicInst* GII = dyn_cast<GenIntrinsicInst>(V);
1519         if (!GII || GII->getIntrinsicID() != GenISAIntrinsic::GenISA_sub_pair)
1520             return false;
1521 
1522         if (Ex->getNumIndices() != 1)
1523             return false;
1524         unsigned Idx = *Ex->idx_begin();
1525         if (Idx != 0 && Idx != 1)
1526             return false;
1527 
1528         struct SubPairPattern : public Pattern {
1529             GenIntrinsicInst* GII;
1530             SSource Sources[4]; // L0, H0, L1, H1
1531             virtual void Emit(EmitPass* Pass, const DstModifier& DstMod) {
1532                 Pass->EmitSubPair(GII, Sources, DstMod);
1533             }
1534         };
1535 
1536         struct SubPairSubPattern : public Pattern {
1537             virtual void Emit(EmitPass* Pass, const DstModifier& Mod) {
1538                 // DO NOTHING. Dummy pattern.
1539             }
1540         };
1541 
1542         PairOutputMapTy::iterator MI;
1543         bool New = false;
1544         std::tie(MI, New) = PairOutputMap.insert(std::make_pair(GII, PairOutputTy()));
1545         if (New) {
1546             SubPairPattern* Pat = new (m_allocator) SubPairPattern();
1547             Pat->GII = GII;
1548             Pat->Sources[0] = GetSource(GII->getOperand(0), false, false);
1549             Pat->Sources[1] = GetSource(GII->getOperand(1), false, false);
1550             Pat->Sources[2] = GetSource(GII->getOperand(2), false, false);
1551             Pat->Sources[3] = GetSource(GII->getOperand(3), false, false);
1552             AddPattern(Pat);
1553         }
1554         else {
1555             SubPairSubPattern* Pat = new (m_allocator) SubPairSubPattern();
1556             AddPattern(Pat);
1557         }
1558         if (Idx == 0)
1559             MI->second.first = Ex;
1560         else
1561             MI->second.second = Ex;
1562 
1563         return true;
1564     }
1565 
matchMulPair(ExtractValueInst * Ex)1566     bool CodeGenPatternMatch::matchMulPair(ExtractValueInst* Ex) {
1567         Value* V = Ex->getOperand(0);
1568         GenIntrinsicInst* GII = dyn_cast<GenIntrinsicInst>(V);
1569         if (!GII || GII->getIntrinsicID() != GenISAIntrinsic::GenISA_mul_pair)
1570             return false;
1571 
1572         if (Ex->getNumIndices() != 1)
1573             return false;
1574         unsigned Idx = *Ex->idx_begin();
1575         if (Idx != 0 && Idx != 1)
1576             return false;
1577 
1578         struct MulPairPattern : public Pattern {
1579             GenIntrinsicInst* GII;
1580             SSource Sources[4]; // L0, H0, L1, H1
1581             virtual void Emit(EmitPass* Pass, const DstModifier& DstMod) {
1582                 Pass->EmitMulPair(GII, Sources, DstMod);
1583             }
1584         };
1585 
1586         struct MulPairSubPattern : public Pattern {
1587             virtual void Emit(EmitPass* Pass, const DstModifier& Mod) {
1588                 // DO NOTHING. Dummy pattern.
1589             }
1590         };
1591 
1592         PairOutputMapTy::iterator MI;
1593         bool New = false;
1594         std::tie(MI, New) = PairOutputMap.insert(std::make_pair(GII, PairOutputTy()));
1595         if (New) {
1596             MulPairPattern* Pat = new (m_allocator) MulPairPattern();
1597             Pat->GII = GII;
1598             Pat->Sources[0] = GetSource(GII->getOperand(0), false, false);
1599             Pat->Sources[1] = GetSource(GII->getOperand(1), false, false);
1600             Pat->Sources[2] = GetSource(GII->getOperand(2), false, false);
1601             Pat->Sources[3] = GetSource(GII->getOperand(3), false, false);
1602             AddPattern(Pat);
1603         }
1604         else {
1605             MulPairSubPattern* Pat = new (m_allocator) MulPairSubPattern();
1606             AddPattern(Pat);
1607         }
1608         if (Idx == 0)
1609             MI->second.first = Ex;
1610         else
1611             MI->second.second = Ex;
1612 
1613         return true;
1614     }
1615 
matchPtrToPair(ExtractValueInst * Ex)1616     bool CodeGenPatternMatch::matchPtrToPair(ExtractValueInst* Ex) {
1617         Value* V = Ex->getOperand(0);
1618         GenIntrinsicInst* GII = dyn_cast<GenIntrinsicInst>(V);
1619         if (!GII || GII->getIntrinsicID() != GenISAIntrinsic::GenISA_ptr_to_pair)
1620             return false;
1621 
1622         if (Ex->getNumIndices() != 1)
1623             return false;
1624         unsigned Idx = *Ex->idx_begin();
1625         if (Idx != 0 && Idx != 1)
1626             return false;
1627 
1628         struct PtrToPairPattern : public Pattern {
1629             GenIntrinsicInst* GII;
1630             SSource Sources[1]; // Ptr
1631             virtual void Emit(EmitPass* Pass, const DstModifier& DstMod) {
1632                 Pass->EmitPtrToPair(GII, Sources, DstMod);
1633             }
1634         };
1635 
1636         struct PtrToPairSubPattern : public Pattern {
1637             virtual void Emit(EmitPass* Pass, const DstModifier& Mod) {
1638                 // DO NOTHING. Dummy pattern.
1639             }
1640         };
1641 
1642         PairOutputMapTy::iterator MI;
1643         bool New = false;
1644         std::tie(MI, New) = PairOutputMap.insert(std::make_pair(GII, PairOutputTy()));
1645         if (New) {
1646             PtrToPairPattern* Pat = new (m_allocator) PtrToPairPattern();
1647             Pat->GII = GII;
1648             Pat->Sources[0] = GetSource(GII->getOperand(0), false, false);
1649             AddPattern(Pat);
1650         }
1651         else {
1652             PtrToPairSubPattern* Pat = new (m_allocator) PtrToPairSubPattern();
1653             AddPattern(Pat);
1654         }
1655         if (Idx == 0)
1656             MI->second.first = Ex;
1657         else
1658             MI->second.second = Ex;
1659 
1660         return true;
1661     }
1662 
MatchAbsNeg(llvm::Instruction & I)1663     bool CodeGenPatternMatch::MatchAbsNeg(llvm::Instruction& I)
1664     {
1665         struct MovModifierPattern : public Pattern
1666         {
1667             SSource source;
1668             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
1669             {
1670                 pass->Mov(source, modifier);
1671             }
1672         };
1673         bool match = false;
1674         e_modifier mod;
1675         Value* source = nullptr;
1676         if (GetModifier(I, mod, source))
1677         {
1678             MovModifierPattern* pattern = new (m_allocator) MovModifierPattern();
1679             pattern->source = GetSource(source, mod, false);
1680             match = true;
1681             AddPattern(pattern);
1682         }
1683         return match;
1684     }
1685 
MatchFrc(llvm::BinaryOperator & I)1686     bool CodeGenPatternMatch::MatchFrc(llvm::BinaryOperator& I)
1687     {
1688         if (m_ctx->m_DriverInfo.DisableMatchFrcPatternMatch())
1689         {
1690             return false;
1691         }
1692 
1693         struct FrcPattern : public Pattern
1694         {
1695             SSource source;
1696             void Emit(EmitPass* pass, const DstModifier& modifier) override
1697             {
1698                 pass->Frc(source, modifier);
1699             }
1700 
1701             bool supportsSaturate() override { return false; }
1702         };
1703         IGC_ASSERT(I.getOpcode() == Instruction::FSub);
1704         llvm::Value* source0 = I.getOperand(0);
1705         llvm::IntrinsicInst* source1 = llvm::dyn_cast<llvm::IntrinsicInst>(I.getOperand(1));
1706         bool found = false;
1707         if (source1 && source1->getIntrinsicID() == Intrinsic::floor)
1708         {
1709             if (source1->getOperand(0) == source0)
1710             {
1711                 found = true;
1712             }
1713         }
1714         if (found)
1715         {
1716             FrcPattern* pattern = new (m_allocator) FrcPattern();
1717             pattern->source = GetSource(source0, true, false);
1718             AddPattern(pattern);
1719         }
1720         return found;
1721     }
1722 
1723     /*
1724     below pass handles x - frac(x) = floor(x) pattern. Refer below :
1725 
1726     frc (8|M0) r20.0<1>:f r19.0<8;8,1>:f {Compacted, @1}
1727     add (8|M0) (ge)f0.1 r19.0<1>:f r19.0<8;8,1>:f -r20.0<8;8,1>:f {@1}
1728     rndd (8|M0) r21.0<1>:f (abs)r19.0<8;8,1>:f {Compacted, @1}
1729    */
1730 
MatchFloor(llvm::BinaryOperator & I)1731     bool CodeGenPatternMatch::MatchFloor(llvm::BinaryOperator& I)
1732     {
1733         if (IGC_IS_FLAG_ENABLED(DisableMatchFloor))
1734         {
1735             return false;
1736         }
1737         struct FloorPattern : public Pattern
1738         {
1739             SSource source;
1740             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
1741             {
1742                 pass->Floor(source, modifier);
1743             }
1744         };
1745         IGC_ASSERT(I.getOpcode() == Instruction::FSub);
1746         llvm::Value* source0 = I.getOperand(0);
1747         GenIntrinsicInst* source1 = dyn_cast<GenIntrinsicInst>(I.getOperand(1));
1748         bool found = false;
1749         if (source1 && source1->getIntrinsicID() == GenISAIntrinsic::GenISA_frc)
1750         {
1751             if (source1->getOperand(0) == source0)
1752             {
1753                 found = true;
1754             }
1755         }
1756         if (found)
1757         {
1758             FloorPattern* pattern = new (m_allocator) FloorPattern();
1759             pattern->source = GetSource(source0, true, false);
1760             AddPattern(pattern);
1761         }
1762         return found;
1763     }
1764 
GetSource(llvm::Value * value,bool modifier,bool regioning)1765     SSource CodeGenPatternMatch::GetSource(llvm::Value* value, bool modifier, bool regioning)
1766     {
1767         llvm::Value* sourceValue = value;
1768         e_modifier mod = EMOD_NONE;
1769         if (modifier)
1770         {
1771             GetModifier(*sourceValue, mod, sourceValue);
1772         }
1773         return GetSource(sourceValue, mod, regioning);
1774     }
1775 
GetSource(llvm::Value * value,e_modifier mod,bool regioning)1776     SSource CodeGenPatternMatch::GetSource(llvm::Value* value, e_modifier mod, bool regioning)
1777     {
1778         SSource source;
1779         GetRegionModifier(source, value, regioning);
1780         source.value = value;
1781         source.mod = mod;
1782         MarkAsSource(value);
1783         return source;
1784     }
1785 
MarkAsSource(llvm::Value * v)1786     void CodeGenPatternMatch::MarkAsSource(llvm::Value* v)
1787     {
1788         // update liveness of the sources
1789         if (IsConstOrSimdConstExpr(v))
1790         {
1791             // skip constant
1792             return;
1793         }
1794         if (isa<Instruction>(v) || isa<Argument>(v))
1795         {
1796             m_LivenessInfo->HandleVirtRegUse(v, m_root->getParent(), m_root);
1797         }
1798         // mark the source as used so that we know we need to generate this value
1799         if (llvm::Instruction * inst = llvm::dyn_cast<Instruction>(v))
1800         {
1801             m_usedInstructions.insert(inst);
1802         }
1803         if (m_rootIsSubspanUse)
1804         {
1805             HandleSubspanUse(v);
1806         }
1807     }
1808 
IsSubspanUse(llvm::Value * v)1809     bool CodeGenPatternMatch::IsSubspanUse(llvm::Value* v)
1810     {
1811         return m_subSpanUse.find(v) != m_subSpanUse.end();
1812     }
1813 
MatchFMA(llvm::IntrinsicInst & I)1814     bool CodeGenPatternMatch::MatchFMA(llvm::IntrinsicInst& I)
1815     {
1816         struct FMAPattern : Pattern
1817         {
1818             SSource sources[3];
1819             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
1820             {
1821                 pass->Mad(sources, modifier);
1822             }
1823         };
1824 
1825         FMAPattern* pattern = new (m_allocator)FMAPattern();
1826         for (int i = 0; i < 3; i++)
1827         {
1828             llvm::Value* V = I.getOperand(i);
1829             pattern->sources[i] = GetSource(V, true, false);
1830             if (isa<Constant>(V) &&
1831                 (!m_Platform.support16BitImmSrcForMad() ||
1832                     V->getType()->getTypeID() != llvm::Type::HalfTyID || i == 1))
1833             {
1834                 //CNL+ mad instruction allows 16 bit immediate for src0 and src2
1835                 AddToConstantPool(I.getParent(), V);
1836                 pattern->sources[i].fromConstantPool = true;
1837             }
1838         }
1839         AddPattern(pattern);
1840 
1841         return true;
1842     }
1843 
MatchPredAdd(llvm::BinaryOperator & I)1844     bool CodeGenPatternMatch::MatchPredAdd(llvm::BinaryOperator& I)
1845     {
1846         struct PredAddPattern : Pattern
1847         {
1848             SSource sources[2];
1849             SSource pred;
1850             e_predMode predMode;
1851             bool invertPred;
1852             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
1853             {
1854                 DstModifier modf = modifier;
1855                 modf.predMode = predMode;
1856                 pass->PredAdd(pred, invertPred, sources, modf);
1857             }
1858         };
1859 
1860         if (m_ctx->getModuleMetaData()->isPrecise)
1861         {
1862             return false;
1863         }
1864 
1865         if (m_ctx->type == ShaderType::VERTEX_SHADER ||
1866             !m_ctx->m_DriverInfo.SupportMatchPredAdd())
1867         {
1868             return false;
1869         }
1870 
1871         bool found = false;
1872 
1873         llvm::Value* sources[2] = {nullptr,nullptr};
1874         llvm::Value* pred = nullptr;
1875         e_modifier src_mod[2] = { e_modifier::EMOD_NONE, e_modifier::EMOD_NONE };
1876         e_modifier pred_mod = e_modifier::EMOD_NONE;
1877         bool invertPred = false;
1878         if (m_AllowContractions == false || IGC_IS_FLAG_ENABLED(DisableMatchPredAdd))
1879         {
1880             return false;
1881         }
1882 
1883         // Skip the pattern match if FPTrunc/FPEXt is used right after fadd
1884         if (I.hasOneUse())
1885         {
1886             FPTruncInst* FPTrunc = llvm::dyn_cast<llvm::FPTruncInst>(*I.user_begin());
1887             FPExtInst* FPExt = llvm::dyn_cast<llvm::FPExtInst>  (*I.user_begin());
1888 
1889             if (FPTrunc || FPExt)
1890             {
1891                 return false;
1892             }
1893         }
1894 
1895         IGC_ASSERT(I.getOpcode() == Instruction::FAdd || I.getOpcode() == Instruction::FSub);
1896         for (uint iAdd = 0; iAdd < 2 && !found; iAdd++)
1897         {
1898             Value* src = I.getOperand(iAdd);
1899             llvm::BinaryOperator* mul = llvm::dyn_cast<llvm::BinaryOperator>(src);
1900             if (mul && mul->getOpcode() == Instruction::FMul)
1901             {
1902                 if (!mul->hasOneUse())
1903                 {
1904                     continue;
1905                 }
1906 
1907                 for (uint iMul = 0; iMul < 2; iMul++)
1908                 {
1909                     if (llvm::SelectInst * selInst = dyn_cast<SelectInst>(mul->getOperand(iMul)))
1910                     {
1911                         ConstantFP* C1 = dyn_cast<ConstantFP>(selInst->getOperand(1));
1912                         ConstantFP* C2 = dyn_cast<ConstantFP>(selInst->getOperand(2));
1913                         if (C1 && C2 && selInst->hasOneUse())
1914                         {
1915                             // select i1 %res_s48, float 1.000000e+00, float 0.000000e+00
1916                             if ((C2->isZero() && C1->isExactlyValue(1.f)))
1917                             {
1918                                 invertPred = false;
1919                             }
1920                             // select i1 %res_s48, float 0.000000e+00, float 1.000000e+00
1921                             else if (C1->isZero() && C2->isExactlyValue(1.f))
1922                             {
1923                                 invertPred = true;
1924                             }
1925                             else
1926                             {
1927                                 continue;
1928                             }
1929                         } // if (C1 && C2 && selInst->hasOneUse())
1930                         else
1931                         {
1932                             continue;
1933                         }
1934 
1935                         //   % 97 = select i1 %res_s48, float 1.000000e+00, float 0.000000e+00
1936                         //   %102 = fmul fast float %97 %98
1937 
1938                         // case 1 (add)
1939                         // Before match
1940                         //   %105 = fadd %102 %103
1941                         // After match
1942                         //              %105 = %103
1943                         //   (%res_s48) %105 = fadd %105 %98
1944 
1945                         // case 2 (fsub match @ iAdd = 0)
1946                         // Before match
1947                         //   %105 = fsub %102 %103
1948                         // After match
1949                         //              %105 = -%103
1950                         //   (%res_s48) %105 = fadd %105 %98
1951 
1952                         // case 3 (fsub match @ iAdd = 1)
1953                         // Before match
1954                         //   %105 = fsub %103 %102
1955                         // After match
1956                         //              %105 = %103
1957                         //   (%res_s48) %98 = fadd %105 -%98
1958 
1959                         // sources[0]: store add operand (i.e. %103 above)
1960                         // sources[1]: store mul operand (i.e. %98 above)
1961 
1962                         sources[0] = I.getOperand(1 ^ iAdd);
1963                         sources[1] = mul->getOperand(1 ^ iMul);
1964                         pred = selInst->getOperand(0);
1965 
1966                         GetModifier(*sources[0], src_mod[0], sources[0]);
1967                         GetModifier(*sources[1], src_mod[1], sources[1]);
1968                         GetModifier(*pred, pred_mod, pred);
1969 
1970                         if (I.getOpcode() == Instruction::FSub)
1971                         {
1972                             src_mod[iAdd] = CombineModifier(EMOD_NEG, src_mod[iAdd]);
1973                         }
1974 
1975                         found = true;
1976                         break;
1977                     } //  if (llvm::SelectInst* selInst = dyn_cast<SelectInst>(mul->getOperand(iMul)))
1978                 } // for (uint iMul = 0; iMul < 2; iMul++)
1979             } // if (mul && mul->getOpcode() == Instruction::FMul)
1980         } // for (uint iAdd = 0; iAdd < 2; iAdd++)
1981 
1982         if (found)
1983         {
1984             PredAddPattern* pattern = new (m_allocator) PredAddPattern();
1985             pattern->predMode = EPRED_NORMAL;
1986             pattern->sources[0] = GetSource(sources[0], src_mod[0], false);
1987             pattern->sources[1] = GetSource(sources[1], src_mod[1], false);
1988             pattern->pred = GetSource(pred, pred_mod, false);
1989             pattern->invertPred = invertPred;
1990 
1991             AddPattern(pattern);
1992         }
1993         return found;
1994     }
1995 
1996     // we match the following pattern
1997     // %c = fcmp %1 %2
1998     // %g = sext %c to i32
1999     // %h = and i32 %g 1065353216
2000     // %m = bitcast i32 %h to float
2001     // %p = fadd %m %n
MatchSimpleAdd(llvm::BinaryOperator & I)2002     bool CodeGenPatternMatch::MatchSimpleAdd(llvm::BinaryOperator& I)
2003     {
2004         struct SimpleAddPattern : public Pattern
2005         {
2006             SSource sources[2];
2007             SSource pred;
2008 
2009             e_predMode predMode;
2010             bool invertPred;
2011             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
2012             {
2013                 DstModifier modf = modifier;
2014                 modf.predMode = predMode;
2015                 pass->PredAdd(pred, invertPred, sources, modf);
2016             }
2017         };
2018 
2019         IGC_ASSERT(I.getOpcode() == Instruction::FAdd);
2020 
2021         if (IGC_IS_FLAG_ENABLED(DisableMatchSimpleAdd))
2022         {
2023             return false;
2024         }
2025 
2026         unsigned int repAddOperand = 0;
2027         BitCastInst* bitCastInst0 = llvm::dyn_cast<llvm::BitCastInst>(I.getOperand(0));
2028         BitCastInst* bitCastInst1 = llvm::dyn_cast<llvm::BitCastInst>(I.getOperand(1));
2029         BitCastInst* bitCastInst = NULL;
2030 
2031         if (!bitCastInst0 && !bitCastInst1)
2032         {
2033             return false;
2034         }
2035 
2036         if (bitCastInst1)
2037         {
2038             bitCastInst = bitCastInst1;
2039             repAddOperand = 1;
2040         }
2041         else
2042         {
2043             bitCastInst = bitCastInst0;
2044             repAddOperand = 0;
2045         }
2046 
2047         if (!bitCastInst->hasOneUse())
2048         {
2049             return false;
2050         }
2051 
2052         Instruction* andInst = dyn_cast<Instruction>(bitCastInst->getOperand(0));
2053         if (!andInst || (andInst->getOpcode() != Instruction::And))
2054         {
2055             return false;
2056         }
2057 
2058         // check %h = and i32 %g 1065353216
2059         if (!andInst->getType()->isIntegerTy(32))
2060         {
2061             return false;
2062         }
2063 
2064         ConstantInt* CInt = dyn_cast<ConstantInt>(andInst->getOperand(1));
2065         if (!CInt || (CInt->getZExtValue() != 0x3f800000))
2066         {
2067             return false;
2068         }
2069 
2070         SExtInst* SExt = llvm::dyn_cast<llvm::SExtInst>(andInst->getOperand(0));
2071         if (!SExt)
2072         {
2073             return false;
2074         }
2075 
2076         CmpInst* cmp = llvm::dyn_cast<CmpInst>(SExt->getOperand(0));
2077         if (!cmp)
2078         {
2079             return false;
2080         }
2081 
2082         // match found
2083         SimpleAddPattern* pattern = new (m_allocator) SimpleAddPattern();
2084         llvm::Value* sources[2], * pred;
2085         e_modifier src_mod[2], pred_mod;
2086 
2087         sources[0] = I.getOperand(1 - repAddOperand);
2088         sources[1] = ConstantFP::get(I.getType(), 1.0);
2089         pred = cmp;
2090 
2091         GetModifier(*sources[0], src_mod[0], sources[0]);
2092         GetModifier(*sources[1], src_mod[1], sources[1]);
2093         GetModifier(*pred, pred_mod, pred);
2094 
2095         pattern->predMode = EPRED_NORMAL;
2096         pattern->sources[0] = GetSource(sources[0], src_mod[0], false);
2097         pattern->sources[1] = GetSource(sources[1], src_mod[1], false);
2098         pattern->pred = GetSource(pred, pred_mod, false);
2099         pattern->invertPred = false;
2100         AddPattern(pattern);
2101 
2102         return true;
2103     }
2104 
MatchMad(llvm::BinaryOperator & I)2105     bool CodeGenPatternMatch::MatchMad(llvm::BinaryOperator& I)
2106     {
2107         struct MadPattern : Pattern
2108         {
2109             SSource sources[3];
2110             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
2111             {
2112                 pass->Mad(sources, modifier);
2113             }
2114         };
2115 
2116         auto isFpMad = [](const Instruction& I)
2117         {
2118             return I.getType()->isFloatingPointTy();
2119         };
2120 
2121         if (isFpMad(I) && (m_ctx->getModuleMetaData()->isPrecise || m_ctx->getModuleMetaData()->compOpt.disableMathRefactoring))
2122         {
2123             return false;
2124         }
2125         if (m_ctx->type == ShaderType::VERTEX_SHADER &&
2126             m_ctx->m_DriverInfo.DisabeMatchMad())
2127         {
2128             return false;
2129         }
2130 
2131         bool allow = m_ctx->getModuleMetaData()->allowMatchMadOptimizationforVS || IGC_IS_FLAG_ENABLED(WaAllowMatchMadOptimizationforVS);
2132 
2133         using namespace llvm::PatternMatch;
2134         if (m_ctx->type == ShaderType::VERTEX_SHADER &&
2135             m_ctx->m_DriverInfo.PreventZFighting() && !allow)
2136         {
2137             if (m_PosDep->PositionDependsOnInst(&I))
2138                 return false;
2139         }
2140 
2141         if (IGC_IS_FLAG_ENABLED(DisableMatchMad))
2142         {
2143             return false;
2144         }
2145 
2146         bool isFpMadWithContractionOverride = false;
2147         if (isFpMad(I) && m_AllowContractions == false)
2148         {
2149             if (I.hasAllowContract() && m_ctx->m_DriverInfo.RespectPerInstructionContractFlag())
2150             {
2151                 isFpMadWithContractionOverride = true;
2152             }
2153             else
2154             {
2155                 return false;
2156             }
2157         }
2158         if (!isFpMad(I) && !(m_Platform.doIntegerMad() && m_ctx->m_DriverInfo.EnableIntegerMad()))
2159         {
2160             return false;
2161         }
2162 
2163         bool found = false;
2164         llvm::Value* sources[3];
2165         e_modifier src_mod[3];
2166 
2167         IGC_ASSERT(I.getOpcode() == Instruction::FAdd ||
2168             I.getOpcode() == Instruction::FSub ||
2169             I.getOpcode() == Instruction::Add ||
2170             I.getOpcode() == Instruction::Sub);
2171         if (I.getOperand(0) != I.getOperand(1))
2172         {
2173             for (uint i = 0; i < 2; i++)
2174             {
2175                 Value* src = I.getOperand(i);
2176                 if (FPExtInst * fpextInst = llvm::dyn_cast<llvm::FPExtInst>(src))
2177                 {
2178                     if (!m_Platform.supportMixMode() && fpextInst->getSrcTy()->getTypeID() == llvm::Type::HalfTyID)
2179                     {
2180                         // no mix mode instructions
2181                     }
2182                     else if (fpextInst->getSrcTy()->getTypeID() != llvm::Type::DoubleTyID &&
2183                         fpextInst->getDestTy()->getTypeID() != llvm::Type::DoubleTyID)
2184                     {
2185                         src = fpextInst->getOperand(0);
2186                     }
2187                 }
2188                 llvm::BinaryOperator* mul = llvm::dyn_cast<llvm::BinaryOperator>(src);
2189 
2190                 if (mul && (mul->getOpcode() == Instruction::FMul ||
2191                     mul->getOpcode() == Instruction::Mul))
2192                 {
2193                     // in case we know we won't be able to remove the mul we don't merge it
2194                     if (!m_PosDep->PositionDependsOnInst(mul) && NeedInstruction(*mul))
2195                         continue;
2196 
2197                     if (isFpMadWithContractionOverride && !mul->hasAllowContract())
2198                         continue;
2199 
2200                     sources[2] = I.getOperand(1 - i);
2201                     sources[1] = mul->getOperand(0);
2202                     sources[0] = mul->getOperand(1);
2203                     GetModifier(*sources[0], src_mod[0], sources[0]);
2204                     GetModifier(*sources[1], src_mod[1], sources[1]);
2205                     GetModifier(*sources[2], src_mod[2], sources[2]);
2206                     if (I.getOpcode() == Instruction::FSub ||
2207                         I.getOpcode() == Instruction::Sub)
2208                     {
2209                         if (i == 0)
2210                         {
2211                             src_mod[2] = CombineModifier(EMOD_NEG, src_mod[2]);
2212                         }
2213                         else
2214                         {
2215                             if (llvm::isa<llvm::Constant>(sources[0]))
2216                             {
2217                                 src_mod[1] = CombineModifier(EMOD_NEG, src_mod[1]);
2218                             }
2219                             else
2220                             {
2221                                 src_mod[0] = CombineModifier(EMOD_NEG, src_mod[0]);
2222                             }
2223                         }
2224                     }
2225                     found = true;
2226                     break;
2227                 }
2228             }
2229         }
2230 
2231         // Check integer mad profitability.
2232         if (found && !isFpMad(I))
2233         {
2234             uint8_t numConstant = 0;
2235             for (int i = 0; i < 3; i++)
2236             {
2237                 if (isa<Constant>(sources[i]))
2238                     numConstant++;
2239 
2240                 // Only one immediate is supported
2241                 if (numConstant > 1)
2242                     return false;
2243             }
2244 
2245             auto isByteOrWordValue = [](Value* V) -> bool {
2246                 if (isa<ConstantInt>(V))
2247                 {
2248                     // only 16-bit int immediate is supported
2249                     APInt val = dyn_cast<ConstantInt>(V)->getValue();
2250                     return val.sge(SHRT_MIN) && val.sle(SHRT_MAX);
2251                 }
2252                 // Trace the def-use chain and return the first non up-cast related value.
2253                 while (isa<ZExtInst>(V) || isa<SExtInst>(V) || isa<BitCastInst>(V))
2254                     V = cast<Instruction>(V)->getOperand(0);
2255                 const unsigned DWordSizeInBits = 32;
2256                 return V->getType()->getScalarSizeInBits() < DWordSizeInBits;
2257             };
2258 
2259             // One multiplicant should be *W or *B.
2260             if (!isByteOrWordValue(sources[0]) && !isByteOrWordValue(sources[1]))
2261                 return false;
2262 
2263             auto isQWordValue = [](Value* V) -> bool {
2264                 while (isa<ZExtInst>(V) || isa<SExtInst>(V) || isa<BitCastInst>(V))
2265                     V = cast<Instruction>(V)->getOperand(0);
2266                 Type* T = V->getType();
2267                 return (T->isIntegerTy() && T->getScalarSizeInBits() == 64);
2268             };
2269 
2270             // Mad instruction doesn't support QW type
2271             if (isQWordValue(sources[0]) || isQWordValue(sources[1]))
2272                 return false;
2273         }
2274 
2275         if (found)
2276         {
2277             MadPattern* pattern = new (m_allocator) MadPattern();
2278             for (int i = 0; i < 3; i++)
2279             {
2280                 pattern->sources[i] = GetSource(sources[i], src_mod[i], false);
2281                 if (isa<Constant>(sources[i]) &&
2282                     (!m_Platform.support16BitImmSrcForMad() ||
2283                     (!sources[i]->getType()->isHalfTy() && !sources[i]->getType()->isIntegerTy()) || i == 1))
2284                 {
2285                     //CNL+ mad instruction allows 16 bit immediate for src0 and src2
2286                     AddToConstantPool(I.getParent(), sources[i]);
2287                     pattern->sources[i].fromConstantPool = true;
2288                 }
2289             }
2290             AddPattern(pattern);
2291         }
2292         return found;
2293     }
2294 
2295     // match simdblockRead/Write with preceding inttoptr if possible
2296     // to save a copy move
MatchBlockReadWritePointer(llvm::GenIntrinsicInst & I)2297     bool CodeGenPatternMatch::MatchBlockReadWritePointer(llvm::GenIntrinsicInst& I)
2298     {
2299         struct BlockReadWritePointerPattern : public Pattern
2300         {
2301             GenIntrinsicInst* inst;
2302             Value* offset;
2303             explicit BlockReadWritePointerPattern(GenIntrinsicInst* I, Value* ofst) : inst(I), offset(ofst) {}
2304             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
2305             {
2306                 switch (inst->getIntrinsicID())
2307                 {
2308                 case GenISAIntrinsic::GenISA_simdBlockRead:
2309                     pass->emitSimdBlockRead(inst, offset);
2310                     break;
2311                 case GenISAIntrinsic::GenISA_simdBlockWrite:
2312                     pass->emitSimdBlockWrite(inst, offset);
2313                     break;
2314                 default:
2315                     IGC_ASSERT_MESSAGE(0, "unexpected intrinsic");
2316                     break;
2317                 }
2318             }
2319         };
2320 
2321         if (I.getIntrinsicID() != GenISAIntrinsic::GenISA_simdBlockRead &&
2322             I.getIntrinsicID() != GenISAIntrinsic::GenISA_simdBlockWrite)
2323         {
2324             return false;
2325         }
2326 
2327         // check the address is inttoptr with same dst and src type width
2328         auto ptrVal = I.getOperand(0);
2329         auto I2P = dyn_cast<IntToPtrInst>(ptrVal);
2330         if (I2P &&
2331             I2P->getOperand(0)->getType()->getIntegerBitWidth() ==
2332             m_ctx->getRegisterPointerSizeInBits(I2P->getAddressSpace()))
2333         {
2334             auto addrBase = I2P->getOperand(0);
2335             BlockReadWritePointerPattern* pattern = new (m_allocator) BlockReadWritePointerPattern(&I, addrBase);
2336             MarkAsSource(addrBase);
2337             // for write mark data ptr as well
2338             if (I.getIntrinsicID() == GenISAIntrinsic::GenISA_simdBlockWrite)
2339             {
2340                 MarkAsSource(I.getOperand(1));
2341             }
2342 
2343             AddPattern(pattern);
2344             return true;
2345         }
2346         return false;
2347     }
2348 
2349     // 1. Detect and handle immediate URB read offsets - these can be put in message descriptor.
2350     // 2. Detect offsets of the form "add dst, var, imm" - here we can remove the add, putting imm in message descriptor,
2351     // and var in message payload.
MatchURBRead(llvm::GenIntrinsicInst & I)2352     bool CodeGenPatternMatch::MatchURBRead(llvm::GenIntrinsicInst& I)
2353     {
2354         struct URBReadPattern : public Pattern
2355         {
2356             explicit URBReadPattern(GenIntrinsicInst* I, QuadEltUnit globalOffset, llvm::Value* const perSlotOffset) :
2357                 m_inst(I), m_globalOffset(globalOffset), m_perSlotOffset(perSlotOffset)
2358             {}
2359 
2360             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
2361             {
2362                 IGC_ASSERT(m_inst->getIntrinsicID() == GenISAIntrinsic::GenISA_URBRead ||
2363                     m_inst->getIntrinsicID() == GenISAIntrinsic::GenISA_URBReadOutput);
2364                 pass->emitURBReadCommon(m_inst, m_globalOffset, m_perSlotOffset);
2365             }
2366 
2367         private:
2368             GenIntrinsicInst* const m_inst;
2369             const QuadEltUnit m_globalOffset;
2370             llvm::Value* const m_perSlotOffset;
2371         };
2372 
2373         if (I.getIntrinsicID() != GenISAIntrinsic::GenISA_URBRead &&
2374             I.getIntrinsicID() != GenISAIntrinsic::GenISA_URBReadOutput)
2375         {
2376             return false;
2377         }
2378 
2379         const bool hasVertexIndexAsArg0 = I.getIntrinsicID() == GenISAIntrinsic::GenISA_URBRead;
2380         llvm::Value* const offset = I.getOperand(hasVertexIndexAsArg0 ? 1 : 0);
2381         if (const ConstantInt * const constOffset = dyn_cast<ConstantInt>(offset))
2382         {
2383             const QuadEltUnit globalOffset = QuadEltUnit(int_cast<unsigned>(constOffset->getZExtValue()));
2384             if (hasVertexIndexAsArg0)
2385             {
2386                 MarkAsSource(I.getOperand(0));
2387             }
2388             URBReadPattern* pattern = new (m_allocator) URBReadPattern(&I, globalOffset, nullptr);
2389             AddPattern(pattern);
2390             return true;
2391         }
2392         else if (llvm::Instruction * const inst = llvm::dyn_cast<llvm::Instruction>(offset))
2393         {
2394             if (inst->getOpcode() == llvm::Instruction::Add)
2395             {
2396                 const bool isConstant0 = llvm::isa<llvm::ConstantInt>(inst->getOperand(0));
2397                 const bool isConstant1 = llvm::isa<llvm::ConstantInt>(inst->getOperand(1));
2398                 if (isConstant0 || isConstant1)
2399                 {
2400                     IGC_ASSERT_MESSAGE(!(isConstant0 && isConstant1), "Both operands are immediate - constants should be folded elsewhere.");
2401 
2402                     if (hasVertexIndexAsArg0)
2403                     {
2404                         MarkAsSource(I.getOperand(0));
2405                     }
2406                     const QuadEltUnit globalOffset = QuadEltUnit(int_cast<unsigned>(cast<ConstantInt>(
2407                         isConstant0 ? inst->getOperand(0) : inst->getOperand(1))->getZExtValue()));
2408                     llvm::Value* const perSlotOffset = isConstant0 ? inst->getOperand(1) : inst->getOperand(0);
2409                     MarkAsSource(perSlotOffset);
2410                     URBReadPattern* pattern = new (m_allocator) URBReadPattern(&I, globalOffset, perSlotOffset);
2411                     AddPattern(pattern);
2412                     return true;
2413                 }
2414             }
2415         }
2416 
2417         return false;
2418     }
2419 
2420 
MatchLoadStorePointer(llvm::Instruction & I,llvm::Value & ptrVal)2421     bool CodeGenPatternMatch::MatchLoadStorePointer(llvm::Instruction& I, llvm::Value& ptrVal)
2422     {
2423         struct LoadStorePointerPattern : public Pattern
2424         {
2425             Instruction* inst;
2426             Value* offset;
2427             ConstantInt* immOffset;
2428             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
2429             {
2430                 if (isa<LoadInst>(inst))
2431                 {
2432                     pass->emitLoad(cast<LoadInst>(inst), offset, immOffset);
2433                 }
2434                 else if (isa<StoreInst>(inst))
2435                 {
2436                     pass->emitStore3D(cast<StoreInst>(inst), offset);
2437                 }
2438             }
2439         };
2440         if (ptrVal.getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL ||
2441             ptrVal.getType()->getPointerAddressSpace() == ADDRESS_SPACE_CONSTANT)
2442         {
2443             return false;
2444         }
2445 
2446         // Store3d supports only types equal or less than 128 bits.
2447         if (auto* storeInst = dyn_cast<StoreInst>(&I))
2448         {
2449             IGCLLVM::FixedVectorType* vectorToStore = dyn_cast<IGCLLVM::FixedVectorType>(storeInst->getValueOperand()->getType());
2450 
2451             // If stored value is a vector of pointers, the size must be calculated manually,
2452             // because getPrimitiveSizeInBits returns 0 for pointers.
2453             if ((storeInst->getValueOperand()->getType()->getPrimitiveSizeInBits() > 128) ||
2454                     (vectorToStore &&
2455                     (vectorToStore->getElementType()->isPointerTy()) &&
2456                     ((vectorToStore->getNumElements() * m_ctx->getModule()->getDataLayout().getPointerSizeInBits(cast<llvm::PointerType>(vectorToStore->getElementType())->getAddressSpace())) > 128)))
2457             {
2458                 return false;
2459             }
2460         }
2461 
2462         if (auto* i2p = dyn_cast<IntToPtrInst>(&ptrVal))
2463         {
2464             LoadStorePointerPattern* pattern = new (m_allocator) LoadStorePointerPattern();
2465             pattern->inst = &I;
2466             uint numSources = GetNbSources(I);
2467             for (uint i = 0; i < numSources; i++)
2468             {
2469                 if (I.getOperand(i) != i2p)
2470                 {
2471                     MarkAsSource(I.getOperand(i));
2472                 }
2473             }
2474             pattern->offset = i2p->getOperand(0);
2475             pattern->immOffset = ConstantInt::get(Type::getInt32Ty(I.getContext()), 0);
2476             MarkAsSource(pattern->offset);
2477             AddPattern(pattern);
2478             return true;
2479         }
2480         return false;
2481     }
2482 
MatchLrp(llvm::BinaryOperator & I)2483     bool CodeGenPatternMatch::MatchLrp(llvm::BinaryOperator& I)
2484     {
2485         struct LRPPattern : public Pattern
2486         {
2487             SSource sources[3];
2488             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
2489             {
2490                 pass->Lrp(sources, modifier);
2491             }
2492         };
2493 
2494         if (!I.getType()->isFloatTy())
2495             return false;
2496         if (!m_Platform.supportLRPInstruction())
2497             return false;
2498 
2499         if (m_ctx->getModuleMetaData()->isPrecise)
2500         {
2501             return false;
2502         }
2503 
2504         bool found = false;
2505         llvm::Value* sources[3];
2506         e_modifier   src_mod[3];
2507 
2508         if (m_AllowContractions == false)
2509         {
2510             return false;
2511         }
2512 
2513         IGC_ASSERT((I.getOpcode() == Instruction::FAdd) || (I.getOpcode() == Instruction::FSub));
2514 
2515         bool startPatternIsAdd = false;
2516         if (I.getOpcode() == Instruction::FAdd)
2517         {
2518             startPatternIsAdd = true;
2519         }
2520 
2521         // match the case: dst = src0 (src1 - src2)  + src2;
2522         for (uint i = 0; i < 2; i++)
2523         {
2524             llvm::BinaryOperator* mul = llvm::dyn_cast<llvm::BinaryOperator>(I.getOperand(i));
2525             if (mul && mul->getOpcode() == Instruction::FMul)
2526             {
2527                 for (uint j = 0; j < 2; j++)
2528                 {
2529                     llvm::BinaryOperator* sub = llvm::dyn_cast<llvm::BinaryOperator>(mul->getOperand(j));
2530                     if (sub)
2531                     {
2532                         llvm::ConstantFP* zero = llvm::dyn_cast<llvm::ConstantFP>(sub->getOperand(0));
2533                         if (zero && zero->isExactlyValue(0.f))
2534                         {
2535                             // in this case we can optimize the pattern into fmad and give better result
2536                             continue;
2537                         }
2538 
2539                         if ((startPatternIsAdd && sub->getOpcode() == Instruction::FSub) ||
2540                             (!startPatternIsAdd && i == 0 && sub->getOpcode() == Instruction::FAdd))
2541                         {
2542                             if (sub->getOperand(1) == I.getOperand(1 - i) &&
2543                                 mul->getOperand(0) != mul->getOperand(1))
2544                             {
2545                                 sources[0] = mul->getOperand(1 - j);
2546                                 sources[1] = sub->getOperand(0);
2547                                 sources[2] = sub->getOperand(1);
2548                                 GetModifier(*sources[0], src_mod[0], sources[0]);
2549                                 GetModifier(*sources[1], src_mod[1], sources[1]);
2550                                 GetModifier(*sources[2], src_mod[2], sources[2]);
2551 
2552                                 if (!startPatternIsAdd && i == 0)
2553                                 {
2554                                     // handle patterns like this:
2555                                     // dst = src0 (src1 + src2) - src2;
2556                                     src_mod[2] = CombineModifier(EMOD_NEG, src_mod[2]);
2557                                 }
2558 
2559                                 found = true;
2560                                 break;
2561                             }
2562                         }
2563                     }
2564                 }
2565             }
2566             if (found)
2567             {
2568                 break;
2569             }
2570         }
2571 
2572         // match the case: dst = src0 * src1 + src2 * (1.0 - src0);
2573         if (!found)
2574         {
2575             llvm::BinaryOperator* mul[2];
2576             mul[0] = llvm::dyn_cast<llvm::BinaryOperator>(I.getOperand(0));
2577             mul[1] = llvm::dyn_cast<llvm::BinaryOperator>(I.getOperand(1));
2578             if (mul[0] && mul[0]->getOpcode() == Instruction::FMul &&
2579                 mul[1] && mul[1]->getOpcode() == Instruction::FMul &&
2580                 !llvm::isa<llvm::ConstantFP>(mul[0]->getOperand(0)) &&
2581                 !llvm::isa<llvm::ConstantFP>(mul[0]->getOperand(1)) &&
2582                 !llvm::isa<llvm::ConstantFP>(mul[1]->getOperand(0)) &&
2583                 !llvm::isa<llvm::ConstantFP>(mul[1]->getOperand(1)))
2584             {
2585                 for (uint i = 0; i < 2; i++)
2586                 {
2587                     for (uint j = 0; j < 2; j++)
2588                     {
2589                         llvm::BinaryOperator* sub = llvm::dyn_cast<llvm::BinaryOperator>(mul[i]->getOperand(j));
2590                         if (sub && sub->getOpcode() == Instruction::FSub)
2591                         {
2592                             llvm::ConstantFP* one = llvm::dyn_cast<llvm::ConstantFP>(sub->getOperand(0));
2593                             if (one && one->isExactlyValue(1.f))
2594                             {
2595                                 for (uint k = 0; k < 2; k++)
2596                                 {
2597                                     if (sub->getOperand(1) == mul[1 - i]->getOperand(k))
2598                                     {
2599                                         sources[0] = sub->getOperand(1);
2600                                         sources[1] = mul[1 - i]->getOperand(1 - k);
2601                                         sources[2] = mul[i]->getOperand(1 - j);
2602                                         GetModifier(*sources[0], src_mod[0], sources[0]);
2603                                         GetModifier(*sources[1], src_mod[1], sources[1]);
2604                                         GetModifier(*sources[2], src_mod[2], sources[2]);
2605                                         if (!startPatternIsAdd)
2606                                         {
2607                                             if (i == 1)
2608                                             {
2609                                                 // handle patterns like this:
2610                                                 // dst = (src1 * src0) - (src2 * (1.0 - src0))
2611                                                 src_mod[2] = CombineModifier(EMOD_NEG, src_mod[2]);
2612                                             }
2613                                             else
2614                                             {
2615                                                 // handle patterns like this:
2616                                                 // dst = (src2 * (1.0 - src0)) - (src1 * src0)
2617                                                 src_mod[1] = CombineModifier(EMOD_NEG, src_mod[1]);
2618                                             }
2619                                         }
2620                                         found = true;
2621                                         break;
2622                                     }
2623                                 }
2624                             }
2625                         }
2626                         if (found)
2627                         {
2628                             break;
2629                         }
2630                     }
2631                     if (found)
2632                     {
2633                         break;
2634                     }
2635                 }
2636             }
2637         }
2638 
2639         if (!found)
2640         {
2641             // match the case: dst = src2 - (src0 * src2) + (src0 * src1);
2642             // match the case: dst = (src0 * src1) + src2 - (src0 * src2);
2643             // match the case: dst = src2 + (src0 * src1) - (src0 * src2);
2644             if (I.getOpcode() == Instruction::FAdd || I.getOpcode() == Instruction::FSub)
2645             {
2646                 // dst = op[0] +/- op[1] +/- op[2]
2647                 llvm::Instruction* op[3];
2648                 llvm::Instruction* addSub1 = llvm::dyn_cast<llvm::Instruction>(I.getOperand(0));
2649                 if (addSub1 && (addSub1->getOpcode() == Instruction::FSub || addSub1->getOpcode() == Instruction::FAdd))
2650                 {
2651                     op[0] = llvm::dyn_cast<llvm::Instruction>(addSub1->getOperand(0));
2652                     op[1] = llvm::dyn_cast<llvm::Instruction>(addSub1->getOperand(1));
2653                     op[2] = llvm::dyn_cast<llvm::Instruction>(I.getOperand(1));
2654 
2655                     if (op[0] && op[1] && op[2])
2656                     {
2657                         for (uint casei = 0; casei < 3; casei++)
2658                         {
2659                             // i, j, k are the index for op[]
2660                             uint i = (casei == 2 ? 1 : 0);
2661                             uint j = (casei == 0 ? 1 : 2);
2662                             uint k = 2 - casei;
2663 
2664                             //op[i] and op[j] should be fMul, and op[k] is src2
2665                             if (op[i]->getOpcode() == Instruction::FMul && op[j]->getOpcode() == Instruction::FMul)
2666                             {
2667                                 for (uint srci = 0; srci < 2; srci++)
2668                                 {
2669                                     for (uint srcj = 0; srcj < 2; srcj++)
2670                                     {
2671                                         // op[i] and op[j] needs to have one common source. this common source will be src0
2672                                         if (op[i]->getOperand(srci) == op[j]->getOperand(srcj))
2673                                         {
2674                                             // one of the non-common source from op[i] and op[j] needs to be the same as op[k], which is src2
2675                                             if (op[k] == op[i]->getOperand(1 - srci) ||
2676                                                 op[k] == op[j]->getOperand(1 - srcj))
2677                                             {
2678                                                 // disable if any of the sources is immediate
2679                                                 if (llvm::isa<llvm::ConstantFP>(op[i]->getOperand(srci)) ||
2680                                                     llvm::isa<llvm::ConstantFP>(op[i]->getOperand(1 - srci)) ||
2681                                                     llvm::isa<llvm::ConstantFP>(op[j]->getOperand(srcj)) ||
2682                                                     llvm::isa<llvm::ConstantFP>(op[j]->getOperand(1 - srcj)) ||
2683                                                     llvm::isa<llvm::ConstantFP>(op[k]))
2684                                                 {
2685                                                     break;
2686                                                 }
2687 
2688                                                 // check the add/sub cases and add negate to the sources when needed.
2689                                                 /*
2690                                                 ( src0src1, -src0src2, src2 )   okay
2691                                                 ( src0src1, -src0src2, -src2 )  skip
2692                                                 ( src0src1, src0src2, src2 )    skip
2693                                                 ( src0src1, src0src2, -src2 )   negate src2
2694                                                 ( -src0src1, -src0src2, src2 )  negate src1
2695                                                 ( -src0src1, -src0src2, -src2 ) skip
2696                                                 ( -src0src1, src0src2, src2 )   skip
2697                                                 ( -src0src1, src0src2, -src2 )  negate src1 src2
2698                                                 */
2699 
2700                                                 bool SignPositiveOp[3];
2701                                                 SignPositiveOp[0] = true;
2702                                                 SignPositiveOp[1] = (addSub1->getOpcode() == Instruction::FAdd);
2703                                                 SignPositiveOp[2] = (I.getOpcode() == Instruction::FAdd);
2704 
2705                                                 uint mulSrc0Src1Index = op[k] == op[i]->getOperand(1 - srci) ? j : i;
2706                                                 uint mulSrc0Src2Index = op[k] == op[i]->getOperand(1 - srci) ? i : j;
2707 
2708                                                 if (SignPositiveOp[mulSrc0Src2Index] == SignPositiveOp[k])
2709                                                 {
2710                                                     // abort the cases marked as "skip" in the comment above
2711                                                     break;
2712                                                 }
2713 
2714                                                 sources[0] = op[i]->getOperand(srci);
2715                                                 sources[1] = op[k] == op[i]->getOperand(1 - srci) ? op[j]->getOperand(1 - srcj) : op[i]->getOperand(1 - srci);
2716                                                 sources[2] = op[k];
2717                                                 GetModifier(*sources[0], src_mod[0], sources[0]);
2718                                                 GetModifier(*sources[1], src_mod[1], sources[1]);
2719                                                 GetModifier(*sources[2], src_mod[2], sources[2]);
2720 
2721                                                 if (SignPositiveOp[mulSrc0Src1Index] == false)
2722                                                 {
2723                                                     src_mod[1] = CombineModifier(EMOD_NEG, src_mod[1]);
2724                                                 }
2725                                                 if (SignPositiveOp[k] == false)
2726                                                 {
2727                                                     src_mod[2] = CombineModifier(EMOD_NEG, src_mod[2]);
2728                                                 }
2729 
2730                                                 found = true;
2731                                             }
2732                                         }
2733                                     }
2734                                     if (found)
2735                                     {
2736                                         break;
2737                                     }
2738                                 }
2739                             }
2740                             if (found)
2741                             {
2742                                 break;
2743                             }
2744                         }
2745                     }
2746                 }
2747             }
2748         }
2749 
2750         if (found)
2751         {
2752             LRPPattern* pattern = new (m_allocator) LRPPattern();
2753             for (int i = 0; i < 3; i++)
2754             {
2755                 pattern->sources[i] = GetSource(sources[i], src_mod[i], false);
2756             }
2757             AddPattern(pattern);
2758         }
2759         return found;
2760     }
2761 
MatchCmpSext(llvm::Instruction & I)2762     bool CodeGenPatternMatch::MatchCmpSext(llvm::Instruction& I)
2763     {
2764         /*
2765             %res_s42 = icmp eq i32 %src1_s41, 0
2766             %17 = sext i1 %res_s42 to i32
2767                 to
2768             %res_s42 (i32) = icmp eq i32 %src1_s41, 0
2769 
2770 
2771             %res_s73 = fcmp oge float %res_s61, %42
2772             %46 = sext i1 %res_s73 to i32
2773                 to
2774             %res_s73 (i32) = fcmp oge float %res_s61, %42
2775         */
2776 
2777         struct CmpSextPattern : Pattern
2778         {
2779             llvm::CmpInst* inst;
2780             SSource sources[2];
2781             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
2782             {
2783                 pass->Cmp(inst->getPredicate(), sources, modifier);
2784             }
2785         };
2786         bool match = false;
2787 
2788         if (CmpInst * cmpInst = dyn_cast<CmpInst>(I.getOperand(0)))
2789         {
2790             if (cmpInst->getOperand(0)->getType()->getPrimitiveSizeInBits() == I.getType()->getPrimitiveSizeInBits())
2791             {
2792                 CmpSextPattern* pattern = new (m_allocator) CmpSextPattern();
2793                 bool supportModifer = SupportsModifier(cmpInst);
2794 
2795                 pattern->inst = cmpInst;
2796                 pattern->sources[0] = GetSource(cmpInst->getOperand(0), supportModifer, false);
2797                 pattern->sources[1] = GetSource(cmpInst->getOperand(1), supportModifer, false);
2798                 AddPattern(pattern);
2799                 match = true;
2800             }
2801         }
2802 
2803         return match;
2804     }
2805 
2806     // Match the pattern of 32 x 32 = 64, a full 32-bit multiplication.
MatchFullMul32(llvm::Instruction & I)2807     bool CodeGenPatternMatch::MatchFullMul32(llvm::Instruction& I) {
2808         using namespace llvm::PatternMatch; // Scoped namespace using.
2809 
2810         struct FullMul32Pattern : public Pattern {
2811             SSource srcs[2];
2812             bool isUnsigned;
2813             virtual void Emit(EmitPass* pass, const DstModifier& dstMod)
2814             {
2815                 pass->EmitFullMul32(isUnsigned, srcs, dstMod);
2816             }
2817         };
2818 
2819         IGC_ASSERT_MESSAGE(I.getOpcode() == llvm::Instruction::Mul, "Mul instruction is expected!");
2820 
2821         if (!I.getType()->isIntegerTy(64))
2822             return false;
2823 
2824         llvm::Value* LHS = I.getOperand(0);
2825         llvm::Value* RHS = I.getOperand(1);
2826 
2827         // Swap operand to ensure the constant is always RHS.
2828         if (isa<ConstantInt>(LHS))
2829             std::swap(LHS, RHS);
2830 
2831         bool IsUnsigned = false;
2832         llvm::Value* L;
2833         llvm::Value* R;
2834 
2835         // Check LHS
2836         if (match(LHS, m_SExt(m_Value(L)))) {
2837             // Bail out if there's non 32-bit integer.
2838             if (!L->getType()->isIntegerTy(32))
2839                 return false;
2840         }
2841         else if (match(LHS, m_ZExt(m_Value(L)))) {
2842             // Bail out if there's non 32-bit integer.
2843             if (!L->getType()->isIntegerTy(32))
2844                 return false;
2845             IsUnsigned = true;
2846         }
2847         else {
2848             // Bailout if it's unknown that LHS have less significant bits than the
2849             // product.
2850             // NOTE we don't assertion fail the case where LHS is an constant to prevent
2851             // the assertion in O0 mode. Otherwise, we expect there's at most 1
2852             // constant operand.
2853             return false;
2854         }
2855 
2856         // Check RHS
2857         if (match(RHS, m_SExt(m_Value(R)))) {
2858             // Bail out if there's signedness mismatch or non 32-bit integer.
2859             if (IsUnsigned || !R->getType()->isIntegerTy(32))
2860                 return false;
2861         }
2862         else if (match(RHS, m_ZExt(m_Value(R)))) {
2863             // Bail out if there's signedness mismatch or non 32-bit integer.
2864             if (!IsUnsigned || !R->getType()->isIntegerTy(32))
2865                 return false;
2866             IsUnsigned = true;
2867         }
2868         else if (ConstantInt * CI = dyn_cast<ConstantInt>(RHS)) {
2869             APInt Val = CI->getValue();
2870             // 31-bit unsigned integer could be used as either signed or
2871             // unsigned one. Otherwise, we need special check how MSB is used.
2872             if (!Val.isIntN(31)) {
2873                 if (!(IsUnsigned && Val.isIntN(32)) &&
2874                     !(!IsUnsigned && Val.isSignedIntN(32))) {
2875                     return false;
2876                 }
2877             }
2878             R = ConstantExpr::getTrunc(CI, L->getType());
2879         }
2880         else {
2881             // Bailout if it's unknown that RHS have less significant bits than the
2882             // product.
2883             return false;
2884         }
2885 
2886         FullMul32Pattern* Pat = new (m_allocator) FullMul32Pattern();
2887         Pat->srcs[0] = GetSource(L, !IsUnsigned, false);
2888         Pat->srcs[1] = GetSource(R, !IsUnsigned, false);
2889         Pat->isUnsigned = IsUnsigned;
2890         AddPattern(Pat);
2891 
2892         return true;
2893     }
2894 
2895     // For 32 bit integer mul/add/sub, use 16bit operands if possible. Thus,
2896     // This will match 16x16->32, 16x32->32, the same for add/sub.
2897     //
2898     // For example:
2899     //   1.  before:
2900     //        %9 = ashr i32 %8, 16
2901     //        %10 = mul nsw i32 %9, -1024
2902     //        ( asr (16|M0)  r19.0<1>:d  r19.0<8;8,1>:d  16:w
2903     //          mul (16|M0)  r19.0<1>:d  r19.0<8;8,1>:d  -1024:w )
2904     //
2905     //      after:
2906     //      --> %10:d = mul %9.1<16;8:2>:w -1024:w
2907     //          (  mul (16|M0)  r23.0<1>:d   r19.1<2;1,0>:w   -1024:w )
2908     //
2909     //  2. before:
2910     //        %9  = lshr i32 %8, 16
2911     //        %10 = and i32 %8, 65535
2912     //        %11 = mul nuw i32 %9, %10
2913     //        ( shr  (16|M0)   r14.0<1>:d  r12.0<8;8,1>:ud   16:w
2914     //          and(16 | M0)   r12.0<1>:d  r12.0<8;8,1>:d  65535:d
2915     //          mul(16 | M0)   r14.0<1>:d  r14.0<8;8,1>:d  r12.0<8;8,1>:d )
2916     //
2917     //     after:
2918     //     --> %11:d = mul %8.1<16;8,2>:uw %8.0<16;8,2>:uw
2919     //         ( mul (16|M0)  r14.0<1>:d   r12.1<2;1,0>:w   r12.0<2;1,0>:w )
2920     //
MatchMulAdd16(Instruction & I)2921     bool CodeGenPatternMatch::MatchMulAdd16(Instruction& I) {
2922         using namespace llvm::PatternMatch;
2923 
2924         struct Oprd16Pattern : public Pattern {
2925             SSource srcs[2];
2926             Instruction* rootInst;
2927             virtual void Emit(EmitPass* pass, const DstModifier& dstMod)
2928             {
2929                 pass->emitMulAdd16(rootInst, srcs, dstMod);
2930             }
2931         };
2932 
2933         // The code is under the control of registry key EnableMixIntOperands.
2934         if (IGC_IS_FLAG_DISABLED(EnableMixIntOperands))
2935         {
2936             return false;
2937         }
2938 
2939         unsigned opc = I.getOpcode();
2940         IGC_ASSERT_MESSAGE((opc == Instruction::Mul) || (opc == Instruction::Add) || (opc == Instruction::Sub), "Mul instruction is expected!");
2941 
2942         // Handle 32 bit integer mul/add/sub only.
2943         if (!I.getType()->isIntegerTy(32))
2944         {
2945             return false;
2946         }
2947 
2948         // Try to replace any source operands with ones of type short if any. As vISA
2949         // allows the mix of any integer type, each operand is considered separately.
2950         struct {
2951             Value* src;
2952             bool useLower;
2953             bool isSigned;
2954         } oprdInfo[2];
2955         bool isCandidate = false;
2956 
2957         for (int i = 0; i < 2; ++i)
2958         {
2959             Value* oprd = I.getOperand(i);
2960             Value* L = nullptr;
2961 
2962             // oprdInfo[i].src == null --> no W operand replacement.
2963             oprdInfo[i].src = nullptr;
2964             if (ConstantInt * CI = dyn_cast<ConstantInt>(oprd))
2965             {
2966                 int64_t val =
2967                     CI->isNegative() ? CI->getSExtValue() : CI->getZExtValue();
2968                 // If src needs to be negated (y = x - a = x + (-a), as gen only
2969                 // has add), need to check if the negated src fits into W/UW.
2970                 bool isNegSrc = (opc == Instruction::Sub && i == 1);
2971                 if (isNegSrc)
2972                 {
2973                     val = -val;
2974                 }
2975                 if (INT16_MIN <= val && val <= INT16_MAX)
2976                 {
2977                     oprdInfo[i].src = oprd;
2978                     oprdInfo[i].useLower = true; // does not matter for const
2979                     oprdInfo[i].isSigned = true;
2980                     isCandidate = true;
2981                 }
2982                 else if (0 <= val && val <= UINT16_MAX)
2983                 {
2984                     oprdInfo[i].src = oprd;
2985                     oprdInfo[i].useLower = true; // does not matter for const
2986                     oprdInfo[i].isSigned = false;
2987                     isCandidate = true;
2988                 }
2989             }
2990             else if (match(oprd, m_And(m_Value(L), m_SpecificInt(0xFFFF))))
2991             {
2992                 oprdInfo[i].src = L;
2993                 oprdInfo[i].useLower = true;
2994                 oprdInfo[i].isSigned = false;
2995                 isCandidate = true;
2996             }
2997             else if (match(oprd, m_LShr(m_Value(L), m_SpecificInt(16))))
2998             {
2999                 oprdInfo[i].src = L;
3000                 oprdInfo[i].useLower = false;
3001                 oprdInfo[i].isSigned = false;
3002                 isCandidate = true;
3003             }
3004             else if (match(oprd, m_AShr(m_Shl(m_Value(L), m_SpecificInt(16)),
3005                 m_SpecificInt(16))))
3006             {
3007                 oprdInfo[i].src = L;
3008                 oprdInfo[i].useLower = true;
3009                 oprdInfo[i].isSigned = true;
3010                 isCandidate = true;
3011             }
3012             else if (match(oprd, m_AShr(m_Value(L), m_SpecificInt(16))))
3013             {
3014                 oprdInfo[i].src = L;
3015                 oprdInfo[i].useLower = false;
3016                 oprdInfo[i].isSigned = true;
3017                 isCandidate = true;
3018             }
3019         }
3020 
3021         if (!isCandidate) {
3022             return false;
3023         }
3024 
3025         Oprd16Pattern* Pat = new (m_allocator)Oprd16Pattern();
3026         for (int i = 0; i < 2; ++i)
3027         {
3028             if (oprdInfo[i].src)
3029             {
3030                 Pat->srcs[i] = GetSource(oprdInfo[i].src, false, false);
3031                 SSource& thisSrc = Pat->srcs[i];
3032 
3033                 // for now, Use W/UW only if region_set is false or the src is scalar
3034                 if (thisSrc.region_set &&
3035                     !(thisSrc.region[0] == 0 && thisSrc.region[1] == 1 && thisSrc.region[2] == 0))
3036                 {
3037                     Pat->srcs[i] = GetSource(I.getOperand(i), true, false);
3038                 }
3039                 else
3040                 {
3041                     // Note that SSource's type, if set by GetSource(), should be 32bit type. It's
3042                     // safe to override it with either UW or W. But for SSource's offset, need to
3043                     // re-calculate in term of 16bit, not 32bit.
3044                     thisSrc.type = oprdInfo[i].isSigned ? ISA_TYPE_W : ISA_TYPE_UW;
3045                     thisSrc.elementOffset = (2 * thisSrc.elementOffset) + (oprdInfo[i].useLower ? 0 : 1);
3046                 }
3047             }
3048             else
3049             {
3050                 Pat->srcs[i] = GetSource(I.getOperand(i), true, false);
3051             }
3052         }
3053         Pat->rootInst = &I;
3054         AddPattern(Pat);
3055 
3056         return true;
3057     }
3058 
3059 
BitcastSearch(SSource & source,llvm::Value * & value,bool broadcast)3060     bool CodeGenPatternMatch::BitcastSearch(SSource& source, llvm::Value*& value, bool broadcast)
3061     {
3062         if (auto elemInst = dyn_cast<ExtractElementInst>(value))
3063         {
3064             if (auto bTInst = dyn_cast<BitCastInst>(elemInst->getOperand(0)))
3065             {
3066                 // Pattern Matching (Instruction) + ExtractElem + (Vector)Bitcast
3067                 //
3068                 // In order to set the regioning for the ALU operand
3069                 // I require three things:
3070                 //      -The first is the source number of elements
3071                 //      -The second is the destination number of elements
3072                 //      -The third is the index from the extract element
3073                 //
3074                 // For example if I have <4 x i32> to <16 x i8> all I need is
3075                 // the 4 (vstride) and the i8 (b) in this case the operand would look
3076                 // like this -> r22.x <4;1,0>:b
3077                 // x is calculated below and later on using the simdsize
3078 
3079                 uint32_t index, srcNElts, dstNElts, nEltsRatio;
3080                 llvm::Type* srcTy = bTInst->getOperand(0)->getType();
3081                 llvm::Type* dstTy = bTInst->getType();
3082 
3083                 srcNElts = (srcTy->isVectorTy()) ? (uint32_t)cast<IGCLLVM::FixedVectorType>(srcTy)->getNumElements() : 1;
3084                 dstNElts = (dstTy->isVectorTy()) ? (uint32_t)cast<IGCLLVM::FixedVectorType>(dstTy)->getNumElements() : 1;
3085 
3086                 if (srcNElts < dstNElts && srcTy->getScalarSizeInBits() < 64)
3087                 {
3088                     if (isa<ConstantInt>(elemInst->getIndexOperand()))
3089                     {
3090                         index = int_cast<uint>(cast<ConstantInt>(elemInst->getIndexOperand())->getZExtValue());
3091                         nEltsRatio = dstNElts / srcNElts;
3092                         source.value = bTInst->getOperand(0);
3093                         source.SIMDOffset = iSTD::RoundDownNonPow2(index, nEltsRatio);
3094                         source.elementOffset = source.elementOffset * nEltsRatio + index % nEltsRatio;
3095                         value = source.value;
3096                         if (!broadcast)
3097                         {
3098                             source.region_set = true;
3099                             if (m_WI->isUniform(value))
3100                                 source.region[0] = 0;
3101                             else
3102                                 source.region[0] = (unsigned char)nEltsRatio;
3103                             source.region[1] = 1;
3104                             source.region[2] = 0;
3105                         }
3106                         return true;
3107                     }
3108                 }
3109             }
3110         }
3111         return false;
3112     }
3113 
MatchModifier(llvm::Instruction & I,bool SupportSrc0Mod)3114     bool CodeGenPatternMatch::MatchModifier(llvm::Instruction& I, bool SupportSrc0Mod)
3115     {
3116         struct ModifierPattern : public Pattern
3117         {
3118             SSource sources[2];
3119             llvm::Instruction* instruction;
3120             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3121             {
3122                 pass->BinaryUnary(instruction, sources, modifier);
3123             }
3124         };
3125 
3126         ModifierPattern* pattern = new (m_allocator)ModifierPattern();
3127         pattern->instruction = &I;
3128         uint nbSources = GetNbSources(I);
3129 
3130         bool supportModiferSrc0 = SupportsModifier(&I);
3131         bool supportRegioning = SupportsRegioning(&I);
3132         llvm::Instruction* src0Inst = llvm::dyn_cast<llvm::Instruction>(I.getOperand(0));
3133         if (I.getOpcode() == llvm::Instruction::UDiv && src0Inst && src0Inst->getOpcode() == llvm::Instruction::Sub) {
3134             supportModiferSrc0 = false;
3135         }
3136         pattern->sources[0] = GetSource(I.getOperand(0), supportModiferSrc0 && SupportSrc0Mod, supportRegioning);
3137         if (nbSources > 1)
3138         {
3139             bool supportModiferSrc1 = SupportsModifier(&I);
3140             llvm::Instruction* src1Inst = llvm::dyn_cast<llvm::Instruction>(I.getOperand(1));
3141             if (I.getOpcode() == llvm::Instruction::UDiv && src1Inst && src1Inst->getOpcode() == llvm::Instruction::Sub) {
3142                 supportModiferSrc1 = false;
3143             }
3144             pattern->sources[1] = GetSource(I.getOperand(1), supportModiferSrc1, supportRegioning);
3145 
3146             // add df imm to constant pool for binary/ternary inst
3147             // we do 64-bit int imm bigger than 32 bits, since smaller may fit in D/W
3148             for (int i = 0, numSrc = (int)nbSources; i < numSrc; ++i)
3149             {
3150                 Value* op = I.getOperand(i);
3151                 if (isCandidateForConstantPool(op))
3152                 {
3153                     AddToConstantPool(I.getParent(), op);
3154                     pattern->sources[i].fromConstantPool = true;
3155                 }
3156             }
3157         }
3158 
3159         AddPattern(pattern);
3160 
3161         return true;
3162     }
3163 
MatchSingleInstruction(llvm::Instruction & I)3164     bool CodeGenPatternMatch::MatchSingleInstruction(llvm::Instruction& I)
3165     {
3166         struct SingleInstPattern : Pattern
3167         {
3168             llvm::Instruction* inst;
3169             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3170             {
3171                 IGC_ASSERT(modifier.sat == false);
3172                 IGC_ASSERT(modifier.flag == nullptr);
3173                 pass->EmitNoModifier(inst);
3174             }
3175         };
3176         SingleInstPattern* pattern = new (m_allocator) SingleInstPattern();
3177         pattern->inst = &I;
3178         uint numSources = GetNbSources(I);
3179         for (uint i = 0; i < numSources; i++)
3180         {
3181             MarkAsSource(I.getOperand(i));
3182         }
3183 
3184         if (CallInst * callinst = dyn_cast<CallInst>(&I))
3185         {
3186             // Mark the function pointer in indirect calls as a source
3187             if (!callinst->getCalledFunction())
3188             {
3189                 MarkAsSource(IGCLLVM::getCalledValue(callinst));
3190             }
3191         }
3192         AddPattern(pattern);
3193         return true;
3194     }
3195 
MatchCanonicalizeInstruction(llvm::Instruction & I)3196     bool CodeGenPatternMatch::MatchCanonicalizeInstruction(llvm::Instruction& I)
3197     {
3198         struct CanonicalizeInstPattern : Pattern
3199         {
3200             CanonicalizeInstPattern(llvm::Instruction* pInst) : m_pInst(pInst) {}
3201 
3202             llvm::Instruction* m_pInst;
3203             Pattern* m_pPattern = nullptr;
3204 
3205             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3206             {
3207                 if (m_pPattern)
3208                 {
3209                     m_pPattern->Emit(pass, modifier);
3210                 }
3211                 else
3212                 {
3213                     pass->emitCanonicalize(m_pInst, modifier);
3214                 }
3215             }
3216         };
3217 
3218 
3219         // FAdd, FSub, FMul, FDiv instructions flush subnormals to zero.
3220         // However, mix mode and math instructions preserve subnormals.
3221         // Other instructions also preserve subnormals.
3222         // FSat intrinsic instruction can be emitted i.e. as FAdd so such an
3223         // instruction should be inspected recursively.
3224         std::function<bool(llvm::Value*)> DetermineIfMixMode;
3225         DetermineIfMixMode = [&DetermineIfMixMode, this](llvm::Value* operand) -> bool
3226         {
3227             bool isMixModePossible = false;
3228             if (m_Platform.supportMixMode())
3229             {
3230                 if (llvm::BinaryOperator* pBianaryOperator = llvm::dyn_cast<llvm::BinaryOperator>(operand))
3231                 {
3232                     // the switch instruction is executed to break the recursion if it is unneeded.
3233                     // The cause for this recursion is a possibility of constructing mad instructions.
3234                     switch (pBianaryOperator->getOpcode())
3235                     {
3236                     case llvm::BinaryOperator::BinaryOps::FAdd:
3237                     case llvm::BinaryOperator::BinaryOps::FMul:
3238                     case llvm::BinaryOperator::BinaryOps::FSub:
3239                         isMixModePossible = pBianaryOperator->getType()->isDoubleTy() == false &&
3240                             (DetermineIfMixMode(pBianaryOperator->getOperand(0)) || DetermineIfMixMode(pBianaryOperator->getOperand(1)));
3241                         break;
3242                     default:
3243                         break;
3244                     }
3245                 }
3246                 else if (isa<FPTruncInst>(operand))
3247                 {
3248                     FPTruncInst* fptruncInst = llvm::cast<FPTruncInst>(operand);
3249                     isMixModePossible = fptruncInst->getSrcTy()->isDoubleTy() == false;
3250                 }
3251                 else if (isa<FPExtInst>(operand))
3252                 {
3253                     FPExtInst* fpextInst = llvm::cast<FPExtInst>(operand);
3254                     isMixModePossible = fpextInst->getDestTy()->isDoubleTy() == false;
3255                 }
3256             }
3257             return isMixModePossible;
3258         };
3259 
3260         std::function<bool(llvm::Value*)> DetermineIfNeeded;
3261         DetermineIfNeeded = [&DetermineIfNeeded, &DetermineIfMixMode](llvm::Value* operand) -> bool
3262         {
3263             bool isNeeded = true;
3264             if (llvm::BinaryOperator* pBianaryOperator = llvm::dyn_cast<llvm::BinaryOperator>(operand))
3265             {
3266                 // the switch instruction is to consider only the operations
3267                 // which support flushing denorms to zero.
3268                 switch (pBianaryOperator->getOpcode())
3269                 {
3270                 case llvm::BinaryOperator::BinaryOps::FAdd:
3271                 case llvm::BinaryOperator::BinaryOps::FMul:
3272                 case llvm::BinaryOperator::BinaryOps::FSub:
3273                 case llvm::BinaryOperator::BinaryOps::FDiv:
3274                     isNeeded = DetermineIfMixMode(pBianaryOperator);
3275                     break;
3276                 default:
3277                     break;
3278                 }
3279             }
3280             else if(GenIntrinsicInst* intrin = dyn_cast<GenIntrinsicInst>(operand))
3281             {
3282                 switch (intrin->getIntrinsicID())
3283                 {
3284                 case GenISAIntrinsic::GenISA_fsat:
3285                     isNeeded = DetermineIfNeeded(intrin->getOperand(0));
3286                     break;
3287                 default:
3288                     break;
3289                 }
3290             }
3291             else if (IntrinsicInst* intrin = dyn_cast<IntrinsicInst>(operand))
3292             {
3293                 switch (intrin->getIntrinsicID())
3294                 {
3295                 case Intrinsic::canonicalize:
3296                     isNeeded = DetermineIfNeeded(intrin->getOperand(0));
3297                     break;
3298                 default:
3299                     break;
3300                 }
3301             }
3302             return isNeeded;
3303         };
3304 
3305         CanonicalizeInstPattern* pattern = new (m_allocator) CanonicalizeInstPattern(&I);
3306         if (DetermineIfNeeded(I.getOperand(0)))
3307         {
3308             MarkAsSource(I.getOperand(0));
3309         }
3310         else
3311         {
3312             pattern->m_pPattern = Match(*llvm::cast<llvm::Instruction>(I.getOperand(0)));
3313         }
3314 
3315         AddPattern(pattern);
3316         return true;
3317     }
3318 
MatchBranch(llvm::BranchInst & I)3319     bool CodeGenPatternMatch::MatchBranch(llvm::BranchInst& I)
3320     {
3321         struct CondBrInstPattern : Pattern
3322         {
3323             SSource cond;
3324             llvm::BranchInst* inst;
3325             e_predMode predMode = EPRED_NORMAL;
3326             bool isDiscardBranch = false;
3327             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3328             {
3329                 if (isDiscardBranch)
3330                 {
3331                     pass->emitDiscardBranch(inst, cond);
3332                 }
3333                 else
3334                 {
3335                     pass->emitBranch(inst, cond, predMode);
3336                 }
3337             }
3338         };
3339         CondBrInstPattern* pattern = new (m_allocator) CondBrInstPattern();
3340         pattern->inst = &I;
3341 
3342         if (!I.isUnconditional())
3343         {
3344             Value* cond = I.getCondition();
3345             if (dyn_cast<GenIntrinsicInst>(cond,
3346                 GenISAIntrinsic::GenISA_UpdateDiscardMask))
3347             {
3348                 pattern->isDiscardBranch = true;
3349             }
3350             pattern->cond = GetSource(I.getCondition(), false, false);
3351         }
3352         AddPattern(pattern);
3353         return true;
3354     }
3355 
MatchFloatingPointSatModifier(llvm::Instruction & I)3356     bool CodeGenPatternMatch::MatchFloatingPointSatModifier(llvm::Instruction& I)
3357     {
3358         struct SatPattern : Pattern
3359         {
3360             Pattern* pattern;
3361             SSource source;
3362             bool isUnsigned;
3363             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3364             {
3365                 DstModifier mod = modifier;
3366                 mod.sat = true;
3367                 if (pattern)
3368                 {
3369                     pattern->Emit(pass, mod);
3370                 }
3371                 else
3372                 {
3373                     pass->Mov(source, mod);
3374                 }
3375             }
3376         };
3377         bool match = false;
3378         llvm::Value* source = nullptr;
3379         bool isUnsigned = false;
3380         if (isSat(&I, source, isUnsigned))
3381         {
3382             SatPattern* satPattern = new (m_allocator) SatPattern();
3383             if (llvm::Instruction * inst = llvm::dyn_cast<Instruction>(source))
3384             {
3385                 // As an heuristic we only match saturate if the instruction has one use
3386                 // to avoid duplicating expensive instructions and increasing reg pressure
3387                 // without improve code quality this may be refined in the future
3388                 if (inst->hasOneUse() && SupportsSaturate(inst))
3389                 {
3390                     auto *pattern = Match(*inst);
3391                     IGC_ASSERT_MESSAGE(pattern, "Failed to match pattern");
3392                     // Even though the original `inst` may support saturate,
3393                     // we need to know if the instruction(s) generated from
3394                     // the pattern support it.
3395                     if (pattern->supportsSaturate())
3396                     {
3397                         satPattern->pattern = pattern;
3398                         match = true;
3399                     }
3400                 }
3401             }
3402             if (!match)
3403             {
3404                 satPattern->pattern = nullptr;
3405                 satPattern->source = GetSource(source, true, false);
3406                 match = true;
3407             }
3408             if (isUniform(&I) && source->hasOneUse())
3409             {
3410                 gatherUniformBools(source);
3411             }
3412             AddPattern(satPattern);
3413         }
3414         return match;
3415     }
3416 
MatchIntegerSatModifier(llvm::Instruction & I)3417     bool CodeGenPatternMatch::MatchIntegerSatModifier(llvm::Instruction& I)
3418     {
3419         // a default pattern
3420         struct SatPattern : Pattern
3421         {
3422             Pattern* pattern;
3423             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3424             {
3425                 DstModifier mod = modifier;
3426                 mod.sat = true;
3427                 pattern->Emit(pass, mod);
3428             }
3429         };
3430 
3431         // a special pattern is required because of the fact that the instruction works on unsigned values
3432         // whereas the default type is signed for arithmetic instructions
3433         struct UAddPattern : Pattern
3434         {
3435             BinaryOperator* inst;
3436 
3437             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3438             {
3439                 DstModifier mod = modifier;
3440                 mod.sat = true;
3441                 pass->EmitUAdd(inst, mod);
3442             }
3443         };
3444 
3445         struct IntegerSatTruncPattern : public Pattern {
3446             SSource src;
3447             bool isSigned;
3448             virtual void Emit(EmitPass* pass, const DstModifier& dstMod)
3449             {
3450                 pass->EmitIntegerTruncWithSat(isSigned, isSigned, src, dstMod);
3451             }
3452         };
3453 
3454         // dp4a with modifiers
3455         struct Dp4aSatPattern : Pattern
3456         {
3457             GenIntrinsicInst* inst;
3458             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3459             {
3460                 DstModifier mod = modifier;
3461                 mod.sat = true;
3462                 pass->emitDP4A(inst, nullptr, mod);
3463             }
3464         };
3465 
3466 
3467         bool match = false;
3468         llvm::Value* source = nullptr;
3469         bool isUnsigned = false;
3470         if (isSat(&I, source, isUnsigned))
3471         {
3472             IGC_ASSERT(llvm::isa<Instruction>(source));
3473 
3474             // As an heuristic we only match saturate if the instruction has one use
3475             // to avoid duplicating expensive instructions and increasing reg pressure
3476             // without improve code quality this may be refined in the future.
3477             if (llvm::Instruction* sourceInst = llvm::cast<llvm::Instruction>(source);
3478                 sourceInst && sourceInst->hasOneUse() && SupportsSaturate(sourceInst))
3479             {
3480                 if (llvm::BinaryOperator* binaryOpInst = llvm::dyn_cast<llvm::BinaryOperator>(source);
3481                     binaryOpInst && (binaryOpInst->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) && isUnsigned)
3482                 {
3483                     match = true;
3484 
3485                     uint numSources = GetNbSources(*sourceInst);
3486                     for (uint i = 0; i < numSources; i++)
3487                     {
3488                         MarkAsSource(sourceInst->getOperand(i));
3489                     }
3490 
3491                     UAddPattern* uAddPattern = new (m_allocator) UAddPattern();
3492                     uAddPattern->inst = binaryOpInst;
3493                     AddPattern(uAddPattern);
3494                 }
3495                 else if (binaryOpInst && (binaryOpInst->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) && !isUnsigned)
3496                 {
3497                     match = true;
3498                     SatPattern* satPattern = new (m_allocator) SatPattern();
3499                     satPattern->pattern = Match(*sourceInst);
3500                     AddPattern(satPattern);
3501                 }
3502                 else if (llvm::TruncInst* truncInst = llvm::dyn_cast<llvm::TruncInst>(source);
3503                     truncInst)
3504                 {
3505                     match = true;
3506                     IntegerSatTruncPattern* satPattern = new (m_allocator) IntegerSatTruncPattern();
3507                     satPattern->isSigned = !isUnsigned;
3508                     satPattern->src = GetSource(truncInst->getOperand(0), !isUnsigned, false);
3509                     AddPattern(satPattern);
3510                 }
3511                 else if (llvm::GenIntrinsicInst * genIsaInst = llvm::dyn_cast<llvm::GenIntrinsicInst>(source);
3512                     genIsaInst &&
3513                     (genIsaInst->getIntrinsicID() == llvm::GenISAIntrinsic::ID::GenISA_dp4a_ss ||
3514                     genIsaInst->getIntrinsicID() == llvm::GenISAIntrinsic::ID::GenISA_dp4a_su ||
3515                     genIsaInst->getIntrinsicID() == llvm::GenISAIntrinsic::ID::GenISA_dp4a_uu ||
3516                     genIsaInst->getIntrinsicID() == llvm::GenISAIntrinsic::ID::GenISA_dp4a_us))
3517                 {
3518                     match = true;
3519 
3520                     uint numSources = GetNbSources(*sourceInst);
3521                     for (uint i = 0; i < numSources; i++)
3522                     {
3523                         MarkAsSource(sourceInst->getOperand(i));
3524                     }
3525 
3526                     Dp4aSatPattern* dp4aSatPattern = new (m_allocator) Dp4aSatPattern();
3527                     dp4aSatPattern->inst = genIsaInst;
3528                     AddPattern(dp4aSatPattern);
3529                 }
3530                 else
3531                 {
3532                     IGC_ASSERT_MESSAGE(0, "An undefined pattern match");
3533                 }
3534             }
3535         }
3536         return match;
3537     }
3538 
MatchPredicate(llvm::SelectInst & I)3539     bool CodeGenPatternMatch::MatchPredicate(llvm::SelectInst& I)
3540     {
3541         struct PredicatePattern : Pattern
3542         {
3543             bool invertFlag;
3544             Pattern* patternNotPredicated;
3545             Pattern* patternPredicated;
3546             SSource flag;
3547             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3548             {
3549                 DstModifier mod = modifier;
3550                 patternNotPredicated->Emit(pass, mod);
3551                 mod.flag = &flag;
3552                 mod.invertFlag = invertFlag;
3553                 patternPredicated->Emit(pass, mod);
3554             }
3555         };
3556         bool match = false;
3557         bool invertFlag = false;
3558         llvm::Instruction* source0 = llvm::dyn_cast<llvm::Instruction>(I.getTrueValue());
3559         llvm::Instruction* source1 = llvm::dyn_cast<llvm::Instruction>(I.getFalseValue());
3560         if (source0 && source0->hasOneUse() && source1 && source1->hasOneUse())
3561         {
3562             if (SupportsPredicate(source0))
3563             {
3564                 // temp fix until we find the best solution for this case
3565                 if (!isa<ExtractElementInst>(source1))
3566                 {
3567                     match = true;
3568                 }
3569             }
3570             else if (SupportsPredicate(source1))
3571             {
3572                 if (!isa<ExtractElementInst>(source0))
3573                 {
3574                     std::swap(source0, source1);
3575                     invertFlag = true;
3576                     match = true;
3577                 }
3578             }
3579         }
3580         if (match == true)
3581         {
3582             PredicatePattern* pattern = new (m_allocator) PredicatePattern();
3583             pattern->flag = GetSource(I.getCondition(), false, false);
3584             pattern->invertFlag = invertFlag;
3585             pattern->patternNotPredicated = Match(*source1);
3586             pattern->patternPredicated = Match(*source0);
3587             IGC_ASSERT_MESSAGE(pattern->patternNotPredicated, "Failed to match pattern");
3588             IGC_ASSERT_MESSAGE(pattern->patternPredicated, "Failed to match pattern");
3589             AddPattern(pattern);
3590         }
3591         return match;
3592     }
3593 
MatchSelectModifier(llvm::SelectInst & I)3594     bool CodeGenPatternMatch::MatchSelectModifier(llvm::SelectInst& I)
3595     {
3596         struct SelectPattern : Pattern
3597         {
3598             SSource sources[3];
3599             e_predMode predMode;
3600             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3601             {
3602                 DstModifier modf = modifier;
3603                 modf.predMode = predMode;
3604                 pass->Select(sources, modf);
3605             }
3606         };
3607         SelectPattern* pattern = new (m_allocator) SelectPattern();
3608         pattern->predMode = EPRED_NORMAL;
3609 
3610         pattern->sources[0] = GetSource(I.getCondition(), false, false);
3611         pattern->sources[1] = GetSource(I.getTrueValue(), true, false);
3612         pattern->sources[2] = GetSource(I.getFalseValue(), true, false);
3613 
3614         // try to add to constant pool whatever possible.
3615         if (isCandidateForConstantPool(I.getTrueValue()))
3616         {
3617             AddToConstantPool(I.getParent(), I.getTrueValue());
3618             pattern->sources[1].fromConstantPool = true;
3619         }
3620 
3621         if (isCandidateForConstantPool(I.getFalseValue()))
3622         {
3623             AddToConstantPool(I.getParent(), I.getFalseValue());
3624             pattern->sources[2].fromConstantPool = true;
3625         }
3626 
3627         AddPattern(pattern);
3628         return true;
3629     }
3630 
IsPositiveFloat(Value * v,unsigned int depth=0)3631     static bool IsPositiveFloat(Value* v, unsigned int depth = 0)
3632     {
3633         if (depth > 3)
3634         {
3635             // limit the depth of recursion to avoid compile time problem
3636             return false;
3637         }
3638         if (ConstantFP * cst = dyn_cast<ConstantFP>(v))
3639         {
3640             if (!cst->getValueAPF().isNegative())
3641             {
3642                 return true;
3643             }
3644         }
3645         else if (Instruction * I = dyn_cast<Instruction>(v))
3646         {
3647             switch (I->getOpcode())
3648             {
3649             case Instruction::FMul:
3650             case Instruction::FAdd:
3651                 return IsPositiveFloat(I->getOperand(0), depth + 1) && IsPositiveFloat(I->getOperand(1), depth + 1);
3652             case Instruction::Call:
3653                 if (IntrinsicInst * intrinsicInst = dyn_cast<IntrinsicInst>(I))
3654                 {
3655                     if (intrinsicInst->getIntrinsicID() == Intrinsic::fabs)
3656                     {
3657                         return true;
3658                     }
3659                 }
3660                 else if (isa<GenIntrinsicInst>(I, GenISAIntrinsic::GenISA_fsat))
3661                 {
3662                     return true;
3663                 }
3664                 break;
3665             default:
3666                 break;
3667             }
3668         }
3669         return false;
3670     }
3671 
MatchPow(llvm::IntrinsicInst & I)3672     bool CodeGenPatternMatch::MatchPow(llvm::IntrinsicInst& I)
3673     {
3674         if (IGC_IS_FLAG_ENABLED(DisableMatchPow))
3675         {
3676             return false;
3677         }
3678 
3679         // For this pattern match exp(log(x) * y) = pow
3680         // if x < 0 and y is an integer (ex: 1.0)
3681         // with pattern match : pow(x, 1.0) = x
3682         // without pattern match : exp(log(x) * 1.0) = NaN because log(x) is NaN.
3683         //
3684         // Since pow is 2x slower than exp/log, disabling this optimization might not hurt much.
3685         // Keep the code and disable MatchPow to track any performance change for now.
3686         struct PowPattern : public Pattern
3687         {
3688             SSource sources[2];
3689             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3690             {
3691                 pass->Pow(sources, modifier);
3692             }
3693         };
3694         bool found = false;
3695         llvm::Value* source0 = NULL;
3696         llvm::Value* source1 = NULL;
3697         if (I.getIntrinsicID() == Intrinsic::exp2)
3698         {
3699             llvm::BinaryOperator* mul = dyn_cast<BinaryOperator>(I.getOperand((0)));
3700             if (mul && mul->getOpcode() == Instruction::FMul)
3701             {
3702                 for (uint j = 0; j < 2; j++)
3703                 {
3704                     llvm::IntrinsicInst* log = dyn_cast<IntrinsicInst>(mul->getOperand(j));
3705                     if (log && log->getIntrinsicID() == Intrinsic::log2)
3706                     {
3707                         if (IsPositiveFloat(log->getOperand(0)))
3708                         {
3709                             source0 = log->getOperand(0);
3710                             source1 = mul->getOperand(1 - j);
3711                             found = true;
3712                             break;
3713                         }
3714                     }
3715                 }
3716             }
3717         }
3718         if (found)
3719         {
3720             PowPattern* pattern = new (m_allocator) PowPattern();
3721             pattern->sources[0] = GetSource(source0, true, false);
3722             pattern->sources[1] = GetSource(source1, true, false);
3723             AddPattern(pattern);
3724         }
3725         return found;
3726     }
3727 
3728     // We match this pattern
3729     // %1 = add %2 %3
3730     // %b = %cmp %1 0
3731     // right now we don't match if the alu has more than 1 use has it could generate worse code
MatchCondModifier(llvm::CmpInst & I)3732     bool CodeGenPatternMatch::MatchCondModifier(llvm::CmpInst& I)
3733     {
3734         struct CondModifierPattern : Pattern
3735         {
3736             Pattern* pattern;
3737             Instruction* alu;
3738             CmpInst* cmp;
3739             int aluOprdNum = 0;  // 0: alu is src0; otherwise, alu is src1
3740             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3741             {
3742                 IGC_ASSERT(modifier.flag == nullptr);
3743                 IGC_ASSERT(modifier.sat == false);
3744                 pass->emitAluConditionMod(pattern, alu, cmp, aluOprdNum);
3745             }
3746         };
3747         bool found = false;
3748         for (uint i = 0; i < 2; i++)
3749         {
3750             if (IsZero(I.getOperand(i)))
3751             {
3752                 llvm::Instruction* alu = dyn_cast<Instruction>(I.getOperand(1 - i));
3753                 if (alu && alu->hasOneUse() && SupportsCondModifier(alu))
3754                 {
3755                     CondModifierPattern* pattern = new (m_allocator) CondModifierPattern();
3756                     pattern->pattern = Match(*alu);
3757                     IGC_ASSERT_MESSAGE(pattern->pattern, "Failed to match pattern");
3758                     pattern->alu = alu;
3759                     pattern->cmp = &I;
3760                     pattern->aluOprdNum = 1 - i;
3761                     AddPattern(pattern);
3762                     found = true;
3763                     break;
3764                 }
3765             }
3766         }
3767         return found;
3768     }
3769 
3770     // we match the following pattern
3771     // %f = cmp %1 %2
3772     // %o = or/and %f %g
MatchBoolOp(llvm::BinaryOperator & I)3773     bool CodeGenPatternMatch::MatchBoolOp(llvm::BinaryOperator& I)
3774     {
3775         struct BoolOpPattern : public Pattern
3776         {
3777             llvm::BinaryOperator* boolOp;
3778             llvm::CmpInst::Predicate predicate;
3779             SSource cmpSource[2];
3780             SSource binarySource;
3781             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3782             {
3783                 pass->CmpBoolOp(boolOp, predicate, cmpSource, binarySource, modifier);
3784             }
3785         };
3786 
3787         IGC_ASSERT(I.getOpcode() == Instruction::Or || I.getOpcode() == Instruction::And);
3788         bool found = false;
3789         if (I.getType()->isIntegerTy(1))
3790         {
3791             for (uint i = 0; i < 2; i++)
3792             {
3793                 if (CmpInst * cmp = llvm::dyn_cast<CmpInst>(I.getOperand(i)))
3794                 {
3795                     // only beneficial if the other operand only have one use
3796                     if (I.getOperand(1 - i)->hasOneUse())
3797                     {
3798                         BoolOpPattern* pattern = new (m_allocator) BoolOpPattern();
3799                         bool supportsMod = SupportsModifier(cmp);
3800                         pattern->boolOp = &I;
3801                         pattern->predicate = cmp->getPredicate();
3802                         pattern->cmpSource[0] = GetSource(cmp->getOperand(0), supportsMod, false);
3803                         pattern->cmpSource[1] = GetSource(cmp->getOperand(1), supportsMod, false);
3804                         pattern->binarySource = GetSource(I.getOperand(1 - i), false, false);
3805                         AddPattern(pattern);
3806                         found = true;
3807                         break;
3808                     }
3809                 }
3810             }
3811         }
3812         return found;
3813     }
3814 
MatchFunnelShiftRotate(llvm::IntrinsicInst & I)3815     bool CodeGenPatternMatch::MatchFunnelShiftRotate(llvm::IntrinsicInst& I)
3816     {
3817         // Hanlde only funnel shift that can be turned into rotate.
3818         struct funnelShiftRotatePattern : public Pattern
3819         {
3820             SSource sources[2];
3821             llvm::IntrinsicInst* instruction;
3822             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3823             {
3824                 bool isShl = instruction->getIntrinsicID() == Intrinsic::fshl;
3825                 pass->Binary(isShl ? EOPCODE_ROL : EOPCODE_ROR, sources, modifier);
3826             }
3827         };
3828 
3829         if (!m_Platform.supportRotateInstruction() || I.getType()->isVectorTy())
3830         {
3831             return false;
3832         }
3833 
3834         Value* A = I.getOperand(0);
3835         Value* B = I.getOperand(1);
3836         Value* Amt = I.getOperand(2);
3837         uint32_t typebits = I.getType()->getScalarSizeInBits();
3838         if (A != B ||
3839             (typebits != 16 && typebits != 32 && typebits != 64))
3840         {
3841             return false;
3842         }
3843 
3844         // Found the pattern.
3845         funnelShiftRotatePattern* pattern = new (m_allocator) funnelShiftRotatePattern();
3846         pattern->instruction = &I;
3847         pattern->sources[0] = GetSource(A, true, false);
3848         pattern->sources[1] = GetSource(Amt, true, false);
3849 
3850         AddPattern(pattern);
3851         return true;
3852     }
3853 
MatchUnmaskedRegionBoundary(Instruction & I,bool start)3854     bool CodeGenPatternMatch::MatchUnmaskedRegionBoundary(Instruction &I, bool start)
3855     {
3856         struct UnmaskedBoundaryPattern : public Pattern
3857         {
3858             bool start;
3859             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3860             {
3861                 (void) modifier;
3862                 pass->emitUnmaskedRegionBoundary(start);
3863             }
3864         };
3865 
3866         UnmaskedBoundaryPattern* pattern = new (m_allocator) UnmaskedBoundaryPattern();
3867         pattern->start = start;
3868         AddPattern(pattern);
3869         return true;
3870     }
3871 
MatchAdd3(Instruction & I)3872     bool CodeGenPatternMatch::MatchAdd3(Instruction& I)
3873     {
3874         using namespace llvm::PatternMatch;
3875 
3876         struct Add3Pattern : public Pattern
3877         {
3878             SSource sources[3];
3879             Instruction* instruction;
3880             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3881             {
3882                 pass->Tenary(EOPCODE_ADD3, sources, modifier);
3883             }
3884         };
3885 
3886         if (IGC_IS_FLAG_DISABLED(EnableAdd3) ||
3887             !m_Platform.supportAdd3Instruction())
3888         {
3889             return false;
3890         }
3891 
3892         // Only handle D & W integer types.
3893         Type* Ty = I.getType();
3894         if (!(Ty->isIntegerTy(16) || Ty->isIntegerTy(32)))
3895         {
3896             return false;
3897         }
3898 
3899         Value* s0 = I.getOperand(0);
3900         Value* s1 = nullptr, * s2 = nullptr;
3901         e_modifier Mod1 = EMOD_NONE, Mod2 = EMOD_NONE;
3902         Instruction* I0 = dyn_cast<Instruction>(s0);
3903         if (I0)
3904         {
3905             if (I0->getOpcode() == Instruction::Sub)
3906             {
3907                 s0 = I0->getOperand(0);
3908                 s1 = I0->getOperand(1);
3909                 Mod1 = EMOD_NEG;
3910             }
3911             else if (I0->getOpcode() == Instruction::Add)
3912             {
3913                 s0 = I0->getOperand(0);
3914                 s1 = I0->getOperand(1);
3915             }
3916         }
3917 
3918         bool isNeg = (I.getOpcode() == Instruction::Sub);
3919         if (s1 == nullptr)
3920         {
3921             s1 = I.getOperand(1);
3922             Instruction* I1 = dyn_cast<Instruction>(s1);
3923             if (I1)
3924             {
3925                 if (I1->getOpcode() == Instruction::Sub)
3926                 {
3927                     s1 = I1->getOperand(0);
3928                     s2 = I1->getOperand(1);
3929                     if (isNeg) {
3930                         Mod1 = EMOD_NEG;
3931                     }
3932                     else {
3933                         Mod2 = EMOD_NEG;
3934                     }
3935                 }
3936                 else if (I1->getOpcode() == Instruction::Add)
3937                 {
3938                     s1 = I1->getOperand(0);
3939                     s2 = I1->getOperand(1);
3940                     if (isNeg) {
3941                         Mod1 = EMOD_NEG;
3942                         Mod2 = EMOD_NEG;
3943                     }
3944                 }
3945             }
3946         }
3947         else
3948         {
3949             s2 = I.getOperand(1);
3950             if (isNeg) {
3951                 Mod2 = EMOD_NEG;
3952             }
3953         }
3954 
3955         if (s2 == nullptr)
3956         {
3957             return false;
3958         }
3959 
3960         // Found the pattern.
3961         // Make sure that the middle one is not constant
3962         if (isa<ConstantInt>(s1))
3963         {
3964             std::swap(s1, s2);
3965             std::swap(Mod1, Mod2);
3966         }
3967 
3968         Add3Pattern* pattern = new (m_allocator) Add3Pattern();
3969         pattern->instruction = &I;
3970         pattern->sources[0] = GetSource(s0, true, false);
3971         pattern->sources[1] = GetSource(s1, true, false);
3972         pattern->sources[2] = GetSource(s2, true, false);
3973         if (Mod1 != EMOD_NONE) {
3974             pattern->sources[1].mod = CombineModifier(Mod1, pattern->sources[1].mod);
3975         }
3976         if (Mod2 != EMOD_NONE) {
3977             pattern->sources[2].mod = CombineModifier(Mod2, pattern->sources[2].mod);
3978         }
3979         AddPattern(pattern);
3980         return true;
3981     }
3982 
MatchBfn(Instruction & I)3983     bool CodeGenPatternMatch::MatchBfn(Instruction& I)
3984     {
3985         using namespace llvm::PatternMatch;
3986 
3987         struct BfnPattern : public Pattern
3988         {
3989             uint8_t booleanFuncCtrl = 0;
3990             SSource sources[3];
3991             Instruction* instruction;
3992             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
3993             {
3994                 pass->Bfn(booleanFuncCtrl, sources, modifier);
3995             }
3996         };
3997 
3998         if (IGC_IS_FLAG_DISABLED(EnableBfn) ||
3999             !m_Platform.supportBfnInstruction())
4000         {
4001             return false;
4002         }
4003 
4004         // Only handle D & W integer types.
4005         Type* Ty = I.getType();
4006         if (!(Ty->isIntegerTy(16) || Ty->isIntegerTy(32)))
4007         {
4008             return false;
4009         }
4010 
4011         struct CtrlCalculator {
4012             enum OP { NONE, AND, OR, XOR };
4013             enum SOURCE { S0 = 0, S1 = 1, S2 = 2 };
4014             // Value of the three sources for calculating BooleanFuncCtrl
4015             const uint8_t s[3] = { 0xAA, 0xCC, 0xF0 };
4016 
4017             typedef std::vector<OP> OpVecType;
4018             typedef std::vector<SOURCE> SourceVecType;
4019 
4020             // The sequence of matched operators, follow the calculation order
4021             OpVecType ops;
4022             // The sequence of matched sources, follow the calculation order
4023             SourceVecType source_idx;
4024 
4025             void addSource(SOURCE src) { source_idx.push_back(src); }
4026             void addOPFromLLVMOp(unsigned llvmOp)
4027             {
4028                 if (llvmOp == Instruction::And)
4029                     ops.push_back(AND);
4030                 else if (llvmOp == Instruction::Or)
4031                     ops.push_back(OR);
4032                 else if (llvmOp == Instruction::Xor)
4033                     ops.push_back(XOR);
4034                 else
4035                 {
4036                     IGC_ASSERT_MESSAGE(0, "MatchBfn: inccorect opcode");
4037                     return;
4038                 }
4039             }
4040 
4041             uint8_t getBooleanFuncCtrl()
4042             {
4043                 IGC_ASSERT_MESSAGE((source_idx.size() == (ops.size() + 1)), "MatchBfn: OPs and Sources length missmatched");
4044                 SourceVecType::iterator s_it = source_idx.begin(), s_end = source_idx.end();
4045                 // start from the first source
4046                 uint8_t result = s[*s_it];
4047                 ++s_it;
4048 
4049                 // iterate through the matched sequence and compute the BooleanFuncCtrl
4050                 OpVecType::iterator op_it = ops.begin();
4051                 for (; s_it != s_end; ++s_it, ++op_it)
4052                 {
4053                     switch (*op_it)
4054                     {
4055                     case AND:
4056                         result = result & s[*s_it];
4057                         break;
4058                     case OR:
4059                         result = result | s[*s_it];
4060                         break;
4061                     case XOR:
4062                         result = result ^ s[*s_it];
4063                         break;
4064                     default:
4065                         IGC_ASSERT_MESSAGE(0, "MatchBfn:CtrlCalculator: inccorect OP");
4066                         break;
4067                     }
4068                 }
4069                 return result;
4070             }
4071         };
4072 
4073         auto isBinaryLogic = [](Instruction::BinaryOps op) { return op == Instruction::Or || op == Instruction::And || op == Instruction::Xor; };
4074         // Find BFN patterns. Matched patterns: (op0 and op1 are boolean operations)
4075         // s0   s1                   s1  s2
4076         //   \ /                      \  /
4077         //    op    s2     OR    s0   op1
4078         //     \   /              \   /
4079         //      op0                op0   <-- match to BFN
4080         //
4081         // TODO: Match NOT pattern on source
4082 
4083         // if source operand has many uses the bfn pattern match is unlikely to be profitable,
4084         // as it increases register pressure and makes register bank conflicts more likely
4085         // ToDo: tune the value of N
4086         const int useThreshold = 4;
4087         // if both operands of the root is binary logic op, use simple heuristics
4088         // to fold one of them
4089         if (isa<BinaryOperator>(I.getOperand(0)) && isa<BinaryOperator>(I.getOperand(1)))
4090         {
4091             auto I0 = cast<BinaryOperator>(I.getOperand(0));
4092             auto I1 = cast<BinaryOperator>(I.getOperand(1));
4093             if (I0->hasNUsesOrMore(useThreshold) && I1->hasNUsesOrMore(useThreshold))
4094             {
4095                 // bfn is unlikely to be profitable.
4096                 return false;
4097             }
4098             else if (I0->getNumUses() > I1->getNumUses())
4099             {
4100                 I.setOperand(0, I1);
4101                 I.setOperand(1, I0);
4102             }
4103         }
4104 
4105         CtrlCalculator ctrlcal;
4106         Value* s0 = I.getOperand(0);
4107         Value* s1 = nullptr, * s2 = nullptr;
4108         BinaryOperator* I0 = dyn_cast<BinaryOperator>(s0);
4109         if (I0)
4110         {
4111             if (isBinaryLogic(I0->getOpcode()) && !I0->hasNUsesOrMore(useThreshold))
4112             {
4113                 s0 = I0->getOperand(0);
4114                 s1 = I0->getOperand(1);
4115             }
4116         }
4117 
4118         if (s1 == nullptr)
4119         {
4120             s1 = I.getOperand(1);
4121             BinaryOperator* I1 = dyn_cast<BinaryOperator>(s1);
4122             if (I1)
4123             {
4124                 if (isBinaryLogic(I1->getOpcode()) && !I1->hasNUsesOrMore(useThreshold))
4125                 {
4126                     s1 = I1->getOperand(0);
4127                     s2 = I1->getOperand(1);
4128 
4129                     // add ops and sources by execution order
4130                     ctrlcal.addSource(CtrlCalculator::S1);
4131                     ctrlcal.addOPFromLLVMOp(I1->getOpcode());
4132                     ctrlcal.addSource(CtrlCalculator::S2);
4133                     ctrlcal.addOPFromLLVMOp(I.getOpcode());
4134                     ctrlcal.addSource(CtrlCalculator::S0);
4135                 }
4136             }
4137         }
4138         else
4139         {
4140             s2 = I.getOperand(1);
4141 
4142             // add ops and sources by execution order
4143             ctrlcal.addSource(CtrlCalculator::S0);
4144             ctrlcal.addOPFromLLVMOp(I0->getOpcode());
4145             ctrlcal.addSource(CtrlCalculator::S1);
4146             ctrlcal.addOPFromLLVMOp(I.getOpcode());
4147             ctrlcal.addSource(CtrlCalculator::S2);
4148         }
4149 
4150         if (s2 == nullptr)
4151             return false;
4152 
4153         BfnPattern* pattern = new (m_allocator) BfnPattern();
4154         pattern->booleanFuncCtrl = ctrlcal.getBooleanFuncCtrl();
4155         pattern->instruction = &I;
4156         pattern->sources[0] = GetSource(s0, false, false);
4157         pattern->sources[1] = GetSource(s1, false, false);
4158         pattern->sources[2] = GetSource(s2, false, false);
4159 
4160         // BFN can use imm16 in src0 and src2, check for those;
4161         // otherwise try to add to constant pool even int32.
4162         if (dyn_cast<ConstantInt>(s0) && !s0->getType()->isIntegerTy(16) )
4163         {
4164             AddToConstantPool(I.getParent(), s0); pattern->sources[0].fromConstantPool = true;
4165         }
4166         if (dyn_cast<ConstantInt>(s1))
4167         {
4168             AddToConstantPool(I.getParent(), s1); pattern->sources[1].fromConstantPool = true;
4169         }
4170         if (dyn_cast<ConstantInt>(s2) && !s2->getType()->isIntegerTy(16))
4171         {
4172             AddToConstantPool(I.getParent(), s2); pattern->sources[2].fromConstantPool = true;
4173         }
4174 
4175         AddPattern(pattern);
4176         return true;
4177     }
4178 
4179     // Match this pattern
4180     // %1 = cmp %2 %3
4181     // %6 = select %1 $4 %5
4182     // to
4183     // %1 = cmp %2 %3
4184     // %6  = bfn %1 %4 %5
4185     //
4186     // For the original cmp-sel sequence, a flag-based sequence is generated.
4187     // We instead want to generate a non-flag cmp-bfn sequence which has shorter latency.
MatchCmpSelect(llvm::SelectInst & I)4188     bool CodeGenPatternMatch::MatchCmpSelect(llvm::SelectInst& I)
4189     {
4190         struct CmpSelectPattern : public Pattern
4191         {
4192             uint8_t execSize;
4193             llvm::CmpInst::Predicate predicate;
4194             SSource cmpSources[2];
4195             uint8_t booleanFuncCtrl = 0xD8; // represents s0&s1|~s0&s2
4196             SSource bfnSources[3];
4197             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
4198             {
4199                 pass->CmpBfn(predicate, cmpSources, booleanFuncCtrl, bfnSources, modifier);
4200             }
4201         };
4202 
4203         if (IGC_IS_FLAG_DISABLED(EnableBfn) ||
4204             !m_Platform.supportBfnInstruction())
4205         {
4206             return false;
4207         }
4208 
4209         if (llvm::CmpInst* cmp = llvm::dyn_cast<llvm::CmpInst>(I.getOperand(0)))
4210         {
4211             // handle one use for now
4212             if (!cmp->hasOneUse())
4213             {
4214                 return false;
4215             }
4216 
4217             // BFN only supports D & W types.
4218             Type* selTy = I.getType();
4219             if (!(selTy->isIntegerTy(16) || selTy->isIntegerTy(32)))
4220             {
4221                 return false;
4222             }
4223 
4224             Type* cmpS0Ty = cmp->getOperand(0)->getType();
4225             if (selTy->getPrimitiveSizeInBits() != cmpS0Ty->getPrimitiveSizeInBits())
4226             {
4227                 return false;
4228             }
4229 
4230             llvm::Value* selSources[2];
4231             e_modifier   selMod[2];
4232             selSources[0] = I.getOperand(1);
4233             selSources[1] = I.getOperand(2);
4234 
4235             // BFN only supports 16bit immediate
4236             if ((isa<Constant>(selSources[0]) && selSources[0]->getType()->isIntegerTy(32)) ||
4237                 (isa<Constant>(selSources[1]) && selSources[1]->getType()->isIntegerTy(32)))
4238             {
4239                 return false;
4240             }
4241 
4242             // As BFN doesn't support src modifier, it is not worth to generate the cmp-bfn
4243             // sequence if one of its sources will need an extra move.
4244             if (GetModifier(*selSources[0], selMod[0], selSources[0]) ||
4245                 GetModifier(*selSources[1], selMod[1], selSources[1]))
4246             {
4247                 return false;
4248             }
4249 
4250             CmpSelectPattern* pattern = new (m_allocator) CmpSelectPattern();
4251             pattern->predicate = cmp->getPredicate();
4252             bool supportsModifer = SupportsModifier(cmp);
4253             pattern->cmpSources[0] = GetSource(cmp->getOperand(0), supportsModifer, false);
4254             pattern->cmpSources[1] = GetSource(cmp->getOperand(1), supportsModifer, false);
4255 
4256             pattern->bfnSources[1] = GetSource(selSources[0], false, false);
4257             pattern->bfnSources[2] = GetSource(selSources[1], false, false);
4258             AddPattern(pattern);
4259 
4260             return true;
4261         }
4262         return false;
4263     }
4264 
MatchDpas(GenIntrinsicInst & I)4265     bool CodeGenPatternMatch::MatchDpas(GenIntrinsicInst& I)
4266     {
4267         struct DpasPattern : public Pattern
4268         {
4269             SSource source[3];
4270             GenIntrinsicInst* instruction;
4271             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
4272             {
4273                 pass->emitDpas(instruction, source, modifier);
4274                 //pass->BinaryUnary(instruction, source, modifier);
4275             }
4276         };
4277 
4278         GenISAIntrinsic::ID dpasID = I.getIntrinsicID();
4279         IGC_ASSERT_MESSAGE((dpasID == GenISAIntrinsic::GenISA_dpas || dpasID == GenISAIntrinsic::GenISA_sub_group_dpas), "Unexpected DPAS intrinsic!");
4280 
4281         Value* src0 = I.getOperand(0); // input
4282         Value* src1 = I.getOperand(1); // activation. operand(3) is its precision
4283         Value* src2 = I.getOperand(2); // weight. operand(4) is its precision
4284         //ConstantInt* pa = dyn_cast<ConstantInt>(I.getOperand(3));
4285         //ConstantInt* pb = dyn_cast<ConstantInt>(I.getOperand(4));
4286         ConstantInt* sdepth = dyn_cast<ConstantInt>(I.getOperand(5));
4287         ConstantInt* rcount = dyn_cast<ConstantInt>(I.getOperand(6));
4288         int SD = (int)sdepth->getZExtValue();
4289         int RC = (int)rcount->getZExtValue();
4290 
4291         if (dpasID == GenISAIntrinsic::GenISA_dpas && RC == 1 && SD == 8)
4292         {
4293             // Special-handling of activation (src1 - uniform). The case handled is:
4294             //
4295             // %s0 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %5, i32 0)
4296             // %s1 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %5, i32 1)
4297             // %s2 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %5, i32 2)
4298             // %s3 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %5, i32 3)
4299             // %s4 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %5, i32 4)
4300             // %s5 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %5, i32 5)
4301             // %s6 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %5, i32 6)
4302             // %s7 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %5, i32 7)
4303             // %a0 = insertelement <8 x i32> undef, i32 %s0, i32 0
4304             // %a1 = insertelement <8 x i32> %a0,   i32 %s1, i32 1
4305             // %a2 = insertelement <8 x i32> %a1,   i32 %s2, i32 2
4306             // %a3 = insertelement <8 x i32> %a2,   i32 %s3, i32 3
4307             // %a4 = insertelement <8 x i32> %a3,   i32 %s4, i32 4
4308             // %a5 = insertelement <8 x i32> %a4,   i32 %s5, i32 5
4309             // %a6 = insertelement <8 x i32> %a5,   i32 %s6, i32 6
4310             // %a7 = insertelement <8 x i32> %a6,   i32 %s7, i32 7
4311             //
4312             // %c0 = call i32 @llvm.genx.GenISA.dpas.8(i32 %9, <8 x i32> %a7, i32 7, <8 x i32> %14, i32 7)
4313             //
4314             // %a7 will be replaced with %5 in vISA.
4315             //
4316 
4317             InsertElementInst* rootIEI = dyn_cast<InsertElementInst>(src1);
4318             if (rootIEI)
4319             {
4320                 // Currently, the src1 can be at most int8.
4321                 Value* srcVec[8];
4322                 Value* WSVal = nullptr;
4323                 for (int i = 0; i < 8; srcVec[i] = nullptr, ++i);
4324                 InsertElementInst* IEI = rootIEI;
4325                 while (IEI)
4326                 {
4327                     Value* elem = IEI->getOperand(1);
4328                     ConstantInt* CI = dyn_cast<ConstantInt>(IEI->getOperand(2));
4329                     if (!CI)
4330                     {
4331                         // Only handle constant index
4332                         WSVal = nullptr;
4333                         break;
4334                     }
4335                     uint32_t ix = (uint32_t)CI->getZExtValue();
4336                     if (ix > 7 || srcVec[ix])
4337                     {
4338                         // Assume each element is inserted once.
4339                         WSVal = nullptr;
4340                         break;
4341                     }
4342 
4343                     // Check if the inserted element is created
4344                     // by WaveShuffleIndex intrinsic
4345                     GenIntrinsicInst* WSI = dyn_cast<GenIntrinsicInst>(elem);
4346                     if (!WSI ||
4347                         WSI->getIntrinsicID() != GenISAIntrinsic::GenISA_WaveShuffleIndex)
4348                     {
4349                         WSVal = nullptr;
4350                         break;
4351                     }
4352                     Value* shuffleVal = WSI->getOperand(0);
4353                     Value* ixVal = WSI->getOperand(1);
4354                     if (WSVal == nullptr)
4355                     {
4356                         WSVal = shuffleVal;
4357                     }
4358                     else if (WSVal != shuffleVal)
4359                     {
4360                         WSVal = nullptr;
4361                         break;
4362                     }
4363                     if (ConstantInt * CIX = dyn_cast<ConstantInt>(ixVal))
4364                     {
4365                         if ((uint32_t)CIX->getZExtValue() != ix)
4366                         {
4367                             WSVal = nullptr;
4368                             break;
4369                         }
4370                     }
4371                     srcVec[ix] = elem;
4372                     IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
4373                 }
4374 
4375                 if (WSVal)
4376                 {
4377                     for (int i = 0; i < SD; ++i)
4378                     {
4379                         if (srcVec[i] == nullptr)
4380                         {
4381                             WSVal = nullptr;
4382                             break;
4383                         }
4384                     }
4385                 }
4386 
4387                 // If WSVal is set at this point, it is one that will
4388                 // replace src1.
4389                 if (WSVal)
4390                 {
4391                     src1 = WSVal;
4392                 }
4393             }
4394         }
4395 
4396         DpasPattern* pattern = new (m_allocator) DpasPattern();
4397         pattern->instruction = &I;
4398         pattern->source[0] = GetSource(src0, false, false);
4399         pattern->source[1] = GetSource(src1, false, false);
4400         pattern->source[2] = GetSource(src2, false, false);
4401         AddPattern(pattern);
4402 
4403         return true;
4404     }
4405 
MatchDp4a(GenIntrinsicInst & I)4406     bool CodeGenPatternMatch::MatchDp4a(GenIntrinsicInst& I)
4407     {
4408         struct MatchDp4a : public Pattern
4409         {
4410             SSource source[3];
4411             GenIntrinsicInst* instruction;
4412             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
4413             {
4414                 pass->emitDP4A(instruction, source, modifier);
4415             }
4416         };
4417 
4418         // Attempt to find a pattern like this:
4419         // %scalar52.6.2474 = extractelement <4 x i8> %145, i32 0
4420         // %scalar53.6.2475 = extractelement <4 x i8> % 145, i32 1
4421         // %scalar54.6.2476 = extractelement <4 x i8> % 145, i32 2
4422         // %scalar55.6.2477 = extractelement <4 x i8> % 145, i32 3
4423         // %simdShuffle.6.2480 = call i8 @llvm.genx.GenISA.WaveShuffleIndex.i8(i8 % scalar52.6.2474, i32 0, i32 0)
4424         // %simdShuffle33.6.2481 = call i8 @llvm.genx.GenISA.WaveShuffleIndex.i8(i8 % scalar53.6.2475, i32 0, i32 0)
4425         // %simdShuffle34.6.2482 = call i8 @llvm.genx.GenISA.WaveShuffleIndex.i8(i8 % scalar54.6.2476, i32 0, i32 0)
4426         // %simdShuffle35.6.2483 = call i8 @llvm.genx.GenISA.WaveShuffleIndex.i8(i8 % scalar55.6.2477, i32 0, i32 0)
4427         // %assembled.vect.6.2484 = insertelement <4 x i8> undef, i8 % simdShuffle.6.2480, i32 0
4428         // %assembled.vect56.6.2485 = insertelement <4 x i8> % assembled.vect.6.2484, i8 % simdShuffle33.6.2481, i32 1
4429         // %assembled.vect57.6.2486 = insertelement <4 x i8> % assembled.vect56.6.2485, i8 % simdShuffle34.6.2482, i32 2
4430         // %assembled.vect58.6.2487 = insertelement <4 x i8> % assembled.vect57.6.2486, i8 % simdShuffle35.6.2483, i32 3
4431         // %astype.i45.6.2489 = bitcast <4 x i8> % assembled.vect58.6.2487 to i32
4432         // %call.i4738.6.2490 = call i32 @llvm.genx.GenISA.dp4a.uu(i32 % call.i4738.6.3.1, i32 % astype.i45.6.2489, i32 % 135)
4433         // %call.i1239.6.2492 = call i32 @llvm.genx.GenISA.dp4a.uu(i32 % call.i1239.6.3.1, i32 % astype.i45.6.2489, i32 16843009)
4434         //
4435         // If the pattern is found, we can avoid creating extra movs by changing the source of the dp4a
4436         // to use the input from the bitcast instead of using the result of the bitcast.
4437         // This pattern only happens if the WaveShuffleIndex uses 0,0 for the arguments.
4438         if (llvm::BitCastInst* bcInst = llvm::dyn_cast<BitCastInst> (I.getOperand(1)))
4439         {
4440             llvm::CallInst* waveInst[4] = { nullptr };
4441             llvm::InsertElementInst* insertInst[4] = { nullptr };
4442 
4443             // Find the 4 x insertelement
4444             insertInst[3] = llvm::dyn_cast<InsertElementInst>(bcInst->getOperand(0));
4445             if (insertInst[3]) {
4446                 for (int i = 2; i >= 0; i--) {
4447                     insertInst[i] = llvm::dyn_cast<InsertElementInst>(insertInst[i + 1]->getOperand(0));
4448                     if (!insertInst[i])
4449                         break;
4450                 }
4451             }
4452 
4453             // Find the 4 x WaveShuffleIndex
4454             for (int i = 3; i >= 0; i--) {
4455                 if (!insertInst[i])
4456                     break;
4457                 CallInst* temp = llvm::dyn_cast<CallInst>(insertInst[i]->getOperand(1));
4458                 if (!temp)
4459                     break;
4460 
4461                 llvm::GenIntrinsicInst* intrin = llvm::dyn_cast<llvm::GenIntrinsicInst>(temp);
4462                 if (!intrin || intrin->getIntrinsicID() != GenISAIntrinsic::GenISA_WaveShuffleIndex)
4463                     break;
4464                 waveInst[i] = temp;
4465             }
4466 
4467             // Check to see if the WaveShuffleIndex uses 0,0
4468             llvm::Constant * wavesrc1 = nullptr, * wavesrc2 = nullptr;
4469             if (waveInst[0]) {
4470                 wavesrc1 = llvm::dyn_cast<llvm::Constant>(waveInst[0]->getOperand(1));
4471                 wavesrc2 = llvm::dyn_cast<llvm::Constant>(waveInst[0]->getOperand(2));
4472             }
4473 
4474             if (wavesrc1 && wavesrc2 && wavesrc1->isZeroValue() && wavesrc2->isZeroValue())
4475             {
4476                 if (llvm::ExtractElementInst* wavesrc0 = llvm::dyn_cast<llvm::ExtractElementInst>(waveInst[0]->getOperand(0)))
4477                 {
4478                     llvm::ExtractElementInst* extractInst[4];
4479 
4480                     // Find the 4 x extractelement
4481                     extractInst[0] = llvm::dyn_cast<llvm::ExtractElementInst>(waveInst[0]->getOperand(0));
4482                     extractInst[1] = llvm::dyn_cast<llvm::ExtractElementInst>(waveInst[1]->getOperand(0));
4483                     extractInst[2] = llvm::dyn_cast<llvm::ExtractElementInst>(waveInst[2]->getOperand(0));
4484                     extractInst[3] = llvm::dyn_cast<llvm::ExtractElementInst>(waveInst[3]->getOperand(0));
4485 
4486                     if (extractInst[0] && extractInst[1] && extractInst[2] && extractInst[3] &&
4487                         extractInst[0]->getOperand(0) == extractInst[1]->getOperand(0) &&
4488                         extractInst[0]->getOperand(0) == extractInst[2]->getOperand(0) &&
4489                         extractInst[0]->getOperand(0) == extractInst[3]->getOperand(0))
4490                     {
4491                         MatchDp4a* pattern = new (m_allocator) MatchDp4a();
4492                         pattern->instruction = &I;
4493                         pattern->source[0] = GetSource(I.getOperand(0), false, false);
4494 
4495                         // set regioning to: <0;1,0>, as if we were still going to use the result of the WaveShuffleIndex
4496                         llvm::BitCastInst* bitCast = llvm::dyn_cast<BitCastInst>(extractInst[0]->getOperand(0));
4497                         if (bitCast)
4498                         {
4499                             pattern->source[1] = GetSource(bitCast->getOperand(0), false, false);
4500                             pattern->source[1].region[0] = 0;
4501                             pattern->source[1].region[1] = 1;
4502                             pattern->source[1].region[2] = 0;
4503                             pattern->source[1].region_set = true;
4504 
4505                             pattern->source[2] = GetSource(I.getOperand(2), false, false);
4506                             AddPattern(pattern);
4507                             return true;
4508                         }
4509 
4510                     }
4511                 }
4512             }
4513         }
4514         return MatchSingleInstruction(I);
4515     }
4516 
MatchLogicAlu(llvm::BinaryOperator & I)4517     bool CodeGenPatternMatch::MatchLogicAlu(llvm::BinaryOperator& I)
4518     {
4519         struct LogicInstPattern : public Pattern
4520         {
4521             SSource sources[2];
4522             llvm::Instruction* instruction;
4523             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
4524             {
4525                 pass->BinaryUnary(instruction, sources, modifier);
4526             }
4527         };
4528         LogicInstPattern* pattern = new (m_allocator) LogicInstPattern();
4529         pattern->instruction = &I;
4530         for (unsigned int i = 0; i < 2; ++i)
4531         {
4532             e_modifier mod = EMOD_NONE;
4533             Value* src = I.getOperand(i);
4534             if (!I.getType()->isIntegerTy(1))
4535             {
4536                 if (BinaryOperator * notInst = dyn_cast<BinaryOperator>(src))
4537                 {
4538                     if (notInst->getOpcode() == Instruction::Xor)
4539                     {
4540                         if (ConstantInt * minusOne = dyn_cast<ConstantInt>(notInst->getOperand(1)))
4541                         {
4542                             if (minusOne->isMinusOne())
4543                             {
4544                                 mod = EMOD_NOT;
4545                                 src = notInst->getOperand(0);
4546                             }
4547                         }
4548                     }
4549                 }
4550             }
4551             pattern->sources[i] = GetSource(src, mod, false);
4552 
4553             if (isCandidateForConstantPool(src))
4554             {
4555                 AddToConstantPool(I.getParent(), src);
4556                 pattern->sources[i].fromConstantPool = true;
4557             }
4558 
4559 
4560         }
4561         AddPattern(pattern);
4562         return true;
4563     }
4564 
MatchRsqrt(llvm::BinaryOperator & I)4565     bool CodeGenPatternMatch::MatchRsqrt(llvm::BinaryOperator& I)
4566     {
4567         struct RsqrtPattern : public Pattern
4568         {
4569             SSource source;
4570             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
4571             {
4572                 pass->Rsqrt(source, modifier);
4573             }
4574         };
4575 
4576         bool found = false;
4577         llvm::Value* source = NULL;
4578         if (I.getOpcode() == Instruction::FDiv)
4579         {
4580             // by vISA document, rsqrt doesn't support double type
4581             if (isOne(I.getOperand(0)) && I.getType()->getTypeID() != Type::DoubleTyID)
4582             {
4583                 if (llvm::IntrinsicInst * sqrt = dyn_cast<IntrinsicInst>(I.getOperand(1)))
4584                 {
4585                     if (sqrt->getIntrinsicID() == Intrinsic::sqrt)
4586                     {
4587                         source = sqrt->getOperand(0);
4588                         found = true;
4589                     }
4590                 }
4591             }
4592         }
4593         if (found)
4594         {
4595             RsqrtPattern* pattern = new (m_allocator) RsqrtPattern();
4596             pattern->source = GetSource(source, true, false);
4597             AddPattern(pattern);
4598         }
4599         return found;
4600     }
4601 
MatchGradient(llvm::GenIntrinsicInst & I)4602     bool CodeGenPatternMatch::MatchGradient(llvm::GenIntrinsicInst& I)
4603     {
4604         struct GradientPattern : public Pattern
4605         {
4606             SSource source;
4607             llvm::Instruction* instruction;
4608             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
4609             {
4610                 pass->BinaryUnary(instruction, &source, modifier);
4611             }
4612         };
4613         GradientPattern* pattern = new (m_allocator) GradientPattern();
4614         pattern->instruction = &I;
4615         pattern->source = GetSource(I.getOperand(0), true, false);
4616         AddPattern(pattern);
4617         // mark the source as subspan use
4618         HandleSubspanUse(pattern->source.value);
4619         return true;
4620     }
4621 
MatchSampleDerivative(llvm::GenIntrinsicInst & I)4622     bool CodeGenPatternMatch::MatchSampleDerivative(llvm::GenIntrinsicInst& I)
4623     {
4624         HandleSampleDerivative(I);
4625         return MatchSingleInstruction(I);
4626     }
4627 
MatchDbgInstruction(llvm::DbgInfoIntrinsic & I)4628     bool CodeGenPatternMatch::MatchDbgInstruction(llvm::DbgInfoIntrinsic& I)
4629     {
4630         struct DbgInstPattern : Pattern
4631         {
4632             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
4633             {
4634                 // Nothing to emit.
4635             }
4636         };
4637         DbgInstPattern* pattern = new (m_allocator) DbgInstPattern();
4638         if (DbgDeclareInst * pDbgDeclInst = dyn_cast<DbgDeclareInst>(&I))
4639         {
4640             if (pDbgDeclInst->getAddress())
4641             {
4642                 MarkAsSource(pDbgDeclInst->getAddress());
4643             }
4644         }
4645         else if (DbgValueInst * pDbgValInst = dyn_cast<DbgValueInst>(&I))
4646         {
4647             if (pDbgValInst->getValue())
4648             {
4649                 MarkAsSource(pDbgValInst->getValue());
4650             }
4651         }
4652         else
4653         {
4654             IGC_ASSERT_MESSAGE(0, "Unhandled Dbg intrinsic");
4655         }
4656         AddPattern(pattern);
4657         return true;
4658     }
4659 
MatchAvg(llvm::Instruction & I)4660     bool CodeGenPatternMatch::MatchAvg(llvm::Instruction& I)
4661     {
4662         // "Average value" pattern:
4663         // (x + y + 1) / 2  -->  avg(x, y)
4664         //
4665         // We're looking for patterns like this:
4666         //    % 14 = add nsw i32 % 10, % 13
4667         //    % 15 = add nsw i32 % 14, 1
4668         //    % 16 = ashr i32 % 15, 1
4669 
4670         struct AvgPattern : Pattern
4671         {
4672             SSource sources[2];
4673             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
4674             {
4675                 pass->Avg(sources, modifier);
4676             }
4677         };
4678 
4679         bool found = false;
4680         llvm::Value* sources[2];
4681         e_modifier   src_mod[2];
4682 
4683         IGC_ASSERT(I.getOpcode() == Instruction::SDiv || I.getOpcode() == Instruction::UDiv || I.getOpcode() == Instruction::AShr);
4684 
4685         // We expect 2 for "div" and 1 for "right shift".
4686         int  expectedVal = (I.getOpcode() == Instruction::SDiv ? 2 : 1);
4687         Value* opnd1 = I.getOperand(1);   // Divisor or shift factor.
4688         if (!isa<ConstantInt>(opnd1) || (cast<ConstantInt>(opnd1))->getZExtValue() != expectedVal)
4689         {
4690             return false;
4691         }
4692 
4693         if (Instruction * divSrc = dyn_cast<Instruction>(I.getOperand(0)))
4694         {
4695             if (divSrc->getOpcode() == Instruction::Add && !NeedInstruction(*divSrc))
4696             {
4697                 Instruction* instAdd = cast<Instruction>(divSrc);
4698                 for (int i = 0; i < 2; i++)
4699                 {
4700                     if (ConstantInt * cnst = dyn_cast<ConstantInt>(instAdd->getOperand(i)))
4701                     {
4702                         // "otherArg" is the second argument of "instAdd" (which is not constant).
4703                         Value* otherArg = instAdd->getOperand(i == 0 ? 1 : 0);
4704                         if (cnst->getZExtValue() == 1 && isa<AddOperator>(otherArg) && !NeedInstruction(*cast<Instruction>(otherArg)))
4705                         {
4706                             Instruction* firstAdd = cast<Instruction>(otherArg);
4707                             sources[0] = firstAdd->getOperand(0);
4708                             sources[1] = firstAdd->getOperand(1);
4709                             GetModifier(*sources[0], src_mod[0], sources[0]);
4710                             GetModifier(*sources[1], src_mod[1], sources[1]);
4711                             found = true;
4712                             break;
4713                         }
4714                     }
4715                 }
4716             }
4717         }
4718 
4719         if (found)
4720         {
4721             AvgPattern* pattern = new (m_allocator)AvgPattern();
4722             pattern->sources[0] = GetSource(sources[0], src_mod[0], false);
4723             pattern->sources[1] = GetSource(sources[1], src_mod[1], false);
4724             AddPattern(pattern);
4725         }
4726         return found;
4727     }
4728 
MatchShuffleBroadCast(llvm::GenIntrinsicInst & I)4729     bool CodeGenPatternMatch::MatchShuffleBroadCast(llvm::GenIntrinsicInst& I)
4730     {
4731         // Match cases like:
4732         //    %84 = bitcast <2 x i32> %vCastload to <4 x half>
4733         //    %scalar269 = extractelement <4 x half> % 84, i32 0
4734         //    %simdShuffle = call half @llvm.genx.GenISA.simdShuffle.f.f16(half %scalar269, i32 0)
4735         //
4736         // to mov with region and offset
4737         struct BroadCastPattern : public Pattern
4738         {
4739             SSource source;
4740             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
4741             {
4742                 pass->Mov(source, modifier);
4743             }
4744         };
4745         bool match = false;
4746         SSource source;
4747         Value* sourceV = &I;
4748         if (GetRegionModifier(source, sourceV, true))
4749         {
4750             BroadCastPattern* pattern = new (m_allocator) BroadCastPattern();
4751             GetModifier(*sourceV, source.mod, sourceV);
4752             source.value = sourceV;
4753             pattern->source = source;
4754             MarkAsSource(sourceV);
4755             match = true;
4756             AddPattern(pattern);
4757         }
4758         return match;
4759     }
4760 
MatchWaveShuffleIndex(llvm::GenIntrinsicInst & I)4761     bool CodeGenPatternMatch::MatchWaveShuffleIndex(llvm::GenIntrinsicInst& I)
4762     {
4763         llvm::Value* helperLaneMode = I.getOperand(2);
4764         IGC_ASSERT(helperLaneMode);
4765         if (int_cast<int>(cast<ConstantInt>(helperLaneMode)->getSExtValue()) == 1)
4766         {
4767             //only if helperLaneMode==1, we enable helper lane under some shuffleindex cases (not for all cases).
4768             HandleSubspanUse(I.getArgOperand(0));
4769             HandleSubspanUse(&I);
4770         }
4771         return MatchSingleInstruction(I);
4772     }
4773 
MatchRegisterRegion(llvm::GenIntrinsicInst & I)4774     bool CodeGenPatternMatch::MatchRegisterRegion(llvm::GenIntrinsicInst& I)
4775     {
4776         struct MatchRegionPattern : public Pattern
4777         {
4778             SSource source;
4779             virtual void Emit(EmitPass* pass, const DstModifier& modifier)
4780             {
4781                 pass->Mov(source, modifier);
4782             }
4783         };
4784 
4785         /*
4786         * Match case 1 - With SubReg Offset: Shuffle( data, (laneID << x) + y )
4787         *   %25 = call i16 @llvm.genx.GenISA.simdLaneId()
4788         *   %30 = zext i16 %25 to i32
4789         *   %31 = shl nuw nsw i32 %30, 1  - Current LaneID shifted by x
4790         *   %36 = add i32 %31, 1          - Current LaneID shifted by x + y  Shuffle( data, (laneID << x) + 1 )
4791         *   %37 = call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %21, i32 %36)
4792 
4793         * Match case 2(Special case of Match Case 1) - No SubReg Offset: Shuffle( data, (laneID << x) + 0 )
4794         *    %25 = call i16 @llvm.genx.GenISA.simdLaneId()
4795         *    %30 = zext i16 %25 to i32
4796         *    %31 = shl nuw nsw i32 %30, 1 - Current LaneID shifted by x
4797         *    %32 = call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %21, i32 %31)
4798         */
4799 
4800         Value* data = I.getOperand(0);
4801         Value* source = I.getOperand(1);
4802         uint typeByteSize = data->getType()->getScalarSizeInBits() / 8;
4803         bool isMatch = false;
4804         int subReg = 0;
4805         uint verticalStride = 1; //Default value for special case  Shuffle( data, (laneID << x) + y )  when x = 0
4806 
4807         if (auto binaryInst = dyn_cast<BinaryOperator>(source))
4808         {
4809             //Will be skipped for match case 2
4810             if (binaryInst->getOpcode() == Instruction::Add)
4811             {
4812                 if (llvm::ConstantInt * simDOffSetInst = llvm::dyn_cast<llvm::ConstantInt>(binaryInst->getOperand(1)))
4813                 {
4814                     subReg = int_cast<int>(cast<ConstantInt>(simDOffSetInst)->getSExtValue());
4815 
4816                     //Subregister must be a number between 0 and 15 for a valid region
4817                     // We could support up to 31 but we need to handle reading from different SIMD16 var chunks
4818                     if (subReg >= 0 && subReg < 16)
4819                     {
4820                         source = binaryInst->getOperand(0);
4821                     }
4822                 }
4823             }
4824         }
4825 
4826         if (auto binaryInst = dyn_cast<BinaryOperator>(source))
4827         {
4828             if (binaryInst->getOpcode() == Instruction::Shl)
4829             {
4830                 source = binaryInst->getOperand(0);
4831 
4832                 if (llvm::ConstantInt * simDOffSetInst = llvm::dyn_cast<llvm::ConstantInt>(binaryInst->getOperand(1)))
4833                 {
4834                     uint shiftFactor = int_cast<uint>(simDOffSetInst->getZExtValue());
4835                     //Check to make sure we dont end up with an invalid Vertical Stride.
4836                     //Only 1, 2, 4, 8, 16 are supported.
4837                     if (shiftFactor <= 4)
4838                         verticalStride = (1U << shiftFactor);
4839                     else
4840                         return false;
4841                 }
4842             }
4843         }
4844 
4845         if (auto zExtInst = llvm::dyn_cast<llvm::ZExtInst>(source))
4846         {
4847             source = zExtInst->getOperand(0);
4848         }
4849 
4850         llvm::GenIntrinsicInst* intrin = llvm::dyn_cast<llvm::GenIntrinsicInst>(source);
4851 
4852         //Finally check for simLaneID intrisic
4853         if (intrin && (intrin->getIntrinsicID() == GenISAIntrinsic::GenISA_simdLaneId))
4854         {
4855             //To avoid compiler crash, pattern match with direct mov will be disable
4856             //Conservetively, we assum simd16 for 32 bytes GRF platforms and simd32 for 64 bytes GRF platforms
4857             bool cross2GRFs = typeByteSize * (subReg + verticalStride * (m_Platform.getGRFSize() > 32 ? 32 : 16)) > (2 * m_Platform.getGRFSize());
4858             if (!cross2GRFs)
4859             {
4860                 MatchRegionPattern* pattern = new (m_allocator) MatchRegionPattern();
4861                 pattern->source.elementOffset = subReg;
4862 
4863                 //Set Region Parameters <VerString;Width,HorzString>
4864                 pattern->source.region_set = true;
4865                 pattern->source.region[0] = verticalStride;
4866                 pattern->source.region[1] = 1;
4867                 pattern->source.region[2] = 0;
4868 
4869                 pattern->source.value = data;
4870                 MarkAsSource(data);
4871                 HandleSubspanUse(data);
4872                 AddPattern(pattern);
4873 
4874                 isMatch = true;
4875             }
4876         }
4877 
4878         return isMatch;
4879     }
4880 
GetRegionModifier(SSource & sourceMod,llvm::Value * & source,bool regioning)4881     bool CodeGenPatternMatch::GetRegionModifier(SSource& sourceMod, llvm::Value*& source, bool regioning)
4882     {
4883         bool found = false;
4884         Value* OrignalSource = source;
4885         if (llvm::BitCastInst * bitCast = llvm::dyn_cast<BitCastInst>(source))
4886         {
4887             if (!bitCast->getType()->isVectorTy() && !bitCast->getOperand(0)->getType()->isVectorTy())
4888             {
4889                 source = bitCast->getOperand(0);
4890                 found = true;
4891             }
4892         }
4893 
4894         if (llvm::GenIntrinsicInst * intrin = llvm::dyn_cast<llvm::GenIntrinsicInst>(source))
4895         {
4896             GenISAIntrinsic::ID id = intrin->getIntrinsicID();
4897             if (id == GenISAIntrinsic::GenISA_WaveShuffleIndex)
4898             {
4899                 if (llvm::ConstantInt * channelVal = llvm::dyn_cast<llvm::ConstantInt>(intrin->getOperand(1)))
4900                 {
4901                     unsigned int offset = int_cast<unsigned int>(channelVal->getZExtValue());
4902                     if (offset < 16 && !isUniform(intrin->getOperand(0)))
4903                     {
4904                         sourceMod.elementOffset = offset;
4905                         // SIMD shuffle force region <0,1;0>
4906                         sourceMod.region_set = true;
4907                         sourceMod.region[0] = 0;
4908                         sourceMod.region[1] = 1;
4909                         sourceMod.region[2] = 0;
4910                         sourceMod.instance = EINSTANCE_FIRST_HALF;
4911                         source = intrin->getOperand(0);
4912                         found = true;
4913                         BitcastSearch(sourceMod, source, true);
4914                     }
4915                 }
4916             }
4917         }
4918         if (regioning && !sourceMod.region_set)
4919         {
4920             found |= BitcastSearch(sourceMod, source, false);
4921         }
4922         if (found && sourceMod.type == VISA_Type::ISA_TYPE_NUM)
4923         {
4924             // keep the original type
4925             sourceMod.type = GetType(OrignalSource->getType(), m_ctx);
4926         }
4927         return found;
4928     }
4929 
HandleSampleDerivative(llvm::GenIntrinsicInst & I)4930     void CodeGenPatternMatch::HandleSampleDerivative(llvm::GenIntrinsicInst& I)
4931     {
4932         switch (I.getIntrinsicID())
4933         {
4934         case GenISAIntrinsic::GenISA_sampleptr:
4935         case GenISAIntrinsic::GenISA_lodptr:
4936         case GenISAIntrinsic::GenISA_sampleKillPix:
4937             HandleSubspanUse(I.getOperand(0));
4938             HandleSubspanUse(I.getOperand(1));
4939             HandleSubspanUse(I.getOperand(2));
4940             break;
4941         case GenISAIntrinsic::GenISA_sampleBptr:
4942         case GenISAIntrinsic::GenISA_sampleCptr:
4943             HandleSubspanUse(I.getOperand(1));
4944             HandleSubspanUse(I.getOperand(2));
4945             HandleSubspanUse(I.getOperand(3));
4946             break;
4947         case GenISAIntrinsic::GenISA_sampleBCptr:
4948             HandleSubspanUse(I.getOperand(2));
4949             HandleSubspanUse(I.getOperand(3));
4950             HandleSubspanUse(I.getOperand(4));
4951             break;
4952         default:
4953             break;
4954         }
4955     }
4956 
4957     // helper function for pattern match
isLowerPredicate(llvm::CmpInst::Predicate pred)4958     static inline bool isLowerPredicate(llvm::CmpInst::Predicate pred)
4959     {
4960         switch (pred)
4961         {
4962         case llvm::CmpInst::FCMP_ULT:
4963         case llvm::CmpInst::FCMP_ULE:
4964         case llvm::CmpInst::FCMP_OLT:
4965         case llvm::CmpInst::FCMP_OLE:
4966         case llvm::CmpInst::ICMP_ULT:
4967         case llvm::CmpInst::ICMP_ULE:
4968         case llvm::CmpInst::ICMP_SLT:
4969         case llvm::CmpInst::ICMP_SLE:
4970             return true;
4971         default:
4972             break;
4973         }
4974         return false;
4975     }
4976 
4977     // helper function for pattern match
isGreaterOrLowerPredicate(llvm::CmpInst::Predicate pred)4978     static inline bool isGreaterOrLowerPredicate(llvm::CmpInst::Predicate pred)
4979     {
4980         switch (pred)
4981         {
4982         case llvm::CmpInst::FCMP_UGT:
4983         case llvm::CmpInst::FCMP_UGE:
4984         case llvm::CmpInst::FCMP_ULT:
4985         case llvm::CmpInst::FCMP_ULE:
4986         case llvm::CmpInst::FCMP_OGT:
4987         case llvm::CmpInst::FCMP_OGE:
4988         case llvm::CmpInst::FCMP_OLT:
4989         case llvm::CmpInst::FCMP_OLE:
4990         case llvm::CmpInst::ICMP_UGT:
4991         case llvm::CmpInst::ICMP_UGE:
4992         case llvm::CmpInst::ICMP_ULT:
4993         case llvm::CmpInst::ICMP_ULE:
4994         case llvm::CmpInst::ICMP_SGT:
4995         case llvm::CmpInst::ICMP_SGE:
4996         case llvm::CmpInst::ICMP_SLT:
4997         case llvm::CmpInst::ICMP_SLE:
4998             return true;
4999         default:
5000             break;
5001         }
5002         return false;
5003     }
5004 
isIntegerAbs(SelectInst * SI,e_modifier & mod,Value * & source)5005     static bool isIntegerAbs(SelectInst* SI, e_modifier& mod, Value*& source) {
5006         using namespace llvm::PatternMatch; // Scoped using declaration.
5007 
5008         Value* Cond = SI->getOperand(0);
5009         Value* TVal = SI->getOperand(1);
5010         Value* FVal = SI->getOperand(2);
5011 
5012         ICmpInst::Predicate IPred = FCmpInst::FCMP_FALSE;
5013         Value* LHS = nullptr;
5014         Value* RHS = nullptr;
5015 
5016         if (!match(Cond, m_ICmp(IPred, m_Value(LHS), m_Value(RHS))))
5017             return false;
5018 
5019         if (!ICmpInst::isSigned(IPred))
5020             return false;
5021 
5022         if (match(LHS, m_Zero())) {
5023             IPred = ICmpInst::getSwappedPredicate(IPred);
5024             std::swap(LHS, RHS);
5025         }
5026 
5027         if (!match(RHS, m_Zero()))
5028             return false;
5029 
5030         if (match(TVal, m_Neg(m_Specific(FVal)))) {
5031             IPred = ICmpInst::getInversePredicate(IPred);
5032             std::swap(TVal, FVal);
5033         }
5034 
5035         if (!match(FVal, m_Neg(m_Specific(TVal))))
5036             return false;
5037 
5038         if (LHS != TVal)
5039             return false;
5040 
5041         source = TVal;
5042         mod = (IPred == ICmpInst::ICMP_SGT || IPred == ICmpInst::ICMP_SGE) ? EMOD_ABS : EMOD_NEGABS;
5043 
5044         return true;
5045     }
5046 
isAbs(llvm::Value * abs,e_modifier & mod,llvm::Value * & source)5047     bool isAbs(llvm::Value* abs, e_modifier& mod, llvm::Value*& source)
5048     {
5049         bool found = false;
5050 
5051         if (IntrinsicInst * intrinsicInst = dyn_cast<IntrinsicInst>(abs))
5052         {
5053             if (intrinsicInst->getIntrinsicID() == Intrinsic::fabs)
5054             {
5055                 source = intrinsicInst->getOperand(0);
5056                 mod = EMOD_ABS;
5057                 return true;
5058             }
5059         }
5060 
5061         llvm::SelectInst* select = llvm::dyn_cast<llvm::SelectInst>(abs);
5062         if (!select)
5063             return false;
5064 
5065         // Try to find floating point abs first
5066         if (llvm::FCmpInst * cmp = llvm::dyn_cast<llvm::FCmpInst>(select->getOperand(0)))
5067         {
5068             llvm::CmpInst::Predicate pred = cmp->getPredicate();
5069             if (isGreaterOrLowerPredicate(pred))
5070             {
5071                 for (int zeroIndex = 0; zeroIndex < 2; zeroIndex++)
5072                 {
5073                     llvm::ConstantFP* zero = llvm::dyn_cast<llvm::ConstantFP>(cmp->getOperand(zeroIndex));
5074                     if (zero && zero->isZero())
5075                     {
5076                         llvm::Value* cmpSource = cmp->getOperand(1 - zeroIndex);
5077                         for (int sourceIndex = 0; sourceIndex < 2; sourceIndex++)
5078                         {
5079                             if (cmpSource == select->getOperand(1 + sourceIndex))
5080                             {
5081                                 llvm::Instruction* opnd = llvm::dyn_cast<llvm::Instruction>(select->getOperand(1 + (1 - sourceIndex)));
5082                                 llvm::Value* negateSource = NULL;
5083                                 if (opnd && IsNegate(opnd, negateSource) && negateSource == cmpSource)
5084                                 {
5085                                     found = true;
5086                                     source = cmpSource;
5087                                     // depending on the order source in cmp/select it can abs() or -abs()
5088                                     bool isNegateAbs = (zeroIndex == 0) ^ isLowerPredicate(pred) ^ (sourceIndex == 1);
5089                                     mod = isNegateAbs ? EMOD_NEGABS : EMOD_ABS;
5090                                 }
5091                                 break;
5092                             }
5093                         }
5094                         break;
5095                     }
5096                 }
5097             }
5098         }
5099 
5100         // If not found, try integer abs
5101         return found || isIntegerAbs(select, mod, source);
5102     }
5103 
5104     // combine two modifiers, this function is *not* communtative
CombineModifier(e_modifier mod1,e_modifier mod2)5105     e_modifier CombineModifier(e_modifier mod1, e_modifier mod2)
5106     {
5107         e_modifier mod = EMOD_NONE;
5108         switch (mod1)
5109         {
5110         case EMOD_ABS:
5111         case EMOD_NEGABS:
5112             mod = mod1;
5113             break;
5114         case EMOD_NEG:
5115             if (mod2 == EMOD_NEGABS)
5116             {
5117                 mod = EMOD_ABS;
5118             }
5119             else if (mod2 == EMOD_ABS)
5120             {
5121                 mod = EMOD_NEGABS;
5122             }
5123             else if (mod2 == EMOD_NEG)
5124             {
5125                 mod = EMOD_NONE;
5126             }
5127             else
5128             {
5129                 mod = EMOD_NEG;
5130             }
5131             break;
5132         default:
5133             mod = mod2;
5134         }
5135         return mod;
5136     }
5137 
GetModifier(llvm::Value & modifier,e_modifier & mod,llvm::Value * & source)5138     bool GetModifier(llvm::Value& modifier, e_modifier& mod, llvm::Value*& source)
5139     {
5140         mod = EMOD_NONE;
5141         if (llvm::Instruction * bin = llvm::dyn_cast<llvm::Instruction>(&modifier))
5142         {
5143             return GetModifier(*bin, mod, source);
5144         }
5145         return false;
5146     }
5147 
GetModifier(llvm::Instruction & modifier,e_modifier & mod,llvm::Value * & source)5148     bool GetModifier(llvm::Instruction& modifier, e_modifier& mod, llvm::Value*& source)
5149     {
5150         llvm::Value* modifierSource = NULL;
5151         mod = EMOD_NONE;
5152         if (IsNegate(&modifier, modifierSource))
5153         {
5154             e_modifier absModifier = EMOD_NONE;
5155             llvm::Value* absSource = NULL;
5156             if (isAbs(modifierSource, absModifier, absSource))
5157             {
5158                 source = absSource;
5159                 mod = IGC::CombineModifier(EMOD_NEG, absModifier);
5160             }
5161             else
5162             {
5163                 source = modifierSource;
5164                 mod = EMOD_NEG;
5165             }
5166             return true;
5167         }
5168         else if (isAbs(&modifier, mod, modifierSource))
5169         {
5170             source = modifierSource;
5171             return true;
5172         }
5173         return false;
5174     }
5175 
IsNegate(llvm::Instruction * inst,llvm::Value * & negateSource)5176     bool IsNegate(llvm::Instruction* inst, llvm::Value*& negateSource)
5177     {
5178         BinaryOperator* binop = dyn_cast<BinaryOperator>(inst);
5179         if (binop &&
5180             (inst->getOpcode() == Instruction::FSub || inst->getOpcode() == Instruction::Sub))
5181         {
5182             if (IsZero(inst->getOperand(0)))
5183             {
5184                 negateSource = inst->getOperand(1);
5185                 return true;
5186             }
5187         }
5188 #if LLVM_VERSION_MAJOR >= 10
5189         UnaryOperator* unop = dyn_cast<UnaryOperator>(inst);
5190         if (unop && inst->getOpcode() == Instruction::FNeg)
5191         {
5192             negateSource = inst->getOperand(0);
5193             return true;
5194         }
5195 #endif
5196         return false;
5197     }
5198 
IsZero(llvm::Value * zero)5199     bool IsZero(llvm::Value* zero)
5200     {
5201         if (llvm::ConstantFP * FCst = llvm::dyn_cast<llvm::ConstantFP>(zero))
5202         {
5203             if (FCst->isZero())
5204             {
5205                 return true;
5206             }
5207         }
5208         if (llvm::ConstantInt * ICst = llvm::dyn_cast<llvm::ConstantInt>(zero))
5209         {
5210             if (ICst->isZero())
5211             {
5212                 return true;
5213             }
5214         }
5215         return false;
5216     }
5217 
isMinOrMax(llvm::Value * inst,llvm::Value * & source0,llvm::Value * & source1,bool & isMin,bool & isUnsigned)5218     inline bool isMinOrMax(llvm::Value* inst, llvm::Value*& source0, llvm::Value*& source1, bool& isMin, bool& isUnsigned)
5219     {
5220         bool found = false;
5221         llvm::Instruction* max = llvm::dyn_cast<llvm::Instruction>(inst);
5222         if (!max)
5223             return false;
5224 
5225         EOPCODE op = GetOpCode(max);
5226         if (op == llvm_min || op == llvm_max)
5227         {
5228             source0 = max->getOperand(0);
5229             source1 = max->getOperand(1);
5230             isUnsigned = false;
5231             isMin = (op == llvm_min);
5232             return true;
5233         }
5234         else if (op == llvm_select)
5235         {
5236             if (llvm::CmpInst * cmp = llvm::dyn_cast<llvm::CmpInst>(max->getOperand(0)))
5237             {
5238                 if (isGreaterOrLowerPredicate(cmp->getPredicate()))
5239                 {
5240                     if ((cmp->getOperand(0) == max->getOperand(1) && cmp->getOperand(1) == max->getOperand(2)) ||
5241                         (cmp->getOperand(0) == max->getOperand(2) && cmp->getOperand(1) == max->getOperand(1)))
5242                     {
5243                         source0 = max->getOperand(1);
5244                         source1 = max->getOperand(2);
5245                         isMin = isLowerPredicate(cmp->getPredicate()) ^ (cmp->getOperand(0) == max->getOperand(2));
5246                         isUnsigned = IsUnsignedCmp(cmp->getPredicate());
5247                         found = true;
5248                     }
5249                 }
5250             }
5251         }
5252         return found;
5253     }
5254 
isMax(llvm::Value * max,llvm::Value * & source0,llvm::Value * & source1)5255     inline bool isMax(llvm::Value* max, llvm::Value*& source0, llvm::Value*& source1)
5256     {
5257         bool isMin, isUnsigned;
5258         llvm::Value* maxSource0;
5259         llvm::Value* maxSource1;
5260         if (isMinOrMax(max, maxSource0, maxSource1, isMin, isUnsigned))
5261         {
5262             if (!isMin)
5263             {
5264                 source0 = maxSource0;
5265                 source1 = maxSource1;
5266                 return true;
5267             }
5268         }
5269         return false;
5270     }
5271 
isMin(llvm::Value * min,llvm::Value * & source0,llvm::Value * & source1)5272     inline bool isMin(llvm::Value* min, llvm::Value*& source0, llvm::Value*& source1)
5273     {
5274         bool isMin, isUnsigned;
5275         llvm::Value* maxSource0;
5276         llvm::Value* maxSource1;
5277         if (isMinOrMax(min, maxSource0, maxSource1, isMin, isUnsigned))
5278         {
5279             if (isMin)
5280             {
5281                 source0 = maxSource0;
5282                 source1 = maxSource1;
5283                 return true;
5284             }
5285         }
5286         return false;
5287     }
5288 
5289 
isOne(llvm::Value * zero)5290     bool isOne(llvm::Value* zero)
5291     {
5292         if (llvm::ConstantFP * FCst = llvm::dyn_cast<llvm::ConstantFP>(zero))
5293         {
5294             if (FCst->isExactlyValue(1.f))
5295             {
5296                 return true;
5297             }
5298         }
5299         if (llvm::ConstantInt * ICst = llvm::dyn_cast<llvm::ConstantInt>(zero))
5300         {
5301             if (ICst->isOne())
5302             {
5303                 return true;
5304             }
5305         }
5306         return false;
5307     }
5308 
isSat(llvm::Instruction * sat,llvm::Value * & source,bool & isUnsigned)5309     bool isSat(llvm::Instruction* sat, llvm::Value*& source, bool& isUnsigned)
5310     {
5311         bool found = false;
5312         llvm::Value* sources[2] = { 0 };
5313         bool floatMatch = sat->getType()->isFloatingPointTy();
5314         GenIntrinsicInst* intrin = dyn_cast<GenIntrinsicInst>(sat);
5315         if (intrin &&
5316             (intrin->getIntrinsicID() == GenISAIntrinsic::GenISA_fsat ||
5317             intrin->getIntrinsicID() == GenISAIntrinsic::GenISA_usat ||
5318             intrin->getIntrinsicID() == GenISAIntrinsic::GenISA_isat))
5319         {
5320             source = intrin->getOperand(0);
5321             found = true;
5322             isUnsigned = intrin->getIntrinsicID() == GenISAIntrinsic::GenISA_usat;
5323         }
5324         else if (floatMatch && isMax(sat, sources[0], sources[1]))
5325         {
5326             for (int i = 0; i < 2; i++)
5327             {
5328                 if (IsZero(sources[i]))
5329                 {
5330                     llvm::Value* maxSources[2] = { 0 };
5331                     if (isMin(sources[1 - i], maxSources[0], maxSources[1]))
5332                     {
5333                         for (int j = 0; j < 2; j++)
5334                         {
5335                             if (isOne(maxSources[j]))
5336                             {
5337                                 found = true;
5338                                 source = maxSources[1 - j];
5339                                 isUnsigned = false;
5340                                 break;
5341                             }
5342                         }
5343                     }
5344                     break;
5345                 }
5346             }
5347         }
5348         else if (floatMatch && isMin(sat, sources[0], sources[1]))
5349         {
5350             for (int i = 0; i < 2; i++)
5351             {
5352                 if (isOne(sources[i]))
5353                 {
5354                     llvm::Value* maxSources[2] = { 0 };
5355                     if (isMax(sources[1 - i], maxSources[0], maxSources[1]))
5356                     {
5357                         for (int j = 0; j < 2; j++)
5358                         {
5359                             if (IsZero(maxSources[j]))
5360                             {
5361                                 found = true;
5362                                 source = maxSources[1 - j];
5363                                 isUnsigned = false;
5364                                 break;
5365                             }
5366                         }
5367                     }
5368                     break;
5369                 }
5370             }
5371         }
5372         return found;
5373     }
5374 
isCandidateForConstantPool(llvm::Value * val)5375     bool isCandidateForConstantPool(llvm::Value * val)
5376     {
5377         auto ci = dyn_cast<ConstantInt>(val);
5378         bool isBigQW = ci && !ci->getValue().isNullValue() && !ci->getValue().isSignedIntN(32);
5379         bool isDF = val->getType()->isDoubleTy();
5380         return (isBigQW || isDF);
5381     }
5382 
GetBlockId(llvm::BasicBlock * block)5383     uint CodeGenPatternMatch::GetBlockId(llvm::BasicBlock* block)
5384     {
5385         auto it = m_blockMap.find(block);
5386         IGC_ASSERT(it != m_blockMap.end());
5387 
5388         uint blockID = it->second->id;
5389         return blockID;
5390     }
5391 }//namespace IGC
5392