1 /*========================== begin_copyright_notice ============================
2
3 Copyright (C) 2017-2021 Intel Corporation
4
5 SPDX-License-Identifier: MIT
6
7 ============================= end_copyright_notice ===========================*/
8
9 //===----------------------------------------------------------------------===//
10 //
11 /// CMABI
12 /// -----
13 ///
14 /// This pass fixes ABI issues for the genx backend. Currently, it
15 ///
16 /// - transforms pass by pointer argument into copy-in and copy-out;
17 ///
18 /// - localizes global scalar or vector variables into copy-in and copy-out;
19 ///
20 /// - passes bool arguments as i8 (matches cm-icl's hehavior).
21 ///
22 //===----------------------------------------------------------------------===//
23
24 #define DEBUG_TYPE "cmabi"
25
26 #include "llvmWrapper/Analysis/CallGraph.h"
27 #include "llvmWrapper/IR/CallSite.h"
28 #include "llvmWrapper/IR/DerivedTypes.h"
29 #include "llvmWrapper/IR/Instructions.h"
30 #include "llvmWrapper/Support/Alignment.h"
31
32 #include "Probe/Assertion.h"
33
34 #include "vc/GenXOpts/GenXOpts.h"
35 #include "vc/GenXOpts/Utils/KernelInfo.h"
36 #include "vc/Support/BackendConfig.h"
37 #include "vc/Utils/GenX/BreakConst.h"
38 #include "vc/Utils/GenX/Printf.h"
39 #include "vc/Utils/General/DebugInfo.h"
40 #include "vc/Utils/General/FunctionAttrs.h"
41 #include "vc/Utils/General/InstRebuilder.h"
42 #include "vc/Utils/General/STLExtras.h"
43 #include "vc/Utils/General/Types.h"
44
45 #include "llvm/ADT/DenseMap.h"
46 #include "llvm/ADT/PostOrderIterator.h"
47 #include "llvm/ADT/SCCIterator.h"
48 #include "llvm/ADT/STLExtras.h"
49 #include "llvm/ADT/SetVector.h"
50 #include "llvm/ADT/Statistic.h"
51 #include "llvm/Analysis/CallGraphSCCPass.h"
52 #include "llvm/Analysis/PostDominators.h"
53 #include "llvm/GenXIntrinsics/GenXIntrinsics.h"
54 #include "llvm/GenXIntrinsics/GenXMetadata.h"
55 #include "llvm/IR/CFG.h"
56 #include "llvm/IR/DebugInfo.h"
57 #include "llvm/IR/DiagnosticInfo.h"
58 #include "llvm/IR/DiagnosticPrinter.h"
59 #include "llvm/IR/Dominators.h"
60 #include "llvm/IR/Function.h"
61 #include "llvm/IR/GlobalVariable.h"
62 #include "llvm/IR/IRBuilder.h"
63 #include "llvm/IR/InstIterator.h"
64 #include "llvm/IR/InstVisitor.h"
65 #include "llvm/IR/IntrinsicInst.h"
66 #include "llvm/IR/Intrinsics.h"
67 #include "llvm/IR/Module.h"
68 #include "llvm/IR/Verifier.h"
69 #include "llvm/InitializePasses.h"
70 #include "llvm/Pass.h"
71 #include "llvm/Support/CommandLine.h"
72 #include "llvm/Support/Debug.h"
73 #include "llvm/Support/raw_ostream.h"
74 #include "llvm/Transforms/Scalar.h"
75 #include "llvm/Transforms/Utils/Local.h"
76
77 #include <algorithm>
78 #include <functional>
79 #include <iterator>
80 #include <numeric>
81 #include <stack>
82 #include <unordered_map>
83 #include <unordered_set>
84 #include <vector>
85
86 using namespace llvm;
87
88 STATISTIC(NumArgumentsTransformed, "Number of pointer arguments transformed");
89
90 namespace llvm {
91 void initializeCMABIAnalysisPass(PassRegistry &);
92 void initializeCMABIPass(PassRegistry &);
93 void initializeCMLowerVLoadVStorePass(PassRegistry &);
94 }
95
96 /// Localizing global variables
97 /// ^^^^^^^^^^^^^^^^^^^^^^^^^^^
98 ///
99 /// General idea of localizing global variables into locals. Globals used in
100 /// different kernels get a seperate copy and they are always invisiable to
101 /// other kernels and we can safely localize all globals used (including
102 /// indirectly) in a kernel. For example,
103 ///
104 /// .. code-block:: text
105 ///
106 /// @gv1 = global <8 x float> zeroinitializer, align 32
107 /// @gv2 = global <8 x float> zeroinitializer, align 32
108 /// @gv3 = global <8 x float> zeroinitializer, align 32
109 ///
110 /// define dllexport void @f0() {
111 /// call @f1()
112 /// call @f2()
113 /// call @f3()
114 /// }
115 ///
116 /// define internal void @f1() {
117 /// ; ...
118 /// store <8 x float> %splat1, <8 x float>* @gv1, align 32
119 /// }
120 ///
121 /// define internal void @f2() {
122 /// ; ...
123 /// store <8 x float> %splat2, <8 x float>* @gv2, align 32
124 /// }
125 ///
126 /// define internal void @f3() {
127 /// %1 = <8 x float>* @gv1, align 32
128 /// %2 = <8 x float>* @gv2, align 32
129 /// %3 = fadd <8 x float> %1, <8 x float> %2
130 /// store <8 x float> %3, <8 x float>* @gv3, align 32
131 /// }
132 ///
133 /// will be transformed into
134 ///
135 /// .. code-block:: text
136 ///
137 /// define dllexport void @f0() {
138 /// %v1 = alloca <8 x float>, align 32
139 /// %v2 = alloca <8 x float>, align 32
140 /// %v3 = alloca <8 x float>, align 32
141 ///
142 /// %0 = load <8 x float> * %v1, align 32
143 /// %1 = { <8 x float> } call @f1_transformed(<8 x float> %0)
144 /// %2 = extractvalue { <8 x float> } %1, 0
145 /// store <8 x float> %2, <8 x float>* %v1, align 32
146 ///
147 /// %3 = load <8 x float> * %v2, align 32
148 /// %4 = { <8 x float> } call @f2_transformed(<8 x float> %3)
149 /// %5 = extractvalue { <8 x float> } %4, 0
150 /// store <8 x float> %5, <8 x float>* %v1, align 32
151 ///
152 /// %6 = load <8 x float> * %v1, align 32
153 /// %7 = load <8 x float> * %v2, align 32
154 /// %8 = load <8 x float> * %v3, align 32
155 ///
156 /// %9 = { <8 x float>, <8 x float>, <8 x float> }
157 /// call @f3_transformed(<8 x float> %6, <8 x float> %7, <8 x float> %8)
158 ///
159 /// %10 = extractvalue { <8 x float>, <8 x float>, <8 x float> } %9, 0
160 /// store <8 x float> %10, <8 x float>* %v1, align 32
161 /// %11 = extractvalue { <8 x float>, <8 x float>, <8 x float> } %9, 1
162 /// store <8 x float> %11, <8 x float>* %v2, align 32
163 /// %12 = extractvalue { <8 x float>, <8 x float>, <8 x float> } %9, 2
164 /// store <8 x float> %12, <8 x float>* %v3, align 32
165 /// }
166 ///
167 /// All callees will be updated accordingly, E.g. f1_transformed becomes
168 ///
169 /// .. code-block:: text
170 ///
171 /// define internal { <8 x float> } @f1_transformed(<8 x float> %v1) {
172 /// %0 = alloca <8 x float>, align 32
173 /// store <8 x float> %v1, <8 x float>* %0, align 32
174 /// ; ...
175 /// store <8 x float> %splat1, <8 x float>* @0, align 32
176 /// ; ...
177 /// %1 = load <8 x float>* %0, align 32
178 /// %2 = insertvalue { <8 x float> } undef, <8 x float> %1, 0
179 /// ret { <8 x float> } %2
180 /// }
181 ///
182 namespace {
183
184 // \brief Collect necessary information for global variable localization.
185 class LocalizationInfo {
186 public:
187 typedef SetVector<GlobalVariable *> GlobalSetTy;
188
LocalizationInfo(Function * F)189 explicit LocalizationInfo(Function *F) : Fn(F) {}
LocalizationInfo()190 LocalizationInfo() : Fn(0) {}
191
getFunction() const192 Function *getFunction() const { return Fn; }
empty() const193 bool empty() const { return Globals.empty(); }
getGlobals()194 GlobalSetTy &getGlobals() { return Globals; }
195
196 // \brief Add a global.
addGlobal(GlobalVariable * GV)197 void addGlobal(GlobalVariable *GV) {
198 Globals.insert(GV);
199 }
200
201 // \brief Add all globals from callee.
addGlobals(LocalizationInfo & LI)202 void addGlobals(LocalizationInfo &LI) {
203 Globals.insert(LI.getGlobals().begin(), LI.getGlobals().end());
204 }
205
206 private:
207 // \brief The function being analyzed.
208 Function *Fn;
209
210 // \brief Global variables that are used directly or indirectly.
211 GlobalSetTy Globals;
212 };
213
214 // Diagnostic information for error/warning for overlapping arg
215 class DiagnosticInfoOverlappingArgs : public DiagnosticInfo {
216 private:
217 std::string Description;
218 StringRef Filename;
219 unsigned Line;
220 unsigned Col;
221 static int KindID;
getKindID()222 static int getKindID() {
223 if (KindID == 0)
224 KindID = llvm::getNextAvailablePluginDiagnosticKind();
225 return KindID;
226 }
227 public:
228 // Initialize from an Instruction and an Argument.
229 DiagnosticInfoOverlappingArgs(Instruction *Inst,
230 const Twine &Desc, DiagnosticSeverity Severity = DS_Error);
231 void print(DiagnosticPrinter &DP) const override;
232
classof(const DiagnosticInfo * DI)233 static bool classof(const DiagnosticInfo *DI) {
234 return DI->getKind() == getKindID();
235 }
236 };
237 int DiagnosticInfoOverlappingArgs::KindID = 0;
238
239 class CMABIAnalysis : public ModulePass {
240 // This map captures all global variables to be localized.
241 std::vector<LocalizationInfo *> LocalizationInfoObjs;
242 bool SaveStackCallLinkage = false;
243
244 public:
245 static char ID;
246
247 // Kernels in the module being processed.
248 SmallPtrSet<Function *, 8> Kernels;
249
250 // Map from function to the index of its LI in LI storage
251 SmallDenseMap<Function *, LocalizationInfo *> GlobalInfo;
252
253 // Function control option if any
254 FunctionControl FCtrl;
255
CMABIAnalysis()256 CMABIAnalysis() : ModulePass{ID} {}
257
getAnalysisUsage(AnalysisUsage & AU) const258 void getAnalysisUsage(AnalysisUsage &AU) const override {
259 AU.addRequired<CallGraphWrapperPass>();
260 AU.addRequired<GenXBackendConfig>();
261 AU.setPreservesAll();
262 }
263
getPassName() const264 StringRef getPassName() const override { return "GenX CMABI analysis"; }
265
266 bool runOnModule(Module &M) override;
267
releaseMemory()268 void releaseMemory() override {
269 for (auto *LI : LocalizationInfoObjs)
270 delete LI;
271 LocalizationInfoObjs.clear();
272 Kernels.clear();
273 GlobalInfo.clear();
274 }
275
276 // \brief Returns the localization info associated to a function.
getLocalizationInfo(Function * F)277 LocalizationInfo &getLocalizationInfo(Function *F) {
278 if (GlobalInfo.count(F))
279 return *GlobalInfo[F];
280 LocalizationInfo *LI = new LocalizationInfo{F};
281 LocalizationInfoObjs.push_back(LI);
282 GlobalInfo[F] = LI;
283 return *LI;
284 }
285
286 private:
287 bool runOnCallGraph(CallGraph &CG);
288 void analyzeGlobals(CallGraph &CG);
289
addDirectGlobal(Function * F,GlobalVariable * GV)290 void addDirectGlobal(Function *F, GlobalVariable *GV) {
291 getLocalizationInfo(F).addGlobal(GV);
292 }
293
294 // \brief Add all globals from callee to caller.
addIndirectGlobal(Function * F,Function * Callee)295 void addIndirectGlobal(Function *F, Function *Callee) {
296 getLocalizationInfo(F).addGlobals(getLocalizationInfo(Callee));
297 }
298
299 void defineGVDirectUsers(GlobalVariable &GV);
300 };
301
302 struct CMABI : public CallGraphSCCPass {
303 static char ID;
304
CMABI__anon278d8c6d0111::CMABI305 CMABI() : CallGraphSCCPass(ID) {
306 initializeCMABIPass(*PassRegistry::getPassRegistry());
307 }
308
getAnalysisUsage__anon278d8c6d0111::CMABI309 void getAnalysisUsage(AnalysisUsage &AU) const override {
310 CallGraphSCCPass::getAnalysisUsage(AU);
311 AU.addRequired<CMABIAnalysis>();
312 }
313
314 bool runOnSCC(CallGraphSCC &SCC) override;
315
316 private:
317
318 CallGraphNode *ProcessNode(CallGraphNode *CGN);
319
320 // Fix argument passing for kernels.
321 CallGraphNode *TransformKernel(Function *F);
322
323 // Major work is done in this method.
324 CallGraphNode *TransformNode(Function &F,
325 SmallPtrSet<Argument *, 8> &ArgsToTransform,
326 LocalizationInfo &LI);
327
328 // \brief Create allocas for globals and replace their uses.
329 void LocalizeGlobals(LocalizationInfo &LI);
330
331 // Return true if pointer type arugment is only used to
332 // load or store a simple value. This helps decide whehter
333 // it is safe to convert ptr arg to by-value arg or
334 // simple-value copy-in-copy-out.
335 bool OnlyUsedBySimpleValueLoadStore(Value *Arg);
336
337 // \brief Diagnose illegal overlapping by-ref args.
338 void diagnoseOverlappingArgs(CallInst *CI);
339
340 // Already visited functions.
341 SmallPtrSet<Function *, 8> AlreadyVisited;
342 CMABIAnalysis *Info;
343 };
344
345 } // namespace
346
347 char CMABIAnalysis::ID = 0;
348 INITIALIZE_PASS_BEGIN(CMABIAnalysis, "cmabi-analysis",
349 "helper analysis pass to get info for CMABI", false, true)
INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)350 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
351 INITIALIZE_PASS_DEPENDENCY(GenXBackendConfig)
352 INITIALIZE_PASS_END(CMABIAnalysis, "cmabi-analysis",
353 "Fix ABI issues for the genx backend", false, true)
354
355 bool CMABIAnalysis::runOnModule(Module &M) {
356 auto &&BCfg = getAnalysis<GenXBackendConfig>();
357 FCtrl = BCfg.getFCtrl();
358 SaveStackCallLinkage = BCfg.saveStackCallLinkage();
359
360 runOnCallGraph(getAnalysis<CallGraphWrapperPass>().getCallGraph());
361 return false;
362 }
363
runOnCallGraph(CallGraph & CG)364 bool CMABIAnalysis::runOnCallGraph(CallGraph &CG) {
365 // Analyze global variable usages and for each function attaches global
366 // variables to be copy-in and copy-out.
367 analyzeGlobals(CG);
368
369 auto getValue = [](Metadata *M) -> Value * {
370 if (auto VM = dyn_cast<ValueAsMetadata>(M))
371 return VM->getValue();
372 return nullptr;
373 };
374
375 // Collect all CM kernels from named metadata.
376 if (NamedMDNode *Named =
377 CG.getModule().getNamedMetadata(genx::FunctionMD::GenXKernels)) {
378 IGC_ASSERT(Named);
379 for (unsigned I = 0, E = Named->getNumOperands(); I != E; ++I) {
380 MDNode *Node = Named->getOperand(I);
381 if (Function *F =
382 dyn_cast_or_null<Function>(getValue(Node->getOperand(0))))
383 Kernels.insert(F);
384 }
385 }
386
387 // no change.
388 return false;
389 }
390
runOnSCC(CallGraphSCC & SCC)391 bool CMABI::runOnSCC(CallGraphSCC &SCC) {
392 Info = &getAnalysis<CMABIAnalysis>();
393 bool Changed = false;
394 bool LocalChange;
395
396 // Diagnose overlapping by-ref args.
397 for (auto i = SCC.begin(), e = SCC.end(); i != e; ++i) {
398 Function *F = (*i)->getFunction();
399 if (!F || F->empty())
400 continue;
401 for (auto ui = F->use_begin(), ue = F->use_end(); ui != ue; ++ui) {
402 auto CI = dyn_cast<CallInst>(ui->getUser());
403 if (CI && CI->getNumArgOperands() == ui->getOperandNo())
404 diagnoseOverlappingArgs(CI);
405 }
406 }
407
408 // Iterate until we stop transforming from this SCC.
409 do {
410 LocalChange = false;
411 for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); I != E; ++I) {
412 if (CallGraphNode *CGN = ProcessNode(*I)) {
413 LocalChange = true;
414 SCC.ReplaceNode(*I, CGN);
415 }
416 }
417 Changed |= LocalChange;
418 } while (LocalChange);
419
420 return Changed;
421 }
422
423 // Replaces uses of global variables with the corresponding allocas inside a
424 // specified function. More insts can be rebuild if global variable addrspace
425 // wasn't private.
replaceUsesWithinFunction(const SmallDenseMap<Value *,Value * > & GlobalsToReplace,Function * F)426 static void replaceUsesWithinFunction(
427 const SmallDenseMap<Value *, Value *> &GlobalsToReplace, Function *F) {
428 for (auto &BB : *F) {
429 for (auto &Inst : BB) {
430 for (unsigned i = 0, e = Inst.getNumOperands(); i < e; ++i) {
431 Value *Op = Inst.getOperand(i);
432 auto Iter = GlobalsToReplace.find(Op);
433 if (Iter != GlobalsToReplace.end()) {
434 IGC_ASSERT_MESSAGE(Op->getType() == Iter->second->getType(),
435 "only global variables in private addrspace are "
436 "localized, so types must match");
437 Inst.setOperand(i, Iter->second);
438 }
439 }
440 }
441 }
442 }
443
444 // \brief Create allocas for globals directly used in this kernel and
445 // replace all uses.
446 //
447 // FIXME: it is not always posible to localize globals with addrspace different
448 // from private. In some cases type info link is lost - casts, stores of
449 // pointers.
LocalizeGlobals(LocalizationInfo & LI)450 void CMABI::LocalizeGlobals(LocalizationInfo &LI) {
451 const LocalizationInfo::GlobalSetTy &Globals = LI.getGlobals();
452 typedef LocalizationInfo::GlobalSetTy::const_iterator IteratorTy;
453
454 SmallDenseMap<Value *, Value *> GlobalsToReplace;
455 Function *Fn = LI.getFunction();
456 for (IteratorTy I = Globals.begin(), E = Globals.end(); I != E; ++I) {
457 GlobalVariable *GV = (*I);
458 LLVM_DEBUG(dbgs() << "Localizing global: " << *GV << "\n ");
459
460 Instruction &FirstI = *Fn->getEntryBlock().begin();
461 Type *ElemTy = GV->getType()->getElementType();
462 AllocaInst *Alloca = new AllocaInst(ElemTy, vc::AddrSpace::Private,
463 GV->getName() + ".local", &FirstI);
464
465 if (GV->getAlignment())
466 Alloca->setAlignment(IGCLLVM::getCorrectAlign(GV->getAlignment()));
467
468 if (!isa<UndefValue>(GV->getInitializer()))
469 new StoreInst(GV->getInitializer(), Alloca, &FirstI);
470
471 vc::DIBuilder::createDbgDeclareForLocalizedGlobal(*Alloca, *GV, FirstI);
472
473 GlobalsToReplace.insert(std::make_pair(GV, Alloca));
474 }
475
476 // Replaces all globals uses within this function.
477 replaceUsesWithinFunction(GlobalsToReplace, Fn);
478 }
479
ProcessNode(CallGraphNode * CGN)480 CallGraphNode *CMABI::ProcessNode(CallGraphNode *CGN) {
481 Function *F = CGN->getFunction();
482
483 // Nothing to do for declarations or already visited functions.
484 if (!F || F->isDeclaration() || AlreadyVisited.count(F))
485 return 0;
486
487 vc::breakConstantExprs(F, vc::LegalizationStage::NotLegalized);
488
489 // Variables to be localized.
490 LocalizationInfo &LI = Info->getLocalizationInfo(F);
491
492 // This is a kernel.
493 if (Info->Kernels.count(F)) {
494 // Localize globals for kernels.
495 if (!LI.getGlobals().empty())
496 LocalizeGlobals(LI);
497
498 // Check whether there are i1 or vxi1 kernel arguments.
499 for (auto AI = F->arg_begin(), AE = F->arg_end(); AI != AE; ++AI)
500 if (AI->getType()->getScalarType()->isIntegerTy(1))
501 return TransformKernel(F);
502
503 // No changes to this kernel's prototype.
504 return 0;
505 }
506
507 // Have to localize implicit arg globals in functions with fixed signature.
508 // FIXME: There's no verification that globals are for implicit args. General
509 // private globals may be localized here, but it is not possible to
510 // use them in such functions at all. A nice place for diagnostics.
511 if (vc::isFixedSignatureFunc(*F)) {
512 if (!LI.getGlobals().empty())
513 LocalizeGlobals(LI);
514 return nullptr;
515 }
516
517 SmallVector<Argument*, 16> PointerArgs;
518 for (auto &Arg: F->args())
519 if (Arg.getType()->isPointerTy())
520 PointerArgs.push_back(&Arg);
521
522 // Check if there is any pointer arguments or globals to localize.
523 if (PointerArgs.empty() && LI.empty())
524 return 0;
525
526 // Check transformable arguments.
527 SmallPtrSet<Argument*, 8> ArgsToTransform;
528 for (Argument *PtrArg: PointerArgs) {
529 Type *ArgTy = cast<PointerType>(PtrArg->getType())->getElementType();
530 // Only transform to simple types.
531 if ((ArgTy->isVectorTy() || OnlyUsedBySimpleValueLoadStore(PtrArg)) &&
532 (ArgTy->isIntOrIntVectorTy() || ArgTy->isFPOrFPVectorTy()))
533 ArgsToTransform.insert(PtrArg);
534 }
535
536 if (ArgsToTransform.empty() && LI.empty())
537 return 0;
538
539 return TransformNode(*F, ArgsToTransform, LI);
540 }
541
542 // Returns true if data is only read using load-like intrinsics. The result may
543 // be false negative.
isSinkedToLoadIntrinsics(const Instruction * Inst)544 static bool isSinkedToLoadIntrinsics(const Instruction *Inst) {
545 if (isa<CallInst>(Inst)) {
546 auto *CI = cast<CallInst>(Inst);
547 auto IID = GenXIntrinsic::getAnyIntrinsicID(CI->getCalledFunction());
548 return IID == GenXIntrinsic::genx_svm_gather ||
549 IID == GenXIntrinsic::genx_gather_scaled;
550 }
551 return std::all_of(Inst->user_begin(), Inst->user_end(), [](const User *U) {
552 if (isa<InsertElementInst>(U) || isa<ShuffleVectorInst>(U) ||
553 isa<BinaryOperator>(U) || isa<CallInst>(U))
554 return isSinkedToLoadIntrinsics(cast<Instruction>(U));
555 return false;
556 });
557 }
558
559 // Arg is a ptr to a vector type. If data is only read using load, then false is
560 // returned. Otherwise, or if it is not clear, true is returned. This is a
561 // recursive function. The result may be false positive.
isPtrArgModified(const Value & Arg)562 static bool isPtrArgModified(const Value &Arg) {
563 // User iterator returns pointer both for star and arrow operators, because...
564 return std::any_of(Arg.user_begin(), Arg.user_end(), [](const User *U) {
565 if (isa<LoadInst>(U))
566 return false;
567 if (isa<AddrSpaceCastInst>(U) || isa<BitCastInst>(U) ||
568 isa<GetElementPtrInst>(U))
569 return isPtrArgModified(*U);
570 if (isa<PtrToIntInst>(U))
571 return !isSinkedToLoadIntrinsics(cast<Instruction>(U));
572 return true;
573 });
574 }
575
OnlyUsedBySimpleValueLoadStore(Value * Arg)576 bool CMABI::OnlyUsedBySimpleValueLoadStore(Value *Arg) {
577 for (const auto &U : Arg->users()) {
578 auto *I = dyn_cast<Instruction>(U);
579 if (!I)
580 return false;
581
582 if (auto LI = dyn_cast<LoadInst>(U)) {
583 if (Arg != LI->getPointerOperand())
584 return false;
585 }
586 else if (auto SI = dyn_cast<StoreInst>(U)) {
587 if (Arg != SI->getPointerOperand())
588 return false;
589 }
590 else if (auto GEP = dyn_cast<GetElementPtrInst>(U)) {
591 if (Arg != GEP->getPointerOperand())
592 return false;
593 else if (!GEP->hasAllZeroIndices())
594 return false;
595 if (!OnlyUsedBySimpleValueLoadStore(U))
596 return false;
597 }
598 else if (isa<AddrSpaceCastInst>(U) || isa<PtrToIntInst>(U)) {
599 if (!OnlyUsedBySimpleValueLoadStore(U))
600 return false;
601 }
602 else
603 return false;
604 }
605 return true;
606 }
607
608 // \brief Fix argument passing for kernels: i1 -> i8.
TransformKernel(Function * F)609 CallGraphNode *CMABI::TransformKernel(Function *F) {
610 IGC_ASSERT(F->getReturnType()->isVoidTy());
611 LLVMContext &Context = F->getContext();
612
613 AttributeList AttrVec;
614 const AttributeList &PAL = F->getAttributes();
615
616 // First, determine the new argument list
617 SmallVector<Type *, 8> ArgTys;
618 unsigned ArgIndex = 0;
619 for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
620 ++I, ++ArgIndex) {
621 Type *ArgTy = I->getType();
622 // Change i1 to i8 and vxi1 to vxi8
623 if (ArgTy->getScalarType()->isIntegerTy(1)) {
624 Type *Ty = IntegerType::get(F->getContext(), 8);
625 if (ArgTy->isVectorTy())
626 ArgTys.push_back(IGCLLVM::FixedVectorType::get(
627 Ty, dyn_cast<IGCLLVM::FixedVectorType>(ArgTy)->getNumElements()));
628 else
629 ArgTys.push_back(Ty);
630 } else {
631 // Unchanged argument
632 AttributeSet attrs = PAL.getParamAttributes(ArgIndex);
633 if (attrs.hasAttributes()) {
634 AttrBuilder B(attrs);
635 AttrVec = AttrVec.addParamAttributes(Context, ArgTys.size(), B);
636 }
637 ArgTys.push_back(I->getType());
638 }
639 }
640
641 FunctionType *NFTy = FunctionType::get(F->getReturnType(), ArgTys, false);
642 IGC_ASSERT_MESSAGE((NFTy != F->getFunctionType()),
643 "type out of sync, expect bool arguments");
644
645 // Add any function attributes.
646 AttributeSet FnAttrs = PAL.getFnAttributes();
647 if (FnAttrs.hasAttributes()) {
648 AttrBuilder B(FnAttrs);
649 AttrVec = AttrVec.addAttributes(Context, AttributeList::FunctionIndex, B);
650 }
651
652 // Create the new function body and insert it into the module.
653 Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName());
654
655 LLVM_DEBUG(dbgs() << "\nCMABI: Transforming From:" << *F);
656 vc::transferNameAndCCWithNewAttr(AttrVec, *F, *NF);
657 F->getParent()->getFunctionList().insert(F->getIterator(), NF);
658 vc::transferDISubprogram(*F, *NF);
659 LLVM_DEBUG(dbgs() << " --> To: " << *NF << "\n");
660
661 // Since we have now created the new function, splice the body of the old
662 // function right into the new function.
663 NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());
664
665 // Loop over the argument list, transferring uses of the old arguments over to
666 // the new arguments, also transferring over the names as well.
667 for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(),
668 I2 = NF->arg_begin();
669 I != E; ++I, ++I2) {
670 // For an unmodified argument, move the name and users over.
671 if (!I->getType()->getScalarType()->isIntegerTy(1)) {
672 I->replaceAllUsesWith(I2);
673 I2->takeName(I);
674 } else {
675 Instruction *InsertPt = &*(NF->begin()->begin());
676 Instruction *Conv = new TruncInst(I2, I->getType(), "tobool", InsertPt);
677 I->replaceAllUsesWith(Conv);
678 I2->takeName(I);
679 }
680 }
681
682 CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
683 CallGraphNode *NF_CGN = CG.getOrInsertFunction(NF);
684
685 // Update the metadata entry.
686 if (F->hasDLLExportStorageClass())
687 NF->setDLLStorageClass(F->getDLLStorageClass());
688
689 genx::replaceFunctionRefMD(*F, *NF);
690
691 // Now that the old function is dead, delete it. If there is a dangling
692 // reference to the CallgraphNode, just leave the dead function around.
693 NF_CGN->stealCalledFunctionsFrom(CG[F]);
694 CallGraphNode *CGN = CG[F];
695 if (CGN->getNumReferences() == 0)
696 delete CG.removeFunctionFromModule(CGN);
697 else
698 F->setLinkage(Function::ExternalLinkage);
699
700 return NF_CGN;
701 }
702
703 namespace {
704 struct TransformedFuncType {
705 SmallVector<Type *, 8> Ret;
706 SmallVector<Type *, 8> Args;
707 };
708
709 enum class ArgKind { General, CopyIn, CopyInOut };
710 enum class GlobalArgKind { ByValueIn, ByValueInOut, ByPointer };
711
712 struct GlobalArgInfo {
713 GlobalVariable *GV;
714 GlobalArgKind Kind;
715 };
716
717 struct GlobalArgsInfo {
718 static constexpr int UndefIdx = -1;
719 std::vector<GlobalArgInfo> Globals;
720 int FirstGlobalArgIdx = UndefIdx;
721
getGlobalInfoForArgNo__anon278d8c6d0511::GlobalArgsInfo722 GlobalArgInfo getGlobalInfoForArgNo(int ArgIdx) const {
723 IGC_ASSERT_MESSAGE(FirstGlobalArgIdx != UndefIdx,
724 "first global arg index isn't set");
725 auto Idx = ArgIdx - FirstGlobalArgIdx;
726 IGC_ASSERT_MESSAGE(Idx >= 0, "out of bound access");
727 IGC_ASSERT_MESSAGE(Idx < static_cast<int>(Globals.size()),
728 "out of bound access");
729 return Globals[ArgIdx - FirstGlobalArgIdx];
730 }
731
getGlobalForArgNo__anon278d8c6d0511::GlobalArgsInfo732 GlobalVariable *getGlobalForArgNo(int ArgIdx) const {
733 return getGlobalInfoForArgNo(ArgIdx).GV;
734 }
735 };
736
737 struct RetToArgInfo {
738 static constexpr int OrigRetNoArg = -1;
739 std::vector<int> Map;
740 };
741
742 // Whether provided \p GV should be passed by pointer.
passLocalizedGlobalByPointer(const GlobalValue & GV)743 static bool passLocalizedGlobalByPointer(const GlobalValue &GV) {
744 auto *Type = GV.getType()->getPointerElementType();
745 return Type->isAggregateType();
746 }
747
748 struct ParameterAttrInfo {
749 unsigned ArgIndex;
750 Attribute::AttrKind Attr;
751 };
752
753 // Computing a new prototype for the function. E.g.
754 //
755 // i32 @foo(i32, <8 x i32>*) becomes {i32, <8 x i32>} @bar(i32, <8 x i32>)
756 //
757 class TransformedFuncInfo {
758 TransformedFuncType NewFuncType;
759 AttributeList Attrs;
760 std::vector<ArgKind> ArgKinds;
761 std::vector<ParameterAttrInfo> DiscardedParameterAttrs;
762 RetToArgInfo RetToArg;
763 GlobalArgsInfo GlobalArgs;
764
765 public:
TransformedFuncInfo(Function & OrigFunc,SmallPtrSetImpl<Argument * > & ArgsToTransform)766 TransformedFuncInfo(Function &OrigFunc,
767 SmallPtrSetImpl<Argument *> &ArgsToTransform) {
768 FillCopyInOutInfo(OrigFunc, ArgsToTransform);
769 std::transform(OrigFunc.arg_begin(), OrigFunc.arg_end(), std::back_inserter(NewFuncType.Args),
770 [&ArgsToTransform](Argument &Arg) {
771 if (ArgsToTransform.count(&Arg))
772 return Arg.getType()->getPointerElementType();
773 return Arg.getType();
774 });
775 InheritAttributes(OrigFunc);
776
777 // struct-returns are not supported for transformed functions,
778 // so we need to discard the attribute
779 if (OrigFunc.hasStructRetAttr() && OrigFunc.hasLocalLinkage())
780 DiscardStructRetAttr(OrigFunc.getContext());
781
782 auto *OrigRetTy = OrigFunc.getFunctionType()->getReturnType();
783 if (!OrigRetTy->isVoidTy()) {
784 NewFuncType.Ret.push_back(OrigRetTy);
785 RetToArg.Map.push_back(RetToArgInfo::OrigRetNoArg);
786 }
787 AppendRetCopyOutInfo();
788 }
789
AppendGlobals(LocalizationInfo & LI)790 void AppendGlobals(LocalizationInfo &LI) {
791 IGC_ASSERT_MESSAGE(GlobalArgs.FirstGlobalArgIdx == GlobalArgsInfo::UndefIdx,
792 "can only be initialized once");
793 GlobalArgs.FirstGlobalArgIdx = NewFuncType.Args.size();
794 for (auto *GV : LI.getGlobals()) {
795 if (passLocalizedGlobalByPointer(*GV)) {
796 NewFuncType.Args.push_back(vc::changeAddrSpace(
797 cast<PointerType>(GV->getType()), vc::AddrSpace::Private));
798 GlobalArgs.Globals.push_back({GV, GlobalArgKind::ByPointer});
799 } else {
800 int ArgIdx = NewFuncType.Args.size();
801 Type *PointeeTy = GV->getType()->getPointerElementType();
802 NewFuncType.Args.push_back(PointeeTy);
803 if (GV->isConstant())
804 GlobalArgs.Globals.push_back({GV, GlobalArgKind::ByValueIn});
805 else {
806 GlobalArgs.Globals.push_back({GV, GlobalArgKind::ByValueInOut});
807 NewFuncType.Ret.push_back(PointeeTy);
808 RetToArg.Map.push_back(ArgIdx);
809 }
810 }
811 }
812 }
813
getType() const814 const TransformedFuncType &getType() const { return NewFuncType; }
getAttributes() const815 AttributeList getAttributes() const { return Attrs; }
getArgKinds() const816 const std::vector<ArgKind> &getArgKinds() const { return ArgKinds; }
getDiscardedParameterAttrs() const817 const std::vector<ParameterAttrInfo> &getDiscardedParameterAttrs() const {
818 return DiscardedParameterAttrs;
819 }
getGlobalArgsInfo() const820 const GlobalArgsInfo &getGlobalArgsInfo() const { return GlobalArgs; }
getRetToArgInfo() const821 const RetToArgInfo &getRetToArgInfo() const { return RetToArg; }
822
823 private:
FillCopyInOutInfo(Function & OrigFunc,SmallPtrSetImpl<Argument * > & ArgsToTransform)824 void FillCopyInOutInfo(Function &OrigFunc,
825 SmallPtrSetImpl<Argument *> &ArgsToTransform) {
826 IGC_ASSERT_MESSAGE(ArgKinds.empty(),
827 "shouldn't be filled before this method");
828 llvm::transform(OrigFunc.args(), std::back_inserter(ArgKinds),
829 [&ArgsToTransform](Argument &Arg) {
830 if (!ArgsToTransform.count(&Arg))
831 return ArgKind::General;
832 if (isPtrArgModified(Arg))
833 return ArgKind::CopyInOut;
834 return ArgKind::CopyIn;
835 });
836 }
837
InheritAttributes(Function & OrigFunc)838 void InheritAttributes(Function &OrigFunc) {
839 LLVMContext &Context = OrigFunc.getContext();
840 const AttributeList &OrigAttrs = OrigFunc.getAttributes();
841
842 // Inherit argument attributes
843 for (auto ArgInfo : enumerate(ArgKinds)) {
844 if (ArgInfo.value() == ArgKind::General) {
845 AttributeSet ArgAttrs = OrigAttrs.getParamAttributes(ArgInfo.index());
846 if (ArgAttrs.hasAttributes())
847 Attrs = Attrs.addParamAttributes(Context, ArgInfo.index(),
848 AttrBuilder{ArgAttrs});
849 }
850 }
851
852 // Inherit function attributes.
853 AttributeSet FnAttrs = OrigAttrs.getFnAttributes();
854 if (FnAttrs.hasAttributes()) {
855 AttrBuilder B(FnAttrs);
856 Attrs = Attrs.addAttributes(Context, AttributeList::FunctionIndex, B);
857 }
858 }
859
DiscardStructRetAttr(LLVMContext & Context)860 void DiscardStructRetAttr(LLVMContext &Context) {
861 constexpr auto SretAttr = Attribute::StructRet;
862 for (auto ArgInfo : enumerate(ArgKinds)) {
863 unsigned ParamIndex = ArgInfo.index();
864 if (Attrs.hasParamAttr(ParamIndex, SretAttr)) {
865 Attrs = Attrs.removeParamAttribute(Context, ParamIndex, SretAttr);
866 DiscardedParameterAttrs.push_back({ParamIndex, SretAttr});
867 }
868 }
869 }
870
AppendRetCopyOutInfo()871 void AppendRetCopyOutInfo() {
872 for (auto ArgInfo : enumerate(ArgKinds)) {
873 if (ArgInfo.value() == ArgKind::CopyInOut) {
874 NewFuncType.Ret.push_back(NewFuncType.Args[ArgInfo.index()]);
875 RetToArg.Map.push_back(ArgInfo.index());
876 }
877 }
878 }
879 };
880 } // namespace
881
getRetType(LLVMContext & Context,const TransformedFuncType & TFType)882 static Type *getRetType(LLVMContext &Context,
883 const TransformedFuncType &TFType) {
884 if (TFType.Ret.empty())
885 return Type::getVoidTy(Context);
886 return StructType::get(Context, TFType.Ret);
887 }
888
createTransformedFuncDecl(Function & OrigFunc,const TransformedFuncInfo & TFuncInfo)889 Function *createTransformedFuncDecl(Function &OrigFunc,
890 const TransformedFuncInfo &TFuncInfo) {
891 LLVMContext &Context = OrigFunc.getContext();
892 // Construct the new function type using the new arguments.
893 FunctionType *NewFuncTy = FunctionType::get(
894 getRetType(Context, TFuncInfo.getType()), TFuncInfo.getType().Args,
895 OrigFunc.getFunctionType()->isVarArg());
896
897 // Create the new function body and insert it into the module.
898 Function *NewFunc =
899 Function::Create(NewFuncTy, OrigFunc.getLinkage(), OrigFunc.getName());
900
901 LLVM_DEBUG(dbgs() << "\nCMABI: Transforming From:" << OrigFunc);
902 vc::transferNameAndCCWithNewAttr(TFuncInfo.getAttributes(), OrigFunc,
903 *NewFunc);
904 OrigFunc.getParent()->getFunctionList().insert(OrigFunc.getIterator(),
905 NewFunc);
906 vc::transferDISubprogram(OrigFunc, *NewFunc);
907 LLVM_DEBUG(dbgs() << " --> To: " << *NewFunc << "\n");
908
909 return NewFunc;
910 }
911
912 static std::vector<Value *>
getTransformedFuncCallArgs(CallInst & OrigCall,const TransformedFuncInfo & NewFuncInfo)913 getTransformedFuncCallArgs(CallInst &OrigCall,
914 const TransformedFuncInfo &NewFuncInfo) {
915 std::vector<Value *> NewCallOps;
916
917 // Loop over the operands, inserting loads in the caller.
918 for (auto &&[OrigArg, Kind] :
919 zip(IGCLLVM::args(OrigCall), NewFuncInfo.getArgKinds())) {
920 switch (Kind) {
921 case ArgKind::General:
922 NewCallOps.push_back(OrigArg.get());
923 break;
924 default: {
925 IGC_ASSERT_MESSAGE(Kind == ArgKind::CopyIn || Kind == ArgKind::CopyInOut,
926 "unexpected arg kind");
927 LoadInst *Load =
928 new LoadInst(OrigArg.get()->getType()->getPointerElementType(),
929 OrigArg.get(), OrigArg.get()->getName() + ".val",
930 /* isVolatile */ false, &OrigCall);
931 NewCallOps.push_back(Load);
932 break;
933 }
934 }
935 }
936
937 IGC_ASSERT_MESSAGE(NewCallOps.size() == IGCLLVM::arg_size(OrigCall),
938 "varargs are unexpected");
939 return std::move(NewCallOps);
940 }
941
942 static AttributeList
inheritCallAttributes(CallInst & OrigCall,int NumOrigFuncArgs,const TransformedFuncInfo & NewFuncInfo)943 inheritCallAttributes(CallInst &OrigCall, int NumOrigFuncArgs,
944 const TransformedFuncInfo &NewFuncInfo) {
945 IGC_ASSERT_MESSAGE(OrigCall.getNumArgOperands() == NumOrigFuncArgs,
946 "varargs aren't supported");
947 AttributeList NewCallAttrs;
948
949 const AttributeList &CallPAL = OrigCall.getAttributes();
950 auto &Context = OrigCall.getContext();
951 for (auto ArgInfo : enumerate(NewFuncInfo.getArgKinds())) {
952 if (ArgInfo.value() == ArgKind::General) {
953 AttributeSet attrs =
954 OrigCall.getAttributes().getParamAttributes(ArgInfo.index());
955 if (attrs.hasAttributes()) {
956 AttrBuilder B(attrs);
957 NewCallAttrs =
958 NewCallAttrs.addParamAttributes(Context, ArgInfo.index(), B);
959 }
960 }
961 }
962
963 for (auto DiscardInfo : NewFuncInfo.getDiscardedParameterAttrs()) {
964 NewCallAttrs = NewCallAttrs.removeParamAttribute(
965 Context, DiscardInfo.ArgIndex, DiscardInfo.Attr);
966 }
967
968 // Add any function attributes.
969 if (CallPAL.hasAttributes(AttributeList::FunctionIndex)) {
970 AttrBuilder B(CallPAL.getFnAttributes());
971 NewCallAttrs =
972 NewCallAttrs.addAttributes(Context, AttributeList::FunctionIndex, B);
973 }
974
975 return std::move(NewCallAttrs);
976 }
977
handleRetValuePortion(int RetIdx,int ArgIdx,CallInst & OrigCall,CallInst & NewCall,IRBuilder<> & Builder,const TransformedFuncInfo & NewFuncInfo)978 static void handleRetValuePortion(int RetIdx, int ArgIdx, CallInst &OrigCall,
979 CallInst &NewCall, IRBuilder<> &Builder,
980 const TransformedFuncInfo &NewFuncInfo) {
981 // Original return value.
982 if (ArgIdx == RetToArgInfo::OrigRetNoArg) {
983 IGC_ASSERT_MESSAGE(RetIdx == 0, "only zero element of returned value can "
984 "be original function argument");
985 OrigCall.replaceAllUsesWith(
986 Builder.CreateExtractValue(&NewCall, RetIdx, "ret"));
987 return;
988 }
989 Value *OutVal = Builder.CreateExtractValue(&NewCall, RetIdx);
990 if (ArgIdx >= NewFuncInfo.getGlobalArgsInfo().FirstGlobalArgIdx) {
991 auto Kind =
992 NewFuncInfo.getGlobalArgsInfo().getGlobalInfoForArgNo(ArgIdx).Kind;
993 IGC_ASSERT_MESSAGE(Kind == GlobalArgKind::ByValueInOut,
994 "only passed by value localized global should be copied-out");
995 Builder.CreateStore(
996 OutVal, NewFuncInfo.getGlobalArgsInfo().getGlobalForArgNo(ArgIdx));
997 } else {
998 IGC_ASSERT_MESSAGE(NewFuncInfo.getArgKinds()[ArgIdx] == ArgKind::CopyInOut,
999 "only copy in-out args are expected");
1000 Builder.CreateStore(OutVal, OrigCall.getArgOperand(ArgIdx));
1001 }
1002 }
1003
handleGlobalArgs(Function & NewFunc,const GlobalArgsInfo & GlobalArgs)1004 static std::vector<Value *> handleGlobalArgs(Function &NewFunc,
1005 const GlobalArgsInfo &GlobalArgs) {
1006 // Collect all globals and their corresponding allocas.
1007 std::vector<Value *> LocalizedGloabls;
1008 Instruction *InsertPt = &*(NewFunc.begin()->getFirstInsertionPt());
1009
1010 llvm::transform(drop_begin(NewFunc.args(), GlobalArgs.FirstGlobalArgIdx),
1011 std::back_inserter(LocalizedGloabls),
1012 [InsertPt](Argument &GVArg) -> Value * {
1013 if (GVArg.getType()->isPointerTy())
1014 return &GVArg;
1015 AllocaInst *Alloca = new AllocaInst(
1016 GVArg.getType(), vc::AddrSpace::Private, "", InsertPt);
1017 new StoreInst(&GVArg, Alloca, InsertPt);
1018 return Alloca;
1019 });
1020 // Fancy naming and debug info.
1021 for (auto &&[GAI, GVArg, MaybeAlloca] :
1022 zip(GlobalArgs.Globals,
1023 drop_begin(NewFunc.args(), GlobalArgs.FirstGlobalArgIdx),
1024 LocalizedGloabls)) {
1025 GVArg.setName(GAI.GV->getName() + ".in");
1026 if (!GVArg.getType()->isPointerTy()) {
1027 IGC_ASSERT_MESSAGE(isa<AllocaInst>(MaybeAlloca),
1028 "an alloca is expected when pass localized global by value");
1029 MaybeAlloca->setName(GAI.GV->getName() + ".local");
1030
1031 vc::DIBuilder::createDbgDeclareForLocalizedGlobal(
1032 *cast<AllocaInst>(MaybeAlloca), *GAI.GV, *InsertPt);
1033 }
1034 }
1035
1036 SmallDenseMap<Value *, Value *> GlobalsToReplace;
1037 for (auto &&[GAI, LocalizedGlobal] :
1038 zip(GlobalArgs.Globals, LocalizedGloabls))
1039 GlobalsToReplace.insert(std::make_pair(GAI.GV, LocalizedGlobal));
1040 // Replaces all globals uses within this new function.
1041 replaceUsesWithinFunction(GlobalsToReplace, &NewFunc);
1042 return LocalizedGloabls;
1043 }
1044
1045 static Value *
appendTransformedFuncRetPortion(Value & NewRetVal,int RetIdx,int ArgIdx,ReturnInst & OrigRet,IRBuilder<> & Builder,const TransformedFuncInfo & NewFuncInfo,const std::vector<Value * > & OrigArgReplacements,std::vector<Value * > & LocalizedGlobals)1046 appendTransformedFuncRetPortion(Value &NewRetVal, int RetIdx, int ArgIdx,
1047 ReturnInst &OrigRet, IRBuilder<> &Builder,
1048 const TransformedFuncInfo &NewFuncInfo,
1049 const std::vector<Value *> &OrigArgReplacements,
1050 std::vector<Value *> &LocalizedGlobals) {
1051 if (ArgIdx == RetToArgInfo::OrigRetNoArg) {
1052 IGC_ASSERT_MESSAGE(RetIdx == 0,
1053 "original return value must be at zero index");
1054 Value *OrigRetVal = OrigRet.getReturnValue();
1055 IGC_ASSERT_MESSAGE(OrigRetVal, "type unexpected");
1056 IGC_ASSERT_MESSAGE(OrigRetVal->getType()->isSingleValueType(),
1057 "type unexpected");
1058 return Builder.CreateInsertValue(&NewRetVal, OrigRetVal, RetIdx);
1059 }
1060 if (ArgIdx >= NewFuncInfo.getGlobalArgsInfo().FirstGlobalArgIdx) {
1061 auto Kind =
1062 NewFuncInfo.getGlobalArgsInfo().getGlobalInfoForArgNo(ArgIdx).Kind;
1063 IGC_ASSERT_MESSAGE(Kind == GlobalArgKind::ByValueInOut,
1064 "only passed by value localized global should be copied-out");
1065 Value *LocalizedGlobal =
1066 LocalizedGlobals[ArgIdx -
1067 NewFuncInfo.getGlobalArgsInfo().FirstGlobalArgIdx];
1068 IGC_ASSERT_MESSAGE(isa<AllocaInst>(LocalizedGlobal),
1069 "an alloca is expected when pass localized global by value");
1070 Value *LocalizedGlobalVal = Builder.CreateLoad(
1071 LocalizedGlobal->getType()->getPointerElementType(), LocalizedGlobal);
1072 return Builder.CreateInsertValue(&NewRetVal, LocalizedGlobalVal, RetIdx);
1073 }
1074 IGC_ASSERT_MESSAGE(NewFuncInfo.getArgKinds()[ArgIdx] == ArgKind::CopyInOut,
1075 "Only copy in-out values are expected");
1076 Value *CurRetByPtr = OrigArgReplacements[ArgIdx];
1077 IGC_ASSERT_MESSAGE(isa<PointerType>(CurRetByPtr->getType()),
1078 "a pointer is expected");
1079 if (isa<AddrSpaceCastInst>(CurRetByPtr))
1080 CurRetByPtr = cast<AddrSpaceCastInst>(CurRetByPtr)->getOperand(0);
1081 IGC_ASSERT_MESSAGE(isa<AllocaInst>(CurRetByPtr),
1082 "corresponding alloca is expected");
1083 Value *CurRetByVal = Builder.CreateLoad(
1084 CurRetByPtr->getType()->getPointerElementType(), CurRetByPtr);
1085 return Builder.CreateInsertValue(&NewRetVal, CurRetByVal, RetIdx);
1086 }
1087
1088 // Add some additional code before \p OrigCall to pass localized global value
1089 // \p GAI to the transformed function.
1090 // An argument corresponding to \p GAI is returned.
passGlobalAsCallArg(GlobalArgInfo GAI,CallInst & OrigCall)1091 static Value *passGlobalAsCallArg(GlobalArgInfo GAI, CallInst &OrigCall) {
1092 // We should should load the global first to pass it by value.
1093 if (GAI.Kind == GlobalArgKind::ByValueIn ||
1094 GAI.Kind == GlobalArgKind::ByValueInOut)
1095 return new LoadInst(GAI.GV->getType()->getPointerElementType(), GAI.GV,
1096 GAI.GV->getName() + ".val",
1097 /* isVolatile */ false, &OrigCall);
1098 IGC_ASSERT_MESSAGE(GAI.Kind == GlobalArgKind::ByPointer,
1099 "localized global can be passed only by value or by pointer");
1100 auto *GVTy = cast<PointerType>(GAI.GV->getType());
1101 // No additional work when addrspaces match
1102 if (GVTy->getAddressSpace() == vc::AddrSpace::Private)
1103 return GAI.GV;
1104 // Need to add a temprorary cast inst to match types.
1105 // When this switch to the caller, it'll remove this cast.
1106 return new AddrSpaceCastInst{
1107 GAI.GV, vc::changeAddrSpace(GVTy, vc::AddrSpace::Private),
1108 GAI.GV->getName() + ".tmp", &OrigCall};
1109 }
1110
1111 namespace {
1112 class FuncUsersUpdater {
1113 Function &OrigFunc;
1114 Function &NewFunc;
1115 const TransformedFuncInfo &NewFuncInfo;
1116 CallGraphNode &NewFuncCGN;
1117 CallGraph &CG;
1118
1119 public:
FuncUsersUpdater(Function & OrigFuncIn,Function & NewFuncIn,const TransformedFuncInfo & NewFuncInfoIn,CallGraphNode & NewFuncCGNIn,CallGraph & CGIn)1120 FuncUsersUpdater(Function &OrigFuncIn, Function &NewFuncIn,
1121 const TransformedFuncInfo &NewFuncInfoIn,
1122 CallGraphNode &NewFuncCGNIn, CallGraph &CGIn)
1123 : OrigFunc{OrigFuncIn}, NewFunc{NewFuncIn}, NewFuncInfo{NewFuncInfoIn},
1124 NewFuncCGN{NewFuncCGNIn}, CG{CGIn} {}
1125
run()1126 void run() {
1127 std::vector<CallInst *> DirectUsers;
1128
1129 for (auto *U : OrigFunc.users()) {
1130 IGC_ASSERT_MESSAGE(
1131 isa<CallInst>(U),
1132 "the transformation is not applied to indirectly called functions");
1133 DirectUsers.push_back(cast<CallInst>(U));
1134 }
1135
1136 std::vector<CallInst *> NewDirectUsers;
1137 // Loop over all of the callers of the function, transforming the call sites
1138 // to pass in the loaded pointers.
1139 for (auto *OrigCall : DirectUsers) {
1140 IGC_ASSERT(OrigCall->getCalledFunction() == &OrigFunc);
1141 auto *NewCall = UpdateFuncDirectUser(*OrigCall);
1142 NewDirectUsers.push_back(NewCall);
1143 }
1144
1145 for (auto *OrigCall : DirectUsers)
1146 OrigCall->eraseFromParent();
1147 }
1148
1149 private:
UpdateFuncDirectUser(CallInst & OrigCall)1150 CallInst *UpdateFuncDirectUser(CallInst &OrigCall) {
1151 std::vector<Value *> NewCallOps =
1152 getTransformedFuncCallArgs(OrigCall, NewFuncInfo);
1153
1154 AttributeList NewCallAttrs = inheritCallAttributes(
1155 OrigCall, OrigFunc.getFunctionType()->getNumParams(), NewFuncInfo);
1156
1157 // Push any localized globals.
1158 IGC_ASSERT_MESSAGE(
1159 NewCallOps.size() == NewFuncInfo.getGlobalArgsInfo().FirstGlobalArgIdx,
1160 "call operands and called function info are inconsistent");
1161 llvm::transform(NewFuncInfo.getGlobalArgsInfo().Globals,
1162 std::back_inserter(NewCallOps),
1163 [&OrigCall](GlobalArgInfo GAI) {
1164 return passGlobalAsCallArg(GAI, OrigCall);
1165 });
1166
1167 IGC_ASSERT_EXIT_MESSAGE(!isa<InvokeInst>(OrigCall),
1168 "InvokeInst not supported");
1169
1170 CallInst *NewCall = CallInst::Create(&NewFunc, NewCallOps, "", &OrigCall);
1171 IGC_ASSERT(nullptr != NewCall);
1172 NewCall->setCallingConv(OrigCall.getCallingConv());
1173 NewCall->setAttributes(NewCallAttrs);
1174 if (cast<CallInst>(OrigCall).isTailCall())
1175 NewCall->setTailCall();
1176 NewCall->setDebugLoc(OrigCall.getDebugLoc());
1177 NewCall->takeName(&OrigCall);
1178
1179 // Update the callgraph to know that the callsite has been transformed.
1180 auto CalleeNode = static_cast<IGCLLVM::CallGraphNode *>(
1181 CG[OrigCall.getParent()->getParent()]);
1182 CalleeNode->replaceCallEdge(
1183 #if LLVM_VERSION_MAJOR <= 10
1184 CallSite(&OrigCall), NewCall,
1185 #else
1186 OrigCall, *NewCall,
1187 #endif
1188 &NewFuncCGN);
1189
1190 IRBuilder<> Builder(&OrigCall);
1191 for (auto RetToArg : enumerate(NewFuncInfo.getRetToArgInfo().Map))
1192 handleRetValuePortion(RetToArg.index(), RetToArg.value(), OrigCall,
1193 *NewCall, Builder, NewFuncInfo);
1194 return NewCall;
1195 }
1196 };
1197
1198 class FuncBodyTransfer {
1199 Function &OrigFunc;
1200 Function &NewFunc;
1201 const TransformedFuncInfo &NewFuncInfo;
1202
1203 public:
FuncBodyTransfer(Function & OrigFuncIn,Function & NewFuncIn,const TransformedFuncInfo & NewFuncInfoIn)1204 FuncBodyTransfer(Function &OrigFuncIn, Function &NewFuncIn,
1205 const TransformedFuncInfo &NewFuncInfoIn)
1206 : OrigFunc{OrigFuncIn}, NewFunc{NewFuncIn}, NewFuncInfo{NewFuncInfoIn} {}
1207
run()1208 void run() {
1209 // Since we have now created the new function, splice the body of the old
1210 // function right into the new function.
1211 NewFunc.getBasicBlockList().splice(NewFunc.begin(),
1212 OrigFunc.getBasicBlockList());
1213
1214 std::vector<Value *> OrigArgReplacements = handleTransformedFuncArgs();
1215 std::vector<Value *> LocalizedGlobals =
1216 handleGlobalArgs(NewFunc, NewFuncInfo.getGlobalArgsInfo());
1217
1218 handleTransformedFuncRets(OrigArgReplacements, LocalizedGlobals);
1219 }
1220
1221 private:
handleTransformedFuncArgs()1222 std::vector<Value *> handleTransformedFuncArgs() {
1223 std::vector<Value *> OrigArgReplacements;
1224 Instruction *InsertPt = &*(NewFunc.begin()->getFirstInsertionPt());
1225
1226 std::transform(
1227 NewFuncInfo.getArgKinds().begin(), NewFuncInfo.getArgKinds().end(),
1228 NewFunc.arg_begin(), std::back_inserter(OrigArgReplacements),
1229 [InsertPt](ArgKind Kind, Argument &NewArg) -> Value * {
1230 switch (Kind) {
1231 case ArgKind::CopyIn:
1232 case ArgKind::CopyInOut: {
1233 auto *Alloca = new AllocaInst(NewArg.getType(),
1234 vc::AddrSpace::Private, "", InsertPt);
1235 new StoreInst{&NewArg, Alloca, InsertPt};
1236 return Alloca;
1237 }
1238 default:
1239 IGC_ASSERT_MESSAGE(Kind == ArgKind::General,
1240 "unexpected argument kind");
1241 return &NewArg;
1242 }
1243 });
1244
1245 std::transform(
1246 OrigArgReplacements.begin(), OrigArgReplacements.end(),
1247 OrigFunc.arg_begin(), OrigArgReplacements.begin(),
1248 [InsertPt](Value *Replacement, Argument &OrigArg) -> Value * {
1249 if (Replacement->getType() == OrigArg.getType())
1250 return Replacement;
1251 IGC_ASSERT_MESSAGE(isa<PointerType>(Replacement->getType()),
1252 "only pointers can posibly mismatch");
1253 IGC_ASSERT_MESSAGE(isa<PointerType>(OrigArg.getType()),
1254 "only pointers can posibly mismatch");
1255 IGC_ASSERT_MESSAGE(
1256 Replacement->getType()->getPointerAddressSpace() !=
1257 OrigArg.getType()->getPointerAddressSpace(),
1258 "pointers should have different addr spaces when they mismatch");
1259 IGC_ASSERT_MESSAGE(
1260 Replacement->getType()->getPointerElementType() ==
1261 OrigArg.getType()->getPointerElementType(),
1262 "pointers must have same element type when they mismatch");
1263 return new AddrSpaceCastInst(Replacement, OrigArg.getType(), "",
1264 InsertPt);
1265 });
1266 for (auto &&[OrigArg, OrigArgReplacement] :
1267 zip(OrigFunc.args(), OrigArgReplacements)) {
1268 OrigArgReplacement->takeName(&OrigArg);
1269 OrigArg.replaceAllUsesWith(OrigArgReplacement);
1270 }
1271
1272 return std::move(OrigArgReplacements);
1273 }
1274
handleTransformedFuncRet(ReturnInst & OrigRet,const std::vector<Value * > & OrigArgReplacements,std::vector<Value * > & LocalizedGlobals)1275 void handleTransformedFuncRet(ReturnInst &OrigRet,
1276 const std::vector<Value *> &OrigArgReplacements,
1277 std::vector<Value *> &LocalizedGlobals) {
1278 Type *NewRetTy = NewFunc.getReturnType();
1279 IRBuilder<> Builder(&OrigRet);
1280 auto &&RetToArg = enumerate(NewFuncInfo.getRetToArgInfo().Map);
1281 Value *NewRetVal = std::accumulate(
1282 RetToArg.begin(), RetToArg.end(),
1283 cast<Value>(UndefValue::get(NewRetTy)),
1284 [&OrigRet, &Builder, &OrigArgReplacements, &LocalizedGlobals,
1285 this](Value *NewRet, auto NewRetPortionInfo) {
1286 return appendTransformedFuncRetPortion(
1287 *NewRet, NewRetPortionInfo.index(), NewRetPortionInfo.value(),
1288 OrigRet, Builder, NewFuncInfo, OrigArgReplacements,
1289 LocalizedGlobals);
1290 });
1291 Builder.CreateRet(NewRetVal);
1292 OrigRet.eraseFromParent();
1293 }
1294
1295 void
handleTransformedFuncRets(const std::vector<Value * > & OrigArgReplacements,std::vector<Value * > & LocalizedGlobals)1296 handleTransformedFuncRets(const std::vector<Value *> &OrigArgReplacements,
1297 std::vector<Value *> &LocalizedGlobals) {
1298 Type *NewRetTy = NewFunc.getReturnType();
1299 if (NewRetTy->isVoidTy())
1300 return;
1301 std::vector<ReturnInst *> OrigRets;
1302 llvm::transform(make_filter_range(instructions(NewFunc),
1303 [](Instruction &Inst) {
1304 return isa<ReturnInst>(Inst);
1305 }),
1306 std::back_inserter(OrigRets),
1307 [](Instruction &RI) { return &cast<ReturnInst>(RI); });
1308
1309 for (ReturnInst *OrigRet : OrigRets)
1310 handleTransformedFuncRet(*OrigRet, OrigArgReplacements, LocalizedGlobals);
1311 }
1312 };
1313 } // namespace
1314
1315 // \brief Actually performs the transformation of the specified arguments, and
1316 // returns the new function.
1317 //
1318 // Note this transformation does change the semantics as a C function, due to
1319 // possible pointer aliasing. But it is allowed as a CM function.
1320 //
1321 // The pass-by-reference scheme is useful to copy-out values from the
1322 // subprogram back to the caller. It also may be useful to convey large inputs
1323 // to subprograms, as the amount of parameter conveying code will be reduced.
1324 // There is a restriction imposed on arguments passed by reference in order to
1325 // allow for an efficient CM implementation. Specifically the restriction is
1326 // that for a subprogram that uses pass-by-reference, the behavior must be the
1327 // same as if we use a copy-in/copy-out semantic to convey the
1328 // pass-by-reference argument; otherwise the CM program is said to be erroneous
1329 // and may produce incorrect results. Such errors are not caught by the
1330 // compiler and it is up to the user to guarantee safety.
1331 //
1332 // The implication of the above stated restriction is that no pass-by-reference
1333 // argument that is written to in a subprogram (either directly or transitively
1334 // by means of a nested subprogram call pass-by-reference argument) may overlap
1335 // with another pass-by-reference parameter or a global variable that is
1336 // referenced in the subprogram; in addition no pass-by-reference subprogram
1337 // argument that is referenced may overlap with a global variable that is
1338 // written to in the subprogram.
1339 //
TransformNode(Function & OrigFunc,SmallPtrSet<Argument *,8> & ArgsToTransform,LocalizationInfo & LI)1340 CallGraphNode *CMABI::TransformNode(Function &OrigFunc,
1341 SmallPtrSet<Argument *, 8> &ArgsToTransform,
1342 LocalizationInfo &LI) {
1343 NumArgumentsTransformed += ArgsToTransform.size();
1344 TransformedFuncInfo NewFuncInfo{OrigFunc, ArgsToTransform};
1345 NewFuncInfo.AppendGlobals(LI);
1346
1347 // Create the new function declaration and insert it into the module.
1348 Function *NewFunc = createTransformedFuncDecl(OrigFunc, NewFuncInfo);
1349
1350 // Get a new callgraph node for NF.
1351 CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
1352 CallGraphNode *NewFuncCGN = CG.getOrInsertFunction(NewFunc);
1353
1354 FuncUsersUpdater{OrigFunc, *NewFunc, NewFuncInfo, *NewFuncCGN, CG}.run();
1355 FuncBodyTransfer{OrigFunc, *NewFunc, NewFuncInfo}.run();
1356
1357 // It turns out sometimes llvm will recycle function pointers which confuses
1358 // this pass. We delete its localization info and mark this function as
1359 // already visited.
1360 Info->GlobalInfo.erase(&OrigFunc);
1361 AlreadyVisited.insert(&OrigFunc);
1362
1363 NewFuncCGN->stealCalledFunctionsFrom(CG[&OrigFunc]);
1364
1365 // Now that the old function is dead, delete it. If there is a dangling
1366 // reference to the CallgraphNode, just leave the dead function around.
1367 CallGraphNode *CGN = CG[&OrigFunc];
1368 if (CGN->getNumReferences() == 0)
1369 delete CG.removeFunctionFromModule(CGN);
1370 else
1371 OrigFunc.setLinkage(Function::ExternalLinkage);
1372
1373 return NewFuncCGN;
1374 }
1375
fillStackWithUsers(std::stack<User * > & Stack,User & CurUser)1376 static void fillStackWithUsers(std::stack<User *> &Stack, User &CurUser) {
1377 for (User *Usr : CurUser.users())
1378 Stack.push(Usr);
1379 }
1380
1381 // Traverse in depth through GV constant users to find instruction users.
1382 // When instruction user is found, it is clear in which function GV is used.
defineGVDirectUsers(GlobalVariable & GV)1383 void CMABIAnalysis::defineGVDirectUsers(GlobalVariable &GV) {
1384 std::stack<User *> Stack;
1385 Stack.push(&GV);
1386 while (!Stack.empty()) {
1387 auto *CurUser = Stack.top();
1388 Stack.pop();
1389
1390 // Continue go in depth when a constant is met.
1391 if (isa<Constant>(CurUser)) {
1392 fillStackWithUsers(Stack, *CurUser);
1393 continue;
1394 }
1395
1396 // We've got what we looked for.
1397 auto *Inst = cast<Instruction>(CurUser);
1398 addDirectGlobal(Inst->getFunction(), &GV);
1399 }
1400 }
1401
1402 // For each function, compute the list of globals that need to be passed as
1403 // copy-in and copy-out arguments.
analyzeGlobals(CallGraph & CG)1404 void CMABIAnalysis::analyzeGlobals(CallGraph &CG) {
1405 Module &M = CG.getModule();
1406 for (auto& F : M.getFunctionList()) {
1407 if (F.isDeclaration() || F.hasDLLExportStorageClass())
1408 continue;
1409 if (GenXIntrinsic::getAnyIntrinsicID(&F) !=
1410 GenXIntrinsic::not_any_intrinsic)
1411 continue;
1412 // __cm_intrinsic_impl_* could be used for emulation mul/div etc
1413 if (F.getName().contains("__cm_intrinsic_impl_"))
1414 continue;
1415
1416 // Convert non-kernel to stack call if applicable
1417 if (FCtrl == FunctionControl::StackCall && !genx::requiresStackCall(&F)) {
1418 LLVM_DEBUG(dbgs() << "Adding stack call to: " << F.getName() << "\n");
1419 F.addFnAttr(genx::FunctionMD::CMStackCall);
1420 }
1421
1422 // Do not change stack calls linkage as we may have both types of stack
1423 // calls.
1424 if (genx::requiresStackCall(&F) && SaveStackCallLinkage)
1425 continue;
1426
1427 F.setLinkage(GlobalValue::InternalLinkage);
1428 }
1429 // No global variables.
1430 if (M.global_empty())
1431 return;
1432
1433 // FIXME: String constants must be localized too. Excluding them there
1434 // to WA legacy printf implementation in CM FE (printf strings are
1435 // not in constant addrspace in legacy printf).
1436 auto ToLocalize =
1437 make_filter_range(M.globals(), [](const GlobalVariable &GV) {
1438 return GV.getAddressSpace() == vc::AddrSpace::Private &&
1439 !GV.hasAttribute(genx::FunctionMD::GenXVolatile) &&
1440 !vc::isConstantString(GV);
1441 });
1442
1443 // Collect direct and indirect (GV is used in a called function)
1444 // uses of globals.
1445 for (GlobalVariable &GV : ToLocalize)
1446 defineGVDirectUsers(GV);
1447 for (const std::vector<CallGraphNode *> &SCCNodes :
1448 make_range(scc_begin(&CG), scc_end(&CG)))
1449 for (const CallGraphNode *Caller : SCCNodes)
1450 for (const IGCLLVM::CallGraphNode::CallRecord &Callee : *Caller) {
1451 Function *CalleeF = Callee.second->getFunction();
1452 if (CalleeF && !vc::isFixedSignatureFunc(*CalleeF))
1453 addIndirectGlobal(Caller->getFunction(), CalleeF);
1454 }
1455 }
1456
1457 /***********************************************************************
1458 * diagnoseOverlappingArgs : attempt to diagnose overlapping by-ref args
1459 *
1460 * The CM language spec says you are not allowed a call with two by-ref args
1461 * that overlap. This is to give the compiler the freedom to implement with
1462 * copy-in copy-out semantics or with an address register.
1463 *
1464 * This function attempts to diagnose code that breaks this restriction. For
1465 * pointer args to the call, it attempts to track how values are loaded using
1466 * the pointer (assumed to be an alloca of the temporary used for copy-in
1467 * copy-out semantics), and how those values then get propagated through
1468 * wrregions and stores. If any vector element in a wrregion or store is found
1469 * that comes from more than one pointer arg, it is reported.
1470 *
1471 * This ignores variable index wrregions, and only traces through instructions
1472 * with the same debug location as the call, so does not work with -g0.
1473 */
diagnoseOverlappingArgs(CallInst * CI)1474 void CMABI::diagnoseOverlappingArgs(CallInst *CI)
1475 {
1476 LLVM_DEBUG(dbgs() << "diagnoseOverlappingArgs " << *CI << "\n");
1477 auto DL = CI->getDebugLoc();
1478 if (!DL)
1479 return;
1480 std::map<Value *, SmallVector<uint8_t, 16>> ValMap;
1481 SmallVector<Instruction *, 8> WorkList;
1482 std::set<Instruction *> InWorkList;
1483 std::set<std::pair<unsigned, unsigned>> Reported;
1484 // Using ArgIndex starting at 1 so we can reserve 0 to mean "element does not
1485 // come from any by-ref arg".
1486 for (unsigned ArgIndex = 1, NumArgs = CI->getNumArgOperands();
1487 ArgIndex <= NumArgs; ++ArgIndex) {
1488 Value *Arg = CI->getOperand(ArgIndex - 1);
1489 if (!Arg->getType()->isPointerTy())
1490 continue;
1491 LLVM_DEBUG(dbgs() << "arg " << ArgIndex << ": " << *Arg << "\n");
1492 // Got a pointer arg. Find its loads (with the same debug loc).
1493 for (auto ui = Arg->use_begin(), ue = Arg->use_end(); ui != ue; ++ui) {
1494 auto LI = dyn_cast<LoadInst>(ui->getUser());
1495 if (!LI || LI->getDebugLoc() != DL)
1496 continue;
1497 LLVM_DEBUG(dbgs() << " " << *LI << "\n");
1498 // For a load, create a map entry that says that every vector element
1499 // comes from this arg.
1500 unsigned NumElements = 1;
1501 if (auto VT = dyn_cast<IGCLLVM::FixedVectorType>(LI->getType()))
1502 NumElements = VT->getNumElements();
1503 auto Entry = &ValMap[LI];
1504 Entry->resize(NumElements, ArgIndex);
1505 // Add its users (with the same debug location) to the work list.
1506 for (auto ui = LI->use_begin(), ue = LI->use_end(); ui != ue; ++ui) {
1507 auto Inst = cast<Instruction>(ui->getUser());
1508 if (Inst->getDebugLoc() == DL)
1509 if (InWorkList.insert(Inst).second)
1510 WorkList.push_back(Inst);
1511 }
1512 }
1513 }
1514 // Process the work list.
1515 while (!WorkList.empty()) {
1516 auto Inst = WorkList.back();
1517 WorkList.pop_back();
1518 InWorkList.erase(Inst);
1519 LLVM_DEBUG(dbgs() << "From worklist: " << *Inst << "\n");
1520 Value *Key = nullptr;
1521 SmallVector<uint8_t, 8> TempVector;
1522 SmallVectorImpl<uint8_t> *VectorToMerge = nullptr;
1523 if (auto SI = dyn_cast<StoreInst>(Inst)) {
1524 // Store: set the map entry using the store pointer as the key. It might
1525 // be an alloca of a local variable, or a global variable.
1526 // Strictly speaking this is not properly keeping track of what is being
1527 // merged using load-wrregion-store for a non-SROAd local variable or a
1528 // global variable. Instead it is just merging at the store itself, which
1529 // is good enough for our purposes.
1530 Key = SI->getPointerOperand();
1531 VectorToMerge = &ValMap[SI->getValueOperand()];
1532 } else if (auto BC = dyn_cast<BitCastInst>(Inst)) {
1533 // Bitcast: calculate the new map entry.
1534 Key = BC;
1535 uint64_t OutElementSize =
1536 BC->getType()->getScalarType()->getPrimitiveSizeInBits();
1537 uint64_t InElementSize = BC->getOperand(0)
1538 ->getType()
1539 ->getScalarType()
1540 ->getPrimitiveSizeInBits();
1541 int LogRatio = countTrailingZeros(OutElementSize, ZB_Undefined) -
1542 countTrailingZeros(InElementSize, ZB_Undefined);
1543 auto OpndEntry = &ValMap[BC->getOperand(0)];
1544 if (!LogRatio)
1545 VectorToMerge = OpndEntry;
1546 else if (LogRatio > 0) {
1547 // Result element type is bigger than input element type, so there are
1548 // fewer result elements. Just use an arbitrarily chosen non-zero entry
1549 // of the N input elements to set the 1 result element.
1550 IGC_ASSERT(!(OpndEntry->size() & ((1U << LogRatio) - 1)));
1551 for (unsigned i = 0, e = OpndEntry->size(); i != e; i += 1U << LogRatio) {
1552 unsigned FoundArgIndex = 0;
1553 for (unsigned j = 0; j != 1U << LogRatio; ++j)
1554 FoundArgIndex = std::max(FoundArgIndex, (unsigned)(*OpndEntry)[i + j]);
1555 TempVector.push_back(FoundArgIndex);
1556 }
1557 VectorToMerge = &TempVector;
1558 } else {
1559 // Result element type is smaller than input element type, so there are
1560 // multiple result elements per input element.
1561 for (unsigned i = 0, e = OpndEntry->size(); i != e; ++i)
1562 for (unsigned j = 0; j != 1U << -LogRatio; ++j)
1563 TempVector.push_back((*OpndEntry)[i]);
1564 VectorToMerge = &TempVector;
1565 }
1566 } else if (auto CI = dyn_cast<CallInst>(Inst)) {
1567 if (auto CF = CI->getCalledFunction()) {
1568 switch (GenXIntrinsic::getGenXIntrinsicID(CF)) {
1569 default:
1570 break;
1571 case GenXIntrinsic::genx_wrregionf:
1572 case GenXIntrinsic::genx_wrregioni:
1573 // wrregion: As long as it is constant index, propagate the argument
1574 // indices into the appropriate elements of the result.
1575 if (auto IdxC = dyn_cast<Constant>(CI->getOperand(
1576 GenXIntrinsic::GenXRegion::WrIndexOperandNum))) {
1577 unsigned Idx = 0;
1578 if (!IdxC->isNullValue()) {
1579 auto IdxCI = dyn_cast<ConstantInt>(IdxC);
1580 if (!IdxCI) {
1581 LLVM_DEBUG(dbgs() << "Ignoring variable index wrregion\n");
1582 break;
1583 }
1584 Idx = IdxCI->getZExtValue();
1585 }
1586 Idx /= (CI->getType()->getScalarType()->getPrimitiveSizeInBits() / 8U);
1587 // First copy the "old value" input to the map entry.
1588 auto OpndEntry = &ValMap[CI->getOperand(
1589 GenXIntrinsic::GenXRegion::OldValueOperandNum)];
1590 auto Entry = &ValMap[CI];
1591 Entry->clear();
1592 Entry->insert(Entry->begin(), OpndEntry->begin(), OpndEntry->end());
1593 // Then copy the "new value" elements according to the region.
1594 TempVector.resize(
1595 dyn_cast<IGCLLVM::FixedVectorType>(CI->getType())->getNumElements(), 0);
1596 int VStride = cast<ConstantInt>(CI->getOperand(
1597 GenXIntrinsic::GenXRegion::WrVStrideOperandNum))->getSExtValue();
1598 unsigned Width = cast<ConstantInt>(CI->getOperand(
1599 GenXIntrinsic::GenXRegion::WrWidthOperandNum))->getZExtValue();
1600 IGC_ASSERT_MESSAGE((Width > 0), "Width of a region must be non-zero");
1601 int Stride = cast<ConstantInt>(CI->getOperand(
1602 GenXIntrinsic::GenXRegion::WrStrideOperandNum))->getSExtValue();
1603 OpndEntry = &ValMap[CI->getOperand(
1604 GenXIntrinsic::GenXRegion::NewValueOperandNum)];
1605 unsigned NumElements = OpndEntry->size();
1606 if (!NumElements)
1607 break;
1608 for (unsigned RowIdx = Idx, Row = 0, Col = 0,
1609 NumRows = NumElements / Width;; Idx += Stride, ++Col) {
1610 if (Col == Width) {
1611 Col = 0;
1612 if (++Row == NumRows)
1613 break;
1614 Idx = RowIdx += VStride;
1615 }
1616 TempVector[Idx] = (*OpndEntry)[Row * Width + Col];
1617 }
1618 VectorToMerge = &TempVector;
1619 Key = CI;
1620 }
1621 break;
1622 }
1623 }
1624 }
1625 if (!VectorToMerge)
1626 continue;
1627 auto Entry = &ValMap[Key];
1628 LLVM_DEBUG(dbgs() << "Merging :";
1629 for (unsigned i = 0; i != VectorToMerge->size(); ++i)
1630 dbgs() << " " << (unsigned)(*VectorToMerge)[i];
1631 dbgs() << "\ninto " << Key->getName() << ":";
1632 for (unsigned i = 0; i != Entry->size(); ++i)
1633 dbgs() << " " << (unsigned)(*Entry)[i];
1634 dbgs() << "\n");
1635 if (Entry->empty())
1636 Entry->insert(Entry->end(), VectorToMerge->begin(), VectorToMerge->end());
1637 else {
1638 IGC_ASSERT(VectorToMerge->size() == Entry->size());
1639 for (unsigned i = 0; i != VectorToMerge->size(); ++i) {
1640 unsigned ArgIdx1 = (*VectorToMerge)[i];
1641 unsigned ArgIdx2 = (*Entry)[i];
1642 if (ArgIdx1 && ArgIdx2 && ArgIdx1 != ArgIdx2) {
1643 LLVM_DEBUG(dbgs() << "By ref args overlap: args " << ArgIdx1 << " and " << ArgIdx2 << "\n");
1644 if (ArgIdx1 > ArgIdx2)
1645 std::swap(ArgIdx1, ArgIdx2);
1646 if (Reported.insert(std::pair<unsigned, unsigned>(ArgIdx1, ArgIdx2))
1647 .second) {
1648 // Not already reported.
1649 DiagnosticInfoOverlappingArgs Err(CI, "by reference arguments "
1650 + Twine(ArgIdx1) + " and " + Twine(ArgIdx2) + " overlap",
1651 DS_Error);
1652 Inst->getContext().diagnose(Err);
1653 }
1654 }
1655 (*Entry)[i] = std::max((*Entry)[i], (*VectorToMerge)[i]);
1656 }
1657 }
1658 LLVM_DEBUG(dbgs() << "giving:";
1659 for (unsigned i = 0; i != Entry->size(); ++i)
1660 dbgs() << " " << (unsigned)(*Entry)[i];
1661 dbgs() << "\n");
1662 if (Key == Inst) {
1663 // Not the case that we have a store and we are using the pointer as
1664 // the key. In ther other cases that do a merge (bitcast and wrregion),
1665 // add users to the work list as long as they have the same debug loc.
1666 for (auto ui = Inst->use_begin(), ue = Inst->use_end(); ui != ue; ++ui) {
1667 auto User = cast<Instruction>(ui->getUser());
1668 if (User->getDebugLoc() == DL)
1669 if (InWorkList.insert(Inst).second)
1670 WorkList.push_back(User);
1671 }
1672 }
1673 }
1674 }
1675
1676 /***********************************************************************
1677 * DiagnosticInfoOverlappingArgs initializer from Instruction
1678 *
1679 * If the Instruction has a DebugLoc, then that is used for the error
1680 * location.
1681 * Otherwise, the location is unknown.
1682 */
DiagnosticInfoOverlappingArgs(Instruction * Inst,const Twine & Desc,DiagnosticSeverity Severity)1683 DiagnosticInfoOverlappingArgs::DiagnosticInfoOverlappingArgs(Instruction *Inst,
1684 const Twine &Desc, DiagnosticSeverity Severity)
1685 : DiagnosticInfo(getKindID(), Severity), Line(0), Col(0)
1686 {
1687 auto DL = Inst->getDebugLoc();
1688 if (!DL) {
1689 Filename = DL.get()->getFilename();
1690 Line = DL.getLine();
1691 Col = DL.getCol();
1692 }
1693 Description = Desc.str();
1694 }
1695
1696 /***********************************************************************
1697 * DiagnosticInfoOverlappingArgs::print : print the error/warning message
1698 */
print(DiagnosticPrinter & DP) const1699 void DiagnosticInfoOverlappingArgs::print(DiagnosticPrinter &DP) const
1700 {
1701 std::string Loc(
1702 (Twine(!Filename.empty() ? Filename : "<unknown>")
1703 + ":" + Twine(Line)
1704 + (!Col ? Twine() : Twine(":") + Twine(Col))
1705 + ": ")
1706 .str());
1707 DP << Loc << Description;
1708 }
1709
1710
1711 char CMABI::ID = 0;
1712 INITIALIZE_PASS_BEGIN(CMABI, "cmabi", "Fix ABI issues for the genx backend", false, false)
INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)1713 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
1714 INITIALIZE_PASS_DEPENDENCY(CMABIAnalysis)
1715 INITIALIZE_PASS_END(CMABI, "cmabi", "Fix ABI issues for the genx backend", false, false)
1716
1717 Pass *llvm::createCMABIPass() { return new CMABI(); }
1718
1719 namespace {
1720
1721 // A well-formed passing argument by reference pattern.
1722 //
1723 // (Alloca)
1724 // %argref1 = alloca <8 x float>, align 32
1725 //
1726 // (CopyInRegion/CopyInStore)
1727 // %rdr = tail call <8 x float> @llvm.genx.rdregionf(<960 x float> %m, i32 0, i32 8, i32 1, i16 0, i32 undef)
1728 // call void @llvm.genx.vstore(<8 x float> %rdr, <8 x float>* %argref)
1729 //
1730 // (CopyOutRegion/CopyOutLoad)
1731 // %ld = call <8 x float> @llvm.genx.vload(<8 x float>* %argref)
1732 // %wr = call <960 x float> @llvm.genx.wrregionf(<960 x float> %m, <8 x float> %ld, i32 0, i32 8, i32 1, i16 0, i32 undef, i1 true)
1733 //
1734 struct ArgRefPattern {
1735 // Alloca of this reference argument.
1736 AllocaInst *Alloca;
1737
1738 // The input value
1739 CallInst *CopyInRegion;
1740 CallInst *CopyInStore;
1741
1742 // The output value
1743 CallInst *CopyOutLoad;
1744 CallInst *CopyOutRegion;
1745
1746 // Load and store instructions on arg alloca.
1747 SmallVector<CallInst *, 8> VLoads;
1748 SmallVector<CallInst *, 8> VStores;
1749
ArgRefPattern__anon278d8c6d1111::ArgRefPattern1750 explicit ArgRefPattern(AllocaInst *AI)
1751 : Alloca(AI), CopyInRegion(nullptr), CopyInStore(nullptr),
1752 CopyOutLoad(nullptr), CopyOutRegion(nullptr) {}
1753
1754 // Match a copy-in and copy-out pattern. Return true on success.
1755 bool match(DominatorTree &DT, PostDominatorTree &PDT);
1756 void process(DominatorTree &DT);
1757 };
1758
1759 struct CMLowerVLoadVStore : public FunctionPass {
1760 static char ID;
CMLowerVLoadVStore__anon278d8c6d1111::CMLowerVLoadVStore1761 CMLowerVLoadVStore() : FunctionPass(ID) {
1762 initializeCMLowerVLoadVStorePass(*PassRegistry::getPassRegistry());
1763 }
getAnalysisUsage__anon278d8c6d1111::CMLowerVLoadVStore1764 void getAnalysisUsage(AnalysisUsage &AU) const override {
1765 AU.addRequired<DominatorTreeWrapperPass>();
1766 AU.addRequired<PostDominatorTreeWrapperPass>();
1767 AU.setPreservesCFG();
1768 }
1769
1770 bool runOnFunction(Function &F) override;
1771
1772 private:
1773 bool promoteAllocas(Function &F);
1774 bool lowerLoadStore(Function &F);
1775 };
1776
1777 } // namespace
1778
1779 char CMLowerVLoadVStore::ID = 0;
1780 INITIALIZE_PASS_BEGIN(CMLowerVLoadVStore, "CMLowerVLoadVStore",
1781 "Lower CM reference vector loads and stores", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)1782 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
1783 INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
1784 INITIALIZE_PASS_END(CMLowerVLoadVStore, "CMLowerVLoadVStore",
1785 "Lower CM reference vector loads and stores", false, false)
1786
1787
1788 bool CMLowerVLoadVStore::runOnFunction(Function &F) {
1789 bool Changed = false;
1790 Changed |= promoteAllocas(F);
1791 Changed |= lowerLoadStore(F);
1792 return Changed;
1793 }
1794
1795 // Lower remaining vector load/store intrinsic calls into normal load/store
1796 // instructions.
lowerLoadStore(Function & F)1797 bool CMLowerVLoadVStore::lowerLoadStore(Function &F) {
1798 auto M = F.getParent();
1799 DenseMap<AllocaInst*, GlobalVariable*> AllocaMap;
1800 // collect all the allocas that store the address of genx-volatile variable
1801 for (auto& G : M->getGlobalList()) {
1802 if (!G.hasAttribute("genx_volatile"))
1803 continue;
1804 std::vector<User*> WL;
1805 for (auto UI = G.user_begin(); UI != G.user_end();) {
1806 auto U = *UI++;
1807 WL.push_back(U);
1808 }
1809
1810 while (!WL.empty()) {
1811 auto Inst = WL.back();
1812 WL.pop_back();
1813 if (auto CE = dyn_cast<ConstantExpr>(Inst)) {
1814 for (auto UI = CE->user_begin(); UI != CE->user_end();) {
1815 auto U = *UI++;
1816 WL.push_back(U);
1817 }
1818 }
1819 else if (auto CI = dyn_cast<CastInst>(Inst)) {
1820 for (auto UI = CI->user_begin(); UI != CI->user_end();) {
1821 auto U = *UI++;
1822 WL.push_back(U);
1823 }
1824 }
1825 else if (auto SI = dyn_cast<StoreInst>(Inst)) {
1826 auto Ptr = SI->getPointerOperand()->stripPointerCasts();
1827 if (auto PI = dyn_cast<AllocaInst>(Ptr)) {
1828 AllocaMap[PI] = &G;
1829 }
1830 }
1831 }
1832 }
1833
1834 // lower all vload/vstore into normal load/store.
1835 std::vector<Instruction *> ToErase;
1836 for (Instruction &Inst : instructions(F)) {
1837 if (GenXIntrinsic::isVLoadStore(&Inst)) {
1838 auto *Ptr = Inst.getOperand(0);
1839 if (GenXIntrinsic::isVStore(&Inst))
1840 Ptr = Inst.getOperand(1);
1841 auto AS0 = cast<PointerType>(Ptr->getType())->getAddressSpace();
1842 Ptr = Ptr->stripPointerCasts();
1843 auto GV = dyn_cast<GlobalVariable>(Ptr);
1844 if (GV) {
1845 if (!GV->hasAttribute("genx_volatile"))
1846 GV = nullptr;
1847 }
1848 else if (auto LI = dyn_cast<LoadInst>(Ptr)) {
1849 auto PV = LI->getPointerOperand()->stripPointerCasts();
1850 if (auto PI = dyn_cast<AllocaInst>(PV)) {
1851 if (AllocaMap.find(PI) != AllocaMap.end()) {
1852 GV = AllocaMap[PI];
1853 }
1854 }
1855 }
1856 if (GV == nullptr) {
1857 // change to load/store
1858 IRBuilder<> Builder(&Inst);
1859 if (GenXIntrinsic::isVStore(&Inst))
1860 Builder.CreateStore(Inst.getOperand(0), Inst.getOperand(1));
1861 else {
1862 Value *Op0 = Inst.getOperand(0);
1863 auto LI = Builder.CreateLoad(Op0->getType()->getPointerElementType(),
1864 Op0, Inst.getName());
1865 LI->setDebugLoc(Inst.getDebugLoc());
1866 Inst.replaceAllUsesWith(LI);
1867 }
1868 ToErase.push_back(&Inst);
1869 }
1870 else {
1871 // change to vload/vstore that has the same address space as
1872 // the global-var in order to clean up unnecessary addr-cast.
1873 auto AS1 = GV->getType()->getAddressSpace();
1874 if (AS0 != AS1) {
1875 IRBuilder<> Builder(&Inst);
1876 if (GenXIntrinsic::isVStore(&Inst)) {
1877 auto PtrTy = cast<PointerType>(Inst.getOperand(1)->getType());
1878 PtrTy = PointerType::get(PtrTy->getElementType(), AS1);
1879 auto PtrCast = Builder.CreateAddrSpaceCast(Inst.getOperand(1), PtrTy);
1880 Type* Tys[] = { Inst.getOperand(0)->getType(),
1881 PtrCast->getType() };
1882 Value* Args[] = { Inst.getOperand(0), PtrCast };
1883 Function* Fn = GenXIntrinsic::getGenXDeclaration(
1884 F.getParent(), GenXIntrinsic::genx_vstore, Tys);
1885 Builder.CreateCall(Fn, Args, Inst.getName());
1886 }
1887 else {
1888 auto PtrTy = cast<PointerType>(Inst.getOperand(0)->getType());
1889 PtrTy = PointerType::get(PtrTy->getElementType(), AS1);
1890 auto PtrCast = Builder.CreateAddrSpaceCast(Inst.getOperand(0), PtrTy);
1891 Type* Tys[] = { Inst.getType(), PtrCast->getType() };
1892 Function* Fn = GenXIntrinsic::getGenXDeclaration(
1893 F.getParent(), GenXIntrinsic::genx_vload, Tys);
1894 Value* VLoad = Builder.CreateCall(Fn, PtrCast, Inst.getName());
1895 Inst.replaceAllUsesWith(VLoad);
1896 }
1897 ToErase.push_back(&Inst);
1898 }
1899 }
1900 }
1901 }
1902
1903 for (auto Inst : ToErase) {
1904 Inst->eraseFromParent();
1905 }
1906
1907 return !ToErase.empty();
1908 }
1909
isBitCastForLifetimeMarker(Value * V)1910 static bool isBitCastForLifetimeMarker(Value *V) {
1911 if (!V || !isa<BitCastInst>(V))
1912 return false;
1913 for (auto U : V->users()) {
1914 unsigned IntrinsicID = GenXIntrinsic::getAnyIntrinsicID(U);
1915 if (IntrinsicID != Intrinsic::lifetime_start &&
1916 IntrinsicID != Intrinsic::lifetime_end)
1917 return false;
1918 }
1919 return true;
1920 }
1921
1922 // Check whether two values are bitwise identical.
isBitwiseIdentical(Value * V1,Value * V2)1923 static bool isBitwiseIdentical(Value *V1, Value *V2) {
1924 IGC_ASSERT_MESSAGE(V1, "null value");
1925 IGC_ASSERT_MESSAGE(V2, "null value");
1926 if (V1 == V2)
1927 return true;
1928 if (BitCastInst *BI = dyn_cast<BitCastInst>(V1))
1929 V1 = BI->getOperand(0);
1930 if (BitCastInst *BI = dyn_cast<BitCastInst>(V2))
1931 V2 = BI->getOperand(0);
1932
1933 // Special case arises from vload/vstore.
1934 if (GenXIntrinsic::isVLoad(V1) && GenXIntrinsic::isVLoad(V2)) {
1935 auto L1 = cast<CallInst>(V1);
1936 auto L2 = cast<CallInst>(V2);
1937 // Check if loading from the same location.
1938 if (L1->getOperand(0) != L2->getOperand(0))
1939 return false;
1940
1941 // Check if this pointer is local and only used in vload/vstore.
1942 Value *Addr = L1->getOperand(0);
1943 if (!isa<AllocaInst>(Addr))
1944 return false;
1945 for (auto UI : Addr->users()) {
1946 if (isa<BitCastInst>(UI)) {
1947 for (auto U : UI->users()) {
1948 unsigned IntrinsicID = GenXIntrinsic::getAnyIntrinsicID(U);
1949 if (IntrinsicID != Intrinsic::lifetime_start &&
1950 IntrinsicID != Intrinsic::lifetime_end)
1951 return false;
1952 }
1953 } else {
1954 if (!GenXIntrinsic::isVLoadStore(UI))
1955 return false;
1956 }
1957 }
1958
1959 // Check if there is no store to the same location in between.
1960 if (L1->getParent() != L2->getParent())
1961 return false;
1962 BasicBlock::iterator I = L1->getParent()->begin();
1963 for (; &*I != L1 && &*I != L2; ++I)
1964 /*empty*/;
1965 IGC_ASSERT(&*I == L1 || &*I == L2);
1966 auto IEnd = (&*I == L1) ? L2->getIterator() : L1->getIterator();
1967 for (; I != IEnd; ++I) {
1968 Instruction *Inst = &*I;
1969 if (GenXIntrinsic::isVStore(Inst) && Inst->getOperand(1) == Addr)
1970 return false;
1971 }
1972
1973 // OK.
1974 return true;
1975 }
1976
1977 // Cannot prove.
1978 return false;
1979 }
1980
match(DominatorTree & DT,PostDominatorTree & PDT)1981 bool ArgRefPattern::match(DominatorTree &DT, PostDominatorTree &PDT) {
1982 IGC_ASSERT(Alloca);
1983 if (Alloca->use_empty())
1984 return false;
1985
1986 // check if all users are load/store.
1987 SmallVector<CallInst *, 8> Loads;
1988 SmallVector<CallInst *, 8> Stores;
1989 for (auto U : Alloca->users())
1990 if (GenXIntrinsic::isVLoad(U))
1991 Loads.push_back(cast<CallInst>(U));
1992 else if (GenXIntrinsic::isVStore(U))
1993 Stores.push_back(cast<CallInst>(U));
1994 else if (isBitCastForLifetimeMarker(U))
1995 continue;
1996 else
1997 return false;
1998
1999 if (Loads.empty() || Stores.empty())
2000 return false;
2001
2002 // find a unique store that dominates all other users if exists.
2003 auto Cmp = [&](CallInst *L, CallInst *R) { return DT.dominates(L, R); };
2004 CopyInStore = *std::min_element(Stores.begin(), Stores.end(), Cmp);
2005 CopyInRegion = dyn_cast<CallInst>(CopyInStore->getArgOperand(0));
2006 if (!CopyInRegion || !CopyInRegion->hasOneUse() || !GenXIntrinsic::isRdRegion(CopyInRegion))
2007 return false;
2008
2009 for (auto SI : Stores)
2010 if (SI != CopyInStore && !Cmp(CopyInStore, SI))
2011 return false;
2012 for (auto LI : Loads)
2013 if (LI != CopyInStore && !Cmp(CopyInStore, LI))
2014 return false;
2015
2016 // find a unique load that post-dominates all other users if exists.
2017 auto PostCmp = [&](CallInst *L, CallInst *R) {
2018 BasicBlock *LBB = L->getParent();
2019 BasicBlock *RBB = R->getParent();
2020 if (LBB != RBB)
2021 return PDT.dominates(LBB, RBB);
2022
2023 // Loop through the basic block until we find L or R.
2024 BasicBlock::const_iterator I = LBB->begin();
2025 for (; &*I != L && &*I != R; ++I)
2026 /*empty*/;
2027
2028 return &*I == R;
2029 };
2030 CopyOutLoad = *std::min_element(Loads.begin(), Loads.end(), PostCmp);
2031
2032 // Expect copy-out load has one or zero use. It is possible there
2033 // is no use as the region becomes dead after this subroutine call.
2034 //
2035 if (!CopyOutLoad->use_empty()) {
2036 if (!CopyOutLoad->hasOneUse())
2037 return false;
2038 CopyOutRegion = dyn_cast<CallInst>(CopyOutLoad->user_back());
2039 if (!GenXIntrinsic::isWrRegion(CopyOutRegion))
2040 return false;
2041 }
2042
2043 for (auto SI : Stores)
2044 if (SI != CopyOutLoad && !PostCmp(CopyOutLoad, SI))
2045 return false;
2046 for (auto LI : Loads)
2047 if (LI != CopyOutLoad && !PostCmp(CopyOutLoad, LI))
2048 return false;
2049
2050 // Ensure read-in and write-out to the same region. It is possible that region
2051 // collasping does not simplify region accesses completely.
2052 // Probably we should use an assertion statement on region descriptors.
2053 if (CopyOutRegion &&
2054 !isBitwiseIdentical(CopyInRegion->getOperand(0),
2055 CopyOutRegion->getOperand(0)))
2056 return false;
2057
2058 // It should be OK to rewrite all loads and stores into the argref.
2059 VLoads.swap(Loads);
2060 VStores.swap(Stores);
2061 return true;
2062 }
2063
process(DominatorTree & DT)2064 void ArgRefPattern::process(DominatorTree &DT) {
2065 // 'Spill' the base region into memory during rewriting.
2066 IRBuilder<> Builder(Alloca);
2067 Function *RdFn = CopyInRegion->getCalledFunction();
2068 IGC_ASSERT(RdFn);
2069 Type *BaseAllocaTy = RdFn->getFunctionType()->getParamType(0);
2070 AllocaInst *BaseAlloca = Builder.CreateAlloca(BaseAllocaTy, nullptr,
2071 Alloca->getName() + ".refprom");
2072
2073 Builder.SetInsertPoint(CopyInRegion);
2074 Builder.CreateStore(CopyInRegion->getArgOperand(0), BaseAlloca);
2075
2076 if (CopyOutRegion) {
2077 Builder.SetInsertPoint(CopyOutRegion);
2078 CopyOutRegion->setArgOperand(
2079 0, Builder.CreateLoad(BaseAlloca->getType()->getPointerElementType(),
2080 BaseAlloca));
2081 }
2082
2083 // Rewrite all stores.
2084 for (auto ST : VStores) {
2085 Builder.SetInsertPoint(ST);
2086 Value *OldVal = Builder.CreateLoad(
2087 BaseAlloca->getType()->getPointerElementType(), BaseAlloca);
2088 // Always use copy-in region arguments as copy-out region
2089 // arguments do not dominate this store.
2090 auto M = ST->getParent()->getParent()->getParent();
2091 Value *Args[] = {OldVal,
2092 ST->getArgOperand(0),
2093 CopyInRegion->getArgOperand(1), // vstride
2094 CopyInRegion->getArgOperand(2), // width
2095 CopyInRegion->getArgOperand(3), // hstride
2096 CopyInRegion->getArgOperand(4), // offset
2097 CopyInRegion->getArgOperand(5), // parent width
2098 ConstantInt::getTrue(Type::getInt1Ty(M->getContext()))};
2099 auto ID = OldVal->getType()->isFPOrFPVectorTy() ? GenXIntrinsic::genx_wrregionf
2100 : GenXIntrinsic::genx_wrregioni;
2101 Type *Tys[] = {Args[0]->getType(), Args[1]->getType(), Args[5]->getType(),
2102 Args[7]->getType()};
2103 auto WrFn = GenXIntrinsic::getGenXDeclaration(M, ID, Tys);
2104 Value *NewVal = Builder.CreateCall(WrFn, Args);
2105 Builder.CreateStore(NewVal, BaseAlloca);
2106 ST->eraseFromParent();
2107 }
2108
2109 // Rewrite all loads
2110 for (auto LI : VLoads) {
2111 if (LI->use_empty())
2112 continue;
2113
2114 Builder.SetInsertPoint(LI);
2115 Value *SrcVal = Builder.CreateLoad(
2116 BaseAlloca->getType()->getPointerElementType(), BaseAlloca);
2117 SmallVector<Value *, 8> Args(CopyInRegion->arg_operands());
2118 Args[0] = SrcVal;
2119 Value *Val = Builder.CreateCall(RdFn, Args);
2120 LI->replaceAllUsesWith(Val);
2121 LI->eraseFromParent();
2122 }
2123 // BaseAlloca created manually, w/o RAUW, need fix debug-info for it
2124 llvm::replaceAllDbgUsesWith(*Alloca, *BaseAlloca, *BaseAlloca, DT);
2125 }
2126
2127 // Allocas that are used in reference argument passing may be promoted into the
2128 // base region.
promoteAllocas(Function & F)2129 bool CMLowerVLoadVStore::promoteAllocas(Function &F) {
2130 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2131 auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
2132 bool Modified = false;
2133
2134 SmallVector<AllocaInst *, 8> Allocas;
2135 for (auto &Inst : F.front().getInstList()) {
2136 if (auto AI = dyn_cast<AllocaInst>(&Inst))
2137 Allocas.push_back(AI);
2138 }
2139
2140 for (auto AI : Allocas) {
2141 ArgRefPattern ArgRef(AI);
2142 if (ArgRef.match(DT, PDT)) {
2143 ArgRef.process(DT);
2144 Modified = true;
2145 }
2146 }
2147
2148 return Modified;
2149 }
2150
createCMLowerVLoadVStorePass()2151 Pass *llvm::createCMLowerVLoadVStorePass() { return new CMLowerVLoadVStore; }
2152