1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package session // import "go.mongodb.org/mongo-driver/x/mongo/driver/session"
8
9import (
10	"errors"
11	"time"
12
13	"go.mongodb.org/mongo-driver/bson"
14	"go.mongodb.org/mongo-driver/bson/primitive"
15	"go.mongodb.org/mongo-driver/mongo/readconcern"
16	"go.mongodb.org/mongo-driver/mongo/readpref"
17	"go.mongodb.org/mongo-driver/mongo/writeconcern"
18	"go.mongodb.org/mongo-driver/x/mongo/driver/description"
19	"go.mongodb.org/mongo-driver/x/mongo/driver/uuid"
20)
21
22// ErrSessionEnded is returned when a client session is used after a call to endSession().
23var ErrSessionEnded = errors.New("ended session was used")
24
25// ErrNoTransactStarted is returned if a transaction operation is called when no transaction has started.
26var ErrNoTransactStarted = errors.New("no transaction started")
27
28// ErrTransactInProgress is returned if startTransaction() is called when a transaction is in progress.
29var ErrTransactInProgress = errors.New("transaction already in progress")
30
31// ErrAbortAfterCommit is returned when abort is called after a commit.
32var ErrAbortAfterCommit = errors.New("cannot call abortTransaction after calling commitTransaction")
33
34// ErrAbortTwice is returned if abort is called after transaction is already aborted.
35var ErrAbortTwice = errors.New("cannot call abortTransaction twice")
36
37// ErrCommitAfterAbort is returned if commit is called after an abort.
38var ErrCommitAfterAbort = errors.New("cannot call commitTransaction after calling abortTransaction")
39
40// ErrUnackWCUnsupported is returned if an unacknowledged write concern is supported for a transaciton.
41var ErrUnackWCUnsupported = errors.New("transactions do not support unacknowledged write concerns")
42
43// Type describes the type of the session
44type Type uint8
45
46// These constants are the valid types for a client session.
47const (
48	Explicit Type = iota
49	Implicit
50)
51
52// State indicates the state of the FSM.
53type state uint8
54
55// Client Session states
56const (
57	None state = iota
58	Starting
59	InProgress
60	Committed
61	Aborted
62)
63
64// Client is a session for clients to run commands.
65type Client struct {
66	*Server
67	ClientID       uuid.UUID
68	ClusterTime    bson.Raw
69	Consistent     bool // causal consistency
70	OperationTime  *primitive.Timestamp
71	SessionType    Type
72	Terminated     bool
73	RetryingCommit bool
74	Committing     bool
75	Aborting       bool
76	RetryWrite     bool
77	RetryRead      bool
78
79	// options for the current transaction
80	// most recently set by transactionopt
81	CurrentRc  *readconcern.ReadConcern
82	CurrentRp  *readpref.ReadPref
83	CurrentWc  *writeconcern.WriteConcern
84	CurrentMct *time.Duration
85
86	// default transaction options
87	transactionRc            *readconcern.ReadConcern
88	transactionRp            *readpref.ReadPref
89	transactionWc            *writeconcern.WriteConcern
90	transactionMaxCommitTime *time.Duration
91
92	pool          *Pool
93	state         state
94	PinnedServer  *description.Server
95	RecoveryToken bson.Raw
96}
97
98func getClusterTime(clusterTime bson.Raw) (uint32, uint32) {
99	if clusterTime == nil {
100		return 0, 0
101	}
102
103	clusterTimeVal, err := clusterTime.LookupErr("$clusterTime")
104	if err != nil {
105		return 0, 0
106	}
107
108	timestampVal, err := bson.Raw(clusterTimeVal.Value).LookupErr("clusterTime")
109	if err != nil {
110		return 0, 0
111	}
112
113	return timestampVal.Timestamp()
114}
115
116// MaxClusterTime compares 2 clusterTime documents and returns the document representing the highest cluster time.
117func MaxClusterTime(ct1, ct2 bson.Raw) bson.Raw {
118	epoch1, ord1 := getClusterTime(ct1)
119	epoch2, ord2 := getClusterTime(ct2)
120
121	if epoch1 > epoch2 {
122		return ct1
123	} else if epoch1 < epoch2 {
124		return ct2
125	} else if ord1 > ord2 {
126		return ct1
127	} else if ord1 < ord2 {
128		return ct2
129	}
130
131	return ct1
132}
133
134// NewClientSession creates a Client.
135func NewClientSession(pool *Pool, clientID uuid.UUID, sessionType Type, opts ...*ClientOptions) (*Client, error) {
136	c := &Client{
137		Consistent:  true, // set default
138		ClientID:    clientID,
139		SessionType: sessionType,
140		pool:        pool,
141	}
142
143	mergedOpts := mergeClientOptions(opts...)
144	if mergedOpts.CausalConsistency != nil {
145		c.Consistent = *mergedOpts.CausalConsistency
146	}
147	if mergedOpts.DefaultReadPreference != nil {
148		c.transactionRp = mergedOpts.DefaultReadPreference
149	}
150	if mergedOpts.DefaultReadConcern != nil {
151		c.transactionRc = mergedOpts.DefaultReadConcern
152	}
153	if mergedOpts.DefaultWriteConcern != nil {
154		c.transactionWc = mergedOpts.DefaultWriteConcern
155	}
156	if mergedOpts.DefaultMaxCommitTime != nil {
157		c.transactionMaxCommitTime = mergedOpts.DefaultMaxCommitTime
158	}
159
160	servSess, err := pool.GetSession()
161	if err != nil {
162		return nil, err
163	}
164
165	c.Server = servSess
166
167	return c, nil
168}
169
170// AdvanceClusterTime updates the session's cluster time.
171func (c *Client) AdvanceClusterTime(clusterTime bson.Raw) error {
172	if c.Terminated {
173		return ErrSessionEnded
174	}
175	c.ClusterTime = MaxClusterTime(c.ClusterTime, clusterTime)
176	return nil
177}
178
179// AdvanceOperationTime updates the session's operation time.
180func (c *Client) AdvanceOperationTime(opTime *primitive.Timestamp) error {
181	if c.Terminated {
182		return ErrSessionEnded
183	}
184
185	if c.OperationTime == nil {
186		c.OperationTime = opTime
187		return nil
188	}
189
190	if opTime.T > c.OperationTime.T {
191		c.OperationTime = opTime
192	} else if (opTime.T == c.OperationTime.T) && (opTime.I > c.OperationTime.I) {
193		c.OperationTime = opTime
194	}
195
196	return nil
197}
198
199// UpdateUseTime updates the session's last used time.
200// Must be called whenver this session is used to send a command to the server.
201func (c *Client) UpdateUseTime() error {
202	if c.Terminated {
203		return ErrSessionEnded
204	}
205	c.updateUseTime()
206	return nil
207}
208
209// UpdateRecoveryToken updates the session's recovery token from the server response.
210func (c *Client) UpdateRecoveryToken(response bson.Raw) {
211	if c == nil {
212		return
213	}
214
215	token, err := response.LookupErr("recoveryToken")
216	if err != nil {
217		return
218	}
219
220	c.RecoveryToken = token.Document()
221}
222
223// ClearPinnedServer sets the PinnedServer to nil.
224func (c *Client) ClearPinnedServer() {
225	if c != nil {
226		c.PinnedServer = nil
227	}
228}
229
230// EndSession ends the session.
231func (c *Client) EndSession() {
232	if c.Terminated {
233		return
234	}
235
236	c.Terminated = true
237	c.pool.ReturnSession(c.Server)
238
239	return
240}
241
242// TransactionInProgress returns true if the client session is in an active transaction.
243func (c *Client) TransactionInProgress() bool {
244	return c.state == InProgress
245}
246
247// TransactionStarting returns true if the client session is starting a transaction.
248func (c *Client) TransactionStarting() bool {
249	return c.state == Starting
250}
251
252// TransactionRunning returns true if the client session has started the transaction
253// and it hasn't been committed or aborted
254func (c *Client) TransactionRunning() bool {
255	return c != nil && (c.state == Starting || c.state == InProgress)
256}
257
258// TransactionCommitted returns true of the client session just committed a transaciton.
259func (c *Client) TransactionCommitted() bool {
260	return c.state == Committed
261}
262
263// CheckStartTransaction checks to see if allowed to start transaction and returns
264// an error if not allowed
265func (c *Client) CheckStartTransaction() error {
266	if c.state == InProgress || c.state == Starting {
267		return ErrTransactInProgress
268	}
269	return nil
270}
271
272// StartTransaction initializes the transaction options and advances the state machine.
273// It does not contact the server to start the transaction.
274func (c *Client) StartTransaction(opts *TransactionOptions) error {
275	err := c.CheckStartTransaction()
276	if err != nil {
277		return err
278	}
279
280	c.IncrementTxnNumber()
281	c.RetryingCommit = false
282
283	if opts != nil {
284		c.CurrentRc = opts.ReadConcern
285		c.CurrentRp = opts.ReadPreference
286		c.CurrentWc = opts.WriteConcern
287		c.CurrentMct = opts.MaxCommitTime
288	}
289
290	if c.CurrentRc == nil {
291		c.CurrentRc = c.transactionRc
292	}
293
294	if c.CurrentRp == nil {
295		c.CurrentRp = c.transactionRp
296	}
297
298	if c.CurrentWc == nil {
299		c.CurrentWc = c.transactionWc
300	}
301
302	if c.CurrentMct == nil {
303		c.CurrentMct = c.transactionMaxCommitTime
304	}
305
306	if !writeconcern.AckWrite(c.CurrentWc) {
307		c.clearTransactionOpts()
308		return ErrUnackWCUnsupported
309	}
310
311	c.state = Starting
312	c.PinnedServer = nil
313	return nil
314}
315
316// CheckCommitTransaction checks to see if allowed to commit transaction and returns
317// an error if not allowed.
318func (c *Client) CheckCommitTransaction() error {
319	if c.state == None {
320		return ErrNoTransactStarted
321	} else if c.state == Aborted {
322		return ErrCommitAfterAbort
323	}
324	return nil
325}
326
327// CommitTransaction updates the state for a successfully committed transaction and returns
328// an error if not permissible.  It does not actually perform the commit.
329func (c *Client) CommitTransaction() error {
330	err := c.CheckCommitTransaction()
331	if err != nil {
332		return err
333	}
334	c.state = Committed
335	return nil
336}
337
338// UpdateCommitTransactionWriteConcern will set the write concern to majority and potentially set  a
339// w timeout of 10 seconds. This should be called after a commit transaction operation fails with a
340// retryable error or after a successful commit transaction operation.
341func (c *Client) UpdateCommitTransactionWriteConcern() {
342	wc := c.CurrentWc
343	timeout := 10 * time.Second
344	if wc != nil && wc.GetWTimeout() != 0 {
345		timeout = wc.GetWTimeout()
346	}
347	c.CurrentWc = wc.WithOptions(writeconcern.WMajority(), writeconcern.WTimeout(timeout))
348}
349
350// CheckAbortTransaction checks to see if allowed to abort transaction and returns
351// an error if not allowed.
352func (c *Client) CheckAbortTransaction() error {
353	if c.state == None {
354		return ErrNoTransactStarted
355	} else if c.state == Committed {
356		return ErrAbortAfterCommit
357	} else if c.state == Aborted {
358		return ErrAbortTwice
359	}
360	return nil
361}
362
363// AbortTransaction updates the state for a successfully aborted transaction and returns
364// an error if not permissible.  It does not actually perform the abort.
365func (c *Client) AbortTransaction() error {
366	err := c.CheckAbortTransaction()
367	if err != nil {
368		return err
369	}
370	c.state = Aborted
371	c.clearTransactionOpts()
372	return nil
373}
374
375// ApplyCommand advances the state machine upon command execution.
376func (c *Client) ApplyCommand(desc description.Server) {
377	if c.Committing {
378		// Do not change state if committing after already committed
379		return
380	}
381	if c.state == Starting {
382		c.state = InProgress
383		// If this is in a transaction and the server is a mongos, pin it
384		if desc.Kind == description.Mongos {
385			c.PinnedServer = &desc
386		}
387	} else if c.state == Committed || c.state == Aborted {
388		c.clearTransactionOpts()
389		c.state = None
390	}
391}
392
393func (c *Client) clearTransactionOpts() {
394	c.RetryingCommit = false
395	c.Aborting = false
396	c.Committing = false
397	c.CurrentWc = nil
398	c.CurrentRp = nil
399	c.CurrentRc = nil
400	c.PinnedServer = nil
401	c.RecoveryToken = nil
402}
403