1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package auth_test
8
9import (
10	"testing"
11
12	"github.com/google/go-cmp/cmp"
13	"github.com/stretchr/testify/require"
14	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
15	. "go.mongodb.org/mongo-driver/x/mongo/driver/auth"
16	"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
17)
18
19func TestCreateAuthenticator(t *testing.T) {
20
21	tests := []struct {
22		name   string
23		source string
24		auther Authenticator
25	}{
26		{name: "", auther: &DefaultAuthenticator{}},
27		{name: "SCRAM-SHA-1", auther: &ScramAuthenticator{}},
28		{name: "SCRAM-SHA-256", auther: &ScramAuthenticator{}},
29		{name: "MONGODB-CR", auther: &MongoDBCRAuthenticator{}},
30		{name: "PLAIN", auther: &PlainAuthenticator{}},
31		{name: "MONGODB-X509", auther: &MongoDBX509Authenticator{}},
32	}
33
34	for _, test := range tests {
35		t.Run(test.name, func(t *testing.T) {
36			cred := &Cred{
37				Username:    "user",
38				Password:    "pencil",
39				PasswordSet: true,
40			}
41
42			a, err := CreateAuthenticator(test.name, cred)
43			require.NoError(t, err)
44			require.IsType(t, test.auther, a)
45		})
46	}
47}
48
49func compareResponses(t *testing.T, wm []byte, expectedPayload bsoncore.Document, dbName string) {
50	_, _, _, opcode, wm, ok := wiremessage.ReadHeader(wm)
51	if !ok {
52		t.Fatalf("wiremessage is too short to unmarshal")
53	}
54	var actualPayload bsoncore.Document
55	switch opcode {
56	case wiremessage.OpQuery:
57		_, wm, ok := wiremessage.ReadQueryFlags(wm)
58		if !ok {
59			t.Fatalf("wiremessage is too short to unmarshal")
60		}
61		_, wm, ok = wiremessage.ReadQueryFullCollectionName(wm)
62		if !ok {
63			t.Fatalf("wiremessage is too short to unmarshal")
64		}
65		_, wm, ok = wiremessage.ReadQueryNumberToSkip(wm)
66		if !ok {
67			t.Fatalf("wiremessage is too short to unmarshal")
68		}
69		_, wm, ok = wiremessage.ReadQueryNumberToReturn(wm)
70		if !ok {
71			t.Fatalf("wiremessage is too short to unmarshal")
72		}
73		actualPayload, _, ok = wiremessage.ReadQueryQuery(wm)
74		if !ok {
75			t.Fatalf("wiremessage is too short to unmarshal")
76		}
77	case wiremessage.OpMsg:
78		// Append the $db field.
79		elems, err := expectedPayload.Elements()
80		if err != nil {
81			t.Fatalf("expectedPayload is not valid: %v", err)
82		}
83		elems = append(elems, bsoncore.AppendStringElement(nil, "$db", dbName))
84		elems = append(elems, bsoncore.AppendDocumentElement(nil,
85			"$readPreference",
86			bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendStringElement(nil, "mode", "primaryPreferred")),
87		))
88		bslc := make([][]byte, 0, len(elems)) // BuildDocumentFromElements takes a [][]byte, not a []bsoncore.Element.
89		for _, elem := range elems {
90			bslc = append(bslc, elem)
91		}
92		expectedPayload = bsoncore.BuildDocumentFromElements(nil, bslc...)
93
94		_, wm, ok := wiremessage.ReadMsgFlags(wm)
95		if !ok {
96			t.Fatalf("wiremessage is too short to unmarshal")
97		}
98	loop:
99		for {
100			var stype wiremessage.SectionType
101			stype, wm, ok = wiremessage.ReadMsgSectionType(wm)
102			if !ok {
103				t.Fatalf("wiremessage is too short to unmarshal")
104				break
105			}
106			switch stype {
107			case wiremessage.DocumentSequence:
108				_, _, wm, ok = wiremessage.ReadMsgSectionDocumentSequence(wm)
109				if !ok {
110					t.Fatalf("wiremessage is too short to unmarshal")
111					break loop
112				}
113			case wiremessage.SingleDocument:
114				actualPayload, wm, ok = wiremessage.ReadMsgSectionSingleDocument(wm)
115				if !ok {
116					t.Fatalf("wiremessage is too short to unmarshal")
117				}
118				break loop
119			}
120		}
121	}
122
123	if !cmp.Equal(actualPayload, expectedPayload) {
124		t.Errorf("Payloads don't match. got %v; want %v", actualPayload, expectedPayload)
125	}
126}
127