1 /*-------------------------------------------------------------------------
2  *
3  * auth-scram.c
4  *	  Server-side implementation of the SASL SCRAM-SHA-256 mechanism.
5  *
6  * See the following RFCs for more details:
7  * - RFC 5802: https://tools.ietf.org/html/rfc5802
8  * - RFC 5803: https://tools.ietf.org/html/rfc5803
9  * - RFC 7677: https://tools.ietf.org/html/rfc7677
10  *
11  * Here are some differences:
12  *
13  * - Username from the authentication exchange is not used. The client
14  *	 should send an empty string as the username.
15  *
16  * - If the password isn't valid UTF-8, or contains characters prohibited
17  *	 by the SASLprep profile, we skip the SASLprep pre-processing and use
18  *	 the raw bytes in calculating the hash.
19  *
20  * - If channel binding is used, the channel binding type is always
21  *	 "tls-server-end-point".  The spec says the default is "tls-unique"
22  *	 (RFC 5802, section 6.1. Default Channel Binding), but there are some
23  *	 problems with that.  Firstly, not all SSL libraries provide an API to
24  *	 get the TLS Finished message, required to use "tls-unique".  Secondly,
25  *	 "tls-unique" is not specified for TLS v1.3, and as of this writing,
26  *	 it's not clear if there will be a replacement.  We could support both
27  *	 "tls-server-end-point" and "tls-unique", but for our use case,
28  *	 "tls-unique" doesn't really have any advantages.  The main advantage
29  *	 of "tls-unique" would be that it works even if the server doesn't
30  *	 have a certificate, but PostgreSQL requires a server certificate
31  *	 whenever SSL is used, anyway.
32  *
33  *
34  * The password stored in pg_authid consists of the iteration count, salt,
35  * StoredKey and ServerKey.
36  *
37  * SASLprep usage
38  * --------------
39  *
40  * One notable difference to the SCRAM specification is that while the
41  * specification dictates that the password is in UTF-8, and prohibits
42  * certain characters, we are more lenient.  If the password isn't a valid
43  * UTF-8 string, or contains prohibited characters, the raw bytes are used
44  * to calculate the hash instead, without SASLprep processing.  This is
45  * because PostgreSQL supports other encodings too, and the encoding being
46  * used during authentication is undefined (client_encoding isn't set until
47  * after authentication).  In effect, we try to interpret the password as
48  * UTF-8 and apply SASLprep processing, but if it looks invalid, we assume
49  * that it's in some other encoding.
50  *
51  * In the worst case, we misinterpret a password that's in a different
52  * encoding as being Unicode, because it happens to consists entirely of
53  * valid UTF-8 bytes, and we apply Unicode normalization to it.  As long
54  * as we do that consistently, that will not lead to failed logins.
55  * Fortunately, the UTF-8 byte sequences that are ignored by SASLprep
56  * don't correspond to any commonly used characters in any of the other
57  * supported encodings, so it should not lead to any significant loss in
58  * entropy, even if the normalization is incorrectly applied to a
59  * non-UTF-8 password.
60  *
61  * Error handling
62  * --------------
63  *
64  * Don't reveal user information to an unauthenticated client.  We don't
65  * want an attacker to be able to probe whether a particular username is
66  * valid.  In SCRAM, the server has to read the salt and iteration count
67  * from the user's password verifier, and send it to the client.  To avoid
68  * revealing whether a user exists, when the client tries to authenticate
69  * with a username that doesn't exist, or doesn't have a valid SCRAM
70  * verifier in pg_authid, we create a fake salt and iteration count
71  * on-the-fly, and proceed with the authentication with that.  In the end,
72  * we'll reject the attempt, as if an incorrect password was given.  When
73  * we are performing a "mock" authentication, the 'doomed' flag in
74  * scram_state is set.
75  *
76  * In the error messages, avoid printing strings from the client, unless
77  * you check that they are pure ASCII.  We don't want an unauthenticated
78  * attacker to be able to spam the logs with characters that are not valid
79  * to the encoding being used, whatever that is.  We cannot avoid that in
80  * general, after logging in, but let's do what we can here.
81  *
82  *
83  * Portions Copyright (c) 1996-2018, PostgreSQL Global Development Group
84  * Portions Copyright (c) 1994, Regents of the University of California
85  *
86  * src/backend/libpq/auth-scram.c
87  *
88  *-------------------------------------------------------------------------
89  */
90 #include "postgres.h"
91 
92 #include <unistd.h>
93 
94 #include "access/xlog.h"
95 #include "catalog/pg_authid.h"
96 #include "catalog/pg_control.h"
97 #include "common/base64.h"
98 #include "common/saslprep.h"
99 #include "common/scram-common.h"
100 #include "common/sha2.h"
101 #include "libpq/auth.h"
102 #include "libpq/crypt.h"
103 #include "libpq/scram.h"
104 #include "miscadmin.h"
105 #include "utils/backend_random.h"
106 #include "utils/builtins.h"
107 #include "utils/timestamp.h"
108 
109 /*
110  * Status data for a SCRAM authentication exchange.  This should be kept
111  * internal to this file.
112  */
113 typedef enum
114 {
115 	SCRAM_AUTH_INIT,
116 	SCRAM_AUTH_SALT_SENT,
117 	SCRAM_AUTH_FINISHED
118 } scram_state_enum;
119 
120 typedef struct
121 {
122 	scram_state_enum state;
123 
124 	const char *username;		/* username from startup packet */
125 
126 	Port	   *port;
127 	bool		channel_binding_in_use;
128 
129 	int			iterations;
130 	char	   *salt;			/* base64-encoded */
131 	uint8		StoredKey[SCRAM_KEY_LEN];
132 	uint8		ServerKey[SCRAM_KEY_LEN];
133 
134 	/* Fields of the first message from client */
135 	char		cbind_flag;
136 	char	   *client_first_message_bare;
137 	char	   *client_username;
138 	char	   *client_nonce;
139 
140 	/* Fields from the last message from client */
141 	char	   *client_final_message_without_proof;
142 	char	   *client_final_nonce;
143 	char		ClientProof[SCRAM_KEY_LEN];
144 
145 	/* Fields generated in the server */
146 	char	   *server_first_message;
147 	char	   *server_nonce;
148 
149 	/*
150 	 * If something goes wrong during the authentication, or we are performing
151 	 * a "mock" authentication (see comments at top of file), the 'doomed'
152 	 * flag is set.  A reason for the failure, for the server log, is put in
153 	 * 'logdetail'.
154 	 */
155 	bool		doomed;
156 	char	   *logdetail;
157 } scram_state;
158 
159 static void read_client_first_message(scram_state *state, char *input);
160 static void read_client_final_message(scram_state *state, char *input);
161 static char *build_server_first_message(scram_state *state);
162 static char *build_server_final_message(scram_state *state);
163 static bool verify_client_proof(scram_state *state);
164 static bool verify_final_nonce(scram_state *state);
165 static void mock_scram_verifier(const char *username, int *iterations,
166 					char **salt, uint8 *stored_key, uint8 *server_key);
167 static bool is_scram_printable(char *p);
168 static char *sanitize_char(char c);
169 static char *sanitize_str(const char *s);
170 static char *scram_mock_salt(const char *username);
171 
172 /*
173  * pg_be_scram_get_mechanisms
174  *
175  * Get a list of SASL mechanisms that this module supports.
176  *
177  * For the convenience of building the FE/BE packet that lists the
178  * mechanisms, the names are appended to the given StringInfo buffer,
179  * separated by '\0' bytes.
180  */
181 void
pg_be_scram_get_mechanisms(Port * port,StringInfo buf)182 pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
183 {
184 	/*
185 	 * Advertise the mechanisms in decreasing order of importance.  So the
186 	 * channel-binding variants go first, if they are supported.  Channel
187 	 * binding is only supported with SSL, and only if the SSL implementation
188 	 * has a function to get the certificate's hash.
189 	 */
190 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
191 	if (port->ssl_in_use)
192 	{
193 		appendStringInfoString(buf, SCRAM_SHA_256_PLUS_NAME);
194 		appendStringInfoChar(buf, '\0');
195 	}
196 #endif
197 	appendStringInfoString(buf, SCRAM_SHA_256_NAME);
198 	appendStringInfoChar(buf, '\0');
199 }
200 
201 /*
202  * pg_be_scram_init
203  *
204  * Initialize a new SCRAM authentication exchange status tracker.  This
205  * needs to be called before doing any exchange.  It will be filled later
206  * after the beginning of the exchange with verifier data.
207  *
208  * 'selected_mech' identifies the SASL mechanism that the client selected.
209  * It should be one of the mechanisms that we support, as returned by
210  * pg_be_scram_get_mechanisms().
211  *
212  * 'shadow_pass' is the role's password verifier, from pg_authid.rolpassword.
213  * The username was provided by the client in the startup message, and is
214  * available in port->user_name.  If 'shadow_pass' is NULL, we still perform
215  * an authentication exchange, but it will fail, as if an incorrect password
216  * was given.
217  */
218 void *
pg_be_scram_init(Port * port,const char * selected_mech,const char * shadow_pass)219 pg_be_scram_init(Port *port,
220 				 const char *selected_mech,
221 				 const char *shadow_pass)
222 {
223 	scram_state *state;
224 	bool		got_verifier;
225 
226 	state = (scram_state *) palloc0(sizeof(scram_state));
227 	state->port = port;
228 	state->state = SCRAM_AUTH_INIT;
229 
230 	/*
231 	 * Parse the selected mechanism.
232 	 *
233 	 * Note that if we don't support channel binding, either because the SSL
234 	 * implementation doesn't support it or we're not using SSL at all, we
235 	 * would not have advertised the PLUS variant in the first place.  If the
236 	 * client nevertheless tries to select it, it's a protocol violation like
237 	 * selecting any other SASL mechanism we don't support.
238 	 */
239 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
240 	if (strcmp(selected_mech, SCRAM_SHA_256_PLUS_NAME) == 0 && port->ssl_in_use)
241 		state->channel_binding_in_use = true;
242 	else
243 #endif
244 	if (strcmp(selected_mech, SCRAM_SHA_256_NAME) == 0)
245 		state->channel_binding_in_use = false;
246 	else
247 		ereport(ERROR,
248 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
249 				 errmsg("client selected an invalid SASL authentication mechanism")));
250 
251 	/*
252 	 * Parse the stored password verifier.
253 	 */
254 	if (shadow_pass)
255 	{
256 		int			password_type = get_password_type(shadow_pass);
257 
258 		if (password_type == PASSWORD_TYPE_SCRAM_SHA_256)
259 		{
260 			if (parse_scram_verifier(shadow_pass, &state->iterations, &state->salt,
261 									 state->StoredKey, state->ServerKey))
262 				got_verifier = true;
263 			else
264 			{
265 				/*
266 				 * The password looked like a SCRAM verifier, but could not be
267 				 * parsed.
268 				 */
269 				ereport(LOG,
270 						(errmsg("invalid SCRAM verifier for user \"%s\"",
271 								state->port->user_name)));
272 				got_verifier = false;
273 			}
274 		}
275 		else
276 		{
277 			/*
278 			 * The user doesn't have SCRAM verifier. (You cannot do SCRAM
279 			 * authentication with an MD5 hash.)
280 			 */
281 			state->logdetail = psprintf(_("User \"%s\" does not have a valid SCRAM verifier."),
282 										state->port->user_name);
283 			got_verifier = false;
284 		}
285 	}
286 	else
287 	{
288 		/*
289 		 * The caller requested us to perform a dummy authentication.  This is
290 		 * considered normal, since the caller requested it, so don't set log
291 		 * detail.
292 		 */
293 		got_verifier = false;
294 	}
295 
296 	/*
297 	 * If the user did not have a valid SCRAM verifier, we still go through
298 	 * the motions with a mock one, and fail as if the client supplied an
299 	 * incorrect password.  This is to avoid revealing information to an
300 	 * attacker.
301 	 */
302 	if (!got_verifier)
303 	{
304 		mock_scram_verifier(state->port->user_name, &state->iterations,
305 							&state->salt, state->StoredKey, state->ServerKey);
306 		state->doomed = true;
307 	}
308 
309 	return state;
310 }
311 
312 /*
313  * Continue a SCRAM authentication exchange.
314  *
315  * 'input' is the SCRAM payload sent by the client.  On the first call,
316  * 'input' contains the "Initial Client Response" that the client sent as
317  * part of the SASLInitialResponse message, or NULL if no Initial Client
318  * Response was given.  (The SASL specification distinguishes between an
319  * empty response and non-existing one.)  On subsequent calls, 'input'
320  * cannot be NULL.  For convenience in this function, the caller must
321  * ensure that there is a null terminator at input[inputlen].
322  *
323  * The next message to send to client is saved in 'output', for a length
324  * of 'outputlen'.  In the case of an error, optionally store a palloc'd
325  * string at *logdetail that will be sent to the postmaster log (but not
326  * the client).
327  */
328 int
pg_be_scram_exchange(void * opaq,char * input,int inputlen,char ** output,int * outputlen,char ** logdetail)329 pg_be_scram_exchange(void *opaq, char *input, int inputlen,
330 					 char **output, int *outputlen, char **logdetail)
331 {
332 	scram_state *state = (scram_state *) opaq;
333 	int			result;
334 
335 	*output = NULL;
336 
337 	/*
338 	 * If the client didn't include an "Initial Client Response" in the
339 	 * SASLInitialResponse message, send an empty challenge, to which the
340 	 * client will respond with the same data that usually comes in the
341 	 * Initial Client Response.
342 	 */
343 	if (input == NULL)
344 	{
345 		Assert(state->state == SCRAM_AUTH_INIT);
346 
347 		*output = pstrdup("");
348 		*outputlen = 0;
349 		return SASL_EXCHANGE_CONTINUE;
350 	}
351 
352 	/*
353 	 * Check that the input length agrees with the string length of the input.
354 	 * We can ignore inputlen after this.
355 	 */
356 	if (inputlen == 0)
357 		ereport(ERROR,
358 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
359 				 errmsg("malformed SCRAM message"),
360 				 errdetail("The message is empty.")));
361 	if (inputlen != strlen(input))
362 		ereport(ERROR,
363 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
364 				 errmsg("malformed SCRAM message"),
365 				 errdetail("Message length does not match input length.")));
366 
367 	switch (state->state)
368 	{
369 		case SCRAM_AUTH_INIT:
370 
371 			/*
372 			 * Initialization phase.  Receive the first message from client
373 			 * and be sure that it parsed correctly.  Then send the challenge
374 			 * to the client.
375 			 */
376 			read_client_first_message(state, input);
377 
378 			/* prepare message to send challenge */
379 			*output = build_server_first_message(state);
380 
381 			state->state = SCRAM_AUTH_SALT_SENT;
382 			result = SASL_EXCHANGE_CONTINUE;
383 			break;
384 
385 		case SCRAM_AUTH_SALT_SENT:
386 
387 			/*
388 			 * Final phase for the server.  Receive the response to the
389 			 * challenge previously sent, verify, and let the client know that
390 			 * everything went well (or not).
391 			 */
392 			read_client_final_message(state, input);
393 
394 			if (!verify_final_nonce(state))
395 				ereport(ERROR,
396 						(errcode(ERRCODE_PROTOCOL_VIOLATION),
397 						 errmsg("invalid SCRAM response"),
398 						 errdetail("Nonce does not match.")));
399 
400 			/*
401 			 * Now check the final nonce and the client proof.
402 			 *
403 			 * If we performed a "mock" authentication that we knew would fail
404 			 * from the get go, this is where we fail.
405 			 *
406 			 * The SCRAM specification includes an error code,
407 			 * "invalid-proof", for authentication failure, but it also allows
408 			 * erroring out in an application-specific way.  We choose to do
409 			 * the latter, so that the error message for invalid password is
410 			 * the same for all authentication methods.  The caller will call
411 			 * ereport(), when we return SASL_EXCHANGE_FAILURE with no output.
412 			 *
413 			 * NB: the order of these checks is intentional.  We calculate the
414 			 * client proof even in a mock authentication, even though it's
415 			 * bound to fail, to thwart timing attacks to determine if a role
416 			 * with the given name exists or not.
417 			 */
418 			if (!verify_client_proof(state) || state->doomed)
419 			{
420 				result = SASL_EXCHANGE_FAILURE;
421 				break;
422 			}
423 
424 			/* Build final message for client */
425 			*output = build_server_final_message(state);
426 
427 			/* Success! */
428 			result = SASL_EXCHANGE_SUCCESS;
429 			state->state = SCRAM_AUTH_FINISHED;
430 			break;
431 
432 		default:
433 			elog(ERROR, "invalid SCRAM exchange state");
434 			result = SASL_EXCHANGE_FAILURE;
435 	}
436 
437 	if (result == SASL_EXCHANGE_FAILURE && state->logdetail && logdetail)
438 		*logdetail = state->logdetail;
439 
440 	if (*output)
441 		*outputlen = strlen(*output);
442 
443 	return result;
444 }
445 
446 /*
447  * Construct a verifier string for SCRAM, stored in pg_authid.rolpassword.
448  *
449  * The result is palloc'd, so caller is responsible for freeing it.
450  */
451 char *
pg_be_scram_build_verifier(const char * password)452 pg_be_scram_build_verifier(const char *password)
453 {
454 	char	   *prep_password;
455 	pg_saslprep_rc rc;
456 	char		saltbuf[SCRAM_DEFAULT_SALT_LEN];
457 	char	   *result;
458 
459 	/*
460 	 * Normalize the password with SASLprep.  If that doesn't work, because
461 	 * the password isn't valid UTF-8 or contains prohibited characters, just
462 	 * proceed with the original password.  (See comments at top of file.)
463 	 */
464 	rc = pg_saslprep(password, &prep_password);
465 	if (rc == SASLPREP_SUCCESS)
466 		password = (const char *) prep_password;
467 
468 	/* Generate random salt */
469 	if (!pg_backend_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
470 		ereport(ERROR,
471 				(errcode(ERRCODE_INTERNAL_ERROR),
472 				 errmsg("could not generate random salt")));
473 
474 	result = scram_build_verifier(saltbuf, SCRAM_DEFAULT_SALT_LEN,
475 								  SCRAM_DEFAULT_ITERATIONS, password);
476 
477 	if (prep_password)
478 		pfree(prep_password);
479 
480 	return result;
481 }
482 
483 /*
484  * Verify a plaintext password against a SCRAM verifier.  This is used when
485  * performing plaintext password authentication for a user that has a SCRAM
486  * verifier stored in pg_authid.
487  */
488 bool
scram_verify_plain_password(const char * username,const char * password,const char * verifier)489 scram_verify_plain_password(const char *username, const char *password,
490 							const char *verifier)
491 {
492 	char	   *encoded_salt;
493 	char	   *salt;
494 	int			saltlen;
495 	int			iterations;
496 	uint8		salted_password[SCRAM_KEY_LEN];
497 	uint8		stored_key[SCRAM_KEY_LEN];
498 	uint8		server_key[SCRAM_KEY_LEN];
499 	uint8		computed_key[SCRAM_KEY_LEN];
500 	char	   *prep_password;
501 	pg_saslprep_rc rc;
502 
503 	if (!parse_scram_verifier(verifier, &iterations, &encoded_salt,
504 							  stored_key, server_key))
505 	{
506 		/*
507 		 * The password looked like a SCRAM verifier, but could not be parsed.
508 		 */
509 		ereport(LOG,
510 				(errmsg("invalid SCRAM verifier for user \"%s\"", username)));
511 		return false;
512 	}
513 
514 	salt = palloc(pg_b64_dec_len(strlen(encoded_salt)));
515 	saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt);
516 	if (saltlen == -1)
517 	{
518 		ereport(LOG,
519 				(errmsg("invalid SCRAM verifier for user \"%s\"", username)));
520 		return false;
521 	}
522 
523 	/* Normalize the password */
524 	rc = pg_saslprep(password, &prep_password);
525 	if (rc == SASLPREP_SUCCESS)
526 		password = prep_password;
527 
528 	/* Compute Server Key based on the user-supplied plaintext password */
529 	scram_SaltedPassword(password, salt, saltlen, iterations, salted_password);
530 	scram_ServerKey(salted_password, computed_key);
531 
532 	if (prep_password)
533 		pfree(prep_password);
534 
535 	/*
536 	 * Compare the verifier's Server Key with the one computed from the
537 	 * user-supplied password.
538 	 */
539 	return memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0;
540 }
541 
542 
543 /*
544  * Parse and validate format of given SCRAM verifier.
545  *
546  * On success, the iteration count, salt, stored key, and server key are
547  * extracted from the verifier, and returned to the caller.  For 'stored_key'
548  * and 'server_key', the caller must pass pre-allocated buffers of size
549  * SCRAM_KEY_LEN.  Salt is returned as a base64-encoded, null-terminated
550  * string.  The buffer for the salt is palloc'd by this function.
551  *
552  * Returns true if the SCRAM verifier has been parsed, and false otherwise.
553  */
554 bool
parse_scram_verifier(const char * verifier,int * iterations,char ** salt,uint8 * stored_key,uint8 * server_key)555 parse_scram_verifier(const char *verifier, int *iterations, char **salt,
556 					 uint8 *stored_key, uint8 *server_key)
557 {
558 	char	   *v;
559 	char	   *p;
560 	char	   *scheme_str;
561 	char	   *salt_str;
562 	char	   *iterations_str;
563 	char	   *storedkey_str;
564 	char	   *serverkey_str;
565 	int			decoded_len;
566 	char	   *decoded_salt_buf;
567 	char	   *decoded_stored_buf;
568 	char	   *decoded_server_buf;
569 
570 	/*
571 	 * The verifier is of form:
572 	 *
573 	 * SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
574 	 */
575 	v = pstrdup(verifier);
576 	if ((scheme_str = strtok(v, "$")) == NULL)
577 		goto invalid_verifier;
578 	if ((iterations_str = strtok(NULL, ":")) == NULL)
579 		goto invalid_verifier;
580 	if ((salt_str = strtok(NULL, "$")) == NULL)
581 		goto invalid_verifier;
582 	if ((storedkey_str = strtok(NULL, ":")) == NULL)
583 		goto invalid_verifier;
584 	if ((serverkey_str = strtok(NULL, "")) == NULL)
585 		goto invalid_verifier;
586 
587 	/* Parse the fields */
588 	if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
589 		goto invalid_verifier;
590 
591 	errno = 0;
592 	*iterations = strtol(iterations_str, &p, 10);
593 	if (*p || errno != 0)
594 		goto invalid_verifier;
595 
596 	/*
597 	 * Verify that the salt is in Base64-encoded format, by decoding it,
598 	 * although we return the encoded version to the caller.
599 	 */
600 	decoded_salt_buf = palloc(pg_b64_dec_len(strlen(salt_str)));
601 	decoded_len = pg_b64_decode(salt_str, strlen(salt_str),
602 								decoded_salt_buf);
603 	if (decoded_len < 0)
604 		goto invalid_verifier;
605 	*salt = pstrdup(salt_str);
606 
607 	/*
608 	 * Decode StoredKey and ServerKey.
609 	 */
610 	decoded_stored_buf = palloc(pg_b64_dec_len(strlen(storedkey_str)));
611 	decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
612 								decoded_stored_buf);
613 	if (decoded_len != SCRAM_KEY_LEN)
614 		goto invalid_verifier;
615 	memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
616 
617 	decoded_server_buf = palloc(pg_b64_dec_len(strlen(serverkey_str)));
618 	decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
619 								decoded_server_buf);
620 	if (decoded_len != SCRAM_KEY_LEN)
621 		goto invalid_verifier;
622 	memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
623 
624 	return true;
625 
626 invalid_verifier:
627 	*salt = NULL;
628 	return false;
629 }
630 
631 /*
632  * Generate plausible SCRAM verifier parameters for mock authentication.
633  *
634  * In a normal authentication, these are extracted from the verifier
635  * stored in the server.  This function generates values that look
636  * realistic, for when there is no stored verifier.
637  *
638  * Like in parse_scram_verifier(), for 'stored_key' and 'server_key', the
639  * caller must pass pre-allocated buffers of size SCRAM_KEY_LEN, and
640  * the buffer for the salt is palloc'd by this function.
641  */
642 static void
mock_scram_verifier(const char * username,int * iterations,char ** salt,uint8 * stored_key,uint8 * server_key)643 mock_scram_verifier(const char *username, int *iterations, char **salt,
644 					uint8 *stored_key, uint8 *server_key)
645 {
646 	char	   *raw_salt;
647 	char	   *encoded_salt;
648 	int			encoded_len;
649 
650 	/* Generate deterministic salt */
651 	raw_salt = scram_mock_salt(username);
652 
653 	encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1);
654 	encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt);
655 	encoded_salt[encoded_len] = '\0';
656 
657 	*salt = encoded_salt;
658 	*iterations = SCRAM_DEFAULT_ITERATIONS;
659 
660 	/* StoredKey and ServerKey are not used in a doomed authentication */
661 	memset(stored_key, 0, SCRAM_KEY_LEN);
662 	memset(server_key, 0, SCRAM_KEY_LEN);
663 }
664 
665 /*
666  * Read the value in a given SCRAM exchange message for given attribute.
667  */
668 static char *
read_attr_value(char ** input,char attr)669 read_attr_value(char **input, char attr)
670 {
671 	char	   *begin = *input;
672 	char	   *end;
673 
674 	if (*begin != attr)
675 		ereport(ERROR,
676 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
677 				 errmsg("malformed SCRAM message"),
678 				 errdetail("Expected attribute \"%c\" but found \"%s\".",
679 						   attr, sanitize_char(*begin))));
680 	begin++;
681 
682 	if (*begin != '=')
683 		ereport(ERROR,
684 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
685 				 errmsg("malformed SCRAM message"),
686 				 errdetail("Expected character \"=\" for attribute \"%c\".", attr)));
687 	begin++;
688 
689 	end = begin;
690 	while (*end && *end != ',')
691 		end++;
692 
693 	if (*end)
694 	{
695 		*end = '\0';
696 		*input = end + 1;
697 	}
698 	else
699 		*input = end;
700 
701 	return begin;
702 }
703 
704 static bool
is_scram_printable(char * p)705 is_scram_printable(char *p)
706 {
707 	/*------
708 	 * Printable characters, as defined by SCRAM spec: (RFC 5802)
709 	 *
710 	 *	printable		= %x21-2B / %x2D-7E
711 	 *					  ;; Printable ASCII except ",".
712 	 *					  ;; Note that any "printable" is also
713 	 *					  ;; a valid "value".
714 	 *------
715 	 */
716 	for (; *p; p++)
717 	{
718 		if (*p < 0x21 || *p > 0x7E || *p == 0x2C /* comma */ )
719 			return false;
720 	}
721 	return true;
722 }
723 
724 /*
725  * Convert an arbitrary byte to printable form.  For error messages.
726  *
727  * If it's a printable ASCII character, print it as a single character.
728  * otherwise, print it in hex.
729  *
730  * The returned pointer points to a static buffer.
731  */
732 static char *
sanitize_char(char c)733 sanitize_char(char c)
734 {
735 	static char buf[5];
736 
737 	if (c >= 0x21 && c <= 0x7E)
738 		snprintf(buf, sizeof(buf), "'%c'", c);
739 	else
740 		snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
741 	return buf;
742 }
743 
744 /*
745  * Convert an arbitrary string to printable form, for error messages.
746  *
747  * Anything that's not a printable ASCII character is replaced with
748  * '?', and the string is truncated at 30 characters.
749  *
750  * The returned pointer points to a static buffer.
751  */
752 static char *
sanitize_str(const char * s)753 sanitize_str(const char *s)
754 {
755 	static char buf[30 + 1];
756 	int			i;
757 
758 	for (i = 0; i < sizeof(buf) - 1; i++)
759 	{
760 		char		c = s[i];
761 
762 		if (c == '\0')
763 			break;
764 
765 		if (c >= 0x21 && c <= 0x7E)
766 			buf[i] = c;
767 		else
768 			buf[i] = '?';
769 	}
770 	buf[i] = '\0';
771 	return buf;
772 }
773 
774 /*
775  * Read the next attribute and value in a SCRAM exchange message.
776  *
777  * Returns NULL if there is attribute.
778  */
779 static char *
read_any_attr(char ** input,char * attr_p)780 read_any_attr(char **input, char *attr_p)
781 {
782 	char	   *begin = *input;
783 	char	   *end;
784 	char		attr = *begin;
785 
786 	/*------
787 	 * attr-val		   = ALPHA "=" value
788 	 *					 ;; Generic syntax of any attribute sent
789 	 *					 ;; by server or client
790 	 *------
791 	 */
792 	if (!((attr >= 'A' && attr <= 'Z') ||
793 		  (attr >= 'a' && attr <= 'z')))
794 		ereport(ERROR,
795 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
796 				 errmsg("malformed SCRAM message"),
797 				 errdetail("Attribute expected, but found invalid character \"%s\".",
798 						   sanitize_char(attr))));
799 	if (attr_p)
800 		*attr_p = attr;
801 	begin++;
802 
803 	if (*begin != '=')
804 		ereport(ERROR,
805 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
806 				 errmsg("malformed SCRAM message"),
807 				 errdetail("Expected character \"=\" for attribute \"%c\".", attr)));
808 	begin++;
809 
810 	end = begin;
811 	while (*end && *end != ',')
812 		end++;
813 
814 	if (*end)
815 	{
816 		*end = '\0';
817 		*input = end + 1;
818 	}
819 	else
820 		*input = end;
821 
822 	return begin;
823 }
824 
825 /*
826  * Read and parse the first message from client in the context of a SCRAM
827  * authentication exchange message.
828  *
829  * At this stage, any errors will be reported directly with ereport(ERROR).
830  */
831 static void
read_client_first_message(scram_state * state,char * input)832 read_client_first_message(scram_state *state, char *input)
833 {
834 	char	   *channel_binding_type;
835 
836 	input = pstrdup(input);
837 
838 	/*------
839 	 * The syntax for the client-first-message is: (RFC 5802)
840 	 *
841 	 * saslname		   = 1*(value-safe-char / "=2C" / "=3D")
842 	 *					 ;; Conforms to <value>.
843 	 *
844 	 * authzid		   = "a=" saslname
845 	 *					 ;; Protocol specific.
846 	 *
847 	 * cb-name		   = 1*(ALPHA / DIGIT / "." / "-")
848 	 *					  ;; See RFC 5056, Section 7.
849 	 *					  ;; E.g., "tls-server-end-point" or
850 	 *					  ;; "tls-unique".
851 	 *
852 	 * gs2-cbind-flag  = ("p=" cb-name) / "n" / "y"
853 	 *					 ;; "n" -> client doesn't support channel binding.
854 	 *					 ;; "y" -> client does support channel binding
855 	 *					 ;;		   but thinks the server does not.
856 	 *					 ;; "p" -> client requires channel binding.
857 	 *					 ;; The selected channel binding follows "p=".
858 	 *
859 	 * gs2-header	   = gs2-cbind-flag "," [ authzid ] ","
860 	 *					 ;; GS2 header for SCRAM
861 	 *					 ;; (the actual GS2 header includes an optional
862 	 *					 ;; flag to indicate that the GSS mechanism is not
863 	 *					 ;; "standard", but since SCRAM is "standard", we
864 	 *					 ;; don't include that flag).
865 	 *
866 	 * username		   = "n=" saslname
867 	 *					 ;; Usernames are prepared using SASLprep.
868 	 *
869 	 * reserved-mext  = "m=" 1*(value-char)
870 	 *					 ;; Reserved for signaling mandatory extensions.
871 	 *					 ;; The exact syntax will be defined in
872 	 *					 ;; the future.
873 	 *
874 	 * nonce		   = "r=" c-nonce [s-nonce]
875 	 *					 ;; Second part provided by server.
876 	 *
877 	 * c-nonce		   = printable
878 	 *
879 	 * client-first-message-bare =
880 	 *					 [reserved-mext ","]
881 	 *					 username "," nonce ["," extensions]
882 	 *
883 	 * client-first-message =
884 	 *					 gs2-header client-first-message-bare
885 	 *
886 	 * For example:
887 	 * n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL
888 	 *
889 	 * The "n,," in the beginning means that the client doesn't support
890 	 * channel binding, and no authzid is given.  "n=user" is the username.
891 	 * However, in PostgreSQL the username is sent in the startup packet, and
892 	 * the username in the SCRAM exchange is ignored.  libpq always sends it
893 	 * as an empty string.  The last part, "r=fyko+d2lbbFgONRv9qkxdawL" is
894 	 * the client nonce.
895 	 *------
896 	 */
897 
898 	/*
899 	 * Read gs2-cbind-flag.  (For details see also RFC 5802 Section 6 "Channel
900 	 * Binding".)
901 	 */
902 	state->cbind_flag = *input;
903 	switch (*input)
904 	{
905 		case 'n':
906 
907 			/*
908 			 * The client does not support channel binding or has simply
909 			 * decided to not use it.  In that case just let it go.
910 			 */
911 			if (state->channel_binding_in_use)
912 				ereport(ERROR,
913 						(errcode(ERRCODE_PROTOCOL_VIOLATION),
914 						 errmsg("malformed SCRAM message"),
915 						 errdetail("The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
916 
917 			input++;
918 			if (*input != ',')
919 				ereport(ERROR,
920 						(errcode(ERRCODE_PROTOCOL_VIOLATION),
921 						 errmsg("malformed SCRAM message"),
922 						 errdetail("Comma expected, but found character \"%s\".",
923 								   sanitize_char(*input))));
924 			input++;
925 			break;
926 		case 'y':
927 
928 			/*
929 			 * The client supports channel binding and thinks that the server
930 			 * does not.  In this case, the server must fail authentication if
931 			 * it supports channel binding.
932 			 */
933 			if (state->channel_binding_in_use)
934 				ereport(ERROR,
935 						(errcode(ERRCODE_PROTOCOL_VIOLATION),
936 						 errmsg("malformed SCRAM message"),
937 						 errdetail("The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
938 
939 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
940 			if (state->port->ssl_in_use)
941 				ereport(ERROR,
942 						(errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
943 						 errmsg("SCRAM channel binding negotiation error"),
944 						 errdetail("The client supports SCRAM channel binding but thinks the server does not.  "
945 								   "However, this server does support channel binding.")));
946 #endif
947 			input++;
948 			if (*input != ',')
949 				ereport(ERROR,
950 						(errcode(ERRCODE_PROTOCOL_VIOLATION),
951 						 errmsg("malformed SCRAM message"),
952 						 errdetail("Comma expected, but found character \"%s\".",
953 								   sanitize_char(*input))));
954 			input++;
955 			break;
956 		case 'p':
957 
958 			/*
959 			 * The client requires channel binding.  Channel binding type
960 			 * follows, e.g., "p=tls-server-end-point".
961 			 */
962 			if (!state->channel_binding_in_use)
963 				ereport(ERROR,
964 						(errcode(ERRCODE_PROTOCOL_VIOLATION),
965 						 errmsg("malformed SCRAM message"),
966 						 errdetail("The client selected SCRAM-SHA-256 without channel binding, but the SCRAM message includes channel binding data.")));
967 
968 			channel_binding_type = read_attr_value(&input, 'p');
969 
970 			/*
971 			 * The only channel binding type we support is
972 			 * tls-server-end-point.
973 			 */
974 			if (strcmp(channel_binding_type, "tls-server-end-point") != 0)
975 				ereport(ERROR,
976 						(errcode(ERRCODE_PROTOCOL_VIOLATION),
977 						 (errmsg("unsupported SCRAM channel-binding type \"%s\"",
978 								 sanitize_str(channel_binding_type)))));
979 			break;
980 		default:
981 			ereport(ERROR,
982 					(errcode(ERRCODE_PROTOCOL_VIOLATION),
983 					 errmsg("malformed SCRAM message"),
984 					 errdetail("Unexpected channel-binding flag \"%s\".",
985 							   sanitize_char(*input))));
986 	}
987 
988 	/*
989 	 * Forbid optional authzid (authorization identity).  We don't support it.
990 	 */
991 	if (*input == 'a')
992 		ereport(ERROR,
993 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
994 				 errmsg("client uses authorization identity, but it is not supported")));
995 	if (*input != ',')
996 		ereport(ERROR,
997 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
998 				 errmsg("malformed SCRAM message"),
999 				 errdetail("Unexpected attribute \"%s\" in client-first-message.",
1000 						   sanitize_char(*input))));
1001 	input++;
1002 
1003 	state->client_first_message_bare = pstrdup(input);
1004 
1005 	/*
1006 	 * Any mandatory extensions would go here.  We don't support any.
1007 	 *
1008 	 * RFC 5802 specifies error code "e=extensions-not-supported" for this,
1009 	 * but it can only be sent in the server-final message.  We prefer to fail
1010 	 * immediately (which the RFC also allows).
1011 	 */
1012 	if (*input == 'm')
1013 		ereport(ERROR,
1014 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
1015 				 errmsg("client requires an unsupported SCRAM extension")));
1016 
1017 	/*
1018 	 * Read username.  Note: this is ignored.  We use the username from the
1019 	 * startup message instead, still it is kept around if provided as it
1020 	 * proves to be useful for debugging purposes.
1021 	 */
1022 	state->client_username = read_attr_value(&input, 'n');
1023 
1024 	/* read nonce and check that it is made of only printable characters */
1025 	state->client_nonce = read_attr_value(&input, 'r');
1026 	if (!is_scram_printable(state->client_nonce))
1027 		ereport(ERROR,
1028 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
1029 				 errmsg("non-printable characters in SCRAM nonce")));
1030 
1031 	/*
1032 	 * There can be any number of optional extensions after this.  We don't
1033 	 * support any extensions, so ignore them.
1034 	 */
1035 	while (*input != '\0')
1036 		read_any_attr(&input, NULL);
1037 
1038 	/* success! */
1039 }
1040 
1041 /*
1042  * Verify the final nonce contained in the last message received from
1043  * client in an exchange.
1044  */
1045 static bool
verify_final_nonce(scram_state * state)1046 verify_final_nonce(scram_state *state)
1047 {
1048 	int			client_nonce_len = strlen(state->client_nonce);
1049 	int			server_nonce_len = strlen(state->server_nonce);
1050 	int			final_nonce_len = strlen(state->client_final_nonce);
1051 
1052 	if (final_nonce_len != client_nonce_len + server_nonce_len)
1053 		return false;
1054 	if (memcmp(state->client_final_nonce, state->client_nonce, client_nonce_len) != 0)
1055 		return false;
1056 	if (memcmp(state->client_final_nonce + client_nonce_len, state->server_nonce, server_nonce_len) != 0)
1057 		return false;
1058 
1059 	return true;
1060 }
1061 
1062 /*
1063  * Verify the client proof contained in the last message received from
1064  * client in an exchange.
1065  */
1066 static bool
verify_client_proof(scram_state * state)1067 verify_client_proof(scram_state *state)
1068 {
1069 	uint8		ClientSignature[SCRAM_KEY_LEN];
1070 	uint8		ClientKey[SCRAM_KEY_LEN];
1071 	uint8		client_StoredKey[SCRAM_KEY_LEN];
1072 	scram_HMAC_ctx ctx;
1073 	int			i;
1074 
1075 	/* calculate ClientSignature */
1076 	scram_HMAC_init(&ctx, state->StoredKey, SCRAM_KEY_LEN);
1077 	scram_HMAC_update(&ctx,
1078 					  state->client_first_message_bare,
1079 					  strlen(state->client_first_message_bare));
1080 	scram_HMAC_update(&ctx, ",", 1);
1081 	scram_HMAC_update(&ctx,
1082 					  state->server_first_message,
1083 					  strlen(state->server_first_message));
1084 	scram_HMAC_update(&ctx, ",", 1);
1085 	scram_HMAC_update(&ctx,
1086 					  state->client_final_message_without_proof,
1087 					  strlen(state->client_final_message_without_proof));
1088 	scram_HMAC_final(ClientSignature, &ctx);
1089 
1090 	/* Extract the ClientKey that the client calculated from the proof */
1091 	for (i = 0; i < SCRAM_KEY_LEN; i++)
1092 		ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
1093 
1094 	/* Hash it one more time, and compare with StoredKey */
1095 	scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey);
1096 
1097 	if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0)
1098 		return false;
1099 
1100 	return true;
1101 }
1102 
1103 /*
1104  * Build the first server-side message sent to the client in a SCRAM
1105  * communication exchange.
1106  */
1107 static char *
build_server_first_message(scram_state * state)1108 build_server_first_message(scram_state *state)
1109 {
1110 	/*------
1111 	 * The syntax for the server-first-message is: (RFC 5802)
1112 	 *
1113 	 * server-first-message =
1114 	 *					 [reserved-mext ","] nonce "," salt ","
1115 	 *					 iteration-count ["," extensions]
1116 	 *
1117 	 * nonce		   = "r=" c-nonce [s-nonce]
1118 	 *					 ;; Second part provided by server.
1119 	 *
1120 	 * c-nonce		   = printable
1121 	 *
1122 	 * s-nonce		   = printable
1123 	 *
1124 	 * salt			   = "s=" base64
1125 	 *
1126 	 * iteration-count = "i=" posit-number
1127 	 *					 ;; A positive number.
1128 	 *
1129 	 * Example:
1130 	 *
1131 	 * r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096
1132 	 *------
1133 	 */
1134 
1135 	/*
1136 	 * Per the spec, the nonce may consist of any printable ASCII characters.
1137 	 * For convenience, however, we don't use the whole range available,
1138 	 * rather, we generate some random bytes, and base64 encode them.
1139 	 */
1140 	char		raw_nonce[SCRAM_RAW_NONCE_LEN];
1141 	int			encoded_len;
1142 
1143 	if (!pg_backend_random(raw_nonce, SCRAM_RAW_NONCE_LEN))
1144 		ereport(ERROR,
1145 				(errcode(ERRCODE_INTERNAL_ERROR),
1146 				 errmsg("could not generate random nonce")));
1147 
1148 	state->server_nonce = palloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
1149 	encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, state->server_nonce);
1150 	state->server_nonce[encoded_len] = '\0';
1151 
1152 	state->server_first_message =
1153 		psprintf("r=%s%s,s=%s,i=%u",
1154 				 state->client_nonce, state->server_nonce,
1155 				 state->salt, state->iterations);
1156 
1157 	return pstrdup(state->server_first_message);
1158 }
1159 
1160 
1161 /*
1162  * Read and parse the final message received from client.
1163  */
1164 static void
read_client_final_message(scram_state * state,char * input)1165 read_client_final_message(scram_state *state, char *input)
1166 {
1167 	char		attr;
1168 	char	   *channel_binding;
1169 	char	   *value;
1170 	char	   *begin,
1171 			   *proof;
1172 	char	   *p;
1173 	char	   *client_proof;
1174 
1175 	begin = p = pstrdup(input);
1176 
1177 	/*------
1178 	 * The syntax for the server-first-message is: (RFC 5802)
1179 	 *
1180 	 * gs2-header	   = gs2-cbind-flag "," [ authzid ] ","
1181 	 *					 ;; GS2 header for SCRAM
1182 	 *					 ;; (the actual GS2 header includes an optional
1183 	 *					 ;; flag to indicate that the GSS mechanism is not
1184 	 *					 ;; "standard", but since SCRAM is "standard", we
1185 	 *					 ;; don't include that flag).
1186 	 *
1187 	 * cbind-input	 = gs2-header [ cbind-data ]
1188 	 *					 ;; cbind-data MUST be present for
1189 	 *					 ;; gs2-cbind-flag of "p" and MUST be absent
1190 	 *					 ;; for "y" or "n".
1191 	 *
1192 	 * channel-binding = "c=" base64
1193 	 *					 ;; base64 encoding of cbind-input.
1194 	 *
1195 	 * proof		   = "p=" base64
1196 	 *
1197 	 * client-final-message-without-proof =
1198 	 *					 channel-binding "," nonce [","
1199 	 *					 extensions]
1200 	 *
1201 	 * client-final-message =
1202 	 *					 client-final-message-without-proof "," proof
1203 	 *------
1204 	 */
1205 
1206 	/*
1207 	 * Read channel binding.  This repeats the channel-binding flags and is
1208 	 * then followed by the actual binding data depending on the type.
1209 	 */
1210 	channel_binding = read_attr_value(&p, 'c');
1211 	if (state->channel_binding_in_use)
1212 	{
1213 #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
1214 		const char *cbind_data = NULL;
1215 		size_t		cbind_data_len = 0;
1216 		size_t		cbind_header_len;
1217 		char	   *cbind_input;
1218 		size_t		cbind_input_len;
1219 		char	   *b64_message;
1220 		int			b64_message_len;
1221 
1222 		Assert(state->cbind_flag == 'p');
1223 
1224 		/* Fetch hash data of server's SSL certificate */
1225 		cbind_data = be_tls_get_certificate_hash(state->port,
1226 												 &cbind_data_len);
1227 
1228 		/* should not happen */
1229 		if (cbind_data == NULL || cbind_data_len == 0)
1230 			elog(ERROR, "could not get server certificate hash");
1231 
1232 		cbind_header_len = strlen("p=tls-server-end-point,,");	/* p=type,, */
1233 		cbind_input_len = cbind_header_len + cbind_data_len;
1234 		cbind_input = palloc(cbind_input_len);
1235 		snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,");
1236 		memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
1237 
1238 		b64_message = palloc(pg_b64_enc_len(cbind_input_len) + 1);
1239 		b64_message_len = pg_b64_encode(cbind_input, cbind_input_len,
1240 										b64_message);
1241 		b64_message[b64_message_len] = '\0';
1242 
1243 		/*
1244 		 * Compare the value sent by the client with the value expected by the
1245 		 * server.
1246 		 */
1247 		if (strcmp(channel_binding, b64_message) != 0)
1248 			ereport(ERROR,
1249 					(errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
1250 					 (errmsg("SCRAM channel binding check failed"))));
1251 #else
1252 		/* shouldn't happen, because we checked this earlier already */
1253 		elog(ERROR, "channel binding not supported by this build");
1254 #endif
1255 	}
1256 	else
1257 	{
1258 		/*
1259 		 * If we are not using channel binding, the binding data is expected
1260 		 * to always be "biws", which is "n,," base64-encoded, or "eSws",
1261 		 * which is "y,,".  We also have to check whether the flag is the same
1262 		 * one that the client originally sent.
1263 		 */
1264 		if (!(strcmp(channel_binding, "biws") == 0 && state->cbind_flag == 'n') &&
1265 			!(strcmp(channel_binding, "eSws") == 0 && state->cbind_flag == 'y'))
1266 			ereport(ERROR,
1267 					(errcode(ERRCODE_PROTOCOL_VIOLATION),
1268 					 (errmsg("unexpected SCRAM channel-binding attribute in client-final-message"))));
1269 	}
1270 
1271 	state->client_final_nonce = read_attr_value(&p, 'r');
1272 
1273 	/* ignore optional extensions */
1274 	do
1275 	{
1276 		proof = p - 1;
1277 		value = read_any_attr(&p, &attr);
1278 	} while (attr != 'p');
1279 
1280 	client_proof = palloc(pg_b64_dec_len(strlen(value)));
1281 	if (pg_b64_decode(value, strlen(value), client_proof) != SCRAM_KEY_LEN)
1282 		ereport(ERROR,
1283 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
1284 				 errmsg("malformed SCRAM message"),
1285 				 errdetail("Malformed proof in client-final-message.")));
1286 	memcpy(state->ClientProof, client_proof, SCRAM_KEY_LEN);
1287 	pfree(client_proof);
1288 
1289 	if (*p != '\0')
1290 		ereport(ERROR,
1291 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
1292 				 errmsg("malformed SCRAM message"),
1293 				 errdetail("Garbage found at the end of client-final-message.")));
1294 
1295 	state->client_final_message_without_proof = palloc(proof - begin + 1);
1296 	memcpy(state->client_final_message_without_proof, input, proof - begin);
1297 	state->client_final_message_without_proof[proof - begin] = '\0';
1298 }
1299 
1300 /*
1301  * Build the final server-side message of an exchange.
1302  */
1303 static char *
build_server_final_message(scram_state * state)1304 build_server_final_message(scram_state *state)
1305 {
1306 	uint8		ServerSignature[SCRAM_KEY_LEN];
1307 	char	   *server_signature_base64;
1308 	int			siglen;
1309 	scram_HMAC_ctx ctx;
1310 
1311 	/* calculate ServerSignature */
1312 	scram_HMAC_init(&ctx, state->ServerKey, SCRAM_KEY_LEN);
1313 	scram_HMAC_update(&ctx,
1314 					  state->client_first_message_bare,
1315 					  strlen(state->client_first_message_bare));
1316 	scram_HMAC_update(&ctx, ",", 1);
1317 	scram_HMAC_update(&ctx,
1318 					  state->server_first_message,
1319 					  strlen(state->server_first_message));
1320 	scram_HMAC_update(&ctx, ",", 1);
1321 	scram_HMAC_update(&ctx,
1322 					  state->client_final_message_without_proof,
1323 					  strlen(state->client_final_message_without_proof));
1324 	scram_HMAC_final(ServerSignature, &ctx);
1325 
1326 	server_signature_base64 = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1);
1327 	siglen = pg_b64_encode((const char *) ServerSignature,
1328 						   SCRAM_KEY_LEN, server_signature_base64);
1329 	server_signature_base64[siglen] = '\0';
1330 
1331 	/*------
1332 	 * The syntax for the server-final-message is: (RFC 5802)
1333 	 *
1334 	 * verifier		   = "v=" base64
1335 	 *					 ;; base-64 encoded ServerSignature.
1336 	 *
1337 	 * server-final-message = (server-error / verifier)
1338 	 *					 ["," extensions]
1339 	 *
1340 	 *------
1341 	 */
1342 	return psprintf("v=%s", server_signature_base64);
1343 }
1344 
1345 
1346 /*
1347  * Deterministically generate salt for mock authentication, using a SHA256
1348  * hash based on the username and a cluster-level secret key.  Returns a
1349  * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN.
1350  */
1351 static char *
scram_mock_salt(const char * username)1352 scram_mock_salt(const char *username)
1353 {
1354 	pg_sha256_ctx ctx;
1355 	static uint8 sha_digest[PG_SHA256_DIGEST_LENGTH];
1356 	char	   *mock_auth_nonce = GetMockAuthenticationNonce();
1357 
1358 	/*
1359 	 * Generate salt using a SHA256 hash of the username and the cluster's
1360 	 * mock authentication nonce.  (This works as long as the salt length is
1361 	 * not larger the SHA256 digest length. If the salt is smaller, the caller
1362 	 * will just ignore the extra data.)
1363 	 */
1364 	StaticAssertStmt(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
1365 					 "salt length greater than SHA256 digest length");
1366 
1367 	pg_sha256_init(&ctx);
1368 	pg_sha256_update(&ctx, (uint8 *) username, strlen(username));
1369 	pg_sha256_update(&ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN);
1370 	pg_sha256_final(&ctx, sha_digest);
1371 
1372 	return (char *) sha_digest;
1373 }
1374