1 //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 //
18 //===----------------------------------------------------------------------===//
19 //
20 #include "X86.h"
21 #include "llvm/ADT/PostOrderIterator.h"
22 #include "llvm/ADT/SmallSet.h"
23 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/CodeGen/Passes.h"
26 #include "llvm/CodeGen/ValueTypes.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/IntrinsicsX86.h"
33 #include "llvm/IR/PatternMatch.h"
34 #include "llvm/InitializePasses.h"
35 #include "llvm/Pass.h"
36 
37 using namespace llvm;
38 using namespace PatternMatch;
39 
40 #define DEBUG_TYPE "lower-amx-type"
41 
42 static AllocaInst *CreateAllocaInst(IRBuilder<> &Builder, BasicBlock *BB) {
43   Function &F = *BB->getParent();
44   Module *M = BB->getModule();
45   const DataLayout &DL = M->getDataLayout();
46 
47   Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
48   LLVMContext &Ctx = Builder.getContext();
49   auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
50   unsigned AllocaAS = DL.getAllocaAddrSpace();
51   AllocaInst *AllocaRes =
52       new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front());
53   AllocaRes->setAlignment(AllocaAlignment);
54   return AllocaRes;
55 }
56 
57 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
58   Value *Row = nullptr, *Col = nullptr;
59   switch (II->getIntrinsicID()) {
60   default:
61     llvm_unreachable("Expect amx intrinsics");
62   case Intrinsic::x86_tileloadd64_internal:
63   case Intrinsic::x86_tilestored64_internal: {
64     Row = II->getArgOperand(0);
65     Col = II->getArgOperand(1);
66     break;
67   }
68   // a * b + c
69   // The shape depends on which operand.
70   case Intrinsic::x86_tdpbssd_internal: {
71     switch (OpNo) {
72     case 3:
73       Row = II->getArgOperand(0);
74       Col = II->getArgOperand(1);
75       break;
76     case 4:
77       Row = II->getArgOperand(0);
78       Col = II->getArgOperand(2);
79       break;
80     case 5:
81       Row = II->getArgOperand(2);
82       Col = II->getArgOperand(1);
83       break;
84     }
85     break;
86   }
87   }
88 
89   return std::make_pair(Row, Col);
90 }
91 
92 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
93 // %2 = bitcast <256 x i32> %src to x86_amx
94 // -->
95 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
96 // i8* %addr, i64 %stride64)
97 static void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
98   Value *Row = nullptr, *Col = nullptr;
99   Use &U = *(Bitcast->use_begin());
100   unsigned OpNo = U.getOperandNo();
101   auto *II = cast<IntrinsicInst>(U.getUser());
102   std::tie(Row, Col) = getShape(II, OpNo);
103   IRBuilder<> Builder(Bitcast);
104   // Use the maximun column as stride.
105   Value *Stride = Builder.getInt64(64);
106   Value *I8Ptr =
107       Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
108   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
109 
110   Value *NewInst =
111       Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
112   Bitcast->replaceAllUsesWith(NewInst);
113 }
114 
115 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
116 //                                                    %stride);
117 // %13 = bitcast x86_amx %src to <256 x i32>
118 // store <256 x i32> %13, <256 x i32>* %addr, align 64
119 // -->
120 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
121 //                                           %stride64, %13)
122 static void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
123 
124   Value *Tile = Bitcast->getOperand(0);
125   auto *II = cast<IntrinsicInst>(Tile);
126   // Tile is output from AMX intrinsic. The first operand of the
127   // intrinsic is row, the second operand of the intrinsic is column.
128   Value *Row = II->getOperand(0);
129   Value *Col = II->getOperand(1);
130   IRBuilder<> Builder(ST);
131   // Use the maximum column as stride. It must be the same with load
132   // stride.
133   Value *Stride = Builder.getInt64(64);
134   Value *I8Ptr =
135       Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
136   std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
137   Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
138   if (Bitcast->hasOneUse())
139     return;
140   // %13 = bitcast x86_amx %src to <256 x i32>
141   // store <256 x i32> %13, <256 x i32>* %addr, align 64
142   // %add = <256 x i32> %13, <256 x i32> %src2
143   // -->
144   // %13 = bitcast x86_amx %src to <256 x i32>
145   // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
146   //                                           %stride64, %13)
147   // %14 = load <256 x i32>, %addr
148   // %add = <256 x i32> %14, <256 x i32> %src2
149   Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
150   Bitcast->replaceAllUsesWith(Vec);
151 }
152 
153 // transform bitcast to <store, load> instructions.
154 static bool transformBitcast(BitCastInst *Bitcast) {
155   IRBuilder<> Builder(Bitcast);
156   AllocaInst *AllocaAddr;
157   Value *I8Ptr, *Stride;
158   auto *Src = Bitcast->getOperand(0);
159 
160   auto Prepare = [&]() {
161     AllocaAddr = CreateAllocaInst(Builder, Bitcast->getParent());
162     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
163     Stride = Builder.getInt64(64);
164   };
165 
166   if (Bitcast->getType()->isX86_AMXTy()) {
167     // %2 = bitcast <256 x i32> %src to x86_amx
168     // -->
169     // %addr = alloca <256 x i32>, align 64
170     // store <256 x i32> %src, <256 x i32>* %addr, align 64
171     // %addr2 = bitcast <256 x i32>* to i8*
172     // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
173     //                                                  i8* %addr2,
174     //                                                  i64 64)
175     Use &U = *(Bitcast->use_begin());
176     unsigned OpNo = U.getOperandNo();
177     auto *II = dyn_cast<IntrinsicInst>(U.getUser());
178     if (!II)
179       return false; // May be bitcast from x86amx to <256 x i32>.
180     Prepare();
181     Builder.CreateStore(Src, AllocaAddr);
182     // TODO we can pick an constant operand for the shape.
183     Value *Row = nullptr, *Col = nullptr;
184     std::tie(Row, Col) = getShape(II, OpNo);
185     std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
186     Value *NewInst = Builder.CreateIntrinsic(
187         Intrinsic::x86_tileloadd64_internal, None, Args);
188     Bitcast->replaceAllUsesWith(NewInst);
189   } else {
190     // %2 = bitcast x86_amx %src to <256 x i32>
191     // -->
192     // %addr = alloca <256 x i32>, align 64
193     // %addr2 = bitcast <256 x i32>* to i8*
194     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
195     //                                           i8* %addr2, i64 %stride)
196     // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
197     auto *II = dyn_cast<IntrinsicInst>(Src);
198     if (!II)
199       return false; // May be bitcast from <256 x i32> to x86amx.
200     Prepare();
201     Value *Row = II->getOperand(0);
202     Value *Col = II->getOperand(1);
203     std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
204     Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
205     Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
206     Bitcast->replaceAllUsesWith(NewInst);
207   }
208 
209   return true;
210 }
211 
212 namespace {
213 class X86LowerAMXType {
214   Function &Func;
215 
216 public:
217   X86LowerAMXType(Function &F) : Func(F) {}
218   bool visit();
219 };
220 
221 bool X86LowerAMXType::visit() {
222   SmallVector<Instruction *, 8> DeadInsts;
223 
224   for (BasicBlock *BB : post_order(&Func)) {
225     for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
226          II != IE;) {
227       Instruction &Inst = *II++;
228       auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
229       if (!Bitcast)
230         continue;
231 
232       Value *Src = Bitcast->getOperand(0);
233       if (Bitcast->getType()->isX86_AMXTy()) {
234         if (Bitcast->user_empty()) {
235           DeadInsts.push_back(Bitcast);
236           continue;
237         }
238         LoadInst *LD = dyn_cast<LoadInst>(Src);
239         if (!LD) {
240           if (transformBitcast(Bitcast))
241             DeadInsts.push_back(Bitcast);
242           continue;
243         }
244         // If load has mutli-user, duplicate a vector load.
245         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
246         // %2 = bitcast <256 x i32> %src to x86_amx
247         // %add = add <256 x i32> %src, <256 x i32> %src2
248         // -->
249         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
250         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
251         //                                            i8* %addr, i64 %stride64)
252         // %add = add <256 x i32> %src, <256 x i32> %src2
253 
254         // If load has one user, the load will be eliminated in DAG ISel.
255         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
256         // %2 = bitcast <256 x i32> %src to x86_amx
257         // -->
258         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
259         //                                            i8* %addr, i64 %stride64)
260         combineLoadBitcast(LD, Bitcast);
261         DeadInsts.push_back(Bitcast);
262         if (LD->hasOneUse())
263           DeadInsts.push_back(LD);
264       } else if (Src->getType()->isX86_AMXTy()) {
265         if (Bitcast->user_empty()) {
266           DeadInsts.push_back(Bitcast);
267           continue;
268         }
269         StoreInst *ST = nullptr;
270         for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end();
271              UI != UE;) {
272           Value *I = (UI++)->getUser();
273           ST = dyn_cast<StoreInst>(I);
274           if (ST)
275             break;
276         }
277         if (!ST) {
278           if (transformBitcast(Bitcast))
279             DeadInsts.push_back(Bitcast);
280           continue;
281         }
282         // If bitcast (%13) has one use, combine bitcast and store to amx store.
283         // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
284         //                                                    %stride);
285         // %13 = bitcast x86_amx %src to <256 x i32>
286         // store <256 x i32> %13, <256 x i32>* %addr, align 64
287         // -->
288         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
289         //                                           %stride64, %13)
290         //
291         // If bitcast (%13) has multi-use, transform as below.
292         // %13 = bitcast x86_amx %src to <256 x i32>
293         // store <256 x i32> %13, <256 x i32>* %addr, align 64
294         // %add = <256 x i32> %13, <256 x i32> %src2
295         // -->
296         // %13 = bitcast x86_amx %src to <256 x i32>
297         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
298         //                                           %stride64, %13)
299         // %14 = load <256 x i32>, %addr
300         // %add = <256 x i32> %14, <256 x i32> %src2
301         //
302         combineBitcastStore(Bitcast, ST);
303         // Delete user first.
304         DeadInsts.push_back(ST);
305         DeadInsts.push_back(Bitcast);
306       }
307     }
308   }
309 
310   bool C = !DeadInsts.empty();
311 
312   for (auto *Inst : DeadInsts)
313     Inst->eraseFromParent();
314 
315   return C;
316 }
317 } // anonymous namespace
318 
319 namespace {
320 
321 class X86LowerAMXTypeLegacyPass : public FunctionPass {
322 public:
323   static char ID;
324 
325   X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
326     initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
327   }
328 
329   bool runOnFunction(Function &F) override {
330     X86LowerAMXType LAT(F);
331     bool C = LAT.visit();
332     return C;
333   }
334 
335   void getAnalysisUsage(AnalysisUsage &AU) const override {
336     AU.setPreservesCFG();
337   }
338 };
339 
340 } // anonymous namespace
341 
342 static const char PassName[] = "Lower AMX type for load/store";
343 char X86LowerAMXTypeLegacyPass::ID = 0;
344 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
345                       false)
346 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
347                     false)
348 
349 FunctionPass *llvm::createX86LowerAMXTypePass() {
350   return new X86LowerAMXTypeLegacyPass();
351 }
352