1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package mongo 8 9import ( 10 "context" 11 "errors" 12 "math" 13 "strconv" 14 "strings" 15 "testing" 16 "time" 17 18 "go.mongodb.org/mongo-driver/bson" 19 "go.mongodb.org/mongo-driver/event" 20 "go.mongodb.org/mongo-driver/internal/testutil" 21 "go.mongodb.org/mongo-driver/internal/testutil/assert" 22 "go.mongodb.org/mongo-driver/mongo/options" 23 "go.mongodb.org/mongo-driver/mongo/readpref" 24 "go.mongodb.org/mongo-driver/mongo/writeconcern" 25 "go.mongodb.org/mongo-driver/x/mongo/driver" 26 "go.mongodb.org/mongo-driver/x/mongo/driver/description" 27 "go.mongodb.org/mongo-driver/x/mongo/driver/topology" 28) 29 30var ( 31 connsCheckedOut int 32) 33 34func TestConvenientTransactions(t *testing.T) { 35 client := setupConvenientTransactions(t) 36 db := client.Database("TestConvenientTransactions") 37 dbAdmin := client.Database("admin") 38 39 defer func() { 40 sessions := client.NumberSessionsInProgress() 41 conns := connsCheckedOut 42 43 err := dbAdmin.RunCommand(bgCtx, bson.D{ 44 {"killAllSessions", bson.A{}}, 45 }).Err() 46 if err != nil { 47 if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted { 48 t.Fatalf("killAllSessions error: %v", err) 49 } 50 } 51 52 _ = db.Drop(bgCtx) 53 _ = client.Disconnect(bgCtx) 54 55 assert.Equal(t, 0, sessions, "%v sessions checked out", sessions) 56 assert.Equal(t, 0, conns, "%v connections checked out", conns) 57 }() 58 59 t.Run("callback raises custom error", func(t *testing.T) { 60 coll := db.Collection(t.Name()) 61 _, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}}) 62 assert.Nil(t, err, "InsertOne error: %v", err) 63 64 sess, err := client.StartSession() 65 assert.Nil(t, err, "StartSession error: %v", err) 66 defer sess.EndSession(context.Background()) 67 68 testErr := errors.New("test error") 69 _, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) { 70 return nil, testErr 71 }) 72 assert.Equal(t, testErr, err, "expected error %v, got %v", testErr, err) 73 }) 74 t.Run("callback returns value", func(t *testing.T) { 75 coll := db.Collection(t.Name()) 76 _, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}}) 77 assert.Nil(t, err, "InsertOne error: %v", err) 78 79 sess, err := client.StartSession() 80 assert.Nil(t, err, "StartSession error: %v", err) 81 defer sess.EndSession(context.Background()) 82 83 res, err := sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) { 84 return false, nil 85 }) 86 assert.Nil(t, err, "WithTransaction error: %v", err) 87 resBool, ok := res.(bool) 88 assert.True(t, ok, "expected result type %T, got %T", false, res) 89 assert.False(t, resBool, "expected result false, got %v", resBool) 90 }) 91 t.Run("retry timeout enforced", func(t *testing.T) { 92 withTransactionTimeout = time.Second 93 94 coll := db.Collection(t.Name()) 95 _, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}}) 96 assert.Nil(t, err, "InsertOne error: %v", err) 97 98 t.Run("transient transaction error", func(t *testing.T) { 99 sess, err := client.StartSession() 100 assert.Nil(t, err, "StartSession error: %v", err) 101 defer sess.EndSession(context.Background()) 102 103 _, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) { 104 return nil, CommandError{Name: "test Error", Labels: []string{driver.TransientTransactionError}} 105 }) 106 assert.NotNil(t, err, "expected WithTransaction error, got nil") 107 cmdErr, ok := err.(CommandError) 108 assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err) 109 assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError), 110 "expected error with label %v, got %v", driver.TransientTransactionError, cmdErr) 111 }) 112 t.Run("unknown transaction commit result", func(t *testing.T) { 113 //set failpoint 114 failpoint := bson.D{{"configureFailPoint", "failCommand"}, 115 {"mode", "alwaysOn"}, 116 {"data", bson.D{ 117 {"failCommands", bson.A{"commitTransaction"}}, 118 {"closeConnection", true}, 119 }}, 120 } 121 err = dbAdmin.RunCommand(bgCtx, failpoint).Err() 122 assert.Nil(t, err, "error setting failpoint: %v", err) 123 defer func() { 124 err = dbAdmin.RunCommand(bgCtx, bson.D{ 125 {"configureFailPoint", "failCommand"}, 126 {"mode", "off"}, 127 }).Err() 128 assert.Nil(t, err, "error turning off failpoint: %v", err) 129 }() 130 131 sess, err := client.StartSession() 132 assert.Nil(t, err, "StartSession error: %v", err) 133 defer sess.EndSession(context.Background()) 134 135 _, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) { 136 _, err := coll.InsertOne(sessCtx, bson.D{{"x", 1}}) 137 return nil, err 138 }) 139 assert.NotNil(t, err, "expected WithTransaction error, got nil") 140 cmdErr, ok := err.(CommandError) 141 assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err) 142 assert.True(t, cmdErr.HasErrorLabel(driver.UnknownTransactionCommitResult), 143 "expected error with label %v, got %v", driver.UnknownTransactionCommitResult, cmdErr) 144 }) 145 t.Run("commit transient transaction error", func(t *testing.T) { 146 //set failpoint 147 failpoint := bson.D{{"configureFailPoint", "failCommand"}, 148 {"mode", "alwaysOn"}, 149 {"data", bson.D{ 150 {"failCommands", bson.A{"commitTransaction"}}, 151 {"errorCode", 251}, 152 }}, 153 } 154 err = dbAdmin.RunCommand(bgCtx, failpoint).Err() 155 assert.Nil(t, err, "error setting failpoint: %v", err) 156 defer func() { 157 err = dbAdmin.RunCommand(bgCtx, bson.D{ 158 {"configureFailPoint", "failCommand"}, 159 {"mode", "off"}, 160 }).Err() 161 assert.Nil(t, err, "error turning off failpoint: %v", err) 162 }() 163 164 sess, err := client.StartSession() 165 assert.Nil(t, err, "StartSession error: %v", err) 166 defer sess.EndSession(context.Background()) 167 168 _, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) { 169 _, err := coll.InsertOne(sessCtx, bson.D{{"x", 1}}) 170 return nil, err 171 }) 172 assert.NotNil(t, err, "expected WithTransaction error, got nil") 173 cmdErr, ok := err.(CommandError) 174 assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err) 175 assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError), 176 "expected error with label %v, got %v", driver.TransientTransactionError, cmdErr) 177 }) 178 }) 179} 180 181func setupConvenientTransactions(t *testing.T) *Client { 182 cs := testutil.ConnString(t) 183 poolMonitor := &event.PoolMonitor{ 184 Event: func(evt *event.PoolEvent) { 185 switch evt.Type { 186 case event.GetSucceeded: 187 connsCheckedOut++ 188 case event.ConnectionReturned: 189 connsCheckedOut-- 190 } 191 }, 192 } 193 clientOpts := options.Client().ApplyURI(cs.Original).SetReadPreference(readpref.Primary()). 194 SetWriteConcern(writeconcern.New(writeconcern.WMajority())).SetPoolMonitor(poolMonitor) 195 client, err := Connect(bgCtx, clientOpts) 196 assert.Nil(t, err, "Connect error: %v", err) 197 198 version, err := getServerVersion(client.Database("admin")) 199 assert.Nil(t, err, "getServerVersion error: %v", err) 200 topoKind := client.deployment.(*topology.Topology).Kind() 201 if compareVersions(t, version, "4.1") < 0 || topoKind == description.Single { 202 t.Skip("skipping standalones and versions < 4.1") 203 } 204 205 // pin to a single mongos if necessary 206 if topoKind != description.Sharded { 207 return client 208 } 209 client, err = Connect(bgCtx, clientOpts.SetHosts([]string{cs.Hosts[0]})) 210 assert.Nil(t, err, "Connect error: %v", err) 211 return client 212} 213 214func getServerVersion(db *Database) (string, error) { 215 serverStatus, err := db.RunCommand( 216 context.Background(), 217 bson.D{{"serverStatus", 1}}, 218 ).DecodeBytes() 219 if err != nil { 220 return "", err 221 } 222 223 version, err := serverStatus.LookupErr("version") 224 if err != nil { 225 return "", err 226 } 227 228 return version.StringValue(), nil 229} 230 231// compareVersions compares two version number strings (i.e. positive integers separated by 232// periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is 233// considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11. 234// 235// Returns a positive int if version1 is greater than version2, a negative int if version1 is less 236// than version2, and 0 if version1 is equal to version2. 237func compareVersions(t *testing.T, v1 string, v2 string) int { 238 n1 := strings.Split(v1, ".") 239 n2 := strings.Split(v2, ".") 240 241 for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ { 242 i1, err := strconv.Atoi(n1[i]) 243 if err != nil { 244 return 1 245 } 246 247 i2, err := strconv.Atoi(n2[i]) 248 if err != nil { 249 return -1 250 } 251 252 difference := i1 - i2 253 if difference != 0 { 254 return difference 255 } 256 } 257 258 return 0 259} 260