1 //===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/Builders.h"
10 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
11 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
12 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
13 #include "mlir/Dialect/SCF/EDSC/Builders.h"
14 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
15 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
16 #include "mlir/IR/AffineExpr.h"
17
18 using namespace mlir;
19 using namespace mlir::edsc;
20 using namespace mlir::edsc::intrinsics;
21 using namespace mlir::linalg;
22 using namespace mlir::scf;
23
makeGenericLinalgOp(ArrayRef<IteratorType> iteratorTypes,ArrayRef<StructuredIndexed> inputs,ArrayRef<StructuredIndexed> outputs,TypeRange resultTensorTypes,function_ref<void (ValueRange)> regionBuilder,ArrayRef<Value> otherValues,ArrayRef<Attribute> otherAttributes)24 Operation *mlir::edsc::makeGenericLinalgOp(
25 ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
26 ArrayRef<StructuredIndexed> outputs, TypeRange resultTensorTypes,
27 function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues,
28 ArrayRef<Attribute> otherAttributes) {
29 OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
30
31 // Build maps
32 SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
33 exprsList.reserve(inputs.size() + outputs.size());
34
35 for (auto container : {inputs, outputs})
36 for (const StructuredIndexed &s : container)
37 exprsList.emplace_back(s.getExprs().begin(), s.getExprs().end());
38 auto maps = AffineMap::inferFromExprList(exprsList);
39
40 SmallVector<Value, 4> inputValues, outputValues;
41 inputValues.reserve(inputs.size());
42 outputValues.reserve(outputs.size());
43 std::copy(inputs.begin(), inputs.end(), std::back_inserter(inputValues));
44 std::copy(outputs.begin(), outputs.end(), std::back_inserter(outputValues));
45
46 auto iteratorStrTypes =
47 llvm::to_vector<8>(llvm::map_range(iteratorTypes, toString));
48 // clang-format off
49 auto *op =
50 edsc::ScopedContext::getBuilderRef()
51 .create<linalg::GenericOp>(
52 edsc::ScopedContext::getLocation(),
53 resultTensorTypes,
54 inputValues,
55 outputValues,
56 builder.getAffineMapArrayAttr(maps),
57 builder.getStrArrayAttr(iteratorStrTypes),
58 StringAttr() /*doc*/,
59 StringAttr() /*library_call*/,
60 ArrayAttr() /*sparse*/
61 /* TODO: other attributes in op */
62 )
63 .getOperation();
64 // clang-format on
65
66 using namespace edsc;
67 SmallVector<Type, 4> blockTypes;
68 blockTypes.reserve(inputs.size() + outputs.size());
69 for (auto container : {inputs, outputs})
70 for (const StructuredIndexed &s : container)
71 blockTypes.push_back(getElementTypeOrSelf(s.getType()));
72
73 assert(op->getNumRegions() == 1);
74 assert(op->getRegion(0).empty());
75 OpBuilder opBuilder(op);
76 ScopedContext scope(opBuilder, op->getLoc());
77 buildInNewBlock(op->getRegion(0), blockTypes, regionBuilder);
78 assert(llvm::hasSingleElement(op->getRegion(0)));
79 return op;
80 }
81
mulRegionBuilder(ValueRange args)82 void mlir::edsc::ops::mulRegionBuilder(ValueRange args) {
83 using edsc::op::operator+;
84 using edsc::op::operator*;
85 assert(args.size() == 2 && "expected 2 block arguments");
86 Value a(args[0]), b(args[1]);
87 linalg_yield(a * b);
88 }
89
macRegionBuilder(ValueRange args)90 void mlir::edsc::ops::macRegionBuilder(ValueRange args) {
91 using edsc::op::operator+;
92 using edsc::op::operator*;
93 assert(args.size() == 3 && "expected 3 block arguments");
94 Value a(args[0]), b(args[1]), c(args[2]);
95 linalg_yield(c + a * b);
96 }
97
linalg_generic_pointwise(UnaryPointwiseOpBuilder unaryOp,StructuredIndexed I,StructuredIndexed O)98 Operation *mlir::edsc::ops::linalg_generic_pointwise(
99 UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O) {
100 SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
101 IteratorType::Parallel);
102 auto fun = [&unaryOp](ValueRange args) {
103 assert(!args.empty() && "expected >= 1 block arguments");
104 Value a(args[0]);
105 linalg_yield(unaryOp(a));
106 };
107 if (O.getType().isa<RankedTensorType>())
108 return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputs=*/{O},
109 /*resultTensorTypes=*/{O}, fun);
110 return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputs=*/{O},
111 /*resultTensorTypes=*/{}, fun);
112 }
113
linalg_generic_pointwise_tanh(StructuredIndexed I,StructuredIndexed O)114 Operation *mlir::edsc::ops::linalg_generic_pointwise_tanh(StructuredIndexed I,
115 StructuredIndexed O) {
116 UnaryPointwiseOpBuilder unOp([](Value a) -> Value { return std_tanh(a); });
117 return linalg_generic_pointwise(unOp, I, O);
118 }
119
120 /// Binary pointwise operation (with broadcast) entry point.
linalg_generic_pointwise(BinaryPointwiseOpBuilder binaryOp,StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)121 Operation *mlir::edsc::ops::linalg_generic_pointwise(
122 BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1,
123 StructuredIndexed I2, StructuredIndexed O) {
124 SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
125 IteratorType::Parallel);
126 auto fun = [&binaryOp](ValueRange args) {
127 assert(args.size() >= 2 && "expected >= 2 block arguments");
128 Value a(args[0]), b(args[1]);
129 linalg_yield(binaryOp(a, b));
130 };
131 if (O.getType().isa<RankedTensorType>())
132 return makeGenericLinalgOp(iterTypes, /*inputs=*/{I1, I2}, /*outputs=*/{O},
133 /*resultTensorTypes=*/{O}, fun);
134 return makeGenericLinalgOp(iterTypes, /*inputs=*/{I1, I2},
135 /*outputs=*/{O}, /*resultTensorTypes=*/{}, fun);
136 }
137
linalg_generic_pointwise_add(StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)138 Operation *mlir::edsc::ops::linalg_generic_pointwise_add(StructuredIndexed I1,
139 StructuredIndexed I2,
140 StructuredIndexed O) {
141 using edsc::op::operator+;
142 BinaryPointwiseOpBuilder binOp(
143 [](Value a, Value b) -> Value { return a + b; });
144 return linalg_generic_pointwise(binOp, I1, I2, O);
145 }
146
linalg_generic_pointwise_max(StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)147 Operation *mlir::edsc::ops::linalg_generic_pointwise_max(StructuredIndexed I1,
148 StructuredIndexed I2,
149 StructuredIndexed O) {
150 BinaryPointwiseOpBuilder binOp([](Value a, Value b) -> Value {
151 using edsc::op::sgt;
152 return std_select(sgt(a, b), a, b);
153 });
154 return linalg_generic_pointwise(binOp, I1, I2, O);
155 }
156
157 Operation *
linalg_generic_matmul(Value vA,Value vB,Value vC,MatmulRegionBuilder regionBuilder)158 mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
159 MatmulRegionBuilder regionBuilder) {
160 // clang-format off
161 AffineExpr m, n, k;
162 bindDims(ScopedContext::getContext(), m, n, k);
163 StructuredIndexed A(vA), B(vB), C(vC);
164 return makeGenericLinalgOp(
165 {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
166 /*inputs=*/{A({m, k}), B({k, n})},
167 /*outputs=*/{C({m, n})},
168 /*resultTensorTypes=*/{},
169 regionBuilder);
170 // clang-format on
171 }
172
173 Operation *
linalg_generic_matmul(Value vA,Value vB,Value vC,RankedTensorType tD,MatmulRegionBuilder regionBuilder)174 mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
175 RankedTensorType tD,
176 MatmulRegionBuilder regionBuilder) {
177 // clang-format off
178 AffineExpr m, n, k;
179 bindDims(ScopedContext::getContext(), m, n, k);
180 StructuredIndexed A(vA), B(vB), C(vC), D(tD);
181 return makeGenericLinalgOp(
182 {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
183 /*inputs=*/{A({m, k}), B({k, n})},
184 /*outputs=*/{C({m, n})},
185 /*resultTensorTypes=*/{D({m, n})},
186 regionBuilder);
187 // clang-format on
188 }
189
linalg_generic_conv_nhwc(Value vI,Value vW,Value vO,ArrayRef<int> strides,ArrayRef<int> dilations)190 Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(Value vI, Value vW,
191 Value vO,
192 ArrayRef<int> strides,
193 ArrayRef<int> dilations) {
194 MLIRContext *ctx = ScopedContext::getContext();
195 // TODO: some template magic to make everything rank-polymorphic.
196 assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
197 assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
198
199 // Some short names.
200 auto par = IteratorType::Parallel;
201 auto red = IteratorType::Reduction;
202 auto s = strides;
203 auto d = dilations;
204
205 AffineExpr b, f, h, w, kh, kw, c;
206 bindDims(ctx, b, f, h, w, kh, kw, c);
207 unsigned numDims = c.cast<AffineDimExpr>().getPosition() + 1;
208 StructuredIndexed I(vI), W(vW), O(vO);
209 // clang-format off
210 return makeGenericLinalgOp(
211 {par, par, par, par, red, red, red},
212 /*inputs=*/{
213 I({b,
214 // Roundtrip to flattened form to serve as canonicalization and ensure
215 // consistent ordering of subexpressions.
216 simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
217 simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
218 c}),
219 W({kh, kw, c, f}) },
220 /*outputs=*/{ O({b, h, w, f}) },
221 /*resultTensorTypes=*/{},
222 macRegionBuilder);
223 // clang-format on
224 }
225
linalg_generic_dilated_conv_nhwc(Value vI,Value vW,Value vO,int depth_multiplier,ArrayRef<int> strides,ArrayRef<int> dilations)226 Operation *mlir::edsc::ops::linalg_generic_dilated_conv_nhwc(
227 Value vI, Value vW, Value vO, int depth_multiplier, ArrayRef<int> strides,
228 ArrayRef<int> dilations) {
229 MLIRContext *ctx = ScopedContext::getContext();
230 // TODO: some template magic to make everything rank-polymorphic.
231 assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
232 assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
233
234 // Some short names.
235 auto par = IteratorType::Parallel;
236 auto red = IteratorType::Reduction;
237 auto s = strides;
238 auto d = dilations;
239
240 // clang-format off
241 AffineExpr b, dm, c, h, w, kh, kw;
242 bindDims(ctx, b, dm, c, h, w, kh, kw);
243 unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1;
244 StructuredIndexed I(vI), W(vW), O(vO);
245 return makeGenericLinalgOp(
246 {par, par, par, par, par, red, red},
247 /*inputs=*/{
248 I({b,
249 // Roundtrip to flattened form to serve as canonicalization and ensure
250 // consistent ordering of subexpressions.
251 simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
252 simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
253 c}),
254 W({kh, kw, c, dm})},
255 /*outputs=*/{
256 O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
257 /*resultTensorTypes=*/{},
258 macRegionBuilder);
259 // clang-format on
260 }
261