1 //===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Builders.h"
10 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
11 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
12 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
13 #include "mlir/Dialect/SCF/EDSC/Builders.h"
14 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
15 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
16 #include "mlir/IR/AffineExpr.h"
17 
18 using namespace mlir;
19 using namespace mlir::edsc;
20 using namespace mlir::edsc::intrinsics;
21 using namespace mlir::linalg;
22 using namespace mlir::scf;
23 
makeGenericLinalgOp(ArrayRef<IteratorType> iteratorTypes,ArrayRef<StructuredIndexed> inputs,ArrayRef<StructuredIndexed> outputs,TypeRange resultTensorTypes,function_ref<void (ValueRange)> regionBuilder,ArrayRef<Value> otherValues,ArrayRef<Attribute> otherAttributes)24 Operation *mlir::edsc::makeGenericLinalgOp(
25     ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
26     ArrayRef<StructuredIndexed> outputs, TypeRange resultTensorTypes,
27     function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues,
28     ArrayRef<Attribute> otherAttributes) {
29   OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
30 
31   // Build maps
32   SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
33   exprsList.reserve(inputs.size() + outputs.size());
34 
35   for (auto container : {inputs, outputs})
36     for (const StructuredIndexed &s : container)
37       exprsList.emplace_back(s.getExprs().begin(), s.getExprs().end());
38   auto maps = AffineMap::inferFromExprList(exprsList);
39 
40   SmallVector<Value, 4> inputValues, outputValues;
41   inputValues.reserve(inputs.size());
42   outputValues.reserve(outputs.size());
43   std::copy(inputs.begin(), inputs.end(), std::back_inserter(inputValues));
44   std::copy(outputs.begin(), outputs.end(), std::back_inserter(outputValues));
45 
46   auto iteratorStrTypes =
47       llvm::to_vector<8>(llvm::map_range(iteratorTypes, toString));
48   // clang-format off
49   auto *op =
50       edsc::ScopedContext::getBuilderRef()
51           .create<linalg::GenericOp>(
52               edsc::ScopedContext::getLocation(),
53               resultTensorTypes,
54               inputValues,
55               outputValues,
56               builder.getAffineMapArrayAttr(maps),
57               builder.getStrArrayAttr(iteratorStrTypes),
58               StringAttr() /*doc*/,
59               StringAttr() /*library_call*/,
60               ArrayAttr() /*sparse*/
61               /* TODO: other attributes in op */
62               )
63           .getOperation();
64   // clang-format on
65 
66   using namespace edsc;
67   SmallVector<Type, 4> blockTypes;
68   blockTypes.reserve(inputs.size() + outputs.size());
69   for (auto container : {inputs, outputs})
70     for (const StructuredIndexed &s : container)
71       blockTypes.push_back(getElementTypeOrSelf(s.getType()));
72 
73   assert(op->getNumRegions() == 1);
74   assert(op->getRegion(0).empty());
75   OpBuilder opBuilder(op);
76   ScopedContext scope(opBuilder, op->getLoc());
77   buildInNewBlock(op->getRegion(0), blockTypes, regionBuilder);
78   assert(llvm::hasSingleElement(op->getRegion(0)));
79   return op;
80 }
81 
mulRegionBuilder(ValueRange args)82 void mlir::edsc::ops::mulRegionBuilder(ValueRange args) {
83   using edsc::op::operator+;
84   using edsc::op::operator*;
85   assert(args.size() == 2 && "expected 2 block arguments");
86   Value a(args[0]), b(args[1]);
87   linalg_yield(a * b);
88 }
89 
macRegionBuilder(ValueRange args)90 void mlir::edsc::ops::macRegionBuilder(ValueRange args) {
91   using edsc::op::operator+;
92   using edsc::op::operator*;
93   assert(args.size() == 3 && "expected 3 block arguments");
94   Value a(args[0]), b(args[1]), c(args[2]);
95   linalg_yield(c + a * b);
96 }
97 
linalg_generic_pointwise(UnaryPointwiseOpBuilder unaryOp,StructuredIndexed I,StructuredIndexed O)98 Operation *mlir::edsc::ops::linalg_generic_pointwise(
99     UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O) {
100   SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
101                                          IteratorType::Parallel);
102   auto fun = [&unaryOp](ValueRange args) {
103     assert(!args.empty() && "expected >= 1 block arguments");
104     Value a(args[0]);
105     linalg_yield(unaryOp(a));
106   };
107   if (O.getType().isa<RankedTensorType>())
108     return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputs=*/{O},
109                                /*resultTensorTypes=*/{O}, fun);
110   return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputs=*/{O},
111                              /*resultTensorTypes=*/{}, fun);
112 }
113 
linalg_generic_pointwise_tanh(StructuredIndexed I,StructuredIndexed O)114 Operation *mlir::edsc::ops::linalg_generic_pointwise_tanh(StructuredIndexed I,
115                                                           StructuredIndexed O) {
116   UnaryPointwiseOpBuilder unOp([](Value a) -> Value { return std_tanh(a); });
117   return linalg_generic_pointwise(unOp, I, O);
118 }
119 
120 /// Binary pointwise operation (with broadcast) entry point.
linalg_generic_pointwise(BinaryPointwiseOpBuilder binaryOp,StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)121 Operation *mlir::edsc::ops::linalg_generic_pointwise(
122     BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1,
123     StructuredIndexed I2, StructuredIndexed O) {
124   SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
125                                          IteratorType::Parallel);
126   auto fun = [&binaryOp](ValueRange args) {
127     assert(args.size() >= 2 && "expected >= 2 block arguments");
128     Value a(args[0]), b(args[1]);
129     linalg_yield(binaryOp(a, b));
130   };
131   if (O.getType().isa<RankedTensorType>())
132     return makeGenericLinalgOp(iterTypes, /*inputs=*/{I1, I2}, /*outputs=*/{O},
133                                /*resultTensorTypes=*/{O}, fun);
134   return makeGenericLinalgOp(iterTypes, /*inputs=*/{I1, I2},
135                              /*outputs=*/{O}, /*resultTensorTypes=*/{}, fun);
136 }
137 
linalg_generic_pointwise_add(StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)138 Operation *mlir::edsc::ops::linalg_generic_pointwise_add(StructuredIndexed I1,
139                                                          StructuredIndexed I2,
140                                                          StructuredIndexed O) {
141   using edsc::op::operator+;
142   BinaryPointwiseOpBuilder binOp(
143       [](Value a, Value b) -> Value { return a + b; });
144   return linalg_generic_pointwise(binOp, I1, I2, O);
145 }
146 
linalg_generic_pointwise_max(StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)147 Operation *mlir::edsc::ops::linalg_generic_pointwise_max(StructuredIndexed I1,
148                                                          StructuredIndexed I2,
149                                                          StructuredIndexed O) {
150   BinaryPointwiseOpBuilder binOp([](Value a, Value b) -> Value {
151     using edsc::op::sgt;
152     return std_select(sgt(a, b), a, b);
153   });
154   return linalg_generic_pointwise(binOp, I1, I2, O);
155 }
156 
157 Operation *
linalg_generic_matmul(Value vA,Value vB,Value vC,MatmulRegionBuilder regionBuilder)158 mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
159                                        MatmulRegionBuilder regionBuilder) {
160   // clang-format off
161   AffineExpr m, n, k;
162   bindDims(ScopedContext::getContext(), m, n, k);
163   StructuredIndexed A(vA), B(vB), C(vC);
164   return makeGenericLinalgOp(
165     {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
166     /*inputs=*/{A({m, k}), B({k, n})},
167     /*outputs=*/{C({m, n})},
168     /*resultTensorTypes=*/{},
169     regionBuilder);
170   // clang-format on
171 }
172 
173 Operation *
linalg_generic_matmul(Value vA,Value vB,Value vC,RankedTensorType tD,MatmulRegionBuilder regionBuilder)174 mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
175                                        RankedTensorType tD,
176                                        MatmulRegionBuilder regionBuilder) {
177   // clang-format off
178   AffineExpr m, n, k;
179   bindDims(ScopedContext::getContext(), m, n, k);
180   StructuredIndexed A(vA), B(vB), C(vC), D(tD);
181   return makeGenericLinalgOp(
182     {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
183     /*inputs=*/{A({m, k}), B({k, n})},
184     /*outputs=*/{C({m, n})},
185     /*resultTensorTypes=*/{D({m, n})},
186     regionBuilder);
187   // clang-format on
188 }
189 
linalg_generic_conv_nhwc(Value vI,Value vW,Value vO,ArrayRef<int> strides,ArrayRef<int> dilations)190 Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(Value vI, Value vW,
191                                                      Value vO,
192                                                      ArrayRef<int> strides,
193                                                      ArrayRef<int> dilations) {
194   MLIRContext *ctx = ScopedContext::getContext();
195   // TODO: some template magic to make everything rank-polymorphic.
196   assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
197   assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
198 
199   // Some short names.
200   auto par = IteratorType::Parallel;
201   auto red = IteratorType::Reduction;
202   auto s = strides;
203   auto d = dilations;
204 
205   AffineExpr b, f, h, w, kh, kw, c;
206   bindDims(ctx, b, f, h, w, kh, kw, c);
207   unsigned numDims = c.cast<AffineDimExpr>().getPosition() + 1;
208   StructuredIndexed I(vI), W(vW), O(vO);
209   // clang-format off
210   return makeGenericLinalgOp(
211     {par, par, par, par, red, red, red},
212     /*inputs=*/{
213       I({b,
214          // Roundtrip to flattened form to serve as canonicalization and ensure
215          // consistent ordering of subexpressions.
216          simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
217          simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
218          c}),
219       W({kh, kw, c, f}) },
220     /*outputs=*/{ O({b, h, w, f}) },
221     /*resultTensorTypes=*/{},
222     macRegionBuilder);
223   // clang-format on
224 }
225 
linalg_generic_dilated_conv_nhwc(Value vI,Value vW,Value vO,int depth_multiplier,ArrayRef<int> strides,ArrayRef<int> dilations)226 Operation *mlir::edsc::ops::linalg_generic_dilated_conv_nhwc(
227     Value vI, Value vW, Value vO, int depth_multiplier, ArrayRef<int> strides,
228     ArrayRef<int> dilations) {
229   MLIRContext *ctx = ScopedContext::getContext();
230   // TODO: some template magic to make everything rank-polymorphic.
231   assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
232   assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
233 
234   // Some short names.
235   auto par = IteratorType::Parallel;
236   auto red = IteratorType::Reduction;
237   auto s = strides;
238   auto d = dilations;
239 
240   // clang-format off
241   AffineExpr b, dm, c, h, w, kh, kw;
242   bindDims(ctx, b, dm, c, h, w, kh, kw);
243   unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1;
244   StructuredIndexed I(vI), W(vW), O(vO);
245   return makeGenericLinalgOp(
246     {par, par, par, par, par, red, red},
247     /*inputs=*/{
248       I({b,
249          // Roundtrip to flattened form to serve as canonicalization and ensure
250          // consistent ordering of subexpressions.
251          simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
252          simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
253          c}),
254       W({kh, kw, c, dm})},
255     /*outputs=*/{
256       O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
257     /*resultTensorTypes=*/{},
258     macRegionBuilder);
259   // clang-format on
260 }
261