1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2017-2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 //===----------------------------------------------------------------------===//
10 //
11 /// CMABI
12 /// -----
13 ///
14 /// This pass fixes ABI issues for the genx backend. Currently, it
15 ///
16 /// - transforms pass by pointer argument into copy-in and copy-out;
17 ///
18 /// - localizes global scalar or vector variables into copy-in and copy-out;
19 ///
20 /// - passes bool arguments as i8 (matches cm-icl's hehavior).
21 ///
22 //===----------------------------------------------------------------------===//
23 
24 #define DEBUG_TYPE "cmabi"
25 
26 #include "llvmWrapper/Analysis/CallGraph.h"
27 #include "llvmWrapper/IR/CallSite.h"
28 #include "llvmWrapper/IR/DerivedTypes.h"
29 #include "llvmWrapper/IR/Instructions.h"
30 #include "llvmWrapper/Support/Alignment.h"
31 
32 #include "Probe/Assertion.h"
33 
34 #include "vc/GenXOpts/GenXOpts.h"
35 #include "vc/GenXOpts/Utils/KernelInfo.h"
36 #include "vc/Support/BackendConfig.h"
37 #include "vc/Utils/GenX/BreakConst.h"
38 #include "vc/Utils/GenX/Printf.h"
39 #include "vc/Utils/General/DebugInfo.h"
40 #include "vc/Utils/General/FunctionAttrs.h"
41 #include "vc/Utils/General/InstRebuilder.h"
42 #include "vc/Utils/General/STLExtras.h"
43 #include "vc/Utils/General/Types.h"
44 
45 #include "llvm/ADT/DenseMap.h"
46 #include "llvm/ADT/PostOrderIterator.h"
47 #include "llvm/ADT/SCCIterator.h"
48 #include "llvm/ADT/STLExtras.h"
49 #include "llvm/ADT/SetVector.h"
50 #include "llvm/ADT/Statistic.h"
51 #include "llvm/Analysis/CallGraphSCCPass.h"
52 #include "llvm/Analysis/PostDominators.h"
53 #include "llvm/GenXIntrinsics/GenXIntrinsics.h"
54 #include "llvm/GenXIntrinsics/GenXMetadata.h"
55 #include "llvm/IR/CFG.h"
56 #include "llvm/IR/DebugInfo.h"
57 #include "llvm/IR/DiagnosticInfo.h"
58 #include "llvm/IR/DiagnosticPrinter.h"
59 #include "llvm/IR/Dominators.h"
60 #include "llvm/IR/Function.h"
61 #include "llvm/IR/GlobalVariable.h"
62 #include "llvm/IR/IRBuilder.h"
63 #include "llvm/IR/InstIterator.h"
64 #include "llvm/IR/InstVisitor.h"
65 #include "llvm/IR/IntrinsicInst.h"
66 #include "llvm/IR/Intrinsics.h"
67 #include "llvm/IR/Module.h"
68 #include "llvm/IR/Verifier.h"
69 #include "llvm/InitializePasses.h"
70 #include "llvm/Pass.h"
71 #include "llvm/Support/CommandLine.h"
72 #include "llvm/Support/Debug.h"
73 #include "llvm/Support/raw_ostream.h"
74 #include "llvm/Transforms/Scalar.h"
75 #include "llvm/Transforms/Utils/Local.h"
76 
77 #include <algorithm>
78 #include <functional>
79 #include <iterator>
80 #include <numeric>
81 #include <stack>
82 #include <unordered_map>
83 #include <unordered_set>
84 #include <vector>
85 
86 using namespace llvm;
87 
88 STATISTIC(NumArgumentsTransformed, "Number of pointer arguments transformed");
89 
90 namespace llvm {
91 void initializeCMABIAnalysisPass(PassRegistry &);
92 void initializeCMABIPass(PassRegistry &);
93 void initializeCMLowerVLoadVStorePass(PassRegistry &);
94 }
95 
96 /// Localizing global variables
97 /// ^^^^^^^^^^^^^^^^^^^^^^^^^^^
98 ///
99 /// General idea of localizing global variables into locals. Globals used in
100 /// different kernels get a seperate copy and they are always invisiable to
101 /// other kernels and we can safely localize all globals used (including
102 /// indirectly) in a kernel. For example,
103 ///
104 /// .. code-block:: text
105 ///
106 ///   @gv1 = global <8 x float> zeroinitializer, align 32
107 ///   @gv2 = global <8 x float> zeroinitializer, align 32
108 ///   @gv3 = global <8 x float> zeroinitializer, align 32
109 ///
110 ///   define dllexport void @f0() {
111 ///     call @f1()
112 ///     call @f2()
113 ///     call @f3()
114 ///   }
115 ///
116 ///   define internal void @f1() {
117 ///     ; ...
118 ///     store <8 x float> %splat1, <8 x float>* @gv1, align 32
119 ///   }
120 ///
121 ///   define internal void @f2() {
122 ///     ; ...
123 ///     store <8 x float> %splat2, <8 x float>* @gv2, align 32
124 ///   }
125 ///
126 ///   define internal void @f3() {
127 ///     %1 = <8 x float>* @gv1, align 32
128 ///     %2 = <8 x float>* @gv2, align 32
129 ///     %3 = fadd <8 x float> %1, <8 x float> %2
130 ///     store <8 x float> %3, <8 x float>* @gv3, align 32
131 ///   }
132 ///
133 /// will be transformed into
134 ///
135 /// .. code-block:: text
136 ///
137 ///   define dllexport void @f0() {
138 ///     %v1 = alloca <8 x float>, align 32
139 ///     %v2 = alloca <8 x float>, align 32
140 ///     %v3 = alloca <8 x float>, align 32
141 ///
142 ///     %0 = load <8 x float> * %v1, align 32
143 ///     %1 = { <8 x float> } call @f1_transformed(<8 x float> %0)
144 ///     %2 = extractvalue { <8 x float> } %1, 0
145 ///     store <8  x float> %2, <8 x float>* %v1, align 32
146 ///
147 ///     %3 = load <8 x float> * %v2, align 32
148 ///     %4 = { <8 x float> } call @f2_transformed(<8 x float> %3)
149 ///     %5 = extractvalue { <8 x float> } %4, 0
150 ///     store <8  x float> %5, <8 x float>* %v1, align 32
151 ///
152 ///     %6 = load <8 x float> * %v1, align 32
153 ///     %7 = load <8 x float> * %v2, align 32
154 ///     %8 = load <8 x float> * %v3, align 32
155 ///
156 ///     %9 = { <8 x float>, <8 x float>, <8 x float> }
157 ///          call @f3_transformed(<8 x float> %6, <8 x float> %7, <8 x float> %8)
158 ///
159 ///     %10 = extractvalue { <8 x float>, <8 x float>, <8 x float> } %9, 0
160 ///     store <8  x float> %10, <8 x float>* %v1, align 32
161 ///     %11 = extractvalue { <8 x float>, <8 x float>, <8 x float> } %9, 1
162 ///     store <8  x float> %11, <8 x float>* %v2, align 32
163 ///     %12 = extractvalue { <8 x float>, <8 x float>, <8 x float> } %9, 2
164 ///     store <8  x float> %12, <8 x float>* %v3, align 32
165 ///   }
166 ///
167 /// All callees will be updated accordingly, E.g. f1_transformed becomes
168 ///
169 /// .. code-block:: text
170 ///
171 ///   define internal { <8 x float> } @f1_transformed(<8 x float> %v1) {
172 ///     %0 = alloca <8 x float>, align 32
173 ///     store <8 x float> %v1, <8 x float>* %0, align 32
174 ///     ; ...
175 ///     store <8 x float> %splat1, <8 x float>* @0, align 32
176 ///     ; ...
177 ///     %1 = load <8 x float>* %0, align 32
178 ///     %2 = insertvalue { <8 x float> } undef, <8 x float> %1, 0
179 ///     ret { <8 x float> } %2
180 ///   }
181 ///
182 namespace {
183 
184 // \brief Collect necessary information for global variable localization.
185 class LocalizationInfo {
186 public:
187   typedef SetVector<GlobalVariable *> GlobalSetTy;
188 
LocalizationInfo(Function * F)189   explicit LocalizationInfo(Function *F) : Fn(F) {}
LocalizationInfo()190   LocalizationInfo() : Fn(0) {}
191 
getFunction() const192   Function *getFunction() const { return Fn; }
empty() const193   bool empty() const { return Globals.empty(); }
getGlobals()194   GlobalSetTy &getGlobals() { return Globals; }
195 
196   // \brief Add a global.
addGlobal(GlobalVariable * GV)197   void addGlobal(GlobalVariable *GV) {
198     Globals.insert(GV);
199   }
200 
201   // \brief Add all globals from callee.
addGlobals(LocalizationInfo & LI)202   void addGlobals(LocalizationInfo &LI) {
203     Globals.insert(LI.getGlobals().begin(), LI.getGlobals().end());
204   }
205 
206 private:
207   // \brief The function being analyzed.
208   Function *Fn;
209 
210   // \brief Global variables that are used directly or indirectly.
211   GlobalSetTy Globals;
212 };
213 
214 // Diagnostic information for error/warning for overlapping arg
215 class DiagnosticInfoOverlappingArgs : public DiagnosticInfo {
216 private:
217   std::string Description;
218   StringRef Filename;
219   unsigned Line;
220   unsigned Col;
221   static int KindID;
getKindID()222   static int getKindID() {
223     if (KindID == 0)
224       KindID = llvm::getNextAvailablePluginDiagnosticKind();
225     return KindID;
226   }
227 public:
228   // Initialize from an Instruction and an Argument.
229   DiagnosticInfoOverlappingArgs(Instruction *Inst,
230       const Twine &Desc, DiagnosticSeverity Severity = DS_Error);
231   void print(DiagnosticPrinter &DP) const override;
232 
classof(const DiagnosticInfo * DI)233   static bool classof(const DiagnosticInfo *DI) {
234     return DI->getKind() == getKindID();
235   }
236 };
237 int DiagnosticInfoOverlappingArgs::KindID = 0;
238 
239 class CMABIAnalysis : public ModulePass {
240   // This map captures all global variables to be localized.
241   std::vector<LocalizationInfo *> LocalizationInfoObjs;
242   bool SaveStackCallLinkage = false;
243 
244 public:
245   static char ID;
246 
247   // Kernels in the module being processed.
248   SmallPtrSet<Function *, 8> Kernels;
249 
250   // Map from function to the index of its LI in LI storage
251   SmallDenseMap<Function *, LocalizationInfo *> GlobalInfo;
252 
253   // Function control option if any
254   FunctionControl FCtrl;
255 
CMABIAnalysis()256   CMABIAnalysis() : ModulePass{ID} {}
257 
getAnalysisUsage(AnalysisUsage & AU) const258   void getAnalysisUsage(AnalysisUsage &AU) const override {
259     AU.addRequired<CallGraphWrapperPass>();
260     AU.addRequired<GenXBackendConfig>();
261     AU.setPreservesAll();
262   }
263 
getPassName() const264   StringRef getPassName() const override { return "GenX CMABI analysis"; }
265 
266   bool runOnModule(Module &M) override;
267 
releaseMemory()268   void releaseMemory() override {
269     for (auto *LI : LocalizationInfoObjs)
270       delete LI;
271     LocalizationInfoObjs.clear();
272     Kernels.clear();
273     GlobalInfo.clear();
274   }
275 
276   // \brief Returns the localization info associated to a function.
getLocalizationInfo(Function * F)277   LocalizationInfo &getLocalizationInfo(Function *F) {
278     if (GlobalInfo.count(F))
279       return *GlobalInfo[F];
280     LocalizationInfo *LI = new LocalizationInfo{F};
281     LocalizationInfoObjs.push_back(LI);
282     GlobalInfo[F] = LI;
283     return *LI;
284   }
285 
286 private:
287   bool runOnCallGraph(CallGraph &CG);
288   void analyzeGlobals(CallGraph &CG);
289 
addDirectGlobal(Function * F,GlobalVariable * GV)290   void addDirectGlobal(Function *F, GlobalVariable *GV) {
291     getLocalizationInfo(F).addGlobal(GV);
292   }
293 
294   // \brief Add all globals from callee to caller.
addIndirectGlobal(Function * F,Function * Callee)295   void addIndirectGlobal(Function *F, Function *Callee) {
296     getLocalizationInfo(F).addGlobals(getLocalizationInfo(Callee));
297   }
298 
299   void defineGVDirectUsers(GlobalVariable &GV);
300 };
301 
302 struct CMABI : public CallGraphSCCPass {
303   static char ID;
304 
CMABI__anon278d8c6d0111::CMABI305   CMABI() : CallGraphSCCPass(ID) {
306     initializeCMABIPass(*PassRegistry::getPassRegistry());
307   }
308 
getAnalysisUsage__anon278d8c6d0111::CMABI309   void getAnalysisUsage(AnalysisUsage &AU) const override {
310     CallGraphSCCPass::getAnalysisUsage(AU);
311     AU.addRequired<CMABIAnalysis>();
312   }
313 
314   bool runOnSCC(CallGraphSCC &SCC) override;
315 
316 private:
317 
318   CallGraphNode *ProcessNode(CallGraphNode *CGN);
319 
320   // Fix argument passing for kernels.
321   CallGraphNode *TransformKernel(Function *F);
322 
323   // Major work is done in this method.
324   CallGraphNode *TransformNode(Function &F,
325                                SmallPtrSet<Argument *, 8> &ArgsToTransform,
326                                LocalizationInfo &LI);
327 
328   // \brief Create allocas for globals and replace their uses.
329   void LocalizeGlobals(LocalizationInfo &LI);
330 
331   // Return true if pointer type arugment is only used to
332   // load or store a simple value. This helps decide whehter
333   // it is safe to convert ptr arg to by-value arg or
334   // simple-value copy-in-copy-out.
335   bool OnlyUsedBySimpleValueLoadStore(Value *Arg);
336 
337   // \brief Diagnose illegal overlapping by-ref args.
338   void diagnoseOverlappingArgs(CallInst *CI);
339 
340   // Already visited functions.
341   SmallPtrSet<Function *, 8> AlreadyVisited;
342   CMABIAnalysis *Info;
343 };
344 
345 } // namespace
346 
347 char CMABIAnalysis::ID = 0;
348 INITIALIZE_PASS_BEGIN(CMABIAnalysis, "cmabi-analysis",
349                       "helper analysis pass to get info for CMABI", false, true)
INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)350 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
351 INITIALIZE_PASS_DEPENDENCY(GenXBackendConfig)
352 INITIALIZE_PASS_END(CMABIAnalysis, "cmabi-analysis",
353                     "Fix ABI issues for the genx backend", false, true)
354 
355 bool CMABIAnalysis::runOnModule(Module &M) {
356   auto &&BCfg = getAnalysis<GenXBackendConfig>();
357   FCtrl = BCfg.getFCtrl();
358   SaveStackCallLinkage = BCfg.saveStackCallLinkage();
359 
360   runOnCallGraph(getAnalysis<CallGraphWrapperPass>().getCallGraph());
361   return false;
362 }
363 
runOnCallGraph(CallGraph & CG)364 bool CMABIAnalysis::runOnCallGraph(CallGraph &CG) {
365   // Analyze global variable usages and for each function attaches global
366   // variables to be copy-in and copy-out.
367   analyzeGlobals(CG);
368 
369   auto getValue = [](Metadata *M) -> Value * {
370     if (auto VM = dyn_cast<ValueAsMetadata>(M))
371       return VM->getValue();
372     return nullptr;
373   };
374 
375   // Collect all CM kernels from named metadata.
376   if (NamedMDNode *Named =
377           CG.getModule().getNamedMetadata(genx::FunctionMD::GenXKernels)) {
378     IGC_ASSERT(Named);
379     for (unsigned I = 0, E = Named->getNumOperands(); I != E; ++I) {
380       MDNode *Node = Named->getOperand(I);
381       if (Function *F =
382               dyn_cast_or_null<Function>(getValue(Node->getOperand(0))))
383         Kernels.insert(F);
384     }
385   }
386 
387   // no change.
388   return false;
389 }
390 
runOnSCC(CallGraphSCC & SCC)391 bool CMABI::runOnSCC(CallGraphSCC &SCC) {
392   Info = &getAnalysis<CMABIAnalysis>();
393   bool Changed = false;
394   bool LocalChange;
395 
396   // Diagnose overlapping by-ref args.
397   for (auto i = SCC.begin(), e = SCC.end(); i != e; ++i) {
398     Function *F = (*i)->getFunction();
399     if (!F || F->empty())
400       continue;
401     for (auto ui = F->use_begin(), ue = F->use_end(); ui != ue; ++ui) {
402       auto CI = dyn_cast<CallInst>(ui->getUser());
403       if (CI && CI->getNumArgOperands() == ui->getOperandNo())
404         diagnoseOverlappingArgs(CI);
405     }
406   }
407 
408   // Iterate until we stop transforming from this SCC.
409   do {
410     LocalChange = false;
411     for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); I != E; ++I) {
412       if (CallGraphNode *CGN = ProcessNode(*I)) {
413         LocalChange = true;
414         SCC.ReplaceNode(*I, CGN);
415       }
416     }
417     Changed |= LocalChange;
418   } while (LocalChange);
419 
420   return Changed;
421 }
422 
423 // Replaces uses of global variables with the corresponding allocas inside a
424 // specified function. More insts can be rebuild if global variable addrspace
425 // wasn't private.
replaceUsesWithinFunction(const SmallDenseMap<Value *,Value * > & GlobalsToReplace,Function * F)426 static void replaceUsesWithinFunction(
427     const SmallDenseMap<Value *, Value *> &GlobalsToReplace, Function *F) {
428   for (auto &BB : *F) {
429     for (auto &Inst : BB) {
430       for (unsigned i = 0, e = Inst.getNumOperands(); i < e; ++i) {
431         Value *Op = Inst.getOperand(i);
432         auto Iter = GlobalsToReplace.find(Op);
433         if (Iter != GlobalsToReplace.end()) {
434           IGC_ASSERT_MESSAGE(Op->getType() == Iter->second->getType(),
435                              "only global variables in private addrspace are "
436                              "localized, so types must match");
437           Inst.setOperand(i, Iter->second);
438         }
439       }
440     }
441   }
442 }
443 
444 // \brief Create allocas for globals directly used in this kernel and
445 // replace all uses.
446 //
447 // FIXME: it is not always posible to localize globals with addrspace different
448 // from private. In some cases type info link is lost - casts, stores of
449 // pointers.
LocalizeGlobals(LocalizationInfo & LI)450 void CMABI::LocalizeGlobals(LocalizationInfo &LI) {
451   const LocalizationInfo::GlobalSetTy &Globals = LI.getGlobals();
452   typedef LocalizationInfo::GlobalSetTy::const_iterator IteratorTy;
453 
454   SmallDenseMap<Value *, Value *> GlobalsToReplace;
455   Function *Fn = LI.getFunction();
456   for (IteratorTy I = Globals.begin(), E = Globals.end(); I != E; ++I) {
457     GlobalVariable *GV = (*I);
458     LLVM_DEBUG(dbgs() << "Localizing global: " << *GV << "\n  ");
459 
460     Instruction &FirstI = *Fn->getEntryBlock().begin();
461     Type *ElemTy = GV->getType()->getElementType();
462     AllocaInst *Alloca = new AllocaInst(ElemTy, vc::AddrSpace::Private,
463                                         GV->getName() + ".local", &FirstI);
464 
465     if (GV->getAlignment())
466       Alloca->setAlignment(IGCLLVM::getCorrectAlign(GV->getAlignment()));
467 
468     if (!isa<UndefValue>(GV->getInitializer()))
469       new StoreInst(GV->getInitializer(), Alloca, &FirstI);
470 
471     vc::DIBuilder::createDbgDeclareForLocalizedGlobal(*Alloca, *GV, FirstI);
472 
473     GlobalsToReplace.insert(std::make_pair(GV, Alloca));
474   }
475 
476   // Replaces all globals uses within this function.
477   replaceUsesWithinFunction(GlobalsToReplace, Fn);
478 }
479 
ProcessNode(CallGraphNode * CGN)480 CallGraphNode *CMABI::ProcessNode(CallGraphNode *CGN) {
481   Function *F = CGN->getFunction();
482 
483   // Nothing to do for declarations or already visited functions.
484   if (!F || F->isDeclaration() || AlreadyVisited.count(F))
485     return 0;
486 
487   vc::breakConstantExprs(F, vc::LegalizationStage::NotLegalized);
488 
489   // Variables to be localized.
490   LocalizationInfo &LI = Info->getLocalizationInfo(F);
491 
492   // This is a kernel.
493   if (Info->Kernels.count(F)) {
494     // Localize globals for kernels.
495     if (!LI.getGlobals().empty())
496       LocalizeGlobals(LI);
497 
498     // Check whether there are i1 or vxi1 kernel arguments.
499     for (auto AI = F->arg_begin(), AE = F->arg_end(); AI != AE; ++AI)
500       if (AI->getType()->getScalarType()->isIntegerTy(1))
501         return TransformKernel(F);
502 
503     // No changes to this kernel's prototype.
504     return 0;
505   }
506 
507   // Have to localize implicit arg globals in functions with fixed signature.
508   // FIXME: There's no verification that globals are for implicit args. General
509   //        private globals may be localized here, but it is not possible to
510   //        use them in such functions at all. A nice place for diagnostics.
511   if (vc::isFixedSignatureFunc(*F)) {
512     if (!LI.getGlobals().empty())
513       LocalizeGlobals(LI);
514     return nullptr;
515   }
516 
517   SmallVector<Argument*, 16> PointerArgs;
518   for (auto &Arg: F->args())
519     if (Arg.getType()->isPointerTy())
520       PointerArgs.push_back(&Arg);
521 
522   // Check if there is any pointer arguments or globals to localize.
523   if (PointerArgs.empty() && LI.empty())
524     return 0;
525 
526   // Check transformable arguments.
527   SmallPtrSet<Argument*, 8> ArgsToTransform;
528   for (Argument *PtrArg: PointerArgs) {
529     Type *ArgTy = cast<PointerType>(PtrArg->getType())->getElementType();
530     // Only transform to simple types.
531     if ((ArgTy->isVectorTy() || OnlyUsedBySimpleValueLoadStore(PtrArg)) &&
532         (ArgTy->isIntOrIntVectorTy() || ArgTy->isFPOrFPVectorTy()))
533       ArgsToTransform.insert(PtrArg);
534   }
535 
536   if (ArgsToTransform.empty() && LI.empty())
537     return 0;
538 
539   return TransformNode(*F, ArgsToTransform, LI);
540 }
541 
542 // Returns true if data is only read using load-like intrinsics. The result may
543 // be false negative.
isSinkedToLoadIntrinsics(const Instruction * Inst)544 static bool isSinkedToLoadIntrinsics(const Instruction *Inst) {
545   if (isa<CallInst>(Inst)) {
546     auto *CI = cast<CallInst>(Inst);
547     auto IID = GenXIntrinsic::getAnyIntrinsicID(CI->getCalledFunction());
548     return IID == GenXIntrinsic::genx_svm_gather ||
549            IID == GenXIntrinsic::genx_gather_scaled;
550   }
551   return std::all_of(Inst->user_begin(), Inst->user_end(), [](const User *U) {
552     if (isa<InsertElementInst>(U) || isa<ShuffleVectorInst>(U) ||
553         isa<BinaryOperator>(U) || isa<CallInst>(U))
554       return isSinkedToLoadIntrinsics(cast<Instruction>(U));
555     return false;
556   });
557 }
558 
559 // Arg is a ptr to a vector type. If data is only read using load, then false is
560 // returned. Otherwise, or if it is not clear, true is returned. This is a
561 // recursive function. The result may be false positive.
isPtrArgModified(const Value & Arg)562 static bool isPtrArgModified(const Value &Arg) {
563   // User iterator returns pointer both for star and arrow operators, because...
564   return std::any_of(Arg.user_begin(), Arg.user_end(), [](const User *U) {
565     if (isa<LoadInst>(U))
566       return false;
567     if (isa<AddrSpaceCastInst>(U) || isa<BitCastInst>(U) ||
568         isa<GetElementPtrInst>(U))
569       return isPtrArgModified(*U);
570     if (isa<PtrToIntInst>(U))
571       return !isSinkedToLoadIntrinsics(cast<Instruction>(U));
572     return true;
573   });
574 }
575 
OnlyUsedBySimpleValueLoadStore(Value * Arg)576 bool CMABI::OnlyUsedBySimpleValueLoadStore(Value *Arg) {
577   for (const auto &U : Arg->users()) {
578     auto *I = dyn_cast<Instruction>(U);
579     if (!I)
580       return false;
581 
582     if (auto LI = dyn_cast<LoadInst>(U)) {
583       if (Arg != LI->getPointerOperand())
584         return false;
585     }
586     else if (auto SI = dyn_cast<StoreInst>(U)) {
587       if (Arg != SI->getPointerOperand())
588         return false;
589     }
590     else if (auto GEP = dyn_cast<GetElementPtrInst>(U)) {
591       if (Arg != GEP->getPointerOperand())
592         return false;
593       else if (!GEP->hasAllZeroIndices())
594         return false;
595       if (!OnlyUsedBySimpleValueLoadStore(U))
596         return false;
597     }
598     else if (isa<AddrSpaceCastInst>(U) || isa<PtrToIntInst>(U)) {
599       if (!OnlyUsedBySimpleValueLoadStore(U))
600         return false;
601     }
602     else
603       return false;
604   }
605   return true;
606 }
607 
608 // \brief Fix argument passing for kernels: i1 -> i8.
TransformKernel(Function * F)609 CallGraphNode *CMABI::TransformKernel(Function *F) {
610   IGC_ASSERT(F->getReturnType()->isVoidTy());
611   LLVMContext &Context = F->getContext();
612 
613   AttributeList AttrVec;
614   const AttributeList &PAL = F->getAttributes();
615 
616   // First, determine the new argument list
617   SmallVector<Type *, 8> ArgTys;
618   unsigned ArgIndex = 0;
619   for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
620        ++I, ++ArgIndex) {
621     Type *ArgTy = I->getType();
622     // Change i1 to i8 and vxi1 to vxi8
623     if (ArgTy->getScalarType()->isIntegerTy(1)) {
624       Type *Ty = IntegerType::get(F->getContext(), 8);
625       if (ArgTy->isVectorTy())
626         ArgTys.push_back(IGCLLVM::FixedVectorType::get(
627             Ty, dyn_cast<IGCLLVM::FixedVectorType>(ArgTy)->getNumElements()));
628       else
629         ArgTys.push_back(Ty);
630     } else {
631       // Unchanged argument
632       AttributeSet attrs = PAL.getParamAttributes(ArgIndex);
633       if (attrs.hasAttributes()) {
634         AttrBuilder B(attrs);
635         AttrVec = AttrVec.addParamAttributes(Context, ArgTys.size(), B);
636       }
637       ArgTys.push_back(I->getType());
638     }
639   }
640 
641   FunctionType *NFTy = FunctionType::get(F->getReturnType(), ArgTys, false);
642   IGC_ASSERT_MESSAGE((NFTy != F->getFunctionType()),
643     "type out of sync, expect bool arguments");
644 
645   // Add any function attributes.
646   AttributeSet FnAttrs = PAL.getFnAttributes();
647   if (FnAttrs.hasAttributes()) {
648     AttrBuilder B(FnAttrs);
649     AttrVec = AttrVec.addAttributes(Context, AttributeList::FunctionIndex, B);
650   }
651 
652   // Create the new function body and insert it into the module.
653   Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName());
654 
655   LLVM_DEBUG(dbgs() << "\nCMABI: Transforming From:" << *F);
656   vc::transferNameAndCCWithNewAttr(AttrVec, *F, *NF);
657   F->getParent()->getFunctionList().insert(F->getIterator(), NF);
658   vc::transferDISubprogram(*F, *NF);
659   LLVM_DEBUG(dbgs() << "  --> To: " << *NF << "\n");
660 
661   // Since we have now created the new function, splice the body of the old
662   // function right into the new function.
663   NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());
664 
665   // Loop over the argument list, transferring uses of the old arguments over to
666   // the new arguments, also transferring over the names as well.
667   for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(),
668                               I2 = NF->arg_begin();
669        I != E; ++I, ++I2) {
670     // For an unmodified argument, move the name and users over.
671     if (!I->getType()->getScalarType()->isIntegerTy(1)) {
672       I->replaceAllUsesWith(I2);
673       I2->takeName(I);
674     } else {
675       Instruction *InsertPt = &*(NF->begin()->begin());
676       Instruction *Conv = new TruncInst(I2, I->getType(), "tobool", InsertPt);
677       I->replaceAllUsesWith(Conv);
678       I2->takeName(I);
679     }
680   }
681 
682   CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
683   CallGraphNode *NF_CGN = CG.getOrInsertFunction(NF);
684 
685   // Update the metadata entry.
686   if (F->hasDLLExportStorageClass())
687     NF->setDLLStorageClass(F->getDLLStorageClass());
688 
689   genx::replaceFunctionRefMD(*F, *NF);
690 
691   // Now that the old function is dead, delete it. If there is a dangling
692   // reference to the CallgraphNode, just leave the dead function around.
693   NF_CGN->stealCalledFunctionsFrom(CG[F]);
694   CallGraphNode *CGN = CG[F];
695   if (CGN->getNumReferences() == 0)
696     delete CG.removeFunctionFromModule(CGN);
697   else
698     F->setLinkage(Function::ExternalLinkage);
699 
700   return NF_CGN;
701 }
702 
703 namespace {
704 struct TransformedFuncType {
705   SmallVector<Type *, 8> Ret;
706   SmallVector<Type *, 8> Args;
707 };
708 
709 enum class ArgKind { General, CopyIn, CopyInOut };
710 enum class GlobalArgKind { ByValueIn, ByValueInOut, ByPointer };
711 
712 struct GlobalArgInfo {
713   GlobalVariable *GV;
714   GlobalArgKind Kind;
715 };
716 
717 struct GlobalArgsInfo {
718   static constexpr int UndefIdx = -1;
719   std::vector<GlobalArgInfo> Globals;
720   int FirstGlobalArgIdx = UndefIdx;
721 
getGlobalInfoForArgNo__anon278d8c6d0511::GlobalArgsInfo722   GlobalArgInfo getGlobalInfoForArgNo(int ArgIdx) const {
723     IGC_ASSERT_MESSAGE(FirstGlobalArgIdx != UndefIdx,
724                        "first global arg index isn't set");
725     auto Idx = ArgIdx - FirstGlobalArgIdx;
726     IGC_ASSERT_MESSAGE(Idx >= 0, "out of bound access");
727     IGC_ASSERT_MESSAGE(Idx < static_cast<int>(Globals.size()),
728                        "out of bound access");
729     return Globals[ArgIdx - FirstGlobalArgIdx];
730   }
731 
getGlobalForArgNo__anon278d8c6d0511::GlobalArgsInfo732   GlobalVariable *getGlobalForArgNo(int ArgIdx) const {
733     return getGlobalInfoForArgNo(ArgIdx).GV;
734   }
735 };
736 
737 struct RetToArgInfo {
738   static constexpr int OrigRetNoArg = -1;
739   std::vector<int> Map;
740 };
741 
742 // Whether provided \p GV should be passed by pointer.
passLocalizedGlobalByPointer(const GlobalValue & GV)743 static bool passLocalizedGlobalByPointer(const GlobalValue &GV) {
744   auto *Type = GV.getType()->getPointerElementType();
745   return Type->isAggregateType();
746 }
747 
748 struct ParameterAttrInfo {
749   unsigned ArgIndex;
750   Attribute::AttrKind Attr;
751 };
752 
753 // Computing a new prototype for the function. E.g.
754 //
755 // i32 @foo(i32, <8 x i32>*) becomes {i32, <8 x i32>} @bar(i32, <8 x i32>)
756 //
757 class TransformedFuncInfo {
758   TransformedFuncType NewFuncType;
759   AttributeList Attrs;
760   std::vector<ArgKind> ArgKinds;
761   std::vector<ParameterAttrInfo> DiscardedParameterAttrs;
762   RetToArgInfo RetToArg;
763   GlobalArgsInfo GlobalArgs;
764 
765 public:
TransformedFuncInfo(Function & OrigFunc,SmallPtrSetImpl<Argument * > & ArgsToTransform)766   TransformedFuncInfo(Function &OrigFunc,
767                       SmallPtrSetImpl<Argument *> &ArgsToTransform) {
768     FillCopyInOutInfo(OrigFunc, ArgsToTransform);
769     std::transform(OrigFunc.arg_begin(), OrigFunc.arg_end(), std::back_inserter(NewFuncType.Args),
770         [&ArgsToTransform](Argument &Arg) {
771           if (ArgsToTransform.count(&Arg))
772             return Arg.getType()->getPointerElementType();
773           return Arg.getType();
774         });
775     InheritAttributes(OrigFunc);
776 
777     // struct-returns are not supported for transformed functions,
778     // so we need to discard the attribute
779     if (OrigFunc.hasStructRetAttr() && OrigFunc.hasLocalLinkage())
780       DiscardStructRetAttr(OrigFunc.getContext());
781 
782     auto *OrigRetTy = OrigFunc.getFunctionType()->getReturnType();
783     if (!OrigRetTy->isVoidTy()) {
784       NewFuncType.Ret.push_back(OrigRetTy);
785       RetToArg.Map.push_back(RetToArgInfo::OrigRetNoArg);
786     }
787     AppendRetCopyOutInfo();
788   }
789 
AppendGlobals(LocalizationInfo & LI)790   void AppendGlobals(LocalizationInfo &LI) {
791     IGC_ASSERT_MESSAGE(GlobalArgs.FirstGlobalArgIdx == GlobalArgsInfo::UndefIdx,
792                        "can only be initialized once");
793     GlobalArgs.FirstGlobalArgIdx = NewFuncType.Args.size();
794     for (auto *GV : LI.getGlobals()) {
795       if (passLocalizedGlobalByPointer(*GV)) {
796         NewFuncType.Args.push_back(vc::changeAddrSpace(
797             cast<PointerType>(GV->getType()), vc::AddrSpace::Private));
798         GlobalArgs.Globals.push_back({GV, GlobalArgKind::ByPointer});
799       } else {
800         int ArgIdx = NewFuncType.Args.size();
801         Type *PointeeTy = GV->getType()->getPointerElementType();
802         NewFuncType.Args.push_back(PointeeTy);
803         if (GV->isConstant())
804           GlobalArgs.Globals.push_back({GV, GlobalArgKind::ByValueIn});
805         else {
806           GlobalArgs.Globals.push_back({GV, GlobalArgKind::ByValueInOut});
807           NewFuncType.Ret.push_back(PointeeTy);
808           RetToArg.Map.push_back(ArgIdx);
809         }
810       }
811     }
812   }
813 
getType() const814   const TransformedFuncType &getType() const { return NewFuncType; }
getAttributes() const815   AttributeList getAttributes() const { return Attrs; }
getArgKinds() const816   const std::vector<ArgKind> &getArgKinds() const { return ArgKinds; }
getDiscardedParameterAttrs() const817   const std::vector<ParameterAttrInfo> &getDiscardedParameterAttrs() const {
818     return DiscardedParameterAttrs;
819   }
getGlobalArgsInfo() const820   const GlobalArgsInfo &getGlobalArgsInfo() const { return GlobalArgs; }
getRetToArgInfo() const821   const RetToArgInfo &getRetToArgInfo() const { return RetToArg; }
822 
823 private:
FillCopyInOutInfo(Function & OrigFunc,SmallPtrSetImpl<Argument * > & ArgsToTransform)824   void FillCopyInOutInfo(Function &OrigFunc,
825                          SmallPtrSetImpl<Argument *> &ArgsToTransform) {
826     IGC_ASSERT_MESSAGE(ArgKinds.empty(),
827                        "shouldn't be filled before this method");
828     llvm::transform(OrigFunc.args(), std::back_inserter(ArgKinds),
829                     [&ArgsToTransform](Argument &Arg) {
830                       if (!ArgsToTransform.count(&Arg))
831                         return ArgKind::General;
832                       if (isPtrArgModified(Arg))
833                         return ArgKind::CopyInOut;
834                       return ArgKind::CopyIn;
835                     });
836   }
837 
InheritAttributes(Function & OrigFunc)838   void InheritAttributes(Function &OrigFunc) {
839     LLVMContext &Context = OrigFunc.getContext();
840     const AttributeList &OrigAttrs = OrigFunc.getAttributes();
841 
842     // Inherit argument attributes
843     for (auto ArgInfo : enumerate(ArgKinds)) {
844       if (ArgInfo.value() == ArgKind::General) {
845         AttributeSet ArgAttrs = OrigAttrs.getParamAttributes(ArgInfo.index());
846         if (ArgAttrs.hasAttributes())
847           Attrs = Attrs.addParamAttributes(Context, ArgInfo.index(),
848                                            AttrBuilder{ArgAttrs});
849       }
850     }
851 
852     // Inherit function attributes.
853     AttributeSet FnAttrs = OrigAttrs.getFnAttributes();
854     if (FnAttrs.hasAttributes()) {
855       AttrBuilder B(FnAttrs);
856       Attrs = Attrs.addAttributes(Context, AttributeList::FunctionIndex, B);
857     }
858   }
859 
DiscardStructRetAttr(LLVMContext & Context)860   void DiscardStructRetAttr(LLVMContext &Context) {
861     constexpr auto SretAttr = Attribute::StructRet;
862     for (auto ArgInfo : enumerate(ArgKinds)) {
863       unsigned ParamIndex = ArgInfo.index();
864       if (Attrs.hasParamAttr(ParamIndex, SretAttr)) {
865         Attrs = Attrs.removeParamAttribute(Context, ParamIndex, SretAttr);
866         DiscardedParameterAttrs.push_back({ParamIndex, SretAttr});
867       }
868     }
869   }
870 
AppendRetCopyOutInfo()871   void AppendRetCopyOutInfo() {
872     for (auto ArgInfo : enumerate(ArgKinds)) {
873       if (ArgInfo.value() == ArgKind::CopyInOut) {
874         NewFuncType.Ret.push_back(NewFuncType.Args[ArgInfo.index()]);
875         RetToArg.Map.push_back(ArgInfo.index());
876       }
877     }
878   }
879 };
880 } // namespace
881 
getRetType(LLVMContext & Context,const TransformedFuncType & TFType)882 static Type *getRetType(LLVMContext &Context,
883                         const TransformedFuncType &TFType) {
884   if (TFType.Ret.empty())
885     return Type::getVoidTy(Context);
886   return StructType::get(Context, TFType.Ret);
887 }
888 
createTransformedFuncDecl(Function & OrigFunc,const TransformedFuncInfo & TFuncInfo)889 Function *createTransformedFuncDecl(Function &OrigFunc,
890                                     const TransformedFuncInfo &TFuncInfo) {
891   LLVMContext &Context = OrigFunc.getContext();
892   // Construct the new function type using the new arguments.
893   FunctionType *NewFuncTy = FunctionType::get(
894       getRetType(Context, TFuncInfo.getType()), TFuncInfo.getType().Args,
895       OrigFunc.getFunctionType()->isVarArg());
896 
897   // Create the new function body and insert it into the module.
898   Function *NewFunc =
899       Function::Create(NewFuncTy, OrigFunc.getLinkage(), OrigFunc.getName());
900 
901   LLVM_DEBUG(dbgs() << "\nCMABI: Transforming From:" << OrigFunc);
902   vc::transferNameAndCCWithNewAttr(TFuncInfo.getAttributes(), OrigFunc,
903                                    *NewFunc);
904   OrigFunc.getParent()->getFunctionList().insert(OrigFunc.getIterator(),
905                                                  NewFunc);
906   vc::transferDISubprogram(OrigFunc, *NewFunc);
907   LLVM_DEBUG(dbgs() << "  --> To: " << *NewFunc << "\n");
908 
909   return NewFunc;
910 }
911 
912 static std::vector<Value *>
getTransformedFuncCallArgs(CallInst & OrigCall,const TransformedFuncInfo & NewFuncInfo)913 getTransformedFuncCallArgs(CallInst &OrigCall,
914                            const TransformedFuncInfo &NewFuncInfo) {
915   std::vector<Value *> NewCallOps;
916 
917   // Loop over the operands, inserting loads in the caller.
918   for (auto &&[OrigArg, Kind] :
919        zip(IGCLLVM::args(OrigCall), NewFuncInfo.getArgKinds())) {
920     switch (Kind) {
921     case ArgKind::General:
922       NewCallOps.push_back(OrigArg.get());
923       break;
924     default: {
925       IGC_ASSERT_MESSAGE(Kind == ArgKind::CopyIn || Kind == ArgKind::CopyInOut,
926                          "unexpected arg kind");
927       LoadInst *Load =
928           new LoadInst(OrigArg.get()->getType()->getPointerElementType(),
929                        OrigArg.get(), OrigArg.get()->getName() + ".val",
930                        /* isVolatile */ false, &OrigCall);
931       NewCallOps.push_back(Load);
932       break;
933     }
934     }
935   }
936 
937   IGC_ASSERT_MESSAGE(NewCallOps.size() == IGCLLVM::arg_size(OrigCall),
938                      "varargs are unexpected");
939   return std::move(NewCallOps);
940 }
941 
942 static AttributeList
inheritCallAttributes(CallInst & OrigCall,int NumOrigFuncArgs,const TransformedFuncInfo & NewFuncInfo)943 inheritCallAttributes(CallInst &OrigCall, int NumOrigFuncArgs,
944                       const TransformedFuncInfo &NewFuncInfo) {
945   IGC_ASSERT_MESSAGE(OrigCall.getNumArgOperands() == NumOrigFuncArgs,
946                      "varargs aren't supported");
947   AttributeList NewCallAttrs;
948 
949   const AttributeList &CallPAL = OrigCall.getAttributes();
950   auto &Context = OrigCall.getContext();
951   for (auto ArgInfo : enumerate(NewFuncInfo.getArgKinds())) {
952     if (ArgInfo.value() == ArgKind::General) {
953       AttributeSet attrs =
954           OrigCall.getAttributes().getParamAttributes(ArgInfo.index());
955       if (attrs.hasAttributes()) {
956         AttrBuilder B(attrs);
957         NewCallAttrs =
958             NewCallAttrs.addParamAttributes(Context, ArgInfo.index(), B);
959       }
960     }
961   }
962 
963   for (auto DiscardInfo : NewFuncInfo.getDiscardedParameterAttrs()) {
964     NewCallAttrs = NewCallAttrs.removeParamAttribute(
965         Context, DiscardInfo.ArgIndex, DiscardInfo.Attr);
966   }
967 
968   // Add any function attributes.
969   if (CallPAL.hasAttributes(AttributeList::FunctionIndex)) {
970     AttrBuilder B(CallPAL.getFnAttributes());
971     NewCallAttrs =
972         NewCallAttrs.addAttributes(Context, AttributeList::FunctionIndex, B);
973   }
974 
975   return std::move(NewCallAttrs);
976 }
977 
handleRetValuePortion(int RetIdx,int ArgIdx,CallInst & OrigCall,CallInst & NewCall,IRBuilder<> & Builder,const TransformedFuncInfo & NewFuncInfo)978 static void handleRetValuePortion(int RetIdx, int ArgIdx, CallInst &OrigCall,
979                                   CallInst &NewCall, IRBuilder<> &Builder,
980                                   const TransformedFuncInfo &NewFuncInfo) {
981   // Original return value.
982   if (ArgIdx == RetToArgInfo::OrigRetNoArg) {
983     IGC_ASSERT_MESSAGE(RetIdx == 0, "only zero element of returned value can "
984                                     "be original function argument");
985     OrigCall.replaceAllUsesWith(
986         Builder.CreateExtractValue(&NewCall, RetIdx, "ret"));
987     return;
988   }
989   Value *OutVal = Builder.CreateExtractValue(&NewCall, RetIdx);
990   if (ArgIdx >= NewFuncInfo.getGlobalArgsInfo().FirstGlobalArgIdx) {
991     auto Kind =
992         NewFuncInfo.getGlobalArgsInfo().getGlobalInfoForArgNo(ArgIdx).Kind;
993     IGC_ASSERT_MESSAGE(Kind == GlobalArgKind::ByValueInOut,
994         "only passed by value localized global should be copied-out");
995     Builder.CreateStore(
996         OutVal, NewFuncInfo.getGlobalArgsInfo().getGlobalForArgNo(ArgIdx));
997   } else {
998     IGC_ASSERT_MESSAGE(NewFuncInfo.getArgKinds()[ArgIdx] == ArgKind::CopyInOut,
999                        "only copy in-out args are expected");
1000     Builder.CreateStore(OutVal, OrigCall.getArgOperand(ArgIdx));
1001   }
1002 }
1003 
handleGlobalArgs(Function & NewFunc,const GlobalArgsInfo & GlobalArgs)1004 static std::vector<Value *> handleGlobalArgs(Function &NewFunc,
1005                                              const GlobalArgsInfo &GlobalArgs) {
1006   // Collect all globals and their corresponding allocas.
1007   std::vector<Value *> LocalizedGloabls;
1008   Instruction *InsertPt = &*(NewFunc.begin()->getFirstInsertionPt());
1009 
1010   llvm::transform(drop_begin(NewFunc.args(), GlobalArgs.FirstGlobalArgIdx),
1011                   std::back_inserter(LocalizedGloabls),
1012                   [InsertPt](Argument &GVArg) -> Value * {
1013                     if (GVArg.getType()->isPointerTy())
1014                       return &GVArg;
1015                     AllocaInst *Alloca = new AllocaInst(
1016                         GVArg.getType(), vc::AddrSpace::Private, "", InsertPt);
1017                     new StoreInst(&GVArg, Alloca, InsertPt);
1018                     return Alloca;
1019                   });
1020   // Fancy naming and debug info.
1021   for (auto &&[GAI, GVArg, MaybeAlloca] :
1022        zip(GlobalArgs.Globals,
1023            drop_begin(NewFunc.args(), GlobalArgs.FirstGlobalArgIdx),
1024            LocalizedGloabls)) {
1025     GVArg.setName(GAI.GV->getName() + ".in");
1026     if (!GVArg.getType()->isPointerTy()) {
1027       IGC_ASSERT_MESSAGE(isa<AllocaInst>(MaybeAlloca),
1028           "an alloca is expected when pass localized global by value");
1029       MaybeAlloca->setName(GAI.GV->getName() + ".local");
1030 
1031       vc::DIBuilder::createDbgDeclareForLocalizedGlobal(
1032           *cast<AllocaInst>(MaybeAlloca), *GAI.GV, *InsertPt);
1033     }
1034   }
1035 
1036   SmallDenseMap<Value *, Value *> GlobalsToReplace;
1037   for (auto &&[GAI, LocalizedGlobal] :
1038        zip(GlobalArgs.Globals, LocalizedGloabls))
1039     GlobalsToReplace.insert(std::make_pair(GAI.GV, LocalizedGlobal));
1040   // Replaces all globals uses within this new function.
1041   replaceUsesWithinFunction(GlobalsToReplace, &NewFunc);
1042   return LocalizedGloabls;
1043 }
1044 
1045 static Value *
appendTransformedFuncRetPortion(Value & NewRetVal,int RetIdx,int ArgIdx,ReturnInst & OrigRet,IRBuilder<> & Builder,const TransformedFuncInfo & NewFuncInfo,const std::vector<Value * > & OrigArgReplacements,std::vector<Value * > & LocalizedGlobals)1046 appendTransformedFuncRetPortion(Value &NewRetVal, int RetIdx, int ArgIdx,
1047                                 ReturnInst &OrigRet, IRBuilder<> &Builder,
1048                                 const TransformedFuncInfo &NewFuncInfo,
1049                                 const std::vector<Value *> &OrigArgReplacements,
1050                                 std::vector<Value *> &LocalizedGlobals) {
1051   if (ArgIdx == RetToArgInfo::OrigRetNoArg) {
1052     IGC_ASSERT_MESSAGE(RetIdx == 0,
1053                        "original return value must be at zero index");
1054     Value *OrigRetVal = OrigRet.getReturnValue();
1055     IGC_ASSERT_MESSAGE(OrigRetVal, "type unexpected");
1056     IGC_ASSERT_MESSAGE(OrigRetVal->getType()->isSingleValueType(),
1057                "type unexpected");
1058     return Builder.CreateInsertValue(&NewRetVal, OrigRetVal, RetIdx);
1059   }
1060   if (ArgIdx >= NewFuncInfo.getGlobalArgsInfo().FirstGlobalArgIdx) {
1061     auto Kind =
1062         NewFuncInfo.getGlobalArgsInfo().getGlobalInfoForArgNo(ArgIdx).Kind;
1063     IGC_ASSERT_MESSAGE(Kind == GlobalArgKind::ByValueInOut,
1064         "only passed by value localized global should be copied-out");
1065     Value *LocalizedGlobal =
1066         LocalizedGlobals[ArgIdx -
1067                          NewFuncInfo.getGlobalArgsInfo().FirstGlobalArgIdx];
1068     IGC_ASSERT_MESSAGE(isa<AllocaInst>(LocalizedGlobal),
1069         "an alloca is expected when pass localized global by value");
1070     Value *LocalizedGlobalVal = Builder.CreateLoad(
1071         LocalizedGlobal->getType()->getPointerElementType(), LocalizedGlobal);
1072     return Builder.CreateInsertValue(&NewRetVal, LocalizedGlobalVal, RetIdx);
1073   }
1074   IGC_ASSERT_MESSAGE(NewFuncInfo.getArgKinds()[ArgIdx] == ArgKind::CopyInOut,
1075                      "Only copy in-out values are expected");
1076   Value *CurRetByPtr = OrigArgReplacements[ArgIdx];
1077   IGC_ASSERT_MESSAGE(isa<PointerType>(CurRetByPtr->getType()),
1078                      "a pointer is expected");
1079   if (isa<AddrSpaceCastInst>(CurRetByPtr))
1080     CurRetByPtr = cast<AddrSpaceCastInst>(CurRetByPtr)->getOperand(0);
1081   IGC_ASSERT_MESSAGE(isa<AllocaInst>(CurRetByPtr),
1082                      "corresponding alloca is expected");
1083   Value *CurRetByVal = Builder.CreateLoad(
1084       CurRetByPtr->getType()->getPointerElementType(), CurRetByPtr);
1085   return Builder.CreateInsertValue(&NewRetVal, CurRetByVal, RetIdx);
1086 }
1087 
1088 // Add some additional code before \p OrigCall to pass localized global value
1089 // \p GAI to the transformed function.
1090 // An argument corresponding to \p GAI is returned.
passGlobalAsCallArg(GlobalArgInfo GAI,CallInst & OrigCall)1091 static Value *passGlobalAsCallArg(GlobalArgInfo GAI, CallInst &OrigCall) {
1092   // We should should load the global first to pass it by value.
1093   if (GAI.Kind == GlobalArgKind::ByValueIn ||
1094       GAI.Kind == GlobalArgKind::ByValueInOut)
1095     return new LoadInst(GAI.GV->getType()->getPointerElementType(), GAI.GV,
1096                         GAI.GV->getName() + ".val",
1097                         /* isVolatile */ false, &OrigCall);
1098   IGC_ASSERT_MESSAGE(GAI.Kind == GlobalArgKind::ByPointer,
1099       "localized global can be passed only by value or by pointer");
1100   auto *GVTy = cast<PointerType>(GAI.GV->getType());
1101   // No additional work when addrspaces match
1102   if (GVTy->getAddressSpace() == vc::AddrSpace::Private)
1103     return GAI.GV;
1104   // Need to add a temprorary cast inst to match types.
1105   // When this switch to the caller, it'll remove this cast.
1106   return new AddrSpaceCastInst{
1107       GAI.GV, vc::changeAddrSpace(GVTy, vc::AddrSpace::Private),
1108       GAI.GV->getName() + ".tmp", &OrigCall};
1109 }
1110 
1111 namespace {
1112 class FuncUsersUpdater {
1113   Function &OrigFunc;
1114   Function &NewFunc;
1115   const TransformedFuncInfo &NewFuncInfo;
1116   CallGraphNode &NewFuncCGN;
1117   CallGraph &CG;
1118 
1119 public:
FuncUsersUpdater(Function & OrigFuncIn,Function & NewFuncIn,const TransformedFuncInfo & NewFuncInfoIn,CallGraphNode & NewFuncCGNIn,CallGraph & CGIn)1120   FuncUsersUpdater(Function &OrigFuncIn, Function &NewFuncIn,
1121                    const TransformedFuncInfo &NewFuncInfoIn,
1122                    CallGraphNode &NewFuncCGNIn, CallGraph &CGIn)
1123       : OrigFunc{OrigFuncIn}, NewFunc{NewFuncIn}, NewFuncInfo{NewFuncInfoIn},
1124         NewFuncCGN{NewFuncCGNIn}, CG{CGIn} {}
1125 
run()1126   void run() {
1127     std::vector<CallInst *> DirectUsers;
1128 
1129     for (auto *U : OrigFunc.users()) {
1130       IGC_ASSERT_MESSAGE(
1131           isa<CallInst>(U),
1132           "the transformation is not applied to indirectly called functions");
1133       DirectUsers.push_back(cast<CallInst>(U));
1134     }
1135 
1136     std::vector<CallInst *> NewDirectUsers;
1137     // Loop over all of the callers of the function, transforming the call sites
1138     // to pass in the loaded pointers.
1139     for (auto *OrigCall : DirectUsers) {
1140       IGC_ASSERT(OrigCall->getCalledFunction() == &OrigFunc);
1141       auto *NewCall = UpdateFuncDirectUser(*OrigCall);
1142       NewDirectUsers.push_back(NewCall);
1143     }
1144 
1145     for (auto *OrigCall : DirectUsers)
1146       OrigCall->eraseFromParent();
1147   }
1148 
1149 private:
UpdateFuncDirectUser(CallInst & OrigCall)1150   CallInst *UpdateFuncDirectUser(CallInst &OrigCall) {
1151     std::vector<Value *> NewCallOps =
1152         getTransformedFuncCallArgs(OrigCall, NewFuncInfo);
1153 
1154     AttributeList NewCallAttrs = inheritCallAttributes(
1155         OrigCall, OrigFunc.getFunctionType()->getNumParams(), NewFuncInfo);
1156 
1157     // Push any localized globals.
1158     IGC_ASSERT_MESSAGE(
1159         NewCallOps.size() == NewFuncInfo.getGlobalArgsInfo().FirstGlobalArgIdx,
1160         "call operands and called function info are inconsistent");
1161     llvm::transform(NewFuncInfo.getGlobalArgsInfo().Globals,
1162                     std::back_inserter(NewCallOps),
1163                     [&OrigCall](GlobalArgInfo GAI) {
1164                       return passGlobalAsCallArg(GAI, OrigCall);
1165                     });
1166 
1167     IGC_ASSERT_EXIT_MESSAGE(!isa<InvokeInst>(OrigCall),
1168                             "InvokeInst not supported");
1169 
1170     CallInst *NewCall = CallInst::Create(&NewFunc, NewCallOps, "", &OrigCall);
1171     IGC_ASSERT(nullptr != NewCall);
1172     NewCall->setCallingConv(OrigCall.getCallingConv());
1173     NewCall->setAttributes(NewCallAttrs);
1174     if (cast<CallInst>(OrigCall).isTailCall())
1175       NewCall->setTailCall();
1176     NewCall->setDebugLoc(OrigCall.getDebugLoc());
1177     NewCall->takeName(&OrigCall);
1178 
1179     // Update the callgraph to know that the callsite has been transformed.
1180     auto CalleeNode = static_cast<IGCLLVM::CallGraphNode *>(
1181         CG[OrigCall.getParent()->getParent()]);
1182     CalleeNode->replaceCallEdge(
1183 #if LLVM_VERSION_MAJOR <= 10
1184         CallSite(&OrigCall), NewCall,
1185 #else
1186         OrigCall, *NewCall,
1187 #endif
1188         &NewFuncCGN);
1189 
1190     IRBuilder<> Builder(&OrigCall);
1191     for (auto RetToArg : enumerate(NewFuncInfo.getRetToArgInfo().Map))
1192       handleRetValuePortion(RetToArg.index(), RetToArg.value(), OrigCall,
1193                             *NewCall, Builder, NewFuncInfo);
1194     return NewCall;
1195   }
1196 };
1197 
1198 class FuncBodyTransfer {
1199   Function &OrigFunc;
1200   Function &NewFunc;
1201   const TransformedFuncInfo &NewFuncInfo;
1202 
1203 public:
FuncBodyTransfer(Function & OrigFuncIn,Function & NewFuncIn,const TransformedFuncInfo & NewFuncInfoIn)1204   FuncBodyTransfer(Function &OrigFuncIn, Function &NewFuncIn,
1205                    const TransformedFuncInfo &NewFuncInfoIn)
1206       : OrigFunc{OrigFuncIn}, NewFunc{NewFuncIn}, NewFuncInfo{NewFuncInfoIn} {}
1207 
run()1208   void run() {
1209     // Since we have now created the new function, splice the body of the old
1210     // function right into the new function.
1211     NewFunc.getBasicBlockList().splice(NewFunc.begin(),
1212                                        OrigFunc.getBasicBlockList());
1213 
1214     std::vector<Value *> OrigArgReplacements = handleTransformedFuncArgs();
1215     std::vector<Value *> LocalizedGlobals =
1216         handleGlobalArgs(NewFunc, NewFuncInfo.getGlobalArgsInfo());
1217 
1218     handleTransformedFuncRets(OrigArgReplacements, LocalizedGlobals);
1219   }
1220 
1221 private:
handleTransformedFuncArgs()1222   std::vector<Value *> handleTransformedFuncArgs() {
1223     std::vector<Value *> OrigArgReplacements;
1224     Instruction *InsertPt = &*(NewFunc.begin()->getFirstInsertionPt());
1225 
1226     std::transform(
1227         NewFuncInfo.getArgKinds().begin(), NewFuncInfo.getArgKinds().end(),
1228         NewFunc.arg_begin(), std::back_inserter(OrigArgReplacements),
1229         [InsertPt](ArgKind Kind, Argument &NewArg) -> Value * {
1230           switch (Kind) {
1231           case ArgKind::CopyIn:
1232           case ArgKind::CopyInOut: {
1233             auto *Alloca = new AllocaInst(NewArg.getType(),
1234                                           vc::AddrSpace::Private, "", InsertPt);
1235             new StoreInst{&NewArg, Alloca, InsertPt};
1236             return Alloca;
1237           }
1238           default:
1239             IGC_ASSERT_MESSAGE(Kind == ArgKind::General,
1240                                "unexpected argument kind");
1241             return &NewArg;
1242           }
1243         });
1244 
1245     std::transform(
1246         OrigArgReplacements.begin(), OrigArgReplacements.end(),
1247         OrigFunc.arg_begin(), OrigArgReplacements.begin(),
1248         [InsertPt](Value *Replacement, Argument &OrigArg) -> Value * {
1249           if (Replacement->getType() == OrigArg.getType())
1250             return Replacement;
1251           IGC_ASSERT_MESSAGE(isa<PointerType>(Replacement->getType()),
1252             "only pointers can posibly mismatch");
1253           IGC_ASSERT_MESSAGE(isa<PointerType>(OrigArg.getType()),
1254             "only pointers can posibly mismatch");
1255           IGC_ASSERT_MESSAGE(
1256               Replacement->getType()->getPointerAddressSpace() !=
1257                   OrigArg.getType()->getPointerAddressSpace(),
1258               "pointers should have different addr spaces when they mismatch");
1259           IGC_ASSERT_MESSAGE(
1260               Replacement->getType()->getPointerElementType() ==
1261                   OrigArg.getType()->getPointerElementType(),
1262               "pointers must have same element type when they mismatch");
1263           return new AddrSpaceCastInst(Replacement, OrigArg.getType(), "",
1264                                        InsertPt);
1265         });
1266     for (auto &&[OrigArg, OrigArgReplacement] :
1267          zip(OrigFunc.args(), OrigArgReplacements)) {
1268       OrigArgReplacement->takeName(&OrigArg);
1269       OrigArg.replaceAllUsesWith(OrigArgReplacement);
1270     }
1271 
1272     return std::move(OrigArgReplacements);
1273   }
1274 
handleTransformedFuncRet(ReturnInst & OrigRet,const std::vector<Value * > & OrigArgReplacements,std::vector<Value * > & LocalizedGlobals)1275   void handleTransformedFuncRet(ReturnInst &OrigRet,
1276                                 const std::vector<Value *> &OrigArgReplacements,
1277                                 std::vector<Value *> &LocalizedGlobals) {
1278     Type *NewRetTy = NewFunc.getReturnType();
1279     IRBuilder<> Builder(&OrigRet);
1280     auto &&RetToArg = enumerate(NewFuncInfo.getRetToArgInfo().Map);
1281     Value *NewRetVal = std::accumulate(
1282         RetToArg.begin(), RetToArg.end(),
1283         cast<Value>(UndefValue::get(NewRetTy)),
1284         [&OrigRet, &Builder, &OrigArgReplacements, &LocalizedGlobals,
1285          this](Value *NewRet, auto NewRetPortionInfo) {
1286           return appendTransformedFuncRetPortion(
1287               *NewRet, NewRetPortionInfo.index(), NewRetPortionInfo.value(),
1288               OrigRet, Builder, NewFuncInfo, OrigArgReplacements,
1289               LocalizedGlobals);
1290         });
1291     Builder.CreateRet(NewRetVal);
1292     OrigRet.eraseFromParent();
1293   }
1294 
1295   void
handleTransformedFuncRets(const std::vector<Value * > & OrigArgReplacements,std::vector<Value * > & LocalizedGlobals)1296   handleTransformedFuncRets(const std::vector<Value *> &OrigArgReplacements,
1297                             std::vector<Value *> &LocalizedGlobals) {
1298     Type *NewRetTy = NewFunc.getReturnType();
1299     if (NewRetTy->isVoidTy())
1300       return;
1301     std::vector<ReturnInst *> OrigRets;
1302     llvm::transform(make_filter_range(instructions(NewFunc),
1303                                       [](Instruction &Inst) {
1304                                         return isa<ReturnInst>(Inst);
1305                                       }),
1306                     std::back_inserter(OrigRets),
1307                     [](Instruction &RI) { return &cast<ReturnInst>(RI); });
1308 
1309     for (ReturnInst *OrigRet : OrigRets)
1310       handleTransformedFuncRet(*OrigRet, OrigArgReplacements, LocalizedGlobals);
1311   }
1312 };
1313 } // namespace
1314 
1315 // \brief Actually performs the transformation of the specified arguments, and
1316 // returns the new function.
1317 //
1318 // Note this transformation does change the semantics as a C function, due to
1319 // possible pointer aliasing. But it is allowed as a CM function.
1320 //
1321 // The pass-by-reference scheme is useful to copy-out values from the
1322 // subprogram back to the caller. It also may be useful to convey large inputs
1323 // to subprograms, as the amount of parameter conveying code will be reduced.
1324 // There is a restriction imposed on arguments passed by reference in order to
1325 // allow for an efficient CM implementation. Specifically the restriction is
1326 // that for a subprogram that uses pass-by-reference, the behavior must be the
1327 // same as if we use a copy-in/copy-out semantic to convey the
1328 // pass-by-reference argument; otherwise the CM program is said to be erroneous
1329 // and may produce incorrect results. Such errors are not caught by the
1330 // compiler and it is up to the user to guarantee safety.
1331 //
1332 // The implication of the above stated restriction is that no pass-by-reference
1333 // argument that is written to in a subprogram (either directly or transitively
1334 // by means of a nested subprogram call pass-by-reference argument) may overlap
1335 // with another pass-by-reference parameter or a global variable that is
1336 // referenced in the subprogram; in addition no pass-by-reference subprogram
1337 // argument that is referenced may overlap with a global variable that is
1338 // written to in the subprogram.
1339 //
TransformNode(Function & OrigFunc,SmallPtrSet<Argument *,8> & ArgsToTransform,LocalizationInfo & LI)1340 CallGraphNode *CMABI::TransformNode(Function &OrigFunc,
1341                                     SmallPtrSet<Argument *, 8> &ArgsToTransform,
1342                                     LocalizationInfo &LI) {
1343   NumArgumentsTransformed += ArgsToTransform.size();
1344   TransformedFuncInfo NewFuncInfo{OrigFunc, ArgsToTransform};
1345   NewFuncInfo.AppendGlobals(LI);
1346 
1347   // Create the new function declaration and insert it into the module.
1348   Function *NewFunc = createTransformedFuncDecl(OrigFunc, NewFuncInfo);
1349 
1350   // Get a new callgraph node for NF.
1351   CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
1352   CallGraphNode *NewFuncCGN = CG.getOrInsertFunction(NewFunc);
1353 
1354   FuncUsersUpdater{OrigFunc, *NewFunc, NewFuncInfo, *NewFuncCGN, CG}.run();
1355   FuncBodyTransfer{OrigFunc, *NewFunc, NewFuncInfo}.run();
1356 
1357   // It turns out sometimes llvm will recycle function pointers which confuses
1358   // this pass. We delete its localization info and mark this function as
1359   // already visited.
1360   Info->GlobalInfo.erase(&OrigFunc);
1361   AlreadyVisited.insert(&OrigFunc);
1362 
1363   NewFuncCGN->stealCalledFunctionsFrom(CG[&OrigFunc]);
1364 
1365   // Now that the old function is dead, delete it. If there is a dangling
1366   // reference to the CallgraphNode, just leave the dead function around.
1367   CallGraphNode *CGN = CG[&OrigFunc];
1368   if (CGN->getNumReferences() == 0)
1369     delete CG.removeFunctionFromModule(CGN);
1370   else
1371     OrigFunc.setLinkage(Function::ExternalLinkage);
1372 
1373   return NewFuncCGN;
1374 }
1375 
fillStackWithUsers(std::stack<User * > & Stack,User & CurUser)1376 static void fillStackWithUsers(std::stack<User *> &Stack, User &CurUser) {
1377   for (User *Usr : CurUser.users())
1378     Stack.push(Usr);
1379 }
1380 
1381 // Traverse in depth through GV constant users to find instruction users.
1382 // When instruction user is found, it is clear in which function GV is used.
defineGVDirectUsers(GlobalVariable & GV)1383 void CMABIAnalysis::defineGVDirectUsers(GlobalVariable &GV) {
1384   std::stack<User *> Stack;
1385   Stack.push(&GV);
1386   while (!Stack.empty()) {
1387     auto *CurUser = Stack.top();
1388     Stack.pop();
1389 
1390     // Continue go in depth when a constant is met.
1391     if (isa<Constant>(CurUser)) {
1392       fillStackWithUsers(Stack, *CurUser);
1393       continue;
1394     }
1395 
1396     // We've got what we looked for.
1397     auto *Inst = cast<Instruction>(CurUser);
1398     addDirectGlobal(Inst->getFunction(), &GV);
1399   }
1400 }
1401 
1402 // For each function, compute the list of globals that need to be passed as
1403 // copy-in and copy-out arguments.
analyzeGlobals(CallGraph & CG)1404 void CMABIAnalysis::analyzeGlobals(CallGraph &CG) {
1405   Module &M = CG.getModule();
1406   for (auto& F : M.getFunctionList()) {
1407     if (F.isDeclaration() || F.hasDLLExportStorageClass())
1408       continue;
1409     if (GenXIntrinsic::getAnyIntrinsicID(&F) !=
1410         GenXIntrinsic::not_any_intrinsic)
1411       continue;
1412     // __cm_intrinsic_impl_* could be used for emulation mul/div etc
1413     if (F.getName().contains("__cm_intrinsic_impl_"))
1414       continue;
1415 
1416     // Convert non-kernel to stack call if applicable
1417     if (FCtrl == FunctionControl::StackCall && !genx::requiresStackCall(&F)) {
1418       LLVM_DEBUG(dbgs() << "Adding stack call to: " << F.getName() << "\n");
1419       F.addFnAttr(genx::FunctionMD::CMStackCall);
1420     }
1421 
1422     // Do not change stack calls linkage as we may have both types of stack
1423     // calls.
1424     if (genx::requiresStackCall(&F) && SaveStackCallLinkage)
1425       continue;
1426 
1427     F.setLinkage(GlobalValue::InternalLinkage);
1428   }
1429   // No global variables.
1430   if (M.global_empty())
1431     return;
1432 
1433   // FIXME: String constants must be localized too. Excluding them there
1434   //        to WA legacy printf implementation in CM FE (printf strings are
1435   //        not in constant addrspace in legacy printf).
1436   auto ToLocalize =
1437       make_filter_range(M.globals(), [](const GlobalVariable &GV) {
1438         return GV.getAddressSpace() == vc::AddrSpace::Private &&
1439                !GV.hasAttribute(genx::FunctionMD::GenXVolatile) &&
1440                !vc::isConstantString(GV);
1441       });
1442 
1443   // Collect direct and indirect (GV is used in a called function)
1444   // uses of globals.
1445   for (GlobalVariable &GV : ToLocalize)
1446     defineGVDirectUsers(GV);
1447   for (const std::vector<CallGraphNode *> &SCCNodes :
1448        make_range(scc_begin(&CG), scc_end(&CG)))
1449     for (const CallGraphNode *Caller : SCCNodes)
1450       for (const IGCLLVM::CallGraphNode::CallRecord &Callee : *Caller) {
1451         Function *CalleeF = Callee.second->getFunction();
1452         if (CalleeF && !vc::isFixedSignatureFunc(*CalleeF))
1453           addIndirectGlobal(Caller->getFunction(), CalleeF);
1454       }
1455 }
1456 
1457 /***********************************************************************
1458  * diagnoseOverlappingArgs : attempt to diagnose overlapping by-ref args
1459  *
1460  * The CM language spec says you are not allowed a call with two by-ref args
1461  * that overlap. This is to give the compiler the freedom to implement with
1462  * copy-in copy-out semantics or with an address register.
1463  *
1464  * This function attempts to diagnose code that breaks this restriction. For
1465  * pointer args to the call, it attempts to track how values are loaded using
1466  * the pointer (assumed to be an alloca of the temporary used for copy-in
1467  * copy-out semantics), and how those values then get propagated through
1468  * wrregions and stores. If any vector element in a wrregion or store is found
1469  * that comes from more than one pointer arg, it is reported.
1470  *
1471  * This ignores variable index wrregions, and only traces through instructions
1472  * with the same debug location as the call, so does not work with -g0.
1473  */
diagnoseOverlappingArgs(CallInst * CI)1474 void CMABI::diagnoseOverlappingArgs(CallInst *CI)
1475 {
1476   LLVM_DEBUG(dbgs() << "diagnoseOverlappingArgs " << *CI << "\n");
1477   auto DL = CI->getDebugLoc();
1478   if (!DL)
1479     return;
1480   std::map<Value *, SmallVector<uint8_t, 16>> ValMap;
1481   SmallVector<Instruction *, 8> WorkList;
1482   std::set<Instruction *> InWorkList;
1483   std::set<std::pair<unsigned, unsigned>> Reported;
1484   // Using ArgIndex starting at 1 so we can reserve 0 to mean "element does not
1485   // come from any by-ref arg".
1486   for (unsigned ArgIndex = 1, NumArgs = CI->getNumArgOperands();
1487       ArgIndex <= NumArgs; ++ArgIndex) {
1488     Value *Arg = CI->getOperand(ArgIndex - 1);
1489     if (!Arg->getType()->isPointerTy())
1490       continue;
1491     LLVM_DEBUG(dbgs() << "arg " << ArgIndex << ": " << *Arg << "\n");
1492     // Got a pointer arg. Find its loads (with the same debug loc).
1493     for (auto ui = Arg->use_begin(), ue = Arg->use_end(); ui != ue; ++ui) {
1494       auto LI = dyn_cast<LoadInst>(ui->getUser());
1495       if (!LI || LI->getDebugLoc() != DL)
1496         continue;
1497       LLVM_DEBUG(dbgs() << "  " << *LI << "\n");
1498       // For a load, create a map entry that says that every vector element
1499       // comes from this arg.
1500       unsigned NumElements = 1;
1501       if (auto VT = dyn_cast<IGCLLVM::FixedVectorType>(LI->getType()))
1502         NumElements = VT->getNumElements();
1503       auto Entry = &ValMap[LI];
1504       Entry->resize(NumElements, ArgIndex);
1505       // Add its users (with the same debug location) to the work list.
1506       for (auto ui = LI->use_begin(), ue = LI->use_end(); ui != ue; ++ui) {
1507         auto Inst = cast<Instruction>(ui->getUser());
1508         if (Inst->getDebugLoc() == DL)
1509           if (InWorkList.insert(Inst).second)
1510             WorkList.push_back(Inst);
1511       }
1512     }
1513   }
1514   // Process the work list.
1515   while (!WorkList.empty()) {
1516     auto Inst = WorkList.back();
1517     WorkList.pop_back();
1518     InWorkList.erase(Inst);
1519     LLVM_DEBUG(dbgs() << "From worklist: " << *Inst << "\n");
1520     Value *Key = nullptr;
1521     SmallVector<uint8_t, 8> TempVector;
1522     SmallVectorImpl<uint8_t> *VectorToMerge = nullptr;
1523     if (auto SI = dyn_cast<StoreInst>(Inst)) {
1524       // Store: set the map entry using the store pointer as the key. It might
1525       // be an alloca of a local variable, or a global variable.
1526       // Strictly speaking this is not properly keeping track of what is being
1527       // merged using load-wrregion-store for a non-SROAd local variable or a
1528       // global variable. Instead it is just merging at the store itself, which
1529       // is good enough for our purposes.
1530       Key = SI->getPointerOperand();
1531       VectorToMerge = &ValMap[SI->getValueOperand()];
1532     } else if (auto BC = dyn_cast<BitCastInst>(Inst)) {
1533       // Bitcast: calculate the new map entry.
1534       Key = BC;
1535       uint64_t OutElementSize =
1536           BC->getType()->getScalarType()->getPrimitiveSizeInBits();
1537       uint64_t InElementSize = BC->getOperand(0)
1538                                    ->getType()
1539                                    ->getScalarType()
1540                                    ->getPrimitiveSizeInBits();
1541       int LogRatio = countTrailingZeros(OutElementSize, ZB_Undefined) -
1542                      countTrailingZeros(InElementSize, ZB_Undefined);
1543       auto OpndEntry = &ValMap[BC->getOperand(0)];
1544       if (!LogRatio)
1545         VectorToMerge = OpndEntry;
1546       else if (LogRatio > 0) {
1547         // Result element type is bigger than input element type, so there are
1548         // fewer result elements. Just use an arbitrarily chosen non-zero entry
1549         // of the N input elements to set the 1 result element.
1550         IGC_ASSERT(!(OpndEntry->size() & ((1U << LogRatio) - 1)));
1551         for (unsigned i = 0, e = OpndEntry->size(); i != e; i += 1U << LogRatio) {
1552           unsigned FoundArgIndex = 0;
1553           for (unsigned j = 0; j != 1U << LogRatio; ++j)
1554             FoundArgIndex = std::max(FoundArgIndex, (unsigned)(*OpndEntry)[i + j]);
1555           TempVector.push_back(FoundArgIndex);
1556         }
1557         VectorToMerge = &TempVector;
1558       } else {
1559         // Result element type is smaller than input element type, so there are
1560         // multiple result elements per input element.
1561         for (unsigned i = 0, e = OpndEntry->size(); i != e; ++i)
1562           for (unsigned j = 0; j != 1U << -LogRatio; ++j)
1563             TempVector.push_back((*OpndEntry)[i]);
1564         VectorToMerge = &TempVector;
1565       }
1566     } else if (auto CI = dyn_cast<CallInst>(Inst)) {
1567       if (auto CF = CI->getCalledFunction()) {
1568         switch (GenXIntrinsic::getGenXIntrinsicID(CF)) {
1569           default:
1570             break;
1571           case GenXIntrinsic::genx_wrregionf:
1572           case GenXIntrinsic::genx_wrregioni:
1573             // wrregion: As long as it is constant index, propagate the argument
1574             // indices into the appropriate elements of the result.
1575             if (auto IdxC = dyn_cast<Constant>(CI->getOperand(
1576                     GenXIntrinsic::GenXRegion::WrIndexOperandNum))) {
1577               unsigned Idx = 0;
1578               if (!IdxC->isNullValue()) {
1579                 auto IdxCI = dyn_cast<ConstantInt>(IdxC);
1580                 if (!IdxCI) {
1581                   LLVM_DEBUG(dbgs() << "Ignoring variable index wrregion\n");
1582                   break;
1583                 }
1584                 Idx = IdxCI->getZExtValue();
1585               }
1586               Idx /= (CI->getType()->getScalarType()->getPrimitiveSizeInBits() / 8U);
1587               // First copy the "old value" input to the map entry.
1588               auto OpndEntry = &ValMap[CI->getOperand(
1589                     GenXIntrinsic::GenXRegion::OldValueOperandNum)];
1590               auto Entry = &ValMap[CI];
1591               Entry->clear();
1592               Entry->insert(Entry->begin(), OpndEntry->begin(), OpndEntry->end());
1593               // Then copy the "new value" elements according to the region.
1594               TempVector.resize(
1595                   dyn_cast<IGCLLVM::FixedVectorType>(CI->getType())->getNumElements(), 0);
1596               int VStride = cast<ConstantInt>(CI->getOperand(
1597                     GenXIntrinsic::GenXRegion::WrVStrideOperandNum))->getSExtValue();
1598               unsigned Width = cast<ConstantInt>(CI->getOperand(
1599                     GenXIntrinsic::GenXRegion::WrWidthOperandNum))->getZExtValue();
1600               IGC_ASSERT_MESSAGE((Width > 0), "Width of a region must be non-zero");
1601               int Stride = cast<ConstantInt>(CI->getOperand(
1602                     GenXIntrinsic::GenXRegion::WrStrideOperandNum))->getSExtValue();
1603               OpndEntry = &ValMap[CI->getOperand(
1604                     GenXIntrinsic::GenXRegion::NewValueOperandNum)];
1605               unsigned NumElements = OpndEntry->size();
1606               if (!NumElements)
1607                 break;
1608               for (unsigned RowIdx = Idx, Row = 0, Col = 0,
1609                     NumRows = NumElements / Width;; Idx += Stride, ++Col) {
1610                 if (Col == Width) {
1611                   Col = 0;
1612                   if (++Row == NumRows)
1613                     break;
1614                   Idx = RowIdx += VStride;
1615                 }
1616                 TempVector[Idx] = (*OpndEntry)[Row * Width + Col];
1617               }
1618               VectorToMerge = &TempVector;
1619               Key = CI;
1620             }
1621             break;
1622         }
1623       }
1624     }
1625     if (!VectorToMerge)
1626       continue;
1627     auto Entry = &ValMap[Key];
1628     LLVM_DEBUG(dbgs() << "Merging :";
1629       for (unsigned i = 0; i != VectorToMerge->size(); ++i)
1630         dbgs() << " " << (unsigned)(*VectorToMerge)[i];
1631       dbgs() << "\ninto " << Key->getName() << ":";
1632       for (unsigned i = 0; i != Entry->size(); ++i)
1633         dbgs() << " " << (unsigned)(*Entry)[i];
1634       dbgs() << "\n");
1635     if (Entry->empty())
1636       Entry->insert(Entry->end(), VectorToMerge->begin(), VectorToMerge->end());
1637     else {
1638       IGC_ASSERT(VectorToMerge->size() == Entry->size());
1639       for (unsigned i = 0; i != VectorToMerge->size(); ++i) {
1640         unsigned ArgIdx1 = (*VectorToMerge)[i];
1641         unsigned ArgIdx2 = (*Entry)[i];
1642         if (ArgIdx1 && ArgIdx2 && ArgIdx1 != ArgIdx2) {
1643           LLVM_DEBUG(dbgs() << "By ref args overlap: args " << ArgIdx1 << " and " << ArgIdx2 << "\n");
1644           if (ArgIdx1 > ArgIdx2)
1645             std::swap(ArgIdx1, ArgIdx2);
1646           if (Reported.insert(std::pair<unsigned, unsigned>(ArgIdx1, ArgIdx2))
1647                 .second) {
1648             // Not already reported.
1649             DiagnosticInfoOverlappingArgs Err(CI, "by reference arguments "
1650                 + Twine(ArgIdx1) + " and " + Twine(ArgIdx2) + " overlap",
1651                 DS_Error);
1652             Inst->getContext().diagnose(Err);
1653           }
1654         }
1655         (*Entry)[i] = std::max((*Entry)[i], (*VectorToMerge)[i]);
1656       }
1657     }
1658     LLVM_DEBUG(dbgs() << "giving:";
1659       for (unsigned i = 0; i != Entry->size(); ++i)
1660         dbgs() << " " << (unsigned)(*Entry)[i];
1661       dbgs() << "\n");
1662     if (Key == Inst) {
1663       // Not the case that we have a store and we are using the pointer as
1664       // the key. In ther other cases that do a merge (bitcast and wrregion),
1665       // add users to the work list as long as they have the same debug loc.
1666       for (auto ui = Inst->use_begin(), ue = Inst->use_end(); ui != ue; ++ui) {
1667         auto User = cast<Instruction>(ui->getUser());
1668         if (User->getDebugLoc() == DL)
1669           if (InWorkList.insert(Inst).second)
1670             WorkList.push_back(User);
1671       }
1672     }
1673   }
1674 }
1675 
1676 /***********************************************************************
1677  * DiagnosticInfoOverlappingArgs initializer from Instruction
1678  *
1679  * If the Instruction has a DebugLoc, then that is used for the error
1680  * location.
1681  * Otherwise, the location is unknown.
1682  */
DiagnosticInfoOverlappingArgs(Instruction * Inst,const Twine & Desc,DiagnosticSeverity Severity)1683 DiagnosticInfoOverlappingArgs::DiagnosticInfoOverlappingArgs(Instruction *Inst,
1684     const Twine &Desc, DiagnosticSeverity Severity)
1685     : DiagnosticInfo(getKindID(), Severity), Line(0), Col(0)
1686 {
1687   auto DL = Inst->getDebugLoc();
1688   if (!DL) {
1689     Filename = DL.get()->getFilename();
1690     Line = DL.getLine();
1691     Col = DL.getCol();
1692   }
1693   Description = Desc.str();
1694 }
1695 
1696 /***********************************************************************
1697  * DiagnosticInfoOverlappingArgs::print : print the error/warning message
1698  */
print(DiagnosticPrinter & DP) const1699 void DiagnosticInfoOverlappingArgs::print(DiagnosticPrinter &DP) const
1700 {
1701   std::string Loc(
1702         (Twine(!Filename.empty() ? Filename : "<unknown>")
1703         + ":" + Twine(Line)
1704         + (!Col ? Twine() : Twine(":") + Twine(Col))
1705         + ": ")
1706       .str());
1707   DP << Loc << Description;
1708 }
1709 
1710 
1711 char CMABI::ID = 0;
1712 INITIALIZE_PASS_BEGIN(CMABI, "cmabi", "Fix ABI issues for the genx backend", false, false)
INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)1713 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
1714 INITIALIZE_PASS_DEPENDENCY(CMABIAnalysis)
1715 INITIALIZE_PASS_END(CMABI, "cmabi", "Fix ABI issues for the genx backend", false, false)
1716 
1717 Pass *llvm::createCMABIPass() { return new CMABI(); }
1718 
1719 namespace {
1720 
1721 // A well-formed passing argument by reference pattern.
1722 //
1723 // (Alloca)
1724 // %argref1 = alloca <8 x float>, align 32
1725 //
1726 // (CopyInRegion/CopyInStore)
1727 // %rdr = tail call <8 x float> @llvm.genx.rdregionf(<960 x float> %m, i32 0, i32 8, i32 1, i16 0, i32 undef)
1728 // call void @llvm.genx.vstore(<8 x float> %rdr, <8 x float>* %argref)
1729 //
1730 // (CopyOutRegion/CopyOutLoad)
1731 // %ld = call <8 x float> @llvm.genx.vload(<8 x float>* %argref)
1732 // %wr = call <960 x float> @llvm.genx.wrregionf(<960 x float> %m, <8 x float> %ld, i32 0, i32 8, i32 1, i16 0, i32 undef, i1 true)
1733 //
1734 struct ArgRefPattern {
1735   // Alloca of this reference argument.
1736   AllocaInst *Alloca;
1737 
1738   // The input value
1739   CallInst *CopyInRegion;
1740   CallInst *CopyInStore;
1741 
1742   // The output value
1743   CallInst *CopyOutLoad;
1744   CallInst *CopyOutRegion;
1745 
1746   // Load and store instructions on arg alloca.
1747   SmallVector<CallInst *, 8> VLoads;
1748   SmallVector<CallInst *, 8> VStores;
1749 
ArgRefPattern__anon278d8c6d1111::ArgRefPattern1750   explicit ArgRefPattern(AllocaInst *AI)
1751       : Alloca(AI), CopyInRegion(nullptr), CopyInStore(nullptr),
1752         CopyOutLoad(nullptr), CopyOutRegion(nullptr) {}
1753 
1754   // Match a copy-in and copy-out pattern. Return true on success.
1755   bool match(DominatorTree &DT, PostDominatorTree &PDT);
1756   void process(DominatorTree &DT);
1757 };
1758 
1759 struct CMLowerVLoadVStore : public FunctionPass {
1760   static char ID;
CMLowerVLoadVStore__anon278d8c6d1111::CMLowerVLoadVStore1761   CMLowerVLoadVStore() : FunctionPass(ID) {
1762     initializeCMLowerVLoadVStorePass(*PassRegistry::getPassRegistry());
1763   }
getAnalysisUsage__anon278d8c6d1111::CMLowerVLoadVStore1764   void getAnalysisUsage(AnalysisUsage &AU) const override {
1765     AU.addRequired<DominatorTreeWrapperPass>();
1766     AU.addRequired<PostDominatorTreeWrapperPass>();
1767     AU.setPreservesCFG();
1768   }
1769 
1770   bool runOnFunction(Function &F) override;
1771 
1772 private:
1773   bool promoteAllocas(Function &F);
1774   bool lowerLoadStore(Function &F);
1775 };
1776 
1777 } // namespace
1778 
1779 char CMLowerVLoadVStore::ID = 0;
1780 INITIALIZE_PASS_BEGIN(CMLowerVLoadVStore, "CMLowerVLoadVStore",
1781                       "Lower CM reference vector loads and stores", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)1782 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
1783 INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
1784 INITIALIZE_PASS_END(CMLowerVLoadVStore, "CMLowerVLoadVStore",
1785                     "Lower CM reference vector loads and stores", false, false)
1786 
1787 
1788 bool CMLowerVLoadVStore::runOnFunction(Function &F) {
1789   bool Changed = false;
1790   Changed |= promoteAllocas(F);
1791   Changed |= lowerLoadStore(F);
1792   return Changed;
1793 }
1794 
1795 // Lower remaining vector load/store intrinsic calls into normal load/store
1796 // instructions.
lowerLoadStore(Function & F)1797 bool CMLowerVLoadVStore::lowerLoadStore(Function &F) {
1798   auto M = F.getParent();
1799   DenseMap<AllocaInst*, GlobalVariable*> AllocaMap;
1800   // collect all the allocas that store the address of genx-volatile variable
1801   for (auto& G : M->getGlobalList()) {
1802     if (!G.hasAttribute("genx_volatile"))
1803       continue;
1804     std::vector<User*> WL;
1805     for (auto UI = G.user_begin(); UI != G.user_end();) {
1806       auto U = *UI++;
1807       WL.push_back(U);
1808     }
1809 
1810     while (!WL.empty()) {
1811       auto Inst = WL.back();
1812       WL.pop_back();
1813       if (auto CE = dyn_cast<ConstantExpr>(Inst)) {
1814         for (auto UI = CE->user_begin(); UI != CE->user_end();) {
1815           auto U = *UI++;
1816           WL.push_back(U);
1817         }
1818       }
1819       else if (auto CI = dyn_cast<CastInst>(Inst)) {
1820         for (auto UI = CI->user_begin(); UI != CI->user_end();) {
1821           auto U = *UI++;
1822           WL.push_back(U);
1823         }
1824       }
1825       else if (auto SI = dyn_cast<StoreInst>(Inst)) {
1826         auto Ptr = SI->getPointerOperand()->stripPointerCasts();
1827         if (auto PI = dyn_cast<AllocaInst>(Ptr)) {
1828           AllocaMap[PI] = &G;
1829         }
1830       }
1831     }
1832   }
1833 
1834   // lower all vload/vstore into normal load/store.
1835   std::vector<Instruction *> ToErase;
1836   for (Instruction &Inst : instructions(F)) {
1837     if (GenXIntrinsic::isVLoadStore(&Inst)) {
1838       auto *Ptr = Inst.getOperand(0);
1839       if (GenXIntrinsic::isVStore(&Inst))
1840         Ptr = Inst.getOperand(1);
1841       auto AS0 = cast<PointerType>(Ptr->getType())->getAddressSpace();
1842       Ptr = Ptr->stripPointerCasts();
1843       auto GV = dyn_cast<GlobalVariable>(Ptr);
1844       if (GV) {
1845         if (!GV->hasAttribute("genx_volatile"))
1846           GV = nullptr;
1847       }
1848       else if (auto LI = dyn_cast<LoadInst>(Ptr)) {
1849         auto PV = LI->getPointerOperand()->stripPointerCasts();
1850         if (auto PI = dyn_cast<AllocaInst>(PV)) {
1851           if (AllocaMap.find(PI) != AllocaMap.end()) {
1852             GV = AllocaMap[PI];
1853           }
1854         }
1855       }
1856       if (GV == nullptr) {
1857         // change to load/store
1858         IRBuilder<> Builder(&Inst);
1859         if (GenXIntrinsic::isVStore(&Inst))
1860           Builder.CreateStore(Inst.getOperand(0), Inst.getOperand(1));
1861         else {
1862           Value *Op0 = Inst.getOperand(0);
1863           auto LI = Builder.CreateLoad(Op0->getType()->getPointerElementType(),
1864                                        Op0, Inst.getName());
1865           LI->setDebugLoc(Inst.getDebugLoc());
1866           Inst.replaceAllUsesWith(LI);
1867         }
1868         ToErase.push_back(&Inst);
1869       }
1870       else {
1871         // change to vload/vstore that has the same address space as
1872         // the global-var in order to clean up unnecessary addr-cast.
1873         auto AS1 = GV->getType()->getAddressSpace();
1874         if (AS0 != AS1) {
1875           IRBuilder<> Builder(&Inst);
1876           if (GenXIntrinsic::isVStore(&Inst)) {
1877             auto PtrTy = cast<PointerType>(Inst.getOperand(1)->getType());
1878             PtrTy = PointerType::get(PtrTy->getElementType(), AS1);
1879             auto PtrCast = Builder.CreateAddrSpaceCast(Inst.getOperand(1), PtrTy);
1880             Type* Tys[] = { Inst.getOperand(0)->getType(),
1881                            PtrCast->getType() };
1882             Value* Args[] = { Inst.getOperand(0), PtrCast };
1883             Function* Fn = GenXIntrinsic::getGenXDeclaration(
1884               F.getParent(), GenXIntrinsic::genx_vstore, Tys);
1885             Builder.CreateCall(Fn, Args, Inst.getName());
1886           }
1887           else {
1888             auto PtrTy = cast<PointerType>(Inst.getOperand(0)->getType());
1889             PtrTy = PointerType::get(PtrTy->getElementType(), AS1);
1890             auto PtrCast = Builder.CreateAddrSpaceCast(Inst.getOperand(0), PtrTy);
1891             Type* Tys[] = { Inst.getType(), PtrCast->getType() };
1892             Function* Fn = GenXIntrinsic::getGenXDeclaration(
1893               F.getParent(), GenXIntrinsic::genx_vload, Tys);
1894             Value* VLoad = Builder.CreateCall(Fn, PtrCast, Inst.getName());
1895             Inst.replaceAllUsesWith(VLoad);
1896           }
1897           ToErase.push_back(&Inst);
1898         }
1899       }
1900     }
1901   }
1902 
1903   for (auto Inst : ToErase) {
1904     Inst->eraseFromParent();
1905   }
1906 
1907   return !ToErase.empty();
1908 }
1909 
isBitCastForLifetimeMarker(Value * V)1910 static bool isBitCastForLifetimeMarker(Value *V) {
1911   if (!V || !isa<BitCastInst>(V))
1912     return false;
1913   for (auto U : V->users()) {
1914     unsigned IntrinsicID = GenXIntrinsic::getAnyIntrinsicID(U);
1915     if (IntrinsicID != Intrinsic::lifetime_start &&
1916         IntrinsicID != Intrinsic::lifetime_end)
1917       return false;
1918   }
1919   return true;
1920 }
1921 
1922 // Check whether two values are bitwise identical.
isBitwiseIdentical(Value * V1,Value * V2)1923 static bool isBitwiseIdentical(Value *V1, Value *V2) {
1924   IGC_ASSERT_MESSAGE(V1, "null value");
1925   IGC_ASSERT_MESSAGE(V2, "null value");
1926   if (V1 == V2)
1927     return true;
1928   if (BitCastInst *BI = dyn_cast<BitCastInst>(V1))
1929     V1 = BI->getOperand(0);
1930   if (BitCastInst *BI = dyn_cast<BitCastInst>(V2))
1931     V2 = BI->getOperand(0);
1932 
1933   // Special case arises from vload/vstore.
1934   if (GenXIntrinsic::isVLoad(V1) && GenXIntrinsic::isVLoad(V2)) {
1935     auto L1 = cast<CallInst>(V1);
1936     auto L2 = cast<CallInst>(V2);
1937     // Check if loading from the same location.
1938     if (L1->getOperand(0) != L2->getOperand(0))
1939       return false;
1940 
1941     // Check if this pointer is local and only used in vload/vstore.
1942     Value *Addr = L1->getOperand(0);
1943     if (!isa<AllocaInst>(Addr))
1944       return false;
1945     for (auto UI : Addr->users()) {
1946       if (isa<BitCastInst>(UI)) {
1947         for (auto U : UI->users()) {
1948           unsigned IntrinsicID = GenXIntrinsic::getAnyIntrinsicID(U);
1949           if (IntrinsicID != Intrinsic::lifetime_start &&
1950               IntrinsicID != Intrinsic::lifetime_end)
1951             return false;
1952         }
1953       } else {
1954         if (!GenXIntrinsic::isVLoadStore(UI))
1955           return false;
1956       }
1957     }
1958 
1959     // Check if there is no store to the same location in between.
1960     if (L1->getParent() != L2->getParent())
1961       return false;
1962     BasicBlock::iterator I = L1->getParent()->begin();
1963     for (; &*I != L1 && &*I != L2; ++I)
1964       /*empty*/;
1965     IGC_ASSERT(&*I == L1 || &*I == L2);
1966     auto IEnd = (&*I == L1) ? L2->getIterator() : L1->getIterator();
1967     for (; I != IEnd; ++I) {
1968       Instruction *Inst = &*I;
1969       if (GenXIntrinsic::isVStore(Inst) && Inst->getOperand(1) == Addr)
1970         return false;
1971     }
1972 
1973     // OK.
1974     return true;
1975   }
1976 
1977   // Cannot prove.
1978   return false;
1979 }
1980 
match(DominatorTree & DT,PostDominatorTree & PDT)1981 bool ArgRefPattern::match(DominatorTree &DT, PostDominatorTree &PDT) {
1982   IGC_ASSERT(Alloca);
1983   if (Alloca->use_empty())
1984     return false;
1985 
1986   // check if all users are load/store.
1987   SmallVector<CallInst *, 8> Loads;
1988   SmallVector<CallInst *, 8> Stores;
1989   for (auto U : Alloca->users())
1990     if (GenXIntrinsic::isVLoad(U))
1991       Loads.push_back(cast<CallInst>(U));
1992     else if (GenXIntrinsic::isVStore(U))
1993       Stores.push_back(cast<CallInst>(U));
1994     else if (isBitCastForLifetimeMarker(U))
1995       continue;
1996     else
1997       return false;
1998 
1999   if (Loads.empty() || Stores.empty())
2000     return false;
2001 
2002   // find a unique store that dominates all other users if exists.
2003   auto Cmp = [&](CallInst *L, CallInst *R) { return DT.dominates(L, R); };
2004   CopyInStore = *std::min_element(Stores.begin(), Stores.end(), Cmp);
2005   CopyInRegion = dyn_cast<CallInst>(CopyInStore->getArgOperand(0));
2006   if (!CopyInRegion || !CopyInRegion->hasOneUse() || !GenXIntrinsic::isRdRegion(CopyInRegion))
2007     return false;
2008 
2009   for (auto SI : Stores)
2010     if (SI != CopyInStore && !Cmp(CopyInStore, SI))
2011       return false;
2012   for (auto LI : Loads)
2013     if (LI != CopyInStore && !Cmp(CopyInStore, LI))
2014       return false;
2015 
2016   // find a unique load that post-dominates all other users if exists.
2017   auto PostCmp = [&](CallInst *L, CallInst *R) {
2018       BasicBlock *LBB = L->getParent();
2019       BasicBlock *RBB = R->getParent();
2020       if (LBB != RBB)
2021           return PDT.dominates(LBB, RBB);
2022 
2023       // Loop through the basic block until we find L or R.
2024       BasicBlock::const_iterator I = LBB->begin();
2025       for (; &*I != L && &*I != R; ++I)
2026           /*empty*/;
2027 
2028       return &*I == R;
2029   };
2030   CopyOutLoad = *std::min_element(Loads.begin(), Loads.end(), PostCmp);
2031 
2032   // Expect copy-out load has one or zero use. It is possible there
2033   // is no use as the region becomes dead after this subroutine call.
2034   //
2035   if (!CopyOutLoad->use_empty()) {
2036     if (!CopyOutLoad->hasOneUse())
2037       return false;
2038     CopyOutRegion = dyn_cast<CallInst>(CopyOutLoad->user_back());
2039     if (!GenXIntrinsic::isWrRegion(CopyOutRegion))
2040       return false;
2041   }
2042 
2043   for (auto SI : Stores)
2044     if (SI != CopyOutLoad && !PostCmp(CopyOutLoad, SI))
2045       return false;
2046   for (auto LI : Loads)
2047     if (LI != CopyOutLoad && !PostCmp(CopyOutLoad, LI))
2048       return false;
2049 
2050   // Ensure read-in and write-out to the same region. It is possible that region
2051   // collasping does not simplify region accesses completely.
2052   // Probably we should use an assertion statement on region descriptors.
2053   if (CopyOutRegion &&
2054       !isBitwiseIdentical(CopyInRegion->getOperand(0),
2055                           CopyOutRegion->getOperand(0)))
2056     return false;
2057 
2058   // It should be OK to rewrite all loads and stores into the argref.
2059   VLoads.swap(Loads);
2060   VStores.swap(Stores);
2061   return true;
2062 }
2063 
process(DominatorTree & DT)2064 void ArgRefPattern::process(DominatorTree &DT) {
2065   // 'Spill' the base region into memory during rewriting.
2066   IRBuilder<> Builder(Alloca);
2067   Function *RdFn = CopyInRegion->getCalledFunction();
2068   IGC_ASSERT(RdFn);
2069   Type *BaseAllocaTy = RdFn->getFunctionType()->getParamType(0);
2070   AllocaInst *BaseAlloca = Builder.CreateAlloca(BaseAllocaTy, nullptr,
2071                                                 Alloca->getName() + ".refprom");
2072 
2073   Builder.SetInsertPoint(CopyInRegion);
2074   Builder.CreateStore(CopyInRegion->getArgOperand(0), BaseAlloca);
2075 
2076   if (CopyOutRegion) {
2077     Builder.SetInsertPoint(CopyOutRegion);
2078     CopyOutRegion->setArgOperand(
2079         0, Builder.CreateLoad(BaseAlloca->getType()->getPointerElementType(),
2080                               BaseAlloca));
2081   }
2082 
2083   // Rewrite all stores.
2084   for (auto ST : VStores) {
2085     Builder.SetInsertPoint(ST);
2086     Value *OldVal = Builder.CreateLoad(
2087         BaseAlloca->getType()->getPointerElementType(), BaseAlloca);
2088     // Always use copy-in region arguments as copy-out region
2089     // arguments do not dominate this store.
2090     auto M = ST->getParent()->getParent()->getParent();
2091     Value *Args[] = {OldVal,
2092                      ST->getArgOperand(0),
2093                      CopyInRegion->getArgOperand(1), // vstride
2094                      CopyInRegion->getArgOperand(2), // width
2095                      CopyInRegion->getArgOperand(3), // hstride
2096                      CopyInRegion->getArgOperand(4), // offset
2097                      CopyInRegion->getArgOperand(5), // parent width
2098                      ConstantInt::getTrue(Type::getInt1Ty(M->getContext()))};
2099     auto ID = OldVal->getType()->isFPOrFPVectorTy() ? GenXIntrinsic::genx_wrregionf
2100                                                     : GenXIntrinsic::genx_wrregioni;
2101     Type *Tys[] = {Args[0]->getType(), Args[1]->getType(), Args[5]->getType(),
2102                    Args[7]->getType()};
2103     auto WrFn = GenXIntrinsic::getGenXDeclaration(M, ID, Tys);
2104     Value *NewVal = Builder.CreateCall(WrFn, Args);
2105     Builder.CreateStore(NewVal, BaseAlloca);
2106     ST->eraseFromParent();
2107   }
2108 
2109   // Rewrite all loads
2110   for (auto LI : VLoads) {
2111     if (LI->use_empty())
2112       continue;
2113 
2114     Builder.SetInsertPoint(LI);
2115     Value *SrcVal = Builder.CreateLoad(
2116         BaseAlloca->getType()->getPointerElementType(), BaseAlloca);
2117     SmallVector<Value *, 8> Args(CopyInRegion->arg_operands());
2118     Args[0] = SrcVal;
2119     Value *Val = Builder.CreateCall(RdFn, Args);
2120     LI->replaceAllUsesWith(Val);
2121     LI->eraseFromParent();
2122   }
2123   // BaseAlloca created manually, w/o RAUW, need fix debug-info for it
2124   llvm::replaceAllDbgUsesWith(*Alloca, *BaseAlloca, *BaseAlloca, DT);
2125 }
2126 
2127 // Allocas that are used in reference argument passing may be promoted into the
2128 // base region.
promoteAllocas(Function & F)2129 bool CMLowerVLoadVStore::promoteAllocas(Function &F) {
2130   auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2131   auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
2132   bool Modified = false;
2133 
2134   SmallVector<AllocaInst *, 8> Allocas;
2135   for (auto &Inst : F.front().getInstList()) {
2136     if (auto AI = dyn_cast<AllocaInst>(&Inst))
2137       Allocas.push_back(AI);
2138   }
2139 
2140   for (auto AI : Allocas) {
2141     ArgRefPattern ArgRef(AI);
2142     if (ArgRef.match(DT, PDT)) {
2143       ArgRef.process(DT);
2144       Modified = true;
2145     }
2146   }
2147 
2148   return Modified;
2149 }
2150 
createCMLowerVLoadVStorePass()2151 Pass *llvm::createCMLowerVLoadVStorePass() { return new CMLowerVLoadVStore; }
2152