1package netlink
2
3import (
4	"errors"
5	"fmt"
6	"net"
7
8	"github.com/vishvananda/netlink/nl"
9	"golang.org/x/sys/unix"
10)
11
12const (
13	sizeofSocketID      = 0x30
14	sizeofSocketRequest = sizeofSocketID + 0x8
15	sizeofSocket        = sizeofSocketID + 0x18
16)
17
18type socketRequest struct {
19	Family   uint8
20	Protocol uint8
21	Ext      uint8
22	pad      uint8
23	States   uint32
24	ID       SocketID
25}
26
27type writeBuffer struct {
28	Bytes []byte
29	pos   int
30}
31
32func (b *writeBuffer) Write(c byte) {
33	b.Bytes[b.pos] = c
34	b.pos++
35}
36
37func (b *writeBuffer) Next(n int) []byte {
38	s := b.Bytes[b.pos : b.pos+n]
39	b.pos += n
40	return s
41}
42
43func (r *socketRequest) Serialize() []byte {
44	b := writeBuffer{Bytes: make([]byte, sizeofSocketRequest)}
45	b.Write(r.Family)
46	b.Write(r.Protocol)
47	b.Write(r.Ext)
48	b.Write(r.pad)
49	native.PutUint32(b.Next(4), r.States)
50	networkOrder.PutUint16(b.Next(2), r.ID.SourcePort)
51	networkOrder.PutUint16(b.Next(2), r.ID.DestinationPort)
52	copy(b.Next(4), r.ID.Source.To4())
53	b.Next(12)
54	copy(b.Next(4), r.ID.Destination.To4())
55	b.Next(12)
56	native.PutUint32(b.Next(4), r.ID.Interface)
57	native.PutUint32(b.Next(4), r.ID.Cookie[0])
58	native.PutUint32(b.Next(4), r.ID.Cookie[1])
59	return b.Bytes
60}
61
62func (r *socketRequest) Len() int { return sizeofSocketRequest }
63
64type readBuffer struct {
65	Bytes []byte
66	pos   int
67}
68
69func (b *readBuffer) Read() byte {
70	c := b.Bytes[b.pos]
71	b.pos++
72	return c
73}
74
75func (b *readBuffer) Next(n int) []byte {
76	s := b.Bytes[b.pos : b.pos+n]
77	b.pos += n
78	return s
79}
80
81func (s *Socket) deserialize(b []byte) error {
82	if len(b) < sizeofSocket {
83		return fmt.Errorf("socket data short read (%d); want %d", len(b), sizeofSocket)
84	}
85	rb := readBuffer{Bytes: b}
86	s.Family = rb.Read()
87	s.State = rb.Read()
88	s.Timer = rb.Read()
89	s.Retrans = rb.Read()
90	s.ID.SourcePort = networkOrder.Uint16(rb.Next(2))
91	s.ID.DestinationPort = networkOrder.Uint16(rb.Next(2))
92	s.ID.Source = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read())
93	rb.Next(12)
94	s.ID.Destination = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read())
95	rb.Next(12)
96	s.ID.Interface = native.Uint32(rb.Next(4))
97	s.ID.Cookie[0] = native.Uint32(rb.Next(4))
98	s.ID.Cookie[1] = native.Uint32(rb.Next(4))
99	s.Expires = native.Uint32(rb.Next(4))
100	s.RQueue = native.Uint32(rb.Next(4))
101	s.WQueue = native.Uint32(rb.Next(4))
102	s.UID = native.Uint32(rb.Next(4))
103	s.INode = native.Uint32(rb.Next(4))
104	return nil
105}
106
107// SocketGet returns the Socket identified by its local and remote addresses.
108func SocketGet(local, remote net.Addr) (*Socket, error) {
109	localTCP, ok := local.(*net.TCPAddr)
110	if !ok {
111		return nil, ErrNotImplemented
112	}
113	remoteTCP, ok := remote.(*net.TCPAddr)
114	if !ok {
115		return nil, ErrNotImplemented
116	}
117	localIP := localTCP.IP.To4()
118	if localIP == nil {
119		return nil, ErrNotImplemented
120	}
121	remoteIP := remoteTCP.IP.To4()
122	if remoteIP == nil {
123		return nil, ErrNotImplemented
124	}
125
126	s, err := nl.Subscribe(unix.NETLINK_INET_DIAG)
127	if err != nil {
128		return nil, err
129	}
130	defer s.Close()
131	req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, 0)
132	req.AddData(&socketRequest{
133		Family:   unix.AF_INET,
134		Protocol: unix.IPPROTO_TCP,
135		ID: SocketID{
136			SourcePort:      uint16(localTCP.Port),
137			DestinationPort: uint16(remoteTCP.Port),
138			Source:          localIP,
139			Destination:     remoteIP,
140			Cookie:          [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE},
141		},
142	})
143	s.Send(req)
144	msgs, from, err := s.Receive()
145	if err != nil {
146		return nil, err
147	}
148	if from.Pid != nl.PidKernel {
149		return nil, fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
150	}
151	if len(msgs) == 0 {
152		return nil, errors.New("no message nor error from netlink")
153	}
154	if len(msgs) > 2 {
155		return nil, fmt.Errorf("multiple (%d) matching sockets", len(msgs))
156	}
157	sock := &Socket{}
158	if err := sock.deserialize(msgs[0].Data); err != nil {
159		return nil, err
160	}
161	return sock, nil
162}
163