1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
14 
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 
17 using namespace mlir;
18 
19 /// Matches a ConstantIndexOp.
20 /// TODO: This should probably just be a general matcher that uses matchConstant
21 /// and checks the operation for an index type.
matchConstantIndex()22 detail::op_matcher<ConstantIndexOp> mlir::matchConstantIndex() {
23   return detail::op_matcher<ConstantIndexOp>();
24 }
25 
26 /// Detects the `values` produced by a ConstantIndexOp and places the new
27 /// constant in place of the corresponding sentinel value.
canonicalizeSubViewPart(SmallVectorImpl<OpFoldResult> & values,llvm::function_ref<bool (int64_t)> isDynamic)28 void mlir::canonicalizeSubViewPart(
29     SmallVectorImpl<OpFoldResult> &values,
30     llvm::function_ref<bool(int64_t)> isDynamic) {
31   for (OpFoldResult &ofr : values) {
32     if (ofr.is<Attribute>())
33       continue;
34     // Newly static, move from Value to constant.
35     if (auto cstOp = ofr.dyn_cast<Value>().getDefiningOp<ConstantIndexOp>())
36       ofr = OpBuilder(cstOp).getIndexAttr(cstOp.getValue());
37   }
38 }
39 
getPositionsOfShapeOne(unsigned rank,ArrayRef<int64_t> shape,llvm::SmallDenseSet<unsigned> & dimsToProject)40 void mlir::getPositionsOfShapeOne(
41     unsigned rank, ArrayRef<int64_t> shape,
42     llvm::SmallDenseSet<unsigned> &dimsToProject) {
43   dimsToProject.reserve(rank);
44   for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
45     if (shape[pos] == 1) {
46       dimsToProject.insert(pos);
47       --rank;
48     }
49   }
50 }
51 
getValueOrCreateConstantIndexOp(OpBuilder & b,Location loc,OpFoldResult ofr)52 Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
53                                             OpFoldResult ofr) {
54   if (auto value = ofr.dyn_cast<Value>())
55     return value;
56   auto attr = ofr.dyn_cast<Attribute>().dyn_cast<IntegerAttr>();
57   assert(attr && "expect the op fold result casts to an integer attribute");
58   return b.create<ConstantIndexOp>(loc, attr.getValue().getSExtValue());
59 }
60 
61 SmallVector<Value>
getValueOrCreateConstantIndexOp(OpBuilder & b,Location loc,ArrayRef<OpFoldResult> valueOrAttrVec)62 mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
63                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
64   return llvm::to_vector<4>(
65       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
66         return getValueOrCreateConstantIndexOp(b, loc, value);
67       }));
68 }
69 
_and(Value lhs,Value rhs)70 Value ArithBuilder::_and(Value lhs, Value rhs) {
71   return b.create<AndOp>(loc, lhs, rhs);
72 }
add(Value lhs,Value rhs)73 Value ArithBuilder::add(Value lhs, Value rhs) {
74   if (lhs.getType().isa<IntegerType>())
75     return b.create<AddIOp>(loc, lhs, rhs);
76   return b.create<AddFOp>(loc, lhs, rhs);
77 }
mul(Value lhs,Value rhs)78 Value ArithBuilder::mul(Value lhs, Value rhs) {
79   if (lhs.getType().isa<IntegerType>())
80     return b.create<MulIOp>(loc, lhs, rhs);
81   return b.create<MulFOp>(loc, lhs, rhs);
82 }
sgt(Value lhs,Value rhs)83 Value ArithBuilder::sgt(Value lhs, Value rhs) {
84   if (lhs.getType().isa<IndexType, IntegerType>())
85     return b.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs);
86   return b.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
87 }
slt(Value lhs,Value rhs)88 Value ArithBuilder::slt(Value lhs, Value rhs) {
89   if (lhs.getType().isa<IndexType, IntegerType>())
90     return b.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs);
91   return b.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
92 }
select(Value cmp,Value lhs,Value rhs)93 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
94   return b.create<SelectOp>(loc, cmp, lhs, rhs);
95 }
96