1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Dialect/Complex/IR/Complex.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16
17 using namespace mlir;
18 using namespace mlir::LLVM;
19
20 //===----------------------------------------------------------------------===//
21 // ComplexStructBuilder implementation.
22 //===----------------------------------------------------------------------===//
23
24 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
25 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
26
undef(OpBuilder & builder,Location loc,Type type)27 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
28 Location loc, Type type) {
29 Value val = builder.create<LLVM::UndefOp>(loc, type);
30 return ComplexStructBuilder(val);
31 }
32
setReal(OpBuilder & builder,Location loc,Value real)33 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
34 Value real) {
35 setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
36 }
37
real(OpBuilder & builder,Location loc)38 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
39 return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
40 }
41
setImaginary(OpBuilder & builder,Location loc,Value imaginary)42 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
43 Value imaginary) {
44 setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
45 }
46
imaginary(OpBuilder & builder,Location loc)47 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
48 return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
49 }
50
51 //===----------------------------------------------------------------------===//
52 // Conversion patterns.
53 //===----------------------------------------------------------------------===//
54
55 namespace {
56
57 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
58 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
59
60 LogicalResult
matchAndRewrite__anon199bdef30111::AbsOpConversion61 matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
62 ConversionPatternRewriter &rewriter) const override {
63 complex::AbsOp::Adaptor transformed(operands);
64 auto loc = op.getLoc();
65
66 ComplexStructBuilder complexStruct(transformed.complex());
67 Value real = complexStruct.real(rewriter, op.getLoc());
68 Value imag = complexStruct.imaginary(rewriter, op.getLoc());
69
70 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
71 Value sqNorm = rewriter.create<LLVM::FAddOp>(
72 loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
73 rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
74
75 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
76 return success();
77 }
78 };
79
80 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
81 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
82
83 LogicalResult
matchAndRewrite__anon199bdef30111::CreateOpConversion84 matchAndRewrite(complex::CreateOp complexOp, ArrayRef<Value> operands,
85 ConversionPatternRewriter &rewriter) const override {
86 complex::CreateOp::Adaptor transformed(operands);
87
88 // Pack real and imaginary part in a complex number struct.
89 auto loc = complexOp.getLoc();
90 auto structType = typeConverter->convertType(complexOp.getType());
91 auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
92 complexStruct.setReal(rewriter, loc, transformed.real());
93 complexStruct.setImaginary(rewriter, loc, transformed.imaginary());
94
95 rewriter.replaceOp(complexOp, {complexStruct});
96 return success();
97 }
98 };
99
100 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
101 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
102
103 LogicalResult
matchAndRewrite__anon199bdef30111::ReOpConversion104 matchAndRewrite(complex::ReOp op, ArrayRef<Value> operands,
105 ConversionPatternRewriter &rewriter) const override {
106 complex::ReOp::Adaptor transformed(operands);
107
108 // Extract real part from the complex number struct.
109 ComplexStructBuilder complexStruct(transformed.complex());
110 Value real = complexStruct.real(rewriter, op.getLoc());
111 rewriter.replaceOp(op, real);
112
113 return success();
114 }
115 };
116
117 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
118 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
119
120 LogicalResult
matchAndRewrite__anon199bdef30111::ImOpConversion121 matchAndRewrite(complex::ImOp op, ArrayRef<Value> operands,
122 ConversionPatternRewriter &rewriter) const override {
123 complex::ImOp::Adaptor transformed(operands);
124
125 // Extract imaginary part from the complex number struct.
126 ComplexStructBuilder complexStruct(transformed.complex());
127 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
128 rewriter.replaceOp(op, imaginary);
129
130 return success();
131 }
132 };
133
134 struct BinaryComplexOperands {
135 std::complex<Value> lhs;
136 std::complex<Value> rhs;
137 };
138
139 template <typename OpTy>
140 BinaryComplexOperands
unpackBinaryComplexOperands(OpTy op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)141 unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands,
142 ConversionPatternRewriter &rewriter) {
143 auto loc = op.getLoc();
144 typename OpTy::Adaptor transformed(operands);
145
146 // Extract real and imaginary values from operands.
147 BinaryComplexOperands unpacked;
148 ComplexStructBuilder lhs(transformed.lhs());
149 unpacked.lhs.real(lhs.real(rewriter, loc));
150 unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
151 ComplexStructBuilder rhs(transformed.rhs());
152 unpacked.rhs.real(rhs.real(rewriter, loc));
153 unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
154
155 return unpacked;
156 }
157
158 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
159 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
160
161 LogicalResult
matchAndRewrite__anon199bdef30111::AddOpConversion162 matchAndRewrite(complex::AddOp op, ArrayRef<Value> operands,
163 ConversionPatternRewriter &rewriter) const override {
164 auto loc = op.getLoc();
165 BinaryComplexOperands arg =
166 unpackBinaryComplexOperands<complex::AddOp>(op, operands, rewriter);
167
168 // Initialize complex number struct for result.
169 auto structType = typeConverter->convertType(op.getType());
170 auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
171
172 // Emit IR to add complex numbers.
173 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
174 Value real =
175 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
176 Value imag =
177 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
178 result.setReal(rewriter, loc, real);
179 result.setImaginary(rewriter, loc, imag);
180
181 rewriter.replaceOp(op, {result});
182 return success();
183 }
184 };
185
186 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
187 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
188
189 LogicalResult
matchAndRewrite__anon199bdef30111::DivOpConversion190 matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands,
191 ConversionPatternRewriter &rewriter) const override {
192 auto loc = op.getLoc();
193 BinaryComplexOperands arg =
194 unpackBinaryComplexOperands<complex::DivOp>(op, operands, rewriter);
195
196 // Initialize complex number struct for result.
197 auto structType = typeConverter->convertType(op.getType());
198 auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
199
200 // Emit IR to add complex numbers.
201 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
202 Value rhsRe = arg.rhs.real();
203 Value rhsIm = arg.rhs.imag();
204 Value lhsRe = arg.lhs.real();
205 Value lhsIm = arg.lhs.imag();
206
207 Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
208 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
209 rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
210
211 Value resultReal = rewriter.create<LLVM::FAddOp>(
212 loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
213 rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
214
215 Value resultImag = rewriter.create<LLVM::FSubOp>(
216 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
217 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
218
219 result.setReal(
220 rewriter, loc,
221 rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
222 result.setImaginary(
223 rewriter, loc,
224 rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
225
226 rewriter.replaceOp(op, {result});
227 return success();
228 }
229 };
230
231 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
232 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
233
234 LogicalResult
matchAndRewrite__anon199bdef30111::MulOpConversion235 matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands,
236 ConversionPatternRewriter &rewriter) const override {
237 auto loc = op.getLoc();
238 BinaryComplexOperands arg =
239 unpackBinaryComplexOperands<complex::MulOp>(op, operands, rewriter);
240
241 // Initialize complex number struct for result.
242 auto structType = typeConverter->convertType(op.getType());
243 auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
244
245 // Emit IR to add complex numbers.
246 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
247 Value rhsRe = arg.rhs.real();
248 Value rhsIm = arg.rhs.imag();
249 Value lhsRe = arg.lhs.real();
250 Value lhsIm = arg.lhs.imag();
251
252 Value real = rewriter.create<LLVM::FSubOp>(
253 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
254 rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
255
256 Value imag = rewriter.create<LLVM::FAddOp>(
257 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
258 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
259
260 result.setReal(rewriter, loc, real);
261 result.setImaginary(rewriter, loc, imag);
262
263 rewriter.replaceOp(op, {result});
264 return success();
265 }
266 };
267
268 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
269 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
270
271 LogicalResult
matchAndRewrite__anon199bdef30111::SubOpConversion272 matchAndRewrite(complex::SubOp op, ArrayRef<Value> operands,
273 ConversionPatternRewriter &rewriter) const override {
274 auto loc = op.getLoc();
275 BinaryComplexOperands arg =
276 unpackBinaryComplexOperands<complex::SubOp>(op, operands, rewriter);
277
278 // Initialize complex number struct for result.
279 auto structType = typeConverter->convertType(op.getType());
280 auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
281
282 // Emit IR to substract complex numbers.
283 auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
284 Value real =
285 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
286 Value imag =
287 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
288 result.setReal(rewriter, loc, real);
289 result.setImaginary(rewriter, loc, imag);
290
291 rewriter.replaceOp(op, {result});
292 return success();
293 }
294 };
295 } // namespace
296
populateComplexToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)297 void mlir::populateComplexToLLVMConversionPatterns(
298 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
299 // clang-format off
300 patterns.add<
301 AbsOpConversion,
302 AddOpConversion,
303 CreateOpConversion,
304 DivOpConversion,
305 ImOpConversion,
306 MulOpConversion,
307 ReOpConversion,
308 SubOpConversion
309 >(converter);
310 // clang-format on
311 }
312
313 namespace {
314 struct ConvertComplexToLLVMPass
315 : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
316 void runOnOperation() override;
317 };
318 } // namespace
319
runOnOperation()320 void ConvertComplexToLLVMPass::runOnOperation() {
321 auto module = getOperation();
322
323 // Convert to the LLVM IR dialect using the converter defined above.
324 RewritePatternSet patterns(&getContext());
325 LLVMTypeConverter converter(&getContext());
326 populateComplexToLLVMConversionPatterns(converter, patterns);
327
328 LLVMConversionTarget target(getContext());
329 target.addLegalOp<ModuleOp, FuncOp>();
330 target.addIllegalDialect<complex::ComplexDialect>();
331 if (failed(applyPartialConversion(module, target, std::move(patterns))))
332 signalPassFailure();
333 }
334
335 std::unique_ptr<OperationPass<ModuleOp>>
createConvertComplexToLLVMPass()336 mlir::createConvertComplexToLLVMPass() {
337 return std::make_unique<ConvertComplexToLLVMPass>();
338 }
339