1 //===- TypeRange.h ----------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the TypeRange and ValueTypeRange classes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_TYPERANGE_H
14 #define MLIR_IR_TYPERANGE_H
15 
16 #include "mlir/IR/Types.h"
17 #include "mlir/IR/Value.h"
18 #include "llvm/ADT/PointerUnion.h"
19 #include "llvm/ADT/Sequence.h"
20 
21 namespace mlir {
22 class OperandRange;
23 class ResultRange;
24 class Type;
25 class Value;
26 class ValueRange;
27 template <typename ValueRangeT>
28 class ValueTypeRange;
29 
30 //===----------------------------------------------------------------------===//
31 // TypeRange
32 
33 /// This class provides an abstraction over the various different ranges of
34 /// value types. In many cases, this prevents the need to explicitly materialize
35 /// a SmallVector/std::vector. This class should be used in places that are not
36 /// suitable for a more derived type (e.g. ArrayRef) or a template range
37 /// parameter.
38 class TypeRange : public llvm::detail::indexed_accessor_range_base<
39                       TypeRange,
40                       llvm::PointerUnion<const Value *, const Type *,
41                                          OpOperand *, detail::OpResultImpl *>,
42                       Type, Type, Type> {
43 public:
44   using RangeBaseT::RangeBaseT;
45   TypeRange(ArrayRef<Type> types = llvm::None);
46   explicit TypeRange(OperandRange values);
47   explicit TypeRange(ResultRange values);
48   explicit TypeRange(ValueRange values);
49   explicit TypeRange(ArrayRef<Value> values);
TypeRange(ArrayRef<BlockArgument> values)50   explicit TypeRange(ArrayRef<BlockArgument> values)
51       : TypeRange(ArrayRef<Value>(values.data(), values.size())) {}
52   template <typename ValueRangeT>
TypeRange(ValueTypeRange<ValueRangeT> values)53   TypeRange(ValueTypeRange<ValueRangeT> values)
54       : TypeRange(ValueRangeT(values.begin().getCurrent(),
55                               values.end().getCurrent())) {}
56   template <typename Arg,
57             typename = typename std::enable_if_t<
58                 std::is_constructible<ArrayRef<Type>, Arg>::value>>
TypeRange(Arg && arg)59   TypeRange(Arg &&arg) : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
TypeRange(std::initializer_list<Type> types)60   TypeRange(std::initializer_list<Type> types)
61       : TypeRange(ArrayRef<Type>(types)) {}
62 
63 private:
64   /// The owner of the range is either:
65   /// * A pointer to the first element of an array of values.
66   /// * A pointer to the first element of an array of types.
67   /// * A pointer to the first element of an array of operands.
68   /// * A pointer to the first element of an array of results.
69   using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *,
70                                     detail::OpResultImpl *>;
71 
72   /// See `llvm::detail::indexed_accessor_range_base` for details.
73   static OwnerT offset_base(OwnerT object, ptrdiff_t index);
74   /// See `llvm::detail::indexed_accessor_range_base` for details.
75   static Type dereference_iterator(OwnerT object, ptrdiff_t index);
76 
77   /// Allow access to `offset_base` and `dereference_iterator`.
78   friend RangeBaseT;
79 };
80 
81 /// Make TypeRange hashable.
hash_value(TypeRange arg)82 inline ::llvm::hash_code hash_value(TypeRange arg) {
83   return ::llvm::hash_combine_range(arg.begin(), arg.end());
84 }
85 
86 /// Emit a type range to the given output stream.
87 inline raw_ostream &operator<<(raw_ostream &os, const TypeRange &types) {
88   llvm::interleaveComma(types, os);
89   return os;
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // TypeRangeRange
94 
95 using TypeRangeRangeIterator =
96     llvm::mapped_iterator<llvm::iota_range<unsigned>::iterator,
97                           std::function<TypeRange(unsigned)>>;
98 
99 /// This class provides an abstraction for a range of TypeRange. This is useful
100 /// when accessing the types of a range of ranges, such as when using
101 /// OperandRangeRange.
102 class TypeRangeRange : public llvm::iterator_range<TypeRangeRangeIterator> {
103 public:
104   template <typename RangeT>
TypeRangeRange(const RangeT & range)105   TypeRangeRange(const RangeT &range)
106       : TypeRangeRange(llvm::seq<unsigned>(0, range.size()), range) {}
107 
108 private:
109   template <typename RangeT>
TypeRangeRange(llvm::iota_range<unsigned> sizeRange,const RangeT & range)110   TypeRangeRange(llvm::iota_range<unsigned> sizeRange, const RangeT &range)
111       : llvm::iterator_range<TypeRangeRangeIterator>(
112             {sizeRange.begin(), getRangeFn(range)},
113             {sizeRange.end(), nullptr}) {}
114 
115   template <typename RangeT>
getRangeFn(const RangeT & range)116   static std::function<TypeRange(unsigned)> getRangeFn(const RangeT &range) {
117     return [=](unsigned index) -> TypeRange { return TypeRange(range[index]); };
118   }
119 };
120 
121 //===----------------------------------------------------------------------===//
122 // ValueTypeRange
123 
124 /// This class implements iteration on the types of a given range of values.
125 template <typename ValueIteratorT>
126 class ValueTypeIterator final
127     : public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)> {
unwrap(Value value)128   static Type unwrap(Value value) { return value.getType(); }
129 
130 public:
131   /// Provide a const dereference method.
132   Type operator*() const { return unwrap(*this->I); }
133 
134   /// Initializes the type iterator to the specified value iterator.
ValueTypeIterator(ValueIteratorT it)135   ValueTypeIterator(ValueIteratorT it)
136       : llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)>(it, &unwrap) {}
137 };
138 
139 /// This class implements iteration on the types of a given range of values.
140 template <typename ValueRangeT>
141 class ValueTypeRange final
142     : public llvm::iterator_range<
143           ValueTypeIterator<typename ValueRangeT::iterator>> {
144 public:
145   using llvm::iterator_range<
146       ValueTypeIterator<typename ValueRangeT::iterator>>::iterator_range;
147   template <typename Container>
ValueTypeRange(Container && c)148   ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {}
149 
150   /// Return the type at the given index.
151   Type operator[](size_t index) const {
152     assert(index < size() && "invalid index into type range");
153     return *(this->begin() + index);
154   }
155 
156   /// Return the size of this range.
size()157   size_t size() const { return llvm::size(*this); }
158 
159   /// Return first type in the range.
front()160   Type front() { return (*this)[0]; }
161 
162   /// Compare this range with another.
163   template <typename OtherT>
164   bool operator==(const OtherT &other) const {
165     return llvm::size(*this) == llvm::size(other) &&
166            std::equal(this->begin(), this->end(), other.begin());
167   }
168   template <typename OtherT>
169   bool operator!=(const OtherT &other) const {
170     return !(*this == other);
171   }
172 };
173 
174 template <typename RangeT>
175 inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) {
176   return lhs.size() == static_cast<size_t>(llvm::size(rhs)) &&
177          std::equal(lhs.begin(), lhs.end(), rhs.begin());
178 }
179 
180 } // namespace mlir
181 
182 namespace llvm {
183 
184 // Provide DenseMapInfo for TypeRange.
185 template <>
186 struct DenseMapInfo<mlir::TypeRange> {
187   static mlir::TypeRange getEmptyKey() {
188     return mlir::TypeRange(getEmptyKeyPointer(), 0);
189   }
190 
191   static mlir::TypeRange getTombstoneKey() {
192     return mlir::TypeRange(getTombstoneKeyPointer(), 0);
193   }
194 
195   static unsigned getHashValue(mlir::TypeRange val) { return hash_value(val); }
196 
197   static bool isEqual(mlir::TypeRange lhs, mlir::TypeRange rhs) {
198     if (isEmptyKey(rhs))
199       return isEmptyKey(lhs);
200     if (isTombstoneKey(rhs))
201       return isTombstoneKey(lhs);
202     return lhs == rhs;
203   }
204 
205 private:
206   static const mlir::Type *getEmptyKeyPointer() {
207     return DenseMapInfo<mlir::Type *>::getEmptyKey();
208   }
209 
210   static const mlir::Type *getTombstoneKeyPointer() {
211     return DenseMapInfo<mlir::Type *>::getTombstoneKey();
212   }
213 
214   static bool isEmptyKey(mlir::TypeRange range) {
215     if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
216       return type == getEmptyKeyPointer();
217     return false;
218   }
219 
220   static bool isTombstoneKey(mlir::TypeRange range) {
221     if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
222       return type == getTombstoneKeyPointer();
223     return false;
224   }
225 };
226 
227 } // namespace llvm
228 
229 #endif // MLIR_IR_TYPERANGE_H
230