1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to transform <256 x i32> load/store 10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only 11 /// provides simple operation on x86_amx. The basic elementwise operation 12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> 13 /// and only AMX intrinsics can operate on the type, we need transform 14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can 15 /// not be combined with load/store, we transform the bitcast to amx load/store 16 /// and <256 x i32> store/load. 17 /// 18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S 19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, 20 /// because that is necessary for AMX fast register allocation. (In Fast 21 /// registera allocation, register will be allocated before spill/reload, so 22 /// there is no additional register for amx to identify the step in spill.) 23 /// The volatileTileData() will handle this case. 24 /// e.g. 25 /// ---------------------------------------------------------- 26 /// | def %td = ... | 27 /// | ... | 28 /// | "use %td" | 29 /// ---------------------------------------------------------- 30 /// will transfer to --> 31 /// ---------------------------------------------------------- 32 /// | def %td = ... | 33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) | 34 /// | ... | 35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| 36 /// | "use %td2" | 37 /// ---------------------------------------------------------- 38 // 39 //===----------------------------------------------------------------------===// 40 // 41 #include "X86.h" 42 #include "llvm/ADT/PostOrderIterator.h" 43 #include "llvm/ADT/SetVector.h" 44 #include "llvm/ADT/SmallSet.h" 45 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 46 #include "llvm/Analysis/TargetLibraryInfo.h" 47 #include "llvm/Analysis/TargetTransformInfo.h" 48 #include "llvm/CodeGen/Passes.h" 49 #include "llvm/CodeGen/TargetPassConfig.h" 50 #include "llvm/CodeGen/ValueTypes.h" 51 #include "llvm/IR/DataLayout.h" 52 #include "llvm/IR/Function.h" 53 #include "llvm/IR/IRBuilder.h" 54 #include "llvm/IR/Instructions.h" 55 #include "llvm/IR/IntrinsicInst.h" 56 #include "llvm/IR/IntrinsicsX86.h" 57 #include "llvm/IR/PatternMatch.h" 58 #include "llvm/InitializePasses.h" 59 #include "llvm/Pass.h" 60 #include "llvm/Target/TargetMachine.h" 61 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" 62 #include "llvm/Transforms/Utils/Local.h" 63 64 #include <map> 65 66 using namespace llvm; 67 using namespace PatternMatch; 68 69 #define DEBUG_TYPE "lower-amx-type" 70 71 static bool isAMXCast(Instruction *II) { 72 return match(II, 73 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) || 74 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value())); 75 } 76 77 static bool isAMXIntrinsic(Value *I) { 78 auto *II = dyn_cast<IntrinsicInst>(I); 79 if (!II) 80 return false; 81 if (isAMXCast(II)) 82 return false; 83 // Check if return type or parameter is x86_amx. If it is x86_amx 84 // the intrinsic must be x86 amx intrinsics. 85 if (II->getType()->isX86_AMXTy()) 86 return true; 87 for (Value *V : II->args()) { 88 if (V->getType()->isX86_AMXTy()) 89 return true; 90 } 91 92 return false; 93 } 94 95 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, 96 Type *Ty) { 97 Function &F = *BB->getParent(); 98 Module *M = BB->getModule(); 99 const DataLayout &DL = M->getDataLayout(); 100 101 LLVMContext &Ctx = Builder.getContext(); 102 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); 103 unsigned AllocaAS = DL.getAllocaAddrSpace(); 104 AllocaInst *AllocaRes = 105 new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front()); 106 AllocaRes->setAlignment(AllocaAlignment); 107 return AllocaRes; 108 } 109 110 static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { 111 for (Instruction &I : F.getEntryBlock()) 112 if (!isa<AllocaInst>(&I)) 113 return &I; 114 llvm_unreachable("No terminator in the entry block!"); 115 } 116 117 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) { 118 IRBuilder<> Builder(II); 119 Value *Row = nullptr, *Col = nullptr; 120 switch (II->getIntrinsicID()) { 121 default: 122 llvm_unreachable("Expect amx intrinsics"); 123 case Intrinsic::x86_tileloadd64_internal: 124 case Intrinsic::x86_tileloaddt164_internal: 125 case Intrinsic::x86_tilestored64_internal: { 126 Row = II->getArgOperand(0); 127 Col = II->getArgOperand(1); 128 break; 129 } 130 // a * b + c 131 // The shape depends on which operand. 132 case Intrinsic::x86_tdpbssd_internal: 133 case Intrinsic::x86_tdpbsud_internal: 134 case Intrinsic::x86_tdpbusd_internal: 135 case Intrinsic::x86_tdpbuud_internal: 136 case Intrinsic::x86_tdpbf16ps_internal: { 137 switch (OpNo) { 138 case 3: 139 Row = II->getArgOperand(0); 140 Col = II->getArgOperand(1); 141 break; 142 case 4: 143 Row = II->getArgOperand(0); 144 Col = II->getArgOperand(2); 145 break; 146 case 5: 147 if (isa<ConstantInt>(II->getArgOperand(2))) 148 Row = Builder.getInt16( 149 (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4); 150 else if (isa<Instruction>(II->getArgOperand(2))) { 151 // When it is not a const value and it is not a function argument, we 152 // create Row after the definition of II->getOperand(2) instead of 153 // before II. For example, II is %118, we try to getshape for %117: 154 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x 155 // i32> %115). 156 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 157 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx 158 // %117). 159 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its 160 // definition is after its user(new tileload for %117). 161 // So, the best choice is to create %row right after the definition of 162 // %106. 163 Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2))); 164 Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4)); 165 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2))); 166 } else { 167 // When it is not a const value and it is a function argument, we create 168 // Row at the entry bb. 169 IRBuilder<> NewBuilder( 170 getFirstNonAllocaInTheEntryBlock(*II->getFunction())); 171 Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4)); 172 } 173 Col = II->getArgOperand(1); 174 break; 175 } 176 break; 177 } 178 } 179 180 return std::make_pair(Row, Col); 181 } 182 183 static std::pair<Value *, Value *> getShape(PHINode *Phi) { 184 Use &U = *(Phi->use_begin()); 185 unsigned OpNo = U.getOperandNo(); 186 User *V = U.getUser(); 187 // TODO We don't traverse all users. To make the algorithm simple, here we 188 // just traverse the first user. If we can find shape, then return the shape, 189 // otherwise just return nullptr and the optimization for undef/zero will be 190 // abandoned. 191 while (V) { 192 if (isAMXCast(dyn_cast<Instruction>(V))) { 193 if (V->use_empty()) 194 break; 195 Use &U = *(V->use_begin()); 196 OpNo = U.getOperandNo(); 197 V = U.getUser(); 198 } else if (isAMXIntrinsic(V)) { 199 return getShape(cast<IntrinsicInst>(V), OpNo); 200 } else if (isa<PHINode>(V)) { 201 if (V->use_empty()) 202 break; 203 Use &U = *(V->use_begin()); 204 V = U.getUser(); 205 } else { 206 break; 207 } 208 } 209 210 return std::make_pair(nullptr, nullptr); 211 } 212 213 namespace { 214 class X86LowerAMXType { 215 Function &Func; 216 217 // In AMX intrinsics we let Shape = {Row, Col}, but the 218 // RealCol = Col / ElementSize. We may use the RealCol 219 // as a new Row for other new created AMX intrinsics. 220 std::map<Value *, Value *> Col2Row; 221 222 public: 223 X86LowerAMXType(Function &F) : Func(F) {} 224 bool visit(); 225 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); 226 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); 227 bool transformBitcast(BitCastInst *Bitcast); 228 }; 229 230 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 231 // %2 = bitcast <256 x i32> %src to x86_amx 232 // --> 233 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 234 // i8* %addr, i64 %stride64) 235 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { 236 Value *Row = nullptr, *Col = nullptr; 237 Use &U = *(Bitcast->use_begin()); 238 unsigned OpNo = U.getOperandNo(); 239 auto *II = cast<IntrinsicInst>(U.getUser()); 240 std::tie(Row, Col) = getShape(II, OpNo); 241 IRBuilder<> Builder(Bitcast); 242 // Use the maximun column as stride. 243 Value *Stride = Builder.getInt64(64); 244 Value *I8Ptr = 245 Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); 246 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 247 248 Value *NewInst = 249 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); 250 Bitcast->replaceAllUsesWith(NewInst); 251 } 252 253 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, 254 // %stride); 255 // %13 = bitcast x86_amx %src to <256 x i32> 256 // store <256 x i32> %13, <256 x i32>* %addr, align 64 257 // --> 258 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 259 // %stride64, %13) 260 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { 261 262 Value *Tile = Bitcast->getOperand(0); 263 auto *II = cast<IntrinsicInst>(Tile); 264 // Tile is output from AMX intrinsic. The first operand of the 265 // intrinsic is row, the second operand of the intrinsic is column. 266 Value *Row = II->getOperand(0); 267 Value *Col = II->getOperand(1); 268 IRBuilder<> Builder(ST); 269 // Use the maximum column as stride. It must be the same with load 270 // stride. 271 Value *Stride = Builder.getInt64(64); 272 Value *I8Ptr = 273 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); 274 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; 275 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); 276 if (Bitcast->hasOneUse()) 277 return; 278 // %13 = bitcast x86_amx %src to <256 x i32> 279 // store <256 x i32> %13, <256 x i32>* %addr, align 64 280 // %add = <256 x i32> %13, <256 x i32> %src2 281 // --> 282 // %13 = bitcast x86_amx %src to <256 x i32> 283 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 284 // %stride64, %13) 285 // %14 = load <256 x i32>, %addr 286 // %add = <256 x i32> %14, <256 x i32> %src2 287 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1)); 288 Bitcast->replaceAllUsesWith(Vec); 289 } 290 291 // transform bitcast to <store, load> instructions. 292 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { 293 IRBuilder<> Builder(Bitcast); 294 AllocaInst *AllocaAddr; 295 Value *I8Ptr, *Stride; 296 auto *Src = Bitcast->getOperand(0); 297 298 auto Prepare = [&](Type *MemTy) { 299 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy); 300 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); 301 Stride = Builder.getInt64(64); 302 }; 303 304 if (Bitcast->getType()->isX86_AMXTy()) { 305 // %2 = bitcast <256 x i32> %src to x86_amx 306 // --> 307 // %addr = alloca <256 x i32>, align 64 308 // store <256 x i32> %src, <256 x i32>* %addr, align 64 309 // %addr2 = bitcast <256 x i32>* to i8* 310 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 311 // i8* %addr2, 312 // i64 64) 313 Use &U = *(Bitcast->use_begin()); 314 unsigned OpNo = U.getOperandNo(); 315 auto *II = dyn_cast<IntrinsicInst>(U.getUser()); 316 if (!II) 317 return false; // May be bitcast from x86amx to <256 x i32>. 318 Prepare(Bitcast->getOperand(0)->getType()); 319 Builder.CreateStore(Src, AllocaAddr); 320 // TODO we can pick an constant operand for the shape. 321 Value *Row = nullptr, *Col = nullptr; 322 std::tie(Row, Col) = getShape(II, OpNo); 323 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 324 Value *NewInst = Builder.CreateIntrinsic( 325 Intrinsic::x86_tileloadd64_internal, None, Args); 326 Bitcast->replaceAllUsesWith(NewInst); 327 } else { 328 // %2 = bitcast x86_amx %src to <256 x i32> 329 // --> 330 // %addr = alloca <256 x i32>, align 64 331 // %addr2 = bitcast <256 x i32>* to i8* 332 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, 333 // i8* %addr2, i64 %stride) 334 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 335 auto *II = dyn_cast<IntrinsicInst>(Src); 336 if (!II) 337 return false; // May be bitcast from <256 x i32> to x86amx. 338 Prepare(Bitcast->getType()); 339 Value *Row = II->getOperand(0); 340 Value *Col = II->getOperand(1); 341 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; 342 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); 343 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr); 344 Bitcast->replaceAllUsesWith(NewInst); 345 } 346 347 return true; 348 } 349 350 bool X86LowerAMXType::visit() { 351 SmallVector<Instruction *, 8> DeadInsts; 352 Col2Row.clear(); 353 354 for (BasicBlock *BB : post_order(&Func)) { 355 for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) { 356 auto *Bitcast = dyn_cast<BitCastInst>(&Inst); 357 if (!Bitcast) 358 continue; 359 360 Value *Src = Bitcast->getOperand(0); 361 if (Bitcast->getType()->isX86_AMXTy()) { 362 if (Bitcast->user_empty()) { 363 DeadInsts.push_back(Bitcast); 364 continue; 365 } 366 LoadInst *LD = dyn_cast<LoadInst>(Src); 367 if (!LD) { 368 if (transformBitcast(Bitcast)) 369 DeadInsts.push_back(Bitcast); 370 continue; 371 } 372 // If load has mutli-user, duplicate a vector load. 373 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 374 // %2 = bitcast <256 x i32> %src to x86_amx 375 // %add = add <256 x i32> %src, <256 x i32> %src2 376 // --> 377 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 378 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 379 // i8* %addr, i64 %stride64) 380 // %add = add <256 x i32> %src, <256 x i32> %src2 381 382 // If load has one user, the load will be eliminated in DAG ISel. 383 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 384 // %2 = bitcast <256 x i32> %src to x86_amx 385 // --> 386 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 387 // i8* %addr, i64 %stride64) 388 combineLoadBitcast(LD, Bitcast); 389 DeadInsts.push_back(Bitcast); 390 if (LD->hasOneUse()) 391 DeadInsts.push_back(LD); 392 } else if (Src->getType()->isX86_AMXTy()) { 393 if (Bitcast->user_empty()) { 394 DeadInsts.push_back(Bitcast); 395 continue; 396 } 397 StoreInst *ST = nullptr; 398 for (Use &U : Bitcast->uses()) { 399 ST = dyn_cast<StoreInst>(U.getUser()); 400 if (ST) 401 break; 402 } 403 if (!ST) { 404 if (transformBitcast(Bitcast)) 405 DeadInsts.push_back(Bitcast); 406 continue; 407 } 408 // If bitcast (%13) has one use, combine bitcast and store to amx store. 409 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, 410 // %stride); 411 // %13 = bitcast x86_amx %src to <256 x i32> 412 // store <256 x i32> %13, <256 x i32>* %addr, align 64 413 // --> 414 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 415 // %stride64, %13) 416 // 417 // If bitcast (%13) has multi-use, transform as below. 418 // %13 = bitcast x86_amx %src to <256 x i32> 419 // store <256 x i32> %13, <256 x i32>* %addr, align 64 420 // %add = <256 x i32> %13, <256 x i32> %src2 421 // --> 422 // %13 = bitcast x86_amx %src to <256 x i32> 423 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 424 // %stride64, %13) 425 // %14 = load <256 x i32>, %addr 426 // %add = <256 x i32> %14, <256 x i32> %src2 427 // 428 combineBitcastStore(Bitcast, ST); 429 // Delete user first. 430 DeadInsts.push_back(ST); 431 DeadInsts.push_back(Bitcast); 432 } 433 } 434 } 435 436 bool C = !DeadInsts.empty(); 437 438 for (auto *Inst : DeadInsts) 439 Inst->eraseFromParent(); 440 441 return C; 442 } 443 } // anonymous namespace 444 445 static Value *getAllocaPos(BasicBlock *BB) { 446 Module *M = BB->getModule(); 447 Function *F = BB->getParent(); 448 IRBuilder<> Builder(&F->getEntryBlock().front()); 449 const DataLayout &DL = M->getDataLayout(); 450 unsigned AllocaAS = DL.getAllocaAddrSpace(); 451 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); 452 AllocaInst *AllocaRes = 453 new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front()); 454 BasicBlock::iterator Iter = AllocaRes->getIterator(); 455 ++Iter; 456 Builder.SetInsertPoint(&*Iter); 457 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy()); 458 return I8Ptr; 459 } 460 461 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { 462 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!"); 463 auto *II = cast<IntrinsicInst>(TileDef); 464 assert(II && "Not tile intrinsic!"); 465 Value *Row = II->getOperand(0); 466 Value *Col = II->getOperand(1); 467 468 BasicBlock *BB = TileDef->getParent(); 469 BasicBlock::iterator Iter = TileDef->getIterator(); 470 IRBuilder<> Builder(BB, ++Iter); 471 Value *Stride = Builder.getInt64(64); 472 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef}; 473 474 Instruction *TileStore = 475 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); 476 return TileStore; 477 } 478 479 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { 480 Value *V = U.get(); 481 assert(V->getType()->isX86_AMXTy() && "Not define tile!"); 482 483 // Get tile shape. 484 IntrinsicInst *II = nullptr; 485 if (IsPHI) { 486 Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0); 487 II = cast<IntrinsicInst>(PhiOp); 488 } else { 489 II = cast<IntrinsicInst>(V); 490 } 491 Value *Row = II->getOperand(0); 492 Value *Col = II->getOperand(1); 493 494 Instruction *UserI = dyn_cast<Instruction>(U.getUser()); 495 IRBuilder<> Builder(UserI); 496 Value *Stride = Builder.getInt64(64); 497 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride}; 498 499 Value *TileLoad = 500 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); 501 UserI->replaceUsesOfWith(V, TileLoad); 502 } 503 504 static bool isIncomingOfPHI(Instruction *I) { 505 for (Use &U : I->uses()) { 506 User *V = U.getUser(); 507 if (isa<PHINode>(V)) 508 return true; 509 } 510 return false; 511 } 512 513 // Let all AMX tile data become volatile data, shorten the life range 514 // of each tile register before fast register allocation. 515 namespace { 516 class X86VolatileTileData { 517 Function &F; 518 519 public: 520 X86VolatileTileData(Function &Func) : F(Func) {} 521 Value *updatePhiIncomings(BasicBlock *BB, 522 SmallVector<Instruction *, 2> &Incomings); 523 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); 524 bool volatileTileData(); 525 void volatileTilePHI(PHINode *Inst); 526 void volatileTileNonPHI(Instruction *I); 527 }; 528 529 Value *X86VolatileTileData::updatePhiIncomings( 530 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { 531 Value *I8Ptr = getAllocaPos(BB); 532 533 for (auto *I : Incomings) { 534 User *Store = createTileStore(I, I8Ptr); 535 536 // All its uses (except phi) should load from stored mem. 537 for (Use &U : I->uses()) { 538 User *V = U.getUser(); 539 if (isa<PHINode>(V) || V == Store) 540 continue; 541 replaceWithTileLoad(U, I8Ptr); 542 } 543 } 544 return I8Ptr; 545 } 546 547 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, 548 Value *StorePtr) { 549 for (Use &U : PHI->uses()) 550 replaceWithTileLoad(U, StorePtr, true); 551 PHI->eraseFromParent(); 552 } 553 554 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes 555 // and their related AMX intrinsics. 556 // 1) PHI Def should change to tileload. 557 // 2) PHI Incoming Values should tilestored in just after their def. 558 // 3) The mem of these tileload and tilestores should be same. 559 // e.g. 560 // ------------------------------------------------------ 561 // bb_dom: 562 // ... 563 // br i1 %bool.cond, label %if.else, label %if.then 564 // 565 // if.then: 566 // def %t0 = ... 567 // ... 568 // use %t0 569 // ... 570 // br label %if.end 571 // 572 // if.else: 573 // def %t1 = ... 574 // br label %if.end 575 // 576 // if.end: 577 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] 578 // ... 579 // use %td 580 // ------------------------------------------------------ 581 // --> 582 // ------------------------------------------------------ 583 // bb_entry: 584 // %mem = alloca <256 x i32>, align 1024 * 585 // ... 586 // bb_dom: 587 // ... 588 // br i1 %bool.cond, label %if.else, label %if.then 589 // 590 // if.then: 591 // def %t0 = ... 592 // call void @llvm.x86.tilestored64.internal(mem, %t0) * 593 // ... 594 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* 595 // use %t0` * 596 // ... 597 // br label %if.end 598 // 599 // if.else: 600 // def %t1 = ... 601 // call void @llvm.x86.tilestored64.internal(mem, %t1) * 602 // br label %if.end 603 // 604 // if.end: 605 // ... 606 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * 607 // use %td 608 // ------------------------------------------------------ 609 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { 610 BasicBlock *BB = PHI->getParent(); 611 SmallVector<Instruction *, 2> Incomings; 612 613 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { 614 Value *Op = PHI->getIncomingValue(I); 615 Instruction *Inst = dyn_cast<Instruction>(Op); 616 assert(Inst && "We shouldn't fold AMX instrution!"); 617 Incomings.push_back(Inst); 618 } 619 620 Value *StorePtr = updatePhiIncomings(BB, Incomings); 621 replacePhiDefWithLoad(PHI, StorePtr); 622 } 623 624 // Store the defined tile and load it before use. 625 // All its users are not PHI. 626 // e.g. 627 // ------------------------------------------------------ 628 // def %td = ... 629 // ... 630 // "use %td" 631 // ------------------------------------------------------ 632 // --> 633 // ------------------------------------------------------ 634 // def %td = ... 635 // call void @llvm.x86.tilestored64.internal(mem, %td) 636 // ... 637 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) 638 // "use %td2" 639 // ------------------------------------------------------ 640 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { 641 BasicBlock *BB = I->getParent(); 642 Value *I8Ptr = getAllocaPos(BB); 643 User *Store = createTileStore(I, I8Ptr); 644 645 // All its uses should load from stored mem. 646 for (Use &U : I->uses()) { 647 User *V = U.getUser(); 648 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!"); 649 if (V != Store) 650 replaceWithTileLoad(U, I8Ptr); 651 } 652 } 653 654 // Volatile Tile Model: 655 // 1) All the uses of tile data comes from tileload in time. 656 // 2) All the defs of tile data tilestore into mem immediately. 657 // For example: 658 // -------------------------------------------------------------------------- 659 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key 660 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) 661 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx 662 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) 663 // call void @llvm.x86.tilestored64.internal(... td) area 664 // -------------------------------------------------------------------------- 665 // 3) No terminator, call or other amx instructions in the key amx area. 666 bool X86VolatileTileData::volatileTileData() { 667 bool Changed = false; 668 for (BasicBlock &BB : F) { 669 SmallVector<Instruction *, 2> PHIInsts; 670 SmallVector<Instruction *, 8> AMXDefInsts; 671 672 for (Instruction &I : BB) { 673 if (!I.getType()->isX86_AMXTy()) 674 continue; 675 if (isa<PHINode>(&I)) 676 PHIInsts.push_back(&I); 677 else 678 AMXDefInsts.push_back(&I); 679 } 680 681 // First we "volatile" the non-phi related amx intrinsics. 682 for (Instruction *I : AMXDefInsts) { 683 if (isIncomingOfPHI(I)) 684 continue; 685 volatileTileNonPHI(I); 686 Changed = true; 687 } 688 689 for (Instruction *I : PHIInsts) { 690 volatileTilePHI(dyn_cast<PHINode>(I)); 691 Changed = true; 692 } 693 } 694 return Changed; 695 } 696 697 } // anonymous namespace 698 699 namespace { 700 701 class X86LowerAMXCast { 702 Function &Func; 703 704 public: 705 X86LowerAMXCast(Function &F) : Func(F) {} 706 void combineCastStore(IntrinsicInst *Cast, StoreInst *ST); 707 void combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); 708 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts); 709 bool combineAMXcast(TargetLibraryInfo *TLI); 710 bool transformAMXCast(IntrinsicInst *AMXCast); 711 bool transformAllAMXCast(); 712 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN, 713 SmallSetVector<Instruction *, 16> &DeadInst); 714 }; 715 716 static bool DCEInstruction(Instruction *I, 717 SmallSetVector<Instruction *, 16> &WorkList, 718 const TargetLibraryInfo *TLI) { 719 if (isInstructionTriviallyDead(I, TLI)) { 720 salvageDebugInfo(*I); 721 salvageKnowledge(I); 722 723 // Null out all of the instruction's operands to see if any operand becomes 724 // dead as we go. 725 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { 726 Value *OpV = I->getOperand(i); 727 I->setOperand(i, nullptr); 728 729 if (!OpV->use_empty() || I == OpV) 730 continue; 731 732 // If the operand is an instruction that became dead as we nulled out the 733 // operand, and if it is 'trivially' dead, delete it in a future loop 734 // iteration. 735 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) { 736 if (isInstructionTriviallyDead(OpI, TLI)) { 737 WorkList.insert(OpI); 738 } 739 } 740 } 741 I->eraseFromParent(); 742 return true; 743 } 744 return false; 745 } 746 747 /// This function handles following case 748 /// 749 /// A -> B amxcast 750 /// PHI 751 /// B -> A amxcast 752 /// 753 /// All the related PHI nodes can be replaced by new PHI nodes with type A. 754 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. 755 bool X86LowerAMXCast::optimizeAMXCastFromPhi( 756 IntrinsicInst *CI, PHINode *PN, 757 SmallSetVector<Instruction *, 16> &DeadInst) { 758 IRBuilder<> Builder(CI); 759 Value *Src = CI->getOperand(0); 760 Type *SrcTy = Src->getType(); // Type B 761 Type *DestTy = CI->getType(); // Type A 762 763 SmallVector<PHINode *, 4> PhiWorklist; 764 SmallSetVector<PHINode *, 4> OldPhiNodes; 765 766 // Find all of the A->B casts and PHI nodes. 767 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so 768 // OldPhiNodes is used to track all known PHI nodes, before adding a new 769 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. 770 PhiWorklist.push_back(PN); 771 OldPhiNodes.insert(PN); 772 while (!PhiWorklist.empty()) { 773 auto *OldPN = PhiWorklist.pop_back_val(); 774 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) { 775 Value *IncValue = OldPN->getIncomingValue(I); 776 // TODO: currently, We ignore cases where it is a const. In the future, we 777 // might support const. 778 if (isa<Constant>(IncValue)) { 779 auto *IncConst = dyn_cast<Constant>(IncValue); 780 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue()) 781 return false; 782 Value *Row = nullptr, *Col = nullptr; 783 std::tie(Row, Col) = getShape(OldPN); 784 // TODO: If it is not constant the Row and Col must domoniate tilezero 785 // that we are going to create. 786 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col)) 787 return false; 788 // Create tilezero at the end of incoming block. 789 auto *Block = OldPN->getIncomingBlock(I); 790 BasicBlock::iterator Iter = Block->getTerminator()->getIterator(); 791 Instruction *NewInst = Builder.CreateIntrinsic( 792 Intrinsic::x86_tilezero_internal, None, {Row, Col}); 793 NewInst->moveBefore(&*Iter); 794 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector, 795 {IncValue->getType()}, {NewInst}); 796 NewInst->moveBefore(&*Iter); 797 // Replace InValue with new Value. 798 OldPN->setIncomingValue(I, NewInst); 799 IncValue = NewInst; 800 } 801 802 if (auto *PNode = dyn_cast<PHINode>(IncValue)) { 803 if (OldPhiNodes.insert(PNode)) 804 PhiWorklist.push_back(PNode); 805 continue; 806 } 807 Instruction *ACI = dyn_cast<Instruction>(IncValue); 808 if (ACI && isAMXCast(ACI)) { 809 // Verify it's a A->B cast. 810 Type *TyA = ACI->getOperand(0)->getType(); 811 Type *TyB = ACI->getType(); 812 if (TyA != DestTy || TyB != SrcTy) 813 return false; 814 continue; 815 } 816 return false; 817 } 818 } 819 820 // Check that each user of each old PHI node is something that we can 821 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. 822 for (auto *OldPN : OldPhiNodes) { 823 for (User *V : OldPN->users()) { 824 Instruction *ACI = dyn_cast<Instruction>(V); 825 if (ACI && isAMXCast(ACI)) { 826 // Verify it's a B->A cast. 827 Type *TyB = ACI->getOperand(0)->getType(); 828 Type *TyA = ACI->getType(); 829 if (TyA != DestTy || TyB != SrcTy) 830 return false; 831 } else if (auto *PHI = dyn_cast<PHINode>(V)) { 832 // As long as the user is another old PHI node, then even if we don't 833 // rewrite it, the PHI web we're considering won't have any users 834 // outside itself, so it'll be dead. 835 // example: 836 // bb.0: 837 // %0 = amxcast ... 838 // bb.1: 839 // %1 = amxcast ... 840 // bb.2: 841 // %goodphi = phi %0, %1 842 // %3 = amxcast %goodphi 843 // bb.3: 844 // %goodphi2 = phi %0, %goodphi 845 // %4 = amxcast %goodphi2 846 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is 847 // outside the phi-web, so the combination stop When 848 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization 849 // will be done. 850 if (OldPhiNodes.count(PHI) == 0) 851 return false; 852 } else 853 return false; 854 } 855 } 856 857 // For each old PHI node, create a corresponding new PHI node with a type A. 858 SmallDenseMap<PHINode *, PHINode *> NewPNodes; 859 for (auto *OldPN : OldPhiNodes) { 860 Builder.SetInsertPoint(OldPN); 861 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands()); 862 NewPNodes[OldPN] = NewPN; 863 } 864 865 // Fill in the operands of new PHI nodes. 866 for (auto *OldPN : OldPhiNodes) { 867 PHINode *NewPN = NewPNodes[OldPN]; 868 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { 869 Value *V = OldPN->getOperand(j); 870 Value *NewV = nullptr; 871 Instruction *ACI = dyn_cast<Instruction>(V); 872 // There should not be a AMXcast from a const. 873 if (ACI && isAMXCast(ACI)) 874 NewV = ACI->getOperand(0); 875 else if (auto *PrevPN = dyn_cast<PHINode>(V)) 876 NewV = NewPNodes[PrevPN]; 877 assert(NewV); 878 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j)); 879 } 880 } 881 882 // Traverse all accumulated PHI nodes and process its users, 883 // which are Stores and BitcCasts. Without this processing 884 // NewPHI nodes could be replicated and could lead to extra 885 // moves generated after DeSSA. 886 // If there is a store with type B, change it to type A. 887 888 // Replace users of BitCast B->A with NewPHI. These will help 889 // later to get rid of a closure formed by OldPHI nodes. 890 for (auto *OldPN : OldPhiNodes) { 891 PHINode *NewPN = NewPNodes[OldPN]; 892 for (User *V : make_early_inc_range(OldPN->users())) { 893 Instruction *ACI = dyn_cast<Instruction>(V); 894 if (ACI && isAMXCast(ACI)) { 895 Type *TyB = ACI->getOperand(0)->getType(); 896 Type *TyA = ACI->getType(); 897 assert(TyA == DestTy && TyB == SrcTy); 898 (void)TyA; 899 (void)TyB; 900 ACI->replaceAllUsesWith(NewPN); 901 DeadInst.insert(ACI); 902 } else if (auto *PHI = dyn_cast<PHINode>(V)) { 903 // We don't need to push PHINode into DeadInst since they are operands 904 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead. 905 assert(OldPhiNodes.contains(PHI)); 906 (void)PHI; 907 } else 908 llvm_unreachable("all uses should be handled"); 909 } 910 } 911 return true; 912 } 913 914 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42) 915 // store <256 x i32> %43, <256 x i32>* %p, align 64 916 // --> 917 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, 918 // i64 64, x86_amx %42) 919 void X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) { 920 Value *Tile = Cast->getOperand(0); 921 // TODO: If it is cast intrinsic or phi node, we can propagate the 922 // shape information through def-use chain. 923 if (!isAMXIntrinsic(Tile)) 924 return; 925 auto *II = cast<IntrinsicInst>(Tile); 926 // Tile is output from AMX intrinsic. The first operand of the 927 // intrinsic is row, the second operand of the intrinsic is column. 928 Value *Row = II->getOperand(0); 929 Value *Col = II->getOperand(1); 930 IRBuilder<> Builder(ST); 931 // Use the maximum column as stride. It must be the same with load 932 // stride. 933 Value *Stride = Builder.getInt64(64); 934 Value *I8Ptr = 935 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); 936 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; 937 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); 938 } 939 940 // %65 = load <256 x i32>, <256 x i32>* %p, align 64 941 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) 942 // --> 943 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 944 // i8* %p, i64 64) 945 void X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { 946 Value *Row = nullptr, *Col = nullptr; 947 Use &U = *(Cast->use_begin()); 948 unsigned OpNo = U.getOperandNo(); 949 auto *II = cast<IntrinsicInst>(U.getUser()); 950 // TODO: If it is cast intrinsic or phi node, we can propagate the 951 // shape information through def-use chain. 952 if (!isAMXIntrinsic(II)) 953 return; 954 std::tie(Row, Col) = getShape(II, OpNo); 955 IRBuilder<> Builder(LD); 956 // Use the maximun column as stride. 957 Value *Stride = Builder.getInt64(64); 958 Value *I8Ptr = 959 Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); 960 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 961 962 Value *NewInst = 963 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); 964 Cast->replaceAllUsesWith(NewInst); 965 } 966 967 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) { 968 bool Change = false; 969 for (auto *Cast : Casts) { 970 auto *II = cast<IntrinsicInst>(Cast); 971 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42) 972 // store <256 x i32> %43, <256 x i32>* %p, align 64 973 // --> 974 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, 975 // i64 64, x86_amx %42) 976 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) { 977 SmallVector<Instruction *, 2> DeadStores; 978 for (User *U : Cast->users()) { 979 StoreInst *Store = dyn_cast<StoreInst>(U); 980 if (!Store) 981 continue; 982 combineCastStore(cast<IntrinsicInst>(Cast), Store); 983 DeadStores.push_back(Store); 984 Change = true; 985 } 986 for (auto *Store : DeadStores) 987 Store->eraseFromParent(); 988 } else { // x86_cast_vector_to_tile 989 SmallVector<Instruction *, 2> DeadLoads; 990 auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0)); 991 if (!Load || !Load->hasOneUse()) 992 continue; 993 // %65 = load <256 x i32>, <256 x i32>* %p, align 64 994 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) 995 // --> 996 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 997 // i8* %p, i64 64) 998 combineLoadCast(cast<IntrinsicInst>(Cast), Load); 999 // Set the operand is null so that load instruction can be erased. 1000 Cast->setOperand(0, nullptr); 1001 Load->eraseFromParent(); 1002 } 1003 } 1004 return Change; 1005 } 1006 1007 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { 1008 bool Change = false; 1009 // Collect tile cast instruction. 1010 SmallVector<Instruction *, 8> Vec2TileInsts; 1011 SmallVector<Instruction *, 8> Tile2VecInsts; 1012 SmallVector<Instruction *, 8> PhiCastWorkList; 1013 SmallSetVector<Instruction *, 16> DeadInst; 1014 for (BasicBlock &BB : Func) { 1015 for (Instruction &I : BB) { 1016 Value *Vec; 1017 if (match(&I, 1018 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec)))) 1019 Vec2TileInsts.push_back(&I); 1020 else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>( 1021 m_Value(Vec)))) 1022 Tile2VecInsts.push_back(&I); 1023 } 1024 } 1025 1026 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) { 1027 for (auto *Inst : Insts) { 1028 for (User *U : Inst->users()) { 1029 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); 1030 if (!II || II->getIntrinsicID() != IID) 1031 continue; 1032 // T1 = vec2tile V0 1033 // V2 = tile2vec T1 1034 // V3 = OP V2 1035 // --> 1036 // T1 = vec2tile V0 1037 // V2 = tile2vec T1 1038 // V3 = OP V0 1039 II->replaceAllUsesWith(Inst->getOperand(0)); 1040 Change = true; 1041 } 1042 } 1043 }; 1044 1045 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector); 1046 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile); 1047 1048 SmallVector<Instruction *, 8> LiveCasts; 1049 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) { 1050 for (auto *Inst : Insts) { 1051 if (Inst->use_empty()) { 1052 Inst->eraseFromParent(); 1053 Change = true; 1054 } else { 1055 LiveCasts.push_back(Inst); 1056 } 1057 } 1058 }; 1059 1060 EraseInst(Vec2TileInsts); 1061 EraseInst(Tile2VecInsts); 1062 Change |= combineLdSt(LiveCasts); 1063 EraseInst(LiveCasts); 1064 1065 // Handle the A->B->A cast, and there is an intervening PHI node. 1066 for (BasicBlock &BB : Func) { 1067 for (Instruction &I : BB) { 1068 if (isAMXCast(&I)) { 1069 if (isa<PHINode>(I.getOperand(0))) 1070 PhiCastWorkList.push_back(&I); 1071 } 1072 } 1073 } 1074 for (auto *I : PhiCastWorkList) { 1075 // We skip the dead Amxcast. 1076 if (DeadInst.contains(I)) 1077 continue; 1078 PHINode *PN = cast<PHINode>(I->getOperand(0)); 1079 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) { 1080 DeadInst.insert(PN); 1081 Change = true; 1082 } 1083 } 1084 1085 // Since we create new phi and merge AMXCast, some old phis and AMXCast might 1086 // have no uses. We do some DeadCodeElimination for them. 1087 while (!DeadInst.empty()) { 1088 Instruction *I = DeadInst.pop_back_val(); 1089 Change |= DCEInstruction(I, DeadInst, TLI); 1090 } 1091 return Change; 1092 } 1093 1094 // There might be remaining AMXcast after combineAMXcast and they should be 1095 // handled elegantly. 1096 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { 1097 IRBuilder<> Builder(AMXCast); 1098 AllocaInst *AllocaAddr; 1099 Value *I8Ptr, *Stride; 1100 auto *Src = AMXCast->getOperand(0); 1101 1102 auto Prepare = [&](Type *MemTy) { 1103 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy); 1104 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); 1105 Stride = Builder.getInt64(64); 1106 }; 1107 1108 if (AMXCast->getType()->isX86_AMXTy()) { 1109 // %2 = amxcast <225 x i32> %src to x86_amx 1110 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, 1111 // i8* %addr3, i64 60, x86_amx %2) 1112 // --> 1113 // %addr = alloca <225 x i32>, align 64 1114 // store <225 x i32> %src, <225 x i32>* %addr, align 64 1115 // %addr2 = bitcast <225 x i32>* %addr to i8* 1116 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60, 1117 // i8* %addr2, 1118 // i64 60) 1119 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, 1120 // i8* %addr3, i64 60, x86_amx %2) 1121 if (AMXCast->use_empty()) { 1122 AMXCast->eraseFromParent(); 1123 return true; 1124 } 1125 Use &U = *(AMXCast->use_begin()); 1126 unsigned OpNo = U.getOperandNo(); 1127 auto *II = dyn_cast<IntrinsicInst>(U.getUser()); 1128 if (!II) 1129 return false; // May be bitcast from x86amx to <256 x i32>. 1130 Prepare(AMXCast->getOperand(0)->getType()); 1131 Builder.CreateStore(Src, AllocaAddr); 1132 // TODO we can pick an constant operand for the shape. 1133 Value *Row = nullptr, *Col = nullptr; 1134 std::tie(Row, Col) = getShape(II, OpNo); 1135 std::array<Value *, 4> Args = { 1136 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())}; 1137 Value *NewInst = Builder.CreateIntrinsic( 1138 Intrinsic::x86_tileloadd64_internal, None, Args); 1139 AMXCast->replaceAllUsesWith(NewInst); 1140 AMXCast->eraseFromParent(); 1141 } else { 1142 // %2 = amxcast x86_amx %src to <225 x i32> 1143 // --> 1144 // %addr = alloca <225 x i32>, align 64 1145 // %addr2 = bitcast <225 x i32>* to i8* 1146 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, 1147 // i8* %addr2, i64 %stride) 1148 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64 1149 auto *II = dyn_cast<IntrinsicInst>(Src); 1150 if (!II) 1151 return false; // May be bitcast from <256 x i32> to x86amx. 1152 Prepare(AMXCast->getType()); 1153 Value *Row = II->getOperand(0); 1154 Value *Col = II->getOperand(1); 1155 std::array<Value *, 5> Args = { 1156 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src}; 1157 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); 1158 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr); 1159 AMXCast->replaceAllUsesWith(NewInst); 1160 AMXCast->eraseFromParent(); 1161 } 1162 1163 return true; 1164 } 1165 1166 bool X86LowerAMXCast::transformAllAMXCast() { 1167 bool Change = false; 1168 // Collect tile cast instruction. 1169 SmallVector<Instruction *, 8> WorkLists; 1170 for (BasicBlock &BB : Func) { 1171 for (Instruction &I : BB) { 1172 if (isAMXCast(&I)) 1173 WorkLists.push_back(&I); 1174 } 1175 } 1176 1177 for (auto *Inst : WorkLists) { 1178 Change |= transformAMXCast(cast<IntrinsicInst>(Inst)); 1179 } 1180 1181 return Change; 1182 } 1183 1184 } // anonymous namespace 1185 1186 namespace { 1187 1188 class X86LowerAMXTypeLegacyPass : public FunctionPass { 1189 public: 1190 static char ID; 1191 1192 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { 1193 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); 1194 } 1195 1196 bool runOnFunction(Function &F) override { 1197 bool C = false; 1198 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 1199 TargetLibraryInfo *TLI = 1200 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 1201 X86LowerAMXCast LAC(F); 1202 C |= LAC.combineAMXcast(TLI); 1203 // There might be remaining AMXcast after combineAMXcast and they should be 1204 // handled elegantly. 1205 C |= LAC.transformAllAMXCast(); 1206 1207 X86LowerAMXType LAT(F); 1208 C |= LAT.visit(); 1209 1210 // Prepare for fast register allocation at O0. 1211 // Todo: May better check the volatile model of AMX code, not just 1212 // by checking Attribute::OptimizeNone and CodeGenOpt::None. 1213 if (TM->getOptLevel() == CodeGenOpt::None) { 1214 // If Front End not use O0 but the Mid/Back end use O0, (e.g. 1215 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make 1216 // sure the amx data is volatile, that is nessary for AMX fast 1217 // register allocation. 1218 if (!F.hasFnAttribute(Attribute::OptimizeNone)) { 1219 X86VolatileTileData VTD(F); 1220 C = VTD.volatileTileData() || C; 1221 } 1222 } 1223 1224 return C; 1225 } 1226 1227 void getAnalysisUsage(AnalysisUsage &AU) const override { 1228 AU.setPreservesCFG(); 1229 AU.addRequired<TargetPassConfig>(); 1230 AU.addRequired<TargetLibraryInfoWrapperPass>(); 1231 } 1232 }; 1233 1234 } // anonymous namespace 1235 1236 static const char PassName[] = "Lower AMX type for load/store"; 1237 char X86LowerAMXTypeLegacyPass::ID = 0; 1238 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, 1239 false) 1240 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 1241 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 1242 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, 1243 false) 1244 1245 FunctionPass *llvm::createX86LowerAMXTypePass() { 1246 return new X86LowerAMXTypeLegacyPass(); 1247 } 1248