1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lower matrix intrinsics to vector operations.
10 //
11 // TODO:
12 // * Improve fusion:
13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14 // transposed.
15 // * Improve cost-modeling, e.g. choose different number of rows/columns
16 // columns for tiles, consider cost of copies on alias.
17 //
18 //===----------------------------------------------------------------------===//
19
20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
21 #include "llvm/ADT/GraphTraits.h"
22 #include "llvm/ADT/PostOrderIterator.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Analysis/AliasAnalysis.h"
25 #include "llvm/Analysis/DomTreeUpdater.h"
26 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
27 #include "llvm/Analysis/TargetTransformInfo.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/Analysis/VectorUtils.h"
30 #include "llvm/IR/CFG.h"
31 #include "llvm/IR/DataLayout.h"
32 #include "llvm/IR/DebugInfoMetadata.h"
33 #include "llvm/IR/Function.h"
34 #include "llvm/IR/IRBuilder.h"
35 #include "llvm/IR/Instructions.h"
36 #include "llvm/IR/IntrinsicInst.h"
37 #include "llvm/IR/PatternMatch.h"
38 #include "llvm/InitializePasses.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/Alignment.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Transforms/Scalar.h"
44 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
45
46 using namespace llvm;
47 using namespace PatternMatch;
48
49 #define DEBUG_TYPE "lower-matrix-intrinsics"
50
51 static cl::opt<bool> EnableShapePropagation(
52 "matrix-propagate-shape", cl::init(true), cl::Hidden,
53 cl::desc("Enable/disable shape propagation from matrix intrinsics to other "
54 "instructions."));
55
56 static cl::opt<bool>
57 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
58 cl::desc("Enable/disable fusing matrix instructions."));
59 // TODO: Allow and use non-square tiles.
60 static cl::opt<unsigned> TileSize(
61 "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
62 cl::desc(
63 "Tile size for matrix instruction fusion using square-shaped tiles."));
64 static cl::opt<bool> ForceFusion(
65 "force-fuse-matrix", cl::init(false), cl::Hidden,
66 cl::desc("Force matrix instruction fusion even if not profitable."));
67 static cl::opt<bool> AllowContractEnabled(
68 "matrix-allow-contract", cl::init(false), cl::Hidden,
69 cl::desc("Allow the use of FMAs if available and profitable. This may "
70 "result in different results, due to less rounding error."));
71
72 enum class MatrixLayoutTy { ColumnMajor, RowMajor };
73
74 static cl::opt<MatrixLayoutTy> MatrixLayout(
75 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
76 cl::desc("Sets the default matrix layout"),
77 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major",
78 "Use column-major layout"),
79 clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
80 "Use row-major layout")));
81
82 /// Helper function to either return Scope, if it is a subprogram or the
83 /// attached subprogram for a local scope.
getSubprogram(DIScope * Scope)84 static DISubprogram *getSubprogram(DIScope *Scope) {
85 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
86 return Subprogram;
87 return cast<DILocalScope>(Scope)->getSubprogram();
88 }
89
90 namespace {
91
92 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
93 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
94 // assuming \p Stride elements between start two consecutive vectors.
95 // \p Stride must be >= \p NumElements.
96 // For column-major matrixes, the function computes the address of a column
97 // vectors and \p NumElements must be set to the number of elements in a column
98 // (= number of rows of the matrix). For row-major matrixes, the function
99 // computes the address of a row vector and \p NumElements must be set to the
100 // number of elements in a column (= number of columns of the matrix).
101 //
102 // Consider a 4x4 matrix in column-mjaor layout like below
103 //
104 // 0 1 2 3
105 // 0 v_0_0 v_0_1 v_0_2 v_0_3
106 // 1 v_1_0 v_1_1 v_1_2 v_1_3
107 // 2 v_2_0 v_2_1 v_2_2 v_2_3
108 // 3 v_3_0 v_3_1 v_3_2 v_3_3
109
110 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
111 // we need a pointer to the first element of the submatrix as base pointer.
112 // Then we can use computeVectorAddr to compute the addresses for the columns
113 // of the sub-matrix.
114 //
115 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
116 // -> just returns Base
117 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
118 // -> returns Base + (1 * 4)
119 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
120 // -> returns Base + (2 * 4)
121 //
122 // The graphic below illustrates the number of elements in a column (marked
123 // with |) and the number of skipped elements (marked with }).
124 //
125 // v_0_0 v_0_1 {v_0_2 {v_0_3
126 // Base Col 1 Col 2
127 // | | |
128 // v_1_0 |v_1_1 |v_1_2 |v_1_3
129 // v_2_0 |v_2_1 |v_2_2 |v_2_3
130 // v_3_0 {v_3_1 {v_3_2 v_3_3
131 //
computeVectorAddr(Value * BasePtr,Value * VecIdx,Value * Stride,unsigned NumElements,Type * EltType,IRBuilder<> & Builder)132 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
133 unsigned NumElements, Type *EltType,
134 IRBuilder<> &Builder) {
135
136 assert((!isa<ConstantInt>(Stride) ||
137 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
138 "Stride must be >= the number of elements in the result vector.");
139 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
140
141 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
142 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
143
144 // Get pointer to the start of the selected vector. Skip GEP creation,
145 // if we select vector 0.
146 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
147 VecStart = BasePtr;
148 else
149 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
150
151 // Cast elementwise vector start pointer to a pointer to a vector
152 // (EltType x NumElements)*.
153 auto *VecType = FixedVectorType::get(EltType, NumElements);
154 Type *VecPtrType = PointerType::get(VecType, AS);
155 return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
156 }
157
158 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
159 ///
160 /// Currently, the lowering for each matrix intrinsic is done as follows:
161 /// 1. Propagate the shape information from intrinsics to connected
162 /// instructions.
163 /// 2. Lower instructions with shape information (assuming column-major layout).
164 /// The lowering works similarly using row-major layout.
165 /// 2.1. Get column vectors for each argument. If we already lowered the
166 /// definition of an argument, use the produced column vectors directly.
167 /// If not, split the operand vector containing an embedded matrix into
168 /// a set of column vectors,
169 /// 2.2. Lower the instruction in terms of column major operations, which
170 /// yields a set of column vectors containing result matrix. Note that we
171 /// lower all instructions that have shape information. Besides the
172 /// intrinsics, this includes stores for example.
173 /// 2.3. Update uses of the lowered instruction. If we have shape information
174 /// for a user, there is nothing to do, as we will look up the result
175 /// column matrix when lowering the user. For other uses, we embed the
176 /// result matrix in a flat vector and update the use.
177 /// 2.4. Cache the result column matrix for the instruction we lowered
178 /// 3. After we lowered all instructions in a function, remove the now
179 /// obsolete instructions.
180 ///
181 class LowerMatrixIntrinsics {
182 Function &Func;
183 const DataLayout &DL;
184 const TargetTransformInfo &TTI;
185 AliasAnalysis &AA;
186 DominatorTree &DT;
187 LoopInfo &LI;
188 OptimizationRemarkEmitter &ORE;
189
190 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
191 struct OpInfoTy {
192 /// Number of stores emitted to generate this matrix.
193 unsigned NumStores = 0;
194 /// Number of loads emitted to generate this matrix.
195 unsigned NumLoads = 0;
196 /// Number of compute operations emitted to generate this matrix.
197 unsigned NumComputeOps = 0;
198
operator +=__anon643663df0111::LowerMatrixIntrinsics::OpInfoTy199 OpInfoTy &operator+=(const OpInfoTy &RHS) {
200 NumStores += RHS.NumStores;
201 NumLoads += RHS.NumLoads;
202 NumComputeOps += RHS.NumComputeOps;
203 return *this;
204 }
205 };
206
207 /// Wrapper class representing a matrix as a set of vectors, either in row or
208 /// column major layout. All vectors must have the same vector type.
209 class MatrixTy {
210 SmallVector<Value *, 16> Vectors;
211
212 OpInfoTy OpInfo;
213
214 bool IsColumnMajor = true;
215
216 public:
MatrixTy()217 MatrixTy()
218 : Vectors(),
219 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
MatrixTy(ArrayRef<Value * > Vectors)220 MatrixTy(ArrayRef<Value *> Vectors)
221 : Vectors(Vectors.begin(), Vectors.end()),
222 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
MatrixTy(unsigned NumRows,unsigned NumColumns,Type * EltTy)223 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
224 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
225
226 unsigned D = isColumnMajor() ? NumColumns : NumRows;
227 for (unsigned J = 0; J < D; ++J)
228 addVector(UndefValue::get(FixedVectorType::get(
229 EltTy, isColumnMajor() ? NumRows : NumColumns)));
230 }
231
getVector(unsigned i) const232 Value *getVector(unsigned i) const { return Vectors[i]; }
getColumn(unsigned i) const233 Value *getColumn(unsigned i) const {
234 assert(isColumnMajor() && "only supported for column-major matrixes");
235 return Vectors[i];
236 }
getRow(unsigned i) const237 Value *getRow(unsigned i) const {
238 assert(!isColumnMajor() && "only supported for row-major matrixes");
239 return Vectors[i];
240 }
241
setVector(unsigned i,Value * V)242 void setVector(unsigned i, Value *V) { Vectors[i] = V; }
243
getElementType()244 Type *getElementType() { return getVectorTy()->getElementType(); }
245
getNumVectors() const246 unsigned getNumVectors() const {
247 if (isColumnMajor())
248 return getNumColumns();
249 return getNumRows();
250 }
251
getNumColumns() const252 unsigned getNumColumns() const {
253 if (isColumnMajor())
254 return Vectors.size();
255 else {
256 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
257 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
258 }
259 }
getNumRows() const260 unsigned getNumRows() const {
261 if (isColumnMajor()) {
262 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
263 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
264 } else
265 return Vectors.size();
266 }
267
addVector(Value * V)268 void addVector(Value *V) { Vectors.push_back(V); }
getColumnTy()269 VectorType *getColumnTy() {
270 assert(isColumnMajor() && "only supported for column-major matrixes");
271 return getVectorTy();
272 }
273
getVectorTy()274 VectorType *getVectorTy() {
275 return cast<VectorType>(Vectors[0]->getType());
276 }
277
columns()278 iterator_range<SmallVector<Value *, 8>::iterator> columns() {
279 assert(isColumnMajor() &&
280 "columns() only supported for column-major matrixes");
281 return make_range(Vectors.begin(), Vectors.end());
282 }
283
vectors()284 iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
285 return make_range(Vectors.begin(), Vectors.end());
286 }
287
288 /// Embed the vectors of the matrix into a flat vector by concatenating
289 /// them.
embedInVector(IRBuilder<> & Builder) const290 Value *embedInVector(IRBuilder<> &Builder) const {
291 return Vectors.size() == 1 ? Vectors[0]
292 : concatenateVectors(Builder, Vectors);
293 }
294
addNumLoads(unsigned N)295 MatrixTy &addNumLoads(unsigned N) {
296 OpInfo.NumLoads += N;
297 return *this;
298 }
299
setNumLoads(unsigned N)300 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
301
addNumStores(unsigned N)302 MatrixTy &addNumStores(unsigned N) {
303 OpInfo.NumStores += N;
304 return *this;
305 }
306
addNumComputeOps(unsigned N)307 MatrixTy &addNumComputeOps(unsigned N) {
308 OpInfo.NumComputeOps += N;
309 return *this;
310 }
311
getNumStores() const312 unsigned getNumStores() const { return OpInfo.NumStores; }
getNumLoads() const313 unsigned getNumLoads() const { return OpInfo.NumLoads; }
getNumComputeOps() const314 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
315
getOpInfo() const316 const OpInfoTy &getOpInfo() const { return OpInfo; }
317
isColumnMajor() const318 bool isColumnMajor() const { return IsColumnMajor; }
319
getStride() const320 unsigned getStride() const {
321 if (isColumnMajor())
322 return getNumRows();
323 return getNumColumns();
324 }
325
326 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
327 /// matrix is column-major, the result vector is extracted from a column
328 /// vector, otherwise from a row vector.
extractVector(unsigned I,unsigned J,unsigned NumElts,IRBuilder<> & Builder) const329 Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
330 IRBuilder<> &Builder) const {
331 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
332 Value *Undef = UndefValue::get(Vec->getType());
333 return Builder.CreateShuffleVector(
334 Vec, Undef, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
335 "block");
336 }
337 };
338
339 struct ShapeInfo {
340 unsigned NumRows;
341 unsigned NumColumns;
342
343 bool IsColumnMajor;
344
ShapeInfo__anon643663df0111::LowerMatrixIntrinsics::ShapeInfo345 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
346 : NumRows(NumRows), NumColumns(NumColumns),
347 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
348
ShapeInfo__anon643663df0111::LowerMatrixIntrinsics::ShapeInfo349 ShapeInfo(Value *NumRows, Value *NumColumns)
350 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
351 cast<ConstantInt>(NumColumns)->getZExtValue()) {}
352
operator ==__anon643663df0111::LowerMatrixIntrinsics::ShapeInfo353 bool operator==(const ShapeInfo &other) {
354 return NumRows == other.NumRows && NumColumns == other.NumColumns;
355 }
operator !=__anon643663df0111::LowerMatrixIntrinsics::ShapeInfo356 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
357
358 /// Returns true if shape-information is defined, meaning both dimensions
359 /// are != 0.
operator bool__anon643663df0111::LowerMatrixIntrinsics::ShapeInfo360 operator bool() const {
361 assert(NumRows == 0 || NumColumns != 0);
362 return NumRows != 0;
363 }
364
getStride__anon643663df0111::LowerMatrixIntrinsics::ShapeInfo365 unsigned getStride() const {
366 if (IsColumnMajor)
367 return NumRows;
368 return NumColumns;
369 }
370
getNumVectors__anon643663df0111::LowerMatrixIntrinsics::ShapeInfo371 unsigned getNumVectors() const {
372 if (IsColumnMajor)
373 return NumColumns;
374 return NumRows;
375 }
376 };
377
378 /// Maps instructions to their shape information. The shape information
379 /// describes the shape to be used while lowering. This matches the shape of
380 /// the result value of the instruction, with the only exceptions being store
381 /// instructions and the matrix_column_major_store intrinsics. For those, the
382 /// shape information indicates that those instructions should be lowered
383 /// using shape information as well.
384 DenseMap<Value *, ShapeInfo> ShapeMap;
385
386 /// List of instructions to remove. While lowering, we are not replacing all
387 /// users of a lowered instruction, if shape information is available and
388 /// those need to be removed after we finished lowering.
389 SmallVector<Instruction *, 16> ToRemove;
390
391 /// Map from instructions to their produced column matrix.
392 MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
393
394 public:
LowerMatrixIntrinsics(Function & F,TargetTransformInfo & TTI,AliasAnalysis & AA,DominatorTree & DT,LoopInfo & LI,OptimizationRemarkEmitter & ORE)395 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
396 AliasAnalysis &AA, DominatorTree &DT, LoopInfo &LI,
397 OptimizationRemarkEmitter &ORE)
398 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
399 LI(LI), ORE(ORE) {}
400
getNumOps(Type * VT)401 unsigned getNumOps(Type *VT) {
402 assert(isa<VectorType>(VT) && "Expected vector type");
403 return getNumOps(VT->getScalarType(),
404 cast<FixedVectorType>(VT)->getNumElements());
405 }
406
407 //
408 /// Return the estimated number of vector ops required for an operation on
409 /// \p VT * N.
getNumOps(Type * ST,unsigned N)410 unsigned getNumOps(Type *ST, unsigned N) {
411 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
412 double(TTI.getRegisterBitWidth(true)));
413 }
414
415 /// Return the set of vectors that a matrix value is lowered to.
416 ///
417 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
418 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
419 /// into vectors.
getMatrix(Value * MatrixVal,const ShapeInfo & SI,IRBuilder<> & Builder)420 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
421 IRBuilder<> &Builder) {
422 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
423 assert(VType && "MatrixVal must be a vector type");
424 assert(cast<FixedVectorType>(VType)->getNumElements() ==
425 SI.NumRows * SI.NumColumns &&
426 "The vector size must match the number of matrix elements");
427
428 // Check if we lowered MatrixVal using shape information. In that case,
429 // return the existing matrix, if it matches the requested shape
430 // information. If there is a mis-match, embed the result in a flat
431 // vector and split it later.
432 auto Found = Inst2ColumnMatrix.find(MatrixVal);
433 if (Found != Inst2ColumnMatrix.end()) {
434 MatrixTy &M = Found->second;
435 // Return the found matrix, if its shape matches the requested shape
436 // information
437 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
438 return M;
439
440 MatrixVal = M.embedInVector(Builder);
441 }
442
443 // Otherwise split MatrixVal.
444 SmallVector<Value *, 16> SplitVecs;
445 Value *Undef = UndefValue::get(VType);
446 for (unsigned MaskStart = 0;
447 MaskStart < cast<FixedVectorType>(VType)->getNumElements();
448 MaskStart += SI.getStride()) {
449 Value *V = Builder.CreateShuffleVector(
450 MatrixVal, Undef, createSequentialMask(MaskStart, SI.getStride(), 0),
451 "split");
452 SplitVecs.push_back(V);
453 }
454
455 return {SplitVecs};
456 }
457
458 /// If \p V already has a known shape return false. Otherwise set the shape
459 /// for instructions that support it.
setShapeInfo(Value * V,ShapeInfo Shape)460 bool setShapeInfo(Value *V, ShapeInfo Shape) {
461 assert(Shape && "Shape not set");
462 if (isa<UndefValue>(V) || !supportsShapeInfo(V))
463 return false;
464
465 auto SIter = ShapeMap.find(V);
466 if (SIter != ShapeMap.end()) {
467 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
468 << SIter->second.NumRows << " "
469 << SIter->second.NumColumns << " for " << *V << "\n");
470 return false;
471 }
472
473 ShapeMap.insert({V, Shape});
474 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
475 << " for " << *V << "\n");
476 return true;
477 }
478
isUniformShape(Value * V)479 bool isUniformShape(Value *V) {
480 Instruction *I = dyn_cast<Instruction>(V);
481 if (!I)
482 return true;
483
484 switch (I->getOpcode()) {
485 case Instruction::FAdd:
486 case Instruction::FSub:
487 case Instruction::FMul: // Scalar multiply.
488 case Instruction::Add:
489 case Instruction::Mul:
490 case Instruction::Sub:
491 return true;
492 default:
493 return false;
494 }
495 }
496
497 /// Returns true if shape information can be used for \p V. The supported
498 /// instructions must match the instructions that can be lowered by this pass.
supportsShapeInfo(Value * V)499 bool supportsShapeInfo(Value *V) {
500 Instruction *Inst = dyn_cast<Instruction>(V);
501 if (!Inst)
502 return false;
503
504 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
505 if (II)
506 switch (II->getIntrinsicID()) {
507 case Intrinsic::matrix_multiply:
508 case Intrinsic::matrix_transpose:
509 case Intrinsic::matrix_column_major_load:
510 case Intrinsic::matrix_column_major_store:
511 return true;
512 default:
513 return false;
514 }
515 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
516 }
517
518 /// Propagate the shape information of instructions to their users.
519 /// The work list contains instructions for which we can compute the shape,
520 /// either based on the information provided by matrix intrinsics or known
521 /// shapes of operands.
522 SmallVector<Instruction *, 32>
propagateShapeForward(SmallVectorImpl<Instruction * > & WorkList)523 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
524 SmallVector<Instruction *, 32> NewWorkList;
525 // Pop an element for which we guaranteed to have at least one of the
526 // operand shapes. Add the shape for this and then add users to the work
527 // list.
528 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
529 while (!WorkList.empty()) {
530 Instruction *Inst = WorkList.back();
531 WorkList.pop_back();
532
533 // New entry, set the value and insert operands
534 bool Propagate = false;
535
536 Value *MatrixA;
537 Value *MatrixB;
538 Value *M;
539 Value *N;
540 Value *K;
541 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
542 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
543 m_Value(N), m_Value(K)))) {
544 Propagate = setShapeInfo(Inst, {M, K});
545 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
546 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
547 // Flip dimensions.
548 Propagate = setShapeInfo(Inst, {N, M});
549 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
550 m_Value(MatrixA), m_Value(), m_Value(),
551 m_Value(), m_Value(M), m_Value(N)))) {
552 Propagate = setShapeInfo(Inst, {N, M});
553 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
554 m_Value(), m_Value(), m_Value(), m_Value(M),
555 m_Value(N)))) {
556 Propagate = setShapeInfo(Inst, {M, N});
557 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
558 auto OpShape = ShapeMap.find(MatrixA);
559 if (OpShape != ShapeMap.end())
560 setShapeInfo(Inst, OpShape->second);
561 continue;
562 } else if (isUniformShape(Inst)) {
563 // Find the first operand that has a known shape and use that.
564 for (auto &Op : Inst->operands()) {
565 auto OpShape = ShapeMap.find(Op.get());
566 if (OpShape != ShapeMap.end()) {
567 Propagate |= setShapeInfo(Inst, OpShape->second);
568 break;
569 }
570 }
571 }
572
573 if (Propagate) {
574 NewWorkList.push_back(Inst);
575 for (auto *User : Inst->users())
576 if (ShapeMap.count(User) == 0)
577 WorkList.push_back(cast<Instruction>(User));
578 }
579 }
580
581 return NewWorkList;
582 }
583
584 /// Propagate the shape to operands of instructions with shape information.
585 /// \p Worklist contains the instruction for which we already know the shape.
586 SmallVector<Instruction *, 32>
propagateShapeBackward(SmallVectorImpl<Instruction * > & WorkList)587 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
588 SmallVector<Instruction *, 32> NewWorkList;
589
590 auto pushInstruction = [](Value *V,
591 SmallVectorImpl<Instruction *> &WorkList) {
592 Instruction *I = dyn_cast<Instruction>(V);
593 if (I)
594 WorkList.push_back(I);
595 };
596 // Pop an element with known shape. Traverse the operands, if their shape
597 // derives from the result shape and is unknown, add it and add them to the
598 // worklist.
599 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
600 while (!WorkList.empty()) {
601 Value *V = WorkList.back();
602 WorkList.pop_back();
603
604 size_t BeforeProcessingV = WorkList.size();
605 if (!isa<Instruction>(V))
606 continue;
607
608 Value *MatrixA;
609 Value *MatrixB;
610 Value *M;
611 Value *N;
612 Value *K;
613 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
614 m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
615 m_Value(N), m_Value(K)))) {
616 if (setShapeInfo(MatrixA, {M, N}))
617 pushInstruction(MatrixA, WorkList);
618
619 if (setShapeInfo(MatrixB, {N, K}))
620 pushInstruction(MatrixB, WorkList);
621
622 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
623 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
624 // Flip dimensions.
625 if (setShapeInfo(MatrixA, {M, N}))
626 pushInstruction(MatrixA, WorkList);
627 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
628 m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
629 m_Value(M), m_Value(N)))) {
630 if (setShapeInfo(MatrixA, {M, N})) {
631 pushInstruction(MatrixA, WorkList);
632 }
633 } else if (isa<LoadInst>(V) ||
634 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
635 // Nothing to do, no matrix input.
636 } else if (isa<StoreInst>(V)) {
637 // Nothing to do. We forward-propagated to this so we would just
638 // backward propagate to an instruction with an already known shape.
639 } else if (isUniformShape(V)) {
640 // Propagate to all operands.
641 ShapeInfo Shape = ShapeMap[V];
642 for (Use &U : cast<Instruction>(V)->operands()) {
643 if (setShapeInfo(U.get(), Shape))
644 pushInstruction(U.get(), WorkList);
645 }
646 }
647 // After we discovered new shape info for new instructions in the
648 // worklist, we use their users as seeds for the next round of forward
649 // propagation.
650 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
651 for (User *U : WorkList[I]->users())
652 if (isa<Instruction>(U) && V != U)
653 NewWorkList.push_back(cast<Instruction>(U));
654 }
655 return NewWorkList;
656 }
657
Visit()658 bool Visit() {
659 if (EnableShapePropagation) {
660 SmallVector<Instruction *, 32> WorkList;
661
662 // Initially only the shape of matrix intrinsics is known.
663 // Initialize the work list with ops carrying shape information.
664 for (BasicBlock &BB : Func)
665 for (Instruction &Inst : BB) {
666 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
667 if (!II)
668 continue;
669
670 switch (II->getIntrinsicID()) {
671 case Intrinsic::matrix_multiply:
672 case Intrinsic::matrix_transpose:
673 case Intrinsic::matrix_column_major_load:
674 case Intrinsic::matrix_column_major_store:
675 WorkList.push_back(&Inst);
676 break;
677 default:
678 break;
679 }
680 }
681 // Propagate shapes until nothing changes any longer.
682 while (!WorkList.empty()) {
683 WorkList = propagateShapeForward(WorkList);
684 WorkList = propagateShapeBackward(WorkList);
685 }
686 }
687
688 bool Changed = false;
689 SmallVector<CallInst *, 16> MaybeFusableInsts;
690 SmallVector<Instruction *, 16> MatrixInsts;
691
692 // First, collect all instructions with shape information and candidates for
693 // fusion (currently only matrix multiplies).
694 ReversePostOrderTraversal<Function *> RPOT(&Func);
695 for (auto *BB : RPOT)
696 for (Instruction &I : *BB) {
697 if (ShapeMap.find(&I) == ShapeMap.end())
698 continue;
699 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
700 MaybeFusableInsts.push_back(cast<CallInst>(&I));
701 MatrixInsts.push_back(&I);
702 }
703
704 // Second, try to fuse candidates.
705 SmallPtrSet<Instruction *, 16> FusedInsts;
706 for (CallInst *CI : MaybeFusableInsts)
707 LowerMatrixMultiplyFused(CI, FusedInsts);
708 Changed = !FusedInsts.empty();
709
710 // Third, lower remaining instructions with shape information.
711 for (Instruction *Inst : MatrixInsts) {
712 if (FusedInsts.count(Inst))
713 continue;
714
715 IRBuilder<> Builder(Inst);
716
717 if (CallInst *CInst = dyn_cast<CallInst>(Inst))
718 Changed |= VisitCallInst(CInst);
719
720 Value *Op1;
721 Value *Op2;
722 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
723 Changed |= VisitBinaryOperator(BinOp);
724 if (match(Inst, m_Load(m_Value(Op1))))
725 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
726 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
727 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
728 }
729
730 RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func);
731 RemarkGen.emitRemarks();
732
733 for (Instruction *Inst : reverse(ToRemove))
734 Inst->eraseFromParent();
735
736 return Changed;
737 }
738
739 /// Turns \p BasePtr into an elementwise pointer to \p EltType.
createElementPtr(Value * BasePtr,Type * EltType,IRBuilder<> & Builder)740 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
741 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
742 Type *EltPtrType = PointerType::get(EltType, AS);
743 return Builder.CreatePointerCast(BasePtr, EltPtrType);
744 }
745
746 /// Replace intrinsic calls
VisitCallInst(CallInst * Inst)747 bool VisitCallInst(CallInst *Inst) {
748 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
749 return false;
750
751 switch (Inst->getCalledFunction()->getIntrinsicID()) {
752 case Intrinsic::matrix_multiply:
753 LowerMultiply(Inst);
754 break;
755 case Intrinsic::matrix_transpose:
756 LowerTranspose(Inst);
757 break;
758 case Intrinsic::matrix_column_major_load:
759 LowerColumnMajorLoad(Inst);
760 break;
761 case Intrinsic::matrix_column_major_store:
762 LowerColumnMajorStore(Inst);
763 break;
764 default:
765 return false;
766 }
767 return true;
768 }
769
770 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
771 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
772 /// ConstantInt, reduce the initial alignment based on the byte offset. For
773 /// non-ConstantInt strides, return the common alignment of the initial
774 /// alignment and the element size in bytes.
getAlignForIndex(unsigned Idx,Value * Stride,Type * ElementTy,MaybeAlign A) const775 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
776 MaybeAlign A) const {
777 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
778 if (Idx == 0)
779 return InitialAlign;
780
781 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
782 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
783 uint64_t StrideInBytes =
784 ConstStride->getZExtValue() * ElementSizeInBits / 8;
785 return commonAlignment(InitialAlign, Idx * StrideInBytes);
786 }
787 return commonAlignment(InitialAlign, ElementSizeInBits / 8);
788 }
789
790 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
791 /// vectors.
loadMatrix(Type * Ty,Value * Ptr,MaybeAlign MAlign,Value * Stride,bool IsVolatile,ShapeInfo Shape,IRBuilder<> & Builder)792 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
793 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
794 auto VType = cast<VectorType>(Ty);
795 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
796 MatrixTy Result;
797 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
798 Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride,
799 Shape.getStride(), VType->getElementType(),
800 Builder);
801 Value *Vector = Builder.CreateAlignedLoad(
802 GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign),
803 IsVolatile, "col.load");
804
805 Result.addVector(Vector);
806 }
807 return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
808 Result.getNumVectors());
809 }
810
811 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
812 /// starting at \p MatrixPtr[I][J].
loadMatrix(Value * MatrixPtr,MaybeAlign Align,bool IsVolatile,ShapeInfo MatrixShape,Value * I,Value * J,ShapeInfo ResultShape,Type * EltTy,IRBuilder<> & Builder)813 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
814 ShapeInfo MatrixShape, Value *I, Value *J,
815 ShapeInfo ResultShape, Type *EltTy,
816 IRBuilder<> &Builder) {
817
818 Value *Offset = Builder.CreateAdd(
819 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
820
821 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
822 Value *EltPtr =
823 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
824 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
825 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
826 ResultShape.NumColumns);
827 Type *TilePtrTy = PointerType::get(TileTy, AS);
828 Value *TilePtr =
829 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
830
831 return loadMatrix(TileTy, TilePtr, Align,
832 Builder.getInt64(MatrixShape.getStride()), IsVolatile,
833 ResultShape, Builder);
834 }
835
836 /// Lower a load instruction with shape information.
LowerLoad(Instruction * Inst,Value * Ptr,MaybeAlign Align,Value * Stride,bool IsVolatile,ShapeInfo Shape)837 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
838 bool IsVolatile, ShapeInfo Shape) {
839 IRBuilder<> Builder(Inst);
840 finalizeLowering(Inst,
841 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
842 Shape, Builder),
843 Builder);
844 }
845
846 /// Lowers llvm.matrix.column.major.load.
847 ///
848 /// The intrinsic loads a matrix from memory using a stride between columns.
LowerColumnMajorLoad(CallInst * Inst)849 void LowerColumnMajorLoad(CallInst *Inst) {
850 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
851 "Intrinsic only supports column-major layout!");
852 Value *Ptr = Inst->getArgOperand(0);
853 Value *Stride = Inst->getArgOperand(1);
854 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
855 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
856 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
857 }
858
859 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
860 /// MatrixPtr[I][J].
storeMatrix(const MatrixTy & StoreVal,Value * MatrixPtr,MaybeAlign MAlign,bool IsVolatile,ShapeInfo MatrixShape,Value * I,Value * J,Type * EltTy,IRBuilder<> & Builder)861 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
862 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
863 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
864 Value *Offset = Builder.CreateAdd(
865 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
866
867 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
868 Value *EltPtr =
869 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
870 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
871 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
872 StoreVal.getNumColumns());
873 Type *TilePtrTy = PointerType::get(TileTy, AS);
874 Value *TilePtr =
875 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
876
877 storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
878 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
879 }
880
881 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
882 /// vectors.
storeMatrix(Type * Ty,MatrixTy StoreVal,Value * Ptr,MaybeAlign MAlign,Value * Stride,bool IsVolatile,IRBuilder<> & Builder)883 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
884 MaybeAlign MAlign, Value *Stride, bool IsVolatile,
885 IRBuilder<> &Builder) {
886 auto VType = cast<VectorType>(Ty);
887 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
888 for (auto Vec : enumerate(StoreVal.vectors())) {
889 Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()),
890 Stride, StoreVal.getStride(),
891 VType->getElementType(), Builder);
892 Builder.CreateAlignedStore(Vec.value(), GEP,
893 getAlignForIndex(Vec.index(), Stride,
894 VType->getElementType(),
895 MAlign),
896 IsVolatile);
897 }
898 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
899 StoreVal.getNumVectors());
900 }
901
902 /// Lower a store instruction with shape information.
LowerStore(Instruction * Inst,Value * Matrix,Value * Ptr,MaybeAlign A,Value * Stride,bool IsVolatile,ShapeInfo Shape)903 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
904 Value *Stride, bool IsVolatile, ShapeInfo Shape) {
905 IRBuilder<> Builder(Inst);
906 auto StoreVal = getMatrix(Matrix, Shape, Builder);
907 finalizeLowering(Inst,
908 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
909 IsVolatile, Builder),
910 Builder);
911 }
912
913 /// Lowers llvm.matrix.column.major.store.
914 ///
915 /// The intrinsic store a matrix back memory using a stride between columns.
LowerColumnMajorStore(CallInst * Inst)916 void LowerColumnMajorStore(CallInst *Inst) {
917 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
918 "Intrinsic only supports column-major layout!");
919 Value *Matrix = Inst->getArgOperand(0);
920 Value *Ptr = Inst->getArgOperand(1);
921 Value *Stride = Inst->getArgOperand(2);
922 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
923 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
924 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
925 }
926
927 // Set elements I..I+NumElts-1 to Block
insertVector(Value * Col,unsigned I,Value * Block,IRBuilder<> & Builder)928 Value *insertVector(Value *Col, unsigned I, Value *Block,
929 IRBuilder<> &Builder) {
930
931 // First, bring Block to the same size as Col
932 unsigned BlockNumElts =
933 cast<FixedVectorType>(Block->getType())->getNumElements();
934 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
935 assert(NumElts >= BlockNumElts && "Too few elements for current block");
936
937 Value *Undef = UndefValue::get(Block->getType());
938 Block = Builder.CreateShuffleVector(
939 Block, Undef,
940 createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
941
942 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
943 // 8, 4, 5, 6
944 SmallVector<int, 16> Mask;
945 unsigned i;
946 for (i = 0; i < I; i++)
947 Mask.push_back(i);
948
949 unsigned VecNumElts =
950 cast<FixedVectorType>(Col->getType())->getNumElements();
951 for (; i < I + BlockNumElts; i++)
952 Mask.push_back(i - I + VecNumElts);
953
954 for (; i < VecNumElts; i++)
955 Mask.push_back(i);
956
957 return Builder.CreateShuffleVector(Col, Block, Mask);
958 }
959
createMulAdd(Value * Sum,Value * A,Value * B,bool UseFPOp,IRBuilder<> & Builder,bool AllowContraction,unsigned & NumComputeOps)960 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
961 IRBuilder<> &Builder, bool AllowContraction,
962 unsigned &NumComputeOps) {
963 NumComputeOps += getNumOps(A->getType());
964 if (!Sum)
965 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
966
967 if (UseFPOp) {
968 if (AllowContraction) {
969 // Use fmuladd for floating point operations and let the backend decide
970 // if that's profitable.
971 Function *FMulAdd = Intrinsic::getDeclaration(
972 Func.getParent(), Intrinsic::fmuladd, A->getType());
973 return Builder.CreateCall(FMulAdd, {A, B, Sum});
974 }
975 NumComputeOps += getNumOps(A->getType());
976 Value *Mul = Builder.CreateFMul(A, B);
977 return Builder.CreateFAdd(Sum, Mul);
978 }
979
980 NumComputeOps += getNumOps(A->getType());
981 Value *Mul = Builder.CreateMul(A, B);
982 return Builder.CreateAdd(Sum, Mul);
983 }
984
985 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
986 /// users with shape information, there's nothing to do: the will use the
987 /// cached value when they are lowered. For other users, \p Matrix is
988 /// flattened and the uses are updated to use it. Also marks \p Inst for
989 /// deletion.
finalizeLowering(Instruction * Inst,MatrixTy Matrix,IRBuilder<> & Builder)990 void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
991 IRBuilder<> &Builder) {
992 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
993
994 ToRemove.push_back(Inst);
995 Value *Flattened = nullptr;
996 for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
997 Use &U = *I++;
998 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
999 if (!Flattened)
1000 Flattened = Matrix.embedInVector(Builder);
1001 U.set(Flattened);
1002 }
1003 }
1004 }
1005
1006 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1007 /// addition.
emitMatrixMultiply(MatrixTy & Result,const MatrixTy & A,const MatrixTy & B,bool AllowContraction,IRBuilder<> & Builder,bool isTiled)1008 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1009 const MatrixTy &B, bool AllowContraction,
1010 IRBuilder<> &Builder, bool isTiled) {
1011 const unsigned VF = std::max<unsigned>(
1012 TTI.getRegisterBitWidth(true) /
1013 Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
1014 1U);
1015 unsigned R = Result.getNumRows();
1016 unsigned C = Result.getNumColumns();
1017 unsigned M = A.getNumColumns();
1018
1019 bool IsFP = Result.getElementType()->isFloatingPointTy();
1020 assert(A.isColumnMajor() == B.isColumnMajor() &&
1021 Result.isColumnMajor() == A.isColumnMajor() &&
1022 "operands must agree on matrix layout");
1023 unsigned NumComputeOps = 0;
1024 if (A.isColumnMajor()) {
1025 // Multiply columns from the first operand with scalars from the second
1026 // operand. Then move along the K axes and accumulate the columns. With
1027 // this the adds can be vectorized without reassociation.
1028 for (unsigned J = 0; J < C; ++J) {
1029 unsigned BlockSize = VF;
1030 // If Result is zero, we don't need to accumulate in the K==0 iteration.
1031 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1032
1033 for (unsigned I = 0; I < R; I += BlockSize) {
1034 // Gradually lower the vectorization factor to cover the remainder.
1035 while (I + BlockSize > R)
1036 BlockSize /= 2;
1037
1038 Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder)
1039 : nullptr;
1040 for (unsigned K = 0; K < M; ++K) {
1041 Value *L = A.extractVector(I, K, BlockSize, Builder);
1042 Value *RH = Builder.CreateExtractElement(B.getColumn(J), K);
1043 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1044 Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1045 Result.getElementType()->isFloatingPointTy(),
1046 Builder, AllowContraction, NumComputeOps);
1047 }
1048 Result.setVector(J,
1049 insertVector(Result.getVector(J), I, Sum, Builder));
1050 }
1051 }
1052 } else {
1053 // Multiply rows from the second operand with scalars from the first
1054 // operand. Then move along the K axes and accumulate the rows. With this
1055 // the adds can be vectorized without reassociation.
1056 for (unsigned I = 0; I < R; ++I) {
1057 unsigned BlockSize = VF;
1058 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1059 for (unsigned J = 0; J < C; J += BlockSize) {
1060 // Gradually lower the vectorization factor to cover the remainder.
1061 while (J + BlockSize > C)
1062 BlockSize /= 2;
1063
1064 Value *Sum = nullptr;
1065 for (unsigned K = 0; K < M; ++K) {
1066 Value *R = B.extractVector(K, J, BlockSize, Builder);
1067 Value *LH = Builder.CreateExtractElement(A.getVector(I), K);
1068 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1069 Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1070 IsFP, Builder, AllowContraction, NumComputeOps);
1071 }
1072 Result.setVector(I,
1073 insertVector(Result.getVector(I), J, Sum, Builder));
1074 }
1075 }
1076 }
1077 Result.addNumComputeOps(NumComputeOps);
1078 }
1079
1080 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1081 /// copying it to a new location. This new or otherwise the original location
1082 /// is returned.
getNonAliasingPointer(LoadInst * Load,StoreInst * Store,CallInst * MatMul)1083 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1084 CallInst *MatMul) {
1085 MemoryLocation StoreLoc = MemoryLocation::get(Store);
1086 MemoryLocation LoadLoc = MemoryLocation::get(Load);
1087
1088 AliasResult LdAliased = AA.alias(LoadLoc, StoreLoc);
1089
1090 // If we can statically determine noalias we're good.
1091 if (!LdAliased)
1092 return Load->getPointerOperand();
1093
1094 // Create code to check if the memory locations of the Load and Store
1095 // overlap and if they do, copy Load's operand to a new buffer.
1096
1097 // First, create new blocks for 2n part of the check and the copy.
1098 BasicBlock *Check0 = MatMul->getParent();
1099 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1100 // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1101 // as we adjust Check0 and Check1's branches.
1102 SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
1103 for (BasicBlock *Succ : successors(Check0))
1104 DTUpdates.push_back({DT.Delete, Check0, Succ});
1105
1106 BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI,
1107 nullptr, "alias_cont");
1108 BasicBlock *Copy =
1109 SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, nullptr, "copy");
1110 BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI,
1111 nullptr, "no_alias");
1112
1113 // Check if the loaded memory location begins before the end of the store
1114 // location. If the condition holds, they might overlap, otherwise they are
1115 // guaranteed to not overlap.
1116 IRBuilder<> Builder(MatMul);
1117 Check0->getTerminator()->eraseFromParent();
1118 Builder.SetInsertPoint(Check0);
1119 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
1120 Value *StoreBegin = Builder.CreatePtrToInt(
1121 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1122 Value *StoreEnd = Builder.CreateAdd(
1123 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1124 "store.end", true, true);
1125 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1126 IntPtrTy, "load.begin");
1127 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1128 Fusion);
1129
1130 // Check if the store begins before the end of the load location. If the
1131 // condition holds, they alias, otherwise they are guaranteed to not
1132 // overlap.
1133 Check1->getTerminator()->eraseFromParent();
1134 Builder.SetInsertPoint(Check1, Check1->begin());
1135 Value *LoadEnd = Builder.CreateAdd(
1136 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1137 "load.end", true, true);
1138 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1139 Fusion);
1140
1141 // Copy load operand to new alloca.
1142 Builder.SetInsertPoint(Copy, Copy->begin());
1143 AllocaInst *NewLd =
1144 Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace());
1145 Builder.CreateMemCpy(NewLd, NewLd->getAlign(),
1146 Load->getPointerOperand(), Load->getAlign(),
1147 LoadLoc.Size.getValue());
1148 Builder.SetInsertPoint(Fusion, Fusion->begin());
1149 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1150 PHI->addIncoming(Load->getPointerOperand(), Check0);
1151 PHI->addIncoming(Load->getPointerOperand(), Check1);
1152 PHI->addIncoming(NewLd, Copy);
1153
1154 // Adjust DT.
1155 DTUpdates.push_back({DT.Insert, Check0, Check1});
1156 DTUpdates.push_back({DT.Insert, Check0, Fusion});
1157 DTUpdates.push_back({DT.Insert, Check1, Copy});
1158 DTUpdates.push_back({DT.Insert, Check1, Fusion});
1159 DT.applyUpdates(DTUpdates);
1160 return PHI;
1161 }
1162
isFusionProfitable(CallInst * MatMul)1163 bool isFusionProfitable(CallInst *MatMul) {
1164 if (ForceFusion)
1165 return true;
1166
1167 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1168 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1169
1170 const unsigned R = LShape.NumRows;
1171 const unsigned C = RShape.NumColumns;
1172 const unsigned M = LShape.NumColumns;
1173 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1174
1175 const unsigned VF =
1176 std::max<unsigned>(TTI.getRegisterBitWidth(true) /
1177 EltType->getPrimitiveSizeInBits().getFixedSize(),
1178 1U);
1179
1180 // Cost model for tiling
1181 //
1182 // For tiling to be beneficial, we need reuse either along the R or
1183 // the C axis. We vectorize along the R axis so that means at least
1184 // 3 elements.
1185 // TODO: Also consider cost of copying if operands alias.
1186 if (R <= VF && C == 1)
1187 return false;
1188 // Then we need enough elements to exceed the number of vector
1189 // registers we have. Note that this is an oversimplification since
1190 // fusing also takes some extra loads which may exceed the number of
1191 // reloads necessary.
1192 unsigned Op0Regs = (R + VF - 1) / VF * M;
1193 unsigned Op1Regs = (M + VF - 1) / VF * C;
1194 return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true);
1195 }
1196
getZeroMatrix(Type * EltType,unsigned R,unsigned C)1197 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1198 MatrixTy Res;
1199 auto *ColumType = FixedVectorType::get(EltType, R);
1200 for (unsigned I = 0; I < C; ++I)
1201 Res.addVector(ConstantAggregateZero::get(ColumType));
1202 return Res;
1203 }
1204
emitSIMDTiling(CallInst * MatMul,LoadInst * LoadOp0,LoadInst * LoadOp1,StoreInst * Store,SmallPtrSetImpl<Instruction * > & FusedInsts)1205 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1206 StoreInst *Store,
1207 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1208 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1209 "Tiling only supported for column-major matrixes at the moment!");
1210 if (!isFusionProfitable(MatMul))
1211 return;
1212
1213 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1214 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1215
1216 const unsigned R = LShape.NumRows;
1217 const unsigned C = RShape.NumColumns;
1218 const unsigned M = LShape.NumColumns;
1219 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1220
1221 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1222 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1223 Value *CPtr = Store->getPointerOperand();
1224
1225 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
1226 MatMul->hasAllowContract());
1227 IRBuilder<> Builder(Store);
1228 for (unsigned J = 0; J < C; J += TileSize)
1229 for (unsigned I = 0; I < R; I += TileSize) {
1230 const unsigned TileR = std::min(R - I, unsigned(TileSize));
1231 const unsigned TileC = std::min(C - J, unsigned(TileSize));
1232 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1233
1234 for (unsigned K = 0; K < M; K += TileSize) {
1235 const unsigned TileM = std::min(M - K, unsigned(TileSize));
1236 MatrixTy A =
1237 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
1238 LShape, Builder.getInt64(I), Builder.getInt64(K),
1239 {TileR, TileM}, EltType, Builder);
1240 MatrixTy B =
1241 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
1242 RShape, Builder.getInt64(K), Builder.getInt64(J),
1243 {TileM, TileC}, EltType, Builder);
1244 emitMatrixMultiply(Res, A, B, AllowContract, Builder, true);
1245 }
1246 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1247 Builder.getInt64(I), Builder.getInt64(J), EltType, Builder);
1248 }
1249
1250 // Mark eliminated instructions as fused and remove them.
1251 FusedInsts.insert(Store);
1252 FusedInsts.insert(MatMul);
1253 Store->eraseFromParent();
1254 MatMul->eraseFromParent();
1255 if (LoadOp0->hasNUses(0)) {
1256 FusedInsts.insert(LoadOp0);
1257 LoadOp0->eraseFromParent();
1258 }
1259 if (LoadOp1->hasNUses(0)) {
1260 FusedInsts.insert(LoadOp1);
1261 LoadOp1->eraseFromParent();
1262 }
1263 }
1264
1265 /// Try to lower matrix multiply chains by fusing operations.
1266 ///
1267 /// Currently we only lower {ld, ld} -> matmul -> st chains.
1268 //
1269 /// No need to return a MatrixTy object for the result of the operation, since
1270 /// the single store user will be lowered as part of this. Instructions that
1271 /// are completely eliminated by fusion are added to \p FusedInsts.
LowerMatrixMultiplyFused(CallInst * MatMul,SmallPtrSetImpl<Instruction * > & FusedInsts)1272 void LowerMatrixMultiplyFused(CallInst *MatMul,
1273 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1274 if (!FuseMatrix || !MatMul->hasOneUse() ||
1275 MatrixLayout != MatrixLayoutTy::ColumnMajor)
1276 return;
1277
1278 auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0));
1279 auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1));
1280 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
1281 if (LoadOp0 && LoadOp1 && Store) {
1282 // The store address must dominate the MatMul instruction, otherwise
1283 // we create invalid IR.
1284 // FIXME: See if we can hoist the store address computation.
1285 auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1));
1286 if (AddrI && (!DT.dominates(AddrI, MatMul)))
1287 return;
1288
1289 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1290 return;
1291 }
1292 }
1293
1294 /// Lowers llvm.matrix.multiply.
LowerMultiply(CallInst * MatMul)1295 void LowerMultiply(CallInst *MatMul) {
1296 IRBuilder<> Builder(MatMul);
1297 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1298 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1299 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1300
1301 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
1302 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
1303
1304 const unsigned R = LShape.NumRows;
1305 const unsigned C = RShape.NumColumns;
1306 assert(LShape.NumColumns == RShape.NumRows);
1307
1308 // Initialize the output
1309 MatrixTy Result(R, C, EltType);
1310
1311 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
1312 MatMul->hasAllowContract());
1313 emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false);
1314 finalizeLowering(MatMul, Result, Builder);
1315 }
1316
1317 /// Lowers llvm.matrix.transpose.
LowerTranspose(CallInst * Inst)1318 void LowerTranspose(CallInst *Inst) {
1319 MatrixTy Result;
1320 IRBuilder<> Builder(Inst);
1321 Value *InputVal = Inst->getArgOperand(0);
1322 VectorType *VectorTy = cast<VectorType>(InputVal->getType());
1323 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
1324 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
1325
1326 const unsigned NewNumVecs =
1327 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1328 const unsigned NewNumElts =
1329 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1330
1331 for (unsigned I = 0; I < NewNumVecs; ++I) {
1332 // Build a single result vector. First initialize it.
1333 Value *ResultVector = UndefValue::get(
1334 FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
1335 // Go through the old elements and insert it into the resulting vector.
1336 for (auto J : enumerate(InputMatrix.vectors())) {
1337 Value *Elt = Builder.CreateExtractElement(J.value(), I);
1338 // Row and column indices are transposed.
1339 ResultVector =
1340 Builder.CreateInsertElement(ResultVector, Elt, J.index());
1341 }
1342 Result.addVector(ResultVector);
1343 }
1344
1345 // TODO: Improve estimate of operations needed for transposes. Currently we
1346 // just count the insertelement/extractelement instructions, but do not
1347 // account for later simplifications/combines.
1348 finalizeLowering(
1349 Inst,
1350 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns),
1351 Builder);
1352 }
1353
1354 /// Lower load instructions, if shape information is available.
VisitLoad(LoadInst * Inst,Value * Ptr,IRBuilder<> & Builder)1355 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
1356 auto I = ShapeMap.find(Inst);
1357 if (I == ShapeMap.end())
1358 return false;
1359
1360 LowerLoad(Inst, Ptr, Inst->getAlign(),
1361 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1362 I->second);
1363 return true;
1364 }
1365
VisitStore(StoreInst * Inst,Value * StoredVal,Value * Ptr,IRBuilder<> & Builder)1366 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
1367 IRBuilder<> &Builder) {
1368 auto I = ShapeMap.find(StoredVal);
1369 if (I == ShapeMap.end())
1370 return false;
1371
1372 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
1373 Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
1374 I->second);
1375 return true;
1376 }
1377
1378 /// Lower binary operators, if shape information is available.
VisitBinaryOperator(BinaryOperator * Inst)1379 bool VisitBinaryOperator(BinaryOperator *Inst) {
1380 auto I = ShapeMap.find(Inst);
1381 if (I == ShapeMap.end())
1382 return false;
1383
1384 Value *Lhs = Inst->getOperand(0);
1385 Value *Rhs = Inst->getOperand(1);
1386
1387 IRBuilder<> Builder(Inst);
1388 ShapeInfo &Shape = I->second;
1389
1390 MatrixTy Result;
1391 MatrixTy A = getMatrix(Lhs, Shape, Builder);
1392 MatrixTy B = getMatrix(Rhs, Shape, Builder);
1393 assert(A.isColumnMajor() == B.isColumnMajor() &&
1394 Result.isColumnMajor() == A.isColumnMajor() &&
1395 "operands must agree on matrix layout");
1396
1397 // Helper to perform binary op on vectors.
1398 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
1399 switch (Inst->getOpcode()) {
1400 case Instruction::Add:
1401 return Builder.CreateAdd(LHS, RHS);
1402 case Instruction::Mul:
1403 return Builder.CreateMul(LHS, RHS);
1404 case Instruction::Sub:
1405 return Builder.CreateSub(LHS, RHS);
1406 case Instruction::FAdd:
1407 return Builder.CreateFAdd(LHS, RHS);
1408 case Instruction::FMul:
1409 return Builder.CreateFMul(LHS, RHS);
1410 case Instruction::FSub:
1411 return Builder.CreateFSub(LHS, RHS);
1412 default:
1413 llvm_unreachable("Unsupported binary operator for matrix");
1414 }
1415 };
1416
1417 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
1418 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
1419
1420 finalizeLowering(Inst,
1421 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1422 Result.getNumVectors()),
1423 Builder);
1424 return true;
1425 }
1426
1427 /// Helper to linearize a matrix expression tree into a string. Currently
1428 /// matrix expressions are linarized by starting at an expression leaf and
1429 /// linearizing bottom up.
1430 struct ExprLinearizer {
1431 unsigned LengthToBreak = 100;
1432 std::string Str;
1433 raw_string_ostream Stream;
1434 unsigned LineLength = 0;
1435 const DataLayout &DL;
1436
1437 /// Mapping from instructions to matrixes. It is used to identify
1438 /// matrix instructions.
1439 const MapVector<Value *, MatrixTy> &Inst2Matrix;
1440
1441 /// Mapping from values to the leaves of all expressions that the value is
1442 /// part of.
1443 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
1444
1445 /// Set of matrix expressions in the scope of a given DISubprogram.
1446 const SmallSetVector<Value *, 32> &ExprsInSubprogram;
1447
1448 /// Leaf node of the expression to linearize.
1449 Value *Leaf;
1450
1451 /// Used to keep track of sub-expressions that get reused while linearizing
1452 /// the expression. Re-used sub-expressions are marked as (reused).
1453 SmallPtrSet<Value *, 8> ReusedExprs;
1454
ExprLinearizer__anon643663df0111::LowerMatrixIntrinsics::ExprLinearizer1455 ExprLinearizer(const DataLayout &DL,
1456 const MapVector<Value *, MatrixTy> &Inst2Matrix,
1457 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
1458 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
1459 Value *Leaf)
1460 : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
1461 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
1462
indent__anon643663df0111::LowerMatrixIntrinsics::ExprLinearizer1463 void indent(unsigned N) {
1464 LineLength += N;
1465 for (unsigned i = 0; i < N; i++)
1466 Stream << " ";
1467 }
1468
lineBreak__anon643663df0111::LowerMatrixIntrinsics::ExprLinearizer1469 void lineBreak() {
1470 Stream << "\n";
1471 LineLength = 0;
1472 }
1473
maybeIndent__anon643663df0111::LowerMatrixIntrinsics::ExprLinearizer1474 void maybeIndent(unsigned Indent) {
1475 if (LineLength >= LengthToBreak)
1476 lineBreak();
1477
1478 if (LineLength == 0)
1479 indent(Indent);
1480 }
1481
write__anon643663df0111::LowerMatrixIntrinsics::ExprLinearizer1482 void write(StringRef S) {
1483 LineLength += S.size();
1484 Stream << S;
1485 }
1486
getUnderlyingObjectThroughLoads__anon643663df0111::LowerMatrixIntrinsics::ExprLinearizer1487 Value *getUnderlyingObjectThroughLoads(Value *V) {
1488 if (Value *Ptr = getPointerOperand(V))
1489 return getUnderlyingObjectThroughLoads(Ptr);
1490 else if (V->getType()->isPointerTy())
1491 return GetUnderlyingObject(V, DL);
1492 return V;
1493 }
1494
1495 /// Returns true if \p V is a matrix value in the given subprogram.
isMatrix__anon643663df0111::LowerMatrixIntrinsics::ExprLinearizer1496 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
1497
1498 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
1499 /// \p SS.
prettyPrintMatrixType__anon643663df0111::LowerMatrixIntrinsics::ExprLinearizer1500 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
1501 auto M = Inst2Matrix.find(V);
1502 if (M == Inst2Matrix.end())
1503 SS << "unknown";
1504 else {
1505 SS << M->second.getNumRows();
1506 SS << "x";
1507 SS << M->second.getNumColumns();
1508 }
1509 }
1510
1511 /// Write the called function name. Handles calls to llvm.matrix.*
1512 /// specially: we write the name, followed by the dimensions of the input
1513 /// matrixes, followed by the scalar type name.
writeFnName__anon643663df0111::LowerMatrixIntrinsics::ExprLinearizer1514 void writeFnName(CallInst *CI) {
1515 if (!CI->getCalledFunction())
1516 write("<no called fn>");
1517 else {
1518 StringRef Name = CI->getCalledFunction()->getName();
1519 if (!Name.startswith("llvm.matrix")) {
1520 write(Name);
1521 return;
1522 }
1523 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
1524 write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {}))
1525 .drop_front(StringRef("llvm.matrix.").size()));
1526 write(".");
1527 std::string Tmp = "";
1528 raw_string_ostream SS(Tmp);
1529
1530 switch (II->getIntrinsicID()) {
1531 case Intrinsic::matrix_multiply:
1532 prettyPrintMatrixType(II->getOperand(0), SS);
1533 SS << ".";
1534 prettyPrintMatrixType(II->getOperand(1), SS);
1535 SS << "." << *II->getType()->getScalarType();
1536 break;
1537 case Intrinsic::matrix_transpose:
1538 prettyPrintMatrixType(II->getOperand(0), SS);
1539 SS << "." << *II->getType()->getScalarType();
1540 break;
1541 case Intrinsic::matrix_column_major_load:
1542 prettyPrintMatrixType(II, SS);
1543 SS << "." << *II->getType()->getScalarType();
1544 break;
1545 case Intrinsic::matrix_column_major_store:
1546 prettyPrintMatrixType(II->getOperand(0), SS);
1547 SS << "." << *II->getOperand(0)->getType()->getScalarType();
1548 break;
1549 default:
1550 llvm_unreachable("Unhandled case");
1551 }
1552 SS.flush();
1553 write(Tmp);
1554 }
1555 }
1556
getNumShapeArgs__anon643663df0111::LowerMatrixIntrinsics::ExprLinearizer1557 unsigned getNumShapeArgs(CallInst *CI) const {
1558 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
1559 switch (II->getIntrinsicID()) {
1560 case Intrinsic::matrix_multiply:
1561 return 3;
1562 case Intrinsic::matrix_transpose:
1563 return 2;
1564 case Intrinsic::matrix_column_major_load:
1565 case Intrinsic::matrix_column_major_store:
1566 return 3;
1567 default:
1568 return 0;
1569 }
1570 }
1571 return 0;
1572 }
1573
1574 /// Special printing for values: for pointers, we print if they refer to an
1575 /// (function) external address or a stack address, for other values we
1576 /// either print the constant or "scalar"/"matrix" for other values.
write__anon643663df0111::LowerMatrixIntrinsics::ExprLinearizer1577 void write(Value *V) {
1578 V = getUnderlyingObjectThroughLoads(V);
1579 if (V->getType()->isPointerTy()) {
1580 if (isa<AllocaInst>(V)) {
1581 Stream << "stack addr";
1582 LineLength += StringRef("stack addr").size();
1583 } else {
1584 Stream << "addr";
1585 LineLength += StringRef("addr").size();
1586 }
1587 if (!V->getName().empty()) {
1588 Stream << " %" << V->getName() << "";
1589 LineLength += V->getName().size() + 2;
1590 }
1591 return;
1592 }
1593
1594 std::string Tmp;
1595 raw_string_ostream TmpStream(Tmp);
1596
1597 if (auto *CI = dyn_cast<ConstantInt>(V))
1598 TmpStream << CI->getValue();
1599 else if (isa<Constant>(V))
1600 TmpStream << "constant";
1601 else {
1602 if (isMatrix(V))
1603 TmpStream << "matrix";
1604 else
1605 TmpStream << "scalar";
1606 }
1607 TmpStream.flush();
1608 Tmp = std::string(StringRef(Tmp).trim());
1609 LineLength += Tmp.size();
1610 Stream << Tmp;
1611 }
1612
1613 /// Linearize expression \p Expr starting at an indentation of \p Indent.
1614 /// Expressions that are re-used multiple times are prefixed with (reused)
1615 /// at the re-used root instruction.
linearizeExpr__anon643663df0111::LowerMatrixIntrinsics::ExprLinearizer1616 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
1617 bool ParentShared) {
1618 auto *I = cast<Instruction>(Expr);
1619 maybeIndent(Indent);
1620 SmallVector<Value *, 8> Ops;
1621
1622 // Is Expr shared with other expression leaves?
1623 bool ExprShared = false;
1624
1625 // Deal with shared subtrees. Mark them as shared, if required.
1626 if (!ParentShared) {
1627 auto SI = Shared.find(Expr);
1628 assert(SI != Shared.end() && SI->second.count(Leaf));
1629
1630 for (Value *S : SI->second) {
1631 if (S == Leaf)
1632 continue;
1633 DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
1634 write("shared with remark at line " + std::to_string(DL.getLine()) +
1635 " column " + std::to_string(DL.getCol()) + " (");
1636 }
1637 ExprShared = SI->second.size() > 1;
1638 }
1639
1640 bool Reused = !ReusedExprs.insert(Expr).second;
1641 if (Reused && !ParentReused)
1642 write("(reused) ");
1643
1644 if (auto *CI = dyn_cast<CallInst>(I)) {
1645 writeFnName(CI);
1646
1647 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
1648 } else if (isa<BitCastInst>(Expr)) {
1649 // Special case bitcasts, which are used to materialize matrixes from
1650 // non-matrix ops.
1651 write("matrix");
1652 return;
1653 } else {
1654 Ops.append(I->value_op_begin(), I->value_op_end());
1655 write(std::string(I->getOpcodeName()));
1656 }
1657
1658 write(std::string("("));
1659
1660 unsigned NumOpsToBreak = 1;
1661 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
1662 NumOpsToBreak = 2;
1663
1664 for (Value *Op : Ops) {
1665 if (Ops.size() > NumOpsToBreak)
1666 lineBreak();
1667
1668 maybeIndent(Indent + 1);
1669 if (isMatrix(Op))
1670 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
1671 else
1672 write(Op);
1673 if (Op != Ops.back())
1674 write(", ");
1675 }
1676
1677 write(")");
1678 }
1679
getResult__anon643663df0111::LowerMatrixIntrinsics::ExprLinearizer1680 const std::string &getResult() {
1681 Stream.flush();
1682 return Str;
1683 }
1684 };
1685
1686 /// Generate remarks for matrix operations in a function. To generate remarks
1687 /// for matrix expressions, the following approach is used:
1688 /// 1. Use the inlined-at debug information to group matrix operations to the
1689 /// DISubprograms they are contained in.
1690 /// 2. Collect leaves of matrix expressions (done in
1691 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
1692 // mapping. Leaves are lowered matrix instructions without other matrix
1693 // users (like stores) in the current subprogram.
1694 /// 3. For each leaf, create a remark containing a linearizied version of the
1695 /// matrix expression. The expression is linearized by a recursive
1696 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
1697 /// that multiple leaves can share sub-expressions. Shared subexpressions
1698 /// are explicitly marked as shared().
1699 struct RemarkGenerator {
1700 const MapVector<Value *, MatrixTy> &Inst2Matrix;
1701 OptimizationRemarkEmitter &ORE;
1702 Function &Func;
1703 const DataLayout &DL;
1704
RemarkGenerator__anon643663df0111::LowerMatrixIntrinsics::RemarkGenerator1705 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
1706 OptimizationRemarkEmitter &ORE, Function &Func)
1707 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
1708 DL(Func.getParent()->getDataLayout()) {}
1709
1710 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
1711 /// instructions in Inst2Matrix returning void or without any users in
1712 /// \p ExprsInSubprogram. Currently that should only include stores.
1713 SmallVector<Value *, 4>
getExpressionLeaves__anon643663df0111::LowerMatrixIntrinsics::RemarkGenerator1714 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
1715 SmallVector<Value *, 4> Leaves;
1716 for (auto *Expr : ExprsInSubprogram)
1717 if (Expr->getType()->isVoidTy() ||
1718 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
1719 return ExprsInSubprogram.count(U);
1720 }))
1721 Leaves.push_back(Expr);
1722 return Leaves;
1723 }
1724
1725 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
1726 /// to all visited expressions in \p Shared. Limit the matrix operations to
1727 /// the ones in \p ExprsInSubprogram.
collectSharedInfo__anon643663df0111::LowerMatrixIntrinsics::RemarkGenerator1728 void collectSharedInfo(Value *Leaf, Value *V,
1729 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
1730 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
1731
1732 if (!ExprsInSubprogram.count(V))
1733 return;
1734
1735 auto I = Shared.insert({V, {}});
1736 I.first->second.insert(Leaf);
1737
1738 for (Value *Op : cast<Instruction>(V)->operand_values())
1739 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
1740 return;
1741 }
1742
1743 /// Calculate the number of exclusive and shared op counts for expression
1744 /// starting at \p V. Expressions used multiple times are counted once.
1745 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
1746 std::pair<OpInfoTy, OpInfoTy>
sumOpInfos__anon643663df0111::LowerMatrixIntrinsics::RemarkGenerator1747 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
1748 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
1749 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
1750 if (!ExprsInSubprogram.count(Root))
1751 return {};
1752
1753 // Already counted this expression. Stop.
1754 if (!ReusedExprs.insert(Root).second)
1755 return {};
1756
1757 OpInfoTy SharedCount;
1758 OpInfoTy Count;
1759
1760 auto I = Shared.find(Root);
1761 auto CM = Inst2Matrix.find(Root);
1762 if (I->second.size() == 1)
1763 Count = CM->second.getOpInfo();
1764 else
1765 SharedCount = CM->second.getOpInfo();
1766
1767 for (Value *Op : cast<Instruction>(Root)->operand_values()) {
1768 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
1769 Count += C.first;
1770 SharedCount += C.second;
1771 }
1772 return {Count, SharedCount};
1773 }
1774
emitRemarks__anon643663df0111::LowerMatrixIntrinsics::RemarkGenerator1775 void emitRemarks() {
1776 if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
1777 return;
1778
1779 // Map matrix operations to their containting subprograms, by traversing
1780 // the inlinedAt chain. If the function does not have a DISubprogram, we
1781 // only map them to the containing function.
1782 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
1783 for (auto &KV : Inst2Matrix) {
1784 if (Func.getSubprogram()) {
1785 auto *I = cast<Instruction>(KV.first);
1786 DILocation *Context = I->getDebugLoc();
1787 while (Context) {
1788 auto I =
1789 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
1790 I.first->second.push_back(KV.first);
1791 Context = DebugLoc(Context).getInlinedAt();
1792 }
1793 } else {
1794 auto I = Subprog2Exprs.insert({nullptr, {}});
1795 I.first->second.push_back(KV.first);
1796 }
1797 }
1798 for (auto &KV : Subprog2Exprs) {
1799 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
1800 KV.second.end());
1801 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
1802
1803 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
1804 for (Value *Leaf : Leaves)
1805 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
1806
1807 // Generate remarks for each leaf.
1808 for (auto *L : Leaves) {
1809
1810 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
1811 DILocation *Context = cast<Instruction>(L)->getDebugLoc();
1812 while (Context) {
1813 if (getSubprogram(Context->getScope()) == KV.first) {
1814 Loc = Context;
1815 break;
1816 }
1817 Context = DebugLoc(Context).getInlinedAt();
1818 }
1819
1820 SmallPtrSet<Value *, 8> ReusedExprs;
1821 OpInfoTy Counts, SharedCounts;
1822 std::tie(Counts, SharedCounts) =
1823 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
1824
1825 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
1826 cast<Instruction>(L)->getParent());
1827
1828 Rem << "Lowered with ";
1829 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
1830 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
1831 << ore::NV("NumComputeOps", Counts.NumComputeOps)
1832 << " compute ops";
1833
1834 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
1835 SharedCounts.NumComputeOps > 0) {
1836 Rem << ",\nadditionally "
1837 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
1838 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
1839 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
1840 << " compute ops"
1841 << " are shared with other expressions";
1842 }
1843
1844 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
1845 ORE.emit(Rem);
1846 }
1847 }
1848 }
1849
1850 std::string
linearize__anon643663df0111::LowerMatrixIntrinsics::RemarkGenerator1851 linearize(Value *L,
1852 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
1853 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
1854 const DataLayout &DL) {
1855 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
1856 Lin.linearizeExpr(L, 0, false, false);
1857 return Lin.getResult();
1858 }
1859 };
1860 };
1861 } // namespace
1862
run(Function & F,FunctionAnalysisManager & AM)1863 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
1864 FunctionAnalysisManager &AM) {
1865 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
1866 auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
1867 auto &AA = AM.getResult<AAManager>(F);
1868 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
1869 auto &LI = AM.getResult<LoopAnalysis>(F);
1870
1871 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
1872 if (LMT.Visit()) {
1873 PreservedAnalyses PA;
1874 PA.preserveSet<CFGAnalyses>();
1875 return PA;
1876 }
1877 return PreservedAnalyses::all();
1878 }
1879
1880 namespace {
1881
1882 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
1883 public:
1884 static char ID;
1885
LowerMatrixIntrinsicsLegacyPass()1886 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
1887 initializeLowerMatrixIntrinsicsLegacyPassPass(
1888 *PassRegistry::getPassRegistry());
1889 }
1890
runOnFunction(Function & F)1891 bool runOnFunction(Function &F) override {
1892 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1893 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
1894 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
1895 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
1896 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1897 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
1898 bool C = LMT.Visit();
1899 return C;
1900 }
1901
getAnalysisUsage(AnalysisUsage & AU) const1902 void getAnalysisUsage(AnalysisUsage &AU) const override {
1903 AU.addRequired<TargetTransformInfoWrapperPass>();
1904 AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
1905 AU.addRequired<AAResultsWrapperPass>();
1906 AU.addRequired<DominatorTreeWrapperPass>();
1907 AU.addPreserved<DominatorTreeWrapperPass>();
1908 AU.addRequired<LoopInfoWrapperPass>();
1909 AU.addPreserved<LoopInfoWrapperPass>();
1910 }
1911 };
1912 } // namespace
1913
1914 static const char pass_name[] = "Lower the matrix intrinsics";
1915 char LowerMatrixIntrinsicsLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass,DEBUG_TYPE,pass_name,false,false)1916 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
1917 false, false)
1918 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
1919 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
1920 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
1921 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
1922 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
1923 false, false)
1924
1925 Pass *llvm::createLowerMatrixIntrinsicsPass() {
1926 return new LowerMatrixIntrinsicsLegacyPass();
1927 }
1928