1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lower matrix intrinsics to vector operations.
10 //
11 // TODO:
12 // * Implement multiply & add fusion
13 // * Add remark, summarizing the available matrix optimization opportunities.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
18 #include "llvm/ADT/GraphTraits.h"
19 #include "llvm/ADT/PostOrderIterator.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Analysis/TargetTransformInfo.h"
22 #include "llvm/Analysis/VectorUtils.h"
23 #include "llvm/IR/CFG.h"
24 #include "llvm/IR/DataLayout.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/PatternMatch.h"
30 #include "llvm/InitializePasses.h"
31 #include "llvm/Pass.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Transforms/Scalar.h"
34
35 using namespace llvm;
36 using namespace PatternMatch;
37
38 #define DEBUG_TYPE "lower-matrix-intrinsics"
39
40 static cl::opt<bool> EnableShapePropagation("matrix-propagate-shape",
41 cl::init(true));
42
43 static cl::opt<bool> AllowContractEnabled(
44 "matrix-allow-contract", cl::init(false), cl::Hidden,
45 cl::desc("Allow the use of FMAs if available and profitable. This may "
46 "result in different results, due to less rounding error."));
47
48 namespace {
49
50 // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute
51 // the start address of column \p Col with type (\p EltType x \p NumRows)
52 // assuming \p Stride elements between start two consecutive columns.
53 // \p Stride must be >= \p NumRows.
54 //
55 // Consider a 4x4 matrix like below
56 //
57 // 0 1 2 3
58 // 0 v_0_0 v_0_1 v_0_2 v_0_3
59 // 1 v_1_0 v_1_1 v_1_2 v_1_3
60 // 2 v_2_0 v_2_1 v_2_2 v_2_3
61 // 3 v_3_0 v_3_1 v_3_2 v_3_3
62
63 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
64 // we need a pointer to the first element of the submatrix as base pointer.
65 // Then we can use computeColumnAddr to compute the addresses for the columns
66 // of the sub-matrix.
67 //
68 // Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
69 // -> just returns Base
70 // Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
71 // -> returns Base + (1 * 4)
72 // Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
73 // -> returns Base + (2 * 4)
74 //
75 // The graphic below illustrates the number of elements in a column (marked
76 // with |) and the number of skipped elements (marked with }).
77 //
78 // v_0_0 v_0_1 {v_0_2 {v_0_3
79 // Base Col 1 Col 2
80 // | | |
81 // v_1_0 |v_1_1 |v_1_2 |v_1_3
82 // v_2_0 |v_2_1 |v_2_2 |v_2_3
83 // v_3_0 {v_3_1 {v_3_2 v_3_3
84 //
computeColumnAddr(Value * BasePtr,Value * Col,Value * Stride,unsigned NumRows,Type * EltType,IRBuilder<> & Builder)85 Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride,
86 unsigned NumRows, Type *EltType,
87 IRBuilder<> &Builder) {
88
89 assert((!isa<ConstantInt>(Stride) ||
90 cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) &&
91 "Stride must be >= the number of rows.");
92 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
93
94 // Compute the start of the column with index Col as Col * Stride.
95 Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start");
96
97 // Get pointer to the start of the selected column. Skip GEP creation,
98 // if we select column 0.
99 if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero())
100 ColumnStart = BasePtr;
101 else
102 ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep");
103
104 // Cast elementwise column start pointer to a pointer to a column
105 // (EltType x NumRows)*.
106 Type *ColumnType = VectorType::get(EltType, NumRows);
107 Type *ColumnPtrType = PointerType::get(ColumnType, AS);
108 return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast");
109 }
110
111 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
112 ///
113 /// Currently, the lowering for each matrix intrinsic is done as follows:
114 /// 1. Propagate the shape information from intrinsics to connected
115 /// instructions.
116 /// 2. Lower instructions with shape information.
117 /// 2.1. Get column vectors for each argument. If we already lowered the
118 /// definition of an argument, use the produced column vectors directly.
119 /// If not, split the operand vector containing an embedded matrix into
120 /// a set of column vectors,
121 /// 2.2. Lower the instruction in terms of columnwise operations, which yields
122 /// a set of column vectors containing result matrix. Note that we lower
123 /// all instructions that have shape information. Besides the intrinsics,
124 /// this includes stores for example.
125 /// 2.3. Update uses of the lowered instruction. If we have shape information
126 /// for a user, there is nothing to do, as we will look up the result
127 /// column matrix when lowering the user. For other uses, we embed the
128 /// result matrix in a flat vector and update the use.
129 /// 2.4. Cache the result column matrix for the instruction we lowered
130 /// 3. After we lowered all instructions in a function, remove the now
131 /// obsolete instructions.
132 ///
133 class LowerMatrixIntrinsics {
134 Function &Func;
135 const DataLayout &DL;
136 const TargetTransformInfo &TTI;
137
138 /// Wrapper class representing a matrix as a set of column vectors.
139 /// All column vectors must have the same vector type.
140 class ColumnMatrixTy {
141 SmallVector<Value *, 16> Columns;
142
143 public:
ColumnMatrixTy()144 ColumnMatrixTy() : Columns() {}
ColumnMatrixTy(ArrayRef<Value * > Cols)145 ColumnMatrixTy(ArrayRef<Value *> Cols)
146 : Columns(Cols.begin(), Cols.end()) {}
147
getColumn(unsigned i) const148 Value *getColumn(unsigned i) const { return Columns[i]; }
149
setColumn(unsigned i,Value * V)150 void setColumn(unsigned i, Value *V) { Columns[i] = V; }
151
getNumColumns() const152 size_t getNumColumns() const { return Columns.size(); }
getNumRows() const153 size_t getNumRows() const {
154 assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
155 return cast<VectorType>(Columns[0]->getType())->getNumElements();
156 }
157
getColumnVectors() const158 const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; }
159
getColumnVectors()160 SmallVectorImpl<Value *> &getColumnVectors() { return Columns; }
161
addColumn(Value * V)162 void addColumn(Value *V) { Columns.push_back(V); }
163
columns()164 iterator_range<SmallVector<Value *, 8>::iterator> columns() {
165 return make_range(Columns.begin(), Columns.end());
166 }
167
168 /// Embed the columns of the matrix into a flat vector by concatenating
169 /// them.
embedInVector(IRBuilder<> & Builder) const170 Value *embedInVector(IRBuilder<> &Builder) const {
171 return Columns.size() == 1 ? Columns[0]
172 : concatenateVectors(Builder, Columns);
173 }
174 };
175
176 struct ShapeInfo {
177 unsigned NumRows;
178 unsigned NumColumns;
179
ShapeInfo__anon41dac4dd0111::LowerMatrixIntrinsics::ShapeInfo180 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
181 : NumRows(NumRows), NumColumns(NumColumns) {}
182
ShapeInfo__anon41dac4dd0111::LowerMatrixIntrinsics::ShapeInfo183 ShapeInfo(Value *NumRows, Value *NumColumns)
184 : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()),
185 NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {}
186
operator ==__anon41dac4dd0111::LowerMatrixIntrinsics::ShapeInfo187 bool operator==(const ShapeInfo &other) {
188 return NumRows == other.NumRows && NumColumns == other.NumColumns;
189 }
operator !=__anon41dac4dd0111::LowerMatrixIntrinsics::ShapeInfo190 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
191
192 /// Returns true if shape-information is defined, meaning both dimensions
193 /// are != 0.
operator bool__anon41dac4dd0111::LowerMatrixIntrinsics::ShapeInfo194 operator bool() const {
195 assert(NumRows == 0 || NumColumns != 0);
196 return NumRows != 0;
197 }
198 };
199
200 /// Maps instructions to their shape information. The shape information
201 /// describes the shape to be used while lowering. This matches the shape of
202 /// the result value of the instruction, with the only exceptions being store
203 /// instructions and the matrix_columnwise_store intrinsics. For those, the
204 /// shape information indicates that those instructions should be lowered
205 /// using shape information as well.
206 DenseMap<Value *, ShapeInfo> ShapeMap;
207
208 /// List of instructions to remove. While lowering, we are not replacing all
209 /// users of a lowered instruction, if shape information is available and
210 /// those need to be removed after we finished lowering.
211 SmallVector<Instruction *, 16> ToRemove;
212
213 /// Map from instructions to their produced column matrix.
214 DenseMap<Value *, ColumnMatrixTy> Inst2ColumnMatrix;
215
216 public:
LowerMatrixIntrinsics(Function & F,TargetTransformInfo & TTI)217 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI)
218 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {}
219
220 /// Return the set of column vectors that a matrix value is lowered to.
221 ///
222 /// If we lowered \p MatrixVal, just return the cache result column matrix.
223 /// Otherwie split the flat vector \p MatrixVal containing a matrix with
224 /// shape \p SI into column vectors.
getMatrix(Value * MatrixVal,const ShapeInfo & SI,IRBuilder<> Builder)225 ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
226 IRBuilder<> Builder) {
227 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
228 assert(VType && "MatrixVal must be a vector type");
229 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
230 "The vector size must match the number of matrix elements");
231
232 // Check if we lowered MatrixVal using shape information. In that case,
233 // return the existing column matrix, if it matches the requested shape
234 // information. If there is a mis-match, embed the result in a flat
235 // vector and split it later.
236 auto Found = Inst2ColumnMatrix.find(MatrixVal);
237 if (Found != Inst2ColumnMatrix.end()) {
238 ColumnMatrixTy &M = Found->second;
239 // Return the found matrix, if its shape matches the requested shape
240 // information
241 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
242 return M;
243
244 MatrixVal = M.embedInVector(Builder);
245 }
246
247 // Otherwise split MatrixVal.
248 SmallVector<Value *, 16> SplitVecs;
249 Value *Undef = UndefValue::get(VType);
250 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
251 MaskStart += SI.NumRows) {
252 Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0);
253 Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split");
254 SplitVecs.push_back(V);
255 }
256
257 return {SplitVecs};
258 }
259
260 /// If \p V already has a known shape return false. Otherwise set the shape
261 /// for instructions that support it.
setShapeInfo(Value * V,ShapeInfo Shape)262 bool setShapeInfo(Value *V, ShapeInfo Shape) {
263 assert(Shape && "Shape not set");
264 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
265 return false;
266
267 auto SIter = ShapeMap.find(V);
268 if (SIter != ShapeMap.end()) {
269 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
270 << SIter->second.NumRows << " "
271 << SIter->second.NumColumns << " for " << *V << "\n");
272 return false;
273 }
274
275 ShapeMap.insert({V, Shape});
276 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
277 << " for " << *V << "\n");
278 return true;
279 }
280
isUniformShape(Value * V)281 bool isUniformShape(Value *V) {
282 Instruction *I = dyn_cast<Instruction>(V);
283 if (!I)
284 return true;
285
286 switch (I->getOpcode()) {
287 case Instruction::FAdd:
288 case Instruction::FSub:
289 case Instruction::FMul: // Scalar multiply.
290 case Instruction::Add:
291 case Instruction::Mul:
292 case Instruction::Sub:
293 return true;
294 default:
295 return false;
296 }
297 }
298
299 /// Returns true if shape information can be used for \p V. The supported
300 /// instructions must match the instructions that can be lowered by this pass.
supportsShapeInfo(Value * V)301 bool supportsShapeInfo(Value *V) {
302 Instruction *Inst = dyn_cast<Instruction>(V);
303 if (!Inst)
304 return false;
305
306 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
307 if (II)
308 switch (II->getIntrinsicID()) {
309 case Intrinsic::matrix_multiply:
310 case Intrinsic::matrix_transpose:
311 case Intrinsic::matrix_columnwise_load:
312 case Intrinsic::matrix_columnwise_store:
313 return true;
314 default:
315 return false;
316 }
317 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
318 }
319
320 /// Propagate the shape information of instructions to their users.
321 /// The work list contains instructions for which we can compute the shape,
322 /// either based on the information provided by matrix intrinsics or known
323 /// shapes of operands.
324 SmallVector<Instruction *, 32>
propagateShapeForward(SmallVectorImpl<Instruction * > & WorkList)325 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
326 SmallVector<Instruction *, 32> NewWorkList;
327 // Pop an element for which we guaranteed to have at least one of the
328 // operand shapes. Add the shape for this and then add users to the work
329 // list.
330 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
331 while (!WorkList.empty()) {
332 Instruction *Inst = WorkList.back();
333 WorkList.pop_back();
334
335 // New entry, set the value and insert operands
336 bool Propagate = false;
337
338 Value *MatrixA;
339 Value *MatrixB;
340 Value *M;
341 Value *N;
342 Value *K;
343 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
344 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
345 m_Value(N), m_Value(K)))) {
346 Propagate = setShapeInfo(Inst, {M, K});
347 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
348 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
349 // Flip dimensions.
350 Propagate = setShapeInfo(Inst, {N, M});
351 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
352 m_Value(MatrixA), m_Value(), m_Value(),
353 m_Value(M), m_Value(N)))) {
354 Propagate = setShapeInfo(Inst, {N, M});
355 } else if (match(Inst,
356 m_Intrinsic<Intrinsic::matrix_columnwise_load>(
357 m_Value(), m_Value(), m_Value(M), m_Value(N)))) {
358 Propagate = setShapeInfo(Inst, {M, N});
359 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
360 auto OpShape = ShapeMap.find(MatrixA);
361 if (OpShape != ShapeMap.end())
362 setShapeInfo(Inst, OpShape->second);
363 continue;
364 } else if (isUniformShape(Inst)) {
365 // Find the first operand that has a known shape and use that.
366 for (auto &Op : Inst->operands()) {
367 auto OpShape = ShapeMap.find(Op.get());
368 if (OpShape != ShapeMap.end()) {
369 Propagate |= setShapeInfo(Inst, OpShape->second);
370 break;
371 }
372 }
373 }
374
375 if (Propagate) {
376 NewWorkList.push_back(Inst);
377 for (auto *User : Inst->users())
378 if (ShapeMap.count(User) == 0)
379 WorkList.push_back(cast<Instruction>(User));
380 }
381 }
382
383 return NewWorkList;
384 }
385
386 /// Propagate the shape to operands of instructions with shape information.
387 /// \p Worklist contains the instruction for which we already know the shape.
388 SmallVector<Instruction *, 32>
propagateShapeBackward(SmallVectorImpl<Instruction * > & WorkList)389 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
390 SmallVector<Instruction *, 32> NewWorkList;
391
392 auto pushInstruction = [](Value *V,
393 SmallVectorImpl<Instruction *> &WorkList) {
394 Instruction *I = dyn_cast<Instruction>(V);
395 if (I)
396 WorkList.push_back(I);
397 };
398 // Pop an element with known shape. Traverse the operands, if their shape
399 // derives from the result shape and is unknown, add it and add them to the
400 // worklist.
401 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
402 while (!WorkList.empty()) {
403 Value *V = WorkList.back();
404 WorkList.pop_back();
405
406 size_t BeforeProcessingV = WorkList.size();
407 if (!isa<Instruction>(V))
408 continue;
409
410 Value *MatrixA;
411 Value *MatrixB;
412 Value *M;
413 Value *N;
414 Value *K;
415 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
416 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
417 m_Value(N), m_Value(K)))) {
418 if (setShapeInfo(MatrixA, {M, N}))
419 pushInstruction(MatrixA, WorkList);
420
421 if (setShapeInfo(MatrixB, {N, K}))
422 pushInstruction(MatrixB, WorkList);
423
424 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
425 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
426 // Flip dimensions.
427 if (setShapeInfo(MatrixA, {M, N}))
428 pushInstruction(MatrixA, WorkList);
429 } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
430 m_Value(MatrixA), m_Value(), m_Value(),
431 m_Value(M), m_Value(N)))) {
432 if (setShapeInfo(MatrixA, {M, N})) {
433 pushInstruction(MatrixA, WorkList);
434 }
435 } else if (isa<LoadInst>(V) ||
436 match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) {
437 // Nothing to do, no matrix input.
438 } else if (isa<StoreInst>(V)) {
439 // Nothing to do. We forward-propagated to this so we would just
440 // backward propagate to an instruction with an already known shape.
441 } else if (isUniformShape(V)) {
442 // Propagate to all operands.
443 ShapeInfo Shape = ShapeMap[V];
444 for (Use &U : cast<Instruction>(V)->operands()) {
445 if (setShapeInfo(U.get(), Shape))
446 pushInstruction(U.get(), WorkList);
447 }
448 }
449 // After we discovered new shape info for new instructions in the
450 // worklist, we use their users as seeds for the next round of forward
451 // propagation.
452 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
453 for (User *U : WorkList[I]->users())
454 if (isa<Instruction>(U) && V != U)
455 NewWorkList.push_back(cast<Instruction>(U));
456 }
457 return NewWorkList;
458 }
459
Visit()460 bool Visit() {
461 if (EnableShapePropagation) {
462 SmallVector<Instruction *, 32> WorkList;
463
464 // Initially only the shape of matrix intrinsics is known.
465 // Initialize the work list with ops carrying shape information.
466 for (BasicBlock &BB : Func)
467 for (Instruction &Inst : BB) {
468 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
469 if (!II)
470 continue;
471
472 switch (II->getIntrinsicID()) {
473 case Intrinsic::matrix_multiply:
474 case Intrinsic::matrix_transpose:
475 case Intrinsic::matrix_columnwise_load:
476 case Intrinsic::matrix_columnwise_store:
477 WorkList.push_back(&Inst);
478 break;
479 default:
480 break;
481 }
482 }
483 // Propagate shapes until nothing changes any longer.
484 while (!WorkList.empty()) {
485 WorkList = propagateShapeForward(WorkList);
486 WorkList = propagateShapeBackward(WorkList);
487 }
488 }
489
490 ReversePostOrderTraversal<Function *> RPOT(&Func);
491 bool Changed = false;
492 for (auto *BB : RPOT) {
493 for (Instruction &Inst : make_early_inc_range(*BB)) {
494 IRBuilder<> Builder(&Inst);
495
496 if (CallInst *CInst = dyn_cast<CallInst>(&Inst))
497 Changed |= VisitCallInst(CInst);
498
499 Value *Op1;
500 Value *Op2;
501 if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst))
502 Changed |= VisitBinaryOperator(BinOp);
503 if (match(&Inst, m_Load(m_Value(Op1))))
504 Changed |= VisitLoad(&Inst, Op1, Builder);
505 else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
506 Changed |= VisitStore(&Inst, Op1, Op2, Builder);
507 }
508 }
509
510 for (Instruction *Inst : reverse(ToRemove))
511 Inst->eraseFromParent();
512
513 return Changed;
514 }
515
createColumnLoad(Value * ColumnPtr,Type * EltType,IRBuilder<> Builder)516 LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType,
517 IRBuilder<> Builder) {
518 unsigned Align = DL.getABITypeAlignment(EltType);
519 return Builder.CreateAlignedLoad(ColumnPtr, Align, "col.load");
520 }
521
createColumnStore(Value * ColumnValue,Value * ColumnPtr,Type * EltType,IRBuilder<> Builder)522 StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr,
523 Type *EltType, IRBuilder<> Builder) {
524 unsigned Align = DL.getABITypeAlignment(EltType);
525 return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align);
526 }
527
528
529 /// Turns \p BasePtr into an elementwise pointer to \p EltType.
createElementPtr(Value * BasePtr,Type * EltType,IRBuilder<> & Builder)530 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
531 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
532 Type *EltPtrType = PointerType::get(EltType, AS);
533 return Builder.CreatePointerCast(BasePtr, EltPtrType);
534 }
535
536 /// Replace intrinsic calls
VisitCallInst(CallInst * Inst)537 bool VisitCallInst(CallInst *Inst) {
538 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
539 return false;
540
541 switch (Inst->getCalledFunction()->getIntrinsicID()) {
542 case Intrinsic::matrix_multiply:
543 LowerMultiply(Inst);
544 break;
545 case Intrinsic::matrix_transpose:
546 LowerTranspose(Inst);
547 break;
548 case Intrinsic::matrix_columnwise_load:
549 LowerColumnwiseLoad(Inst);
550 break;
551 case Intrinsic::matrix_columnwise_store:
552 LowerColumnwiseStore(Inst);
553 break;
554 default:
555 return false;
556 }
557 return true;
558 }
559
LowerLoad(Instruction * Inst,Value * Ptr,Value * Stride,ShapeInfo Shape)560 void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
561 ShapeInfo Shape) {
562 IRBuilder<> Builder(Inst);
563 auto VType = cast<VectorType>(Inst->getType());
564 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
565 ColumnMatrixTy Result;
566 // Distance between start of one column and the start of the next
567 for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) {
568 Value *GEP =
569 computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows,
570 VType->getElementType(), Builder);
571 Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder);
572 Result.addColumn(Column);
573 }
574
575 finalizeLowering(Inst, Result, Builder);
576 }
577
578 /// Lowers llvm.matrix.columnwise.load.
579 ///
580 /// The intrinsic loads a matrix from memory using a stride between columns.
LowerColumnwiseLoad(CallInst * Inst)581 void LowerColumnwiseLoad(CallInst *Inst) {
582 Value *Ptr = Inst->getArgOperand(0);
583 Value *Stride = Inst->getArgOperand(1);
584 LowerLoad(Inst, Ptr, Stride,
585 {Inst->getArgOperand(2), Inst->getArgOperand(3)});
586 }
587
LowerStore(Instruction * Inst,Value * Matrix,Value * Ptr,Value * Stride,ShapeInfo Shape)588 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
589 ShapeInfo Shape) {
590 IRBuilder<> Builder(Inst);
591 auto VType = cast<VectorType>(Matrix->getType());
592 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
593 auto LM = getMatrix(Matrix, Shape, Builder);
594 for (auto C : enumerate(LM.columns())) {
595 Value *GEP =
596 computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride,
597 Shape.NumRows, VType->getElementType(), Builder);
598 createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
599 }
600
601 ToRemove.push_back(Inst);
602 }
603
604 /// Lowers llvm.matrix.columnwise.store.
605 ///
606 /// The intrinsic store a matrix back memory using a stride between columns.
LowerColumnwiseStore(CallInst * Inst)607 void LowerColumnwiseStore(CallInst *Inst) {
608 Value *Matrix = Inst->getArgOperand(0);
609 Value *Ptr = Inst->getArgOperand(1);
610 Value *Stride = Inst->getArgOperand(2);
611 LowerStore(Inst, Matrix, Ptr, Stride,
612 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
613 }
614
615 /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from
616 /// the matrix \p LM represented as a vector of column vectors.
extractVector(const ColumnMatrixTy & LM,unsigned I,unsigned J,unsigned NumElts,IRBuilder<> Builder)617 Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J,
618 unsigned NumElts, IRBuilder<> Builder) {
619 Value *Col = LM.getColumn(J);
620 Value *Undef = UndefValue::get(Col->getType());
621 Constant *Mask = createSequentialMask(Builder, I, NumElts, 0);
622 return Builder.CreateShuffleVector(Col, Undef, Mask, "block");
623 }
624
625 // Set elements I..I+NumElts-1 to Block
insertVector(Value * Col,unsigned I,Value * Block,IRBuilder<> Builder)626 Value *insertVector(Value *Col, unsigned I, Value *Block,
627 IRBuilder<> Builder) {
628
629 // First, bring Block to the same size as Col
630 unsigned BlockNumElts =
631 cast<VectorType>(Block->getType())->getNumElements();
632 unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements();
633 assert(NumElts >= BlockNumElts && "Too few elements for current block");
634
635 Value *ExtendMask =
636 createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts);
637 Value *Undef = UndefValue::get(Block->getType());
638 Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask);
639
640 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
641 // 8, 4, 5, 6
642 SmallVector<Constant *, 16> Mask;
643 unsigned i;
644 for (i = 0; i < I; i++)
645 Mask.push_back(Builder.getInt32(i));
646
647 unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements();
648 for (; i < I + BlockNumElts; i++)
649 Mask.push_back(Builder.getInt32(i - I + VecNumElts));
650
651 for (; i < VecNumElts; i++)
652 Mask.push_back(Builder.getInt32(i));
653
654 Value *MaskVal = ConstantVector::get(Mask);
655
656 return Builder.CreateShuffleVector(Col, Block, MaskVal);
657 }
658
createMulAdd(Value * Sum,Value * A,Value * B,bool UseFPOp,IRBuilder<> & Builder,bool AllowContraction)659 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
660 IRBuilder<> &Builder, bool AllowContraction) {
661
662 if (!Sum)
663 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
664
665 if (UseFPOp) {
666 if (AllowContraction) {
667 // Use fmuladd for floating point operations and let the backend decide
668 // if that's profitable.
669 Value *FMulAdd = Intrinsic::getDeclaration(
670 Func.getParent(), Intrinsic::fmuladd, A->getType());
671 return Builder.CreateCall(FMulAdd, {A, B, Sum});
672 }
673 Value *Mul = Builder.CreateFMul(A, B);
674 return Builder.CreateFAdd(Sum, Mul);
675 }
676
677 Value *Mul = Builder.CreateMul(A, B);
678 return Builder.CreateAdd(Sum, Mul);
679 }
680
681 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
682 /// users with shape information, there's nothing to do: the will use the
683 /// cached value when they are lowered. For other users, \p Matrix is
684 /// flattened and the uses are updated to use it. Also marks \p Inst for
685 /// deletion.
finalizeLowering(Instruction * Inst,ColumnMatrixTy Matrix,IRBuilder<> & Builder)686 void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix,
687 IRBuilder<> &Builder) {
688 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
689
690 ToRemove.push_back(Inst);
691 Value *Flattened = nullptr;
692 for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
693 Use &U = *I++;
694 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
695 if (!Flattened)
696 Flattened = Matrix.embedInVector(Builder);
697 U.set(Flattened);
698 }
699 }
700 }
701
702 /// Lowers llvm.matrix.multiply.
LowerMultiply(CallInst * MatMul)703 void LowerMultiply(CallInst *MatMul) {
704 IRBuilder<> Builder(MatMul);
705 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
706 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
707 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
708
709 const ColumnMatrixTy &Lhs =
710 getMatrix(MatMul->getArgOperand(0), LShape, Builder);
711 const ColumnMatrixTy &Rhs =
712 getMatrix(MatMul->getArgOperand(1), RShape, Builder);
713
714 const unsigned R = LShape.NumRows;
715 const unsigned M = LShape.NumColumns;
716 const unsigned C = RShape.NumColumns;
717 assert(M == RShape.NumRows);
718
719 // Initialize the output
720 ColumnMatrixTy Result;
721 for (unsigned J = 0; J < C; ++J)
722 Result.addColumn(UndefValue::get(VectorType::get(EltType, R)));
723
724 const unsigned VF = std::max(TTI.getRegisterBitWidth(true) /
725 EltType->getPrimitiveSizeInBits(),
726 uint64_t(1));
727
728 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
729 MatMul->hasAllowContract());
730 // Multiply columns from the first operand with scalars from the second
731 // operand. Then move along the K axes and accumulate the columns. With
732 // this the adds can be vectorized without reassociation.
733 for (unsigned J = 0; J < C; ++J) {
734 unsigned BlockSize = VF;
735 for (unsigned I = 0; I < R; I += BlockSize) {
736 // Gradually lower the vectorization factor to cover the remainder.
737 while (I + BlockSize > R)
738 BlockSize /= 2;
739
740 Value *Sum = nullptr;
741 for (unsigned K = 0; K < M; ++K) {
742 Value *L = extractVector(Lhs, I, K, BlockSize, Builder);
743 Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K);
744 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
745 Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(),
746 Builder, AllowContract);
747 }
748 Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
749 }
750 }
751 finalizeLowering(MatMul, Result, Builder);
752 }
753
754 /// Lowers llvm.matrix.transpose.
LowerTranspose(CallInst * Inst)755 void LowerTranspose(CallInst *Inst) {
756 ColumnMatrixTy Result;
757 IRBuilder<> Builder(Inst);
758 Value *InputVal = Inst->getArgOperand(0);
759 VectorType *VectorTy = cast<VectorType>(InputVal->getType());
760 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
761 ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
762
763 for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) {
764 // Build a single column vector for this row. First initialize it.
765 Value *ResultColumn = UndefValue::get(
766 VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns));
767
768 // Go through the elements of this row and insert it into the resulting
769 // column vector.
770 for (auto C : enumerate(InputMatrix.columns())) {
771 Value *Elt = Builder.CreateExtractElement(C.value(), Row);
772 // We insert at index Column since that is the row index after the
773 // transpose.
774 ResultColumn =
775 Builder.CreateInsertElement(ResultColumn, Elt, C.index());
776 }
777 Result.addColumn(ResultColumn);
778 }
779
780 finalizeLowering(Inst, Result, Builder);
781 }
782
783 /// Lower load instructions, if shape information is available.
VisitLoad(Instruction * Inst,Value * Ptr,IRBuilder<> & Builder)784 bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) {
785 auto I = ShapeMap.find(Inst);
786 if (I == ShapeMap.end())
787 return false;
788
789 LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second);
790 return true;
791 }
792
VisitStore(Instruction * Inst,Value * StoredVal,Value * Ptr,IRBuilder<> & Builder)793 bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr,
794 IRBuilder<> &Builder) {
795 auto I = ShapeMap.find(StoredVal);
796 if (I == ShapeMap.end())
797 return false;
798
799 LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second);
800 return true;
801 }
802
803 /// Lower binary operators, if shape information is available.
VisitBinaryOperator(BinaryOperator * Inst)804 bool VisitBinaryOperator(BinaryOperator *Inst) {
805 auto I = ShapeMap.find(Inst);
806 if (I == ShapeMap.end())
807 return false;
808
809 Value *Lhs = Inst->getOperand(0);
810 Value *Rhs = Inst->getOperand(1);
811
812 IRBuilder<> Builder(Inst);
813 ShapeInfo &Shape = I->second;
814
815 ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder);
816 ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder);
817
818 // Add each column and store the result back into the opmapping
819 ColumnMatrixTy Result;
820 auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) {
821 switch (Inst->getOpcode()) {
822 case Instruction::Add:
823 return Builder.CreateAdd(LHS, RHS);
824 case Instruction::Mul:
825 return Builder.CreateMul(LHS, RHS);
826 case Instruction::Sub:
827 return Builder.CreateSub(LHS, RHS);
828 case Instruction::FAdd:
829 return Builder.CreateFAdd(LHS, RHS);
830 case Instruction::FMul:
831 return Builder.CreateFMul(LHS, RHS);
832 case Instruction::FSub:
833 return Builder.CreateFSub(LHS, RHS);
834 default:
835 llvm_unreachable("Unsupported binary operator for matrix");
836 }
837 };
838 for (unsigned C = 0; C < Shape.NumColumns; ++C)
839 Result.addColumn(
840 BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C)));
841
842 finalizeLowering(Inst, Result, Builder);
843 return true;
844 }
845 };
846 } // namespace
847
run(Function & F,FunctionAnalysisManager & AM)848 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
849 FunctionAnalysisManager &AM) {
850 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
851 LowerMatrixIntrinsics LMT(F, TTI);
852 if (LMT.Visit()) {
853 PreservedAnalyses PA;
854 PA.preserveSet<CFGAnalyses>();
855 return PA;
856 }
857 return PreservedAnalyses::all();
858 }
859
860 namespace {
861
862 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
863 public:
864 static char ID;
865
LowerMatrixIntrinsicsLegacyPass()866 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
867 initializeLowerMatrixIntrinsicsLegacyPassPass(
868 *PassRegistry::getPassRegistry());
869 }
870
runOnFunction(Function & F)871 bool runOnFunction(Function &F) override {
872 auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
873 LowerMatrixIntrinsics LMT(F, *TTI);
874 bool C = LMT.Visit();
875 return C;
876 }
877
getAnalysisUsage(AnalysisUsage & AU) const878 void getAnalysisUsage(AnalysisUsage &AU) const override {
879 AU.addRequired<TargetTransformInfoWrapperPass>();
880 AU.setPreservesCFG();
881 }
882 };
883 } // namespace
884
885 static const char pass_name[] = "Lower the matrix intrinsics";
886 char LowerMatrixIntrinsicsLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass,DEBUG_TYPE,pass_name,false,false)887 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
888 false, false)
889 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
890 false, false)
891
892 Pass *llvm::createLowerMatrixIntrinsicsPass() {
893 return new LowerMatrixIntrinsicsLegacyPass();
894 }
895