1 //===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
10 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/EDSC/Builders.h"
13 #include "mlir/EDSC/Intrinsics.h"
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/Support/Functional.h"
17 
18 using namespace mlir;
19 using namespace mlir::edsc;
20 using namespace mlir::edsc::intrinsics;
21 using namespace mlir::edsc::ops;
22 
getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,unsigned & pos)23 static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
24                            unsigned &pos) {
25   for (auto sidx : structuredIndices) {
26     for (auto expr : sidx.getExprs()) {
27       expr.walk([&pos](AffineExpr e) {
28         if (auto d = e.dyn_cast<AffineDimExpr>())
29           pos = std::max(pos, d.getPosition());
30       });
31     }
32   }
33 }
34 
makeGenericLinalgOp(ArrayRef<IterType> iteratorTypes,ArrayRef<StructuredIndexed> inputs,ArrayRef<StructuredIndexed> outputs,function_ref<void (ArrayRef<BlockArgument>)> regionBuilder,ArrayRef<Value> otherValues,ArrayRef<Attribute> otherAttributes)35 Operation *mlir::edsc::makeGenericLinalgOp(
36     ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
37     ArrayRef<StructuredIndexed> outputs,
38     function_ref<void(ArrayRef<BlockArgument>)> regionBuilder,
39     ArrayRef<Value> otherValues, ArrayRef<Attribute> otherAttributes) {
40   auto &builder = edsc::ScopedContext::getBuilder();
41   auto *ctx = builder.getContext();
42   unsigned nInputs = inputs.size();
43   unsigned nOutputs = outputs.size();
44   unsigned maxPos = 0;
45   getMaxDimIndex(inputs, maxPos);
46   getMaxDimIndex(outputs, maxPos);
47   // maxPos is 0 indexed, need to turn this into a count (i.e. +1)
48   unsigned nDims = maxPos + 1;
49 
50   SmallVector<AffineMap, 4> maps;
51   maps.reserve(nInputs + nOutputs);
52   for (auto in : inputs)
53     maps.push_back(
54         AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs()));
55   for (auto out : outputs)
56     maps.push_back(
57         AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs()));
58 
59   unsigned nViews = nInputs + nOutputs;
60   SmallVector<Value, 4> values;
61   values.reserve(nViews);
62   values.append(inputs.begin(), inputs.end());
63   values.append(outputs.begin(), outputs.end());
64 
65   auto iteratorStrTypes = functional::map(toString, iteratorTypes);
66   // clang-format off
67   auto *op =
68       edsc::ScopedContext::getBuilder()
69           .create<linalg::GenericOp>(
70               edsc::ScopedContext::getLocation(),
71               ArrayRef<Type>{}, // TODO(ntv): support tensors
72               values,
73               IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
74               IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
75               builder.getAffineMapArrayAttr(maps),
76               builder.getStrArrayAttr(iteratorStrTypes),
77               StringAttr() /*doc*/,
78               FlatSymbolRefAttr() /*fun*/,
79               StringAttr() /*library_call*/
80               /* TODO: other attributes in op */
81               )
82           .getOperation();
83   // clang-format on
84 
85   using namespace edsc;
86   SmallVector<Type, 4> blockTypes;
87   blockTypes.reserve(values.size());
88   for (auto it : llvm::enumerate(values))
89     blockTypes.push_back((it.index() < nViews)
90                              ? getElementTypeOrSelf(it.value())
91                              : it.value().getType());
92 
93   assert(op->getRegions().front().empty());
94   op->getRegions().front().push_front(new Block);
95   OpBuilder bb(op->getRegions().front());
96   ScopedContext scope(bb, op->getLoc());
97   BlockHandle b;
98   auto handles = makeValueHandles(blockTypes);
99   BlockBuilder(&b, makeHandlePointers(MutableArrayRef<ValueHandle>(handles)))(
100       [&] { regionBuilder(b.getBlock()->getArguments()); });
101   return op;
102 }
103 
macRegionBuilder(ArrayRef<BlockArgument> args)104 void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
105   using edsc::op::operator+;
106   using edsc::op::operator*;
107   assert(args.size() == 3 && "expected 3 block arguments");
108   ValueHandle a(args[0]), b(args[1]), c(args[2]);
109   linalg_yield((c + a * b).getValue());
110 }
111 
linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,StructuredIndexed I,StructuredIndexed O)112 Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
113                                              StructuredIndexed I,
114                                              StructuredIndexed O) {
115   SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
116                                            edsc::IterType::Parallel);
117   auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
118     assert(args.size() == 2 && "expected 2 block arguments");
119     ValueHandle a(args[0]);
120     linalg_yield(unaryOp(a));
121   };
122   return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
123 }
124 
linalg_pointwise_tanh(StructuredIndexed I,StructuredIndexed O)125 Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I,
126                                                   StructuredIndexed O) {
127   ;
128   using edsc::intrinsics::tanh;
129   UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); });
130   return linalg_pointwise(unOp, I, O);
131 }
132 
133 /// Binary pointwise operation (with broadcast) entry point.
linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)134 Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
135                                              StructuredIndexed I1,
136                                              StructuredIndexed I2,
137                                              StructuredIndexed O) {
138   SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
139                                            edsc::IterType::Parallel);
140   auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
141     assert(args.size() == 3 && "expected 3 block arguments");
142     ValueHandle a(args[0]), b(args[1]);
143     linalg_yield(binaryOp(a, b));
144   };
145   return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
146 }
147 
linalg_pointwise_add(StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)148 Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1,
149                                                  StructuredIndexed I2,
150                                                  StructuredIndexed O) {
151   using edsc::op::operator+;
152   BinaryPointwiseOpBuilder binOp(
153       [](ValueHandle a, ValueHandle b) -> Value { return a + b; });
154   return linalg_pointwise(binOp, I1, I2, O);
155 }
156 
linalg_pointwise_max(StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)157 Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
158                                                  StructuredIndexed I2,
159                                                  StructuredIndexed O) {
160   BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value {
161     using edsc::intrinsics::select;
162     using edsc::op::operator>;
163     return select(a > b, a, b).getValue();
164   });
165   return linalg_pointwise(binOp, I1, I2, O);
166 }
167 
linalg_matmul(ValueHandle vA,ValueHandle vB,ValueHandle vC)168 Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
169                                           ValueHandle vC) {
170   // clang-format off
171   AffineExpr m, n, k;
172   bindDims(ScopedContext::getContext(), m, n, k);
173   StructuredIndexed A(vA), B(vB), C(vC);
174   return makeGenericLinalgOp(
175     {IterType::Parallel, IterType::Parallel, IterType::Reduction},
176     {A({m, k}), B({k, n})},
177     {C({m, n})},
178     macRegionBuilder);
179   // clang-format on
180 }
181 
linalg_conv_nhwc(ValueHandle vI,ValueHandle vW,ValueHandle vO,ArrayRef<int> strides,ArrayRef<int> dilations)182 Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
183                                              ValueHandle vO,
184                                              ArrayRef<int> strides,
185                                              ArrayRef<int> dilations) {
186   MLIRContext *ctx = ScopedContext::getContext();
187   // TODO(ntv) some template magic to make everything rank-polymorphic.
188   assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
189   assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
190 
191   // Some short names.
192   auto par = IterType::Parallel;
193   auto red = IterType::Reduction;
194   auto s = strides;
195   auto d = dilations;
196 
197   AffineExpr b, f, h, w, kh, kw, c;
198   bindDims(ctx, b, f, h, w, kh, kw, c);
199   unsigned numDims = c.cast<AffineDimExpr>().getPosition() + 1;
200   StructuredIndexed I(vI), W(vW), O(vO);
201   // clang-format off
202   return makeGenericLinalgOp(
203     {par, par, par, par, red, red, red}, {
204       I({b,
205          // Roundtrip to flattened form to serve as canonicalization and ensure
206          // consistent ordering of subexpressions.
207          simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
208          simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
209          c}),
210       W({kh, kw, c, f})}, {
211       O({b, h, w, f})},
212     macRegionBuilder);
213   // clang-format on
214 }
215 
linalg_dilated_conv_nhwc(ValueHandle vI,ValueHandle vW,ValueHandle vO,int depth_multiplier,ArrayRef<int> strides,ArrayRef<int> dilations)216 Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
217     ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier,
218     ArrayRef<int> strides, ArrayRef<int> dilations) {
219   MLIRContext *ctx = ScopedContext::getContext();
220   // TODO(ntv) some template magic to make everything rank-polymorphic.
221   assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
222   assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
223 
224   // Some short names.
225   auto par = IterType::Parallel;
226   auto red = IterType::Reduction;
227   auto s = strides;
228   auto d = dilations;
229 
230   // clang-format off
231   AffineExpr b, dm, c, h, w, kh, kw;
232   bindDims(ctx, b, dm, c, h, w, kh, kw);
233   unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1;
234   StructuredIndexed I(vI), W(vW), O(vO);
235   return makeGenericLinalgOp(
236     {par, par, par, par, par, red, red}, {
237       I({b,
238          // Roundtrip to flattened form to serve as canonicalization and ensure
239          // consistent ordering of subexpressions.
240          simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
241          simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
242          c}),
243       W({kh, kw, c, dm})}, {
244       O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
245     macRegionBuilder);
246   // clang-format on
247 }
248