1 /*-------------------------------------------------------------------------
2  * scram-common.c
3  *		Shared frontend/backend code for SCRAM authentication
4  *
5  * This contains the common low-level functions needed in both frontend and
6  * backend, for implement the Salted Challenge Response Authentication
7  * Mechanism (SCRAM), per IETF's RFC 5802.
8  *
9  * Portions Copyright (c) 2017-2021, PostgreSQL Global Development Group
10  *
11  * IDENTIFICATION
12  *	  src/common/scram-common.c
13  *
14  *-------------------------------------------------------------------------
15  */
16 #ifndef FRONTEND
17 #include "postgres.h"
18 #else
19 #include "postgres_fe.h"
20 #endif
21 
22 #include "common/base64.h"
23 #include "common/hmac.h"
24 #include "common/scram-common.h"
25 #include "port/pg_bswap.h"
26 
27 /*
28  * Calculate SaltedPassword.
29  *
30  * The password should already be normalized by SASLprep.  Returns 0 on
31  * success, -1 on failure.
32  */
33 int
scram_SaltedPassword(const char * password,const char * salt,int saltlen,int iterations,uint8 * result)34 scram_SaltedPassword(const char *password,
35 					 const char *salt, int saltlen, int iterations,
36 					 uint8 *result)
37 {
38 	int			password_len = strlen(password);
39 	uint32		one = pg_hton32(1);
40 	int			i,
41 				j;
42 	uint8		Ui[SCRAM_KEY_LEN];
43 	uint8		Ui_prev[SCRAM_KEY_LEN];
44 	pg_hmac_ctx *hmac_ctx = pg_hmac_create(PG_SHA256);
45 
46 	if (hmac_ctx == NULL)
47 		return -1;
48 
49 	/*
50 	 * Iterate hash calculation of HMAC entry using given salt.  This is
51 	 * essentially PBKDF2 (see RFC2898) with HMAC() as the pseudorandom
52 	 * function.
53 	 */
54 
55 	/* First iteration */
56 	if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
57 		pg_hmac_update(hmac_ctx, (uint8 *) salt, saltlen) < 0 ||
58 		pg_hmac_update(hmac_ctx, (uint8 *) &one, sizeof(uint32)) < 0 ||
59 		pg_hmac_final(hmac_ctx, Ui_prev, sizeof(Ui_prev)) < 0)
60 	{
61 		pg_hmac_free(hmac_ctx);
62 		return -1;
63 	}
64 
65 	memcpy(result, Ui_prev, SCRAM_KEY_LEN);
66 
67 	/* Subsequent iterations */
68 	for (i = 2; i <= iterations; i++)
69 	{
70 		if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
71 			pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, SCRAM_KEY_LEN) < 0 ||
72 			pg_hmac_final(hmac_ctx, Ui, sizeof(Ui)) < 0)
73 		{
74 			pg_hmac_free(hmac_ctx);
75 			return -1;
76 		}
77 
78 		for (j = 0; j < SCRAM_KEY_LEN; j++)
79 			result[j] ^= Ui[j];
80 		memcpy(Ui_prev, Ui, SCRAM_KEY_LEN);
81 	}
82 
83 	pg_hmac_free(hmac_ctx);
84 	return 0;
85 }
86 
87 
88 /*
89  * Calculate SHA-256 hash for a NULL-terminated string. (The NULL terminator is
90  * not included in the hash).  Returns 0 on success, -1 on failure.
91  */
92 int
scram_H(const uint8 * input,int len,uint8 * result)93 scram_H(const uint8 *input, int len, uint8 *result)
94 {
95 	pg_cryptohash_ctx *ctx;
96 
97 	ctx = pg_cryptohash_create(PG_SHA256);
98 	if (ctx == NULL)
99 		return -1;
100 
101 	if (pg_cryptohash_init(ctx) < 0 ||
102 		pg_cryptohash_update(ctx, input, len) < 0 ||
103 		pg_cryptohash_final(ctx, result, SCRAM_KEY_LEN) < 0)
104 	{
105 		pg_cryptohash_free(ctx);
106 		return -1;
107 	}
108 
109 	pg_cryptohash_free(ctx);
110 	return 0;
111 }
112 
113 /*
114  * Calculate ClientKey.  Returns 0 on success, -1 on failure.
115  */
116 int
scram_ClientKey(const uint8 * salted_password,uint8 * result)117 scram_ClientKey(const uint8 *salted_password, uint8 *result)
118 {
119 	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
120 
121 	if (ctx == NULL)
122 		return -1;
123 
124 	if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 ||
125 		pg_hmac_update(ctx, (uint8 *) "Client Key", strlen("Client Key")) < 0 ||
126 		pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
127 	{
128 		pg_hmac_free(ctx);
129 		return -1;
130 	}
131 
132 	pg_hmac_free(ctx);
133 	return 0;
134 }
135 
136 /*
137  * Calculate ServerKey.  Returns 0 on success, -1 on failure.
138  */
139 int
scram_ServerKey(const uint8 * salted_password,uint8 * result)140 scram_ServerKey(const uint8 *salted_password, uint8 *result)
141 {
142 	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
143 
144 	if (ctx == NULL)
145 		return -1;
146 
147 	if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 ||
148 		pg_hmac_update(ctx, (uint8 *) "Server Key", strlen("Server Key")) < 0 ||
149 		pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
150 	{
151 		pg_hmac_free(ctx);
152 		return -1;
153 	}
154 
155 	pg_hmac_free(ctx);
156 	return 0;
157 }
158 
159 
160 /*
161  * Construct a SCRAM secret, for storing in pg_authid.rolpassword.
162  *
163  * The password should already have been processed with SASLprep, if necessary!
164  *
165  * If iterations is 0, default number of iterations is used.  The result is
166  * palloc'd or malloc'd, so caller is responsible for freeing it.
167  */
168 char *
scram_build_secret(const char * salt,int saltlen,int iterations,const char * password)169 scram_build_secret(const char *salt, int saltlen, int iterations,
170 				   const char *password)
171 {
172 	uint8		salted_password[SCRAM_KEY_LEN];
173 	uint8		stored_key[SCRAM_KEY_LEN];
174 	uint8		server_key[SCRAM_KEY_LEN];
175 	char	   *result;
176 	char	   *p;
177 	int			maxlen;
178 	int			encoded_salt_len;
179 	int			encoded_stored_len;
180 	int			encoded_server_len;
181 	int			encoded_result;
182 
183 	if (iterations <= 0)
184 		iterations = SCRAM_DEFAULT_ITERATIONS;
185 
186 	/* Calculate StoredKey and ServerKey */
187 	if (scram_SaltedPassword(password, salt, saltlen, iterations,
188 							 salted_password) < 0 ||
189 		scram_ClientKey(salted_password, stored_key) < 0 ||
190 		scram_H(stored_key, SCRAM_KEY_LEN, stored_key) < 0 ||
191 		scram_ServerKey(salted_password, server_key) < 0)
192 	{
193 #ifdef FRONTEND
194 		return NULL;
195 #else
196 		elog(ERROR, "could not calculate stored key and server key");
197 #endif
198 	}
199 
200 	/*----------
201 	 * The format is:
202 	 * SCRAM-SHA-256$<iteration count>:<salt>$<StoredKey>:<ServerKey>
203 	 *----------
204 	 */
205 	encoded_salt_len = pg_b64_enc_len(saltlen);
206 	encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN);
207 	encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN);
208 
209 	maxlen = strlen("SCRAM-SHA-256") + 1
210 		+ 10 + 1				/* iteration count */
211 		+ encoded_salt_len + 1	/* Base64-encoded salt */
212 		+ encoded_stored_len + 1	/* Base64-encoded StoredKey */
213 		+ encoded_server_len + 1;	/* Base64-encoded ServerKey */
214 
215 #ifdef FRONTEND
216 	result = malloc(maxlen);
217 	if (!result)
218 		return NULL;
219 #else
220 	result = palloc(maxlen);
221 #endif
222 
223 	p = result + sprintf(result, "SCRAM-SHA-256$%d:", iterations);
224 
225 	/* salt */
226 	encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len);
227 	if (encoded_result < 0)
228 	{
229 #ifdef FRONTEND
230 		free(result);
231 		return NULL;
232 #else
233 		elog(ERROR, "could not encode salt");
234 #endif
235 	}
236 	p += encoded_result;
237 	*(p++) = '$';
238 
239 	/* stored key */
240 	encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p,
241 								   encoded_stored_len);
242 	if (encoded_result < 0)
243 	{
244 #ifdef FRONTEND
245 		free(result);
246 		return NULL;
247 #else
248 		elog(ERROR, "could not encode stored key");
249 #endif
250 	}
251 
252 	p += encoded_result;
253 	*(p++) = ':';
254 
255 	/* server key */
256 	encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p,
257 								   encoded_server_len);
258 	if (encoded_result < 0)
259 	{
260 #ifdef FRONTEND
261 		free(result);
262 		return NULL;
263 #else
264 		elog(ERROR, "could not encode server key");
265 #endif
266 	}
267 
268 	p += encoded_result;
269 	*(p++) = '\0';
270 
271 	Assert(p - result <= maxlen);
272 
273 	return result;
274 }
275