1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ...                                          |
27 /// | ...                                                    |
28 /// | "use %td"                                              |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ...                                          |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td)    |
34 /// | ...                                                    |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2"                                             |
37 /// ----------------------------------------------------------
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 #include "X86.h"
42 #include "llvm/ADT/PostOrderIterator.h"
43 #include "llvm/ADT/SetVector.h"
44 #include "llvm/ADT/SmallSet.h"
45 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
46 #include "llvm/Analysis/TargetLibraryInfo.h"
47 #include "llvm/Analysis/TargetTransformInfo.h"
48 #include "llvm/CodeGen/Passes.h"
49 #include "llvm/CodeGen/TargetPassConfig.h"
50 #include "llvm/CodeGen/ValueTypes.h"
51 #include "llvm/IR/DataLayout.h"
52 #include "llvm/IR/Function.h"
53 #include "llvm/IR/IRBuilder.h"
54 #include "llvm/IR/Instructions.h"
55 #include "llvm/IR/IntrinsicInst.h"
56 #include "llvm/IR/IntrinsicsX86.h"
57 #include "llvm/IR/PatternMatch.h"
58 #include "llvm/InitializePasses.h"
59 #include "llvm/Pass.h"
60 #include "llvm/Target/TargetMachine.h"
61 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
62 #include "llvm/Transforms/Utils/Local.h"
63 
64 #include <map>
65 
66 using namespace llvm;
67 using namespace PatternMatch;
68 
69 #define DEBUG_TYPE "lower-amx-type"
70 
71 static bool isAMXCast(Instruction *II) {
72   return match(II,
73                m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
74          match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
75 }
76 
77 static bool isAMXIntrinsic(Value *I) {
78   auto *II = dyn_cast<IntrinsicInst>(I);
79   if (!II)
80     return false;
81   if (isAMXCast(II))
82     return false;
83   // Check if return type or parameter is x86_amx. If it is x86_amx
84   // the intrinsic must be x86 amx intrinsics.
85   if (II->getType()->isX86_AMXTy())
86     return true;
87   for (Value *V : II->args()) {
88     if (V->getType()->isX86_AMXTy())
89       return true;
90   }
91 
92   return false;
93 }
94 
95 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
96                                            Type *Ty) {
97   Function &F = *BB->getParent();
98   Module *M = BB->getModule();
99   const DataLayout &DL = M->getDataLayout();
100 
101   LLVMContext &Ctx = Builder.getContext();
102   auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
103   unsigned AllocaAS = DL.getAllocaAddrSpace();
104   AllocaInst *AllocaRes =
105       new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
106   AllocaRes->setAlignment(AllocaAlignment);
107   return AllocaRes;
108 }
109 
110 static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
111   for (Instruction &I : F.getEntryBlock())
112     if (!isa<AllocaInst>(&I))
113       return &I;
114   llvm_unreachable("No terminator in the entry block!");
115 }
116 
117 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
118   IRBuilder<> Builder(II);
119   Value *Row = nullptr, *Col = nullptr;
120   switch (II->getIntrinsicID()) {
121   default:
122     llvm_unreachable("Expect amx intrinsics");
123   case Intrinsic::x86_tileloadd64_internal:
124   case Intrinsic::x86_tileloaddt164_internal:
125   case Intrinsic::x86_tilestored64_internal: {
126     Row = II->getArgOperand(0);
127     Col = II->getArgOperand(1);
128     break;
129   }
130   // a * b + c
131   // The shape depends on which operand.
132   case Intrinsic::x86_tdpbssd_internal:
133   case Intrinsic::x86_tdpbsud_internal:
134   case Intrinsic::x86_tdpbusd_internal:
135   case Intrinsic::x86_tdpbuud_internal:
136   case Intrinsic::x86_tdpbf16ps_internal: {
137     switch (OpNo) {
138     case 3:
139       Row = II->getArgOperand(0);
140       Col = II->getArgOperand(1);
141       break;
142     case 4:
143       Row = II->getArgOperand(0);
144       Col = II->getArgOperand(2);
145       break;
146     case 5:
147       if (isa<ConstantInt>(II->getArgOperand(2)))
148         Row = Builder.getInt16(
149             (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
150       else if (isa<Instruction>(II->getArgOperand(2))) {
151         // When it is not a const value and it is not a function argument, we
152         // create Row after the definition of II->getOperand(2) instead of
153         // before II. For example, II is %118, we try to getshape for %117:
154         //   %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
155         //   i32> %115).
156         //   %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
157         //   %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
158         //   %117).
159         // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
160         // definition is after its user(new tileload for %117).
161         // So, the best choice is to create %row right after the definition of
162         // %106.
163         Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
164         Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
165         cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
166       } else {
167         // When it is not a const value and it is a function argument, we create
168         // Row at the entry bb.
169         IRBuilder<> NewBuilder(
170             getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
171         Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
172       }
173       Col = II->getArgOperand(1);
174       break;
175     }
176     break;
177   }
178   }
179 
180   return std::make_pair(Row, Col);
181 }
182 
183 static std::pair<Value *, Value *> getShape(PHINode *Phi) {
184   Use &U = *(Phi->use_begin());
185   unsigned OpNo = U.getOperandNo();
186   User *V = U.getUser();
187   // TODO We don't traverse all users. To make the algorithm simple, here we
188   // just traverse the first user. If we can find shape, then return the shape,
189   // otherwise just return nullptr and the optimization for undef/zero will be
190   // abandoned.
191   while (V) {
192     if (isAMXCast(dyn_cast<Instruction>(V))) {
193       if (V->use_empty())
194         break;
195       Use &U = *(V->use_begin());
196       OpNo = U.getOperandNo();
197       V = U.getUser();
198     } else if (isAMXIntrinsic(V)) {
199       return getShape(cast<IntrinsicInst>(V), OpNo);
200     } else if (isa<PHINode>(V)) {
201       if (V->use_empty())
202         break;
203       Use &U = *(V->use_begin());
204       V = U.getUser();
205     } else {
206       break;
207     }
208   }
209 
210   return std::make_pair(nullptr, nullptr);
211 }
212 
213 namespace {
214 class X86LowerAMXType {
215   Function &Func;
216 
217   // In AMX intrinsics we let Shape = {Row, Col}, but the
218   // RealCol = Col / ElementSize. We may use the RealCol
219   // as a new Row for other new created AMX intrinsics.
220   std::map<Value *, Value *> Col2Row;
221 
222 public:
223   X86LowerAMXType(Function &F) : Func(F) {}
224   bool visit();
225   void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
226   void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
227   bool transformBitcast(BitCastInst *Bitcast);
228 };
229 
230 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
231 // %2 = bitcast <256 x i32> %src to x86_amx
232 // -->
233 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
234 // i8* %addr, i64 %stride64)
235 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
236   Value *Row = nullptr, *Col = nullptr;
237   Use &U = *(Bitcast->use_begin());
238   unsigned OpNo = U.getOperandNo();
239   auto *II = cast<IntrinsicInst>(U.getUser());
240   std::tie(Row, Col) = getShape(II, OpNo);
241   IRBuilder<> Builder(Bitcast);
242   // Use the maximun column as stride.
243   Value *Stride = Builder.getInt64(64);
244   Value *I8Ptr =
245       Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
246   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
247 
248   Value *NewInst =
249       Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
250   Bitcast->replaceAllUsesWith(NewInst);
251 }
252 
253 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
254 //                                                    %stride);
255 // %13 = bitcast x86_amx %src to <256 x i32>
256 // store <256 x i32> %13, <256 x i32>* %addr, align 64
257 // -->
258 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
259 //                                           %stride64, %13)
260 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
261 
262   Value *Tile = Bitcast->getOperand(0);
263   auto *II = cast<IntrinsicInst>(Tile);
264   // Tile is output from AMX intrinsic. The first operand of the
265   // intrinsic is row, the second operand of the intrinsic is column.
266   Value *Row = II->getOperand(0);
267   Value *Col = II->getOperand(1);
268   IRBuilder<> Builder(ST);
269   // Use the maximum column as stride. It must be the same with load
270   // stride.
271   Value *Stride = Builder.getInt64(64);
272   Value *I8Ptr =
273       Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
274   std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
275   Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
276   if (Bitcast->hasOneUse())
277     return;
278   // %13 = bitcast x86_amx %src to <256 x i32>
279   // store <256 x i32> %13, <256 x i32>* %addr, align 64
280   // %add = <256 x i32> %13, <256 x i32> %src2
281   // -->
282   // %13 = bitcast x86_amx %src to <256 x i32>
283   // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
284   //                                           %stride64, %13)
285   // %14 = load <256 x i32>, %addr
286   // %add = <256 x i32> %14, <256 x i32> %src2
287   Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
288   Bitcast->replaceAllUsesWith(Vec);
289 }
290 
291 // transform bitcast to <store, load> instructions.
292 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
293   IRBuilder<> Builder(Bitcast);
294   AllocaInst *AllocaAddr;
295   Value *I8Ptr, *Stride;
296   auto *Src = Bitcast->getOperand(0);
297 
298   auto Prepare = [&](Type *MemTy) {
299     AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
300     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
301     Stride = Builder.getInt64(64);
302   };
303 
304   if (Bitcast->getType()->isX86_AMXTy()) {
305     // %2 = bitcast <256 x i32> %src to x86_amx
306     // -->
307     // %addr = alloca <256 x i32>, align 64
308     // store <256 x i32> %src, <256 x i32>* %addr, align 64
309     // %addr2 = bitcast <256 x i32>* to i8*
310     // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
311     //                                                  i8* %addr2,
312     //                                                  i64 64)
313     Use &U = *(Bitcast->use_begin());
314     unsigned OpNo = U.getOperandNo();
315     auto *II = dyn_cast<IntrinsicInst>(U.getUser());
316     if (!II)
317       return false; // May be bitcast from x86amx to <256 x i32>.
318     Prepare(Bitcast->getOperand(0)->getType());
319     Builder.CreateStore(Src, AllocaAddr);
320     // TODO we can pick an constant operand for the shape.
321     Value *Row = nullptr, *Col = nullptr;
322     std::tie(Row, Col) = getShape(II, OpNo);
323     std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
324     Value *NewInst = Builder.CreateIntrinsic(
325         Intrinsic::x86_tileloadd64_internal, None, Args);
326     Bitcast->replaceAllUsesWith(NewInst);
327   } else {
328     // %2 = bitcast x86_amx %src to <256 x i32>
329     // -->
330     // %addr = alloca <256 x i32>, align 64
331     // %addr2 = bitcast <256 x i32>* to i8*
332     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
333     //                                           i8* %addr2, i64 %stride)
334     // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
335     auto *II = dyn_cast<IntrinsicInst>(Src);
336     if (!II)
337       return false; // May be bitcast from <256 x i32> to x86amx.
338     Prepare(Bitcast->getType());
339     Value *Row = II->getOperand(0);
340     Value *Col = II->getOperand(1);
341     std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
342     Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
343     Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
344     Bitcast->replaceAllUsesWith(NewInst);
345   }
346 
347   return true;
348 }
349 
350 bool X86LowerAMXType::visit() {
351   SmallVector<Instruction *, 8> DeadInsts;
352   Col2Row.clear();
353 
354   for (BasicBlock *BB : post_order(&Func)) {
355     for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) {
356       auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
357       if (!Bitcast)
358         continue;
359 
360       Value *Src = Bitcast->getOperand(0);
361       if (Bitcast->getType()->isX86_AMXTy()) {
362         if (Bitcast->user_empty()) {
363           DeadInsts.push_back(Bitcast);
364           continue;
365         }
366         LoadInst *LD = dyn_cast<LoadInst>(Src);
367         if (!LD) {
368           if (transformBitcast(Bitcast))
369             DeadInsts.push_back(Bitcast);
370           continue;
371         }
372         // If load has mutli-user, duplicate a vector load.
373         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
374         // %2 = bitcast <256 x i32> %src to x86_amx
375         // %add = add <256 x i32> %src, <256 x i32> %src2
376         // -->
377         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
378         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
379         //                                            i8* %addr, i64 %stride64)
380         // %add = add <256 x i32> %src, <256 x i32> %src2
381 
382         // If load has one user, the load will be eliminated in DAG ISel.
383         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
384         // %2 = bitcast <256 x i32> %src to x86_amx
385         // -->
386         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
387         //                                            i8* %addr, i64 %stride64)
388         combineLoadBitcast(LD, Bitcast);
389         DeadInsts.push_back(Bitcast);
390         if (LD->hasOneUse())
391           DeadInsts.push_back(LD);
392       } else if (Src->getType()->isX86_AMXTy()) {
393         if (Bitcast->user_empty()) {
394           DeadInsts.push_back(Bitcast);
395           continue;
396         }
397         StoreInst *ST = nullptr;
398         for (Use &U : Bitcast->uses()) {
399           ST = dyn_cast<StoreInst>(U.getUser());
400           if (ST)
401             break;
402         }
403         if (!ST) {
404           if (transformBitcast(Bitcast))
405             DeadInsts.push_back(Bitcast);
406           continue;
407         }
408         // If bitcast (%13) has one use, combine bitcast and store to amx store.
409         // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
410         //                                                    %stride);
411         // %13 = bitcast x86_amx %src to <256 x i32>
412         // store <256 x i32> %13, <256 x i32>* %addr, align 64
413         // -->
414         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
415         //                                           %stride64, %13)
416         //
417         // If bitcast (%13) has multi-use, transform as below.
418         // %13 = bitcast x86_amx %src to <256 x i32>
419         // store <256 x i32> %13, <256 x i32>* %addr, align 64
420         // %add = <256 x i32> %13, <256 x i32> %src2
421         // -->
422         // %13 = bitcast x86_amx %src to <256 x i32>
423         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
424         //                                           %stride64, %13)
425         // %14 = load <256 x i32>, %addr
426         // %add = <256 x i32> %14, <256 x i32> %src2
427         //
428         combineBitcastStore(Bitcast, ST);
429         // Delete user first.
430         DeadInsts.push_back(ST);
431         DeadInsts.push_back(Bitcast);
432       }
433     }
434   }
435 
436   bool C = !DeadInsts.empty();
437 
438   for (auto *Inst : DeadInsts)
439     Inst->eraseFromParent();
440 
441   return C;
442 }
443 } // anonymous namespace
444 
445 static Value *getAllocaPos(BasicBlock *BB) {
446   Module *M = BB->getModule();
447   Function *F = BB->getParent();
448   IRBuilder<> Builder(&F->getEntryBlock().front());
449   const DataLayout &DL = M->getDataLayout();
450   unsigned AllocaAS = DL.getAllocaAddrSpace();
451   Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
452   AllocaInst *AllocaRes =
453       new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
454   BasicBlock::iterator Iter = AllocaRes->getIterator();
455   ++Iter;
456   Builder.SetInsertPoint(&*Iter);
457   Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
458   return I8Ptr;
459 }
460 
461 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
462   assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
463   auto *II = cast<IntrinsicInst>(TileDef);
464   assert(II && "Not tile intrinsic!");
465   Value *Row = II->getOperand(0);
466   Value *Col = II->getOperand(1);
467 
468   BasicBlock *BB = TileDef->getParent();
469   BasicBlock::iterator Iter = TileDef->getIterator();
470   IRBuilder<> Builder(BB, ++Iter);
471   Value *Stride = Builder.getInt64(64);
472   std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
473 
474   Instruction *TileStore =
475       Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
476   return TileStore;
477 }
478 
479 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
480   Value *V = U.get();
481   assert(V->getType()->isX86_AMXTy() && "Not define tile!");
482 
483   // Get tile shape.
484   IntrinsicInst *II = nullptr;
485   if (IsPHI) {
486     Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
487     II = cast<IntrinsicInst>(PhiOp);
488   } else {
489     II = cast<IntrinsicInst>(V);
490   }
491   Value *Row = II->getOperand(0);
492   Value *Col = II->getOperand(1);
493 
494   Instruction *UserI = dyn_cast<Instruction>(U.getUser());
495   IRBuilder<> Builder(UserI);
496   Value *Stride = Builder.getInt64(64);
497   std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
498 
499   Value *TileLoad =
500       Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
501   UserI->replaceUsesOfWith(V, TileLoad);
502 }
503 
504 static bool isIncomingOfPHI(Instruction *I) {
505   for (Use &U : I->uses()) {
506     User *V = U.getUser();
507     if (isa<PHINode>(V))
508       return true;
509   }
510   return false;
511 }
512 
513 // Let all AMX tile data become volatile data, shorten the life range
514 // of each tile register before fast register allocation.
515 namespace {
516 class X86VolatileTileData {
517   Function &F;
518 
519 public:
520   X86VolatileTileData(Function &Func) : F(Func) {}
521   Value *updatePhiIncomings(BasicBlock *BB,
522                             SmallVector<Instruction *, 2> &Incomings);
523   void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
524   bool volatileTileData();
525   void volatileTilePHI(PHINode *Inst);
526   void volatileTileNonPHI(Instruction *I);
527 };
528 
529 Value *X86VolatileTileData::updatePhiIncomings(
530     BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
531   Value *I8Ptr = getAllocaPos(BB);
532 
533   for (auto *I : Incomings) {
534     User *Store = createTileStore(I, I8Ptr);
535 
536     // All its uses (except phi) should load from stored mem.
537     for (Use &U : I->uses()) {
538       User *V = U.getUser();
539       if (isa<PHINode>(V) || V == Store)
540         continue;
541       replaceWithTileLoad(U, I8Ptr);
542     }
543   }
544   return I8Ptr;
545 }
546 
547 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
548                                                 Value *StorePtr) {
549   for (Use &U : PHI->uses())
550     replaceWithTileLoad(U, StorePtr, true);
551   PHI->eraseFromParent();
552 }
553 
554 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
555 // and their related AMX intrinsics.
556 // 1) PHI Def should change to tileload.
557 // 2) PHI Incoming Values should tilestored in just after their def.
558 // 3) The mem of these tileload and tilestores should be same.
559 // e.g.
560 // ------------------------------------------------------
561 // bb_dom:
562 //   ...
563 //   br i1 %bool.cond, label %if.else, label %if.then
564 //
565 // if.then:
566 //   def %t0 = ...
567 //   ...
568 //   use %t0
569 //   ...
570 //   br label %if.end
571 //
572 // if.else:
573 //   def %t1 = ...
574 //   br label %if.end
575 //
576 // if.end:
577 //   %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
578 //   ...
579 //   use %td
580 // ------------------------------------------------------
581 // -->
582 // ------------------------------------------------------
583 // bb_entry:
584 //   %mem = alloca <256 x i32>, align 1024                  *
585 //   ...
586 // bb_dom:
587 //   ...
588 //   br i1 %bool.cond, label %if.else, label %if.then
589 //
590 // if.then:
591 //   def %t0 = ...
592 //   call void @llvm.x86.tilestored64.internal(mem, %t0)    *
593 //   ...
594 //   %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
595 //   use %t0`                                               *
596 //   ...
597 //   br label %if.end
598 //
599 // if.else:
600 //   def %t1 = ...
601 //   call void @llvm.x86.tilestored64.internal(mem, %t1)    *
602 //   br label %if.end
603 //
604 // if.end:
605 //   ...
606 //   %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
607 //   use %td
608 // ------------------------------------------------------
609 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
610   BasicBlock *BB = PHI->getParent();
611   SmallVector<Instruction *, 2> Incomings;
612 
613   for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
614     Value *Op = PHI->getIncomingValue(I);
615     Instruction *Inst = dyn_cast<Instruction>(Op);
616     assert(Inst && "We shouldn't fold AMX instrution!");
617     Incomings.push_back(Inst);
618   }
619 
620   Value *StorePtr = updatePhiIncomings(BB, Incomings);
621   replacePhiDefWithLoad(PHI, StorePtr);
622 }
623 
624 // Store the defined tile and load it before use.
625 // All its users are not PHI.
626 // e.g.
627 // ------------------------------------------------------
628 // def %td = ...
629 // ...
630 // "use %td"
631 // ------------------------------------------------------
632 // -->
633 // ------------------------------------------------------
634 // def %td = ...
635 // call void @llvm.x86.tilestored64.internal(mem, %td)
636 // ...
637 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
638 // "use %td2"
639 // ------------------------------------------------------
640 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
641   BasicBlock *BB = I->getParent();
642   Value *I8Ptr = getAllocaPos(BB);
643   User *Store = createTileStore(I, I8Ptr);
644 
645   // All its uses should load from stored mem.
646   for (Use &U : I->uses()) {
647     User *V = U.getUser();
648     assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
649     if (V != Store)
650       replaceWithTileLoad(U, I8Ptr);
651   }
652 }
653 
654 // Volatile Tile Model:
655 // 1) All the uses of tile data comes from tileload in time.
656 // 2) All the defs of tile data tilestore into mem immediately.
657 // For example:
658 // --------------------------------------------------------------------------
659 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
660 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
661 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
662 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
663 // call void @llvm.x86.tilestored64.internal(... td)                     area
664 // --------------------------------------------------------------------------
665 // 3) No terminator, call or other amx instructions in the key amx area.
666 bool X86VolatileTileData::volatileTileData() {
667   bool Changed = false;
668   for (BasicBlock &BB : F) {
669     SmallVector<Instruction *, 2> PHIInsts;
670     SmallVector<Instruction *, 8> AMXDefInsts;
671 
672     for (Instruction &I : BB) {
673       if (!I.getType()->isX86_AMXTy())
674         continue;
675       if (isa<PHINode>(&I))
676         PHIInsts.push_back(&I);
677       else
678         AMXDefInsts.push_back(&I);
679     }
680 
681     // First we "volatile" the non-phi related amx intrinsics.
682     for (Instruction *I : AMXDefInsts) {
683       if (isIncomingOfPHI(I))
684         continue;
685       volatileTileNonPHI(I);
686       Changed = true;
687     }
688 
689     for (Instruction *I : PHIInsts) {
690       volatileTilePHI(dyn_cast<PHINode>(I));
691       Changed = true;
692     }
693   }
694   return Changed;
695 }
696 
697 } // anonymous namespace
698 
699 namespace {
700 
701 class X86LowerAMXCast {
702   Function &Func;
703 
704 public:
705   X86LowerAMXCast(Function &F) : Func(F) {}
706   void combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
707   void combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
708   bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
709   bool combineAMXcast(TargetLibraryInfo *TLI);
710   bool transformAMXCast(IntrinsicInst *AMXCast);
711   bool transformAllAMXCast();
712   bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
713                               SmallSetVector<Instruction *, 16> &DeadInst);
714 };
715 
716 static bool DCEInstruction(Instruction *I,
717                            SmallSetVector<Instruction *, 16> &WorkList,
718                            const TargetLibraryInfo *TLI) {
719   if (isInstructionTriviallyDead(I, TLI)) {
720     salvageDebugInfo(*I);
721     salvageKnowledge(I);
722 
723     // Null out all of the instruction's operands to see if any operand becomes
724     // dead as we go.
725     for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
726       Value *OpV = I->getOperand(i);
727       I->setOperand(i, nullptr);
728 
729       if (!OpV->use_empty() || I == OpV)
730         continue;
731 
732       // If the operand is an instruction that became dead as we nulled out the
733       // operand, and if it is 'trivially' dead, delete it in a future loop
734       // iteration.
735       if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
736         if (isInstructionTriviallyDead(OpI, TLI)) {
737           WorkList.insert(OpI);
738         }
739       }
740     }
741     I->eraseFromParent();
742     return true;
743   }
744   return false;
745 }
746 
747 /// This function handles following case
748 ///
749 ///     A  ->  B    amxcast
750 ///     PHI
751 ///     B  ->  A    amxcast
752 ///
753 /// All the related PHI nodes can be replaced by new PHI nodes with type A.
754 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
755 bool X86LowerAMXCast::optimizeAMXCastFromPhi(
756     IntrinsicInst *CI, PHINode *PN,
757     SmallSetVector<Instruction *, 16> &DeadInst) {
758   IRBuilder<> Builder(CI);
759   Value *Src = CI->getOperand(0);
760   Type *SrcTy = Src->getType(); // Type B
761   Type *DestTy = CI->getType(); // Type A
762 
763   SmallVector<PHINode *, 4> PhiWorklist;
764   SmallSetVector<PHINode *, 4> OldPhiNodes;
765 
766   // Find all of the A->B casts and PHI nodes.
767   // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
768   // OldPhiNodes is used to track all known PHI nodes, before adding a new
769   // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
770   PhiWorklist.push_back(PN);
771   OldPhiNodes.insert(PN);
772   while (!PhiWorklist.empty()) {
773     auto *OldPN = PhiWorklist.pop_back_val();
774     for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
775       Value *IncValue = OldPN->getIncomingValue(I);
776       // TODO: currently, We ignore cases where it is a const. In the future, we
777       // might support const.
778       if (isa<Constant>(IncValue)) {
779         auto *IncConst = dyn_cast<Constant>(IncValue);
780         if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
781           return false;
782         Value *Row = nullptr, *Col = nullptr;
783         std::tie(Row, Col) = getShape(OldPN);
784         // TODO: If it is not constant the Row and Col must domoniate tilezero
785         // that we are going to create.
786         if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
787           return false;
788         // Create tilezero at the end of incoming block.
789         auto *Block = OldPN->getIncomingBlock(I);
790         BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
791         Instruction *NewInst = Builder.CreateIntrinsic(
792             Intrinsic::x86_tilezero_internal, None, {Row, Col});
793         NewInst->moveBefore(&*Iter);
794         NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
795                                           {IncValue->getType()}, {NewInst});
796         NewInst->moveBefore(&*Iter);
797         // Replace InValue with new Value.
798         OldPN->setIncomingValue(I, NewInst);
799         IncValue = NewInst;
800       }
801 
802       if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
803         if (OldPhiNodes.insert(PNode))
804           PhiWorklist.push_back(PNode);
805         continue;
806       }
807       Instruction *ACI = dyn_cast<Instruction>(IncValue);
808       if (ACI && isAMXCast(ACI)) {
809         // Verify it's a A->B cast.
810         Type *TyA = ACI->getOperand(0)->getType();
811         Type *TyB = ACI->getType();
812         if (TyA != DestTy || TyB != SrcTy)
813           return false;
814         continue;
815       }
816       return false;
817     }
818   }
819 
820   // Check that each user of each old PHI node is something that we can
821   // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
822   for (auto *OldPN : OldPhiNodes) {
823     for (User *V : OldPN->users()) {
824       Instruction *ACI = dyn_cast<Instruction>(V);
825       if (ACI && isAMXCast(ACI)) {
826         // Verify it's a B->A cast.
827         Type *TyB = ACI->getOperand(0)->getType();
828         Type *TyA = ACI->getType();
829         if (TyA != DestTy || TyB != SrcTy)
830           return false;
831       } else if (auto *PHI = dyn_cast<PHINode>(V)) {
832         // As long as the user is another old PHI node, then even if we don't
833         // rewrite it, the PHI web we're considering won't have any users
834         // outside itself, so it'll be dead.
835         // example:
836         //   bb.0:
837         //      %0 = amxcast ...
838         //   bb.1:
839         //      %1 = amxcast ...
840         //   bb.2:
841         //      %goodphi = phi %0, %1
842         //      %3 = amxcast %goodphi
843         //   bb.3:
844         //      %goodphi2 = phi %0, %goodphi
845         //      %4 = amxcast %goodphi2
846         // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
847         // outside the phi-web, so the combination stop When
848         // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
849         // will be done.
850         if (OldPhiNodes.count(PHI) == 0)
851           return false;
852       } else
853         return false;
854     }
855   }
856 
857   // For each old PHI node, create a corresponding new PHI node with a type A.
858   SmallDenseMap<PHINode *, PHINode *> NewPNodes;
859   for (auto *OldPN : OldPhiNodes) {
860     Builder.SetInsertPoint(OldPN);
861     PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
862     NewPNodes[OldPN] = NewPN;
863   }
864 
865   // Fill in the operands of new PHI nodes.
866   for (auto *OldPN : OldPhiNodes) {
867     PHINode *NewPN = NewPNodes[OldPN];
868     for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
869       Value *V = OldPN->getOperand(j);
870       Value *NewV = nullptr;
871       Instruction *ACI = dyn_cast<Instruction>(V);
872       // There should not be a AMXcast from a const.
873       if (ACI && isAMXCast(ACI))
874         NewV = ACI->getOperand(0);
875       else if (auto *PrevPN = dyn_cast<PHINode>(V))
876         NewV = NewPNodes[PrevPN];
877       assert(NewV);
878       NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
879     }
880   }
881 
882   // Traverse all accumulated PHI nodes and process its users,
883   // which are Stores and BitcCasts. Without this processing
884   // NewPHI nodes could be replicated and could lead to extra
885   // moves generated after DeSSA.
886   // If there is a store with type B, change it to type A.
887 
888   // Replace users of BitCast B->A with NewPHI. These will help
889   // later to get rid of a closure formed by OldPHI nodes.
890   for (auto *OldPN : OldPhiNodes) {
891     PHINode *NewPN = NewPNodes[OldPN];
892     for (User *V : make_early_inc_range(OldPN->users())) {
893       Instruction *ACI = dyn_cast<Instruction>(V);
894       if (ACI && isAMXCast(ACI)) {
895         Type *TyB = ACI->getOperand(0)->getType();
896         Type *TyA = ACI->getType();
897         assert(TyA == DestTy && TyB == SrcTy);
898         (void)TyA;
899         (void)TyB;
900         ACI->replaceAllUsesWith(NewPN);
901         DeadInst.insert(ACI);
902       } else if (auto *PHI = dyn_cast<PHINode>(V)) {
903         // We don't need to push PHINode into DeadInst since they are operands
904         // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
905         assert(OldPhiNodes.contains(PHI));
906         (void)PHI;
907       } else
908         llvm_unreachable("all uses should be handled");
909     }
910   }
911   return true;
912 }
913 
914 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
915 // store <256 x i32> %43, <256 x i32>* %p, align 64
916 // -->
917 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
918 //                                           i64 64, x86_amx %42)
919 void X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
920   Value *Tile = Cast->getOperand(0);
921   // TODO: If it is cast intrinsic or phi node, we can propagate the
922   // shape information through def-use chain.
923   if (!isAMXIntrinsic(Tile))
924     return;
925   auto *II = cast<IntrinsicInst>(Tile);
926   // Tile is output from AMX intrinsic. The first operand of the
927   // intrinsic is row, the second operand of the intrinsic is column.
928   Value *Row = II->getOperand(0);
929   Value *Col = II->getOperand(1);
930   IRBuilder<> Builder(ST);
931   // Use the maximum column as stride. It must be the same with load
932   // stride.
933   Value *Stride = Builder.getInt64(64);
934   Value *I8Ptr =
935       Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
936   std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
937   Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
938 }
939 
940 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
941 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
942 // -->
943 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
944 //                                                   i8* %p, i64 64)
945 void X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
946   Value *Row = nullptr, *Col = nullptr;
947   Use &U = *(Cast->use_begin());
948   unsigned OpNo = U.getOperandNo();
949   auto *II = cast<IntrinsicInst>(U.getUser());
950   // TODO: If it is cast intrinsic or phi node, we can propagate the
951   // shape information through def-use chain.
952   if (!isAMXIntrinsic(II))
953     return;
954   std::tie(Row, Col) = getShape(II, OpNo);
955   IRBuilder<> Builder(LD);
956   // Use the maximun column as stride.
957   Value *Stride = Builder.getInt64(64);
958   Value *I8Ptr =
959       Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
960   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
961 
962   Value *NewInst =
963       Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
964   Cast->replaceAllUsesWith(NewInst);
965 }
966 
967 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
968   bool Change = false;
969   for (auto *Cast : Casts) {
970     auto *II = cast<IntrinsicInst>(Cast);
971     // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
972     // store <256 x i32> %43, <256 x i32>* %p, align 64
973     // -->
974     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
975     //                                           i64 64, x86_amx %42)
976     if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
977       SmallVector<Instruction *, 2> DeadStores;
978       for (User *U : Cast->users()) {
979         StoreInst *Store = dyn_cast<StoreInst>(U);
980         if (!Store)
981           continue;
982         combineCastStore(cast<IntrinsicInst>(Cast), Store);
983         DeadStores.push_back(Store);
984         Change = true;
985       }
986       for (auto *Store : DeadStores)
987         Store->eraseFromParent();
988     } else { // x86_cast_vector_to_tile
989       SmallVector<Instruction *, 2> DeadLoads;
990       auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
991       if (!Load || !Load->hasOneUse())
992         continue;
993       // %65 = load <256 x i32>, <256 x i32>* %p, align 64
994       // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
995       // -->
996       // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
997       //                                                   i8* %p, i64 64)
998       combineLoadCast(cast<IntrinsicInst>(Cast), Load);
999       // Set the operand is null so that load instruction can be erased.
1000       Cast->setOperand(0, nullptr);
1001       Load->eraseFromParent();
1002     }
1003   }
1004   return Change;
1005 }
1006 
1007 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1008   bool Change = false;
1009   // Collect tile cast instruction.
1010   SmallVector<Instruction *, 8> Vec2TileInsts;
1011   SmallVector<Instruction *, 8> Tile2VecInsts;
1012   SmallVector<Instruction *, 8> PhiCastWorkList;
1013   SmallSetVector<Instruction *, 16> DeadInst;
1014   for (BasicBlock &BB : Func) {
1015     for (Instruction &I : BB) {
1016       Value *Vec;
1017       if (match(&I,
1018                 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
1019         Vec2TileInsts.push_back(&I);
1020       else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1021                              m_Value(Vec))))
1022         Tile2VecInsts.push_back(&I);
1023     }
1024   }
1025 
1026   auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1027     for (auto *Inst : Insts) {
1028       for (User *U : Inst->users()) {
1029         IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1030         if (!II || II->getIntrinsicID() != IID)
1031           continue;
1032         // T1 = vec2tile V0
1033         // V2 = tile2vec T1
1034         // V3 = OP V2
1035         // -->
1036         // T1 = vec2tile V0
1037         // V2 = tile2vec T1
1038         // V3 = OP V0
1039         II->replaceAllUsesWith(Inst->getOperand(0));
1040         Change = true;
1041       }
1042     }
1043   };
1044 
1045   Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1046   Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1047 
1048   SmallVector<Instruction *, 8> LiveCasts;
1049   auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1050     for (auto *Inst : Insts) {
1051       if (Inst->use_empty()) {
1052         Inst->eraseFromParent();
1053         Change = true;
1054       } else {
1055         LiveCasts.push_back(Inst);
1056       }
1057     }
1058   };
1059 
1060   EraseInst(Vec2TileInsts);
1061   EraseInst(Tile2VecInsts);
1062   Change |= combineLdSt(LiveCasts);
1063   EraseInst(LiveCasts);
1064 
1065   // Handle the A->B->A cast, and there is an intervening PHI node.
1066   for (BasicBlock &BB : Func) {
1067     for (Instruction &I : BB) {
1068       if (isAMXCast(&I)) {
1069         if (isa<PHINode>(I.getOperand(0)))
1070           PhiCastWorkList.push_back(&I);
1071       }
1072     }
1073   }
1074   for (auto *I : PhiCastWorkList) {
1075     // We skip the dead Amxcast.
1076     if (DeadInst.contains(I))
1077       continue;
1078     PHINode *PN = cast<PHINode>(I->getOperand(0));
1079     if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1080       DeadInst.insert(PN);
1081       Change = true;
1082     }
1083   }
1084 
1085   // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1086   // have no uses. We do some DeadCodeElimination for them.
1087   while (!DeadInst.empty()) {
1088     Instruction *I = DeadInst.pop_back_val();
1089     Change |= DCEInstruction(I, DeadInst, TLI);
1090   }
1091   return Change;
1092 }
1093 
1094 // There might be remaining AMXcast after combineAMXcast and they should be
1095 // handled elegantly.
1096 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1097   IRBuilder<> Builder(AMXCast);
1098   AllocaInst *AllocaAddr;
1099   Value *I8Ptr, *Stride;
1100   auto *Src = AMXCast->getOperand(0);
1101 
1102   auto Prepare = [&](Type *MemTy) {
1103     AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1104     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
1105     Stride = Builder.getInt64(64);
1106   };
1107 
1108   if (AMXCast->getType()->isX86_AMXTy()) {
1109     // %2 = amxcast <225 x i32> %src to x86_amx
1110     // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1111     //                                           i8* %addr3, i64 60, x86_amx %2)
1112     // -->
1113     // %addr = alloca <225 x i32>, align 64
1114     // store <225 x i32> %src, <225 x i32>* %addr, align 64
1115     // %addr2 = bitcast <225 x i32>* %addr to i8*
1116     // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1117     //                                                  i8* %addr2,
1118     //                                                  i64 60)
1119     // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1120     //                                           i8* %addr3, i64 60, x86_amx %2)
1121     if (AMXCast->use_empty()) {
1122       AMXCast->eraseFromParent();
1123       return true;
1124     }
1125     Use &U = *(AMXCast->use_begin());
1126     unsigned OpNo = U.getOperandNo();
1127     auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1128     if (!II)
1129       return false; // May be bitcast from x86amx to <256 x i32>.
1130     Prepare(AMXCast->getOperand(0)->getType());
1131     Builder.CreateStore(Src, AllocaAddr);
1132     // TODO we can pick an constant operand for the shape.
1133     Value *Row = nullptr, *Col = nullptr;
1134     std::tie(Row, Col) = getShape(II, OpNo);
1135     std::array<Value *, 4> Args = {
1136         Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1137     Value *NewInst = Builder.CreateIntrinsic(
1138         Intrinsic::x86_tileloadd64_internal, None, Args);
1139     AMXCast->replaceAllUsesWith(NewInst);
1140     AMXCast->eraseFromParent();
1141   } else {
1142     // %2 = amxcast x86_amx %src to <225 x i32>
1143     // -->
1144     // %addr = alloca <225 x i32>, align 64
1145     // %addr2 = bitcast <225 x i32>* to i8*
1146     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1147     //                                           i8* %addr2, i64 %stride)
1148     // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1149     auto *II = dyn_cast<IntrinsicInst>(Src);
1150     if (!II)
1151       return false; // May be bitcast from <256 x i32> to x86amx.
1152     Prepare(AMXCast->getType());
1153     Value *Row = II->getOperand(0);
1154     Value *Col = II->getOperand(1);
1155     std::array<Value *, 5> Args = {
1156         Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1157     Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
1158     Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1159     AMXCast->replaceAllUsesWith(NewInst);
1160     AMXCast->eraseFromParent();
1161   }
1162 
1163   return true;
1164 }
1165 
1166 bool X86LowerAMXCast::transformAllAMXCast() {
1167   bool Change = false;
1168   // Collect tile cast instruction.
1169   SmallVector<Instruction *, 8> WorkLists;
1170   for (BasicBlock &BB : Func) {
1171     for (Instruction &I : BB) {
1172       if (isAMXCast(&I))
1173         WorkLists.push_back(&I);
1174     }
1175   }
1176 
1177   for (auto *Inst : WorkLists) {
1178     Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1179   }
1180 
1181   return Change;
1182 }
1183 
1184 } // anonymous namespace
1185 
1186 namespace {
1187 
1188 class X86LowerAMXTypeLegacyPass : public FunctionPass {
1189 public:
1190   static char ID;
1191 
1192   X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1193     initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
1194   }
1195 
1196   bool runOnFunction(Function &F) override {
1197     bool C = false;
1198     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1199     TargetLibraryInfo *TLI =
1200         &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1201     X86LowerAMXCast LAC(F);
1202     C |= LAC.combineAMXcast(TLI);
1203     // There might be remaining AMXcast after combineAMXcast and they should be
1204     // handled elegantly.
1205     C |= LAC.transformAllAMXCast();
1206 
1207     X86LowerAMXType LAT(F);
1208     C |= LAT.visit();
1209 
1210     // Prepare for fast register allocation at O0.
1211     // Todo: May better check the volatile model of AMX code, not just
1212     // by checking Attribute::OptimizeNone and CodeGenOpt::None.
1213     if (TM->getOptLevel() == CodeGenOpt::None) {
1214       // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1215       // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1216       // sure the amx data is volatile, that is nessary for AMX fast
1217       // register allocation.
1218       if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1219         X86VolatileTileData VTD(F);
1220         C = VTD.volatileTileData() || C;
1221       }
1222     }
1223 
1224     return C;
1225   }
1226 
1227   void getAnalysisUsage(AnalysisUsage &AU) const override {
1228     AU.setPreservesCFG();
1229     AU.addRequired<TargetPassConfig>();
1230     AU.addRequired<TargetLibraryInfoWrapperPass>();
1231   }
1232 };
1233 
1234 } // anonymous namespace
1235 
1236 static const char PassName[] = "Lower AMX type for load/store";
1237 char X86LowerAMXTypeLegacyPass::ID = 0;
1238 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1239                       false)
1240 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1241 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
1242 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1243                     false)
1244 
1245 FunctionPass *llvm::createX86LowerAMXTypePass() {
1246   return new X86LowerAMXTypeLegacyPass();
1247 }
1248