1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to transform <256 x i32> load/store 10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only 11 /// provides simple operation on x86_amx. The basic elementwise operation 12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> 13 /// and only AMX intrinsics can operate on the type, we need transform 14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can 15 /// not be combined with load/store, we transform the bitcast to amx load/store 16 /// and <256 x i32> store/load. 17 /// 18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S 19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, 20 /// because that is necessary for AMX fast register allocation. (In Fast 21 /// registera allocation, register will be allocated before spill/reload, so 22 /// there is no additional register for amx to identify the step in spill.) 23 /// The volatileTileData() will handle this case. 24 /// e.g. 25 /// ---------------------------------------------------------- 26 /// | def %td = ... | 27 /// | ... | 28 /// | "use %td" | 29 /// ---------------------------------------------------------- 30 /// will transfer to --> 31 /// ---------------------------------------------------------- 32 /// | def %td = ... | 33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) | 34 /// | ... | 35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| 36 /// | "use %td2" | 37 /// ---------------------------------------------------------- 38 // 39 //===----------------------------------------------------------------------===// 40 // 41 #include "X86.h" 42 #include "llvm/ADT/PostOrderIterator.h" 43 #include "llvm/ADT/SmallSet.h" 44 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 45 #include "llvm/Analysis/TargetTransformInfo.h" 46 #include "llvm/CodeGen/Passes.h" 47 #include "llvm/CodeGen/TargetPassConfig.h" 48 #include "llvm/CodeGen/ValueTypes.h" 49 #include "llvm/IR/DataLayout.h" 50 #include "llvm/IR/Function.h" 51 #include "llvm/IR/IRBuilder.h" 52 #include "llvm/IR/Instructions.h" 53 #include "llvm/IR/IntrinsicInst.h" 54 #include "llvm/IR/IntrinsicsX86.h" 55 #include "llvm/IR/PatternMatch.h" 56 #include "llvm/InitializePasses.h" 57 #include "llvm/Pass.h" 58 #include "llvm/Target/TargetMachine.h" 59 60 using namespace llvm; 61 using namespace PatternMatch; 62 63 #define DEBUG_TYPE "lower-amx-type" 64 65 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, 66 BasicBlock *BB) { 67 Function &F = *BB->getParent(); 68 Module *M = BB->getModule(); 69 const DataLayout &DL = M->getDataLayout(); 70 71 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); 72 LLVMContext &Ctx = Builder.getContext(); 73 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); 74 unsigned AllocaAS = DL.getAllocaAddrSpace(); 75 AllocaInst *AllocaRes = 76 new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front()); 77 AllocaRes->setAlignment(AllocaAlignment); 78 return AllocaRes; 79 } 80 81 namespace { 82 class X86LowerAMXType { 83 Function &Func; 84 TargetMachine *TM = nullptr; 85 86 // In AMX intrinsics we let Shape = {Row, Col}, but the 87 // RealCol = Col / ElementSize. We may use the RealCol 88 // as a new Row for other new created AMX intrinsics. 89 std::map<Value *, Value *> Col2Row; 90 91 public: 92 X86LowerAMXType(Function &F, TargetMachine *TargetM) : Func(F), TM(TargetM) {} 93 bool visit(); 94 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); 95 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); 96 bool transformBitcast(BitCastInst *Bitcast); 97 std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo); 98 Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity); 99 }; 100 101 Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V, 102 unsigned Granularity) { 103 if (Col2Row.count(V)) 104 return Col2Row[V]; 105 IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt()); 106 if (auto *I = dyn_cast<Instruction>(V)) { 107 BasicBlock::iterator Iter = I->getIterator(); 108 ++Iter; 109 Builder.SetInsertPoint(&*Iter); 110 } 111 ConstantInt *Gran = Builder.getInt16(Granularity); 112 Value *RealRow = Builder.CreateUDiv(V, Gran); 113 Col2Row[V] = RealRow; 114 return RealRow; 115 } 116 117 std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II, 118 unsigned OpNo) { 119 Value *Row = nullptr, *Col = nullptr; 120 switch (II->getIntrinsicID()) { 121 default: 122 llvm_unreachable("Expect amx intrinsics"); 123 case Intrinsic::x86_tileloadd64_internal: 124 case Intrinsic::x86_tileloaddt164_internal: 125 case Intrinsic::x86_tilestored64_internal: { 126 Row = II->getArgOperand(0); 127 Col = II->getArgOperand(1); 128 break; 129 } 130 // a * b + c 131 // The shape depends on which operand. 132 case Intrinsic::x86_tdpbssd_internal: 133 case Intrinsic::x86_tdpbsud_internal: 134 case Intrinsic::x86_tdpbusd_internal: 135 case Intrinsic::x86_tdpbuud_internal: 136 case Intrinsic::x86_tdpbf16ps_internal: { 137 switch (OpNo) { 138 case 3: 139 Row = II->getArgOperand(0); 140 Col = II->getArgOperand(1); 141 break; 142 case 4: 143 Row = II->getArgOperand(0); 144 Col = II->getArgOperand(2); 145 break; 146 case 5: 147 Row = II->getArgOperand(2); 148 // FIXME: There is a design bug for AMX shape, which the Col should be 149 // Col/4 if it will be used as Row, but current Greedy RA can't handle 150 // this case well, it may failed if we generate a new Shape definition. 151 // So Let's just do it in O0 first. 152 // Row = Row / 4 153 if (TM->getOptLevel() == CodeGenOpt::None) 154 Row = getRowFromCol(II, Row, 4); 155 Col = II->getArgOperand(1); 156 break; 157 } 158 break; 159 } 160 } 161 162 return std::make_pair(Row, Col); 163 } 164 165 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 166 // %2 = bitcast <256 x i32> %src to x86_amx 167 // --> 168 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 169 // i8* %addr, i64 %stride64) 170 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { 171 Value *Row = nullptr, *Col = nullptr; 172 Use &U = *(Bitcast->use_begin()); 173 unsigned OpNo = U.getOperandNo(); 174 auto *II = cast<IntrinsicInst>(U.getUser()); 175 std::tie(Row, Col) = getShape(II, OpNo); 176 IRBuilder<> Builder(Bitcast); 177 // Use the maximun column as stride. 178 Value *Stride = Builder.getInt64(64); 179 Value *I8Ptr = 180 Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); 181 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 182 183 Value *NewInst = 184 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); 185 Bitcast->replaceAllUsesWith(NewInst); 186 } 187 188 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, 189 // %stride); 190 // %13 = bitcast x86_amx %src to <256 x i32> 191 // store <256 x i32> %13, <256 x i32>* %addr, align 64 192 // --> 193 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 194 // %stride64, %13) 195 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { 196 197 Value *Tile = Bitcast->getOperand(0); 198 auto *II = cast<IntrinsicInst>(Tile); 199 // Tile is output from AMX intrinsic. The first operand of the 200 // intrinsic is row, the second operand of the intrinsic is column. 201 Value *Row = II->getOperand(0); 202 Value *Col = II->getOperand(1); 203 IRBuilder<> Builder(ST); 204 // Use the maximum column as stride. It must be the same with load 205 // stride. 206 Value *Stride = Builder.getInt64(64); 207 Value *I8Ptr = 208 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); 209 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; 210 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); 211 if (Bitcast->hasOneUse()) 212 return; 213 // %13 = bitcast x86_amx %src to <256 x i32> 214 // store <256 x i32> %13, <256 x i32>* %addr, align 64 215 // %add = <256 x i32> %13, <256 x i32> %src2 216 // --> 217 // %13 = bitcast x86_amx %src to <256 x i32> 218 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 219 // %stride64, %13) 220 // %14 = load <256 x i32>, %addr 221 // %add = <256 x i32> %14, <256 x i32> %src2 222 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1)); 223 Bitcast->replaceAllUsesWith(Vec); 224 } 225 226 // transform bitcast to <store, load> instructions. 227 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { 228 IRBuilder<> Builder(Bitcast); 229 AllocaInst *AllocaAddr; 230 Value *I8Ptr, *Stride; 231 auto *Src = Bitcast->getOperand(0); 232 233 auto Prepare = [&]() { 234 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent()); 235 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); 236 Stride = Builder.getInt64(64); 237 }; 238 239 if (Bitcast->getType()->isX86_AMXTy()) { 240 // %2 = bitcast <256 x i32> %src to x86_amx 241 // --> 242 // %addr = alloca <256 x i32>, align 64 243 // store <256 x i32> %src, <256 x i32>* %addr, align 64 244 // %addr2 = bitcast <256 x i32>* to i8* 245 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 246 // i8* %addr2, 247 // i64 64) 248 Use &U = *(Bitcast->use_begin()); 249 unsigned OpNo = U.getOperandNo(); 250 auto *II = dyn_cast<IntrinsicInst>(U.getUser()); 251 if (!II) 252 return false; // May be bitcast from x86amx to <256 x i32>. 253 Prepare(); 254 Builder.CreateStore(Src, AllocaAddr); 255 // TODO we can pick an constant operand for the shape. 256 Value *Row = nullptr, *Col = nullptr; 257 std::tie(Row, Col) = getShape(II, OpNo); 258 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 259 Value *NewInst = Builder.CreateIntrinsic( 260 Intrinsic::x86_tileloadd64_internal, None, Args); 261 Bitcast->replaceAllUsesWith(NewInst); 262 } else { 263 // %2 = bitcast x86_amx %src to <256 x i32> 264 // --> 265 // %addr = alloca <256 x i32>, align 64 266 // %addr2 = bitcast <256 x i32>* to i8* 267 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, 268 // i8* %addr2, i64 %stride) 269 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 270 auto *II = dyn_cast<IntrinsicInst>(Src); 271 if (!II) 272 return false; // May be bitcast from <256 x i32> to x86amx. 273 Prepare(); 274 Value *Row = II->getOperand(0); 275 Value *Col = II->getOperand(1); 276 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; 277 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); 278 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr); 279 Bitcast->replaceAllUsesWith(NewInst); 280 } 281 282 return true; 283 } 284 285 bool X86LowerAMXType::visit() { 286 SmallVector<Instruction *, 8> DeadInsts; 287 Col2Row.clear(); 288 289 for (BasicBlock *BB : post_order(&Func)) { 290 for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend(); 291 II != IE;) { 292 Instruction &Inst = *II++; 293 auto *Bitcast = dyn_cast<BitCastInst>(&Inst); 294 if (!Bitcast) 295 continue; 296 297 Value *Src = Bitcast->getOperand(0); 298 if (Bitcast->getType()->isX86_AMXTy()) { 299 if (Bitcast->user_empty()) { 300 DeadInsts.push_back(Bitcast); 301 continue; 302 } 303 LoadInst *LD = dyn_cast<LoadInst>(Src); 304 if (!LD) { 305 if (transformBitcast(Bitcast)) 306 DeadInsts.push_back(Bitcast); 307 continue; 308 } 309 // If load has mutli-user, duplicate a vector load. 310 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 311 // %2 = bitcast <256 x i32> %src to x86_amx 312 // %add = add <256 x i32> %src, <256 x i32> %src2 313 // --> 314 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 315 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 316 // i8* %addr, i64 %stride64) 317 // %add = add <256 x i32> %src, <256 x i32> %src2 318 319 // If load has one user, the load will be eliminated in DAG ISel. 320 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 321 // %2 = bitcast <256 x i32> %src to x86_amx 322 // --> 323 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 324 // i8* %addr, i64 %stride64) 325 combineLoadBitcast(LD, Bitcast); 326 DeadInsts.push_back(Bitcast); 327 if (LD->hasOneUse()) 328 DeadInsts.push_back(LD); 329 } else if (Src->getType()->isX86_AMXTy()) { 330 if (Bitcast->user_empty()) { 331 DeadInsts.push_back(Bitcast); 332 continue; 333 } 334 StoreInst *ST = nullptr; 335 for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end(); 336 UI != UE;) { 337 Value *I = (UI++)->getUser(); 338 ST = dyn_cast<StoreInst>(I); 339 if (ST) 340 break; 341 } 342 if (!ST) { 343 if (transformBitcast(Bitcast)) 344 DeadInsts.push_back(Bitcast); 345 continue; 346 } 347 // If bitcast (%13) has one use, combine bitcast and store to amx store. 348 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, 349 // %stride); 350 // %13 = bitcast x86_amx %src to <256 x i32> 351 // store <256 x i32> %13, <256 x i32>* %addr, align 64 352 // --> 353 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 354 // %stride64, %13) 355 // 356 // If bitcast (%13) has multi-use, transform as below. 357 // %13 = bitcast x86_amx %src to <256 x i32> 358 // store <256 x i32> %13, <256 x i32>* %addr, align 64 359 // %add = <256 x i32> %13, <256 x i32> %src2 360 // --> 361 // %13 = bitcast x86_amx %src to <256 x i32> 362 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 363 // %stride64, %13) 364 // %14 = load <256 x i32>, %addr 365 // %add = <256 x i32> %14, <256 x i32> %src2 366 // 367 combineBitcastStore(Bitcast, ST); 368 // Delete user first. 369 DeadInsts.push_back(ST); 370 DeadInsts.push_back(Bitcast); 371 } 372 } 373 } 374 375 bool C = !DeadInsts.empty(); 376 377 for (auto *Inst : DeadInsts) 378 Inst->eraseFromParent(); 379 380 return C; 381 } 382 } // anonymous namespace 383 384 static Value *getAllocaPos(BasicBlock *BB) { 385 Module *M = BB->getModule(); 386 Function *F = BB->getParent(); 387 IRBuilder<> Builder(&F->getEntryBlock().front()); 388 const DataLayout &DL = M->getDataLayout(); 389 unsigned AllocaAS = DL.getAllocaAddrSpace(); 390 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); 391 AllocaInst *AllocaRes = 392 new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front()); 393 BasicBlock::iterator Iter = AllocaRes->getIterator(); 394 ++Iter; 395 Builder.SetInsertPoint(&*Iter); 396 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy()); 397 return I8Ptr; 398 } 399 400 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { 401 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!"); 402 auto *II = cast<IntrinsicInst>(TileDef); 403 assert(II && "Not tile intrinsic!"); 404 Value *Row = II->getOperand(0); 405 Value *Col = II->getOperand(1); 406 407 BasicBlock *BB = TileDef->getParent(); 408 BasicBlock::iterator Iter = TileDef->getIterator(); 409 IRBuilder<> Builder(BB, ++Iter); 410 Value *Stride = Builder.getInt64(64); 411 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef}; 412 413 Instruction *TileStore = 414 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); 415 return TileStore; 416 } 417 418 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { 419 Value *V = U.get(); 420 assert(V->getType()->isX86_AMXTy() && "Not define tile!"); 421 422 // Get tile shape. 423 IntrinsicInst *II = nullptr; 424 if (IsPHI) { 425 Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0); 426 II = cast<IntrinsicInst>(PhiOp); 427 } else { 428 II = cast<IntrinsicInst>(V); 429 } 430 Value *Row = II->getOperand(0); 431 Value *Col = II->getOperand(1); 432 433 Instruction *UserI = dyn_cast<Instruction>(U.getUser()); 434 IRBuilder<> Builder(UserI); 435 Value *Stride = Builder.getInt64(64); 436 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride}; 437 438 Value *TileLoad = 439 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); 440 UserI->replaceUsesOfWith(V, TileLoad); 441 } 442 443 static bool isIncomingOfPHI(Instruction *I) { 444 for (Use &U : I->uses()) { 445 User *V = U.getUser(); 446 if (isa<PHINode>(V)) 447 return true; 448 } 449 return false; 450 } 451 452 // Let all AMX tile data become volatile data, shorten the life range 453 // of each tile register before fast register allocation. 454 namespace { 455 class X86VolatileTileData { 456 Function &F; 457 458 public: 459 X86VolatileTileData(Function &Func) : F(Func) {} 460 Value *updatePhiIncomings(BasicBlock *BB, 461 SmallVector<Instruction *, 2> &Incomings); 462 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); 463 bool volatileTileData(); 464 void volatileTilePHI(PHINode *Inst); 465 void volatileTileNonPHI(Instruction *I); 466 }; 467 468 Value *X86VolatileTileData::updatePhiIncomings( 469 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { 470 Value *I8Ptr = getAllocaPos(BB); 471 472 for (auto *I : Incomings) { 473 User *Store = createTileStore(I, I8Ptr); 474 475 // All its uses (except phi) should load from stored mem. 476 for (Use &U : I->uses()) { 477 User *V = U.getUser(); 478 if (isa<PHINode>(V) || V == Store) 479 continue; 480 replaceWithTileLoad(U, I8Ptr); 481 } 482 } 483 return I8Ptr; 484 } 485 486 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, 487 Value *StorePtr) { 488 for (Use &U : PHI->uses()) 489 replaceWithTileLoad(U, StorePtr, true); 490 PHI->eraseFromParent(); 491 } 492 493 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes 494 // and their related AMX intrinsics. 495 // 1) PHI Def should change to tileload. 496 // 2) PHI Incoming Values should tilestored in just after their def. 497 // 3) The mem of these tileload and tilestores should be same. 498 // e.g. 499 // ------------------------------------------------------ 500 // bb_dom: 501 // ... 502 // br i1 %bool.cond, label %if.else, label %if.then 503 // 504 // if.then: 505 // def %t0 = ... 506 // ... 507 // use %t0 508 // ... 509 // br label %if.end 510 // 511 // if.else: 512 // def %t1 = ... 513 // br label %if.end 514 // 515 // if.end: 516 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] 517 // ... 518 // use %td 519 // ------------------------------------------------------ 520 // --> 521 // ------------------------------------------------------ 522 // bb_entry: 523 // %mem = alloca <256 x i32>, align 1024 * 524 // ... 525 // bb_dom: 526 // ... 527 // br i1 %bool.cond, label %if.else, label %if.then 528 // 529 // if.then: 530 // def %t0 = ... 531 // call void @llvm.x86.tilestored64.internal(mem, %t0) * 532 // ... 533 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* 534 // use %t0` * 535 // ... 536 // br label %if.end 537 // 538 // if.else: 539 // def %t1 = ... 540 // call void @llvm.x86.tilestored64.internal(mem, %t1) * 541 // br label %if.end 542 // 543 // if.end: 544 // ... 545 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * 546 // use %td 547 // ------------------------------------------------------ 548 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { 549 BasicBlock *BB = PHI->getParent(); 550 SmallVector<Instruction *, 2> Incomings; 551 552 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { 553 Value *Op = PHI->getIncomingValue(I); 554 Instruction *Inst = dyn_cast<Instruction>(Op); 555 assert(Inst && "We shouldn't fold AMX instrution!"); 556 Incomings.push_back(Inst); 557 } 558 559 Value *StorePtr = updatePhiIncomings(BB, Incomings); 560 replacePhiDefWithLoad(PHI, StorePtr); 561 } 562 563 // Store the defined tile and load it before use. 564 // All its users are not PHI. 565 // e.g. 566 // ------------------------------------------------------ 567 // def %td = ... 568 // ... 569 // "use %td" 570 // ------------------------------------------------------ 571 // --> 572 // ------------------------------------------------------ 573 // def %td = ... 574 // call void @llvm.x86.tilestored64.internal(mem, %td) 575 // ... 576 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) 577 // "use %td2" 578 // ------------------------------------------------------ 579 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { 580 BasicBlock *BB = I->getParent(); 581 Value *I8Ptr = getAllocaPos(BB); 582 User *Store = createTileStore(I, I8Ptr); 583 584 // All its uses should load from stored mem. 585 for (Use &U : I->uses()) { 586 User *V = U.getUser(); 587 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!"); 588 if (V != Store) 589 replaceWithTileLoad(U, I8Ptr); 590 } 591 } 592 593 // Volatile Tile Model: 594 // 1) All the uses of tile data comes from tileload in time. 595 // 2) All the defs of tile data tilestore into mem immediately. 596 // For example: 597 // -------------------------------------------------------------------------- 598 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key 599 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) 600 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx 601 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) 602 // call void @llvm.x86.tilestored64.internal(... td) area 603 // -------------------------------------------------------------------------- 604 // 3) No terminator, call or other amx instructions in the key amx area. 605 bool X86VolatileTileData::volatileTileData() { 606 bool Changed = false; 607 for (BasicBlock &BB : F) { 608 SmallVector<Instruction *, 2> PHIInsts; 609 SmallVector<Instruction *, 8> AMXDefInsts; 610 611 for (Instruction &I : BB) { 612 if (!I.getType()->isX86_AMXTy()) 613 continue; 614 if (isa<PHINode>(&I)) 615 PHIInsts.push_back(&I); 616 else 617 AMXDefInsts.push_back(&I); 618 } 619 620 // First we "volatile" the non-phi related amx intrinsics. 621 for (Instruction *I : AMXDefInsts) { 622 if (isIncomingOfPHI(I)) 623 continue; 624 volatileTileNonPHI(I); 625 Changed = true; 626 } 627 628 for (Instruction *I : PHIInsts) { 629 volatileTilePHI(dyn_cast<PHINode>(I)); 630 Changed = true; 631 } 632 } 633 return Changed; 634 } 635 636 } // anonymous namespace 637 638 namespace { 639 640 class X86LowerAMXTypeLegacyPass : public FunctionPass { 641 public: 642 static char ID; 643 644 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { 645 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); 646 } 647 648 bool runOnFunction(Function &F) override { 649 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 650 651 X86LowerAMXType LAT(F, TM); 652 bool C = LAT.visit(); 653 654 // Prepare for fast register allocation at O0. 655 // Todo: May better check the volatile model of AMX code, not just 656 // by checking Attribute::OptimizeNone and CodeGenOpt::None. 657 if (TM->getOptLevel() == CodeGenOpt::None) { 658 // If Front End not use O0 but the Mid/Back end use O0, (e.g. 659 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make 660 // sure the amx data is volatile, that is nessary for AMX fast 661 // register allocation. 662 if (!F.hasFnAttribute(Attribute::OptimizeNone)) { 663 X86VolatileTileData VTD(F); 664 C = VTD.volatileTileData() || C; 665 } 666 } 667 668 return C; 669 } 670 671 void getAnalysisUsage(AnalysisUsage &AU) const override { 672 AU.setPreservesCFG(); 673 AU.addRequired<TargetPassConfig>(); 674 } 675 }; 676 677 } // anonymous namespace 678 679 static const char PassName[] = "Lower AMX type for load/store"; 680 char X86LowerAMXTypeLegacyPass::ID = 0; 681 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, 682 false) 683 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 684 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, 685 false) 686 687 FunctionPass *llvm::createX86LowerAMXTypePass() { 688 return new X86LowerAMXTypeLegacyPass(); 689 } 690