1 /*
2  * PgBouncer - Lightweight connection pooler for PostgreSQL.
3  *
4  * Copyright (c) 2007-2009  Marko Kreen, Skype Technologies OÜ
5  *
6  * Permission to use, copy, modify, and/or distribute this software for any
7  * purpose with or without fee is hereby granted, provided that the above
8  * copyright notice and this permission notice appear in all copies.
9  *
10  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17  */
18 
19 /*
20  * Pieces that need to have detailed info about protocol.
21  */
22 
23 #include "bouncer.h"
24 #include "scram.h"
25 
26 /*
27  * parse protocol header from struct MBuf
28  */
29 
30 /* parses pkt header from buffer, returns false if failed */
get_header(struct MBuf * data,PktHdr * pkt)31 bool get_header(struct MBuf *data, PktHdr *pkt)
32 {
33 	unsigned type;
34 	uint32_t len;
35 	unsigned got;
36 	unsigned avail;
37 	uint16_t len16;
38 	uint8_t type8;
39 	uint32_t code;
40 	struct MBuf hdr;
41 	const uint8_t *ptr;
42 
43 	mbuf_copy(data, &hdr);
44 
45 	if (mbuf_avail_for_read(&hdr) < NEW_HEADER_LEN) {
46 		log_noise("get_header: less than 5 bytes available");
47 		return false;
48 	}
49 	if (!mbuf_get_byte(&hdr, &type8))
50 		return false;
51 	type = type8;
52 	if (type != 0) {
53 		/* wire length does not include type byte */
54 		if (!mbuf_get_uint32be(&hdr, &len))
55 			return false;
56 		len++;
57 		got = NEW_HEADER_LEN;
58 	} else {
59 		if (!mbuf_get_byte(&hdr, &type8))
60 			return false;
61 		if (type8 != 0) {
62 			log_noise("get_header: unknown special pkt");
63 			return false;
64 		}
65 		/* don't tolerate partial pkt */
66 		if (mbuf_avail_for_read(&hdr) < OLD_HEADER_LEN - 2) {
67 			log_noise("get_header: less than 8 bytes for special pkt");
68 			return false;
69 		}
70 		if (!mbuf_get_uint16be(&hdr, &len16))
71 			return false;
72 		len = len16;
73 		if (!mbuf_get_uint32be(&hdr, &code))
74 			return false;
75 		if (code == PKT_CANCEL) {
76 			type = PKT_CANCEL;
77 		} else if (code == PKT_SSLREQ) {
78 			type = PKT_SSLREQ;
79 		} else if (code == PKT_GSSENCREQ) {
80 			type = PKT_GSSENCREQ;
81 		} else if ((code >> 16) == 3 && (code & 0xFFFF) < 2) {
82 			type = PKT_STARTUP;
83 		} else if (code == PKT_STARTUP_V2) {
84 			type = PKT_STARTUP_V2;
85 		} else {
86 			log_noise("get_header: unknown special pkt: len=%u code=%u", len, code);
87 			return false;
88 		}
89 		got = OLD_HEADER_LEN;
90 	}
91 
92 	/* don't believe nonsense */
93 	if (len < got || len > cf_max_packet_size)
94 		return false;
95 
96 	/* store pkt info */
97 	pkt->type = type;
98 	pkt->len = len;
99 
100 	/* fill pkt with only data for this packet */
101 	if (len > mbuf_avail_for_read(data)) {
102 		avail = mbuf_avail_for_read(data);
103 	} else {
104 		avail = len;
105 	}
106 	if (!mbuf_slice(data, avail, &pkt->data))
107 		return false;
108 
109 	/* tag header as read */
110 	return mbuf_get_bytes(&pkt->data, got, &ptr);
111 }
112 
113 
114 /*
115  * Send error message packet to client.
116  *
117  * If level_fatal is true, use severity "FATAL", else "ERROR".  Although it is
118  * not technically part of the protocol specification, some clients expect the
119  * connection to be closed after receiving a FATAL error, and don't expect it
120  * to be closed after an ERROR-level error.  So to be nice, level_fatal should
121  * be true if the caller plans to close the connection after sending this
122  * error.
123  */
send_pooler_error(PgSocket * client,bool send_ready,bool level_fatal,const char * msg)124 bool send_pooler_error(PgSocket *client, bool send_ready, bool level_fatal, const char *msg)
125 {
126 	uint8_t tmpbuf[512];
127 	PktBuf buf;
128 
129 	if (cf_log_pooler_errors)
130 		slog_warning(client, "pooler error: %s", msg);
131 
132 	pktbuf_static(&buf, tmpbuf, sizeof(tmpbuf));
133 	pktbuf_write_generic(&buf, 'E', "cscscsc",
134 			     'S', level_fatal ? "FATAL" : "ERROR",
135 			     'C', "08P01", 'M', msg, 0);
136 	if (send_ready)
137 		pktbuf_write_ReadyForQuery(&buf);
138 	return pktbuf_send_immediate(&buf, client);
139 }
140 
141 /*
142  * Parse server error message and log it.
143  */
parse_server_error(PktHdr * pkt,const char ** level_p,const char ** msg_p)144 void parse_server_error(PktHdr *pkt, const char **level_p, const char **msg_p)
145 {
146 	const char *level = NULL, *msg = NULL, *val;
147 	uint8_t type;
148 	while (mbuf_avail_for_read(&pkt->data)) {
149 		if (!mbuf_get_byte(&pkt->data, &type))
150 			break;
151 		if (type == 0)
152 			break;
153 		if (!mbuf_get_string(&pkt->data, &val))
154 			break;
155 		if (type == 'S') {
156 			level = val;
157 		} else if (type == 'M') {
158 			msg = val;
159 		}
160 	}
161 	*level_p = level;
162 	*msg_p = msg;
163 }
164 
log_server_error(const char * note,PktHdr * pkt)165 void log_server_error(const char *note, PktHdr *pkt)
166 {
167 	const char *level = NULL, *msg = NULL;
168 
169 	parse_server_error(pkt, &level, &msg);
170 
171 	if (!msg || !level) {
172 		log_error("%s: partial error message, cannot log", note);
173 	} else {
174 		log_error("%s: %s: %s", note, level, msg);
175 	}
176 }
177 
178 
179 /*
180  * Preparation of welcome message for client connection.
181  */
182 
183 /* add another server parameter packet to cache */
add_welcome_parameter(PgPool * pool,const char * key,const char * val)184 bool add_welcome_parameter(PgPool *pool, const char *key, const char *val)
185 {
186 	PktBuf *msg = pool->welcome_msg;
187 
188 	if (pool->welcome_msg_ready)
189 		return true;
190 
191 	if (!msg) {
192 		msg = pktbuf_dynamic(128);
193 		if (!msg)
194 			return false;
195 		pool->welcome_msg = msg;
196 	}
197 
198 	/* first packet must be AuthOk */
199 	if (msg->write_pos == 0)
200 		pktbuf_write_AuthenticationOk(msg);
201 
202 	/* if not stored in ->orig_vars, write full packet */
203 	if (!varcache_set(&pool->orig_vars, key, val))
204 		pktbuf_write_ParameterStatus(msg, key, val);
205 
206 	return !msg->failed;
207 }
208 
209 /* all parameters processed */
finish_welcome_msg(PgSocket * server)210 void finish_welcome_msg(PgSocket *server)
211 {
212 	PgPool *pool = server->pool;
213 	if (pool->welcome_msg_ready)
214 		return;
215 	pool->welcome_msg_ready = true;
216 }
217 
welcome_client(PgSocket * client)218 bool welcome_client(PgSocket *client)
219 {
220 	int res;
221 	PgPool *pool = client->pool;
222 	const PktBuf *pmsg = pool->welcome_msg;
223 	PktBuf *msg;
224 
225 	slog_noise(client, "P: welcome_client");
226 
227 	/* copy prepared stuff around */
228 	msg = pktbuf_temp();
229 	pktbuf_put_bytes(msg, pmsg->buf, pmsg->write_pos);
230 
231 	/* fill vars */
232 	varcache_fill_unset(&pool->orig_vars, client);
233 	varcache_add_params(msg, &client->vars);
234 
235 	/* give each client its own cancel key */
236 	get_random_bytes(client->cancel_key, 8);
237 	pktbuf_write_BackendKeyData(msg, client->cancel_key);
238 
239 	/* finish */
240 	pktbuf_write_ReadyForQuery(msg);
241 	if (msg->failed) {
242 		disconnect_client(client, true, "failed to prepare welcome message");
243 		return false;
244 	}
245 
246 	/* send all together */
247 	res = pktbuf_send_immediate(msg, client);
248 	if (!res) {
249 		disconnect_client(client, true, "failed to send welcome message");
250 		return false;
251 	}
252 	return true;
253 }
254 
255 /*
256  * Password authentication for server
257  */
258 
get_srv_psw(PgSocket * server)259 static PgUser *get_srv_psw(PgSocket *server)
260 {
261 	PgDatabase *db = server->pool->db;
262 	PgUser *user = server->pool->user;
263 
264 	/* if forced user without password, use userlist psw */
265 	if (!user->passwd[0] && db->forced_user) {
266 		PgUser *u2 = find_user(user->name);
267 		if (u2)
268 			return u2;
269 	}
270 	return user;
271 }
272 
273 /* actual packet send */
send_password(PgSocket * server,const char * enc_psw)274 static bool send_password(PgSocket *server, const char *enc_psw)
275 {
276 	bool res;
277 	SEND_PasswordMessage(res, server, enc_psw);
278 	return res;
279 }
280 
login_clear_psw(PgSocket * server)281 static bool login_clear_psw(PgSocket *server)
282 {
283 	PgUser *user = get_srv_psw(server);
284 	slog_debug(server, "P: send clear password");
285 	return send_password(server, user->passwd);
286 }
287 
login_md5_psw(PgSocket * server,const uint8_t * salt)288 static bool login_md5_psw(PgSocket *server, const uint8_t *salt)
289 {
290 	char txt[MD5_PASSWD_LEN + 1], *src;
291 	PgUser *user = get_srv_psw(server);
292 
293 	slog_debug(server, "P: send md5 password");
294 
295 	switch (get_password_type(user->passwd)) {
296 	case PASSWORD_TYPE_PLAINTEXT:
297 		pg_md5_encrypt(user->passwd, user->name, strlen(user->name), txt);
298 		src = txt + 3;
299 		break;
300 	case PASSWORD_TYPE_MD5:
301 		src = user->passwd + 3;
302 		break;
303 	default:
304 		slog_error(server, "cannot do MD5 authentication: wrong password type");
305 		kill_pool_logins(server->pool, "server login failed: wrong password type");
306 		return false;
307 	}
308 
309 	pg_md5_encrypt(src, (char *)salt, 4, txt);
310 
311 	return send_password(server, txt);
312 }
313 
login_scram_sha_256(PgSocket * server)314 static bool login_scram_sha_256(PgSocket *server)
315 {
316 	PgUser *user = get_srv_psw(server);
317 	bool res;
318 	char *client_first_message = NULL;
319 
320 	switch (get_password_type(user->passwd)) {
321 	case PASSWORD_TYPE_PLAINTEXT:
322 		/* ok */
323 		break;
324 	case PASSWORD_TYPE_SCRAM_SHA_256:
325 		if (!user->has_scram_keys) {
326 			slog_error(server, "cannot do SCRAM authentication: password is SCRAM secret but client authentication did not provide SCRAM keys");
327 			kill_pool_logins(server->pool, "server login failed: wrong password type");
328 			return false;
329 		}
330 		break;
331 	default:
332 		slog_error(server, "cannot do SCRAM authentication: wrong password type");
333 		kill_pool_logins(server->pool, "server login failed: wrong password type");
334 		return false;
335 	}
336 
337 	if (server->scram_state.client_nonce)
338 	{
339 		slog_error(server, "protocol error: duplicate AuthenticationSASL message from server");
340 		return false;
341 	}
342 
343 	client_first_message = build_client_first_message(&server->scram_state);
344 	if (!client_first_message)
345 		return false;
346 
347 	slog_debug(server, "SCRAM client-first-message = \"%s\"", client_first_message);
348 	slog_debug(server, "P: send SASLInitialResponse");
349 	SEND_SASLInitialResponseMessage(res, server, "SCRAM-SHA-256", client_first_message);
350 
351 	free(client_first_message);
352 	return res;
353 }
354 
login_scram_sha_256_cont(PgSocket * server,unsigned datalen,const uint8_t * data)355 static bool login_scram_sha_256_cont(PgSocket *server, unsigned datalen, const uint8_t *data)
356 {
357 	PgUser *user = get_srv_psw(server);
358 	char *ibuf = NULL;
359 	char *input;
360 	char *server_nonce;
361 	int saltlen;
362 	char *salt = NULL;
363 	int iterations;
364 	bool res;
365 	char *client_final_message = NULL;
366 
367 	if (!server->scram_state.client_nonce)
368 	{
369 		slog_error(server, "protocol error: AuthenticationSASLContinue without prior AuthenticationSASL");
370 		return false;
371 	}
372 
373 	if (server->scram_state.server_first_message)
374 	{
375 		slog_error(server, "SCRAM exchange protocol error: received second AuthenticationSASLContinue");
376 		return false;
377 	}
378 
379 	ibuf = malloc(datalen + 1);
380 	if (ibuf == NULL)
381 		return false;
382 	memcpy(ibuf, data, datalen);
383 	ibuf[datalen] = '\0';
384 
385 	input = ibuf;
386 	slog_debug(server, "SCRAM server-first-message = \"%s\"", input);
387 	if (!read_server_first_message(server, input,
388 				       &server_nonce, &salt, &saltlen, &iterations))
389 		goto failed;
390 
391 	client_final_message = build_client_final_message(&server->scram_state,
392 							  user, server_nonce,
393 							  salt, saltlen, iterations);
394 
395 	free(salt);
396 	free(ibuf);
397 
398 	slog_debug(server, "SCRAM client-final-message = \"%s\"", client_final_message);
399 	slog_debug(server, "P: send SASLResponse");
400 	SEND_SASLResponseMessage(res, server, client_final_message);
401 
402 	free(client_final_message);
403 	return res;
404 failed:
405 	free(salt);
406 	free(ibuf);
407 	free(client_final_message);
408 	return false;
409 }
410 
login_scram_sha_256_final(PgSocket * server,unsigned datalen,const uint8_t * data)411 static bool login_scram_sha_256_final(PgSocket *server, unsigned datalen, const uint8_t *data)
412 {
413 	PgUser *user = get_srv_psw(server);
414 	char *ibuf = NULL;
415 	char *input;
416 	char ServerSignature[SHA256_DIGEST_LENGTH];
417 
418 	if (!server->scram_state.server_first_message)
419 	{
420 		slog_error(server, "protocol error: AuthenticationSASLFinal without prior AuthenticationSASLContinue");
421 		return false;
422 	}
423 
424 	ibuf = malloc(datalen + 1);
425 	if (ibuf == NULL)
426 		return false;
427 	memcpy(ibuf, data, datalen);
428 	ibuf[datalen] = '\0';
429 
430 	input = ibuf;
431 	slog_debug(server, "SCRAM server-final-message = \"%s\"", input);
432 	if (!read_server_final_message(server, input, ServerSignature))
433 		goto failed;
434 
435 	if (!verify_server_signature(&server->scram_state, user, ServerSignature))
436 	{
437 		slog_error(server, "invalid server signature");
438 		kill_pool_logins(server->pool, "server login failed: invalid server signature");
439 		return false;
440 	}
441 
442 	free(ibuf);
443 	return true;
444 failed:
445 	free(ibuf);
446 	return false;
447 }
448 
449 /* answer server authentication request */
answer_authreq(PgSocket * server,PktHdr * pkt)450 bool answer_authreq(PgSocket *server, PktHdr *pkt)
451 {
452 	uint32_t cmd;
453 	const uint8_t *salt;
454 	bool res = false;
455 
456 	/* authreq body must contain 32bit cmd */
457 	if (mbuf_avail_for_read(&pkt->data) < 4)
458 		return false;
459 
460 	if (!mbuf_get_uint32be(&pkt->data, &cmd))
461 		return false;
462 	switch (cmd) {
463 	case AUTH_OK:
464 		slog_debug(server, "S: auth ok");
465 		res = true;
466 		break;
467 	case AUTH_PLAIN:
468 		slog_debug(server, "S: req cleartext password");
469 		res = login_clear_psw(server);
470 		break;
471 	case AUTH_MD5:
472 		slog_debug(server, "S: req md5-crypted psw");
473 		if (!mbuf_get_bytes(&pkt->data, 4, &salt))
474 			return false;
475 		res = login_md5_psw(server, salt);
476 		break;
477 	case AUTH_SASL:
478 	{
479 		bool selected_mechanism = false;
480 
481 		slog_debug(server, "S: req SASL");
482 
483 		do {
484 			const char *mech;
485 
486 			if (!mbuf_get_string(&pkt->data, &mech))
487 				return false;
488 			if (!mech[0])
489 				break;
490 			slog_debug(server, "S: SASL advertised mechanism: %s", mech);
491 			if (strcmp(mech, "SCRAM-SHA-256") == 0)
492 				selected_mechanism = true;
493 		} while (!selected_mechanism);
494 
495 		if (!selected_mechanism) {
496 			slog_error(server, "none of the server's SASL authentication mechanisms are supported");
497 			kill_pool_logins(server->pool, "server login failed: none of the server's SASL authentication mechanisms are supported");
498 		} else
499 			res = login_scram_sha_256(server);
500 		break;
501 	}
502 	case AUTH_SASL_CONT:
503 	{
504 		unsigned len;
505 		const uint8_t *data;
506 
507 		slog_debug(server, "S: SASL cont");
508 		len = mbuf_avail_for_read(&pkt->data);
509 		if (!mbuf_get_bytes(&pkt->data, len, &data))
510 			return false;
511 		res = login_scram_sha_256_cont(server, len, data);
512 		break;
513 	}
514 	case AUTH_SASL_FIN:
515 	{
516 		unsigned len;
517 		const uint8_t *data;
518 
519 		slog_debug(server, "S: SASL final");
520 		len = mbuf_avail_for_read(&pkt->data);
521 		if (!mbuf_get_bytes(&pkt->data, len, &data))
522 			return false;
523 		res = login_scram_sha_256_final(server, len, data);
524 		free_scram_state(&server->scram_state);
525 		break;
526 	}
527 	default:
528 		slog_error(server, "unknown/unsupported auth method: %d", cmd);
529 		res = false;
530 		break;
531 	}
532 	return res;
533 }
534 
send_startup_packet(PgSocket * server)535 bool send_startup_packet(PgSocket *server)
536 {
537 	PgDatabase *db = server->pool->db;
538 	const char *username = server->pool->user->name;
539 	PktBuf *pkt;
540 
541 	pkt = pktbuf_temp();
542 	pktbuf_write_StartupMessage(pkt, username,
543 				    db->startup_params->buf,
544 				    db->startup_params->write_pos);
545 	return pktbuf_send_immediate(pkt, server);
546 }
547 
send_sslreq_packet(PgSocket * server)548 bool send_sslreq_packet(PgSocket *server)
549 {
550 	int res;
551 	SEND_wrap(16, pktbuf_write_SSLRequest, res, server);
552 	return res;
553 }
554 
555 /*
556  * decode DataRow packet (opposite of pktbuf_write_DataRow)
557  *
558  * tupdesc keys:
559  * 'i' - int4
560  * 'q' - int8
561  * 's' - text to string
562  * 'b' - bytea to bytes (result is malloced)
563  */
scan_text_result(struct MBuf * pkt,const char * tupdesc,...)564 int scan_text_result(struct MBuf *pkt, const char *tupdesc, ...)
565 {
566 	uint16_t ncol;
567 	unsigned asked;
568 	va_list ap;
569 
570 	asked = strlen(tupdesc);
571 	if (!mbuf_get_uint16be(pkt, &ncol))
572 		return -1;
573 
574 	va_start(ap, tupdesc);
575 	for (unsigned i = 0; i < asked; i++) {
576 		const char *val = NULL;
577 		uint32_t len;
578 
579 		if (i < ncol) {
580 			if (!mbuf_get_uint32be(pkt, &len)) {
581 				va_end(ap);
582 				return -1;
583 			}
584 			if ((int32_t)len < 0) {
585 				val = NULL;
586 			} else {
587 				if (!mbuf_get_chars(pkt, len, &val)) {
588 					va_end(ap);
589 					return -1;
590 				}
591 			}
592 
593 			/* hack to zero-terminate the result */
594 			if (val) {
595 				char *xval = (char *)val - 1;
596 				memmove(xval, val, len);
597 				xval[len] = 0;
598 				val = xval;
599 			}
600 		} else {
601 			/* tuple was shorter than requested */
602 			val = NULL;
603 			len = -1;
604 		}
605 
606 		switch (tupdesc[i]) {
607 		case 'i': {
608 			int *int_p;
609 
610 			int_p = va_arg(ap, int *);
611 			*int_p = val ? atoi(val) : 0;
612 			break;
613 		}
614 		case 'q': {
615 			uint64_t *long_p;
616 
617 			long_p = va_arg(ap, uint64_t *);
618 			*long_p = val ? atoll(val) : 0;
619 			break;
620 		}
621 		case 's': {
622 			const char **str_p;
623 
624 			str_p = va_arg(ap, const char **);
625 			*str_p = val;
626 			break;
627 		}
628 		case 'b': {
629 			int *len_p = va_arg(ap, int *);
630 			uint8_t **bytes_p = va_arg(ap, uint8_t **);
631 
632 			if (val) {
633 				int newlen;
634 				if (strncmp(val, "\\x", 2) != 0) {
635 					log_warning("invalid bytea value");
636 					return -1;
637 				}
638 
639 				newlen = (len - 2) / 2;
640 				*len_p = newlen;
641 				*bytes_p = malloc(newlen);
642 				if (!(*bytes_p))
643 					return -1;
644 				for (int j = 0; j < newlen; j++) {
645 					unsigned int b;
646 					sscanf(val + 2 + 2 * j, "%2x", &b);
647 					(*bytes_p)[j] = b;
648 				}
649 			} else {
650 				*len_p = -1;
651 				*bytes_p = NULL;
652 			}
653 			break;
654 		}
655 		default:
656 			fatal("bad tupdesc: %s", tupdesc);
657 		}
658 	}
659 	va_end(ap);
660 
661 	return ncol;
662 }
663