1package capture
2
3import (
4	"context"
5	"encoding/binary"
6	"io/ioutil"
7	"net"
8	"os"
9	"testing"
10	"time"
11
12	"github.com/google/gopacket"
13	"github.com/google/gopacket/layers"
14	"github.com/google/gopacket/pcap"
15)
16
17var LoopBack = func() net.Interface {
18	ifis, _ := net.Interfaces()
19	for _, v := range ifis {
20		if v.Flags&net.FlagLoopback != 0 {
21			return v
22		}
23	}
24	return ifis[0]
25}()
26
27func TestSetInterfaces(t *testing.T) {
28	l := &Listener{}
29	l.host = "127.0.0.1"
30	l.setInterfaces()
31	if len(l.Interfaces) != 1 {
32		t.Error("expected a single interface")
33	}
34	l.host = LoopBack.HardwareAddr.String()
35	l.setInterfaces()
36	if l.Interfaces[0].Name != LoopBack.Name && len(l.Interfaces) != 1 {
37		t.Error("interface should be loop back interface")
38	}
39	l.host = ""
40	l.setInterfaces()
41	if len(l.Interfaces) < 1 {
42		t.Error("should get all interfaces")
43	}
44}
45
46func TestBPFFilter(t *testing.T) {
47	l := &Listener{}
48	l.host = "127.0.0.1"
49	l.Transport = "tcp"
50	l.setInterfaces()
51	filter := l.Filter(l.Interfaces[0])
52	if filter != "(tcp dst portrange 0-65535 and host 127.0.0.1)" {
53		t.Error("wrong filter", filter)
54	}
55	l.port = 8000
56	l.trackResponse = true
57	filter = l.Filter(l.Interfaces[0])
58	if filter != "(tcp port 8000 and host 127.0.0.1)" {
59		t.Error("wrong filter")
60	}
61}
62
63var decodeOpts = gopacket.DecodeOptions{Lazy: true, NoCopy: true}
64
65func generateHeaders(seq uint32, length uint16) (headers [44]byte) {
66	// set ethernet headers
67	binary.BigEndian.PutUint32(headers[0:4], uint32(layers.ProtocolFamilyIPv4))
68
69	// set ip header
70	ip := headers[4:]
71	copy(ip[0:2], []byte{4<<4 | 5, 0x28<<2 | 0x00})
72	binary.BigEndian.PutUint16(ip[2:4], length+54)
73	ip[9] = uint8(layers.IPProtocolTCP)
74	copy(ip[12:16], []byte{127, 0, 0, 1})
75	copy(ip[16:], []byte{127, 0, 0, 1})
76
77	// set tcp header
78	tcp := ip[20:]
79	binary.BigEndian.PutUint16(tcp[0:2], 45678)
80	binary.BigEndian.PutUint16(tcp[2:4], 8000)
81	tcp[12] = 5 << 4
82	return
83}
84
85func randomPackets(start uint32, _len int, length uint16) []gopacket.Packet {
86	var packets = make([]gopacket.Packet, _len)
87	for i := start; i < start+uint32(_len); i++ {
88		h := generateHeaders(i, length)
89		d := make([]byte, int(length)+len(h))
90		copy(d, h[0:])
91		packet := gopacket.NewPacket(d, layers.LinkTypeLoop, decodeOpts)
92		packets[i-start] = packet
93		inf := packets[i-start].Metadata()
94		_len := len(d)
95		inf.CaptureInfo = gopacket.CaptureInfo{CaptureLength: _len, Length: _len, Timestamp: time.Now()}
96	}
97	return packets
98}
99
100func TestPcapDump(t *testing.T) {
101	f, err := ioutil.TempFile("", "pcap_file")
102	if err != nil {
103		t.Error(err)
104	}
105	waiter := make(chan bool, 1)
106	h, _ := PcapDumpHandler(f, layers.LinkTypeLoop, func(level int, a ...interface{}) {
107		if level != 3 {
108			t.Errorf("expected debug level to be 3, got %d", level)
109		}
110		waiter <- true
111	})
112	packets := randomPackets(1, 5, 5)
113	for i := 0; i < len(packets); i++ {
114		if i == 1 {
115			tcp := packets[i].Data()[4:][20:]
116			// change dst port
117			binary.BigEndian.PutUint16(tcp[2:], 8001)
118		}
119		if i == 4 {
120			inf := packets[i].Metadata()
121			inf.CaptureLength = 40
122		}
123		h(packets[i])
124	}
125	<-waiter
126	name := f.Name()
127	f.Close()
128	testPcapDumpEngine(name, t)
129}
130
131func testPcapDumpEngine(f string, t *testing.T) {
132	defer os.Remove(f)
133	l, err := NewListener(f, 8000, "", EnginePcapFile, true)
134	err = l.Activate()
135	if err != nil {
136		t.Errorf("expected error to be nil, got %q", err)
137		return
138	}
139	pckts := 0
140	ctx, cancel := context.WithCancel(context.Background())
141	defer cancel()
142	err = l.Listen(ctx, func(packet gopacket.Packet) {
143		if packet.Metadata().CaptureLength != 49 {
144			t.Errorf("expected packet length to be %d, got %d", 49, packet.Metadata().CaptureLength)
145		}
146		pckts++
147	})
148
149	if err != nil {
150		t.Errorf("expected error to be nil, got %q", err)
151	}
152	if pckts != 3 {
153		t.Errorf("expected %d packets, got %d packets", 3, pckts)
154	}
155}
156
157func TestPcapHandler(t *testing.T) {
158	l, err := NewListener(LoopBack.Name, 8000, "", EnginePcap, true)
159	if err != nil {
160		t.Errorf("expected error to be nil, got %v", err)
161		return
162	}
163	err = l.Activate()
164	if err != nil {
165		t.Errorf("expected error to be nil, got %v", err)
166		return
167	}
168	defer l.Handles[LoopBack.Name].(*pcap.Handle).Close()
169	if err != nil {
170		t.Errorf("expected error to be nil, got %v", err)
171		return
172	}
173	for i := 0; i < 5; i++ {
174		_, _ = net.Dial("tcp", "127.0.0.1:8000")
175	}
176	sts, _ := l.Handles[LoopBack.Name].(*pcap.Handle).Stats()
177	if sts.PacketsReceived < 5 {
178		t.Errorf("expected >=5 packets got %d", sts.PacketsReceived)
179	}
180}
181
182func TestSocketHandler(t *testing.T) {
183	l, err := NewListener(LoopBack.Name, 8000, "", EngineRawSocket, true)
184	err = l.Activate()
185	if err != nil {
186		return
187	}
188	defer l.Handles[LoopBack.Name].(*SockRaw).Close()
189	if err != nil {
190		t.Errorf("expected error to be nil, got %v", err)
191		return
192	}
193	for i := 0; i < 5; i++ {
194		_, _ = net.Dial("tcp", "127.0.0.1:8000")
195	}
196	sts, _ := l.Handles[LoopBack.Name].(*SockRaw).Stats()
197	if sts.Packets < 5 {
198		t.Errorf("expected >=5 packets got %d", sts.Packets)
199	}
200}
201
202func BenchmarkPcapDump(b *testing.B) {
203	f, err := ioutil.TempFile("", "pcap_file")
204	if err != nil {
205		b.Error(err)
206		return
207	}
208	now := time.Now()
209	defer os.Remove(f.Name())
210	h, _ := PcapDumpHandler(f, layers.LinkTypeLoop, nil)
211	packets := randomPackets(1, b.N, 5)
212	for i := 0; i < len(packets); i++ {
213		h(packets[i])
214	}
215	f.Close()
216	b.Logf("%d packets in %s", b.N, time.Since(now))
217}
218
219func BenchmarkPcapFile(b *testing.B) {
220	f, err := ioutil.TempFile("", "pcap_file")
221	if err != nil {
222		b.Error(err)
223		return
224	}
225	defer os.Remove(f.Name())
226	h, _ := PcapDumpHandler(f, layers.LinkTypeLoop, nil)
227	packets := randomPackets(1, b.N, 5)
228	for i := 0; i < len(packets); i++ {
229		h(packets[i])
230	}
231	name := f.Name()
232	f.Close()
233	b.ResetTimer()
234	var l *Listener
235	l, err = NewListener(name, 8000, "", EnginePcapFile, true)
236	if err != nil {
237		b.Error(err)
238		return
239	}
240	err = l.Activate()
241	if err != nil {
242		b.Error(err)
243		return
244	}
245	now := time.Now()
246	pckts := 0
247	ctx, cancel := context.WithCancel(context.Background())
248	defer cancel()
249	if err = l.Listen(ctx, func(packet gopacket.Packet) {
250		if packet.Metadata().CaptureLength != 49 {
251			b.Errorf("expected packet length to be %d, got %d", 49, packet.Metadata().CaptureLength)
252		}
253		pckts++
254	}); err != nil {
255		b.Error(err)
256	}
257	b.Logf("%d/%d packets in %s", pckts, b.N, time.Since(now))
258}
259
260func BenchmarkPcap(b *testing.B) {
261	now := time.Now()
262	var err error
263
264	l, err := NewListener(LoopBack.Name, 8000, "", EnginePcap, true)
265	if err != nil {
266		b.Errorf("expected error to be nil, got %v", err)
267		return
268	}
269	err = l.Activate()
270	if err != nil {
271		b.Errorf("expected error to be nil, got %v", err)
272		return
273	}
274	defer l.Handles[LoopBack.Name].(*pcap.Handle).Close()
275	for i := 0; i < b.N; i++ {
276		_, _ = net.Dial("tcp", "127.0.0.1:8000")
277	}
278	sts, _ := l.Handles[LoopBack.Name].(*pcap.Handle).Stats()
279	b.Logf("%d packets in %s", sts.PacketsReceived, time.Since(now))
280}
281
282func BenchmarkRawSocket(b *testing.B) {
283	now := time.Now()
284	var err error
285
286	l, err := NewListener(LoopBack.Name, 0, "", EngineRawSocket, true)
287	err = l.Activate()
288	if err != nil {
289		return
290	}
291	defer l.Handles[LoopBack.Name].(*SockRaw).Close()
292	for i := 0; i < b.N; i++ {
293		_, _ = net.Dial("tcp", "127.0.0.1:8000")
294	}
295	sts, _ := l.Handles[LoopBack.Name].(*SockRaw).Stats()
296	b.Logf("%d packets in %s", sts.Packets, time.Since(now))
297}
298