1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Dialect/Complex/IR/Complex.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16
17 using namespace mlir;
18 using namespace mlir::LLVM;
19
20 //===----------------------------------------------------------------------===//
21 // ComplexStructBuilder implementation.
22 //===----------------------------------------------------------------------===//
23
24 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
25 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
26
undef(OpBuilder & builder,Location loc,Type type)27 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
28 Location loc, Type type) {
29 Value val = builder.create<LLVM::UndefOp>(loc, type);
30 return ComplexStructBuilder(val);
31 }
32
setReal(OpBuilder & builder,Location loc,Value real)33 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
34 Value real) {
35 setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
36 }
37
real(OpBuilder & builder,Location loc)38 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
39 return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
40 }
41
setImaginary(OpBuilder & builder,Location loc,Value imaginary)42 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
43 Value imaginary) {
44 setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
45 }
46
imaginary(OpBuilder & builder,Location loc)47 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
48 return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
49 }
50
51 //===----------------------------------------------------------------------===//
52 // Conversion patterns.
53 //===----------------------------------------------------------------------===//
54
55 namespace {
56
57 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
58 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
59
60 LogicalResult
matchAndRewrite__anon33076c320111::AbsOpConversion61 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
62 ConversionPatternRewriter &rewriter) const override {
63 auto loc = op.getLoc();
64
65 ComplexStructBuilder complexStruct(adaptor.complex());
66 Value real = complexStruct.real(rewriter, op.getLoc());
67 Value imag = complexStruct.imaginary(rewriter, op.getLoc());
68
69 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
70 Value sqNorm = rewriter.create<LLVM::FAddOp>(
71 loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
72 rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
73
74 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
75 return success();
76 }
77 };
78
79 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
80 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
81
82 LogicalResult
matchAndRewrite__anon33076c320111::CreateOpConversion83 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
84 ConversionPatternRewriter &rewriter) const override {
85 // Pack real and imaginary part in a complex number struct.
86 auto loc = complexOp.getLoc();
87 auto structType = typeConverter->convertType(complexOp.getType());
88 auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
89 complexStruct.setReal(rewriter, loc, adaptor.real());
90 complexStruct.setImaginary(rewriter, loc, adaptor.imaginary());
91
92 rewriter.replaceOp(complexOp, {complexStruct});
93 return success();
94 }
95 };
96
97 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
98 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
99
100 LogicalResult
matchAndRewrite__anon33076c320111::ReOpConversion101 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
102 ConversionPatternRewriter &rewriter) const override {
103 // Extract real part from the complex number struct.
104 ComplexStructBuilder complexStruct(adaptor.complex());
105 Value real = complexStruct.real(rewriter, op.getLoc());
106 rewriter.replaceOp(op, real);
107
108 return success();
109 }
110 };
111
112 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
113 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
114
115 LogicalResult
matchAndRewrite__anon33076c320111::ImOpConversion116 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
117 ConversionPatternRewriter &rewriter) const override {
118 // Extract imaginary part from the complex number struct.
119 ComplexStructBuilder complexStruct(adaptor.complex());
120 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
121 rewriter.replaceOp(op, imaginary);
122
123 return success();
124 }
125 };
126
127 struct BinaryComplexOperands {
128 std::complex<Value> lhs;
129 std::complex<Value> rhs;
130 };
131
132 template <typename OpTy>
133 BinaryComplexOperands
unpackBinaryComplexOperands(OpTy op,typename OpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter)134 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
135 ConversionPatternRewriter &rewriter) {
136 auto loc = op.getLoc();
137
138 // Extract real and imaginary values from operands.
139 BinaryComplexOperands unpacked;
140 ComplexStructBuilder lhs(adaptor.lhs());
141 unpacked.lhs.real(lhs.real(rewriter, loc));
142 unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
143 ComplexStructBuilder rhs(adaptor.rhs());
144 unpacked.rhs.real(rhs.real(rewriter, loc));
145 unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
146
147 return unpacked;
148 }
149
150 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
151 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
152
153 LogicalResult
matchAndRewrite__anon33076c320111::AddOpConversion154 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
155 ConversionPatternRewriter &rewriter) const override {
156 auto loc = op.getLoc();
157 BinaryComplexOperands arg =
158 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
159
160 // Initialize complex number struct for result.
161 auto structType = typeConverter->convertType(op.getType());
162 auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
163
164 // Emit IR to add complex numbers.
165 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
166 Value real =
167 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
168 Value imag =
169 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
170 result.setReal(rewriter, loc, real);
171 result.setImaginary(rewriter, loc, imag);
172
173 rewriter.replaceOp(op, {result});
174 return success();
175 }
176 };
177
178 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
179 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
180
181 LogicalResult
matchAndRewrite__anon33076c320111::DivOpConversion182 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
183 ConversionPatternRewriter &rewriter) const override {
184 auto loc = op.getLoc();
185 BinaryComplexOperands arg =
186 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
187
188 // Initialize complex number struct for result.
189 auto structType = typeConverter->convertType(op.getType());
190 auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
191
192 // Emit IR to add complex numbers.
193 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
194 Value rhsRe = arg.rhs.real();
195 Value rhsIm = arg.rhs.imag();
196 Value lhsRe = arg.lhs.real();
197 Value lhsIm = arg.lhs.imag();
198
199 Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
200 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
201 rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
202
203 Value resultReal = rewriter.create<LLVM::FAddOp>(
204 loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
205 rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
206
207 Value resultImag = rewriter.create<LLVM::FSubOp>(
208 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
209 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
210
211 result.setReal(
212 rewriter, loc,
213 rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
214 result.setImaginary(
215 rewriter, loc,
216 rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
217
218 rewriter.replaceOp(op, {result});
219 return success();
220 }
221 };
222
223 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
224 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
225
226 LogicalResult
matchAndRewrite__anon33076c320111::MulOpConversion227 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
228 ConversionPatternRewriter &rewriter) const override {
229 auto loc = op.getLoc();
230 BinaryComplexOperands arg =
231 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
232
233 // Initialize complex number struct for result.
234 auto structType = typeConverter->convertType(op.getType());
235 auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
236
237 // Emit IR to add complex numbers.
238 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
239 Value rhsRe = arg.rhs.real();
240 Value rhsIm = arg.rhs.imag();
241 Value lhsRe = arg.lhs.real();
242 Value lhsIm = arg.lhs.imag();
243
244 Value real = rewriter.create<LLVM::FSubOp>(
245 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
246 rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
247
248 Value imag = rewriter.create<LLVM::FAddOp>(
249 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
250 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
251
252 result.setReal(rewriter, loc, real);
253 result.setImaginary(rewriter, loc, imag);
254
255 rewriter.replaceOp(op, {result});
256 return success();
257 }
258 };
259
260 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
261 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
262
263 LogicalResult
matchAndRewrite__anon33076c320111::SubOpConversion264 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter) const override {
266 auto loc = op.getLoc();
267 BinaryComplexOperands arg =
268 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
269
270 // Initialize complex number struct for result.
271 auto structType = typeConverter->convertType(op.getType());
272 auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
273
274 // Emit IR to substract complex numbers.
275 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
276 Value real =
277 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
278 Value imag =
279 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
280 result.setReal(rewriter, loc, real);
281 result.setImaginary(rewriter, loc, imag);
282
283 rewriter.replaceOp(op, {result});
284 return success();
285 }
286 };
287 } // namespace
288
populateComplexToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)289 void mlir::populateComplexToLLVMConversionPatterns(
290 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
291 // clang-format off
292 patterns.add<
293 AbsOpConversion,
294 AddOpConversion,
295 CreateOpConversion,
296 DivOpConversion,
297 ImOpConversion,
298 MulOpConversion,
299 ReOpConversion,
300 SubOpConversion
301 >(converter);
302 // clang-format on
303 }
304
305 namespace {
306 struct ConvertComplexToLLVMPass
307 : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
308 void runOnOperation() override;
309 };
310 } // namespace
311
runOnOperation()312 void ConvertComplexToLLVMPass::runOnOperation() {
313 auto module = getOperation();
314
315 // Convert to the LLVM IR dialect using the converter defined above.
316 RewritePatternSet patterns(&getContext());
317 LLVMTypeConverter converter(&getContext());
318 populateComplexToLLVMConversionPatterns(converter, patterns);
319
320 LLVMConversionTarget target(getContext());
321 target.addLegalOp<ModuleOp, FuncOp>();
322 target.addIllegalDialect<complex::ComplexDialect>();
323 if (failed(applyPartialConversion(module, target, std::move(patterns))))
324 signalPassFailure();
325 }
326
327 std::unique_ptr<OperationPass<ModuleOp>>
createConvertComplexToLLVMPass()328 mlir::createConvertComplexToLLVMPass() {
329 return std::make_unique<ConvertComplexToLLVMPass>();
330 }
331