1 //===- Traits.cpp - Common op traits shared by dialects -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Traits.h"
10 #include "mlir/IR/StandardTypes.h"
11 #include "mlir/IR/TypeUtilities.h"
12 #include "llvm/Support/FormatVariadic.h"
13 
14 using namespace mlir;
15 
staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,ArrayRef<int64_t> shape2)16 bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
17                                                  ArrayRef<int64_t> shape2) {
18   // Two dimensions are compatible when
19   //   1. they are defined and equal, or
20   //   2. one of them is 1
21   return llvm::all_of(llvm::zip(llvm::reverse(shape1), llvm::reverse(shape2)),
22                       [](auto dimensions) {
23                         auto dim1 = std::get<0>(dimensions);
24                         auto dim2 = std::get<1>(dimensions);
25                         if (dim1 == 1 || dim2 == 1)
26                           return true;
27                         if (dim1 == dim2 && !ShapedType::isDynamic(dim1))
28                           return true;
29                         return false;
30                       });
31 }
32 
getBroadcastedShape(ArrayRef<int64_t> shape1,ArrayRef<int64_t> shape2,SmallVectorImpl<int64_t> & resultShape)33 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
34                                         ArrayRef<int64_t> shape2,
35                                         SmallVectorImpl<int64_t> &resultShape) {
36   // To compute the result broadcasted shape, we compare operand shapes
37   // element-wise: starting with the trailing dimensions, and working the
38   // way backward. Two dimensions are compatible when
39   //   1. they are equal, or
40   //   2. one of them is 1
41   // The result shape has the maximum among the two inputs at every
42   // dimension index.
43 
44   resultShape.clear();
45   if (shape1.size() > shape2.size()) {
46     std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
47   } else {
48     std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
49   }
50 
51   auto i1 = shape1.rbegin(), e1 = shape1.rend();
52   auto i2 = shape2.rbegin(), e2 = shape2.rend();
53   auto iR = resultShape.rbegin();
54 
55   // Check each dimension is consistent.
56   for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
57     if (*i1 == -1 || *i2 == -1) {
58       // One or both dimensions is unknown. Follow TensorFlow behavior:
59       // - If either dimension is greater than 1, we assume that the program is
60       //   correct, and the other dimension will be broadcast to match it.
61       // - If either dimension is 1, the other dimension is the output.
62       if (*i1 > 1) {
63         *iR = *i1;
64       } else if (*i2 > 1) {
65         *iR = *i2;
66       } else if (*i1 == 1) {
67         *iR = *i2;
68       } else if (*i2 == 1) {
69         *iR = *i1;
70       } else {
71         *iR = -1;
72       }
73     } else {
74       if (*i1 == *i2 || *i2 == 1) {
75         *iR = *i1;
76       } else if (*i1 == 1) {
77         *iR = *i2;
78       } else {
79         // This dimension of the two operand types is incompatible.
80         resultShape.clear();
81         return false;
82       }
83     }
84   }
85 
86   return true;
87 }
88 
89 /// Returns the shape of the given type. Scalars will be considered as having a
90 /// shape with zero dimensions.
getShape(Type type)91 static ArrayRef<int64_t> getShape(Type type) {
92   if (auto sType = type.dyn_cast<ShapedType>())
93     return sType.getShape();
94   return {};
95 }
96 
97 /// Returns the result broadcast composition type from the two given types by
98 /// following NumPy broadcast semantics. Returned type may have dynamic shape if
99 /// either of the input types has dynamic shape. Returns null type if the two
100 /// given types are not broadcast-compatible.
101 ///
102 /// elementType, if specified, will be used as the element type of the
103 /// broadcasted result type. Otherwise it is required that the element type of
104 /// type1 and type2 is the same and this element type will be used as the
105 /// resultant element type.
getBroadcastedType(Type type1,Type type2,Type elementType)106 Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
107                                        Type elementType) {
108   // If the elementType is not specified, then the use the common element type
109   // of the inputs or fail if there is no common element type.
110   if (!elementType) {
111     elementType = getElementTypeOrSelf(type1);
112     if (elementType != getElementTypeOrSelf(type2))
113       return {};
114   }
115 
116   // If one of the types is unranked tensor, then the other type shouldn't be
117   // vector and the result should have unranked tensor type.
118   if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
119     if (type1.isa<VectorType>() || type2.isa<VectorType>())
120       return {};
121     return UnrankedTensorType::get(elementType);
122   }
123 
124   // Returns the type kind if the given type is a vector or ranked tensor type.
125   // Returns llvm::None otherwise.
126   auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
127     if (type.isa<VectorType, RankedTensorType>())
128       return static_cast<StandardTypes::Kind>(type.getKind());
129     return llvm::None;
130   };
131 
132   // Make sure the composite type, if has, is consistent.
133   auto compositeKind1 = getCompositeTypeKind(type1);
134   auto compositeKind2 = getCompositeTypeKind(type2);
135   Optional<StandardTypes::Kind> resultCompositeKind;
136 
137   if (compositeKind1 && compositeKind2) {
138     // Disallow mixing vector and tensor.
139     if (compositeKind1 != compositeKind2)
140       return {};
141     resultCompositeKind = compositeKind1;
142   } else if (compositeKind1) {
143     resultCompositeKind = compositeKind1;
144   } else if (compositeKind2) {
145     resultCompositeKind = compositeKind2;
146   }
147 
148   // Get the shape of each type.
149   SmallVector<int64_t, 4> resultShape;
150   if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
151     return {};
152 
153   // Compose the final broadcasted type
154   if (resultCompositeKind == StandardTypes::Vector)
155     return VectorType::get(resultShape, elementType);
156   if (resultCompositeKind == StandardTypes::RankedTensor)
157     return RankedTensorType::get(resultShape, elementType);
158   return elementType;
159 }
160 
161 /// Returns a tuple corresponding to whether range has tensor or vector type.
162 template <typename iterator_range>
hasTensorOrVectorType(iterator_range types)163 static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
164   return std::make_tuple(
165       llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }),
166       llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }));
167 }
168 
areCompatibleShapes(ArrayRef<int64_t> shape1,ArrayRef<int64_t> shape2)169 static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
170                                 ArrayRef<int64_t> shape2) {
171   auto isCompatible = [](int64_t dim1, int64_t dim2) {
172     return dim1 == dim2 || dim1 == -1 || dim2 == -1;
173   };
174   if (shape1.size() != shape2.size())
175     return false;
176   for (auto p : llvm::zip(shape1, shape2))
177     if (!isCompatible(std::get<0>(p), std::get<1>(p)))
178       return false;
179   return true;
180 }
181 
getShapeString(ArrayRef<int64_t> shape)182 static std::string getShapeString(ArrayRef<int64_t> shape) {
183   // TODO: should replace with printing shape more uniformly across here and
184   // when in type.
185   return std::string(
186       formatv("'{0:$[x]}'", llvm::make_range(shape.begin(), shape.end())));
187 }
188 
verifyCompatibleOperandBroadcast(Operation * op)189 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
190   // Ensure broadcasting only tensor or only vector types.
191   auto operandsHasTensorVectorType =
192       hasTensorOrVectorType(op->getOperandTypes());
193   auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
194   if ((std::get<0>(operandsHasTensorVectorType) ||
195        std::get<0>(resultsHasTensorVectorType)) &&
196       (std::get<1>(operandsHasTensorVectorType) ||
197        std::get<1>(resultsHasTensorVectorType)))
198     return op->emitError("cannot broadcast vector with tensor");
199 
200   auto rankedOperands = make_filter_range(
201       op->getOperandTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
202 
203   // If all operands are unranked, then all result shapes are possible.
204   if (rankedOperands.empty())
205     return success();
206 
207   // Compute broadcasted shape of operands (which requires that operands are
208   // broadcast compatible). The results need to be broadcast compatible with
209   // this result shape.
210   SmallVector<int64_t, 4> resultShape;
211   (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
212                                   resultShape);
213   for (auto other : make_early_inc_range(rankedOperands)) {
214     SmallVector<int64_t, 4> temp = resultShape;
215     if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
216       return op->emitOpError("operands don't have broadcast-compatible shapes");
217   }
218 
219   auto rankedResults = make_filter_range(
220       op->getResultTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
221 
222   // If all of the results are unranked then no further verification.
223   if (rankedResults.empty())
224     return success();
225 
226   for (auto type : rankedResults) {
227     ArrayRef<int64_t> actualSuffix =
228         getShape(type).take_back(resultShape.size());
229     if (!areCompatibleShapes(actualSuffix, resultShape))
230       return op->emitOpError()
231              << "result type " << getShapeString(getShape(type))
232              << " not broadcast compatible with broadcasted operands's shapes "
233              << getShapeString(resultShape);
234   }
235   return success();
236 }
237