1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongo
8
9import (
10	"context"
11	"errors"
12	"math"
13	"strconv"
14	"strings"
15	"testing"
16	"time"
17
18	"go.mongodb.org/mongo-driver/bson"
19	"go.mongodb.org/mongo-driver/event"
20	"go.mongodb.org/mongo-driver/internal/testutil"
21	"go.mongodb.org/mongo-driver/internal/testutil/assert"
22	"go.mongodb.org/mongo-driver/mongo/options"
23	"go.mongodb.org/mongo-driver/mongo/readpref"
24	"go.mongodb.org/mongo-driver/mongo/writeconcern"
25	"go.mongodb.org/mongo-driver/x/mongo/driver"
26	"go.mongodb.org/mongo-driver/x/mongo/driver/description"
27	"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
28)
29
30var (
31	connsCheckedOut int
32)
33
34func TestConvenientTransactions(t *testing.T) {
35	client := setupConvenientTransactions(t)
36	db := client.Database("TestConvenientTransactions")
37	dbAdmin := client.Database("admin")
38
39	defer func() {
40		sessions := client.NumberSessionsInProgress()
41		conns := connsCheckedOut
42
43		err := dbAdmin.RunCommand(bgCtx, bson.D{
44			{"killAllSessions", bson.A{}},
45		}).Err()
46		if err != nil {
47			if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted {
48				t.Fatalf("killAllSessions error: %v", err)
49			}
50		}
51
52		_ = db.Drop(bgCtx)
53		_ = client.Disconnect(bgCtx)
54
55		assert.Equal(t, 0, sessions, "%v sessions checked out", sessions)
56		assert.Equal(t, 0, conns, "%v connections checked out", conns)
57	}()
58
59	t.Run("callback raises custom error", func(t *testing.T) {
60		coll := db.Collection(t.Name())
61		_, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}})
62		assert.Nil(t, err, "InsertOne error: %v", err)
63
64		sess, err := client.StartSession()
65		assert.Nil(t, err, "StartSession error: %v", err)
66		defer sess.EndSession(context.Background())
67
68		testErr := errors.New("test error")
69		_, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) {
70			return nil, testErr
71		})
72		assert.Equal(t, testErr, err, "expected error %v, got %v", testErr, err)
73	})
74	t.Run("callback returns value", func(t *testing.T) {
75		coll := db.Collection(t.Name())
76		_, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}})
77		assert.Nil(t, err, "InsertOne error: %v", err)
78
79		sess, err := client.StartSession()
80		assert.Nil(t, err, "StartSession error: %v", err)
81		defer sess.EndSession(context.Background())
82
83		res, err := sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) {
84			return false, nil
85		})
86		assert.Nil(t, err, "WithTransaction error: %v", err)
87		resBool, ok := res.(bool)
88		assert.True(t, ok, "expected result type %T, got %T", false, res)
89		assert.False(t, resBool, "expected result false, got %v", resBool)
90	})
91	t.Run("retry timeout enforced", func(t *testing.T) {
92		withTransactionTimeout = time.Second
93
94		coll := db.Collection(t.Name())
95		_, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}})
96		assert.Nil(t, err, "InsertOne error: %v", err)
97
98		t.Run("transient transaction error", func(t *testing.T) {
99			sess, err := client.StartSession()
100			assert.Nil(t, err, "StartSession error: %v", err)
101			defer sess.EndSession(context.Background())
102
103			_, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) {
104				return nil, CommandError{Name: "test Error", Labels: []string{driver.TransientTransactionError}}
105			})
106			assert.NotNil(t, err, "expected WithTransaction error, got nil")
107			cmdErr, ok := err.(CommandError)
108			assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
109			assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError),
110				"expected error with label %v, got %v", driver.TransientTransactionError, cmdErr)
111		})
112		t.Run("unknown transaction commit result", func(t *testing.T) {
113			//set failpoint
114			failpoint := bson.D{{"configureFailPoint", "failCommand"},
115				{"mode", "alwaysOn"},
116				{"data", bson.D{
117					{"failCommands", bson.A{"commitTransaction"}},
118					{"closeConnection", true},
119				}},
120			}
121			err = dbAdmin.RunCommand(bgCtx, failpoint).Err()
122			assert.Nil(t, err, "error setting failpoint: %v", err)
123			defer func() {
124				err = dbAdmin.RunCommand(bgCtx, bson.D{
125					{"configureFailPoint", "failCommand"},
126					{"mode", "off"},
127				}).Err()
128				assert.Nil(t, err, "error turning off failpoint: %v", err)
129			}()
130
131			sess, err := client.StartSession()
132			assert.Nil(t, err, "StartSession error: %v", err)
133			defer sess.EndSession(context.Background())
134
135			_, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) {
136				_, err := coll.InsertOne(sessCtx, bson.D{{"x", 1}})
137				return nil, err
138			})
139			assert.NotNil(t, err, "expected WithTransaction error, got nil")
140			cmdErr, ok := err.(CommandError)
141			assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
142			assert.True(t, cmdErr.HasErrorLabel(driver.UnknownTransactionCommitResult),
143				"expected error with label %v, got %v", driver.UnknownTransactionCommitResult, cmdErr)
144		})
145		t.Run("commit transient transaction error", func(t *testing.T) {
146			//set failpoint
147			failpoint := bson.D{{"configureFailPoint", "failCommand"},
148				{"mode", "alwaysOn"},
149				{"data", bson.D{
150					{"failCommands", bson.A{"commitTransaction"}},
151					{"errorCode", 251},
152				}},
153			}
154			err = dbAdmin.RunCommand(bgCtx, failpoint).Err()
155			assert.Nil(t, err, "error setting failpoint: %v", err)
156			defer func() {
157				err = dbAdmin.RunCommand(bgCtx, bson.D{
158					{"configureFailPoint", "failCommand"},
159					{"mode", "off"},
160				}).Err()
161				assert.Nil(t, err, "error turning off failpoint: %v", err)
162			}()
163
164			sess, err := client.StartSession()
165			assert.Nil(t, err, "StartSession error: %v", err)
166			defer sess.EndSession(context.Background())
167
168			_, err = sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) {
169				_, err := coll.InsertOne(sessCtx, bson.D{{"x", 1}})
170				return nil, err
171			})
172			assert.NotNil(t, err, "expected WithTransaction error, got nil")
173			cmdErr, ok := err.(CommandError)
174			assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
175			assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError),
176				"expected error with label %v, got %v", driver.TransientTransactionError, cmdErr)
177		})
178	})
179}
180
181func setupConvenientTransactions(t *testing.T) *Client {
182	cs := testutil.ConnString(t)
183	poolMonitor := &event.PoolMonitor{
184		Event: func(evt *event.PoolEvent) {
185			switch evt.Type {
186			case event.GetSucceeded:
187				connsCheckedOut++
188			case event.ConnectionReturned:
189				connsCheckedOut--
190			}
191		},
192	}
193	clientOpts := options.Client().ApplyURI(cs.Original).SetReadPreference(readpref.Primary()).
194		SetWriteConcern(writeconcern.New(writeconcern.WMajority())).SetPoolMonitor(poolMonitor)
195	client, err := Connect(bgCtx, clientOpts)
196	assert.Nil(t, err, "Connect error: %v", err)
197
198	version, err := getServerVersion(client.Database("admin"))
199	assert.Nil(t, err, "getServerVersion error: %v", err)
200	topoKind := client.deployment.(*topology.Topology).Kind()
201	if compareVersions(t, version, "4.1") < 0 || topoKind == description.Single {
202		t.Skip("skipping standalones and versions < 4.1")
203	}
204
205	// pin to a single mongos if necessary
206	if topoKind != description.Sharded {
207		return client
208	}
209	client, err = Connect(bgCtx, clientOpts.SetHosts([]string{cs.Hosts[0]}))
210	assert.Nil(t, err, "Connect error: %v", err)
211	return client
212}
213
214func getServerVersion(db *Database) (string, error) {
215	serverStatus, err := db.RunCommand(
216		context.Background(),
217		bson.D{{"serverStatus", 1}},
218	).DecodeBytes()
219	if err != nil {
220		return "", err
221	}
222
223	version, err := serverStatus.LookupErr("version")
224	if err != nil {
225		return "", err
226	}
227
228	return version.StringValue(), nil
229}
230
231// compareVersions compares two version number strings (i.e. positive integers separated by
232// periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is
233// considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11.
234//
235// Returns a positive int if version1 is greater than version2, a negative int if version1 is less
236// than version2, and 0 if version1 is equal to version2.
237func compareVersions(t *testing.T, v1 string, v2 string) int {
238	n1 := strings.Split(v1, ".")
239	n2 := strings.Split(v2, ".")
240
241	for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ {
242		i1, err := strconv.Atoi(n1[i])
243		if err != nil {
244			return 1
245		}
246
247		i2, err := strconv.Atoi(n2[i])
248		if err != nil {
249			return -1
250		}
251
252		difference := i1 - i2
253		if difference != 0 {
254			return difference
255		}
256	}
257
258	return 0
259}
260