1 //===-- X86FastTileConfig.cpp - Fast Tile Register Configure---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to config the shape of AMX physical registers
10 /// AMX register need to be configured before use. Before FastRegAllocation pass
11 /// the ldtilecfg instruction is inserted, however at that time we don't
12 /// know the shape of each physical tile registers, because the register
13 /// allocation is not done yet. This pass runs after register allocation
14 /// pass. It collects the shape information of each physical tile register
15 /// and store the shape in the stack slot that is allocated for load config
16 /// to tile config register.
17 //
18 //===----------------------------------------------------------------------===//
19
20 #include "X86.h"
21 #include "X86InstrBuilder.h"
22 #include "X86MachineFunctionInfo.h"
23 #include "X86RegisterInfo.h"
24 #include "X86Subtarget.h"
25 #include "llvm/CodeGen/MachineFrameInfo.h"
26 #include "llvm/CodeGen/MachineFunctionPass.h"
27 #include "llvm/CodeGen/MachineInstr.h"
28 #include "llvm/CodeGen/MachineRegisterInfo.h"
29 #include "llvm/CodeGen/Passes.h"
30 #include "llvm/CodeGen/TargetInstrInfo.h"
31 #include "llvm/CodeGen/TargetRegisterInfo.h"
32 #include "llvm/InitializePasses.h"
33
34 using namespace llvm;
35
36 #define DEBUG_TYPE "fasttileconfig"
37
38 namespace {
39
40 class X86FastTileConfig : public MachineFunctionPass {
41 // context
42 MachineFunction *MF = nullptr;
43 const X86Subtarget *ST = nullptr;
44 const TargetRegisterInfo *TRI = nullptr;
45 const TargetInstrInfo *TII = nullptr;
46 MachineRegisterInfo *MRI = nullptr;
47
48 MachineInstr *getTileConfigPoint();
49 void tileConfig();
50
51 public:
X86FastTileConfig()52 X86FastTileConfig() : MachineFunctionPass(ID) {}
53
54 bool fastTileConfig();
55 bool isTileLoad(MachineInstr &MI);
56 bool isTileStore(MachineInstr &MI);
57 bool isAMXInstr(MachineInstr &MI);
58 void getTileStoreShape(MachineInstr &MI,
59 SmallVector<MachineOperand *> &ShapedTiles);
60
61 MachineInstr *getKeyAMXInstr(MachineInstr *MI);
62 void getTileShapesCfg(MachineInstr *MI,
63 SmallVector<MachineOperand *> &ShapedTiles);
64 void getShapeCfgInstrs(MachineInstr *MI,
65 std::map<unsigned, MachineInstr *> &RowCfgs,
66 std::map<unsigned, MachineInstr *> &ColCfgs);
67
68 /// Return the pass name.
getPassName() const69 StringRef getPassName() const override {
70 return "Fast Tile Register Configure";
71 }
72
73 void materializeTileCfg(MachineInstr *MI);
74
75 void rewriteTileCfg(SmallVector<MachineOperand *> &ShapedTiles,
76 std::map<unsigned, MachineInstr *> &RowCfgs,
77 std::map<unsigned, MachineInstr *> &ColCfgs);
78
79 /// Perform register allocation.
80 bool runOnMachineFunction(MachineFunction &MFunc) override;
81
getRequiredProperties() const82 MachineFunctionProperties getRequiredProperties() const override {
83 return MachineFunctionProperties().set(
84 MachineFunctionProperties::Property::NoPHIs);
85 }
86
87 static char ID;
88 };
89
90 } // end anonymous namespace
91
92 char X86FastTileConfig::ID = 0;
93
94 INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE,
95 "Fast Tile Register Configure", false, false)
96 INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE,
97 "Fast Tile Register Configure", false, false)
98
isTilePhysReg(MachineOperand & Op)99 static bool isTilePhysReg(MachineOperand &Op) {
100 if (!Op.isReg())
101 return false;
102
103 Register Reg = Op.getReg();
104 if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
105 return true;
106 return false;
107 }
108
getTilePhysRegIdx(MachineOperand * Op)109 static unsigned getTilePhysRegIdx(MachineOperand *Op) {
110 assert(isTilePhysReg(*Op) && "Tile Operand is invalid");
111 return Op->getReg() - X86::TMM0;
112 }
113
adjustRowCfg(unsigned TIdx,MachineInstr * MI)114 static inline void adjustRowCfg(unsigned TIdx, MachineInstr *MI) {
115 unsigned Offset = 48 + TIdx;
116 MI->getOperand(3).ChangeToImmediate(Offset);
117 }
118
adjustColCfg(unsigned TIdx,MachineInstr * MI)119 static inline void adjustColCfg(unsigned TIdx, MachineInstr *MI) {
120 unsigned Offset = 16 + TIdx * 2;
121 MI->getOperand(3).ChangeToImmediate(Offset);
122 }
123
isTileLoad(MachineInstr & MI)124 bool X86FastTileConfig::isTileLoad(MachineInstr &MI) {
125 return MI.getOpcode() == X86::PTILELOADDV ||
126 MI.getOpcode() == X86::PTILELOADDT1V;
127 }
isTileStore(MachineInstr & MI)128 bool X86FastTileConfig::isTileStore(MachineInstr &MI) {
129 return MI.getOpcode() == X86::PTILESTOREDV;
130 }
isAMXInstr(MachineInstr & MI)131 bool X86FastTileConfig::isAMXInstr(MachineInstr &MI) {
132 // TODO: May need to handle some special nontile amx instrucion.
133 if (MI.getOpcode() == X86::PLDTILECFGV || MI.isDebugInstr())
134 return false;
135
136 for (MachineOperand &MO : MI.operands())
137 if (isTilePhysReg(MO))
138 return true;
139
140 return false;
141 }
142
getKeyAMXInstr(MachineInstr * MI)143 MachineInstr *X86FastTileConfig::getKeyAMXInstr(MachineInstr *MI) {
144 auto Cfg = MachineBasicBlock::iterator(MI);
145 MachineBasicBlock *MBB = MI->getParent();
146 MachineInstr *KeyMI = nullptr;
147 int KeyAMXNum = 0;
148
149 for (auto II = Cfg; II != MBB->end(); II++) {
150 if (isTileLoad(*II)) {
151 KeyMI = &*II;
152 continue;
153 }
154
155 if (isTileStore(*II)) {
156 assert(KeyMI && "Key AMX Should be found before!");
157 break;
158 }
159
160 if (isAMXInstr(*II)) {
161 assert((KeyAMXNum == 0) && "Too many Key AMX instruction!");
162 KeyAMXNum++;
163 KeyMI = &*II;
164 }
165 }
166 assert(KeyMI && "There must be an AMX instruction.");
167 return KeyMI;
168 }
169
170 // Orderly get the tiles in key amx instruction, uses before defs.
getTileShapesCfg(MachineInstr * CfgMI,SmallVector<MachineOperand * > & ShapedTiles)171 void X86FastTileConfig::getTileShapesCfg(
172 MachineInstr *CfgMI, SmallVector<MachineOperand *> &ShapedTiles) {
173 MachineInstr *KeyMI = getKeyAMXInstr(CfgMI);
174
175 SmallVector<MachineOperand *> DefTiles;
176 for (MachineOperand &MO : KeyMI->operands()) {
177 if (!isTilePhysReg(MO))
178 continue;
179 if (MO.isDef())
180 DefTiles.push_back(&MO);
181 else
182 ShapedTiles.push_back(&MO);
183 }
184 ShapedTiles.append(DefTiles);
185 }
186
187 // We pre-config the shapes at position named with "amx.tmm.N.shape.row* and
188 // amx.shape.N.col*" at pass "Pre AMX Tile Config".
189 // The 'N' implies the order of tiles in key amx intrinsic.
getShapeCfgInstrs(MachineInstr * MI,std::map<unsigned,MachineInstr * > & RowCfgs,std::map<unsigned,MachineInstr * > & ColCfgs)190 void X86FastTileConfig::getShapeCfgInstrs(
191 MachineInstr *MI, std::map<unsigned, MachineInstr *> &RowCfgs,
192 std::map<unsigned, MachineInstr *> &ColCfgs) {
193 auto Cfg = MachineBasicBlock::iterator(MI);
194 MachineBasicBlock *MBB = MI->getParent();
195
196 for (auto II = Cfg; II != MBB->begin(); II--) {
197 if (isAMXInstr(*II) || II->isTerminator() || II->isCall())
198 break;
199 if (!II->mayStore() || !II->hasOneMemOperand())
200 continue;
201 const Value *MemPtr = II->memoperands()[0]->getValue();
202 if (!MemPtr)
203 continue;
204
205 StringRef Name = MemPtr->getName();
206 if (!Name.startswith("amx.tmm."))
207 continue;
208
209 // Get the 'N'th tile shape config in key amx instruction.
210 auto N = Name.find(".shape");
211 StringRef STileIdx = Name.slice(8, N);
212 unsigned Idx;
213 STileIdx.getAsInteger(10, Idx);
214
215 // And related them with their store instructions.
216 if (Name.contains("row"))
217 RowCfgs[Idx] = &*II;
218 else if (Name.contains("col"))
219 ColCfgs[Idx] = &*II;
220 else
221 llvm_unreachable("Invalid tile shape info!");
222 }
223 assert((RowCfgs.size() == ColCfgs.size()) &&
224 "The number of tile row and col must be equal!");
225 }
226
227 // Here is the data format for the tile config.
228 // 0 palette = 1 now.
229 // 1 start_row = 0 now.
230 // 2-15 reserved, must be zero
231 // 16-17 tile0.colsb Tile 0 bytes per row.
232 // 18-19 tile1.colsb Tile 1 bytes per row.
233 // 20-21 tile2.colsb Tile 2 bytes per row.
234 // ... (sequence continues)
235 // 30-31 tile7.colsb Tile 7 bytes per row.
236 // 32-47 reserved, must be zero
237 // 48 tile0.rows Tile 0 rows.
238 // 49 tile1.rows Tile 1 rows.
239 // 50 tile2.rows Tile 2 rows.
240 // ... (sequence continues)
241 // 55 tile7.rows Tile 7 rows.
242 // 56-63 reserved, must be zero
rewriteTileCfg(SmallVector<MachineOperand * > & ShapedTiles,std::map<unsigned,MachineInstr * > & RowCfgs,std::map<unsigned,MachineInstr * > & ColCfgs)243 void X86FastTileConfig::rewriteTileCfg(
244 SmallVector<MachineOperand *> &ShapedTiles,
245 std::map<unsigned, MachineInstr *> &RowCfgs,
246 std::map<unsigned, MachineInstr *> &ColCfgs) {
247 assert((RowCfgs.size() == ShapedTiles.size()) &&
248 "The number of tile shapes not equal with the number of tiles!");
249
250 // Orderly get the tiles and adjust the shape config.
251 for (unsigned I = 0, E = ShapedTiles.size(); I < E; I++) {
252 MachineOperand *MO = ShapedTiles[I];
253 unsigned TmmIdx = getTilePhysRegIdx(MO);
254 if (I == TmmIdx)
255 continue;
256 adjustRowCfg(TmmIdx, RowCfgs[I]);
257 adjustColCfg(TmmIdx, ColCfgs[I]);
258 }
259 }
260
261 // We have already preconfig the shapes before fast register allocation at
262 // X86PreAMXConfig::preWriteTileCfg(). Now, we have done fast register
263 // allocation, the shapes pre-written before may not rightly corresponding
264 // to the correct tmm registers, so we need adjust them.
materializeTileCfg(MachineInstr * CfgMI)265 void X86FastTileConfig::materializeTileCfg(MachineInstr *CfgMI) {
266 SmallVector<MachineOperand *> ShapedTiles;
267 std::map<unsigned, MachineInstr *> RowCfgs;
268 std::map<unsigned, MachineInstr *> ColCfgs;
269
270 // Orderly keep the tile uses and def in ShapedTiles;
271 getTileShapesCfg(CfgMI, ShapedTiles);
272 assert(ShapedTiles.size() && "Not find shapes config!");
273
274 getShapeCfgInstrs(CfgMI, RowCfgs, ColCfgs);
275
276 rewriteTileCfg(ShapedTiles, RowCfgs, ColCfgs);
277 }
278
fastTileConfig()279 bool X86FastTileConfig::fastTileConfig() {
280 bool Changed = false;
281
282 for (MachineBasicBlock &MBB : *MF) {
283 SmallVector<MachineInstr *, 2> CFGs;
284 for (MachineInstr &MI : MBB)
285 if (MI.getOpcode() == X86::PLDTILECFGV)
286 CFGs.push_back(&MI);
287 for (auto *MI : CFGs)
288 materializeTileCfg(MI);
289 if (!CFGs.empty())
290 Changed = true;
291 }
292 return Changed;
293 }
294
runOnMachineFunction(MachineFunction & MFunc)295 bool X86FastTileConfig::runOnMachineFunction(MachineFunction &MFunc) {
296 MF = &MFunc;
297 MRI = &MFunc.getRegInfo();
298 ST = &MFunc.getSubtarget<X86Subtarget>();
299 TRI = ST->getRegisterInfo();
300 TII = MFunc.getSubtarget().getInstrInfo();
301
302 return fastTileConfig();
303 }
304
createX86FastTileConfigPass()305 FunctionPass *llvm::createX86FastTileConfigPass() {
306 return new X86FastTileConfig();
307 }
308