1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/Transforms/FoldUtils.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/Debug.h"
29
30 #define DEBUG_TYPE "linalg-drop-unit-dims"
31
32 using namespace mlir;
33 using namespace mlir::linalg;
34
35 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
36 /// broadcasting. For example,
37 ///
38 /// ```mlir
39 /// #accesses = [
40 /// affine_map<(d0, d1) -> (0, d1)>,
41 /// affine_map<(d0, d1) -> (d0, 0)>,
42 /// affine_map<(d0, d1) -> (d0, d1)>
43 /// ]
44 ///
45 /// #trait = {
46 /// args_in = 2,
47 /// args_out = 1,
48 /// indexing_maps = #accesses,
49 /// iterator_types = ["parallel", "parallel"],
50 /// library_call = "some_external_fn"
51 /// }
52 ///
53 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
54 /// tensor<5x5xf32>
55 /// {
56 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
57 /// tensor<5xf32> into tensor<1x5xf32>
58 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
59 /// tensor<5xf32> into tensor<5x1xf32>
60 /// %2 = linalg.generic #trait %0, %1 {
61 /// ^bb0(%arg2: f32, %arg3: f32):
62 /// %3 = addf %arg2, %arg3 : f32
63 /// linalg.yield %3 : f32
64 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
65 /// return %2 : tensor<5x5xf32>
66 /// }
67 ///
68 /// would canonicalize to
69 ///
70 /// ```mlir
71 /// #accesses = [
72 /// affine_map<(d0, d1) -> (d1)>,
73 /// affine_map<(d0, d1) -> (d0)>,
74 /// affine_map<(d0, d1) -> (d0, d1)>
75 /// ]
76 ///
77 /// #trait = {
78 /// args_in = 2,
79 /// args_out = 1,
80 /// indexing_maps = #accesses,
81 /// iterator_types = ["parallel", "parallel"],
82 /// library_call = "some_external_fn"
83 /// }
84 ///
85 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
86 /// tensor<5x5xf32>
87 /// {
88 /// %0 = linalg.generic #trait %arg0, %arg1 {
89 /// ^bb0(%arg2: f32, %arg3: f32):
90 /// %3 = addf %arg2, %arg3 : f32
91 /// linalg.yield %3 : f32
92 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
93 /// return %0 : tensor<5x5xf32>
94 /// }
95
96 /// Given dims of the iteration space of a structured op that are known to be
97 /// single trip count (`unitDims`), return the indexing maps to use in the
98 /// canonicalized op with these dims removed, given the original `indexingMaps`.
replaceUnitDims(DenseSet<unsigned> & unitDims,ArrayRef<AffineMap> indexingMaps,MLIRContext * context)99 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
100 ArrayRef<AffineMap> indexingMaps,
101 MLIRContext *context) {
102 if (indexingMaps.empty())
103 return nullptr;
104 unsigned numIterationDims = indexingMaps.front().getNumDims();
105 unsigned numSymbols = indexingMaps.front().getNumSymbols();
106
107 // Compute the replacement for each dim expr.
108 SmallVector<AffineExpr, 4> dimReplacements;
109 dimReplacements.reserve(numIterationDims);
110 unsigned numKeptDims = 0;
111 for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
112 if (unitDims.count(dim))
113 dimReplacements.push_back(getAffineConstantExpr(0, context));
114 else
115 dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
116 }
117
118 // Symbols remain the same.
119 SmallVector<AffineExpr, 4> symReplacements;
120 symReplacements.reserve(numSymbols);
121 for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
122 symReplacements.push_back(getAffineSymbolExpr(symbol, context));
123
124 SmallVector<AffineMap, 4> newIndexingMaps;
125 newIndexingMaps.reserve(indexingMaps.size());
126 for (AffineMap operandMap : indexingMaps) {
127 // Expected indexing maps to have no symbols.
128 if (operandMap.getNumSymbols())
129 return nullptr;
130 newIndexingMaps.push_back(simplifyAffineMap(
131 operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
132 numIterationDims - unitDims.size(),
133 numSymbols)));
134 }
135
136 // Check that the new index maps are invertible. If not, something went
137 // wrong, so abort.
138 if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
139 return nullptr;
140 return ArrayAttr::get(context,
141 llvm::to_vector<4>(llvm::map_range(
142 newIndexingMaps, [](AffineMap map) -> Attribute {
143 return AffineMapAttr::get(map);
144 })));
145 }
146
147 /// Update the index accesses of linalg operations having index semantics.
replaceUnitDimIndexOps(GenericOp genericOp,const DenseSet<unsigned> & unitDims,PatternRewriter & rewriter)148 static void replaceUnitDimIndexOps(GenericOp genericOp,
149 const DenseSet<unsigned> &unitDims,
150 PatternRewriter &rewriter) {
151 assert(genericOp->getNumRegions() == 1 &&
152 genericOp->getRegion(0).getBlocks().size() == 1 &&
153 "expected generic operation to have one block.");
154 Block &block = genericOp->getRegion(0).front();
155
156 for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps<IndexOp>())) {
157 OpBuilder::InsertionGuard guard(rewriter);
158 rewriter.setInsertionPoint(indexOp);
159 if (unitDims.count(indexOp.dim()) != 0) {
160 rewriter.replaceOpWithNewOp<ConstantIndexOp>(indexOp, 0);
161 } else {
162 // Update the dimension of the index operation if needed.
163 unsigned droppedDims = llvm::count_if(
164 unitDims, [&](unsigned dim) { return dim < indexOp.dim(); });
165 if (droppedDims != 0)
166 rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
167 indexOp.dim() - droppedDims);
168 }
169 }
170 }
171
172 namespace {
173 /// Pattern to fold unit-trip count loops in GenericOps.
174 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
175 using OpRewritePattern<GenericOp>::OpRewritePattern;
matchAndRewrite__anon6cffb90d0311::FoldUnitDimLoops176 LogicalResult matchAndRewrite(GenericOp genericOp,
177 PatternRewriter &rewriter) const override {
178 SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
179 if (indexingMaps.empty())
180 return failure();
181
182 // Check if any of the iteration dimensions are unit-trip count. They will
183 // end up being unit-trip count if they are used to index into a unit-dim
184 // tensor/memref.
185 AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
186 if (!invertedMap)
187 return failure();
188 SmallVector<int64_t> dims = genericOp.getStaticShape();
189
190 DenseSet<unsigned> unitDims;
191 SmallVector<unsigned, 4> unitDimsReductionLoops;
192 ArrayAttr iteratorTypes = genericOp.iterator_types();
193 for (auto expr : enumerate(invertedMap.getResults())) {
194 if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
195 if (dims[dimExpr.getPosition()] == 1)
196 unitDims.insert(expr.index());
197 }
198
199 if (unitDims.empty())
200 return failure();
201
202 // Compute the modified indexing maps.
203 MLIRContext *context = rewriter.getContext();
204 ArrayAttr newIndexingMapAttr =
205 replaceUnitDims(unitDims, indexingMaps, context);
206 if (!newIndexingMapAttr)
207 return genericOp.emitError("unable to compute modified indexing_maps");
208
209 // Compute the iterator types of the modified op by dropping the one-trip
210 // count loops.
211 SmallVector<Attribute, 4> newIteratorTypes;
212 for (auto attr : llvm::enumerate(iteratorTypes)) {
213 if (!unitDims.count(attr.index()))
214 newIteratorTypes.push_back(attr.value());
215 }
216
217 rewriter.startRootUpdate(genericOp);
218 genericOp.indexing_mapsAttr(newIndexingMapAttr);
219 genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
220 replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
221 rewriter.finalizeRootUpdate(genericOp);
222 return success();
223 }
224 };
225
226 struct UnitExtentReplacementInfo {
227 Type type;
228 AffineMap indexMap;
229 ArrayAttr reassociation;
230 };
231 } // namespace
232
233 /// Utility function for replacing operands/results to a linalg generic
234 /// operation with unit-extent dimensions. These can be replaced with
235 /// an operand/result with the unit-extent dimension removed. This is only done
236 /// if the indexing map used to access that didimensionmension has a
237 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
238 /// Linalg op, and its `indexMap` the utility function returns:
239 /// - the new type with dimensions of size 1 removed.
240 /// - modified index map that can be used to access the replaced result/operand
241 /// - the reassociation that converts from the original tensor type to the
242 /// modified tensor type.
243 static llvm::Optional<UnitExtentReplacementInfo>
replaceUnitExtents(GenericOp genericOp,OpOperand * opOperand,MLIRContext * context)244 replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
245 MLIRContext *context) {
246 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
247 ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
248 ArrayRef<AffineExpr> exprs = indexingMap.getResults();
249 SmallVector<AffineExpr> reassociations;
250 SmallVector<Attribute> reassociationMaps;
251 SmallVector<AffineExpr> newIndexExprs;
252 SmallVector<int64_t> newShape;
253
254 int64_t origRank = genericOp.getRank(opOperand);
255 AffineExpr zeroExpr = getAffineConstantExpr(0, context);
256 auto isUnitExtent = [&](int64_t dim) -> bool {
257 return shape[dim] == 1 && exprs[dim] == zeroExpr;
258 };
259
260 // Early return for memrefs with affine maps to represent that we will always
261 // leave them unchanged.
262 Type actualType = opOperand->get().getType();
263 if (auto memref = actualType.dyn_cast<MemRefType>()) {
264 if (!memref.getAffineMaps().empty())
265 return llvm::None;
266 }
267
268 int64_t dim = 0;
269 // Fold dimensions that are unit-extent at the beginning of the tensor.
270 while (dim < origRank && isUnitExtent(dim))
271 reassociations.push_back(getAffineDimExpr(dim++, context));
272 while (dim < origRank) {
273 reassociations.push_back(getAffineDimExpr(dim, context));
274 newIndexExprs.push_back(exprs[dim]);
275 newShape.push_back(shape[dim]);
276 // Fold all following dimensions that are unit-extent.
277 while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
278 ++dim;
279 reassociations.push_back(getAffineDimExpr(dim, context));
280 }
281 reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
282 origRank, /*symbolCount = */ 0, reassociations, context)));
283 reassociations.clear();
284 ++dim;
285 }
286
287 // Compute the tensor or scalar replacement type.
288 Type elementType = getElementTypeOrSelf(opOperand->get());
289 Type replacementType;
290 if (elementType == opOperand->get().getType()) {
291 replacementType = elementType;
292 } else if (actualType.isa<RankedTensorType>()) {
293 replacementType = RankedTensorType::get(newShape, elementType);
294 } else if (actualType.isa<MemRefType>()) {
295 replacementType = MemRefType::get(newShape, elementType);
296 }
297 assert(replacementType && "unsupported shaped type");
298 UnitExtentReplacementInfo info = {replacementType,
299 AffineMap::get(indexingMap.getNumDims(),
300 indexingMap.getNumSymbols(),
301 newIndexExprs, context),
302 ArrayAttr::get(context, reassociationMaps)};
303 return info;
304 }
305
306 namespace {
307
308 SmallVector<ReassociationExprs, 2>
convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr)309 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
310 SmallVector<ReassociationExprs, 2> reassociationExprs;
311 for (auto attr : affineMapArrayAttr)
312 reassociationExprs.push_back(
313 llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults()));
314 return reassociationExprs;
315 }
316
317 /// Pattern to replace tensor/buffer operands/results that are unit extents.
318 struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
319 using OpRewritePattern<GenericOp>::OpRewritePattern;
320
321 // Return the original value if the type is unchanged, or reshape it. Return a
322 // nullptr if this is an unsupported type.
maybeExpand__anon6cffb90d0511::ReplaceUnitExtents323 Value maybeExpand(Value result, Type origResultType,
324 ArrayAttr reassociationMap, Location loc,
325 PatternRewriter &rewriter) const {
326 if (origResultType == result.getType())
327 return result;
328 if (origResultType.isa<RankedTensorType>()) {
329 return rewriter.create<linalg::TensorExpandShapeOp>(
330 loc, origResultType, result,
331 convertAffineMapArrayToExprs(reassociationMap));
332 }
333 if (origResultType.isa<MemRefType>()) {
334 return rewriter.create<memref::ExpandShapeOp>(
335 loc, origResultType, result,
336 convertAffineMapArrayToExprs(reassociationMap));
337 }
338 return nullptr;
339 };
340
341 // Return the original value if the type is unchanged, or reshape it. Return a
342 // nullptr if this is an unsupported type.
maybeCollapse__anon6cffb90d0511::ReplaceUnitExtents343 Value maybeCollapse(Value operand, Type newInputOutputType,
344 ArrayAttr reassociationMap, Location loc,
345 PatternRewriter &rewriter) const {
346 auto operandType = operand.getType();
347 if (operandType == newInputOutputType)
348 return operand;
349 if (operandType.isa<MemRefType>()) {
350 return rewriter.create<memref::CollapseShapeOp>(
351 loc, newInputOutputType, operand,
352 convertAffineMapArrayToExprs(reassociationMap));
353 }
354 if (operandType.isa<RankedTensorType>()) {
355 return rewriter.create<linalg::TensorCollapseShapeOp>(
356 loc, newInputOutputType, operand,
357 convertAffineMapArrayToExprs(reassociationMap));
358 }
359 return nullptr;
360 };
361
matchAndRewrite__anon6cffb90d0511::ReplaceUnitExtents362 LogicalResult matchAndRewrite(GenericOp genericOp,
363 PatternRewriter &rewriter) const override {
364 // Skip the pattern if the op has any tensor with special encoding.
365 if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
366 auto tensorType = type.dyn_cast<RankedTensorType>();
367 return tensorType && tensorType.getEncoding() != nullptr;
368 }))
369 return failure();
370 MLIRContext *context = rewriter.getContext();
371 Location loc = genericOp.getLoc();
372
373 SmallVector<AffineMap> newIndexingMaps;
374 SmallVector<ArrayAttr> reassociationMaps;
375 SmallVector<Type> newInputOutputTypes;
376 bool doCanonicalization = false;
377 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
378 auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
379 if (replacementInfo) {
380 reassociationMaps.push_back(replacementInfo->reassociation);
381 newIndexingMaps.push_back(replacementInfo->indexMap);
382 newInputOutputTypes.push_back(replacementInfo->type);
383 doCanonicalization |=
384 replacementInfo->type != opOperand->get().getType();
385 } else {
386 // If replaceUnitExtents cannot handle this case, maintain the same
387 // type, indexing map, and create a set of mappings representing an
388 // identity matrix.
389 newInputOutputTypes.push_back(opOperand->get().getType());
390 newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
391 int64_t origRank = genericOp.getRank(opOperand);
392 auto maps = llvm::to_vector<8>(llvm::map_range(
393 llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
394 return AffineMapAttr::get(
395 AffineMap::get(origRank, /*symbolCount = */ 0,
396 getAffineDimExpr(dim, context), context));
397 }));
398 reassociationMaps.push_back(ArrayAttr::get(context, maps));
399 }
400 }
401
402 // If the indexing maps of the result operation are not invertible (i.e. not
403 // legal), abort.
404 if (!doCanonicalization ||
405 !inversePermutation(concatAffineMaps(newIndexingMaps)))
406 return failure();
407
408 // If any operand type change, insert a reshape to convert from the original
409 // type to the new type.
410 // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
411 unsigned flattenedIdx = 0;
412 auto insertReshapes = [&](ValueRange values) {
413 SmallVector<Value, 4> res;
414 res.reserve(values.size());
415 for (auto operand : values) {
416 auto reshapedValue =
417 maybeCollapse(operand, newInputOutputTypes[flattenedIdx],
418 reassociationMaps[flattenedIdx], loc, rewriter);
419 assert(reshapedValue &&
420 "expected ranked MemRef or Tensor operand type");
421 res.push_back(reshapedValue);
422 ++flattenedIdx;
423 }
424 return res;
425 };
426
427 SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
428 SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs());
429
430 // If any result type changes, insert a reshape to convert from the original
431 // type to the new type.
432 SmallVector<Type, 4> resultTypes;
433 resultTypes.reserve(genericOp.getNumResults());
434 for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
435 resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
436 GenericOp replacementOp = rewriter.create<GenericOp>(
437 loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
438 llvm::to_vector<4>(
439 genericOp.iterator_types().template getAsValueRange<StringAttr>()));
440 rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
441 replacementOp.region().begin());
442
443 // If any result tensor has a modified shape, then add reshape to recover
444 // the original shape.
445 SmallVector<Value, 4> resultReplacements;
446 for (auto result : llvm::enumerate(replacementOp.getResults())) {
447 unsigned index = result.index() + replacementOp.getNumInputs();
448 auto origResultType = genericOp.getResult(result.index()).getType();
449
450 auto newResult = maybeExpand(result.value(), origResultType,
451 reassociationMaps[index], loc, rewriter);
452 assert(newResult &&
453 "unexpected output type other than ranked MemRef or Tensor");
454 resultReplacements.push_back(newResult);
455 }
456 rewriter.replaceOp(genericOp, resultReplacements);
457 return success();
458 }
459 };
460 } // namespace
461
462 /// Get the reassociation maps to fold the result of a extract_slice (or source
463 /// of a insert_slice) operation with given offsets, and sizes to its
464 /// rank-reduced version. This is only done for the cases where the size is 1
465 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
466 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
467 /// potentially cannot be handled).
468 static Optional<SmallVector<ReassociationIndices>>
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes)469 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
470 SmallVector<ReassociationIndices> reassociation;
471 ReassociationIndices curr;
472 for (auto it : llvm::enumerate(mixedSizes)) {
473 auto dim = it.index();
474 auto size = it.value();
475 curr.push_back(dim);
476 auto attr = size.dyn_cast<Attribute>();
477 if (attr && attr.cast<IntegerAttr>().getInt() == 1)
478 continue;
479 reassociation.emplace_back(ReassociationIndices{});
480 std::swap(reassociation.back(), curr);
481 }
482 // When the reassociations are not empty, then fold the remaining
483 // unit-dimensions into the last dimension. If the reassociations so far is
484 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
485 if (!curr.empty() && !reassociation.empty())
486 reassociation.back().append(curr.begin(), curr.end());
487 return reassociation;
488 }
489
490 namespace {
491 /// Convert `extract_slice` operations to rank-reduced versions.
492 struct UseRankReducedExtractSliceOp
493 : public OpRewritePattern<tensor::ExtractSliceOp> {
494 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
495
matchAndRewrite__anon6cffb90d0911::UseRankReducedExtractSliceOp496 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
497 PatternRewriter &rewriter) const override {
498 RankedTensorType resultType = sliceOp.getType();
499 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
500 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
501 SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
502 auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
503 if (!reassociation ||
504 reassociation->size() == static_cast<size_t>(resultType.getRank()))
505 return failure();
506 auto rankReducedType = tensor::ExtractSliceOp::inferRankReducedResultType(
507 reassociation->size(), sliceOp.getSourceType(),
508 offsets, sizes, strides)
509 .cast<RankedTensorType>();
510
511 Location loc = sliceOp.getLoc();
512 Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
513 loc, rankReducedType, sliceOp.source(), offsets, sizes, strides);
514 rewriter.replaceOpWithNewOp<TensorExpandShapeOp>(sliceOp, resultType,
515 newSlice, *reassociation);
516 return success();
517 }
518 };
519
520 /// Convert `insert_slice` operations to rank-reduced versions.
521 struct UseRankReducedInsertSliceOp
522 : public OpRewritePattern<tensor::InsertSliceOp> {
523 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
524
matchAndRewrite__anon6cffb90d0911::UseRankReducedInsertSliceOp525 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
526 PatternRewriter &rewriter) const override {
527 RankedTensorType sourceType = insertOp.getSourceType();
528 SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets();
529 SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes();
530 SmallVector<OpFoldResult> strides = insertOp.getMixedStrides();
531 auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
532 if (!reassociation ||
533 reassociation->size() == static_cast<size_t>(sourceType.getRank()))
534 return failure();
535 Location loc = insertOp.getLoc();
536 auto reshapedSource = rewriter.create<TensorCollapseShapeOp>(
537 loc, insertOp.source(), *reassociation);
538 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
539 insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(),
540 insertOp.getMixedSizes(), insertOp.getMixedStrides());
541 return success();
542 }
543 };
544 } // namespace
545
546 /// Patterns that are used to canonicalize the use of unit-extent dims for
547 /// broadcasting.
populateFoldUnitExtentDimsPatterns(RewritePatternSet & patterns)548 void mlir::linalg::populateFoldUnitExtentDimsPatterns(
549 RewritePatternSet &patterns) {
550 auto *context = patterns.getContext();
551 patterns.add<FoldUnitDimLoops, ReplaceUnitExtents,
552 UseRankReducedExtractSliceOp, UseRankReducedInsertSliceOp>(
553 context);
554 TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
555 TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
556 }
557
558 namespace {
559 /// Pass that removes unit-extent dims within generic ops.
560 struct LinalgFoldUnitExtentDimsPass
561 : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
runOnFunction__anon6cffb90d0a11::LinalgFoldUnitExtentDimsPass562 void runOnFunction() override {
563 FuncOp funcOp = getFunction();
564 MLIRContext *context = funcOp.getContext();
565 RewritePatternSet patterns(context);
566 if (foldOneTripLoopsOnly)
567 patterns.add<FoldUnitDimLoops>(context);
568 else
569 populateFoldUnitExtentDimsPatterns(patterns);
570 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
571 }
572 };
573 } // namespace
574
575 std::unique_ptr<OperationPass<FuncOp>>
createLinalgFoldUnitExtentDimsPass()576 mlir::createLinalgFoldUnitExtentDimsPass() {
577 return std::make_unique<LinalgFoldUnitExtentDimsPass>();
578 }
579