1 //===- Generalization.cpp - linalg named ops to generic ops --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg generalization pass. It converts named
10 // Linalg ops to linalg.generic ops.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/Debug.h"
26
27 #define DEBUG_TYPE "linalg-generalization"
28
29 using namespace mlir;
30 using namespace mlir::linalg;
31
32 // Creates a linalg.generic op from the given `namedOp`. Returns a null op if
33 // the given `namedOp` does not have a region builder.
createGenericOpFromNamedOp(LinalgOp namedOp,PatternRewriter & rewriter)34 static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
35 PatternRewriter &rewriter) {
36 SmallVector<Value> inputOperands = namedOp.getInputOperands();
37 SmallVector<Value> outputOperands = namedOp.getOutputOperands();
38 SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
39 SmallVector<StringRef> iterators = llvm::to_vector<4>(
40 namedOp.iterator_types().getAsValueRange<StringAttr>());
41 SmallVector<RankedTensorType> resultTypes = namedOp.getOutputTensorTypes();
42 SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
43
44 // Inline the existing region if the named operation has a region attached.
45 if (namedOp->getNumRegions() == 1) {
46 GenericOp genericOp =
47 rewriter.create<GenericOp>(namedOp.getLoc(), types, inputOperands,
48 outputOperands, indexingMaps, iterators);
49 rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(),
50 genericOp.region().begin());
51 return genericOp;
52 }
53
54 // Otherwise use the region builder to generate a new region.
55 // TODO: Remove this path once all linag operations have a region attached.
56 auto regionBuilder = namedOp.getRegionBuilder();
57 if (!regionBuilder) {
58 LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
59 return nullptr;
60 }
61 return rewriter.create<GenericOp>(
62 namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps,
63 iterators,
64 [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
65 ImplicitLocOpBuilder b(loc, bodyBuilder);
66 regionBuilder(b, *bodyBuilder.getBlock());
67 });
68 }
69
70 namespace {
71
72 /// Base class for all linalg generalization patterns. A subclass must provide
73 /// the following method:
74 /// GenericOp createGenericOp(RootOp, PatternRewriter &)
75 /// for creating the generic op.
76 // TODO: remove this pattern after migrating all manually-written named ops
77 // into auto-generated ones.
78 template <typename ConcretePattern, typename RootOp>
79 struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
LinalgGeneralizationPattern__anon8d3e0c9c0211::LinalgGeneralizationPattern80 LinalgGeneralizationPattern(MLIRContext *context,
81 LinalgTransformationFilter marker,
82 PatternBenefit benefit = 1)
83 : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
84
matchAndRewrite__anon8d3e0c9c0211::LinalgGeneralizationPattern85 LogicalResult matchAndRewrite(RootOp rootOp,
86 PatternRewriter &rewriter) const override {
87 auto linalgOp = dyn_cast<LinalgOp>(rootOp.getOperation());
88 if (!linalgOp)
89 return failure();
90 if (failed(marker.checkAndNotify(rewriter, linalgOp)))
91 return failure();
92
93 auto *pattern = static_cast<const ConcretePattern *>(this);
94 GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
95 if (!genericOp)
96 return failure();
97
98 rewriter.replaceOp(rootOp, genericOp.getResults());
99 marker.replaceLinalgTransformationFilter(rewriter,
100 genericOp.getOperation());
101 return success();
102 }
103
104 private:
105 LinalgTransformationFilter marker;
106 };
107
108 struct GeneralizeConvOp
109 : public LinalgGeneralizationPattern<GeneralizeConvOp, ConvOp> {
110 using LinalgGeneralizationPattern::LinalgGeneralizationPattern;
111
112 GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
113 };
114
115 /// Catch-all pattern for converting all named ops with a region builder into
116 /// linalg.generic.
117 struct LinalgNamedOpGeneralizationPattern : RewritePattern {
LinalgNamedOpGeneralizationPattern__anon8d3e0c9c0211::LinalgNamedOpGeneralizationPattern118 LinalgNamedOpGeneralizationPattern(MLIRContext *context,
119 LinalgTransformationFilter marker,
120 PatternBenefit benefit = 1)
121 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
122 marker(std::move(marker)) {}
123
matchAndRewrite__anon8d3e0c9c0211::LinalgNamedOpGeneralizationPattern124 LogicalResult matchAndRewrite(Operation *rootOp,
125 PatternRewriter &rewriter) const override {
126 auto linalgOp = dyn_cast<LinalgOp>(rootOp);
127 if (!linalgOp)
128 return failure();
129 if (failed(marker.checkAndNotify(rewriter, linalgOp)))
130 return failure();
131
132 // No nothing to do for linalg.generic.
133 if (isa<GenericOp>(rootOp))
134 return failure();
135
136 GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);
137 if (!genericOp)
138 return failure();
139
140 rewriter.replaceOp(rootOp, genericOp.getResults());
141 marker.replaceLinalgTransformationFilter(rewriter,
142 genericOp.getOperation());
143 return success();
144 }
145
146 private:
147 LinalgTransformationFilter marker;
148 };
149
150 struct LinalgGeneralizationPass
151 : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
152 void runOnFunction() override;
153 };
154
155 } // namespace
156
runOnFunction()157 void LinalgGeneralizationPass::runOnFunction() {
158 FuncOp func = getFunction();
159 RewritePatternSet patterns(&getContext());
160 populateLinalgConvGeneralizationPatterns(patterns);
161 populateLinalgNamedOpsGeneralizationPatterns(patterns);
162 (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
163 }
164
createGenericOp(ConvOp convOp,OpBuilder & builder) const165 GenericOp GeneralizeConvOp::createGenericOp(ConvOp convOp,
166 OpBuilder &builder) const {
167 SmallVector<AffineMap> indexingMaps = convOp.getIndexingMaps();
168 auto iterators =
169 llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
170 SmallVector<Value> inputBuffers = convOp.getInputBufferOperands();
171 SmallVector<Value> outputBuffers = convOp.getOutputBufferOperands();
172 return builder.create<GenericOp>(
173 convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(), inputBuffers,
174 outputBuffers, indexingMaps, iterators,
175 [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) {
176 Value mul =
177 bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
178 Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
179 bodyBuilder.create<YieldOp>(bodyLoc, add);
180 });
181 }
182
populateLinalgConvGeneralizationPatterns(RewritePatternSet & patterns,LinalgTransformationFilter marker)183 void mlir::linalg::populateLinalgConvGeneralizationPatterns(
184 RewritePatternSet &patterns, LinalgTransformationFilter marker) {
185 patterns.add<GeneralizeConvOp>(patterns.getContext(), marker);
186 }
187
populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet & patterns,LinalgTransformationFilter marker)188 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
189 RewritePatternSet &patterns, LinalgTransformationFilter marker) {
190 patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
191 marker);
192 }
193
createLinalgGeneralizationPass()194 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
195 return std::make_unique<LinalgGeneralizationPass>();
196 }
197