1package sockaddr
2
3import (
4	"bytes"
5	"encoding/binary"
6	"fmt"
7	"math/big"
8	"net"
9)
10
11type (
12	// IPv6Address is a named type representing an IPv6 address.
13	IPv6Address *big.Int
14
15	// IPv6Network is a named type representing an IPv6 network.
16	IPv6Network *big.Int
17
18	// IPv6Mask is a named type representing an IPv6 network mask.
19	IPv6Mask *big.Int
20)
21
22// IPv6HostPrefix is a constant represents a /128 IPv6 Prefix.
23const IPv6HostPrefix = IPPrefixLen(128)
24
25// ipv6HostMask is an unexported big.Int representing a /128 IPv6 address.
26// This value must be a constant and always set to all ones.
27var ipv6HostMask IPv6Mask
28
29// ipv6AddrAttrMap is a map of the IPv6Addr type-specific attributes.
30var ipv6AddrAttrMap map[AttrName]func(IPv6Addr) string
31var ipv6AddrAttrs []AttrName
32
33func init() {
34	biMask := new(big.Int)
35	biMask.SetBytes([]byte{
36		0xff, 0xff,
37		0xff, 0xff,
38		0xff, 0xff,
39		0xff, 0xff,
40		0xff, 0xff,
41		0xff, 0xff,
42		0xff, 0xff,
43		0xff, 0xff,
44	},
45	)
46	ipv6HostMask = IPv6Mask(biMask)
47
48	ipv6AddrInit()
49}
50
51// IPv6Addr implements a convenience wrapper around the union of Go's
52// built-in net.IP and net.IPNet types.  In UNIX-speak, IPv6Addr implements
53// `sockaddr` when the the address family is set to AF_INET6
54// (i.e. `sockaddr_in6`).
55type IPv6Addr struct {
56	IPAddr
57	Address IPv6Address
58	Mask    IPv6Mask
59	Port    IPPort
60}
61
62// NewIPv6Addr creates an IPv6Addr from a string.  String can be in the form of
63// an an IPv6:port (e.g. `[2001:4860:0:2001::68]:80`, in which case the mask is
64// assumed to be a /128), an IPv6 address (e.g. `2001:4860:0:2001::68`, also
65// with a `/128` mask), an IPv6 CIDR (e.g. `2001:4860:0:2001::68/64`, which has
66// its IP port initialized to zero).  ipv6Str can not be a hostname.
67//
68// NOTE: Many net.*() routines will initialize and return an IPv4 address.
69// Always test to make sure the address returned cannot be converted to a 4 byte
70// array using To4().
71func NewIPv6Addr(ipv6Str string) (IPv6Addr, error) {
72	v6Addr := false
73LOOP:
74	for i := 0; i < len(ipv6Str); i++ {
75		switch ipv6Str[i] {
76		case '.':
77			break LOOP
78		case ':':
79			v6Addr = true
80			break LOOP
81		}
82	}
83
84	if !v6Addr {
85		return IPv6Addr{}, fmt.Errorf("Unable to resolve %+q as an IPv6 address, appears to be an IPv4 address", ipv6Str)
86	}
87
88	// Attempt to parse ipv6Str as a /128 host with a port number.
89	tcpAddr, err := net.ResolveTCPAddr("tcp6", ipv6Str)
90	if err == nil {
91		ipv6 := tcpAddr.IP.To16()
92		if ipv6 == nil {
93			return IPv6Addr{}, fmt.Errorf("Unable to resolve %+q as a 16byte IPv6 address", ipv6Str)
94		}
95
96		ipv6BigIntAddr := new(big.Int)
97		ipv6BigIntAddr.SetBytes(ipv6)
98
99		ipv6BigIntMask := new(big.Int)
100		ipv6BigIntMask.Set(ipv6HostMask)
101
102		ipv6Addr := IPv6Addr{
103			Address: IPv6Address(ipv6BigIntAddr),
104			Mask:    IPv6Mask(ipv6BigIntMask),
105			Port:    IPPort(tcpAddr.Port),
106		}
107
108		return ipv6Addr, nil
109	}
110
111	// Parse as a naked IPv6 address.  Trim square brackets if present.
112	if len(ipv6Str) > 2 && ipv6Str[0] == '[' && ipv6Str[len(ipv6Str)-1] == ']' {
113		ipv6Str = ipv6Str[1 : len(ipv6Str)-1]
114	}
115	ip := net.ParseIP(ipv6Str)
116	if ip != nil {
117		ipv6 := ip.To16()
118		if ipv6 == nil {
119			return IPv6Addr{}, fmt.Errorf("Unable to string convert %+q to a 16byte IPv6 address", ipv6Str)
120		}
121
122		ipv6BigIntAddr := new(big.Int)
123		ipv6BigIntAddr.SetBytes(ipv6)
124
125		ipv6BigIntMask := new(big.Int)
126		ipv6BigIntMask.Set(ipv6HostMask)
127
128		return IPv6Addr{
129			Address: IPv6Address(ipv6BigIntAddr),
130			Mask:    IPv6Mask(ipv6BigIntMask),
131		}, nil
132	}
133
134	// Parse as an IPv6 CIDR
135	ipAddr, network, err := net.ParseCIDR(ipv6Str)
136	if err == nil {
137		ipv6 := ipAddr.To16()
138		if ipv6 == nil {
139			return IPv6Addr{}, fmt.Errorf("Unable to convert %+q to a 16byte IPv6 address", ipv6Str)
140		}
141
142		ipv6BigIntAddr := new(big.Int)
143		ipv6BigIntAddr.SetBytes(ipv6)
144
145		ipv6BigIntMask := new(big.Int)
146		ipv6BigIntMask.SetBytes(network.Mask)
147
148		ipv6Addr := IPv6Addr{
149			Address: IPv6Address(ipv6BigIntAddr),
150			Mask:    IPv6Mask(ipv6BigIntMask),
151		}
152		return ipv6Addr, nil
153	}
154
155	return IPv6Addr{}, fmt.Errorf("Unable to parse %+q to an IPv6 address: %v", ipv6Str, err)
156}
157
158// AddressBinString returns a string with the IPv6Addr's Address represented
159// as a sequence of '0' and '1' characters.  This method is useful for
160// debugging or by operators who want to inspect an address.
161func (ipv6 IPv6Addr) AddressBinString() string {
162	bi := big.Int(*ipv6.Address)
163	return fmt.Sprintf("%0128s", bi.Text(2))
164}
165
166// AddressHexString returns a string with the IPv6Addr address represented as
167// a sequence of hex characters.  This method is useful for debugging or by
168// operators who want to inspect an address.
169func (ipv6 IPv6Addr) AddressHexString() string {
170	bi := big.Int(*ipv6.Address)
171	return fmt.Sprintf("%032s", bi.Text(16))
172}
173
174// CmpAddress follows the Cmp() standard protocol and returns:
175//
176// - -1 If the receiver should sort first because its address is lower than arg
177// - 0 if the SockAddr arg equal to the receiving IPv6Addr or the argument is of a
178//   different type.
179// - 1 If the argument should sort first.
180func (ipv6 IPv6Addr) CmpAddress(sa SockAddr) int {
181	ipv6b, ok := sa.(IPv6Addr)
182	if !ok {
183		return sortDeferDecision
184	}
185
186	ipv6aBigInt := new(big.Int)
187	ipv6aBigInt.Set(ipv6.Address)
188	ipv6bBigInt := new(big.Int)
189	ipv6bBigInt.Set(ipv6b.Address)
190
191	return ipv6aBigInt.Cmp(ipv6bBigInt)
192}
193
194// CmpPort follows the Cmp() standard protocol and returns:
195//
196// - -1 If the receiver should sort first because its port is lower than arg
197// - 0 if the SockAddr arg's port number is equal to the receiving IPv6Addr,
198//   regardless of type.
199// - 1 If the argument should sort first.
200func (ipv6 IPv6Addr) CmpPort(sa SockAddr) int {
201	var saPort IPPort
202	switch v := sa.(type) {
203	case IPv4Addr:
204		saPort = v.Port
205	case IPv6Addr:
206		saPort = v.Port
207	default:
208		return sortDeferDecision
209	}
210
211	switch {
212	case ipv6.Port == saPort:
213		return sortDeferDecision
214	case ipv6.Port < saPort:
215		return sortReceiverBeforeArg
216	default:
217		return sortArgBeforeReceiver
218	}
219}
220
221// CmpRFC follows the Cmp() standard protocol and returns:
222//
223// - -1 If the receiver should sort first because it belongs to the RFC and its
224//   arg does not
225// - 0 if the receiver and arg both belong to the same RFC or neither do.
226// - 1 If the arg belongs to the RFC but receiver does not.
227func (ipv6 IPv6Addr) CmpRFC(rfcNum uint, sa SockAddr) int {
228	recvInRFC := IsRFC(rfcNum, ipv6)
229	ipv6b, ok := sa.(IPv6Addr)
230	if !ok {
231		// If the receiver is part of the desired RFC and the SockAddr
232		// argument is not, sort receiver before the non-IPv6 SockAddr.
233		// Conversely, if the receiver is not part of the RFC, punt on
234		// sorting and leave it for the next sorter.
235		if recvInRFC {
236			return sortReceiverBeforeArg
237		} else {
238			return sortDeferDecision
239		}
240	}
241
242	argInRFC := IsRFC(rfcNum, ipv6b)
243	switch {
244	case (recvInRFC && argInRFC), (!recvInRFC && !argInRFC):
245		// If a and b both belong to the RFC, or neither belong to
246		// rfcNum, defer sorting to the next sorter.
247		return sortDeferDecision
248	case recvInRFC && !argInRFC:
249		return sortReceiverBeforeArg
250	default:
251		return sortArgBeforeReceiver
252	}
253}
254
255// Contains returns true if the SockAddr is contained within the receiver.
256func (ipv6 IPv6Addr) Contains(sa SockAddr) bool {
257	ipv6b, ok := sa.(IPv6Addr)
258	if !ok {
259		return false
260	}
261
262	return ipv6.ContainsNetwork(ipv6b)
263}
264
265// ContainsAddress returns true if the IPv6Address is contained within the
266// receiver.
267func (ipv6 IPv6Addr) ContainsAddress(x IPv6Address) bool {
268	xAddr := IPv6Addr{
269		Address: x,
270		Mask:    ipv6HostMask,
271	}
272
273	{
274		xIPv6 := xAddr.FirstUsable().(IPv6Addr)
275		yIPv6 := ipv6.FirstUsable().(IPv6Addr)
276		if xIPv6.CmpAddress(yIPv6) >= 1 {
277			return false
278		}
279	}
280
281	{
282		xIPv6 := xAddr.LastUsable().(IPv6Addr)
283		yIPv6 := ipv6.LastUsable().(IPv6Addr)
284		if xIPv6.CmpAddress(yIPv6) <= -1 {
285			return false
286		}
287	}
288	return true
289}
290
291// ContainsNetwork returns true if the network from IPv6Addr is contained within
292// the receiver.
293func (x IPv6Addr) ContainsNetwork(y IPv6Addr) bool {
294	{
295		xIPv6 := x.FirstUsable().(IPv6Addr)
296		yIPv6 := y.FirstUsable().(IPv6Addr)
297		if ret := xIPv6.CmpAddress(yIPv6); ret >= 1 {
298			return false
299		}
300	}
301
302	{
303		xIPv6 := x.LastUsable().(IPv6Addr)
304		yIPv6 := y.LastUsable().(IPv6Addr)
305		if ret := xIPv6.CmpAddress(yIPv6); ret <= -1 {
306			return false
307		}
308	}
309	return true
310}
311
312// DialPacketArgs returns the arguments required to be passed to
313// net.DialUDP().  If the Mask of ipv6 is not a /128 or the Port is 0,
314// DialPacketArgs() will fail.  See Host() to create an IPv6Addr with its
315// mask set to /128.
316func (ipv6 IPv6Addr) DialPacketArgs() (network, dialArgs string) {
317	ipv6Mask := big.Int(*ipv6.Mask)
318	if ipv6Mask.Cmp(ipv6HostMask) != 0 || ipv6.Port == 0 {
319		return "udp6", ""
320	}
321	return "udp6", fmt.Sprintf("[%s]:%d", ipv6.NetIP().String(), ipv6.Port)
322}
323
324// DialStreamArgs returns the arguments required to be passed to
325// net.DialTCP().  If the Mask of ipv6 is not a /128 or the Port is 0,
326// DialStreamArgs() will fail.  See Host() to create an IPv6Addr with its
327// mask set to /128.
328func (ipv6 IPv6Addr) DialStreamArgs() (network, dialArgs string) {
329	ipv6Mask := big.Int(*ipv6.Mask)
330	if ipv6Mask.Cmp(ipv6HostMask) != 0 || ipv6.Port == 0 {
331		return "tcp6", ""
332	}
333	return "tcp6", fmt.Sprintf("[%s]:%d", ipv6.NetIP().String(), ipv6.Port)
334}
335
336// Equal returns true if a SockAddr is equal to the receiving IPv4Addr.
337func (ipv6a IPv6Addr) Equal(sa SockAddr) bool {
338	ipv6b, ok := sa.(IPv6Addr)
339	if !ok {
340		return false
341	}
342
343	if ipv6a.NetIP().String() != ipv6b.NetIP().String() {
344		return false
345	}
346
347	if ipv6a.NetIPNet().String() != ipv6b.NetIPNet().String() {
348		return false
349	}
350
351	if ipv6a.Port != ipv6b.Port {
352		return false
353	}
354
355	return true
356}
357
358// FirstUsable returns an IPv6Addr set to the first address following the
359// network prefix.  The first usable address in a network is normally the
360// gateway and should not be used except by devices forwarding packets
361// between two administratively distinct networks (i.e. a router).  This
362// function does not discriminate against first usable vs "first address that
363// should be used."  For example, FirstUsable() on "2001:0db8::0003/64" would
364// return "2001:0db8::00011".
365func (ipv6 IPv6Addr) FirstUsable() IPAddr {
366	return IPv6Addr{
367		Address: IPv6Address(ipv6.NetworkAddress()),
368		Mask:    ipv6HostMask,
369	}
370}
371
372// Host returns a copy of ipv6 with its mask set to /128 so that it can be
373// used by DialPacketArgs(), DialStreamArgs(), ListenPacketArgs(), or
374// ListenStreamArgs().
375func (ipv6 IPv6Addr) Host() IPAddr {
376	// Nothing should listen on a broadcast address.
377	return IPv6Addr{
378		Address: ipv6.Address,
379		Mask:    ipv6HostMask,
380		Port:    ipv6.Port,
381	}
382}
383
384// IPPort returns the Port number attached to the IPv6Addr
385func (ipv6 IPv6Addr) IPPort() IPPort {
386	return ipv6.Port
387}
388
389// LastUsable returns the last address in a given network.
390func (ipv6 IPv6Addr) LastUsable() IPAddr {
391	addr := new(big.Int)
392	addr.Set(ipv6.Address)
393
394	mask := new(big.Int)
395	mask.Set(ipv6.Mask)
396
397	negMask := new(big.Int)
398	negMask.Xor(ipv6HostMask, mask)
399
400	lastAddr := new(big.Int)
401	lastAddr.And(addr, mask)
402	lastAddr.Or(lastAddr, negMask)
403
404	return IPv6Addr{
405		Address: IPv6Address(lastAddr),
406		Mask:    ipv6HostMask,
407	}
408}
409
410// ListenPacketArgs returns the arguments required to be passed to
411// net.ListenUDP().  If the Mask of ipv6 is not a /128, ListenPacketArgs()
412// will fail.  See Host() to create an IPv6Addr with its mask set to /128.
413func (ipv6 IPv6Addr) ListenPacketArgs() (network, listenArgs string) {
414	ipv6Mask := big.Int(*ipv6.Mask)
415	if ipv6Mask.Cmp(ipv6HostMask) != 0 {
416		return "udp6", ""
417	}
418	return "udp6", fmt.Sprintf("[%s]:%d", ipv6.NetIP().String(), ipv6.Port)
419}
420
421// ListenStreamArgs returns the arguments required to be passed to
422// net.ListenTCP().  If the Mask of ipv6 is not a /128, ListenStreamArgs()
423// will fail.  See Host() to create an IPv6Addr with its mask set to /128.
424func (ipv6 IPv6Addr) ListenStreamArgs() (network, listenArgs string) {
425	ipv6Mask := big.Int(*ipv6.Mask)
426	if ipv6Mask.Cmp(ipv6HostMask) != 0 {
427		return "tcp6", ""
428	}
429	return "tcp6", fmt.Sprintf("[%s]:%d", ipv6.NetIP().String(), ipv6.Port)
430}
431
432// Maskbits returns the number of network mask bits in a given IPv6Addr.  For
433// example, the Maskbits() of "2001:0db8::0003/64" would return 64.
434func (ipv6 IPv6Addr) Maskbits() int {
435	maskOnes, _ := ipv6.NetIPNet().Mask.Size()
436
437	return maskOnes
438}
439
440// MustIPv6Addr is a helper method that must return an IPv6Addr or panic on
441// invalid input.
442func MustIPv6Addr(addr string) IPv6Addr {
443	ipv6, err := NewIPv6Addr(addr)
444	if err != nil {
445		panic(fmt.Sprintf("Unable to create an IPv6Addr from %+q: %v", addr, err))
446	}
447	return ipv6
448}
449
450// NetIP returns the address as a net.IP.
451func (ipv6 IPv6Addr) NetIP() *net.IP {
452	return bigIntToNetIPv6(ipv6.Address)
453}
454
455// NetIPMask create a new net.IPMask from the IPv6Addr.
456func (ipv6 IPv6Addr) NetIPMask() *net.IPMask {
457	ipv6Mask := make(net.IPMask, IPv6len)
458	m := big.Int(*ipv6.Mask)
459	copy(ipv6Mask, m.Bytes())
460	return &ipv6Mask
461}
462
463// Network returns a pointer to the net.IPNet within IPv4Addr receiver.
464func (ipv6 IPv6Addr) NetIPNet() *net.IPNet {
465	ipv6net := &net.IPNet{}
466	ipv6net.IP = make(net.IP, IPv6len)
467	copy(ipv6net.IP, *ipv6.NetIP())
468	ipv6net.Mask = *ipv6.NetIPMask()
469	return ipv6net
470}
471
472// Network returns the network prefix or network address for a given network.
473func (ipv6 IPv6Addr) Network() IPAddr {
474	return IPv6Addr{
475		Address: IPv6Address(ipv6.NetworkAddress()),
476		Mask:    ipv6.Mask,
477	}
478}
479
480// NetworkAddress returns an IPv6Network of the IPv6Addr's network address.
481func (ipv6 IPv6Addr) NetworkAddress() IPv6Network {
482	addr := new(big.Int)
483	addr.SetBytes((*ipv6.Address).Bytes())
484
485	mask := new(big.Int)
486	mask.SetBytes(*ipv6.NetIPMask())
487
488	netAddr := new(big.Int)
489	netAddr.And(addr, mask)
490
491	return IPv6Network(netAddr)
492}
493
494// Octets returns a slice of the 16 octets in an IPv6Addr's Address.  The
495// order of the bytes is big endian.
496func (ipv6 IPv6Addr) Octets() []int {
497	x := make([]int, IPv6len)
498	for i, b := range *bigIntToNetIPv6(ipv6.Address) {
499		x[i] = int(b)
500	}
501
502	return x
503}
504
505// String returns a string representation of the IPv6Addr
506func (ipv6 IPv6Addr) String() string {
507	if ipv6.Port != 0 {
508		return fmt.Sprintf("[%s]:%d", ipv6.NetIP().String(), ipv6.Port)
509	}
510
511	if ipv6.Maskbits() == 128 {
512		return ipv6.NetIP().String()
513	}
514
515	return fmt.Sprintf("%s/%d", ipv6.NetIP().String(), ipv6.Maskbits())
516}
517
518// Type is used as a type switch and returns TypeIPv6
519func (IPv6Addr) Type() SockAddrType {
520	return TypeIPv6
521}
522
523// IPv6Attrs returns a list of attributes supported by the IPv6Addr type
524func IPv6Attrs() []AttrName {
525	return ipv6AddrAttrs
526}
527
528// IPv6AddrAttr returns a string representation of an attribute for the given
529// IPv6Addr.
530func IPv6AddrAttr(ipv6 IPv6Addr, selector AttrName) string {
531	fn, found := ipv6AddrAttrMap[selector]
532	if !found {
533		return ""
534	}
535
536	return fn(ipv6)
537}
538
539// ipv6AddrInit is called once at init()
540func ipv6AddrInit() {
541	// Sorted for human readability
542	ipv6AddrAttrs = []AttrName{
543		"size", // Same position as in IPv6 for output consistency
544		"uint128",
545	}
546
547	ipv6AddrAttrMap = map[AttrName]func(ipv6 IPv6Addr) string{
548		"size": func(ipv6 IPv6Addr) string {
549			netSize := big.NewInt(1)
550			netSize = netSize.Lsh(netSize, uint(IPv6len*8-ipv6.Maskbits()))
551			return netSize.Text(10)
552		},
553		"uint128": func(ipv6 IPv6Addr) string {
554			b := big.Int(*ipv6.Address)
555			return b.Text(10)
556		},
557	}
558}
559
560// bigIntToNetIPv6 is a helper function that correctly returns a net.IP with the
561// correctly padded values.
562func bigIntToNetIPv6(bi *big.Int) *net.IP {
563	x := make(net.IP, IPv6len)
564	ipv6Bytes := bi.Bytes()
565
566	// It's possibe for ipv6Bytes to be less than IPv6len bytes in size.  If
567	// they are different sizes we to pad the size of response.
568	if len(ipv6Bytes) < IPv6len {
569		buf := new(bytes.Buffer)
570		buf.Grow(IPv6len)
571
572		for i := len(ipv6Bytes); i < IPv6len; i++ {
573			if err := binary.Write(buf, binary.BigEndian, byte(0)); err != nil {
574				panic(fmt.Sprintf("Unable to pad byte %d of input %v: %v", i, bi, err))
575			}
576		}
577
578		for _, b := range ipv6Bytes {
579			if err := binary.Write(buf, binary.BigEndian, b); err != nil {
580				panic(fmt.Sprintf("Unable to preserve endianness of input %v: %v", bi, err))
581			}
582		}
583
584		ipv6Bytes = buf.Bytes()
585	}
586	i := copy(x, ipv6Bytes)
587	if i != IPv6len {
588		panic("IPv6 wrong size")
589	}
590	return &x
591}
592