1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongo
8
9import (
10	"context"
11	"errors"
12	"time"
13
14	"go.mongodb.org/mongo-driver/bson"
15	"go.mongodb.org/mongo-driver/bson/primitive"
16	"go.mongodb.org/mongo-driver/mongo/options"
17	"go.mongodb.org/mongo-driver/x/bsonx"
18	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
19	"go.mongodb.org/mongo-driver/x/mongo/driver"
20	"go.mongodb.org/mongo-driver/x/mongo/driver/description"
21	"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
22	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
23)
24
25// ErrWrongClient is returned when a user attempts to pass in a session created by a different client than
26// the method call is using.
27var ErrWrongClient = errors.New("session was not created by this client")
28
29var withTransactionTimeout = 120 * time.Second
30
31// SessionContext combines the context.Context and mongo.Session interfaces. It should be used as the Context arguments
32// to operations that should be executed in a session. This type is not goroutine safe and must not be used concurrently
33// by multiple goroutines.
34type SessionContext interface {
35	context.Context
36	Session
37}
38
39type sessionContext struct {
40	context.Context
41	Session
42}
43
44type sessionKey struct {
45}
46
47// Session is an interface that represents a MongoDB logical session. Sessions can be used to enable causal consistency
48// for a group of operations or to execute operations in an ACID transaction. A new Session can be created from a Client
49// instance. A Session created from a Client must only be used to execute operations using that Client or a Database or
50// Collection created from that Client. Custom implementations of this interface should not be used in production. For
51// more information about sessions, and their use cases, see https://docs.mongodb.com/manual/reference/server-sessions/,
52// https://docs.mongodb.com/manual/core/read-isolation-consistency-recency/#causal-consistency, and
53// https://docs.mongodb.com/manual/core/transactions/.
54//
55// StartTransaction starts a new transaction, configured with the given options, on this session. This method will
56// return an error if there is already a transaction in-progress for this session.
57//
58// CommitTransaction commits the active transaction for this session. This method will return an error if there is no
59// active transaction for this session or the transaction has been aborted.
60//
61// AbortTransaction aborts the active transaction for this session. This method will return an error if there is no
62// active transaction for this session or the transaction has been committed or aborted.
63//
64// WithTransaction starts a transaction on this session and runs the fn callback. Errors with the
65// TransientTransactionError and UnknownTransactionCommitResult labels are retried for up to 120 seconds. Inside the
66// callback, sessCtx must be used as the Context parameter for any operations that should be part of the transaction. If
67// the ctx parameter already has a Session attached to it, it will be replaced by this session. The fn callback may be
68// run multiple times during WithTransaction due to retry attempts, so it must be idempotent. Non-retryable operation
69// errors or any operation errors that occur after the timeout expires will be returned without retrying. For a usage
70// example, see the Client.StartSession method documentation.
71//
72// ClusterTime, OperationTime, and Client return the session's current operation time, the session's current cluster
73// time, and the Client associated with the session, respectively.
74//
75// EndSession method should abort any existing transactions and close the session.
76//
77// AdvanceClusterTime and AdvanceOperationTime are for internal use only and must not be called.
78type Session interface {
79	StartTransaction(...*options.TransactionOptions) error
80	AbortTransaction(context.Context) error
81	CommitTransaction(context.Context) error
82	WithTransaction(ctx context.Context, fn func(sessCtx SessionContext) (interface{}, error),
83		opts ...*options.TransactionOptions) (interface{}, error)
84	ClusterTime() bson.Raw
85	OperationTime() *primitive.Timestamp
86	Client() *Client
87	EndSession(context.Context)
88
89	AdvanceClusterTime(bson.Raw) error
90	AdvanceOperationTime(*primitive.Timestamp) error
91
92	session()
93}
94
95// XSession is an unstable interface for internal use only. This interface is deprecated and is not part of the
96// stability guarantee. It may be removed at any time.
97type XSession interface {
98	ClientSession() *session.Client
99	ID() bsonx.Doc
100}
101
102// sessionImpl represents a set of sequential operations executed by an application that are related in some way.
103type sessionImpl struct {
104	clientSession       *session.Client
105	client              *Client
106	deployment          driver.Deployment
107	didCommitAfterStart bool // true if commit was called after start with no other operations
108}
109
110var _ Session = &sessionImpl{}
111var _ XSession = &sessionImpl{}
112
113// ClientSession implements the XSession interface.
114func (s *sessionImpl) ClientSession() *session.Client {
115	return s.clientSession
116}
117
118// ID implements the XSession interface.
119func (s *sessionImpl) ID() bsonx.Doc {
120	return s.clientSession.SessionID
121}
122
123// EndSession implements the Session interface.
124func (s *sessionImpl) EndSession(ctx context.Context) {
125	if s.clientSession.TransactionInProgress() {
126		// ignore all errors aborting during an end session
127		_ = s.AbortTransaction(ctx)
128	}
129	s.clientSession.EndSession()
130}
131
132// WithTransaction implements the Session interface.
133func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx SessionContext) (interface{}, error),
134	opts ...*options.TransactionOptions) (interface{}, error) {
135	timeout := time.NewTimer(withTransactionTimeout)
136	defer timeout.Stop()
137	var err error
138	for {
139		err = s.StartTransaction(opts...)
140		if err != nil {
141			return nil, err
142		}
143
144		res, err := fn(contextWithSession(ctx, s))
145		if err != nil {
146			if s.clientSession.TransactionRunning() {
147				_ = s.AbortTransaction(ctx)
148			}
149
150			select {
151			case <-timeout.C:
152				return nil, err
153			default:
154			}
155
156			if cerr, ok := err.(CommandError); ok {
157				if cerr.HasErrorLabel(driver.TransientTransactionError) {
158					continue
159				}
160			}
161			return res, err
162		}
163
164		err = s.clientSession.CheckAbortTransaction()
165		if err != nil {
166			return res, nil
167		}
168
169	CommitLoop:
170		for {
171			err = s.CommitTransaction(ctx)
172			if err == nil {
173				return res, nil
174			}
175
176			select {
177			case <-timeout.C:
178				return res, err
179			default:
180			}
181
182			if cerr, ok := err.(CommandError); ok {
183				if cerr.HasErrorLabel(driver.UnknownTransactionCommitResult) && !cerr.IsMaxTimeMSExpiredError() {
184					continue
185				}
186				if cerr.HasErrorLabel(driver.TransientTransactionError) {
187					break CommitLoop
188				}
189			}
190			return res, err
191		}
192	}
193}
194
195// StartTransaction implements the Session interface.
196func (s *sessionImpl) StartTransaction(opts ...*options.TransactionOptions) error {
197	err := s.clientSession.CheckStartTransaction()
198	if err != nil {
199		return err
200	}
201
202	s.didCommitAfterStart = false
203
204	topts := options.MergeTransactionOptions(opts...)
205	coreOpts := &session.TransactionOptions{
206		ReadConcern:    topts.ReadConcern,
207		ReadPreference: topts.ReadPreference,
208		WriteConcern:   topts.WriteConcern,
209		MaxCommitTime:  topts.MaxCommitTime,
210	}
211
212	return s.clientSession.StartTransaction(coreOpts)
213}
214
215// AbortTransaction implements the Session interface.
216func (s *sessionImpl) AbortTransaction(ctx context.Context) error {
217	err := s.clientSession.CheckAbortTransaction()
218	if err != nil {
219		return err
220	}
221
222	// Do not run the abort command if the transaction is in starting state
223	if s.clientSession.TransactionStarting() || s.didCommitAfterStart {
224		return s.clientSession.AbortTransaction()
225	}
226
227	selector := makePinnedSelector(s.clientSession, description.WriteSelector())
228
229	s.clientSession.Aborting = true
230	_ = operation.NewAbortTransaction().Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").
231		Deployment(s.deployment).WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector).
232		Retry(driver.RetryOncePerCommand).CommandMonitor(s.client.monitor).
233		RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)).Execute(ctx)
234
235	s.clientSession.Aborting = false
236	_ = s.clientSession.AbortTransaction()
237
238	return nil
239}
240
241// CommitTransaction implements the Session interface.
242func (s *sessionImpl) CommitTransaction(ctx context.Context) error {
243	err := s.clientSession.CheckCommitTransaction()
244	if err != nil {
245		return err
246	}
247
248	// Do not run the commit command if the transaction is in started state
249	if s.clientSession.TransactionStarting() || s.didCommitAfterStart {
250		s.didCommitAfterStart = true
251		return s.clientSession.CommitTransaction()
252	}
253
254	if s.clientSession.TransactionCommitted() {
255		s.clientSession.RetryingCommit = true
256	}
257
258	selector := makePinnedSelector(s.clientSession, description.WriteSelector())
259
260	s.clientSession.Committing = true
261	op := operation.NewCommitTransaction().
262		Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").Deployment(s.deployment).
263		WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector).Retry(driver.RetryOncePerCommand).
264		CommandMonitor(s.client.monitor).RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken))
265	if s.clientSession.CurrentMct != nil {
266		op.MaxTimeMS(int64(*s.clientSession.CurrentMct / time.Millisecond))
267	}
268
269	err = op.Execute(ctx)
270	s.clientSession.Committing = false
271	commitErr := s.clientSession.CommitTransaction()
272
273	// We set the write concern to majority for subsequent calls to CommitTransaction.
274	s.clientSession.UpdateCommitTransactionWriteConcern()
275
276	if err != nil {
277		return replaceErrors(err)
278	}
279	return commitErr
280}
281
282// ClusterTime implements the Session interface.
283func (s *sessionImpl) ClusterTime() bson.Raw {
284	return s.clientSession.ClusterTime
285}
286
287// AdvanceClusterTime implements the Session interface.
288func (s *sessionImpl) AdvanceClusterTime(d bson.Raw) error {
289	return s.clientSession.AdvanceClusterTime(d)
290}
291
292// OperationTime implements the Session interface.
293func (s *sessionImpl) OperationTime() *primitive.Timestamp {
294	return s.clientSession.OperationTime
295}
296
297// AdvanceOperationTime implements the Session interface.
298func (s *sessionImpl) AdvanceOperationTime(ts *primitive.Timestamp) error {
299	return s.clientSession.AdvanceOperationTime(ts)
300}
301
302// Client implements the Session interface.
303func (s *sessionImpl) Client() *Client {
304	return s.client
305}
306
307// session implements the Session interface.
308func (*sessionImpl) session() {
309}
310
311// sessionFromContext checks for a sessionImpl in the argued context and returns the session if it
312// exists
313func sessionFromContext(ctx context.Context) *session.Client {
314	s := ctx.Value(sessionKey{})
315	if ses, ok := s.(*sessionImpl); ses != nil && ok {
316		return ses.clientSession
317	}
318
319	return nil
320}
321
322// contextWithSession creates a new SessionContext associated with the given Context and Session parameters.
323func contextWithSession(ctx context.Context, sess Session) SessionContext {
324	return &sessionContext{
325		Context: context.WithValue(ctx, sessionKey{}, sess),
326		Session: sess,
327	}
328}
329