1 /****************************************************************************
2  * Copyright (C) 2014-2018 Intel Corporation.   All Rights Reserved.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  * @file lower_x86.cpp
24  *
25  * @brief llvm pass to lower meta code to x86
26  *
27  * Notes:
28  *
29  ******************************************************************************/
30 
31 #include "jit_pch.hpp"
32 #include "passes.h"
33 #include "JitManager.h"
34 
35 #include "common/simdlib.hpp"
36 
37 #include <unordered_map>
38 
39 extern "C" void ScatterPS_256(uint8_t*, SIMD256::Integer, SIMD256::Float, uint8_t, uint32_t);
40 
41 namespace llvm
42 {
43     // forward declare the initializer
44     void initializeLowerX86Pass(PassRegistry&);
45 } // namespace llvm
46 
47 namespace SwrJit
48 {
49     using namespace llvm;
50 
51     enum TargetArch
52     {
53         AVX    = 0,
54         AVX2   = 1,
55         AVX512 = 2
56     };
57 
58     enum TargetWidth
59     {
60         W256       = 0,
61         W512       = 1,
62         NUM_WIDTHS = 2
63     };
64 
65     struct LowerX86;
66 
67     typedef std::function<Instruction*(LowerX86*, TargetArch, TargetWidth, CallInst*)> EmuFunc;
68 
69     struct X86Intrinsic
70     {
71         IntrinsicID intrin[NUM_WIDTHS];
72         EmuFunc       emuFunc;
73     };
74 
75     // Map of intrinsics that haven't been moved to the new mechanism yet. If used, these get the
76     // previous behavior of mapping directly to avx/avx2 intrinsics.
77     using intrinsicMap_t = std::map<std::string, IntrinsicID>;
getIntrinsicMap()78     static intrinsicMap_t& getIntrinsicMap() {
79         static std::map<std::string, IntrinsicID> intrinsicMap = {
80             {"meta.intrinsic.BEXTR_32", Intrinsic::x86_bmi_bextr_32},
81             {"meta.intrinsic.VPSHUFB", Intrinsic::x86_avx2_pshuf_b},
82             {"meta.intrinsic.VCVTPS2PH", Intrinsic::x86_vcvtps2ph_256},
83             {"meta.intrinsic.VPTESTC", Intrinsic::x86_avx_ptestc_256},
84             {"meta.intrinsic.VPTESTZ", Intrinsic::x86_avx_ptestz_256},
85             {"meta.intrinsic.VPHADDD", Intrinsic::x86_avx2_phadd_d},
86             {"meta.intrinsic.PDEP32", Intrinsic::x86_bmi_pdep_32},
87             {"meta.intrinsic.RDTSC", Intrinsic::x86_rdtsc}
88         };
89         return intrinsicMap;
90     }
91 
92     // Forward decls
93     Instruction* NO_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
94     Instruction*
95     VPERM_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
96     Instruction*
97     VGATHER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
98     Instruction*
99     VSCATTER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
100     Instruction*
101     VROUND_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
102     Instruction*
103     VHSUB_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
104     Instruction*
105     VCONVERT_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
106 
107     Instruction* DOUBLE_EMU(LowerX86*     pThis,
108                             TargetArch    arch,
109                             TargetWidth   width,
110                             CallInst*     pCallInst,
111                             Intrinsic::ID intrin);
112 
113     static Intrinsic::ID DOUBLE = (Intrinsic::ID)-1;
114 
115     using intrinsicMapAdvanced_t = std::vector<std::map<std::string, X86Intrinsic>>;
116 
getIntrinsicMapAdvanced()117     static intrinsicMapAdvanced_t&  getIntrinsicMapAdvanced()
118     {
119         // clang-format off
120         static intrinsicMapAdvanced_t intrinsicMapAdvanced = {
121             //                               256 wide                               512 wide
122             {
123                 // AVX
124                 {"meta.intrinsic.VRCPPS",    {{Intrinsic::x86_avx_rcp_ps_256,       DOUBLE},                    NO_EMU}},
125                 {"meta.intrinsic.VPERMPS",   {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic},  VPERM_EMU}},
126                 {"meta.intrinsic.VPERMD",    {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic},  VPERM_EMU}},
127                 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic},  VGATHER_EMU}},
128                 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic},  VGATHER_EMU}},
129                 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic},  VGATHER_EMU}},
130                 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic,           Intrinsic::not_intrinsic}, VSCATTER_EMU}},
131                 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx_cvt_pd2_ps_256,   Intrinsic::not_intrinsic},  NO_EMU}},
132                 {"meta.intrinsic.VROUND",    {{Intrinsic::x86_avx_round_ps_256,     DOUBLE},                    NO_EMU}},
133                 {"meta.intrinsic.VHSUBPS",   {{Intrinsic::x86_avx_hsub_ps_256,      DOUBLE},                    NO_EMU}},
134             },
135             {
136                 // AVX2
137                 {"meta.intrinsic.VRCPPS",       {{Intrinsic::x86_avx_rcp_ps_256,    DOUBLE},                    NO_EMU}},
138                 {"meta.intrinsic.VPERMPS",      {{Intrinsic::x86_avx2_permps,       Intrinsic::not_intrinsic},  VPERM_EMU}},
139                 {"meta.intrinsic.VPERMD",       {{Intrinsic::x86_avx2_permd,        Intrinsic::not_intrinsic},  VPERM_EMU}},
140                 {"meta.intrinsic.VGATHERPD",    {{Intrinsic::not_intrinsic,         Intrinsic::not_intrinsic},  VGATHER_EMU}},
141                 {"meta.intrinsic.VGATHERPS",    {{Intrinsic::not_intrinsic,         Intrinsic::not_intrinsic},  VGATHER_EMU}},
142                 {"meta.intrinsic.VGATHERDD",    {{Intrinsic::not_intrinsic,         Intrinsic::not_intrinsic},  VGATHER_EMU}},
143                 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic,           Intrinsic::not_intrinsic}, VSCATTER_EMU}},
144                 {"meta.intrinsic.VCVTPD2PS",    {{Intrinsic::x86_avx_cvt_pd2_ps_256, DOUBLE},                   NO_EMU}},
145                 {"meta.intrinsic.VROUND",       {{Intrinsic::x86_avx_round_ps_256,  DOUBLE},                    NO_EMU}},
146                 {"meta.intrinsic.VHSUBPS",      {{Intrinsic::x86_avx_hsub_ps_256,   DOUBLE},                    NO_EMU}},
147             },
148             {
149                 // AVX512
150                 {"meta.intrinsic.VRCPPS", {{Intrinsic::x86_avx512_rcp14_ps_256,     Intrinsic::x86_avx512_rcp14_ps_512}, NO_EMU}},
151     #if LLVM_VERSION_MAJOR < 7
152                 {"meta.intrinsic.VPERMPS", {{Intrinsic::x86_avx512_mask_permvar_sf_256, Intrinsic::x86_avx512_mask_permvar_sf_512}, NO_EMU}},
153                 {"meta.intrinsic.VPERMD", {{Intrinsic::x86_avx512_mask_permvar_si_256, Intrinsic::x86_avx512_mask_permvar_si_512}, NO_EMU}},
154     #else
155                 {"meta.intrinsic.VPERMPS", {{Intrinsic::not_intrinsic,              Intrinsic::not_intrinsic}, VPERM_EMU}},
156                 {"meta.intrinsic.VPERMD", {{Intrinsic::not_intrinsic,               Intrinsic::not_intrinsic}, VPERM_EMU}},
157     #endif
158                 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic}, VGATHER_EMU}},
159                 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic}, VGATHER_EMU}},
160                 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic}, VGATHER_EMU}},
161                 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic,           Intrinsic::not_intrinsic}, VSCATTER_EMU}},
162     #if LLVM_VERSION_MAJOR < 7
163                 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx512_mask_cvtpd2ps_256, Intrinsic::x86_avx512_mask_cvtpd2ps_512}, NO_EMU}},
164     #else
165                 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic}, VCONVERT_EMU}},
166     #endif
167                 {"meta.intrinsic.VROUND", {{Intrinsic::not_intrinsic,               Intrinsic::not_intrinsic}, VROUND_EMU}},
168                 {"meta.intrinsic.VHSUBPS", {{Intrinsic::not_intrinsic,              Intrinsic::not_intrinsic}, VHSUB_EMU}}
169             }};
170         // clang-format on
171         return intrinsicMapAdvanced;
172     }
173 
getBitWidth(VectorType * pVTy)174     static uint32_t getBitWidth(VectorType *pVTy)
175     {
176 #if LLVM_VERSION_MAJOR >= 12
177         return cast<FixedVectorType>(pVTy)->getNumElements() * pVTy->getElementType()->getPrimitiveSizeInBits();
178 #elif LLVM_VERSION_MAJOR >= 11
179         return pVTy->getNumElements() * pVTy->getElementType()->getPrimitiveSizeInBits();
180 #else
181         return pVTy->getBitWidth();
182 #endif
183     }
184 
185     struct LowerX86 : public FunctionPass
186     {
LowerX86SwrJit::LowerX86187         LowerX86(Builder* b = nullptr) : FunctionPass(ID), B(b)
188         {
189             initializeLowerX86Pass(*PassRegistry::getPassRegistry());
190 
191             // Determine target arch
192             if (JM()->mArch.AVX512F())
193             {
194                 mTarget = AVX512;
195             }
196             else if (JM()->mArch.AVX2())
197             {
198                 mTarget = AVX2;
199             }
200             else if (JM()->mArch.AVX())
201             {
202                 mTarget = AVX;
203             }
204             else
205             {
206                 SWR_ASSERT(false, "Unsupported AVX architecture.");
207                 mTarget = AVX;
208             }
209 
210             // Setup scatter function for 256 wide
211             uint32_t curWidth = B->mVWidth;
212             B->SetTargetWidth(8);
213             std::vector<Type*> args = {
214                 B->mInt8PtrTy,   // pBase
215                 B->mSimdInt32Ty, // vIndices
216                 B->mSimdFP32Ty,  // vSrc
217                 B->mInt8Ty,      // mask
218                 B->mInt32Ty      // scale
219             };
220 
221             FunctionType* pfnScatterTy = FunctionType::get(B->mVoidTy, args, false);
222             mPfnScatter256             = cast<Function>(
223 #if LLVM_VERSION_MAJOR >= 9
224                 B->JM()->mpCurrentModule->getOrInsertFunction("ScatterPS_256", pfnScatterTy).getCallee());
225 #else
226                 B->JM()->mpCurrentModule->getOrInsertFunction("ScatterPS_256", pfnScatterTy));
227 #endif
228             if (sys::DynamicLibrary::SearchForAddressOfSymbol("ScatterPS_256") == nullptr)
229             {
230                 sys::DynamicLibrary::AddSymbol("ScatterPS_256", (void*)&ScatterPS_256);
231             }
232 
233             B->SetTargetWidth(curWidth);
234         }
235 
236         // Try to decipher the vector type of the instruction. This does not work properly
237         // across all intrinsics, and will have to be rethought. Probably need something
238         // similar to llvm's getDeclaration() utility to map a set of inputs to a specific typed
239         // intrinsic.
GetRequestedWidthAndTypeSwrJit::LowerX86240         void GetRequestedWidthAndType(CallInst*       pCallInst,
241                                       const StringRef intrinName,
242                                       TargetWidth*    pWidth,
243                                       Type**          pTy)
244         {
245             assert(pCallInst);
246             Type* pVecTy = pCallInst->getType();
247 
248             // Check for intrinsic specific types
249             // VCVTPD2PS type comes from src, not dst
250             if (intrinName.equals("meta.intrinsic.VCVTPD2PS"))
251             {
252                 Value* pOp = pCallInst->getOperand(0);
253                 assert(pOp);
254                 pVecTy = pOp->getType();
255             }
256 
257             if (!pVecTy->isVectorTy())
258             {
259                 for (auto& op : pCallInst->arg_operands())
260                 {
261                     if (op.get()->getType()->isVectorTy())
262                     {
263                         pVecTy = op.get()->getType();
264                         break;
265                     }
266                 }
267             }
268             SWR_ASSERT(pVecTy->isVectorTy(), "Couldn't determine vector size");
269 
270             uint32_t width = getBitWidth(cast<VectorType>(pVecTy));
271             switch (width)
272             {
273             case 256:
274                 *pWidth = W256;
275                 break;
276             case 512:
277                 *pWidth = W512;
278                 break;
279             default:
280                 SWR_ASSERT(false, "Unhandled vector width %d", width);
281                 *pWidth = W256;
282             }
283 
284             *pTy = pVecTy->getScalarType();
285         }
286 
GetZeroVecSwrJit::LowerX86287         Value* GetZeroVec(TargetWidth width, Type* pTy)
288         {
289             uint32_t numElem = 0;
290             switch (width)
291             {
292             case W256:
293                 numElem = 8;
294                 break;
295             case W512:
296                 numElem = 16;
297                 break;
298             default:
299                 SWR_ASSERT(false, "Unhandled vector width type %d\n", width);
300             }
301 
302             return ConstantVector::getNullValue(getVectorType(pTy, numElem));
303         }
304 
GetMaskSwrJit::LowerX86305         Value* GetMask(TargetWidth width)
306         {
307             Value* mask;
308             switch (width)
309             {
310             case W256:
311                 mask = B->C((uint8_t)-1);
312                 break;
313             case W512:
314                 mask = B->C((uint16_t)-1);
315                 break;
316             default:
317                 SWR_ASSERT(false, "Unhandled vector width type %d\n", width);
318             }
319             return mask;
320         }
321 
322         // Convert <N x i1> mask to <N x i32> x86 mask
VectorMaskSwrJit::LowerX86323         Value* VectorMask(Value* vi1Mask)
324         {
325 #if LLVM_VERSION_MAJOR >= 12
326             uint32_t numElem = cast<FixedVectorType>(vi1Mask->getType())->getNumElements();
327 #elif LLVM_VERSION_MAJOR >= 11
328             uint32_t numElem = cast<VectorType>(vi1Mask->getType())->getNumElements();
329 #else
330             uint32_t numElem = vi1Mask->getType()->getVectorNumElements();
331 #endif
332             return B->S_EXT(vi1Mask, getVectorType(B->mInt32Ty, numElem));
333         }
334 
ProcessIntrinsicAdvancedSwrJit::LowerX86335         Instruction* ProcessIntrinsicAdvanced(CallInst* pCallInst)
336         {
337             Function*   pFunc = pCallInst->getCalledFunction();
338             assert(pFunc);
339 
340             auto&       intrinsic = getIntrinsicMapAdvanced()[mTarget][pFunc->getName().str()];
341             TargetWidth vecWidth;
342             Type*       pElemTy;
343             GetRequestedWidthAndType(pCallInst, pFunc->getName(), &vecWidth, &pElemTy);
344 
345             // Check if there is a native intrinsic for this instruction
346             IntrinsicID id = intrinsic.intrin[vecWidth];
347             if (id == DOUBLE)
348             {
349                 // Double pump the next smaller SIMD intrinsic
350                 SWR_ASSERT(vecWidth != 0, "Cannot double pump smallest SIMD width.");
351                 Intrinsic::ID id2 = intrinsic.intrin[vecWidth - 1];
352                 SWR_ASSERT(id2 != Intrinsic::not_intrinsic,
353                            "Cannot find intrinsic to double pump.");
354                 return DOUBLE_EMU(this, mTarget, vecWidth, pCallInst, id2);
355             }
356             else if (id != Intrinsic::not_intrinsic)
357             {
358                 Function* pIntrin = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, id);
359                 SmallVector<Value*, 8> args;
360                 for (auto& arg : pCallInst->arg_operands())
361                 {
362                     args.push_back(arg.get());
363                 }
364 
365                 // If AVX512, all instructions add a src operand and mask. We'll pass in 0 src and
366                 // full mask for now Assuming the intrinsics are consistent and place the src
367                 // operand and mask last in the argument list.
368                 if (mTarget == AVX512)
369                 {
370                     if (pFunc->getName().equals("meta.intrinsic.VCVTPD2PS"))
371                     {
372                         args.push_back(GetZeroVec(W256, pCallInst->getType()->getScalarType()));
373                         args.push_back(GetMask(W256));
374                         // for AVX512 VCVTPD2PS, we also have to add rounding mode
375                         args.push_back(B->C(_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
376                     }
377                     else
378                     {
379                         args.push_back(GetZeroVec(vecWidth, pElemTy));
380                         args.push_back(GetMask(vecWidth));
381                     }
382                 }
383 
384                 return B->CALLA(pIntrin, args);
385             }
386             else
387             {
388                 // No native intrinsic, call emulation function
389                 return intrinsic.emuFunc(this, mTarget, vecWidth, pCallInst);
390             }
391 
392             SWR_ASSERT(false);
393             return nullptr;
394         }
395 
ProcessIntrinsicSwrJit::LowerX86396         Instruction* ProcessIntrinsic(CallInst* pCallInst)
397         {
398             Function* pFunc = pCallInst->getCalledFunction();
399             assert(pFunc);
400 
401             // Forward to the advanced support if found
402             if (getIntrinsicMapAdvanced()[mTarget].find(pFunc->getName().str()) != getIntrinsicMapAdvanced()[mTarget].end())
403             {
404                 return ProcessIntrinsicAdvanced(pCallInst);
405             }
406 
407             SWR_ASSERT(getIntrinsicMap().find(pFunc->getName().str()) != getIntrinsicMap().end(),
408                        "Unimplemented intrinsic %s.",
409                        pFunc->getName().str().c_str());
410 
411             Intrinsic::ID x86Intrinsic = getIntrinsicMap()[pFunc->getName().str()];
412             Function*     pX86IntrinFunc =
413                 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, x86Intrinsic);
414 
415             SmallVector<Value*, 8> args;
416             for (auto& arg : pCallInst->arg_operands())
417             {
418                 args.push_back(arg.get());
419             }
420             return B->CALLA(pX86IntrinFunc, args);
421         }
422 
423         //////////////////////////////////////////////////////////////////////////
424         /// @brief LLVM function pass run method.
425         /// @param f- The function we're working on with this pass.
runOnFunctionSwrJit::LowerX86426         virtual bool runOnFunction(Function& F)
427         {
428             std::vector<Instruction*> toRemove;
429             std::vector<BasicBlock*>  bbs;
430 
431             // Make temp copy of the basic blocks and instructions, as the intrinsic
432             // replacement code might invalidate the iterators
433             for (auto& b : F.getBasicBlockList())
434             {
435                 bbs.push_back(&b);
436             }
437 
438             for (auto* BB : bbs)
439             {
440                 std::vector<Instruction*> insts;
441                 for (auto& i : BB->getInstList())
442                 {
443                     insts.push_back(&i);
444                 }
445 
446                 for (auto* I : insts)
447                 {
448                     if (CallInst* pCallInst = dyn_cast<CallInst>(I))
449                     {
450                         Function* pFunc = pCallInst->getCalledFunction();
451                         if (pFunc)
452                         {
453                             if (pFunc->getName().startswith("meta.intrinsic"))
454                             {
455                                 B->IRB()->SetInsertPoint(I);
456                                 Instruction* pReplace = ProcessIntrinsic(pCallInst);
457                                 toRemove.push_back(pCallInst);
458                                 if (pReplace)
459                                 {
460                                     pCallInst->replaceAllUsesWith(pReplace);
461                                 }
462                             }
463                         }
464                     }
465                 }
466             }
467 
468             for (auto* pInst : toRemove)
469             {
470                 pInst->eraseFromParent();
471             }
472 
473             JitManager::DumpToFile(&F, "lowerx86");
474 
475             return true;
476         }
477 
getAnalysisUsageSwrJit::LowerX86478         virtual void getAnalysisUsage(AnalysisUsage& AU) const {}
479 
JMSwrJit::LowerX86480         JitManager* JM() { return B->JM(); }
481         Builder*    B;
482         TargetArch  mTarget;
483         Function*   mPfnScatter256;
484 
485         static char ID; ///< Needed by LLVM to generate ID for FunctionPass.
486     };
487 
488     char LowerX86::ID = 0; // LLVM uses address of ID as the actual ID.
489 
createLowerX86Pass(Builder * b)490     FunctionPass* createLowerX86Pass(Builder* b) { return new LowerX86(b); }
491 
NO_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)492     Instruction* NO_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
493     {
494         SWR_ASSERT(false, "Unimplemented intrinsic emulation.");
495         return nullptr;
496     }
497 
VPERM_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)498     Instruction* VPERM_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
499     {
500         // Only need vperm emulation for AVX
501         SWR_ASSERT(arch == AVX);
502 
503         Builder* B         = pThis->B;
504         auto     v32A      = pCallInst->getArgOperand(0);
505         auto     vi32Index = pCallInst->getArgOperand(1);
506 
507         Value* v32Result;
508         if (isa<Constant>(vi32Index))
509         {
510             // Can use llvm shuffle vector directly with constant shuffle indices
511             v32Result = B->VSHUFFLE(v32A, v32A, vi32Index);
512         }
513         else
514         {
515             v32Result = UndefValue::get(v32A->getType());
516 #if LLVM_VERSION_MAJOR >= 12
517             uint32_t numElem = cast<FixedVectorType>(v32A->getType())->getNumElements();
518 #elif LLVM_VERSION_MAJOR >= 11
519             uint32_t numElem = cast<VectorType>(v32A->getType())->getNumElements();
520 #else
521             uint32_t numElem = v32A->getType()->getVectorNumElements();
522 #endif
523             for (uint32_t l = 0; l < numElem; ++l)
524             {
525                 auto i32Index = B->VEXTRACT(vi32Index, B->C(l));
526                 auto val      = B->VEXTRACT(v32A, i32Index);
527                 v32Result     = B->VINSERT(v32Result, val, B->C(l));
528             }
529         }
530         return cast<Instruction>(v32Result);
531     }
532 
533     Instruction*
VGATHER_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)534     VGATHER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
535     {
536         Builder* B           = pThis->B;
537         auto     vSrc        = pCallInst->getArgOperand(0);
538         auto     pBase       = pCallInst->getArgOperand(1);
539         auto     vi32Indices = pCallInst->getArgOperand(2);
540         auto     vi1Mask     = pCallInst->getArgOperand(3);
541         auto     i8Scale     = pCallInst->getArgOperand(4);
542 
543         pBase              = B->POINTER_CAST(pBase, PointerType::get(B->mInt8Ty, 0));
544 #if LLVM_VERSION_MAJOR >= 11
545 #if LLVM_VERSION_MAJOR >= 12
546         FixedVectorType* pVectorType = cast<FixedVectorType>(vSrc->getType());
547 #else
548         VectorType* pVectorType = cast<VectorType>(vSrc->getType());
549 #endif
550         uint32_t    numElem     = pVectorType->getNumElements();
551         auto        srcTy       = pVectorType->getElementType();
552 #else
553         uint32_t numElem   = vSrc->getType()->getVectorNumElements();
554         auto     srcTy     = vSrc->getType()->getVectorElementType();
555 #endif
556         auto     i32Scale  = B->Z_EXT(i8Scale, B->mInt32Ty);
557 
558         Value*   v32Gather = nullptr;
559         if (arch == AVX)
560         {
561             // Full emulation for AVX
562             // Store source on stack to provide a valid address to load from inactive lanes
563             auto pStack = B->STACKSAVE();
564             auto pTmp   = B->ALLOCA(vSrc->getType());
565             B->STORE(vSrc, pTmp);
566 
567             v32Gather        = UndefValue::get(vSrc->getType());
568 #if LLVM_VERSION_MAJOR <= 10
569             auto vi32Scale   = ConstantVector::getSplat(numElem, cast<ConstantInt>(i32Scale));
570 #elif LLVM_VERSION_MAJOR == 11
571             auto vi32Scale   = ConstantVector::getSplat(ElementCount(numElem, false), cast<ConstantInt>(i32Scale));
572 #else
573             auto vi32Scale   = ConstantVector::getSplat(ElementCount::get(numElem, false), cast<ConstantInt>(i32Scale));
574 #endif
575             auto vi32Offsets = B->MUL(vi32Indices, vi32Scale);
576 
577             for (uint32_t i = 0; i < numElem; ++i)
578             {
579                 auto i32Offset          = B->VEXTRACT(vi32Offsets, B->C(i));
580                 auto pLoadAddress       = B->GEP(pBase, i32Offset);
581                 pLoadAddress            = B->BITCAST(pLoadAddress, PointerType::get(srcTy, 0));
582                 auto pMaskedLoadAddress = B->GEP(pTmp, {0, i});
583                 auto i1Mask             = B->VEXTRACT(vi1Mask, B->C(i));
584                 auto pValidAddress      = B->SELECT(i1Mask, pLoadAddress, pMaskedLoadAddress);
585                 auto val                = B->LOAD(pValidAddress);
586                 v32Gather               = B->VINSERT(v32Gather, val, B->C(i));
587             }
588 
589             B->STACKRESTORE(pStack);
590         }
591         else if (arch == AVX2 || (arch == AVX512 && width == W256))
592         {
593             Function* pX86IntrinFunc = nullptr;
594             if (srcTy == B->mFP32Ty)
595             {
596                 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
597                                                            Intrinsic::x86_avx2_gather_d_ps_256);
598             }
599             else if (srcTy == B->mInt32Ty)
600             {
601                 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
602                                                            Intrinsic::x86_avx2_gather_d_d_256);
603             }
604             else if (srcTy == B->mDoubleTy)
605             {
606                 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
607                                                            Intrinsic::x86_avx2_gather_d_q_256);
608             }
609             else
610             {
611                 SWR_ASSERT(false, "Unsupported vector element type for gather.");
612             }
613 
614             if (width == W256)
615             {
616                 auto v32Mask = B->BITCAST(pThis->VectorMask(vi1Mask), vSrc->getType());
617                 v32Gather = B->CALL(pX86IntrinFunc, {vSrc, pBase, vi32Indices, v32Mask, i8Scale});
618             }
619             else if (width == W512)
620             {
621                 // Double pump 4-wide for 64bit elements
622 #if LLVM_VERSION_MAJOR >= 12
623                 if (cast<FixedVectorType>(vSrc->getType())->getElementType() == B->mDoubleTy)
624 #elif LLVM_VERSION_MAJOR >= 11
625                 if (cast<VectorType>(vSrc->getType())->getElementType() == B->mDoubleTy)
626 #else
627                 if (vSrc->getType()->getVectorElementType() == B->mDoubleTy)
628 #endif
629                 {
630                     auto v64Mask = pThis->VectorMask(vi1Mask);
631 #if LLVM_VERSION_MAJOR >= 12
632                     uint32_t numElem = cast<FixedVectorType>(v64Mask->getType())->getNumElements();
633 #elif LLVM_VERSION_MAJOR >= 11
634                     uint32_t numElem = cast<VectorType>(v64Mask->getType())->getNumElements();
635 #else
636                     uint32_t numElem = v64Mask->getType()->getVectorNumElements();
637 #endif
638                     v64Mask = B->S_EXT(v64Mask, getVectorType(B->mInt64Ty, numElem));
639                     v64Mask = B->BITCAST(v64Mask, vSrc->getType());
640 
641                     Value* src0 = B->VSHUFFLE(vSrc, vSrc, B->C({0, 1, 2, 3}));
642                     Value* src1 = B->VSHUFFLE(vSrc, vSrc, B->C({4, 5, 6, 7}));
643 
644                     Value* indices0 = B->VSHUFFLE(vi32Indices, vi32Indices, B->C({0, 1, 2, 3}));
645                     Value* indices1 = B->VSHUFFLE(vi32Indices, vi32Indices, B->C({4, 5, 6, 7}));
646 
647                     Value* mask0 = B->VSHUFFLE(v64Mask, v64Mask, B->C({0, 1, 2, 3}));
648                     Value* mask1 = B->VSHUFFLE(v64Mask, v64Mask, B->C({4, 5, 6, 7}));
649 
650 #if LLVM_VERSION_MAJOR >= 12
651                     uint32_t numElemSrc0  = cast<FixedVectorType>(src0->getType())->getNumElements();
652                     uint32_t numElemMask0 = cast<FixedVectorType>(mask0->getType())->getNumElements();
653                     uint32_t numElemSrc1  = cast<FixedVectorType>(src1->getType())->getNumElements();
654                     uint32_t numElemMask1 = cast<FixedVectorType>(mask1->getType())->getNumElements();
655 #elif LLVM_VERSION_MAJOR >= 11
656                     uint32_t numElemSrc0  = cast<VectorType>(src0->getType())->getNumElements();
657                     uint32_t numElemMask0 = cast<VectorType>(mask0->getType())->getNumElements();
658                     uint32_t numElemSrc1  = cast<VectorType>(src1->getType())->getNumElements();
659                     uint32_t numElemMask1 = cast<VectorType>(mask1->getType())->getNumElements();
660 #else
661                     uint32_t numElemSrc0  = src0->getType()->getVectorNumElements();
662                     uint32_t numElemMask0 = mask0->getType()->getVectorNumElements();
663                     uint32_t numElemSrc1  = src1->getType()->getVectorNumElements();
664                     uint32_t numElemMask1 = mask1->getType()->getVectorNumElements();
665 #endif
666                     src0 = B->BITCAST(src0, getVectorType(B->mInt64Ty, numElemSrc0));
667                     mask0 = B->BITCAST(mask0, getVectorType(B->mInt64Ty, numElemMask0));
668                     Value* gather0 =
669                         B->CALL(pX86IntrinFunc, {src0, pBase, indices0, mask0, i8Scale});
670                     src1 = B->BITCAST(src1, getVectorType(B->mInt64Ty, numElemSrc1));
671                     mask1 = B->BITCAST(mask1, getVectorType(B->mInt64Ty, numElemMask1));
672                     Value* gather1 =
673                         B->CALL(pX86IntrinFunc, {src1, pBase, indices1, mask1, i8Scale});
674                     v32Gather = B->VSHUFFLE(gather0, gather1, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
675                     v32Gather = B->BITCAST(v32Gather, vSrc->getType());
676                 }
677                 else
678                 {
679                     // Double pump 8-wide for 32bit elements
680                     auto v32Mask = pThis->VectorMask(vi1Mask);
681                     v32Mask      = B->BITCAST(v32Mask, vSrc->getType());
682                     Value* src0  = B->EXTRACT_16(vSrc, 0);
683                     Value* src1  = B->EXTRACT_16(vSrc, 1);
684 
685                     Value* indices0 = B->EXTRACT_16(vi32Indices, 0);
686                     Value* indices1 = B->EXTRACT_16(vi32Indices, 1);
687 
688                     Value* mask0 = B->EXTRACT_16(v32Mask, 0);
689                     Value* mask1 = B->EXTRACT_16(v32Mask, 1);
690 
691                     Value* gather0 =
692                         B->CALL(pX86IntrinFunc, {src0, pBase, indices0, mask0, i8Scale});
693                     Value* gather1 =
694                         B->CALL(pX86IntrinFunc, {src1, pBase, indices1, mask1, i8Scale});
695 
696                     v32Gather = B->JOIN_16(gather0, gather1);
697                 }
698             }
699         }
700         else if (arch == AVX512)
701         {
702             Value*    iMask = nullptr;
703             Function* pX86IntrinFunc = nullptr;
704             if (srcTy == B->mFP32Ty)
705             {
706                 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
707                                                            Intrinsic::x86_avx512_gather_dps_512);
708                 iMask          = B->BITCAST(vi1Mask, B->mInt16Ty);
709             }
710             else if (srcTy == B->mInt32Ty)
711             {
712                 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
713                                                            Intrinsic::x86_avx512_gather_dpi_512);
714                 iMask          = B->BITCAST(vi1Mask, B->mInt16Ty);
715             }
716             else if (srcTy == B->mDoubleTy)
717             {
718                 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
719                                                            Intrinsic::x86_avx512_gather_dpd_512);
720                 iMask          = B->BITCAST(vi1Mask, B->mInt8Ty);
721             }
722             else
723             {
724                 SWR_ASSERT(false, "Unsupported vector element type for gather.");
725             }
726 
727             auto i32Scale = B->Z_EXT(i8Scale, B->mInt32Ty);
728             v32Gather     = B->CALL(pX86IntrinFunc, {vSrc, pBase, vi32Indices, iMask, i32Scale});
729         }
730 
731         return cast<Instruction>(v32Gather);
732     }
733     Instruction*
VSCATTER_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)734     VSCATTER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
735     {
736         Builder* B           = pThis->B;
737         auto     pBase       = pCallInst->getArgOperand(0);
738         auto     vi1Mask     = pCallInst->getArgOperand(1);
739         auto     vi32Indices = pCallInst->getArgOperand(2);
740         auto     v32Src      = pCallInst->getArgOperand(3);
741         auto     i32Scale    = pCallInst->getArgOperand(4);
742 
743         if (arch != AVX512)
744         {
745             // Call into C function to do the scatter. This has significantly better compile perf
746             // compared to jitting scatter loops for every scatter
747             if (width == W256)
748             {
749                 auto mask = B->BITCAST(vi1Mask, B->mInt8Ty);
750                 B->CALL(pThis->mPfnScatter256, {pBase, vi32Indices, v32Src, mask, i32Scale});
751             }
752             else
753             {
754                 // Need to break up 512 wide scatter to two 256 wide
755                 auto maskLo = B->VSHUFFLE(vi1Mask, vi1Mask, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
756                 auto indicesLo =
757                     B->VSHUFFLE(vi32Indices, vi32Indices, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
758                 auto srcLo = B->VSHUFFLE(v32Src, v32Src, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
759 
760                 auto mask = B->BITCAST(maskLo, B->mInt8Ty);
761                 B->CALL(pThis->mPfnScatter256, {pBase, indicesLo, srcLo, mask, i32Scale});
762 
763                 auto maskHi = B->VSHUFFLE(vi1Mask, vi1Mask, B->C({8, 9, 10, 11, 12, 13, 14, 15}));
764                 auto indicesHi =
765                     B->VSHUFFLE(vi32Indices, vi32Indices, B->C({8, 9, 10, 11, 12, 13, 14, 15}));
766                 auto srcHi = B->VSHUFFLE(v32Src, v32Src, B->C({8, 9, 10, 11, 12, 13, 14, 15}));
767 
768                 mask = B->BITCAST(maskHi, B->mInt8Ty);
769                 B->CALL(pThis->mPfnScatter256, {pBase, indicesHi, srcHi, mask, i32Scale});
770             }
771             return nullptr;
772         }
773 
774         Value*    iMask;
775         Function* pX86IntrinFunc;
776         if (width == W256)
777         {
778             // No direct intrinsic supported in llvm to scatter 8 elem with 32bit indices, but we
779             // can use the scatter of 8 elements with 64bit indices
780             pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
781                                                        Intrinsic::x86_avx512_scatter_qps_512);
782 
783             auto vi32IndicesExt = B->Z_EXT(vi32Indices, B->mSimdInt64Ty);
784             iMask               = B->BITCAST(vi1Mask, B->mInt8Ty);
785             B->CALL(pX86IntrinFunc, {pBase, iMask, vi32IndicesExt, v32Src, i32Scale});
786         }
787         else if (width == W512)
788         {
789             pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
790                                                        Intrinsic::x86_avx512_scatter_dps_512);
791             iMask          = B->BITCAST(vi1Mask, B->mInt16Ty);
792             B->CALL(pX86IntrinFunc, {pBase, iMask, vi32Indices, v32Src, i32Scale});
793         }
794         return nullptr;
795     }
796 
797     // No support for vroundps in avx512 (it is available in kncni), so emulate with avx
798     // instructions
799     Instruction*
VROUND_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)800     VROUND_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
801     {
802         SWR_ASSERT(arch == AVX512);
803 
804         auto B       = pThis->B;
805         auto vf32Src = pCallInst->getOperand(0);
806         assert(vf32Src);
807         auto i8Round = pCallInst->getOperand(1);
808         assert(i8Round);
809         auto pfnFunc =
810             Intrinsic::getDeclaration(B->JM()->mpCurrentModule, Intrinsic::x86_avx_round_ps_256);
811 
812         if (width == W256)
813         {
814             return cast<Instruction>(B->CALL2(pfnFunc, vf32Src, i8Round));
815         }
816         else if (width == W512)
817         {
818             auto v8f32SrcLo = B->EXTRACT_16(vf32Src, 0);
819             auto v8f32SrcHi = B->EXTRACT_16(vf32Src, 1);
820 
821             auto v8f32ResLo = B->CALL2(pfnFunc, v8f32SrcLo, i8Round);
822             auto v8f32ResHi = B->CALL2(pfnFunc, v8f32SrcHi, i8Round);
823 
824             return cast<Instruction>(B->JOIN_16(v8f32ResLo, v8f32ResHi));
825         }
826         else
827         {
828             SWR_ASSERT(false, "Unimplemented vector width.");
829         }
830 
831         return nullptr;
832     }
833 
834     Instruction*
VCONVERT_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)835     VCONVERT_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
836     {
837         SWR_ASSERT(arch == AVX512);
838 
839         auto B       = pThis->B;
840         auto vf32Src = pCallInst->getOperand(0);
841 
842         if (width == W256)
843         {
844             auto vf32SrcRound = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
845                                                           Intrinsic::x86_avx_round_ps_256);
846             return cast<Instruction>(B->FP_TRUNC(vf32SrcRound, B->mFP32Ty));
847         }
848         else if (width == W512)
849         {
850             // 512 can use intrinsic
851             auto pfnFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
852                                                      Intrinsic::x86_avx512_mask_cvtpd2ps_512);
853             return cast<Instruction>(B->CALL(pfnFunc, vf32Src));
854         }
855         else
856         {
857             SWR_ASSERT(false, "Unimplemented vector width.");
858         }
859 
860         return nullptr;
861     }
862 
863     // No support for hsub in AVX512
VHSUB_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)864     Instruction* VHSUB_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
865     {
866         SWR_ASSERT(arch == AVX512);
867 
868         auto B    = pThis->B;
869         auto src0 = pCallInst->getOperand(0);
870         auto src1 = pCallInst->getOperand(1);
871 
872         // 256b hsub can just use avx intrinsic
873         if (width == W256)
874         {
875             auto pX86IntrinFunc =
876                 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, Intrinsic::x86_avx_hsub_ps_256);
877             return cast<Instruction>(B->CALL2(pX86IntrinFunc, src0, src1));
878         }
879         else if (width == W512)
880         {
881             // 512b hsub can be accomplished with shuf/sub combo
882             auto minuend    = B->VSHUFFLE(src0, src1, B->C({0, 2, 8, 10, 4, 6, 12, 14}));
883             auto subtrahend = B->VSHUFFLE(src0, src1, B->C({1, 3, 9, 11, 5, 7, 13, 15}));
884             return cast<Instruction>(B->SUB(minuend, subtrahend));
885         }
886         else
887         {
888             SWR_ASSERT(false, "Unimplemented vector width.");
889             return nullptr;
890         }
891     }
892 
893     // Double pump input using Intrin template arg. This blindly extracts lower and upper 256 from
894     // each vector argument and calls the 256 wide intrinsic, then merges the results to 512 wide
DOUBLE_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst,Intrinsic::ID intrin)895     Instruction* DOUBLE_EMU(LowerX86*     pThis,
896                             TargetArch    arch,
897                             TargetWidth   width,
898                             CallInst*     pCallInst,
899                             Intrinsic::ID intrin)
900     {
901         auto B = pThis->B;
902         SWR_ASSERT(width == W512);
903         Value*    result[2];
904         Function* pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, intrin);
905         for (uint32_t i = 0; i < 2; ++i)
906         {
907             SmallVector<Value*, 8> args;
908             for (auto& arg : pCallInst->arg_operands())
909             {
910                 auto argType = arg.get()->getType();
911                 if (argType->isVectorTy())
912                 {
913 #if LLVM_VERSION_MAJOR >= 12
914                     uint32_t vecWidth  = cast<FixedVectorType>(argType)->getNumElements();
915                     auto     elemTy    = cast<FixedVectorType>(argType)->getElementType();
916 #elif LLVM_VERSION_MAJOR >= 11
917                     uint32_t vecWidth  = cast<VectorType>(argType)->getNumElements();
918                     auto     elemTy    = cast<VectorType>(argType)->getElementType();
919 #else
920                     uint32_t vecWidth  = argType->getVectorNumElements();
921                     auto     elemTy    = argType->getVectorElementType();
922 #endif
923                     Value*   lanes     = B->CInc<int>(i * vecWidth / 2, vecWidth / 2);
924                     Value*   argToPush = B->VSHUFFLE(arg.get(), B->VUNDEF(elemTy, vecWidth), lanes);
925                     args.push_back(argToPush);
926                 }
927                 else
928                 {
929                     args.push_back(arg.get());
930                 }
931             }
932             result[i] = B->CALLA(pX86IntrinFunc, args);
933         }
934         uint32_t vecWidth;
935         if (result[0]->getType()->isVectorTy())
936         {
937             assert(result[1]->getType()->isVectorTy());
938 #if LLVM_VERSION_MAJOR >= 12
939             vecWidth = cast<FixedVectorType>(result[0]->getType())->getNumElements() +
940                        cast<FixedVectorType>(result[1]->getType())->getNumElements();
941 #elif LLVM_VERSION_MAJOR >= 11
942             vecWidth = cast<VectorType>(result[0]->getType())->getNumElements() +
943                        cast<VectorType>(result[1]->getType())->getNumElements();
944 #else
945             vecWidth = result[0]->getType()->getVectorNumElements() +
946                        result[1]->getType()->getVectorNumElements();
947 #endif
948         }
949         else
950         {
951             vecWidth = 2;
952         }
953         Value* lanes = B->CInc<int>(0, vecWidth);
954         return cast<Instruction>(B->VSHUFFLE(result[0], result[1], lanes));
955     }
956 
957 } // namespace SwrJit
958 
959 using namespace SwrJit;
960 
961 INITIALIZE_PASS_BEGIN(LowerX86, "LowerX86", "LowerX86", false, false)
962 INITIALIZE_PASS_END(LowerX86, "LowerX86", "LowerX86", false, false)
963