1// Copyright 2012 Google, Inc. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style license
4// that can be found in the LICENSE file in the root of the source
5// tree.
6
7package reassembly
8
9import (
10	"flag"
11	"log"
12	"sync"
13	"time"
14
15	"github.com/google/gopacket/layers"
16)
17
18var memLog = flag.Bool("assembly_memuse_log", defaultDebug, "If true, the github.com/google/gopacket/reassembly library will log information regarding its memory use every once in a while.")
19
20/*
21 * pageCache
22 */
23// pageCache is a concurrency-unsafe store of page objects we use to avoid
24// memory allocation as much as we can.
25type pageCache struct {
26	free         []*page
27	pcSize       int
28	size, used   int
29	pageRequests int64
30	ops          int
31	nextShrink   int
32}
33
34const initialAllocSize = 1024
35
36func newPageCache() *pageCache {
37	pc := &pageCache{
38		free:   make([]*page, 0, initialAllocSize),
39		pcSize: initialAllocSize,
40	}
41	pc.grow()
42	return pc
43}
44
45// grow exponentially increases the size of our page cache as much as necessary.
46func (c *pageCache) grow() {
47	pages := make([]page, c.pcSize)
48	c.size += c.pcSize
49	for i := range pages {
50		c.free = append(c.free, &pages[i])
51	}
52	if *memLog {
53		log.Println("PageCache: created", c.pcSize, "new pages, size:", c.size, "cap:", cap(c.free), "len:", len(c.free))
54	}
55	// control next shrink attempt
56	c.nextShrink = c.pcSize
57	c.ops = 0
58	// prepare for next alloc
59	c.pcSize *= 2
60}
61
62// Remove references to unused pages to let GC collect them
63// Note: memory used by c.free itself it not collected.
64func (c *pageCache) tryShrink() {
65	var min = c.pcSize / 2
66	if min < initialAllocSize {
67		min = initialAllocSize
68	}
69	if len(c.free) <= min {
70		return
71	}
72	for i := range c.free[min:] {
73		c.free[min+i] = nil
74	}
75	c.size -= len(c.free) - min
76	c.free = c.free[:min]
77	c.pcSize = min
78}
79
80// next returns a clean, ready-to-use page object.
81func (c *pageCache) next(ts time.Time) (p *page) {
82	if *memLog {
83		c.pageRequests++
84		if c.pageRequests&0xFFFF == 0 {
85			log.Println("PageCache:", c.pageRequests, "requested,", c.used, "used,", len(c.free), "free")
86		}
87	}
88	if len(c.free) == 0 {
89		c.grow()
90	}
91	i := len(c.free) - 1
92	p, c.free = c.free[i], c.free[:i]
93	p.seen = ts
94	p.bytes = p.buf[:0]
95	c.used++
96	if *memLog {
97		log.Printf("allocator returns %s\n", p)
98	}
99	c.ops++
100	if c.ops > c.nextShrink {
101		c.ops = 0
102		c.tryShrink()
103	}
104
105	return p
106}
107
108// replace replaces a page into the pageCache.
109func (c *pageCache) replace(p *page) {
110	c.used--
111	if *memLog {
112		log.Printf("replacing %s\n", p)
113	}
114	p.prev = nil
115	p.next = nil
116	c.free = append(c.free, p)
117}
118
119/*
120 * StreamPool
121 */
122
123// StreamPool stores all streams created by Assemblers, allowing multiple
124// assemblers to work together on stream processing while enforcing the fact
125// that a single stream receives its data serially.  It is safe
126// for concurrency, usable by multiple Assemblers at once.
127//
128// StreamPool handles the creation and storage of Stream objects used by one or
129// more Assembler objects.  When a new TCP stream is found by an Assembler, it
130// creates an associated Stream by calling its StreamFactory's New method.
131// Thereafter (until the stream is closed), that Stream object will receive
132// assembled TCP data via Assembler's calls to the stream's Reassembled
133// function.
134//
135// Like the Assembler, StreamPool attempts to minimize allocation.  Unlike the
136// Assembler, though, it does have to do some locking to make sure that the
137// connection objects it stores are accessible to multiple Assemblers.
138type StreamPool struct {
139	conns              map[key]*connection
140	users              int
141	mu                 sync.RWMutex
142	factory            StreamFactory
143	free               []*connection
144	all                [][]connection
145	nextAlloc          int
146	newConnectionCount int64
147}
148
149func (p *StreamPool) grow() {
150	conns := make([]connection, p.nextAlloc)
151	p.all = append(p.all, conns)
152	for i := range conns {
153		p.free = append(p.free, &conns[i])
154	}
155	if *memLog {
156		log.Println("StreamPool: created", p.nextAlloc, "new connections")
157	}
158	p.nextAlloc *= 2
159}
160
161// Dump logs all connections
162func (p *StreamPool) Dump() {
163	p.mu.Lock()
164	defer p.mu.Unlock()
165	log.Printf("Remaining %d connections: ", len(p.conns))
166	for _, conn := range p.conns {
167		log.Printf("%v %s", conn.key, conn)
168	}
169}
170
171func (p *StreamPool) remove(conn *connection) {
172	p.mu.Lock()
173	if _, ok := p.conns[conn.key]; ok {
174		delete(p.conns, conn.key)
175		p.free = append(p.free, conn)
176	}
177	p.mu.Unlock()
178}
179
180// NewStreamPool creates a new connection pool.  Streams will
181// be created as necessary using the passed-in StreamFactory.
182func NewStreamPool(factory StreamFactory) *StreamPool {
183	return &StreamPool{
184		conns:     make(map[key]*connection, initialAllocSize),
185		free:      make([]*connection, 0, initialAllocSize),
186		factory:   factory,
187		nextAlloc: initialAllocSize,
188	}
189}
190
191func (p *StreamPool) connections() []*connection {
192	p.mu.RLock()
193	conns := make([]*connection, 0, len(p.conns))
194	for _, conn := range p.conns {
195		conns = append(conns, conn)
196	}
197	p.mu.RUnlock()
198	return conns
199}
200
201func (p *StreamPool) newConnection(k key, s Stream, ts time.Time) (c *connection, h *halfconnection, r *halfconnection) {
202	if *memLog {
203		p.newConnectionCount++
204		if p.newConnectionCount&0x7FFF == 0 {
205			log.Println("StreamPool:", p.newConnectionCount, "requests,", len(p.conns), "used,", len(p.free), "free")
206		}
207	}
208	if len(p.free) == 0 {
209		p.grow()
210	}
211	index := len(p.free) - 1
212	c, p.free = p.free[index], p.free[:index]
213	c.reset(k, s, ts)
214	return c, &c.c2s, &c.s2c
215}
216
217func (p *StreamPool) getHalf(k key) (*connection, *halfconnection, *halfconnection) {
218	conn := p.conns[k]
219	if conn != nil {
220		return conn, &conn.c2s, &conn.s2c
221	}
222	rk := k.Reverse()
223	conn = p.conns[rk]
224	if conn != nil {
225		return conn, &conn.s2c, &conn.c2s
226	}
227	return nil, nil, nil
228}
229
230// getConnection returns a connection.  If end is true and a connection
231// does not already exist, returns nil.  This allows us to check for a
232// connection without actually creating one if it doesn't already exist.
233func (p *StreamPool) getConnection(k key, end bool, ts time.Time, tcp *layers.TCP, ac AssemblerContext) (*connection, *halfconnection, *halfconnection) {
234	p.mu.RLock()
235	conn, half, rev := p.getHalf(k)
236	p.mu.RUnlock()
237	if end || conn != nil {
238		return conn, half, rev
239	}
240	s := p.factory.New(k[0], k[1], tcp, ac)
241	p.mu.Lock()
242	defer p.mu.Unlock()
243	conn, half, rev = p.newConnection(k, s, ts)
244	conn2, half2, rev2 := p.getHalf(k)
245	if conn2 != nil {
246		if conn2.key != k {
247			panic("FIXME: other dir added in the meantime...")
248		}
249		// FIXME: delete s ?
250		return conn2, half2, rev2
251	}
252	p.conns[k] = conn
253	return conn, half, rev
254}
255