1 //===- Attributes.h - MLIR Attribute Classes --------------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_ATTRIBUTES_H
10 #define MLIR_IR_ATTRIBUTES_H
11 
12 #include "mlir/IR/AttributeSupport.h"
13 #include "llvm/ADT/APFloat.h"
14 #include "llvm/ADT/Sequence.h"
15 
16 namespace mlir {
17 class AffineMap;
18 class Dialect;
19 class FunctionType;
20 class Identifier;
21 class IntegerSet;
22 class Location;
23 class MLIRContext;
24 class ShapedType;
25 class Type;
26 
27 namespace detail {
28 
29 struct AffineMapAttributeStorage;
30 struct ArrayAttributeStorage;
31 struct BoolAttributeStorage;
32 struct DictionaryAttributeStorage;
33 struct IntegerAttributeStorage;
34 struct IntegerSetAttributeStorage;
35 struct FloatAttributeStorage;
36 struct OpaqueAttributeStorage;
37 struct StringAttributeStorage;
38 struct SymbolRefAttributeStorage;
39 struct TypeAttributeStorage;
40 
41 /// Elements Attributes.
42 struct DenseElementsAttributeStorage;
43 struct OpaqueElementsAttributeStorage;
44 struct SparseElementsAttributeStorage;
45 } // namespace detail
46 
47 /// Attributes are known-constant values of operations and functions.
48 ///
49 /// Instances of the Attribute class are references to immutable, uniqued,
50 /// and immortal values owned by MLIRContext. As such, an Attribute is a thin
51 /// wrapper around an underlying storage pointer. Attributes are usually passed
52 /// by value.
53 class Attribute {
54 public:
55   /// Integer identifier for all the concrete attribute kinds.
56   enum Kind {
57   // Reserve attribute kinds for dialect specific extensions.
58 #define DEFINE_SYM_KIND_RANGE(Dialect)                                         \
59   FIRST_##Dialect##_ATTR, LAST_##Dialect##_ATTR = FIRST_##Dialect##_ATTR + 0xff,
60 #include "DialectSymbolRegistry.def"
61   };
62 
63   /// Utility class for implementing attributes.
64   template <typename ConcreteType, typename BaseType = Attribute,
65             typename StorageType = AttributeStorage>
66   using AttrBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
67                                            detail::AttributeUniquer>;
68 
69   using ImplType = AttributeStorage;
70   using ValueType = void;
71 
Attribute()72   constexpr Attribute() : impl(nullptr) {}
Attribute(const ImplType * impl)73   /* implicit */ Attribute(const ImplType *impl)
74       : impl(const_cast<ImplType *>(impl)) {}
75 
76   Attribute(const Attribute &other) = default;
77   Attribute &operator=(const Attribute &other) = default;
78 
79   bool operator==(Attribute other) const { return impl == other.impl; }
80   bool operator!=(Attribute other) const { return !(*this == other); }
81   explicit operator bool() const { return impl; }
82 
83   bool operator!() const { return impl == nullptr; }
84 
85   template <typename U> bool isa() const;
86   template <typename U> U dyn_cast() const;
87   template <typename U> U dyn_cast_or_null() const;
88   template <typename U> U cast() const;
89 
90   // Support dyn_cast'ing Attribute to itself.
classof(Attribute)91   static bool classof(Attribute) { return true; }
92 
93   /// Return the classification for this attribute.
getKind()94   unsigned getKind() const { return impl->getKind(); }
95 
96   /// Return the type of this attribute.
97   Type getType() const;
98 
99   /// Return the context this attribute belongs to.
100   MLIRContext *getContext() const;
101 
102   /// Get the dialect this attribute is registered to.
103   Dialect &getDialect() const;
104 
105   /// Print the attribute.
106   void print(raw_ostream &os) const;
107   void dump() const;
108 
109   /// Get an opaque pointer to the attribute.
getAsOpaquePointer()110   const void *getAsOpaquePointer() const { return impl; }
111   /// Construct an attribute from the opaque pointer representation.
getFromOpaquePointer(const void * ptr)112   static Attribute getFromOpaquePointer(const void *ptr) {
113     return Attribute(reinterpret_cast<const ImplType *>(ptr));
114   }
115 
116   friend ::llvm::hash_code hash_value(Attribute arg);
117 
118 protected:
119   ImplType *impl;
120 };
121 
122 inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
123   attr.print(os);
124   return os;
125 }
126 
127 namespace StandardAttributes {
128 enum Kind {
129   AffineMap = Attribute::FIRST_STANDARD_ATTR,
130   Array,
131   Bool,
132   Dictionary,
133   Float,
134   Integer,
135   IntegerSet,
136   Opaque,
137   String,
138   SymbolRef,
139   Type,
140   Unit,
141 
142   /// Elements Attributes.
143   DenseElements,
144   OpaqueElements,
145   SparseElements,
146   FIRST_ELEMENTS_ATTR = DenseElements,
147   LAST_ELEMENTS_ATTR = SparseElements,
148 
149   /// Locations.
150   CallSiteLocation,
151   FileLineColLocation,
152   FusedLocation,
153   NameLocation,
154   OpaqueLocation,
155   UnknownLocation,
156 
157   // Represents a location as a 'void*' pointer to a front-end's opaque
158   // location information, which must live longer than the MLIR objects that
159   // refer to it.  OpaqueLocation's are never serialized.
160   //
161   // TODO: OpaqueLocation,
162 
163   // Represents a value inlined through a function call.
164   // TODO: InlinedLocation,
165 
166   FIRST_LOCATION_ATTR = CallSiteLocation,
167   LAST_LOCATION_ATTR = UnknownLocation,
168 };
169 } // namespace StandardAttributes
170 
171 //===----------------------------------------------------------------------===//
172 // AffineMapAttr
173 //===----------------------------------------------------------------------===//
174 
175 class AffineMapAttr
176     : public Attribute::AttrBase<AffineMapAttr, Attribute,
177                                  detail::AffineMapAttributeStorage> {
178 public:
179   using Base::Base;
180   using ValueType = AffineMap;
181 
182   static AffineMapAttr get(AffineMap value);
183 
184   AffineMap getValue() const;
185 
186   /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)187   static bool kindof(unsigned kind) {
188     return kind == StandardAttributes::AffineMap;
189   }
190 };
191 
192 //===----------------------------------------------------------------------===//
193 // ArrayAttr
194 //===----------------------------------------------------------------------===//
195 
196 /// Array attributes are lists of other attributes.  They are not necessarily
197 /// type homogenous given that attributes don't, in general, carry types.
198 class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
199                                              detail::ArrayAttributeStorage> {
200 public:
201   using Base::Base;
202   using ValueType = ArrayRef<Attribute>;
203 
204   static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
205 
206   ArrayRef<Attribute> getValue() const;
207 
208   /// Support range iteration.
209   using iterator = llvm::ArrayRef<Attribute>::iterator;
begin()210   iterator begin() const { return getValue().begin(); }
end()211   iterator end() const { return getValue().end(); }
size()212   size_t size() const { return getValue().size(); }
213 
214   /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)215   static bool kindof(unsigned kind) {
216     return kind == StandardAttributes::Array;
217   }
218 
219 private:
220   /// Class for underlying value iterator support.
221   template <typename AttrTy>
222   class attr_value_iterator final
223       : public llvm::mapped_iterator<ArrayAttr::iterator,
224                                      AttrTy (*)(Attribute)> {
225   public:
attr_value_iterator(ArrayAttr::iterator it)226     explicit attr_value_iterator(ArrayAttr::iterator it)
227         : llvm::mapped_iterator<ArrayAttr::iterator, AttrTy (*)(Attribute)>(
228               it, [](Attribute attr) { return attr.cast<AttrTy>(); }) {}
229     AttrTy operator*() { return (*this->I).template cast<AttrTy>(); }
230   };
231 
232 public:
233   template <typename AttrTy>
getAsRange()234   llvm::iterator_range<attr_value_iterator<AttrTy>> getAsRange() {
235     return llvm::make_range(attr_value_iterator<AttrTy>(begin()),
236                             attr_value_iterator<AttrTy>(end()));
237   }
238 };
239 
240 //===----------------------------------------------------------------------===//
241 // BoolAttr
242 //===----------------------------------------------------------------------===//
243 
244 class BoolAttr : public Attribute::AttrBase<BoolAttr, Attribute,
245                                             detail::BoolAttributeStorage> {
246 public:
247   using Base::Base;
248   using ValueType = bool;
249 
250   static BoolAttr get(bool value, MLIRContext *context);
251 
252   bool getValue() const;
253 
254   /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)255   static bool kindof(unsigned kind) { return kind == StandardAttributes::Bool; }
256 };
257 
258 //===----------------------------------------------------------------------===//
259 // DictionaryAttr
260 //===----------------------------------------------------------------------===//
261 
262 /// NamedAttribute is used for dictionary attributes, it holds an identifier for
263 /// the name and a value for the attribute. The attribute pointer should always
264 /// be non-null.
265 using NamedAttribute = std::pair<Identifier, Attribute>;
266 
267 /// Dictionary attribute is an attribute that represents a sorted collection of
268 /// named attribute values. The elements are sorted by name, and each name must
269 /// be unique within the collection.
270 class DictionaryAttr
271     : public Attribute::AttrBase<DictionaryAttr, Attribute,
272                                  detail::DictionaryAttributeStorage> {
273 public:
274   using Base::Base;
275   using ValueType = ArrayRef<NamedAttribute>;
276 
277   static DictionaryAttr get(ArrayRef<NamedAttribute> value,
278                             MLIRContext *context);
279 
280   ArrayRef<NamedAttribute> getValue() const;
281 
282   /// Return the specified attribute if present, null otherwise.
283   Attribute get(StringRef name) const;
284   Attribute get(Identifier name) const;
285 
286   /// Support range iteration.
287   using iterator = llvm::ArrayRef<NamedAttribute>::iterator;
288   iterator begin() const;
289   iterator end() const;
empty()290   bool empty() const { return size() == 0; }
291   size_t size() const;
292 
293   /// Methods for supporting type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)294   static bool kindof(unsigned kind) {
295     return kind == StandardAttributes::Dictionary;
296   }
297 };
298 
299 //===----------------------------------------------------------------------===//
300 // FloatAttr
301 //===----------------------------------------------------------------------===//
302 
303 class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
304                                              detail::FloatAttributeStorage> {
305 public:
306   using Base::Base;
307   using ValueType = APFloat;
308 
309   /// Return a float attribute for the specified value in the specified type.
310   /// These methods should only be used for simple constant values, e.g 1.0/2.0,
311   /// that are known-valid both as host double and the 'type' format.
312   static FloatAttr get(Type type, double value);
313   static FloatAttr getChecked(Type type, double value, Location loc);
314 
315   /// Return a float attribute for the specified value in the specified type.
316   static FloatAttr get(Type type, const APFloat &value);
317   static FloatAttr getChecked(Type type, const APFloat &value, Location loc);
318 
319   APFloat getValue() const;
320 
321   /// This function is used to convert the value to a double, even if it loses
322   /// precision.
323   double getValueAsDouble() const;
324   static double getValueAsDouble(APFloat val);
325 
326   /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)327   static bool kindof(unsigned kind) {
328     return kind == StandardAttributes::Float;
329   }
330 
331   /// Verify the construction invariants for a double value.
332   static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
333                                                     MLIRContext *ctx, Type type,
334                                                     double value);
335   static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
336                                                     MLIRContext *ctx, Type type,
337                                                     const APFloat &value);
338 };
339 
340 //===----------------------------------------------------------------------===//
341 // IntegerAttr
342 //===----------------------------------------------------------------------===//
343 
344 class IntegerAttr
345     : public Attribute::AttrBase<IntegerAttr, Attribute,
346                                  detail::IntegerAttributeStorage> {
347 public:
348   using Base::Base;
349   using ValueType = APInt;
350 
351   static IntegerAttr get(Type type, int64_t value);
352   static IntegerAttr get(Type type, const APInt &value);
353 
354   APInt getValue() const;
355   // TODO(jpienaar): Change callers to use getValue instead.
356   int64_t getInt() const;
357 
358   /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)359   static bool kindof(unsigned kind) {
360     return kind == StandardAttributes::Integer;
361   }
362 };
363 
364 //===----------------------------------------------------------------------===//
365 // IntegerSetAttr
366 //===----------------------------------------------------------------------===//
367 
368 class IntegerSetAttr
369     : public Attribute::AttrBase<IntegerSetAttr, Attribute,
370                                  detail::IntegerSetAttributeStorage> {
371 public:
372   using Base::Base;
373   using ValueType = IntegerSet;
374 
375   static IntegerSetAttr get(IntegerSet value);
376 
377   IntegerSet getValue() const;
378 
379   /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)380   static bool kindof(unsigned kind) {
381     return kind == StandardAttributes::IntegerSet;
382   }
383 };
384 
385 //===----------------------------------------------------------------------===//
386 // OpaqueAttr
387 //===----------------------------------------------------------------------===//
388 
389 /// Opaque attributes represent attributes of non-registered dialects. These are
390 /// attribute represented in their raw string form, and can only usefully be
391 /// tested for attribute equality.
392 class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
393                                               detail::OpaqueAttributeStorage> {
394 public:
395   using Base::Base;
396 
397   /// Get or create a new OpaqueAttr with the provided dialect and string data.
398   static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type,
399                         MLIRContext *context);
400 
401   /// Get or create a new OpaqueAttr with the provided dialect and string data.
402   /// If the given identifier is not a valid namespace for a dialect, then a
403   /// null attribute is returned.
404   static OpaqueAttr getChecked(Identifier dialect, StringRef attrData,
405                                Type type, Location location);
406 
407   /// Returns the dialect namespace of the opaque attribute.
408   Identifier getDialectNamespace() const;
409 
410   /// Returns the raw attribute data of the opaque attribute.
411   StringRef getAttrData() const;
412 
413   /// Verify the construction of an opaque attribute.
414   static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
415                                                     MLIRContext *context,
416                                                     Identifier dialect,
417                                                     StringRef attrData,
418                                                     Type type);
419 
kindof(unsigned kind)420   static bool kindof(unsigned kind) {
421     return kind == StandardAttributes::Opaque;
422   }
423 };
424 
425 //===----------------------------------------------------------------------===//
426 // StringAttr
427 //===----------------------------------------------------------------------===//
428 
429 class StringAttr : public Attribute::AttrBase<StringAttr, Attribute,
430                                               detail::StringAttributeStorage> {
431 public:
432   using Base::Base;
433   using ValueType = StringRef;
434 
435   /// Get an instance of a StringAttr with the given string.
436   static StringAttr get(StringRef bytes, MLIRContext *context);
437 
438   /// Get an instance of a StringAttr with the given string and Type.
439   static StringAttr get(StringRef bytes, Type type);
440 
441   StringRef getValue() const;
442 
443   /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)444   static bool kindof(unsigned kind) {
445     return kind == StandardAttributes::String;
446   }
447 };
448 
449 //===----------------------------------------------------------------------===//
450 // SymbolRefAttr
451 //===----------------------------------------------------------------------===//
452 
453 class FlatSymbolRefAttr;
454 
455 /// A symbol reference attribute represents a symbolic reference to another
456 /// operation.
457 class SymbolRefAttr
458     : public Attribute::AttrBase<SymbolRefAttr, Attribute,
459                                  detail::SymbolRefAttributeStorage> {
460 public:
461   using Base::Base;
462 
463   /// Construct a symbol reference for the given value name.
464   static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx);
465 
466   /// Construct a symbol reference for the given value name, and a set of nested
467   /// references that are further resolve to a nested symbol.
468   static SymbolRefAttr get(StringRef value,
469                            ArrayRef<FlatSymbolRefAttr> references,
470                            MLIRContext *ctx);
471 
472   /// Returns the name of the top level symbol reference, i.e. the root of the
473   /// reference path.
474   StringRef getRootReference() const;
475 
476   /// Returns the name of the fully resolved symbol, i.e. the leaf of the
477   /// reference path.
478   StringRef getLeafReference() const;
479 
480   /// Returns the set of nested references representing the path to the symbol
481   /// nested under the root reference.
482   ArrayRef<FlatSymbolRefAttr> getNestedReferences() const;
483 
484   /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)485   static bool kindof(unsigned kind) {
486     return kind == StandardAttributes::SymbolRef;
487   }
488 };
489 
490 /// A symbol reference with a reference path containing a single element. This
491 /// is used to refer to an operation within the current symbol table.
492 class FlatSymbolRefAttr : public SymbolRefAttr {
493 public:
494   using SymbolRefAttr::SymbolRefAttr;
495   using ValueType = StringRef;
496 
497   /// Construct a symbol reference for the given value name.
get(StringRef value,MLIRContext * ctx)498   static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) {
499     return SymbolRefAttr::get(value, ctx);
500   }
501 
502   /// Returns the name of the held symbol reference.
getValue()503   StringRef getValue() const { return getRootReference(); }
504 
505   /// Methods for support type inquiry through isa, cast, and dyn_cast.
classof(Attribute attr)506   static bool classof(Attribute attr) {
507     SymbolRefAttr refAttr = attr.dyn_cast<SymbolRefAttr>();
508     return refAttr && refAttr.getNestedReferences().empty();
509   }
510 
511 private:
512   using SymbolRefAttr::get;
513   using SymbolRefAttr::getNestedReferences;
514 };
515 
516 //===----------------------------------------------------------------------===//
517 // Type
518 //===----------------------------------------------------------------------===//
519 
520 class TypeAttr : public Attribute::AttrBase<TypeAttr, Attribute,
521                                             detail::TypeAttributeStorage> {
522 public:
523   using Base::Base;
524   using ValueType = Type;
525 
526   static TypeAttr get(Type value);
527 
528   Type getValue() const;
529 
530   /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)531   static bool kindof(unsigned kind) { return kind == StandardAttributes::Type; }
532 };
533 
534 //===----------------------------------------------------------------------===//
535 // UnitAttr
536 //===----------------------------------------------------------------------===//
537 
538 /// Unit attributes are attributes that hold no specific value and are given
539 /// meaning by their existence.
540 class UnitAttr : public Attribute::AttrBase<UnitAttr> {
541 public:
542   using Base::Base;
543 
544   static UnitAttr get(MLIRContext *context);
545 
kindof(unsigned kind)546   static bool kindof(unsigned kind) { return kind == StandardAttributes::Unit; }
547 };
548 
549 //===----------------------------------------------------------------------===//
550 // Elements Attributes
551 //===----------------------------------------------------------------------===//
552 
553 namespace detail {
554 template <typename T> class ElementsAttrIterator;
555 template <typename T> class ElementsAttrRange;
556 } // namespace detail
557 
558 /// A base attribute that represents a reference to a static shaped tensor or
559 /// vector constant.
560 class ElementsAttr : public Attribute {
561 public:
562   using Attribute::Attribute;
563   template <typename T> using iterator = detail::ElementsAttrIterator<T>;
564   template <typename T> using iterator_range = detail::ElementsAttrRange<T>;
565 
566   /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
567   /// with static shape.
568   ShapedType getType() const;
569 
570   /// Return the value at the given index. The index is expected to refer to a
571   /// valid element.
572   Attribute getValue(ArrayRef<uint64_t> index) const;
573 
574   /// Return the value of type 'T' at the given index, where 'T' corresponds to
575   /// an Attribute type.
getValue(ArrayRef<uint64_t> index)576   template <typename T> T getValue(ArrayRef<uint64_t> index) const {
577     return getValue(index).template cast<T>();
578   }
579 
580   /// Return the elements of this attribute as a value of type 'T'. Note:
581   /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
582   /// iteration.
583   template <typename T> iterator_range<T> getValues() const;
584 
585   /// Return if the given 'index' refers to a valid element in this attribute.
586   bool isValidIndex(ArrayRef<uint64_t> index) const;
587 
588   /// Returns the number of elements held by this attribute.
589   int64_t getNumElements() const;
590 
591   /// Generates a new ElementsAttr by mapping each int value to a new
592   /// underlying APInt. The new values can represent either a integer or float.
593   /// This ElementsAttr should contain integers.
594   ElementsAttr mapValues(Type newElementType,
595                          function_ref<APInt(const APInt &)> mapping) const;
596 
597   /// Generates a new ElementsAttr by mapping each float value to a new
598   /// underlying APInt. The new values can represent either a integer or float.
599   /// This ElementsAttr should contain floats.
600   ElementsAttr mapValues(Type newElementType,
601                          function_ref<APInt(const APFloat &)> mapping) const;
602 
603   /// Method for support type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)604   static bool classof(Attribute attr) {
605     return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR &&
606            attr.getKind() <= StandardAttributes::LAST_ELEMENTS_ATTR;
607   }
608 
609 protected:
610   /// Returns the 1 dimensional flattened row-major index from the given
611   /// multi-dimensional index.
612   uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const;
613 };
614 
615 namespace detail {
616 /// DenseElementsAttr data is aligned to uint64_t, so this traits class is
617 /// necessary to interop with PointerIntPair.
618 class DenseElementDataPointerTypeTraits {
619 public:
getAsVoidPointer(const char * ptr)620   static inline const void *getAsVoidPointer(const char *ptr) { return ptr; }
getFromVoidPointer(const void * ptr)621   static inline const char *getFromVoidPointer(const void *ptr) {
622     return static_cast<const char *>(ptr);
623   }
624 
625   // Note: We could steal more bits if the need arises.
626   enum { NumLowBitsAvailable = 1 };
627 };
628 
629 /// Pair of raw pointer and a boolean flag of whether the pointer holds a splat,
630 using DenseIterPtrAndSplat =
631     llvm::PointerIntPair<const char *, 1, bool,
632                          DenseElementDataPointerTypeTraits>;
633 
634 /// Impl iterator for indexed DenseElementAttr iterators that records a data
635 /// pointer and data index that is adjusted for the case of a splat attribute.
636 template <typename ConcreteT, typename T, typename PointerT = T *,
637           typename ReferenceT = T &>
638 class DenseElementIndexedIteratorImpl
639     : public indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T,
640                                        PointerT, ReferenceT> {
641 protected:
DenseElementIndexedIteratorImpl(const char * data,bool isSplat,size_t dataIndex)642   DenseElementIndexedIteratorImpl(const char *data, bool isSplat,
643                                   size_t dataIndex)
644       : indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T, PointerT,
645                                   ReferenceT>({data, isSplat}, dataIndex) {}
646 
647   /// Return the current index for this iterator, adjusted for the case of a
648   /// splat.
getDataIndex()649   ptrdiff_t getDataIndex() const {
650     bool isSplat = this->base.getInt();
651     return isSplat ? 0 : this->index;
652   }
653 
654   /// Return the data base pointer.
getData()655   const char *getData() const { return this->base.getPointer(); }
656 };
657 } // namespace detail
658 
659 /// An attribute that represents a reference to a dense vector or tensor object.
660 ///
661 class DenseElementsAttr
662     : public Attribute::AttrBase<DenseElementsAttr, ElementsAttr,
663                                  detail::DenseElementsAttributeStorage> {
664 public:
665   using Base::Base;
666 
667   /// Method for support type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)668   static bool classof(Attribute attr) {
669     return attr.getKind() == StandardAttributes::DenseElements;
670   }
671 
672   /// Constructs a dense elements attribute from an array of element values.
673   /// Each element attribute value is expected to be an element of 'type'.
674   /// 'type' must be a vector or tensor with static shape.
675   static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
676 
677   /// Constructs a dense integer elements attribute from an array of integer
678   /// or floating-point values. Each value is expected to be the same bitwidth
679   /// of the element type of 'type'. 'type' must be a vector or tensor with
680   /// static shape.
681   template <typename T, typename = typename std::enable_if<
682                             std::numeric_limits<T>::is_integer ||
683                             llvm::is_one_of<T, float, double>::value>::type>
get(const ShapedType & type,ArrayRef<T> values)684   static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
685     const char *data = reinterpret_cast<const char *>(values.data());
686     return getRawIntOrFloat(
687         type, ArrayRef<char>(data, values.size() * sizeof(T)), sizeof(T),
688         /*isInt=*/std::numeric_limits<T>::is_integer);
689   }
690 
691   /// Constructs a dense integer elements attribute from a single element.
692   template <typename T, typename = typename std::enable_if<
693                             std::numeric_limits<T>::is_integer ||
694                             llvm::is_one_of<T, float, double>::value>::type>
get(const ShapedType & type,T value)695   static DenseElementsAttr get(const ShapedType &type, T value) {
696     return get(type, llvm::makeArrayRef(value));
697   }
698 
699   /// Overload of the above 'get' method that is specialized for boolean values.
700   static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
701 
702   /// Constructs a dense integer elements attribute from an array of APInt
703   /// values. Each APInt value is expected to have the same bitwidth as the
704   /// element type of 'type'. 'type' must be a vector or tensor with static
705   /// shape.
706   static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
707 
708   /// Constructs a dense float elements attribute from an array of APFloat
709   /// values. Each APFloat value is expected to have the same bitwidth as the
710   /// element type of 'type'. 'type' must be a vector or tensor with static
711   /// shape.
712   static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
713 
714   /// Construct a dense elements attribute for an initializer_list of values.
715   /// Each value is expected to be the same bitwidth of the element type of
716   /// 'type'. 'type' must be a vector or tensor with static shape.
717   template <typename T>
get(const ShapedType & type,const std::initializer_list<T> & list)718   static DenseElementsAttr get(const ShapedType &type,
719                                const std::initializer_list<T> &list) {
720     return get(type, ArrayRef<T>(list));
721   }
722 
723   //===--------------------------------------------------------------------===//
724   // Iterators
725   //===--------------------------------------------------------------------===//
726 
727   /// A utility iterator that allows walking over the internal Attribute values
728   /// of a DenseElementsAttr.
729   class AttributeElementIterator
730       : public indexed_accessor_iterator<AttributeElementIterator, const void *,
731                                          Attribute, Attribute, Attribute> {
732   public:
733     /// Accesses the Attribute value at this iterator position.
734     Attribute operator*() const;
735 
736   private:
737     friend DenseElementsAttr;
738 
739     /// Constructs a new iterator.
740     AttributeElementIterator(DenseElementsAttr attr, size_t index);
741   };
742 
743   /// Iterator for walking raw element values of the specified type 'T', which
744   /// may be any c++ data type matching the stored representation: int32_t,
745   /// float, etc.
746   template <typename T>
747   class ElementIterator
748       : public detail::DenseElementIndexedIteratorImpl<ElementIterator<T>,
749                                                        const T> {
750   public:
751     /// Accesses the raw value at this iterator position.
752     const T &operator*() const {
753       return reinterpret_cast<const T *>(this->getData())[this->getDataIndex()];
754     }
755 
756   private:
757     friend DenseElementsAttr;
758 
759     /// Constructs a new iterator.
ElementIterator(const char * data,bool isSplat,size_t dataIndex)760     ElementIterator(const char *data, bool isSplat, size_t dataIndex)
761         : detail::DenseElementIndexedIteratorImpl<ElementIterator<T>, const T>(
762               data, isSplat, dataIndex) {}
763   };
764 
765   /// A utility iterator that allows walking over the internal bool values.
766   class BoolElementIterator
767       : public detail::DenseElementIndexedIteratorImpl<BoolElementIterator,
768                                                        bool, bool, bool> {
769   public:
770     /// Accesses the bool value at this iterator position.
771     bool operator*() const;
772 
773   private:
774     friend DenseElementsAttr;
775 
776     /// Constructs a new iterator.
777     BoolElementIterator(DenseElementsAttr attr, size_t dataIndex);
778   };
779 
780   /// A utility iterator that allows walking over the internal raw APInt values.
781   class IntElementIterator
782       : public detail::DenseElementIndexedIteratorImpl<IntElementIterator,
783                                                        APInt, APInt, APInt> {
784   public:
785     /// Accesses the raw APInt value at this iterator position.
786     APInt operator*() const;
787 
788   private:
789     friend DenseElementsAttr;
790 
791     /// Constructs a new iterator.
792     IntElementIterator(DenseElementsAttr attr, size_t dataIndex);
793 
794     /// The bitwidth of the element type.
795     size_t bitWidth;
796   };
797 
798   /// Iterator for walking over APFloat values.
799   class FloatElementIterator final
800       : public llvm::mapped_iterator<IntElementIterator,
801                                      std::function<APFloat(const APInt &)>> {
802     friend DenseElementsAttr;
803 
804     /// Initializes the float element iterator to the specified iterator.
805     FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it);
806 
807   public:
808     using reference = APFloat;
809   };
810 
811   //===--------------------------------------------------------------------===//
812   // Value Querying
813   //===--------------------------------------------------------------------===//
814 
815   /// Returns if this attribute corresponds to a splat, i.e. if all element
816   /// values are the same.
817   bool isSplat() const;
818 
819   /// Return the splat value for this attribute. This asserts that the attribute
820   /// corresponds to a splat.
getSplatValue()821   Attribute getSplatValue() const { return getSplatValue<Attribute>(); }
822   template <typename T>
823   typename std::enable_if<!std::is_base_of<Attribute, T>::value ||
824                               std::is_same<Attribute, T>::value,
825                           T>::type
getSplatValue()826   getSplatValue() const {
827     assert(isSplat() && "expected the attribute to be a splat");
828     return *getValues<T>().begin();
829   }
830   /// Return the splat value for derived attribute element types.
831   template <typename T>
832   typename std::enable_if<std::is_base_of<Attribute, T>::value &&
833                               !std::is_same<Attribute, T>::value,
834                           T>::type
getSplatValue()835   getSplatValue() const {
836     return getSplatValue().template cast<T>();
837   }
838 
839   /// Return the value at the given index. The 'index' is expected to refer to a
840   /// valid element.
getValue(ArrayRef<uint64_t> index)841   Attribute getValue(ArrayRef<uint64_t> index) const {
842     return getValue<Attribute>(index);
843   }
getValue(ArrayRef<uint64_t> index)844   template <typename T> T getValue(ArrayRef<uint64_t> index) const {
845     // Skip to the element corresponding to the flattened index.
846     return *std::next(getValues<T>().begin(), getFlattenedIndex(index));
847   }
848 
849   /// Return the held element values as a range of integer or floating-point
850   /// values.
851   template <typename T, typename = typename std::enable_if<
852                             (!std::is_same<T, bool>::value &&
853                              std::numeric_limits<T>::is_integer) ||
854                             llvm::is_one_of<T, float, double>::value>::type>
getValues()855   llvm::iterator_range<ElementIterator<T>> getValues() const {
856     assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer));
857     auto rawData = getRawData().data();
858     bool splat = isSplat();
859     return {ElementIterator<T>(rawData, splat, 0),
860             ElementIterator<T>(rawData, splat, getNumElements())};
861   }
862 
863   /// Return the held element values as a range of Attributes.
864   llvm::iterator_range<AttributeElementIterator> getAttributeValues() const;
865   template <typename T, typename = typename std::enable_if<
866                             std::is_same<T, Attribute>::value>::type>
getValues()867   llvm::iterator_range<AttributeElementIterator> getValues() const {
868     return getAttributeValues();
869   }
870   AttributeElementIterator attr_value_begin() const;
871   AttributeElementIterator attr_value_end() const;
872 
873   /// Return the held element values a range of T, where T is a derived
874   /// attribute type.
875   template <typename T>
876   using DerivedAttributeElementIterator =
877       llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>;
878   template <typename T, typename = typename std::enable_if<
879                             std::is_base_of<Attribute, T>::value &&
880                             !std::is_same<Attribute, T>::value>::type>
getValues()881   llvm::iterator_range<DerivedAttributeElementIterator<T>> getValues() const {
882     auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
883     return llvm::map_range(getAttributeValues(),
884                            static_cast<T (*)(Attribute)>(castFn));
885   }
886 
887   /// Return the held element values as a range of bool. The element type of
888   /// this attribute must be of integer type of bitwidth 1.
889   llvm::iterator_range<BoolElementIterator> getBoolValues() const;
890   template <typename T, typename = typename std::enable_if<
891                             std::is_same<T, bool>::value>::type>
getValues()892   llvm::iterator_range<BoolElementIterator> getValues() const {
893     return getBoolValues();
894   }
895 
896   /// Return the held element values as a range of APInts. The element type of
897   /// this attribute must be of integer type.
898   llvm::iterator_range<IntElementIterator> getIntValues() const;
899   template <typename T, typename = typename std::enable_if<
900                             std::is_same<T, APInt>::value>::type>
getValues()901   llvm::iterator_range<IntElementIterator> getValues() const {
902     return getIntValues();
903   }
904   IntElementIterator int_value_begin() const;
905   IntElementIterator int_value_end() const;
906 
907   /// Return the held element values as a range of APFloat. The element type of
908   /// this attribute must be of float type.
909   llvm::iterator_range<FloatElementIterator> getFloatValues() const;
910   template <typename T, typename = typename std::enable_if<
911                             std::is_same<T, APFloat>::value>::type>
getValues()912   llvm::iterator_range<FloatElementIterator> getValues() const {
913     return getFloatValues();
914   }
915   FloatElementIterator float_value_begin() const;
916   FloatElementIterator float_value_end() const;
917 
918   //===--------------------------------------------------------------------===//
919   // Mutation Utilities
920   //===--------------------------------------------------------------------===//
921 
922   /// Return a new DenseElementsAttr that has the same data as the current
923   /// attribute, but has been reshaped to 'newType'. The new type must have the
924   /// same total number of elements as well as element type.
925   DenseElementsAttr reshape(ShapedType newType);
926 
927   /// Generates a new DenseElementsAttr by mapping each int value to a new
928   /// underlying APInt. The new values can represent either a integer or float.
929   /// This underlying type must be an DenseIntElementsAttr.
930   DenseElementsAttr mapValues(Type newElementType,
931                               function_ref<APInt(const APInt &)> mapping) const;
932 
933   /// Generates a new DenseElementsAttr by mapping each float value to a new
934   /// underlying APInt. the new values can represent either a integer or float.
935   /// This underlying type must be an DenseFPElementsAttr.
936   DenseElementsAttr
937   mapValues(Type newElementType,
938             function_ref<APInt(const APFloat &)> mapping) const;
939 
940 protected:
941   /// Return the raw storage data held by this attribute.
942   ArrayRef<char> getRawData() const;
943 
944   /// Get iterators to the raw APInt values for each element in this attribute.
raw_int_begin()945   IntElementIterator raw_int_begin() const {
946     return IntElementIterator(*this, 0);
947   }
raw_int_end()948   IntElementIterator raw_int_end() const {
949     return IntElementIterator(*this, getNumElements());
950   }
951 
952   /// Constructs a dense elements attribute from an array of raw APInt values.
953   /// Each APInt value is expected to have the same bitwidth as the element type
954   /// of 'type'. 'type' must be a vector or tensor with static shape.
955   static DenseElementsAttr getRaw(ShapedType type, ArrayRef<APInt> values);
956 
957   /// Get or create a new dense elements attribute instance with the given raw
958   /// data buffer. 'type' must be a vector or tensor with static shape.
959   static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data,
960                                   bool isSplat);
961 
962   /// Overload of the raw 'get' method that asserts that the given type is of
963   /// integer or floating-point type. This method is used to verify type
964   /// invariants that the templatized 'get' method cannot.
965   static DenseElementsAttr getRawIntOrFloat(ShapedType type,
966                                             ArrayRef<char> data,
967                                             int64_t dataEltSize, bool isInt);
968 
969   /// Check the information for a c++ data type, check if this type is valid for
970   /// the current attribute. This method is used to verify specific type
971   /// invariants that the templatized 'getValues' method cannot.
972   bool isValidIntOrFloat(int64_t dataEltSize, bool isInt) const;
973 };
974 
975 /// An attribute that represents a reference to a dense float vector or tensor
976 /// object. Each element is stored as a double.
977 class DenseFPElementsAttr : public DenseElementsAttr {
978 public:
979   using iterator = DenseElementsAttr::FloatElementIterator;
980 
981   using DenseElementsAttr::DenseElementsAttr;
982 
983   /// Get an instance of a DenseFPElementsAttr with the given arguments. This
984   /// simply wraps the DenseElementsAttr::get calls.
985   template <typename Arg>
get(const ShapedType & type,Arg && arg)986   static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg) {
987     return DenseElementsAttr::get(type, llvm::makeArrayRef(arg))
988         .template cast<DenseFPElementsAttr>();
989   }
990   template <typename T>
get(const ShapedType & type,const std::initializer_list<T> & list)991   static DenseFPElementsAttr get(const ShapedType &type,
992                                  const std::initializer_list<T> &list) {
993     return DenseElementsAttr::get(type, list)
994         .template cast<DenseFPElementsAttr>();
995   }
996 
997   /// Generates a new DenseElementsAttr by mapping each value attribute, and
998   /// constructing the DenseElementsAttr given the new element type.
999   DenseElementsAttr
1000   mapValues(Type newElementType,
1001             function_ref<APInt(const APFloat &)> mapping) const;
1002 
1003   /// Iterator access to the float element values.
begin()1004   iterator begin() const { return float_value_begin(); }
end()1005   iterator end() const { return float_value_end(); }
1006 
1007   /// Method for supporting type inquiry through isa, cast and dyn_cast.
1008   static bool classof(Attribute attr);
1009 };
1010 
1011 /// An attribute that represents a reference to a dense integer vector or tensor
1012 /// object.
1013 class DenseIntElementsAttr : public DenseElementsAttr {
1014 public:
1015   /// DenseIntElementsAttr iterates on APInt, so we can use the raw element
1016   /// iterator directly.
1017   using iterator = DenseElementsAttr::IntElementIterator;
1018 
1019   using DenseElementsAttr::DenseElementsAttr;
1020 
1021   /// Get an instance of a DenseIntElementsAttr with the given arguments. This
1022   /// simply wraps the DenseElementsAttr::get calls.
1023   template <typename Arg>
get(const ShapedType & type,Arg && arg)1024   static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg) {
1025     return DenseElementsAttr::get(type, llvm::makeArrayRef(arg))
1026         .template cast<DenseIntElementsAttr>();
1027   }
1028   template <typename T>
get(const ShapedType & type,const std::initializer_list<T> & list)1029   static DenseIntElementsAttr get(const ShapedType &type,
1030                                   const std::initializer_list<T> &list) {
1031     return DenseElementsAttr::get(type, list)
1032         .template cast<DenseIntElementsAttr>();
1033   }
1034 
1035   /// Generates a new DenseElementsAttr by mapping each value attribute, and
1036   /// constructing the DenseElementsAttr given the new element type.
1037   DenseElementsAttr mapValues(Type newElementType,
1038                               function_ref<APInt(const APInt &)> mapping) const;
1039 
1040   /// Iterator access to the integer element values.
begin()1041   iterator begin() const { return raw_int_begin(); }
end()1042   iterator end() const { return raw_int_end(); }
1043 
1044   /// Method for supporting type inquiry through isa, cast and dyn_cast.
1045   static bool classof(Attribute attr);
1046 };
1047 
1048 /// An opaque attribute that represents a reference to a vector or tensor
1049 /// constant with opaque content. This representation is for tensor constants
1050 /// which the compiler may not need to interpret. This attribute is always
1051 /// associated with a particular dialect, which provides a method to convert
1052 /// tensor representation to a non-opaque format.
1053 class OpaqueElementsAttr
1054     : public Attribute::AttrBase<OpaqueElementsAttr, ElementsAttr,
1055                                  detail::OpaqueElementsAttributeStorage> {
1056 public:
1057   using Base::Base;
1058   using ValueType = StringRef;
1059 
1060   static OpaqueElementsAttr get(Dialect *dialect, ShapedType type,
1061                                 StringRef bytes);
1062 
1063   StringRef getValue() const;
1064 
1065   /// Return the value at the given index. The 'index' is expected to refer to a
1066   /// valid element.
1067   Attribute getValue(ArrayRef<uint64_t> index) const;
1068 
1069   /// Decodes the attribute value using dialect-specific decoding hook.
1070   /// Returns false if decoding is successful. If not, returns true and leaves
1071   /// 'result' argument unspecified.
1072   bool decode(ElementsAttr &result);
1073 
1074   /// Returns dialect associated with this opaque constant.
1075   Dialect *getDialect() const;
1076 
1077   /// Method for support type inquiry through isa, cast and dyn_cast.
kindof(unsigned kind)1078   static bool kindof(unsigned kind) {
1079     return kind == StandardAttributes::OpaqueElements;
1080   }
1081 };
1082 
1083 /// An attribute that represents a reference to a sparse vector or tensor
1084 /// object.
1085 ///
1086 /// This class uses COO (coordinate list) encoding to represent the sparse
1087 /// elements in an element attribute. Specifically, the sparse vector/tensor
1088 /// stores the indices and values as two separate dense elements attributes of
1089 /// tensor type (even if the sparse attribute is of vector type, in order to
1090 /// support empty lists). The dense elements attribute indices is a 2-D tensor
1091 /// of 64-bit integer elements with shape [N, ndims], which specifies the
1092 /// indices of the elements in the sparse tensor that contains nonzero values.
1093 /// The dense elements attribute values is a 1-D tensor with shape [N], and it
1094 /// supplies the corresponding values for the indices.
1095 ///
1096 /// For example,
1097 /// `sparse<tensor<3x4xi32>, [[0, 0], [1, 2]], [1, 5]>` represents tensor
1098 /// [[1, 0, 0, 0],
1099 ///  [0, 0, 5, 0],
1100 ///  [0, 0, 0, 0]].
1101 class SparseElementsAttr
1102     : public Attribute::AttrBase<SparseElementsAttr, ElementsAttr,
1103                                  detail::SparseElementsAttributeStorage> {
1104 public:
1105   using Base::Base;
1106 
1107   template <typename T>
1108   using iterator =
1109       llvm::mapped_iterator<llvm::detail::value_sequence_iterator<ptrdiff_t>,
1110                             std::function<T(ptrdiff_t)>>;
1111 
1112   /// 'type' must be a vector or tensor with static shape.
1113   static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices,
1114                                 DenseElementsAttr values);
1115 
1116   DenseIntElementsAttr getIndices() const;
1117 
1118   DenseElementsAttr getValues() const;
1119 
1120   /// Return the values of this attribute in the form of the given type 'T'. 'T'
1121   /// may be any of Attribute, APInt, APFloat, c++ integer/float types, etc.
getValues()1122   template <typename T> llvm::iterator_range<iterator<T>> getValues() const {
1123     auto zeroValue = getZeroValue<T>();
1124     auto valueIt = getValues().getValues<T>().begin();
1125     const std::vector<ptrdiff_t> flatSparseIndices(getFlattenedSparseIndices());
1126     // TODO(riverriddle): Move-capture flatSparseIndices when c++14 is
1127     // available.
1128     std::function<T(ptrdiff_t)> mapFn = [=](ptrdiff_t index) {
1129       // Try to map the current index to one of the sparse indices.
1130       for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i)
1131         if (flatSparseIndices[i] == index)
1132           return *std::next(valueIt, i);
1133       // Otherwise, return the zero value.
1134       return zeroValue;
1135     };
1136     return llvm::map_range(llvm::seq<ptrdiff_t>(0, getNumElements()), mapFn);
1137   }
1138 
1139   /// Return the value of the element at the given index. The 'index' is
1140   /// expected to refer to a valid element.
1141   Attribute getValue(ArrayRef<uint64_t> index) const;
1142 
1143   /// Method for support type inquiry through isa, cast and dyn_cast.
kindof(unsigned kind)1144   static bool kindof(unsigned kind) {
1145     return kind == StandardAttributes::SparseElements;
1146   }
1147 
1148 private:
1149   /// Get a zero APFloat for the given sparse attribute.
1150   APFloat getZeroAPFloat() const;
1151 
1152   /// Get a zero APInt for the given sparse attribute.
1153   APInt getZeroAPInt() const;
1154 
1155   /// Get a zero attribute for the given sparse attribute.
1156   Attribute getZeroAttr() const;
1157 
1158   /// Utility methods to generate a zero value of some type 'T'. This is used by
1159   /// the 'iterator' class.
1160   /// Get a zero for a given attribute type.
1161   template <typename T>
1162   typename std::enable_if<std::is_base_of<Attribute, T>::value, T>::type
getZeroValue()1163   getZeroValue() const {
1164     return getZeroAttr().template cast<T>();
1165   }
1166   /// Get a zero for an APInt.
1167   template <typename T>
1168   typename std::enable_if<std::is_same<APInt, T>::value, T>::type
getZeroValue()1169   getZeroValue() const {
1170     return getZeroAPInt();
1171   }
1172   /// Get a zero for an APFloat.
1173   template <typename T>
1174   typename std::enable_if<std::is_same<APFloat, T>::value, T>::type
getZeroValue()1175   getZeroValue() const {
1176     return getZeroAPFloat();
1177   }
1178   /// Get a zero for an C++ integer or float type.
1179   template <typename T>
1180   typename std::enable_if<std::numeric_limits<T>::is_integer ||
1181                               llvm::is_one_of<T, float, double>::value,
1182                           T>::type
getZeroValue()1183   getZeroValue() const {
1184     return T(0);
1185   }
1186 
1187   /// Flatten, and return, all of the sparse indices in this attribute in
1188   /// row-major order.
1189   std::vector<ptrdiff_t> getFlattenedSparseIndices() const;
1190 };
1191 
1192 /// An attribute that represents a reference to a splat vector or tensor
1193 /// constant, meaning all of the elements have the same value.
1194 class SplatElementsAttr : public DenseElementsAttr {
1195 public:
1196   using DenseElementsAttr::DenseElementsAttr;
1197 
1198   /// Method for support type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)1199   static bool classof(Attribute attr) {
1200     auto denseAttr = attr.dyn_cast<DenseElementsAttr>();
1201     return denseAttr && denseAttr.isSplat();
1202   }
1203 };
1204 
1205 namespace detail {
1206 /// This class represents a general iterator over the values of an ElementsAttr.
1207 /// It supports all subclasses aside from OpaqueElementsAttr.
1208 template <typename T>
1209 class ElementsAttrIterator
1210     : public llvm::iterator_facade_base<ElementsAttrIterator<T>,
1211                                         std::random_access_iterator_tag, T,
1212                                         std::ptrdiff_t, T, T> {
1213   // NOTE: We use a dummy enable_if here because MSVC cannot use 'decltype'
1214   // inside of a conversion operator.
1215   using DenseIteratorT = typename std::enable_if<
1216       true,
1217       decltype(std::declval<DenseElementsAttr>().getValues<T>().begin())>::type;
1218   using SparseIteratorT = SparseElementsAttr::iterator<T>;
1219 
1220   /// A union containing the specific iterators for each derived attribute kind.
1221   union Iterator {
Iterator(DenseIteratorT && it)1222     Iterator(DenseIteratorT &&it) : denseIt(std::move(it)) {}
Iterator(SparseIteratorT && it)1223     Iterator(SparseIteratorT &&it) : sparseIt(std::move(it)) {}
Iterator()1224     Iterator() {}
~Iterator()1225     ~Iterator() {}
1226 
1227     operator const DenseIteratorT &() const { return denseIt; }
1228     operator const SparseIteratorT &() const { return sparseIt; }
1229     operator DenseIteratorT &() { return denseIt; }
1230     operator SparseIteratorT &() { return sparseIt; }
1231 
1232     /// An instance of a dense elements iterator.
1233     DenseIteratorT denseIt;
1234     /// An instance of a sparse elements iterator.
1235     SparseIteratorT sparseIt;
1236   };
1237 
1238   /// Utility method to process a functor on each of the internal iterator
1239   /// types.
1240   template <typename RetT, template <typename> class ProcessFn,
1241             typename... Args>
process(Args &...args)1242   RetT process(Args &... args) const {
1243     switch (attrKind) {
1244     case StandardAttributes::DenseElements:
1245       return ProcessFn<DenseIteratorT>()(args...);
1246     case StandardAttributes::SparseElements:
1247       return ProcessFn<SparseIteratorT>()(args...);
1248     }
1249     llvm_unreachable("unexpected attribute kind");
1250   }
1251 
1252   /// Utility functors used to generically implement the iterators methods.
1253   template <typename ItT> struct PlusAssign {
operatorPlusAssign1254     void operator()(ItT &it, ptrdiff_t offset) { it += offset; }
1255   };
1256   template <typename ItT> struct Minus {
operatorMinus1257     ptrdiff_t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; }
1258   };
1259   template <typename ItT> struct MinusAssign {
operatorMinusAssign1260     void operator()(ItT &it, ptrdiff_t offset) { it -= offset; }
1261   };
1262   template <typename ItT> struct Dereference {
operatorDereference1263     T operator()(ItT &it) { return *it; }
1264   };
1265   template <typename ItT> struct ConstructIter {
operatorConstructIter1266     void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); }
1267   };
1268   template <typename ItT> struct DestructIter {
operatorDestructIter1269     void operator()(ItT &it) { it.~ItT(); }
1270   };
1271 
1272 public:
ElementsAttrIterator(const ElementsAttrIterator<T> & rhs)1273   ElementsAttrIterator(const ElementsAttrIterator<T> &rhs)
1274       : attrKind(rhs.attrKind) {
1275     process<void, ConstructIter>(it, rhs.it);
1276   }
~ElementsAttrIterator()1277   ~ElementsAttrIterator() { process<void, DestructIter>(it); }
1278 
1279   /// Methods necessary to support random access iteration.
1280   ptrdiff_t operator-(const ElementsAttrIterator<T> &rhs) const {
1281     assert(attrKind == rhs.attrKind && "incompatible iterators");
1282     return process<ptrdiff_t, Minus>(it, rhs.it);
1283   }
1284   bool operator==(const ElementsAttrIterator<T> &rhs) const {
1285     return rhs.attrKind == attrKind && process<bool, std::equal_to>(it, rhs.it);
1286   }
1287   bool operator<(const ElementsAttrIterator<T> &rhs) const {
1288     assert(attrKind == rhs.attrKind && "incompatible iterators");
1289     return process<bool, std::less>(it, rhs.it);
1290   }
1291   ElementsAttrIterator<T> &operator+=(ptrdiff_t offset) {
1292     process<void, PlusAssign>(it, offset);
1293     return *this;
1294   }
1295   ElementsAttrIterator<T> &operator-=(ptrdiff_t offset) {
1296     process<void, MinusAssign>(it, offset);
1297     return *this;
1298   }
1299 
1300   /// Dereference the iterator at the current index.
1301   T operator*() { return process<T, Dereference>(it); }
1302 
1303 private:
1304   template <typename IteratorT>
ElementsAttrIterator(unsigned attrKind,IteratorT && it)1305   ElementsAttrIterator(unsigned attrKind, IteratorT &&it)
1306       : attrKind(attrKind), it(std::forward<IteratorT>(it)) {}
1307 
1308   /// Allow accessing the constructor.
1309   friend ElementsAttr;
1310 
1311   /// The kind of derived elements attribute.
1312   unsigned attrKind;
1313 
1314   /// A union containing the specific iterators for each derived kind.
1315   Iterator it;
1316 };
1317 
1318 template <typename T>
1319 class ElementsAttrRange : public llvm::iterator_range<ElementsAttrIterator<T>> {
1320   using llvm::iterator_range<ElementsAttrIterator<T>>::iterator_range;
1321 };
1322 } // namespace detail
1323 
1324 /// Return the elements of this attribute as a value of type 'T'.
1325 template <typename T>
1326 auto ElementsAttr::getValues() const -> iterator_range<T> {
1327   if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>()) {
1328     auto values = denseAttr.getValues<T>();
1329     return {iterator<T>(getKind(), values.begin()),
1330             iterator<T>(getKind(), values.end())};
1331   }
1332   if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>()) {
1333     auto values = sparseAttr.getValues<T>();
1334     return {iterator<T>(getKind(), values.begin()),
1335             iterator<T>(getKind(), values.end())};
1336   }
1337   llvm_unreachable("unexpected attribute kind");
1338 }
1339 
1340 //===----------------------------------------------------------------------===//
1341 // Attributes Utils
1342 //===----------------------------------------------------------------------===//
1343 
isa()1344 template <typename U> bool Attribute::isa() const {
1345   assert(impl && "isa<> used on a null attribute.");
1346   return U::classof(*this);
1347 }
dyn_cast()1348 template <typename U> U Attribute::dyn_cast() const {
1349   return isa<U>() ? U(impl) : U(nullptr);
1350 }
dyn_cast_or_null()1351 template <typename U> U Attribute::dyn_cast_or_null() const {
1352   return (impl && isa<U>()) ? U(impl) : U(nullptr);
1353 }
cast()1354 template <typename U> U Attribute::cast() const {
1355   assert(isa<U>());
1356   return U(impl);
1357 }
1358 
1359 // Make Attribute hashable.
hash_value(Attribute arg)1360 inline ::llvm::hash_code hash_value(Attribute arg) {
1361   return ::llvm::hash_value(arg.impl);
1362 }
1363 
1364 //===----------------------------------------------------------------------===//
1365 // NamedAttributeList
1366 //===----------------------------------------------------------------------===//
1367 
1368 /// A NamedAttributeList is used to manage a list of named attributes. This
1369 /// provides simple interfaces for adding/removing/finding attributes from
1370 /// within a DictionaryAttr.
1371 ///
1372 /// We assume there will be relatively few attributes on a given operation
1373 /// (maybe a dozen or so, but not hundreds or thousands) so we use linear
1374 /// searches for everything.
1375 class NamedAttributeList {
1376 public:
1377   NamedAttributeList(DictionaryAttr attrs = nullptr)
1378       : attrs((attrs && !attrs.empty()) ? attrs : nullptr) {}
1379   NamedAttributeList(ArrayRef<NamedAttribute> attributes);
1380 
1381   bool operator!=(const NamedAttributeList &other) const {
1382     return !(*this == other);
1383   }
1384   bool operator==(const NamedAttributeList &other) const {
1385     return attrs == other.attrs;
1386   }
1387 
1388   /// Return the underlying dictionary attribute. This may be null, if this list
1389   /// has no attributes.
getDictionary()1390   DictionaryAttr getDictionary() const { return attrs; }
1391 
1392   /// Return all of the attributes on this operation.
1393   ArrayRef<NamedAttribute> getAttrs() const;
1394 
1395   /// Replace the held attributes with ones provided in 'newAttrs'.
1396   void setAttrs(ArrayRef<NamedAttribute> attributes);
1397 
1398   /// Return the specified attribute if present, null otherwise.
1399   Attribute get(StringRef name) const;
1400   Attribute get(Identifier name) const;
1401 
1402   /// If the an attribute exists with the specified name, change it to the new
1403   /// value.  Otherwise, add a new attribute with the specified name/value.
1404   void set(Identifier name, Attribute value);
1405 
1406   enum class RemoveResult { Removed, NotFound };
1407 
1408   /// Remove the attribute with the specified name if it exists.  The return
1409   /// value indicates whether the attribute was present or not.
1410   RemoveResult remove(Identifier name);
1411 
1412 private:
1413   DictionaryAttr attrs;
1414 };
1415 
1416 } // end namespace mlir.
1417 
1418 namespace llvm {
1419 
1420 // Attribute hash just like pointers.
1421 template <> struct DenseMapInfo<mlir::Attribute> {
1422   static mlir::Attribute getEmptyKey() {
1423     auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
1424     return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
1425   }
1426   static mlir::Attribute getTombstoneKey() {
1427     auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
1428     return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
1429   }
1430   static unsigned getHashValue(mlir::Attribute val) {
1431     return mlir::hash_value(val);
1432   }
1433   static bool isEqual(mlir::Attribute LHS, mlir::Attribute RHS) {
1434     return LHS == RHS;
1435   }
1436 };
1437 
1438 /// Allow LLVM to steal the low bits of Attributes.
1439 template <> struct PointerLikeTypeTraits<mlir::Attribute> {
1440   static inline void *getAsVoidPointer(mlir::Attribute attr) {
1441     return const_cast<void *>(attr.getAsOpaquePointer());
1442   }
1443   static inline mlir::Attribute getFromVoidPointer(void *ptr) {
1444     return mlir::Attribute::getFromOpaquePointer(ptr);
1445   }
1446   enum { NumLowBitsAvailable = 3 };
1447 };
1448 
1449 template <>
1450 struct PointerLikeTypeTraits<mlir::SymbolRefAttr>
1451     : public PointerLikeTypeTraits<mlir::Attribute> {
1452   static inline mlir::SymbolRefAttr getFromVoidPointer(void *ptr) {
1453     return PointerLikeTypeTraits<mlir::Attribute>::getFromVoidPointer(ptr)
1454         .cast<mlir::SymbolRefAttr>();
1455   }
1456 };
1457 
1458 } // namespace llvm
1459 
1460 #endif
1461