1// Copyright 2012 Google, Inc. All rights reserved. 2// 3// Use of this source code is governed by a BSD-style license 4// that can be found in the LICENSE file in the root of the source 5// tree. 6 7package reassembly 8 9import ( 10 "flag" 11 "log" 12 "sync" 13 "time" 14 15 "github.com/google/gopacket/layers" 16) 17 18var memLog = flag.Bool("assembly_memuse_log", defaultDebug, "If true, the github.com/google/gopacket/reassembly library will log information regarding its memory use every once in a while.") 19 20/* 21 * pageCache 22 */ 23// pageCache is a concurrency-unsafe store of page objects we use to avoid 24// memory allocation as much as we can. 25type pageCache struct { 26 free []*page 27 pcSize int 28 size, used int 29 pageRequests int64 30 ops int 31 nextShrink int 32} 33 34const initialAllocSize = 1024 35 36func newPageCache() *pageCache { 37 pc := &pageCache{ 38 free: make([]*page, 0, initialAllocSize), 39 pcSize: initialAllocSize, 40 } 41 pc.grow() 42 return pc 43} 44 45// grow exponentially increases the size of our page cache as much as necessary. 46func (c *pageCache) grow() { 47 pages := make([]page, c.pcSize) 48 c.size += c.pcSize 49 for i := range pages { 50 c.free = append(c.free, &pages[i]) 51 } 52 if *memLog { 53 log.Println("PageCache: created", c.pcSize, "new pages, size:", c.size, "cap:", cap(c.free), "len:", len(c.free)) 54 } 55 // control next shrink attempt 56 c.nextShrink = c.pcSize 57 c.ops = 0 58 // prepare for next alloc 59 c.pcSize *= 2 60} 61 62// Remove references to unused pages to let GC collect them 63// Note: memory used by c.free itself it not collected. 64func (c *pageCache) tryShrink() { 65 var min = c.pcSize / 2 66 if min < initialAllocSize { 67 min = initialAllocSize 68 } 69 if len(c.free) <= min { 70 return 71 } 72 for i := range c.free[min:] { 73 c.free[min+i] = nil 74 } 75 c.size -= len(c.free) - min 76 c.free = c.free[:min] 77 c.pcSize = min 78} 79 80// next returns a clean, ready-to-use page object. 81func (c *pageCache) next(ts time.Time) (p *page) { 82 if *memLog { 83 c.pageRequests++ 84 if c.pageRequests&0xFFFF == 0 { 85 log.Println("PageCache:", c.pageRequests, "requested,", c.used, "used,", len(c.free), "free") 86 } 87 } 88 if len(c.free) == 0 { 89 c.grow() 90 } 91 i := len(c.free) - 1 92 p, c.free = c.free[i], c.free[:i] 93 p.seen = ts 94 p.bytes = p.buf[:0] 95 c.used++ 96 if *memLog { 97 log.Printf("allocator returns %s\n", p) 98 } 99 c.ops++ 100 if c.ops > c.nextShrink { 101 c.ops = 0 102 c.tryShrink() 103 } 104 105 return p 106} 107 108// replace replaces a page into the pageCache. 109func (c *pageCache) replace(p *page) { 110 c.used-- 111 if *memLog { 112 log.Printf("replacing %s\n", p) 113 } 114 p.prev = nil 115 p.next = nil 116 c.free = append(c.free, p) 117} 118 119/* 120 * StreamPool 121 */ 122 123// StreamPool stores all streams created by Assemblers, allowing multiple 124// assemblers to work together on stream processing while enforcing the fact 125// that a single stream receives its data serially. It is safe 126// for concurrency, usable by multiple Assemblers at once. 127// 128// StreamPool handles the creation and storage of Stream objects used by one or 129// more Assembler objects. When a new TCP stream is found by an Assembler, it 130// creates an associated Stream by calling its StreamFactory's New method. 131// Thereafter (until the stream is closed), that Stream object will receive 132// assembled TCP data via Assembler's calls to the stream's Reassembled 133// function. 134// 135// Like the Assembler, StreamPool attempts to minimize allocation. Unlike the 136// Assembler, though, it does have to do some locking to make sure that the 137// connection objects it stores are accessible to multiple Assemblers. 138type StreamPool struct { 139 conns map[key]*connection 140 users int 141 mu sync.RWMutex 142 factory StreamFactory 143 free []*connection 144 all [][]connection 145 nextAlloc int 146 newConnectionCount int64 147} 148 149func (p *StreamPool) grow() { 150 conns := make([]connection, p.nextAlloc) 151 p.all = append(p.all, conns) 152 for i := range conns { 153 p.free = append(p.free, &conns[i]) 154 } 155 if *memLog { 156 log.Println("StreamPool: created", p.nextAlloc, "new connections") 157 } 158 p.nextAlloc *= 2 159} 160 161// Dump logs all connections 162func (p *StreamPool) Dump() { 163 p.mu.Lock() 164 defer p.mu.Unlock() 165 log.Printf("Remaining %d connections: ", len(p.conns)) 166 for _, conn := range p.conns { 167 log.Printf("%v %s", conn.key, conn) 168 } 169} 170 171func (p *StreamPool) remove(conn *connection) { 172 p.mu.Lock() 173 if _, ok := p.conns[conn.key]; ok { 174 delete(p.conns, conn.key) 175 p.free = append(p.free, conn) 176 } 177 p.mu.Unlock() 178} 179 180// NewStreamPool creates a new connection pool. Streams will 181// be created as necessary using the passed-in StreamFactory. 182func NewStreamPool(factory StreamFactory) *StreamPool { 183 return &StreamPool{ 184 conns: make(map[key]*connection, initialAllocSize), 185 free: make([]*connection, 0, initialAllocSize), 186 factory: factory, 187 nextAlloc: initialAllocSize, 188 } 189} 190 191func (p *StreamPool) connections() []*connection { 192 p.mu.RLock() 193 conns := make([]*connection, 0, len(p.conns)) 194 for _, conn := range p.conns { 195 conns = append(conns, conn) 196 } 197 p.mu.RUnlock() 198 return conns 199} 200 201func (p *StreamPool) newConnection(k key, s Stream, ts time.Time) (c *connection, h *halfconnection, r *halfconnection) { 202 if *memLog { 203 p.newConnectionCount++ 204 if p.newConnectionCount&0x7FFF == 0 { 205 log.Println("StreamPool:", p.newConnectionCount, "requests,", len(p.conns), "used,", len(p.free), "free") 206 } 207 } 208 if len(p.free) == 0 { 209 p.grow() 210 } 211 index := len(p.free) - 1 212 c, p.free = p.free[index], p.free[:index] 213 c.reset(k, s, ts) 214 return c, &c.c2s, &c.s2c 215} 216 217func (p *StreamPool) getHalf(k key) (*connection, *halfconnection, *halfconnection) { 218 conn := p.conns[k] 219 if conn != nil { 220 return conn, &conn.c2s, &conn.s2c 221 } 222 rk := k.Reverse() 223 conn = p.conns[rk] 224 if conn != nil { 225 return conn, &conn.s2c, &conn.c2s 226 } 227 return nil, nil, nil 228} 229 230// getConnection returns a connection. If end is true and a connection 231// does not already exist, returns nil. This allows us to check for a 232// connection without actually creating one if it doesn't already exist. 233func (p *StreamPool) getConnection(k key, end bool, ts time.Time, tcp *layers.TCP, ac AssemblerContext) (*connection, *halfconnection, *halfconnection) { 234 p.mu.RLock() 235 conn, half, rev := p.getHalf(k) 236 p.mu.RUnlock() 237 if end || conn != nil { 238 return conn, half, rev 239 } 240 s := p.factory.New(k[0], k[1], tcp, ac) 241 p.mu.Lock() 242 defer p.mu.Unlock() 243 conn, half, rev = p.newConnection(k, s, ts) 244 conn2, half2, rev2 := p.getHalf(k) 245 if conn2 != nil { 246 if conn2.key != k { 247 panic("FIXME: other dir added in the meantime...") 248 } 249 // FIXME: delete s ? 250 return conn2, half2, rev2 251 } 252 p.conns[k] = conn 253 return conn, half, rev 254} 255