1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package session // import "go.mongodb.org/mongo-driver/x/mongo/driver/session" 8 9import ( 10 "errors" 11 "time" 12 13 "go.mongodb.org/mongo-driver/bson" 14 "go.mongodb.org/mongo-driver/bson/primitive" 15 "go.mongodb.org/mongo-driver/mongo/description" 16 "go.mongodb.org/mongo-driver/mongo/readconcern" 17 "go.mongodb.org/mongo-driver/mongo/readpref" 18 "go.mongodb.org/mongo-driver/mongo/writeconcern" 19 "go.mongodb.org/mongo-driver/x/mongo/driver/uuid" 20) 21 22// ErrSessionEnded is returned when a client session is used after a call to endSession(). 23var ErrSessionEnded = errors.New("ended session was used") 24 25// ErrNoTransactStarted is returned if a transaction operation is called when no transaction has started. 26var ErrNoTransactStarted = errors.New("no transaction started") 27 28// ErrTransactInProgress is returned if startTransaction() is called when a transaction is in progress. 29var ErrTransactInProgress = errors.New("transaction already in progress") 30 31// ErrAbortAfterCommit is returned when abort is called after a commit. 32var ErrAbortAfterCommit = errors.New("cannot call abortTransaction after calling commitTransaction") 33 34// ErrAbortTwice is returned if abort is called after transaction is already aborted. 35var ErrAbortTwice = errors.New("cannot call abortTransaction twice") 36 37// ErrCommitAfterAbort is returned if commit is called after an abort. 38var ErrCommitAfterAbort = errors.New("cannot call commitTransaction after calling abortTransaction") 39 40// ErrUnackWCUnsupported is returned if an unacknowledged write concern is supported for a transaciton. 41var ErrUnackWCUnsupported = errors.New("transactions do not support unacknowledged write concerns") 42 43// Type describes the type of the session 44type Type uint8 45 46// These constants are the valid types for a client session. 47const ( 48 Explicit Type = iota 49 Implicit 50) 51 52// TransactionState indicates the state of the transactions FSM. 53type TransactionState uint8 54 55// Client Session states 56const ( 57 None TransactionState = iota 58 Starting 59 InProgress 60 Committed 61 Aborted 62) 63 64// String implements the fmt.Stringer interface. 65func (s TransactionState) String() string { 66 switch s { 67 case None: 68 return "none" 69 case Starting: 70 return "starting" 71 case InProgress: 72 return "in progress" 73 case Committed: 74 return "committed" 75 case Aborted: 76 return "aborted" 77 default: 78 return "unknown" 79 } 80} 81 82// Client is a session for clients to run commands. 83type Client struct { 84 *Server 85 ClientID uuid.UUID 86 ClusterTime bson.Raw 87 Consistent bool // causal consistency 88 OperationTime *primitive.Timestamp 89 SessionType Type 90 Terminated bool 91 RetryingCommit bool 92 Committing bool 93 Aborting bool 94 RetryWrite bool 95 RetryRead bool 96 97 // options for the current transaction 98 // most recently set by transactionopt 99 CurrentRc *readconcern.ReadConcern 100 CurrentRp *readpref.ReadPref 101 CurrentWc *writeconcern.WriteConcern 102 CurrentMct *time.Duration 103 104 // default transaction options 105 transactionRc *readconcern.ReadConcern 106 transactionRp *readpref.ReadPref 107 transactionWc *writeconcern.WriteConcern 108 transactionMaxCommitTime *time.Duration 109 110 pool *Pool 111 TransactionState TransactionState 112 PinnedServer *description.Server 113 RecoveryToken bson.Raw 114} 115 116func getClusterTime(clusterTime bson.Raw) (uint32, uint32) { 117 if clusterTime == nil { 118 return 0, 0 119 } 120 121 clusterTimeVal, err := clusterTime.LookupErr("$clusterTime") 122 if err != nil { 123 return 0, 0 124 } 125 126 timestampVal, err := bson.Raw(clusterTimeVal.Value).LookupErr("clusterTime") 127 if err != nil { 128 return 0, 0 129 } 130 131 return timestampVal.Timestamp() 132} 133 134// MaxClusterTime compares 2 clusterTime documents and returns the document representing the highest cluster time. 135func MaxClusterTime(ct1, ct2 bson.Raw) bson.Raw { 136 epoch1, ord1 := getClusterTime(ct1) 137 epoch2, ord2 := getClusterTime(ct2) 138 139 if epoch1 > epoch2 { 140 return ct1 141 } else if epoch1 < epoch2 { 142 return ct2 143 } else if ord1 > ord2 { 144 return ct1 145 } else if ord1 < ord2 { 146 return ct2 147 } 148 149 return ct1 150} 151 152// NewClientSession creates a Client. 153func NewClientSession(pool *Pool, clientID uuid.UUID, sessionType Type, opts ...*ClientOptions) (*Client, error) { 154 c := &Client{ 155 Consistent: true, // set default 156 ClientID: clientID, 157 SessionType: sessionType, 158 pool: pool, 159 } 160 161 mergedOpts := mergeClientOptions(opts...) 162 if mergedOpts.CausalConsistency != nil { 163 c.Consistent = *mergedOpts.CausalConsistency 164 } 165 if mergedOpts.DefaultReadPreference != nil { 166 c.transactionRp = mergedOpts.DefaultReadPreference 167 } 168 if mergedOpts.DefaultReadConcern != nil { 169 c.transactionRc = mergedOpts.DefaultReadConcern 170 } 171 if mergedOpts.DefaultWriteConcern != nil { 172 c.transactionWc = mergedOpts.DefaultWriteConcern 173 } 174 if mergedOpts.DefaultMaxCommitTime != nil { 175 c.transactionMaxCommitTime = mergedOpts.DefaultMaxCommitTime 176 } 177 178 servSess, err := pool.GetSession() 179 if err != nil { 180 return nil, err 181 } 182 183 c.Server = servSess 184 185 return c, nil 186} 187 188// AdvanceClusterTime updates the session's cluster time. 189func (c *Client) AdvanceClusterTime(clusterTime bson.Raw) error { 190 if c.Terminated { 191 return ErrSessionEnded 192 } 193 c.ClusterTime = MaxClusterTime(c.ClusterTime, clusterTime) 194 return nil 195} 196 197// AdvanceOperationTime updates the session's operation time. 198func (c *Client) AdvanceOperationTime(opTime *primitive.Timestamp) error { 199 if c.Terminated { 200 return ErrSessionEnded 201 } 202 203 if c.OperationTime == nil { 204 c.OperationTime = opTime 205 return nil 206 } 207 208 if opTime.T > c.OperationTime.T { 209 c.OperationTime = opTime 210 } else if (opTime.T == c.OperationTime.T) && (opTime.I > c.OperationTime.I) { 211 c.OperationTime = opTime 212 } 213 214 return nil 215} 216 217// UpdateUseTime sets the session's last used time to the current time. This must be called whenever the session is 218// used to send a command to the server to ensure that the session is not prematurely marked expired in the driver's 219// session pool. If the session has already been ended, this method will return ErrSessionEnded. 220func (c *Client) UpdateUseTime() error { 221 if c.Terminated { 222 return ErrSessionEnded 223 } 224 c.updateUseTime() 225 return nil 226} 227 228// UpdateRecoveryToken updates the session's recovery token from the server response. 229func (c *Client) UpdateRecoveryToken(response bson.Raw) { 230 if c == nil { 231 return 232 } 233 234 token, err := response.LookupErr("recoveryToken") 235 if err != nil { 236 return 237 } 238 239 c.RecoveryToken = token.Document() 240} 241 242// ClearPinnedServer sets the PinnedServer to nil. 243func (c *Client) ClearPinnedServer() { 244 if c != nil { 245 c.PinnedServer = nil 246 } 247} 248 249// EndSession ends the session. 250func (c *Client) EndSession() { 251 if c.Terminated { 252 return 253 } 254 255 c.Terminated = true 256 c.pool.ReturnSession(c.Server) 257 258 return 259} 260 261// TransactionInProgress returns true if the client session is in an active transaction. 262func (c *Client) TransactionInProgress() bool { 263 return c.TransactionState == InProgress 264} 265 266// TransactionStarting returns true if the client session is starting a transaction. 267func (c *Client) TransactionStarting() bool { 268 return c.TransactionState == Starting 269} 270 271// TransactionRunning returns true if the client session has started the transaction 272// and it hasn't been committed or aborted 273func (c *Client) TransactionRunning() bool { 274 return c != nil && (c.TransactionState == Starting || c.TransactionState == InProgress) 275} 276 277// TransactionCommitted returns true of the client session just committed a transaciton. 278func (c *Client) TransactionCommitted() bool { 279 return c.TransactionState == Committed 280} 281 282// CheckStartTransaction checks to see if allowed to start transaction and returns 283// an error if not allowed 284func (c *Client) CheckStartTransaction() error { 285 if c.TransactionState == InProgress || c.TransactionState == Starting { 286 return ErrTransactInProgress 287 } 288 return nil 289} 290 291// StartTransaction initializes the transaction options and advances the state machine. 292// It does not contact the server to start the transaction. 293func (c *Client) StartTransaction(opts *TransactionOptions) error { 294 err := c.CheckStartTransaction() 295 if err != nil { 296 return err 297 } 298 299 c.IncrementTxnNumber() 300 c.RetryingCommit = false 301 302 if opts != nil { 303 c.CurrentRc = opts.ReadConcern 304 c.CurrentRp = opts.ReadPreference 305 c.CurrentWc = opts.WriteConcern 306 c.CurrentMct = opts.MaxCommitTime 307 } 308 309 if c.CurrentRc == nil { 310 c.CurrentRc = c.transactionRc 311 } 312 313 if c.CurrentRp == nil { 314 c.CurrentRp = c.transactionRp 315 } 316 317 if c.CurrentWc == nil { 318 c.CurrentWc = c.transactionWc 319 } 320 321 if c.CurrentMct == nil { 322 c.CurrentMct = c.transactionMaxCommitTime 323 } 324 325 if !writeconcern.AckWrite(c.CurrentWc) { 326 c.clearTransactionOpts() 327 return ErrUnackWCUnsupported 328 } 329 330 c.TransactionState = Starting 331 c.PinnedServer = nil 332 return nil 333} 334 335// CheckCommitTransaction checks to see if allowed to commit transaction and returns 336// an error if not allowed. 337func (c *Client) CheckCommitTransaction() error { 338 if c.TransactionState == None { 339 return ErrNoTransactStarted 340 } else if c.TransactionState == Aborted { 341 return ErrCommitAfterAbort 342 } 343 return nil 344} 345 346// CommitTransaction updates the state for a successfully committed transaction and returns 347// an error if not permissible. It does not actually perform the commit. 348func (c *Client) CommitTransaction() error { 349 err := c.CheckCommitTransaction() 350 if err != nil { 351 return err 352 } 353 c.TransactionState = Committed 354 return nil 355} 356 357// UpdateCommitTransactionWriteConcern will set the write concern to majority and potentially set a 358// w timeout of 10 seconds. This should be called after a commit transaction operation fails with a 359// retryable error or after a successful commit transaction operation. 360func (c *Client) UpdateCommitTransactionWriteConcern() { 361 wc := c.CurrentWc 362 timeout := 10 * time.Second 363 if wc != nil && wc.GetWTimeout() != 0 { 364 timeout = wc.GetWTimeout() 365 } 366 c.CurrentWc = wc.WithOptions(writeconcern.WMajority(), writeconcern.WTimeout(timeout)) 367} 368 369// CheckAbortTransaction checks to see if allowed to abort transaction and returns 370// an error if not allowed. 371func (c *Client) CheckAbortTransaction() error { 372 if c.TransactionState == None { 373 return ErrNoTransactStarted 374 } else if c.TransactionState == Committed { 375 return ErrAbortAfterCommit 376 } else if c.TransactionState == Aborted { 377 return ErrAbortTwice 378 } 379 return nil 380} 381 382// AbortTransaction updates the state for a successfully aborted transaction and returns 383// an error if not permissible. It does not actually perform the abort. 384func (c *Client) AbortTransaction() error { 385 err := c.CheckAbortTransaction() 386 if err != nil { 387 return err 388 } 389 c.TransactionState = Aborted 390 c.clearTransactionOpts() 391 return nil 392} 393 394// ApplyCommand advances the state machine upon command execution. 395func (c *Client) ApplyCommand(desc description.Server) { 396 if c.Committing { 397 // Do not change state if committing after already committed 398 return 399 } 400 if c.TransactionState == Starting { 401 c.TransactionState = InProgress 402 // If this is in a transaction and the server is a mongos, pin it 403 if desc.Kind == description.Mongos { 404 c.PinnedServer = &desc 405 } 406 } else if c.TransactionState == Committed || c.TransactionState == Aborted { 407 c.clearTransactionOpts() 408 c.TransactionState = None 409 } 410} 411 412func (c *Client) clearTransactionOpts() { 413 c.RetryingCommit = false 414 c.Aborting = false 415 c.Committing = false 416 c.CurrentWc = nil 417 c.CurrentRp = nil 418 c.CurrentRc = nil 419 c.PinnedServer = nil 420 c.RecoveryToken = nil 421} 422