1 /*
2 * PgBouncer - Lightweight connection pooler for PostgreSQL.
3 *
4 * Copyright (c) 2007-2009 Marko Kreen, Skype Technologies OÜ
5 *
6 * Permission to use, copy, modify, and/or distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19 /*
20 * SCRAM support
21 */
22
23 #include "bouncer.h"
24 #include "scram.h"
25 #include "common/base64.h"
26 #include "common/saslprep.h"
27 #include "common/scram-common.h"
28
29
30 static bool calculate_client_proof(ScramState *scram_state,
31 const PgUser *user,
32 const char *salt,
33 int saltlen,
34 int iterations,
35 const char *client_final_message_without_proof,
36 uint8_t *result);
37
38
39 /*
40 * free SCRAM state info after auth is done
41 */
free_scram_state(ScramState * scram_state)42 void free_scram_state(ScramState *scram_state)
43 {
44 free(scram_state->client_nonce);
45 free(scram_state->client_first_message_bare);
46 free(scram_state->client_final_message_without_proof);
47 free(scram_state->server_nonce);
48 free(scram_state->server_first_message);
49 free(scram_state->SaltedPassword);
50 free(scram_state->salt);
51 memset(scram_state, 0, sizeof(*scram_state));
52 }
53
is_scram_printable(char * p)54 static bool is_scram_printable(char *p)
55 {
56 /*------
57 * Printable characters, as defined by SCRAM spec: (RFC 5802)
58 *
59 * printable = %x21-2B / %x2D-7E
60 * ;; Printable ASCII except ",".
61 * ;; Note that any "printable" is also
62 * ;; a valid "value".
63 *------
64 */
65 for (; *p; p++)
66 if (*p < 0x21 || *p > 0x7E || *p == 0x2C /* comma */ )
67 return false;
68
69 return true;
70 }
71
sanitize_char(char c)72 static char *sanitize_char(char c)
73 {
74 static char buf[5];
75
76 if (c >= 0x21 && c <= 0x7E)
77 snprintf(buf, sizeof(buf), "'%c'", c);
78 else
79 snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
80 return buf;
81 }
82
83 /*
84 * Read value for an attribute part of a SCRAM message.
85 */
read_attr_value(PgSocket * sk,char ** input,char attr)86 static char *read_attr_value(PgSocket *sk, char **input, char attr)
87 {
88 char *begin = *input;
89 char *end;
90
91 if (*begin != attr)
92 {
93 slog_error(sk, "malformed SCRAM message (attribute \"%c\" expected)",
94 attr);
95 return NULL;
96 }
97 begin++;
98
99 if (*begin != '=')
100 {
101 slog_error(sk, "malformed SCRAM message (expected \"=\" after attribute \"%c\")",
102 attr);
103 return NULL;
104 }
105 begin++;
106
107 end = begin;
108 while (*end && *end != ',')
109 end++;
110
111 if (*end)
112 {
113 *end = '\0';
114 *input = end + 1;
115 }
116 else
117 *input = end;
118
119 return begin;
120 }
121
122 /*
123 * Read the next attribute and value in a SCRAM exchange message.
124 *
125 * Returns NULL if there is no attribute.
126 */
127 static char *
read_any_attr(PgSocket * sk,char ** input,char * attr_p)128 read_any_attr(PgSocket *sk, char **input, char *attr_p)
129 {
130 char *begin = *input;
131 char *end;
132 char attr = *begin;
133
134 if (!((attr >= 'A' && attr <= 'Z') ||
135 (attr >= 'a' && attr <= 'z')))
136 {
137 slog_error(sk, "malformed SCRAM message (attribute expected, but found invalid character \"%s\")",
138 sanitize_char(attr));
139 return NULL;
140 }
141 if (attr_p)
142 *attr_p = attr;
143 begin++;
144
145 if (*begin != '=')
146 {
147 slog_error(sk, "malformed SCRAM message (expected character \"=\" after attribute \"%c\")",
148 attr);
149 return NULL;
150 }
151 begin++;
152
153 end = begin;
154 while (*end && *end != ',')
155 end++;
156
157 if (*end)
158 {
159 *end = '\0';
160 *input = end + 1;
161 }
162 else
163 *input = end;
164
165 return begin;
166 }
167
168 /*
169 * Parse and validate format of given SCRAM verifier.
170 *
171 * Returns true if the SCRAM verifier has been parsed, and false otherwise.
172 */
parse_scram_verifier(const char * verifier,int * iterations,char ** salt,uint8_t * stored_key,uint8_t * server_key)173 static bool parse_scram_verifier(const char *verifier, int *iterations, char **salt,
174 uint8_t *stored_key, uint8_t *server_key)
175 {
176 char *v;
177 char *p;
178 char *scheme_str;
179 char *salt_str;
180 char *iterations_str;
181 char *storedkey_str;
182 char *serverkey_str;
183 int decoded_len;
184 char *decoded_salt_buf;
185 char *decoded_stored_buf = NULL;
186 char *decoded_server_buf = NULL;
187
188 /*
189 * The verifier is of form:
190 *
191 * SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
192 */
193 v = strdup(verifier);
194 if (!v)
195 goto invalid_verifier;
196 if ((scheme_str = strtok(v, "$")) == NULL)
197 goto invalid_verifier;
198 if ((iterations_str = strtok(NULL, ":")) == NULL)
199 goto invalid_verifier;
200 if ((salt_str = strtok(NULL, "$")) == NULL)
201 goto invalid_verifier;
202 if ((storedkey_str = strtok(NULL, ":")) == NULL)
203 goto invalid_verifier;
204 if ((serverkey_str = strtok(NULL, "")) == NULL)
205 goto invalid_verifier;
206
207 /* Parse the fields */
208 if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
209 goto invalid_verifier;
210
211 errno = 0;
212 *iterations = strtol(iterations_str, &p, 10);
213 if (*p || errno != 0)
214 goto invalid_verifier;
215
216 /*
217 * Verify that the salt is in Base64-encoded format, by decoding it,
218 * although we return the encoded version to the caller.
219 */
220 decoded_salt_buf = malloc(pg_b64_dec_len(strlen(salt_str)));
221 if (!decoded_salt_buf)
222 goto invalid_verifier;
223 decoded_len = pg_b64_decode(salt_str, strlen(salt_str), decoded_salt_buf);
224 free(decoded_salt_buf);
225 if (decoded_len < 0)
226 goto invalid_verifier;
227 *salt = strdup(salt_str);
228 if (!*salt)
229 goto invalid_verifier;
230
231 /*
232 * Decode StoredKey and ServerKey.
233 */
234 decoded_stored_buf = malloc(pg_b64_dec_len(strlen(storedkey_str)));
235 if (!decoded_stored_buf)
236 goto invalid_verifier;
237 decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str), decoded_stored_buf);
238 if (decoded_len != SCRAM_KEY_LEN)
239 goto invalid_verifier;
240 memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
241
242 decoded_server_buf = malloc(pg_b64_dec_len(strlen(serverkey_str)));
243 decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
244 decoded_server_buf);
245 if (decoded_len != SCRAM_KEY_LEN)
246 goto invalid_verifier;
247 memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
248
249 free(decoded_stored_buf);
250 free(decoded_server_buf);
251 free(v);
252 return true;
253
254 invalid_verifier:
255 free(decoded_stored_buf);
256 free(decoded_server_buf);
257 free(v);
258 free(*salt);
259 *salt = NULL;
260 return false;
261 }
262
263 #define MD5_PASSWD_CHARSET "0123456789abcdef"
264
265 /*
266 * What kind of a password verifier is 'shadow_pass'?
267 */
268 PasswordType
get_password_type(const char * shadow_pass)269 get_password_type(const char *shadow_pass)
270 {
271 char *encoded_salt = NULL;
272 int iterations;
273 uint8_t stored_key[SCRAM_KEY_LEN];
274 uint8_t server_key[SCRAM_KEY_LEN];
275
276 if (strncmp(shadow_pass, "md5", 3) == 0 &&
277 strlen(shadow_pass) == MD5_PASSWD_LEN &&
278 strspn(shadow_pass + 3, MD5_PASSWD_CHARSET) == MD5_PASSWD_LEN - 3)
279 return PASSWORD_TYPE_MD5;
280 if (parse_scram_verifier(shadow_pass, &iterations, &encoded_salt,
281 stored_key, server_key)) {
282 free(encoded_salt);
283 return PASSWORD_TYPE_SCRAM_SHA_256;
284 }
285 free(encoded_salt);
286 return PASSWORD_TYPE_PLAINTEXT;
287 }
288
289 /*
290 * Functions for communicating as a client with the server
291 */
292
build_client_first_message(ScramState * scram_state)293 char *build_client_first_message(ScramState *scram_state)
294 {
295 uint8_t raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
296 int encoded_len;
297 size_t len;
298 char *result = NULL;
299
300 get_random_bytes(raw_nonce, SCRAM_RAW_NONCE_LEN);
301
302 scram_state->client_nonce = malloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
303 if (scram_state->client_nonce == NULL)
304 goto failed;
305 encoded_len = pg_b64_encode((char *) raw_nonce, SCRAM_RAW_NONCE_LEN, scram_state->client_nonce);
306 scram_state->client_nonce[encoded_len] = '\0';
307
308 len = 8 + strlen(scram_state->client_nonce) + 1;
309 result = malloc(len);
310 if (result == NULL)
311 goto failed;
312 snprintf(result, len, "n,,n=,r=%s", scram_state->client_nonce);
313
314 scram_state->client_first_message_bare = strdup(result + 3);
315 if (scram_state->client_first_message_bare == NULL)
316 goto failed;
317
318 return result;
319
320 failed:
321 free(result);
322 free(scram_state->client_nonce);
323 free(scram_state->client_first_message_bare);
324 return NULL;
325 }
326
build_client_final_message(ScramState * scram_state,const PgUser * user,const char * server_nonce,const char * salt,int saltlen,int iterations)327 char *build_client_final_message(ScramState *scram_state,
328 const PgUser *user,
329 const char *server_nonce,
330 const char *salt,
331 int saltlen,
332 int iterations)
333 {
334 char buf[512];
335 size_t len;
336 uint8_t client_proof[SCRAM_KEY_LEN];
337
338 snprintf(buf, sizeof(buf), "c=biws,r=%s", server_nonce);
339
340 scram_state->client_final_message_without_proof = strdup(buf);
341 if (scram_state->client_final_message_without_proof == NULL)
342 goto failed;
343
344 if (!calculate_client_proof(scram_state, user,
345 salt, saltlen, iterations, buf,
346 client_proof))
347 goto failed;
348
349 len = strlcat(buf, ",p=", sizeof(buf));
350 len += pg_b64_encode((char *) client_proof,
351 SCRAM_KEY_LEN,
352 buf + len);
353 buf[len] = '\0';
354
355 return strdup(buf);
356 failed:
357 return NULL;
358 }
359
read_server_first_message(PgSocket * server,char * input,char ** server_nonce_p,char ** salt_p,int * saltlen_p,int * iterations_p)360 bool read_server_first_message(PgSocket *server, char *input,
361 char **server_nonce_p, char **salt_p, int *saltlen_p, int *iterations_p)
362 {
363 char *server_nonce;
364 char *encoded_salt;
365 char *salt = NULL;
366 int saltlen;
367 char *iterations_str;
368 char *endptr;
369 int iterations;
370
371 server->scram_state.server_first_message = strdup(input);
372 if (server->scram_state.server_first_message == NULL)
373 goto failed;
374
375 server_nonce = read_attr_value(server, &input, 'r');
376 if (server_nonce == NULL)
377 goto failed;
378
379 if (strlen(server_nonce) < strlen(server->scram_state.client_nonce) ||
380 memcmp(server_nonce, server->scram_state.client_nonce, strlen(server->scram_state.client_nonce)) != 0)
381 {
382 slog_error(server, "invalid SCRAM response (nonce mismatch)");
383 goto failed;
384 }
385
386 encoded_salt = read_attr_value(server, &input, 's');
387 if (encoded_salt == NULL)
388 goto failed;
389 salt = malloc(pg_b64_dec_len(strlen(encoded_salt)));
390 if (salt == NULL)
391 goto failed;
392 saltlen = pg_b64_decode(encoded_salt,
393 strlen(encoded_salt),
394 salt);
395 if (saltlen < 0)
396 {
397 slog_error(server, "malformed SCRAM message (invalid salt)");
398 goto failed;
399 }
400
401 iterations_str = read_attr_value(server, &input, 'i');
402 if (iterations_str == NULL)
403 goto failed;
404
405 iterations = strtol(iterations_str, &endptr, 10);
406 if (*endptr != '\0' || iterations < 1)
407 {
408 slog_error(server, "malformed SCRAM message (invalid iteration count)");
409 goto failed;
410 }
411
412 if (*input != '\0')
413 {
414 slog_error(server, "malformed SCRAM message (garbage at end of server-first-message)");
415 goto failed;
416 }
417
418 *server_nonce_p = server_nonce;
419 *salt_p = salt;
420 *saltlen_p = saltlen;
421 *iterations_p = iterations;
422 return true;
423 failed:
424 free(salt);
425 return false;
426 }
427
read_server_final_message(PgSocket * server,char * input,char * ServerSignature)428 bool read_server_final_message(PgSocket *server, char *input, char *ServerSignature)
429 {
430 char *encoded_server_signature;
431 char *decoded_server_signature = NULL;
432 int server_signature_len;
433
434 if (*input == 'e')
435 {
436 char *errmsg = read_attr_value(server, &input, 'e');
437 slog_error(server, "error received from server in SCRAM exchange: %s",
438 errmsg);
439 goto failed;
440 }
441
442 encoded_server_signature = read_attr_value(server, &input, 'v');
443 if (encoded_server_signature == NULL)
444 goto failed;
445
446 if (*input != '\0')
447 slog_error(server, "malformed SCRAM message (garbage at end of server-final-message)");
448
449 server_signature_len = pg_b64_dec_len(strlen(encoded_server_signature));
450 decoded_server_signature = malloc(server_signature_len);
451 if (!decoded_server_signature)
452 goto failed;
453
454 server_signature_len = pg_b64_decode(encoded_server_signature,
455 strlen(encoded_server_signature),
456 decoded_server_signature);
457 if (server_signature_len != SCRAM_KEY_LEN)
458 {
459 slog_error(server, "malformed SCRAM message (malformed server signature)");
460 goto failed;
461 }
462 memcpy(ServerSignature, decoded_server_signature, SCRAM_KEY_LEN);
463
464 free(decoded_server_signature);
465 return true;
466 failed:
467 free(decoded_server_signature);
468 return false;
469 }
470
calculate_client_proof(ScramState * scram_state,const PgUser * user,const char * salt,int saltlen,int iterations,const char * client_final_message_without_proof,uint8_t * result)471 static bool calculate_client_proof(ScramState *scram_state,
472 const PgUser *user,
473 const char *salt,
474 int saltlen,
475 int iterations,
476 const char *client_final_message_without_proof,
477 uint8_t *result)
478 {
479 pg_saslprep_rc rc;
480 char *prep_password = NULL;
481 uint8_t StoredKey[SCRAM_KEY_LEN];
482 uint8_t ClientKey[SCRAM_KEY_LEN];
483 uint8_t ClientSignature[SCRAM_KEY_LEN];
484 scram_HMAC_ctx ctx;
485
486 if (user->has_scram_keys)
487 {
488 memcpy(ClientKey, user->scram_ClientKey, SCRAM_KEY_LEN);
489 }
490 else
491 {
492 rc = pg_saslprep(user->passwd, &prep_password);
493 if (rc == SASLPREP_OOM)
494 return false;
495 if (rc != SASLPREP_SUCCESS)
496 {
497 prep_password = strdup(user->passwd);
498 if (!prep_password)
499 return false;
500 }
501
502 scram_state->SaltedPassword = malloc(SCRAM_KEY_LEN);
503 if (scram_state->SaltedPassword == NULL)
504 goto failed;
505 scram_SaltedPassword(prep_password,
506 salt,
507 saltlen,
508 iterations,
509 scram_state->SaltedPassword);
510
511 scram_ClientKey(scram_state->SaltedPassword, ClientKey);
512 }
513
514 scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey);
515
516 scram_HMAC_init(&ctx, StoredKey, SCRAM_KEY_LEN);
517 scram_HMAC_update(&ctx,
518 scram_state->client_first_message_bare,
519 strlen(scram_state->client_first_message_bare));
520 scram_HMAC_update(&ctx, ",", 1);
521 scram_HMAC_update(&ctx,
522 scram_state->server_first_message,
523 strlen(scram_state->server_first_message));
524 scram_HMAC_update(&ctx, ",", 1);
525 scram_HMAC_update(&ctx,
526 client_final_message_without_proof,
527 strlen(client_final_message_without_proof));
528 scram_HMAC_final(ClientSignature, &ctx);
529
530 for (int i = 0; i < SCRAM_KEY_LEN; i++)
531 result[i] = ClientKey[i] ^ ClientSignature[i];
532
533 free(prep_password);
534 return true;
535 failed:
536 free(prep_password);
537 return false;
538 }
539
verify_server_signature(ScramState * scram_state,const PgUser * user,const char * ServerSignature)540 bool verify_server_signature(ScramState *scram_state, const PgUser *user, const char *ServerSignature)
541 {
542 uint8_t expected_ServerSignature[SCRAM_KEY_LEN];
543 uint8_t ServerKey[SCRAM_KEY_LEN];
544 scram_HMAC_ctx ctx;
545
546 if (user->has_scram_keys)
547 memcpy(ServerKey, user->scram_ServerKey, SCRAM_KEY_LEN);
548 else
549 scram_ServerKey(scram_state->SaltedPassword, ServerKey);
550
551 scram_HMAC_init(&ctx, ServerKey, SCRAM_KEY_LEN);
552 scram_HMAC_update(&ctx,
553 scram_state->client_first_message_bare,
554 strlen(scram_state->client_first_message_bare));
555 scram_HMAC_update(&ctx, ",", 1);
556 scram_HMAC_update(&ctx,
557 scram_state->server_first_message,
558 strlen(scram_state->server_first_message));
559 scram_HMAC_update(&ctx, ",", 1);
560 scram_HMAC_update(&ctx,
561 scram_state->client_final_message_without_proof,
562 strlen(scram_state->client_final_message_without_proof));
563 scram_HMAC_final(expected_ServerSignature, &ctx);
564
565 if (memcmp(expected_ServerSignature, ServerSignature, SCRAM_KEY_LEN) != 0)
566 return false;
567
568 return true;
569 }
570
571
572 /*
573 * Functions for communicating as a server to the client
574 */
575
read_client_first_message(PgSocket * client,char * input,char * cbind_flag_p,char ** client_first_message_bare_p,char ** client_nonce_p)576 bool read_client_first_message(PgSocket *client, char *input,
577 char *cbind_flag_p,
578 char **client_first_message_bare_p,
579 char **client_nonce_p)
580 {
581 char *client_first_message_bare = NULL;
582 char *client_nonce = NULL;
583 char *client_nonce_copy = NULL;
584
585 *cbind_flag_p = *input;
586 switch (*input) {
587 case 'n':
588 /* Client does not support channel binding */
589 input++;
590 break;
591 case 'y':
592 /* Client supports channel binding, but we're not doing it today */
593 input++;
594 break;
595 case 'p':
596 /* Client requires channel binding. We don't support it. */
597 slog_error(client, "client requires SCRAM channel binding, but it is not supported");
598 goto failed;
599 default:
600 slog_error(client, "malformed SCRAM message (unexpected channel-binding flag \"%s\")",
601 sanitize_char(*input));
602 goto failed;
603 }
604
605 if (*input != ',') {
606 slog_error(client, "malformed SCRAM message (comma expected, but found character \"%s\")",
607 sanitize_char(*input));
608 goto failed;
609 }
610 input++;
611
612 if (*input == 'a') {
613 slog_error(client, "client uses authorization identity, but it is not supported");
614 goto failed;
615 }
616 if (*input != ',') {
617 slog_error(client, "malformed SCRAM message (unexpected attribute \"%s\" in client-first-message)",
618 sanitize_char(*input));
619 goto failed;
620 }
621 input++;
622
623 client_first_message_bare = strdup(input);
624 if (client_first_message_bare == NULL)
625 goto failed;
626
627 if (*input == 'm') {
628 slog_error(client, "client requires an unsupported SCRAM extension");
629 goto failed;
630 }
631
632 /* read and ignore user name */
633 read_attr_value(client, &input, 'n');
634
635 client_nonce = read_attr_value(client, &input, 'r');
636 if (client_nonce == NULL)
637 goto failed;
638 if (!is_scram_printable(client_nonce)) {
639 slog_error(client, "non-printable characters in SCRAM nonce");
640 goto failed;
641 }
642 client_nonce_copy = strdup(client_nonce);
643 if (client_nonce_copy == NULL)
644 goto failed;
645
646 /*
647 * There can be any number of optional extensions after this. We don't
648 * support any extensions, so ignore them.
649 */
650 while (*input != '\0') {
651 if (!read_any_attr(client, &input, NULL))
652 goto failed;
653 }
654
655 *client_first_message_bare_p = client_first_message_bare;
656 *client_nonce_p = client_nonce_copy;
657 return true;
658 failed:
659 free(client_first_message_bare);
660 free(client_nonce_copy);
661 return false;
662 }
663
read_client_final_message(PgSocket * client,const uint8_t * raw_input,char * input,const char ** client_final_nonce_p,char ** proof_p)664 bool read_client_final_message(PgSocket *client, const uint8_t *raw_input, char *input,
665 const char **client_final_nonce_p,
666 char **proof_p)
667 {
668 const char *input_start = input;
669 char attr;
670 char *channel_binding;
671 char *client_final_nonce;
672 char *proof_start;
673 char *value;
674 char *encoded_proof;
675 char *proof = NULL;
676 int prooflen;
677
678 /*
679 * Read channel-binding. We don't support channel binding, so
680 * it's expected to always be "biws", which is "n,,",
681 * base64-encoded, or "eSws", which is "y,,". We also have to
682 * check whether the flag is the same one that the client
683 * originally sent.
684 */
685 channel_binding = read_attr_value(client, &input, 'c');
686 if (channel_binding == NULL)
687 goto failed;
688 if (!(strcmp(channel_binding, "biws") == 0 && client->scram_state.cbind_flag == 'n') &&
689 !(strcmp(channel_binding, "eSws") == 0 && client->scram_state.cbind_flag == 'y')) {
690 slog_error(client, "unexpected SCRAM channel-binding attribute in client-final-message");
691 goto failed;
692 }
693
694 client_final_nonce = read_attr_value(client, &input, 'r');
695
696 /* ignore optional extensions */
697 do
698 {
699 proof_start = input - 1;
700 value = read_any_attr(client, &input, &attr);
701 } while (value && attr != 'p');
702
703 if (!value) {
704 slog_error(client, "could not read proof");
705 goto failed;
706 }
707
708 encoded_proof = value;
709
710 proof = malloc(pg_b64_dec_len(strlen(encoded_proof)));
711 if (proof == NULL) {
712 slog_error(client, "could not decode proof");
713 goto failed;
714 }
715 prooflen = pg_b64_decode(encoded_proof,
716 strlen(encoded_proof),
717 proof);
718 (void) prooflen;
719
720 if (*input != '\0') {
721 slog_error(client, "malformed SCRAM message (garbage at the end of client-final-message)");
722 goto failed;
723 }
724
725 client->scram_state.client_final_message_without_proof = malloc(proof_start - input_start + 1);
726 if (!client->scram_state.client_final_message_without_proof)
727 goto failed;
728 memcpy(client->scram_state.client_final_message_without_proof, raw_input, proof_start - input_start);
729 client->scram_state.client_final_message_without_proof[proof_start - input_start] = '\0';
730
731 *client_final_nonce_p = client_final_nonce;
732 *proof_p = proof;
733 return true;
734 failed:
735 free(proof);
736 return false;
737 }
738
739 /*
740 * For doing SCRAM with a password stored in plain text, build a SCRAM
741 * secret on the fly.
742 */
build_adhoc_scram_secret(const char * plain_password,ScramState * scram_state)743 static bool build_adhoc_scram_secret(const char *plain_password, ScramState *scram_state)
744 {
745 const char *password;
746 char *prep_password;
747 pg_saslprep_rc rc;
748 char saltbuf[SCRAM_DEFAULT_SALT_LEN];
749 int encoded_len;
750 uint8_t salted_password[SCRAM_KEY_LEN];
751
752 rc = pg_saslprep(plain_password, &prep_password);
753 if (rc == SASLPREP_OOM)
754 goto failed;
755 else if (rc == SASLPREP_SUCCESS)
756 password = prep_password;
757 else
758 password = plain_password;
759
760 get_random_bytes((uint8_t *) saltbuf, sizeof(saltbuf));
761
762 scram_state->adhoc = true;
763
764 scram_state->iterations = SCRAM_DEFAULT_ITERATIONS;
765
766 scram_state->salt = malloc(pg_b64_enc_len(sizeof(saltbuf)) + 1);
767 if (!scram_state->salt)
768 goto failed;
769 encoded_len = pg_b64_encode(saltbuf, sizeof(saltbuf), scram_state->salt);
770 scram_state->salt[encoded_len] = '\0';
771
772 /* Calculate StoredKey and ServerKey */
773 scram_SaltedPassword(password, saltbuf, sizeof(saltbuf),
774 scram_state->iterations,
775 salted_password);
776 scram_ClientKey(salted_password, scram_state->StoredKey);
777 scram_H(scram_state->StoredKey, SCRAM_KEY_LEN, scram_state->StoredKey);
778 scram_ServerKey(salted_password, scram_state->ServerKey);
779
780 if (prep_password)
781 free(prep_password);
782 return true;
783 failed:
784 if (prep_password)
785 free(prep_password);
786 return false;
787 }
788
789 /*
790 * Deterministically generate salt for mock authentication, using a
791 * SHA256 hash based on the username and an instance-level secret key.
792 * Target buffer needs to be of size SCRAM_DEFAULT_SALT_LEN.
793 */
scram_mock_salt(const char * username,uint8_t * saltbuf)794 static void scram_mock_salt(const char *username, uint8_t *saltbuf)
795 {
796 static uint8_t mock_auth_nonce[32];
797 static bool mock_auth_nonce_initialized = false;
798 struct sha256_ctx ctx;
799 uint8_t sha_digest[PG_SHA256_DIGEST_LENGTH];
800
801 /*
802 * Generating salt using a SHA256 hash works as long as the
803 * required salt length is not larger than the SHA256 digest
804 * length.
805 */
806 static_assert(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
807 "salt length greater than SHA256 digest length");
808
809 if (!mock_auth_nonce_initialized) {
810 get_random_bytes(mock_auth_nonce, sizeof(mock_auth_nonce));
811 mock_auth_nonce_initialized = true;
812 }
813
814 sha256_reset(&ctx);
815 sha256_update(&ctx, (uint8_t *) username, strlen(username));
816 sha256_update(&ctx, mock_auth_nonce, sizeof(mock_auth_nonce));
817 sha256_final(&ctx, sha_digest);
818
819 memcpy(saltbuf, sha_digest, SCRAM_DEFAULT_SALT_LEN);
820 }
821
build_mock_scram_secret(const char * username,ScramState * scram_state)822 static bool build_mock_scram_secret(const char *username, ScramState *scram_state)
823 {
824 uint8_t saltbuf[SCRAM_DEFAULT_SALT_LEN];
825 int encoded_len;
826
827 scram_state->iterations = SCRAM_DEFAULT_ITERATIONS;
828
829 scram_mock_salt(username, saltbuf);
830 scram_state->salt = malloc(pg_b64_enc_len(sizeof(saltbuf)) + 1);
831 if (!scram_state->salt)
832 goto failed;
833 encoded_len = pg_b64_encode((char *) saltbuf, sizeof(saltbuf), scram_state->salt);
834 scram_state->salt[encoded_len] = '\0';
835
836 return true;
837 failed:
838 return false;
839 }
840
build_server_first_message(ScramState * scram_state,const char * username,const char * stored_secret)841 char *build_server_first_message(ScramState *scram_state, const char *username, const char *stored_secret)
842 {
843 uint8_t raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
844 int encoded_len;
845 size_t len;
846 char *result;
847
848 if (!stored_secret) {
849 if (!build_mock_scram_secret(username, scram_state))
850 goto failed;
851 } else {
852 switch (get_password_type(stored_secret)) {
853 case PASSWORD_TYPE_SCRAM_SHA_256:
854 if (!parse_scram_verifier(stored_secret,
855 &scram_state->iterations,
856 &scram_state->salt,
857 scram_state->StoredKey,
858 scram_state->ServerKey))
859 goto failed;
860 break;
861 case PASSWORD_TYPE_PLAINTEXT:
862 if (!build_adhoc_scram_secret(stored_secret, scram_state))
863 goto failed;
864 break;
865 default:
866 /* shouldn't get here */
867 goto failed;
868 }
869 }
870
871 get_random_bytes(raw_nonce, SCRAM_RAW_NONCE_LEN);
872 scram_state->server_nonce = malloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
873 if (scram_state->server_nonce == NULL)
874 goto failed;
875 encoded_len = pg_b64_encode((char *) raw_nonce, SCRAM_RAW_NONCE_LEN, scram_state->server_nonce);
876 scram_state->server_nonce[encoded_len] = '\0';
877
878 len = (2
879 + strlen(scram_state->client_nonce)
880 + strlen(scram_state->server_nonce)
881 + 3
882 + strlen(scram_state->salt)
883 + 3 + 10 + 1);
884 result = malloc(len);
885 if (!result)
886 goto failed;
887 snprintf(result, len,
888 "r=%s%s,s=%s,i=%u",
889 scram_state->client_nonce,
890 scram_state->server_nonce,
891 scram_state->salt,
892 scram_state->iterations);
893
894 scram_state->server_first_message = result;
895
896 return result;
897 failed:
898 free(scram_state->server_nonce);
899 free(scram_state->server_first_message);
900 return NULL;
901 }
902
903 static char *
compute_server_signature(ScramState * state)904 compute_server_signature(ScramState *state)
905 {
906 uint8_t ServerSignature[SCRAM_KEY_LEN];
907 char *server_signature_base64;
908 int siglen;
909 scram_HMAC_ctx ctx;
910
911 /* calculate ServerSignature */
912 scram_HMAC_init(&ctx, state->ServerKey, SCRAM_KEY_LEN);
913 scram_HMAC_update(&ctx,
914 state->client_first_message_bare,
915 strlen(state->client_first_message_bare));
916 scram_HMAC_update(&ctx, ",", 1);
917 scram_HMAC_update(&ctx,
918 state->server_first_message,
919 strlen(state->server_first_message));
920 scram_HMAC_update(&ctx, ",", 1);
921 scram_HMAC_update(&ctx,
922 state->client_final_message_without_proof,
923 strlen(state->client_final_message_without_proof));
924 scram_HMAC_final(ServerSignature, &ctx);
925
926 server_signature_base64 = malloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1);
927 if (!server_signature_base64)
928 return NULL;
929 siglen = pg_b64_encode((const char *) ServerSignature,
930 SCRAM_KEY_LEN, server_signature_base64);
931 server_signature_base64[siglen] = '\0';
932
933 return server_signature_base64;
934 }
935
build_server_final_message(ScramState * scram_state)936 char *build_server_final_message(ScramState *scram_state)
937 {
938 char *server_signature = NULL;
939 size_t len;
940 char *result;
941
942 server_signature = compute_server_signature(scram_state);
943 if (!server_signature)
944 goto failed;
945
946 len = 2 + strlen(server_signature) + 1;
947 result = malloc(len);
948 if (!result)
949 goto failed;
950 snprintf(result, len, "v=%s", server_signature);
951
952 free(server_signature);
953 return result;
954 failed:
955 free(server_signature);
956 return NULL;
957 }
958
verify_final_nonce(const ScramState * scram_state,const char * client_final_nonce)959 bool verify_final_nonce(const ScramState *scram_state, const char *client_final_nonce)
960 {
961 size_t client_nonce_len = strlen(scram_state->client_nonce);
962 size_t server_nonce_len = strlen(scram_state->server_nonce);
963 size_t final_nonce_len = strlen(client_final_nonce);
964
965 if (final_nonce_len != client_nonce_len + server_nonce_len)
966 return false;
967 if (memcmp(client_final_nonce, scram_state->client_nonce, client_nonce_len) != 0)
968 return false;
969 if (memcmp(client_final_nonce + client_nonce_len, scram_state->server_nonce, server_nonce_len) != 0)
970 return false;
971
972 return true;
973 }
974
verify_client_proof(ScramState * state,const char * ClientProof)975 bool verify_client_proof(ScramState *state, const char *ClientProof)
976 {
977 uint8_t ClientSignature[SCRAM_KEY_LEN];
978 uint8_t client_StoredKey[SCRAM_KEY_LEN];
979 scram_HMAC_ctx ctx;
980 int i;
981
982 /* calculate ClientSignature */
983 scram_HMAC_init(&ctx, state->StoredKey, SCRAM_KEY_LEN);
984 scram_HMAC_update(&ctx,
985 state->client_first_message_bare,
986 strlen(state->client_first_message_bare));
987 scram_HMAC_update(&ctx, ",", 1);
988 scram_HMAC_update(&ctx,
989 state->server_first_message,
990 strlen(state->server_first_message));
991 scram_HMAC_update(&ctx, ",", 1);
992 scram_HMAC_update(&ctx,
993 state->client_final_message_without_proof,
994 strlen(state->client_final_message_without_proof));
995 scram_HMAC_final(ClientSignature, &ctx);
996
997 /* Extract the ClientKey that the client calculated from the proof */
998 for (i = 0; i < SCRAM_KEY_LEN; i++)
999 state->ClientKey[i] = ClientProof[i] ^ ClientSignature[i];
1000
1001 /* Hash it one more time, and compare with StoredKey */
1002 scram_H(state->ClientKey, SCRAM_KEY_LEN, client_StoredKey);
1003
1004 if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0)
1005 return false;
1006
1007 return true;
1008 }
1009
1010 /*
1011 * Verify a plaintext password against a SCRAM verifier. This is used when
1012 * performing plaintext password authentication for a user that has a SCRAM
1013 * verifier stored in pg_authid.
1014 */
1015 bool
scram_verify_plain_password(PgSocket * client,const char * username,const char * password,const char * verifier)1016 scram_verify_plain_password(PgSocket *client,
1017 const char *username, const char *password,
1018 const char *verifier)
1019 {
1020 char *encoded_salt = NULL;
1021 char *salt = NULL;
1022 int saltlen;
1023 int iterations;
1024 uint8_t salted_password[SCRAM_KEY_LEN];
1025 uint8_t stored_key[SCRAM_KEY_LEN];
1026 uint8_t server_key[SCRAM_KEY_LEN];
1027 uint8_t computed_key[SCRAM_KEY_LEN];
1028 char *prep_password = NULL;
1029 pg_saslprep_rc rc;
1030 bool result = false;
1031
1032 if (!parse_scram_verifier(verifier, &iterations, &encoded_salt,
1033 stored_key, server_key))
1034 {
1035 /* The password looked like a SCRAM verifier, but could not be parsed. */
1036 slog_warning(client, "invalid SCRAM verifier for user \"%s\"", username);
1037 goto failed;
1038 }
1039
1040 salt = malloc(pg_b64_dec_len(strlen(encoded_salt)));
1041 if (!salt)
1042 goto failed;
1043 saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt);
1044 if (saltlen == -1)
1045 {
1046 slog_warning(client, "invalid SCRAM verifier for user \"%s\"", username);
1047 goto failed;
1048 }
1049
1050 /* Normalize the password */
1051 rc = pg_saslprep(password, &prep_password);
1052 if (rc == SASLPREP_SUCCESS)
1053 password = prep_password;
1054
1055 /* Compute Server Key based on the user-supplied plaintext password */
1056 scram_SaltedPassword(password, salt, saltlen, iterations, salted_password);
1057 scram_ServerKey(salted_password, computed_key);
1058
1059 /*
1060 * Compare the verifier's Server Key with the one computed from the
1061 * user-supplied password.
1062 */
1063 result = memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0;
1064
1065 failed:
1066 free(encoded_salt);
1067 free(salt);
1068 free(prep_password);
1069 return result;
1070 }
1071