1 /*
2  * PgBouncer - Lightweight connection pooler for PostgreSQL.
3  *
4  * Copyright (c) 2007-2009  Marko Kreen, Skype Technologies OÜ
5  *
6  * Permission to use, copy, modify, and/or distribute this software for any
7  * purpose with or without fee is hereby granted, provided that the above
8  * copyright notice and this permission notice appear in all copies.
9  *
10  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17  */
18 
19 /*
20  * Connect to running bouncer process, load fds from it, shut it down
21  * and continue with them.
22  *
23  * Each row from SHOW FDS will have corresponding fd in ancillary message.
24  *
25  * Manpages: unix, sendmsg, recvmsg, cmsg, readv
26  */
27 
28 #include "bouncer.h"
29 
30 /*
31  * Takeover done, old process shut down,
32  * kick this one running.
33  */
34 
35 static PgSocket *old_bouncer = NULL;
36 
takeover_finish(void)37 void takeover_finish(void)
38 {
39 	uint8_t buf[512];
40 	int fd = sbuf_socket(&old_bouncer->sbuf);
41 	bool res;
42 	ssize_t got;
43 
44 	log_info("sending SHUTDOWN;");
45 	socket_set_nonblocking(fd, 0);
46 	SEND_generic(res, old_bouncer, 'Q', "s", "SHUTDOWN;");
47 	if (!res)
48 		die("failed to send SHUTDOWN;");
49 
50 	while (1) {
51 		got = safe_recv(fd, buf, sizeof(buf), 0);
52 		if (got == 0)
53 			break;
54 		if (got < 0)
55 			die("sky is falling - error while waiting result from SHUTDOWN: %s", strerror(errno));
56 	}
57 
58 	disconnect_server(old_bouncer, false, "disko over");
59 	old_bouncer = NULL;
60 
61 	if (cf_pidfile && cf_pidfile[0]) {
62 		log_info("waiting for old pidfile to go away");
63 		while (1) {
64 			struct stat st;
65 			if (stat(cf_pidfile, &st) < 0) {
66 				if (errno == ENOENT)
67 					break;
68 			}
69 			usleep(USEC/10);
70 		}
71 	}
72 
73 	log_info("old process killed, resuming work");
74 	resume_all();
75 }
76 
takeover_finish_part1(PgSocket * bouncer)77 static void takeover_finish_part1(PgSocket *bouncer)
78 {
79 	Assert(old_bouncer == NULL);
80 
81 	/* unregister bouncer from libevent */
82 	if (!sbuf_pause(&bouncer->sbuf))
83 		fatal_perror("sbuf_pause failed");
84 	old_bouncer = bouncer;
85 	cf_reboot = 0;
86 	log_info("disko over, going background");
87 }
88 
89 /* parse msg for fd and info */
takeover_load_fd(struct MBuf * pkt,const struct cmsghdr * cmsg)90 static void takeover_load_fd(struct MBuf *pkt, const struct cmsghdr *cmsg)
91 {
92 	int fd;
93 	char *task, *saddr, *user, *db;
94 	char *client_enc, *std_string, *datestyle, *timezone, *password,
95 		*scram_client_key, *scram_server_key;
96 	int scram_client_key_len, scram_server_key_len;
97 	int oldfd, port, linkfd;
98 	int got;
99 	uint64_t ckey;
100 	PgAddr addr;
101 	bool res = false;
102 
103 	memset(&addr, 0, sizeof(addr));
104 
105 	if (cmsg->cmsg_level == SOL_SOCKET
106 		&& cmsg->cmsg_type == SCM_RIGHTS
107 		&& cmsg->cmsg_len >= CMSG_LEN(sizeof(int)))
108 	{
109 		/* get the fd */
110 		memcpy(&fd, CMSG_DATA(cmsg), sizeof(int));
111 		log_debug("got fd: %d", fd);
112 	} else {
113 		fatal("broken fd packet");
114 	}
115 
116 	/* parse row contents */
117 	got = scan_text_result(pkt, "issssiqisssssbb", &oldfd, &task, &user, &db,
118 			       &saddr, &port, &ckey, &linkfd,
119 			       &client_enc, &std_string, &datestyle, &timezone,
120 			       &password,
121 			       &scram_client_key_len,
122 			       &scram_client_key,
123 			       &scram_server_key_len,
124 			       &scram_server_key);
125 	if (got < 0)
126 		die("invalid data from old process");
127 	if (task == NULL || saddr == NULL)
128 		die("incomplete data from old process");
129 
130 	log_debug("FD row: fd=%d(%d) linkfd=%d task=%s user=%s db=%s enc=%s",
131 		  oldfd, fd, linkfd, task,
132 		  user ? user : "NULL", db ? db : "NULL",
133 		  client_enc ? client_enc : "NULL");
134 
135 	if (!password)
136 		password = "";
137 
138 	/* fill address */
139 	if (strcmp(saddr, "unix") == 0) {
140 		pga_set(&addr, AF_UNIX, cf_listen_port);
141 	} else {
142 		if (!pga_pton(&addr, saddr, port))
143 			fatal("failed to convert address: %s", saddr);
144 	}
145 
146 	/* decide what to do with it */
147 	if (strcmp(task, "client") == 0) {
148 		res = use_client_socket(fd, &addr, db, user, ckey, oldfd, linkfd,
149 				  client_enc, std_string, datestyle, timezone,
150 				  password,
151 				  scram_client_key, scram_client_key_len,
152 				  scram_server_key, scram_server_key_len);
153 	} else if (strcmp(task, "server") == 0) {
154 		res = use_server_socket(fd, &addr, db, user, ckey, oldfd, linkfd,
155 				  client_enc, std_string, datestyle, timezone,
156 				  password,
157 				  scram_client_key, scram_client_key_len,
158 				  scram_server_key, scram_server_key_len);
159 	} else if (strcmp(task, "pooler") == 0) {
160 		res = use_pooler_socket(fd, pga_is_unix(&addr));
161 	} else {
162 		fatal("unknown task: %s", task);
163 	}
164 
165 	free(scram_client_key);
166 	free(scram_server_key);
167 
168 	if (!res)
169 		fatal("socket takeover failed");
170 }
171 
takeover_create_link(PgPool * pool,PgSocket * client)172 static void takeover_create_link(PgPool *pool, PgSocket *client)
173 {
174 	struct List *item;
175 	PgSocket *server;
176 
177 	statlist_for_each(item, &pool->active_server_list) {
178 		server = container_of(item, PgSocket, head);
179 		if (server->tmp_sk_oldfd == client->tmp_sk_linkfd) {
180 			server->link = client;
181 			client->link = server;
182 			return;
183 		}
184 	}
185 	fatal("takeover_create_link: failed to find pair");
186 }
187 
188 /* clean the inappropriate places the old fds got stored in */
takeover_clean_socket_list(struct StatList * list)189 static void takeover_clean_socket_list(struct StatList *list)
190 {
191 	struct List *item;
192 	PgSocket *sk;
193 	statlist_for_each(item, list) {
194 		sk = container_of(item, PgSocket, head);
195 		if (sk->suspended) {
196 			sk->tmp_sk_oldfd = get_cached_time();
197 			sk->tmp_sk_linkfd = get_cached_time();
198 		}
199 	}
200 }
201 
202 /* all fds loaded, create links */
takeover_postprocess_fds(void)203 static void takeover_postprocess_fds(void)
204 {
205 	struct List *item, *item2;
206 	PgSocket *client;
207 	PgPool *pool;
208 
209 	statlist_for_each(item, &pool_list) {
210 		pool = container_of(item, PgPool, head);
211 		if (pool->db->admin)
212 			continue;
213 		statlist_for_each(item2, &pool->active_client_list) {
214 			client = container_of(item2, PgSocket, head);
215 			if (client->suspended && client->tmp_sk_linkfd)
216 				takeover_create_link(pool, client);
217 		}
218 	}
219 	statlist_for_each(item, &pool_list) {
220 		pool = container_of(item, PgPool, head);
221 		takeover_clean_socket_list(&pool->active_client_list);
222 		takeover_clean_socket_list(&pool->active_server_list);
223 		takeover_clean_socket_list(&pool->idle_server_list);
224 	}
225 }
226 
next_command(PgSocket * bouncer,struct MBuf * pkt)227 static void next_command(PgSocket *bouncer, struct MBuf *pkt)
228 {
229 	bool res = true;
230 	const char *cmd;
231 
232 	if (!mbuf_get_string(pkt, &cmd))
233 		fatal("bad result pkt");
234 
235 	log_debug("takeover_recv_fds: 'C' body: %s", cmd);
236 	if (strcmp(cmd, "SUSPEND") == 0) {
237 		log_info("SUSPEND finished, sending SHOW FDS");
238 		SEND_generic(res, bouncer, 'Q', "s", "SHOW FDS;");
239 	} else if (strncmp(cmd, "SHOW", 4) == 0) {
240 		/* all fds loaded, review them */
241 		takeover_postprocess_fds();
242 		log_info("SHOW FDS finished");
243 
244 		takeover_finish_part1(bouncer);
245 	} else {
246 		fatal("got bad CMD from old bouncer: %s", cmd);
247 	}
248 
249 	if (!res)
250 		fatal("command send failed");
251 }
252 
takeover_parse_data(PgSocket * bouncer,struct msghdr * msg,struct MBuf * data)253 static void takeover_parse_data(PgSocket *bouncer,
254 				struct msghdr *msg, struct MBuf *data)
255 {
256 	struct cmsghdr *cmsg;
257 	PktHdr pkt;
258 
259 	cmsg = msg->msg_controllen ? CMSG_FIRSTHDR(msg) : NULL;
260 
261 	while (mbuf_avail_for_read(data) > 0) {
262 		if (!get_header(data, &pkt))
263 			fatal("cannot parse packet");
264 
265 		/*
266 		 * There should not be partial reads from UNIX socket.
267 		 */
268 		if (incomplete_pkt(&pkt))
269 			fatal("unexpected partial packet");
270 
271 		switch (pkt.type) {
272 		case 'T': /* RowDescription */
273 			log_debug("takeover_parse_data: 'T'");
274 			break;
275 		case 'D': /* DataRow */
276 			log_debug("takeover_parse_data: 'D'");
277 			if (cmsg) {
278 				takeover_load_fd(&pkt.data, cmsg);
279 				cmsg = CMSG_NXTHDR(msg, cmsg);
280 			} else
281 				fatal("got row without fd info");
282 			break;
283 		case 'Z': /* ReadyForQuery */
284 			log_debug("takeover_parse_data: 'Z'");
285 			break;
286 		case 'C': /* CommandComplete */
287 			log_debug("takeover_parse_data: 'C'");
288 			next_command(bouncer, &pkt.data);
289 			break;
290 		case 'E': /* ErrorMessage */
291 			log_server_error("old bouncer sent", &pkt);
292 			fatal("something failed");
293 		default:
294 			fatal("takeover_parse_data: unexpected pkt: '%c'", pkt_desc(&pkt));
295 		}
296 	}
297 }
298 
299 /*
300  * listen for data from old bouncer.
301  *
302  * use always recvmsg, to keep code simpler
303  */
takeover_recv_cb(evutil_socket_t sock,short flags,void * arg)304 static void takeover_recv_cb(evutil_socket_t sock, short flags, void *arg)
305 {
306 	PgSocket *bouncer = container_of(arg, PgSocket, sbuf);
307 	uint8_t data_buf[STARTUP_BUF * 2];
308 	uint8_t cnt_buf[128];
309 	struct msghdr msg;
310 	struct iovec io;
311 	ssize_t res;
312 	struct MBuf data;
313 
314 	memset(&msg, 0, sizeof(msg));
315 	io.iov_base = data_buf;
316 	io.iov_len = sizeof(data_buf);
317 	msg.msg_iov = &io;
318 	msg.msg_iovlen = 1;
319 	msg.msg_control = cnt_buf;
320 	msg.msg_controllen = sizeof(cnt_buf);
321 
322 	res = safe_recvmsg(sock, &msg, 0);
323 	if (res > 0) {
324 		mbuf_init_fixed_reader(&data, data_buf, res);
325 		takeover_parse_data(bouncer, &msg, &data);
326 	} else if (res == 0) {
327 		fatal("unexpected EOF");
328 	} else {
329 		if (errno == EAGAIN)
330 			return;
331 		fatal_perror("safe_recvmsg");
332 	}
333 }
334 
335 /*
336  * login finished, send first command,
337  * replace recv callback with custom recvmsg() based one.
338  */
takeover_login(PgSocket * bouncer)339 bool takeover_login(PgSocket *bouncer)
340 {
341 	bool res;
342 
343 	slog_info(bouncer, "login OK, sending SUSPEND");
344 	SEND_generic(res, bouncer, 'Q', "s", "SUSPEND;");
345 	if (res) {
346 		/* use own callback */
347 		if (!sbuf_pause(&bouncer->sbuf))
348 			fatal("sbuf_pause failed");
349 		res = sbuf_continue_with_callback(&bouncer->sbuf, takeover_recv_cb);
350 		if (!res)
351 			fatal("takeover_login: sbuf_continue_with_callback failed");
352 	} else {
353 		fatal("takeover_login: failed to send command");
354 	}
355 	return res;
356 }
357 
358 /* launch connection to running process */
takeover_init(void)359 void takeover_init(void)
360 {
361 	PgDatabase *db = find_database("pgbouncer");
362 	PgPool *pool = get_pool(db, db->forced_user);
363 
364 	if (!pool)
365 		fatal("no admin pool?");
366 
367 	log_info("takeover_init: launching connection");
368 	launch_new_connection(pool);
369 }
370 
takeover_login_failed(void)371 void takeover_login_failed(void)
372 {
373 	fatal("login failed");
374 }
375