1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package session // import "go.mongodb.org/mongo-driver/x/mongo/driver/session" 8 9import ( 10 "errors" 11 "time" 12 13 "go.mongodb.org/mongo-driver/bson" 14 "go.mongodb.org/mongo-driver/bson/primitive" 15 "go.mongodb.org/mongo-driver/mongo/readconcern" 16 "go.mongodb.org/mongo-driver/mongo/readpref" 17 "go.mongodb.org/mongo-driver/mongo/writeconcern" 18 "go.mongodb.org/mongo-driver/x/mongo/driver/description" 19 "go.mongodb.org/mongo-driver/x/mongo/driver/uuid" 20) 21 22// ErrSessionEnded is returned when a client session is used after a call to endSession(). 23var ErrSessionEnded = errors.New("ended session was used") 24 25// ErrNoTransactStarted is returned if a transaction operation is called when no transaction has started. 26var ErrNoTransactStarted = errors.New("no transaction started") 27 28// ErrTransactInProgress is returned if startTransaction() is called when a transaction is in progress. 29var ErrTransactInProgress = errors.New("transaction already in progress") 30 31// ErrAbortAfterCommit is returned when abort is called after a commit. 32var ErrAbortAfterCommit = errors.New("cannot call abortTransaction after calling commitTransaction") 33 34// ErrAbortTwice is returned if abort is called after transaction is already aborted. 35var ErrAbortTwice = errors.New("cannot call abortTransaction twice") 36 37// ErrCommitAfterAbort is returned if commit is called after an abort. 38var ErrCommitAfterAbort = errors.New("cannot call commitTransaction after calling abortTransaction") 39 40// ErrUnackWCUnsupported is returned if an unacknowledged write concern is supported for a transaciton. 41var ErrUnackWCUnsupported = errors.New("transactions do not support unacknowledged write concerns") 42 43// Type describes the type of the session 44type Type uint8 45 46// These constants are the valid types for a client session. 47const ( 48 Explicit Type = iota 49 Implicit 50) 51 52// State indicates the state of the FSM. 53type state uint8 54 55// Client Session states 56const ( 57 None state = iota 58 Starting 59 InProgress 60 Committed 61 Aborted 62) 63 64// Client is a session for clients to run commands. 65type Client struct { 66 *Server 67 ClientID uuid.UUID 68 ClusterTime bson.Raw 69 Consistent bool // causal consistency 70 OperationTime *primitive.Timestamp 71 SessionType Type 72 Terminated bool 73 RetryingCommit bool 74 Committing bool 75 Aborting bool 76 RetryWrite bool 77 RetryRead bool 78 79 // options for the current transaction 80 // most recently set by transactionopt 81 CurrentRc *readconcern.ReadConcern 82 CurrentRp *readpref.ReadPref 83 CurrentWc *writeconcern.WriteConcern 84 CurrentMct *time.Duration 85 86 // default transaction options 87 transactionRc *readconcern.ReadConcern 88 transactionRp *readpref.ReadPref 89 transactionWc *writeconcern.WriteConcern 90 transactionMaxCommitTime *time.Duration 91 92 pool *Pool 93 state state 94 PinnedServer *description.Server 95 RecoveryToken bson.Raw 96} 97 98func getClusterTime(clusterTime bson.Raw) (uint32, uint32) { 99 if clusterTime == nil { 100 return 0, 0 101 } 102 103 clusterTimeVal, err := clusterTime.LookupErr("$clusterTime") 104 if err != nil { 105 return 0, 0 106 } 107 108 timestampVal, err := bson.Raw(clusterTimeVal.Value).LookupErr("clusterTime") 109 if err != nil { 110 return 0, 0 111 } 112 113 return timestampVal.Timestamp() 114} 115 116// MaxClusterTime compares 2 clusterTime documents and returns the document representing the highest cluster time. 117func MaxClusterTime(ct1, ct2 bson.Raw) bson.Raw { 118 epoch1, ord1 := getClusterTime(ct1) 119 epoch2, ord2 := getClusterTime(ct2) 120 121 if epoch1 > epoch2 { 122 return ct1 123 } else if epoch1 < epoch2 { 124 return ct2 125 } else if ord1 > ord2 { 126 return ct1 127 } else if ord1 < ord2 { 128 return ct2 129 } 130 131 return ct1 132} 133 134// NewClientSession creates a Client. 135func NewClientSession(pool *Pool, clientID uuid.UUID, sessionType Type, opts ...*ClientOptions) (*Client, error) { 136 c := &Client{ 137 Consistent: true, // set default 138 ClientID: clientID, 139 SessionType: sessionType, 140 pool: pool, 141 } 142 143 mergedOpts := mergeClientOptions(opts...) 144 if mergedOpts.CausalConsistency != nil { 145 c.Consistent = *mergedOpts.CausalConsistency 146 } 147 if mergedOpts.DefaultReadPreference != nil { 148 c.transactionRp = mergedOpts.DefaultReadPreference 149 } 150 if mergedOpts.DefaultReadConcern != nil { 151 c.transactionRc = mergedOpts.DefaultReadConcern 152 } 153 if mergedOpts.DefaultWriteConcern != nil { 154 c.transactionWc = mergedOpts.DefaultWriteConcern 155 } 156 if mergedOpts.DefaultMaxCommitTime != nil { 157 c.transactionMaxCommitTime = mergedOpts.DefaultMaxCommitTime 158 } 159 160 servSess, err := pool.GetSession() 161 if err != nil { 162 return nil, err 163 } 164 165 c.Server = servSess 166 167 return c, nil 168} 169 170// AdvanceClusterTime updates the session's cluster time. 171func (c *Client) AdvanceClusterTime(clusterTime bson.Raw) error { 172 if c.Terminated { 173 return ErrSessionEnded 174 } 175 c.ClusterTime = MaxClusterTime(c.ClusterTime, clusterTime) 176 return nil 177} 178 179// AdvanceOperationTime updates the session's operation time. 180func (c *Client) AdvanceOperationTime(opTime *primitive.Timestamp) error { 181 if c.Terminated { 182 return ErrSessionEnded 183 } 184 185 if c.OperationTime == nil { 186 c.OperationTime = opTime 187 return nil 188 } 189 190 if opTime.T > c.OperationTime.T { 191 c.OperationTime = opTime 192 } else if (opTime.T == c.OperationTime.T) && (opTime.I > c.OperationTime.I) { 193 c.OperationTime = opTime 194 } 195 196 return nil 197} 198 199// UpdateUseTime updates the session's last used time. 200// Must be called whenver this session is used to send a command to the server. 201func (c *Client) UpdateUseTime() error { 202 if c.Terminated { 203 return ErrSessionEnded 204 } 205 c.updateUseTime() 206 return nil 207} 208 209// UpdateRecoveryToken updates the session's recovery token from the server response. 210func (c *Client) UpdateRecoveryToken(response bson.Raw) { 211 if c == nil { 212 return 213 } 214 215 token, err := response.LookupErr("recoveryToken") 216 if err != nil { 217 return 218 } 219 220 c.RecoveryToken = token.Document() 221} 222 223// ClearPinnedServer sets the PinnedServer to nil. 224func (c *Client) ClearPinnedServer() { 225 if c != nil { 226 c.PinnedServer = nil 227 } 228} 229 230// EndSession ends the session. 231func (c *Client) EndSession() { 232 if c.Terminated { 233 return 234 } 235 236 c.Terminated = true 237 c.pool.ReturnSession(c.Server) 238 239 return 240} 241 242// TransactionInProgress returns true if the client session is in an active transaction. 243func (c *Client) TransactionInProgress() bool { 244 return c.state == InProgress 245} 246 247// TransactionStarting returns true if the client session is starting a transaction. 248func (c *Client) TransactionStarting() bool { 249 return c.state == Starting 250} 251 252// TransactionRunning returns true if the client session has started the transaction 253// and it hasn't been committed or aborted 254func (c *Client) TransactionRunning() bool { 255 return c != nil && (c.state == Starting || c.state == InProgress) 256} 257 258// TransactionCommitted returns true of the client session just committed a transaciton. 259func (c *Client) TransactionCommitted() bool { 260 return c.state == Committed 261} 262 263// CheckStartTransaction checks to see if allowed to start transaction and returns 264// an error if not allowed 265func (c *Client) CheckStartTransaction() error { 266 if c.state == InProgress || c.state == Starting { 267 return ErrTransactInProgress 268 } 269 return nil 270} 271 272// StartTransaction initializes the transaction options and advances the state machine. 273// It does not contact the server to start the transaction. 274func (c *Client) StartTransaction(opts *TransactionOptions) error { 275 err := c.CheckStartTransaction() 276 if err != nil { 277 return err 278 } 279 280 c.IncrementTxnNumber() 281 c.RetryingCommit = false 282 283 if opts != nil { 284 c.CurrentRc = opts.ReadConcern 285 c.CurrentRp = opts.ReadPreference 286 c.CurrentWc = opts.WriteConcern 287 c.CurrentMct = opts.MaxCommitTime 288 } 289 290 if c.CurrentRc == nil { 291 c.CurrentRc = c.transactionRc 292 } 293 294 if c.CurrentRp == nil { 295 c.CurrentRp = c.transactionRp 296 } 297 298 if c.CurrentWc == nil { 299 c.CurrentWc = c.transactionWc 300 } 301 302 if c.CurrentMct == nil { 303 c.CurrentMct = c.transactionMaxCommitTime 304 } 305 306 if !writeconcern.AckWrite(c.CurrentWc) { 307 c.clearTransactionOpts() 308 return ErrUnackWCUnsupported 309 } 310 311 c.state = Starting 312 c.PinnedServer = nil 313 return nil 314} 315 316// CheckCommitTransaction checks to see if allowed to commit transaction and returns 317// an error if not allowed. 318func (c *Client) CheckCommitTransaction() error { 319 if c.state == None { 320 return ErrNoTransactStarted 321 } else if c.state == Aborted { 322 return ErrCommitAfterAbort 323 } 324 return nil 325} 326 327// CommitTransaction updates the state for a successfully committed transaction and returns 328// an error if not permissible. It does not actually perform the commit. 329func (c *Client) CommitTransaction() error { 330 err := c.CheckCommitTransaction() 331 if err != nil { 332 return err 333 } 334 c.state = Committed 335 return nil 336} 337 338// UpdateCommitTransactionWriteConcern will set the write concern to majority and potentially set a 339// w timeout of 10 seconds. This should be called after a commit transaction operation fails with a 340// retryable error or after a successful commit transaction operation. 341func (c *Client) UpdateCommitTransactionWriteConcern() { 342 wc := c.CurrentWc 343 timeout := 10 * time.Second 344 if wc != nil && wc.GetWTimeout() != 0 { 345 timeout = wc.GetWTimeout() 346 } 347 c.CurrentWc = wc.WithOptions(writeconcern.WMajority(), writeconcern.WTimeout(timeout)) 348} 349 350// CheckAbortTransaction checks to see if allowed to abort transaction and returns 351// an error if not allowed. 352func (c *Client) CheckAbortTransaction() error { 353 if c.state == None { 354 return ErrNoTransactStarted 355 } else if c.state == Committed { 356 return ErrAbortAfterCommit 357 } else if c.state == Aborted { 358 return ErrAbortTwice 359 } 360 return nil 361} 362 363// AbortTransaction updates the state for a successfully aborted transaction and returns 364// an error if not permissible. It does not actually perform the abort. 365func (c *Client) AbortTransaction() error { 366 err := c.CheckAbortTransaction() 367 if err != nil { 368 return err 369 } 370 c.state = Aborted 371 c.clearTransactionOpts() 372 return nil 373} 374 375// ApplyCommand advances the state machine upon command execution. 376func (c *Client) ApplyCommand(desc description.Server) { 377 if c.Committing { 378 // Do not change state if committing after already committed 379 return 380 } 381 if c.state == Starting { 382 c.state = InProgress 383 // If this is in a transaction and the server is a mongos, pin it 384 if desc.Kind == description.Mongos { 385 c.PinnedServer = &desc 386 } 387 } else if c.state == Committed || c.state == Aborted { 388 c.clearTransactionOpts() 389 c.state = None 390 } 391} 392 393func (c *Client) clearTransactionOpts() { 394 c.RetryingCommit = false 395 c.Aborting = false 396 c.Committing = false 397 c.CurrentWc = nil 398 c.CurrentRp = nil 399 c.CurrentRc = nil 400 c.PinnedServer = nil 401 c.RecoveryToken = nil 402} 403