1 //===- Generalization.cpp - linalg named ops to generic ops  --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg generalization pass. It converts named
10 // Linalg ops to linalg.generic ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "linalg-generalization"
28 
29 using namespace mlir;
30 using namespace mlir::linalg;
31 
32 // Creates a linalg.generic op from the given `namedOp`. Returns a null op if
33 // the given `namedOp` does not have a region builder.
createGenericOpFromNamedOp(LinalgOp namedOp,PatternRewriter & rewriter)34 static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
35                                             PatternRewriter &rewriter) {
36   SmallVector<Value> inputOperands = namedOp.getInputOperands();
37   SmallVector<Value> outputOperands = namedOp.getOutputOperands();
38   SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
39   SmallVector<StringRef> iterators = llvm::to_vector<4>(
40       namedOp.iterator_types().getAsValueRange<StringAttr>());
41   SmallVector<RankedTensorType> resultTypes = namedOp.getOutputTensorTypes();
42   SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
43 
44   // Inline the existing region if the named operation has a region attached.
45   if (namedOp->getNumRegions() == 1) {
46     GenericOp genericOp =
47         rewriter.create<GenericOp>(namedOp.getLoc(), types, inputOperands,
48                                    outputOperands, indexingMaps, iterators);
49     rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(),
50                                 genericOp.region().begin());
51     return genericOp;
52   }
53 
54   // Otherwise use the region builder to generate a new region.
55   // TODO: Remove this path once all linag operations have a region attached.
56   auto regionBuilder = namedOp.getRegionBuilder();
57   if (!regionBuilder) {
58     LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
59     return nullptr;
60   }
61   return rewriter.create<GenericOp>(
62       namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps,
63       iterators,
64       [&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
65         ImplicitLocOpBuilder b(loc, bodyBuilder);
66         regionBuilder(b, *bodyBuilder.getBlock());
67       });
68 }
69 
70 namespace {
71 
72 /// Base class for all linalg generalization patterns. A subclass must provide
73 /// the following method:
74 ///   GenericOp createGenericOp(RootOp, PatternRewriter &)
75 /// for creating the generic op.
76 // TODO: remove this pattern after migrating all manually-written named ops
77 // into auto-generated ones.
78 template <typename ConcretePattern, typename RootOp>
79 struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
LinalgGeneralizationPattern__anon8d3e0c9c0211::LinalgGeneralizationPattern80   LinalgGeneralizationPattern(MLIRContext *context,
81                               LinalgTransformationFilter marker,
82                               PatternBenefit benefit = 1)
83       : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
84 
matchAndRewrite__anon8d3e0c9c0211::LinalgGeneralizationPattern85   LogicalResult matchAndRewrite(RootOp rootOp,
86                                 PatternRewriter &rewriter) const override {
87     auto linalgOp = dyn_cast<LinalgOp>(rootOp.getOperation());
88     if (!linalgOp)
89       return failure();
90     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
91       return failure();
92 
93     auto *pattern = static_cast<const ConcretePattern *>(this);
94     GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
95     if (!genericOp)
96       return failure();
97 
98     rewriter.replaceOp(rootOp, genericOp.getResults());
99     marker.replaceLinalgTransformationFilter(rewriter,
100                                              genericOp.getOperation());
101     return success();
102   }
103 
104 private:
105   LinalgTransformationFilter marker;
106 };
107 
108 struct GeneralizeConvOp
109     : public LinalgGeneralizationPattern<GeneralizeConvOp, ConvOp> {
110   using LinalgGeneralizationPattern::LinalgGeneralizationPattern;
111 
112   GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
113 };
114 
115 /// Catch-all pattern for converting all named ops with a region builder into
116 /// linalg.generic.
117 struct LinalgNamedOpGeneralizationPattern : RewritePattern {
LinalgNamedOpGeneralizationPattern__anon8d3e0c9c0211::LinalgNamedOpGeneralizationPattern118   LinalgNamedOpGeneralizationPattern(MLIRContext *context,
119                                      LinalgTransformationFilter marker,
120                                      PatternBenefit benefit = 1)
121       : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
122         marker(std::move(marker)) {}
123 
matchAndRewrite__anon8d3e0c9c0211::LinalgNamedOpGeneralizationPattern124   LogicalResult matchAndRewrite(Operation *rootOp,
125                                 PatternRewriter &rewriter) const override {
126     auto linalgOp = dyn_cast<LinalgOp>(rootOp);
127     if (!linalgOp)
128       return failure();
129     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
130       return failure();
131 
132     // No nothing to do for linalg.generic.
133     if (isa<GenericOp>(rootOp))
134       return failure();
135 
136     GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);
137     if (!genericOp)
138       return failure();
139 
140     rewriter.replaceOp(rootOp, genericOp.getResults());
141     marker.replaceLinalgTransformationFilter(rewriter,
142                                              genericOp.getOperation());
143     return success();
144   }
145 
146 private:
147   LinalgTransformationFilter marker;
148 };
149 
150 struct LinalgGeneralizationPass
151     : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
152   void runOnFunction() override;
153 };
154 
155 } // namespace
156 
runOnFunction()157 void LinalgGeneralizationPass::runOnFunction() {
158   FuncOp func = getFunction();
159   RewritePatternSet patterns(&getContext());
160   populateLinalgConvGeneralizationPatterns(patterns);
161   populateLinalgNamedOpsGeneralizationPatterns(patterns);
162   (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
163 }
164 
createGenericOp(ConvOp convOp,OpBuilder & builder) const165 GenericOp GeneralizeConvOp::createGenericOp(ConvOp convOp,
166                                             OpBuilder &builder) const {
167   SmallVector<AffineMap> indexingMaps = convOp.getIndexingMaps();
168   auto iterators =
169       llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
170   SmallVector<Value> inputBuffers = convOp.getInputBufferOperands();
171   SmallVector<Value> outputBuffers = convOp.getOutputBufferOperands();
172   return builder.create<GenericOp>(
173       convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(), inputBuffers,
174       outputBuffers, indexingMaps, iterators,
175       [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) {
176         Value mul =
177             bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
178         Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
179         bodyBuilder.create<YieldOp>(bodyLoc, add);
180       });
181 }
182 
populateLinalgConvGeneralizationPatterns(RewritePatternSet & patterns,LinalgTransformationFilter marker)183 void mlir::linalg::populateLinalgConvGeneralizationPatterns(
184     RewritePatternSet &patterns, LinalgTransformationFilter marker) {
185   patterns.add<GeneralizeConvOp>(patterns.getContext(), marker);
186 }
187 
populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet & patterns,LinalgTransformationFilter marker)188 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
189     RewritePatternSet &patterns, LinalgTransformationFilter marker) {
190   patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
191                                                    marker);
192 }
193 
createLinalgGeneralizationPass()194 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
195   return std::make_unique<LinalgGeneralizationPass>();
196 }
197