1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/IR/TypeUtilities.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
22 using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
23 using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
24 using Log10OpLowering =
25     VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
26 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
27 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
28 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
29 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
30 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
31 
32 // A `expm1` is converted into `exp - 1`.
33 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
34   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
35 
36   LogicalResult
matchAndRewrite__anon7c58be4b0111::ExpM1OpLowering37   matchAndRewrite(math::ExpM1Op op, ArrayRef<Value> operands,
38                   ConversionPatternRewriter &rewriter) const override {
39     math::ExpM1Op::Adaptor transformed(operands);
40     auto operandType = transformed.operand().getType();
41 
42     if (!operandType || !LLVM::isCompatibleType(operandType))
43       return failure();
44 
45     auto loc = op.getLoc();
46     auto resultType = op.getResult().getType();
47     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
48     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
49 
50     if (!operandType.isa<LLVM::LLVMArrayType>()) {
51       LLVM::ConstantOp one;
52       if (LLVM::isCompatibleVectorType(operandType)) {
53         one = rewriter.create<LLVM::ConstantOp>(
54             loc, operandType,
55             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
56       } else {
57         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
58       }
59       auto exp = rewriter.create<LLVM::ExpOp>(loc, transformed.operand());
60       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
61       return success();
62     }
63 
64     auto vectorType = resultType.dyn_cast<VectorType>();
65     if (!vectorType)
66       return rewriter.notifyMatchFailure(op, "expected vector result type");
67 
68     return LLVM::detail::handleMultidimensionalVectors(
69         op.getOperation(), operands, *getTypeConverter(),
70         [&](Type llvm1DVectorTy, ValueRange operands) {
71           auto splatAttr = SplatElementsAttr::get(
72               mlir::VectorType::get(
73                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
74                   floatType),
75               floatOne);
76           auto one =
77               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
78           auto exp =
79               rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
80           return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
81         },
82         rewriter);
83   }
84 };
85 
86 // A `log1p` is converted into `log(1 + ...)`.
87 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
88   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
89 
90   LogicalResult
matchAndRewrite__anon7c58be4b0111::Log1pOpLowering91   matchAndRewrite(math::Log1pOp op, ArrayRef<Value> operands,
92                   ConversionPatternRewriter &rewriter) const override {
93     math::Log1pOp::Adaptor transformed(operands);
94     auto operandType = transformed.operand().getType();
95 
96     if (!operandType || !LLVM::isCompatibleType(operandType))
97       return rewriter.notifyMatchFailure(op, "unsupported operand type");
98 
99     auto loc = op.getLoc();
100     auto resultType = op.getResult().getType();
101     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
102     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
103 
104     if (!operandType.isa<LLVM::LLVMArrayType>()) {
105       LLVM::ConstantOp one =
106           LLVM::isCompatibleVectorType(operandType)
107               ? rewriter.create<LLVM::ConstantOp>(
108                     loc, operandType,
109                     SplatElementsAttr::get(resultType.cast<ShapedType>(),
110                                            floatOne))
111               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
112 
113       auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
114                                                transformed.operand());
115       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
116       return success();
117     }
118 
119     auto vectorType = resultType.dyn_cast<VectorType>();
120     if (!vectorType)
121       return rewriter.notifyMatchFailure(op, "expected vector result type");
122 
123     return LLVM::detail::handleMultidimensionalVectors(
124         op.getOperation(), operands, *getTypeConverter(),
125         [&](Type llvm1DVectorTy, ValueRange operands) {
126           auto splatAttr = SplatElementsAttr::get(
127               mlir::VectorType::get(
128                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
129                   floatType),
130               floatOne);
131           auto one =
132               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
133           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
134                                                    operands[0]);
135           return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
136         },
137         rewriter);
138   }
139 };
140 
141 // A `rsqrt` is converted into `1 / sqrt`.
142 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
143   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
144 
145   LogicalResult
matchAndRewrite__anon7c58be4b0111::RsqrtOpLowering146   matchAndRewrite(math::RsqrtOp op, ArrayRef<Value> operands,
147                   ConversionPatternRewriter &rewriter) const override {
148     math::RsqrtOp::Adaptor transformed(operands);
149     auto operandType = transformed.operand().getType();
150 
151     if (!operandType || !LLVM::isCompatibleType(operandType))
152       return failure();
153 
154     auto loc = op.getLoc();
155     auto resultType = op.getResult().getType();
156     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
157     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
158 
159     if (!operandType.isa<LLVM::LLVMArrayType>()) {
160       LLVM::ConstantOp one;
161       if (LLVM::isCompatibleVectorType(operandType)) {
162         one = rewriter.create<LLVM::ConstantOp>(
163             loc, operandType,
164             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
165       } else {
166         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
167       }
168       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
169       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
170       return success();
171     }
172 
173     auto vectorType = resultType.dyn_cast<VectorType>();
174     if (!vectorType)
175       return failure();
176 
177     return LLVM::detail::handleMultidimensionalVectors(
178         op.getOperation(), operands, *getTypeConverter(),
179         [&](Type llvm1DVectorTy, ValueRange operands) {
180           auto splatAttr = SplatElementsAttr::get(
181               mlir::VectorType::get(
182                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
183                   floatType),
184               floatOne);
185           auto one =
186               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
187           auto sqrt =
188               rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
189           return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
190         },
191         rewriter);
192   }
193 };
194 
195 struct ConvertMathToLLVMPass
196     : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
197   ConvertMathToLLVMPass() = default;
198 
runOnFunction__anon7c58be4b0111::ConvertMathToLLVMPass199   void runOnFunction() override {
200     RewritePatternSet patterns(&getContext());
201     LLVMTypeConverter converter(&getContext());
202     populateMathToLLVMConversionPatterns(converter, patterns);
203     LLVMConversionTarget target(getContext());
204     if (failed(
205             applyPartialConversion(getFunction(), target, std::move(patterns))))
206       signalPassFailure();
207   }
208 };
209 } // namespace
210 
populateMathToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)211 void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
212                                                 RewritePatternSet &patterns) {
213   // clang-format off
214   patterns.add<
215     CosOpLowering,
216     ExpOpLowering,
217     Exp2OpLowering,
218     ExpM1OpLowering,
219     Log10OpLowering,
220     Log1pOpLowering,
221     Log2OpLowering,
222     LogOpLowering,
223     PowFOpLowering,
224     RsqrtOpLowering,
225     SinOpLowering,
226     SqrtOpLowering
227   >(converter);
228   // clang-format on
229 }
230 
createConvertMathToLLVMPass()231 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
232   return std::make_unique<ConvertMathToLLVMPass>();
233 }
234