1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 10 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 12 #include "mlir/Dialect/GPU/GPUDialect.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/Ops.h" 15 #include "mlir/IR/Builders.h" 16 17 namespace mlir { 18 19 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` 20 /// depending on the element type that Op operates upon. The function 21 /// declaration is added in case it was not added before. 22 /// 23 /// Example with NVVM: 24 /// %exp_f32 = std.exp %arg_f32 : f32 25 /// 26 /// will be transformed into 27 /// llvm.call @__nv_expf(%arg_f32) : (!llvm.float) -> !llvm.float 28 template <typename SourceOp> 29 struct OpToFuncCallLowering : public LLVMOpLowering { 30 public: OpToFuncCallLoweringOpToFuncCallLowering31 explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func, 32 StringRef f64Func) 33 : LLVMOpLowering(SourceOp::getOperationName(), 34 lowering_.getDialect()->getContext(), lowering_), 35 f32Func(f32Func), f64Func(f64Func) {} 36 37 PatternMatchResult matchAndRewriteOpToFuncCallLowering38 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 39 ConversionPatternRewriter &rewriter) const override { 40 using LLVM::LLVMFuncOp; 41 using LLVM::LLVMType; 42 43 static_assert( 44 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, 45 "expected single result op"); 46 47 LLVMType resultType = lowering.convertType(op->getResult(0).getType()) 48 .template cast<LLVM::LLVMType>(); 49 LLVMType funcType = getFunctionType(resultType, operands); 50 StringRef funcName = getFunctionName(resultType); 51 if (funcName.empty()) 52 return matchFailure(); 53 54 LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op); 55 auto callOp = rewriter.create<LLVM::CallOp>( 56 op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands); 57 rewriter.replaceOp(op, {callOp.getResult(0)}); 58 return matchSuccess(); 59 } 60 61 private: getFunctionTypeOpToFuncCallLowering62 LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType, 63 ArrayRef<Value> operands) const { 64 using LLVM::LLVMType; 65 SmallVector<LLVMType, 1> operandTypes; 66 for (Value operand : operands) { 67 operandTypes.push_back(operand.getType().cast<LLVMType>()); 68 } 69 return LLVMType::getFunctionTy(resultType, operandTypes, 70 /*isVarArg=*/false); 71 } 72 getFunctionNameOpToFuncCallLowering73 StringRef getFunctionName(LLVM::LLVMType type) const { 74 if (type.isFloatTy()) 75 return f32Func; 76 if (type.isDoubleTy()) 77 return f64Func; 78 return ""; 79 } 80 appendOrGetFuncOpOpToFuncCallLowering81 LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, 82 LLVM::LLVMType funcType, 83 Operation *op) const { 84 using LLVM::LLVMFuncOp; 85 86 Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcName); 87 if (funcOp) 88 return cast<LLVMFuncOp>(*funcOp); 89 90 mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>()); 91 return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType); 92 } 93 94 const std::string f32Func; 95 const std::string f64Func; 96 }; 97 98 } // namespace mlir 99 100 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 101