1 //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MatrixBuilder class, which is used as a convenient way
10 // to lower matrix operations to LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_IR_MATRIXBUILDER_H
15 #define LLVM_IR_MATRIXBUILDER_H
16 
17 #include "llvm/IR/Constant.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstrTypes.h"
21 #include "llvm/IR/Instruction.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/Type.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Alignment.h"
26 
27 namespace llvm {
28 
29 class Function;
30 class Twine;
31 class Module;
32 
33 template <class IRBuilderTy> class MatrixBuilder {
34   IRBuilderTy &B;
35   Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
36 
37   std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
38                                                          Value *RHS) {
39     assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
40            "One of the operands must be a matrix (embedded in a vector)");
41     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy())
42       RHS = B.CreateVectorSplat(
43           cast<VectorType>(LHS->getType())->getNumElements(), RHS,
44           "scalar.splat");
45     else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy())
46       LHS = B.CreateVectorSplat(
47           cast<VectorType>(RHS->getType())->getNumElements(), LHS,
48           "scalar.splat");
49     return {LHS, RHS};
50   }
51 
52 public:
53   MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {}
54 
55   /// Create a column major, strided matrix load.
56   /// \p DataPtr - Start address of the matrix read
57   /// \p Rows    - Number of rows in matrix (must be a constant)
58   /// \p Columns - Number of columns in matrix (must be a constant)
59   /// \p Stride  - Space between columns
60   CallInst *CreateColumnMajorLoad(Value *DataPtr, Align Alignment,
61                                   Value *Stride, bool IsVolatile, unsigned Rows,
62                                   unsigned Columns, const Twine &Name = "") {
63 
64     // Deal with the pointer
65     PointerType *PtrTy = cast<PointerType>(DataPtr->getType());
66     Type *EltTy = PtrTy->getElementType();
67 
68     auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
69 
70     Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
71                     B.getInt32(Columns)};
72     Type *OverloadedTypes[] = {RetType};
73 
74     Function *TheFn = Intrinsic::getDeclaration(
75         getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
76 
77     CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
78     Attribute AlignAttr =
79         Attribute::getWithAlignment(Call->getContext(), Alignment);
80     Call->addAttribute(1, AlignAttr);
81     return Call;
82   }
83 
84   /// Create a column major, strided matrix store.
85   /// \p Matrix  - Matrix to store
86   /// \p Ptr     - Pointer to write back to
87   /// \p Stride  - Space between columns
88   CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment,
89                                    Value *Stride, bool IsVolatile,
90                                    unsigned Rows, unsigned Columns,
91                                    const Twine &Name = "") {
92     Value *Ops[] = {Matrix,           Ptr,
93                     Stride,           B.getInt1(IsVolatile),
94                     B.getInt32(Rows), B.getInt32(Columns)};
95     Type *OverloadedTypes[] = {Matrix->getType()};
96 
97     Function *TheFn = Intrinsic::getDeclaration(
98         getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
99 
100     CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
101     Attribute AlignAttr =
102         Attribute::getWithAlignment(Call->getContext(), Alignment);
103     Call->addAttribute(2, AlignAttr);
104     return Call;
105   }
106 
107   /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
108   /// rows and \p Columns columns.
109   CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
110                                   unsigned Columns, const Twine &Name = "") {
111     auto *OpType = cast<VectorType>(Matrix->getType());
112     auto *ReturnType =
113         FixedVectorType::get(OpType->getElementType(), Rows * Columns);
114 
115     Type *OverloadedTypes[] = {ReturnType};
116     Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
117     Function *TheFn = Intrinsic::getDeclaration(
118         getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
119 
120     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
121   }
122 
123   /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
124   /// RHS.
125   CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
126                                  unsigned LHSColumns, unsigned RHSColumns,
127                                  const Twine &Name = "") {
128     auto *LHSType = cast<VectorType>(LHS->getType());
129     auto *RHSType = cast<VectorType>(RHS->getType());
130 
131     auto *ReturnType =
132         FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
133 
134     Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
135                     B.getInt32(RHSColumns)};
136     Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
137 
138     Function *TheFn = Intrinsic::getDeclaration(
139         getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
140     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
141   }
142 
143   /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
144   /// ColumnIdx).
145   Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
146                             Value *ColumnIdx, unsigned NumRows) {
147     return B.CreateInsertElement(
148         Matrix, NewVal,
149         B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
150                                                ColumnIdx->getType(), NumRows)),
151                     RowIdx));
152   }
153 
154   /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
155   /// matrixes.
156   Value *CreateAdd(Value *LHS, Value *RHS) {
157     assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
158     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy())
159       RHS = B.CreateVectorSplat(
160           cast<VectorType>(LHS->getType())->getNumElements(), RHS,
161           "scalar.splat");
162     else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy())
163       LHS = B.CreateVectorSplat(
164           cast<VectorType>(RHS->getType())->getNumElements(), LHS,
165           "scalar.splat");
166 
167     return cast<VectorType>(LHS->getType())
168                    ->getElementType()
169                    ->isFloatingPointTy()
170                ? B.CreateFAdd(LHS, RHS)
171                : B.CreateAdd(LHS, RHS);
172   }
173 
174   /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
175   /// point matrixes.
176   Value *CreateSub(Value *LHS, Value *RHS) {
177     assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
178     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy())
179       RHS = B.CreateVectorSplat(
180           cast<VectorType>(LHS->getType())->getNumElements(), RHS,
181           "scalar.splat");
182     else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy())
183       LHS = B.CreateVectorSplat(
184           cast<VectorType>(RHS->getType())->getNumElements(), LHS,
185           "scalar.splat");
186 
187     return cast<VectorType>(LHS->getType())
188                    ->getElementType()
189                    ->isFloatingPointTy()
190                ? B.CreateFSub(LHS, RHS)
191                : B.CreateSub(LHS, RHS);
192   }
193 
194   /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
195   /// RHS.
196   Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
197     std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
198     if (LHS->getType()->getScalarType()->isFloatingPointTy())
199       return B.CreateFMul(LHS, RHS);
200     return B.CreateMul(LHS, RHS);
201   }
202 
203   /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix.
204   Value *CreateExtractElement(Value *Matrix, Value *RowIdx, Value *ColumnIdx,
205                               unsigned NumRows, Twine const &Name = "") {
206 
207     unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
208                                  ColumnIdx->getType()->getScalarSizeInBits());
209     Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
210     RowIdx = B.CreateZExt(RowIdx, IntTy);
211     ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
212     Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
213     return B.CreateExtractElement(
214         Matrix, B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx),
215         "matext");
216   }
217 };
218 
219 } // end namespace llvm
220 
221 #endif // LLVM_IR_MATRIXBUILDER_H
222