1 //===- Utils.h - Utilities to support the Linalg dialect --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_DIALECT_LINALG_UTILS_H_
10 #define MLIR_DIALECT_LINALG_UTILS_H_
11
12 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
13 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18
19 #include "llvm/ADT/SetVector.h"
20
21 using mlir::edsc::intrinsics::AffineIndexedValue;
22 using mlir::edsc::intrinsics::StdIndexedValue;
23
24 namespace mlir {
25 class AffineExpr;
26 class AffineForOp;
27 class AffineMap;
28 class OperationFolder;
29 class PatternRewriter;
30
31 namespace linalg {
32 class LinalgDependenceGraph;
33
34 struct FusionInfo {
35 LinalgOp originalProducer;
36 LinalgOp fusedProducer;
37 };
38
39 /// A struct containing common matchers over linalg op's region.
40 struct RegionMatcher {
41 enum class BinaryOpKind {
42 IAdd,
43 };
44
45 /// Matches the given linalg op if its body is performing binary operation on
46 /// int or float scalar values and returns the binary op kind.
47 ///
48 /// The linalg op's region is expected to be
49 /// ```
50 /// {
51 /// ^bb(%a: <scalar-type>, %b: <scalar-type>):
52 /// %0 = <binary-op> %a, %b: <scalar-type>
53 /// linalg.yield %0: <scalar-type>
54 /// }
55 /// ```
56 static Optional<BinaryOpKind> matchAsScalarBinaryOp(GenericOp op);
57 };
58
59 /// Checks if an iterator_type attribute is parallel.
60 bool isParallelIteratorType(Attribute attr);
61
62 /// Checks if an iterator_type attribute is parallel.
63 bool isReductionIteratorType(Attribute attr);
64
65 /// Checks if an iterator_type attribute is parallel.
66 bool isWindowIteratorType(Attribute attr);
67
68 /// Checks whether the specific `producer` is the last write to exactly the
69 /// whole `consumedView`. This checks structural dominance, that the dependence
70 /// is a RAW without any interleaved write to any piece of `consumedView`.
71 bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
72 LinalgOp consumer, Value consumedView,
73 LinalgOp producer);
74
75 /// Checks whether fusing the specific `producer` of the `consumedView` is
76 /// feasible. This checks `producer` is the last write of `consumedView` and
77 /// that no interleaved dependence would be violated (RAW, WAR or WAW).
78 bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
79 Value consumedView, LinalgOp producer);
80
81 /// Fuses producer into consumer if the producer is structurally feasible and
82 /// the fusion would not violate dependencies.
83 /// When non-null, the optional pointer `folder` is used to call into the
84 /// `createAndFold` builder method. If `folder` is null, the regular `create`
85 /// method is called.
86 Optional<FusionInfo> fuseProducerOf(OpBuilder &b, LinalgOp consumer,
87 unsigned consumerIdx,
88 const LinalgDependenceGraph &graph,
89 OperationFolder *folder = nullptr);
90
91 /// Fuse linalg operation on tensors, with the producer of the operand at
92 /// position `consumerIdx` of the consumer.
93 Operation *fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
94 unsigned consumerIdx,
95 OperationFolder *folder = nullptr);
96
97 /// Returns the linearized list of all view dimensions in a linalgOp. Applying
98 /// the inverse, concatenated loopToOperandRangeMaps to this list allows the
99 /// derivation of loop ranges for any linalgOp.
100 template <typename ConcreteOp>
getViewSizes(OpBuilder & builder,ConcreteOp linalgOp)101 SmallVector<Value, 8> getViewSizes(OpBuilder &builder, ConcreteOp linalgOp) {
102 auto loc = linalgOp.getLoc();
103 SmallVector<Value, 8> res;
104 for (auto v : linalgOp.getInputsAndOutputBuffers()) {
105 MemRefType t = v.getType().template cast<MemRefType>();
106 for (unsigned i = 0; i < t.getRank(); ++i)
107 res.push_back(builder.create<DimOp>(loc, v, i));
108 }
109 return res;
110 }
111
112 /// Returns the values obtained by applying `map` to the list of values.
113 /// When non-null, the optional pointer `folder` is used to call into the
114 /// `createAndFold` builder method. If `folder` is null, the regular `create`
115 /// method is called.
116 SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
117 AffineMap map, ArrayRef<Value> values,
118 OperationFolder *folder = nullptr);
119
120 /// Returns all the operands of `linalgOp` that are not views.
121 /// Asserts that these operands are value types to allow transformations like
122 /// tiling to just use the values when cloning `linalgOp`.
123 SmallVector<Value, 4> getAssumedNonViewOperands(LinalgOp linalgOp);
124
125 /// Apply the permutation defined by `permutation` to `inVec`.
126 /// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
127 /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
128 /// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
129 template <typename T, unsigned N>
applyPermutationToVector(SmallVector<T,N> & inVec,ArrayRef<unsigned> permutation)130 void applyPermutationToVector(SmallVector<T, N> &inVec,
131 ArrayRef<unsigned> permutation) {
132 SmallVector<T, N> auxVec(inVec.size());
133 for (unsigned i = 0; i < permutation.size(); ++i)
134 auxVec[i] = inVec[permutation[i]];
135 inVec = auxVec;
136 }
137
138 /// Utility class used to generate nested loops with ranges described by
139 /// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn`
140 /// is used to generate the body of the innermost loop. It is passed a range
141 /// of loop induction variables.
142 template <typename LoopTy>
143 struct GenerateLoopNest {
144 using IndexedValueTy =
145 typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
146 AffineIndexedValue, StdIndexedValue>::type;
147
148 static void doit(ArrayRef<SubViewOp::Range> loopRanges,
149 ArrayRef<Attribute> iteratorTypes,
150 function_ref<void(ValueRange)> bodyBuilderFn);
151 };
152
153 } // namespace linalg
154 } // namespace mlir
155
156 #endif // MLIR_DIALECT_LINALG_UTILS_H_
157