1package capture
2
3import (
4	"fmt"
5	"net"
6	"sync"
7	"time"
8	"unsafe"
9
10	"golang.org/x/sys/unix"
11
12	"github.com/google/gopacket"
13	"github.com/google/gopacket/layers"
14	"github.com/google/gopacket/pcap"
15)
16
17const (
18	// ETHALL htons(ETH_P_ALL)
19	ETHALL uint16 = unix.ETH_P_ALL<<8 | unix.ETH_P_ALL>>8
20	// BLOCKSIZE ring buffer block_size
21	BLOCKSIZE = 64 << 10
22	// BLOCKNR ring buffer block_nr
23	BLOCKNR = (2 << 20) / BLOCKSIZE // 2mb / 64kb
24	// FRAMESIZE ring buffer frame_size
25	FRAMESIZE = BLOCKSIZE
26	// FRAMENR ring buffer frame_nr
27	FRAMENR = BLOCKNR * BLOCKSIZE / FRAMESIZE
28	// MAPHUGE2MB 2mb huge map
29	MAPHUGE2MB = 21 << unix.MAP_HUGE_SHIFT
30)
31
32var tpacket2hdrlen = tpAlign(int(unsafe.Sizeof(unix.Tpacket2Hdr{})))
33
34// SockRaw is a linux M'maped af_packet socket
35type SockRaw struct {
36	mu          sync.Mutex
37	fd          int
38	ifindex     int
39	snaplen     int
40	pollTimeout uintptr
41	frame       uint32 // current frame
42	buf         []byte // points to the memory space of the ring buffer shared with the kernel.
43	loopIndex   int32  // this field must filled to avoid reading packet twice on a loopback device
44}
45
46// NewSocket returns new M'maped sock_raw on packet version 2.
47func NewSocket(ifi net.Interface) (*SockRaw, error) {
48	// sock create
49	fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(ETHALL))
50	if err != nil {
51		return nil, err
52	}
53	sock := &SockRaw{
54		fd:          fd,
55		ifindex:     ifi.Index,
56		snaplen:     FRAMESIZE,
57		pollTimeout: ^uintptr(0),
58	}
59
60	// set packet version
61	err = unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_VERSION, unix.TPACKET_V2)
62	if err != nil {
63		unix.Close(fd)
64		return nil, fmt.Errorf("setsockopt packet_version: %v", err)
65	}
66
67	// bind to interface
68	addr := unix.RawSockaddrLinklayer{
69		Family:   unix.AF_PACKET,
70		Protocol: ETHALL,
71		Ifindex:  int32(ifi.Index),
72	}
73	_, _, e := unix.Syscall(
74		unix.SYS_BIND,
75		uintptr(fd),
76		uintptr(unsafe.Pointer(&addr)),
77		uintptr(unix.SizeofSockaddrLinklayer),
78	)
79	if e != 0 {
80		unix.Close(fd)
81		return nil, e
82	}
83
84	// create shared-memory ring buffer
85	tp := &unix.TpacketReq{
86		Block_size: BLOCKSIZE,
87		Block_nr:   BLOCKNR,
88		Frame_size: FRAMESIZE,
89		Frame_nr:   FRAMENR,
90	}
91	err = unix.SetsockoptTpacketReq(sock.fd, unix.SOL_PACKET, unix.PACKET_RX_RING, tp)
92	if err != nil {
93		unix.Close(fd)
94		return nil, fmt.Errorf("setsockopt packet_rx_ring: %v", err)
95	}
96	sock.buf, err = unix.Mmap(
97		sock.fd,
98		0,
99		BLOCKSIZE*BLOCKNR,
100		unix.PROT_READ|unix.PROT_WRITE,
101		unix.MAP_SHARED|MAPHUGE2MB,
102	)
103	if err != nil {
104		unix.Close(fd)
105		return nil, fmt.Errorf("socket mmap error: %v", err)
106	}
107	return sock, nil
108}
109
110// ReadPacketData implements gopacket.PacketDataSource.
111func (sock *SockRaw) ReadPacketData() (buf []byte, ci gopacket.CaptureInfo, err error) {
112	sock.mu.Lock()
113	defer sock.mu.Unlock()
114	var tpHdr *unix.Tpacket2Hdr
115	poll := &unix.PollFd{
116		Fd:     int32(sock.fd),
117		Events: unix.POLLIN,
118	}
119	var i int
120read:
121	i = int(sock.frame * FRAMESIZE)
122	tpHdr = (*unix.Tpacket2Hdr)(unsafe.Pointer(&sock.buf[i]))
123	sock.frame = (sock.frame + 1) % FRAMENR
124
125	if tpHdr.Status&unix.TP_STATUS_USER == 0 {
126		_, _, e := unix.Syscall(unix.SYS_POLL, uintptr(unsafe.Pointer(poll)), 1, sock.pollTimeout)
127		if e != 0 && e != unix.EINTR {
128			return buf, ci, e
129		}
130		// it might be some other frame with data!
131		if tpHdr.Status&unix.TP_STATUS_USER == 0 {
132			goto read
133		}
134	}
135	tpHdr.Status = unix.TP_STATUS_KERNEL
136	sockAddr := (*unix.RawSockaddrLinklayer)(unsafe.Pointer(&sock.buf[i+tpacket2hdrlen]))
137
138	// parse out repeating packets on loopback
139	if sockAddr.Ifindex == sock.loopIndex && sock.frame%2 != 0 {
140		goto read
141	}
142
143	ci.Length = int(tpHdr.Len)
144	ci.Timestamp = time.Unix(int64(tpHdr.Sec), int64(tpHdr.Nsec))
145	ci.InterfaceIndex = int(sockAddr.Ifindex)
146	buf = make([]byte, tpHdr.Snaplen)
147	ci.CaptureLength = copy(buf, sock.buf[i+int(tpHdr.Mac):])
148
149	return
150}
151
152// Close closes the underlying socket
153func (sock *SockRaw) Close() (err error) {
154	sock.mu.Lock()
155	defer sock.mu.Unlock()
156	if sock.fd != -1 {
157		unix.Munmap(sock.buf)
158		sock.buf = nil
159		err = unix.Close(sock.fd)
160		sock.fd = -1
161	}
162	return
163}
164
165// SetSnapLen sets the maximum capture length to the given value.
166// for this to take effects on the kernel level SetBPFilter should be called too.
167func (sock *SockRaw) SetSnapLen(snap int) error {
168	sock.mu.Lock()
169	defer sock.mu.Unlock()
170	if snap < 0 {
171		return fmt.Errorf("expected %d snap length to be at least 0", snap)
172	}
173	if snap > FRAMESIZE {
174		snap = FRAMESIZE
175	}
176	sock.snaplen = snap
177	return nil
178}
179
180// SetTimeout sets poll wait timeout for the socket.
181// negative value will block forever
182func (sock *SockRaw) SetTimeout(t time.Duration) error {
183	sock.mu.Lock()
184	defer sock.mu.Unlock()
185	sock.pollTimeout = uintptr(t)
186	return nil
187}
188
189// GetSnapLen returns the maximum capture length
190func (sock *SockRaw) GetSnapLen() int {
191	sock.mu.Lock()
192	defer sock.mu.Unlock()
193	return sock.snaplen
194}
195
196// SetBPFFilter compiles and sets a BPF filter for the socket handle.
197func (sock *SockRaw) SetBPFFilter(expr string) error {
198	sock.mu.Lock()
199	defer sock.mu.Unlock()
200	if len(expr) == 0 {
201		return unix.SetsockoptInt(sock.fd, unix.SOL_SOCKET, unix.SO_DETACH_FILTER, 0)
202	}
203	filter, err := pcap.CompileBPFFilter(layers.LinkTypeEthernet, sock.snaplen, expr)
204	if err != nil {
205		return err
206	}
207	if len(filter) > int(^uint16(0)) {
208		return fmt.Errorf("filters out of range 0-%d", ^uint16(0))
209	}
210	if len(filter) == 0 {
211		return unix.SetsockoptInt(sock.fd, unix.SOL_SOCKET, unix.SO_DETACH_FILTER, 0)
212	}
213	fprog := &unix.SockFprog{
214		Len:    uint16(len(filter)),
215		Filter: &(*(*[]unix.SockFilter)(unsafe.Pointer(&filter)))[0],
216	}
217	return unix.SetsockoptSockFprog(sock.fd, unix.SOL_SOCKET, unix.SO_ATTACH_FILTER, fprog)
218}
219
220// SetPromiscuous sets promiscous mode to the required value. for better result capture on all interfaces instead.
221// If it is enabled, traffic not destined for the interface will also be captured.
222func (sock *SockRaw) SetPromiscuous(b bool) error {
223	sock.mu.Lock()
224	defer sock.mu.Unlock()
225	mreq := unix.PacketMreq{
226		Ifindex: int32(sock.ifindex),
227		Type:    unix.PACKET_MR_PROMISC,
228	}
229
230	opt := unix.PACKET_ADD_MEMBERSHIP
231	if !b {
232		opt = unix.PACKET_DROP_MEMBERSHIP
233	}
234
235	return unix.SetsockoptPacketMreq(sock.fd, unix.SOL_PACKET, opt, &mreq)
236}
237
238// Stats returns number of packets and dropped packets. This will be the number of packets/dropped packets since the last call to stats (not the cummulative sum!).
239func (sock *SockRaw) Stats() (*unix.TpacketStats, error) {
240	sock.mu.Lock()
241	defer sock.mu.Unlock()
242	return unix.GetsockoptTpacketStats(sock.fd, unix.SOL_PACKET, unix.PACKET_STATISTICS)
243}
244
245// SetLoopbackIndex necessary to avoid reading packet twice on a loopback device
246func (sock *SockRaw) SetLoopbackIndex(i int32) {
247	sock.mu.Lock()
248	defer sock.mu.Unlock()
249	sock.loopIndex = i
250}
251
252// WritePacketData transmits a raw packet.
253func (sock *SockRaw) WritePacketData(pkt []byte) error {
254	_, err := unix.Write(sock.fd, pkt)
255	return err
256}
257
258func tpAlign(x int) int {
259	return int((uint(x) + unix.TPACKET_ALIGNMENT - 1) &^ (unix.TPACKET_ALIGNMENT - 1))
260}
261