1 //===- Utils.h - Utilities to support the Linalg dialect --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_LINALG_UTILS_H_
10 #define MLIR_DIALECT_LINALG_UTILS_H_
11 
12 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
13 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 
19 #include "llvm/ADT/SetVector.h"
20 
21 using mlir::edsc::intrinsics::AffineIndexedValue;
22 using mlir::edsc::intrinsics::StdIndexedValue;
23 
24 namespace mlir {
25 class AffineExpr;
26 class AffineForOp;
27 class AffineMap;
28 class OperationFolder;
29 class PatternRewriter;
30 
31 namespace linalg {
32 class LinalgDependenceGraph;
33 
34 struct FusionInfo {
35   LinalgOp originalProducer;
36   LinalgOp fusedProducer;
37 };
38 
39 /// A struct containing common matchers over linalg op's region.
40 struct RegionMatcher {
41   enum class BinaryOpKind {
42     IAdd,
43   };
44 
45   /// Matches the given linalg op if its body is performing binary operation on
46   /// int or float scalar values and returns the binary op kind.
47   ///
48   /// The linalg op's region is expected to be
49   /// ```
50   /// {
51   ///   ^bb(%a: <scalar-type>, %b: <scalar-type>):
52   ///     %0 = <binary-op> %a, %b: <scalar-type>
53   ///     linalg.yield %0: <scalar-type>
54   /// }
55   /// ```
56   static Optional<BinaryOpKind> matchAsScalarBinaryOp(GenericOp op);
57 };
58 
59 /// Checks if an iterator_type attribute is parallel.
60 bool isParallelIteratorType(Attribute attr);
61 
62 /// Checks if an iterator_type attribute is parallel.
63 bool isReductionIteratorType(Attribute attr);
64 
65 /// Checks if an iterator_type attribute is parallel.
66 bool isWindowIteratorType(Attribute attr);
67 
68 /// Checks whether the specific `producer` is the last write to exactly the
69 /// whole `consumedView`. This checks structural dominance, that the dependence
70 /// is a RAW without any interleaved write to any piece of `consumedView`.
71 bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
72                                LinalgOp consumer, Value consumedView,
73                                LinalgOp producer);
74 
75 /// Checks whether fusing the specific `producer` of the `consumedView` is
76 /// feasible. This checks `producer` is the last write of `consumedView` and
77 /// that no interleaved dependence would be violated (RAW, WAR or WAW).
78 bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
79                    Value consumedView, LinalgOp producer);
80 
81 /// Fuses producer into consumer if the producer is structurally feasible and
82 /// the fusion would not violate dependencies.
83 /// When non-null, the optional pointer `folder` is used to call into the
84 /// `createAndFold` builder method. If `folder` is null, the regular `create`
85 /// method is called.
86 Optional<FusionInfo> fuseProducerOf(OpBuilder &b, LinalgOp consumer,
87                                     unsigned consumerIdx,
88                                     const LinalgDependenceGraph &graph,
89                                     OperationFolder *folder = nullptr);
90 
91 /// Fuse linalg operation on tensors, with the producer of the operand at
92 /// position `consumerIdx` of the consumer.
93 Operation *fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
94                          unsigned consumerIdx,
95                          OperationFolder *folder = nullptr);
96 
97 /// Returns the linearized list of all view dimensions in a linalgOp. Applying
98 /// the inverse, concatenated loopToOperandRangeMaps to this list allows the
99 /// derivation of loop ranges for any linalgOp.
100 template <typename ConcreteOp>
getViewSizes(OpBuilder & builder,ConcreteOp linalgOp)101 SmallVector<Value, 8> getViewSizes(OpBuilder &builder, ConcreteOp linalgOp) {
102   auto loc = linalgOp.getLoc();
103   SmallVector<Value, 8> res;
104   for (auto v : linalgOp.getInputsAndOutputBuffers()) {
105     MemRefType t = v.getType().template cast<MemRefType>();
106     for (unsigned i = 0; i < t.getRank(); ++i)
107       res.push_back(builder.create<DimOp>(loc, v, i));
108   }
109   return res;
110 }
111 
112 /// Returns the values obtained by applying `map` to the list of values.
113 /// When non-null, the optional pointer `folder` is used to call into the
114 /// `createAndFold` builder method. If `folder` is null, the regular `create`
115 /// method is called.
116 SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
117                                        AffineMap map, ArrayRef<Value> values,
118                                        OperationFolder *folder = nullptr);
119 
120 /// Returns all the operands of `linalgOp` that are not views.
121 /// Asserts that these operands are value types to allow transformations like
122 /// tiling to just use the values when cloning `linalgOp`.
123 SmallVector<Value, 4> getAssumedNonViewOperands(LinalgOp linalgOp);
124 
125 /// Apply the permutation defined by `permutation` to `inVec`.
126 /// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
127 /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
128 /// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
129 template <typename T, unsigned N>
applyPermutationToVector(SmallVector<T,N> & inVec,ArrayRef<unsigned> permutation)130 void applyPermutationToVector(SmallVector<T, N> &inVec,
131                               ArrayRef<unsigned> permutation) {
132   SmallVector<T, N> auxVec(inVec.size());
133   for (unsigned i = 0; i < permutation.size(); ++i)
134     auxVec[i] = inVec[permutation[i]];
135   inVec = auxVec;
136 }
137 
138 /// Utility class used to generate nested loops with ranges described by
139 /// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn`
140 /// is used to generate the body of the innermost loop. It is passed a range
141 /// of loop induction variables.
142 template <typename LoopTy>
143 struct GenerateLoopNest {
144   using IndexedValueTy =
145       typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
146                                 AffineIndexedValue, StdIndexedValue>::type;
147 
148   static void doit(ArrayRef<SubViewOp::Range> loopRanges,
149                    ArrayRef<Attribute> iteratorTypes,
150                    function_ref<void(ValueRange)> bodyBuilderFn);
151 };
152 
153 } // namespace linalg
154 } // namespace mlir
155 
156 #endif // MLIR_DIALECT_LINALG_UTILS_H_
157