1 //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to transform <256 x i32> load/store 10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only 11 /// provides simple operation on x86_amx. The basic elementwise operation 12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> 13 /// and only AMX intrinsics can operate on the type, we need transform 14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can 15 /// not be combined with load/store, we transform the bitcast to amx load/store 16 /// and <256 x i32> store/load. 17 // 18 //===----------------------------------------------------------------------===// 19 // 20 #include "X86.h" 21 #include "llvm/ADT/PostOrderIterator.h" 22 #include "llvm/ADT/SmallSet.h" 23 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 24 #include "llvm/Analysis/TargetTransformInfo.h" 25 #include "llvm/CodeGen/Passes.h" 26 #include "llvm/CodeGen/ValueTypes.h" 27 #include "llvm/IR/DataLayout.h" 28 #include "llvm/IR/Function.h" 29 #include "llvm/IR/IRBuilder.h" 30 #include "llvm/IR/Instructions.h" 31 #include "llvm/IR/IntrinsicInst.h" 32 #include "llvm/IR/IntrinsicsX86.h" 33 #include "llvm/IR/PatternMatch.h" 34 #include "llvm/InitializePasses.h" 35 #include "llvm/Pass.h" 36 37 using namespace llvm; 38 using namespace PatternMatch; 39 40 #define DEBUG_TYPE "lower-amx-type" 41 42 static AllocaInst *CreateAllocaInst(IRBuilder<> &Builder, BasicBlock *BB) { 43 Function &F = *BB->getParent(); 44 Module *M = BB->getModule(); 45 const DataLayout &DL = M->getDataLayout(); 46 47 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); 48 LLVMContext &Ctx = Builder.getContext(); 49 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); 50 unsigned AllocaAS = DL.getAllocaAddrSpace(); 51 AllocaInst *AllocaRes = 52 new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front()); 53 AllocaRes->setAlignment(AllocaAlignment); 54 return AllocaRes; 55 } 56 57 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) { 58 Value *Row = nullptr, *Col = nullptr; 59 switch (II->getIntrinsicID()) { 60 default: 61 llvm_unreachable("Expect amx intrinsics"); 62 case Intrinsic::x86_tileloadd64_internal: 63 case Intrinsic::x86_tilestored64_internal: { 64 Row = II->getArgOperand(0); 65 Col = II->getArgOperand(1); 66 break; 67 } 68 // a * b + c 69 // The shape depends on which operand. 70 case Intrinsic::x86_tdpbssd_internal: { 71 switch (OpNo) { 72 case 3: 73 Row = II->getArgOperand(0); 74 Col = II->getArgOperand(1); 75 break; 76 case 4: 77 Row = II->getArgOperand(0); 78 Col = II->getArgOperand(2); 79 break; 80 case 5: 81 Row = II->getArgOperand(2); 82 Col = II->getArgOperand(1); 83 break; 84 } 85 break; 86 } 87 } 88 89 return std::make_pair(Row, Col); 90 } 91 92 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 93 // %2 = bitcast <256 x i32> %src to x86_amx 94 // --> 95 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 96 // i8* %addr, i64 %stride64) 97 static void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { 98 Value *Row = nullptr, *Col = nullptr; 99 Use &U = *(Bitcast->use_begin()); 100 unsigned OpNo = U.getOperandNo(); 101 auto *II = cast<IntrinsicInst>(U.getUser()); 102 std::tie(Row, Col) = getShape(II, OpNo); 103 IRBuilder<> Builder(Bitcast); 104 // Use the maximun column as stride. 105 Value *Stride = Builder.getInt64(64); 106 Value *I8Ptr = 107 Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); 108 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 109 110 Value *NewInst = 111 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); 112 Bitcast->replaceAllUsesWith(NewInst); 113 } 114 115 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, 116 // %stride); 117 // %13 = bitcast x86_amx %src to <256 x i32> 118 // store <256 x i32> %13, <256 x i32>* %addr, align 64 119 // --> 120 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 121 // %stride64, %13) 122 static void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { 123 124 Value *Tile = Bitcast->getOperand(0); 125 auto *II = cast<IntrinsicInst>(Tile); 126 // Tile is output from AMX intrinsic. The first operand of the 127 // intrinsic is row, the second operand of the intrinsic is column. 128 Value *Row = II->getOperand(0); 129 Value *Col = II->getOperand(1); 130 IRBuilder<> Builder(ST); 131 // Use the maximum column as stride. It must be the same with load 132 // stride. 133 Value *Stride = Builder.getInt64(64); 134 Value *I8Ptr = 135 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); 136 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; 137 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); 138 if (Bitcast->hasOneUse()) 139 return; 140 // %13 = bitcast x86_amx %src to <256 x i32> 141 // store <256 x i32> %13, <256 x i32>* %addr, align 64 142 // %add = <256 x i32> %13, <256 x i32> %src2 143 // --> 144 // %13 = bitcast x86_amx %src to <256 x i32> 145 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 146 // %stride64, %13) 147 // %14 = load <256 x i32>, %addr 148 // %add = <256 x i32> %14, <256 x i32> %src2 149 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1)); 150 Bitcast->replaceAllUsesWith(Vec); 151 } 152 153 // transform bitcast to <store, load> instructions. 154 static bool transformBitcast(BitCastInst *Bitcast) { 155 IRBuilder<> Builder(Bitcast); 156 AllocaInst *AllocaAddr; 157 Value *I8Ptr, *Stride; 158 auto *Src = Bitcast->getOperand(0); 159 160 auto Prepare = [&]() { 161 AllocaAddr = CreateAllocaInst(Builder, Bitcast->getParent()); 162 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); 163 Stride = Builder.getInt64(64); 164 }; 165 166 if (Bitcast->getType()->isX86_AMXTy()) { 167 // %2 = bitcast <256 x i32> %src to x86_amx 168 // --> 169 // %addr = alloca <256 x i32>, align 64 170 // store <256 x i32> %src, <256 x i32>* %addr, align 64 171 // %addr2 = bitcast <256 x i32>* to i8* 172 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 173 // i8* %addr2, 174 // i64 64) 175 Use &U = *(Bitcast->use_begin()); 176 unsigned OpNo = U.getOperandNo(); 177 auto *II = dyn_cast<IntrinsicInst>(U.getUser()); 178 if (!II) 179 return false; // May be bitcast from x86amx to <256 x i32>. 180 Prepare(); 181 Builder.CreateStore(Src, AllocaAddr); 182 // TODO we can pick an constant operand for the shape. 183 Value *Row = nullptr, *Col = nullptr; 184 std::tie(Row, Col) = getShape(II, OpNo); 185 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 186 Value *NewInst = Builder.CreateIntrinsic( 187 Intrinsic::x86_tileloadd64_internal, None, Args); 188 Bitcast->replaceAllUsesWith(NewInst); 189 } else { 190 // %2 = bitcast x86_amx %src to <256 x i32> 191 // --> 192 // %addr = alloca <256 x i32>, align 64 193 // %addr2 = bitcast <256 x i32>* to i8* 194 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, 195 // i8* %addr2, i64 %stride) 196 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 197 auto *II = dyn_cast<IntrinsicInst>(Src); 198 if (!II) 199 return false; // May be bitcast from <256 x i32> to x86amx. 200 Prepare(); 201 Value *Row = II->getOperand(0); 202 Value *Col = II->getOperand(1); 203 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; 204 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); 205 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr); 206 Bitcast->replaceAllUsesWith(NewInst); 207 } 208 209 return true; 210 } 211 212 namespace { 213 class X86LowerAMXType { 214 Function &Func; 215 216 public: 217 X86LowerAMXType(Function &F) : Func(F) {} 218 bool visit(); 219 }; 220 221 bool X86LowerAMXType::visit() { 222 SmallVector<Instruction *, 8> DeadInsts; 223 224 for (BasicBlock *BB : post_order(&Func)) { 225 for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend(); 226 II != IE;) { 227 Instruction &Inst = *II++; 228 auto *Bitcast = dyn_cast<BitCastInst>(&Inst); 229 if (!Bitcast) 230 continue; 231 232 Value *Src = Bitcast->getOperand(0); 233 if (Bitcast->getType()->isX86_AMXTy()) { 234 if (Bitcast->user_empty()) { 235 DeadInsts.push_back(Bitcast); 236 continue; 237 } 238 LoadInst *LD = dyn_cast<LoadInst>(Src); 239 if (!LD) { 240 if (transformBitcast(Bitcast)) 241 DeadInsts.push_back(Bitcast); 242 continue; 243 } 244 // If load has mutli-user, duplicate a vector load. 245 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 246 // %2 = bitcast <256 x i32> %src to x86_amx 247 // %add = add <256 x i32> %src, <256 x i32> %src2 248 // --> 249 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 250 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 251 // i8* %addr, i64 %stride64) 252 // %add = add <256 x i32> %src, <256 x i32> %src2 253 254 // If load has one user, the load will be eliminated in DAG ISel. 255 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 256 // %2 = bitcast <256 x i32> %src to x86_amx 257 // --> 258 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 259 // i8* %addr, i64 %stride64) 260 combineLoadBitcast(LD, Bitcast); 261 DeadInsts.push_back(Bitcast); 262 if (LD->hasOneUse()) 263 DeadInsts.push_back(LD); 264 } else if (Src->getType()->isX86_AMXTy()) { 265 if (Bitcast->user_empty()) { 266 DeadInsts.push_back(Bitcast); 267 continue; 268 } 269 StoreInst *ST = nullptr; 270 for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end(); 271 UI != UE;) { 272 Value *I = (UI++)->getUser(); 273 ST = dyn_cast<StoreInst>(I); 274 if (ST) 275 break; 276 } 277 if (!ST) { 278 if (transformBitcast(Bitcast)) 279 DeadInsts.push_back(Bitcast); 280 continue; 281 } 282 // If bitcast (%13) has one use, combine bitcast and store to amx store. 283 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, 284 // %stride); 285 // %13 = bitcast x86_amx %src to <256 x i32> 286 // store <256 x i32> %13, <256 x i32>* %addr, align 64 287 // --> 288 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 289 // %stride64, %13) 290 // 291 // If bitcast (%13) has multi-use, transform as below. 292 // %13 = bitcast x86_amx %src to <256 x i32> 293 // store <256 x i32> %13, <256 x i32>* %addr, align 64 294 // %add = <256 x i32> %13, <256 x i32> %src2 295 // --> 296 // %13 = bitcast x86_amx %src to <256 x i32> 297 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 298 // %stride64, %13) 299 // %14 = load <256 x i32>, %addr 300 // %add = <256 x i32> %14, <256 x i32> %src2 301 // 302 combineBitcastStore(Bitcast, ST); 303 // Delete user first. 304 DeadInsts.push_back(ST); 305 DeadInsts.push_back(Bitcast); 306 } 307 } 308 } 309 310 bool C = !DeadInsts.empty(); 311 312 for (auto *Inst : DeadInsts) 313 Inst->eraseFromParent(); 314 315 return C; 316 } 317 } // anonymous namespace 318 319 namespace { 320 321 class X86LowerAMXTypeLegacyPass : public FunctionPass { 322 public: 323 static char ID; 324 325 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { 326 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); 327 } 328 329 bool runOnFunction(Function &F) override { 330 X86LowerAMXType LAT(F); 331 bool C = LAT.visit(); 332 return C; 333 } 334 335 void getAnalysisUsage(AnalysisUsage &AU) const override { 336 AU.setPreservesCFG(); 337 } 338 }; 339 340 } // anonymous namespace 341 342 static const char PassName[] = "Lower AMX type for load/store"; 343 char X86LowerAMXTypeLegacyPass::ID = 0; 344 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, 345 false) 346 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, 347 false) 348 349 FunctionPass *llvm::createX86LowerAMXTypePass() { 350 return new X86LowerAMXTypeLegacyPass(); 351 } 352