1package cidrutil
2
3import (
4	"fmt"
5	"net"
6	"strings"
7
8	"github.com/hashicorp/errwrap"
9	sockaddr "github.com/hashicorp/go-sockaddr"
10	"github.com/hashicorp/vault/sdk/helper/strutil"
11)
12
13// RemoteAddrIsOk checks if the given remote address is either:
14//   - OK because there's no CIDR whitelist
15//   - OK because it's in the CIDR whitelist
16func RemoteAddrIsOk(remoteAddr string, boundCIDRs []*sockaddr.SockAddrMarshaler) bool {
17	if len(boundCIDRs) == 0 {
18		// There's no CIDR whitelist.
19		return true
20	}
21	remoteSockAddr, err := sockaddr.NewSockAddr(remoteAddr)
22	if err != nil {
23		// Can't tell, err on the side of less access.
24		return false
25	}
26	for _, cidr := range boundCIDRs {
27		if cidr.Contains(remoteSockAddr) {
28			// Whitelisted.
29			return true
30		}
31	}
32	// Not whitelisted.
33	return false
34}
35
36// IPBelongsToCIDR checks if the given IP is encompassed by the given CIDR block
37func IPBelongsToCIDR(ipAddr string, cidr string) (bool, error) {
38	if ipAddr == "" {
39		return false, fmt.Errorf("missing IP address")
40	}
41
42	ip := net.ParseIP(ipAddr)
43	if ip == nil {
44		return false, fmt.Errorf("invalid IP address")
45	}
46
47	_, ipnet, err := net.ParseCIDR(cidr)
48	if err != nil {
49		return false, err
50	}
51
52	if !ipnet.Contains(ip) {
53		return false, nil
54	}
55
56	return true, nil
57}
58
59// IPBelongsToCIDRBlocksSlice checks if the given IP is encompassed by any of the given
60// CIDR blocks
61func IPBelongsToCIDRBlocksSlice(ipAddr string, cidrs []string) (bool, error) {
62	if ipAddr == "" {
63		return false, fmt.Errorf("missing IP address")
64	}
65
66	if len(cidrs) == 0 {
67		return false, fmt.Errorf("missing CIDR blocks to be checked against")
68	}
69
70	if ip := net.ParseIP(ipAddr); ip == nil {
71		return false, fmt.Errorf("invalid IP address")
72	}
73
74	for _, cidr := range cidrs {
75		belongs, err := IPBelongsToCIDR(ipAddr, cidr)
76		if err != nil {
77			return false, err
78		}
79		if belongs {
80			return true, nil
81		}
82	}
83
84	return false, nil
85}
86
87// ValidateCIDRListString checks if the list of CIDR blocks are valid, given
88// that the input is a string composed by joining all the CIDR blocks using a
89// separator. The input is separated based on the given separator and validity
90// of each is checked.
91func ValidateCIDRListString(cidrList string, separator string) (bool, error) {
92	if cidrList == "" {
93		return false, fmt.Errorf("missing CIDR list that needs validation")
94	}
95	if separator == "" {
96		return false, fmt.Errorf("missing separator")
97	}
98
99	return ValidateCIDRListSlice(strutil.ParseDedupLowercaseAndSortStrings(cidrList, separator))
100}
101
102// ValidateCIDRListSlice checks if the given list of CIDR blocks are valid
103func ValidateCIDRListSlice(cidrBlocks []string) (bool, error) {
104	if len(cidrBlocks) == 0 {
105		return false, fmt.Errorf("missing CIDR blocks that needs validation")
106	}
107
108	for _, block := range cidrBlocks {
109		if _, _, err := net.ParseCIDR(strings.TrimSpace(block)); err != nil {
110			return false, err
111		}
112	}
113
114	return true, nil
115}
116
117// Subset checks if the IPs belonging to a given CIDR block is a subset of IPs
118// belonging to another CIDR block.
119func Subset(cidr1, cidr2 string) (bool, error) {
120	if cidr1 == "" {
121		return false, fmt.Errorf("missing CIDR to be checked against")
122	}
123
124	if cidr2 == "" {
125		return false, fmt.Errorf("missing CIDR that needs to be checked")
126	}
127
128	ip1, net1, err := net.ParseCIDR(cidr1)
129	if err != nil {
130		return false, errwrap.Wrapf("failed to parse the CIDR to be checked against: {{err}}", err)
131	}
132
133	zeroAddr := false
134	if ip := ip1.To4(); ip != nil && ip.Equal(net.IPv4zero) {
135		zeroAddr = true
136	}
137	if ip := ip1.To16(); ip != nil && ip.Equal(net.IPv6zero) {
138		zeroAddr = true
139	}
140
141	maskLen1, _ := net1.Mask.Size()
142	if !zeroAddr && maskLen1 == 0 {
143		return false, fmt.Errorf("CIDR to be checked against is not in its canonical form")
144	}
145
146	ip2, net2, err := net.ParseCIDR(cidr2)
147	if err != nil {
148		return false, errwrap.Wrapf("failed to parse the CIDR that needs to be checked: {{err}}", err)
149	}
150
151	zeroAddr = false
152	if ip := ip2.To4(); ip != nil && ip.Equal(net.IPv4zero) {
153		zeroAddr = true
154	}
155	if ip := ip2.To16(); ip != nil && ip.Equal(net.IPv6zero) {
156		zeroAddr = true
157	}
158
159	maskLen2, _ := net2.Mask.Size()
160	if !zeroAddr && maskLen2 == 0 {
161		return false, fmt.Errorf("CIDR that needs to be checked is not in its canonical form")
162	}
163
164	// If the mask length of the CIDR that needs to be checked is smaller
165	// then the mask length of the CIDR to be checked against, then the
166	// former will encompass more IPs than the latter, and hence can't be a
167	// subset of the latter.
168	if maskLen2 < maskLen1 {
169		return false, nil
170	}
171
172	belongs, err := IPBelongsToCIDR(net2.IP.String(), cidr1)
173	if err != nil {
174		return false, err
175	}
176
177	return belongs, nil
178}
179
180// SubsetBlocks checks if each CIDR block of a given set of CIDR blocks, is a
181// subset of at least one CIDR block belonging to another set of CIDR blocks.
182// First parameter is the set of CIDR blocks to check against and the second
183// parameter is the set of CIDR blocks that needs to be checked.
184func SubsetBlocks(cidrBlocks1, cidrBlocks2 []string) (bool, error) {
185	if len(cidrBlocks1) == 0 {
186		return false, fmt.Errorf("missing CIDR blocks to be checked against")
187	}
188
189	if len(cidrBlocks2) == 0 {
190		return false, fmt.Errorf("missing CIDR blocks that needs to be checked")
191	}
192
193	// Check if all the elements of cidrBlocks2 is a subset of at least one
194	// element of cidrBlocks1
195	for _, cidrBlock2 := range cidrBlocks2 {
196		isSubset := false
197		for _, cidrBlock1 := range cidrBlocks1 {
198			subset, err := Subset(cidrBlock1, cidrBlock2)
199			if err != nil {
200				return false, err
201			}
202			// If CIDR is a subset of any of the CIDR block, its
203			// good enough. Break out.
204			if subset {
205				isSubset = true
206				break
207			}
208		}
209		// CIDR block was not a subset of any of the CIDR blocks in the
210		// set of blocks to check against
211		if !isSubset {
212			return false, nil
213		}
214	}
215
216	return true, nil
217}
218