1 //===- Traits.cpp - Common op traits shared by dialects -------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Traits.h"
10 #include "mlir/IR/StandardTypes.h"
11 #include "llvm/Support/FormatVariadic.h"
12 
13 using namespace mlir;
14 
getBroadcastedShape(ArrayRef<int64_t> shape1,ArrayRef<int64_t> shape2,SmallVectorImpl<int64_t> & resultShape)15 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
16                                         ArrayRef<int64_t> shape2,
17                                         SmallVectorImpl<int64_t> &resultShape) {
18   // To compute the result broadcasted shape, we compare operand shapes
19   // element-wise: starting with the trailing dimensions, and working the
20   // way backward. Two dimensions are compatible when
21   //   1. they are equal, or
22   //   2. one of them is 1
23   // The result shape has the maximum among the two inputs at every
24   // dimension index.
25 
26   resultShape.clear();
27   if (shape1.size() > shape2.size()) {
28     std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
29   } else {
30     std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
31   }
32 
33   auto i1 = shape1.rbegin(), e1 = shape1.rend();
34   auto i2 = shape2.rbegin(), e2 = shape2.rend();
35   auto iR = resultShape.rbegin();
36 
37   // Check each dimension is consistent.
38   for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
39     if (*i1 == -1 || *i2 == -1) {
40       // One or both dimensions is unknown. Follow TensorFlow behavior:
41       // - If either dimension is greater than 1, we assume that the program is
42       //   correct, and the other dimension will be broadcast to match it.
43       // - If either dimension is 1, the other dimension is the output.
44       if (*i1 > 1) {
45         *iR = *i1;
46       } else if (*i2 > 1) {
47         *iR = *i2;
48       } else if (*i1 == 1) {
49         *iR = *i2;
50       } else if (*i2 == 1) {
51         *iR = *i1;
52       } else {
53         *iR = -1;
54       }
55     } else {
56       if (*i1 == *i2 || *i2 == 1) {
57         *iR = *i1;
58       } else if (*i1 == 1) {
59         *iR = *i2;
60       } else {
61         // This dimension of the two operand types is incompatible.
62         resultShape.clear();
63         return false;
64       }
65     }
66   }
67 
68   return true;
69 }
70 
71 /// Returns the shape of the given type. Scalars will be considered as having a
72 /// shape with zero dimensions.
getShape(Type type)73 static ArrayRef<int64_t> getShape(Type type) {
74   if (auto sType = type.dyn_cast<ShapedType>())
75     return sType.getShape();
76   return {};
77 }
78 
79 /// Returns the result broadcast composition type from the two given types by
80 /// following NumPy broadcast semantics. Returned type may have dynamic shape if
81 /// either of the input types has dynamic shape. Returns null type if the two
82 /// given types are not broadcast-compatible.
getBroadcastedType(Type type1,Type type2)83 Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
84   // Returns the scalar type out of the given type.
85   auto getScalarType = [](Type type) -> Type {
86     if (auto shapedType = type.dyn_cast<ShapedType>())
87       return shapedType.getElementType();
88     return type;
89   };
90 
91   // Make sure underlying scalar type is the same.
92   auto scalarType = getScalarType(type1);
93   if (scalarType != getScalarType(type2))
94     return {};
95 
96   // If one of the types is unranked tensor, then the other type shouldn't be
97   // vector and the result should have unranked tensor type.
98   if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
99     if (type1.isa<VectorType>() || type2.isa<VectorType>())
100       return {};
101     return UnrankedTensorType::get(scalarType);
102   }
103 
104   // Returns the type kind if the given type is a vector or ranked tensor type.
105   // Returns llvm::None otherwise.
106   auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
107     if (type.isa<VectorType>() || type.isa<RankedTensorType>())
108       return static_cast<StandardTypes::Kind>(type.getKind());
109     return llvm::None;
110   };
111 
112   // Make sure the composite type, if has, is consistent.
113   auto compositeKind1 = getCompositeTypeKind(type1);
114   auto compositeKind2 = getCompositeTypeKind(type2);
115   Optional<StandardTypes::Kind> resultCompositeKind;
116 
117   if (compositeKind1 && compositeKind2) {
118     // Disallow mixing vector and tensor.
119     if (compositeKind1 != compositeKind2)
120       return {};
121     resultCompositeKind = compositeKind1;
122   } else if (compositeKind1) {
123     resultCompositeKind = compositeKind1;
124   } else if (compositeKind2) {
125     resultCompositeKind = compositeKind2;
126   }
127 
128   // Get the shape of each type.
129   SmallVector<int64_t, 4> resultShape;
130   if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
131     return {};
132 
133   // Compose the final broadcasted type
134   if (resultCompositeKind == StandardTypes::Vector)
135     return VectorType::get(resultShape, scalarType);
136   if (resultCompositeKind == StandardTypes::RankedTensor)
137     return RankedTensorType::get(resultShape, scalarType);
138   return scalarType;
139 }
140 
141 /// Returns true if the given types has both vector types and tensor types.
hasBothVectorAndTensorType(ArrayRef<Type> types)142 static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
143   return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
144          llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
145 }
146 
areCompatibleShapes(ArrayRef<int64_t> shape1,ArrayRef<int64_t> shape2)147 static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
148                                 ArrayRef<int64_t> shape2) {
149   auto isCompatible = [](int64_t dim1, int64_t dim2) {
150     return dim1 == dim2 || dim1 == -1 || dim2 == -1;
151   };
152   if (shape1.size() != shape2.size())
153     return false;
154   for (auto p : llvm::zip(shape1, shape2))
155     if (!isCompatible(std::get<0>(p), std::get<1>(p)))
156       return false;
157   return true;
158 }
159 
verifyCompatibleOperandBroadcast(Operation * op)160 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
161   assert(op->getNumOperands() == 2 &&
162          "only support broadcast check on two operands");
163   assert(op->getNumResults() == 1 &&
164          "only support broadcast check on one result");
165 
166   auto type1 = op->getOperand(0).getType();
167   auto type2 = op->getOperand(1).getType();
168   auto retType = op->getResult(0).getType();
169 
170   // We forbid broadcasting vector and tensor.
171   if (hasBothVectorAndTensorType({type1, type2, retType}))
172     return op->emitError("cannot broadcast vector with tensor");
173 
174   if (retType.isa<UnrankedTensorType>())
175     return success();
176 
177   bool isUnranked1 = type1.isa<UnrankedTensorType>();
178   bool isUnranked2 = type2.isa<UnrankedTensorType>();
179 
180   // If both operands are unranked, then all result shapes are possible.
181   if (isUnranked1 && isUnranked2)
182     return success();
183 
184   // If one of the operands is unranked, then the known dimensions in the result
185   // should be compatible with the other shaped operand.
186   if (isUnranked1 || isUnranked2) {
187     // Result should have higher rank than the shaped operand's rank and then
188     // the result's trailing dimensions should be compatible with the operand
189     // shape.
190     ArrayRef<int64_t> shape = getShape(!isUnranked1 ? type1 : type2);
191     ArrayRef<int64_t> actualSuffix = getShape(retType).take_back(shape.size());
192     if (!areCompatibleShapes(actualSuffix, shape))
193       return op->emitOpError()
194              << "result type " << retType
195              << " has shape incompatible with a ranked operand type";
196     return success();
197   }
198 
199   // If both operands are shaped, then the computed broadcasted shape should be
200   // compatible with the result shape.
201   SmallVector<int64_t, 4> resultShape;
202   if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
203     return op->emitOpError("operands don't have broadcast-compatible shapes");
204 
205   if (!areCompatibleShapes(resultShape, getShape(retType)))
206     return op->emitOpError() << "result type " << retType
207                              << " does not have shape compatible with the one "
208                                 "computed from the operand types";
209 
210   return success();
211 }
212