1 //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MatrixBuilder class, which is used as a convenient way
10 // to lower matrix operations to LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_IR_MATRIXBUILDER_H
15 #define LLVM_IR_MATRIXBUILDER_H
16 
17 #include "llvm/IR/Constant.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstrTypes.h"
21 #include "llvm/IR/Instruction.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/Type.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Alignment.h"
26 
27 namespace llvm {
28 
29 class Function;
30 class Twine;
31 class Module;
32 
33 class MatrixBuilder {
34   IRBuilderBase &B;
getModule()35   Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
36 
splatScalarOperandIfNeeded(Value * LHS,Value * RHS)37   std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
38                                                          Value *RHS) {
39     assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
40            "One of the operands must be a matrix (embedded in a vector)");
41     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
42       assert(!isa<ScalableVectorType>(LHS->getType()) &&
43              "LHS Assumed to be fixed width");
44       RHS = B.CreateVectorSplat(
45           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
46           "scalar.splat");
47     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
48       assert(!isa<ScalableVectorType>(RHS->getType()) &&
49              "RHS Assumed to be fixed width");
50       LHS = B.CreateVectorSplat(
51           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
52           "scalar.splat");
53     }
54     return {LHS, RHS};
55   }
56 
57 public:
MatrixBuilder(IRBuilderBase & Builder)58   MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {}
59 
60   /// Create a column major, strided matrix load.
61   /// \p EltTy   - Matrix element type
62   /// \p DataPtr - Start address of the matrix read
63   /// \p Rows    - Number of rows in matrix (must be a constant)
64   /// \p Columns - Number of columns in matrix (must be a constant)
65   /// \p Stride  - Space between columns
66   CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment,
67                                   Value *Stride, bool IsVolatile, unsigned Rows,
68                                   unsigned Columns, const Twine &Name = "") {
69     auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
70 
71     Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
72                     B.getInt32(Columns)};
73     Type *OverloadedTypes[] = {RetType, Stride->getType()};
74 
75     Function *TheFn = Intrinsic::getDeclaration(
76         getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
77 
78     CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
79     Attribute AlignAttr =
80         Attribute::getWithAlignment(Call->getContext(), Alignment);
81     Call->addParamAttr(0, AlignAttr);
82     return Call;
83   }
84 
85   /// Create a column major, strided matrix store.
86   /// \p Matrix  - Matrix to store
87   /// \p Ptr     - Pointer to write back to
88   /// \p Stride  - Space between columns
89   CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment,
90                                    Value *Stride, bool IsVolatile,
91                                    unsigned Rows, unsigned Columns,
92                                    const Twine &Name = "") {
93     Value *Ops[] = {Matrix,           Ptr,
94                     Stride,           B.getInt1(IsVolatile),
95                     B.getInt32(Rows), B.getInt32(Columns)};
96     Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()};
97 
98     Function *TheFn = Intrinsic::getDeclaration(
99         getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
100 
101     CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
102     Attribute AlignAttr =
103         Attribute::getWithAlignment(Call->getContext(), Alignment);
104     Call->addParamAttr(1, AlignAttr);
105     return Call;
106   }
107 
108   /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
109   /// rows and \p Columns columns.
110   CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
111                                   unsigned Columns, const Twine &Name = "") {
112     auto *OpType = cast<VectorType>(Matrix->getType());
113     auto *ReturnType =
114         FixedVectorType::get(OpType->getElementType(), Rows * Columns);
115 
116     Type *OverloadedTypes[] = {ReturnType};
117     Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
118     Function *TheFn = Intrinsic::getDeclaration(
119         getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
120 
121     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
122   }
123 
124   /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
125   /// RHS.
126   CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
127                                  unsigned LHSColumns, unsigned RHSColumns,
128                                  const Twine &Name = "") {
129     auto *LHSType = cast<VectorType>(LHS->getType());
130     auto *RHSType = cast<VectorType>(RHS->getType());
131 
132     auto *ReturnType =
133         FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
134 
135     Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
136                     B.getInt32(RHSColumns)};
137     Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
138 
139     Function *TheFn = Intrinsic::getDeclaration(
140         getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
141     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
142   }
143 
144   /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
145   /// ColumnIdx).
CreateMatrixInsert(Value * Matrix,Value * NewVal,Value * RowIdx,Value * ColumnIdx,unsigned NumRows)146   Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
147                             Value *ColumnIdx, unsigned NumRows) {
148     return B.CreateInsertElement(
149         Matrix, NewVal,
150         B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
151                                                ColumnIdx->getType(), NumRows)),
152                     RowIdx));
153   }
154 
155   /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
156   /// matrixes.
CreateAdd(Value * LHS,Value * RHS)157   Value *CreateAdd(Value *LHS, Value *RHS) {
158     assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
159     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
160       assert(!isa<ScalableVectorType>(LHS->getType()) &&
161              "LHS Assumed to be fixed width");
162       RHS = B.CreateVectorSplat(
163           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
164           "scalar.splat");
165     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
166       assert(!isa<ScalableVectorType>(RHS->getType()) &&
167              "RHS Assumed to be fixed width");
168       LHS = B.CreateVectorSplat(
169           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
170           "scalar.splat");
171     }
172 
173     return cast<VectorType>(LHS->getType())
174                    ->getElementType()
175                    ->isFloatingPointTy()
176                ? B.CreateFAdd(LHS, RHS)
177                : B.CreateAdd(LHS, RHS);
178   }
179 
180   /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
181   /// point matrixes.
CreateSub(Value * LHS,Value * RHS)182   Value *CreateSub(Value *LHS, Value *RHS) {
183     assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
184     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
185       assert(!isa<ScalableVectorType>(LHS->getType()) &&
186              "LHS Assumed to be fixed width");
187       RHS = B.CreateVectorSplat(
188           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
189           "scalar.splat");
190     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
191       assert(!isa<ScalableVectorType>(RHS->getType()) &&
192              "RHS Assumed to be fixed width");
193       LHS = B.CreateVectorSplat(
194           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
195           "scalar.splat");
196     }
197 
198     return cast<VectorType>(LHS->getType())
199                    ->getElementType()
200                    ->isFloatingPointTy()
201                ? B.CreateFSub(LHS, RHS)
202                : B.CreateSub(LHS, RHS);
203   }
204 
205   /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
206   /// RHS.
CreateScalarMultiply(Value * LHS,Value * RHS)207   Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
208     std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
209     if (LHS->getType()->getScalarType()->isFloatingPointTy())
210       return B.CreateFMul(LHS, RHS);
211     return B.CreateMul(LHS, RHS);
212   }
213 
214   /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
215   /// IsUnsigned indicates whether UDiv or SDiv should be used.
CreateScalarDiv(Value * LHS,Value * RHS,bool IsUnsigned)216   Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
217     assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
218     assert(!isa<ScalableVectorType>(LHS->getType()) &&
219            "LHS Assumed to be fixed width");
220     RHS =
221         B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
222                             RHS, "scalar.splat");
223     return cast<VectorType>(LHS->getType())
224                    ->getElementType()
225                    ->isFloatingPointTy()
226                ? B.CreateFDiv(LHS, RHS)
227                : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
228   }
229 
230   /// Create an assumption that \p Idx is less than \p NumElements.
231   void CreateIndexAssumption(Value *Idx, unsigned NumElements,
232                              Twine const &Name = "") {
233     Value *NumElts =
234         B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements);
235     auto *Cmp = B.CreateICmpULT(Idx, NumElts);
236     if (isa<ConstantInt>(Cmp))
237       assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!");
238     else
239       B.CreateAssumption(Cmp);
240   }
241 
242   /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
243   /// a matrix with \p NumRows embedded in a vector.
244   Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
245                      Twine const &Name = "") {
246     unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
247                                  ColumnIdx->getType()->getScalarSizeInBits());
248     Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
249     RowIdx = B.CreateZExt(RowIdx, IntTy);
250     ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
251     Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
252     return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
253   }
254 };
255 
256 } // end namespace llvm
257 
258 #endif // LLVM_IR_MATRIXBUILDER_H
259