1 //===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file declares various common operation folders. These folders
10 // are intended to be used by dialects to support common folding behavior
11 // without requiring each dialect to provide its own implementation.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_DIALECT_COMMONFOLDERS_H
16 #define MLIR_DIALECT_COMMONFOLDERS_H
17 
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 
23 namespace mlir {
24 /// Performs constant folding `calculate` with element-wise behavior on the two
25 /// attributes in `operands` and returns the result if possible.
26 template <class AttrElementT,
27           class ElementValueT = typename AttrElementT::ValueType,
28           class CalculationT =
29               function_ref<ElementValueT(ElementValueT, ElementValueT)>>
constFoldBinaryOp(ArrayRef<Attribute> operands,const CalculationT & calculate)30 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
31                             const CalculationT &calculate) {
32   assert(operands.size() == 2 && "binary op takes two operands");
33   if (!operands[0] || !operands[1])
34     return {};
35   if (operands[0].getType() != operands[1].getType())
36     return {};
37 
38   if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
39     auto lhs = operands[0].cast<AttrElementT>();
40     auto rhs = operands[1].cast<AttrElementT>();
41 
42     return AttrElementT::get(lhs.getType(),
43                              calculate(lhs.getValue(), rhs.getValue()));
44   } else if (operands[0].isa<SplatElementsAttr>() &&
45              operands[1].isa<SplatElementsAttr>()) {
46     // Both operands are splats so we can avoid expanding the values out and
47     // just fold based on the splat value.
48     auto lhs = operands[0].cast<SplatElementsAttr>();
49     auto rhs = operands[1].cast<SplatElementsAttr>();
50 
51     auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
52                                    rhs.getSplatValue<ElementValueT>());
53     return DenseElementsAttr::get(lhs.getType(), elementResult);
54   } else if (operands[0].isa<ElementsAttr>() &&
55              operands[1].isa<ElementsAttr>()) {
56     // Operands are ElementsAttr-derived; perform an element-wise fold by
57     // expanding the values.
58     auto lhs = operands[0].cast<ElementsAttr>();
59     auto rhs = operands[1].cast<ElementsAttr>();
60 
61     auto lhsIt = lhs.getValues<ElementValueT>().begin();
62     auto rhsIt = rhs.getValues<ElementValueT>().begin();
63     SmallVector<ElementValueT, 4> elementResults;
64     elementResults.reserve(lhs.getNumElements());
65     for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)
66       elementResults.push_back(calculate(*lhsIt, *rhsIt));
67     return DenseElementsAttr::get(lhs.getType(), elementResults);
68   }
69   return {};
70 }
71 } // namespace mlir
72 
73 #endif // MLIR_DIALECT_COMMONFOLDERS_H
74