1 //===- TypeRange.h ----------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the TypeRange and ValueTypeRange classes.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef MLIR_IR_TYPERANGE_H
14 #define MLIR_IR_TYPERANGE_H
15
16 #include "mlir/IR/Types.h"
17 #include "mlir/IR/Value.h"
18 #include "llvm/ADT/PointerUnion.h"
19 #include "llvm/ADT/Sequence.h"
20
21 namespace mlir {
22 class OperandRange;
23 class ResultRange;
24 class Type;
25 class Value;
26 class ValueRange;
27 template <typename ValueRangeT>
28 class ValueTypeRange;
29
30 //===----------------------------------------------------------------------===//
31 // TypeRange
32
33 /// This class provides an abstraction over the various different ranges of
34 /// value types. In many cases, this prevents the need to explicitly materialize
35 /// a SmallVector/std::vector. This class should be used in places that are not
36 /// suitable for a more derived type (e.g. ArrayRef) or a template range
37 /// parameter.
38 class TypeRange : public llvm::detail::indexed_accessor_range_base<
39 TypeRange,
40 llvm::PointerUnion<const Value *, const Type *,
41 OpOperand *, detail::OpResultImpl *>,
42 Type, Type, Type> {
43 public:
44 using RangeBaseT::RangeBaseT;
45 TypeRange(ArrayRef<Type> types = llvm::None);
46 explicit TypeRange(OperandRange values);
47 explicit TypeRange(ResultRange values);
48 explicit TypeRange(ValueRange values);
49 explicit TypeRange(ArrayRef<Value> values);
TypeRange(ArrayRef<BlockArgument> values)50 explicit TypeRange(ArrayRef<BlockArgument> values)
51 : TypeRange(ArrayRef<Value>(values.data(), values.size())) {}
52 template <typename ValueRangeT>
TypeRange(ValueTypeRange<ValueRangeT> values)53 TypeRange(ValueTypeRange<ValueRangeT> values)
54 : TypeRange(ValueRangeT(values.begin().getCurrent(),
55 values.end().getCurrent())) {}
56 template <typename Arg,
57 typename = typename std::enable_if_t<
58 std::is_constructible<ArrayRef<Type>, Arg>::value>>
TypeRange(Arg && arg)59 TypeRange(Arg &&arg) : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
TypeRange(std::initializer_list<Type> types)60 TypeRange(std::initializer_list<Type> types)
61 : TypeRange(ArrayRef<Type>(types)) {}
62
63 private:
64 /// The owner of the range is either:
65 /// * A pointer to the first element of an array of values.
66 /// * A pointer to the first element of an array of types.
67 /// * A pointer to the first element of an array of operands.
68 /// * A pointer to the first element of an array of results.
69 using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *,
70 detail::OpResultImpl *>;
71
72 /// See `llvm::detail::indexed_accessor_range_base` for details.
73 static OwnerT offset_base(OwnerT object, ptrdiff_t index);
74 /// See `llvm::detail::indexed_accessor_range_base` for details.
75 static Type dereference_iterator(OwnerT object, ptrdiff_t index);
76
77 /// Allow access to `offset_base` and `dereference_iterator`.
78 friend RangeBaseT;
79 };
80
81 /// Make TypeRange hashable.
hash_value(TypeRange arg)82 inline ::llvm::hash_code hash_value(TypeRange arg) {
83 return ::llvm::hash_combine_range(arg.begin(), arg.end());
84 }
85
86 /// Emit a type range to the given output stream.
87 inline raw_ostream &operator<<(raw_ostream &os, const TypeRange &types) {
88 llvm::interleaveComma(types, os);
89 return os;
90 }
91
92 //===----------------------------------------------------------------------===//
93 // TypeRangeRange
94
95 using TypeRangeRangeIterator =
96 llvm::mapped_iterator<llvm::iota_range<unsigned>::iterator,
97 std::function<TypeRange(unsigned)>>;
98
99 /// This class provides an abstraction for a range of TypeRange. This is useful
100 /// when accessing the types of a range of ranges, such as when using
101 /// OperandRangeRange.
102 class TypeRangeRange : public llvm::iterator_range<TypeRangeRangeIterator> {
103 public:
104 template <typename RangeT>
TypeRangeRange(const RangeT & range)105 TypeRangeRange(const RangeT &range)
106 : TypeRangeRange(llvm::seq<unsigned>(0, range.size()), range) {}
107
108 private:
109 template <typename RangeT>
TypeRangeRange(llvm::iota_range<unsigned> sizeRange,const RangeT & range)110 TypeRangeRange(llvm::iota_range<unsigned> sizeRange, const RangeT &range)
111 : llvm::iterator_range<TypeRangeRangeIterator>(
112 {sizeRange.begin(), getRangeFn(range)},
113 {sizeRange.end(), nullptr}) {}
114
115 template <typename RangeT>
getRangeFn(const RangeT & range)116 static std::function<TypeRange(unsigned)> getRangeFn(const RangeT &range) {
117 return [=](unsigned index) -> TypeRange { return TypeRange(range[index]); };
118 }
119 };
120
121 //===----------------------------------------------------------------------===//
122 // ValueTypeRange
123
124 /// This class implements iteration on the types of a given range of values.
125 template <typename ValueIteratorT>
126 class ValueTypeIterator final
127 : public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)> {
unwrap(Value value)128 static Type unwrap(Value value) { return value.getType(); }
129
130 public:
131 /// Provide a const dereference method.
132 Type operator*() const { return unwrap(*this->I); }
133
134 /// Initializes the type iterator to the specified value iterator.
ValueTypeIterator(ValueIteratorT it)135 ValueTypeIterator(ValueIteratorT it)
136 : llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)>(it, &unwrap) {}
137 };
138
139 /// This class implements iteration on the types of a given range of values.
140 template <typename ValueRangeT>
141 class ValueTypeRange final
142 : public llvm::iterator_range<
143 ValueTypeIterator<typename ValueRangeT::iterator>> {
144 public:
145 using llvm::iterator_range<
146 ValueTypeIterator<typename ValueRangeT::iterator>>::iterator_range;
147 template <typename Container>
ValueTypeRange(Container && c)148 ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {}
149
150 /// Return the type at the given index.
151 Type operator[](size_t index) const {
152 assert(index < size() && "invalid index into type range");
153 return *(this->begin() + index);
154 }
155
156 /// Return the size of this range.
size()157 size_t size() const { return llvm::size(*this); }
158
159 /// Return first type in the range.
front()160 Type front() { return (*this)[0]; }
161
162 /// Compare this range with another.
163 template <typename OtherT>
164 bool operator==(const OtherT &other) const {
165 return llvm::size(*this) == llvm::size(other) &&
166 std::equal(this->begin(), this->end(), other.begin());
167 }
168 template <typename OtherT>
169 bool operator!=(const OtherT &other) const {
170 return !(*this == other);
171 }
172 };
173
174 template <typename RangeT>
175 inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) {
176 return lhs.size() == static_cast<size_t>(llvm::size(rhs)) &&
177 std::equal(lhs.begin(), lhs.end(), rhs.begin());
178 }
179
180 } // namespace mlir
181
182 namespace llvm {
183
184 // Provide DenseMapInfo for TypeRange.
185 template <>
186 struct DenseMapInfo<mlir::TypeRange> {
187 static mlir::TypeRange getEmptyKey() {
188 return mlir::TypeRange(getEmptyKeyPointer(), 0);
189 }
190
191 static mlir::TypeRange getTombstoneKey() {
192 return mlir::TypeRange(getTombstoneKeyPointer(), 0);
193 }
194
195 static unsigned getHashValue(mlir::TypeRange val) { return hash_value(val); }
196
197 static bool isEqual(mlir::TypeRange lhs, mlir::TypeRange rhs) {
198 if (isEmptyKey(rhs))
199 return isEmptyKey(lhs);
200 if (isTombstoneKey(rhs))
201 return isTombstoneKey(lhs);
202 return lhs == rhs;
203 }
204
205 private:
206 static const mlir::Type *getEmptyKeyPointer() {
207 return DenseMapInfo<mlir::Type *>::getEmptyKey();
208 }
209
210 static const mlir::Type *getTombstoneKeyPointer() {
211 return DenseMapInfo<mlir::Type *>::getTombstoneKey();
212 }
213
214 static bool isEmptyKey(mlir::TypeRange range) {
215 if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
216 return type == getEmptyKeyPointer();
217 return false;
218 }
219
220 static bool isTombstoneKey(mlir::TypeRange range) {
221 if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
222 return type == getTombstoneKeyPointer();
223 return false;
224 }
225 };
226
227 } // namespace llvm
228
229 #endif // MLIR_IR_TYPERANGE_H
230