1 /*
2  * Copyright (C) 2020 Emeric Poupon
3  *
4  * This file is part of LMS.
5  *
6  * LMS is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * LMS is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with LMS.  If not, see <http://www.gnu.org/licenses/>.
18  */
19 
20 #pragma once
21 
22 #include <cassert>
23 #include <type_traits>
24 
25 template <typename T, typename underlying_type = std::uint32_t>
26 class EnumSet
27 {
28 	static_assert(std::is_enum<T>::value);
29 	static_assert(std::is_same<underlying_type, std::uint64_t>::value || std::is_same<underlying_type, std::uint32_t>::value);
30 
31 	using index_type = std::uint_fast8_t;
32 
33 	public:
34 		EnumSet() = default;
EnumSet(std::initializer_list<T> values)35 		constexpr EnumSet(std::initializer_list<T> values)
36 		{
37 			for (T value : values)
38 				insert(value);
39 		}
40 
41 		template <typename It>
EnumSet(It begin,It end)42 		constexpr EnumSet(It begin, It end)
43 		{
44 			for (It it {begin}; it != end; ++it)
45 				insert(*it);
46 		}
47 
insert(T value)48 		constexpr void insert(T value)
49 		{
50 			assert(static_cast<size_t>(value) < sizeof(_bitfield) * 8);
51 			_bitfield |= (underlying_type{ 1 } << static_cast<underlying_type>(value));
52 		}
53 
erase(T value)54 		constexpr void erase(T value)
55 		{
56 			assert(static_cast<size_t>(value) < sizeof(_bitfield) * 8);
57 			_bitfield &= ~(underlying_type{ 1 } << static_cast<underlying_type>(value));
58 		}
59 
empty() const60 		constexpr bool empty() const
61 		{
62 			return _bitfield == 0;
63 		}
64 
contains(T value) const65 		constexpr bool contains(T value) const
66 		{
67 			assert(static_cast<size_t>(value) < sizeof(_bitfield) * 8);
68 			return _bitfield & (underlying_type{ 1 } << static_cast<underlying_type>(value));
69 		}
70 
71 		class iterator
72 		{
73 			public:
74 				using value_type = T;
75 
operator *() const76 				constexpr value_type operator*() const
77 				{
78 					return static_cast<value_type>(_index);
79 				}
80 
operator ==(const iterator & _other) const81 				constexpr bool operator==(const iterator& _other) const
82 				{
83 					return &_container == &_other._container && _index == _other._index;
84 				}
85 
operator !=(const iterator & _other) const86 				constexpr bool operator!=(const iterator& _other) const
87 				{
88 					return !(*this == _other);
89 				}
90 
operator ++()91 				constexpr iterator& operator++()
92 				{
93 					_index = _container.getFirstBitSetIndex(_index + 1);
94 					return *this;
95 				}
96 
97 			private:
98 				friend class EnumSet;
99 
iterator(const EnumSet & _container,index_type _index)100 				constexpr iterator(const EnumSet& _container, index_type _index)
101 					: _container {_container}
102 					, _index {_index}
103 				{
104 				}
105 
106 				const EnumSet& _container;
107 				index_type _index;
108 		};
109 
begin() const110 		constexpr iterator begin() const
111 		{
112 			return iterator {*this, getFirstBitSetIndex()};
113 		}
114 
end() const115 		constexpr iterator end() const
116 		{
117 			return iterator {*this, npos};
118 		}
119 
120 	private:
121 		static_assert(std::numeric_limits<index_type>::max() >= sizeof(underlying_type) * 8);
122 		enum : index_type { npos = sizeof(underlying_type) * 8 };
123 
getFirstBitSetIndex(index_type start={}) const124 		constexpr index_type getFirstBitSetIndex(index_type start = {}) const
125 		{
126 			assert(start < npos);
127 
128 			// return npos if no bit found
129 			index_type res {countTrailingZero(_bitfield >> start)};
130 			if (res == npos)
131 				return res;
132 
133 			return res + start;
134 		}
135 
countTrailingZero(underlying_type bitField)136 		static constexpr index_type countTrailingZero(underlying_type bitField)
137 		{
138 			index_type res {};
139 
140 			while (res < (sizeof(underlying_type) * 8) && (bitField & 1) == 0)
141 			{
142 				++res;
143 				bitField >>= 1;
144 			}
145 
146 			if (res == sizeof(underlying_type) * 8)
147 				res = npos;
148 
149 			return res;
150 		}
151 
152 		underlying_type _bitfield{};
153 };
154 
155 
156