1package dag
2
3import (
4	"sync"
5)
6
7// Set is a set data structure.
8type Set struct {
9	m    map[interface{}]interface{}
10	once sync.Once
11}
12
13// Hashable is the interface used by set to get the hash code of a value.
14// If this isn't given, then the value of the item being added to the set
15// itself is used as the comparison value.
16type Hashable interface {
17	Hashcode() interface{}
18}
19
20// hashcode returns the hashcode used for set elements.
21func hashcode(v interface{}) interface{} {
22	if h, ok := v.(Hashable); ok {
23		return h.Hashcode()
24	}
25
26	return v
27}
28
29// Add adds an item to the set
30func (s *Set) Add(v interface{}) {
31	s.once.Do(s.init)
32	s.m[hashcode(v)] = v
33}
34
35// Delete removes an item from the set.
36func (s *Set) Delete(v interface{}) {
37	s.once.Do(s.init)
38	delete(s.m, hashcode(v))
39}
40
41// Include returns true/false of whether a value is in the set.
42func (s *Set) Include(v interface{}) bool {
43	s.once.Do(s.init)
44	_, ok := s.m[hashcode(v)]
45	return ok
46}
47
48// Intersection computes the set intersection with other.
49func (s *Set) Intersection(other *Set) *Set {
50	result := new(Set)
51	if s == nil {
52		return result
53	}
54	if other != nil {
55		for _, v := range s.m {
56			if other.Include(v) {
57				result.Add(v)
58			}
59		}
60	}
61
62	return result
63}
64
65// Difference returns a set with the elements that s has but
66// other doesn't.
67func (s *Set) Difference(other *Set) *Set {
68	result := new(Set)
69	if s != nil {
70		for k, v := range s.m {
71			var ok bool
72			if other != nil {
73				_, ok = other.m[k]
74			}
75			if !ok {
76				result.Add(v)
77			}
78		}
79	}
80
81	return result
82}
83
84// Filter returns a set that contains the elements from the receiver
85// where the given callback returns true.
86func (s *Set) Filter(cb func(interface{}) bool) *Set {
87	result := new(Set)
88
89	for _, v := range s.m {
90		if cb(v) {
91			result.Add(v)
92		}
93	}
94
95	return result
96}
97
98// Len is the number of items in the set.
99func (s *Set) Len() int {
100	if s == nil {
101		return 0
102	}
103
104	return len(s.m)
105}
106
107// List returns the list of set elements.
108func (s *Set) List() []interface{} {
109	if s == nil {
110		return nil
111	}
112
113	r := make([]interface{}, 0, len(s.m))
114	for _, v := range s.m {
115		r = append(r, v)
116	}
117
118	return r
119}
120
121func (s *Set) init() {
122	s.m = make(map[interface{}]interface{})
123}
124