1// Copyright 2020 Shivaram Lingamneni <slingamn@cs.stanford.edu>
2// Copyright 2009 The Go Authors
3// Released under the MIT license
4
5package flatip
6
7import (
8	"bytes"
9	"errors"
10	"net"
11)
12
13var (
14	v4InV6Prefix = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}
15
16	IPv6loopback = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
17	IPv6zero     = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
18	IPv4zero     = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 0, 0, 0, 0}
19
20	ErrInvalidIPString = errors.New("String could not be interpreted as an IP address")
21)
22
23// packed versions of net.IP and net.IPNet; these are pure value types,
24// so they can be compared with == and used as map keys.
25
26// IP is a 128-bit representation of an IP address, using the 4-in-6 mapping
27// to represent IPv4 addresses.
28type IP [16]byte
29
30// IPNet is a IP network. In a valid value, all bits after PrefixLen are zeroes.
31type IPNet struct {
32	IP
33	PrefixLen uint8
34}
35
36// NetIP converts an IP into a net.IP.
37func (ip IP) NetIP() (result net.IP) {
38	result = make(net.IP, 16)
39	copy(result[:], ip[:])
40	return
41}
42
43// FromNetIP converts a net.IP into an IP.
44func FromNetIP(ip net.IP) (result IP) {
45	if len(ip) == 16 {
46		copy(result[:], ip[:])
47	} else {
48		result[10] = 0xff
49		result[11] = 0xff
50		copy(result[12:], ip[:])
51	}
52	return
53}
54
55// IPv4 returns the IP address representation of a.b.c.d
56func IPv4(a, b, c, d byte) (result IP) {
57	copy(result[:12], v4InV6Prefix)
58	result[12] = a
59	result[13] = b
60	result[14] = c
61	result[15] = d
62	return
63}
64
65// ParseIP parses a string representation of an IP address into an IP.
66// Unlike net.ParseIP, it returns an error instead of a zero value on failure,
67// since the zero value of `IP` is a representation of a valid IP (::0, the
68// IPv6 "unspecified address").
69func ParseIP(ipstr string) (ip IP, err error) {
70	// TODO reimplement this without net.ParseIP
71	netip := net.ParseIP(ipstr)
72	if netip == nil {
73		err = ErrInvalidIPString
74		return
75	}
76	netip = netip.To16()
77	copy(ip[:], netip)
78	return
79}
80
81// String returns the string representation of an IP
82func (ip IP) String() string {
83	// TODO reimplement this without using (net.IP).String()
84	return (net.IP)(ip[:]).String()
85}
86
87// IsIPv4 returns whether the IP is an IPv4 address.
88func (ip IP) IsIPv4() bool {
89	return bytes.Equal(ip[:12], v4InV6Prefix)
90}
91
92// IsLoopback returns whether the IP is a loopback address.
93func (ip IP) IsLoopback() bool {
94	if ip.IsIPv4() {
95		return ip[12] == 127
96	} else {
97		return ip == IPv6loopback
98	}
99}
100
101func (ip IP) IsUnspecified() bool {
102	return ip == IPv4zero || ip == IPv6zero
103}
104
105func rawCidrMask(length int) (m IP) {
106	n := uint(length)
107	for i := 0; i < 16; i++ {
108		if n >= 8 {
109			m[i] = 0xff
110			n -= 8
111			continue
112		}
113		m[i] = ^byte(0xff >> n)
114		return
115	}
116	return
117}
118
119func (ip IP) applyMask(mask IP) (result IP) {
120	for i := 0; i < 16; i += 1 {
121		result[i] = ip[i] & mask[i]
122	}
123	return
124}
125
126func cidrMask(ones, bits int) (result IP) {
127	switch bits {
128	case 32:
129		return rawCidrMask(96 + ones)
130	case 128:
131		return rawCidrMask(ones)
132	default:
133		return
134	}
135}
136
137// Mask returns the result of masking ip with the CIDR mask of
138// length 'ones', out of a total of 'bits' (which must be either
139// 32 for an IPv4 subnet or 128 for an IPv6 subnet).
140func (ip IP) Mask(ones, bits int) (result IP) {
141	return ip.applyMask(cidrMask(ones, bits))
142}
143
144// ToNetIPNet converts an IPNet into a net.IPNet.
145func (cidr IPNet) ToNetIPNet() (result net.IPNet) {
146	return net.IPNet{
147		IP:   cidr.IP.NetIP(),
148		Mask: net.CIDRMask(int(cidr.PrefixLen), 128),
149	}
150}
151
152// Contains retuns whether the network contains `ip`.
153func (cidr IPNet) Contains(ip IP) bool {
154	maskedIP := ip.Mask(int(cidr.PrefixLen), 128)
155	return cidr.IP == maskedIP
156}
157
158// FromNetIPnet converts a net.IPNet into an IPNet.
159func FromNetIPNet(network net.IPNet) (result IPNet) {
160	ones, _ := network.Mask.Size()
161	if len(network.IP) == 16 {
162		copy(result.IP[:], network.IP[:])
163	} else {
164		result.IP[10] = 0xff
165		result.IP[11] = 0xff
166		copy(result.IP[12:], network.IP[:])
167		ones += 96
168	}
169	// perform masking so that equal CIDRs are ==
170	result.IP = result.IP.Mask(ones, 128)
171	result.PrefixLen = uint8(ones)
172	return
173}
174
175// String returns a string representation of an IPNet.
176func (cidr IPNet) String() string {
177	ip := make(net.IP, 16)
178	copy(ip[:], cidr.IP[:])
179	ipnet := net.IPNet{
180		IP:   ip,
181		Mask: net.CIDRMask(int(cidr.PrefixLen), 128),
182	}
183	return ipnet.String()
184}
185
186// HumanReadableString returns a string representation of an IPNet;
187// if the network contains only a single IP address, it returns
188// a representation of that address.
189func (cidr IPNet) HumanReadableString() string {
190	if cidr.PrefixLen == 128 {
191		return cidr.IP.String()
192	}
193	return cidr.String()
194}
195
196// IsZero tests whether ipnet is the zero value of an IPNet, 0::0/0.
197// Although this is a valid subnet, it can still be used as a sentinel
198// value in some contexts.
199func (ipnet IPNet) IsZero() bool {
200	return ipnet == IPNet{}
201}
202
203// ParseCIDR parses a string representation of an IP network in CIDR notation,
204// then returns it as an IPNet (along with the original, unmasked address).
205func ParseCIDR(netstr string) (ip IP, ipnet IPNet, err error) {
206	// TODO reimplement this without net.ParseCIDR
207	nip, nipnet, err := net.ParseCIDR(netstr)
208	if err != nil {
209		return
210	}
211	return FromNetIP(nip), FromNetIPNet(*nipnet), nil
212}
213