1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 
18 using namespace mlir;
19 using namespace mlir::shape;
20 using namespace mlir::scf;
21 
22 /// Conversion patterns.
23 namespace {
24 class AnyOpConversion : public OpConversionPattern<AnyOp> {
25 public:
26   using OpConversionPattern<AnyOp>::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
30                   ConversionPatternRewriter &rewriter) const override;
31 };
32 } // namespace
33 
34 LogicalResult
matchAndRewrite(AnyOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const35 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
36                                  ConversionPatternRewriter &rewriter) const {
37   AnyOp::Adaptor transformed(operands);
38 
39   // Replace `any` with its first operand.
40   // Any operand would be a valid substitution.
41   rewriter.replaceOp(op, {transformed.inputs().front()});
42   return success();
43 }
44 
45 namespace {
46 template <typename SrcOpTy, typename DstOpTy>
47 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
48 public:
49   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
50 
51   LogicalResult
matchAndRewrite(SrcOpTy op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const52   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
53                   ConversionPatternRewriter &rewriter) const override {
54     typename SrcOpTy::Adaptor transformed(operands);
55 
56     // For now, only error-free types are supported by this lowering.
57     if (op.getType().template isa<SizeType>())
58       return failure();
59 
60     rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
61                                          transformed.rhs());
62     return success();
63   }
64 };
65 } // namespace
66 
67 namespace {
68 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
69   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
70 
71   LogicalResult
72   matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
73                   ConversionPatternRewriter &rewriter) const override;
74 };
75 } // namespace
76 
matchAndRewrite(BroadcastOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const77 LogicalResult BroadcastOpConverter::matchAndRewrite(
78     BroadcastOp op, ArrayRef<Value> operands,
79     ConversionPatternRewriter &rewriter) const {
80   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
81   // on shapes.
82   if (op.getType().isa<ShapeType>())
83     return failure();
84 
85   assert(!op.lhs().getType().isa<ShapeType>() &&
86          !op.rhs().getType().isa<ShapeType>());
87   auto loc = op.getLoc();
88   BroadcastOp::Adaptor transformed(operands);
89   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
90   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
91 
92   // Find smaller and greater rank and extent tensor.
93   Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
94   Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
95   Value lhsRankULE =
96       rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
97   Type indexTy = rewriter.getIndexType();
98   Value lesserRank =
99       rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
100   Value greaterRank =
101       rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
102   auto erasedRankType =
103       RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
104   Value rankErasedLhs =
105       rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
106   Value rankErasedRhs =
107       rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
108   Value lesserRankOperand =
109       rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
110   Value greaterRankOperand =
111       rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
112 
113   Value rankDiff =
114       rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
115   rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
116       op, getExtentTensorType(op.getContext()), ValueRange{greaterRank},
117       [&](OpBuilder &b, Location loc, ValueRange args) {
118         Value outputDimension = args[0];
119         Value isUnchallengedDimension = b.create<CmpIOp>(
120             loc, CmpIPredicate::ult, outputDimension, rankDiff);
121         Value greaterRankOperandExtent = b.create<ExtractElementOp>(
122             loc, greaterRankOperand, outputDimension);
123         // The initial dimensions of the greater-rank operand are unchallenged,
124         // so we can take them as-is. Otherwise, we need to do a comparison.
125         // We need an actual branch here (instead of a select) because the
126         // lesser-rank operand might be rank 0, so any extract_element would be
127         // invalid.
128         auto ifOp = b.create<IfOp>(
129             loc, TypeRange{indexTy}, isUnchallengedDimension,
130             [&](OpBuilder &b, Location loc) {
131               b.create<scf::YieldOp>(loc, greaterRankOperandExtent);
132             },
133             [&](OpBuilder &b, Location loc) {
134               // The broadcasting logic is:
135               // - if one extent (here we arbitrarily choose the extent from
136               // the greater-rank operand) is equal to 1, then take the extent
137               // from the other operand
138               // - otherwise, take the extent as-is.
139               // Note that this logic remains correct in the presence of
140               // dimensions of zero extent.
141               Value lesserRankOperandDimension =
142                   b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
143               Value lesserRankOperandExtent = b.create<ExtractElementOp>(
144                   loc, lesserRankOperand,
145                   ValueRange{lesserRankOperandDimension});
146               Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
147                   loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
148               Value broadcastedExtent = b.create<SelectOp>(
149                   loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent,
150                   greaterRankOperandExtent);
151               b.create<scf::YieldOp>(loc, broadcastedExtent);
152             });
153         b.create<mlir::YieldOp>(loc, ifOp.getResult(0));
154       });
155   return success();
156 }
157 
158 namespace {
159 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
160 public:
161   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
162 
163   LogicalResult
164   matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
165                   ConversionPatternRewriter &rewriter) const override;
166 };
167 } // namespace
168 
matchAndRewrite(ConstShapeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const169 LogicalResult ConstShapeOpConverter::matchAndRewrite(
170     ConstShapeOp op, ArrayRef<Value> operands,
171     ConversionPatternRewriter &rewriter) const {
172 
173   // For now, this lowering supports only extent tensors, not `shape.shape`
174   // types.
175   if (op.getType().isa<ShapeType>())
176     return failure();
177 
178   auto loc = op.getLoc();
179   SmallVector<Value, 4> extentOperands;
180   for (auto extent : op.shape()) {
181     extentOperands.push_back(
182         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
183   }
184   Type indexTy = rewriter.getIndexType();
185   Value tensor =
186       rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
187   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
188   rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
189   return success();
190 }
191 
192 namespace {
193 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
194 public:
195   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
196 
197   LogicalResult
198   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
199                   ConversionPatternRewriter &rewriter) const override;
200 };
201 } // namespace
202 
matchAndRewrite(ConstSizeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const203 LogicalResult ConstSizeOpConversion::matchAndRewrite(
204     ConstSizeOp op, ArrayRef<Value> operands,
205     ConversionPatternRewriter &rewriter) const {
206   rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
207   return success();
208 }
209 
210 namespace {
211 struct IsBroadcastableOpConverter
212     : public OpConversionPattern<IsBroadcastableOp> {
213   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
214 
215   LogicalResult
216   matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands,
217                   ConversionPatternRewriter &rewriter) const override;
218 };
219 } // namespace
220 
matchAndRewrite(IsBroadcastableOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const221 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
222     IsBroadcastableOp op, ArrayRef<Value> operands,
223     ConversionPatternRewriter &rewriter) const {
224   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
225   // on shapes.
226   IsBroadcastableOp::Adaptor transformed(operands);
227   if (transformed.lhs().getType().isa<ShapeType>() ||
228       transformed.rhs().getType().isa<ShapeType>())
229     return failure();
230 
231   auto loc = op.getLoc();
232   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
233   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
234 
235   // Find smaller and greater rank and extent tensor.
236   Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
237   Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
238   Value lhsRankULE =
239       rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
240   Type indexTy = rewriter.getIndexType();
241   Value lesserRank =
242       rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
243   Value greaterRank =
244       rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
245   auto erasedRankType =
246       RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
247   Value rankErasedLhs =
248       rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
249   Value rankErasedRhs =
250       rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
251   Value lesserRankOperand =
252       rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
253   Value greaterRankOperand =
254       rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
255   Value rankDiff =
256       rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
257   Type i1Ty = rewriter.getI1Type();
258   Value init =
259       rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
260 
261   // Determine if all overlapping extents are broadcastable.
262   auto reduceResult = rewriter.create<ForOp>(
263       loc, rankDiff, greaterRank, one, ValueRange{init},
264       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
265         Value greaterRankOperandExtent =
266             b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
267         Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
268             loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
269         Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
270         Value lesserRankOperandExtent = b.create<ExtractElementOp>(
271             loc, lesserRankOperand, ValueRange{ivShifted});
272         Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
273             loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
274         Value extentsAreEqual =
275             b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
276                              lesserRankOperandExtent);
277         Value broadcastableExtents = b.create<AndOp>(
278             loc, iterArgs[0],
279             b.create<OrOp>(loc,
280                            b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
281                                           lesserRankOperandExtentIsOne),
282                            extentsAreEqual));
283         b.create<scf::YieldOp>(loc, broadcastableExtents);
284       });
285 
286   rewriter.replaceOp(op, reduceResult.results().front());
287   return success();
288 }
289 
290 namespace {
291 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
292   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
293 
294   LogicalResult
295   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
296                   ConversionPatternRewriter &rewriter) const override;
297 };
298 } // namespace
299 
matchAndRewrite(GetExtentOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const300 LogicalResult GetExtentOpConverter::matchAndRewrite(
301     GetExtentOp op, ArrayRef<Value> operands,
302     ConversionPatternRewriter &rewriter) const {
303   GetExtentOp::Adaptor transformed(operands);
304 
305   // For now, only error-free types are supported by this lowering.
306   if (op.getType().isa<SizeType>())
307     return failure();
308 
309   // Derive shape extent directly from shape origin if possible. This
310   // circumvents the necessity to materialize the shape in memory.
311   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
312     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
313       rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
314                                          transformed.dim());
315       return success();
316     }
317   }
318 
319   rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
320                                                 transformed.shape(),
321                                                 ValueRange{transformed.dim()});
322   return success();
323 }
324 
325 namespace {
326 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
327 public:
328   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
329 
330   LogicalResult
331   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
332                   ConversionPatternRewriter &rewriter) const override;
333 };
334 } // namespace
335 
336 LogicalResult
matchAndRewrite(shape::RankOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const337 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
338                                  ConversionPatternRewriter &rewriter) const {
339   // For now, this lowering supports only error-free types.
340   if (op.getType().isa<SizeType>())
341     return failure();
342 
343   shape::RankOp::Adaptor transformed(operands);
344   rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
345   return success();
346 }
347 
348 namespace {
349 /// Converts `shape.reduce` to `scf.for`.
350 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
351 public:
352   using OpConversionPattern::OpConversionPattern;
353 
354   LogicalResult
355   matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
356                   ConversionPatternRewriter &rewriter) const final;
357 };
358 } // namespace
359 
360 LogicalResult
matchAndRewrite(shape::ReduceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const361 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
362                                    ConversionPatternRewriter &rewriter) const {
363   // For now, this lowering is only defined on `tensor<?xindex>` operands.
364   if (op.shape().getType().isa<ShapeType>())
365     return failure();
366 
367   auto loc = op.getLoc();
368   shape::ReduceOp::Adaptor transformed(operands);
369 
370   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
371   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
372   Type indexTy = rewriter.getIndexType();
373   Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
374 
375   auto loop = rewriter.create<scf::ForOp>(
376       loc, zero, rank, one, op.initVals(),
377       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
378         Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv);
379 
380         SmallVector<Value, 2> mappedValues{iv, extent};
381         mappedValues.append(args.begin(), args.end());
382 
383         BlockAndValueMapping mapping;
384         Block *reduceBody = op.getBody();
385         mapping.map(reduceBody->getArguments(), mappedValues);
386         for (auto &nested : reduceBody->without_terminator())
387           b.clone(nested, mapping);
388 
389         SmallVector<Value, 2> mappedResults;
390         for (auto result : reduceBody->getTerminator()->getOperands())
391           mappedResults.push_back(mapping.lookup(result));
392         b.create<scf::YieldOp>(loc, mappedResults);
393       });
394 
395   rewriter.replaceOp(op, loop.getResults());
396   return success();
397 }
398 
399 namespace {
400 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
401 /// only defined on `tensor<?xindex>` operands. The test for equality first
402 /// compares their size and, if equal, checks every extent for equality.
403 ///
404 /// Example:
405 ///
406 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
407 ///
408 /// becomes
409 ///
410 /// %c0 = constant 0 : index
411 /// %0 = dim %arg0, %c0 : tensor<?xindex>
412 /// %1 = dim %arg1, %c0 : tensor<?xindex>
413 /// %2 = cmpi "eq", %0, %1 : index
414 /// %result = scf.if %2 -> (i1) {
415 ///   %c1 = constant 1 : index
416 ///   %true = constant true
417 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
418 ///     %5 = extract_element %arg0[%arg2] : tensor<?xindex>
419 ///     %6 = extract_element %arg1[%arg2] : tensor<?xindex>
420 ///     %7 = cmpi "eq", %5, %6 : index
421 ///     %8 = and %arg3, %7 : i1
422 ///     scf.yield %8 : i1
423 ///   }
424 ///   scf.yield %4 : i1
425 /// } else {
426 ///   %false = constant false
427 ///   scf.yield %false : i1
428 /// }
429 ///
430 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
431   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
432 
433   LogicalResult
434   matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
435                   ConversionPatternRewriter &rewriter) const override;
436 };
437 } // namespace
438 
439 LogicalResult
matchAndRewrite(ShapeEqOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const440 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
441                                     ConversionPatternRewriter &rewriter) const {
442   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
443   // on shapes.
444   if (op.lhs().getType().isa<ShapeType>() ||
445       op.rhs().getType().isa<ShapeType>()) {
446     return failure();
447   }
448 
449   ShapeEqOp::Adaptor transformed(operands);
450   auto loc = op.getLoc();
451   Type indexTy = rewriter.getIndexType();
452   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
453   Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
454   Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
455   Value eqRank =
456       rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
457   Type i1Ty = rewriter.getI1Type();
458   rewriter.replaceOpWithNewOp<IfOp>(
459       op, i1Ty, eqRank,
460       [&](OpBuilder &b, Location loc) {
461         Value one = b.create<ConstantIndexOp>(loc, 1);
462         Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
463         auto loop = b.create<scf::ForOp>(
464             loc, zero, lhsRank, one, ValueRange{init},
465             [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
466               Value conj = args[0];
467               Value lhsExtent =
468                   b.create<ExtractElementOp>(loc, transformed.lhs(), iv);
469               Value rhsExtent =
470                   b.create<ExtractElementOp>(loc, transformed.rhs(), iv);
471               Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
472                                                 lhsExtent, rhsExtent);
473               Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
474               b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
475             });
476         b.create<scf::YieldOp>(loc, loop.getResults());
477       },
478       [&](OpBuilder &b, Location loc) {
479         Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
480         b.create<scf::YieldOp>(loc, result);
481       });
482   return success();
483 }
484 
485 namespace {
486 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
487 public:
488   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
489 
490   LogicalResult
491   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
492                   ConversionPatternRewriter &rewriter) const override;
493 };
494 } // namespace
495 
matchAndRewrite(ShapeOfOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const496 LogicalResult ShapeOfOpConversion::matchAndRewrite(
497     ShapeOfOp op, ArrayRef<Value> operands,
498     ConversionPatternRewriter &rewriter) const {
499 
500   // For now, only error-free types are supported by this lowering.
501   if (op.getType().isa<ShapeType>())
502     return failure();
503 
504   // For ranked tensor arguments, lower to `tensor_from_elements`.
505   auto loc = op.getLoc();
506   ShapeOfOp::Adaptor transformed(operands);
507   Value tensor = transformed.arg();
508   Type tensorTy = tensor.getType();
509   if (tensorTy.isa<RankedTensorType>()) {
510 
511     // Build values for individual extents.
512     SmallVector<Value, 8> extentValues;
513     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
514     int64_t rank = rankedTensorTy.getRank();
515     for (int64_t i = 0; i < rank; i++) {
516       if (rankedTensorTy.isDynamicDim(i)) {
517         Value extent = rewriter.create<DimOp>(loc, tensor, i);
518         extentValues.push_back(extent);
519       } else {
520         Value extent =
521             rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
522         extentValues.push_back(extent);
523       }
524     }
525 
526     // Materialize extent tensor.
527     Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
528         loc, rewriter.getIndexType(), extentValues);
529     rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
530                                               op.getType());
531     return success();
532   }
533 
534   // Lower to `dynamic_tensor_from_elements` otherwise.
535   auto *ctx = rewriter.getContext();
536   Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
537   rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
538       op, getExtentTensorType(ctx), ValueRange{rank},
539       [&](OpBuilder &b, Location loc, ValueRange args) {
540         Value dim = args.front();
541         Value extent = b.create<DimOp>(loc, tensor, dim);
542         b.create<mlir::YieldOp>(loc, extent);
543       });
544 
545   return success();
546 }
547 
548 namespace {
549 class ToExtentTensorOpConversion
550     : public OpConversionPattern<ToExtentTensorOp> {
551 public:
552   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
553 
554   LogicalResult
matchAndRewrite(ToExtentTensorOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const555   matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
556                   ConversionPatternRewriter &rewriter) const override {
557     ToExtentTensorOpAdaptor adaptor(operands);
558 
559     if (!adaptor.input().getType().isa<RankedTensorType>())
560       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
561 
562     rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
563                                               op.getType());
564     return success();
565   }
566 };
567 } // namespace
568 
569 namespace {
570 /// Import the Shape Ops to Std Patterns.
571 #include "ShapeToStandard.cpp.inc"
572 } // namespace
573 
574 namespace {
575 /// Conversion pass.
576 class ConvertShapeToStandardPass
577     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
578 
579   void runOnOperation() override;
580 };
581 } // namespace
582 
runOnOperation()583 void ConvertShapeToStandardPass::runOnOperation() {
584   // Setup target legality.
585   MLIRContext &ctx = getContext();
586   ConversionTarget target(ctx);
587   target.addLegalDialect<StandardOpsDialect, SCFDialect>();
588   target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
589 
590   // Setup conversion patterns.
591   OwningRewritePatternList patterns;
592   populateShapeToStandardConversionPatterns(patterns, &ctx);
593 
594   // Apply conversion.
595   auto module = getOperation();
596   if (failed(applyPartialConversion(module, target, std::move(patterns))))
597     signalPassFailure();
598 }
599 
populateShapeToStandardConversionPatterns(OwningRewritePatternList & patterns,MLIRContext * ctx)600 void mlir::populateShapeToStandardConversionPatterns(
601     OwningRewritePatternList &patterns, MLIRContext *ctx) {
602   // clang-format off
603   populateWithGenerated(ctx, patterns);
604   patterns.insert<
605       AnyOpConversion,
606       BinaryOpConversion<AddOp, AddIOp>,
607       BinaryOpConversion<MulOp, MulIOp>,
608       BroadcastOpConverter,
609       ConstShapeOpConverter,
610       ConstSizeOpConversion,
611       IsBroadcastableOpConverter,
612       GetExtentOpConverter,
613       RankOpConverter,
614       ReduceOpConverter,
615       ShapeEqOpConverter,
616       ShapeOfOpConversion,
617       ToExtentTensorOpConversion>(ctx);
618   // clang-format on
619 }
620 
621 std::unique_ptr<OperationPass<ModuleOp>>
createConvertShapeToStandardPass()622 mlir::createConvertShapeToStandardPass() {
623   return std::make_unique<ConvertShapeToStandardPass>();
624 }
625