1 /*-------------------------------------------------------------------------
2  *
3  * fe-auth-scram.c
4  *	   The front-end (client) implementation of SCRAM authentication.
5  *
6  * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
7  * Portions Copyright (c) 1994, Regents of the University of California
8  *
9  * IDENTIFICATION
10  *	  src/interfaces/libpq/fe-auth-scram.c
11  *
12  *-------------------------------------------------------------------------
13  */
14 
15 #include "postgres_fe.h"
16 
17 #include "common/base64.h"
18 #include "common/hmac.h"
19 #include "common/saslprep.h"
20 #include "common/scram-common.h"
21 #include "fe-auth.h"
22 
23 
24 /*
25  * Status of exchange messages used for SCRAM authentication via the
26  * SASL protocol.
27  */
28 typedef enum
29 {
30 	FE_SCRAM_INIT,
31 	FE_SCRAM_NONCE_SENT,
32 	FE_SCRAM_PROOF_SENT,
33 	FE_SCRAM_FINISHED
34 } fe_scram_state_enum;
35 
36 typedef struct
37 {
38 	fe_scram_state_enum state;
39 
40 	/* These are supplied by the user */
41 	PGconn	   *conn;
42 	char	   *password;
43 	char	   *sasl_mechanism;
44 
45 	/* We construct these */
46 	uint8		SaltedPassword[SCRAM_KEY_LEN];
47 	char	   *client_nonce;
48 	char	   *client_first_message_bare;
49 	char	   *client_final_message_without_proof;
50 
51 	/* These come from the server-first message */
52 	char	   *server_first_message;
53 	char	   *salt;
54 	int			saltlen;
55 	int			iterations;
56 	char	   *nonce;
57 
58 	/* These come from the server-final message */
59 	char	   *server_final_message;
60 	char		ServerSignature[SCRAM_KEY_LEN];
61 } fe_scram_state;
62 
63 static bool read_server_first_message(fe_scram_state *state, char *input);
64 static bool read_server_final_message(fe_scram_state *state, char *input);
65 static char *build_client_first_message(fe_scram_state *state);
66 static char *build_client_final_message(fe_scram_state *state);
67 static bool verify_server_signature(fe_scram_state *state, bool *match);
68 static bool calculate_client_proof(fe_scram_state *state,
69 								   const char *client_final_message_without_proof,
70 								   uint8 *result);
71 
72 /*
73  * Initialize SCRAM exchange status.
74  */
75 void *
pg_fe_scram_init(PGconn * conn,const char * password,const char * sasl_mechanism)76 pg_fe_scram_init(PGconn *conn,
77 				 const char *password,
78 				 const char *sasl_mechanism)
79 {
80 	fe_scram_state *state;
81 	char	   *prep_password;
82 	pg_saslprep_rc rc;
83 
84 	Assert(sasl_mechanism != NULL);
85 
86 	state = (fe_scram_state *) malloc(sizeof(fe_scram_state));
87 	if (!state)
88 		return NULL;
89 	memset(state, 0, sizeof(fe_scram_state));
90 	state->conn = conn;
91 	state->state = FE_SCRAM_INIT;
92 	state->sasl_mechanism = strdup(sasl_mechanism);
93 
94 	if (!state->sasl_mechanism)
95 	{
96 		free(state);
97 		return NULL;
98 	}
99 
100 	/* Normalize the password with SASLprep, if possible */
101 	rc = pg_saslprep(password, &prep_password);
102 	if (rc == SASLPREP_OOM)
103 	{
104 		free(state->sasl_mechanism);
105 		free(state);
106 		return NULL;
107 	}
108 	if (rc != SASLPREP_SUCCESS)
109 	{
110 		prep_password = strdup(password);
111 		if (!prep_password)
112 		{
113 			free(state->sasl_mechanism);
114 			free(state);
115 			return NULL;
116 		}
117 	}
118 	state->password = prep_password;
119 
120 	return state;
121 }
122 
123 /*
124  * Return true if channel binding was employed and the SCRAM exchange
125  * completed. This should be used after a successful exchange to determine
126  * whether the server authenticated itself to the client.
127  *
128  * Note that the caller must also ensure that the exchange was actually
129  * successful.
130  */
131 bool
pg_fe_scram_channel_bound(void * opaq)132 pg_fe_scram_channel_bound(void *opaq)
133 {
134 	fe_scram_state *state = (fe_scram_state *) opaq;
135 
136 	/* no SCRAM exchange done */
137 	if (state == NULL)
138 		return false;
139 
140 	/* SCRAM exchange not completed */
141 	if (state->state != FE_SCRAM_FINISHED)
142 		return false;
143 
144 	/* channel binding mechanism not used */
145 	if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) != 0)
146 		return false;
147 
148 	/* all clear! */
149 	return true;
150 }
151 
152 /*
153  * Free SCRAM exchange status
154  */
155 void
pg_fe_scram_free(void * opaq)156 pg_fe_scram_free(void *opaq)
157 {
158 	fe_scram_state *state = (fe_scram_state *) opaq;
159 
160 	if (state->password)
161 		free(state->password);
162 	if (state->sasl_mechanism)
163 		free(state->sasl_mechanism);
164 
165 	/* client messages */
166 	if (state->client_nonce)
167 		free(state->client_nonce);
168 	if (state->client_first_message_bare)
169 		free(state->client_first_message_bare);
170 	if (state->client_final_message_without_proof)
171 		free(state->client_final_message_without_proof);
172 
173 	/* first message from server */
174 	if (state->server_first_message)
175 		free(state->server_first_message);
176 	if (state->salt)
177 		free(state->salt);
178 	if (state->nonce)
179 		free(state->nonce);
180 
181 	/* final message from server */
182 	if (state->server_final_message)
183 		free(state->server_final_message);
184 
185 	free(state);
186 }
187 
188 /*
189  * Exchange a SCRAM message with backend.
190  */
191 void
pg_fe_scram_exchange(void * opaq,char * input,int inputlen,char ** output,int * outputlen,bool * done,bool * success)192 pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
193 					 char **output, int *outputlen,
194 					 bool *done, bool *success)
195 {
196 	fe_scram_state *state = (fe_scram_state *) opaq;
197 	PGconn	   *conn = state->conn;
198 
199 	*done = false;
200 	*success = false;
201 	*output = NULL;
202 	*outputlen = 0;
203 
204 	/*
205 	 * Check that the input length agrees with the string length of the input.
206 	 * We can ignore inputlen after this.
207 	 */
208 	if (state->state != FE_SCRAM_INIT)
209 	{
210 		if (inputlen == 0)
211 		{
212 			appendPQExpBufferStr(&conn->errorMessage,
213 								 libpq_gettext("malformed SCRAM message (empty message)\n"));
214 			goto error;
215 		}
216 		if (inputlen != strlen(input))
217 		{
218 			appendPQExpBufferStr(&conn->errorMessage,
219 								 libpq_gettext("malformed SCRAM message (length mismatch)\n"));
220 			goto error;
221 		}
222 	}
223 
224 	switch (state->state)
225 	{
226 		case FE_SCRAM_INIT:
227 			/* Begin the SCRAM handshake, by sending client nonce */
228 			*output = build_client_first_message(state);
229 			if (*output == NULL)
230 				goto error;
231 
232 			*outputlen = strlen(*output);
233 			*done = false;
234 			state->state = FE_SCRAM_NONCE_SENT;
235 			break;
236 
237 		case FE_SCRAM_NONCE_SENT:
238 			/* Receive salt and server nonce, send response. */
239 			if (!read_server_first_message(state, input))
240 				goto error;
241 
242 			*output = build_client_final_message(state);
243 			if (*output == NULL)
244 				goto error;
245 
246 			*outputlen = strlen(*output);
247 			*done = false;
248 			state->state = FE_SCRAM_PROOF_SENT;
249 			break;
250 
251 		case FE_SCRAM_PROOF_SENT:
252 			/* Receive server signature */
253 			if (!read_server_final_message(state, input))
254 				goto error;
255 
256 			/*
257 			 * Verify server signature, to make sure we're talking to the
258 			 * genuine server.
259 			 */
260 			if (!verify_server_signature(state, success))
261 			{
262 				appendPQExpBufferStr(&conn->errorMessage,
263 									 libpq_gettext("could not verify server signature\n"));
264 				goto error;
265 			}
266 
267 			if (!*success)
268 			{
269 				appendPQExpBufferStr(&conn->errorMessage,
270 									 libpq_gettext("incorrect server signature\n"));
271 			}
272 			*done = true;
273 			state->state = FE_SCRAM_FINISHED;
274 			break;
275 
276 		default:
277 			/* shouldn't happen */
278 			appendPQExpBufferStr(&conn->errorMessage,
279 								 libpq_gettext("invalid SCRAM exchange state\n"));
280 			goto error;
281 	}
282 	return;
283 
284 error:
285 	*done = true;
286 	*success = false;
287 }
288 
289 /*
290  * Read value for an attribute part of a SCRAM message.
291  *
292  * The buffer at **input is destructively modified, and *input is
293  * advanced over the "attr=value" string and any following comma.
294  *
295  * On failure, append an error message to *errorMessage and return NULL.
296  */
297 static char *
read_attr_value(char ** input,char attr,PQExpBuffer errorMessage)298 read_attr_value(char **input, char attr, PQExpBuffer errorMessage)
299 {
300 	char	   *begin = *input;
301 	char	   *end;
302 
303 	if (*begin != attr)
304 	{
305 		appendPQExpBuffer(errorMessage,
306 						  libpq_gettext("malformed SCRAM message (attribute \"%c\" expected)\n"),
307 						  attr);
308 		return NULL;
309 	}
310 	begin++;
311 
312 	if (*begin != '=')
313 	{
314 		appendPQExpBuffer(errorMessage,
315 						  libpq_gettext("malformed SCRAM message (expected character \"=\" for attribute \"%c\")\n"),
316 						  attr);
317 		return NULL;
318 	}
319 	begin++;
320 
321 	end = begin;
322 	while (*end && *end != ',')
323 		end++;
324 
325 	if (*end)
326 	{
327 		*end = '\0';
328 		*input = end + 1;
329 	}
330 	else
331 		*input = end;
332 
333 	return begin;
334 }
335 
336 /*
337  * Build the first exchange message sent by the client.
338  */
339 static char *
build_client_first_message(fe_scram_state * state)340 build_client_first_message(fe_scram_state *state)
341 {
342 	PGconn	   *conn = state->conn;
343 	char		raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
344 	char	   *result;
345 	int			channel_info_len;
346 	int			encoded_len;
347 	PQExpBufferData buf;
348 
349 	/*
350 	 * Generate a "raw" nonce.  This is converted to ASCII-printable form by
351 	 * base64-encoding it.
352 	 */
353 	if (!pg_strong_random(raw_nonce, SCRAM_RAW_NONCE_LEN))
354 	{
355 		appendPQExpBufferStr(&conn->errorMessage,
356 							 libpq_gettext("could not generate nonce\n"));
357 		return NULL;
358 	}
359 
360 	encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
361 	/* don't forget the zero-terminator */
362 	state->client_nonce = malloc(encoded_len + 1);
363 	if (state->client_nonce == NULL)
364 	{
365 		appendPQExpBufferStr(&conn->errorMessage,
366 							 libpq_gettext("out of memory\n"));
367 		return NULL;
368 	}
369 	encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
370 								state->client_nonce, encoded_len);
371 	if (encoded_len < 0)
372 	{
373 		appendPQExpBufferStr(&conn->errorMessage,
374 							 libpq_gettext("could not encode nonce\n"));
375 		return NULL;
376 	}
377 	state->client_nonce[encoded_len] = '\0';
378 
379 	/*
380 	 * Generate message.  The username is left empty as the backend uses the
381 	 * value provided by the startup packet.  Also, as this username is not
382 	 * prepared with SASLprep, the message parsing would fail if it includes
383 	 * '=' or ',' characters.
384 	 */
385 
386 	initPQExpBuffer(&buf);
387 
388 	/*
389 	 * First build the gs2-header with channel binding information.
390 	 */
391 	if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) == 0)
392 	{
393 		Assert(conn->ssl_in_use);
394 		appendPQExpBufferStr(&buf, "p=tls-server-end-point");
395 	}
396 #ifdef HAVE_PGTLS_GET_PEER_CERTIFICATE_HASH
397 	else if (conn->channel_binding[0] != 'd' && /* disable */
398 			 conn->ssl_in_use)
399 	{
400 		/*
401 		 * Client supports channel binding, but thinks the server does not.
402 		 */
403 		appendPQExpBufferChar(&buf, 'y');
404 	}
405 #endif
406 	else
407 	{
408 		/*
409 		 * Client does not support channel binding, or has disabled it.
410 		 */
411 		appendPQExpBufferChar(&buf, 'n');
412 	}
413 
414 	if (PQExpBufferDataBroken(buf))
415 		goto oom_error;
416 
417 	channel_info_len = buf.len;
418 
419 	appendPQExpBuffer(&buf, ",,n=,r=%s", state->client_nonce);
420 	if (PQExpBufferDataBroken(buf))
421 		goto oom_error;
422 
423 	/*
424 	 * The first message content needs to be saved without channel binding
425 	 * information.
426 	 */
427 	state->client_first_message_bare = strdup(buf.data + channel_info_len + 2);
428 	if (!state->client_first_message_bare)
429 		goto oom_error;
430 
431 	result = strdup(buf.data);
432 	if (result == NULL)
433 		goto oom_error;
434 
435 	termPQExpBuffer(&buf);
436 	return result;
437 
438 oom_error:
439 	termPQExpBuffer(&buf);
440 	appendPQExpBufferStr(&conn->errorMessage,
441 						 libpq_gettext("out of memory\n"));
442 	return NULL;
443 }
444 
445 /*
446  * Build the final exchange message sent from the client.
447  */
448 static char *
build_client_final_message(fe_scram_state * state)449 build_client_final_message(fe_scram_state *state)
450 {
451 	PQExpBufferData buf;
452 	PGconn	   *conn = state->conn;
453 	uint8		client_proof[SCRAM_KEY_LEN];
454 	char	   *result;
455 	int			encoded_len;
456 
457 	initPQExpBuffer(&buf);
458 
459 	/*
460 	 * Construct client-final-message-without-proof.  We need to remember it
461 	 * for verifying the server proof in the final step of authentication.
462 	 *
463 	 * The channel binding flag handling (p/y/n) must be consistent with
464 	 * build_client_first_message(), because the server will check that it's
465 	 * the same flag both times.
466 	 */
467 	if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) == 0)
468 	{
469 #ifdef HAVE_PGTLS_GET_PEER_CERTIFICATE_HASH
470 		char	   *cbind_data = NULL;
471 		size_t		cbind_data_len = 0;
472 		size_t		cbind_header_len;
473 		char	   *cbind_input;
474 		size_t		cbind_input_len;
475 		int			encoded_cbind_len;
476 
477 		/* Fetch hash data of server's SSL certificate */
478 		cbind_data =
479 			pgtls_get_peer_certificate_hash(state->conn,
480 											&cbind_data_len);
481 		if (cbind_data == NULL)
482 		{
483 			/* error message is already set on error */
484 			termPQExpBuffer(&buf);
485 			return NULL;
486 		}
487 
488 		appendPQExpBufferStr(&buf, "c=");
489 
490 		/* p=type,, */
491 		cbind_header_len = strlen("p=tls-server-end-point,,");
492 		cbind_input_len = cbind_header_len + cbind_data_len;
493 		cbind_input = malloc(cbind_input_len);
494 		if (!cbind_input)
495 		{
496 			free(cbind_data);
497 			goto oom_error;
498 		}
499 		memcpy(cbind_input, "p=tls-server-end-point,,", cbind_header_len);
500 		memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
501 
502 		encoded_cbind_len = pg_b64_enc_len(cbind_input_len);
503 		if (!enlargePQExpBuffer(&buf, encoded_cbind_len))
504 		{
505 			free(cbind_data);
506 			free(cbind_input);
507 			goto oom_error;
508 		}
509 		encoded_cbind_len = pg_b64_encode(cbind_input, cbind_input_len,
510 										  buf.data + buf.len,
511 										  encoded_cbind_len);
512 		if (encoded_cbind_len < 0)
513 		{
514 			free(cbind_data);
515 			free(cbind_input);
516 			termPQExpBuffer(&buf);
517 			appendPQExpBufferStr(&conn->errorMessage,
518 								 "could not encode cbind data for channel binding\n");
519 			return NULL;
520 		}
521 		buf.len += encoded_cbind_len;
522 		buf.data[buf.len] = '\0';
523 
524 		free(cbind_data);
525 		free(cbind_input);
526 #else
527 		/*
528 		 * Chose channel binding, but the SSL library doesn't support it.
529 		 * Shouldn't happen.
530 		 */
531 		termPQExpBuffer(&buf);
532 		appendPQExpBufferStr(&conn->errorMessage,
533 							 "channel binding not supported by this build\n");
534 		return NULL;
535 #endif							/* HAVE_PGTLS_GET_PEER_CERTIFICATE_HASH */
536 	}
537 #ifdef HAVE_PGTLS_GET_PEER_CERTIFICATE_HASH
538 	else if (conn->channel_binding[0] != 'd' && /* disable */
539 			 conn->ssl_in_use)
540 		appendPQExpBufferStr(&buf, "c=eSws");	/* base64 of "y,," */
541 #endif
542 	else
543 		appendPQExpBufferStr(&buf, "c=biws");	/* base64 of "n,," */
544 
545 	if (PQExpBufferDataBroken(buf))
546 		goto oom_error;
547 
548 	appendPQExpBuffer(&buf, ",r=%s", state->nonce);
549 	if (PQExpBufferDataBroken(buf))
550 		goto oom_error;
551 
552 	state->client_final_message_without_proof = strdup(buf.data);
553 	if (state->client_final_message_without_proof == NULL)
554 		goto oom_error;
555 
556 	/* Append proof to it, to form client-final-message. */
557 	if (!calculate_client_proof(state,
558 								state->client_final_message_without_proof,
559 								client_proof))
560 	{
561 		termPQExpBuffer(&buf);
562 		appendPQExpBufferStr(&conn->errorMessage,
563 							 libpq_gettext("could not calculate client proof\n"));
564 		return NULL;
565 	}
566 
567 	appendPQExpBufferStr(&buf, ",p=");
568 	encoded_len = pg_b64_enc_len(SCRAM_KEY_LEN);
569 	if (!enlargePQExpBuffer(&buf, encoded_len))
570 		goto oom_error;
571 	encoded_len = pg_b64_encode((char *) client_proof,
572 								SCRAM_KEY_LEN,
573 								buf.data + buf.len,
574 								encoded_len);
575 	if (encoded_len < 0)
576 	{
577 		termPQExpBuffer(&buf);
578 		appendPQExpBufferStr(&conn->errorMessage,
579 							 libpq_gettext("could not encode client proof\n"));
580 		return NULL;
581 	}
582 	buf.len += encoded_len;
583 	buf.data[buf.len] = '\0';
584 
585 	result = strdup(buf.data);
586 	if (result == NULL)
587 		goto oom_error;
588 
589 	termPQExpBuffer(&buf);
590 	return result;
591 
592 oom_error:
593 	termPQExpBuffer(&buf);
594 	appendPQExpBufferStr(&conn->errorMessage,
595 						 libpq_gettext("out of memory\n"));
596 	return NULL;
597 }
598 
599 /*
600  * Read the first exchange message coming from the server.
601  */
602 static bool
read_server_first_message(fe_scram_state * state,char * input)603 read_server_first_message(fe_scram_state *state, char *input)
604 {
605 	PGconn	   *conn = state->conn;
606 	char	   *iterations_str;
607 	char	   *endptr;
608 	char	   *encoded_salt;
609 	char	   *nonce;
610 	int			decoded_salt_len;
611 
612 	state->server_first_message = strdup(input);
613 	if (state->server_first_message == NULL)
614 	{
615 		appendPQExpBufferStr(&conn->errorMessage,
616 							 libpq_gettext("out of memory\n"));
617 		return false;
618 	}
619 
620 	/* parse the message */
621 	nonce = read_attr_value(&input, 'r',
622 							&conn->errorMessage);
623 	if (nonce == NULL)
624 	{
625 		/* read_attr_value() has appended an error string */
626 		return false;
627 	}
628 
629 	/* Verify immediately that the server used our part of the nonce */
630 	if (strlen(nonce) < strlen(state->client_nonce) ||
631 		memcmp(nonce, state->client_nonce, strlen(state->client_nonce)) != 0)
632 	{
633 		appendPQExpBufferStr(&conn->errorMessage,
634 							 libpq_gettext("invalid SCRAM response (nonce mismatch)\n"));
635 		return false;
636 	}
637 
638 	state->nonce = strdup(nonce);
639 	if (state->nonce == NULL)
640 	{
641 		appendPQExpBufferStr(&conn->errorMessage,
642 							 libpq_gettext("out of memory\n"));
643 		return false;
644 	}
645 
646 	encoded_salt = read_attr_value(&input, 's', &conn->errorMessage);
647 	if (encoded_salt == NULL)
648 	{
649 		/* read_attr_value() has appended an error string */
650 		return false;
651 	}
652 	decoded_salt_len = pg_b64_dec_len(strlen(encoded_salt));
653 	state->salt = malloc(decoded_salt_len);
654 	if (state->salt == NULL)
655 	{
656 		appendPQExpBufferStr(&conn->errorMessage,
657 							 libpq_gettext("out of memory\n"));
658 		return false;
659 	}
660 	state->saltlen = pg_b64_decode(encoded_salt,
661 								   strlen(encoded_salt),
662 								   state->salt,
663 								   decoded_salt_len);
664 	if (state->saltlen < 0)
665 	{
666 		appendPQExpBufferStr(&conn->errorMessage,
667 							 libpq_gettext("malformed SCRAM message (invalid salt)\n"));
668 		return false;
669 	}
670 
671 	iterations_str = read_attr_value(&input, 'i', &conn->errorMessage);
672 	if (iterations_str == NULL)
673 	{
674 		/* read_attr_value() has appended an error string */
675 		return false;
676 	}
677 	state->iterations = strtol(iterations_str, &endptr, 10);
678 	if (*endptr != '\0' || state->iterations < 1)
679 	{
680 		appendPQExpBufferStr(&conn->errorMessage,
681 							 libpq_gettext("malformed SCRAM message (invalid iteration count)\n"));
682 		return false;
683 	}
684 
685 	if (*input != '\0')
686 		appendPQExpBufferStr(&conn->errorMessage,
687 							 libpq_gettext("malformed SCRAM message (garbage at end of server-first-message)\n"));
688 
689 	return true;
690 }
691 
692 /*
693  * Read the final exchange message coming from the server.
694  */
695 static bool
read_server_final_message(fe_scram_state * state,char * input)696 read_server_final_message(fe_scram_state *state, char *input)
697 {
698 	PGconn	   *conn = state->conn;
699 	char	   *encoded_server_signature;
700 	char	   *decoded_server_signature;
701 	int			server_signature_len;
702 
703 	state->server_final_message = strdup(input);
704 	if (!state->server_final_message)
705 	{
706 		appendPQExpBufferStr(&conn->errorMessage,
707 							 libpq_gettext("out of memory\n"));
708 		return false;
709 	}
710 
711 	/* Check for error result. */
712 	if (*input == 'e')
713 	{
714 		char	   *errmsg = read_attr_value(&input, 'e',
715 											 &conn->errorMessage);
716 
717 		if (errmsg == NULL)
718 		{
719 			/* read_attr_value() has appended an error message */
720 			return false;
721 		}
722 		appendPQExpBuffer(&conn->errorMessage,
723 						  libpq_gettext("error received from server in SCRAM exchange: %s\n"),
724 						  errmsg);
725 		return false;
726 	}
727 
728 	/* Parse the message. */
729 	encoded_server_signature = read_attr_value(&input, 'v',
730 											   &conn->errorMessage);
731 	if (encoded_server_signature == NULL)
732 	{
733 		/* read_attr_value() has appended an error message */
734 		return false;
735 	}
736 
737 	if (*input != '\0')
738 		appendPQExpBufferStr(&conn->errorMessage,
739 							 libpq_gettext("malformed SCRAM message (garbage at end of server-final-message)\n"));
740 
741 	server_signature_len = pg_b64_dec_len(strlen(encoded_server_signature));
742 	decoded_server_signature = malloc(server_signature_len);
743 	if (!decoded_server_signature)
744 	{
745 		appendPQExpBufferStr(&conn->errorMessage,
746 							 libpq_gettext("out of memory\n"));
747 		return false;
748 	}
749 
750 	server_signature_len = pg_b64_decode(encoded_server_signature,
751 										 strlen(encoded_server_signature),
752 										 decoded_server_signature,
753 										 server_signature_len);
754 	if (server_signature_len != SCRAM_KEY_LEN)
755 	{
756 		free(decoded_server_signature);
757 		appendPQExpBufferStr(&conn->errorMessage,
758 							 libpq_gettext("malformed SCRAM message (invalid server signature)\n"));
759 		return false;
760 	}
761 	memcpy(state->ServerSignature, decoded_server_signature, SCRAM_KEY_LEN);
762 	free(decoded_server_signature);
763 
764 	return true;
765 }
766 
767 /*
768  * Calculate the client proof, part of the final exchange message sent
769  * by the client.  Returns true on success, false on failure.
770  */
771 static bool
calculate_client_proof(fe_scram_state * state,const char * client_final_message_without_proof,uint8 * result)772 calculate_client_proof(fe_scram_state *state,
773 					   const char *client_final_message_without_proof,
774 					   uint8 *result)
775 {
776 	uint8		StoredKey[SCRAM_KEY_LEN];
777 	uint8		ClientKey[SCRAM_KEY_LEN];
778 	uint8		ClientSignature[SCRAM_KEY_LEN];
779 	int			i;
780 	pg_hmac_ctx *ctx;
781 
782 	ctx = pg_hmac_create(PG_SHA256);
783 	if (ctx == NULL)
784 		return false;
785 
786 	/*
787 	 * Calculate SaltedPassword, and store it in 'state' so that we can reuse
788 	 * it later in verify_server_signature.
789 	 */
790 	if (scram_SaltedPassword(state->password, state->salt, state->saltlen,
791 							 state->iterations, state->SaltedPassword) < 0 ||
792 		scram_ClientKey(state->SaltedPassword, ClientKey) < 0 ||
793 		scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey) < 0 ||
794 		pg_hmac_init(ctx, StoredKey, SCRAM_KEY_LEN) < 0 ||
795 		pg_hmac_update(ctx,
796 					   (uint8 *) state->client_first_message_bare,
797 					   strlen(state->client_first_message_bare)) < 0 ||
798 		pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
799 		pg_hmac_update(ctx,
800 					   (uint8 *) state->server_first_message,
801 					   strlen(state->server_first_message)) < 0 ||
802 		pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
803 		pg_hmac_update(ctx,
804 					   (uint8 *) client_final_message_without_proof,
805 					   strlen(client_final_message_without_proof)) < 0 ||
806 		pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
807 	{
808 		pg_hmac_free(ctx);
809 		return false;
810 	}
811 
812 	for (i = 0; i < SCRAM_KEY_LEN; i++)
813 		result[i] = ClientKey[i] ^ ClientSignature[i];
814 
815 	pg_hmac_free(ctx);
816 	return true;
817 }
818 
819 /*
820  * Validate the server signature, received as part of the final exchange
821  * message received from the server.  *match tracks if the server signature
822  * matched or not. Returns true if the server signature got verified, and
823  * false for a processing error.
824  */
825 static bool
verify_server_signature(fe_scram_state * state,bool * match)826 verify_server_signature(fe_scram_state *state, bool *match)
827 {
828 	uint8		expected_ServerSignature[SCRAM_KEY_LEN];
829 	uint8		ServerKey[SCRAM_KEY_LEN];
830 	pg_hmac_ctx *ctx;
831 
832 	ctx = pg_hmac_create(PG_SHA256);
833 	if (ctx == NULL)
834 		return false;
835 
836 	if (scram_ServerKey(state->SaltedPassword, ServerKey) < 0 ||
837 	/* calculate ServerSignature */
838 		pg_hmac_init(ctx, ServerKey, SCRAM_KEY_LEN) < 0 ||
839 		pg_hmac_update(ctx,
840 					   (uint8 *) state->client_first_message_bare,
841 					   strlen(state->client_first_message_bare)) < 0 ||
842 		pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
843 		pg_hmac_update(ctx,
844 					   (uint8 *) state->server_first_message,
845 					   strlen(state->server_first_message)) < 0 ||
846 		pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
847 		pg_hmac_update(ctx,
848 					   (uint8 *) state->client_final_message_without_proof,
849 					   strlen(state->client_final_message_without_proof)) < 0 ||
850 		pg_hmac_final(ctx, expected_ServerSignature,
851 					  sizeof(expected_ServerSignature)) < 0)
852 	{
853 		pg_hmac_free(ctx);
854 		return false;
855 	}
856 
857 	pg_hmac_free(ctx);
858 
859 	/* signature processed, so now check after it */
860 	if (memcmp(expected_ServerSignature, state->ServerSignature, SCRAM_KEY_LEN) != 0)
861 		*match = false;
862 	else
863 		*match = true;
864 
865 	return true;
866 }
867 
868 /*
869  * Build a new SCRAM secret.
870  */
871 char *
pg_fe_scram_build_secret(const char * password)872 pg_fe_scram_build_secret(const char *password)
873 {
874 	char	   *prep_password;
875 	pg_saslprep_rc rc;
876 	char		saltbuf[SCRAM_DEFAULT_SALT_LEN];
877 	char	   *result;
878 
879 	/*
880 	 * Normalize the password with SASLprep.  If that doesn't work, because
881 	 * the password isn't valid UTF-8 or contains prohibited characters, just
882 	 * proceed with the original password.  (See comments at top of file.)
883 	 */
884 	rc = pg_saslprep(password, &prep_password);
885 	if (rc == SASLPREP_OOM)
886 		return NULL;
887 	if (rc == SASLPREP_SUCCESS)
888 		password = (const char *) prep_password;
889 
890 	/* Generate a random salt */
891 	if (!pg_strong_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
892 	{
893 		if (prep_password)
894 			free(prep_password);
895 		return NULL;
896 	}
897 
898 	result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
899 								SCRAM_DEFAULT_ITERATIONS, password);
900 
901 	if (prep_password)
902 		free(prep_password);
903 
904 	return result;
905 }
906