1// Copyright 2012 Google, Inc. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style license
4// that can be found in the LICENSE file in the root of the source
5// tree.
6
7// Package reassembly provides TCP stream re-assembly.
8//
9// The reassembly package implements uni-directional TCP reassembly, for use in
10// packet-sniffing applications.  The caller reads packets off the wire, then
11// presents them to an Assembler in the form of gopacket layers.TCP packets
12// (github.com/google/gopacket, github.com/google/gopacket/layers).
13//
14// The Assembler uses a user-supplied
15// StreamFactory to create a user-defined Stream interface, then passes packet
16// data in stream order to that object.  A concurrency-safe StreamPool keeps
17// track of all current Streams being reassembled, so multiple Assemblers may
18// run at once to assemble packets while taking advantage of multiple cores.
19//
20// TODO: Add simplest example
21package reassembly
22
23import (
24	"encoding/hex"
25	"flag"
26	"fmt"
27	"log"
28	"sync"
29	"time"
30
31	"github.com/google/gopacket"
32	"github.com/google/gopacket/layers"
33)
34
35// TODO:
36// - push to Stream on Ack
37// - implement chunked (cheap) reads and Reader() interface
38// - better organize file: split files: 'mem', 'misc' (seq + flow)
39
40var defaultDebug = false
41
42var debugLog = flag.Bool("assembly_debug_log", defaultDebug, "If true, the github.com/google/gopacket/reassembly library will log verbose debugging information (at least one line per packet)")
43
44const invalidSequence = -1
45const uint32Max = 0xFFFFFFFF
46
47// Sequence is a TCP sequence number.  It provides a few convenience functions
48// for handling TCP wrap-around.  The sequence should always be in the range
49// [0,0xFFFFFFFF]... its other bits are simply used in wrap-around calculations
50// and should never be set.
51type Sequence int64
52
53// Difference defines an ordering for comparing TCP sequences that's safe for
54// roll-overs.  It returns:
55//    > 0 : if t comes after s
56//    < 0 : if t comes before s
57//      0 : if t == s
58// The number returned is the sequence difference, so 4.Difference(8) will
59// return 4.
60//
61// It handles rollovers by considering any sequence in the first quarter of the
62// uint32 space to be after any sequence in the last quarter of that space, thus
63// wrapping the uint32 space.
64func (s Sequence) Difference(t Sequence) int {
65	if s > uint32Max-uint32Max/4 && t < uint32Max/4 {
66		t += uint32Max
67	} else if t > uint32Max-uint32Max/4 && s < uint32Max/4 {
68		s += uint32Max
69	}
70	return int(t - s)
71}
72
73// Add adds an integer to a sequence and returns the resulting sequence.
74func (s Sequence) Add(t int) Sequence {
75	return (s + Sequence(t)) & uint32Max
76}
77
78// TCPAssemblyStats provides some figures for a ScatterGather
79type TCPAssemblyStats struct {
80	// For this ScatterGather
81	Chunks  int
82	Packets int
83	// For the half connection, since last call to ReassembledSG()
84	QueuedBytes    int
85	QueuedPackets  int
86	OverlapBytes   int
87	OverlapPackets int
88}
89
90// ScatterGather is used to pass reassembled data and metadata of reassembled
91// packets to a Stream via ReassembledSG
92type ScatterGather interface {
93	// Returns the length of available bytes and saved bytes
94	Lengths() (int, int)
95	// Returns the bytes up to length (shall be <= available bytes)
96	Fetch(length int) []byte
97	// Tell to keep from offset
98	KeepFrom(offset int)
99	// Return CaptureInfo of packet corresponding to given offset
100	CaptureInfo(offset int) gopacket.CaptureInfo
101	// Return some info about the reassembled chunks
102	Info() (direction TCPFlowDirection, start bool, end bool, skip int)
103	// Return some stats regarding the state of the stream
104	Stats() TCPAssemblyStats
105}
106
107// byteContainer is either a page or a livePacket
108type byteContainer interface {
109	getBytes() []byte
110	length() int
111	convertToPages(*pageCache, int, AssemblerContext) (*page, *page, int)
112	captureInfo() gopacket.CaptureInfo
113	assemblerContext() AssemblerContext
114	release(*pageCache) int
115	isStart() bool
116	isEnd() bool
117	getSeq() Sequence
118	isPacket() bool
119}
120
121// Implements a ScatterGather
122type reassemblyObject struct {
123	all       []byteContainer
124	Skip      int
125	Direction TCPFlowDirection
126	saved     int
127	toKeep    int
128	// stats
129	queuedBytes    int
130	queuedPackets  int
131	overlapBytes   int
132	overlapPackets int
133}
134
135func (rl *reassemblyObject) Lengths() (int, int) {
136	l := 0
137	for _, r := range rl.all {
138		l += r.length()
139	}
140	return l, rl.saved
141}
142
143func (rl *reassemblyObject) Fetch(l int) []byte {
144	if l <= rl.all[0].length() {
145		return rl.all[0].getBytes()[:l]
146	}
147	bytes := make([]byte, 0, l)
148	for _, bc := range rl.all {
149		bytes = append(bytes, bc.getBytes()...)
150	}
151	return bytes[:l]
152}
153
154func (rl *reassemblyObject) KeepFrom(offset int) {
155	rl.toKeep = offset
156}
157
158func (rl *reassemblyObject) CaptureInfo(offset int) gopacket.CaptureInfo {
159	current := 0
160	var r byteContainer
161	for _, r = range rl.all {
162		if current >= offset {
163			return r.captureInfo()
164		}
165		current += r.length()
166	}
167	if r != nil && current >= offset {
168		return r.captureInfo()
169	}
170	// Invalid offset
171	return gopacket.CaptureInfo{}
172}
173
174func (rl *reassemblyObject) Info() (TCPFlowDirection, bool, bool, int) {
175	return rl.Direction, rl.all[0].isStart(), rl.all[len(rl.all)-1].isEnd(), rl.Skip
176}
177
178func (rl *reassemblyObject) Stats() TCPAssemblyStats {
179	packets := int(0)
180	for _, r := range rl.all {
181		if r.isPacket() {
182			packets++
183		}
184	}
185	return TCPAssemblyStats{
186		Chunks:         len(rl.all),
187		Packets:        packets,
188		QueuedBytes:    rl.queuedBytes,
189		QueuedPackets:  rl.queuedPackets,
190		OverlapBytes:   rl.overlapBytes,
191		OverlapPackets: rl.overlapPackets,
192	}
193}
194
195const pageBytes = 1900
196
197// TCPFlowDirection distinguish the two half-connections directions.
198//
199// TCPDirClientToServer is assigned to half-connection for the first received
200// packet, hence might be wrong if packets are not received in order.
201// It's up to the caller (e.g. in Accept()) to decide if the direction should
202// be interpretted differently.
203type TCPFlowDirection bool
204
205// Value are not really useful
206const (
207	TCPDirClientToServer TCPFlowDirection = false
208	TCPDirServerToClient TCPFlowDirection = true
209)
210
211func (dir TCPFlowDirection) String() string {
212	switch dir {
213	case TCPDirClientToServer:
214		return "client->server"
215	case TCPDirServerToClient:
216		return "server->client"
217	}
218	return ""
219}
220
221// Reverse returns the reversed direction
222func (dir TCPFlowDirection) Reverse() TCPFlowDirection {
223	return !dir
224}
225
226/* page: implements a byteContainer */
227
228// page is used to store TCP data we're not ready for yet (out-of-order
229// packets).  Unused pages are stored in and returned from a pageCache, which
230// avoids memory allocation.  Used pages are stored in a doubly-linked list in
231// a connection.
232type page struct {
233	bytes      []byte
234	seq        Sequence
235	prev, next *page
236	buf        [pageBytes]byte
237	ac         AssemblerContext // only set for the first page of a packet
238	seen       time.Time
239	start, end bool
240}
241
242func (p *page) getBytes() []byte {
243	return p.bytes
244}
245func (p *page) captureInfo() gopacket.CaptureInfo {
246	return p.ac.GetCaptureInfo()
247}
248func (p *page) assemblerContext() AssemblerContext {
249	return p.ac
250}
251func (p *page) convertToPages(pc *pageCache, skip int, ac AssemblerContext) (*page, *page, int) {
252	if skip != 0 {
253		p.bytes = p.bytes[skip:]
254		p.seq = p.seq.Add(skip)
255	}
256	p.prev, p.next = nil, nil
257	return p, p, 1
258}
259func (p *page) length() int {
260	return len(p.bytes)
261}
262func (p *page) release(pc *pageCache) int {
263	pc.replace(p)
264	return 1
265}
266func (p *page) isStart() bool {
267	return p.start
268}
269func (p *page) isEnd() bool {
270	return p.end
271}
272func (p *page) getSeq() Sequence {
273	return p.seq
274}
275func (p *page) isPacket() bool {
276	return p.ac != nil
277}
278func (p *page) String() string {
279	return fmt.Sprintf("page@%p{seq: %v, bytes:%d, -> nextSeq:%v} (prev:%p, next:%p)", p, p.seq, len(p.bytes), p.seq+Sequence(len(p.bytes)), p.prev, p.next)
280}
281
282/* livePacket: implements a byteContainer */
283type livePacket struct {
284	bytes []byte
285	start bool
286	end   bool
287	ci    gopacket.CaptureInfo
288	ac    AssemblerContext
289	seq   Sequence
290}
291
292func (lp *livePacket) getBytes() []byte {
293	return lp.bytes
294}
295func (lp *livePacket) captureInfo() gopacket.CaptureInfo {
296	return lp.ci
297}
298func (lp *livePacket) assemblerContext() AssemblerContext {
299	return lp.ac
300}
301func (lp *livePacket) length() int {
302	return len(lp.bytes)
303}
304func (lp *livePacket) isStart() bool {
305	return lp.start
306}
307func (lp *livePacket) isEnd() bool {
308	return lp.end
309}
310func (lp *livePacket) getSeq() Sequence {
311	return lp.seq
312}
313func (lp *livePacket) isPacket() bool {
314	return true
315}
316
317// Creates a page (or set of pages) from a TCP packet: returns the first and last
318// page in its doubly-linked list of new pages.
319func (lp *livePacket) convertToPages(pc *pageCache, skip int, ac AssemblerContext) (*page, *page, int) {
320	ts := lp.ci.Timestamp
321	first := pc.next(ts)
322	current := first
323	current.prev = nil
324	first.ac = ac
325	numPages := 1
326	seq, bytes := lp.seq.Add(skip), lp.bytes[skip:]
327	for {
328		length := min(len(bytes), pageBytes)
329		current.bytes = current.buf[:length]
330		copy(current.bytes, bytes)
331		current.seq = seq
332		bytes = bytes[length:]
333		if len(bytes) == 0 {
334			current.end = lp.isEnd()
335			current.next = nil
336			break
337		}
338		seq = seq.Add(length)
339		current.next = pc.next(ts)
340		current.next.prev = current
341		current = current.next
342		current.ac = nil
343		numPages++
344	}
345	return first, current, numPages
346}
347func (lp *livePacket) estimateNumberOfPages() int {
348	return (len(lp.bytes) + pageBytes + 1) / pageBytes
349}
350
351func (lp *livePacket) release(*pageCache) int {
352	return 0
353}
354
355// Stream is implemented by the caller to handle incoming reassembled
356// TCP data.  Callers create a StreamFactory, then StreamPool uses
357// it to create a new Stream for every TCP stream.
358//
359// assembly will, in order:
360//    1) Create the stream via StreamFactory.New
361//    2) Call ReassembledSG 0 or more times, passing in reassembled TCP data in order
362//    3) Call ReassemblyComplete one time, after which the stream is dereferenced by assembly.
363type Stream interface {
364	// Tell whether the TCP packet should be accepted, start could be modified to force a start even if no SYN have been seen
365	Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir TCPFlowDirection, nextSeq Sequence, start *bool, ac AssemblerContext) bool
366
367	// ReassembledSG is called zero or more times.
368	// ScatterGather is reused after each Reassembled call,
369	// so it's important to copy anything you need out of it,
370	// especially bytes (or use KeepFrom())
371	ReassembledSG(sg ScatterGather, ac AssemblerContext)
372
373	// ReassemblyComplete is called when assembly decides there is
374	// no more data for this Stream, either because a FIN or RST packet
375	// was seen, or because the stream has timed out without any new
376	// packet data (due to a call to FlushCloseOlderThan).
377	// It should return true if the connection should be removed from the pool
378	// It can return false if it want to see subsequent packets with Accept(), e.g. to
379	// see FIN-ACK, for deeper state-machine analysis.
380	ReassemblyComplete(ac AssemblerContext) bool
381}
382
383// StreamFactory is used by assembly to create a new stream for each
384// new TCP session.
385type StreamFactory interface {
386	// New should return a new stream for the given TCP key.
387	New(netFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac AssemblerContext) Stream
388}
389
390type key [2]gopacket.Flow
391
392func (k *key) String() string {
393	return fmt.Sprintf("%s:%s", k[0], k[1])
394}
395
396func (k *key) Reverse() key {
397	return key{
398		k[0].Reverse(),
399		k[1].Reverse(),
400	}
401}
402
403const assemblerReturnValueInitialSize = 16
404
405/* one-way connection, i.e. halfconnection */
406type halfconnection struct {
407	dir               TCPFlowDirection
408	pages             int      // Number of pages used (both in first/last and saved)
409	saved             *page    // Doubly-linked list of in-order pages (seq < nextSeq) already given to Stream who told us to keep
410	first, last       *page    // Doubly-linked list of out-of-order pages (seq > nextSeq)
411	nextSeq           Sequence // sequence number of in-order received bytes
412	ackSeq            Sequence
413	created, lastSeen time.Time
414	stream            Stream
415	closed            bool
416	// for stats
417	queuedBytes    int
418	queuedPackets  int
419	overlapBytes   int
420	overlapPackets int
421}
422
423func (half *halfconnection) String() string {
424	closed := ""
425	if half.closed {
426		closed = "closed "
427	}
428	return fmt.Sprintf("%screated:%v, last:%v", closed, half.created, half.lastSeen)
429}
430
431// Dump returns a string (crypticly) describing the halfconnction
432func (half *halfconnection) Dump() string {
433	s := fmt.Sprintf("pages: %d\n"+
434		"nextSeq: %d\n"+
435		"ackSeq: %d\n"+
436		"Seen :  %s\n"+
437		"dir:    %s\n", half.pages, half.nextSeq, half.ackSeq, half.lastSeen, half.dir)
438	nb := 0
439	for p := half.first; p != nil; p = p.next {
440		s += fmt.Sprintf("	Page[%d] %s len: %d\n", nb, p, len(p.bytes))
441		nb++
442	}
443	return s
444}
445
446/* Bi-directionnal connection */
447
448type connection struct {
449	key      key // client->server
450	c2s, s2c halfconnection
451	mu       sync.Mutex
452}
453
454func (c *connection) reset(k key, s Stream, ts time.Time) {
455	c.key = k
456	base := halfconnection{
457		nextSeq:  invalidSequence,
458		ackSeq:   invalidSequence,
459		created:  ts,
460		lastSeen: ts,
461		stream:   s,
462	}
463	c.c2s, c.s2c = base, base
464	c.c2s.dir, c.s2c.dir = TCPDirClientToServer, TCPDirServerToClient
465}
466
467func (c *connection) lastSeen() time.Time {
468	if c.c2s.lastSeen.Before(c.s2c.lastSeen) {
469		return c.s2c.lastSeen
470	}
471
472	return c.c2s.lastSeen
473}
474
475func (c *connection) String() string {
476	return fmt.Sprintf("c2s: %s, s2c: %s", &c.c2s, &c.s2c)
477}
478
479/*
480 * Assembler
481 */
482
483// DefaultAssemblerOptions provides default options for an assembler.
484// These options are used by default when calling NewAssembler, so if
485// modified before a NewAssembler call they'll affect the resulting Assembler.
486//
487// Note that the default options can result in ever-increasing memory usage
488// unless one of the Flush* methods is called on a regular basis.
489var DefaultAssemblerOptions = AssemblerOptions{
490	MaxBufferedPagesPerConnection: 0, // unlimited
491	MaxBufferedPagesTotal:         0, // unlimited
492}
493
494// AssemblerOptions controls the behavior of each assembler.  Modify the
495// options of each assembler you create to change their behavior.
496type AssemblerOptions struct {
497	// MaxBufferedPagesTotal is an upper limit on the total number of pages to
498	// buffer while waiting for out-of-order packets.  Once this limit is
499	// reached, the assembler will degrade to flushing every connection it
500	// gets a packet for.  If <= 0, this is ignored.
501	MaxBufferedPagesTotal int
502	// MaxBufferedPagesPerConnection is an upper limit on the number of pages
503	// buffered for a single connection.  Should this limit be reached for a
504	// particular connection, the smallest sequence number will be flushed, along
505	// with any contiguous data.  If <= 0, this is ignored.
506	MaxBufferedPagesPerConnection int
507}
508
509// Assembler handles reassembling TCP streams.  It is not safe for
510// concurrency... after passing a packet in via the Assemble call, the caller
511// must wait for that call to return before calling Assemble again.  Callers can
512// get around this by creating multiple assemblers that share a StreamPool.  In
513// that case, each individual stream will still be handled serially (each stream
514// has an individual mutex associated with it), however multiple assemblers can
515// assemble different connections concurrently.
516//
517// The Assembler provides (hopefully) fast TCP stream re-assembly for sniffing
518// applications written in Go.  The Assembler uses the following methods to be
519// as fast as possible, to keep packet processing speedy:
520//
521// Avoids Lock Contention
522//
523// Assemblers locks connections, but each connection has an individual lock, and
524// rarely will two Assemblers be looking at the same connection.  Assemblers
525// lock the StreamPool when looking up connections, but they use Reader
526// locks initially, and only force a write lock if they need to create a new
527// connection or close one down.  These happen much less frequently than
528// individual packet handling.
529//
530// Each assembler runs in its own goroutine, and the only state shared between
531// goroutines is through the StreamPool.  Thus all internal Assembler state
532// can be handled without any locking.
533//
534// NOTE:  If you can guarantee that packets going to a set of Assemblers will
535// contain information on different connections per Assembler (for example,
536// they're already hashed by PF_RING hashing or some other hashing mechanism),
537// then we recommend you use a seperate StreamPool per Assembler, thus
538// avoiding all lock contention.  Only when different Assemblers could receive
539// packets for the same Stream should a StreamPool be shared between them.
540//
541// Avoids Memory Copying
542//
543// In the common case, handling of a single TCP packet should result in zero
544// memory allocations.  The Assembler will look up the connection, figure out
545// that the packet has arrived in order, and immediately pass that packet on to
546// the appropriate connection's handling code.  Only if a packet arrives out of
547// order is its contents copied and stored in memory for later.
548//
549// Avoids Memory Allocation
550//
551// Assemblers try very hard to not use memory allocation unless absolutely
552// necessary.  Packet data for sequential packets is passed directly to streams
553// with no copying or allocation.  Packet data for out-of-order packets is
554// copied into reusable pages, and new pages are only allocated rarely when the
555// page cache runs out.  Page caches are Assembler-specific, thus not used
556// concurrently and requiring no locking.
557//
558// Internal representations for connection objects are also reused over time.
559// Because of this, the most common memory allocation done by the Assembler is
560// generally what's done by the caller in StreamFactory.New.  If no allocation
561// is done there, then very little allocation is done ever, mostly to handle
562// large increases in bandwidth or numbers of connections.
563//
564// TODO:  The page caches used by an Assembler will grow to the size necessary
565// to handle a workload, and currently will never shrink.  This means that
566// traffic spikes can result in large memory usage which isn't garbage
567// collected when typical traffic levels return.
568type Assembler struct {
569	AssemblerOptions
570	ret      []byteContainer
571	pc       *pageCache
572	connPool *StreamPool
573	cacheLP  livePacket
574	cacheSG  reassemblyObject
575	start    bool
576}
577
578// NewAssembler creates a new assembler.  Pass in the StreamPool
579// to use, may be shared across assemblers.
580//
581// This sets some sane defaults for the assembler options,
582// see DefaultAssemblerOptions for details.
583func NewAssembler(pool *StreamPool) *Assembler {
584	pool.mu.Lock()
585	pool.users++
586	pool.mu.Unlock()
587	return &Assembler{
588		ret:              make([]byteContainer, 0, assemblerReturnValueInitialSize),
589		pc:               newPageCache(),
590		connPool:         pool,
591		AssemblerOptions: DefaultAssemblerOptions,
592	}
593}
594
595// Dump returns a short string describing the page usage of the Assembler
596func (a *Assembler) Dump() string {
597	s := ""
598	s += fmt.Sprintf("pageCache: used: %d, size: %d, free: %d", a.pc.used, a.pc.size, len(a.pc.free))
599	return s
600}
601
602// AssemblerContext provides method to get metadata
603type AssemblerContext interface {
604	GetCaptureInfo() gopacket.CaptureInfo
605}
606
607// Implements AssemblerContext for Assemble()
608type assemblerSimpleContext gopacket.CaptureInfo
609
610func (asc *assemblerSimpleContext) GetCaptureInfo() gopacket.CaptureInfo {
611	return gopacket.CaptureInfo(*asc)
612}
613
614// Assemble calls AssembleWithContext with the current timestamp, useful for
615// packets being read directly off the wire.
616func (a *Assembler) Assemble(netFlow gopacket.Flow, t *layers.TCP) {
617	ctx := assemblerSimpleContext(gopacket.CaptureInfo{Timestamp: time.Now()})
618	a.AssembleWithContext(netFlow, t, &ctx)
619}
620
621type assemblerAction struct {
622	nextSeq Sequence
623	queue   bool
624}
625
626// AssembleWithContext reassembles the given TCP packet into its appropriate
627// stream.
628//
629// The timestamp passed in must be the timestamp the packet was seen.
630// For packets read off the wire, time.Now() should be fine.  For packets read
631// from PCAP files, CaptureInfo.Timestamp should be passed in.  This timestamp
632// will affect which streams are flushed by a call to FlushCloseOlderThan.
633//
634// Each AssembleWithContext call results in, in order:
635//
636//    zero or one call to StreamFactory.New, creating a stream
637//    zero or one call to ReassembledSG on a single stream
638//    zero or one call to ReassemblyComplete on the same stream
639func (a *Assembler) AssembleWithContext(netFlow gopacket.Flow, t *layers.TCP, ac AssemblerContext) {
640	var conn *connection
641	var half *halfconnection
642	var rev *halfconnection
643
644	a.ret = a.ret[:0]
645	key := key{netFlow, t.TransportFlow()}
646	ci := ac.GetCaptureInfo()
647	timestamp := ci.Timestamp
648
649	conn, half, rev = a.connPool.getConnection(key, false, timestamp, t, ac)
650	if conn == nil {
651		if *debugLog {
652			log.Printf("%v got empty packet on otherwise empty connection", key)
653		}
654		return
655	}
656	conn.mu.Lock()
657	defer conn.mu.Unlock()
658	if half.lastSeen.Before(timestamp) {
659		half.lastSeen = timestamp
660	}
661	a.start = half.nextSeq == invalidSequence && t.SYN
662	if *debugLog {
663		if half.nextSeq < rev.ackSeq {
664			log.Printf("Delay detected on %v, data is acked but not assembled yet (acked %v, nextSeq %v)", key, rev.ackSeq, half.nextSeq)
665		}
666	}
667
668	if !half.stream.Accept(t, ci, half.dir, half.nextSeq, &a.start, ac) {
669		if *debugLog {
670			log.Printf("Ignoring packet")
671		}
672		return
673	}
674	if half.closed {
675		// this way is closed
676		if *debugLog {
677			log.Printf("%v got packet on closed half", key)
678		}
679		return
680	}
681
682	seq, ack, bytes := Sequence(t.Seq), Sequence(t.Ack), t.Payload
683	if t.ACK {
684		half.ackSeq = ack
685	}
686	// TODO: push when Ack is seen ??
687	action := assemblerAction{
688		nextSeq: Sequence(invalidSequence),
689		queue:   true,
690	}
691	a.dump("AssembleWithContext()", half)
692	if half.nextSeq == invalidSequence {
693		if t.SYN {
694			if *debugLog {
695				log.Printf("%v saw first SYN packet, returning immediately, seq=%v", key, seq)
696			}
697			seq = seq.Add(1)
698			half.nextSeq = seq
699			action.queue = false
700		} else if a.start {
701			if *debugLog {
702				log.Printf("%v start forced", key)
703			}
704			half.nextSeq = seq
705			action.queue = false
706		} else {
707			if *debugLog {
708				log.Printf("%v waiting for start, storing into connection", key)
709			}
710		}
711	} else {
712		diff := half.nextSeq.Difference(seq)
713		if diff > 0 {
714			if *debugLog {
715				log.Printf("%v gap in sequence numbers (%v, %v) diff %v, storing into connection", key, half.nextSeq, seq, diff)
716			}
717		} else {
718			if *debugLog {
719				log.Printf("%v found contiguous data (%v, %v), returning immediately: len:%d", key, seq, half.nextSeq, len(bytes))
720			}
721			action.queue = false
722		}
723	}
724
725	action = a.handleBytes(bytes, seq, half, ci, t.SYN, t.RST || t.FIN, action, ac)
726	if len(a.ret) > 0 {
727		action.nextSeq = a.sendToConnection(conn, half, ac)
728	}
729	if action.nextSeq != invalidSequence {
730		half.nextSeq = action.nextSeq
731		if t.FIN {
732			half.nextSeq = half.nextSeq.Add(1)
733		}
734	}
735	if *debugLog {
736		log.Printf("%v nextSeq:%d", key, half.nextSeq)
737	}
738}
739
740// Overlap strategies:
741//  - new packet overlaps with sent packets:
742//	1) discard new overlapping part
743//	2) overwrite old overlapped (TODO)
744//  - new packet overlaps existing queued packets:
745//	a) consider "age" by timestamp (TODO)
746//	b) consider "age" by being present
747//	Then
748//      1) discard new overlapping part
749//      2) overwrite queued part
750
751func (a *Assembler) checkOverlap(half *halfconnection, queue bool, ac AssemblerContext) {
752	var next *page
753	cur := half.last
754	bytes := a.cacheLP.bytes
755	start := a.cacheLP.seq
756	end := start.Add(len(bytes))
757
758	a.dump("before checkOverlap", half)
759
760	//          [s6           :           e6]
761	//   [s1:e1][s2:e2] -- [s3:e3] -- [s4:e4][s5:e5]
762	//             [s <--ds-- : --de--> e]
763	for cur != nil {
764
765		if *debugLog {
766			log.Printf("cur = %p (%s)\n", cur, cur)
767		}
768
769		// end < cur.start: continue (5)
770		if end.Difference(cur.seq) > 0 {
771			if *debugLog {
772				log.Printf("case 5\n")
773			}
774			next = cur
775			cur = cur.prev
776			continue
777		}
778
779		curEnd := cur.seq.Add(len(cur.bytes))
780		// start > cur.end: stop (1)
781		if start.Difference(curEnd) <= 0 {
782			if *debugLog {
783				log.Printf("case 1\n")
784			}
785			break
786		}
787
788		diffStart := start.Difference(cur.seq)
789		diffEnd := end.Difference(curEnd)
790
791		// end > cur.end && start < cur.start: drop (3)
792		if diffEnd <= 0 && diffStart >= 0 {
793			if *debugLog {
794				log.Printf("case 3\n")
795			}
796			if cur.isPacket() {
797				half.overlapPackets++
798			}
799			half.overlapBytes += len(cur.bytes)
800			// update links
801			if cur.prev != nil {
802				cur.prev.next = cur.next
803			} else {
804				half.first = cur.next
805			}
806			if cur.next != nil {
807				cur.next.prev = cur.prev
808			} else {
809				half.last = cur.prev
810			}
811			tmp := cur.prev
812			half.pages -= cur.release(a.pc)
813			cur = tmp
814			continue
815		}
816
817		// end > cur.end && start < cur.end: drop cur's end (2)
818		if diffEnd < 0 && start.Difference(curEnd) > 0 {
819			if *debugLog {
820				log.Printf("case 2\n")
821			}
822			cur.bytes = cur.bytes[:-start.Difference(cur.seq)]
823			break
824		} else
825
826		// start < cur.start && end > cur.start: drop cur's start (4)
827		if diffStart > 0 && end.Difference(cur.seq) < 0 {
828			if *debugLog {
829				log.Printf("case 4\n")
830			}
831			cur.bytes = cur.bytes[-end.Difference(cur.seq):]
832			cur.seq = cur.seq.Add(-end.Difference(cur.seq))
833			next = cur
834		} else
835
836		// end < cur.end && start > cur.start: replace bytes inside cur (6)
837		if diffEnd > 0 && diffStart < 0 {
838			if *debugLog {
839				log.Printf("case 6\n")
840			}
841			copy(cur.bytes[-diffStart:-diffStart+len(bytes)], bytes)
842			bytes = bytes[:0]
843		} else {
844			if *debugLog {
845				log.Printf("no overlap\n")
846			}
847			next = cur
848		}
849		cur = cur.prev
850	}
851
852	// Split bytes into pages, and insert in queue
853	a.cacheLP.bytes = bytes
854	a.cacheLP.seq = start
855	if len(bytes) > 0 && queue {
856		p, p2, numPages := a.cacheLP.convertToPages(a.pc, 0, ac)
857		half.queuedPackets++
858		half.queuedBytes += len(bytes)
859		half.pages += numPages
860		if cur != nil {
861			if *debugLog {
862				log.Printf("adding %s after %s", p, cur)
863			}
864			cur.next = p
865			p.prev = cur
866		} else {
867			if *debugLog {
868				log.Printf("adding %s as first", p)
869			}
870			half.first = p
871		}
872		if next != nil {
873			if *debugLog {
874				log.Printf("setting %s as next of new %s", next, p2)
875			}
876			p2.next = next
877			next.prev = p2
878		} else {
879			if *debugLog {
880				log.Printf("setting %s as last", p2)
881			}
882			half.last = p2
883		}
884	}
885	a.dump("After checkOverlap", half)
886}
887
888// Warning: this is a low-level dumper, i.e. a.ret or a.cacheSG might
889// be strange, but it could be ok.
890func (a *Assembler) dump(text string, half *halfconnection) {
891	if !*debugLog {
892		return
893	}
894	log.Printf("%s: dump\n", text)
895	if half != nil {
896		p := half.first
897		if p == nil {
898			log.Printf(" * half.first = %p, no chunks queued\n", p)
899		} else {
900			s := 0
901			nb := 0
902			log.Printf(" * half.first = %p, queued chunks:", p)
903			for p != nil {
904				log.Printf("\t%s bytes:%s\n", p, hex.EncodeToString(p.bytes))
905				s += len(p.bytes)
906				nb++
907				p = p.next
908			}
909			log.Printf("\t%d chunks for %d bytes", nb, s)
910		}
911		log.Printf(" * half.last = %p\n", half.last)
912		log.Printf(" * half.saved = %p\n", half.saved)
913		p = half.saved
914		for p != nil {
915			log.Printf("\tseq:%d %s bytes:%s\n", p.getSeq(), p, hex.EncodeToString(p.bytes))
916			p = p.next
917		}
918	}
919	log.Printf(" * a.ret\n")
920	for i, r := range a.ret {
921		log.Printf("\t%d: %v b:%s\n", i, r.captureInfo(), hex.EncodeToString(r.getBytes()))
922	}
923	log.Printf(" * a.cacheSG.all\n")
924	for i, r := range a.cacheSG.all {
925		log.Printf("\t%d: %v b:%s\n", i, r.captureInfo(), hex.EncodeToString(r.getBytes()))
926	}
927}
928
929func (a *Assembler) overlapExisting(half *halfconnection, start, end Sequence, bytes []byte) ([]byte, Sequence) {
930	if half.nextSeq == invalidSequence {
931		// no start yet
932		return bytes, start
933	}
934	diff := start.Difference(half.nextSeq)
935	if diff == 0 {
936		return bytes, start
937	}
938	s := 0
939	e := len(bytes)
940	// TODO: depending on strategy, we might want to shrink half.saved if possible
941	if e != 0 {
942		if *debugLog {
943			log.Printf("Overlap detected: ignoring current packet's first %d bytes", diff)
944		}
945		half.overlapPackets++
946		half.overlapBytes += diff
947	}
948	s += diff
949	if s >= e {
950		// Completely included in sent
951		s = e
952	}
953	bytes = bytes[s:]
954	return bytes, half.nextSeq
955}
956
957// Prepare send or queue
958func (a *Assembler) handleBytes(bytes []byte, seq Sequence, half *halfconnection, ci gopacket.CaptureInfo, start bool, end bool, action assemblerAction, ac AssemblerContext) assemblerAction {
959	a.cacheLP.bytes = bytes
960	a.cacheLP.start = start
961	a.cacheLP.end = end
962	a.cacheLP.seq = seq
963	a.cacheLP.ci = ci
964	a.cacheLP.ac = ac
965
966	if action.queue {
967		a.checkOverlap(half, true, ac)
968		if (a.MaxBufferedPagesPerConnection > 0 && half.pages >= a.MaxBufferedPagesPerConnection) ||
969			(a.MaxBufferedPagesTotal > 0 && a.pc.used >= a.MaxBufferedPagesTotal) {
970			if *debugLog {
971				log.Printf("hit max buffer size: %+v, %v, %v", a.AssemblerOptions, half.pages, a.pc.used)
972			}
973			action.queue = false
974			a.addNextFromConn(half)
975		}
976		a.dump("handleBytes after queue", half)
977	} else {
978		a.cacheLP.bytes, a.cacheLP.seq = a.overlapExisting(half, seq, seq.Add(len(bytes)), a.cacheLP.bytes)
979		a.checkOverlap(half, false, ac)
980		if len(a.cacheLP.bytes) != 0 || end || start {
981			a.ret = append(a.ret, &a.cacheLP)
982		}
983		a.dump("handleBytes after no queue", half)
984	}
985	return action
986}
987
988func (a *Assembler) setStatsToSG(half *halfconnection) {
989	a.cacheSG.queuedBytes = half.queuedBytes
990	half.queuedBytes = 0
991	a.cacheSG.queuedPackets = half.queuedPackets
992	half.queuedPackets = 0
993	a.cacheSG.overlapBytes = half.overlapBytes
994	half.overlapBytes = 0
995	a.cacheSG.overlapPackets = half.overlapPackets
996	half.overlapPackets = 0
997}
998
999// Build the ScatterGather object, i.e. prepend saved bytes and
1000// append continuous bytes.
1001func (a *Assembler) buildSG(half *halfconnection) (bool, Sequence) {
1002	// find if there are skipped bytes
1003	skip := -1
1004	if half.nextSeq != invalidSequence {
1005		skip = half.nextSeq.Difference(a.ret[0].getSeq())
1006	}
1007	last := a.ret[0].getSeq().Add(a.ret[0].length())
1008	// Prepend saved bytes
1009	saved := a.addPending(half, a.ret[0].getSeq())
1010	// Append continuous bytes
1011	nextSeq := a.addContiguous(half, last)
1012	a.cacheSG.all = a.ret
1013	a.cacheSG.Direction = half.dir
1014	a.cacheSG.Skip = skip
1015	a.cacheSG.saved = saved
1016	a.cacheSG.toKeep = -1
1017	a.setStatsToSG(half)
1018	a.dump("after buildSG", half)
1019	return a.ret[len(a.ret)-1].isEnd(), nextSeq
1020}
1021
1022func (a *Assembler) cleanSG(half *halfconnection, ac AssemblerContext) {
1023	cur := 0
1024	ndx := 0
1025	skip := 0
1026
1027	a.dump("cleanSG(start)", half)
1028
1029	var r byteContainer
1030	// Find first page to keep
1031	if a.cacheSG.toKeep < 0 {
1032		ndx = len(a.cacheSG.all)
1033	} else {
1034		skip = a.cacheSG.toKeep
1035		found := false
1036		for ndx, r = range a.cacheSG.all {
1037			if a.cacheSG.toKeep < cur+r.length() {
1038				found = true
1039				break
1040			}
1041			cur += r.length()
1042			if skip >= r.length() {
1043				skip -= r.length()
1044			}
1045		}
1046		if !found {
1047			ndx++
1048		}
1049	}
1050	// Release consumed pages
1051	for _, r := range a.cacheSG.all[:ndx] {
1052		if r == half.saved {
1053			if half.saved.next != nil {
1054				half.saved.next.prev = nil
1055			}
1056			half.saved = half.saved.next
1057		} else if r == half.first {
1058			if half.first.next != nil {
1059				half.first.next.prev = nil
1060			}
1061			if half.first == half.last {
1062				half.first, half.last = nil, nil
1063			} else {
1064				half.first = half.first.next
1065			}
1066		}
1067		half.pages -= r.release(a.pc)
1068	}
1069	a.dump("after consumed release", half)
1070	// Keep un-consumed pages
1071	nbKept := 0
1072	half.saved = nil
1073	var saved *page
1074	for _, r := range a.cacheSG.all[ndx:] {
1075		first, last, nb := r.convertToPages(a.pc, skip, ac)
1076		if half.saved == nil {
1077			half.saved = first
1078		} else {
1079			saved.next = first
1080			first.prev = saved
1081		}
1082		saved = last
1083		nbKept += nb
1084	}
1085	if *debugLog {
1086		log.Printf("Remaining %d chunks in SG\n", nbKept)
1087		log.Printf("%s\n", a.Dump())
1088		a.dump("after cleanSG()", half)
1089	}
1090}
1091
1092// sendToConnection sends the current values in a.ret to the connection, closing
1093// the connection if the last thing sent had End set.
1094func (a *Assembler) sendToConnection(conn *connection, half *halfconnection, ac AssemblerContext) Sequence {
1095	if *debugLog {
1096		log.Printf("sendToConnection\n")
1097	}
1098	end, nextSeq := a.buildSG(half)
1099	half.stream.ReassembledSG(&a.cacheSG, ac)
1100	a.cleanSG(half, ac)
1101	if end {
1102		a.closeHalfConnection(conn, half)
1103	}
1104	if *debugLog {
1105		log.Printf("after sendToConnection: nextSeq: %d\n", nextSeq)
1106	}
1107	return nextSeq
1108}
1109
1110//
1111func (a *Assembler) addPending(half *halfconnection, firstSeq Sequence) int {
1112	if half.saved == nil {
1113		return 0
1114	}
1115	s := 0
1116	ret := []byteContainer{}
1117	for p := half.saved; p != nil; p = p.next {
1118		if *debugLog {
1119			log.Printf("adding pending @%p %s (%s)\n", p, p, hex.EncodeToString(p.bytes))
1120		}
1121		ret = append(ret, p)
1122		s += len(p.bytes)
1123	}
1124	if half.saved.seq.Add(s) != firstSeq {
1125		// non-continuous saved: drop them
1126		var next *page
1127		for p := half.saved; p != nil; p = next {
1128			next = p.next
1129			p.release(a.pc)
1130		}
1131		half.saved = nil
1132		ret = []byteContainer{}
1133		s = 0
1134	}
1135
1136	a.ret = append(ret, a.ret...)
1137	return s
1138}
1139
1140// addContiguous adds contiguous byte-sets to a connection.
1141func (a *Assembler) addContiguous(half *halfconnection, lastSeq Sequence) Sequence {
1142	page := half.first
1143	if page == nil {
1144		if *debugLog {
1145			log.Printf("addContiguous(%d): no pages\n", lastSeq)
1146		}
1147		return lastSeq
1148	}
1149	if lastSeq == invalidSequence {
1150		lastSeq = page.seq
1151	}
1152	for page != nil && lastSeq.Difference(page.seq) == 0 {
1153		if *debugLog {
1154			log.Printf("addContiguous: lastSeq: %d, first.seq=%d, page.seq=%d\n", half.nextSeq, half.first.seq, page.seq)
1155		}
1156		lastSeq = lastSeq.Add(len(page.bytes))
1157		a.ret = append(a.ret, page)
1158		half.first = page.next
1159		if half.first == nil {
1160			half.last = nil
1161		}
1162		if page.next != nil {
1163			page.next.prev = nil
1164		}
1165		page = page.next
1166	}
1167	return lastSeq
1168}
1169
1170// skipFlush skips the first set of bytes we're waiting for and returns the
1171// first set of bytes we have.  If we have no bytes saved, it closes the
1172// connection.
1173func (a *Assembler) skipFlush(conn *connection, half *halfconnection) {
1174	if *debugLog {
1175		log.Printf("skipFlush %v\n", half.nextSeq)
1176	}
1177	// Well, it's embarassing it there is still something in half.saved
1178	// FIXME: change API to give back saved + new/no packets
1179	if half.first == nil {
1180		a.closeHalfConnection(conn, half)
1181		return
1182	}
1183	a.ret = a.ret[:0]
1184	a.addNextFromConn(half)
1185	nextSeq := a.sendToConnection(conn, half, a.ret[0].assemblerContext())
1186	if nextSeq != invalidSequence {
1187		half.nextSeq = nextSeq
1188	}
1189}
1190
1191func (a *Assembler) closeHalfConnection(conn *connection, half *halfconnection) {
1192	if *debugLog {
1193		log.Printf("%v closing", conn)
1194	}
1195	half.closed = true
1196	for p := half.first; p != nil; p = p.next {
1197		// FIXME: it should be already empty
1198		a.pc.replace(p)
1199		half.pages--
1200	}
1201	if conn.s2c.closed && conn.c2s.closed {
1202		if half.stream.ReassemblyComplete(nil) { //FIXME: which context to pass ?
1203			a.connPool.remove(conn)
1204		}
1205	}
1206}
1207
1208// addNextFromConn pops the first page from a connection off and adds it to the
1209// return array.
1210func (a *Assembler) addNextFromConn(conn *halfconnection) {
1211	if conn.first == nil {
1212		return
1213	}
1214	if *debugLog {
1215		log.Printf("   adding from conn (%v, %v) %v (%d)\n", conn.first.seq, conn.nextSeq, conn.nextSeq-conn.first.seq, len(conn.first.bytes))
1216	}
1217	a.ret = append(a.ret, conn.first)
1218	conn.first = conn.first.next
1219	if conn.first != nil {
1220		conn.first.prev = nil
1221	} else {
1222		conn.last = nil
1223	}
1224}
1225
1226// FlushOptions provide options for flushing connections.
1227type FlushOptions struct {
1228	T  time.Time // If nonzero, only connections with data older than T are flushed
1229	TC time.Time // If nonzero, only connections with data older than TC are closed (if no FIN/RST received)
1230}
1231
1232// FlushWithOptions finds any streams waiting for packets older than
1233// the given time T, and pushes through the data they have (IE: tells
1234// them to stop waiting and skip the data they're waiting for).
1235//
1236// It also closes streams older than TC (that can be set to zero, to keep
1237// long-lived stream alive, but to flush data anyway).
1238//
1239// Each Stream maintains a list of zero or more sets of bytes it has received
1240// out-of-order.  For example, if it has processed up through sequence number
1241// 10, it might have bytes [15-20), [20-25), [30,50) in its list.  Each set of
1242// bytes also has the timestamp it was originally viewed.  A flush call will
1243// look at the smallest subsequent set of bytes, in this case [15-20), and if
1244// its timestamp is older than the passed-in time, it will push it and all
1245// contiguous byte-sets out to the Stream's Reassembled function.  In this case,
1246// it will push [15-20), but also [20-25), since that's contiguous.  It will
1247// only push [30-50) if its timestamp is also older than the passed-in time,
1248// otherwise it will wait until the next FlushCloseOlderThan to see if bytes
1249// [25-30) come in.
1250//
1251// Returns the number of connections flushed, and of those, the number closed
1252// because of the flush.
1253func (a *Assembler) FlushWithOptions(opt FlushOptions) (flushed, closed int) {
1254	conns := a.connPool.connections()
1255	closes := 0
1256	flushes := 0
1257	for _, conn := range conns {
1258		remove := false
1259		conn.mu.Lock()
1260		for _, half := range []*halfconnection{&conn.s2c, &conn.c2s} {
1261			flushed, closed := a.flushClose(conn, half, opt.T, opt.TC)
1262			if flushed {
1263				flushes++
1264			}
1265			if closed {
1266				closes++
1267			}
1268		}
1269		if conn.s2c.closed && conn.c2s.closed && conn.s2c.lastSeen.Before(opt.TC) && conn.c2s.lastSeen.Before(opt.TC) {
1270			remove = true
1271		}
1272		conn.mu.Unlock()
1273		if remove {
1274			a.connPool.remove(conn)
1275		}
1276	}
1277	return flushes, closes
1278}
1279
1280// FlushCloseOlderThan flushes and closes streams older than given time
1281func (a *Assembler) FlushCloseOlderThan(t time.Time) (flushed, closed int) {
1282	return a.FlushWithOptions(FlushOptions{T: t, TC: t})
1283}
1284
1285func (a *Assembler) flushClose(conn *connection, half *halfconnection, t time.Time, tc time.Time) (bool, bool) {
1286	flushed, closed := false, false
1287	if half.closed {
1288		return flushed, closed
1289	}
1290	for half.first != nil && half.first.seen.Before(t) {
1291		flushed = true
1292		a.skipFlush(conn, half)
1293		if half.closed {
1294			closed = true
1295			return flushed, closed
1296		}
1297	}
1298	// Close the connection only if both halfs of the connection last seen before tc.
1299	if !half.closed && half.first == nil && conn.lastSeen().Before(tc) {
1300		a.closeHalfConnection(conn, half)
1301		closed = true
1302	}
1303	return flushed, closed
1304}
1305
1306// FlushAll flushes all remaining data into all remaining connections and closes
1307// those connections. It returns the total number of connections flushed/closed
1308// by the call.
1309func (a *Assembler) FlushAll() (closed int) {
1310	conns := a.connPool.connections()
1311	closed = len(conns)
1312	for _, conn := range conns {
1313		conn.mu.Lock()
1314		for _, half := range []*halfconnection{&conn.s2c, &conn.c2s} {
1315			for !half.closed {
1316				a.skipFlush(conn, half)
1317			}
1318			if !half.closed {
1319				a.closeHalfConnection(conn, half)
1320			}
1321		}
1322		conn.mu.Unlock()
1323	}
1324	return
1325}
1326
1327func min(a, b int) int {
1328	if a < b {
1329		return a
1330	}
1331	return b
1332}
1333