1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform amx intrinsics to scalar operations.
10 /// This pass is always enabled and it skips when it is not -O0 and has no
11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
12 /// intrinsics is near the amx intrinsics code. We are not able to find a
13 /// point which post-dominate all the shape and dominate all amx intrinsics.
14 /// To decouple the dependency of the shape, we transform amx intrinsics
15 /// to scalar operation, so that compiling doesn't fail. In long term, we
16 /// should improve fast register allocation to allocate amx register.
17 //===----------------------------------------------------------------------===//
18 //
19 #include "X86.h"
20 #include "llvm/ADT/DenseSet.h"
21 #include "llvm/ADT/PostOrderIterator.h"
22 #include "llvm/Analysis/DomTreeUpdater.h"
23 #include "llvm/Analysis/LoopInfo.h"
24 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
25 #include "llvm/Analysis/TargetTransformInfo.h"
26 #include "llvm/CodeGen/Passes.h"
27 #include "llvm/CodeGen/TargetPassConfig.h"
28 #include "llvm/CodeGen/ValueTypes.h"
29 #include "llvm/IR/DataLayout.h"
30 #include "llvm/IR/Function.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/IntrinsicInst.h"
34 #include "llvm/IR/IntrinsicsX86.h"
35 #include "llvm/IR/PatternMatch.h"
36 #include "llvm/InitializePasses.h"
37 #include "llvm/Pass.h"
38 #include "llvm/Support/CommandLine.h"
39 #include "llvm/Target/TargetMachine.h"
40 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
41 #include "llvm/Transforms/Utils/LoopUtils.h"
42 
43 using namespace llvm;
44 using namespace PatternMatch;
45 
46 #define DEBUG_TYPE "lower-amx-intrinsics"
47 
48 #ifndef NDEBUG
49 static bool isV256I32Ty(Type *Ty) {
50   if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
51     return FVT->getNumElements() == 256 &&
52            FVT->getElementType()->isIntegerTy(32);
53   return false;
54 }
55 #endif
56 
57 static cl::opt<bool>
58     X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden,
59                     cl::desc("X86: enable AMX scalarizition."));
60 
61 namespace {
62 class X86LowerAMXIntrinsics {
63   Function &Func;
64 
65 public:
66   X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
67       : Func(F), DTU(DomTU), LI(LoopI) {}
68   bool visit();
69 
70 private:
71   DomTreeUpdater &DTU;
72   LoopInfo *LI;
73   BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
74                          Value *Step, StringRef Name, IRBuilderBase &B,
75                          Loop *L);
76   template <bool IsTileLoad>
77   Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
78                                   IRBuilderBase &B, Value *Row, Value *Col,
79                                   Value *Ptr, Value *Stride, Value *Tile);
80   template <Intrinsic::ID IntrID>
81   std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
82                        IntrID == Intrinsic::x86_tdpbsud_internal ||
83                        IntrID == Intrinsic::x86_tdpbusd_internal ||
84                        IntrID == Intrinsic::x86_tdpbuud_internal ||
85                        IntrID == Intrinsic::x86_tdpbf16ps_internal,
86                    Value *>
87   createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
88                     Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
89                     Value *RHS);
90   template <bool IsTileLoad>
91   bool lowerTileLoadStore(Instruction *TileLoadStore);
92   template <Intrinsic::ID IntrID>
93   std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
94                        IntrID == Intrinsic::x86_tdpbsud_internal ||
95                        IntrID == Intrinsic::x86_tdpbusd_internal ||
96                        IntrID == Intrinsic::x86_tdpbuud_internal ||
97                        IntrID == Intrinsic::x86_tdpbf16ps_internal,
98                    bool>
99   lowerTileDP(Instruction *TileDP);
100   bool lowerTileZero(Instruction *TileZero);
101 };
102 } // anonymous namespace
103 
104 BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
105                                               BasicBlock *Exit, Value *Bound,
106                                               Value *Step, StringRef Name,
107                                               IRBuilderBase &B, Loop *L) {
108   LLVMContext &Ctx = Preheader->getContext();
109   BasicBlock *Header =
110       BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
111   BasicBlock *Body =
112       BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
113   BasicBlock *Latch =
114       BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
115 
116   Type *I16Ty = Type::getInt16Ty(Ctx);
117   BranchInst::Create(Body, Header);
118   BranchInst::Create(Latch, Body);
119   PHINode *IV =
120       PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator());
121   IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
122 
123   B.SetInsertPoint(Latch);
124   Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
125   Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
126   BranchInst::Create(Header, Exit, Cond, Latch);
127   IV->addIncoming(Inc, Latch);
128 
129   BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
130   BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
131   PreheaderBr->setSuccessor(0, Header);
132   DTU.applyUpdatesPermissive({
133       {DominatorTree::Delete, Preheader, Tmp},
134       {DominatorTree::Insert, Header, Body},
135       {DominatorTree::Insert, Body, Latch},
136       {DominatorTree::Insert, Latch, Header},
137       {DominatorTree::Insert, Latch, Exit},
138       {DominatorTree::Insert, Preheader, Header},
139   });
140   if (LI) {
141     L->addBasicBlockToLoop(Header, *LI);
142     L->addBasicBlockToLoop(Body, *LI);
143     L->addBasicBlockToLoop(Latch, *LI);
144   }
145   return Body;
146 }
147 
148 template <bool IsTileLoad>
149 Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
150     BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
151     Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
152   std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
153   Loop *RowLoop = nullptr;
154   Loop *ColLoop = nullptr;
155   if (LI) {
156     RowLoop = LI->AllocateLoop();
157     ColLoop = LI->AllocateLoop();
158     RowLoop->addChildLoop(ColLoop);
159     if (Loop *ParentL = LI->getLoopFor(Start))
160       ParentL->addChildLoop(RowLoop);
161     else
162       LI->addTopLevelLoop(RowLoop);
163   }
164 
165   BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
166                                    IntrinName + ".scalarize.rows", B, RowLoop);
167   BasicBlock *RowLatch = RowBody->getSingleSuccessor();
168 
169   BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
170                                    IntrinName + ".scalarize.cols", B, ColLoop);
171 
172   BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
173   BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
174   BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
175   Value *CurrentRow = &*RowLoopHeader->begin();
176   Value *CurrentCol = &*ColLoopHeader->begin();
177   Type *EltTy = B.getInt32Ty();
178   FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256);
179 
180   // Common part for tileload and tilestore
181   // *.scalarize.cols.body:
182   // Calculate %idxmem and %idxvec
183   B.SetInsertPoint(ColBody->getTerminator());
184   Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
185   Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
186   Value *Offset =
187       B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
188   unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
189   Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS));
190   Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset);
191   Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
192   if (IsTileLoad) {
193     // tileload.scalarize.rows.header:
194     // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
195     // %tileload.scalarize.rows.latch ]
196     B.SetInsertPoint(RowLoopHeader->getTerminator());
197     Value *VecZero = Constant::getNullValue(V256I32Ty);
198     PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
199     VecCPhiRowLoop->addIncoming(VecZero, Start);
200 
201     // tileload.scalarize.cols.header:
202     // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
203     // ], [ %ResVec, %tileload.scalarize.cols.latch ]
204     B.SetInsertPoint(ColLoopHeader->getTerminator());
205     PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
206     VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
207 
208     // tileload.scalarize.cols.body:
209     // Calculate %idxmem and %idxvec
210     // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
211     // %elt = load i32, i32* %ptr
212     // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
213     B.SetInsertPoint(ColBody->getTerminator());
214     Value *Elt = B.CreateLoad(EltTy, EltPtr);
215     Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
216     VecPhi->addIncoming(ResVec, ColLoopLatch);
217     VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
218 
219     return ResVec;
220   } else {
221     auto *BitCast = cast<BitCastInst>(Tile);
222     Value *Vec = BitCast->getOperand(0);
223     assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
224     // tilestore.scalarize.cols.body:
225     // %mul = mul i16 %row.iv, i16 16
226     // %idx = add i16 %mul, i16 %col.iv
227     // %vec = extractelement <16 x i32> %vec, i16 %idx
228     // store i32 %vec, i32* %ptr
229     B.SetInsertPoint(ColBody->getTerminator());
230     Value *Elt = B.CreateExtractElement(Vec, Idx);
231 
232     B.CreateStore(Elt, EltPtr);
233     return nullptr;
234   }
235 }
236 
237 template <Intrinsic::ID IntrID>
238 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
239                      IntrID == Intrinsic::x86_tdpbsud_internal ||
240                      IntrID == Intrinsic::x86_tdpbusd_internal ||
241                      IntrID == Intrinsic::x86_tdpbuud_internal ||
242                      IntrID == Intrinsic::x86_tdpbf16ps_internal,
243                  Value *>
244 X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
245                                          IRBuilderBase &B, Value *Row,
246                                          Value *Col, Value *K, Value *Acc,
247                                          Value *LHS, Value *RHS) {
248   std::string IntrinName;
249   switch (IntrID) {
250   case Intrinsic::x86_tdpbssd_internal:
251     IntrinName = "tiledpbssd";
252     break;
253   case Intrinsic::x86_tdpbsud_internal:
254     IntrinName = "tiledpbsud";
255     break;
256   case Intrinsic::x86_tdpbusd_internal:
257     IntrinName = "tiledpbusd";
258     break;
259   case Intrinsic::x86_tdpbuud_internal:
260     IntrinName = "tiledpbuud";
261     break;
262   case Intrinsic::x86_tdpbf16ps_internal:
263     IntrinName = "tiledpbf16ps";
264     break;
265   }
266   Loop *RowLoop = nullptr;
267   Loop *ColLoop = nullptr;
268   Loop *InnerLoop = nullptr;
269   if (LI) {
270     RowLoop = LI->AllocateLoop();
271     ColLoop = LI->AllocateLoop();
272     InnerLoop = LI->AllocateLoop();
273     ColLoop->addChildLoop(InnerLoop);
274     RowLoop->addChildLoop(ColLoop);
275     if (Loop *ParentL = LI->getLoopFor(Start))
276       ParentL->addChildLoop(RowLoop);
277     else
278       LI->addTopLevelLoop(RowLoop);
279   }
280 
281   BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
282                                    IntrinName + ".scalarize.rows", B, RowLoop);
283   BasicBlock *RowLatch = RowBody->getSingleSuccessor();
284 
285   BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
286                                    IntrinName + ".scalarize.cols", B, ColLoop);
287 
288   BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
289 
290   B.SetInsertPoint(ColBody->getTerminator());
291   BasicBlock *InnerBody =
292       createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
293                  IntrinName + ".scalarize.inner", B, InnerLoop);
294 
295   BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
296   BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
297   BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
298   BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
299   Value *CurrentRow = &*RowLoopHeader->begin();
300   Value *CurrentCol = &*ColLoopHeader->begin();
301   Value *CurrentInner = &*InnerLoopHeader->begin();
302 
303   FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
304   auto *BitCastAcc = cast<BitCastInst>(Acc);
305   Value *VecC = BitCastAcc->getOperand(0);
306   assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
307   // TODO else create BitCast from x86amx to v256i32.
308   // Store x86amx to memory, and reload from memory
309   // to vector. However with -O0, it doesn't happen.
310   auto *BitCastLHS = cast<BitCastInst>(LHS);
311   Value *VecA = BitCastLHS->getOperand(0);
312   assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
313   auto *BitCastRHS = cast<BitCastInst>(RHS);
314   Value *VecB = BitCastRHS->getOperand(0);
315   assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
316 
317   // tiledpbssd.scalarize.rows.header:
318   // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
319   // %tiledpbssd.scalarize.rows.latch ]
320 
321   // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
322   // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
323   B.SetInsertPoint(RowLoopHeader->getTerminator());
324   PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
325   VecCPhiRowLoop->addIncoming(VecC, Start);
326   Value *VecZero = Constant::getNullValue(V256I32Ty);
327   PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
328   VecDPhiRowLoop->addIncoming(VecZero, Start);
329 
330   // tiledpbssd.scalarize.cols.header:
331   // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
332   // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
333   // %tiledpbssd.scalarize.cols.latch ]
334 
335   // %vec.d.phi.col = phi <256 x i32> [
336   // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
337   // %tiledpbssd.scalarize.cols.latch ]
338 
339   // calculate idxc.
340   B.SetInsertPoint(ColLoopHeader->getTerminator());
341   PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
342   VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
343   PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
344   VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
345   Value *IdxC =
346       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
347 
348   // tiledpbssd.scalarize.inner.header:
349   // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
350   // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
351   // %tiledpbssd.scalarize.inner.latch ]
352 
353   B.SetInsertPoint(InnerLoopHeader->getTerminator());
354   PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
355   VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
356 
357   B.SetInsertPoint(InnerBody->getTerminator());
358   Value *IdxA =
359       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
360   Value *IdxB =
361       B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
362   Value *NewVecC = nullptr;
363 
364   if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
365     // tiledpbssd.scalarize.inner.body:
366     // calculate idxa, idxb
367     // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
368     // %elta = extractelement <256 x i32> %veca, i16 %idxa
369     // %eltav4i8 = bitcast i32 %elta to <4 x i8>
370     // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
371     // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
372     // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
373     // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
374     // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
375     // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
376     // %neweltc = add i32 %elt, %acc
377     // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
378     // i16 %idxc
379     FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
380     FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
381     Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
382     Value *EltA = B.CreateExtractElement(VecA, IdxA);
383     Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
384     Value *EltB = B.CreateExtractElement(VecB, IdxB);
385     Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
386     Value *SEXTSubVecB = nullptr;
387     Value *SEXTSubVecA = nullptr;
388     switch (IntrID) {
389     case Intrinsic::x86_tdpbssd_internal:
390       SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
391       SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
392       break;
393     case Intrinsic::x86_tdpbsud_internal:
394       SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
395       SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
396       break;
397     case Intrinsic::x86_tdpbusd_internal:
398       SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
399       SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
400       break;
401     case Intrinsic::x86_tdpbuud_internal:
402       SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
403       SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
404       break;
405     default:
406       llvm_unreachable("Invalid intrinsic ID!");
407     }
408     Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
409     Value *ResElt = B.CreateAdd(EltC, SubVecR);
410     NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
411   } else {
412     // tiledpbf16ps.scalarize.inner.body:
413     // calculate idxa, idxb, idxc
414     // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
415     // %eltcf32 = bitcast i32 %eltc to float
416     // %elta = extractelement <256 x i32> %veca, i16 %idxa
417     // %eltav2i16 = bitcast i32 %elta to <2 x i16>
418     // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
419     // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
420     // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
421     // x i32> <i32 2, i32 0, i32 3, i32 1>
422     // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
423     // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
424     // i32> <i32 2, i32 0, i32 3, i32 1>
425     // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
426     // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
427     // %acc = call float
428     // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
429     // %neweltc = bitcast float %acc to i32
430     // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
431     // i16 %idxc
432     // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
433     // i16 %idxc
434     FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
435     FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
436     Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
437     Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
438     Value *EltA = B.CreateExtractElement(VecA, IdxA);
439     Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
440     Value *EltB = B.CreateExtractElement(VecB, IdxB);
441     Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
442     Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
443     int ShuffleMask[4] = {2, 0, 3, 1};
444     auto ShuffleArray = ArrayRef(ShuffleMask);
445     Value *AV2F32 = B.CreateBitCast(
446         B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
447     Value *BV2F32 = B.CreateBitCast(
448         B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
449     Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
450     Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
451     NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
452   }
453 
454   // tiledpbssd.scalarize.cols.latch:
455   // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
456   // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
457   // i16 %idxc
458   B.SetInsertPoint(ColLoopLatch->getTerminator());
459   Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
460   Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
461 
462   VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
463   VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
464   VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
465   VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
466   VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
467 
468   return NewVecD;
469 }
470 
471 template <Intrinsic::ID IntrID>
472 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
473                      IntrID == Intrinsic::x86_tdpbsud_internal ||
474                      IntrID == Intrinsic::x86_tdpbusd_internal ||
475                      IntrID == Intrinsic::x86_tdpbuud_internal ||
476                      IntrID == Intrinsic::x86_tdpbf16ps_internal,
477                  bool>
478 X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
479   Value *M, *N, *K, *C, *A, *B;
480   match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
481                                     m_Value(C), m_Value(A), m_Value(B)));
482   Instruction *InsertI = TileDP;
483   IRBuilder<> PreBuilder(TileDP);
484   PreBuilder.SetInsertPoint(TileDP);
485   // We visit the loop with (m, n/4, k/4):
486   // %n_dword = lshr i16 %n, 2
487   // %k_dword = lshr i16 %k, 2
488   Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
489   Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
490   BasicBlock *Start = InsertI->getParent();
491   BasicBlock *End =
492       SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
493   IRBuilder<> Builder(TileDP);
494   Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
495                                             KDWord, C, A, B);
496   // we cannot assume there always be bitcast after tiledpbssd. So we need to
497   // insert one bitcast as required
498   Builder.SetInsertPoint(End->getFirstNonPHI());
499   Value *ResAMX =
500       Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
501   // Delete TileDP intrinsic and do some clean-up.
502   for (Use &U : llvm::make_early_inc_range(TileDP->uses())) {
503     Instruction *I = cast<Instruction>(U.getUser());
504     Value *Vec;
505     if (match(I, m_BitCast(m_Value(Vec)))) {
506       I->replaceAllUsesWith(ResVec);
507       I->eraseFromParent();
508     }
509   }
510   TileDP->replaceAllUsesWith(ResAMX);
511   TileDP->eraseFromParent();
512   return true;
513 }
514 
515 template <bool IsTileLoad>
516 bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
517   Value *M, *N, *Ptr, *Stride, *Tile;
518   if (IsTileLoad)
519     match(TileLoadStore,
520           m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
521               m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
522   else
523     match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
524                              m_Value(M), m_Value(N), m_Value(Ptr),
525                              m_Value(Stride), m_Value(Tile)));
526 
527   Instruction *InsertI = TileLoadStore;
528   IRBuilder<> PreBuilder(TileLoadStore);
529   PreBuilder.SetInsertPoint(TileLoadStore);
530   Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
531   Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
532   BasicBlock *Start = InsertI->getParent();
533   BasicBlock *End =
534       SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
535   IRBuilder<> Builder(TileLoadStore);
536   Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
537       Start, End, Builder, M, NDWord, Ptr, StrideDWord,
538       IsTileLoad ? nullptr : Tile);
539   if (IsTileLoad) {
540     // we cannot assume there always be bitcast after tileload. So we need to
541     // insert one bitcast as required
542     Builder.SetInsertPoint(End->getFirstNonPHI());
543     Value *ResAMX =
544         Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
545     // Delete tileloadd6 intrinsic and do some clean-up
546     for (Use &U : llvm::make_early_inc_range(TileLoadStore->uses())) {
547       Instruction *I = cast<Instruction>(U.getUser());
548       Value *Vec;
549       if (match(I, m_BitCast(m_Value(Vec)))) {
550         I->replaceAllUsesWith(ResVec);
551         I->eraseFromParent();
552       }
553     }
554     TileLoadStore->replaceAllUsesWith(ResAMX);
555   }
556   TileLoadStore->eraseFromParent();
557   return true;
558 }
559 
560 bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
561   IRBuilder<> Builder(TileZero);
562   FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
563   Value *VecZero = Constant::getNullValue(V256I32Ty);
564   for (Use &U : llvm::make_early_inc_range(TileZero->uses())) {
565     Instruction *I = cast<Instruction>(U.getUser());
566     Value *Vec;
567     if (match(I, m_BitCast(m_Value(Vec)))) {
568       I->replaceAllUsesWith(VecZero);
569       I->eraseFromParent();
570     }
571   }
572   TileZero->eraseFromParent();
573   return true;
574 }
575 
576 bool X86LowerAMXIntrinsics::visit() {
577   bool C = false;
578   SmallVector<IntrinsicInst *, 8> WorkList;
579   for (BasicBlock *BB : depth_first(&Func)) {
580     for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
581       if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
582         switch (Inst->getIntrinsicID()) {
583         case Intrinsic::x86_tdpbssd_internal:
584         case Intrinsic::x86_tdpbsud_internal:
585         case Intrinsic::x86_tdpbusd_internal:
586         case Intrinsic::x86_tdpbuud_internal:
587         case Intrinsic::x86_tileloadd64_internal:
588         case Intrinsic::x86_tilestored64_internal:
589         case Intrinsic::x86_tilezero_internal:
590         case Intrinsic::x86_tdpbf16ps_internal:
591           WorkList.push_back(Inst);
592           break;
593         default:
594           break;
595         }
596       }
597     }
598   }
599 
600   for (auto *Inst : WorkList) {
601     switch (Inst->getIntrinsicID()) {
602     case Intrinsic::x86_tdpbssd_internal:
603       C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
604       break;
605     case Intrinsic::x86_tdpbsud_internal:
606       C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C;
607       break;
608     case Intrinsic::x86_tdpbusd_internal:
609       C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C;
610       break;
611     case Intrinsic::x86_tdpbuud_internal:
612       C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C;
613       break;
614     case Intrinsic::x86_tdpbf16ps_internal:
615       C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
616       break;
617     case Intrinsic::x86_tileloadd64_internal:
618       C = lowerTileLoadStore<true>(Inst) || C;
619       break;
620     case Intrinsic::x86_tilestored64_internal:
621       C = lowerTileLoadStore<false>(Inst) || C;
622       break;
623     case Intrinsic::x86_tilezero_internal:
624       C = lowerTileZero(Inst) || C;
625       break;
626     default:
627       llvm_unreachable("invalid amx intrinsics!");
628     }
629   }
630 
631   return C;
632 }
633 
634 namespace {
635 class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
636 public:
637   static char ID;
638 
639   X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
640     initializeX86LowerAMXIntrinsicsLegacyPassPass(
641         *PassRegistry::getPassRegistry());
642   }
643 
644   bool runOnFunction(Function &F) override {
645     if (!X86ScalarizeAMX)
646       return false;
647     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
648     if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
649         TM->getOptLevel() != CodeGenOpt::None)
650       return false;
651 
652     auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
653     auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
654     auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
655     auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
656     DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
657 
658     X86LowerAMXIntrinsics LAT(F, DTU, LI);
659     return LAT.visit();
660   }
661   StringRef getPassName() const override { return "Lower AMX intrinsics"; }
662 
663   void getAnalysisUsage(AnalysisUsage &AU) const override {
664     AU.addPreserved<DominatorTreeWrapperPass>();
665     AU.addPreserved<LoopInfoWrapperPass>();
666     AU.addRequired<TargetPassConfig>();
667   }
668 };
669 } // namespace
670 
671 static const char PassName[] = "Lower AMX intrinsics";
672 char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
673 INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
674                       false, false)
675 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
676 INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
677                     false, false)
678 
679 FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
680   return new X86LowerAMXIntrinsicsLegacyPass();
681 }
682