1 //===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
10 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/EDSC/Builders.h"
13 #include "mlir/EDSC/Intrinsics.h"
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/Support/Functional.h"
17
18 using namespace mlir;
19 using namespace mlir::edsc;
20 using namespace mlir::edsc::intrinsics;
21 using namespace mlir::edsc::ops;
22
getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,unsigned & pos)23 static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
24 unsigned &pos) {
25 for (auto sidx : structuredIndices) {
26 for (auto expr : sidx.getExprs()) {
27 expr.walk([&pos](AffineExpr e) {
28 if (auto d = e.dyn_cast<AffineDimExpr>())
29 pos = std::max(pos, d.getPosition());
30 });
31 }
32 }
33 }
34
makeGenericLinalgOp(ArrayRef<IterType> iteratorTypes,ArrayRef<StructuredIndexed> inputs,ArrayRef<StructuredIndexed> outputs,function_ref<void (ArrayRef<BlockArgument>)> regionBuilder,ArrayRef<Value> otherValues,ArrayRef<Attribute> otherAttributes)35 Operation *mlir::edsc::makeGenericLinalgOp(
36 ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
37 ArrayRef<StructuredIndexed> outputs,
38 function_ref<void(ArrayRef<BlockArgument>)> regionBuilder,
39 ArrayRef<Value> otherValues, ArrayRef<Attribute> otherAttributes) {
40 auto &builder = edsc::ScopedContext::getBuilder();
41 auto *ctx = builder.getContext();
42 unsigned nInputs = inputs.size();
43 unsigned nOutputs = outputs.size();
44 unsigned maxPos = 0;
45 getMaxDimIndex(inputs, maxPos);
46 getMaxDimIndex(outputs, maxPos);
47 // maxPos is 0 indexed, need to turn this into a count (i.e. +1)
48 unsigned nDims = maxPos + 1;
49
50 SmallVector<AffineMap, 4> maps;
51 maps.reserve(nInputs + nOutputs);
52 for (auto in : inputs)
53 maps.push_back(
54 AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs()));
55 for (auto out : outputs)
56 maps.push_back(
57 AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs()));
58
59 unsigned nViews = nInputs + nOutputs;
60 SmallVector<Value, 4> values;
61 values.reserve(nViews);
62 values.append(inputs.begin(), inputs.end());
63 values.append(outputs.begin(), outputs.end());
64
65 auto iteratorStrTypes = functional::map(toString, iteratorTypes);
66 // clang-format off
67 auto *op =
68 edsc::ScopedContext::getBuilder()
69 .create<linalg::GenericOp>(
70 edsc::ScopedContext::getLocation(),
71 ArrayRef<Type>{}, // TODO(ntv): support tensors
72 values,
73 IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
74 IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
75 builder.getAffineMapArrayAttr(maps),
76 builder.getStrArrayAttr(iteratorStrTypes),
77 StringAttr() /*doc*/,
78 FlatSymbolRefAttr() /*fun*/,
79 StringAttr() /*library_call*/
80 /* TODO: other attributes in op */
81 )
82 .getOperation();
83 // clang-format on
84
85 using namespace edsc;
86 SmallVector<Type, 4> blockTypes;
87 blockTypes.reserve(values.size());
88 for (auto it : llvm::enumerate(values))
89 blockTypes.push_back((it.index() < nViews)
90 ? getElementTypeOrSelf(it.value())
91 : it.value().getType());
92
93 assert(op->getRegions().front().empty());
94 op->getRegions().front().push_front(new Block);
95 OpBuilder bb(op->getRegions().front());
96 ScopedContext scope(bb, op->getLoc());
97 BlockHandle b;
98 auto handles = makeValueHandles(blockTypes);
99 BlockBuilder(&b, makeHandlePointers(MutableArrayRef<ValueHandle>(handles)))(
100 [&] { regionBuilder(b.getBlock()->getArguments()); });
101 return op;
102 }
103
macRegionBuilder(ArrayRef<BlockArgument> args)104 void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
105 using edsc::op::operator+;
106 using edsc::op::operator*;
107 assert(args.size() == 3 && "expected 3 block arguments");
108 ValueHandle a(args[0]), b(args[1]), c(args[2]);
109 linalg_yield((c + a * b).getValue());
110 }
111
linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,StructuredIndexed I,StructuredIndexed O)112 Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
113 StructuredIndexed I,
114 StructuredIndexed O) {
115 SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
116 edsc::IterType::Parallel);
117 auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
118 assert(args.size() == 2 && "expected 2 block arguments");
119 ValueHandle a(args[0]);
120 linalg_yield(unaryOp(a));
121 };
122 return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
123 }
124
linalg_pointwise_tanh(StructuredIndexed I,StructuredIndexed O)125 Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I,
126 StructuredIndexed O) {
127 ;
128 using edsc::intrinsics::tanh;
129 UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); });
130 return linalg_pointwise(unOp, I, O);
131 }
132
133 /// Binary pointwise operation (with broadcast) entry point.
linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)134 Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
135 StructuredIndexed I1,
136 StructuredIndexed I2,
137 StructuredIndexed O) {
138 SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
139 edsc::IterType::Parallel);
140 auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
141 assert(args.size() == 3 && "expected 3 block arguments");
142 ValueHandle a(args[0]), b(args[1]);
143 linalg_yield(binaryOp(a, b));
144 };
145 return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
146 }
147
linalg_pointwise_add(StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)148 Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1,
149 StructuredIndexed I2,
150 StructuredIndexed O) {
151 using edsc::op::operator+;
152 BinaryPointwiseOpBuilder binOp(
153 [](ValueHandle a, ValueHandle b) -> Value { return a + b; });
154 return linalg_pointwise(binOp, I1, I2, O);
155 }
156
linalg_pointwise_max(StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)157 Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
158 StructuredIndexed I2,
159 StructuredIndexed O) {
160 BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value {
161 using edsc::intrinsics::select;
162 using edsc::op::operator>;
163 return select(a > b, a, b).getValue();
164 });
165 return linalg_pointwise(binOp, I1, I2, O);
166 }
167
linalg_matmul(ValueHandle vA,ValueHandle vB,ValueHandle vC)168 Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
169 ValueHandle vC) {
170 // clang-format off
171 AffineExpr m, n, k;
172 bindDims(ScopedContext::getContext(), m, n, k);
173 StructuredIndexed A(vA), B(vB), C(vC);
174 return makeGenericLinalgOp(
175 {IterType::Parallel, IterType::Parallel, IterType::Reduction},
176 {A({m, k}), B({k, n})},
177 {C({m, n})},
178 macRegionBuilder);
179 // clang-format on
180 }
181
linalg_conv_nhwc(ValueHandle vI,ValueHandle vW,ValueHandle vO,ArrayRef<int> strides,ArrayRef<int> dilations)182 Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
183 ValueHandle vO,
184 ArrayRef<int> strides,
185 ArrayRef<int> dilations) {
186 MLIRContext *ctx = ScopedContext::getContext();
187 // TODO(ntv) some template magic to make everything rank-polymorphic.
188 assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
189 assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
190
191 // Some short names.
192 auto par = IterType::Parallel;
193 auto red = IterType::Reduction;
194 auto s = strides;
195 auto d = dilations;
196
197 AffineExpr b, f, h, w, kh, kw, c;
198 bindDims(ctx, b, f, h, w, kh, kw, c);
199 unsigned numDims = c.cast<AffineDimExpr>().getPosition() + 1;
200 StructuredIndexed I(vI), W(vW), O(vO);
201 // clang-format off
202 return makeGenericLinalgOp(
203 {par, par, par, par, red, red, red}, {
204 I({b,
205 // Roundtrip to flattened form to serve as canonicalization and ensure
206 // consistent ordering of subexpressions.
207 simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
208 simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
209 c}),
210 W({kh, kw, c, f})}, {
211 O({b, h, w, f})},
212 macRegionBuilder);
213 // clang-format on
214 }
215
linalg_dilated_conv_nhwc(ValueHandle vI,ValueHandle vW,ValueHandle vO,int depth_multiplier,ArrayRef<int> strides,ArrayRef<int> dilations)216 Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
217 ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier,
218 ArrayRef<int> strides, ArrayRef<int> dilations) {
219 MLIRContext *ctx = ScopedContext::getContext();
220 // TODO(ntv) some template magic to make everything rank-polymorphic.
221 assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
222 assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
223
224 // Some short names.
225 auto par = IterType::Parallel;
226 auto red = IterType::Reduction;
227 auto s = strides;
228 auto d = dilations;
229
230 // clang-format off
231 AffineExpr b, dm, c, h, w, kh, kw;
232 bindDims(ctx, b, dm, c, h, w, kh, kw);
233 unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1;
234 StructuredIndexed I(vI), W(vW), O(vO);
235 return makeGenericLinalgOp(
236 {par, par, par, par, par, red, red}, {
237 I({b,
238 // Roundtrip to flattened form to serve as canonicalization and ensure
239 // consistent ordering of subexpressions.
240 simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
241 simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
242 c}),
243 W({kh, kw, c, dm})}, {
244 O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
245 macRegionBuilder);
246 // clang-format on
247 }
248