1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package auth_test 8 9import ( 10 "testing" 11 12 "github.com/google/go-cmp/cmp" 13 "github.com/stretchr/testify/require" 14 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" 15 . "go.mongodb.org/mongo-driver/x/mongo/driver/auth" 16 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" 17) 18 19func TestCreateAuthenticator(t *testing.T) { 20 21 tests := []struct { 22 name string 23 source string 24 auther Authenticator 25 }{ 26 {name: "", auther: &DefaultAuthenticator{}}, 27 {name: "SCRAM-SHA-1", auther: &ScramAuthenticator{}}, 28 {name: "SCRAM-SHA-256", auther: &ScramAuthenticator{}}, 29 {name: "MONGODB-CR", auther: &MongoDBCRAuthenticator{}}, 30 {name: "PLAIN", auther: &PlainAuthenticator{}}, 31 {name: "MONGODB-X509", auther: &MongoDBX509Authenticator{}}, 32 } 33 34 for _, test := range tests { 35 t.Run(test.name, func(t *testing.T) { 36 cred := &Cred{ 37 Username: "user", 38 Password: "pencil", 39 PasswordSet: true, 40 } 41 42 a, err := CreateAuthenticator(test.name, cred) 43 require.NoError(t, err) 44 require.IsType(t, test.auther, a) 45 }) 46 } 47} 48 49func compareResponses(t *testing.T, wm []byte, expectedPayload bsoncore.Document, dbName string) { 50 _, _, _, opcode, wm, ok := wiremessage.ReadHeader(wm) 51 if !ok { 52 t.Fatalf("wiremessage is too short to unmarshal") 53 } 54 var actualPayload bsoncore.Document 55 switch opcode { 56 case wiremessage.OpQuery: 57 _, wm, ok := wiremessage.ReadQueryFlags(wm) 58 if !ok { 59 t.Fatalf("wiremessage is too short to unmarshal") 60 } 61 _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm) 62 if !ok { 63 t.Fatalf("wiremessage is too short to unmarshal") 64 } 65 _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm) 66 if !ok { 67 t.Fatalf("wiremessage is too short to unmarshal") 68 } 69 _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm) 70 if !ok { 71 t.Fatalf("wiremessage is too short to unmarshal") 72 } 73 actualPayload, _, ok = wiremessage.ReadQueryQuery(wm) 74 if !ok { 75 t.Fatalf("wiremessage is too short to unmarshal") 76 } 77 case wiremessage.OpMsg: 78 // Append the $db field. 79 elems, err := expectedPayload.Elements() 80 if err != nil { 81 t.Fatalf("expectedPayload is not valid: %v", err) 82 } 83 elems = append(elems, bsoncore.AppendStringElement(nil, "$db", dbName)) 84 elems = append(elems, bsoncore.AppendDocumentElement(nil, 85 "$readPreference", 86 bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendStringElement(nil, "mode", "primaryPreferred")), 87 )) 88 bslc := make([][]byte, 0, len(elems)) // BuildDocumentFromElements takes a [][]byte, not a []bsoncore.Element. 89 for _, elem := range elems { 90 bslc = append(bslc, elem) 91 } 92 expectedPayload = bsoncore.BuildDocumentFromElements(nil, bslc...) 93 94 _, wm, ok := wiremessage.ReadMsgFlags(wm) 95 if !ok { 96 t.Fatalf("wiremessage is too short to unmarshal") 97 } 98 loop: 99 for { 100 var stype wiremessage.SectionType 101 stype, wm, ok = wiremessage.ReadMsgSectionType(wm) 102 if !ok { 103 t.Fatalf("wiremessage is too short to unmarshal") 104 break 105 } 106 switch stype { 107 case wiremessage.DocumentSequence: 108 _, _, wm, ok = wiremessage.ReadMsgSectionDocumentSequence(wm) 109 if !ok { 110 t.Fatalf("wiremessage is too short to unmarshal") 111 break loop 112 } 113 case wiremessage.SingleDocument: 114 actualPayload, wm, ok = wiremessage.ReadMsgSectionSingleDocument(wm) 115 if !ok { 116 t.Fatalf("wiremessage is too short to unmarshal") 117 } 118 break loop 119 } 120 } 121 } 122 123 if !cmp.Equal(actualPayload, expectedPayload) { 124 t.Errorf("Payloads don't match. got %v; want %v", actualPayload, expectedPayload) 125 } 126} 127