1 /*
2   Copyright (c) 2010-2021, Intel Corporation
3   All rights reserved.
4 
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions are
7   met:
8 
9     * Redistributions of source code must retain the above copyright
10       notice, this list of conditions and the following disclaimer.
11 
12     * Redistributions in binary form must reproduce the above copyright
13       notice, this list of conditions and the following disclaimer in the
14       documentation and/or other materials provided with the distribution.
15 
16     * Neither the name of Intel Corporation nor the names of its
17       contributors may be used to endorse or promote products derived from
18       this software without specific prior written permission.
19 
20 
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
22    IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
23    TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
24    PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
25    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33 
34 /** @file opt.cpp
35     @brief Implementations of various ispc optimization passes that operate
36            on the LLVM IR.
37 */
38 
39 #include "opt.h"
40 #include "ctx.h"
41 #include "llvmutil.h"
42 #include "module.h"
43 #include "sym.h"
44 #include "util.h"
45 
46 #include <map>
47 #include <regex>
48 #include <set>
49 #include <stdio.h>
50 
51 #include <llvm/ADT/SmallSet.h>
52 #include <llvm/ADT/Triple.h>
53 #include <llvm/Analysis/BasicAliasAnalysis.h>
54 #include <llvm/Analysis/ConstantFolding.h>
55 #include <llvm/Analysis/Passes.h>
56 #include <llvm/Analysis/TargetLibraryInfo.h>
57 #include <llvm/Analysis/TargetTransformInfo.h>
58 #include <llvm/Analysis/TypeBasedAliasAnalysis.h>
59 #include <llvm/BinaryFormat/Dwarf.h>
60 #include <llvm/IR/BasicBlock.h>
61 #include <llvm/IR/Constants.h>
62 #include <llvm/IR/DataLayout.h>
63 #include <llvm/IR/DebugInfo.h>
64 #include <llvm/IR/Function.h>
65 #include <llvm/IR/IRPrintingPasses.h>
66 #include <llvm/IR/Instructions.h>
67 #include <llvm/IR/IntrinsicInst.h>
68 #include <llvm/IR/Intrinsics.h>
69 #include <llvm/IR/IntrinsicsX86.h>
70 #include <llvm/IR/LegacyPassManager.h>
71 #include <llvm/IR/Module.h>
72 #include <llvm/IR/PatternMatch.h>
73 #include <llvm/IR/Verifier.h>
74 #include <llvm/InitializePasses.h>
75 #include <llvm/Pass.h>
76 #include <llvm/PassRegistry.h>
77 #include <llvm/Support/raw_ostream.h>
78 #include <llvm/Target/TargetMachine.h>
79 #include <llvm/Target/TargetOptions.h>
80 #include <llvm/Transforms/IPO.h>
81 #include <llvm/Transforms/IPO/FunctionAttrs.h>
82 #include <llvm/Transforms/InstCombine/InstCombine.h>
83 #include <llvm/Transforms/Instrumentation.h>
84 #include <llvm/Transforms/Scalar.h>
85 #include <llvm/Transforms/Scalar/GVN.h>
86 #include <llvm/Transforms/Scalar/InstSimplifyPass.h>
87 #include <llvm/Transforms/Utils.h>
88 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
89 
90 #ifdef ISPC_HOST_IS_LINUX
91 #include <alloca.h>
92 #elif defined(ISPC_HOST_IS_WINDOWS)
93 #include <malloc.h>
94 #ifndef __MINGW32__
95 #define alloca _alloca
96 #endif
97 #endif // ISPC_HOST_IS_WINDOWS
98 
99 #ifndef PRId64
100 #define PRId64 "lld"
101 #endif
102 #ifndef PRIu64
103 #define PRIu64 "llu"
104 #endif
105 
106 #ifndef ISPC_NO_DUMPS
107 #include <llvm/Support/FileSystem.h>
108 #include <llvm/Support/Regex.h>
109 #endif
110 #ifdef ISPC_GENX_ENABLED
111 #include "gen/GlobalsLocalization.h"
112 #include <LLVMSPIRVLib/LLVMSPIRVLib.h>
113 #include <llvm/GenXIntrinsics/GenXIntrOpts.h>
114 #include <llvm/GenXIntrinsics/GenXIntrinsics.h>
115 #include <llvm/GenXIntrinsics/GenXSPIRVWriterAdaptor.h>
116 // Used for GenX gather coalescing
117 #include <llvm/Transforms/Utils/Local.h>
118 
119 // Constant in number of bytes.
120 enum { BYTE = 1, WORD = 2, DWORD = 4, QWORD = 8, OWORD = 16, GRF = 32 };
121 #endif
122 
123 using namespace ispc;
124 
125 static llvm::Pass *CreateIntrinsicsOptPass();
126 static llvm::Pass *CreateInstructionSimplifyPass();
127 static llvm::Pass *CreatePeepholePass();
128 
129 static llvm::Pass *CreateImproveMemoryOpsPass();
130 static llvm::Pass *CreateGatherCoalescePass();
131 static llvm::Pass *CreateReplacePseudoMemoryOpsPass();
132 
133 static llvm::Pass *CreateIsCompileTimeConstantPass(bool isLastTry);
134 static llvm::Pass *CreateMakeInternalFuncsStaticPass();
135 
136 #ifndef ISPC_NO_DUMPS
137 static llvm::Pass *CreateDebugPass(char *output);
138 static llvm::Pass *CreateDebugPassFile(int number, llvm::StringRef name);
139 #endif
140 
141 static llvm::Pass *CreateReplaceStdlibShiftPass();
142 
143 #ifdef ISPC_GENX_ENABLED
144 static llvm::Pass *CreateGenXGatherCoalescingPass();
145 static llvm::Pass *CreateReplaceLLVMIntrinsics();
146 static llvm::Pass *CreateFixDivisionInstructions();
147 static llvm::Pass *CreateFixAddressSpace();
148 static llvm::Pass *CreateDemotePHIs();
149 static llvm::Pass *CreateCheckUnsupportedInsts();
150 static llvm::Pass *CreateMangleOpenCLBuiltins();
151 #endif
152 
153 #ifndef ISPC_NO_DUMPS
154 #define DEBUG_START_PASS(NAME)                                                                                         \
155     if (g->debugPrint &&                                                                                               \
156         (getenv("FUNC") == NULL || (getenv("FUNC") != NULL && !strncmp(bb.getParent()->getName().str().c_str(),        \
157                                                                        getenv("FUNC"), strlen(getenv("FUNC")))))) {    \
158         fprintf(stderr, "Start of " NAME "\n");                                                                        \
159         fprintf(stderr, "---------------\n");                                                                          \
160         /*bb.dump();*/                                                                                                     \
161         fprintf(stderr, "---------------\n\n");                                                                        \
162     } else /* eat semicolon */
163 
164 #define DEBUG_END_PASS(NAME)                                                                                           \
165     if (g->debugPrint &&                                                                                               \
166         (getenv("FUNC") == NULL || (getenv("FUNC") != NULL && !strncmp(bb.getParent()->getName().str().c_str(),        \
167                                                                        getenv("FUNC"), strlen(getenv("FUNC")))))) {    \
168         fprintf(stderr, "End of " NAME " %s\n", modifiedAny ? "** CHANGES **" : "");                                   \
169         fprintf(stderr, "---------------\n");                                                                          \
170         /*bb.dump();*/                                                                                                     \
171         fprintf(stderr, "---------------\n\n");                                                                        \
172     } else /* eat semicolon */
173 #else
174 #define DEBUG_START_PASS(NAME)
175 #define DEBUG_END_PASS(NAME)
176 #endif
177 
178 ///////////////////////////////////////////////////////////////////////////
179 
180 /** This utility routine copies the metadata (if any) attached to the
181     'from' instruction in the IR to the 'to' instruction.
182 
183     For flexibility, this function takes an llvm::Value rather than an
184     llvm::Instruction for the 'to' parameter; at some places in the code
185     below, we sometimes use a llvm::Value to start out storing a value and
186     then later store instructions.  If a llvm::Value is passed to this, the
187     routine just returns without doing anything; if it is in fact an
188     LLVM::Instruction, then the metadata can be copied to it.
189  */
lCopyMetadata(llvm::Value * vto,const llvm::Instruction * from)190 static void lCopyMetadata(llvm::Value *vto, const llvm::Instruction *from) {
191     llvm::Instruction *to = llvm::dyn_cast<llvm::Instruction>(vto);
192     if (!to)
193         return;
194 
195     llvm::SmallVector<std::pair<unsigned int, llvm::MDNode *>, 8> metadata;
196 
197     from->getAllMetadata(metadata);
198     for (unsigned int i = 0; i < metadata.size(); ++i)
199         to->setMetadata(metadata[i].first, metadata[i].second);
200 }
201 
202 /** We have a protocol with the front-end LLVM IR code generation process
203     that allows us to encode the source file position that corresponds with
204     instructions.  (For example, this allows us to issue performance
205     warnings related to things like scatter and gather after optimization
206     has been performed, so that we aren't warning about scatters and
207     gathers that have been improved to stores and loads by optimization
208     passes.)  Note that this is slightly redundant with the source file
209     position encoding generated for debugging symbols, though we don't
210     always generate debugging information but we do always generate this
211     position data.
212 
213     This function finds the SourcePos that the metadata in the instruction
214     (if present) corresponds to.  See the implementation of
215     FunctionEmitContext::addGSMetadata(), which encodes the source position during
216     code generation.
217 
218     @param inst   Instruction to try to find the source position of
219     @param pos    Output variable in which to store the position
220     @returns      True if source file position metadata was present and *pos
221                   has been set.  False otherwise.
222 */
lGetSourcePosFromMetadata(const llvm::Instruction * inst,SourcePos * pos)223 static bool lGetSourcePosFromMetadata(const llvm::Instruction *inst, SourcePos *pos) {
224     llvm::MDNode *filename = inst->getMetadata("filename");
225     llvm::MDNode *first_line = inst->getMetadata("first_line");
226     llvm::MDNode *first_column = inst->getMetadata("first_column");
227     llvm::MDNode *last_line = inst->getMetadata("last_line");
228     llvm::MDNode *last_column = inst->getMetadata("last_column");
229 
230     if (!filename || !first_line || !first_column || !last_line || !last_column)
231         return false;
232 
233     // All of these asserts are things that FunctionEmitContext::addGSMetadata() is
234     // expected to have done in its operation
235     llvm::MDString *str = llvm::dyn_cast<llvm::MDString>(filename->getOperand(0));
236     Assert(str);
237     llvm::ConstantInt *first_lnum =
238 
239         llvm::mdconst::extract<llvm::ConstantInt>(first_line->getOperand(0));
240     Assert(first_lnum);
241 
242     llvm::ConstantInt *first_colnum =
243 
244         llvm::mdconst::extract<llvm::ConstantInt>(first_column->getOperand(0));
245     Assert(first_column);
246 
247     llvm::ConstantInt *last_lnum =
248 
249         llvm::mdconst::extract<llvm::ConstantInt>(last_line->getOperand(0));
250     Assert(last_lnum);
251 
252     llvm::ConstantInt *last_colnum = llvm::mdconst::extract<llvm::ConstantInt>(last_column->getOperand(0));
253     Assert(last_column);
254 
255     *pos = SourcePos(str->getString().data(), (int)first_lnum->getZExtValue(), (int)first_colnum->getZExtValue(),
256                      (int)last_lnum->getZExtValue(), (int)last_colnum->getZExtValue());
257     return true;
258 }
259 
lCallInst(llvm::Function * func,llvm::Value * arg0,llvm::Value * arg1,const llvm::Twine & name,llvm::Instruction * insertBefore=NULL)260 static llvm::Instruction *lCallInst(llvm::Function *func, llvm::Value *arg0, llvm::Value *arg1, const llvm::Twine &name,
261                                     llvm::Instruction *insertBefore = NULL) {
262     llvm::Value *args[2] = {arg0, arg1};
263     llvm::ArrayRef<llvm::Value *> newArgArray(&args[0], &args[2]);
264     return llvm::CallInst::Create(func, newArgArray, name, insertBefore);
265 }
266 
lCallInst(llvm::Function * func,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,const llvm::Twine & name,llvm::Instruction * insertBefore=NULL)267 static llvm::Instruction *lCallInst(llvm::Function *func, llvm::Value *arg0, llvm::Value *arg1, llvm::Value *arg2,
268                                     const llvm::Twine &name, llvm::Instruction *insertBefore = NULL) {
269     llvm::Value *args[3] = {arg0, arg1, arg2};
270     llvm::ArrayRef<llvm::Value *> newArgArray(&args[0], &args[3]);
271     return llvm::CallInst::Create(func, newArgArray, name, insertBefore);
272 }
273 
lCallInst(llvm::Function * func,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,llvm::Value * arg3,const llvm::Twine & name,llvm::Instruction * insertBefore=NULL)274 static llvm::Instruction *lCallInst(llvm::Function *func, llvm::Value *arg0, llvm::Value *arg1, llvm::Value *arg2,
275                                     llvm::Value *arg3, const llvm::Twine &name,
276                                     llvm::Instruction *insertBefore = NULL) {
277     llvm::Value *args[4] = {arg0, arg1, arg2, arg3};
278     llvm::ArrayRef<llvm::Value *> newArgArray(&args[0], &args[4]);
279     return llvm::CallInst::Create(func, newArgArray, name, insertBefore);
280 }
281 
lCallInst(llvm::Function * func,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,llvm::Value * arg3,llvm::Value * arg4,const llvm::Twine & name,llvm::Instruction * insertBefore=NULL)282 static llvm::Instruction *lCallInst(llvm::Function *func, llvm::Value *arg0, llvm::Value *arg1, llvm::Value *arg2,
283                                     llvm::Value *arg3, llvm::Value *arg4, const llvm::Twine &name,
284                                     llvm::Instruction *insertBefore = NULL) {
285     llvm::Value *args[5] = {arg0, arg1, arg2, arg3, arg4};
286     llvm::ArrayRef<llvm::Value *> newArgArray(&args[0], &args[5]);
287     return llvm::CallInst::Create(func, newArgArray, name, insertBefore);
288 }
289 
lCallInst(llvm::Function * func,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,llvm::Value * arg3,llvm::Value * arg4,llvm::Value * arg5,const llvm::Twine & name,llvm::Instruction * insertBefore=NULL)290 static llvm::Instruction *lCallInst(llvm::Function *func, llvm::Value *arg0, llvm::Value *arg1, llvm::Value *arg2,
291                                     llvm::Value *arg3, llvm::Value *arg4, llvm::Value *arg5, const llvm::Twine &name,
292                                     llvm::Instruction *insertBefore = NULL) {
293     llvm::Value *args[6] = {arg0, arg1, arg2, arg3, arg4, arg5};
294     llvm::ArrayRef<llvm::Value *> newArgArray(&args[0], &args[6]);
295     return llvm::CallInst::Create(func, newArgArray, name, insertBefore);
296 }
297 
lGEPInst(llvm::Value * ptr,llvm::Value * offset,const char * name,llvm::Instruction * insertBefore)298 static llvm::Instruction *lGEPInst(llvm::Value *ptr, llvm::Value *offset, const char *name,
299                                    llvm::Instruction *insertBefore) {
300     llvm::Value *index[1] = {offset};
301     llvm::ArrayRef<llvm::Value *> arrayRef(&index[0], &index[1]);
302 
303     return llvm::GetElementPtrInst::Create(PTYPE(ptr), ptr, arrayRef, name, insertBefore);
304 }
305 
306 /** Given a vector of constant values (int, float, or bool) representing an
307     execution mask, convert it to a bitvector where the 0th bit corresponds
308     to the first vector value and so forth.
309 */
lConstElementsToMask(const llvm::SmallVector<llvm::Constant *,ISPC_MAX_NVEC> & elements)310 static uint64_t lConstElementsToMask(const llvm::SmallVector<llvm::Constant *, ISPC_MAX_NVEC> &elements) {
311     Assert(elements.size() <= 64);
312 
313     uint64_t mask = 0;
314     uint64_t undefSetMask = 0;
315     llvm::APInt intMaskValue;
316     for (unsigned int i = 0; i < elements.size(); ++i) {
317         // SSE has the "interesting" approach of encoding blending
318         // masks as <n x float>.
319         if (llvm::ConstantFP *cf = llvm::dyn_cast<llvm::ConstantFP>(elements[i])) {
320             llvm::APFloat apf = cf->getValueAPF();
321             intMaskValue = apf.bitcastToAPInt();
322         } else if (llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(elements[i])) {
323             // Otherwise get it as an int
324             intMaskValue = ci->getValue();
325         } else {
326             // We create a separate 'undef mask' with all undef bits set.
327             // This mask will have no bits set if there are no 'undef' elements.
328             llvm::UndefValue *uv = llvm::dyn_cast<llvm::UndefValue>(elements[i]);
329             Assert(uv != NULL); // vs return -1 if NULL?
330             undefSetMask |= (1ull << i);
331             continue;
332         }
333         // Is the high-bit set?  If so, OR in the appropriate bit in
334         // the result mask
335         if (intMaskValue.countLeadingOnes() > 0)
336             mask |= (1ull << i);
337     }
338 
339     // if no bits are set in mask, do not need to consider undefs. It's
340     // always 'all_off'.
341     // If any bits are set in mask, assume' undef' bits as as '1'. This ensures
342     // cases with only '1's and 'undef's will be considered as 'all_on'
343     if (mask != 0)
344         mask |= undefSetMask;
345 
346     return mask;
347 }
348 
349 /** Given an llvm::Value represinting a vector mask, see if the value is a
350     constant.  If so, return true and set *bits to be the integer mask
351     found by taking the high bits of the mask values in turn and
352     concatenating them into a single integer.  In other words, given the
353     4-wide mask: < 0xffffffff, 0, 0, 0xffffffff >, we have 0b1001 = 9.
354  */
lGetMask(llvm::Value * factor,uint64_t * mask)355 static bool lGetMask(llvm::Value *factor, uint64_t *mask) {
356     llvm::ConstantDataVector *cdv = llvm::dyn_cast<llvm::ConstantDataVector>(factor);
357     if (cdv != NULL) {
358         llvm::SmallVector<llvm::Constant *, ISPC_MAX_NVEC> elements;
359         for (int i = 0; i < (int)cdv->getNumElements(); ++i)
360             elements.push_back(cdv->getElementAsConstant(i));
361         *mask = lConstElementsToMask(elements);
362         return true;
363     }
364 
365     llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(factor);
366     if (cv != NULL) {
367         llvm::SmallVector<llvm::Constant *, ISPC_MAX_NVEC> elements;
368         for (int i = 0; i < (int)cv->getNumOperands(); ++i) {
369             llvm::Constant *c = llvm::dyn_cast<llvm::Constant>(cv->getOperand(i));
370             if (c == NULL)
371                 return false;
372             if (llvm::isa<llvm::ConstantExpr>(cv->getOperand(i)))
373                 return false; // We can not handle constant expressions here
374             elements.push_back(c);
375         }
376         *mask = lConstElementsToMask(elements);
377         return true;
378     } else if (llvm::isa<llvm::ConstantAggregateZero>(factor)) {
379         *mask = 0;
380         return true;
381     } else {
382 #if 0
383         llvm::ConstantExpr *ce = llvm::dyn_cast<llvm::ConstantExpr>(factor);
384         if (ce != NULL) {
385             llvm::TargetMachine *targetMachine = g->target->GetTargetMachine();
386             const llvm::TargetData *td = targetMachine->getTargetData();
387             llvm::Constant *c = llvm::ConstantFoldConstantExpression(ce, td);
388             c->dump();
389             factor = c;
390         }
391         // else we should be able to handle it above...
392         Assert(!llvm::isa<llvm::Constant>(factor));
393 #endif
394         return false;
395     }
396 }
397 
398 enum class MaskStatus { all_on, all_off, mixed, unknown };
399 
400 /** Determines if the given mask value is all on, all off, mixed, or
401     unknown at compile time.
402 */
lGetMaskStatus(llvm::Value * mask,int vecWidth=-1)403 static MaskStatus lGetMaskStatus(llvm::Value *mask, int vecWidth = -1) {
404     uint64_t bits;
405     if (lGetMask(mask, &bits) == false)
406         return MaskStatus::unknown;
407 
408     if (bits == 0)
409         return MaskStatus::all_off;
410 
411     if (vecWidth == -1)
412         vecWidth = g->target->getVectorWidth();
413     Assert(vecWidth <= 64);
414 
415     for (int i = 0; i < vecWidth; ++i) {
416         if ((bits & (1ull << i)) == 0)
417             return MaskStatus::mixed;
418     }
419     return MaskStatus::all_on;
420 }
421 
422 ///////////////////////////////////////////////////////////////////////////
423 // This is a wrap over class llvm::PassManager. This duplicates PassManager function run()
424 //   and change PassManager function add by adding some checks and debug passes.
425 //   This wrap can control:
426 //   - If we want to switch off optimization with given number.
427 //   - If we want to dump LLVM IR after optimization with given number.
428 //   - If we want to generate LLVM IR debug for gdb after optimization with given number.
429 class DebugPassManager {
430   public:
DebugPassManager()431     DebugPassManager() : number(0) {}
432     void add(llvm::Pass *P, int stage);
run(llvm::Module & M)433     bool run(llvm::Module &M) { return PM.run(M); }
getPM()434     llvm::legacy::PassManager &getPM() { return PM; }
435 
436   private:
437     llvm::legacy::PassManager PM;
438     int number;
439 };
440 
add(llvm::Pass * P,int stage=-1)441 void DebugPassManager::add(llvm::Pass *P, int stage = -1) {
442     // taking number of optimization
443     if (stage == -1) {
444         number++;
445     } else {
446         number = stage;
447     }
448     if (g->off_stages.find(number) == g->off_stages.end()) {
449         // adding optimization (not switched off)
450         PM.add(P);
451 #ifndef ISPC_NO_DUMPS
452         if (g->debug_stages.find(number) != g->debug_stages.end()) {
453             // adding dump of LLVM IR after optimization
454             if (g->dumpFile) {
455                 PM.add(CreateDebugPassFile(number, P->getPassName()));
456             } else {
457                 char buf[100];
458                 snprintf(buf, sizeof(buf), "\n\n*****LLVM IR after phase %d: %s*****\n\n", number,
459                          P->getPassName().data());
460                 PM.add(CreateDebugPass(buf));
461             }
462         }
463 #endif
464     }
465 }
466 ///////////////////////////////////////////////////////////////////////////
467 
Optimize(llvm::Module * module,int optLevel)468 void ispc::Optimize(llvm::Module *module, int optLevel) {
469 #ifndef ISPC_NO_DUMPS
470     if (g->debugPrint) {
471         printf("*** Code going into optimization ***\n");
472         //module->dump();
473     }
474 #endif
475     DebugPassManager optPM;
476 
477     if (g->enableLLVMIntrinsics) {
478         // Required for matrix intrinsics. This needs to happen before VerifierPass.
479         // TODO : Limit pass to only when llvm.matrix.* intrinsics are used.
480         optPM.add(llvm::createLowerMatrixIntrinsicsPass()); // llvm.matrix
481     }
482     optPM.add(llvm::createVerifierPass(), 0);
483 
484     optPM.add(new llvm::TargetLibraryInfoWrapperPass(llvm::Triple(module->getTargetTriple())));
485     if (!g->target->isGenXTarget()) {
486         llvm::TargetMachine *targetMachine = g->target->GetTargetMachine();
487         optPM.getPM().add(createTargetTransformInfoWrapperPass(targetMachine->getTargetIRAnalysis()));
488     }
489     optPM.add(llvm::createIndVarSimplifyPass());
490 
491 #if ISPC_LLVM_VERSION >= ISPC_LLVM_12_0
492     llvm::SimplifyCFGOptions simplifyCFGopt;
493     simplifyCFGopt.HoistCommonInsts = true;
494 #endif
495     if (optLevel == 0) {
496         // This is more or less the minimum set of optimizations that we
497         // need to do to generate code that will actually run.  (We can't
498         // run absolutely no optimizations, since the front-end needs us to
499         // take the various __pseudo_* functions it has emitted and turn
500         // them into something that can actually execute.
501 #ifdef ISPC_GENX_ENABLED
502         if (g->target->isGenXTarget()) {
503             // Global DCE is required for ISPCSimdCFLoweringPass
504             optPM.add(llvm::createGlobalDCEPass());
505             // FIXME: temporary solution
506             optPM.add(llvm::createBreakCriticalEdgesPass());
507             optPM.add(CreateDemotePHIs());
508             optPM.add(llvm::createISPCSimdCFLoweringPass());
509             // FIXME: temporary solution
510             optPM.add(llvm::createPromoteMemoryToRegisterPass());
511         }
512 #endif
513         optPM.add(CreateImproveMemoryOpsPass(), 100);
514 
515         if (g->opt.disableHandlePseudoMemoryOps == false)
516             optPM.add(CreateReplacePseudoMemoryOpsPass());
517 
518         optPM.add(CreateIntrinsicsOptPass(), 102);
519         optPM.add(CreateIsCompileTimeConstantPass(true));
520 #ifdef ISPC_GENX_ENABLED
521         if (g->target->isGenXTarget()) {
522             // InstructionCombining pass is required for FixDivisionInstructions
523             optPM.add(llvm::createInstructionCombiningPass());
524             optPM.add(CreateFixDivisionInstructions());
525         }
526 #endif
527         optPM.add(llvm::createFunctionInliningPass());
528         optPM.add(CreateMakeInternalFuncsStaticPass());
529 #if ISPC_LLVM_VERSION >= ISPC_LLVM_12_0
530         optPM.add(llvm::createCFGSimplificationPass(simplifyCFGopt));
531 #else
532         optPM.add(llvm::createCFGSimplificationPass());
533 #endif
534 #ifdef ISPC_GENX_ENABLED
535         if (g->target->isGenXTarget()) {
536             optPM.add(llvm::createPromoteMemoryToRegisterPass());
537             optPM.add(llvm::createGlobalsLocalizationPass());
538             // Remove dead globals after localization
539             optPM.add(llvm::createGlobalDCEPass());
540             // This pass is needed for correct prints work
541             optPM.add(llvm::createSROAPass());
542             optPM.add(CreateReplaceLLVMIntrinsics());
543             optPM.add(CreateCheckUnsupportedInsts());
544             optPM.add(CreateFixAddressSpace());
545             optPM.add(CreateMangleOpenCLBuiltins());
546             // This pass is required to prepare LLVM IR for open source SPIR-V translator
547             optPM.add(
548                 llvm::createGenXSPIRVWriterAdaptorPass(true /*RewriteTypes*/, false /*RewriteSingleElementVectors*/));
549         }
550 #endif
551         optPM.add(llvm::createGlobalDCEPass());
552     } else {
553         llvm::PassRegistry *registry = llvm::PassRegistry::getPassRegistry();
554         llvm::initializeCore(*registry);
555         llvm::initializeScalarOpts(*registry);
556         llvm::initializeIPO(*registry);
557         llvm::initializeAnalysis(*registry);
558         llvm::initializeTransformUtils(*registry);
559         llvm::initializeInstCombine(*registry);
560         llvm::initializeInstrumentation(*registry);
561         llvm::initializeTarget(*registry);
562 
563         optPM.add(llvm::createGlobalDCEPass(), 184);
564 
565 #ifdef ISPC_GENX_ENABLED
566         if (g->target->isGenXTarget()) {
567             // FIXME: temporary solution
568             optPM.add(llvm::createBreakCriticalEdgesPass());
569             optPM.add(CreateDemotePHIs());
570             optPM.add(llvm::createISPCSimdCFLoweringPass());
571             // FIXME: temporary solution
572             optPM.add(llvm::createPromoteMemoryToRegisterPass());
573         }
574 #endif
575         // Setup to use LLVM default AliasAnalysis
576         // Ideally, we want call:
577         //    llvm::PassManagerBuilder pm_Builder;
578         //    pm_Builder.OptLevel = optLevel;
579         //    pm_Builder.addInitialAliasAnalysisPasses(optPM);
580         // but the addInitialAliasAnalysisPasses() is a private function
581         // so we explicitly enable them here.
582         // Need to keep sync with future LLVM change
583         // An alternative is to call populateFunctionPassManager()
584         optPM.add(llvm::createTypeBasedAAWrapperPass(), 190);
585         optPM.add(llvm::createBasicAAWrapperPass());
586 #if ISPC_LLVM_VERSION >= ISPC_LLVM_12_0
587         optPM.add(llvm::createCFGSimplificationPass(simplifyCFGopt));
588 #else
589         optPM.add(llvm::createCFGSimplificationPass());
590 #endif
591 
592         optPM.add(llvm::createSROAPass());
593 
594         optPM.add(llvm::createEarlyCSEPass());
595         optPM.add(llvm::createLowerExpectIntrinsicPass());
596 
597         // Early optimizations to try to reduce the total amount of code to
598         // work with if we can
599         optPM.add(llvm::createReassociatePass(), 200);
600         optPM.add(llvm::createInstSimplifyLegacyPass());
601         optPM.add(llvm::createDeadCodeEliminationPass());
602 #if ISPC_LLVM_VERSION >= ISPC_LLVM_12_0
603         optPM.add(llvm::createCFGSimplificationPass(simplifyCFGopt));
604 #else
605         optPM.add(llvm::createCFGSimplificationPass());
606 #endif
607 
608         optPM.add(llvm::createPromoteMemoryToRegisterPass());
609         optPM.add(llvm::createAggressiveDCEPass());
610 
611         if (g->opt.disableGatherScatterOptimizations == false && g->target->getVectorWidth() > 1) {
612             optPM.add(llvm::createInstructionCombiningPass(), 210);
613             optPM.add(CreateImproveMemoryOpsPass());
614         }
615         if (!g->opt.disableMaskAllOnOptimizations) {
616             optPM.add(CreateIntrinsicsOptPass(), 215);
617             optPM.add(CreateInstructionSimplifyPass());
618         }
619         optPM.add(llvm::createDeadCodeEliminationPass(), 220);
620 
621         // On to more serious optimizations
622         optPM.add(llvm::createSROAPass());
623         optPM.add(llvm::createInstructionCombiningPass());
624 #if ISPC_LLVM_VERSION >= ISPC_LLVM_12_0
625         optPM.add(llvm::createCFGSimplificationPass(simplifyCFGopt));
626 #else
627         optPM.add(llvm::createCFGSimplificationPass());
628 #endif
629         optPM.add(llvm::createPromoteMemoryToRegisterPass());
630         optPM.add(llvm::createGlobalOptimizerPass());
631         optPM.add(llvm::createReassociatePass());
632         // IPConstProp will not be supported by LLVM moving forward.
633         // Switching to IPSCCP which is its recommended functional equivalent.
634         // TODO : Make IPSCCP the default after ISPC 1.14 release.
635 #if ISPC_LLVM_VERSION < ISPC_LLVM_12_0
636         optPM.add(llvm::createIPConstantPropagationPass());
637 #else
638         optPM.add(llvm::createIPSCCPPass());
639 #endif
640 
641         optPM.add(CreateReplaceStdlibShiftPass(), 229);
642 
643         optPM.add(llvm::createDeadArgEliminationPass(), 230);
644         optPM.add(llvm::createInstructionCombiningPass());
645 #if ISPC_LLVM_VERSION >= ISPC_LLVM_12_0
646         optPM.add(llvm::createCFGSimplificationPass(simplifyCFGopt));
647 #else
648         optPM.add(llvm::createCFGSimplificationPass());
649 #endif
650         optPM.add(llvm::createPruneEHPass());
651 #ifdef ISPC_GENX_ENABLED
652         if (g->target->isGenXTarget()) {
653             optPM.add(CreateFixDivisionInstructions());
654         }
655 #endif
656         optPM.add(llvm::createPostOrderFunctionAttrsLegacyPass());
657         optPM.add(llvm::createReversePostOrderFunctionAttrsPass());
658 
659         // Next inline pass will remove functions, saved by __keep_funcs_live
660         optPM.add(llvm::createFunctionInliningPass());
661         optPM.add(llvm::createInstSimplifyLegacyPass());
662         optPM.add(llvm::createDeadCodeEliminationPass());
663 #if ISPC_LLVM_VERSION >= ISPC_LLVM_12_0
664         optPM.add(llvm::createCFGSimplificationPass(simplifyCFGopt));
665 #else
666         optPM.add(llvm::createCFGSimplificationPass());
667 #endif
668 
669         optPM.add(llvm::createArgumentPromotionPass());
670 
671         optPM.add(llvm::createAggressiveDCEPass());
672         optPM.add(llvm::createInstructionCombiningPass(), 241);
673         optPM.add(llvm::createJumpThreadingPass());
674 #if ISPC_LLVM_VERSION >= ISPC_LLVM_12_0
675         optPM.add(llvm::createCFGSimplificationPass(simplifyCFGopt));
676 #else
677         optPM.add(llvm::createCFGSimplificationPass());
678 #endif
679 
680         optPM.add(llvm::createSROAPass());
681 
682         optPM.add(llvm::createInstructionCombiningPass());
683 #ifdef ISPC_GENX_ENABLED
684         if (g->target->isGenXTarget()) {
685             // Inline
686             optPM.add(llvm::createCorrelatedValuePropagationPass());
687             optPM.add(llvm::createInstructionCombiningPass());
688             optPM.add(llvm::createGlobalDCEPass());
689             optPM.add(llvm::createInstructionCombiningPass());
690             optPM.add(llvm::createEarlyCSEPass());
691             optPM.add(llvm::createDeadCodeEliminationPass());
692             optPM.add(llvm::createGlobalsLocalizationPass());
693             // remove dead globals after localization
694             optPM.add(llvm::createGlobalDCEPass());
695         }
696 #endif
697         optPM.add(llvm::createTailCallEliminationPass());
698 
699         if (!g->opt.disableMaskAllOnOptimizations) {
700             optPM.add(CreateIntrinsicsOptPass(), 250);
701             optPM.add(CreateInstructionSimplifyPass());
702         }
703 
704         if (g->opt.disableGatherScatterOptimizations == false && g->target->getVectorWidth() > 1) {
705             optPM.add(llvm::createInstructionCombiningPass(), 255);
706 #ifdef ISPC_GENX_ENABLED
707             if (g->target->isGenXTarget() && !g->opt.disableGenXGatherCoalescing)
708                 optPM.add(CreateGenXGatherCoalescingPass());
709 #endif
710             optPM.add(CreateImproveMemoryOpsPass());
711 
712             if (g->opt.disableCoalescing == false) {
713                 // It is important to run this here to make it easier to
714                 // finding matching gathers we can coalesce..
715                 optPM.add(llvm::createEarlyCSEPass(), 260);
716                 optPM.add(CreateGatherCoalescePass());
717             }
718         }
719 
720         optPM.add(llvm::createFunctionInliningPass(), 265);
721         optPM.add(llvm::createInstSimplifyLegacyPass());
722         optPM.add(CreateIntrinsicsOptPass());
723         optPM.add(CreateInstructionSimplifyPass());
724 
725         if (g->opt.disableGatherScatterOptimizations == false && g->target->getVectorWidth() > 1) {
726             optPM.add(llvm::createInstructionCombiningPass(), 270);
727             optPM.add(CreateImproveMemoryOpsPass());
728         }
729 
730         optPM.add(llvm::createIPSCCPPass(), 275);
731         optPM.add(llvm::createDeadArgEliminationPass());
732         optPM.add(llvm::createAggressiveDCEPass());
733         optPM.add(llvm::createInstructionCombiningPass());
734 #if ISPC_LLVM_VERSION >= ISPC_LLVM_12_0
735         optPM.add(llvm::createCFGSimplificationPass(simplifyCFGopt));
736 #else
737         optPM.add(llvm::createCFGSimplificationPass());
738 #endif
739 
740         if (g->opt.disableHandlePseudoMemoryOps == false) {
741             optPM.add(CreateReplacePseudoMemoryOpsPass(), 280);
742         }
743         optPM.add(CreateIntrinsicsOptPass(), 281);
744         optPM.add(CreateInstructionSimplifyPass());
745 
746         optPM.add(llvm::createFunctionInliningPass());
747         optPM.add(llvm::createArgumentPromotionPass());
748 
749         optPM.add(llvm::createSROAPass());
750 
751         optPM.add(llvm::createInstructionCombiningPass());
752         optPM.add(CreateInstructionSimplifyPass());
753 #if ISPC_LLVM_VERSION >= ISPC_LLVM_12_0
754         optPM.add(llvm::createCFGSimplificationPass(simplifyCFGopt));
755 #else
756         optPM.add(llvm::createCFGSimplificationPass());
757 #endif
758         optPM.add(llvm::createReassociatePass());
759         optPM.add(llvm::createLoopRotatePass());
760         optPM.add(llvm::createLICMPass());
761         optPM.add(llvm::createLoopUnswitchPass(false));
762         optPM.add(llvm::createInstructionCombiningPass());
763         optPM.add(CreateInstructionSimplifyPass());
764         optPM.add(llvm::createIndVarSimplifyPass());
765         // Currently CM does not support memset/memcpy
766         // so this pass is temporary disabled for GEN.
767         if (!g->target->isGenXTarget()) {
768             optPM.add(llvm::createLoopIdiomPass());
769         }
770         optPM.add(llvm::createLoopDeletionPass());
771         if (g->opt.unrollLoops) {
772             optPM.add(llvm::createLoopUnrollPass(), 300);
773         }
774         optPM.add(llvm::createGVNPass(), 301);
775 
776         optPM.add(CreateIsCompileTimeConstantPass(true));
777         optPM.add(CreateIntrinsicsOptPass());
778         optPM.add(CreateInstructionSimplifyPass());
779         // Currently CM does not support memset/memcpy
780         // so this pass is temporary disabled for GEN.
781         if (!g->target->isGenXTarget()) {
782             optPM.add(llvm::createMemCpyOptPass());
783         }
784         optPM.add(llvm::createSCCPPass());
785         optPM.add(llvm::createInstructionCombiningPass());
786         optPM.add(CreateInstructionSimplifyPass());
787         optPM.add(llvm::createJumpThreadingPass());
788         optPM.add(llvm::createCorrelatedValuePropagationPass());
789         optPM.add(llvm::createDeadStoreEliminationPass());
790         optPM.add(llvm::createAggressiveDCEPass());
791 #if ISPC_LLVM_VERSION >= ISPC_LLVM_12_0
792         optPM.add(llvm::createCFGSimplificationPass(simplifyCFGopt));
793 #else
794         optPM.add(llvm::createCFGSimplificationPass());
795 #endif
796         optPM.add(llvm::createInstructionCombiningPass());
797         optPM.add(CreateInstructionSimplifyPass());
798 #ifdef ISPC_GENX_ENABLED
799         if (g->target->isGenXTarget()) {
800             optPM.add(CreateReplaceLLVMIntrinsics());
801         }
802 #endif
803         optPM.add(CreatePeepholePass());
804         optPM.add(llvm::createFunctionInliningPass());
805         optPM.add(llvm::createAggressiveDCEPass());
806         optPM.add(llvm::createStripDeadPrototypesPass());
807         optPM.add(CreateMakeInternalFuncsStaticPass());
808         optPM.add(llvm::createGlobalDCEPass());
809         optPM.add(llvm::createConstantMergePass());
810 #ifdef ISPC_GENX_ENABLED
811         if (g->target->isGenXTarget()) {
812             optPM.add(CreateCheckUnsupportedInsts());
813             optPM.add(CreateFixAddressSpace());
814             optPM.add(CreateMangleOpenCLBuiltins());
815             // This pass is required to prepare LLVM IR for open source SPIR-V translator
816             optPM.add(
817                 llvm::createGenXSPIRVWriterAdaptorPass(true /*RewriteTypes*/, false /*RewriteSingleElementVectors*/));
818         }
819 #endif
820     }
821 
822     // Finish up by making sure we didn't mess anything up in the IR along
823     // the way.
824     optPM.add(llvm::createVerifierPass(), LAST_OPT_NUMBER);
825     optPM.run(*module);
826 
827 #ifndef ISPC_NO_DUMPS
828     if (g->debugPrint) {
829         printf("\n*****\nFINAL OUTPUT\n*****\n");
830         //module->dump();
831     }
832 #endif
833 }
834 
835 ///////////////////////////////////////////////////////////////////////////
836 // IntrinsicsOpt
837 
838 /** This is a relatively simple optimization pass that does a few small
839     optimizations that LLVM's x86 optimizer doesn't currently handle.
840     (Specifically, MOVMSK of a constant can be replaced with the
841     corresponding constant value, BLENDVPS and AVX masked load/store with
842     either an 'all on' or 'all off' masks can be replaced with simpler
843     operations.
844 
845     @todo The better thing to do would be to submit a patch to LLVM to get
846     these; they're presumably pretty simple patterns to match.
847 */
848 class IntrinsicsOpt : public llvm::FunctionPass {
849   public:
IntrinsicsOpt()850     IntrinsicsOpt() : FunctionPass(ID){};
851 
getPassName() const852     llvm::StringRef getPassName() const { return "Intrinsics Cleanup Optimization"; }
853 
854     bool runOnBasicBlock(llvm::BasicBlock &BB);
855 
856     bool runOnFunction(llvm::Function &F);
857 
858     static char ID;
859 
860   private:
861     struct MaskInstruction {
MaskInstructionIntrinsicsOpt::MaskInstruction862         MaskInstruction(llvm::Function *f) { function = f; }
863         llvm::Function *function;
864     };
865     std::vector<MaskInstruction> maskInstructions;
866 
867     /** Structure that records everything we need to know about a blend
868         instruction for this optimization pass.
869      */
870     struct BlendInstruction {
BlendInstructionIntrinsicsOpt::BlendInstruction871         BlendInstruction(llvm::Function *f, uint64_t ao, int o0, int o1, int of)
872             : function(f), allOnMask(ao), op0(o0), op1(o1), opFactor(of) {}
873         /** Function pointer for the blend instruction */
874         llvm::Function *function;
875         /** Mask value for an "all on" mask for this instruction */
876         uint64_t allOnMask;
877         /** The operand number in the llvm CallInst corresponds to the
878             first operand to blend with. */
879         int op0;
880         /** The operand number in the CallInst corresponding to the second
881             operand to blend with. */
882         int op1;
883         /** The operand in the call inst where the blending factor is
884             found. */
885         int opFactor;
886     };
887     std::vector<BlendInstruction> blendInstructions;
888 
889     bool matchesMaskInstruction(llvm::Function *function);
890     BlendInstruction *matchingBlendInstruction(llvm::Function *function);
891 };
892 
893 char IntrinsicsOpt::ID = 0;
894 
895 /** Given an llvm::Value, return true if we can determine that it's an
896     undefined value.  This only makes a weak attempt at chasing this down,
897     only detecting flat-out undef values, and bitcasts of undef values.
898 
899     @todo Is it worth working harder to find more of these?  It starts to
900     get tricky, since having an undef operand doesn't necessarily mean that
901     the result will be undefined.  (And for that matter, is there an LLVM
902     call that will do this for us?)
903  */
lIsUndef(llvm::Value * value)904 static bool lIsUndef(llvm::Value *value) {
905     if (llvm::isa<llvm::UndefValue>(value))
906         return true;
907 
908     llvm::BitCastInst *bci = llvm::dyn_cast<llvm::BitCastInst>(value);
909     if (bci)
910         return lIsUndef(bci->getOperand(0));
911 
912     return false;
913 }
914 
runOnBasicBlock(llvm::BasicBlock & bb)915 bool IntrinsicsOpt::runOnBasicBlock(llvm::BasicBlock &bb) {
916     DEBUG_START_PASS("IntrinsicsOpt");
917 
918     // We can't initialize mask/blend function vector during pass initialization,
919     // as they may be optimized out by the time the pass is invoked.
920 
921     // All of the mask instructions we may encounter.  Note that even if
922     // compiling for AVX, we may still encounter the regular 4-wide SSE
923     // MOVMSK instruction.
924     if (llvm::Function *ssei8Movmsk =
925             m->module->getFunction(llvm::Intrinsic::getName(llvm::Intrinsic::x86_sse2_pmovmskb_128))) {
926         maskInstructions.push_back(ssei8Movmsk);
927     }
928     if (llvm::Function *sseFloatMovmsk =
929             m->module->getFunction(llvm::Intrinsic::getName(llvm::Intrinsic::x86_sse_movmsk_ps))) {
930         maskInstructions.push_back(sseFloatMovmsk);
931     }
932     if (llvm::Function *__movmsk = m->module->getFunction("__movmsk")) {
933         maskInstructions.push_back(__movmsk);
934     }
935     if (llvm::Function *avxFloatMovmsk =
936             m->module->getFunction(llvm::Intrinsic::getName(llvm::Intrinsic::x86_avx_movmsk_ps_256))) {
937         maskInstructions.push_back(avxFloatMovmsk);
938     }
939 
940     // And all of the blend instructions
941     blendInstructions.push_back(BlendInstruction(
942         m->module->getFunction(llvm::Intrinsic::getName(llvm::Intrinsic::x86_sse41_blendvps)), 0xf, 0, 1, 2));
943     blendInstructions.push_back(BlendInstruction(
944         m->module->getFunction(llvm::Intrinsic::getName(llvm::Intrinsic::x86_avx_blendv_ps_256)), 0xff, 0, 1, 2));
945 
946     llvm::Function *avxMaskedLoad32 =
947         m->module->getFunction(llvm::Intrinsic::getName(llvm::Intrinsic::x86_avx_maskload_ps_256));
948     llvm::Function *avxMaskedLoad64 =
949         m->module->getFunction(llvm::Intrinsic::getName(llvm::Intrinsic::x86_avx_maskload_pd_256));
950     llvm::Function *avxMaskedStore32 =
951         m->module->getFunction(llvm::Intrinsic::getName(llvm::Intrinsic::x86_avx_maskstore_ps_256));
952     llvm::Function *avxMaskedStore64 =
953         m->module->getFunction(llvm::Intrinsic::getName(llvm::Intrinsic::x86_avx_maskstore_pd_256));
954 
955     bool modifiedAny = false;
956 restart:
957     for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
958         llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(&*iter);
959         if (callInst == NULL || callInst->getCalledFunction() == NULL)
960             continue;
961 
962         BlendInstruction *blend = matchingBlendInstruction(callInst->getCalledFunction());
963         if (blend != NULL) {
964             llvm::Value *v[2] = {callInst->getArgOperand(blend->op0), callInst->getArgOperand(blend->op1)};
965             llvm::Value *factor = callInst->getArgOperand(blend->opFactor);
966 
967             // If the values are the same, then no need to blend..
968             if (v[0] == v[1]) {
969                 llvm::ReplaceInstWithValue(iter->getParent()->getInstList(), iter, v[0]);
970                 modifiedAny = true;
971                 goto restart;
972             }
973 
974             // If one of the two is undefined, we're allowed to replace
975             // with the value of the other.  (In other words, the only
976             // valid case is that the blend factor ends up having a value
977             // that only selects from the defined one of the two operands,
978             // otherwise the result is undefined and any value is fine,
979             // ergo the defined one is an acceptable result.)
980             if (lIsUndef(v[0])) {
981                 llvm::ReplaceInstWithValue(iter->getParent()->getInstList(), iter, v[1]);
982                 modifiedAny = true;
983                 goto restart;
984             }
985             if (lIsUndef(v[1])) {
986                 llvm::ReplaceInstWithValue(iter->getParent()->getInstList(), iter, v[0]);
987                 modifiedAny = true;
988                 goto restart;
989             }
990 
991             MaskStatus maskStatus = lGetMaskStatus(factor);
992             llvm::Value *value = NULL;
993             if (maskStatus == MaskStatus::all_off) {
994                 // Mask all off -> replace with the first blend value
995                 value = v[0];
996             } else if (maskStatus == MaskStatus::all_on) {
997                 // Mask all on -> replace with the second blend value
998                 value = v[1];
999             }
1000 
1001             if (value != NULL) {
1002                 llvm::ReplaceInstWithValue(iter->getParent()->getInstList(), iter, value);
1003                 modifiedAny = true;
1004                 goto restart;
1005             }
1006         } else if (matchesMaskInstruction(callInst->getCalledFunction())) {
1007             llvm::Value *factor = callInst->getArgOperand(0);
1008             uint64_t mask;
1009             if (lGetMask(factor, &mask) == true) {
1010                 // If the vector-valued mask has a known value, replace it
1011                 // with the corresponding integer mask from its elements
1012                 // high bits.
1013                 llvm::Value *value = (callInst->getType() == LLVMTypes::Int32Type) ? LLVMInt32(mask) : LLVMInt64(mask);
1014                 llvm::ReplaceInstWithValue(iter->getParent()->getInstList(), iter, value);
1015                 modifiedAny = true;
1016                 goto restart;
1017             }
1018         } else if (callInst->getCalledFunction() == avxMaskedLoad32 ||
1019                    callInst->getCalledFunction() == avxMaskedLoad64) {
1020             llvm::Value *factor = callInst->getArgOperand(1);
1021             MaskStatus maskStatus = lGetMaskStatus(factor);
1022             if (maskStatus == MaskStatus::all_off) {
1023                 // nothing being loaded, replace with undef value
1024                 llvm::Type *returnType = callInst->getType();
1025                 Assert(llvm::isa<llvm::VectorType>(returnType));
1026                 llvm::Value *undefValue = llvm::UndefValue::get(returnType);
1027                 llvm::ReplaceInstWithValue(iter->getParent()->getInstList(), iter, undefValue);
1028                 modifiedAny = true;
1029                 goto restart;
1030             } else if (maskStatus == MaskStatus::all_on) {
1031                 // all lanes active; replace with a regular load
1032                 llvm::Type *returnType = callInst->getType();
1033                 Assert(llvm::isa<llvm::VectorType>(returnType));
1034                 // cast the i8 * to the appropriate type
1035                 llvm::Value *castPtr =
1036                     new llvm::BitCastInst(callInst->getArgOperand(0), llvm::PointerType::get(returnType, 0),
1037                                           llvm::Twine(callInst->getArgOperand(0)->getName()) + "_cast", callInst);
1038                 lCopyMetadata(castPtr, callInst);
1039                 int align;
1040                 if (g->opt.forceAlignedMemory)
1041                     align = g->target->getNativeVectorAlignment();
1042                 else
1043                     align = callInst->getCalledFunction() == avxMaskedLoad32 ? 4 : 8;
1044 #if ISPC_LLVM_VERSION < ISPC_LLVM_11_0
1045                 llvm::Instruction *loadInst =
1046                     new llvm::LoadInst(castPtr, llvm::Twine(callInst->getArgOperand(0)->getName()) + "_load",
1047                                        false /* not volatile */, llvm::MaybeAlign(align), (llvm::Instruction *)NULL);
1048 #else
1049                 llvm::Instruction *loadInst = new llvm::LoadInst(
1050                     llvm::dyn_cast<llvm::PointerType>(castPtr->getType())->getPointerElementType(), castPtr,
1051                     llvm::Twine(callInst->getArgOperand(0)->getName()) + "_load", false /* not volatile */,
1052                     llvm::MaybeAlign(align).valueOrOne(), (llvm::Instruction *)NULL);
1053 #endif
1054                 lCopyMetadata(loadInst, callInst);
1055                 llvm::ReplaceInstWithInst(callInst, loadInst);
1056                 modifiedAny = true;
1057                 goto restart;
1058             }
1059         } else if (callInst->getCalledFunction() == avxMaskedStore32 ||
1060                    callInst->getCalledFunction() == avxMaskedStore64) {
1061             // NOTE: mask is the 2nd parameter, not the 3rd one!!
1062             llvm::Value *factor = callInst->getArgOperand(1);
1063             MaskStatus maskStatus = lGetMaskStatus(factor);
1064             if (maskStatus == MaskStatus::all_off) {
1065                 // nothing actually being stored, just remove the inst
1066                 callInst->eraseFromParent();
1067                 modifiedAny = true;
1068                 goto restart;
1069             } else if (maskStatus == MaskStatus::all_on) {
1070                 // all lanes storing, so replace with a regular store
1071                 llvm::Value *rvalue = callInst->getArgOperand(2);
1072                 llvm::Type *storeType = rvalue->getType();
1073                 llvm::Value *castPtr =
1074                     new llvm::BitCastInst(callInst->getArgOperand(0), llvm::PointerType::get(storeType, 0),
1075                                           llvm::Twine(callInst->getArgOperand(0)->getName()) + "_ptrcast", callInst);
1076                 lCopyMetadata(castPtr, callInst);
1077 
1078                 int align;
1079                 if (g->opt.forceAlignedMemory)
1080                     align = g->target->getNativeVectorAlignment();
1081                 else
1082                     align = callInst->getCalledFunction() == avxMaskedStore32 ? 4 : 8;
1083                 llvm::StoreInst *storeInst = new llvm::StoreInst(rvalue, castPtr, (llvm::Instruction *)NULL,
1084                                                                  llvm::MaybeAlign(align).valueOrOne());
1085                 lCopyMetadata(storeInst, callInst);
1086                 llvm::ReplaceInstWithInst(callInst, storeInst);
1087 
1088                 modifiedAny = true;
1089                 goto restart;
1090             }
1091         }
1092     }
1093 
1094     DEBUG_END_PASS("IntrinsicsOpt");
1095 
1096     return modifiedAny;
1097 }
1098 
runOnFunction(llvm::Function & F)1099 bool IntrinsicsOpt::runOnFunction(llvm::Function &F) {
1100 
1101     llvm::TimeTraceScope FuncScope("IntrinsicsOpt::runOnFunction", F.getName());
1102     bool modifiedAny = false;
1103     for (llvm::BasicBlock &BB : F) {
1104         modifiedAny |= runOnBasicBlock(BB);
1105     }
1106     return modifiedAny;
1107 }
1108 
matchesMaskInstruction(llvm::Function * function)1109 bool IntrinsicsOpt::matchesMaskInstruction(llvm::Function *function) {
1110     for (unsigned int i = 0; i < maskInstructions.size(); ++i) {
1111         if (maskInstructions[i].function != NULL && function == maskInstructions[i].function) {
1112             return true;
1113         }
1114     }
1115     return false;
1116 }
1117 
matchingBlendInstruction(llvm::Function * function)1118 IntrinsicsOpt::BlendInstruction *IntrinsicsOpt::matchingBlendInstruction(llvm::Function *function) {
1119     for (unsigned int i = 0; i < blendInstructions.size(); ++i) {
1120         if (blendInstructions[i].function != NULL && function == blendInstructions[i].function) {
1121             return &blendInstructions[i];
1122         }
1123     }
1124     return NULL;
1125 }
1126 
CreateIntrinsicsOptPass()1127 static llvm::Pass *CreateIntrinsicsOptPass() { return new IntrinsicsOpt; }
1128 
1129 ///////////////////////////////////////////////////////////////////////////
1130 
1131 /** This simple optimization pass looks for a vector select instruction
1132     with an all-on or all-off constant mask, simplifying it to the
1133     appropriate operand if so.
1134 
1135     @todo The better thing to do would be to submit a patch to LLVM to get
1136     these; they're presumably pretty simple patterns to match.
1137 */
1138 class InstructionSimplifyPass : public llvm::FunctionPass {
1139   public:
InstructionSimplifyPass()1140     InstructionSimplifyPass() : FunctionPass(ID) {}
1141 
getPassName() const1142     llvm::StringRef getPassName() const { return "Vector Select Optimization"; }
1143     bool runOnBasicBlock(llvm::BasicBlock &BB);
1144     bool runOnFunction(llvm::Function &F);
1145 
1146     static char ID;
1147 
1148   private:
1149     static bool simplifySelect(llvm::SelectInst *selectInst, llvm::BasicBlock::iterator iter);
1150     static llvm::Value *simplifyBoolVec(llvm::Value *value);
1151     static bool simplifyCall(llvm::CallInst *callInst, llvm::BasicBlock::iterator iter);
1152 };
1153 
1154 char InstructionSimplifyPass::ID = 0;
1155 
simplifyBoolVec(llvm::Value * value)1156 llvm::Value *InstructionSimplifyPass::simplifyBoolVec(llvm::Value *value) {
1157     llvm::TruncInst *trunc = llvm::dyn_cast<llvm::TruncInst>(value);
1158     if (trunc != NULL) {
1159         // Convert trunc({sext,zext}(i1 vector)) -> (i1 vector)
1160         llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(value);
1161         if (sext && sext->getOperand(0)->getType() == LLVMTypes::Int1VectorType)
1162             return sext->getOperand(0);
1163 
1164         llvm::ZExtInst *zext = llvm::dyn_cast<llvm::ZExtInst>(value);
1165         if (zext && zext->getOperand(0)->getType() == LLVMTypes::Int1VectorType)
1166             return zext->getOperand(0);
1167     }
1168     /*
1169       // This optimization has discernable benefit on the perf
1170       // suite on latest LLVM versions.
1171       // On 3.4+ (maybe even older), it can result in illegal
1172       // operations, so it's being disabled.
1173     llvm::ICmpInst *icmp = llvm::dyn_cast<llvm::ICmpInst>(value);
1174     if (icmp != NULL) {
1175         // icmp(ne, {sext,zext}(foo), zeroinitializer) -> foo
1176         if (icmp->getSignedPredicate() == llvm::CmpInst::ICMP_NE) {
1177             llvm::Value *op1 = icmp->getOperand(1);
1178             if (llvm::isa<llvm::ConstantAggregateZero>(op1)) {
1179                 llvm::Value *op0 = icmp->getOperand(0);
1180                 llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(op0);
1181                 if (sext)
1182                     return sext->getOperand(0);
1183                 llvm::ZExtInst *zext = llvm::dyn_cast<llvm::ZExtInst>(op0);
1184                 if (zext)
1185                     return zext->getOperand(0);
1186             }
1187         }
1188 
1189     }
1190     */
1191     return NULL;
1192 }
1193 
simplifySelect(llvm::SelectInst * selectInst,llvm::BasicBlock::iterator iter)1194 bool InstructionSimplifyPass::simplifySelect(llvm::SelectInst *selectInst, llvm::BasicBlock::iterator iter) {
1195     if (selectInst->getType()->isVectorTy() == false)
1196         return false;
1197     Assert(selectInst->getOperand(1) != NULL);
1198     Assert(selectInst->getOperand(2) != NULL);
1199     llvm::Value *factor = selectInst->getOperand(0);
1200 
1201     // Simplify all-on or all-off mask values
1202     MaskStatus maskStatus = lGetMaskStatus(factor);
1203     llvm::Value *value = NULL;
1204     if (maskStatus == MaskStatus::all_on)
1205         // Mask all on -> replace with the first select value
1206         value = selectInst->getOperand(1);
1207     else if (maskStatus == MaskStatus::all_off)
1208         // Mask all off -> replace with the second select value
1209         value = selectInst->getOperand(2);
1210     if (value != NULL) {
1211         llvm::ReplaceInstWithValue(iter->getParent()->getInstList(), iter, value);
1212         return true;
1213     }
1214 
1215     // Sometimes earlier LLVM optimization passes generate unnecessarily
1216     // complex expressions for the selection vector, which in turn confuses
1217     // the code generators and leads to sub-optimal code (particularly for
1218     // 8 and 16-bit masks).  We'll try to simplify them out here so that
1219     // the code generator patterns match..
1220     if ((factor = simplifyBoolVec(factor)) != NULL) {
1221         llvm::Instruction *newSelect = llvm::SelectInst::Create(factor, selectInst->getOperand(1),
1222                                                                 selectInst->getOperand(2), selectInst->getName());
1223         llvm::ReplaceInstWithInst(selectInst, newSelect);
1224         return true;
1225     }
1226 
1227     return false;
1228 }
1229 
simplifyCall(llvm::CallInst * callInst,llvm::BasicBlock::iterator iter)1230 bool InstructionSimplifyPass::simplifyCall(llvm::CallInst *callInst, llvm::BasicBlock::iterator iter) {
1231     llvm::Function *calledFunc = callInst->getCalledFunction();
1232 
1233     // Turn a __movmsk call with a compile-time constant vector into the
1234     // equivalent scalar value.
1235     if (calledFunc == NULL || calledFunc != m->module->getFunction("__movmsk"))
1236         return false;
1237 
1238     uint64_t mask;
1239     if (lGetMask(callInst->getArgOperand(0), &mask) == true) {
1240         llvm::ReplaceInstWithValue(iter->getParent()->getInstList(), iter, LLVMInt64(mask));
1241         return true;
1242     }
1243     return false;
1244 }
1245 
runOnBasicBlock(llvm::BasicBlock & bb)1246 bool InstructionSimplifyPass::runOnBasicBlock(llvm::BasicBlock &bb) {
1247     DEBUG_START_PASS("InstructionSimplify");
1248 
1249     bool modifiedAny = false;
1250 
1251 restart:
1252     for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
1253         llvm::SelectInst *selectInst = llvm::dyn_cast<llvm::SelectInst>(&*iter);
1254         if (selectInst && simplifySelect(selectInst, iter)) {
1255             modifiedAny = true;
1256             goto restart;
1257         }
1258         llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(&*iter);
1259         if (callInst && simplifyCall(callInst, iter)) {
1260             modifiedAny = true;
1261             goto restart;
1262         }
1263     }
1264 
1265     DEBUG_END_PASS("InstructionSimplify");
1266 
1267     return modifiedAny;
1268 }
1269 
runOnFunction(llvm::Function & F)1270 bool InstructionSimplifyPass::runOnFunction(llvm::Function &F) {
1271 
1272     llvm::TimeTraceScope FuncScope("InstructionSimplifyPass::runOnFunction", F.getName());
1273     bool modifiedAny = false;
1274     for (llvm::BasicBlock &BB : F) {
1275         modifiedAny |= runOnBasicBlock(BB);
1276     }
1277     return modifiedAny;
1278 }
1279 
CreateInstructionSimplifyPass()1280 static llvm::Pass *CreateInstructionSimplifyPass() { return new InstructionSimplifyPass; }
1281 
1282 ///////////////////////////////////////////////////////////////////////////
1283 // ImproveMemoryOpsPass
1284 
1285 /** When the front-end emits gathers and scatters, it generates an array of
1286     vector-width pointers to represent the set of addresses to read from or
1287     write to.  This optimization detects cases when the base pointer is a
1288     uniform pointer or when the indexing is into an array that can be
1289     converted into scatters/gathers from a single base pointer and an array
1290     of offsets.
1291 
1292     See for example the comments discussing the __pseudo_gather functions
1293     in builtins.cpp for more information about this.
1294  */
1295 class ImproveMemoryOpsPass : public llvm::FunctionPass {
1296   public:
1297     static char ID;
ImproveMemoryOpsPass()1298     ImproveMemoryOpsPass() : FunctionPass(ID) {}
1299 
getPassName() const1300     llvm::StringRef getPassName() const { return "Improve Memory Ops"; }
1301 
1302     bool runOnBasicBlock(llvm::BasicBlock &BB);
1303 
1304     bool runOnFunction(llvm::Function &F);
1305 };
1306 
1307 char ImproveMemoryOpsPass::ID = 0;
1308 
1309 /** Check to make sure that this value is actually a pointer in the end.
1310     We need to make sure that given an expression like vec(offset) +
1311     ptr2int(ptr), lGetBasePointer() doesn't return vec(offset) for the base
1312     pointer such that we then treat ptr2int(ptr) as an offset.  This ends
1313     up being important so that we don't generate LLVM GEP instructions like
1314     "gep inttoptr 8, i64 %ptr", which in turn can lead to incorrect code
1315     since LLVM's pointer aliasing analysis assumes that operands after the
1316     first one to a GEP aren't pointers.
1317  */
lCheckForActualPointer(llvm::Value * v)1318 static llvm::Value *lCheckForActualPointer(llvm::Value *v) {
1319     if (v == NULL) {
1320         return NULL;
1321     } else if (llvm::isa<llvm::PointerType>(v->getType())) {
1322         return v;
1323     } else if (llvm::isa<llvm::PtrToIntInst>(v)) {
1324         return v;
1325     }
1326     // This one is tricky, as it's heuristic tuned for LLVM 3.7+, which may
1327     // optimize loading double* with consequent ptr2int to straight load of i64.
1328     // This heuristic should be good enough to catch all the cases we should
1329     // detect and nothing else.
1330     else if (llvm::isa<llvm::LoadInst>(v)) {
1331         return v;
1332     }
1333 
1334     else if (llvm::CastInst *ci = llvm::dyn_cast<llvm::CastInst>(v)) {
1335         llvm::Value *t = lCheckForActualPointer(ci->getOperand(0));
1336         if (t == NULL) {
1337             return NULL;
1338         } else {
1339             return v;
1340         }
1341     } else {
1342         llvm::ConstantExpr *uce = llvm::dyn_cast<llvm::ConstantExpr>(v);
1343         if (uce != NULL && uce->getOpcode() == llvm::Instruction::PtrToInt)
1344             return v;
1345         return NULL;
1346     }
1347 }
1348 
1349 /** Given a llvm::Value representing a varying pointer, this function
1350     checks to see if all of the elements of the vector have the same value
1351     (i.e. there's a common base pointer). If broadcast has been already detected
1352     it checks that the first element of the vector is not undef. If one of the conditions
1353     is true, it returns the common pointer value; otherwise it returns NULL.
1354  */
lGetBasePointer(llvm::Value * v,llvm::Instruction * insertBefore,bool broadcastDetected)1355 static llvm::Value *lGetBasePointer(llvm::Value *v, llvm::Instruction *insertBefore, bool broadcastDetected) {
1356     if (llvm::isa<llvm::InsertElementInst>(v) || llvm::isa<llvm::ShuffleVectorInst>(v)) {
1357         // If we have already detected broadcast we want to look for
1358         // the vector with the first not-undef element
1359         llvm::Value *element = LLVMFlattenInsertChain(v, g->target->getVectorWidth(), true, false, broadcastDetected);
1360         // TODO: it's probably ok to allow undefined elements and return
1361         // the base pointer if all of the other elements have the same
1362         // value.
1363         if (element != NULL) {
1364             // all elements are the same and not NULLs
1365             return lCheckForActualPointer(element);
1366         } else {
1367             return NULL;
1368         }
1369     }
1370 
1371     // This case comes up with global/static arrays
1372     if (llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v)) {
1373         return lCheckForActualPointer(cv->getSplatValue());
1374     } else if (llvm::ConstantDataVector *cdv = llvm::dyn_cast<llvm::ConstantDataVector>(v)) {
1375         return lCheckForActualPointer(cdv->getSplatValue());
1376     }
1377     // It is a little bit tricky to use operations with pointers, casted to int with another bit size
1378     // but sometimes it is useful, so we handle this case here.
1379     else if (llvm::CastInst *ci = llvm::dyn_cast<llvm::CastInst>(v)) {
1380         llvm::Value *t = lGetBasePointer(ci->getOperand(0), insertBefore, broadcastDetected);
1381         if (t == NULL) {
1382             return NULL;
1383         } else {
1384             return llvm::CastInst::Create(ci->getOpcode(), t, ci->getType()->getScalarType(),
1385                                           llvm::Twine(t->getName()) + "_cast", insertBefore);
1386         }
1387     }
1388 
1389     return NULL;
1390 }
1391 
1392 /** Given the two operands to a constant add expression, see if we have the
1393     form "base pointer + offset", whee op0 is the base pointer and op1 is
1394     the offset; if so return the base and the offset. */
lGetConstantAddExprBaseOffset(llvm::Constant * op0,llvm::Constant * op1,llvm::Constant ** delta)1395 static llvm::Constant *lGetConstantAddExprBaseOffset(llvm::Constant *op0, llvm::Constant *op1, llvm::Constant **delta) {
1396     llvm::ConstantExpr *op = llvm::dyn_cast<llvm::ConstantExpr>(op0);
1397     if (op == NULL || op->getOpcode() != llvm::Instruction::PtrToInt)
1398         // the first operand isn't a pointer
1399         return NULL;
1400 
1401     llvm::ConstantInt *opDelta = llvm::dyn_cast<llvm::ConstantInt>(op1);
1402     if (opDelta == NULL)
1403         // the second operand isn't an integer operand
1404         return NULL;
1405 
1406     *delta = opDelta;
1407     return op0;
1408 }
1409 
lExtractFromInserts(llvm::Value * v,unsigned int index)1410 static llvm::Value *lExtractFromInserts(llvm::Value *v, unsigned int index) {
1411     llvm::InsertValueInst *iv = llvm::dyn_cast<llvm::InsertValueInst>(v);
1412     if (iv == NULL)
1413         return NULL;
1414 
1415     Assert(iv->hasIndices() && iv->getNumIndices() == 1);
1416     if (iv->getIndices()[0] == index)
1417         return iv->getInsertedValueOperand();
1418     else
1419         return lExtractFromInserts(iv->getAggregateOperand(), index);
1420 }
1421 
1422 /** Given a varying pointer in ptrs, this function checks to see if it can
1423     be determined to be indexing from a common uniform base pointer.  If
1424     so, the function returns the base pointer llvm::Value and initializes
1425     *offsets with an int vector of the per-lane offsets
1426  */
lGetBasePtrAndOffsets(llvm::Value * ptrs,llvm::Value ** offsets,llvm::Instruction * insertBefore)1427 static llvm::Value *lGetBasePtrAndOffsets(llvm::Value *ptrs, llvm::Value **offsets, llvm::Instruction *insertBefore) {
1428 #ifndef ISPC_NO_DUMPS
1429     if (g->debugPrint) {
1430         fprintf(stderr, "lGetBasePtrAndOffsets\n");
1431         LLVMDumpValue(ptrs);
1432     }
1433 #endif
1434 
1435     bool broadcastDetected = false;
1436     // Looking for %gep_offset = shufflevector <8 x i64> %0, <8 x i64> undef, <8 x i32> zeroinitializer
1437     llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast<llvm::ShuffleVectorInst>(ptrs);
1438     if (shuffle != NULL) {
1439 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
1440         llvm::Value *indices = shuffle->getShuffleMaskForBitcode();
1441 #else
1442         llvm::Value *indices = shuffle->getOperand(2);
1443 #endif
1444         llvm::Value *vec = shuffle->getOperand(1);
1445 
1446         if (lIsUndef(vec) && llvm::isa<llvm::ConstantAggregateZero>(indices)) {
1447             broadcastDetected = true;
1448         }
1449     }
1450     llvm::Value *base = lGetBasePointer(ptrs, insertBefore, broadcastDetected);
1451     if (base != NULL) {
1452         // We have a straight up varying pointer with no indexing that's
1453         // actually all the same value.
1454         if (g->target->is32Bit())
1455             *offsets = LLVMInt32Vector(0);
1456         else
1457             *offsets = LLVMInt64Vector((int64_t)0);
1458 
1459         if (broadcastDetected) {
1460             llvm::Value *op = shuffle->getOperand(0);
1461             llvm::BinaryOperator *bop_var = llvm::dyn_cast<llvm::BinaryOperator>(op);
1462             if (bop_var != NULL && ((bop_var->getOpcode() == llvm::Instruction::Add) || IsOrEquivalentToAdd(bop_var))) {
1463                 // We expect here ConstantVector as
1464                 // <i64 4, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef, i64 undef>
1465                 llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(bop_var->getOperand(1));
1466                 llvm::Value *shuffle_offset = NULL;
1467                 if (cv != NULL) {
1468                     llvm::Value *zeroMask =
1469 #if ISPC_LLVM_VERSION < ISPC_LLVM_11_0
1470                         llvm::ConstantVector::getSplat(cv->getType()->getVectorNumElements(),
1471 #elif ISPC_LLVM_VERSION < ISPC_LLVM_12_0
1472                         llvm::ConstantVector::getSplat(
1473                             {llvm::dyn_cast<llvm::VectorType>(cv->getType())->getNumElements(), false},
1474 #else // LLVM 12.0+
1475                         llvm::ConstantVector::getSplat(
1476                             llvm::ElementCount::get(
1477                                 llvm::dyn_cast<llvm::FixedVectorType>(cv->getType())->getNumElements(), false),
1478 #endif
1479                                                        llvm::Constant::getNullValue(llvm::Type::getInt32Ty(*g->ctx)));
1480                     // Create offset
1481                     shuffle_offset = new llvm::ShuffleVectorInst(cv, llvm::UndefValue::get(cv->getType()), zeroMask,
1482                                                                  "shuffle", bop_var);
1483                 } else {
1484                     // or it binaryoperator can accept another binary operator
1485                     // that is a result of counting another part of offset:
1486                     // %another_bop = bop <16 x i32> %vec, <i32 7, i32 undef, i32 undef, ...>
1487                     // %offsets = add <16 x i32> %another_bop, %base
1488                     bop_var = llvm::dyn_cast<llvm::BinaryOperator>(bop_var->getOperand(0));
1489                     if (bop_var != NULL) {
1490                         llvm::Type *bop_var_type = bop_var->getType();
1491                         llvm::Value *zeroMask = llvm::ConstantVector::getSplat(
1492 #if ISPC_LLVM_VERSION < ISPC_LLVM_11_0
1493                             bop_var_type->getVectorNumElements(),
1494 #elif ISPC_LLVM_VERSION < ISPC_LLVM_12_0
1495                             {llvm::dyn_cast<llvm::VectorType>(bop_var_type)->getNumElements(), false},
1496 #else // LLVM 12.0+
1497                             llvm::ElementCount::get(
1498                                 llvm::dyn_cast<llvm::FixedVectorType>(bop_var_type)->getNumElements(), false),
1499 #endif
1500                             llvm::Constant::getNullValue(llvm::Type::getInt32Ty(*g->ctx)));
1501                         shuffle_offset = new llvm::ShuffleVectorInst(bop_var, llvm::UndefValue::get(bop_var_type),
1502                                                                      zeroMask, "shuffle", bop_var);
1503                     }
1504                 }
1505                 if (shuffle_offset != NULL) {
1506                     *offsets = llvm::BinaryOperator::Create(llvm::Instruction::Add, *offsets, shuffle_offset,
1507                                                             "new_offsets", insertBefore);
1508                     return base;
1509                 } else {
1510                     // Base + offset pattern was not recognized
1511                     return NULL;
1512                 }
1513             }
1514         }
1515         return base;
1516     }
1517 
1518     llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(ptrs);
1519     if (bop != NULL && ((bop->getOpcode() == llvm::Instruction::Add) || IsOrEquivalentToAdd(bop))) {
1520         // If we have a common pointer plus something, then we're also
1521         // good.
1522         if ((base = lGetBasePtrAndOffsets(bop->getOperand(0), offsets, insertBefore)) != NULL) {
1523             *offsets = llvm::BinaryOperator::Create(llvm::Instruction::Add, *offsets, bop->getOperand(1), "new_offsets",
1524                                                     insertBefore);
1525             return base;
1526         } else if ((base = lGetBasePtrAndOffsets(bop->getOperand(1), offsets, insertBefore)) != NULL) {
1527             *offsets = llvm::BinaryOperator::Create(llvm::Instruction::Add, *offsets, bop->getOperand(0), "new_offsets",
1528                                                     insertBefore);
1529             return base;
1530         }
1531     }
1532     llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(ptrs);
1533     if (cv != NULL) {
1534         // Indexing into global arrays can lead to this form, with
1535         // ConstantVectors..
1536         llvm::SmallVector<llvm::Constant *, ISPC_MAX_NVEC> elements;
1537         for (int i = 0; i < (int)cv->getNumOperands(); ++i) {
1538             llvm::Constant *c = llvm::dyn_cast<llvm::Constant>(cv->getOperand(i));
1539             if (c == NULL)
1540                 return NULL;
1541             elements.push_back(c);
1542         }
1543 
1544         llvm::Constant *delta[ISPC_MAX_NVEC];
1545         for (unsigned int i = 0; i < elements.size(); ++i) {
1546             // For each element, try to decompose it into either a straight
1547             // up base pointer, or a base pointer plus an integer value.
1548             llvm::ConstantExpr *ce = llvm::dyn_cast<llvm::ConstantExpr>(elements[i]);
1549             if (ce == NULL)
1550                 return NULL;
1551 
1552             delta[i] = NULL;
1553             llvm::Value *elementBase = NULL; // base pointer for this element
1554             if (ce->getOpcode() == llvm::Instruction::PtrToInt) {
1555                 // If the element is just a ptr to int instruction, treat
1556                 // it as having an offset of zero
1557                 elementBase = ce;
1558                 delta[i] = g->target->is32Bit() ? LLVMInt32(0) : LLVMInt64(0);
1559             } else if ((ce->getOpcode() == llvm::Instruction::Add) || IsOrEquivalentToAdd(ce)) {
1560                 // Try both orderings of the operands to see if we can get
1561                 // a pointer+offset out of them.
1562                 elementBase = lGetConstantAddExprBaseOffset(ce->getOperand(0), ce->getOperand(1), &delta[i]);
1563                 if (elementBase == NULL)
1564                     elementBase = lGetConstantAddExprBaseOffset(ce->getOperand(1), ce->getOperand(0), &delta[i]);
1565             }
1566 
1567             // We weren't able to find a base pointer in the above.  (We
1568             // don't expect this to happen; if it does, it may be necessary
1569             // to handle more cases in the decomposition above.)
1570             if (elementBase == NULL)
1571                 return NULL;
1572 
1573             Assert(delta[i] != NULL);
1574             if (base == NULL)
1575                 // The first time we've found a base pointer
1576                 base = elementBase;
1577             else if (base != elementBase)
1578                 // Different program instances have different base
1579                 // pointers, so no luck.
1580                 return NULL;
1581         }
1582 
1583         Assert(base != NULL);
1584         llvm::ArrayRef<llvm::Constant *> deltas(&delta[0], &delta[elements.size()]);
1585         *offsets = llvm::ConstantVector::get(deltas);
1586         return base;
1587     }
1588 
1589     llvm::ExtractValueInst *ev = llvm::dyn_cast<llvm::ExtractValueInst>(ptrs);
1590     if (ev != NULL) {
1591         Assert(ev->getNumIndices() == 1);
1592         int index = ev->getIndices()[0];
1593         ptrs = lExtractFromInserts(ev->getAggregateOperand(), index);
1594         if (ptrs != NULL)
1595             return lGetBasePtrAndOffsets(ptrs, offsets, insertBefore);
1596     }
1597 
1598     return NULL;
1599 }
1600 
1601 /** Given a vector expression in vec, separate it into a compile-time
1602     constant component and a variable component, returning the two parts in
1603     *constOffset and *variableOffset.  (It should be the case that the sum
1604     of these two is exactly equal to the original vector.)
1605 
1606     This routine only handles some (important) patterns; in some cases it
1607     will fail and return components that are actually compile-time
1608     constants in *variableOffset.
1609 
1610     Finally, if there aren't any constant (or, respectivaly, variable)
1611     components, the corresponding return value may be set to NULL.
1612  */
lExtractConstantOffset(llvm::Value * vec,llvm::Value ** constOffset,llvm::Value ** variableOffset,llvm::Instruction * insertBefore)1613 static void lExtractConstantOffset(llvm::Value *vec, llvm::Value **constOffset, llvm::Value **variableOffset,
1614                                    llvm::Instruction *insertBefore) {
1615     if (llvm::isa<llvm::ConstantVector>(vec) || llvm::isa<llvm::ConstantDataVector>(vec) ||
1616         llvm::isa<llvm::ConstantAggregateZero>(vec)) {
1617         *constOffset = vec;
1618         *variableOffset = NULL;
1619         return;
1620     }
1621 
1622     llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(vec);
1623     if (cast != NULL) {
1624         // Check the cast target.
1625         llvm::Value *co, *vo;
1626         lExtractConstantOffset(cast->getOperand(0), &co, &vo, insertBefore);
1627 
1628         // make new cast instructions for the two parts
1629         if (co == NULL)
1630             *constOffset = NULL;
1631         else
1632             *constOffset = llvm::CastInst::Create(cast->getOpcode(), co, cast->getType(),
1633                                                   llvm::Twine(co->getName()) + "_cast", insertBefore);
1634         if (vo == NULL)
1635             *variableOffset = NULL;
1636         else
1637             *variableOffset = llvm::CastInst::Create(cast->getOpcode(), vo, cast->getType(),
1638                                                      llvm::Twine(vo->getName()) + "_cast", insertBefore);
1639         return;
1640     }
1641 
1642     llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(vec);
1643     if (bop != NULL) {
1644         llvm::Value *op0 = bop->getOperand(0);
1645         llvm::Value *op1 = bop->getOperand(1);
1646         llvm::Value *c0, *v0, *c1, *v1;
1647 
1648         if ((bop->getOpcode() == llvm::Instruction::Add) || IsOrEquivalentToAdd(bop)) {
1649             lExtractConstantOffset(op0, &c0, &v0, insertBefore);
1650             lExtractConstantOffset(op1, &c1, &v1, insertBefore);
1651 
1652             if (c0 == NULL || llvm::isa<llvm::ConstantAggregateZero>(c0))
1653                 *constOffset = c1;
1654             else if (c1 == NULL || llvm::isa<llvm::ConstantAggregateZero>(c1))
1655                 *constOffset = c0;
1656             else
1657                 *constOffset = llvm::BinaryOperator::Create(
1658                     llvm::Instruction::Add, c0, c1, ((llvm::Twine("add_") + c0->getName()) + "_") + c1->getName(),
1659                     insertBefore);
1660 
1661             if (v0 == NULL || llvm::isa<llvm::ConstantAggregateZero>(v0))
1662                 *variableOffset = v1;
1663             else if (v1 == NULL || llvm::isa<llvm::ConstantAggregateZero>(v1))
1664                 *variableOffset = v0;
1665             else
1666                 *variableOffset = llvm::BinaryOperator::Create(
1667                     llvm::Instruction::Add, v0, v1, ((llvm::Twine("add_") + v0->getName()) + "_") + v1->getName(),
1668                     insertBefore);
1669             return;
1670         } else if (bop->getOpcode() == llvm::Instruction::Shl) {
1671             lExtractConstantOffset(op0, &c0, &v0, insertBefore);
1672             lExtractConstantOffset(op1, &c1, &v1, insertBefore);
1673 
1674             // Given the product of constant and variable terms, we have:
1675             // (c0 + v0) * (2^(c1 + v1))  = c0 * 2^c1 * 2^v1 + v0 * 2^c1 * 2^v1
1676             // We can optimize only if v1 == NULL.
1677             if ((v1 != NULL) || (c0 == NULL) || (c1 == NULL)) {
1678                 *constOffset = NULL;
1679                 *variableOffset = vec;
1680             } else if (v0 == NULL) {
1681                 *constOffset = vec;
1682                 *variableOffset = NULL;
1683             } else {
1684                 *constOffset = llvm::BinaryOperator::Create(
1685                     llvm::Instruction::Shl, c0, c1, ((llvm::Twine("shl_") + c0->getName()) + "_") + c1->getName(),
1686                     insertBefore);
1687                 *variableOffset = llvm::BinaryOperator::Create(
1688                     llvm::Instruction::Shl, v0, c1, ((llvm::Twine("shl_") + v0->getName()) + "_") + c1->getName(),
1689                     insertBefore);
1690             }
1691             return;
1692         } else if (bop->getOpcode() == llvm::Instruction::Mul) {
1693             lExtractConstantOffset(op0, &c0, &v0, insertBefore);
1694             lExtractConstantOffset(op1, &c1, &v1, insertBefore);
1695 
1696             // Given the product of constant and variable terms, we have:
1697             // (c0 + v0) * (c1 + v1) == (c0 c1) + (v0 c1 + c0 v1 + v0 v1)
1698             // Note that the first term is a constant and the last three are
1699             // variable.
1700             if (c0 != NULL && c1 != NULL)
1701                 *constOffset = llvm::BinaryOperator::Create(
1702                     llvm::Instruction::Mul, c0, c1, ((llvm::Twine("mul_") + c0->getName()) + "_") + c1->getName(),
1703                     insertBefore);
1704             else
1705                 *constOffset = NULL;
1706 
1707             llvm::Value *va = NULL, *vb = NULL, *vc = NULL;
1708             if (v0 != NULL && c1 != NULL)
1709                 va = llvm::BinaryOperator::Create(llvm::Instruction::Mul, v0, c1,
1710                                                   ((llvm::Twine("mul_") + v0->getName()) + "_") + c1->getName(),
1711                                                   insertBefore);
1712             if (c0 != NULL && v1 != NULL)
1713                 vb = llvm::BinaryOperator::Create(llvm::Instruction::Mul, c0, v1,
1714                                                   ((llvm::Twine("mul_") + c0->getName()) + "_") + v1->getName(),
1715                                                   insertBefore);
1716             if (v0 != NULL && v1 != NULL)
1717                 vc = llvm::BinaryOperator::Create(llvm::Instruction::Mul, v0, v1,
1718                                                   ((llvm::Twine("mul_") + v0->getName()) + "_") + v1->getName(),
1719                                                   insertBefore);
1720 
1721             llvm::Value *vab = NULL;
1722             if (va != NULL && vb != NULL)
1723                 vab = llvm::BinaryOperator::Create(llvm::Instruction::Add, va, vb,
1724                                                    ((llvm::Twine("add_") + va->getName()) + "_") + vb->getName(),
1725                                                    insertBefore);
1726             else if (va != NULL)
1727                 vab = va;
1728             else
1729                 vab = vb;
1730 
1731             if (vab != NULL && vc != NULL)
1732                 *variableOffset = llvm::BinaryOperator::Create(
1733                     llvm::Instruction::Add, vab, vc, ((llvm::Twine("add_") + vab->getName()) + "_") + vc->getName(),
1734                     insertBefore);
1735             else if (vab != NULL)
1736                 *variableOffset = vab;
1737             else
1738                 *variableOffset = vc;
1739 
1740             return;
1741         }
1742     }
1743 
1744     // Nothing matched, just return what we have as a variable component
1745     *constOffset = NULL;
1746     *variableOffset = vec;
1747 }
1748 
1749 /* Returns true if the given value is a constant vector of integers with
1750    the same value in all of the elements.  (Returns the splatted value in
1751    *splat, if so). */
lIsIntegerSplat(llvm::Value * v,int * splat)1752 static bool lIsIntegerSplat(llvm::Value *v, int *splat) {
1753     llvm::ConstantDataVector *cvec = llvm::dyn_cast<llvm::ConstantDataVector>(v);
1754     if (cvec == NULL)
1755         return false;
1756 
1757     llvm::Constant *splatConst = cvec->getSplatValue();
1758     if (splatConst == NULL)
1759         return false;
1760 
1761     llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(splatConst);
1762     if (ci == NULL)
1763         return false;
1764 
1765     int64_t splatVal = ci->getSExtValue();
1766     *splat = (int)splatVal;
1767     return true;
1768 }
1769 
lExtract248Scale(llvm::Value * splatOperand,int splatValue,llvm::Value * otherOperand,llvm::Value ** result)1770 static llvm::Value *lExtract248Scale(llvm::Value *splatOperand, int splatValue, llvm::Value *otherOperand,
1771                                      llvm::Value **result) {
1772     if (splatValue == 2 || splatValue == 4 || splatValue == 8) {
1773         *result = otherOperand;
1774         return LLVMInt32(splatValue);
1775     }
1776     // Even if we don't have a common scale by exactly 2, 4, or 8, we'll
1777     // see if we can pull out that much of the scale anyway; this may in
1778     // turn allow other optimizations later.
1779     for (int scale = 8; scale >= 2; scale /= 2) {
1780         llvm::Instruction *insertBefore = llvm::dyn_cast<llvm::Instruction>(*result);
1781         Assert(insertBefore != NULL);
1782 
1783         if ((splatValue % scale) == 0) {
1784             // *result = otherOperand * splatOperand / scale;
1785             llvm::Value *splatScaleVec = (splatOperand->getType() == LLVMTypes::Int32VectorType)
1786                                              ? LLVMInt32Vector(scale)
1787                                              : LLVMInt64Vector(scale);
1788             llvm::Value *splatDiv =
1789                 llvm::BinaryOperator::Create(llvm::Instruction::SDiv, splatOperand, splatScaleVec, "div", insertBefore);
1790             *result = llvm::BinaryOperator::Create(llvm::Instruction::Mul, splatDiv, otherOperand, "mul", insertBefore);
1791             return LLVMInt32(scale);
1792         }
1793     }
1794     return LLVMInt32(1);
1795 }
1796 
1797 /** Given a vector of integer offsets to a base pointer being used for a
1798     gather or a scatter, see if its root operation is a multiply by a
1799     vector of some value by all 2s/4s/8s.  If not, return NULL.
1800 
1801     If it is return an i32 value of 2, 4, 8 from the function and modify
1802     *vec so that it points to the operand that is being multiplied by
1803     2/4/8.
1804 
1805     We go through all this trouble so that we can pass the i32 scale factor
1806     to the {gather,scatter}_base_offsets function as a separate scale
1807     factor for the offsets.  This in turn is used in a way so that the LLVM
1808     x86 code generator matches it to apply x86's free scale by 2x, 4x, or
1809     8x to one of two registers being added together for an addressing
1810     calculation.
1811  */
lExtractOffsetVector248Scale(llvm::Value ** vec)1812 static llvm::Value *lExtractOffsetVector248Scale(llvm::Value **vec) {
1813     llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(*vec);
1814     if (cast != NULL) {
1815         llvm::Value *castOp = cast->getOperand(0);
1816         // Check the cast target.
1817         llvm::Value *scale = lExtractOffsetVector248Scale(&castOp);
1818         if (scale == NULL)
1819             return NULL;
1820 
1821         // make a new cast instruction so that we end up with the right
1822         // type
1823         *vec = llvm::CastInst::Create(cast->getOpcode(), castOp, cast->getType(), "offset_cast", cast);
1824         return scale;
1825     }
1826 
1827     // If we don't have a binary operator, then just give up
1828     llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(*vec);
1829     if (bop == NULL)
1830         return LLVMInt32(1);
1831 
1832     llvm::Value *op0 = bop->getOperand(0), *op1 = bop->getOperand(1);
1833     if ((bop->getOpcode() == llvm::Instruction::Add) || IsOrEquivalentToAdd(bop)) {
1834         if (llvm::isa<llvm::ConstantAggregateZero>(op0)) {
1835             *vec = op1;
1836             return lExtractOffsetVector248Scale(vec);
1837         } else if (llvm::isa<llvm::ConstantAggregateZero>(op1)) {
1838             *vec = op0;
1839             return lExtractOffsetVector248Scale(vec);
1840         } else {
1841             llvm::Value *s0 = lExtractOffsetVector248Scale(&op0);
1842             llvm::Value *s1 = lExtractOffsetVector248Scale(&op1);
1843             if (s0 == s1) {
1844                 *vec = llvm::BinaryOperator::Create(llvm::Instruction::Add, op0, op1, "new_add", bop);
1845                 return s0;
1846             } else
1847                 return LLVMInt32(1);
1848         }
1849     } else if (bop->getOpcode() == llvm::Instruction::Mul) {
1850         // Check each operand for being one of the scale factors we care about.
1851         int splat;
1852         if (lIsIntegerSplat(op0, &splat))
1853             return lExtract248Scale(op0, splat, op1, vec);
1854         else if (lIsIntegerSplat(op1, &splat))
1855             return lExtract248Scale(op1, splat, op0, vec);
1856         else
1857             return LLVMInt32(1);
1858     } else
1859         return LLVMInt32(1);
1860 }
1861 
1862 #if 0
1863 static llvm::Value *
1864 lExtractUniforms(llvm::Value **vec, llvm::Instruction *insertBefore) {
1865     fprintf(stderr, " lextract: ");
1866     (*vec)->dump();
1867     fprintf(stderr, "\n");
1868 
1869     if (llvm::isa<llvm::ConstantVector>(*vec) ||
1870         llvm::isa<llvm::ConstantDataVector>(*vec) ||
1871         llvm::isa<llvm::ConstantAggregateZero>(*vec))
1872         return NULL;
1873 
1874     llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(*vec);
1875     if (sext != NULL) {
1876         llvm::Value *sextOp = sext->getOperand(0);
1877         // Check the sext target.
1878         llvm::Value *unif = lExtractUniforms(&sextOp, insertBefore);
1879         if (unif == NULL)
1880             return NULL;
1881 
1882         // make a new sext instruction so that we end up with the right
1883         // type
1884         *vec = new llvm::SExtInst(sextOp, sext->getType(), "offset_sext", sext);
1885         return unif;
1886     }
1887 
1888     if (LLVMVectorValuesAllEqual(*vec)) {
1889         // FIXME: we may want to redo all of the expression here, in scalar
1890         // form (if at all possible), for code quality...
1891         llvm::Value *unif =
1892             llvm::ExtractElementInst::Create(*vec, LLVMInt32(0),
1893                                              "first_uniform", insertBefore);
1894         *vec = NULL;
1895         return unif;
1896     }
1897 
1898     llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(*vec);
1899     if (bop == NULL)
1900         return NULL;
1901 
1902     llvm::Value *op0 = bop->getOperand(0), *op1 = bop->getOperand(1);
1903     if (bop->getOpcode() == llvm::Instruction::Add) {
1904         llvm::Value *s0 = lExtractUniforms(&op0, insertBefore);
1905         llvm::Value *s1 = lExtractUniforms(&op1, insertBefore);
1906         if (s0 == NULL && s1 == NULL)
1907             return NULL;
1908 
1909         if (op0 == NULL)
1910             *vec = op1;
1911         else if (op1 == NULL)
1912             *vec = op0;
1913         else
1914             *vec = llvm::BinaryOperator::Create(llvm::Instruction::Add,
1915                                                 op0, op1, "new_add", insertBefore);
1916 
1917         if (s0 == NULL)
1918             return s1;
1919         else if (s1 == NULL)
1920             return s0;
1921         else
1922             return llvm::BinaryOperator::Create(llvm::Instruction::Add, s0, s1,
1923                                                 "add_unif", insertBefore);
1924     }
1925 #if 0
1926     else if (bop->getOpcode() == llvm::Instruction::Mul) {
1927         // Check each operand for being one of the scale factors we care about.
1928         int splat;
1929         if (lIs248Splat(op0, &splat)) {
1930             *vec = op1;
1931             return LLVMInt32(splat);
1932         }
1933         else if (lIs248Splat(op1, &splat)) {
1934             *vec = op0;
1935             return LLVMInt32(splat);
1936         }
1937         else
1938             return LLVMInt32(1);
1939     }
1940 #endif
1941     else
1942         return NULL;
1943 }
1944 
1945 
1946 static void
1947 lExtractUniformsFromOffset(llvm::Value **basePtr, llvm::Value **offsetVector,
1948                            llvm::Value *offsetScale,
1949                            llvm::Instruction *insertBefore) {
1950 #if 1
1951     //(*basePtr)->dump();
1952     printf("\n");
1953     //(*offsetVector)->dump();
1954     printf("\n");
1955     //offsetScale->dump();
1956     printf("-----\n");
1957 #endif
1958 
1959     llvm::Value *uniformDelta = lExtractUniforms(offsetVector, insertBefore);
1960     if (uniformDelta == NULL)
1961         return;
1962 
1963     *basePtr = lGEPInst(*basePtr, arrayRef, "new_base", insertBefore);
1964 
1965     // this should only happen if we have only uniforms, but that in turn
1966     // shouldn't be a gather/scatter!
1967     Assert(*offsetVector != NULL);
1968 }
1969 #endif
1970 
lVectorIs32BitInts(llvm::Value * v)1971 static bool lVectorIs32BitInts(llvm::Value *v) {
1972     int nElts;
1973     int64_t elts[ISPC_MAX_NVEC];
1974     if (!LLVMExtractVectorInts(v, elts, &nElts))
1975         return false;
1976 
1977     for (int i = 0; i < nElts; ++i)
1978         if ((int32_t)elts[i] != elts[i])
1979             return false;
1980 
1981     return true;
1982 }
1983 
1984 /** Check to see if the two offset vectors can safely be represented with
1985     32-bit values.  If so, return true and update the pointed-to
1986     llvm::Value *s to be the 32-bit equivalents. */
lOffsets32BitSafe(llvm::Value ** variableOffsetPtr,llvm::Value ** constOffsetPtr,llvm::Instruction * insertBefore)1987 static bool lOffsets32BitSafe(llvm::Value **variableOffsetPtr, llvm::Value **constOffsetPtr,
1988                               llvm::Instruction *insertBefore) {
1989     llvm::Value *variableOffset = *variableOffsetPtr;
1990     llvm::Value *constOffset = *constOffsetPtr;
1991 
1992     if (variableOffset->getType() != LLVMTypes::Int32VectorType) {
1993         llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(variableOffset);
1994         if (sext != NULL && sext->getOperand(0)->getType() == LLVMTypes::Int32VectorType)
1995             // sext of a 32-bit vector -> the 32-bit vector is good
1996             variableOffset = sext->getOperand(0);
1997         else if (lVectorIs32BitInts(variableOffset))
1998             // The only constant vector we should have here is a vector of
1999             // all zeros (i.e. a ConstantAggregateZero, but just in case,
2000             // do the more general check with lVectorIs32BitInts().
2001             variableOffset = new llvm::TruncInst(variableOffset, LLVMTypes::Int32VectorType,
2002                                                  llvm::Twine(variableOffset->getName()) + "_trunc", insertBefore);
2003         else
2004             return false;
2005     }
2006 
2007     if (constOffset->getType() != LLVMTypes::Int32VectorType) {
2008         if (lVectorIs32BitInts(constOffset)) {
2009             // Truncate them so we have a 32-bit vector type for them.
2010             constOffset = new llvm::TruncInst(constOffset, LLVMTypes::Int32VectorType,
2011                                               llvm::Twine(constOffset->getName()) + "_trunc", insertBefore);
2012         } else {
2013             // FIXME: otherwise we just assume that all constant offsets
2014             // can actually always fit into 32-bits...  (This could be
2015             // wrong, but it should be only in pretty esoteric cases).  We
2016             // make this assumption for now since we sometimes generate
2017             // constants that need constant folding before we really have a
2018             // constant vector out of them, and
2019             // llvm::ConstantFoldInstruction() doesn't seem to be doing
2020             // enough for us in some cases if we call it from here.
2021             constOffset = new llvm::TruncInst(constOffset, LLVMTypes::Int32VectorType,
2022                                               llvm::Twine(constOffset->getName()) + "_trunc", insertBefore);
2023         }
2024     }
2025 
2026     *variableOffsetPtr = variableOffset;
2027     *constOffsetPtr = constOffset;
2028     return true;
2029 }
2030 
2031 /** Check to see if the offset value is composed of a string of Adds,
2032     SExts, and Constant Vectors that are 32-bit safe.  Recursively
2033     explores the operands of Add instructions (as they might themselves
2034     be adds that eventually terminate in constant vectors or a SExt.)
2035  */
2036 
lIs32BitSafeHelper(llvm::Value * v)2037 static bool lIs32BitSafeHelper(llvm::Value *v) {
2038     // handle Adds, SExts, Constant Vectors
2039     if (llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v)) {
2040         if ((bop->getOpcode() == llvm::Instruction::Add) || IsOrEquivalentToAdd(bop)) {
2041             return lIs32BitSafeHelper(bop->getOperand(0)) && lIs32BitSafeHelper(bop->getOperand(1));
2042         }
2043         return false;
2044     } else if (llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(v)) {
2045         return sext->getOperand(0)->getType() == LLVMTypes::Int32VectorType;
2046     } else
2047         return lVectorIs32BitInts(v);
2048 }
2049 
2050 /** Check to see if the single offset vector can safely be represented with
2051     32-bit values.  If so, return true and update the pointed-to
2052     llvm::Value * to be the 32-bit equivalent. */
lOffsets32BitSafe(llvm::Value ** offsetPtr,llvm::Instruction * insertBefore)2053 static bool lOffsets32BitSafe(llvm::Value **offsetPtr, llvm::Instruction *insertBefore) {
2054     llvm::Value *offset = *offsetPtr;
2055 
2056     if (offset->getType() == LLVMTypes::Int32VectorType)
2057         return true;
2058 
2059     llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(offset);
2060     if (sext != NULL && sext->getOperand(0)->getType() == LLVMTypes::Int32VectorType) {
2061         // sext of a 32-bit vector -> the 32-bit vector is good
2062         *offsetPtr = sext->getOperand(0);
2063         return true;
2064     } else if (lIs32BitSafeHelper(offset)) {
2065         // The only constant vector we should have here is a vector of
2066         // all zeros (i.e. a ConstantAggregateZero, but just in case,
2067         // do the more general check with lVectorIs32BitInts().
2068 
2069         // Alternatively, offset could be a sequence of adds terminating
2070         // in safe constant vectors or a SExt.
2071         *offsetPtr = new llvm::TruncInst(offset, LLVMTypes::Int32VectorType, llvm::Twine(offset->getName()) + "_trunc",
2072                                          insertBefore);
2073         return true;
2074     } else
2075         return false;
2076 }
2077 
lGSToGSBaseOffsets(llvm::CallInst * callInst)2078 static bool lGSToGSBaseOffsets(llvm::CallInst *callInst) {
2079     struct GSInfo {
2080         GSInfo(const char *pgFuncName, const char *pgboFuncName, const char *pgbo32FuncName, bool ig, bool ip)
2081             : isGather(ig), isPrefetch(ip) {
2082             func = m->module->getFunction(pgFuncName);
2083             baseOffsetsFunc = m->module->getFunction(pgboFuncName);
2084             baseOffsets32Func = m->module->getFunction(pgbo32FuncName);
2085         }
2086         llvm::Function *func;
2087         llvm::Function *baseOffsetsFunc, *baseOffsets32Func;
2088         const bool isGather;
2089         const bool isPrefetch;
2090     };
2091 
2092     GSInfo gsFuncs[] = {
2093         GSInfo(
2094             "__pseudo_gather32_i8",
2095             g->target->hasGather() ? "__pseudo_gather_base_offsets32_i8" : "__pseudo_gather_factored_base_offsets32_i8",
2096             g->target->hasGather() ? "__pseudo_gather_base_offsets32_i8" : "__pseudo_gather_factored_base_offsets32_i8",
2097             true, false),
2098         GSInfo("__pseudo_gather32_i16",
2099                g->target->hasGather() ? "__pseudo_gather_base_offsets32_i16"
2100                                       : "__pseudo_gather_factored_base_offsets32_i16",
2101                g->target->hasGather() ? "__pseudo_gather_base_offsets32_i16"
2102                                       : "__pseudo_gather_factored_base_offsets32_i16",
2103                true, false),
2104         GSInfo("__pseudo_gather32_i32",
2105                g->target->hasGather() ? "__pseudo_gather_base_offsets32_i32"
2106                                       : "__pseudo_gather_factored_base_offsets32_i32",
2107                g->target->hasGather() ? "__pseudo_gather_base_offsets32_i32"
2108                                       : "__pseudo_gather_factored_base_offsets32_i32",
2109                true, false),
2110         GSInfo("__pseudo_gather32_float",
2111                g->target->hasGather() ? "__pseudo_gather_base_offsets32_float"
2112                                       : "__pseudo_gather_factored_base_offsets32_float",
2113                g->target->hasGather() ? "__pseudo_gather_base_offsets32_float"
2114                                       : "__pseudo_gather_factored_base_offsets32_float",
2115                true, false),
2116         GSInfo("__pseudo_gather32_i64",
2117                g->target->hasGather() ? "__pseudo_gather_base_offsets32_i64"
2118                                       : "__pseudo_gather_factored_base_offsets32_i64",
2119                g->target->hasGather() ? "__pseudo_gather_base_offsets32_i64"
2120                                       : "__pseudo_gather_factored_base_offsets32_i64",
2121                true, false),
2122         GSInfo("__pseudo_gather32_double",
2123                g->target->hasGather() ? "__pseudo_gather_base_offsets32_double"
2124                                       : "__pseudo_gather_factored_base_offsets32_double",
2125                g->target->hasGather() ? "__pseudo_gather_base_offsets32_double"
2126                                       : "__pseudo_gather_factored_base_offsets32_double",
2127                true, false),
2128 
2129         GSInfo("__pseudo_scatter32_i8",
2130                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i8"
2131                                        : "__pseudo_scatter_factored_base_offsets32_i8",
2132                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i8"
2133                                        : "__pseudo_scatter_factored_base_offsets32_i8",
2134                false, false),
2135         GSInfo("__pseudo_scatter32_i16",
2136                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i16"
2137                                        : "__pseudo_scatter_factored_base_offsets32_i16",
2138                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i16"
2139                                        : "__pseudo_scatter_factored_base_offsets32_i16",
2140                false, false),
2141         GSInfo("__pseudo_scatter32_i32",
2142                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i32"
2143                                        : "__pseudo_scatter_factored_base_offsets32_i32",
2144                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i32"
2145                                        : "__pseudo_scatter_factored_base_offsets32_i32",
2146                false, false),
2147         GSInfo("__pseudo_scatter32_float",
2148                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_float"
2149                                        : "__pseudo_scatter_factored_base_offsets32_float",
2150                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_float"
2151                                        : "__pseudo_scatter_factored_base_offsets32_float",
2152                false, false),
2153         GSInfo("__pseudo_scatter32_i64",
2154                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i64"
2155                                        : "__pseudo_scatter_factored_base_offsets32_i64",
2156                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i64"
2157                                        : "__pseudo_scatter_factored_base_offsets32_i64",
2158                false, false),
2159         GSInfo("__pseudo_scatter32_double",
2160                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_double"
2161                                        : "__pseudo_scatter_factored_base_offsets32_double",
2162                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_double"
2163                                        : "__pseudo_scatter_factored_base_offsets32_double",
2164                false, false),
2165 
2166         GSInfo(
2167             "__pseudo_gather64_i8",
2168             g->target->hasGather() ? "__pseudo_gather_base_offsets64_i8" : "__pseudo_gather_factored_base_offsets64_i8",
2169             g->target->hasGather() ? "__pseudo_gather_base_offsets32_i8" : "__pseudo_gather_factored_base_offsets32_i8",
2170             true, false),
2171         GSInfo("__pseudo_gather64_i16",
2172                g->target->hasGather() ? "__pseudo_gather_base_offsets64_i16"
2173                                       : "__pseudo_gather_factored_base_offsets64_i16",
2174                g->target->hasGather() ? "__pseudo_gather_base_offsets32_i16"
2175                                       : "__pseudo_gather_factored_base_offsets32_i16",
2176                true, false),
2177         GSInfo("__pseudo_gather64_i32",
2178                g->target->hasGather() ? "__pseudo_gather_base_offsets64_i32"
2179                                       : "__pseudo_gather_factored_base_offsets64_i32",
2180                g->target->hasGather() ? "__pseudo_gather_base_offsets32_i32"
2181                                       : "__pseudo_gather_factored_base_offsets32_i32",
2182                true, false),
2183         GSInfo("__pseudo_gather64_float",
2184                g->target->hasGather() ? "__pseudo_gather_base_offsets64_float"
2185                                       : "__pseudo_gather_factored_base_offsets64_float",
2186                g->target->hasGather() ? "__pseudo_gather_base_offsets32_float"
2187                                       : "__pseudo_gather_factored_base_offsets32_float",
2188                true, false),
2189         GSInfo("__pseudo_gather64_i64",
2190                g->target->hasGather() ? "__pseudo_gather_base_offsets64_i64"
2191                                       : "__pseudo_gather_factored_base_offsets64_i64",
2192                g->target->hasGather() ? "__pseudo_gather_base_offsets32_i64"
2193                                       : "__pseudo_gather_factored_base_offsets32_i64",
2194                true, false),
2195         GSInfo("__pseudo_gather64_double",
2196                g->target->hasGather() ? "__pseudo_gather_base_offsets64_double"
2197                                       : "__pseudo_gather_factored_base_offsets64_double",
2198                g->target->hasGather() ? "__pseudo_gather_base_offsets32_double"
2199                                       : "__pseudo_gather_factored_base_offsets32_double",
2200                true, false),
2201 
2202         GSInfo("__pseudo_scatter64_i8",
2203                g->target->hasScatter() ? "__pseudo_scatter_base_offsets64_i8"
2204                                        : "__pseudo_scatter_factored_base_offsets64_i8",
2205                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i8"
2206                                        : "__pseudo_scatter_factored_base_offsets32_i8",
2207                false, false),
2208         GSInfo("__pseudo_scatter64_i16",
2209                g->target->hasScatter() ? "__pseudo_scatter_base_offsets64_i16"
2210                                        : "__pseudo_scatter_factored_base_offsets64_i16",
2211                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i16"
2212                                        : "__pseudo_scatter_factored_base_offsets32_i16",
2213                false, false),
2214         GSInfo("__pseudo_scatter64_i32",
2215                g->target->hasScatter() ? "__pseudo_scatter_base_offsets64_i32"
2216                                        : "__pseudo_scatter_factored_base_offsets64_i32",
2217                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i32"
2218                                        : "__pseudo_scatter_factored_base_offsets32_i32",
2219                false, false),
2220         GSInfo("__pseudo_scatter64_float",
2221                g->target->hasScatter() ? "__pseudo_scatter_base_offsets64_float"
2222                                        : "__pseudo_scatter_factored_base_offsets64_float",
2223                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_float"
2224                                        : "__pseudo_scatter_factored_base_offsets32_float",
2225                false, false),
2226         GSInfo("__pseudo_scatter64_i64",
2227                g->target->hasScatter() ? "__pseudo_scatter_base_offsets64_i64"
2228                                        : "__pseudo_scatter_factored_base_offsets64_i64",
2229                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i64"
2230                                        : "__pseudo_scatter_factored_base_offsets32_i64",
2231                false, false),
2232         GSInfo("__pseudo_scatter64_double",
2233                g->target->hasScatter() ? "__pseudo_scatter_base_offsets64_double"
2234                                        : "__pseudo_scatter_factored_base_offsets64_double",
2235                g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_double"
2236                                        : "__pseudo_scatter_factored_base_offsets32_double",
2237                false, false),
2238         GSInfo("__pseudo_prefetch_read_varying_1",
2239                g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_1_native" : "__prefetch_read_varying_1",
2240                g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_1_native" : "__prefetch_read_varying_1",
2241                false, true),
2242 
2243         GSInfo("__pseudo_prefetch_read_varying_2",
2244                g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_2_native" : "__prefetch_read_varying_2",
2245                g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_2_native" : "__prefetch_read_varying_2",
2246                false, true),
2247 
2248         GSInfo("__pseudo_prefetch_read_varying_3",
2249                g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_3_native" : "__prefetch_read_varying_3",
2250                g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_3_native" : "__prefetch_read_varying_3",
2251                false, true),
2252 
2253         GSInfo("__pseudo_prefetch_read_varying_nt",
2254                g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_nt_native" : "__prefetch_read_varying_nt",
2255                g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_nt_native" : "__prefetch_read_varying_nt",
2256                false, true),
2257     };
2258 
2259     int numGSFuncs = sizeof(gsFuncs) / sizeof(gsFuncs[0]);
2260     for (int i = 0; i < numGSFuncs; ++i)
2261         Assert(gsFuncs[i].func != NULL && gsFuncs[i].baseOffsetsFunc != NULL && gsFuncs[i].baseOffsets32Func != NULL);
2262 
2263     GSInfo *info = NULL;
2264     for (int i = 0; i < numGSFuncs; ++i)
2265         if (gsFuncs[i].func != NULL && callInst->getCalledFunction() == gsFuncs[i].func) {
2266             info = &gsFuncs[i];
2267             break;
2268         }
2269     if (info == NULL)
2270         return false;
2271 
2272     // Try to transform the array of pointers to a single base pointer
2273     // and an array of int32 offsets.  (All the hard work is done by
2274     // lGetBasePtrAndOffsets).
2275     llvm::Value *ptrs = callInst->getArgOperand(0);
2276     llvm::Value *offsetVector = NULL;
2277     llvm::Value *basePtr = lGetBasePtrAndOffsets(ptrs, &offsetVector, callInst);
2278 
2279     if (basePtr == NULL || offsetVector == NULL ||
2280         (info->isGather == false && info->isPrefetch == true && g->target->hasVecPrefetch() == false)) {
2281         // It's actually a fully general gather/scatter with a varying
2282         // set of base pointers, so leave it as is and continune onward
2283         // to the next instruction...
2284         return false;
2285     }
2286     // Cast the base pointer to a void *, since that's what the
2287     // __pseudo_*_base_offsets_* functions want.
2288     basePtr = new llvm::IntToPtrInst(basePtr, LLVMTypes::VoidPointerType, llvm::Twine(basePtr->getName()) + "_2void",
2289                                      callInst);
2290     lCopyMetadata(basePtr, callInst);
2291     llvm::Function *gatherScatterFunc = info->baseOffsetsFunc;
2292 
2293     if ((info->isGather == true && g->target->hasGather()) ||
2294         (info->isGather == false && info->isPrefetch == false && g->target->hasScatter()) ||
2295         (info->isGather == false && info->isPrefetch == true && g->target->hasVecPrefetch())) {
2296 
2297         // See if the offsets are scaled by 2, 4, or 8.  If so,
2298         // extract that scale factor and rewrite the offsets to remove
2299         // it.
2300         llvm::Value *offsetScale = lExtractOffsetVector248Scale(&offsetVector);
2301 
2302         // If we're doing 32-bit addressing on a 64-bit target, here we
2303         // will see if we can call one of the 32-bit variants of the pseudo
2304         // gather/scatter functions.
2305         if (g->opt.force32BitAddressing && lOffsets32BitSafe(&offsetVector, callInst)) {
2306             gatherScatterFunc = info->baseOffsets32Func;
2307         }
2308 
2309         if (info->isGather || info->isPrefetch) {
2310             llvm::Value *mask = callInst->getArgOperand(1);
2311 
2312             // Generate a new function call to the next pseudo gather
2313             // base+offsets instruction.  Note that we're passing a NULL
2314             // llvm::Instruction to llvm::CallInst::Create; this means that
2315             // the instruction isn't inserted into a basic block and that
2316             // way we can then call ReplaceInstWithInst().
2317             llvm::Instruction *newCall = lCallInst(gatherScatterFunc, basePtr, offsetScale, offsetVector, mask,
2318                                                    callInst->getName().str().c_str(), NULL);
2319             lCopyMetadata(newCall, callInst);
2320             llvm::ReplaceInstWithInst(callInst, newCall);
2321         } else {
2322             llvm::Value *storeValue = callInst->getArgOperand(1);
2323             llvm::Value *mask = callInst->getArgOperand(2);
2324 
2325             // Generate a new function call to the next pseudo scatter
2326             // base+offsets instruction.  See above for why passing NULL
2327             // for the Instruction * is intended.
2328             llvm::Instruction *newCall =
2329                 lCallInst(gatherScatterFunc, basePtr, offsetScale, offsetVector, storeValue, mask, "", NULL);
2330             lCopyMetadata(newCall, callInst);
2331             llvm::ReplaceInstWithInst(callInst, newCall);
2332         }
2333     } else {
2334         // Try to decompose the offset vector into a compile time constant
2335         // component and a varying component.  The constant component is
2336         // passed as a separate parameter to the gather/scatter functions,
2337         // which in turn allows their implementations to end up emitting
2338         // x86 instructions with constant offsets encoded in them.
2339         llvm::Value *constOffset = NULL;
2340         llvm::Value *variableOffset = NULL;
2341         lExtractConstantOffset(offsetVector, &constOffset, &variableOffset, callInst);
2342         if (constOffset == NULL)
2343             constOffset = LLVMIntAsType(0, offsetVector->getType());
2344         if (variableOffset == NULL)
2345             variableOffset = LLVMIntAsType(0, offsetVector->getType());
2346 
2347         // See if the varying component is scaled by 2, 4, or 8.  If so,
2348         // extract that scale factor and rewrite variableOffset to remove
2349         // it.  (This also is pulled out so that we can match the scales by
2350         // 2/4/8 offered by x86 addressing operators.)
2351         llvm::Value *offsetScale = lExtractOffsetVector248Scale(&variableOffset);
2352 
2353         // If we're doing 32-bit addressing on a 64-bit target, here we
2354         // will see if we can call one of the 32-bit variants of the pseudo
2355         // gather/scatter functions.
2356         if (g->opt.force32BitAddressing && lOffsets32BitSafe(&variableOffset, &constOffset, callInst)) {
2357             gatherScatterFunc = info->baseOffsets32Func;
2358         }
2359 
2360         if (info->isGather || info->isPrefetch) {
2361             llvm::Value *mask = callInst->getArgOperand(1);
2362 
2363             // Generate a new function call to the next pseudo gather
2364             // base+offsets instruction.  Note that we're passing a NULL
2365             // llvm::Instruction to llvm::CallInst::Create; this means that
2366             // the instruction isn't inserted into a basic block and that
2367             // way we can then call ReplaceInstWithInst().
2368             llvm::Instruction *newCall = lCallInst(gatherScatterFunc, basePtr, variableOffset, offsetScale, constOffset,
2369                                                    mask, callInst->getName().str().c_str(), NULL);
2370             lCopyMetadata(newCall, callInst);
2371             llvm::ReplaceInstWithInst(callInst, newCall);
2372         } else {
2373             llvm::Value *storeValue = callInst->getArgOperand(1);
2374             llvm::Value *mask = callInst->getArgOperand(2);
2375 
2376             // Generate a new function call to the next pseudo scatter
2377             // base+offsets instruction.  See above for why passing NULL
2378             // for the Instruction * is intended.
2379             llvm::Instruction *newCall = lCallInst(gatherScatterFunc, basePtr, variableOffset, offsetScale, constOffset,
2380                                                    storeValue, mask, "", NULL);
2381             lCopyMetadata(newCall, callInst);
2382             llvm::ReplaceInstWithInst(callInst, newCall);
2383         }
2384     }
2385     return true;
2386 }
2387 
2388 /** Try to improve the decomposition between compile-time constant and
2389     compile-time unknown offsets in calls to the __pseudo_*_base_offsets*
2390     functions.  Other other optimizations have run, we will sometimes be
2391     able to pull more terms out of the unknown part and add them into the
2392     compile-time-known part.
2393  */
lGSBaseOffsetsGetMoreConst(llvm::CallInst * callInst)2394 static bool lGSBaseOffsetsGetMoreConst(llvm::CallInst *callInst) {
2395     struct GSBOInfo {
2396         GSBOInfo(const char *pgboFuncName, const char *pgbo32FuncName, bool ig, bool ip)
2397             : isGather(ig), isPrefetch(ip) {
2398             baseOffsetsFunc = m->module->getFunction(pgboFuncName);
2399             baseOffsets32Func = m->module->getFunction(pgbo32FuncName);
2400         }
2401         llvm::Function *baseOffsetsFunc, *baseOffsets32Func;
2402         const bool isGather;
2403         const bool isPrefetch;
2404     };
2405 
2406     GSBOInfo gsFuncs[] = {
2407         GSBOInfo(
2408             g->target->hasGather() ? "__pseudo_gather_base_offsets32_i8" : "__pseudo_gather_factored_base_offsets32_i8",
2409             g->target->hasGather() ? "__pseudo_gather_base_offsets32_i8" : "__pseudo_gather_factored_base_offsets32_i8",
2410             true, false),
2411         GSBOInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets32_i16"
2412                                         : "__pseudo_gather_factored_base_offsets32_i16",
2413                  g->target->hasGather() ? "__pseudo_gather_base_offsets32_i16"
2414                                         : "__pseudo_gather_factored_base_offsets32_i16",
2415                  true, false),
2416         GSBOInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets32_i32"
2417                                         : "__pseudo_gather_factored_base_offsets32_i32",
2418                  g->target->hasGather() ? "__pseudo_gather_base_offsets32_i32"
2419                                         : "__pseudo_gather_factored_base_offsets32_i32",
2420                  true, false),
2421         GSBOInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets32_float"
2422                                         : "__pseudo_gather_factored_base_offsets32_float",
2423                  g->target->hasGather() ? "__pseudo_gather_base_offsets32_float"
2424                                         : "__pseudo_gather_factored_base_offsets32_float",
2425                  true, false),
2426         GSBOInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets32_i64"
2427                                         : "__pseudo_gather_factored_base_offsets32_i64",
2428                  g->target->hasGather() ? "__pseudo_gather_base_offsets32_i64"
2429                                         : "__pseudo_gather_factored_base_offsets32_i64",
2430                  true, false),
2431         GSBOInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets32_double"
2432                                         : "__pseudo_gather_factored_base_offsets32_double",
2433                  g->target->hasGather() ? "__pseudo_gather_base_offsets32_double"
2434                                         : "__pseudo_gather_factored_base_offsets32_double",
2435                  true, false),
2436 
2437         GSBOInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i8"
2438                                          : "__pseudo_scatter_factored_base_offsets32_i8",
2439                  g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i8"
2440                                          : "__pseudo_scatter_factored_base_offsets32_i8",
2441                  false, false),
2442         GSBOInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i16"
2443                                          : "__pseudo_scatter_factored_base_offsets32_i16",
2444                  g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i16"
2445                                          : "__pseudo_scatter_factored_base_offsets32_i16",
2446                  false, false),
2447         GSBOInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i32"
2448                                          : "__pseudo_scatter_factored_base_offsets32_i32",
2449                  g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i32"
2450                                          : "__pseudo_scatter_factored_base_offsets32_i32",
2451                  false, false),
2452         GSBOInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_float"
2453                                          : "__pseudo_scatter_factored_base_offsets32_float",
2454                  g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_float"
2455                                          : "__pseudo_scatter_factored_base_offsets32_float",
2456                  false, false),
2457         GSBOInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i64"
2458                                          : "__pseudo_scatter_factored_base_offsets32_i64",
2459                  g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i64"
2460                                          : "__pseudo_scatter_factored_base_offsets32_i64",
2461                  false, false),
2462         GSBOInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_double"
2463                                          : "__pseudo_scatter_factored_base_offsets32_double",
2464                  g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_double"
2465                                          : "__pseudo_scatter_factored_base_offsets32_double",
2466                  false, false),
2467 
2468         GSBOInfo(g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_1_native" : "__prefetch_read_varying_1",
2469                  g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_1_native" : "__prefetch_read_varying_1",
2470                  false, true),
2471 
2472         GSBOInfo(g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_2_native" : "__prefetch_read_varying_2",
2473                  g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_2_native" : "__prefetch_read_varying_2",
2474                  false, true),
2475 
2476         GSBOInfo(g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_3_native" : "__prefetch_read_varying_3",
2477                  g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_3_native" : "__prefetch_read_varying_3",
2478                  false, true),
2479 
2480         GSBOInfo(
2481             g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_nt_native" : "__prefetch_read_varying_nt",
2482             g->target->hasVecPrefetch() ? "__pseudo_prefetch_read_varying_nt_native" : "__prefetch_read_varying_nt",
2483             false, true),
2484     };
2485 
2486     int numGSFuncs = sizeof(gsFuncs) / sizeof(gsFuncs[0]);
2487     for (int i = 0; i < numGSFuncs; ++i)
2488         Assert(gsFuncs[i].baseOffsetsFunc != NULL && gsFuncs[i].baseOffsets32Func != NULL);
2489 
2490     llvm::Function *calledFunc = callInst->getCalledFunction();
2491     Assert(calledFunc != NULL);
2492 
2493     // Is one of the gather/scatter functins that decompose into
2494     // base+offsets being called?
2495     GSBOInfo *info = NULL;
2496     for (int i = 0; i < numGSFuncs; ++i)
2497         if (calledFunc == gsFuncs[i].baseOffsetsFunc || calledFunc == gsFuncs[i].baseOffsets32Func) {
2498             info = &gsFuncs[i];
2499             break;
2500         }
2501     if (info == NULL)
2502         return false;
2503 
2504     // Grab the old variable offset
2505     llvm::Value *origVariableOffset = callInst->getArgOperand(1);
2506 
2507     // If it's zero, we're done.  Don't go and think that we're clever by
2508     // adding these zeros to the constant offsets.
2509     if (llvm::isa<llvm::ConstantAggregateZero>(origVariableOffset))
2510         return false;
2511 
2512     // Try to decompose the old variable offset
2513     llvm::Value *constOffset = NULL;
2514     llvm::Value *variableOffset = NULL;
2515     lExtractConstantOffset(origVariableOffset, &constOffset, &variableOffset, callInst);
2516 
2517     // No luck
2518     if (constOffset == NULL)
2519         return false;
2520 
2521     // Total luck: everything could be moved to the constant offset
2522     if (variableOffset == NULL)
2523         variableOffset = LLVMIntAsType(0, origVariableOffset->getType());
2524 
2525     // We need to scale the value we add to the constant offset by the
2526     // 2/4/8 scale for the variable offset, if present.
2527     llvm::ConstantInt *varScale = llvm::dyn_cast<llvm::ConstantInt>(callInst->getArgOperand(2));
2528     Assert(varScale != NULL);
2529 
2530     llvm::Value *scaleSmear;
2531     if (origVariableOffset->getType() == LLVMTypes::Int64VectorType)
2532         scaleSmear = LLVMInt64Vector((int64_t)varScale->getZExtValue());
2533     else
2534         scaleSmear = LLVMInt32Vector((int32_t)varScale->getZExtValue());
2535 
2536     constOffset =
2537         llvm::BinaryOperator::Create(llvm::Instruction::Mul, constOffset, scaleSmear, constOffset->getName(), callInst);
2538 
2539     // And add the additional offset to the original constant offset
2540     constOffset = llvm::BinaryOperator::Create(llvm::Instruction::Add, constOffset, callInst->getArgOperand(3),
2541                                                callInst->getArgOperand(3)->getName(), callInst);
2542 
2543     // Finally, update the values of the operands to the gather/scatter
2544     // function.
2545     callInst->setArgOperand(1, variableOffset);
2546     callInst->setArgOperand(3, constOffset);
2547 
2548     return true;
2549 }
2550 
lComputeCommonPointer(llvm::Value * base,llvm::Value * offsets,llvm::Instruction * insertBefore,int typeScale=1)2551 static llvm::Value *lComputeCommonPointer(llvm::Value *base, llvm::Value *offsets, llvm::Instruction *insertBefore,
2552                                           int typeScale = 1) {
2553     llvm::Value *firstOffset = LLVMExtractFirstVectorElement(offsets);
2554     Assert(firstOffset != NULL);
2555     llvm::Value *typeScaleValue =
2556         firstOffset->getType() == LLVMTypes::Int32Type ? LLVMInt32(typeScale) : LLVMInt64(typeScale);
2557     if (g->target->isGenXTarget() && typeScale > 1) {
2558         firstOffset = llvm::BinaryOperator::Create(llvm::Instruction::SDiv, firstOffset, typeScaleValue,
2559                                                    "scaled_offset", insertBefore);
2560     }
2561 
2562     return lGEPInst(base, firstOffset, "ptr", insertBefore);
2563 }
2564 
lGetOffsetScaleVec(llvm::Value * offsetScale,llvm::Type * vecType)2565 static llvm::Constant *lGetOffsetScaleVec(llvm::Value *offsetScale, llvm::Type *vecType) {
2566     llvm::ConstantInt *offsetScaleInt = llvm::dyn_cast<llvm::ConstantInt>(offsetScale);
2567     Assert(offsetScaleInt != NULL);
2568     uint64_t scaleValue = offsetScaleInt->getZExtValue();
2569 
2570     std::vector<llvm::Constant *> scales;
2571     for (int i = 0; i < g->target->getVectorWidth(); ++i) {
2572         if (vecType == LLVMTypes::Int64VectorType)
2573             scales.push_back(LLVMInt64(scaleValue));
2574         else {
2575             Assert(vecType == LLVMTypes::Int32VectorType);
2576             scales.push_back(LLVMInt32((int32_t)scaleValue));
2577         }
2578     }
2579     return llvm::ConstantVector::get(scales);
2580 }
2581 
2582 /** After earlier optimization passes have run, we are sometimes able to
2583     determine that gathers/scatters are actually accessing memory in a more
2584     regular fashion and then change the operation to something simpler and
2585     more efficient.  For example, if all of the lanes in a gather are
2586     reading from the same location, we can instead do a scalar load and
2587     broadcast.  This pass examines gathers and scatters and tries to
2588     simplify them if at all possible.
2589 
2590     @todo Currently, this only looks for all program instances going to the
2591     same location and all going to a linear sequence of locations in
2592     memory.  There are a number of other cases that might make sense to
2593     look for, including things that could be handled with a vector load +
2594     shuffle or things that could be handled with hybrids of e.g. 2 4-wide
2595     vector loads with AVX, etc.
2596 */
lGSToLoadStore(llvm::CallInst * callInst)2597 static bool lGSToLoadStore(llvm::CallInst *callInst) {
2598     struct GatherImpInfo {
2599         GatherImpInfo(const char *pName, const char *lmName, const char *bmName, llvm::Type *st, int a)
2600             : align(a), isFactored(!g->target->hasGather()) {
2601             pseudoFunc = m->module->getFunction(pName);
2602             loadMaskedFunc = m->module->getFunction(lmName);
2603             blendMaskedFunc = m->module->getFunction(bmName);
2604             Assert(pseudoFunc != NULL && loadMaskedFunc != NULL);
2605             scalarType = st;
2606         }
2607 
2608         llvm::Function *pseudoFunc;
2609         llvm::Function *loadMaskedFunc;
2610         llvm::Function *blendMaskedFunc;
2611         llvm::Type *scalarType;
2612         const int align;
2613         const bool isFactored;
2614     };
2615 
2616     GatherImpInfo gInfo[] = {
2617         GatherImpInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets32_i8"
2618                                              : "__pseudo_gather_factored_base_offsets32_i8",
2619                       "__masked_load_i8", "__masked_load_blend_i8", LLVMTypes::Int8Type, 1),
2620         GatherImpInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets32_i16"
2621                                              : "__pseudo_gather_factored_base_offsets32_i16",
2622                       "__masked_load_i16", "__masked_load_blend_i16", LLVMTypes::Int16Type, 2),
2623         GatherImpInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets32_i32"
2624                                              : "__pseudo_gather_factored_base_offsets32_i32",
2625                       "__masked_load_i32", "__masked_load_blend_i32", LLVMTypes::Int32Type, 4),
2626         GatherImpInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets32_float"
2627                                              : "__pseudo_gather_factored_base_offsets32_float",
2628                       "__masked_load_float", "__masked_load_blend_float", LLVMTypes::FloatType, 4),
2629         GatherImpInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets32_i64"
2630                                              : "__pseudo_gather_factored_base_offsets32_i64",
2631                       "__masked_load_i64", "__masked_load_blend_i64", LLVMTypes::Int64Type, 8),
2632         GatherImpInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets32_double"
2633                                              : "__pseudo_gather_factored_base_offsets32_double",
2634                       "__masked_load_double", "__masked_load_blend_double", LLVMTypes::DoubleType, 8),
2635         GatherImpInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets64_i8"
2636                                              : "__pseudo_gather_factored_base_offsets64_i8",
2637                       "__masked_load_i8", "__masked_load_blend_i8", LLVMTypes::Int8Type, 1),
2638         GatherImpInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets64_i16"
2639                                              : "__pseudo_gather_factored_base_offsets64_i16",
2640                       "__masked_load_i16", "__masked_load_blend_i16", LLVMTypes::Int16Type, 2),
2641         GatherImpInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets64_i32"
2642                                              : "__pseudo_gather_factored_base_offsets64_i32",
2643                       "__masked_load_i32", "__masked_load_blend_i32", LLVMTypes::Int32Type, 4),
2644         GatherImpInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets64_float"
2645                                              : "__pseudo_gather_factored_base_offsets64_float",
2646                       "__masked_load_float", "__masked_load_blend_float", LLVMTypes::FloatType, 4),
2647         GatherImpInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets64_i64"
2648                                              : "__pseudo_gather_factored_base_offsets64_i64",
2649                       "__masked_load_i64", "__masked_load_blend_i64", LLVMTypes::Int64Type, 8),
2650         GatherImpInfo(g->target->hasGather() ? "__pseudo_gather_base_offsets64_double"
2651                                              : "__pseudo_gather_factored_base_offsets64_double",
2652                       "__masked_load_double", "__masked_load_blend_double", LLVMTypes::DoubleType, 8),
2653     };
2654 
2655     struct ScatterImpInfo {
2656         ScatterImpInfo(const char *pName, const char *msName, llvm::Type *vpt, int a)
2657             : align(a), isFactored(!g->target->hasScatter()) {
2658             pseudoFunc = m->module->getFunction(pName);
2659             maskedStoreFunc = m->module->getFunction(msName);
2660             vecPtrType = vpt;
2661             Assert(pseudoFunc != NULL && maskedStoreFunc != NULL);
2662         }
2663         llvm::Function *pseudoFunc;
2664         llvm::Function *maskedStoreFunc;
2665         llvm::Type *vecPtrType;
2666         const int align;
2667         const bool isFactored;
2668     };
2669 
2670     ScatterImpInfo sInfo[] = {
2671         ScatterImpInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i8"
2672                                                : "__pseudo_scatter_factored_base_offsets32_i8",
2673                        "__pseudo_masked_store_i8", LLVMTypes::Int8VectorPointerType, 1),
2674         ScatterImpInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i16"
2675                                                : "__pseudo_scatter_factored_base_offsets32_i16",
2676                        "__pseudo_masked_store_i16", LLVMTypes::Int16VectorPointerType, 2),
2677         ScatterImpInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i32"
2678                                                : "__pseudo_scatter_factored_base_offsets32_i32",
2679                        "__pseudo_masked_store_i32", LLVMTypes::Int32VectorPointerType, 4),
2680         ScatterImpInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_float"
2681                                                : "__pseudo_scatter_factored_base_offsets32_float",
2682                        "__pseudo_masked_store_float", LLVMTypes::FloatVectorPointerType, 4),
2683         ScatterImpInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_i64"
2684                                                : "__pseudo_scatter_factored_base_offsets32_i64",
2685                        "__pseudo_masked_store_i64", LLVMTypes::Int64VectorPointerType, 8),
2686         ScatterImpInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets32_double"
2687                                                : "__pseudo_scatter_factored_base_offsets32_double",
2688                        "__pseudo_masked_store_double", LLVMTypes::DoubleVectorPointerType, 8),
2689         ScatterImpInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets64_i8"
2690                                                : "__pseudo_scatter_factored_base_offsets64_i8",
2691                        "__pseudo_masked_store_i8", LLVMTypes::Int8VectorPointerType, 1),
2692         ScatterImpInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets64_i16"
2693                                                : "__pseudo_scatter_factored_base_offsets64_i16",
2694                        "__pseudo_masked_store_i16", LLVMTypes::Int16VectorPointerType, 2),
2695         ScatterImpInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets64_i32"
2696                                                : "__pseudo_scatter_factored_base_offsets64_i32",
2697                        "__pseudo_masked_store_i32", LLVMTypes::Int32VectorPointerType, 4),
2698         ScatterImpInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets64_float"
2699                                                : "__pseudo_scatter_factored_base_offsets64_float",
2700                        "__pseudo_masked_store_float", LLVMTypes::FloatVectorPointerType, 4),
2701         ScatterImpInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets64_i64"
2702                                                : "__pseudo_scatter_factored_base_offsets64_i64",
2703                        "__pseudo_masked_store_i64", LLVMTypes::Int64VectorPointerType, 8),
2704         ScatterImpInfo(g->target->hasScatter() ? "__pseudo_scatter_base_offsets64_double"
2705                                                : "__pseudo_scatter_factored_base_offsets64_double",
2706                        "__pseudo_masked_store_double", LLVMTypes::DoubleVectorPointerType, 8),
2707     };
2708 
2709     llvm::Function *calledFunc = callInst->getCalledFunction();
2710 
2711     GatherImpInfo *gatherInfo = NULL;
2712     ScatterImpInfo *scatterInfo = NULL;
2713     for (unsigned int i = 0; i < sizeof(gInfo) / sizeof(gInfo[0]); ++i) {
2714         if (gInfo[i].pseudoFunc != NULL && calledFunc == gInfo[i].pseudoFunc) {
2715             gatherInfo = &gInfo[i];
2716             break;
2717         }
2718     }
2719     for (unsigned int i = 0; i < sizeof(sInfo) / sizeof(sInfo[0]); ++i) {
2720         if (sInfo[i].pseudoFunc != NULL && calledFunc == sInfo[i].pseudoFunc) {
2721             scatterInfo = &sInfo[i];
2722             break;
2723         }
2724     }
2725     if (gatherInfo == NULL && scatterInfo == NULL)
2726         return false;
2727 
2728     SourcePos pos;
2729     lGetSourcePosFromMetadata(callInst, &pos);
2730 
2731     llvm::Value *base = callInst->getArgOperand(0);
2732     llvm::Value *fullOffsets = NULL;
2733     llvm::Value *storeValue = NULL;
2734     llvm::Value *mask = NULL;
2735     if ((gatherInfo != NULL && gatherInfo->isFactored) || (scatterInfo != NULL && scatterInfo->isFactored)) {
2736         llvm::Value *varyingOffsets = callInst->getArgOperand(1);
2737         llvm::Value *offsetScale = callInst->getArgOperand(2);
2738         llvm::Value *constOffsets = callInst->getArgOperand(3);
2739         if (scatterInfo)
2740             storeValue = callInst->getArgOperand(4);
2741         mask = callInst->getArgOperand((gatherInfo != NULL) ? 4 : 5);
2742 
2743         // Compute the full offset vector: offsetScale * varyingOffsets + constOffsets
2744         llvm::Constant *offsetScaleVec = lGetOffsetScaleVec(offsetScale, varyingOffsets->getType());
2745 
2746         llvm::Value *scaledVarying = llvm::BinaryOperator::Create(llvm::Instruction::Mul, offsetScaleVec,
2747                                                                   varyingOffsets, "scaled_varying", callInst);
2748         fullOffsets = llvm::BinaryOperator::Create(llvm::Instruction::Add, scaledVarying, constOffsets,
2749                                                    "varying+const_offsets", callInst);
2750     } else {
2751         if (scatterInfo)
2752             storeValue = callInst->getArgOperand(3);
2753         mask = callInst->getArgOperand((gatherInfo != NULL) ? 3 : 4);
2754 
2755         llvm::Value *offsetScale = callInst->getArgOperand(1);
2756         llvm::Value *offsets = callInst->getArgOperand(2);
2757         llvm::Value *offsetScaleVec = lGetOffsetScaleVec(offsetScale, offsets->getType());
2758 
2759         fullOffsets =
2760             llvm::BinaryOperator::Create(llvm::Instruction::Mul, offsetScaleVec, offsets, "scaled_offsets", callInst);
2761     }
2762 
2763     Debug(SourcePos(), "GSToLoadStore: %s.", fullOffsets->getName().str().c_str());
2764     llvm::Type *scalarType = (gatherInfo != NULL) ? gatherInfo->scalarType : scatterInfo->vecPtrType->getScalarType();
2765     int typeScale = g->target->getDataLayout()->getTypeStoreSize(scalarType) /
2766                     g->target->getDataLayout()->getTypeStoreSize(base->getType()->getContainedType(0));
2767 
2768     if (LLVMVectorValuesAllEqual(fullOffsets)) {
2769         // If all the offsets are equal, then compute the single
2770         // pointer they all represent based on the first one of them
2771         // (arbitrarily).
2772         if (gatherInfo != NULL) {
2773             // A gather with everyone going to the same location is
2774             // handled as a scalar load and broadcast across the lanes.
2775             Debug(pos, "Transformed gather to scalar load and broadcast!");
2776             llvm::Value *ptr;
2777             // For gen we need to cast the base first and only after that get common pointer otherwise
2778             // CM backend will be broken on bitcast i8* to T* instruction with following load.
2779             // For this we need to re-calculate the offset basing on type sizes.
2780             if (g->target->isGenXTarget()) {
2781                 base = new llvm::BitCastInst(base, llvm::PointerType::get(scalarType, 0), base->getName(), callInst);
2782                 ptr = lComputeCommonPointer(base, fullOffsets, callInst, typeScale);
2783             } else {
2784                 ptr = lComputeCommonPointer(base, fullOffsets, callInst);
2785                 ptr = new llvm::BitCastInst(ptr, llvm::PointerType::get(scalarType, 0), base->getName(), callInst);
2786             }
2787 
2788             lCopyMetadata(ptr, callInst);
2789 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
2790             Assert(llvm::isa<llvm::PointerType>(ptr->getType()));
2791             llvm::Value *scalarValue =
2792                 new llvm::LoadInst(llvm::dyn_cast<llvm::PointerType>(ptr->getType())->getPointerElementType(), ptr,
2793                                    callInst->getName(), callInst);
2794 #else
2795             llvm::Value *scalarValue = new llvm::LoadInst(ptr, callInst->getName(), callInst);
2796 #endif
2797 
2798             // Generate the following sequence:
2799             //   %name123 = insertelement <4 x i32> undef, i32 %val, i32 0
2800             //   %name124 = shufflevector <4 x i32> %name123, <4 x i32> undef,
2801             //                                              <4 x i32> zeroinitializer
2802             llvm::Value *undef1Value = llvm::UndefValue::get(callInst->getType());
2803             llvm::Value *undef2Value = llvm::UndefValue::get(callInst->getType());
2804             llvm::Value *insertVec =
2805                 llvm::InsertElementInst::Create(undef1Value, scalarValue, LLVMInt32(0), callInst->getName(), callInst);
2806             llvm::Value *zeroMask =
2807 #if ISPC_LLVM_VERSION < ISPC_LLVM_11_0
2808                 llvm::ConstantVector::getSplat(callInst->getType()->getVectorNumElements(),
2809 #elif ISPC_LLVM_VERSION < ISPC_LLVM_12_0
2810                 llvm::ConstantVector::getSplat(
2811                     {llvm::dyn_cast<llvm::VectorType>(callInst->getType())->getNumElements(), false},
2812 
2813 #else // LLVM 12.0+
2814                 llvm::ConstantVector::getSplat(
2815                     llvm::ElementCount::get(
2816                         llvm::dyn_cast<llvm::FixedVectorType>(callInst->getType())->getNumElements(), false),
2817 #endif
2818                                                llvm::Constant::getNullValue(llvm::Type::getInt32Ty(*g->ctx)));
2819             llvm::Value *shufValue = new llvm::ShuffleVectorInst(insertVec, undef2Value, zeroMask, callInst->getName());
2820 
2821             lCopyMetadata(shufValue, callInst);
2822             llvm::ReplaceInstWithInst(callInst, llvm::dyn_cast<llvm::Instruction>(shufValue));
2823             return true;
2824         } else {
2825             // A scatter with everyone going to the same location is
2826             // undefined (if there's more than one program instance in
2827             // the gang).  Issue a warning.
2828             if (g->target->getVectorWidth() > 1)
2829                 Warning(pos, "Undefined behavior: all program instances are "
2830                              "writing to the same location!");
2831 
2832             // We could do something similar to the gather case, where
2833             // we arbitrarily write one of the values, but we need to
2834             // a) check to be sure the mask isn't all off and b) pick
2835             // the value from an executing program instance in that
2836             // case.  We'll just let a bunch of the program instances
2837             // do redundant writes, since this isn't important to make
2838             // fast anyway...
2839             return false;
2840         }
2841     } else {
2842         int step = gatherInfo ? gatherInfo->align : scatterInfo->align;
2843         if (step > 0 && LLVMVectorIsLinear(fullOffsets, step)) {
2844             // We have a linear sequence of memory locations being accessed
2845             // starting with the location given by the offset from
2846             // offsetElements[0], with stride of 4 or 8 bytes (for 32 bit
2847             // and 64 bit gather/scatters, respectively.
2848             llvm::Value *ptr;
2849 
2850             if (gatherInfo != NULL) {
2851                 if (g->target->isGenXTarget()) {
2852                     // For gen we need to cast the base first and only after that get common pointer otherwise
2853                     // CM backend will be broken on bitcast i8* to T* instruction with following load.
2854                     // For this we need to re-calculate the offset basing on type sizes.
2855                     // Second bitcast to void* does not cause such problem in backend.
2856                     base =
2857                         new llvm::BitCastInst(base, llvm::PointerType::get(scalarType, 0), base->getName(), callInst);
2858                     ptr = lComputeCommonPointer(base, fullOffsets, callInst, typeScale);
2859                     ptr = new llvm::BitCastInst(ptr, LLVMTypes::Int8PointerType, base->getName(), callInst);
2860                 } else {
2861                     ptr = lComputeCommonPointer(base, fullOffsets, callInst);
2862                 }
2863                 lCopyMetadata(ptr, callInst);
2864                 Debug(pos, "Transformed gather to unaligned vector load!");
2865                 bool doBlendLoad = false;
2866 #ifdef ISPC_GENX_ENABLED
2867                 doBlendLoad = g->target->isGenXTarget() && g->opt.enableGenXUnsafeMaskedLoad;
2868 #endif
2869                 llvm::Instruction *newCall =
2870                     lCallInst(doBlendLoad ? gatherInfo->blendMaskedFunc : gatherInfo->loadMaskedFunc, ptr, mask,
2871                               llvm::Twine(ptr->getName()) + "_masked_load");
2872                 lCopyMetadata(newCall, callInst);
2873                 llvm::ReplaceInstWithInst(callInst, newCall);
2874                 return true;
2875             } else {
2876                 Debug(pos, "Transformed scatter to unaligned vector store!");
2877                 ptr = lComputeCommonPointer(base, fullOffsets, callInst);
2878                 ptr = new llvm::BitCastInst(ptr, scatterInfo->vecPtrType, "ptrcast", callInst);
2879                 llvm::Instruction *newCall = lCallInst(scatterInfo->maskedStoreFunc, ptr, storeValue, mask, "");
2880                 lCopyMetadata(newCall, callInst);
2881                 llvm::ReplaceInstWithInst(callInst, newCall);
2882                 return true;
2883             }
2884         }
2885         return false;
2886     }
2887 }
2888 
2889 ///////////////////////////////////////////////////////////////////////////
2890 // MaskedStoreOptPass
2891 
2892 #ifdef ISPC_GENX_ENABLED
lGenXMaskedInt8Inst(llvm::Instruction * inst,bool isStore)2893 static llvm::Function *lGenXMaskedInt8Inst(llvm::Instruction *inst, bool isStore) {
2894     std::string maskedFuncName;
2895     if (isStore) {
2896         maskedFuncName = "masked_store_i8";
2897     } else {
2898         maskedFuncName = "masked_load_i8";
2899     }
2900     llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(inst);
2901     if (callInst != NULL && callInst->getCalledFunction()->getName().contains(maskedFuncName)) {
2902         return NULL;
2903     }
2904     return m->module->getFunction("__" + maskedFuncName);
2905 }
2906 
lGenXStoreInst(llvm::Value * val,llvm::Value * ptr,llvm::Instruction * inst)2907 static llvm::CallInst *lGenXStoreInst(llvm::Value *val, llvm::Value *ptr, llvm::Instruction *inst) {
2908     Assert(g->target->isGenXTarget());
2909 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
2910     Assert(llvm::isa<llvm::FixedVectorType>(val->getType()));
2911     llvm::FixedVectorType *valVecType = llvm::dyn_cast<llvm::FixedVectorType>(val->getType());
2912 #else
2913     Assert(llvm::isa<llvm::VectorType>(val->getType()));
2914     llvm::VectorType *valVecType = llvm::dyn_cast<llvm::VectorType>(val->getType());
2915 #endif
2916     Assert(llvm::isPowerOf2_32(valVecType->getNumElements()));
2917     Assert(valVecType->getPrimitiveSizeInBits() / 8 <= 8 * OWORD);
2918 
2919     // The data write of svm store must have a size that is a power of two from 16 to 128
2920     // bytes. However for int8 type and simd width = 8, the data write size is 8.
2921     // So we use masked store function here instead of svm store which process int8 type
2922     // correctly.
2923     if (valVecType->getPrimitiveSizeInBits() / 8 < 16) {
2924         Assert(valVecType->getScalarType() == LLVMTypes::Int8Type);
2925         if (llvm::Function *maskedFunc = lGenXMaskedInt8Inst(inst, true))
2926             return llvm::dyn_cast<llvm::CallInst>(lCallInst(maskedFunc, ptr, val, LLVMMaskAllOn, ""));
2927         else {
2928             return NULL;
2929         }
2930     }
2931     llvm::Instruction *svm_st_zext = new llvm::PtrToIntInst(ptr, LLVMTypes::Int64Type, "svm_st_ptrtoint", inst);
2932 
2933     llvm::Type *argTypes[] = {svm_st_zext->getType(), val->getType()};
2934     auto Fn = llvm::GenXIntrinsic::getGenXDeclaration(m->module, llvm::GenXIntrinsic::genx_svm_block_st, argTypes);
2935     return llvm::CallInst::Create(Fn, {svm_st_zext, val}, inst->getName());
2936 }
2937 
lGenXLoadInst(llvm::Value * ptr,llvm::Type * retType,llvm::Instruction * inst)2938 static llvm::CallInst *lGenXLoadInst(llvm::Value *ptr, llvm::Type *retType, llvm::Instruction *inst) {
2939 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
2940     Assert(llvm::isa<llvm::FixedVectorType>(retType));
2941     llvm::FixedVectorType *retVecType = llvm::dyn_cast<llvm::FixedVectorType>(retType);
2942 #else
2943     Assert(llvm::isa<llvm::VectorType>(retType));
2944     llvm::VectorType *retVecType = llvm::dyn_cast<llvm::VectorType>(retType);
2945 #endif
2946     Assert(llvm::isPowerOf2_32(retVecType->getNumElements()));
2947     Assert(retVecType->getPrimitiveSizeInBits());
2948     Assert(retVecType->getPrimitiveSizeInBits() / 8 <= 8 * OWORD);
2949     // The data read of svm load must have a size that is a power of two from 16 to 128
2950     // bytes. However for int8 type and simd width = 8, the data read size is 8.
2951     // So we use masked load function here instead of svm load which process int8 type
2952     // correctly.
2953     if (retVecType->getPrimitiveSizeInBits() / 8 < 16) {
2954         Assert(retVecType->getScalarType() == LLVMTypes::Int8Type);
2955         if (llvm::Function *maskedFunc = lGenXMaskedInt8Inst(inst, false))
2956             return llvm::dyn_cast<llvm::CallInst>(lCallInst(maskedFunc, ptr, LLVMMaskAllOn, ""));
2957         else {
2958             return NULL;
2959         }
2960     }
2961 
2962     llvm::Value *svm_ld_ptrtoint = new llvm::PtrToIntInst(ptr, LLVMTypes::Int64Type, "svm_ld_ptrtoint", inst);
2963 
2964     auto Fn = llvm::GenXIntrinsic::getGenXDeclaration(m->module, llvm::GenXIntrinsic::genx_svm_block_ld_unaligned,
2965                                                       {retType, svm_ld_ptrtoint->getType()});
2966 
2967     return llvm::CallInst::Create(Fn, svm_ld_ptrtoint, inst->getName());
2968 }
2969 #endif
2970 /** Masked stores are generally more complex than regular stores; for
2971     example, they require multiple instructions to simulate under SSE.
2972     This optimization detects cases where masked stores can be replaced
2973     with regular stores or removed entirely, for the cases of an 'all on'
2974     mask and an 'all off' mask, respectively.
2975 */
lImproveMaskedStore(llvm::CallInst * callInst)2976 static bool lImproveMaskedStore(llvm::CallInst *callInst) {
2977     struct MSInfo {
2978         MSInfo(const char *name, const int a) : align(a) {
2979             func = m->module->getFunction(name);
2980             Assert(func != NULL);
2981         }
2982         llvm::Function *func;
2983         const int align;
2984     };
2985 
2986     MSInfo msInfo[] = {MSInfo("__pseudo_masked_store_i8", 1),  MSInfo("__pseudo_masked_store_i16", 2),
2987                        MSInfo("__pseudo_masked_store_i32", 4), MSInfo("__pseudo_masked_store_float", 4),
2988                        MSInfo("__pseudo_masked_store_i64", 8), MSInfo("__pseudo_masked_store_double", 8),
2989                        MSInfo("__masked_store_blend_i8", 1),   MSInfo("__masked_store_blend_i16", 2),
2990                        MSInfo("__masked_store_blend_i32", 4),  MSInfo("__masked_store_blend_float", 4),
2991                        MSInfo("__masked_store_blend_i64", 8),  MSInfo("__masked_store_blend_double", 8),
2992                        MSInfo("__masked_store_i8", 1),         MSInfo("__masked_store_i16", 2),
2993                        MSInfo("__masked_store_i32", 4),        MSInfo("__masked_store_float", 4),
2994                        MSInfo("__masked_store_i64", 8),        MSInfo("__masked_store_double", 8)};
2995     llvm::Function *called = callInst->getCalledFunction();
2996 
2997     int nMSFuncs = sizeof(msInfo) / sizeof(msInfo[0]);
2998     MSInfo *info = NULL;
2999     for (int i = 0; i < nMSFuncs; ++i) {
3000         if (msInfo[i].func != NULL && called == msInfo[i].func) {
3001             info = &msInfo[i];
3002             break;
3003         }
3004     }
3005     if (info == NULL)
3006         return false;
3007 
3008     // Got one; grab the operands
3009     llvm::Value *lvalue = callInst->getArgOperand(0);
3010     llvm::Value *rvalue = callInst->getArgOperand(1);
3011     llvm::Value *mask = callInst->getArgOperand(2);
3012 
3013     MaskStatus maskStatus = lGetMaskStatus(mask);
3014     if (maskStatus == MaskStatus::all_off) {
3015         // Zero mask - no-op, so remove the store completely.  (This
3016         // may in turn lead to being able to optimize out instructions
3017         // that compute the rvalue...)
3018         callInst->eraseFromParent();
3019         return true;
3020     } else if (maskStatus == MaskStatus::all_on) {
3021         // The mask is all on, so turn this into a regular store
3022         llvm::Type *rvalueType = rvalue->getType();
3023         llvm::Instruction *store = NULL;
3024 #ifdef ISPC_GENX_ENABLED
3025         // InternalLinkage check is to prevent generation of SVM store when the pointer came from caller.
3026         // Since it can be allocated in a caller, it may be allocated on register. Possible svm store
3027         // is resolved after inlining. TODO: problems can be met here in case of Stack Calls.
3028         if (g->target->isGenXTarget() && GetAddressSpace(lvalue) == AddressSpace::External &&
3029             callInst->getParent()->getParent()->getLinkage() != llvm::GlobalValue::LinkageTypes::InternalLinkage) {
3030             store = lGenXStoreInst(rvalue, lvalue, callInst);
3031         } else if (!g->target->isGenXTarget() ||
3032                    (g->target->isGenXTarget() && GetAddressSpace(lvalue) == AddressSpace::Local))
3033 #endif
3034         {
3035             llvm::Type *ptrType = llvm::PointerType::get(rvalueType, 0);
3036 
3037             lvalue = new llvm::BitCastInst(lvalue, ptrType, "lvalue_to_ptr_type", callInst);
3038             lCopyMetadata(lvalue, callInst);
3039             store = new llvm::StoreInst(
3040                 rvalue, lvalue, false /* not volatile */,
3041                 llvm::MaybeAlign(g->opt.forceAlignedMemory ? g->target->getNativeVectorAlignment() : info->align)
3042                     .valueOrOne());
3043         }
3044         if (store != NULL) {
3045             lCopyMetadata(store, callInst);
3046             llvm::ReplaceInstWithInst(callInst, store);
3047             return true;
3048         }
3049 #ifdef ISPC_GENX_ENABLED
3050     } else {
3051         if (g->target->isGenXTarget() && GetAddressSpace(lvalue) == AddressSpace::External) {
3052             // In thuis case we use masked_store which on genx target causes scatter usage.
3053             // Get the source position from the metadata attached to the call
3054             // instruction so that we can issue PerformanceWarning()s below.
3055             SourcePos pos;
3056             bool gotPosition = lGetSourcePosFromMetadata(callInst, &pos);
3057             if (gotPosition) {
3058                 PerformanceWarning(pos, "Scatter required to store value.");
3059             }
3060         }
3061 #endif
3062     }
3063     return false;
3064 }
3065 
lImproveMaskedLoad(llvm::CallInst * callInst,llvm::BasicBlock::iterator iter)3066 static bool lImproveMaskedLoad(llvm::CallInst *callInst, llvm::BasicBlock::iterator iter) {
3067     struct MLInfo {
3068         MLInfo(const char *name, const int a) : align(a) {
3069             func = m->module->getFunction(name);
3070             Assert(func != NULL);
3071         }
3072         llvm::Function *func;
3073         const int align;
3074     };
3075 
3076     llvm::Function *called = callInst->getCalledFunction();
3077     // TODO: we should use dynamic data structure for MLInfo and fill
3078     // it differently for GenX and CPU targets. It will also help
3079     // to avoid declaration of GenX intrinsics for CPU targets.
3080     // It should be changed seamlessly here and in all similar places in this file.
3081     MLInfo mlInfo[] = {MLInfo("__masked_load_i8", 1),  MLInfo("__masked_load_i16", 2),
3082                        MLInfo("__masked_load_i32", 4), MLInfo("__masked_load_float", 4),
3083                        MLInfo("__masked_load_i64", 8), MLInfo("__masked_load_double", 8)};
3084     MLInfo genxInfo[] = {MLInfo("__masked_load_i8", 1),        MLInfo("__masked_load_i16", 2),
3085                          MLInfo("__masked_load_i32", 4),       MLInfo("__masked_load_float", 4),
3086                          MLInfo("__masked_load_i64", 8),       MLInfo("__masked_load_double", 8),
3087                          MLInfo("__masked_load_blend_i8", 1),  MLInfo("__masked_load_blend_i16", 2),
3088                          MLInfo("__masked_load_blend_i32", 4), MLInfo("__masked_load_blend_float", 4),
3089                          MLInfo("__masked_load_blend_i64", 8), MLInfo("__masked_load_blend_double", 8)};
3090     MLInfo *info = NULL;
3091     if (g->target->isGenXTarget()) {
3092         int nFuncs = sizeof(genxInfo) / sizeof(genxInfo[0]);
3093         for (int i = 0; i < nFuncs; ++i) {
3094             if (genxInfo[i].func != NULL && called == genxInfo[i].func) {
3095                 info = &genxInfo[i];
3096                 break;
3097             }
3098         }
3099     } else {
3100         int nFuncs = sizeof(mlInfo) / sizeof(mlInfo[0]);
3101         for (int i = 0; i < nFuncs; ++i) {
3102             if (mlInfo[i].func != NULL && called == mlInfo[i].func) {
3103                 info = &mlInfo[i];
3104                 break;
3105             }
3106         }
3107     }
3108     if (info == NULL)
3109         return false;
3110 
3111     // Got one; grab the operands
3112     llvm::Value *ptr = callInst->getArgOperand(0);
3113     llvm::Value *mask = callInst->getArgOperand(1);
3114 
3115     MaskStatus maskStatus = lGetMaskStatus(mask);
3116     if (maskStatus == MaskStatus::all_off) {
3117         // Zero mask - no-op, so replace the load with an undef value
3118         llvm::ReplaceInstWithValue(iter->getParent()->getInstList(), iter, llvm::UndefValue::get(callInst->getType()));
3119         return true;
3120     } else if (maskStatus == MaskStatus::all_on) {
3121         // The mask is all on, so turn this into a regular load
3122         llvm::Instruction *load = NULL;
3123 #ifdef ISPC_GENX_ENABLED
3124         // InternalLinkage check is to prevent generation of SVM load when the pointer came from caller.
3125         // Since it can be allocated in a caller, it may be allocated on register. Possible svm load
3126         // is resolved after inlining. TODO: problems can be met here in case of Stack Calls.
3127         if (g->target->isGenXTarget() && GetAddressSpace(ptr) == AddressSpace::External &&
3128             callInst->getParent()->getParent()->getLinkage() != llvm::GlobalValue::LinkageTypes::InternalLinkage) {
3129             load = lGenXLoadInst(ptr, callInst->getType(), callInst);
3130         } else if (!g->target->isGenXTarget() ||
3131                    (g->target->isGenXTarget() && GetAddressSpace(ptr) == AddressSpace::Local))
3132 #endif
3133         {
3134             llvm::Type *ptrType = llvm::PointerType::get(callInst->getType(), 0);
3135             ptr = new llvm::BitCastInst(ptr, ptrType, "ptr_cast_for_load", callInst);
3136 #if ISPC_LLVM_VERSION < ISPC_LLVM_11_0
3137             load = new llvm::LoadInst(
3138                 ptr, callInst->getName(), false /* not volatile */,
3139                 llvm::MaybeAlign(g->opt.forceAlignedMemory ? g->target->getNativeVectorAlignment() : info->align)
3140                     .valueOrOne(),
3141                 (llvm::Instruction *)NULL);
3142 #else // LLVM 11.0+
3143             Assert(llvm::isa<llvm::PointerType>(ptr->getType()));
3144             load = new llvm::LoadInst(
3145                 llvm::dyn_cast<llvm::PointerType>(ptr->getType())->getPointerElementType(), ptr, callInst->getName(),
3146                 false /* not volatile */,
3147                 llvm::MaybeAlign(g->opt.forceAlignedMemory ? g->target->getNativeVectorAlignment() : info->align)
3148                     .valueOrOne(),
3149                 (llvm::Instruction *)NULL);
3150 #endif
3151         }
3152         if (load != NULL) {
3153             lCopyMetadata(load, callInst);
3154             llvm::ReplaceInstWithInst(callInst, load);
3155             return true;
3156         }
3157     }
3158     return false;
3159 }
3160 
runOnBasicBlock(llvm::BasicBlock & bb)3161 bool ImproveMemoryOpsPass::runOnBasicBlock(llvm::BasicBlock &bb) {
3162     DEBUG_START_PASS("ImproveMemoryOps");
3163 
3164     bool modifiedAny = false;
3165 restart:
3166     // Iterate through all of the instructions in the basic block.
3167     for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
3168         llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(&*iter);
3169         // If we don't have a call to one of the
3170         // __pseudo_{gather,scatter}_* functions, then just go on to the
3171         // next instruction.
3172         if (callInst == NULL || callInst->getCalledFunction() == NULL)
3173             continue;
3174         if (lGSToGSBaseOffsets(callInst)) {
3175             modifiedAny = true;
3176             goto restart;
3177         }
3178         if (lGSBaseOffsetsGetMoreConst(callInst)) {
3179             modifiedAny = true;
3180             goto restart;
3181         }
3182         if (lGSToLoadStore(callInst)) {
3183             modifiedAny = true;
3184             goto restart;
3185         }
3186         if (lImproveMaskedStore(callInst)) {
3187             modifiedAny = true;
3188             goto restart;
3189         }
3190         if (lImproveMaskedLoad(callInst, iter)) {
3191             modifiedAny = true;
3192             goto restart;
3193         }
3194     }
3195 
3196     DEBUG_END_PASS("ImproveMemoryOps");
3197 
3198     return modifiedAny;
3199 }
3200 
runOnFunction(llvm::Function & F)3201 bool ImproveMemoryOpsPass::runOnFunction(llvm::Function &F) {
3202 
3203     llvm::TimeTraceScope FuncScope("ImproveMemoryOpsPass::runOnFunction", F.getName());
3204     bool modifiedAny = false;
3205     for (llvm::BasicBlock &BB : F) {
3206         modifiedAny |= runOnBasicBlock(BB);
3207     }
3208     return modifiedAny;
3209 }
3210 
CreateImproveMemoryOpsPass()3211 static llvm::Pass *CreateImproveMemoryOpsPass() { return new ImproveMemoryOpsPass; }
3212 
3213 ///////////////////////////////////////////////////////////////////////////
3214 // GatherCoalescePass
3215 
3216 // This pass implements two optimizations to improve the performance of
3217 // gathers; currently only gathers of 32-bit values where it can be
3218 // determined at compile time that the mask is all on are supported, though
3219 // both of those limitations may be generalized in the future.
3220 //
3221 //  First, for any single gather, see if it's worthwhile to break it into
3222 //  any of scalar, 2-wide (i.e. 64-bit), 4-wide, or 8-wide loads.  Further,
3223 //  we generate code that shuffles these loads around.  Doing fewer, larger
3224 //  loads in this manner, when possible, can be more efficient.
3225 //
3226 //  Second, this pass can coalesce memory accesses across multiple
3227 //  gathers. If we have a series of gathers without any memory writes in
3228 //  the middle, then we try to analyze their reads collectively and choose
3229 //  an efficient set of loads for them.  Not only does this help if
3230 //  different gathers reuse values from the same location in memory, but
3231 //  it's specifically helpful when data with AOS layout is being accessed;
3232 //  in this case, we're often able to generate wide vector loads and
3233 //  appropriate shuffles automatically.
3234 
3235 class GatherCoalescePass : public llvm::FunctionPass {
3236   public:
3237     static char ID;
GatherCoalescePass()3238     GatherCoalescePass() : FunctionPass(ID) {}
3239 
getPassName() const3240     llvm::StringRef getPassName() const { return "Gather Coalescing"; }
3241     bool runOnBasicBlock(llvm::BasicBlock &BB);
3242     bool runOnFunction(llvm::Function &F);
3243 };
3244 
3245 char GatherCoalescePass::ID = 0;
3246 
3247 /** Representation of a memory load that the gather coalescing code has
3248     decided to generate.
3249  */
3250 struct CoalescedLoadOp {
CoalescedLoadOpCoalescedLoadOp3251     CoalescedLoadOp(int64_t s, int c) {
3252         start = s;
3253         count = c;
3254         load = element0 = element1 = NULL;
3255     }
3256 
3257     /** Starting offset of the load from the common base pointer (in terms
3258         of numbers of items of the underlying element type--*not* in terms
3259         of bytes). */
3260     int64_t start;
3261 
3262     /** Number of elements to load at this location */
3263     int count;
3264 
3265     /** Value loaded from memory for this load op */
3266     llvm::Value *load;
3267 
3268     /** For 2-wide loads (i.e. 64-bit loads), these store the lower and
3269         upper 32 bits of the result, respectively. */
3270     llvm::Value *element0, *element1;
3271 };
3272 
3273 /** This function determines whether it makes sense (and is safe) to
3274     generate a vector load of width vectorWidth, starting at *iter.  It
3275     returns true if so, setting *newIter to point to the next element in
3276     the set that isn't taken care of by the generated load.  If a vector
3277     load of the given width doesn't make sense, then false is returned.
3278  */
lVectorLoadIsEfficient(std::set<int64_t>::iterator iter,std::set<int64_t>::iterator end,std::set<int64_t>::iterator * newIter,int vectorWidth)3279 static bool lVectorLoadIsEfficient(std::set<int64_t>::iterator iter, std::set<int64_t>::iterator end,
3280                                    std::set<int64_t>::iterator *newIter, int vectorWidth) {
3281     // We're considering a vector load of width vectorWidth, starting at
3282     // the offset "start".
3283     int64_t start = *iter;
3284 
3285     // The basic idea is that we'll look at the subsequent elements in the
3286     // load set after the initial one at start.  As long as subsequent
3287     // elements:
3288     //
3289     // 1. Aren't so far separated that they no longer fit into the range
3290     //    [start, start+vectorWidth)
3291     //
3292     // 2. And don't have too large a gap in between them (e.g., it's not
3293     //    worth generating an 8-wide load for two elements with offsets 0
3294     //    and 7, but no loads requested in between).
3295     //
3296     // Then we continue moving forward through the elements until we either
3297     // fill up the vector or run out of elements.
3298 
3299     // lastAccepted holds the last offset we've processed and accepted as
3300     // valid for the vector load underconsideration
3301     int64_t lastAccepted = start;
3302 
3303     while (iter != end) {
3304         // What is the separation in offset values from the last element we
3305         // added to the set for this load?
3306         int64_t delta = *iter - lastAccepted;
3307         if (delta > 3)
3308             // If there's too big a gap, then we won't issue the load
3309             return false;
3310 
3311         int64_t span = *iter - start + 1;
3312 
3313         if (span == vectorWidth) {
3314             // We've extended far enough that we have exactly filled up the
3315             // entire vector width; we can't go any further, so return with
3316             // success.  (Update *newIter to point at the next element
3317             // after the last one accepted here.)
3318             *newIter = ++iter;
3319             return true;
3320         } else if (span > vectorWidth) {
3321             // The current offset won't fit into a vectorWidth-wide load
3322             // starting from start.  It's still generally worthwhile
3323             // issuing the load we've been considering, though, since it
3324             // will provide values for a number of previous offsets.  This
3325             // load will have one or more elements at the end of its range
3326             // that is not needed by any of the offsets under
3327             // consideration.  As such, there are three cases where issuing
3328             // this load is a bad idea:
3329             //
3330             // 1. 2-wide loads: we know that we haven't completely filled
3331             //    the 2-wide vector, since otherwise the if() test above
3332             //    would have succeeded previously.  Therefore, we must have
3333             //    a situation with offsets like (4,6,...); it would be a
3334             //    silly idea to issue a 2-wide load to get the value for
3335             //    the 4 offset, versus failing here and issuing a scalar
3336             //    load instead.
3337             //
3338             // 2. If there are too many unnecessary values at the end of
3339             //    the load extent (defined as more than half of them)--in
3340             //    this case, it'd be better to issue a vector load of
3341             //    smaller width anyway.
3342             //
3343             // 3. If the gap between the last accepted offset and the
3344             //    current one under consideration is more than the page
3345             //    size.  In this case we can't be sure whether or not some
3346             //    of the unused elements at the end of the load will
3347             //    straddle a page boundary and thus lead to an undesirable
3348             //    fault.  (It's hard to imagine this happening in practice,
3349             //    except under contrived circumstances, but better safe
3350             //    than sorry.)
3351             const int pageSize = 4096;
3352             if (vectorWidth != 2 && (lastAccepted - start) > (vectorWidth / 2) && (*iter - lastAccepted) < pageSize) {
3353                 *newIter = iter;
3354                 return true;
3355             } else
3356                 return false;
3357         }
3358 
3359         // Continue moving forward
3360         lastAccepted = *iter;
3361         ++iter;
3362     }
3363 
3364     return false;
3365 }
3366 
3367 /** Given a set of offsets from a common base pointer that we need to get
3368     loaded into memory, determine a reasonable set of load operations that
3369     gets all of the corresponding values in memory (ideally, including as
3370     many as possible wider vector loads rather than scalar loads).  Return
3371     a CoalescedLoadOp for each one in the *loads array.
3372  */
lSelectLoads(const std::vector<int64_t> & loadOffsets,std::vector<CoalescedLoadOp> * loads)3373 static void lSelectLoads(const std::vector<int64_t> &loadOffsets, std::vector<CoalescedLoadOp> *loads) {
3374     // First, get a sorted set of unique offsets to load from.
3375     std::set<int64_t> allOffsets;
3376     for (unsigned int i = 0; i < loadOffsets.size(); ++i)
3377         allOffsets.insert(loadOffsets[i]);
3378 
3379     std::set<int64_t>::iterator iter = allOffsets.begin();
3380     while (iter != allOffsets.end()) {
3381         Debug(SourcePos(), "Load needed at %" PRId64 ".", *iter);
3382         ++iter;
3383     }
3384 
3385     // Now, iterate over the offsets from low to high.  Starting at the
3386     // current offset, we see if a vector load starting from that offset
3387     // will cover loads at subsequent offsets as well.
3388     iter = allOffsets.begin();
3389     while (iter != allOffsets.end()) {
3390         // Consider vector loads of width of each of the elements of
3391         // spanSizes[], in order.
3392         int vectorWidths[] = {8, 4, 2};
3393         int nVectorWidths = sizeof(vectorWidths) / sizeof(vectorWidths[0]);
3394         bool gotOne = false;
3395         for (int i = 0; i < nVectorWidths; ++i) {
3396             // See if a load of vector with width vectorWidths[i] would be
3397             // effective (i.e. would cover a reasonable number of the
3398             // offsets that need to be loaded from).
3399             std::set<int64_t>::iterator newIter;
3400             if (lVectorLoadIsEfficient(iter, allOffsets.end(), &newIter, vectorWidths[i])) {
3401                 // Yes: create the corresponding coalesced load and update
3402                 // the iterator to the returned iterator; doing so skips
3403                 // over the additional offsets that are taken care of by
3404                 // this load.
3405                 loads->push_back(CoalescedLoadOp(*iter, vectorWidths[i]));
3406                 iter = newIter;
3407                 gotOne = true;
3408                 break;
3409             }
3410         }
3411 
3412         if (gotOne == false) {
3413             // We couldn't find a vector load starting from this offset
3414             // that made sense, so emit a scalar load and continue onward.
3415             loads->push_back(CoalescedLoadOp(*iter, 1));
3416             ++iter;
3417         }
3418     }
3419 }
3420 
3421 /** Print a performance message with the details of the result of
3422     coalescing over a group of gathers. */
lCoalescePerfInfo(const std::vector<llvm::CallInst * > & coalesceGroup,const std::vector<CoalescedLoadOp> & loadOps)3423 static void lCoalescePerfInfo(const std::vector<llvm::CallInst *> &coalesceGroup,
3424                               const std::vector<CoalescedLoadOp> &loadOps) {
3425     SourcePos pos;
3426     lGetSourcePosFromMetadata(coalesceGroup[0], &pos);
3427 
3428     // Create a string that indicates the line numbers of the subsequent
3429     // gathers from the first one that were coalesced here.
3430     char otherPositions[512];
3431     otherPositions[0] = '\0';
3432     if (coalesceGroup.size() > 1) {
3433         const char *plural = (coalesceGroup.size() > 2) ? "s" : "";
3434         char otherBuf[32];
3435         snprintf(otherBuf, sizeof(otherBuf), "(other%s at line%s ", plural, plural);
3436         strncat(otherPositions, otherBuf, sizeof(otherPositions) - strlen(otherPositions) - 1);
3437 
3438         for (int i = 1; i < (int)coalesceGroup.size(); ++i) {
3439             SourcePos p;
3440             bool ok = lGetSourcePosFromMetadata(coalesceGroup[i], &p);
3441             if (ok) {
3442                 char buf[32];
3443                 snprintf(buf, sizeof(buf), "%d", p.first_line);
3444                 strncat(otherPositions, buf, sizeof(otherPositions) - strlen(otherPositions) - 1);
3445                 if (i < (int)coalesceGroup.size() - 1)
3446                     strncat(otherPositions, ", ", sizeof(otherPositions) - strlen(otherPositions) - 1);
3447             }
3448         }
3449         strncat(otherPositions, ") ", sizeof(otherPositions) - strlen(otherPositions) - 1);
3450     }
3451 
3452     // Count how many loads of each size there were.
3453     std::map<int, int> loadOpsCount;
3454     for (int i = 0; i < (int)loadOps.size(); ++i)
3455         ++loadOpsCount[loadOps[i].count];
3456 
3457     // Generate a string the describes the mix of load ops
3458     char loadOpsInfo[512];
3459     loadOpsInfo[0] = '\0';
3460     std::map<int, int>::const_iterator iter = loadOpsCount.begin();
3461     while (iter != loadOpsCount.end()) {
3462         char buf[32];
3463         snprintf(buf, sizeof(buf), "%d x %d-wide", iter->second, iter->first);
3464         if ((strlen(loadOpsInfo) + strlen(buf)) >= 512) {
3465             break;
3466         }
3467         strncat(loadOpsInfo, buf, sizeof(loadOpsInfo) - strlen(loadOpsInfo) - 1);
3468         ++iter;
3469         if (iter != loadOpsCount.end())
3470             strncat(loadOpsInfo, ", ", sizeof(loadOpsInfo) - strlen(loadOpsInfo) - 1);
3471     }
3472 
3473     if (g->opt.level > 0) {
3474         if (coalesceGroup.size() == 1)
3475             PerformanceWarning(pos, "Coalesced gather into %d load%s (%s).", (int)loadOps.size(),
3476                                (loadOps.size() > 1) ? "s" : "", loadOpsInfo);
3477         else
3478             PerformanceWarning(pos,
3479                                "Coalesced %d gathers starting here %sinto %d "
3480                                "load%s (%s).",
3481                                (int)coalesceGroup.size(), otherPositions, (int)loadOps.size(),
3482                                (loadOps.size() > 1) ? "s" : "", loadOpsInfo);
3483     }
3484 }
3485 
3486 /** Utility routine that computes an offset from a base pointer and then
3487     returns the result of a load of the given type from the resulting
3488     location:
3489 
3490     return *((type *)(basePtr + offset))
3491  */
lGEPAndLoad(llvm::Value * basePtr,int64_t offset,int align,llvm::Instruction * insertBefore,llvm::Type * type)3492 llvm::Value *lGEPAndLoad(llvm::Value *basePtr, int64_t offset, int align, llvm::Instruction *insertBefore,
3493                          llvm::Type *type) {
3494     llvm::Value *ptr = lGEPInst(basePtr, LLVMInt64(offset), "new_base", insertBefore);
3495     ptr = new llvm::BitCastInst(ptr, llvm::PointerType::get(type, 0), "ptr_cast", insertBefore);
3496 #if ISPC_LLVM_VERSION < ISPC_LLVM_11_0
3497     return new llvm::LoadInst(ptr, "gather_load", false /* not volatile */, llvm::MaybeAlign(align), insertBefore);
3498 #else // LLVM 11.0+
3499     Assert(llvm::isa<llvm::PointerType>(ptr->getType()));
3500     return new llvm::LoadInst(llvm::dyn_cast<llvm::PointerType>(ptr->getType())->getPointerElementType(), ptr,
3501                               "gather_load", false /* not volatile */, llvm::MaybeAlign(align).valueOrOne(),
3502                               insertBefore);
3503 #endif
3504 }
3505 
3506 /* Having decided that we're doing to emit a series of loads, as encoded in
3507    the loadOps array, this function emits the corresponding load
3508    instructions.
3509  */
lEmitLoads(llvm::Value * basePtr,std::vector<CoalescedLoadOp> & loadOps,int elementSize,llvm::Instruction * insertBefore)3510 static void lEmitLoads(llvm::Value *basePtr, std::vector<CoalescedLoadOp> &loadOps, int elementSize,
3511                        llvm::Instruction *insertBefore) {
3512     Debug(SourcePos(), "Coalesce doing %d loads.", (int)loadOps.size());
3513     for (int i = 0; i < (int)loadOps.size(); ++i) {
3514         Debug(SourcePos(), "Load #%d @ %" PRId64 ", %d items", i, loadOps[i].start, loadOps[i].count);
3515 
3516         // basePtr is an i8 *, so the offset from it should be in terms of
3517         // bytes, not underlying i32 elements.
3518         int64_t start = loadOps[i].start * elementSize;
3519 
3520         int align = 4;
3521         switch (loadOps[i].count) {
3522         case 1:
3523             // Single 32-bit scalar load
3524             loadOps[i].load = lGEPAndLoad(basePtr, start, align, insertBefore, LLVMTypes::Int32Type);
3525             break;
3526         case 2: {
3527             // Emit 2 x i32 loads as i64 loads and then break the result
3528             // into two 32-bit parts.
3529             loadOps[i].load = lGEPAndLoad(basePtr, start, align, insertBefore, LLVMTypes::Int64Type);
3530             // element0 = (int32)value;
3531             loadOps[i].element0 =
3532                 new llvm::TruncInst(loadOps[i].load, LLVMTypes::Int32Type, "load64_elt0", insertBefore);
3533             // element1 = (int32)(value >> 32)
3534             llvm::Value *shift = llvm::BinaryOperator::Create(llvm::Instruction::LShr, loadOps[i].load, LLVMInt64(32),
3535                                                               "load64_shift", insertBefore);
3536             loadOps[i].element1 = new llvm::TruncInst(shift, LLVMTypes::Int32Type, "load64_elt1", insertBefore);
3537             break;
3538         }
3539         case 4: {
3540             // 4-wide vector load
3541             if (g->opt.forceAlignedMemory) {
3542                 align = g->target->getNativeVectorAlignment();
3543             }
3544             llvm::VectorType *vt = LLVMVECTOR::get(LLVMTypes::Int32Type, 4);
3545             loadOps[i].load = lGEPAndLoad(basePtr, start, align, insertBefore, vt);
3546             break;
3547         }
3548         case 8: {
3549             // 8-wide vector load
3550             if (g->opt.forceAlignedMemory) {
3551                 align = g->target->getNativeVectorAlignment();
3552             }
3553             llvm::VectorType *vt = LLVMVECTOR::get(LLVMTypes::Int32Type, 8);
3554             loadOps[i].load = lGEPAndLoad(basePtr, start, align, insertBefore, vt);
3555             break;
3556         }
3557         default:
3558             FATAL("Unexpected load count in lEmitLoads()");
3559         }
3560     }
3561 }
3562 
3563 /** Convert any loads of 8-wide vectors into two 4-wide vectors
3564     (logically).  This allows the assembly code below to always operate on
3565     4-wide vectors, which leads to better code.  Returns a new vector of
3566     load operations.
3567  */
lSplit8WideLoads(const std::vector<CoalescedLoadOp> & loadOps,llvm::Instruction * insertBefore)3568 static std::vector<CoalescedLoadOp> lSplit8WideLoads(const std::vector<CoalescedLoadOp> &loadOps,
3569                                                      llvm::Instruction *insertBefore) {
3570     std::vector<CoalescedLoadOp> ret;
3571     for (unsigned int i = 0; i < loadOps.size(); ++i) {
3572         if (loadOps[i].count == 8) {
3573             // Create fake CoalescedLOadOps, where the load llvm::Value is
3574             // actually a shuffle that pulls either the first 4 or the last
3575             // 4 values out of the original 8-wide loaded value.
3576             int32_t shuf[2][4] = {{0, 1, 2, 3}, {4, 5, 6, 7}};
3577 
3578             ret.push_back(CoalescedLoadOp(loadOps[i].start, 4));
3579             ret.back().load = LLVMShuffleVectors(loadOps[i].load, loadOps[i].load, shuf[0], 4, insertBefore);
3580 
3581             ret.push_back(CoalescedLoadOp(loadOps[i].start + 4, 4));
3582             ret.back().load = LLVMShuffleVectors(loadOps[i].load, loadOps[i].load, shuf[1], 4, insertBefore);
3583         } else
3584             ret.push_back(loadOps[i]);
3585     }
3586 
3587     return ret;
3588 }
3589 
3590 /** Given a 1-wide load of a 32-bit value, merge its value into the result
3591     vector for any and all elements for which it applies.
3592  */
lApplyLoad1(llvm::Value * result,const CoalescedLoadOp & load,const int64_t offsets[4],bool set[4],llvm::Instruction * insertBefore)3593 static llvm::Value *lApplyLoad1(llvm::Value *result, const CoalescedLoadOp &load, const int64_t offsets[4], bool set[4],
3594                                 llvm::Instruction *insertBefore) {
3595     for (int elt = 0; elt < 4; ++elt) {
3596         if (offsets[elt] >= load.start && offsets[elt] < load.start + load.count) {
3597             Debug(SourcePos(),
3598                   "Load 1 @ %" PRId64 " matches for element #%d "
3599                   "(value %" PRId64 ")",
3600                   load.start, elt, offsets[elt]);
3601             // If this load gives one of the values that we need, then we
3602             // can just insert it in directly
3603             Assert(set[elt] == false);
3604             result = llvm::InsertElementInst::Create(result, load.load, LLVMInt32(elt), "insert_load", insertBefore);
3605             set[elt] = true;
3606         }
3607     }
3608 
3609     return result;
3610 }
3611 
3612 /** Similarly, incorporate the values from a 2-wide load into any vector
3613     elements that they apply to. */
lApplyLoad2(llvm::Value * result,const CoalescedLoadOp & load,const int64_t offsets[4],bool set[4],llvm::Instruction * insertBefore)3614 static llvm::Value *lApplyLoad2(llvm::Value *result, const CoalescedLoadOp &load, const int64_t offsets[4], bool set[4],
3615                                 llvm::Instruction *insertBefore) {
3616     int elt = 0;
3617     while (elt < 4) {
3618         // First, try to do a 64-bit-wide insert into the result vector.
3619         // We can do this when we're currently at an even element, when the
3620         // current and next element have consecutive values, and where the
3621         // original 64-bit load is at the offset needed by the current
3622         // element.
3623         if ((elt & 1) == 0 && offsets[elt] + 1 == offsets[elt + 1] && offsets[elt] == load.start) {
3624             Debug(SourcePos(),
3625                   "Load 2 @ %" PRId64 " matches for elements #%d,%d "
3626                   "(values %" PRId64 ",%" PRId64 ")",
3627                   load.start, elt, elt + 1, offsets[elt], offsets[elt + 1]);
3628             Assert(set[elt] == false && ((elt < 3) && set[elt + 1] == false));
3629 
3630             // In this case, we bitcast from a 4xi32 to a 2xi64 vector
3631             llvm::Type *vec2x64Type = LLVMVECTOR::get(LLVMTypes::Int64Type, 2);
3632             result = new llvm::BitCastInst(result, vec2x64Type, "to2x64", insertBefore);
3633 
3634             // And now we can insert the 64-bit wide value into the
3635             // appropriate elment
3636             result = llvm::InsertElementInst::Create(result, load.load, LLVMInt32(elt / 2), "insert64", insertBefore);
3637 
3638             // And back to 4xi32.
3639             llvm::Type *vec4x32Type = LLVMVECTOR::get(LLVMTypes::Int32Type, 4);
3640             result = new llvm::BitCastInst(result, vec4x32Type, "to4x32", insertBefore);
3641 
3642             set[elt] = true;
3643             if (elt < 3) {
3644                 set[elt + 1] = true;
3645             }
3646             // Advance elt one extra time, since we just took care of two
3647             // elements
3648             ++elt;
3649         } else if (offsets[elt] >= load.start && offsets[elt] < load.start + load.count) {
3650             Debug(SourcePos(),
3651                   "Load 2 @ %" PRId64 " matches for element #%d "
3652                   "(value %" PRId64 ")",
3653                   load.start, elt, offsets[elt]);
3654             // Otherwise, insert one of the 32-bit pieces into an element
3655             // of the final vector
3656             Assert(set[elt] == false);
3657             llvm::Value *toInsert = (offsets[elt] == load.start) ? load.element0 : load.element1;
3658             result = llvm::InsertElementInst::Create(result, toInsert, LLVMInt32(elt), "insert_load", insertBefore);
3659             set[elt] = true;
3660         }
3661         ++elt;
3662     }
3663 
3664     return result;
3665 }
3666 
3667 #if 1
3668 /* This approach works better with AVX, while the #else path generates
3669    slightly better code with SSE.  Need to continue to dig into performance
3670    details with this stuff in general... */
3671 
3672 /** And handle a 4-wide load */
lApplyLoad4(llvm::Value * result,const CoalescedLoadOp & load,const int64_t offsets[4],bool set[4],llvm::Instruction * insertBefore)3673 static llvm::Value *lApplyLoad4(llvm::Value *result, const CoalescedLoadOp &load, const int64_t offsets[4], bool set[4],
3674                                 llvm::Instruction *insertBefore) {
3675     // Conceptually, we're doing to consider doing a shuffle vector with
3676     // the 4-wide load and the 4-wide result we have so far to generate a
3677     // new 4-wide vector.  We'll start with shuffle indices that just
3678     // select each element of the result so far for the result.
3679     int32_t shuf[4] = {4, 5, 6, 7};
3680 
3681     for (int elt = 0; elt < 4; ++elt) {
3682         if (offsets[elt] >= load.start && offsets[elt] < load.start + load.count) {
3683             Debug(SourcePos(),
3684                   "Load 4 @ %" PRId64 " matches for element #%d "
3685                   "(value %" PRId64 ")",
3686                   load.start, elt, offsets[elt]);
3687 
3688             // If the current element falls within the range of locations
3689             // that the 4-wide load covers, then compute the appropriate
3690             // shuffle index that extracts the appropriate element from the
3691             // load.
3692             Assert(set[elt] == false);
3693             shuf[elt] = int32_t(offsets[elt] - load.start);
3694             set[elt] = true;
3695         }
3696     }
3697 
3698     // Now, issue a shufflevector instruction if any of the values from the
3699     // load we just considered were applicable.
3700     if (shuf[0] != 4 || shuf[1] != 5 || shuf[2] != 6 || shuf[3] != 7)
3701         result = LLVMShuffleVectors(load.load, result, shuf, 4, insertBefore);
3702 
3703     return result;
3704 }
3705 
3706 /** We're need to fill in the values for a 4-wide result vector.  This
3707     function looks at all of the generated loads and extracts the
3708     appropriate elements from the appropriate loads to assemble the result.
3709     Here the offsets[] parameter gives the 4 offsets from the base pointer
3710     for the four elements of the result.
3711 */
lAssemble4Vector(const std::vector<CoalescedLoadOp> & loadOps,const int64_t offsets[4],llvm::Instruction * insertBefore)3712 static llvm::Value *lAssemble4Vector(const std::vector<CoalescedLoadOp> &loadOps, const int64_t offsets[4],
3713                                      llvm::Instruction *insertBefore) {
3714     llvm::Type *returnType = LLVMVECTOR::get(LLVMTypes::Int32Type, 4);
3715     llvm::Value *result = llvm::UndefValue::get(returnType);
3716 
3717     Debug(SourcePos(), "Starting search for loads [%" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 "].", offsets[0],
3718           offsets[1], offsets[2], offsets[3]);
3719 
3720     // Track whether we have found a valid value for each of the four
3721     // elements of the result
3722     bool set[4] = {false, false, false, false};
3723 
3724     // Loop over all of the loads and check each one to see if it provides
3725     // a value that's applicable to the result
3726     for (int load = 0; load < (int)loadOps.size(); ++load) {
3727         const CoalescedLoadOp &li = loadOps[load];
3728 
3729         switch (li.count) {
3730         case 1:
3731             result = lApplyLoad1(result, li, offsets, set, insertBefore);
3732             break;
3733         case 2:
3734             result = lApplyLoad2(result, li, offsets, set, insertBefore);
3735             break;
3736         case 4:
3737             result = lApplyLoad4(result, li, offsets, set, insertBefore);
3738             break;
3739         default:
3740             FATAL("Unexpected load count in lAssemble4Vector()");
3741         }
3742     }
3743 
3744     Debug(SourcePos(), "Done with search for loads [%" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 "].", offsets[0],
3745           offsets[1], offsets[2], offsets[3]);
3746 
3747     for (int i = 0; i < 4; ++i)
3748         Assert(set[i] == true);
3749 
3750     return result;
3751 }
3752 
3753 #else
3754 
lApplyLoad4s(llvm::Value * result,const std::vector<CoalescedLoadOp> & loadOps,const int64_t offsets[4],bool set[4],llvm::Instruction * insertBefore)3755 static llvm::Value *lApplyLoad4s(llvm::Value *result, const std::vector<CoalescedLoadOp> &loadOps,
3756                                  const int64_t offsets[4], bool set[4], llvm::Instruction *insertBefore) {
3757     int32_t firstMatchElements[4] = {-1, -1, -1, -1};
3758     const CoalescedLoadOp *firstMatch = NULL;
3759 
3760     Assert(llvm::isa<llvm::UndefValue>(result));
3761 
3762     for (int load = 0; load < (int)loadOps.size(); ++load) {
3763         const CoalescedLoadOp &loadop = loadOps[load];
3764         if (loadop.count != 4)
3765             continue;
3766 
3767         int32_t matchElements[4] = {-1, -1, -1, -1};
3768         bool anyMatched = false;
3769         for (int elt = 0; elt < 4; ++elt) {
3770             if (offsets[elt] >= loadop.start && offsets[elt] < loadop.start + loadop.count) {
3771                 Debug(SourcePos(),
3772                       "Load 4 @ %" PRId64 " matches for element #%d "
3773                       "(value %" PRId64 ")",
3774                       loadop.start, elt, offsets[elt]);
3775                 anyMatched = true;
3776                 Assert(set[elt] == false);
3777                 matchElements[elt] = offsets[elt] - loadop.start;
3778                 set[elt] = true;
3779             }
3780         }
3781 
3782         if (anyMatched) {
3783             if (llvm::isa<llvm::UndefValue>(result)) {
3784                 if (firstMatch == NULL) {
3785                     firstMatch = &loadop;
3786                     for (int i = 0; i < 4; ++i)
3787                         firstMatchElements[i] = matchElements[i];
3788                 } else {
3789                     int32_t shuffle[4] = {-1, -1, -1, -1};
3790                     for (int i = 0; i < 4; ++i) {
3791                         if (firstMatchElements[i] != -1)
3792                             shuffle[i] = firstMatchElements[i];
3793                         else
3794                             shuffle[i] = 4 + matchElements[i];
3795                     }
3796                     result = LLVMShuffleVectors(firstMatch->load, loadop.load, shuffle, 4, insertBefore);
3797                     firstMatch = NULL;
3798                 }
3799             } else {
3800                 int32_t shuffle[4] = {-1, -1, -1, -1};
3801                 for (int i = 0; i < 4; ++i) {
3802                     if (matchElements[i] != -1)
3803                         shuffle[i] = 4 + matchElements[i];
3804                     else
3805                         shuffle[i] = i;
3806                 }
3807                 result = LLVMShuffleVectors(result, loadop.load, shuffle, 4, insertBefore);
3808             }
3809         }
3810     }
3811 
3812     if (firstMatch != NULL && llvm::isa<llvm::UndefValue>(result))
3813         return LLVMShuffleVectors(firstMatch->load, result, firstMatchElements, 4, insertBefore);
3814     else
3815         return result;
3816 }
3817 
lApplyLoad12s(llvm::Value * result,const std::vector<CoalescedLoadOp> & loadOps,const int64_t offsets[4],bool set[4],llvm::Instruction * insertBefore)3818 static llvm::Value *lApplyLoad12s(llvm::Value *result, const std::vector<CoalescedLoadOp> &loadOps,
3819                                   const int64_t offsets[4], bool set[4], llvm::Instruction *insertBefore) {
3820     // Loop over all of the loads and check each one to see if it provides
3821     // a value that's applicable to the result
3822     for (int load = 0; load < (int)loadOps.size(); ++load) {
3823         const CoalescedLoadOp &loadop = loadOps[load];
3824         Assert(loadop.count == 1 || loadop.count == 2 || loadop.count == 4);
3825 
3826         if (loadop.count == 1)
3827             result = lApplyLoad1(result, loadop, offsets, set, insertBefore);
3828         else if (loadop.count == 2)
3829             result = lApplyLoad2(result, loadop, offsets, set, insertBefore);
3830     }
3831     return result;
3832 }
3833 
3834 /** We're need to fill in the values for a 4-wide result vector.  This
3835     function looks at all of the generated loads and extracts the
3836     appropriate elements from the appropriate loads to assemble the result.
3837     Here the offsets[] parameter gives the 4 offsets from the base pointer
3838     for the four elements of the result.
3839 */
lAssemble4Vector(const std::vector<CoalescedLoadOp> & loadOps,const int64_t offsets[4],llvm::Instruction * insertBefore)3840 static llvm::Value *lAssemble4Vector(const std::vector<CoalescedLoadOp> &loadOps, const int64_t offsets[4],
3841                                      llvm::Instruction *insertBefore) {
3842     llvm::Type *returnType = LLVMVECTOR::get(LLVMTypes::Int32Type, 4);
3843     llvm::Value *result = llvm::UndefValue::get(returnType);
3844 
3845     Debug(SourcePos(), "Starting search for loads [%" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 "].", offsets[0],
3846           offsets[1], offsets[2], offsets[3]);
3847 
3848     // Track whether we have found a valid value for each of the four
3849     // elements of the result
3850     bool set[4] = {false, false, false, false};
3851 
3852     result = lApplyLoad4s(result, loadOps, offsets, set, insertBefore);
3853     result = lApplyLoad12s(result, loadOps, offsets, set, insertBefore);
3854 
3855     Debug(SourcePos(), "Done with search for loads [%" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 "].", offsets[0],
3856           offsets[1], offsets[2], offsets[3]);
3857 
3858     for (int i = 0; i < 4; ++i)
3859         Assert(set[i] == true);
3860 
3861     return result;
3862 }
3863 #endif
3864 
3865 /** Given the set of loads that we've done and the set of result values to
3866     be computed, this function computes the final llvm::Value *s for each
3867     result vector.
3868  */
lAssembleResultVectors(const std::vector<CoalescedLoadOp> & loadOps,const std::vector<int64_t> & constOffsets,std::vector<llvm::Value * > & results,llvm::Instruction * insertBefore)3869 static void lAssembleResultVectors(const std::vector<CoalescedLoadOp> &loadOps,
3870                                    const std::vector<int64_t> &constOffsets, std::vector<llvm::Value *> &results,
3871                                    llvm::Instruction *insertBefore) {
3872     // We work on 4-wide chunks of the final values, even when we're
3873     // computing 8-wide or 16-wide vectors.  This gives better code from
3874     // LLVM's SSE/AVX code generators.
3875     Assert((constOffsets.size() % 4) == 0);
3876     std::vector<llvm::Value *> vec4s;
3877     for (int i = 0; i < (int)constOffsets.size(); i += 4)
3878         vec4s.push_back(lAssemble4Vector(loadOps, &constOffsets[i], insertBefore));
3879 
3880     // And now concatenate 1, 2, or 4 of the 4-wide vectors computed above
3881     // into 4, 8, or 16-wide final result vectors.
3882     int numGathers = constOffsets.size() / g->target->getVectorWidth();
3883     for (int i = 0; i < numGathers; ++i) {
3884         llvm::Value *result = NULL;
3885         switch (g->target->getVectorWidth()) {
3886         case 4:
3887             result = vec4s[i];
3888             break;
3889         case 8:
3890             result = LLVMConcatVectors(vec4s[2 * i], vec4s[2 * i + 1], insertBefore);
3891             break;
3892         case 16: {
3893             llvm::Value *v1 = LLVMConcatVectors(vec4s[4 * i], vec4s[4 * i + 1], insertBefore);
3894             llvm::Value *v2 = LLVMConcatVectors(vec4s[4 * i + 2], vec4s[4 * i + 3], insertBefore);
3895             result = LLVMConcatVectors(v1, v2, insertBefore);
3896             break;
3897         }
3898         default:
3899             FATAL("Unhandled vector width in lAssembleResultVectors()");
3900         }
3901 
3902         results.push_back(result);
3903     }
3904 }
3905 
3906 /** Given a call to a gather function, extract the base pointer, the 2/4/8
3907     scale, and the first varying offsets value to use them to compute that
3908     scalar base pointer that is shared by all of the gathers in the group.
3909     (Thus, this base pointer plus the constant offsets term for each gather
3910     gives the set of addresses to use for each gather.
3911  */
lComputeBasePtr(llvm::CallInst * gatherInst,llvm::Instruction * insertBefore)3912 static llvm::Value *lComputeBasePtr(llvm::CallInst *gatherInst, llvm::Instruction *insertBefore) {
3913     llvm::Value *basePtr = gatherInst->getArgOperand(0);
3914     llvm::Value *variableOffsets = gatherInst->getArgOperand(1);
3915     llvm::Value *offsetScale = gatherInst->getArgOperand(2);
3916 
3917     // All of the variable offsets values should be the same, due to
3918     // checking for this in GatherCoalescePass::runOnBasicBlock().  Thus,
3919     // extract the first value and use that as a scalar.
3920     llvm::Value *variable = LLVMExtractFirstVectorElement(variableOffsets);
3921     Assert(variable != NULL);
3922     if (variable->getType() == LLVMTypes::Int64Type)
3923         offsetScale = new llvm::ZExtInst(offsetScale, LLVMTypes::Int64Type, "scale_to64", insertBefore);
3924     llvm::Value *offset =
3925         llvm::BinaryOperator::Create(llvm::Instruction::Mul, variable, offsetScale, "offset", insertBefore);
3926 
3927     return lGEPInst(basePtr, offset, "new_base", insertBefore);
3928 }
3929 
3930 /** Extract the constant offsets (from the common base pointer) from each
3931     of the gathers in a set to be coalesced.  These come in as byte
3932     offsets, but we'll transform them into offsets in terms of the size of
3933     the base scalar type being gathered.  (e.g. for an i32 gather, we might
3934     have offsets like <0,4,16,20>, which would be transformed to <0,1,4,5>
3935     here.)
3936  */
lExtractConstOffsets(const std::vector<llvm::CallInst * > & coalesceGroup,int elementSize,std::vector<int64_t> * constOffsets)3937 static void lExtractConstOffsets(const std::vector<llvm::CallInst *> &coalesceGroup, int elementSize,
3938                                  std::vector<int64_t> *constOffsets) {
3939     int width = g->target->getVectorWidth();
3940     *constOffsets = std::vector<int64_t>(coalesceGroup.size() * width, 0);
3941 
3942     int64_t *endPtr = &((*constOffsets)[0]);
3943     for (int i = 0; i < (int)coalesceGroup.size(); ++i, endPtr += width) {
3944         llvm::Value *offsets = coalesceGroup[i]->getArgOperand(3);
3945         int nElts;
3946         bool ok = LLVMExtractVectorInts(offsets, endPtr, &nElts);
3947         Assert(ok && nElts == width);
3948     }
3949 
3950     for (int i = 0; i < (int)constOffsets->size(); ++i)
3951         (*constOffsets)[i] /= elementSize;
3952 }
3953 
3954 /** Actually do the coalescing.  We have a set of gathers all accessing
3955     addresses of the form:
3956 
3957     (ptr + {1,2,4,8} * varyingOffset) + constOffset, a.k.a.
3958     basePtr + constOffset
3959 
3960     where varyingOffset actually has the same value across all of the SIMD
3961     lanes and where the part in parenthesis has the same value for all of
3962     the gathers in the group.
3963  */
lCoalesceGathers(const std::vector<llvm::CallInst * > & coalesceGroup)3964 static bool lCoalesceGathers(const std::vector<llvm::CallInst *> &coalesceGroup) {
3965     llvm::Instruction *insertBefore = coalesceGroup[0];
3966 
3967     // First, compute the shared base pointer for all of the gathers
3968     llvm::Value *basePtr = lComputeBasePtr(coalesceGroup[0], insertBefore);
3969 
3970     int elementSize = 0;
3971     if (coalesceGroup[0]->getType() == LLVMTypes::Int32VectorType ||
3972         coalesceGroup[0]->getType() == LLVMTypes::FloatVectorType)
3973         elementSize = 4;
3974     else if (coalesceGroup[0]->getType() == LLVMTypes::Int64VectorType ||
3975              coalesceGroup[0]->getType() == LLVMTypes::DoubleVectorType)
3976         elementSize = 8;
3977     else
3978         FATAL("Unexpected gather type in lCoalesceGathers");
3979 
3980     // Extract the constant offsets from the gathers into the constOffsets
3981     // vector: the first vectorWidth elements will be those for the first
3982     // gather, the next vectorWidth those for the next gather, and so
3983     // forth.
3984     std::vector<int64_t> constOffsets;
3985     lExtractConstOffsets(coalesceGroup, elementSize, &constOffsets);
3986 
3987     // Determine a set of loads to perform to get all of the values we need
3988     // loaded.
3989     std::vector<CoalescedLoadOp> loadOps;
3990     lSelectLoads(constOffsets, &loadOps);
3991 
3992     lCoalescePerfInfo(coalesceGroup, loadOps);
3993 
3994     // Actually emit load instructions for them
3995     lEmitLoads(basePtr, loadOps, elementSize, insertBefore);
3996 
3997     // Now, for any loads that give us <8 x i32> vectors, split their
3998     // values into two <4 x i32> vectors; it turns out that LLVM gives us
3999     // better code on AVX when we assemble the pieces from 4-wide vectors.
4000     loadOps = lSplit8WideLoads(loadOps, insertBefore);
4001 
4002     // Given all of these chunks of values, shuffle together a vector that
4003     // gives us each result value; the i'th element of results[] gives the
4004     // result for the i'th gather in coalesceGroup.
4005     std::vector<llvm::Value *> results;
4006     lAssembleResultVectors(loadOps, constOffsets, results, insertBefore);
4007 
4008     // Finally, replace each of the original gathers with the instruction
4009     // that gives the value from the coalescing process.
4010     Assert(results.size() == coalesceGroup.size());
4011     for (int i = 0; i < (int)results.size(); ++i) {
4012         llvm::Instruction *ir = llvm::dyn_cast<llvm::Instruction>(results[i]);
4013         Assert(ir != NULL);
4014 
4015         llvm::Type *origType = coalesceGroup[i]->getType();
4016         if (origType != ir->getType())
4017             ir = new llvm::BitCastInst(ir, origType, ir->getName(), coalesceGroup[i]);
4018 
4019         // Previously, all of the instructions to compute the final result
4020         // were into the basic block here; here we remove the very last one
4021         // of them (that holds the final result) from the basic block.
4022         // This way, the following ReplaceInstWithInst() call will operate
4023         // successfully. (It expects that the second argument not be in any
4024         // basic block.)
4025         ir->removeFromParent();
4026 
4027         llvm::ReplaceInstWithInst(coalesceGroup[i], ir);
4028     }
4029 
4030     return true;
4031 }
4032 
4033 /** Given an instruction, returns true if the instructon may write to
4034     memory.  This is a conservative test in that it may return true for
4035     some instructions that don't actually end up writing to memory, but
4036     should never return false for an instruction that does write to
4037     memory. */
lInstructionMayWriteToMemory(llvm::Instruction * inst)4038 static bool lInstructionMayWriteToMemory(llvm::Instruction *inst) {
4039     if (llvm::isa<llvm::StoreInst>(inst) || llvm::isa<llvm::AtomicRMWInst>(inst) ||
4040         llvm::isa<llvm::AtomicCmpXchgInst>(inst))
4041         // FIXME: we could be less conservative and try to allow stores if
4042         // we are sure that the pointers don't overlap..
4043         return true;
4044 
4045     // Otherwise, any call instruction that doesn't have an attribute
4046     // indicating it won't write to memory has to be treated as a potential
4047     // store.
4048     llvm::CallInst *ci = llvm::dyn_cast<llvm::CallInst>(inst);
4049     if (ci != NULL) {
4050         llvm::Function *calledFunc = ci->getCalledFunction();
4051         if (calledFunc == NULL)
4052             return true;
4053 
4054         if (calledFunc->onlyReadsMemory() || calledFunc->doesNotAccessMemory())
4055             return false;
4056         return true;
4057     }
4058 
4059     return false;
4060 }
4061 
runOnBasicBlock(llvm::BasicBlock & bb)4062 bool GatherCoalescePass::runOnBasicBlock(llvm::BasicBlock &bb) {
4063     DEBUG_START_PASS("GatherCoalescePass");
4064 
4065     llvm::Function *gatherFuncs[] = {
4066         m->module->getFunction("__pseudo_gather_factored_base_offsets32_i32"),
4067         m->module->getFunction("__pseudo_gather_factored_base_offsets32_float"),
4068         m->module->getFunction("__pseudo_gather_factored_base_offsets64_i32"),
4069         m->module->getFunction("__pseudo_gather_factored_base_offsets64_float"),
4070     };
4071     int nGatherFuncs = sizeof(gatherFuncs) / sizeof(gatherFuncs[0]);
4072 
4073     bool modifiedAny = false;
4074 
4075 restart:
4076     for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
4077         // Iterate over all of the instructions and look for calls to
4078         // __pseudo_gather_factored_base_offsets{32,64}_{i32,float} calls.
4079         llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(&*iter);
4080         if (callInst == NULL)
4081             continue;
4082 
4083         llvm::Function *calledFunc = callInst->getCalledFunction();
4084         if (calledFunc == NULL)
4085             continue;
4086 
4087         int i;
4088         for (i = 0; i < nGatherFuncs; ++i)
4089             if (gatherFuncs[i] != NULL && calledFunc == gatherFuncs[i])
4090                 break;
4091         if (i == nGatherFuncs)
4092             // Doesn't match any of the types of gathers we care about
4093             continue;
4094 
4095         SourcePos pos;
4096         lGetSourcePosFromMetadata(callInst, &pos);
4097         Debug(pos, "Checking for coalescable gathers starting here...");
4098 
4099         llvm::Value *base = callInst->getArgOperand(0);
4100         llvm::Value *variableOffsets = callInst->getArgOperand(1);
4101         llvm::Value *offsetScale = callInst->getArgOperand(2);
4102         llvm::Value *mask = callInst->getArgOperand(4);
4103 
4104         // To apply this optimization, we need a set of one or more gathers
4105         // that fulfill the following conditions:
4106         //
4107         // - Mask all on
4108         // - The variable offsets to all have the same value (i.e., to be
4109         //   uniform).
4110         // - Same base pointer, variable offsets, and offset scale (for
4111         //   more than one gather)
4112         //
4113         // Then and only then do we have a common base pointer with all
4114         // offsets from that constants (in which case we can potentially
4115         // coalesce).
4116         if (lGetMaskStatus(mask) != MaskStatus::all_on)
4117             continue;
4118 
4119         if (!LLVMVectorValuesAllEqual(variableOffsets))
4120             continue;
4121 
4122         // coalesceGroup stores the set of gathers that we're going to try to
4123         // coalesce over
4124         std::vector<llvm::CallInst *> coalesceGroup;
4125         coalesceGroup.push_back(callInst);
4126 
4127         // Start iterating at the instruction after the initial gather;
4128         // look at the remainder of instructions in the basic block (up
4129         // until we reach a write to memory) to try to find any other
4130         // gathers that can coalesce with this one.
4131         llvm::BasicBlock::iterator fwdIter = iter;
4132         ++fwdIter;
4133         for (; fwdIter != bb.end(); ++fwdIter) {
4134             // Must stop once we come to an instruction that may write to
4135             // memory; otherwise we could end up moving a read before this
4136             // write.
4137             if (lInstructionMayWriteToMemory(&*fwdIter))
4138                 break;
4139 
4140             llvm::CallInst *fwdCall = llvm::dyn_cast<llvm::CallInst>(&*fwdIter);
4141             if (fwdCall == NULL || fwdCall->getCalledFunction() != calledFunc)
4142                 continue;
4143 
4144             SourcePos fwdPos;
4145             // TODO: need to redesign metadata attached to pseudo calls,
4146             // LLVM drops metadata frequently and it results in bad disgnostics.
4147             lGetSourcePosFromMetadata(fwdCall, &fwdPos);
4148 
4149 #ifndef ISPC_NO_DUMPS
4150             if (g->debugPrint) {
4151                 if (base != fwdCall->getArgOperand(0)) {
4152                     Debug(fwdPos, "base pointers mismatch");
4153                     LLVMDumpValue(base);
4154                     LLVMDumpValue(fwdCall->getArgOperand(0));
4155                 }
4156                 if (variableOffsets != fwdCall->getArgOperand(1)) {
4157                     Debug(fwdPos, "varying offsets mismatch");
4158                     LLVMDumpValue(variableOffsets);
4159                     LLVMDumpValue(fwdCall->getArgOperand(1));
4160                 }
4161                 if (offsetScale != fwdCall->getArgOperand(2)) {
4162                     Debug(fwdPos, "offset scales mismatch");
4163                     LLVMDumpValue(offsetScale);
4164                     LLVMDumpValue(fwdCall->getArgOperand(2));
4165                 }
4166                 if (mask != fwdCall->getArgOperand(4)) {
4167                     Debug(fwdPos, "masks mismatch");
4168                     LLVMDumpValue(mask);
4169                     LLVMDumpValue(fwdCall->getArgOperand(4));
4170                 }
4171             }
4172 #endif
4173 
4174             if (base == fwdCall->getArgOperand(0) && variableOffsets == fwdCall->getArgOperand(1) &&
4175                 offsetScale == fwdCall->getArgOperand(2) && mask == fwdCall->getArgOperand(4)) {
4176                 Debug(fwdPos, "This gather can be coalesced.");
4177                 coalesceGroup.push_back(fwdCall);
4178 
4179                 if (coalesceGroup.size() == 4)
4180                     // FIXME: untested heuristic: don't try to coalesce
4181                     // over a window of more than 4 gathers, so that we
4182                     // don't cause too much register pressure and end up
4183                     // spilling to memory anyway.
4184                     break;
4185             } else
4186                 Debug(fwdPos, "This gather doesn't match the initial one.");
4187         }
4188 
4189         Debug(pos, "Done with checking for matching gathers");
4190 
4191         // Now that we have a group of gathers, see if we can coalesce them
4192         // into something more efficient than the original set of gathers.
4193         if (lCoalesceGathers(coalesceGroup)) {
4194             modifiedAny = true;
4195             goto restart;
4196         }
4197     }
4198 
4199     DEBUG_END_PASS("GatherCoalescePass");
4200 
4201     return modifiedAny;
4202 }
4203 
runOnFunction(llvm::Function & F)4204 bool GatherCoalescePass::runOnFunction(llvm::Function &F) {
4205 
4206     llvm::TimeTraceScope FuncScope("GatherCoalescePass::runOnFunction", F.getName());
4207     bool modifiedAny = false;
4208     for (llvm::BasicBlock &BB : F) {
4209         modifiedAny |= runOnBasicBlock(BB);
4210     }
4211     return modifiedAny;
4212 }
4213 
CreateGatherCoalescePass()4214 static llvm::Pass *CreateGatherCoalescePass() { return new GatherCoalescePass; }
4215 
4216 ///////////////////////////////////////////////////////////////////////////
4217 // ReplacePseudoMemoryOpsPass
4218 
4219 /** For any gathers and scatters remaining after the GSToLoadStorePass
4220     runs, we need to turn them into actual native gathers and scatters.
4221     This task is handled by the ReplacePseudoMemoryOpsPass here.
4222  */
4223 class ReplacePseudoMemoryOpsPass : public llvm::FunctionPass {
4224   public:
4225     static char ID;
ReplacePseudoMemoryOpsPass()4226     ReplacePseudoMemoryOpsPass() : FunctionPass(ID) {}
4227 
getPassName() const4228     llvm::StringRef getPassName() const { return "Replace Pseudo Memory Ops"; }
4229     bool runOnBasicBlock(llvm::BasicBlock &BB);
4230     bool runOnFunction(llvm::Function &F);
4231 };
4232 
4233 char ReplacePseudoMemoryOpsPass::ID = 0;
4234 
4235 /** This routine attempts to determine if the given pointer in lvalue is
4236     pointing to stack-allocated memory.  It's conservative in that it
4237     should never return true for non-stack allocated memory, but may return
4238     false for memory that actually is stack allocated.  The basic strategy
4239     is to traverse through the operands and see if the pointer originally
4240     comes from an AllocaInst.
4241 */
lIsSafeToBlend(llvm::Value * lvalue)4242 static bool lIsSafeToBlend(llvm::Value *lvalue) {
4243     llvm::BitCastInst *bc = llvm::dyn_cast<llvm::BitCastInst>(lvalue);
4244     if (bc != NULL)
4245         return lIsSafeToBlend(bc->getOperand(0));
4246     else {
4247         llvm::AllocaInst *ai = llvm::dyn_cast<llvm::AllocaInst>(lvalue);
4248         if (ai) {
4249             llvm::Type *type = ai->getType();
4250             llvm::PointerType *pt = llvm::dyn_cast<llvm::PointerType>(type);
4251             Assert(pt != NULL);
4252             type = pt->getElementType();
4253             llvm::ArrayType *at;
4254             while ((at = llvm::dyn_cast<llvm::ArrayType>(type))) {
4255                 type = at->getElementType();
4256             }
4257 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
4258             llvm::FixedVectorType *vt = llvm::dyn_cast<llvm::FixedVectorType>(type);
4259 #else
4260             llvm::VectorType *vt = llvm::dyn_cast<llvm::VectorType>(type);
4261 #endif
4262             return (vt != NULL && (int)vt->getNumElements() == g->target->getVectorWidth());
4263         } else {
4264             llvm::GetElementPtrInst *gep = llvm::dyn_cast<llvm::GetElementPtrInst>(lvalue);
4265             if (gep != NULL)
4266                 return lIsSafeToBlend(gep->getOperand(0));
4267             else
4268                 return false;
4269         }
4270     }
4271 }
4272 
lReplacePseudoMaskedStore(llvm::CallInst * callInst)4273 static bool lReplacePseudoMaskedStore(llvm::CallInst *callInst) {
4274     struct LMSInfo {
4275         LMSInfo(const char *pname, const char *bname, const char *msname) {
4276             pseudoFunc = m->module->getFunction(pname);
4277             blendFunc = m->module->getFunction(bname);
4278             maskedStoreFunc = m->module->getFunction(msname);
4279             Assert(pseudoFunc != NULL && blendFunc != NULL && maskedStoreFunc != NULL);
4280         }
4281         llvm::Function *pseudoFunc;
4282         llvm::Function *blendFunc;
4283         llvm::Function *maskedStoreFunc;
4284     };
4285 
4286     LMSInfo msInfo[] = {
4287         LMSInfo("__pseudo_masked_store_i8", "__masked_store_blend_i8", "__masked_store_i8"),
4288         LMSInfo("__pseudo_masked_store_i16", "__masked_store_blend_i16", "__masked_store_i16"),
4289         LMSInfo("__pseudo_masked_store_i32", "__masked_store_blend_i32", "__masked_store_i32"),
4290         LMSInfo("__pseudo_masked_store_float", "__masked_store_blend_float", "__masked_store_float"),
4291         LMSInfo("__pseudo_masked_store_i64", "__masked_store_blend_i64", "__masked_store_i64"),
4292         LMSInfo("__pseudo_masked_store_double", "__masked_store_blend_double", "__masked_store_double")};
4293     LMSInfo *info = NULL;
4294     for (unsigned int i = 0; i < sizeof(msInfo) / sizeof(msInfo[0]); ++i) {
4295         if (msInfo[i].pseudoFunc != NULL && callInst->getCalledFunction() == msInfo[i].pseudoFunc) {
4296             info = &msInfo[i];
4297             break;
4298         }
4299     }
4300     if (info == NULL)
4301         return false;
4302 
4303     llvm::Value *lvalue = callInst->getArgOperand(0);
4304     llvm::Value *rvalue = callInst->getArgOperand(1);
4305     llvm::Value *mask = callInst->getArgOperand(2);
4306 
4307     // We need to choose between doing the load + blend + store trick,
4308     // or serializing the masked store.  Even on targets with a native
4309     // masked store instruction, this is preferable since it lets us
4310     // keep values in registers rather than going out to the stack.
4311     bool doBlend = (!g->opt.disableBlendedMaskedStores && lIsSafeToBlend(lvalue));
4312 
4313     // Generate the call to the appropriate masked store function and
4314     // replace the __pseudo_* one with it.
4315     llvm::Function *fms = doBlend ? info->blendFunc : info->maskedStoreFunc;
4316     llvm::Instruction *inst = lCallInst(fms, lvalue, rvalue, mask, "", callInst);
4317     lCopyMetadata(inst, callInst);
4318 
4319     callInst->eraseFromParent();
4320     return true;
4321 }
4322 
lReplacePseudoGS(llvm::CallInst * callInst)4323 static bool lReplacePseudoGS(llvm::CallInst *callInst) {
4324     struct LowerGSInfo {
4325         LowerGSInfo(const char *pName, const char *aName, bool ig, bool ip) : isGather(ig), isPrefetch(ip) {
4326             pseudoFunc = m->module->getFunction(pName);
4327             actualFunc = m->module->getFunction(aName);
4328         }
4329         llvm::Function *pseudoFunc;
4330         llvm::Function *actualFunc;
4331         const bool isGather;
4332         const bool isPrefetch;
4333     };
4334 
4335     LowerGSInfo lgsInfo[] = {
4336         LowerGSInfo("__pseudo_gather32_i8", "__gather32_i8", true, false),
4337         LowerGSInfo("__pseudo_gather32_i16", "__gather32_i16", true, false),
4338         LowerGSInfo("__pseudo_gather32_i32", "__gather32_i32", true, false),
4339         LowerGSInfo("__pseudo_gather32_float", "__gather32_float", true, false),
4340         LowerGSInfo("__pseudo_gather32_i64", "__gather32_i64", true, false),
4341         LowerGSInfo("__pseudo_gather32_double", "__gather32_double", true, false),
4342 
4343         LowerGSInfo("__pseudo_gather64_i8", "__gather64_i8", true, false),
4344         LowerGSInfo("__pseudo_gather64_i16", "__gather64_i16", true, false),
4345         LowerGSInfo("__pseudo_gather64_i32", "__gather64_i32", true, false),
4346         LowerGSInfo("__pseudo_gather64_float", "__gather64_float", true, false),
4347         LowerGSInfo("__pseudo_gather64_i64", "__gather64_i64", true, false),
4348         LowerGSInfo("__pseudo_gather64_double", "__gather64_double", true, false),
4349 
4350         LowerGSInfo("__pseudo_gather_factored_base_offsets32_i8", "__gather_factored_base_offsets32_i8", true, false),
4351         LowerGSInfo("__pseudo_gather_factored_base_offsets32_i16", "__gather_factored_base_offsets32_i16", true, false),
4352         LowerGSInfo("__pseudo_gather_factored_base_offsets32_i32", "__gather_factored_base_offsets32_i32", true, false),
4353         LowerGSInfo("__pseudo_gather_factored_base_offsets32_float", "__gather_factored_base_offsets32_float", true,
4354                     false),
4355         LowerGSInfo("__pseudo_gather_factored_base_offsets32_i64", "__gather_factored_base_offsets32_i64", true, false),
4356         LowerGSInfo("__pseudo_gather_factored_base_offsets32_double", "__gather_factored_base_offsets32_double", true,
4357                     false),
4358 
4359         LowerGSInfo("__pseudo_gather_factored_base_offsets64_i8", "__gather_factored_base_offsets64_i8", true, false),
4360         LowerGSInfo("__pseudo_gather_factored_base_offsets64_i16", "__gather_factored_base_offsets64_i16", true, false),
4361         LowerGSInfo("__pseudo_gather_factored_base_offsets64_i32", "__gather_factored_base_offsets64_i32", true, false),
4362         LowerGSInfo("__pseudo_gather_factored_base_offsets64_float", "__gather_factored_base_offsets64_float", true,
4363                     false),
4364         LowerGSInfo("__pseudo_gather_factored_base_offsets64_i64", "__gather_factored_base_offsets64_i64", true, false),
4365         LowerGSInfo("__pseudo_gather_factored_base_offsets64_double", "__gather_factored_base_offsets64_double", true,
4366                     false),
4367 
4368         LowerGSInfo("__pseudo_gather_base_offsets32_i8", "__gather_base_offsets32_i8", true, false),
4369         LowerGSInfo("__pseudo_gather_base_offsets32_i16", "__gather_base_offsets32_i16", true, false),
4370         LowerGSInfo("__pseudo_gather_base_offsets32_i32", "__gather_base_offsets32_i32", true, false),
4371         LowerGSInfo("__pseudo_gather_base_offsets32_float", "__gather_base_offsets32_float", true, false),
4372         LowerGSInfo("__pseudo_gather_base_offsets32_i64", "__gather_base_offsets32_i64", true, false),
4373         LowerGSInfo("__pseudo_gather_base_offsets32_double", "__gather_base_offsets32_double", true, false),
4374 
4375         LowerGSInfo("__pseudo_gather_base_offsets64_i8", "__gather_base_offsets64_i8", true, false),
4376         LowerGSInfo("__pseudo_gather_base_offsets64_i16", "__gather_base_offsets64_i16", true, false),
4377         LowerGSInfo("__pseudo_gather_base_offsets64_i32", "__gather_base_offsets64_i32", true, false),
4378         LowerGSInfo("__pseudo_gather_base_offsets64_float", "__gather_base_offsets64_float", true, false),
4379         LowerGSInfo("__pseudo_gather_base_offsets64_i64", "__gather_base_offsets64_i64", true, false),
4380         LowerGSInfo("__pseudo_gather_base_offsets64_double", "__gather_base_offsets64_double", true, false),
4381 
4382         LowerGSInfo("__pseudo_scatter32_i8", "__scatter32_i8", false, false),
4383         LowerGSInfo("__pseudo_scatter32_i16", "__scatter32_i16", false, false),
4384         LowerGSInfo("__pseudo_scatter32_i32", "__scatter32_i32", false, false),
4385         LowerGSInfo("__pseudo_scatter32_float", "__scatter32_float", false, false),
4386         LowerGSInfo("__pseudo_scatter32_i64", "__scatter32_i64", false, false),
4387         LowerGSInfo("__pseudo_scatter32_double", "__scatter32_double", false, false),
4388 
4389         LowerGSInfo("__pseudo_scatter64_i8", "__scatter64_i8", false, false),
4390         LowerGSInfo("__pseudo_scatter64_i16", "__scatter64_i16", false, false),
4391         LowerGSInfo("__pseudo_scatter64_i32", "__scatter64_i32", false, false),
4392         LowerGSInfo("__pseudo_scatter64_float", "__scatter64_float", false, false),
4393         LowerGSInfo("__pseudo_scatter64_i64", "__scatter64_i64", false, false),
4394         LowerGSInfo("__pseudo_scatter64_double", "__scatter64_double", false, false),
4395 
4396         LowerGSInfo("__pseudo_scatter_factored_base_offsets32_i8", "__scatter_factored_base_offsets32_i8", false,
4397                     false),
4398         LowerGSInfo("__pseudo_scatter_factored_base_offsets32_i16", "__scatter_factored_base_offsets32_i16", false,
4399                     false),
4400         LowerGSInfo("__pseudo_scatter_factored_base_offsets32_i32", "__scatter_factored_base_offsets32_i32", false,
4401                     false),
4402         LowerGSInfo("__pseudo_scatter_factored_base_offsets32_float", "__scatter_factored_base_offsets32_float", false,
4403                     false),
4404         LowerGSInfo("__pseudo_scatter_factored_base_offsets32_i64", "__scatter_factored_base_offsets32_i64", false,
4405                     false),
4406         LowerGSInfo("__pseudo_scatter_factored_base_offsets32_double", "__scatter_factored_base_offsets32_double",
4407                     false, false),
4408 
4409         LowerGSInfo("__pseudo_scatter_factored_base_offsets64_i8", "__scatter_factored_base_offsets64_i8", false,
4410                     false),
4411         LowerGSInfo("__pseudo_scatter_factored_base_offsets64_i16", "__scatter_factored_base_offsets64_i16", false,
4412                     false),
4413         LowerGSInfo("__pseudo_scatter_factored_base_offsets64_i32", "__scatter_factored_base_offsets64_i32", false,
4414                     false),
4415         LowerGSInfo("__pseudo_scatter_factored_base_offsets64_float", "__scatter_factored_base_offsets64_float", false,
4416                     false),
4417         LowerGSInfo("__pseudo_scatter_factored_base_offsets64_i64", "__scatter_factored_base_offsets64_i64", false,
4418                     false),
4419         LowerGSInfo("__pseudo_scatter_factored_base_offsets64_double", "__scatter_factored_base_offsets64_double",
4420                     false, false),
4421 
4422         LowerGSInfo("__pseudo_scatter_base_offsets32_i8", "__scatter_base_offsets32_i8", false, false),
4423         LowerGSInfo("__pseudo_scatter_base_offsets32_i16", "__scatter_base_offsets32_i16", false, false),
4424         LowerGSInfo("__pseudo_scatter_base_offsets32_i32", "__scatter_base_offsets32_i32", false, false),
4425         LowerGSInfo("__pseudo_scatter_base_offsets32_float", "__scatter_base_offsets32_float", false, false),
4426         LowerGSInfo("__pseudo_scatter_base_offsets32_i64", "__scatter_base_offsets32_i64", false, false),
4427         LowerGSInfo("__pseudo_scatter_base_offsets32_double", "__scatter_base_offsets32_double", false, false),
4428 
4429         LowerGSInfo("__pseudo_scatter_base_offsets64_i8", "__scatter_base_offsets64_i8", false, false),
4430         LowerGSInfo("__pseudo_scatter_base_offsets64_i16", "__scatter_base_offsets64_i16", false, false),
4431         LowerGSInfo("__pseudo_scatter_base_offsets64_i32", "__scatter_base_offsets64_i32", false, false),
4432         LowerGSInfo("__pseudo_scatter_base_offsets64_float", "__scatter_base_offsets64_float", false, false),
4433         LowerGSInfo("__pseudo_scatter_base_offsets64_i64", "__scatter_base_offsets64_i64", false, false),
4434         LowerGSInfo("__pseudo_scatter_base_offsets64_double", "__scatter_base_offsets64_double", false, false),
4435 
4436         LowerGSInfo("__pseudo_prefetch_read_varying_1", "__prefetch_read_varying_1", false, true),
4437         LowerGSInfo("__pseudo_prefetch_read_varying_1_native", "__prefetch_read_varying_1_native", false, true),
4438 
4439         LowerGSInfo("__pseudo_prefetch_read_varying_2", "__prefetch_read_varying_2", false, true),
4440         LowerGSInfo("__pseudo_prefetch_read_varying_2_native", "__prefetch_read_varying_2_native", false, true),
4441 
4442         LowerGSInfo("__pseudo_prefetch_read_varying_3", "__prefetch_read_varying_3", false, true),
4443         LowerGSInfo("__pseudo_prefetch_read_varying_3_native", "__prefetch_read_varying_3_native", false, true),
4444 
4445         LowerGSInfo("__pseudo_prefetch_read_varying_nt", "__prefetch_read_varying_nt", false, true),
4446         LowerGSInfo("__pseudo_prefetch_read_varying_nt_native", "__prefetch_read_varying_nt_native", false, true),
4447     };
4448 
4449     llvm::Function *calledFunc = callInst->getCalledFunction();
4450 
4451     LowerGSInfo *info = NULL;
4452     for (unsigned int i = 0; i < sizeof(lgsInfo) / sizeof(lgsInfo[0]); ++i) {
4453         if (lgsInfo[i].pseudoFunc != NULL && calledFunc == lgsInfo[i].pseudoFunc) {
4454             info = &lgsInfo[i];
4455             break;
4456         }
4457     }
4458     if (info == NULL)
4459         return false;
4460 
4461     Assert(info->actualFunc != NULL);
4462 
4463     // Get the source position from the metadata attached to the call
4464     // instruction so that we can issue PerformanceWarning()s below.
4465     SourcePos pos;
4466     bool gotPosition = lGetSourcePosFromMetadata(callInst, &pos);
4467 
4468     callInst->setCalledFunction(info->actualFunc);
4469     // Check for alloca and if not alloca - generate __gather and change arguments
4470     if (gotPosition && (g->target->getVectorWidth() > 1) && (g->opt.level > 0)) {
4471         if (info->isGather)
4472             PerformanceWarning(pos, "Gather required to load value.");
4473         else if (!info->isPrefetch)
4474             PerformanceWarning(pos, "Scatter required to store value.");
4475     }
4476     return true;
4477 }
4478 
runOnBasicBlock(llvm::BasicBlock & bb)4479 bool ReplacePseudoMemoryOpsPass::runOnBasicBlock(llvm::BasicBlock &bb) {
4480     DEBUG_START_PASS("ReplacePseudoMemoryOpsPass");
4481 
4482     bool modifiedAny = false;
4483 
4484 restart:
4485     for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
4486         llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(&*iter);
4487         if (callInst == NULL || callInst->getCalledFunction() == NULL)
4488             continue;
4489 
4490         if (lReplacePseudoGS(callInst)) {
4491             modifiedAny = true;
4492             goto restart;
4493         } else if (lReplacePseudoMaskedStore(callInst)) {
4494             modifiedAny = true;
4495             goto restart;
4496         }
4497     }
4498 
4499     DEBUG_END_PASS("ReplacePseudoMemoryOpsPass");
4500 
4501     return modifiedAny;
4502 }
4503 
runOnFunction(llvm::Function & F)4504 bool ReplacePseudoMemoryOpsPass::runOnFunction(llvm::Function &F) {
4505 
4506     llvm::TimeTraceScope FuncScope("ReplacePseudoMemoryOpsPass::runOnFunction", F.getName());
4507     bool modifiedAny = false;
4508     for (llvm::BasicBlock &BB : F) {
4509         modifiedAny |= runOnBasicBlock(BB);
4510     }
4511     return modifiedAny;
4512 }
4513 
CreateReplacePseudoMemoryOpsPass()4514 static llvm::Pass *CreateReplacePseudoMemoryOpsPass() { return new ReplacePseudoMemoryOpsPass; }
4515 
4516 ///////////////////////////////////////////////////////////////////////////
4517 // IsCompileTimeConstantPass
4518 
4519 /** LLVM IR implementations of target-specific functions may include calls
4520     to the functions "bool __is_compile_time_constant_*(...)"; these allow
4521     them to have specialied code paths for where the corresponding value is
4522     known at compile time.  For masks, for example, this allows them to not
4523     incur the cost of a MOVMSK call at runtime to compute its value in
4524     cases where the mask value isn't known until runtime.
4525 
4526     This pass resolves these calls into either 'true' or 'false' values so
4527     that later optimization passes can operate with these as constants.
4528 
4529     See stdlib.m4 for a number of uses of this idiom.
4530  */
4531 
4532 class IsCompileTimeConstantPass : public llvm::FunctionPass {
4533   public:
4534     static char ID;
IsCompileTimeConstantPass(bool last=false)4535     IsCompileTimeConstantPass(bool last = false) : FunctionPass(ID) { isLastTry = last; }
4536 
getPassName() const4537     llvm::StringRef getPassName() const { return "Resolve \"is compile time constant\""; }
4538     bool runOnBasicBlock(llvm::BasicBlock &BB);
4539     bool runOnFunction(llvm::Function &F);
4540 
4541     bool isLastTry;
4542 };
4543 
4544 char IsCompileTimeConstantPass::ID = 0;
4545 
runOnBasicBlock(llvm::BasicBlock & bb)4546 bool IsCompileTimeConstantPass::runOnBasicBlock(llvm::BasicBlock &bb) {
4547     DEBUG_START_PASS("IsCompileTimeConstantPass");
4548 
4549     llvm::Function *funcs[] = {m->module->getFunction("__is_compile_time_constant_mask"),
4550                                m->module->getFunction("__is_compile_time_constant_uniform_int32"),
4551                                m->module->getFunction("__is_compile_time_constant_varying_int32")};
4552 
4553     bool modifiedAny = false;
4554 restart:
4555     for (llvm::BasicBlock::iterator i = bb.begin(), e = bb.end(); i != e; ++i) {
4556         // Iterate through the instructions looking for calls to the
4557         // __is_compile_time_constant_*() functions
4558         llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(&*i);
4559         if (callInst == NULL)
4560             continue;
4561 
4562         int j;
4563         int nFuncs = sizeof(funcs) / sizeof(funcs[0]);
4564         for (j = 0; j < nFuncs; ++j) {
4565             if (funcs[j] != NULL && callInst->getCalledFunction() == funcs[j])
4566                 break;
4567         }
4568         if (j == nFuncs)
4569             // not a __is_compile_time_constant_* function
4570             continue;
4571 
4572         // This optimization pass can be disabled with both the (poorly
4573         // named) disableGatherScatterFlattening option and
4574         // disableMaskAllOnOptimizations.
4575         if (g->opt.disableGatherScatterFlattening || g->opt.disableMaskAllOnOptimizations) {
4576             llvm::ReplaceInstWithValue(i->getParent()->getInstList(), i, LLVMFalse);
4577             modifiedAny = true;
4578             goto restart;
4579         }
4580 
4581         // Is it a constant?  Bingo, turn the call's value into a constant
4582         // true value.
4583         llvm::Value *operand = callInst->getArgOperand(0);
4584         if (llvm::isa<llvm::Constant>(operand)) {
4585             llvm::ReplaceInstWithValue(i->getParent()->getInstList(), i, LLVMTrue);
4586             modifiedAny = true;
4587             goto restart;
4588         }
4589 
4590         // This pass runs multiple times during optimization.  Up until the
4591         // very last time, it only replaces the call with a 'true' if the
4592         // value is known to be constant and otherwise leaves the call
4593         // alone, in case further optimization passes can help resolve its
4594         // value.  The last time through, it eventually has to give up, and
4595         // replaces any remaining ones with 'false' constants.
4596         if (isLastTry) {
4597             llvm::ReplaceInstWithValue(i->getParent()->getInstList(), i, LLVMFalse);
4598             modifiedAny = true;
4599             goto restart;
4600         }
4601     }
4602 
4603     DEBUG_END_PASS("IsCompileTimeConstantPass");
4604 
4605     return modifiedAny;
4606 }
4607 
runOnFunction(llvm::Function & F)4608 bool IsCompileTimeConstantPass::runOnFunction(llvm::Function &F) {
4609 
4610     llvm::TimeTraceScope FuncScope("IsCompileTimeConstantPass::runOnFunction", F.getName());
4611     bool modifiedAny = false;
4612     for (llvm::BasicBlock &BB : F) {
4613         modifiedAny |= runOnBasicBlock(BB);
4614     }
4615     return modifiedAny;
4616 }
4617 
CreateIsCompileTimeConstantPass(bool isLastTry)4618 static llvm::Pass *CreateIsCompileTimeConstantPass(bool isLastTry) { return new IsCompileTimeConstantPass(isLastTry); }
4619 
4620 //////////////////////////////////////////////////////////////////////////
4621 // DebugPass
4622 
4623 /** This pass is added in list of passes after optimizations which
4624     we want to debug and print dump of LLVM IR in stderr. Also it
4625     prints name and number of previous optimization.
4626  */
4627 #ifndef ISPC_NO_DUMPS
4628 class DebugPass : public llvm::ModulePass {
4629   public:
4630     static char ID;
DebugPass(char * output)4631     DebugPass(char *output) : ModulePass(ID) { snprintf(str_output, sizeof(str_output), "%s", output); }
4632 
getPassName() const4633     llvm::StringRef getPassName() const { return "Dump LLVM IR"; }
4634     bool runOnModule(llvm::Module &m);
4635 
4636   private:
4637     char str_output[100];
4638 };
4639 
4640 char DebugPass::ID = 0;
4641 
runOnModule(llvm::Module & module)4642 bool DebugPass::runOnModule(llvm::Module &module) {
4643     fprintf(stderr, "%s", str_output);
4644     fflush(stderr);
4645     //module.dump();
4646     return true;
4647 }
4648 
CreateDebugPass(char * output)4649 static llvm::Pass *CreateDebugPass(char *output) { return new DebugPass(output); }
4650 #endif
4651 
4652 //////////////////////////////////////////////////////////////////////////
4653 // DebugPassFile
4654 
4655 /** This pass is added in list of passes after optimizations which
4656     we want to debug and print dump of LLVM IR to file.
4657  */
4658 #ifndef ISPC_NO_DUMPS
4659 class DebugPassFile : public llvm::ModulePass {
4660   public:
4661     static char ID;
DebugPassFile(int number,llvm::StringRef name)4662     DebugPassFile(int number, llvm::StringRef name) : ModulePass(ID), pnum(number), pname(name) {}
4663 
getPassName() const4664     llvm::StringRef getPassName() const { return "Dump LLVM IR"; }
4665     bool runOnModule(llvm::Module &m);
4666     bool doInitialization(llvm::Module &m);
4667 
4668   private:
4669     void run(llvm::Module &m, bool init);
4670     int pnum;
4671     llvm::StringRef pname;
4672 };
4673 
4674 char DebugPassFile::ID = 0;
4675 
4676 /**
4677  * Strips all non-alphanumeric characters from given string.
4678  */
sanitize(std::string in)4679 std::string sanitize(std::string in) {
4680     llvm::Regex r("[^[:alnum:]]");
4681     while (r.match(in))
4682         in = r.sub("", in);
4683     return in;
4684 }
4685 
run(llvm::Module & module,bool init)4686 void DebugPassFile::run(llvm::Module &module, bool init) {
4687     std::error_code EC;
4688     char fname[100];
4689     snprintf(fname, sizeof(fname), "%s_%d_%s.ll", init ? "init" : "ir", pnum, sanitize(std::string(pname)).c_str());
4690     llvm::raw_fd_ostream OS(fname, EC, llvm::sys::fs::F_None);
4691     Assert(!EC && "IR dump file creation failed!");
4692     module.print(OS, 0);
4693 }
4694 
runOnModule(llvm::Module & module)4695 bool DebugPassFile::runOnModule(llvm::Module &module) {
4696     run(module, false);
4697     return true;
4698 }
4699 
doInitialization(llvm::Module & module)4700 bool DebugPassFile::doInitialization(llvm::Module &module) {
4701     run(module, true);
4702     return true;
4703 }
4704 
CreateDebugPassFile(int number,llvm::StringRef name)4705 static llvm::Pass *CreateDebugPassFile(int number, llvm::StringRef name) { return new DebugPassFile(number, name); }
4706 #endif
4707 
4708 ///////////////////////////////////////////////////////////////////////////
4709 // MakeInternalFuncsStaticPass
4710 
4711 /** There are a number of target-specific functions that we use during
4712     these optimization passes.  By the time we are done with optimization,
4713     any uses of these should be inlined and no calls to these functions
4714     should remain.  This pass marks all of these functions as having
4715     private linkage so that subsequent passes can eliminate them as dead
4716     code, thus cleaning up the final code output by the compiler.  We can't
4717     just declare these as static from the start, however, since then they
4718     end up being eliminated as dead code during early optimization passes
4719     even though we may need to generate calls to them during later
4720     optimization passes.
4721  */
4722 class MakeInternalFuncsStaticPass : public llvm::ModulePass {
4723   public:
4724     static char ID;
MakeInternalFuncsStaticPass(bool last=false)4725     MakeInternalFuncsStaticPass(bool last = false) : ModulePass(ID) {}
4726 
getAnalysisUsage(llvm::AnalysisUsage & AU) const4727     void getAnalysisUsage(llvm::AnalysisUsage &AU) const { AU.setPreservesCFG(); }
4728 
getPassName() const4729     llvm::StringRef getPassName() const { return "Make internal funcs \"static\""; }
4730     bool runOnModule(llvm::Module &m);
4731 };
4732 
4733 char MakeInternalFuncsStaticPass::ID = 0;
4734 
runOnModule(llvm::Module & module)4735 bool MakeInternalFuncsStaticPass::runOnModule(llvm::Module &module) {
4736     const char *names[] = {
4737         "__avg_up_uint8",
4738         "__avg_up_int8",
4739         "__avg_up_uint16",
4740         "__avg_up_int16",
4741         "__avg_down_uint8",
4742         "__avg_down_int8",
4743         "__avg_down_uint16",
4744         "__avg_down_int16",
4745         "__fast_masked_vload",
4746         "__gather_factored_base_offsets32_i8",
4747         "__gather_factored_base_offsets32_i16",
4748         "__gather_factored_base_offsets32_i32",
4749         "__gather_factored_base_offsets32_i64",
4750         "__gather_factored_base_offsets32_float",
4751         "__gather_factored_base_offsets32_double",
4752         "__gather_factored_base_offsets64_i8",
4753         "__gather_factored_base_offsets64_i16",
4754         "__gather_factored_base_offsets64_i32",
4755         "__gather_factored_base_offsets64_i64",
4756         "__gather_factored_base_offsets64_float",
4757         "__gather_factored_base_offsets64_double",
4758         "__gather_base_offsets32_i8",
4759         "__gather_base_offsets32_i16",
4760         "__gather_base_offsets32_i32",
4761         "__gather_base_offsets32_i64",
4762         "__gather_base_offsets32_float",
4763         "__gather_base_offsets32_double",
4764         "__gather_base_offsets64_i8",
4765         "__gather_base_offsets64_i16",
4766         "__gather_base_offsets64_i32",
4767         "__gather_base_offsets64_i64",
4768         "__gather_base_offsets64_float",
4769         "__gather_base_offsets64_double",
4770         "__gather32_i8",
4771         "__gather32_i16",
4772         "__gather32_i32",
4773         "__gather32_i64",
4774         "__gather32_float",
4775         "__gather32_double",
4776         "__gather64_i8",
4777         "__gather64_i16",
4778         "__gather64_i32",
4779         "__gather64_i64",
4780         "__gather64_float",
4781         "__gather64_double",
4782         "__gather_elt32_i8",
4783         "__gather_elt32_i16",
4784         "__gather_elt32_i32",
4785         "__gather_elt32_i64",
4786         "__gather_elt32_float",
4787         "__gather_elt32_double",
4788         "__gather_elt64_i8",
4789         "__gather_elt64_i16",
4790         "__gather_elt64_i32",
4791         "__gather_elt64_i64",
4792         "__gather_elt64_float",
4793         "__gather_elt64_double",
4794         "__masked_load_i8",
4795         "__masked_load_i16",
4796         "__masked_load_i32",
4797         "__masked_load_i64",
4798         "__masked_load_float",
4799         "__masked_load_double",
4800         "__masked_store_i8",
4801         "__masked_store_i16",
4802         "__masked_store_i32",
4803         "__masked_store_i64",
4804         "__masked_store_float",
4805         "__masked_store_double",
4806         "__masked_store_blend_i8",
4807         "__masked_store_blend_i16",
4808         "__masked_store_blend_i32",
4809         "__masked_store_blend_i64",
4810         "__masked_store_blend_float",
4811         "__masked_store_blend_double",
4812         "__scatter_factored_base_offsets32_i8",
4813         "__scatter_factored_base_offsets32_i16",
4814         "__scatter_factored_base_offsets32_i32",
4815         "__scatter_factored_base_offsets32_i64",
4816         "__scatter_factored_base_offsets32_float",
4817         "__scatter_factored_base_offsets32_double",
4818         "__scatter_factored_base_offsets64_i8",
4819         "__scatter_factored_base_offsets64_i16",
4820         "__scatter_factored_base_offsets64_i32",
4821         "__scatter_factored_base_offsets64_i64",
4822         "__scatter_factored_base_offsets64_float",
4823         "__scatter_factored_base_offsets64_double",
4824         "__scatter_base_offsets32_i8",
4825         "__scatter_base_offsets32_i16",
4826         "__scatter_base_offsets32_i32",
4827         "__scatter_base_offsets32_i64",
4828         "__scatter_base_offsets32_float",
4829         "__scatter_base_offsets32_double",
4830         "__scatter_base_offsets64_i8",
4831         "__scatter_base_offsets64_i16",
4832         "__scatter_base_offsets64_i32",
4833         "__scatter_base_offsets64_i64",
4834         "__scatter_base_offsets64_float",
4835         "__scatter_base_offsets64_double",
4836         "__scatter_elt32_i8",
4837         "__scatter_elt32_i16",
4838         "__scatter_elt32_i32",
4839         "__scatter_elt32_i64",
4840         "__scatter_elt32_float",
4841         "__scatter_elt32_double",
4842         "__scatter_elt64_i8",
4843         "__scatter_elt64_i16",
4844         "__scatter_elt64_i32",
4845         "__scatter_elt64_i64",
4846         "__scatter_elt64_float",
4847         "__scatter_elt64_double",
4848         "__scatter32_i8",
4849         "__scatter32_i16",
4850         "__scatter32_i32",
4851         "__scatter32_i64",
4852         "__scatter32_float",
4853         "__scatter32_double",
4854         "__scatter64_i8",
4855         "__scatter64_i16",
4856         "__scatter64_i32",
4857         "__scatter64_i64",
4858         "__scatter64_float",
4859         "__scatter64_double",
4860         "__prefetch_read_varying_1",
4861         "__prefetch_read_varying_2",
4862         "__prefetch_read_varying_3",
4863         "__prefetch_read_varying_nt",
4864         "__keep_funcs_live",
4865 #ifdef ISPC_GENX_ENABLED
4866         "__masked_load_blend_i8",
4867         "__masked_load_blend_i16",
4868         "__masked_load_blend_i32",
4869         "__masked_load_blend_i64",
4870         "__masked_load_blend_float",
4871         "__masked_load_blend_double",
4872 #endif
4873     };
4874 
4875     bool modifiedAny = false;
4876     int count = sizeof(names) / sizeof(names[0]);
4877     for (int i = 0; i < count; ++i) {
4878         llvm::Function *f = m->module->getFunction(names[i]);
4879         if (f != NULL && f->empty() == false) {
4880             f->setLinkage(llvm::GlobalValue::InternalLinkage);
4881             modifiedAny = true;
4882         }
4883     }
4884 
4885     return modifiedAny;
4886 }
4887 
CreateMakeInternalFuncsStaticPass()4888 static llvm::Pass *CreateMakeInternalFuncsStaticPass() { return new MakeInternalFuncsStaticPass; }
4889 
4890 ///////////////////////////////////////////////////////////////////////////
4891 // PeepholePass
4892 
4893 class PeepholePass : public llvm::FunctionPass {
4894   public:
4895     PeepholePass();
4896 
getPassName() const4897     llvm::StringRef getPassName() const { return "Peephole Optimizations"; }
4898     bool runOnBasicBlock(llvm::BasicBlock &BB);
4899     bool runOnFunction(llvm::Function &F);
4900 
4901     static char ID;
4902 };
4903 
4904 char PeepholePass::ID = 0;
4905 
PeepholePass()4906 PeepholePass::PeepholePass() : FunctionPass(ID) {}
4907 
4908 using namespace llvm::PatternMatch;
4909 
4910 template <typename Op_t, unsigned Opcode> struct CastClassTypes_match {
4911     Op_t Op;
4912     const llvm::Type *fromType, *toType;
4913 
CastClassTypes_matchCastClassTypes_match4914     CastClassTypes_match(const Op_t &OpMatch, const llvm::Type *f, const llvm::Type *t)
4915         : Op(OpMatch), fromType(f), toType(t) {}
4916 
matchCastClassTypes_match4917     template <typename OpTy> bool match(OpTy *V) {
4918         if (llvm::Operator *O = llvm::dyn_cast<llvm::Operator>(V))
4919             return (O->getOpcode() == Opcode && Op.match(O->getOperand(0)) && O->getType() == toType &&
4920                     O->getOperand(0)->getType() == fromType);
4921         return false;
4922     }
4923 };
4924 
m_SExt8To16(const OpTy & Op)4925 template <typename OpTy> inline CastClassTypes_match<OpTy, llvm::Instruction::SExt> m_SExt8To16(const OpTy &Op) {
4926     return CastClassTypes_match<OpTy, llvm::Instruction::SExt>(Op, LLVMTypes::Int8VectorType,
4927                                                                LLVMTypes::Int16VectorType);
4928 }
4929 
m_ZExt8To16(const OpTy & Op)4930 template <typename OpTy> inline CastClassTypes_match<OpTy, llvm::Instruction::ZExt> m_ZExt8To16(const OpTy &Op) {
4931     return CastClassTypes_match<OpTy, llvm::Instruction::ZExt>(Op, LLVMTypes::Int8VectorType,
4932                                                                LLVMTypes::Int16VectorType);
4933 }
4934 
m_Trunc16To8(const OpTy & Op)4935 template <typename OpTy> inline CastClassTypes_match<OpTy, llvm::Instruction::Trunc> m_Trunc16To8(const OpTy &Op) {
4936     return CastClassTypes_match<OpTy, llvm::Instruction::Trunc>(Op, LLVMTypes::Int16VectorType,
4937                                                                 LLVMTypes::Int8VectorType);
4938 }
4939 
m_SExt16To32(const OpTy & Op)4940 template <typename OpTy> inline CastClassTypes_match<OpTy, llvm::Instruction::SExt> m_SExt16To32(const OpTy &Op) {
4941     return CastClassTypes_match<OpTy, llvm::Instruction::SExt>(Op, LLVMTypes::Int16VectorType,
4942                                                                LLVMTypes::Int32VectorType);
4943 }
4944 
m_ZExt16To32(const OpTy & Op)4945 template <typename OpTy> inline CastClassTypes_match<OpTy, llvm::Instruction::ZExt> m_ZExt16To32(const OpTy &Op) {
4946     return CastClassTypes_match<OpTy, llvm::Instruction::ZExt>(Op, LLVMTypes::Int16VectorType,
4947                                                                LLVMTypes::Int32VectorType);
4948 }
4949 
m_Trunc32To16(const OpTy & Op)4950 template <typename OpTy> inline CastClassTypes_match<OpTy, llvm::Instruction::Trunc> m_Trunc32To16(const OpTy &Op) {
4951     return CastClassTypes_match<OpTy, llvm::Instruction::Trunc>(Op, LLVMTypes::Int32VectorType,
4952                                                                 LLVMTypes::Int16VectorType);
4953 }
4954 
4955 template <typename Op_t> struct UDiv2_match {
4956     Op_t Op;
4957 
UDiv2_matchUDiv2_match4958     UDiv2_match(const Op_t &OpMatch) : Op(OpMatch) {}
4959 
matchUDiv2_match4960     template <typename OpTy> bool match(OpTy *V) {
4961         llvm::BinaryOperator *bop;
4962         llvm::ConstantDataVector *cdv;
4963         if ((bop = llvm::dyn_cast<llvm::BinaryOperator>(V)) &&
4964             (cdv = llvm::dyn_cast<llvm::ConstantDataVector>(bop->getOperand(1))) && cdv->getSplatValue() != NULL) {
4965             const llvm::APInt &apInt = cdv->getUniqueInteger();
4966 
4967             switch (bop->getOpcode()) {
4968             case llvm::Instruction::UDiv:
4969                 // divide by 2
4970                 return (apInt.isIntN(2) && Op.match(bop->getOperand(0)));
4971             case llvm::Instruction::LShr:
4972                 // shift left by 1
4973                 return (apInt.isIntN(1) && Op.match(bop->getOperand(0)));
4974             default:
4975                 return false;
4976             }
4977         }
4978         return false;
4979     }
4980 };
4981 
m_UDiv2(const V & v)4982 template <typename V> inline UDiv2_match<V> m_UDiv2(const V &v) { return UDiv2_match<V>(v); }
4983 
4984 template <typename Op_t> struct SDiv2_match {
4985     Op_t Op;
4986 
SDiv2_matchSDiv2_match4987     SDiv2_match(const Op_t &OpMatch) : Op(OpMatch) {}
4988 
matchSDiv2_match4989     template <typename OpTy> bool match(OpTy *V) {
4990         llvm::BinaryOperator *bop;
4991         llvm::ConstantDataVector *cdv;
4992         if ((bop = llvm::dyn_cast<llvm::BinaryOperator>(V)) &&
4993             (cdv = llvm::dyn_cast<llvm::ConstantDataVector>(bop->getOperand(1))) && cdv->getSplatValue() != NULL) {
4994             const llvm::APInt &apInt = cdv->getUniqueInteger();
4995 
4996             switch (bop->getOpcode()) {
4997             case llvm::Instruction::SDiv:
4998                 // divide by 2
4999                 return (apInt.isIntN(2) && Op.match(bop->getOperand(0)));
5000             case llvm::Instruction::AShr:
5001                 // shift left by 1
5002                 return (apInt.isIntN(1) && Op.match(bop->getOperand(0)));
5003             default:
5004                 return false;
5005             }
5006         }
5007         return false;
5008     }
5009 };
5010 
m_SDiv2(const V & v)5011 template <typename V> inline SDiv2_match<V> m_SDiv2(const V &v) { return SDiv2_match<V>(v); }
5012 
5013 // Returns true if the given function has a call to an intrinsic function
5014 // in its definition.
lHasIntrinsicInDefinition(llvm::Function * func)5015 static bool lHasIntrinsicInDefinition(llvm::Function *func) {
5016     llvm::Function::iterator bbiter = func->begin();
5017     for (; bbiter != func->end(); ++bbiter) {
5018         for (llvm::BasicBlock::iterator institer = bbiter->begin(); institer != bbiter->end(); ++institer) {
5019             if (llvm::isa<llvm::IntrinsicInst>(institer))
5020                 return true;
5021         }
5022     }
5023     return false;
5024 }
5025 
lGetBinaryIntrinsic(const char * name,llvm::Value * opa,llvm::Value * opb)5026 static llvm::Instruction *lGetBinaryIntrinsic(const char *name, llvm::Value *opa, llvm::Value *opb) {
5027     llvm::Function *func = m->module->getFunction(name);
5028     Assert(func != NULL);
5029 
5030     // Make sure that the definition of the llvm::Function has a call to an
5031     // intrinsic function in its instructions; otherwise we will generate
5032     // infinite loops where we "helpfully" turn the default implementations
5033     // of target builtins like __avg_up_uint8 that are implemented with plain
5034     // arithmetic ops into recursive calls to themselves.
5035     if (lHasIntrinsicInDefinition(func))
5036         return lCallInst(func, opa, opb, name);
5037     else
5038         return NULL;
5039 }
5040 
5041 //////////////////////////////////////////////////
5042 
lMatchAvgUpUInt8(llvm::Value * inst)5043 static llvm::Instruction *lMatchAvgUpUInt8(llvm::Value *inst) {
5044     // (unsigned int8)(((unsigned int16)a + (unsigned int16)b + 1)/2)
5045     llvm::Value *opa, *opb;
5046     const llvm::APInt *delta;
5047     if (match(inst, m_Trunc16To8(m_UDiv2(m_CombineOr(
5048                         m_CombineOr(m_Add(m_ZExt8To16(m_Value(opa)), m_Add(m_ZExt8To16(m_Value(opb)), m_APInt(delta))),
5049                                     m_Add(m_Add(m_ZExt8To16(m_Value(opa)), m_APInt(delta)), m_ZExt8To16(m_Value(opb)))),
5050                         m_Add(m_Add(m_ZExt8To16(m_Value(opa)), m_ZExt8To16(m_Value(opb))), m_APInt(delta))))))) {
5051         if (delta->isIntN(1) == false)
5052             return NULL;
5053 
5054         return lGetBinaryIntrinsic("__avg_up_uint8", opa, opb);
5055     }
5056     return NULL;
5057 }
5058 
lMatchAvgDownUInt8(llvm::Value * inst)5059 static llvm::Instruction *lMatchAvgDownUInt8(llvm::Value *inst) {
5060     // (unsigned int8)(((unsigned int16)a + (unsigned int16)b)/2)
5061     llvm::Value *opa, *opb;
5062     if (match(inst, m_Trunc16To8(m_UDiv2(m_Add(m_ZExt8To16(m_Value(opa)), m_ZExt8To16(m_Value(opb))))))) {
5063         return lGetBinaryIntrinsic("__avg_down_uint8", opa, opb);
5064     }
5065     return NULL;
5066 }
5067 
lMatchAvgUpUInt16(llvm::Value * inst)5068 static llvm::Instruction *lMatchAvgUpUInt16(llvm::Value *inst) {
5069     // (unsigned int16)(((unsigned int32)a + (unsigned int32)b + 1)/2)
5070     llvm::Value *opa, *opb;
5071     const llvm::APInt *delta;
5072     if (match(inst,
5073               m_Trunc32To16(m_UDiv2(m_CombineOr(
5074                   m_CombineOr(m_Add(m_ZExt16To32(m_Value(opa)), m_Add(m_ZExt16To32(m_Value(opb)), m_APInt(delta))),
5075                               m_Add(m_Add(m_ZExt16To32(m_Value(opa)), m_APInt(delta)), m_ZExt16To32(m_Value(opb)))),
5076                   m_Add(m_Add(m_ZExt16To32(m_Value(opa)), m_ZExt16To32(m_Value(opb))), m_APInt(delta))))))) {
5077         if (delta->isIntN(1) == false)
5078             return NULL;
5079 
5080         return lGetBinaryIntrinsic("__avg_up_uint16", opa, opb);
5081     }
5082     return NULL;
5083 }
5084 
lMatchAvgDownUInt16(llvm::Value * inst)5085 static llvm::Instruction *lMatchAvgDownUInt16(llvm::Value *inst) {
5086     // (unsigned int16)(((unsigned int32)a + (unsigned int32)b)/2)
5087     llvm::Value *opa, *opb;
5088     if (match(inst, m_Trunc32To16(m_UDiv2(m_Add(m_ZExt16To32(m_Value(opa)), m_ZExt16To32(m_Value(opb))))))) {
5089         return lGetBinaryIntrinsic("__avg_down_uint16", opa, opb);
5090     }
5091     return NULL;
5092 }
5093 
lMatchAvgUpInt8(llvm::Value * inst)5094 static llvm::Instruction *lMatchAvgUpInt8(llvm::Value *inst) {
5095     // (int8)(((int16)a + (int16)b + 1)/2)
5096     llvm::Value *opa, *opb;
5097     const llvm::APInt *delta;
5098     if (match(inst, m_Trunc16To8(m_SDiv2(m_CombineOr(
5099                         m_CombineOr(m_Add(m_SExt8To16(m_Value(opa)), m_Add(m_SExt8To16(m_Value(opb)), m_APInt(delta))),
5100                                     m_Add(m_Add(m_SExt8To16(m_Value(opa)), m_APInt(delta)), m_SExt8To16(m_Value(opb)))),
5101                         m_Add(m_Add(m_SExt8To16(m_Value(opa)), m_SExt8To16(m_Value(opb))), m_APInt(delta))))))) {
5102         if (delta->isIntN(1) == false)
5103             return NULL;
5104 
5105         return lGetBinaryIntrinsic("__avg_up_int8", opa, opb);
5106     }
5107     return NULL;
5108 }
5109 
lMatchAvgDownInt8(llvm::Value * inst)5110 static llvm::Instruction *lMatchAvgDownInt8(llvm::Value *inst) {
5111     // (int8)(((int16)a + (int16)b)/2)
5112     llvm::Value *opa, *opb;
5113     if (match(inst, m_Trunc16To8(m_SDiv2(m_Add(m_SExt8To16(m_Value(opa)), m_SExt8To16(m_Value(opb))))))) {
5114         return lGetBinaryIntrinsic("__avg_down_int8", opa, opb);
5115     }
5116     return NULL;
5117 }
5118 
lMatchAvgUpInt16(llvm::Value * inst)5119 static llvm::Instruction *lMatchAvgUpInt16(llvm::Value *inst) {
5120     // (int16)(((int32)a + (int32)b + 1)/2)
5121     llvm::Value *opa, *opb;
5122     const llvm::APInt *delta;
5123     if (match(inst,
5124               m_Trunc32To16(m_SDiv2(m_CombineOr(
5125                   m_CombineOr(m_Add(m_SExt16To32(m_Value(opa)), m_Add(m_SExt16To32(m_Value(opb)), m_APInt(delta))),
5126                               m_Add(m_Add(m_SExt16To32(m_Value(opa)), m_APInt(delta)), m_SExt16To32(m_Value(opb)))),
5127                   m_Add(m_Add(m_SExt16To32(m_Value(opa)), m_SExt16To32(m_Value(opb))), m_APInt(delta))))))) {
5128         if (delta->isIntN(1) == false)
5129             return NULL;
5130 
5131         return lGetBinaryIntrinsic("__avg_up_int16", opa, opb);
5132     }
5133     return NULL;
5134 }
5135 
lMatchAvgDownInt16(llvm::Value * inst)5136 static llvm::Instruction *lMatchAvgDownInt16(llvm::Value *inst) {
5137     // (int16)(((int32)a + (int32)b)/2)
5138     llvm::Value *opa, *opb;
5139     if (match(inst, m_Trunc32To16(m_SDiv2(m_Add(m_SExt16To32(m_Value(opa)), m_SExt16To32(m_Value(opb))))))) {
5140         return lGetBinaryIntrinsic("__avg_down_int16", opa, opb);
5141     }
5142     return NULL;
5143 }
5144 
runOnBasicBlock(llvm::BasicBlock & bb)5145 bool PeepholePass::runOnBasicBlock(llvm::BasicBlock &bb) {
5146     DEBUG_START_PASS("PeepholePass");
5147 
5148     bool modifiedAny = false;
5149 restart:
5150     for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
5151         llvm::Instruction *inst = &*iter;
5152 
5153         llvm::Instruction *builtinCall = lMatchAvgUpUInt8(inst);
5154         if (!builtinCall)
5155             builtinCall = lMatchAvgUpUInt16(inst);
5156         if (!builtinCall)
5157             builtinCall = lMatchAvgDownUInt8(inst);
5158         if (!builtinCall)
5159             builtinCall = lMatchAvgDownUInt16(inst);
5160         if (!builtinCall)
5161             builtinCall = lMatchAvgUpInt8(inst);
5162         if (!builtinCall)
5163             builtinCall = lMatchAvgUpInt16(inst);
5164         if (!builtinCall)
5165             builtinCall = lMatchAvgDownInt8(inst);
5166         if (!builtinCall)
5167             builtinCall = lMatchAvgDownInt16(inst);
5168         if (builtinCall != NULL) {
5169             llvm::ReplaceInstWithInst(inst, builtinCall);
5170             modifiedAny = true;
5171             goto restart;
5172         }
5173     }
5174 
5175     DEBUG_END_PASS("PeepholePass");
5176 
5177     return modifiedAny;
5178 }
5179 
runOnFunction(llvm::Function & F)5180 bool PeepholePass::runOnFunction(llvm::Function &F) {
5181 
5182     llvm::TimeTraceScope FuncScope("PeepholePass::runOnFunction", F.getName());
5183     bool modifiedAny = false;
5184     for (llvm::BasicBlock &BB : F) {
5185         modifiedAny |= runOnBasicBlock(BB);
5186     }
5187     return modifiedAny;
5188 }
5189 
CreatePeepholePass()5190 static llvm::Pass *CreatePeepholePass() { return new PeepholePass; }
5191 
5192 /** Given an llvm::Value known to be an integer, return its value as
5193     an int64_t.
5194 */
lGetIntValue(llvm::Value * offset)5195 static int64_t lGetIntValue(llvm::Value *offset) {
5196     llvm::ConstantInt *intOffset = llvm::dyn_cast<llvm::ConstantInt>(offset);
5197     Assert(intOffset && (intOffset->getBitWidth() == 32 || intOffset->getBitWidth() == 64));
5198     return intOffset->getSExtValue();
5199 }
5200 
5201 ///////////////////////////////////////////////////////////////////////////
5202 // ReplaceStdlibShiftPass
5203 
5204 class ReplaceStdlibShiftPass : public llvm::FunctionPass {
5205   public:
5206     static char ID;
ReplaceStdlibShiftPass()5207     ReplaceStdlibShiftPass() : FunctionPass(ID) {}
5208 
getPassName() const5209     llvm::StringRef getPassName() const { return "Resolve \"replace extract insert chains\""; }
5210 
5211     bool runOnBasicBlock(llvm::BasicBlock &BB);
5212 
5213     bool runOnFunction(llvm::Function &F);
5214 };
5215 
5216 char ReplaceStdlibShiftPass::ID = 0;
5217 
5218 // This pass replaces shift() with ShuffleVector when the offset is a constant.
5219 // rotate() which is similar in functionality has a slightly different
5220 // implementation. This is due to LLVM(createInstructionCombiningPass)
5221 // optimizing rotate() implementation better when similar implementations
5222 // are used for both. This is a hack to produce similarly optimized code for
5223 // shift.
runOnBasicBlock(llvm::BasicBlock & bb)5224 bool ReplaceStdlibShiftPass::runOnBasicBlock(llvm::BasicBlock &bb) {
5225     DEBUG_START_PASS("ReplaceStdlibShiftPass");
5226     bool modifiedAny = false;
5227 
5228     llvm::Function *shifts[6];
5229     shifts[0] = m->module->getFunction("shift___vytuni");
5230     shifts[1] = m->module->getFunction("shift___vysuni");
5231     shifts[2] = m->module->getFunction("shift___vyiuni");
5232     shifts[3] = m->module->getFunction("shift___vyIuni");
5233     shifts[4] = m->module->getFunction("shift___vyfuni");
5234     shifts[5] = m->module->getFunction("shift___vyduni");
5235 
5236     for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
5237         llvm::Instruction *inst = &*iter;
5238 
5239         if (llvm::CallInst *ci = llvm::dyn_cast<llvm::CallInst>(inst)) {
5240             llvm::Function *func = ci->getCalledFunction();
5241             for (int i = 0; i < 6; i++) {
5242                 if (shifts[i] && (shifts[i] == func)) {
5243                     // we matched a call
5244                     llvm::Value *shiftedVec = ci->getArgOperand(0);
5245                     llvm::Value *shiftAmt = ci->getArgOperand(1);
5246                     if (llvm::isa<llvm::Constant>(shiftAmt)) {
5247                         int vectorWidth = g->target->getVectorWidth();
5248                         int *shuffleVals = new int[vectorWidth];
5249                         int shiftInt = lGetIntValue(shiftAmt);
5250                         for (int i = 0; i < vectorWidth; i++) {
5251                             int s = i + shiftInt;
5252                             s = (s < 0) ? vectorWidth : s;
5253                             s = (s >= vectorWidth) ? vectorWidth : s;
5254                             shuffleVals[i] = s;
5255                         }
5256                         llvm::Value *shuffleIdxs = LLVMInt32Vector(shuffleVals);
5257                         llvm::Value *zeroVec = llvm::ConstantAggregateZero::get(shiftedVec->getType());
5258                         llvm::Value *shuffle =
5259                             new llvm::ShuffleVectorInst(shiftedVec, zeroVec, shuffleIdxs, "vecShift", ci);
5260                         ci->replaceAllUsesWith(shuffle);
5261                         modifiedAny = true;
5262                         delete[] shuffleVals;
5263                     } else if (g->opt.level > 0) {
5264                         PerformanceWarning(SourcePos(), "Stdlib shift() called without constant shift amount.");
5265                     }
5266                 }
5267             }
5268         }
5269     }
5270 
5271     DEBUG_END_PASS("ReplaceStdlibShiftPass");
5272 
5273     return modifiedAny;
5274 }
5275 
runOnFunction(llvm::Function & F)5276 bool ReplaceStdlibShiftPass::runOnFunction(llvm::Function &F) {
5277 
5278     llvm::TimeTraceScope FuncScope("ReplaceStdlibShiftPass::runOnFunction", F.getName());
5279     bool modifiedAny = false;
5280     for (llvm::BasicBlock &BB : F) {
5281         modifiedAny |= runOnBasicBlock(BB);
5282     }
5283     return modifiedAny;
5284 }
5285 
CreateReplaceStdlibShiftPass()5286 static llvm::Pass *CreateReplaceStdlibShiftPass() { return new ReplaceStdlibShiftPass(); }
5287 
5288 #ifdef ISPC_GENX_ENABLED
5289 
5290 ///////////////////////////////////////////////////////////////////////////
5291 // GenXGatherCoalescingPass
5292 
5293 /** This pass performs gather coalescing for GenX target.
5294  */
5295 class GenXGatherCoalescing : public llvm::FunctionPass {
5296   public:
5297     static char ID;
GenXGatherCoalescing()5298     GenXGatherCoalescing() : FunctionPass(ID) {}
5299 
getPassName() const5300     llvm::StringRef getPassName() const { return "GenX Gather Coalescing"; }
5301 
5302     bool runOnBasicBlock(llvm::BasicBlock &BB);
5303 
5304     bool runOnFunction(llvm::Function &Fn);
5305 };
5306 
5307 char GenXGatherCoalescing::ID = 0;
5308 
5309 // Returns pointer to pseudo_gather CallInst if inst is
5310 // actually a pseudo_gather
lGetPseudoGather(llvm::Instruction * inst)5311 static llvm::CallInst *lGetPseudoGather(llvm::Instruction *inst) {
5312     if (auto CI = llvm::dyn_cast<llvm::CallInst>(inst)) {
5313         if (CI->getCalledFunction()->getName().startswith("__pseudo_gather_base_offsets"))
5314             return CI;
5315     }
5316     return nullptr;
5317 }
5318 
5319 // Returns type of GEP users: nullptr in case when type is
5320 // different.
5321 //
5322 // TODO: currently only few instructions are handled:
5323 //   - load (scalar, see TODO below)
5324 //   - bitcast (checking it users)
5325 //   - pseudo_gather: see lGetPseudoGather
lGetLoadsType(llvm::Instruction * inst)5326 static llvm::Type *lGetLoadsType(llvm::Instruction *inst) {
5327     llvm::Type *ty = nullptr;
5328 
5329     // Bad user was provided
5330     if (!inst)
5331         return nullptr;
5332 
5333     for (auto ui = inst->use_begin(), ue = inst->use_end(); ui != ue; ++ui) {
5334 
5335         llvm::Instruction *user = llvm::dyn_cast<llvm::Instruction>(ui->getUser());
5336         llvm::Type *curr_type = nullptr;
5337 
5338         // Bitcast appeared, go through it
5339         if (auto BCI = llvm::dyn_cast<llvm::BitCastInst>(user)) {
5340             if (!(curr_type = lGetLoadsType(BCI)))
5341                 return nullptr;
5342         } else {
5343             // Load is OK if it is scalar
5344             // TODO: perform full investigation on situation
5345             // with non-scalar loads
5346             if (auto LI = llvm::dyn_cast<llvm::LoadInst>(user)) {
5347                 if (LI->getType()->isVectorTy())
5348                     return nullptr;
5349                 curr_type = LI->getType();
5350             } else if (lGetPseudoGather(user)) {
5351                 curr_type = user->getType();
5352             } else
5353                 continue;
5354         }
5355 
5356         if (!curr_type || (ty && ty != curr_type))
5357             return nullptr;
5358 
5359         ty = curr_type;
5360     }
5361 
5362     return ty;
5363 }
5364 
5365 // Struct that contains info about memory users
5366 struct PtrUse {
5367     llvm::Instruction *user; // Ptr
5368 
5369     std::vector<int64_t> idxs; // Users a presented in this list like it was Ptr[idx]
5370                                // TODO: current implementation depends on traverse order:
5371                                //       see lCreateBlockLDUse and lCollectIdxs
5372 
5373     llvm::Type *type; // Currently loads are grouped by types.
5374                       // TODO: it can be optimized
5375 
getScalarIdxPtrUse5376     int64_t &getScalarIdx() { return idxs[0]; }
5377 };
5378 
5379 // Create users for newly created block ld
5380 // TODO: current implementation depends on traverse order: see lCollectIdxs
lCreateBlockLDUse(llvm::Instruction * currInst,std::vector<llvm::ExtractElementInst * > EEIs,PtrUse & ptrUse,int & curr_idx,std::vector<llvm::Instruction * > & dead)5381 static void lCreateBlockLDUse(llvm::Instruction *currInst, std::vector<llvm::ExtractElementInst *> EEIs, PtrUse &ptrUse,
5382                               int &curr_idx, std::vector<llvm::Instruction *> &dead) {
5383     for (auto ui = currInst->use_begin(), ue = currInst->use_end(); ui != ue; ++ui) {
5384         if (auto BCI = llvm::dyn_cast<llvm::BitCastInst>(ui->getUser())) {
5385             // Go through bitcast
5386             lCreateBlockLDUse(BCI, EEIs, ptrUse, curr_idx, dead);
5387         } else if (auto LI = llvm::dyn_cast<llvm::LoadInst>(ui->getUser())) {
5388             // Only scalar loads are supported now
5389             llvm::ExtractElementInst *EEI = EEIs[ptrUse.idxs[curr_idx++]];
5390             LI->replaceAllUsesWith(EEI);
5391             lCopyMetadata(EEI, LI);
5392             dead.push_back(LI);
5393         } else if (auto gather = lGetPseudoGather(llvm::dyn_cast<llvm::Instruction>(ui->getUser()))) {
5394             // Collect idxs from gather and fix users
5395             llvm::Value *res = llvm::UndefValue::get(gather->getType());
5396 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
5397             for (unsigned i = 0, end = llvm::dyn_cast<llvm::FixedVectorType>(res->getType())->getNumElements(); i < end;
5398                  ++i)
5399 #else
5400             for (unsigned i = 0, end = llvm::dyn_cast<llvm::VectorType>(res->getType())->getNumElements(); i < end; ++i)
5401 #endif
5402             {
5403                 // Get element via IEI
5404                 res = llvm::InsertElementInst::Create(res, EEIs[ptrUse.idxs[curr_idx++]],
5405                                                       llvm::ConstantInt::get(LLVMTypes::Int32Type, i),
5406                                                       "coalesced_gather_iei", gather);
5407             }
5408             gather->replaceAllUsesWith(res);
5409             lCopyMetadata(res, gather);
5410             dead.push_back(gather);
5411         } else {
5412             continue;
5413         }
5414     }
5415 }
5416 
5417 // Perform optimization from GEPs
lVectorizeGEPs(llvm::Value * ptr,std::vector<PtrUse> & ptrUses,std::vector<llvm::Instruction * > & dead)5418 static bool lVectorizeGEPs(llvm::Value *ptr, std::vector<PtrUse> &ptrUses, std::vector<llvm::Instruction *> &dead) {
5419     // Calculate memory width for GEPs
5420     int64_t min_idx = INT64_MAX;
5421     int64_t max_idx = INT64_MIN;
5422     for (auto elem : ptrUses) {
5423         for (auto i : elem.idxs) {
5424             int64_t idx = i;
5425             max_idx = std::max(idx, max_idx);
5426             min_idx = std::min(idx, min_idx);
5427         }
5428     }
5429     llvm::Type *type = ptrUses[0].type;
5430     llvm::Type *scalar_type = type;
5431     llvm::Instruction *insertBefore = ptrUses[0].user;
5432 
5433     // Calculate element size in bytes
5434     uint64_t t_size = type->getPrimitiveSizeInBits() >> 3;
5435 
5436     // Adjust values for vector load
5437 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
5438     if (auto vecTy = llvm::dyn_cast<llvm::FixedVectorType>(type))
5439 #else
5440     if (auto vecTy = llvm::dyn_cast<llvm::VectorType>(type))
5441 #endif
5442     {
5443         // Get single element size
5444         t_size /= vecTy->getNumElements();
5445         // Get single element type
5446         scalar_type = vecTy->getScalarType();
5447     }
5448 
5449     // Pointer load should be done via inttoptr: replace type
5450     bool loadingPtr = false;
5451     llvm::Type *originalType = scalar_type;
5452     if (auto pTy = llvm::dyn_cast<llvm::PointerType>(scalar_type)) {
5453         scalar_type = g->target->is32Bit() ? LLVMTypes::Int32Type : LLVMTypes::Int64Type;
5454         t_size = scalar_type->getPrimitiveSizeInBits() >> 3;
5455         loadingPtr = true;
5456     }
5457 
5458     // Calculate length of array that needs to be loaded.
5459     // Idxs are in bytes now.
5460     uint64_t data_size = max_idx - min_idx + t_size;
5461 
5462     unsigned loads_needed = 1;
5463     uint64_t reqSize = 1;
5464 
5465     // Required load size. It can be 1, 2, 4, 8 OWORDs
5466     while (reqSize < data_size)
5467         reqSize <<= 1;
5468 
5469     // Adjust number of loads.
5470     // TODO: it is not clear if we should adjust
5471     // load size here
5472     if (reqSize > 8 * OWORD) {
5473         loads_needed = reqSize / (8 * OWORD);
5474         reqSize = 8 * OWORD;
5475     }
5476 
5477     // TODO: it is not clear if we should skip it or not
5478     if (reqSize < OWORD) {
5479         // Skip it for now
5480         return false;
5481     }
5482 
5483     // Coalesce loads/gathers
5484     std::vector<llvm::ExtractElementInst *> EEIs;
5485     for (unsigned i = 0; i < loads_needed; ++i) {
5486         llvm::Constant *offset = llvm::ConstantInt::get(LLVMTypes::Int64Type, min_idx + i * reqSize);
5487         llvm::PtrToIntInst *ptrToInt =
5488             new llvm::PtrToIntInst(ptr, LLVMTypes::Int64Type, "vectorized_ptrtoint", insertBefore);
5489         llvm::Instruction *addr = llvm::BinaryOperator::CreateAdd(ptrToInt, offset, "vectorized_address", insertBefore);
5490 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
5491         llvm::Type *retType = llvm::FixedVectorType::get(scalar_type, reqSize / t_size);
5492 #else
5493         llvm::Type *retType = llvm::VectorType::get(scalar_type, reqSize / t_size);
5494 #endif
5495         llvm::Function *fn = llvm::GenXIntrinsic::getGenXDeclaration(
5496             m->module, llvm::GenXIntrinsic::genx_svm_block_ld_unaligned, {retType, addr->getType()});
5497         llvm::Instruction *ld = llvm::CallInst::Create(fn, {addr}, "vectorized_ld", insertBefore);
5498 
5499         if (loadingPtr) {
5500             // Cast int to ptr via inttoptr
5501 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
5502             ld = new llvm::IntToPtrInst(ld, llvm::FixedVectorType::get(originalType, reqSize / t_size),
5503                                         "vectorized_inttoptr", insertBefore);
5504 #else
5505             ld = new llvm::IntToPtrInst(ld, llvm::VectorType::get(originalType, reqSize / t_size),
5506                                         "vectorized_inttoptr", insertBefore);
5507 #endif
5508         }
5509 
5510         // Scalar extracts for all loaded elements
5511         for (unsigned j = 0; j < reqSize / t_size; ++j) {
5512             auto EEI = llvm::ExtractElementInst::Create(ld, llvm::ConstantInt::get(LLVMTypes::Int64Type, j),
5513                                                         "vectorized_extr_elem", ld->getParent());
5514             EEIs.push_back(EEI);
5515             EEI->moveAfter(ld);
5516         }
5517     }
5518 
5519     // Change value in users and delete redundant insts
5520     for (auto elem : ptrUses) {
5521         // Recalculate idx: we are going to use extractelement
5522         for (auto &index : elem.idxs) {
5523             index -= min_idx;
5524             index /= t_size;
5525         }
5526         int curr_idx = 0;
5527         lCreateBlockLDUse(elem.user, EEIs, elem, curr_idx, dead);
5528     }
5529 
5530     return true;
5531 }
5532 
5533 // Get idxs from current instruction users
5534 // TODO: current implementation depends on traverse order: see lCreateBlockLDUse
lCollectIdxs(llvm::Instruction * inst,std::vector<int64_t> & idxs,int64_t original_idx)5535 static void lCollectIdxs(llvm::Instruction *inst, std::vector<int64_t> &idxs, int64_t original_idx) {
5536     for (auto ui = inst->use_begin(), ue = inst->use_end(); ui != ue; ++ui) {
5537 
5538         llvm::Instruction *user = llvm::dyn_cast<llvm::Instruction>(ui->getUser());
5539 
5540         if (auto BCI = llvm::dyn_cast<llvm::BitCastInst>(user)) {
5541             // Bitcast appeared, go through it
5542             lCollectIdxs(BCI, idxs, original_idx);
5543         } else {
5544             if (auto LI = llvm::dyn_cast<llvm::LoadInst>(user)) {
5545                 // Only scalar loads are supported now. It should be checked earlier.
5546                 Assert(!LI->getType()->isVectorTy());
5547                 idxs.push_back(original_idx);
5548             } else if (auto gather = lGetPseudoGather(user)) {
5549                 // Collect gather's idxs
5550                 auto scale_c = llvm::dyn_cast<llvm::ConstantInt>(gather->getOperand(1));
5551                 int64_t scale = scale_c->getSExtValue();
5552                 auto offset = llvm::dyn_cast<llvm::ConstantDataVector>(gather->getOperand(2));
5553 
5554                 for (unsigned i = 0, size = offset->getType()->getNumElements(); i < size; ++i) {
5555                     int64_t curr_offset =
5556                         llvm::dyn_cast<llvm::ConstantInt>(offset->getElementAsConstant(i))->getSExtValue() * scale +
5557                         original_idx;
5558                     idxs.push_back(curr_offset);
5559                 }
5560             } else {
5561                 continue;
5562             }
5563         }
5564     }
5565 }
5566 
5567 // Update GEP index in case of nested GEP interaction
lSetRealPtr(llvm::GetElementPtrInst * GEP)5568 static void lSetRealPtr(llvm::GetElementPtrInst *GEP) {
5569     llvm::GetElementPtrInst *predGEP = llvm::dyn_cast<llvm::GetElementPtrInst>(GEP->getPointerOperand());
5570     // The deepest GEP
5571     if (!predGEP)
5572         return;
5573     // Pred is a bypass
5574     if (predGEP->hasAllConstantIndices() && predGEP->getNumOperands() == 2 &&
5575         llvm::dyn_cast<llvm::Constant>(predGEP->getOperand(1))->isNullValue()) {
5576         // Go through several bypasses
5577         lSetRealPtr(predGEP);
5578         GEP->setOperand(0, predGEP->getPointerOperand());
5579     }
5580 }
5581 
lSkipInst(llvm::Instruction * inst)5582 static bool lSkipInst(llvm::Instruction *inst) {
5583     // TODO: this variable should have different
5584     // value for scatter coalescing
5585     unsigned min_operands = 1;
5586 
5587     // Skip phis
5588     if (llvm::isa<llvm::PHINode>(inst))
5589         return true;
5590     // Skip obviously wrong instruction
5591     if (inst->getNumOperands() < min_operands)
5592         return true;
5593     // Skip GEPs: handle its users instead
5594     if (llvm::isa<llvm::GetElementPtrInst>(inst))
5595         return true;
5596 
5597     return false;
5598 }
5599 
5600 // Prepare GEPs for further optimization:
5601 //   - Simplify complex GEPs
5602 //   - Copy GEPs that are located in other BBs
5603 //   - Resolve possible duplications in current BB
5604 // TODO: it looks possible to find common pointers in
5605 //       some cases. Not implemented now.
lPrepareGEPs(llvm::BasicBlock & bb)5606 static bool lPrepareGEPs(llvm::BasicBlock &bb) {
5607     bool changed = true;
5608     bool modified = false;
5609     // TODO: this variable should have different
5610     // value for scatter coalescing
5611     unsigned gep_operand_index = 0;
5612 
5613     while (changed) {
5614         changed = false;
5615 
5616         std::map<llvm::Value *, std::map<llvm::Value *, llvm::Value *>> ptrData;
5617         std::vector<llvm::Instruction *> dead;
5618         std::map<llvm::Value *, llvm::Value *> movedGEPs;
5619         for (auto bi = bb.begin(), be = bb.end(); bi != be; ++bi) {
5620             llvm::Instruction *inst = &*bi;
5621 
5622             if (lSkipInst(inst))
5623                 continue;
5624 
5625             llvm::GetElementPtrInst *GEP = llvm::dyn_cast<llvm::GetElementPtrInst>(inst->getOperand(gep_operand_index));
5626             if (!GEP) {
5627                 continue;
5628             }
5629 
5630             // Set real pointer in order to ignore possible bypasses
5631             lSetRealPtr(GEP);
5632 
5633             // Copy GEP to current BB
5634             if (GEP->getParent() != &bb) {
5635                 auto it = movedGEPs.end();
5636                 if ((it = movedGEPs.find(GEP)) == movedGEPs.end()) {
5637                     auto newGEP = llvm::dyn_cast<llvm::GetElementPtrInst>(GEP->clone());
5638                     Assert(newGEP != nullptr);
5639                     newGEP->insertBefore(inst);
5640                     inst->setOperand(gep_operand_index, newGEP);
5641                     movedGEPs[GEP] = newGEP;
5642                     dead.push_back(GEP);
5643                     GEP = newGEP;
5644                     changed = true;
5645                 } else {
5646                     GEP = llvm::dyn_cast<llvm::GetElementPtrInst>(it->second);
5647                     Assert(GEP != nullptr);
5648                     inst->setOperand(gep_operand_index, GEP);
5649                     // Do not handle GEP during one BB run twice
5650                     continue;
5651                 }
5652             }
5653 
5654             // GEP is ready for optimization
5655             if (GEP->hasAllConstantIndices())
5656                 continue;
5657             // TODO: Currently skipping constants at first index
5658             // We can handle const-(const)*-arg-(arg)*-const-(const)*
5659             // pattern if the first 4 fields are the same for some GEPs
5660             if (llvm::dyn_cast<llvm::Constant>(GEP->getOperand(1)))
5661                 continue;
5662             // Not interested in it
5663             if (!GEP->getNumUses())
5664                 continue;
5665 
5666             llvm::Value *ptr = GEP->getPointerOperand();
5667             llvm::Value *idx = GEP->getOperand(1);
5668             llvm::Value *foundGEP = nullptr;
5669             auto it = ptrData[ptr].find(idx);
5670 
5671             if (it != ptrData[ptr].end()) {
5672                 foundGEP = it->second;
5673                 // Duplication: just move uses
5674                 if (GEP->getNumOperands() == 2) {
5675                     if (GEP != foundGEP) {
5676                         GEP->replaceAllUsesWith(foundGEP);
5677                         lCopyMetadata(foundGEP, GEP);
5678                         dead.push_back(GEP);
5679                         changed = true;
5680                     }
5681 
5682                     continue;
5683                 }
5684             } else {
5685                 // Don't need to create new GEP, use this one
5686                 if (GEP->getNumOperands() == 2)
5687                     foundGEP = GEP;
5688                 else
5689                     foundGEP = llvm::GetElementPtrInst::Create(nullptr, ptr, {idx}, "lowered_gep", GEP);
5690 
5691                 ptrData[ptr][idx] = foundGEP;
5692             }
5693 
5694             if (foundGEP == GEP)
5695                 continue;
5696 
5697             std::vector<llvm::Value *> args;
5698             args.push_back(llvm::Constant::getNullValue(GEP->getOperand(1)->getType()));
5699             for (unsigned i = 2, end = GEP->getNumOperands(); i < end; ++i)
5700                 args.push_back(GEP->getOperand(i));
5701             auto foundGEPUser = llvm::GetElementPtrInst::Create(nullptr, foundGEP, args, "lowered_gep_succ", GEP);
5702             GEP->replaceAllUsesWith(foundGEPUser);
5703             lCopyMetadata(foundGEPUser, GEP);
5704             dead.push_back(GEP);
5705             changed = true;
5706         }
5707 
5708         for (auto inst : dead) {
5709             llvm::RecursivelyDeleteTriviallyDeadInstructions(inst);
5710         }
5711 
5712         if (changed)
5713             modified = true;
5714     }
5715 
5716     return modified;
5717 }
5718 
5719 // Return true if there is a possibility of
5720 // load-store-load sequence due to current instruction
lCheckForPossibleStore(llvm::Instruction * inst)5721 static bool lCheckForPossibleStore(llvm::Instruction *inst) {
5722     if (auto CI = llvm::dyn_cast<llvm::CallInst>(inst)) {
5723         // Such call may introduce load-store-load sequence
5724         if (!CI->onlyReadsMemory())
5725             return true;
5726     } else if (auto SI = llvm::dyn_cast<llvm::StoreInst>(inst)) {
5727         // May introduce load-store-load sequence
5728         if (GetAddressSpace(SI->getPointerOperand()) == AddressSpace::External)
5729             return true;
5730     }
5731 
5732     // No possible bad store detected
5733     return false;
5734 }
5735 
5736 // Run optimization for all collected GEPs
lRunGenXCoalescing(std::map<llvm::Value *,std::map<llvm::Type *,std::vector<PtrUse>>> & ptrUses,std::vector<llvm::Instruction * > & dead)5737 static bool lRunGenXCoalescing(std::map<llvm::Value *, std::map<llvm::Type *, std::vector<PtrUse>>> &ptrUses,
5738                                std::vector<llvm::Instruction *> &dead) {
5739     bool modifiedAny = false;
5740 
5741     for (auto data : ptrUses) {
5742         for (auto data_t : data.second)
5743             modifiedAny |= lVectorizeGEPs(data.first, data_t.second, dead);
5744     }
5745 
5746     return modifiedAny;
5747 }
5748 
5749 // Analyze and optimize loads/gathers in current BB
lVectorizeLoads(llvm::BasicBlock & bb)5750 static bool lVectorizeLoads(llvm::BasicBlock &bb) {
5751     bool modifiedAny = false;
5752 
5753     // TODO: this variable should have different
5754     // value for scatter coalescing
5755     unsigned gep_operand_index = 0;
5756 
5757     std::map<llvm::Value *, std::map<llvm::Type *, std::vector<PtrUse>>> ptrUses;
5758     std::vector<llvm::Instruction *> dead;
5759     std::set<llvm::GetElementPtrInst *> handledGEPs;
5760     for (auto bi = bb.begin(), be = bb.end(); bi != be; ++bi) {
5761         llvm::Instruction *inst = &*bi;
5762 
5763         if (lSkipInst(inst))
5764             continue;
5765 
5766         // Check for possible load-store-load sequence
5767         if (lCheckForPossibleStore(inst)) {
5768             // Perform early run of the optimization
5769             modifiedAny |= lRunGenXCoalescing(ptrUses, dead);
5770             // All current changes were applied
5771             ptrUses.clear();
5772             handledGEPs.clear();
5773             continue;
5774         }
5775 
5776         // Get GEP operand: current algorithm is based on GEPs
5777         // TODO: this fact leads to a problem with PTR[0]: it can be
5778         // accessed without GEP. The possible solution is to build GEP
5779         // during preprocessing GEPs.
5780         llvm::GetElementPtrInst *GEP = llvm::dyn_cast<llvm::GetElementPtrInst>(inst->getOperand(gep_operand_index));
5781         if (!GEP)
5782             continue;
5783 
5784         // Since gathers work with SVM, this optimization
5785         // should be applied to external address space only
5786         llvm::Value *ptr = GEP->getPointerOperand();
5787         if (GetAddressSpace(ptr) != AddressSpace::External)
5788             continue;
5789 
5790         // GEP was already added to the list for optimization
5791         if (handledGEPs.count(GEP))
5792             continue;
5793         handledGEPs.insert(GEP);
5794 
5795         // All GEPs that are located in other BB should be copied during preprocessing
5796         Assert(GEP->getParent() == &bb);
5797 
5798         // Create bypass for bad GEPs
5799         // TODO: that can be applied to several GEPs, so they can be grouped.
5800         // It is not done now.
5801         if (!GEP->hasAllConstantIndices()) {
5802             llvm::GetElementPtrInst *newGEP = llvm::GetElementPtrInst::Create(
5803                 nullptr, GEP, {llvm::dyn_cast<llvm::ConstantInt>(llvm::ConstantInt::get(LLVMTypes::Int32Type, 0))},
5804                 "ptr_bypass", GEP->getParent());
5805             newGEP->moveAfter(GEP);
5806             GEP->replaceAllUsesWith(newGEP);
5807             lCopyMetadata(newGEP, GEP);
5808             newGEP->setOperand(0, GEP);
5809             GEP = newGEP;
5810             continue;
5811         }
5812 
5813         // Calculate GEP offset
5814         llvm::APInt acc(g->target->is32Bit() ? 32 : 64, 0, true);
5815         bool checker = GEP->accumulateConstantOffset(GEP->getParent()->getParent()->getParent()->getDataLayout(), acc);
5816         Assert(checker);
5817 
5818         // Skip unsuitable GEPs
5819         llvm::Type *ty = nullptr;
5820         if (!(ty = lGetLoadsType(GEP)))
5821             continue;
5822 
5823         // Collect used idxs
5824         int64_t idx = acc.getSExtValue();
5825         std::vector<int64_t> idxs;
5826         lCollectIdxs(GEP, idxs, idx);
5827 
5828         // Store data for GEP
5829         ptrUses[ptr][ty].push_back({GEP, idxs, ty});
5830     }
5831 
5832     // Run optimization
5833     modifiedAny |= lRunGenXCoalescing(ptrUses, dead);
5834 
5835     // TODO: perform investigation on compilation failure with it
5836 #if 0
5837     for (auto d : dead) {
5838         llvm::RecursivelyDeleteTriviallyDeadInstructions(d);
5839     }
5840 #endif
5841 
5842     return modifiedAny;
5843 }
5844 
runOnBasicBlock(llvm::BasicBlock & bb)5845 bool GenXGatherCoalescing::runOnBasicBlock(llvm::BasicBlock &bb) {
5846     DEBUG_START_PASS("GenXGatherCoalescing");
5847 
5848     bool modifiedAny = lPrepareGEPs(bb);
5849 
5850     modifiedAny |= lVectorizeLoads(bb);
5851 
5852     DEBUG_END_PASS("GenXGatherCoalescing");
5853 
5854     return modifiedAny;
5855 }
5856 
runOnFunction(llvm::Function & F)5857 bool GenXGatherCoalescing::runOnFunction(llvm::Function &F) {
5858     llvm::TimeTraceScope FuncScope("GenXGatherCoalescing::runOnFunction", F.getName());
5859     bool modifiedAny = false;
5860     for (llvm::BasicBlock &BB : F) {
5861         modifiedAny |= runOnBasicBlock(BB);
5862     }
5863     return modifiedAny;
5864 }
5865 
CreateGenXGatherCoalescingPass()5866 static llvm::Pass *CreateGenXGatherCoalescingPass() { return new GenXGatherCoalescing; }
5867 
5868 ///////////////////////////////////////////////////////////////////////////
5869 // ReplaceLLVMIntrinsics
5870 
5871 /** This pass replaces LLVM intrinsics unsupported on GenX
5872  */
5873 
5874 class ReplaceLLVMIntrinsics : public llvm::FunctionPass {
5875   public:
5876     static char ID;
ReplaceLLVMIntrinsics()5877     ReplaceLLVMIntrinsics() : FunctionPass(ID) {}
getPassName() const5878     llvm::StringRef getPassName() const { return "LLVM intrinsics replacement"; }
5879     bool runOnBasicBlock(llvm::BasicBlock &BB);
5880     bool runOnFunction(llvm::Function &F);
5881 };
5882 
5883 char ReplaceLLVMIntrinsics::ID = 0;
5884 
runOnBasicBlock(llvm::BasicBlock & bb)5885 bool ReplaceLLVMIntrinsics::runOnBasicBlock(llvm::BasicBlock &bb) {
5886     DEBUG_START_PASS("LLVM intrinsics replacement");
5887     std::vector<llvm::AllocaInst *> Allocas;
5888 
5889     bool modifiedAny = false;
5890 
5891 restart:
5892     for (llvm::BasicBlock::iterator I = bb.begin(), E = --bb.end(); I != E; ++I) {
5893         llvm::Instruction *inst = &*I;
5894         if (llvm::CallInst *ci = llvm::dyn_cast<llvm::CallInst>(inst)) {
5895             llvm::Function *func = ci->getCalledFunction();
5896             if (func == NULL || !func->isIntrinsic())
5897                 continue;
5898 
5899             if (func->getName().equals("llvm.trap")) {
5900                 llvm::Type *argTypes[] = {LLVMTypes::Int1VectorType, LLVMTypes::Int16VectorType};
5901                 // Description of parameters for genx_raw_send_noresult can be found in target-genx.ll
5902                 auto Fn = +llvm::GenXIntrinsic::getGenXDeclaration(
5903                     m->module, llvm::GenXIntrinsic::genx_raw_send_noresult, argTypes);
5904                 llvm::SmallVector<llvm::Value *, 8> Args;
5905                 Args.push_back(llvm::ConstantInt::get(LLVMTypes::Int32Type, 0));
5906                 Args.push_back(llvm::ConstantVector::getSplat(
5907 #if ISPC_LLVM_VERSION < ISPC_LLVM_11_0
5908                     g->target->getNativeVectorWidth(),
5909 #elif ISPC_LLVM_VERSION < ISPC_LLVM_12_0
5910                     {static_cast<unsigned int>(g->target->getNativeVectorWidth()), false},
5911 #else // LLVM 12.0+
5912                     llvm::ElementCount::get(static_cast<unsigned int>(g->target->getNativeVectorWidth()), false),
5913 #endif
5914                     llvm::ConstantInt::getTrue(*g->ctx)));
5915 
5916                 Args.push_back(llvm::ConstantInt::get(LLVMTypes::Int32Type, 39));
5917                 Args.push_back(llvm::ConstantInt::get(LLVMTypes::Int32Type, 33554448));
5918                 llvm::Value *zeroMask = llvm::ConstantVector::getSplat(
5919 #if ISPC_LLVM_VERSION < ISPC_LLVM_11_0
5920                     g->target->getNativeVectorWidth(),
5921 #elif ISPC_LLVM_VERSION < ISPC_LLVM_12_0
5922                     {static_cast<unsigned int>(g->target->getNativeVectorWidth()), false},
5923 #else // LLVM 12.0+
5924                     llvm::ElementCount::get(static_cast<unsigned int>(g->target->getNativeVectorWidth()), false),
5925 #endif
5926                     llvm::Constant::getNullValue(llvm::Type::getInt16Ty(*g->ctx)));
5927                 Args.push_back(zeroMask);
5928 
5929                 llvm::Instruction *newInst = llvm::CallInst::Create(Fn, Args, ci->getName());
5930                 if (newInst != NULL) {
5931                     llvm::ReplaceInstWithInst(ci, newInst);
5932                     modifiedAny = true;
5933                     goto restart;
5934                 }
5935             } else if (func->getName().equals("llvm.assume") ||
5936                        func->getName().equals("llvm.experimental.noalias.scope.decl")) {
5937                 // These intrinsics are not supported by backend so remove them.
5938                 ci->eraseFromParent();
5939                 modifiedAny = true;
5940                 goto restart;
5941             } else if (func->getName().contains("llvm.abs")) {
5942                 // Replace llvm.asb with llvm.genx.aba.alternative
5943                 Assert(ci->getOperand(0));
5944                 llvm::Type *argType = ci->getOperand(0)->getType();
5945 
5946                 llvm::Type *Tys[2];
5947                 Tys[0] = func->getReturnType(); // return type
5948                 Tys[1] = argType;               // value type
5949 
5950                 llvm::GenXIntrinsic::ID genxAbsID =
5951                     argType->isIntOrIntVectorTy() ? llvm::GenXIntrinsic::genx_absi : llvm::GenXIntrinsic::genx_absf;
5952                 auto Fn = llvm::GenXIntrinsic::getGenXDeclaration(m->module, genxAbsID, Tys);
5953                 Assert(Fn);
5954                 llvm::Instruction *newInst = llvm::CallInst::Create(Fn, ci->getOperand(0), "");
5955                 if (newInst != NULL) {
5956                     lCopyMetadata(newInst, ci);
5957                     llvm::ReplaceInstWithInst(ci, newInst);
5958                     modifiedAny = true;
5959                     goto restart;
5960                 }
5961             }
5962         }
5963     }
5964     DEBUG_END_PASS("LLVM intrinsics replacement");
5965     return modifiedAny;
5966 }
5967 
runOnFunction(llvm::Function & F)5968 bool ReplaceLLVMIntrinsics::runOnFunction(llvm::Function &F) {
5969 
5970     llvm::TimeTraceScope FuncScope("ReplaceLLVMIntrinsics::runOnFunction", F.getName());
5971     bool modifiedAny = false;
5972     for (llvm::BasicBlock &BB : F) {
5973         modifiedAny |= runOnBasicBlock(BB);
5974     }
5975     return modifiedAny;
5976 }
5977 
CreateReplaceLLVMIntrinsics()5978 static llvm::Pass *CreateReplaceLLVMIntrinsics() { return new ReplaceLLVMIntrinsics(); }
5979 
5980 ///////////////////////////////////////////////////////////////////////////
5981 // FixDivisionInstructions
5982 
5983 /** This pass replaces instructions supported by LLVM IR, but not supported by GenX backend.
5984     There is IR for i64 div, but there is no div i64 in VISA and GenX backend doesn't handle
5985     this situation, so ISPC must not generate IR with i64 div.
5986  */
5987 
5988 class FixDivisionInstructions : public llvm::FunctionPass {
5989   public:
5990     static char ID;
FixDivisionInstructions()5991     FixDivisionInstructions() : FunctionPass(ID) {}
getPassName() const5992     llvm::StringRef getPassName() const { return "Fix division instructions unsupported by VC backend"; }
5993     bool runOnBasicBlock(llvm::BasicBlock &BB);
5994     bool runOnFunction(llvm::Function &F);
5995 };
5996 
5997 char FixDivisionInstructions::ID = 0;
5998 
runOnBasicBlock(llvm::BasicBlock & bb)5999 bool FixDivisionInstructions::runOnBasicBlock(llvm::BasicBlock &bb) {
6000     DEBUG_START_PASS("FixDivisionInstructions");
6001     bool modifiedAny = false;
6002 
6003     std::string name = "";
6004     llvm::Function *func;
6005     for (llvm::BasicBlock::iterator I = bb.begin(), E = --bb.end(); I != E; ++I) {
6006         llvm::Instruction *inst = &*I;
6007         // for now, all replaced inst have 2 operands
6008         if (inst->getNumOperands() > 1) {
6009             auto type = inst->getOperand(0)->getType();
6010             // for now, all replaced inst have i64 operands
6011             if ((type == LLVMTypes::Int64Type) || (type == LLVMTypes::Int64VectorType)) {
6012                 switch (inst->getOpcode()) {
6013                 case llvm::Instruction::BinaryOps::UDiv:
6014                     if (type == LLVMTypes::Int64Type) {
6015                         name = "__divus_ui64";
6016                     } else {
6017                         name = "__divus_vi64";
6018                     }
6019                     break;
6020                 case llvm::Instruction::BinaryOps::SDiv:
6021                     if (type == LLVMTypes::Int64Type) {
6022                         name = "__divs_ui64";
6023                     } else {
6024                         name = "__divs_vi64";
6025                     }
6026                     break;
6027                 case llvm::Instruction::BinaryOps::URem:
6028                     if (type == LLVMTypes::Int64Type) {
6029                         name = "__remus_ui64";
6030                     } else {
6031                         name = "__remus_vi64";
6032                     }
6033                     break;
6034                 case llvm::Instruction::BinaryOps::SRem:
6035                     if (type == LLVMTypes::Int64Type) {
6036                         name = "__rems_ui64";
6037                     } else {
6038                         name = "__rems_vi64";
6039                     }
6040                     break;
6041                 default:
6042                     name = "";
6043                     break;
6044                 }
6045                 if (name != "") {
6046                     func = m->module->getFunction(name);
6047                     Assert(func != NULL && "FixDivisionInstructions: Can't find correct function!!!");
6048                     llvm::SmallVector<llvm::Value *, 8> args;
6049                     args.push_back(inst->getOperand(0));
6050                     args.push_back(inst->getOperand(1));
6051                     llvm::Instruction *newInst = llvm::CallInst::Create(func, args, name);
6052                     llvm::ReplaceInstWithInst(inst, newInst);
6053                     modifiedAny = true;
6054                     I = newInst->getIterator();
6055                 }
6056             }
6057         }
6058     }
6059     DEBUG_END_PASS("FixDivisionInstructions");
6060     return modifiedAny;
6061 }
6062 
runOnFunction(llvm::Function & F)6063 bool FixDivisionInstructions::runOnFunction(llvm::Function &F) {
6064 
6065     llvm::TimeTraceScope FuncScope("FixDivisionInstructions::runOnFunction", F.getName());
6066     bool modifiedAny = false;
6067     for (llvm::BasicBlock &BB : F) {
6068         modifiedAny |= runOnBasicBlock(BB);
6069     }
6070     return modifiedAny;
6071 }
6072 
CreateFixDivisionInstructions()6073 static llvm::Pass *CreateFixDivisionInstructions() { return new FixDivisionInstructions(); }
6074 
6075 ///////////////////////////////////////////////////////////////////////////
6076 // CheckUnsupportedInsts
6077 
6078 /** This pass checks if there are any functions used which are not supported currently for gen target,
6079     reports error and stops compilation.
6080  */
6081 
6082 class CheckUnsupportedInsts : public llvm::FunctionPass {
6083   public:
6084     static char ID;
CheckUnsupportedInsts(bool last=false)6085     CheckUnsupportedInsts(bool last = false) : FunctionPass(ID) {}
6086 
getPassName() const6087     llvm::StringRef getPassName() const { return "Check unsupported instructions for gen target"; }
6088     bool runOnBasicBlock(llvm::BasicBlock &BB);
6089     bool runOnFunction(llvm::Function &F);
6090 };
6091 
6092 char CheckUnsupportedInsts::ID = 0;
6093 
runOnBasicBlock(llvm::BasicBlock & bb)6094 bool CheckUnsupportedInsts::runOnBasicBlock(llvm::BasicBlock &bb) {
6095     DEBUG_START_PASS("CheckUnsupportedInsts");
6096     bool modifiedAny = false;
6097     // This list contains regex expr for unsupported function names
6098     // To be extended
6099 
6100     for (llvm::BasicBlock::iterator I = bb.begin(), E = --bb.end(); I != E; ++I) {
6101         llvm::Instruction *inst = &*I;
6102         SourcePos pos;
6103         lGetSourcePosFromMetadata(inst, &pos);
6104         if (llvm::CallInst *ci = llvm::dyn_cast<llvm::CallInst>(inst)) {
6105             llvm::Function *func = ci->getCalledFunction();
6106             // Report error that prefetch is not supported on SKL and TGLLP
6107             if (func && func->getName().contains("genx.lsc.prefetch.stateless")) {
6108                 if (!g->target->hasGenxPrefetch()) {
6109                     Error(pos, "\'prefetch\' is not supported by %s\n", g->target->getCPU().c_str());
6110                 }
6111             }
6112         }
6113         // Report error if double type is not supported by the target
6114         if (!g->target->hasFp64Support()) {
6115             for (int i = 0; i < (int)inst->getNumOperands(); ++i) {
6116                 llvm::Type *t = inst->getOperand(i)->getType();
6117                 if (t == LLVMTypes::DoubleType || t == LLVMTypes::DoublePointerType ||
6118                     t == LLVMTypes::DoubleVectorType || t == LLVMTypes::DoubleVectorPointerType) {
6119                     Error(pos, "\'double\' type is not supported by the target\n");
6120                 }
6121             }
6122         }
6123     }
6124     DEBUG_END_PASS("CheckUnsupportedInsts");
6125 
6126     return modifiedAny;
6127 }
6128 
runOnFunction(llvm::Function & F)6129 bool CheckUnsupportedInsts::runOnFunction(llvm::Function &F) {
6130     llvm::TimeTraceScope FuncScope("CheckUnsupportedInsts::runOnFunction", F.getName());
6131     bool modifiedAny = false;
6132     for (llvm::BasicBlock &BB : F) {
6133         modifiedAny |= runOnBasicBlock(BB);
6134     }
6135     return modifiedAny;
6136 }
6137 
CreateCheckUnsupportedInsts()6138 static llvm::Pass *CreateCheckUnsupportedInsts() { return new CheckUnsupportedInsts(); }
6139 
6140 ///////////////////////////////////////////////////////////////////////////
6141 // MangleOpenCLBuiltins
6142 
6143 /** This pass mangles SPIR-V OpenCL builtins used in gen target file
6144  */
6145 
6146 class MangleOpenCLBuiltins : public llvm::FunctionPass {
6147   public:
6148     static char ID;
MangleOpenCLBuiltins(bool last=false)6149     MangleOpenCLBuiltins(bool last = false) : FunctionPass(ID) {}
6150 
getPassName() const6151     llvm::StringRef getPassName() const { return "Mangle OpenCL builtins"; }
6152     bool runOnBasicBlock(llvm::BasicBlock &BB);
6153     bool runOnFunction(llvm::Function &F);
6154 };
6155 
6156 char MangleOpenCLBuiltins::ID = 0;
6157 
runOnBasicBlock(llvm::BasicBlock & bb)6158 bool MangleOpenCLBuiltins::runOnBasicBlock(llvm::BasicBlock &bb) {
6159     DEBUG_START_PASS("MangleOpenCLBuiltins");
6160     bool modifiedAny = false;
6161     for (llvm::BasicBlock::iterator I = bb.begin(), E = --bb.end(); I != E; ++I) {
6162         llvm::Instruction *inst = &*I;
6163         if (llvm::CallInst *ci = llvm::dyn_cast<llvm::CallInst>(inst)) {
6164             llvm::Function *func = ci->getCalledFunction();
6165             if (func == NULL)
6166                 continue;
6167             if (func->getName().startswith("__spirv_ocl")) {
6168                 std::string mangledName;
6169                 llvm::Type *retType = func->getReturnType();
6170                 std::string funcName = func->getName().str();
6171                 std::vector<llvm::Type *> ArgTy;
6172                 // spirv OpenCL builtins are used for double types only
6173                 Assert(retType->isVectorTy() &&
6174                            llvm::dyn_cast<llvm::VectorType>(retType)->getElementType()->isDoubleTy() ||
6175                        retType->isSingleValueType() && retType->isDoubleTy());
6176                 if (retType->isVectorTy() &&
6177                     llvm::dyn_cast<llvm::VectorType>(retType)->getElementType()->isDoubleTy()) {
6178                     ArgTy.push_back(LLVMTypes::DoubleVectorType);
6179                     // _DvWIDTH suffix is used in target file to differentiate scalar
6180                     // and vector versions of intrinsics. Here we remove this
6181                     // suffix and mangle the name.
6182                     size_t pos = funcName.find("_DvWIDTH");
6183                     if (pos != std::string::npos) {
6184                         funcName.erase(pos, 8);
6185                     }
6186                 } else if (retType->isSingleValueType() && retType->isDoubleTy()) {
6187                     ArgTy.push_back(LLVMTypes::DoubleType);
6188                 }
6189                 mangleOpenClBuiltin(funcName, ArgTy, mangledName);
6190                 func->setName(mangledName);
6191                 modifiedAny = true;
6192             }
6193         }
6194     }
6195     DEBUG_END_PASS("MangleOpenCLBuiltins");
6196 
6197     return modifiedAny;
6198 }
6199 
runOnFunction(llvm::Function & F)6200 bool MangleOpenCLBuiltins::runOnFunction(llvm::Function &F) {
6201     llvm::TimeTraceScope FuncScope("MangleOpenCLBuiltins::runOnFunction", F.getName());
6202     bool modifiedAny = false;
6203     for (llvm::BasicBlock &BB : F) {
6204         modifiedAny |= runOnBasicBlock(BB);
6205     }
6206     return modifiedAny;
6207 }
6208 
CreateMangleOpenCLBuiltins()6209 static llvm::Pass *CreateMangleOpenCLBuiltins() { return new MangleOpenCLBuiltins(); }
6210 
6211 ///////////////////////////////////////////////////////////////////////////
6212 // FixAddressSpace
6213 
6214 /** This pass fixes should go close to the end of optimizations. It fixes address space issues
6215     caused by functions inlining and LLVM optimizations.
6216     TODO: the implemenattion is not completed yet, it just unblocks work on openvkl kernel.
6217     1. fix where needed svn.block.ld -> load, svm.block.st -> store
6218     2. svm gathers/scatters <-> private gathers/scatters
6219     3. ideally generate LLVM vector loads/stores in ImproveOptMemory pass, fix address space here
6220  */
6221 
6222 class FixAddressSpace : public llvm::FunctionPass {
6223     llvm::Instruction *processVectorLoad(llvm::LoadInst *LI);
6224     llvm::Instruction *processSVMVectorLoad(llvm::Instruction *LI);
6225     llvm::Instruction *processVectorStore(llvm::StoreInst *SI);
6226     llvm::Instruction *processSVMVectorStore(llvm::Instruction *LI);
6227     void applyReplace(llvm::Instruction *Inst, llvm::Instruction *ToReplace);
6228 
6229   public:
6230     static char ID;
FixAddressSpace()6231     FixAddressSpace() : FunctionPass(ID) {}
getPassName() const6232     llvm::StringRef getPassName() const { return "Fix address space"; }
6233     bool runOnBasicBlock(llvm::BasicBlock &BB);
6234     bool runOnFunction(llvm::Function &F);
6235 };
6236 
6237 char FixAddressSpace::ID = 0;
6238 
applyReplace(llvm::Instruction * Inst,llvm::Instruction * ToReplace)6239 void FixAddressSpace::applyReplace(llvm::Instruction *Inst, llvm::Instruction *ToReplace) {
6240     lCopyMetadata(Inst, ToReplace);
6241     llvm::ReplaceInstWithInst(ToReplace, Inst);
6242 }
6243 
processVectorLoad(llvm::LoadInst * LI)6244 llvm::Instruction *FixAddressSpace::processVectorLoad(llvm::LoadInst *LI) {
6245     llvm::Value *ptr = LI->getOperand(0);
6246     llvm::Type *retType = LI->getType();
6247 
6248     Assert(llvm::isa<llvm::VectorType>(LI->getType()));
6249 
6250     if (GetAddressSpace(ptr) != AddressSpace::External)
6251         return NULL;
6252 
6253     // Ptr load should be done via inttoptr
6254     bool isPtrLoad = false;
6255     if (retType->getScalarType()->isPointerTy()) {
6256         isPtrLoad = true;
6257         auto scalarType = g->target->is32Bit() ? LLVMTypes::Int32Type : LLVMTypes::Int64Type;
6258 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
6259         retType =
6260             llvm::FixedVectorType::get(scalarType, llvm::dyn_cast<llvm::FixedVectorType>(retType)->getNumElements());
6261 #else
6262         retType = llvm::VectorType::get(scalarType, retType->getVectorNumElements());
6263 #endif
6264     }
6265     llvm::Instruction *res = lGenXLoadInst(ptr, retType, llvm::dyn_cast<llvm::Instruction>(LI));
6266     Assert(res);
6267 
6268     if (isPtrLoad) {
6269         res->insertBefore(LI);
6270         res = new llvm::IntToPtrInst(res, LI->getType(), "svm_ld_inttoptr");
6271     }
6272 
6273     return res;
6274 }
6275 
processSVMVectorLoad(llvm::Instruction * CI)6276 llvm::Instruction *FixAddressSpace::processSVMVectorLoad(llvm::Instruction *CI) {
6277     llvm::Value *ptr = CI->getOperand(0);
6278     llvm::Type *retType = CI->getType();
6279 
6280     Assert(llvm::isa<llvm::VectorType>(retType));
6281 
6282     if (GetAddressSpace(ptr) == AddressSpace::External)
6283         return NULL;
6284 
6285     // Conevrt int64 ptr to pointer
6286     ptr = new llvm::IntToPtrInst(ptr, llvm::PointerType::get(retType, 0), CI->getName() + "_inttoptr", CI);
6287     llvm::Instruction *loadInst = NULL;
6288 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
6289     Assert(llvm::isa<llvm::PointerType>(ptr->getType()));
6290     loadInst = new llvm::LoadInst(llvm::dyn_cast<llvm::PointerType>(ptr->getType())->getPointerElementType(), ptr,
6291                                   CI->getName(), false /* not volatile */,
6292                                   llvm::MaybeAlign(g->target->getNativeVectorAlignment()).valueOrOne(),
6293                                   (llvm::Instruction *)NULL);
6294 #else
6295     loadInst = new llvm::LoadInst(ptr, CI->getName(), false,
6296                                   llvm::MaybeAlign(g->target->getNativeVectorAlignment()).valueOrOne(),
6297                                   (llvm::Instruction *)NULL);
6298 #endif
6299 
6300     Assert(loadInst);
6301     return loadInst;
6302 }
6303 
processVectorStore(llvm::StoreInst * SI)6304 llvm::Instruction *FixAddressSpace::processVectorStore(llvm::StoreInst *SI) {
6305     llvm::Value *ptr = SI->getOperand(1);
6306     llvm::Value *val = SI->getOperand(0);
6307     Assert(ptr != NULL);
6308     Assert(val != NULL);
6309 
6310     llvm::Type *valType = val->getType();
6311 
6312     Assert(llvm::isa<llvm::VectorType>(valType));
6313 
6314     if (GetAddressSpace(ptr) != AddressSpace::External)
6315         return NULL;
6316 
6317     // Ptr store should be done via ptrtoint
6318     // Note: it doesn't look like a normal case for GenX target
6319     if (valType->getScalarType()->isPointerTy()) {
6320         auto scalarType = g->target->is32Bit() ? LLVMTypes::Int32Type : LLVMTypes::Int64Type;
6321 #if ISPC_LLVM_VERSION >= ISPC_LLVM_11_0
6322         valType =
6323             llvm::FixedVectorType::get(scalarType, llvm::dyn_cast<llvm::FixedVectorType>(valType)->getNumElements());
6324 #else
6325         valType = llvm::VectorType::get(scalarType, valType->getVectorNumElements());
6326 #endif
6327         val = new llvm::PtrToIntInst(val, valType, "svm_st_val_ptrtoint", SI);
6328     }
6329 
6330     return lGenXStoreInst(val, ptr, llvm::dyn_cast<llvm::Instruction>(SI));
6331 }
6332 
processSVMVectorStore(llvm::Instruction * CI)6333 llvm::Instruction *FixAddressSpace::processSVMVectorStore(llvm::Instruction *CI) {
6334     llvm::Value *ptr = CI->getOperand(0);
6335     llvm::Value *val = CI->getOperand(1);
6336 
6337     Assert(ptr != NULL);
6338     Assert(val != NULL);
6339 
6340     llvm::Type *valType = val->getType();
6341 
6342     Assert(llvm::isa<llvm::VectorType>(valType));
6343 
6344     if (GetAddressSpace(ptr) == AddressSpace::External)
6345         return NULL;
6346 
6347     // Conevrt int64 ptr to pointer
6348     ptr = new llvm::IntToPtrInst(ptr, llvm::PointerType::get(valType, 0), CI->getName() + "_inttoptr", CI);
6349 
6350     llvm::Instruction *storeInst = NULL;
6351     storeInst = new llvm::StoreInst(val, ptr, (llvm::Instruction *)NULL,
6352                                     llvm::MaybeAlign(g->target->getNativeVectorAlignment()).valueOrOne());
6353     Assert(storeInst);
6354     return storeInst;
6355 }
6356 
runOnBasicBlock(llvm::BasicBlock & bb)6357 bool FixAddressSpace::runOnBasicBlock(llvm::BasicBlock &bb) {
6358     DEBUG_START_PASS("FixAddressSpace");
6359     bool modifiedAny = false;
6360 
6361 restart:
6362     for (llvm::BasicBlock::iterator I = bb.begin(), E = --bb.end(); I != E; ++I) {
6363         llvm::Instruction *inst = &*I;
6364         if (llvm::LoadInst *ld = llvm::dyn_cast<llvm::LoadInst>(inst)) {
6365             if (llvm::isa<llvm::VectorType>(ld->getType())) {
6366                 llvm::Instruction *load_inst = processVectorLoad(ld);
6367                 if (load_inst != NULL) {
6368                     applyReplace(load_inst, ld);
6369                     modifiedAny = true;
6370                     goto restart;
6371                 }
6372             }
6373         } else if (llvm::StoreInst *st = llvm::dyn_cast<llvm::StoreInst>(inst)) {
6374             llvm::Value *val = st->getOperand(0);
6375             Assert(val != NULL);
6376             if (llvm::isa<llvm::VectorType>(val->getType())) {
6377                 llvm::Instruction *store_inst = processVectorStore(st);
6378                 if (store_inst != NULL) {
6379                     applyReplace(store_inst, st);
6380                     modifiedAny = true;
6381                     goto restart;
6382                 }
6383             }
6384         } else if (llvm::GenXIntrinsic::getGenXIntrinsicID(inst) == llvm::GenXIntrinsic::genx_svm_block_ld_unaligned) {
6385             llvm::Instruction *load_inst = processSVMVectorLoad(inst);
6386             if (load_inst != NULL) {
6387                 applyReplace(load_inst, inst);
6388                 modifiedAny = true;
6389                 goto restart;
6390             }
6391         } else if (llvm::GenXIntrinsic::getGenXIntrinsicID(inst) == llvm::GenXIntrinsic::genx_svm_block_st) {
6392             llvm::Instruction *store_inst = processSVMVectorStore(inst);
6393             if (store_inst != NULL) {
6394                 applyReplace(store_inst, inst);
6395                 modifiedAny = true;
6396                 goto restart;
6397             }
6398         }
6399     }
6400     DEBUG_END_PASS("Fix address space");
6401     return modifiedAny;
6402 }
6403 
runOnFunction(llvm::Function & F)6404 bool FixAddressSpace::runOnFunction(llvm::Function &F) {
6405     // Transformations are correct when the function is not internal.
6406     // This is due to address space calculation algorithm.
6407     // TODO: problems can be met in case of Stack Calls
6408     llvm::TimeTraceScope FuncScope("FixAddressSpace::runOnFunction", F.getName());
6409     if (F.getLinkage() == llvm::GlobalValue::LinkageTypes::InternalLinkage)
6410         return false;
6411 
6412     bool modifiedAny = false;
6413     for (llvm::BasicBlock &BB : F) {
6414         modifiedAny |= runOnBasicBlock(BB);
6415     }
6416     return modifiedAny;
6417 }
6418 
CreateFixAddressSpace()6419 static llvm::Pass *CreateFixAddressSpace() { return new FixAddressSpace(); }
6420 
6421 class DemotePHIs : public llvm::FunctionPass {
6422   public:
6423     static char ID;
DemotePHIs()6424     DemotePHIs() : FunctionPass(ID) {}
getPassName() const6425     llvm::StringRef getPassName() const { return "Demote PHI nodes"; }
6426     bool runOnFunction(llvm::Function &F);
6427 };
6428 
6429 char DemotePHIs::ID = 0;
6430 
runOnFunction(llvm::Function & F)6431 bool DemotePHIs::runOnFunction(llvm::Function &F) {
6432     llvm::TimeTraceScope FuncScope("DemotePHIs::runOnFunction", F.getName());
6433     if (F.isDeclaration() || skipFunction(F))
6434         return false;
6435     std::vector<llvm::Instruction *> WorkList;
6436     for (auto &ibb : F)
6437         for (llvm::BasicBlock::iterator iib = ibb.begin(), iie = ibb.end(); iib != iie; ++iib)
6438             if (llvm::isa<llvm::PHINode>(iib))
6439                 WorkList.push_back(&*iib);
6440 
6441     // Demote phi nodes
6442     for (auto *ilb : llvm::reverse(WorkList))
6443         DemotePHIToStack(llvm::cast<llvm::PHINode>(ilb), nullptr);
6444 
6445     return !WorkList.empty();
6446 }
6447 
CreateDemotePHIs()6448 static llvm::Pass *CreateDemotePHIs() { return new DemotePHIs(); }
6449 
6450 #endif
6451