1// Copyright 2016-2017 VMware, Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package ip
16
17import (
18	"bytes"
19	"fmt"
20	"math"
21	"net"
22	"strconv"
23	"strings"
24)
25
26type Range struct {
27	FirstIP net.IP `vic:"0.1" scope:"read-only" key:"first"`
28	LastIP  net.IP `vic:"0.1" scope:"read-only" key:"last"`
29}
30
31func NewRange(first, last net.IP) *Range {
32	return &Range{FirstIP: first, LastIP: last}
33}
34
35func (i *Range) Overlaps(other Range) bool {
36	if (bytes.Compare(i.FirstIP, other.FirstIP) <= 0 && bytes.Compare(other.FirstIP, i.LastIP) <= 0) ||
37		(bytes.Compare(i.FirstIP, other.LastIP) <= 0 && bytes.Compare(other.FirstIP, i.LastIP) <= 0) {
38		return true
39	}
40
41	return false
42}
43
44func (i *Range) String() string {
45	n := i.Network()
46	if n == nil {
47		return fmt.Sprintf("%s-%s", i.FirstIP, i.LastIP)
48	}
49
50	return n.String()
51}
52
53func (i *Range) Equal(other *Range) bool {
54	return i.FirstIP.Equal(other.FirstIP) && i.LastIP.Equal(other.LastIP)
55}
56
57// Network returns the network that this range represents, if any
58func (i *Range) Network() *net.IPNet {
59	// only works for ipv4
60	first := i.FirstIP.To4()
61	last := i.LastIP.To4()
62	diff := net.IPv4(0, 0, 0, 0).To4()
63	for j := 0; j < net.IPv4len; j++ {
64		diff[j] = first[j] ^ last[j]
65	}
66
67	var m uint
68	for j := net.IPv4len - 1; j >= 0; j-- {
69		var k uint
70		for ; k < 8; k++ {
71			if diff[j]>>k == 0 {
72				break
73			}
74		}
75
76		m += k
77		if k < 8 {
78			break
79		}
80	}
81
82	if m == 0 {
83		return nil
84	}
85
86	mask := net.CIDRMask(32-int(m), 32)
87	for j, f := range first {
88		l := f | ^mask[j]
89		if l != last[j] {
90			return nil
91		}
92		if f != (f & mask[j]) {
93			return nil
94		}
95	}
96
97	return &net.IPNet{IP: first, Mask: mask}
98}
99
100func ParseRange(r string) *Range {
101	var first, last net.IP
102	addr, ipnet, err := net.ParseCIDR(r)
103	if err == nil && addr != nil && ipnet != nil {
104		// normalize to IPv4 or it's 16 bytes by default
105		first = addr.To4()
106		if first == nil {
107			first = addr
108		}
109
110		last := make(net.IP, len(first))
111
112		// IPv6 - don't know if we'll have a mask available and it's not currently covered or supported.
113		for i, f := range first {
114			last[i] = f | ^ipnet.Mask[i]
115		}
116
117		return &Range{
118			FirstIP: first,
119			LastIP:  last,
120		}
121	}
122
123	comps := strings.Split(r, "-")
124	if len(comps) != 2 {
125		return nil
126	}
127
128	first = net.ParseIP(comps[0])
129	if first == nil {
130		return nil
131	}
132
133	last = net.ParseIP(comps[1])
134	if last == nil {
135		var end int
136		end, err := strconv.Atoi(comps[1])
137		if err != nil || end <= int(first[15]) || end > math.MaxUint8 {
138			return nil
139		}
140
141		last = net.IPv4(first[12], first[13], first[14], byte(end))
142	}
143
144	if bytes.Compare(first, last) > 0 {
145		return nil
146	}
147
148	return &Range{
149		FirstIP: first,
150		LastIP:  last,
151	}
152}
153
154// MarshalText implements the encoding.TextMarshaler interface
155func (i *Range) MarshalText() ([]byte, error) {
156	return []byte(i.String()), nil
157}
158
159// UmarshalText implements the encoding.TextUnmarshaler interface
160func (i *Range) UnmarshalText(text []byte) error {
161	s := string(text)
162	r := ParseRange(s)
163	if r == nil {
164		return fmt.Errorf("parse error: %s", s)
165	}
166
167	*i = *r
168	return nil
169}
170
171// ParseIPandMask parses a CIDR format address (e.g. 1.1.1.1/8)
172func ParseIPandMask(s string) (net.IPNet, error) {
173	var i net.IPNet
174	ip, ipnet, err := net.ParseCIDR(s)
175	if err != nil {
176		return i, err
177	}
178
179	i.IP = ip
180	i.Mask = ipnet.Mask
181	return i, nil
182}
183
184// Empty determines if net.IPNet is empty
185func Empty(i net.IPNet) bool {
186	return i.IP == nil && i.Mask == nil
187}
188
189func IsUnspecifiedIP(ip net.IP) bool {
190	return len(ip) == 0 || ip.IsUnspecified()
191}
192
193func IsUnspecifiedSubnet(n *net.IPNet) bool {
194	if n == nil || IsUnspecifiedIP(n.IP) {
195		return true
196	}
197
198	ones, bits := n.Mask.Size()
199	return bits == 0 || ones == 0
200}
201
202// AllZerosAddr returns the all-zeros address for a subnet
203func AllZerosAddr(subnet *net.IPNet) net.IP {
204	return subnet.IP.Mask(subnet.Mask)
205}
206
207// AllOnesAddr returns the all-ones address for a subnet
208func AllOnesAddr(subnet *net.IPNet) net.IP {
209	ones := net.IPv4(0, 0, 0, 0)
210	ip := subnet.IP.To16()
211	for i := range ip[12:] {
212		ones[12+i] = ip[12+i] | ^subnet.Mask[i]
213	}
214
215	return ones
216}
217
218func IsRoutableIP(ip net.IP, subnet *net.IPNet) bool {
219	return subnet.Contains(ip) && !ip.Equal(AllZerosAddr(subnet)) && !ip.Equal(AllOnesAddr(subnet))
220}
221