1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Dialect/Complex/IR/Complex.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 
17 using namespace mlir;
18 using namespace mlir::LLVM;
19 
20 //===----------------------------------------------------------------------===//
21 // ComplexStructBuilder implementation.
22 //===----------------------------------------------------------------------===//
23 
24 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
25 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
26 
undef(OpBuilder & builder,Location loc,Type type)27 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
28                                                  Location loc, Type type) {
29   Value val = builder.create<LLVM::UndefOp>(loc, type);
30   return ComplexStructBuilder(val);
31 }
32 
setReal(OpBuilder & builder,Location loc,Value real)33 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
34                                    Value real) {
35   setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
36 }
37 
real(OpBuilder & builder,Location loc)38 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
39   return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
40 }
41 
setImaginary(OpBuilder & builder,Location loc,Value imaginary)42 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
43                                         Value imaginary) {
44   setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
45 }
46 
imaginary(OpBuilder & builder,Location loc)47 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
48   return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
49 }
50 
51 //===----------------------------------------------------------------------===//
52 // Conversion patterns.
53 //===----------------------------------------------------------------------===//
54 
55 namespace {
56 
57 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
58   using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
59 
60   LogicalResult
matchAndRewrite__anon33076c320111::AbsOpConversion61   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
62                   ConversionPatternRewriter &rewriter) const override {
63     auto loc = op.getLoc();
64 
65     ComplexStructBuilder complexStruct(adaptor.complex());
66     Value real = complexStruct.real(rewriter, op.getLoc());
67     Value imag = complexStruct.imaginary(rewriter, op.getLoc());
68 
69     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
70     Value sqNorm = rewriter.create<LLVM::FAddOp>(
71         loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
72         rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
73 
74     rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
75     return success();
76   }
77 };
78 
79 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
80   using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
81 
82   LogicalResult
matchAndRewrite__anon33076c320111::CreateOpConversion83   matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
84                   ConversionPatternRewriter &rewriter) const override {
85     // Pack real and imaginary part in a complex number struct.
86     auto loc = complexOp.getLoc();
87     auto structType = typeConverter->convertType(complexOp.getType());
88     auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
89     complexStruct.setReal(rewriter, loc, adaptor.real());
90     complexStruct.setImaginary(rewriter, loc, adaptor.imaginary());
91 
92     rewriter.replaceOp(complexOp, {complexStruct});
93     return success();
94   }
95 };
96 
97 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
98   using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
99 
100   LogicalResult
matchAndRewrite__anon33076c320111::ReOpConversion101   matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
102                   ConversionPatternRewriter &rewriter) const override {
103     // Extract real part from the complex number struct.
104     ComplexStructBuilder complexStruct(adaptor.complex());
105     Value real = complexStruct.real(rewriter, op.getLoc());
106     rewriter.replaceOp(op, real);
107 
108     return success();
109   }
110 };
111 
112 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
113   using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
114 
115   LogicalResult
matchAndRewrite__anon33076c320111::ImOpConversion116   matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
117                   ConversionPatternRewriter &rewriter) const override {
118     // Extract imaginary part from the complex number struct.
119     ComplexStructBuilder complexStruct(adaptor.complex());
120     Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
121     rewriter.replaceOp(op, imaginary);
122 
123     return success();
124   }
125 };
126 
127 struct BinaryComplexOperands {
128   std::complex<Value> lhs;
129   std::complex<Value> rhs;
130 };
131 
132 template <typename OpTy>
133 BinaryComplexOperands
unpackBinaryComplexOperands(OpTy op,typename OpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter)134 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
135                             ConversionPatternRewriter &rewriter) {
136   auto loc = op.getLoc();
137 
138   // Extract real and imaginary values from operands.
139   BinaryComplexOperands unpacked;
140   ComplexStructBuilder lhs(adaptor.lhs());
141   unpacked.lhs.real(lhs.real(rewriter, loc));
142   unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
143   ComplexStructBuilder rhs(adaptor.rhs());
144   unpacked.rhs.real(rhs.real(rewriter, loc));
145   unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
146 
147   return unpacked;
148 }
149 
150 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
151   using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
152 
153   LogicalResult
matchAndRewrite__anon33076c320111::AddOpConversion154   matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
155                   ConversionPatternRewriter &rewriter) const override {
156     auto loc = op.getLoc();
157     BinaryComplexOperands arg =
158         unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
159 
160     // Initialize complex number struct for result.
161     auto structType = typeConverter->convertType(op.getType());
162     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
163 
164     // Emit IR to add complex numbers.
165     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
166     Value real =
167         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
168     Value imag =
169         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
170     result.setReal(rewriter, loc, real);
171     result.setImaginary(rewriter, loc, imag);
172 
173     rewriter.replaceOp(op, {result});
174     return success();
175   }
176 };
177 
178 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
179   using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
180 
181   LogicalResult
matchAndRewrite__anon33076c320111::DivOpConversion182   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
183                   ConversionPatternRewriter &rewriter) const override {
184     auto loc = op.getLoc();
185     BinaryComplexOperands arg =
186         unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
187 
188     // Initialize complex number struct for result.
189     auto structType = typeConverter->convertType(op.getType());
190     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
191 
192     // Emit IR to add complex numbers.
193     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
194     Value rhsRe = arg.rhs.real();
195     Value rhsIm = arg.rhs.imag();
196     Value lhsRe = arg.lhs.real();
197     Value lhsIm = arg.lhs.imag();
198 
199     Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
200         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
201         rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
202 
203     Value resultReal = rewriter.create<LLVM::FAddOp>(
204         loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
205         rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
206 
207     Value resultImag = rewriter.create<LLVM::FSubOp>(
208         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
209         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
210 
211     result.setReal(
212         rewriter, loc,
213         rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
214     result.setImaginary(
215         rewriter, loc,
216         rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
217 
218     rewriter.replaceOp(op, {result});
219     return success();
220   }
221 };
222 
223 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
224   using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
225 
226   LogicalResult
matchAndRewrite__anon33076c320111::MulOpConversion227   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
228                   ConversionPatternRewriter &rewriter) const override {
229     auto loc = op.getLoc();
230     BinaryComplexOperands arg =
231         unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
232 
233     // Initialize complex number struct for result.
234     auto structType = typeConverter->convertType(op.getType());
235     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
236 
237     // Emit IR to add complex numbers.
238     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
239     Value rhsRe = arg.rhs.real();
240     Value rhsIm = arg.rhs.imag();
241     Value lhsRe = arg.lhs.real();
242     Value lhsIm = arg.lhs.imag();
243 
244     Value real = rewriter.create<LLVM::FSubOp>(
245         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
246         rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
247 
248     Value imag = rewriter.create<LLVM::FAddOp>(
249         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
250         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
251 
252     result.setReal(rewriter, loc, real);
253     result.setImaginary(rewriter, loc, imag);
254 
255     rewriter.replaceOp(op, {result});
256     return success();
257   }
258 };
259 
260 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
261   using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
262 
263   LogicalResult
matchAndRewrite__anon33076c320111::SubOpConversion264   matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
265                   ConversionPatternRewriter &rewriter) const override {
266     auto loc = op.getLoc();
267     BinaryComplexOperands arg =
268         unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
269 
270     // Initialize complex number struct for result.
271     auto structType = typeConverter->convertType(op.getType());
272     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
273 
274     // Emit IR to substract complex numbers.
275     auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
276     Value real =
277         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
278     Value imag =
279         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
280     result.setReal(rewriter, loc, real);
281     result.setImaginary(rewriter, loc, imag);
282 
283     rewriter.replaceOp(op, {result});
284     return success();
285   }
286 };
287 } // namespace
288 
populateComplexToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)289 void mlir::populateComplexToLLVMConversionPatterns(
290     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
291   // clang-format off
292   patterns.add<
293       AbsOpConversion,
294       AddOpConversion,
295       CreateOpConversion,
296       DivOpConversion,
297       ImOpConversion,
298       MulOpConversion,
299       ReOpConversion,
300       SubOpConversion
301     >(converter);
302   // clang-format on
303 }
304 
305 namespace {
306 struct ConvertComplexToLLVMPass
307     : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
308   void runOnOperation() override;
309 };
310 } // namespace
311 
runOnOperation()312 void ConvertComplexToLLVMPass::runOnOperation() {
313   auto module = getOperation();
314 
315   // Convert to the LLVM IR dialect using the converter defined above.
316   RewritePatternSet patterns(&getContext());
317   LLVMTypeConverter converter(&getContext());
318   populateComplexToLLVMConversionPatterns(converter, patterns);
319 
320   LLVMConversionTarget target(getContext());
321   target.addLegalOp<ModuleOp, FuncOp>();
322   target.addIllegalDialect<complex::ComplexDialect>();
323   if (failed(applyPartialConversion(module, target, std::move(patterns))))
324     signalPassFailure();
325 }
326 
327 std::unique_ptr<OperationPass<ModuleOp>>
createConvertComplexToLLVMPass()328 mlir::createConvertComplexToLLVMPass() {
329   return std::make_unique<ConvertComplexToLLVMPass>();
330 }
331