1package sockaddr
2
3import (
4	"bytes"
5	"sort"
6)
7
8// SockAddrs is a slice of SockAddrs
9type SockAddrs []SockAddr
10
11func (s SockAddrs) Len() int      { return len(s) }
12func (s SockAddrs) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
13
14// CmpAddrFunc is the function signature that must be met to be used in the
15// OrderedAddrBy multiAddrSorter
16type CmpAddrFunc func(p1, p2 *SockAddr) int
17
18// multiAddrSorter implements the Sort interface, sorting the SockAddrs within.
19type multiAddrSorter struct {
20	addrs SockAddrs
21	cmp   []CmpAddrFunc
22}
23
24// Sort sorts the argument slice according to the Cmp functions passed to
25// OrderedAddrBy.
26func (ms *multiAddrSorter) Sort(sockAddrs SockAddrs) {
27	ms.addrs = sockAddrs
28	sort.Sort(ms)
29}
30
31// OrderedAddrBy sorts SockAddr by the list of sort function pointers.
32func OrderedAddrBy(cmpFuncs ...CmpAddrFunc) *multiAddrSorter {
33	return &multiAddrSorter{
34		cmp: cmpFuncs,
35	}
36}
37
38// Len is part of sort.Interface.
39func (ms *multiAddrSorter) Len() int {
40	return len(ms.addrs)
41}
42
43// Less is part of sort.Interface. It is implemented by looping along the
44// Cmp() functions until it finds a comparison that is either less than,
45// equal to, or greater than.
46func (ms *multiAddrSorter) Less(i, j int) bool {
47	p, q := &ms.addrs[i], &ms.addrs[j]
48	// Try all but the last comparison.
49	var k int
50	for k = 0; k < len(ms.cmp)-1; k++ {
51		cmp := ms.cmp[k]
52		x := cmp(p, q)
53		switch x {
54		case -1:
55			// p < q, so we have a decision.
56			return true
57		case 1:
58			// p > q, so we have a decision.
59			return false
60		}
61		// p == q; try the next comparison.
62	}
63	// All comparisons to here said "equal", so just return whatever the
64	// final comparison reports.
65	switch ms.cmp[k](p, q) {
66	case -1:
67		return true
68	case 1:
69		return false
70	default:
71		// Still a tie! Now what?
72		return false
73	}
74}
75
76// Swap is part of sort.Interface.
77func (ms *multiAddrSorter) Swap(i, j int) {
78	ms.addrs[i], ms.addrs[j] = ms.addrs[j], ms.addrs[i]
79}
80
81const (
82	// NOTE (sean@): These constants are here for code readability only and
83	// are sprucing up the code for readability purposes.  Some of the
84	// Cmp*() variants have confusing logic (especially when dealing with
85	// mixed-type comparisons) and this, I think, has made it easier to grok
86	// the code faster.
87	sortReceiverBeforeArg = -1
88	sortDeferDecision     = 0
89	sortArgBeforeReceiver = 1
90)
91
92// AscAddress is a sorting function to sort SockAddrs by their respective
93// address type.  Non-equal types are deferred in the sort.
94func AscAddress(p1Ptr, p2Ptr *SockAddr) int {
95	p1 := *p1Ptr
96	p2 := *p2Ptr
97
98	switch v := p1.(type) {
99	case IPv4Addr:
100		return v.CmpAddress(p2)
101	case IPv6Addr:
102		return v.CmpAddress(p2)
103	case UnixSock:
104		return v.CmpAddress(p2)
105	default:
106		return sortDeferDecision
107	}
108}
109
110// AscPort is a sorting function to sort SockAddrs by their respective address
111// type.  Non-equal types are deferred in the sort.
112func AscPort(p1Ptr, p2Ptr *SockAddr) int {
113	p1 := *p1Ptr
114	p2 := *p2Ptr
115
116	switch v := p1.(type) {
117	case IPv4Addr:
118		return v.CmpPort(p2)
119	case IPv6Addr:
120		return v.CmpPort(p2)
121	default:
122		return sortDeferDecision
123	}
124}
125
126// AscPrivate is a sorting function to sort "more secure" private values before
127// "more public" values.  Both IPv4 and IPv6 are compared against RFC6890
128// (RFC6890 includes, and is not limited to, RFC1918 and RFC6598 for IPv4, and
129// IPv6 includes RFC4193).
130func AscPrivate(p1Ptr, p2Ptr *SockAddr) int {
131	p1 := *p1Ptr
132	p2 := *p2Ptr
133
134	switch v := p1.(type) {
135	case IPv4Addr, IPv6Addr:
136		return v.CmpRFC(6890, p2)
137	default:
138		return sortDeferDecision
139	}
140}
141
142// AscNetworkSize is a sorting function to sort SockAddrs based on their network
143// size.  Non-equal types are deferred in the sort.
144func AscNetworkSize(p1Ptr, p2Ptr *SockAddr) int {
145	p1 := *p1Ptr
146	p2 := *p2Ptr
147	p1Type := p1.Type()
148	p2Type := p2.Type()
149
150	// Network size operations on non-IP types make no sense
151	if p1Type != p2Type && p1Type != TypeIP {
152		return sortDeferDecision
153	}
154
155	ipA := p1.(IPAddr)
156	ipB := p2.(IPAddr)
157
158	return bytes.Compare([]byte(*ipA.NetIPMask()), []byte(*ipB.NetIPMask()))
159}
160
161// AscType is a sorting function to sort "more secure" types before
162// "less-secure" types.
163func AscType(p1Ptr, p2Ptr *SockAddr) int {
164	p1 := *p1Ptr
165	p2 := *p2Ptr
166	p1Type := p1.Type()
167	p2Type := p2.Type()
168	switch {
169	case p1Type < p2Type:
170		return sortReceiverBeforeArg
171	case p1Type == p2Type:
172		return sortDeferDecision
173	case p1Type > p2Type:
174		return sortArgBeforeReceiver
175	default:
176		return sortDeferDecision
177	}
178}
179
180// FilterByType returns two lists: a list of matched and unmatched SockAddrs
181func (sas SockAddrs) FilterByType(type_ SockAddrType) (matched, excluded SockAddrs) {
182	matched = make(SockAddrs, 0, len(sas))
183	excluded = make(SockAddrs, 0, len(sas))
184
185	for _, sa := range sas {
186		if sa.Type()&type_ != 0 {
187			matched = append(matched, sa)
188		} else {
189			excluded = append(excluded, sa)
190		}
191	}
192	return matched, excluded
193}
194