1 //===- TypeRange.h ----------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the TypeRange and ValueTypeRange classes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_TYPERANGE_H
14 #define MLIR_IR_TYPERANGE_H
15 
16 #include "mlir/IR/Types.h"
17 #include "mlir/IR/Value.h"
18 #include "llvm/ADT/PointerUnion.h"
19 
20 namespace mlir {
21 class OperandRange;
22 class ResultRange;
23 class Type;
24 class Value;
25 class ValueRange;
26 template <typename ValueRangeT>
27 class ValueTypeRange;
28 
29 //===----------------------------------------------------------------------===//
30 // TypeRange
31 
32 /// This class provides an abstraction over the various different ranges of
33 /// value types. In many cases, this prevents the need to explicitly materialize
34 /// a SmallVector/std::vector. This class should be used in places that are not
35 /// suitable for a more derived type (e.g. ArrayRef) or a template range
36 /// parameter.
37 class TypeRange : public llvm::detail::indexed_accessor_range_base<
38                       TypeRange,
39                       llvm::PointerUnion<const Value *, const Type *,
40                                          OpOperand *, detail::OpResultImpl *>,
41                       Type, Type, Type> {
42 public:
43   using RangeBaseT::RangeBaseT;
44   TypeRange(ArrayRef<Type> types = llvm::None);
45   explicit TypeRange(OperandRange values);
46   explicit TypeRange(ResultRange values);
47   explicit TypeRange(ValueRange values);
48   explicit TypeRange(ArrayRef<Value> values);
TypeRange(ArrayRef<BlockArgument> values)49   explicit TypeRange(ArrayRef<BlockArgument> values)
50       : TypeRange(ArrayRef<Value>(values.data(), values.size())) {}
51   template <typename ValueRangeT>
TypeRange(ValueTypeRange<ValueRangeT> values)52   TypeRange(ValueTypeRange<ValueRangeT> values)
53       : TypeRange(ValueRangeT(values.begin().getCurrent(),
54                               values.end().getCurrent())) {}
55   template <typename Arg,
56             typename = typename std::enable_if_t<
57                 std::is_constructible<ArrayRef<Type>, Arg>::value>>
TypeRange(Arg && arg)58   TypeRange(Arg &&arg) : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
TypeRange(std::initializer_list<Type> types)59   TypeRange(std::initializer_list<Type> types)
60       : TypeRange(ArrayRef<Type>(types)) {}
61 
62 private:
63   /// The owner of the range is either:
64   /// * A pointer to the first element of an array of values.
65   /// * A pointer to the first element of an array of types.
66   /// * A pointer to the first element of an array of operands.
67   /// * A pointer to the first element of an array of results.
68   using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *,
69                                     detail::OpResultImpl *>;
70 
71   /// See `llvm::detail::indexed_accessor_range_base` for details.
72   static OwnerT offset_base(OwnerT object, ptrdiff_t index);
73   /// See `llvm::detail::indexed_accessor_range_base` for details.
74   static Type dereference_iterator(OwnerT object, ptrdiff_t index);
75 
76   /// Allow access to `offset_base` and `dereference_iterator`.
77   friend RangeBaseT;
78 };
79 
80 /// Make TypeRange hashable.
hash_value(TypeRange arg)81 inline ::llvm::hash_code hash_value(TypeRange arg) {
82   return ::llvm::hash_combine_range(arg.begin(), arg.end());
83 }
84 
85 /// Emit a type range to the given output stream.
86 inline raw_ostream &operator<<(raw_ostream &os, const TypeRange &types) {
87   llvm::interleaveComma(types, os);
88   return os;
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // ValueTypeRange
93 
94 /// This class implements iteration on the types of a given range of values.
95 template <typename ValueIteratorT>
96 class ValueTypeIterator final
97     : public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)> {
unwrap(Value value)98   static Type unwrap(Value value) { return value.getType(); }
99 
100 public:
101   using reference = Type;
102 
103   /// Provide a const dereference method.
104   Type operator*() const { return unwrap(*this->I); }
105 
106   /// Initializes the type iterator to the specified value iterator.
ValueTypeIterator(ValueIteratorT it)107   ValueTypeIterator(ValueIteratorT it)
108       : llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)>(it, &unwrap) {}
109 };
110 
111 /// This class implements iteration on the types of a given range of values.
112 template <typename ValueRangeT>
113 class ValueTypeRange final
114     : public llvm::iterator_range<
115           ValueTypeIterator<typename ValueRangeT::iterator>> {
116 public:
117   using llvm::iterator_range<
118       ValueTypeIterator<typename ValueRangeT::iterator>>::iterator_range;
119   template <typename Container>
ValueTypeRange(Container && c)120   ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {}
121 
122   /// Return the type at the given index.
123   Type operator[](size_t index) const {
124     assert(index < size() && "invalid index into type range");
125     return *(this->begin() + index);
126   }
127 
128   /// Return the size of this range.
size()129   size_t size() const { return llvm::size(*this); }
130 
131   /// Return first type in the range.
front()132   Type front() { return (*this)[0]; }
133 
134   /// Compare this range with another.
135   template <typename OtherT>
136   bool operator==(const OtherT &other) const {
137     return llvm::size(*this) == llvm::size(other) &&
138            std::equal(this->begin(), this->end(), other.begin());
139   }
140   template <typename OtherT>
141   bool operator!=(const OtherT &other) const {
142     return !(*this == other);
143   }
144 };
145 
146 template <typename RangeT>
147 inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) {
148   return lhs.size() == static_cast<size_t>(llvm::size(rhs)) &&
149          std::equal(lhs.begin(), lhs.end(), rhs.begin());
150 }
151 
152 } // namespace mlir
153 
154 namespace llvm {
155 
156 // Provide DenseMapInfo for TypeRange.
157 template <>
158 struct DenseMapInfo<mlir::TypeRange> {
159   static mlir::TypeRange getEmptyKey() {
160     return mlir::TypeRange(getEmptyKeyPointer(), 0);
161   }
162 
163   static mlir::TypeRange getTombstoneKey() {
164     return mlir::TypeRange(getTombstoneKeyPointer(), 0);
165   }
166 
167   static unsigned getHashValue(mlir::TypeRange val) { return hash_value(val); }
168 
169   static bool isEqual(mlir::TypeRange lhs, mlir::TypeRange rhs) {
170     if (isEmptyKey(rhs))
171       return isEmptyKey(lhs);
172     if (isTombstoneKey(rhs))
173       return isTombstoneKey(lhs);
174     return lhs == rhs;
175   }
176 
177 private:
178   static const mlir::Type *getEmptyKeyPointer() {
179     return DenseMapInfo<mlir::Type *>::getEmptyKey();
180   }
181 
182   static const mlir::Type *getTombstoneKeyPointer() {
183     return DenseMapInfo<mlir::Type *>::getTombstoneKey();
184   }
185 
186   static bool isEmptyKey(mlir::TypeRange range) {
187     if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
188       return type == getEmptyKeyPointer();
189     return false;
190   }
191 
192   static bool isTombstoneKey(mlir::TypeRange range) {
193     if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
194       return type == getTombstoneKeyPointer();
195     return false;
196   }
197 };
198 
199 } // namespace llvm
200 
201 #endif // MLIR_IR_TYPERANGE_H
202