1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package session // import "go.mongodb.org/mongo-driver/x/mongo/driver/session"
8
9import (
10	"errors"
11	"time"
12
13	"go.mongodb.org/mongo-driver/bson"
14	"go.mongodb.org/mongo-driver/bson/primitive"
15	"go.mongodb.org/mongo-driver/mongo/description"
16	"go.mongodb.org/mongo-driver/mongo/readconcern"
17	"go.mongodb.org/mongo-driver/mongo/readpref"
18	"go.mongodb.org/mongo-driver/mongo/writeconcern"
19	"go.mongodb.org/mongo-driver/x/mongo/driver/uuid"
20)
21
22// ErrSessionEnded is returned when a client session is used after a call to endSession().
23var ErrSessionEnded = errors.New("ended session was used")
24
25// ErrNoTransactStarted is returned if a transaction operation is called when no transaction has started.
26var ErrNoTransactStarted = errors.New("no transaction started")
27
28// ErrTransactInProgress is returned if startTransaction() is called when a transaction is in progress.
29var ErrTransactInProgress = errors.New("transaction already in progress")
30
31// ErrAbortAfterCommit is returned when abort is called after a commit.
32var ErrAbortAfterCommit = errors.New("cannot call abortTransaction after calling commitTransaction")
33
34// ErrAbortTwice is returned if abort is called after transaction is already aborted.
35var ErrAbortTwice = errors.New("cannot call abortTransaction twice")
36
37// ErrCommitAfterAbort is returned if commit is called after an abort.
38var ErrCommitAfterAbort = errors.New("cannot call commitTransaction after calling abortTransaction")
39
40// ErrUnackWCUnsupported is returned if an unacknowledged write concern is supported for a transaciton.
41var ErrUnackWCUnsupported = errors.New("transactions do not support unacknowledged write concerns")
42
43// Type describes the type of the session
44type Type uint8
45
46// These constants are the valid types for a client session.
47const (
48	Explicit Type = iota
49	Implicit
50)
51
52// TransactionState indicates the state of the transactions FSM.
53type TransactionState uint8
54
55// Client Session states
56const (
57	None TransactionState = iota
58	Starting
59	InProgress
60	Committed
61	Aborted
62)
63
64// String implements the fmt.Stringer interface.
65func (s TransactionState) String() string {
66	switch s {
67	case None:
68		return "none"
69	case Starting:
70		return "starting"
71	case InProgress:
72		return "in progress"
73	case Committed:
74		return "committed"
75	case Aborted:
76		return "aborted"
77	default:
78		return "unknown"
79	}
80}
81
82// Client is a session for clients to run commands.
83type Client struct {
84	*Server
85	ClientID       uuid.UUID
86	ClusterTime    bson.Raw
87	Consistent     bool // causal consistency
88	OperationTime  *primitive.Timestamp
89	SessionType    Type
90	Terminated     bool
91	RetryingCommit bool
92	Committing     bool
93	Aborting       bool
94	RetryWrite     bool
95	RetryRead      bool
96
97	// options for the current transaction
98	// most recently set by transactionopt
99	CurrentRc  *readconcern.ReadConcern
100	CurrentRp  *readpref.ReadPref
101	CurrentWc  *writeconcern.WriteConcern
102	CurrentMct *time.Duration
103
104	// default transaction options
105	transactionRc            *readconcern.ReadConcern
106	transactionRp            *readpref.ReadPref
107	transactionWc            *writeconcern.WriteConcern
108	transactionMaxCommitTime *time.Duration
109
110	pool             *Pool
111	TransactionState TransactionState
112	PinnedServer     *description.Server
113	RecoveryToken    bson.Raw
114}
115
116func getClusterTime(clusterTime bson.Raw) (uint32, uint32) {
117	if clusterTime == nil {
118		return 0, 0
119	}
120
121	clusterTimeVal, err := clusterTime.LookupErr("$clusterTime")
122	if err != nil {
123		return 0, 0
124	}
125
126	timestampVal, err := bson.Raw(clusterTimeVal.Value).LookupErr("clusterTime")
127	if err != nil {
128		return 0, 0
129	}
130
131	return timestampVal.Timestamp()
132}
133
134// MaxClusterTime compares 2 clusterTime documents and returns the document representing the highest cluster time.
135func MaxClusterTime(ct1, ct2 bson.Raw) bson.Raw {
136	epoch1, ord1 := getClusterTime(ct1)
137	epoch2, ord2 := getClusterTime(ct2)
138
139	if epoch1 > epoch2 {
140		return ct1
141	} else if epoch1 < epoch2 {
142		return ct2
143	} else if ord1 > ord2 {
144		return ct1
145	} else if ord1 < ord2 {
146		return ct2
147	}
148
149	return ct1
150}
151
152// NewClientSession creates a Client.
153func NewClientSession(pool *Pool, clientID uuid.UUID, sessionType Type, opts ...*ClientOptions) (*Client, error) {
154	c := &Client{
155		Consistent:  true, // set default
156		ClientID:    clientID,
157		SessionType: sessionType,
158		pool:        pool,
159	}
160
161	mergedOpts := mergeClientOptions(opts...)
162	if mergedOpts.CausalConsistency != nil {
163		c.Consistent = *mergedOpts.CausalConsistency
164	}
165	if mergedOpts.DefaultReadPreference != nil {
166		c.transactionRp = mergedOpts.DefaultReadPreference
167	}
168	if mergedOpts.DefaultReadConcern != nil {
169		c.transactionRc = mergedOpts.DefaultReadConcern
170	}
171	if mergedOpts.DefaultWriteConcern != nil {
172		c.transactionWc = mergedOpts.DefaultWriteConcern
173	}
174	if mergedOpts.DefaultMaxCommitTime != nil {
175		c.transactionMaxCommitTime = mergedOpts.DefaultMaxCommitTime
176	}
177
178	servSess, err := pool.GetSession()
179	if err != nil {
180		return nil, err
181	}
182
183	c.Server = servSess
184
185	return c, nil
186}
187
188// AdvanceClusterTime updates the session's cluster time.
189func (c *Client) AdvanceClusterTime(clusterTime bson.Raw) error {
190	if c.Terminated {
191		return ErrSessionEnded
192	}
193	c.ClusterTime = MaxClusterTime(c.ClusterTime, clusterTime)
194	return nil
195}
196
197// AdvanceOperationTime updates the session's operation time.
198func (c *Client) AdvanceOperationTime(opTime *primitive.Timestamp) error {
199	if c.Terminated {
200		return ErrSessionEnded
201	}
202
203	if c.OperationTime == nil {
204		c.OperationTime = opTime
205		return nil
206	}
207
208	if opTime.T > c.OperationTime.T {
209		c.OperationTime = opTime
210	} else if (opTime.T == c.OperationTime.T) && (opTime.I > c.OperationTime.I) {
211		c.OperationTime = opTime
212	}
213
214	return nil
215}
216
217// UpdateUseTime sets the session's last used time to the current time. This must be called whenever the session is
218// used to send a command to the server to ensure that the session is not prematurely marked expired in the driver's
219// session pool. If the session has already been ended, this method will return ErrSessionEnded.
220func (c *Client) UpdateUseTime() error {
221	if c.Terminated {
222		return ErrSessionEnded
223	}
224	c.updateUseTime()
225	return nil
226}
227
228// UpdateRecoveryToken updates the session's recovery token from the server response.
229func (c *Client) UpdateRecoveryToken(response bson.Raw) {
230	if c == nil {
231		return
232	}
233
234	token, err := response.LookupErr("recoveryToken")
235	if err != nil {
236		return
237	}
238
239	c.RecoveryToken = token.Document()
240}
241
242// ClearPinnedServer sets the PinnedServer to nil.
243func (c *Client) ClearPinnedServer() {
244	if c != nil {
245		c.PinnedServer = nil
246	}
247}
248
249// EndSession ends the session.
250func (c *Client) EndSession() {
251	if c.Terminated {
252		return
253	}
254
255	c.Terminated = true
256	c.pool.ReturnSession(c.Server)
257
258	return
259}
260
261// TransactionInProgress returns true if the client session is in an active transaction.
262func (c *Client) TransactionInProgress() bool {
263	return c.TransactionState == InProgress
264}
265
266// TransactionStarting returns true if the client session is starting a transaction.
267func (c *Client) TransactionStarting() bool {
268	return c.TransactionState == Starting
269}
270
271// TransactionRunning returns true if the client session has started the transaction
272// and it hasn't been committed or aborted
273func (c *Client) TransactionRunning() bool {
274	return c != nil && (c.TransactionState == Starting || c.TransactionState == InProgress)
275}
276
277// TransactionCommitted returns true of the client session just committed a transaciton.
278func (c *Client) TransactionCommitted() bool {
279	return c.TransactionState == Committed
280}
281
282// CheckStartTransaction checks to see if allowed to start transaction and returns
283// an error if not allowed
284func (c *Client) CheckStartTransaction() error {
285	if c.TransactionState == InProgress || c.TransactionState == Starting {
286		return ErrTransactInProgress
287	}
288	return nil
289}
290
291// StartTransaction initializes the transaction options and advances the state machine.
292// It does not contact the server to start the transaction.
293func (c *Client) StartTransaction(opts *TransactionOptions) error {
294	err := c.CheckStartTransaction()
295	if err != nil {
296		return err
297	}
298
299	c.IncrementTxnNumber()
300	c.RetryingCommit = false
301
302	if opts != nil {
303		c.CurrentRc = opts.ReadConcern
304		c.CurrentRp = opts.ReadPreference
305		c.CurrentWc = opts.WriteConcern
306		c.CurrentMct = opts.MaxCommitTime
307	}
308
309	if c.CurrentRc == nil {
310		c.CurrentRc = c.transactionRc
311	}
312
313	if c.CurrentRp == nil {
314		c.CurrentRp = c.transactionRp
315	}
316
317	if c.CurrentWc == nil {
318		c.CurrentWc = c.transactionWc
319	}
320
321	if c.CurrentMct == nil {
322		c.CurrentMct = c.transactionMaxCommitTime
323	}
324
325	if !writeconcern.AckWrite(c.CurrentWc) {
326		c.clearTransactionOpts()
327		return ErrUnackWCUnsupported
328	}
329
330	c.TransactionState = Starting
331	c.PinnedServer = nil
332	return nil
333}
334
335// CheckCommitTransaction checks to see if allowed to commit transaction and returns
336// an error if not allowed.
337func (c *Client) CheckCommitTransaction() error {
338	if c.TransactionState == None {
339		return ErrNoTransactStarted
340	} else if c.TransactionState == Aborted {
341		return ErrCommitAfterAbort
342	}
343	return nil
344}
345
346// CommitTransaction updates the state for a successfully committed transaction and returns
347// an error if not permissible.  It does not actually perform the commit.
348func (c *Client) CommitTransaction() error {
349	err := c.CheckCommitTransaction()
350	if err != nil {
351		return err
352	}
353	c.TransactionState = Committed
354	return nil
355}
356
357// UpdateCommitTransactionWriteConcern will set the write concern to majority and potentially set  a
358// w timeout of 10 seconds. This should be called after a commit transaction operation fails with a
359// retryable error or after a successful commit transaction operation.
360func (c *Client) UpdateCommitTransactionWriteConcern() {
361	wc := c.CurrentWc
362	timeout := 10 * time.Second
363	if wc != nil && wc.GetWTimeout() != 0 {
364		timeout = wc.GetWTimeout()
365	}
366	c.CurrentWc = wc.WithOptions(writeconcern.WMajority(), writeconcern.WTimeout(timeout))
367}
368
369// CheckAbortTransaction checks to see if allowed to abort transaction and returns
370// an error if not allowed.
371func (c *Client) CheckAbortTransaction() error {
372	if c.TransactionState == None {
373		return ErrNoTransactStarted
374	} else if c.TransactionState == Committed {
375		return ErrAbortAfterCommit
376	} else if c.TransactionState == Aborted {
377		return ErrAbortTwice
378	}
379	return nil
380}
381
382// AbortTransaction updates the state for a successfully aborted transaction and returns
383// an error if not permissible.  It does not actually perform the abort.
384func (c *Client) AbortTransaction() error {
385	err := c.CheckAbortTransaction()
386	if err != nil {
387		return err
388	}
389	c.TransactionState = Aborted
390	c.clearTransactionOpts()
391	return nil
392}
393
394// ApplyCommand advances the state machine upon command execution.
395func (c *Client) ApplyCommand(desc description.Server) {
396	if c.Committing {
397		// Do not change state if committing after already committed
398		return
399	}
400	if c.TransactionState == Starting {
401		c.TransactionState = InProgress
402		// If this is in a transaction and the server is a mongos, pin it
403		if desc.Kind == description.Mongos {
404			c.PinnedServer = &desc
405		}
406	} else if c.TransactionState == Committed || c.TransactionState == Aborted {
407		c.clearTransactionOpts()
408		c.TransactionState = None
409	}
410}
411
412func (c *Client) clearTransactionOpts() {
413	c.RetryingCommit = false
414	c.Aborting = false
415	c.Committing = false
416	c.CurrentWc = nil
417	c.CurrentRp = nil
418	c.CurrentRc = nil
419	c.PinnedServer = nil
420	c.RecoveryToken = nil
421}
422