1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform amx intrinsics to scalar operations.
10 /// This pass is always enabled and it skips when it is not -O0 and has no
11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
12 /// intrinsics is near the amx intrinsics code. We are not able to find a
13 /// point which post-dominate all the shape and dominate all amx intrinsics.
14 /// To decouple the dependency of the shape, we transform amx intrinsics
15 /// to scalar operation, so that compiling doesn't fail. In long term, we
16 /// should improve fast register allocation to allocate amx register.
17 //===----------------------------------------------------------------------===//
18 //
19 #include "X86.h"
20 #include "llvm/ADT/DenseSet.h"
21 #include "llvm/ADT/PostOrderIterator.h"
22 #include "llvm/Analysis/DomTreeUpdater.h"
23 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/CodeGen/Passes.h"
26 #include "llvm/CodeGen/TargetPassConfig.h"
27 #include "llvm/CodeGen/ValueTypes.h"
28 #include "llvm/IR/DataLayout.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/IntrinsicsX86.h"
34 #include "llvm/IR/PatternMatch.h"
35 #include "llvm/InitializePasses.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Target/TargetMachine.h"
39 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
40 #include "llvm/Transforms/Utils/LoopUtils.h"
41
42 using namespace llvm;
43 using namespace PatternMatch;
44
45 #define DEBUG_TYPE "lower-amx-intrinsics"
46
47 #ifndef NDEBUG
isV256I32Ty(Type * Ty)48 static bool isV256I32Ty(Type *Ty) {
49 if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
50 return FVT->getNumElements() == 256 &&
51 FVT->getElementType()->isIntegerTy(32);
52 return false;
53 }
54 #endif
55
56 static cl::opt<bool>
57 X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden,
58 cl::desc("X86: enable AMX scalarizition."));
59
60 namespace {
61 class X86LowerAMXIntrinsics {
62 Function &Func;
63
64 public:
X86LowerAMXIntrinsics(Function & F,DomTreeUpdater & DomTU,LoopInfo * LoopI)65 X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
66 : Func(F), DTU(DomTU), LI(LoopI) {}
67 bool visit();
68
69 private:
70 DomTreeUpdater &DTU;
71 LoopInfo *LI;
72 BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
73 Value *Step, StringRef Name, IRBuilderBase &B,
74 Loop *L);
75 template <bool IsTileLoad>
76 Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
77 IRBuilderBase &B, Value *Row, Value *Col,
78 Value *Ptr, Value *Stride, Value *Tile);
79 template <Intrinsic::ID IntrID>
80 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
81 IntrID == Intrinsic::x86_tdpbsud_internal ||
82 IntrID == Intrinsic::x86_tdpbusd_internal ||
83 IntrID == Intrinsic::x86_tdpbuud_internal ||
84 IntrID == Intrinsic::x86_tdpbf16ps_internal,
85 Value *>::type
86 createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
87 Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
88 Value *RHS);
89 template <bool IsTileLoad>
90 bool lowerTileLoadStore(Instruction *TileLoadStore);
91 template <Intrinsic::ID IntrID>
92 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
93 IntrID == Intrinsic::x86_tdpbsud_internal ||
94 IntrID == Intrinsic::x86_tdpbusd_internal ||
95 IntrID == Intrinsic::x86_tdpbuud_internal ||
96 IntrID == Intrinsic::x86_tdpbf16ps_internal,
97 bool>::type
98 lowerTileDP(Instruction *TileDP);
99 bool lowerTileZero(Instruction *TileZero);
100 };
101 } // anonymous namespace
102
createLoop(BasicBlock * Preheader,BasicBlock * Exit,Value * Bound,Value * Step,StringRef Name,IRBuilderBase & B,Loop * L)103 BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
104 BasicBlock *Exit, Value *Bound,
105 Value *Step, StringRef Name,
106 IRBuilderBase &B, Loop *L) {
107 LLVMContext &Ctx = Preheader->getContext();
108 BasicBlock *Header =
109 BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
110 BasicBlock *Body =
111 BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
112 BasicBlock *Latch =
113 BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
114
115 Type *I16Ty = Type::getInt16Ty(Ctx);
116 BranchInst::Create(Body, Header);
117 BranchInst::Create(Latch, Body);
118 PHINode *IV =
119 PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator());
120 IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
121
122 B.SetInsertPoint(Latch);
123 Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
124 Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
125 BranchInst::Create(Header, Exit, Cond, Latch);
126 IV->addIncoming(Inc, Latch);
127
128 BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
129 BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
130 PreheaderBr->setSuccessor(0, Header);
131 DTU.applyUpdatesPermissive({
132 {DominatorTree::Delete, Preheader, Tmp},
133 {DominatorTree::Insert, Header, Body},
134 {DominatorTree::Insert, Body, Latch},
135 {DominatorTree::Insert, Latch, Header},
136 {DominatorTree::Insert, Latch, Exit},
137 {DominatorTree::Insert, Preheader, Header},
138 });
139 if (LI) {
140 L->addBasicBlockToLoop(Header, *LI);
141 L->addBasicBlockToLoop(Body, *LI);
142 L->addBasicBlockToLoop(Latch, *LI);
143 }
144 return Body;
145 }
146
147 template <bool IsTileLoad>
createTileLoadStoreLoops(BasicBlock * Start,BasicBlock * End,IRBuilderBase & B,Value * Row,Value * Col,Value * Ptr,Value * Stride,Value * Tile)148 Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
149 BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
150 Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
151 std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
152 Loop *RowLoop = nullptr;
153 Loop *ColLoop = nullptr;
154 if (LI) {
155 RowLoop = LI->AllocateLoop();
156 ColLoop = LI->AllocateLoop();
157 RowLoop->addChildLoop(ColLoop);
158 if (Loop *ParentL = LI->getLoopFor(Start))
159 ParentL->addChildLoop(RowLoop);
160 else
161 LI->addTopLevelLoop(RowLoop);
162 }
163
164 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
165 IntrinName + ".scalarize.rows", B, RowLoop);
166 BasicBlock *RowLatch = RowBody->getSingleSuccessor();
167
168 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
169 IntrinName + ".scalarize.cols", B, ColLoop);
170
171 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
172 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
173 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
174 Value *CurrentRow = &*RowLoopHeader->begin();
175 Value *CurrentCol = &*ColLoopHeader->begin();
176 Type *EltTy = B.getInt32Ty();
177 FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256);
178
179 // Common part for tileload and tilestore
180 // *.scalarize.cols.body:
181 // Calculate %idxmem and %idxvec
182 B.SetInsertPoint(ColBody->getTerminator());
183 Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
184 Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
185 Value *Offset =
186 B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
187 unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
188 Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS));
189 Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset);
190 Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
191 if (IsTileLoad) {
192 // tileload.scalarize.rows.header:
193 // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
194 // %tileload.scalarize.rows.latch ]
195 B.SetInsertPoint(RowLoopHeader->getTerminator());
196 Value *VecZero = Constant::getNullValue(V256I32Ty);
197 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
198 VecCPhiRowLoop->addIncoming(VecZero, Start);
199
200 // tileload.scalarize.cols.header:
201 // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
202 // ], [ %ResVec, %tileload.scalarize.cols.latch ]
203 B.SetInsertPoint(ColLoopHeader->getTerminator());
204 PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
205 VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
206
207 // tileload.scalarize.cols.body:
208 // Calculate %idxmem and %idxvec
209 // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
210 // %elt = load i32, i32* %ptr
211 // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
212 B.SetInsertPoint(ColBody->getTerminator());
213 Value *Elt = B.CreateLoad(EltTy, EltPtr);
214 Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
215 VecPhi->addIncoming(ResVec, ColLoopLatch);
216 VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
217
218 return ResVec;
219 } else {
220 auto *BitCast = cast<BitCastInst>(Tile);
221 Value *Vec = BitCast->getOperand(0);
222 assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
223 // tilestore.scalarize.cols.body:
224 // %mul = mul i16 %row.iv, i16 16
225 // %idx = add i16 %mul, i16 %col.iv
226 // %vec = extractelement <16 x i32> %vec, i16 %idx
227 // store i32 %vec, i32* %ptr
228 B.SetInsertPoint(ColBody->getTerminator());
229 Value *Elt = B.CreateExtractElement(Vec, Idx);
230
231 B.CreateStore(Elt, EltPtr);
232 return nullptr;
233 }
234 }
235
236 template <Intrinsic::ID IntrID>
237 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
238 IntrID == Intrinsic::x86_tdpbsud_internal ||
239 IntrID == Intrinsic::x86_tdpbusd_internal ||
240 IntrID == Intrinsic::x86_tdpbuud_internal ||
241 IntrID == Intrinsic::x86_tdpbf16ps_internal,
242 Value *>::type
createTileDPLoops(BasicBlock * Start,BasicBlock * End,IRBuilderBase & B,Value * Row,Value * Col,Value * K,Value * Acc,Value * LHS,Value * RHS)243 X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
244 IRBuilderBase &B, Value *Row,
245 Value *Col, Value *K, Value *Acc,
246 Value *LHS, Value *RHS) {
247 std::string IntrinName;
248 switch (IntrID) {
249 case Intrinsic::x86_tdpbssd_internal:
250 IntrinName = "tiledpbssd";
251 break;
252 case Intrinsic::x86_tdpbsud_internal:
253 IntrinName = "tiledpbsud";
254 break;
255 case Intrinsic::x86_tdpbusd_internal:
256 IntrinName = "tiledpbusd";
257 break;
258 case Intrinsic::x86_tdpbuud_internal:
259 IntrinName = "tiledpbuud";
260 break;
261 case Intrinsic::x86_tdpbf16ps_internal:
262 IntrinName = "tiledpbf16ps";
263 break;
264 }
265 Loop *RowLoop = nullptr;
266 Loop *ColLoop = nullptr;
267 Loop *InnerLoop = nullptr;
268 if (LI) {
269 RowLoop = LI->AllocateLoop();
270 ColLoop = LI->AllocateLoop();
271 InnerLoop = LI->AllocateLoop();
272 ColLoop->addChildLoop(InnerLoop);
273 RowLoop->addChildLoop(ColLoop);
274 if (Loop *ParentL = LI->getLoopFor(Start))
275 ParentL->addChildLoop(RowLoop);
276 else
277 LI->addTopLevelLoop(RowLoop);
278 }
279
280 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
281 IntrinName + ".scalarize.rows", B, RowLoop);
282 BasicBlock *RowLatch = RowBody->getSingleSuccessor();
283
284 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
285 IntrinName + ".scalarize.cols", B, ColLoop);
286
287 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
288
289 B.SetInsertPoint(ColBody->getTerminator());
290 BasicBlock *InnerBody =
291 createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
292 IntrinName + ".scalarize.inner", B, InnerLoop);
293
294 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
295 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
296 BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
297 BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
298 Value *CurrentRow = &*RowLoopHeader->begin();
299 Value *CurrentCol = &*ColLoopHeader->begin();
300 Value *CurrentInner = &*InnerLoopHeader->begin();
301
302 FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
303 auto *BitCastAcc = cast<BitCastInst>(Acc);
304 Value *VecC = BitCastAcc->getOperand(0);
305 assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
306 // TODO else create BitCast from x86amx to v256i32.
307 // Store x86amx to memory, and reload from memory
308 // to vector. However with -O0, it doesn't happen.
309 auto *BitCastLHS = cast<BitCastInst>(LHS);
310 Value *VecA = BitCastLHS->getOperand(0);
311 assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
312 auto *BitCastRHS = cast<BitCastInst>(RHS);
313 Value *VecB = BitCastRHS->getOperand(0);
314 assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
315
316 // tiledpbssd.scalarize.rows.header:
317 // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
318 // %tiledpbssd.scalarize.rows.latch ]
319
320 // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
321 // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
322 B.SetInsertPoint(RowLoopHeader->getTerminator());
323 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
324 VecCPhiRowLoop->addIncoming(VecC, Start);
325 Value *VecZero = Constant::getNullValue(V256I32Ty);
326 PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
327 VecDPhiRowLoop->addIncoming(VecZero, Start);
328
329 // tiledpbssd.scalarize.cols.header:
330 // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
331 // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
332 // %tiledpbssd.scalarize.cols.latch ]
333
334 // %vec.d.phi.col = phi <256 x i32> [
335 // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
336 // %tiledpbssd.scalarize.cols.latch ]
337
338 // calculate idxc.
339 B.SetInsertPoint(ColLoopHeader->getTerminator());
340 PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
341 VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
342 PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
343 VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
344 Value *IdxC =
345 B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
346
347 // tiledpbssd.scalarize.inner.header:
348 // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
349 // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
350 // %tiledpbssd.scalarize.inner.latch ]
351
352 B.SetInsertPoint(InnerLoopHeader->getTerminator());
353 PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
354 VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
355
356 B.SetInsertPoint(InnerBody->getTerminator());
357 Value *IdxA =
358 B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
359 Value *IdxB =
360 B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
361 Value *NewVecC = nullptr;
362
363 if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
364 // tiledpbssd.scalarize.inner.body:
365 // calculate idxa, idxb
366 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
367 // %elta = extractelement <256 x i32> %veca, i16 %idxa
368 // %eltav4i8 = bitcast i32 %elta to <4 x i8>
369 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
370 // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
371 // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
372 // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
373 // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
374 // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
375 // %neweltc = add i32 %elt, %acc
376 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
377 // i16 %idxc
378 FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
379 FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
380 Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
381 Value *EltA = B.CreateExtractElement(VecA, IdxA);
382 Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
383 Value *EltB = B.CreateExtractElement(VecB, IdxB);
384 Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
385 Value *SEXTSubVecB = nullptr;
386 Value *SEXTSubVecA = nullptr;
387 switch (IntrID) {
388 case Intrinsic::x86_tdpbssd_internal:
389 SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
390 SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
391 break;
392 case Intrinsic::x86_tdpbsud_internal:
393 SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
394 SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
395 break;
396 case Intrinsic::x86_tdpbusd_internal:
397 SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
398 SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
399 break;
400 case Intrinsic::x86_tdpbuud_internal:
401 SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
402 SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
403 break;
404 default:
405 llvm_unreachable("Invalid intrinsic ID!");
406 }
407 Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
408 Value *ResElt = B.CreateAdd(EltC, SubVecR);
409 NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
410 } else {
411 // tiledpbf16ps.scalarize.inner.body:
412 // calculate idxa, idxb, idxc
413 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
414 // %eltcf32 = bitcast i32 %eltc to float
415 // %elta = extractelement <256 x i32> %veca, i16 %idxa
416 // %eltav2i16 = bitcast i32 %elta to <2 x i16>
417 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
418 // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
419 // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
420 // x i32> <i32 2, i32 0, i32 3, i32 1>
421 // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
422 // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
423 // i32> <i32 2, i32 0, i32 3, i32 1>
424 // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
425 // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
426 // %acc = call float
427 // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
428 // %neweltc = bitcast float %acc to i32
429 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
430 // i16 %idxc
431 // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
432 // i16 %idxc
433 FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
434 FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
435 Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
436 Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
437 Value *EltA = B.CreateExtractElement(VecA, IdxA);
438 Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
439 Value *EltB = B.CreateExtractElement(VecB, IdxB);
440 Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
441 Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
442 int ShuffleMask[4] = {2, 0, 3, 1};
443 auto ShuffleArray = makeArrayRef(ShuffleMask);
444 Value *AV2F32 = B.CreateBitCast(
445 B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
446 Value *BV2F32 = B.CreateBitCast(
447 B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
448 Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
449 Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
450 NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
451 }
452
453 // tiledpbssd.scalarize.cols.latch:
454 // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
455 // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
456 // i16 %idxc
457 B.SetInsertPoint(ColLoopLatch->getTerminator());
458 Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
459 Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
460
461 VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
462 VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
463 VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
464 VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
465 VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
466
467 return NewVecD;
468 }
469
470 template <Intrinsic::ID IntrID>
471 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
472 IntrID == Intrinsic::x86_tdpbsud_internal ||
473 IntrID == Intrinsic::x86_tdpbusd_internal ||
474 IntrID == Intrinsic::x86_tdpbuud_internal ||
475 IntrID == Intrinsic::x86_tdpbf16ps_internal,
476 bool>::type
lowerTileDP(Instruction * TileDP)477 X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
478 Value *M, *N, *K, *C, *A, *B;
479 match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
480 m_Value(C), m_Value(A), m_Value(B)));
481 Instruction *InsertI = TileDP;
482 IRBuilder<> PreBuilder(TileDP);
483 PreBuilder.SetInsertPoint(TileDP);
484 // We visit the loop with (m, n/4, k/4):
485 // %n_dword = lshr i16 %n, 2
486 // %k_dword = lshr i16 %k, 2
487 Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
488 Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
489 BasicBlock *Start = InsertI->getParent();
490 BasicBlock *End =
491 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
492 IRBuilder<> Builder(TileDP);
493 Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
494 KDWord, C, A, B);
495 // we cannot assume there always be bitcast after tiledpbssd. So we need to
496 // insert one bitcast as required
497 Builder.SetInsertPoint(End->getFirstNonPHI());
498 Value *ResAMX =
499 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
500 // Delete TileDP intrinsic and do some clean-up.
501 for (auto UI = TileDP->use_begin(), UE = TileDP->use_end(); UI != UE;) {
502 Instruction *I = cast<Instruction>((UI++)->getUser());
503 Value *Vec;
504 if (match(I, m_BitCast(m_Value(Vec)))) {
505 I->replaceAllUsesWith(ResVec);
506 I->eraseFromParent();
507 }
508 }
509 TileDP->replaceAllUsesWith(ResAMX);
510 TileDP->eraseFromParent();
511 return true;
512 }
513
514 template <bool IsTileLoad>
lowerTileLoadStore(Instruction * TileLoadStore)515 bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
516 Value *M, *N, *Ptr, *Stride, *Tile;
517 if (IsTileLoad)
518 match(TileLoadStore,
519 m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
520 m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
521 else
522 match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
523 m_Value(M), m_Value(N), m_Value(Ptr),
524 m_Value(Stride), m_Value(Tile)));
525
526 Instruction *InsertI = TileLoadStore;
527 IRBuilder<> PreBuilder(TileLoadStore);
528 PreBuilder.SetInsertPoint(TileLoadStore);
529 Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
530 Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
531 BasicBlock *Start = InsertI->getParent();
532 BasicBlock *End =
533 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
534 IRBuilder<> Builder(TileLoadStore);
535 Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
536 Start, End, Builder, M, NDWord, Ptr, StrideDWord,
537 IsTileLoad ? nullptr : Tile);
538 if (IsTileLoad) {
539 // we cannot assume there always be bitcast after tileload. So we need to
540 // insert one bitcast as required
541 Builder.SetInsertPoint(End->getFirstNonPHI());
542 Value *ResAMX =
543 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
544 // Delete tileloadd6 intrinsic and do some clean-up
545 for (auto UI = TileLoadStore->use_begin(), UE = TileLoadStore->use_end();
546 UI != UE;) {
547 Instruction *I = cast<Instruction>((UI++)->getUser());
548 Value *Vec;
549 if (match(I, m_BitCast(m_Value(Vec)))) {
550 I->replaceAllUsesWith(ResVec);
551 I->eraseFromParent();
552 }
553 }
554 TileLoadStore->replaceAllUsesWith(ResAMX);
555 }
556 TileLoadStore->eraseFromParent();
557 return true;
558 }
559
lowerTileZero(Instruction * TileZero)560 bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
561 IRBuilder<> Builder(TileZero);
562 FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
563 Value *VecZero = Constant::getNullValue(V256I32Ty);
564 for (auto UI = TileZero->use_begin(), UE = TileZero->use_end(); UI != UE;) {
565 Instruction *I = cast<Instruction>((UI++)->getUser());
566 Value *Vec;
567 if (match(I, m_BitCast(m_Value(Vec)))) {
568 I->replaceAllUsesWith(VecZero);
569 I->eraseFromParent();
570 }
571 }
572 TileZero->eraseFromParent();
573 return true;
574 }
575
visit()576 bool X86LowerAMXIntrinsics::visit() {
577 bool C = false;
578 SmallVector<IntrinsicInst *, 8> WorkList;
579 for (BasicBlock *BB : depth_first(&Func)) {
580 for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
581 if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
582 switch (Inst->getIntrinsicID()) {
583 case Intrinsic::x86_tdpbssd_internal:
584 case Intrinsic::x86_tdpbsud_internal:
585 case Intrinsic::x86_tdpbusd_internal:
586 case Intrinsic::x86_tdpbuud_internal:
587 case Intrinsic::x86_tileloadd64_internal:
588 case Intrinsic::x86_tilestored64_internal:
589 case Intrinsic::x86_tilezero_internal:
590 case Intrinsic::x86_tdpbf16ps_internal:
591 WorkList.push_back(Inst);
592 break;
593 default:
594 break;
595 }
596 }
597 }
598 }
599
600 for (auto *Inst : WorkList) {
601 switch (Inst->getIntrinsicID()) {
602 case Intrinsic::x86_tdpbssd_internal:
603 C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
604 break;
605 case Intrinsic::x86_tdpbsud_internal:
606 C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C;
607 break;
608 case Intrinsic::x86_tdpbusd_internal:
609 C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C;
610 break;
611 case Intrinsic::x86_tdpbuud_internal:
612 C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C;
613 break;
614 case Intrinsic::x86_tdpbf16ps_internal:
615 C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
616 break;
617 case Intrinsic::x86_tileloadd64_internal:
618 C = lowerTileLoadStore<true>(Inst) || C;
619 break;
620 case Intrinsic::x86_tilestored64_internal:
621 C = lowerTileLoadStore<false>(Inst) || C;
622 break;
623 case Intrinsic::x86_tilezero_internal:
624 C = lowerTileZero(Inst) || C;
625 break;
626 default:
627 llvm_unreachable("invalid amx intrinsics!");
628 }
629 }
630
631 return C;
632 }
633
634 class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
635 public:
636 static char ID;
637
X86LowerAMXIntrinsicsLegacyPass()638 X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
639 initializeX86LowerAMXIntrinsicsLegacyPassPass(
640 *PassRegistry::getPassRegistry());
641 }
642
runOnFunction(Function & F)643 bool runOnFunction(Function &F) override {
644 if (!X86ScalarizeAMX)
645 return false;
646 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
647 if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
648 TM->getOptLevel() != CodeGenOpt::None)
649 return false;
650
651 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
652 auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
653 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
654 auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
655 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
656
657 X86LowerAMXIntrinsics LAT(F, DTU, LI);
658 return LAT.visit();
659 }
getPassName() const660 StringRef getPassName() const override { return "Lower AMX intrinsics"; }
661
getAnalysisUsage(AnalysisUsage & AU) const662 void getAnalysisUsage(AnalysisUsage &AU) const override {
663 AU.addPreserved<DominatorTreeWrapperPass>();
664 AU.addPreserved<LoopInfoWrapperPass>();
665 AU.addRequired<TargetPassConfig>();
666 }
667 };
668
669 static const char PassName[] = "Lower AMX intrinsics";
670 char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass,DEBUG_TYPE,PassName,false,false)671 INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
672 false, false)
673 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
674 INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
675 false, false)
676
677 FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
678 return new X86LowerAMXIntrinsicsLegacyPass();
679 }
680