1 /**************************************************************************** 2 * Copyright (C) 2014-2018 Intel Corporation. All Rights Reserved. 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 * 23 * @file lower_x86.cpp 24 * 25 * @brief llvm pass to lower meta code to x86 26 * 27 * Notes: 28 * 29 ******************************************************************************/ 30 31 #include "jit_pch.hpp" 32 #include "passes.h" 33 #include "JitManager.h" 34 35 #include "common/simdlib.hpp" 36 37 #include <unordered_map> 38 39 extern "C" void ScatterPS_256(uint8_t*, SIMD256::Integer, SIMD256::Float, uint8_t, uint32_t); 40 41 namespace llvm 42 { 43 // forward declare the initializer 44 void initializeLowerX86Pass(PassRegistry&); 45 } // namespace llvm 46 47 namespace SwrJit 48 { 49 using namespace llvm; 50 51 enum TargetArch 52 { 53 AVX = 0, 54 AVX2 = 1, 55 AVX512 = 2 56 }; 57 58 enum TargetWidth 59 { 60 W256 = 0, 61 W512 = 1, 62 NUM_WIDTHS = 2 63 }; 64 65 struct LowerX86; 66 67 typedef std::function<Instruction*(LowerX86*, TargetArch, TargetWidth, CallInst*)> EmuFunc; 68 69 struct X86Intrinsic 70 { 71 IntrinsicID intrin[NUM_WIDTHS]; 72 EmuFunc emuFunc; 73 }; 74 75 // Map of intrinsics that haven't been moved to the new mechanism yet. If used, these get the 76 // previous behavior of mapping directly to avx/avx2 intrinsics. 77 using intrinsicMap_t = std::map<std::string, IntrinsicID>; getIntrinsicMap()78 static intrinsicMap_t& getIntrinsicMap() { 79 static std::map<std::string, IntrinsicID> intrinsicMap = { 80 {"meta.intrinsic.BEXTR_32", Intrinsic::x86_bmi_bextr_32}, 81 {"meta.intrinsic.VPSHUFB", Intrinsic::x86_avx2_pshuf_b}, 82 {"meta.intrinsic.VCVTPS2PH", Intrinsic::x86_vcvtps2ph_256}, 83 {"meta.intrinsic.VPTESTC", Intrinsic::x86_avx_ptestc_256}, 84 {"meta.intrinsic.VPTESTZ", Intrinsic::x86_avx_ptestz_256}, 85 {"meta.intrinsic.VPHADDD", Intrinsic::x86_avx2_phadd_d}, 86 {"meta.intrinsic.PDEP32", Intrinsic::x86_bmi_pdep_32}, 87 {"meta.intrinsic.RDTSC", Intrinsic::x86_rdtsc} 88 }; 89 return intrinsicMap; 90 } 91 92 // Forward decls 93 Instruction* NO_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst); 94 Instruction* 95 VPERM_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst); 96 Instruction* 97 VGATHER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst); 98 Instruction* 99 VSCATTER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst); 100 Instruction* 101 VROUND_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst); 102 Instruction* 103 VHSUB_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst); 104 Instruction* 105 VCONVERT_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst); 106 107 Instruction* DOUBLE_EMU(LowerX86* pThis, 108 TargetArch arch, 109 TargetWidth width, 110 CallInst* pCallInst, 111 Intrinsic::ID intrin); 112 113 static Intrinsic::ID DOUBLE = (Intrinsic::ID)-1; 114 115 using intrinsicMapAdvanced_t = std::vector<std::map<std::string, X86Intrinsic>>; 116 getIntrinsicMapAdvanced()117 static intrinsicMapAdvanced_t& getIntrinsicMapAdvanced() 118 { 119 // clang-format off 120 static intrinsicMapAdvanced_t intrinsicMapAdvanced = { 121 // 256 wide 512 wide 122 { 123 // AVX 124 {"meta.intrinsic.VRCPPS", {{Intrinsic::x86_avx_rcp_ps_256, DOUBLE}, NO_EMU}}, 125 {"meta.intrinsic.VPERMPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}}, 126 {"meta.intrinsic.VPERMD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}}, 127 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 128 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 129 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 130 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VSCATTER_EMU}}, 131 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx_cvt_pd2_ps_256, Intrinsic::not_intrinsic}, NO_EMU}}, 132 {"meta.intrinsic.VROUND", {{Intrinsic::x86_avx_round_ps_256, DOUBLE}, NO_EMU}}, 133 {"meta.intrinsic.VHSUBPS", {{Intrinsic::x86_avx_hsub_ps_256, DOUBLE}, NO_EMU}}, 134 }, 135 { 136 // AVX2 137 {"meta.intrinsic.VRCPPS", {{Intrinsic::x86_avx_rcp_ps_256, DOUBLE}, NO_EMU}}, 138 {"meta.intrinsic.VPERMPS", {{Intrinsic::x86_avx2_permps, Intrinsic::not_intrinsic}, VPERM_EMU}}, 139 {"meta.intrinsic.VPERMD", {{Intrinsic::x86_avx2_permd, Intrinsic::not_intrinsic}, VPERM_EMU}}, 140 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 141 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 142 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 143 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VSCATTER_EMU}}, 144 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx_cvt_pd2_ps_256, DOUBLE}, NO_EMU}}, 145 {"meta.intrinsic.VROUND", {{Intrinsic::x86_avx_round_ps_256, DOUBLE}, NO_EMU}}, 146 {"meta.intrinsic.VHSUBPS", {{Intrinsic::x86_avx_hsub_ps_256, DOUBLE}, NO_EMU}}, 147 }, 148 { 149 // AVX512 150 {"meta.intrinsic.VRCPPS", {{Intrinsic::x86_avx512_rcp14_ps_256, Intrinsic::x86_avx512_rcp14_ps_512}, NO_EMU}}, 151 #if LLVM_VERSION_MAJOR < 7 152 {"meta.intrinsic.VPERMPS", {{Intrinsic::x86_avx512_mask_permvar_sf_256, Intrinsic::x86_avx512_mask_permvar_sf_512}, NO_EMU}}, 153 {"meta.intrinsic.VPERMD", {{Intrinsic::x86_avx512_mask_permvar_si_256, Intrinsic::x86_avx512_mask_permvar_si_512}, NO_EMU}}, 154 #else 155 {"meta.intrinsic.VPERMPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}}, 156 {"meta.intrinsic.VPERMD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}}, 157 #endif 158 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 159 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 160 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 161 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VSCATTER_EMU}}, 162 #if LLVM_VERSION_MAJOR < 7 163 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx512_mask_cvtpd2ps_256, Intrinsic::x86_avx512_mask_cvtpd2ps_512}, NO_EMU}}, 164 #else 165 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VCONVERT_EMU}}, 166 #endif 167 {"meta.intrinsic.VROUND", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VROUND_EMU}}, 168 {"meta.intrinsic.VHSUBPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VHSUB_EMU}} 169 }}; 170 // clang-format on 171 return intrinsicMapAdvanced; 172 } 173 getBitWidth(VectorType * pVTy)174 static uint32_t getBitWidth(VectorType *pVTy) 175 { 176 #if LLVM_VERSION_MAJOR >= 12 177 return cast<FixedVectorType>(pVTy)->getNumElements() * pVTy->getElementType()->getPrimitiveSizeInBits(); 178 #elif LLVM_VERSION_MAJOR >= 11 179 return pVTy->getNumElements() * pVTy->getElementType()->getPrimitiveSizeInBits(); 180 #else 181 return pVTy->getBitWidth(); 182 #endif 183 } 184 185 struct LowerX86 : public FunctionPass 186 { LowerX86SwrJit::LowerX86187 LowerX86(Builder* b = nullptr) : FunctionPass(ID), B(b) 188 { 189 initializeLowerX86Pass(*PassRegistry::getPassRegistry()); 190 191 // Determine target arch 192 if (JM()->mArch.AVX512F()) 193 { 194 mTarget = AVX512; 195 } 196 else if (JM()->mArch.AVX2()) 197 { 198 mTarget = AVX2; 199 } 200 else if (JM()->mArch.AVX()) 201 { 202 mTarget = AVX; 203 } 204 else 205 { 206 SWR_ASSERT(false, "Unsupported AVX architecture."); 207 mTarget = AVX; 208 } 209 210 // Setup scatter function for 256 wide 211 uint32_t curWidth = B->mVWidth; 212 B->SetTargetWidth(8); 213 std::vector<Type*> args = { 214 B->mInt8PtrTy, // pBase 215 B->mSimdInt32Ty, // vIndices 216 B->mSimdFP32Ty, // vSrc 217 B->mInt8Ty, // mask 218 B->mInt32Ty // scale 219 }; 220 221 FunctionType* pfnScatterTy = FunctionType::get(B->mVoidTy, args, false); 222 mPfnScatter256 = cast<Function>( 223 #if LLVM_VERSION_MAJOR >= 9 224 B->JM()->mpCurrentModule->getOrInsertFunction("ScatterPS_256", pfnScatterTy).getCallee()); 225 #else 226 B->JM()->mpCurrentModule->getOrInsertFunction("ScatterPS_256", pfnScatterTy)); 227 #endif 228 if (sys::DynamicLibrary::SearchForAddressOfSymbol("ScatterPS_256") == nullptr) 229 { 230 sys::DynamicLibrary::AddSymbol("ScatterPS_256", (void*)&ScatterPS_256); 231 } 232 233 B->SetTargetWidth(curWidth); 234 } 235 236 // Try to decipher the vector type of the instruction. This does not work properly 237 // across all intrinsics, and will have to be rethought. Probably need something 238 // similar to llvm's getDeclaration() utility to map a set of inputs to a specific typed 239 // intrinsic. GetRequestedWidthAndTypeSwrJit::LowerX86240 void GetRequestedWidthAndType(CallInst* pCallInst, 241 const StringRef intrinName, 242 TargetWidth* pWidth, 243 Type** pTy) 244 { 245 assert(pCallInst); 246 Type* pVecTy = pCallInst->getType(); 247 248 // Check for intrinsic specific types 249 // VCVTPD2PS type comes from src, not dst 250 if (intrinName.equals("meta.intrinsic.VCVTPD2PS")) 251 { 252 Value* pOp = pCallInst->getOperand(0); 253 assert(pOp); 254 pVecTy = pOp->getType(); 255 } 256 257 if (!pVecTy->isVectorTy()) 258 { 259 for (auto& op : pCallInst->arg_operands()) 260 { 261 if (op.get()->getType()->isVectorTy()) 262 { 263 pVecTy = op.get()->getType(); 264 break; 265 } 266 } 267 } 268 SWR_ASSERT(pVecTy->isVectorTy(), "Couldn't determine vector size"); 269 270 uint32_t width = getBitWidth(cast<VectorType>(pVecTy)); 271 switch (width) 272 { 273 case 256: 274 *pWidth = W256; 275 break; 276 case 512: 277 *pWidth = W512; 278 break; 279 default: 280 SWR_ASSERT(false, "Unhandled vector width %d", width); 281 *pWidth = W256; 282 } 283 284 *pTy = pVecTy->getScalarType(); 285 } 286 GetZeroVecSwrJit::LowerX86287 Value* GetZeroVec(TargetWidth width, Type* pTy) 288 { 289 uint32_t numElem = 0; 290 switch (width) 291 { 292 case W256: 293 numElem = 8; 294 break; 295 case W512: 296 numElem = 16; 297 break; 298 default: 299 SWR_ASSERT(false, "Unhandled vector width type %d\n", width); 300 } 301 302 return ConstantVector::getNullValue(getVectorType(pTy, numElem)); 303 } 304 GetMaskSwrJit::LowerX86305 Value* GetMask(TargetWidth width) 306 { 307 Value* mask; 308 switch (width) 309 { 310 case W256: 311 mask = B->C((uint8_t)-1); 312 break; 313 case W512: 314 mask = B->C((uint16_t)-1); 315 break; 316 default: 317 SWR_ASSERT(false, "Unhandled vector width type %d\n", width); 318 } 319 return mask; 320 } 321 322 // Convert <N x i1> mask to <N x i32> x86 mask VectorMaskSwrJit::LowerX86323 Value* VectorMask(Value* vi1Mask) 324 { 325 #if LLVM_VERSION_MAJOR >= 12 326 uint32_t numElem = cast<FixedVectorType>(vi1Mask->getType())->getNumElements(); 327 #elif LLVM_VERSION_MAJOR >= 11 328 uint32_t numElem = cast<VectorType>(vi1Mask->getType())->getNumElements(); 329 #else 330 uint32_t numElem = vi1Mask->getType()->getVectorNumElements(); 331 #endif 332 return B->S_EXT(vi1Mask, getVectorType(B->mInt32Ty, numElem)); 333 } 334 ProcessIntrinsicAdvancedSwrJit::LowerX86335 Instruction* ProcessIntrinsicAdvanced(CallInst* pCallInst) 336 { 337 Function* pFunc = pCallInst->getCalledFunction(); 338 assert(pFunc); 339 340 auto& intrinsic = getIntrinsicMapAdvanced()[mTarget][pFunc->getName().str()]; 341 TargetWidth vecWidth; 342 Type* pElemTy; 343 GetRequestedWidthAndType(pCallInst, pFunc->getName(), &vecWidth, &pElemTy); 344 345 // Check if there is a native intrinsic for this instruction 346 IntrinsicID id = intrinsic.intrin[vecWidth]; 347 if (id == DOUBLE) 348 { 349 // Double pump the next smaller SIMD intrinsic 350 SWR_ASSERT(vecWidth != 0, "Cannot double pump smallest SIMD width."); 351 Intrinsic::ID id2 = intrinsic.intrin[vecWidth - 1]; 352 SWR_ASSERT(id2 != Intrinsic::not_intrinsic, 353 "Cannot find intrinsic to double pump."); 354 return DOUBLE_EMU(this, mTarget, vecWidth, pCallInst, id2); 355 } 356 else if (id != Intrinsic::not_intrinsic) 357 { 358 Function* pIntrin = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, id); 359 SmallVector<Value*, 8> args; 360 for (auto& arg : pCallInst->arg_operands()) 361 { 362 args.push_back(arg.get()); 363 } 364 365 // If AVX512, all instructions add a src operand and mask. We'll pass in 0 src and 366 // full mask for now Assuming the intrinsics are consistent and place the src 367 // operand and mask last in the argument list. 368 if (mTarget == AVX512) 369 { 370 if (pFunc->getName().equals("meta.intrinsic.VCVTPD2PS")) 371 { 372 args.push_back(GetZeroVec(W256, pCallInst->getType()->getScalarType())); 373 args.push_back(GetMask(W256)); 374 // for AVX512 VCVTPD2PS, we also have to add rounding mode 375 args.push_back(B->C(_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); 376 } 377 else 378 { 379 args.push_back(GetZeroVec(vecWidth, pElemTy)); 380 args.push_back(GetMask(vecWidth)); 381 } 382 } 383 384 return B->CALLA(pIntrin, args); 385 } 386 else 387 { 388 // No native intrinsic, call emulation function 389 return intrinsic.emuFunc(this, mTarget, vecWidth, pCallInst); 390 } 391 392 SWR_ASSERT(false); 393 return nullptr; 394 } 395 ProcessIntrinsicSwrJit::LowerX86396 Instruction* ProcessIntrinsic(CallInst* pCallInst) 397 { 398 Function* pFunc = pCallInst->getCalledFunction(); 399 assert(pFunc); 400 401 // Forward to the advanced support if found 402 if (getIntrinsicMapAdvanced()[mTarget].find(pFunc->getName().str()) != getIntrinsicMapAdvanced()[mTarget].end()) 403 { 404 return ProcessIntrinsicAdvanced(pCallInst); 405 } 406 407 SWR_ASSERT(getIntrinsicMap().find(pFunc->getName().str()) != getIntrinsicMap().end(), 408 "Unimplemented intrinsic %s.", 409 pFunc->getName().str().c_str()); 410 411 Intrinsic::ID x86Intrinsic = getIntrinsicMap()[pFunc->getName().str()]; 412 Function* pX86IntrinFunc = 413 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, x86Intrinsic); 414 415 SmallVector<Value*, 8> args; 416 for (auto& arg : pCallInst->arg_operands()) 417 { 418 args.push_back(arg.get()); 419 } 420 return B->CALLA(pX86IntrinFunc, args); 421 } 422 423 ////////////////////////////////////////////////////////////////////////// 424 /// @brief LLVM function pass run method. 425 /// @param f- The function we're working on with this pass. runOnFunctionSwrJit::LowerX86426 virtual bool runOnFunction(Function& F) 427 { 428 std::vector<Instruction*> toRemove; 429 std::vector<BasicBlock*> bbs; 430 431 // Make temp copy of the basic blocks and instructions, as the intrinsic 432 // replacement code might invalidate the iterators 433 for (auto& b : F.getBasicBlockList()) 434 { 435 bbs.push_back(&b); 436 } 437 438 for (auto* BB : bbs) 439 { 440 std::vector<Instruction*> insts; 441 for (auto& i : BB->getInstList()) 442 { 443 insts.push_back(&i); 444 } 445 446 for (auto* I : insts) 447 { 448 if (CallInst* pCallInst = dyn_cast<CallInst>(I)) 449 { 450 Function* pFunc = pCallInst->getCalledFunction(); 451 if (pFunc) 452 { 453 if (pFunc->getName().startswith("meta.intrinsic")) 454 { 455 B->IRB()->SetInsertPoint(I); 456 Instruction* pReplace = ProcessIntrinsic(pCallInst); 457 toRemove.push_back(pCallInst); 458 if (pReplace) 459 { 460 pCallInst->replaceAllUsesWith(pReplace); 461 } 462 } 463 } 464 } 465 } 466 } 467 468 for (auto* pInst : toRemove) 469 { 470 pInst->eraseFromParent(); 471 } 472 473 JitManager::DumpToFile(&F, "lowerx86"); 474 475 return true; 476 } 477 getAnalysisUsageSwrJit::LowerX86478 virtual void getAnalysisUsage(AnalysisUsage& AU) const {} 479 JMSwrJit::LowerX86480 JitManager* JM() { return B->JM(); } 481 Builder* B; 482 TargetArch mTarget; 483 Function* mPfnScatter256; 484 485 static char ID; ///< Needed by LLVM to generate ID for FunctionPass. 486 }; 487 488 char LowerX86::ID = 0; // LLVM uses address of ID as the actual ID. 489 createLowerX86Pass(Builder * b)490 FunctionPass* createLowerX86Pass(Builder* b) { return new LowerX86(b); } 491 NO_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)492 Instruction* NO_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst) 493 { 494 SWR_ASSERT(false, "Unimplemented intrinsic emulation."); 495 return nullptr; 496 } 497 VPERM_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)498 Instruction* VPERM_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst) 499 { 500 // Only need vperm emulation for AVX 501 SWR_ASSERT(arch == AVX); 502 503 Builder* B = pThis->B; 504 auto v32A = pCallInst->getArgOperand(0); 505 auto vi32Index = pCallInst->getArgOperand(1); 506 507 Value* v32Result; 508 if (isa<Constant>(vi32Index)) 509 { 510 // Can use llvm shuffle vector directly with constant shuffle indices 511 v32Result = B->VSHUFFLE(v32A, v32A, vi32Index); 512 } 513 else 514 { 515 v32Result = UndefValue::get(v32A->getType()); 516 #if LLVM_VERSION_MAJOR >= 12 517 uint32_t numElem = cast<FixedVectorType>(v32A->getType())->getNumElements(); 518 #elif LLVM_VERSION_MAJOR >= 11 519 uint32_t numElem = cast<VectorType>(v32A->getType())->getNumElements(); 520 #else 521 uint32_t numElem = v32A->getType()->getVectorNumElements(); 522 #endif 523 for (uint32_t l = 0; l < numElem; ++l) 524 { 525 auto i32Index = B->VEXTRACT(vi32Index, B->C(l)); 526 auto val = B->VEXTRACT(v32A, i32Index); 527 v32Result = B->VINSERT(v32Result, val, B->C(l)); 528 } 529 } 530 return cast<Instruction>(v32Result); 531 } 532 533 Instruction* VGATHER_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)534 VGATHER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst) 535 { 536 Builder* B = pThis->B; 537 auto vSrc = pCallInst->getArgOperand(0); 538 auto pBase = pCallInst->getArgOperand(1); 539 auto vi32Indices = pCallInst->getArgOperand(2); 540 auto vi1Mask = pCallInst->getArgOperand(3); 541 auto i8Scale = pCallInst->getArgOperand(4); 542 543 pBase = B->POINTER_CAST(pBase, PointerType::get(B->mInt8Ty, 0)); 544 #if LLVM_VERSION_MAJOR >= 11 545 #if LLVM_VERSION_MAJOR >= 12 546 FixedVectorType* pVectorType = cast<FixedVectorType>(vSrc->getType()); 547 #else 548 VectorType* pVectorType = cast<VectorType>(vSrc->getType()); 549 #endif 550 uint32_t numElem = pVectorType->getNumElements(); 551 auto srcTy = pVectorType->getElementType(); 552 #else 553 uint32_t numElem = vSrc->getType()->getVectorNumElements(); 554 auto srcTy = vSrc->getType()->getVectorElementType(); 555 #endif 556 auto i32Scale = B->Z_EXT(i8Scale, B->mInt32Ty); 557 558 Value* v32Gather = nullptr; 559 if (arch == AVX) 560 { 561 // Full emulation for AVX 562 // Store source on stack to provide a valid address to load from inactive lanes 563 auto pStack = B->STACKSAVE(); 564 auto pTmp = B->ALLOCA(vSrc->getType()); 565 B->STORE(vSrc, pTmp); 566 567 v32Gather = UndefValue::get(vSrc->getType()); 568 #if LLVM_VERSION_MAJOR <= 10 569 auto vi32Scale = ConstantVector::getSplat(numElem, cast<ConstantInt>(i32Scale)); 570 #elif LLVM_VERSION_MAJOR == 11 571 auto vi32Scale = ConstantVector::getSplat(ElementCount(numElem, false), cast<ConstantInt>(i32Scale)); 572 #else 573 auto vi32Scale = ConstantVector::getSplat(ElementCount::get(numElem, false), cast<ConstantInt>(i32Scale)); 574 #endif 575 auto vi32Offsets = B->MUL(vi32Indices, vi32Scale); 576 577 for (uint32_t i = 0; i < numElem; ++i) 578 { 579 auto i32Offset = B->VEXTRACT(vi32Offsets, B->C(i)); 580 auto pLoadAddress = B->GEP(pBase, i32Offset); 581 pLoadAddress = B->BITCAST(pLoadAddress, PointerType::get(srcTy, 0)); 582 auto pMaskedLoadAddress = B->GEP(pTmp, {0, i}); 583 auto i1Mask = B->VEXTRACT(vi1Mask, B->C(i)); 584 auto pValidAddress = B->SELECT(i1Mask, pLoadAddress, pMaskedLoadAddress); 585 auto val = B->LOAD(pValidAddress); 586 v32Gather = B->VINSERT(v32Gather, val, B->C(i)); 587 } 588 589 B->STACKRESTORE(pStack); 590 } 591 else if (arch == AVX2 || (arch == AVX512 && width == W256)) 592 { 593 Function* pX86IntrinFunc = nullptr; 594 if (srcTy == B->mFP32Ty) 595 { 596 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 597 Intrinsic::x86_avx2_gather_d_ps_256); 598 } 599 else if (srcTy == B->mInt32Ty) 600 { 601 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 602 Intrinsic::x86_avx2_gather_d_d_256); 603 } 604 else if (srcTy == B->mDoubleTy) 605 { 606 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 607 Intrinsic::x86_avx2_gather_d_q_256); 608 } 609 else 610 { 611 SWR_ASSERT(false, "Unsupported vector element type for gather."); 612 } 613 614 if (width == W256) 615 { 616 auto v32Mask = B->BITCAST(pThis->VectorMask(vi1Mask), vSrc->getType()); 617 v32Gather = B->CALL(pX86IntrinFunc, {vSrc, pBase, vi32Indices, v32Mask, i8Scale}); 618 } 619 else if (width == W512) 620 { 621 // Double pump 4-wide for 64bit elements 622 #if LLVM_VERSION_MAJOR >= 12 623 if (cast<FixedVectorType>(vSrc->getType())->getElementType() == B->mDoubleTy) 624 #elif LLVM_VERSION_MAJOR >= 11 625 if (cast<VectorType>(vSrc->getType())->getElementType() == B->mDoubleTy) 626 #else 627 if (vSrc->getType()->getVectorElementType() == B->mDoubleTy) 628 #endif 629 { 630 auto v64Mask = pThis->VectorMask(vi1Mask); 631 #if LLVM_VERSION_MAJOR >= 12 632 uint32_t numElem = cast<FixedVectorType>(v64Mask->getType())->getNumElements(); 633 #elif LLVM_VERSION_MAJOR >= 11 634 uint32_t numElem = cast<VectorType>(v64Mask->getType())->getNumElements(); 635 #else 636 uint32_t numElem = v64Mask->getType()->getVectorNumElements(); 637 #endif 638 v64Mask = B->S_EXT(v64Mask, getVectorType(B->mInt64Ty, numElem)); 639 v64Mask = B->BITCAST(v64Mask, vSrc->getType()); 640 641 Value* src0 = B->VSHUFFLE(vSrc, vSrc, B->C({0, 1, 2, 3})); 642 Value* src1 = B->VSHUFFLE(vSrc, vSrc, B->C({4, 5, 6, 7})); 643 644 Value* indices0 = B->VSHUFFLE(vi32Indices, vi32Indices, B->C({0, 1, 2, 3})); 645 Value* indices1 = B->VSHUFFLE(vi32Indices, vi32Indices, B->C({4, 5, 6, 7})); 646 647 Value* mask0 = B->VSHUFFLE(v64Mask, v64Mask, B->C({0, 1, 2, 3})); 648 Value* mask1 = B->VSHUFFLE(v64Mask, v64Mask, B->C({4, 5, 6, 7})); 649 650 #if LLVM_VERSION_MAJOR >= 12 651 uint32_t numElemSrc0 = cast<FixedVectorType>(src0->getType())->getNumElements(); 652 uint32_t numElemMask0 = cast<FixedVectorType>(mask0->getType())->getNumElements(); 653 uint32_t numElemSrc1 = cast<FixedVectorType>(src1->getType())->getNumElements(); 654 uint32_t numElemMask1 = cast<FixedVectorType>(mask1->getType())->getNumElements(); 655 #elif LLVM_VERSION_MAJOR >= 11 656 uint32_t numElemSrc0 = cast<VectorType>(src0->getType())->getNumElements(); 657 uint32_t numElemMask0 = cast<VectorType>(mask0->getType())->getNumElements(); 658 uint32_t numElemSrc1 = cast<VectorType>(src1->getType())->getNumElements(); 659 uint32_t numElemMask1 = cast<VectorType>(mask1->getType())->getNumElements(); 660 #else 661 uint32_t numElemSrc0 = src0->getType()->getVectorNumElements(); 662 uint32_t numElemMask0 = mask0->getType()->getVectorNumElements(); 663 uint32_t numElemSrc1 = src1->getType()->getVectorNumElements(); 664 uint32_t numElemMask1 = mask1->getType()->getVectorNumElements(); 665 #endif 666 src0 = B->BITCAST(src0, getVectorType(B->mInt64Ty, numElemSrc0)); 667 mask0 = B->BITCAST(mask0, getVectorType(B->mInt64Ty, numElemMask0)); 668 Value* gather0 = 669 B->CALL(pX86IntrinFunc, {src0, pBase, indices0, mask0, i8Scale}); 670 src1 = B->BITCAST(src1, getVectorType(B->mInt64Ty, numElemSrc1)); 671 mask1 = B->BITCAST(mask1, getVectorType(B->mInt64Ty, numElemMask1)); 672 Value* gather1 = 673 B->CALL(pX86IntrinFunc, {src1, pBase, indices1, mask1, i8Scale}); 674 v32Gather = B->VSHUFFLE(gather0, gather1, B->C({0, 1, 2, 3, 4, 5, 6, 7})); 675 v32Gather = B->BITCAST(v32Gather, vSrc->getType()); 676 } 677 else 678 { 679 // Double pump 8-wide for 32bit elements 680 auto v32Mask = pThis->VectorMask(vi1Mask); 681 v32Mask = B->BITCAST(v32Mask, vSrc->getType()); 682 Value* src0 = B->EXTRACT_16(vSrc, 0); 683 Value* src1 = B->EXTRACT_16(vSrc, 1); 684 685 Value* indices0 = B->EXTRACT_16(vi32Indices, 0); 686 Value* indices1 = B->EXTRACT_16(vi32Indices, 1); 687 688 Value* mask0 = B->EXTRACT_16(v32Mask, 0); 689 Value* mask1 = B->EXTRACT_16(v32Mask, 1); 690 691 Value* gather0 = 692 B->CALL(pX86IntrinFunc, {src0, pBase, indices0, mask0, i8Scale}); 693 Value* gather1 = 694 B->CALL(pX86IntrinFunc, {src1, pBase, indices1, mask1, i8Scale}); 695 696 v32Gather = B->JOIN_16(gather0, gather1); 697 } 698 } 699 } 700 else if (arch == AVX512) 701 { 702 Value* iMask = nullptr; 703 Function* pX86IntrinFunc = nullptr; 704 if (srcTy == B->mFP32Ty) 705 { 706 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 707 Intrinsic::x86_avx512_gather_dps_512); 708 iMask = B->BITCAST(vi1Mask, B->mInt16Ty); 709 } 710 else if (srcTy == B->mInt32Ty) 711 { 712 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 713 Intrinsic::x86_avx512_gather_dpi_512); 714 iMask = B->BITCAST(vi1Mask, B->mInt16Ty); 715 } 716 else if (srcTy == B->mDoubleTy) 717 { 718 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 719 Intrinsic::x86_avx512_gather_dpd_512); 720 iMask = B->BITCAST(vi1Mask, B->mInt8Ty); 721 } 722 else 723 { 724 SWR_ASSERT(false, "Unsupported vector element type for gather."); 725 } 726 727 auto i32Scale = B->Z_EXT(i8Scale, B->mInt32Ty); 728 v32Gather = B->CALL(pX86IntrinFunc, {vSrc, pBase, vi32Indices, iMask, i32Scale}); 729 } 730 731 return cast<Instruction>(v32Gather); 732 } 733 Instruction* VSCATTER_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)734 VSCATTER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst) 735 { 736 Builder* B = pThis->B; 737 auto pBase = pCallInst->getArgOperand(0); 738 auto vi1Mask = pCallInst->getArgOperand(1); 739 auto vi32Indices = pCallInst->getArgOperand(2); 740 auto v32Src = pCallInst->getArgOperand(3); 741 auto i32Scale = pCallInst->getArgOperand(4); 742 743 if (arch != AVX512) 744 { 745 // Call into C function to do the scatter. This has significantly better compile perf 746 // compared to jitting scatter loops for every scatter 747 if (width == W256) 748 { 749 auto mask = B->BITCAST(vi1Mask, B->mInt8Ty); 750 B->CALL(pThis->mPfnScatter256, {pBase, vi32Indices, v32Src, mask, i32Scale}); 751 } 752 else 753 { 754 // Need to break up 512 wide scatter to two 256 wide 755 auto maskLo = B->VSHUFFLE(vi1Mask, vi1Mask, B->C({0, 1, 2, 3, 4, 5, 6, 7})); 756 auto indicesLo = 757 B->VSHUFFLE(vi32Indices, vi32Indices, B->C({0, 1, 2, 3, 4, 5, 6, 7})); 758 auto srcLo = B->VSHUFFLE(v32Src, v32Src, B->C({0, 1, 2, 3, 4, 5, 6, 7})); 759 760 auto mask = B->BITCAST(maskLo, B->mInt8Ty); 761 B->CALL(pThis->mPfnScatter256, {pBase, indicesLo, srcLo, mask, i32Scale}); 762 763 auto maskHi = B->VSHUFFLE(vi1Mask, vi1Mask, B->C({8, 9, 10, 11, 12, 13, 14, 15})); 764 auto indicesHi = 765 B->VSHUFFLE(vi32Indices, vi32Indices, B->C({8, 9, 10, 11, 12, 13, 14, 15})); 766 auto srcHi = B->VSHUFFLE(v32Src, v32Src, B->C({8, 9, 10, 11, 12, 13, 14, 15})); 767 768 mask = B->BITCAST(maskHi, B->mInt8Ty); 769 B->CALL(pThis->mPfnScatter256, {pBase, indicesHi, srcHi, mask, i32Scale}); 770 } 771 return nullptr; 772 } 773 774 Value* iMask; 775 Function* pX86IntrinFunc; 776 if (width == W256) 777 { 778 // No direct intrinsic supported in llvm to scatter 8 elem with 32bit indices, but we 779 // can use the scatter of 8 elements with 64bit indices 780 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 781 Intrinsic::x86_avx512_scatter_qps_512); 782 783 auto vi32IndicesExt = B->Z_EXT(vi32Indices, B->mSimdInt64Ty); 784 iMask = B->BITCAST(vi1Mask, B->mInt8Ty); 785 B->CALL(pX86IntrinFunc, {pBase, iMask, vi32IndicesExt, v32Src, i32Scale}); 786 } 787 else if (width == W512) 788 { 789 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 790 Intrinsic::x86_avx512_scatter_dps_512); 791 iMask = B->BITCAST(vi1Mask, B->mInt16Ty); 792 B->CALL(pX86IntrinFunc, {pBase, iMask, vi32Indices, v32Src, i32Scale}); 793 } 794 return nullptr; 795 } 796 797 // No support for vroundps in avx512 (it is available in kncni), so emulate with avx 798 // instructions 799 Instruction* VROUND_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)800 VROUND_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst) 801 { 802 SWR_ASSERT(arch == AVX512); 803 804 auto B = pThis->B; 805 auto vf32Src = pCallInst->getOperand(0); 806 assert(vf32Src); 807 auto i8Round = pCallInst->getOperand(1); 808 assert(i8Round); 809 auto pfnFunc = 810 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, Intrinsic::x86_avx_round_ps_256); 811 812 if (width == W256) 813 { 814 return cast<Instruction>(B->CALL2(pfnFunc, vf32Src, i8Round)); 815 } 816 else if (width == W512) 817 { 818 auto v8f32SrcLo = B->EXTRACT_16(vf32Src, 0); 819 auto v8f32SrcHi = B->EXTRACT_16(vf32Src, 1); 820 821 auto v8f32ResLo = B->CALL2(pfnFunc, v8f32SrcLo, i8Round); 822 auto v8f32ResHi = B->CALL2(pfnFunc, v8f32SrcHi, i8Round); 823 824 return cast<Instruction>(B->JOIN_16(v8f32ResLo, v8f32ResHi)); 825 } 826 else 827 { 828 SWR_ASSERT(false, "Unimplemented vector width."); 829 } 830 831 return nullptr; 832 } 833 834 Instruction* VCONVERT_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)835 VCONVERT_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst) 836 { 837 SWR_ASSERT(arch == AVX512); 838 839 auto B = pThis->B; 840 auto vf32Src = pCallInst->getOperand(0); 841 842 if (width == W256) 843 { 844 auto vf32SrcRound = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 845 Intrinsic::x86_avx_round_ps_256); 846 return cast<Instruction>(B->FP_TRUNC(vf32SrcRound, B->mFP32Ty)); 847 } 848 else if (width == W512) 849 { 850 // 512 can use intrinsic 851 auto pfnFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 852 Intrinsic::x86_avx512_mask_cvtpd2ps_512); 853 return cast<Instruction>(B->CALL(pfnFunc, vf32Src)); 854 } 855 else 856 { 857 SWR_ASSERT(false, "Unimplemented vector width."); 858 } 859 860 return nullptr; 861 } 862 863 // No support for hsub in AVX512 VHSUB_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)864 Instruction* VHSUB_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst) 865 { 866 SWR_ASSERT(arch == AVX512); 867 868 auto B = pThis->B; 869 auto src0 = pCallInst->getOperand(0); 870 auto src1 = pCallInst->getOperand(1); 871 872 // 256b hsub can just use avx intrinsic 873 if (width == W256) 874 { 875 auto pX86IntrinFunc = 876 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, Intrinsic::x86_avx_hsub_ps_256); 877 return cast<Instruction>(B->CALL2(pX86IntrinFunc, src0, src1)); 878 } 879 else if (width == W512) 880 { 881 // 512b hsub can be accomplished with shuf/sub combo 882 auto minuend = B->VSHUFFLE(src0, src1, B->C({0, 2, 8, 10, 4, 6, 12, 14})); 883 auto subtrahend = B->VSHUFFLE(src0, src1, B->C({1, 3, 9, 11, 5, 7, 13, 15})); 884 return cast<Instruction>(B->SUB(minuend, subtrahend)); 885 } 886 else 887 { 888 SWR_ASSERT(false, "Unimplemented vector width."); 889 return nullptr; 890 } 891 } 892 893 // Double pump input using Intrin template arg. This blindly extracts lower and upper 256 from 894 // each vector argument and calls the 256 wide intrinsic, then merges the results to 512 wide DOUBLE_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst,Intrinsic::ID intrin)895 Instruction* DOUBLE_EMU(LowerX86* pThis, 896 TargetArch arch, 897 TargetWidth width, 898 CallInst* pCallInst, 899 Intrinsic::ID intrin) 900 { 901 auto B = pThis->B; 902 SWR_ASSERT(width == W512); 903 Value* result[2]; 904 Function* pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, intrin); 905 for (uint32_t i = 0; i < 2; ++i) 906 { 907 SmallVector<Value*, 8> args; 908 for (auto& arg : pCallInst->arg_operands()) 909 { 910 auto argType = arg.get()->getType(); 911 if (argType->isVectorTy()) 912 { 913 #if LLVM_VERSION_MAJOR >= 12 914 uint32_t vecWidth = cast<FixedVectorType>(argType)->getNumElements(); 915 auto elemTy = cast<FixedVectorType>(argType)->getElementType(); 916 #elif LLVM_VERSION_MAJOR >= 11 917 uint32_t vecWidth = cast<VectorType>(argType)->getNumElements(); 918 auto elemTy = cast<VectorType>(argType)->getElementType(); 919 #else 920 uint32_t vecWidth = argType->getVectorNumElements(); 921 auto elemTy = argType->getVectorElementType(); 922 #endif 923 Value* lanes = B->CInc<int>(i * vecWidth / 2, vecWidth / 2); 924 Value* argToPush = B->VSHUFFLE(arg.get(), B->VUNDEF(elemTy, vecWidth), lanes); 925 args.push_back(argToPush); 926 } 927 else 928 { 929 args.push_back(arg.get()); 930 } 931 } 932 result[i] = B->CALLA(pX86IntrinFunc, args); 933 } 934 uint32_t vecWidth; 935 if (result[0]->getType()->isVectorTy()) 936 { 937 assert(result[1]->getType()->isVectorTy()); 938 #if LLVM_VERSION_MAJOR >= 12 939 vecWidth = cast<FixedVectorType>(result[0]->getType())->getNumElements() + 940 cast<FixedVectorType>(result[1]->getType())->getNumElements(); 941 #elif LLVM_VERSION_MAJOR >= 11 942 vecWidth = cast<VectorType>(result[0]->getType())->getNumElements() + 943 cast<VectorType>(result[1]->getType())->getNumElements(); 944 #else 945 vecWidth = result[0]->getType()->getVectorNumElements() + 946 result[1]->getType()->getVectorNumElements(); 947 #endif 948 } 949 else 950 { 951 vecWidth = 2; 952 } 953 Value* lanes = B->CInc<int>(0, vecWidth); 954 return cast<Instruction>(B->VSHUFFLE(result[0], result[1], lanes)); 955 } 956 957 } // namespace SwrJit 958 959 using namespace SwrJit; 960 961 INITIALIZE_PASS_BEGIN(LowerX86, "LowerX86", "LowerX86", false, false) 962 INITIALIZE_PASS_END(LowerX86, "LowerX86", "LowerX86", false, false) 963