1// Copyright (C) MongoDB, Inc. 2014-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongoreplay
8
9import (
10	"fmt"
11	"sync"
12	"time"
13
14	mgo "github.com/10gen/llmgo"
15	"github.com/patrickmn/go-cache"
16)
17
18// ReplyPair contains both a live reply and a recorded reply when fully
19// occupied.
20type ReplyPair struct {
21	ops [2]Replyable
22}
23
24const (
25	// ReplyFromWire is the ReplyPair index for live replies.
26	ReplyFromWire = 0
27	// ReplyFromFile is the ReplyPair index for recorded replies.
28	ReplyFromFile = 1
29)
30
31// ExecutionContext maintains information for a mongoreplay execution.
32type ExecutionContext struct {
33	// IncompleteReplies holds half complete ReplyPairs, which contains either a
34	// live reply or a recorded reply when one arrives before the other.
35	IncompleteReplies *cache.Cache
36
37	// CompleteReplies contains ReplyPairs that have been competed by the
38	// arrival of the missing half of.
39	CompleteReplies map[string]*ReplyPair
40
41	// CursorIDMap contains the mapping between recorded cursorIDs and live
42	// cursorIDs
43	CursorIDMap cursorManager
44
45	// lock synchronizes access to all of the caches and maps in the
46	// ExecutionContext
47	sync.Mutex
48
49	ConnectionChansWaitGroup sync.WaitGroup
50
51	*StatCollector
52
53	// fullSpeed is a control to indicate whether the tool will sleep to synchronize
54	// the playback of operations or if it will play back all operations as fast
55	// as possible.
56	fullSpeed bool
57
58	driverOpsFiltered bool
59
60	session *mgo.Session
61}
62
63// ExecutionOptions holds the additional configuration options needed to completely
64// create an execution session.
65type ExecutionOptions struct {
66	fullSpeed         bool
67	driverOpsFiltered bool
68}
69
70// NewExecutionContext initializes a new ExecutionContext.
71func NewExecutionContext(statColl *StatCollector, session *mgo.Session, options *ExecutionOptions) *ExecutionContext {
72	return &ExecutionContext{
73		IncompleteReplies: cache.New(60*time.Second, 60*time.Second),
74		CompleteReplies:   map[string]*ReplyPair{},
75		CursorIDMap:       newCursorCache(),
76		StatCollector:     statColl,
77		fullSpeed:         options.fullSpeed,
78		driverOpsFiltered: options.driverOpsFiltered,
79		session:           session,
80	}
81}
82
83// AddFromWire adds a from-wire reply to its IncompleteReplies ReplyPair and
84// moves that ReplyPair to CompleteReplies if it's complete.  The index is based
85// on the src/dest of the recordedOp which should be the op that this ReplyOp is
86// a reply to.
87func (context *ExecutionContext) AddFromWire(reply Replyable, recordedOp *RecordedOp) {
88	if cursorID, _ := reply.getCursorID(); cursorID == 0 {
89		return
90	}
91	key := cacheKey(recordedOp, false)
92	toolDebugLogger.Logvf(DebugHigh, "Adding live reply with key %v", key)
93	context.completeReply(key, reply, ReplyFromWire)
94}
95
96// AddFromFile adds a from-file reply to its IncompleteReplies ReplyPair and
97// moves that ReplyPair to CompleteReplies if it's complete.  The index is based
98// on the reversed src/dest of the recordedOp which should the RecordedOp that
99// this ReplyOp was unmarshaled out of.
100func (context *ExecutionContext) AddFromFile(reply Replyable, recordedOp *RecordedOp) {
101	if cursorID, _ := reply.getCursorID(); cursorID == 0 {
102		return
103	}
104	key := cacheKey(recordedOp, true)
105	toolDebugLogger.Logvf(DebugHigh, "Adding recorded reply with key %v", key)
106	context.completeReply(key, reply, ReplyFromFile)
107}
108
109func (context *ExecutionContext) completeReply(key string, reply Replyable, opSource int) {
110	context.Lock()
111	if cacheValue, ok := context.IncompleteReplies.Get(key); !ok {
112		rp := &ReplyPair{}
113		rp.ops[opSource] = reply
114		context.IncompleteReplies.Set(key, rp, cache.DefaultExpiration)
115	} else {
116		rp := cacheValue.(*ReplyPair)
117		rp.ops[opSource] = reply
118		if rp.ops[1-opSource] != nil {
119			context.CompleteReplies[key] = rp
120			context.IncompleteReplies.Delete(key)
121		}
122	}
123	context.Unlock()
124}
125
126func (context *ExecutionContext) rewriteCursors(rewriteable cursorsRewriteable, connectionNum int64) (bool, error) {
127	cursorIDs, err := rewriteable.getCursorIDs()
128
129	index := 0
130	for _, cursorID := range cursorIDs {
131		userInfoLogger.Logvf(DebugLow, "Rewriting cursorID : %v", cursorID)
132		liveCursorID, ok := context.CursorIDMap.GetCursor(cursorID, connectionNum)
133		if ok {
134			cursorIDs[index] = liveCursorID
135			index++
136		} else {
137			userInfoLogger.Logvf(DebugLow, "Missing mapped cursorID for raw cursorID : %v", cursorID)
138		}
139	}
140	newCursors := cursorIDs[0:index]
141	err = rewriteable.setCursorIDs(newCursors)
142	if err != nil {
143		return false, err
144	}
145	return len(newCursors) != 0, nil
146}
147
148func (context *ExecutionContext) handleCompletedReplies() error {
149	context.Lock()
150	for key, rp := range context.CompleteReplies {
151		userInfoLogger.Logvf(DebugHigh, "Completed reply: %v, %v", rp.ops[ReplyFromFile], rp.ops[ReplyFromWire])
152		cursorFromFile, err := rp.ops[ReplyFromFile].getCursorID()
153		if err != nil {
154			return err
155		}
156		cursorFromWire, err := rp.ops[ReplyFromWire].getCursorID()
157		if err != nil {
158			return err
159		}
160		if cursorFromFile != 0 {
161			context.CursorIDMap.SetCursor(cursorFromFile, cursorFromWire)
162		}
163
164		delete(context.CompleteReplies, key)
165	}
166
167	context.Unlock()
168	return nil
169}
170
171func (context *ExecutionContext) newExecutionConnection(start time.Time, connectionNum int64) chan<- *RecordedOp {
172	ch := make(chan *RecordedOp, 10000)
173	context.ConnectionChansWaitGroup.Add(1)
174
175	go func() {
176		now := time.Now()
177		var connected bool
178		time.Sleep(start.Add(-5 * time.Second).Sub(now)) // Sleep until five seconds before the start time
179		socket, err := context.session.AcquireSocketDirect()
180		if err == nil {
181			userInfoLogger.Logvf(Info, "(Connection %v) New connection CREATED.", connectionNum)
182			connected = true
183			defer socket.Close()
184		} else {
185			userInfoLogger.Logvf(Info, "(Connection %v) New Connection FAILED: %v", connectionNum, err)
186		}
187		for recordedOp := range ch {
188			var parsedOp Op
189			var reply Replyable
190			var err error
191			msg := ""
192			if connected {
193				// Populate the op with the connection num it's being played on.
194				// This allows it to be used for downstream reporting of stats.
195				recordedOp.PlayedConnectionNum = connectionNum
196				t := time.Now()
197
198				if !context.fullSpeed && recordedOp.RawOp.Header.OpCode != OpCodeReply {
199					if t.Before(recordedOp.PlayAt.Time) {
200						time.Sleep(recordedOp.PlayAt.Sub(t))
201					}
202				}
203				userInfoLogger.Logvf(DebugHigh, "(Connection %v) op %v", connectionNum, recordedOp.String())
204				parsedOp, reply, err = context.Execute(recordedOp, socket)
205				if err != nil {
206					toolDebugLogger.Logvf(Always, "context.Execute error: %v", err)
207				}
208			} else {
209				parsedOp, err = recordedOp.Parse()
210				if err != nil {
211					toolDebugLogger.Logvf(Always, "Execution Connection error: %v", err)
212				}
213
214				msg = fmt.Sprintf("Skipped on non-connected socket (Connection %v)", connectionNum)
215				toolDebugLogger.Logv(Always, msg)
216			}
217			if shouldCollectOp(parsedOp, context.driverOpsFiltered) {
218				context.Collect(recordedOp, parsedOp, reply, msg)
219			}
220		}
221		userInfoLogger.Logvf(Info, "(Connection %v) Connection ENDED.", connectionNum)
222		context.ConnectionChansWaitGroup.Done()
223	}()
224	return ch
225}
226
227// Execute plays a particular command on an mgo socket.
228func (context *ExecutionContext) Execute(op *RecordedOp, socket *mgo.MongoSocket) (Op, Replyable, error) {
229	opToExec, err := op.RawOp.Parse()
230	var reply Replyable
231
232	if err != nil {
233		return nil, nil, fmt.Errorf("ParseOpRawError: %v", err)
234	}
235	if opToExec == nil {
236		toolDebugLogger.Logvf(Always, "Skipping incomplete op: %v", op.RawOp.Header.OpCode)
237		return nil, nil, nil
238	}
239	switch replyable := opToExec.(type) {
240	case *ReplyOp:
241		context.AddFromFile(replyable, op)
242	case *CommandReplyOp:
243		context.AddFromFile(replyable, op)
244	case *MsgOpReply:
245		context.AddFromFile(replyable, op)
246	default:
247		if !context.driverOpsFiltered && IsDriverOp(opToExec) {
248			return opToExec, nil, nil
249		}
250		if rewriteable, ok1 := opToExec.(cursorsRewriteable); ok1 {
251			ok2, err := context.rewriteCursors(rewriteable, op.SeenConnectionNum)
252			if err != nil {
253				return opToExec, nil, err
254			}
255			if !ok2 {
256				return opToExec, nil, nil
257			}
258		}
259		// check if the op has a function to preprocess its data given the current
260		// set of options
261		if op, ok := opToExec.(Preprocessable); ok {
262			op.Preprocess()
263		}
264
265		op.PlayedAt = &PreciseTime{time.Now()}
266
267		reply, err = opToExec.Execute(socket)
268
269		if err != nil {
270			context.CursorIDMap.MarkFailed(op)
271			return opToExec, reply, fmt.Errorf("error executing op: %v", err)
272		}
273		if reply != nil {
274			context.AddFromWire(reply, op)
275		}
276	}
277	context.handleCompletedReplies()
278	return opToExec, reply, nil
279}
280