1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package unified 8 9import ( 10 "context" 11 "fmt" 12 13 "go.mongodb.org/mongo-driver/bson" 14 "go.mongodb.org/mongo-driver/mongo" 15 "go.mongodb.org/mongo-driver/mongo/integration/mtest" 16 "go.mongodb.org/mongo-driver/x/mongo/driver/session" 17) 18 19func executeTestRunnerOperation(ctx context.Context, operation *Operation) error { 20 args := operation.Arguments 21 22 switch operation.Name { 23 case "failPoint": 24 clientID := LookupString(args, "client") 25 client, err := Entities(ctx).Client(clientID) 26 if err != nil { 27 return err 28 } 29 30 fpDoc := args.Lookup("failPoint").Document() 31 if err := mtest.SetRawFailPoint(fpDoc, client.Client); err != nil { 32 return err 33 } 34 return AddFailPoint(ctx, fpDoc.Index(0).Value().StringValue(), client.Client) 35 case "targetedFailPoint": 36 sessID := LookupString(args, "session") 37 sess, err := Entities(ctx).Session(sessID) 38 if err != nil { 39 return err 40 } 41 42 clientSession := extractClientSession(sess) 43 if clientSession.PinnedServer == nil { 44 return fmt.Errorf("session is not pinned to a server") 45 } 46 47 targetHost := clientSession.PinnedServer.Addr.String() 48 fpDoc := args.Lookup("failPoint").Document() 49 commandFn := func(ctx context.Context, client *mongo.Client) error { 50 return mtest.SetRawFailPoint(fpDoc, client) 51 } 52 53 if err := RunCommandOnHost(ctx, targetHost, commandFn); err != nil { 54 return err 55 } 56 return AddTargetedFailPoint(ctx, fpDoc.Index(0).Value().StringValue(), targetHost) 57 case "assertSessionTransactionState": 58 sessID := LookupString(args, "session") 59 sess, err := Entities(ctx).Session(sessID) 60 if err != nil { 61 return err 62 } 63 64 var expectedState session.TransactionState 65 switch stateStr := LookupString(args, "state"); stateStr { 66 case "none": 67 expectedState = session.None 68 case "starting": 69 expectedState = session.Starting 70 case "in_progress": 71 expectedState = session.InProgress 72 case "committed": 73 expectedState = session.Committed 74 case "aborted": 75 expectedState = session.Aborted 76 default: 77 return fmt.Errorf("unrecognized session state type %q", stateStr) 78 } 79 80 if actualState := extractClientSession(sess).TransactionState; actualState != expectedState { 81 return fmt.Errorf("expected session state %q does not match actual state %q", expectedState, actualState) 82 } 83 return nil 84 case "assertSessionPinned": 85 return verifySessionPinnedState(ctx, LookupString(args, "session"), true) 86 case "assertSessionUnpinned": 87 return verifySessionPinnedState(ctx, LookupString(args, "session"), false) 88 case "assertSameLsidOnLastTwoCommands": 89 return verifyLastTwoLsidsEqual(ctx, LookupString(args, "client"), true) 90 case "assertDifferentLsidOnLastTwoCommands": 91 return verifyLastTwoLsidsEqual(ctx, LookupString(args, "client"), false) 92 case "assertSessionDirty": 93 return verifySessionDirtyState(ctx, LookupString(args, "session"), true) 94 case "assertSessionNotDirty": 95 return verifySessionDirtyState(ctx, LookupString(args, "session"), false) 96 case "assertCollectionExists": 97 db := LookupString(args, "databaseName") 98 coll := LookupString(args, "collectionName") 99 return verifyCollectionExists(ctx, db, coll, true) 100 case "assertCollectionNotExists": 101 db := LookupString(args, "databaseName") 102 coll := LookupString(args, "collectionName") 103 return verifyCollectionExists(ctx, db, coll, false) 104 case "assertIndexExists": 105 db := LookupString(args, "databaseName") 106 coll := LookupString(args, "collectionName") 107 index := LookupString(args, "indexName") 108 return verifyIndexExists(ctx, db, coll, index, true) 109 case "assertIndexNotExists": 110 db := LookupString(args, "databaseName") 111 coll := LookupString(args, "collectionName") 112 index := LookupString(args, "indexName") 113 return verifyIndexExists(ctx, db, coll, index, false) 114 default: 115 return fmt.Errorf("unrecognized testRunner operation %q", operation.Name) 116 } 117} 118 119func extractClientSession(sess mongo.Session) *session.Client { 120 return sess.(mongo.XSession).ClientSession() 121} 122 123func verifySessionPinnedState(ctx context.Context, sessionID string, expectedPinned bool) error { 124 sess, err := Entities(ctx).Session(sessionID) 125 if err != nil { 126 return err 127 } 128 129 if isPinned := extractClientSession(sess).PinnedServer != nil; expectedPinned != isPinned { 130 return fmt.Errorf("session pinned state mismatch; expected to be pinned: %v, is pinned: %v", expectedPinned, isPinned) 131 } 132 return nil 133} 134 135func verifyLastTwoLsidsEqual(ctx context.Context, clientID string, expectedEqual bool) error { 136 client, err := Entities(ctx).Client(clientID) 137 if err != nil { 138 return err 139 } 140 141 allEvents := client.StartedEvents() 142 if len(allEvents) < 2 { 143 return fmt.Errorf("client has recorded fewer than two command started events") 144 } 145 lastTwoEvents := allEvents[len(allEvents)-2:] 146 147 firstID, err := lastTwoEvents[0].Command.LookupErr("lsid") 148 if err != nil { 149 return fmt.Errorf("first command has no 'lsid' field: %v", client.started[0].Command) 150 } 151 secondID, err := lastTwoEvents[1].Command.LookupErr("lsid") 152 if err != nil { 153 return fmt.Errorf("first command has no 'lsid' field: %v", client.started[1].Command) 154 } 155 156 areEqual := firstID.Equal(secondID) 157 if expectedEqual && !areEqual { 158 return fmt.Errorf("expected last two lsids to be equal, but got %s and %s", firstID, secondID) 159 } 160 if !expectedEqual && areEqual { 161 return fmt.Errorf("expected last two lsids to be different but both were %s", firstID) 162 } 163 return nil 164} 165 166func verifySessionDirtyState(ctx context.Context, sessionID string, expectedDirty bool) error { 167 sess, err := Entities(ctx).Session(sessionID) 168 if err != nil { 169 return err 170 } 171 172 if isDirty := extractClientSession(sess).Dirty; expectedDirty != isDirty { 173 return fmt.Errorf("session dirty state mismatch; expected to be dirty: %v, is dirty: %v", expectedDirty, isDirty) 174 } 175 return nil 176} 177 178func verifyCollectionExists(ctx context.Context, dbName, collName string, expectedExists bool) error { 179 db := mtest.GlobalClient().Database(dbName) 180 collections, err := db.ListCollectionNames(ctx, bson.M{"name": collName}) 181 if err != nil { 182 return fmt.Errorf("error running ListCollectionNames: %v", err) 183 } 184 185 if exists := len(collections) == 1; expectedExists != exists { 186 ns := fmt.Sprintf("%s.%s", dbName, collName) 187 return fmt.Errorf("collection existence mismatch; expected namespace %q to exist: %v, exists: %v", ns, 188 expectedExists, exists) 189 } 190 return nil 191} 192 193func verifyIndexExists(ctx context.Context, dbName, collName, indexName string, expectedExists bool) error { 194 iv := mtest.GlobalClient().Database(dbName).Collection(collName).Indexes() 195 cursor, err := iv.List(ctx) 196 if err != nil { 197 return fmt.Errorf("error running IndexView.List: %v", err) 198 } 199 defer cursor.Close(ctx) 200 201 var exists bool 202 for cursor.Next(ctx) { 203 if LookupString(cursor.Current, "name") == indexName { 204 exists = true 205 break 206 } 207 } 208 if expectedExists != exists { 209 ns := fmt.Sprintf("%s.%s", dbName, collName) 210 return fmt.Errorf("index existence mismatch: expected index %q to exist in namespace %q: %v, exists: %v", 211 indexName, ns, expectedExists, exists) 212 } 213 return nil 214} 215