1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ...                                          |
27 /// | ...                                                    |
28 /// | "use %td"                                              |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ...                                          |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td)    |
34 /// | ...                                                    |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2"                                             |
37 /// ----------------------------------------------------------
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 #include "X86.h"
42 #include "llvm/ADT/PostOrderIterator.h"
43 #include "llvm/ADT/SetVector.h"
44 #include "llvm/ADT/SmallSet.h"
45 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
46 #include "llvm/Analysis/TargetLibraryInfo.h"
47 #include "llvm/Analysis/TargetTransformInfo.h"
48 #include "llvm/CodeGen/Passes.h"
49 #include "llvm/CodeGen/TargetPassConfig.h"
50 #include "llvm/CodeGen/ValueTypes.h"
51 #include "llvm/IR/DataLayout.h"
52 #include "llvm/IR/Function.h"
53 #include "llvm/IR/IRBuilder.h"
54 #include "llvm/IR/Instructions.h"
55 #include "llvm/IR/IntrinsicInst.h"
56 #include "llvm/IR/IntrinsicsX86.h"
57 #include "llvm/IR/PatternMatch.h"
58 #include "llvm/InitializePasses.h"
59 #include "llvm/Pass.h"
60 #include "llvm/Target/TargetMachine.h"
61 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
62 #include "llvm/Transforms/Utils/Local.h"
63 
64 #include <map>
65 
66 using namespace llvm;
67 using namespace PatternMatch;
68 
69 #define DEBUG_TYPE "lower-amx-type"
70 
71 static bool isAMXCast(Instruction *II) {
72   return match(II,
73                m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
74          match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
75 }
76 
77 static bool isAMXIntrinsic(Value *I) {
78   auto *II = dyn_cast<IntrinsicInst>(I);
79   if (!II)
80     return false;
81   if (isAMXCast(II))
82     return false;
83   // Check if return type or parameter is x86_amx. If it is x86_amx
84   // the intrinsic must be x86 amx intrinsics.
85   if (II->getType()->isX86_AMXTy())
86     return true;
87   for (Value *V : II->args()) {
88     if (V->getType()->isX86_AMXTy())
89       return true;
90   }
91 
92   return false;
93 }
94 
95 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
96                                            Type *Ty) {
97   Function &F = *BB->getParent();
98   Module *M = BB->getModule();
99   const DataLayout &DL = M->getDataLayout();
100 
101   LLVMContext &Ctx = Builder.getContext();
102   auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
103   unsigned AllocaAS = DL.getAllocaAddrSpace();
104   AllocaInst *AllocaRes =
105       new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
106   AllocaRes->setAlignment(AllocaAlignment);
107   return AllocaRes;
108 }
109 
110 static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
111   for (Instruction &I : F.getEntryBlock())
112     if (!isa<AllocaInst>(&I))
113       return &I;
114   llvm_unreachable("No terminator in the entry block!");
115 }
116 
117 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
118   IRBuilder<> Builder(II);
119   Value *Row = nullptr, *Col = nullptr;
120   switch (II->getIntrinsicID()) {
121   default:
122     llvm_unreachable("Expect amx intrinsics");
123   case Intrinsic::x86_tileloadd64_internal:
124   case Intrinsic::x86_tileloaddt164_internal:
125   case Intrinsic::x86_tilestored64_internal: {
126     Row = II->getArgOperand(0);
127     Col = II->getArgOperand(1);
128     break;
129   }
130   // a * b + c
131   // The shape depends on which operand.
132   case Intrinsic::x86_tcmmimfp16ps_internal:
133   case Intrinsic::x86_tcmmrlfp16ps_internal:
134   case Intrinsic::x86_tdpbssd_internal:
135   case Intrinsic::x86_tdpbsud_internal:
136   case Intrinsic::x86_tdpbusd_internal:
137   case Intrinsic::x86_tdpbuud_internal:
138   case Intrinsic::x86_tdpbf16ps_internal:
139   case Intrinsic::x86_tdpfp16ps_internal: {
140     switch (OpNo) {
141     case 3:
142       Row = II->getArgOperand(0);
143       Col = II->getArgOperand(1);
144       break;
145     case 4:
146       Row = II->getArgOperand(0);
147       Col = II->getArgOperand(2);
148       break;
149     case 5:
150       if (isa<ConstantInt>(II->getArgOperand(2)))
151         Row = Builder.getInt16(
152             (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
153       else if (isa<Instruction>(II->getArgOperand(2))) {
154         // When it is not a const value and it is not a function argument, we
155         // create Row after the definition of II->getOperand(2) instead of
156         // before II. For example, II is %118, we try to getshape for %117:
157         //   %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
158         //   i32> %115).
159         //   %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
160         //   %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
161         //   %117).
162         // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
163         // definition is after its user(new tileload for %117).
164         // So, the best choice is to create %row right after the definition of
165         // %106.
166         Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
167         Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
168         cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
169       } else {
170         // When it is not a const value and it is a function argument, we create
171         // Row at the entry bb.
172         IRBuilder<> NewBuilder(
173             getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
174         Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
175       }
176       Col = II->getArgOperand(1);
177       break;
178     }
179     break;
180   }
181   }
182 
183   return std::make_pair(Row, Col);
184 }
185 
186 static std::pair<Value *, Value *> getShape(PHINode *Phi) {
187   Use &U = *(Phi->use_begin());
188   unsigned OpNo = U.getOperandNo();
189   User *V = U.getUser();
190   // TODO We don't traverse all users. To make the algorithm simple, here we
191   // just traverse the first user. If we can find shape, then return the shape,
192   // otherwise just return nullptr and the optimization for undef/zero will be
193   // abandoned.
194   while (V) {
195     if (isAMXCast(dyn_cast<Instruction>(V))) {
196       if (V->use_empty())
197         break;
198       Use &U = *(V->use_begin());
199       OpNo = U.getOperandNo();
200       V = U.getUser();
201     } else if (isAMXIntrinsic(V)) {
202       return getShape(cast<IntrinsicInst>(V), OpNo);
203     } else if (isa<PHINode>(V)) {
204       if (V->use_empty())
205         break;
206       Use &U = *(V->use_begin());
207       V = U.getUser();
208     } else {
209       break;
210     }
211   }
212 
213   return std::make_pair(nullptr, nullptr);
214 }
215 
216 namespace {
217 class X86LowerAMXType {
218   Function &Func;
219 
220   // In AMX intrinsics we let Shape = {Row, Col}, but the
221   // RealCol = Col / ElementSize. We may use the RealCol
222   // as a new Row for other new created AMX intrinsics.
223   std::map<Value *, Value *> Col2Row;
224 
225 public:
226   X86LowerAMXType(Function &F) : Func(F) {}
227   bool visit();
228   void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
229   void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
230   bool transformBitcast(BitCastInst *Bitcast);
231 };
232 
233 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
234 // %2 = bitcast <256 x i32> %src to x86_amx
235 // -->
236 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
237 // i8* %addr, i64 %stride64)
238 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
239   Value *Row = nullptr, *Col = nullptr;
240   Use &U = *(Bitcast->use_begin());
241   unsigned OpNo = U.getOperandNo();
242   auto *II = cast<IntrinsicInst>(U.getUser());
243   std::tie(Row, Col) = getShape(II, OpNo);
244   IRBuilder<> Builder(Bitcast);
245   // Use the maximun column as stride.
246   Value *Stride = Builder.getInt64(64);
247   Value *I8Ptr = LD->getOperand(0);
248   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
249 
250   Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
251                                            std::nullopt, Args);
252   Bitcast->replaceAllUsesWith(NewInst);
253 }
254 
255 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
256 //                                                    %stride);
257 // %13 = bitcast x86_amx %src to <256 x i32>
258 // store <256 x i32> %13, <256 x i32>* %addr, align 64
259 // -->
260 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
261 //                                           %stride64, %13)
262 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
263 
264   Value *Tile = Bitcast->getOperand(0);
265   auto *II = cast<IntrinsicInst>(Tile);
266   // Tile is output from AMX intrinsic. The first operand of the
267   // intrinsic is row, the second operand of the intrinsic is column.
268   Value *Row = II->getOperand(0);
269   Value *Col = II->getOperand(1);
270   IRBuilder<> Builder(ST);
271   // Use the maximum column as stride. It must be the same with load
272   // stride.
273   Value *Stride = Builder.getInt64(64);
274   Value *I8Ptr = ST->getOperand(1);
275   std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
276   Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
277                           Args);
278   if (Bitcast->hasOneUse())
279     return;
280   // %13 = bitcast x86_amx %src to <256 x i32>
281   // store <256 x i32> %13, <256 x i32>* %addr, align 64
282   // %add = <256 x i32> %13, <256 x i32> %src2
283   // -->
284   // %13 = bitcast x86_amx %src to <256 x i32>
285   // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
286   //                                           %stride64, %13)
287   // %14 = load <256 x i32>, %addr
288   // %add = <256 x i32> %14, <256 x i32> %src2
289   Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
290   Bitcast->replaceAllUsesWith(Vec);
291 }
292 
293 // transform bitcast to <store, load> instructions.
294 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
295   IRBuilder<> Builder(Bitcast);
296   AllocaInst *AllocaAddr;
297   Value *I8Ptr, *Stride;
298   auto *Src = Bitcast->getOperand(0);
299 
300   auto Prepare = [&](Type *MemTy) {
301     AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
302     I8Ptr = AllocaAddr;
303     Stride = Builder.getInt64(64);
304   };
305 
306   if (Bitcast->getType()->isX86_AMXTy()) {
307     // %2 = bitcast <256 x i32> %src to x86_amx
308     // -->
309     // %addr = alloca <256 x i32>, align 64
310     // store <256 x i32> %src, <256 x i32>* %addr, align 64
311     // %addr2 = bitcast <256 x i32>* to i8*
312     // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
313     //                                                  i8* %addr2,
314     //                                                  i64 64)
315     Use &U = *(Bitcast->use_begin());
316     unsigned OpNo = U.getOperandNo();
317     auto *II = dyn_cast<IntrinsicInst>(U.getUser());
318     if (!II)
319       return false; // May be bitcast from x86amx to <256 x i32>.
320     Prepare(Bitcast->getOperand(0)->getType());
321     Builder.CreateStore(Src, AllocaAddr);
322     // TODO we can pick an constant operand for the shape.
323     Value *Row = nullptr, *Col = nullptr;
324     std::tie(Row, Col) = getShape(II, OpNo);
325     std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
326     Value *NewInst = Builder.CreateIntrinsic(
327         Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
328     Bitcast->replaceAllUsesWith(NewInst);
329   } else {
330     // %2 = bitcast x86_amx %src to <256 x i32>
331     // -->
332     // %addr = alloca <256 x i32>, align 64
333     // %addr2 = bitcast <256 x i32>* to i8*
334     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
335     //                                           i8* %addr2, i64 %stride)
336     // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
337     auto *II = dyn_cast<IntrinsicInst>(Src);
338     if (!II)
339       return false; // May be bitcast from <256 x i32> to x86amx.
340     Prepare(Bitcast->getType());
341     Value *Row = II->getOperand(0);
342     Value *Col = II->getOperand(1);
343     std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
344     Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
345                             Args);
346     Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
347     Bitcast->replaceAllUsesWith(NewInst);
348   }
349 
350   return true;
351 }
352 
353 bool X86LowerAMXType::visit() {
354   SmallVector<Instruction *, 8> DeadInsts;
355   Col2Row.clear();
356 
357   for (BasicBlock *BB : post_order(&Func)) {
358     for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) {
359       auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
360       if (!Bitcast)
361         continue;
362 
363       Value *Src = Bitcast->getOperand(0);
364       if (Bitcast->getType()->isX86_AMXTy()) {
365         if (Bitcast->user_empty()) {
366           DeadInsts.push_back(Bitcast);
367           continue;
368         }
369         LoadInst *LD = dyn_cast<LoadInst>(Src);
370         if (!LD) {
371           if (transformBitcast(Bitcast))
372             DeadInsts.push_back(Bitcast);
373           continue;
374         }
375         // If load has mutli-user, duplicate a vector load.
376         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
377         // %2 = bitcast <256 x i32> %src to x86_amx
378         // %add = add <256 x i32> %src, <256 x i32> %src2
379         // -->
380         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
381         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
382         //                                            i8* %addr, i64 %stride64)
383         // %add = add <256 x i32> %src, <256 x i32> %src2
384 
385         // If load has one user, the load will be eliminated in DAG ISel.
386         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
387         // %2 = bitcast <256 x i32> %src to x86_amx
388         // -->
389         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
390         //                                            i8* %addr, i64 %stride64)
391         combineLoadBitcast(LD, Bitcast);
392         DeadInsts.push_back(Bitcast);
393         if (LD->hasOneUse())
394           DeadInsts.push_back(LD);
395       } else if (Src->getType()->isX86_AMXTy()) {
396         if (Bitcast->user_empty()) {
397           DeadInsts.push_back(Bitcast);
398           continue;
399         }
400         StoreInst *ST = nullptr;
401         for (Use &U : Bitcast->uses()) {
402           ST = dyn_cast<StoreInst>(U.getUser());
403           if (ST)
404             break;
405         }
406         if (!ST) {
407           if (transformBitcast(Bitcast))
408             DeadInsts.push_back(Bitcast);
409           continue;
410         }
411         // If bitcast (%13) has one use, combine bitcast and store to amx store.
412         // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
413         //                                                    %stride);
414         // %13 = bitcast x86_amx %src to <256 x i32>
415         // store <256 x i32> %13, <256 x i32>* %addr, align 64
416         // -->
417         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
418         //                                           %stride64, %13)
419         //
420         // If bitcast (%13) has multi-use, transform as below.
421         // %13 = bitcast x86_amx %src to <256 x i32>
422         // store <256 x i32> %13, <256 x i32>* %addr, align 64
423         // %add = <256 x i32> %13, <256 x i32> %src2
424         // -->
425         // %13 = bitcast x86_amx %src to <256 x i32>
426         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
427         //                                           %stride64, %13)
428         // %14 = load <256 x i32>, %addr
429         // %add = <256 x i32> %14, <256 x i32> %src2
430         //
431         combineBitcastStore(Bitcast, ST);
432         // Delete user first.
433         DeadInsts.push_back(ST);
434         DeadInsts.push_back(Bitcast);
435       }
436     }
437   }
438 
439   bool C = !DeadInsts.empty();
440 
441   for (auto *Inst : DeadInsts)
442     Inst->eraseFromParent();
443 
444   return C;
445 }
446 } // anonymous namespace
447 
448 static Value *getAllocaPos(BasicBlock *BB) {
449   Module *M = BB->getModule();
450   Function *F = BB->getParent();
451   IRBuilder<> Builder(&F->getEntryBlock().front());
452   const DataLayout &DL = M->getDataLayout();
453   unsigned AllocaAS = DL.getAllocaAddrSpace();
454   Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
455   AllocaInst *AllocaRes =
456       new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
457   BasicBlock::iterator Iter = AllocaRes->getIterator();
458   ++Iter;
459   Builder.SetInsertPoint(&*Iter);
460   Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy());
461   return I8Ptr;
462 }
463 
464 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
465   assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
466   auto *II = cast<IntrinsicInst>(TileDef);
467   assert(II && "Not tile intrinsic!");
468   Value *Row = II->getOperand(0);
469   Value *Col = II->getOperand(1);
470 
471   BasicBlock *BB = TileDef->getParent();
472   BasicBlock::iterator Iter = TileDef->getIterator();
473   IRBuilder<> Builder(BB, ++Iter);
474   Value *Stride = Builder.getInt64(64);
475   std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
476 
477   Instruction *TileStore = Builder.CreateIntrinsic(
478       Intrinsic::x86_tilestored64_internal, std::nullopt, Args);
479   return TileStore;
480 }
481 
482 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
483   Value *V = U.get();
484   assert(V->getType()->isX86_AMXTy() && "Not define tile!");
485 
486   // Get tile shape.
487   IntrinsicInst *II = nullptr;
488   if (IsPHI) {
489     Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
490     II = cast<IntrinsicInst>(PhiOp);
491   } else {
492     II = cast<IntrinsicInst>(V);
493   }
494   Value *Row = II->getOperand(0);
495   Value *Col = II->getOperand(1);
496 
497   Instruction *UserI = cast<Instruction>(U.getUser());
498   IRBuilder<> Builder(UserI);
499   Value *Stride = Builder.getInt64(64);
500   std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
501 
502   Value *TileLoad = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
503                                             std::nullopt, Args);
504   UserI->replaceUsesOfWith(V, TileLoad);
505 }
506 
507 static bool isIncomingOfPHI(Instruction *I) {
508   for (Use &U : I->uses()) {
509     User *V = U.getUser();
510     if (isa<PHINode>(V))
511       return true;
512   }
513   return false;
514 }
515 
516 // Let all AMX tile data become volatile data, shorten the life range
517 // of each tile register before fast register allocation.
518 namespace {
519 class X86VolatileTileData {
520   Function &F;
521 
522 public:
523   X86VolatileTileData(Function &Func) : F(Func) {}
524   Value *updatePhiIncomings(BasicBlock *BB,
525                             SmallVector<Instruction *, 2> &Incomings);
526   void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
527   bool volatileTileData();
528   void volatileTilePHI(PHINode *PHI);
529   void volatileTileNonPHI(Instruction *I);
530 };
531 
532 Value *X86VolatileTileData::updatePhiIncomings(
533     BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
534   Value *I8Ptr = getAllocaPos(BB);
535 
536   for (auto *I : Incomings) {
537     User *Store = createTileStore(I, I8Ptr);
538 
539     // All its uses (except phi) should load from stored mem.
540     for (Use &U : I->uses()) {
541       User *V = U.getUser();
542       if (isa<PHINode>(V) || V == Store)
543         continue;
544       replaceWithTileLoad(U, I8Ptr);
545     }
546   }
547   return I8Ptr;
548 }
549 
550 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
551                                                 Value *StorePtr) {
552   for (Use &U : PHI->uses())
553     replaceWithTileLoad(U, StorePtr, true);
554   PHI->eraseFromParent();
555 }
556 
557 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
558 // and their related AMX intrinsics.
559 // 1) PHI Def should change to tileload.
560 // 2) PHI Incoming Values should tilestored in just after their def.
561 // 3) The mem of these tileload and tilestores should be same.
562 // e.g.
563 // ------------------------------------------------------
564 // bb_dom:
565 //   ...
566 //   br i1 %bool.cond, label %if.else, label %if.then
567 //
568 // if.then:
569 //   def %t0 = ...
570 //   ...
571 //   use %t0
572 //   ...
573 //   br label %if.end
574 //
575 // if.else:
576 //   def %t1 = ...
577 //   br label %if.end
578 //
579 // if.end:
580 //   %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
581 //   ...
582 //   use %td
583 // ------------------------------------------------------
584 // -->
585 // ------------------------------------------------------
586 // bb_entry:
587 //   %mem = alloca <256 x i32>, align 1024                  *
588 //   ...
589 // bb_dom:
590 //   ...
591 //   br i1 %bool.cond, label %if.else, label %if.then
592 //
593 // if.then:
594 //   def %t0 = ...
595 //   call void @llvm.x86.tilestored64.internal(mem, %t0)    *
596 //   ...
597 //   %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
598 //   use %t0`                                               *
599 //   ...
600 //   br label %if.end
601 //
602 // if.else:
603 //   def %t1 = ...
604 //   call void @llvm.x86.tilestored64.internal(mem, %t1)    *
605 //   br label %if.end
606 //
607 // if.end:
608 //   ...
609 //   %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
610 //   use %td
611 // ------------------------------------------------------
612 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
613   BasicBlock *BB = PHI->getParent();
614   SmallVector<Instruction *, 2> Incomings;
615 
616   for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
617     Value *Op = PHI->getIncomingValue(I);
618     Instruction *Inst = dyn_cast<Instruction>(Op);
619     assert(Inst && "We shouldn't fold AMX instrution!");
620     Incomings.push_back(Inst);
621   }
622 
623   Value *StorePtr = updatePhiIncomings(BB, Incomings);
624   replacePhiDefWithLoad(PHI, StorePtr);
625 }
626 
627 // Store the defined tile and load it before use.
628 // All its users are not PHI.
629 // e.g.
630 // ------------------------------------------------------
631 // def %td = ...
632 // ...
633 // "use %td"
634 // ------------------------------------------------------
635 // -->
636 // ------------------------------------------------------
637 // def %td = ...
638 // call void @llvm.x86.tilestored64.internal(mem, %td)
639 // ...
640 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
641 // "use %td2"
642 // ------------------------------------------------------
643 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
644   BasicBlock *BB = I->getParent();
645   Value *I8Ptr = getAllocaPos(BB);
646   User *Store = createTileStore(I, I8Ptr);
647 
648   // All its uses should load from stored mem.
649   for (Use &U : I->uses()) {
650     User *V = U.getUser();
651     assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
652     if (V != Store)
653       replaceWithTileLoad(U, I8Ptr);
654   }
655 }
656 
657 // Volatile Tile Model:
658 // 1) All the uses of tile data comes from tileload in time.
659 // 2) All the defs of tile data tilestore into mem immediately.
660 // For example:
661 // --------------------------------------------------------------------------
662 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
663 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
664 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
665 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
666 // call void @llvm.x86.tilestored64.internal(... td)                     area
667 // --------------------------------------------------------------------------
668 // 3) No terminator, call or other amx instructions in the key amx area.
669 bool X86VolatileTileData::volatileTileData() {
670   bool Changed = false;
671   for (BasicBlock &BB : F) {
672     SmallVector<Instruction *, 2> PHIInsts;
673     SmallVector<Instruction *, 8> AMXDefInsts;
674 
675     for (Instruction &I : BB) {
676       if (!I.getType()->isX86_AMXTy())
677         continue;
678       if (isa<PHINode>(&I))
679         PHIInsts.push_back(&I);
680       else
681         AMXDefInsts.push_back(&I);
682     }
683 
684     // First we "volatile" the non-phi related amx intrinsics.
685     for (Instruction *I : AMXDefInsts) {
686       if (isIncomingOfPHI(I))
687         continue;
688       volatileTileNonPHI(I);
689       Changed = true;
690     }
691 
692     for (Instruction *I : PHIInsts) {
693       volatileTilePHI(dyn_cast<PHINode>(I));
694       Changed = true;
695     }
696   }
697   return Changed;
698 }
699 
700 } // anonymous namespace
701 
702 namespace {
703 
704 class X86LowerAMXCast {
705   Function &Func;
706   std::unique_ptr<DominatorTree> DT;
707 
708 public:
709   X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
710   bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
711   bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
712   bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
713   bool combineAMXcast(TargetLibraryInfo *TLI);
714   bool transformAMXCast(IntrinsicInst *AMXCast);
715   bool transformAllAMXCast();
716   bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
717                               SmallSetVector<Instruction *, 16> &DeadInst);
718 };
719 
720 static bool DCEInstruction(Instruction *I,
721                            SmallSetVector<Instruction *, 16> &WorkList,
722                            const TargetLibraryInfo *TLI) {
723   if (isInstructionTriviallyDead(I, TLI)) {
724     salvageDebugInfo(*I);
725     salvageKnowledge(I);
726 
727     // Null out all of the instruction's operands to see if any operand becomes
728     // dead as we go.
729     for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
730       Value *OpV = I->getOperand(i);
731       I->setOperand(i, nullptr);
732 
733       if (!OpV->use_empty() || I == OpV)
734         continue;
735 
736       // If the operand is an instruction that became dead as we nulled out the
737       // operand, and if it is 'trivially' dead, delete it in a future loop
738       // iteration.
739       if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
740         if (isInstructionTriviallyDead(OpI, TLI)) {
741           WorkList.insert(OpI);
742         }
743       }
744     }
745     I->eraseFromParent();
746     return true;
747   }
748   return false;
749 }
750 
751 /// This function handles following case
752 ///
753 ///     A  ->  B    amxcast
754 ///     PHI
755 ///     B  ->  A    amxcast
756 ///
757 /// All the related PHI nodes can be replaced by new PHI nodes with type A.
758 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
759 bool X86LowerAMXCast::optimizeAMXCastFromPhi(
760     IntrinsicInst *CI, PHINode *PN,
761     SmallSetVector<Instruction *, 16> &DeadInst) {
762   IRBuilder<> Builder(CI);
763   Value *Src = CI->getOperand(0);
764   Type *SrcTy = Src->getType(); // Type B
765   Type *DestTy = CI->getType(); // Type A
766 
767   SmallVector<PHINode *, 4> PhiWorklist;
768   SmallSetVector<PHINode *, 4> OldPhiNodes;
769 
770   // Find all of the A->B casts and PHI nodes.
771   // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
772   // OldPhiNodes is used to track all known PHI nodes, before adding a new
773   // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
774   PhiWorklist.push_back(PN);
775   OldPhiNodes.insert(PN);
776   while (!PhiWorklist.empty()) {
777     auto *OldPN = PhiWorklist.pop_back_val();
778     for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
779       Value *IncValue = OldPN->getIncomingValue(I);
780       // TODO: currently, We ignore cases where it is a const. In the future, we
781       // might support const.
782       if (isa<Constant>(IncValue)) {
783         auto *IncConst = dyn_cast<Constant>(IncValue);
784         if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
785           return false;
786         Value *Row = nullptr, *Col = nullptr;
787         std::tie(Row, Col) = getShape(OldPN);
788         // TODO: If it is not constant the Row and Col must domoniate tilezero
789         // that we are going to create.
790         if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
791           return false;
792         // Create tilezero at the end of incoming block.
793         auto *Block = OldPN->getIncomingBlock(I);
794         BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
795         Instruction *NewInst = Builder.CreateIntrinsic(
796             Intrinsic::x86_tilezero_internal, std::nullopt, {Row, Col});
797         NewInst->moveBefore(&*Iter);
798         NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
799                                           {IncValue->getType()}, {NewInst});
800         NewInst->moveBefore(&*Iter);
801         // Replace InValue with new Value.
802         OldPN->setIncomingValue(I, NewInst);
803         IncValue = NewInst;
804       }
805 
806       if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
807         if (OldPhiNodes.insert(PNode))
808           PhiWorklist.push_back(PNode);
809         continue;
810       }
811       Instruction *ACI = dyn_cast<Instruction>(IncValue);
812       if (ACI && isAMXCast(ACI)) {
813         // Verify it's a A->B cast.
814         Type *TyA = ACI->getOperand(0)->getType();
815         Type *TyB = ACI->getType();
816         if (TyA != DestTy || TyB != SrcTy)
817           return false;
818         continue;
819       }
820       return false;
821     }
822   }
823 
824   // Check that each user of each old PHI node is something that we can
825   // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
826   for (auto *OldPN : OldPhiNodes) {
827     for (User *V : OldPN->users()) {
828       Instruction *ACI = dyn_cast<Instruction>(V);
829       if (ACI && isAMXCast(ACI)) {
830         // Verify it's a B->A cast.
831         Type *TyB = ACI->getOperand(0)->getType();
832         Type *TyA = ACI->getType();
833         if (TyA != DestTy || TyB != SrcTy)
834           return false;
835       } else if (auto *PHI = dyn_cast<PHINode>(V)) {
836         // As long as the user is another old PHI node, then even if we don't
837         // rewrite it, the PHI web we're considering won't have any users
838         // outside itself, so it'll be dead.
839         // example:
840         //   bb.0:
841         //      %0 = amxcast ...
842         //   bb.1:
843         //      %1 = amxcast ...
844         //   bb.2:
845         //      %goodphi = phi %0, %1
846         //      %3 = amxcast %goodphi
847         //   bb.3:
848         //      %goodphi2 = phi %0, %goodphi
849         //      %4 = amxcast %goodphi2
850         // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
851         // outside the phi-web, so the combination stop When
852         // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
853         // will be done.
854         if (OldPhiNodes.count(PHI) == 0)
855           return false;
856       } else
857         return false;
858     }
859   }
860 
861   // For each old PHI node, create a corresponding new PHI node with a type A.
862   SmallDenseMap<PHINode *, PHINode *> NewPNodes;
863   for (auto *OldPN : OldPhiNodes) {
864     Builder.SetInsertPoint(OldPN);
865     PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
866     NewPNodes[OldPN] = NewPN;
867   }
868 
869   // Fill in the operands of new PHI nodes.
870   for (auto *OldPN : OldPhiNodes) {
871     PHINode *NewPN = NewPNodes[OldPN];
872     for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
873       Value *V = OldPN->getOperand(j);
874       Value *NewV = nullptr;
875       Instruction *ACI = dyn_cast<Instruction>(V);
876       // There should not be a AMXcast from a const.
877       if (ACI && isAMXCast(ACI))
878         NewV = ACI->getOperand(0);
879       else if (auto *PrevPN = dyn_cast<PHINode>(V))
880         NewV = NewPNodes[PrevPN];
881       assert(NewV);
882       NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
883     }
884   }
885 
886   // Traverse all accumulated PHI nodes and process its users,
887   // which are Stores and BitcCasts. Without this processing
888   // NewPHI nodes could be replicated and could lead to extra
889   // moves generated after DeSSA.
890   // If there is a store with type B, change it to type A.
891 
892   // Replace users of BitCast B->A with NewPHI. These will help
893   // later to get rid of a closure formed by OldPHI nodes.
894   for (auto *OldPN : OldPhiNodes) {
895     PHINode *NewPN = NewPNodes[OldPN];
896     for (User *V : make_early_inc_range(OldPN->users())) {
897       Instruction *ACI = dyn_cast<Instruction>(V);
898       if (ACI && isAMXCast(ACI)) {
899         Type *TyB = ACI->getOperand(0)->getType();
900         Type *TyA = ACI->getType();
901         assert(TyA == DestTy && TyB == SrcTy);
902         (void)TyA;
903         (void)TyB;
904         ACI->replaceAllUsesWith(NewPN);
905         DeadInst.insert(ACI);
906       } else if (auto *PHI = dyn_cast<PHINode>(V)) {
907         // We don't need to push PHINode into DeadInst since they are operands
908         // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
909         assert(OldPhiNodes.contains(PHI));
910         (void)PHI;
911       } else
912         llvm_unreachable("all uses should be handled");
913     }
914   }
915   return true;
916 }
917 
918 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
919 // store <256 x i32> %43, <256 x i32>* %p, align 64
920 // -->
921 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
922 //                                           i64 64, x86_amx %42)
923 bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
924   Value *Tile = Cast->getOperand(0);
925   // TODO: If it is cast intrinsic or phi node, we can propagate the
926   // shape information through def-use chain.
927   if (!isAMXIntrinsic(Tile))
928     return false;
929   auto *II = cast<IntrinsicInst>(Tile);
930   // Tile is output from AMX intrinsic. The first operand of the
931   // intrinsic is row, the second operand of the intrinsic is column.
932   Value *Row = II->getOperand(0);
933   Value *Col = II->getOperand(1);
934   IRBuilder<> Builder(ST);
935   // Stride should be equal to col(measured by bytes)
936   Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
937   Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy());
938   std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
939   Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
940                           Args);
941   return true;
942 }
943 
944 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
945 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
946 // -->
947 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
948 //                                                   i8* %p, i64 64)
949 bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
950   bool EraseLoad = true;
951   Value *Row = nullptr, *Col = nullptr;
952   Use &U = *(Cast->use_begin());
953   unsigned OpNo = U.getOperandNo();
954   auto *II = cast<IntrinsicInst>(U.getUser());
955   // TODO: If it is cast intrinsic or phi node, we can propagate the
956   // shape information through def-use chain.
957   if (!isAMXIntrinsic(II))
958     return false;
959   std::tie(Row, Col) = getShape(II, OpNo);
960   IRBuilder<> Builder(LD);
961   // Stride should be equal to col(measured by bytes)
962   Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
963   Value *I8Ptr;
964 
965   // To save compiling time, we create doninator tree when it is really
966   // needed.
967   if (!DT)
968     DT.reset(new DominatorTree(Func));
969   if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
970     // store the value to stack and reload it from stack before cast.
971     auto *AllocaAddr =
972         createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
973     Builder.SetInsertPoint(&*std::next(LD->getIterator()));
974     Builder.CreateStore(LD, AllocaAddr);
975 
976     Builder.SetInsertPoint(Cast);
977     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
978     EraseLoad = false;
979   } else {
980     I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());
981   }
982   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
983 
984   Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
985                                            std::nullopt, Args);
986   Cast->replaceAllUsesWith(NewInst);
987 
988   return EraseLoad;
989 }
990 
991 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
992   bool Change = false;
993   for (auto *Cast : Casts) {
994     auto *II = cast<IntrinsicInst>(Cast);
995     // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
996     // store <256 x i32> %43, <256 x i32>* %p, align 64
997     // -->
998     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
999     //                                           i64 64, x86_amx %42)
1000     if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1001       SmallVector<Instruction *, 2> DeadStores;
1002       for (User *U : Cast->users()) {
1003         StoreInst *Store = dyn_cast<StoreInst>(U);
1004         if (!Store)
1005           continue;
1006         if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) {
1007           DeadStores.push_back(Store);
1008           Change = true;
1009         }
1010       }
1011       for (auto *Store : DeadStores)
1012         Store->eraseFromParent();
1013     } else { // x86_cast_vector_to_tile
1014       SmallVector<Instruction *, 2> DeadLoads;
1015       auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
1016       if (!Load || !Load->hasOneUse())
1017         continue;
1018       // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1019       // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1020       // -->
1021       // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1022       //                                                   i8* %p, i64 64)
1023       if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1024         // Set the operand is null so that load instruction can be erased.
1025         Cast->setOperand(0, nullptr);
1026         Load->eraseFromParent();
1027       }
1028     }
1029   }
1030   return Change;
1031 }
1032 
1033 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1034   bool Change = false;
1035   // Collect tile cast instruction.
1036   SmallVector<Instruction *, 8> Vec2TileInsts;
1037   SmallVector<Instruction *, 8> Tile2VecInsts;
1038   SmallVector<Instruction *, 8> PhiCastWorkList;
1039   SmallSetVector<Instruction *, 16> DeadInst;
1040   for (BasicBlock &BB : Func) {
1041     for (Instruction &I : BB) {
1042       Value *Vec;
1043       if (match(&I,
1044                 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
1045         Vec2TileInsts.push_back(&I);
1046       else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1047                              m_Value(Vec))))
1048         Tile2VecInsts.push_back(&I);
1049     }
1050   }
1051 
1052   auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1053     for (auto *Inst : Insts) {
1054       for (User *U : Inst->users()) {
1055         IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1056         if (!II || II->getIntrinsicID() != IID)
1057           continue;
1058         // T1 = vec2tile V0
1059         // V2 = tile2vec T1
1060         // V3 = OP V2
1061         // -->
1062         // T1 = vec2tile V0
1063         // V2 = tile2vec T1
1064         // V3 = OP V0
1065         II->replaceAllUsesWith(Inst->getOperand(0));
1066         Change = true;
1067       }
1068     }
1069   };
1070 
1071   Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1072   Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1073 
1074   SmallVector<Instruction *, 8> LiveCasts;
1075   auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1076     for (auto *Inst : Insts) {
1077       if (Inst->use_empty()) {
1078         Inst->eraseFromParent();
1079         Change = true;
1080       } else {
1081         LiveCasts.push_back(Inst);
1082       }
1083     }
1084   };
1085 
1086   EraseInst(Vec2TileInsts);
1087   EraseInst(Tile2VecInsts);
1088   LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1089                        "Vec2Tile and Tile2Vec:\n";
1090              Func.dump());
1091   Change |= combineLdSt(LiveCasts);
1092   EraseInst(LiveCasts);
1093   LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1094                        "AMXCast and load/store:\n";
1095              Func.dump());
1096 
1097   // Handle the A->B->A cast, and there is an intervening PHI node.
1098   for (BasicBlock &BB : Func) {
1099     for (Instruction &I : BB) {
1100       if (isAMXCast(&I)) {
1101         if (isa<PHINode>(I.getOperand(0)))
1102           PhiCastWorkList.push_back(&I);
1103       }
1104     }
1105   }
1106   for (auto *I : PhiCastWorkList) {
1107     // We skip the dead Amxcast.
1108     if (DeadInst.contains(I))
1109       continue;
1110     PHINode *PN = cast<PHINode>(I->getOperand(0));
1111     if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1112       DeadInst.insert(PN);
1113       Change = true;
1114     }
1115   }
1116 
1117   // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1118   // have no uses. We do some DeadCodeElimination for them.
1119   while (!DeadInst.empty()) {
1120     Instruction *I = DeadInst.pop_back_val();
1121     Change |= DCEInstruction(I, DeadInst, TLI);
1122   }
1123   LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1124                        "optimizeAMXCastFromPhi:\n";
1125              Func.dump());
1126   return Change;
1127 }
1128 
1129 // There might be remaining AMXcast after combineAMXcast and they should be
1130 // handled elegantly.
1131 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1132   IRBuilder<> Builder(AMXCast);
1133   AllocaInst *AllocaAddr;
1134   Value *I8Ptr, *Stride;
1135   auto *Src = AMXCast->getOperand(0);
1136 
1137   auto Prepare = [&](Type *MemTy) {
1138     AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1139     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1140     Stride = Builder.getInt64(64);
1141   };
1142 
1143   if (AMXCast->getType()->isX86_AMXTy()) {
1144     // %2 = amxcast <225 x i32> %src to x86_amx
1145     // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1146     //                                           i8* %addr3, i64 60, x86_amx %2)
1147     // -->
1148     // %addr = alloca <225 x i32>, align 64
1149     // store <225 x i32> %src, <225 x i32>* %addr, align 64
1150     // %addr2 = bitcast <225 x i32>* %addr to i8*
1151     // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1152     //                                                  i8* %addr2,
1153     //                                                  i64 60)
1154     // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1155     //                                           i8* %addr3, i64 60, x86_amx %2)
1156     if (AMXCast->use_empty()) {
1157       AMXCast->eraseFromParent();
1158       return true;
1159     }
1160     Use &U = *(AMXCast->use_begin());
1161     unsigned OpNo = U.getOperandNo();
1162     auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1163     if (!II)
1164       return false; // May be bitcast from x86amx to <256 x i32>.
1165     Prepare(AMXCast->getOperand(0)->getType());
1166     Builder.CreateStore(Src, AllocaAddr);
1167     // TODO we can pick an constant operand for the shape.
1168     Value *Row = nullptr, *Col = nullptr;
1169     std::tie(Row, Col) = getShape(II, OpNo);
1170     std::array<Value *, 4> Args = {
1171         Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1172     Value *NewInst = Builder.CreateIntrinsic(
1173         Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
1174     AMXCast->replaceAllUsesWith(NewInst);
1175     AMXCast->eraseFromParent();
1176   } else {
1177     // %2 = amxcast x86_amx %src to <225 x i32>
1178     // -->
1179     // %addr = alloca <225 x i32>, align 64
1180     // %addr2 = bitcast <225 x i32>* to i8*
1181     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1182     //                                           i8* %addr2, i64 %stride)
1183     // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1184     auto *II = dyn_cast<IntrinsicInst>(Src);
1185     if (!II)
1186       return false; // May be bitcast from <256 x i32> to x86amx.
1187     Prepare(AMXCast->getType());
1188     Value *Row = II->getOperand(0);
1189     Value *Col = II->getOperand(1);
1190     std::array<Value *, 5> Args = {
1191         Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1192     Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
1193                             Args);
1194     Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1195     AMXCast->replaceAllUsesWith(NewInst);
1196     AMXCast->eraseFromParent();
1197   }
1198 
1199   return true;
1200 }
1201 
1202 bool X86LowerAMXCast::transformAllAMXCast() {
1203   bool Change = false;
1204   // Collect tile cast instruction.
1205   SmallVector<Instruction *, 8> WorkLists;
1206   for (BasicBlock &BB : Func) {
1207     for (Instruction &I : BB) {
1208       if (isAMXCast(&I))
1209         WorkLists.push_back(&I);
1210     }
1211   }
1212 
1213   for (auto *Inst : WorkLists) {
1214     Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1215   }
1216 
1217   return Change;
1218 }
1219 
1220 } // anonymous namespace
1221 
1222 namespace {
1223 
1224 class X86LowerAMXTypeLegacyPass : public FunctionPass {
1225 public:
1226   static char ID;
1227 
1228   X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1229     initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
1230   }
1231 
1232   bool runOnFunction(Function &F) override {
1233     bool C = false;
1234     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1235     TargetLibraryInfo *TLI =
1236         &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1237 
1238     X86LowerAMXCast LAC(F);
1239     C |= LAC.combineAMXcast(TLI);
1240     // There might be remaining AMXcast after combineAMXcast and they should be
1241     // handled elegantly.
1242     C |= LAC.transformAllAMXCast();
1243 
1244     X86LowerAMXType LAT(F);
1245     C |= LAT.visit();
1246 
1247     // Prepare for fast register allocation at O0.
1248     // Todo: May better check the volatile model of AMX code, not just
1249     // by checking Attribute::OptimizeNone and CodeGenOptLevel::None.
1250     if (TM->getOptLevel() == CodeGenOptLevel::None) {
1251       // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1252       // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1253       // sure the amx data is volatile, that is nessary for AMX fast
1254       // register allocation.
1255       if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1256         X86VolatileTileData VTD(F);
1257         C = VTD.volatileTileData() || C;
1258       }
1259     }
1260 
1261     return C;
1262   }
1263 
1264   void getAnalysisUsage(AnalysisUsage &AU) const override {
1265     AU.setPreservesCFG();
1266     AU.addRequired<TargetPassConfig>();
1267     AU.addRequired<TargetLibraryInfoWrapperPass>();
1268   }
1269 };
1270 
1271 } // anonymous namespace
1272 
1273 static const char PassName[] = "Lower AMX type for load/store";
1274 char X86LowerAMXTypeLegacyPass::ID = 0;
1275 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1276                       false)
1277 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1278 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
1279 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1280                     false)
1281 
1282 FunctionPass *llvm::createX86LowerAMXTypePass() {
1283   return new X86LowerAMXTypeLegacyPass();
1284 }
1285