1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to transform <256 x i32> load/store 10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only 11 /// provides simple operation on x86_amx. The basic elementwise operation 12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> 13 /// and only AMX intrinsics can operate on the type, we need transform 14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can 15 /// not be combined with load/store, we transform the bitcast to amx load/store 16 /// and <256 x i32> store/load. 17 /// 18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S 19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, 20 /// because that is necessary for AMX fast register allocation. (In Fast 21 /// registera allocation, register will be allocated before spill/reload, so 22 /// there is no additional register for amx to identify the step in spill.) 23 /// The volatileTileData() will handle this case. 24 /// e.g. 25 /// ---------------------------------------------------------- 26 /// | def %td = ... | 27 /// | ... | 28 /// | "use %td" | 29 /// ---------------------------------------------------------- 30 /// will transfer to --> 31 /// ---------------------------------------------------------- 32 /// | def %td = ... | 33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) | 34 /// | ... | 35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| 36 /// | "use %td2" | 37 /// ---------------------------------------------------------- 38 // 39 //===----------------------------------------------------------------------===// 40 // 41 #include "X86.h" 42 #include "llvm/ADT/PostOrderIterator.h" 43 #include "llvm/ADT/SetVector.h" 44 #include "llvm/ADT/SmallSet.h" 45 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 46 #include "llvm/Analysis/TargetLibraryInfo.h" 47 #include "llvm/Analysis/TargetTransformInfo.h" 48 #include "llvm/CodeGen/Passes.h" 49 #include "llvm/CodeGen/TargetPassConfig.h" 50 #include "llvm/CodeGen/ValueTypes.h" 51 #include "llvm/IR/DataLayout.h" 52 #include "llvm/IR/Function.h" 53 #include "llvm/IR/IRBuilder.h" 54 #include "llvm/IR/Instructions.h" 55 #include "llvm/IR/IntrinsicInst.h" 56 #include "llvm/IR/IntrinsicsX86.h" 57 #include "llvm/IR/PatternMatch.h" 58 #include "llvm/InitializePasses.h" 59 #include "llvm/Pass.h" 60 #include "llvm/Target/TargetMachine.h" 61 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" 62 #include "llvm/Transforms/Utils/Local.h" 63 64 #include <map> 65 66 using namespace llvm; 67 using namespace PatternMatch; 68 69 #define DEBUG_TYPE "lower-amx-type" 70 71 static bool isAMXCast(Instruction *II) { 72 return match(II, 73 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) || 74 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value())); 75 } 76 77 static bool isAMXIntrinsic(Value *I) { 78 auto *II = dyn_cast<IntrinsicInst>(I); 79 if (!II) 80 return false; 81 if (isAMXCast(II)) 82 return false; 83 // Check if return type or parameter is x86_amx. If it is x86_amx 84 // the intrinsic must be x86 amx intrinsics. 85 if (II->getType()->isX86_AMXTy()) 86 return true; 87 for (Value *V : II->args()) { 88 if (V->getType()->isX86_AMXTy()) 89 return true; 90 } 91 92 return false; 93 } 94 95 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, 96 Type *Ty) { 97 Function &F = *BB->getParent(); 98 Module *M = BB->getModule(); 99 const DataLayout &DL = M->getDataLayout(); 100 101 LLVMContext &Ctx = Builder.getContext(); 102 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); 103 unsigned AllocaAS = DL.getAllocaAddrSpace(); 104 AllocaInst *AllocaRes = 105 new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front()); 106 AllocaRes->setAlignment(AllocaAlignment); 107 return AllocaRes; 108 } 109 110 static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { 111 for (Instruction &I : F.getEntryBlock()) 112 if (!isa<AllocaInst>(&I)) 113 return &I; 114 llvm_unreachable("No terminator in the entry block!"); 115 } 116 117 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) { 118 IRBuilder<> Builder(II); 119 Value *Row = nullptr, *Col = nullptr; 120 switch (II->getIntrinsicID()) { 121 default: 122 llvm_unreachable("Expect amx intrinsics"); 123 case Intrinsic::x86_tileloadd64_internal: 124 case Intrinsic::x86_tileloaddt164_internal: 125 case Intrinsic::x86_tilestored64_internal: { 126 Row = II->getArgOperand(0); 127 Col = II->getArgOperand(1); 128 break; 129 } 130 // a * b + c 131 // The shape depends on which operand. 132 case Intrinsic::x86_tcmmimfp16ps_internal: 133 case Intrinsic::x86_tcmmrlfp16ps_internal: 134 case Intrinsic::x86_tdpbssd_internal: 135 case Intrinsic::x86_tdpbsud_internal: 136 case Intrinsic::x86_tdpbusd_internal: 137 case Intrinsic::x86_tdpbuud_internal: 138 case Intrinsic::x86_tdpbf16ps_internal: 139 case Intrinsic::x86_tdpfp16ps_internal: { 140 switch (OpNo) { 141 case 3: 142 Row = II->getArgOperand(0); 143 Col = II->getArgOperand(1); 144 break; 145 case 4: 146 Row = II->getArgOperand(0); 147 Col = II->getArgOperand(2); 148 break; 149 case 5: 150 if (isa<ConstantInt>(II->getArgOperand(2))) 151 Row = Builder.getInt16( 152 (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4); 153 else if (isa<Instruction>(II->getArgOperand(2))) { 154 // When it is not a const value and it is not a function argument, we 155 // create Row after the definition of II->getOperand(2) instead of 156 // before II. For example, II is %118, we try to getshape for %117: 157 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x 158 // i32> %115). 159 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 160 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx 161 // %117). 162 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its 163 // definition is after its user(new tileload for %117). 164 // So, the best choice is to create %row right after the definition of 165 // %106. 166 Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2))); 167 Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4)); 168 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2))); 169 } else { 170 // When it is not a const value and it is a function argument, we create 171 // Row at the entry bb. 172 IRBuilder<> NewBuilder( 173 getFirstNonAllocaInTheEntryBlock(*II->getFunction())); 174 Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4)); 175 } 176 Col = II->getArgOperand(1); 177 break; 178 } 179 break; 180 } 181 } 182 183 return std::make_pair(Row, Col); 184 } 185 186 static std::pair<Value *, Value *> getShape(PHINode *Phi) { 187 Use &U = *(Phi->use_begin()); 188 unsigned OpNo = U.getOperandNo(); 189 User *V = U.getUser(); 190 // TODO We don't traverse all users. To make the algorithm simple, here we 191 // just traverse the first user. If we can find shape, then return the shape, 192 // otherwise just return nullptr and the optimization for undef/zero will be 193 // abandoned. 194 while (V) { 195 if (isAMXCast(dyn_cast<Instruction>(V))) { 196 if (V->use_empty()) 197 break; 198 Use &U = *(V->use_begin()); 199 OpNo = U.getOperandNo(); 200 V = U.getUser(); 201 } else if (isAMXIntrinsic(V)) { 202 return getShape(cast<IntrinsicInst>(V), OpNo); 203 } else if (isa<PHINode>(V)) { 204 if (V->use_empty()) 205 break; 206 Use &U = *(V->use_begin()); 207 V = U.getUser(); 208 } else { 209 break; 210 } 211 } 212 213 return std::make_pair(nullptr, nullptr); 214 } 215 216 namespace { 217 class X86LowerAMXType { 218 Function &Func; 219 220 // In AMX intrinsics we let Shape = {Row, Col}, but the 221 // RealCol = Col / ElementSize. We may use the RealCol 222 // as a new Row for other new created AMX intrinsics. 223 std::map<Value *, Value *> Col2Row; 224 225 public: 226 X86LowerAMXType(Function &F) : Func(F) {} 227 bool visit(); 228 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); 229 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); 230 bool transformBitcast(BitCastInst *Bitcast); 231 }; 232 233 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 234 // %2 = bitcast <256 x i32> %src to x86_amx 235 // --> 236 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 237 // i8* %addr, i64 %stride64) 238 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { 239 Value *Row = nullptr, *Col = nullptr; 240 Use &U = *(Bitcast->use_begin()); 241 unsigned OpNo = U.getOperandNo(); 242 auto *II = cast<IntrinsicInst>(U.getUser()); 243 std::tie(Row, Col) = getShape(II, OpNo); 244 IRBuilder<> Builder(Bitcast); 245 // Use the maximun column as stride. 246 Value *Stride = Builder.getInt64(64); 247 Value *I8Ptr = 248 Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); 249 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 250 251 Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, 252 std::nullopt, Args); 253 Bitcast->replaceAllUsesWith(NewInst); 254 } 255 256 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, 257 // %stride); 258 // %13 = bitcast x86_amx %src to <256 x i32> 259 // store <256 x i32> %13, <256 x i32>* %addr, align 64 260 // --> 261 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 262 // %stride64, %13) 263 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { 264 265 Value *Tile = Bitcast->getOperand(0); 266 auto *II = cast<IntrinsicInst>(Tile); 267 // Tile is output from AMX intrinsic. The first operand of the 268 // intrinsic is row, the second operand of the intrinsic is column. 269 Value *Row = II->getOperand(0); 270 Value *Col = II->getOperand(1); 271 IRBuilder<> Builder(ST); 272 // Use the maximum column as stride. It must be the same with load 273 // stride. 274 Value *Stride = Builder.getInt64(64); 275 Value *I8Ptr = 276 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); 277 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; 278 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, 279 Args); 280 if (Bitcast->hasOneUse()) 281 return; 282 // %13 = bitcast x86_amx %src to <256 x i32> 283 // store <256 x i32> %13, <256 x i32>* %addr, align 64 284 // %add = <256 x i32> %13, <256 x i32> %src2 285 // --> 286 // %13 = bitcast x86_amx %src to <256 x i32> 287 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 288 // %stride64, %13) 289 // %14 = load <256 x i32>, %addr 290 // %add = <256 x i32> %14, <256 x i32> %src2 291 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1)); 292 Bitcast->replaceAllUsesWith(Vec); 293 } 294 295 // transform bitcast to <store, load> instructions. 296 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { 297 IRBuilder<> Builder(Bitcast); 298 AllocaInst *AllocaAddr; 299 Value *I8Ptr, *Stride; 300 auto *Src = Bitcast->getOperand(0); 301 302 auto Prepare = [&](Type *MemTy) { 303 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy); 304 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); 305 Stride = Builder.getInt64(64); 306 }; 307 308 if (Bitcast->getType()->isX86_AMXTy()) { 309 // %2 = bitcast <256 x i32> %src to x86_amx 310 // --> 311 // %addr = alloca <256 x i32>, align 64 312 // store <256 x i32> %src, <256 x i32>* %addr, align 64 313 // %addr2 = bitcast <256 x i32>* to i8* 314 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 315 // i8* %addr2, 316 // i64 64) 317 Use &U = *(Bitcast->use_begin()); 318 unsigned OpNo = U.getOperandNo(); 319 auto *II = dyn_cast<IntrinsicInst>(U.getUser()); 320 if (!II) 321 return false; // May be bitcast from x86amx to <256 x i32>. 322 Prepare(Bitcast->getOperand(0)->getType()); 323 Builder.CreateStore(Src, AllocaAddr); 324 // TODO we can pick an constant operand for the shape. 325 Value *Row = nullptr, *Col = nullptr; 326 std::tie(Row, Col) = getShape(II, OpNo); 327 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 328 Value *NewInst = Builder.CreateIntrinsic( 329 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args); 330 Bitcast->replaceAllUsesWith(NewInst); 331 } else { 332 // %2 = bitcast x86_amx %src to <256 x i32> 333 // --> 334 // %addr = alloca <256 x i32>, align 64 335 // %addr2 = bitcast <256 x i32>* to i8* 336 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, 337 // i8* %addr2, i64 %stride) 338 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 339 auto *II = dyn_cast<IntrinsicInst>(Src); 340 if (!II) 341 return false; // May be bitcast from <256 x i32> to x86amx. 342 Prepare(Bitcast->getType()); 343 Value *Row = II->getOperand(0); 344 Value *Col = II->getOperand(1); 345 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; 346 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, 347 Args); 348 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr); 349 Bitcast->replaceAllUsesWith(NewInst); 350 } 351 352 return true; 353 } 354 355 bool X86LowerAMXType::visit() { 356 SmallVector<Instruction *, 8> DeadInsts; 357 Col2Row.clear(); 358 359 for (BasicBlock *BB : post_order(&Func)) { 360 for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) { 361 auto *Bitcast = dyn_cast<BitCastInst>(&Inst); 362 if (!Bitcast) 363 continue; 364 365 Value *Src = Bitcast->getOperand(0); 366 if (Bitcast->getType()->isX86_AMXTy()) { 367 if (Bitcast->user_empty()) { 368 DeadInsts.push_back(Bitcast); 369 continue; 370 } 371 LoadInst *LD = dyn_cast<LoadInst>(Src); 372 if (!LD) { 373 if (transformBitcast(Bitcast)) 374 DeadInsts.push_back(Bitcast); 375 continue; 376 } 377 // If load has mutli-user, duplicate a vector load. 378 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 379 // %2 = bitcast <256 x i32> %src to x86_amx 380 // %add = add <256 x i32> %src, <256 x i32> %src2 381 // --> 382 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 383 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 384 // i8* %addr, i64 %stride64) 385 // %add = add <256 x i32> %src, <256 x i32> %src2 386 387 // If load has one user, the load will be eliminated in DAG ISel. 388 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 389 // %2 = bitcast <256 x i32> %src to x86_amx 390 // --> 391 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 392 // i8* %addr, i64 %stride64) 393 combineLoadBitcast(LD, Bitcast); 394 DeadInsts.push_back(Bitcast); 395 if (LD->hasOneUse()) 396 DeadInsts.push_back(LD); 397 } else if (Src->getType()->isX86_AMXTy()) { 398 if (Bitcast->user_empty()) { 399 DeadInsts.push_back(Bitcast); 400 continue; 401 } 402 StoreInst *ST = nullptr; 403 for (Use &U : Bitcast->uses()) { 404 ST = dyn_cast<StoreInst>(U.getUser()); 405 if (ST) 406 break; 407 } 408 if (!ST) { 409 if (transformBitcast(Bitcast)) 410 DeadInsts.push_back(Bitcast); 411 continue; 412 } 413 // If bitcast (%13) has one use, combine bitcast and store to amx store. 414 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, 415 // %stride); 416 // %13 = bitcast x86_amx %src to <256 x i32> 417 // store <256 x i32> %13, <256 x i32>* %addr, align 64 418 // --> 419 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 420 // %stride64, %13) 421 // 422 // If bitcast (%13) has multi-use, transform as below. 423 // %13 = bitcast x86_amx %src to <256 x i32> 424 // store <256 x i32> %13, <256 x i32>* %addr, align 64 425 // %add = <256 x i32> %13, <256 x i32> %src2 426 // --> 427 // %13 = bitcast x86_amx %src to <256 x i32> 428 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 429 // %stride64, %13) 430 // %14 = load <256 x i32>, %addr 431 // %add = <256 x i32> %14, <256 x i32> %src2 432 // 433 combineBitcastStore(Bitcast, ST); 434 // Delete user first. 435 DeadInsts.push_back(ST); 436 DeadInsts.push_back(Bitcast); 437 } 438 } 439 } 440 441 bool C = !DeadInsts.empty(); 442 443 for (auto *Inst : DeadInsts) 444 Inst->eraseFromParent(); 445 446 return C; 447 } 448 } // anonymous namespace 449 450 static Value *getAllocaPos(BasicBlock *BB) { 451 Module *M = BB->getModule(); 452 Function *F = BB->getParent(); 453 IRBuilder<> Builder(&F->getEntryBlock().front()); 454 const DataLayout &DL = M->getDataLayout(); 455 unsigned AllocaAS = DL.getAllocaAddrSpace(); 456 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); 457 AllocaInst *AllocaRes = 458 new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front()); 459 BasicBlock::iterator Iter = AllocaRes->getIterator(); 460 ++Iter; 461 Builder.SetInsertPoint(&*Iter); 462 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy()); 463 return I8Ptr; 464 } 465 466 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { 467 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!"); 468 auto *II = cast<IntrinsicInst>(TileDef); 469 assert(II && "Not tile intrinsic!"); 470 Value *Row = II->getOperand(0); 471 Value *Col = II->getOperand(1); 472 473 BasicBlock *BB = TileDef->getParent(); 474 BasicBlock::iterator Iter = TileDef->getIterator(); 475 IRBuilder<> Builder(BB, ++Iter); 476 Value *Stride = Builder.getInt64(64); 477 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef}; 478 479 Instruction *TileStore = Builder.CreateIntrinsic( 480 Intrinsic::x86_tilestored64_internal, std::nullopt, Args); 481 return TileStore; 482 } 483 484 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { 485 Value *V = U.get(); 486 assert(V->getType()->isX86_AMXTy() && "Not define tile!"); 487 488 // Get tile shape. 489 IntrinsicInst *II = nullptr; 490 if (IsPHI) { 491 Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0); 492 II = cast<IntrinsicInst>(PhiOp); 493 } else { 494 II = cast<IntrinsicInst>(V); 495 } 496 Value *Row = II->getOperand(0); 497 Value *Col = II->getOperand(1); 498 499 Instruction *UserI = dyn_cast<Instruction>(U.getUser()); 500 IRBuilder<> Builder(UserI); 501 Value *Stride = Builder.getInt64(64); 502 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride}; 503 504 Value *TileLoad = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, 505 std::nullopt, Args); 506 UserI->replaceUsesOfWith(V, TileLoad); 507 } 508 509 static bool isIncomingOfPHI(Instruction *I) { 510 for (Use &U : I->uses()) { 511 User *V = U.getUser(); 512 if (isa<PHINode>(V)) 513 return true; 514 } 515 return false; 516 } 517 518 // Let all AMX tile data become volatile data, shorten the life range 519 // of each tile register before fast register allocation. 520 namespace { 521 class X86VolatileTileData { 522 Function &F; 523 524 public: 525 X86VolatileTileData(Function &Func) : F(Func) {} 526 Value *updatePhiIncomings(BasicBlock *BB, 527 SmallVector<Instruction *, 2> &Incomings); 528 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); 529 bool volatileTileData(); 530 void volatileTilePHI(PHINode *PHI); 531 void volatileTileNonPHI(Instruction *I); 532 }; 533 534 Value *X86VolatileTileData::updatePhiIncomings( 535 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { 536 Value *I8Ptr = getAllocaPos(BB); 537 538 for (auto *I : Incomings) { 539 User *Store = createTileStore(I, I8Ptr); 540 541 // All its uses (except phi) should load from stored mem. 542 for (Use &U : I->uses()) { 543 User *V = U.getUser(); 544 if (isa<PHINode>(V) || V == Store) 545 continue; 546 replaceWithTileLoad(U, I8Ptr); 547 } 548 } 549 return I8Ptr; 550 } 551 552 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, 553 Value *StorePtr) { 554 for (Use &U : PHI->uses()) 555 replaceWithTileLoad(U, StorePtr, true); 556 PHI->eraseFromParent(); 557 } 558 559 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes 560 // and their related AMX intrinsics. 561 // 1) PHI Def should change to tileload. 562 // 2) PHI Incoming Values should tilestored in just after their def. 563 // 3) The mem of these tileload and tilestores should be same. 564 // e.g. 565 // ------------------------------------------------------ 566 // bb_dom: 567 // ... 568 // br i1 %bool.cond, label %if.else, label %if.then 569 // 570 // if.then: 571 // def %t0 = ... 572 // ... 573 // use %t0 574 // ... 575 // br label %if.end 576 // 577 // if.else: 578 // def %t1 = ... 579 // br label %if.end 580 // 581 // if.end: 582 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] 583 // ... 584 // use %td 585 // ------------------------------------------------------ 586 // --> 587 // ------------------------------------------------------ 588 // bb_entry: 589 // %mem = alloca <256 x i32>, align 1024 * 590 // ... 591 // bb_dom: 592 // ... 593 // br i1 %bool.cond, label %if.else, label %if.then 594 // 595 // if.then: 596 // def %t0 = ... 597 // call void @llvm.x86.tilestored64.internal(mem, %t0) * 598 // ... 599 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* 600 // use %t0` * 601 // ... 602 // br label %if.end 603 // 604 // if.else: 605 // def %t1 = ... 606 // call void @llvm.x86.tilestored64.internal(mem, %t1) * 607 // br label %if.end 608 // 609 // if.end: 610 // ... 611 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * 612 // use %td 613 // ------------------------------------------------------ 614 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { 615 BasicBlock *BB = PHI->getParent(); 616 SmallVector<Instruction *, 2> Incomings; 617 618 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { 619 Value *Op = PHI->getIncomingValue(I); 620 Instruction *Inst = dyn_cast<Instruction>(Op); 621 assert(Inst && "We shouldn't fold AMX instrution!"); 622 Incomings.push_back(Inst); 623 } 624 625 Value *StorePtr = updatePhiIncomings(BB, Incomings); 626 replacePhiDefWithLoad(PHI, StorePtr); 627 } 628 629 // Store the defined tile and load it before use. 630 // All its users are not PHI. 631 // e.g. 632 // ------------------------------------------------------ 633 // def %td = ... 634 // ... 635 // "use %td" 636 // ------------------------------------------------------ 637 // --> 638 // ------------------------------------------------------ 639 // def %td = ... 640 // call void @llvm.x86.tilestored64.internal(mem, %td) 641 // ... 642 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) 643 // "use %td2" 644 // ------------------------------------------------------ 645 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { 646 BasicBlock *BB = I->getParent(); 647 Value *I8Ptr = getAllocaPos(BB); 648 User *Store = createTileStore(I, I8Ptr); 649 650 // All its uses should load from stored mem. 651 for (Use &U : I->uses()) { 652 User *V = U.getUser(); 653 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!"); 654 if (V != Store) 655 replaceWithTileLoad(U, I8Ptr); 656 } 657 } 658 659 // Volatile Tile Model: 660 // 1) All the uses of tile data comes from tileload in time. 661 // 2) All the defs of tile data tilestore into mem immediately. 662 // For example: 663 // -------------------------------------------------------------------------- 664 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key 665 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) 666 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx 667 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) 668 // call void @llvm.x86.tilestored64.internal(... td) area 669 // -------------------------------------------------------------------------- 670 // 3) No terminator, call or other amx instructions in the key amx area. 671 bool X86VolatileTileData::volatileTileData() { 672 bool Changed = false; 673 for (BasicBlock &BB : F) { 674 SmallVector<Instruction *, 2> PHIInsts; 675 SmallVector<Instruction *, 8> AMXDefInsts; 676 677 for (Instruction &I : BB) { 678 if (!I.getType()->isX86_AMXTy()) 679 continue; 680 if (isa<PHINode>(&I)) 681 PHIInsts.push_back(&I); 682 else 683 AMXDefInsts.push_back(&I); 684 } 685 686 // First we "volatile" the non-phi related amx intrinsics. 687 for (Instruction *I : AMXDefInsts) { 688 if (isIncomingOfPHI(I)) 689 continue; 690 volatileTileNonPHI(I); 691 Changed = true; 692 } 693 694 for (Instruction *I : PHIInsts) { 695 volatileTilePHI(dyn_cast<PHINode>(I)); 696 Changed = true; 697 } 698 } 699 return Changed; 700 } 701 702 } // anonymous namespace 703 704 namespace { 705 706 class X86LowerAMXCast { 707 Function &Func; 708 std::unique_ptr<DominatorTree> DT; 709 710 public: 711 X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {} 712 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST); 713 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); 714 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts); 715 bool combineAMXcast(TargetLibraryInfo *TLI); 716 bool transformAMXCast(IntrinsicInst *AMXCast); 717 bool transformAllAMXCast(); 718 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN, 719 SmallSetVector<Instruction *, 16> &DeadInst); 720 }; 721 722 static bool DCEInstruction(Instruction *I, 723 SmallSetVector<Instruction *, 16> &WorkList, 724 const TargetLibraryInfo *TLI) { 725 if (isInstructionTriviallyDead(I, TLI)) { 726 salvageDebugInfo(*I); 727 salvageKnowledge(I); 728 729 // Null out all of the instruction's operands to see if any operand becomes 730 // dead as we go. 731 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { 732 Value *OpV = I->getOperand(i); 733 I->setOperand(i, nullptr); 734 735 if (!OpV->use_empty() || I == OpV) 736 continue; 737 738 // If the operand is an instruction that became dead as we nulled out the 739 // operand, and if it is 'trivially' dead, delete it in a future loop 740 // iteration. 741 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) { 742 if (isInstructionTriviallyDead(OpI, TLI)) { 743 WorkList.insert(OpI); 744 } 745 } 746 } 747 I->eraseFromParent(); 748 return true; 749 } 750 return false; 751 } 752 753 /// This function handles following case 754 /// 755 /// A -> B amxcast 756 /// PHI 757 /// B -> A amxcast 758 /// 759 /// All the related PHI nodes can be replaced by new PHI nodes with type A. 760 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. 761 bool X86LowerAMXCast::optimizeAMXCastFromPhi( 762 IntrinsicInst *CI, PHINode *PN, 763 SmallSetVector<Instruction *, 16> &DeadInst) { 764 IRBuilder<> Builder(CI); 765 Value *Src = CI->getOperand(0); 766 Type *SrcTy = Src->getType(); // Type B 767 Type *DestTy = CI->getType(); // Type A 768 769 SmallVector<PHINode *, 4> PhiWorklist; 770 SmallSetVector<PHINode *, 4> OldPhiNodes; 771 772 // Find all of the A->B casts and PHI nodes. 773 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so 774 // OldPhiNodes is used to track all known PHI nodes, before adding a new 775 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. 776 PhiWorklist.push_back(PN); 777 OldPhiNodes.insert(PN); 778 while (!PhiWorklist.empty()) { 779 auto *OldPN = PhiWorklist.pop_back_val(); 780 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) { 781 Value *IncValue = OldPN->getIncomingValue(I); 782 // TODO: currently, We ignore cases where it is a const. In the future, we 783 // might support const. 784 if (isa<Constant>(IncValue)) { 785 auto *IncConst = dyn_cast<Constant>(IncValue); 786 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue()) 787 return false; 788 Value *Row = nullptr, *Col = nullptr; 789 std::tie(Row, Col) = getShape(OldPN); 790 // TODO: If it is not constant the Row and Col must domoniate tilezero 791 // that we are going to create. 792 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col)) 793 return false; 794 // Create tilezero at the end of incoming block. 795 auto *Block = OldPN->getIncomingBlock(I); 796 BasicBlock::iterator Iter = Block->getTerminator()->getIterator(); 797 Instruction *NewInst = Builder.CreateIntrinsic( 798 Intrinsic::x86_tilezero_internal, std::nullopt, {Row, Col}); 799 NewInst->moveBefore(&*Iter); 800 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector, 801 {IncValue->getType()}, {NewInst}); 802 NewInst->moveBefore(&*Iter); 803 // Replace InValue with new Value. 804 OldPN->setIncomingValue(I, NewInst); 805 IncValue = NewInst; 806 } 807 808 if (auto *PNode = dyn_cast<PHINode>(IncValue)) { 809 if (OldPhiNodes.insert(PNode)) 810 PhiWorklist.push_back(PNode); 811 continue; 812 } 813 Instruction *ACI = dyn_cast<Instruction>(IncValue); 814 if (ACI && isAMXCast(ACI)) { 815 // Verify it's a A->B cast. 816 Type *TyA = ACI->getOperand(0)->getType(); 817 Type *TyB = ACI->getType(); 818 if (TyA != DestTy || TyB != SrcTy) 819 return false; 820 continue; 821 } 822 return false; 823 } 824 } 825 826 // Check that each user of each old PHI node is something that we can 827 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. 828 for (auto *OldPN : OldPhiNodes) { 829 for (User *V : OldPN->users()) { 830 Instruction *ACI = dyn_cast<Instruction>(V); 831 if (ACI && isAMXCast(ACI)) { 832 // Verify it's a B->A cast. 833 Type *TyB = ACI->getOperand(0)->getType(); 834 Type *TyA = ACI->getType(); 835 if (TyA != DestTy || TyB != SrcTy) 836 return false; 837 } else if (auto *PHI = dyn_cast<PHINode>(V)) { 838 // As long as the user is another old PHI node, then even if we don't 839 // rewrite it, the PHI web we're considering won't have any users 840 // outside itself, so it'll be dead. 841 // example: 842 // bb.0: 843 // %0 = amxcast ... 844 // bb.1: 845 // %1 = amxcast ... 846 // bb.2: 847 // %goodphi = phi %0, %1 848 // %3 = amxcast %goodphi 849 // bb.3: 850 // %goodphi2 = phi %0, %goodphi 851 // %4 = amxcast %goodphi2 852 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is 853 // outside the phi-web, so the combination stop When 854 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization 855 // will be done. 856 if (OldPhiNodes.count(PHI) == 0) 857 return false; 858 } else 859 return false; 860 } 861 } 862 863 // For each old PHI node, create a corresponding new PHI node with a type A. 864 SmallDenseMap<PHINode *, PHINode *> NewPNodes; 865 for (auto *OldPN : OldPhiNodes) { 866 Builder.SetInsertPoint(OldPN); 867 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands()); 868 NewPNodes[OldPN] = NewPN; 869 } 870 871 // Fill in the operands of new PHI nodes. 872 for (auto *OldPN : OldPhiNodes) { 873 PHINode *NewPN = NewPNodes[OldPN]; 874 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { 875 Value *V = OldPN->getOperand(j); 876 Value *NewV = nullptr; 877 Instruction *ACI = dyn_cast<Instruction>(V); 878 // There should not be a AMXcast from a const. 879 if (ACI && isAMXCast(ACI)) 880 NewV = ACI->getOperand(0); 881 else if (auto *PrevPN = dyn_cast<PHINode>(V)) 882 NewV = NewPNodes[PrevPN]; 883 assert(NewV); 884 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j)); 885 } 886 } 887 888 // Traverse all accumulated PHI nodes and process its users, 889 // which are Stores and BitcCasts. Without this processing 890 // NewPHI nodes could be replicated and could lead to extra 891 // moves generated after DeSSA. 892 // If there is a store with type B, change it to type A. 893 894 // Replace users of BitCast B->A with NewPHI. These will help 895 // later to get rid of a closure formed by OldPHI nodes. 896 for (auto *OldPN : OldPhiNodes) { 897 PHINode *NewPN = NewPNodes[OldPN]; 898 for (User *V : make_early_inc_range(OldPN->users())) { 899 Instruction *ACI = dyn_cast<Instruction>(V); 900 if (ACI && isAMXCast(ACI)) { 901 Type *TyB = ACI->getOperand(0)->getType(); 902 Type *TyA = ACI->getType(); 903 assert(TyA == DestTy && TyB == SrcTy); 904 (void)TyA; 905 (void)TyB; 906 ACI->replaceAllUsesWith(NewPN); 907 DeadInst.insert(ACI); 908 } else if (auto *PHI = dyn_cast<PHINode>(V)) { 909 // We don't need to push PHINode into DeadInst since they are operands 910 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead. 911 assert(OldPhiNodes.contains(PHI)); 912 (void)PHI; 913 } else 914 llvm_unreachable("all uses should be handled"); 915 } 916 } 917 return true; 918 } 919 920 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42) 921 // store <256 x i32> %43, <256 x i32>* %p, align 64 922 // --> 923 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, 924 // i64 64, x86_amx %42) 925 bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) { 926 Value *Tile = Cast->getOperand(0); 927 // TODO: If it is cast intrinsic or phi node, we can propagate the 928 // shape information through def-use chain. 929 if (!isAMXIntrinsic(Tile)) 930 return false; 931 auto *II = cast<IntrinsicInst>(Tile); 932 // Tile is output from AMX intrinsic. The first operand of the 933 // intrinsic is row, the second operand of the intrinsic is column. 934 Value *Row = II->getOperand(0); 935 Value *Col = II->getOperand(1); 936 IRBuilder<> Builder(ST); 937 // Stride should be equal to col(measured by bytes) 938 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); 939 Value *I8Ptr = 940 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); 941 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; 942 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, 943 Args); 944 return true; 945 } 946 947 // %65 = load <256 x i32>, <256 x i32>* %p, align 64 948 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) 949 // --> 950 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 951 // i8* %p, i64 64) 952 bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { 953 bool EraseLoad = true; 954 Value *Row = nullptr, *Col = nullptr; 955 Use &U = *(Cast->use_begin()); 956 unsigned OpNo = U.getOperandNo(); 957 auto *II = cast<IntrinsicInst>(U.getUser()); 958 // TODO: If it is cast intrinsic or phi node, we can propagate the 959 // shape information through def-use chain. 960 if (!isAMXIntrinsic(II)) 961 return false; 962 std::tie(Row, Col) = getShape(II, OpNo); 963 IRBuilder<> Builder(LD); 964 // Stride should be equal to col(measured by bytes) 965 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); 966 Value *I8Ptr; 967 968 // To save compiling time, we create doninator tree when it is really 969 // needed. 970 if (!DT) 971 DT.reset(new DominatorTree(Func)); 972 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) { 973 // store the value to stack and reload it from stack before cast. 974 auto *AllocaAddr = 975 createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType()); 976 Builder.SetInsertPoint(&*std::next(LD->getIterator())); 977 Builder.CreateStore(LD, AllocaAddr); 978 979 Builder.SetInsertPoint(Cast); 980 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); 981 EraseLoad = false; 982 } else { 983 I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); 984 } 985 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 986 987 Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, 988 std::nullopt, Args); 989 Cast->replaceAllUsesWith(NewInst); 990 991 return EraseLoad; 992 } 993 994 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) { 995 bool Change = false; 996 for (auto *Cast : Casts) { 997 auto *II = cast<IntrinsicInst>(Cast); 998 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42) 999 // store <256 x i32> %43, <256 x i32>* %p, align 64 1000 // --> 1001 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, 1002 // i64 64, x86_amx %42) 1003 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) { 1004 SmallVector<Instruction *, 2> DeadStores; 1005 for (User *U : Cast->users()) { 1006 StoreInst *Store = dyn_cast<StoreInst>(U); 1007 if (!Store) 1008 continue; 1009 if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) { 1010 DeadStores.push_back(Store); 1011 Change = true; 1012 } 1013 } 1014 for (auto *Store : DeadStores) 1015 Store->eraseFromParent(); 1016 } else { // x86_cast_vector_to_tile 1017 SmallVector<Instruction *, 2> DeadLoads; 1018 auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0)); 1019 if (!Load || !Load->hasOneUse()) 1020 continue; 1021 // %65 = load <256 x i32>, <256 x i32>* %p, align 64 1022 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) 1023 // --> 1024 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 1025 // i8* %p, i64 64) 1026 if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) { 1027 // Set the operand is null so that load instruction can be erased. 1028 Cast->setOperand(0, nullptr); 1029 Load->eraseFromParent(); 1030 } 1031 } 1032 } 1033 return Change; 1034 } 1035 1036 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { 1037 bool Change = false; 1038 // Collect tile cast instruction. 1039 SmallVector<Instruction *, 8> Vec2TileInsts; 1040 SmallVector<Instruction *, 8> Tile2VecInsts; 1041 SmallVector<Instruction *, 8> PhiCastWorkList; 1042 SmallSetVector<Instruction *, 16> DeadInst; 1043 for (BasicBlock &BB : Func) { 1044 for (Instruction &I : BB) { 1045 Value *Vec; 1046 if (match(&I, 1047 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec)))) 1048 Vec2TileInsts.push_back(&I); 1049 else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>( 1050 m_Value(Vec)))) 1051 Tile2VecInsts.push_back(&I); 1052 } 1053 } 1054 1055 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) { 1056 for (auto *Inst : Insts) { 1057 for (User *U : Inst->users()) { 1058 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); 1059 if (!II || II->getIntrinsicID() != IID) 1060 continue; 1061 // T1 = vec2tile V0 1062 // V2 = tile2vec T1 1063 // V3 = OP V2 1064 // --> 1065 // T1 = vec2tile V0 1066 // V2 = tile2vec T1 1067 // V3 = OP V0 1068 II->replaceAllUsesWith(Inst->getOperand(0)); 1069 Change = true; 1070 } 1071 } 1072 }; 1073 1074 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector); 1075 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile); 1076 1077 SmallVector<Instruction *, 8> LiveCasts; 1078 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) { 1079 for (auto *Inst : Insts) { 1080 if (Inst->use_empty()) { 1081 Inst->eraseFromParent(); 1082 Change = true; 1083 } else { 1084 LiveCasts.push_back(Inst); 1085 } 1086 } 1087 }; 1088 1089 EraseInst(Vec2TileInsts); 1090 EraseInst(Tile2VecInsts); 1091 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " 1092 "Vec2Tile and Tile2Vec:\n"; 1093 Func.dump()); 1094 Change |= combineLdSt(LiveCasts); 1095 EraseInst(LiveCasts); 1096 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " 1097 "AMXCast and load/store:\n"; 1098 Func.dump()); 1099 1100 // Handle the A->B->A cast, and there is an intervening PHI node. 1101 for (BasicBlock &BB : Func) { 1102 for (Instruction &I : BB) { 1103 if (isAMXCast(&I)) { 1104 if (isa<PHINode>(I.getOperand(0))) 1105 PhiCastWorkList.push_back(&I); 1106 } 1107 } 1108 } 1109 for (auto *I : PhiCastWorkList) { 1110 // We skip the dead Amxcast. 1111 if (DeadInst.contains(I)) 1112 continue; 1113 PHINode *PN = cast<PHINode>(I->getOperand(0)); 1114 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) { 1115 DeadInst.insert(PN); 1116 Change = true; 1117 } 1118 } 1119 1120 // Since we create new phi and merge AMXCast, some old phis and AMXCast might 1121 // have no uses. We do some DeadCodeElimination for them. 1122 while (!DeadInst.empty()) { 1123 Instruction *I = DeadInst.pop_back_val(); 1124 Change |= DCEInstruction(I, DeadInst, TLI); 1125 } 1126 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after " 1127 "optimizeAMXCastFromPhi:\n"; 1128 Func.dump()); 1129 return Change; 1130 } 1131 1132 // There might be remaining AMXcast after combineAMXcast and they should be 1133 // handled elegantly. 1134 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { 1135 IRBuilder<> Builder(AMXCast); 1136 AllocaInst *AllocaAddr; 1137 Value *I8Ptr, *Stride; 1138 auto *Src = AMXCast->getOperand(0); 1139 1140 auto Prepare = [&](Type *MemTy) { 1141 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy); 1142 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); 1143 Stride = Builder.getInt64(64); 1144 }; 1145 1146 if (AMXCast->getType()->isX86_AMXTy()) { 1147 // %2 = amxcast <225 x i32> %src to x86_amx 1148 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, 1149 // i8* %addr3, i64 60, x86_amx %2) 1150 // --> 1151 // %addr = alloca <225 x i32>, align 64 1152 // store <225 x i32> %src, <225 x i32>* %addr, align 64 1153 // %addr2 = bitcast <225 x i32>* %addr to i8* 1154 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60, 1155 // i8* %addr2, 1156 // i64 60) 1157 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, 1158 // i8* %addr3, i64 60, x86_amx %2) 1159 if (AMXCast->use_empty()) { 1160 AMXCast->eraseFromParent(); 1161 return true; 1162 } 1163 Use &U = *(AMXCast->use_begin()); 1164 unsigned OpNo = U.getOperandNo(); 1165 auto *II = dyn_cast<IntrinsicInst>(U.getUser()); 1166 if (!II) 1167 return false; // May be bitcast from x86amx to <256 x i32>. 1168 Prepare(AMXCast->getOperand(0)->getType()); 1169 Builder.CreateStore(Src, AllocaAddr); 1170 // TODO we can pick an constant operand for the shape. 1171 Value *Row = nullptr, *Col = nullptr; 1172 std::tie(Row, Col) = getShape(II, OpNo); 1173 std::array<Value *, 4> Args = { 1174 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())}; 1175 Value *NewInst = Builder.CreateIntrinsic( 1176 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args); 1177 AMXCast->replaceAllUsesWith(NewInst); 1178 AMXCast->eraseFromParent(); 1179 } else { 1180 // %2 = amxcast x86_amx %src to <225 x i32> 1181 // --> 1182 // %addr = alloca <225 x i32>, align 64 1183 // %addr2 = bitcast <225 x i32>* to i8* 1184 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, 1185 // i8* %addr2, i64 %stride) 1186 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64 1187 auto *II = dyn_cast<IntrinsicInst>(Src); 1188 if (!II) 1189 return false; // May be bitcast from <256 x i32> to x86amx. 1190 Prepare(AMXCast->getType()); 1191 Value *Row = II->getOperand(0); 1192 Value *Col = II->getOperand(1); 1193 std::array<Value *, 5> Args = { 1194 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src}; 1195 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, 1196 Args); 1197 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr); 1198 AMXCast->replaceAllUsesWith(NewInst); 1199 AMXCast->eraseFromParent(); 1200 } 1201 1202 return true; 1203 } 1204 1205 bool X86LowerAMXCast::transformAllAMXCast() { 1206 bool Change = false; 1207 // Collect tile cast instruction. 1208 SmallVector<Instruction *, 8> WorkLists; 1209 for (BasicBlock &BB : Func) { 1210 for (Instruction &I : BB) { 1211 if (isAMXCast(&I)) 1212 WorkLists.push_back(&I); 1213 } 1214 } 1215 1216 for (auto *Inst : WorkLists) { 1217 Change |= transformAMXCast(cast<IntrinsicInst>(Inst)); 1218 } 1219 1220 return Change; 1221 } 1222 1223 } // anonymous namespace 1224 1225 namespace { 1226 1227 class X86LowerAMXTypeLegacyPass : public FunctionPass { 1228 public: 1229 static char ID; 1230 1231 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { 1232 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); 1233 } 1234 1235 bool runOnFunction(Function &F) override { 1236 bool C = false; 1237 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 1238 TargetLibraryInfo *TLI = 1239 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 1240 1241 X86LowerAMXCast LAC(F); 1242 C |= LAC.combineAMXcast(TLI); 1243 // There might be remaining AMXcast after combineAMXcast and they should be 1244 // handled elegantly. 1245 C |= LAC.transformAllAMXCast(); 1246 1247 X86LowerAMXType LAT(F); 1248 C |= LAT.visit(); 1249 1250 // Prepare for fast register allocation at O0. 1251 // Todo: May better check the volatile model of AMX code, not just 1252 // by checking Attribute::OptimizeNone and CodeGenOpt::None. 1253 if (TM->getOptLevel() == CodeGenOpt::None) { 1254 // If Front End not use O0 but the Mid/Back end use O0, (e.g. 1255 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make 1256 // sure the amx data is volatile, that is nessary for AMX fast 1257 // register allocation. 1258 if (!F.hasFnAttribute(Attribute::OptimizeNone)) { 1259 X86VolatileTileData VTD(F); 1260 C = VTD.volatileTileData() || C; 1261 } 1262 } 1263 1264 return C; 1265 } 1266 1267 void getAnalysisUsage(AnalysisUsage &AU) const override { 1268 AU.setPreservesCFG(); 1269 AU.addRequired<TargetPassConfig>(); 1270 AU.addRequired<TargetLibraryInfoWrapperPass>(); 1271 } 1272 }; 1273 1274 } // anonymous namespace 1275 1276 static const char PassName[] = "Lower AMX type for load/store"; 1277 char X86LowerAMXTypeLegacyPass::ID = 0; 1278 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, 1279 false) 1280 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 1281 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 1282 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, 1283 false) 1284 1285 FunctionPass *llvm::createX86LowerAMXTypePass() { 1286 return new X86LowerAMXTypeLegacyPass(); 1287 } 1288