1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package unified
8
9import (
10	"context"
11	"fmt"
12
13	"go.mongodb.org/mongo-driver/bson"
14	"go.mongodb.org/mongo-driver/mongo"
15	"go.mongodb.org/mongo-driver/mongo/integration/mtest"
16	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
17)
18
19func executeTestRunnerOperation(ctx context.Context, operation *Operation) error {
20	args := operation.Arguments
21
22	switch operation.Name {
23	case "failPoint":
24		clientID := LookupString(args, "client")
25		client, err := Entities(ctx).Client(clientID)
26		if err != nil {
27			return err
28		}
29
30		fpDoc := args.Lookup("failPoint").Document()
31		if err := mtest.SetRawFailPoint(fpDoc, client.Client); err != nil {
32			return err
33		}
34		return AddFailPoint(ctx, fpDoc.Index(0).Value().StringValue(), client.Client)
35	case "targetedFailPoint":
36		sessID := LookupString(args, "session")
37		sess, err := Entities(ctx).Session(sessID)
38		if err != nil {
39			return err
40		}
41
42		clientSession := extractClientSession(sess)
43		if clientSession.PinnedServer == nil {
44			return fmt.Errorf("session is not pinned to a server")
45		}
46
47		targetHost := clientSession.PinnedServer.Addr.String()
48		fpDoc := args.Lookup("failPoint").Document()
49		commandFn := func(ctx context.Context, client *mongo.Client) error {
50			return mtest.SetRawFailPoint(fpDoc, client)
51		}
52
53		if err := RunCommandOnHost(ctx, targetHost, commandFn); err != nil {
54			return err
55		}
56		return AddTargetedFailPoint(ctx, fpDoc.Index(0).Value().StringValue(), targetHost)
57	case "assertSessionTransactionState":
58		sessID := LookupString(args, "session")
59		sess, err := Entities(ctx).Session(sessID)
60		if err != nil {
61			return err
62		}
63
64		var expectedState session.TransactionState
65		switch stateStr := LookupString(args, "state"); stateStr {
66		case "none":
67			expectedState = session.None
68		case "starting":
69			expectedState = session.Starting
70		case "in_progress":
71			expectedState = session.InProgress
72		case "committed":
73			expectedState = session.Committed
74		case "aborted":
75			expectedState = session.Aborted
76		default:
77			return fmt.Errorf("unrecognized session state type %q", stateStr)
78		}
79
80		if actualState := extractClientSession(sess).TransactionState; actualState != expectedState {
81			return fmt.Errorf("expected session state %q does not match actual state %q", expectedState, actualState)
82		}
83		return nil
84	case "assertSessionPinned":
85		return verifySessionPinnedState(ctx, LookupString(args, "session"), true)
86	case "assertSessionUnpinned":
87		return verifySessionPinnedState(ctx, LookupString(args, "session"), false)
88	case "assertSameLsidOnLastTwoCommands":
89		return verifyLastTwoLsidsEqual(ctx, LookupString(args, "client"), true)
90	case "assertDifferentLsidOnLastTwoCommands":
91		return verifyLastTwoLsidsEqual(ctx, LookupString(args, "client"), false)
92	case "assertSessionDirty":
93		return verifySessionDirtyState(ctx, LookupString(args, "session"), true)
94	case "assertSessionNotDirty":
95		return verifySessionDirtyState(ctx, LookupString(args, "session"), false)
96	case "assertCollectionExists":
97		db := LookupString(args, "databaseName")
98		coll := LookupString(args, "collectionName")
99		return verifyCollectionExists(ctx, db, coll, true)
100	case "assertCollectionNotExists":
101		db := LookupString(args, "databaseName")
102		coll := LookupString(args, "collectionName")
103		return verifyCollectionExists(ctx, db, coll, false)
104	case "assertIndexExists":
105		db := LookupString(args, "databaseName")
106		coll := LookupString(args, "collectionName")
107		index := LookupString(args, "indexName")
108		return verifyIndexExists(ctx, db, coll, index, true)
109	case "assertIndexNotExists":
110		db := LookupString(args, "databaseName")
111		coll := LookupString(args, "collectionName")
112		index := LookupString(args, "indexName")
113		return verifyIndexExists(ctx, db, coll, index, false)
114	default:
115		return fmt.Errorf("unrecognized testRunner operation %q", operation.Name)
116	}
117}
118
119func extractClientSession(sess mongo.Session) *session.Client {
120	return sess.(mongo.XSession).ClientSession()
121}
122
123func verifySessionPinnedState(ctx context.Context, sessionID string, expectedPinned bool) error {
124	sess, err := Entities(ctx).Session(sessionID)
125	if err != nil {
126		return err
127	}
128
129	if isPinned := extractClientSession(sess).PinnedServer != nil; expectedPinned != isPinned {
130		return fmt.Errorf("session pinned state mismatch; expected to be pinned: %v, is pinned: %v", expectedPinned, isPinned)
131	}
132	return nil
133}
134
135func verifyLastTwoLsidsEqual(ctx context.Context, clientID string, expectedEqual bool) error {
136	client, err := Entities(ctx).Client(clientID)
137	if err != nil {
138		return err
139	}
140
141	allEvents := client.StartedEvents()
142	if len(allEvents) < 2 {
143		return fmt.Errorf("client has recorded fewer than two command started events")
144	}
145	lastTwoEvents := allEvents[len(allEvents)-2:]
146
147	firstID, err := lastTwoEvents[0].Command.LookupErr("lsid")
148	if err != nil {
149		return fmt.Errorf("first command has no 'lsid' field: %v", client.started[0].Command)
150	}
151	secondID, err := lastTwoEvents[1].Command.LookupErr("lsid")
152	if err != nil {
153		return fmt.Errorf("first command has no 'lsid' field: %v", client.started[1].Command)
154	}
155
156	areEqual := firstID.Equal(secondID)
157	if expectedEqual && !areEqual {
158		return fmt.Errorf("expected last two lsids to be equal, but got %s and %s", firstID, secondID)
159	}
160	if !expectedEqual && areEqual {
161		return fmt.Errorf("expected last two lsids to be different but both were %s", firstID)
162	}
163	return nil
164}
165
166func verifySessionDirtyState(ctx context.Context, sessionID string, expectedDirty bool) error {
167	sess, err := Entities(ctx).Session(sessionID)
168	if err != nil {
169		return err
170	}
171
172	if isDirty := extractClientSession(sess).Dirty; expectedDirty != isDirty {
173		return fmt.Errorf("session dirty state mismatch; expected to be dirty: %v, is dirty: %v", expectedDirty, isDirty)
174	}
175	return nil
176}
177
178func verifyCollectionExists(ctx context.Context, dbName, collName string, expectedExists bool) error {
179	db := mtest.GlobalClient().Database(dbName)
180	collections, err := db.ListCollectionNames(ctx, bson.M{"name": collName})
181	if err != nil {
182		return fmt.Errorf("error running ListCollectionNames: %v", err)
183	}
184
185	if exists := len(collections) == 1; expectedExists != exists {
186		ns := fmt.Sprintf("%s.%s", dbName, collName)
187		return fmt.Errorf("collection existence mismatch; expected namespace %q to exist: %v, exists: %v", ns,
188			expectedExists, exists)
189	}
190	return nil
191}
192
193func verifyIndexExists(ctx context.Context, dbName, collName, indexName string, expectedExists bool) error {
194	iv := mtest.GlobalClient().Database(dbName).Collection(collName).Indexes()
195	cursor, err := iv.List(ctx)
196	if err != nil {
197		return fmt.Errorf("error running IndexView.List: %v", err)
198	}
199	defer cursor.Close(ctx)
200
201	var exists bool
202	for cursor.Next(ctx) {
203		if LookupString(cursor.Current, "name") == indexName {
204			exists = true
205			break
206		}
207	}
208	if expectedExists != exists {
209		ns := fmt.Sprintf("%s.%s", dbName, collName)
210		return fmt.Errorf("index existence mismatch: expected index %q to exist in namespace %q: %v, exists: %v",
211			indexName, ns, expectedExists, exists)
212	}
213	return nil
214}
215