1package nat
2
3import (
4	"sort"
5	"strings"
6)
7
8type portSorter struct {
9	ports []Port
10	by    func(i, j Port) bool
11}
12
13func (s *portSorter) Len() int {
14	return len(s.ports)
15}
16
17func (s *portSorter) Swap(i, j int) {
18	s.ports[i], s.ports[j] = s.ports[j], s.ports[i]
19}
20
21func (s *portSorter) Less(i, j int) bool {
22	ip := s.ports[i]
23	jp := s.ports[j]
24
25	return s.by(ip, jp)
26}
27
28// Sort sorts a list of ports using the provided predicate
29// This function should compare `i` and `j`, returning true if `i` is
30// considered to be less than `j`
31func Sort(ports []Port, predicate func(i, j Port) bool) {
32	s := &portSorter{ports, predicate}
33	sort.Sort(s)
34}
35
36type portMapEntry struct {
37	port    Port
38	binding PortBinding
39}
40
41type portMapSorter []portMapEntry
42
43func (s portMapSorter) Len() int      { return len(s) }
44func (s portMapSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
45
46// sort the port so that the order is:
47// 1. port with larger specified bindings
48// 2. larger port
49// 3. port with tcp protocol
50func (s portMapSorter) Less(i, j int) bool {
51	pi, pj := s[i].port, s[j].port
52	hpi, hpj := toInt(s[i].binding.HostPort), toInt(s[j].binding.HostPort)
53	return hpi > hpj || pi.Int() > pj.Int() || (pi.Int() == pj.Int() && strings.ToLower(pi.Proto()) == "tcp")
54}
55
56// SortPortMap sorts the list of ports and their respected mapping. The ports
57// will explicit HostPort will be placed first.
58func SortPortMap(ports []Port, bindings PortMap) {
59	s := portMapSorter{}
60	for _, p := range ports {
61		if binding, ok := bindings[p]; ok {
62			for _, b := range binding {
63				s = append(s, portMapEntry{port: p, binding: b})
64			}
65			bindings[p] = []PortBinding{}
66		} else {
67			s = append(s, portMapEntry{port: p})
68		}
69	}
70
71	sort.Sort(s)
72	var (
73		i  int
74		pm = make(map[Port]struct{})
75	)
76	// reorder ports
77	for _, entry := range s {
78		if _, ok := pm[entry.port]; !ok {
79			ports[i] = entry.port
80			pm[entry.port] = struct{}{}
81			i++
82		}
83		// reorder bindings for this port
84		if _, ok := bindings[entry.port]; ok {
85			bindings[entry.port] = append(bindings[entry.port], entry.binding)
86		}
87	}
88}
89
90func toInt(s string) uint64 {
91	i, _, err := ParsePortRange(s)
92	if err != nil {
93		i = 0
94	}
95	return i
96}
97