1 //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Complex/IR/Complex.h"
10 #include "mlir/IR/Builders.h"
11 
12 using namespace mlir;
13 using namespace mlir::complex;
14 
15 //===----------------------------------------------------------------------===//
16 // TableGen'd op method definitions
17 //===----------------------------------------------------------------------===//
18 
19 #define GET_OP_CLASSES
20 #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
21 
fold(ArrayRef<Attribute> operands)22 OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
23   assert(operands.size() == 2 && "binary op takes two operands");
24   // Fold complex.create(complex.re(op), complex.im(op)).
25   if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
26     if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
27       if (reOp.getOperand() == imOp.getOperand()) {
28         return reOp.getOperand();
29       }
30     }
31   }
32   return {};
33 }
34 
fold(ArrayRef<Attribute> operands)35 OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
36   assert(operands.size() == 1 && "unary op takes 1 operand");
37   ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
38   if (arrayAttr && arrayAttr.size() == 2)
39     return arrayAttr[1];
40   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
41     return createOp.getOperand(1);
42   return {};
43 }
44 
fold(ArrayRef<Attribute> operands)45 OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
46   assert(operands.size() == 1 && "unary op takes 1 operand");
47   ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
48   if (arrayAttr && arrayAttr.size() == 2)
49     return arrayAttr[0];
50   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
51     return createOp.getOperand(0);
52   return {};
53 }
54