1 
2 /**
3  *    Copyright (C) 2018-present MongoDB, Inc.
4  *
5  *    This program is free software: you can redistribute it and/or modify
6  *    it under the terms of the Server Side Public License, version 1,
7  *    as published by MongoDB, Inc.
8  *
9  *    This program is distributed in the hope that it will be useful,
10  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
11  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  *    Server Side Public License for more details.
13  *
14  *    You should have received a copy of the Server Side Public License
15  *    along with this program. If not, see
16  *    <http://www.mongodb.com/licensing/server-side-public-license>.
17  *
18  *    As a special exception, the copyright holders give permission to link the
19  *    code of portions of this program with the OpenSSL library under certain
20  *    conditions as described in each individual source file and distribute
21  *    linked combinations including the program with the OpenSSL library. You
22  *    must comply with the Server Side Public License in all respects for
23  *    all of the code used other than as permitted herein. If you modify file(s)
24  *    with this exception, you may extend this exception to your version of the
25  *    file(s), but you are not obligated to do so. If you do not wish to do so,
26  *    delete this exception statement from your version. If you delete this
27  *    exception statement from all source files in the program, then also delete
28  *    it in the license file.
29  */
30 
31 #include "mongo/platform/basic.h"
32 
33 #include "mongo/client/native_sasl_client_session.h"
34 #include "mongo/client/scram_sha1_client_cache.h"
35 #include "mongo/crypto/mechanism_scram.h"
36 #include "mongo/db/auth/authorization_manager.h"
37 #include "mongo/db/auth/authz_manager_external_state_mock.h"
38 #include "mongo/db/auth/authz_session_external_state_mock.h"
39 #include "mongo/db/auth/native_sasl_authentication_session.h"
40 #include "mongo/db/auth/sasl_scramsha1_server_conversation.h"
41 #include "mongo/db/service_context_noop.h"
42 #include "mongo/stdx/memory.h"
43 #include "mongo/unittest/unittest.h"
44 #include "mongo/util/base64.h"
45 #include "mongo/util/password_digest.h"
46 
47 namespace mongo {
48 
generateSCRAMUserDocument(StringData username,StringData password)49 BSONObj generateSCRAMUserDocument(StringData username, StringData password) {
50     const size_t scramIterationCount = 10000;
51     auto database = "test"_sd;
52 
53     std::string digested = createPasswordDigest(username, password);
54     BSONObj scramCred = scram::generateCredentials(digested, scramIterationCount);
55     return BSON("_id" << (str::stream() << database << "." << username).operator StringData()
56                       << AuthorizationManager::USER_NAME_FIELD_NAME
57                       << username
58                       << AuthorizationManager::USER_DB_FIELD_NAME
59                       << database
60                       << "credentials"
61                       << BSON("SCRAM-SHA-1" << scramCred)
62                       << "roles"
63                       << BSONArray()
64                       << "privileges"
65                       << BSONArray());
66 }
67 
generateMONGODBCRUserDocument(StringData username,StringData password)68 BSONObj generateMONGODBCRUserDocument(StringData username, StringData password) {
69     auto database = "test"_sd;
70 
71     std::string digested = createPasswordDigest(username, password);
72     return BSON("_id" << (str::stream() << database << "." << username).operator StringData()
73                       << AuthorizationManager::USER_NAME_FIELD_NAME
74                       << username
75                       << AuthorizationManager::USER_DB_FIELD_NAME
76                       << database
77                       << "credentials"
78                       << BSON("MONGODB-CR" << digested)
79                       << "roles"
80                       << BSONArray()
81                       << "privileges"
82                       << BSONArray());
83 }
84 
corruptEncodedPayload(const std::string & message,std::string::const_iterator begin,std::string::const_iterator end)85 std::string corruptEncodedPayload(const std::string& message,
86                                   std::string::const_iterator begin,
87                                   std::string::const_iterator end) {
88     std::string raw = base64::decode(
89         message.substr(std::distance(message.begin(), begin), std::distance(begin, end)));
90     if (raw[0] == std::numeric_limits<char>::max()) {
91         raw[0] -= 1;
92     } else {
93         raw[0] += 1;
94     }
95     return base64::encode(raw);
96 }
97 
98 class SaslTestState {
99 public:
100     enum Participant { kClient, kServer };
SaslTestState()101     SaslTestState() : SaslTestState(kClient, 0) {}
SaslTestState(Participant participant,size_t stage)102     SaslTestState(Participant participant, size_t stage) : participant(participant), stage(stage) {}
103 
104 private:
105     // Define members here, so that they can be used in declaration of lens(). In C++14, lens()
106     // can be declared with a return of decltype(auto), without a trailing return type, and these
107     // members can go at the end of the class.
108     Participant participant;
109     size_t stage;
110 
111 public:
lens() const112     auto lens() const -> decltype(std::tie(this->stage, this->participant)) {
113         return std::tie(stage, participant);
114     }
115 
operator ==(const SaslTestState & lhs,const SaslTestState & rhs)116     friend bool operator==(const SaslTestState& lhs, const SaslTestState& rhs) {
117         return lhs.lens() == rhs.lens();
118     }
119 
operator <(const SaslTestState & lhs,const SaslTestState & rhs)120     friend bool operator<(const SaslTestState& lhs, const SaslTestState& rhs) {
121         return lhs.lens() < rhs.lens();
122     }
123 
next()124     void next() {
125         if (participant == kClient) {
126             participant = kServer;
127         } else {
128             participant = kClient;
129             stage++;
130         }
131     }
132 
toString() const133     std::string toString() const {
134         std::stringstream ss;
135         if (participant == kClient) {
136             ss << "Client";
137         } else {
138             ss << "Server";
139         }
140         ss << "Step" << stage;
141 
142         return ss.str();
143     }
144 };
145 
146 class SCRAMMutators {
147 public:
SCRAMMutators()148     SCRAMMutators() {}
149 
setMutator(SaslTestState state,stdx::function<void (std::string &)> fun)150     void setMutator(SaslTestState state, stdx::function<void(std::string&)> fun) {
151         mutators.insert(std::make_pair(state, fun));
152     }
153 
execute(SaslTestState state,std::string & str)154     void execute(SaslTestState state, std::string& str) {
155         auto it = mutators.find(state);
156         if (it != mutators.end()) {
157             it->second(str);
158         }
159     }
160 
161 private:
162     std::map<SaslTestState, stdx::function<void(std::string&)>> mutators;
163 };
164 
165 struct SCRAMStepsResult {
SCRAMStepsResultmongo::SCRAMStepsResult166     SCRAMStepsResult() : outcome(SaslTestState::kClient, 1), status(Status::OK()) {}
SCRAMStepsResultmongo::SCRAMStepsResult167     SCRAMStepsResult(SaslTestState outcome, Status status) : outcome(outcome), status(status) {}
operator ==mongo::SCRAMStepsResult168     bool operator==(const SCRAMStepsResult& other) const {
169         return outcome == other.outcome && status.code() == other.status.code() &&
170             status.reason() == other.status.reason();
171     }
172     SaslTestState outcome;
173     Status status;
174 
operator <<(std::ostream & os,const SCRAMStepsResult & result)175     friend std::ostream& operator<<(std::ostream& os, const SCRAMStepsResult& result) {
176         return os << "{outcome: " << result.outcome.toString() << ", status: " << result.status
177                   << "}";
178     }
179 };
180 
runSteps(NativeSaslAuthenticationSession * saslServerSession,NativeSaslClientSession * saslClientSession,SCRAMMutators interposers=SCRAMMutators{})181 SCRAMStepsResult runSteps(NativeSaslAuthenticationSession* saslServerSession,
182                           NativeSaslClientSession* saslClientSession,
183                           SCRAMMutators interposers = SCRAMMutators{}) {
184     SCRAMStepsResult result{};
185     std::string clientOutput = "";
186     std::string serverOutput = "";
187 
188     for (size_t step = 1; step <= 3; step++) {
189         ASSERT_FALSE(saslClientSession->isDone());
190         ASSERT_FALSE(saslServerSession->isDone());
191 
192         // Client step
193         result.status = saslClientSession->step(serverOutput, &clientOutput);
194         if (result.status != Status::OK()) {
195             return result;
196         }
197         interposers.execute(result.outcome, clientOutput);
198         std::cout << result.outcome.toString() << ": " << clientOutput << std::endl;
199         result.outcome.next();
200 
201         // Server step
202         result.status = saslServerSession->step(clientOutput, &serverOutput);
203         if (result.status != Status::OK()) {
204             return result;
205         }
206         interposers.execute(result.outcome, serverOutput);
207         std::cout << result.outcome.toString() << ": " << serverOutput << std::endl;
208         result.outcome.next();
209     }
210     ASSERT_TRUE(saslClientSession->isDone());
211     ASSERT_TRUE(saslServerSession->isDone());
212 
213     return result;
214 }
215 
216 class SCRAMSHA1Fixture : public mongo::unittest::Test {
217 protected:
218     const SCRAMStepsResult goalState =
219         SCRAMStepsResult(SaslTestState(SaslTestState::kClient, 4), Status::OK());
220 
221     ServiceContextNoop serviceContext;
222     ServiceContextNoop::UniqueClient client;
223     ServiceContextNoop::UniqueOperationContext opCtx;
224 
225     AuthzManagerExternalStateMock* authzManagerExternalState;
226     std::unique_ptr<AuthorizationManager> authzManager;
227     std::unique_ptr<AuthorizationSession> authzSession;
228 
229     std::unique_ptr<NativeSaslAuthenticationSession> saslServerSession;
230     std::unique_ptr<NativeSaslClientSession> saslClientSession;
231 
setUp()232     void setUp() {
233         client = serviceContext.makeClient("test");
234         opCtx = serviceContext.makeOperationContext(client.get());
235 
236         auto uniqueAuthzManagerExternalStateMock =
237             stdx::make_unique<AuthzManagerExternalStateMock>();
238         authzManagerExternalState = uniqueAuthzManagerExternalStateMock.get();
239         authzManager =
240             stdx::make_unique<AuthorizationManager>(std::move(uniqueAuthzManagerExternalStateMock));
241         authzSession = stdx::make_unique<AuthorizationSession>(
242             stdx::make_unique<AuthzSessionExternalStateMock>(authzManager.get()));
243 
244         saslServerSession = stdx::make_unique<NativeSaslAuthenticationSession>(authzSession.get());
245         saslServerSession->setOpCtxt(opCtx.get());
246         saslServerSession->start("test", "SCRAM-SHA-1", "mongodb", "MockServer.test", 1, false)
247             .transitional_ignore();
248         saslClientSession = stdx::make_unique<NativeSaslClientSession>();
249         saslClientSession->setParameter(NativeSaslClientSession::parameterMechanism, "SCRAM-SHA-1");
250         saslClientSession->setParameter(NativeSaslClientSession::parameterServiceName, "mongodb");
251         saslClientSession->setParameter(NativeSaslClientSession::parameterServiceHostname,
252                                         "MockServer.test");
253         saslClientSession->setParameter(NativeSaslClientSession::parameterServiceHostAndPort,
254                                         "MockServer.test:27017");
255     }
256 };
257 
TEST_F(SCRAMSHA1Fixture,testServerStep1DoesNotIncludeNonceFromClientStep1)258 TEST_F(SCRAMSHA1Fixture, testServerStep1DoesNotIncludeNonceFromClientStep1) {
259     authzManagerExternalState
260         ->insertPrivilegeDocument(
261             opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj())
262         .transitional_ignore();
263 
264     saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
265     saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
266                                     createPasswordDigest("sajack", "sajack"));
267 
268     ASSERT_OK(saslClientSession->initialize());
269 
270     SCRAMMutators mutator;
271     mutator.setMutator(SaslTestState(SaslTestState::kServer, 1), [](std::string& serverMessage) {
272         std::string::iterator nonceBegin = serverMessage.begin() + serverMessage.find("r=");
273         std::string::iterator nonceEnd = std::find(nonceBegin, serverMessage.end(), ',');
274         serverMessage = serverMessage.replace(nonceBegin, nonceEnd, "r=");
275 
276     });
277     ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kClient, 2),
278                                Status(ErrorCodes::BadValue,
279                                       "Server SCRAM-SHA-1 nonce does not match client nonce: r=")),
280               runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
281 }
282 
TEST_F(SCRAMSHA1Fixture,testClientStep2DoesNotIncludeNonceFromServerStep1)283 TEST_F(SCRAMSHA1Fixture, testClientStep2DoesNotIncludeNonceFromServerStep1) {
284     authzManagerExternalState
285         ->insertPrivilegeDocument(
286             opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj())
287         .transitional_ignore();
288 
289     saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
290     saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
291                                     createPasswordDigest("sajack", "sajack"));
292 
293     ASSERT_OK(saslClientSession->initialize());
294 
295     SCRAMMutators mutator;
296     mutator.setMutator(SaslTestState(SaslTestState::kClient, 2), [](std::string& clientMessage) {
297         std::string::iterator nonceBegin = clientMessage.begin() + clientMessage.find("r=");
298         std::string::iterator nonceEnd = std::find(nonceBegin, clientMessage.end(), ',');
299         clientMessage = clientMessage.replace(nonceBegin, nonceEnd, "r=");
300     });
301     ASSERT_EQ(SCRAMStepsResult(
302                   SaslTestState(SaslTestState::kServer, 2),
303                   Status(ErrorCodes::BadValue, "Incorrect SCRAM-SHA-1 client|server nonce: r=")),
304               runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
305 }
306 
TEST_F(SCRAMSHA1Fixture,testClientStep2GivesBadProof)307 TEST_F(SCRAMSHA1Fixture, testClientStep2GivesBadProof) {
308     authzManagerExternalState
309         ->insertPrivilegeDocument(
310             opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj())
311         .transitional_ignore();
312 
313     saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
314     saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
315                                     createPasswordDigest("sajack", "sajack"));
316 
317     ASSERT_OK(saslClientSession->initialize());
318 
319     SCRAMMutators mutator;
320     mutator.setMutator(SaslTestState(SaslTestState::kClient, 2), [](std::string& clientMessage) {
321         std::string::iterator proofBegin = clientMessage.begin() + clientMessage.find("p=") + 2;
322         std::string::iterator proofEnd = std::find(proofBegin, clientMessage.end(), ',');
323         clientMessage = clientMessage.replace(
324             proofBegin, proofEnd, corruptEncodedPayload(clientMessage, proofBegin, proofEnd));
325 
326     });
327 
328     ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 2),
329                                Status(ErrorCodes::AuthenticationFailed,
330                                       "SCRAM-SHA-1 authentication failed, storedKey mismatch")),
331               runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
332 }
333 
TEST_F(SCRAMSHA1Fixture,testServerStep2GivesBadVerifier)334 TEST_F(SCRAMSHA1Fixture, testServerStep2GivesBadVerifier) {
335     authzManagerExternalState
336         ->insertPrivilegeDocument(
337             opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj())
338         .transitional_ignore();
339 
340     saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
341     saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
342                                     createPasswordDigest("sajack", "sajack"));
343 
344     ASSERT_OK(saslClientSession->initialize());
345 
346     std::string encodedVerifier;
347     SCRAMMutators mutator;
348     mutator.setMutator(
349         SaslTestState(SaslTestState::kServer, 2), [&encodedVerifier](std::string& serverMessage) {
350             std::string::iterator verifierBegin =
351                 serverMessage.begin() + serverMessage.find("v=") + 2;
352             std::string::iterator verifierEnd = std::find(verifierBegin, serverMessage.end(), ',');
353             encodedVerifier = corruptEncodedPayload(serverMessage, verifierBegin, verifierEnd);
354 
355             serverMessage = serverMessage.replace(verifierBegin, verifierEnd, encodedVerifier);
356 
357         });
358 
359     auto result = runSteps(saslServerSession.get(), saslClientSession.get(), mutator);
360 
361     ASSERT_EQ(
362         SCRAMStepsResult(
363             SaslTestState(SaslTestState::kClient, 3),
364             Status(ErrorCodes::BadValue,
365                    str::stream() << "Client failed to verify SCRAM-SHA-1 ServerSignature, received "
366                                  << encodedVerifier)),
367         result);
368 }
369 
370 
TEST_F(SCRAMSHA1Fixture,testSCRAM)371 TEST_F(SCRAMSHA1Fixture, testSCRAM) {
372     authzManagerExternalState
373         ->insertPrivilegeDocument(
374             opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj())
375         .transitional_ignore();
376 
377     saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
378     saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
379                                     createPasswordDigest("sajack", "sajack"));
380 
381     ASSERT_OK(saslClientSession->initialize());
382 
383     ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get()));
384 }
385 
TEST_F(SCRAMSHA1Fixture,testSCRAMWithChannelBindingSupportedByClient)386 TEST_F(SCRAMSHA1Fixture, testSCRAMWithChannelBindingSupportedByClient) {
387     authzManagerExternalState
388         ->insertPrivilegeDocument(
389             opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj())
390         .transitional_ignore();
391 
392     saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
393     saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
394                                     createPasswordDigest("sajack", "sajack"));
395 
396     ASSERT_OK(saslClientSession->initialize());
397 
398     SCRAMMutators mutator;
399     mutator.setMutator(SaslTestState(SaslTestState::kClient, 1), [](std::string& clientMessage) {
400         clientMessage.replace(clientMessage.begin(), clientMessage.begin() + 1, "y");
401     });
402 
403     ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
404 }
405 
TEST_F(SCRAMSHA1Fixture,testSCRAMWithChannelBindingRequiredByClient)406 TEST_F(SCRAMSHA1Fixture, testSCRAMWithChannelBindingRequiredByClient) {
407     authzManagerExternalState
408         ->insertPrivilegeDocument(
409             opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj())
410         .transitional_ignore();
411 
412     saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
413     saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
414                                     createPasswordDigest("sajack", "sajack"));
415 
416     ASSERT_OK(saslClientSession->initialize());
417 
418     SCRAMMutators mutator;
419     mutator.setMutator(SaslTestState(SaslTestState::kClient, 1), [](std::string& clientMessage) {
420         clientMessage.replace(clientMessage.begin(), clientMessage.begin() + 1, "p=tls-unique");
421     });
422 
423     ASSERT_EQ(
424         SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 1),
425                          Status(ErrorCodes::BadValue, "Server does not support channel binding")),
426         runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
427 }
428 
TEST_F(SCRAMSHA1Fixture,testSCRAMWithInvalidChannelBinding)429 TEST_F(SCRAMSHA1Fixture, testSCRAMWithInvalidChannelBinding) {
430     authzManagerExternalState
431         ->insertPrivilegeDocument(
432             opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj())
433         .transitional_ignore();
434 
435     saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
436     saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
437                                     createPasswordDigest("sajack", "sajack"));
438 
439     ASSERT_OK(saslClientSession->initialize());
440 
441     SCRAMMutators mutator;
442     mutator.setMutator(SaslTestState(SaslTestState::kClient, 1), [](std::string& clientMessage) {
443         clientMessage.replace(clientMessage.begin(), clientMessage.begin() + 1, "v=illegalGarbage");
444     });
445 
446     ASSERT_EQ(
447         SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 1),
448                          Status(ErrorCodes::BadValue,
449                                 "Incorrect SCRAM-SHA-1 client message prefix: v=illegalGarbage")),
450         runSteps(saslServerSession.get(), saslClientSession.get(), mutator));
451 }
452 
TEST_F(SCRAMSHA1Fixture,testNULLInPassword)453 TEST_F(SCRAMSHA1Fixture, testNULLInPassword) {
454     authzManagerExternalState
455         ->insertPrivilegeDocument(
456             opCtx.get(), generateSCRAMUserDocument("sajack", "saj\0ack"), BSONObj())
457         .transitional_ignore();
458 
459     saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
460     saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
461                                     createPasswordDigest("sajack", "saj\0ack"));
462 
463     ASSERT_OK(saslClientSession->initialize());
464 
465     ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get()));
466 }
467 
468 
TEST_F(SCRAMSHA1Fixture,testCommasInUsernameAndPassword)469 TEST_F(SCRAMSHA1Fixture, testCommasInUsernameAndPassword) {
470     authzManagerExternalState
471         ->insertPrivilegeDocument(
472             opCtx.get(), generateSCRAMUserDocument("s,a,jack", "s,a,jack"), BSONObj())
473         .transitional_ignore();
474 
475     saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "s,a,jack");
476     saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
477                                     createPasswordDigest("s,a,jack", "s,a,jack"));
478 
479     ASSERT_OK(saslClientSession->initialize());
480 
481     ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get()));
482 }
483 
TEST_F(SCRAMSHA1Fixture,testIncorrectUser)484 TEST_F(SCRAMSHA1Fixture, testIncorrectUser) {
485     saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
486     saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
487                                     createPasswordDigest("sajack", "sajack"));
488 
489     ASSERT_OK(saslClientSession->initialize());
490 
491     ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 1),
492                                Status(ErrorCodes::UserNotFound, "Could not find user sajack@test")),
493               runSteps(saslServerSession.get(), saslClientSession.get()));
494 }
495 
TEST_F(SCRAMSHA1Fixture,testIncorrectPassword)496 TEST_F(SCRAMSHA1Fixture, testIncorrectPassword) {
497     authzManagerExternalState
498         ->insertPrivilegeDocument(
499             opCtx.get(), generateSCRAMUserDocument("sajack", "sajack"), BSONObj())
500         .transitional_ignore();
501 
502     saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
503     saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
504                                     createPasswordDigest("sajack", "invalidPassword"));
505 
506     ASSERT_OK(saslClientSession->initialize());
507 
508     ASSERT_EQ(SCRAMStepsResult(SaslTestState(SaslTestState::kServer, 2),
509                                Status(ErrorCodes::AuthenticationFailed,
510                                       "SCRAM-SHA-1 authentication failed, storedKey mismatch")),
511               runSteps(saslServerSession.get(), saslClientSession.get()));
512 }
513 
TEST_F(SCRAMSHA1Fixture,testMONGODBCR)514 TEST_F(SCRAMSHA1Fixture, testMONGODBCR) {
515     authzManagerExternalState
516         ->insertPrivilegeDocument(
517             opCtx.get(), generateMONGODBCRUserDocument("sajack", "sajack"), BSONObj())
518         .transitional_ignore();
519 
520     saslClientSession->setParameter(NativeSaslClientSession::parameterUser, "sajack");
521     saslClientSession->setParameter(NativeSaslClientSession::parameterPassword,
522                                     createPasswordDigest("sajack", "sajack"));
523 
524     ASSERT_OK(saslClientSession->initialize());
525 
526     ASSERT_EQ(goalState, runSteps(saslServerSession.get(), saslClientSession.get()));
527 }
528 
TEST(SCRAMSHA1Cache,testGetFromEmptyCache)529 TEST(SCRAMSHA1Cache, testGetFromEmptyCache) {
530     SCRAMSHA1ClientCache cache;
531     std::string saltStr("saltsaltsaltsalt");
532     std::vector<std::uint8_t> salt(saltStr.begin(), saltStr.end());
533     HostAndPort host("localhost:27017");
534 
535     ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000)));
536 }
537 
538 
TEST(SCRAMSHA1Cache,testSetAndGet)539 TEST(SCRAMSHA1Cache, testSetAndGet) {
540     SCRAMSHA1ClientCache cache;
541     std::string saltStr("saltsaltsaltsalt");
542     std::string badSaltStr("s@lts@lts@lts@lt");
543     std::vector<std::uint8_t> salt(saltStr.begin(), saltStr.end());
544     std::vector<std::uint8_t> badSalt(badSaltStr.begin(), badSaltStr.end());
545     HostAndPort host("localhost:27017");
546 
547     auto secret = scram::generateSecrets(scram::SCRAMPresecrets("aaa", salt, 10000));
548     cache.setCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000), secret);
549     auto cachedSecret = cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000));
550     ASSERT_TRUE(cachedSecret);
551     ASSERT_TRUE(secret->clientKey == cachedSecret->clientKey);
552     ASSERT_TRUE(secret->serverKey == cachedSecret->serverKey);
553     ASSERT_TRUE(secret->storedKey == cachedSecret->storedKey);
554 }
555 
556 
TEST(SCRAMSHA1Cache,testSetAndGetWithDifferentParameters)557 TEST(SCRAMSHA1Cache, testSetAndGetWithDifferentParameters) {
558     SCRAMSHA1ClientCache cache;
559     std::string saltStr("saltsaltsaltsalt");
560     std::string badSaltStr("s@lts@lts@lts@lt");
561     std::vector<std::uint8_t> salt(saltStr.begin(), saltStr.end());
562     std::vector<std::uint8_t> badSalt(badSaltStr.begin(), badSaltStr.end());
563     HostAndPort host("localhost:27017");
564 
565     auto secret = scram::generateSecrets(scram::SCRAMPresecrets("aaa", salt, 10000));
566     cache.setCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000), secret);
567 
568     ASSERT_FALSE(cache.getCachedSecrets(HostAndPort("localhost:27018"),
569                                         scram::SCRAMPresecrets("aaa", salt, 10000)));
570     ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aab", salt, 10000)));
571     ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", badSalt, 10000)));
572     ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10001)));
573 }
574 
575 
TEST(SCRAMSHA1Cache,testSetAndReset)576 TEST(SCRAMSHA1Cache, testSetAndReset) {
577     SCRAMSHA1ClientCache cache;
578     StringData saltStr("saltsaltsaltsalt");
579     std::vector<std::uint8_t> salt(saltStr.begin(), saltStr.end());
580     HostAndPort host("localhost:27017");
581 
582     auto secret = scram::generateSecrets(scram::SCRAMPresecrets("aaa", salt, 10000));
583     cache.setCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000), secret);
584     auto newSecret = scram::generateSecrets(scram::SCRAMPresecrets("aab", salt, 10000));
585     cache.setCachedSecrets(host, scram::SCRAMPresecrets("aab", salt, 10000), newSecret);
586 
587     ASSERT_FALSE(cache.getCachedSecrets(host, scram::SCRAMPresecrets("aaa", salt, 10000)));
588     auto cachedSecret = cache.getCachedSecrets(host, scram::SCRAMPresecrets("aab", salt, 10000));
589     ASSERT_TRUE(cachedSecret);
590     ASSERT_TRUE(newSecret->clientKey == cachedSecret->clientKey);
591     ASSERT_TRUE(newSecret->serverKey == cachedSecret->serverKey);
592     ASSERT_TRUE(newSecret->storedKey == cachedSecret->storedKey);
593 }
594 
595 }  // namespace mongo
596