1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to pre-config the shapes of AMX registers 10 /// AMX register needs to be configured before use. The shapes of AMX register 11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions. 12 /// 13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable 14 /// for all variable shapes. ldtilecfg will be inserted more than once if we 15 /// cannot find a dominating point for all AMX instructions. 16 /// 17 /// The configure register is caller saved according to ABI. We need to insert 18 /// ldtilecfg again after the call instruction if callee clobbers any AMX 19 /// registers. 20 /// 21 /// This pass calculates all points that ldtilecfg need to be inserted to and 22 /// insert them. It reports error if the reachability conditions aren't met. 23 // 24 //===----------------------------------------------------------------------===// 25 26 #include "X86.h" 27 #include "X86InstrBuilder.h" 28 #include "X86RegisterInfo.h" 29 #include "X86Subtarget.h" 30 #include "llvm/CodeGen/MachineFunctionPass.h" 31 #include "llvm/CodeGen/MachineInstr.h" 32 #include "llvm/CodeGen/MachineLoopInfo.h" 33 #include "llvm/CodeGen/MachineRegisterInfo.h" 34 #include "llvm/CodeGen/Passes.h" 35 #include "llvm/CodeGen/TargetInstrInfo.h" 36 #include "llvm/CodeGen/TargetRegisterInfo.h" 37 #include "llvm/InitializePasses.h" 38 39 using namespace llvm; 40 41 #define DEBUG_TYPE "tile-pre-config" 42 #define REPORT_CONFIG_FAIL \ 43 report_fatal_error( \ 44 MF.getName() + \ 45 ": Failed to config tile register, please define the shape earlier"); 46 47 namespace { 48 49 struct MIRef { 50 MachineInstr *MI = nullptr; 51 MachineBasicBlock *MBB = nullptr; 52 // A virtual position for instruction that will be inserted after MI. 53 size_t Pos = 0; 54 MIRef() = default; 55 MIRef(MachineBasicBlock *MBB) : MBB(MBB) { 56 for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI(); 57 ++I, ++Pos) 58 MI = &*I; 59 } 60 MIRef(MachineInstr *MI) 61 : MI(MI), MBB(MI->getParent()), 62 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} 63 MIRef(MachineInstr *MI, MachineBasicBlock *MBB) 64 : MI(MI), MBB(MBB), 65 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} 66 MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos) 67 : MI(MI), MBB(MBB), Pos(Pos) {} 68 operator bool() const { return MBB != nullptr; } 69 bool operator==(const MIRef &RHS) const { 70 return MI == RHS.MI && MBB == RHS.MBB; 71 } 72 bool operator!=(const MIRef &RHS) const { return !(*this == RHS); } 73 bool operator<(const MIRef &RHS) const { 74 // Comparison between different BBs happens when inserting a MIRef into set. 75 // So we compare MBB first to make the insertion happy. 76 return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos); 77 } 78 bool operator>(const MIRef &RHS) const { 79 // Comparison between different BBs happens when inserting a MIRef into set. 80 // So we compare MBB first to make the insertion happy. 81 return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos); 82 } 83 }; 84 85 struct BBInfo { 86 MIRef FirstAMX; 87 MIRef LastCall; 88 bool HasAMXRegLiveIn = false; 89 bool TileCfgForbidden = false; 90 bool NeedTileCfgLiveIn = false; 91 }; 92 93 class X86PreTileConfig : public MachineFunctionPass { 94 MachineRegisterInfo *MRI; 95 const MachineLoopInfo *MLI; 96 SmallSet<MachineInstr *, 8> DefVisited; 97 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo; 98 DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs; 99 100 /// Check if the callee will clobber AMX registers. 101 bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) { 102 auto Iter = llvm::find_if( 103 MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); }); 104 if (Iter == MI.operands_end()) 105 return false; 106 UsableRegs.clearBitsInMask(Iter->getRegMask()); 107 return !UsableRegs.none(); 108 } 109 110 /// Check if MI is AMX pseudo instruction. 111 bool isAMXInstruction(MachineInstr &MI) { 112 if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3) 113 return false; 114 MachineOperand &MO = MI.getOperand(0); 115 // We can simply check if it is AMX instruction by its def. 116 // But we should exclude old API which uses physical registers. 117 if (MO.isReg() && MO.getReg().isVirtual() && 118 MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) { 119 collectShapeInfo(MI); 120 return true; 121 } 122 // PTILESTOREDV is the only exception that doesn't def a AMX register. 123 return MI.getOpcode() == X86::PTILESTOREDV; 124 } 125 126 /// Check if it is an edge from loop bottom to loop head. 127 bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) { 128 if (!MLI->isLoopHeader(Header)) 129 return false; 130 auto *ML = MLI->getLoopFor(Header); 131 if (ML->contains(Bottom) && ML->isLoopLatch(Bottom)) 132 return true; 133 134 return false; 135 } 136 137 /// Collect the shape def information for later use. 138 void collectShapeInfo(MachineInstr &MI); 139 140 /// Try to hoist shapes definded below AMX instructions. 141 bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) { 142 MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX; 143 auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX); 144 auto InsertPoint = FirstAMX.MI->getIterator(); 145 for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) { 146 // Do not hoist instructions that access memory. 147 if (I->MI->mayLoadOrStore()) 148 return false; 149 for (auto &MO : I->MI->operands()) { 150 if (MO.isDef()) 151 continue; 152 // Do not hoist instructions if the sources' def under AMX instruction. 153 // TODO: We can handle isMoveImmediate MI here. 154 if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX) 155 return false; 156 // TODO: Maybe need more checks here. 157 } 158 MBB->insert(InsertPoint, I->MI->removeFromParent()); 159 } 160 // We only need to mark the last shape in the BB now. 161 Shapes.clear(); 162 Shapes.push_back(MIRef(&*--InsertPoint, MBB)); 163 return true; 164 } 165 166 public: 167 X86PreTileConfig() : MachineFunctionPass(ID) {} 168 169 /// Return the pass name. 170 StringRef getPassName() const override { 171 return "Tile Register Pre-configure"; 172 } 173 174 /// X86PreTileConfig analysis usage. 175 void getAnalysisUsage(AnalysisUsage &AU) const override { 176 AU.setPreservesAll(); 177 AU.addRequired<MachineLoopInfo>(); 178 MachineFunctionPass::getAnalysisUsage(AU); 179 } 180 181 /// Clear MF related structures. 182 void releaseMemory() override { 183 ShapeBBs.clear(); 184 DefVisited.clear(); 185 BBVisitedInfo.clear(); 186 } 187 188 /// Perform ldtilecfg instructions inserting. 189 bool runOnMachineFunction(MachineFunction &MF) override; 190 191 static char ID; 192 }; 193 194 } // end anonymous namespace 195 196 char X86PreTileConfig::ID = 0; 197 198 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", 199 "Tile Register Pre-configure", false, false) 200 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) 201 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", 202 "Tile Register Pre-configure", false, false) 203 204 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) { 205 auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) { 206 MIRef MIR(MI, MBB); 207 auto I = llvm::lower_bound(ShapeBBs[MBB], MIR); 208 if (I == ShapeBBs[MBB].end() || *I != MIR) 209 ShapeBBs[MBB].insert(I, MIR); 210 }; 211 212 SmallVector<Register, 8> WorkList( 213 {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()}); 214 while (!WorkList.empty()) { 215 Register R = WorkList.pop_back_val(); 216 MachineInstr *DefMI = MRI->getVRegDef(R); 217 assert(DefMI && "R must has one define instruction"); 218 MachineBasicBlock *DefMBB = DefMI->getParent(); 219 if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second) 220 continue; 221 if (DefMI->isPHI()) { 222 for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2) 223 if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB())) 224 RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def. 225 else 226 WorkList.push_back(DefMI->getOperand(I).getReg()); 227 } else { 228 RecordShape(DefMI, DefMBB); 229 } 230 } 231 } 232 233 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { 234 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); 235 const TargetInstrInfo *TII = ST.getInstrInfo(); 236 const TargetRegisterInfo *TRI = ST.getRegisterInfo(); 237 const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); 238 239 BitVector AMXRegs(TRI->getNumRegs()); 240 for (unsigned I = 0; I < RC->getNumRegs(); I++) 241 AMXRegs.set(X86::TMM0 + I); 242 243 // Iterate MF to collect information. 244 MRI = &MF.getRegInfo(); 245 MLI = &getAnalysis<MachineLoopInfo>(); 246 SmallSet<MIRef, 8> CfgNeedInsert; 247 SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs; 248 for (auto &MBB : MF) { 249 size_t Pos = 0; 250 for (auto &MI : MBB) { 251 ++Pos; 252 if (isAMXInstruction(MI)) { 253 // If there's call before the AMX, we need to reload tile config. 254 if (BBVisitedInfo[&MBB].LastCall) 255 CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall); 256 else // Otherwise, we need tile config to live in this BB. 257 BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true; 258 // Always record the first AMX in case there's shape def after it. 259 if (!BBVisitedInfo[&MBB].FirstAMX) 260 BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos); 261 } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) { 262 // Record the call only if the callee clobbers all AMX registers. 263 BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos); 264 } 265 } 266 if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) { 267 if (&MBB == &MF.front()) 268 CfgNeedInsert.insert(MIRef(&MBB)); 269 else 270 CfgLiveInBBs.push_back(&MBB); 271 } 272 if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn) 273 for (auto *Succ : MBB.successors()) 274 if (!isLoopBackEdge(Succ, &MBB)) 275 BBVisitedInfo[Succ].HasAMXRegLiveIn = true; 276 } 277 278 // Update NeedTileCfgLiveIn for predecessors. 279 while (!CfgLiveInBBs.empty()) { 280 MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val(); 281 for (auto *Pred : MBB->predecessors()) { 282 if (BBVisitedInfo[Pred].LastCall) { 283 CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall); 284 } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) { 285 BBVisitedInfo[Pred].NeedTileCfgLiveIn = true; 286 if (Pred == &MF.front()) 287 CfgNeedInsert.insert(MIRef(Pred)); 288 else 289 CfgLiveInBBs.push_back(Pred); 290 } 291 } 292 } 293 294 // There's no AMX instruction if we didn't find a tile config live in point. 295 if (CfgNeedInsert.empty()) 296 return false; 297 298 // Avoid to insert ldtilecfg before any shape defs. 299 SmallVector<MachineBasicBlock *, 8> WorkList; 300 for (auto &I : ShapeBBs) { 301 // TODO: We can hoist shapes across BBs here. 302 if (BBVisitedInfo[I.first].HasAMXRegLiveIn) 303 REPORT_CONFIG_FAIL 304 if (BBVisitedInfo[I.first].FirstAMX && 305 BBVisitedInfo[I.first].FirstAMX < I.second.back() && 306 !hoistShapesInBB(I.first, I.second)) 307 REPORT_CONFIG_FAIL 308 WorkList.push_back(I.first); 309 } 310 while (!WorkList.empty()) { 311 MachineBasicBlock *MBB = WorkList.pop_back_val(); 312 for (auto *Pred : MBB->predecessors()) { 313 if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) { 314 BBVisitedInfo[Pred].TileCfgForbidden = true; 315 WorkList.push_back(Pred); 316 } 317 } 318 } 319 320 DebugLoc DL; 321 SmallSet<MIRef, 8> VisitedOrInserted; 322 int SS = MF.getFrameInfo().CreateStackObject( 323 ST.getTileConfigSize(), ST.getTileConfigAlignment(), false); 324 325 // Try to insert for the tile config live in points. 326 for (auto I : CfgNeedInsert) { 327 SmallSet<MIRef, 8> InsertPoints; 328 SmallVector<MIRef, 8> WorkList({I}); 329 while (!WorkList.empty()) { 330 MIRef I = WorkList.pop_back_val(); 331 if (!VisitedOrInserted.count(I)) { 332 if (!BBVisitedInfo[I.MBB].TileCfgForbidden) { 333 // If the BB is all shapes reachable, stop sink and try to insert. 334 InsertPoints.insert(I); 335 } else { 336 // Avoid the BB to be multi visited. 337 VisitedOrInserted.insert(I); 338 // Sink the inserting point along the chain with NeedTileCfgLiveIn = 339 // true when MBB isn't all shapes reachable. 340 for (auto *Succ : I.MBB->successors()) 341 if (BBVisitedInfo[Succ].NeedTileCfgLiveIn) 342 WorkList.push_back(MIRef(Succ)); 343 } 344 } 345 } 346 347 // A given point might be forked due to shape conditions are not met. 348 for (MIRef I : InsertPoints) { 349 // Make sure we insert ldtilecfg after the last shape def in MBB. 350 if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back()) 351 I = ShapeBBs[I.MBB].back(); 352 // There're chances the MBB is sunk more than once. Record it to avoid 353 // multi insert. 354 if (VisitedOrInserted.insert(I).second) { 355 auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin(); 356 addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)), 357 SS); 358 } 359 } 360 } 361 362 // Zero stack slot. 363 MachineBasicBlock &MBB = MF.front(); 364 MachineInstr *MI = &*MBB.begin(); 365 if (ST.hasAVX512()) { 366 Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); 367 BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm) 368 .addReg(Zmm, RegState::Undef) 369 .addReg(Zmm, RegState::Undef); 370 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS) 371 .addReg(Zmm); 372 } else if (ST.hasAVX2()) { 373 Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); 374 BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm) 375 .addReg(Ymm, RegState::Undef) 376 .addReg(Ymm, RegState::Undef); 377 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS) 378 .addReg(Ymm); 379 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32) 380 .addReg(Ymm); 381 } else { 382 assert(ST.hasSSE2() && "AMX should assume SSE2 enabled"); 383 Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); 384 BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm) 385 .addReg(Xmm, RegState::Undef) 386 .addReg(Xmm, RegState::Undef); 387 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS) 388 .addReg(Xmm); 389 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16) 390 .addReg(Xmm); 391 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32) 392 .addReg(Xmm); 393 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48) 394 .addReg(Xmm); 395 } 396 // Fill in the palette first. 397 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1); 398 399 return true; 400 } 401 402 FunctionPass *llvm::createX86PreTileConfigPass() { 403 return new X86PreTileConfig(); 404 } 405