1// Copyright 2012 Google, Inc. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style license
4// that can be found in the LICENSE file in the root of the source
5// tree.
6
7// This binary provides an example of connecting up bidirectional streams from
8// the unidirectional streams provided by gopacket/tcpassembly.
9package main
10
11import (
12	"flag"
13	"fmt"
14	"github.com/google/gopacket"
15	"github.com/google/gopacket/examples/util"
16	"github.com/google/gopacket/layers"
17	"github.com/google/gopacket/pcap"
18	"github.com/google/gopacket/tcpassembly"
19	"log"
20	"time"
21)
22
23var iface = flag.String("i", "eth0", "Interface to get packets from")
24var snaplen = flag.Int("s", 16<<10, "SnapLen for pcap packet capture")
25var filter = flag.String("f", "tcp", "BPF filter for pcap")
26var logAllPackets = flag.Bool("v", false, "Logs every packet in great detail")
27
28// key is used to map bidirectional streams to each other.
29type key struct {
30	net, transport gopacket.Flow
31}
32
33// String prints out the key in a human-readable fashion.
34func (k key) String() string {
35	return fmt.Sprintf("%v:%v", k.net, k.transport)
36}
37
38// timeout is the length of time to wait befor flushing connections and
39// bidirectional stream pairs.
40const timeout time.Duration = time.Minute * 5
41
42// myStream implements tcpassembly.Stream
43type myStream struct {
44	bytes int64 // total bytes seen on this stream.
45	bidi  *bidi // maps to my bidirectional twin.
46	done  bool  // if true, we've seen the last packet we're going to for this stream.
47}
48
49// bidi stores each unidirectional side of a bidirectional stream.
50//
51// When a new stream comes in, if we don't have an opposite stream, a bidi is
52// created with 'a' set to the new stream.  If we DO have an opposite stream,
53// 'b' is set to the new stream.
54type bidi struct {
55	key            key       // Key of the first stream, mostly for logging.
56	a, b           *myStream // the two bidirectional streams.
57	lastPacketSeen time.Time // last time we saw a packet from either stream.
58}
59
60// myFactory implements tcpassmebly.StreamFactory
61type myFactory struct {
62	// bidiMap maps keys to bidirectional stream pairs.
63	bidiMap map[key]*bidi
64}
65
66// New handles creating a new tcpassembly.Stream.
67func (f *myFactory) New(netFlow, tcpFlow gopacket.Flow) tcpassembly.Stream {
68	// Create a new stream.
69	s := &myStream{}
70
71	// Find the bidi bidirectional struct for this stream, creating a new one if
72	// one doesn't already exist in the map.
73	k := key{netFlow, tcpFlow}
74	bd := f.bidiMap[k]
75	if bd == nil {
76		bd = &bidi{a: s, key: k}
77		log.Printf("[%v] created first side of bidirectional stream", bd.key)
78		// Register bidirectional with the reverse key, so the matching stream going
79		// the other direction will find it.
80		f.bidiMap[key{netFlow.Reverse(), tcpFlow.Reverse()}] = bd
81	} else {
82		log.Printf("[%v] found second side of bidirectional stream", bd.key)
83		bd.b = s
84		// Clear out the bidi we're using from the map, just in case.
85		delete(f.bidiMap, k)
86	}
87	s.bidi = bd
88	return s
89}
90
91// emptyStream is used to finish bidi that only have one stream, in
92// collectOldStreams.
93var emptyStream = &myStream{done: true}
94
95// collectOldStreams finds any streams that haven't received a packet within
96// 'timeout', and sets/finishes the 'b' stream inside them.  The 'a' stream may
97// still receive packets after this.
98func (f *myFactory) collectOldStreams() {
99	cutoff := time.Now().Add(-timeout)
100	for k, bd := range f.bidiMap {
101		if bd.lastPacketSeen.Before(cutoff) {
102			log.Printf("[%v] timing out old stream", bd.key)
103			bd.b = emptyStream   // stub out b with an empty stream.
104			delete(f.bidiMap, k) // remove it from our map.
105			bd.maybeFinish()     // if b was the last stream we were waiting for, finish up.
106		}
107	}
108}
109
110// Reassembled handles reassembled TCP stream data.
111func (s *myStream) Reassembled(rs []tcpassembly.Reassembly) {
112	for _, r := range rs {
113		// For now, we'll simply count the bytes on each side of the TCP stream.
114		s.bytes += int64(len(r.Bytes))
115		if r.Skip > 0 {
116			s.bytes += int64(r.Skip)
117		}
118		// Mark that we've received new packet data.
119		// We could just use time.Now, but by using r.Seen we handle the case
120		// where packets are being read from a file and could be very old.
121		if s.bidi.lastPacketSeen.Before(r.Seen) {
122			s.bidi.lastPacketSeen = r.Seen
123		}
124	}
125}
126
127// ReassemblyComplete marks this stream as finished.
128func (s *myStream) ReassemblyComplete() {
129	s.done = true
130	s.bidi.maybeFinish()
131}
132
133// maybeFinish will wait until both directions are complete, then print out
134// stats.
135func (bd *bidi) maybeFinish() {
136	switch {
137	case bd.a == nil:
138		log.Fatalf("[%v] a should always be non-nil, since it's set when bidis are created", bd.key)
139	case !bd.a.done:
140		log.Printf("[%v] still waiting on first stream", bd.key)
141	case bd.b == nil:
142		log.Printf("[%v] no second stream yet", bd.key)
143	case !bd.b.done:
144		log.Printf("[%v] still waiting on second stream", bd.key)
145	default:
146		log.Printf("[%v] FINISHED, bytes: %d tx, %d rx", bd.key, bd.a.bytes, bd.b.bytes)
147	}
148}
149
150func main() {
151	defer util.Run()()
152	log.Printf("starting capture on interface %q", *iface)
153	// Set up pcap packet capture
154	handle, err := pcap.OpenLive(*iface, int32(*snaplen), true, pcap.BlockForever)
155	if err != nil {
156		panic(err)
157	}
158	if err := handle.SetBPFFilter(*filter); err != nil {
159		panic(err)
160	}
161
162	// Set up assembly
163	streamFactory := &myFactory{bidiMap: make(map[key]*bidi)}
164	streamPool := tcpassembly.NewStreamPool(streamFactory)
165	assembler := tcpassembly.NewAssembler(streamPool)
166	// Limit memory usage by auto-flushing connection state if we get over 100K
167	// packets in memory, or over 1000 for a single stream.
168	assembler.MaxBufferedPagesTotal = 100000
169	assembler.MaxBufferedPagesPerConnection = 1000
170
171	log.Println("reading in packets")
172	// Read in packets, pass to assembler.
173	packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
174	packets := packetSource.Packets()
175	ticker := time.Tick(timeout / 4)
176	for {
177		select {
178		case packet := <-packets:
179			if *logAllPackets {
180				log.Println(packet)
181			}
182			if packet.NetworkLayer() == nil || packet.TransportLayer() == nil || packet.TransportLayer().LayerType() != layers.LayerTypeTCP {
183				log.Println("Unusable packet")
184				continue
185			}
186			tcp := packet.TransportLayer().(*layers.TCP)
187			assembler.AssembleWithTimestamp(packet.NetworkLayer().NetworkFlow(), tcp, packet.Metadata().Timestamp)
188
189		case <-ticker:
190			// Every minute, flush connections that haven't seen activity in the past minute.
191			log.Println("---- FLUSHING ----")
192			assembler.FlushOlderThan(time.Now().Add(-timeout))
193			streamFactory.collectOldStreams()
194		}
195	}
196}
197