1 /* a throttling transparent proxy. */
2 #include <string.h>
3 #include <stdlib.h>
4 #include "../gsk.h"
5 #include "../gsklistmacros.h"
6 #include "../http/gskhttpcontent.h"
7
8 typedef struct _GskThrottleProxyConnection GskThrottleProxyConnection;
9 typedef struct _Side Side;
10
11 /* configuration */
12 guint upload_per_second_base = 10*1024;
13 guint download_per_second_base = 100*1024;
14 guint upload_per_second_noise = 1*1024;
15 guint download_per_second_noise = 10*1024;
16
17 /* if TRUE, shut-down the read and write ends of the connection
18 independently. if FALSE, either propagation a read or a write
19 shutdown into both. */
20 gboolean half_shutdowns = TRUE;
21
22 GskSocketAddress *bind_addr = NULL;
23 GskSocketAddress *server_addr = NULL;
24 GskSocketAddress *bind_status_addr = NULL;
25
26 static guint n_connections_accepted = 0;
27 static guint64 n_bytes_read_total = 0;
28 static guint64 n_bytes_written_total = 0;
29
30 struct _Side
31 {
32 GskThrottleProxyConnection *connection;
33
34 GskStream *read_side; /* client for upload, server for download */
35 GskStream *write_side; /* client for upload, server for download */
36 gboolean read_side_blocked;
37 gboolean write_side_blocked;
38
39 /* sides are in this list if their xferred_in_last_second==max
40 but buffer.size < max_buffer */
41 Side *next_throttled, *prev_throttled;
42 gboolean throttled;
43
44 guint max_xfer_per_second;
45 gulong last_xfer_second;
46 guint xferred_in_last_second;
47
48 GskBuffer buffer;
49
50 guint max_buffer;/* should be set to max_xfer_per_second or a bit more */
51
52 guint total_read, total_written;
53 };
54
55 struct _GskThrottleProxyConnection
56 {
57 Side upload;
58 Side download;
59
60 guint ref_count;
61
62 GskThrottleProxyConnection *prev, *next;
63 };
64
65 static GskThrottleProxyConnection *first_conn, *last_conn;
66 #define GET_CONNECTION_LIST() \
67 GskThrottleProxyConnection *, first_conn, last_conn, prev, next
68
69 static Side *first_throttled, *last_throttled;
70 #define GET_THROTTLED_LIST() \
71 Side *, first_throttled, last_throttled, prev_throttled, next_throttled
72
73 #define CURRENT_SECOND() (gsk_main_loop_default ()->current_time.tv_sec)
74
75 /* must be called whenever side->buffer changes "emptiness" */
76 static inline void
update_write_block(Side * side)77 update_write_block (Side *side)
78 {
79 gboolean old_val = side->write_side_blocked;
80 gboolean val = (side->read_side != NULL && side->buffer.size == 0);
81 side->write_side_blocked = val;
82
83 if (old_val && !val)
84 gsk_io_unblock_write (side->write_side);
85 else if (!old_val && val)
86 gsk_io_block_write (side->write_side);
87 }
88
89 /* must be called whenever side->buffer changes "emptiness" */
90 static inline void
update_read_block(Side * side)91 update_read_block (Side *side)
92 {
93 gboolean was_throttled = side->throttled;
94 gboolean old_val = side->read_side_blocked;
95 gboolean xfer_blocked = side->xferred_in_last_second >= side->max_xfer_per_second;
96 gboolean buf_blocked = side->buffer.size >= side->max_buffer;
97 gboolean val = xfer_blocked || buf_blocked;
98
99 side->throttled = xfer_blocked && !buf_blocked;
100 side->read_side_blocked = val;
101
102 if (side->throttled && !was_throttled)
103 {
104 /* put in throttled list */
105 GSK_LIST_APPEND (GET_THROTTLED_LIST (), side);
106 }
107 else if (!side->throttled && was_throttled)
108 {
109 /* remove from throttled list */
110 GSK_LIST_REMOVE (GET_THROTTLED_LIST (), side);
111 }
112
113 if (old_val && !val)
114 gsk_io_unblock_read (side->read_side);
115 else if (!old_val && val)
116 gsk_io_block_read (side->read_side);
117 }
118
119 static void
connection_unref(GskThrottleProxyConnection * conn)120 connection_unref (GskThrottleProxyConnection *conn)
121 {
122 if (--(conn->ref_count) == 0)
123 {
124 GSK_LIST_REMOVE (GET_CONNECTION_LIST (), conn);
125 gsk_buffer_destruct (&conn->upload.buffer);
126 gsk_buffer_destruct (&conn->download.buffer);
127 g_free (conn);
128 }
129 }
130
131 static gboolean
handle_side_writable(GskStream * stream,gpointer data)132 handle_side_writable (GskStream *stream,
133 gpointer data)
134 {
135 Side *side = data;
136 GError *error = NULL;
137 guint written = gsk_stream_write_buffer (stream, &side->buffer, &error);
138 if (error)
139 {
140 g_warning ("error writing to stream %p: %s",
141 stream, error->message);
142 g_error_free (error);
143 }
144 n_bytes_written_total += written;
145 side->total_written += written;
146 update_write_block (side);
147 update_read_block (side);
148 if (written == 0 && side->read_side == NULL && side->buffer.size == 0)
149 {
150 update_write_block (side);
151 if (half_shutdowns)
152 gsk_io_write_shutdown (side->write_side, NULL);
153 else
154 gsk_io_shutdown (GSK_IO (side->write_side), NULL);
155 }
156 return TRUE;
157 }
158
159 static gboolean
handle_side_write_shutdown(GskStream * stream,gpointer data)160 handle_side_write_shutdown (GskStream *stream,
161 gpointer data)
162 {
163 Side *side = data;
164 if (side->buffer.size > 0)
165 g_warning ("write-side shut down while data still pending");
166 if (side->read_side)
167 {
168 if (half_shutdowns)
169 gsk_io_read_shutdown (side->read_side, NULL);
170 else
171 gsk_io_shutdown (GSK_IO (side->read_side), NULL);
172 }
173 return FALSE;
174 }
175
176 static void
handle_side_write_destroy(gpointer data)177 handle_side_write_destroy (gpointer data)
178 {
179 Side *side = data;
180 g_object_unref (side->write_side);
181 side->write_side = NULL;
182 connection_unref (side->connection);
183 }
184
185 static gboolean
handle_side_readable(GskStream * stream,gpointer data)186 handle_side_readable (GskStream *stream,
187 gpointer data)
188 {
189 Side *side = data;
190 gulong cur_sec = CURRENT_SECOND ();
191 GError *error = NULL;
192 guint max_read;
193 guint nread;
194 char *tmp;
195 if (cur_sec == side->last_xfer_second)
196 {
197 max_read = side->max_xfer_per_second - side->xferred_in_last_second;
198 }
199 else
200 {
201 side->xferred_in_last_second = 0;
202 side->last_xfer_second = cur_sec;
203 max_read = side->max_xfer_per_second;
204 }
205 if (max_read + side->buffer.size > side->max_buffer)
206 {
207 if (side->buffer.size > side->max_buffer)
208 max_read = 0;
209 else
210 max_read = side->max_buffer - side->buffer.size;
211 }
212
213 tmp = g_malloc (max_read);
214 nread = gsk_stream_read (stream, tmp, max_read, &error);
215 if (error != NULL)
216 {
217 g_warning ("error reading from stream %p: %s",
218 stream, error->message);
219 g_error_free (error);
220 }
221 /* TODO: use append_foreign if nread is big */
222 gsk_buffer_append (&side->buffer, tmp, nread);
223
224 g_free (tmp);
225 n_bytes_read_total += nread;
226 side->total_read += nread;
227
228 side->xferred_in_last_second += nread;
229 g_assert (side->xferred_in_last_second <= side->max_xfer_per_second);
230 update_write_block (side);
231 update_read_block (side);
232 return TRUE;
233 }
234
235 static gboolean
handle_side_read_shutdown(GskStream * stream,gpointer data)236 handle_side_read_shutdown (GskStream *stream,
237 gpointer data)
238 {
239 return FALSE;
240 }
241
242 static void
handle_side_read_destroy(gpointer data)243 handle_side_read_destroy (gpointer data)
244 {
245 Side *side = data;
246 g_object_unref (side->read_side);
247 side->read_side = NULL;
248 if (side->buffer.size == 0 && side->write_side != NULL)
249 {
250 update_write_block (side);
251 if (half_shutdowns)
252 gsk_io_write_shutdown (side->write_side, NULL);
253 else
254 gsk_io_shutdown (GSK_IO (side->write_side), NULL);
255 }
256 connection_unref (side->connection);
257 }
258
259 static void
side_init(Side * side,GskThrottleProxyConnection * conn,GskStream * read_side,GskStream * write_side,guint max_xfer_per_second)260 side_init (Side *side,
261 GskThrottleProxyConnection *conn,
262 GskStream *read_side,
263 GskStream *write_side,
264 guint max_xfer_per_second)
265 {
266 side->connection = conn;
267 side->read_side = read_side;
268 side->write_side = write_side;
269 side->read_side_blocked = FALSE;
270 side->write_side_blocked = FALSE;
271 side->throttled = FALSE;
272 side->next_throttled = side->prev_throttled = NULL;
273 side->max_xfer_per_second = max_xfer_per_second;
274 side->last_xfer_second = gsk_main_loop_default ()->current_time.tv_sec;
275 side->xferred_in_last_second = 0;
276 gsk_buffer_construct (&side->buffer);
277 side->max_buffer = max_xfer_per_second;
278 side->total_read = 0;
279 side->total_written = 0;
280
281 conn->ref_count += 2;
282
283 g_object_ref (read_side);
284 gsk_io_trap_readable (read_side,
285 handle_side_readable,
286 handle_side_read_shutdown,
287 side,
288 handle_side_read_destroy);
289
290 g_object_ref (write_side);
291 gsk_io_trap_writable (write_side,
292 handle_side_writable,
293 handle_side_write_shutdown,
294 side,
295 handle_side_write_destroy);
296 }
297
298
299 /* --- handle a new stream --- */
300 static guint
pick_rand(guint base,guint noise)301 pick_rand (guint base, guint noise)
302 {
303 return base + (noise ? g_random_int_range (0, noise) : 0);
304 }
305
306 static gboolean
handle_accept(GskStream * stream,gpointer data,GError ** error)307 handle_accept (GskStream *stream,
308 gpointer data,
309 GError **error)
310 {
311 GskThrottleProxyConnection *conn = g_new (GskThrottleProxyConnection, 1);
312 GError *e = NULL;
313 GskStream *server = gsk_stream_new_connecting (server_addr, &e);
314 if (e)
315 g_error ("gsk_stream_new_connecting failed: %s", e->message);
316 n_connections_accepted++;
317 conn->ref_count = 1;
318 GSK_LIST_APPEND (GET_CONNECTION_LIST (), conn);
319 side_init (&conn->upload, conn, stream, server,
320 pick_rand (upload_per_second_base, upload_per_second_noise));
321 side_init (&conn->download, conn, server, stream,
322 pick_rand (download_per_second_base, download_per_second_noise));
323 connection_unref (conn);
324 g_object_unref (stream);
325 g_object_unref (server);
326 return TRUE;
327 }
328
329 static void
handle_listener_error(GError * error,gpointer data)330 handle_listener_error (GError *error,
331 gpointer data)
332 {
333 g_error ("handle_listener_error: %s", error->message);
334 }
335
336 /* --- unblock throttled streams every second --- */
337 static gboolean
unblock_timer_func(gpointer data)338 unblock_timer_func (gpointer data)
339 {
340 Side *at = first_throttled;
341 gulong sec = CURRENT_SECOND ();
342 while (at)
343 {
344 Side *next = at->next_throttled;
345 g_assert (at->throttled);
346 g_assert (at->read_side_blocked);
347 if (sec > at->last_xfer_second)
348 {
349 at->last_xfer_second = sec;
350 at->xferred_in_last_second = 0;
351 update_read_block (at);
352 }
353 at = next;
354 }
355
356 /* schedule next timeout */
357 gsk_main_loop_add_timer_absolute (gsk_main_loop_default (),
358 unblock_timer_func, NULL, NULL,
359 sec + 1, 0);
360
361 return FALSE;
362 }
363
364 static void
usage(void)365 usage (void)
366 {
367 g_printerr ("usage: %s --bind=LISTEN_ADDR --server=CONNECT_ADDR OPTIONS\n\n",
368 g_get_prgname ());
369 g_printerr ("Bind to LISTEN_ADDR; whenever we receive a connection,\n"
370 "proxy to CONNECT_ADDR, obeying thottling constraints.\n"
371 "\n"
372 "Options:\n"
373 " --bind-status=STATUS_ADDR Report status on this addr.\n"
374 " --upload-rate=BPS ...\n"
375 " --download-rate=BPS ...\n"
376 " --upload-rate-noise=BPS ...\n"
377 " --download-rate-noise=BPS ...\n"
378 " --full-shutdowns\n"
379 " --half-shutdowns\n"
380 );
381 exit (1);
382 }
383
384 static void
dump_side_to_buffer(Side * side,GskBuffer * out)385 dump_side_to_buffer (Side *side, GskBuffer *out)
386 {
387 gsk_buffer_printf (out, "<td>%sreadable%s, %swritable%s, %u buffered [total read/written=%u/%u]</td>\n",
388 side->read_side ? "" : "NOT ",
389 side->throttled ? " [throttled]" :
390 side->read_side_blocked ? " [blocked]" : "",
391 side->write_side ? "" : "NOT ",
392 side->write_side_blocked ? " [blocked]" : "",
393 side->buffer.size,
394 side->total_read, side->total_written);
395 }
396
397 static GskHttpContentResult
create_status_page(GskHttpContent * content,GskHttpContentHandler * handler,GskHttpServer * server,GskHttpRequest * request,GskStream * post_data,gpointer data)398 create_status_page (GskHttpContent *content,
399 GskHttpContentHandler *handler,
400 GskHttpServer *server,
401 GskHttpRequest *request,
402 GskStream *post_data,
403 gpointer data)
404 {
405 GskThrottleProxyConnection *conn;
406 GskBuffer buffer = GSK_BUFFER_STATIC_INIT;
407 GskHttpResponse *response;
408 GskStream *stream;
409 gsk_buffer_printf (&buffer, "<html><head>\n");
410 gsk_buffer_printf (&buffer, "<title>GskThrottleProxy Status Page</title>\n");
411 gsk_buffer_printf (&buffer, "</head>\n");
412 gsk_buffer_printf (&buffer, "<body>\n");
413 gsk_buffer_printf (&buffer, "<h1>Statistics</h1>\n");
414 gsk_buffer_printf (&buffer, "<br>%u connections accepted.\n",
415 n_connections_accepted);
416 gsk_buffer_printf (&buffer, "<br>%"G_GUINT64_FORMAT" bytes read.\n",
417 n_bytes_read_total);
418 gsk_buffer_printf (&buffer, "<br>%"G_GUINT64_FORMAT" bytes written.\n",
419 n_bytes_written_total);
420 gsk_buffer_printf (&buffer, "<h1>Connections</h1>\n");
421 gsk_buffer_printf (&buffer, "<table>\n"
422 " <tr><th>Connection Pointer</th>"
423 "<th>RefCount</th>"
424 "<th>Upload</th>"
425 "<th>Download</th>"
426 "</tr>\n");
427 for (conn = first_conn; conn; conn = conn->next)
428 {
429 gsk_buffer_printf (&buffer,
430 " <tr><td>%p</td><td>%u</td>", conn, conn->ref_count);
431 dump_side_to_buffer (&conn->upload, &buffer);
432 dump_side_to_buffer (&conn->download, &buffer);
433 gsk_buffer_printf (&buffer, "</tr>\n");
434 }
435 gsk_buffer_printf (&buffer, "</table>\n</body>\n</html>\n");
436 response = gsk_http_response_from_request (request, 200, buffer.size);
437 gsk_http_header_set_content_type (response, "text");
438 gsk_http_header_set_content_subtype (response, "html");
439 stream = gsk_memory_buffer_source_new (&buffer);
440 gsk_http_server_respond (server, request, response, stream);
441 g_object_unref (response);
442 g_object_unref (stream);
443
444 return GSK_HTTP_CONTENT_OK;
445 }
446
447 /* --- main --- */
main(int argc,char ** argv)448 int main(int argc, char **argv)
449 {
450 guint i;
451 GskStreamListener *listener;
452 GError *error = NULL;
453 gsk_init_without_threads (&argc, &argv);
454 for (i = 1; i < (guint) argc; i++)
455 {
456 if (g_str_has_prefix (argv[i], "--bind="))
457 {
458 const char *bind_str = strchr (argv[i], '=') + 1;
459 if (bind_addr != NULL)
460 g_error ("--bind may only be given once");
461 if (g_ascii_isdigit (bind_str[0]))
462 {
463 bind_addr = gsk_socket_address_ipv4_new (gsk_ipv4_ip_address_any,
464 atoi (bind_str));
465 }
466 else
467 {
468 bind_addr = gsk_socket_address_local_new (bind_str);
469 }
470 }
471 else if (g_str_has_prefix (argv[i], "--bind-status="))
472 {
473 const char *bind_str = strchr (argv[i], '=') + 1;
474 if (bind_status_addr != NULL)
475 g_error ("--bind-status may only be given once");
476 if (g_ascii_isdigit (bind_str[0]))
477 {
478 bind_status_addr = gsk_socket_address_ipv4_new (gsk_ipv4_ip_address_any,
479 atoi (bind_str));
480 }
481 else
482 {
483 bind_status_addr = gsk_socket_address_local_new (bind_str);
484 }
485 }
486 else if (g_str_has_prefix (argv[i], "--server="))
487 {
488 const char *server_str = strchr (argv[i], '=') + 1;
489 const char *colon = strchr (server_str, ':');
490 if (server_addr != NULL)
491 g_error ("--server may only be given once");
492 if (colon != NULL && strchr (server_str, '/') == NULL)
493 {
494 /* host:port */
495 char *host = g_strndup (server_str, colon - server_str);
496 guint port = atoi (colon + 1);
497 server_addr = gsk_socket_address_symbolic_ipv4_new (host, port);
498 g_free (host);
499 }
500 else
501 {
502 /* unix */
503 server_addr = gsk_socket_address_local_new (server_str);
504 }
505 }
506 else if (g_str_has_prefix (argv[i], "--upload-rate="))
507 upload_per_second_base = atoi (strchr (argv[i], '=') + 1);
508 else if (g_str_has_prefix (argv[i], "--download-rate="))
509 download_per_second_base = atoi (strchr (argv[i], '=') + 1);
510 else if (g_str_has_prefix (argv[i], "--upload-rate-noise="))
511 upload_per_second_noise = atoi (strchr (argv[i], '=') + 1);
512 else if (g_str_has_prefix (argv[i], "--download-rate-noise="))
513 download_per_second_noise = atoi (strchr (argv[i], '=') + 1);
514 else if (strcmp (argv[i], "--half-shutdowns") == 0)
515 half_shutdowns = TRUE;
516 else if (strcmp (argv[i], "--full-shutdowns") == 0)
517 half_shutdowns = FALSE;
518 else
519 usage ();
520 }
521
522 if (server_addr == NULL)
523 g_error ("missing --server=ADDRESS: try --help");
524 if (bind_addr == NULL)
525 g_error ("missing --bind=ADDRESS: try --help");
526
527 listener = gsk_stream_listener_socket_new_bind (bind_addr, &error);
528 if (listener == NULL)
529 g_error ("bind failed: %s", error->message);
530 gsk_stream_listener_handle_accept (listener,
531 handle_accept,
532 handle_listener_error,
533 NULL, NULL);
534
535 if (bind_status_addr != NULL)
536 {
537 GskHttpContentHandler *handler;
538 GskHttpContent *content = gsk_http_content_new ();
539 GskHttpContentId id = GSK_HTTP_CONTENT_ID_INIT;
540 handler = gsk_http_content_handler_new (create_status_page, NULL, NULL);
541 id.path = "/";
542 gsk_http_content_add_handler (content, &id, handler, GSK_HTTP_CONTENT_REPLACE);
543 gsk_http_content_handler_unref (handler);
544 if (!gsk_http_content_listen (content, bind_status_addr, &error))
545 g_error ("error listening: %s", error->message);
546 }
547
548 gsk_main_loop_add_timer_absolute (gsk_main_loop_default (),
549 unblock_timer_func, NULL, NULL,
550 gsk_main_loop_default ()->current_time.tv_sec + 1, 0);
551
552 return gsk_main_run ();
553 }
554