1 /*
2  * PgBouncer - Lightweight connection pooler for PostgreSQL.
3  *
4  * Copyright (c) 2007-2009  Marko Kreen, Skype Technologies OÜ
5  *
6  * Permission to use, copy, modify, and/or distribute this software for any
7  * purpose with or without fee is hereby granted, provided that the above
8  * copyright notice and this permission notice appear in all copies.
9  *
10  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17  */
18 
19 /*
20  * SCRAM support
21  */
22 
23 #include "bouncer.h"
24 #include "scram.h"
25 #include "common/base64.h"
26 #include "common/saslprep.h"
27 #include "common/scram-common.h"
28 
29 
30 static bool calculate_client_proof(ScramState *scram_state,
31 				   const PgUser *user,
32 				   const char *salt,
33 				   int saltlen,
34 				   int iterations,
35 				   const char *client_final_message_without_proof,
36 				   uint8_t *result);
37 
38 
39 /*
40  * free SCRAM state info after auth is done
41  */
free_scram_state(ScramState * scram_state)42 void free_scram_state(ScramState *scram_state)
43 {
44 	free(scram_state->client_nonce);
45 	free(scram_state->client_first_message_bare);
46 	free(scram_state->client_final_message_without_proof);
47 	free(scram_state->server_nonce);
48 	free(scram_state->server_first_message);
49 	free(scram_state->SaltedPassword);
50 	free(scram_state->salt);
51 	memset(scram_state, 0, sizeof(*scram_state));
52 }
53 
is_scram_printable(char * p)54 static bool is_scram_printable(char *p)
55 {
56 	/*------
57 	 * Printable characters, as defined by SCRAM spec: (RFC 5802)
58 	 *
59 	 *  printable       = %x21-2B / %x2D-7E
60 	 *                    ;; Printable ASCII except ",".
61 	 *                    ;; Note that any "printable" is also
62 	 *                    ;; a valid "value".
63 	 *------
64 	 */
65 	for (; *p; p++)
66 		if (*p < 0x21 || *p > 0x7E || *p == 0x2C /* comma */ )
67 			return false;
68 
69 	return true;
70 }
71 
sanitize_char(char c)72 static char *sanitize_char(char c)
73 {
74 	static char buf[5];
75 
76 	if (c >= 0x21 && c <= 0x7E)
77 		snprintf(buf, sizeof(buf), "'%c'", c);
78 	else
79 		snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
80 	return buf;
81 }
82 
83 /*
84  * Read value for an attribute part of a SCRAM message.
85  */
read_attr_value(PgSocket * sk,char ** input,char attr)86 static char *read_attr_value(PgSocket *sk, char **input, char attr)
87 {
88 	char *begin = *input;
89 	char *end;
90 
91 	if (*begin != attr)
92 	{
93 		slog_error(sk, "malformed SCRAM message (attribute \"%c\" expected)",
94 			   attr);
95 		return NULL;
96 	}
97 	begin++;
98 
99 	if (*begin != '=')
100 	{
101 		slog_error(sk, "malformed SCRAM message (expected \"=\" after attribute \"%c\")",
102 			   attr);
103 		return NULL;
104 	}
105 	begin++;
106 
107 	end = begin;
108 	while (*end && *end != ',')
109 		end++;
110 
111 	if (*end)
112 	{
113 		*end = '\0';
114 		*input = end + 1;
115 	}
116 	else
117 		*input = end;
118 
119 	return begin;
120 }
121 
122 /*
123  * Read the next attribute and value in a SCRAM exchange message.
124  *
125  * Returns NULL if there is no attribute.
126  */
127 static char *
read_any_attr(PgSocket * sk,char ** input,char * attr_p)128 read_any_attr(PgSocket *sk, char **input, char *attr_p)
129 {
130 	char *begin = *input;
131 	char *end;
132 	char attr = *begin;
133 
134 	if (!((attr >= 'A' && attr <= 'Z') ||
135 	      (attr >= 'a' && attr <= 'z')))
136 	{
137 		slog_error(sk, "malformed SCRAM message (attribute expected, but found invalid character \"%s\")",
138 			   sanitize_char(attr));
139 		return NULL;
140 	}
141 	if (attr_p)
142 		*attr_p = attr;
143 	begin++;
144 
145 	if (*begin != '=')
146 	{
147 		slog_error(sk, "malformed SCRAM message (expected character \"=\" after attribute \"%c\")",
148 			   attr);
149 		return NULL;
150 	}
151 	begin++;
152 
153 	end = begin;
154 	while (*end && *end != ',')
155 		end++;
156 
157 	if (*end)
158 	{
159 		*end = '\0';
160 		*input = end + 1;
161 	}
162 	else
163 		*input = end;
164 
165 	return begin;
166 }
167 
168 /*
169  * Parse and validate format of given SCRAM verifier.
170  *
171  * Returns true if the SCRAM verifier has been parsed, and false otherwise.
172  */
parse_scram_verifier(const char * verifier,int * iterations,char ** salt,uint8_t * stored_key,uint8_t * server_key)173 static bool parse_scram_verifier(const char *verifier, int *iterations, char **salt,
174 				 uint8_t *stored_key, uint8_t *server_key)
175 {
176 	char	   *v;
177 	char	   *p;
178 	char	   *scheme_str;
179 	char	   *salt_str;
180 	char	   *iterations_str;
181 	char	   *storedkey_str;
182 	char	   *serverkey_str;
183 	int			decoded_len;
184 	char	   *decoded_salt_buf;
185 	char	   *decoded_stored_buf = NULL;
186 	char	   *decoded_server_buf = NULL;
187 
188 	/*
189 	 * The verifier is of form:
190 	 *
191 	 * SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
192 	 */
193 	v = strdup(verifier);
194 	if (!v)
195 		goto invalid_verifier;
196 	if ((scheme_str = strtok(v, "$")) == NULL)
197 		goto invalid_verifier;
198 	if ((iterations_str = strtok(NULL, ":")) == NULL)
199 		goto invalid_verifier;
200 	if ((salt_str = strtok(NULL, "$")) == NULL)
201 		goto invalid_verifier;
202 	if ((storedkey_str = strtok(NULL, ":")) == NULL)
203 		goto invalid_verifier;
204 	if ((serverkey_str = strtok(NULL, "")) == NULL)
205 		goto invalid_verifier;
206 
207 	/* Parse the fields */
208 	if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
209 		goto invalid_verifier;
210 
211 	errno = 0;
212 	*iterations = strtol(iterations_str, &p, 10);
213 	if (*p || errno != 0)
214 		goto invalid_verifier;
215 
216 	/*
217 	 * Verify that the salt is in Base64-encoded format, by decoding it,
218 	 * although we return the encoded version to the caller.
219 	 */
220 	decoded_salt_buf = malloc(pg_b64_dec_len(strlen(salt_str)));
221 	if (!decoded_salt_buf)
222 		goto invalid_verifier;
223 	decoded_len = pg_b64_decode(salt_str, strlen(salt_str), decoded_salt_buf);
224 	free(decoded_salt_buf);
225 	if (decoded_len < 0)
226 		goto invalid_verifier;
227 	*salt = strdup(salt_str);
228 	if (!*salt)
229 		goto invalid_verifier;
230 
231 	/*
232 	 * Decode StoredKey and ServerKey.
233 	 */
234 	decoded_stored_buf = malloc(pg_b64_dec_len(strlen(storedkey_str)));
235 	if (!decoded_stored_buf)
236 		goto invalid_verifier;
237 	decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str), decoded_stored_buf);
238 	if (decoded_len != SCRAM_KEY_LEN)
239 		goto invalid_verifier;
240 	memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
241 
242 	decoded_server_buf = malloc(pg_b64_dec_len(strlen(serverkey_str)));
243 	decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
244 				    decoded_server_buf);
245 	if (decoded_len != SCRAM_KEY_LEN)
246 		goto invalid_verifier;
247 	memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
248 
249 	free(decoded_stored_buf);
250 	free(decoded_server_buf);
251 	free(v);
252 	return true;
253 
254 invalid_verifier:
255 	free(decoded_stored_buf);
256 	free(decoded_server_buf);
257 	free(v);
258 	free(*salt);
259 	*salt = NULL;
260 	return false;
261 }
262 
263 #define MD5_PASSWD_CHARSET "0123456789abcdef"
264 
265 /*
266  * What kind of a password verifier is 'shadow_pass'?
267  */
268 PasswordType
get_password_type(const char * shadow_pass)269 get_password_type(const char *shadow_pass)
270 {
271 	char *encoded_salt = NULL;
272 	int iterations;
273 	uint8_t stored_key[SCRAM_KEY_LEN];
274 	uint8_t server_key[SCRAM_KEY_LEN];
275 
276 	if (strncmp(shadow_pass, "md5", 3) == 0 &&
277 	    strlen(shadow_pass) == MD5_PASSWD_LEN &&
278 	    strspn(shadow_pass + 3, MD5_PASSWD_CHARSET) == MD5_PASSWD_LEN - 3)
279 		return PASSWORD_TYPE_MD5;
280 	if (parse_scram_verifier(shadow_pass, &iterations, &encoded_salt,
281 				 stored_key, server_key)) {
282 		free(encoded_salt);
283 		return PASSWORD_TYPE_SCRAM_SHA_256;
284 	}
285 	free(encoded_salt);
286 	return PASSWORD_TYPE_PLAINTEXT;
287 }
288 
289 /*
290  * Functions for communicating as a client with the server
291  */
292 
build_client_first_message(ScramState * scram_state)293 char *build_client_first_message(ScramState *scram_state)
294 {
295 	uint8_t raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
296 	int encoded_len;
297 	size_t len;
298 	char *result = NULL;
299 
300 	get_random_bytes(raw_nonce, SCRAM_RAW_NONCE_LEN);
301 
302 	scram_state->client_nonce = malloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
303 	if (scram_state->client_nonce == NULL)
304 		goto failed;
305 	encoded_len = pg_b64_encode((char *) raw_nonce, SCRAM_RAW_NONCE_LEN, scram_state->client_nonce);
306 	scram_state->client_nonce[encoded_len] = '\0';
307 
308 	len = 8 + strlen(scram_state->client_nonce) + 1;
309 	result = malloc(len);
310 	if (result == NULL)
311 		goto failed;
312 	snprintf(result, len, "n,,n=,r=%s", scram_state->client_nonce);
313 
314 	scram_state->client_first_message_bare = strdup(result + 3);
315 	if (scram_state->client_first_message_bare == NULL)
316 		goto failed;
317 
318 	return result;
319 
320 failed:
321 	free(result);
322 	free(scram_state->client_nonce);
323 	free(scram_state->client_first_message_bare);
324 	return NULL;
325 }
326 
build_client_final_message(ScramState * scram_state,const PgUser * user,const char * server_nonce,const char * salt,int saltlen,int iterations)327 char *build_client_final_message(ScramState *scram_state,
328 				 const PgUser *user,
329 				 const char *server_nonce,
330 				 const char *salt,
331 				 int saltlen,
332 				 int iterations)
333 {
334 	char buf[512];
335 	size_t len;
336 	uint8_t	client_proof[SCRAM_KEY_LEN];
337 
338 	snprintf(buf, sizeof(buf), "c=biws,r=%s", server_nonce);
339 
340 	scram_state->client_final_message_without_proof = strdup(buf);
341 	if (scram_state->client_final_message_without_proof == NULL)
342 		goto failed;
343 
344 	if (!calculate_client_proof(scram_state, user,
345 				    salt, saltlen, iterations, buf,
346 				    client_proof))
347 		goto failed;
348 
349 	len = strlcat(buf, ",p=", sizeof(buf));
350 	len += pg_b64_encode((char *) client_proof,
351 			     SCRAM_KEY_LEN,
352 			     buf + len);
353 	buf[len] = '\0';
354 
355 	return strdup(buf);
356 failed:
357 	return NULL;
358 }
359 
read_server_first_message(PgSocket * server,char * input,char ** server_nonce_p,char ** salt_p,int * saltlen_p,int * iterations_p)360 bool read_server_first_message(PgSocket *server, char *input,
361 			       char **server_nonce_p, char **salt_p, int *saltlen_p, int *iterations_p)
362 {
363 	char *server_nonce;
364 	char *encoded_salt;
365 	char *salt = NULL;
366 	int saltlen;
367 	char *iterations_str;
368 	char *endptr;
369 	int iterations;
370 
371 	server->scram_state.server_first_message = strdup(input);
372 	if (server->scram_state.server_first_message == NULL)
373 		goto failed;
374 
375 	server_nonce = read_attr_value(server, &input, 'r');
376 	if (server_nonce == NULL)
377 		goto failed;
378 
379 	if (strlen(server_nonce) < strlen(server->scram_state.client_nonce) ||
380 	    memcmp(server_nonce, server->scram_state.client_nonce, strlen(server->scram_state.client_nonce)) != 0)
381 	{
382 		slog_error(server, "invalid SCRAM response (nonce mismatch)");
383 		goto failed;
384 	}
385 
386 	encoded_salt = read_attr_value(server, &input, 's');
387 	if (encoded_salt == NULL)
388 		goto failed;
389 	salt = malloc(pg_b64_dec_len(strlen(encoded_salt)));
390 	if (salt == NULL)
391 		goto failed;
392 	saltlen = pg_b64_decode(encoded_salt,
393 				strlen(encoded_salt),
394 				salt);
395 	if (saltlen < 0)
396 	{
397 		slog_error(server, "malformed SCRAM message (invalid salt)");
398 		goto failed;
399 	}
400 
401 	iterations_str = read_attr_value(server, &input, 'i');
402 	if (iterations_str == NULL)
403 		goto failed;
404 
405 	iterations = strtol(iterations_str, &endptr, 10);
406 	if (*endptr != '\0' || iterations < 1)
407 	{
408 		slog_error(server, "malformed SCRAM message (invalid iteration count)");
409 		goto failed;
410 	}
411 
412 	if (*input != '\0')
413 	{
414 		slog_error(server, "malformed SCRAM message (garbage at end of server-first-message)");
415 		goto failed;
416 	}
417 
418 	*server_nonce_p = server_nonce;
419 	*salt_p = salt;
420 	*saltlen_p = saltlen;
421 	*iterations_p = iterations;
422 	return true;
423 failed:
424 	free(salt);
425 	return false;
426 }
427 
read_server_final_message(PgSocket * server,char * input,char * ServerSignature)428 bool read_server_final_message(PgSocket *server, char *input, char *ServerSignature)
429 {
430 	char *encoded_server_signature;
431 	char *decoded_server_signature = NULL;
432 	int server_signature_len;
433 
434 	if (*input == 'e')
435 	{
436 		char *errmsg = read_attr_value(server, &input, 'e');
437 		slog_error(server, "error received from server in SCRAM exchange: %s",
438 			   errmsg);
439 		goto failed;
440 	}
441 
442 	encoded_server_signature = read_attr_value(server, &input, 'v');
443 	if (encoded_server_signature == NULL)
444 		goto failed;
445 
446 	if (*input != '\0')
447 		slog_error(server, "malformed SCRAM message (garbage at end of server-final-message)");
448 
449 	server_signature_len = pg_b64_dec_len(strlen(encoded_server_signature));
450 	decoded_server_signature = malloc(server_signature_len);
451 	if (!decoded_server_signature)
452 		goto failed;
453 
454 	server_signature_len = pg_b64_decode(encoded_server_signature,
455 					     strlen(encoded_server_signature),
456 					     decoded_server_signature);
457 	if (server_signature_len != SCRAM_KEY_LEN)
458 	{
459 		slog_error(server, "malformed SCRAM message (malformed server signature)");
460 		goto failed;
461 	}
462 	memcpy(ServerSignature, decoded_server_signature, SCRAM_KEY_LEN);
463 
464 	free(decoded_server_signature);
465 	return true;
466 failed:
467 	free(decoded_server_signature);
468 	return false;
469 }
470 
calculate_client_proof(ScramState * scram_state,const PgUser * user,const char * salt,int saltlen,int iterations,const char * client_final_message_without_proof,uint8_t * result)471 static bool calculate_client_proof(ScramState *scram_state,
472 				   const PgUser *user,
473 				   const char *salt,
474 				   int saltlen,
475 				   int iterations,
476 				   const char *client_final_message_without_proof,
477 				   uint8_t *result)
478 {
479 	pg_saslprep_rc rc;
480 	char *prep_password = NULL;
481 	uint8_t	StoredKey[SCRAM_KEY_LEN];
482 	uint8_t	ClientKey[SCRAM_KEY_LEN];
483 	uint8_t	ClientSignature[SCRAM_KEY_LEN];
484 	scram_HMAC_ctx ctx;
485 
486 	if (user->has_scram_keys)
487 	{
488 		memcpy(ClientKey, user->scram_ClientKey, SCRAM_KEY_LEN);
489 	}
490 	else
491 	{
492 		rc = pg_saslprep(user->passwd, &prep_password);
493 		if (rc == SASLPREP_OOM)
494 			return false;
495 		if (rc != SASLPREP_SUCCESS)
496 		{
497 			prep_password = strdup(user->passwd);
498 			if (!prep_password)
499 				return false;
500 		}
501 
502 		scram_state->SaltedPassword = malloc(SCRAM_KEY_LEN);
503 		if (scram_state->SaltedPassword == NULL)
504 			goto failed;
505 		scram_SaltedPassword(prep_password,
506 				     salt,
507 				     saltlen,
508 				     iterations,
509 				     scram_state->SaltedPassword);
510 
511 		scram_ClientKey(scram_state->SaltedPassword, ClientKey);
512 	}
513 
514 	scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey);
515 
516 	scram_HMAC_init(&ctx, StoredKey, SCRAM_KEY_LEN);
517 	scram_HMAC_update(&ctx,
518 			  scram_state->client_first_message_bare,
519 			  strlen(scram_state->client_first_message_bare));
520 	scram_HMAC_update(&ctx, ",", 1);
521 	scram_HMAC_update(&ctx,
522 			  scram_state->server_first_message,
523 			  strlen(scram_state->server_first_message));
524 	scram_HMAC_update(&ctx, ",", 1);
525 	scram_HMAC_update(&ctx,
526 			  client_final_message_without_proof,
527 			  strlen(client_final_message_without_proof));
528 	scram_HMAC_final(ClientSignature, &ctx);
529 
530 	for (int i = 0; i < SCRAM_KEY_LEN; i++)
531 		result[i] = ClientKey[i] ^ ClientSignature[i];
532 
533 	free(prep_password);
534 	return true;
535 failed:
536 	free(prep_password);
537 	return false;
538 }
539 
verify_server_signature(ScramState * scram_state,const PgUser * user,const char * ServerSignature)540 bool verify_server_signature(ScramState *scram_state, const PgUser *user, const char *ServerSignature)
541 {
542 	uint8_t expected_ServerSignature[SCRAM_KEY_LEN];
543 	uint8_t ServerKey[SCRAM_KEY_LEN];
544 	scram_HMAC_ctx ctx;
545 
546 	if (user->has_scram_keys)
547 		memcpy(ServerKey, user->scram_ServerKey, SCRAM_KEY_LEN);
548 	else
549 		scram_ServerKey(scram_state->SaltedPassword, ServerKey);
550 
551 	scram_HMAC_init(&ctx, ServerKey, SCRAM_KEY_LEN);
552 	scram_HMAC_update(&ctx,
553 			  scram_state->client_first_message_bare,
554 			  strlen(scram_state->client_first_message_bare));
555 	scram_HMAC_update(&ctx, ",", 1);
556 	scram_HMAC_update(&ctx,
557 			  scram_state->server_first_message,
558 			  strlen(scram_state->server_first_message));
559 	scram_HMAC_update(&ctx, ",", 1);
560 	scram_HMAC_update(&ctx,
561 			  scram_state->client_final_message_without_proof,
562 			  strlen(scram_state->client_final_message_without_proof));
563 	scram_HMAC_final(expected_ServerSignature, &ctx);
564 
565 	if (memcmp(expected_ServerSignature, ServerSignature, SCRAM_KEY_LEN) != 0)
566 		return false;
567 
568 	return true;
569 }
570 
571 
572 /*
573  * Functions for communicating as a server to the client
574  */
575 
read_client_first_message(PgSocket * client,char * input,char * cbind_flag_p,char ** client_first_message_bare_p,char ** client_nonce_p)576 bool read_client_first_message(PgSocket *client, char *input,
577 			       char *cbind_flag_p,
578 			       char **client_first_message_bare_p,
579 			       char **client_nonce_p)
580 {
581 	char *client_first_message_bare = NULL;
582 	char *client_nonce = NULL;
583 	char *client_nonce_copy = NULL;
584 
585 	*cbind_flag_p = *input;
586 	switch (*input) {
587 	case 'n':
588 		/* Client does not support channel binding */
589 		input++;
590 		break;
591 	case 'y':
592 		/* Client supports channel binding, but we're not doing it today */
593 		input++;
594 		break;
595 	case 'p':
596 		/* Client requires channel binding.  We don't support it. */
597 		slog_error(client, "client requires SCRAM channel binding, but it is not supported");
598 		goto failed;
599 	default:
600 		slog_error(client, "malformed SCRAM message (unexpected channel-binding flag \"%s\")",
601 			   sanitize_char(*input));
602 		goto failed;
603 	}
604 
605 	if (*input != ',') {
606 		slog_error(client, "malformed SCRAM message (comma expected, but found character \"%s\")",
607 			   sanitize_char(*input));
608 		goto failed;
609 	}
610 	input++;
611 
612 	if (*input == 'a') {
613 		slog_error(client, "client uses authorization identity, but it is not supported");
614 		goto failed;
615 	}
616 	if (*input != ',') {
617 		slog_error(client, "malformed SCRAM message (unexpected attribute \"%s\" in client-first-message)",
618 			   sanitize_char(*input));
619 		goto failed;
620 	}
621 	input++;
622 
623 	client_first_message_bare = strdup(input);
624 	if (client_first_message_bare == NULL)
625 		goto failed;
626 
627 	if (*input == 'm') {
628 		slog_error(client, "client requires an unsupported SCRAM extension");
629 		goto failed;
630 	}
631 
632 	/* read and ignore user name */
633 	read_attr_value(client, &input, 'n');
634 
635 	client_nonce = read_attr_value(client, &input, 'r');
636 	if (client_nonce == NULL)
637 		goto failed;
638 	if (!is_scram_printable(client_nonce)) {
639 		slog_error(client, "non-printable characters in SCRAM nonce");
640 		goto failed;
641 	}
642 	client_nonce_copy = strdup(client_nonce);
643 	if (client_nonce_copy == NULL)
644 		goto failed;
645 
646 	/*
647 	 * There can be any number of optional extensions after this.  We don't
648 	 * support any extensions, so ignore them.
649 	 */
650 	while (*input != '\0') {
651 		if (!read_any_attr(client, &input, NULL))
652 			goto failed;
653 	}
654 
655 	*client_first_message_bare_p = client_first_message_bare;
656 	*client_nonce_p = client_nonce_copy;
657 	return true;
658 failed:
659 	free(client_first_message_bare);
660 	free(client_nonce_copy);
661 	return false;
662 }
663 
read_client_final_message(PgSocket * client,const uint8_t * raw_input,char * input,const char ** client_final_nonce_p,char ** proof_p)664 bool read_client_final_message(PgSocket *client, const uint8_t *raw_input, char *input,
665 			       const char **client_final_nonce_p,
666 			       char **proof_p)
667 {
668 	const char *input_start = input;
669 	char attr;
670 	char *channel_binding;
671 	char *client_final_nonce;
672 	char *proof_start;
673 	char *value;
674 	char *encoded_proof;
675 	char *proof = NULL;
676 	int prooflen;
677 
678 	/*
679 	 * Read channel-binding.  We don't support channel binding, so
680 	 * it's expected to always be "biws", which is "n,,",
681 	 * base64-encoded, or "eSws", which is "y,,".  We also have to
682 	 * check whether the flag is the same one that the client
683 	 * originally sent.
684 	 */
685 	channel_binding = read_attr_value(client, &input, 'c');
686 	if (channel_binding == NULL)
687 		goto failed;
688 	if (!(strcmp(channel_binding, "biws") == 0 && client->scram_state.cbind_flag == 'n') &&
689 	    !(strcmp(channel_binding, "eSws") == 0 && client->scram_state.cbind_flag == 'y')) {
690 		slog_error(client, "unexpected SCRAM channel-binding attribute in client-final-message");
691 		goto failed;
692 	}
693 
694 	client_final_nonce = read_attr_value(client, &input, 'r');
695 
696 	/* ignore optional extensions */
697 	do
698 	{
699 		proof_start = input - 1;
700 		value = read_any_attr(client, &input, &attr);
701 	} while (value && attr != 'p');
702 
703 	if (!value) {
704 		slog_error(client, "could not read proof");
705 		goto failed;
706 	}
707 
708 	encoded_proof = value;
709 
710 	proof = malloc(pg_b64_dec_len(strlen(encoded_proof)));
711 	if (proof == NULL) {
712 		slog_error(client, "could not decode proof");
713 		goto failed;
714 	}
715 	prooflen = pg_b64_decode(encoded_proof,
716 				 strlen(encoded_proof),
717 				 proof);
718 	(void) prooflen;
719 
720 	if (*input != '\0') {
721 		slog_error(client, "malformed SCRAM message (garbage at the end of client-final-message)");
722 		goto failed;
723 	}
724 
725 	client->scram_state.client_final_message_without_proof = malloc(proof_start - input_start + 1);
726 	if (!client->scram_state.client_final_message_without_proof)
727 		goto failed;
728 	memcpy(client->scram_state.client_final_message_without_proof, raw_input, proof_start - input_start);
729 	client->scram_state.client_final_message_without_proof[proof_start - input_start] = '\0';
730 
731 	*client_final_nonce_p = client_final_nonce;
732 	*proof_p = proof;
733 	return true;
734 failed:
735 	free(proof);
736 	return false;
737 }
738 
739 /*
740  * For doing SCRAM with a password stored in plain text, build a SCRAM
741  * secret on the fly.
742  */
build_adhoc_scram_secret(const char * plain_password,ScramState * scram_state)743 static bool build_adhoc_scram_secret(const char *plain_password, ScramState *scram_state)
744 {
745 	const char *password;
746 	char *prep_password;
747 	pg_saslprep_rc rc;
748 	char saltbuf[SCRAM_DEFAULT_SALT_LEN];
749 	int encoded_len;
750 	uint8_t salted_password[SCRAM_KEY_LEN];
751 
752 	rc = pg_saslprep(plain_password, &prep_password);
753 	if (rc == SASLPREP_OOM)
754 		goto failed;
755 	else if (rc == SASLPREP_SUCCESS)
756 		password = prep_password;
757 	else
758 		password = plain_password;
759 
760 	get_random_bytes((uint8_t *) saltbuf, sizeof(saltbuf));
761 
762 	scram_state->adhoc = true;
763 
764 	scram_state->iterations = SCRAM_DEFAULT_ITERATIONS;
765 
766 	scram_state->salt = malloc(pg_b64_enc_len(sizeof(saltbuf)) + 1);
767 	if (!scram_state->salt)
768 		goto failed;
769 	encoded_len = pg_b64_encode(saltbuf, sizeof(saltbuf), scram_state->salt);
770 	scram_state->salt[encoded_len] = '\0';
771 
772 	/* Calculate StoredKey and ServerKey */
773 	scram_SaltedPassword(password, saltbuf, sizeof(saltbuf),
774 			     scram_state->iterations,
775 			     salted_password);
776 	scram_ClientKey(salted_password, scram_state->StoredKey);
777 	scram_H(scram_state->StoredKey, SCRAM_KEY_LEN, scram_state->StoredKey);
778 	scram_ServerKey(salted_password, scram_state->ServerKey);
779 
780 	if (prep_password)
781 		free(prep_password);
782 	return true;
783 failed:
784 	if (prep_password)
785 		free(prep_password);
786 	return false;
787 }
788 
789 /*
790  * Deterministically generate salt for mock authentication, using a
791  * SHA256 hash based on the username and an instance-level secret key.
792  * Target buffer needs to be of size SCRAM_DEFAULT_SALT_LEN.
793  */
scram_mock_salt(const char * username,uint8_t * saltbuf)794 static void scram_mock_salt(const char *username, uint8_t *saltbuf)
795 {
796 	static uint8_t mock_auth_nonce[32];
797 	static bool mock_auth_nonce_initialized = false;
798 	struct sha256_ctx ctx;
799 	uint8_t sha_digest[PG_SHA256_DIGEST_LENGTH];
800 
801 	/*
802 	 * Generating salt using a SHA256 hash works as long as the
803 	 * required salt length is not larger than the SHA256 digest
804 	 * length.
805 	 */
806 	static_assert(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
807 		      "salt length greater than SHA256 digest length");
808 
809 	if (!mock_auth_nonce_initialized) {
810 		get_random_bytes(mock_auth_nonce, sizeof(mock_auth_nonce));
811 		mock_auth_nonce_initialized = true;
812 	}
813 
814 	sha256_reset(&ctx);
815 	sha256_update(&ctx, (uint8_t *) username, strlen(username));
816 	sha256_update(&ctx, mock_auth_nonce, sizeof(mock_auth_nonce));
817 	sha256_final(&ctx, sha_digest);
818 
819 	memcpy(saltbuf, sha_digest, SCRAM_DEFAULT_SALT_LEN);
820 }
821 
build_mock_scram_secret(const char * username,ScramState * scram_state)822 static bool build_mock_scram_secret(const char *username, ScramState *scram_state)
823 {
824 	uint8_t saltbuf[SCRAM_DEFAULT_SALT_LEN];
825 	int encoded_len;
826 
827 	scram_state->iterations = SCRAM_DEFAULT_ITERATIONS;
828 
829 	scram_mock_salt(username, saltbuf);
830 	scram_state->salt = malloc(pg_b64_enc_len(sizeof(saltbuf)) + 1);
831 	if (!scram_state->salt)
832 		goto failed;
833 	encoded_len = pg_b64_encode((char *) saltbuf, sizeof(saltbuf), scram_state->salt);
834 	scram_state->salt[encoded_len] = '\0';
835 
836 	return true;
837 failed:
838 	return false;
839 }
840 
build_server_first_message(ScramState * scram_state,const char * username,const char * stored_secret)841 char *build_server_first_message(ScramState *scram_state, const char *username, const char *stored_secret)
842 {
843 	uint8_t raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
844 	int encoded_len;
845 	size_t len;
846 	char *result;
847 
848 	if (!stored_secret) {
849 		if (!build_mock_scram_secret(username, scram_state))
850 			goto failed;
851 	} else {
852 		switch (get_password_type(stored_secret)) {
853 		case PASSWORD_TYPE_SCRAM_SHA_256:
854 			if (!parse_scram_verifier(stored_secret,
855 						  &scram_state->iterations,
856 						  &scram_state->salt,
857 						  scram_state->StoredKey,
858 						  scram_state->ServerKey))
859 				goto failed;
860 			break;
861 		case PASSWORD_TYPE_PLAINTEXT:
862 			if (!build_adhoc_scram_secret(stored_secret, scram_state))
863 				goto failed;
864 			break;
865 		default:
866 			/* shouldn't get here */
867 			goto failed;
868 		}
869 	}
870 
871 	get_random_bytes(raw_nonce, SCRAM_RAW_NONCE_LEN);
872 	scram_state->server_nonce = malloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
873 	if (scram_state->server_nonce == NULL)
874 		goto failed;
875 	encoded_len = pg_b64_encode((char *) raw_nonce, SCRAM_RAW_NONCE_LEN, scram_state->server_nonce);
876 	scram_state->server_nonce[encoded_len] = '\0';
877 
878 	len = (2
879 	       + strlen(scram_state->client_nonce)
880 	       + strlen(scram_state->server_nonce)
881 	       + 3
882 	       + strlen(scram_state->salt)
883 	       + 3 + 10 + 1);
884 	result = malloc(len);
885 	if (!result)
886 		goto failed;
887 	snprintf(result, len,
888 		 "r=%s%s,s=%s,i=%u",
889 		 scram_state->client_nonce,
890 		 scram_state->server_nonce,
891 		 scram_state->salt,
892 		 scram_state->iterations);
893 
894 	scram_state->server_first_message = result;
895 
896 	return result;
897 failed:
898 	free(scram_state->server_nonce);
899 	free(scram_state->server_first_message);
900 	return NULL;
901 }
902 
903 static char *
compute_server_signature(ScramState * state)904 compute_server_signature(ScramState *state)
905 {
906 	uint8_t		ServerSignature[SCRAM_KEY_LEN];
907 	char	   *server_signature_base64;
908 	int			siglen;
909 	scram_HMAC_ctx ctx;
910 
911 	/* calculate ServerSignature */
912 	scram_HMAC_init(&ctx, state->ServerKey, SCRAM_KEY_LEN);
913 	scram_HMAC_update(&ctx,
914 			  state->client_first_message_bare,
915 			  strlen(state->client_first_message_bare));
916 	scram_HMAC_update(&ctx, ",", 1);
917 	scram_HMAC_update(&ctx,
918 			  state->server_first_message,
919 			  strlen(state->server_first_message));
920 	scram_HMAC_update(&ctx, ",", 1);
921 	scram_HMAC_update(&ctx,
922 			  state->client_final_message_without_proof,
923 			  strlen(state->client_final_message_without_proof));
924 	scram_HMAC_final(ServerSignature, &ctx);
925 
926 	server_signature_base64 = malloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1);
927 	if (!server_signature_base64)
928 		return NULL;
929 	siglen = pg_b64_encode((const char *) ServerSignature,
930 						   SCRAM_KEY_LEN, server_signature_base64);
931 	server_signature_base64[siglen] = '\0';
932 
933 	return server_signature_base64;
934 }
935 
build_server_final_message(ScramState * scram_state)936 char *build_server_final_message(ScramState *scram_state)
937 {
938 	char *server_signature = NULL;
939 	size_t len;
940 	char *result;
941 
942 	server_signature = compute_server_signature(scram_state);
943 	if (!server_signature)
944 		goto failed;
945 
946 	len = 2 + strlen(server_signature) + 1;
947 	result = malloc(len);
948 	if (!result)
949 		goto failed;
950 	snprintf(result, len, "v=%s", server_signature);
951 
952 	free(server_signature);
953 	return result;
954 failed:
955 	free(server_signature);
956 	return NULL;
957 }
958 
verify_final_nonce(const ScramState * scram_state,const char * client_final_nonce)959 bool verify_final_nonce(const ScramState *scram_state, const char *client_final_nonce)
960 {
961 	size_t client_nonce_len = strlen(scram_state->client_nonce);
962 	size_t server_nonce_len = strlen(scram_state->server_nonce);
963 	size_t final_nonce_len = strlen(client_final_nonce);
964 
965 	if (final_nonce_len != client_nonce_len + server_nonce_len)
966 		return false;
967 	if (memcmp(client_final_nonce, scram_state->client_nonce, client_nonce_len) != 0)
968 		return false;
969 	if (memcmp(client_final_nonce + client_nonce_len, scram_state->server_nonce, server_nonce_len) != 0)
970 		return false;
971 
972 	return true;
973 }
974 
verify_client_proof(ScramState * state,const char * ClientProof)975 bool verify_client_proof(ScramState *state, const char *ClientProof)
976 {
977     uint8_t ClientSignature[SCRAM_KEY_LEN];
978     uint8_t client_StoredKey[SCRAM_KEY_LEN];
979     scram_HMAC_ctx ctx;
980     int i;
981 
982     /* calculate ClientSignature */
983     scram_HMAC_init(&ctx, state->StoredKey, SCRAM_KEY_LEN);
984     scram_HMAC_update(&ctx,
985 		      state->client_first_message_bare,
986 		      strlen(state->client_first_message_bare));
987     scram_HMAC_update(&ctx, ",", 1);
988     scram_HMAC_update(&ctx,
989 		      state->server_first_message,
990 		      strlen(state->server_first_message));
991     scram_HMAC_update(&ctx, ",", 1);
992     scram_HMAC_update(&ctx,
993 		      state->client_final_message_without_proof,
994 		      strlen(state->client_final_message_without_proof));
995     scram_HMAC_final(ClientSignature, &ctx);
996 
997     /* Extract the ClientKey that the client calculated from the proof */
998     for (i = 0; i < SCRAM_KEY_LEN; i++)
999 	    state->ClientKey[i] = ClientProof[i] ^ ClientSignature[i];
1000 
1001     /* Hash it one more time, and compare with StoredKey */
1002     scram_H(state->ClientKey, SCRAM_KEY_LEN, client_StoredKey);
1003 
1004     if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0)
1005 	    return false;
1006 
1007     return true;
1008 }
1009 
1010 /*
1011  * Verify a plaintext password against a SCRAM verifier.  This is used when
1012  * performing plaintext password authentication for a user that has a SCRAM
1013  * verifier stored in pg_authid.
1014  */
1015 bool
scram_verify_plain_password(PgSocket * client,const char * username,const char * password,const char * verifier)1016 scram_verify_plain_password(PgSocket *client,
1017 			    const char *username, const char *password,
1018 			    const char *verifier)
1019 {
1020 	char *encoded_salt = NULL;
1021 	char *salt = NULL;
1022 	int saltlen;
1023 	int iterations;
1024 	uint8_t salted_password[SCRAM_KEY_LEN];
1025 	uint8_t stored_key[SCRAM_KEY_LEN];
1026 	uint8_t server_key[SCRAM_KEY_LEN];
1027 	uint8_t computed_key[SCRAM_KEY_LEN];
1028 	char *prep_password = NULL;
1029 	pg_saslprep_rc rc;
1030 	bool result = false;
1031 
1032 	if (!parse_scram_verifier(verifier, &iterations, &encoded_salt,
1033 				  stored_key, server_key))
1034 	{
1035 		/* The password looked like a SCRAM verifier, but could not be parsed. */
1036 		slog_warning(client, "invalid SCRAM verifier for user \"%s\"", username);
1037 		goto failed;
1038 	}
1039 
1040 	salt = malloc(pg_b64_dec_len(strlen(encoded_salt)));
1041 	if (!salt)
1042 		goto failed;
1043 	saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt);
1044 	if (saltlen == -1)
1045 	{
1046 		slog_warning(client, "invalid SCRAM verifier for user \"%s\"", username);
1047 		goto failed;
1048 	}
1049 
1050 	/* Normalize the password */
1051 	rc = pg_saslprep(password, &prep_password);
1052 	if (rc == SASLPREP_SUCCESS)
1053 		password = prep_password;
1054 
1055 	/* Compute Server Key based on the user-supplied plaintext password */
1056 	scram_SaltedPassword(password, salt, saltlen, iterations, salted_password);
1057 	scram_ServerKey(salted_password, computed_key);
1058 
1059 	/*
1060 	 * Compare the verifier's Server Key with the one computed from the
1061 	 * user-supplied password.
1062 	 */
1063 	result = memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0;
1064 
1065 failed:
1066 	free(encoded_salt);
1067 	free(salt);
1068 	free(prep_password);
1069 	return result;
1070 }
1071