1 //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the MatrixBuilder class, which is used as a convenient way 10 // to lower matrix operations to LLVM IR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_IR_MATRIXBUILDER_H 15 #define LLVM_IR_MATRIXBUILDER_H 16 17 #include "llvm/IR/Constant.h" 18 #include "llvm/IR/Constants.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/InstrTypes.h" 21 #include "llvm/IR/Instruction.h" 22 #include "llvm/IR/IntrinsicInst.h" 23 #include "llvm/IR/Type.h" 24 #include "llvm/IR/Value.h" 25 #include "llvm/Support/Alignment.h" 26 27 namespace llvm { 28 29 class Function; 30 class Twine; 31 class Module; 32 33 class MatrixBuilder { 34 IRBuilderBase &B; getModule()35 Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } 36 splatScalarOperandIfNeeded(Value * LHS,Value * RHS)37 std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS, 38 Value *RHS) { 39 assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && 40 "One of the operands must be a matrix (embedded in a vector)"); 41 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { 42 assert(!isa<ScalableVectorType>(LHS->getType()) && 43 "LHS Assumed to be fixed width"); 44 RHS = B.CreateVectorSplat( 45 cast<VectorType>(LHS->getType())->getElementCount(), RHS, 46 "scalar.splat"); 47 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { 48 assert(!isa<ScalableVectorType>(RHS->getType()) && 49 "RHS Assumed to be fixed width"); 50 LHS = B.CreateVectorSplat( 51 cast<VectorType>(RHS->getType())->getElementCount(), LHS, 52 "scalar.splat"); 53 } 54 return {LHS, RHS}; 55 } 56 57 public: MatrixBuilder(IRBuilderBase & Builder)58 MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {} 59 60 /// Create a column major, strided matrix load. 61 /// \p EltTy - Matrix element type 62 /// \p DataPtr - Start address of the matrix read 63 /// \p Rows - Number of rows in matrix (must be a constant) 64 /// \p Columns - Number of columns in matrix (must be a constant) 65 /// \p Stride - Space between columns 66 CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment, 67 Value *Stride, bool IsVolatile, unsigned Rows, 68 unsigned Columns, const Twine &Name = "") { 69 auto *RetType = FixedVectorType::get(EltTy, Rows * Columns); 70 71 Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows), 72 B.getInt32(Columns)}; 73 Type *OverloadedTypes[] = {RetType, Stride->getType()}; 74 75 Function *TheFn = Intrinsic::getDeclaration( 76 getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes); 77 78 CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 79 Attribute AlignAttr = 80 Attribute::getWithAlignment(Call->getContext(), Alignment); 81 Call->addParamAttr(0, AlignAttr); 82 return Call; 83 } 84 85 /// Create a column major, strided matrix store. 86 /// \p Matrix - Matrix to store 87 /// \p Ptr - Pointer to write back to 88 /// \p Stride - Space between columns 89 CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, 90 Value *Stride, bool IsVolatile, 91 unsigned Rows, unsigned Columns, 92 const Twine &Name = "") { 93 Value *Ops[] = {Matrix, Ptr, 94 Stride, B.getInt1(IsVolatile), 95 B.getInt32(Rows), B.getInt32(Columns)}; 96 Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()}; 97 98 Function *TheFn = Intrinsic::getDeclaration( 99 getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes); 100 101 CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 102 Attribute AlignAttr = 103 Attribute::getWithAlignment(Call->getContext(), Alignment); 104 Call->addParamAttr(1, AlignAttr); 105 return Call; 106 } 107 108 /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows 109 /// rows and \p Columns columns. 110 CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows, 111 unsigned Columns, const Twine &Name = "") { 112 auto *OpType = cast<VectorType>(Matrix->getType()); 113 auto *ReturnType = 114 FixedVectorType::get(OpType->getElementType(), Rows * Columns); 115 116 Type *OverloadedTypes[] = {ReturnType}; 117 Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)}; 118 Function *TheFn = Intrinsic::getDeclaration( 119 getModule(), Intrinsic::matrix_transpose, OverloadedTypes); 120 121 return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 122 } 123 124 /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p 125 /// RHS. 126 CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, 127 unsigned LHSColumns, unsigned RHSColumns, 128 const Twine &Name = "") { 129 auto *LHSType = cast<VectorType>(LHS->getType()); 130 auto *RHSType = cast<VectorType>(RHS->getType()); 131 132 auto *ReturnType = 133 FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns); 134 135 Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns), 136 B.getInt32(RHSColumns)}; 137 Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType}; 138 139 Function *TheFn = Intrinsic::getDeclaration( 140 getModule(), Intrinsic::matrix_multiply, OverloadedTypes); 141 return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 142 } 143 144 /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p 145 /// ColumnIdx). CreateMatrixInsert(Value * Matrix,Value * NewVal,Value * RowIdx,Value * ColumnIdx,unsigned NumRows)146 Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, 147 Value *ColumnIdx, unsigned NumRows) { 148 return B.CreateInsertElement( 149 Matrix, NewVal, 150 B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get( 151 ColumnIdx->getType(), NumRows)), 152 RowIdx)); 153 } 154 155 /// Add matrixes \p LHS and \p RHS. Support both integer and floating point 156 /// matrixes. CreateAdd(Value * LHS,Value * RHS)157 Value *CreateAdd(Value *LHS, Value *RHS) { 158 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); 159 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { 160 assert(!isa<ScalableVectorType>(LHS->getType()) && 161 "LHS Assumed to be fixed width"); 162 RHS = B.CreateVectorSplat( 163 cast<VectorType>(LHS->getType())->getElementCount(), RHS, 164 "scalar.splat"); 165 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { 166 assert(!isa<ScalableVectorType>(RHS->getType()) && 167 "RHS Assumed to be fixed width"); 168 LHS = B.CreateVectorSplat( 169 cast<VectorType>(RHS->getType())->getElementCount(), LHS, 170 "scalar.splat"); 171 } 172 173 return cast<VectorType>(LHS->getType()) 174 ->getElementType() 175 ->isFloatingPointTy() 176 ? B.CreateFAdd(LHS, RHS) 177 : B.CreateAdd(LHS, RHS); 178 } 179 180 /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating 181 /// point matrixes. CreateSub(Value * LHS,Value * RHS)182 Value *CreateSub(Value *LHS, Value *RHS) { 183 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); 184 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { 185 assert(!isa<ScalableVectorType>(LHS->getType()) && 186 "LHS Assumed to be fixed width"); 187 RHS = B.CreateVectorSplat( 188 cast<VectorType>(LHS->getType())->getElementCount(), RHS, 189 "scalar.splat"); 190 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { 191 assert(!isa<ScalableVectorType>(RHS->getType()) && 192 "RHS Assumed to be fixed width"); 193 LHS = B.CreateVectorSplat( 194 cast<VectorType>(RHS->getType())->getElementCount(), LHS, 195 "scalar.splat"); 196 } 197 198 return cast<VectorType>(LHS->getType()) 199 ->getElementType() 200 ->isFloatingPointTy() 201 ? B.CreateFSub(LHS, RHS) 202 : B.CreateSub(LHS, RHS); 203 } 204 205 /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p 206 /// RHS. CreateScalarMultiply(Value * LHS,Value * RHS)207 Value *CreateScalarMultiply(Value *LHS, Value *RHS) { 208 std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); 209 if (LHS->getType()->getScalarType()->isFloatingPointTy()) 210 return B.CreateFMul(LHS, RHS); 211 return B.CreateMul(LHS, RHS); 212 } 213 214 /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p 215 /// IsUnsigned indicates whether UDiv or SDiv should be used. CreateScalarDiv(Value * LHS,Value * RHS,bool IsUnsigned)216 Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) { 217 assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()); 218 assert(!isa<ScalableVectorType>(LHS->getType()) && 219 "LHS Assumed to be fixed width"); 220 RHS = 221 B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(), 222 RHS, "scalar.splat"); 223 return cast<VectorType>(LHS->getType()) 224 ->getElementType() 225 ->isFloatingPointTy() 226 ? B.CreateFDiv(LHS, RHS) 227 : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS)); 228 } 229 230 /// Create an assumption that \p Idx is less than \p NumElements. 231 void CreateIndexAssumption(Value *Idx, unsigned NumElements, 232 Twine const &Name = "") { 233 Value *NumElts = 234 B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements); 235 auto *Cmp = B.CreateICmpULT(Idx, NumElts); 236 if (isa<ConstantInt>(Cmp)) 237 assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!"); 238 else 239 B.CreateAssumption(Cmp); 240 } 241 242 /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from 243 /// a matrix with \p NumRows embedded in a vector. 244 Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows, 245 Twine const &Name = "") { 246 unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(), 247 ColumnIdx->getType()->getScalarSizeInBits()); 248 Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth); 249 RowIdx = B.CreateZExt(RowIdx, IntTy); 250 ColumnIdx = B.CreateZExt(ColumnIdx, IntTy); 251 Value *NumRowsV = B.getIntN(MaxWidth, NumRows); 252 return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx); 253 } 254 }; 255 256 } // end namespace llvm 257 258 #endif // LLVM_IR_MATRIXBUILDER_H 259