1 //===- Traits.cpp - Common op traits shared by dialects -------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Traits.h"
10 #include "mlir/IR/StandardTypes.h"
11 #include "llvm/Support/FormatVariadic.h"
12
13 using namespace mlir;
14
getBroadcastedShape(ArrayRef<int64_t> shape1,ArrayRef<int64_t> shape2,SmallVectorImpl<int64_t> & resultShape)15 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
16 ArrayRef<int64_t> shape2,
17 SmallVectorImpl<int64_t> &resultShape) {
18 // To compute the result broadcasted shape, we compare operand shapes
19 // element-wise: starting with the trailing dimensions, and working the
20 // way backward. Two dimensions are compatible when
21 // 1. they are equal, or
22 // 2. one of them is 1
23 // The result shape has the maximum among the two inputs at every
24 // dimension index.
25
26 resultShape.clear();
27 if (shape1.size() > shape2.size()) {
28 std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
29 } else {
30 std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
31 }
32
33 auto i1 = shape1.rbegin(), e1 = shape1.rend();
34 auto i2 = shape2.rbegin(), e2 = shape2.rend();
35 auto iR = resultShape.rbegin();
36
37 // Check each dimension is consistent.
38 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
39 if (*i1 == -1 || *i2 == -1) {
40 // One or both dimensions is unknown. Follow TensorFlow behavior:
41 // - If either dimension is greater than 1, we assume that the program is
42 // correct, and the other dimension will be broadcast to match it.
43 // - If either dimension is 1, the other dimension is the output.
44 if (*i1 > 1) {
45 *iR = *i1;
46 } else if (*i2 > 1) {
47 *iR = *i2;
48 } else if (*i1 == 1) {
49 *iR = *i2;
50 } else if (*i2 == 1) {
51 *iR = *i1;
52 } else {
53 *iR = -1;
54 }
55 } else {
56 if (*i1 == *i2 || *i2 == 1) {
57 *iR = *i1;
58 } else if (*i1 == 1) {
59 *iR = *i2;
60 } else {
61 // This dimension of the two operand types is incompatible.
62 resultShape.clear();
63 return false;
64 }
65 }
66 }
67
68 return true;
69 }
70
71 /// Returns the shape of the given type. Scalars will be considered as having a
72 /// shape with zero dimensions.
getShape(Type type)73 static ArrayRef<int64_t> getShape(Type type) {
74 if (auto sType = type.dyn_cast<ShapedType>())
75 return sType.getShape();
76 return {};
77 }
78
79 /// Returns the result broadcast composition type from the two given types by
80 /// following NumPy broadcast semantics. Returned type may have dynamic shape if
81 /// either of the input types has dynamic shape. Returns null type if the two
82 /// given types are not broadcast-compatible.
getBroadcastedType(Type type1,Type type2)83 Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
84 // Returns the scalar type out of the given type.
85 auto getScalarType = [](Type type) -> Type {
86 if (auto shapedType = type.dyn_cast<ShapedType>())
87 return shapedType.getElementType();
88 return type;
89 };
90
91 // Make sure underlying scalar type is the same.
92 auto scalarType = getScalarType(type1);
93 if (scalarType != getScalarType(type2))
94 return {};
95
96 // If one of the types is unranked tensor, then the other type shouldn't be
97 // vector and the result should have unranked tensor type.
98 if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
99 if (type1.isa<VectorType>() || type2.isa<VectorType>())
100 return {};
101 return UnrankedTensorType::get(scalarType);
102 }
103
104 // Returns the type kind if the given type is a vector or ranked tensor type.
105 // Returns llvm::None otherwise.
106 auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
107 if (type.isa<VectorType>() || type.isa<RankedTensorType>())
108 return static_cast<StandardTypes::Kind>(type.getKind());
109 return llvm::None;
110 };
111
112 // Make sure the composite type, if has, is consistent.
113 auto compositeKind1 = getCompositeTypeKind(type1);
114 auto compositeKind2 = getCompositeTypeKind(type2);
115 Optional<StandardTypes::Kind> resultCompositeKind;
116
117 if (compositeKind1 && compositeKind2) {
118 // Disallow mixing vector and tensor.
119 if (compositeKind1 != compositeKind2)
120 return {};
121 resultCompositeKind = compositeKind1;
122 } else if (compositeKind1) {
123 resultCompositeKind = compositeKind1;
124 } else if (compositeKind2) {
125 resultCompositeKind = compositeKind2;
126 }
127
128 // Get the shape of each type.
129 SmallVector<int64_t, 4> resultShape;
130 if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
131 return {};
132
133 // Compose the final broadcasted type
134 if (resultCompositeKind == StandardTypes::Vector)
135 return VectorType::get(resultShape, scalarType);
136 if (resultCompositeKind == StandardTypes::RankedTensor)
137 return RankedTensorType::get(resultShape, scalarType);
138 return scalarType;
139 }
140
141 /// Returns true if the given types has both vector types and tensor types.
hasBothVectorAndTensorType(ArrayRef<Type> types)142 static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
143 return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
144 llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
145 }
146
areCompatibleShapes(ArrayRef<int64_t> shape1,ArrayRef<int64_t> shape2)147 static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
148 ArrayRef<int64_t> shape2) {
149 auto isCompatible = [](int64_t dim1, int64_t dim2) {
150 return dim1 == dim2 || dim1 == -1 || dim2 == -1;
151 };
152 if (shape1.size() != shape2.size())
153 return false;
154 for (auto p : llvm::zip(shape1, shape2))
155 if (!isCompatible(std::get<0>(p), std::get<1>(p)))
156 return false;
157 return true;
158 }
159
verifyCompatibleOperandBroadcast(Operation * op)160 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
161 assert(op->getNumOperands() == 2 &&
162 "only support broadcast check on two operands");
163 assert(op->getNumResults() == 1 &&
164 "only support broadcast check on one result");
165
166 auto type1 = op->getOperand(0).getType();
167 auto type2 = op->getOperand(1).getType();
168 auto retType = op->getResult(0).getType();
169
170 // We forbid broadcasting vector and tensor.
171 if (hasBothVectorAndTensorType({type1, type2, retType}))
172 return op->emitError("cannot broadcast vector with tensor");
173
174 if (retType.isa<UnrankedTensorType>())
175 return success();
176
177 bool isUnranked1 = type1.isa<UnrankedTensorType>();
178 bool isUnranked2 = type2.isa<UnrankedTensorType>();
179
180 // If both operands are unranked, then all result shapes are possible.
181 if (isUnranked1 && isUnranked2)
182 return success();
183
184 // If one of the operands is unranked, then the known dimensions in the result
185 // should be compatible with the other shaped operand.
186 if (isUnranked1 || isUnranked2) {
187 // Result should have higher rank than the shaped operand's rank and then
188 // the result's trailing dimensions should be compatible with the operand
189 // shape.
190 ArrayRef<int64_t> shape = getShape(!isUnranked1 ? type1 : type2);
191 ArrayRef<int64_t> actualSuffix = getShape(retType).take_back(shape.size());
192 if (!areCompatibleShapes(actualSuffix, shape))
193 return op->emitOpError()
194 << "result type " << retType
195 << " has shape incompatible with a ranked operand type";
196 return success();
197 }
198
199 // If both operands are shaped, then the computed broadcasted shape should be
200 // compatible with the result shape.
201 SmallVector<int64_t, 4> resultShape;
202 if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
203 return op->emitOpError("operands don't have broadcast-compatible shapes");
204
205 if (!areCompatibleShapes(resultShape, getShape(retType)))
206 return op->emitOpError() << "result type " << retType
207 << " does not have shape compatible with the one "
208 "computed from the operand types";
209
210 return success();
211 }
212