1 //===- Attributes.h - MLIR Attribute Classes --------------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_IR_ATTRIBUTES_H
10 #define MLIR_IR_ATTRIBUTES_H
11
12 #include "mlir/IR/AttributeSupport.h"
13 #include "llvm/ADT/APFloat.h"
14 #include "llvm/ADT/Sequence.h"
15
16 namespace mlir {
17 class AffineMap;
18 class Dialect;
19 class FunctionType;
20 class Identifier;
21 class IntegerSet;
22 class Location;
23 class MLIRContext;
24 class ShapedType;
25 class Type;
26
27 namespace detail {
28
29 struct AffineMapAttributeStorage;
30 struct ArrayAttributeStorage;
31 struct BoolAttributeStorage;
32 struct DictionaryAttributeStorage;
33 struct IntegerAttributeStorage;
34 struct IntegerSetAttributeStorage;
35 struct FloatAttributeStorage;
36 struct OpaqueAttributeStorage;
37 struct StringAttributeStorage;
38 struct SymbolRefAttributeStorage;
39 struct TypeAttributeStorage;
40
41 /// Elements Attributes.
42 struct DenseElementsAttributeStorage;
43 struct OpaqueElementsAttributeStorage;
44 struct SparseElementsAttributeStorage;
45 } // namespace detail
46
47 /// Attributes are known-constant values of operations and functions.
48 ///
49 /// Instances of the Attribute class are references to immutable, uniqued,
50 /// and immortal values owned by MLIRContext. As such, an Attribute is a thin
51 /// wrapper around an underlying storage pointer. Attributes are usually passed
52 /// by value.
53 class Attribute {
54 public:
55 /// Integer identifier for all the concrete attribute kinds.
56 enum Kind {
57 // Reserve attribute kinds for dialect specific extensions.
58 #define DEFINE_SYM_KIND_RANGE(Dialect) \
59 FIRST_##Dialect##_ATTR, LAST_##Dialect##_ATTR = FIRST_##Dialect##_ATTR + 0xff,
60 #include "DialectSymbolRegistry.def"
61 };
62
63 /// Utility class for implementing attributes.
64 template <typename ConcreteType, typename BaseType = Attribute,
65 typename StorageType = AttributeStorage>
66 using AttrBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
67 detail::AttributeUniquer>;
68
69 using ImplType = AttributeStorage;
70 using ValueType = void;
71
Attribute()72 constexpr Attribute() : impl(nullptr) {}
Attribute(const ImplType * impl)73 /* implicit */ Attribute(const ImplType *impl)
74 : impl(const_cast<ImplType *>(impl)) {}
75
76 Attribute(const Attribute &other) = default;
77 Attribute &operator=(const Attribute &other) = default;
78
79 bool operator==(Attribute other) const { return impl == other.impl; }
80 bool operator!=(Attribute other) const { return !(*this == other); }
81 explicit operator bool() const { return impl; }
82
83 bool operator!() const { return impl == nullptr; }
84
85 template <typename U> bool isa() const;
86 template <typename U> U dyn_cast() const;
87 template <typename U> U dyn_cast_or_null() const;
88 template <typename U> U cast() const;
89
90 // Support dyn_cast'ing Attribute to itself.
classof(Attribute)91 static bool classof(Attribute) { return true; }
92
93 /// Return the classification for this attribute.
getKind()94 unsigned getKind() const { return impl->getKind(); }
95
96 /// Return the type of this attribute.
97 Type getType() const;
98
99 /// Return the context this attribute belongs to.
100 MLIRContext *getContext() const;
101
102 /// Get the dialect this attribute is registered to.
103 Dialect &getDialect() const;
104
105 /// Print the attribute.
106 void print(raw_ostream &os) const;
107 void dump() const;
108
109 /// Get an opaque pointer to the attribute.
getAsOpaquePointer()110 const void *getAsOpaquePointer() const { return impl; }
111 /// Construct an attribute from the opaque pointer representation.
getFromOpaquePointer(const void * ptr)112 static Attribute getFromOpaquePointer(const void *ptr) {
113 return Attribute(reinterpret_cast<const ImplType *>(ptr));
114 }
115
116 friend ::llvm::hash_code hash_value(Attribute arg);
117
118 protected:
119 ImplType *impl;
120 };
121
122 inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
123 attr.print(os);
124 return os;
125 }
126
127 namespace StandardAttributes {
128 enum Kind {
129 AffineMap = Attribute::FIRST_STANDARD_ATTR,
130 Array,
131 Bool,
132 Dictionary,
133 Float,
134 Integer,
135 IntegerSet,
136 Opaque,
137 String,
138 SymbolRef,
139 Type,
140 Unit,
141
142 /// Elements Attributes.
143 DenseElements,
144 OpaqueElements,
145 SparseElements,
146 FIRST_ELEMENTS_ATTR = DenseElements,
147 LAST_ELEMENTS_ATTR = SparseElements,
148
149 /// Locations.
150 CallSiteLocation,
151 FileLineColLocation,
152 FusedLocation,
153 NameLocation,
154 OpaqueLocation,
155 UnknownLocation,
156
157 // Represents a location as a 'void*' pointer to a front-end's opaque
158 // location information, which must live longer than the MLIR objects that
159 // refer to it. OpaqueLocation's are never serialized.
160 //
161 // TODO: OpaqueLocation,
162
163 // Represents a value inlined through a function call.
164 // TODO: InlinedLocation,
165
166 FIRST_LOCATION_ATTR = CallSiteLocation,
167 LAST_LOCATION_ATTR = UnknownLocation,
168 };
169 } // namespace StandardAttributes
170
171 //===----------------------------------------------------------------------===//
172 // AffineMapAttr
173 //===----------------------------------------------------------------------===//
174
175 class AffineMapAttr
176 : public Attribute::AttrBase<AffineMapAttr, Attribute,
177 detail::AffineMapAttributeStorage> {
178 public:
179 using Base::Base;
180 using ValueType = AffineMap;
181
182 static AffineMapAttr get(AffineMap value);
183
184 AffineMap getValue() const;
185
186 /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)187 static bool kindof(unsigned kind) {
188 return kind == StandardAttributes::AffineMap;
189 }
190 };
191
192 //===----------------------------------------------------------------------===//
193 // ArrayAttr
194 //===----------------------------------------------------------------------===//
195
196 /// Array attributes are lists of other attributes. They are not necessarily
197 /// type homogenous given that attributes don't, in general, carry types.
198 class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
199 detail::ArrayAttributeStorage> {
200 public:
201 using Base::Base;
202 using ValueType = ArrayRef<Attribute>;
203
204 static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
205
206 ArrayRef<Attribute> getValue() const;
207
208 /// Support range iteration.
209 using iterator = llvm::ArrayRef<Attribute>::iterator;
begin()210 iterator begin() const { return getValue().begin(); }
end()211 iterator end() const { return getValue().end(); }
size()212 size_t size() const { return getValue().size(); }
213
214 /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)215 static bool kindof(unsigned kind) {
216 return kind == StandardAttributes::Array;
217 }
218
219 private:
220 /// Class for underlying value iterator support.
221 template <typename AttrTy>
222 class attr_value_iterator final
223 : public llvm::mapped_iterator<ArrayAttr::iterator,
224 AttrTy (*)(Attribute)> {
225 public:
attr_value_iterator(ArrayAttr::iterator it)226 explicit attr_value_iterator(ArrayAttr::iterator it)
227 : llvm::mapped_iterator<ArrayAttr::iterator, AttrTy (*)(Attribute)>(
228 it, [](Attribute attr) { return attr.cast<AttrTy>(); }) {}
229 AttrTy operator*() { return (*this->I).template cast<AttrTy>(); }
230 };
231
232 public:
233 template <typename AttrTy>
getAsRange()234 llvm::iterator_range<attr_value_iterator<AttrTy>> getAsRange() {
235 return llvm::make_range(attr_value_iterator<AttrTy>(begin()),
236 attr_value_iterator<AttrTy>(end()));
237 }
238 };
239
240 //===----------------------------------------------------------------------===//
241 // BoolAttr
242 //===----------------------------------------------------------------------===//
243
244 class BoolAttr : public Attribute::AttrBase<BoolAttr, Attribute,
245 detail::BoolAttributeStorage> {
246 public:
247 using Base::Base;
248 using ValueType = bool;
249
250 static BoolAttr get(bool value, MLIRContext *context);
251
252 bool getValue() const;
253
254 /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)255 static bool kindof(unsigned kind) { return kind == StandardAttributes::Bool; }
256 };
257
258 //===----------------------------------------------------------------------===//
259 // DictionaryAttr
260 //===----------------------------------------------------------------------===//
261
262 /// NamedAttribute is used for dictionary attributes, it holds an identifier for
263 /// the name and a value for the attribute. The attribute pointer should always
264 /// be non-null.
265 using NamedAttribute = std::pair<Identifier, Attribute>;
266
267 /// Dictionary attribute is an attribute that represents a sorted collection of
268 /// named attribute values. The elements are sorted by name, and each name must
269 /// be unique within the collection.
270 class DictionaryAttr
271 : public Attribute::AttrBase<DictionaryAttr, Attribute,
272 detail::DictionaryAttributeStorage> {
273 public:
274 using Base::Base;
275 using ValueType = ArrayRef<NamedAttribute>;
276
277 static DictionaryAttr get(ArrayRef<NamedAttribute> value,
278 MLIRContext *context);
279
280 ArrayRef<NamedAttribute> getValue() const;
281
282 /// Return the specified attribute if present, null otherwise.
283 Attribute get(StringRef name) const;
284 Attribute get(Identifier name) const;
285
286 /// Support range iteration.
287 using iterator = llvm::ArrayRef<NamedAttribute>::iterator;
288 iterator begin() const;
289 iterator end() const;
empty()290 bool empty() const { return size() == 0; }
291 size_t size() const;
292
293 /// Methods for supporting type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)294 static bool kindof(unsigned kind) {
295 return kind == StandardAttributes::Dictionary;
296 }
297 };
298
299 //===----------------------------------------------------------------------===//
300 // FloatAttr
301 //===----------------------------------------------------------------------===//
302
303 class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
304 detail::FloatAttributeStorage> {
305 public:
306 using Base::Base;
307 using ValueType = APFloat;
308
309 /// Return a float attribute for the specified value in the specified type.
310 /// These methods should only be used for simple constant values, e.g 1.0/2.0,
311 /// that are known-valid both as host double and the 'type' format.
312 static FloatAttr get(Type type, double value);
313 static FloatAttr getChecked(Type type, double value, Location loc);
314
315 /// Return a float attribute for the specified value in the specified type.
316 static FloatAttr get(Type type, const APFloat &value);
317 static FloatAttr getChecked(Type type, const APFloat &value, Location loc);
318
319 APFloat getValue() const;
320
321 /// This function is used to convert the value to a double, even if it loses
322 /// precision.
323 double getValueAsDouble() const;
324 static double getValueAsDouble(APFloat val);
325
326 /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)327 static bool kindof(unsigned kind) {
328 return kind == StandardAttributes::Float;
329 }
330
331 /// Verify the construction invariants for a double value.
332 static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
333 MLIRContext *ctx, Type type,
334 double value);
335 static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
336 MLIRContext *ctx, Type type,
337 const APFloat &value);
338 };
339
340 //===----------------------------------------------------------------------===//
341 // IntegerAttr
342 //===----------------------------------------------------------------------===//
343
344 class IntegerAttr
345 : public Attribute::AttrBase<IntegerAttr, Attribute,
346 detail::IntegerAttributeStorage> {
347 public:
348 using Base::Base;
349 using ValueType = APInt;
350
351 static IntegerAttr get(Type type, int64_t value);
352 static IntegerAttr get(Type type, const APInt &value);
353
354 APInt getValue() const;
355 // TODO(jpienaar): Change callers to use getValue instead.
356 int64_t getInt() const;
357
358 /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)359 static bool kindof(unsigned kind) {
360 return kind == StandardAttributes::Integer;
361 }
362 };
363
364 //===----------------------------------------------------------------------===//
365 // IntegerSetAttr
366 //===----------------------------------------------------------------------===//
367
368 class IntegerSetAttr
369 : public Attribute::AttrBase<IntegerSetAttr, Attribute,
370 detail::IntegerSetAttributeStorage> {
371 public:
372 using Base::Base;
373 using ValueType = IntegerSet;
374
375 static IntegerSetAttr get(IntegerSet value);
376
377 IntegerSet getValue() const;
378
379 /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)380 static bool kindof(unsigned kind) {
381 return kind == StandardAttributes::IntegerSet;
382 }
383 };
384
385 //===----------------------------------------------------------------------===//
386 // OpaqueAttr
387 //===----------------------------------------------------------------------===//
388
389 /// Opaque attributes represent attributes of non-registered dialects. These are
390 /// attribute represented in their raw string form, and can only usefully be
391 /// tested for attribute equality.
392 class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
393 detail::OpaqueAttributeStorage> {
394 public:
395 using Base::Base;
396
397 /// Get or create a new OpaqueAttr with the provided dialect and string data.
398 static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type,
399 MLIRContext *context);
400
401 /// Get or create a new OpaqueAttr with the provided dialect and string data.
402 /// If the given identifier is not a valid namespace for a dialect, then a
403 /// null attribute is returned.
404 static OpaqueAttr getChecked(Identifier dialect, StringRef attrData,
405 Type type, Location location);
406
407 /// Returns the dialect namespace of the opaque attribute.
408 Identifier getDialectNamespace() const;
409
410 /// Returns the raw attribute data of the opaque attribute.
411 StringRef getAttrData() const;
412
413 /// Verify the construction of an opaque attribute.
414 static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
415 MLIRContext *context,
416 Identifier dialect,
417 StringRef attrData,
418 Type type);
419
kindof(unsigned kind)420 static bool kindof(unsigned kind) {
421 return kind == StandardAttributes::Opaque;
422 }
423 };
424
425 //===----------------------------------------------------------------------===//
426 // StringAttr
427 //===----------------------------------------------------------------------===//
428
429 class StringAttr : public Attribute::AttrBase<StringAttr, Attribute,
430 detail::StringAttributeStorage> {
431 public:
432 using Base::Base;
433 using ValueType = StringRef;
434
435 /// Get an instance of a StringAttr with the given string.
436 static StringAttr get(StringRef bytes, MLIRContext *context);
437
438 /// Get an instance of a StringAttr with the given string and Type.
439 static StringAttr get(StringRef bytes, Type type);
440
441 StringRef getValue() const;
442
443 /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)444 static bool kindof(unsigned kind) {
445 return kind == StandardAttributes::String;
446 }
447 };
448
449 //===----------------------------------------------------------------------===//
450 // SymbolRefAttr
451 //===----------------------------------------------------------------------===//
452
453 class FlatSymbolRefAttr;
454
455 /// A symbol reference attribute represents a symbolic reference to another
456 /// operation.
457 class SymbolRefAttr
458 : public Attribute::AttrBase<SymbolRefAttr, Attribute,
459 detail::SymbolRefAttributeStorage> {
460 public:
461 using Base::Base;
462
463 /// Construct a symbol reference for the given value name.
464 static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx);
465
466 /// Construct a symbol reference for the given value name, and a set of nested
467 /// references that are further resolve to a nested symbol.
468 static SymbolRefAttr get(StringRef value,
469 ArrayRef<FlatSymbolRefAttr> references,
470 MLIRContext *ctx);
471
472 /// Returns the name of the top level symbol reference, i.e. the root of the
473 /// reference path.
474 StringRef getRootReference() const;
475
476 /// Returns the name of the fully resolved symbol, i.e. the leaf of the
477 /// reference path.
478 StringRef getLeafReference() const;
479
480 /// Returns the set of nested references representing the path to the symbol
481 /// nested under the root reference.
482 ArrayRef<FlatSymbolRefAttr> getNestedReferences() const;
483
484 /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)485 static bool kindof(unsigned kind) {
486 return kind == StandardAttributes::SymbolRef;
487 }
488 };
489
490 /// A symbol reference with a reference path containing a single element. This
491 /// is used to refer to an operation within the current symbol table.
492 class FlatSymbolRefAttr : public SymbolRefAttr {
493 public:
494 using SymbolRefAttr::SymbolRefAttr;
495 using ValueType = StringRef;
496
497 /// Construct a symbol reference for the given value name.
get(StringRef value,MLIRContext * ctx)498 static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) {
499 return SymbolRefAttr::get(value, ctx);
500 }
501
502 /// Returns the name of the held symbol reference.
getValue()503 StringRef getValue() const { return getRootReference(); }
504
505 /// Methods for support type inquiry through isa, cast, and dyn_cast.
classof(Attribute attr)506 static bool classof(Attribute attr) {
507 SymbolRefAttr refAttr = attr.dyn_cast<SymbolRefAttr>();
508 return refAttr && refAttr.getNestedReferences().empty();
509 }
510
511 private:
512 using SymbolRefAttr::get;
513 using SymbolRefAttr::getNestedReferences;
514 };
515
516 //===----------------------------------------------------------------------===//
517 // Type
518 //===----------------------------------------------------------------------===//
519
520 class TypeAttr : public Attribute::AttrBase<TypeAttr, Attribute,
521 detail::TypeAttributeStorage> {
522 public:
523 using Base::Base;
524 using ValueType = Type;
525
526 static TypeAttr get(Type value);
527
528 Type getValue() const;
529
530 /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)531 static bool kindof(unsigned kind) { return kind == StandardAttributes::Type; }
532 };
533
534 //===----------------------------------------------------------------------===//
535 // UnitAttr
536 //===----------------------------------------------------------------------===//
537
538 /// Unit attributes are attributes that hold no specific value and are given
539 /// meaning by their existence.
540 class UnitAttr : public Attribute::AttrBase<UnitAttr> {
541 public:
542 using Base::Base;
543
544 static UnitAttr get(MLIRContext *context);
545
kindof(unsigned kind)546 static bool kindof(unsigned kind) { return kind == StandardAttributes::Unit; }
547 };
548
549 //===----------------------------------------------------------------------===//
550 // Elements Attributes
551 //===----------------------------------------------------------------------===//
552
553 namespace detail {
554 template <typename T> class ElementsAttrIterator;
555 template <typename T> class ElementsAttrRange;
556 } // namespace detail
557
558 /// A base attribute that represents a reference to a static shaped tensor or
559 /// vector constant.
560 class ElementsAttr : public Attribute {
561 public:
562 using Attribute::Attribute;
563 template <typename T> using iterator = detail::ElementsAttrIterator<T>;
564 template <typename T> using iterator_range = detail::ElementsAttrRange<T>;
565
566 /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
567 /// with static shape.
568 ShapedType getType() const;
569
570 /// Return the value at the given index. The index is expected to refer to a
571 /// valid element.
572 Attribute getValue(ArrayRef<uint64_t> index) const;
573
574 /// Return the value of type 'T' at the given index, where 'T' corresponds to
575 /// an Attribute type.
getValue(ArrayRef<uint64_t> index)576 template <typename T> T getValue(ArrayRef<uint64_t> index) const {
577 return getValue(index).template cast<T>();
578 }
579
580 /// Return the elements of this attribute as a value of type 'T'. Note:
581 /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
582 /// iteration.
583 template <typename T> iterator_range<T> getValues() const;
584
585 /// Return if the given 'index' refers to a valid element in this attribute.
586 bool isValidIndex(ArrayRef<uint64_t> index) const;
587
588 /// Returns the number of elements held by this attribute.
589 int64_t getNumElements() const;
590
591 /// Generates a new ElementsAttr by mapping each int value to a new
592 /// underlying APInt. The new values can represent either a integer or float.
593 /// This ElementsAttr should contain integers.
594 ElementsAttr mapValues(Type newElementType,
595 function_ref<APInt(const APInt &)> mapping) const;
596
597 /// Generates a new ElementsAttr by mapping each float value to a new
598 /// underlying APInt. The new values can represent either a integer or float.
599 /// This ElementsAttr should contain floats.
600 ElementsAttr mapValues(Type newElementType,
601 function_ref<APInt(const APFloat &)> mapping) const;
602
603 /// Method for support type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)604 static bool classof(Attribute attr) {
605 return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR &&
606 attr.getKind() <= StandardAttributes::LAST_ELEMENTS_ATTR;
607 }
608
609 protected:
610 /// Returns the 1 dimensional flattened row-major index from the given
611 /// multi-dimensional index.
612 uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const;
613 };
614
615 namespace detail {
616 /// DenseElementsAttr data is aligned to uint64_t, so this traits class is
617 /// necessary to interop with PointerIntPair.
618 class DenseElementDataPointerTypeTraits {
619 public:
getAsVoidPointer(const char * ptr)620 static inline const void *getAsVoidPointer(const char *ptr) { return ptr; }
getFromVoidPointer(const void * ptr)621 static inline const char *getFromVoidPointer(const void *ptr) {
622 return static_cast<const char *>(ptr);
623 }
624
625 // Note: We could steal more bits if the need arises.
626 enum { NumLowBitsAvailable = 1 };
627 };
628
629 /// Pair of raw pointer and a boolean flag of whether the pointer holds a splat,
630 using DenseIterPtrAndSplat =
631 llvm::PointerIntPair<const char *, 1, bool,
632 DenseElementDataPointerTypeTraits>;
633
634 /// Impl iterator for indexed DenseElementAttr iterators that records a data
635 /// pointer and data index that is adjusted for the case of a splat attribute.
636 template <typename ConcreteT, typename T, typename PointerT = T *,
637 typename ReferenceT = T &>
638 class DenseElementIndexedIteratorImpl
639 : public indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T,
640 PointerT, ReferenceT> {
641 protected:
DenseElementIndexedIteratorImpl(const char * data,bool isSplat,size_t dataIndex)642 DenseElementIndexedIteratorImpl(const char *data, bool isSplat,
643 size_t dataIndex)
644 : indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T, PointerT,
645 ReferenceT>({data, isSplat}, dataIndex) {}
646
647 /// Return the current index for this iterator, adjusted for the case of a
648 /// splat.
getDataIndex()649 ptrdiff_t getDataIndex() const {
650 bool isSplat = this->base.getInt();
651 return isSplat ? 0 : this->index;
652 }
653
654 /// Return the data base pointer.
getData()655 const char *getData() const { return this->base.getPointer(); }
656 };
657 } // namespace detail
658
659 /// An attribute that represents a reference to a dense vector or tensor object.
660 ///
661 class DenseElementsAttr
662 : public Attribute::AttrBase<DenseElementsAttr, ElementsAttr,
663 detail::DenseElementsAttributeStorage> {
664 public:
665 using Base::Base;
666
667 /// Method for support type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)668 static bool classof(Attribute attr) {
669 return attr.getKind() == StandardAttributes::DenseElements;
670 }
671
672 /// Constructs a dense elements attribute from an array of element values.
673 /// Each element attribute value is expected to be an element of 'type'.
674 /// 'type' must be a vector or tensor with static shape.
675 static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
676
677 /// Constructs a dense integer elements attribute from an array of integer
678 /// or floating-point values. Each value is expected to be the same bitwidth
679 /// of the element type of 'type'. 'type' must be a vector or tensor with
680 /// static shape.
681 template <typename T, typename = typename std::enable_if<
682 std::numeric_limits<T>::is_integer ||
683 llvm::is_one_of<T, float, double>::value>::type>
get(const ShapedType & type,ArrayRef<T> values)684 static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
685 const char *data = reinterpret_cast<const char *>(values.data());
686 return getRawIntOrFloat(
687 type, ArrayRef<char>(data, values.size() * sizeof(T)), sizeof(T),
688 /*isInt=*/std::numeric_limits<T>::is_integer);
689 }
690
691 /// Constructs a dense integer elements attribute from a single element.
692 template <typename T, typename = typename std::enable_if<
693 std::numeric_limits<T>::is_integer ||
694 llvm::is_one_of<T, float, double>::value>::type>
get(const ShapedType & type,T value)695 static DenseElementsAttr get(const ShapedType &type, T value) {
696 return get(type, llvm::makeArrayRef(value));
697 }
698
699 /// Overload of the above 'get' method that is specialized for boolean values.
700 static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
701
702 /// Constructs a dense integer elements attribute from an array of APInt
703 /// values. Each APInt value is expected to have the same bitwidth as the
704 /// element type of 'type'. 'type' must be a vector or tensor with static
705 /// shape.
706 static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
707
708 /// Constructs a dense float elements attribute from an array of APFloat
709 /// values. Each APFloat value is expected to have the same bitwidth as the
710 /// element type of 'type'. 'type' must be a vector or tensor with static
711 /// shape.
712 static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
713
714 /// Construct a dense elements attribute for an initializer_list of values.
715 /// Each value is expected to be the same bitwidth of the element type of
716 /// 'type'. 'type' must be a vector or tensor with static shape.
717 template <typename T>
get(const ShapedType & type,const std::initializer_list<T> & list)718 static DenseElementsAttr get(const ShapedType &type,
719 const std::initializer_list<T> &list) {
720 return get(type, ArrayRef<T>(list));
721 }
722
723 //===--------------------------------------------------------------------===//
724 // Iterators
725 //===--------------------------------------------------------------------===//
726
727 /// A utility iterator that allows walking over the internal Attribute values
728 /// of a DenseElementsAttr.
729 class AttributeElementIterator
730 : public indexed_accessor_iterator<AttributeElementIterator, const void *,
731 Attribute, Attribute, Attribute> {
732 public:
733 /// Accesses the Attribute value at this iterator position.
734 Attribute operator*() const;
735
736 private:
737 friend DenseElementsAttr;
738
739 /// Constructs a new iterator.
740 AttributeElementIterator(DenseElementsAttr attr, size_t index);
741 };
742
743 /// Iterator for walking raw element values of the specified type 'T', which
744 /// may be any c++ data type matching the stored representation: int32_t,
745 /// float, etc.
746 template <typename T>
747 class ElementIterator
748 : public detail::DenseElementIndexedIteratorImpl<ElementIterator<T>,
749 const T> {
750 public:
751 /// Accesses the raw value at this iterator position.
752 const T &operator*() const {
753 return reinterpret_cast<const T *>(this->getData())[this->getDataIndex()];
754 }
755
756 private:
757 friend DenseElementsAttr;
758
759 /// Constructs a new iterator.
ElementIterator(const char * data,bool isSplat,size_t dataIndex)760 ElementIterator(const char *data, bool isSplat, size_t dataIndex)
761 : detail::DenseElementIndexedIteratorImpl<ElementIterator<T>, const T>(
762 data, isSplat, dataIndex) {}
763 };
764
765 /// A utility iterator that allows walking over the internal bool values.
766 class BoolElementIterator
767 : public detail::DenseElementIndexedIteratorImpl<BoolElementIterator,
768 bool, bool, bool> {
769 public:
770 /// Accesses the bool value at this iterator position.
771 bool operator*() const;
772
773 private:
774 friend DenseElementsAttr;
775
776 /// Constructs a new iterator.
777 BoolElementIterator(DenseElementsAttr attr, size_t dataIndex);
778 };
779
780 /// A utility iterator that allows walking over the internal raw APInt values.
781 class IntElementIterator
782 : public detail::DenseElementIndexedIteratorImpl<IntElementIterator,
783 APInt, APInt, APInt> {
784 public:
785 /// Accesses the raw APInt value at this iterator position.
786 APInt operator*() const;
787
788 private:
789 friend DenseElementsAttr;
790
791 /// Constructs a new iterator.
792 IntElementIterator(DenseElementsAttr attr, size_t dataIndex);
793
794 /// The bitwidth of the element type.
795 size_t bitWidth;
796 };
797
798 /// Iterator for walking over APFloat values.
799 class FloatElementIterator final
800 : public llvm::mapped_iterator<IntElementIterator,
801 std::function<APFloat(const APInt &)>> {
802 friend DenseElementsAttr;
803
804 /// Initializes the float element iterator to the specified iterator.
805 FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it);
806
807 public:
808 using reference = APFloat;
809 };
810
811 //===--------------------------------------------------------------------===//
812 // Value Querying
813 //===--------------------------------------------------------------------===//
814
815 /// Returns if this attribute corresponds to a splat, i.e. if all element
816 /// values are the same.
817 bool isSplat() const;
818
819 /// Return the splat value for this attribute. This asserts that the attribute
820 /// corresponds to a splat.
getSplatValue()821 Attribute getSplatValue() const { return getSplatValue<Attribute>(); }
822 template <typename T>
823 typename std::enable_if<!std::is_base_of<Attribute, T>::value ||
824 std::is_same<Attribute, T>::value,
825 T>::type
getSplatValue()826 getSplatValue() const {
827 assert(isSplat() && "expected the attribute to be a splat");
828 return *getValues<T>().begin();
829 }
830 /// Return the splat value for derived attribute element types.
831 template <typename T>
832 typename std::enable_if<std::is_base_of<Attribute, T>::value &&
833 !std::is_same<Attribute, T>::value,
834 T>::type
getSplatValue()835 getSplatValue() const {
836 return getSplatValue().template cast<T>();
837 }
838
839 /// Return the value at the given index. The 'index' is expected to refer to a
840 /// valid element.
getValue(ArrayRef<uint64_t> index)841 Attribute getValue(ArrayRef<uint64_t> index) const {
842 return getValue<Attribute>(index);
843 }
getValue(ArrayRef<uint64_t> index)844 template <typename T> T getValue(ArrayRef<uint64_t> index) const {
845 // Skip to the element corresponding to the flattened index.
846 return *std::next(getValues<T>().begin(), getFlattenedIndex(index));
847 }
848
849 /// Return the held element values as a range of integer or floating-point
850 /// values.
851 template <typename T, typename = typename std::enable_if<
852 (!std::is_same<T, bool>::value &&
853 std::numeric_limits<T>::is_integer) ||
854 llvm::is_one_of<T, float, double>::value>::type>
getValues()855 llvm::iterator_range<ElementIterator<T>> getValues() const {
856 assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer));
857 auto rawData = getRawData().data();
858 bool splat = isSplat();
859 return {ElementIterator<T>(rawData, splat, 0),
860 ElementIterator<T>(rawData, splat, getNumElements())};
861 }
862
863 /// Return the held element values as a range of Attributes.
864 llvm::iterator_range<AttributeElementIterator> getAttributeValues() const;
865 template <typename T, typename = typename std::enable_if<
866 std::is_same<T, Attribute>::value>::type>
getValues()867 llvm::iterator_range<AttributeElementIterator> getValues() const {
868 return getAttributeValues();
869 }
870 AttributeElementIterator attr_value_begin() const;
871 AttributeElementIterator attr_value_end() const;
872
873 /// Return the held element values a range of T, where T is a derived
874 /// attribute type.
875 template <typename T>
876 using DerivedAttributeElementIterator =
877 llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>;
878 template <typename T, typename = typename std::enable_if<
879 std::is_base_of<Attribute, T>::value &&
880 !std::is_same<Attribute, T>::value>::type>
getValues()881 llvm::iterator_range<DerivedAttributeElementIterator<T>> getValues() const {
882 auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
883 return llvm::map_range(getAttributeValues(),
884 static_cast<T (*)(Attribute)>(castFn));
885 }
886
887 /// Return the held element values as a range of bool. The element type of
888 /// this attribute must be of integer type of bitwidth 1.
889 llvm::iterator_range<BoolElementIterator> getBoolValues() const;
890 template <typename T, typename = typename std::enable_if<
891 std::is_same<T, bool>::value>::type>
getValues()892 llvm::iterator_range<BoolElementIterator> getValues() const {
893 return getBoolValues();
894 }
895
896 /// Return the held element values as a range of APInts. The element type of
897 /// this attribute must be of integer type.
898 llvm::iterator_range<IntElementIterator> getIntValues() const;
899 template <typename T, typename = typename std::enable_if<
900 std::is_same<T, APInt>::value>::type>
getValues()901 llvm::iterator_range<IntElementIterator> getValues() const {
902 return getIntValues();
903 }
904 IntElementIterator int_value_begin() const;
905 IntElementIterator int_value_end() const;
906
907 /// Return the held element values as a range of APFloat. The element type of
908 /// this attribute must be of float type.
909 llvm::iterator_range<FloatElementIterator> getFloatValues() const;
910 template <typename T, typename = typename std::enable_if<
911 std::is_same<T, APFloat>::value>::type>
getValues()912 llvm::iterator_range<FloatElementIterator> getValues() const {
913 return getFloatValues();
914 }
915 FloatElementIterator float_value_begin() const;
916 FloatElementIterator float_value_end() const;
917
918 //===--------------------------------------------------------------------===//
919 // Mutation Utilities
920 //===--------------------------------------------------------------------===//
921
922 /// Return a new DenseElementsAttr that has the same data as the current
923 /// attribute, but has been reshaped to 'newType'. The new type must have the
924 /// same total number of elements as well as element type.
925 DenseElementsAttr reshape(ShapedType newType);
926
927 /// Generates a new DenseElementsAttr by mapping each int value to a new
928 /// underlying APInt. The new values can represent either a integer or float.
929 /// This underlying type must be an DenseIntElementsAttr.
930 DenseElementsAttr mapValues(Type newElementType,
931 function_ref<APInt(const APInt &)> mapping) const;
932
933 /// Generates a new DenseElementsAttr by mapping each float value to a new
934 /// underlying APInt. the new values can represent either a integer or float.
935 /// This underlying type must be an DenseFPElementsAttr.
936 DenseElementsAttr
937 mapValues(Type newElementType,
938 function_ref<APInt(const APFloat &)> mapping) const;
939
940 protected:
941 /// Return the raw storage data held by this attribute.
942 ArrayRef<char> getRawData() const;
943
944 /// Get iterators to the raw APInt values for each element in this attribute.
raw_int_begin()945 IntElementIterator raw_int_begin() const {
946 return IntElementIterator(*this, 0);
947 }
raw_int_end()948 IntElementIterator raw_int_end() const {
949 return IntElementIterator(*this, getNumElements());
950 }
951
952 /// Constructs a dense elements attribute from an array of raw APInt values.
953 /// Each APInt value is expected to have the same bitwidth as the element type
954 /// of 'type'. 'type' must be a vector or tensor with static shape.
955 static DenseElementsAttr getRaw(ShapedType type, ArrayRef<APInt> values);
956
957 /// Get or create a new dense elements attribute instance with the given raw
958 /// data buffer. 'type' must be a vector or tensor with static shape.
959 static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data,
960 bool isSplat);
961
962 /// Overload of the raw 'get' method that asserts that the given type is of
963 /// integer or floating-point type. This method is used to verify type
964 /// invariants that the templatized 'get' method cannot.
965 static DenseElementsAttr getRawIntOrFloat(ShapedType type,
966 ArrayRef<char> data,
967 int64_t dataEltSize, bool isInt);
968
969 /// Check the information for a c++ data type, check if this type is valid for
970 /// the current attribute. This method is used to verify specific type
971 /// invariants that the templatized 'getValues' method cannot.
972 bool isValidIntOrFloat(int64_t dataEltSize, bool isInt) const;
973 };
974
975 /// An attribute that represents a reference to a dense float vector or tensor
976 /// object. Each element is stored as a double.
977 class DenseFPElementsAttr : public DenseElementsAttr {
978 public:
979 using iterator = DenseElementsAttr::FloatElementIterator;
980
981 using DenseElementsAttr::DenseElementsAttr;
982
983 /// Get an instance of a DenseFPElementsAttr with the given arguments. This
984 /// simply wraps the DenseElementsAttr::get calls.
985 template <typename Arg>
get(const ShapedType & type,Arg && arg)986 static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg) {
987 return DenseElementsAttr::get(type, llvm::makeArrayRef(arg))
988 .template cast<DenseFPElementsAttr>();
989 }
990 template <typename T>
get(const ShapedType & type,const std::initializer_list<T> & list)991 static DenseFPElementsAttr get(const ShapedType &type,
992 const std::initializer_list<T> &list) {
993 return DenseElementsAttr::get(type, list)
994 .template cast<DenseFPElementsAttr>();
995 }
996
997 /// Generates a new DenseElementsAttr by mapping each value attribute, and
998 /// constructing the DenseElementsAttr given the new element type.
999 DenseElementsAttr
1000 mapValues(Type newElementType,
1001 function_ref<APInt(const APFloat &)> mapping) const;
1002
1003 /// Iterator access to the float element values.
begin()1004 iterator begin() const { return float_value_begin(); }
end()1005 iterator end() const { return float_value_end(); }
1006
1007 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1008 static bool classof(Attribute attr);
1009 };
1010
1011 /// An attribute that represents a reference to a dense integer vector or tensor
1012 /// object.
1013 class DenseIntElementsAttr : public DenseElementsAttr {
1014 public:
1015 /// DenseIntElementsAttr iterates on APInt, so we can use the raw element
1016 /// iterator directly.
1017 using iterator = DenseElementsAttr::IntElementIterator;
1018
1019 using DenseElementsAttr::DenseElementsAttr;
1020
1021 /// Get an instance of a DenseIntElementsAttr with the given arguments. This
1022 /// simply wraps the DenseElementsAttr::get calls.
1023 template <typename Arg>
get(const ShapedType & type,Arg && arg)1024 static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg) {
1025 return DenseElementsAttr::get(type, llvm::makeArrayRef(arg))
1026 .template cast<DenseIntElementsAttr>();
1027 }
1028 template <typename T>
get(const ShapedType & type,const std::initializer_list<T> & list)1029 static DenseIntElementsAttr get(const ShapedType &type,
1030 const std::initializer_list<T> &list) {
1031 return DenseElementsAttr::get(type, list)
1032 .template cast<DenseIntElementsAttr>();
1033 }
1034
1035 /// Generates a new DenseElementsAttr by mapping each value attribute, and
1036 /// constructing the DenseElementsAttr given the new element type.
1037 DenseElementsAttr mapValues(Type newElementType,
1038 function_ref<APInt(const APInt &)> mapping) const;
1039
1040 /// Iterator access to the integer element values.
begin()1041 iterator begin() const { return raw_int_begin(); }
end()1042 iterator end() const { return raw_int_end(); }
1043
1044 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1045 static bool classof(Attribute attr);
1046 };
1047
1048 /// An opaque attribute that represents a reference to a vector or tensor
1049 /// constant with opaque content. This representation is for tensor constants
1050 /// which the compiler may not need to interpret. This attribute is always
1051 /// associated with a particular dialect, which provides a method to convert
1052 /// tensor representation to a non-opaque format.
1053 class OpaqueElementsAttr
1054 : public Attribute::AttrBase<OpaqueElementsAttr, ElementsAttr,
1055 detail::OpaqueElementsAttributeStorage> {
1056 public:
1057 using Base::Base;
1058 using ValueType = StringRef;
1059
1060 static OpaqueElementsAttr get(Dialect *dialect, ShapedType type,
1061 StringRef bytes);
1062
1063 StringRef getValue() const;
1064
1065 /// Return the value at the given index. The 'index' is expected to refer to a
1066 /// valid element.
1067 Attribute getValue(ArrayRef<uint64_t> index) const;
1068
1069 /// Decodes the attribute value using dialect-specific decoding hook.
1070 /// Returns false if decoding is successful. If not, returns true and leaves
1071 /// 'result' argument unspecified.
1072 bool decode(ElementsAttr &result);
1073
1074 /// Returns dialect associated with this opaque constant.
1075 Dialect *getDialect() const;
1076
1077 /// Method for support type inquiry through isa, cast and dyn_cast.
kindof(unsigned kind)1078 static bool kindof(unsigned kind) {
1079 return kind == StandardAttributes::OpaqueElements;
1080 }
1081 };
1082
1083 /// An attribute that represents a reference to a sparse vector or tensor
1084 /// object.
1085 ///
1086 /// This class uses COO (coordinate list) encoding to represent the sparse
1087 /// elements in an element attribute. Specifically, the sparse vector/tensor
1088 /// stores the indices and values as two separate dense elements attributes of
1089 /// tensor type (even if the sparse attribute is of vector type, in order to
1090 /// support empty lists). The dense elements attribute indices is a 2-D tensor
1091 /// of 64-bit integer elements with shape [N, ndims], which specifies the
1092 /// indices of the elements in the sparse tensor that contains nonzero values.
1093 /// The dense elements attribute values is a 1-D tensor with shape [N], and it
1094 /// supplies the corresponding values for the indices.
1095 ///
1096 /// For example,
1097 /// `sparse<tensor<3x4xi32>, [[0, 0], [1, 2]], [1, 5]>` represents tensor
1098 /// [[1, 0, 0, 0],
1099 /// [0, 0, 5, 0],
1100 /// [0, 0, 0, 0]].
1101 class SparseElementsAttr
1102 : public Attribute::AttrBase<SparseElementsAttr, ElementsAttr,
1103 detail::SparseElementsAttributeStorage> {
1104 public:
1105 using Base::Base;
1106
1107 template <typename T>
1108 using iterator =
1109 llvm::mapped_iterator<llvm::detail::value_sequence_iterator<ptrdiff_t>,
1110 std::function<T(ptrdiff_t)>>;
1111
1112 /// 'type' must be a vector or tensor with static shape.
1113 static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices,
1114 DenseElementsAttr values);
1115
1116 DenseIntElementsAttr getIndices() const;
1117
1118 DenseElementsAttr getValues() const;
1119
1120 /// Return the values of this attribute in the form of the given type 'T'. 'T'
1121 /// may be any of Attribute, APInt, APFloat, c++ integer/float types, etc.
getValues()1122 template <typename T> llvm::iterator_range<iterator<T>> getValues() const {
1123 auto zeroValue = getZeroValue<T>();
1124 auto valueIt = getValues().getValues<T>().begin();
1125 const std::vector<ptrdiff_t> flatSparseIndices(getFlattenedSparseIndices());
1126 // TODO(riverriddle): Move-capture flatSparseIndices when c++14 is
1127 // available.
1128 std::function<T(ptrdiff_t)> mapFn = [=](ptrdiff_t index) {
1129 // Try to map the current index to one of the sparse indices.
1130 for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i)
1131 if (flatSparseIndices[i] == index)
1132 return *std::next(valueIt, i);
1133 // Otherwise, return the zero value.
1134 return zeroValue;
1135 };
1136 return llvm::map_range(llvm::seq<ptrdiff_t>(0, getNumElements()), mapFn);
1137 }
1138
1139 /// Return the value of the element at the given index. The 'index' is
1140 /// expected to refer to a valid element.
1141 Attribute getValue(ArrayRef<uint64_t> index) const;
1142
1143 /// Method for support type inquiry through isa, cast and dyn_cast.
kindof(unsigned kind)1144 static bool kindof(unsigned kind) {
1145 return kind == StandardAttributes::SparseElements;
1146 }
1147
1148 private:
1149 /// Get a zero APFloat for the given sparse attribute.
1150 APFloat getZeroAPFloat() const;
1151
1152 /// Get a zero APInt for the given sparse attribute.
1153 APInt getZeroAPInt() const;
1154
1155 /// Get a zero attribute for the given sparse attribute.
1156 Attribute getZeroAttr() const;
1157
1158 /// Utility methods to generate a zero value of some type 'T'. This is used by
1159 /// the 'iterator' class.
1160 /// Get a zero for a given attribute type.
1161 template <typename T>
1162 typename std::enable_if<std::is_base_of<Attribute, T>::value, T>::type
getZeroValue()1163 getZeroValue() const {
1164 return getZeroAttr().template cast<T>();
1165 }
1166 /// Get a zero for an APInt.
1167 template <typename T>
1168 typename std::enable_if<std::is_same<APInt, T>::value, T>::type
getZeroValue()1169 getZeroValue() const {
1170 return getZeroAPInt();
1171 }
1172 /// Get a zero for an APFloat.
1173 template <typename T>
1174 typename std::enable_if<std::is_same<APFloat, T>::value, T>::type
getZeroValue()1175 getZeroValue() const {
1176 return getZeroAPFloat();
1177 }
1178 /// Get a zero for an C++ integer or float type.
1179 template <typename T>
1180 typename std::enable_if<std::numeric_limits<T>::is_integer ||
1181 llvm::is_one_of<T, float, double>::value,
1182 T>::type
getZeroValue()1183 getZeroValue() const {
1184 return T(0);
1185 }
1186
1187 /// Flatten, and return, all of the sparse indices in this attribute in
1188 /// row-major order.
1189 std::vector<ptrdiff_t> getFlattenedSparseIndices() const;
1190 };
1191
1192 /// An attribute that represents a reference to a splat vector or tensor
1193 /// constant, meaning all of the elements have the same value.
1194 class SplatElementsAttr : public DenseElementsAttr {
1195 public:
1196 using DenseElementsAttr::DenseElementsAttr;
1197
1198 /// Method for support type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)1199 static bool classof(Attribute attr) {
1200 auto denseAttr = attr.dyn_cast<DenseElementsAttr>();
1201 return denseAttr && denseAttr.isSplat();
1202 }
1203 };
1204
1205 namespace detail {
1206 /// This class represents a general iterator over the values of an ElementsAttr.
1207 /// It supports all subclasses aside from OpaqueElementsAttr.
1208 template <typename T>
1209 class ElementsAttrIterator
1210 : public llvm::iterator_facade_base<ElementsAttrIterator<T>,
1211 std::random_access_iterator_tag, T,
1212 std::ptrdiff_t, T, T> {
1213 // NOTE: We use a dummy enable_if here because MSVC cannot use 'decltype'
1214 // inside of a conversion operator.
1215 using DenseIteratorT = typename std::enable_if<
1216 true,
1217 decltype(std::declval<DenseElementsAttr>().getValues<T>().begin())>::type;
1218 using SparseIteratorT = SparseElementsAttr::iterator<T>;
1219
1220 /// A union containing the specific iterators for each derived attribute kind.
1221 union Iterator {
Iterator(DenseIteratorT && it)1222 Iterator(DenseIteratorT &&it) : denseIt(std::move(it)) {}
Iterator(SparseIteratorT && it)1223 Iterator(SparseIteratorT &&it) : sparseIt(std::move(it)) {}
Iterator()1224 Iterator() {}
~Iterator()1225 ~Iterator() {}
1226
1227 operator const DenseIteratorT &() const { return denseIt; }
1228 operator const SparseIteratorT &() const { return sparseIt; }
1229 operator DenseIteratorT &() { return denseIt; }
1230 operator SparseIteratorT &() { return sparseIt; }
1231
1232 /// An instance of a dense elements iterator.
1233 DenseIteratorT denseIt;
1234 /// An instance of a sparse elements iterator.
1235 SparseIteratorT sparseIt;
1236 };
1237
1238 /// Utility method to process a functor on each of the internal iterator
1239 /// types.
1240 template <typename RetT, template <typename> class ProcessFn,
1241 typename... Args>
process(Args &...args)1242 RetT process(Args &... args) const {
1243 switch (attrKind) {
1244 case StandardAttributes::DenseElements:
1245 return ProcessFn<DenseIteratorT>()(args...);
1246 case StandardAttributes::SparseElements:
1247 return ProcessFn<SparseIteratorT>()(args...);
1248 }
1249 llvm_unreachable("unexpected attribute kind");
1250 }
1251
1252 /// Utility functors used to generically implement the iterators methods.
1253 template <typename ItT> struct PlusAssign {
operatorPlusAssign1254 void operator()(ItT &it, ptrdiff_t offset) { it += offset; }
1255 };
1256 template <typename ItT> struct Minus {
operatorMinus1257 ptrdiff_t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; }
1258 };
1259 template <typename ItT> struct MinusAssign {
operatorMinusAssign1260 void operator()(ItT &it, ptrdiff_t offset) { it -= offset; }
1261 };
1262 template <typename ItT> struct Dereference {
operatorDereference1263 T operator()(ItT &it) { return *it; }
1264 };
1265 template <typename ItT> struct ConstructIter {
operatorConstructIter1266 void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); }
1267 };
1268 template <typename ItT> struct DestructIter {
operatorDestructIter1269 void operator()(ItT &it) { it.~ItT(); }
1270 };
1271
1272 public:
ElementsAttrIterator(const ElementsAttrIterator<T> & rhs)1273 ElementsAttrIterator(const ElementsAttrIterator<T> &rhs)
1274 : attrKind(rhs.attrKind) {
1275 process<void, ConstructIter>(it, rhs.it);
1276 }
~ElementsAttrIterator()1277 ~ElementsAttrIterator() { process<void, DestructIter>(it); }
1278
1279 /// Methods necessary to support random access iteration.
1280 ptrdiff_t operator-(const ElementsAttrIterator<T> &rhs) const {
1281 assert(attrKind == rhs.attrKind && "incompatible iterators");
1282 return process<ptrdiff_t, Minus>(it, rhs.it);
1283 }
1284 bool operator==(const ElementsAttrIterator<T> &rhs) const {
1285 return rhs.attrKind == attrKind && process<bool, std::equal_to>(it, rhs.it);
1286 }
1287 bool operator<(const ElementsAttrIterator<T> &rhs) const {
1288 assert(attrKind == rhs.attrKind && "incompatible iterators");
1289 return process<bool, std::less>(it, rhs.it);
1290 }
1291 ElementsAttrIterator<T> &operator+=(ptrdiff_t offset) {
1292 process<void, PlusAssign>(it, offset);
1293 return *this;
1294 }
1295 ElementsAttrIterator<T> &operator-=(ptrdiff_t offset) {
1296 process<void, MinusAssign>(it, offset);
1297 return *this;
1298 }
1299
1300 /// Dereference the iterator at the current index.
1301 T operator*() { return process<T, Dereference>(it); }
1302
1303 private:
1304 template <typename IteratorT>
ElementsAttrIterator(unsigned attrKind,IteratorT && it)1305 ElementsAttrIterator(unsigned attrKind, IteratorT &&it)
1306 : attrKind(attrKind), it(std::forward<IteratorT>(it)) {}
1307
1308 /// Allow accessing the constructor.
1309 friend ElementsAttr;
1310
1311 /// The kind of derived elements attribute.
1312 unsigned attrKind;
1313
1314 /// A union containing the specific iterators for each derived kind.
1315 Iterator it;
1316 };
1317
1318 template <typename T>
1319 class ElementsAttrRange : public llvm::iterator_range<ElementsAttrIterator<T>> {
1320 using llvm::iterator_range<ElementsAttrIterator<T>>::iterator_range;
1321 };
1322 } // namespace detail
1323
1324 /// Return the elements of this attribute as a value of type 'T'.
1325 template <typename T>
1326 auto ElementsAttr::getValues() const -> iterator_range<T> {
1327 if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>()) {
1328 auto values = denseAttr.getValues<T>();
1329 return {iterator<T>(getKind(), values.begin()),
1330 iterator<T>(getKind(), values.end())};
1331 }
1332 if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>()) {
1333 auto values = sparseAttr.getValues<T>();
1334 return {iterator<T>(getKind(), values.begin()),
1335 iterator<T>(getKind(), values.end())};
1336 }
1337 llvm_unreachable("unexpected attribute kind");
1338 }
1339
1340 //===----------------------------------------------------------------------===//
1341 // Attributes Utils
1342 //===----------------------------------------------------------------------===//
1343
isa()1344 template <typename U> bool Attribute::isa() const {
1345 assert(impl && "isa<> used on a null attribute.");
1346 return U::classof(*this);
1347 }
dyn_cast()1348 template <typename U> U Attribute::dyn_cast() const {
1349 return isa<U>() ? U(impl) : U(nullptr);
1350 }
dyn_cast_or_null()1351 template <typename U> U Attribute::dyn_cast_or_null() const {
1352 return (impl && isa<U>()) ? U(impl) : U(nullptr);
1353 }
cast()1354 template <typename U> U Attribute::cast() const {
1355 assert(isa<U>());
1356 return U(impl);
1357 }
1358
1359 // Make Attribute hashable.
hash_value(Attribute arg)1360 inline ::llvm::hash_code hash_value(Attribute arg) {
1361 return ::llvm::hash_value(arg.impl);
1362 }
1363
1364 //===----------------------------------------------------------------------===//
1365 // NamedAttributeList
1366 //===----------------------------------------------------------------------===//
1367
1368 /// A NamedAttributeList is used to manage a list of named attributes. This
1369 /// provides simple interfaces for adding/removing/finding attributes from
1370 /// within a DictionaryAttr.
1371 ///
1372 /// We assume there will be relatively few attributes on a given operation
1373 /// (maybe a dozen or so, but not hundreds or thousands) so we use linear
1374 /// searches for everything.
1375 class NamedAttributeList {
1376 public:
1377 NamedAttributeList(DictionaryAttr attrs = nullptr)
1378 : attrs((attrs && !attrs.empty()) ? attrs : nullptr) {}
1379 NamedAttributeList(ArrayRef<NamedAttribute> attributes);
1380
1381 bool operator!=(const NamedAttributeList &other) const {
1382 return !(*this == other);
1383 }
1384 bool operator==(const NamedAttributeList &other) const {
1385 return attrs == other.attrs;
1386 }
1387
1388 /// Return the underlying dictionary attribute. This may be null, if this list
1389 /// has no attributes.
getDictionary()1390 DictionaryAttr getDictionary() const { return attrs; }
1391
1392 /// Return all of the attributes on this operation.
1393 ArrayRef<NamedAttribute> getAttrs() const;
1394
1395 /// Replace the held attributes with ones provided in 'newAttrs'.
1396 void setAttrs(ArrayRef<NamedAttribute> attributes);
1397
1398 /// Return the specified attribute if present, null otherwise.
1399 Attribute get(StringRef name) const;
1400 Attribute get(Identifier name) const;
1401
1402 /// If the an attribute exists with the specified name, change it to the new
1403 /// value. Otherwise, add a new attribute with the specified name/value.
1404 void set(Identifier name, Attribute value);
1405
1406 enum class RemoveResult { Removed, NotFound };
1407
1408 /// Remove the attribute with the specified name if it exists. The return
1409 /// value indicates whether the attribute was present or not.
1410 RemoveResult remove(Identifier name);
1411
1412 private:
1413 DictionaryAttr attrs;
1414 };
1415
1416 } // end namespace mlir.
1417
1418 namespace llvm {
1419
1420 // Attribute hash just like pointers.
1421 template <> struct DenseMapInfo<mlir::Attribute> {
1422 static mlir::Attribute getEmptyKey() {
1423 auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
1424 return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
1425 }
1426 static mlir::Attribute getTombstoneKey() {
1427 auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
1428 return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
1429 }
1430 static unsigned getHashValue(mlir::Attribute val) {
1431 return mlir::hash_value(val);
1432 }
1433 static bool isEqual(mlir::Attribute LHS, mlir::Attribute RHS) {
1434 return LHS == RHS;
1435 }
1436 };
1437
1438 /// Allow LLVM to steal the low bits of Attributes.
1439 template <> struct PointerLikeTypeTraits<mlir::Attribute> {
1440 static inline void *getAsVoidPointer(mlir::Attribute attr) {
1441 return const_cast<void *>(attr.getAsOpaquePointer());
1442 }
1443 static inline mlir::Attribute getFromVoidPointer(void *ptr) {
1444 return mlir::Attribute::getFromOpaquePointer(ptr);
1445 }
1446 enum { NumLowBitsAvailable = 3 };
1447 };
1448
1449 template <>
1450 struct PointerLikeTypeTraits<mlir::SymbolRefAttr>
1451 : public PointerLikeTypeTraits<mlir::Attribute> {
1452 static inline mlir::SymbolRefAttr getFromVoidPointer(void *ptr) {
1453 return PointerLikeTypeTraits<mlir::Attribute>::getFromVoidPointer(ptr)
1454 .cast<mlir::SymbolRefAttr>();
1455 }
1456 };
1457
1458 } // namespace llvm
1459
1460 #endif
1461