1 //===- TypeRange.h ----------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the TypeRange and ValueTypeRange classes.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef MLIR_IR_TYPERANGE_H
14 #define MLIR_IR_TYPERANGE_H
15
16 #include "mlir/IR/Types.h"
17 #include "mlir/IR/Value.h"
18 #include "llvm/ADT/PointerUnion.h"
19
20 namespace mlir {
21 class OperandRange;
22 class ResultRange;
23 class Type;
24 class Value;
25 class ValueRange;
26 template <typename ValueRangeT>
27 class ValueTypeRange;
28
29 //===----------------------------------------------------------------------===//
30 // TypeRange
31
32 /// This class provides an abstraction over the various different ranges of
33 /// value types. In many cases, this prevents the need to explicitly materialize
34 /// a SmallVector/std::vector. This class should be used in places that are not
35 /// suitable for a more derived type (e.g. ArrayRef) or a template range
36 /// parameter.
37 class TypeRange : public llvm::detail::indexed_accessor_range_base<
38 TypeRange,
39 llvm::PointerUnion<const Value *, const Type *,
40 OpOperand *, detail::OpResultImpl *>,
41 Type, Type, Type> {
42 public:
43 using RangeBaseT::RangeBaseT;
44 TypeRange(ArrayRef<Type> types = llvm::None);
45 explicit TypeRange(OperandRange values);
46 explicit TypeRange(ResultRange values);
47 explicit TypeRange(ValueRange values);
48 explicit TypeRange(ArrayRef<Value> values);
TypeRange(ArrayRef<BlockArgument> values)49 explicit TypeRange(ArrayRef<BlockArgument> values)
50 : TypeRange(ArrayRef<Value>(values.data(), values.size())) {}
51 template <typename ValueRangeT>
TypeRange(ValueTypeRange<ValueRangeT> values)52 TypeRange(ValueTypeRange<ValueRangeT> values)
53 : TypeRange(ValueRangeT(values.begin().getCurrent(),
54 values.end().getCurrent())) {}
55 template <typename Arg,
56 typename = typename std::enable_if_t<
57 std::is_constructible<ArrayRef<Type>, Arg>::value>>
TypeRange(Arg && arg)58 TypeRange(Arg &&arg) : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
TypeRange(std::initializer_list<Type> types)59 TypeRange(std::initializer_list<Type> types)
60 : TypeRange(ArrayRef<Type>(types)) {}
61
62 private:
63 /// The owner of the range is either:
64 /// * A pointer to the first element of an array of values.
65 /// * A pointer to the first element of an array of types.
66 /// * A pointer to the first element of an array of operands.
67 /// * A pointer to the first element of an array of results.
68 using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *,
69 detail::OpResultImpl *>;
70
71 /// See `llvm::detail::indexed_accessor_range_base` for details.
72 static OwnerT offset_base(OwnerT object, ptrdiff_t index);
73 /// See `llvm::detail::indexed_accessor_range_base` for details.
74 static Type dereference_iterator(OwnerT object, ptrdiff_t index);
75
76 /// Allow access to `offset_base` and `dereference_iterator`.
77 friend RangeBaseT;
78 };
79
80 /// Make TypeRange hashable.
hash_value(TypeRange arg)81 inline ::llvm::hash_code hash_value(TypeRange arg) {
82 return ::llvm::hash_combine_range(arg.begin(), arg.end());
83 }
84
85 /// Emit a type range to the given output stream.
86 inline raw_ostream &operator<<(raw_ostream &os, const TypeRange &types) {
87 llvm::interleaveComma(types, os);
88 return os;
89 }
90
91 //===----------------------------------------------------------------------===//
92 // ValueTypeRange
93
94 /// This class implements iteration on the types of a given range of values.
95 template <typename ValueIteratorT>
96 class ValueTypeIterator final
97 : public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)> {
unwrap(Value value)98 static Type unwrap(Value value) { return value.getType(); }
99
100 public:
101 using reference = Type;
102
103 /// Provide a const dereference method.
104 Type operator*() const { return unwrap(*this->I); }
105
106 /// Initializes the type iterator to the specified value iterator.
ValueTypeIterator(ValueIteratorT it)107 ValueTypeIterator(ValueIteratorT it)
108 : llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)>(it, &unwrap) {}
109 };
110
111 /// This class implements iteration on the types of a given range of values.
112 template <typename ValueRangeT>
113 class ValueTypeRange final
114 : public llvm::iterator_range<
115 ValueTypeIterator<typename ValueRangeT::iterator>> {
116 public:
117 using llvm::iterator_range<
118 ValueTypeIterator<typename ValueRangeT::iterator>>::iterator_range;
119 template <typename Container>
ValueTypeRange(Container && c)120 ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {}
121
122 /// Return the type at the given index.
123 Type operator[](size_t index) const {
124 assert(index < size() && "invalid index into type range");
125 return *(this->begin() + index);
126 }
127
128 /// Return the size of this range.
size()129 size_t size() const { return llvm::size(*this); }
130
131 /// Return first type in the range.
front()132 Type front() { return (*this)[0]; }
133
134 /// Compare this range with another.
135 template <typename OtherT>
136 bool operator==(const OtherT &other) const {
137 return llvm::size(*this) == llvm::size(other) &&
138 std::equal(this->begin(), this->end(), other.begin());
139 }
140 template <typename OtherT>
141 bool operator!=(const OtherT &other) const {
142 return !(*this == other);
143 }
144 };
145
146 template <typename RangeT>
147 inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) {
148 return lhs.size() == static_cast<size_t>(llvm::size(rhs)) &&
149 std::equal(lhs.begin(), lhs.end(), rhs.begin());
150 }
151
152 } // namespace mlir
153
154 namespace llvm {
155
156 // Provide DenseMapInfo for TypeRange.
157 template <>
158 struct DenseMapInfo<mlir::TypeRange> {
159 static mlir::TypeRange getEmptyKey() {
160 return mlir::TypeRange(getEmptyKeyPointer(), 0);
161 }
162
163 static mlir::TypeRange getTombstoneKey() {
164 return mlir::TypeRange(getTombstoneKeyPointer(), 0);
165 }
166
167 static unsigned getHashValue(mlir::TypeRange val) { return hash_value(val); }
168
169 static bool isEqual(mlir::TypeRange lhs, mlir::TypeRange rhs) {
170 if (isEmptyKey(rhs))
171 return isEmptyKey(lhs);
172 if (isTombstoneKey(rhs))
173 return isTombstoneKey(lhs);
174 return lhs == rhs;
175 }
176
177 private:
178 static const mlir::Type *getEmptyKeyPointer() {
179 return DenseMapInfo<mlir::Type *>::getEmptyKey();
180 }
181
182 static const mlir::Type *getTombstoneKeyPointer() {
183 return DenseMapInfo<mlir::Type *>::getTombstoneKey();
184 }
185
186 static bool isEmptyKey(mlir::TypeRange range) {
187 if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
188 return type == getEmptyKeyPointer();
189 return false;
190 }
191
192 static bool isTombstoneKey(mlir::TypeRange range) {
193 if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
194 return type == getTombstoneKeyPointer();
195 return false;
196 }
197 };
198
199 } // namespace llvm
200
201 #endif // MLIR_IR_TYPERANGE_H
202