1 /* 2 * Copyright (C) 2020 Emeric Poupon 3 * 4 * This file is part of LMS. 5 * 6 * LMS is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * LMS is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU General Public License for more details. 15 * 16 * You should have received a copy of the GNU General Public License 17 * along with LMS. If not, see <http://www.gnu.org/licenses/>. 18 */ 19 20 #pragma once 21 22 #include <cassert> 23 #include <type_traits> 24 25 template <typename T, typename underlying_type = std::uint32_t> 26 class EnumSet 27 { 28 static_assert(std::is_enum<T>::value); 29 static_assert(std::is_same<underlying_type, std::uint64_t>::value || std::is_same<underlying_type, std::uint32_t>::value); 30 31 using index_type = std::uint_fast8_t; 32 33 public: 34 EnumSet() = default; EnumSet(std::initializer_list<T> values)35 constexpr EnumSet(std::initializer_list<T> values) 36 { 37 for (T value : values) 38 insert(value); 39 } 40 41 template <typename It> EnumSet(It begin,It end)42 constexpr EnumSet(It begin, It end) 43 { 44 for (It it {begin}; it != end; ++it) 45 insert(*it); 46 } 47 insert(T value)48 constexpr void insert(T value) 49 { 50 assert(static_cast<size_t>(value) < sizeof(_bitfield) * 8); 51 _bitfield |= (underlying_type{ 1 } << static_cast<underlying_type>(value)); 52 } 53 erase(T value)54 constexpr void erase(T value) 55 { 56 assert(static_cast<size_t>(value) < sizeof(_bitfield) * 8); 57 _bitfield &= ~(underlying_type{ 1 } << static_cast<underlying_type>(value)); 58 } 59 empty() const60 constexpr bool empty() const 61 { 62 return _bitfield == 0; 63 } 64 contains(T value) const65 constexpr bool contains(T value) const 66 { 67 assert(static_cast<size_t>(value) < sizeof(_bitfield) * 8); 68 return _bitfield & (underlying_type{ 1 } << static_cast<underlying_type>(value)); 69 } 70 71 class iterator 72 { 73 public: 74 using value_type = T; 75 operator *() const76 constexpr value_type operator*() const 77 { 78 return static_cast<value_type>(_index); 79 } 80 operator ==(const iterator & _other) const81 constexpr bool operator==(const iterator& _other) const 82 { 83 return &_container == &_other._container && _index == _other._index; 84 } 85 operator !=(const iterator & _other) const86 constexpr bool operator!=(const iterator& _other) const 87 { 88 return !(*this == _other); 89 } 90 operator ++()91 constexpr iterator& operator++() 92 { 93 _index = _container.getFirstBitSetIndex(_index + 1); 94 return *this; 95 } 96 97 private: 98 friend class EnumSet; 99 iterator(const EnumSet & _container,index_type _index)100 constexpr iterator(const EnumSet& _container, index_type _index) 101 : _container {_container} 102 , _index {_index} 103 { 104 } 105 106 const EnumSet& _container; 107 index_type _index; 108 }; 109 begin() const110 constexpr iterator begin() const 111 { 112 return iterator {*this, getFirstBitSetIndex()}; 113 } 114 end() const115 constexpr iterator end() const 116 { 117 return iterator {*this, npos}; 118 } 119 120 private: 121 static_assert(std::numeric_limits<index_type>::max() >= sizeof(underlying_type) * 8); 122 enum : index_type { npos = sizeof(underlying_type) * 8 }; 123 getFirstBitSetIndex(index_type start={}) const124 constexpr index_type getFirstBitSetIndex(index_type start = {}) const 125 { 126 assert(start < npos); 127 128 // return npos if no bit found 129 index_type res {countTrailingZero(_bitfield >> start)}; 130 if (res == npos) 131 return res; 132 133 return res + start; 134 } 135 countTrailingZero(underlying_type bitField)136 static constexpr index_type countTrailingZero(underlying_type bitField) 137 { 138 index_type res {}; 139 140 while (res < (sizeof(underlying_type) * 8) && (bitField & 1) == 0) 141 { 142 ++res; 143 bitField >>= 1; 144 } 145 146 if (res == sizeof(underlying_type) * 8) 147 res = npos; 148 149 return res; 150 } 151 152 underlying_type _bitfield{}; 153 }; 154 155 156