1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to transform <256 x i32> load/store 10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only 11 /// provides simple operation on x86_amx. The basic elementwise operation 12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> 13 /// and only AMX intrinsics can operate on the type, we need transform 14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can 15 /// not be combined with load/store, we transform the bitcast to amx load/store 16 /// and <256 x i32> store/load. 17 /// 18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S 19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, 20 /// because that is necessary for AMX fast register allocation. (In Fast 21 /// registera allocation, register will be allocated before spill/reload, so 22 /// there is no additional register for amx to identify the step in spill.) 23 /// The volatileTileData() will handle this case. 24 /// e.g. 25 /// ---------------------------------------------------------- 26 /// | def %td = ... | 27 /// | ... | 28 /// | "use %td" | 29 /// ---------------------------------------------------------- 30 /// will transfer to --> 31 /// ---------------------------------------------------------- 32 /// | def %td = ... | 33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) | 34 /// | ... | 35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| 36 /// | "use %td2" | 37 /// ---------------------------------------------------------- 38 // 39 //===----------------------------------------------------------------------===// 40 // 41 #include "X86.h" 42 #include "llvm/ADT/PostOrderIterator.h" 43 #include "llvm/ADT/SetVector.h" 44 #include "llvm/ADT/SmallSet.h" 45 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 46 #include "llvm/Analysis/TargetLibraryInfo.h" 47 #include "llvm/Analysis/TargetTransformInfo.h" 48 #include "llvm/CodeGen/Passes.h" 49 #include "llvm/CodeGen/TargetPassConfig.h" 50 #include "llvm/CodeGen/ValueTypes.h" 51 #include "llvm/IR/DataLayout.h" 52 #include "llvm/IR/Function.h" 53 #include "llvm/IR/IRBuilder.h" 54 #include "llvm/IR/Instructions.h" 55 #include "llvm/IR/IntrinsicInst.h" 56 #include "llvm/IR/IntrinsicsX86.h" 57 #include "llvm/IR/PatternMatch.h" 58 #include "llvm/InitializePasses.h" 59 #include "llvm/Pass.h" 60 #include "llvm/Target/TargetMachine.h" 61 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" 62 #include "llvm/Transforms/Utils/Local.h" 63 64 #include <map> 65 66 using namespace llvm; 67 using namespace PatternMatch; 68 69 #define DEBUG_TYPE "lower-amx-type" 70 71 static bool isAMXCast(Instruction *II) { 72 return match(II, 73 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) || 74 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value())); 75 } 76 77 static bool isAMXIntrinsic(Value *I) { 78 auto *II = dyn_cast<IntrinsicInst>(I); 79 if (!II) 80 return false; 81 if (isAMXCast(II)) 82 return false; 83 // Check if return type or parameter is x86_amx. If it is x86_amx 84 // the intrinsic must be x86 amx intrinsics. 85 if (II->getType()->isX86_AMXTy()) 86 return true; 87 for (Value *V : II->args()) { 88 if (V->getType()->isX86_AMXTy()) 89 return true; 90 } 91 92 return false; 93 } 94 95 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, 96 Type *Ty) { 97 Function &F = *BB->getParent(); 98 Module *M = BB->getModule(); 99 const DataLayout &DL = M->getDataLayout(); 100 101 LLVMContext &Ctx = Builder.getContext(); 102 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); 103 unsigned AllocaAS = DL.getAllocaAddrSpace(); 104 AllocaInst *AllocaRes = 105 new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front()); 106 AllocaRes->setAlignment(AllocaAlignment); 107 return AllocaRes; 108 } 109 110 static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { 111 for (Instruction &I : F.getEntryBlock()) 112 if (!isa<AllocaInst>(&I)) 113 return &I; 114 llvm_unreachable("No terminator in the entry block!"); 115 } 116 117 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) { 118 IRBuilder<> Builder(II); 119 Value *Row = nullptr, *Col = nullptr; 120 switch (II->getIntrinsicID()) { 121 default: 122 llvm_unreachable("Expect amx intrinsics"); 123 case Intrinsic::x86_tileloadd64_internal: 124 case Intrinsic::x86_tileloaddt164_internal: 125 case Intrinsic::x86_tilestored64_internal: { 126 Row = II->getArgOperand(0); 127 Col = II->getArgOperand(1); 128 break; 129 } 130 // a * b + c 131 // The shape depends on which operand. 132 case Intrinsic::x86_tcmmimfp16ps_internal: 133 case Intrinsic::x86_tcmmrlfp16ps_internal: 134 case Intrinsic::x86_tdpbssd_internal: 135 case Intrinsic::x86_tdpbsud_internal: 136 case Intrinsic::x86_tdpbusd_internal: 137 case Intrinsic::x86_tdpbuud_internal: 138 case Intrinsic::x86_tdpbf16ps_internal: 139 case Intrinsic::x86_tdpfp16ps_internal: { 140 switch (OpNo) { 141 case 3: 142 Row = II->getArgOperand(0); 143 Col = II->getArgOperand(1); 144 break; 145 case 4: 146 Row = II->getArgOperand(0); 147 Col = II->getArgOperand(2); 148 break; 149 case 5: 150 if (isa<ConstantInt>(II->getArgOperand(2))) 151 Row = Builder.getInt16( 152 (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4); 153 else if (isa<Instruction>(II->getArgOperand(2))) { 154 // When it is not a const value and it is not a function argument, we 155 // create Row after the definition of II->getOperand(2) instead of 156 // before II. For example, II is %118, we try to getshape for %117: 157 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x 158 // i32> %115). 159 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 160 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx 161 // %117). 162 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its 163 // definition is after its user(new tileload for %117). 164 // So, the best choice is to create %row right after the definition of 165 // %106. 166 Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2))); 167 Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4)); 168 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2))); 169 } else { 170 // When it is not a const value and it is a function argument, we create 171 // Row at the entry bb. 172 IRBuilder<> NewBuilder( 173 getFirstNonAllocaInTheEntryBlock(*II->getFunction())); 174 Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4)); 175 } 176 Col = II->getArgOperand(1); 177 break; 178 } 179 break; 180 } 181 } 182 183 return std::make_pair(Row, Col); 184 } 185 186 static std::pair<Value *, Value *> getShape(PHINode *Phi) { 187 Use &U = *(Phi->use_begin()); 188 unsigned OpNo = U.getOperandNo(); 189 User *V = U.getUser(); 190 // TODO We don't traverse all users. To make the algorithm simple, here we 191 // just traverse the first user. If we can find shape, then return the shape, 192 // otherwise just return nullptr and the optimization for undef/zero will be 193 // abandoned. 194 while (V) { 195 if (isAMXCast(dyn_cast<Instruction>(V))) { 196 if (V->use_empty()) 197 break; 198 Use &U = *(V->use_begin()); 199 OpNo = U.getOperandNo(); 200 V = U.getUser(); 201 } else if (isAMXIntrinsic(V)) { 202 return getShape(cast<IntrinsicInst>(V), OpNo); 203 } else if (isa<PHINode>(V)) { 204 if (V->use_empty()) 205 break; 206 Use &U = *(V->use_begin()); 207 V = U.getUser(); 208 } else { 209 break; 210 } 211 } 212 213 return std::make_pair(nullptr, nullptr); 214 } 215 216 namespace { 217 class X86LowerAMXType { 218 Function &Func; 219 220 // In AMX intrinsics we let Shape = {Row, Col}, but the 221 // RealCol = Col / ElementSize. We may use the RealCol 222 // as a new Row for other new created AMX intrinsics. 223 std::map<Value *, Value *> Col2Row; 224 225 public: 226 X86LowerAMXType(Function &F) : Func(F) {} 227 bool visit(); 228 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); 229 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); 230 bool transformBitcast(BitCastInst *Bitcast); 231 }; 232 233 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 234 // %2 = bitcast <256 x i32> %src to x86_amx 235 // --> 236 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 237 // i8* %addr, i64 %stride64) 238 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { 239 Value *Row = nullptr, *Col = nullptr; 240 Use &U = *(Bitcast->use_begin()); 241 unsigned OpNo = U.getOperandNo(); 242 auto *II = cast<IntrinsicInst>(U.getUser()); 243 std::tie(Row, Col) = getShape(II, OpNo); 244 IRBuilder<> Builder(Bitcast); 245 // Use the maximun column as stride. 246 Value *Stride = Builder.getInt64(64); 247 Value *I8Ptr = LD->getOperand(0); 248 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 249 250 Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, 251 std::nullopt, Args); 252 Bitcast->replaceAllUsesWith(NewInst); 253 } 254 255 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, 256 // %stride); 257 // %13 = bitcast x86_amx %src to <256 x i32> 258 // store <256 x i32> %13, <256 x i32>* %addr, align 64 259 // --> 260 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 261 // %stride64, %13) 262 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { 263 264 Value *Tile = Bitcast->getOperand(0); 265 auto *II = cast<IntrinsicInst>(Tile); 266 // Tile is output from AMX intrinsic. The first operand of the 267 // intrinsic is row, the second operand of the intrinsic is column. 268 Value *Row = II->getOperand(0); 269 Value *Col = II->getOperand(1); 270 IRBuilder<> Builder(ST); 271 // Use the maximum column as stride. It must be the same with load 272 // stride. 273 Value *Stride = Builder.getInt64(64); 274 Value *I8Ptr = ST->getOperand(1); 275 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; 276 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, 277 Args); 278 if (Bitcast->hasOneUse()) 279 return; 280 // %13 = bitcast x86_amx %src to <256 x i32> 281 // store <256 x i32> %13, <256 x i32>* %addr, align 64 282 // %add = <256 x i32> %13, <256 x i32> %src2 283 // --> 284 // %13 = bitcast x86_amx %src to <256 x i32> 285 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 286 // %stride64, %13) 287 // %14 = load <256 x i32>, %addr 288 // %add = <256 x i32> %14, <256 x i32> %src2 289 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1)); 290 Bitcast->replaceAllUsesWith(Vec); 291 } 292 293 // transform bitcast to <store, load> instructions. 294 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { 295 IRBuilder<> Builder(Bitcast); 296 AllocaInst *AllocaAddr; 297 Value *I8Ptr, *Stride; 298 auto *Src = Bitcast->getOperand(0); 299 300 auto Prepare = [&](Type *MemTy) { 301 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy); 302 I8Ptr = AllocaAddr; 303 Stride = Builder.getInt64(64); 304 }; 305 306 if (Bitcast->getType()->isX86_AMXTy()) { 307 // %2 = bitcast <256 x i32> %src to x86_amx 308 // --> 309 // %addr = alloca <256 x i32>, align 64 310 // store <256 x i32> %src, <256 x i32>* %addr, align 64 311 // %addr2 = bitcast <256 x i32>* to i8* 312 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 313 // i8* %addr2, 314 // i64 64) 315 Use &U = *(Bitcast->use_begin()); 316 unsigned OpNo = U.getOperandNo(); 317 auto *II = dyn_cast<IntrinsicInst>(U.getUser()); 318 if (!II) 319 return false; // May be bitcast from x86amx to <256 x i32>. 320 Prepare(Bitcast->getOperand(0)->getType()); 321 Builder.CreateStore(Src, AllocaAddr); 322 // TODO we can pick an constant operand for the shape. 323 Value *Row = nullptr, *Col = nullptr; 324 std::tie(Row, Col) = getShape(II, OpNo); 325 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 326 Value *NewInst = Builder.CreateIntrinsic( 327 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args); 328 Bitcast->replaceAllUsesWith(NewInst); 329 } else { 330 // %2 = bitcast x86_amx %src to <256 x i32> 331 // --> 332 // %addr = alloca <256 x i32>, align 64 333 // %addr2 = bitcast <256 x i32>* to i8* 334 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, 335 // i8* %addr2, i64 %stride) 336 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 337 auto *II = dyn_cast<IntrinsicInst>(Src); 338 if (!II) 339 return false; // May be bitcast from <256 x i32> to x86amx. 340 Prepare(Bitcast->getType()); 341 Value *Row = II->getOperand(0); 342 Value *Col = II->getOperand(1); 343 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; 344 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, 345 Args); 346 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr); 347 Bitcast->replaceAllUsesWith(NewInst); 348 } 349 350 return true; 351 } 352 353 bool X86LowerAMXType::visit() { 354 SmallVector<Instruction *, 8> DeadInsts; 355 Col2Row.clear(); 356 357 for (BasicBlock *BB : post_order(&Func)) { 358 for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) { 359 auto *Bitcast = dyn_cast<BitCastInst>(&Inst); 360 if (!Bitcast) 361 continue; 362 363 Value *Src = Bitcast->getOperand(0); 364 if (Bitcast->getType()->isX86_AMXTy()) { 365 if (Bitcast->user_empty()) { 366 DeadInsts.push_back(Bitcast); 367 continue; 368 } 369 LoadInst *LD = dyn_cast<LoadInst>(Src); 370 if (!LD) { 371 if (transformBitcast(Bitcast)) 372 DeadInsts.push_back(Bitcast); 373 continue; 374 } 375 // If load has mutli-user, duplicate a vector load. 376 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 377 // %2 = bitcast <256 x i32> %src to x86_amx 378 // %add = add <256 x i32> %src, <256 x i32> %src2 379 // --> 380 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 381 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 382 // i8* %addr, i64 %stride64) 383 // %add = add <256 x i32> %src, <256 x i32> %src2 384 385 // If load has one user, the load will be eliminated in DAG ISel. 386 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 387 // %2 = bitcast <256 x i32> %src to x86_amx 388 // --> 389 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 390 // i8* %addr, i64 %stride64) 391 combineLoadBitcast(LD, Bitcast); 392 DeadInsts.push_back(Bitcast); 393 if (LD->hasOneUse()) 394 DeadInsts.push_back(LD); 395 } else if (Src->getType()->isX86_AMXTy()) { 396 if (Bitcast->user_empty()) { 397 DeadInsts.push_back(Bitcast); 398 continue; 399 } 400 StoreInst *ST = nullptr; 401 for (Use &U : Bitcast->uses()) { 402 ST = dyn_cast<StoreInst>(U.getUser()); 403 if (ST) 404 break; 405 } 406 if (!ST) { 407 if (transformBitcast(Bitcast)) 408 DeadInsts.push_back(Bitcast); 409 continue; 410 } 411 // If bitcast (%13) has one use, combine bitcast and store to amx store. 412 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, 413 // %stride); 414 // %13 = bitcast x86_amx %src to <256 x i32> 415 // store <256 x i32> %13, <256 x i32>* %addr, align 64 416 // --> 417 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 418 // %stride64, %13) 419 // 420 // If bitcast (%13) has multi-use, transform as below. 421 // %13 = bitcast x86_amx %src to <256 x i32> 422 // store <256 x i32> %13, <256 x i32>* %addr, align 64 423 // %add = <256 x i32> %13, <256 x i32> %src2 424 // --> 425 // %13 = bitcast x86_amx %src to <256 x i32> 426 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 427 // %stride64, %13) 428 // %14 = load <256 x i32>, %addr 429 // %add = <256 x i32> %14, <256 x i32> %src2 430 // 431 combineBitcastStore(Bitcast, ST); 432 // Delete user first. 433 DeadInsts.push_back(ST); 434 DeadInsts.push_back(Bitcast); 435 } 436 } 437 } 438 439 bool C = !DeadInsts.empty(); 440 441 for (auto *Inst : DeadInsts) 442 Inst->eraseFromParent(); 443 444 return C; 445 } 446 } // anonymous namespace 447 448 static Value *getAllocaPos(BasicBlock *BB) { 449 Module *M = BB->getModule(); 450 Function *F = BB->getParent(); 451 IRBuilder<> Builder(&F->getEntryBlock().front()); 452 const DataLayout &DL = M->getDataLayout(); 453 unsigned AllocaAS = DL.getAllocaAddrSpace(); 454 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); 455 AllocaInst *AllocaRes = 456 new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front()); 457 BasicBlock::iterator Iter = AllocaRes->getIterator(); 458 ++Iter; 459 Builder.SetInsertPoint(&*Iter); 460 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy()); 461 return I8Ptr; 462 } 463 464 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { 465 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!"); 466 auto *II = cast<IntrinsicInst>(TileDef); 467 assert(II && "Not tile intrinsic!"); 468 Value *Row = II->getOperand(0); 469 Value *Col = II->getOperand(1); 470 471 BasicBlock *BB = TileDef->getParent(); 472 BasicBlock::iterator Iter = TileDef->getIterator(); 473 IRBuilder<> Builder(BB, ++Iter); 474 Value *Stride = Builder.getInt64(64); 475 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef}; 476 477 Instruction *TileStore = Builder.CreateIntrinsic( 478 Intrinsic::x86_tilestored64_internal, std::nullopt, Args); 479 return TileStore; 480 } 481 482 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { 483 Value *V = U.get(); 484 assert(V->getType()->isX86_AMXTy() && "Not define tile!"); 485 486 // Get tile shape. 487 IntrinsicInst *II = nullptr; 488 if (IsPHI) { 489 Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0); 490 II = cast<IntrinsicInst>(PhiOp); 491 } else { 492 II = cast<IntrinsicInst>(V); 493 } 494 Value *Row = II->getOperand(0); 495 Value *Col = II->getOperand(1); 496 497 Instruction *UserI = cast<Instruction>(U.getUser()); 498 IRBuilder<> Builder(UserI); 499 Value *Stride = Builder.getInt64(64); 500 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride}; 501 502 Value *TileLoad = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, 503 std::nullopt, Args); 504 UserI->replaceUsesOfWith(V, TileLoad); 505 } 506 507 static bool isIncomingOfPHI(Instruction *I) { 508 for (Use &U : I->uses()) { 509 User *V = U.getUser(); 510 if (isa<PHINode>(V)) 511 return true; 512 } 513 return false; 514 } 515 516 // Let all AMX tile data become volatile data, shorten the life range 517 // of each tile register before fast register allocation. 518 namespace { 519 class X86VolatileTileData { 520 Function &F; 521 522 public: 523 X86VolatileTileData(Function &Func) : F(Func) {} 524 Value *updatePhiIncomings(BasicBlock *BB, 525 SmallVector<Instruction *, 2> &Incomings); 526 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); 527 bool volatileTileData(); 528 void volatileTilePHI(PHINode *PHI); 529 void volatileTileNonPHI(Instruction *I); 530 }; 531 532 Value *X86VolatileTileData::updatePhiIncomings( 533 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { 534 Value *I8Ptr = getAllocaPos(BB); 535 536 for (auto *I : Incomings) { 537 User *Store = createTileStore(I, I8Ptr); 538 539 // All its uses (except phi) should load from stored mem. 540 for (Use &U : I->uses()) { 541 User *V = U.getUser(); 542 if (isa<PHINode>(V) || V == Store) 543 continue; 544 replaceWithTileLoad(U, I8Ptr); 545 } 546 } 547 return I8Ptr; 548 } 549 550 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, 551 Value *StorePtr) { 552 for (Use &U : PHI->uses()) 553 replaceWithTileLoad(U, StorePtr, true); 554 PHI->eraseFromParent(); 555 } 556 557 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes 558 // and their related AMX intrinsics. 559 // 1) PHI Def should change to tileload. 560 // 2) PHI Incoming Values should tilestored in just after their def. 561 // 3) The mem of these tileload and tilestores should be same. 562 // e.g. 563 // ------------------------------------------------------ 564 // bb_dom: 565 // ... 566 // br i1 %bool.cond, label %if.else, label %if.then 567 // 568 // if.then: 569 // def %t0 = ... 570 // ... 571 // use %t0 572 // ... 573 // br label %if.end 574 // 575 // if.else: 576 // def %t1 = ... 577 // br label %if.end 578 // 579 // if.end: 580 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] 581 // ... 582 // use %td 583 // ------------------------------------------------------ 584 // --> 585 // ------------------------------------------------------ 586 // bb_entry: 587 // %mem = alloca <256 x i32>, align 1024 * 588 // ... 589 // bb_dom: 590 // ... 591 // br i1 %bool.cond, label %if.else, label %if.then 592 // 593 // if.then: 594 // def %t0 = ... 595 // call void @llvm.x86.tilestored64.internal(mem, %t0) * 596 // ... 597 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* 598 // use %t0` * 599 // ... 600 // br label %if.end 601 // 602 // if.else: 603 // def %t1 = ... 604 // call void @llvm.x86.tilestored64.internal(mem, %t1) * 605 // br label %if.end 606 // 607 // if.end: 608 // ... 609 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * 610 // use %td 611 // ------------------------------------------------------ 612 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { 613 BasicBlock *BB = PHI->getParent(); 614 SmallVector<Instruction *, 2> Incomings; 615 616 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { 617 Value *Op = PHI->getIncomingValue(I); 618 Instruction *Inst = dyn_cast<Instruction>(Op); 619 assert(Inst && "We shouldn't fold AMX instrution!"); 620 Incomings.push_back(Inst); 621 } 622 623 Value *StorePtr = updatePhiIncomings(BB, Incomings); 624 replacePhiDefWithLoad(PHI, StorePtr); 625 } 626 627 // Store the defined tile and load it before use. 628 // All its users are not PHI. 629 // e.g. 630 // ------------------------------------------------------ 631 // def %td = ... 632 // ... 633 // "use %td" 634 // ------------------------------------------------------ 635 // --> 636 // ------------------------------------------------------ 637 // def %td = ... 638 // call void @llvm.x86.tilestored64.internal(mem, %td) 639 // ... 640 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) 641 // "use %td2" 642 // ------------------------------------------------------ 643 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { 644 BasicBlock *BB = I->getParent(); 645 Value *I8Ptr = getAllocaPos(BB); 646 User *Store = createTileStore(I, I8Ptr); 647 648 // All its uses should load from stored mem. 649 for (Use &U : I->uses()) { 650 User *V = U.getUser(); 651 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!"); 652 if (V != Store) 653 replaceWithTileLoad(U, I8Ptr); 654 } 655 } 656 657 // Volatile Tile Model: 658 // 1) All the uses of tile data comes from tileload in time. 659 // 2) All the defs of tile data tilestore into mem immediately. 660 // For example: 661 // -------------------------------------------------------------------------- 662 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key 663 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) 664 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx 665 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) 666 // call void @llvm.x86.tilestored64.internal(... td) area 667 // -------------------------------------------------------------------------- 668 // 3) No terminator, call or other amx instructions in the key amx area. 669 bool X86VolatileTileData::volatileTileData() { 670 bool Changed = false; 671 for (BasicBlock &BB : F) { 672 SmallVector<Instruction *, 2> PHIInsts; 673 SmallVector<Instruction *, 8> AMXDefInsts; 674 675 for (Instruction &I : BB) { 676 if (!I.getType()->isX86_AMXTy()) 677 continue; 678 if (isa<PHINode>(&I)) 679 PHIInsts.push_back(&I); 680 else 681 AMXDefInsts.push_back(&I); 682 } 683 684 // First we "volatile" the non-phi related amx intrinsics. 685 for (Instruction *I : AMXDefInsts) { 686 if (isIncomingOfPHI(I)) 687 continue; 688 volatileTileNonPHI(I); 689 Changed = true; 690 } 691 692 for (Instruction *I : PHIInsts) { 693 volatileTilePHI(dyn_cast<PHINode>(I)); 694 Changed = true; 695 } 696 } 697 return Changed; 698 } 699 700 } // anonymous namespace 701 702 namespace { 703 704 class X86LowerAMXCast { 705 Function &Func; 706 std::unique_ptr<DominatorTree> DT; 707 708 public: 709 X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {} 710 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST); 711 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); 712 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts); 713 bool combineAMXcast(TargetLibraryInfo *TLI); 714 bool transformAMXCast(IntrinsicInst *AMXCast); 715 bool transformAllAMXCast(); 716 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN, 717 SmallSetVector<Instruction *, 16> &DeadInst); 718 }; 719 720 static bool DCEInstruction(Instruction *I, 721 SmallSetVector<Instruction *, 16> &WorkList, 722 const TargetLibraryInfo *TLI) { 723 if (isInstructionTriviallyDead(I, TLI)) { 724 salvageDebugInfo(*I); 725 salvageKnowledge(I); 726 727 // Null out all of the instruction's operands to see if any operand becomes 728 // dead as we go. 729 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { 730 Value *OpV = I->getOperand(i); 731 I->setOperand(i, nullptr); 732 733 if (!OpV->use_empty() || I == OpV) 734 continue; 735 736 // If the operand is an instruction that became dead as we nulled out the 737 // operand, and if it is 'trivially' dead, delete it in a future loop 738 // iteration. 739 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) { 740 if (isInstructionTriviallyDead(OpI, TLI)) { 741 WorkList.insert(OpI); 742 } 743 } 744 } 745 I->eraseFromParent(); 746 return true; 747 } 748 return false; 749 } 750 751 /// This function handles following case 752 /// 753 /// A -> B amxcast 754 /// PHI 755 /// B -> A amxcast 756 /// 757 /// All the related PHI nodes can be replaced by new PHI nodes with type A. 758 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. 759 bool X86LowerAMXCast::optimizeAMXCastFromPhi( 760 IntrinsicInst *CI, PHINode *PN, 761 SmallSetVector<Instruction *, 16> &DeadInst) { 762 IRBuilder<> Builder(CI); 763 Value *Src = CI->getOperand(0); 764 Type *SrcTy = Src->getType(); // Type B 765 Type *DestTy = CI->getType(); // Type A 766 767 SmallVector<PHINode *, 4> PhiWorklist; 768 SmallSetVector<PHINode *, 4> OldPhiNodes; 769 770 // Find all of the A->B casts and PHI nodes. 771 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so 772 // OldPhiNodes is used to track all known PHI nodes, before adding a new 773 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. 774 PhiWorklist.push_back(PN); 775 OldPhiNodes.insert(PN); 776 while (!PhiWorklist.empty()) { 777 auto *OldPN = PhiWorklist.pop_back_val(); 778 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) { 779 Value *IncValue = OldPN->getIncomingValue(I); 780 // TODO: currently, We ignore cases where it is a const. In the future, we 781 // might support const. 782 if (isa<Constant>(IncValue)) { 783 auto *IncConst = dyn_cast<Constant>(IncValue); 784 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue()) 785 return false; 786 Value *Row = nullptr, *Col = nullptr; 787 std::tie(Row, Col) = getShape(OldPN); 788 // TODO: If it is not constant the Row and Col must domoniate tilezero 789 // that we are going to create. 790 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col)) 791 return false; 792 // Create tilezero at the end of incoming block. 793 auto *Block = OldPN->getIncomingBlock(I); 794 BasicBlock::iterator Iter = Block->getTerminator()->getIterator(); 795 Instruction *NewInst = Builder.CreateIntrinsic( 796 Intrinsic::x86_tilezero_internal, std::nullopt, {Row, Col}); 797 NewInst->moveBefore(&*Iter); 798 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector, 799 {IncValue->getType()}, {NewInst}); 800 NewInst->moveBefore(&*Iter); 801 // Replace InValue with new Value. 802 OldPN->setIncomingValue(I, NewInst); 803 IncValue = NewInst; 804 } 805 806 if (auto *PNode = dyn_cast<PHINode>(IncValue)) { 807 if (OldPhiNodes.insert(PNode)) 808 PhiWorklist.push_back(PNode); 809 continue; 810 } 811 Instruction *ACI = dyn_cast<Instruction>(IncValue); 812 if (ACI && isAMXCast(ACI)) { 813 // Verify it's a A->B cast. 814 Type *TyA = ACI->getOperand(0)->getType(); 815 Type *TyB = ACI->getType(); 816 if (TyA != DestTy || TyB != SrcTy) 817 return false; 818 continue; 819 } 820 return false; 821 } 822 } 823 824 // Check that each user of each old PHI node is something that we can 825 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. 826 for (auto *OldPN : OldPhiNodes) { 827 for (User *V : OldPN->users()) { 828 Instruction *ACI = dyn_cast<Instruction>(V); 829 if (ACI && isAMXCast(ACI)) { 830 // Verify it's a B->A cast. 831 Type *TyB = ACI->getOperand(0)->getType(); 832 Type *TyA = ACI->getType(); 833 if (TyA != DestTy || TyB != SrcTy) 834 return false; 835 } else if (auto *PHI = dyn_cast<PHINode>(V)) { 836 // As long as the user is another old PHI node, then even if we don't 837 // rewrite it, the PHI web we're considering won't have any users 838 // outside itself, so it'll be dead. 839 // example: 840 // bb.0: 841 // %0 = amxcast ... 842 // bb.1: 843 // %1 = amxcast ... 844 // bb.2: 845 // %goodphi = phi %0, %1 846 // %3 = amxcast %goodphi 847 // bb.3: 848 // %goodphi2 = phi %0, %goodphi 849 // %4 = amxcast %goodphi2 850 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is 851 // outside the phi-web, so the combination stop When 852 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization 853 // will be done. 854 if (OldPhiNodes.count(PHI) == 0) 855 return false; 856 } else 857 return false; 858 } 859 } 860 861 // For each old PHI node, create a corresponding new PHI node with a type A. 862 SmallDenseMap<PHINode *, PHINode *> NewPNodes; 863 for (auto *OldPN : OldPhiNodes) { 864 Builder.SetInsertPoint(OldPN); 865 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands()); 866 NewPNodes[OldPN] = NewPN; 867 } 868 869 // Fill in the operands of new PHI nodes. 870 for (auto *OldPN : OldPhiNodes) { 871 PHINode *NewPN = NewPNodes[OldPN]; 872 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { 873 Value *V = OldPN->getOperand(j); 874 Value *NewV = nullptr; 875 Instruction *ACI = dyn_cast<Instruction>(V); 876 // There should not be a AMXcast from a const. 877 if (ACI && isAMXCast(ACI)) 878 NewV = ACI->getOperand(0); 879 else if (auto *PrevPN = dyn_cast<PHINode>(V)) 880 NewV = NewPNodes[PrevPN]; 881 assert(NewV); 882 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j)); 883 } 884 } 885 886 // Traverse all accumulated PHI nodes and process its users, 887 // which are Stores and BitcCasts. Without this processing 888 // NewPHI nodes could be replicated and could lead to extra 889 // moves generated after DeSSA. 890 // If there is a store with type B, change it to type A. 891 892 // Replace users of BitCast B->A with NewPHI. These will help 893 // later to get rid of a closure formed by OldPHI nodes. 894 for (auto *OldPN : OldPhiNodes) { 895 PHINode *NewPN = NewPNodes[OldPN]; 896 for (User *V : make_early_inc_range(OldPN->users())) { 897 Instruction *ACI = dyn_cast<Instruction>(V); 898 if (ACI && isAMXCast(ACI)) { 899 Type *TyB = ACI->getOperand(0)->getType(); 900 Type *TyA = ACI->getType(); 901 assert(TyA == DestTy && TyB == SrcTy); 902 (void)TyA; 903 (void)TyB; 904 ACI->replaceAllUsesWith(NewPN); 905 DeadInst.insert(ACI); 906 } else if (auto *PHI = dyn_cast<PHINode>(V)) { 907 // We don't need to push PHINode into DeadInst since they are operands 908 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead. 909 assert(OldPhiNodes.contains(PHI)); 910 (void)PHI; 911 } else 912 llvm_unreachable("all uses should be handled"); 913 } 914 } 915 return true; 916 } 917 918 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42) 919 // store <256 x i32> %43, <256 x i32>* %p, align 64 920 // --> 921 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, 922 // i64 64, x86_amx %42) 923 bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) { 924 Value *Tile = Cast->getOperand(0); 925 // TODO: If it is cast intrinsic or phi node, we can propagate the 926 // shape information through def-use chain. 927 if (!isAMXIntrinsic(Tile)) 928 return false; 929 auto *II = cast<IntrinsicInst>(Tile); 930 // Tile is output from AMX intrinsic. The first operand of the 931 // intrinsic is row, the second operand of the intrinsic is column. 932 Value *Row = II->getOperand(0); 933 Value *Col = II->getOperand(1); 934 IRBuilder<> Builder(ST); 935 // Stride should be equal to col(measured by bytes) 936 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); 937 Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy()); 938 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; 939 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, 940 Args); 941 return true; 942 } 943 944 // %65 = load <256 x i32>, <256 x i32>* %p, align 64 945 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) 946 // --> 947 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 948 // i8* %p, i64 64) 949 bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { 950 bool EraseLoad = true; 951 Value *Row = nullptr, *Col = nullptr; 952 Use &U = *(Cast->use_begin()); 953 unsigned OpNo = U.getOperandNo(); 954 auto *II = cast<IntrinsicInst>(U.getUser()); 955 // TODO: If it is cast intrinsic or phi node, we can propagate the 956 // shape information through def-use chain. 957 if (!isAMXIntrinsic(II)) 958 return false; 959 std::tie(Row, Col) = getShape(II, OpNo); 960 IRBuilder<> Builder(LD); 961 // Stride should be equal to col(measured by bytes) 962 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); 963 Value *I8Ptr; 964 965 // To save compiling time, we create doninator tree when it is really 966 // needed. 967 if (!DT) 968 DT.reset(new DominatorTree(Func)); 969 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) { 970 // store the value to stack and reload it from stack before cast. 971 auto *AllocaAddr = 972 createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType()); 973 Builder.SetInsertPoint(&*std::next(LD->getIterator())); 974 Builder.CreateStore(LD, AllocaAddr); 975 976 Builder.SetInsertPoint(Cast); 977 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy()); 978 EraseLoad = false; 979 } else { 980 I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy()); 981 } 982 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 983 984 Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, 985 std::nullopt, Args); 986 Cast->replaceAllUsesWith(NewInst); 987 988 return EraseLoad; 989 } 990 991 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) { 992 bool Change = false; 993 for (auto *Cast : Casts) { 994 auto *II = cast<IntrinsicInst>(Cast); 995 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42) 996 // store <256 x i32> %43, <256 x i32>* %p, align 64 997 // --> 998 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, 999 // i64 64, x86_amx %42) 1000 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) { 1001 SmallVector<Instruction *, 2> DeadStores; 1002 for (User *U : Cast->users()) { 1003 StoreInst *Store = dyn_cast<StoreInst>(U); 1004 if (!Store) 1005 continue; 1006 if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) { 1007 DeadStores.push_back(Store); 1008 Change = true; 1009 } 1010 } 1011 for (auto *Store : DeadStores) 1012 Store->eraseFromParent(); 1013 } else { // x86_cast_vector_to_tile 1014 SmallVector<Instruction *, 2> DeadLoads; 1015 auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0)); 1016 if (!Load || !Load->hasOneUse()) 1017 continue; 1018 // %65 = load <256 x i32>, <256 x i32>* %p, align 64 1019 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) 1020 // --> 1021 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 1022 // i8* %p, i64 64) 1023 if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) { 1024 // Set the operand is null so that load instruction can be erased. 1025 Cast->setOperand(0, nullptr); 1026 Load->eraseFromParent(); 1027 } 1028 } 1029 } 1030 return Change; 1031 } 1032 1033 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { 1034 bool Change = false; 1035 // Collect tile cast instruction. 1036 SmallVector<Instruction *, 8> Vec2TileInsts; 1037 SmallVector<Instruction *, 8> Tile2VecInsts; 1038 SmallVector<Instruction *, 8> PhiCastWorkList; 1039 SmallSetVector<Instruction *, 16> DeadInst; 1040 for (BasicBlock &BB : Func) { 1041 for (Instruction &I : BB) { 1042 Value *Vec; 1043 if (match(&I, 1044 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec)))) 1045 Vec2TileInsts.push_back(&I); 1046 else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>( 1047 m_Value(Vec)))) 1048 Tile2VecInsts.push_back(&I); 1049 } 1050 } 1051 1052 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) { 1053 for (auto *Inst : Insts) { 1054 for (User *U : Inst->users()) { 1055 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); 1056 if (!II || II->getIntrinsicID() != IID) 1057 continue; 1058 // T1 = vec2tile V0 1059 // V2 = tile2vec T1 1060 // V3 = OP V2 1061 // --> 1062 // T1 = vec2tile V0 1063 // V2 = tile2vec T1 1064 // V3 = OP V0 1065 II->replaceAllUsesWith(Inst->getOperand(0)); 1066 Change = true; 1067 } 1068 } 1069 }; 1070 1071 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector); 1072 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile); 1073 1074 SmallVector<Instruction *, 8> LiveCasts; 1075 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) { 1076 for (auto *Inst : Insts) { 1077 if (Inst->use_empty()) { 1078 Inst->eraseFromParent(); 1079 Change = true; 1080 } else { 1081 LiveCasts.push_back(Inst); 1082 } 1083 } 1084 }; 1085 1086 EraseInst(Vec2TileInsts); 1087 EraseInst(Tile2VecInsts); 1088 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " 1089 "Vec2Tile and Tile2Vec:\n"; 1090 Func.dump()); 1091 Change |= combineLdSt(LiveCasts); 1092 EraseInst(LiveCasts); 1093 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " 1094 "AMXCast and load/store:\n"; 1095 Func.dump()); 1096 1097 // Handle the A->B->A cast, and there is an intervening PHI node. 1098 for (BasicBlock &BB : Func) { 1099 for (Instruction &I : BB) { 1100 if (isAMXCast(&I)) { 1101 if (isa<PHINode>(I.getOperand(0))) 1102 PhiCastWorkList.push_back(&I); 1103 } 1104 } 1105 } 1106 for (auto *I : PhiCastWorkList) { 1107 // We skip the dead Amxcast. 1108 if (DeadInst.contains(I)) 1109 continue; 1110 PHINode *PN = cast<PHINode>(I->getOperand(0)); 1111 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) { 1112 DeadInst.insert(PN); 1113 Change = true; 1114 } 1115 } 1116 1117 // Since we create new phi and merge AMXCast, some old phis and AMXCast might 1118 // have no uses. We do some DeadCodeElimination for them. 1119 while (!DeadInst.empty()) { 1120 Instruction *I = DeadInst.pop_back_val(); 1121 Change |= DCEInstruction(I, DeadInst, TLI); 1122 } 1123 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after " 1124 "optimizeAMXCastFromPhi:\n"; 1125 Func.dump()); 1126 return Change; 1127 } 1128 1129 // There might be remaining AMXcast after combineAMXcast and they should be 1130 // handled elegantly. 1131 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { 1132 IRBuilder<> Builder(AMXCast); 1133 AllocaInst *AllocaAddr; 1134 Value *I8Ptr, *Stride; 1135 auto *Src = AMXCast->getOperand(0); 1136 1137 auto Prepare = [&](Type *MemTy) { 1138 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy); 1139 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy()); 1140 Stride = Builder.getInt64(64); 1141 }; 1142 1143 if (AMXCast->getType()->isX86_AMXTy()) { 1144 // %2 = amxcast <225 x i32> %src to x86_amx 1145 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, 1146 // i8* %addr3, i64 60, x86_amx %2) 1147 // --> 1148 // %addr = alloca <225 x i32>, align 64 1149 // store <225 x i32> %src, <225 x i32>* %addr, align 64 1150 // %addr2 = bitcast <225 x i32>* %addr to i8* 1151 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60, 1152 // i8* %addr2, 1153 // i64 60) 1154 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, 1155 // i8* %addr3, i64 60, x86_amx %2) 1156 if (AMXCast->use_empty()) { 1157 AMXCast->eraseFromParent(); 1158 return true; 1159 } 1160 Use &U = *(AMXCast->use_begin()); 1161 unsigned OpNo = U.getOperandNo(); 1162 auto *II = dyn_cast<IntrinsicInst>(U.getUser()); 1163 if (!II) 1164 return false; // May be bitcast from x86amx to <256 x i32>. 1165 Prepare(AMXCast->getOperand(0)->getType()); 1166 Builder.CreateStore(Src, AllocaAddr); 1167 // TODO we can pick an constant operand for the shape. 1168 Value *Row = nullptr, *Col = nullptr; 1169 std::tie(Row, Col) = getShape(II, OpNo); 1170 std::array<Value *, 4> Args = { 1171 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())}; 1172 Value *NewInst = Builder.CreateIntrinsic( 1173 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args); 1174 AMXCast->replaceAllUsesWith(NewInst); 1175 AMXCast->eraseFromParent(); 1176 } else { 1177 // %2 = amxcast x86_amx %src to <225 x i32> 1178 // --> 1179 // %addr = alloca <225 x i32>, align 64 1180 // %addr2 = bitcast <225 x i32>* to i8* 1181 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, 1182 // i8* %addr2, i64 %stride) 1183 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64 1184 auto *II = dyn_cast<IntrinsicInst>(Src); 1185 if (!II) 1186 return false; // May be bitcast from <256 x i32> to x86amx. 1187 Prepare(AMXCast->getType()); 1188 Value *Row = II->getOperand(0); 1189 Value *Col = II->getOperand(1); 1190 std::array<Value *, 5> Args = { 1191 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src}; 1192 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, 1193 Args); 1194 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr); 1195 AMXCast->replaceAllUsesWith(NewInst); 1196 AMXCast->eraseFromParent(); 1197 } 1198 1199 return true; 1200 } 1201 1202 bool X86LowerAMXCast::transformAllAMXCast() { 1203 bool Change = false; 1204 // Collect tile cast instruction. 1205 SmallVector<Instruction *, 8> WorkLists; 1206 for (BasicBlock &BB : Func) { 1207 for (Instruction &I : BB) { 1208 if (isAMXCast(&I)) 1209 WorkLists.push_back(&I); 1210 } 1211 } 1212 1213 for (auto *Inst : WorkLists) { 1214 Change |= transformAMXCast(cast<IntrinsicInst>(Inst)); 1215 } 1216 1217 return Change; 1218 } 1219 1220 } // anonymous namespace 1221 1222 namespace { 1223 1224 class X86LowerAMXTypeLegacyPass : public FunctionPass { 1225 public: 1226 static char ID; 1227 1228 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { 1229 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); 1230 } 1231 1232 bool runOnFunction(Function &F) override { 1233 bool C = false; 1234 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 1235 TargetLibraryInfo *TLI = 1236 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 1237 1238 X86LowerAMXCast LAC(F); 1239 C |= LAC.combineAMXcast(TLI); 1240 // There might be remaining AMXcast after combineAMXcast and they should be 1241 // handled elegantly. 1242 C |= LAC.transformAllAMXCast(); 1243 1244 X86LowerAMXType LAT(F); 1245 C |= LAT.visit(); 1246 1247 // Prepare for fast register allocation at O0. 1248 // Todo: May better check the volatile model of AMX code, not just 1249 // by checking Attribute::OptimizeNone and CodeGenOptLevel::None. 1250 if (TM->getOptLevel() == CodeGenOptLevel::None) { 1251 // If Front End not use O0 but the Mid/Back end use O0, (e.g. 1252 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make 1253 // sure the amx data is volatile, that is nessary for AMX fast 1254 // register allocation. 1255 if (!F.hasFnAttribute(Attribute::OptimizeNone)) { 1256 X86VolatileTileData VTD(F); 1257 C = VTD.volatileTileData() || C; 1258 } 1259 } 1260 1261 return C; 1262 } 1263 1264 void getAnalysisUsage(AnalysisUsage &AU) const override { 1265 AU.setPreservesCFG(); 1266 AU.addRequired<TargetPassConfig>(); 1267 AU.addRequired<TargetLibraryInfoWrapperPass>(); 1268 } 1269 }; 1270 1271 } // anonymous namespace 1272 1273 static const char PassName[] = "Lower AMX type for load/store"; 1274 char X86LowerAMXTypeLegacyPass::ID = 0; 1275 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, 1276 false) 1277 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 1278 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 1279 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, 1280 false) 1281 1282 FunctionPass *llvm::createX86LowerAMXTypePass() { 1283 return new X86LowerAMXTypeLegacyPass(); 1284 } 1285