1// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package types
16
17import (
18	"reflect"
19	"sort"
20	"sync"
21)
22
23type Set interface {
24	Add(string)
25	Remove(string)
26	Contains(string) bool
27	Equals(Set) bool
28	Length() int
29	Values() []string
30	Copy() Set
31	Sub(Set) Set
32}
33
34func NewUnsafeSet(values ...string) *unsafeSet {
35	set := &unsafeSet{make(map[string]struct{})}
36	for _, v := range values {
37		set.Add(v)
38	}
39	return set
40}
41
42func NewThreadsafeSet(values ...string) *tsafeSet {
43	us := NewUnsafeSet(values...)
44	return &tsafeSet{us, sync.RWMutex{}}
45}
46
47type unsafeSet struct {
48	d map[string]struct{}
49}
50
51// Add adds a new value to the set (no-op if the value is already present)
52func (us *unsafeSet) Add(value string) {
53	us.d[value] = struct{}{}
54}
55
56// Remove removes the given value from the set
57func (us *unsafeSet) Remove(value string) {
58	delete(us.d, value)
59}
60
61// Contains returns whether the set contains the given value
62func (us *unsafeSet) Contains(value string) (exists bool) {
63	_, exists = us.d[value]
64	return exists
65}
66
67// ContainsAll returns whether the set contains all given values
68func (us *unsafeSet) ContainsAll(values []string) bool {
69	for _, s := range values {
70		if !us.Contains(s) {
71			return false
72		}
73	}
74	return true
75}
76
77// Equals returns whether the contents of two sets are identical
78func (us *unsafeSet) Equals(other Set) bool {
79	v1 := sort.StringSlice(us.Values())
80	v2 := sort.StringSlice(other.Values())
81	v1.Sort()
82	v2.Sort()
83	return reflect.DeepEqual(v1, v2)
84}
85
86// Length returns the number of elements in the set
87func (us *unsafeSet) Length() int {
88	return len(us.d)
89}
90
91// Values returns the values of the Set in an unspecified order.
92func (us *unsafeSet) Values() (values []string) {
93	values = make([]string, 0)
94	for val := range us.d {
95		values = append(values, val)
96	}
97	return values
98}
99
100// Copy creates a new Set containing the values of the first
101func (us *unsafeSet) Copy() Set {
102	cp := NewUnsafeSet()
103	for val := range us.d {
104		cp.Add(val)
105	}
106
107	return cp
108}
109
110// Sub removes all elements in other from the set
111func (us *unsafeSet) Sub(other Set) Set {
112	oValues := other.Values()
113	result := us.Copy().(*unsafeSet)
114
115	for _, val := range oValues {
116		if _, ok := result.d[val]; !ok {
117			continue
118		}
119		delete(result.d, val)
120	}
121
122	return result
123}
124
125type tsafeSet struct {
126	us *unsafeSet
127	m  sync.RWMutex
128}
129
130func (ts *tsafeSet) Add(value string) {
131	ts.m.Lock()
132	defer ts.m.Unlock()
133	ts.us.Add(value)
134}
135
136func (ts *tsafeSet) Remove(value string) {
137	ts.m.Lock()
138	defer ts.m.Unlock()
139	ts.us.Remove(value)
140}
141
142func (ts *tsafeSet) Contains(value string) (exists bool) {
143	ts.m.RLock()
144	defer ts.m.RUnlock()
145	return ts.us.Contains(value)
146}
147
148func (ts *tsafeSet) Equals(other Set) bool {
149	ts.m.RLock()
150	defer ts.m.RUnlock()
151
152	// If ts and other represent the same variable, avoid calling
153	// ts.us.Equals(other), to avoid double RLock bug
154	if _other, ok := other.(*tsafeSet); ok {
155		if _other == ts {
156			return true
157		}
158	}
159	return ts.us.Equals(other)
160}
161
162func (ts *tsafeSet) Length() int {
163	ts.m.RLock()
164	defer ts.m.RUnlock()
165	return ts.us.Length()
166}
167
168func (ts *tsafeSet) Values() (values []string) {
169	ts.m.RLock()
170	defer ts.m.RUnlock()
171	return ts.us.Values()
172}
173
174func (ts *tsafeSet) Copy() Set {
175	ts.m.RLock()
176	defer ts.m.RUnlock()
177	usResult := ts.us.Copy().(*unsafeSet)
178	return &tsafeSet{usResult, sync.RWMutex{}}
179}
180
181func (ts *tsafeSet) Sub(other Set) Set {
182	ts.m.RLock()
183	defer ts.m.RUnlock()
184
185	// If ts and other represent the same variable, avoid calling
186	// ts.us.Sub(other), to avoid double RLock bug
187	if _other, ok := other.(*tsafeSet); ok {
188		if _other == ts {
189			usResult := NewUnsafeSet()
190			return &tsafeSet{usResult, sync.RWMutex{}}
191		}
192	}
193	usResult := ts.us.Sub(other).(*unsafeSet)
194	return &tsafeSet{usResult, sync.RWMutex{}}
195}
196