1/*
2Open Source Initiative OSI - The MIT License (MIT):Licensing
3
4The MIT License (MIT)
5Copyright (c) 2013 Ralph Caraveo (deckarep@gmail.com)
6
7Permission is hereby granted, free of charge, to any person obtaining a copy of
8this software and associated documentation files (the "Software"), to deal in
9the Software without restriction, including without limitation the rights to
10use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
11of the Software, and to permit persons to whom the Software is furnished to do
12so, subject to the following conditions:
13
14The above copyright notice and this permission notice shall be included in all
15copies or substantial portions of the Software.
16
17THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23SOFTWARE.
24*/
25
26package mapset
27
28import (
29	"bytes"
30	"encoding/json"
31	"fmt"
32	"reflect"
33	"strings"
34)
35
36type threadUnsafeSet map[interface{}]struct{}
37
38// An OrderedPair represents a 2-tuple of values.
39type OrderedPair struct {
40	First  interface{}
41	Second interface{}
42}
43
44func newThreadUnsafeSet() threadUnsafeSet {
45	return make(threadUnsafeSet)
46}
47
48// Equal says whether two 2-tuples contain the same values in the same order.
49func (pair *OrderedPair) Equal(other OrderedPair) bool {
50	if pair.First == other.First &&
51		pair.Second == other.Second {
52		return true
53	}
54
55	return false
56}
57
58func (set *threadUnsafeSet) Add(i interface{}) bool {
59	_, found := (*set)[i]
60	(*set)[i] = struct{}{}
61	return !found //False if it existed already
62}
63
64func (set *threadUnsafeSet) Contains(i ...interface{}) bool {
65	for _, val := range i {
66		if _, ok := (*set)[val]; !ok {
67			return false
68		}
69	}
70	return true
71}
72
73func (set *threadUnsafeSet) IsSubset(other Set) bool {
74	_ = other.(*threadUnsafeSet)
75	for elem := range *set {
76		if !other.Contains(elem) {
77			return false
78		}
79	}
80	return true
81}
82
83func (set *threadUnsafeSet) IsProperSubset(other Set) bool {
84	return set.IsSubset(other) && !set.Equal(other)
85}
86
87func (set *threadUnsafeSet) IsSuperset(other Set) bool {
88	return other.IsSubset(set)
89}
90
91func (set *threadUnsafeSet) IsProperSuperset(other Set) bool {
92	return set.IsSuperset(other) && !set.Equal(other)
93}
94
95func (set *threadUnsafeSet) Union(other Set) Set {
96	o := other.(*threadUnsafeSet)
97
98	unionedSet := newThreadUnsafeSet()
99
100	for elem := range *set {
101		unionedSet.Add(elem)
102	}
103	for elem := range *o {
104		unionedSet.Add(elem)
105	}
106	return &unionedSet
107}
108
109func (set *threadUnsafeSet) Intersect(other Set) Set {
110	o := other.(*threadUnsafeSet)
111
112	intersection := newThreadUnsafeSet()
113	// loop over smaller set
114	if set.Cardinality() < other.Cardinality() {
115		for elem := range *set {
116			if other.Contains(elem) {
117				intersection.Add(elem)
118			}
119		}
120	} else {
121		for elem := range *o {
122			if set.Contains(elem) {
123				intersection.Add(elem)
124			}
125		}
126	}
127	return &intersection
128}
129
130func (set *threadUnsafeSet) Difference(other Set) Set {
131	_ = other.(*threadUnsafeSet)
132
133	difference := newThreadUnsafeSet()
134	for elem := range *set {
135		if !other.Contains(elem) {
136			difference.Add(elem)
137		}
138	}
139	return &difference
140}
141
142func (set *threadUnsafeSet) SymmetricDifference(other Set) Set {
143	_ = other.(*threadUnsafeSet)
144
145	aDiff := set.Difference(other)
146	bDiff := other.Difference(set)
147	return aDiff.Union(bDiff)
148}
149
150func (set *threadUnsafeSet) Clear() {
151	*set = newThreadUnsafeSet()
152}
153
154func (set *threadUnsafeSet) Remove(i interface{}) {
155	delete(*set, i)
156}
157
158func (set *threadUnsafeSet) Cardinality() int {
159	return len(*set)
160}
161
162func (set *threadUnsafeSet) Each(cb func(interface{}) bool) {
163	for elem := range *set {
164		if cb(elem) {
165			break
166		}
167	}
168}
169
170func (set *threadUnsafeSet) Iter() <-chan interface{} {
171	ch := make(chan interface{})
172	go func() {
173		for elem := range *set {
174			ch <- elem
175		}
176		close(ch)
177	}()
178
179	return ch
180}
181
182func (set *threadUnsafeSet) Iterator() *Iterator {
183	iterator, ch, stopCh := newIterator()
184
185	go func() {
186	L:
187		for elem := range *set {
188			select {
189			case <-stopCh:
190				break L
191			case ch <- elem:
192			}
193		}
194		close(ch)
195	}()
196
197	return iterator
198}
199
200func (set *threadUnsafeSet) Equal(other Set) bool {
201	_ = other.(*threadUnsafeSet)
202
203	if set.Cardinality() != other.Cardinality() {
204		return false
205	}
206	for elem := range *set {
207		if !other.Contains(elem) {
208			return false
209		}
210	}
211	return true
212}
213
214func (set *threadUnsafeSet) Clone() Set {
215	clonedSet := newThreadUnsafeSet()
216	for elem := range *set {
217		clonedSet.Add(elem)
218	}
219	return &clonedSet
220}
221
222func (set *threadUnsafeSet) String() string {
223	items := make([]string, 0, len(*set))
224
225	for elem := range *set {
226		items = append(items, fmt.Sprintf("%v", elem))
227	}
228	return fmt.Sprintf("Set{%s}", strings.Join(items, ", "))
229}
230
231// String outputs a 2-tuple in the form "(A, B)".
232func (pair OrderedPair) String() string {
233	return fmt.Sprintf("(%v, %v)", pair.First, pair.Second)
234}
235
236func (set *threadUnsafeSet) PowerSet() Set {
237	powSet := NewThreadUnsafeSet()
238	nullset := newThreadUnsafeSet()
239	powSet.Add(&nullset)
240
241	for es := range *set {
242		u := newThreadUnsafeSet()
243		j := powSet.Iter()
244		for er := range j {
245			p := newThreadUnsafeSet()
246			if reflect.TypeOf(er).Name() == "" {
247				k := er.(*threadUnsafeSet)
248				for ek := range *(k) {
249					p.Add(ek)
250				}
251			} else {
252				p.Add(er)
253			}
254			p.Add(es)
255			u.Add(&p)
256		}
257
258		powSet = powSet.Union(&u)
259	}
260
261	return powSet
262}
263
264func (set *threadUnsafeSet) CartesianProduct(other Set) Set {
265	o := other.(*threadUnsafeSet)
266	cartProduct := NewThreadUnsafeSet()
267
268	for i := range *set {
269		for j := range *o {
270			elem := OrderedPair{First: i, Second: j}
271			cartProduct.Add(elem)
272		}
273	}
274
275	return cartProduct
276}
277
278func (set *threadUnsafeSet) ToSlice() []interface{} {
279	keys := make([]interface{}, 0, set.Cardinality())
280	for elem := range *set {
281		keys = append(keys, elem)
282	}
283
284	return keys
285}
286
287// MarshalJSON creates a JSON array from the set, it marshals all elements
288func (set *threadUnsafeSet) MarshalJSON() ([]byte, error) {
289	items := make([]string, 0, set.Cardinality())
290
291	for elem := range *set {
292		b, err := json.Marshal(elem)
293		if err != nil {
294			return nil, err
295		}
296
297		items = append(items, string(b))
298	}
299
300	return []byte(fmt.Sprintf("[%s]", strings.Join(items, ","))), nil
301}
302
303// UnmarshalJSON recreates a set from a JSON array, it only decodes
304// primitive types. Numbers are decoded as json.Number.
305func (set *threadUnsafeSet) UnmarshalJSON(b []byte) error {
306	var i []interface{}
307
308	d := json.NewDecoder(bytes.NewReader(b))
309	d.UseNumber()
310	err := d.Decode(&i)
311	if err != nil {
312		return err
313	}
314
315	for _, v := range i {
316		switch t := v.(type) {
317		case []interface{}, map[string]interface{}:
318			continue
319		default:
320			set.Add(t)
321		}
322	}
323
324	return nil
325}
326