1// Copyright (C) MongoDB, Inc. 2014-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongoreplay
8
9import (
10	"container/heap"
11	"fmt"
12	"sync/atomic"
13	"time"
14
15	"github.com/google/gopacket"
16	"github.com/google/gopacket/tcpassembly"
17)
18
19// OpStreamSettings stores settings for any command which may listen to an
20// opstream.
21type OpStreamSettings struct {
22	PcapFile         string `short:"f" description:"path to the pcap file to be read"`
23	PacketBufSize    int    `short:"b" description:"Size of heap used to merge separate streams together"`
24	CaptureBufSize   int    `long:"capSize" description:"Size in KiB of the PCAP capture buffer"`
25	Expression       string `short:"e" long:"expr" description:"BPF filter expression to apply to packets"`
26	NetworkInterface string `short:"i" description:"network interface to listen on"`
27	MaxBufferedPages int    `long:"maxBufferedPages" description:"maximum number of memory pages to store when buffering packets. The cache size is unlimited if not set"`
28}
29
30// tcpassembly.Stream implementation.
31type stream struct {
32	bidi             *bidi
33	reassembled      chan []tcpassembly.Reassembly
34	reassembly       tcpassembly.Reassembly
35	done             chan interface{}
36	op               *RawOp
37	opTimeStamp      time.Time
38	state            streamState
39	netFlow, tcpFlow gopacket.Flow
40}
41
42// Reassembled receives the new slice of reassembled data and forwards it to the
43// MongoOpStream->streamOps goroutine for which turns them in to protocol
44// messages.
45// Since the tcpassembler reuses the tcpreassembly.Reassembled buffers, we wait
46// for streamOps to signal us that it's done with them before returning.
47func (stream *stream) Reassembled(reassembly []tcpassembly.Reassembly) {
48	stream.reassembled <- reassembly
49	<-stream.done
50}
51
52// ReassemblyComplete receives from the tcpassembler the fact that the stream is
53// now finished. Because our streamOps function may be reading from more then
54// one stream, we only shut down the bidi once all the streams are finished.
55func (stream *stream) ReassemblyComplete() {
56
57	count := atomic.AddInt32(&stream.bidi.openStreamCount, -1)
58	if count < 0 {
59		panic("negative openStreamCount")
60	}
61	if count == 0 {
62		stream.bidi.handleStreamCompleted()
63		stream.bidi.close()
64	}
65}
66
67// tcpassembly.StreamFactory implementation.
68
69// bidi is a bidirectional connection.
70type bidi struct {
71	streams          [2]*stream
72	openStreamCount  int32
73	opStream         *MongoOpStream
74	responseStream   bool
75	sawStart         bool
76	connectionNumber int64
77}
78
79func newBidi(netFlow, tcpFlow gopacket.Flow, opStream *MongoOpStream, num int64) *bidi {
80	bidi := &bidi{connectionNumber: num}
81	bidi.streams[0] = &stream{
82		bidi:        bidi,
83		reassembled: make(chan []tcpassembly.Reassembly),
84		done:        make(chan interface{}),
85		op:          &RawOp{},
86		netFlow:     netFlow,
87		tcpFlow:     tcpFlow,
88	}
89	bidi.streams[1] = &stream{
90		bidi:        bidi,
91		reassembled: make(chan []tcpassembly.Reassembly),
92		done:        make(chan interface{}),
93		op:          &RawOp{},
94		netFlow:     netFlow.Reverse(),
95		tcpFlow:     tcpFlow.Reverse(),
96	}
97	bidi.opStream = opStream
98	return bidi
99}
100
101func (bidi *bidi) logvf(minVerb int, format string, a ...interface{}) {
102	userInfoLogger.Logvf(minVerb, "stream %v %v", bidi.connectionNumber, fmt.Sprintf(format, a...))
103}
104
105// close closes the channels used to communicate between the
106// stream's and bidi.streamOps,
107// and removes the bidi from the bidiMap.
108func (bidi *bidi) close() {
109	close(bidi.streams[0].reassembled)
110	close(bidi.streams[0].done)
111	close(bidi.streams[1].reassembled)
112	close(bidi.streams[1].done)
113	delete(bidi.opStream.bidiMap, bidiKey{bidi.streams[1].netFlow, bidi.streams[1].tcpFlow})
114	delete(bidi.opStream.bidiMap, bidiKey{bidi.streams[0].netFlow, bidi.streams[0].tcpFlow})
115	// probably not important, just trying to make the garbage collection easier.
116	bidi.streams[0].bidi = nil
117	bidi.streams[1].bidi = nil
118}
119
120type bidiKey struct {
121	net, transport gopacket.Flow
122}
123
124// MongoOpStream is the opstream which yields RecordedOps
125type MongoOpStream struct {
126	Ops chan *RecordedOp
127
128	FirstSeen         time.Time
129	unorderedOps      chan RecordedOp
130	opHeap            *orderedOps
131	bidiMap           map[bidiKey]*bidi
132	connectionCounter chan int64
133	connectionNumber  int64
134}
135
136// NewMongoOpStream initializes a new MongoOpStream
137func NewMongoOpStream(heapBufSize int) *MongoOpStream {
138	h := make(orderedOps, 0, heapBufSize)
139	os := &MongoOpStream{
140		Ops:               make(chan *RecordedOp), // ordered
141		unorderedOps:      make(chan RecordedOp),  // unordered
142		opHeap:            &h,
143		bidiMap:           make(map[bidiKey]*bidi),
144		connectionCounter: make(chan int64, 1024),
145	}
146	heap.Init(os.opHeap)
147	go func() {
148		var counter int64
149		for {
150			os.connectionCounter <- counter
151			counter++
152		}
153	}()
154	go os.handleOps()
155	return os
156}
157
158// New is the factory method called by the tcpassembly to generate new tcpassembly.Stream.
159func (os *MongoOpStream) New(netFlow, tcpFlow gopacket.Flow) tcpassembly.Stream {
160	key := bidiKey{netFlow, tcpFlow}
161	rkey := bidiKey{netFlow.Reverse(), tcpFlow.Reverse()}
162	if bidi, ok := os.bidiMap[key]; ok {
163		atomic.AddInt32(&bidi.openStreamCount, 1)
164		delete(os.bidiMap, key)
165		return bidi.streams[1]
166	}
167	bidi := newBidi(netFlow, tcpFlow, os, <-os.connectionCounter)
168	os.bidiMap[rkey] = bidi
169	atomic.AddInt32(&bidi.openStreamCount, 1)
170	go bidi.streamOps()
171	return bidi.streams[0]
172}
173
174// Close is called by the tcpassembly to indicate that all of the packets
175// have been processed.
176func (os *MongoOpStream) Close() error {
177	close(os.unorderedOps)
178	os.unorderedOps = nil
179	return nil
180}
181
182// SetFirstSeen sets the time for the first message on the MongoOpStream.
183// All of this SetFirstSeen/FirstSeen/SetFirstseer stuff can go away ( from here
184// and from packet_handler.go ) it's a cruft and was how someone was trying to
185// get around the fact that using the tcpassembly.tcpreader library throws away
186// all of the metadata about the stream.
187func (os *MongoOpStream) SetFirstSeen(t time.Time) {
188	os.FirstSeen = t
189}
190
191// handleOps runs all of the ops read from the unorderedOps through a heapsort
192// and then runs them out on the Ops channel.
193func (os *MongoOpStream) handleOps() {
194	defer close(os.Ops)
195	var counter int64
196	for op := range os.unorderedOps {
197		heap.Push(os.opHeap, op)
198		if len(*os.opHeap) == cap(*os.opHeap) {
199			nextOp := heap.Pop(os.opHeap).(RecordedOp)
200			counter++
201			nextOp.Order = counter
202			os.Ops <- &nextOp
203		}
204	}
205	for len(*os.opHeap) > 0 {
206		nextOp := heap.Pop(os.opHeap).(RecordedOp)
207		counter++
208		nextOp.Order = counter
209		os.Ops <- &nextOp
210	}
211}
212
213type streamState int
214
215func (st streamState) String() string {
216	switch st {
217	case streamStateBeforeMessage:
218		return "Before Message"
219	case streamStateInMessage:
220		return "In Message"
221	case streamStateOutOfSync:
222		return "Out Of Sync"
223	}
224	return "Unknown"
225}
226
227const (
228	streamStateBeforeMessage streamState = iota
229	streamStateInMessage
230	streamStateOutOfSync
231)
232
233func (bidi *bidi) handleStreamStateBeforeMessage(stream *stream) {
234	if stream.reassembly.Start {
235		if bidi.sawStart {
236			panic("apparently saw the beginning of a connection twice")
237		}
238		bidi.sawStart = true
239	}
240	// TODO deal with the situation that the first packet doesn't contain a
241	// whole MessageHeader of an otherwise valid protocol message.  The
242	// following code erroneously assumes that all packets will have at least 16
243	// bytes of data
244	if len(stream.reassembly.Bytes) < 16 {
245		stream.state = streamStateOutOfSync
246		stream.reassembly.Bytes = stream.reassembly.Bytes[:0]
247		return
248	}
249	stream.op.Header.FromWire(stream.reassembly.Bytes)
250	if !stream.op.Header.LooksReal() {
251		// When we're here and stream.reassembly.Start is true we may be able to
252		// know that we're actually not looking at mongodb traffic and that this
253		// whole stream should be discarded.
254		bidi.logvf(DebugLow, "not a good header %#v", stream.op.Header)
255		bidi.logvf(Info, "Expected to, but didn't see a valid protocol message")
256		stream.state = streamStateOutOfSync
257		stream.reassembly.Bytes = stream.reassembly.Bytes[:0]
258		return
259	}
260	stream.op.Body = make([]byte, 16, stream.op.Header.MessageLength)
261	stream.state = streamStateInMessage
262	stream.opTimeStamp = stream.reassembly.Seen
263	copy(stream.op.Body, stream.reassembly.Bytes)
264	stream.reassembly.Bytes = stream.reassembly.Bytes[16:]
265	return
266}
267func (bidi *bidi) handleStreamStateInMessage(stream *stream) {
268	var copySize int
269	bodyLen := len(stream.op.Body)
270	if bodyLen+len(stream.reassembly.Bytes) > int(stream.op.Header.MessageLength) {
271		copySize = int(stream.op.Header.MessageLength) - bodyLen
272	} else {
273		copySize = len(stream.reassembly.Bytes)
274	}
275	stream.op.Body = stream.op.Body[:bodyLen+copySize]
276	copied := copy(stream.op.Body[bodyLen:], stream.reassembly.Bytes)
277	if copied != copySize {
278		panic("copied != copySize")
279	}
280	stream.reassembly.Bytes = stream.reassembly.Bytes[copySize:]
281	if len(stream.op.Body) == int(stream.op.Header.MessageLength) {
282		// TODO maybe remember if we were recently in streamStateOutOfSync,
283		// and if so, parse the raw op here.
284
285		bidi.opStream.unorderedOps <- RecordedOp{
286			RawOp:             *stream.op,
287			Seen:              &PreciseTime{stream.opTimeStamp},
288			SrcEndpoint:       stream.netFlow.Src().String(),
289			DstEndpoint:       stream.netFlow.Dst().String(),
290			SeenConnectionNum: bidi.connectionNumber,
291		}
292
293		stream.op = &RawOp{}
294		stream.state = streamStateBeforeMessage
295		if len(stream.reassembly.Bytes) > 0 {
296			// parse the remainder of the stream.reassembly as a new message.
297			return
298		}
299	}
300	return
301}
302func (bidi *bidi) handleStreamStateOutOfSync(stream *stream) {
303	bidi.logvf(DebugHigh, "out of sync")
304	if len(stream.reassembly.Bytes) < 16 {
305		stream.reassembly.Bytes = stream.reassembly.Bytes[:0]
306		return
307	}
308	stream.op.Header.FromWire(stream.reassembly.Bytes)
309	bidi.logvf(DebugHigh, "possible message header %#v", stream.op.Header)
310	if stream.op.Header.LooksReal() {
311		stream.state = streamStateBeforeMessage
312		bidi.logvf(DebugLow, "synchronized")
313		return
314	}
315	stream.reassembly.Bytes = stream.reassembly.Bytes[:0]
316	return
317}
318func (bidi *bidi) handleStreamCompleted() {
319	var lastOpTimeStamp time.Time
320	if bidi.streams[0].opTimeStamp.After(bidi.streams[1].opTimeStamp) {
321		lastOpTimeStamp = bidi.streams[0].opTimeStamp
322	} else {
323		lastOpTimeStamp = bidi.streams[1].opTimeStamp
324	}
325	if !lastOpTimeStamp.IsZero() {
326		bidi.opStream.unorderedOps <- RecordedOp{
327			Seen:              &PreciseTime{lastOpTimeStamp.Add(time.Nanosecond)},
328			SeenConnectionNum: bidi.connectionNumber,
329			EOF:               true,
330		}
331	}
332	bidi.logvf(Info, "Connection %v: finishing", bidi.connectionNumber)
333}
334
335// streamOps reads tcpassembly.Reassembly[] blocks from the
336// stream's and tries to create whole protocol messages from them.
337func (bidi *bidi) streamOps() {
338	bidi.logvf(Info, "starting")
339	for {
340		var reassemblies []tcpassembly.Reassembly
341		var reassembliesStream int
342		var ok bool
343		select {
344		case reassemblies, ok = <-bidi.streams[0].reassembled:
345			reassembliesStream = 0
346		case reassemblies, ok = <-bidi.streams[1].reassembled:
347			reassembliesStream = 1
348		}
349		if !ok {
350			return
351		}
352		stream := bidi.streams[reassembliesStream]
353
354		for _, stream.reassembly = range reassemblies {
355			// Skip > 0 means that we've missed something, and we have
356			// incomplete packets in hand.
357			if stream.reassembly.Skip > 0 {
358				// TODO, we may want to do more state specific reporting here.
359				stream.state = streamStateOutOfSync
360				//when we have skip, we destroy this buffer
361				stream.op.Body = stream.op.Body[:0]
362				bidi.logvf(Info, "Connection %v state '%v': ignoring incomplete packet (skip: %v)", bidi.connectionNumber, stream.state, stream.reassembly.Skip)
363				continue
364			}
365			// Skip < 0 means that we're picking up a stream mid-stream, and we
366			// don't really know the state of what's in hand, we need to
367			// synchronize.
368			if stream.reassembly.Skip < 0 {
369				bidi.logvf(Info, "Connection %v state '%v': capture started in the middle of stream", bidi.connectionNumber, stream.state)
370				stream.state = streamStateOutOfSync
371			}
372
373			for len(stream.reassembly.Bytes) > 0 {
374				bidi.logvf(DebugHigh, "Connection %v: state '%v'", bidi.connectionNumber, stream.state)
375				switch stream.state {
376				case streamStateBeforeMessage:
377					bidi.handleStreamStateBeforeMessage(stream)
378				case streamStateInMessage:
379					bidi.handleStreamStateInMessage(stream)
380				case streamStateOutOfSync:
381					bidi.handleStreamStateOutOfSync(stream)
382				}
383			}
384		}
385		// inform the tcpassembly that we've finished with the reassemblies.
386		stream.done <- nil
387	}
388}
389