1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/IR/TypeUtilities.h"
17
18 using namespace mlir;
19
20 namespace {
21 using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
22 using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
23 using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
24 using Log10OpLowering =
25 VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
26 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
27 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
28 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
29 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
30 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
31
32 // A `expm1` is converted into `exp - 1`.
33 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
34 using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
35
36 LogicalResult
matchAndRewrite__anon7c58be4b0111::ExpM1OpLowering37 matchAndRewrite(math::ExpM1Op op, ArrayRef<Value> operands,
38 ConversionPatternRewriter &rewriter) const override {
39 math::ExpM1Op::Adaptor transformed(operands);
40 auto operandType = transformed.operand().getType();
41
42 if (!operandType || !LLVM::isCompatibleType(operandType))
43 return failure();
44
45 auto loc = op.getLoc();
46 auto resultType = op.getResult().getType();
47 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
48 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
49
50 if (!operandType.isa<LLVM::LLVMArrayType>()) {
51 LLVM::ConstantOp one;
52 if (LLVM::isCompatibleVectorType(operandType)) {
53 one = rewriter.create<LLVM::ConstantOp>(
54 loc, operandType,
55 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
56 } else {
57 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
58 }
59 auto exp = rewriter.create<LLVM::ExpOp>(loc, transformed.operand());
60 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
61 return success();
62 }
63
64 auto vectorType = resultType.dyn_cast<VectorType>();
65 if (!vectorType)
66 return rewriter.notifyMatchFailure(op, "expected vector result type");
67
68 return LLVM::detail::handleMultidimensionalVectors(
69 op.getOperation(), operands, *getTypeConverter(),
70 [&](Type llvm1DVectorTy, ValueRange operands) {
71 auto splatAttr = SplatElementsAttr::get(
72 mlir::VectorType::get(
73 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
74 floatType),
75 floatOne);
76 auto one =
77 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
78 auto exp =
79 rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
80 return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
81 },
82 rewriter);
83 }
84 };
85
86 // A `log1p` is converted into `log(1 + ...)`.
87 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
88 using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
89
90 LogicalResult
matchAndRewrite__anon7c58be4b0111::Log1pOpLowering91 matchAndRewrite(math::Log1pOp op, ArrayRef<Value> operands,
92 ConversionPatternRewriter &rewriter) const override {
93 math::Log1pOp::Adaptor transformed(operands);
94 auto operandType = transformed.operand().getType();
95
96 if (!operandType || !LLVM::isCompatibleType(operandType))
97 return rewriter.notifyMatchFailure(op, "unsupported operand type");
98
99 auto loc = op.getLoc();
100 auto resultType = op.getResult().getType();
101 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
102 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
103
104 if (!operandType.isa<LLVM::LLVMArrayType>()) {
105 LLVM::ConstantOp one =
106 LLVM::isCompatibleVectorType(operandType)
107 ? rewriter.create<LLVM::ConstantOp>(
108 loc, operandType,
109 SplatElementsAttr::get(resultType.cast<ShapedType>(),
110 floatOne))
111 : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
112
113 auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
114 transformed.operand());
115 rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
116 return success();
117 }
118
119 auto vectorType = resultType.dyn_cast<VectorType>();
120 if (!vectorType)
121 return rewriter.notifyMatchFailure(op, "expected vector result type");
122
123 return LLVM::detail::handleMultidimensionalVectors(
124 op.getOperation(), operands, *getTypeConverter(),
125 [&](Type llvm1DVectorTy, ValueRange operands) {
126 auto splatAttr = SplatElementsAttr::get(
127 mlir::VectorType::get(
128 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
129 floatType),
130 floatOne);
131 auto one =
132 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
133 auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
134 operands[0]);
135 return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
136 },
137 rewriter);
138 }
139 };
140
141 // A `rsqrt` is converted into `1 / sqrt`.
142 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
143 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
144
145 LogicalResult
matchAndRewrite__anon7c58be4b0111::RsqrtOpLowering146 matchAndRewrite(math::RsqrtOp op, ArrayRef<Value> operands,
147 ConversionPatternRewriter &rewriter) const override {
148 math::RsqrtOp::Adaptor transformed(operands);
149 auto operandType = transformed.operand().getType();
150
151 if (!operandType || !LLVM::isCompatibleType(operandType))
152 return failure();
153
154 auto loc = op.getLoc();
155 auto resultType = op.getResult().getType();
156 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
157 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
158
159 if (!operandType.isa<LLVM::LLVMArrayType>()) {
160 LLVM::ConstantOp one;
161 if (LLVM::isCompatibleVectorType(operandType)) {
162 one = rewriter.create<LLVM::ConstantOp>(
163 loc, operandType,
164 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
165 } else {
166 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
167 }
168 auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
169 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
170 return success();
171 }
172
173 auto vectorType = resultType.dyn_cast<VectorType>();
174 if (!vectorType)
175 return failure();
176
177 return LLVM::detail::handleMultidimensionalVectors(
178 op.getOperation(), operands, *getTypeConverter(),
179 [&](Type llvm1DVectorTy, ValueRange operands) {
180 auto splatAttr = SplatElementsAttr::get(
181 mlir::VectorType::get(
182 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
183 floatType),
184 floatOne);
185 auto one =
186 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
187 auto sqrt =
188 rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
189 return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
190 },
191 rewriter);
192 }
193 };
194
195 struct ConvertMathToLLVMPass
196 : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
197 ConvertMathToLLVMPass() = default;
198
runOnFunction__anon7c58be4b0111::ConvertMathToLLVMPass199 void runOnFunction() override {
200 RewritePatternSet patterns(&getContext());
201 LLVMTypeConverter converter(&getContext());
202 populateMathToLLVMConversionPatterns(converter, patterns);
203 LLVMConversionTarget target(getContext());
204 if (failed(
205 applyPartialConversion(getFunction(), target, std::move(patterns))))
206 signalPassFailure();
207 }
208 };
209 } // namespace
210
populateMathToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)211 void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
212 RewritePatternSet &patterns) {
213 // clang-format off
214 patterns.add<
215 CosOpLowering,
216 ExpOpLowering,
217 Exp2OpLowering,
218 ExpM1OpLowering,
219 Log10OpLowering,
220 Log1pOpLowering,
221 Log2OpLowering,
222 LogOpLowering,
223 PowFOpLowering,
224 RsqrtOpLowering,
225 SinOpLowering,
226 SqrtOpLowering
227 >(converter);
228 // clang-format on
229 }
230
createConvertMathToLLVMPass()231 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
232 return std::make_unique<ConvertMathToLLVMPass>();
233 }
234