1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Dialect/Complex/IR/Complex.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 
17 using namespace mlir;
18 using namespace mlir::LLVM;
19 
20 //===----------------------------------------------------------------------===//
21 // ComplexStructBuilder implementation.
22 //===----------------------------------------------------------------------===//
23 
24 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
25 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
26 
undef(OpBuilder & builder,Location loc,Type type)27 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
28                                                  Location loc, Type type) {
29   Value val = builder.create<LLVM::UndefOp>(loc, type);
30   return ComplexStructBuilder(val);
31 }
32 
setReal(OpBuilder & builder,Location loc,Value real)33 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
34                                    Value real) {
35   setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
36 }
37 
real(OpBuilder & builder,Location loc)38 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
39   return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
40 }
41 
setImaginary(OpBuilder & builder,Location loc,Value imaginary)42 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
43                                         Value imaginary) {
44   setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
45 }
46 
imaginary(OpBuilder & builder,Location loc)47 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
48   return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
49 }
50 
51 //===----------------------------------------------------------------------===//
52 // Conversion patterns.
53 //===----------------------------------------------------------------------===//
54 
55 namespace {
56 
57 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
58   using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
59 
60   LogicalResult
matchAndRewrite__anon199bdef30111::AbsOpConversion61   matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
62                   ConversionPatternRewriter &rewriter) const override {
63     complex::AbsOp::Adaptor transformed(operands);
64     auto loc = op.getLoc();
65 
66     ComplexStructBuilder complexStruct(transformed.complex());
67     Value real = complexStruct.real(rewriter, op.getLoc());
68     Value imag = complexStruct.imaginary(rewriter, op.getLoc());
69 
70     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
71     Value sqNorm = rewriter.create<LLVM::FAddOp>(
72         loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
73         rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
74 
75     rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
76     return success();
77   }
78 };
79 
80 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
81   using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
82 
83   LogicalResult
matchAndRewrite__anon199bdef30111::CreateOpConversion84   matchAndRewrite(complex::CreateOp complexOp, ArrayRef<Value> operands,
85                   ConversionPatternRewriter &rewriter) const override {
86     complex::CreateOp::Adaptor transformed(operands);
87 
88     // Pack real and imaginary part in a complex number struct.
89     auto loc = complexOp.getLoc();
90     auto structType = typeConverter->convertType(complexOp.getType());
91     auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
92     complexStruct.setReal(rewriter, loc, transformed.real());
93     complexStruct.setImaginary(rewriter, loc, transformed.imaginary());
94 
95     rewriter.replaceOp(complexOp, {complexStruct});
96     return success();
97   }
98 };
99 
100 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
101   using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
102 
103   LogicalResult
matchAndRewrite__anon199bdef30111::ReOpConversion104   matchAndRewrite(complex::ReOp op, ArrayRef<Value> operands,
105                   ConversionPatternRewriter &rewriter) const override {
106     complex::ReOp::Adaptor transformed(operands);
107 
108     // Extract real part from the complex number struct.
109     ComplexStructBuilder complexStruct(transformed.complex());
110     Value real = complexStruct.real(rewriter, op.getLoc());
111     rewriter.replaceOp(op, real);
112 
113     return success();
114   }
115 };
116 
117 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
118   using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
119 
120   LogicalResult
matchAndRewrite__anon199bdef30111::ImOpConversion121   matchAndRewrite(complex::ImOp op, ArrayRef<Value> operands,
122                   ConversionPatternRewriter &rewriter) const override {
123     complex::ImOp::Adaptor transformed(operands);
124 
125     // Extract imaginary part from the complex number struct.
126     ComplexStructBuilder complexStruct(transformed.complex());
127     Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
128     rewriter.replaceOp(op, imaginary);
129 
130     return success();
131   }
132 };
133 
134 struct BinaryComplexOperands {
135   std::complex<Value> lhs;
136   std::complex<Value> rhs;
137 };
138 
139 template <typename OpTy>
140 BinaryComplexOperands
unpackBinaryComplexOperands(OpTy op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)141 unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands,
142                             ConversionPatternRewriter &rewriter) {
143   auto loc = op.getLoc();
144   typename OpTy::Adaptor transformed(operands);
145 
146   // Extract real and imaginary values from operands.
147   BinaryComplexOperands unpacked;
148   ComplexStructBuilder lhs(transformed.lhs());
149   unpacked.lhs.real(lhs.real(rewriter, loc));
150   unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
151   ComplexStructBuilder rhs(transformed.rhs());
152   unpacked.rhs.real(rhs.real(rewriter, loc));
153   unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
154 
155   return unpacked;
156 }
157 
158 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
159   using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
160 
161   LogicalResult
matchAndRewrite__anon199bdef30111::AddOpConversion162   matchAndRewrite(complex::AddOp op, ArrayRef<Value> operands,
163                   ConversionPatternRewriter &rewriter) const override {
164     auto loc = op.getLoc();
165     BinaryComplexOperands arg =
166         unpackBinaryComplexOperands<complex::AddOp>(op, operands, rewriter);
167 
168     // Initialize complex number struct for result.
169     auto structType = typeConverter->convertType(op.getType());
170     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
171 
172     // Emit IR to add complex numbers.
173     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
174     Value real =
175         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
176     Value imag =
177         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
178     result.setReal(rewriter, loc, real);
179     result.setImaginary(rewriter, loc, imag);
180 
181     rewriter.replaceOp(op, {result});
182     return success();
183   }
184 };
185 
186 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
187   using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
188 
189   LogicalResult
matchAndRewrite__anon199bdef30111::DivOpConversion190   matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands,
191                   ConversionPatternRewriter &rewriter) const override {
192     auto loc = op.getLoc();
193     BinaryComplexOperands arg =
194         unpackBinaryComplexOperands<complex::DivOp>(op, operands, rewriter);
195 
196     // Initialize complex number struct for result.
197     auto structType = typeConverter->convertType(op.getType());
198     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
199 
200     // Emit IR to add complex numbers.
201     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
202     Value rhsRe = arg.rhs.real();
203     Value rhsIm = arg.rhs.imag();
204     Value lhsRe = arg.lhs.real();
205     Value lhsIm = arg.lhs.imag();
206 
207     Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
208         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
209         rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
210 
211     Value resultReal = rewriter.create<LLVM::FAddOp>(
212         loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
213         rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
214 
215     Value resultImag = rewriter.create<LLVM::FSubOp>(
216         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
217         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
218 
219     result.setReal(
220         rewriter, loc,
221         rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
222     result.setImaginary(
223         rewriter, loc,
224         rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
225 
226     rewriter.replaceOp(op, {result});
227     return success();
228   }
229 };
230 
231 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
232   using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
233 
234   LogicalResult
matchAndRewrite__anon199bdef30111::MulOpConversion235   matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands,
236                   ConversionPatternRewriter &rewriter) const override {
237     auto loc = op.getLoc();
238     BinaryComplexOperands arg =
239         unpackBinaryComplexOperands<complex::MulOp>(op, operands, rewriter);
240 
241     // Initialize complex number struct for result.
242     auto structType = typeConverter->convertType(op.getType());
243     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
244 
245     // Emit IR to add complex numbers.
246     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
247     Value rhsRe = arg.rhs.real();
248     Value rhsIm = arg.rhs.imag();
249     Value lhsRe = arg.lhs.real();
250     Value lhsIm = arg.lhs.imag();
251 
252     Value real = rewriter.create<LLVM::FSubOp>(
253         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
254         rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
255 
256     Value imag = rewriter.create<LLVM::FAddOp>(
257         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
258         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
259 
260     result.setReal(rewriter, loc, real);
261     result.setImaginary(rewriter, loc, imag);
262 
263     rewriter.replaceOp(op, {result});
264     return success();
265   }
266 };
267 
268 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
269   using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
270 
271   LogicalResult
matchAndRewrite__anon199bdef30111::SubOpConversion272   matchAndRewrite(complex::SubOp op, ArrayRef<Value> operands,
273                   ConversionPatternRewriter &rewriter) const override {
274     auto loc = op.getLoc();
275     BinaryComplexOperands arg =
276         unpackBinaryComplexOperands<complex::SubOp>(op, operands, rewriter);
277 
278     // Initialize complex number struct for result.
279     auto structType = typeConverter->convertType(op.getType());
280     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
281 
282     // Emit IR to substract complex numbers.
283     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
284     Value real =
285         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
286     Value imag =
287         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
288     result.setReal(rewriter, loc, real);
289     result.setImaginary(rewriter, loc, imag);
290 
291     rewriter.replaceOp(op, {result});
292     return success();
293   }
294 };
295 } // namespace
296 
populateComplexToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)297 void mlir::populateComplexToLLVMConversionPatterns(
298     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
299   // clang-format off
300   patterns.add<
301       AbsOpConversion,
302       AddOpConversion,
303       CreateOpConversion,
304       DivOpConversion,
305       ImOpConversion,
306       MulOpConversion,
307       ReOpConversion,
308       SubOpConversion
309     >(converter);
310   // clang-format on
311 }
312 
313 namespace {
314 struct ConvertComplexToLLVMPass
315     : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
316   void runOnOperation() override;
317 };
318 } // namespace
319 
runOnOperation()320 void ConvertComplexToLLVMPass::runOnOperation() {
321   auto module = getOperation();
322 
323   // Convert to the LLVM IR dialect using the converter defined above.
324   RewritePatternSet patterns(&getContext());
325   LLVMTypeConverter converter(&getContext());
326   populateComplexToLLVMConversionPatterns(converter, patterns);
327 
328   LLVMConversionTarget target(getContext());
329   target.addLegalOp<ModuleOp, FuncOp>();
330   target.addIllegalDialect<complex::ComplexDialect>();
331   if (failed(applyPartialConversion(module, target, std::move(patterns))))
332     signalPassFailure();
333 }
334 
335 std::unique_ptr<OperationPass<ModuleOp>>
createConvertComplexToLLVMPass()336 mlir::createConvertComplexToLLVMPass() {
337   return std::make_unique<ConvertComplexToLLVMPass>();
338 }
339