1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ...                                          |
27 /// | ...                                                    |
28 /// | "use %td"                                              |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ...                                          |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td)    |
34 /// | ...                                                    |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2"                                             |
37 /// ----------------------------------------------------------
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 #include "X86.h"
42 #include "llvm/ADT/PostOrderIterator.h"
43 #include "llvm/ADT/SetVector.h"
44 #include "llvm/ADT/SmallSet.h"
45 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
46 #include "llvm/Analysis/TargetLibraryInfo.h"
47 #include "llvm/Analysis/TargetTransformInfo.h"
48 #include "llvm/CodeGen/Passes.h"
49 #include "llvm/CodeGen/TargetPassConfig.h"
50 #include "llvm/CodeGen/ValueTypes.h"
51 #include "llvm/IR/DataLayout.h"
52 #include "llvm/IR/Function.h"
53 #include "llvm/IR/IRBuilder.h"
54 #include "llvm/IR/Instructions.h"
55 #include "llvm/IR/IntrinsicInst.h"
56 #include "llvm/IR/IntrinsicsX86.h"
57 #include "llvm/IR/PatternMatch.h"
58 #include "llvm/InitializePasses.h"
59 #include "llvm/Pass.h"
60 #include "llvm/Target/TargetMachine.h"
61 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
62 #include "llvm/Transforms/Utils/Local.h"
63 
64 #include <map>
65 
66 using namespace llvm;
67 using namespace PatternMatch;
68 
69 #define DEBUG_TYPE "lower-amx-type"
70 
71 static bool isAMXCast(Instruction *II) {
72   return match(II,
73                m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
74          match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
75 }
76 
77 static bool isAMXIntrinsic(Value *I) {
78   auto *II = dyn_cast<IntrinsicInst>(I);
79   if (!II)
80     return false;
81   if (isAMXCast(II))
82     return false;
83   // Check if return type or parameter is x86_amx. If it is x86_amx
84   // the intrinsic must be x86 amx intrinsics.
85   if (II->getType()->isX86_AMXTy())
86     return true;
87   for (Value *V : II->args()) {
88     if (V->getType()->isX86_AMXTy())
89       return true;
90   }
91 
92   return false;
93 }
94 
95 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
96                                            Type *Ty) {
97   Function &F = *BB->getParent();
98   Module *M = BB->getModule();
99   const DataLayout &DL = M->getDataLayout();
100 
101   LLVMContext &Ctx = Builder.getContext();
102   auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
103   unsigned AllocaAS = DL.getAllocaAddrSpace();
104   AllocaInst *AllocaRes =
105       new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
106   AllocaRes->setAlignment(AllocaAlignment);
107   return AllocaRes;
108 }
109 
110 static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
111   for (Instruction &I : F.getEntryBlock())
112     if (!isa<AllocaInst>(&I))
113       return &I;
114   llvm_unreachable("No terminator in the entry block!");
115 }
116 
117 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
118   IRBuilder<> Builder(II);
119   Value *Row = nullptr, *Col = nullptr;
120   switch (II->getIntrinsicID()) {
121   default:
122     llvm_unreachable("Expect amx intrinsics");
123   case Intrinsic::x86_tileloadd64_internal:
124   case Intrinsic::x86_tileloaddt164_internal:
125   case Intrinsic::x86_tilestored64_internal: {
126     Row = II->getArgOperand(0);
127     Col = II->getArgOperand(1);
128     break;
129   }
130   // a * b + c
131   // The shape depends on which operand.
132   case Intrinsic::x86_tcmmimfp16ps_internal:
133   case Intrinsic::x86_tcmmrlfp16ps_internal:
134   case Intrinsic::x86_tdpbssd_internal:
135   case Intrinsic::x86_tdpbsud_internal:
136   case Intrinsic::x86_tdpbusd_internal:
137   case Intrinsic::x86_tdpbuud_internal:
138   case Intrinsic::x86_tdpbf16ps_internal:
139   case Intrinsic::x86_tdpfp16ps_internal: {
140     switch (OpNo) {
141     case 3:
142       Row = II->getArgOperand(0);
143       Col = II->getArgOperand(1);
144       break;
145     case 4:
146       Row = II->getArgOperand(0);
147       Col = II->getArgOperand(2);
148       break;
149     case 5:
150       if (isa<ConstantInt>(II->getArgOperand(2)))
151         Row = Builder.getInt16(
152             (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
153       else if (isa<Instruction>(II->getArgOperand(2))) {
154         // When it is not a const value and it is not a function argument, we
155         // create Row after the definition of II->getOperand(2) instead of
156         // before II. For example, II is %118, we try to getshape for %117:
157         //   %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
158         //   i32> %115).
159         //   %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
160         //   %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
161         //   %117).
162         // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
163         // definition is after its user(new tileload for %117).
164         // So, the best choice is to create %row right after the definition of
165         // %106.
166         Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
167         Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
168         cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
169       } else {
170         // When it is not a const value and it is a function argument, we create
171         // Row at the entry bb.
172         IRBuilder<> NewBuilder(
173             getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
174         Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
175       }
176       Col = II->getArgOperand(1);
177       break;
178     }
179     break;
180   }
181   }
182 
183   return std::make_pair(Row, Col);
184 }
185 
186 static std::pair<Value *, Value *> getShape(PHINode *Phi) {
187   Use &U = *(Phi->use_begin());
188   unsigned OpNo = U.getOperandNo();
189   User *V = U.getUser();
190   // TODO We don't traverse all users. To make the algorithm simple, here we
191   // just traverse the first user. If we can find shape, then return the shape,
192   // otherwise just return nullptr and the optimization for undef/zero will be
193   // abandoned.
194   while (V) {
195     if (isAMXCast(dyn_cast<Instruction>(V))) {
196       if (V->use_empty())
197         break;
198       Use &U = *(V->use_begin());
199       OpNo = U.getOperandNo();
200       V = U.getUser();
201     } else if (isAMXIntrinsic(V)) {
202       return getShape(cast<IntrinsicInst>(V), OpNo);
203     } else if (isa<PHINode>(V)) {
204       if (V->use_empty())
205         break;
206       Use &U = *(V->use_begin());
207       V = U.getUser();
208     } else {
209       break;
210     }
211   }
212 
213   return std::make_pair(nullptr, nullptr);
214 }
215 
216 namespace {
217 class X86LowerAMXType {
218   Function &Func;
219 
220   // In AMX intrinsics we let Shape = {Row, Col}, but the
221   // RealCol = Col / ElementSize. We may use the RealCol
222   // as a new Row for other new created AMX intrinsics.
223   std::map<Value *, Value *> Col2Row;
224 
225 public:
226   X86LowerAMXType(Function &F) : Func(F) {}
227   bool visit();
228   void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
229   void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
230   bool transformBitcast(BitCastInst *Bitcast);
231 };
232 
233 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
234 // %2 = bitcast <256 x i32> %src to x86_amx
235 // -->
236 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
237 // i8* %addr, i64 %stride64)
238 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
239   Value *Row = nullptr, *Col = nullptr;
240   Use &U = *(Bitcast->use_begin());
241   unsigned OpNo = U.getOperandNo();
242   auto *II = cast<IntrinsicInst>(U.getUser());
243   std::tie(Row, Col) = getShape(II, OpNo);
244   IRBuilder<> Builder(Bitcast);
245   // Use the maximun column as stride.
246   Value *Stride = Builder.getInt64(64);
247   Value *I8Ptr =
248       Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
249   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
250 
251   Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
252                                            std::nullopt, Args);
253   Bitcast->replaceAllUsesWith(NewInst);
254 }
255 
256 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
257 //                                                    %stride);
258 // %13 = bitcast x86_amx %src to <256 x i32>
259 // store <256 x i32> %13, <256 x i32>* %addr, align 64
260 // -->
261 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
262 //                                           %stride64, %13)
263 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
264 
265   Value *Tile = Bitcast->getOperand(0);
266   auto *II = cast<IntrinsicInst>(Tile);
267   // Tile is output from AMX intrinsic. The first operand of the
268   // intrinsic is row, the second operand of the intrinsic is column.
269   Value *Row = II->getOperand(0);
270   Value *Col = II->getOperand(1);
271   IRBuilder<> Builder(ST);
272   // Use the maximum column as stride. It must be the same with load
273   // stride.
274   Value *Stride = Builder.getInt64(64);
275   Value *I8Ptr =
276       Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
277   std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
278   Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
279                           Args);
280   if (Bitcast->hasOneUse())
281     return;
282   // %13 = bitcast x86_amx %src to <256 x i32>
283   // store <256 x i32> %13, <256 x i32>* %addr, align 64
284   // %add = <256 x i32> %13, <256 x i32> %src2
285   // -->
286   // %13 = bitcast x86_amx %src to <256 x i32>
287   // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
288   //                                           %stride64, %13)
289   // %14 = load <256 x i32>, %addr
290   // %add = <256 x i32> %14, <256 x i32> %src2
291   Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
292   Bitcast->replaceAllUsesWith(Vec);
293 }
294 
295 // transform bitcast to <store, load> instructions.
296 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
297   IRBuilder<> Builder(Bitcast);
298   AllocaInst *AllocaAddr;
299   Value *I8Ptr, *Stride;
300   auto *Src = Bitcast->getOperand(0);
301 
302   auto Prepare = [&](Type *MemTy) {
303     AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
304     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
305     Stride = Builder.getInt64(64);
306   };
307 
308   if (Bitcast->getType()->isX86_AMXTy()) {
309     // %2 = bitcast <256 x i32> %src to x86_amx
310     // -->
311     // %addr = alloca <256 x i32>, align 64
312     // store <256 x i32> %src, <256 x i32>* %addr, align 64
313     // %addr2 = bitcast <256 x i32>* to i8*
314     // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
315     //                                                  i8* %addr2,
316     //                                                  i64 64)
317     Use &U = *(Bitcast->use_begin());
318     unsigned OpNo = U.getOperandNo();
319     auto *II = dyn_cast<IntrinsicInst>(U.getUser());
320     if (!II)
321       return false; // May be bitcast from x86amx to <256 x i32>.
322     Prepare(Bitcast->getOperand(0)->getType());
323     Builder.CreateStore(Src, AllocaAddr);
324     // TODO we can pick an constant operand for the shape.
325     Value *Row = nullptr, *Col = nullptr;
326     std::tie(Row, Col) = getShape(II, OpNo);
327     std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
328     Value *NewInst = Builder.CreateIntrinsic(
329         Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
330     Bitcast->replaceAllUsesWith(NewInst);
331   } else {
332     // %2 = bitcast x86_amx %src to <256 x i32>
333     // -->
334     // %addr = alloca <256 x i32>, align 64
335     // %addr2 = bitcast <256 x i32>* to i8*
336     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
337     //                                           i8* %addr2, i64 %stride)
338     // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
339     auto *II = dyn_cast<IntrinsicInst>(Src);
340     if (!II)
341       return false; // May be bitcast from <256 x i32> to x86amx.
342     Prepare(Bitcast->getType());
343     Value *Row = II->getOperand(0);
344     Value *Col = II->getOperand(1);
345     std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
346     Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
347                             Args);
348     Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
349     Bitcast->replaceAllUsesWith(NewInst);
350   }
351 
352   return true;
353 }
354 
355 bool X86LowerAMXType::visit() {
356   SmallVector<Instruction *, 8> DeadInsts;
357   Col2Row.clear();
358 
359   for (BasicBlock *BB : post_order(&Func)) {
360     for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) {
361       auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
362       if (!Bitcast)
363         continue;
364 
365       Value *Src = Bitcast->getOperand(0);
366       if (Bitcast->getType()->isX86_AMXTy()) {
367         if (Bitcast->user_empty()) {
368           DeadInsts.push_back(Bitcast);
369           continue;
370         }
371         LoadInst *LD = dyn_cast<LoadInst>(Src);
372         if (!LD) {
373           if (transformBitcast(Bitcast))
374             DeadInsts.push_back(Bitcast);
375           continue;
376         }
377         // If load has mutli-user, duplicate a vector load.
378         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
379         // %2 = bitcast <256 x i32> %src to x86_amx
380         // %add = add <256 x i32> %src, <256 x i32> %src2
381         // -->
382         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
383         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
384         //                                            i8* %addr, i64 %stride64)
385         // %add = add <256 x i32> %src, <256 x i32> %src2
386 
387         // If load has one user, the load will be eliminated in DAG ISel.
388         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
389         // %2 = bitcast <256 x i32> %src to x86_amx
390         // -->
391         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
392         //                                            i8* %addr, i64 %stride64)
393         combineLoadBitcast(LD, Bitcast);
394         DeadInsts.push_back(Bitcast);
395         if (LD->hasOneUse())
396           DeadInsts.push_back(LD);
397       } else if (Src->getType()->isX86_AMXTy()) {
398         if (Bitcast->user_empty()) {
399           DeadInsts.push_back(Bitcast);
400           continue;
401         }
402         StoreInst *ST = nullptr;
403         for (Use &U : Bitcast->uses()) {
404           ST = dyn_cast<StoreInst>(U.getUser());
405           if (ST)
406             break;
407         }
408         if (!ST) {
409           if (transformBitcast(Bitcast))
410             DeadInsts.push_back(Bitcast);
411           continue;
412         }
413         // If bitcast (%13) has one use, combine bitcast and store to amx store.
414         // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
415         //                                                    %stride);
416         // %13 = bitcast x86_amx %src to <256 x i32>
417         // store <256 x i32> %13, <256 x i32>* %addr, align 64
418         // -->
419         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
420         //                                           %stride64, %13)
421         //
422         // If bitcast (%13) has multi-use, transform as below.
423         // %13 = bitcast x86_amx %src to <256 x i32>
424         // store <256 x i32> %13, <256 x i32>* %addr, align 64
425         // %add = <256 x i32> %13, <256 x i32> %src2
426         // -->
427         // %13 = bitcast x86_amx %src to <256 x i32>
428         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
429         //                                           %stride64, %13)
430         // %14 = load <256 x i32>, %addr
431         // %add = <256 x i32> %14, <256 x i32> %src2
432         //
433         combineBitcastStore(Bitcast, ST);
434         // Delete user first.
435         DeadInsts.push_back(ST);
436         DeadInsts.push_back(Bitcast);
437       }
438     }
439   }
440 
441   bool C = !DeadInsts.empty();
442 
443   for (auto *Inst : DeadInsts)
444     Inst->eraseFromParent();
445 
446   return C;
447 }
448 } // anonymous namespace
449 
450 static Value *getAllocaPos(BasicBlock *BB) {
451   Module *M = BB->getModule();
452   Function *F = BB->getParent();
453   IRBuilder<> Builder(&F->getEntryBlock().front());
454   const DataLayout &DL = M->getDataLayout();
455   unsigned AllocaAS = DL.getAllocaAddrSpace();
456   Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
457   AllocaInst *AllocaRes =
458       new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
459   BasicBlock::iterator Iter = AllocaRes->getIterator();
460   ++Iter;
461   Builder.SetInsertPoint(&*Iter);
462   Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
463   return I8Ptr;
464 }
465 
466 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
467   assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
468   auto *II = cast<IntrinsicInst>(TileDef);
469   assert(II && "Not tile intrinsic!");
470   Value *Row = II->getOperand(0);
471   Value *Col = II->getOperand(1);
472 
473   BasicBlock *BB = TileDef->getParent();
474   BasicBlock::iterator Iter = TileDef->getIterator();
475   IRBuilder<> Builder(BB, ++Iter);
476   Value *Stride = Builder.getInt64(64);
477   std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
478 
479   Instruction *TileStore = Builder.CreateIntrinsic(
480       Intrinsic::x86_tilestored64_internal, std::nullopt, Args);
481   return TileStore;
482 }
483 
484 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
485   Value *V = U.get();
486   assert(V->getType()->isX86_AMXTy() && "Not define tile!");
487 
488   // Get tile shape.
489   IntrinsicInst *II = nullptr;
490   if (IsPHI) {
491     Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
492     II = cast<IntrinsicInst>(PhiOp);
493   } else {
494     II = cast<IntrinsicInst>(V);
495   }
496   Value *Row = II->getOperand(0);
497   Value *Col = II->getOperand(1);
498 
499   Instruction *UserI = dyn_cast<Instruction>(U.getUser());
500   IRBuilder<> Builder(UserI);
501   Value *Stride = Builder.getInt64(64);
502   std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
503 
504   Value *TileLoad = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
505                                             std::nullopt, Args);
506   UserI->replaceUsesOfWith(V, TileLoad);
507 }
508 
509 static bool isIncomingOfPHI(Instruction *I) {
510   for (Use &U : I->uses()) {
511     User *V = U.getUser();
512     if (isa<PHINode>(V))
513       return true;
514   }
515   return false;
516 }
517 
518 // Let all AMX tile data become volatile data, shorten the life range
519 // of each tile register before fast register allocation.
520 namespace {
521 class X86VolatileTileData {
522   Function &F;
523 
524 public:
525   X86VolatileTileData(Function &Func) : F(Func) {}
526   Value *updatePhiIncomings(BasicBlock *BB,
527                             SmallVector<Instruction *, 2> &Incomings);
528   void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
529   bool volatileTileData();
530   void volatileTilePHI(PHINode *PHI);
531   void volatileTileNonPHI(Instruction *I);
532 };
533 
534 Value *X86VolatileTileData::updatePhiIncomings(
535     BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
536   Value *I8Ptr = getAllocaPos(BB);
537 
538   for (auto *I : Incomings) {
539     User *Store = createTileStore(I, I8Ptr);
540 
541     // All its uses (except phi) should load from stored mem.
542     for (Use &U : I->uses()) {
543       User *V = U.getUser();
544       if (isa<PHINode>(V) || V == Store)
545         continue;
546       replaceWithTileLoad(U, I8Ptr);
547     }
548   }
549   return I8Ptr;
550 }
551 
552 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
553                                                 Value *StorePtr) {
554   for (Use &U : PHI->uses())
555     replaceWithTileLoad(U, StorePtr, true);
556   PHI->eraseFromParent();
557 }
558 
559 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
560 // and their related AMX intrinsics.
561 // 1) PHI Def should change to tileload.
562 // 2) PHI Incoming Values should tilestored in just after their def.
563 // 3) The mem of these tileload and tilestores should be same.
564 // e.g.
565 // ------------------------------------------------------
566 // bb_dom:
567 //   ...
568 //   br i1 %bool.cond, label %if.else, label %if.then
569 //
570 // if.then:
571 //   def %t0 = ...
572 //   ...
573 //   use %t0
574 //   ...
575 //   br label %if.end
576 //
577 // if.else:
578 //   def %t1 = ...
579 //   br label %if.end
580 //
581 // if.end:
582 //   %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
583 //   ...
584 //   use %td
585 // ------------------------------------------------------
586 // -->
587 // ------------------------------------------------------
588 // bb_entry:
589 //   %mem = alloca <256 x i32>, align 1024                  *
590 //   ...
591 // bb_dom:
592 //   ...
593 //   br i1 %bool.cond, label %if.else, label %if.then
594 //
595 // if.then:
596 //   def %t0 = ...
597 //   call void @llvm.x86.tilestored64.internal(mem, %t0)    *
598 //   ...
599 //   %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
600 //   use %t0`                                               *
601 //   ...
602 //   br label %if.end
603 //
604 // if.else:
605 //   def %t1 = ...
606 //   call void @llvm.x86.tilestored64.internal(mem, %t1)    *
607 //   br label %if.end
608 //
609 // if.end:
610 //   ...
611 //   %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
612 //   use %td
613 // ------------------------------------------------------
614 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
615   BasicBlock *BB = PHI->getParent();
616   SmallVector<Instruction *, 2> Incomings;
617 
618   for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
619     Value *Op = PHI->getIncomingValue(I);
620     Instruction *Inst = dyn_cast<Instruction>(Op);
621     assert(Inst && "We shouldn't fold AMX instrution!");
622     Incomings.push_back(Inst);
623   }
624 
625   Value *StorePtr = updatePhiIncomings(BB, Incomings);
626   replacePhiDefWithLoad(PHI, StorePtr);
627 }
628 
629 // Store the defined tile and load it before use.
630 // All its users are not PHI.
631 // e.g.
632 // ------------------------------------------------------
633 // def %td = ...
634 // ...
635 // "use %td"
636 // ------------------------------------------------------
637 // -->
638 // ------------------------------------------------------
639 // def %td = ...
640 // call void @llvm.x86.tilestored64.internal(mem, %td)
641 // ...
642 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
643 // "use %td2"
644 // ------------------------------------------------------
645 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
646   BasicBlock *BB = I->getParent();
647   Value *I8Ptr = getAllocaPos(BB);
648   User *Store = createTileStore(I, I8Ptr);
649 
650   // All its uses should load from stored mem.
651   for (Use &U : I->uses()) {
652     User *V = U.getUser();
653     assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
654     if (V != Store)
655       replaceWithTileLoad(U, I8Ptr);
656   }
657 }
658 
659 // Volatile Tile Model:
660 // 1) All the uses of tile data comes from tileload in time.
661 // 2) All the defs of tile data tilestore into mem immediately.
662 // For example:
663 // --------------------------------------------------------------------------
664 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
665 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
666 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
667 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
668 // call void @llvm.x86.tilestored64.internal(... td)                     area
669 // --------------------------------------------------------------------------
670 // 3) No terminator, call or other amx instructions in the key amx area.
671 bool X86VolatileTileData::volatileTileData() {
672   bool Changed = false;
673   for (BasicBlock &BB : F) {
674     SmallVector<Instruction *, 2> PHIInsts;
675     SmallVector<Instruction *, 8> AMXDefInsts;
676 
677     for (Instruction &I : BB) {
678       if (!I.getType()->isX86_AMXTy())
679         continue;
680       if (isa<PHINode>(&I))
681         PHIInsts.push_back(&I);
682       else
683         AMXDefInsts.push_back(&I);
684     }
685 
686     // First we "volatile" the non-phi related amx intrinsics.
687     for (Instruction *I : AMXDefInsts) {
688       if (isIncomingOfPHI(I))
689         continue;
690       volatileTileNonPHI(I);
691       Changed = true;
692     }
693 
694     for (Instruction *I : PHIInsts) {
695       volatileTilePHI(dyn_cast<PHINode>(I));
696       Changed = true;
697     }
698   }
699   return Changed;
700 }
701 
702 } // anonymous namespace
703 
704 namespace {
705 
706 class X86LowerAMXCast {
707   Function &Func;
708   std::unique_ptr<DominatorTree> DT;
709 
710 public:
711   X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
712   bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
713   bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
714   bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
715   bool combineAMXcast(TargetLibraryInfo *TLI);
716   bool transformAMXCast(IntrinsicInst *AMXCast);
717   bool transformAllAMXCast();
718   bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
719                               SmallSetVector<Instruction *, 16> &DeadInst);
720 };
721 
722 static bool DCEInstruction(Instruction *I,
723                            SmallSetVector<Instruction *, 16> &WorkList,
724                            const TargetLibraryInfo *TLI) {
725   if (isInstructionTriviallyDead(I, TLI)) {
726     salvageDebugInfo(*I);
727     salvageKnowledge(I);
728 
729     // Null out all of the instruction's operands to see if any operand becomes
730     // dead as we go.
731     for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
732       Value *OpV = I->getOperand(i);
733       I->setOperand(i, nullptr);
734 
735       if (!OpV->use_empty() || I == OpV)
736         continue;
737 
738       // If the operand is an instruction that became dead as we nulled out the
739       // operand, and if it is 'trivially' dead, delete it in a future loop
740       // iteration.
741       if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
742         if (isInstructionTriviallyDead(OpI, TLI)) {
743           WorkList.insert(OpI);
744         }
745       }
746     }
747     I->eraseFromParent();
748     return true;
749   }
750   return false;
751 }
752 
753 /// This function handles following case
754 ///
755 ///     A  ->  B    amxcast
756 ///     PHI
757 ///     B  ->  A    amxcast
758 ///
759 /// All the related PHI nodes can be replaced by new PHI nodes with type A.
760 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
761 bool X86LowerAMXCast::optimizeAMXCastFromPhi(
762     IntrinsicInst *CI, PHINode *PN,
763     SmallSetVector<Instruction *, 16> &DeadInst) {
764   IRBuilder<> Builder(CI);
765   Value *Src = CI->getOperand(0);
766   Type *SrcTy = Src->getType(); // Type B
767   Type *DestTy = CI->getType(); // Type A
768 
769   SmallVector<PHINode *, 4> PhiWorklist;
770   SmallSetVector<PHINode *, 4> OldPhiNodes;
771 
772   // Find all of the A->B casts and PHI nodes.
773   // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
774   // OldPhiNodes is used to track all known PHI nodes, before adding a new
775   // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
776   PhiWorklist.push_back(PN);
777   OldPhiNodes.insert(PN);
778   while (!PhiWorklist.empty()) {
779     auto *OldPN = PhiWorklist.pop_back_val();
780     for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
781       Value *IncValue = OldPN->getIncomingValue(I);
782       // TODO: currently, We ignore cases where it is a const. In the future, we
783       // might support const.
784       if (isa<Constant>(IncValue)) {
785         auto *IncConst = dyn_cast<Constant>(IncValue);
786         if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
787           return false;
788         Value *Row = nullptr, *Col = nullptr;
789         std::tie(Row, Col) = getShape(OldPN);
790         // TODO: If it is not constant the Row and Col must domoniate tilezero
791         // that we are going to create.
792         if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
793           return false;
794         // Create tilezero at the end of incoming block.
795         auto *Block = OldPN->getIncomingBlock(I);
796         BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
797         Instruction *NewInst = Builder.CreateIntrinsic(
798             Intrinsic::x86_tilezero_internal, std::nullopt, {Row, Col});
799         NewInst->moveBefore(&*Iter);
800         NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
801                                           {IncValue->getType()}, {NewInst});
802         NewInst->moveBefore(&*Iter);
803         // Replace InValue with new Value.
804         OldPN->setIncomingValue(I, NewInst);
805         IncValue = NewInst;
806       }
807 
808       if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
809         if (OldPhiNodes.insert(PNode))
810           PhiWorklist.push_back(PNode);
811         continue;
812       }
813       Instruction *ACI = dyn_cast<Instruction>(IncValue);
814       if (ACI && isAMXCast(ACI)) {
815         // Verify it's a A->B cast.
816         Type *TyA = ACI->getOperand(0)->getType();
817         Type *TyB = ACI->getType();
818         if (TyA != DestTy || TyB != SrcTy)
819           return false;
820         continue;
821       }
822       return false;
823     }
824   }
825 
826   // Check that each user of each old PHI node is something that we can
827   // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
828   for (auto *OldPN : OldPhiNodes) {
829     for (User *V : OldPN->users()) {
830       Instruction *ACI = dyn_cast<Instruction>(V);
831       if (ACI && isAMXCast(ACI)) {
832         // Verify it's a B->A cast.
833         Type *TyB = ACI->getOperand(0)->getType();
834         Type *TyA = ACI->getType();
835         if (TyA != DestTy || TyB != SrcTy)
836           return false;
837       } else if (auto *PHI = dyn_cast<PHINode>(V)) {
838         // As long as the user is another old PHI node, then even if we don't
839         // rewrite it, the PHI web we're considering won't have any users
840         // outside itself, so it'll be dead.
841         // example:
842         //   bb.0:
843         //      %0 = amxcast ...
844         //   bb.1:
845         //      %1 = amxcast ...
846         //   bb.2:
847         //      %goodphi = phi %0, %1
848         //      %3 = amxcast %goodphi
849         //   bb.3:
850         //      %goodphi2 = phi %0, %goodphi
851         //      %4 = amxcast %goodphi2
852         // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
853         // outside the phi-web, so the combination stop When
854         // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
855         // will be done.
856         if (OldPhiNodes.count(PHI) == 0)
857           return false;
858       } else
859         return false;
860     }
861   }
862 
863   // For each old PHI node, create a corresponding new PHI node with a type A.
864   SmallDenseMap<PHINode *, PHINode *> NewPNodes;
865   for (auto *OldPN : OldPhiNodes) {
866     Builder.SetInsertPoint(OldPN);
867     PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
868     NewPNodes[OldPN] = NewPN;
869   }
870 
871   // Fill in the operands of new PHI nodes.
872   for (auto *OldPN : OldPhiNodes) {
873     PHINode *NewPN = NewPNodes[OldPN];
874     for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
875       Value *V = OldPN->getOperand(j);
876       Value *NewV = nullptr;
877       Instruction *ACI = dyn_cast<Instruction>(V);
878       // There should not be a AMXcast from a const.
879       if (ACI && isAMXCast(ACI))
880         NewV = ACI->getOperand(0);
881       else if (auto *PrevPN = dyn_cast<PHINode>(V))
882         NewV = NewPNodes[PrevPN];
883       assert(NewV);
884       NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
885     }
886   }
887 
888   // Traverse all accumulated PHI nodes and process its users,
889   // which are Stores and BitcCasts. Without this processing
890   // NewPHI nodes could be replicated and could lead to extra
891   // moves generated after DeSSA.
892   // If there is a store with type B, change it to type A.
893 
894   // Replace users of BitCast B->A with NewPHI. These will help
895   // later to get rid of a closure formed by OldPHI nodes.
896   for (auto *OldPN : OldPhiNodes) {
897     PHINode *NewPN = NewPNodes[OldPN];
898     for (User *V : make_early_inc_range(OldPN->users())) {
899       Instruction *ACI = dyn_cast<Instruction>(V);
900       if (ACI && isAMXCast(ACI)) {
901         Type *TyB = ACI->getOperand(0)->getType();
902         Type *TyA = ACI->getType();
903         assert(TyA == DestTy && TyB == SrcTy);
904         (void)TyA;
905         (void)TyB;
906         ACI->replaceAllUsesWith(NewPN);
907         DeadInst.insert(ACI);
908       } else if (auto *PHI = dyn_cast<PHINode>(V)) {
909         // We don't need to push PHINode into DeadInst since they are operands
910         // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
911         assert(OldPhiNodes.contains(PHI));
912         (void)PHI;
913       } else
914         llvm_unreachable("all uses should be handled");
915     }
916   }
917   return true;
918 }
919 
920 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
921 // store <256 x i32> %43, <256 x i32>* %p, align 64
922 // -->
923 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
924 //                                           i64 64, x86_amx %42)
925 bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
926   Value *Tile = Cast->getOperand(0);
927   // TODO: If it is cast intrinsic or phi node, we can propagate the
928   // shape information through def-use chain.
929   if (!isAMXIntrinsic(Tile))
930     return false;
931   auto *II = cast<IntrinsicInst>(Tile);
932   // Tile is output from AMX intrinsic. The first operand of the
933   // intrinsic is row, the second operand of the intrinsic is column.
934   Value *Row = II->getOperand(0);
935   Value *Col = II->getOperand(1);
936   IRBuilder<> Builder(ST);
937   // Stride should be equal to col(measured by bytes)
938   Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
939   Value *I8Ptr =
940       Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
941   std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
942   Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
943                           Args);
944   return true;
945 }
946 
947 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
948 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
949 // -->
950 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
951 //                                                   i8* %p, i64 64)
952 bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
953   bool EraseLoad = true;
954   Value *Row = nullptr, *Col = nullptr;
955   Use &U = *(Cast->use_begin());
956   unsigned OpNo = U.getOperandNo();
957   auto *II = cast<IntrinsicInst>(U.getUser());
958   // TODO: If it is cast intrinsic or phi node, we can propagate the
959   // shape information through def-use chain.
960   if (!isAMXIntrinsic(II))
961     return false;
962   std::tie(Row, Col) = getShape(II, OpNo);
963   IRBuilder<> Builder(LD);
964   // Stride should be equal to col(measured by bytes)
965   Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
966   Value *I8Ptr;
967 
968   // To save compiling time, we create doninator tree when it is really
969   // needed.
970   if (!DT)
971     DT.reset(new DominatorTree(Func));
972   if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
973     // store the value to stack and reload it from stack before cast.
974     auto *AllocaAddr =
975         createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
976     Builder.SetInsertPoint(&*std::next(LD->getIterator()));
977     Builder.CreateStore(LD, AllocaAddr);
978 
979     Builder.SetInsertPoint(Cast);
980     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
981     EraseLoad = false;
982   } else {
983     I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
984   }
985   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
986 
987   Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
988                                            std::nullopt, Args);
989   Cast->replaceAllUsesWith(NewInst);
990 
991   return EraseLoad;
992 }
993 
994 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
995   bool Change = false;
996   for (auto *Cast : Casts) {
997     auto *II = cast<IntrinsicInst>(Cast);
998     // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
999     // store <256 x i32> %43, <256 x i32>* %p, align 64
1000     // -->
1001     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1002     //                                           i64 64, x86_amx %42)
1003     if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1004       SmallVector<Instruction *, 2> DeadStores;
1005       for (User *U : Cast->users()) {
1006         StoreInst *Store = dyn_cast<StoreInst>(U);
1007         if (!Store)
1008           continue;
1009         if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) {
1010           DeadStores.push_back(Store);
1011           Change = true;
1012         }
1013       }
1014       for (auto *Store : DeadStores)
1015         Store->eraseFromParent();
1016     } else { // x86_cast_vector_to_tile
1017       SmallVector<Instruction *, 2> DeadLoads;
1018       auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
1019       if (!Load || !Load->hasOneUse())
1020         continue;
1021       // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1022       // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1023       // -->
1024       // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1025       //                                                   i8* %p, i64 64)
1026       if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1027         // Set the operand is null so that load instruction can be erased.
1028         Cast->setOperand(0, nullptr);
1029         Load->eraseFromParent();
1030       }
1031     }
1032   }
1033   return Change;
1034 }
1035 
1036 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1037   bool Change = false;
1038   // Collect tile cast instruction.
1039   SmallVector<Instruction *, 8> Vec2TileInsts;
1040   SmallVector<Instruction *, 8> Tile2VecInsts;
1041   SmallVector<Instruction *, 8> PhiCastWorkList;
1042   SmallSetVector<Instruction *, 16> DeadInst;
1043   for (BasicBlock &BB : Func) {
1044     for (Instruction &I : BB) {
1045       Value *Vec;
1046       if (match(&I,
1047                 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
1048         Vec2TileInsts.push_back(&I);
1049       else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1050                              m_Value(Vec))))
1051         Tile2VecInsts.push_back(&I);
1052     }
1053   }
1054 
1055   auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1056     for (auto *Inst : Insts) {
1057       for (User *U : Inst->users()) {
1058         IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1059         if (!II || II->getIntrinsicID() != IID)
1060           continue;
1061         // T1 = vec2tile V0
1062         // V2 = tile2vec T1
1063         // V3 = OP V2
1064         // -->
1065         // T1 = vec2tile V0
1066         // V2 = tile2vec T1
1067         // V3 = OP V0
1068         II->replaceAllUsesWith(Inst->getOperand(0));
1069         Change = true;
1070       }
1071     }
1072   };
1073 
1074   Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1075   Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1076 
1077   SmallVector<Instruction *, 8> LiveCasts;
1078   auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1079     for (auto *Inst : Insts) {
1080       if (Inst->use_empty()) {
1081         Inst->eraseFromParent();
1082         Change = true;
1083       } else {
1084         LiveCasts.push_back(Inst);
1085       }
1086     }
1087   };
1088 
1089   EraseInst(Vec2TileInsts);
1090   EraseInst(Tile2VecInsts);
1091   LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1092                        "Vec2Tile and Tile2Vec:\n";
1093              Func.dump());
1094   Change |= combineLdSt(LiveCasts);
1095   EraseInst(LiveCasts);
1096   LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1097                        "AMXCast and load/store:\n";
1098              Func.dump());
1099 
1100   // Handle the A->B->A cast, and there is an intervening PHI node.
1101   for (BasicBlock &BB : Func) {
1102     for (Instruction &I : BB) {
1103       if (isAMXCast(&I)) {
1104         if (isa<PHINode>(I.getOperand(0)))
1105           PhiCastWorkList.push_back(&I);
1106       }
1107     }
1108   }
1109   for (auto *I : PhiCastWorkList) {
1110     // We skip the dead Amxcast.
1111     if (DeadInst.contains(I))
1112       continue;
1113     PHINode *PN = cast<PHINode>(I->getOperand(0));
1114     if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1115       DeadInst.insert(PN);
1116       Change = true;
1117     }
1118   }
1119 
1120   // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1121   // have no uses. We do some DeadCodeElimination for them.
1122   while (!DeadInst.empty()) {
1123     Instruction *I = DeadInst.pop_back_val();
1124     Change |= DCEInstruction(I, DeadInst, TLI);
1125   }
1126   LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1127                        "optimizeAMXCastFromPhi:\n";
1128              Func.dump());
1129   return Change;
1130 }
1131 
1132 // There might be remaining AMXcast after combineAMXcast and they should be
1133 // handled elegantly.
1134 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1135   IRBuilder<> Builder(AMXCast);
1136   AllocaInst *AllocaAddr;
1137   Value *I8Ptr, *Stride;
1138   auto *Src = AMXCast->getOperand(0);
1139 
1140   auto Prepare = [&](Type *MemTy) {
1141     AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1142     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
1143     Stride = Builder.getInt64(64);
1144   };
1145 
1146   if (AMXCast->getType()->isX86_AMXTy()) {
1147     // %2 = amxcast <225 x i32> %src to x86_amx
1148     // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1149     //                                           i8* %addr3, i64 60, x86_amx %2)
1150     // -->
1151     // %addr = alloca <225 x i32>, align 64
1152     // store <225 x i32> %src, <225 x i32>* %addr, align 64
1153     // %addr2 = bitcast <225 x i32>* %addr to i8*
1154     // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1155     //                                                  i8* %addr2,
1156     //                                                  i64 60)
1157     // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1158     //                                           i8* %addr3, i64 60, x86_amx %2)
1159     if (AMXCast->use_empty()) {
1160       AMXCast->eraseFromParent();
1161       return true;
1162     }
1163     Use &U = *(AMXCast->use_begin());
1164     unsigned OpNo = U.getOperandNo();
1165     auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1166     if (!II)
1167       return false; // May be bitcast from x86amx to <256 x i32>.
1168     Prepare(AMXCast->getOperand(0)->getType());
1169     Builder.CreateStore(Src, AllocaAddr);
1170     // TODO we can pick an constant operand for the shape.
1171     Value *Row = nullptr, *Col = nullptr;
1172     std::tie(Row, Col) = getShape(II, OpNo);
1173     std::array<Value *, 4> Args = {
1174         Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1175     Value *NewInst = Builder.CreateIntrinsic(
1176         Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
1177     AMXCast->replaceAllUsesWith(NewInst);
1178     AMXCast->eraseFromParent();
1179   } else {
1180     // %2 = amxcast x86_amx %src to <225 x i32>
1181     // -->
1182     // %addr = alloca <225 x i32>, align 64
1183     // %addr2 = bitcast <225 x i32>* to i8*
1184     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1185     //                                           i8* %addr2, i64 %stride)
1186     // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1187     auto *II = dyn_cast<IntrinsicInst>(Src);
1188     if (!II)
1189       return false; // May be bitcast from <256 x i32> to x86amx.
1190     Prepare(AMXCast->getType());
1191     Value *Row = II->getOperand(0);
1192     Value *Col = II->getOperand(1);
1193     std::array<Value *, 5> Args = {
1194         Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1195     Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
1196                             Args);
1197     Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1198     AMXCast->replaceAllUsesWith(NewInst);
1199     AMXCast->eraseFromParent();
1200   }
1201 
1202   return true;
1203 }
1204 
1205 bool X86LowerAMXCast::transformAllAMXCast() {
1206   bool Change = false;
1207   // Collect tile cast instruction.
1208   SmallVector<Instruction *, 8> WorkLists;
1209   for (BasicBlock &BB : Func) {
1210     for (Instruction &I : BB) {
1211       if (isAMXCast(&I))
1212         WorkLists.push_back(&I);
1213     }
1214   }
1215 
1216   for (auto *Inst : WorkLists) {
1217     Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1218   }
1219 
1220   return Change;
1221 }
1222 
1223 } // anonymous namespace
1224 
1225 namespace {
1226 
1227 class X86LowerAMXTypeLegacyPass : public FunctionPass {
1228 public:
1229   static char ID;
1230 
1231   X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1232     initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
1233   }
1234 
1235   bool runOnFunction(Function &F) override {
1236     bool C = false;
1237     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1238     TargetLibraryInfo *TLI =
1239         &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1240 
1241     X86LowerAMXCast LAC(F);
1242     C |= LAC.combineAMXcast(TLI);
1243     // There might be remaining AMXcast after combineAMXcast and they should be
1244     // handled elegantly.
1245     C |= LAC.transformAllAMXCast();
1246 
1247     X86LowerAMXType LAT(F);
1248     C |= LAT.visit();
1249 
1250     // Prepare for fast register allocation at O0.
1251     // Todo: May better check the volatile model of AMX code, not just
1252     // by checking Attribute::OptimizeNone and CodeGenOpt::None.
1253     if (TM->getOptLevel() == CodeGenOpt::None) {
1254       // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1255       // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1256       // sure the amx data is volatile, that is nessary for AMX fast
1257       // register allocation.
1258       if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1259         X86VolatileTileData VTD(F);
1260         C = VTD.volatileTileData() || C;
1261       }
1262     }
1263 
1264     return C;
1265   }
1266 
1267   void getAnalysisUsage(AnalysisUsage &AU) const override {
1268     AU.setPreservesCFG();
1269     AU.addRequired<TargetPassConfig>();
1270     AU.addRequired<TargetLibraryInfoWrapperPass>();
1271   }
1272 };
1273 
1274 } // anonymous namespace
1275 
1276 static const char PassName[] = "Lower AMX type for load/store";
1277 char X86LowerAMXTypeLegacyPass::ID = 0;
1278 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1279                       false)
1280 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1281 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
1282 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1283                     false)
1284 
1285 FunctionPass *llvm::createX86LowerAMXTypePass() {
1286   return new X86LowerAMXTypeLegacyPass();
1287 }
1288