1 //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 //
18 //===----------------------------------------------------------------------===//
19 //
20 #include "X86.h"
21 #include "llvm/ADT/PostOrderIterator.h"
22 #include "llvm/ADT/SmallSet.h"
23 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/CodeGen/Passes.h"
26 #include "llvm/CodeGen/ValueTypes.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/IntrinsicsX86.h"
33 #include "llvm/IR/PatternMatch.h"
34 #include "llvm/InitializePasses.h"
35 #include "llvm/Pass.h"
36
37 using namespace llvm;
38 using namespace PatternMatch;
39
40 #define DEBUG_TYPE "lower-amx-type"
41
CreateAllocaInst(IRBuilder<> & Builder,BasicBlock * BB)42 static AllocaInst *CreateAllocaInst(IRBuilder<> &Builder, BasicBlock *BB) {
43 Function &F = *BB->getParent();
44 Module *M = BB->getModule();
45 const DataLayout &DL = M->getDataLayout();
46
47 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
48 LLVMContext &Ctx = Builder.getContext();
49 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
50 unsigned AllocaAS = DL.getAllocaAddrSpace();
51 AllocaInst *AllocaRes =
52 new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front());
53 AllocaRes->setAlignment(AllocaAlignment);
54 return AllocaRes;
55 }
56
getShape(IntrinsicInst * II,unsigned OpNo)57 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
58 Value *Row = nullptr, *Col = nullptr;
59 switch (II->getIntrinsicID()) {
60 default:
61 llvm_unreachable("Expect amx intrinsics");
62 case Intrinsic::x86_tileloadd64_internal:
63 case Intrinsic::x86_tilestored64_internal: {
64 Row = II->getArgOperand(0);
65 Col = II->getArgOperand(1);
66 break;
67 }
68 // a * b + c
69 // The shape depends on which operand.
70 case Intrinsic::x86_tdpbssd_internal: {
71 switch (OpNo) {
72 case 3:
73 Row = II->getArgOperand(0);
74 Col = II->getArgOperand(1);
75 break;
76 case 4:
77 Row = II->getArgOperand(0);
78 Col = II->getArgOperand(2);
79 break;
80 case 5:
81 Row = II->getArgOperand(2);
82 Col = II->getArgOperand(1);
83 break;
84 }
85 break;
86 }
87 }
88
89 return std::make_pair(Row, Col);
90 }
91
92 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
93 // %2 = bitcast <256 x i32> %src to x86_amx
94 // -->
95 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
96 // i8* %addr, i64 %stride64)
combineLoadBitcast(LoadInst * LD,BitCastInst * Bitcast)97 static void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
98 Value *Row = nullptr, *Col = nullptr;
99 Use &U = *(Bitcast->use_begin());
100 unsigned OpNo = U.getOperandNo();
101 auto *II = cast<IntrinsicInst>(U.getUser());
102 std::tie(Row, Col) = getShape(II, OpNo);
103 IRBuilder<> Builder(Bitcast);
104 // Use the maximun column as stride.
105 Value *Stride = Builder.getInt64(64);
106 Value *I8Ptr =
107 Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
108 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
109
110 Value *NewInst =
111 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
112 Bitcast->replaceAllUsesWith(NewInst);
113 }
114
115 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
116 // %stride);
117 // %13 = bitcast x86_amx %src to <256 x i32>
118 // store <256 x i32> %13, <256 x i32>* %addr, align 64
119 // -->
120 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
121 // %stride64, %13)
combineBitcastStore(BitCastInst * Bitcast,StoreInst * ST)122 static void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
123
124 Value *Tile = Bitcast->getOperand(0);
125 auto *II = cast<IntrinsicInst>(Tile);
126 // Tile is output from AMX intrinsic. The first operand of the
127 // intrinsic is row, the second operand of the intrinsic is column.
128 Value *Row = II->getOperand(0);
129 Value *Col = II->getOperand(1);
130 IRBuilder<> Builder(ST);
131 // Use the maximum column as stride. It must be the same with load
132 // stride.
133 Value *Stride = Builder.getInt64(64);
134 Value *I8Ptr =
135 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
136 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
137 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
138 if (Bitcast->hasOneUse())
139 return;
140 // %13 = bitcast x86_amx %src to <256 x i32>
141 // store <256 x i32> %13, <256 x i32>* %addr, align 64
142 // %add = <256 x i32> %13, <256 x i32> %src2
143 // -->
144 // %13 = bitcast x86_amx %src to <256 x i32>
145 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
146 // %stride64, %13)
147 // %14 = load <256 x i32>, %addr
148 // %add = <256 x i32> %14, <256 x i32> %src2
149 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
150 Bitcast->replaceAllUsesWith(Vec);
151 }
152
153 // transform bitcast to <store, load> instructions.
transformBitcast(BitCastInst * Bitcast)154 static bool transformBitcast(BitCastInst *Bitcast) {
155 IRBuilder<> Builder(Bitcast);
156 AllocaInst *AllocaAddr;
157 Value *I8Ptr, *Stride;
158 auto *Src = Bitcast->getOperand(0);
159
160 auto Prepare = [&]() {
161 AllocaAddr = CreateAllocaInst(Builder, Bitcast->getParent());
162 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
163 Stride = Builder.getInt64(64);
164 };
165
166 if (Bitcast->getType()->isX86_AMXTy()) {
167 // %2 = bitcast <256 x i32> %src to x86_amx
168 // -->
169 // %addr = alloca <256 x i32>, align 64
170 // store <256 x i32> %src, <256 x i32>* %addr, align 64
171 // %addr2 = bitcast <256 x i32>* to i8*
172 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
173 // i8* %addr2,
174 // i64 64)
175 Use &U = *(Bitcast->use_begin());
176 unsigned OpNo = U.getOperandNo();
177 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
178 if (!II)
179 return false; // May be bitcast from x86amx to <256 x i32>.
180 Prepare();
181 Builder.CreateStore(Src, AllocaAddr);
182 // TODO we can pick an constant operand for the shape.
183 Value *Row = nullptr, *Col = nullptr;
184 std::tie(Row, Col) = getShape(II, OpNo);
185 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
186 Value *NewInst = Builder.CreateIntrinsic(
187 Intrinsic::x86_tileloadd64_internal, None, Args);
188 Bitcast->replaceAllUsesWith(NewInst);
189 } else {
190 // %2 = bitcast x86_amx %src to <256 x i32>
191 // -->
192 // %addr = alloca <256 x i32>, align 64
193 // %addr2 = bitcast <256 x i32>* to i8*
194 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
195 // i8* %addr2, i64 %stride)
196 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
197 auto *II = dyn_cast<IntrinsicInst>(Src);
198 if (!II)
199 return false; // May be bitcast from <256 x i32> to x86amx.
200 Prepare();
201 Value *Row = II->getOperand(0);
202 Value *Col = II->getOperand(1);
203 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
204 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
205 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
206 Bitcast->replaceAllUsesWith(NewInst);
207 }
208
209 return true;
210 }
211
212 namespace {
213 class X86LowerAMXType {
214 Function &Func;
215
216 public:
X86LowerAMXType(Function & F)217 X86LowerAMXType(Function &F) : Func(F) {}
218 bool visit();
219 };
220
visit()221 bool X86LowerAMXType::visit() {
222 SmallVector<Instruction *, 8> DeadInsts;
223
224 for (BasicBlock *BB : post_order(&Func)) {
225 for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
226 II != IE;) {
227 Instruction &Inst = *II++;
228 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
229 if (!Bitcast)
230 continue;
231
232 Value *Src = Bitcast->getOperand(0);
233 if (Bitcast->getType()->isX86_AMXTy()) {
234 if (Bitcast->user_empty()) {
235 DeadInsts.push_back(Bitcast);
236 continue;
237 }
238 LoadInst *LD = dyn_cast<LoadInst>(Src);
239 if (!LD) {
240 if (transformBitcast(Bitcast))
241 DeadInsts.push_back(Bitcast);
242 continue;
243 }
244 // If load has mutli-user, duplicate a vector load.
245 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
246 // %2 = bitcast <256 x i32> %src to x86_amx
247 // %add = add <256 x i32> %src, <256 x i32> %src2
248 // -->
249 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
250 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
251 // i8* %addr, i64 %stride64)
252 // %add = add <256 x i32> %src, <256 x i32> %src2
253
254 // If load has one user, the load will be eliminated in DAG ISel.
255 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
256 // %2 = bitcast <256 x i32> %src to x86_amx
257 // -->
258 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
259 // i8* %addr, i64 %stride64)
260 combineLoadBitcast(LD, Bitcast);
261 DeadInsts.push_back(Bitcast);
262 if (LD->hasOneUse())
263 DeadInsts.push_back(LD);
264 } else if (Src->getType()->isX86_AMXTy()) {
265 if (Bitcast->user_empty()) {
266 DeadInsts.push_back(Bitcast);
267 continue;
268 }
269 StoreInst *ST = nullptr;
270 for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end();
271 UI != UE;) {
272 Value *I = (UI++)->getUser();
273 ST = dyn_cast<StoreInst>(I);
274 if (ST)
275 break;
276 }
277 if (!ST) {
278 if (transformBitcast(Bitcast))
279 DeadInsts.push_back(Bitcast);
280 continue;
281 }
282 // If bitcast (%13) has one use, combine bitcast and store to amx store.
283 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
284 // %stride);
285 // %13 = bitcast x86_amx %src to <256 x i32>
286 // store <256 x i32> %13, <256 x i32>* %addr, align 64
287 // -->
288 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
289 // %stride64, %13)
290 //
291 // If bitcast (%13) has multi-use, transform as below.
292 // %13 = bitcast x86_amx %src to <256 x i32>
293 // store <256 x i32> %13, <256 x i32>* %addr, align 64
294 // %add = <256 x i32> %13, <256 x i32> %src2
295 // -->
296 // %13 = bitcast x86_amx %src to <256 x i32>
297 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
298 // %stride64, %13)
299 // %14 = load <256 x i32>, %addr
300 // %add = <256 x i32> %14, <256 x i32> %src2
301 //
302 combineBitcastStore(Bitcast, ST);
303 // Delete user first.
304 DeadInsts.push_back(ST);
305 DeadInsts.push_back(Bitcast);
306 }
307 }
308 }
309
310 bool C = !DeadInsts.empty();
311
312 for (auto *Inst : DeadInsts)
313 Inst->eraseFromParent();
314
315 return C;
316 }
317 } // anonymous namespace
318
319 namespace {
320
321 class X86LowerAMXTypeLegacyPass : public FunctionPass {
322 public:
323 static char ID;
324
X86LowerAMXTypeLegacyPass()325 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
326 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
327 }
328
runOnFunction(Function & F)329 bool runOnFunction(Function &F) override {
330 X86LowerAMXType LAT(F);
331 bool C = LAT.visit();
332 return C;
333 }
334
getAnalysisUsage(AnalysisUsage & AU) const335 void getAnalysisUsage(AnalysisUsage &AU) const override {
336 AU.setPreservesCFG();
337 }
338 };
339
340 } // anonymous namespace
341
342 static const char PassName[] = "Lower AMX type for load/store";
343 char X86LowerAMXTypeLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass,DEBUG_TYPE,PassName,false,false)344 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
345 false)
346 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
347 false)
348
349 FunctionPass *llvm::createX86LowerAMXTypePass() {
350 return new X86LowerAMXTypeLegacyPass();
351 }
352