1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ...                                          |
27 /// | ...                                                    |
28 /// | "use %td"                                              |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ...                                          |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td)    |
34 /// | ...                                                    |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2"                                             |
37 /// ----------------------------------------------------------
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 #include "X86.h"
42 #include "llvm/ADT/PostOrderIterator.h"
43 #include "llvm/ADT/SmallSet.h"
44 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
45 #include "llvm/Analysis/TargetTransformInfo.h"
46 #include "llvm/CodeGen/Passes.h"
47 #include "llvm/CodeGen/TargetPassConfig.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/IR/DataLayout.h"
50 #include "llvm/IR/Function.h"
51 #include "llvm/IR/IRBuilder.h"
52 #include "llvm/IR/Instructions.h"
53 #include "llvm/IR/IntrinsicInst.h"
54 #include "llvm/IR/IntrinsicsX86.h"
55 #include "llvm/IR/PatternMatch.h"
56 #include "llvm/InitializePasses.h"
57 #include "llvm/Pass.h"
58 #include "llvm/Target/TargetMachine.h"
59 
60 using namespace llvm;
61 using namespace PatternMatch;
62 
63 #define DEBUG_TYPE "lower-amx-type"
64 
65 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder,
66                                            BasicBlock *BB) {
67   Function &F = *BB->getParent();
68   Module *M = BB->getModule();
69   const DataLayout &DL = M->getDataLayout();
70 
71   Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
72   LLVMContext &Ctx = Builder.getContext();
73   auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
74   unsigned AllocaAS = DL.getAllocaAddrSpace();
75   AllocaInst *AllocaRes =
76       new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front());
77   AllocaRes->setAlignment(AllocaAlignment);
78   return AllocaRes;
79 }
80 
81 namespace {
82 class X86LowerAMXType {
83   Function &Func;
84   TargetMachine *TM = nullptr;
85 
86   // In AMX intrinsics we let Shape = {Row, Col}, but the
87   // RealCol = Col / ElementSize. We may use the RealCol
88   // as a new Row for other new created AMX intrinsics.
89   std::map<Value *, Value *> Col2Row;
90 
91 public:
92   X86LowerAMXType(Function &F, TargetMachine *TargetM) : Func(F), TM(TargetM) {}
93   bool visit();
94   void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
95   void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
96   bool transformBitcast(BitCastInst *Bitcast);
97   std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo);
98   Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
99 };
100 
101 Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V,
102                                       unsigned Granularity) {
103   if (Col2Row.count(V))
104     return Col2Row[V];
105   IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt());
106   if (auto *I = dyn_cast<Instruction>(V)) {
107     BasicBlock::iterator Iter = I->getIterator();
108     ++Iter;
109     Builder.SetInsertPoint(&*Iter);
110   }
111   ConstantInt *Gran = Builder.getInt16(Granularity);
112   Value *RealRow = Builder.CreateUDiv(V, Gran);
113   Col2Row[V] = RealRow;
114   return RealRow;
115 }
116 
117 std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II,
118                                                       unsigned OpNo) {
119   Value *Row = nullptr, *Col = nullptr;
120   switch (II->getIntrinsicID()) {
121   default:
122     llvm_unreachable("Expect amx intrinsics");
123   case Intrinsic::x86_tileloadd64_internal:
124   case Intrinsic::x86_tileloaddt164_internal:
125   case Intrinsic::x86_tilestored64_internal: {
126     Row = II->getArgOperand(0);
127     Col = II->getArgOperand(1);
128     break;
129   }
130   // a * b + c
131   // The shape depends on which operand.
132   case Intrinsic::x86_tdpbssd_internal:
133   case Intrinsic::x86_tdpbsud_internal:
134   case Intrinsic::x86_tdpbusd_internal:
135   case Intrinsic::x86_tdpbuud_internal:
136   case Intrinsic::x86_tdpbf16ps_internal: {
137     switch (OpNo) {
138     case 3:
139       Row = II->getArgOperand(0);
140       Col = II->getArgOperand(1);
141       break;
142     case 4:
143       Row = II->getArgOperand(0);
144       Col = II->getArgOperand(2);
145       break;
146     case 5:
147       Row = II->getArgOperand(2);
148       // FIXME: There is a design bug for AMX shape, which the Col should be
149       // Col/4 if it will be used as Row, but current Greedy RA can't handle
150       // this case well, it may failed if we generate a new Shape definition.
151       // So Let's just do it in O0 first.
152       // Row = Row / 4
153       if (TM->getOptLevel() == CodeGenOpt::None)
154         Row = getRowFromCol(II, Row, 4);
155       Col = II->getArgOperand(1);
156       break;
157     }
158     break;
159   }
160   }
161 
162   return std::make_pair(Row, Col);
163 }
164 
165 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
166 // %2 = bitcast <256 x i32> %src to x86_amx
167 // -->
168 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
169 // i8* %addr, i64 %stride64)
170 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
171   Value *Row = nullptr, *Col = nullptr;
172   Use &U = *(Bitcast->use_begin());
173   unsigned OpNo = U.getOperandNo();
174   auto *II = cast<IntrinsicInst>(U.getUser());
175   std::tie(Row, Col) = getShape(II, OpNo);
176   IRBuilder<> Builder(Bitcast);
177   // Use the maximun column as stride.
178   Value *Stride = Builder.getInt64(64);
179   Value *I8Ptr =
180       Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
181   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
182 
183   Value *NewInst =
184       Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
185   Bitcast->replaceAllUsesWith(NewInst);
186 }
187 
188 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
189 //                                                    %stride);
190 // %13 = bitcast x86_amx %src to <256 x i32>
191 // store <256 x i32> %13, <256 x i32>* %addr, align 64
192 // -->
193 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
194 //                                           %stride64, %13)
195 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
196 
197   Value *Tile = Bitcast->getOperand(0);
198   auto *II = cast<IntrinsicInst>(Tile);
199   // Tile is output from AMX intrinsic. The first operand of the
200   // intrinsic is row, the second operand of the intrinsic is column.
201   Value *Row = II->getOperand(0);
202   Value *Col = II->getOperand(1);
203   IRBuilder<> Builder(ST);
204   // Use the maximum column as stride. It must be the same with load
205   // stride.
206   Value *Stride = Builder.getInt64(64);
207   Value *I8Ptr =
208       Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
209   std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
210   Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
211   if (Bitcast->hasOneUse())
212     return;
213   // %13 = bitcast x86_amx %src to <256 x i32>
214   // store <256 x i32> %13, <256 x i32>* %addr, align 64
215   // %add = <256 x i32> %13, <256 x i32> %src2
216   // -->
217   // %13 = bitcast x86_amx %src to <256 x i32>
218   // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
219   //                                           %stride64, %13)
220   // %14 = load <256 x i32>, %addr
221   // %add = <256 x i32> %14, <256 x i32> %src2
222   Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
223   Bitcast->replaceAllUsesWith(Vec);
224 }
225 
226 // transform bitcast to <store, load> instructions.
227 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
228   IRBuilder<> Builder(Bitcast);
229   AllocaInst *AllocaAddr;
230   Value *I8Ptr, *Stride;
231   auto *Src = Bitcast->getOperand(0);
232 
233   auto Prepare = [&]() {
234     AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent());
235     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
236     Stride = Builder.getInt64(64);
237   };
238 
239   if (Bitcast->getType()->isX86_AMXTy()) {
240     // %2 = bitcast <256 x i32> %src to x86_amx
241     // -->
242     // %addr = alloca <256 x i32>, align 64
243     // store <256 x i32> %src, <256 x i32>* %addr, align 64
244     // %addr2 = bitcast <256 x i32>* to i8*
245     // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
246     //                                                  i8* %addr2,
247     //                                                  i64 64)
248     Use &U = *(Bitcast->use_begin());
249     unsigned OpNo = U.getOperandNo();
250     auto *II = dyn_cast<IntrinsicInst>(U.getUser());
251     if (!II)
252       return false; // May be bitcast from x86amx to <256 x i32>.
253     Prepare();
254     Builder.CreateStore(Src, AllocaAddr);
255     // TODO we can pick an constant operand for the shape.
256     Value *Row = nullptr, *Col = nullptr;
257     std::tie(Row, Col) = getShape(II, OpNo);
258     std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
259     Value *NewInst = Builder.CreateIntrinsic(
260         Intrinsic::x86_tileloadd64_internal, None, Args);
261     Bitcast->replaceAllUsesWith(NewInst);
262   } else {
263     // %2 = bitcast x86_amx %src to <256 x i32>
264     // -->
265     // %addr = alloca <256 x i32>, align 64
266     // %addr2 = bitcast <256 x i32>* to i8*
267     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
268     //                                           i8* %addr2, i64 %stride)
269     // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
270     auto *II = dyn_cast<IntrinsicInst>(Src);
271     if (!II)
272       return false; // May be bitcast from <256 x i32> to x86amx.
273     Prepare();
274     Value *Row = II->getOperand(0);
275     Value *Col = II->getOperand(1);
276     std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
277     Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
278     Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
279     Bitcast->replaceAllUsesWith(NewInst);
280   }
281 
282   return true;
283 }
284 
285 bool X86LowerAMXType::visit() {
286   SmallVector<Instruction *, 8> DeadInsts;
287   Col2Row.clear();
288 
289   for (BasicBlock *BB : post_order(&Func)) {
290     for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
291          II != IE;) {
292       Instruction &Inst = *II++;
293       auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
294       if (!Bitcast)
295         continue;
296 
297       Value *Src = Bitcast->getOperand(0);
298       if (Bitcast->getType()->isX86_AMXTy()) {
299         if (Bitcast->user_empty()) {
300           DeadInsts.push_back(Bitcast);
301           continue;
302         }
303         LoadInst *LD = dyn_cast<LoadInst>(Src);
304         if (!LD) {
305           if (transformBitcast(Bitcast))
306             DeadInsts.push_back(Bitcast);
307           continue;
308         }
309         // If load has mutli-user, duplicate a vector load.
310         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
311         // %2 = bitcast <256 x i32> %src to x86_amx
312         // %add = add <256 x i32> %src, <256 x i32> %src2
313         // -->
314         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
315         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
316         //                                            i8* %addr, i64 %stride64)
317         // %add = add <256 x i32> %src, <256 x i32> %src2
318 
319         // If load has one user, the load will be eliminated in DAG ISel.
320         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
321         // %2 = bitcast <256 x i32> %src to x86_amx
322         // -->
323         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
324         //                                            i8* %addr, i64 %stride64)
325         combineLoadBitcast(LD, Bitcast);
326         DeadInsts.push_back(Bitcast);
327         if (LD->hasOneUse())
328           DeadInsts.push_back(LD);
329       } else if (Src->getType()->isX86_AMXTy()) {
330         if (Bitcast->user_empty()) {
331           DeadInsts.push_back(Bitcast);
332           continue;
333         }
334         StoreInst *ST = nullptr;
335         for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end();
336              UI != UE;) {
337           Value *I = (UI++)->getUser();
338           ST = dyn_cast<StoreInst>(I);
339           if (ST)
340             break;
341         }
342         if (!ST) {
343           if (transformBitcast(Bitcast))
344             DeadInsts.push_back(Bitcast);
345           continue;
346         }
347         // If bitcast (%13) has one use, combine bitcast and store to amx store.
348         // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
349         //                                                    %stride);
350         // %13 = bitcast x86_amx %src to <256 x i32>
351         // store <256 x i32> %13, <256 x i32>* %addr, align 64
352         // -->
353         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
354         //                                           %stride64, %13)
355         //
356         // If bitcast (%13) has multi-use, transform as below.
357         // %13 = bitcast x86_amx %src to <256 x i32>
358         // store <256 x i32> %13, <256 x i32>* %addr, align 64
359         // %add = <256 x i32> %13, <256 x i32> %src2
360         // -->
361         // %13 = bitcast x86_amx %src to <256 x i32>
362         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
363         //                                           %stride64, %13)
364         // %14 = load <256 x i32>, %addr
365         // %add = <256 x i32> %14, <256 x i32> %src2
366         //
367         combineBitcastStore(Bitcast, ST);
368         // Delete user first.
369         DeadInsts.push_back(ST);
370         DeadInsts.push_back(Bitcast);
371       }
372     }
373   }
374 
375   bool C = !DeadInsts.empty();
376 
377   for (auto *Inst : DeadInsts)
378     Inst->eraseFromParent();
379 
380   return C;
381 }
382 } // anonymous namespace
383 
384 static Value *getAllocaPos(BasicBlock *BB) {
385   Module *M = BB->getModule();
386   Function *F = BB->getParent();
387   IRBuilder<> Builder(&F->getEntryBlock().front());
388   const DataLayout &DL = M->getDataLayout();
389   unsigned AllocaAS = DL.getAllocaAddrSpace();
390   Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
391   AllocaInst *AllocaRes =
392       new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
393   BasicBlock::iterator Iter = AllocaRes->getIterator();
394   ++Iter;
395   Builder.SetInsertPoint(&*Iter);
396   Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
397   return I8Ptr;
398 }
399 
400 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
401   assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
402   auto *II = cast<IntrinsicInst>(TileDef);
403   assert(II && "Not tile intrinsic!");
404   Value *Row = II->getOperand(0);
405   Value *Col = II->getOperand(1);
406 
407   BasicBlock *BB = TileDef->getParent();
408   BasicBlock::iterator Iter = TileDef->getIterator();
409   IRBuilder<> Builder(BB, ++Iter);
410   Value *Stride = Builder.getInt64(64);
411   std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
412 
413   Instruction *TileStore =
414       Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
415   return TileStore;
416 }
417 
418 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
419   Value *V = U.get();
420   assert(V->getType()->isX86_AMXTy() && "Not define tile!");
421 
422   // Get tile shape.
423   IntrinsicInst *II = nullptr;
424   if (IsPHI) {
425     Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
426     II = cast<IntrinsicInst>(PhiOp);
427   } else {
428     II = cast<IntrinsicInst>(V);
429   }
430   Value *Row = II->getOperand(0);
431   Value *Col = II->getOperand(1);
432 
433   Instruction *UserI = dyn_cast<Instruction>(U.getUser());
434   IRBuilder<> Builder(UserI);
435   Value *Stride = Builder.getInt64(64);
436   std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
437 
438   Value *TileLoad =
439       Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
440   UserI->replaceUsesOfWith(V, TileLoad);
441 }
442 
443 static bool isIncomingOfPHI(Instruction *I) {
444   for (Use &U : I->uses()) {
445     User *V = U.getUser();
446     if (isa<PHINode>(V))
447       return true;
448   }
449   return false;
450 }
451 
452 // Let all AMX tile data become volatile data, shorten the life range
453 // of each tile register before fast register allocation.
454 namespace {
455 class X86VolatileTileData {
456   Function &F;
457 
458 public:
459   X86VolatileTileData(Function &Func) : F(Func) {}
460   Value *updatePhiIncomings(BasicBlock *BB,
461                             SmallVector<Instruction *, 2> &Incomings);
462   void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
463   bool volatileTileData();
464   void volatileTilePHI(PHINode *Inst);
465   void volatileTileNonPHI(Instruction *I);
466 };
467 
468 Value *X86VolatileTileData::updatePhiIncomings(
469     BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
470   Value *I8Ptr = getAllocaPos(BB);
471 
472   for (auto *I : Incomings) {
473     User *Store = createTileStore(I, I8Ptr);
474 
475     // All its uses (except phi) should load from stored mem.
476     for (Use &U : I->uses()) {
477       User *V = U.getUser();
478       if (isa<PHINode>(V) || V == Store)
479         continue;
480       replaceWithTileLoad(U, I8Ptr);
481     }
482   }
483   return I8Ptr;
484 }
485 
486 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
487                                                 Value *StorePtr) {
488   for (Use &U : PHI->uses())
489     replaceWithTileLoad(U, StorePtr, true);
490   PHI->eraseFromParent();
491 }
492 
493 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
494 // and their related AMX intrinsics.
495 // 1) PHI Def should change to tileload.
496 // 2) PHI Incoming Values should tilestored in just after their def.
497 // 3) The mem of these tileload and tilestores should be same.
498 // e.g.
499 // ------------------------------------------------------
500 // bb_dom:
501 //   ...
502 //   br i1 %bool.cond, label %if.else, label %if.then
503 //
504 // if.then:
505 //   def %t0 = ...
506 //   ...
507 //   use %t0
508 //   ...
509 //   br label %if.end
510 //
511 // if.else:
512 //   def %t1 = ...
513 //   br label %if.end
514 //
515 // if.end:
516 //   %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
517 //   ...
518 //   use %td
519 // ------------------------------------------------------
520 // -->
521 // ------------------------------------------------------
522 // bb_entry:
523 //   %mem = alloca <256 x i32>, align 1024                  *
524 //   ...
525 // bb_dom:
526 //   ...
527 //   br i1 %bool.cond, label %if.else, label %if.then
528 //
529 // if.then:
530 //   def %t0 = ...
531 //   call void @llvm.x86.tilestored64.internal(mem, %t0)    *
532 //   ...
533 //   %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
534 //   use %t0`                                               *
535 //   ...
536 //   br label %if.end
537 //
538 // if.else:
539 //   def %t1 = ...
540 //   call void @llvm.x86.tilestored64.internal(mem, %t1)    *
541 //   br label %if.end
542 //
543 // if.end:
544 //   ...
545 //   %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
546 //   use %td
547 // ------------------------------------------------------
548 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
549   BasicBlock *BB = PHI->getParent();
550   SmallVector<Instruction *, 2> Incomings;
551 
552   for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
553     Value *Op = PHI->getIncomingValue(I);
554     Instruction *Inst = dyn_cast<Instruction>(Op);
555     assert(Inst && "We shouldn't fold AMX instrution!");
556     Incomings.push_back(Inst);
557   }
558 
559   Value *StorePtr = updatePhiIncomings(BB, Incomings);
560   replacePhiDefWithLoad(PHI, StorePtr);
561 }
562 
563 // Store the defined tile and load it before use.
564 // All its users are not PHI.
565 // e.g.
566 // ------------------------------------------------------
567 // def %td = ...
568 // ...
569 // "use %td"
570 // ------------------------------------------------------
571 // -->
572 // ------------------------------------------------------
573 // def %td = ...
574 // call void @llvm.x86.tilestored64.internal(mem, %td)
575 // ...
576 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
577 // "use %td2"
578 // ------------------------------------------------------
579 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
580   BasicBlock *BB = I->getParent();
581   Value *I8Ptr = getAllocaPos(BB);
582   User *Store = createTileStore(I, I8Ptr);
583 
584   // All its uses should load from stored mem.
585   for (Use &U : I->uses()) {
586     User *V = U.getUser();
587     assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
588     if (V != Store)
589       replaceWithTileLoad(U, I8Ptr);
590   }
591 }
592 
593 // Volatile Tile Model:
594 // 1) All the uses of tile data comes from tileload in time.
595 // 2) All the defs of tile data tilestore into mem immediately.
596 // For example:
597 // --------------------------------------------------------------------------
598 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
599 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
600 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
601 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
602 // call void @llvm.x86.tilestored64.internal(... td)                     area
603 // --------------------------------------------------------------------------
604 // 3) No terminator, call or other amx instructions in the key amx area.
605 bool X86VolatileTileData::volatileTileData() {
606   bool Changed = false;
607   for (BasicBlock &BB : F) {
608     SmallVector<Instruction *, 2> PHIInsts;
609     SmallVector<Instruction *, 8> AMXDefInsts;
610 
611     for (Instruction &I : BB) {
612       if (!I.getType()->isX86_AMXTy())
613         continue;
614       if (isa<PHINode>(&I))
615         PHIInsts.push_back(&I);
616       else
617         AMXDefInsts.push_back(&I);
618     }
619 
620     // First we "volatile" the non-phi related amx intrinsics.
621     for (Instruction *I : AMXDefInsts) {
622       if (isIncomingOfPHI(I))
623         continue;
624       volatileTileNonPHI(I);
625       Changed = true;
626     }
627 
628     for (Instruction *I : PHIInsts) {
629       volatileTilePHI(dyn_cast<PHINode>(I));
630       Changed = true;
631     }
632   }
633   return Changed;
634 }
635 
636 } // anonymous namespace
637 
638 namespace {
639 
640 class X86LowerAMXTypeLegacyPass : public FunctionPass {
641 public:
642   static char ID;
643 
644   X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
645     initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
646   }
647 
648   bool runOnFunction(Function &F) override {
649     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
650 
651     X86LowerAMXType LAT(F, TM);
652     bool C = LAT.visit();
653 
654     // Prepare for fast register allocation at O0.
655     // Todo: May better check the volatile model of AMX code, not just
656     // by checking Attribute::OptimizeNone and CodeGenOpt::None.
657     if (TM->getOptLevel() == CodeGenOpt::None) {
658       // If Front End not use O0 but the Mid/Back end use O0, (e.g.
659       // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
660       // sure the amx data is volatile, that is nessary for AMX fast
661       // register allocation.
662       if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
663         X86VolatileTileData VTD(F);
664         C = VTD.volatileTileData() || C;
665       }
666     }
667 
668     return C;
669   }
670 
671   void getAnalysisUsage(AnalysisUsage &AU) const override {
672     AU.setPreservesCFG();
673     AU.addRequired<TargetPassConfig>();
674   }
675 };
676 
677 } // anonymous namespace
678 
679 static const char PassName[] = "Lower AMX type for load/store";
680 char X86LowerAMXTypeLegacyPass::ID = 0;
681 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
682                       false)
683 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
684 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
685                     false)
686 
687 FunctionPass *llvm::createX86LowerAMXTypePass() {
688   return new X86LowerAMXTypeLegacyPass();
689 }
690