1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/Transforms/DialectConversion.h"
17
18 using namespace mlir;
19 using namespace mlir::shape;
20 using namespace mlir::scf;
21
22 /// Conversion patterns.
23 namespace {
24 class AnyOpConversion : public OpConversionPattern<AnyOp> {
25 public:
26 using OpConversionPattern<AnyOp>::OpConversionPattern;
27
28 LogicalResult
29 matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
30 ConversionPatternRewriter &rewriter) const override;
31 };
32 } // namespace
33
34 LogicalResult
matchAndRewrite(AnyOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const35 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
36 ConversionPatternRewriter &rewriter) const {
37 AnyOp::Adaptor transformed(operands);
38
39 // Replace `any` with its first operand.
40 // Any operand would be a valid substitution.
41 rewriter.replaceOp(op, {transformed.inputs().front()});
42 return success();
43 }
44
45 namespace {
46 template <typename SrcOpTy, typename DstOpTy>
47 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
48 public:
49 using OpConversionPattern<SrcOpTy>::OpConversionPattern;
50
51 LogicalResult
matchAndRewrite(SrcOpTy op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const52 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
53 ConversionPatternRewriter &rewriter) const override {
54 typename SrcOpTy::Adaptor transformed(operands);
55
56 // For now, only error-free types are supported by this lowering.
57 if (op.getType().template isa<SizeType>())
58 return failure();
59
60 rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
61 transformed.rhs());
62 return success();
63 }
64 };
65 } // namespace
66
67 namespace {
68 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
69 using OpConversionPattern<BroadcastOp>::OpConversionPattern;
70
71 LogicalResult
72 matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
73 ConversionPatternRewriter &rewriter) const override;
74 };
75 } // namespace
76
matchAndRewrite(BroadcastOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const77 LogicalResult BroadcastOpConverter::matchAndRewrite(
78 BroadcastOp op, ArrayRef<Value> operands,
79 ConversionPatternRewriter &rewriter) const {
80 // For now, this lowering is only defined on `tensor<?xindex>` operands, not
81 // on shapes.
82 if (op.getType().isa<ShapeType>())
83 return failure();
84
85 assert(!op.lhs().getType().isa<ShapeType>() &&
86 !op.rhs().getType().isa<ShapeType>());
87 auto loc = op.getLoc();
88 BroadcastOp::Adaptor transformed(operands);
89 Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
90 Value one = rewriter.create<ConstantIndexOp>(loc, 1);
91
92 // Find smaller and greater rank and extent tensor.
93 Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
94 Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
95 Value lhsRankULE =
96 rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
97 Type indexTy = rewriter.getIndexType();
98 Value lesserRank =
99 rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
100 Value greaterRank =
101 rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
102 auto erasedRankType =
103 RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
104 Value rankErasedLhs =
105 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
106 Value rankErasedRhs =
107 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
108 Value lesserRankOperand =
109 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
110 Value greaterRankOperand =
111 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
112
113 Value rankDiff =
114 rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
115 rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
116 op, getExtentTensorType(op.getContext()), ValueRange{greaterRank},
117 [&](OpBuilder &b, Location loc, ValueRange args) {
118 Value outputDimension = args[0];
119 Value isUnchallengedDimension = b.create<CmpIOp>(
120 loc, CmpIPredicate::ult, outputDimension, rankDiff);
121 Value greaterRankOperandExtent = b.create<ExtractElementOp>(
122 loc, greaterRankOperand, outputDimension);
123 // The initial dimensions of the greater-rank operand are unchallenged,
124 // so we can take them as-is. Otherwise, we need to do a comparison.
125 // We need an actual branch here (instead of a select) because the
126 // lesser-rank operand might be rank 0, so any extract_element would be
127 // invalid.
128 auto ifOp = b.create<IfOp>(
129 loc, TypeRange{indexTy}, isUnchallengedDimension,
130 [&](OpBuilder &b, Location loc) {
131 b.create<scf::YieldOp>(loc, greaterRankOperandExtent);
132 },
133 [&](OpBuilder &b, Location loc) {
134 // The broadcasting logic is:
135 // - if one extent (here we arbitrarily choose the extent from
136 // the greater-rank operand) is equal to 1, then take the extent
137 // from the other operand
138 // - otherwise, take the extent as-is.
139 // Note that this logic remains correct in the presence of
140 // dimensions of zero extent.
141 Value lesserRankOperandDimension =
142 b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
143 Value lesserRankOperandExtent = b.create<ExtractElementOp>(
144 loc, lesserRankOperand,
145 ValueRange{lesserRankOperandDimension});
146 Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
147 loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
148 Value broadcastedExtent = b.create<SelectOp>(
149 loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent,
150 greaterRankOperandExtent);
151 b.create<scf::YieldOp>(loc, broadcastedExtent);
152 });
153 b.create<mlir::YieldOp>(loc, ifOp.getResult(0));
154 });
155 return success();
156 }
157
158 namespace {
159 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
160 public:
161 using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
162
163 LogicalResult
164 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
165 ConversionPatternRewriter &rewriter) const override;
166 };
167 } // namespace
168
matchAndRewrite(ConstShapeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const169 LogicalResult ConstShapeOpConverter::matchAndRewrite(
170 ConstShapeOp op, ArrayRef<Value> operands,
171 ConversionPatternRewriter &rewriter) const {
172
173 // For now, this lowering supports only extent tensors, not `shape.shape`
174 // types.
175 if (op.getType().isa<ShapeType>())
176 return failure();
177
178 auto loc = op.getLoc();
179 SmallVector<Value, 4> extentOperands;
180 for (auto extent : op.shape()) {
181 extentOperands.push_back(
182 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
183 }
184 Type indexTy = rewriter.getIndexType();
185 Value tensor =
186 rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
187 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
188 rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
189 return success();
190 }
191
192 namespace {
193 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
194 public:
195 using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
196
197 LogicalResult
198 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
199 ConversionPatternRewriter &rewriter) const override;
200 };
201 } // namespace
202
matchAndRewrite(ConstSizeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const203 LogicalResult ConstSizeOpConversion::matchAndRewrite(
204 ConstSizeOp op, ArrayRef<Value> operands,
205 ConversionPatternRewriter &rewriter) const {
206 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
207 return success();
208 }
209
210 namespace {
211 struct IsBroadcastableOpConverter
212 : public OpConversionPattern<IsBroadcastableOp> {
213 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
214
215 LogicalResult
216 matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands,
217 ConversionPatternRewriter &rewriter) const override;
218 };
219 } // namespace
220
matchAndRewrite(IsBroadcastableOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const221 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
222 IsBroadcastableOp op, ArrayRef<Value> operands,
223 ConversionPatternRewriter &rewriter) const {
224 // For now, this lowering is only defined on `tensor<?xindex>` operands, not
225 // on shapes.
226 IsBroadcastableOp::Adaptor transformed(operands);
227 if (transformed.lhs().getType().isa<ShapeType>() ||
228 transformed.rhs().getType().isa<ShapeType>())
229 return failure();
230
231 auto loc = op.getLoc();
232 Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
233 Value one = rewriter.create<ConstantIndexOp>(loc, 1);
234
235 // Find smaller and greater rank and extent tensor.
236 Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
237 Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
238 Value lhsRankULE =
239 rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
240 Type indexTy = rewriter.getIndexType();
241 Value lesserRank =
242 rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
243 Value greaterRank =
244 rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
245 auto erasedRankType =
246 RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
247 Value rankErasedLhs =
248 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
249 Value rankErasedRhs =
250 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
251 Value lesserRankOperand =
252 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
253 Value greaterRankOperand =
254 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
255 Value rankDiff =
256 rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
257 Type i1Ty = rewriter.getI1Type();
258 Value init =
259 rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
260
261 // Determine if all overlapping extents are broadcastable.
262 auto reduceResult = rewriter.create<ForOp>(
263 loc, rankDiff, greaterRank, one, ValueRange{init},
264 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
265 Value greaterRankOperandExtent =
266 b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
267 Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
268 loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
269 Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
270 Value lesserRankOperandExtent = b.create<ExtractElementOp>(
271 loc, lesserRankOperand, ValueRange{ivShifted});
272 Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
273 loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
274 Value extentsAreEqual =
275 b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
276 lesserRankOperandExtent);
277 Value broadcastableExtents = b.create<AndOp>(
278 loc, iterArgs[0],
279 b.create<OrOp>(loc,
280 b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
281 lesserRankOperandExtentIsOne),
282 extentsAreEqual));
283 b.create<scf::YieldOp>(loc, broadcastableExtents);
284 });
285
286 rewriter.replaceOp(op, reduceResult.results().front());
287 return success();
288 }
289
290 namespace {
291 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
292 using OpConversionPattern<GetExtentOp>::OpConversionPattern;
293
294 LogicalResult
295 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
296 ConversionPatternRewriter &rewriter) const override;
297 };
298 } // namespace
299
matchAndRewrite(GetExtentOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const300 LogicalResult GetExtentOpConverter::matchAndRewrite(
301 GetExtentOp op, ArrayRef<Value> operands,
302 ConversionPatternRewriter &rewriter) const {
303 GetExtentOp::Adaptor transformed(operands);
304
305 // For now, only error-free types are supported by this lowering.
306 if (op.getType().isa<SizeType>())
307 return failure();
308
309 // Derive shape extent directly from shape origin if possible. This
310 // circumvents the necessity to materialize the shape in memory.
311 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
312 if (shapeOfOp.arg().getType().isa<ShapedType>()) {
313 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
314 transformed.dim());
315 return success();
316 }
317 }
318
319 rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
320 transformed.shape(),
321 ValueRange{transformed.dim()});
322 return success();
323 }
324
325 namespace {
326 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
327 public:
328 using OpConversionPattern<shape::RankOp>::OpConversionPattern;
329
330 LogicalResult
331 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
332 ConversionPatternRewriter &rewriter) const override;
333 };
334 } // namespace
335
336 LogicalResult
matchAndRewrite(shape::RankOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const337 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
338 ConversionPatternRewriter &rewriter) const {
339 // For now, this lowering supports only error-free types.
340 if (op.getType().isa<SizeType>())
341 return failure();
342
343 shape::RankOp::Adaptor transformed(operands);
344 rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
345 return success();
346 }
347
348 namespace {
349 /// Converts `shape.reduce` to `scf.for`.
350 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
351 public:
352 using OpConversionPattern::OpConversionPattern;
353
354 LogicalResult
355 matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
356 ConversionPatternRewriter &rewriter) const final;
357 };
358 } // namespace
359
360 LogicalResult
matchAndRewrite(shape::ReduceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const361 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
362 ConversionPatternRewriter &rewriter) const {
363 // For now, this lowering is only defined on `tensor<?xindex>` operands.
364 if (op.shape().getType().isa<ShapeType>())
365 return failure();
366
367 auto loc = op.getLoc();
368 shape::ReduceOp::Adaptor transformed(operands);
369
370 Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
371 Value one = rewriter.create<ConstantIndexOp>(loc, 1);
372 Type indexTy = rewriter.getIndexType();
373 Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
374
375 auto loop = rewriter.create<scf::ForOp>(
376 loc, zero, rank, one, op.initVals(),
377 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
378 Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv);
379
380 SmallVector<Value, 2> mappedValues{iv, extent};
381 mappedValues.append(args.begin(), args.end());
382
383 BlockAndValueMapping mapping;
384 Block *reduceBody = op.getBody();
385 mapping.map(reduceBody->getArguments(), mappedValues);
386 for (auto &nested : reduceBody->without_terminator())
387 b.clone(nested, mapping);
388
389 SmallVector<Value, 2> mappedResults;
390 for (auto result : reduceBody->getTerminator()->getOperands())
391 mappedResults.push_back(mapping.lookup(result));
392 b.create<scf::YieldOp>(loc, mappedResults);
393 });
394
395 rewriter.replaceOp(op, loop.getResults());
396 return success();
397 }
398
399 namespace {
400 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
401 /// only defined on `tensor<?xindex>` operands. The test for equality first
402 /// compares their size and, if equal, checks every extent for equality.
403 ///
404 /// Example:
405 ///
406 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
407 ///
408 /// becomes
409 ///
410 /// %c0 = constant 0 : index
411 /// %0 = dim %arg0, %c0 : tensor<?xindex>
412 /// %1 = dim %arg1, %c0 : tensor<?xindex>
413 /// %2 = cmpi "eq", %0, %1 : index
414 /// %result = scf.if %2 -> (i1) {
415 /// %c1 = constant 1 : index
416 /// %true = constant true
417 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
418 /// %5 = extract_element %arg0[%arg2] : tensor<?xindex>
419 /// %6 = extract_element %arg1[%arg2] : tensor<?xindex>
420 /// %7 = cmpi "eq", %5, %6 : index
421 /// %8 = and %arg3, %7 : i1
422 /// scf.yield %8 : i1
423 /// }
424 /// scf.yield %4 : i1
425 /// } else {
426 /// %false = constant false
427 /// scf.yield %false : i1
428 /// }
429 ///
430 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
431 using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
432
433 LogicalResult
434 matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
435 ConversionPatternRewriter &rewriter) const override;
436 };
437 } // namespace
438
439 LogicalResult
matchAndRewrite(ShapeEqOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const440 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
441 ConversionPatternRewriter &rewriter) const {
442 // For now, this lowering is only defined on `tensor<?xindex>` operands, not
443 // on shapes.
444 if (op.lhs().getType().isa<ShapeType>() ||
445 op.rhs().getType().isa<ShapeType>()) {
446 return failure();
447 }
448
449 ShapeEqOp::Adaptor transformed(operands);
450 auto loc = op.getLoc();
451 Type indexTy = rewriter.getIndexType();
452 Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
453 Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
454 Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
455 Value eqRank =
456 rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
457 Type i1Ty = rewriter.getI1Type();
458 rewriter.replaceOpWithNewOp<IfOp>(
459 op, i1Ty, eqRank,
460 [&](OpBuilder &b, Location loc) {
461 Value one = b.create<ConstantIndexOp>(loc, 1);
462 Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
463 auto loop = b.create<scf::ForOp>(
464 loc, zero, lhsRank, one, ValueRange{init},
465 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
466 Value conj = args[0];
467 Value lhsExtent =
468 b.create<ExtractElementOp>(loc, transformed.lhs(), iv);
469 Value rhsExtent =
470 b.create<ExtractElementOp>(loc, transformed.rhs(), iv);
471 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
472 lhsExtent, rhsExtent);
473 Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
474 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
475 });
476 b.create<scf::YieldOp>(loc, loop.getResults());
477 },
478 [&](OpBuilder &b, Location loc) {
479 Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
480 b.create<scf::YieldOp>(loc, result);
481 });
482 return success();
483 }
484
485 namespace {
486 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
487 public:
488 using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
489
490 LogicalResult
491 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
492 ConversionPatternRewriter &rewriter) const override;
493 };
494 } // namespace
495
matchAndRewrite(ShapeOfOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const496 LogicalResult ShapeOfOpConversion::matchAndRewrite(
497 ShapeOfOp op, ArrayRef<Value> operands,
498 ConversionPatternRewriter &rewriter) const {
499
500 // For now, only error-free types are supported by this lowering.
501 if (op.getType().isa<ShapeType>())
502 return failure();
503
504 // For ranked tensor arguments, lower to `tensor_from_elements`.
505 auto loc = op.getLoc();
506 ShapeOfOp::Adaptor transformed(operands);
507 Value tensor = transformed.arg();
508 Type tensorTy = tensor.getType();
509 if (tensorTy.isa<RankedTensorType>()) {
510
511 // Build values for individual extents.
512 SmallVector<Value, 8> extentValues;
513 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
514 int64_t rank = rankedTensorTy.getRank();
515 for (int64_t i = 0; i < rank; i++) {
516 if (rankedTensorTy.isDynamicDim(i)) {
517 Value extent = rewriter.create<DimOp>(loc, tensor, i);
518 extentValues.push_back(extent);
519 } else {
520 Value extent =
521 rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
522 extentValues.push_back(extent);
523 }
524 }
525
526 // Materialize extent tensor.
527 Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
528 loc, rewriter.getIndexType(), extentValues);
529 rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
530 op.getType());
531 return success();
532 }
533
534 // Lower to `dynamic_tensor_from_elements` otherwise.
535 auto *ctx = rewriter.getContext();
536 Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
537 rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
538 op, getExtentTensorType(ctx), ValueRange{rank},
539 [&](OpBuilder &b, Location loc, ValueRange args) {
540 Value dim = args.front();
541 Value extent = b.create<DimOp>(loc, tensor, dim);
542 b.create<mlir::YieldOp>(loc, extent);
543 });
544
545 return success();
546 }
547
548 namespace {
549 class ToExtentTensorOpConversion
550 : public OpConversionPattern<ToExtentTensorOp> {
551 public:
552 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
553
554 LogicalResult
matchAndRewrite(ToExtentTensorOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const555 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
556 ConversionPatternRewriter &rewriter) const override {
557 ToExtentTensorOpAdaptor adaptor(operands);
558
559 if (!adaptor.input().getType().isa<RankedTensorType>())
560 return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
561
562 rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
563 op.getType());
564 return success();
565 }
566 };
567 } // namespace
568
569 namespace {
570 /// Import the Shape Ops to Std Patterns.
571 #include "ShapeToStandard.cpp.inc"
572 } // namespace
573
574 namespace {
575 /// Conversion pass.
576 class ConvertShapeToStandardPass
577 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
578
579 void runOnOperation() override;
580 };
581 } // namespace
582
runOnOperation()583 void ConvertShapeToStandardPass::runOnOperation() {
584 // Setup target legality.
585 MLIRContext &ctx = getContext();
586 ConversionTarget target(ctx);
587 target.addLegalDialect<StandardOpsDialect, SCFDialect>();
588 target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
589
590 // Setup conversion patterns.
591 OwningRewritePatternList patterns;
592 populateShapeToStandardConversionPatterns(patterns, &ctx);
593
594 // Apply conversion.
595 auto module = getOperation();
596 if (failed(applyPartialConversion(module, target, std::move(patterns))))
597 signalPassFailure();
598 }
599
populateShapeToStandardConversionPatterns(OwningRewritePatternList & patterns,MLIRContext * ctx)600 void mlir::populateShapeToStandardConversionPatterns(
601 OwningRewritePatternList &patterns, MLIRContext *ctx) {
602 // clang-format off
603 populateWithGenerated(ctx, patterns);
604 patterns.insert<
605 AnyOpConversion,
606 BinaryOpConversion<AddOp, AddIOp>,
607 BinaryOpConversion<MulOp, MulIOp>,
608 BroadcastOpConverter,
609 ConstShapeOpConverter,
610 ConstSizeOpConversion,
611 IsBroadcastableOpConverter,
612 GetExtentOpConverter,
613 RankOpConverter,
614 ReduceOpConverter,
615 ShapeEqOpConverter,
616 ShapeOfOpConversion,
617 ToExtentTensorOpConversion>(ctx);
618 // clang-format on
619 }
620
621 std::unique_ptr<OperationPass<ModuleOp>>
createConvertShapeToStandardPass()622 mlir::createConvertShapeToStandardPass() {
623 return std::make_unique<ConvertShapeToStandardPass>();
624 }
625