1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package mongo 8 9import ( 10 "context" 11 "errors" 12 "math" 13 "strconv" 14 "strings" 15 "testing" 16 "time" 17 18 "go.mongodb.org/mongo-driver/bson" 19 "go.mongodb.org/mongo-driver/event" 20 "go.mongodb.org/mongo-driver/internal/testutil" 21 "go.mongodb.org/mongo-driver/internal/testutil/assert" 22 "go.mongodb.org/mongo-driver/mongo/options" 23 "go.mongodb.org/mongo-driver/mongo/readpref" 24 "go.mongodb.org/mongo-driver/mongo/writeconcern" 25 "go.mongodb.org/mongo-driver/x/mongo/driver" 26 "go.mongodb.org/mongo-driver/x/mongo/driver/description" 27 "go.mongodb.org/mongo-driver/x/mongo/driver/topology" 28) 29 30var ( 31 connsCheckedOut int 32 errorInterrupted int32 = 11601 33) 34 35func TestConvenientTransactions(t *testing.T) { 36 client := setupConvenientTransactions(t) 37 db := client.Database("TestConvenientTransactions") 38 dbAdmin := client.Database("admin") 39 40 defer func() { 41 sessions := client.NumberSessionsInProgress() 42 conns := connsCheckedOut 43 44 err := dbAdmin.RunCommand(bgCtx, bson.D{ 45 {"killAllSessions", bson.A{}}, 46 }).Err() 47 if err != nil { 48 if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted { 49 t.Fatalf("killAllSessions error: %v", err) 50 } 51 } 52 53 _ = db.Drop(bgCtx) 54 _ = client.Disconnect(bgCtx) 55 56 assert.Equal(t, 0, sessions, "%v sessions checked out", sessions) 57 assert.Equal(t, 0, conns, "%v connections checked out", conns) 58 }() 59 60 t.Run("callback raises custom error", func(t *testing.T) { 61 coll := db.Collection(t.Name()) 62 _, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}}) 63 assert.Nil(t, err, "InsertOne error: %v", err) 64 65 sess, err := client.StartSession() 66 assert.Nil(t, err, "StartSession error: %v", err) 67 defer sess.EndSession(context.Background()) 68 69 testErr := errors.New("test error") 70 _, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) { 71 return nil, testErr 72 }) 73 assert.Equal(t, testErr, err, "expected error %v, got %v", testErr, err) 74 }) 75 t.Run("callback returns value", func(t *testing.T) { 76 coll := db.Collection(t.Name()) 77 _, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}}) 78 assert.Nil(t, err, "InsertOne error: %v", err) 79 80 sess, err := client.StartSession() 81 assert.Nil(t, err, "StartSession error: %v", err) 82 defer sess.EndSession(context.Background()) 83 84 res, err := sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) { 85 return false, nil 86 }) 87 assert.Nil(t, err, "WithTransaction error: %v", err) 88 resBool, ok := res.(bool) 89 assert.True(t, ok, "expected result type %T, got %T", false, res) 90 assert.False(t, resBool, "expected result false, got %v", resBool) 91 }) 92 t.Run("retry timeout enforced", func(t *testing.T) { 93 withTransactionTimeout = time.Second 94 95 coll := db.Collection(t.Name()) 96 _, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}}) 97 assert.Nil(t, err, "InsertOne error: %v", err) 98 99 t.Run("transient transaction error", func(t *testing.T) { 100 sess, err := client.StartSession() 101 assert.Nil(t, err, "StartSession error: %v", err) 102 defer sess.EndSession(context.Background()) 103 104 _, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) { 105 return nil, CommandError{Name: "test Error", Labels: []string{driver.TransientTransactionError}} 106 }) 107 assert.NotNil(t, err, "expected WithTransaction error, got nil") 108 cmdErr, ok := err.(CommandError) 109 assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err) 110 assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError), 111 "expected error with label %v, got %v", driver.TransientTransactionError, cmdErr) 112 }) 113 t.Run("unknown transaction commit result", func(t *testing.T) { 114 //set failpoint 115 failpoint := bson.D{{"configureFailPoint", "failCommand"}, 116 {"mode", "alwaysOn"}, 117 {"data", bson.D{ 118 {"failCommands", bson.A{"commitTransaction"}}, 119 {"closeConnection", true}, 120 }}, 121 } 122 err = dbAdmin.RunCommand(bgCtx, failpoint).Err() 123 assert.Nil(t, err, "error setting failpoint: %v", err) 124 defer func() { 125 err = dbAdmin.RunCommand(bgCtx, bson.D{ 126 {"configureFailPoint", "failCommand"}, 127 {"mode", "off"}, 128 }).Err() 129 assert.Nil(t, err, "error turning off failpoint: %v", err) 130 }() 131 132 sess, err := client.StartSession() 133 assert.Nil(t, err, "StartSession error: %v", err) 134 defer sess.EndSession(context.Background()) 135 136 _, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) { 137 _, err := coll.InsertOne(sessCtx, bson.D{{"x", 1}}) 138 return nil, err 139 }) 140 assert.NotNil(t, err, "expected WithTransaction error, got nil") 141 cmdErr, ok := err.(CommandError) 142 assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err) 143 assert.True(t, cmdErr.HasErrorLabel(driver.UnknownTransactionCommitResult), 144 "expected error with label %v, got %v", driver.UnknownTransactionCommitResult, cmdErr) 145 }) 146 t.Run("commit transient transaction error", func(t *testing.T) { 147 //set failpoint 148 failpoint := bson.D{{"configureFailPoint", "failCommand"}, 149 {"mode", "alwaysOn"}, 150 {"data", bson.D{ 151 {"failCommands", bson.A{"commitTransaction"}}, 152 {"errorCode", 251}, 153 }}, 154 } 155 err = dbAdmin.RunCommand(bgCtx, failpoint).Err() 156 assert.Nil(t, err, "error setting failpoint: %v", err) 157 defer func() { 158 err = dbAdmin.RunCommand(bgCtx, bson.D{ 159 {"configureFailPoint", "failCommand"}, 160 {"mode", "off"}, 161 }).Err() 162 assert.Nil(t, err, "error turning off failpoint: %v", err) 163 }() 164 165 sess, err := client.StartSession() 166 assert.Nil(t, err, "StartSession error: %v", err) 167 defer sess.EndSession(context.Background()) 168 169 _, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) { 170 _, err := coll.InsertOne(sessCtx, bson.D{{"x", 1}}) 171 return nil, err 172 }) 173 assert.NotNil(t, err, "expected WithTransaction error, got nil") 174 cmdErr, ok := err.(CommandError) 175 assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err) 176 assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError), 177 "expected error with label %v, got %v", driver.TransientTransactionError, cmdErr) 178 }) 179 }) 180} 181 182func setupConvenientTransactions(t *testing.T) *Client { 183 cs := testutil.ConnString(t) 184 poolMonitor := &event.PoolMonitor{ 185 Event: func(evt *event.PoolEvent) { 186 switch evt.Type { 187 case event.GetSucceeded: 188 connsCheckedOut++ 189 case event.ConnectionReturned: 190 connsCheckedOut-- 191 } 192 }, 193 } 194 clientOpts := options.Client().ApplyURI(cs.Original).SetReadPreference(readpref.Primary()). 195 SetWriteConcern(writeconcern.New(writeconcern.WMajority())).SetPoolMonitor(poolMonitor) 196 client, err := Connect(bgCtx, clientOpts) 197 assert.Nil(t, err, "Connect error: %v", err) 198 199 version, err := getServerVersion(client.Database("admin")) 200 assert.Nil(t, err, "getServerVersion error: %v", err) 201 topoKind := client.deployment.(*topology.Topology).Kind() 202 if compareVersions(t, version, "4.1") < 0 || topoKind == description.Single { 203 t.Skip("skipping standalones and versions < 4.1") 204 } 205 206 // pin to a single mongos if necessary 207 if topoKind != description.Sharded { 208 return client 209 } 210 client, err = Connect(bgCtx, clientOpts.SetHosts([]string{cs.Hosts[0]})) 211 assert.Nil(t, err, "Connect error: %v", err) 212 return client 213} 214 215func getServerVersion(db *Database) (string, error) { 216 serverStatus, err := db.RunCommand( 217 context.Background(), 218 bson.D{{"serverStatus", 1}}, 219 ).DecodeBytes() 220 if err != nil { 221 return "", err 222 } 223 224 version, err := serverStatus.LookupErr("version") 225 if err != nil { 226 return "", err 227 } 228 229 return version.StringValue(), nil 230} 231 232// compareVersions compares two version number strings (i.e. positive integers separated by 233// periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is 234// considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11. 235// 236// Returns a positive int if version1 is greater than version2, a negative int if version1 is less 237// than version2, and 0 if version1 is equal to version2. 238func compareVersions(t *testing.T, v1 string, v2 string) int { 239 n1 := strings.Split(v1, ".") 240 n2 := strings.Split(v2, ".") 241 242 for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ { 243 i1, err := strconv.Atoi(n1[i]) 244 if err != nil { 245 return 1 246 } 247 248 i2, err := strconv.Atoi(n2[i]) 249 if err != nil { 250 return -1 251 } 252 253 difference := i1 - i2 254 if difference != 0 { 255 return difference 256 } 257 } 258 259 return 0 260} 261