1 /*  Copyright (C) 2021 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
2 
3     This program is free software: you can redistribute it and/or modify
4     it under the terms of the GNU General Public License as published by
5     the Free Software Foundation, either version 3 of the License, or
6     (at your option) any later version.
7 
8     This program is distributed in the hope that it will be useful,
9     but WITHOUT ANY WARRANTY; without even the implied warranty of
10     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11     GNU General Public License for more details.
12 
13     You should have received a copy of the GNU General Public License
14     along with this program.  If not, see <https://www.gnu.org/licenses/>.
15  */
16 
17 #include <assert.h>
18 #include <stdio.h>
19 #include <stdlib.h>
20 #include <string.h>
21 #include <unistd.h>
22 
23 #include "contrib/conn_pool.h"
24 
25 #include "contrib/sockaddr.h"
26 
27 conn_pool_t *global_conn_pool = NULL;
28 
29 static int pool_pop(conn_pool_t *pool, size_t i);
30 
31 /*!
32  * \brief Try to get an open connection older than specified timestamp.
33  *
34  * \param pool           Pool to search in.
35  * \param older_than     Timestamp that the connection must be older than.
36  * \param next_oldest    Out: the timestamp of the oldest connection (other than the returned).
37  *
38  * \return -1 if error (no such connection), >= 0 connection file descriptor.
39  *
40  * \warning The returned connection is not necessarily the oldest one.
41  */
get_old(conn_pool_t * pool,knot_time_t older_than,knot_time_t * next_oldest)42 static int get_old(conn_pool_t *pool,
43                    knot_time_t older_than,
44                    knot_time_t *next_oldest)
45 {
46 	assert(pool);
47 
48 	*next_oldest = 0;
49 
50 	int fd = -1;
51 	pthread_mutex_lock(&pool->mutex);
52 
53 	for (size_t i = 0; i < pool->capacity; i++) {
54 		knot_time_t la = pool->conns[i].last_active;
55 		if (fd == -1 && knot_time_cmp(la, older_than) < 0) {
56 			fd = pool_pop(pool, i);
57 		} else if (knot_time_cmp(la, *next_oldest) < 0) {
58 			*next_oldest = la;
59 		}
60 	}
61 
62 	pthread_mutex_unlock(&pool->mutex);
63 	return fd;
64 }
65 
closing_thread(void * _arg)66 static void *closing_thread(void *_arg)
67 {
68 	conn_pool_t *pool = _arg;
69 
70 	while (true) {
71 		knot_time_t now = knot_time(), next = 0;
72 		knot_timediff_t timeout = conn_pool_timeout(pool, 0);
73 		assert(timeout != 0);
74 
75 		while (true) {
76 			int old_fd = get_old(pool, now - timeout + 1, &next);
77 			if (old_fd >= 0) {
78 				close(old_fd);
79 			} else {
80 				break;
81 			}
82 		}
83 
84 		if (next == 0) {
85 			sleep(timeout);
86 		} else {
87 			sleep(next + timeout - now);
88 		}
89 	}
90 
91 	return NULL; // we never get here since the thread will be cancelled instead
92 }
93 
conn_pool_init(size_t capacity,knot_timediff_t timeout)94 conn_pool_t *conn_pool_init(size_t capacity, knot_timediff_t timeout)
95 {
96 	if (capacity == 0 || timeout == 0) {
97 		return NULL;
98 	}
99 
100 	conn_pool_t *pool = calloc(1, sizeof(*pool) + capacity * sizeof(pool->conns[0]));
101 	if (pool != NULL) {
102 		pool->capacity = capacity;
103 		pool->timeout = timeout;
104 		if (pthread_mutex_init(&pool->mutex, 0) != 0) {
105 			free(pool);
106 			return NULL;
107 		}
108 		if (pthread_create(&pool->closing_thread, NULL, closing_thread, pool) != 0) {
109 			pthread_mutex_destroy(&pool->mutex);
110 			free(pool);
111 			return NULL;
112 		}
113 	}
114 	return pool;
115 }
116 
conn_pool_deinit(conn_pool_t * pool)117 void conn_pool_deinit(conn_pool_t *pool)
118 {
119 	if (pool != NULL) {
120 		pthread_cancel(pool->closing_thread);
121 		pthread_join(pool->closing_thread, NULL);
122 
123 		int fd;
124 		knot_time_t unused;
125 		while ((fd = get_old(pool, 0, &unused)) >= 0) {
126 			close(fd);
127 		}
128 
129 		pthread_mutex_destroy(&pool->mutex);
130 		free(pool);
131 	}
132 }
133 
conn_pool_timeout(conn_pool_t * pool,knot_timediff_t new_timeout)134 knot_timediff_t conn_pool_timeout(conn_pool_t *pool,
135                                   knot_timediff_t new_timeout)
136 {
137 	if (pool == NULL) {
138 		return 0;
139 	}
140 
141 	pthread_mutex_lock(&pool->mutex);
142 
143 	knot_timediff_t prev = pool->timeout;
144 	if (new_timeout != 0) {
145 		pool->timeout = new_timeout;
146 	}
147 
148 	pthread_mutex_unlock(&pool->mutex);
149 	return prev;
150 }
151 
pool_pop(conn_pool_t * pool,size_t i)152 static int pool_pop(conn_pool_t *pool, size_t i)
153 {
154 	conn_pool_memb_t *conn = &pool->conns[i];
155 	assert(conn->last_active != 0);
156 	assert(pool->usage > 0);
157 	int fd = conn->fd;
158 	memset(conn, 0, sizeof(*conn));
159 	pool->usage--;
160 	return fd;
161 }
162 
conn_pool_get(conn_pool_t * pool,struct sockaddr_storage * src,struct sockaddr_storage * dst)163 int conn_pool_get(conn_pool_t *pool,
164                   struct sockaddr_storage *src,
165                   struct sockaddr_storage *dst)
166 {
167 	if (pool == NULL) {
168 		return -1;
169 	}
170 
171 	int fd = -1;
172 	pthread_mutex_lock(&pool->mutex);
173 
174 	for (size_t i = 0; i < pool->capacity; i++) {
175 		if (pool->conns[i].last_active != 0 &&
176 		    sockaddr_cmp(&pool->conns[i].dst, dst, false) == 0 &&
177 		    sockaddr_cmp(&pool->conns[i].src, src, true) == 0) {
178 			fd = pool_pop(pool, i);
179 			break;
180 		}
181 	}
182 
183 	pthread_mutex_unlock(&pool->mutex);
184 
185 	if (fd >= 0) {
186 		uint8_t unused;
187 		int peek = recv(fd, &unused, 1, MSG_PEEK | MSG_DONTWAIT);
188 		if (peek >= 0) { // closed or pending data
189 			close(fd);
190 			fd = -1;
191 		}
192 	}
193 
194 	return fd;
195 }
196 
pool_push(conn_pool_t * pool,size_t i,struct sockaddr_storage * src,struct sockaddr_storage * dst,int fd)197 static void pool_push(conn_pool_t *pool, size_t i,
198                       struct sockaddr_storage *src,
199                       struct sockaddr_storage *dst,
200                       int fd)
201 {
202 	conn_pool_memb_t *conn = &pool->conns[i];
203 	assert(conn->last_active == 0);
204 	assert(pool->usage < pool->capacity);
205 	conn->last_active = knot_time();
206 	conn->fd = fd;
207 	memcpy(&conn->src, src, sizeof(conn->src));
208 	memcpy(&conn->dst, dst, sizeof(conn->dst));
209 	pool->usage++;
210 }
211 
conn_pool_put(conn_pool_t * pool,struct sockaddr_storage * src,struct sockaddr_storage * dst,int fd)212 int conn_pool_put(conn_pool_t *pool,
213                   struct sockaddr_storage *src,
214                   struct sockaddr_storage *dst,
215                   int fd)
216 {
217 	if (pool == NULL || pool->capacity == 0) {
218 		return fd;
219 	}
220 
221 	knot_time_t oldest_time = 0;
222 	size_t oldest_i = pool->capacity;
223 
224 	pthread_mutex_lock(&pool->mutex);
225 
226 	for (size_t i = 0; i < pool->capacity; i++) {
227 		knot_time_t la = pool->conns[i].last_active;
228 		if (la == 0) {
229 			pool_push(pool, i, src, dst, fd);
230 			pthread_mutex_unlock(&pool->mutex);
231 			return -1;
232 		} else if (knot_time_cmp(la, oldest_time) < 0) {
233 			oldest_time = la;
234 			oldest_i = i;
235 		}
236 	}
237 
238 	assert(oldest_i < pool->capacity);
239 	int oldest_fd = pool_pop(pool, oldest_i);
240 	pool_push(pool, oldest_i, src, dst, fd);
241 	pthread_mutex_unlock(&pool->mutex);
242 	return oldest_fd;
243 }
244