1// Copyright 2012 Google, Inc. All rights reserved. 2// 3// Use of this source code is governed by a BSD-style license 4// that can be found in the LICENSE file in the root of the source 5// tree. 6 7// Package reassembly provides TCP stream re-assembly. 8// 9// The reassembly package implements uni-directional TCP reassembly, for use in 10// packet-sniffing applications. The caller reads packets off the wire, then 11// presents them to an Assembler in the form of gopacket layers.TCP packets 12// (github.com/google/gopacket, github.com/google/gopacket/layers). 13// 14// The Assembler uses a user-supplied 15// StreamFactory to create a user-defined Stream interface, then passes packet 16// data in stream order to that object. A concurrency-safe StreamPool keeps 17// track of all current Streams being reassembled, so multiple Assemblers may 18// run at once to assemble packets while taking advantage of multiple cores. 19// 20// TODO: Add simplest example 21package reassembly 22 23import ( 24 "encoding/hex" 25 "flag" 26 "fmt" 27 "log" 28 "sync" 29 "time" 30 31 "github.com/google/gopacket" 32 "github.com/google/gopacket/layers" 33) 34 35// TODO: 36// - push to Stream on Ack 37// - implement chunked (cheap) reads and Reader() interface 38// - better organize file: split files: 'mem', 'misc' (seq + flow) 39 40var defaultDebug = false 41 42var debugLog = flag.Bool("assembly_debug_log", defaultDebug, "If true, the github.com/google/gopacket/reassembly library will log verbose debugging information (at least one line per packet)") 43 44const invalidSequence = -1 45const uint32Max = 0xFFFFFFFF 46 47// Sequence is a TCP sequence number. It provides a few convenience functions 48// for handling TCP wrap-around. The sequence should always be in the range 49// [0,0xFFFFFFFF]... its other bits are simply used in wrap-around calculations 50// and should never be set. 51type Sequence int64 52 53// Difference defines an ordering for comparing TCP sequences that's safe for 54// roll-overs. It returns: 55// > 0 : if t comes after s 56// < 0 : if t comes before s 57// 0 : if t == s 58// The number returned is the sequence difference, so 4.Difference(8) will 59// return 4. 60// 61// It handles rollovers by considering any sequence in the first quarter of the 62// uint32 space to be after any sequence in the last quarter of that space, thus 63// wrapping the uint32 space. 64func (s Sequence) Difference(t Sequence) int { 65 if s > uint32Max-uint32Max/4 && t < uint32Max/4 { 66 t += uint32Max 67 } else if t > uint32Max-uint32Max/4 && s < uint32Max/4 { 68 s += uint32Max 69 } 70 return int(t - s) 71} 72 73// Add adds an integer to a sequence and returns the resulting sequence. 74func (s Sequence) Add(t int) Sequence { 75 return (s + Sequence(t)) & uint32Max 76} 77 78// TCPAssemblyStats provides some figures for a ScatterGather 79type TCPAssemblyStats struct { 80 // For this ScatterGather 81 Chunks int 82 Packets int 83 // For the half connection, since last call to ReassembledSG() 84 QueuedBytes int 85 QueuedPackets int 86 OverlapBytes int 87 OverlapPackets int 88} 89 90// ScatterGather is used to pass reassembled data and metadata of reassembled 91// packets to a Stream via ReassembledSG 92type ScatterGather interface { 93 // Returns the length of available bytes and saved bytes 94 Lengths() (int, int) 95 // Returns the bytes up to length (shall be <= available bytes) 96 Fetch(length int) []byte 97 // Tell to keep from offset 98 KeepFrom(offset int) 99 // Return CaptureInfo of packet corresponding to given offset 100 CaptureInfo(offset int) gopacket.CaptureInfo 101 // Return some info about the reassembled chunks 102 Info() (direction TCPFlowDirection, start bool, end bool, skip int) 103 // Return some stats regarding the state of the stream 104 Stats() TCPAssemblyStats 105} 106 107// byteContainer is either a page or a livePacket 108type byteContainer interface { 109 getBytes() []byte 110 length() int 111 convertToPages(*pageCache, int, AssemblerContext) (*page, *page, int) 112 captureInfo() gopacket.CaptureInfo 113 assemblerContext() AssemblerContext 114 release(*pageCache) int 115 isStart() bool 116 isEnd() bool 117 getSeq() Sequence 118 isPacket() bool 119} 120 121// Implements a ScatterGather 122type reassemblyObject struct { 123 all []byteContainer 124 Skip int 125 Direction TCPFlowDirection 126 saved int 127 toKeep int 128 // stats 129 queuedBytes int 130 queuedPackets int 131 overlapBytes int 132 overlapPackets int 133} 134 135func (rl *reassemblyObject) Lengths() (int, int) { 136 l := 0 137 for _, r := range rl.all { 138 l += r.length() 139 } 140 return l, rl.saved 141} 142 143func (rl *reassemblyObject) Fetch(l int) []byte { 144 if l <= rl.all[0].length() { 145 return rl.all[0].getBytes()[:l] 146 } 147 bytes := make([]byte, 0, l) 148 for _, bc := range rl.all { 149 bytes = append(bytes, bc.getBytes()...) 150 } 151 return bytes[:l] 152} 153 154func (rl *reassemblyObject) KeepFrom(offset int) { 155 rl.toKeep = offset 156} 157 158func (rl *reassemblyObject) CaptureInfo(offset int) gopacket.CaptureInfo { 159 current := 0 160 var r byteContainer 161 for _, r = range rl.all { 162 if current >= offset { 163 return r.captureInfo() 164 } 165 current += r.length() 166 } 167 if r != nil && current >= offset { 168 return r.captureInfo() 169 } 170 // Invalid offset 171 return gopacket.CaptureInfo{} 172} 173 174func (rl *reassemblyObject) Info() (TCPFlowDirection, bool, bool, int) { 175 return rl.Direction, rl.all[0].isStart(), rl.all[len(rl.all)-1].isEnd(), rl.Skip 176} 177 178func (rl *reassemblyObject) Stats() TCPAssemblyStats { 179 packets := int(0) 180 for _, r := range rl.all { 181 if r.isPacket() { 182 packets++ 183 } 184 } 185 return TCPAssemblyStats{ 186 Chunks: len(rl.all), 187 Packets: packets, 188 QueuedBytes: rl.queuedBytes, 189 QueuedPackets: rl.queuedPackets, 190 OverlapBytes: rl.overlapBytes, 191 OverlapPackets: rl.overlapPackets, 192 } 193} 194 195const pageBytes = 1900 196 197// TCPFlowDirection distinguish the two half-connections directions. 198// 199// TCPDirClientToServer is assigned to half-connection for the first received 200// packet, hence might be wrong if packets are not received in order. 201// It's up to the caller (e.g. in Accept()) to decide if the direction should 202// be interpretted differently. 203type TCPFlowDirection bool 204 205// Value are not really useful 206const ( 207 TCPDirClientToServer TCPFlowDirection = false 208 TCPDirServerToClient TCPFlowDirection = true 209) 210 211func (dir TCPFlowDirection) String() string { 212 switch dir { 213 case TCPDirClientToServer: 214 return "client->server" 215 case TCPDirServerToClient: 216 return "server->client" 217 } 218 return "" 219} 220 221// Reverse returns the reversed direction 222func (dir TCPFlowDirection) Reverse() TCPFlowDirection { 223 return !dir 224} 225 226/* page: implements a byteContainer */ 227 228// page is used to store TCP data we're not ready for yet (out-of-order 229// packets). Unused pages are stored in and returned from a pageCache, which 230// avoids memory allocation. Used pages are stored in a doubly-linked list in 231// a connection. 232type page struct { 233 bytes []byte 234 seq Sequence 235 prev, next *page 236 buf [pageBytes]byte 237 ac AssemblerContext // only set for the first page of a packet 238 seen time.Time 239 start, end bool 240} 241 242func (p *page) getBytes() []byte { 243 return p.bytes 244} 245func (p *page) captureInfo() gopacket.CaptureInfo { 246 return p.ac.GetCaptureInfo() 247} 248func (p *page) assemblerContext() AssemblerContext { 249 return p.ac 250} 251func (p *page) convertToPages(pc *pageCache, skip int, ac AssemblerContext) (*page, *page, int) { 252 if skip != 0 { 253 p.bytes = p.bytes[skip:] 254 p.seq = p.seq.Add(skip) 255 } 256 p.prev, p.next = nil, nil 257 return p, p, 1 258} 259func (p *page) length() int { 260 return len(p.bytes) 261} 262func (p *page) release(pc *pageCache) int { 263 pc.replace(p) 264 return 1 265} 266func (p *page) isStart() bool { 267 return p.start 268} 269func (p *page) isEnd() bool { 270 return p.end 271} 272func (p *page) getSeq() Sequence { 273 return p.seq 274} 275func (p *page) isPacket() bool { 276 return p.ac != nil 277} 278func (p *page) String() string { 279 return fmt.Sprintf("page@%p{seq: %v, bytes:%d, -> nextSeq:%v} (prev:%p, next:%p)", p, p.seq, len(p.bytes), p.seq+Sequence(len(p.bytes)), p.prev, p.next) 280} 281 282/* livePacket: implements a byteContainer */ 283type livePacket struct { 284 bytes []byte 285 start bool 286 end bool 287 ci gopacket.CaptureInfo 288 ac AssemblerContext 289 seq Sequence 290} 291 292func (lp *livePacket) getBytes() []byte { 293 return lp.bytes 294} 295func (lp *livePacket) captureInfo() gopacket.CaptureInfo { 296 return lp.ci 297} 298func (lp *livePacket) assemblerContext() AssemblerContext { 299 return lp.ac 300} 301func (lp *livePacket) length() int { 302 return len(lp.bytes) 303} 304func (lp *livePacket) isStart() bool { 305 return lp.start 306} 307func (lp *livePacket) isEnd() bool { 308 return lp.end 309} 310func (lp *livePacket) getSeq() Sequence { 311 return lp.seq 312} 313func (lp *livePacket) isPacket() bool { 314 return true 315} 316 317// Creates a page (or set of pages) from a TCP packet: returns the first and last 318// page in its doubly-linked list of new pages. 319func (lp *livePacket) convertToPages(pc *pageCache, skip int, ac AssemblerContext) (*page, *page, int) { 320 ts := lp.ci.Timestamp 321 first := pc.next(ts) 322 current := first 323 current.prev = nil 324 first.ac = ac 325 numPages := 1 326 seq, bytes := lp.seq.Add(skip), lp.bytes[skip:] 327 for { 328 length := min(len(bytes), pageBytes) 329 current.bytes = current.buf[:length] 330 copy(current.bytes, bytes) 331 current.seq = seq 332 bytes = bytes[length:] 333 if len(bytes) == 0 { 334 current.end = lp.isEnd() 335 current.next = nil 336 break 337 } 338 seq = seq.Add(length) 339 current.next = pc.next(ts) 340 current.next.prev = current 341 current = current.next 342 current.ac = nil 343 numPages++ 344 } 345 return first, current, numPages 346} 347func (lp *livePacket) estimateNumberOfPages() int { 348 return (len(lp.bytes) + pageBytes + 1) / pageBytes 349} 350 351func (lp *livePacket) release(*pageCache) int { 352 return 0 353} 354 355// Stream is implemented by the caller to handle incoming reassembled 356// TCP data. Callers create a StreamFactory, then StreamPool uses 357// it to create a new Stream for every TCP stream. 358// 359// assembly will, in order: 360// 1) Create the stream via StreamFactory.New 361// 2) Call ReassembledSG 0 or more times, passing in reassembled TCP data in order 362// 3) Call ReassemblyComplete one time, after which the stream is dereferenced by assembly. 363type Stream interface { 364 // Tell whether the TCP packet should be accepted, start could be modified to force a start even if no SYN have been seen 365 Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir TCPFlowDirection, nextSeq Sequence, start *bool, ac AssemblerContext) bool 366 367 // ReassembledSG is called zero or more times. 368 // ScatterGather is reused after each Reassembled call, 369 // so it's important to copy anything you need out of it, 370 // especially bytes (or use KeepFrom()) 371 ReassembledSG(sg ScatterGather, ac AssemblerContext) 372 373 // ReassemblyComplete is called when assembly decides there is 374 // no more data for this Stream, either because a FIN or RST packet 375 // was seen, or because the stream has timed out without any new 376 // packet data (due to a call to FlushCloseOlderThan). 377 // It should return true if the connection should be removed from the pool 378 // It can return false if it want to see subsequent packets with Accept(), e.g. to 379 // see FIN-ACK, for deeper state-machine analysis. 380 ReassemblyComplete(ac AssemblerContext) bool 381} 382 383// StreamFactory is used by assembly to create a new stream for each 384// new TCP session. 385type StreamFactory interface { 386 // New should return a new stream for the given TCP key. 387 New(netFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac AssemblerContext) Stream 388} 389 390type key [2]gopacket.Flow 391 392func (k *key) String() string { 393 return fmt.Sprintf("%s:%s", k[0], k[1]) 394} 395 396func (k *key) Reverse() key { 397 return key{ 398 k[0].Reverse(), 399 k[1].Reverse(), 400 } 401} 402 403const assemblerReturnValueInitialSize = 16 404 405/* one-way connection, i.e. halfconnection */ 406type halfconnection struct { 407 dir TCPFlowDirection 408 pages int // Number of pages used (both in first/last and saved) 409 saved *page // Doubly-linked list of in-order pages (seq < nextSeq) already given to Stream who told us to keep 410 first, last *page // Doubly-linked list of out-of-order pages (seq > nextSeq) 411 nextSeq Sequence // sequence number of in-order received bytes 412 ackSeq Sequence 413 created, lastSeen time.Time 414 stream Stream 415 closed bool 416 // for stats 417 queuedBytes int 418 queuedPackets int 419 overlapBytes int 420 overlapPackets int 421} 422 423func (half *halfconnection) String() string { 424 closed := "" 425 if half.closed { 426 closed = "closed " 427 } 428 return fmt.Sprintf("%screated:%v, last:%v", closed, half.created, half.lastSeen) 429} 430 431// Dump returns a string (crypticly) describing the halfconnction 432func (half *halfconnection) Dump() string { 433 s := fmt.Sprintf("pages: %d\n"+ 434 "nextSeq: %d\n"+ 435 "ackSeq: %d\n"+ 436 "Seen : %s\n"+ 437 "dir: %s\n", half.pages, half.nextSeq, half.ackSeq, half.lastSeen, half.dir) 438 nb := 0 439 for p := half.first; p != nil; p = p.next { 440 s += fmt.Sprintf(" Page[%d] %s len: %d\n", nb, p, len(p.bytes)) 441 nb++ 442 } 443 return s 444} 445 446/* Bi-directionnal connection */ 447 448type connection struct { 449 key key // client->server 450 c2s, s2c halfconnection 451 mu sync.Mutex 452} 453 454func (c *connection) reset(k key, s Stream, ts time.Time) { 455 c.key = k 456 base := halfconnection{ 457 nextSeq: invalidSequence, 458 ackSeq: invalidSequence, 459 created: ts, 460 lastSeen: ts, 461 stream: s, 462 } 463 c.c2s, c.s2c = base, base 464 c.c2s.dir, c.s2c.dir = TCPDirClientToServer, TCPDirServerToClient 465} 466 467func (c *connection) lastSeen() time.Time { 468 if c.c2s.lastSeen.Before(c.s2c.lastSeen) { 469 return c.s2c.lastSeen 470 } 471 472 return c.c2s.lastSeen 473} 474 475func (c *connection) String() string { 476 return fmt.Sprintf("c2s: %s, s2c: %s", &c.c2s, &c.s2c) 477} 478 479/* 480 * Assembler 481 */ 482 483// DefaultAssemblerOptions provides default options for an assembler. 484// These options are used by default when calling NewAssembler, so if 485// modified before a NewAssembler call they'll affect the resulting Assembler. 486// 487// Note that the default options can result in ever-increasing memory usage 488// unless one of the Flush* methods is called on a regular basis. 489var DefaultAssemblerOptions = AssemblerOptions{ 490 MaxBufferedPagesPerConnection: 0, // unlimited 491 MaxBufferedPagesTotal: 0, // unlimited 492} 493 494// AssemblerOptions controls the behavior of each assembler. Modify the 495// options of each assembler you create to change their behavior. 496type AssemblerOptions struct { 497 // MaxBufferedPagesTotal is an upper limit on the total number of pages to 498 // buffer while waiting for out-of-order packets. Once this limit is 499 // reached, the assembler will degrade to flushing every connection it 500 // gets a packet for. If <= 0, this is ignored. 501 MaxBufferedPagesTotal int 502 // MaxBufferedPagesPerConnection is an upper limit on the number of pages 503 // buffered for a single connection. Should this limit be reached for a 504 // particular connection, the smallest sequence number will be flushed, along 505 // with any contiguous data. If <= 0, this is ignored. 506 MaxBufferedPagesPerConnection int 507} 508 509// Assembler handles reassembling TCP streams. It is not safe for 510// concurrency... after passing a packet in via the Assemble call, the caller 511// must wait for that call to return before calling Assemble again. Callers can 512// get around this by creating multiple assemblers that share a StreamPool. In 513// that case, each individual stream will still be handled serially (each stream 514// has an individual mutex associated with it), however multiple assemblers can 515// assemble different connections concurrently. 516// 517// The Assembler provides (hopefully) fast TCP stream re-assembly for sniffing 518// applications written in Go. The Assembler uses the following methods to be 519// as fast as possible, to keep packet processing speedy: 520// 521// Avoids Lock Contention 522// 523// Assemblers locks connections, but each connection has an individual lock, and 524// rarely will two Assemblers be looking at the same connection. Assemblers 525// lock the StreamPool when looking up connections, but they use Reader 526// locks initially, and only force a write lock if they need to create a new 527// connection or close one down. These happen much less frequently than 528// individual packet handling. 529// 530// Each assembler runs in its own goroutine, and the only state shared between 531// goroutines is through the StreamPool. Thus all internal Assembler state 532// can be handled without any locking. 533// 534// NOTE: If you can guarantee that packets going to a set of Assemblers will 535// contain information on different connections per Assembler (for example, 536// they're already hashed by PF_RING hashing or some other hashing mechanism), 537// then we recommend you use a seperate StreamPool per Assembler, thus 538// avoiding all lock contention. Only when different Assemblers could receive 539// packets for the same Stream should a StreamPool be shared between them. 540// 541// Avoids Memory Copying 542// 543// In the common case, handling of a single TCP packet should result in zero 544// memory allocations. The Assembler will look up the connection, figure out 545// that the packet has arrived in order, and immediately pass that packet on to 546// the appropriate connection's handling code. Only if a packet arrives out of 547// order is its contents copied and stored in memory for later. 548// 549// Avoids Memory Allocation 550// 551// Assemblers try very hard to not use memory allocation unless absolutely 552// necessary. Packet data for sequential packets is passed directly to streams 553// with no copying or allocation. Packet data for out-of-order packets is 554// copied into reusable pages, and new pages are only allocated rarely when the 555// page cache runs out. Page caches are Assembler-specific, thus not used 556// concurrently and requiring no locking. 557// 558// Internal representations for connection objects are also reused over time. 559// Because of this, the most common memory allocation done by the Assembler is 560// generally what's done by the caller in StreamFactory.New. If no allocation 561// is done there, then very little allocation is done ever, mostly to handle 562// large increases in bandwidth or numbers of connections. 563// 564// TODO: The page caches used by an Assembler will grow to the size necessary 565// to handle a workload, and currently will never shrink. This means that 566// traffic spikes can result in large memory usage which isn't garbage 567// collected when typical traffic levels return. 568type Assembler struct { 569 AssemblerOptions 570 ret []byteContainer 571 pc *pageCache 572 connPool *StreamPool 573 cacheLP livePacket 574 cacheSG reassemblyObject 575 start bool 576} 577 578// NewAssembler creates a new assembler. Pass in the StreamPool 579// to use, may be shared across assemblers. 580// 581// This sets some sane defaults for the assembler options, 582// see DefaultAssemblerOptions for details. 583func NewAssembler(pool *StreamPool) *Assembler { 584 pool.mu.Lock() 585 pool.users++ 586 pool.mu.Unlock() 587 return &Assembler{ 588 ret: make([]byteContainer, 0, assemblerReturnValueInitialSize), 589 pc: newPageCache(), 590 connPool: pool, 591 AssemblerOptions: DefaultAssemblerOptions, 592 } 593} 594 595// Dump returns a short string describing the page usage of the Assembler 596func (a *Assembler) Dump() string { 597 s := "" 598 s += fmt.Sprintf("pageCache: used: %d, size: %d, free: %d", a.pc.used, a.pc.size, len(a.pc.free)) 599 return s 600} 601 602// AssemblerContext provides method to get metadata 603type AssemblerContext interface { 604 GetCaptureInfo() gopacket.CaptureInfo 605} 606 607// Implements AssemblerContext for Assemble() 608type assemblerSimpleContext gopacket.CaptureInfo 609 610func (asc *assemblerSimpleContext) GetCaptureInfo() gopacket.CaptureInfo { 611 return gopacket.CaptureInfo(*asc) 612} 613 614// Assemble calls AssembleWithContext with the current timestamp, useful for 615// packets being read directly off the wire. 616func (a *Assembler) Assemble(netFlow gopacket.Flow, t *layers.TCP) { 617 ctx := assemblerSimpleContext(gopacket.CaptureInfo{Timestamp: time.Now()}) 618 a.AssembleWithContext(netFlow, t, &ctx) 619} 620 621type assemblerAction struct { 622 nextSeq Sequence 623 queue bool 624} 625 626// AssembleWithContext reassembles the given TCP packet into its appropriate 627// stream. 628// 629// The timestamp passed in must be the timestamp the packet was seen. 630// For packets read off the wire, time.Now() should be fine. For packets read 631// from PCAP files, CaptureInfo.Timestamp should be passed in. This timestamp 632// will affect which streams are flushed by a call to FlushCloseOlderThan. 633// 634// Each AssembleWithContext call results in, in order: 635// 636// zero or one call to StreamFactory.New, creating a stream 637// zero or one call to ReassembledSG on a single stream 638// zero or one call to ReassemblyComplete on the same stream 639func (a *Assembler) AssembleWithContext(netFlow gopacket.Flow, t *layers.TCP, ac AssemblerContext) { 640 var conn *connection 641 var half *halfconnection 642 var rev *halfconnection 643 644 a.ret = a.ret[:0] 645 key := key{netFlow, t.TransportFlow()} 646 ci := ac.GetCaptureInfo() 647 timestamp := ci.Timestamp 648 649 conn, half, rev = a.connPool.getConnection(key, false, timestamp, t, ac) 650 if conn == nil { 651 if *debugLog { 652 log.Printf("%v got empty packet on otherwise empty connection", key) 653 } 654 return 655 } 656 conn.mu.Lock() 657 defer conn.mu.Unlock() 658 if half.lastSeen.Before(timestamp) { 659 half.lastSeen = timestamp 660 } 661 a.start = half.nextSeq == invalidSequence && t.SYN 662 if *debugLog { 663 if half.nextSeq < rev.ackSeq { 664 log.Printf("Delay detected on %v, data is acked but not assembled yet (acked %v, nextSeq %v)", key, rev.ackSeq, half.nextSeq) 665 } 666 } 667 668 if !half.stream.Accept(t, ci, half.dir, half.nextSeq, &a.start, ac) { 669 if *debugLog { 670 log.Printf("Ignoring packet") 671 } 672 return 673 } 674 if half.closed { 675 // this way is closed 676 if *debugLog { 677 log.Printf("%v got packet on closed half", key) 678 } 679 return 680 } 681 682 seq, ack, bytes := Sequence(t.Seq), Sequence(t.Ack), t.Payload 683 if t.ACK { 684 half.ackSeq = ack 685 } 686 // TODO: push when Ack is seen ?? 687 action := assemblerAction{ 688 nextSeq: Sequence(invalidSequence), 689 queue: true, 690 } 691 a.dump("AssembleWithContext()", half) 692 if half.nextSeq == invalidSequence { 693 if t.SYN { 694 if *debugLog { 695 log.Printf("%v saw first SYN packet, returning immediately, seq=%v", key, seq) 696 } 697 seq = seq.Add(1) 698 half.nextSeq = seq 699 action.queue = false 700 } else if a.start { 701 if *debugLog { 702 log.Printf("%v start forced", key) 703 } 704 half.nextSeq = seq 705 action.queue = false 706 } else { 707 if *debugLog { 708 log.Printf("%v waiting for start, storing into connection", key) 709 } 710 } 711 } else { 712 diff := half.nextSeq.Difference(seq) 713 if diff > 0 { 714 if *debugLog { 715 log.Printf("%v gap in sequence numbers (%v, %v) diff %v, storing into connection", key, half.nextSeq, seq, diff) 716 } 717 } else { 718 if *debugLog { 719 log.Printf("%v found contiguous data (%v, %v), returning immediately: len:%d", key, seq, half.nextSeq, len(bytes)) 720 } 721 action.queue = false 722 } 723 } 724 725 action = a.handleBytes(bytes, seq, half, ci, t.SYN, t.RST || t.FIN, action, ac) 726 if len(a.ret) > 0 { 727 action.nextSeq = a.sendToConnection(conn, half, ac) 728 } 729 if action.nextSeq != invalidSequence { 730 half.nextSeq = action.nextSeq 731 if t.FIN { 732 half.nextSeq = half.nextSeq.Add(1) 733 } 734 } 735 if *debugLog { 736 log.Printf("%v nextSeq:%d", key, half.nextSeq) 737 } 738} 739 740// Overlap strategies: 741// - new packet overlaps with sent packets: 742// 1) discard new overlapping part 743// 2) overwrite old overlapped (TODO) 744// - new packet overlaps existing queued packets: 745// a) consider "age" by timestamp (TODO) 746// b) consider "age" by being present 747// Then 748// 1) discard new overlapping part 749// 2) overwrite queued part 750 751func (a *Assembler) checkOverlap(half *halfconnection, queue bool, ac AssemblerContext) { 752 var next *page 753 cur := half.last 754 bytes := a.cacheLP.bytes 755 start := a.cacheLP.seq 756 end := start.Add(len(bytes)) 757 758 a.dump("before checkOverlap", half) 759 760 // [s6 : e6] 761 // [s1:e1][s2:e2] -- [s3:e3] -- [s4:e4][s5:e5] 762 // [s <--ds-- : --de--> e] 763 for cur != nil { 764 765 if *debugLog { 766 log.Printf("cur = %p (%s)\n", cur, cur) 767 } 768 769 // end < cur.start: continue (5) 770 if end.Difference(cur.seq) > 0 { 771 if *debugLog { 772 log.Printf("case 5\n") 773 } 774 next = cur 775 cur = cur.prev 776 continue 777 } 778 779 curEnd := cur.seq.Add(len(cur.bytes)) 780 // start > cur.end: stop (1) 781 if start.Difference(curEnd) <= 0 { 782 if *debugLog { 783 log.Printf("case 1\n") 784 } 785 break 786 } 787 788 diffStart := start.Difference(cur.seq) 789 diffEnd := end.Difference(curEnd) 790 791 // end > cur.end && start < cur.start: drop (3) 792 if diffEnd <= 0 && diffStart >= 0 { 793 if *debugLog { 794 log.Printf("case 3\n") 795 } 796 if cur.isPacket() { 797 half.overlapPackets++ 798 } 799 half.overlapBytes += len(cur.bytes) 800 // update links 801 if cur.prev != nil { 802 cur.prev.next = cur.next 803 } else { 804 half.first = cur.next 805 } 806 if cur.next != nil { 807 cur.next.prev = cur.prev 808 } else { 809 half.last = cur.prev 810 } 811 tmp := cur.prev 812 half.pages -= cur.release(a.pc) 813 cur = tmp 814 continue 815 } 816 817 // end > cur.end && start < cur.end: drop cur's end (2) 818 if diffEnd < 0 && start.Difference(curEnd) > 0 { 819 if *debugLog { 820 log.Printf("case 2\n") 821 } 822 cur.bytes = cur.bytes[:-start.Difference(cur.seq)] 823 break 824 } else 825 826 // start < cur.start && end > cur.start: drop cur's start (4) 827 if diffStart > 0 && end.Difference(cur.seq) < 0 { 828 if *debugLog { 829 log.Printf("case 4\n") 830 } 831 cur.bytes = cur.bytes[-end.Difference(cur.seq):] 832 cur.seq = cur.seq.Add(-end.Difference(cur.seq)) 833 next = cur 834 } else 835 836 // end < cur.end && start > cur.start: replace bytes inside cur (6) 837 if diffEnd > 0 && diffStart < 0 { 838 if *debugLog { 839 log.Printf("case 6\n") 840 } 841 copy(cur.bytes[-diffStart:-diffStart+len(bytes)], bytes) 842 bytes = bytes[:0] 843 } else { 844 if *debugLog { 845 log.Printf("no overlap\n") 846 } 847 next = cur 848 } 849 cur = cur.prev 850 } 851 852 // Split bytes into pages, and insert in queue 853 a.cacheLP.bytes = bytes 854 a.cacheLP.seq = start 855 if len(bytes) > 0 && queue { 856 p, p2, numPages := a.cacheLP.convertToPages(a.pc, 0, ac) 857 half.queuedPackets++ 858 half.queuedBytes += len(bytes) 859 half.pages += numPages 860 if cur != nil { 861 if *debugLog { 862 log.Printf("adding %s after %s", p, cur) 863 } 864 cur.next = p 865 p.prev = cur 866 } else { 867 if *debugLog { 868 log.Printf("adding %s as first", p) 869 } 870 half.first = p 871 } 872 if next != nil { 873 if *debugLog { 874 log.Printf("setting %s as next of new %s", next, p2) 875 } 876 p2.next = next 877 next.prev = p2 878 } else { 879 if *debugLog { 880 log.Printf("setting %s as last", p2) 881 } 882 half.last = p2 883 } 884 } 885 a.dump("After checkOverlap", half) 886} 887 888// Warning: this is a low-level dumper, i.e. a.ret or a.cacheSG might 889// be strange, but it could be ok. 890func (a *Assembler) dump(text string, half *halfconnection) { 891 if !*debugLog { 892 return 893 } 894 log.Printf("%s: dump\n", text) 895 if half != nil { 896 p := half.first 897 if p == nil { 898 log.Printf(" * half.first = %p, no chunks queued\n", p) 899 } else { 900 s := 0 901 nb := 0 902 log.Printf(" * half.first = %p, queued chunks:", p) 903 for p != nil { 904 log.Printf("\t%s bytes:%s\n", p, hex.EncodeToString(p.bytes)) 905 s += len(p.bytes) 906 nb++ 907 p = p.next 908 } 909 log.Printf("\t%d chunks for %d bytes", nb, s) 910 } 911 log.Printf(" * half.last = %p\n", half.last) 912 log.Printf(" * half.saved = %p\n", half.saved) 913 p = half.saved 914 for p != nil { 915 log.Printf("\tseq:%d %s bytes:%s\n", p.getSeq(), p, hex.EncodeToString(p.bytes)) 916 p = p.next 917 } 918 } 919 log.Printf(" * a.ret\n") 920 for i, r := range a.ret { 921 log.Printf("\t%d: %v b:%s\n", i, r.captureInfo(), hex.EncodeToString(r.getBytes())) 922 } 923 log.Printf(" * a.cacheSG.all\n") 924 for i, r := range a.cacheSG.all { 925 log.Printf("\t%d: %v b:%s\n", i, r.captureInfo(), hex.EncodeToString(r.getBytes())) 926 } 927} 928 929func (a *Assembler) overlapExisting(half *halfconnection, start, end Sequence, bytes []byte) ([]byte, Sequence) { 930 if half.nextSeq == invalidSequence { 931 // no start yet 932 return bytes, start 933 } 934 diff := start.Difference(half.nextSeq) 935 if diff == 0 { 936 return bytes, start 937 } 938 s := 0 939 e := len(bytes) 940 // TODO: depending on strategy, we might want to shrink half.saved if possible 941 if e != 0 { 942 if *debugLog { 943 log.Printf("Overlap detected: ignoring current packet's first %d bytes", diff) 944 } 945 half.overlapPackets++ 946 half.overlapBytes += diff 947 } 948 s += diff 949 if s >= e { 950 // Completely included in sent 951 s = e 952 } 953 bytes = bytes[s:] 954 return bytes, half.nextSeq 955} 956 957// Prepare send or queue 958func (a *Assembler) handleBytes(bytes []byte, seq Sequence, half *halfconnection, ci gopacket.CaptureInfo, start bool, end bool, action assemblerAction, ac AssemblerContext) assemblerAction { 959 a.cacheLP.bytes = bytes 960 a.cacheLP.start = start 961 a.cacheLP.end = end 962 a.cacheLP.seq = seq 963 a.cacheLP.ci = ci 964 a.cacheLP.ac = ac 965 966 if action.queue { 967 a.checkOverlap(half, true, ac) 968 if (a.MaxBufferedPagesPerConnection > 0 && half.pages >= a.MaxBufferedPagesPerConnection) || 969 (a.MaxBufferedPagesTotal > 0 && a.pc.used >= a.MaxBufferedPagesTotal) { 970 if *debugLog { 971 log.Printf("hit max buffer size: %+v, %v, %v", a.AssemblerOptions, half.pages, a.pc.used) 972 } 973 action.queue = false 974 a.addNextFromConn(half) 975 } 976 a.dump("handleBytes after queue", half) 977 } else { 978 a.cacheLP.bytes, a.cacheLP.seq = a.overlapExisting(half, seq, seq.Add(len(bytes)), a.cacheLP.bytes) 979 a.checkOverlap(half, false, ac) 980 if len(a.cacheLP.bytes) != 0 || end || start { 981 a.ret = append(a.ret, &a.cacheLP) 982 } 983 a.dump("handleBytes after no queue", half) 984 } 985 return action 986} 987 988func (a *Assembler) setStatsToSG(half *halfconnection) { 989 a.cacheSG.queuedBytes = half.queuedBytes 990 half.queuedBytes = 0 991 a.cacheSG.queuedPackets = half.queuedPackets 992 half.queuedPackets = 0 993 a.cacheSG.overlapBytes = half.overlapBytes 994 half.overlapBytes = 0 995 a.cacheSG.overlapPackets = half.overlapPackets 996 half.overlapPackets = 0 997} 998 999// Build the ScatterGather object, i.e. prepend saved bytes and 1000// append continuous bytes. 1001func (a *Assembler) buildSG(half *halfconnection) (bool, Sequence) { 1002 // find if there are skipped bytes 1003 skip := -1 1004 if half.nextSeq != invalidSequence { 1005 skip = half.nextSeq.Difference(a.ret[0].getSeq()) 1006 } 1007 last := a.ret[0].getSeq().Add(a.ret[0].length()) 1008 // Prepend saved bytes 1009 saved := a.addPending(half, a.ret[0].getSeq()) 1010 // Append continuous bytes 1011 nextSeq := a.addContiguous(half, last) 1012 a.cacheSG.all = a.ret 1013 a.cacheSG.Direction = half.dir 1014 a.cacheSG.Skip = skip 1015 a.cacheSG.saved = saved 1016 a.cacheSG.toKeep = -1 1017 a.setStatsToSG(half) 1018 a.dump("after buildSG", half) 1019 return a.ret[len(a.ret)-1].isEnd(), nextSeq 1020} 1021 1022func (a *Assembler) cleanSG(half *halfconnection, ac AssemblerContext) { 1023 cur := 0 1024 ndx := 0 1025 skip := 0 1026 1027 a.dump("cleanSG(start)", half) 1028 1029 var r byteContainer 1030 // Find first page to keep 1031 if a.cacheSG.toKeep < 0 { 1032 ndx = len(a.cacheSG.all) 1033 } else { 1034 skip = a.cacheSG.toKeep 1035 found := false 1036 for ndx, r = range a.cacheSG.all { 1037 if a.cacheSG.toKeep < cur+r.length() { 1038 found = true 1039 break 1040 } 1041 cur += r.length() 1042 if skip >= r.length() { 1043 skip -= r.length() 1044 } 1045 } 1046 if !found { 1047 ndx++ 1048 } 1049 } 1050 // Release consumed pages 1051 for _, r := range a.cacheSG.all[:ndx] { 1052 if r == half.saved { 1053 if half.saved.next != nil { 1054 half.saved.next.prev = nil 1055 } 1056 half.saved = half.saved.next 1057 } else if r == half.first { 1058 if half.first.next != nil { 1059 half.first.next.prev = nil 1060 } 1061 if half.first == half.last { 1062 half.first, half.last = nil, nil 1063 } else { 1064 half.first = half.first.next 1065 } 1066 } 1067 half.pages -= r.release(a.pc) 1068 } 1069 a.dump("after consumed release", half) 1070 // Keep un-consumed pages 1071 nbKept := 0 1072 half.saved = nil 1073 var saved *page 1074 for _, r := range a.cacheSG.all[ndx:] { 1075 first, last, nb := r.convertToPages(a.pc, skip, ac) 1076 if half.saved == nil { 1077 half.saved = first 1078 } else { 1079 saved.next = first 1080 first.prev = saved 1081 } 1082 saved = last 1083 nbKept += nb 1084 } 1085 if *debugLog { 1086 log.Printf("Remaining %d chunks in SG\n", nbKept) 1087 log.Printf("%s\n", a.Dump()) 1088 a.dump("after cleanSG()", half) 1089 } 1090} 1091 1092// sendToConnection sends the current values in a.ret to the connection, closing 1093// the connection if the last thing sent had End set. 1094func (a *Assembler) sendToConnection(conn *connection, half *halfconnection, ac AssemblerContext) Sequence { 1095 if *debugLog { 1096 log.Printf("sendToConnection\n") 1097 } 1098 end, nextSeq := a.buildSG(half) 1099 half.stream.ReassembledSG(&a.cacheSG, ac) 1100 a.cleanSG(half, ac) 1101 if end { 1102 a.closeHalfConnection(conn, half) 1103 } 1104 if *debugLog { 1105 log.Printf("after sendToConnection: nextSeq: %d\n", nextSeq) 1106 } 1107 return nextSeq 1108} 1109 1110// 1111func (a *Assembler) addPending(half *halfconnection, firstSeq Sequence) int { 1112 if half.saved == nil { 1113 return 0 1114 } 1115 s := 0 1116 ret := []byteContainer{} 1117 for p := half.saved; p != nil; p = p.next { 1118 if *debugLog { 1119 log.Printf("adding pending @%p %s (%s)\n", p, p, hex.EncodeToString(p.bytes)) 1120 } 1121 ret = append(ret, p) 1122 s += len(p.bytes) 1123 } 1124 if half.saved.seq.Add(s) != firstSeq { 1125 // non-continuous saved: drop them 1126 var next *page 1127 for p := half.saved; p != nil; p = next { 1128 next = p.next 1129 p.release(a.pc) 1130 } 1131 half.saved = nil 1132 ret = []byteContainer{} 1133 s = 0 1134 } 1135 1136 a.ret = append(ret, a.ret...) 1137 return s 1138} 1139 1140// addContiguous adds contiguous byte-sets to a connection. 1141func (a *Assembler) addContiguous(half *halfconnection, lastSeq Sequence) Sequence { 1142 page := half.first 1143 if page == nil { 1144 if *debugLog { 1145 log.Printf("addContiguous(%d): no pages\n", lastSeq) 1146 } 1147 return lastSeq 1148 } 1149 if lastSeq == invalidSequence { 1150 lastSeq = page.seq 1151 } 1152 for page != nil && lastSeq.Difference(page.seq) == 0 { 1153 if *debugLog { 1154 log.Printf("addContiguous: lastSeq: %d, first.seq=%d, page.seq=%d\n", half.nextSeq, half.first.seq, page.seq) 1155 } 1156 lastSeq = lastSeq.Add(len(page.bytes)) 1157 a.ret = append(a.ret, page) 1158 half.first = page.next 1159 if half.first == nil { 1160 half.last = nil 1161 } 1162 if page.next != nil { 1163 page.next.prev = nil 1164 } 1165 page = page.next 1166 } 1167 return lastSeq 1168} 1169 1170// skipFlush skips the first set of bytes we're waiting for and returns the 1171// first set of bytes we have. If we have no bytes saved, it closes the 1172// connection. 1173func (a *Assembler) skipFlush(conn *connection, half *halfconnection) { 1174 if *debugLog { 1175 log.Printf("skipFlush %v\n", half.nextSeq) 1176 } 1177 // Well, it's embarassing it there is still something in half.saved 1178 // FIXME: change API to give back saved + new/no packets 1179 if half.first == nil { 1180 a.closeHalfConnection(conn, half) 1181 return 1182 } 1183 a.ret = a.ret[:0] 1184 a.addNextFromConn(half) 1185 nextSeq := a.sendToConnection(conn, half, a.ret[0].assemblerContext()) 1186 if nextSeq != invalidSequence { 1187 half.nextSeq = nextSeq 1188 } 1189} 1190 1191func (a *Assembler) closeHalfConnection(conn *connection, half *halfconnection) { 1192 if *debugLog { 1193 log.Printf("%v closing", conn) 1194 } 1195 half.closed = true 1196 for p := half.first; p != nil; p = p.next { 1197 // FIXME: it should be already empty 1198 a.pc.replace(p) 1199 half.pages-- 1200 } 1201 if conn.s2c.closed && conn.c2s.closed { 1202 if half.stream.ReassemblyComplete(nil) { //FIXME: which context to pass ? 1203 a.connPool.remove(conn) 1204 } 1205 } 1206} 1207 1208// addNextFromConn pops the first page from a connection off and adds it to the 1209// return array. 1210func (a *Assembler) addNextFromConn(conn *halfconnection) { 1211 if conn.first == nil { 1212 return 1213 } 1214 if *debugLog { 1215 log.Printf(" adding from conn (%v, %v) %v (%d)\n", conn.first.seq, conn.nextSeq, conn.nextSeq-conn.first.seq, len(conn.first.bytes)) 1216 } 1217 a.ret = append(a.ret, conn.first) 1218 conn.first = conn.first.next 1219 if conn.first != nil { 1220 conn.first.prev = nil 1221 } else { 1222 conn.last = nil 1223 } 1224} 1225 1226// FlushOptions provide options for flushing connections. 1227type FlushOptions struct { 1228 T time.Time // If nonzero, only connections with data older than T are flushed 1229 TC time.Time // If nonzero, only connections with data older than TC are closed (if no FIN/RST received) 1230} 1231 1232// FlushWithOptions finds any streams waiting for packets older than 1233// the given time T, and pushes through the data they have (IE: tells 1234// them to stop waiting and skip the data they're waiting for). 1235// 1236// It also closes streams older than TC (that can be set to zero, to keep 1237// long-lived stream alive, but to flush data anyway). 1238// 1239// Each Stream maintains a list of zero or more sets of bytes it has received 1240// out-of-order. For example, if it has processed up through sequence number 1241// 10, it might have bytes [15-20), [20-25), [30,50) in its list. Each set of 1242// bytes also has the timestamp it was originally viewed. A flush call will 1243// look at the smallest subsequent set of bytes, in this case [15-20), and if 1244// its timestamp is older than the passed-in time, it will push it and all 1245// contiguous byte-sets out to the Stream's Reassembled function. In this case, 1246// it will push [15-20), but also [20-25), since that's contiguous. It will 1247// only push [30-50) if its timestamp is also older than the passed-in time, 1248// otherwise it will wait until the next FlushCloseOlderThan to see if bytes 1249// [25-30) come in. 1250// 1251// Returns the number of connections flushed, and of those, the number closed 1252// because of the flush. 1253func (a *Assembler) FlushWithOptions(opt FlushOptions) (flushed, closed int) { 1254 conns := a.connPool.connections() 1255 closes := 0 1256 flushes := 0 1257 for _, conn := range conns { 1258 remove := false 1259 conn.mu.Lock() 1260 for _, half := range []*halfconnection{&conn.s2c, &conn.c2s} { 1261 flushed, closed := a.flushClose(conn, half, opt.T, opt.TC) 1262 if flushed { 1263 flushes++ 1264 } 1265 if closed { 1266 closes++ 1267 } 1268 } 1269 if conn.s2c.closed && conn.c2s.closed && conn.s2c.lastSeen.Before(opt.TC) && conn.c2s.lastSeen.Before(opt.TC) { 1270 remove = true 1271 } 1272 conn.mu.Unlock() 1273 if remove { 1274 a.connPool.remove(conn) 1275 } 1276 } 1277 return flushes, closes 1278} 1279 1280// FlushCloseOlderThan flushes and closes streams older than given time 1281func (a *Assembler) FlushCloseOlderThan(t time.Time) (flushed, closed int) { 1282 return a.FlushWithOptions(FlushOptions{T: t, TC: t}) 1283} 1284 1285func (a *Assembler) flushClose(conn *connection, half *halfconnection, t time.Time, tc time.Time) (bool, bool) { 1286 flushed, closed := false, false 1287 if half.closed { 1288 return flushed, closed 1289 } 1290 for half.first != nil && half.first.seen.Before(t) { 1291 flushed = true 1292 a.skipFlush(conn, half) 1293 if half.closed { 1294 closed = true 1295 return flushed, closed 1296 } 1297 } 1298 // Close the connection only if both halfs of the connection last seen before tc. 1299 if !half.closed && half.first == nil && conn.lastSeen().Before(tc) { 1300 a.closeHalfConnection(conn, half) 1301 closed = true 1302 } 1303 return flushed, closed 1304} 1305 1306// FlushAll flushes all remaining data into all remaining connections and closes 1307// those connections. It returns the total number of connections flushed/closed 1308// by the call. 1309func (a *Assembler) FlushAll() (closed int) { 1310 conns := a.connPool.connections() 1311 closed = len(conns) 1312 for _, conn := range conns { 1313 conn.mu.Lock() 1314 for _, half := range []*halfconnection{&conn.s2c, &conn.c2s} { 1315 for !half.closed { 1316 a.skipFlush(conn, half) 1317 } 1318 if !half.closed { 1319 a.closeHalfConnection(conn, half) 1320 } 1321 } 1322 conn.mu.Unlock() 1323 } 1324 return 1325} 1326 1327func min(a, b int) int { 1328 if a < b { 1329 return a 1330 } 1331 return b 1332} 1333