1 //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the MatrixBuilder class, which is used as a convenient way 10 // to lower matrix operations to LLVM IR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_IR_MATRIXBUILDER_H 15 #define LLVM_IR_MATRIXBUILDER_H 16 17 #include "llvm/IR/Constant.h" 18 #include "llvm/IR/Constants.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/InstrTypes.h" 21 #include "llvm/IR/Instruction.h" 22 #include "llvm/IR/IntrinsicInst.h" 23 #include "llvm/IR/Type.h" 24 #include "llvm/IR/Value.h" 25 #include "llvm/Support/Alignment.h" 26 27 namespace llvm { 28 29 class Function; 30 class Twine; 31 class Module; 32 33 template <class IRBuilderTy> class MatrixBuilder { 34 IRBuilderTy &B; 35 Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } 36 37 std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS, 38 Value *RHS) { 39 assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && 40 "One of the operands must be a matrix (embedded in a vector)"); 41 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) 42 RHS = B.CreateVectorSplat( 43 cast<VectorType>(LHS->getType())->getNumElements(), RHS, 44 "scalar.splat"); 45 else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) 46 LHS = B.CreateVectorSplat( 47 cast<VectorType>(RHS->getType())->getNumElements(), LHS, 48 "scalar.splat"); 49 return {LHS, RHS}; 50 } 51 52 public: 53 MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {} 54 55 /// Create a column major, strided matrix load. 56 /// \p DataPtr - Start address of the matrix read 57 /// \p Rows - Number of rows in matrix (must be a constant) 58 /// \p Columns - Number of columns in matrix (must be a constant) 59 /// \p Stride - Space between columns 60 CallInst *CreateColumnMajorLoad(Value *DataPtr, Align Alignment, 61 Value *Stride, bool IsVolatile, unsigned Rows, 62 unsigned Columns, const Twine &Name = "") { 63 64 // Deal with the pointer 65 PointerType *PtrTy = cast<PointerType>(DataPtr->getType()); 66 Type *EltTy = PtrTy->getElementType(); 67 68 auto *RetType = FixedVectorType::get(EltTy, Rows * Columns); 69 70 Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows), 71 B.getInt32(Columns)}; 72 Type *OverloadedTypes[] = {RetType}; 73 74 Function *TheFn = Intrinsic::getDeclaration( 75 getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes); 76 77 CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 78 Attribute AlignAttr = 79 Attribute::getWithAlignment(Call->getContext(), Alignment); 80 Call->addAttribute(1, AlignAttr); 81 return Call; 82 } 83 84 /// Create a column major, strided matrix store. 85 /// \p Matrix - Matrix to store 86 /// \p Ptr - Pointer to write back to 87 /// \p Stride - Space between columns 88 CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, 89 Value *Stride, bool IsVolatile, 90 unsigned Rows, unsigned Columns, 91 const Twine &Name = "") { 92 Value *Ops[] = {Matrix, Ptr, 93 Stride, B.getInt1(IsVolatile), 94 B.getInt32(Rows), B.getInt32(Columns)}; 95 Type *OverloadedTypes[] = {Matrix->getType()}; 96 97 Function *TheFn = Intrinsic::getDeclaration( 98 getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes); 99 100 CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 101 Attribute AlignAttr = 102 Attribute::getWithAlignment(Call->getContext(), Alignment); 103 Call->addAttribute(2, AlignAttr); 104 return Call; 105 } 106 107 /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows 108 /// rows and \p Columns columns. 109 CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows, 110 unsigned Columns, const Twine &Name = "") { 111 auto *OpType = cast<VectorType>(Matrix->getType()); 112 auto *ReturnType = 113 FixedVectorType::get(OpType->getElementType(), Rows * Columns); 114 115 Type *OverloadedTypes[] = {ReturnType}; 116 Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)}; 117 Function *TheFn = Intrinsic::getDeclaration( 118 getModule(), Intrinsic::matrix_transpose, OverloadedTypes); 119 120 return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 121 } 122 123 /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p 124 /// RHS. 125 CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, 126 unsigned LHSColumns, unsigned RHSColumns, 127 const Twine &Name = "") { 128 auto *LHSType = cast<VectorType>(LHS->getType()); 129 auto *RHSType = cast<VectorType>(RHS->getType()); 130 131 auto *ReturnType = 132 FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns); 133 134 Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns), 135 B.getInt32(RHSColumns)}; 136 Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType}; 137 138 Function *TheFn = Intrinsic::getDeclaration( 139 getModule(), Intrinsic::matrix_multiply, OverloadedTypes); 140 return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 141 } 142 143 /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p 144 /// ColumnIdx). 145 Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, 146 Value *ColumnIdx, unsigned NumRows) { 147 return B.CreateInsertElement( 148 Matrix, NewVal, 149 B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get( 150 ColumnIdx->getType(), NumRows)), 151 RowIdx)); 152 } 153 154 /// Add matrixes \p LHS and \p RHS. Support both integer and floating point 155 /// matrixes. 156 Value *CreateAdd(Value *LHS, Value *RHS) { 157 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); 158 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) 159 RHS = B.CreateVectorSplat( 160 cast<VectorType>(LHS->getType())->getNumElements(), RHS, 161 "scalar.splat"); 162 else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) 163 LHS = B.CreateVectorSplat( 164 cast<VectorType>(RHS->getType())->getNumElements(), LHS, 165 "scalar.splat"); 166 167 return cast<VectorType>(LHS->getType()) 168 ->getElementType() 169 ->isFloatingPointTy() 170 ? B.CreateFAdd(LHS, RHS) 171 : B.CreateAdd(LHS, RHS); 172 } 173 174 /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating 175 /// point matrixes. 176 Value *CreateSub(Value *LHS, Value *RHS) { 177 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); 178 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) 179 RHS = B.CreateVectorSplat( 180 cast<VectorType>(LHS->getType())->getNumElements(), RHS, 181 "scalar.splat"); 182 else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) 183 LHS = B.CreateVectorSplat( 184 cast<VectorType>(RHS->getType())->getNumElements(), LHS, 185 "scalar.splat"); 186 187 return cast<VectorType>(LHS->getType()) 188 ->getElementType() 189 ->isFloatingPointTy() 190 ? B.CreateFSub(LHS, RHS) 191 : B.CreateSub(LHS, RHS); 192 } 193 194 /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p 195 /// RHS. 196 Value *CreateScalarMultiply(Value *LHS, Value *RHS) { 197 std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); 198 if (LHS->getType()->getScalarType()->isFloatingPointTy()) 199 return B.CreateFMul(LHS, RHS); 200 return B.CreateMul(LHS, RHS); 201 } 202 203 /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix. 204 Value *CreateExtractElement(Value *Matrix, Value *RowIdx, Value *ColumnIdx, 205 unsigned NumRows, Twine const &Name = "") { 206 207 unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(), 208 ColumnIdx->getType()->getScalarSizeInBits()); 209 Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth); 210 RowIdx = B.CreateZExt(RowIdx, IntTy); 211 ColumnIdx = B.CreateZExt(ColumnIdx, IntTy); 212 Value *NumRowsV = B.getIntN(MaxWidth, NumRows); 213 return B.CreateExtractElement( 214 Matrix, B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx), 215 "matext"); 216 } 217 }; 218 219 } // end namespace llvm 220 221 #endif // LLVM_IR_MATRIXBUILDER_H 222