1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ... |
27 /// | ... |
28 /// | "use %td" |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ... |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34 /// | ... |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2" |
37 /// ----------------------------------------------------------
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 #include "X86.h"
42 #include "llvm/ADT/PostOrderIterator.h"
43 #include "llvm/ADT/SmallSet.h"
44 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
45 #include "llvm/Analysis/TargetTransformInfo.h"
46 #include "llvm/CodeGen/Passes.h"
47 #include "llvm/CodeGen/TargetPassConfig.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/IR/DataLayout.h"
50 #include "llvm/IR/Function.h"
51 #include "llvm/IR/IRBuilder.h"
52 #include "llvm/IR/Instructions.h"
53 #include "llvm/IR/IntrinsicInst.h"
54 #include "llvm/IR/IntrinsicsX86.h"
55 #include "llvm/IR/PatternMatch.h"
56 #include "llvm/InitializePasses.h"
57 #include "llvm/Pass.h"
58 #include "llvm/Target/TargetMachine.h"
59
60 using namespace llvm;
61 using namespace PatternMatch;
62
63 #define DEBUG_TYPE "lower-amx-type"
64
createAllocaInstAtEntry(IRBuilder<> & Builder,BasicBlock * BB)65 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder,
66 BasicBlock *BB) {
67 Function &F = *BB->getParent();
68 Module *M = BB->getModule();
69 const DataLayout &DL = M->getDataLayout();
70
71 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
72 LLVMContext &Ctx = Builder.getContext();
73 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
74 unsigned AllocaAS = DL.getAllocaAddrSpace();
75 AllocaInst *AllocaRes =
76 new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front());
77 AllocaRes->setAlignment(AllocaAlignment);
78 return AllocaRes;
79 }
80
81 namespace {
82 class X86LowerAMXType {
83 Function &Func;
84 TargetMachine *TM = nullptr;
85
86 // In AMX intrinsics we let Shape = {Row, Col}, but the
87 // RealCol = Col / ElementSize. We may use the RealCol
88 // as a new Row for other new created AMX intrinsics.
89 std::map<Value *, Value *> Col2Row;
90
91 public:
X86LowerAMXType(Function & F,TargetMachine * TargetM)92 X86LowerAMXType(Function &F, TargetMachine *TargetM) : Func(F), TM(TargetM) {}
93 bool visit();
94 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
95 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
96 bool transformBitcast(BitCastInst *Bitcast);
97 std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo);
98 Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
99 };
100
getRowFromCol(Instruction * II,Value * V,unsigned Granularity)101 Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V,
102 unsigned Granularity) {
103 if (Col2Row.count(V))
104 return Col2Row[V];
105 IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt());
106 if (auto *I = dyn_cast<Instruction>(V)) {
107 BasicBlock::iterator Iter = I->getIterator();
108 ++Iter;
109 Builder.SetInsertPoint(&*Iter);
110 }
111 ConstantInt *Gran = Builder.getInt16(Granularity);
112 Value *RealRow = Builder.CreateUDiv(V, Gran);
113 Col2Row[V] = RealRow;
114 return RealRow;
115 }
116
getShape(IntrinsicInst * II,unsigned OpNo)117 std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II,
118 unsigned OpNo) {
119 Value *Row = nullptr, *Col = nullptr;
120 switch (II->getIntrinsicID()) {
121 default:
122 llvm_unreachable("Expect amx intrinsics");
123 case Intrinsic::x86_tileloadd64_internal:
124 case Intrinsic::x86_tileloaddt164_internal:
125 case Intrinsic::x86_tilestored64_internal: {
126 Row = II->getArgOperand(0);
127 Col = II->getArgOperand(1);
128 break;
129 }
130 // a * b + c
131 // The shape depends on which operand.
132 case Intrinsic::x86_tdpbssd_internal:
133 case Intrinsic::x86_tdpbsud_internal:
134 case Intrinsic::x86_tdpbusd_internal:
135 case Intrinsic::x86_tdpbuud_internal:
136 case Intrinsic::x86_tdpbf16ps_internal: {
137 switch (OpNo) {
138 case 3:
139 Row = II->getArgOperand(0);
140 Col = II->getArgOperand(1);
141 break;
142 case 4:
143 Row = II->getArgOperand(0);
144 Col = II->getArgOperand(2);
145 break;
146 case 5:
147 Row = II->getArgOperand(2);
148 // FIXME: There is a design bug for AMX shape, which the Col should be
149 // Col/4 if it will be used as Row, but current Greedy RA can't handle
150 // this case well, it may failed if we generate a new Shape definition.
151 // So Let's just do it in O0 first.
152 // Row = Row / 4
153 if (TM->getOptLevel() == CodeGenOpt::None)
154 Row = getRowFromCol(II, Row, 4);
155 Col = II->getArgOperand(1);
156 break;
157 }
158 break;
159 }
160 }
161
162 return std::make_pair(Row, Col);
163 }
164
165 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
166 // %2 = bitcast <256 x i32> %src to x86_amx
167 // -->
168 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
169 // i8* %addr, i64 %stride64)
combineLoadBitcast(LoadInst * LD,BitCastInst * Bitcast)170 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
171 Value *Row = nullptr, *Col = nullptr;
172 Use &U = *(Bitcast->use_begin());
173 unsigned OpNo = U.getOperandNo();
174 auto *II = cast<IntrinsicInst>(U.getUser());
175 std::tie(Row, Col) = getShape(II, OpNo);
176 IRBuilder<> Builder(Bitcast);
177 // Use the maximun column as stride.
178 Value *Stride = Builder.getInt64(64);
179 Value *I8Ptr =
180 Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
181 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
182
183 Value *NewInst =
184 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
185 Bitcast->replaceAllUsesWith(NewInst);
186 }
187
188 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
189 // %stride);
190 // %13 = bitcast x86_amx %src to <256 x i32>
191 // store <256 x i32> %13, <256 x i32>* %addr, align 64
192 // -->
193 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
194 // %stride64, %13)
combineBitcastStore(BitCastInst * Bitcast,StoreInst * ST)195 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
196
197 Value *Tile = Bitcast->getOperand(0);
198 auto *II = cast<IntrinsicInst>(Tile);
199 // Tile is output from AMX intrinsic. The first operand of the
200 // intrinsic is row, the second operand of the intrinsic is column.
201 Value *Row = II->getOperand(0);
202 Value *Col = II->getOperand(1);
203 IRBuilder<> Builder(ST);
204 // Use the maximum column as stride. It must be the same with load
205 // stride.
206 Value *Stride = Builder.getInt64(64);
207 Value *I8Ptr =
208 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
209 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
210 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
211 if (Bitcast->hasOneUse())
212 return;
213 // %13 = bitcast x86_amx %src to <256 x i32>
214 // store <256 x i32> %13, <256 x i32>* %addr, align 64
215 // %add = <256 x i32> %13, <256 x i32> %src2
216 // -->
217 // %13 = bitcast x86_amx %src to <256 x i32>
218 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
219 // %stride64, %13)
220 // %14 = load <256 x i32>, %addr
221 // %add = <256 x i32> %14, <256 x i32> %src2
222 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
223 Bitcast->replaceAllUsesWith(Vec);
224 }
225
226 // transform bitcast to <store, load> instructions.
transformBitcast(BitCastInst * Bitcast)227 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
228 IRBuilder<> Builder(Bitcast);
229 AllocaInst *AllocaAddr;
230 Value *I8Ptr, *Stride;
231 auto *Src = Bitcast->getOperand(0);
232
233 auto Prepare = [&]() {
234 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent());
235 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
236 Stride = Builder.getInt64(64);
237 };
238
239 if (Bitcast->getType()->isX86_AMXTy()) {
240 // %2 = bitcast <256 x i32> %src to x86_amx
241 // -->
242 // %addr = alloca <256 x i32>, align 64
243 // store <256 x i32> %src, <256 x i32>* %addr, align 64
244 // %addr2 = bitcast <256 x i32>* to i8*
245 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
246 // i8* %addr2,
247 // i64 64)
248 Use &U = *(Bitcast->use_begin());
249 unsigned OpNo = U.getOperandNo();
250 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
251 if (!II)
252 return false; // May be bitcast from x86amx to <256 x i32>.
253 Prepare();
254 Builder.CreateStore(Src, AllocaAddr);
255 // TODO we can pick an constant operand for the shape.
256 Value *Row = nullptr, *Col = nullptr;
257 std::tie(Row, Col) = getShape(II, OpNo);
258 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
259 Value *NewInst = Builder.CreateIntrinsic(
260 Intrinsic::x86_tileloadd64_internal, None, Args);
261 Bitcast->replaceAllUsesWith(NewInst);
262 } else {
263 // %2 = bitcast x86_amx %src to <256 x i32>
264 // -->
265 // %addr = alloca <256 x i32>, align 64
266 // %addr2 = bitcast <256 x i32>* to i8*
267 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
268 // i8* %addr2, i64 %stride)
269 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
270 auto *II = dyn_cast<IntrinsicInst>(Src);
271 if (!II)
272 return false; // May be bitcast from <256 x i32> to x86amx.
273 Prepare();
274 Value *Row = II->getOperand(0);
275 Value *Col = II->getOperand(1);
276 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
277 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
278 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
279 Bitcast->replaceAllUsesWith(NewInst);
280 }
281
282 return true;
283 }
284
visit()285 bool X86LowerAMXType::visit() {
286 SmallVector<Instruction *, 8> DeadInsts;
287 Col2Row.clear();
288
289 for (BasicBlock *BB : post_order(&Func)) {
290 for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
291 II != IE;) {
292 Instruction &Inst = *II++;
293 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
294 if (!Bitcast)
295 continue;
296
297 Value *Src = Bitcast->getOperand(0);
298 if (Bitcast->getType()->isX86_AMXTy()) {
299 if (Bitcast->user_empty()) {
300 DeadInsts.push_back(Bitcast);
301 continue;
302 }
303 LoadInst *LD = dyn_cast<LoadInst>(Src);
304 if (!LD) {
305 if (transformBitcast(Bitcast))
306 DeadInsts.push_back(Bitcast);
307 continue;
308 }
309 // If load has mutli-user, duplicate a vector load.
310 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
311 // %2 = bitcast <256 x i32> %src to x86_amx
312 // %add = add <256 x i32> %src, <256 x i32> %src2
313 // -->
314 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
315 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
316 // i8* %addr, i64 %stride64)
317 // %add = add <256 x i32> %src, <256 x i32> %src2
318
319 // If load has one user, the load will be eliminated in DAG ISel.
320 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
321 // %2 = bitcast <256 x i32> %src to x86_amx
322 // -->
323 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
324 // i8* %addr, i64 %stride64)
325 combineLoadBitcast(LD, Bitcast);
326 DeadInsts.push_back(Bitcast);
327 if (LD->hasOneUse())
328 DeadInsts.push_back(LD);
329 } else if (Src->getType()->isX86_AMXTy()) {
330 if (Bitcast->user_empty()) {
331 DeadInsts.push_back(Bitcast);
332 continue;
333 }
334 StoreInst *ST = nullptr;
335 for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end();
336 UI != UE;) {
337 Value *I = (UI++)->getUser();
338 ST = dyn_cast<StoreInst>(I);
339 if (ST)
340 break;
341 }
342 if (!ST) {
343 if (transformBitcast(Bitcast))
344 DeadInsts.push_back(Bitcast);
345 continue;
346 }
347 // If bitcast (%13) has one use, combine bitcast and store to amx store.
348 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
349 // %stride);
350 // %13 = bitcast x86_amx %src to <256 x i32>
351 // store <256 x i32> %13, <256 x i32>* %addr, align 64
352 // -->
353 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
354 // %stride64, %13)
355 //
356 // If bitcast (%13) has multi-use, transform as below.
357 // %13 = bitcast x86_amx %src to <256 x i32>
358 // store <256 x i32> %13, <256 x i32>* %addr, align 64
359 // %add = <256 x i32> %13, <256 x i32> %src2
360 // -->
361 // %13 = bitcast x86_amx %src to <256 x i32>
362 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
363 // %stride64, %13)
364 // %14 = load <256 x i32>, %addr
365 // %add = <256 x i32> %14, <256 x i32> %src2
366 //
367 combineBitcastStore(Bitcast, ST);
368 // Delete user first.
369 DeadInsts.push_back(ST);
370 DeadInsts.push_back(Bitcast);
371 }
372 }
373 }
374
375 bool C = !DeadInsts.empty();
376
377 for (auto *Inst : DeadInsts)
378 Inst->eraseFromParent();
379
380 return C;
381 }
382 } // anonymous namespace
383
getAllocaPos(BasicBlock * BB)384 static Value *getAllocaPos(BasicBlock *BB) {
385 Module *M = BB->getModule();
386 Function *F = BB->getParent();
387 IRBuilder<> Builder(&F->getEntryBlock().front());
388 const DataLayout &DL = M->getDataLayout();
389 unsigned AllocaAS = DL.getAllocaAddrSpace();
390 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
391 AllocaInst *AllocaRes =
392 new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
393 BasicBlock::iterator Iter = AllocaRes->getIterator();
394 ++Iter;
395 Builder.SetInsertPoint(&*Iter);
396 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
397 return I8Ptr;
398 }
399
createTileStore(Instruction * TileDef,Value * Ptr)400 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
401 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
402 auto *II = cast<IntrinsicInst>(TileDef);
403 assert(II && "Not tile intrinsic!");
404 Value *Row = II->getOperand(0);
405 Value *Col = II->getOperand(1);
406
407 BasicBlock *BB = TileDef->getParent();
408 BasicBlock::iterator Iter = TileDef->getIterator();
409 IRBuilder<> Builder(BB, ++Iter);
410 Value *Stride = Builder.getInt64(64);
411 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
412
413 Instruction *TileStore =
414 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
415 return TileStore;
416 }
417
replaceWithTileLoad(Use & U,Value * Ptr,bool IsPHI=false)418 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
419 Value *V = U.get();
420 assert(V->getType()->isX86_AMXTy() && "Not define tile!");
421
422 // Get tile shape.
423 IntrinsicInst *II = nullptr;
424 if (IsPHI) {
425 Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
426 II = cast<IntrinsicInst>(PhiOp);
427 } else {
428 II = cast<IntrinsicInst>(V);
429 }
430 Value *Row = II->getOperand(0);
431 Value *Col = II->getOperand(1);
432
433 Instruction *UserI = dyn_cast<Instruction>(U.getUser());
434 IRBuilder<> Builder(UserI);
435 Value *Stride = Builder.getInt64(64);
436 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
437
438 Value *TileLoad =
439 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
440 UserI->replaceUsesOfWith(V, TileLoad);
441 }
442
isIncomingOfPHI(Instruction * I)443 static bool isIncomingOfPHI(Instruction *I) {
444 for (Use &U : I->uses()) {
445 User *V = U.getUser();
446 if (isa<PHINode>(V))
447 return true;
448 }
449 return false;
450 }
451
452 // Let all AMX tile data become volatile data, shorten the life range
453 // of each tile register before fast register allocation.
454 namespace {
455 class X86VolatileTileData {
456 Function &F;
457
458 public:
X86VolatileTileData(Function & Func)459 X86VolatileTileData(Function &Func) : F(Func) {}
460 Value *updatePhiIncomings(BasicBlock *BB,
461 SmallVector<Instruction *, 2> &Incomings);
462 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
463 bool volatileTileData();
464 void volatileTilePHI(PHINode *Inst);
465 void volatileTileNonPHI(Instruction *I);
466 };
467
updatePhiIncomings(BasicBlock * BB,SmallVector<Instruction *,2> & Incomings)468 Value *X86VolatileTileData::updatePhiIncomings(
469 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
470 Value *I8Ptr = getAllocaPos(BB);
471
472 for (auto *I : Incomings) {
473 User *Store = createTileStore(I, I8Ptr);
474
475 // All its uses (except phi) should load from stored mem.
476 for (Use &U : I->uses()) {
477 User *V = U.getUser();
478 if (isa<PHINode>(V) || V == Store)
479 continue;
480 replaceWithTileLoad(U, I8Ptr);
481 }
482 }
483 return I8Ptr;
484 }
485
replacePhiDefWithLoad(Instruction * PHI,Value * StorePtr)486 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
487 Value *StorePtr) {
488 for (Use &U : PHI->uses())
489 replaceWithTileLoad(U, StorePtr, true);
490 PHI->eraseFromParent();
491 }
492
493 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
494 // and their related AMX intrinsics.
495 // 1) PHI Def should change to tileload.
496 // 2) PHI Incoming Values should tilestored in just after their def.
497 // 3) The mem of these tileload and tilestores should be same.
498 // e.g.
499 // ------------------------------------------------------
500 // bb_dom:
501 // ...
502 // br i1 %bool.cond, label %if.else, label %if.then
503 //
504 // if.then:
505 // def %t0 = ...
506 // ...
507 // use %t0
508 // ...
509 // br label %if.end
510 //
511 // if.else:
512 // def %t1 = ...
513 // br label %if.end
514 //
515 // if.end:
516 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
517 // ...
518 // use %td
519 // ------------------------------------------------------
520 // -->
521 // ------------------------------------------------------
522 // bb_entry:
523 // %mem = alloca <256 x i32>, align 1024 *
524 // ...
525 // bb_dom:
526 // ...
527 // br i1 %bool.cond, label %if.else, label %if.then
528 //
529 // if.then:
530 // def %t0 = ...
531 // call void @llvm.x86.tilestored64.internal(mem, %t0) *
532 // ...
533 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
534 // use %t0` *
535 // ...
536 // br label %if.end
537 //
538 // if.else:
539 // def %t1 = ...
540 // call void @llvm.x86.tilestored64.internal(mem, %t1) *
541 // br label %if.end
542 //
543 // if.end:
544 // ...
545 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
546 // use %td
547 // ------------------------------------------------------
volatileTilePHI(PHINode * PHI)548 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
549 BasicBlock *BB = PHI->getParent();
550 SmallVector<Instruction *, 2> Incomings;
551
552 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
553 Value *Op = PHI->getIncomingValue(I);
554 Instruction *Inst = dyn_cast<Instruction>(Op);
555 assert(Inst && "We shouldn't fold AMX instrution!");
556 Incomings.push_back(Inst);
557 }
558
559 Value *StorePtr = updatePhiIncomings(BB, Incomings);
560 replacePhiDefWithLoad(PHI, StorePtr);
561 }
562
563 // Store the defined tile and load it before use.
564 // All its users are not PHI.
565 // e.g.
566 // ------------------------------------------------------
567 // def %td = ...
568 // ...
569 // "use %td"
570 // ------------------------------------------------------
571 // -->
572 // ------------------------------------------------------
573 // def %td = ...
574 // call void @llvm.x86.tilestored64.internal(mem, %td)
575 // ...
576 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
577 // "use %td2"
578 // ------------------------------------------------------
volatileTileNonPHI(Instruction * I)579 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
580 BasicBlock *BB = I->getParent();
581 Value *I8Ptr = getAllocaPos(BB);
582 User *Store = createTileStore(I, I8Ptr);
583
584 // All its uses should load from stored mem.
585 for (Use &U : I->uses()) {
586 User *V = U.getUser();
587 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
588 if (V != Store)
589 replaceWithTileLoad(U, I8Ptr);
590 }
591 }
592
593 // Volatile Tile Model:
594 // 1) All the uses of tile data comes from tileload in time.
595 // 2) All the defs of tile data tilestore into mem immediately.
596 // For example:
597 // --------------------------------------------------------------------------
598 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
599 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
600 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
601 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
602 // call void @llvm.x86.tilestored64.internal(... td) area
603 // --------------------------------------------------------------------------
604 // 3) No terminator, call or other amx instructions in the key amx area.
volatileTileData()605 bool X86VolatileTileData::volatileTileData() {
606 bool Changed = false;
607 for (BasicBlock &BB : F) {
608 SmallVector<Instruction *, 2> PHIInsts;
609 SmallVector<Instruction *, 8> AMXDefInsts;
610
611 for (Instruction &I : BB) {
612 if (!I.getType()->isX86_AMXTy())
613 continue;
614 if (isa<PHINode>(&I))
615 PHIInsts.push_back(&I);
616 else
617 AMXDefInsts.push_back(&I);
618 }
619
620 // First we "volatile" the non-phi related amx intrinsics.
621 for (Instruction *I : AMXDefInsts) {
622 if (isIncomingOfPHI(I))
623 continue;
624 volatileTileNonPHI(I);
625 Changed = true;
626 }
627
628 for (Instruction *I : PHIInsts) {
629 volatileTilePHI(dyn_cast<PHINode>(I));
630 Changed = true;
631 }
632 }
633 return Changed;
634 }
635
636 } // anonymous namespace
637
638 namespace {
639
640 class X86LowerAMXTypeLegacyPass : public FunctionPass {
641 public:
642 static char ID;
643
X86LowerAMXTypeLegacyPass()644 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
645 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
646 }
647
runOnFunction(Function & F)648 bool runOnFunction(Function &F) override {
649 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
650
651 X86LowerAMXType LAT(F, TM);
652 bool C = LAT.visit();
653
654 // Prepare for fast register allocation at O0.
655 // Todo: May better check the volatile model of AMX code, not just
656 // by checking Attribute::OptimizeNone and CodeGenOpt::None.
657 if (TM->getOptLevel() == CodeGenOpt::None) {
658 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
659 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
660 // sure the amx data is volatile, that is nessary for AMX fast
661 // register allocation.
662 if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
663 X86VolatileTileData VTD(F);
664 C = VTD.volatileTileData() || C;
665 }
666 }
667
668 return C;
669 }
670
getAnalysisUsage(AnalysisUsage & AU) const671 void getAnalysisUsage(AnalysisUsage &AU) const override {
672 AU.setPreservesCFG();
673 AU.addRequired<TargetPassConfig>();
674 }
675 };
676
677 } // anonymous namespace
678
679 static const char PassName[] = "Lower AMX type for load/store";
680 char X86LowerAMXTypeLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass,DEBUG_TYPE,PassName,false,false)681 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
682 false)
683 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
684 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
685 false)
686
687 FunctionPass *llvm::createX86LowerAMXTypePass() {
688 return new X86LowerAMXTypeLegacyPass();
689 }
690