1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongo
8
9import (
10	"context"
11	"errors"
12	"math"
13	"strconv"
14	"strings"
15	"testing"
16	"time"
17
18	"go.mongodb.org/mongo-driver/bson"
19	"go.mongodb.org/mongo-driver/event"
20	"go.mongodb.org/mongo-driver/internal/testutil"
21	"go.mongodb.org/mongo-driver/internal/testutil/assert"
22	"go.mongodb.org/mongo-driver/mongo/options"
23	"go.mongodb.org/mongo-driver/mongo/readpref"
24	"go.mongodb.org/mongo-driver/mongo/writeconcern"
25	"go.mongodb.org/mongo-driver/x/mongo/driver"
26	"go.mongodb.org/mongo-driver/x/mongo/driver/description"
27	"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
28)
29
30var (
31	connsCheckedOut  int
32	errorInterrupted int32 = 11601
33)
34
35func TestConvenientTransactions(t *testing.T) {
36	client := setupConvenientTransactions(t)
37	db := client.Database("TestConvenientTransactions")
38	dbAdmin := client.Database("admin")
39
40	defer func() {
41		sessions := client.NumberSessionsInProgress()
42		conns := connsCheckedOut
43
44		err := dbAdmin.RunCommand(bgCtx, bson.D{
45			{"killAllSessions", bson.A{}},
46		}).Err()
47		if err != nil {
48			if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted {
49				t.Fatalf("killAllSessions error: %v", err)
50			}
51		}
52
53		_ = db.Drop(bgCtx)
54		_ = client.Disconnect(bgCtx)
55
56		assert.Equal(t, 0, sessions, "%v sessions checked out", sessions)
57		assert.Equal(t, 0, conns, "%v connections checked out", conns)
58	}()
59
60	t.Run("callback raises custom error", func(t *testing.T) {
61		coll := db.Collection(t.Name())
62		_, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}})
63		assert.Nil(t, err, "InsertOne error: %v", err)
64
65		sess, err := client.StartSession()
66		assert.Nil(t, err, "StartSession error: %v", err)
67		defer sess.EndSession(context.Background())
68
69		testErr := errors.New("test error")
70		_, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) {
71			return nil, testErr
72		})
73		assert.Equal(t, testErr, err, "expected error %v, got %v", testErr, err)
74	})
75	t.Run("callback returns value", func(t *testing.T) {
76		coll := db.Collection(t.Name())
77		_, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}})
78		assert.Nil(t, err, "InsertOne error: %v", err)
79
80		sess, err := client.StartSession()
81		assert.Nil(t, err, "StartSession error: %v", err)
82		defer sess.EndSession(context.Background())
83
84		res, err := sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) {
85			return false, nil
86		})
87		assert.Nil(t, err, "WithTransaction error: %v", err)
88		resBool, ok := res.(bool)
89		assert.True(t, ok, "expected result type %T, got %T", false, res)
90		assert.False(t, resBool, "expected result false, got %v", resBool)
91	})
92	t.Run("retry timeout enforced", func(t *testing.T) {
93		withTransactionTimeout = time.Second
94
95		coll := db.Collection(t.Name())
96		_, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}})
97		assert.Nil(t, err, "InsertOne error: %v", err)
98
99		t.Run("transient transaction error", func(t *testing.T) {
100			sess, err := client.StartSession()
101			assert.Nil(t, err, "StartSession error: %v", err)
102			defer sess.EndSession(context.Background())
103
104			_, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) {
105				return nil, CommandError{Name: "test Error", Labels: []string{driver.TransientTransactionError}}
106			})
107			assert.NotNil(t, err, "expected WithTransaction error, got nil")
108			cmdErr, ok := err.(CommandError)
109			assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
110			assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError),
111				"expected error with label %v, got %v", driver.TransientTransactionError, cmdErr)
112		})
113		t.Run("unknown transaction commit result", func(t *testing.T) {
114			//set failpoint
115			failpoint := bson.D{{"configureFailPoint", "failCommand"},
116				{"mode", "alwaysOn"},
117				{"data", bson.D{
118					{"failCommands", bson.A{"commitTransaction"}},
119					{"closeConnection", true},
120				}},
121			}
122			err = dbAdmin.RunCommand(bgCtx, failpoint).Err()
123			assert.Nil(t, err, "error setting failpoint: %v", err)
124			defer func() {
125				err = dbAdmin.RunCommand(bgCtx, bson.D{
126					{"configureFailPoint", "failCommand"},
127					{"mode", "off"},
128				}).Err()
129				assert.Nil(t, err, "error turning off failpoint: %v", err)
130			}()
131
132			sess, err := client.StartSession()
133			assert.Nil(t, err, "StartSession error: %v", err)
134			defer sess.EndSession(context.Background())
135
136			_, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) {
137				_, err := coll.InsertOne(sessCtx, bson.D{{"x", 1}})
138				return nil, err
139			})
140			assert.NotNil(t, err, "expected WithTransaction error, got nil")
141			cmdErr, ok := err.(CommandError)
142			assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
143			assert.True(t, cmdErr.HasErrorLabel(driver.UnknownTransactionCommitResult),
144				"expected error with label %v, got %v", driver.UnknownTransactionCommitResult, cmdErr)
145		})
146		t.Run("commit transient transaction error", func(t *testing.T) {
147			//set failpoint
148			failpoint := bson.D{{"configureFailPoint", "failCommand"},
149				{"mode", "alwaysOn"},
150				{"data", bson.D{
151					{"failCommands", bson.A{"commitTransaction"}},
152					{"errorCode", 251},
153				}},
154			}
155			err = dbAdmin.RunCommand(bgCtx, failpoint).Err()
156			assert.Nil(t, err, "error setting failpoint: %v", err)
157			defer func() {
158				err = dbAdmin.RunCommand(bgCtx, bson.D{
159					{"configureFailPoint", "failCommand"},
160					{"mode", "off"},
161				}).Err()
162				assert.Nil(t, err, "error turning off failpoint: %v", err)
163			}()
164
165			sess, err := client.StartSession()
166			assert.Nil(t, err, "StartSession error: %v", err)
167			defer sess.EndSession(context.Background())
168
169			_, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) {
170				_, err := coll.InsertOne(sessCtx, bson.D{{"x", 1}})
171				return nil, err
172			})
173			assert.NotNil(t, err, "expected WithTransaction error, got nil")
174			cmdErr, ok := err.(CommandError)
175			assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
176			assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError),
177				"expected error with label %v, got %v", driver.TransientTransactionError, cmdErr)
178		})
179	})
180}
181
182func setupConvenientTransactions(t *testing.T) *Client {
183	cs := testutil.ConnString(t)
184	poolMonitor := &event.PoolMonitor{
185		Event: func(evt *event.PoolEvent) {
186			switch evt.Type {
187			case event.GetSucceeded:
188				connsCheckedOut++
189			case event.ConnectionReturned:
190				connsCheckedOut--
191			}
192		},
193	}
194	clientOpts := options.Client().ApplyURI(cs.Original).SetReadPreference(readpref.Primary()).
195		SetWriteConcern(writeconcern.New(writeconcern.WMajority())).SetPoolMonitor(poolMonitor)
196	client, err := Connect(bgCtx, clientOpts)
197	assert.Nil(t, err, "Connect error: %v", err)
198
199	version, err := getServerVersion(client.Database("admin"))
200	assert.Nil(t, err, "getServerVersion error: %v", err)
201	topoKind := client.deployment.(*topology.Topology).Kind()
202	if compareVersions(t, version, "4.1") < 0 || topoKind == description.Single {
203		t.Skip("skipping standalones and versions < 4.1")
204	}
205
206	// pin to a single mongos if necessary
207	if topoKind != description.Sharded {
208		return client
209	}
210	client, err = Connect(bgCtx, clientOpts.SetHosts([]string{cs.Hosts[0]}))
211	assert.Nil(t, err, "Connect error: %v", err)
212	return client
213}
214
215func getServerVersion(db *Database) (string, error) {
216	serverStatus, err := db.RunCommand(
217		context.Background(),
218		bson.D{{"serverStatus", 1}},
219	).DecodeBytes()
220	if err != nil {
221		return "", err
222	}
223
224	version, err := serverStatus.LookupErr("version")
225	if err != nil {
226		return "", err
227	}
228
229	return version.StringValue(), nil
230}
231
232// compareVersions compares two version number strings (i.e. positive integers separated by
233// periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is
234// considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11.
235//
236// Returns a positive int if version1 is greater than version2, a negative int if version1 is less
237// than version2, and 0 if version1 is equal to version2.
238func compareVersions(t *testing.T, v1 string, v2 string) int {
239	n1 := strings.Split(v1, ".")
240	n2 := strings.Split(v2, ".")
241
242	for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ {
243		i1, err := strconv.Atoi(n1[i])
244		if err != nil {
245			return 1
246		}
247
248		i2, err := strconv.Atoi(n2[i])
249		if err != nil {
250			return -1
251		}
252
253		difference := i1 - i2
254		if difference != 0 {
255			return difference
256		}
257	}
258
259	return 0
260}
261